net: if run out of passthroughs, remove oldest

TODO: update the time used to judge to least recently sent?
This commit is contained in:
hosted-fornet 2024-10-02 18:54:11 -07:00
parent 045b7a1c20
commit 569403b62c
3 changed files with 69 additions and 13 deletions

View File

@ -7,7 +7,7 @@ use types::{
WS_PROTOCOL,
};
use {
dashmap::{DashMap, DashSet},
dashmap::DashMap,
ring::signature::Ed25519KeyPair,
std::sync::Arc,
tokio::task::JoinSet,
@ -58,7 +58,7 @@ pub async fn networking(
let peers: Peers = Peers::new(max_peers, ext.kernel_message_tx.clone());
// only used by routers
let pending_passthroughs: PendingPassthroughs = Arc::new(DashMap::new());
let active_passthroughs: ActivePassthroughs = Arc::new(DashSet::new());
let active_passthroughs: ActivePassthroughs = Arc::new(DashMap::new());
let net_data = NetData {
pki,
@ -246,7 +246,7 @@ async fn handle_local_request(
data.active_passthroughs.len()
));
for p in data.active_passthroughs.iter() {
printout.push_str(&format!(" {} -> {}\r\n", p.0, p.1));
printout.push_str(&format!(" {} -> {}\r\n", p.key().0, p.key().1));
}
}

View File

@ -3,7 +3,7 @@ use lib::types::core::{
Identity, KernelMessage, MessageSender, NetworkErrorSender, NodeId, PrintSender,
};
use {
dashmap::{DashMap, DashSet},
dashmap::DashMap,
ring::signature::Ed25519KeyPair,
serde::{Deserialize, Serialize},
std::sync::Arc,
@ -132,7 +132,7 @@ pub type OnchainPKI = Arc<DashMap<String, Identity>>;
/// (from, target) -> from's socket
///
/// only used by routers
pub type PendingPassthroughs = Arc<DashMap<(NodeId, NodeId), PendingStream>>;
pub type PendingPassthroughs = Arc<DashMap<(NodeId, NodeId), (PendingStream, u64)>>;
pub enum PendingStream {
WebSocket(WebSocketStream<MaybeTlsStream<TcpStream>>),
Tcp(TcpStream),
@ -141,7 +141,7 @@ pub enum PendingStream {
/// (from, target)
///
/// only used by routers
pub type ActivePassthroughs = Arc<DashSet<(NodeId, NodeId)>>;
pub type ActivePassthroughs = Arc<DashMap<(NodeId, NodeId), (u64, KillSender)>>;
impl PendingStream {
pub fn is_ws(&self) -> bool {
@ -152,6 +152,8 @@ impl PendingStream {
}
}
type KillSender = tokio::sync::mpsc::Sender<()>;
pub struct Peer {
pub identity: Identity,
/// If true, we are routing for them and have a RoutingClientConnection

View File

@ -38,17 +38,58 @@ pub async fn create_passthrough(
socket_1: PendingStream,
) -> anyhow::Result<()> {
// if we already are at the max number of passthroughs, reject
if data.max_passthroughs == 0 {
return Err(anyhow::anyhow!(
"passthrough denied: this node has disallowed passthroughs. Start node with `--max-passthroughs <VAL>` to allow passthroughs"
));
}
// remove pending before checking bound because otherwise we stop
// ourselves from matching pending if this connection will be
// the max_passthroughs passthrough
let maybe_pending = data
.pending_passthroughs
.remove(&(target_id.name.clone(), from_id.name.clone()));
if data.active_passthroughs.len() + data.pending_passthroughs.len()
>= data.max_passthroughs as usize
{
return Err(anyhow::anyhow!("max passthroughs reached"));
let oldest_active = data
.active_passthroughs
.iter()
.min_by_key(|p| p.0);
let (oldest_active_key, oldest_active_time, oldest_active_kill_sender) = match oldest_active {
None => (None, get_now(), None),
Some(oldest_active) => {
let (oldest_active_key, oldest_active_val) = oldest_active.pair();
let oldest_active_key = oldest_active_key.clone();
let (oldest_active_time, oldest_active_kill_sender) = oldest_active_val.clone();
(Some(oldest_active_key), oldest_active_time, Some(oldest_active_kill_sender))
}
};
let oldest_pending = data
.pending_passthroughs
.iter()
.min_by_key(|p| p.1);
let (oldest_pending_key, oldest_pending_time) = match oldest_pending {
None => (None, get_now()),
Some(oldest_pending) => {
let (oldest_pending_key, oldest_pending_val) = oldest_pending.pair();
let oldest_pending_key = oldest_pending_key.clone();
let (_, oldest_pending_time) = oldest_pending_val;
(Some(oldest_pending_key), oldest_pending_time.clone())
}
};
if oldest_active_time < oldest_pending_time {
// active key is oldest
oldest_active_kill_sender.unwrap().send(()).await.unwrap();
data.active_passthroughs.remove(&oldest_active_key.unwrap());
} else {
// pending key is oldest
data.pending_passthroughs.remove(&oldest_pending_key.unwrap());
}
}
// if the target has already generated a pending passthrough for this source,
// immediately match them
if let Some(((from, target), pending_stream)) = data
.pending_passthroughs
.remove(&(target_id.name.clone(), from_id.name.clone()))
{
if let Some(((from, target), (pending_stream, _))) = maybe_pending {
tokio::spawn(maintain_passthrough(
from,
target,
@ -136,8 +177,9 @@ pub async fn create_passthrough(
// or if the target node connects to us with a matching passthrough.
// TODO it is currently possible to have dangling passthroughs in the map
// if the target is "connected" to us but nonresponsive.
let now = get_now();
data.pending_passthroughs
.insert((from_id.name, target_id.name), socket_1);
.insert((from_id.name, target_id.name), (socket_1, now));
Ok(())
}
@ -149,7 +191,9 @@ pub async fn maintain_passthrough(
socket_2: PendingStream,
active_passthroughs: ActivePassthroughs,
) {
active_passthroughs.insert((from.clone(), target.clone()));
let now = get_now();
let (kill_sender, mut kill_receiver) = tokio::sync::mpsc::channel(1);
active_passthroughs.insert((from.clone(), target.clone()), (now, kill_sender));
match (socket_1, socket_2) {
(PendingStream::Tcp(socket_1), PendingStream::Tcp(socket_2)) => {
// do not use bidirectional because if one side closes,
@ -160,6 +204,7 @@ pub async fn maintain_passthrough(
tokio::select! {
_ = copy(&mut r1, &mut w2) => {},
_ = copy(&mut r2, &mut w1) => {},
_ = kill_receiver.recv() => {},
}
}
(PendingStream::WebSocket(mut socket_1), PendingStream::WebSocket(mut socket_2)) => {
@ -194,6 +239,7 @@ pub async fn maintain_passthrough(
break
}
}
_ = kill_receiver.recv() => break,
}
}
let _ = socket_1.close(None).await;
@ -402,3 +448,11 @@ pub async fn print_loud(print_tx: &PrintSender, content: &str) {
pub async fn print_debug(print_tx: &PrintSender, content: &str) {
Printout::new(2, content).send(print_tx).await;
}
pub fn get_now() -> u64 {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs();
now
}