diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 730879c0d1..184592f033 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -18,7 +18,7 @@ use axum::{ headers::{Header, HeaderName}, http::StatusCode, middleware, - response::{IntoResponse, Response}, + response::IntoResponse, routing::get, Extension, Router, TypedHeader, }; @@ -27,7 +27,7 @@ use futures::{channel::mpsc, future::BoxFuture, FutureExt, SinkExt, StreamExt, T use lazy_static::lazy_static; use rpc::{ proto::{self, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage}, - Connection, ConnectionId, Peer, TypedEnvelope, + Connection, ConnectionId, Peer, Receipt, TypedEnvelope, }; use std::{ any::TypeId, @@ -36,7 +36,10 @@ use std::{ net::SocketAddr, ops::{Deref, DerefMut}, rc::Rc, - sync::Arc, + sync::{ + atomic::{AtomicBool, Ordering::SeqCst}, + Arc, + }, time::Duration, }; use store::{Store, Worktree}; @@ -51,6 +54,20 @@ use tracing::{info_span, instrument, Instrument}; type MessageHandler = Box, Box) -> BoxFuture<'static, ()>>; +struct Response { + server: Arc, + receipt: Receipt, + responded: Arc, +} + +impl Response { + fn send(self, payload: R::Response) -> Result<()> { + self.responded.store(true, SeqCst); + self.server.peer.respond(self.receipt, payload)?; + Ok(()) + } +} + pub struct Server { peer: Arc, store: RwLock, @@ -100,7 +117,7 @@ impl Server { .add_message_handler(Server::unregister_project) .add_request_handler(Server::share_project) .add_message_handler(Server::unshare_project) - .add_sync_request_handler(Server::join_project) + .add_request_handler(Server::join_project) .add_message_handler(Server::leave_project) .add_request_handler(Server::register_worktree) .add_message_handler(Server::unregister_worktree) @@ -179,43 +196,12 @@ impl Server { self } - fn add_request_handler(&mut self, handler: F) -> &mut Self - where - F: 'static + Send + Sync + Fn(Arc, TypedEnvelope) -> Fut, - Fut: 'static + Send + Future>, - M: RequestMessage, - { - self.add_message_handler(move |server, envelope| { - let receipt = envelope.receipt(); - let response = (handler)(server.clone(), envelope); - async move { - match response.await { - Ok(response) => { - server.peer.respond(receipt, response)?; - Ok(()) - } - Err(error) => { - server.peer.respond_with_error( - receipt, - proto::Error { - message: error.to_string(), - }, - )?; - Err(error) - } - } - } - }) - } - /// Handle a request while holding a lock to the store. This is useful when we're registering /// a connection but we want to respond on the connection before anybody else can send on it. - fn add_sync_request_handler(&mut self, handler: F) -> &mut Self + fn add_request_handler(&mut self, handler: F) -> &mut Self where - F: 'static - + Send - + Sync - + Fn(Arc, &mut Store, TypedEnvelope) -> Result, + F: 'static + Send + Sync + Fn(Arc, TypedEnvelope, Response) -> Fut, + Fut: Send + Future>, M: RequestMessage, { let handler = Arc::new(handler); @@ -223,12 +209,19 @@ impl Server { let receipt = envelope.receipt(); let handler = handler.clone(); async move { - let mut store = server.state_mut().await; - let response = (handler)(server.clone(), &mut *store, envelope); - match response { - Ok(response) => { - server.peer.respond(receipt, response)?; - Ok(()) + let responded = Arc::new(AtomicBool::default()); + let response = Response { + server: server.clone(), + responded: responded.clone(), + receipt: envelope.receipt(), + }; + match (handler)(server.clone(), envelope, response).await { + Ok(()) => { + if responded.load(std::sync::atomic::Ordering::SeqCst) { + Ok(()) + } else { + Err(anyhow!("handler did not send a response"))? + } } Err(error) => { server.peer.respond_with_error( @@ -364,20 +357,27 @@ impl Server { Ok(()) } - async fn ping(self: Arc, _: TypedEnvelope) -> Result { - Ok(proto::Ack {}) + async fn ping( + self: Arc, + _: TypedEnvelope, + response: Response, + ) -> Result<()> { + response.send(proto::Ack {})?; + Ok(()) } async fn register_project( self: Arc, request: TypedEnvelope, - ) -> Result { + response: Response, + ) -> Result<()> { let project_id = { let mut state = self.state_mut().await; let user_id = state.user_id_for_connection(request.sender_id)?; state.register_project(request.sender_id, user_id) }; - Ok(proto::RegisterProjectResponse { project_id }) + response.send(proto::RegisterProjectResponse { project_id })?; + Ok(()) } async fn unregister_project( @@ -393,11 +393,13 @@ impl Server { async fn share_project( self: Arc, request: TypedEnvelope, - ) -> Result { + response: Response, + ) -> Result<()> { let mut state = self.state_mut().await; let project = state.share_project(request.payload.project_id, request.sender_id)?; self.update_contacts_for_users(&mut *state, &project.authorized_user_ids); - Ok(proto::Ack {}) + response.send(proto::Ack {})?; + Ok(()) } async fn unshare_project( @@ -415,15 +417,16 @@ impl Server { Ok(()) } - fn join_project( + async fn join_project( self: Arc, - state: &mut Store, request: TypedEnvelope, - ) -> Result { + response: Response, + ) -> Result<()> { let project_id = request.payload.project_id; + let state = &mut *self.state_mut().await; let user_id = state.user_id_for_connection(request.sender_id)?; - let (response, connection_ids, contact_user_ids) = state + let (response_payload, connection_ids, contact_user_ids) = state .join_project(request.sender_id, user_id, project_id) .and_then(|joined| { let share = joined.project.share()?; @@ -480,14 +483,15 @@ impl Server { project_id, collaborator: Some(proto::Collaborator { peer_id: request.sender_id.0, - replica_id: response.replica_id, + replica_id: response_payload.replica_id, user_id: user_id.to_proto(), }), }, ) }); self.update_contacts_for_users(state, &contact_user_ids); - Ok(response) + response.send(response_payload)?; + Ok(()) } async fn leave_project( @@ -514,7 +518,8 @@ impl Server { async fn register_worktree( self: Arc, request: TypedEnvelope, - ) -> Result { + response: Response, + ) -> Result<()> { let mut contact_user_ids = HashSet::default(); for github_login in &request.payload.authorized_logins { let contact_user_id = self.app_state.db.create_user(github_login, false).await?; @@ -545,7 +550,8 @@ impl Server { .forward_send(request.sender_id, connection_id, request.payload.clone()) }); self.update_contacts_for_users(&*state, &contact_user_ids); - Ok(proto::Ack {}) + response.send(proto::Ack {})?; + Ok(()) } async fn unregister_worktree( @@ -573,7 +579,8 @@ impl Server { async fn update_worktree( self: Arc, request: TypedEnvelope, - ) -> Result { + response: Response, + ) -> Result<()> { let connection_ids = self.state_mut().await.update_worktree( request.sender_id, request.payload.project_id, @@ -587,8 +594,8 @@ impl Server { self.peer .forward_send(request.sender_id, connection_id, request.payload.clone()) }); - - Ok(proto::Ack {}) + response.send(proto::Ack {})?; + Ok(()) } async fn update_diagnostic_summary( @@ -652,7 +659,8 @@ impl Server { async fn forward_project_request( self: Arc, request: TypedEnvelope, - ) -> Result + response: Response, + ) -> Result<()> where T: EntityMessage + RequestMessage, { @@ -661,22 +669,26 @@ impl Server { .await .read_project(request.payload.remote_entity_id(), request.sender_id)? .host_connection_id; - Ok(self - .peer - .forward_request(request.sender_id, host_connection_id, request.payload) - .await?) + + response.send( + self.peer + .forward_request(request.sender_id, host_connection_id, request.payload) + .await?, + )?; + Ok(()) } async fn save_buffer( self: Arc, request: TypedEnvelope, - ) -> Result { + response: Response, + ) -> Result<()> { let host = self .state() .await .read_project(request.payload.project_id, request.sender_id)? .host_connection_id; - let response = self + let response_payload = self .peer .forward_request(request.sender_id, host, request.payload.clone()) .await?; @@ -688,16 +700,18 @@ impl Server { .connection_ids(); guests.retain(|guest_connection_id| *guest_connection_id != request.sender_id); broadcast(host, guests, |conn_id| { - self.peer.forward_send(host, conn_id, response.clone()) + self.peer + .forward_send(host, conn_id, response_payload.clone()) }); - - Ok(response) + response.send(response_payload)?; + Ok(()) } async fn update_buffer( self: Arc, request: TypedEnvelope, - ) -> Result { + response: Response, + ) -> Result<()> { let receiver_ids = self .state() .await @@ -706,7 +720,8 @@ impl Server { self.peer .forward_send(request.sender_id, connection_id, request.payload.clone()) }); - Ok(proto::Ack {}) + response.send(proto::Ack {})?; + Ok(()) } async fn update_buffer_file( @@ -757,7 +772,8 @@ impl Server { async fn follow( self: Arc, request: TypedEnvelope, - ) -> Result { + response: Response, + ) -> Result<()> { let leader_id = ConnectionId(request.payload.leader_id); let follower_id = request.sender_id; if !self @@ -768,14 +784,15 @@ impl Server { { Err(anyhow!("no such peer"))?; } - let mut response = self + let mut response_payload = self .peer .forward_request(request.sender_id, leader_id, request.payload) .await?; - response + response_payload .views .retain(|view| view.leader_id != Some(follower_id.0)); - Ok(response) + response.send(response_payload)?; + Ok(()) } async fn unfollow(self: Arc, request: TypedEnvelope) -> Result<()> { @@ -823,13 +840,14 @@ impl Server { async fn get_channels( self: Arc, request: TypedEnvelope, - ) -> Result { + response: Response, + ) -> Result<()> { let user_id = self .state() .await .user_id_for_connection(request.sender_id)?; let channels = self.app_state.db.get_accessible_channels(user_id).await?; - Ok(proto::GetChannelsResponse { + response.send(proto::GetChannelsResponse { channels: channels .into_iter() .map(|chan| proto::Channel { @@ -837,13 +855,15 @@ impl Server { name: chan.name, }) .collect(), - }) + })?; + Ok(()) } async fn get_users( self: Arc, request: TypedEnvelope, - ) -> Result { + response: Response, + ) -> Result<()> { let user_ids = request .payload .user_ids @@ -862,13 +882,15 @@ impl Server { github_login: user.github_login, }) .collect(); - Ok(proto::UsersResponse { users }) + response.send(proto::UsersResponse { users })?; + Ok(()) } async fn fuzzy_search_users( self: Arc, request: TypedEnvelope, - ) -> Result { + response: Response, + ) -> Result<()> { let query = request.payload.query; let db = &self.app_state.db; let users = match query.len() { @@ -888,7 +910,8 @@ impl Server { github_login: user.github_login, }) .collect(); - Ok(proto::UsersResponse { users }) + response.send(proto::UsersResponse { users })?; + Ok(()) } #[instrument(skip(self, state, user_ids))] @@ -917,7 +940,8 @@ impl Server { async fn join_channel( self: Arc, request: TypedEnvelope, - ) -> Result { + response: Response, + ) -> Result<()> { let user_id = self .state() .await @@ -949,10 +973,11 @@ impl Server { nonce: Some(msg.nonce.as_u128().into()), }) .collect::>(); - Ok(proto::JoinChannelResponse { + response.send(proto::JoinChannelResponse { done: messages.len() < MESSAGE_COUNT_PER_PAGE, messages, - }) + })?; + Ok(()) } async fn leave_channel( @@ -983,7 +1008,8 @@ impl Server { async fn send_channel_message( self: Arc, request: TypedEnvelope, - ) -> Result { + response: Response, + ) -> Result<()> { let channel_id = ChannelId::from_proto(request.payload.channel_id); let user_id; let connection_ids; @@ -1030,15 +1056,17 @@ impl Server { }, ) }); - Ok(proto::SendChannelMessageResponse { + response.send(proto::SendChannelMessageResponse { message: Some(message), - }) + })?; + Ok(()) } async fn get_channel_messages( self: Arc, request: TypedEnvelope, - ) -> Result { + response: Response, + ) -> Result<()> { let user_id = self .state() .await @@ -1071,11 +1099,11 @@ impl Server { nonce: Some(msg.nonce.as_u128().into()), }) .collect::>(); - - Ok(proto::GetChannelMessagesResponse { + response.send(proto::GetChannelMessagesResponse { done: messages.len() < MESSAGE_COUNT_PER_PAGE, messages, - }) + })?; + Ok(()) } async fn state<'a>(self: &'a Arc) -> StoreReadGuard<'a> { @@ -1213,7 +1241,7 @@ pub async fn handle_websocket_request( Extension(server): Extension>, Extension(user_id): Extension, ws: WebSocketUpgrade, -) -> Response { +) -> axum::response::Response { if protocol_version != rpc::PROTOCOL_VERSION { return ( StatusCode::UPGRADE_REQUIRED,