From d4fe1115e7cdfe79de7fdb86a7a384c3800225cc Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Mon, 7 Feb 2022 12:27:13 -0800 Subject: [PATCH] Use an unbounded channel for peer's outgoing messages Using a bounded channel may have blocked the collaboration server from making progress handling RPC traffic. There's no need to apply backpressure to calling code within the same process - suspending a task that is attempting to call `send` has an even greater memory cost than just buffering a protobuf message. We do still want a bounded channel for incoming messages, so that we provide backpressure to noisy peers - blocking their writes as opposed to allowing them to buffer arbitrarily many messages in our server. Co-Authored-By: Antonio Scandurra Co-Authored-By: Nathan Sobo --- crates/client/src/channel.rs | 38 ++-- crates/client/src/client.rs | 17 +- crates/client/src/test.rs | 6 +- crates/project/src/project.rs | 164 +++++++-------- crates/project/src/worktree.rs | 92 +++------ crates/rpc/src/peer.rs | 141 ++++++------- crates/server/src/rpc.rs | 355 ++++++++++++++------------------- 7 files changed, 341 insertions(+), 472 deletions(-) diff --git a/crates/client/src/channel.rs b/crates/client/src/channel.rs index d9555da7dd..f89f578247 100644 --- a/crates/client/src/channel.rs +++ b/crates/client/src/channel.rs @@ -17,7 +17,7 @@ use std::{ }; use sum_tree::{Bias, SumTree}; use time::OffsetDateTime; -use util::{post_inc, TryFutureExt}; +use util::{post_inc, ResultExt as _, TryFutureExt}; pub struct ChannelList { available_channels: Option>, @@ -168,16 +168,12 @@ impl ChannelList { impl Entity for Channel { type Event = ChannelEvent; - fn release(&mut self, cx: &mut MutableAppContext) { - let rpc = self.rpc.clone(); - let channel_id = self.details.id; - cx.foreground() - .spawn(async move { - if let Err(error) = rpc.send(proto::LeaveChannel { channel_id }).await { - log::error!("error leaving channel: {}", error); - }; + fn release(&mut self, _: &mut MutableAppContext) { + self.rpc + .send(proto::LeaveChannel { + channel_id: self.details.id, }) - .detach() + .log_err(); } } @@ -718,18 +714,16 @@ mod tests { }); // Receive a new message. - server - .send(proto::ChannelMessageSent { - channel_id: channel.read_with(&cx, |channel, _| channel.details.id), - message: Some(proto::ChannelMessage { - id: 12, - body: "c".into(), - timestamp: 1002, - sender_id: 7, - nonce: Some(3.into()), - }), - }) - .await; + server.send(proto::ChannelMessageSent { + channel_id: channel.read_with(&cx, |channel, _| channel.details.id), + message: Some(proto::ChannelMessage { + id: 12, + body: "c".into(), + timestamp: 1002, + sender_id: 7, + nonce: Some(3.into()), + }), + }); // Client requests user for message since they haven't seen them yet let get_users = server.receive::().await.unwrap(); diff --git a/crates/client/src/client.rs b/crates/client/src/client.rs index e22cd7cba9..3d70622ad6 100644 --- a/crates/client/src/client.rs +++ b/crates/client/src/client.rs @@ -24,7 +24,6 @@ use std::{ collections::HashMap, convert::TryFrom, fmt::Write as _, - future::Future, sync::{Arc, Weak}, time::{Duration, Instant}, }; @@ -677,8 +676,8 @@ impl Client { } } - pub async fn send(&self, message: T) -> Result<()> { - self.peer.send(self.connection_id()?, message).await + pub fn send(&self, message: T) -> Result<()> { + self.peer.send(self.connection_id()?, message) } pub async fn request(&self, request: T) -> Result { @@ -689,7 +688,7 @@ impl Client { &self, receipt: Receipt, response: T::Response, - ) -> impl Future> { + ) -> Result<()> { self.peer.respond(receipt, response) } @@ -697,7 +696,7 @@ impl Client { &self, receipt: Receipt, error: proto::Error, - ) -> impl Future> { + ) -> Result<()> { self.peer.respond_with_error(receipt, error) } } @@ -860,8 +859,8 @@ mod tests { }); drop(subscription3); - server.send(proto::UnshareProject { project_id: 1 }).await; - server.send(proto::UnshareProject { project_id: 2 }).await; + server.send(proto::UnshareProject { project_id: 1 }); + server.send(proto::UnshareProject { project_id: 2 }); done_rx1.next().await.unwrap(); done_rx2.next().await.unwrap(); } @@ -890,7 +889,7 @@ mod tests { Ok(()) }) }); - server.send(proto::Ping {}).await; + server.send(proto::Ping {}); done_rx2.next().await.unwrap(); } @@ -914,7 +913,7 @@ mod tests { }, )); }); - server.send(proto::Ping {}).await; + server.send(proto::Ping {}); done_rx.next().await.unwrap(); } diff --git a/crates/client/src/test.rs b/crates/client/src/test.rs index 7402417196..c8aca79192 100644 --- a/crates/client/src/test.rs +++ b/crates/client/src/test.rs @@ -118,8 +118,8 @@ impl FakeServer { self.forbid_connections.store(false, SeqCst); } - pub async fn send(&self, message: T) { - self.peer.send(self.connection_id(), message).await.unwrap(); + pub fn send(&self, message: T) { + self.peer.send(self.connection_id(), message).unwrap(); } pub async fn receive(&self) -> Result> { @@ -148,7 +148,7 @@ impl FakeServer { receipt: Receipt, response: T::Response, ) { - self.peer.respond(receipt, response).await.unwrap() + self.peer.respond(receipt, response).unwrap() } fn connection_id(&self) -> ConnectionId { diff --git a/crates/project/src/project.rs b/crates/project/src/project.rs index aced7b4f09..a9a9470005 100644 --- a/crates/project/src/project.rs +++ b/crates/project/src/project.rs @@ -460,7 +460,7 @@ impl Project { } })?; - rpc.send(proto::UnshareProject { project_id }).await?; + rpc.send(proto::UnshareProject { project_id })?; this.update(&mut cx, |this, cx| { this.collaborators.clear(); this.shared_buffers.clear(); @@ -818,15 +818,13 @@ impl Project { let this = cx.read(|cx| this.upgrade(cx))?; match message { LspEvent::DiagnosticsStart => { - let send = this.update(&mut cx, |this, cx| { + this.update(&mut cx, |this, cx| { this.disk_based_diagnostics_started(cx); - this.remote_id().map(|project_id| { + if let Some(project_id) = this.remote_id() { rpc.send(proto::DiskBasedDiagnosticsUpdating { project_id }) - }) + .log_err(); + } }); - if let Some(send) = send { - send.await.log_err(); - } } LspEvent::DiagnosticsUpdate(mut params) => { language.process_diagnostics(&mut params); @@ -836,15 +834,13 @@ impl Project { }); } LspEvent::DiagnosticsFinish => { - let send = this.update(&mut cx, |this, cx| { + this.update(&mut cx, |this, cx| { this.disk_based_diagnostics_finished(cx); - this.remote_id().map(|project_id| { + if let Some(project_id) = this.remote_id() { rpc.send(proto::DiskBasedDiagnosticsUpdated { project_id }) - }) + .log_err(); + } }); - if let Some(send) = send { - send.await.log_err(); - } } } } @@ -1311,15 +1307,13 @@ impl Project { }; if let Some(project_id) = self.remote_id() { - let client = self.client.clone(); - let message = proto::UpdateBufferFile { - project_id, - buffer_id: *buffer_id as u64, - file: Some(new_file.to_proto()), - }; - cx.foreground() - .spawn(async move { client.send(message).await }) - .detach_and_log_err(cx); + self.client + .send(proto::UpdateBufferFile { + project_id, + buffer_id: *buffer_id as u64, + file: Some(new_file.to_proto()), + }) + .log_err(); } buffer.file_updated(Box::new(new_file), cx).detach(); } @@ -1639,8 +1633,7 @@ impl Project { version: (&version).into(), mtime: Some(mtime.into()), }, - ) - .await?; + )?; Ok(()) } @@ -1669,16 +1662,13 @@ impl Project { // associated with formatting. cx.spawn(|_| async move { match format { - Ok(()) => rpc.respond(receipt, proto::Ack {}).await?, - Err(error) => { - rpc.respond_with_error( - receipt, - proto::Error { - message: error.to_string(), - }, - ) - .await? - } + Ok(()) => rpc.respond(receipt, proto::Ack {})?, + Err(error) => rpc.respond_with_error( + receipt, + proto::Error { + message: error.to_string(), + }, + )?, } Ok::<_, anyhow::Error>(()) }) @@ -1712,27 +1702,21 @@ impl Project { .update(&mut cx, |buffer, cx| buffer.completions(position, cx)) .await { - Ok(completions) => { - rpc.respond( - receipt, - proto::GetCompletionsResponse { - completions: completions - .iter() - .map(language::proto::serialize_completion) - .collect(), - }, - ) - .await - } - Err(error) => { - rpc.respond_with_error( - receipt, - proto::Error { - message: error.to_string(), - }, - ) - .await - } + Ok(completions) => rpc.respond( + receipt, + proto::GetCompletionsResponse { + completions: completions + .iter() + .map(language::proto::serialize_completion) + .collect(), + }, + ), + Err(error) => rpc.respond_with_error( + receipt, + proto::Error { + message: error.to_string(), + }, + ), } }) .detach_and_log_err(cx); @@ -1767,30 +1751,24 @@ impl Project { }) .await { - Ok(edit_ids) => { - rpc.respond( - receipt, - proto::ApplyCompletionAdditionalEditsResponse { - additional_edits: edit_ids - .into_iter() - .map(|edit_id| proto::AdditionalEdit { - replica_id: edit_id.replica_id as u32, - local_timestamp: edit_id.value, - }) - .collect(), - }, - ) - .await - } - Err(error) => { - rpc.respond_with_error( - receipt, - proto::Error { - message: error.to_string(), - }, - ) - .await - } + Ok(edit_ids) => rpc.respond( + receipt, + proto::ApplyCompletionAdditionalEditsResponse { + additional_edits: edit_ids + .into_iter() + .map(|edit_id| proto::AdditionalEdit { + replica_id: edit_id.replica_id as u32, + local_timestamp: edit_id.value, + }) + .collect(), + }, + ), + Err(error) => rpc.respond_with_error( + receipt, + proto::Error { + message: error.to_string(), + }, + ), } }) .detach_and_log_err(cx); @@ -1836,7 +1814,7 @@ impl Project { }); } }); - rpc.respond(receipt, response).await?; + rpc.respond(receipt, response)?; Ok::<_, anyhow::Error>(()) }) .detach_and_log_err(cx); @@ -1872,7 +1850,6 @@ impl Project { buffer: Some(buffer), }, ) - .await } .log_err() }) @@ -2106,28 +2083,21 @@ impl<'a> Iterator for CandidateSetIter<'a> { impl Entity for Project { type Event = Event; - fn release(&mut self, cx: &mut gpui::MutableAppContext) { + fn release(&mut self, _: &mut gpui::MutableAppContext) { match &self.client_state { ProjectClientState::Local { remote_id_rx, .. } => { if let Some(project_id) = *remote_id_rx.borrow() { - let rpc = self.client.clone(); - cx.spawn(|_| async move { - if let Err(err) = rpc.send(proto::UnregisterProject { project_id }).await { - log::error!("error unregistering project: {}", err); - } - }) - .detach(); + self.client + .send(proto::UnregisterProject { project_id }) + .log_err(); } } ProjectClientState::Remote { remote_id, .. } => { - let rpc = self.client.clone(); - let project_id = *remote_id; - cx.spawn(|_| async move { - if let Err(err) = rpc.send(proto::LeaveProject { project_id }).await { - log::error!("error leaving project: {}", err); - } - }) - .detach(); + self.client + .send(proto::LeaveProject { + project_id: *remote_id, + }) + .log_err(); } } } diff --git a/crates/project/src/worktree.rs b/crates/project/src/worktree.rs index 32b4009207..643c26aa71 100644 --- a/crates/project/src/worktree.rs +++ b/crates/project/src/worktree.rs @@ -149,7 +149,7 @@ pub enum Event { impl Entity for Worktree { type Event = Event; - fn release(&mut self, cx: &mut MutableAppContext) { + fn release(&mut self, _: &mut MutableAppContext) { if let Some(worktree) = self.as_local_mut() { if let Registration::Done { project_id } = worktree.registration { let client = worktree.client.clone(); @@ -157,12 +157,7 @@ impl Entity for Worktree { project_id, worktree_id: worktree.id().to_proto(), }; - cx.foreground() - .spawn(async move { - client.send(unregister_message).await?; - Ok::<_, anyhow::Error>(()) - }) - .detach_and_log_err(cx); + client.send(unregister_message).log_err(); } } } @@ -596,7 +591,7 @@ impl LocalWorktree { &mut self, worktree_path: Arc, diagnostics: Vec>, - cx: &mut ModelContext, + _: &mut ModelContext, ) -> Result<()> { let summary = DiagnosticSummary::new(&diagnostics); self.diagnostic_summaries @@ -604,30 +599,19 @@ impl LocalWorktree { self.diagnostics.insert(worktree_path.clone(), diagnostics); if let Some(share) = self.share.as_ref() { - cx.foreground() - .spawn({ - let client = self.client.clone(); - let project_id = share.project_id; - let worktree_id = self.id().to_proto(); - let path = worktree_path.to_string_lossy().to_string(); - async move { - client - .send(proto::UpdateDiagnosticSummary { - project_id, - worktree_id, - summary: Some(proto::DiagnosticSummary { - path, - error_count: summary.error_count as u32, - warning_count: summary.warning_count as u32, - info_count: summary.info_count as u32, - hint_count: summary.hint_count as u32, - }), - }) - .await - .log_err() - } + self.client + .send(proto::UpdateDiagnosticSummary { + project_id: share.project_id, + worktree_id: self.id().to_proto(), + summary: Some(proto::DiagnosticSummary { + path: worktree_path.to_string_lossy().to_string(), + error_count: summary.error_count as u32, + warning_count: summary.warning_count as u32, + info_count: summary.info_count as u32, + hint_count: summary.hint_count as u32, + }), }) - .detach(); + .log_err(); } Ok(()) @@ -787,7 +771,7 @@ impl LocalWorktree { while let Ok(snapshot) = snapshots_to_send_rx.recv().await { let message = snapshot.build_update(&prev_snapshot, project_id, worktree_id, false); - match rpc.send(message).await { + match rpc.send(message) { Ok(()) => prev_snapshot = snapshot, Err(err) => log::error!("error sending snapshot diff {}", err), } @@ -1377,8 +1361,7 @@ impl language::File for File { buffer_id, version: (&version).into(), mtime: Some(entry.mtime.into()), - }) - .await?; + })?; } Ok((version, entry.mtime)) }) @@ -1501,23 +1484,15 @@ impl language::File for File { } fn buffer_removed(&self, buffer_id: u64, cx: &mut MutableAppContext) { - self.worktree.update(cx, |worktree, cx| { + self.worktree.update(cx, |worktree, _| { if let Worktree::Remote(worktree) = worktree { - let project_id = worktree.project_id; - let rpc = worktree.client.clone(); - cx.background() - .spawn(async move { - if let Err(error) = rpc - .send(proto::CloseBuffer { - project_id, - buffer_id, - }) - .await - { - log::error!("error closing remote buffer: {}", error); - } + worktree + .client + .send(proto::CloseBuffer { + project_id: worktree.project_id, + buffer_id, }) - .detach(); + .log_err(); } }); } @@ -1563,16 +1538,15 @@ impl language::LocalFile for File { ) { let worktree = self.worktree.read(cx).as_local().unwrap(); if let Some(project_id) = worktree.share.as_ref().map(|share| share.project_id) { - let rpc = worktree.client.clone(); - let message = proto::BufferReloaded { - project_id, - buffer_id, - version: version.into(), - mtime: Some(mtime.into()), - }; - cx.background() - .spawn(async move { rpc.send(message).await }) - .detach_and_log_err(cx); + worktree + .client + .send(proto::BufferReloaded { + project_id, + buffer_id, + version: version.into(), + mtime: Some(mtime.into()), + }) + .log_err(); } } } diff --git a/crates/rpc/src/peer.rs b/crates/rpc/src/peer.rs index dcfcd2530c..ec9e109dea 100644 --- a/crates/rpc/src/peer.rs +++ b/crates/rpc/src/peer.rs @@ -89,7 +89,7 @@ pub struct Peer { #[derive(Clone)] pub struct ConnectionState { - outgoing_tx: mpsc::Sender, + outgoing_tx: futures::channel::mpsc::UnboundedSender, next_message_id: Arc, response_channels: Arc>>>>, } @@ -112,9 +112,14 @@ impl Peer { impl Future> + Send, BoxStream<'static, Box>, ) { - let connection_id = ConnectionId(self.next_connection_id.fetch_add(1, SeqCst)); + // For outgoing messages, use an unbounded channel so that application code + // can always send messages without yielding. For incoming messages, use a + // bounded channel so that other peers will receive backpressure if they send + // messages faster than this peer can process them. let (mut incoming_tx, incoming_rx) = mpsc::channel(64); - let (outgoing_tx, mut outgoing_rx) = mpsc::channel(64); + let (outgoing_tx, mut outgoing_rx) = futures::channel::mpsc::unbounded(); + + let connection_id = ConnectionId(self.next_connection_id.fetch_add(1, SeqCst)); let connection_state = ConnectionState { outgoing_tx, next_message_id: Default::default(), @@ -131,6 +136,16 @@ impl Peer { futures::pin_mut!(read_message); loop { futures::select_biased! { + outgoing = outgoing_rx.next().fuse() => match outgoing { + Some(outgoing) => { + match writer.write_message(&outgoing).timeout(WRITE_TIMEOUT).await { + None => break 'outer Err(anyhow!("timed out writing RPC message")), + Some(Err(result)) => break 'outer Err(result).context("failed to write RPC message"), + _ => {} + } + } + None => break 'outer Ok(()), + }, incoming = read_message => match incoming { Ok(incoming) => { if incoming_tx.send(incoming).await.is_err() { @@ -142,16 +157,6 @@ impl Peer { break 'outer Err(error).context("received invalid RPC message") } }, - outgoing = outgoing_rx.recv().fuse() => match outgoing { - Some(outgoing) => { - match writer.write_message(&outgoing).timeout(WRITE_TIMEOUT).await { - None => break 'outer Err(anyhow!("timed out writing RPC message")), - Some(Err(result)) => break 'outer Err(result).context("failed to write RPC message"), - _ => {} - } - } - None => break 'outer Ok(()), - } } } }; @@ -223,9 +228,9 @@ impl Peer { request: T, ) -> impl Future> { let this = self.clone(); - let (tx, mut rx) = mpsc::channel(1); async move { - let mut connection = this.connection_state(receiver_id)?; + let (tx, mut rx) = mpsc::channel(1); + let connection = this.connection_state(receiver_id)?; let message_id = connection.next_message_id.fetch_add(1, SeqCst); connection .response_channels @@ -235,8 +240,11 @@ impl Peer { .insert(message_id, tx); connection .outgoing_tx - .send(request.into_envelope(message_id, None, original_sender_id.map(|id| id.0))) - .await + .unbounded_send(request.into_envelope( + message_id, + None, + original_sender_id.map(|id| id.0), + )) .map_err(|_| anyhow!("connection was closed"))?; let response = rx .recv() @@ -255,19 +263,15 @@ impl Peer { self: &Arc, receiver_id: ConnectionId, message: T, - ) -> impl Future> { - let this = self.clone(); - async move { - let mut connection = this.connection_state(receiver_id)?; - let message_id = connection - .next_message_id - .fetch_add(1, atomic::Ordering::SeqCst); - connection - .outgoing_tx - .send(message.into_envelope(message_id, None, None)) - .await?; - Ok(()) - } + ) -> Result<()> { + let connection = self.connection_state(receiver_id)?; + let message_id = connection + .next_message_id + .fetch_add(1, atomic::Ordering::SeqCst); + connection + .outgoing_tx + .unbounded_send(message.into_envelope(message_id, None, None))?; + Ok(()) } pub fn forward_send( @@ -275,57 +279,45 @@ impl Peer { sender_id: ConnectionId, receiver_id: ConnectionId, message: T, - ) -> impl Future> { - let this = self.clone(); - async move { - let mut connection = this.connection_state(receiver_id)?; - let message_id = connection - .next_message_id - .fetch_add(1, atomic::Ordering::SeqCst); - connection - .outgoing_tx - .send(message.into_envelope(message_id, None, Some(sender_id.0))) - .await?; - Ok(()) - } + ) -> Result<()> { + let connection = self.connection_state(receiver_id)?; + let message_id = connection + .next_message_id + .fetch_add(1, atomic::Ordering::SeqCst); + connection + .outgoing_tx + .unbounded_send(message.into_envelope(message_id, None, Some(sender_id.0)))?; + Ok(()) } pub fn respond( self: &Arc, receipt: Receipt, response: T::Response, - ) -> impl Future> { - let this = self.clone(); - async move { - let mut connection = this.connection_state(receipt.sender_id)?; - let message_id = connection - .next_message_id - .fetch_add(1, atomic::Ordering::SeqCst); - connection - .outgoing_tx - .send(response.into_envelope(message_id, Some(receipt.message_id), None)) - .await?; - Ok(()) - } + ) -> Result<()> { + let connection = self.connection_state(receipt.sender_id)?; + let message_id = connection + .next_message_id + .fetch_add(1, atomic::Ordering::SeqCst); + connection + .outgoing_tx + .unbounded_send(response.into_envelope(message_id, Some(receipt.message_id), None))?; + Ok(()) } pub fn respond_with_error( self: &Arc, receipt: Receipt, response: proto::Error, - ) -> impl Future> { - let this = self.clone(); - async move { - let mut connection = this.connection_state(receipt.sender_id)?; - let message_id = connection - .next_message_id - .fetch_add(1, atomic::Ordering::SeqCst); - connection - .outgoing_tx - .send(response.into_envelope(message_id, Some(receipt.message_id), None)) - .await?; - Ok(()) - } + ) -> Result<()> { + let connection = self.connection_state(receipt.sender_id)?; + let message_id = connection + .next_message_id + .fetch_add(1, atomic::Ordering::SeqCst); + connection + .outgoing_tx + .unbounded_send(response.into_envelope(message_id, Some(receipt.message_id), None))?; + Ok(()) } fn connection_state(&self, connection_id: ConnectionId) -> Result { @@ -447,7 +439,7 @@ mod tests { let envelope = envelope.into_any(); if let Some(envelope) = envelope.downcast_ref::>() { let receipt = envelope.receipt(); - peer.respond(receipt, proto::Ack {}).await? + peer.respond(receipt, proto::Ack {})? } else if let Some(envelope) = envelope.downcast_ref::>() { @@ -475,7 +467,7 @@ mod tests { } }; - peer.respond(receipt, response).await? + peer.respond(receipt, response)? } else { panic!("unknown message type"); } @@ -518,7 +510,6 @@ mod tests { message: "message 1".to_string(), }, ) - .await .unwrap(); server .send( @@ -527,12 +518,8 @@ mod tests { message: "message 2".to_string(), }, ) - .await - .unwrap(); - server - .respond(request.receipt(), proto::Ack {}) - .await .unwrap(); + server.respond(request.receipt(), proto::Ack {}).unwrap(); // Prevent the connection from being dropped server_incoming.next().await; diff --git a/crates/server/src/rpc.rs b/crates/server/src/rpc.rs index 1d9f784f47..9b779a57d1 100644 --- a/crates/server/src/rpc.rs +++ b/crates/server/src/rpc.rs @@ -131,7 +131,7 @@ impl Server { } this.state_mut().add_connection(connection_id, user_id); - if let Err(err) = this.update_contacts_for_users(&[user_id]).await { + if let Err(err) = this.update_contacts_for_users(&[user_id]) { log::error!("error updating contacts for {:?}: {}", user_id, err); } @@ -141,6 +141,12 @@ impl Server { let next_message = incoming_rx.next().fuse(); futures::pin_mut!(next_message); futures::select_biased! { + result = handle_io => { + if let Err(err) = result { + log::error!("error handling rpc connection {:?} - {:?}", addr, err); + } + break; + } message = next_message => { if let Some(message) = message { let start_time = Instant::now(); @@ -163,12 +169,6 @@ impl Server { break; } } - handle_io = handle_io => { - if let Err(err) = handle_io { - log::error!("error handling rpc connection {:?} - {:?}", addr, err); - } - break; - } } } @@ -191,8 +191,7 @@ impl Server { self.peer .send(conn_id, proto::UnshareProject { project_id }) }, - ) - .await?; + )?; } } @@ -205,18 +204,15 @@ impl Server { peer_id: connection_id.0, }, ) - }) - .await?; + })?; } - self.update_contacts_for_users(removed_connection.contact_ids.iter()) - .await?; - + self.update_contacts_for_users(removed_connection.contact_ids.iter())?; Ok(()) } async fn ping(self: Arc, request: TypedEnvelope) -> tide::Result<()> { - self.peer.respond(request.receipt(), proto::Ack {}).await?; + self.peer.respond(request.receipt(), proto::Ack {})?; Ok(()) } @@ -229,12 +225,10 @@ impl Server { let user_id = state.user_id_for_connection(request.sender_id)?; state.register_project(request.sender_id, user_id) }; - self.peer - .respond( - request.receipt(), - proto::RegisterProjectResponse { project_id }, - ) - .await?; + self.peer.respond( + request.receipt(), + proto::RegisterProjectResponse { project_id }, + )?; Ok(()) } @@ -246,8 +240,7 @@ impl Server { .state_mut() .unregister_project(request.payload.project_id, request.sender_id) .ok_or_else(|| anyhow!("no such project"))?; - self.update_contacts_for_users(project.authorized_user_ids().iter()) - .await?; + self.update_contacts_for_users(project.authorized_user_ids().iter())?; Ok(()) } @@ -257,7 +250,7 @@ impl Server { ) -> tide::Result<()> { self.state_mut() .share_project(request.payload.project_id, request.sender_id); - self.peer.respond(request.receipt(), proto::Ack {}).await?; + self.peer.respond(request.receipt(), proto::Ack {})?; Ok(()) } @@ -273,11 +266,8 @@ impl Server { broadcast(request.sender_id, project.connection_ids, |conn_id| { self.peer .send(conn_id, proto::UnshareProject { project_id }) - }) - .await?; - self.update_contacts_for_users(&project.authorized_user_ids) - .await?; - + })?; + self.update_contacts_for_users(&project.authorized_user_ids)?; Ok(()) } @@ -351,20 +341,17 @@ impl Server { }), }, ) - }) - .await?; - self.peer.respond(request.receipt(), response).await?; - self.update_contacts_for_users(&contact_user_ids).await?; + })?; + self.peer.respond(request.receipt(), response)?; + self.update_contacts_for_users(&contact_user_ids)?; } Err(error) => { - self.peer - .respond_with_error( - request.receipt(), - proto::Error { - message: error.to_string(), - }, - ) - .await?; + self.peer.respond_with_error( + request.receipt(), + proto::Error { + message: error.to_string(), + }, + )?; } } @@ -387,10 +374,8 @@ impl Server { peer_id: sender_id.0, }, ) - }) - .await?; - self.update_contacts_for_users(&worktree.authorized_user_ids) - .await?; + })?; + self.update_contacts_for_users(&worktree.authorized_user_ids)?; } Ok(()) } @@ -412,8 +397,7 @@ impl Server { Err(err) => { let message = err.to_string(); self.peer - .respond_with_error(receipt, proto::Error { message }) - .await?; + .respond_with_error(receipt, proto::Error { message })?; return Ok(()); } } @@ -432,17 +416,15 @@ impl Server { ); if ok { - self.peer.respond(receipt, proto::Ack {}).await?; - self.update_contacts_for_users(&contact_user_ids).await?; + self.peer.respond(receipt, proto::Ack {})?; + self.update_contacts_for_users(&contact_user_ids)?; } else { - self.peer - .respond_with_error( - receipt, - proto::Error { - message: NO_SUCH_PROJECT.to_string(), - }, - ) - .await?; + self.peer.respond_with_error( + receipt, + proto::Error { + message: NO_SUCH_PROJECT.to_string(), + }, + )?; } Ok(()) @@ -457,7 +439,6 @@ impl Server { let (worktree, guest_connection_ids) = self.state_mut() .unregister_worktree(project_id, worktree_id, request.sender_id)?; - broadcast(request.sender_id, guest_connection_ids, |conn_id| { self.peer.send( conn_id, @@ -466,10 +447,8 @@ impl Server { worktree_id, }, ) - }) - .await?; - self.update_contacts_for_users(&worktree.authorized_user_ids) - .await?; + })?; + self.update_contacts_for_users(&worktree.authorized_user_ids)?; Ok(()) } @@ -511,20 +490,16 @@ impl Server { request.payload.clone(), ) }, - ) - .await?; - self.peer.respond(request.receipt(), proto::Ack {}).await?; - self.update_contacts_for_users(&shared_worktree.authorized_user_ids) - .await?; + )?; + self.peer.respond(request.receipt(), proto::Ack {})?; + self.update_contacts_for_users(&shared_worktree.authorized_user_ids)?; } else { - self.peer - .respond_with_error( - request.receipt(), - proto::Error { - message: "no such worktree".to_string(), - }, - ) - .await?; + self.peer.respond_with_error( + request.receipt(), + proto::Error { + message: "no such worktree".to_string(), + }, + )?; } Ok(()) } @@ -547,8 +522,7 @@ impl Server { broadcast(request.sender_id, connection_ids, |connection_id| { self.peer .forward_send(request.sender_id, connection_id, request.payload.clone()) - }) - .await?; + })?; Ok(()) } @@ -574,8 +548,7 @@ impl Server { broadcast(request.sender_id, receiver_ids, |connection_id| { self.peer .forward_send(request.sender_id, connection_id, request.payload.clone()) - }) - .await?; + })?; Ok(()) } @@ -590,8 +563,7 @@ impl Server { broadcast(request.sender_id, receiver_ids, |connection_id| { self.peer .forward_send(request.sender_id, connection_id, request.payload.clone()) - }) - .await?; + })?; Ok(()) } @@ -606,8 +578,7 @@ impl Server { broadcast(request.sender_id, receiver_ids, |connection_id| { self.peer .forward_send(request.sender_id, connection_id, request.payload.clone()) - }) - .await?; + })?; Ok(()) } @@ -625,7 +596,7 @@ impl Server { .peer .forward_request(request.sender_id, host_connection_id, request.payload) .await?; - self.peer.respond(receipt, response).await?; + self.peer.respond(receipt, response)?; Ok(()) } @@ -643,7 +614,7 @@ impl Server { .peer .forward_request(request.sender_id, host_connection_id, request.payload) .await?; - self.peer.respond(receipt, response).await?; + self.peer.respond(receipt, response)?; Ok(()) } @@ -657,8 +628,7 @@ impl Server { .ok_or_else(|| anyhow!(NO_SUCH_PROJECT))? .host_connection_id; self.peer - .forward_send(request.sender_id, host_connection_id, request.payload) - .await?; + .forward_send(request.sender_id, host_connection_id, request.payload)?; Ok(()) } @@ -686,16 +656,12 @@ impl Server { broadcast(host, guests, |conn_id| { let response = response.clone(); - let peer = &self.peer; - async move { - if conn_id == sender { - peer.respond(receipt, response).await - } else { - peer.forward_send(host, conn_id, response).await - } + if conn_id == sender { + self.peer.respond(receipt, response) + } else { + self.peer.forward_send(host, conn_id, response) } - }) - .await?; + })?; Ok(()) } @@ -719,7 +685,7 @@ impl Server { .peer .forward_request(sender, host, request.payload.clone()) .await?; - self.peer.respond(receipt, response).await?; + self.peer.respond(receipt, response)?; Ok(()) } @@ -743,8 +709,7 @@ impl Server { .peer .forward_request(sender, host, request.payload.clone()) .await?; - self.peer.respond(receipt, response).await?; - + self.peer.respond(receipt, response)?; Ok(()) } @@ -767,8 +732,7 @@ impl Server { .peer .forward_request(sender, host, request.payload.clone()) .await?; - self.peer.respond(receipt, response).await?; - + self.peer.respond(receipt, response)?; Ok(()) } @@ -783,9 +747,8 @@ impl Server { broadcast(request.sender_id, receiver_ids, |connection_id| { self.peer .forward_send(request.sender_id, connection_id, request.payload.clone()) - }) - .await?; - self.peer.respond(request.receipt(), proto::Ack {}).await?; + })?; + self.peer.respond(request.receipt(), proto::Ack {})?; Ok(()) } @@ -800,8 +763,7 @@ impl Server { broadcast(request.sender_id, receiver_ids, |connection_id| { self.peer .forward_send(request.sender_id, connection_id, request.payload.clone()) - }) - .await?; + })?; Ok(()) } @@ -816,8 +778,7 @@ impl Server { broadcast(request.sender_id, receiver_ids, |connection_id| { self.peer .forward_send(request.sender_id, connection_id, request.payload.clone()) - }) - .await?; + })?; Ok(()) } @@ -832,8 +793,7 @@ impl Server { broadcast(request.sender_id, receiver_ids, |connection_id| { self.peer .forward_send(request.sender_id, connection_id, request.payload.clone()) - }) - .await?; + })?; Ok(()) } @@ -843,20 +803,18 @@ impl Server { ) -> tide::Result<()> { let user_id = self.state().user_id_for_connection(request.sender_id)?; let channels = self.app_state.db.get_accessible_channels(user_id).await?; - self.peer - .respond( - request.receipt(), - proto::GetChannelsResponse { - channels: channels - .into_iter() - .map(|chan| proto::Channel { - id: chan.id.to_proto(), - name: chan.name, - }) - .collect(), - }, - ) - .await?; + self.peer.respond( + request.receipt(), + proto::GetChannelsResponse { + channels: channels + .into_iter() + .map(|chan| proto::Channel { + id: chan.id.to_proto(), + name: chan.name, + }) + .collect(), + }, + )?; Ok(()) } @@ -879,34 +837,30 @@ impl Server { }) .collect(); self.peer - .respond(receipt, proto::GetUsersResponse { users }) - .await?; + .respond(receipt, proto::GetUsersResponse { users })?; Ok(()) } - async fn update_contacts_for_users<'a>( + fn update_contacts_for_users<'a>( self: &Arc, user_ids: impl IntoIterator, - ) -> tide::Result<()> { - let mut send_futures = Vec::new(); - - { - let state = self.state(); - for user_id in user_ids { - let contacts = state.contacts_for_user(*user_id); - for connection_id in state.connection_ids_for_user(*user_id) { - send_futures.push(self.peer.send( - connection_id, - proto::UpdateContacts { - contacts: contacts.clone(), - }, - )); + ) -> anyhow::Result<()> { + let mut result = Ok(()); + let state = self.state(); + for user_id in user_ids { + let contacts = state.contacts_for_user(*user_id); + for connection_id in state.connection_ids_for_user(*user_id) { + if let Err(error) = self.peer.send( + connection_id, + proto::UpdateContacts { + contacts: contacts.clone(), + }, + ) { + result = Err(error); } } } - futures::future::try_join_all(send_futures).await?; - - Ok(()) + result } async fn join_channel( @@ -939,15 +893,13 @@ impl Server { nonce: Some(msg.nonce.as_u128().into()), }) .collect::>(); - self.peer - .respond( - request.receipt(), - proto::JoinChannelResponse { - done: messages.len() < MESSAGE_COUNT_PER_PAGE, - messages, - }, - ) - .await?; + self.peer.respond( + request.receipt(), + proto::JoinChannelResponse { + done: messages.len() < MESSAGE_COUNT_PER_PAGE, + messages, + }, + )?; Ok(()) } @@ -993,25 +945,21 @@ impl Server { // Validate the message body. let body = request.payload.body.trim().to_string(); if body.len() > MAX_MESSAGE_LEN { - self.peer - .respond_with_error( - receipt, - proto::Error { - message: "message is too long".to_string(), - }, - ) - .await?; + self.peer.respond_with_error( + receipt, + proto::Error { + message: "message is too long".to_string(), + }, + )?; return Ok(()); } if body.is_empty() { - self.peer - .respond_with_error( - receipt, - proto::Error { - message: "message can't be blank".to_string(), - }, - ) - .await?; + self.peer.respond_with_error( + receipt, + proto::Error { + message: "message can't be blank".to_string(), + }, + )?; return Ok(()); } @@ -1019,14 +967,12 @@ impl Server { let nonce = if let Some(nonce) = request.payload.nonce { nonce } else { - self.peer - .respond_with_error( - receipt, - proto::Error { - message: "nonce can't be blank".to_string(), - }, - ) - .await?; + self.peer.respond_with_error( + receipt, + proto::Error { + message: "nonce can't be blank".to_string(), + }, + )?; return Ok(()); }; @@ -1051,16 +997,13 @@ impl Server { message: Some(message.clone()), }, ) - }) - .await?; - self.peer - .respond( - receipt, - proto::SendChannelMessageResponse { - message: Some(message), - }, - ) - .await?; + })?; + self.peer.respond( + receipt, + proto::SendChannelMessageResponse { + message: Some(message), + }, + )?; Ok(()) } @@ -1097,15 +1040,13 @@ impl Server { nonce: Some(msg.nonce.as_u128().into()), }) .collect::>(); - self.peer - .respond( - request.receipt(), - proto::GetChannelMessagesResponse { - done: messages.len() < MESSAGE_COUNT_PER_PAGE, - messages, - }, - ) - .await?; + self.peer.respond( + request.receipt(), + proto::GetChannelMessagesResponse { + done: messages.len() < MESSAGE_COUNT_PER_PAGE, + messages, + }, + )?; Ok(()) } @@ -1118,21 +1059,25 @@ impl Server { } } -pub async fn broadcast( +fn broadcast( sender_id: ConnectionId, receiver_ids: Vec, mut f: F, ) -> anyhow::Result<()> where - F: FnMut(ConnectionId) -> T, - T: Future>, + F: FnMut(ConnectionId) -> anyhow::Result<()>, { - let futures = receiver_ids - .into_iter() - .filter(|id| *id != sender_id) - .map(|id| f(id)); - futures::future::try_join_all(futures).await?; - Ok(()) + let mut result = Ok(()); + for receiver_id in receiver_ids { + if receiver_id != sender_id { + if let Err(error) = f(receiver_id) { + if result.is_ok() { + result = Err(error); + } + } + } + } + result } pub fn add_routes(app: &mut tide::Server>, rpc: &Arc) {