From 6814455baf8673ba796d10595bfee04535f5c4a8 Mon Sep 17 00:00:00 2001 From: dr-frmr Date: Tue, 28 May 2024 22:40:00 -0600 Subject: [PATCH] ugh: use custom `snow` for async --- Cargo.lock | 5 +- kinode/Cargo.toml | 4 +- kinode/src/net/tcp/utils.rs | 155 +++++++++++++++++++----------------- kinode/src/net/utils.rs | 2 + 4 files changed, 87 insertions(+), 79 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c4015859..340073c3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5033,9 +5033,8 @@ checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" [[package]] name = "snow" -version = "0.9.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "850948bee068e713b8ab860fe1adc4d109676ab4c3b621fd8147f06b261f2f85" +version = "0.9.0" +source = "git+https://github.com/dr-frmr/snow?branch=dr/extract_cipherstates#1d4eb5f6747aa59aabb32bbbe698fb4bb7dfb9a4" dependencies = [ "aes-gcm", "blake2", diff --git a/kinode/Cargo.toml b/kinode/Cargo.toml index 7957fc6c..929f6d42 100644 --- a/kinode/Cargo.toml +++ b/kinode/Cargo.toml @@ -80,7 +80,9 @@ serde_json = "1.0" serde_urlencoded = "0.7" sha2 = "0.10" sha3 = "0.10.8" -snow = { version = "0.9.5", features = ["ring-resolver"] } +# snow = { version = "0.9.5", features = ["ring-resolver"] } +# unfortunately need to use forked version for async use and in-place encryption +snow = { git = "https://github.com/dr-frmr/snow", branch = "dr/extract_cipherstates", features = ["ring-resolver"] } socket2 = "0.5.7" static_dir = "0.2.0" thiserror = "1.0" diff --git a/kinode/src/net/tcp/utils.rs b/kinode/src/net/tcp/utils.rs index 35c12444..8a7ec5af 100644 --- a/kinode/src/net/tcp/utils.rs +++ b/kinode/src/net/tcp/utils.rs @@ -6,7 +6,7 @@ use crate::net::{ use lib::types::core::{KernelMessage, MessageSender, NodeId, PrintSender}; use { tokio::io::{AsyncReadExt, AsyncWriteExt}, - tokio::net::TcpStream, + tokio::net::{tcp::OwnedReadHalf, tcp::OwnedWriteHalf, TcpStream}, tokio::sync::mpsc::UnboundedReceiver, }; @@ -29,112 +29,117 @@ pub async fn maintain_connection( .set_tcp_keepalive(&ka) .expect("failed to set tcp keepalive"); - loop { - tokio::select! { - maybe_recv = peer_rx.recv() => { - let Some(km) = maybe_recv else { - break - }; - let Ok(()) = send_protocol_message(&km, &mut conn).await else { - break - }; - }, - outer_len = recv_protocol_message_init(&mut conn.stream) => { - match outer_len { - Ok((read, outer_len)) => match recv_protocol_message(&mut conn, read, outer_len).await { - Ok(km) => { - if km.source.node != peer_name { - print_loud( - &print_tx, - &format!( - "net: got message with spoofed source from {peer_name}!" - ), - ).await; - break - } else { - kernel_message_tx.send(km).await.expect("net: fatal: kernel receiver died"); - continue - } - } - Err(e) => { - print_debug(&print_tx, &format!("net: error receiving message: {e}")).await; - break - } - } - Err(e) => { - print_debug(&print_tx, &format!("net: error receiving message: {e}")).await; - break + let (mut read_stream, mut write_stream) = conn.stream.into_split(); + let (mut our_cipher, mut their_cipher) = if conn.noise.is_initiator() { + // if initiator, we write with first and read with second + let snow::CipherStates(our_cipher, their_cipher) = conn.noise.extract_cipherstates(); + (our_cipher, their_cipher) + } else { + // if responder, we read with first and write with second + let snow::CipherStates(their_cipher, our_cipher) = conn.noise.extract_cipherstates(); + (our_cipher, their_cipher) + }; + + let write_buf = &mut [0; 65536]; + let write = async move { + while let Some(km) = peer_rx.recv().await { + let Ok(()) = + send_protocol_message(&km, &mut our_cipher, write_buf, &mut write_stream).await + else { + break; + }; + } + }; + + let read_buf = &mut conn.buf; + let read_peer_name = peer_name.clone(); + let read_print_tx = print_tx.clone(); + let read = async move { + loop { + match recv_protocol_message(&mut their_cipher, read_buf, &mut read_stream).await { + Ok(km) => { + if km.source.node != read_peer_name { + print_loud( + &read_print_tx, + &format!("net: got message with spoofed source from {read_peer_name}!"), + ) + .await; + break; + } else { + kernel_message_tx + .send(km) + .await + .expect("net: fatal: kernel receiver died"); } } - }, + Err(e) => { + print_debug( + &read_print_tx, + &format!("net: error receiving message: {e}"), + ) + .await; + break; + } + } } + }; + + tokio::select! { + _ = write => (), + _ = read => (), } - let _ = conn.stream.shutdown().await; + print_debug(&print_tx, &format!("net: connection lost with {peer_name}")).await; peers.remove(&peer_name); } async fn send_protocol_message( km: &KernelMessage, - conn: &mut PeerConnection, + cipher: &mut snow::CipherState, + buf: &mut [u8], + stream: &mut OwnedWriteHalf, ) -> anyhow::Result<()> { - println!( - "initiatior: {}, sending_nonce: {}, receiving_nonce: {}\r", - conn.noise.is_initiator(), - conn.noise.sending_nonce(), - conn.noise.receiving_nonce() - ); + println!("send_protocol_message\r"); let serialized = rmp_serde::to_vec(km)?; if serialized.len() > MESSAGE_MAX_SIZE as usize { return Err(anyhow::anyhow!("message too large")); } - let outer_len = (serialized.len() as u32).to_be_bytes(); - conn.stream.write_all(&outer_len).await?; - + stream.write_all(&outer_len).await?; + println!("1\r"); // 65519 = 65535 - 16 (TAGLEN) for payload in serialized.chunks(65519) { - let len = conn.noise.write_message(payload, &mut conn.buf)? as u16; - conn.stream.write_all(&len.to_be_bytes()).await?; - conn.stream.write_all(&conn.buf[..len as usize]).await?; + let len = cipher.encrypt(payload, buf)? as u16; + stream.write_all(&len.to_be_bytes()).await?; + println!(" 2\r"); + stream.write_all(&buf[..len as usize]).await?; + println!(" 3\r"); } - Ok(conn.stream.flush().await?) -} - -async fn recv_protocol_message_init(stream: &mut TcpStream) -> anyhow::Result<(usize, [u8; 4])> { - let mut outer_len = [0; 4]; - let read = stream.read(&mut outer_len).await?; - Ok((read, outer_len)) + println!("send_protocol_message flush\r"); + Ok(stream.flush().await?) } /// any error in receiving a message will result in the connection being closed. async fn recv_protocol_message( - conn: &mut PeerConnection, - already_read: usize, - mut outer_len: [u8; 4], + cipher: &mut snow::CipherState, + buf: &mut [u8], + stream: &mut OwnedReadHalf, ) -> anyhow::Result { - // fill out the rest of outer_len depending on how many bytes were read - if already_read < 4 { - conn.stream - .read_exact(&mut outer_len[already_read..]) - .await?; - } - let outer_len = u32::from_be_bytes(outer_len) as usize; + println!("recv_protocol_message\r"); + stream.read_exact(&mut buf[..4]).await?; + let outer_len = u32::from_be_bytes(buf[..4].try_into().unwrap()) as usize; let mut msg = vec![0; outer_len]; let mut ptr = 0; while ptr < outer_len { let mut inner_len = [0; 2]; - conn.stream.read_exact(&mut inner_len).await?; + stream.read_exact(&mut inner_len).await?; let inner_len = u16::from_be_bytes(inner_len); - conn.stream - .read_exact(&mut conn.buf[..inner_len as usize]) - .await?; - let read_len = conn - .noise - .read_message(&conn.buf[..inner_len as usize], &mut msg[ptr..])?; + stream.read_exact(&mut buf[..inner_len as usize]).await?; + let read_len = cipher.decrypt(&buf[..inner_len as usize], &mut msg[ptr..])?; ptr += read_len; } + println!("recv_protocol_message done\r"); Ok(rmp_serde::from_slice(&msg)?) } diff --git a/kinode/src/net/utils.rs b/kinode/src/net/utils.rs index f34b9928..006dab88 100644 --- a/kinode/src/net/utils.rs +++ b/kinode/src/net/utils.rs @@ -255,6 +255,7 @@ pub fn build_responder() -> (snow::HandshakeState, Vec) { ( builder .local_private_key(&keypair.private) + .unwrap() .build_responder() .expect("net: couldn't build responder?"), keypair.public, @@ -269,6 +270,7 @@ pub fn build_initiator() -> (snow::HandshakeState, Vec) { ( builder .local_private_key(&keypair.private) + .unwrap() .build_initiator() .expect("net: couldn't build initiator?"), keypair.public,