From 50c45c7897ce84627437786f3d4837ff263496bb Mon Sep 17 00:00:00 2001 From: Kyle Kelley Date: Thu, 9 May 2024 15:57:14 -0700 Subject: [PATCH] Streaming tools (#11629) Stream characters in for tool calls to allow rendering partial input. https://github.com/zed-industries/zed/assets/836375/0f023a4b-9c46-4449-ae69-8b6bcab41673 Release Notes: - N/A --------- Co-authored-by: Max Brunsfeld Co-authored-by: Marshall Co-authored-by: Max --- Cargo.lock | 20 +- Cargo.toml | 1 + crates/assistant2/Cargo.toml | 1 - crates/assistant2/src/assistant2.rs | 68 +- .../assistant2/src/attachments/active_file.rs | 4 +- crates/assistant2/src/tools/annotate_code.rs | 259 +++++--- crates/assistant2/src/tools/create_buffer.rs | 100 +-- crates/assistant2/src/tools/project_index.rs | 362 ++++++----- crates/assistant_tooling/Cargo.toml | 2 + .../src/assistant_tooling.rs | 8 +- .../src/attachment_registry.rs | 8 +- crates/assistant_tooling/src/tool_registry.rs | 582 ++++++++---------- crates/multi_buffer/src/multi_buffer.rs | 24 +- 13 files changed, 786 insertions(+), 653 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 8e0b11d5ca..1bf917f984 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -390,7 +390,6 @@ dependencies = [ "language", "languages", "log", - "nanoid", "node_runtime", "open_ai", "picker", @@ -419,7 +418,9 @@ dependencies = [ "collections", "futures 0.3.28", "gpui", + "log", "project", + "repair_json", "schemars", "serde", "serde_json", @@ -8050,6 +8051,15 @@ dependencies = [ "bytecheck", ] +[[package]] +name = "repair_json" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5ee191e184125fe72cb59b74160e25584e3908f2aaa84cbda1e161347102aa15" +dependencies = [ + "thiserror", +] + [[package]] name = "reqwest" version = "0.11.20" @@ -10185,18 +10195,18 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.48" +version = "1.0.60" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9d6d7a740b8a666a7e828dd00da9c0dc290dff53154ea77ac109281de90589b7" +checksum = "579e9083ca58dd9dcf91a9923bb9054071b9ebbd800b342194c9feb0ee89fc18" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.48" +version = "1.0.60" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49922ecae66cc8a249b77e68d1d0623c1b2c514f0060c27cdc68bd62a1219d35" +checksum = "e2470041c06ec3ac1ab38d0356a6119054dedaea53e12fbefc0de730a1c08524" dependencies = [ "proc-macro2", "quote", diff --git a/Cargo.toml b/Cargo.toml index 498ba467df..1f1bfe5a5d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -307,6 +307,7 @@ pulldown-cmark = { version = "0.10.0", default-features = false } rand = "0.8.5" refineable = { path = "./crates/refineable" } regex = "1.5" +repair_json = "0.1.0" rusqlite = { version = "0.29.0", features = ["blob", "array", "modern_sqlite"] } rust-embed = { version = "8.0", features = ["include-exclude"] } schemars = "0.8" diff --git a/crates/assistant2/Cargo.toml b/crates/assistant2/Cargo.toml index 27f1ebcbcc..a263aa836d 100644 --- a/crates/assistant2/Cargo.toml +++ b/crates/assistant2/Cargo.toml @@ -29,7 +29,6 @@ fuzzy.workspace = true gpui.workspace = true language.workspace = true log.workspace = true -nanoid.workspace = true open_ai.workspace = true picker.workspace = true project.workspace = true diff --git a/crates/assistant2/src/assistant2.rs b/crates/assistant2/src/assistant2.rs index 41b426e084..35100f9882 100644 --- a/crates/assistant2/src/assistant2.rs +++ b/crates/assistant2/src/assistant2.rs @@ -536,25 +536,27 @@ impl AssistantChat { body.push_str(content); } - for tool_call in delta.tool_calls { - let index = tool_call.index as usize; + for tool_call_delta in delta.tool_calls { + let index = tool_call_delta.index as usize; if index >= message.tool_calls.len() { message.tool_calls.resize_with(index + 1, Default::default); } - let call = &mut message.tool_calls[index]; + let tool_call = &mut message.tool_calls[index]; - if let Some(id) = &tool_call.id { - call.id.push_str(id); + if let Some(id) = &tool_call_delta.id { + tool_call.id.push_str(id); } - match tool_call.variant { - Some(proto::tool_call_delta::Variant::Function(tool_call)) => { - if let Some(name) = &tool_call.name { - call.name.push_str(name); - } - if let Some(arguments) = &tool_call.arguments { - call.arguments.push_str(arguments); - } + match tool_call_delta.variant { + Some(proto::tool_call_delta::Variant::Function( + tool_call_delta, + )) => { + this.tool_registry.update_tool_call( + tool_call, + tool_call_delta.name.as_deref(), + tool_call_delta.arguments.as_deref(), + cx, + ); } None => {} } @@ -587,34 +589,20 @@ impl AssistantChat { } else { if let Some(current_message) = messages.last_mut() { for tool_call in current_message.tool_calls.iter() { - tool_tasks.push(this.tool_registry.call(tool_call, cx)); + tool_tasks + .extend(this.tool_registry.execute_tool_call(&tool_call, cx)); } } } } })?; + // This ends recursion on calling for responses after tools if tool_tasks.is_empty() { return Ok(()); } - let tools = join_all(tool_tasks.into_iter()).await; - // If the WindowContext went away for any tool's view we don't include it - // especially since the below call would fail for the same reason. - let tools = tools.into_iter().filter_map(|tool| tool.ok()).collect(); - - this.update(cx, |this, cx| { - if let Some(ChatMessage::Assistant(AssistantMessage { messages, .. })) = - this.messages.last_mut() - { - if let Some(current_message) = messages.last_mut() { - current_message.tool_calls = tools; - cx.notify(); - } else { - unreachable!() - } - } - })?; + join_all(tool_tasks.into_iter()).await; } } @@ -948,13 +936,11 @@ impl AssistantChat { for tool_call in &message.tool_calls { // Every tool call _must_ have a result by ID, otherwise OpenAI will error. - let content = match &tool_call.result { - Some(result) => { - result.generate(&tool_call.name, &mut project_context, cx) - } - None => "".to_string(), - }; - + let content = self.tool_registry.content_for_tool_call( + tool_call, + &mut project_context, + cx, + ); completion_messages.push(CompletionMessage::Tool { content, tool_call_id: tool_call.id.clone(), @@ -1003,7 +989,11 @@ impl AssistantChat { tool_calls: message .tool_calls .iter() - .map(|tool_call| self.tool_registry.serialize_tool_call(tool_call)) + .filter_map(|tool_call| { + self.tool_registry + .serialize_tool_call(tool_call, cx) + .log_err() + }) .collect(), }) .collect(), diff --git a/crates/assistant2/src/attachments/active_file.rs b/crates/assistant2/src/attachments/active_file.rs index 811eb4219c..744d92689f 100644 --- a/crates/assistant2/src/attachments/active_file.rs +++ b/crates/assistant2/src/attachments/active_file.rs @@ -1,7 +1,7 @@ use std::{path::PathBuf, sync::Arc}; use anyhow::{anyhow, Result}; -use assistant_tooling::{LanguageModelAttachment, ProjectContext, ToolOutput}; +use assistant_tooling::{AttachmentOutput, LanguageModelAttachment, ProjectContext}; use editor::Editor; use gpui::{Render, Task, View, WeakModel, WeakView}; use language::Buffer; @@ -52,7 +52,7 @@ impl Render for FileAttachmentView { } } -impl ToolOutput for FileAttachmentView { +impl AttachmentOutput for FileAttachmentView { fn generate(&self, project: &mut ProjectContext, cx: &mut WindowContext) -> String { if let Some(path) = &self.project_path { project.add_file(path.clone()); diff --git a/crates/assistant2/src/tools/annotate_code.rs b/crates/assistant2/src/tools/annotate_code.rs index f2427bd440..29a7d8cb96 100644 --- a/crates/assistant2/src/tools/annotate_code.rs +++ b/crates/assistant2/src/tools/annotate_code.rs @@ -4,7 +4,8 @@ use editor::{ display_map::{BlockContext, BlockDisposition, BlockProperties, BlockStyle}, Editor, MultiBuffer, }; -use gpui::{prelude::*, AnyElement, Model, Task, View, WeakView}; +use futures::{channel::mpsc::UnboundedSender, StreamExt as _}; +use gpui::{prelude::*, AnyElement, AsyncWindowContext, Model, Task, View, WeakView}; use language::ToPoint; use project::{search::SearchQuery, Project, ProjectPath}; use schemars::JsonSchema; @@ -25,14 +26,19 @@ impl AnnotationTool { } } -#[derive(Debug, Deserialize, JsonSchema, Clone)] +#[derive(Default, Debug, Deserialize, JsonSchema, Clone)] pub struct AnnotationInput { /// Name for this set of annotations + #[serde(default = "default_title")] title: String, /// Excerpts from the file to show to the user. excerpts: Vec, } +fn default_title() -> String { + "Untitled".to_string() +} + #[derive(Debug, Deserialize, JsonSchema, Clone)] struct Excerpt { /// Path to the file @@ -44,8 +50,6 @@ struct Excerpt { } impl LanguageModelTool for AnnotationTool { - type Input = AnnotationInput; - type Output = String; type View = AnnotationResultView; fn name(&self) -> String { @@ -56,67 +60,100 @@ impl LanguageModelTool for AnnotationTool { "Dynamically annotate symbols in the current codebase. Opens a buffer in a panel in their editor, to the side of the conversation. The annotations are shown in the editor as a block decoration.".to_string() } - fn execute(&self, input: &Self::Input, cx: &mut WindowContext) -> Task> { - let workspace = self.workspace.clone(); - let project = self.project.clone(); - let excerpts = input.excerpts.clone(); - let title = input.title.clone(); + fn view(&self, cx: &mut WindowContext) -> View { + cx.new_view(|cx| { + let (tx, mut rx) = futures::channel::mpsc::unbounded(); + cx.spawn(|view, mut cx| async move { + while let Some(excerpt) = rx.next().await { + AnnotationResultView::add_excerpt(view.clone(), excerpt, &mut cx).await?; + } + anyhow::Ok(()) + }) + .detach(); + + AnnotationResultView { + project: self.project.clone(), + workspace: self.workspace.clone(), + tx, + pending_excerpt: None, + added_editor_to_workspace: false, + editor: None, + error: None, + rendered_excerpt_count: 0, + } + }) + } +} + +pub struct AnnotationResultView { + workspace: WeakView, + project: Model, + pending_excerpt: Option, + added_editor_to_workspace: bool, + editor: Option>, + tx: UnboundedSender, + error: Option, + rendered_excerpt_count: usize, +} + +impl AnnotationResultView { + async fn add_excerpt( + this: WeakView, + excerpt: Excerpt, + cx: &mut AsyncWindowContext, + ) -> Result<()> { + let project = this.update(cx, |this, _cx| this.project.clone())?; let worktree_id = project.update(cx, |project, cx| { let worktree = project.worktrees().next()?; let worktree_id = worktree.read(cx).id(); Some(worktree_id) - }); + })?; let worktree_id = if let Some(worktree_id) = worktree_id { worktree_id } else { - return Task::ready(Err(anyhow::anyhow!("No worktree found"))); + return Err(anyhow::anyhow!("No worktree found")); }; - let buffer_tasks = project.update(cx, |project, cx| { - excerpts - .iter() - .map(|excerpt| { - project.open_buffer( - ProjectPath { - worktree_id, - path: Path::new(&excerpt.path).into(), - }, - cx, - ) + let buffer_task = project.update(cx, |project, cx| { + project.open_buffer( + ProjectPath { + worktree_id, + path: Path::new(&excerpt.path).into(), + }, + cx, + ) + })?; + + let buffer = match buffer_task.await { + Ok(buffer) => buffer, + Err(error) => { + return this.update(cx, |this, cx| { + this.error = Some(error); + cx.notify(); }) - .collect::>() - }); + } + }; - cx.spawn(move |mut cx| async move { - let buffers = futures::future::try_join_all(buffer_tasks).await?; + let snapshot = buffer.update(cx, |buffer, _cx| buffer.snapshot())?; + let query = SearchQuery::text(&excerpt.text_passage, false, false, false, vec![], vec![])?; + let matches = query.search(&snapshot, None).await; + let Some(first_match) = matches.first() else { + log::warn!( + "text {:?} does not appear in '{}'", + excerpt.text_passage, + excerpt.path + ); + return Ok(()); + }; - let multibuffer = cx.new_model(|_cx| { - MultiBuffer::new(0, language::Capability::ReadWrite).with_title(title) - })?; - let editor = - cx.new_view(|cx| Editor::for_multibuffer(multibuffer, Some(project), cx))?; + this.update(cx, |this, cx| { + let mut start = first_match.start.to_point(&snapshot); + start.column = 0; - for (excerpt, buffer) in excerpts.iter().zip(buffers.iter()) { - let snapshot = buffer.update(&mut cx, |buffer, _cx| buffer.snapshot())?; - - let query = - SearchQuery::text(&excerpt.text_passage, false, false, false, vec![], vec![])?; - - let matches = query.search(&snapshot, None).await; - let Some(first_match) = matches.first() else { - log::warn!( - "text {:?} does not appear in '{}'", - excerpt.text_passage, - excerpt.path - ); - continue; - }; - let mut start = first_match.start.to_point(&snapshot); - start.column = 0; - - editor.update(&mut cx, |editor, cx| { + if let Some(editor) = &this.editor { + editor.update(cx, |editor, cx| { let ranges = editor.buffer().update(cx, |multibuffer, cx| { multibuffer.push_excerpts_with_context_lines( buffer.clone(), @@ -125,7 +162,8 @@ impl LanguageModelTool for AnnotationTool { cx, ) }); - let annotation = SharedString::from(excerpt.annotation.clone()); + + let annotation = SharedString::from(excerpt.annotation); editor.insert_blocks( [BlockProperties { position: ranges[0].start, @@ -137,30 +175,22 @@ impl LanguageModelTool for AnnotationTool { None, cx, ); - })?; + }); + + if !this.added_editor_to_workspace { + this.added_editor_to_workspace = true; + this.workspace + .update(cx, |workspace, cx| { + workspace.add_item_to_active_pane(Box::new(editor.clone()), None, cx); + }) + .log_err(); + } } + })?; - workspace - .update(&mut cx, |workspace, cx| { - workspace.add_item_to_active_pane(Box::new(editor.clone()), None, cx); - }) - .log_err(); - - anyhow::Ok("showed comments to users in a new view".into()) - }) + Ok(()) } - fn view( - &self, - _: Self::Input, - output: Result, - cx: &mut WindowContext, - ) -> View { - cx.new_view(|_cx| AnnotationResultView { output }) - } -} - -impl AnnotationTool { fn render_note_block(explanation: &SharedString, cx: &mut BlockContext) -> AnyElement { let anchor_x = cx.anchor_x; let gutter_width = cx.gutter_dimensions.width; @@ -186,24 +216,89 @@ impl AnnotationTool { } } -pub struct AnnotationResultView { - output: Result, -} - impl Render for AnnotationResultView { fn render(&mut self, _cx: &mut ViewContext) -> impl IntoElement { - match &self.output { - Ok(output) => div().child(output.clone().into_any_element()), - Err(error) => div().child(format!("failed to open path: {:?}", error)), + if let Some(error) = &self.error { + ui::Label::new(error.to_string()).into_any_element() + } else { + ui::Label::new(SharedString::from(format!( + "Opened a buffer with {} excerpts", + self.rendered_excerpt_count + ))) + .into_any_element() } } } impl ToolOutput for AnnotationResultView { - fn generate(&self, _: &mut ProjectContext, _: &mut WindowContext) -> String { - match &self.output { - Ok(output) => output.clone(), - Err(err) => format!("Failed to create buffer: {err:?}"), + type Input = AnnotationInput; + type SerializedState = Option; + + fn generate(&self, _: &mut ProjectContext, _: &mut ViewContext) -> String { + if let Some(error) = &self.error { + format!("Failed to create buffer: {error:?}") + } else { + format!( + "opened {} excerpts in a buffer", + self.rendered_excerpt_count + ) } } + + fn set_input(&mut self, mut input: Self::Input, cx: &mut ViewContext) { + let editor = if let Some(editor) = &self.editor { + editor.clone() + } else { + let multibuffer = cx.new_model(|_cx| { + MultiBuffer::new(0, language::Capability::ReadWrite).with_title(String::new()) + }); + let editor = cx.new_view(|cx| { + Editor::for_multibuffer(multibuffer.clone(), Some(self.project.clone()), cx) + }); + + self.editor = Some(editor.clone()); + editor + }; + + editor.update(cx, |editor, cx| { + editor.buffer().update(cx, |multibuffer, cx| { + if multibuffer.title(cx) != input.title { + multibuffer.set_title(input.title.clone(), cx); + } + }); + + self.pending_excerpt = input.excerpts.pop(); + for excerpt in input.excerpts.iter().skip(self.rendered_excerpt_count) { + self.tx.unbounded_send(excerpt.clone()).ok(); + } + self.rendered_excerpt_count = input.excerpts.len(); + }); + + cx.notify(); + } + + fn execute(&mut self, _cx: &mut ViewContext) -> Task> { + if let Some(excerpt) = self.pending_excerpt.take() { + self.rendered_excerpt_count += 1; + self.tx.unbounded_send(excerpt.clone()).ok(); + } + + self.tx.close_channel(); + Task::ready(Ok(())) + } + + fn serialize(&self, _cx: &mut ViewContext) -> Self::SerializedState { + self.error.as_ref().map(|error| error.to_string()) + } + + fn deserialize( + &mut self, + output: Self::SerializedState, + _cx: &mut ViewContext, + ) -> Result<()> { + if let Some(error_message) = output { + self.error = Some(anyhow::anyhow!("{}", error_message)); + } + Ok(()) + } } diff --git a/crates/assistant2/src/tools/create_buffer.rs b/crates/assistant2/src/tools/create_buffer.rs index 5563615bdd..9051b69856 100644 --- a/crates/assistant2/src/tools/create_buffer.rs +++ b/crates/assistant2/src/tools/create_buffer.rs @@ -1,4 +1,4 @@ -use anyhow::Result; +use anyhow::{anyhow, Result}; use assistant_tooling::{LanguageModelTool, ProjectContext, ToolOutput}; use editor::Editor; use gpui::{prelude::*, Model, Task, View, WeakView}; @@ -20,7 +20,7 @@ impl CreateBufferTool { } } -#[derive(Debug, Deserialize, JsonSchema)] +#[derive(Debug, Clone, Deserialize, JsonSchema)] pub struct CreateBufferInput { /// The contents of the buffer. text: String, @@ -32,8 +32,6 @@ pub struct CreateBufferInput { } impl LanguageModelTool for CreateBufferTool { - type Input = CreateBufferInput; - type Output = (); type View = CreateBufferView; fn name(&self) -> String { @@ -44,13 +42,59 @@ impl LanguageModelTool for CreateBufferTool { "Create a new buffer in the current codebase".to_string() } - fn execute(&self, input: &Self::Input, cx: &mut WindowContext) -> Task> { + fn view(&self, cx: &mut WindowContext) -> View { + cx.new_view(|_cx| CreateBufferView { + workspace: self.workspace.clone(), + project: self.project.clone(), + input: None, + error: None, + }) + } +} + +pub struct CreateBufferView { + workspace: WeakView, + project: Model, + input: Option, + error: Option, +} + +impl Render for CreateBufferView { + fn render(&mut self, _cx: &mut ViewContext) -> impl IntoElement { + div().child("Opening a buffer") + } +} + +impl ToolOutput for CreateBufferView { + type Input = CreateBufferInput; + + type SerializedState = (); + + fn generate(&self, _project: &mut ProjectContext, _cx: &mut ViewContext) -> String { + let Some(input) = self.input.as_ref() else { + return "No input".to_string(); + }; + + match &self.error { + None => format!("Created a new {} buffer", input.language), + Some(err) => format!("Failed to create buffer: {err:?}"), + } + } + + fn set_input(&mut self, input: Self::Input, _cx: &mut ViewContext) { + self.input = Some(input); + } + + fn execute(&mut self, cx: &mut ViewContext) -> Task> { cx.spawn({ let workspace = self.workspace.clone(); let project = self.project.clone(); - let text = input.text.clone(); - let language_name = input.language.clone(); - |mut cx| async move { + let input = self.input.clone(); + |_this, mut cx| async move { + let input = input.ok_or_else(|| anyhow!("no input"))?; + + let text = input.text.clone(); + let language_name = input.language.clone(); let language = cx .update(|cx| { project @@ -86,35 +130,15 @@ impl LanguageModelTool for CreateBufferTool { }) } - fn view( - &self, - input: Self::Input, - output: Result, - cx: &mut WindowContext, - ) -> View { - cx.new_view(|_cx| CreateBufferView { - language: input.language, - output, - }) - } -} - -pub struct CreateBufferView { - language: String, - output: Result<()>, -} - -impl Render for CreateBufferView { - fn render(&mut self, _cx: &mut ViewContext) -> impl IntoElement { - div().child("Opening a buffer") - } -} - -impl ToolOutput for CreateBufferView { - fn generate(&self, _: &mut ProjectContext, _: &mut WindowContext) -> String { - match &self.output { - Ok(_) => format!("Created a new {} buffer", self.language), - Err(err) => format!("Failed to create buffer: {err:?}"), - } + fn serialize(&self, _cx: &mut ViewContext) -> Self::SerializedState { + () + } + + fn deserialize( + &mut self, + _output: Self::SerializedState, + _cx: &mut ViewContext, + ) -> Result<()> { + Ok(()) } } diff --git a/crates/assistant2/src/tools/project_index.rs b/crates/assistant2/src/tools/project_index.rs index 0c43ef09b9..36f9e6f962 100644 --- a/crates/assistant2/src/tools/project_index.rs +++ b/crates/assistant2/src/tools/project_index.rs @@ -1,4 +1,4 @@ -use anyhow::{anyhow, Result}; +use anyhow::Result; use assistant_tooling::{LanguageModelTool, ToolOutput}; use collections::BTreeMap; use gpui::{prelude::*, Model, Task}; @@ -6,9 +6,8 @@ use project::ProjectPath; use schemars::JsonSchema; use semantic_index::{ProjectIndex, Status}; use serde::{Deserialize, Serialize}; -use serde_json::Value; use std::{fmt::Write as _, ops::Range, path::Path, sync::Arc}; -use ui::{div, prelude::*, CollapsibleContainer, Color, Icon, IconName, Label, WindowContext}; +use ui::{prelude::*, CollapsibleContainer, Color, Icon, IconName, Label, WindowContext}; const DEFAULT_SEARCH_LIMIT: usize = 20; @@ -16,10 +15,26 @@ pub struct ProjectIndexTool { project_index: Model, } -// Note: Comments on a `LanguageModelTool::Input` become descriptions on the generated JSON schema as shown to the language model. -// Any changes or deletions to the `CodebaseQuery` comments will change model behavior. +#[derive(Default)] +enum ProjectIndexToolState { + #[default] + CollectingQuery, + Searching, + Error(anyhow::Error), + Finished { + excerpts: BTreeMap>>, + index_status: Status, + }, +} -#[derive(Deserialize, JsonSchema)] +pub struct ProjectIndexView { + project_index: Model, + input: CodebaseQuery, + expanded_header: bool, + state: ProjectIndexToolState, +} + +#[derive(Default, Deserialize, JsonSchema)] pub struct CodebaseQuery { /// Semantic search query query: String, @@ -27,21 +42,14 @@ pub struct CodebaseQuery { limit: Option, } -pub struct ProjectIndexView { - input: CodebaseQuery, - status: Status, - excerpts: Result>>>, - element_id: ElementId, - expanded_header: bool, -} - #[derive(Serialize, Deserialize)] -pub struct ProjectIndexOutput { - status: Status, +pub struct SerializedState { + index_status: Status, + error_message: Option, worktrees: BTreeMap, WorktreeIndexOutput>, } -#[derive(Serialize, Deserialize)] +#[derive(Default, Serialize, Deserialize)] struct WorktreeIndexOutput { excerpts: BTreeMap, Vec>>, } @@ -56,58 +64,80 @@ impl ProjectIndexView { impl Render for ProjectIndexView { fn render(&mut self, cx: &mut ViewContext) -> impl IntoElement { let query = self.input.query.clone(); - let excerpts = match &self.excerpts { - Err(err) => { - return div().child(Label::new(format!("Error: {}", err)).color(Color::Error)); + + let (header_text, content) = match &self.state { + ProjectIndexToolState::Error(error) => { + return format!("failed to search: {error:?}").into_any_element() + } + ProjectIndexToolState::CollectingQuery | ProjectIndexToolState::Searching => { + ("Searching...".to_string(), div()) + } + ProjectIndexToolState::Finished { excerpts, .. } => { + let file_count = excerpts.len(); + + let header_text = format!( + "Read {} {}", + file_count, + if file_count == 1 { "file" } else { "files" } + ); + + let el = v_flex().gap_2().children(excerpts.keys().map(|path| { + h_flex().gap_2().child(Icon::new(IconName::File)).child( + Label::new(path.path.to_string_lossy().to_string()).color(Color::Muted), + ) + })); + + (header_text, el) } - Ok(excerpts) => excerpts, }; - let file_count = excerpts.len(); let header = h_flex() .gap_2() .child(Icon::new(IconName::File)) - .child(format!( - "Read {} {}", - file_count, - if file_count == 1 { "file" } else { "files" } - )); + .child(header_text); - v_flex().gap_3().child( - CollapsibleContainer::new(self.element_id.clone(), self.expanded_header) - .start_slot(header) - .on_click(cx.listener(move |this, _, cx| { - this.toggle_header(cx); - })) - .child( - v_flex() - .gap_3() - .p_3() - .child( - h_flex() - .gap_2() - .child(Icon::new(IconName::MagnifyingGlass)) - .child(Label::new(format!("`{}`", query)).color(Color::Muted)), - ) - .child(v_flex().gap_2().children(excerpts.keys().map(|path| { - h_flex().gap_2().child(Icon::new(IconName::File)).child( - Label::new(path.path.to_string_lossy().to_string()) - .color(Color::Muted), + v_flex() + .gap_3() + .child( + CollapsibleContainer::new("collapsible-container", self.expanded_header) + .start_slot(header) + .on_click(cx.listener(move |this, _, cx| { + this.toggle_header(cx); + })) + .child( + v_flex() + .gap_3() + .p_3() + .child( + h_flex() + .gap_2() + .child(Icon::new(IconName::MagnifyingGlass)) + .child(Label::new(format!("`{}`", query)).color(Color::Muted)), ) - }))), - ), - ) + .child(content), + ), + ) + .into_any_element() } } impl ToolOutput for ProjectIndexView { + type Input = CodebaseQuery; + type SerializedState = SerializedState; + fn generate( &self, context: &mut assistant_tooling::ProjectContext, - _: &mut WindowContext, + _: &mut ViewContext, ) -> String { - match &self.excerpts { - Ok(excerpts) => { + match &self.state { + ProjectIndexToolState::CollectingQuery => String::new(), + ProjectIndexToolState::Searching => String::new(), + ProjectIndexToolState::Error(error) => format!("failed to search: {error:?}"), + ProjectIndexToolState::Finished { + excerpts, + index_status, + } => { let mut body = "found results in the following paths:\n".to_string(); for (project_path, ranges) in excerpts { @@ -115,15 +145,126 @@ impl ToolOutput for ProjectIndexView { writeln!(&mut body, "* {}", &project_path.path.display()).unwrap(); } - if self.status != Status::Idle { + if *index_status != Status::Idle { body.push_str("Still indexing. Results may be incomplete.\n"); } body } - Err(err) => format!("Error: {}", err), } } + + fn set_input(&mut self, input: Self::Input, cx: &mut ViewContext) { + self.input = input; + cx.notify(); + } + + fn execute(&mut self, cx: &mut ViewContext) -> Task> { + self.state = ProjectIndexToolState::Searching; + cx.notify(); + + let project_index = self.project_index.read(cx); + let index_status = project_index.status(); + let search = project_index.search( + self.input.query.clone(), + self.input.limit.unwrap_or(DEFAULT_SEARCH_LIMIT), + cx, + ); + + cx.spawn(|this, mut cx| async move { + let search_result = search.await; + this.update(&mut cx, |this, cx| { + match search_result { + Ok(search_results) => { + let mut excerpts = BTreeMap::>>::new(); + for search_result in search_results { + let project_path = ProjectPath { + worktree_id: search_result.worktree.read(cx).id(), + path: search_result.path, + }; + excerpts + .entry(project_path) + .or_default() + .push(search_result.range); + } + this.state = ProjectIndexToolState::Finished { + excerpts, + index_status, + }; + } + Err(error) => { + this.state = ProjectIndexToolState::Error(error); + } + } + cx.notify(); + }) + }) + } + + fn serialize(&self, cx: &mut ViewContext) -> Self::SerializedState { + let mut serialized = SerializedState { + error_message: None, + index_status: Status::Idle, + worktrees: Default::default(), + }; + match &self.state { + ProjectIndexToolState::Error(err) => serialized.error_message = Some(err.to_string()), + ProjectIndexToolState::Finished { + excerpts, + index_status, + } => { + serialized.index_status = *index_status; + if let Some(project) = self.project_index.read(cx).project().upgrade() { + let project = project.read(cx); + for (project_path, excerpts) in excerpts { + if let Some(worktree) = + project.worktree_for_id(project_path.worktree_id, cx) + { + let worktree_path = worktree.read(cx).abs_path(); + serialized + .worktrees + .entry(worktree_path) + .or_default() + .excerpts + .insert(project_path.path.clone(), excerpts.clone()); + } + } + } + } + _ => {} + } + serialized + } + + fn deserialize( + &mut self, + serialized: Self::SerializedState, + cx: &mut ViewContext, + ) -> Result<()> { + if !serialized.worktrees.is_empty() { + let mut excerpts = BTreeMap::>>::new(); + if let Some(project) = self.project_index.read(cx).project().upgrade() { + let project = project.read(cx); + for (worktree_path, worktree_state) in serialized.worktrees { + if let Some(worktree) = project + .worktrees() + .find(|worktree| worktree.read(cx).abs_path() == worktree_path) + { + let worktree_id = worktree.read(cx).id(); + for (path, serialized_excerpts) in worktree_state.excerpts { + excerpts.insert(ProjectPath { worktree_id, path }, serialized_excerpts); + } + } + } + } + self.state = ProjectIndexToolState::Finished { + excerpts, + index_status: serialized.index_status, + }; + } + cx.notify(); + Ok(()) + } } impl ProjectIndexTool { @@ -133,8 +274,6 @@ impl ProjectIndexTool { } impl LanguageModelTool for ProjectIndexTool { - type Input = CodebaseQuery; - type Output = ProjectIndexOutput; type View = ProjectIndexView; fn name(&self) -> String { @@ -145,109 +284,12 @@ impl LanguageModelTool for ProjectIndexTool { "Semantic search against the user's current codebase, returning excerpts related to the query by computing a dot product against embeddings of code chunks in the code base and an embedding of the query.".to_string() } - fn execute(&self, query: &Self::Input, cx: &mut WindowContext) -> Task> { - let project_index = self.project_index.read(cx); - let status = project_index.status(); - let search = project_index.search( - query.query.clone(), - query.limit.unwrap_or(DEFAULT_SEARCH_LIMIT), - cx, - ); - - cx.spawn(|mut cx| async move { - let search_results = search.await?; - - cx.update(|cx| { - let mut output = ProjectIndexOutput { - status, - worktrees: Default::default(), - }; - - for search_result in search_results { - let worktree_path = search_result.worktree.read(cx).abs_path(); - let excerpts = &mut output - .worktrees - .entry(worktree_path) - .or_insert(WorktreeIndexOutput { - excerpts: Default::default(), - }) - .excerpts; - - let excerpts_for_path = excerpts.entry(search_result.path).or_default(); - let ix = match excerpts_for_path - .binary_search_by_key(&search_result.range.start, |r| r.start) - { - Ok(ix) | Err(ix) => ix, - }; - excerpts_for_path.insert(ix, search_result.range); - } - - output - }) + fn view(&self, cx: &mut WindowContext) -> gpui::View { + cx.new_view(|_| ProjectIndexView { + state: ProjectIndexToolState::CollectingQuery, + input: Default::default(), + expanded_header: false, + project_index: self.project_index.clone(), }) } - - fn view( - &self, - input: Self::Input, - output: Result, - cx: &mut WindowContext, - ) -> gpui::View { - cx.new_view(|cx| { - let status; - let excerpts; - match output { - Ok(output) => { - status = output.status; - let project_index = self.project_index.read(cx); - if let Some(project) = project_index.project().upgrade() { - let project = project.read(cx); - excerpts = Ok(output - .worktrees - .into_iter() - .filter_map(|(abs_path, output)| { - for worktree in project.worktrees() { - let worktree = worktree.read(cx); - if worktree.abs_path() == abs_path { - return Some((worktree.id(), output.excerpts)); - } - } - None - }) - .flat_map(|(worktree_id, excerpts)| { - excerpts.into_iter().map(move |(path, ranges)| { - (ProjectPath { worktree_id, path }, ranges) - }) - }) - .collect::>()); - } else { - excerpts = Err(anyhow!("project was dropped")); - } - } - Err(err) => { - status = Status::Idle; - excerpts = Err(err); - } - }; - - ProjectIndexView { - input, - status, - excerpts, - element_id: ElementId::Name(nanoid::nanoid!().into()), - expanded_header: false, - } - }) - } - - fn render_running(arguments: &Option, _: &mut WindowContext) -> impl IntoElement { - let text: String = arguments - .as_ref() - .and_then(|arguments| arguments.get("query")) - .and_then(|query| query.as_str()) - .map(|query| format!("Searching for: {}", query)) - .unwrap_or_else(|| "Preparing search...".to_string()); - - CollapsibleContainer::new(ElementId::Name(nanoid::nanoid!().into()), false).start_slot(text) - } } diff --git a/crates/assistant_tooling/Cargo.toml b/crates/assistant_tooling/Cargo.toml index c7290f9c98..79f41faad2 100644 --- a/crates/assistant_tooling/Cargo.toml +++ b/crates/assistant_tooling/Cargo.toml @@ -16,7 +16,9 @@ anyhow.workspace = true collections.workspace = true futures.workspace = true gpui.workspace = true +log.workspace = true project.workspace = true +repair_json.workspace = true schemars.workspace = true serde.workspace = true serde_json.workspace = true diff --git a/crates/assistant_tooling/src/assistant_tooling.rs b/crates/assistant_tooling/src/assistant_tooling.rs index e5aff01edf..dd4dac39e9 100644 --- a/crates/assistant_tooling/src/assistant_tooling.rs +++ b/crates/assistant_tooling/src/assistant_tooling.rs @@ -3,11 +3,11 @@ mod project_context; mod tool_registry; pub use attachment_registry::{ - AttachmentRegistry, LanguageModelAttachment, SavedUserAttachment, UserAttachment, + AttachmentOutput, AttachmentRegistry, LanguageModelAttachment, SavedUserAttachment, + UserAttachment, }; pub use project_context::ProjectContext; pub use tool_registry::{ - tool_running_placeholder, LanguageModelTool, SavedToolFunctionCall, - SavedToolFunctionCallResult, ToolFunctionCall, ToolFunctionCallResult, ToolFunctionDefinition, - ToolOutput, ToolRegistry, + tool_running_placeholder, LanguageModelTool, SavedToolFunctionCall, SavedToolFunctionCallState, + ToolFunctionCall, ToolFunctionCallState, ToolFunctionDefinition, ToolOutput, ToolRegistry, }; diff --git a/crates/assistant_tooling/src/attachment_registry.rs b/crates/assistant_tooling/src/attachment_registry.rs index 8c82099f4d..e8b52d26f0 100644 --- a/crates/assistant_tooling/src/attachment_registry.rs +++ b/crates/assistant_tooling/src/attachment_registry.rs @@ -1,4 +1,4 @@ -use crate::{ProjectContext, ToolOutput}; +use crate::ProjectContext; use anyhow::{anyhow, Result}; use collections::HashMap; use futures::future::join_all; @@ -18,9 +18,13 @@ pub struct AttachmentRegistry { registered_attachments: HashMap, } +pub trait AttachmentOutput { + fn generate(&self, project: &mut ProjectContext, cx: &mut WindowContext) -> String; +} + pub trait LanguageModelAttachment { type Output: DeserializeOwned + Serialize + 'static; - type View: Render + ToolOutput; + type View: Render + AttachmentOutput; fn name(&self) -> Arc; fn run(&self, cx: &mut WindowContext) -> Task>; diff --git a/crates/assistant_tooling/src/tool_registry.rs b/crates/assistant_tooling/src/tool_registry.rs index d1a14c4c9d..7f0a8fb296 100644 --- a/crates/assistant_tooling/src/tool_registry.rs +++ b/crates/assistant_tooling/src/tool_registry.rs @@ -1,11 +1,10 @@ use crate::ProjectContext; use anyhow::{anyhow, Result}; -use gpui::{ - div, AnyElement, AnyView, IntoElement, ParentElement, Render, Styled, Task, View, WindowContext, -}; +use gpui::{AnyElement, AnyView, IntoElement, Render, Task, View, WindowContext}; +use repair_json::repair; use schemars::{schema::RootSchema, schema_for, JsonSchema}; use serde::{de::DeserializeOwned, Deserialize, Serialize}; -use serde_json::{value::RawValue, Value}; +use serde_json::value::RawValue; use std::{ any::TypeId, collections::HashMap, @@ -15,6 +14,7 @@ use std::{ Arc, }, }; +use ui::ViewContext; pub struct ToolRegistry { registered_tools: HashMap, @@ -25,7 +25,25 @@ pub struct ToolFunctionCall { pub id: String, pub name: String, pub arguments: String, - pub result: Option, + state: ToolFunctionCallState, +} + +#[derive(Default)] +pub enum ToolFunctionCallState { + #[default] + Initializing, + NoSuchTool, + KnownTool(Box), + ExecutedTool(Box), +} + +pub trait ToolView { + fn view(&self) -> AnyView; + fn generate(&self, project: &mut ProjectContext, cx: &mut WindowContext) -> String; + fn set_input(&self, input: &str, cx: &mut WindowContext); + fn execute(&self, cx: &mut WindowContext) -> Task>; + fn serialize_output(&self, cx: &mut WindowContext) -> Result>; + fn deserialize_output(&self, raw_value: &RawValue, cx: &mut WindowContext) -> Result<()>; } #[derive(Default, Serialize, Deserialize)] @@ -33,29 +51,19 @@ pub struct SavedToolFunctionCall { pub id: String, pub name: String, pub arguments: String, - pub result: Option, + pub state: SavedToolFunctionCallState, } -pub enum ToolFunctionCallResult { +#[derive(Default, Serialize, Deserialize)] +pub enum SavedToolFunctionCallState { + #[default] + Initializing, NoSuchTool, - ParsingFailed, - Finished { - view: AnyView, - serialized_output: Result, String>, - generate_fn: fn(AnyView, &mut ProjectContext, &mut WindowContext) -> String, - }, + KnownTool, + ExecutedTool(Box), } -#[derive(Serialize, Deserialize)] -pub enum SavedToolFunctionCallResult { - NoSuchTool, - ParsingFailed, - Finished { - serialized_output: Result, String>, - }, -} - -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct ToolFunctionDefinition { pub name: String, pub description: String, @@ -63,14 +71,7 @@ pub struct ToolFunctionDefinition { } pub trait LanguageModelTool { - /// The input type that will be passed in to `execute` when the tool is called - /// by the language model. - type Input: DeserializeOwned + JsonSchema; - - /// The output returned by executing the tool. - type Output: DeserializeOwned + Serialize + 'static; - - type View: Render + ToolOutput; + type View: ToolOutput; /// Returns the name of the tool. /// @@ -86,7 +87,7 @@ pub trait LanguageModelTool { /// Returns the OpenAI Function definition for the tool, for direct use with OpenAI's API. fn definition(&self) -> ToolFunctionDefinition { - let root_schema = schema_for!(Self::Input); + let root_schema = schema_for!(::Input); ToolFunctionDefinition { name: self.name(), @@ -95,36 +96,46 @@ pub trait LanguageModelTool { } } - /// Executes the tool with the given input. - fn execute(&self, input: &Self::Input, cx: &mut WindowContext) -> Task>; - /// A view of the output of running the tool, for displaying to the user. - fn view( - &self, - input: Self::Input, - output: Result, - cx: &mut WindowContext, - ) -> View; - - fn render_running(_arguments: &Option, _cx: &mut WindowContext) -> impl IntoElement { - tool_running_placeholder() - } + fn view(&self, cx: &mut WindowContext) -> View; } pub fn tool_running_placeholder() -> AnyElement { ui::Label::new("Researching...").into_any_element() } -pub trait ToolOutput: Sized { - fn generate(&self, project: &mut ProjectContext, cx: &mut WindowContext) -> String; +pub fn unknown_tool_placeholder() -> AnyElement { + ui::Label::new("Unknown tool").into_any_element() +} + +pub fn no_such_tool_placeholder() -> AnyElement { + ui::Label::new("No such tool").into_any_element() +} + +pub trait ToolOutput: Render { + /// The input type that will be passed in to `execute` when the tool is called + /// by the language model. + type Input: DeserializeOwned + JsonSchema; + + /// The output returned by executing the tool. + type SerializedState: DeserializeOwned + Serialize; + + fn generate(&self, project: &mut ProjectContext, cx: &mut ViewContext) -> String; + fn set_input(&mut self, input: Self::Input, cx: &mut ViewContext); + fn execute(&mut self, cx: &mut ViewContext) -> Task>; + + fn serialize(&self, cx: &mut ViewContext) -> Self::SerializedState; + fn deserialize( + &mut self, + output: Self::SerializedState, + cx: &mut ViewContext, + ) -> Result<()>; } struct RegisteredTool { enabled: AtomicBool, type_id: TypeId, - execute: Box Task>>, - deserialize: Box ToolFunctionCall>, - render_running: fn(&ToolFunctionCall, &mut WindowContext) -> gpui::AnyElement, + build_view: Box Box>, definition: ToolFunctionDefinition, } @@ -161,63 +172,132 @@ impl ToolRegistry { .collect() } - pub fn render_tool_call( - &self, - tool_call: &ToolFunctionCall, - cx: &mut WindowContext, - ) -> AnyElement { - match &tool_call.result { - Some(result) => div() - .p_2() - .child(result.into_any_element(&tool_call.name)) - .into_any_element(), - None => { - let tool = self.registered_tools.get(&tool_call.name); + pub fn view_for_tool(&self, name: &str, cx: &mut WindowContext) -> Option> { + let tool = self.registered_tools.get(name)?; + Some((tool.build_view)(cx)) + } - if let Some(tool) = tool { - (tool.render_running)(&tool_call, cx) + pub fn update_tool_call( + &self, + call: &mut ToolFunctionCall, + name: Option<&str>, + arguments: Option<&str>, + cx: &mut WindowContext, + ) { + if let Some(name) = name { + call.name.push_str(name); + } + if let Some(arguments) = arguments { + if call.arguments.is_empty() { + if let Some(view) = self.view_for_tool(&call.name, cx) { + call.state = ToolFunctionCallState::KnownTool(view); } else { - tool_running_placeholder() + call.state = ToolFunctionCallState::NoSuchTool; + } + } + call.arguments.push_str(arguments); + + if let ToolFunctionCallState::KnownTool(view) = &call.state { + if let Ok(repaired_arguments) = repair(call.arguments.clone()) { + view.set_input(&repaired_arguments, cx) } } } } - pub fn serialize_tool_call(&self, call: &ToolFunctionCall) -> SavedToolFunctionCall { - SavedToolFunctionCall { + pub fn execute_tool_call( + &self, + tool_call: &ToolFunctionCall, + cx: &mut WindowContext, + ) -> Option>> { + if let ToolFunctionCallState::KnownTool(view) = &tool_call.state { + Some(view.execute(cx)) + } else { + None + } + } + + pub fn render_tool_call( + &self, + tool_call: &ToolFunctionCall, + _cx: &mut WindowContext, + ) -> AnyElement { + match &tool_call.state { + ToolFunctionCallState::NoSuchTool => no_such_tool_placeholder(), + ToolFunctionCallState::Initializing => unknown_tool_placeholder(), + ToolFunctionCallState::KnownTool(view) | ToolFunctionCallState::ExecutedTool(view) => { + view.view().into_any_element() + } + } + } + + pub fn content_for_tool_call( + &self, + tool_call: &ToolFunctionCall, + project_context: &mut ProjectContext, + cx: &mut WindowContext, + ) -> String { + match &tool_call.state { + ToolFunctionCallState::Initializing => String::new(), + ToolFunctionCallState::NoSuchTool => { + format!("No such tool: {}", tool_call.name) + } + ToolFunctionCallState::KnownTool(view) | ToolFunctionCallState::ExecutedTool(view) => { + view.generate(project_context, cx) + } + } + } + + pub fn serialize_tool_call( + &self, + call: &ToolFunctionCall, + cx: &mut WindowContext, + ) -> Result { + Ok(SavedToolFunctionCall { id: call.id.clone(), name: call.name.clone(), arguments: call.arguments.clone(), - result: call.result.as_ref().map(|result| match result { - ToolFunctionCallResult::NoSuchTool => SavedToolFunctionCallResult::NoSuchTool, - ToolFunctionCallResult::ParsingFailed => SavedToolFunctionCallResult::ParsingFailed, - ToolFunctionCallResult::Finished { - serialized_output, .. - } => SavedToolFunctionCallResult::Finished { - serialized_output: match serialized_output { - Ok(value) => Ok(value.clone()), - Err(e) => Err(e.to_string()), - }, - }, - }), - } + state: match &call.state { + ToolFunctionCallState::Initializing => SavedToolFunctionCallState::Initializing, + ToolFunctionCallState::NoSuchTool => SavedToolFunctionCallState::NoSuchTool, + ToolFunctionCallState::KnownTool(_) => SavedToolFunctionCallState::KnownTool, + ToolFunctionCallState::ExecutedTool(view) => { + SavedToolFunctionCallState::ExecutedTool(view.serialize_output(cx)?) + } + }, + }) } pub fn deserialize_tool_call( &self, call: &SavedToolFunctionCall, cx: &mut WindowContext, - ) -> ToolFunctionCall { - if let Some(tool) = &self.registered_tools.get(&call.name) { - (tool.deserialize)(call, cx) - } else { - ToolFunctionCall { - id: call.id.clone(), - name: call.name.clone(), - arguments: call.arguments.clone(), - result: Some(ToolFunctionCallResult::NoSuchTool), - } - } + ) -> Result { + let Some(tool) = self.registered_tools.get(&call.name) else { + return Err(anyhow!("no such tool {}", call.name)); + }; + + Ok(ToolFunctionCall { + id: call.id.clone(), + name: call.name.clone(), + arguments: call.arguments.clone(), + state: match &call.state { + SavedToolFunctionCallState::Initializing => ToolFunctionCallState::Initializing, + SavedToolFunctionCallState::NoSuchTool => ToolFunctionCallState::NoSuchTool, + SavedToolFunctionCallState::KnownTool => { + log::error!("Deserialized tool that had not executed"); + let view = (tool.build_view)(cx); + view.set_input(&call.arguments, cx); + ToolFunctionCallState::KnownTool(view) + } + SavedToolFunctionCallState::ExecutedTool(output) => { + let view = (tool.build_view)(cx); + view.set_input(&call.arguments, cx); + view.deserialize_output(output, cx)?; + ToolFunctionCallState::ExecutedTool(view) + } + }, + }) } pub fn register( @@ -231,114 +311,7 @@ impl ToolRegistry { type_id: TypeId::of::(), definition: tool.definition(), enabled: AtomicBool::new(true), - deserialize: Box::new({ - let tool = tool.clone(); - move |tool_call: &SavedToolFunctionCall, cx: &mut WindowContext| { - let id = tool_call.id.clone(); - let name = tool_call.name.clone(); - let arguments = tool_call.arguments.clone(); - - let Ok(input) = serde_json::from_str::(&tool_call.arguments) else { - return ToolFunctionCall { - id, - name: name.clone(), - arguments, - result: Some(ToolFunctionCallResult::ParsingFailed), - }; - }; - - let result = match &tool_call.result { - Some(result) => match result { - SavedToolFunctionCallResult::NoSuchTool => { - Some(ToolFunctionCallResult::NoSuchTool) - } - SavedToolFunctionCallResult::ParsingFailed => { - Some(ToolFunctionCallResult::ParsingFailed) - } - SavedToolFunctionCallResult::Finished { serialized_output } => { - let output = match serialized_output { - Ok(value) => { - match serde_json::from_str::(value.get()) { - Ok(value) => Ok(value), - Err(_) => { - return ToolFunctionCall { - id, - name: name.clone(), - arguments, - result: Some( - ToolFunctionCallResult::ParsingFailed, - ), - }; - } - } - } - Err(e) => Err(anyhow!("{e}")), - }; - - let view = tool.view(input, output, cx).into(); - Some(ToolFunctionCallResult::Finished { - serialized_output: serialized_output.clone(), - generate_fn: generate::, - view, - }) - } - }, - None => None, - }; - - ToolFunctionCall { - id: tool_call.id.clone(), - name: name.clone(), - arguments: tool_call.arguments.clone(), - result, - } - } - }), - execute: Box::new({ - let tool = tool.clone(); - move |tool_call: &ToolFunctionCall, cx: &mut WindowContext| { - let id = tool_call.id.clone(); - let name = tool_call.name.clone(); - let arguments = tool_call.arguments.clone(); - - let Ok(input) = serde_json::from_str::(&arguments) else { - return Task::ready(Ok(ToolFunctionCall { - id, - name: name.clone(), - arguments, - result: Some(ToolFunctionCallResult::ParsingFailed), - })); - }; - - let result = tool.execute(&input, cx); - let tool = tool.clone(); - cx.spawn(move |mut cx| async move { - let result = result.await; - let serialized_output = result - .as_ref() - .map_err(ToString::to_string) - .and_then(|output| { - Ok(RawValue::from_string( - serde_json::to_string(output).map_err(|e| e.to_string())?, - ) - .unwrap()) - }); - let view = cx.update(|cx| tool.view(input, result, cx))?; - - Ok(ToolFunctionCall { - id, - name: name.clone(), - arguments, - result: Some(ToolFunctionCallResult::Finished { - serialized_output, - view: view.into(), - generate_fn: generate::, - }), - }) - }) - } - }), - render_running: render_running::, + build_view: Box::new(move |cx: &mut WindowContext| Box::new(tool.view(cx))), }; let previous = self.registered_tools.insert(name.clone(), registered_tool); @@ -347,83 +320,40 @@ impl ToolRegistry { } return Ok(()); - - fn render_running( - tool_call: &ToolFunctionCall, - cx: &mut WindowContext, - ) -> AnyElement { - // Attempt to parse the string arguments that are JSON as a JSON value - let maybe_arguments = serde_json::to_value(tool_call.arguments.clone()).ok(); - - T::render_running(&maybe_arguments, cx).into_any_element() - } - - fn generate( - view: AnyView, - project: &mut ProjectContext, - cx: &mut WindowContext, - ) -> String { - view.downcast::() - .unwrap() - .update(cx, |view, cx| T::View::generate(view, project, cx)) - } - } - - /// Task yields an error if the window for the given WindowContext is closed before the task completes. - pub fn call( - &self, - tool_call: &ToolFunctionCall, - cx: &mut WindowContext, - ) -> Task> { - let name = tool_call.name.clone(); - let arguments = tool_call.arguments.clone(); - let id = tool_call.id.clone(); - - let tool = match self.registered_tools.get(&name) { - Some(tool) => tool, - None => { - let name = name.clone(); - return Task::ready(Ok(ToolFunctionCall { - id, - name: name.clone(), - arguments, - result: Some(ToolFunctionCallResult::NoSuchTool), - })); - } - }; - - (tool.execute)(tool_call, cx) } } -impl ToolFunctionCallResult { - pub fn generate( - &self, - name: &String, - project: &mut ProjectContext, - cx: &mut WindowContext, - ) -> String { - match self { - ToolFunctionCallResult::NoSuchTool => format!("No tool for {name}"), - ToolFunctionCallResult::ParsingFailed => { - format!("Unable to parse arguments for {name}") - } - ToolFunctionCallResult::Finished { - generate_fn, view, .. - } => (generate_fn)(view.clone(), project, cx), +impl ToolView for View { + fn view(&self) -> AnyView { + self.clone().into() + } + + fn generate(&self, project: &mut ProjectContext, cx: &mut WindowContext) -> String { + self.update(cx, |view, cx| view.generate(project, cx)) + } + + fn set_input(&self, input: &str, cx: &mut WindowContext) { + if let Ok(input) = serde_json::from_str::(input) { + self.update(cx, |view, cx| { + view.set_input(input, cx); + cx.notify(); + }); } } - fn into_any_element(&self, name: &String) -> AnyElement { - match self { - ToolFunctionCallResult::NoSuchTool => { - format!("Language Model attempted to call {name}").into_any_element() - } - ToolFunctionCallResult::ParsingFailed => { - format!("Language Model called {name} with bad arguments").into_any_element() - } - ToolFunctionCallResult::Finished { view, .. } => view.clone().into_any_element(), - } + fn execute(&self, cx: &mut WindowContext) -> Task> { + self.update(cx, |view, cx| view.execute(cx)) + } + + fn serialize_output(&self, cx: &mut WindowContext) -> Result> { + let output = self.update(cx, |view, cx| view.serialize(cx)); + Ok(RawValue::from_string(serde_json::to_string(&output)?)?) + } + + fn deserialize_output(&self, output: &RawValue, cx: &mut WindowContext) -> Result<()> { + let state = serde_json::from_str::(output.get())?; + self.update(cx, |view, cx| view.deserialize(state, cx))?; + Ok(()) } } @@ -453,10 +383,6 @@ mod test { unit: String, } - struct WeatherTool { - current_weather: WeatherResult, - } - #[derive(Clone, Serialize, Deserialize, PartialEq, Debug)] struct WeatherResult { location: String, @@ -465,24 +391,81 @@ mod test { } struct WeatherView { - result: WeatherResult, + input: Option, + result: Option, + + // Fake API call + current_weather: WeatherResult, + } + + #[derive(Clone, Serialize)] + struct WeatherTool { + current_weather: WeatherResult, + } + + impl WeatherView { + fn new(current_weather: WeatherResult) -> Self { + Self { + input: None, + result: None, + current_weather, + } + } } impl Render for WeatherView { fn render(&mut self, _cx: &mut gpui::ViewContext) -> impl IntoElement { - div().child(format!("temperature: {}", self.result.temperature)) + match self.result { + Some(ref result) => div() + .child(format!("temperature: {}", result.temperature)) + .into_any_element(), + None => div().child("Calculating weather...").into_any_element(), + } } } impl ToolOutput for WeatherView { - fn generate(&self, _output: &mut ProjectContext, _cx: &mut WindowContext) -> String { + type Input = WeatherQuery; + + type SerializedState = WeatherResult; + + fn generate(&self, _output: &mut ProjectContext, _cx: &mut ViewContext) -> String { serde_json::to_string(&self.result).unwrap() } + + fn set_input(&mut self, input: Self::Input, cx: &mut ViewContext) { + self.input = Some(input); + cx.notify(); + } + + fn execute(&mut self, _cx: &mut ViewContext) -> Task> { + let input = self.input.as_ref().unwrap(); + + let _location = input.location.clone(); + let _unit = input.unit.clone(); + + let weather = self.current_weather.clone(); + + self.result = Some(weather); + + Task::ready(Ok(())) + } + + fn serialize(&self, _cx: &mut ViewContext) -> Self::SerializedState { + self.current_weather.clone() + } + + fn deserialize( + &mut self, + output: Self::SerializedState, + _cx: &mut ViewContext, + ) -> Result<()> { + self.current_weather = output; + Ok(()) + } } impl LanguageModelTool for WeatherTool { - type Input = WeatherQuery; - type Output = WeatherResult; type View = WeatherView; fn name(&self) -> String { @@ -493,29 +476,8 @@ mod test { "Fetches the current weather for a given location.".to_string() } - fn execute( - &self, - input: &Self::Input, - _cx: &mut WindowContext, - ) -> Task> { - let _location = input.location.clone(); - let _unit = input.unit.clone(); - - let weather = self.current_weather.clone(); - - Task::ready(Ok(weather)) - } - - fn view( - &self, - _input: Self::Input, - result: Result, - cx: &mut WindowContext, - ) -> View { - cx.new_view(|_cx| { - let result = result.unwrap(); - WeatherView { result } - }) + fn view(&self, cx: &mut WindowContext) -> View { + cx.new_view(|_cx| WeatherView::new(self.current_weather.clone())) } } @@ -564,18 +526,14 @@ mod test { }) ); - let args = json!({ - "location": "San Francisco", - "unit": "Celsius" + let view = cx.update(|cx| tool.view(cx)); + + cx.update(|cx| { + view.set_input(&r#"{"location": "San Francisco", "unit": "Celsius"}"#, cx); }); - let query: WeatherQuery = serde_json::from_value(args).unwrap(); + let finished = cx.update(|cx| view.execute(cx)).await; - let result = cx.update(|cx| tool.execute(&query, cx)).await; - - assert!(result.is_ok()); - let result = result.unwrap(); - - assert_eq!(result, tool.current_weather); + assert!(finished.is_ok()); } } diff --git a/crates/multi_buffer/src/multi_buffer.rs b/crates/multi_buffer/src/multi_buffer.rs index f32f2b2424..70343beb5e 100644 --- a/crates/multi_buffer/src/multi_buffer.rs +++ b/crates/multi_buffer/src/multi_buffer.rs @@ -1603,6 +1603,11 @@ impl MultiBuffer { "untitled".into() } + pub fn set_title(&mut self, title: String, cx: &mut ModelContext) { + self.title = Some(title); + cx.notify(); + } + #[cfg(any(test, feature = "test-support"))] pub fn is_parsing(&self, cx: &AppContext) -> bool { self.as_singleton().unwrap().read(cx).is_parsing() @@ -3151,10 +3156,10 @@ impl MultiBufferSnapshot { .redacted_ranges(excerpt.range.context.clone()) .map(move |mut redacted_range| { // Re-base onto the excerpts coordinates in the multibuffer - redacted_range.start = - excerpt_offset + (redacted_range.start - excerpt_buffer_start); - redacted_range.end = - excerpt_offset + (redacted_range.end - excerpt_buffer_start); + redacted_range.start = excerpt_offset + + redacted_range.start.saturating_sub(excerpt_buffer_start); + redacted_range.end = excerpt_offset + + redacted_range.end.saturating_sub(excerpt_buffer_start); redacted_range }) @@ -3179,10 +3184,13 @@ impl MultiBufferSnapshot { .runnable_ranges(excerpt.range.context.clone()) .map(move |mut runnable| { // Re-base onto the excerpts coordinates in the multibuffer - runnable.run_range.start = - excerpt_offset + (runnable.run_range.start - excerpt_buffer_start); - runnable.run_range.end = - excerpt_offset + (runnable.run_range.end - excerpt_buffer_start); + runnable.run_range.start = excerpt_offset + + runnable + .run_range + .start + .saturating_sub(excerpt_buffer_start); + runnable.run_range.end = excerpt_offset + + runnable.run_range.end.saturating_sub(excerpt_buffer_start); runnable }) .skip_while(move |runnable| runnable.run_range.end < range.start)