From a64e20ed9603883492d400b6bf3ab9e1bc6ecafd Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Mon, 6 May 2024 17:01:50 -0700 Subject: [PATCH] Centralize project context provided to the assistant (#11471) This PR restructures the way that tools and attachments add information about the current project to a conversation with the assistant. Rather than each tool call or attachment generating a new tool or system message containing information about the project, they can all collectively mutate a new type called a `ProjectContext`, which stores all of the project data that should be sent to the assistant. That data is then formatted in a single place, and passed to the assistant in one system message. This prevents multiple tools/attachments from including redundant context. Release Notes: - N/A --------- Co-authored-by: Kyle --- Cargo.lock | 7 + crates/assistant2/src/assistant2.rs | 101 +++--- crates/assistant2/src/attachments.rs | 206 +++--------- .../assistant2/src/attachments/active_file.rs | 1 + crates/assistant2/src/tools/create_buffer.rs | 39 ++- crates/assistant2/src/tools/project_index.rs | 186 +++++------ .../assistant2/src/ui/active_file_button.rs | 13 +- crates/assistant2/src/ui/composer.rs | 10 +- crates/assistant_tooling/Cargo.toml | 8 + .../src/assistant_tooling.rs | 12 +- .../src/attachment_registry.rs | 148 +++++++++ .../assistant_tooling/src/project_context.rs | 296 ++++++++++++++++++ crates/assistant_tooling/src/tool.rs | 111 ------- .../src/{registry.rs => tool_registry.rs} | 217 ++++++++++--- crates/semantic_index/src/semantic_index.rs | 4 + 15 files changed, 841 insertions(+), 518 deletions(-) create mode 100644 crates/assistant2/src/attachments/active_file.rs create mode 100644 crates/assistant_tooling/src/attachment_registry.rs create mode 100644 crates/assistant_tooling/src/project_context.rs delete mode 100644 crates/assistant_tooling/src/tool.rs rename crates/assistant_tooling/src/{registry.rs => tool_registry.rs} (59%) diff --git a/Cargo.lock b/Cargo.lock index f095d096fb..eee04e4eb1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -411,10 +411,17 @@ name = "assistant_tooling" version = "0.1.0" dependencies = [ "anyhow", + "collections", + "futures 0.3.28", "gpui", + "project", "schemars", "serde", "serde_json", + "settings", + "sum_tree", + "unindent", + "util", ] [[package]] diff --git a/crates/assistant2/src/assistant2.rs b/crates/assistant2/src/assistant2.rs index 99bf9a6884..3fa8d25dd6 100644 --- a/crates/assistant2/src/assistant2.rs +++ b/crates/assistant2/src/assistant2.rs @@ -4,10 +4,16 @@ mod completion_provider; mod tools; pub mod ui; +use crate::{ + attachments::ActiveEditorAttachmentTool, + tools::{CreateBufferTool, ProjectIndexTool}, + ui::UserOrAssistant, +}; use ::ui::{div, prelude::*, Color, ViewContext}; use anyhow::{Context, Result}; -use assistant_tooling::{ToolFunctionCall, ToolRegistry}; -use attachments::{ActiveEditorAttachmentTool, UserAttachment, UserAttachmentStore}; +use assistant_tooling::{ + AttachmentRegistry, ProjectContext, ToolFunctionCall, ToolRegistry, UserAttachment, +}; use client::{proto, Client, UserStore}; use collections::HashMap; use completion_provider::*; @@ -34,9 +40,6 @@ use workspace::{ pub use assistant_settings::AssistantSettings; -use crate::tools::{CreateBufferTool, ProjectIndexTool}; -use crate::ui::UserOrAssistant; - const MAX_COMPLETION_CALLS_PER_SUBMISSION: usize = 5; #[derive(Eq, PartialEq, Copy, Clone, Deserialize)] @@ -85,10 +88,9 @@ pub fn init(client: Arc, cx: &mut AppContext) { }); workspace.register_action(|workspace, _: &DebugProjectIndex, cx| { if let Some(panel) = workspace.panel::(cx) { - if let Some(index) = panel.read(cx).chat.read(cx).project_index.clone() { - let view = cx.new_view(|cx| ProjectIndexDebugView::new(index, cx)); - workspace.add_item_to_center(Box::new(view), cx); - } + let index = panel.read(cx).chat.read(cx).project_index.clone(); + let view = cx.new_view(|cx| ProjectIndexDebugView::new(index, cx)); + workspace.add_item_to_center(Box::new(view), cx); } }); }, @@ -122,10 +124,7 @@ impl AssistantPanel { let mut tool_registry = ToolRegistry::new(); tool_registry - .register( - ProjectIndexTool::new(project_index.clone(), project.read(cx).fs().clone()), - cx, - ) + .register(ProjectIndexTool::new(project_index.clone()), cx) .context("failed to register ProjectIndexTool") .log_err(); tool_registry @@ -136,7 +135,7 @@ impl AssistantPanel { .context("failed to register CreateBufferTool") .log_err(); - let mut attachment_store = UserAttachmentStore::new(); + let mut attachment_store = AttachmentRegistry::new(); attachment_store.register(ActiveEditorAttachmentTool::new(workspace.clone(), cx)); Self::new( @@ -144,7 +143,7 @@ impl AssistantPanel { Arc::new(tool_registry), Arc::new(attachment_store), app_state.user_store.clone(), - Some(project_index), + project_index, workspace, cx, ) @@ -155,9 +154,9 @@ impl AssistantPanel { pub fn new( language_registry: Arc, tool_registry: Arc, - attachment_store: Arc, + attachment_store: Arc, user_store: Model, - project_index: Option>, + project_index: Model, workspace: WeakView, cx: &mut ViewContext, ) -> Self { @@ -241,16 +240,16 @@ pub struct AssistantChat { list_state: ListState, language_registry: Arc, composer_editor: View, - project_index_button: Option>, + project_index_button: View, active_file_button: Option>, user_store: Model, next_message_id: MessageId, collapsed_messages: HashMap, editing_message: Option, pending_completion: Option>, - attachment_store: Arc, tool_registry: Arc, - project_index: Option>, + attachment_registry: Arc, + project_index: Model, } struct EditingMessage { @@ -263,9 +262,9 @@ impl AssistantChat { fn new( language_registry: Arc, tool_registry: Arc, - attachment_store: Arc, + attachment_registry: Arc, user_store: Model, - project_index: Option>, + project_index: Model, workspace: WeakView, cx: &mut ViewContext, ) -> Self { @@ -281,14 +280,14 @@ impl AssistantChat { }, ); - let project_index_button = project_index.clone().map(|project_index| { - cx.new_view(|cx| ProjectIndexButton::new(project_index, tool_registry.clone(), cx)) + let project_index_button = cx.new_view(|cx| { + ProjectIndexButton::new(project_index.clone(), tool_registry.clone(), cx) }); let active_file_button = match workspace.upgrade() { Some(workspace) => { Some(cx.new_view( - |cx| ActiveFileButton::new(attachment_store.clone(), workspace, cx), // + |cx| ActiveFileButton::new(attachment_registry.clone(), workspace, cx), // )) } _ => None, @@ -313,7 +312,7 @@ impl AssistantChat { editing_message: None, collapsed_messages: HashMap::default(), pending_completion: None, - attachment_store, + attachment_registry, tool_registry, } } @@ -395,7 +394,7 @@ impl AssistantChat { let mode = *mode; self.pending_completion = Some(cx.spawn(move |this, mut cx| async move { let attachments_task = this.update(&mut cx, |this, cx| { - let attachment_store = this.attachment_store.clone(); + let attachment_store = this.attachment_registry.clone(); attachment_store.call_all_attachment_tools(cx) }); @@ -443,7 +442,7 @@ impl AssistantChat { let mut call_count = 0; loop { let complete = async { - let completion = this.update(cx, |this, cx| { + let (tool_definitions, model_name, messages) = this.update(cx, |this, cx| { this.push_new_assistant_message(cx); let definitions = if call_count < limit @@ -455,14 +454,22 @@ impl AssistantChat { }; call_count += 1; - let messages = this.completion_messages(cx); - - CompletionProvider::get(cx).complete( + ( + definitions, this.model.clone(), + this.completion_messages(cx), + ) + })?; + + let messages = messages.await?; + + let completion = cx.update(|cx| { + CompletionProvider::get(cx).complete( + model_name, messages, Vec::new(), 1.0, - definitions, + tool_definitions, ) }); @@ -765,7 +772,12 @@ impl AssistantChat { } } - fn completion_messages(&self, cx: &mut WindowContext) -> Vec { + fn completion_messages(&self, cx: &mut WindowContext) -> Task>> { + let project_index = self.project_index.read(cx); + let project = project_index.project(); + let fs = project_index.fs(); + + let mut project_context = ProjectContext::new(project, fs); let mut completion_messages = Vec::new(); for message in &self.messages { @@ -773,12 +785,11 @@ impl AssistantChat { ChatMessage::User(UserMessage { body, attachments, .. }) => { - completion_messages.extend( - attachments - .into_iter() - .filter_map(|attachment| attachment.message.clone()) - .map(|content| CompletionMessage::System { content }), - ); + for attachment in attachments { + if let Some(content) = attachment.generate(&mut project_context, cx) { + completion_messages.push(CompletionMessage::System { content }); + } + } // Show user's message last so that the assistant is grounded in the user's request completion_messages.push(CompletionMessage::User { @@ -815,7 +826,9 @@ impl AssistantChat { for tool_call in tool_calls { // Every tool call _must_ have a result by ID, otherwise OpenAI will error. let content = match &tool_call.result { - Some(result) => result.format(&tool_call.name), + Some(result) => { + result.generate(&tool_call.name, &mut project_context, cx) + } None => "".to_string(), }; @@ -828,7 +841,13 @@ impl AssistantChat { } } - completion_messages + let system_message = project_context.generate_system_message(cx); + + cx.background_executor().spawn(async move { + let content = system_message.await?; + completion_messages.insert(0, CompletionMessage::System { content }); + Ok(completion_messages) + }) } } diff --git a/crates/assistant2/src/attachments.rs b/crates/assistant2/src/attachments.rs index cddf648163..5da8af7e0d 100644 --- a/crates/assistant2/src/attachments.rs +++ b/crates/assistant2/src/attachments.rs @@ -1,137 +1,18 @@ -use std::{ - any::TypeId, - sync::{ - atomic::{AtomicBool, Ordering::SeqCst}, - Arc, - }, -}; +pub mod active_file; use anyhow::{anyhow, Result}; -use collections::HashMap; +use assistant_tooling::{LanguageModelAttachment, ProjectContext, ToolOutput}; use editor::Editor; -use futures::future::join_all; -use gpui::{AnyView, Render, Task, View, WeakView}; +use gpui::{Render, Task, View, WeakModel, WeakView}; +use language::Buffer; +use project::ProjectPath; use ui::{prelude::*, ButtonLike, Tooltip, WindowContext}; -use util::{maybe, ResultExt}; +use util::maybe; use workspace::Workspace; -/// A collected attachment from running an attachment tool -pub struct UserAttachment { - pub message: Option, - pub view: AnyView, -} - -pub struct UserAttachmentStore { - attachment_tools: HashMap, -} - -/// Internal representation of an attachment tool to allow us to treat them dynamically -struct DynamicAttachment { - enabled: AtomicBool, - call: Box Task>>, -} - -impl UserAttachmentStore { - pub fn new() -> Self { - Self { - attachment_tools: HashMap::default(), - } - } - - pub fn register(&mut self, attachment: A) { - let call = Box::new(move |cx: &mut WindowContext| { - let result = attachment.run(cx); - - cx.spawn(move |mut cx| async move { - let result: Result = result.await; - let message = A::format(&result); - let view = cx.update(|cx| A::view(result, cx))?; - - Ok(UserAttachment { - message, - view: view.into(), - }) - }) - }); - - self.attachment_tools.insert( - TypeId::of::(), - DynamicAttachment { - call, - enabled: AtomicBool::new(true), - }, - ); - } - - pub fn set_attachment_tool_enabled(&self, is_enabled: bool) { - if let Some(attachment) = self.attachment_tools.get(&TypeId::of::()) { - attachment.enabled.store(is_enabled, SeqCst); - } - } - - pub fn is_attachment_tool_enabled(&self) -> bool { - if let Some(attachment) = self.attachment_tools.get(&TypeId::of::()) { - attachment.enabled.load(SeqCst) - } else { - false - } - } - - pub fn call( - &self, - cx: &mut WindowContext, - ) -> Task> { - let Some(attachment) = self.attachment_tools.get(&TypeId::of::()) else { - return Task::ready(Err(anyhow!("no attachment tool"))); - }; - - (attachment.call)(cx) - } - - pub fn call_all_attachment_tools( - self: Arc, - cx: &mut WindowContext<'_>, - ) -> Task>> { - let this = self.clone(); - cx.spawn(|mut cx| async move { - let attachment_tasks = cx.update(|cx| { - let mut tasks = Vec::new(); - for attachment in this - .attachment_tools - .values() - .filter(|attachment| attachment.enabled.load(SeqCst)) - { - tasks.push((attachment.call)(cx)) - } - - tasks - })?; - - let attachments = join_all(attachment_tasks.into_iter()).await; - - Ok(attachments - .into_iter() - .filter_map(|attachment| attachment.log_err()) - .collect()) - }) - } -} - -pub trait AttachmentTool { - type Output: 'static; - type View: Render; - - fn run(&self, cx: &mut WindowContext) -> Task>; - - fn format(output: &Result) -> Option; - - fn view(output: Result, cx: &mut WindowContext) -> View; -} - pub struct ActiveEditorAttachment { - filename: Arc, - language: Arc, - text: Arc, + buffer: WeakModel, + path: Option, } pub struct FileAttachmentView { @@ -142,7 +23,13 @@ impl Render for FileAttachmentView { fn render(&mut self, cx: &mut ViewContext) -> impl IntoElement { match &self.output { Ok(attachment) => { - let filename = attachment.filename.clone(); + let filename: SharedString = attachment + .path + .as_ref() + .and_then(|p| p.path.file_name()?.to_str()) + .unwrap_or("Untitled") + .to_string() + .into(); // todo!(): make the button link to the actual file to open ButtonLike::new("file-attachment") @@ -152,7 +39,7 @@ impl Render for FileAttachmentView { .bg(cx.theme().colors().editor_background) .rounded_md() .child(ui::Icon::new(IconName::File)) - .child(filename.to_string()), + .child(filename.clone()), ) .tooltip({ move |cx| Tooltip::with_meta("File Attached", None, filename.clone(), cx) @@ -164,6 +51,20 @@ impl Render for FileAttachmentView { } } +impl ToolOutput for FileAttachmentView { + fn generate(&self, project: &mut ProjectContext, cx: &mut WindowContext) -> String { + if let Ok(result) = &self.output { + if let Some(path) = &result.path { + project.add_file(path.clone()); + return format!("current file: {}", path.path.display()); + } else if let Some(buffer) = result.buffer.upgrade() { + return format!("current untitled buffer text:\n{}", buffer.read(cx).text()); + } + } + String::new() + } +} + pub struct ActiveEditorAttachmentTool { workspace: WeakView, } @@ -174,7 +75,7 @@ impl ActiveEditorAttachmentTool { } } -impl AttachmentTool for ActiveEditorAttachmentTool { +impl LanguageModelAttachment for ActiveEditorAttachmentTool { type Output = ActiveEditorAttachment; type View = FileAttachmentView; @@ -191,47 +92,22 @@ impl AttachmentTool for ActiveEditorAttachmentTool { let buffer = active_buffer.read(cx); - if let Some(singleton) = buffer.as_singleton() { - let singleton = singleton.read(cx); - - let filename = singleton - .file() - .map(|file| file.path().to_string_lossy()) - .unwrap_or("Untitled".into()); - - let text = singleton.text(); - - let language = singleton - .language() - .map(|l| { - let name = l.code_fence_block_name(); - name.to_string() - }) - .unwrap_or_default(); - + if let Some(buffer) = buffer.as_singleton() { + let path = + project::File::from_dyn(buffer.read(cx).file()).map(|file| ProjectPath { + worktree_id: file.worktree_id(cx), + path: file.path.clone(), + }); return Ok(ActiveEditorAttachment { - filename: filename.into(), - language: language.into(), - text: text.into(), + buffer: buffer.downgrade(), + path, }); + } else { + Err(anyhow!("no active buffer")) } - - Err(anyhow!("no active buffer")) })) } - fn format(output: &Result) -> Option { - let output = output.as_ref().ok()?; - - let filename = &output.filename; - let language = &output.language; - let text = &output.text; - - Some(format!( - "User's active file `{filename}`:\n\n```{language}\n{text}```\n\n" - )) - } - fn view(output: Result, cx: &mut WindowContext) -> View { cx.new_view(|_cx| FileAttachmentView { output }) } diff --git a/crates/assistant2/src/attachments/active_file.rs b/crates/assistant2/src/attachments/active_file.rs new file mode 100644 index 0000000000..8b13789179 --- /dev/null +++ b/crates/assistant2/src/attachments/active_file.rs @@ -0,0 +1 @@ + diff --git a/crates/assistant2/src/tools/create_buffer.rs b/crates/assistant2/src/tools/create_buffer.rs index b2c6f7d8a9..13e4cc7081 100644 --- a/crates/assistant2/src/tools/create_buffer.rs +++ b/crates/assistant2/src/tools/create_buffer.rs @@ -1,5 +1,5 @@ use anyhow::Result; -use assistant_tooling::LanguageModelTool; +use assistant_tooling::{LanguageModelTool, ProjectContext, ToolOutput}; use editor::Editor; use gpui::{prelude::*, Model, Task, View, WeakView}; use project::Project; @@ -31,11 +31,9 @@ pub struct CreateBufferInput { language: String, } -pub struct CreateBufferOutput {} - impl LanguageModelTool for CreateBufferTool { type Input = CreateBufferInput; - type Output = CreateBufferOutput; + type Output = (); type View = CreateBufferView; fn name(&self) -> String { @@ -83,32 +81,39 @@ impl LanguageModelTool for CreateBufferTool { }) .log_err(); - Ok(CreateBufferOutput {}) + Ok(()) } }) } - fn format(input: &Self::Input, output: &Result) -> String { - match output { - Ok(_) => format!("Created a new {} buffer", input.language), - Err(err) => format!("Failed to create buffer: {err:?}"), - } - } - fn output_view( - _tool_call_id: String, - _input: Self::Input, - _output: Result, + input: Self::Input, + output: Result, cx: &mut WindowContext, ) -> View { - cx.new_view(|_cx| CreateBufferView {}) + cx.new_view(|_cx| CreateBufferView { + language: input.language, + output, + }) } } -pub struct CreateBufferView {} +pub struct CreateBufferView { + language: String, + output: Result<()>, +} impl Render for CreateBufferView { fn render(&mut self, _cx: &mut ViewContext) -> impl IntoElement { div().child("Opening a buffer") } } + +impl ToolOutput for CreateBufferView { + fn generate(&self, _: &mut ProjectContext, _: &mut WindowContext) -> String { + match &self.output { + Ok(_) => format!("Created a new {} buffer", self.language), + Err(err) => format!("Failed to create buffer: {err:?}"), + } + } +} diff --git a/crates/assistant2/src/tools/project_index.rs b/crates/assistant2/src/tools/project_index.rs index 7ccbae79d3..c67c9216c1 100644 --- a/crates/assistant2/src/tools/project_index.rs +++ b/crates/assistant2/src/tools/project_index.rs @@ -1,25 +1,18 @@ use anyhow::Result; -use assistant_tooling::LanguageModelTool; +use assistant_tooling::{LanguageModelTool, ToolOutput}; +use collections::BTreeMap; use gpui::{prelude::*, Model, Task}; -use project::Fs; +use project::ProjectPath; use schemars::JsonSchema; use semantic_index::{ProjectIndex, Status}; use serde::Deserialize; -use std::{collections::HashSet, sync::Arc}; - -use ui::{ - div, prelude::*, CollapsibleContainer, Color, Icon, IconName, Label, SharedString, - WindowContext, -}; -use util::ResultExt as _; +use std::{fmt::Write as _, ops::Range}; +use ui::{div, prelude::*, CollapsibleContainer, Color, Icon, IconName, Label, WindowContext}; const DEFAULT_SEARCH_LIMIT: usize = 20; -#[derive(Clone)] -pub struct CodebaseExcerpt { - path: SharedString, - text: SharedString, - score: f32, +pub struct ProjectIndexTool { + project_index: Model, } // Note: Comments on a `LanguageModelTool::Input` become descriptions on the generated JSON schema as shown to the language model. @@ -40,6 +33,11 @@ pub struct ProjectIndexView { expanded_header: bool, } +pub struct ProjectIndexOutput { + status: Status, + excerpts: BTreeMap>>, +} + impl ProjectIndexView { fn new(input: CodebaseQuery, output: Result) -> Self { let element_id = ElementId::Name(nanoid::nanoid!().into()); @@ -71,19 +69,15 @@ impl Render for ProjectIndexView { Ok(output) => output, }; - let num_files_searched = output.files_searched.len(); + let file_count = output.excerpts.len(); let header = h_flex() .gap_2() .child(Icon::new(IconName::File)) .child(format!( "Read {} {}", - num_files_searched, - if num_files_searched == 1 { - "file" - } else { - "files" - } + file_count, + if file_count == 1 { "file" } else { "files" } )); v_flex().gap_3().child( @@ -102,36 +96,50 @@ impl Render for ProjectIndexView { .child(Icon::new(IconName::MagnifyingGlass)) .child(Label::new(format!("`{}`", query)).color(Color::Muted)), ) - .child(v_flex().gap_2().children(output.files_searched.iter().map( - |path| { - h_flex() - .gap_2() - .child(Icon::new(IconName::File)) - .child(Label::new(path.clone()).color(Color::Muted)) - }, - ))), + .child( + v_flex() + .gap_2() + .children(output.excerpts.keys().map(|path| { + h_flex().gap_2().child(Icon::new(IconName::File)).child( + Label::new(path.path.to_string_lossy().to_string()) + .color(Color::Muted), + ) + })), + ), ), ) } } -pub struct ProjectIndexTool { - project_index: Model, - fs: Arc, -} +impl ToolOutput for ProjectIndexView { + fn generate( + &self, + context: &mut assistant_tooling::ProjectContext, + _: &mut WindowContext, + ) -> String { + match &self.output { + Ok(output) => { + let mut body = "found results in the following paths:\n".to_string(); -pub struct ProjectIndexOutput { - excerpts: Vec, - status: Status, - files_searched: HashSet, + for (project_path, ranges) in &output.excerpts { + context.add_excerpts(project_path.clone(), ranges); + writeln!(&mut body, "* {}", &project_path.path.display()).unwrap(); + } + + if output.status != Status::Idle { + body.push_str("Still indexing. Results may be incomplete.\n"); + } + + body + } + Err(err) => format!("Error: {}", err), + } + } } impl ProjectIndexTool { - pub fn new(project_index: Model, fs: Arc) -> Self { - // Listen for project index status and update the ProjectIndexTool directly - - // TODO: setup a better description based on the user's current codebase. - Self { project_index, fs } + pub fn new(project_index: Model) -> Self { + Self { project_index } } } @@ -151,64 +159,42 @@ impl LanguageModelTool for ProjectIndexTool { fn execute(&self, query: &Self::Input, cx: &mut WindowContext) -> Task> { let project_index = self.project_index.read(cx); let status = project_index.status(); - let results = project_index.search( + let search = project_index.search( query.query.clone(), query.limit.unwrap_or(DEFAULT_SEARCH_LIMIT), cx, ); - let fs = self.fs.clone(); + cx.spawn(|mut cx| async move { + let search_results = search.await?; - cx.spawn(|cx| async move { - let results = results.await?; + cx.update(|cx| { + let mut output = ProjectIndexOutput { + status, + excerpts: Default::default(), + }; - let excerpts = results.into_iter().map(|result| { - let abs_path = result - .worktree - .read_with(&cx, |worktree, _| worktree.abs_path().join(&result.path)); - let fs = fs.clone(); + for search_result in search_results { + let path = ProjectPath { + worktree_id: search_result.worktree.read(cx).id(), + path: search_result.path.clone(), + }; - async move { - let path = result.path.clone(); - let text = fs.load(&abs_path?).await?; - - let mut start = result.range.start; - let mut end = result.range.end.min(text.len()); - while !text.is_char_boundary(start) { - start += 1; - } - while !text.is_char_boundary(end) { - end -= 1; - } - - anyhow::Ok(CodebaseExcerpt { - path: path.to_string_lossy().to_string().into(), - text: SharedString::from(text[start..end].to_string()), - score: result.score, - }) + let excerpts_for_path = output.excerpts.entry(path).or_default(); + let ix = match excerpts_for_path + .binary_search_by_key(&search_result.range.start, |r| r.start) + { + Ok(ix) | Err(ix) => ix, + }; + excerpts_for_path.insert(ix, search_result.range); } - }); - let mut files_searched = HashSet::new(); - let excerpts = futures::future::join_all(excerpts) - .await - .into_iter() - .filter_map(|result| result.log_err()) - .inspect(|excerpt| { - files_searched.insert(excerpt.path.clone()); - }) - .collect::>(); - - anyhow::Ok(ProjectIndexOutput { - excerpts, - status, - files_searched, + output }) }) } fn output_view( - _tool_call_id: String, input: Self::Input, output: Result, cx: &mut WindowContext, @@ -220,34 +206,4 @@ impl LanguageModelTool for ProjectIndexTool { CollapsibleContainer::new(ElementId::Name(nanoid::nanoid!().into()), false) .start_slot("Searching code base") } - - fn format(_input: &Self::Input, output: &Result) -> String { - match &output { - Ok(output) => { - let mut body = "Semantic search results:\n".to_string(); - - if output.status != Status::Idle { - body.push_str("Still indexing. Results may be incomplete.\n"); - } - - if output.excerpts.is_empty() { - body.push_str("No results found"); - return body; - } - - for excerpt in &output.excerpts { - body.push_str("Excerpt from "); - body.push_str(excerpt.path.as_ref()); - body.push_str(", score "); - body.push_str(&excerpt.score.to_string()); - body.push_str(":\n"); - body.push_str("~~~\n"); - body.push_str(excerpt.text.as_ref()); - body.push_str("~~~\n"); - } - body - } - Err(err) => format!("Error: {}", err), - } - } } diff --git a/crates/assistant2/src/ui/active_file_button.rs b/crates/assistant2/src/ui/active_file_button.rs index d6381b6b04..3a2ac8e04d 100644 --- a/crates/assistant2/src/ui/active_file_button.rs +++ b/crates/assistant2/src/ui/active_file_button.rs @@ -1,4 +1,5 @@ -use crate::attachments::{ActiveEditorAttachmentTool, UserAttachmentStore}; +use crate::attachments::ActiveEditorAttachmentTool; +use assistant_tooling::AttachmentRegistry; use editor::Editor; use gpui::{prelude::*, Subscription, View}; use std::sync::Arc; @@ -13,7 +14,7 @@ enum Status { } pub struct ActiveFileButton { - attachment_store: Arc, + attachment_registry: Arc, status: Status, #[allow(dead_code)] workspace_subscription: Subscription, @@ -21,7 +22,7 @@ pub struct ActiveFileButton { impl ActiveFileButton { pub fn new( - attachment_store: Arc, + attachment_store: Arc, workspace: View, cx: &mut ViewContext, ) -> Self { @@ -30,14 +31,14 @@ impl ActiveFileButton { cx.defer(move |this, cx| this.update_active_buffer(workspace.clone(), cx)); Self { - attachment_store, + attachment_registry: attachment_store, status: Status::NoFile, workspace_subscription, } } pub fn set_enabled(&mut self, enabled: bool) { - self.attachment_store + self.attachment_registry .set_attachment_tool_enabled::(enabled); } @@ -79,7 +80,7 @@ impl ActiveFileButton { impl Render for ActiveFileButton { fn render(&mut self, cx: &mut ViewContext) -> impl IntoElement { let is_enabled = self - .attachment_store + .attachment_registry .is_attachment_tool_enabled::(); let icon = if is_enabled { diff --git a/crates/assistant2/src/ui/composer.rs b/crates/assistant2/src/ui/composer.rs index 2b866e3ad5..b6c1ca1c48 100644 --- a/crates/assistant2/src/ui/composer.rs +++ b/crates/assistant2/src/ui/composer.rs @@ -11,7 +11,7 @@ use ui::{popover_menu, prelude::*, ButtonLike, ContextMenu, Divider, TextSize, T #[derive(IntoElement)] pub struct Composer { editor: View, - project_index_button: Option>, + project_index_button: View, active_file_button: Option>, model_selector: AnyElement, } @@ -19,7 +19,7 @@ pub struct Composer { impl Composer { pub fn new( editor: View, - project_index_button: Option>, + project_index_button: View, active_file_button: Option>, model_selector: AnyElement, ) -> Self { @@ -32,11 +32,7 @@ impl Composer { } fn render_tools(&mut self, _cx: &mut WindowContext) -> impl IntoElement { - h_flex().children( - self.project_index_button - .clone() - .map(|view| view.into_any_element()), - ) + h_flex().child(self.project_index_button.clone()) } fn render_attachment_tools(&mut self, _cx: &mut WindowContext) -> impl IntoElement { diff --git a/crates/assistant_tooling/Cargo.toml b/crates/assistant_tooling/Cargo.toml index 8a7e7ab185..a69d1729d3 100644 --- a/crates/assistant_tooling/Cargo.toml +++ b/crates/assistant_tooling/Cargo.toml @@ -13,10 +13,18 @@ path = "src/assistant_tooling.rs" [dependencies] anyhow.workspace = true +collections.workspace = true +futures.workspace = true gpui.workspace = true +project.workspace = true schemars.workspace = true serde.workspace = true serde_json.workspace = true +sum_tree.workspace = true +util.workspace = true [dev-dependencies] gpui = { workspace = true, features = ["test-support"] } +project = { workspace = true, features = ["test-support"] } +settings = { workspace = true, features = ["test-support"] } +unindent.workspace = true diff --git a/crates/assistant_tooling/src/assistant_tooling.rs b/crates/assistant_tooling/src/assistant_tooling.rs index 93d81cbb9d..6e5903c1f4 100644 --- a/crates/assistant_tooling/src/assistant_tooling.rs +++ b/crates/assistant_tooling/src/assistant_tooling.rs @@ -1,5 +1,9 @@ -pub mod registry; -pub mod tool; +mod attachment_registry; +mod project_context; +mod tool_registry; -pub use crate::registry::ToolRegistry; -pub use crate::tool::{LanguageModelTool, ToolFunctionCall, ToolFunctionDefinition}; +pub use attachment_registry::{AttachmentRegistry, LanguageModelAttachment, UserAttachment}; +pub use project_context::ProjectContext; +pub use tool_registry::{ + LanguageModelTool, ToolFunctionCall, ToolFunctionDefinition, ToolOutput, ToolRegistry, +}; diff --git a/crates/assistant_tooling/src/attachment_registry.rs b/crates/assistant_tooling/src/attachment_registry.rs new file mode 100644 index 0000000000..8c0ae347a0 --- /dev/null +++ b/crates/assistant_tooling/src/attachment_registry.rs @@ -0,0 +1,148 @@ +use crate::{ProjectContext, ToolOutput}; +use anyhow::{anyhow, Result}; +use collections::HashMap; +use futures::future::join_all; +use gpui::{AnyView, Render, Task, View, WindowContext}; +use std::{ + any::TypeId, + sync::{ + atomic::{AtomicBool, Ordering::SeqCst}, + Arc, + }, +}; +use util::ResultExt as _; + +pub struct AttachmentRegistry { + registered_attachments: HashMap, +} + +pub trait LanguageModelAttachment { + type Output: 'static; + type View: Render + ToolOutput; + + fn run(&self, cx: &mut WindowContext) -> Task>; + + fn view(output: Result, cx: &mut WindowContext) -> View; +} + +/// A collected attachment from running an attachment tool +pub struct UserAttachment { + pub view: AnyView, + generate_fn: fn(AnyView, &mut ProjectContext, cx: &mut WindowContext) -> String, +} + +/// Internal representation of an attachment tool to allow us to treat them dynamically +struct RegisteredAttachment { + enabled: AtomicBool, + call: Box Task>>, +} + +impl AttachmentRegistry { + pub fn new() -> Self { + Self { + registered_attachments: HashMap::default(), + } + } + + pub fn register(&mut self, attachment: A) { + let call = Box::new(move |cx: &mut WindowContext| { + let result = attachment.run(cx); + + cx.spawn(move |mut cx| async move { + let result: Result = result.await; + let view = cx.update(|cx| A::view(result, cx))?; + + Ok(UserAttachment { + view: view.into(), + generate_fn: generate::, + }) + }) + }); + + self.registered_attachments.insert( + TypeId::of::(), + RegisteredAttachment { + call, + enabled: AtomicBool::new(true), + }, + ); + return; + + fn generate( + view: AnyView, + project: &mut ProjectContext, + cx: &mut WindowContext, + ) -> String { + view.downcast::() + .unwrap() + .update(cx, |view, cx| T::View::generate(view, project, cx)) + } + } + + pub fn set_attachment_tool_enabled( + &self, + is_enabled: bool, + ) { + if let Some(attachment) = self.registered_attachments.get(&TypeId::of::()) { + attachment.enabled.store(is_enabled, SeqCst); + } + } + + pub fn is_attachment_tool_enabled(&self) -> bool { + if let Some(attachment) = self.registered_attachments.get(&TypeId::of::()) { + attachment.enabled.load(SeqCst) + } else { + false + } + } + + pub fn call( + &self, + cx: &mut WindowContext, + ) -> Task> { + let Some(attachment) = self.registered_attachments.get(&TypeId::of::()) else { + return Task::ready(Err(anyhow!("no attachment tool"))); + }; + + (attachment.call)(cx) + } + + pub fn call_all_attachment_tools( + self: Arc, + cx: &mut WindowContext<'_>, + ) -> Task>> { + let this = self.clone(); + cx.spawn(|mut cx| async move { + let attachment_tasks = cx.update(|cx| { + let mut tasks = Vec::new(); + for attachment in this + .registered_attachments + .values() + .filter(|attachment| attachment.enabled.load(SeqCst)) + { + tasks.push((attachment.call)(cx)) + } + + tasks + })?; + + let attachments = join_all(attachment_tasks.into_iter()).await; + + Ok(attachments + .into_iter() + .filter_map(|attachment| attachment.log_err()) + .collect()) + }) + } +} + +impl UserAttachment { + pub fn generate(&self, output: &mut ProjectContext, cx: &mut WindowContext) -> Option { + let result = (self.generate_fn)(self.view.clone(), output, cx); + if result.is_empty() { + None + } else { + Some(result) + } + } +} diff --git a/crates/assistant_tooling/src/project_context.rs b/crates/assistant_tooling/src/project_context.rs new file mode 100644 index 0000000000..aafe2728bf --- /dev/null +++ b/crates/assistant_tooling/src/project_context.rs @@ -0,0 +1,296 @@ +use anyhow::{anyhow, Result}; +use gpui::{AppContext, Model, Task, WeakModel}; +use project::{Fs, Project, ProjectPath, Worktree}; +use std::{cmp::Ordering, fmt::Write as _, ops::Range, sync::Arc}; +use sum_tree::TreeMap; + +pub struct ProjectContext { + files: TreeMap, + project: WeakModel, + fs: Arc, +} + +#[derive(Debug, Clone)] +enum PathState { + PathOnly, + EntireFile, + Excerpts { ranges: Vec> }, +} + +impl ProjectContext { + pub fn new(project: WeakModel, fs: Arc) -> Self { + Self { + files: TreeMap::default(), + fs, + project, + } + } + + pub fn add_path(&mut self, project_path: ProjectPath) { + if self.files.get(&project_path).is_none() { + self.files.insert(project_path, PathState::PathOnly); + } + } + + pub fn add_excerpts(&mut self, project_path: ProjectPath, new_ranges: &[Range]) { + let previous_state = self + .files + .get(&project_path) + .unwrap_or(&PathState::PathOnly); + + let mut ranges = match previous_state { + PathState::EntireFile => return, + PathState::PathOnly => Vec::new(), + PathState::Excerpts { ranges } => ranges.to_vec(), + }; + + for new_range in new_ranges { + let ix = ranges.binary_search_by(|probe| { + if probe.end < new_range.start { + Ordering::Less + } else if probe.start > new_range.end { + Ordering::Greater + } else { + Ordering::Equal + } + }); + + match ix { + Ok(mut ix) => { + let existing = &mut ranges[ix]; + existing.start = existing.start.min(new_range.start); + existing.end = existing.end.max(new_range.end); + while ix + 1 < ranges.len() && ranges[ix + 1].start <= ranges[ix].end { + ranges[ix].end = ranges[ix].end.max(ranges[ix + 1].end); + ranges.remove(ix + 1); + } + while ix > 0 && ranges[ix - 1].end >= ranges[ix].start { + ranges[ix].start = ranges[ix].start.min(ranges[ix - 1].start); + ranges.remove(ix - 1); + ix -= 1; + } + } + Err(ix) => { + ranges.insert(ix, new_range.clone()); + } + } + } + + self.files + .insert(project_path, PathState::Excerpts { ranges }); + } + + pub fn add_file(&mut self, project_path: ProjectPath) { + self.files.insert(project_path, PathState::EntireFile); + } + + pub fn generate_system_message(&self, cx: &mut AppContext) -> Task> { + let project = self + .project + .upgrade() + .ok_or_else(|| anyhow!("project dropped")); + let files = self.files.clone(); + let fs = self.fs.clone(); + cx.spawn(|cx| async move { + let project = project?; + let mut result = "project structure:\n".to_string(); + + let mut last_worktree: Option> = None; + for (project_path, path_state) in files.iter() { + if let Some(worktree) = &last_worktree { + if worktree.read_with(&cx, |tree, _| tree.id())? != project_path.worktree_id { + last_worktree = None; + } + } + + let worktree; + if let Some(last_worktree) = &last_worktree { + worktree = last_worktree.clone(); + } else if let Some(tree) = project.read_with(&cx, |project, cx| { + project.worktree_for_id(project_path.worktree_id, cx) + })? { + worktree = tree; + last_worktree = Some(worktree.clone()); + let worktree_name = + worktree.read_with(&cx, |tree, _cx| tree.root_name().to_string())?; + writeln!(&mut result, "# {}", worktree_name).unwrap(); + } else { + continue; + } + + let worktree_abs_path = worktree.read_with(&cx, |tree, _cx| tree.abs_path())?; + let path = &project_path.path; + writeln!(&mut result, "## {}", path.display()).unwrap(); + + match path_state { + PathState::PathOnly => {} + PathState::EntireFile => { + let text = fs.load(&worktree_abs_path.join(&path)).await?; + writeln!(&mut result, "~~~\n{text}\n~~~").unwrap(); + } + PathState::Excerpts { ranges } => { + let text = fs.load(&worktree_abs_path.join(&path)).await?; + + writeln!(&mut result, "~~~").unwrap(); + + // Assumption: ranges are in order, not overlapping + let mut prev_range_end = 0; + for range in ranges { + if range.start > prev_range_end { + writeln!(&mut result, "...").unwrap(); + prev_range_end = range.end; + } + + let mut start = range.start; + let mut end = range.end.min(text.len()); + while !text.is_char_boundary(start) { + start += 1; + } + while !text.is_char_boundary(end) { + end -= 1; + } + result.push_str(&text[start..end]); + if !result.ends_with('\n') { + result.push('\n'); + } + } + + if prev_range_end < text.len() { + writeln!(&mut result, "...").unwrap(); + } + + writeln!(&mut result, "~~~").unwrap(); + } + } + } + Ok(result) + }) + } +} + +#[cfg(test)] +mod tests { + use std::path::Path; + + use super::*; + use gpui::TestAppContext; + use project::FakeFs; + use serde_json::json; + use settings::SettingsStore; + + use unindent::Unindent as _; + + #[gpui::test] + async fn test_system_message_generation(cx: &mut TestAppContext) { + init_test(cx); + + let file_3_contents = r#" + fn test1() {} + fn test2() {} + fn test3() {} + "# + .unindent(); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + "/code", + json!({ + "root1": { + "lib": { + "file1.rs": "mod example;", + "file2.rs": "", + }, + "test": { + "file3.rs": file_3_contents, + } + }, + "root2": { + "src": { + "main.rs": "" + } + } + }), + ) + .await; + + let project = Project::test( + fs.clone(), + ["/code/root1".as_ref(), "/code/root2".as_ref()], + cx, + ) + .await; + + let worktree_ids = project.read_with(cx, |project, cx| { + project + .worktrees() + .map(|worktree| worktree.read(cx).id()) + .collect::>() + }); + + let mut ax = ProjectContext::new(project.downgrade(), fs); + + ax.add_file(ProjectPath { + worktree_id: worktree_ids[0], + path: Path::new("lib/file1.rs").into(), + }); + + let message = cx + .update(|cx| ax.generate_system_message(cx)) + .await + .unwrap(); + assert_eq!( + r#" + project structure: + # root1 + ## lib/file1.rs + ~~~ + mod example; + ~~~ + "# + .unindent(), + message + ); + + ax.add_excerpts( + ProjectPath { + worktree_id: worktree_ids[0], + path: Path::new("test/file3.rs").into(), + }, + &[ + file_3_contents.find("fn test2").unwrap() + ..file_3_contents.find("fn test3").unwrap(), + ], + ); + + let message = cx + .update(|cx| ax.generate_system_message(cx)) + .await + .unwrap(); + assert_eq!( + r#" + project structure: + # root1 + ## lib/file1.rs + ~~~ + mod example; + ~~~ + ## test/file3.rs + ~~~ + ... + fn test2() {} + ... + ~~~ + "# + .unindent(), + message + ); + } + + fn init_test(cx: &mut TestAppContext) { + cx.update(|cx| { + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + Project::init_settings(cx); + }); + } +} diff --git a/crates/assistant_tooling/src/tool.rs b/crates/assistant_tooling/src/tool.rs deleted file mode 100644 index 8bc55bac80..0000000000 --- a/crates/assistant_tooling/src/tool.rs +++ /dev/null @@ -1,111 +0,0 @@ -use anyhow::Result; -use gpui::{div, AnyElement, AnyView, IntoElement, Render, Task, View, WindowContext}; -use schemars::{schema::RootSchema, schema_for, JsonSchema}; -use serde::Deserialize; -use std::fmt::Display; - -#[derive(Default, Deserialize)] -pub struct ToolFunctionCall { - pub id: String, - pub name: String, - pub arguments: String, - #[serde(skip)] - pub result: Option, -} - -pub enum ToolFunctionCallResult { - NoSuchTool, - ParsingFailed, - Finished { for_model: String, view: AnyView }, -} - -impl ToolFunctionCallResult { - pub fn format(&self, name: &String) -> String { - match self { - ToolFunctionCallResult::NoSuchTool => format!("No tool for {name}"), - ToolFunctionCallResult::ParsingFailed => { - format!("Unable to parse arguments for {name}") - } - ToolFunctionCallResult::Finished { for_model, .. } => for_model.clone(), - } - } - - pub fn into_any_element(&self, name: &String) -> AnyElement { - match self { - ToolFunctionCallResult::NoSuchTool => { - format!("Language Model attempted to call {name}").into_any_element() - } - ToolFunctionCallResult::ParsingFailed => { - format!("Language Model called {name} with bad arguments").into_any_element() - } - ToolFunctionCallResult::Finished { view, .. } => view.clone().into_any_element(), - } - } -} - -#[derive(Clone)] -pub struct ToolFunctionDefinition { - pub name: String, - pub description: String, - pub parameters: RootSchema, -} - -impl Display for ToolFunctionDefinition { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let schema = serde_json::to_string(&self.parameters).ok(); - let schema = schema.unwrap_or("None".to_string()); - write!(f, "Name: {}:\n", self.name)?; - write!(f, "Description: {}\n", self.description)?; - write!(f, "Parameters: {}", schema) - } -} - -pub trait LanguageModelTool { - /// The input type that will be passed in to `execute` when the tool is called - /// by the language model. - type Input: for<'de> Deserialize<'de> + JsonSchema; - - /// The output returned by executing the tool. - type Output: 'static; - - type View: Render; - - /// Returns the name of the tool. - /// - /// This name is exposed to the language model to allow the model to pick - /// which tools to use. As this name is used to identify the tool within a - /// tool registry, it should be unique. - fn name(&self) -> String; - - /// Returns the description of the tool. - /// - /// This can be used to _prompt_ the model as to what the tool does. - fn description(&self) -> String; - - /// Returns the OpenAI Function definition for the tool, for direct use with OpenAI's API. - fn definition(&self) -> ToolFunctionDefinition { - let root_schema = schema_for!(Self::Input); - - ToolFunctionDefinition { - name: self.name(), - description: self.description(), - parameters: root_schema, - } - } - - /// Executes the tool with the given input. - fn execute(&self, input: &Self::Input, cx: &mut WindowContext) -> Task>; - - fn format(input: &Self::Input, output: &Result) -> String; - - fn output_view( - tool_call_id: String, - input: Self::Input, - output: Result, - cx: &mut WindowContext, - ) -> View; - - fn render_running(_cx: &mut WindowContext) -> impl IntoElement { - div() - } -} diff --git a/crates/assistant_tooling/src/registry.rs b/crates/assistant_tooling/src/tool_registry.rs similarity index 59% rename from crates/assistant_tooling/src/registry.rs rename to crates/assistant_tooling/src/tool_registry.rs index 4c3c1a082d..5e1da303f9 100644 --- a/crates/assistant_tooling/src/registry.rs +++ b/crates/assistant_tooling/src/tool_registry.rs @@ -1,54 +1,115 @@ use anyhow::{anyhow, Result}; -use gpui::{div, AnyElement, IntoElement as _, ParentElement, Styled, Task, WindowContext}; +use gpui::{ + div, AnyElement, AnyView, IntoElement, ParentElement, Render, Styled, Task, View, WindowContext, +}; +use schemars::{schema::RootSchema, schema_for, JsonSchema}; +use serde::Deserialize; use std::{ any::TypeId, collections::HashMap, + fmt::Display, sync::atomic::{AtomicBool, Ordering::SeqCst}, }; -use crate::tool::{ - LanguageModelTool, ToolFunctionCall, ToolFunctionCallResult, ToolFunctionDefinition, -}; +use crate::ProjectContext; -// Internal Tool representation for the registry -pub struct Tool { - enabled: AtomicBool, - type_id: TypeId, - call: Box Task>>, - render_running: Box gpui::AnyElement>, - definition: ToolFunctionDefinition, +pub struct ToolRegistry { + registered_tools: HashMap, } -impl Tool { - fn new( - type_id: TypeId, - call: Box Task>>, - render_running: Box gpui::AnyElement>, - definition: ToolFunctionDefinition, - ) -> Self { - Self { - enabled: AtomicBool::new(true), - type_id, - call, - render_running, - definition, +#[derive(Default, Deserialize)] +pub struct ToolFunctionCall { + pub id: String, + pub name: String, + pub arguments: String, + #[serde(skip)] + pub result: Option, +} + +pub enum ToolFunctionCallResult { + NoSuchTool, + ParsingFailed, + Finished { + view: AnyView, + generate_fn: fn(AnyView, &mut ProjectContext, &mut WindowContext) -> String, + }, +} + +#[derive(Clone)] +pub struct ToolFunctionDefinition { + pub name: String, + pub description: String, + pub parameters: RootSchema, +} + +pub trait LanguageModelTool { + /// The input type that will be passed in to `execute` when the tool is called + /// by the language model. + type Input: for<'de> Deserialize<'de> + JsonSchema; + + /// The output returned by executing the tool. + type Output: 'static; + + type View: Render + ToolOutput; + + /// Returns the name of the tool. + /// + /// This name is exposed to the language model to allow the model to pick + /// which tools to use. As this name is used to identify the tool within a + /// tool registry, it should be unique. + fn name(&self) -> String; + + /// Returns the description of the tool. + /// + /// This can be used to _prompt_ the model as to what the tool does. + fn description(&self) -> String; + + /// Returns the OpenAI Function definition for the tool, for direct use with OpenAI's API. + fn definition(&self) -> ToolFunctionDefinition { + let root_schema = schema_for!(Self::Input); + + ToolFunctionDefinition { + name: self.name(), + description: self.description(), + parameters: root_schema, } } + + /// Executes the tool with the given input. + fn execute(&self, input: &Self::Input, cx: &mut WindowContext) -> Task>; + + fn output_view( + input: Self::Input, + output: Result, + cx: &mut WindowContext, + ) -> View; + + fn render_running(_cx: &mut WindowContext) -> impl IntoElement { + div() + } } -pub struct ToolRegistry { - tools: HashMap, +pub trait ToolOutput: Sized { + fn generate(&self, project: &mut ProjectContext, cx: &mut WindowContext) -> String; +} + +struct RegisteredTool { + enabled: AtomicBool, + type_id: TypeId, + call: Box Task>>, + render_running: fn(&mut WindowContext) -> gpui::AnyElement, + definition: ToolFunctionDefinition, } impl ToolRegistry { pub fn new() -> Self { Self { - tools: HashMap::new(), + registered_tools: HashMap::new(), } } pub fn set_tool_enabled(&self, is_enabled: bool) { - for tool in self.tools.values() { + for tool in self.registered_tools.values() { if tool.type_id == TypeId::of::() { tool.enabled.store(is_enabled, SeqCst); return; @@ -57,7 +118,7 @@ impl ToolRegistry { } pub fn is_tool_enabled(&self) -> bool { - for tool in self.tools.values() { + for tool in self.registered_tools.values() { if tool.type_id == TypeId::of::() { return tool.enabled.load(SeqCst); } @@ -66,7 +127,7 @@ impl ToolRegistry { } pub fn definitions(&self) -> Vec { - self.tools + self.registered_tools .values() .filter(|tool| tool.enabled.load(SeqCst)) .map(|tool| tool.definition.clone()) @@ -84,7 +145,7 @@ impl ToolRegistry { .child(result.into_any_element(&tool_call.name)) .into_any_element(), None => self - .tools + .registered_tools .get(&tool_call.name) .map(|tool| (tool.render_running)(cx)) .unwrap_or_else(|| div().into_any_element()), @@ -96,13 +157,12 @@ impl ToolRegistry { tool: T, _cx: &mut WindowContext, ) -> Result<()> { - let definition = tool.definition(); - let name = tool.name(); - - let registered_tool = Tool::new( - TypeId::of::(), - Box::new( + let registered_tool = RegisteredTool { + type_id: TypeId::of::(), + definition: tool.definition(), + enabled: AtomicBool::new(true), + call: Box::new( move |tool_call: &ToolFunctionCall, cx: &mut WindowContext| { let name = tool_call.name.clone(); let arguments = tool_call.arguments.clone(); @@ -121,8 +181,7 @@ impl ToolRegistry { cx.spawn(move |mut cx| async move { let result: Result = result.await; - let for_model = T::format(&input, &result); - let view = cx.update(|cx| T::output_view(id.clone(), input, result, cx))?; + let view = cx.update(|cx| T::output_view(input, result, cx))?; Ok(ToolFunctionCall { id, @@ -130,23 +189,35 @@ impl ToolRegistry { arguments, result: Some(ToolFunctionCallResult::Finished { view: view.into(), - for_model, + generate_fn: generate::, }), }) }) }, ), - Box::new(|cx| T::render_running(cx).into_any_element()), - definition, - ); - - let previous = self.tools.insert(name.clone(), registered_tool); + render_running: render_running::, + }; + let previous = self.registered_tools.insert(name.clone(), registered_tool); if previous.is_some() { return Err(anyhow!("already registered a tool with name {}", name)); } - Ok(()) + return Ok(()); + + fn render_running(cx: &mut WindowContext) -> AnyElement { + T::render_running(cx).into_any_element() + } + + fn generate( + view: AnyView, + project: &mut ProjectContext, + cx: &mut WindowContext, + ) -> String { + view.downcast::() + .unwrap() + .update(cx, |view, cx| T::View::generate(view, project, cx)) + } } /// Task yields an error if the window for the given WindowContext is closed before the task completes. @@ -159,7 +230,7 @@ impl ToolRegistry { let arguments = tool_call.arguments.clone(); let id = tool_call.id.clone(); - let tool = match self.tools.get(&name) { + let tool = match self.registered_tools.get(&name) { Some(tool) => tool, None => { let name = name.clone(); @@ -176,6 +247,47 @@ impl ToolRegistry { } } +impl ToolFunctionCallResult { + pub fn generate( + &self, + name: &String, + project: &mut ProjectContext, + cx: &mut WindowContext, + ) -> String { + match self { + ToolFunctionCallResult::NoSuchTool => format!("No tool for {name}"), + ToolFunctionCallResult::ParsingFailed => { + format!("Unable to parse arguments for {name}") + } + ToolFunctionCallResult::Finished { generate_fn, view } => { + (generate_fn)(view.clone(), project, cx) + } + } + } + + fn into_any_element(&self, name: &String) -> AnyElement { + match self { + ToolFunctionCallResult::NoSuchTool => { + format!("Language Model attempted to call {name}").into_any_element() + } + ToolFunctionCallResult::ParsingFailed => { + format!("Language Model called {name} with bad arguments").into_any_element() + } + ToolFunctionCallResult::Finished { view, .. } => view.clone().into_any_element(), + } + } +} + +impl Display for ToolFunctionDefinition { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let schema = serde_json::to_string(&self.parameters).ok(); + let schema = schema.unwrap_or("None".to_string()); + write!(f, "Name: {}:\n", self.name)?; + write!(f, "Description: {}\n", self.description)?; + write!(f, "Parameters: {}", schema) + } +} + #[cfg(test)] mod test { use super::*; @@ -213,6 +325,12 @@ mod test { } } + impl ToolOutput for WeatherView { + fn generate(&self, _output: &mut ProjectContext, _cx: &mut WindowContext) -> String { + serde_json::to_string(&self.result).unwrap() + } + } + impl LanguageModelTool for WeatherTool { type Input = WeatherQuery; type Output = WeatherResult; @@ -240,7 +358,6 @@ mod test { } fn output_view( - _tool_call_id: String, _input: Self::Input, result: Result, cx: &mut WindowContext, @@ -250,10 +367,6 @@ mod test { WeatherView { result } }) } - - fn format(_: &Self::Input, output: &Result) -> String { - serde_json::to_string(&output.as_ref().unwrap()).unwrap() - } } #[gpui::test] diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index 1501104a25..98ca2f25c7 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -163,6 +163,10 @@ impl ProjectIndex { self.project.clone() } + pub fn fs(&self) -> Arc { + self.fs.clone() + } + fn handle_project_event( &mut self, _: Model,