diff --git a/Cargo.lock b/Cargo.lock index ae226ee437..0cb1d880b5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8512,6 +8512,8 @@ dependencies = [ "anyhow", "collections", "futures 0.3.30", + "gpui", + "parking_lot", "prost", "prost-build", "serde", diff --git a/crates/assistant/src/assistant.rs b/crates/assistant/src/assistant.rs index 9557eb6e3d..95e91a93dd 100644 --- a/crates/assistant/src/assistant.rs +++ b/crates/assistant/src/assistant.rs @@ -209,7 +209,7 @@ pub fn init( }) .detach(); - context_store::init(&client); + context_store::init(&client.clone().into()); prompt_library::init(cx); init_language_model_settings(cx); assistant_slash_command::init(cx); diff --git a/crates/assistant/src/context_store.rs b/crates/assistant/src/context_store.rs index 95bd958be0..b584c008bc 100644 --- a/crates/assistant/src/context_store.rs +++ b/crates/assistant/src/context_store.rs @@ -2,6 +2,7 @@ use crate::{ prompts::PromptBuilder, Context, ContextEvent, ContextId, ContextOperation, ContextVersion, SavedContext, SavedContextMetadata, }; +use ::proto::AnyProtoClient; use anyhow::{anyhow, Context as _, Result}; use client::{proto, telemetry::Telemetry, Client, TypedEnvelope}; use clock::ReplicaId; @@ -25,7 +26,7 @@ use std::{ }; use util::{ResultExt, TryFutureExt}; -pub fn init(client: &Arc) { +pub fn init(client: &AnyProtoClient) { client.add_model_message_handler(ContextStore::handle_advertise_contexts); client.add_model_request_handler(ContextStore::handle_open_context); client.add_model_request_handler(ContextStore::handle_create_context); diff --git a/crates/channel/src/channel.rs b/crates/channel/src/channel.rs index aee92d0f6c..b9547bef1a 100644 --- a/crates/channel/src/channel.rs +++ b/crates/channel/src/channel.rs @@ -18,6 +18,6 @@ mod channel_store_tests; pub fn init(client: &Arc, user_store: Model, cx: &mut AppContext) { channel_store::init(client, user_store, cx); - channel_buffer::init(client); - channel_chat::init(client); + channel_buffer::init(&client.clone().into()); + channel_chat::init(&client.clone().into()); } diff --git a/crates/channel/src/channel_buffer.rs b/crates/channel/src/channel_buffer.rs index 7ce291ef4a..acf08612d1 100644 --- a/crates/channel/src/channel_buffer.rs +++ b/crates/channel/src/channel_buffer.rs @@ -5,7 +5,7 @@ use collections::HashMap; use gpui::{AppContext, AsyncAppContext, Context, EventEmitter, Model, ModelContext, Task}; use language::proto::serialize_version; use rpc::{ - proto::{self, PeerId}, + proto::{self, AnyProtoClient, PeerId}, TypedEnvelope, }; use std::{sync::Arc, time::Duration}; @@ -14,7 +14,7 @@ use util::ResultExt; pub const ACKNOWLEDGE_DEBOUNCE_INTERVAL: Duration = Duration::from_millis(250); -pub(crate) fn init(client: &Arc) { +pub(crate) fn init(client: &AnyProtoClient) { client.add_model_message_handler(ChannelBuffer::handle_update_channel_buffer); client.add_model_message_handler(ChannelBuffer::handle_update_channel_buffer_collaborators); } diff --git a/crates/channel/src/channel_chat.rs b/crates/channel/src/channel_chat.rs index 8a1250fd69..a186e31cc6 100644 --- a/crates/channel/src/channel_chat.rs +++ b/crates/channel/src/channel_chat.rs @@ -11,6 +11,7 @@ use gpui::{ AppContext, AsyncAppContext, Context, EventEmitter, Model, ModelContext, Task, WeakModel, }; use rand::prelude::*; +use rpc::proto::AnyProtoClient; use std::{ ops::{ControlFlow, Range}, sync::Arc, @@ -95,7 +96,7 @@ pub enum ChannelChatEvent { } impl EventEmitter for ChannelChat {} -pub fn init(client: &Arc) { +pub fn init(client: &AnyProtoClient) { client.add_model_message_handler(ChannelChat::handle_message_sent); client.add_model_message_handler(ChannelChat::handle_message_removed); client.add_model_message_handler(ChannelChat::handle_message_updated); diff --git a/crates/client/src/client.rs b/crates/client/src/client.rs index 9d7fe7545e..ab5f6bb394 100644 --- a/crates/client/src/client.rs +++ b/crates/client/src/client.rs @@ -14,22 +14,18 @@ use async_tungstenite::tungstenite::{ }; use chrono::{DateTime, Utc}; use clock::SystemClock; -use collections::HashMap; use futures::{ - channel::oneshot, - future::{BoxFuture, LocalBoxFuture}, - AsyncReadExt, FutureExt, SinkExt, Stream, StreamExt, TryFutureExt as _, TryStreamExt, -}; -use gpui::{ - actions, AnyModel, AnyWeakModel, AppContext, AsyncAppContext, Global, Model, Task, WeakModel, + channel::oneshot, future::BoxFuture, AsyncReadExt, FutureExt, SinkExt, Stream, StreamExt, + TryFutureExt as _, TryStreamExt, }; +use gpui::{actions, AppContext, AsyncAppContext, Global, Model, Task, WeakModel}; use http_client::{AsyncBody, HttpClient, HttpClientWithUrl}; use parking_lot::RwLock; use postage::watch; -use proto::ProtoClient; +use proto::{AnyProtoClient, EntityMessageSubscriber, ProtoClient, ProtoMessageHandlerSet}; use rand::prelude::*; use release_channel::{AppVersion, ReleaseChannel}; -use rpc::proto::{AnyTypedEnvelope, EntityMessage, EnvelopedMessage, PeerId, RequestMessage}; +use rpc::proto::{AnyTypedEnvelope, EnvelopedMessage, PeerId, RequestMessage}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use settings::{Settings, SettingsSources}; @@ -208,6 +204,7 @@ pub struct Client { telemetry: Arc, credentials_provider: Arc, state: RwLock, + handler_set: parking_lot::Mutex, #[allow(clippy::type_complexity)] #[cfg(any(test, feature = "test-support"))] @@ -304,30 +301,7 @@ impl Status { struct ClientState { credentials: Option, status: (watch::Sender, watch::Receiver), - entity_id_extractors: HashMap u64>, _reconnect_task: Option>, - entities_by_type_and_remote_id: HashMap<(TypeId, u64), WeakSubscriber>, - models_by_message_type: HashMap, - entity_types_by_message_type: HashMap, - #[allow(clippy::type_complexity)] - message_handlers: HashMap< - TypeId, - Arc< - dyn Send - + Sync - + Fn( - AnyModel, - Box, - &Arc, - AsyncAppContext, - ) -> LocalBoxFuture<'static, Result<()>>, - >, - >, -} - -enum WeakSubscriber { - Entity { handle: AnyWeakModel }, - Pending(Vec>), } #[derive(Clone, Debug, Eq, PartialEq)] @@ -379,12 +353,7 @@ impl Default for ClientState { Self { credentials: None, status: watch::channel_with(Status::SignedOut), - entity_id_extractors: Default::default(), _reconnect_task: None, - models_by_message_type: Default::default(), - entities_by_type_and_remote_id: Default::default(), - entity_types_by_message_type: Default::default(), - message_handlers: Default::default(), } } } @@ -405,13 +374,13 @@ impl Drop for Subscription { match self { Subscription::Entity { client, id } => { if let Some(client) = client.upgrade() { - let mut state = client.state.write(); + let mut state = client.handler_set.lock(); let _ = state.entities_by_type_and_remote_id.remove(id); } } Subscription::Message { client, id } => { if let Some(client) = client.upgrade() { - let mut state = client.state.write(); + let mut state = client.handler_set.lock(); let _ = state.entity_types_by_message_type.remove(id); let _ = state.message_handlers.remove(id); } @@ -430,21 +399,21 @@ pub struct PendingEntitySubscription { impl PendingEntitySubscription { pub fn set_model(mut self, model: &Model, cx: &mut AsyncAppContext) -> Subscription { self.consumed = true; - let mut state = self.client.state.write(); + let mut handlers = self.client.handler_set.lock(); let id = (TypeId::of::(), self.remote_id); - let Some(WeakSubscriber::Pending(messages)) = - state.entities_by_type_and_remote_id.remove(&id) + let Some(EntityMessageSubscriber::Pending(messages)) = + handlers.entities_by_type_and_remote_id.remove(&id) else { unreachable!() }; - state.entities_by_type_and_remote_id.insert( + handlers.entities_by_type_and_remote_id.insert( id, - WeakSubscriber::Entity { + EntityMessageSubscriber::Entity { handle: model.downgrade().into(), }, ); - drop(state); + drop(handlers); for message in messages { let client_id = self.client.id(); let type_name = message.payload_type_name(); @@ -467,8 +436,8 @@ impl PendingEntitySubscription { impl Drop for PendingEntitySubscription { fn drop(&mut self) { if !self.consumed { - let mut state = self.client.state.write(); - if let Some(WeakSubscriber::Pending(messages)) = state + let mut state = self.client.handler_set.lock(); + if let Some(EntityMessageSubscriber::Pending(messages)) = state .entities_by_type_and_remote_id .remove(&(TypeId::of::(), self.remote_id)) { @@ -549,6 +518,7 @@ impl Client { http, credentials_provider, state: Default::default(), + handler_set: Default::default(), #[cfg(any(test, feature = "test-support"))] authenticate: Default::default(), @@ -592,10 +562,7 @@ impl Client { pub fn teardown(&self) { let mut state = self.state.write(); state._reconnect_task.take(); - state.message_handlers.clear(); - state.models_by_message_type.clear(); - state.entities_by_type_and_remote_id.clear(); - state.entity_id_extractors.clear(); + self.handler_set.lock().clear(); self.peer.teardown(); } @@ -708,14 +675,14 @@ impl Client { { let id = (TypeId::of::(), remote_id); - let mut state = self.state.write(); + let mut state = self.handler_set.lock(); if state.entities_by_type_and_remote_id.contains_key(&id) { return Err(anyhow!("already subscribed to entity")); } state .entities_by_type_and_remote_id - .insert(id, WeakSubscriber::Pending(Default::default())); + .insert(id, EntityMessageSubscriber::Pending(Default::default())); Ok(PendingEntitySubscription { client: self.clone(), @@ -752,13 +719,13 @@ impl Client { E: 'static, H: 'static + Sync - + Fn(Model, TypedEnvelope, Arc, AsyncAppContext) -> F + + Fn(Model, TypedEnvelope, AnyProtoClient, AsyncAppContext) -> F + Send + Sync, F: 'static + Future>, { let message_type_id = TypeId::of::(); - let mut state = self.state.write(); + let mut state = self.handler_set.lock(); state .models_by_message_type .insert(message_type_id, entity.into()); @@ -803,85 +770,18 @@ impl Client { }) } - pub fn add_model_message_handler(self: &Arc, handler: H) - where - M: EntityMessage, - E: 'static, - H: 'static + Fn(Model, TypedEnvelope, AsyncAppContext) -> F + Send + Sync, - F: 'static + Future>, - { - self.add_entity_message_handler::(move |subscriber, message, _, cx| { - handler(subscriber.downcast::().unwrap(), message, cx) - }) - } - - fn add_entity_message_handler(self: &Arc, handler: H) - where - M: EntityMessage, - E: 'static, - H: 'static + Fn(AnyModel, TypedEnvelope, Arc, AsyncAppContext) -> F + Send + Sync, - F: 'static + Future>, - { - let model_type_id = TypeId::of::(); - let message_type_id = TypeId::of::(); - - let mut state = self.state.write(); - state - .entity_types_by_message_type - .insert(message_type_id, model_type_id); - state - .entity_id_extractors - .entry(message_type_id) - .or_insert_with(|| { - |envelope| { - envelope - .as_any() - .downcast_ref::>() - .unwrap() - .payload - .remote_entity_id() - } - }); - let prev_handler = state.message_handlers.insert( - message_type_id, - Arc::new(move |handle, envelope, client, cx| { - let envelope = envelope.into_any().downcast::>().unwrap(); - handler(handle, *envelope, client.clone(), cx).boxed_local() - }), - ); - if prev_handler.is_some() { - panic!("registered handler for the same message twice"); - } - } - - pub fn add_model_request_handler(self: &Arc, handler: H) - where - M: EntityMessage + RequestMessage, - E: 'static, - H: 'static + Fn(Model, TypedEnvelope, AsyncAppContext) -> F + Send + Sync, - F: 'static + Future>, - { - self.add_entity_message_handler::(move |entity, envelope, client, cx| { - Self::respond_to_request::( - envelope.receipt(), - handler(entity.downcast::().unwrap(), envelope, cx), - client, - ) - }) - } - async fn respond_to_request>>( receipt: Receipt, response: F, - client: Arc, + client: AnyProtoClient, ) -> Result<()> { match response.await { Ok(response) => { - client.respond(receipt, response)?; + client.send_response(receipt.message_id, response)?; Ok(()) } Err(error) => { - client.respond_with_error(receipt, error.to_proto())?; + client.send_response(receipt.message_id, error.to_proto())?; Err(error) } } @@ -1541,16 +1441,6 @@ impl Client { self.peer.send(self.connection_id()?, message) } - pub fn send_dynamic( - &self, - envelope: proto::Envelope, - message_type: &'static str, - ) -> Result<()> { - log::debug!("rpc send. client_id:{}, name:{}", self.id(), message_type); - let connection_id = self.connection_id()?; - self.peer.send_dynamic(connection_id, envelope) - } - pub fn request( &self, request: T, @@ -1632,115 +1522,56 @@ impl Client { } } - fn respond(&self, receipt: Receipt, response: T::Response) -> Result<()> { - log::debug!("rpc respond. client_id:{}. name:{}", self.id(), T::NAME); - self.peer.respond(receipt, response) - } - - fn respond_with_error( - &self, - receipt: Receipt, - error: proto::Error, - ) -> Result<()> { - log::debug!("rpc respond. client_id:{}. name:{}", self.id(), T::NAME); - self.peer.respond_with_error(receipt, error) - } - fn handle_message( self: &Arc, message: Box, cx: &AsyncAppContext, ) { - let mut state = self.state.write(); + let sender_id = message.sender_id(); + let request_id = message.message_id(); let type_name = message.payload_type_name(); - let payload_type_id = message.payload_type_id(); - let sender_id = message.original_sender_id(); + let original_sender_id = message.original_sender_id(); - let mut subscriber = None; - - if let Some(handle) = state - .models_by_message_type - .get(&payload_type_id) - .and_then(|handle| handle.upgrade()) - { - subscriber = Some(handle); - } else if let Some((extract_entity_id, entity_type_id)) = - state.entity_id_extractors.get(&payload_type_id).zip( - state - .entity_types_by_message_type - .get(&payload_type_id) - .copied(), - ) - { - let entity_id = (extract_entity_id)(message.as_ref()); - - match state - .entities_by_type_and_remote_id - .get_mut(&(entity_type_id, entity_id)) - { - Some(WeakSubscriber::Pending(pending)) => { - pending.push(message); - return; - } - Some(weak_subscriber) => match weak_subscriber { - WeakSubscriber::Entity { handle } => { - subscriber = handle.upgrade(); - } - - WeakSubscriber::Pending(_) => {} - }, - _ => {} - } - } - - let subscriber = if let Some(subscriber) = subscriber { - subscriber - } else { - log::info!("unhandled message {}", type_name); - self.peer.respond_with_unhandled_message(message).log_err(); - return; - }; - - let handler = state.message_handlers.get(&payload_type_id).cloned(); - // Dropping the state prevents deadlocks if the handler interacts with rpc::Client. - // It also ensures we don't hold the lock while yielding back to the executor, as - // that might cause the executor thread driving this future to block indefinitely. - drop(state); - - if let Some(handler) = handler { - let future = handler(subscriber, message, self, cx.clone()); + if let Some(future) = ProtoMessageHandlerSet::handle_message( + &self.handler_set, + message, + self.clone().into(), + cx.clone(), + ) { let client_id = self.id(); log::debug!( "rpc message received. client_id:{}, sender_id:{:?}, type:{}", client_id, - sender_id, + original_sender_id, type_name ); cx.spawn(move |_| async move { - match future.await { - Ok(()) => { - log::debug!( - "rpc message handled. client_id:{}, sender_id:{:?}, type:{}", - client_id, - sender_id, - type_name - ); - } - Err(error) => { - log::error!( - "error handling message. client_id:{}, sender_id:{:?}, type:{}, error:{:?}", - client_id, - sender_id, - type_name, - error - ); - } + match future.await { + Ok(()) => { + log::debug!( + "rpc message handled. client_id:{}, sender_id:{:?}, type:{}", + client_id, + original_sender_id, + type_name + ); } - }) - .detach(); + Err(error) => { + log::error!( + "error handling message. client_id:{}, sender_id:{:?}, type:{}, error:{:?}", + client_id, + original_sender_id, + type_name, + error + ); + } + } + }) + .detach(); } else { log::info!("unhandled message {}", type_name); - self.peer.respond_with_unhandled_message(message).log_err(); + self.peer + .respond_with_unhandled_message(sender_id.into(), request_id, type_name) + .log_err(); } } @@ -1759,7 +1590,23 @@ impl ProtoClient for Client { } fn send(&self, envelope: proto::Envelope, message_type: &'static str) -> Result<()> { - self.send_dynamic(envelope, message_type) + log::debug!("rpc send. client_id:{}, name:{}", self.id(), message_type); + let connection_id = self.connection_id()?; + self.peer.send_dynamic(connection_id, envelope) + } + + fn send_response(&self, envelope: proto::Envelope, message_type: &'static str) -> Result<()> { + log::debug!( + "rpc respond. client_id:{}, name:{}", + self.id(), + message_type + ); + let connection_id = self.connection_id()?; + self.peer.send_dynamic(connection_id, envelope) + } + + fn message_handler_set(&self) -> &parking_lot::Mutex { + &self.handler_set } } @@ -2103,7 +1950,7 @@ mod tests { let (done_tx1, mut done_rx1) = smol::channel::unbounded(); let (done_tx2, mut done_rx2) = smol::channel::unbounded(); - client.add_model_message_handler( + AnyProtoClient::from(client.clone()).add_model_message_handler( move |model: Model, _: TypedEnvelope, mut cx| { match model.update(&mut cx, |model, _| model.id).unwrap() { 1 => done_tx1.try_send(()).unwrap(), diff --git a/crates/collab/src/tests/test_server.rs b/crates/collab/src/tests/test_server.rs index 12c4f3bfcb..6aedfd95db 100644 --- a/crates/collab/src/tests/test_server.rs +++ b/crates/collab/src/tests/test_server.rs @@ -301,7 +301,7 @@ impl TestServer { dev_server_projects::init(client.clone(), cx); settings::KeymapFile::load_asset(os_keymap, cx).unwrap(); language_model::LanguageModelRegistry::test(cx); - assistant::context_store::init(&client); + assistant::context_store::init(&client.clone().into()); }); client diff --git a/crates/project/src/buffer_store.rs b/crates/project/src/buffer_store.rs index 428684783f..a9d013f83b 100644 --- a/crates/project/src/buffer_store.rs +++ b/crates/project/src/buffer_store.rs @@ -71,7 +71,7 @@ pub struct ProjectTransaction(pub HashMap, language::Transaction>) impl EventEmitter for BufferStore {} impl BufferStore { - pub fn init(client: &Arc) { + pub fn init(client: &AnyProtoClient) { client.add_model_message_handler(Self::handle_buffer_reloaded); client.add_model_message_handler(Self::handle_buffer_saved); client.add_model_message_handler(Self::handle_update_buffer_file); diff --git a/crates/project/src/lsp_store.rs b/crates/project/src/lsp_store.rs index 2082382be2..0fde0ac5ad 100644 --- a/crates/project/src/lsp_store.rs +++ b/crates/project/src/lsp_store.rs @@ -12,7 +12,7 @@ use crate::{ }; use anyhow::{anyhow, Context as _, Result}; use async_trait::async_trait; -use client::{proto, Client, TypedEnvelope}; +use client::{proto, TypedEnvelope}; use collections::{btree_map, BTreeMap, HashMap, HashSet}; use futures::{ future::{join_all, Shared}, @@ -45,6 +45,7 @@ use lsp::{ use parking_lot::{Mutex, RwLock}; use postage::watch; use rand::prelude::*; + use rpc::proto::AnyProtoClient; use serde::Serialize; use settings::{Settings, SettingsLocation, SettingsStore}; @@ -84,51 +85,7 @@ const SERVER_REINSTALL_DEBOUNCE_TIMEOUT: Duration = Duration::from_secs(1); const SERVER_LAUNCHING_BEFORE_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(5); pub const SERVER_PROGRESS_THROTTLE_TIMEOUT: Duration = Duration::from_millis(100); -#[derive(Clone, Debug)] -pub(crate) struct CoreSymbol { - pub language_server_name: LanguageServerName, - pub source_worktree_id: WorktreeId, - pub path: ProjectPath, - pub name: String, - pub kind: lsp::SymbolKind, - pub range: Range>, - pub signature: [u8; 32], -} - -pub enum LspStoreEvent { - LanguageServerAdded(LanguageServerId), - LanguageServerRemoved(LanguageServerId), - LanguageServerUpdate { - language_server_id: LanguageServerId, - message: proto::update_language_server::Variant, - }, - LanguageServerLog(LanguageServerId, LanguageServerLogType, String), - LanguageServerPrompt(LanguageServerPromptRequest), - Notification(String), - RefreshInlayHints, - DiagnosticsUpdated { - language_server_id: LanguageServerId, - path: ProjectPath, - }, - DiskBasedDiagnosticsStarted { - language_server_id: LanguageServerId, - }, - DiskBasedDiagnosticsFinished { - language_server_id: LanguageServerId, - }, - SnippetEdit { - buffer_id: BufferId, - edits: Vec<(lsp::Range, Snippet)>, - most_recent_edit: clock::Lamport, - }, - StartFormattingLocalBuffer(BufferId), - FinishFormattingLocalBuffer(BufferId), -} - -impl EventEmitter for LspStore {} - pub struct LspStore { - _subscription: gpui::Subscription, downstream_client: Option, upstream_client: Option, project_id: u64, @@ -165,10 +122,60 @@ pub struct LspStore { >, >, yarn: Model, + _subscription: gpui::Subscription, +} + +pub enum LspStoreEvent { + LanguageServerAdded(LanguageServerId), + LanguageServerRemoved(LanguageServerId), + LanguageServerUpdate { + language_server_id: LanguageServerId, + message: proto::update_language_server::Variant, + }, + LanguageServerLog(LanguageServerId, LanguageServerLogType, String), + LanguageServerPrompt(LanguageServerPromptRequest), + Notification(String), + RefreshInlayHints, + DiagnosticsUpdated { + language_server_id: LanguageServerId, + path: ProjectPath, + }, + DiskBasedDiagnosticsStarted { + language_server_id: LanguageServerId, + }, + DiskBasedDiagnosticsFinished { + language_server_id: LanguageServerId, + }, + SnippetEdit { + buffer_id: BufferId, + edits: Vec<(lsp::Range, Snippet)>, + most_recent_edit: clock::Lamport, + }, + StartFormattingLocalBuffer(BufferId), + FinishFormattingLocalBuffer(BufferId), +} + +#[derive(Clone, Debug, Serialize)] +pub struct LanguageServerStatus { + pub name: String, + pub pending_work: BTreeMap, + pub has_pending_diagnostic_updates: bool, + progress_tokens: HashSet, +} + +#[derive(Clone, Debug)] +struct CoreSymbol { + pub language_server_name: LanguageServerName, + pub source_worktree_id: WorktreeId, + pub path: ProjectPath, + pub name: String, + pub kind: lsp::SymbolKind, + pub range: Range>, + pub signature: [u8; 32], } impl LspStore { - pub fn init(client: &Arc) { + pub fn init(client: &AnyProtoClient) { client.add_model_request_handler(Self::handle_multi_lsp_query); client.add_model_request_handler(Self::handle_restart_language_servers); client.add_model_message_handler(Self::handle_start_language_server); @@ -180,6 +187,9 @@ impl LspStore { client.add_model_request_handler(Self::handle_get_project_symbols); client.add_model_request_handler(Self::handle_resolve_inlay_hint); client.add_model_request_handler(Self::handle_open_buffer_for_symbol); + client.add_model_request_handler(Self::handle_refresh_inlay_hints); + client.add_model_request_handler(Self::handle_on_type_formatting); + client.add_model_request_handler(Self::handle_apply_additional_edits_for_completion); client.add_model_request_handler(Self::handle_lsp_command::); client.add_model_request_handler(Self::handle_lsp_command::); client.add_model_request_handler(Self::handle_lsp_command::); @@ -192,10 +202,6 @@ impl LspStore { client.add_model_request_handler(Self::handle_lsp_command::); client.add_model_request_handler(Self::handle_lsp_command::); client.add_model_request_handler(Self::handle_lsp_command::); - - client.add_model_request_handler(Self::handle_refresh_inlay_hints); - client.add_model_request_handler(Self::handle_on_type_formatting); - client.add_model_request_handler(Self::handle_apply_additional_edits_for_completion); } #[allow(clippy::too_many_arguments)] @@ -296,26 +302,6 @@ impl LspStore { Ok(()) } - fn send_lsp_proto_request( - &self, - buffer: Model, - project_id: u64, - request: R, - cx: &mut ModelContext<'_, Self>, - ) -> Task::Response>> { - let Some(upstream_client) = self.upstream_client.clone() else { - return Task::ready(Err(anyhow!("disconnected before completing request"))); - }; - let message = request.to_proto(project_id, buffer.read(cx)); - cx.spawn(move |this, cx| async move { - let response = upstream_client.request(message).await?; - let this = this.upgrade().context("project dropped")?; - request - .response_from_proto(response, this, buffer, cx) - .await - }) - } - pub fn request_lsp( &self, buffer_handle: Model, @@ -416,6 +402,26 @@ impl LspStore { Task::ready(Ok(Default::default())) } + fn send_lsp_proto_request( + &self, + buffer: Model, + project_id: u64, + request: R, + cx: &mut ModelContext<'_, Self>, + ) -> Task::Response>> { + let Some(upstream_client) = self.upstream_client.clone() else { + return Task::ready(Err(anyhow!("disconnected before completing request"))); + }; + let message = request.to_proto(project_id, buffer.read(cx)); + cx.spawn(move |this, cx| async move { + let response = upstream_client.request(message).await?; + let this = this.upgrade().context("project dropped")?; + request + .response_from_proto(response, this, buffer, cx) + .await + }) + } + pub async fn execute_code_actions_on_servers( this: &WeakModel, adapters_and_servers: &Vec<(Arc, Arc)>, @@ -440,7 +446,7 @@ impl LspStore { .await?; for mut action in actions { - LspStore::try_resolve_code_action(&language_server, &mut action) + Self::try_resolve_code_action(&language_server, &mut action) .await .context("resolving a formatting code action")?; @@ -490,7 +496,7 @@ impl LspStore { Ok(()) } - pub async fn try_resolve_code_action( + async fn try_resolve_code_action( lang_server: &LanguageServer, action: &mut CodeAction, ) -> anyhow::Result<()> { @@ -507,63 +513,6 @@ impl LspStore { anyhow::Ok(()) } - pub(crate) fn serialize_completion(completion: &CoreCompletion) -> proto::Completion { - proto::Completion { - old_start: Some(serialize_anchor(&completion.old_range.start)), - old_end: Some(serialize_anchor(&completion.old_range.end)), - new_text: completion.new_text.clone(), - server_id: completion.server_id.0 as u64, - lsp_completion: serde_json::to_vec(&completion.lsp_completion).unwrap(), - } - } - - pub(crate) fn deserialize_completion(completion: proto::Completion) -> Result { - let old_start = completion - .old_start - .and_then(deserialize_anchor) - .ok_or_else(|| anyhow!("invalid old start"))?; - let old_end = completion - .old_end - .and_then(deserialize_anchor) - .ok_or_else(|| anyhow!("invalid old end"))?; - let lsp_completion = serde_json::from_slice(&completion.lsp_completion)?; - - Ok(CoreCompletion { - old_range: old_start..old_end, - new_text: completion.new_text, - server_id: LanguageServerId(completion.server_id as usize), - lsp_completion, - }) - } - - // todo: CodeAction.to_proto() - pub fn serialize_code_action(action: &CodeAction) -> proto::CodeAction { - proto::CodeAction { - server_id: action.server_id.0 as u64, - start: Some(serialize_anchor(&action.range.start)), - end: Some(serialize_anchor(&action.range.end)), - lsp_action: serde_json::to_vec(&action.lsp_action).unwrap(), - } - } - - // todo: CodeAction::from__proto() - pub fn deserialize_code_action(action: proto::CodeAction) -> Result { - let start = action - .start - .and_then(deserialize_anchor) - .ok_or_else(|| anyhow!("invalid start"))?; - let end = action - .end - .and_then(deserialize_anchor) - .ok_or_else(|| anyhow!("invalid end"))?; - let lsp_action = serde_json::from_slice(&action.lsp_action)?; - Ok(CodeAction { - server_id: LanguageServerId(action.server_id as usize), - range: start..end, - lsp_action, - }) - } - pub fn apply_code_action( &self, buffer_handle: Model, @@ -649,6 +598,66 @@ impl LspStore { } } + pub fn resolve_inlay_hint( + &self, + hint: InlayHint, + buffer_handle: Model, + server_id: LanguageServerId, + cx: &mut ModelContext, + ) -> Task> { + if let Some(upstream_client) = self.upstream_client.clone() { + let request = proto::ResolveInlayHint { + project_id: self.project_id, + buffer_id: buffer_handle.read(cx).remote_id().into(), + language_server_id: server_id.0 as u64, + hint: Some(InlayHints::project_to_proto_hint(hint.clone())), + }; + cx.spawn(move |_, _| async move { + let response = upstream_client + .request(request) + .await + .context("inlay hints proto request")?; + match response.hint { + Some(resolved_hint) => InlayHints::proto_to_project_hint(resolved_hint) + .context("inlay hints proto resolve response conversion"), + None => Ok(hint), + } + }) + } else { + let buffer = buffer_handle.read(cx); + let (_, lang_server) = if let Some((adapter, server)) = + self.language_server_for_buffer(buffer, server_id, cx) + { + (adapter.clone(), server.clone()) + } else { + return Task::ready(Ok(hint)); + }; + if !InlayHints::can_resolve_inlays(&lang_server.capabilities()) { + return Task::ready(Ok(hint)); + } + + let buffer_snapshot = buffer.snapshot(); + cx.spawn(move |_, mut cx| async move { + let resolve_task = lang_server.request::( + InlayHints::project_to_lsp_hint(hint, &buffer_snapshot), + ); + let resolved_hint = resolve_task + .await + .context("inlay hint resolve LSP request")?; + let resolved_hint = InlayHints::lsp_to_project_hint( + resolved_hint, + &buffer_handle, + server_id, + ResolveState::Resolved, + false, + &mut cx, + ) + .await?; + Ok(resolved_hint) + }) + } + } + pub(crate) fn linked_edit( &self, buffer: &Model, @@ -773,7 +782,7 @@ impl LspStore { self.on_type_format_impl(buffer, position, trigger, push_to_history, cx) } - pub fn on_type_format_impl( + fn on_type_format_impl( &mut self, buffer: Model, position: PointUtf16, @@ -1715,36 +1724,6 @@ impl LspStore { } } - pub(crate) fn deserialize_symbol(serialized_symbol: proto::Symbol) -> Result { - let source_worktree_id = WorktreeId::from_proto(serialized_symbol.source_worktree_id); - let worktree_id = WorktreeId::from_proto(serialized_symbol.worktree_id); - let kind = unsafe { mem::transmute::(serialized_symbol.kind) }; - let path = ProjectPath { - worktree_id, - path: PathBuf::from(serialized_symbol.path).into(), - }; - - let start = serialized_symbol - .start - .ok_or_else(|| anyhow!("invalid start"))?; - let end = serialized_symbol - .end - .ok_or_else(|| anyhow!("invalid end"))?; - Ok(CoreSymbol { - language_server_name: LanguageServerName(serialized_symbol.language_server_name.into()), - source_worktree_id, - path, - name: serialized_symbol.name, - range: Unclipped(PointUtf16::new(start.row, start.column)) - ..Unclipped(PointUtf16::new(end.row, end.column)), - kind, - signature: serialized_symbol - .signature - .try_into() - .map_err(|_| anyhow!("invalid signature"))?, - }) - } - pub fn diagnostic_summaries<'a>( &'a self, include_ignored: bool, @@ -2332,7 +2311,7 @@ impl LspStore { if let Some(client) = self.upstream_client.clone() { let request = client.request(proto::OpenBufferForSymbol { project_id: self.project_id, - symbol: Some(serialize_symbol(symbol)), + symbol: Some(Self::serialize_symbol(symbol)), }); cx.spawn(move |this, mut cx| async move { let response = request.await?; @@ -2343,13 +2322,10 @@ impl LspStore { .await }) } else { - let language_server_id = if let Some(id) = self - .language_server_id_for_worktree_and_name( - symbol.source_worktree_id, - symbol.language_server_name.clone(), - ) { - *id - } else { + let Some(&language_server_id) = self.language_server_ids.get(&( + symbol.source_worktree_id, + symbol.language_server_name.clone(), + )) else { return Task::ready(Err(anyhow!( "language server for worktree and language not found" ))); @@ -2587,7 +2563,7 @@ impl LspStore { }); } - pub async fn handle_lsp_command( + async fn handle_lsp_command( this: Model, envelope: TypedEnvelope, mut cx: AsyncAppContext, @@ -2629,7 +2605,7 @@ impl LspStore { })? } - pub async fn handle_multi_lsp_query( + async fn handle_multi_lsp_query( this: Model, envelope: TypedEnvelope, mut cx: AsyncAppContext, @@ -2769,7 +2745,7 @@ impl LspStore { } } - pub async fn handle_apply_code_action( + async fn handle_apply_code_action( this: Model, envelope: TypedEnvelope, mut cx: AsyncAppContext, @@ -2802,7 +2778,7 @@ impl LspStore { }) } - pub async fn handle_update_diagnostic_summary( + async fn handle_update_diagnostic_summary( this: Model, envelope: TypedEnvelope, mut cx: AsyncAppContext, @@ -2849,7 +2825,7 @@ impl LspStore { })? } - pub async fn handle_start_language_server( + async fn handle_start_language_server( this: Model, envelope: TypedEnvelope, mut cx: AsyncAppContext, @@ -2873,7 +2849,7 @@ impl LspStore { Ok(()) } - pub async fn handle_update_language_server( + async fn handle_update_language_server( this: Model, envelope: TypedEnvelope, mut cx: AsyncAppContext, @@ -3094,14 +3070,6 @@ impl LspStore { cx.notify(); } - pub fn language_server_id_for_worktree_and_name( - &self, - worktree_id: WorktreeId, - name: LanguageServerName, - ) -> Option<&LanguageServerId> { - self.language_server_ids.get(&(worktree_id, name)) - } - pub fn language_server_for_id(&self, id: LanguageServerId) -> Option> { if let Some(LanguageServerState::Running { server, .. }) = self.language_servers.get(&id) { Some(server.clone()) @@ -3112,261 +3080,6 @@ impl LspStore { } } - pub async fn deserialize_text_edits( - this: Model, - buffer_to_edit: Model, - edits: Vec, - push_to_history: bool, - _: Arc, - language_server: Arc, - cx: &mut AsyncAppContext, - ) -> Result> { - let edits = this - .update(cx, |this, cx| { - this.edits_from_lsp( - &buffer_to_edit, - edits, - language_server.server_id(), - None, - cx, - ) - })? - .await?; - - let transaction = buffer_to_edit.update(cx, |buffer, cx| { - buffer.finalize_last_transaction(); - buffer.start_transaction(); - for (range, text) in edits { - buffer.edit([(range, text)], None, cx); - } - - if buffer.end_transaction(cx).is_some() { - let transaction = buffer.finalize_last_transaction().unwrap().clone(); - if !push_to_history { - buffer.forget_transaction(transaction.id); - } - Some(transaction) - } else { - None - } - })?; - - Ok(transaction) - } - - pub async fn deserialize_workspace_edit( - this: Model, - edit: lsp::WorkspaceEdit, - push_to_history: bool, - lsp_adapter: Arc, - language_server: Arc, - cx: &mut AsyncAppContext, - ) -> Result { - let fs = this.update(cx, |this, _| this.fs.clone())?; - let mut operations = Vec::new(); - if let Some(document_changes) = edit.document_changes { - match document_changes { - lsp::DocumentChanges::Edits(edits) => { - operations.extend(edits.into_iter().map(lsp::DocumentChangeOperation::Edit)) - } - lsp::DocumentChanges::Operations(ops) => operations = ops, - } - } else if let Some(changes) = edit.changes { - operations.extend(changes.into_iter().map(|(uri, edits)| { - lsp::DocumentChangeOperation::Edit(lsp::TextDocumentEdit { - text_document: lsp::OptionalVersionedTextDocumentIdentifier { - uri, - version: None, - }, - edits: edits.into_iter().map(Edit::Plain).collect(), - }) - })); - } - - let mut project_transaction = ProjectTransaction::default(); - for operation in operations { - match operation { - lsp::DocumentChangeOperation::Op(lsp::ResourceOp::Create(op)) => { - let abs_path = op - .uri - .to_file_path() - .map_err(|_| anyhow!("can't convert URI to path"))?; - - if let Some(parent_path) = abs_path.parent() { - fs.create_dir(parent_path).await?; - } - if abs_path.ends_with("/") { - fs.create_dir(&abs_path).await?; - } else { - fs.create_file( - &abs_path, - op.options - .map(|options| fs::CreateOptions { - overwrite: options.overwrite.unwrap_or(false), - ignore_if_exists: options.ignore_if_exists.unwrap_or(false), - }) - .unwrap_or_default(), - ) - .await?; - } - } - - lsp::DocumentChangeOperation::Op(lsp::ResourceOp::Rename(op)) => { - let source_abs_path = op - .old_uri - .to_file_path() - .map_err(|_| anyhow!("can't convert URI to path"))?; - let target_abs_path = op - .new_uri - .to_file_path() - .map_err(|_| anyhow!("can't convert URI to path"))?; - fs.rename( - &source_abs_path, - &target_abs_path, - op.options - .map(|options| fs::RenameOptions { - overwrite: options.overwrite.unwrap_or(false), - ignore_if_exists: options.ignore_if_exists.unwrap_or(false), - }) - .unwrap_or_default(), - ) - .await?; - } - - lsp::DocumentChangeOperation::Op(lsp::ResourceOp::Delete(op)) => { - let abs_path = op - .uri - .to_file_path() - .map_err(|_| anyhow!("can't convert URI to path"))?; - let options = op - .options - .map(|options| fs::RemoveOptions { - recursive: options.recursive.unwrap_or(false), - ignore_if_not_exists: options.ignore_if_not_exists.unwrap_or(false), - }) - .unwrap_or_default(); - if abs_path.ends_with("/") { - fs.remove_dir(&abs_path, options).await?; - } else { - fs.remove_file(&abs_path, options).await?; - } - } - - lsp::DocumentChangeOperation::Edit(op) => { - let buffer_to_edit = this - .update(cx, |this, cx| { - this.open_local_buffer_via_lsp( - op.text_document.uri.clone(), - language_server.server_id(), - lsp_adapter.name.clone(), - cx, - ) - })? - .await?; - - let edits = this - .update(cx, |this, cx| { - let path = buffer_to_edit.read(cx).project_path(cx); - let active_entry = this.active_entry; - let is_active_entry = path.clone().map_or(false, |project_path| { - this.worktree_store - .read(cx) - .entry_for_path(&project_path, cx) - .map_or(false, |entry| Some(entry.id) == active_entry) - }); - - let (mut edits, mut snippet_edits) = (vec![], vec![]); - for edit in op.edits { - match edit { - Edit::Plain(edit) => edits.push(edit), - Edit::Annotated(edit) => edits.push(edit.text_edit), - Edit::Snippet(edit) => { - let Ok(snippet) = Snippet::parse(&edit.snippet.value) - else { - continue; - }; - - if is_active_entry { - snippet_edits.push((edit.range, snippet)); - } else { - // Since this buffer is not focused, apply a normal edit. - edits.push(TextEdit { - range: edit.range, - new_text: snippet.text, - }); - } - } - } - } - if !snippet_edits.is_empty() { - if let Some(buffer_version) = op.text_document.version { - let buffer_id = buffer_to_edit.read(cx).remote_id(); - // Check if the edit that triggered that edit has been made by this participant. - let most_recent_edit = this - .buffer_snapshots - .get(&buffer_id) - .and_then(|server_to_snapshots| { - let all_snapshots = server_to_snapshots - .get(&language_server.server_id())?; - all_snapshots - .binary_search_by_key(&buffer_version, |snapshot| { - snapshot.version - }) - .ok() - .and_then(|index| all_snapshots.get(index)) - }) - .and_then(|lsp_snapshot| { - let version = lsp_snapshot.snapshot.version(); - version.iter().max_by_key(|timestamp| timestamp.value) - }); - if let Some(most_recent_edit) = most_recent_edit { - cx.emit(LspStoreEvent::SnippetEdit { - buffer_id, - edits: snippet_edits, - most_recent_edit, - }); - } - } - } - - this.edits_from_lsp( - &buffer_to_edit, - edits, - language_server.server_id(), - op.text_document.version, - cx, - ) - })? - .await?; - - let transaction = buffer_to_edit.update(cx, |buffer, cx| { - buffer.finalize_last_transaction(); - buffer.start_transaction(); - for (range, text) in edits { - buffer.edit([(range, text)], None, cx); - } - let transaction = if buffer.end_transaction(cx).is_some() { - let transaction = buffer.finalize_last_transaction().unwrap().clone(); - if !push_to_history { - buffer.forget_transaction(transaction.id); - } - Some(transaction) - } else { - None - }; - - transaction - })?; - if let Some(transaction) = transaction { - project_transaction.0.insert(buffer_to_edit, transaction); - } - } - } - } - - Ok(project_transaction) - } - async fn on_lsp_workspace_edit( this: WeakModel, params: lsp::ApplyWorkspaceEditParams, @@ -3616,7 +3329,7 @@ impl LspStore { } #[allow(clippy::type_complexity)] - pub fn edits_from_lsp( + pub(crate) fn edits_from_lsp( &mut self, buffer: &Model, lsp_edits: impl 'static + Send + IntoIterator, @@ -3822,7 +3535,7 @@ impl LspStore { Ok(proto::Ack {}) } - pub async fn handle_inlay_hints( + async fn handle_inlay_hints( this: Model, envelope: TypedEnvelope, mut cx: AsyncAppContext, @@ -3867,7 +3580,7 @@ impl LspStore { }) } - pub async fn handle_resolve_inlay_hint( + async fn handle_resolve_inlay_hint( this: Model, envelope: TypedEnvelope, mut cx: AsyncAppContext, @@ -3898,66 +3611,6 @@ impl LspStore { }) } - pub fn resolve_inlay_hint( - &self, - hint: InlayHint, - buffer_handle: Model, - server_id: LanguageServerId, - cx: &mut ModelContext, - ) -> Task> { - if let Some(upstream_client) = self.upstream_client.clone() { - let request = proto::ResolveInlayHint { - project_id: self.project_id, - buffer_id: buffer_handle.read(cx).remote_id().into(), - language_server_id: server_id.0 as u64, - hint: Some(InlayHints::project_to_proto_hint(hint.clone())), - }; - cx.spawn(move |_, _| async move { - let response = upstream_client - .request(request) - .await - .context("inlay hints proto request")?; - match response.hint { - Some(resolved_hint) => InlayHints::proto_to_project_hint(resolved_hint) - .context("inlay hints proto resolve response conversion"), - None => Ok(hint), - } - }) - } else { - let buffer = buffer_handle.read(cx); - let (_, lang_server) = if let Some((adapter, server)) = - self.language_server_for_buffer(buffer, server_id, cx) - { - (adapter.clone(), server.clone()) - } else { - return Task::ready(Ok(hint)); - }; - if !InlayHints::can_resolve_inlays(&lang_server.capabilities()) { - return Task::ready(Ok(hint)); - } - - let buffer_snapshot = buffer.snapshot(); - cx.spawn(move |_, mut cx| async move { - let resolve_task = lang_server.request::( - InlayHints::project_to_lsp_hint(hint, &buffer_snapshot), - ); - let resolved_hint = resolve_task - .await - .context("inlay hint resolve LSP request")?; - let resolved_hint = InlayHints::lsp_to_project_hint( - resolved_hint, - &buffer_handle, - server_id, - ResolveState::Resolved, - false, - &mut cx, - ) - .await?; - Ok(resolved_hint) - }) - } - } - async fn handle_open_buffer_for_symbol( this: Model, envelope: TypedEnvelope, @@ -4039,7 +3692,7 @@ impl LspStore { .await?; Ok(proto::GetProjectSymbolsResponse { - symbols: symbols.iter().map(serialize_symbol).collect(), + symbols: symbols.iter().map(Self::serialize_symbol).collect(), }) } @@ -5663,8 +5316,370 @@ impl LspStore { Vec::new() } } + + pub async fn deserialize_text_edits( + this: Model, + buffer_to_edit: Model, + edits: Vec, + push_to_history: bool, + _: Arc, + language_server: Arc, + cx: &mut AsyncAppContext, + ) -> Result> { + let edits = this + .update(cx, |this, cx| { + this.edits_from_lsp( + &buffer_to_edit, + edits, + language_server.server_id(), + None, + cx, + ) + })? + .await?; + + let transaction = buffer_to_edit.update(cx, |buffer, cx| { + buffer.finalize_last_transaction(); + buffer.start_transaction(); + for (range, text) in edits { + buffer.edit([(range, text)], None, cx); + } + + if buffer.end_transaction(cx).is_some() { + let transaction = buffer.finalize_last_transaction().unwrap().clone(); + if !push_to_history { + buffer.forget_transaction(transaction.id); + } + Some(transaction) + } else { + None + } + })?; + + Ok(transaction) + } + + pub async fn deserialize_workspace_edit( + this: Model, + edit: lsp::WorkspaceEdit, + push_to_history: bool, + lsp_adapter: Arc, + language_server: Arc, + cx: &mut AsyncAppContext, + ) -> Result { + let fs = this.update(cx, |this, _| this.fs.clone())?; + let mut operations = Vec::new(); + if let Some(document_changes) = edit.document_changes { + match document_changes { + lsp::DocumentChanges::Edits(edits) => { + operations.extend(edits.into_iter().map(lsp::DocumentChangeOperation::Edit)) + } + lsp::DocumentChanges::Operations(ops) => operations = ops, + } + } else if let Some(changes) = edit.changes { + operations.extend(changes.into_iter().map(|(uri, edits)| { + lsp::DocumentChangeOperation::Edit(lsp::TextDocumentEdit { + text_document: lsp::OptionalVersionedTextDocumentIdentifier { + uri, + version: None, + }, + edits: edits.into_iter().map(Edit::Plain).collect(), + }) + })); + } + + let mut project_transaction = ProjectTransaction::default(); + for operation in operations { + match operation { + lsp::DocumentChangeOperation::Op(lsp::ResourceOp::Create(op)) => { + let abs_path = op + .uri + .to_file_path() + .map_err(|_| anyhow!("can't convert URI to path"))?; + + if let Some(parent_path) = abs_path.parent() { + fs.create_dir(parent_path).await?; + } + if abs_path.ends_with("/") { + fs.create_dir(&abs_path).await?; + } else { + fs.create_file( + &abs_path, + op.options + .map(|options| fs::CreateOptions { + overwrite: options.overwrite.unwrap_or(false), + ignore_if_exists: options.ignore_if_exists.unwrap_or(false), + }) + .unwrap_or_default(), + ) + .await?; + } + } + + lsp::DocumentChangeOperation::Op(lsp::ResourceOp::Rename(op)) => { + let source_abs_path = op + .old_uri + .to_file_path() + .map_err(|_| anyhow!("can't convert URI to path"))?; + let target_abs_path = op + .new_uri + .to_file_path() + .map_err(|_| anyhow!("can't convert URI to path"))?; + fs.rename( + &source_abs_path, + &target_abs_path, + op.options + .map(|options| fs::RenameOptions { + overwrite: options.overwrite.unwrap_or(false), + ignore_if_exists: options.ignore_if_exists.unwrap_or(false), + }) + .unwrap_or_default(), + ) + .await?; + } + + lsp::DocumentChangeOperation::Op(lsp::ResourceOp::Delete(op)) => { + let abs_path = op + .uri + .to_file_path() + .map_err(|_| anyhow!("can't convert URI to path"))?; + let options = op + .options + .map(|options| fs::RemoveOptions { + recursive: options.recursive.unwrap_or(false), + ignore_if_not_exists: options.ignore_if_not_exists.unwrap_or(false), + }) + .unwrap_or_default(); + if abs_path.ends_with("/") { + fs.remove_dir(&abs_path, options).await?; + } else { + fs.remove_file(&abs_path, options).await?; + } + } + + lsp::DocumentChangeOperation::Edit(op) => { + let buffer_to_edit = this + .update(cx, |this, cx| { + this.open_local_buffer_via_lsp( + op.text_document.uri.clone(), + language_server.server_id(), + lsp_adapter.name.clone(), + cx, + ) + })? + .await?; + + let edits = this + .update(cx, |this, cx| { + let path = buffer_to_edit.read(cx).project_path(cx); + let active_entry = this.active_entry; + let is_active_entry = path.clone().map_or(false, |project_path| { + this.worktree_store + .read(cx) + .entry_for_path(&project_path, cx) + .map_or(false, |entry| Some(entry.id) == active_entry) + }); + + let (mut edits, mut snippet_edits) = (vec![], vec![]); + for edit in op.edits { + match edit { + Edit::Plain(edit) => edits.push(edit), + Edit::Annotated(edit) => edits.push(edit.text_edit), + Edit::Snippet(edit) => { + let Ok(snippet) = Snippet::parse(&edit.snippet.value) + else { + continue; + }; + + if is_active_entry { + snippet_edits.push((edit.range, snippet)); + } else { + // Since this buffer is not focused, apply a normal edit. + edits.push(TextEdit { + range: edit.range, + new_text: snippet.text, + }); + } + } + } + } + if !snippet_edits.is_empty() { + if let Some(buffer_version) = op.text_document.version { + let buffer_id = buffer_to_edit.read(cx).remote_id(); + // Check if the edit that triggered that edit has been made by this participant. + let most_recent_edit = this + .buffer_snapshots + .get(&buffer_id) + .and_then(|server_to_snapshots| { + let all_snapshots = server_to_snapshots + .get(&language_server.server_id())?; + all_snapshots + .binary_search_by_key(&buffer_version, |snapshot| { + snapshot.version + }) + .ok() + .and_then(|index| all_snapshots.get(index)) + }) + .and_then(|lsp_snapshot| { + let version = lsp_snapshot.snapshot.version(); + version.iter().max_by_key(|timestamp| timestamp.value) + }); + if let Some(most_recent_edit) = most_recent_edit { + cx.emit(LspStoreEvent::SnippetEdit { + buffer_id, + edits: snippet_edits, + most_recent_edit, + }); + } + } + } + + this.edits_from_lsp( + &buffer_to_edit, + edits, + language_server.server_id(), + op.text_document.version, + cx, + ) + })? + .await?; + + let transaction = buffer_to_edit.update(cx, |buffer, cx| { + buffer.finalize_last_transaction(); + buffer.start_transaction(); + for (range, text) in edits { + buffer.edit([(range, text)], None, cx); + } + let transaction = if buffer.end_transaction(cx).is_some() { + let transaction = buffer.finalize_last_transaction().unwrap().clone(); + if !push_to_history { + buffer.forget_transaction(transaction.id); + } + Some(transaction) + } else { + None + }; + + transaction + })?; + if let Some(transaction) = transaction { + project_transaction.0.insert(buffer_to_edit, transaction); + } + } + } + } + + Ok(project_transaction) + } + + fn serialize_symbol(symbol: &Symbol) -> proto::Symbol { + proto::Symbol { + language_server_name: symbol.language_server_name.0.to_string(), + source_worktree_id: symbol.source_worktree_id.to_proto(), + worktree_id: symbol.path.worktree_id.to_proto(), + path: symbol.path.path.to_string_lossy().to_string(), + name: symbol.name.clone(), + kind: unsafe { mem::transmute::(symbol.kind) }, + start: Some(proto::PointUtf16 { + row: symbol.range.start.0.row, + column: symbol.range.start.0.column, + }), + end: Some(proto::PointUtf16 { + row: symbol.range.end.0.row, + column: symbol.range.end.0.column, + }), + signature: symbol.signature.to_vec(), + } + } + + fn deserialize_symbol(serialized_symbol: proto::Symbol) -> Result { + let source_worktree_id = WorktreeId::from_proto(serialized_symbol.source_worktree_id); + let worktree_id = WorktreeId::from_proto(serialized_symbol.worktree_id); + let kind = unsafe { mem::transmute::(serialized_symbol.kind) }; + let path = ProjectPath { + worktree_id, + path: PathBuf::from(serialized_symbol.path).into(), + }; + + let start = serialized_symbol + .start + .ok_or_else(|| anyhow!("invalid start"))?; + let end = serialized_symbol + .end + .ok_or_else(|| anyhow!("invalid end"))?; + Ok(CoreSymbol { + language_server_name: LanguageServerName(serialized_symbol.language_server_name.into()), + source_worktree_id, + path, + name: serialized_symbol.name, + range: Unclipped(PointUtf16::new(start.row, start.column)) + ..Unclipped(PointUtf16::new(end.row, end.column)), + kind, + signature: serialized_symbol + .signature + .try_into() + .map_err(|_| anyhow!("invalid signature"))?, + }) + } + + pub(crate) fn serialize_completion(completion: &CoreCompletion) -> proto::Completion { + proto::Completion { + old_start: Some(serialize_anchor(&completion.old_range.start)), + old_end: Some(serialize_anchor(&completion.old_range.end)), + new_text: completion.new_text.clone(), + server_id: completion.server_id.0 as u64, + lsp_completion: serde_json::to_vec(&completion.lsp_completion).unwrap(), + } + } + + pub(crate) fn deserialize_completion(completion: proto::Completion) -> Result { + let old_start = completion + .old_start + .and_then(deserialize_anchor) + .ok_or_else(|| anyhow!("invalid old start"))?; + let old_end = completion + .old_end + .and_then(deserialize_anchor) + .ok_or_else(|| anyhow!("invalid old end"))?; + let lsp_completion = serde_json::from_slice(&completion.lsp_completion)?; + + Ok(CoreCompletion { + old_range: old_start..old_end, + new_text: completion.new_text, + server_id: LanguageServerId(completion.server_id as usize), + lsp_completion, + }) + } + + pub(crate) fn serialize_code_action(action: &CodeAction) -> proto::CodeAction { + proto::CodeAction { + server_id: action.server_id.0 as u64, + start: Some(serialize_anchor(&action.range.start)), + end: Some(serialize_anchor(&action.range.end)), + lsp_action: serde_json::to_vec(&action.lsp_action).unwrap(), + } + } + + pub(crate) fn deserialize_code_action(action: proto::CodeAction) -> Result { + let start = action + .start + .and_then(deserialize_anchor) + .ok_or_else(|| anyhow!("invalid start"))?; + let end = action + .end + .and_then(deserialize_anchor) + .ok_or_else(|| anyhow!("invalid end"))?; + let lsp_action = serde_json::from_slice(&action.lsp_action)?; + Ok(CodeAction { + server_id: LanguageServerId(action.server_id as usize), + range: start..end, + lsp_action, + }) + } } +impl EventEmitter for LspStore {} + fn remove_empty_hover_blocks(mut hover: Hover) -> Option { hover .contents @@ -5779,14 +5794,6 @@ pub enum LanguageServerState { }, } -#[derive(Clone, Debug, Serialize)] -pub struct LanguageServerStatus { - pub name: String, - pub pending_work: BTreeMap, - pub has_pending_diagnostic_updates: bool, - progress_tokens: HashSet, -} - #[derive(Clone, Debug, Serialize)] pub struct LanguageServerProgress { pub is_disk_based_diagnostics_progress: bool, @@ -6053,26 +6060,6 @@ fn include_text(server: &lsp::LanguageServer) -> Option { } } -fn serialize_symbol(symbol: &Symbol) -> proto::Symbol { - proto::Symbol { - language_server_name: symbol.language_server_name.0.to_string(), - source_worktree_id: symbol.source_worktree_id.to_proto(), - worktree_id: symbol.path.worktree_id.to_proto(), - path: symbol.path.path.to_string_lossy().to_string(), - name: symbol.name.clone(), - kind: unsafe { mem::transmute::(symbol.kind) }, - start: Some(proto::PointUtf16 { - row: symbol.range.start.0.row, - column: symbol.range.start.0.column, - }), - end: Some(proto::PointUtf16 { - row: symbol.range.end.0.row, - column: symbol.range.end.0.column, - }), - signature: symbol.signature.to_vec(), - } -} - #[cfg(test)] #[test] fn test_glob_literal_prefix() { diff --git a/crates/project/src/project.rs b/crates/project/src/project.rs index 46d3929d30..99ebb8ca42 100644 --- a/crates/project/src/project.rs +++ b/crates/project/src/project.rs @@ -65,7 +65,10 @@ use paths::{ use prettier_support::{DefaultPrettier, PrettierInstance}; use project_settings::{LspSettings, ProjectSettings}; use remote::SshSession; -use rpc::{proto::AnyProtoClient, ErrorCode}; +use rpc::{ + proto::{AnyProtoClient, SSH_PROJECT_ID}, + ErrorCode, +}; use search::{SearchQuery, SearchResult}; use search_history::SearchHistory; use settings::{watch_config_file, Settings, SettingsLocation, SettingsStore}; @@ -574,6 +577,7 @@ impl Project { connection_manager::init(client.clone(), cx); Self::init_settings(cx); + let client: AnyProtoClient = client.clone().into(); client.add_model_message_handler(Self::handle_add_collaborator); client.add_model_message_handler(Self::handle_update_project_collaborator); client.add_model_message_handler(Self::handle_remove_collaborator); @@ -594,9 +598,9 @@ impl Project { client.add_model_request_handler(Self::handle_task_templates); client.add_model_message_handler(Self::handle_create_buffer_for_peer); - WorktreeStore::init(client); - BufferStore::init(client); - LspStore::init(client); + WorktreeStore::init(&client); + BufferStore::init(&client); + LspStore::init(&client); } pub fn local( @@ -697,15 +701,19 @@ impl Project { ) -> Model { let this = Self::local(client, node, user_store, languages, fs, None, cx); this.update(cx, |this, cx| { - let buffer_store = this.buffer_store.downgrade(); + let client: AnyProtoClient = ssh.clone().into(); + this.worktree_store.update(cx, |store, _cx| { - store.set_upstream_client(ssh.clone().into()); + store.set_upstream_client(client.clone()); }); - ssh.add_message_handler(cx.weak_model(), Self::handle_update_worktree); - ssh.add_message_handler(cx.weak_model(), Self::handle_create_buffer_for_peer); - ssh.add_message_handler(buffer_store.clone(), BufferStore::handle_update_buffer_file); - ssh.add_message_handler(buffer_store.clone(), BufferStore::handle_update_diff_base); + ssh.subscribe_to_entity(SSH_PROJECT_ID, &cx.handle()); + ssh.subscribe_to_entity(SSH_PROJECT_ID, &this.buffer_store); + ssh.subscribe_to_entity(SSH_PROJECT_ID, &this.worktree_store); + client.add_model_message_handler(Self::handle_update_worktree); + client.add_model_message_handler(Self::handle_create_buffer_for_peer); + client.add_model_message_handler(BufferStore::handle_update_buffer_file); + client.add_model_message_handler(BufferStore::handle_update_diff_base); this.ssh_session = Some(ssh); }); diff --git a/crates/project/src/worktree_store.rs b/crates/project/src/worktree_store.rs index c021af4e09..439e02da17 100644 --- a/crates/project/src/worktree_store.rs +++ b/crates/project/src/worktree_store.rs @@ -5,7 +5,7 @@ use std::{ }; use anyhow::{anyhow, Context as _, Result}; -use client::{Client, DevServerProjectId}; +use client::DevServerProjectId; use collections::{HashMap, HashSet}; use fs::Fs; use futures::{ @@ -17,7 +17,7 @@ use gpui::{ }; use postage::oneshot; use rpc::{ - proto::{self, AnyProtoClient}, + proto::{self, AnyProtoClient, SSH_PROJECT_ID}, TypedEnvelope, }; use smol::{ @@ -58,12 +58,12 @@ pub enum WorktreeStoreEvent { impl EventEmitter for WorktreeStore {} impl WorktreeStore { - pub fn init(client: &Arc) { - client.add_model_request_handler(WorktreeStore::handle_create_project_entry); - client.add_model_request_handler(WorktreeStore::handle_rename_project_entry); - client.add_model_request_handler(WorktreeStore::handle_copy_project_entry); - client.add_model_request_handler(WorktreeStore::handle_delete_project_entry); - client.add_model_request_handler(WorktreeStore::handle_expand_project_entry); + pub fn init(client: &AnyProtoClient) { + client.add_model_request_handler(Self::handle_create_project_entry); + client.add_model_request_handler(Self::handle_rename_project_entry); + client.add_model_request_handler(Self::handle_copy_project_entry); + client.add_model_request_handler(Self::handle_delete_project_entry); + client.add_model_request_handler(Self::handle_expand_project_entry); } pub fn new(retain_worktrees: bool, fs: Arc) -> Self { @@ -188,7 +188,10 @@ impl WorktreeStore { let path = abs_path.to_string_lossy().to_string(); cx.spawn(|this, mut cx| async move { let response = client - .request(proto::AddWorktree { path: path.clone() }) + .request(proto::AddWorktree { + project_id: SSH_PROJECT_ID, + path: path.clone(), + }) .await?; let worktree = cx.update(|cx| { Worktree::remote( diff --git a/crates/proto/Cargo.toml b/crates/proto/Cargo.toml index 5ee2e60aaa..e1d111366c 100644 --- a/crates/proto/Cargo.toml +++ b/crates/proto/Cargo.toml @@ -20,8 +20,10 @@ doctest = false anyhow.workspace = true collections.workspace = true futures.workspace = true +parking_lot.workspace = true prost.workspace = true serde.workspace = true +gpui.workspace = true [build-dependencies] prost-build.workspace = true diff --git a/crates/proto/proto/zed.proto b/crates/proto/proto/zed.proto index 6f5321a5bd..43278e6a51 100644 --- a/crates/proto/proto/zed.proto +++ b/crates/proto/proto/zed.proto @@ -2484,6 +2484,7 @@ message GetLlmTokenResponse { // Remote FS message AddWorktree { + uint64 project_id = 2; string path = 1; } diff --git a/crates/proto/src/proto.rs b/crates/proto/src/proto.rs index dabd29f914..b580338320 100644 --- a/crates/proto/src/proto.rs +++ b/crates/proto/src/proto.rs @@ -2,14 +2,14 @@ pub mod error; mod macros; +mod proto_client; mod typed_envelope; pub use error::*; +pub use proto_client::*; pub use typed_envelope::*; -use anyhow::anyhow; use collections::HashMap; -use futures::{future::BoxFuture, Future}; pub use prost::{DecodeError, Message}; use serde::Serialize; use std::{ @@ -17,12 +17,14 @@ use std::{ cmp, fmt::{self, Debug}, iter, mem, - sync::Arc, time::{Duration, SystemTime, UNIX_EPOCH}, }; include!(concat!(env!("OUT_DIR"), "/zed.messages.rs")); +pub const SSH_PEER_ID: PeerId = PeerId { owner_id: 0, id: 0 }; +pub const SSH_PROJECT_ID: u64 = 0; + pub trait EnvelopedMessage: Clone + Debug + Serialize + Sized + Send + Sync + 'static { const NAME: &'static str; const PRIORITY: MessagePriority; @@ -60,51 +62,6 @@ pub enum MessagePriority { Background, } -pub trait ProtoClient: Send + Sync { - fn request( - &self, - envelope: Envelope, - request_type: &'static str, - ) -> BoxFuture<'static, anyhow::Result>; - - fn send(&self, envelope: Envelope, message_type: &'static str) -> anyhow::Result<()>; -} - -#[derive(Clone)] -pub struct AnyProtoClient(Arc); - -impl From> for AnyProtoClient -where - T: ProtoClient + 'static, -{ - fn from(client: Arc) -> Self { - Self(client) - } -} - -impl AnyProtoClient { - pub fn new(client: Arc) -> Self { - Self(client) - } - - pub fn request( - &self, - request: T, - ) -> impl Future> { - let envelope = request.into_envelope(0, None, None); - let response = self.0.request(envelope, T::NAME); - async move { - T::Response::from_envelope(response.await?) - .ok_or_else(|| anyhow!("received response of the wrong type")) - } - } - - pub fn send(&self, request: T) -> anyhow::Result<()> { - let envelope = request.into_envelope(0, None, None); - self.0.send(envelope, T::NAME) - } -} - impl AnyTypedEnvelope for TypedEnvelope { fn payload_type_id(&self) -> TypeId { TypeId::of::() @@ -537,11 +494,13 @@ request_messages!( entity_messages!( {project_id, ShareProject}, AddProjectCollaborator, + AddWorktree, ApplyCodeAction, ApplyCompletionAdditionalEdits, BlameBuffer, BufferReloaded, BufferSaved, + CloseBuffer, CopyProjectEntry, CreateBufferForPeer, CreateProjectEntry, diff --git a/crates/proto/src/proto_client.rs b/crates/proto/src/proto_client.rs new file mode 100644 index 0000000000..edcb6417d8 --- /dev/null +++ b/crates/proto/src/proto_client.rs @@ -0,0 +1,277 @@ +use crate::{ + error::ErrorExt as _, AnyTypedEnvelope, EntityMessage, Envelope, EnvelopedMessage, + RequestMessage, TypedEnvelope, +}; +use anyhow::anyhow; +use collections::HashMap; +use futures::{ + future::{BoxFuture, LocalBoxFuture}, + Future, FutureExt as _, +}; +use gpui::{AnyModel, AnyWeakModel, AsyncAppContext, Model}; +pub use prost::Message; +use std::{any::TypeId, sync::Arc}; + +#[derive(Clone)] +pub struct AnyProtoClient(Arc); + +pub trait ProtoClient: Send + Sync { + fn request( + &self, + envelope: Envelope, + request_type: &'static str, + ) -> BoxFuture<'static, anyhow::Result>; + + fn send(&self, envelope: Envelope, message_type: &'static str) -> anyhow::Result<()>; + + fn send_response(&self, envelope: Envelope, message_type: &'static str) -> anyhow::Result<()>; + + fn message_handler_set(&self) -> &parking_lot::Mutex; +} + +#[derive(Default)] +pub struct ProtoMessageHandlerSet { + pub entity_types_by_message_type: HashMap, + pub entities_by_type_and_remote_id: HashMap<(TypeId, u64), EntityMessageSubscriber>, + pub entity_id_extractors: HashMap u64>, + pub models_by_message_type: HashMap, + pub message_handlers: HashMap, +} + +pub type ProtoMessageHandler = Arc< + dyn Send + + Sync + + Fn( + AnyModel, + Box, + AnyProtoClient, + AsyncAppContext, + ) -> LocalBoxFuture<'static, anyhow::Result<()>>, +>; + +impl ProtoMessageHandlerSet { + pub fn clear(&mut self) { + self.message_handlers.clear(); + self.models_by_message_type.clear(); + self.entities_by_type_and_remote_id.clear(); + self.entity_id_extractors.clear(); + } + + fn add_message_handler( + &mut self, + message_type_id: TypeId, + model: gpui::AnyWeakModel, + handler: ProtoMessageHandler, + ) { + self.models_by_message_type.insert(message_type_id, model); + let prev_handler = self.message_handlers.insert(message_type_id, handler); + if prev_handler.is_some() { + panic!("registered handler for the same message twice"); + } + } + + fn add_entity_message_handler( + &mut self, + message_type_id: TypeId, + model_type_id: TypeId, + entity_id_extractor: fn(&dyn AnyTypedEnvelope) -> u64, + handler: ProtoMessageHandler, + ) { + self.entity_id_extractors + .entry(message_type_id) + .or_insert(entity_id_extractor); + self.entity_types_by_message_type + .insert(message_type_id, model_type_id); + let prev_handler = self.message_handlers.insert(message_type_id, handler); + if prev_handler.is_some() { + panic!("registered handler for the same message twice"); + } + } + + pub fn handle_message( + this: &parking_lot::Mutex, + message: Box, + client: AnyProtoClient, + cx: AsyncAppContext, + ) -> Option>> { + let payload_type_id = message.payload_type_id(); + let mut this = this.lock(); + let handler = this.message_handlers.get(&payload_type_id)?.clone(); + let entity = if let Some(entity) = this.models_by_message_type.get(&payload_type_id) { + entity.upgrade()? + } else { + let extract_entity_id = *this.entity_id_extractors.get(&payload_type_id)?; + let entity_type_id = *this.entity_types_by_message_type.get(&payload_type_id)?; + let entity_id = (extract_entity_id)(message.as_ref()); + + match this + .entities_by_type_and_remote_id + .get_mut(&(entity_type_id, entity_id))? + { + EntityMessageSubscriber::Pending(pending) => { + pending.push(message); + return None; + } + EntityMessageSubscriber::Entity { handle } => handle.upgrade()?, + } + }; + drop(this); + Some(handler(entity, message, client, cx)) + } +} + +pub enum EntityMessageSubscriber { + Entity { handle: AnyWeakModel }, + Pending(Vec>), +} + +impl From> for AnyProtoClient +where + T: ProtoClient + 'static, +{ + fn from(client: Arc) -> Self { + Self(client) + } +} + +impl AnyProtoClient { + pub fn new(client: Arc) -> Self { + Self(client) + } + + pub fn request( + &self, + request: T, + ) -> impl Future> { + let envelope = request.into_envelope(0, None, None); + let response = self.0.request(envelope, T::NAME); + async move { + T::Response::from_envelope(response.await?) + .ok_or_else(|| anyhow!("received response of the wrong type")) + } + } + + pub fn send(&self, request: T) -> anyhow::Result<()> { + let envelope = request.into_envelope(0, None, None); + self.0.send(envelope, T::NAME) + } + + pub fn send_response( + &self, + request_id: u32, + request: T, + ) -> anyhow::Result<()> { + let envelope = request.into_envelope(0, Some(request_id), None); + self.0.send(envelope, T::NAME) + } + + pub fn add_request_handler(&self, model: gpui::WeakModel, handler: H) + where + M: RequestMessage, + E: 'static, + H: 'static + Sync + Fn(Model, TypedEnvelope, AsyncAppContext) -> F + Send + Sync, + F: 'static + Future>, + { + self.0.message_handler_set().lock().add_message_handler( + TypeId::of::(), + model.into(), + Arc::new(move |model, envelope, client, cx| { + let model = model.downcast::().unwrap(); + let envelope = envelope.into_any().downcast::>().unwrap(); + let request_id = envelope.message_id(); + handler(model, *envelope, cx) + .then(move |result| async move { + match result { + Ok(response) => { + client.send_response(request_id, response)?; + Ok(()) + } + Err(error) => { + client.send_response(request_id, error.to_proto())?; + Err(error) + } + } + }) + .boxed_local() + }), + ) + } + + pub fn add_model_request_handler(&self, handler: H) + where + M: EnvelopedMessage + RequestMessage + EntityMessage, + E: 'static, + H: 'static + Sync + Send + Fn(gpui::Model, TypedEnvelope, AsyncAppContext) -> F, + F: 'static + Future>, + { + let message_type_id = TypeId::of::(); + let model_type_id = TypeId::of::(); + let entity_id_extractor = |envelope: &dyn AnyTypedEnvelope| { + envelope + .as_any() + .downcast_ref::>() + .unwrap() + .payload + .remote_entity_id() + }; + self.0 + .message_handler_set() + .lock() + .add_entity_message_handler( + message_type_id, + model_type_id, + entity_id_extractor, + Arc::new(move |model, envelope, client, cx| { + let model = model.downcast::().unwrap(); + let envelope = envelope.into_any().downcast::>().unwrap(); + let request_id = envelope.message_id(); + handler(model, *envelope, cx) + .then(move |result| async move { + match result { + Ok(response) => { + client.send_response(request_id, response)?; + Ok(()) + } + Err(error) => { + client.send_response(request_id, error.to_proto())?; + Err(error) + } + } + }) + .boxed_local() + }), + ); + } + + pub fn add_model_message_handler(&self, handler: H) + where + M: EnvelopedMessage + EntityMessage, + E: 'static, + H: 'static + Sync + Send + Fn(gpui::Model, TypedEnvelope, AsyncAppContext) -> F, + F: 'static + Future>, + { + let message_type_id = TypeId::of::(); + let model_type_id = TypeId::of::(); + let entity_id_extractor = |envelope: &dyn AnyTypedEnvelope| { + envelope + .as_any() + .downcast_ref::>() + .unwrap() + .payload + .remote_entity_id() + }; + self.0 + .message_handler_set() + .lock() + .add_entity_message_handler( + message_type_id, + model_type_id, + entity_id_extractor, + Arc::new(move |model, envelope, _, cx| { + let model = model.downcast::().unwrap(); + let envelope = envelope.into_any().downcast::>().unwrap(); + handler(model, *envelope, cx).boxed_local() + }), + ); + } +} diff --git a/crates/remote/src/ssh_session.rs b/crates/remote/src/ssh_session.rs index b913689692..8d76614918 100644 --- a/crates/remote/src/ssh_session.rs +++ b/crates/remote/src/ssh_session.rs @@ -8,17 +8,14 @@ use anyhow::{anyhow, Context as _, Result}; use collections::HashMap; use futures::{ channel::{mpsc, oneshot}, - future::{BoxFuture, LocalBoxFuture}, + future::BoxFuture, select_biased, AsyncReadExt as _, AsyncWriteExt as _, Future, FutureExt as _, StreamExt as _, }; -use gpui::{AppContext, AsyncAppContext, Model, SemanticVersion, WeakModel}; +use gpui::{AppContext, AsyncAppContext, Model, SemanticVersion}; use parking_lot::Mutex; -use rpc::{ - proto::{ - self, build_typed_envelope, AnyTypedEnvelope, Envelope, EnvelopedMessage, PeerId, - ProtoClient, RequestMessage, - }, - TypedEnvelope, +use rpc::proto::{ + self, build_typed_envelope, EntityMessageSubscriber, Envelope, EnvelopedMessage, PeerId, + ProtoClient, ProtoMessageHandlerSet, RequestMessage, }; use smol::{ fs, @@ -48,20 +45,7 @@ pub struct SshSession { outgoing_tx: mpsc::UnboundedSender, spawn_process_tx: mpsc::UnboundedSender, client_socket: Option, - message_handlers: Mutex< - HashMap< - TypeId, - Arc< - dyn Send - + Sync - + Fn( - Box, - Arc, - AsyncAppContext, - ) -> Option>>, - >, - >, - >, + state: Mutex, } struct SshClientState { @@ -330,7 +314,7 @@ impl SshSession { outgoing_tx, spawn_process_tx, client_socket, - message_handlers: Default::default(), + state: Default::default(), }); cx.spawn(|cx| { @@ -351,18 +335,26 @@ impl SshSession { } else if let Some(envelope) = build_typed_envelope(peer_id, Instant::now(), incoming) { - log::debug!( - "ssh message received. name:{}", - envelope.payload_type_name() - ); - let type_id = envelope.payload_type_id(); - let handler = this.message_handlers.lock().get(&type_id).cloned(); - if let Some(handler) = handler { - if let Some(future) = handler(envelope, this.clone(), cx.clone()) { - future.await.ok(); - } else { - this.message_handlers.lock().remove(&type_id); + let type_name = envelope.payload_type_name(); + if let Some(future) = ProtoMessageHandlerSet::handle_message( + &this.state, + envelope, + this.clone().into(), + cx.clone(), + ) { + log::debug!("ssh message received. name:{type_name}"); + match future.await { + Ok(_) => { + log::debug!("ssh message handled. name:{type_name}"); + } + Err(error) => { + log::error!( + "error handling message. type:{type_name}, error:{error:?}", + ); + } } + } else { + log::error!("unhandled ssh message name:{type_name}"); } } } @@ -389,6 +381,7 @@ impl SshSession { } pub fn send(&self, payload: T) -> Result<()> { + log::debug!("ssh send name:{}", T::NAME); self.send_dynamic(payload.into_envelope(0, None, None)) } @@ -412,6 +405,22 @@ impl SshSession { Ok(()) } + pub fn subscribe_to_entity(&self, remote_id: u64, entity: &Model) { + let id = (TypeId::of::(), remote_id); + + let mut state = self.state.lock(); + if state.entities_by_type_and_remote_id.contains_key(&id) { + panic!("already subscribed to entity"); + } + + state.entities_by_type_and_remote_id.insert( + id, + EntityMessageSubscriber::Entity { + handle: entity.downgrade().into(), + }, + ); + } + pub async fn spawn_process(&self, command: String) -> process::Child { let (process_tx, process_rx) = oneshot::channel(); self.spawn_process_tx @@ -426,54 +435,6 @@ impl SshSession { pub fn ssh_args(&self) -> Vec { self.client_socket.as_ref().unwrap().ssh_args() } - - pub fn add_message_handler(&self, entity: WeakModel, handler: H) - where - M: EnvelopedMessage, - E: 'static, - H: 'static + Sync + Send + Fn(Model, TypedEnvelope, AsyncAppContext) -> F, - F: 'static + Future>, - { - let message_type_id = TypeId::of::(); - self.message_handlers.lock().insert( - message_type_id, - Arc::new(move |envelope, _, cx| { - let entity = entity.upgrade()?; - let envelope = envelope.into_any().downcast::>().unwrap(); - Some(handler(entity, *envelope, cx).boxed_local()) - }), - ); - } - - pub fn add_request_handler(&self, entity: WeakModel, handler: H) - where - M: EnvelopedMessage + RequestMessage, - E: 'static, - H: 'static + Sync + Send + Fn(Model, TypedEnvelope, AsyncAppContext) -> F, - F: 'static + Future>, - { - let message_type_id = TypeId::of::(); - self.message_handlers.lock().insert( - message_type_id, - Arc::new(move |envelope, this, cx| { - let entity = entity.upgrade()?; - let envelope = envelope.into_any().downcast::>().unwrap(); - let request_id = envelope.message_id(); - Some( - handler(entity, *envelope, cx) - .then(move |result| async move { - this.outgoing_tx.unbounded_send(result?.into_envelope( - this.next_message_id.fetch_add(1, SeqCst), - Some(request_id), - None, - ))?; - Ok(()) - }) - .boxed_local(), - ) - }), - ); - } } impl ProtoClient for SshSession { @@ -488,6 +449,14 @@ impl ProtoClient for SshSession { fn send(&self, envelope: proto::Envelope, _message_type: &'static str) -> Result<()> { self.send_dynamic(envelope) } + + fn send_response(&self, envelope: Envelope, _message_type: &'static str) -> anyhow::Result<()> { + self.send_dynamic(envelope) + } + + fn message_handler_set(&self) -> &Mutex { + &self.state + } } impl SshClientState { diff --git a/crates/remote_server/src/headless_project.rs b/crates/remote_server/src/headless_project.rs index 4f402ae2d4..36738f6694 100644 --- a/crates/remote_server/src/headless_project.rs +++ b/crates/remote_server/src/headless_project.rs @@ -7,7 +7,7 @@ use project::{ }; use remote::SshSession; use rpc::{ - proto::{self, AnyProtoClient, PeerId}, + proto::{self, AnyProtoClient, SSH_PEER_ID, SSH_PROJECT_ID}, TypedEnvelope, }; use settings::{Settings as _, SettingsStore}; @@ -18,9 +18,6 @@ use std::{ }; use worktree::Worktree; -const PEER_ID: PeerId = PeerId { owner_id: 0, id: 0 }; -const PROJECT_ID: u64 = 0; - pub struct HeadlessProject { pub fs: Arc, pub session: AnyProtoClient, @@ -36,48 +33,34 @@ impl HeadlessProject { } pub fn new(session: Arc, fs: Arc, cx: &mut ModelContext) -> Self { - let this = cx.weak_model(); - let worktree_store = cx.new_model(|_| WorktreeStore::new(true, fs.clone())); let buffer_store = cx.new_model(|cx| { - let mut buffer_store = BufferStore::new(worktree_store.clone(), Some(PROJECT_ID), cx); - buffer_store.shared(PROJECT_ID, session.clone().into(), cx); + let mut buffer_store = + BufferStore::new(worktree_store.clone(), Some(SSH_PROJECT_ID), cx); + buffer_store.shared(SSH_PROJECT_ID, session.clone().into(), cx); buffer_store }); - session.add_request_handler(this.clone(), Self::handle_list_remote_directory); - session.add_request_handler(this.clone(), Self::handle_add_worktree); - session.add_request_handler(this.clone(), Self::handle_open_buffer_by_path); - session.add_request_handler(this.clone(), Self::handle_find_search_candidates); + let client: AnyProtoClient = session.clone().into(); - session.add_request_handler(buffer_store.downgrade(), BufferStore::handle_blame_buffer); - session.add_request_handler(buffer_store.downgrade(), BufferStore::handle_update_buffer); - session.add_request_handler(buffer_store.downgrade(), BufferStore::handle_save_buffer); - session.add_message_handler(buffer_store.downgrade(), BufferStore::handle_close_buffer); + session.subscribe_to_entity(SSH_PROJECT_ID, &worktree_store); + session.subscribe_to_entity(SSH_PROJECT_ID, &buffer_store); + session.subscribe_to_entity(SSH_PROJECT_ID, &cx.handle()); - session.add_request_handler( - worktree_store.downgrade(), - WorktreeStore::handle_create_project_entry, - ); - session.add_request_handler( - worktree_store.downgrade(), - WorktreeStore::handle_rename_project_entry, - ); - session.add_request_handler( - worktree_store.downgrade(), - WorktreeStore::handle_copy_project_entry, - ); - session.add_request_handler( - worktree_store.downgrade(), - WorktreeStore::handle_delete_project_entry, - ); - session.add_request_handler( - worktree_store.downgrade(), - WorktreeStore::handle_expand_project_entry, - ); + client.add_request_handler(cx.weak_model(), Self::handle_list_remote_directory); + + client.add_model_request_handler(Self::handle_add_worktree); + client.add_model_request_handler(Self::handle_open_buffer_by_path); + client.add_model_request_handler(Self::handle_find_search_candidates); + + client.add_model_request_handler(BufferStore::handle_update_buffer); + client.add_model_message_handler(BufferStore::handle_close_buffer); + + BufferStore::init(&client); + WorktreeStore::init(&client); HeadlessProject { - session: session.into(), + session: client, fs, worktree_store, buffer_store, @@ -144,7 +127,7 @@ impl HeadlessProject { let buffer_id = buffer.read_with(&cx, |b, _| b.remote_id())?; buffer_store.update(&mut cx, |buffer_store, cx| { buffer_store - .create_buffer_for_peer(&buffer, PEER_ID, cx) + .create_buffer_for_peer(&buffer, SSH_PEER_ID, cx) .detach_and_log_err(cx); })?; @@ -181,7 +164,7 @@ impl HeadlessProject { response.buffer_ids.push(buffer_id.to_proto()); buffer_store .update(&mut cx, |buffer_store, cx| { - buffer_store.create_buffer_for_peer(&buffer, PEER_ID, cx) + buffer_store.create_buffer_for_peer(&buffer, SSH_PEER_ID, cx) })? .await?; } diff --git a/crates/remote_server/src/remote_editing_tests.rs b/crates/remote_server/src/remote_editing_tests.rs index 18d61527b5..c92776f4a8 100644 --- a/crates/remote_server/src/remote_editing_tests.rs +++ b/crates/remote_server/src/remote_editing_tests.rs @@ -17,7 +17,7 @@ use smol::stream::StreamExt; use std::{path::Path, sync::Arc}; #[gpui::test] -async fn test_remote_editing(cx: &mut TestAppContext, server_cx: &mut TestAppContext) { +async fn test_basic_remote_editing(cx: &mut TestAppContext, server_cx: &mut TestAppContext) { let (project, _headless, fs) = init_test(cx, server_cx).await; let (worktree, _) = project .update(cx, |project, cx| { diff --git a/crates/rpc/src/peer.rs b/crates/rpc/src/peer.rs index 242b3b9116..4ef52ca77a 100644 --- a/crates/rpc/src/peer.rs +++ b/crates/rpc/src/peer.rs @@ -639,14 +639,13 @@ impl Peer { pub fn respond_with_unhandled_message( &self, - envelope: Box, + sender_id: ConnectionId, + request_message_id: u32, + message_type_name: &'static str, ) -> Result<()> { - let connection = self.connection_state(envelope.sender_id().into())?; + let connection = self.connection_state(sender_id)?; let response = ErrorCode::Internal - .message(format!( - "message {} was not handled", - envelope.payload_type_name() - )) + .message(format!("message {} was not handled", message_type_name)) .to_proto(); let message_id = connection .next_message_id @@ -655,7 +654,7 @@ impl Peer { .outgoing_tx .unbounded_send(proto::Message::Envelope(response.into_envelope( message_id, - Some(envelope.message_id()), + Some(request_message_id), None, )))?; Ok(())