mirror of
https://github.com/zed-industries/zed.git
synced 2024-11-14 02:43:19 +03:00
Allow peers to receive individual messages before starting message loop
Co-Authored-By: Nathan Sobo <nathan@zed.dev>
This commit is contained in:
parent
4d28d03e3f
commit
05a662b35e
@ -21,12 +21,14 @@ use std::{
|
|||||||
};
|
};
|
||||||
|
|
||||||
type BoxedWriter = Pin<Box<dyn AsyncWrite + 'static + Send>>;
|
type BoxedWriter = Pin<Box<dyn AsyncWrite + 'static + Send>>;
|
||||||
|
type BoxedReader = Pin<Box<dyn AsyncRead + 'static + Send>>;
|
||||||
|
|
||||||
#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
|
#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
|
||||||
pub struct ConnectionId(u32);
|
pub struct ConnectionId(u32);
|
||||||
|
|
||||||
struct Connection {
|
struct Connection {
|
||||||
writer: Mutex<MessageStream<BoxedWriter>>,
|
writer: Mutex<MessageStream<BoxedWriter>>,
|
||||||
|
reader: Mutex<MessageStream<BoxedReader>>,
|
||||||
response_channels: Mutex<HashMap<u32, oneshot::Sender<proto::Envelope>>>,
|
response_channels: Mutex<HashMap<u32, oneshot::Sender<proto::Envelope>>>,
|
||||||
next_message_id: AtomicU32,
|
next_message_id: AtomicU32,
|
||||||
}
|
}
|
||||||
@ -52,7 +54,8 @@ impl<T> TypedEnvelope<T> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub struct Peer {
|
pub struct Peer {
|
||||||
connections: RwLock<HashMap<ConnectionId, (Arc<Connection>, barrier::Sender)>>,
|
connections: RwLock<HashMap<ConnectionId, Arc<Connection>>>,
|
||||||
|
connection_close_barriers: RwLock<HashMap<ConnectionId, barrier::Sender>>,
|
||||||
message_handlers: RwLock<Vec<MessageHandler>>,
|
message_handlers: RwLock<Vec<MessageHandler>>,
|
||||||
handler_types: Mutex<HashSet<TypeId>>,
|
handler_types: Mutex<HashSet<TypeId>>,
|
||||||
next_connection_id: AtomicU32,
|
next_connection_id: AtomicU32,
|
||||||
@ -62,6 +65,7 @@ impl Peer {
|
|||||||
pub fn new() -> Arc<Self> {
|
pub fn new() -> Arc<Self> {
|
||||||
Arc::new(Self {
|
Arc::new(Self {
|
||||||
connections: Default::default(),
|
connections: Default::default(),
|
||||||
|
connection_close_barriers: Default::default(),
|
||||||
message_handlers: Default::default(),
|
message_handlers: Default::default(),
|
||||||
handler_types: Default::default(),
|
handler_types: Default::default(),
|
||||||
next_connection_id: Default::default(),
|
next_connection_id: Default::default(),
|
||||||
@ -102,10 +106,7 @@ impl Peer {
|
|||||||
rx
|
rx
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn add_connection<Conn>(
|
pub async fn add_connection<Conn>(self: &Arc<Self>, conn: Conn) -> ConnectionId
|
||||||
self: &Arc<Self>,
|
|
||||||
conn: Conn,
|
|
||||||
) -> (ConnectionId, impl Future<Output = Result<()>>)
|
|
||||||
where
|
where
|
||||||
Conn: Clone + AsyncRead + AsyncWrite + Unpin + Send + 'static,
|
Conn: Clone + AsyncRead + AsyncWrite + Unpin + Send + 'static,
|
||||||
{
|
{
|
||||||
@ -113,26 +114,44 @@ impl Peer {
|
|||||||
self.next_connection_id
|
self.next_connection_id
|
||||||
.fetch_add(1, atomic::Ordering::SeqCst),
|
.fetch_add(1, atomic::Ordering::SeqCst),
|
||||||
);
|
);
|
||||||
let (close_tx, mut close_rx) = barrier::channel();
|
self.connections.write().await.insert(
|
||||||
let connection = Arc::new(Connection {
|
connection_id,
|
||||||
writer: Mutex::new(MessageStream::new(Box::pin(conn.clone()))),
|
Arc::new(Connection {
|
||||||
response_channels: Default::default(),
|
reader: Mutex::new(MessageStream::new(Box::pin(conn.clone()))),
|
||||||
next_message_id: Default::default(),
|
writer: Mutex::new(MessageStream::new(Box::pin(conn.clone()))),
|
||||||
});
|
response_channels: Default::default(),
|
||||||
|
next_message_id: Default::default(),
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
connection_id
|
||||||
|
}
|
||||||
|
|
||||||
self.connections
|
pub async fn disconnect(&self, connection_id: ConnectionId) {
|
||||||
|
self.connections.write().await.remove(&connection_id);
|
||||||
|
self.connection_close_barriers
|
||||||
.write()
|
.write()
|
||||||
.await
|
.await
|
||||||
.insert(connection_id, (connection.clone(), close_tx));
|
.remove(&connection_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn handle_messages(
|
||||||
|
self: &Arc<Self>,
|
||||||
|
connection_id: ConnectionId,
|
||||||
|
) -> impl Future<Output = Result<()>> + 'static {
|
||||||
|
let (close_tx, mut close_rx) = barrier::channel();
|
||||||
let this = self.clone();
|
let this = self.clone();
|
||||||
let handler_future = async move {
|
async move {
|
||||||
|
this.connection_close_barriers
|
||||||
|
.write()
|
||||||
|
.await
|
||||||
|
.insert(connection_id, close_tx);
|
||||||
|
let connection = this.connection(connection_id).await?;
|
||||||
let closed = close_rx.recv();
|
let closed = close_rx.recv();
|
||||||
futures::pin_mut!(closed);
|
futures::pin_mut!(closed);
|
||||||
|
|
||||||
let mut stream = MessageStream::new(conn);
|
|
||||||
loop {
|
loop {
|
||||||
let read_message = stream.read_message();
|
let mut reader = connection.reader.lock().await;
|
||||||
|
let read_message = reader.read_message();
|
||||||
futures::pin_mut!(read_message);
|
futures::pin_mut!(read_message);
|
||||||
|
|
||||||
match futures::future::select(read_message, &mut closed).await {
|
match futures::future::select(read_message, &mut closed).await {
|
||||||
@ -181,13 +200,23 @@ impl Peer {
|
|||||||
Either::Right(_) => return Ok(()),
|
Either::Right(_) => return Ok(()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
}
|
||||||
|
|
||||||
(connection_id, handler_future)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn disconnect(&self, connection_id: ConnectionId) {
|
pub async fn receive<M: EnvelopedMessage>(
|
||||||
self.connections.write().await.remove(&connection_id);
|
self: &Arc<Self>,
|
||||||
|
connection_id: ConnectionId,
|
||||||
|
) -> Result<TypedEnvelope<M>> {
|
||||||
|
let connection = self.connection(connection_id).await?;
|
||||||
|
let envelope = connection.reader.lock().await.read_message().await?;
|
||||||
|
let id = envelope.id;
|
||||||
|
let payload =
|
||||||
|
M::from_envelope(envelope).ok_or_else(|| anyhow!("unexpected message type"))?;
|
||||||
|
Ok(TypedEnvelope {
|
||||||
|
id,
|
||||||
|
connection_id,
|
||||||
|
payload,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn request<T: RequestMessage>(
|
pub fn request<T: RequestMessage>(
|
||||||
@ -271,7 +300,6 @@ impl Peer {
|
|||||||
.await
|
.await
|
||||||
.get(&id)
|
.get(&id)
|
||||||
.ok_or_else(|| anyhow!("unknown connection: {}", id.0))?
|
.ok_or_else(|| anyhow!("unknown connection: {}", id.0))?
|
||||||
.0
|
|
||||||
.clone())
|
.clone())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -298,22 +326,22 @@ mod tests {
|
|||||||
let server = Peer::new();
|
let server = Peer::new();
|
||||||
let client1 = Peer::new();
|
let client1 = Peer::new();
|
||||||
let client2 = Peer::new();
|
let client2 = Peer::new();
|
||||||
let (client1_conn_id, f1) = client1
|
let client1_conn_id = client1
|
||||||
.add_connection(UnixStream::connect(&socket_path).await.unwrap())
|
.add_connection(UnixStream::connect(&socket_path).await.unwrap())
|
||||||
.await;
|
.await;
|
||||||
let (client2_conn_id, f2) = client2
|
let client2_conn_id = client2
|
||||||
.add_connection(UnixStream::connect(&socket_path).await.unwrap())
|
.add_connection(UnixStream::connect(&socket_path).await.unwrap())
|
||||||
.await;
|
.await;
|
||||||
let (_, f3) = server
|
let server_conn_id1 = server
|
||||||
.add_connection(listener.accept().await.unwrap().0)
|
.add_connection(listener.accept().await.unwrap().0)
|
||||||
.await;
|
.await;
|
||||||
let (_, f4) = server
|
let server_conn_id2 = server
|
||||||
.add_connection(listener.accept().await.unwrap().0)
|
.add_connection(listener.accept().await.unwrap().0)
|
||||||
.await;
|
.await;
|
||||||
smol::spawn(f1).detach();
|
smol::spawn(client1.handle_messages(client1_conn_id)).detach();
|
||||||
smol::spawn(f2).detach();
|
smol::spawn(client2.handle_messages(client2_conn_id)).detach();
|
||||||
smol::spawn(f3).detach();
|
smol::spawn(server.handle_messages(server_conn_id1)).detach();
|
||||||
smol::spawn(f4).detach();
|
smol::spawn(server.handle_messages(server_conn_id2)).detach();
|
||||||
|
|
||||||
// define the expected requests and responses
|
// define the expected requests and responses
|
||||||
let request1 = proto::OpenWorktree {
|
let request1 = proto::OpenWorktree {
|
||||||
@ -428,21 +456,21 @@ mod tests {
|
|||||||
let (mut server_conn, _) = listener.accept().await.unwrap();
|
let (mut server_conn, _) = listener.accept().await.unwrap();
|
||||||
|
|
||||||
let client = Peer::new();
|
let client = Peer::new();
|
||||||
let (connection_id, handler) = client.add_connection(client_conn).await;
|
let connection_id = client.add_connection(client_conn).await;
|
||||||
smol::spawn(handler).detach();
|
let (mut incoming_messages_ended_tx, mut incoming_messages_ended_rx) =
|
||||||
|
barrier::channel();
|
||||||
|
let handle_messages = client.handle_messages(connection_id);
|
||||||
|
smol::spawn(async move {
|
||||||
|
handle_messages.await.unwrap();
|
||||||
|
incoming_messages_ended_tx.send(()).await.unwrap();
|
||||||
|
})
|
||||||
|
.detach();
|
||||||
client.disconnect(connection_id).await;
|
client.disconnect(connection_id).await;
|
||||||
|
|
||||||
// Try sending an empty payload over and over, until the client is dropped and hangs up.
|
incoming_messages_ended_rx.recv().await;
|
||||||
loop {
|
|
||||||
match server_conn.write(&[]).await {
|
let err = server_conn.write(&[]).await.unwrap_err();
|
||||||
Ok(_) => {}
|
assert_eq!(err.kind(), io::ErrorKind::BrokenPipe);
|
||||||
Err(err) => {
|
|
||||||
if err.kind() == io::ErrorKind::BrokenPipe {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -456,8 +484,8 @@ mod tests {
|
|||||||
client_conn.close().await.unwrap();
|
client_conn.close().await.unwrap();
|
||||||
|
|
||||||
let client = Peer::new();
|
let client = Peer::new();
|
||||||
let (connection_id, handler) = client.add_connection(client_conn).await;
|
let connection_id = client.add_connection(client_conn).await;
|
||||||
smol::spawn(handler).detach();
|
smol::spawn(client.handle_messages(connection_id)).detach();
|
||||||
|
|
||||||
let err = client
|
let err = client
|
||||||
.request(
|
.request(
|
||||||
|
@ -691,8 +691,8 @@ impl Workspace {
|
|||||||
// a TLS stream using `native-tls`.
|
// a TLS stream using `native-tls`.
|
||||||
let stream = smol::net::TcpStream::connect(rpc_address).await?;
|
let stream = smol::net::TcpStream::connect(rpc_address).await?;
|
||||||
|
|
||||||
let (connection_id, handler) = rpc.add_connection(stream).await;
|
let connection_id = rpc.add_connection(stream).await;
|
||||||
executor.spawn(handler).detach();
|
executor.spawn(rpc.handle_messages(connection_id)).detach();
|
||||||
|
|
||||||
let auth_response = rpc
|
let auth_response = rpc
|
||||||
.request(
|
.request(
|
||||||
|
Loading…
Reference in New Issue
Block a user