diff --git a/crates/client/src/test.rs b/crates/client/src/test.rs index f5d88c2d9a..f630d9c0ee 100644 --- a/crates/client/src/test.rs +++ b/crates/client/src/test.rs @@ -6,6 +6,7 @@ use anyhow::{anyhow, Result}; use futures::{future::BoxFuture, stream::BoxStream, Future, StreamExt}; use gpui::{executor, ModelHandle, TestAppContext}; use parking_lot::Mutex; +use postage::barrier; use rpc::{proto, ConnectionId, Peer, Receipt, TypedEnvelope}; use std::{fmt, rc::Rc, sync::Arc}; @@ -22,6 +23,7 @@ struct FakeServerState { connection_id: Option, forbid_connections: bool, auth_count: usize, + connection_killer: Option, access_token: usize, } @@ -74,13 +76,15 @@ impl FakeServer { Err(EstablishConnectionError::Unauthorized)? } - let (client_conn, server_conn, _) = Connection::in_memory(cx.background()); + let (client_conn, server_conn, kill) = + Connection::in_memory(cx.background()); let (connection_id, io, incoming) = peer.add_test_connection(server_conn, cx.background()).await; cx.background().spawn(io).detach(); let mut state = state.lock(); state.connection_id = Some(connection_id); state.incoming = Some(incoming); + state.connection_killer = Some(kill); Ok(client_conn) }) } diff --git a/crates/rpc/src/conn.rs b/crates/rpc/src/conn.rs index fb91b72d9f..a97797fc9d 100644 --- a/crates/rpc/src/conn.rs +++ b/crates/rpc/src/conn.rs @@ -1,6 +1,5 @@ use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage}; -use futures::{SinkExt as _, Stream, StreamExt as _}; -use std::{io, task::Poll}; +use futures::{SinkExt as _, StreamExt as _}; pub struct Connection { pub(crate) tx: @@ -36,87 +35,82 @@ impl Connection { #[cfg(any(test, feature = "test-support"))] pub fn in_memory( executor: std::sync::Arc, - ) -> (Self, Self, postage::watch::Sender>) { - let (kill_tx, mut kill_rx) = postage::watch::channel_with(None); - postage::stream::Stream::try_recv(&mut kill_rx).unwrap(); + ) -> (Self, Self, postage::barrier::Sender) { + use postage::prelude::Stream; - let (a_tx, a_rx) = Self::channel(kill_rx.clone(), executor.clone()); - let (b_tx, b_rx) = Self::channel(kill_rx, executor); - ( + let (kill_tx, kill_rx) = postage::barrier::channel(); + let (a_tx, a_rx) = channel(kill_rx.clone(), executor.clone()); + let (b_tx, b_rx) = channel(kill_rx, executor); + return ( Self { tx: a_tx, rx: b_rx }, Self { tx: b_tx, rx: a_rx }, kill_tx, - ) - } + ); - #[cfg(any(test, feature = "test-support"))] - fn channel( - kill_rx: postage::watch::Receiver>, - executor: std::sync::Arc, - ) -> ( - Box>, - Box>>, - ) { - use futures::channel::mpsc; - use io::{Error, ErrorKind}; - use std::sync::Arc; + fn channel( + kill_rx: postage::barrier::Receiver, + executor: std::sync::Arc, + ) -> ( + Box>, + Box< + dyn Send + Unpin + futures::Stream>, + >, + ) { + use futures::channel::mpsc; + use std::{ + io::{Error, ErrorKind}, + sync::Arc, + }; - let (tx, rx) = mpsc::unbounded::(); - let tx = tx - .sink_map_err(|e| WebSocketError::from(Error::new(ErrorKind::Other, e))) - .with({ - let executor = Arc::downgrade(&executor); - let kill_rx = kill_rx.clone(); - move |msg| { + let (tx, rx) = mpsc::unbounded::(); + + let tx = tx + .sink_map_err(|e| WebSocketError::from(Error::new(ErrorKind::Other, e))) + .with({ let kill_rx = kill_rx.clone(); + let executor = Arc::downgrade(&executor); + move |msg| { + let mut kill_rx = kill_rx.clone(); + let executor = executor.clone(); + Box::pin(async move { + if let Some(executor) = executor.upgrade() { + executor.simulate_random_delay().await; + } + + // Writes to a half-open TCP connection will error. + if kill_rx.try_recv().is_ok() { + std::io::Result::Err( + Error::new(ErrorKind::Other, "connection lost").into(), + )?; + } + + Ok(msg) + }) + } + }); + + let rx = rx.then({ + let kill_rx = kill_rx.clone(); + let executor = Arc::downgrade(&executor); + move |msg| { + let mut kill_rx = kill_rx.clone(); let executor = executor.clone(); Box::pin(async move { if let Some(executor) = executor.upgrade() { executor.simulate_random_delay().await; } - if kill_rx.borrow().is_none() { - Ok(msg) - } else { - Err(Error::new(ErrorKind::Other, "connection killed").into()) + + // Reads from a half-open TCP connection will hang. + if kill_rx.try_recv().is_ok() { + futures::future::pending::<()>().await; } + + Ok(msg) }) } }); - let rx = rx.then(move |msg| { - let executor = Arc::downgrade(&executor); - Box::pin(async move { - if let Some(executor) = executor.upgrade() { - executor.simulate_random_delay().await; - } - msg - }) - }); - let rx = KillableReceiver { kill_rx, rx }; - (Box::new(tx), Box::new(rx)) - } -} - -struct KillableReceiver { - rx: S, - kill_rx: postage::watch::Receiver>, -} - -impl> Stream for KillableReceiver { - type Item = Result; - - fn poll_next( - mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> Poll> { - if let Poll::Ready(Some(Some(()))) = self.kill_rx.poll_next_unpin(cx) { - Poll::Ready(Some(Err(io::Error::new( - io::ErrorKind::Other, - "connection killed", - ) - .into()))) - } else { - self.rx.poll_next_unpin(cx).map(|value| value.map(Ok)) + (Box::new(tx), Box::new(rx)) } } } diff --git a/crates/rpc/src/peer.rs b/crates/rpc/src/peer.rs index f9c94cc84d..e9b8d50e68 100644 --- a/crates/rpc/src/peer.rs +++ b/crates/rpc/src/peer.rs @@ -371,7 +371,7 @@ mod tests { let client1 = Peer::new(); let client2 = Peer::new(); - let (client1_to_server_conn, server_to_client_1_conn, _) = + let (client1_to_server_conn, server_to_client_1_conn, _kill) = Connection::in_memory(cx.background()); let (client1_conn_id, io_task1, client1_incoming) = client1 .add_test_connection(client1_to_server_conn, cx.background()) @@ -380,7 +380,7 @@ mod tests { .add_test_connection(server_to_client_1_conn, cx.background()) .await; - let (client2_to_server_conn, server_to_client_2_conn, _) = + let (client2_to_server_conn, server_to_client_2_conn, _kill) = Connection::in_memory(cx.background()); let (client2_conn_id, io_task3, client2_incoming) = client2 .add_test_connection(client2_to_server_conn, cx.background()) @@ -468,7 +468,7 @@ mod tests { let server = Peer::new(); let client = Peer::new(); - let (client_to_server_conn, server_to_client_conn, _) = + let (client_to_server_conn, server_to_client_conn, _kill) = Connection::in_memory(cx.background()); let (client_to_server_conn_id, io_task1, mut client_incoming) = client .add_test_connection(client_to_server_conn, cx.background()) @@ -568,7 +568,7 @@ mod tests { let server = Peer::new(); let client = Peer::new(); - let (client_to_server_conn, server_to_client_conn, _) = + let (client_to_server_conn, server_to_client_conn, _kill) = Connection::in_memory(cx.background()); let (client_to_server_conn_id, io_task1, mut client_incoming) = client .add_test_connection(client_to_server_conn, cx.background()) @@ -680,7 +680,7 @@ mod tests { async fn test_disconnect(cx: &mut TestAppContext) { let executor = cx.foreground(); - let (client_conn, mut server_conn, _) = Connection::in_memory(cx.background()); + let (client_conn, mut server_conn, _kill) = Connection::in_memory(cx.background()); let client = Peer::new(); let (connection_id, io_handler, mut incoming) = client @@ -716,7 +716,7 @@ mod tests { #[gpui::test(iterations = 50)] async fn test_io_error(cx: &mut TestAppContext) { let executor = cx.foreground(); - let (client_conn, mut server_conn, _) = Connection::in_memory(cx.background()); + let (client_conn, mut server_conn, _kill) = Connection::in_memory(cx.background()); let client = Peer::new(); let (connection_id, io_handler, mut incoming) = client diff --git a/crates/server/src/rpc.rs b/crates/server/src/rpc.rs index 9f812ba104..63ac9c2eb3 100644 --- a/crates/server/src/rpc.rs +++ b/crates/server/src/rpc.rs @@ -1030,7 +1030,7 @@ mod tests { }; use lsp; use parking_lot::Mutex; - use postage::{sink::Sink, watch}; + use postage::{barrier, watch}; use project::{ fs::{FakeFs, Fs as _}, search::SearchQuery, @@ -1872,6 +1872,7 @@ mod tests { // Simulate connection loss for client B and ensure client A observes client B leaving the project. server.disconnect_client(client_b.current_user_id(cx_b)); + cx_a.foreground().advance_clock(Duration::from_secs(3)); project_a .condition(&cx_a, |p, _| p.collaborators().len() == 0) .await; @@ -3898,6 +3899,7 @@ mod tests { // Disconnect client B, ensuring we can still access its cached channel data. server.forbid_connections(); server.disconnect_client(client_b.current_user_id(&cx_b)); + cx_b.foreground().advance_clock(Duration::from_secs(3)); while !matches!( status_b.next().await, Some(client::Status::ReconnectionError { .. }) @@ -4388,7 +4390,7 @@ mod tests { server: Arc, foreground: Rc, notifications: mpsc::UnboundedReceiver<()>, - connection_killers: Arc>>>>, + connection_killers: Arc>>, forbid_connections: Arc, _test_db: TestDb, } @@ -4492,9 +4494,7 @@ mod tests { } fn disconnect_client(&self, user_id: UserId) { - if let Some(mut kill_conn) = self.connection_killers.lock().remove(&user_id) { - let _ = kill_conn.try_send(Some(())); - } + self.connection_killers.lock().remove(&user_id); } fn forbid_connections(&self) {