diff --git a/kinode/src/net/tcp/mod.rs b/kinode/src/net/tcp/mod.rs index faeaf9b3..41f40eb0 100644 --- a/kinode/src/net/tcp/mod.rs +++ b/kinode/src/net/tcp/mod.rs @@ -207,12 +207,16 @@ async fn recv_connection( &their_id, )?; - let (peer, peer_rx) = Peer::new(their_id.clone(), their_handshake.proxy_request); - data.peers.insert(their_id.name.clone(), peer).await; + // if we already have a connection to this peer, kill it so we + // don't build a duplicate connection + if let Some(mut peer) = data.peers.get_mut(&their_handshake.name) { + peer.kill(); + } - tokio::spawn(utils::maintain_connection( + let (mut peer, peer_rx) = Peer::new(their_id.clone(), their_handshake.proxy_request); + peer.handle = Some(tokio::spawn(utils::maintain_connection( their_handshake.name, - data.peers, + data.peers.clone(), PeerConnection { noise: noise.into_transport_mode()?, buf, @@ -221,7 +225,8 @@ async fn recv_connection( peer_rx, ext.kernel_message_tx, ext.print_tx, - )); + ))); + data.peers.insert(their_id.name.clone(), peer).await; Ok(()) } @@ -322,17 +327,17 @@ pub async fn recv_via_router( }; match connect_with_handshake_via_router(&ext, &peer_id, &router_id, stream).await { Ok(connection) => { - let (peer, peer_rx) = Peer::new(peer_id.clone(), false); - data.peers.insert(peer_id.name.clone(), peer).await; // maintain direct connection - tokio::spawn(utils::maintain_connection( - peer_id.name, + let (mut peer, peer_rx) = Peer::new(peer_id.clone(), false); + peer.handle = Some(tokio::spawn(utils::maintain_connection( + peer_id.name.clone(), data.peers.clone(), connection, peer_rx, ext.kernel_message_tx, ext.print_tx, - )); + ))); + data.peers.insert(peer_id.name, peer).await; } Err(e) => { print_debug(&ext.print_tx, &format!("net: error getting routed: {e}")).await; diff --git a/kinode/src/net/types.rs b/kinode/src/net/types.rs index 7b37d918..669393a9 100644 --- a/kinode/src/net/types.rs +++ b/kinode/src/net/types.rs @@ -175,6 +175,7 @@ pub struct Peer { /// associated with them. We can send them prompts to establish Passthroughs. pub routing_for: bool, pub sender: UnboundedSender, + pub handle: Option>, /// unix timestamp of last message sent *or* received pub last_message: u64, } @@ -189,6 +190,7 @@ impl Peer { identity, routing_for, sender: peer_tx, + handle: None, last_message: std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) .unwrap() @@ -215,6 +217,12 @@ impl Peer { .unwrap() .as_secs() } + + pub fn kill(&mut self) { + if let Some(handle) = self.handle.take() { + handle.abort(); + } + } } /// [`Identity`], with additional fields for networking. #[derive(Clone)] diff --git a/kinode/src/net/ws/mod.rs b/kinode/src/net/ws/mod.rs index 0e2b9714..11a01077 100644 --- a/kinode/src/net/ws/mod.rs +++ b/kinode/src/net/ws/mod.rs @@ -187,17 +187,17 @@ pub async fn recv_via_router( }; match connect_with_handshake_via_router(&ext, &peer_id, &router_id, socket).await { Ok(connection) => { - let (peer, peer_rx) = Peer::new(peer_id.clone(), false); - data.peers.insert(peer_id.name.clone(), peer).await; // maintain direct connection - tokio::spawn(utils::maintain_connection( - peer_id.name, + let (mut peer, peer_rx) = Peer::new(peer_id.clone(), false); + peer.handle = Some(tokio::spawn(utils::maintain_connection( + peer_id.name.clone(), data.peers.clone(), connection, peer_rx, ext.kernel_message_tx, ext.print_tx, - )); + ))); + data.peers.insert(peer_id.name, peer).await; } Err(e) => { print_debug(&ext.print_tx, &format!("net: error getting routed: {e}")).await; @@ -263,12 +263,16 @@ async fn recv_connection( &their_id, )?; - let (peer, peer_rx) = Peer::new(their_id.clone(), their_handshake.proxy_request); - data.peers.insert(their_id.name.clone(), peer).await; + // if we already have a connection to this peer, kill it so we + // don't build a duplicate connection + if let Some(mut peer) = data.peers.get_mut(&their_handshake.name) { + peer.kill(); + } - tokio::spawn(utils::maintain_connection( + let (mut peer, peer_rx) = Peer::new(their_id.clone(), their_handshake.proxy_request); + peer.handle = Some(tokio::spawn(utils::maintain_connection( their_handshake.name, - data.peers, + data.peers.clone(), PeerConnection { noise: noise.into_transport_mode()?, buf, @@ -277,7 +281,8 @@ async fn recv_connection( peer_rx, ext.kernel_message_tx, ext.print_tx, - )); + ))); + data.peers.insert(their_id.name.clone(), peer).await; Ok(()) }