mirror of
https://github.com/zed-industries/zed.git
synced 2025-01-06 00:23:05 +03:00
Refactor add_request_handler
to respond via a Response
struct
This also removes `add_sync_request_handler`. Co-Authored-By: Nathan Sobo <nathan@zed.dev>
This commit is contained in:
parent
9555b93bca
commit
989b82d664
@ -18,7 +18,7 @@ use axum::{
|
||||
headers::{Header, HeaderName},
|
||||
http::StatusCode,
|
||||
middleware,
|
||||
response::{IntoResponse, Response},
|
||||
response::IntoResponse,
|
||||
routing::get,
|
||||
Extension, Router, TypedHeader,
|
||||
};
|
||||
@ -27,7 +27,7 @@ use futures::{channel::mpsc, future::BoxFuture, FutureExt, SinkExt, StreamExt, T
|
||||
use lazy_static::lazy_static;
|
||||
use rpc::{
|
||||
proto::{self, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage},
|
||||
Connection, ConnectionId, Peer, TypedEnvelope,
|
||||
Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
|
||||
};
|
||||
use std::{
|
||||
any::TypeId,
|
||||
@ -36,7 +36,10 @@ use std::{
|
||||
net::SocketAddr,
|
||||
ops::{Deref, DerefMut},
|
||||
rc::Rc,
|
||||
sync::Arc,
|
||||
sync::{
|
||||
atomic::{AtomicBool, Ordering::SeqCst},
|
||||
Arc,
|
||||
},
|
||||
time::Duration,
|
||||
};
|
||||
use store::{Store, Worktree};
|
||||
@ -51,6 +54,20 @@ use tracing::{info_span, instrument, Instrument};
|
||||
type MessageHandler =
|
||||
Box<dyn Send + Sync + Fn(Arc<Server>, Box<dyn AnyTypedEnvelope>) -> BoxFuture<'static, ()>>;
|
||||
|
||||
struct Response<R> {
|
||||
server: Arc<Server>,
|
||||
receipt: Receipt<R>,
|
||||
responded: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
impl<R: RequestMessage> Response<R> {
|
||||
fn send(self, payload: R::Response) -> Result<()> {
|
||||
self.responded.store(true, SeqCst);
|
||||
self.server.peer.respond(self.receipt, payload)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Server {
|
||||
peer: Arc<Peer>,
|
||||
store: RwLock<Store>,
|
||||
@ -100,7 +117,7 @@ impl Server {
|
||||
.add_message_handler(Server::unregister_project)
|
||||
.add_request_handler(Server::share_project)
|
||||
.add_message_handler(Server::unshare_project)
|
||||
.add_sync_request_handler(Server::join_project)
|
||||
.add_request_handler(Server::join_project)
|
||||
.add_message_handler(Server::leave_project)
|
||||
.add_request_handler(Server::register_worktree)
|
||||
.add_message_handler(Server::unregister_worktree)
|
||||
@ -179,43 +196,12 @@ impl Server {
|
||||
self
|
||||
}
|
||||
|
||||
fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
|
||||
where
|
||||
F: 'static + Send + Sync + Fn(Arc<Self>, TypedEnvelope<M>) -> Fut,
|
||||
Fut: 'static + Send + Future<Output = Result<M::Response>>,
|
||||
M: RequestMessage,
|
||||
{
|
||||
self.add_message_handler(move |server, envelope| {
|
||||
let receipt = envelope.receipt();
|
||||
let response = (handler)(server.clone(), envelope);
|
||||
async move {
|
||||
match response.await {
|
||||
Ok(response) => {
|
||||
server.peer.respond(receipt, response)?;
|
||||
Ok(())
|
||||
}
|
||||
Err(error) => {
|
||||
server.peer.respond_with_error(
|
||||
receipt,
|
||||
proto::Error {
|
||||
message: error.to_string(),
|
||||
},
|
||||
)?;
|
||||
Err(error)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Handle a request while holding a lock to the store. This is useful when we're registering
|
||||
/// a connection but we want to respond on the connection before anybody else can send on it.
|
||||
fn add_sync_request_handler<F, M>(&mut self, handler: F) -> &mut Self
|
||||
fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
|
||||
where
|
||||
F: 'static
|
||||
+ Send
|
||||
+ Sync
|
||||
+ Fn(Arc<Self>, &mut Store, TypedEnvelope<M>) -> Result<M::Response>,
|
||||
F: 'static + Send + Sync + Fn(Arc<Self>, TypedEnvelope<M>, Response<M>) -> Fut,
|
||||
Fut: Send + Future<Output = Result<()>>,
|
||||
M: RequestMessage,
|
||||
{
|
||||
let handler = Arc::new(handler);
|
||||
@ -223,12 +209,19 @@ impl Server {
|
||||
let receipt = envelope.receipt();
|
||||
let handler = handler.clone();
|
||||
async move {
|
||||
let mut store = server.state_mut().await;
|
||||
let response = (handler)(server.clone(), &mut *store, envelope);
|
||||
match response {
|
||||
Ok(response) => {
|
||||
server.peer.respond(receipt, response)?;
|
||||
Ok(())
|
||||
let responded = Arc::new(AtomicBool::default());
|
||||
let response = Response {
|
||||
server: server.clone(),
|
||||
responded: responded.clone(),
|
||||
receipt: envelope.receipt(),
|
||||
};
|
||||
match (handler)(server.clone(), envelope, response).await {
|
||||
Ok(()) => {
|
||||
if responded.load(std::sync::atomic::Ordering::SeqCst) {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(anyhow!("handler did not send a response"))?
|
||||
}
|
||||
}
|
||||
Err(error) => {
|
||||
server.peer.respond_with_error(
|
||||
@ -364,20 +357,27 @@ impl Server {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn ping(self: Arc<Server>, _: TypedEnvelope<proto::Ping>) -> Result<proto::Ack> {
|
||||
Ok(proto::Ack {})
|
||||
async fn ping(
|
||||
self: Arc<Server>,
|
||||
_: TypedEnvelope<proto::Ping>,
|
||||
response: Response<proto::Ping>,
|
||||
) -> Result<()> {
|
||||
response.send(proto::Ack {})?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn register_project(
|
||||
self: Arc<Server>,
|
||||
request: TypedEnvelope<proto::RegisterProject>,
|
||||
) -> Result<proto::RegisterProjectResponse> {
|
||||
response: Response<proto::RegisterProject>,
|
||||
) -> Result<()> {
|
||||
let project_id = {
|
||||
let mut state = self.state_mut().await;
|
||||
let user_id = state.user_id_for_connection(request.sender_id)?;
|
||||
state.register_project(request.sender_id, user_id)
|
||||
};
|
||||
Ok(proto::RegisterProjectResponse { project_id })
|
||||
response.send(proto::RegisterProjectResponse { project_id })?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn unregister_project(
|
||||
@ -393,11 +393,13 @@ impl Server {
|
||||
async fn share_project(
|
||||
self: Arc<Server>,
|
||||
request: TypedEnvelope<proto::ShareProject>,
|
||||
) -> Result<proto::Ack> {
|
||||
response: Response<proto::ShareProject>,
|
||||
) -> Result<()> {
|
||||
let mut state = self.state_mut().await;
|
||||
let project = state.share_project(request.payload.project_id, request.sender_id)?;
|
||||
self.update_contacts_for_users(&mut *state, &project.authorized_user_ids);
|
||||
Ok(proto::Ack {})
|
||||
response.send(proto::Ack {})?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn unshare_project(
|
||||
@ -415,15 +417,16 @@ impl Server {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn join_project(
|
||||
async fn join_project(
|
||||
self: Arc<Server>,
|
||||
state: &mut Store,
|
||||
request: TypedEnvelope<proto::JoinProject>,
|
||||
) -> Result<proto::JoinProjectResponse> {
|
||||
response: Response<proto::JoinProject>,
|
||||
) -> Result<()> {
|
||||
let project_id = request.payload.project_id;
|
||||
|
||||
let state = &mut *self.state_mut().await;
|
||||
let user_id = state.user_id_for_connection(request.sender_id)?;
|
||||
let (response, connection_ids, contact_user_ids) = state
|
||||
let (response_payload, connection_ids, contact_user_ids) = state
|
||||
.join_project(request.sender_id, user_id, project_id)
|
||||
.and_then(|joined| {
|
||||
let share = joined.project.share()?;
|
||||
@ -480,14 +483,15 @@ impl Server {
|
||||
project_id,
|
||||
collaborator: Some(proto::Collaborator {
|
||||
peer_id: request.sender_id.0,
|
||||
replica_id: response.replica_id,
|
||||
replica_id: response_payload.replica_id,
|
||||
user_id: user_id.to_proto(),
|
||||
}),
|
||||
},
|
||||
)
|
||||
});
|
||||
self.update_contacts_for_users(state, &contact_user_ids);
|
||||
Ok(response)
|
||||
response.send(response_payload)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn leave_project(
|
||||
@ -514,7 +518,8 @@ impl Server {
|
||||
async fn register_worktree(
|
||||
self: Arc<Server>,
|
||||
request: TypedEnvelope<proto::RegisterWorktree>,
|
||||
) -> Result<proto::Ack> {
|
||||
response: Response<proto::RegisterWorktree>,
|
||||
) -> Result<()> {
|
||||
let mut contact_user_ids = HashSet::default();
|
||||
for github_login in &request.payload.authorized_logins {
|
||||
let contact_user_id = self.app_state.db.create_user(github_login, false).await?;
|
||||
@ -545,7 +550,8 @@ impl Server {
|
||||
.forward_send(request.sender_id, connection_id, request.payload.clone())
|
||||
});
|
||||
self.update_contacts_for_users(&*state, &contact_user_ids);
|
||||
Ok(proto::Ack {})
|
||||
response.send(proto::Ack {})?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn unregister_worktree(
|
||||
@ -573,7 +579,8 @@ impl Server {
|
||||
async fn update_worktree(
|
||||
self: Arc<Server>,
|
||||
request: TypedEnvelope<proto::UpdateWorktree>,
|
||||
) -> Result<proto::Ack> {
|
||||
response: Response<proto::UpdateWorktree>,
|
||||
) -> Result<()> {
|
||||
let connection_ids = self.state_mut().await.update_worktree(
|
||||
request.sender_id,
|
||||
request.payload.project_id,
|
||||
@ -587,8 +594,8 @@ impl Server {
|
||||
self.peer
|
||||
.forward_send(request.sender_id, connection_id, request.payload.clone())
|
||||
});
|
||||
|
||||
Ok(proto::Ack {})
|
||||
response.send(proto::Ack {})?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn update_diagnostic_summary(
|
||||
@ -652,7 +659,8 @@ impl Server {
|
||||
async fn forward_project_request<T>(
|
||||
self: Arc<Server>,
|
||||
request: TypedEnvelope<T>,
|
||||
) -> Result<T::Response>
|
||||
response: Response<T>,
|
||||
) -> Result<()>
|
||||
where
|
||||
T: EntityMessage + RequestMessage,
|
||||
{
|
||||
@ -661,22 +669,26 @@ impl Server {
|
||||
.await
|
||||
.read_project(request.payload.remote_entity_id(), request.sender_id)?
|
||||
.host_connection_id;
|
||||
Ok(self
|
||||
.peer
|
||||
.forward_request(request.sender_id, host_connection_id, request.payload)
|
||||
.await?)
|
||||
|
||||
response.send(
|
||||
self.peer
|
||||
.forward_request(request.sender_id, host_connection_id, request.payload)
|
||||
.await?,
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn save_buffer(
|
||||
self: Arc<Server>,
|
||||
request: TypedEnvelope<proto::SaveBuffer>,
|
||||
) -> Result<proto::BufferSaved> {
|
||||
response: Response<proto::SaveBuffer>,
|
||||
) -> Result<()> {
|
||||
let host = self
|
||||
.state()
|
||||
.await
|
||||
.read_project(request.payload.project_id, request.sender_id)?
|
||||
.host_connection_id;
|
||||
let response = self
|
||||
let response_payload = self
|
||||
.peer
|
||||
.forward_request(request.sender_id, host, request.payload.clone())
|
||||
.await?;
|
||||
@ -688,16 +700,18 @@ impl Server {
|
||||
.connection_ids();
|
||||
guests.retain(|guest_connection_id| *guest_connection_id != request.sender_id);
|
||||
broadcast(host, guests, |conn_id| {
|
||||
self.peer.forward_send(host, conn_id, response.clone())
|
||||
self.peer
|
||||
.forward_send(host, conn_id, response_payload.clone())
|
||||
});
|
||||
|
||||
Ok(response)
|
||||
response.send(response_payload)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn update_buffer(
|
||||
self: Arc<Server>,
|
||||
request: TypedEnvelope<proto::UpdateBuffer>,
|
||||
) -> Result<proto::Ack> {
|
||||
response: Response<proto::UpdateBuffer>,
|
||||
) -> Result<()> {
|
||||
let receiver_ids = self
|
||||
.state()
|
||||
.await
|
||||
@ -706,7 +720,8 @@ impl Server {
|
||||
self.peer
|
||||
.forward_send(request.sender_id, connection_id, request.payload.clone())
|
||||
});
|
||||
Ok(proto::Ack {})
|
||||
response.send(proto::Ack {})?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn update_buffer_file(
|
||||
@ -757,7 +772,8 @@ impl Server {
|
||||
async fn follow(
|
||||
self: Arc<Self>,
|
||||
request: TypedEnvelope<proto::Follow>,
|
||||
) -> Result<proto::FollowResponse> {
|
||||
response: Response<proto::Follow>,
|
||||
) -> Result<()> {
|
||||
let leader_id = ConnectionId(request.payload.leader_id);
|
||||
let follower_id = request.sender_id;
|
||||
if !self
|
||||
@ -768,14 +784,15 @@ impl Server {
|
||||
{
|
||||
Err(anyhow!("no such peer"))?;
|
||||
}
|
||||
let mut response = self
|
||||
let mut response_payload = self
|
||||
.peer
|
||||
.forward_request(request.sender_id, leader_id, request.payload)
|
||||
.await?;
|
||||
response
|
||||
response_payload
|
||||
.views
|
||||
.retain(|view| view.leader_id != Some(follower_id.0));
|
||||
Ok(response)
|
||||
response.send(response_payload)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn unfollow(self: Arc<Self>, request: TypedEnvelope<proto::Unfollow>) -> Result<()> {
|
||||
@ -823,13 +840,14 @@ impl Server {
|
||||
async fn get_channels(
|
||||
self: Arc<Server>,
|
||||
request: TypedEnvelope<proto::GetChannels>,
|
||||
) -> Result<proto::GetChannelsResponse> {
|
||||
response: Response<proto::GetChannels>,
|
||||
) -> Result<()> {
|
||||
let user_id = self
|
||||
.state()
|
||||
.await
|
||||
.user_id_for_connection(request.sender_id)?;
|
||||
let channels = self.app_state.db.get_accessible_channels(user_id).await?;
|
||||
Ok(proto::GetChannelsResponse {
|
||||
response.send(proto::GetChannelsResponse {
|
||||
channels: channels
|
||||
.into_iter()
|
||||
.map(|chan| proto::Channel {
|
||||
@ -837,13 +855,15 @@ impl Server {
|
||||
name: chan.name,
|
||||
})
|
||||
.collect(),
|
||||
})
|
||||
})?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn get_users(
|
||||
self: Arc<Server>,
|
||||
request: TypedEnvelope<proto::GetUsers>,
|
||||
) -> Result<proto::UsersResponse> {
|
||||
response: Response<proto::GetUsers>,
|
||||
) -> Result<()> {
|
||||
let user_ids = request
|
||||
.payload
|
||||
.user_ids
|
||||
@ -862,13 +882,15 @@ impl Server {
|
||||
github_login: user.github_login,
|
||||
})
|
||||
.collect();
|
||||
Ok(proto::UsersResponse { users })
|
||||
response.send(proto::UsersResponse { users })?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn fuzzy_search_users(
|
||||
self: Arc<Server>,
|
||||
request: TypedEnvelope<proto::FuzzySearchUsers>,
|
||||
) -> Result<proto::UsersResponse> {
|
||||
response: Response<proto::FuzzySearchUsers>,
|
||||
) -> Result<()> {
|
||||
let query = request.payload.query;
|
||||
let db = &self.app_state.db;
|
||||
let users = match query.len() {
|
||||
@ -888,7 +910,8 @@ impl Server {
|
||||
github_login: user.github_login,
|
||||
})
|
||||
.collect();
|
||||
Ok(proto::UsersResponse { users })
|
||||
response.send(proto::UsersResponse { users })?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[instrument(skip(self, state, user_ids))]
|
||||
@ -917,7 +940,8 @@ impl Server {
|
||||
async fn join_channel(
|
||||
self: Arc<Self>,
|
||||
request: TypedEnvelope<proto::JoinChannel>,
|
||||
) -> Result<proto::JoinChannelResponse> {
|
||||
response: Response<proto::JoinChannel>,
|
||||
) -> Result<()> {
|
||||
let user_id = self
|
||||
.state()
|
||||
.await
|
||||
@ -949,10 +973,11 @@ impl Server {
|
||||
nonce: Some(msg.nonce.as_u128().into()),
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
Ok(proto::JoinChannelResponse {
|
||||
response.send(proto::JoinChannelResponse {
|
||||
done: messages.len() < MESSAGE_COUNT_PER_PAGE,
|
||||
messages,
|
||||
})
|
||||
})?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn leave_channel(
|
||||
@ -983,7 +1008,8 @@ impl Server {
|
||||
async fn send_channel_message(
|
||||
self: Arc<Self>,
|
||||
request: TypedEnvelope<proto::SendChannelMessage>,
|
||||
) -> Result<proto::SendChannelMessageResponse> {
|
||||
response: Response<proto::SendChannelMessage>,
|
||||
) -> Result<()> {
|
||||
let channel_id = ChannelId::from_proto(request.payload.channel_id);
|
||||
let user_id;
|
||||
let connection_ids;
|
||||
@ -1030,15 +1056,17 @@ impl Server {
|
||||
},
|
||||
)
|
||||
});
|
||||
Ok(proto::SendChannelMessageResponse {
|
||||
response.send(proto::SendChannelMessageResponse {
|
||||
message: Some(message),
|
||||
})
|
||||
})?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn get_channel_messages(
|
||||
self: Arc<Self>,
|
||||
request: TypedEnvelope<proto::GetChannelMessages>,
|
||||
) -> Result<proto::GetChannelMessagesResponse> {
|
||||
response: Response<proto::GetChannelMessages>,
|
||||
) -> Result<()> {
|
||||
let user_id = self
|
||||
.state()
|
||||
.await
|
||||
@ -1071,11 +1099,11 @@ impl Server {
|
||||
nonce: Some(msg.nonce.as_u128().into()),
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
Ok(proto::GetChannelMessagesResponse {
|
||||
response.send(proto::GetChannelMessagesResponse {
|
||||
done: messages.len() < MESSAGE_COUNT_PER_PAGE,
|
||||
messages,
|
||||
})
|
||||
})?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn state<'a>(self: &'a Arc<Self>) -> StoreReadGuard<'a> {
|
||||
@ -1213,7 +1241,7 @@ pub async fn handle_websocket_request(
|
||||
Extension(server): Extension<Arc<Server>>,
|
||||
Extension(user_id): Extension<UserId>,
|
||||
ws: WebSocketUpgrade,
|
||||
) -> Response {
|
||||
) -> axum::response::Response {
|
||||
if protocol_version != rpc::PROTOCOL_VERSION {
|
||||
return (
|
||||
StatusCode::UPGRADE_REQUIRED,
|
||||
|
Loading…
Reference in New Issue
Block a user