Lay the groundwork for collaborating on assistant panel (#13991)

This pull request introduces collaboration for the assistant panel by
turning `Context` into a CRDT. `ContextStore` is responsible for sending
and applying operations, as well as synchronizing missed changes while
the connection was lost.

Contexts are shared on a per-project basis, and only the host can share
them for now. Shared contexts can be accessed via the `History` tab in
the assistant panel.

<img width="1819" alt="image"
src="https://github.com/zed-industries/zed/assets/482957/c7ae46d2-cde3-4b03-b74a-6e9b1555c154">


Please note that this doesn't implement following yet, which is
scheduled for a subsequent pull request.

Release Notes:

- N/A
This commit is contained in:
Antonio Scandurra 2024-07-10 17:36:22 +02:00 committed by GitHub
parent 1662993811
commit 8944af7406
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
25 changed files with 4232 additions and 2120 deletions

4
Cargo.lock generated
View File

@ -377,6 +377,7 @@ dependencies = [
"cargo_toml", "cargo_toml",
"chrono", "chrono",
"client", "client",
"clock",
"collections", "collections",
"command_palette_hooks", "command_palette_hooks",
"ctor", "ctor",
@ -419,6 +420,7 @@ dependencies = [
"telemetry_events", "telemetry_events",
"terminal", "terminal",
"terminal_view", "terminal_view",
"text",
"theme", "theme",
"tiktoken-rs", "tiktoken-rs",
"toml 0.8.10", "toml 0.8.10",
@ -2405,6 +2407,7 @@ version = "0.1.0"
dependencies = [ dependencies = [
"chrono", "chrono",
"parking_lot", "parking_lot",
"serde",
"smallvec", "smallvec",
] ]
@ -2463,6 +2466,7 @@ version = "0.44.0"
dependencies = [ dependencies = [
"anthropic", "anthropic",
"anyhow", "anyhow",
"assistant",
"async-trait", "async-trait",
"async-tungstenite", "async-tungstenite",
"audio", "audio",

View File

@ -12,6 +12,14 @@ workspace = true
path = "src/assistant.rs" path = "src/assistant.rs"
doctest = false doctest = false
[features]
test-support = [
"editor/test-support",
"language/test-support",
"project/test-support",
"text/test-support",
]
[dependencies] [dependencies]
anthropic = { workspace = true, features = ["schemars"] } anthropic = { workspace = true, features = ["schemars"] }
anyhow.workspace = true anyhow.workspace = true
@ -21,6 +29,7 @@ breadcrumbs.workspace = true
cargo_toml.workspace = true cargo_toml.workspace = true
chrono.workspace = true chrono.workspace = true
client.workspace = true client.workspace = true
clock.workspace = true
collections.workspace = true collections.workspace = true
command_palette_hooks.workspace = true command_palette_hooks.workspace = true
editor.workspace = true editor.workspace = true
@ -72,7 +81,9 @@ picker.workspace = true
ctor.workspace = true ctor.workspace = true
editor = { workspace = true, features = ["test-support"] } editor = { workspace = true, features = ["test-support"] }
env_logger.workspace = true env_logger.workspace = true
language = { workspace = true, features = ["test-support"] }
log.workspace = true log.workspace = true
project = { workspace = true, features = ["test-support"] } project = { workspace = true, features = ["test-support"] }
rand.workspace = true rand.workspace = true
text = { workspace = true, features = ["test-support"] }
unindent.workspace = true unindent.workspace = true

View File

@ -1,7 +1,8 @@
pub mod assistant_panel; pub mod assistant_panel;
pub mod assistant_settings; pub mod assistant_settings;
mod completion_provider; mod completion_provider;
mod context_store; mod context;
pub mod context_store;
mod inline_assistant; mod inline_assistant;
mod model_selector; mod model_selector;
mod prompt_library; mod prompt_library;
@ -16,8 +17,9 @@ use assistant_settings::{AnthropicModel, AssistantSettings, CloudModel, OllamaMo
use assistant_slash_command::SlashCommandRegistry; use assistant_slash_command::SlashCommandRegistry;
use client::{proto, Client}; use client::{proto, Client};
use command_palette_hooks::CommandPaletteFilter; use command_palette_hooks::CommandPaletteFilter;
pub(crate) use completion_provider::*; pub use completion_provider::*;
pub(crate) use context_store::*; pub use context::*;
pub use context_store::*;
use fs::Fs; use fs::Fs;
use gpui::{actions, AppContext, Global, SharedString, UpdateGlobal}; use gpui::{actions, AppContext, Global, SharedString, UpdateGlobal};
use indexed_docs::IndexedDocsRegistry; use indexed_docs::IndexedDocsRegistry;
@ -57,10 +59,14 @@ actions!(
] ]
); );
#[derive( #[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize, pub struct MessageId(clock::Lamport);
)]
struct MessageId(usize); impl MessageId {
pub fn as_u64(self) -> u64 {
self.0.as_u64()
}
}
#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)] #[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
#[serde(rename_all = "lowercase")] #[serde(rename_all = "lowercase")]
@ -71,8 +77,26 @@ pub enum Role {
} }
impl Role { impl Role {
pub fn cycle(&mut self) { pub fn from_proto(role: i32) -> Role {
*self = match self { match proto::LanguageModelRole::from_i32(role) {
Some(proto::LanguageModelRole::LanguageModelUser) => Role::User,
Some(proto::LanguageModelRole::LanguageModelAssistant) => Role::Assistant,
Some(proto::LanguageModelRole::LanguageModelSystem) => Role::System,
Some(proto::LanguageModelRole::LanguageModelTool) => Role::System,
None => Role::User,
}
}
pub fn to_proto(&self) -> proto::LanguageModelRole {
match self {
Role::User => proto::LanguageModelRole::LanguageModelUser,
Role::Assistant => proto::LanguageModelRole::LanguageModelAssistant,
Role::System => proto::LanguageModelRole::LanguageModelSystem,
}
}
pub fn cycle(self) -> Role {
match self {
Role::User => Role::Assistant, Role::User => Role::Assistant,
Role::Assistant => Role::System, Role::Assistant => Role::System,
Role::System => Role::User, Role::System => Role::User,
@ -151,11 +175,7 @@ pub struct LanguageModelRequestMessage {
impl LanguageModelRequestMessage { impl LanguageModelRequestMessage {
pub fn to_proto(&self) -> proto::LanguageModelRequestMessage { pub fn to_proto(&self) -> proto::LanguageModelRequestMessage {
proto::LanguageModelRequestMessage { proto::LanguageModelRequestMessage {
role: match self.role { role: self.role.to_proto() as i32,
Role::User => proto::LanguageModelRole::LanguageModelUser,
Role::Assistant => proto::LanguageModelRole::LanguageModelAssistant,
Role::System => proto::LanguageModelRole::LanguageModelSystem,
} as i32,
content: self.content.clone(), content: self.content.clone(),
tool_calls: Vec::new(), tool_calls: Vec::new(),
tool_call_id: None, tool_call_id: None,
@ -222,19 +242,48 @@ pub struct LanguageModelChoiceDelta {
pub finish_reason: Option<String>, pub finish_reason: Option<String>,
} }
#[derive(Clone, Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
struct MessageMetadata { pub enum MessageStatus {
role: Role,
status: MessageStatus,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
enum MessageStatus {
Pending, Pending,
Done, Done,
Error(SharedString), Error(SharedString),
} }
impl MessageStatus {
pub fn from_proto(status: proto::ContextMessageStatus) -> MessageStatus {
match status.variant {
Some(proto::context_message_status::Variant::Pending(_)) => MessageStatus::Pending,
Some(proto::context_message_status::Variant::Done(_)) => MessageStatus::Done,
Some(proto::context_message_status::Variant::Error(error)) => {
MessageStatus::Error(error.message.into())
}
None => MessageStatus::Pending,
}
}
pub fn to_proto(&self) -> proto::ContextMessageStatus {
match self {
MessageStatus::Pending => proto::ContextMessageStatus {
variant: Some(proto::context_message_status::Variant::Pending(
proto::context_message_status::Pending {},
)),
},
MessageStatus::Done => proto::ContextMessageStatus {
variant: Some(proto::context_message_status::Variant::Done(
proto::context_message_status::Done {},
)),
},
MessageStatus::Error(message) => proto::ContextMessageStatus {
variant: Some(proto::context_message_status::Variant::Error(
proto::context_message_status::Error {
message: message.to_string(),
},
)),
},
}
}
}
/// The state pertaining to the Assistant. /// The state pertaining to the Assistant.
#[derive(Default)] #[derive(Default)]
struct Assistant { struct Assistant {
@ -287,6 +336,7 @@ pub fn init(fs: Arc<dyn Fs>, client: Arc<Client>, cx: &mut AppContext) {
}) })
.detach(); .detach();
context_store::init(&client);
prompt_library::init(cx); prompt_library::init(cx);
completion_provider::init(client.clone(), cx); completion_provider::init(client.clone(), cx);
assistant_slash_command::init(cx); assistant_slash_command::init(cx);

File diff suppressed because it is too large Load Diff

View File

@ -1,13 +1,13 @@
mod anthropic; mod anthropic;
mod cloud; mod cloud;
#[cfg(test)] #[cfg(any(test, feature = "test-support"))]
mod fake; mod fake;
mod ollama; mod ollama;
mod open_ai; mod open_ai;
pub use anthropic::*; pub use anthropic::*;
pub use cloud::*; pub use cloud::*;
#[cfg(test)] #[cfg(any(test, feature = "test-support"))]
pub use fake::*; pub use fake::*;
pub use ollama::*; pub use ollama::*;
pub use open_ai::*; pub use open_ai::*;

View File

@ -13,7 +13,6 @@ pub struct FakeCompletionProvider {
} }
impl FakeCompletionProvider { impl FakeCompletionProvider {
#[cfg(test)]
pub fn setup_test(cx: &mut AppContext) -> Self { pub fn setup_test(cx: &mut AppContext) -> Self {
use crate::CompletionProvider; use crate::CompletionProvider;
use parking_lot::RwLock; use parking_lot::RwLock;

File diff suppressed because it is too large Load Diff

View File

@ -1,97 +1,117 @@
use crate::{assistant_settings::OpenAiModel, MessageId, MessageMetadata}; use crate::{
use anyhow::{anyhow, Result}; Context, ContextEvent, ContextId, ContextOperation, ContextVersion, SavedContext,
use assistant_slash_command::SlashCommandOutputSection; SavedContextMetadata,
use collections::HashMap; };
use anyhow::{anyhow, Context as _, Result};
use client::{proto, telemetry::Telemetry, Client, TypedEnvelope};
use clock::ReplicaId;
use fs::Fs; use fs::Fs;
use futures::StreamExt; use futures::StreamExt;
use fuzzy::StringMatchCandidate; use fuzzy::StringMatchCandidate;
use gpui::{AppContext, Model, ModelContext, Task}; use gpui::{AppContext, AsyncAppContext, Context as _, Model, ModelContext, Task, WeakModel};
use language::LanguageRegistry;
use paths::contexts_dir; use paths::contexts_dir;
use project::Project;
use regex::Regex; use regex::Regex;
use serde::{Deserialize, Serialize}; use std::{
use std::{cmp::Reverse, ffi::OsStr, path::PathBuf, sync::Arc, time::Duration}; cmp::Reverse,
use ui::Context; ffi::OsStr,
mem,
path::{Path, PathBuf},
sync::Arc,
time::Duration,
};
use util::{ResultExt, TryFutureExt}; use util::{ResultExt, TryFutureExt};
#[derive(Serialize, Deserialize)] pub fn init(client: &Arc<Client>) {
pub struct SavedMessage { client.add_model_message_handler(ContextStore::handle_advertise_contexts);
pub id: MessageId, client.add_model_request_handler(ContextStore::handle_open_context);
pub start: usize, client.add_model_message_handler(ContextStore::handle_update_context);
} client.add_model_request_handler(ContextStore::handle_synchronize_contexts);
#[derive(Serialize, Deserialize)]
pub struct SavedContext {
pub id: Option<String>,
pub zed: String,
pub version: String,
pub text: String,
pub messages: Vec<SavedMessage>,
pub message_metadata: HashMap<MessageId, MessageMetadata>,
pub summary: String,
pub slash_command_output_sections: Vec<SlashCommandOutputSection<usize>>,
}
impl SavedContext {
pub const VERSION: &'static str = "0.3.0";
}
#[derive(Serialize, Deserialize)]
pub struct SavedContextV0_2_0 {
pub id: Option<String>,
pub zed: String,
pub version: String,
pub text: String,
pub messages: Vec<SavedMessage>,
pub message_metadata: HashMap<MessageId, MessageMetadata>,
pub summary: String,
}
#[derive(Serialize, Deserialize)]
struct SavedContextV0_1_0 {
id: Option<String>,
zed: String,
version: String,
text: String,
messages: Vec<SavedMessage>,
message_metadata: HashMap<MessageId, MessageMetadata>,
summary: String,
api_url: Option<String>,
model: OpenAiModel,
} }
#[derive(Clone)] #[derive(Clone)]
pub struct SavedContextMetadata { pub struct RemoteContextMetadata {
pub title: String, pub id: ContextId,
pub path: PathBuf, pub summary: Option<String>,
pub mtime: chrono::DateTime<chrono::Local>,
} }
pub struct ContextStore { pub struct ContextStore {
contexts: Vec<ContextHandle>,
contexts_metadata: Vec<SavedContextMetadata>, contexts_metadata: Vec<SavedContextMetadata>,
host_contexts: Vec<RemoteContextMetadata>,
fs: Arc<dyn Fs>, fs: Arc<dyn Fs>,
languages: Arc<LanguageRegistry>,
telemetry: Arc<Telemetry>,
_watch_updates: Task<Option<()>>, _watch_updates: Task<Option<()>>,
client: Arc<Client>,
project: Model<Project>,
project_is_shared: bool,
client_subscription: Option<client::Subscription>,
_project_subscriptions: Vec<gpui::Subscription>,
}
enum ContextHandle {
Weak(WeakModel<Context>),
Strong(Model<Context>),
}
impl ContextHandle {
fn upgrade(&self) -> Option<Model<Context>> {
match self {
ContextHandle::Weak(weak) => weak.upgrade(),
ContextHandle::Strong(strong) => Some(strong.clone()),
}
}
fn downgrade(&self) -> WeakModel<Context> {
match self {
ContextHandle::Weak(weak) => weak.clone(),
ContextHandle::Strong(strong) => strong.downgrade(),
}
}
} }
impl ContextStore { impl ContextStore {
pub fn new(fs: Arc<dyn Fs>, cx: &mut AppContext) -> Task<Result<Model<Self>>> { pub fn new(project: Model<Project>, cx: &mut AppContext) -> Task<Result<Model<Self>>> {
let fs = project.read(cx).fs().clone();
let languages = project.read(cx).languages().clone();
let telemetry = project.read(cx).client().telemetry().clone();
cx.spawn(|mut cx| async move { cx.spawn(|mut cx| async move {
const CONTEXT_WATCH_DURATION: Duration = Duration::from_millis(100); const CONTEXT_WATCH_DURATION: Duration = Duration::from_millis(100);
let (mut events, _) = fs.watch(contexts_dir(), CONTEXT_WATCH_DURATION).await; let (mut events, _) = fs.watch(contexts_dir(), CONTEXT_WATCH_DURATION).await;
let this = cx.new_model(|cx: &mut ModelContext<Self>| Self { let this = cx.new_model(|cx: &mut ModelContext<Self>| {
contexts_metadata: Vec::new(), let mut this = Self {
fs, contexts: Vec::new(),
_watch_updates: cx.spawn(|this, mut cx| { contexts_metadata: Vec::new(),
async move { host_contexts: Vec::new(),
while events.next().await.is_some() { fs,
this.update(&mut cx, |this, cx| this.reload(cx))? languages,
.await telemetry,
.log_err(); _watch_updates: cx.spawn(|this, mut cx| {
async move {
while events.next().await.is_some() {
this.update(&mut cx, |this, cx| this.reload(cx))?
.await
.log_err();
}
anyhow::Ok(())
} }
anyhow::Ok(()) .log_err()
} }),
.log_err() client_subscription: None,
}), _project_subscriptions: vec![
cx.observe(&project, Self::handle_project_changed),
cx.subscribe(&project, Self::handle_project_event),
],
project_is_shared: false,
client: project.read(cx).client(),
project: project.clone(),
};
this.handle_project_changed(project, cx);
this.synchronize_contexts(cx);
this
})?; })?;
this.update(&mut cx, |this, cx| this.reload(cx))? this.update(&mut cx, |this, cx| this.reload(cx))?
.await .await
@ -100,54 +120,433 @@ impl ContextStore {
}) })
} }
pub fn load(&self, path: PathBuf, cx: &AppContext) -> Task<Result<SavedContext>> { async fn handle_advertise_contexts(
this: Model<Self>,
envelope: TypedEnvelope<proto::AdvertiseContexts>,
mut cx: AsyncAppContext,
) -> Result<()> {
this.update(&mut cx, |this, cx| {
this.host_contexts = envelope
.payload
.contexts
.into_iter()
.map(|context| RemoteContextMetadata {
id: ContextId::from_proto(context.context_id),
summary: context.summary,
})
.collect();
cx.notify();
})
}
async fn handle_open_context(
this: Model<Self>,
envelope: TypedEnvelope<proto::OpenContext>,
mut cx: AsyncAppContext,
) -> Result<proto::OpenContextResponse> {
let context_id = ContextId::from_proto(envelope.payload.context_id);
let operations = this.update(&mut cx, |this, cx| {
if this.project.read(cx).is_remote() {
return Err(anyhow!("only the host contexts can be opened"));
}
let context = this
.loaded_context_for_id(&context_id, cx)
.context("context not found")?;
if context.read(cx).replica_id() != ReplicaId::default() {
return Err(anyhow!("context must be opened via the host"));
}
anyhow::Ok(
context
.read(cx)
.serialize_ops(&ContextVersion::default(), cx),
)
})??;
let operations = operations.await;
Ok(proto::OpenContextResponse {
context: Some(proto::Context { operations }),
})
}
async fn handle_update_context(
this: Model<Self>,
envelope: TypedEnvelope<proto::UpdateContext>,
mut cx: AsyncAppContext,
) -> Result<()> {
this.update(&mut cx, |this, cx| {
let context_id = ContextId::from_proto(envelope.payload.context_id);
if let Some(context) = this.loaded_context_for_id(&context_id, cx) {
let operation_proto = envelope.payload.operation.context("invalid operation")?;
let operation = ContextOperation::from_proto(operation_proto)?;
context.update(cx, |context, cx| context.apply_ops([operation], cx))?;
}
Ok(())
})?
}
async fn handle_synchronize_contexts(
this: Model<Self>,
envelope: TypedEnvelope<proto::SynchronizeContexts>,
mut cx: AsyncAppContext,
) -> Result<proto::SynchronizeContextsResponse> {
this.update(&mut cx, |this, cx| {
if this.project.read(cx).is_remote() {
return Err(anyhow!("only the host can synchronize contexts"));
}
let mut local_versions = Vec::new();
for remote_version_proto in envelope.payload.contexts {
let remote_version = ContextVersion::from_proto(&remote_version_proto);
let context_id = ContextId::from_proto(remote_version_proto.context_id);
if let Some(context) = this.loaded_context_for_id(&context_id, cx) {
let context = context.read(cx);
let operations = context.serialize_ops(&remote_version, cx);
local_versions.push(context.version(cx).to_proto(context_id.clone()));
let client = this.client.clone();
let project_id = envelope.payload.project_id;
cx.background_executor()
.spawn(async move {
let operations = operations.await;
for operation in operations {
client.send(proto::UpdateContext {
project_id,
context_id: context_id.to_proto(),
operation: Some(operation),
})?;
}
anyhow::Ok(())
})
.detach_and_log_err(cx);
}
}
this.advertise_contexts(cx);
anyhow::Ok(proto::SynchronizeContextsResponse {
contexts: local_versions,
})
})?
}
fn handle_project_changed(&mut self, _: Model<Project>, cx: &mut ModelContext<Self>) {
let is_shared = self.project.read(cx).is_shared();
let was_shared = mem::replace(&mut self.project_is_shared, is_shared);
if is_shared == was_shared {
return;
}
if is_shared {
self.contexts.retain_mut(|context| {
if let Some(strong_context) = context.upgrade() {
*context = ContextHandle::Strong(strong_context);
true
} else {
false
}
});
let remote_id = self.project.read(cx).remote_id().unwrap();
self.client_subscription = self
.client
.subscribe_to_entity(remote_id)
.log_err()
.map(|subscription| subscription.set_model(&cx.handle(), &mut cx.to_async()));
self.advertise_contexts(cx);
} else {
self.client_subscription = None;
}
}
fn handle_project_event(
&mut self,
_: Model<Project>,
event: &project::Event,
cx: &mut ModelContext<Self>,
) {
match event {
project::Event::Reshared => {
self.advertise_contexts(cx);
}
project::Event::HostReshared | project::Event::Rejoined => {
self.synchronize_contexts(cx);
}
project::Event::DisconnectedFromHost => {
self.contexts.retain_mut(|context| {
if let Some(strong_context) = context.upgrade() {
*context = ContextHandle::Weak(context.downgrade());
strong_context.update(cx, |context, cx| {
if context.replica_id() != ReplicaId::default() {
context.set_capability(language::Capability::ReadOnly, cx);
}
});
true
} else {
false
}
});
self.host_contexts.clear();
cx.notify();
}
_ => {}
}
}
pub fn create(&mut self, cx: &mut ModelContext<Self>) -> Model<Context> {
let context = cx.new_model(|cx| {
Context::local(self.languages.clone(), Some(self.telemetry.clone()), cx)
});
self.register_context(&context, cx);
context
}
pub fn open_local_context(
&mut self,
path: PathBuf,
cx: &ModelContext<Self>,
) -> Task<Result<Model<Context>>> {
if let Some(existing_context) = self.loaded_context_for_path(&path, cx) {
return Task::ready(Ok(existing_context));
}
let fs = self.fs.clone(); let fs = self.fs.clone();
cx.background_executor().spawn(async move { let languages = self.languages.clone();
let saved_context = fs.load(&path).await?; let telemetry = self.telemetry.clone();
let saved_context_json = serde_json::from_str::<serde_json::Value>(&saved_context)?; let load = cx.background_executor().spawn({
match saved_context_json let path = path.clone();
.get("version") async move {
.ok_or_else(|| anyhow!("version not found"))? let saved_context = fs.load(&path).await?;
{ SavedContext::from_json(&saved_context)
serde_json::Value::String(version) => match version.as_str() { }
SavedContext::VERSION => { });
Ok(serde_json::from_value::<SavedContext>(saved_context_json)?)
} cx.spawn(|this, mut cx| async move {
"0.2.0" => { let saved_context = load.await?;
let saved_context = let context = cx.new_model(|cx| {
serde_json::from_value::<SavedContextV0_2_0>(saved_context_json)?; Context::deserialize(saved_context, path.clone(), languages, Some(telemetry), cx)
Ok(SavedContext { })?;
id: saved_context.id, this.update(&mut cx, |this, cx| {
zed: saved_context.zed, if let Some(existing_context) = this.loaded_context_for_path(&path, cx) {
version: saved_context.version, existing_context
text: saved_context.text, } else {
messages: saved_context.messages, this.register_context(&context, cx);
message_metadata: saved_context.message_metadata, context
summary: saved_context.summary, }
slash_command_output_sections: Vec::new(), })
}) })
} }
"0.1.0" => {
let saved_context = fn loaded_context_for_path(&self, path: &Path, cx: &AppContext) -> Option<Model<Context>> {
serde_json::from_value::<SavedContextV0_1_0>(saved_context_json)?; self.contexts.iter().find_map(|context| {
Ok(SavedContext { let context = context.upgrade()?;
id: saved_context.id, if context.read(cx).path() == Some(path) {
zed: saved_context.zed, Some(context)
version: saved_context.version, } else {
text: saved_context.text, None
messages: saved_context.messages,
message_metadata: saved_context.message_metadata,
summary: saved_context.summary,
slash_command_output_sections: Vec::new(),
})
}
_ => Err(anyhow!("unrecognized saved context version: {}", version)),
},
_ => Err(anyhow!("version not found on saved context")),
} }
}) })
} }
fn loaded_context_for_id(&self, id: &ContextId, cx: &AppContext) -> Option<Model<Context>> {
self.contexts.iter().find_map(|context| {
let context = context.upgrade()?;
if context.read(cx).id() == id {
Some(context)
} else {
None
}
})
}
pub fn open_remote_context(
&mut self,
context_id: ContextId,
cx: &mut ModelContext<Self>,
) -> Task<Result<Model<Context>>> {
let project = self.project.read(cx);
let Some(project_id) = project.remote_id() else {
return Task::ready(Err(anyhow!("project was not remote")));
};
if project.is_local() {
return Task::ready(Err(anyhow!("cannot open remote contexts as the host")));
}
if let Some(context) = self.loaded_context_for_id(&context_id, cx) {
return Task::ready(Ok(context));
}
let replica_id = project.replica_id();
let capability = project.capability();
let language_registry = self.languages.clone();
let telemetry = self.telemetry.clone();
let request = self.client.request(proto::OpenContext {
project_id,
context_id: context_id.to_proto(),
});
cx.spawn(|this, mut cx| async move {
let response = request.await?;
let context_proto = response.context.context("invalid context")?;
let context = cx.new_model(|cx| {
Context::new(
context_id.clone(),
replica_id,
capability,
language_registry,
Some(telemetry),
cx,
)
})?;
let operations = cx
.background_executor()
.spawn(async move {
context_proto
.operations
.into_iter()
.map(|op| ContextOperation::from_proto(op))
.collect::<Result<Vec<_>>>()
})
.await?;
context.update(&mut cx, |context, cx| context.apply_ops(operations, cx))??;
this.update(&mut cx, |this, cx| {
if let Some(existing_context) = this.loaded_context_for_id(&context_id, cx) {
existing_context
} else {
this.register_context(&context, cx);
this.synchronize_contexts(cx);
context
}
})
})
}
fn register_context(&mut self, context: &Model<Context>, cx: &mut ModelContext<Self>) {
let handle = if self.project_is_shared {
ContextHandle::Strong(context.clone())
} else {
ContextHandle::Weak(context.downgrade())
};
self.contexts.push(handle);
self.advertise_contexts(cx);
cx.subscribe(context, Self::handle_context_event).detach();
}
fn handle_context_event(
&mut self,
context: Model<Context>,
event: &ContextEvent,
cx: &mut ModelContext<Self>,
) {
let Some(project_id) = self.project.read(cx).remote_id() else {
return;
};
match event {
ContextEvent::SummaryChanged => {
self.advertise_contexts(cx);
}
ContextEvent::Operation(operation) => {
let context_id = context.read(cx).id().to_proto();
let operation = operation.to_proto();
self.client
.send(proto::UpdateContext {
project_id,
context_id,
operation: Some(operation),
})
.log_err();
}
_ => {}
}
}
fn advertise_contexts(&self, cx: &AppContext) {
let Some(project_id) = self.project.read(cx).remote_id() else {
return;
};
// For now, only the host can advertise their open contexts.
if self.project.read(cx).is_remote() {
return;
}
let contexts = self
.contexts
.iter()
.rev()
.filter_map(|context| {
let context = context.upgrade()?.read(cx);
if context.replica_id() == ReplicaId::default() {
Some(proto::ContextMetadata {
context_id: context.id().to_proto(),
summary: context.summary().map(|summary| summary.text.clone()),
})
} else {
None
}
})
.collect();
self.client
.send(proto::AdvertiseContexts {
project_id,
contexts,
})
.ok();
}
fn synchronize_contexts(&mut self, cx: &mut ModelContext<Self>) {
let Some(project_id) = self.project.read(cx).remote_id() else {
return;
};
let contexts = self
.contexts
.iter()
.filter_map(|context| {
let context = context.upgrade()?.read(cx);
if context.replica_id() != ReplicaId::default() {
Some(context.version(cx).to_proto(context.id().clone()))
} else {
None
}
})
.collect();
let client = self.client.clone();
let request = self.client.request(proto::SynchronizeContexts {
project_id,
contexts,
});
cx.spawn(|this, cx| async move {
let response = request.await?;
let mut context_ids = Vec::new();
let mut operations = Vec::new();
this.read_with(&cx, |this, cx| {
for context_version_proto in response.contexts {
let context_version = ContextVersion::from_proto(&context_version_proto);
let context_id = ContextId::from_proto(context_version_proto.context_id);
if let Some(context) = this.loaded_context_for_id(&context_id, cx) {
context_ids.push(context_id);
operations.push(context.read(cx).serialize_ops(&context_version, cx));
}
}
})?;
let operations = futures::future::join_all(operations).await;
for (context_id, operations) in context_ids.into_iter().zip(operations) {
for operation in operations {
client.send(proto::UpdateContext {
project_id,
context_id: context_id.to_proto(),
operation: Some(operation),
})?;
}
}
anyhow::Ok(())
})
.detach_and_log_err(cx);
}
pub fn search(&self, query: String, cx: &AppContext) -> Task<Vec<SavedContextMetadata>> { pub fn search(&self, query: String, cx: &AppContext) -> Task<Vec<SavedContextMetadata>> {
let metadata = self.contexts_metadata.clone(); let metadata = self.contexts_metadata.clone();
let executor = cx.background_executor().clone(); let executor = cx.background_executor().clone();
@ -178,6 +577,10 @@ impl ContextStore {
}) })
} }
pub fn host_contexts(&self) -> &[RemoteContextMetadata] {
&self.host_contexts
}
fn reload(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> { fn reload(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
let fs = self.fs.clone(); let fs = self.fs.clone();
cx.spawn(|this, mut cx| async move { cx.spawn(|this, mut cx| async move {

View File

@ -3,7 +3,6 @@ use crate::{
InlineAssist, InlineAssistant, LanguageModelRequest, LanguageModelRequestMessage, Role, InlineAssist, InlineAssistant, LanguageModelRequest, LanguageModelRequestMessage, Role,
}; };
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use assistant_slash_command::SlashCommandRegistry;
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use collections::{HashMap, HashSet}; use collections::{HashMap, HashSet};
use editor::{actions::Tab, CurrentLineHighlight, Editor, EditorElement, EditorEvent, EditorStyle}; use editor::{actions::Tab, CurrentLineHighlight, Editor, EditorElement, EditorEvent, EditorStyle};
@ -448,7 +447,6 @@ impl PromptLibrary {
self.set_active_prompt(Some(prompt_id), cx); self.set_active_prompt(Some(prompt_id), cx);
} else if let Some(prompt_metadata) = self.store.metadata(prompt_id) { } else if let Some(prompt_metadata) = self.store.metadata(prompt_id) {
let language_registry = self.language_registry.clone(); let language_registry = self.language_registry.clone();
let commands = SlashCommandRegistry::global(cx);
let prompt = self.store.load(prompt_id); let prompt = self.store.load(prompt_id);
self.pending_load = cx.spawn(|this, mut cx| async move { self.pending_load = cx.spawn(|this, mut cx| async move {
let prompt = prompt.await; let prompt = prompt.await;
@ -477,7 +475,7 @@ impl PromptLibrary {
editor.set_use_modal_editing(false); editor.set_use_modal_editing(false);
editor.set_current_line_highlight(Some(CurrentLineHighlight::None)); editor.set_current_line_highlight(Some(CurrentLineHighlight::None));
editor.set_completion_provider(Box::new( editor.set_completion_provider(Box::new(
SlashCommandCompletionProvider::new(commands, None, None), SlashCommandCompletionProvider::new(None, None),
)); ));
if focus { if focus {
editor.focus(cx); editor.focus(cx);

View File

@ -31,7 +31,6 @@ pub mod tabs_command;
pub mod term_command; pub mod term_command;
pub(crate) struct SlashCommandCompletionProvider { pub(crate) struct SlashCommandCompletionProvider {
commands: Arc<SlashCommandRegistry>,
cancel_flag: Mutex<Arc<AtomicBool>>, cancel_flag: Mutex<Arc<AtomicBool>>,
editor: Option<WeakView<ContextEditor>>, editor: Option<WeakView<ContextEditor>>,
workspace: Option<WeakView<Workspace>>, workspace: Option<WeakView<Workspace>>,
@ -46,14 +45,12 @@ pub(crate) struct SlashCommandLine {
impl SlashCommandCompletionProvider { impl SlashCommandCompletionProvider {
pub fn new( pub fn new(
commands: Arc<SlashCommandRegistry>,
editor: Option<WeakView<ContextEditor>>, editor: Option<WeakView<ContextEditor>>,
workspace: Option<WeakView<Workspace>>, workspace: Option<WeakView<Workspace>>,
) -> Self { ) -> Self {
Self { Self {
cancel_flag: Mutex::new(Arc::new(AtomicBool::new(false))), cancel_flag: Mutex::new(Arc::new(AtomicBool::new(false))),
editor, editor,
commands,
workspace, workspace,
} }
} }
@ -65,8 +62,8 @@ impl SlashCommandCompletionProvider {
name_range: Range<Anchor>, name_range: Range<Anchor>,
cx: &mut WindowContext, cx: &mut WindowContext,
) -> Task<Result<Vec<project::Completion>>> { ) -> Task<Result<Vec<project::Completion>>> {
let candidates = self let commands = SlashCommandRegistry::global(cx);
.commands let candidates = commands
.command_names() .command_names()
.into_iter() .into_iter()
.enumerate() .enumerate()
@ -76,7 +73,6 @@ impl SlashCommandCompletionProvider {
char_bag: def.as_ref().into(), char_bag: def.as_ref().into(),
}) })
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let commands = self.commands.clone();
let command_name = command_name.to_string(); let command_name = command_name.to_string();
let editor = self.editor.clone(); let editor = self.editor.clone();
let workspace = self.workspace.clone(); let workspace = self.workspace.clone();
@ -155,7 +151,8 @@ impl SlashCommandCompletionProvider {
flag.store(true, SeqCst); flag.store(true, SeqCst);
*flag = new_cancel_flag.clone(); *flag = new_cancel_flag.clone();
if let Some(command) = self.commands.command(command_name) { let commands = SlashCommandRegistry::global(cx);
if let Some(command) = commands.command(command_name) {
let completions = command.complete_argument( let completions = command.complete_argument(
argument, argument,
new_cancel_flag.clone(), new_cancel_flag.clone(),

View File

@ -67,7 +67,7 @@ pub struct SlashCommandOutput {
pub run_commands_in_text: bool, pub run_commands_in_text: bool,
} }
#[derive(Clone, Serialize, Deserialize)] #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct SlashCommandOutputSection<T> { pub struct SlashCommandOutputSection<T> {
pub range: Range<T>, pub range: Range<T>,
pub icon: IconName, pub icon: IconName,

View File

@ -18,4 +18,5 @@ test-support = ["dep:parking_lot"]
[dependencies] [dependencies]
chrono.workspace = true chrono.workspace = true
parking_lot = { workspace = true, optional = true } parking_lot = { workspace = true, optional = true }
serde.workspace = true
smallvec.workspace = true smallvec.workspace = true

View File

@ -1,5 +1,6 @@
mod system_clock; mod system_clock;
use serde::{Deserialize, Serialize};
use smallvec::SmallVec; use smallvec::SmallVec;
use std::{ use std::{
cmp::{self, Ordering}, cmp::{self, Ordering},
@ -16,7 +17,7 @@ pub type Seq = u32;
/// A [Lamport timestamp](https://en.wikipedia.org/wiki/Lamport_timestamp), /// A [Lamport timestamp](https://en.wikipedia.org/wiki/Lamport_timestamp),
/// used to determine the ordering of events in the editor. /// used to determine the ordering of events in the editor.
#[derive(Clone, Copy, Default, Eq, Hash, PartialEq)] #[derive(Clone, Copy, Default, Eq, Hash, PartialEq, Serialize, Deserialize)]
pub struct Lamport { pub struct Lamport {
pub replica_id: ReplicaId, pub replica_id: ReplicaId,
pub value: Seq, pub value: Seq,
@ -161,6 +162,10 @@ impl Lamport {
} }
} }
pub fn as_u64(self) -> u64 {
((self.value as u64) << 32) | (self.replica_id as u64)
}
pub fn tick(&mut self) -> Self { pub fn tick(&mut self) -> Self {
let timestamp = *self; let timestamp = *self;
self.value += 1; self.value += 1;

View File

@ -71,6 +71,7 @@ util.workspace = true
uuid.workspace = true uuid.workspace = true
[dev-dependencies] [dev-dependencies]
assistant = { workspace = true, features = ["test-support"] }
async-trait.workspace = true async-trait.workspace = true
audio.workspace = true audio.workspace = true
call = { workspace = true, features = ["test-support"] } call = { workspace = true, features = ["test-support"] }

View File

@ -595,6 +595,14 @@ impl Server {
.add_message_handler(user_message_handler(acknowledge_channel_message)) .add_message_handler(user_message_handler(acknowledge_channel_message))
.add_message_handler(user_message_handler(acknowledge_buffer_version)) .add_message_handler(user_message_handler(acknowledge_buffer_version))
.add_request_handler(user_handler(get_supermaven_api_key)) .add_request_handler(user_handler(get_supermaven_api_key))
.add_request_handler(user_handler(
forward_mutating_project_request::<proto::OpenContext>,
))
.add_request_handler(user_handler(
forward_mutating_project_request::<proto::SynchronizeContexts>,
))
.add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
.add_message_handler(update_context)
.add_streaming_request_handler({ .add_streaming_request_handler({
let app_state = app_state.clone(); let app_state = app_state.clone();
move |request, response, session| { move |request, response, session| {
@ -3056,6 +3064,53 @@ async fn update_buffer(
Ok(()) Ok(())
} }
async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
let project_id = ProjectId::from_proto(message.project_id);
let operation = message.operation.as_ref().context("invalid operation")?;
let capability = match operation.variant.as_ref() {
Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
if let Some(buffer_op) = buffer_op.operation.as_ref() {
match buffer_op.variant {
None | Some(proto::operation::Variant::UpdateSelections(_)) => {
Capability::ReadOnly
}
_ => Capability::ReadWrite,
}
} else {
Capability::ReadWrite
}
}
Some(_) => Capability::ReadWrite,
None => Capability::ReadOnly,
};
let guard = session
.db()
.await
.connections_for_buffer_update(
project_id,
session.principal_id(),
session.connection_id,
capability,
)
.await?;
let (host, guests) = &*guard;
broadcast(
Some(session.connection_id),
guests.iter().chain([host]).copied(),
|connection_id| {
session
.peer
.forward_send(session.connection_id, connection_id, message.clone())
},
);
Ok(())
}
/// Notify other participants that a project has been updated. /// Notify other participants that a project has been updated.
async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>( async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
request: T, request: T,

View File

@ -6,6 +6,7 @@ use crate::{
}, },
}; };
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use assistant::ContextStore;
use call::{room, ActiveCall, ParticipantLocation, Room}; use call::{room, ActiveCall, ParticipantLocation, Room};
use client::{User, RECEIVE_TIMEOUT}; use client::{User, RECEIVE_TIMEOUT};
use collections::{HashMap, HashSet}; use collections::{HashMap, HashSet};
@ -6449,3 +6450,123 @@ async fn test_preview_tabs(cx: &mut TestAppContext) {
assert!(!pane.can_navigate_forward()); assert!(!pane.can_navigate_forward());
}); });
} }
#[gpui::test(iterations = 10)]
async fn test_context_collaboration_with_reconnect(
executor: BackgroundExecutor,
cx_a: &mut TestAppContext,
cx_b: &mut TestAppContext,
) {
let mut server = TestServer::start(executor.clone()).await;
let client_a = server.create_client(cx_a, "user_a").await;
let client_b = server.create_client(cx_b, "user_b").await;
server
.create_room(&mut [(&client_a, cx_a), (&client_b, cx_b)])
.await;
let active_call_a = cx_a.read(ActiveCall::global);
client_a.fs().insert_tree("/a", Default::default()).await;
let (project_a, _) = client_a.build_local_project("/a", cx_a).await;
let project_id = active_call_a
.update(cx_a, |call, cx| call.share_project(project_a.clone(), cx))
.await
.unwrap();
let project_b = client_b.build_dev_server_project(project_id, cx_b).await;
// Client A sees that a guest has joined.
executor.run_until_parked();
project_a.read_with(cx_a, |project, _| {
assert_eq!(project.collaborators().len(), 1);
});
project_b.read_with(cx_b, |project, _| {
assert_eq!(project.collaborators().len(), 1);
});
let context_store_a = cx_a
.update(|cx| ContextStore::new(project_a.clone(), cx))
.await
.unwrap();
let context_store_b = cx_b
.update(|cx| ContextStore::new(project_b.clone(), cx))
.await
.unwrap();
// Client A creates a new context.
let context_a = context_store_a.update(cx_a, |store, cx| store.create(cx));
executor.run_until_parked();
// Client B retrieves host's contexts and joins one.
let context_b = context_store_b
.update(cx_b, |store, cx| {
let host_contexts = store.host_contexts().to_vec();
assert_eq!(host_contexts.len(), 1);
store.open_remote_context(host_contexts[0].id.clone(), cx)
})
.await
.unwrap();
// Host and guest make changes
context_a.update(cx_a, |context, cx| {
context.buffer().update(cx, |buffer, cx| {
buffer.edit([(0..0, "Host change\n")], None, cx)
})
});
context_b.update(cx_b, |context, cx| {
context.buffer().update(cx, |buffer, cx| {
buffer.edit([(0..0, "Guest change\n")], None, cx)
})
});
executor.run_until_parked();
assert_eq!(
context_a.read_with(cx_a, |context, cx| context.buffer().read(cx).text()),
"Guest change\nHost change\n"
);
assert_eq!(
context_b.read_with(cx_b, |context, cx| context.buffer().read(cx).text()),
"Guest change\nHost change\n"
);
// Disconnect client A and make some changes while disconnected.
server.disconnect_client(client_a.peer_id().unwrap());
server.forbid_connections();
context_a.update(cx_a, |context, cx| {
context.buffer().update(cx, |buffer, cx| {
buffer.edit([(0..0, "Host offline change\n")], None, cx)
})
});
context_b.update(cx_b, |context, cx| {
context.buffer().update(cx, |buffer, cx| {
buffer.edit([(0..0, "Guest offline change\n")], None, cx)
})
});
executor.run_until_parked();
assert_eq!(
context_a.read_with(cx_a, |context, cx| context.buffer().read(cx).text()),
"Host offline change\nGuest change\nHost change\n"
);
assert_eq!(
context_b.read_with(cx_b, |context, cx| context.buffer().read(cx).text()),
"Guest offline change\nGuest change\nHost change\n"
);
// Allow client A to reconnect and verify that contexts converge.
server.allow_connections();
executor.advance_clock(RECEIVE_TIMEOUT);
assert_eq!(
context_a.read_with(cx_a, |context, cx| context.buffer().read(cx).text()),
"Guest offline change\nHost offline change\nGuest change\nHost change\n"
);
assert_eq!(
context_b.read_with(cx_b, |context, cx| context.buffer().read(cx).text()),
"Guest offline change\nHost offline change\nGuest change\nHost change\n"
);
// Client A disconnects without being able to reconnect. Context B becomes readonly.
server.forbid_connections();
server.disconnect_client(client_a.peer_id().unwrap());
executor.advance_clock(RECEIVE_TIMEOUT + RECONNECT_TIMEOUT);
context_b.read_with(cx_b, |context, cx| {
assert!(context.buffer().read(cx).read_only());
});
}

View File

@ -294,6 +294,8 @@ impl TestServer {
menu::init(); menu::init();
dev_server_projects::init(client.clone(), cx); dev_server_projects::init(client.clone(), cx);
settings::KeymapFile::load_asset(os_keymap, cx).unwrap(); settings::KeymapFile::load_asset(os_keymap, cx).unwrap();
assistant::FakeCompletionProvider::setup_test(cx);
assistant::context_store::init(&client);
}); });
client client

View File

@ -1903,6 +1903,10 @@ impl Buffer {
self.deferred_ops.insert(deferred_ops); self.deferred_ops.insert(deferred_ops);
} }
pub fn has_deferred_ops(&self) -> bool {
!self.deferred_ops.is_empty() || self.text.has_deferred_ops()
}
fn can_apply_op(&self, operation: &Operation) -> bool { fn can_apply_op(&self, operation: &Operation) -> bool {
match operation { match operation {
Operation::Buffer(_) => { Operation::Buffer(_) => {

View File

@ -1,7 +1,7 @@
//! Handles conversions of `language` items to and from the [`rpc`] protocol. //! Handles conversions of `language` items to and from the [`rpc`] protocol.
use crate::{diagnostic_set::DiagnosticEntry, CursorShape, Diagnostic}; use crate::{diagnostic_set::DiagnosticEntry, CursorShape, Diagnostic};
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Context as _, Result};
use clock::ReplicaId; use clock::ReplicaId;
use lsp::{DiagnosticSeverity, LanguageServerId}; use lsp::{DiagnosticSeverity, LanguageServerId};
use rpc::proto; use rpc::proto;
@ -231,6 +231,21 @@ pub fn serialize_anchor(anchor: &Anchor) -> proto::Anchor {
} }
} }
pub fn serialize_anchor_range(range: Range<Anchor>) -> proto::AnchorRange {
proto::AnchorRange {
start: Some(serialize_anchor(&range.start)),
end: Some(serialize_anchor(&range.end)),
}
}
/// Deserializes an [`Range<Anchor>`] from the RPC representation.
pub fn deserialize_anchor_range(range: proto::AnchorRange) -> Result<Range<Anchor>> {
Ok(
deserialize_anchor(range.start.context("invalid anchor")?).context("invalid anchor")?
..deserialize_anchor(range.end.context("invalid anchor")?).context("invalid anchor")?,
)
}
// This behavior is currently copied in the collab database, for snapshotting channel notes // This behavior is currently copied in the collab database, for snapshotting channel notes
/// Deserializes an [`crate::Operation`] from the RPC representation. /// Deserializes an [`crate::Operation`] from the RPC representation.
pub fn deserialize_operation(message: proto::Operation) -> Result<crate::Operation> { pub fn deserialize_operation(message: proto::Operation) -> Result<crate::Operation> {

View File

@ -355,6 +355,9 @@ pub enum Event {
}, },
CollaboratorJoined(proto::PeerId), CollaboratorJoined(proto::PeerId),
CollaboratorLeft(proto::PeerId), CollaboratorLeft(proto::PeerId),
HostReshared,
Reshared,
Rejoined,
RefreshInlayHints, RefreshInlayHints,
RevealInProjectPanel(ProjectEntryId), RevealInProjectPanel(ProjectEntryId),
SnippetEdit(BufferId, Vec<(lsp::Range, Snippet)>), SnippetEdit(BufferId, Vec<(lsp::Range, Snippet)>),
@ -1716,6 +1719,7 @@ impl Project {
self.shared_buffers.clear(); self.shared_buffers.clear();
self.set_collaborators_from_proto(message.collaborators, cx)?; self.set_collaborators_from_proto(message.collaborators, cx)?;
self.metadata_changed(cx); self.metadata_changed(cx);
cx.emit(Event::Reshared);
Ok(()) Ok(())
} }
@ -1753,6 +1757,7 @@ impl Project {
.collect(); .collect();
self.enqueue_buffer_ordered_message(BufferOrderedMessage::Resync) self.enqueue_buffer_ordered_message(BufferOrderedMessage::Resync)
.unwrap(); .unwrap();
cx.emit(Event::Rejoined);
cx.notify(); cx.notify();
Ok(()) Ok(())
} }
@ -1805,9 +1810,11 @@ impl Project {
} }
} }
self.client.send(proto::UnshareProject { self.client
project_id: remote_id, .send(proto::UnshareProject {
})?; project_id: remote_id,
})
.ok();
Ok(()) Ok(())
} else { } else {
@ -8810,6 +8817,7 @@ impl Project {
.retain(|_, buffer| !matches!(buffer, OpenBuffer::Operations(_))); .retain(|_, buffer| !matches!(buffer, OpenBuffer::Operations(_)));
this.enqueue_buffer_ordered_message(BufferOrderedMessage::Resync) this.enqueue_buffer_ordered_message(BufferOrderedMessage::Resync)
.unwrap(); .unwrap();
cx.emit(Event::HostReshared);
} }
cx.emit(Event::CollaboratorUpdated { cx.emit(Event::CollaboratorUpdated {

View File

@ -255,7 +255,14 @@ message Envelope {
TaskTemplates task_templates = 206; TaskTemplates task_templates = 206;
LinkedEditingRange linked_editing_range = 209; LinkedEditingRange linked_editing_range = 209;
LinkedEditingRangeResponse linked_editing_range_response = 210; // current max LinkedEditingRangeResponse linked_editing_range_response = 210;
AdvertiseContexts advertise_contexts = 211;
OpenContext open_context = 212;
OpenContextResponse open_context_response = 213;
UpdateContext update_context = 214;
SynchronizeContexts synchronize_contexts = 215;
SynchronizeContextsResponse synchronize_contexts_response = 216; // current max
} }
reserved 158 to 161; reserved 158 to 161;
@ -2222,3 +2229,117 @@ message TaskSourceKind {
string name = 1; string name = 1;
} }
} }
message ContextMessageStatus {
oneof variant {
Done done = 1;
Pending pending = 2;
Error error = 3;
}
message Done {}
message Pending {}
message Error {
string message = 1;
}
}
message ContextMessage {
LamportTimestamp id = 1;
Anchor start = 2;
LanguageModelRole role = 3;
ContextMessageStatus status = 4;
}
message SlashCommandOutputSection {
AnchorRange range = 1;
string icon_name = 2;
string label = 3;
}
message ContextOperation {
oneof variant {
InsertMessage insert_message = 1;
UpdateMessage update_message = 2;
UpdateSummary update_summary = 3;
SlashCommandFinished slash_command_finished = 4;
BufferOperation buffer_operation = 5;
}
message InsertMessage {
ContextMessage message = 1;
repeated VectorClockEntry version = 2;
}
message UpdateMessage {
LamportTimestamp message_id = 1;
LanguageModelRole role = 2;
ContextMessageStatus status = 3;
LamportTimestamp timestamp = 4;
repeated VectorClockEntry version = 5;
}
message UpdateSummary {
string summary = 1;
bool done = 2;
LamportTimestamp timestamp = 3;
repeated VectorClockEntry version = 4;
}
message SlashCommandFinished {
LamportTimestamp id = 1;
AnchorRange output_range = 2;
repeated SlashCommandOutputSection sections = 3;
repeated VectorClockEntry version = 4;
}
message BufferOperation {
Operation operation = 1;
}
}
message Context {
repeated ContextOperation operations = 1;
}
message ContextMetadata {
string context_id = 1;
optional string summary = 2;
}
message AdvertiseContexts {
uint64 project_id = 1;
repeated ContextMetadata contexts = 2;
}
message OpenContext {
uint64 project_id = 1;
string context_id = 2;
}
message OpenContextResponse {
Context context = 1;
}
message UpdateContext {
uint64 project_id = 1;
string context_id = 2;
ContextOperation operation = 3;
}
message ContextVersion {
string context_id = 1;
repeated VectorClockEntry context_version = 2;
repeated VectorClockEntry buffer_version = 3;
}
message SynchronizeContexts {
uint64 project_id = 1;
repeated ContextVersion contexts = 2;
}
message SynchronizeContextsResponse {
repeated ContextVersion contexts = 1;
}

View File

@ -337,7 +337,13 @@ messages!(
(OpenNewBuffer, Foreground), (OpenNewBuffer, Foreground),
(RestartLanguageServers, Foreground), (RestartLanguageServers, Foreground),
(LinkedEditingRange, Background), (LinkedEditingRange, Background),
(LinkedEditingRangeResponse, Background) (LinkedEditingRangeResponse, Background),
(AdvertiseContexts, Foreground),
(OpenContext, Foreground),
(OpenContextResponse, Foreground),
(UpdateContext, Foreground),
(SynchronizeContexts, Foreground),
(SynchronizeContextsResponse, Foreground),
); );
request_messages!( request_messages!(
@ -449,7 +455,9 @@ request_messages!(
(DeleteDevServerProject, Ack), (DeleteDevServerProject, Ack),
(RegenerateDevServerToken, RegenerateDevServerTokenResponse), (RegenerateDevServerToken, RegenerateDevServerTokenResponse),
(RenameDevServer, Ack), (RenameDevServer, Ack),
(RestartLanguageServers, Ack) (RestartLanguageServers, Ack),
(OpenContext, OpenContextResponse),
(SynchronizeContexts, SynchronizeContextsResponse),
); );
entity_messages!( entity_messages!(
@ -511,6 +519,10 @@ entity_messages!(
UpdateWorktree, UpdateWorktree,
UpdateWorktreeSettings, UpdateWorktreeSettings,
LspExtExpandMacro, LspExtExpandMacro,
AdvertiseContexts,
OpenContext,
UpdateContext,
SynchronizeContexts,
); );
entity_messages!( entity_messages!(

View File

@ -1,12 +1,15 @@
use std::fmt::Debug;
use clock::ReplicaId; use clock::ReplicaId;
use collections::{BTreeMap, HashSet};
pub struct Network<T: Clone, R: rand::Rng> { pub struct Network<T: Clone, R: rand::Rng> {
inboxes: std::collections::BTreeMap<ReplicaId, Vec<Envelope<T>>>, inboxes: BTreeMap<ReplicaId, Vec<Envelope<T>>>,
all_messages: Vec<T>, disconnected_peers: HashSet<ReplicaId>,
rng: R, rng: R,
} }
#[derive(Clone)] #[derive(Clone, Debug)]
struct Envelope<T: Clone> { struct Envelope<T: Clone> {
message: T, message: T,
} }
@ -14,8 +17,8 @@ struct Envelope<T: Clone> {
impl<T: Clone, R: rand::Rng> Network<T, R> { impl<T: Clone, R: rand::Rng> Network<T, R> {
pub fn new(rng: R) -> Self { pub fn new(rng: R) -> Self {
Network { Network {
inboxes: Default::default(), inboxes: BTreeMap::default(),
all_messages: Vec::new(), disconnected_peers: HashSet::default(),
rng, rng,
} }
} }
@ -24,6 +27,24 @@ impl<T: Clone, R: rand::Rng> Network<T, R> {
self.inboxes.insert(id, Vec::new()); self.inboxes.insert(id, Vec::new());
} }
pub fn disconnect_peer(&mut self, id: ReplicaId) {
self.disconnected_peers.insert(id);
self.inboxes.get_mut(&id).unwrap().clear();
}
pub fn reconnect_peer(&mut self, id: ReplicaId, replicate_from: ReplicaId) {
assert!(self.disconnected_peers.remove(&id));
self.replicate(replicate_from, id);
}
pub fn is_disconnected(&self, id: ReplicaId) -> bool {
self.disconnected_peers.contains(&id)
}
pub fn contains_disconnected_peers(&self) -> bool {
!self.disconnected_peers.is_empty()
}
pub fn replicate(&mut self, old_replica_id: ReplicaId, new_replica_id: ReplicaId) { pub fn replicate(&mut self, old_replica_id: ReplicaId, new_replica_id: ReplicaId) {
self.inboxes self.inboxes
.insert(new_replica_id, self.inboxes[&old_replica_id].clone()); .insert(new_replica_id, self.inboxes[&old_replica_id].clone());
@ -34,8 +55,13 @@ impl<T: Clone, R: rand::Rng> Network<T, R> {
} }
pub fn broadcast(&mut self, sender: ReplicaId, messages: Vec<T>) { pub fn broadcast(&mut self, sender: ReplicaId, messages: Vec<T>) {
// Drop messages from disconnected peers.
if self.disconnected_peers.contains(&sender) {
return;
}
for (replica, inbox) in self.inboxes.iter_mut() { for (replica, inbox) in self.inboxes.iter_mut() {
if *replica != sender { if *replica != sender && !self.disconnected_peers.contains(replica) {
for message in &messages { for message in &messages {
// Insert one or more duplicates of this message, potentially *before* the previous // Insert one or more duplicates of this message, potentially *before* the previous
// message sent by this peer to simulate out-of-order delivery. // message sent by this peer to simulate out-of-order delivery.
@ -51,7 +77,6 @@ impl<T: Clone, R: rand::Rng> Network<T, R> {
} }
} }
} }
self.all_messages.extend(messages);
} }
pub fn has_unreceived(&self, receiver: ReplicaId) -> bool { pub fn has_unreceived(&self, receiver: ReplicaId) -> bool {

View File

@ -1265,6 +1265,10 @@ impl Buffer {
} }
} }
pub fn has_deferred_ops(&self) -> bool {
!self.deferred_ops.is_empty()
}
pub fn peek_undo_stack(&self) -> Option<&HistoryEntry> { pub fn peek_undo_stack(&self) -> Option<&HistoryEntry> {
self.history.undo_stack.last() self.history.undo_stack.last()
} }

View File

@ -1,6 +1,6 @@
use gpui::{svg, AnimationElement, Hsla, IntoElement, Rems, Transformation}; use gpui::{svg, AnimationElement, Hsla, IntoElement, Rems, Transformation};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use strum::EnumIter; use strum::{EnumIter, EnumString, IntoStaticStr};
use crate::{prelude::*, Indicator}; use crate::{prelude::*, Indicator};
@ -90,7 +90,9 @@ impl IconSize {
} }
} }
#[derive(Debug, PartialEq, Copy, Clone, EnumIter, Serialize, Deserialize)] #[derive(
Debug, Eq, PartialEq, Copy, Clone, EnumIter, EnumString, IntoStaticStr, Serialize, Deserialize,
)]
pub enum IconName { pub enum IconName {
Ai, Ai,
ArrowCircle, ArrowCircle,