Refactor add_request_handler to respond via a Response struct

This also removes `add_sync_request_handler`.

Co-Authored-By: Nathan Sobo <nathan@zed.dev>
This commit is contained in:
Antonio Scandurra 2022-05-06 17:01:27 +02:00
parent 9555b93bca
commit 989b82d664

View File

@ -18,7 +18,7 @@ use axum::{
headers::{Header, HeaderName},
http::StatusCode,
middleware,
response::{IntoResponse, Response},
response::IntoResponse,
routing::get,
Extension, Router, TypedHeader,
};
@ -27,7 +27,7 @@ use futures::{channel::mpsc, future::BoxFuture, FutureExt, SinkExt, StreamExt, T
use lazy_static::lazy_static;
use rpc::{
proto::{self, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage},
Connection, ConnectionId, Peer, TypedEnvelope,
Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
};
use std::{
any::TypeId,
@ -36,7 +36,10 @@ use std::{
net::SocketAddr,
ops::{Deref, DerefMut},
rc::Rc,
sync::Arc,
sync::{
atomic::{AtomicBool, Ordering::SeqCst},
Arc,
},
time::Duration,
};
use store::{Store, Worktree};
@ -51,6 +54,20 @@ use tracing::{info_span, instrument, Instrument};
type MessageHandler =
Box<dyn Send + Sync + Fn(Arc<Server>, Box<dyn AnyTypedEnvelope>) -> BoxFuture<'static, ()>>;
struct Response<R> {
server: Arc<Server>,
receipt: Receipt<R>,
responded: Arc<AtomicBool>,
}
impl<R: RequestMessage> Response<R> {
fn send(self, payload: R::Response) -> Result<()> {
self.responded.store(true, SeqCst);
self.server.peer.respond(self.receipt, payload)?;
Ok(())
}
}
pub struct Server {
peer: Arc<Peer>,
store: RwLock<Store>,
@ -100,7 +117,7 @@ impl Server {
.add_message_handler(Server::unregister_project)
.add_request_handler(Server::share_project)
.add_message_handler(Server::unshare_project)
.add_sync_request_handler(Server::join_project)
.add_request_handler(Server::join_project)
.add_message_handler(Server::leave_project)
.add_request_handler(Server::register_worktree)
.add_message_handler(Server::unregister_worktree)
@ -179,43 +196,12 @@ impl Server {
self
}
fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
where
F: 'static + Send + Sync + Fn(Arc<Self>, TypedEnvelope<M>) -> Fut,
Fut: 'static + Send + Future<Output = Result<M::Response>>,
M: RequestMessage,
{
self.add_message_handler(move |server, envelope| {
let receipt = envelope.receipt();
let response = (handler)(server.clone(), envelope);
async move {
match response.await {
Ok(response) => {
server.peer.respond(receipt, response)?;
Ok(())
}
Err(error) => {
server.peer.respond_with_error(
receipt,
proto::Error {
message: error.to_string(),
},
)?;
Err(error)
}
}
}
})
}
/// Handle a request while holding a lock to the store. This is useful when we're registering
/// a connection but we want to respond on the connection before anybody else can send on it.
fn add_sync_request_handler<F, M>(&mut self, handler: F) -> &mut Self
fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
where
F: 'static
+ Send
+ Sync
+ Fn(Arc<Self>, &mut Store, TypedEnvelope<M>) -> Result<M::Response>,
F: 'static + Send + Sync + Fn(Arc<Self>, TypedEnvelope<M>, Response<M>) -> Fut,
Fut: Send + Future<Output = Result<()>>,
M: RequestMessage,
{
let handler = Arc::new(handler);
@ -223,12 +209,19 @@ impl Server {
let receipt = envelope.receipt();
let handler = handler.clone();
async move {
let mut store = server.state_mut().await;
let response = (handler)(server.clone(), &mut *store, envelope);
match response {
Ok(response) => {
server.peer.respond(receipt, response)?;
let responded = Arc::new(AtomicBool::default());
let response = Response {
server: server.clone(),
responded: responded.clone(),
receipt: envelope.receipt(),
};
match (handler)(server.clone(), envelope, response).await {
Ok(()) => {
if responded.load(std::sync::atomic::Ordering::SeqCst) {
Ok(())
} else {
Err(anyhow!("handler did not send a response"))?
}
}
Err(error) => {
server.peer.respond_with_error(
@ -364,20 +357,27 @@ impl Server {
Ok(())
}
async fn ping(self: Arc<Server>, _: TypedEnvelope<proto::Ping>) -> Result<proto::Ack> {
Ok(proto::Ack {})
async fn ping(
self: Arc<Server>,
_: TypedEnvelope<proto::Ping>,
response: Response<proto::Ping>,
) -> Result<()> {
response.send(proto::Ack {})?;
Ok(())
}
async fn register_project(
self: Arc<Server>,
request: TypedEnvelope<proto::RegisterProject>,
) -> Result<proto::RegisterProjectResponse> {
response: Response<proto::RegisterProject>,
) -> Result<()> {
let project_id = {
let mut state = self.state_mut().await;
let user_id = state.user_id_for_connection(request.sender_id)?;
state.register_project(request.sender_id, user_id)
};
Ok(proto::RegisterProjectResponse { project_id })
response.send(proto::RegisterProjectResponse { project_id })?;
Ok(())
}
async fn unregister_project(
@ -393,11 +393,13 @@ impl Server {
async fn share_project(
self: Arc<Server>,
request: TypedEnvelope<proto::ShareProject>,
) -> Result<proto::Ack> {
response: Response<proto::ShareProject>,
) -> Result<()> {
let mut state = self.state_mut().await;
let project = state.share_project(request.payload.project_id, request.sender_id)?;
self.update_contacts_for_users(&mut *state, &project.authorized_user_ids);
Ok(proto::Ack {})
response.send(proto::Ack {})?;
Ok(())
}
async fn unshare_project(
@ -415,15 +417,16 @@ impl Server {
Ok(())
}
fn join_project(
async fn join_project(
self: Arc<Server>,
state: &mut Store,
request: TypedEnvelope<proto::JoinProject>,
) -> Result<proto::JoinProjectResponse> {
response: Response<proto::JoinProject>,
) -> Result<()> {
let project_id = request.payload.project_id;
let state = &mut *self.state_mut().await;
let user_id = state.user_id_for_connection(request.sender_id)?;
let (response, connection_ids, contact_user_ids) = state
let (response_payload, connection_ids, contact_user_ids) = state
.join_project(request.sender_id, user_id, project_id)
.and_then(|joined| {
let share = joined.project.share()?;
@ -480,14 +483,15 @@ impl Server {
project_id,
collaborator: Some(proto::Collaborator {
peer_id: request.sender_id.0,
replica_id: response.replica_id,
replica_id: response_payload.replica_id,
user_id: user_id.to_proto(),
}),
},
)
});
self.update_contacts_for_users(state, &contact_user_ids);
Ok(response)
response.send(response_payload)?;
Ok(())
}
async fn leave_project(
@ -514,7 +518,8 @@ impl Server {
async fn register_worktree(
self: Arc<Server>,
request: TypedEnvelope<proto::RegisterWorktree>,
) -> Result<proto::Ack> {
response: Response<proto::RegisterWorktree>,
) -> Result<()> {
let mut contact_user_ids = HashSet::default();
for github_login in &request.payload.authorized_logins {
let contact_user_id = self.app_state.db.create_user(github_login, false).await?;
@ -545,7 +550,8 @@ impl Server {
.forward_send(request.sender_id, connection_id, request.payload.clone())
});
self.update_contacts_for_users(&*state, &contact_user_ids);
Ok(proto::Ack {})
response.send(proto::Ack {})?;
Ok(())
}
async fn unregister_worktree(
@ -573,7 +579,8 @@ impl Server {
async fn update_worktree(
self: Arc<Server>,
request: TypedEnvelope<proto::UpdateWorktree>,
) -> Result<proto::Ack> {
response: Response<proto::UpdateWorktree>,
) -> Result<()> {
let connection_ids = self.state_mut().await.update_worktree(
request.sender_id,
request.payload.project_id,
@ -587,8 +594,8 @@ impl Server {
self.peer
.forward_send(request.sender_id, connection_id, request.payload.clone())
});
Ok(proto::Ack {})
response.send(proto::Ack {})?;
Ok(())
}
async fn update_diagnostic_summary(
@ -652,7 +659,8 @@ impl Server {
async fn forward_project_request<T>(
self: Arc<Server>,
request: TypedEnvelope<T>,
) -> Result<T::Response>
response: Response<T>,
) -> Result<()>
where
T: EntityMessage + RequestMessage,
{
@ -661,22 +669,26 @@ impl Server {
.await
.read_project(request.payload.remote_entity_id(), request.sender_id)?
.host_connection_id;
Ok(self
.peer
response.send(
self.peer
.forward_request(request.sender_id, host_connection_id, request.payload)
.await?)
.await?,
)?;
Ok(())
}
async fn save_buffer(
self: Arc<Server>,
request: TypedEnvelope<proto::SaveBuffer>,
) -> Result<proto::BufferSaved> {
response: Response<proto::SaveBuffer>,
) -> Result<()> {
let host = self
.state()
.await
.read_project(request.payload.project_id, request.sender_id)?
.host_connection_id;
let response = self
let response_payload = self
.peer
.forward_request(request.sender_id, host, request.payload.clone())
.await?;
@ -688,16 +700,18 @@ impl Server {
.connection_ids();
guests.retain(|guest_connection_id| *guest_connection_id != request.sender_id);
broadcast(host, guests, |conn_id| {
self.peer.forward_send(host, conn_id, response.clone())
self.peer
.forward_send(host, conn_id, response_payload.clone())
});
Ok(response)
response.send(response_payload)?;
Ok(())
}
async fn update_buffer(
self: Arc<Server>,
request: TypedEnvelope<proto::UpdateBuffer>,
) -> Result<proto::Ack> {
response: Response<proto::UpdateBuffer>,
) -> Result<()> {
let receiver_ids = self
.state()
.await
@ -706,7 +720,8 @@ impl Server {
self.peer
.forward_send(request.sender_id, connection_id, request.payload.clone())
});
Ok(proto::Ack {})
response.send(proto::Ack {})?;
Ok(())
}
async fn update_buffer_file(
@ -757,7 +772,8 @@ impl Server {
async fn follow(
self: Arc<Self>,
request: TypedEnvelope<proto::Follow>,
) -> Result<proto::FollowResponse> {
response: Response<proto::Follow>,
) -> Result<()> {
let leader_id = ConnectionId(request.payload.leader_id);
let follower_id = request.sender_id;
if !self
@ -768,14 +784,15 @@ impl Server {
{
Err(anyhow!("no such peer"))?;
}
let mut response = self
let mut response_payload = self
.peer
.forward_request(request.sender_id, leader_id, request.payload)
.await?;
response
response_payload
.views
.retain(|view| view.leader_id != Some(follower_id.0));
Ok(response)
response.send(response_payload)?;
Ok(())
}
async fn unfollow(self: Arc<Self>, request: TypedEnvelope<proto::Unfollow>) -> Result<()> {
@ -823,13 +840,14 @@ impl Server {
async fn get_channels(
self: Arc<Server>,
request: TypedEnvelope<proto::GetChannels>,
) -> Result<proto::GetChannelsResponse> {
response: Response<proto::GetChannels>,
) -> Result<()> {
let user_id = self
.state()
.await
.user_id_for_connection(request.sender_id)?;
let channels = self.app_state.db.get_accessible_channels(user_id).await?;
Ok(proto::GetChannelsResponse {
response.send(proto::GetChannelsResponse {
channels: channels
.into_iter()
.map(|chan| proto::Channel {
@ -837,13 +855,15 @@ impl Server {
name: chan.name,
})
.collect(),
})
})?;
Ok(())
}
async fn get_users(
self: Arc<Server>,
request: TypedEnvelope<proto::GetUsers>,
) -> Result<proto::UsersResponse> {
response: Response<proto::GetUsers>,
) -> Result<()> {
let user_ids = request
.payload
.user_ids
@ -862,13 +882,15 @@ impl Server {
github_login: user.github_login,
})
.collect();
Ok(proto::UsersResponse { users })
response.send(proto::UsersResponse { users })?;
Ok(())
}
async fn fuzzy_search_users(
self: Arc<Server>,
request: TypedEnvelope<proto::FuzzySearchUsers>,
) -> Result<proto::UsersResponse> {
response: Response<proto::FuzzySearchUsers>,
) -> Result<()> {
let query = request.payload.query;
let db = &self.app_state.db;
let users = match query.len() {
@ -888,7 +910,8 @@ impl Server {
github_login: user.github_login,
})
.collect();
Ok(proto::UsersResponse { users })
response.send(proto::UsersResponse { users })?;
Ok(())
}
#[instrument(skip(self, state, user_ids))]
@ -917,7 +940,8 @@ impl Server {
async fn join_channel(
self: Arc<Self>,
request: TypedEnvelope<proto::JoinChannel>,
) -> Result<proto::JoinChannelResponse> {
response: Response<proto::JoinChannel>,
) -> Result<()> {
let user_id = self
.state()
.await
@ -949,10 +973,11 @@ impl Server {
nonce: Some(msg.nonce.as_u128().into()),
})
.collect::<Vec<_>>();
Ok(proto::JoinChannelResponse {
response.send(proto::JoinChannelResponse {
done: messages.len() < MESSAGE_COUNT_PER_PAGE,
messages,
})
})?;
Ok(())
}
async fn leave_channel(
@ -983,7 +1008,8 @@ impl Server {
async fn send_channel_message(
self: Arc<Self>,
request: TypedEnvelope<proto::SendChannelMessage>,
) -> Result<proto::SendChannelMessageResponse> {
response: Response<proto::SendChannelMessage>,
) -> Result<()> {
let channel_id = ChannelId::from_proto(request.payload.channel_id);
let user_id;
let connection_ids;
@ -1030,15 +1056,17 @@ impl Server {
},
)
});
Ok(proto::SendChannelMessageResponse {
response.send(proto::SendChannelMessageResponse {
message: Some(message),
})
})?;
Ok(())
}
async fn get_channel_messages(
self: Arc<Self>,
request: TypedEnvelope<proto::GetChannelMessages>,
) -> Result<proto::GetChannelMessagesResponse> {
response: Response<proto::GetChannelMessages>,
) -> Result<()> {
let user_id = self
.state()
.await
@ -1071,11 +1099,11 @@ impl Server {
nonce: Some(msg.nonce.as_u128().into()),
})
.collect::<Vec<_>>();
Ok(proto::GetChannelMessagesResponse {
response.send(proto::GetChannelMessagesResponse {
done: messages.len() < MESSAGE_COUNT_PER_PAGE,
messages,
})
})?;
Ok(())
}
async fn state<'a>(self: &'a Arc<Self>) -> StoreReadGuard<'a> {
@ -1213,7 +1241,7 @@ pub async fn handle_websocket_request(
Extension(server): Extension<Arc<Server>>,
Extension(user_id): Extension<UserId>,
ws: WebSocketUpgrade,
) -> Response {
) -> axum::response::Response {
if protocol_version != rpc::PROTOCOL_VERSION {
return (
StatusCode::UPGRADE_REQUIRED,