From 8944af7406e68266004b05aaaf7e6ff74413490d Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Wed, 10 Jul 2024 17:36:22 +0200 Subject: [PATCH] Lay the groundwork for collaborating on assistant panel (#13991) This pull request introduces collaboration for the assistant panel by turning `Context` into a CRDT. `ContextStore` is responsible for sending and applying operations, as well as synchronizing missed changes while the connection was lost. Contexts are shared on a per-project basis, and only the host can share them for now. Shared contexts can be accessed via the `History` tab in the assistant panel. image Please note that this doesn't implement following yet, which is scheduled for a subsequent pull request. Release Notes: - N/A --- Cargo.lock | 4 + crates/assistant/Cargo.toml | 11 + crates/assistant/src/assistant.rs | 94 +- crates/assistant/src/assistant_panel.rs | 2173 ++---------- crates/assistant/src/completion_provider.rs | 4 +- .../assistant/src/completion_provider/fake.rs | 1 - crates/assistant/src/context.rs | 3009 +++++++++++++++++ crates/assistant/src/context_store.rs | 629 +++- crates/assistant/src/prompt_library.rs | 4 +- crates/assistant/src/slash_command.rs | 11 +- .../src/assistant_slash_command.rs | 2 +- crates/clock/Cargo.toml | 1 + crates/clock/src/clock.rs | 7 +- crates/collab/Cargo.toml | 1 + crates/collab/src/rpc.rs | 55 + crates/collab/src/tests/integration_tests.rs | 121 + crates/collab/src/tests/test_server.rs | 2 + crates/language/src/buffer.rs | 4 + crates/language/src/proto.rs | 17 +- crates/project/src/project.rs | 14 +- crates/proto/proto/zed.proto | 123 +- crates/proto/src/proto.rs | 16 +- crates/text/src/network.rs | 39 +- crates/text/src/text.rs | 4 + crates/ui/src/components/icon.rs | 6 +- 25 files changed, 4232 insertions(+), 2120 deletions(-) create mode 100644 crates/assistant/src/context.rs diff --git a/Cargo.lock b/Cargo.lock index 42a197fe6e..31307a6848 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -377,6 +377,7 @@ dependencies = [ "cargo_toml", "chrono", "client", + "clock", "collections", "command_palette_hooks", "ctor", @@ -419,6 +420,7 @@ dependencies = [ "telemetry_events", "terminal", "terminal_view", + "text", "theme", "tiktoken-rs", "toml 0.8.10", @@ -2405,6 +2407,7 @@ version = "0.1.0" dependencies = [ "chrono", "parking_lot", + "serde", "smallvec", ] @@ -2463,6 +2466,7 @@ version = "0.44.0" dependencies = [ "anthropic", "anyhow", + "assistant", "async-trait", "async-tungstenite", "audio", diff --git a/crates/assistant/Cargo.toml b/crates/assistant/Cargo.toml index 32c3eda683..c48da49a0d 100644 --- a/crates/assistant/Cargo.toml +++ b/crates/assistant/Cargo.toml @@ -12,6 +12,14 @@ workspace = true path = "src/assistant.rs" doctest = false +[features] +test-support = [ + "editor/test-support", + "language/test-support", + "project/test-support", + "text/test-support", +] + [dependencies] anthropic = { workspace = true, features = ["schemars"] } anyhow.workspace = true @@ -21,6 +29,7 @@ breadcrumbs.workspace = true cargo_toml.workspace = true chrono.workspace = true client.workspace = true +clock.workspace = true collections.workspace = true command_palette_hooks.workspace = true editor.workspace = true @@ -72,7 +81,9 @@ picker.workspace = true ctor.workspace = true editor = { workspace = true, features = ["test-support"] } env_logger.workspace = true +language = { workspace = true, features = ["test-support"] } log.workspace = true project = { workspace = true, features = ["test-support"] } rand.workspace = true +text = { workspace = true, features = ["test-support"] } unindent.workspace = true diff --git a/crates/assistant/src/assistant.rs b/crates/assistant/src/assistant.rs index 9b7d4e3f73..9d6f793650 100644 --- a/crates/assistant/src/assistant.rs +++ b/crates/assistant/src/assistant.rs @@ -1,7 +1,8 @@ pub mod assistant_panel; pub mod assistant_settings; mod completion_provider; -mod context_store; +mod context; +pub mod context_store; mod inline_assistant; mod model_selector; mod prompt_library; @@ -16,8 +17,9 @@ use assistant_settings::{AnthropicModel, AssistantSettings, CloudModel, OllamaMo use assistant_slash_command::SlashCommandRegistry; use client::{proto, Client}; use command_palette_hooks::CommandPaletteFilter; -pub(crate) use completion_provider::*; -pub(crate) use context_store::*; +pub use completion_provider::*; +pub use context::*; +pub use context_store::*; use fs::Fs; use gpui::{actions, AppContext, Global, SharedString, UpdateGlobal}; use indexed_docs::IndexedDocsRegistry; @@ -57,10 +59,14 @@ actions!( ] ); -#[derive( - Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize, -)] -struct MessageId(usize); +#[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize)] +pub struct MessageId(clock::Lamport); + +impl MessageId { + pub fn as_u64(self) -> u64 { + self.0.as_u64() + } +} #[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)] #[serde(rename_all = "lowercase")] @@ -71,8 +77,26 @@ pub enum Role { } impl Role { - pub fn cycle(&mut self) { - *self = match self { + pub fn from_proto(role: i32) -> Role { + match proto::LanguageModelRole::from_i32(role) { + Some(proto::LanguageModelRole::LanguageModelUser) => Role::User, + Some(proto::LanguageModelRole::LanguageModelAssistant) => Role::Assistant, + Some(proto::LanguageModelRole::LanguageModelSystem) => Role::System, + Some(proto::LanguageModelRole::LanguageModelTool) => Role::System, + None => Role::User, + } + } + + pub fn to_proto(&self) -> proto::LanguageModelRole { + match self { + Role::User => proto::LanguageModelRole::LanguageModelUser, + Role::Assistant => proto::LanguageModelRole::LanguageModelAssistant, + Role::System => proto::LanguageModelRole::LanguageModelSystem, + } + } + + pub fn cycle(self) -> Role { + match self { Role::User => Role::Assistant, Role::Assistant => Role::System, Role::System => Role::User, @@ -151,11 +175,7 @@ pub struct LanguageModelRequestMessage { impl LanguageModelRequestMessage { pub fn to_proto(&self) -> proto::LanguageModelRequestMessage { proto::LanguageModelRequestMessage { - role: match self.role { - Role::User => proto::LanguageModelRole::LanguageModelUser, - Role::Assistant => proto::LanguageModelRole::LanguageModelAssistant, - Role::System => proto::LanguageModelRole::LanguageModelSystem, - } as i32, + role: self.role.to_proto() as i32, content: self.content.clone(), tool_calls: Vec::new(), tool_call_id: None, @@ -222,19 +242,48 @@ pub struct LanguageModelChoiceDelta { pub finish_reason: Option, } -#[derive(Clone, Debug, Serialize, Deserialize)] -struct MessageMetadata { - role: Role, - status: MessageStatus, -} - -#[derive(Clone, Debug, Serialize, Deserialize)] -enum MessageStatus { +#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)] +pub enum MessageStatus { Pending, Done, Error(SharedString), } +impl MessageStatus { + pub fn from_proto(status: proto::ContextMessageStatus) -> MessageStatus { + match status.variant { + Some(proto::context_message_status::Variant::Pending(_)) => MessageStatus::Pending, + Some(proto::context_message_status::Variant::Done(_)) => MessageStatus::Done, + Some(proto::context_message_status::Variant::Error(error)) => { + MessageStatus::Error(error.message.into()) + } + None => MessageStatus::Pending, + } + } + + pub fn to_proto(&self) -> proto::ContextMessageStatus { + match self { + MessageStatus::Pending => proto::ContextMessageStatus { + variant: Some(proto::context_message_status::Variant::Pending( + proto::context_message_status::Pending {}, + )), + }, + MessageStatus::Done => proto::ContextMessageStatus { + variant: Some(proto::context_message_status::Variant::Done( + proto::context_message_status::Done {}, + )), + }, + MessageStatus::Error(message) => proto::ContextMessageStatus { + variant: Some(proto::context_message_status::Variant::Error( + proto::context_message_status::Error { + message: message.to_string(), + }, + )), + }, + } + } +} + /// The state pertaining to the Assistant. #[derive(Default)] struct Assistant { @@ -287,6 +336,7 @@ pub fn init(fs: Arc, client: Arc, cx: &mut AppContext) { }) .detach(); + context_store::init(&client); prompt_library::init(cx); completion_provider::init(client.clone(), cx); assistant_slash_command::init(cx); diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index a0b5179211..9a98f5aed8 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -1,24 +1,23 @@ -use crate::slash_command::docs_command::{DocsSlashCommand, DocsSlashCommandArgs}; use crate::{ assistant_settings::{AssistantDockPosition, AssistantSettings}, - humanize_token_count, + humanize_token_count, parse_next_edit_suggestion, prompt_library::open_prompt_library, search::*, slash_command::{ - default_command::DefaultSlashCommand, SlashCommandCompletionProvider, SlashCommandLine, - SlashCommandRegistry, + default_command::DefaultSlashCommand, + docs_command::{DocsSlashCommand, DocsSlashCommandArgs}, + SlashCommandCompletionProvider, SlashCommandRegistry, }, terminal_inline_assistant::TerminalInlineAssistant, - ApplyEdit, Assist, CompletionProvider, ConfirmCommand, ContextStore, CycleMessageRole, - DeployHistory, DeployPromptLibrary, InlineAssist, InlineAssistant, InsertIntoEditor, - LanguageModelRequest, LanguageModelRequestMessage, MessageId, MessageMetadata, MessageStatus, - ModelSelector, QuoteSelection, ResetKey, Role, SavedContext, SavedContextMetadata, - SavedMessage, Split, ToggleFocus, ToggleModelSelector, + ApplyEdit, Assist, CompletionProvider, ConfirmCommand, Context, ContextEvent, ContextId, + ContextStore, CycleMessageRole, DeployHistory, DeployPromptLibrary, EditSuggestion, + InlineAssist, InlineAssistant, InsertIntoEditor, MessageStatus, ModelSelector, + PendingSlashCommand, PendingSlashCommandStatus, QuoteSelection, RemoteContextMetadata, + ResetKey, Role, SavedContextMetadata, Split, ToggleFocus, ToggleModelSelector, }; use anyhow::{anyhow, Result}; -use assistant_slash_command::{SlashCommand, SlashCommandOutput, SlashCommandOutputSection}; +use assistant_slash_command::{SlashCommand, SlashCommandOutputSection}; use breadcrumbs::Breadcrumbs; -use client::telemetry::Telemetry; use collections::{BTreeSet, HashMap, HashSet}; use editor::{ actions::{FoldAt, MoveToEndOfLine, Newline, ShowCompletions, UnfoldAt}, @@ -30,44 +29,33 @@ use editor::{ }; use editor::{display_map::CreaseId, FoldPlaceholder}; use fs::Fs; -use futures::future::Shared; -use futures::{FutureExt, StreamExt}; use gpui::{ div, percentage, point, Action, Animation, AnimationExt, AnyElement, AnyView, AppContext, - AsyncAppContext, AsyncWindowContext, ClipboardItem, Context as _, DismissEvent, Empty, - EventEmitter, FocusHandle, FocusableView, InteractiveElement, IntoElement, Model, ModelContext, - ParentElement, Pixels, Render, SharedString, StatefulInteractiveElement, Styled, Subscription, - Task, Transformation, UpdateGlobal, View, ViewContext, VisualContext, WeakView, WindowContext, + AsyncWindowContext, ClipboardItem, DismissEvent, Empty, EventEmitter, FocusHandle, + FocusableView, InteractiveElement, IntoElement, Model, ParentElement, Pixels, Render, + SharedString, StatefulInteractiveElement, Styled, Subscription, Task, Transformation, + UpdateGlobal, View, ViewContext, VisualContext, WeakView, WindowContext, }; use indexed_docs::IndexedDocsStore; use language::{ - language_settings::SoftWrap, AnchorRangeExt as _, AutoindentMode, Buffer, LanguageRegistry, - LspAdapterDelegate, OffsetRangeExt as _, Point, ToOffset as _, + language_settings::SoftWrap, AutoindentMode, Buffer, LanguageRegistry, LspAdapterDelegate, + OffsetRangeExt as _, Point, ToOffset, }; use multi_buffer::MultiBufferRow; -use paths::contexts_dir; use picker::{Picker, PickerDelegate}; use project::{Project, ProjectLspAdapterDelegate, ProjectTransaction}; use search::{buffer_search::DivRegistrar, BufferSearchBar}; use settings::Settings; -use std::{ - cmp::{self, Ordering}, - fmt::Write, - iter, - ops::Range, - path::PathBuf, - sync::Arc, - time::{Duration, Instant}, -}; -use telemetry_events::AssistantKind; +use std::{cmp, fmt::Write, ops::Range, path::PathBuf, sync::Arc, time::Duration}; use terminal_view::{terminal_panel::TerminalPanel, TerminalView}; use theme::ThemeSettings; use ui::{ - prelude::*, ButtonLike, ContextMenu, Disclosure, ElevationIndex, KeyBinding, ListItem, + prelude::*, + utils::{format_distance_from_now, DateTimeType}, + Avatar, AvatarShape, ButtonLike, ContextMenu, Disclosure, ElevationIndex, KeyBinding, ListItem, ListItemSpacing, PopoverMenu, PopoverMenuHandle, Tooltip, }; -use util::{post_inc, ResultExt, TryFutureExt}; -use uuid::Uuid; +use util::ResultExt; use workspace::{ dock::{DockPosition, Panel, PanelEvent}, item::{BreadcrumbText, Item, ItemHandle}, @@ -106,24 +94,30 @@ pub struct AssistantPanel { workspace: WeakView, width: Option, height: Option, + project: Model, context_store: Model, languages: Arc, - slash_commands: Arc, fs: Arc, - telemetry: Arc, subscriptions: Vec, authentication_prompt: Option, model_selector_menu_handle: PopoverMenuHandle, } +#[derive(Clone)] +enum ContextMetadata { + Remote(RemoteContextMetadata), + Saved(SavedContextMetadata), +} + struct SavedContextPickerDelegate { store: Model, - matches: Vec, + project: Model, + matches: Vec, selected_index: usize, } enum SavedContextPickerEvent { - Confirmed { path: PathBuf }, + Confirmed(ContextMetadata), } enum InlineAssistTarget { @@ -134,8 +128,9 @@ enum InlineAssistTarget { impl EventEmitter for Picker {} impl SavedContextPickerDelegate { - fn new(store: Model) -> Self { + fn new(project: Model, store: Model) -> Self { Self { + project, store, matches: Vec::new(), selected_index: 0, @@ -167,7 +162,13 @@ impl PickerDelegate for SavedContextPickerDelegate { cx.spawn(|this, mut cx| async move { let matches = search.await; this.update(&mut cx, |this, cx| { - this.delegate.matches = matches; + let host_contexts = this.delegate.store.read(cx).host_contexts(); + this.delegate.matches = host_contexts + .iter() + .cloned() + .map(ContextMetadata::Remote) + .chain(matches.into_iter().map(ContextMetadata::Saved)) + .collect(); this.delegate.selected_index = 0; cx.notify(); }) @@ -177,9 +178,7 @@ impl PickerDelegate for SavedContextPickerDelegate { fn confirm(&mut self, _secondary: bool, cx: &mut ViewContext>) { if let Some(metadata) = self.matches.get(self.selected_index) { - cx.emit(SavedContextPickerEvent::Confirmed { - path: metadata.path.clone(), - }) + cx.emit(SavedContextPickerEvent::Confirmed(metadata.clone())); } } @@ -189,26 +188,78 @@ impl PickerDelegate for SavedContextPickerDelegate { &self, ix: usize, selected: bool, - _cx: &mut ViewContext>, + cx: &mut ViewContext>, ) -> Option { let context = self.matches.get(ix)?; + let item = match context { + ContextMetadata::Remote(context) => { + let host_user = self.project.read(cx).host().and_then(|collaborator| { + self.project + .read(cx) + .user_store() + .read(cx) + .get_cached_user(collaborator.user_id) + }); + div() + .flex() + .w_full() + .justify_between() + .gap_2() + .child( + h_flex().flex_1().overflow_x_hidden().child( + Label::new(context.summary.clone().unwrap_or("New Context".into())) + .size(LabelSize::Small), + ), + ) + .child( + h_flex() + .gap_2() + .children(if let Some(host_user) = host_user { + vec![ + Avatar::new(host_user.avatar_uri.clone()) + .shape(AvatarShape::Circle) + .into_any_element(), + Label::new(format!("Shared by @{}", host_user.github_login)) + .color(Color::Muted) + .size(LabelSize::Small) + .into_any_element(), + ] + } else { + vec![Label::new("Shared by host") + .color(Color::Muted) + .size(LabelSize::Small) + .into_any_element()] + }), + ) + } + ContextMetadata::Saved(context) => div() + .flex() + .w_full() + .justify_between() + .gap_2() + .child( + h_flex() + .flex_1() + .child(Label::new(context.title.clone()).size(LabelSize::Small)) + .overflow_x_hidden(), + ) + .child( + Label::new(format_distance_from_now( + DateTimeType::Local(context.mtime), + false, + true, + true, + )) + .color(Color::Muted) + .size(LabelSize::Small), + ), + }; Some( ListItem::new(ix) .inset(true) .spacing(ListItemSpacing::Sparse) .selected(selected) - .child( - div() - .flex() - .w_full() - .gap_2() - .child( - Label::new(context.mtime.format("%F %I:%M%p").to_string()) - .color(Color::Muted) - .size(LabelSize::Small), - ) - .child(Label::new(context.title.clone()).size(LabelSize::Small)), - ), + .child(item), ) } } @@ -219,11 +270,14 @@ impl AssistantPanel { cx: AsyncWindowContext, ) -> Task>> { cx.spawn(|mut cx| async move { - // TODO: deserialize state. - let fs = workspace.update(&mut cx, |workspace, _| workspace.app_state().fs.clone())?; - let context_store = cx.update(|cx| ContextStore::new(fs.clone(), cx))?.await?; + let context_store = workspace + .update(&mut cx, |workspace, cx| { + ContextStore::new(workspace.project().clone(), cx) + })? + .await?; workspace.update(&mut cx, |workspace, cx| { - cx.new_view(|cx| Self::new(workspace, context_store.clone(), cx)) + // TODO: deserialize state. + cx.new_view(|cx| Self::new(workspace, context_store, cx)) }) }) } @@ -308,11 +362,10 @@ impl AssistantPanel { workspace: workspace.weak_handle(), width: None, height: None, + project: workspace.project().clone(), context_store, languages: workspace.app_state().languages.clone(), - slash_commands: SlashCommandRegistry::global(cx), fs: workspace.app_state().fs.clone(), - telemetry: workspace.client().telemetry().clone(), subscriptions, authentication_prompt: None, model_selector_menu_handle, @@ -519,16 +572,22 @@ impl AssistantPanel { } fn new_context(&mut self, cx: &mut ViewContext) -> Option> { + let context = self.context_store.update(cx, |store, cx| store.create(cx)); let workspace = self.workspace.upgrade()?; + let lsp_adapter_delegate = workspace.update(cx, |workspace, cx| { + make_lsp_adapter_delegate(workspace.project(), cx).log_err() + }); let editor = cx.new_view(|cx| { - ContextEditor::new( - self.languages.clone(), - self.slash_commands.clone(), + let mut editor = ContextEditor::for_context( + context, self.fs.clone(), workspace, + lsp_adapter_delegate, cx, - ) + ); + editor.insert_default_prompt(cx); + editor }); self.show_context(editor.clone(), cx); @@ -577,7 +636,12 @@ impl AssistantPanel { } else { let assistant_panel = cx.view().downgrade(); let history = cx.new_view(|cx| { - ContextHistory::new(self.context_store.clone(), assistant_panel, cx) + ContextHistory::new( + self.project.clone(), + self.context_store.clone(), + assistant_panel, + cx, + ) }); self.pane.update(cx, |pane, cx| { pane.add_item(Box::new(history), true, true, None, cx); @@ -610,10 +674,14 @@ impl AssistantPanel { Some(self.active_context_editor(cx)?.read(cx).context.clone()) } - fn open_context(&mut self, path: PathBuf, cx: &mut ViewContext) -> Task> { + fn open_saved_context( + &mut self, + path: PathBuf, + cx: &mut ViewContext, + ) -> Task> { let existing_context = self.pane.read(cx).items().find_map(|item| { item.downcast::() - .filter(|editor| editor.read(cx).context.read(cx).path.as_ref() == Some(&path)) + .filter(|editor| editor.read(cx).context.read(cx).path() == Some(&path)) }); if let Some(existing_context) = existing_context { return cx.spawn(|this, mut cx| async move { @@ -621,12 +689,11 @@ impl AssistantPanel { }); } - let saved_context = self.context_store.read(cx).load(path.clone(), cx); + let context = self + .context_store + .update(cx, |store, cx| store.open_local_context(path.clone(), cx)); let fs = self.fs.clone(); let workspace = self.workspace.clone(); - let slash_commands = self.slash_commands.clone(); - let languages = self.languages.clone(); - let telemetry = self.telemetry.clone(); let lsp_adapter_delegate = workspace .update(cx, |workspace, cx| { @@ -636,17 +703,51 @@ impl AssistantPanel { .flatten(); cx.spawn(|this, mut cx| async move { - let saved_context = saved_context.await?; - let context = Context::deserialize( - saved_context, - path, - languages, - slash_commands, - Some(telemetry), - &mut cx, - ) - .await?; + let context = context.await?; + this.update(&mut cx, |this, cx| { + let workspace = workspace + .upgrade() + .ok_or_else(|| anyhow!("workspace dropped"))?; + let editor = cx.new_view(|cx| { + ContextEditor::for_context(context, fs, workspace, lsp_adapter_delegate, cx) + }); + this.show_context(editor, cx); + anyhow::Ok(()) + })??; + Ok(()) + }) + } + fn open_remote_context( + &mut self, + id: ContextId, + cx: &mut ViewContext, + ) -> Task> { + let existing_context = self.pane.read(cx).items().find_map(|item| { + item.downcast::() + .filter(|editor| *editor.read(cx).context.read(cx).id() == id) + }); + if let Some(existing_context) = existing_context { + return cx.spawn(|this, mut cx| async move { + this.update(&mut cx, |this, cx| this.show_context(existing_context, cx)) + }); + } + + let context = self + .context_store + .update(cx, |store, cx| store.open_remote_context(id, cx)); + let fs = self.fs.clone(); + let workspace = self.workspace.clone(); + + let lsp_adapter_delegate = workspace + .update(cx, |workspace, cx| { + make_lsp_adapter_delegate(workspace.project(), cx).log_err() + }) + .log_err() + .flatten(); + + cx.spawn(|this, mut cx| async move { + let context = context.await?; this.update(&mut cx, |this, cx| { let workspace = workspace .upgrade() @@ -804,1200 +905,6 @@ impl FocusableView for AssistantPanel { } } -#[derive(Clone)] -enum ContextEvent { - MessagesEdited, - SummaryChanged, - EditSuggestionsChanged, - StreamedCompletion, - PendingSlashCommandsUpdated { - removed: Vec>, - updated: Vec, - }, - SlashCommandFinished { - output_range: Range, - sections: Vec>, - run_commands_in_output: bool, - }, -} - -#[derive(Default)] -struct Summary { - text: String, - done: bool, -} - -pub struct Context { - id: Option, - buffer: Model, - edit_suggestions: Vec, - pending_slash_commands: Vec, - edits_since_last_slash_command_parse: language::Subscription, - slash_command_output_sections: Vec>, - message_anchors: Vec, - messages_metadata: HashMap, - next_message_id: MessageId, - summary: Option, - pending_summary: Task>, - completion_count: usize, - pending_completions: Vec, - token_count: Option, - pending_token_count: Task>, - pending_edit_suggestion_parse: Option>, - pending_save: Task>, - path: Option, - _subscriptions: Vec, - telemetry: Option>, - slash_command_registry: Arc, - language_registry: Arc, -} - -impl EventEmitter for Context {} - -impl Context { - fn new( - language_registry: Arc, - slash_command_registry: Arc, - telemetry: Option>, - cx: &mut ModelContext, - ) -> Self { - let buffer = cx.new_model(|cx| { - let mut buffer = Buffer::local("", cx); - buffer.set_language_registry(language_registry.clone()); - buffer - }); - let edits_since_last_slash_command_parse = - buffer.update(cx, |buffer, _| buffer.subscribe()); - let mut this = Self { - id: Some(Uuid::new_v4().to_string()), - message_anchors: Default::default(), - messages_metadata: Default::default(), - next_message_id: Default::default(), - edit_suggestions: Vec::new(), - pending_slash_commands: Vec::new(), - slash_command_output_sections: Vec::new(), - edits_since_last_slash_command_parse, - summary: None, - pending_summary: Task::ready(None), - completion_count: Default::default(), - pending_completions: Default::default(), - token_count: None, - pending_token_count: Task::ready(None), - pending_edit_suggestion_parse: None, - _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)], - pending_save: Task::ready(Ok(())), - path: None, - buffer, - telemetry, - language_registry, - slash_command_registry, - }; - - let message = MessageAnchor { - id: MessageId(post_inc(&mut this.next_message_id.0)), - start: language::Anchor::MIN, - }; - this.message_anchors.push(message.clone()); - this.messages_metadata.insert( - message.id, - MessageMetadata { - role: Role::User, - status: MessageStatus::Done, - }, - ); - - this.set_language(cx); - this.count_remaining_tokens(cx); - this - } - - fn serialize(&self, cx: &AppContext) -> SavedContext { - let buffer = self.buffer.read(cx); - SavedContext { - id: self.id.clone(), - zed: "context".into(), - version: SavedContext::VERSION.into(), - text: buffer.text(), - message_metadata: self.messages_metadata.clone(), - messages: self - .messages(cx) - .map(|message| SavedMessage { - id: message.id, - start: message.offset_range.start, - }) - .collect(), - summary: self - .summary - .as_ref() - .map(|summary| summary.text.clone()) - .unwrap_or_default(), - slash_command_output_sections: self - .slash_command_output_sections - .iter() - .filter_map(|section| { - let range = section.range.to_offset(buffer); - if section.range.start.is_valid(buffer) && !range.is_empty() { - Some(SlashCommandOutputSection { - range, - icon: section.icon, - label: section.label.clone(), - }) - } else { - None - } - }) - .collect(), - } - } - - #[allow(clippy::too_many_arguments)] - async fn deserialize( - saved_context: SavedContext, - path: PathBuf, - language_registry: Arc, - slash_command_registry: Arc, - telemetry: Option>, - cx: &mut AsyncAppContext, - ) -> Result> { - let id = match saved_context.id { - Some(id) => Some(id), - None => Some(Uuid::new_v4().to_string()), - }; - - let markdown = language_registry.language_for_name("Markdown"); - let mut message_anchors = Vec::new(); - let mut next_message_id = MessageId(0); - let buffer = cx.new_model(|cx| { - let mut buffer = Buffer::local(saved_context.text, cx); - for message in saved_context.messages { - message_anchors.push(MessageAnchor { - id: message.id, - start: buffer.anchor_before(message.start), - }); - next_message_id = cmp::max(next_message_id, MessageId(message.id.0 + 1)); - } - buffer.set_language_registry(language_registry.clone()); - cx.spawn(|buffer, mut cx| async move { - let markdown = markdown.await?; - buffer.update(&mut cx, |buffer: &mut Buffer, cx| { - buffer.set_language(Some(markdown), cx) - })?; - anyhow::Ok(()) - }) - .detach_and_log_err(cx); - buffer - })?; - - cx.new_model(move |cx| { - let edits_since_last_slash_command_parse = - buffer.update(cx, |buffer, _| buffer.subscribe()); - let mut this = Self { - id, - message_anchors, - messages_metadata: saved_context.message_metadata, - next_message_id, - edit_suggestions: Vec::new(), - pending_slash_commands: Vec::new(), - slash_command_output_sections: saved_context - .slash_command_output_sections - .into_iter() - .map(|section| { - let buffer = buffer.read(cx); - SlashCommandOutputSection { - range: buffer.anchor_after(section.range.start) - ..buffer.anchor_before(section.range.end), - icon: section.icon, - label: section.label, - } - }) - .collect(), - edits_since_last_slash_command_parse, - summary: Some(Summary { - text: saved_context.summary, - done: true, - }), - pending_summary: Task::ready(None), - completion_count: Default::default(), - pending_completions: Default::default(), - token_count: None, - pending_edit_suggestion_parse: None, - pending_token_count: Task::ready(None), - _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)], - pending_save: Task::ready(Ok(())), - path: Some(path), - buffer, - telemetry, - language_registry, - slash_command_registry, - }; - this.set_language(cx); - this.reparse_edit_suggestions(cx); - this.count_remaining_tokens(cx); - this - }) - } - - fn set_language(&mut self, cx: &mut ModelContext) { - let markdown = self.language_registry.language_for_name("Markdown"); - cx.spawn(|this, mut cx| async move { - let markdown = markdown.await?; - this.update(&mut cx, |this, cx| { - this.buffer - .update(cx, |buffer, cx| buffer.set_language(Some(markdown), cx)); - }) - }) - .detach_and_log_err(cx); - } - - fn handle_buffer_event( - &mut self, - _: Model, - event: &language::Event, - cx: &mut ModelContext, - ) { - if *event == language::Event::Edited { - self.count_remaining_tokens(cx); - self.reparse_edit_suggestions(cx); - self.reparse_slash_commands(cx); - cx.emit(ContextEvent::MessagesEdited); - } - } - - pub(crate) fn token_count(&self) -> Option { - self.token_count - } - - pub(crate) fn count_remaining_tokens(&mut self, cx: &mut ModelContext) { - let request = self.to_completion_request(cx); - self.pending_token_count = cx.spawn(|this, mut cx| { - async move { - cx.background_executor() - .timer(Duration::from_millis(200)) - .await; - - let token_count = cx - .update(|cx| CompletionProvider::global(cx).count_tokens(request, cx))? - .await?; - - this.update(&mut cx, |this, cx| { - this.token_count = Some(token_count); - cx.notify() - })?; - anyhow::Ok(()) - } - .log_err() - }); - } - - fn reparse_slash_commands(&mut self, cx: &mut ModelContext) { - let buffer = self.buffer.read(cx); - let mut row_ranges = self - .edits_since_last_slash_command_parse - .consume() - .into_iter() - .map(|edit| { - let start_row = buffer.offset_to_point(edit.new.start).row; - let end_row = buffer.offset_to_point(edit.new.end).row + 1; - start_row..end_row - }) - .peekable(); - - let mut removed = Vec::new(); - let mut updated = Vec::new(); - while let Some(mut row_range) = row_ranges.next() { - while let Some(next_row_range) = row_ranges.peek() { - if row_range.end >= next_row_range.start { - row_range.end = next_row_range.end; - row_ranges.next(); - } else { - break; - } - } - - let start = buffer.anchor_before(Point::new(row_range.start, 0)); - let end = buffer.anchor_after(Point::new( - row_range.end - 1, - buffer.line_len(row_range.end - 1), - )); - - let old_range = self.pending_command_indices_for_range(start..end, cx); - - let mut new_commands = Vec::new(); - let mut lines = buffer.text_for_range(start..end).lines(); - let mut offset = lines.offset(); - while let Some(line) = lines.next() { - if let Some(command_line) = SlashCommandLine::parse(line) { - let name = &line[command_line.name.clone()]; - let argument = command_line.argument.as_ref().and_then(|argument| { - (!argument.is_empty()).then_some(&line[argument.clone()]) - }); - if let Some(command) = self.slash_command_registry.command(name) { - if !command.requires_argument() || argument.is_some() { - let start_ix = offset + command_line.name.start - 1; - let end_ix = offset - + command_line - .argument - .map_or(command_line.name.end, |argument| argument.end); - let source_range = - buffer.anchor_after(start_ix)..buffer.anchor_after(end_ix); - let pending_command = PendingSlashCommand { - name: name.to_string(), - argument: argument.map(ToString::to_string), - source_range, - status: PendingSlashCommandStatus::Idle, - }; - updated.push(pending_command.clone()); - new_commands.push(pending_command); - } - } - } - - offset = lines.offset(); - } - - let removed_commands = self.pending_slash_commands.splice(old_range, new_commands); - removed.extend(removed_commands.map(|command| command.source_range)); - } - - if !updated.is_empty() || !removed.is_empty() { - cx.emit(ContextEvent::PendingSlashCommandsUpdated { removed, updated }); - } - } - - fn reparse_edit_suggestions(&mut self, cx: &mut ModelContext) { - self.pending_edit_suggestion_parse = Some(cx.spawn(|this, mut cx| async move { - cx.background_executor() - .timer(Duration::from_millis(200)) - .await; - - this.update(&mut cx, |this, cx| { - this.reparse_edit_suggestions_in_range(0..this.buffer.read(cx).len(), cx); - }) - .ok(); - })); - } - - fn reparse_edit_suggestions_in_range( - &mut self, - range: Range, - cx: &mut ModelContext, - ) { - self.buffer.update(cx, |buffer, _| { - let range_start = buffer.anchor_before(range.start); - let range_end = buffer.anchor_after(range.end); - let start_ix = self - .edit_suggestions - .binary_search_by(|probe| { - probe - .source_range - .end - .cmp(&range_start, buffer) - .then(Ordering::Greater) - }) - .unwrap_err(); - let end_ix = self - .edit_suggestions - .binary_search_by(|probe| { - probe - .source_range - .start - .cmp(&range_end, buffer) - .then(Ordering::Less) - }) - .unwrap_err(); - - let mut new_edit_suggestions = Vec::new(); - let mut message_lines = buffer.as_rope().chunks_in_range(range).lines(); - while let Some(suggestion) = parse_next_edit_suggestion(&mut message_lines) { - let start_anchor = buffer.anchor_after(suggestion.outer_range.start); - let end_anchor = buffer.anchor_before(suggestion.outer_range.end); - new_edit_suggestions.push(EditSuggestion { - source_range: start_anchor..end_anchor, - full_path: suggestion.path, - }); - } - self.edit_suggestions - .splice(start_ix..end_ix, new_edit_suggestions); - }); - cx.emit(ContextEvent::EditSuggestionsChanged); - cx.notify(); - } - - fn pending_command_for_position( - &mut self, - position: language::Anchor, - cx: &mut ModelContext, - ) -> Option<&mut PendingSlashCommand> { - let buffer = self.buffer.read(cx); - match self - .pending_slash_commands - .binary_search_by(|probe| probe.source_range.end.cmp(&position, buffer)) - { - Ok(ix) => Some(&mut self.pending_slash_commands[ix]), - Err(ix) => { - let cmd = self.pending_slash_commands.get_mut(ix)?; - if position.cmp(&cmd.source_range.start, buffer).is_ge() - && position.cmp(&cmd.source_range.end, buffer).is_le() - { - Some(cmd) - } else { - None - } - } - } - } - - fn pending_commands_for_range( - &self, - range: Range, - cx: &AppContext, - ) -> &[PendingSlashCommand] { - let range = self.pending_command_indices_for_range(range, cx); - &self.pending_slash_commands[range] - } - - fn pending_command_indices_for_range( - &self, - range: Range, - cx: &AppContext, - ) -> Range { - let buffer = self.buffer.read(cx); - let start_ix = match self - .pending_slash_commands - .binary_search_by(|probe| probe.source_range.end.cmp(&range.start, &buffer)) - { - Ok(ix) | Err(ix) => ix, - }; - let end_ix = match self - .pending_slash_commands - .binary_search_by(|probe| probe.source_range.start.cmp(&range.end, &buffer)) - { - Ok(ix) => ix + 1, - Err(ix) => ix, - }; - start_ix..end_ix - } - - fn insert_command_output( - &mut self, - command_range: Range, - output: Task>, - insert_trailing_newline: bool, - cx: &mut ModelContext, - ) { - self.reparse_slash_commands(cx); - - let insert_output_task = cx.spawn(|this, mut cx| { - let command_range = command_range.clone(); - async move { - let output = output.await; - this.update(&mut cx, |this, cx| match output { - Ok(mut output) => { - if insert_trailing_newline { - output.text.push('\n'); - } - - let event = this.buffer.update(cx, |buffer, cx| { - let start = command_range.start.to_offset(buffer); - let old_end = command_range.end.to_offset(buffer); - let new_end = start + output.text.len(); - buffer.edit([(start..old_end, output.text)], None, cx); - - let mut sections = output - .sections - .into_iter() - .map(|section| SlashCommandOutputSection { - range: buffer.anchor_after(start + section.range.start) - ..buffer.anchor_before(start + section.range.end), - icon: section.icon, - label: section.label, - }) - .collect::>(); - sections.sort_by(|a, b| a.range.cmp(&b.range, buffer)); - - this.slash_command_output_sections - .extend(sections.iter().cloned()); - this.slash_command_output_sections - .sort_by(|a, b| a.range.cmp(&b.range, buffer)); - - ContextEvent::SlashCommandFinished { - output_range: buffer.anchor_after(start) - ..buffer.anchor_before(new_end), - sections, - run_commands_in_output: output.run_commands_in_text, - } - }); - cx.emit(event); - } - Err(error) => { - if let Some(pending_command) = - this.pending_command_for_position(command_range.start, cx) - { - pending_command.status = - PendingSlashCommandStatus::Error(error.to_string()); - cx.emit(ContextEvent::PendingSlashCommandsUpdated { - removed: vec![pending_command.source_range.clone()], - updated: vec![pending_command.clone()], - }); - } - } - }) - .ok(); - } - }); - - if let Some(pending_command) = self.pending_command_for_position(command_range.start, cx) { - pending_command.status = PendingSlashCommandStatus::Running { - _task: insert_output_task.shared(), - }; - cx.emit(ContextEvent::PendingSlashCommandsUpdated { - removed: vec![pending_command.source_range.clone()], - updated: vec![pending_command.clone()], - }); - } - } - - fn completion_provider_changed(&mut self, cx: &mut ModelContext) { - self.count_remaining_tokens(cx); - } - - fn assist( - &mut self, - selected_messages: HashSet, - cx: &mut ModelContext, - ) -> Vec { - let mut user_messages = Vec::new(); - - let last_message_id = if let Some(last_message_id) = - self.message_anchors.iter().rev().find_map(|message| { - message - .start - .is_valid(self.buffer.read(cx)) - .then_some(message.id) - }) { - last_message_id - } else { - return Default::default(); - }; - - let mut should_assist = false; - for selected_message_id in selected_messages { - let selected_message_role = - if let Some(metadata) = self.messages_metadata.get(&selected_message_id) { - metadata.role - } else { - continue; - }; - - if selected_message_role == Role::Assistant { - if let Some(user_message) = self.insert_message_after( - selected_message_id, - Role::User, - MessageStatus::Done, - cx, - ) { - user_messages.push(user_message); - } - } else { - should_assist = true; - } - } - - if should_assist { - if !CompletionProvider::global(cx).is_authenticated() { - log::info!("completion provider has no credentials"); - return Default::default(); - } - - let request = self.to_completion_request(cx); - let response = CompletionProvider::global(cx).complete(request, cx); - let assistant_message = self - .insert_message_after(last_message_id, Role::Assistant, MessageStatus::Pending, cx) - .unwrap(); - - // Queue up the user's next reply. - let user_message = self - .insert_message_after(assistant_message.id, Role::User, MessageStatus::Done, cx) - .unwrap(); - user_messages.push(user_message); - - let task = cx.spawn({ - |this, mut cx| async move { - let response = response.await; - let assistant_message_id = assistant_message.id; - let mut response_latency = None; - let stream_completion = async { - let request_start = Instant::now(); - let mut messages = response.inner.await?; - - while let Some(message) = messages.next().await { - if response_latency.is_none() { - response_latency = Some(request_start.elapsed()); - } - let text = message?; - - this.update(&mut cx, |this, cx| { - let message_ix = this - .message_anchors - .iter() - .position(|message| message.id == assistant_message_id)?; - let message_range = this.buffer.update(cx, |buffer, cx| { - let message_start_offset = - this.message_anchors[message_ix].start.to_offset(buffer); - let message_old_end_offset = this.message_anchors - [message_ix + 1..] - .iter() - .find(|message| message.start.is_valid(buffer)) - .map_or(buffer.len(), |message| { - message.start.to_offset(buffer).saturating_sub(1) - }); - let message_new_end_offset = - message_old_end_offset + text.len(); - buffer.edit( - [(message_old_end_offset..message_old_end_offset, text)], - None, - cx, - ); - message_start_offset..message_new_end_offset - }); - this.reparse_edit_suggestions_in_range(message_range, cx); - cx.emit(ContextEvent::StreamedCompletion); - - Some(()) - })?; - smol::future::yield_now().await; - } - - this.update(&mut cx, |this, cx| { - this.pending_completions - .retain(|completion| completion.id != this.completion_count); - this.summarize(cx); - })?; - - anyhow::Ok(()) - }; - - let result = stream_completion.await; - - this.update(&mut cx, |this, cx| { - if let Some(metadata) = - this.messages_metadata.get_mut(&assistant_message.id) - { - let error_message = result - .err() - .map(|error| error.to_string().trim().to_string()); - if let Some(error_message) = error_message.as_ref() { - metadata.status = - MessageStatus::Error(SharedString::from(error_message.clone())); - } else { - metadata.status = MessageStatus::Done; - } - - if let Some(telemetry) = this.telemetry.as_ref() { - let model = CompletionProvider::global(cx).model(); - telemetry.report_assistant_event( - this.id.clone(), - AssistantKind::Panel, - model.telemetry_id(), - response_latency, - error_message, - ); - } - - cx.emit(ContextEvent::MessagesEdited); - } - }) - .ok(); - } - }); - - self.pending_completions.push(PendingCompletion { - id: post_inc(&mut self.completion_count), - _task: task, - }); - } - - user_messages - } - - pub fn to_completion_request(&self, cx: &AppContext) -> LanguageModelRequest { - let messages = self - .messages(cx) - .filter(|message| matches!(message.status, MessageStatus::Done)) - .map(|message| message.to_request_message(self.buffer.read(cx))); - - LanguageModelRequest { - model: CompletionProvider::global(cx).model(), - messages: messages.collect(), - stop: vec![], - temperature: 1.0, - } - } - - fn cancel_last_assist(&mut self) -> bool { - self.pending_completions.pop().is_some() - } - - fn cycle_message_roles(&mut self, ids: HashSet, cx: &mut ModelContext) { - for id in ids { - if let Some(metadata) = self.messages_metadata.get_mut(&id) { - metadata.role.cycle(); - cx.emit(ContextEvent::MessagesEdited); - cx.notify(); - } - } - } - - fn insert_message_after( - &mut self, - message_id: MessageId, - role: Role, - status: MessageStatus, - cx: &mut ModelContext, - ) -> Option { - if let Some(prev_message_ix) = self - .message_anchors - .iter() - .position(|message| message.id == message_id) - { - // Find the next valid message after the one we were given. - let mut next_message_ix = prev_message_ix + 1; - while let Some(next_message) = self.message_anchors.get(next_message_ix) { - if next_message.start.is_valid(self.buffer.read(cx)) { - break; - } - next_message_ix += 1; - } - - let start = self.buffer.update(cx, |buffer, cx| { - let offset = self - .message_anchors - .get(next_message_ix) - .map_or(buffer.len(), |message| message.start.to_offset(buffer) - 1); - buffer.edit([(offset..offset, "\n")], None, cx); - buffer.anchor_before(offset + 1) - }); - let message = MessageAnchor { - id: MessageId(post_inc(&mut self.next_message_id.0)), - start, - }; - self.message_anchors - .insert(next_message_ix, message.clone()); - self.messages_metadata - .insert(message.id, MessageMetadata { role, status }); - cx.emit(ContextEvent::MessagesEdited); - Some(message) - } else { - None - } - } - - fn split_message( - &mut self, - range: Range, - cx: &mut ModelContext, - ) -> (Option, Option) { - let start_message = self.message_for_offset(range.start, cx); - let end_message = self.message_for_offset(range.end, cx); - if let Some((start_message, end_message)) = start_message.zip(end_message) { - // Prevent splitting when range spans multiple messages. - if start_message.id != end_message.id { - return (None, None); - } - - let message = start_message; - let role = message.role; - let mut edited_buffer = false; - - let mut suffix_start = None; - if range.start > message.offset_range.start && range.end < message.offset_range.end - 1 - { - if self.buffer.read(cx).chars_at(range.end).next() == Some('\n') { - suffix_start = Some(range.end + 1); - } else if self.buffer.read(cx).reversed_chars_at(range.end).next() == Some('\n') { - suffix_start = Some(range.end); - } - } - - let suffix = if let Some(suffix_start) = suffix_start { - MessageAnchor { - id: MessageId(post_inc(&mut self.next_message_id.0)), - start: self.buffer.read(cx).anchor_before(suffix_start), - } - } else { - self.buffer.update(cx, |buffer, cx| { - buffer.edit([(range.end..range.end, "\n")], None, cx); - }); - edited_buffer = true; - MessageAnchor { - id: MessageId(post_inc(&mut self.next_message_id.0)), - start: self.buffer.read(cx).anchor_before(range.end + 1), - } - }; - - self.message_anchors - .insert(message.index_range.end + 1, suffix.clone()); - self.messages_metadata.insert( - suffix.id, - MessageMetadata { - role, - status: MessageStatus::Done, - }, - ); - - let new_messages = - if range.start == range.end || range.start == message.offset_range.start { - (None, Some(suffix)) - } else { - let mut prefix_end = None; - if range.start > message.offset_range.start - && range.end < message.offset_range.end - 1 - { - if self.buffer.read(cx).chars_at(range.start).next() == Some('\n') { - prefix_end = Some(range.start + 1); - } else if self.buffer.read(cx).reversed_chars_at(range.start).next() - == Some('\n') - { - prefix_end = Some(range.start); - } - } - - let selection = if let Some(prefix_end) = prefix_end { - cx.emit(ContextEvent::MessagesEdited); - MessageAnchor { - id: MessageId(post_inc(&mut self.next_message_id.0)), - start: self.buffer.read(cx).anchor_before(prefix_end), - } - } else { - self.buffer.update(cx, |buffer, cx| { - buffer.edit([(range.start..range.start, "\n")], None, cx) - }); - edited_buffer = true; - MessageAnchor { - id: MessageId(post_inc(&mut self.next_message_id.0)), - start: self.buffer.read(cx).anchor_before(range.end + 1), - } - }; - - self.message_anchors - .insert(message.index_range.end + 1, selection.clone()); - self.messages_metadata.insert( - selection.id, - MessageMetadata { - role, - status: MessageStatus::Done, - }, - ); - (Some(selection), Some(suffix)) - }; - - if !edited_buffer { - cx.emit(ContextEvent::MessagesEdited); - } - new_messages - } else { - (None, None) - } - } - - fn summarize(&mut self, cx: &mut ModelContext) { - if self.message_anchors.len() >= 2 && self.summary.is_none() { - if !CompletionProvider::global(cx).is_authenticated() { - return; - } - - let messages = self - .messages(cx) - .map(|message| message.to_request_message(self.buffer.read(cx))) - .chain(Some(LanguageModelRequestMessage { - role: Role::User, - content: "Summarize the context into a short title without punctuation.".into(), - })); - let request = LanguageModelRequest { - model: CompletionProvider::global(cx).model(), - messages: messages.collect(), - stop: vec![], - temperature: 1.0, - }; - - let response = CompletionProvider::global(cx).complete(request, cx); - self.pending_summary = cx.spawn(|this, mut cx| { - async move { - let response = response.await; - let mut messages = response.inner.await?; - - while let Some(message) = messages.next().await { - let text = message?; - let mut lines = text.lines(); - this.update(&mut cx, |this, cx| { - let summary = this.summary.get_or_insert(Default::default()); - summary.text.extend(lines.next()); - cx.emit(ContextEvent::SummaryChanged); - })?; - - // Stop if the LLM generated multiple lines. - if lines.next().is_some() { - break; - } - } - - this.update(&mut cx, |this, cx| { - if let Some(summary) = this.summary.as_mut() { - summary.done = true; - cx.emit(ContextEvent::SummaryChanged); - } - })?; - - anyhow::Ok(()) - } - .log_err() - }); - } - } - - fn message_for_offset(&self, offset: usize, cx: &AppContext) -> Option { - self.messages_for_offsets([offset], cx).pop() - } - - fn messages_for_offsets( - &self, - offsets: impl IntoIterator, - cx: &AppContext, - ) -> Vec { - let mut result = Vec::new(); - - let mut messages = self.messages(cx).peekable(); - let mut offsets = offsets.into_iter().peekable(); - let mut current_message = messages.next(); - while let Some(offset) = offsets.next() { - // Locate the message that contains the offset. - while current_message.as_ref().map_or(false, |message| { - !message.offset_range.contains(&offset) && messages.peek().is_some() - }) { - current_message = messages.next(); - } - let Some(message) = current_message.as_ref() else { - break; - }; - - // Skip offsets that are in the same message. - while offsets.peek().map_or(false, |offset| { - message.offset_range.contains(offset) || messages.peek().is_none() - }) { - offsets.next(); - } - - result.push(message.clone()); - } - result - } - - fn messages<'a>(&'a self, cx: &'a AppContext) -> impl 'a + Iterator { - let buffer = self.buffer.read(cx); - let mut message_anchors = self.message_anchors.iter().enumerate().peekable(); - iter::from_fn(move || { - if let Some((start_ix, message_anchor)) = message_anchors.next() { - let metadata = self.messages_metadata.get(&message_anchor.id)?; - let message_start = message_anchor.start.to_offset(buffer); - let mut message_end = None; - let mut end_ix = start_ix; - while let Some((_, next_message)) = message_anchors.peek() { - if next_message.start.is_valid(buffer) { - message_end = Some(next_message.start); - break; - } else { - end_ix += 1; - message_anchors.next(); - } - } - let message_end = message_end - .unwrap_or(language::Anchor::MAX) - .to_offset(buffer); - - return Some(Message { - index_range: start_ix..end_ix, - offset_range: message_start..message_end, - id: message_anchor.id, - anchor: message_anchor.start, - role: metadata.role, - status: metadata.status.clone(), - }); - } - None - }) - } - - fn save( - &mut self, - debounce: Option, - fs: Arc, - cx: &mut ModelContext, - ) { - self.pending_save = cx.spawn(|this, mut cx| async move { - if let Some(debounce) = debounce { - cx.background_executor().timer(debounce).await; - } - - let (old_path, summary) = this.read_with(&cx, |this, _| { - let path = this.path.clone(); - let summary = if let Some(summary) = this.summary.as_ref() { - if summary.done { - Some(summary.text.clone()) - } else { - None - } - } else { - None - }; - (path, summary) - })?; - - if let Some(summary) = summary { - let context = this.read_with(&cx, |this, cx| this.serialize(cx))?; - let path = if let Some(old_path) = old_path { - old_path - } else { - let mut discriminant = 1; - let mut new_path; - loop { - new_path = contexts_dir().join(&format!( - "{} - {}.zed.json", - summary.trim(), - discriminant - )); - if fs.is_file(&new_path).await { - discriminant += 1; - } else { - break; - } - } - new_path - }; - - fs.create_dir(contexts_dir().as_ref()).await?; - fs.atomic_write(path.clone(), serde_json::to_string(&context).unwrap()) - .await?; - this.update(&mut cx, |this, _| this.path = Some(path))?; - } - - Ok(()) - }); - } -} - -#[derive(Debug)] -enum EditParsingState { - None, - InOldText { - path: PathBuf, - start_offset: usize, - old_text_start_offset: usize, - }, - InNewText { - path: PathBuf, - start_offset: usize, - old_text_range: Range, - new_text_start_offset: usize, - }, -} - -#[derive(Clone, Debug, PartialEq)] -struct EditSuggestion { - source_range: Range, - full_path: PathBuf, -} - -struct ParsedEditSuggestion { - path: PathBuf, - outer_range: Range, - old_text_range: Range, - new_text_range: Range, -} - -fn parse_next_edit_suggestion(lines: &mut rope::Lines) -> Option { - let mut state = EditParsingState::None; - loop { - let offset = lines.offset(); - let message_line = lines.next()?; - match state { - EditParsingState::None => { - if let Some(rest) = message_line.strip_prefix("```edit ") { - let path = rest.trim(); - if !path.is_empty() { - state = EditParsingState::InOldText { - path: PathBuf::from(path), - start_offset: offset, - old_text_start_offset: lines.offset(), - }; - } - } - } - EditParsingState::InOldText { - path, - start_offset, - old_text_start_offset, - } => { - if message_line == "---" { - state = EditParsingState::InNewText { - path, - start_offset, - old_text_range: old_text_start_offset..offset, - new_text_start_offset: lines.offset(), - }; - } else { - state = EditParsingState::InOldText { - path, - start_offset, - old_text_start_offset, - }; - } - } - EditParsingState::InNewText { - path, - start_offset, - old_text_range, - new_text_start_offset, - } => { - if message_line == "```" { - return Some(ParsedEditSuggestion { - path, - outer_range: start_offset..offset + "```".len(), - old_text_range, - new_text_range: new_text_start_offset..offset, - }); - } else { - state = EditParsingState::InNewText { - path, - start_offset, - old_text_range, - new_text_start_offset, - }; - } - } - } - } -} - -#[derive(Clone)] -struct PendingSlashCommand { - name: String, - argument: Option, - status: PendingSlashCommandStatus, - source_range: Range, -} - -#[derive(Clone)] -enum PendingSlashCommandStatus { - Idle, - Running { _task: Shared> }, - Error(String), -} - -struct PendingCompletion { - id: usize, - _task: Task<()>, -} - pub enum ContextEditorEvent { Edited, TabContentChanged, @@ -2013,7 +920,6 @@ pub struct ContextEditor { context: Model, fs: Arc, workspace: WeakView, - slash_command_registry: Arc, lsp_adapter_delegate: Option>, editor: View, blocks: HashSet, @@ -2026,31 +932,6 @@ pub struct ContextEditor { impl ContextEditor { const MAX_TAB_TITLE_LEN: usize = 16; - fn new( - language_registry: Arc, - slash_command_registry: Arc, - fs: Arc, - workspace: View, - cx: &mut ViewContext, - ) -> Self { - let telemetry = workspace.read(cx).client().telemetry().clone(); - let project = workspace.read(cx).project().clone(); - let lsp_adapter_delegate = make_lsp_adapter_delegate(&project, cx).log_err(); - - let context = cx.new_model(|cx| { - Context::new( - language_registry, - slash_command_registry, - Some(telemetry), - cx, - ) - }); - - let mut this = Self::for_context(context, fs, workspace, lsp_adapter_delegate, cx); - this.insert_default_prompt(cx); - this - } - fn for_context( context: Model, fs: Arc, @@ -2058,16 +939,13 @@ impl ContextEditor { lsp_adapter_delegate: Option>, cx: &mut ViewContext, ) -> Self { - let slash_command_registry = context.read(cx).slash_command_registry.clone(); - let completion_provider = SlashCommandCompletionProvider::new( - slash_command_registry.clone(), Some(cx.view().downgrade()), Some(workspace.downgrade()), ); let editor = cx.new_view(|cx| { - let mut editor = Editor::for_buffer(context.read(cx).buffer.clone(), None, cx); + let mut editor = Editor::for_buffer(context.read(cx).buffer().clone(), None, cx); editor.set_soft_wrap_mode(SoftWrap::EditorWidth, cx); editor.set_show_line_numbers(false, cx); editor.set_show_git_diff_gutter(false, cx); @@ -2086,11 +964,10 @@ impl ContextEditor { cx.subscribe(&editor, Self::handle_editor_search_event), ]; - let sections = context.read(cx).slash_command_output_sections.clone(); + let sections = context.read(cx).slash_command_output_sections().to_vec(); let mut this = Self { context, editor, - slash_command_registry, lsp_adapter_delegate, blocks: Default::default(), scroll_position: None, @@ -2112,13 +989,12 @@ impl ContextEditor { }); self.split(&Split, cx); let command = self.context.update(cx, |context, cx| { - context - .messages_metadata - .get_mut(&MessageId::default()) - .unwrap() - .role = Role::System; + let first_message_id = context.messages(cx).next().unwrap().id; + context.update_metadata(first_message_id, cx, |metadata| { + metadata.role = Role::System; + }); context.reparse_slash_commands(cx); - context.pending_slash_commands[0].clone() + context.pending_slash_commands()[0].clone() }); self.run_command( @@ -2147,7 +1023,7 @@ impl ContextEditor { .map(|message| { let cursor = message .start - .to_offset(self.context.read(cx).buffer.read(cx)); + .to_offset(self.context.read(cx).buffer().read(cx)); cursor..cursor }) .collect::>(); @@ -2194,7 +1070,7 @@ impl ContextEditor { } fn insert_command(&mut self, name: &str, cx: &mut ViewContext) { - if let Some(command) = self.slash_command_registry.command(name) { + if let Some(command) = SlashCommandRegistry::global(cx).command(name) { self.editor.update(cx, |editor, cx| { editor.transact(cx, |editor, cx| { editor.change_selections(Some(Autoscroll::fit()), cx, |s| s.try_cancel()); @@ -2271,7 +1147,7 @@ impl ContextEditor { workspace: WeakView, cx: &mut ViewContext, ) { - if let Some(command) = self.slash_command_registry.command(name) { + if let Some(command) = SlashCommandRegistry::global(cx).command(name) { if let Some(lsp_adapter_delegate) = self.lsp_adapter_delegate.clone() { let argument = argument.map(ToString::to_string); let output = command.run(argument.as_deref(), workspace, lsp_adapter_delegate, cx); @@ -2308,7 +1184,7 @@ impl ContextEditor { let excerpt_id = *buffer.as_singleton().unwrap().0; let context = self.context.read(cx); let highlighted_rows = context - .edit_suggestions + .edit_suggestions() .iter() .map(|suggestion| { let start = buffer @@ -2511,6 +1387,7 @@ impl ContextEditor { } } } + ContextEvent::Operation(_) => {} } } @@ -2681,7 +1558,7 @@ impl ContextEditor { }); h_flex() - .id(("message_header", message_id.0)) + .id(("message_header", message_id.as_u64())) .pl(cx.gutter_dimensions.full_width()) .h_11() .w_full() @@ -2832,7 +1709,7 @@ impl ContextEditor { if !range.is_empty() { spanned_messages += 1; write!(&mut copied_text, "## {}\n\n", message.role).unwrap(); - for chunk in context.buffer.read(cx).text_for_range(range) { + for chunk in context.buffer().read(cx).text_for_range(range) { copied_text.push_str(chunk); } copied_text.push('\n'); @@ -2874,13 +1751,13 @@ impl ContextEditor { } let context = self.context.read(cx); - let context_buffer = context.buffer.read(cx); + let context_buffer = context.buffer().read(cx); let context_buffer_snapshot = context_buffer.snapshot(); let selections = self.editor.read(cx).selections.disjoint_anchors(); let mut selections = selections.iter().peekable(); let selected_suggestions = context - .edit_suggestions + .edit_suggestions() .iter() .filter(|suggestion| { while let Some(selection) = selections.peek() { @@ -3037,8 +1914,7 @@ impl ContextEditor { fn title(&self, cx: &AppContext) -> String { self.context .read(cx) - .summary - .as_ref() + .summary() .map(|summary| summary.text.clone()) .unwrap_or_else(|| "New Context".into()) } @@ -3424,14 +2300,18 @@ pub struct ContextHistory { impl ContextHistory { fn new( + project: Model, context_store: Model, assistant_panel: WeakView, cx: &mut ViewContext, ) -> Self { let picker = cx.new_view(|cx| { - Picker::uniform_list(SavedContextPickerDelegate::new(context_store.clone()), cx) - .modal(false) - .max_height(None) + Picker::uniform_list( + SavedContextPickerDelegate::new(project, context_store.clone()), + cx, + ) + .modal(false) + .max_height(None) }); let _subscriptions = vec![ @@ -3454,12 +2334,19 @@ impl ContextHistory { event: &SavedContextPickerEvent, cx: &mut ViewContext, ) { - let SavedContextPickerEvent::Confirmed { path } = event; + let SavedContextPickerEvent::Confirmed(context) = event; self.assistant_panel - .update(cx, |assistant_panel, cx| { - assistant_panel - .open_context(path.clone(), cx) - .detach_and_log_err(cx); + .update(cx, |assistant_panel, cx| match context { + ContextMetadata::Remote(metadata) => { + assistant_panel + .open_remote_context(metadata.id.clone(), cx) + .detach_and_log_err(cx); + } + ContextMetadata::Saved(metadata) => { + assistant_panel + .open_saved_context(metadata.path.clone(), cx) + .detach_and_log_err(cx); + } }) .ok(); } @@ -3496,31 +2383,6 @@ impl Item for ContextHistory { } } -#[derive(Clone, Debug)] -struct MessageAnchor { - id: MessageId, - start: language::Anchor, -} - -#[derive(Clone, Debug)] -pub struct Message { - offset_range: Range, - index_range: Range, - id: MessageId, - anchor: language::Anchor, - role: Role, - status: MessageStatus, -} - -impl Message { - fn to_request_message(&self, buffer: &Buffer) -> LanguageModelRequestMessage { - LanguageModelRequestMessage { - role: self.role, - content: buffer.text_for_range(self.offset_range.clone()).collect(), - } - } -} - type ToggleFold = Arc; fn render_slash_command_output_toggle( @@ -3624,600 +2486,3 @@ fn slash_command_error_block_renderer(message: String) -> RenderBlock { .into_any() }) } - -#[cfg(test)] -mod tests { - use super::*; - use crate::{ - slash_command::{active_command, file_command}, - FakeCompletionProvider, MessageId, - }; - use fs::FakeFs; - use gpui::{AppContext, TestAppContext}; - use rope::Rope; - use serde_json::json; - use settings::SettingsStore; - use std::{cell::RefCell, path::Path, rc::Rc}; - use unindent::Unindent; - use util::test::marked_text_ranges; - - #[gpui::test] - fn test_inserting_and_removing_messages(cx: &mut AppContext) { - let settings_store = SettingsStore::test(cx); - FakeCompletionProvider::setup_test(cx); - cx.set_global(settings_store); - init(cx); - let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone())); - - let context = cx.new_model(|cx| Context::new(registry, Default::default(), None, cx)); - let buffer = context.read(cx).buffer.clone(); - - let message_1 = context.read(cx).message_anchors[0].clone(); - assert_eq!( - messages(&context, cx), - vec![(message_1.id, Role::User, 0..0)] - ); - - let message_2 = context.update(cx, |context, cx| { - context - .insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx) - .unwrap() - }); - assert_eq!( - messages(&context, cx), - vec![ - (message_1.id, Role::User, 0..1), - (message_2.id, Role::Assistant, 1..1) - ] - ); - - buffer.update(cx, |buffer, cx| { - buffer.edit([(0..0, "1"), (1..1, "2")], None, cx) - }); - assert_eq!( - messages(&context, cx), - vec![ - (message_1.id, Role::User, 0..2), - (message_2.id, Role::Assistant, 2..3) - ] - ); - - let message_3 = context.update(cx, |context, cx| { - context - .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx) - .unwrap() - }); - assert_eq!( - messages(&context, cx), - vec![ - (message_1.id, Role::User, 0..2), - (message_2.id, Role::Assistant, 2..4), - (message_3.id, Role::User, 4..4) - ] - ); - - let message_4 = context.update(cx, |context, cx| { - context - .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx) - .unwrap() - }); - assert_eq!( - messages(&context, cx), - vec![ - (message_1.id, Role::User, 0..2), - (message_2.id, Role::Assistant, 2..4), - (message_4.id, Role::User, 4..5), - (message_3.id, Role::User, 5..5), - ] - ); - - buffer.update(cx, |buffer, cx| { - buffer.edit([(4..4, "C"), (5..5, "D")], None, cx) - }); - assert_eq!( - messages(&context, cx), - vec![ - (message_1.id, Role::User, 0..2), - (message_2.id, Role::Assistant, 2..4), - (message_4.id, Role::User, 4..6), - (message_3.id, Role::User, 6..7), - ] - ); - - // Deleting across message boundaries merges the messages. - buffer.update(cx, |buffer, cx| buffer.edit([(1..4, "")], None, cx)); - assert_eq!( - messages(&context, cx), - vec![ - (message_1.id, Role::User, 0..3), - (message_3.id, Role::User, 3..4), - ] - ); - - // Undoing the deletion should also undo the merge. - buffer.update(cx, |buffer, cx| buffer.undo(cx)); - assert_eq!( - messages(&context, cx), - vec![ - (message_1.id, Role::User, 0..2), - (message_2.id, Role::Assistant, 2..4), - (message_4.id, Role::User, 4..6), - (message_3.id, Role::User, 6..7), - ] - ); - - // Redoing the deletion should also redo the merge. - buffer.update(cx, |buffer, cx| buffer.redo(cx)); - assert_eq!( - messages(&context, cx), - vec![ - (message_1.id, Role::User, 0..3), - (message_3.id, Role::User, 3..4), - ] - ); - - // Ensure we can still insert after a merged message. - let message_5 = context.update(cx, |context, cx| { - context - .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx) - .unwrap() - }); - assert_eq!( - messages(&context, cx), - vec![ - (message_1.id, Role::User, 0..3), - (message_5.id, Role::System, 3..4), - (message_3.id, Role::User, 4..5) - ] - ); - } - - #[gpui::test] - fn test_message_splitting(cx: &mut AppContext) { - let settings_store = SettingsStore::test(cx); - cx.set_global(settings_store); - FakeCompletionProvider::setup_test(cx); - init(cx); - let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone())); - - let context = cx.new_model(|cx| Context::new(registry, Default::default(), None, cx)); - let buffer = context.read(cx).buffer.clone(); - - let message_1 = context.read(cx).message_anchors[0].clone(); - assert_eq!( - messages(&context, cx), - vec![(message_1.id, Role::User, 0..0)] - ); - - buffer.update(cx, |buffer, cx| { - buffer.edit([(0..0, "aaa\nbbb\nccc\nddd\n")], None, cx) - }); - - let (_, message_2) = context.update(cx, |context, cx| context.split_message(3..3, cx)); - let message_2 = message_2.unwrap(); - - // We recycle newlines in the middle of a split message - assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\nddd\n"); - assert_eq!( - messages(&context, cx), - vec![ - (message_1.id, Role::User, 0..4), - (message_2.id, Role::User, 4..16), - ] - ); - - let (_, message_3) = context.update(cx, |context, cx| context.split_message(3..3, cx)); - let message_3 = message_3.unwrap(); - - // We don't recycle newlines at the end of a split message - assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n"); - assert_eq!( - messages(&context, cx), - vec![ - (message_1.id, Role::User, 0..4), - (message_3.id, Role::User, 4..5), - (message_2.id, Role::User, 5..17), - ] - ); - - let (_, message_4) = context.update(cx, |context, cx| context.split_message(9..9, cx)); - let message_4 = message_4.unwrap(); - assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n"); - assert_eq!( - messages(&context, cx), - vec![ - (message_1.id, Role::User, 0..4), - (message_3.id, Role::User, 4..5), - (message_2.id, Role::User, 5..9), - (message_4.id, Role::User, 9..17), - ] - ); - - let (_, message_5) = context.update(cx, |context, cx| context.split_message(9..9, cx)); - let message_5 = message_5.unwrap(); - assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\nddd\n"); - assert_eq!( - messages(&context, cx), - vec![ - (message_1.id, Role::User, 0..4), - (message_3.id, Role::User, 4..5), - (message_2.id, Role::User, 5..9), - (message_4.id, Role::User, 9..10), - (message_5.id, Role::User, 10..18), - ] - ); - - let (message_6, message_7) = - context.update(cx, |context, cx| context.split_message(14..16, cx)); - let message_6 = message_6.unwrap(); - let message_7 = message_7.unwrap(); - assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\ndd\nd\n"); - assert_eq!( - messages(&context, cx), - vec![ - (message_1.id, Role::User, 0..4), - (message_3.id, Role::User, 4..5), - (message_2.id, Role::User, 5..9), - (message_4.id, Role::User, 9..10), - (message_5.id, Role::User, 10..14), - (message_6.id, Role::User, 14..17), - (message_7.id, Role::User, 17..19), - ] - ); - } - - #[gpui::test] - fn test_messages_for_offsets(cx: &mut AppContext) { - let settings_store = SettingsStore::test(cx); - FakeCompletionProvider::setup_test(cx); - cx.set_global(settings_store); - init(cx); - let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone())); - let context = cx.new_model(|cx| Context::new(registry, Default::default(), None, cx)); - let buffer = context.read(cx).buffer.clone(); - - let message_1 = context.read(cx).message_anchors[0].clone(); - assert_eq!( - messages(&context, cx), - vec![(message_1.id, Role::User, 0..0)] - ); - - buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx)); - let message_2 = context - .update(cx, |context, cx| { - context.insert_message_after(message_1.id, Role::User, MessageStatus::Done, cx) - }) - .unwrap(); - buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbb")], None, cx)); - - let message_3 = context - .update(cx, |context, cx| { - context.insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx) - }) - .unwrap(); - buffer.update(cx, |buffer, cx| buffer.edit([(8..8, "ccc")], None, cx)); - - assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc"); - assert_eq!( - messages(&context, cx), - vec![ - (message_1.id, Role::User, 0..4), - (message_2.id, Role::User, 4..8), - (message_3.id, Role::User, 8..11) - ] - ); - - assert_eq!( - message_ids_for_offsets(&context, &[0, 4, 9], cx), - [message_1.id, message_2.id, message_3.id] - ); - assert_eq!( - message_ids_for_offsets(&context, &[0, 1, 11], cx), - [message_1.id, message_3.id] - ); - - let message_4 = context - .update(cx, |context, cx| { - context.insert_message_after(message_3.id, Role::User, MessageStatus::Done, cx) - }) - .unwrap(); - assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\n"); - assert_eq!( - messages(&context, cx), - vec![ - (message_1.id, Role::User, 0..4), - (message_2.id, Role::User, 4..8), - (message_3.id, Role::User, 8..12), - (message_4.id, Role::User, 12..12) - ] - ); - assert_eq!( - message_ids_for_offsets(&context, &[0, 4, 8, 12], cx), - [message_1.id, message_2.id, message_3.id, message_4.id] - ); - - fn message_ids_for_offsets( - context: &Model, - offsets: &[usize], - cx: &AppContext, - ) -> Vec { - context - .read(cx) - .messages_for_offsets(offsets.iter().copied(), cx) - .into_iter() - .map(|message| message.id) - .collect() - } - } - - #[gpui::test] - async fn test_slash_commands(cx: &mut TestAppContext) { - let settings_store = cx.update(SettingsStore::test); - cx.set_global(settings_store); - cx.update(|cx| FakeCompletionProvider::setup_test(cx)); - - cx.update(Project::init_settings); - cx.update(init); - let fs = FakeFs::new(cx.background_executor.clone()); - - fs.insert_tree( - "/test", - json!({ - "src": { - "lib.rs": "fn one() -> usize { 1 }", - "main.rs": " - use crate::one; - fn main() { one(); } - ".unindent(), - } - }), - ) - .await; - - let slash_command_registry = SlashCommandRegistry::new(); - slash_command_registry.register_command(file_command::FileSlashCommand, false); - slash_command_registry.register_command(active_command::ActiveSlashCommand, false); - - let registry = Arc::new(LanguageRegistry::test(cx.executor())); - let context = - cx.new_model(|cx| Context::new(registry.clone(), slash_command_registry, None, cx)); - - let output_ranges = Rc::new(RefCell::new(HashSet::default())); - context.update(cx, |_, cx| { - cx.subscribe(&context, { - let ranges = output_ranges.clone(); - move |_, _, event, _| match event { - ContextEvent::PendingSlashCommandsUpdated { removed, updated } => { - for range in removed { - ranges.borrow_mut().remove(range); - } - for command in updated { - ranges.borrow_mut().insert(command.source_range.clone()); - } - } - _ => {} - } - }) - .detach(); - }); - - let buffer = context.read_with(cx, |context, _| context.buffer.clone()); - - // Insert a slash command - buffer.update(cx, |buffer, cx| { - buffer.edit([(0..0, "/file src/lib.rs")], None, cx); - }); - assert_text_and_output_ranges( - &buffer, - &output_ranges.borrow(), - " - «/file src/lib.rs» - " - .unindent() - .trim_end(), - cx, - ); - - // Edit the argument of the slash command. - buffer.update(cx, |buffer, cx| { - let edit_offset = buffer.text().find("lib.rs").unwrap(); - buffer.edit([(edit_offset..edit_offset + "lib".len(), "main")], None, cx); - }); - assert_text_and_output_ranges( - &buffer, - &output_ranges.borrow(), - " - «/file src/main.rs» - " - .unindent() - .trim_end(), - cx, - ); - - // Edit the name of the slash command, using one that doesn't exist. - buffer.update(cx, |buffer, cx| { - let edit_offset = buffer.text().find("/file").unwrap(); - buffer.edit( - [(edit_offset..edit_offset + "/file".len(), "/unknown")], - None, - cx, - ); - }); - assert_text_and_output_ranges( - &buffer, - &output_ranges.borrow(), - " - /unknown src/main.rs - " - .unindent() - .trim_end(), - cx, - ); - - #[track_caller] - fn assert_text_and_output_ranges( - buffer: &Model, - ranges: &HashSet>, - expected_marked_text: &str, - cx: &mut TestAppContext, - ) { - let (expected_text, expected_ranges) = marked_text_ranges(expected_marked_text, false); - let (actual_text, actual_ranges) = buffer.update(cx, |buffer, _| { - let mut ranges = ranges - .iter() - .map(|range| range.to_offset(buffer)) - .collect::>(); - ranges.sort_by_key(|a| a.start); - (buffer.text(), ranges) - }); - - assert_eq!(actual_text, expected_text); - assert_eq!(actual_ranges, expected_ranges); - } - } - - #[test] - fn test_parse_next_edit_suggestion() { - let text = " - some output: - - ```edit src/foo.rs - let a = 1; - let b = 2; - --- - let w = 1; - let x = 2; - let y = 3; - let z = 4; - ``` - - some more output: - - ```edit src/foo.rs - let c = 1; - --- - ``` - - and the conclusion. - " - .unindent(); - - let rope = Rope::from(text.as_str()); - let mut lines = rope.chunks().lines(); - let mut suggestions = vec![]; - while let Some(suggestion) = parse_next_edit_suggestion(&mut lines) { - suggestions.push(( - suggestion.path.clone(), - text[suggestion.old_text_range].to_string(), - text[suggestion.new_text_range].to_string(), - )); - } - - assert_eq!( - suggestions, - vec![ - ( - Path::new("src/foo.rs").into(), - [ - " let a = 1;", // - " let b = 2;", - "", - ] - .join("\n"), - [ - " let w = 1;", - " let x = 2;", - " let y = 3;", - " let z = 4;", - "", - ] - .join("\n"), - ), - ( - Path::new("src/foo.rs").into(), - [ - " let c = 1;", // - "", - ] - .join("\n"), - String::new(), - ) - ] - ); - } - - #[gpui::test] - async fn test_serialization(cx: &mut TestAppContext) { - let settings_store = cx.update(SettingsStore::test); - cx.set_global(settings_store); - cx.update(FakeCompletionProvider::setup_test); - cx.update(init); - let registry = Arc::new(LanguageRegistry::test(cx.executor())); - let context = - cx.new_model(|cx| Context::new(registry.clone(), Default::default(), None, cx)); - let buffer = context.read_with(cx, |context, _| context.buffer.clone()); - let message_0 = context.read_with(cx, |context, _| context.message_anchors[0].id); - let message_1 = context.update(cx, |context, cx| { - context - .insert_message_after(message_0, Role::Assistant, MessageStatus::Done, cx) - .unwrap() - }); - let message_2 = context.update(cx, |context, cx| { - context - .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx) - .unwrap() - }); - buffer.update(cx, |buffer, cx| { - buffer.edit([(0..0, "a"), (1..1, "b\nc")], None, cx); - buffer.finalize_last_transaction(); - }); - let _message_3 = context.update(cx, |context, cx| { - context - .insert_message_after(message_2.id, Role::System, MessageStatus::Done, cx) - .unwrap() - }); - buffer.update(cx, |buffer, cx| buffer.undo(cx)); - assert_eq!(buffer.read_with(cx, |buffer, _| buffer.text()), "a\nb\nc\n"); - assert_eq!( - cx.read(|cx| messages(&context, cx)), - [ - (message_0, Role::User, 0..2), - (message_1.id, Role::Assistant, 2..6), - (message_2.id, Role::System, 6..6), - ] - ); - - let deserialized_context = Context::deserialize( - context.read_with(cx, |context, cx| context.serialize(cx)), - Default::default(), - registry.clone(), - Default::default(), - None, - &mut cx.to_async(), - ) - .await - .unwrap(); - let deserialized_buffer = - deserialized_context.read_with(cx, |context, _| context.buffer.clone()); - assert_eq!( - deserialized_buffer.read_with(cx, |buffer, _| buffer.text()), - "a\nb\nc\n" - ); - assert_eq!( - cx.read(|cx| messages(&deserialized_context, cx)), - [ - (message_0, Role::User, 0..2), - (message_1.id, Role::Assistant, 2..6), - (message_2.id, Role::System, 6..6), - ] - ); - } - - fn messages(context: &Model, cx: &AppContext) -> Vec<(MessageId, Role, Range)> { - context - .read(cx) - .messages(cx) - .map(|message| (message.id, message.role, message.offset_range)) - .collect() - } -} diff --git a/crates/assistant/src/completion_provider.rs b/crates/assistant/src/completion_provider.rs index 36a5bc883e..a51d3256e2 100644 --- a/crates/assistant/src/completion_provider.rs +++ b/crates/assistant/src/completion_provider.rs @@ -1,13 +1,13 @@ mod anthropic; mod cloud; -#[cfg(test)] +#[cfg(any(test, feature = "test-support"))] mod fake; mod ollama; mod open_ai; pub use anthropic::*; pub use cloud::*; -#[cfg(test)] +#[cfg(any(test, feature = "test-support"))] pub use fake::*; pub use ollama::*; pub use open_ai::*; diff --git a/crates/assistant/src/completion_provider/fake.rs b/crates/assistant/src/completion_provider/fake.rs index f07a3befd2..434e584d00 100644 --- a/crates/assistant/src/completion_provider/fake.rs +++ b/crates/assistant/src/completion_provider/fake.rs @@ -13,7 +13,6 @@ pub struct FakeCompletionProvider { } impl FakeCompletionProvider { - #[cfg(test)] pub fn setup_test(cx: &mut AppContext) -> Self { use crate::CompletionProvider; use parking_lot::RwLock; diff --git a/crates/assistant/src/context.rs b/crates/assistant/src/context.rs new file mode 100644 index 0000000000..e3dcd36313 --- /dev/null +++ b/crates/assistant/src/context.rs @@ -0,0 +1,3009 @@ +use crate::{ + slash_command::SlashCommandLine, CompletionProvider, LanguageModelRequest, + LanguageModelRequestMessage, MessageId, MessageStatus, Role, +}; +use anyhow::{anyhow, Context as _, Result}; +use assistant_slash_command::{ + SlashCommandOutput, SlashCommandOutputSection, SlashCommandRegistry, +}; +use client::{proto, telemetry::Telemetry}; +use clock::ReplicaId; +use collections::{HashMap, HashSet}; +use fs::Fs; +use futures::{future::Shared, FutureExt, StreamExt}; +use gpui::{AppContext, Context as _, EventEmitter, Model, ModelContext, Subscription, Task}; +use language::{AnchorRangeExt, Bias, Buffer, LanguageRegistry, OffsetRangeExt, Point, ToOffset}; +use open_ai::Model as OpenAiModel; +use paths::contexts_dir; +use serde::{Deserialize, Serialize}; +use std::{ + cmp::Ordering, + iter, mem, + ops::Range, + path::{Path, PathBuf}, + sync::Arc, + time::{Duration, Instant}, +}; +use telemetry_events::AssistantKind; +use ui::SharedString; +use util::{post_inc, TryFutureExt}; +use uuid::Uuid; + +#[derive(Clone, Eq, PartialEq, Hash, PartialOrd, Ord, Serialize, Deserialize)] +pub struct ContextId(String); + +impl ContextId { + pub fn new() -> Self { + Self(Uuid::new_v4().to_string()) + } + + pub fn from_proto(id: String) -> Self { + Self(id) + } + + pub fn to_proto(&self) -> String { + self.0.clone() + } +} + +#[derive(Clone, Debug)] +pub enum ContextOperation { + InsertMessage { + anchor: MessageAnchor, + metadata: MessageMetadata, + version: clock::Global, + }, + UpdateMessage { + message_id: MessageId, + metadata: MessageMetadata, + version: clock::Global, + }, + UpdateSummary { + summary: ContextSummary, + version: clock::Global, + }, + SlashCommandFinished { + id: SlashCommandId, + output_range: Range, + sections: Vec>, + version: clock::Global, + }, + BufferOperation(language::Operation), +} + +impl ContextOperation { + pub fn from_proto(op: proto::ContextOperation) -> Result { + match op.variant.context("invalid variant")? { + proto::context_operation::Variant::InsertMessage(insert) => { + let message = insert.message.context("invalid message")?; + let id = MessageId(language::proto::deserialize_timestamp( + message.id.context("invalid id")?, + )); + Ok(Self::InsertMessage { + anchor: MessageAnchor { + id, + start: language::proto::deserialize_anchor( + message.start.context("invalid anchor")?, + ) + .context("invalid anchor")?, + }, + metadata: MessageMetadata { + role: Role::from_proto(message.role), + status: MessageStatus::from_proto( + message.status.context("invalid status")?, + ), + timestamp: id.0, + }, + version: language::proto::deserialize_version(&insert.version), + }) + } + proto::context_operation::Variant::UpdateMessage(update) => Ok(Self::UpdateMessage { + message_id: MessageId(language::proto::deserialize_timestamp( + update.message_id.context("invalid message id")?, + )), + metadata: MessageMetadata { + role: Role::from_proto(update.role), + status: MessageStatus::from_proto(update.status.context("invalid status")?), + timestamp: language::proto::deserialize_timestamp( + update.timestamp.context("invalid timestamp")?, + ), + }, + version: language::proto::deserialize_version(&update.version), + }), + proto::context_operation::Variant::UpdateSummary(update) => Ok(Self::UpdateSummary { + summary: ContextSummary { + text: update.summary, + done: update.done, + timestamp: language::proto::deserialize_timestamp( + update.timestamp.context("invalid timestamp")?, + ), + }, + version: language::proto::deserialize_version(&update.version), + }), + proto::context_operation::Variant::SlashCommandFinished(finished) => { + Ok(Self::SlashCommandFinished { + id: SlashCommandId(language::proto::deserialize_timestamp( + finished.id.context("invalid id")?, + )), + output_range: language::proto::deserialize_anchor_range( + finished.output_range.context("invalid range")?, + )?, + sections: finished + .sections + .into_iter() + .map(|section| { + Ok(SlashCommandOutputSection { + range: language::proto::deserialize_anchor_range( + section.range.context("invalid range")?, + )?, + icon: section.icon_name.parse()?, + label: section.label.into(), + }) + }) + .collect::>>()?, + version: language::proto::deserialize_version(&finished.version), + }) + } + proto::context_operation::Variant::BufferOperation(op) => Ok(Self::BufferOperation( + language::proto::deserialize_operation( + op.operation.context("invalid buffer operation")?, + )?, + )), + } + } + + pub fn to_proto(&self) -> proto::ContextOperation { + match self { + Self::InsertMessage { + anchor, + metadata, + version, + } => proto::ContextOperation { + variant: Some(proto::context_operation::Variant::InsertMessage( + proto::context_operation::InsertMessage { + message: Some(proto::ContextMessage { + id: Some(language::proto::serialize_timestamp(anchor.id.0)), + start: Some(language::proto::serialize_anchor(&anchor.start)), + role: metadata.role.to_proto() as i32, + status: Some(metadata.status.to_proto()), + }), + version: language::proto::serialize_version(version), + }, + )), + }, + Self::UpdateMessage { + message_id, + metadata, + version, + } => proto::ContextOperation { + variant: Some(proto::context_operation::Variant::UpdateMessage( + proto::context_operation::UpdateMessage { + message_id: Some(language::proto::serialize_timestamp(message_id.0)), + role: metadata.role.to_proto() as i32, + status: Some(metadata.status.to_proto()), + timestamp: Some(language::proto::serialize_timestamp(metadata.timestamp)), + version: language::proto::serialize_version(version), + }, + )), + }, + Self::UpdateSummary { summary, version } => proto::ContextOperation { + variant: Some(proto::context_operation::Variant::UpdateSummary( + proto::context_operation::UpdateSummary { + summary: summary.text.clone(), + done: summary.done, + timestamp: Some(language::proto::serialize_timestamp(summary.timestamp)), + version: language::proto::serialize_version(version), + }, + )), + }, + Self::SlashCommandFinished { + id, + output_range, + sections, + version, + } => proto::ContextOperation { + variant: Some(proto::context_operation::Variant::SlashCommandFinished( + proto::context_operation::SlashCommandFinished { + id: Some(language::proto::serialize_timestamp(id.0)), + output_range: Some(language::proto::serialize_anchor_range( + output_range.clone(), + )), + sections: sections + .iter() + .map(|section| { + let icon_name: &'static str = section.icon.into(); + proto::SlashCommandOutputSection { + range: Some(language::proto::serialize_anchor_range( + section.range.clone(), + )), + icon_name: icon_name.to_string(), + label: section.label.to_string(), + } + }) + .collect(), + version: language::proto::serialize_version(version), + }, + )), + }, + Self::BufferOperation(operation) => proto::ContextOperation { + variant: Some(proto::context_operation::Variant::BufferOperation( + proto::context_operation::BufferOperation { + operation: Some(language::proto::serialize_operation(operation)), + }, + )), + }, + } + } + + fn timestamp(&self) -> clock::Lamport { + match self { + Self::InsertMessage { anchor, .. } => anchor.id.0, + Self::UpdateMessage { metadata, .. } => metadata.timestamp, + Self::UpdateSummary { summary, .. } => summary.timestamp, + Self::SlashCommandFinished { id, .. } => id.0, + Self::BufferOperation(_) => { + panic!("reading the timestamp of a buffer operation is not supported") + } + } + } + + /// Returns the current version of the context operation. + pub fn version(&self) -> &clock::Global { + match self { + Self::InsertMessage { version, .. } + | Self::UpdateMessage { version, .. } + | Self::UpdateSummary { version, .. } + | Self::SlashCommandFinished { version, .. } => version, + Self::BufferOperation(_) => { + panic!("reading the version of a buffer operation is not supported") + } + } + } +} + +#[derive(Clone)] +pub enum ContextEvent { + MessagesEdited, + SummaryChanged, + EditSuggestionsChanged, + StreamedCompletion, + PendingSlashCommandsUpdated { + removed: Vec>, + updated: Vec, + }, + SlashCommandFinished { + output_range: Range, + sections: Vec>, + run_commands_in_output: bool, + }, + Operation(ContextOperation), +} + +#[derive(Clone, Default, Debug)] +pub struct ContextSummary { + pub text: String, + done: bool, + timestamp: clock::Lamport, +} + +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct MessageAnchor { + pub id: MessageId, + pub start: language::Anchor, +} + +#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)] +pub struct MessageMetadata { + pub role: Role, + status: MessageStatus, + timestamp: clock::Lamport, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Message { + pub offset_range: Range, + pub index_range: Range, + pub id: MessageId, + pub anchor: language::Anchor, + pub role: Role, + pub status: MessageStatus, +} + +impl Message { + fn to_request_message(&self, buffer: &Buffer) -> LanguageModelRequestMessage { + LanguageModelRequestMessage { + role: self.role, + content: buffer.text_for_range(self.offset_range.clone()).collect(), + } + } +} + +struct PendingCompletion { + id: usize, + _task: Task<()>, +} + +#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)] +pub struct SlashCommandId(clock::Lamport); + +pub struct Context { + id: ContextId, + timestamp: clock::Lamport, + version: clock::Global, + pending_ops: Vec, + operations: Vec, + buffer: Model, + edit_suggestions: Vec, + pending_slash_commands: Vec, + edits_since_last_slash_command_parse: language::Subscription, + finished_slash_commands: HashSet, + slash_command_output_sections: Vec>, + message_anchors: Vec, + messages_metadata: HashMap, + summary: Option, + pending_summary: Task>, + completion_count: usize, + pending_completions: Vec, + token_count: Option, + pending_token_count: Task>, + pending_edit_suggestion_parse: Option>, + pending_save: Task>, + path: Option, + _subscriptions: Vec, + telemetry: Option>, + language_registry: Arc, +} + +impl EventEmitter for Context {} + +impl Context { + pub fn local( + language_registry: Arc, + telemetry: Option>, + cx: &mut ModelContext, + ) -> Self { + Self::new( + ContextId::new(), + ReplicaId::default(), + language::Capability::ReadWrite, + language_registry, + telemetry, + cx, + ) + } + + pub fn new( + id: ContextId, + replica_id: ReplicaId, + capability: language::Capability, + language_registry: Arc, + telemetry: Option>, + cx: &mut ModelContext, + ) -> Self { + let buffer = cx.new_model(|_cx| { + let mut buffer = Buffer::remote( + language::BufferId::new(1).unwrap(), + replica_id, + capability, + "", + ); + buffer.set_language_registry(language_registry.clone()); + buffer + }); + let edits_since_last_slash_command_parse = + buffer.update(cx, |buffer, _| buffer.subscribe()); + let mut this = Self { + id, + timestamp: clock::Lamport::new(replica_id), + version: clock::Global::new(), + pending_ops: Vec::new(), + operations: Vec::new(), + message_anchors: Default::default(), + messages_metadata: Default::default(), + edit_suggestions: Vec::new(), + pending_slash_commands: Vec::new(), + finished_slash_commands: HashSet::default(), + slash_command_output_sections: Vec::new(), + edits_since_last_slash_command_parse, + summary: None, + pending_summary: Task::ready(None), + completion_count: Default::default(), + pending_completions: Default::default(), + token_count: None, + pending_token_count: Task::ready(None), + pending_edit_suggestion_parse: None, + _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)], + pending_save: Task::ready(Ok(())), + path: None, + buffer, + telemetry, + language_registry, + }; + + let first_message_id = MessageId(clock::Lamport { + replica_id: 0, + value: 0, + }); + let message = MessageAnchor { + id: first_message_id, + start: language::Anchor::MIN, + }; + this.messages_metadata.insert( + first_message_id, + MessageMetadata { + role: Role::User, + status: MessageStatus::Done, + timestamp: first_message_id.0, + }, + ); + this.message_anchors.push(message); + + this.set_language(cx); + this.count_remaining_tokens(cx); + this + } + + fn serialize(&self, cx: &AppContext) -> SavedContext { + let buffer = self.buffer.read(cx); + SavedContext { + id: Some(self.id.clone()), + zed: "context".into(), + version: SavedContext::VERSION.into(), + text: buffer.text(), + messages: self + .messages(cx) + .map(|message| SavedMessage { + id: message.id, + start: message.offset_range.start, + metadata: self.messages_metadata[&message.id].clone(), + }) + .collect(), + summary: self + .summary + .as_ref() + .map(|summary| summary.text.clone()) + .unwrap_or_default(), + slash_command_output_sections: self + .slash_command_output_sections + .iter() + .filter_map(|section| { + let range = section.range.to_offset(buffer); + if section.range.start.is_valid(buffer) && !range.is_empty() { + Some(assistant_slash_command::SlashCommandOutputSection { + range, + icon: section.icon, + label: section.label.clone(), + }) + } else { + None + } + }) + .collect(), + } + } + + #[allow(clippy::too_many_arguments)] + pub fn deserialize( + saved_context: SavedContext, + path: PathBuf, + language_registry: Arc, + telemetry: Option>, + cx: &mut ModelContext, + ) -> Self { + let id = saved_context.id.clone().unwrap_or_else(|| ContextId::new()); + let mut this = Self::new( + id, + ReplicaId::default(), + language::Capability::ReadWrite, + language_registry, + telemetry, + cx, + ); + this.path = Some(path); + this.buffer.update(cx, |buffer, cx| { + buffer.set_text(saved_context.text.as_str(), cx) + }); + let operations = saved_context.into_ops(&this.buffer, cx); + this.apply_ops(operations, cx).unwrap(); + this + } + + pub fn id(&self) -> &ContextId { + &self.id + } + + pub fn replica_id(&self) -> ReplicaId { + self.timestamp.replica_id + } + + pub fn version(&self, cx: &AppContext) -> ContextVersion { + ContextVersion { + context: self.version.clone(), + buffer: self.buffer.read(cx).version(), + } + } + + pub fn set_capability( + &mut self, + capability: language::Capability, + cx: &mut ModelContext, + ) { + self.buffer + .update(cx, |buffer, cx| buffer.set_capability(capability, cx)); + } + + fn next_timestamp(&mut self) -> clock::Lamport { + let timestamp = self.timestamp.tick(); + self.version.observe(timestamp); + timestamp + } + + pub fn serialize_ops( + &self, + since: &ContextVersion, + cx: &AppContext, + ) -> Task> { + let buffer_ops = self + .buffer + .read(cx) + .serialize_ops(Some(since.buffer.clone()), cx); + + let mut context_ops = self + .operations + .iter() + .filter(|op| !since.context.observed(op.timestamp())) + .cloned() + .collect::>(); + context_ops.extend(self.pending_ops.iter().cloned()); + + cx.background_executor().spawn(async move { + let buffer_ops = buffer_ops.await; + context_ops.sort_unstable_by_key(|op| op.timestamp()); + buffer_ops + .into_iter() + .map(|op| proto::ContextOperation { + variant: Some(proto::context_operation::Variant::BufferOperation( + proto::context_operation::BufferOperation { + operation: Some(op), + }, + )), + }) + .chain(context_ops.into_iter().map(|op| op.to_proto())) + .collect() + }) + } + + pub fn apply_ops( + &mut self, + ops: impl IntoIterator, + cx: &mut ModelContext, + ) -> Result<()> { + let mut buffer_ops = Vec::new(); + for op in ops { + match op { + ContextOperation::BufferOperation(buffer_op) => buffer_ops.push(buffer_op), + op @ _ => self.pending_ops.push(op), + } + } + self.buffer + .update(cx, |buffer, cx| buffer.apply_ops(buffer_ops, cx))?; + self.flush_ops(cx); + + Ok(()) + } + + fn flush_ops(&mut self, cx: &mut ModelContext) { + let mut messages_changed = false; + let mut summary_changed = false; + + self.pending_ops.sort_unstable_by_key(|op| op.timestamp()); + for op in mem::take(&mut self.pending_ops) { + if !self.can_apply_op(&op, cx) { + self.pending_ops.push(op); + continue; + } + + let timestamp = op.timestamp(); + match op.clone() { + ContextOperation::InsertMessage { + anchor, metadata, .. + } => { + if self.messages_metadata.contains_key(&anchor.id) { + // We already applied this operation. + } else { + self.insert_message(anchor, metadata, cx); + messages_changed = true; + } + } + ContextOperation::UpdateMessage { + message_id, + metadata: new_metadata, + .. + } => { + let metadata = self.messages_metadata.get_mut(&message_id).unwrap(); + if new_metadata.timestamp > metadata.timestamp { + *metadata = new_metadata; + messages_changed = true; + } + } + ContextOperation::UpdateSummary { + summary: new_summary, + .. + } => { + if self + .summary + .as_ref() + .map_or(true, |summary| new_summary.timestamp > summary.timestamp) + { + self.summary = Some(new_summary); + summary_changed = true; + } + } + ContextOperation::SlashCommandFinished { + id, + output_range, + sections, + .. + } => { + if self.finished_slash_commands.insert(id) { + let buffer = self.buffer.read(cx); + self.slash_command_output_sections + .extend(sections.iter().cloned()); + self.slash_command_output_sections + .sort_by(|a, b| a.range.cmp(&b.range, buffer)); + cx.emit(ContextEvent::SlashCommandFinished { + output_range, + sections, + run_commands_in_output: false, + }); + } + } + ContextOperation::BufferOperation(_) => unreachable!(), + } + + self.version.observe(timestamp); + self.timestamp.observe(timestamp); + self.operations.push(op); + } + + if messages_changed { + cx.emit(ContextEvent::MessagesEdited); + cx.notify(); + } + + if summary_changed { + cx.emit(ContextEvent::SummaryChanged); + cx.notify(); + } + } + + fn can_apply_op(&self, op: &ContextOperation, cx: &AppContext) -> bool { + if !self.version.observed_all(op.version()) { + return false; + } + + match op { + ContextOperation::InsertMessage { anchor, .. } => self + .buffer + .read(cx) + .version + .observed(anchor.start.timestamp), + ContextOperation::UpdateMessage { message_id, .. } => { + self.messages_metadata.contains_key(message_id) + } + ContextOperation::UpdateSummary { .. } => true, + ContextOperation::SlashCommandFinished { + output_range, + sections, + .. + } => { + let version = &self.buffer.read(cx).version; + sections + .iter() + .map(|section| §ion.range) + .chain([output_range]) + .all(|range| { + let observed_start = range.start == language::Anchor::MIN + || range.start == language::Anchor::MAX + || version.observed(range.start.timestamp); + let observed_end = range.end == language::Anchor::MIN + || range.end == language::Anchor::MAX + || version.observed(range.end.timestamp); + observed_start && observed_end + }) + } + ContextOperation::BufferOperation(_) => { + panic!("buffer operations should always be applied") + } + } + } + + fn push_op(&mut self, op: ContextOperation, cx: &mut ModelContext) { + self.operations.push(op.clone()); + cx.emit(ContextEvent::Operation(op)); + } + + pub fn buffer(&self) -> &Model { + &self.buffer + } + + pub fn path(&self) -> Option<&Path> { + self.path.as_deref() + } + + pub fn summary(&self) -> Option<&ContextSummary> { + self.summary.as_ref() + } + + pub fn edit_suggestions(&self) -> &[EditSuggestion] { + &self.edit_suggestions + } + + pub fn pending_slash_commands(&self) -> &[PendingSlashCommand] { + &self.pending_slash_commands + } + + pub fn slash_command_output_sections(&self) -> &[SlashCommandOutputSection] { + &self.slash_command_output_sections + } + + fn set_language(&mut self, cx: &mut ModelContext) { + let markdown = self.language_registry.language_for_name("Markdown"); + cx.spawn(|this, mut cx| async move { + let markdown = markdown.await?; + this.update(&mut cx, |this, cx| { + this.buffer + .update(cx, |buffer, cx| buffer.set_language(Some(markdown), cx)); + }) + }) + .detach_and_log_err(cx); + } + + fn handle_buffer_event( + &mut self, + _: Model, + event: &language::Event, + cx: &mut ModelContext, + ) { + match event { + language::Event::Operation(operation) => cx.emit(ContextEvent::Operation( + ContextOperation::BufferOperation(operation.clone()), + )), + language::Event::Edited => { + self.count_remaining_tokens(cx); + self.reparse_edit_suggestions(cx); + self.reparse_slash_commands(cx); + cx.emit(ContextEvent::MessagesEdited); + } + _ => {} + } + } + + pub(crate) fn token_count(&self) -> Option { + self.token_count + } + + pub(crate) fn count_remaining_tokens(&mut self, cx: &mut ModelContext) { + let request = self.to_completion_request(cx); + self.pending_token_count = cx.spawn(|this, mut cx| { + async move { + cx.background_executor() + .timer(Duration::from_millis(200)) + .await; + + let token_count = cx + .update(|cx| CompletionProvider::global(cx).count_tokens(request, cx))? + .await?; + + this.update(&mut cx, |this, cx| { + this.token_count = Some(token_count); + cx.notify() + })?; + anyhow::Ok(()) + } + .log_err() + }); + } + + pub fn reparse_slash_commands(&mut self, cx: &mut ModelContext) { + let buffer = self.buffer.read(cx); + let mut row_ranges = self + .edits_since_last_slash_command_parse + .consume() + .into_iter() + .map(|edit| { + let start_row = buffer.offset_to_point(edit.new.start).row; + let end_row = buffer.offset_to_point(edit.new.end).row + 1; + start_row..end_row + }) + .peekable(); + + let mut removed = Vec::new(); + let mut updated = Vec::new(); + while let Some(mut row_range) = row_ranges.next() { + while let Some(next_row_range) = row_ranges.peek() { + if row_range.end >= next_row_range.start { + row_range.end = next_row_range.end; + row_ranges.next(); + } else { + break; + } + } + + let start = buffer.anchor_before(Point::new(row_range.start, 0)); + let end = buffer.anchor_after(Point::new( + row_range.end - 1, + buffer.line_len(row_range.end - 1), + )); + + let old_range = self.pending_command_indices_for_range(start..end, cx); + + let mut new_commands = Vec::new(); + let mut lines = buffer.text_for_range(start..end).lines(); + let mut offset = lines.offset(); + while let Some(line) = lines.next() { + if let Some(command_line) = SlashCommandLine::parse(line) { + let name = &line[command_line.name.clone()]; + let argument = command_line.argument.as_ref().and_then(|argument| { + (!argument.is_empty()).then_some(&line[argument.clone()]) + }); + if let Some(command) = SlashCommandRegistry::global(cx).command(name) { + if !command.requires_argument() || argument.is_some() { + let start_ix = offset + command_line.name.start - 1; + let end_ix = offset + + command_line + .argument + .map_or(command_line.name.end, |argument| argument.end); + let source_range = + buffer.anchor_after(start_ix)..buffer.anchor_after(end_ix); + let pending_command = PendingSlashCommand { + name: name.to_string(), + argument: argument.map(ToString::to_string), + source_range, + status: PendingSlashCommandStatus::Idle, + }; + updated.push(pending_command.clone()); + new_commands.push(pending_command); + } + } + } + + offset = lines.offset(); + } + + let removed_commands = self.pending_slash_commands.splice(old_range, new_commands); + removed.extend(removed_commands.map(|command| command.source_range)); + } + + if !updated.is_empty() || !removed.is_empty() { + cx.emit(ContextEvent::PendingSlashCommandsUpdated { removed, updated }); + } + } + + fn reparse_edit_suggestions(&mut self, cx: &mut ModelContext) { + self.pending_edit_suggestion_parse = Some(cx.spawn(|this, mut cx| async move { + cx.background_executor() + .timer(Duration::from_millis(200)) + .await; + + this.update(&mut cx, |this, cx| { + this.reparse_edit_suggestions_in_range(0..this.buffer.read(cx).len(), cx); + }) + .ok(); + })); + } + + fn reparse_edit_suggestions_in_range( + &mut self, + range: Range, + cx: &mut ModelContext, + ) { + self.buffer.update(cx, |buffer, _| { + let range_start = buffer.anchor_before(range.start); + let range_end = buffer.anchor_after(range.end); + let start_ix = self + .edit_suggestions + .binary_search_by(|probe| { + probe + .source_range + .end + .cmp(&range_start, buffer) + .then(Ordering::Greater) + }) + .unwrap_err(); + let end_ix = self + .edit_suggestions + .binary_search_by(|probe| { + probe + .source_range + .start + .cmp(&range_end, buffer) + .then(Ordering::Less) + }) + .unwrap_err(); + + let mut new_edit_suggestions = Vec::new(); + let mut message_lines = buffer.as_rope().chunks_in_range(range).lines(); + while let Some(suggestion) = parse_next_edit_suggestion(&mut message_lines) { + let start_anchor = buffer.anchor_after(suggestion.outer_range.start); + let end_anchor = buffer.anchor_before(suggestion.outer_range.end); + new_edit_suggestions.push(EditSuggestion { + source_range: start_anchor..end_anchor, + full_path: suggestion.path, + }); + } + self.edit_suggestions + .splice(start_ix..end_ix, new_edit_suggestions); + }); + cx.emit(ContextEvent::EditSuggestionsChanged); + cx.notify(); + } + + pub fn pending_command_for_position( + &mut self, + position: language::Anchor, + cx: &mut ModelContext, + ) -> Option<&mut PendingSlashCommand> { + let buffer = self.buffer.read(cx); + match self + .pending_slash_commands + .binary_search_by(|probe| probe.source_range.end.cmp(&position, buffer)) + { + Ok(ix) => Some(&mut self.pending_slash_commands[ix]), + Err(ix) => { + let cmd = self.pending_slash_commands.get_mut(ix)?; + if position.cmp(&cmd.source_range.start, buffer).is_ge() + && position.cmp(&cmd.source_range.end, buffer).is_le() + { + Some(cmd) + } else { + None + } + } + } + } + + pub fn pending_commands_for_range( + &self, + range: Range, + cx: &AppContext, + ) -> &[PendingSlashCommand] { + let range = self.pending_command_indices_for_range(range, cx); + &self.pending_slash_commands[range] + } + + fn pending_command_indices_for_range( + &self, + range: Range, + cx: &AppContext, + ) -> Range { + let buffer = self.buffer.read(cx); + let start_ix = match self + .pending_slash_commands + .binary_search_by(|probe| probe.source_range.end.cmp(&range.start, &buffer)) + { + Ok(ix) | Err(ix) => ix, + }; + let end_ix = match self + .pending_slash_commands + .binary_search_by(|probe| probe.source_range.start.cmp(&range.end, &buffer)) + { + Ok(ix) => ix + 1, + Err(ix) => ix, + }; + start_ix..end_ix + } + + pub fn insert_command_output( + &mut self, + command_range: Range, + output: Task>, + insert_trailing_newline: bool, + cx: &mut ModelContext, + ) { + self.reparse_slash_commands(cx); + + let insert_output_task = cx.spawn(|this, mut cx| { + let command_range = command_range.clone(); + async move { + let output = output.await; + this.update(&mut cx, |this, cx| match output { + Ok(mut output) => { + if insert_trailing_newline { + output.text.push('\n'); + } + + let version = this.version.clone(); + let command_id = SlashCommandId(this.next_timestamp()); + let (operation, event) = this.buffer.update(cx, |buffer, cx| { + let start = command_range.start.to_offset(buffer); + let old_end = command_range.end.to_offset(buffer); + let new_end = start + output.text.len(); + buffer.edit([(start..old_end, output.text)], None, cx); + + let mut sections = output + .sections + .into_iter() + .map(|section| SlashCommandOutputSection { + range: buffer.anchor_after(start + section.range.start) + ..buffer.anchor_before(start + section.range.end), + icon: section.icon, + label: section.label, + }) + .collect::>(); + sections.sort_by(|a, b| a.range.cmp(&b.range, buffer)); + + this.slash_command_output_sections + .extend(sections.iter().cloned()); + this.slash_command_output_sections + .sort_by(|a, b| a.range.cmp(&b.range, buffer)); + + let output_range = + buffer.anchor_after(start)..buffer.anchor_before(new_end); + this.finished_slash_commands.insert(command_id); + + ( + ContextOperation::SlashCommandFinished { + id: command_id, + output_range: output_range.clone(), + sections: sections.clone(), + version, + }, + ContextEvent::SlashCommandFinished { + output_range, + sections, + run_commands_in_output: output.run_commands_in_text, + }, + ) + }); + + this.push_op(operation, cx); + cx.emit(event); + } + Err(error) => { + if let Some(pending_command) = + this.pending_command_for_position(command_range.start, cx) + { + pending_command.status = + PendingSlashCommandStatus::Error(error.to_string()); + cx.emit(ContextEvent::PendingSlashCommandsUpdated { + removed: vec![pending_command.source_range.clone()], + updated: vec![pending_command.clone()], + }); + } + } + }) + .ok(); + } + }); + + if let Some(pending_command) = self.pending_command_for_position(command_range.start, cx) { + pending_command.status = PendingSlashCommandStatus::Running { + _task: insert_output_task.shared(), + }; + cx.emit(ContextEvent::PendingSlashCommandsUpdated { + removed: vec![pending_command.source_range.clone()], + updated: vec![pending_command.clone()], + }); + } + } + + pub fn completion_provider_changed(&mut self, cx: &mut ModelContext) { + self.count_remaining_tokens(cx); + } + + pub fn assist( + &mut self, + selected_messages: HashSet, + cx: &mut ModelContext, + ) -> Vec { + let mut user_messages = Vec::new(); + + let last_message_id = if let Some(last_message_id) = + self.message_anchors.iter().rev().find_map(|message| { + message + .start + .is_valid(self.buffer.read(cx)) + .then_some(message.id) + }) { + last_message_id + } else { + return Default::default(); + }; + + let mut should_assist = false; + for selected_message_id in selected_messages { + let selected_message_role = + if let Some(metadata) = self.messages_metadata.get(&selected_message_id) { + metadata.role + } else { + continue; + }; + + if selected_message_role == Role::Assistant { + if let Some(user_message) = self.insert_message_after( + selected_message_id, + Role::User, + MessageStatus::Done, + cx, + ) { + user_messages.push(user_message); + } + } else { + should_assist = true; + } + } + + if should_assist { + if !CompletionProvider::global(cx).is_authenticated() { + log::info!("completion provider has no credentials"); + return Default::default(); + } + + let request = self.to_completion_request(cx); + let stream = CompletionProvider::global(cx).complete(request, cx); + let assistant_message = self + .insert_message_after(last_message_id, Role::Assistant, MessageStatus::Pending, cx) + .unwrap(); + + // Queue up the user's next reply. + let user_message = self + .insert_message_after(assistant_message.id, Role::User, MessageStatus::Done, cx) + .unwrap(); + user_messages.push(user_message); + + let task = cx.spawn({ + |this, mut cx| async move { + let assistant_message_id = assistant_message.id; + let mut response_latency = None; + let stream_completion = async { + let request_start = Instant::now(); + let mut messages = stream.await.inner.await?; + + while let Some(message) = messages.next().await { + if response_latency.is_none() { + response_latency = Some(request_start.elapsed()); + } + let text = message?; + + this.update(&mut cx, |this, cx| { + let message_ix = this + .message_anchors + .iter() + .position(|message| message.id == assistant_message_id)?; + let message_range = this.buffer.update(cx, |buffer, cx| { + let message_start_offset = + this.message_anchors[message_ix].start.to_offset(buffer); + let message_old_end_offset = this.message_anchors + [message_ix + 1..] + .iter() + .find(|message| message.start.is_valid(buffer)) + .map_or(buffer.len(), |message| { + message.start.to_offset(buffer).saturating_sub(1) + }); + let message_new_end_offset = + message_old_end_offset + text.len(); + buffer.edit( + [(message_old_end_offset..message_old_end_offset, text)], + None, + cx, + ); + message_start_offset..message_new_end_offset + }); + this.reparse_edit_suggestions_in_range(message_range, cx); + cx.emit(ContextEvent::StreamedCompletion); + + Some(()) + })?; + smol::future::yield_now().await; + } + + this.update(&mut cx, |this, cx| { + this.pending_completions + .retain(|completion| completion.id != this.completion_count); + this.summarize(cx); + })?; + + anyhow::Ok(()) + }; + + let result = stream_completion.await; + + this.update(&mut cx, |this, cx| { + let error_message = result + .err() + .map(|error| error.to_string().trim().to_string()); + + this.update_metadata(assistant_message_id, cx, |metadata| { + if let Some(error_message) = error_message.as_ref() { + metadata.status = + MessageStatus::Error(SharedString::from(error_message.clone())); + } else { + metadata.status = MessageStatus::Done; + } + }); + + if let Some(telemetry) = this.telemetry.as_ref() { + let model = CompletionProvider::global(cx).model(); + telemetry.report_assistant_event( + Some(this.id.0.clone()), + AssistantKind::Panel, + model.telemetry_id(), + response_latency, + error_message, + ); + } + }) + .ok(); + } + }); + + self.pending_completions.push(PendingCompletion { + id: post_inc(&mut self.completion_count), + _task: task, + }); + } + + user_messages + } + + pub fn to_completion_request(&self, cx: &AppContext) -> LanguageModelRequest { + let messages = self + .messages(cx) + .filter(|message| matches!(message.status, MessageStatus::Done)) + .map(|message| message.to_request_message(self.buffer.read(cx))); + + LanguageModelRequest { + model: CompletionProvider::global(cx).model(), + messages: messages.collect(), + stop: vec![], + temperature: 1.0, + } + } + + pub fn cancel_last_assist(&mut self) -> bool { + self.pending_completions.pop().is_some() + } + + pub fn cycle_message_roles(&mut self, ids: HashSet, cx: &mut ModelContext) { + for id in ids { + if let Some(metadata) = self.messages_metadata.get(&id) { + let role = metadata.role.cycle(); + self.update_metadata(id, cx, |metadata| metadata.role = role); + } + } + } + + pub fn update_metadata( + &mut self, + id: MessageId, + cx: &mut ModelContext, + f: impl FnOnce(&mut MessageMetadata), + ) { + let version = self.version.clone(); + let timestamp = self.next_timestamp(); + if let Some(metadata) = self.messages_metadata.get_mut(&id) { + f(metadata); + metadata.timestamp = timestamp; + let operation = ContextOperation::UpdateMessage { + message_id: id, + metadata: metadata.clone(), + version, + }; + self.push_op(operation, cx); + cx.emit(ContextEvent::MessagesEdited); + cx.notify(); + } + } + + fn insert_message_after( + &mut self, + message_id: MessageId, + role: Role, + status: MessageStatus, + cx: &mut ModelContext, + ) -> Option { + if let Some(prev_message_ix) = self + .message_anchors + .iter() + .position(|message| message.id == message_id) + { + // Find the next valid message after the one we were given. + let mut next_message_ix = prev_message_ix + 1; + while let Some(next_message) = self.message_anchors.get(next_message_ix) { + if next_message.start.is_valid(self.buffer.read(cx)) { + break; + } + next_message_ix += 1; + } + + let start = self.buffer.update(cx, |buffer, cx| { + let offset = self + .message_anchors + .get(next_message_ix) + .map_or(buffer.len(), |message| { + buffer.clip_offset(message.start.to_offset(buffer) - 1, Bias::Left) + }); + buffer.edit([(offset..offset, "\n")], None, cx); + buffer.anchor_before(offset + 1) + }); + + let version = self.version.clone(); + let anchor = MessageAnchor { + id: MessageId(self.next_timestamp()), + start, + }; + let metadata = MessageMetadata { + role, + status, + timestamp: anchor.id.0, + }; + self.insert_message(anchor.clone(), metadata.clone(), cx); + self.push_op( + ContextOperation::InsertMessage { + anchor: anchor.clone(), + metadata, + version, + }, + cx, + ); + Some(anchor) + } else { + None + } + } + + pub fn split_message( + &mut self, + range: Range, + cx: &mut ModelContext, + ) -> (Option, Option) { + let start_message = self.message_for_offset(range.start, cx); + let end_message = self.message_for_offset(range.end, cx); + if let Some((start_message, end_message)) = start_message.zip(end_message) { + // Prevent splitting when range spans multiple messages. + if start_message.id != end_message.id { + return (None, None); + } + + let message = start_message; + let role = message.role; + let mut edited_buffer = false; + + let mut suffix_start = None; + if range.start > message.offset_range.start && range.end < message.offset_range.end - 1 + { + if self.buffer.read(cx).chars_at(range.end).next() == Some('\n') { + suffix_start = Some(range.end + 1); + } else if self.buffer.read(cx).reversed_chars_at(range.end).next() == Some('\n') { + suffix_start = Some(range.end); + } + } + + let version = self.version.clone(); + let suffix = if let Some(suffix_start) = suffix_start { + MessageAnchor { + id: MessageId(self.next_timestamp()), + start: self.buffer.read(cx).anchor_before(suffix_start), + } + } else { + self.buffer.update(cx, |buffer, cx| { + buffer.edit([(range.end..range.end, "\n")], None, cx); + }); + edited_buffer = true; + MessageAnchor { + id: MessageId(self.next_timestamp()), + start: self.buffer.read(cx).anchor_before(range.end + 1), + } + }; + + let suffix_metadata = MessageMetadata { + role, + status: MessageStatus::Done, + timestamp: suffix.id.0, + }; + self.insert_message(suffix.clone(), suffix_metadata.clone(), cx); + self.push_op( + ContextOperation::InsertMessage { + anchor: suffix.clone(), + metadata: suffix_metadata, + version, + }, + cx, + ); + + let new_messages = + if range.start == range.end || range.start == message.offset_range.start { + (None, Some(suffix)) + } else { + let mut prefix_end = None; + if range.start > message.offset_range.start + && range.end < message.offset_range.end - 1 + { + if self.buffer.read(cx).chars_at(range.start).next() == Some('\n') { + prefix_end = Some(range.start + 1); + } else if self.buffer.read(cx).reversed_chars_at(range.start).next() + == Some('\n') + { + prefix_end = Some(range.start); + } + } + + let version = self.version.clone(); + let selection = if let Some(prefix_end) = prefix_end { + MessageAnchor { + id: MessageId(self.next_timestamp()), + start: self.buffer.read(cx).anchor_before(prefix_end), + } + } else { + self.buffer.update(cx, |buffer, cx| { + buffer.edit([(range.start..range.start, "\n")], None, cx) + }); + edited_buffer = true; + MessageAnchor { + id: MessageId(self.next_timestamp()), + start: self.buffer.read(cx).anchor_before(range.end + 1), + } + }; + + let selection_metadata = MessageMetadata { + role, + status: MessageStatus::Done, + timestamp: selection.id.0, + }; + self.insert_message(selection.clone(), selection_metadata.clone(), cx); + self.push_op( + ContextOperation::InsertMessage { + anchor: selection.clone(), + metadata: selection_metadata, + version, + }, + cx, + ); + + (Some(selection), Some(suffix)) + }; + + if !edited_buffer { + cx.emit(ContextEvent::MessagesEdited); + } + new_messages + } else { + (None, None) + } + } + + fn insert_message( + &mut self, + new_anchor: MessageAnchor, + new_metadata: MessageMetadata, + cx: &mut ModelContext, + ) { + cx.emit(ContextEvent::MessagesEdited); + + self.messages_metadata.insert(new_anchor.id, new_metadata); + + let buffer = self.buffer.read(cx); + let insertion_ix = self + .message_anchors + .iter() + .position(|anchor| { + let comparison = new_anchor.start.cmp(&anchor.start, buffer); + comparison.is_lt() || (comparison.is_eq() && new_anchor.id > anchor.id) + }) + .unwrap_or(self.message_anchors.len()); + self.message_anchors.insert(insertion_ix, new_anchor); + } + + fn summarize(&mut self, cx: &mut ModelContext) { + if self.message_anchors.len() >= 2 && self.summary.is_none() { + if !CompletionProvider::global(cx).is_authenticated() { + return; + } + + let messages = self + .messages(cx) + .map(|message| message.to_request_message(self.buffer.read(cx))) + .chain(Some(LanguageModelRequestMessage { + role: Role::User, + content: "Summarize the context into a short title without punctuation.".into(), + })); + let request = LanguageModelRequest { + model: CompletionProvider::global(cx).model(), + messages: messages.collect(), + stop: vec![], + temperature: 1.0, + }; + + let stream = CompletionProvider::global(cx).complete(request, cx); + self.pending_summary = cx.spawn(|this, mut cx| { + async move { + let mut messages = stream.await.inner.await?; + + while let Some(message) = messages.next().await { + let text = message?; + let mut lines = text.lines(); + this.update(&mut cx, |this, cx| { + let version = this.version.clone(); + let timestamp = this.next_timestamp(); + let summary = this.summary.get_or_insert(Default::default()); + summary.text.extend(lines.next()); + summary.timestamp = timestamp; + let operation = ContextOperation::UpdateSummary { + summary: summary.clone(), + version, + }; + this.push_op(operation, cx); + cx.emit(ContextEvent::SummaryChanged); + })?; + + // Stop if the LLM generated multiple lines. + if lines.next().is_some() { + break; + } + } + + this.update(&mut cx, |this, cx| { + let version = this.version.clone(); + let timestamp = this.next_timestamp(); + if let Some(summary) = this.summary.as_mut() { + summary.done = true; + summary.timestamp = timestamp; + let operation = ContextOperation::UpdateSummary { + summary: summary.clone(), + version, + }; + this.push_op(operation, cx); + cx.emit(ContextEvent::SummaryChanged); + } + })?; + + anyhow::Ok(()) + } + .log_err() + }); + } + } + + fn message_for_offset(&self, offset: usize, cx: &AppContext) -> Option { + self.messages_for_offsets([offset], cx).pop() + } + + pub fn messages_for_offsets( + &self, + offsets: impl IntoIterator, + cx: &AppContext, + ) -> Vec { + let mut result = Vec::new(); + + let mut messages = self.messages(cx).peekable(); + let mut offsets = offsets.into_iter().peekable(); + let mut current_message = messages.next(); + while let Some(offset) = offsets.next() { + // Locate the message that contains the offset. + while current_message.as_ref().map_or(false, |message| { + !message.offset_range.contains(&offset) && messages.peek().is_some() + }) { + current_message = messages.next(); + } + let Some(message) = current_message.as_ref() else { + break; + }; + + // Skip offsets that are in the same message. + while offsets.peek().map_or(false, |offset| { + message.offset_range.contains(offset) || messages.peek().is_none() + }) { + offsets.next(); + } + + result.push(message.clone()); + } + result + } + + pub fn messages<'a>(&'a self, cx: &'a AppContext) -> impl 'a + Iterator { + let buffer = self.buffer.read(cx); + let mut message_anchors = self.message_anchors.iter().enumerate().peekable(); + iter::from_fn(move || { + if let Some((start_ix, message_anchor)) = message_anchors.next() { + let metadata = self.messages_metadata.get(&message_anchor.id)?; + let message_start = message_anchor.start.to_offset(buffer); + let mut message_end = None; + let mut end_ix = start_ix; + while let Some((_, next_message)) = message_anchors.peek() { + if next_message.start.is_valid(buffer) { + message_end = Some(next_message.start); + break; + } else { + end_ix += 1; + message_anchors.next(); + } + } + let message_end = message_end + .unwrap_or(language::Anchor::MAX) + .to_offset(buffer); + + return Some(Message { + index_range: start_ix..end_ix, + offset_range: message_start..message_end, + id: message_anchor.id, + anchor: message_anchor.start, + role: metadata.role, + status: metadata.status.clone(), + }); + } + None + }) + } + + pub fn save( + &mut self, + debounce: Option, + fs: Arc, + cx: &mut ModelContext, + ) { + if self.replica_id() != ReplicaId::default() { + // Prevent saving a remote context for now. + return; + } + + self.pending_save = cx.spawn(|this, mut cx| async move { + if let Some(debounce) = debounce { + cx.background_executor().timer(debounce).await; + } + + let (old_path, summary) = this.read_with(&cx, |this, _| { + let path = this.path.clone(); + let summary = if let Some(summary) = this.summary.as_ref() { + if summary.done { + Some(summary.text.clone()) + } else { + None + } + } else { + None + }; + (path, summary) + })?; + + if let Some(summary) = summary { + let context = this.read_with(&cx, |this, cx| this.serialize(cx))?; + let path = if let Some(old_path) = old_path { + old_path + } else { + let mut discriminant = 1; + let mut new_path; + loop { + new_path = contexts_dir().join(&format!( + "{} - {}.zed.json", + summary.trim(), + discriminant + )); + if fs.is_file(&new_path).await { + discriminant += 1; + } else { + break; + } + } + new_path + }; + + fs.create_dir(contexts_dir().as_ref()).await?; + fs.atomic_write(path.clone(), serde_json::to_string(&context).unwrap()) + .await?; + this.update(&mut cx, |this, _| this.path = Some(path))?; + } + + Ok(()) + }); + } +} + +#[derive(Debug, Default)] +pub struct ContextVersion { + context: clock::Global, + buffer: clock::Global, +} + +impl ContextVersion { + pub fn from_proto(proto: &proto::ContextVersion) -> Self { + Self { + context: language::proto::deserialize_version(&proto.context_version), + buffer: language::proto::deserialize_version(&proto.buffer_version), + } + } + + pub fn to_proto(&self, context_id: ContextId) -> proto::ContextVersion { + proto::ContextVersion { + context_id: context_id.to_proto(), + context_version: language::proto::serialize_version(&self.context), + buffer_version: language::proto::serialize_version(&self.buffer), + } + } +} + +#[derive(Debug)] +enum EditParsingState { + None, + InOldText { + path: PathBuf, + start_offset: usize, + old_text_start_offset: usize, + }, + InNewText { + path: PathBuf, + start_offset: usize, + old_text_range: Range, + new_text_start_offset: usize, + }, +} + +#[derive(Clone, Debug, PartialEq)] +pub struct EditSuggestion { + pub source_range: Range, + pub full_path: PathBuf, +} + +pub struct ParsedEditSuggestion { + pub path: PathBuf, + pub outer_range: Range, + pub old_text_range: Range, + pub new_text_range: Range, +} + +pub fn parse_next_edit_suggestion(lines: &mut rope::Lines) -> Option { + let mut state = EditParsingState::None; + loop { + let offset = lines.offset(); + let message_line = lines.next()?; + match state { + EditParsingState::None => { + if let Some(rest) = message_line.strip_prefix("```edit ") { + let path = rest.trim(); + if !path.is_empty() { + state = EditParsingState::InOldText { + path: PathBuf::from(path), + start_offset: offset, + old_text_start_offset: lines.offset(), + }; + } + } + } + EditParsingState::InOldText { + path, + start_offset, + old_text_start_offset, + } => { + if message_line == "---" { + state = EditParsingState::InNewText { + path, + start_offset, + old_text_range: old_text_start_offset..offset, + new_text_start_offset: lines.offset(), + }; + } else { + state = EditParsingState::InOldText { + path, + start_offset, + old_text_start_offset, + }; + } + } + EditParsingState::InNewText { + path, + start_offset, + old_text_range, + new_text_start_offset, + } => { + if message_line == "```" { + return Some(ParsedEditSuggestion { + path, + outer_range: start_offset..offset + "```".len(), + old_text_range, + new_text_range: new_text_start_offset..offset, + }); + } else { + state = EditParsingState::InNewText { + path, + start_offset, + old_text_range, + new_text_start_offset, + }; + } + } + } + } +} + +#[derive(Clone)] +pub struct PendingSlashCommand { + pub name: String, + pub argument: Option, + pub status: PendingSlashCommandStatus, + pub source_range: Range, +} + +#[derive(Clone)] +pub enum PendingSlashCommandStatus { + Idle, + Running { _task: Shared> }, + Error(String), +} + +#[derive(Serialize, Deserialize)] +pub struct SavedMessage { + pub id: MessageId, + pub start: usize, + pub metadata: MessageMetadata, +} + +#[derive(Serialize, Deserialize)] +pub struct SavedContext { + pub id: Option, + pub zed: String, + pub version: String, + pub text: String, + pub messages: Vec, + pub summary: String, + pub slash_command_output_sections: + Vec>, +} + +impl SavedContext { + pub const VERSION: &'static str = "0.4.0"; + + pub fn from_json(json: &str) -> Result { + let saved_context_json = serde_json::from_str::(json)?; + match saved_context_json + .get("version") + .ok_or_else(|| anyhow!("version not found"))? + { + serde_json::Value::String(version) => match version.as_str() { + SavedContext::VERSION => { + Ok(serde_json::from_value::(saved_context_json)?) + } + SavedContextV0_3_0::VERSION => { + let saved_context = + serde_json::from_value::(saved_context_json)?; + Ok(saved_context.upgrade()) + } + SavedContextV0_2_0::VERSION => { + let saved_context = + serde_json::from_value::(saved_context_json)?; + Ok(saved_context.upgrade()) + } + SavedContextV0_1_0::VERSION => { + let saved_context = + serde_json::from_value::(saved_context_json)?; + Ok(saved_context.upgrade()) + } + _ => Err(anyhow!("unrecognized saved context version: {}", version)), + }, + _ => Err(anyhow!("version not found on saved context")), + } + } + + fn into_ops( + self, + buffer: &Model, + cx: &mut ModelContext, + ) -> Vec { + let mut operations = Vec::new(); + let mut version = clock::Global::new(); + let mut next_timestamp = clock::Lamport::new(ReplicaId::default()); + + let mut first_message_metadata = None; + for message in self.messages { + if message.id == MessageId(clock::Lamport::default()) { + first_message_metadata = Some(message.metadata); + } else { + operations.push(ContextOperation::InsertMessage { + anchor: MessageAnchor { + id: message.id, + start: buffer.read(cx).anchor_before(message.start), + }, + metadata: MessageMetadata { + role: message.metadata.role, + status: message.metadata.status, + timestamp: message.metadata.timestamp, + }, + version: version.clone(), + }); + version.observe(message.id.0); + next_timestamp.observe(message.id.0); + } + } + + if let Some(metadata) = first_message_metadata { + let timestamp = next_timestamp.tick(); + operations.push(ContextOperation::UpdateMessage { + message_id: MessageId(clock::Lamport::default()), + metadata: MessageMetadata { + role: metadata.role, + status: metadata.status, + timestamp, + }, + version: version.clone(), + }); + version.observe(timestamp); + } + + let timestamp = next_timestamp.tick(); + operations.push(ContextOperation::SlashCommandFinished { + id: SlashCommandId(timestamp), + output_range: language::Anchor::MIN..language::Anchor::MAX, + sections: self + .slash_command_output_sections + .into_iter() + .map(|section| { + let buffer = buffer.read(cx); + SlashCommandOutputSection { + range: buffer.anchor_after(section.range.start) + ..buffer.anchor_before(section.range.end), + icon: section.icon, + label: section.label, + } + }) + .collect(), + version: version.clone(), + }); + version.observe(timestamp); + + let timestamp = next_timestamp.tick(); + operations.push(ContextOperation::UpdateSummary { + summary: ContextSummary { + text: self.summary, + done: true, + timestamp, + }, + version: version.clone(), + }); + version.observe(timestamp); + + operations + } +} + +#[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize)] +struct SavedMessageIdPreV0_4_0(usize); + +#[derive(Serialize, Deserialize)] +struct SavedMessagePreV0_4_0 { + id: SavedMessageIdPreV0_4_0, + start: usize, +} + +#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)] +struct SavedMessageMetadataPreV0_4_0 { + role: Role, + status: MessageStatus, +} + +#[derive(Serialize, Deserialize)] +struct SavedContextV0_3_0 { + id: Option, + zed: String, + version: String, + text: String, + messages: Vec, + message_metadata: HashMap, + summary: String, + slash_command_output_sections: Vec>, +} + +impl SavedContextV0_3_0 { + const VERSION: &'static str = "0.3.0"; + + fn upgrade(self) -> SavedContext { + SavedContext { + id: self.id, + zed: self.zed, + version: SavedContext::VERSION.into(), + text: self.text, + messages: self + .messages + .into_iter() + .filter_map(|message| { + let metadata = self.message_metadata.get(&message.id)?; + let timestamp = clock::Lamport { + replica_id: ReplicaId::default(), + value: message.id.0 as u32, + }; + Some(SavedMessage { + id: MessageId(timestamp), + start: message.start, + metadata: MessageMetadata { + role: metadata.role, + status: metadata.status.clone(), + timestamp, + }, + }) + }) + .collect(), + summary: self.summary, + slash_command_output_sections: self.slash_command_output_sections, + } + } +} + +#[derive(Serialize, Deserialize)] +struct SavedContextV0_2_0 { + id: Option, + zed: String, + version: String, + text: String, + messages: Vec, + message_metadata: HashMap, + summary: String, +} + +impl SavedContextV0_2_0 { + const VERSION: &'static str = "0.2.0"; + + fn upgrade(self) -> SavedContext { + SavedContextV0_3_0 { + id: self.id, + zed: self.zed, + version: SavedContextV0_3_0::VERSION.to_string(), + text: self.text, + messages: self.messages, + message_metadata: self.message_metadata, + summary: self.summary, + slash_command_output_sections: Vec::new(), + } + .upgrade() + } +} + +#[derive(Serialize, Deserialize)] +struct SavedContextV0_1_0 { + id: Option, + zed: String, + version: String, + text: String, + messages: Vec, + message_metadata: HashMap, + summary: String, + api_url: Option, + model: OpenAiModel, +} + +impl SavedContextV0_1_0 { + const VERSION: &'static str = "0.1.0"; + + fn upgrade(self) -> SavedContext { + SavedContextV0_2_0 { + id: self.id, + zed: self.zed, + version: SavedContextV0_2_0::VERSION.to_string(), + text: self.text, + messages: self.messages, + message_metadata: self.message_metadata, + summary: self.summary, + } + .upgrade() + } +} + +#[derive(Clone)] +pub struct SavedContextMetadata { + pub title: String, + pub path: PathBuf, + pub mtime: chrono::DateTime, +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + assistant_panel, + slash_command::{active_command, file_command}, + FakeCompletionProvider, MessageId, + }; + use assistant_slash_command::{ArgumentCompletion, SlashCommand}; + use fs::FakeFs; + use gpui::{AppContext, TestAppContext, WeakView}; + use language::LspAdapterDelegate; + use parking_lot::Mutex; + use project::Project; + use rand::prelude::*; + use rope::Rope; + use serde_json::json; + use settings::SettingsStore; + use std::{cell::RefCell, env, path::Path, rc::Rc, sync::atomic::AtomicBool}; + use text::network::Network; + use ui::WindowContext; + use unindent::Unindent; + use util::{test::marked_text_ranges, RandomCharIter}; + use workspace::Workspace; + + #[gpui::test] + fn test_inserting_and_removing_messages(cx: &mut AppContext) { + let settings_store = SettingsStore::test(cx); + FakeCompletionProvider::setup_test(cx); + cx.set_global(settings_store); + assistant_panel::init(cx); + let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone())); + + let context = cx.new_model(|cx| Context::local(registry, None, cx)); + let buffer = context.read(cx).buffer.clone(); + + let message_1 = context.read(cx).message_anchors[0].clone(); + assert_eq!( + messages(&context, cx), + vec![(message_1.id, Role::User, 0..0)] + ); + + let message_2 = context.update(cx, |context, cx| { + context + .insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx) + .unwrap() + }); + assert_eq!( + messages(&context, cx), + vec![ + (message_1.id, Role::User, 0..1), + (message_2.id, Role::Assistant, 1..1) + ] + ); + + buffer.update(cx, |buffer, cx| { + buffer.edit([(0..0, "1"), (1..1, "2")], None, cx) + }); + assert_eq!( + messages(&context, cx), + vec![ + (message_1.id, Role::User, 0..2), + (message_2.id, Role::Assistant, 2..3) + ] + ); + + let message_3 = context.update(cx, |context, cx| { + context + .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx) + .unwrap() + }); + assert_eq!( + messages(&context, cx), + vec![ + (message_1.id, Role::User, 0..2), + (message_2.id, Role::Assistant, 2..4), + (message_3.id, Role::User, 4..4) + ] + ); + + let message_4 = context.update(cx, |context, cx| { + context + .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx) + .unwrap() + }); + assert_eq!( + messages(&context, cx), + vec![ + (message_1.id, Role::User, 0..2), + (message_2.id, Role::Assistant, 2..4), + (message_4.id, Role::User, 4..5), + (message_3.id, Role::User, 5..5), + ] + ); + + buffer.update(cx, |buffer, cx| { + buffer.edit([(4..4, "C"), (5..5, "D")], None, cx) + }); + assert_eq!( + messages(&context, cx), + vec![ + (message_1.id, Role::User, 0..2), + (message_2.id, Role::Assistant, 2..4), + (message_4.id, Role::User, 4..6), + (message_3.id, Role::User, 6..7), + ] + ); + + // Deleting across message boundaries merges the messages. + buffer.update(cx, |buffer, cx| buffer.edit([(1..4, "")], None, cx)); + assert_eq!( + messages(&context, cx), + vec![ + (message_1.id, Role::User, 0..3), + (message_3.id, Role::User, 3..4), + ] + ); + + // Undoing the deletion should also undo the merge. + buffer.update(cx, |buffer, cx| buffer.undo(cx)); + assert_eq!( + messages(&context, cx), + vec![ + (message_1.id, Role::User, 0..2), + (message_2.id, Role::Assistant, 2..4), + (message_4.id, Role::User, 4..6), + (message_3.id, Role::User, 6..7), + ] + ); + + // Redoing the deletion should also redo the merge. + buffer.update(cx, |buffer, cx| buffer.redo(cx)); + assert_eq!( + messages(&context, cx), + vec![ + (message_1.id, Role::User, 0..3), + (message_3.id, Role::User, 3..4), + ] + ); + + // Ensure we can still insert after a merged message. + let message_5 = context.update(cx, |context, cx| { + context + .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx) + .unwrap() + }); + assert_eq!( + messages(&context, cx), + vec![ + (message_1.id, Role::User, 0..3), + (message_5.id, Role::System, 3..4), + (message_3.id, Role::User, 4..5) + ] + ); + } + + #[gpui::test] + fn test_message_splitting(cx: &mut AppContext) { + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + FakeCompletionProvider::setup_test(cx); + assistant_panel::init(cx); + let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone())); + + let context = cx.new_model(|cx| Context::local(registry, None, cx)); + let buffer = context.read(cx).buffer.clone(); + + let message_1 = context.read(cx).message_anchors[0].clone(); + assert_eq!( + messages(&context, cx), + vec![(message_1.id, Role::User, 0..0)] + ); + + buffer.update(cx, |buffer, cx| { + buffer.edit([(0..0, "aaa\nbbb\nccc\nddd\n")], None, cx) + }); + + let (_, message_2) = context.update(cx, |context, cx| context.split_message(3..3, cx)); + let message_2 = message_2.unwrap(); + + // We recycle newlines in the middle of a split message + assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\nddd\n"); + assert_eq!( + messages(&context, cx), + vec![ + (message_1.id, Role::User, 0..4), + (message_2.id, Role::User, 4..16), + ] + ); + + let (_, message_3) = context.update(cx, |context, cx| context.split_message(3..3, cx)); + let message_3 = message_3.unwrap(); + + // We don't recycle newlines at the end of a split message + assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n"); + assert_eq!( + messages(&context, cx), + vec![ + (message_1.id, Role::User, 0..4), + (message_3.id, Role::User, 4..5), + (message_2.id, Role::User, 5..17), + ] + ); + + let (_, message_4) = context.update(cx, |context, cx| context.split_message(9..9, cx)); + let message_4 = message_4.unwrap(); + assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n"); + assert_eq!( + messages(&context, cx), + vec![ + (message_1.id, Role::User, 0..4), + (message_3.id, Role::User, 4..5), + (message_2.id, Role::User, 5..9), + (message_4.id, Role::User, 9..17), + ] + ); + + let (_, message_5) = context.update(cx, |context, cx| context.split_message(9..9, cx)); + let message_5 = message_5.unwrap(); + assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\nddd\n"); + assert_eq!( + messages(&context, cx), + vec![ + (message_1.id, Role::User, 0..4), + (message_3.id, Role::User, 4..5), + (message_2.id, Role::User, 5..9), + (message_4.id, Role::User, 9..10), + (message_5.id, Role::User, 10..18), + ] + ); + + let (message_6, message_7) = + context.update(cx, |context, cx| context.split_message(14..16, cx)); + let message_6 = message_6.unwrap(); + let message_7 = message_7.unwrap(); + assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\ndd\nd\n"); + assert_eq!( + messages(&context, cx), + vec![ + (message_1.id, Role::User, 0..4), + (message_3.id, Role::User, 4..5), + (message_2.id, Role::User, 5..9), + (message_4.id, Role::User, 9..10), + (message_5.id, Role::User, 10..14), + (message_6.id, Role::User, 14..17), + (message_7.id, Role::User, 17..19), + ] + ); + } + + #[gpui::test] + fn test_messages_for_offsets(cx: &mut AppContext) { + let settings_store = SettingsStore::test(cx); + FakeCompletionProvider::setup_test(cx); + cx.set_global(settings_store); + assistant_panel::init(cx); + let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone())); + let context = cx.new_model(|cx| Context::local(registry, None, cx)); + let buffer = context.read(cx).buffer.clone(); + + let message_1 = context.read(cx).message_anchors[0].clone(); + assert_eq!( + messages(&context, cx), + vec![(message_1.id, Role::User, 0..0)] + ); + + buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx)); + let message_2 = context + .update(cx, |context, cx| { + context.insert_message_after(message_1.id, Role::User, MessageStatus::Done, cx) + }) + .unwrap(); + buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbb")], None, cx)); + + let message_3 = context + .update(cx, |context, cx| { + context.insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx) + }) + .unwrap(); + buffer.update(cx, |buffer, cx| buffer.edit([(8..8, "ccc")], None, cx)); + + assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc"); + assert_eq!( + messages(&context, cx), + vec![ + (message_1.id, Role::User, 0..4), + (message_2.id, Role::User, 4..8), + (message_3.id, Role::User, 8..11) + ] + ); + + assert_eq!( + message_ids_for_offsets(&context, &[0, 4, 9], cx), + [message_1.id, message_2.id, message_3.id] + ); + assert_eq!( + message_ids_for_offsets(&context, &[0, 1, 11], cx), + [message_1.id, message_3.id] + ); + + let message_4 = context + .update(cx, |context, cx| { + context.insert_message_after(message_3.id, Role::User, MessageStatus::Done, cx) + }) + .unwrap(); + assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\n"); + assert_eq!( + messages(&context, cx), + vec![ + (message_1.id, Role::User, 0..4), + (message_2.id, Role::User, 4..8), + (message_3.id, Role::User, 8..12), + (message_4.id, Role::User, 12..12) + ] + ); + assert_eq!( + message_ids_for_offsets(&context, &[0, 4, 8, 12], cx), + [message_1.id, message_2.id, message_3.id, message_4.id] + ); + + fn message_ids_for_offsets( + context: &Model, + offsets: &[usize], + cx: &AppContext, + ) -> Vec { + context + .read(cx) + .messages_for_offsets(offsets.iter().copied(), cx) + .into_iter() + .map(|message| message.id) + .collect() + } + } + + #[gpui::test] + async fn test_slash_commands(cx: &mut TestAppContext) { + let settings_store = cx.update(SettingsStore::test); + cx.set_global(settings_store); + cx.update(FakeCompletionProvider::setup_test); + cx.update(Project::init_settings); + cx.update(assistant_panel::init); + let fs = FakeFs::new(cx.background_executor.clone()); + + fs.insert_tree( + "/test", + json!({ + "src": { + "lib.rs": "fn one() -> usize { 1 }", + "main.rs": " + use crate::one; + fn main() { one(); } + ".unindent(), + } + }), + ) + .await; + + let slash_command_registry = cx.update(SlashCommandRegistry::default_global); + slash_command_registry.register_command(file_command::FileSlashCommand, false); + slash_command_registry.register_command(active_command::ActiveSlashCommand, false); + + let registry = Arc::new(LanguageRegistry::test(cx.executor())); + let context = cx.new_model(|cx| Context::local(registry.clone(), None, cx)); + + let output_ranges = Rc::new(RefCell::new(HashSet::default())); + context.update(cx, |_, cx| { + cx.subscribe(&context, { + let ranges = output_ranges.clone(); + move |_, _, event, _| match event { + ContextEvent::PendingSlashCommandsUpdated { removed, updated } => { + for range in removed { + ranges.borrow_mut().remove(range); + } + for command in updated { + ranges.borrow_mut().insert(command.source_range.clone()); + } + } + _ => {} + } + }) + .detach(); + }); + + let buffer = context.read_with(cx, |context, _| context.buffer.clone()); + + // Insert a slash command + buffer.update(cx, |buffer, cx| { + buffer.edit([(0..0, "/file src/lib.rs")], None, cx); + }); + assert_text_and_output_ranges( + &buffer, + &output_ranges.borrow(), + " + «/file src/lib.rs» + " + .unindent() + .trim_end(), + cx, + ); + + // Edit the argument of the slash command. + buffer.update(cx, |buffer, cx| { + let edit_offset = buffer.text().find("lib.rs").unwrap(); + buffer.edit([(edit_offset..edit_offset + "lib".len(), "main")], None, cx); + }); + assert_text_and_output_ranges( + &buffer, + &output_ranges.borrow(), + " + «/file src/main.rs» + " + .unindent() + .trim_end(), + cx, + ); + + // Edit the name of the slash command, using one that doesn't exist. + buffer.update(cx, |buffer, cx| { + let edit_offset = buffer.text().find("/file").unwrap(); + buffer.edit( + [(edit_offset..edit_offset + "/file".len(), "/unknown")], + None, + cx, + ); + }); + assert_text_and_output_ranges( + &buffer, + &output_ranges.borrow(), + " + /unknown src/main.rs + " + .unindent() + .trim_end(), + cx, + ); + + #[track_caller] + fn assert_text_and_output_ranges( + buffer: &Model, + ranges: &HashSet>, + expected_marked_text: &str, + cx: &mut TestAppContext, + ) { + let (expected_text, expected_ranges) = marked_text_ranges(expected_marked_text, false); + let (actual_text, actual_ranges) = buffer.update(cx, |buffer, _| { + let mut ranges = ranges + .iter() + .map(|range| range.to_offset(buffer)) + .collect::>(); + ranges.sort_by_key(|a| a.start); + (buffer.text(), ranges) + }); + + assert_eq!(actual_text, expected_text); + assert_eq!(actual_ranges, expected_ranges); + } + } + + #[test] + fn test_parse_next_edit_suggestion() { + let text = " + some output: + + ```edit src/foo.rs + let a = 1; + let b = 2; + --- + let w = 1; + let x = 2; + let y = 3; + let z = 4; + ``` + + some more output: + + ```edit src/foo.rs + let c = 1; + --- + ``` + + and the conclusion. + " + .unindent(); + + let rope = Rope::from(text.as_str()); + let mut lines = rope.chunks().lines(); + let mut suggestions = vec![]; + while let Some(suggestion) = parse_next_edit_suggestion(&mut lines) { + suggestions.push(( + suggestion.path.clone(), + text[suggestion.old_text_range].to_string(), + text[suggestion.new_text_range].to_string(), + )); + } + + assert_eq!( + suggestions, + vec![ + ( + Path::new("src/foo.rs").into(), + [ + " let a = 1;", // + " let b = 2;", + "", + ] + .join("\n"), + [ + " let w = 1;", + " let x = 2;", + " let y = 3;", + " let z = 4;", + "", + ] + .join("\n"), + ), + ( + Path::new("src/foo.rs").into(), + [ + " let c = 1;", // + "", + ] + .join("\n"), + String::new(), + ) + ] + ); + } + + #[gpui::test] + async fn test_serialization(cx: &mut TestAppContext) { + let settings_store = cx.update(SettingsStore::test); + cx.set_global(settings_store); + cx.update(FakeCompletionProvider::setup_test); + cx.update(assistant_panel::init); + let registry = Arc::new(LanguageRegistry::test(cx.executor())); + let context = cx.new_model(|cx| Context::local(registry.clone(), None, cx)); + let buffer = context.read_with(cx, |context, _| context.buffer.clone()); + let message_0 = context.read_with(cx, |context, _| context.message_anchors[0].id); + let message_1 = context.update(cx, |context, cx| { + context + .insert_message_after(message_0, Role::Assistant, MessageStatus::Done, cx) + .unwrap() + }); + let message_2 = context.update(cx, |context, cx| { + context + .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx) + .unwrap() + }); + buffer.update(cx, |buffer, cx| { + buffer.edit([(0..0, "a"), (1..1, "b\nc")], None, cx); + buffer.finalize_last_transaction(); + }); + let _message_3 = context.update(cx, |context, cx| { + context + .insert_message_after(message_2.id, Role::System, MessageStatus::Done, cx) + .unwrap() + }); + buffer.update(cx, |buffer, cx| buffer.undo(cx)); + assert_eq!(buffer.read_with(cx, |buffer, _| buffer.text()), "a\nb\nc\n"); + assert_eq!( + cx.read(|cx| messages(&context, cx)), + [ + (message_0, Role::User, 0..2), + (message_1.id, Role::Assistant, 2..6), + (message_2.id, Role::System, 6..6), + ] + ); + + let serialized_context = context.read_with(cx, |context, cx| context.serialize(cx)); + let deserialized_context = cx.new_model(|cx| { + Context::deserialize( + serialized_context, + Default::default(), + registry.clone(), + None, + cx, + ) + }); + let deserialized_buffer = + deserialized_context.read_with(cx, |context, _| context.buffer.clone()); + assert_eq!( + deserialized_buffer.read_with(cx, |buffer, _| buffer.text()), + "a\nb\nc\n" + ); + assert_eq!( + cx.read(|cx| messages(&deserialized_context, cx)), + [ + (message_0, Role::User, 0..2), + (message_1.id, Role::Assistant, 2..6), + (message_2.id, Role::System, 6..6), + ] + ); + } + + #[gpui::test(iterations = 100)] + async fn test_random_context_collaboration(cx: &mut TestAppContext, mut rng: StdRng) { + let min_peers = env::var("MIN_PEERS") + .map(|i| i.parse().expect("invalid `MIN_PEERS` variable")) + .unwrap_or(2); + let max_peers = env::var("MAX_PEERS") + .map(|i| i.parse().expect("invalid `MAX_PEERS` variable")) + .unwrap_or(5); + let operations = env::var("OPERATIONS") + .map(|i| i.parse().expect("invalid `OPERATIONS` variable")) + .unwrap_or(50); + + let settings_store = cx.update(SettingsStore::test); + cx.set_global(settings_store); + cx.update(FakeCompletionProvider::setup_test); + cx.update(assistant_panel::init); + let slash_commands = cx.update(SlashCommandRegistry::default_global); + slash_commands.register_command(FakeSlashCommand("cmd-1".into()), false); + slash_commands.register_command(FakeSlashCommand("cmd-2".into()), false); + slash_commands.register_command(FakeSlashCommand("cmd-3".into()), false); + + let registry = Arc::new(LanguageRegistry::test(cx.background_executor.clone())); + let network = Arc::new(Mutex::new(Network::new(rng.clone()))); + let mut contexts = Vec::new(); + + let num_peers = rng.gen_range(min_peers..=max_peers); + let context_id = ContextId::new(); + for i in 0..num_peers { + let context = cx.new_model(|cx| { + Context::new( + context_id.clone(), + i as ReplicaId, + language::Capability::ReadWrite, + registry.clone(), + None, + cx, + ) + }); + + cx.update(|cx| { + cx.subscribe(&context, { + let network = network.clone(); + move |_, event, _| { + if let ContextEvent::Operation(op) = event { + network + .lock() + .broadcast(i as ReplicaId, vec![op.to_proto()]); + } + } + }) + .detach(); + }); + + contexts.push(context); + network.lock().add_peer(i as ReplicaId); + } + + let mut mutation_count = operations; + + while mutation_count > 0 + || !network.lock().is_idle() + || network.lock().contains_disconnected_peers() + { + let context_index = rng.gen_range(0..contexts.len()); + let context = &contexts[context_index]; + + match rng.gen_range(0..100) { + 0..=29 if mutation_count > 0 => { + log::info!("Context {}: edit buffer", context_index); + context.update(cx, |context, cx| { + context + .buffer + .update(cx, |buffer, cx| buffer.randomly_edit(&mut rng, 1, cx)); + }); + mutation_count -= 1; + } + 30..=44 if mutation_count > 0 => { + context.update(cx, |context, cx| { + let range = context.buffer.read(cx).random_byte_range(0, &mut rng); + log::info!("Context {}: split message at {:?}", context_index, range); + context.split_message(range, cx); + }); + mutation_count -= 1; + } + 45..=59 if mutation_count > 0 => { + context.update(cx, |context, cx| { + if let Some(message) = context.messages(cx).choose(&mut rng) { + let role = *[Role::User, Role::Assistant, Role::System] + .choose(&mut rng) + .unwrap(); + log::info!( + "Context {}: insert message after {:?} with {:?}", + context_index, + message.id, + role + ); + context.insert_message_after(message.id, role, MessageStatus::Done, cx); + } + }); + mutation_count -= 1; + } + 60..=74 if mutation_count > 0 => { + context.update(cx, |context, cx| { + let command_text = "/".to_string() + + slash_commands + .command_names() + .choose(&mut rng) + .unwrap() + .clone() + .as_ref(); + + let command_range = context.buffer.update(cx, |buffer, cx| { + let offset = buffer.random_byte_range(0, &mut rng).start; + buffer.edit( + [(offset..offset, format!("\n{}\n", command_text))], + None, + cx, + ); + offset + 1..offset + 1 + command_text.len() + }); + + let output_len = rng.gen_range(1..=10); + let output_text = RandomCharIter::new(&mut rng) + .filter(|c| *c != '\r') + .take(output_len) + .collect::(); + + let num_sections = rng.gen_range(0..=3); + let mut sections = Vec::with_capacity(num_sections); + for _ in 0..num_sections { + let section_start = rng.gen_range(0..output_len); + let section_end = rng.gen_range(section_start..=output_len); + sections.push(SlashCommandOutputSection { + range: section_start..section_end, + icon: ui::IconName::Ai, + label: "section".into(), + }); + } + + log::info!( + "Context {}: insert slash command output at {:?} with {:?}", + context_index, + command_range, + sections + ); + + let command_range = + context.buffer.read(cx).anchor_after(command_range.start) + ..context.buffer.read(cx).anchor_after(command_range.end); + context.insert_command_output( + command_range, + Task::ready(Ok(SlashCommandOutput { + text: output_text, + sections, + run_commands_in_text: false, + })), + true, + cx, + ); + }); + cx.run_until_parked(); + mutation_count -= 1; + } + 75..=84 if mutation_count > 0 => { + context.update(cx, |context, cx| { + if let Some(message) = context.messages(cx).choose(&mut rng) { + let new_status = match rng.gen_range(0..3) { + 0 => MessageStatus::Done, + 1 => MessageStatus::Pending, + _ => MessageStatus::Error(SharedString::from("Random error")), + }; + log::info!( + "Context {}: update message {:?} status to {:?}", + context_index, + message.id, + new_status + ); + context.update_metadata(message.id, cx, |metadata| { + metadata.status = new_status; + }); + } + }); + mutation_count -= 1; + } + _ => { + let replica_id = context_index as ReplicaId; + if network.lock().is_disconnected(replica_id) { + network.lock().reconnect_peer(replica_id, 0); + + let (ops_to_send, ops_to_receive) = cx.read(|cx| { + let host_context = &contexts[0].read(cx); + let guest_context = context.read(cx); + ( + guest_context.serialize_ops(&host_context.version(cx), cx), + host_context.serialize_ops(&guest_context.version(cx), cx), + ) + }); + let ops_to_send = ops_to_send.await; + let ops_to_receive = ops_to_receive + .await + .into_iter() + .map(ContextOperation::from_proto) + .collect::>>() + .unwrap(); + log::info!( + "Context {}: reconnecting. Sent {} operations, received {} operations", + context_index, + ops_to_send.len(), + ops_to_receive.len() + ); + + network.lock().broadcast(replica_id, ops_to_send); + context + .update(cx, |context, cx| context.apply_ops(ops_to_receive, cx)) + .unwrap(); + } else if rng.gen_bool(0.1) && replica_id != 0 { + log::info!("Context {}: disconnecting", context_index); + network.lock().disconnect_peer(replica_id); + } else if network.lock().has_unreceived(replica_id) { + log::info!("Context {}: applying operations", context_index); + let ops = network.lock().receive(replica_id); + let ops = ops + .into_iter() + .map(ContextOperation::from_proto) + .collect::>>() + .unwrap(); + context + .update(cx, |context, cx| context.apply_ops(ops, cx)) + .unwrap(); + } + } + } + } + + cx.read(|cx| { + let first_context = contexts[0].read(cx); + for context in &contexts[1..] { + let context = context.read(cx); + assert!(context.pending_ops.is_empty()); + assert_eq!( + context.buffer.read(cx).text(), + first_context.buffer.read(cx).text(), + "Context {} text != Context 0 text", + context.buffer.read(cx).replica_id() + ); + assert_eq!( + context.message_anchors, + first_context.message_anchors, + "Context {} messages != Context 0 messages", + context.buffer.read(cx).replica_id() + ); + assert_eq!( + context.messages_metadata, + first_context.messages_metadata, + "Context {} message metadata != Context 0 message metadata", + context.buffer.read(cx).replica_id() + ); + assert_eq!( + context.slash_command_output_sections, + first_context.slash_command_output_sections, + "Context {} slash command output sections != Context 0 slash command output sections", + context.buffer.read(cx).replica_id() + ); + } + }); + } + + fn messages(context: &Model, cx: &AppContext) -> Vec<(MessageId, Role, Range)> { + context + .read(cx) + .messages(cx) + .map(|message| (message.id, message.role, message.offset_range)) + .collect() + } + + #[derive(Clone)] + struct FakeSlashCommand(String); + + impl SlashCommand for FakeSlashCommand { + fn name(&self) -> String { + self.0.clone() + } + + fn description(&self) -> String { + format!("Fake slash command: {}", self.0) + } + + fn menu_text(&self) -> String { + format!("Run fake command: {}", self.0) + } + + fn complete_argument( + self: Arc, + _query: String, + _cancel: Arc, + _workspace: Option>, + _cx: &mut AppContext, + ) -> Task>> { + Task::ready(Ok(vec![])) + } + + fn requires_argument(&self) -> bool { + false + } + + fn run( + self: Arc, + _argument: Option<&str>, + _workspace: WeakView, + _delegate: Arc, + _cx: &mut WindowContext, + ) -> Task> { + Task::ready(Ok(SlashCommandOutput { + text: format!("Executed fake command: {}", self.0), + sections: vec![], + run_commands_in_text: false, + })) + } + } +} diff --git a/crates/assistant/src/context_store.rs b/crates/assistant/src/context_store.rs index 9e76caea15..9cf60fc014 100644 --- a/crates/assistant/src/context_store.rs +++ b/crates/assistant/src/context_store.rs @@ -1,97 +1,117 @@ -use crate::{assistant_settings::OpenAiModel, MessageId, MessageMetadata}; -use anyhow::{anyhow, Result}; -use assistant_slash_command::SlashCommandOutputSection; -use collections::HashMap; +use crate::{ + Context, ContextEvent, ContextId, ContextOperation, ContextVersion, SavedContext, + SavedContextMetadata, +}; +use anyhow::{anyhow, Context as _, Result}; +use client::{proto, telemetry::Telemetry, Client, TypedEnvelope}; +use clock::ReplicaId; use fs::Fs; use futures::StreamExt; use fuzzy::StringMatchCandidate; -use gpui::{AppContext, Model, ModelContext, Task}; +use gpui::{AppContext, AsyncAppContext, Context as _, Model, ModelContext, Task, WeakModel}; +use language::LanguageRegistry; use paths::contexts_dir; +use project::Project; use regex::Regex; -use serde::{Deserialize, Serialize}; -use std::{cmp::Reverse, ffi::OsStr, path::PathBuf, sync::Arc, time::Duration}; -use ui::Context; +use std::{ + cmp::Reverse, + ffi::OsStr, + mem, + path::{Path, PathBuf}, + sync::Arc, + time::Duration, +}; use util::{ResultExt, TryFutureExt}; -#[derive(Serialize, Deserialize)] -pub struct SavedMessage { - pub id: MessageId, - pub start: usize, -} - -#[derive(Serialize, Deserialize)] -pub struct SavedContext { - pub id: Option, - pub zed: String, - pub version: String, - pub text: String, - pub messages: Vec, - pub message_metadata: HashMap, - pub summary: String, - pub slash_command_output_sections: Vec>, -} - -impl SavedContext { - pub const VERSION: &'static str = "0.3.0"; -} - -#[derive(Serialize, Deserialize)] -pub struct SavedContextV0_2_0 { - pub id: Option, - pub zed: String, - pub version: String, - pub text: String, - pub messages: Vec, - pub message_metadata: HashMap, - pub summary: String, -} - -#[derive(Serialize, Deserialize)] -struct SavedContextV0_1_0 { - id: Option, - zed: String, - version: String, - text: String, - messages: Vec, - message_metadata: HashMap, - summary: String, - api_url: Option, - model: OpenAiModel, +pub fn init(client: &Arc) { + client.add_model_message_handler(ContextStore::handle_advertise_contexts); + client.add_model_request_handler(ContextStore::handle_open_context); + client.add_model_message_handler(ContextStore::handle_update_context); + client.add_model_request_handler(ContextStore::handle_synchronize_contexts); } #[derive(Clone)] -pub struct SavedContextMetadata { - pub title: String, - pub path: PathBuf, - pub mtime: chrono::DateTime, +pub struct RemoteContextMetadata { + pub id: ContextId, + pub summary: Option, } pub struct ContextStore { + contexts: Vec, contexts_metadata: Vec, + host_contexts: Vec, fs: Arc, + languages: Arc, + telemetry: Arc, _watch_updates: Task>, + client: Arc, + project: Model, + project_is_shared: bool, + client_subscription: Option, + _project_subscriptions: Vec, +} + +enum ContextHandle { + Weak(WeakModel), + Strong(Model), +} + +impl ContextHandle { + fn upgrade(&self) -> Option> { + match self { + ContextHandle::Weak(weak) => weak.upgrade(), + ContextHandle::Strong(strong) => Some(strong.clone()), + } + } + + fn downgrade(&self) -> WeakModel { + match self { + ContextHandle::Weak(weak) => weak.clone(), + ContextHandle::Strong(strong) => strong.downgrade(), + } + } } impl ContextStore { - pub fn new(fs: Arc, cx: &mut AppContext) -> Task>> { + pub fn new(project: Model, cx: &mut AppContext) -> Task>> { + let fs = project.read(cx).fs().clone(); + let languages = project.read(cx).languages().clone(); + let telemetry = project.read(cx).client().telemetry().clone(); cx.spawn(|mut cx| async move { const CONTEXT_WATCH_DURATION: Duration = Duration::from_millis(100); let (mut events, _) = fs.watch(contexts_dir(), CONTEXT_WATCH_DURATION).await; - let this = cx.new_model(|cx: &mut ModelContext| Self { - contexts_metadata: Vec::new(), - fs, - _watch_updates: cx.spawn(|this, mut cx| { - async move { - while events.next().await.is_some() { - this.update(&mut cx, |this, cx| this.reload(cx))? - .await - .log_err(); + let this = cx.new_model(|cx: &mut ModelContext| { + let mut this = Self { + contexts: Vec::new(), + contexts_metadata: Vec::new(), + host_contexts: Vec::new(), + fs, + languages, + telemetry, + _watch_updates: cx.spawn(|this, mut cx| { + async move { + while events.next().await.is_some() { + this.update(&mut cx, |this, cx| this.reload(cx))? + .await + .log_err(); + } + anyhow::Ok(()) } - anyhow::Ok(()) - } - .log_err() - }), + .log_err() + }), + client_subscription: None, + _project_subscriptions: vec![ + cx.observe(&project, Self::handle_project_changed), + cx.subscribe(&project, Self::handle_project_event), + ], + project_is_shared: false, + client: project.read(cx).client(), + project: project.clone(), + }; + this.handle_project_changed(project, cx); + this.synchronize_contexts(cx); + this })?; this.update(&mut cx, |this, cx| this.reload(cx))? .await @@ -100,54 +120,433 @@ impl ContextStore { }) } - pub fn load(&self, path: PathBuf, cx: &AppContext) -> Task> { + async fn handle_advertise_contexts( + this: Model, + envelope: TypedEnvelope, + mut cx: AsyncAppContext, + ) -> Result<()> { + this.update(&mut cx, |this, cx| { + this.host_contexts = envelope + .payload + .contexts + .into_iter() + .map(|context| RemoteContextMetadata { + id: ContextId::from_proto(context.context_id), + summary: context.summary, + }) + .collect(); + cx.notify(); + }) + } + + async fn handle_open_context( + this: Model, + envelope: TypedEnvelope, + mut cx: AsyncAppContext, + ) -> Result { + let context_id = ContextId::from_proto(envelope.payload.context_id); + let operations = this.update(&mut cx, |this, cx| { + if this.project.read(cx).is_remote() { + return Err(anyhow!("only the host contexts can be opened")); + } + + let context = this + .loaded_context_for_id(&context_id, cx) + .context("context not found")?; + if context.read(cx).replica_id() != ReplicaId::default() { + return Err(anyhow!("context must be opened via the host")); + } + + anyhow::Ok( + context + .read(cx) + .serialize_ops(&ContextVersion::default(), cx), + ) + })??; + let operations = operations.await; + Ok(proto::OpenContextResponse { + context: Some(proto::Context { operations }), + }) + } + + async fn handle_update_context( + this: Model, + envelope: TypedEnvelope, + mut cx: AsyncAppContext, + ) -> Result<()> { + this.update(&mut cx, |this, cx| { + let context_id = ContextId::from_proto(envelope.payload.context_id); + if let Some(context) = this.loaded_context_for_id(&context_id, cx) { + let operation_proto = envelope.payload.operation.context("invalid operation")?; + let operation = ContextOperation::from_proto(operation_proto)?; + context.update(cx, |context, cx| context.apply_ops([operation], cx))?; + } + Ok(()) + })? + } + + async fn handle_synchronize_contexts( + this: Model, + envelope: TypedEnvelope, + mut cx: AsyncAppContext, + ) -> Result { + this.update(&mut cx, |this, cx| { + if this.project.read(cx).is_remote() { + return Err(anyhow!("only the host can synchronize contexts")); + } + + let mut local_versions = Vec::new(); + for remote_version_proto in envelope.payload.contexts { + let remote_version = ContextVersion::from_proto(&remote_version_proto); + let context_id = ContextId::from_proto(remote_version_proto.context_id); + if let Some(context) = this.loaded_context_for_id(&context_id, cx) { + let context = context.read(cx); + let operations = context.serialize_ops(&remote_version, cx); + local_versions.push(context.version(cx).to_proto(context_id.clone())); + let client = this.client.clone(); + let project_id = envelope.payload.project_id; + cx.background_executor() + .spawn(async move { + let operations = operations.await; + for operation in operations { + client.send(proto::UpdateContext { + project_id, + context_id: context_id.to_proto(), + operation: Some(operation), + })?; + } + anyhow::Ok(()) + }) + .detach_and_log_err(cx); + } + } + + this.advertise_contexts(cx); + + anyhow::Ok(proto::SynchronizeContextsResponse { + contexts: local_versions, + }) + })? + } + + fn handle_project_changed(&mut self, _: Model, cx: &mut ModelContext) { + let is_shared = self.project.read(cx).is_shared(); + let was_shared = mem::replace(&mut self.project_is_shared, is_shared); + if is_shared == was_shared { + return; + } + + if is_shared { + self.contexts.retain_mut(|context| { + if let Some(strong_context) = context.upgrade() { + *context = ContextHandle::Strong(strong_context); + true + } else { + false + } + }); + let remote_id = self.project.read(cx).remote_id().unwrap(); + self.client_subscription = self + .client + .subscribe_to_entity(remote_id) + .log_err() + .map(|subscription| subscription.set_model(&cx.handle(), &mut cx.to_async())); + self.advertise_contexts(cx); + } else { + self.client_subscription = None; + } + } + + fn handle_project_event( + &mut self, + _: Model, + event: &project::Event, + cx: &mut ModelContext, + ) { + match event { + project::Event::Reshared => { + self.advertise_contexts(cx); + } + project::Event::HostReshared | project::Event::Rejoined => { + self.synchronize_contexts(cx); + } + project::Event::DisconnectedFromHost => { + self.contexts.retain_mut(|context| { + if let Some(strong_context) = context.upgrade() { + *context = ContextHandle::Weak(context.downgrade()); + strong_context.update(cx, |context, cx| { + if context.replica_id() != ReplicaId::default() { + context.set_capability(language::Capability::ReadOnly, cx); + } + }); + true + } else { + false + } + }); + self.host_contexts.clear(); + cx.notify(); + } + _ => {} + } + } + + pub fn create(&mut self, cx: &mut ModelContext) -> Model { + let context = cx.new_model(|cx| { + Context::local(self.languages.clone(), Some(self.telemetry.clone()), cx) + }); + self.register_context(&context, cx); + context + } + + pub fn open_local_context( + &mut self, + path: PathBuf, + cx: &ModelContext, + ) -> Task>> { + if let Some(existing_context) = self.loaded_context_for_path(&path, cx) { + return Task::ready(Ok(existing_context)); + } + let fs = self.fs.clone(); - cx.background_executor().spawn(async move { - let saved_context = fs.load(&path).await?; - let saved_context_json = serde_json::from_str::(&saved_context)?; - match saved_context_json - .get("version") - .ok_or_else(|| anyhow!("version not found"))? - { - serde_json::Value::String(version) => match version.as_str() { - SavedContext::VERSION => { - Ok(serde_json::from_value::(saved_context_json)?) - } - "0.2.0" => { - let saved_context = - serde_json::from_value::(saved_context_json)?; - Ok(SavedContext { - id: saved_context.id, - zed: saved_context.zed, - version: saved_context.version, - text: saved_context.text, - messages: saved_context.messages, - message_metadata: saved_context.message_metadata, - summary: saved_context.summary, - slash_command_output_sections: Vec::new(), - }) - } - "0.1.0" => { - let saved_context = - serde_json::from_value::(saved_context_json)?; - Ok(SavedContext { - id: saved_context.id, - zed: saved_context.zed, - version: saved_context.version, - text: saved_context.text, - messages: saved_context.messages, - message_metadata: saved_context.message_metadata, - summary: saved_context.summary, - slash_command_output_sections: Vec::new(), - }) - } - _ => Err(anyhow!("unrecognized saved context version: {}", version)), - }, - _ => Err(anyhow!("version not found on saved context")), + let languages = self.languages.clone(); + let telemetry = self.telemetry.clone(); + let load = cx.background_executor().spawn({ + let path = path.clone(); + async move { + let saved_context = fs.load(&path).await?; + SavedContext::from_json(&saved_context) + } + }); + + cx.spawn(|this, mut cx| async move { + let saved_context = load.await?; + let context = cx.new_model(|cx| { + Context::deserialize(saved_context, path.clone(), languages, Some(telemetry), cx) + })?; + this.update(&mut cx, |this, cx| { + if let Some(existing_context) = this.loaded_context_for_path(&path, cx) { + existing_context + } else { + this.register_context(&context, cx); + context + } + }) + }) + } + + fn loaded_context_for_path(&self, path: &Path, cx: &AppContext) -> Option> { + self.contexts.iter().find_map(|context| { + let context = context.upgrade()?; + if context.read(cx).path() == Some(path) { + Some(context) + } else { + None } }) } + fn loaded_context_for_id(&self, id: &ContextId, cx: &AppContext) -> Option> { + self.contexts.iter().find_map(|context| { + let context = context.upgrade()?; + if context.read(cx).id() == id { + Some(context) + } else { + None + } + }) + } + + pub fn open_remote_context( + &mut self, + context_id: ContextId, + cx: &mut ModelContext, + ) -> Task>> { + let project = self.project.read(cx); + let Some(project_id) = project.remote_id() else { + return Task::ready(Err(anyhow!("project was not remote"))); + }; + if project.is_local() { + return Task::ready(Err(anyhow!("cannot open remote contexts as the host"))); + } + + if let Some(context) = self.loaded_context_for_id(&context_id, cx) { + return Task::ready(Ok(context)); + } + + let replica_id = project.replica_id(); + let capability = project.capability(); + let language_registry = self.languages.clone(); + let telemetry = self.telemetry.clone(); + let request = self.client.request(proto::OpenContext { + project_id, + context_id: context_id.to_proto(), + }); + cx.spawn(|this, mut cx| async move { + let response = request.await?; + let context_proto = response.context.context("invalid context")?; + let context = cx.new_model(|cx| { + Context::new( + context_id.clone(), + replica_id, + capability, + language_registry, + Some(telemetry), + cx, + ) + })?; + let operations = cx + .background_executor() + .spawn(async move { + context_proto + .operations + .into_iter() + .map(|op| ContextOperation::from_proto(op)) + .collect::>>() + }) + .await?; + context.update(&mut cx, |context, cx| context.apply_ops(operations, cx))??; + this.update(&mut cx, |this, cx| { + if let Some(existing_context) = this.loaded_context_for_id(&context_id, cx) { + existing_context + } else { + this.register_context(&context, cx); + this.synchronize_contexts(cx); + context + } + }) + }) + } + + fn register_context(&mut self, context: &Model, cx: &mut ModelContext) { + let handle = if self.project_is_shared { + ContextHandle::Strong(context.clone()) + } else { + ContextHandle::Weak(context.downgrade()) + }; + self.contexts.push(handle); + self.advertise_contexts(cx); + cx.subscribe(context, Self::handle_context_event).detach(); + } + + fn handle_context_event( + &mut self, + context: Model, + event: &ContextEvent, + cx: &mut ModelContext, + ) { + let Some(project_id) = self.project.read(cx).remote_id() else { + return; + }; + + match event { + ContextEvent::SummaryChanged => { + self.advertise_contexts(cx); + } + ContextEvent::Operation(operation) => { + let context_id = context.read(cx).id().to_proto(); + let operation = operation.to_proto(); + self.client + .send(proto::UpdateContext { + project_id, + context_id, + operation: Some(operation), + }) + .log_err(); + } + _ => {} + } + } + + fn advertise_contexts(&self, cx: &AppContext) { + let Some(project_id) = self.project.read(cx).remote_id() else { + return; + }; + + // For now, only the host can advertise their open contexts. + if self.project.read(cx).is_remote() { + return; + } + + let contexts = self + .contexts + .iter() + .rev() + .filter_map(|context| { + let context = context.upgrade()?.read(cx); + if context.replica_id() == ReplicaId::default() { + Some(proto::ContextMetadata { + context_id: context.id().to_proto(), + summary: context.summary().map(|summary| summary.text.clone()), + }) + } else { + None + } + }) + .collect(); + self.client + .send(proto::AdvertiseContexts { + project_id, + contexts, + }) + .ok(); + } + + fn synchronize_contexts(&mut self, cx: &mut ModelContext) { + let Some(project_id) = self.project.read(cx).remote_id() else { + return; + }; + + let contexts = self + .contexts + .iter() + .filter_map(|context| { + let context = context.upgrade()?.read(cx); + if context.replica_id() != ReplicaId::default() { + Some(context.version(cx).to_proto(context.id().clone())) + } else { + None + } + }) + .collect(); + + let client = self.client.clone(); + let request = self.client.request(proto::SynchronizeContexts { + project_id, + contexts, + }); + cx.spawn(|this, cx| async move { + let response = request.await?; + + let mut context_ids = Vec::new(); + let mut operations = Vec::new(); + this.read_with(&cx, |this, cx| { + for context_version_proto in response.contexts { + let context_version = ContextVersion::from_proto(&context_version_proto); + let context_id = ContextId::from_proto(context_version_proto.context_id); + if let Some(context) = this.loaded_context_for_id(&context_id, cx) { + context_ids.push(context_id); + operations.push(context.read(cx).serialize_ops(&context_version, cx)); + } + } + })?; + + let operations = futures::future::join_all(operations).await; + for (context_id, operations) in context_ids.into_iter().zip(operations) { + for operation in operations { + client.send(proto::UpdateContext { + project_id, + context_id: context_id.to_proto(), + operation: Some(operation), + })?; + } + } + + anyhow::Ok(()) + }) + .detach_and_log_err(cx); + } + pub fn search(&self, query: String, cx: &AppContext) -> Task> { let metadata = self.contexts_metadata.clone(); let executor = cx.background_executor().clone(); @@ -178,6 +577,10 @@ impl ContextStore { }) } + pub fn host_contexts(&self) -> &[RemoteContextMetadata] { + &self.host_contexts + } + fn reload(&mut self, cx: &mut ModelContext) -> Task> { let fs = self.fs.clone(); cx.spawn(|this, mut cx| async move { diff --git a/crates/assistant/src/prompt_library.rs b/crates/assistant/src/prompt_library.rs index 1ddb24b02e..9adadd1608 100644 --- a/crates/assistant/src/prompt_library.rs +++ b/crates/assistant/src/prompt_library.rs @@ -3,7 +3,6 @@ use crate::{ InlineAssist, InlineAssistant, LanguageModelRequest, LanguageModelRequestMessage, Role, }; use anyhow::{anyhow, Result}; -use assistant_slash_command::SlashCommandRegistry; use chrono::{DateTime, Utc}; use collections::{HashMap, HashSet}; use editor::{actions::Tab, CurrentLineHighlight, Editor, EditorElement, EditorEvent, EditorStyle}; @@ -448,7 +447,6 @@ impl PromptLibrary { self.set_active_prompt(Some(prompt_id), cx); } else if let Some(prompt_metadata) = self.store.metadata(prompt_id) { let language_registry = self.language_registry.clone(); - let commands = SlashCommandRegistry::global(cx); let prompt = self.store.load(prompt_id); self.pending_load = cx.spawn(|this, mut cx| async move { let prompt = prompt.await; @@ -477,7 +475,7 @@ impl PromptLibrary { editor.set_use_modal_editing(false); editor.set_current_line_highlight(Some(CurrentLineHighlight::None)); editor.set_completion_provider(Box::new( - SlashCommandCompletionProvider::new(commands, None, None), + SlashCommandCompletionProvider::new(None, None), )); if focus { editor.focus(cx); diff --git a/crates/assistant/src/slash_command.rs b/crates/assistant/src/slash_command.rs index ebb563313a..6cec386168 100644 --- a/crates/assistant/src/slash_command.rs +++ b/crates/assistant/src/slash_command.rs @@ -31,7 +31,6 @@ pub mod tabs_command; pub mod term_command; pub(crate) struct SlashCommandCompletionProvider { - commands: Arc, cancel_flag: Mutex>, editor: Option>, workspace: Option>, @@ -46,14 +45,12 @@ pub(crate) struct SlashCommandLine { impl SlashCommandCompletionProvider { pub fn new( - commands: Arc, editor: Option>, workspace: Option>, ) -> Self { Self { cancel_flag: Mutex::new(Arc::new(AtomicBool::new(false))), editor, - commands, workspace, } } @@ -65,8 +62,8 @@ impl SlashCommandCompletionProvider { name_range: Range, cx: &mut WindowContext, ) -> Task>> { - let candidates = self - .commands + let commands = SlashCommandRegistry::global(cx); + let candidates = commands .command_names() .into_iter() .enumerate() @@ -76,7 +73,6 @@ impl SlashCommandCompletionProvider { char_bag: def.as_ref().into(), }) .collect::>(); - let commands = self.commands.clone(); let command_name = command_name.to_string(); let editor = self.editor.clone(); let workspace = self.workspace.clone(); @@ -155,7 +151,8 @@ impl SlashCommandCompletionProvider { flag.store(true, SeqCst); *flag = new_cancel_flag.clone(); - if let Some(command) = self.commands.command(command_name) { + let commands = SlashCommandRegistry::global(cx); + if let Some(command) = commands.command(command_name) { let completions = command.complete_argument( argument, new_cancel_flag.clone(), diff --git a/crates/assistant_slash_command/src/assistant_slash_command.rs b/crates/assistant_slash_command/src/assistant_slash_command.rs index d361f49d42..5f917363a2 100644 --- a/crates/assistant_slash_command/src/assistant_slash_command.rs +++ b/crates/assistant_slash_command/src/assistant_slash_command.rs @@ -67,7 +67,7 @@ pub struct SlashCommandOutput { pub run_commands_in_text: bool, } -#[derive(Clone, Serialize, Deserialize)] +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct SlashCommandOutputSection { pub range: Range, pub icon: IconName, diff --git a/crates/clock/Cargo.toml b/crates/clock/Cargo.toml index d1fb21747b..699a50e70d 100644 --- a/crates/clock/Cargo.toml +++ b/crates/clock/Cargo.toml @@ -18,4 +18,5 @@ test-support = ["dep:parking_lot"] [dependencies] chrono.workspace = true parking_lot = { workspace = true, optional = true } +serde.workspace = true smallvec.workspace = true diff --git a/crates/clock/src/clock.rs b/crates/clock/src/clock.rs index 48d9928bcb..f7d36ed4a8 100644 --- a/crates/clock/src/clock.rs +++ b/crates/clock/src/clock.rs @@ -1,5 +1,6 @@ mod system_clock; +use serde::{Deserialize, Serialize}; use smallvec::SmallVec; use std::{ cmp::{self, Ordering}, @@ -16,7 +17,7 @@ pub type Seq = u32; /// A [Lamport timestamp](https://en.wikipedia.org/wiki/Lamport_timestamp), /// used to determine the ordering of events in the editor. -#[derive(Clone, Copy, Default, Eq, Hash, PartialEq)] +#[derive(Clone, Copy, Default, Eq, Hash, PartialEq, Serialize, Deserialize)] pub struct Lamport { pub replica_id: ReplicaId, pub value: Seq, @@ -161,6 +162,10 @@ impl Lamport { } } + pub fn as_u64(self) -> u64 { + ((self.value as u64) << 32) | (self.replica_id as u64) + } + pub fn tick(&mut self) -> Self { let timestamp = *self; self.value += 1; diff --git a/crates/collab/Cargo.toml b/crates/collab/Cargo.toml index a95682ede0..4f1e1151b6 100644 --- a/crates/collab/Cargo.toml +++ b/crates/collab/Cargo.toml @@ -71,6 +71,7 @@ util.workspace = true uuid.workspace = true [dev-dependencies] +assistant = { workspace = true, features = ["test-support"] } async-trait.workspace = true audio.workspace = true call = { workspace = true, features = ["test-support"] } diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 61aaa98144..42e5c7e94f 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -595,6 +595,14 @@ impl Server { .add_message_handler(user_message_handler(acknowledge_channel_message)) .add_message_handler(user_message_handler(acknowledge_buffer_version)) .add_request_handler(user_handler(get_supermaven_api_key)) + .add_request_handler(user_handler( + forward_mutating_project_request::, + )) + .add_request_handler(user_handler( + forward_mutating_project_request::, + )) + .add_message_handler(broadcast_project_message_from_host::) + .add_message_handler(update_context) .add_streaming_request_handler({ let app_state = app_state.clone(); move |request, response, session| { @@ -3056,6 +3064,53 @@ async fn update_buffer( Ok(()) } +async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> { + let project_id = ProjectId::from_proto(message.project_id); + + let operation = message.operation.as_ref().context("invalid operation")?; + let capability = match operation.variant.as_ref() { + Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => { + if let Some(buffer_op) = buffer_op.operation.as_ref() { + match buffer_op.variant { + None | Some(proto::operation::Variant::UpdateSelections(_)) => { + Capability::ReadOnly + } + _ => Capability::ReadWrite, + } + } else { + Capability::ReadWrite + } + } + Some(_) => Capability::ReadWrite, + None => Capability::ReadOnly, + }; + + let guard = session + .db() + .await + .connections_for_buffer_update( + project_id, + session.principal_id(), + session.connection_id, + capability, + ) + .await?; + + let (host, guests) = &*guard; + + broadcast( + Some(session.connection_id), + guests.iter().chain([host]).copied(), + |connection_id| { + session + .peer + .forward_send(session.connection_id, connection_id, message.clone()) + }, + ); + + Ok(()) +} + /// Notify other participants that a project has been updated. async fn broadcast_project_message_from_host>( request: T, diff --git a/crates/collab/src/tests/integration_tests.rs b/crates/collab/src/tests/integration_tests.rs index f8e5c483af..c4f44e4dad 100644 --- a/crates/collab/src/tests/integration_tests.rs +++ b/crates/collab/src/tests/integration_tests.rs @@ -6,6 +6,7 @@ use crate::{ }, }; use anyhow::{anyhow, Result}; +use assistant::ContextStore; use call::{room, ActiveCall, ParticipantLocation, Room}; use client::{User, RECEIVE_TIMEOUT}; use collections::{HashMap, HashSet}; @@ -6449,3 +6450,123 @@ async fn test_preview_tabs(cx: &mut TestAppContext) { assert!(!pane.can_navigate_forward()); }); } + +#[gpui::test(iterations = 10)] +async fn test_context_collaboration_with_reconnect( + executor: BackgroundExecutor, + cx_a: &mut TestAppContext, + cx_b: &mut TestAppContext, +) { + let mut server = TestServer::start(executor.clone()).await; + let client_a = server.create_client(cx_a, "user_a").await; + let client_b = server.create_client(cx_b, "user_b").await; + server + .create_room(&mut [(&client_a, cx_a), (&client_b, cx_b)]) + .await; + let active_call_a = cx_a.read(ActiveCall::global); + + client_a.fs().insert_tree("/a", Default::default()).await; + let (project_a, _) = client_a.build_local_project("/a", cx_a).await; + let project_id = active_call_a + .update(cx_a, |call, cx| call.share_project(project_a.clone(), cx)) + .await + .unwrap(); + let project_b = client_b.build_dev_server_project(project_id, cx_b).await; + + // Client A sees that a guest has joined. + executor.run_until_parked(); + + project_a.read_with(cx_a, |project, _| { + assert_eq!(project.collaborators().len(), 1); + }); + project_b.read_with(cx_b, |project, _| { + assert_eq!(project.collaborators().len(), 1); + }); + + let context_store_a = cx_a + .update(|cx| ContextStore::new(project_a.clone(), cx)) + .await + .unwrap(); + let context_store_b = cx_b + .update(|cx| ContextStore::new(project_b.clone(), cx)) + .await + .unwrap(); + + // Client A creates a new context. + let context_a = context_store_a.update(cx_a, |store, cx| store.create(cx)); + executor.run_until_parked(); + + // Client B retrieves host's contexts and joins one. + let context_b = context_store_b + .update(cx_b, |store, cx| { + let host_contexts = store.host_contexts().to_vec(); + assert_eq!(host_contexts.len(), 1); + store.open_remote_context(host_contexts[0].id.clone(), cx) + }) + .await + .unwrap(); + + // Host and guest make changes + context_a.update(cx_a, |context, cx| { + context.buffer().update(cx, |buffer, cx| { + buffer.edit([(0..0, "Host change\n")], None, cx) + }) + }); + context_b.update(cx_b, |context, cx| { + context.buffer().update(cx, |buffer, cx| { + buffer.edit([(0..0, "Guest change\n")], None, cx) + }) + }); + executor.run_until_parked(); + assert_eq!( + context_a.read_with(cx_a, |context, cx| context.buffer().read(cx).text()), + "Guest change\nHost change\n" + ); + assert_eq!( + context_b.read_with(cx_b, |context, cx| context.buffer().read(cx).text()), + "Guest change\nHost change\n" + ); + + // Disconnect client A and make some changes while disconnected. + server.disconnect_client(client_a.peer_id().unwrap()); + server.forbid_connections(); + context_a.update(cx_a, |context, cx| { + context.buffer().update(cx, |buffer, cx| { + buffer.edit([(0..0, "Host offline change\n")], None, cx) + }) + }); + context_b.update(cx_b, |context, cx| { + context.buffer().update(cx, |buffer, cx| { + buffer.edit([(0..0, "Guest offline change\n")], None, cx) + }) + }); + executor.run_until_parked(); + assert_eq!( + context_a.read_with(cx_a, |context, cx| context.buffer().read(cx).text()), + "Host offline change\nGuest change\nHost change\n" + ); + assert_eq!( + context_b.read_with(cx_b, |context, cx| context.buffer().read(cx).text()), + "Guest offline change\nGuest change\nHost change\n" + ); + + // Allow client A to reconnect and verify that contexts converge. + server.allow_connections(); + executor.advance_clock(RECEIVE_TIMEOUT); + assert_eq!( + context_a.read_with(cx_a, |context, cx| context.buffer().read(cx).text()), + "Guest offline change\nHost offline change\nGuest change\nHost change\n" + ); + assert_eq!( + context_b.read_with(cx_b, |context, cx| context.buffer().read(cx).text()), + "Guest offline change\nHost offline change\nGuest change\nHost change\n" + ); + + // Client A disconnects without being able to reconnect. Context B becomes readonly. + server.forbid_connections(); + server.disconnect_client(client_a.peer_id().unwrap()); + executor.advance_clock(RECEIVE_TIMEOUT + RECONNECT_TIMEOUT); + context_b.read_with(cx_b, |context, cx| { + assert!(context.buffer().read(cx).read_only()); + }); +} diff --git a/crates/collab/src/tests/test_server.rs b/crates/collab/src/tests/test_server.rs index 231dd08e7c..e9e6ade1a1 100644 --- a/crates/collab/src/tests/test_server.rs +++ b/crates/collab/src/tests/test_server.rs @@ -294,6 +294,8 @@ impl TestServer { menu::init(); dev_server_projects::init(client.clone(), cx); settings::KeymapFile::load_asset(os_keymap, cx).unwrap(); + assistant::FakeCompletionProvider::setup_test(cx); + assistant::context_store::init(&client); }); client diff --git a/crates/language/src/buffer.rs b/crates/language/src/buffer.rs index adba227948..efa0a09b09 100644 --- a/crates/language/src/buffer.rs +++ b/crates/language/src/buffer.rs @@ -1903,6 +1903,10 @@ impl Buffer { self.deferred_ops.insert(deferred_ops); } + pub fn has_deferred_ops(&self) -> bool { + !self.deferred_ops.is_empty() || self.text.has_deferred_ops() + } + fn can_apply_op(&self, operation: &Operation) -> bool { match operation { Operation::Buffer(_) => { diff --git a/crates/language/src/proto.rs b/crates/language/src/proto.rs index 30c209db55..f2634b4285 100644 --- a/crates/language/src/proto.rs +++ b/crates/language/src/proto.rs @@ -1,7 +1,7 @@ //! Handles conversions of `language` items to and from the [`rpc`] protocol. use crate::{diagnostic_set::DiagnosticEntry, CursorShape, Diagnostic}; -use anyhow::{anyhow, Result}; +use anyhow::{anyhow, Context as _, Result}; use clock::ReplicaId; use lsp::{DiagnosticSeverity, LanguageServerId}; use rpc::proto; @@ -231,6 +231,21 @@ pub fn serialize_anchor(anchor: &Anchor) -> proto::Anchor { } } +pub fn serialize_anchor_range(range: Range) -> proto::AnchorRange { + proto::AnchorRange { + start: Some(serialize_anchor(&range.start)), + end: Some(serialize_anchor(&range.end)), + } +} + +/// Deserializes an [`Range`] from the RPC representation. +pub fn deserialize_anchor_range(range: proto::AnchorRange) -> Result> { + Ok( + deserialize_anchor(range.start.context("invalid anchor")?).context("invalid anchor")? + ..deserialize_anchor(range.end.context("invalid anchor")?).context("invalid anchor")?, + ) +} + // This behavior is currently copied in the collab database, for snapshotting channel notes /// Deserializes an [`crate::Operation`] from the RPC representation. pub fn deserialize_operation(message: proto::Operation) -> Result { diff --git a/crates/project/src/project.rs b/crates/project/src/project.rs index 7f10b11bd2..4d99454299 100644 --- a/crates/project/src/project.rs +++ b/crates/project/src/project.rs @@ -355,6 +355,9 @@ pub enum Event { }, CollaboratorJoined(proto::PeerId), CollaboratorLeft(proto::PeerId), + HostReshared, + Reshared, + Rejoined, RefreshInlayHints, RevealInProjectPanel(ProjectEntryId), SnippetEdit(BufferId, Vec<(lsp::Range, Snippet)>), @@ -1716,6 +1719,7 @@ impl Project { self.shared_buffers.clear(); self.set_collaborators_from_proto(message.collaborators, cx)?; self.metadata_changed(cx); + cx.emit(Event::Reshared); Ok(()) } @@ -1753,6 +1757,7 @@ impl Project { .collect(); self.enqueue_buffer_ordered_message(BufferOrderedMessage::Resync) .unwrap(); + cx.emit(Event::Rejoined); cx.notify(); Ok(()) } @@ -1805,9 +1810,11 @@ impl Project { } } - self.client.send(proto::UnshareProject { - project_id: remote_id, - })?; + self.client + .send(proto::UnshareProject { + project_id: remote_id, + }) + .ok(); Ok(()) } else { @@ -8810,6 +8817,7 @@ impl Project { .retain(|_, buffer| !matches!(buffer, OpenBuffer::Operations(_))); this.enqueue_buffer_ordered_message(BufferOrderedMessage::Resync) .unwrap(); + cx.emit(Event::HostReshared); } cx.emit(Event::CollaboratorUpdated { diff --git a/crates/proto/proto/zed.proto b/crates/proto/proto/zed.proto index 3b27c5b536..d410da716e 100644 --- a/crates/proto/proto/zed.proto +++ b/crates/proto/proto/zed.proto @@ -255,7 +255,14 @@ message Envelope { TaskTemplates task_templates = 206; LinkedEditingRange linked_editing_range = 209; - LinkedEditingRangeResponse linked_editing_range_response = 210; // current max + LinkedEditingRangeResponse linked_editing_range_response = 210; + + AdvertiseContexts advertise_contexts = 211; + OpenContext open_context = 212; + OpenContextResponse open_context_response = 213; + UpdateContext update_context = 214; + SynchronizeContexts synchronize_contexts = 215; + SynchronizeContextsResponse synchronize_contexts_response = 216; // current max } reserved 158 to 161; @@ -2222,3 +2229,117 @@ message TaskSourceKind { string name = 1; } } + +message ContextMessageStatus { + oneof variant { + Done done = 1; + Pending pending = 2; + Error error = 3; + } + + message Done {} + + message Pending {} + + message Error { + string message = 1; + } +} + +message ContextMessage { + LamportTimestamp id = 1; + Anchor start = 2; + LanguageModelRole role = 3; + ContextMessageStatus status = 4; +} + +message SlashCommandOutputSection { + AnchorRange range = 1; + string icon_name = 2; + string label = 3; +} + +message ContextOperation { + oneof variant { + InsertMessage insert_message = 1; + UpdateMessage update_message = 2; + UpdateSummary update_summary = 3; + SlashCommandFinished slash_command_finished = 4; + BufferOperation buffer_operation = 5; + } + + message InsertMessage { + ContextMessage message = 1; + repeated VectorClockEntry version = 2; + } + + message UpdateMessage { + LamportTimestamp message_id = 1; + LanguageModelRole role = 2; + ContextMessageStatus status = 3; + LamportTimestamp timestamp = 4; + repeated VectorClockEntry version = 5; + } + + message UpdateSummary { + string summary = 1; + bool done = 2; + LamportTimestamp timestamp = 3; + repeated VectorClockEntry version = 4; + } + + message SlashCommandFinished { + LamportTimestamp id = 1; + AnchorRange output_range = 2; + repeated SlashCommandOutputSection sections = 3; + repeated VectorClockEntry version = 4; + } + + message BufferOperation { + Operation operation = 1; + } +} + +message Context { + repeated ContextOperation operations = 1; +} + +message ContextMetadata { + string context_id = 1; + optional string summary = 2; +} + +message AdvertiseContexts { + uint64 project_id = 1; + repeated ContextMetadata contexts = 2; +} + +message OpenContext { + uint64 project_id = 1; + string context_id = 2; +} + +message OpenContextResponse { + Context context = 1; +} + +message UpdateContext { + uint64 project_id = 1; + string context_id = 2; + ContextOperation operation = 3; +} + +message ContextVersion { + string context_id = 1; + repeated VectorClockEntry context_version = 2; + repeated VectorClockEntry buffer_version = 3; +} + +message SynchronizeContexts { + uint64 project_id = 1; + repeated ContextVersion contexts = 2; +} + +message SynchronizeContextsResponse { + repeated ContextVersion contexts = 1; +} diff --git a/crates/proto/src/proto.rs b/crates/proto/src/proto.rs index 6457f3fa0d..f733278231 100644 --- a/crates/proto/src/proto.rs +++ b/crates/proto/src/proto.rs @@ -337,7 +337,13 @@ messages!( (OpenNewBuffer, Foreground), (RestartLanguageServers, Foreground), (LinkedEditingRange, Background), - (LinkedEditingRangeResponse, Background) + (LinkedEditingRangeResponse, Background), + (AdvertiseContexts, Foreground), + (OpenContext, Foreground), + (OpenContextResponse, Foreground), + (UpdateContext, Foreground), + (SynchronizeContexts, Foreground), + (SynchronizeContextsResponse, Foreground), ); request_messages!( @@ -449,7 +455,9 @@ request_messages!( (DeleteDevServerProject, Ack), (RegenerateDevServerToken, RegenerateDevServerTokenResponse), (RenameDevServer, Ack), - (RestartLanguageServers, Ack) + (RestartLanguageServers, Ack), + (OpenContext, OpenContextResponse), + (SynchronizeContexts, SynchronizeContextsResponse), ); entity_messages!( @@ -511,6 +519,10 @@ entity_messages!( UpdateWorktree, UpdateWorktreeSettings, LspExtExpandMacro, + AdvertiseContexts, + OpenContext, + UpdateContext, + SynchronizeContexts, ); entity_messages!( diff --git a/crates/text/src/network.rs b/crates/text/src/network.rs index 2f49756ca3..f22bb52d20 100644 --- a/crates/text/src/network.rs +++ b/crates/text/src/network.rs @@ -1,12 +1,15 @@ +use std::fmt::Debug; + use clock::ReplicaId; +use collections::{BTreeMap, HashSet}; pub struct Network { - inboxes: std::collections::BTreeMap>>, - all_messages: Vec, + inboxes: BTreeMap>>, + disconnected_peers: HashSet, rng: R, } -#[derive(Clone)] +#[derive(Clone, Debug)] struct Envelope { message: T, } @@ -14,8 +17,8 @@ struct Envelope { impl Network { pub fn new(rng: R) -> Self { Network { - inboxes: Default::default(), - all_messages: Vec::new(), + inboxes: BTreeMap::default(), + disconnected_peers: HashSet::default(), rng, } } @@ -24,6 +27,24 @@ impl Network { self.inboxes.insert(id, Vec::new()); } + pub fn disconnect_peer(&mut self, id: ReplicaId) { + self.disconnected_peers.insert(id); + self.inboxes.get_mut(&id).unwrap().clear(); + } + + pub fn reconnect_peer(&mut self, id: ReplicaId, replicate_from: ReplicaId) { + assert!(self.disconnected_peers.remove(&id)); + self.replicate(replicate_from, id); + } + + pub fn is_disconnected(&self, id: ReplicaId) -> bool { + self.disconnected_peers.contains(&id) + } + + pub fn contains_disconnected_peers(&self) -> bool { + !self.disconnected_peers.is_empty() + } + pub fn replicate(&mut self, old_replica_id: ReplicaId, new_replica_id: ReplicaId) { self.inboxes .insert(new_replica_id, self.inboxes[&old_replica_id].clone()); @@ -34,8 +55,13 @@ impl Network { } pub fn broadcast(&mut self, sender: ReplicaId, messages: Vec) { + // Drop messages from disconnected peers. + if self.disconnected_peers.contains(&sender) { + return; + } + for (replica, inbox) in self.inboxes.iter_mut() { - if *replica != sender { + if *replica != sender && !self.disconnected_peers.contains(replica) { for message in &messages { // Insert one or more duplicates of this message, potentially *before* the previous // message sent by this peer to simulate out-of-order delivery. @@ -51,7 +77,6 @@ impl Network { } } } - self.all_messages.extend(messages); } pub fn has_unreceived(&self, receiver: ReplicaId) -> bool { diff --git a/crates/text/src/text.rs b/crates/text/src/text.rs index a59d4785d4..945483a848 100644 --- a/crates/text/src/text.rs +++ b/crates/text/src/text.rs @@ -1265,6 +1265,10 @@ impl Buffer { } } + pub fn has_deferred_ops(&self) -> bool { + !self.deferred_ops.is_empty() + } + pub fn peek_undo_stack(&self) -> Option<&HistoryEntry> { self.history.undo_stack.last() } diff --git a/crates/ui/src/components/icon.rs b/crates/ui/src/components/icon.rs index 332e30a14d..e8c4ad31ee 100644 --- a/crates/ui/src/components/icon.rs +++ b/crates/ui/src/components/icon.rs @@ -1,6 +1,6 @@ use gpui::{svg, AnimationElement, Hsla, IntoElement, Rems, Transformation}; use serde::{Deserialize, Serialize}; -use strum::EnumIter; +use strum::{EnumIter, EnumString, IntoStaticStr}; use crate::{prelude::*, Indicator}; @@ -90,7 +90,9 @@ impl IconSize { } } -#[derive(Debug, PartialEq, Copy, Clone, EnumIter, Serialize, Deserialize)] +#[derive( + Debug, Eq, PartialEq, Copy, Clone, EnumIter, EnumString, IntoStaticStr, Serialize, Deserialize, +)] pub enum IconName { Ai, ArrowCircle,