Allow peers to receive individual messages before starting message loop

Co-Authored-By: Nathan Sobo <nathan@zed.dev>
This commit is contained in:
Max Brunsfeld 2021-06-17 14:19:15 -07:00
parent 4d28d03e3f
commit 05a662b35e
2 changed files with 75 additions and 47 deletions

View File

@ -21,12 +21,14 @@ use std::{
};
type BoxedWriter = Pin<Box<dyn AsyncWrite + 'static + Send>>;
type BoxedReader = Pin<Box<dyn AsyncRead + 'static + Send>>;
#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
pub struct ConnectionId(u32);
struct Connection {
writer: Mutex<MessageStream<BoxedWriter>>,
reader: Mutex<MessageStream<BoxedReader>>,
response_channels: Mutex<HashMap<u32, oneshot::Sender<proto::Envelope>>>,
next_message_id: AtomicU32,
}
@ -52,7 +54,8 @@ impl<T> TypedEnvelope<T> {
}
pub struct Peer {
connections: RwLock<HashMap<ConnectionId, (Arc<Connection>, barrier::Sender)>>,
connections: RwLock<HashMap<ConnectionId, Arc<Connection>>>,
connection_close_barriers: RwLock<HashMap<ConnectionId, barrier::Sender>>,
message_handlers: RwLock<Vec<MessageHandler>>,
handler_types: Mutex<HashSet<TypeId>>,
next_connection_id: AtomicU32,
@ -62,6 +65,7 @@ impl Peer {
pub fn new() -> Arc<Self> {
Arc::new(Self {
connections: Default::default(),
connection_close_barriers: Default::default(),
message_handlers: Default::default(),
handler_types: Default::default(),
next_connection_id: Default::default(),
@ -102,10 +106,7 @@ impl Peer {
rx
}
pub async fn add_connection<Conn>(
self: &Arc<Self>,
conn: Conn,
) -> (ConnectionId, impl Future<Output = Result<()>>)
pub async fn add_connection<Conn>(self: &Arc<Self>, conn: Conn) -> ConnectionId
where
Conn: Clone + AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
@ -113,26 +114,44 @@ impl Peer {
self.next_connection_id
.fetch_add(1, atomic::Ordering::SeqCst),
);
let (close_tx, mut close_rx) = barrier::channel();
let connection = Arc::new(Connection {
writer: Mutex::new(MessageStream::new(Box::pin(conn.clone()))),
response_channels: Default::default(),
next_message_id: Default::default(),
});
self.connections.write().await.insert(
connection_id,
Arc::new(Connection {
reader: Mutex::new(MessageStream::new(Box::pin(conn.clone()))),
writer: Mutex::new(MessageStream::new(Box::pin(conn.clone()))),
response_channels: Default::default(),
next_message_id: Default::default(),
}),
);
connection_id
}
self.connections
pub async fn disconnect(&self, connection_id: ConnectionId) {
self.connections.write().await.remove(&connection_id);
self.connection_close_barriers
.write()
.await
.insert(connection_id, (connection.clone(), close_tx));
.remove(&connection_id);
}
pub fn handle_messages(
self: &Arc<Self>,
connection_id: ConnectionId,
) -> impl Future<Output = Result<()>> + 'static {
let (close_tx, mut close_rx) = barrier::channel();
let this = self.clone();
let handler_future = async move {
async move {
this.connection_close_barriers
.write()
.await
.insert(connection_id, close_tx);
let connection = this.connection(connection_id).await?;
let closed = close_rx.recv();
futures::pin_mut!(closed);
let mut stream = MessageStream::new(conn);
loop {
let read_message = stream.read_message();
let mut reader = connection.reader.lock().await;
let read_message = reader.read_message();
futures::pin_mut!(read_message);
match futures::future::select(read_message, &mut closed).await {
@ -181,13 +200,23 @@ impl Peer {
Either::Right(_) => return Ok(()),
}
}
};
(connection_id, handler_future)
}
}
pub async fn disconnect(&self, connection_id: ConnectionId) {
self.connections.write().await.remove(&connection_id);
pub async fn receive<M: EnvelopedMessage>(
self: &Arc<Self>,
connection_id: ConnectionId,
) -> Result<TypedEnvelope<M>> {
let connection = self.connection(connection_id).await?;
let envelope = connection.reader.lock().await.read_message().await?;
let id = envelope.id;
let payload =
M::from_envelope(envelope).ok_or_else(|| anyhow!("unexpected message type"))?;
Ok(TypedEnvelope {
id,
connection_id,
payload,
})
}
pub fn request<T: RequestMessage>(
@ -271,7 +300,6 @@ impl Peer {
.await
.get(&id)
.ok_or_else(|| anyhow!("unknown connection: {}", id.0))?
.0
.clone())
}
}
@ -298,22 +326,22 @@ mod tests {
let server = Peer::new();
let client1 = Peer::new();
let client2 = Peer::new();
let (client1_conn_id, f1) = client1
let client1_conn_id = client1
.add_connection(UnixStream::connect(&socket_path).await.unwrap())
.await;
let (client2_conn_id, f2) = client2
let client2_conn_id = client2
.add_connection(UnixStream::connect(&socket_path).await.unwrap())
.await;
let (_, f3) = server
let server_conn_id1 = server
.add_connection(listener.accept().await.unwrap().0)
.await;
let (_, f4) = server
let server_conn_id2 = server
.add_connection(listener.accept().await.unwrap().0)
.await;
smol::spawn(f1).detach();
smol::spawn(f2).detach();
smol::spawn(f3).detach();
smol::spawn(f4).detach();
smol::spawn(client1.handle_messages(client1_conn_id)).detach();
smol::spawn(client2.handle_messages(client2_conn_id)).detach();
smol::spawn(server.handle_messages(server_conn_id1)).detach();
smol::spawn(server.handle_messages(server_conn_id2)).detach();
// define the expected requests and responses
let request1 = proto::OpenWorktree {
@ -428,21 +456,21 @@ mod tests {
let (mut server_conn, _) = listener.accept().await.unwrap();
let client = Peer::new();
let (connection_id, handler) = client.add_connection(client_conn).await;
smol::spawn(handler).detach();
let connection_id = client.add_connection(client_conn).await;
let (mut incoming_messages_ended_tx, mut incoming_messages_ended_rx) =
barrier::channel();
let handle_messages = client.handle_messages(connection_id);
smol::spawn(async move {
handle_messages.await.unwrap();
incoming_messages_ended_tx.send(()).await.unwrap();
})
.detach();
client.disconnect(connection_id).await;
// Try sending an empty payload over and over, until the client is dropped and hangs up.
loop {
match server_conn.write(&[]).await {
Ok(_) => {}
Err(err) => {
if err.kind() == io::ErrorKind::BrokenPipe {
break;
}
}
}
}
incoming_messages_ended_rx.recv().await;
let err = server_conn.write(&[]).await.unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::BrokenPipe);
});
}
@ -456,8 +484,8 @@ mod tests {
client_conn.close().await.unwrap();
let client = Peer::new();
let (connection_id, handler) = client.add_connection(client_conn).await;
smol::spawn(handler).detach();
let connection_id = client.add_connection(client_conn).await;
smol::spawn(client.handle_messages(connection_id)).detach();
let err = client
.request(

View File

@ -691,8 +691,8 @@ impl Workspace {
// a TLS stream using `native-tls`.
let stream = smol::net::TcpStream::connect(rpc_address).await?;
let (connection_id, handler) = rpc.add_connection(stream).await;
executor.spawn(handler).detach();
let connection_id = rpc.add_connection(stream).await;
executor.spawn(rpc.handle_messages(connection_id)).detach();
let auth_response = rpc
.request(