ugh: use custom snow for async

This commit is contained in:
dr-frmr 2024-05-28 22:40:00 -06:00
parent 55ed925277
commit 6814455baf
No known key found for this signature in database
4 changed files with 87 additions and 79 deletions

5
Cargo.lock generated
View File

@ -5033,9 +5033,8 @@ checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67"
[[package]]
name = "snow"
version = "0.9.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "850948bee068e713b8ab860fe1adc4d109676ab4c3b621fd8147f06b261f2f85"
version = "0.9.0"
source = "git+https://github.com/dr-frmr/snow?branch=dr/extract_cipherstates#1d4eb5f6747aa59aabb32bbbe698fb4bb7dfb9a4"
dependencies = [
"aes-gcm",
"blake2",

View File

@ -80,7 +80,9 @@ serde_json = "1.0"
serde_urlencoded = "0.7"
sha2 = "0.10"
sha3 = "0.10.8"
snow = { version = "0.9.5", features = ["ring-resolver"] }
# snow = { version = "0.9.5", features = ["ring-resolver"] }
# unfortunately need to use forked version for async use and in-place encryption
snow = { git = "https://github.com/dr-frmr/snow", branch = "dr/extract_cipherstates", features = ["ring-resolver"] }
socket2 = "0.5.7"
static_dir = "0.2.0"
thiserror = "1.0"

View File

@ -6,7 +6,7 @@ use crate::net::{
use lib::types::core::{KernelMessage, MessageSender, NodeId, PrintSender};
use {
tokio::io::{AsyncReadExt, AsyncWriteExt},
tokio::net::TcpStream,
tokio::net::{tcp::OwnedReadHalf, tcp::OwnedWriteHalf, TcpStream},
tokio::sync::mpsc::UnboundedReceiver,
};
@ -29,112 +29,117 @@ pub async fn maintain_connection(
.set_tcp_keepalive(&ka)
.expect("failed to set tcp keepalive");
loop {
tokio::select! {
maybe_recv = peer_rx.recv() => {
let Some(km) = maybe_recv else {
break
};
let Ok(()) = send_protocol_message(&km, &mut conn).await else {
break
};
},
outer_len = recv_protocol_message_init(&mut conn.stream) => {
match outer_len {
Ok((read, outer_len)) => match recv_protocol_message(&mut conn, read, outer_len).await {
Ok(km) => {
if km.source.node != peer_name {
print_loud(
&print_tx,
&format!(
"net: got message with spoofed source from {peer_name}!"
),
).await;
break
} else {
kernel_message_tx.send(km).await.expect("net: fatal: kernel receiver died");
continue
}
}
Err(e) => {
print_debug(&print_tx, &format!("net: error receiving message: {e}")).await;
break
}
}
Err(e) => {
print_debug(&print_tx, &format!("net: error receiving message: {e}")).await;
break
let (mut read_stream, mut write_stream) = conn.stream.into_split();
let (mut our_cipher, mut their_cipher) = if conn.noise.is_initiator() {
// if initiator, we write with first and read with second
let snow::CipherStates(our_cipher, their_cipher) = conn.noise.extract_cipherstates();
(our_cipher, their_cipher)
} else {
// if responder, we read with first and write with second
let snow::CipherStates(their_cipher, our_cipher) = conn.noise.extract_cipherstates();
(our_cipher, their_cipher)
};
let write_buf = &mut [0; 65536];
let write = async move {
while let Some(km) = peer_rx.recv().await {
let Ok(()) =
send_protocol_message(&km, &mut our_cipher, write_buf, &mut write_stream).await
else {
break;
};
}
};
let read_buf = &mut conn.buf;
let read_peer_name = peer_name.clone();
let read_print_tx = print_tx.clone();
let read = async move {
loop {
match recv_protocol_message(&mut their_cipher, read_buf, &mut read_stream).await {
Ok(km) => {
if km.source.node != read_peer_name {
print_loud(
&read_print_tx,
&format!("net: got message with spoofed source from {read_peer_name}!"),
)
.await;
break;
} else {
kernel_message_tx
.send(km)
.await
.expect("net: fatal: kernel receiver died");
}
}
},
Err(e) => {
print_debug(
&read_print_tx,
&format!("net: error receiving message: {e}"),
)
.await;
break;
}
}
}
};
tokio::select! {
_ = write => (),
_ = read => (),
}
let _ = conn.stream.shutdown().await;
print_debug(&print_tx, &format!("net: connection lost with {peer_name}")).await;
peers.remove(&peer_name);
}
async fn send_protocol_message(
km: &KernelMessage,
conn: &mut PeerConnection,
cipher: &mut snow::CipherState,
buf: &mut [u8],
stream: &mut OwnedWriteHalf,
) -> anyhow::Result<()> {
println!(
"initiatior: {}, sending_nonce: {}, receiving_nonce: {}\r",
conn.noise.is_initiator(),
conn.noise.sending_nonce(),
conn.noise.receiving_nonce()
);
println!("send_protocol_message\r");
let serialized = rmp_serde::to_vec(km)?;
if serialized.len() > MESSAGE_MAX_SIZE as usize {
return Err(anyhow::anyhow!("message too large"));
}
let outer_len = (serialized.len() as u32).to_be_bytes();
conn.stream.write_all(&outer_len).await?;
stream.write_all(&outer_len).await?;
println!("1\r");
// 65519 = 65535 - 16 (TAGLEN)
for payload in serialized.chunks(65519) {
let len = conn.noise.write_message(payload, &mut conn.buf)? as u16;
conn.stream.write_all(&len.to_be_bytes()).await?;
conn.stream.write_all(&conn.buf[..len as usize]).await?;
let len = cipher.encrypt(payload, buf)? as u16;
stream.write_all(&len.to_be_bytes()).await?;
println!(" 2\r");
stream.write_all(&buf[..len as usize]).await?;
println!(" 3\r");
}
Ok(conn.stream.flush().await?)
}
async fn recv_protocol_message_init(stream: &mut TcpStream) -> anyhow::Result<(usize, [u8; 4])> {
let mut outer_len = [0; 4];
let read = stream.read(&mut outer_len).await?;
Ok((read, outer_len))
println!("send_protocol_message flush\r");
Ok(stream.flush().await?)
}
/// any error in receiving a message will result in the connection being closed.
async fn recv_protocol_message(
conn: &mut PeerConnection,
already_read: usize,
mut outer_len: [u8; 4],
cipher: &mut snow::CipherState,
buf: &mut [u8],
stream: &mut OwnedReadHalf,
) -> anyhow::Result<KernelMessage> {
// fill out the rest of outer_len depending on how many bytes were read
if already_read < 4 {
conn.stream
.read_exact(&mut outer_len[already_read..])
.await?;
}
let outer_len = u32::from_be_bytes(outer_len) as usize;
println!("recv_protocol_message\r");
stream.read_exact(&mut buf[..4]).await?;
let outer_len = u32::from_be_bytes(buf[..4].try_into().unwrap()) as usize;
let mut msg = vec![0; outer_len];
let mut ptr = 0;
while ptr < outer_len {
let mut inner_len = [0; 2];
conn.stream.read_exact(&mut inner_len).await?;
stream.read_exact(&mut inner_len).await?;
let inner_len = u16::from_be_bytes(inner_len);
conn.stream
.read_exact(&mut conn.buf[..inner_len as usize])
.await?;
let read_len = conn
.noise
.read_message(&conn.buf[..inner_len as usize], &mut msg[ptr..])?;
stream.read_exact(&mut buf[..inner_len as usize]).await?;
let read_len = cipher.decrypt(&buf[..inner_len as usize], &mut msg[ptr..])?;
ptr += read_len;
}
println!("recv_protocol_message done\r");
Ok(rmp_serde::from_slice(&msg)?)
}

View File

@ -255,6 +255,7 @@ pub fn build_responder() -> (snow::HandshakeState, Vec<u8>) {
(
builder
.local_private_key(&keypair.private)
.unwrap()
.build_responder()
.expect("net: couldn't build responder?"),
keypair.public,
@ -269,6 +270,7 @@ pub fn build_initiator() -> (snow::HandshakeState, Vec<u8>) {
(
builder
.local_private_key(&keypair.private)
.unwrap()
.build_initiator()
.expect("net: couldn't build initiator?"),
keypair.public,