mirror of
https://github.com/zed-industries/zed.git
synced 2024-11-07 20:39:04 +03:00
Re-register message handlers in RPC server
This commit is contained in:
parent
d6412fdbde
commit
d398b96f56
@ -35,23 +35,26 @@ use zrpc::{
|
||||
|
||||
type ReplicaId = u16;
|
||||
|
||||
type Handler = Box<
|
||||
type MessageHandler = Box<
|
||||
dyn Send
|
||||
+ Sync
|
||||
+ Fn(&mut Option<Box<dyn Any + Send + Sync>>, Arc<Server>) -> Option<BoxFuture<'static, ()>>,
|
||||
+ Fn(
|
||||
&mut Option<Box<dyn Any + Send + Sync>>,
|
||||
Arc<Server>,
|
||||
) -> Option<BoxFuture<'static, tide::Result<()>>>,
|
||||
>;
|
||||
|
||||
#[derive(Default)]
|
||||
struct ServerBuilder {
|
||||
handlers: Vec<Handler>,
|
||||
handlers: Vec<MessageHandler>,
|
||||
handler_types: HashSet<TypeId>,
|
||||
}
|
||||
|
||||
impl ServerBuilder {
|
||||
pub fn on_message<F, Fut, M>(&mut self, handler: F) -> &mut Self
|
||||
pub fn on_message<F, Fut, M>(mut self, handler: F) -> Self
|
||||
where
|
||||
F: 'static + Send + Sync + Fn(Box<TypedEnvelope<M>>, Arc<Server>) -> Fut,
|
||||
Fut: 'static + Send + Future<Output = ()>,
|
||||
Fut: 'static + Send + Future<Output = tide::Result<()>>,
|
||||
M: EnvelopedMessage,
|
||||
{
|
||||
if self.handler_types.insert(TypeId::of::<M>()) {
|
||||
@ -87,7 +90,7 @@ impl ServerBuilder {
|
||||
pub struct Server {
|
||||
rpc: Arc<Peer>,
|
||||
state: Arc<AppState>,
|
||||
handlers: Vec<Handler>,
|
||||
handlers: Vec<MessageHandler>,
|
||||
}
|
||||
|
||||
impl Server {
|
||||
@ -119,10 +122,16 @@ impl Server {
|
||||
futures::select_biased! {
|
||||
message = next_message => {
|
||||
if let Some(message) = message {
|
||||
let start_time = Instant::now();
|
||||
log::info!("RPC message received");
|
||||
let mut message = Some(message);
|
||||
for handler in &this.handlers {
|
||||
if let Some(future) = (handler)(&mut message, this.clone()) {
|
||||
future.await;
|
||||
if let Err(err) = future.await {
|
||||
log::error!("error handling message: {:?}", err);
|
||||
} else {
|
||||
log::info!("RPC message handled. duration:{:?}", start_time.elapsed());
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
@ -336,26 +345,24 @@ impl State {
|
||||
|
||||
pub fn build_server(state: &Arc<AppState>, rpc: &Arc<Peer>) -> Arc<Server> {
|
||||
ServerBuilder::default()
|
||||
// .on_message(share_worktree)
|
||||
// .on_message(join_worktree)
|
||||
// .on_message(update_worktree)
|
||||
// .on_message(close_worktree)
|
||||
// .on_message(open_buffer)
|
||||
// .on_message(close_buffer)
|
||||
// .on_message(update_buffer)
|
||||
// .on_message(buffer_saved)
|
||||
// .on_message(save_buffer)
|
||||
// .on_message(get_channels)
|
||||
// .on_message(get_users)
|
||||
// .on_message(join_channel)
|
||||
// .on_message(send_channel_message)
|
||||
.on_message(share_worktree)
|
||||
.on_message(join_worktree)
|
||||
.on_message(update_worktree)
|
||||
.on_message(close_worktree)
|
||||
.on_message(open_buffer)
|
||||
.on_message(close_buffer)
|
||||
.on_message(update_buffer)
|
||||
.on_message(buffer_saved)
|
||||
.on_message(save_buffer)
|
||||
.on_message(get_channels)
|
||||
.on_message(get_users)
|
||||
.on_message(join_channel)
|
||||
.on_message(send_channel_message)
|
||||
.build(rpc, state)
|
||||
}
|
||||
|
||||
pub fn add_routes(app: &mut tide::Server<Arc<AppState>>, rpc: &Arc<Peer>) {
|
||||
let server = build_server(app.state(), rpc);
|
||||
|
||||
let rpc = rpc.clone();
|
||||
app.at("/rpc").with(auth::VerifyToken).get(move |request: Request<Arc<AppState>>| {
|
||||
let user_id = request.ext::<UserId>().copied();
|
||||
let server = server.clone();
|
||||
@ -399,11 +406,10 @@ pub fn add_routes(app: &mut tide::Server<Arc<AppState>>, rpc: &Arc<Peer>) {
|
||||
}
|
||||
|
||||
async fn share_worktree(
|
||||
mut request: TypedEnvelope<proto::ShareWorktree>,
|
||||
rpc: &Arc<Peer>,
|
||||
state: &Arc<AppState>,
|
||||
mut request: Box<TypedEnvelope<proto::ShareWorktree>>,
|
||||
server: Arc<Server>,
|
||||
) -> tide::Result<()> {
|
||||
let mut state = state.rpc.write().await;
|
||||
let mut state = server.state.rpc.write().await;
|
||||
let worktree_id = state.next_worktree_id;
|
||||
state.next_worktree_id += 1;
|
||||
let access_token = random_token();
|
||||
@ -428,26 +434,27 @@ async fn share_worktree(
|
||||
},
|
||||
);
|
||||
|
||||
rpc.respond(
|
||||
request.receipt(),
|
||||
proto::ShareWorktreeResponse {
|
||||
worktree_id,
|
||||
access_token,
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
server
|
||||
.rpc
|
||||
.respond(
|
||||
request.receipt(),
|
||||
proto::ShareWorktreeResponse {
|
||||
worktree_id,
|
||||
access_token,
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn join_worktree(
|
||||
request: TypedEnvelope<proto::OpenWorktree>,
|
||||
rpc: &Arc<Peer>,
|
||||
state: &Arc<AppState>,
|
||||
request: Box<TypedEnvelope<proto::OpenWorktree>>,
|
||||
server: Arc<Server>,
|
||||
) -> tide::Result<()> {
|
||||
let worktree_id = request.payload.worktree_id;
|
||||
let access_token = &request.payload.access_token;
|
||||
|
||||
let mut state = state.rpc.write().await;
|
||||
let mut state = server.state.rpc.write().await;
|
||||
if let Some((peer_replica_id, worktree)) =
|
||||
state.join_worktree(request.sender_id, worktree_id, access_token)
|
||||
{
|
||||
@ -468,7 +475,7 @@ async fn join_worktree(
|
||||
}
|
||||
|
||||
broadcast(request.sender_id, worktree.connection_ids(), |conn_id| {
|
||||
rpc.send(
|
||||
server.rpc.send(
|
||||
conn_id,
|
||||
proto::AddPeer {
|
||||
worktree_id,
|
||||
@ -480,42 +487,45 @@ async fn join_worktree(
|
||||
)
|
||||
})
|
||||
.await?;
|
||||
rpc.respond(
|
||||
request.receipt(),
|
||||
proto::OpenWorktreeResponse {
|
||||
worktree_id,
|
||||
worktree: Some(proto::Worktree {
|
||||
root_name: worktree.root_name.clone(),
|
||||
entries: worktree.entries.values().cloned().collect(),
|
||||
}),
|
||||
replica_id: peer_replica_id as u32,
|
||||
peers,
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
server
|
||||
.rpc
|
||||
.respond(
|
||||
request.receipt(),
|
||||
proto::OpenWorktreeResponse {
|
||||
worktree_id,
|
||||
worktree: Some(proto::Worktree {
|
||||
root_name: worktree.root_name.clone(),
|
||||
entries: worktree.entries.values().cloned().collect(),
|
||||
}),
|
||||
replica_id: peer_replica_id as u32,
|
||||
peers,
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
} else {
|
||||
rpc.respond(
|
||||
request.receipt(),
|
||||
proto::OpenWorktreeResponse {
|
||||
worktree_id,
|
||||
worktree: None,
|
||||
replica_id: 0,
|
||||
peers: Vec::new(),
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
server
|
||||
.rpc
|
||||
.respond(
|
||||
request.receipt(),
|
||||
proto::OpenWorktreeResponse {
|
||||
worktree_id,
|
||||
worktree: None,
|
||||
replica_id: 0,
|
||||
peers: Vec::new(),
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn update_worktree(
|
||||
request: TypedEnvelope<proto::UpdateWorktree>,
|
||||
rpc: &Arc<Peer>,
|
||||
state: &Arc<AppState>,
|
||||
request: Box<TypedEnvelope<proto::UpdateWorktree>>,
|
||||
server: Arc<Server>,
|
||||
) -> tide::Result<()> {
|
||||
{
|
||||
let mut state = state.rpc.write().await;
|
||||
let mut state = server.state.rpc.write().await;
|
||||
let worktree = state.write_worktree(request.payload.worktree_id, request.sender_id)?;
|
||||
for entry_id in &request.payload.removed_entries {
|
||||
worktree.entries.remove(&entry_id);
|
||||
@ -526,18 +536,17 @@ async fn update_worktree(
|
||||
}
|
||||
}
|
||||
|
||||
broadcast_in_worktree(request.payload.worktree_id, request, rpc, state).await?;
|
||||
broadcast_in_worktree(request.payload.worktree_id, &request, &server).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn close_worktree(
|
||||
request: TypedEnvelope<proto::CloseWorktree>,
|
||||
rpc: &Arc<Peer>,
|
||||
state: &Arc<AppState>,
|
||||
request: Box<TypedEnvelope<proto::CloseWorktree>>,
|
||||
server: Arc<Server>,
|
||||
) -> tide::Result<()> {
|
||||
let connection_ids;
|
||||
{
|
||||
let mut state = state.rpc.write().await;
|
||||
let mut state = server.state.rpc.write().await;
|
||||
let worktree = state.write_worktree(request.payload.worktree_id, request.sender_id)?;
|
||||
connection_ids = worktree.connection_ids();
|
||||
if worktree.host_connection_id == Some(request.sender_id) {
|
||||
@ -548,7 +557,7 @@ async fn close_worktree(
|
||||
}
|
||||
|
||||
broadcast(request.sender_id, connection_ids, |conn_id| {
|
||||
rpc.send(
|
||||
server.rpc.send(
|
||||
conn_id,
|
||||
proto::RemovePeer {
|
||||
worktree_id: request.payload.worktree_id,
|
||||
@ -562,53 +571,55 @@ async fn close_worktree(
|
||||
}
|
||||
|
||||
async fn open_buffer(
|
||||
request: TypedEnvelope<proto::OpenBuffer>,
|
||||
rpc: &Arc<Peer>,
|
||||
state: &Arc<AppState>,
|
||||
request: Box<TypedEnvelope<proto::OpenBuffer>>,
|
||||
server: Arc<Server>,
|
||||
) -> tide::Result<()> {
|
||||
let receipt = request.receipt();
|
||||
let worktree_id = request.payload.worktree_id;
|
||||
let host_connection_id = state
|
||||
let host_connection_id = server
|
||||
.state
|
||||
.rpc
|
||||
.read()
|
||||
.await
|
||||
.read_worktree(worktree_id, request.sender_id)?
|
||||
.host_connection_id()?;
|
||||
|
||||
let response = rpc
|
||||
let response = server
|
||||
.rpc
|
||||
.forward_request(request.sender_id, host_connection_id, request.payload)
|
||||
.await?;
|
||||
rpc.respond(receipt, response).await?;
|
||||
server.rpc.respond(receipt, response).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn close_buffer(
|
||||
request: TypedEnvelope<proto::CloseBuffer>,
|
||||
rpc: &Arc<Peer>,
|
||||
state: &Arc<AppState>,
|
||||
request: Box<TypedEnvelope<proto::CloseBuffer>>,
|
||||
server: Arc<Server>,
|
||||
) -> tide::Result<()> {
|
||||
let host_connection_id = state
|
||||
let host_connection_id = server
|
||||
.state
|
||||
.rpc
|
||||
.read()
|
||||
.await
|
||||
.read_worktree(request.payload.worktree_id, request.sender_id)?
|
||||
.host_connection_id()?;
|
||||
|
||||
rpc.forward_send(request.sender_id, host_connection_id, request.payload)
|
||||
server
|
||||
.rpc
|
||||
.forward_send(request.sender_id, host_connection_id, request.payload)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn save_buffer(
|
||||
request: TypedEnvelope<proto::SaveBuffer>,
|
||||
rpc: &Arc<Peer>,
|
||||
state: &Arc<AppState>,
|
||||
request: Box<TypedEnvelope<proto::SaveBuffer>>,
|
||||
server: Arc<Server>,
|
||||
) -> tide::Result<()> {
|
||||
let host;
|
||||
let guests;
|
||||
{
|
||||
let state = state.rpc.read().await;
|
||||
let state = server.state.rpc.read().await;
|
||||
let worktree = state.read_worktree(request.payload.worktree_id, request.sender_id)?;
|
||||
host = worktree.host_connection_id()?;
|
||||
guests = worktree
|
||||
@ -620,17 +631,19 @@ async fn save_buffer(
|
||||
|
||||
let sender = request.sender_id;
|
||||
let receipt = request.receipt();
|
||||
let response = rpc
|
||||
let response = server
|
||||
.rpc
|
||||
.forward_request(sender, host, request.payload.clone())
|
||||
.await?;
|
||||
|
||||
broadcast(host, guests, |conn_id| {
|
||||
let response = response.clone();
|
||||
let server = &server;
|
||||
async move {
|
||||
if conn_id == sender {
|
||||
rpc.respond(receipt, response).await
|
||||
server.rpc.respond(receipt, response).await
|
||||
} else {
|
||||
rpc.forward_send(host, conn_id, response).await
|
||||
server.rpc.forward_send(host, conn_id, response).await
|
||||
}
|
||||
}
|
||||
})
|
||||
@ -640,61 +653,62 @@ async fn save_buffer(
|
||||
}
|
||||
|
||||
async fn update_buffer(
|
||||
request: TypedEnvelope<proto::UpdateBuffer>,
|
||||
rpc: &Arc<Peer>,
|
||||
state: &Arc<AppState>,
|
||||
request: Box<TypedEnvelope<proto::UpdateBuffer>>,
|
||||
server: Arc<Server>,
|
||||
) -> tide::Result<()> {
|
||||
broadcast_in_worktree(request.payload.worktree_id, request, rpc, state).await
|
||||
broadcast_in_worktree(request.payload.worktree_id, &request, &server).await
|
||||
}
|
||||
|
||||
async fn buffer_saved(
|
||||
request: TypedEnvelope<proto::BufferSaved>,
|
||||
rpc: &Arc<Peer>,
|
||||
state: &Arc<AppState>,
|
||||
request: Box<TypedEnvelope<proto::BufferSaved>>,
|
||||
server: Arc<Server>,
|
||||
) -> tide::Result<()> {
|
||||
broadcast_in_worktree(request.payload.worktree_id, request, rpc, state).await
|
||||
broadcast_in_worktree(request.payload.worktree_id, &request, &server).await
|
||||
}
|
||||
|
||||
async fn get_channels(
|
||||
request: TypedEnvelope<proto::GetChannels>,
|
||||
rpc: &Arc<Peer>,
|
||||
state: &Arc<AppState>,
|
||||
request: Box<TypedEnvelope<proto::GetChannels>>,
|
||||
server: Arc<Server>,
|
||||
) -> tide::Result<()> {
|
||||
let user_id = state
|
||||
let user_id = server
|
||||
.state
|
||||
.rpc
|
||||
.read()
|
||||
.await
|
||||
.user_id_for_connection(request.sender_id)?;
|
||||
let channels = state.db.get_channels_for_user(user_id).await?;
|
||||
rpc.respond(
|
||||
request.receipt(),
|
||||
proto::GetChannelsResponse {
|
||||
channels: channels
|
||||
.into_iter()
|
||||
.map(|chan| proto::Channel {
|
||||
id: chan.id.to_proto(),
|
||||
name: chan.name,
|
||||
})
|
||||
.collect(),
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
let channels = server.state.db.get_channels_for_user(user_id).await?;
|
||||
server
|
||||
.rpc
|
||||
.respond(
|
||||
request.receipt(),
|
||||
proto::GetChannelsResponse {
|
||||
channels: channels
|
||||
.into_iter()
|
||||
.map(|chan| proto::Channel {
|
||||
id: chan.id.to_proto(),
|
||||
name: chan.name,
|
||||
})
|
||||
.collect(),
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn get_users(
|
||||
request: TypedEnvelope<proto::GetUsers>,
|
||||
rpc: &Arc<Peer>,
|
||||
state: &Arc<AppState>,
|
||||
request: Box<TypedEnvelope<proto::GetUsers>>,
|
||||
server: Arc<Server>,
|
||||
) -> tide::Result<()> {
|
||||
let user_id = state
|
||||
let user_id = server
|
||||
.state
|
||||
.rpc
|
||||
.read()
|
||||
.await
|
||||
.user_id_for_connection(request.sender_id)?;
|
||||
let receipt = request.receipt();
|
||||
let user_ids = request.payload.user_ids.into_iter().map(UserId::from_proto);
|
||||
let users = state
|
||||
let users = server
|
||||
.state
|
||||
.db
|
||||
.get_users_by_ids(user_id, user_ids)
|
||||
.await?
|
||||
@ -705,23 +719,26 @@ async fn get_users(
|
||||
avatar_url: String::new(),
|
||||
})
|
||||
.collect();
|
||||
rpc.respond(receipt, proto::GetUsersResponse { users })
|
||||
server
|
||||
.rpc
|
||||
.respond(receipt, proto::GetUsersResponse { users })
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn join_channel(
|
||||
request: TypedEnvelope<proto::JoinChannel>,
|
||||
rpc: &Arc<Peer>,
|
||||
state: &Arc<AppState>,
|
||||
request: Box<TypedEnvelope<proto::JoinChannel>>,
|
||||
server: Arc<Server>,
|
||||
) -> tide::Result<()> {
|
||||
let user_id = state
|
||||
let user_id = server
|
||||
.state
|
||||
.rpc
|
||||
.read()
|
||||
.await
|
||||
.user_id_for_connection(request.sender_id)?;
|
||||
let channel_id = ChannelId::from_proto(request.payload.channel_id);
|
||||
if !state
|
||||
if !server
|
||||
.state
|
||||
.db
|
||||
.can_user_access_channel(user_id, channel_id)
|
||||
.await?
|
||||
@ -729,12 +746,14 @@ async fn join_channel(
|
||||
Err(anyhow!("access denied"))?;
|
||||
}
|
||||
|
||||
state
|
||||
server
|
||||
.state
|
||||
.rpc
|
||||
.write()
|
||||
.await
|
||||
.join_channel(request.sender_id, channel_id);
|
||||
let messages = state
|
||||
let messages = server
|
||||
.state
|
||||
.db
|
||||
.get_recent_channel_messages(channel_id, 50)
|
||||
.await?
|
||||
@ -746,21 +765,22 @@ async fn join_channel(
|
||||
sender_id: msg.sender_id.to_proto(),
|
||||
})
|
||||
.collect();
|
||||
rpc.respond(request.receipt(), proto::JoinChannelResponse { messages })
|
||||
server
|
||||
.rpc
|
||||
.respond(request.receipt(), proto::JoinChannelResponse { messages })
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn send_channel_message(
|
||||
request: TypedEnvelope<proto::SendChannelMessage>,
|
||||
peer: &Arc<Peer>,
|
||||
app: &Arc<AppState>,
|
||||
request: Box<TypedEnvelope<proto::SendChannelMessage>>,
|
||||
server: Arc<Server>,
|
||||
) -> tide::Result<()> {
|
||||
let channel_id = ChannelId::from_proto(request.payload.channel_id);
|
||||
let user_id;
|
||||
let connection_ids;
|
||||
{
|
||||
let state = app.rpc.read().await;
|
||||
let state = server.state.rpc.read().await;
|
||||
user_id = state.user_id_for_connection(request.sender_id)?;
|
||||
if let Some(channel) = state.channels.get(&channel_id) {
|
||||
connection_ids = channel.connection_ids();
|
||||
@ -770,7 +790,8 @@ async fn send_channel_message(
|
||||
}
|
||||
|
||||
let timestamp = OffsetDateTime::now_utc();
|
||||
let message_id = app
|
||||
let message_id = server
|
||||
.state
|
||||
.db
|
||||
.create_channel_message(channel_id, user_id, &request.payload.body, timestamp)
|
||||
.await?;
|
||||
@ -784,7 +805,7 @@ async fn send_channel_message(
|
||||
}),
|
||||
};
|
||||
broadcast(request.sender_id, connection_ids, |conn_id| {
|
||||
peer.send(conn_id, message.clone())
|
||||
server.rpc.send(conn_id, message.clone())
|
||||
})
|
||||
.await?;
|
||||
|
||||
@ -793,11 +814,11 @@ async fn send_channel_message(
|
||||
|
||||
async fn broadcast_in_worktree<T: proto::EnvelopedMessage>(
|
||||
worktree_id: u64,
|
||||
request: TypedEnvelope<T>,
|
||||
rpc: &Arc<Peer>,
|
||||
state: &Arc<AppState>,
|
||||
request: &TypedEnvelope<T>,
|
||||
server: &Arc<Server>,
|
||||
) -> tide::Result<()> {
|
||||
let connection_ids = state
|
||||
let connection_ids = server
|
||||
.state
|
||||
.rpc
|
||||
.read()
|
||||
.await
|
||||
@ -805,7 +826,9 @@ async fn broadcast_in_worktree<T: proto::EnvelopedMessage>(
|
||||
.connection_ids();
|
||||
|
||||
broadcast(request.sender_id, connection_ids, |conn_id| {
|
||||
rpc.forward_send(request.sender_id, conn_id, request.payload.clone())
|
||||
server
|
||||
.rpc
|
||||
.forward_send(request.sender_id, conn_id, request.payload.clone())
|
||||
})
|
||||
.await?;
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user