From a87d4db1555eed7bd3b71091556f588e2dc514b4 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Mon, 14 Jun 2021 19:59:46 +0200 Subject: [PATCH] Change `RpcClient` methods to take shared references This will make it easier to spawn a future on gpui's executors when calling `RpcClient` methods. Co-Authored-By: Max Brunsfeld --- Cargo.lock | 15 +++- zed/Cargo.toml | 2 +- zed/src/rpc_client.rs | 183 ++++++++++++++++++++++-------------------- zed/src/workspace.rs | 7 +- zed/src/worktree.rs | 4 +- 5 files changed, 115 insertions(+), 96 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 571bcfb97e..28ef3d4335 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1350,6 +1350,7 @@ checksum = "da9052a1a50244d8d5aa9bf55cbc2fb6f357c86cc52e46c62ed390a7180cf150" dependencies = [ "futures-channel", "futures-core", + "futures-executor", "futures-io", "futures-sink", "futures-task", @@ -1372,6 +1373,17 @@ version = "0.3.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "79e5145dde8da7d1b3892dad07a9c98fc04bc39892b1ecc9692cf53e2b780a65" +[[package]] +name = "futures-executor" +version = "0.3.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e9e59fdc009a4b3096bf94f740a0f2424c082521f20a9b08c5c07c48d90fd9b9" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", +] + [[package]] name = "futures-io" version = "0.3.12" @@ -1423,6 +1435,7 @@ version = "0.3.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "632a8cd0f2a4b3fdea1657f08bde063848c3bd00f9bbf6e256b8be78802e624b" dependencies = [ + "futures-channel", "futures-core", "futures-io", "futures-macro", @@ -4304,7 +4317,7 @@ dependencies = [ "easy-parallel", "env_logger", "fsevent", - "futures-core", + "futures", "gpui", "http-auth-basic", "ignore", diff --git a/zed/Cargo.toml b/zed/Cargo.toml index 2df4039a72..0504578992 100644 --- a/zed/Cargo.toml +++ b/zed/Cargo.toml @@ -21,7 +21,7 @@ ctor = "0.1.20" dirs = "3.0" easy-parallel = "3.1.0" fsevent = { path = "../fsevent" } -futures-core = "0.3" +futures = "0.3" gpui = { path = "../gpui" } http-auth-basic = "0.1.3" ignore = "0.4" diff --git a/zed/src/rpc_client.rs b/zed/src/rpc_client.rs index 2358c791a0..44e7c28407 100644 --- a/zed/src/rpc_client.rs +++ b/zed/src/rpc_client.rs @@ -1,106 +1,112 @@ use anyhow::{anyhow, Result}; +use futures::FutureExt; use gpui::executor::Background; use parking_lot::Mutex; use postage::{ - mpsc, oneshot, + mpsc, prelude::{Sink, Stream}, }; -use smol::{ - future::FutureExt, - io::WriteHalf, - prelude::{AsyncRead, AsyncWrite}, +use smol::prelude::{AsyncRead, AsyncWrite}; +use std::{ + collections::HashMap, + io, + sync::{ + atomic::{self, AtomicI32}, + Arc, + }, }; -use std::{collections::HashMap, sync::Arc}; use zed_rpc::proto::{ self, MessageStream, RequestMessage, SendMessage, ServerMessage, SubscribeMessage, }; -pub struct RpcClient { - stream: MessageStream>, +pub struct RpcClient { response_channels: Arc, bool)>>>, - next_message_id: i32, - _drop_tx: oneshot::Sender<()>, + outgoing_tx: mpsc::Sender, + next_message_id: AtomicI32, } -impl RpcClient -where - Conn: Clone + AsyncRead + AsyncWrite + Unpin + Send + 'static, -{ - pub fn new(conn: Conn, executor: Arc) -> Self { - let (conn_rx, conn_tx) = smol::io::split(conn); - let (drop_tx, mut drop_rx) = oneshot::channel(); +impl RpcClient { + pub fn new(conn: Conn, executor: Arc) -> Self + where + Conn: AsyncRead + AsyncWrite + Unpin + Send + 'static, + { let response_channels = Arc::new(Mutex::new(HashMap::new())); - let client = Self { - next_message_id: 0, - stream: MessageStream::new(conn_tx), - response_channels: response_channels.clone(), - _drop_tx: drop_tx, - }; + let (outgoing_tx, mut outgoing_rx) = mpsc::channel(32); - executor - .spawn::, _>(async move { - enum Message { - Message(proto::FromServer), - ClientDropped, - } - - let mut stream = MessageStream::new(conn_rx); - let client_dropped = async move { - assert!(drop_rx.recv().await.is_none()); - Ok(Message::ClientDropped) as Result<_> - }; - smol::pin!(client_dropped); - loop { - let message = async { - Ok(Message::Message( - stream.read_message::().await?, - )) - }; - - match message.race(&mut client_dropped).await? { - Message::Message(message) => { - if let Some(variant) = message.variant { - if let Some(request_id) = message.request_id { - let channel = response_channels.lock().remove(&request_id); - if let Some((mut tx, oneshot)) = channel { - if tx.send(variant).await.is_ok() { - if !oneshot { - response_channels - .lock() - .insert(request_id, (tx, false)); - } - } - } else { - log::warn!( - "received RPC response to unknown request id {}", - request_id - ); - } + { + let response_channels = response_channels.clone(); + executor + .spawn(async move { + let (conn_rx, conn_tx) = smol::io::split(conn); + let mut stream_tx = MessageStream::new(conn_tx); + let mut stream_rx = MessageStream::new(conn_rx); + loop { + futures::select! { + incoming = stream_rx.read_message::().fuse() => { + Self::handle_incoming(incoming, &response_channels).await; + } + outgoing = outgoing_rx.recv().fuse() => { + if let Some(outgoing) = outgoing { + stream_tx.write_message(&outgoing).await; + } else { + break; } - } else { - log::warn!("received RPC message with no content"); } } - Message::ClientDropped => break Ok(()), } - } - }) - .detach(); + }) + .detach(); + } - client + Self { + response_channels, + outgoing_tx, + next_message_id: AtomicI32::new(0), + } } - pub async fn request(&mut self, req: T) -> Result { - let message_id = self.next_message_id; - self.next_message_id += 1; + async fn handle_incoming( + incoming: io::Result, + response_channels: &Mutex, bool)>>, + ) { + match incoming { + Ok(incoming) => { + if let Some(variant) = incoming.variant { + if let Some(request_id) = incoming.request_id { + let channel = response_channels.lock().remove(&request_id); + if let Some((mut tx, oneshot)) = channel { + if tx.send(variant).await.is_ok() { + if !oneshot { + response_channels.lock().insert(request_id, (tx, false)); + } + } + } else { + log::warn!( + "received RPC response to unknown request id {}", + request_id + ); + } + } + } else { + log::warn!("received RPC message with no content"); + } + } + Err(error) => log::warn!("invalid incoming RPC message {:?}", error), + } + } + + pub async fn request(&self, req: T) -> Result { + let message_id = self.next_message_id.fetch_add(1, atomic::Ordering::SeqCst); let (tx, mut rx) = mpsc::channel(1); self.response_channels.lock().insert(message_id, (tx, true)); - self.stream - .write_message(&proto::FromClient { + self.outgoing_tx + .clone() + .send(proto::FromClient { id: message_id, variant: Some(req.to_variant()), }) - .await?; + .await + .unwrap(); let response = rx .recv() .await @@ -109,15 +115,16 @@ where .ok_or_else(|| anyhow!("received response of the wrong t")) } - pub async fn send(&mut self, message: T) -> Result<()> { - let message_id = self.next_message_id; - self.next_message_id += 1; - self.stream - .write_message(&proto::FromClient { + pub async fn send(&self, message: T) -> Result<()> { + let message_id = self.next_message_id.fetch_add(1, atomic::Ordering::SeqCst); + self.outgoing_tx + .clone() + .send(proto::FromClient { id: message_id, variant: Some(message.to_variant()), }) - .await?; + .await + .unwrap(); Ok(()) } @@ -125,19 +132,19 @@ where &mut self, subscription: T, ) -> Result>> { - let message_id = self.next_message_id; - self.next_message_id += 1; + let message_id = self.next_message_id.fetch_add(1, atomic::Ordering::SeqCst); let (tx, rx) = mpsc::channel(256); self.response_channels .lock() .insert(message_id, (tx, false)); - self.stream - .write_message(&proto::FromClient { + self.outgoing_tx + .clone() + .send(proto::FromClient { id: message_id, variant: Some(subscription.to_variant()), }) - .await?; - + .await + .unwrap(); Ok(rx.map(|event| { T::Event::from_variant(event).ok_or_else(|| anyhow!("invalid event {:?}")) })) @@ -165,7 +172,7 @@ mod tests { let (server_conn, _) = listener.accept().await.unwrap(); let mut server_stream = MessageStream::new(server_conn); - let mut client = RpcClient::new(client_conn, executor.clone()); + let client = RpcClient::new(client_conn, executor.clone()); let client_req = client.request(proto::from_client::Auth { user_id: 42, diff --git a/zed/src/workspace.rs b/zed/src/workspace.rs index 180771ce72..7bf2d7abf6 100644 --- a/zed/src/workspace.rs +++ b/zed/src/workspace.rs @@ -8,7 +8,6 @@ use crate::{ worktree::{FileHandle, Worktree, WorktreeHandle}, AppState, }; -use futures_core::Future; use gpui::{ color::rgbu, elements::*, json::to_string_pretty, keymap::Binding, AnyViewHandle, AppContext, ClipboardItem, Entity, ModelHandle, MutableAppContext, PathPromptOptions, PromptLevel, Task, @@ -19,10 +18,10 @@ pub use pane::*; pub use pane_group::*; use postage::watch; use smol::prelude::*; -use std::{collections::HashMap, path::PathBuf}; use std::{ - collections::{hash_map::Entry, HashSet}, - path::Path, + collections::{hash_map::Entry, HashMap, HashSet}, + future::Future, + path::{Path, PathBuf}, sync::Arc, }; diff --git a/zed/src/worktree.rs b/zed/src/worktree.rs index 762d98ee72..a43856d7f8 100644 --- a/zed/src/worktree.rs +++ b/zed/src/worktree.rs @@ -1207,7 +1207,7 @@ pub trait WorktreeHandle { fn flush_fs_events<'a>( &self, cx: &'a gpui::TestAppContext, - ) -> futures_core::future::LocalBoxFuture<'a, ()>; + ) -> futures::future::LocalBoxFuture<'a, ()>; } impl WorktreeHandle for ModelHandle { @@ -1268,7 +1268,7 @@ impl WorktreeHandle for ModelHandle { fn flush_fs_events<'a>( &self, cx: &'a gpui::TestAppContext, - ) -> futures_core::future::LocalBoxFuture<'a, ()> { + ) -> futures::future::LocalBoxFuture<'a, ()> { use smol::future::FutureExt; let filename = "fs-event-sentinel";