Allow peers to receive individual messages before starting message loop

Co-Authored-By: Nathan Sobo <nathan@zed.dev>
This commit is contained in:
Max Brunsfeld 2021-06-17 14:19:15 -07:00
parent 4d28d03e3f
commit 05a662b35e
2 changed files with 75 additions and 47 deletions

View File

@ -21,12 +21,14 @@ use std::{
}; };
type BoxedWriter = Pin<Box<dyn AsyncWrite + 'static + Send>>; type BoxedWriter = Pin<Box<dyn AsyncWrite + 'static + Send>>;
type BoxedReader = Pin<Box<dyn AsyncRead + 'static + Send>>;
#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)] #[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
pub struct ConnectionId(u32); pub struct ConnectionId(u32);
struct Connection { struct Connection {
writer: Mutex<MessageStream<BoxedWriter>>, writer: Mutex<MessageStream<BoxedWriter>>,
reader: Mutex<MessageStream<BoxedReader>>,
response_channels: Mutex<HashMap<u32, oneshot::Sender<proto::Envelope>>>, response_channels: Mutex<HashMap<u32, oneshot::Sender<proto::Envelope>>>,
next_message_id: AtomicU32, next_message_id: AtomicU32,
} }
@ -52,7 +54,8 @@ impl<T> TypedEnvelope<T> {
} }
pub struct Peer { pub struct Peer {
connections: RwLock<HashMap<ConnectionId, (Arc<Connection>, barrier::Sender)>>, connections: RwLock<HashMap<ConnectionId, Arc<Connection>>>,
connection_close_barriers: RwLock<HashMap<ConnectionId, barrier::Sender>>,
message_handlers: RwLock<Vec<MessageHandler>>, message_handlers: RwLock<Vec<MessageHandler>>,
handler_types: Mutex<HashSet<TypeId>>, handler_types: Mutex<HashSet<TypeId>>,
next_connection_id: AtomicU32, next_connection_id: AtomicU32,
@ -62,6 +65,7 @@ impl Peer {
pub fn new() -> Arc<Self> { pub fn new() -> Arc<Self> {
Arc::new(Self { Arc::new(Self {
connections: Default::default(), connections: Default::default(),
connection_close_barriers: Default::default(),
message_handlers: Default::default(), message_handlers: Default::default(),
handler_types: Default::default(), handler_types: Default::default(),
next_connection_id: Default::default(), next_connection_id: Default::default(),
@ -102,10 +106,7 @@ impl Peer {
rx rx
} }
pub async fn add_connection<Conn>( pub async fn add_connection<Conn>(self: &Arc<Self>, conn: Conn) -> ConnectionId
self: &Arc<Self>,
conn: Conn,
) -> (ConnectionId, impl Future<Output = Result<()>>)
where where
Conn: Clone + AsyncRead + AsyncWrite + Unpin + Send + 'static, Conn: Clone + AsyncRead + AsyncWrite + Unpin + Send + 'static,
{ {
@ -113,26 +114,44 @@ impl Peer {
self.next_connection_id self.next_connection_id
.fetch_add(1, atomic::Ordering::SeqCst), .fetch_add(1, atomic::Ordering::SeqCst),
); );
let (close_tx, mut close_rx) = barrier::channel(); self.connections.write().await.insert(
let connection = Arc::new(Connection { connection_id,
writer: Mutex::new(MessageStream::new(Box::pin(conn.clone()))), Arc::new(Connection {
response_channels: Default::default(), reader: Mutex::new(MessageStream::new(Box::pin(conn.clone()))),
next_message_id: Default::default(), writer: Mutex::new(MessageStream::new(Box::pin(conn.clone()))),
}); response_channels: Default::default(),
next_message_id: Default::default(),
}),
);
connection_id
}
self.connections pub async fn disconnect(&self, connection_id: ConnectionId) {
self.connections.write().await.remove(&connection_id);
self.connection_close_barriers
.write() .write()
.await .await
.insert(connection_id, (connection.clone(), close_tx)); .remove(&connection_id);
}
pub fn handle_messages(
self: &Arc<Self>,
connection_id: ConnectionId,
) -> impl Future<Output = Result<()>> + 'static {
let (close_tx, mut close_rx) = barrier::channel();
let this = self.clone(); let this = self.clone();
let handler_future = async move { async move {
this.connection_close_barriers
.write()
.await
.insert(connection_id, close_tx);
let connection = this.connection(connection_id).await?;
let closed = close_rx.recv(); let closed = close_rx.recv();
futures::pin_mut!(closed); futures::pin_mut!(closed);
let mut stream = MessageStream::new(conn);
loop { loop {
let read_message = stream.read_message(); let mut reader = connection.reader.lock().await;
let read_message = reader.read_message();
futures::pin_mut!(read_message); futures::pin_mut!(read_message);
match futures::future::select(read_message, &mut closed).await { match futures::future::select(read_message, &mut closed).await {
@ -181,13 +200,23 @@ impl Peer {
Either::Right(_) => return Ok(()), Either::Right(_) => return Ok(()),
} }
} }
}; }
(connection_id, handler_future)
} }
pub async fn disconnect(&self, connection_id: ConnectionId) { pub async fn receive<M: EnvelopedMessage>(
self.connections.write().await.remove(&connection_id); self: &Arc<Self>,
connection_id: ConnectionId,
) -> Result<TypedEnvelope<M>> {
let connection = self.connection(connection_id).await?;
let envelope = connection.reader.lock().await.read_message().await?;
let id = envelope.id;
let payload =
M::from_envelope(envelope).ok_or_else(|| anyhow!("unexpected message type"))?;
Ok(TypedEnvelope {
id,
connection_id,
payload,
})
} }
pub fn request<T: RequestMessage>( pub fn request<T: RequestMessage>(
@ -271,7 +300,6 @@ impl Peer {
.await .await
.get(&id) .get(&id)
.ok_or_else(|| anyhow!("unknown connection: {}", id.0))? .ok_or_else(|| anyhow!("unknown connection: {}", id.0))?
.0
.clone()) .clone())
} }
} }
@ -298,22 +326,22 @@ mod tests {
let server = Peer::new(); let server = Peer::new();
let client1 = Peer::new(); let client1 = Peer::new();
let client2 = Peer::new(); let client2 = Peer::new();
let (client1_conn_id, f1) = client1 let client1_conn_id = client1
.add_connection(UnixStream::connect(&socket_path).await.unwrap()) .add_connection(UnixStream::connect(&socket_path).await.unwrap())
.await; .await;
let (client2_conn_id, f2) = client2 let client2_conn_id = client2
.add_connection(UnixStream::connect(&socket_path).await.unwrap()) .add_connection(UnixStream::connect(&socket_path).await.unwrap())
.await; .await;
let (_, f3) = server let server_conn_id1 = server
.add_connection(listener.accept().await.unwrap().0) .add_connection(listener.accept().await.unwrap().0)
.await; .await;
let (_, f4) = server let server_conn_id2 = server
.add_connection(listener.accept().await.unwrap().0) .add_connection(listener.accept().await.unwrap().0)
.await; .await;
smol::spawn(f1).detach(); smol::spawn(client1.handle_messages(client1_conn_id)).detach();
smol::spawn(f2).detach(); smol::spawn(client2.handle_messages(client2_conn_id)).detach();
smol::spawn(f3).detach(); smol::spawn(server.handle_messages(server_conn_id1)).detach();
smol::spawn(f4).detach(); smol::spawn(server.handle_messages(server_conn_id2)).detach();
// define the expected requests and responses // define the expected requests and responses
let request1 = proto::OpenWorktree { let request1 = proto::OpenWorktree {
@ -428,21 +456,21 @@ mod tests {
let (mut server_conn, _) = listener.accept().await.unwrap(); let (mut server_conn, _) = listener.accept().await.unwrap();
let client = Peer::new(); let client = Peer::new();
let (connection_id, handler) = client.add_connection(client_conn).await; let connection_id = client.add_connection(client_conn).await;
smol::spawn(handler).detach(); let (mut incoming_messages_ended_tx, mut incoming_messages_ended_rx) =
barrier::channel();
let handle_messages = client.handle_messages(connection_id);
smol::spawn(async move {
handle_messages.await.unwrap();
incoming_messages_ended_tx.send(()).await.unwrap();
})
.detach();
client.disconnect(connection_id).await; client.disconnect(connection_id).await;
// Try sending an empty payload over and over, until the client is dropped and hangs up. incoming_messages_ended_rx.recv().await;
loop {
match server_conn.write(&[]).await { let err = server_conn.write(&[]).await.unwrap_err();
Ok(_) => {} assert_eq!(err.kind(), io::ErrorKind::BrokenPipe);
Err(err) => {
if err.kind() == io::ErrorKind::BrokenPipe {
break;
}
}
}
}
}); });
} }
@ -456,8 +484,8 @@ mod tests {
client_conn.close().await.unwrap(); client_conn.close().await.unwrap();
let client = Peer::new(); let client = Peer::new();
let (connection_id, handler) = client.add_connection(client_conn).await; let connection_id = client.add_connection(client_conn).await;
smol::spawn(handler).detach(); smol::spawn(client.handle_messages(connection_id)).detach();
let err = client let err = client
.request( .request(

View File

@ -691,8 +691,8 @@ impl Workspace {
// a TLS stream using `native-tls`. // a TLS stream using `native-tls`.
let stream = smol::net::TcpStream::connect(rpc_address).await?; let stream = smol::net::TcpStream::connect(rpc_address).await?;
let (connection_id, handler) = rpc.add_connection(stream).await; let connection_id = rpc.add_connection(stream).await;
executor.spawn(handler).detach(); executor.spawn(rpc.handle_messages(connection_id)).detach();
let auth_response = rpc let auth_response = rpc
.request( .request(