mirror of
https://github.com/zed-industries/zed.git
synced 2024-11-08 07:35:01 +03:00
Implement serialization of assistant conversations, including tool calls and attachments (#11577)
Release Notes: - N/A --------- Co-authored-by: Kyle <kylek@zed.dev> Co-authored-by: Marshall <marshall@zed.dev>
This commit is contained in:
parent
24ffa0fcf3
commit
a7aa2578e1
@ -6,19 +6,14 @@ mod saved_conversation_picker;
|
||||
mod tools;
|
||||
pub mod ui;
|
||||
|
||||
use crate::saved_conversation::{SavedConversation, SavedMessage, SavedMessageRole};
|
||||
use crate::saved_conversation_picker::SavedConversationPicker;
|
||||
use crate::{
|
||||
attachments::ActiveEditorAttachmentTool,
|
||||
tools::{CreateBufferTool, ProjectIndexTool},
|
||||
ui::UserOrAssistant,
|
||||
};
|
||||
use crate::ui::UserOrAssistant;
|
||||
use ::ui::{div, prelude::*, Color, Tooltip, ViewContext};
|
||||
use anyhow::{Context, Result};
|
||||
use assistant_tooling::{
|
||||
tool_running_placeholder, AttachmentRegistry, ProjectContext, ToolFunctionCall, ToolRegistry,
|
||||
UserAttachment,
|
||||
};
|
||||
use attachments::ActiveEditorAttachmentTool;
|
||||
use client::{proto, Client, UserStore};
|
||||
use collections::HashMap;
|
||||
use completion_provider::*;
|
||||
@ -33,11 +28,13 @@ use gpui::{
|
||||
use language::{language_settings::SoftWrap, LanguageRegistry};
|
||||
use open_ai::{FunctionContent, ToolCall, ToolCallContent};
|
||||
use rich_text::RichText;
|
||||
use saved_conversation::{SavedAssistantMessagePart, SavedChatMessage, SavedConversation};
|
||||
use saved_conversation_picker::SavedConversationPicker;
|
||||
use semantic_index::{CloudEmbeddingProvider, ProjectIndex, ProjectIndexDebugView, SemanticIndex};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::Settings;
|
||||
use std::sync::Arc;
|
||||
use tools::AnnotationTool;
|
||||
use tools::{AnnotationTool, CreateBufferTool, ProjectIndexTool};
|
||||
use ui::{ActiveFileButton, Composer, ProjectIndexButton};
|
||||
use util::paths::CONVERSATIONS_DIR;
|
||||
use util::{maybe, paths::EMBEDDINGS_DIR, ResultExt};
|
||||
@ -506,13 +503,11 @@ impl AssistantChat {
|
||||
while let Some(delta) = stream.next().await {
|
||||
let delta = delta?;
|
||||
this.update(cx, |this, cx| {
|
||||
if let Some(ChatMessage::Assistant(GroupedAssistantMessage {
|
||||
messages,
|
||||
..
|
||||
})) = this.messages.last_mut()
|
||||
if let Some(ChatMessage::Assistant(AssistantMessage { messages, .. })) =
|
||||
this.messages.last_mut()
|
||||
{
|
||||
if messages.is_empty() {
|
||||
messages.push(AssistantMessage {
|
||||
messages.push(AssistantMessagePart {
|
||||
body: RichText::default(),
|
||||
tool_calls: Vec::new(),
|
||||
})
|
||||
@ -563,7 +558,7 @@ impl AssistantChat {
|
||||
|
||||
let mut tool_tasks = Vec::new();
|
||||
this.update(cx, |this, cx| {
|
||||
if let Some(ChatMessage::Assistant(GroupedAssistantMessage {
|
||||
if let Some(ChatMessage::Assistant(AssistantMessage {
|
||||
error: message_error,
|
||||
messages,
|
||||
..
|
||||
@ -592,7 +587,7 @@ impl AssistantChat {
|
||||
let tools = tools.into_iter().filter_map(|tool| tool.ok()).collect();
|
||||
|
||||
this.update(cx, |this, cx| {
|
||||
if let Some(ChatMessage::Assistant(GroupedAssistantMessage { messages, .. })) =
|
||||
if let Some(ChatMessage::Assistant(AssistantMessage { messages, .. })) =
|
||||
this.messages.last_mut()
|
||||
{
|
||||
if let Some(current_message) = messages.last_mut() {
|
||||
@ -608,19 +603,19 @@ impl AssistantChat {
|
||||
|
||||
fn push_new_assistant_message(&mut self, cx: &mut ViewContext<Self>) {
|
||||
// If the last message is a grouped assistant message, add to the grouped message
|
||||
if let Some(ChatMessage::Assistant(GroupedAssistantMessage { messages, .. })) =
|
||||
if let Some(ChatMessage::Assistant(AssistantMessage { messages, .. })) =
|
||||
self.messages.last_mut()
|
||||
{
|
||||
messages.push(AssistantMessage {
|
||||
messages.push(AssistantMessagePart {
|
||||
body: RichText::default(),
|
||||
tool_calls: Vec::new(),
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
let message = ChatMessage::Assistant(GroupedAssistantMessage {
|
||||
let message = ChatMessage::Assistant(AssistantMessage {
|
||||
id: self.next_message_id.post_inc(),
|
||||
messages: vec![AssistantMessage {
|
||||
messages: vec![AssistantMessagePart {
|
||||
body: RichText::default(),
|
||||
tool_calls: Vec::new(),
|
||||
}],
|
||||
@ -669,40 +664,30 @@ impl AssistantChat {
|
||||
*entry = !*entry;
|
||||
}
|
||||
|
||||
fn new_conversation(&mut self, cx: &mut ViewContext<Self>) {
|
||||
let messages = self
|
||||
.messages
|
||||
.drain(..)
|
||||
.map(|message| {
|
||||
let text = match &message {
|
||||
ChatMessage::User(message) => message.body.read(cx).text(cx),
|
||||
ChatMessage::Assistant(message) => message
|
||||
.messages
|
||||
.iter()
|
||||
.map(|message| message.body.text.to_string())
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n\n"),
|
||||
};
|
||||
|
||||
SavedMessage {
|
||||
id: message.id(),
|
||||
role: match message {
|
||||
ChatMessage::User(_) => SavedMessageRole::User,
|
||||
ChatMessage::Assistant(_) => SavedMessageRole::Assistant,
|
||||
},
|
||||
text,
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// Reset the chat for the new conversation.
|
||||
fn reset(&mut self) {
|
||||
self.messages.clear();
|
||||
self.list_state.reset(0);
|
||||
self.editing_message.take();
|
||||
self.collapsed_messages.clear();
|
||||
}
|
||||
|
||||
fn new_conversation(&mut self, cx: &mut ViewContext<Self>) {
|
||||
let messages = std::mem::take(&mut self.messages)
|
||||
.into_iter()
|
||||
.map(|message| self.serialize_message(message, cx))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
self.reset();
|
||||
|
||||
let title = messages
|
||||
.first()
|
||||
.map(|message| message.text.clone())
|
||||
.map(|message| match message {
|
||||
SavedChatMessage::User { body, .. } => body.clone(),
|
||||
SavedChatMessage::Assistant { messages, .. } => messages
|
||||
.first()
|
||||
.map(|message| message.body.to_string())
|
||||
.unwrap_or_default(),
|
||||
})
|
||||
.unwrap_or_else(|| "A conversation with the assistant.".to_string());
|
||||
|
||||
let saved_conversation = SavedConversation {
|
||||
@ -836,7 +821,7 @@ impl AssistantChat {
|
||||
}
|
||||
})
|
||||
.into_any(),
|
||||
ChatMessage::Assistant(GroupedAssistantMessage {
|
||||
ChatMessage::Assistant(AssistantMessage {
|
||||
id,
|
||||
messages,
|
||||
error,
|
||||
@ -917,7 +902,7 @@ impl AssistantChat {
|
||||
content: body.read(cx).text(cx),
|
||||
});
|
||||
}
|
||||
ChatMessage::Assistant(GroupedAssistantMessage { messages, .. }) => {
|
||||
ChatMessage::Assistant(AssistantMessage { messages, .. }) => {
|
||||
for message in messages {
|
||||
let body = message.body.clone();
|
||||
|
||||
@ -971,6 +956,43 @@ impl AssistantChat {
|
||||
Ok(completion_messages)
|
||||
})
|
||||
}
|
||||
|
||||
fn serialize_message(
|
||||
&self,
|
||||
message: ChatMessage,
|
||||
cx: &mut ViewContext<AssistantChat>,
|
||||
) -> SavedChatMessage {
|
||||
match message {
|
||||
ChatMessage::User(message) => SavedChatMessage::User {
|
||||
id: message.id,
|
||||
body: message.body.read(cx).text(cx),
|
||||
attachments: message
|
||||
.attachments
|
||||
.iter()
|
||||
.map(|attachment| {
|
||||
self.attachment_registry
|
||||
.serialize_user_attachment(attachment)
|
||||
})
|
||||
.collect(),
|
||||
},
|
||||
ChatMessage::Assistant(message) => SavedChatMessage::Assistant {
|
||||
id: message.id,
|
||||
error: message.error,
|
||||
messages: message
|
||||
.messages
|
||||
.iter()
|
||||
.map(|message| SavedAssistantMessagePart {
|
||||
body: message.body.text.clone(),
|
||||
tool_calls: message
|
||||
.tool_calls
|
||||
.iter()
|
||||
.map(|tool_call| self.tool_registry.serialize_tool_call(tool_call))
|
||||
.collect(),
|
||||
})
|
||||
.collect(),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Render for AssistantChat {
|
||||
@ -1053,17 +1075,10 @@ impl MessageId {
|
||||
|
||||
enum ChatMessage {
|
||||
User(UserMessage),
|
||||
Assistant(GroupedAssistantMessage),
|
||||
Assistant(AssistantMessage),
|
||||
}
|
||||
|
||||
impl ChatMessage {
|
||||
pub fn id(&self) -> MessageId {
|
||||
match self {
|
||||
ChatMessage::User(message) => message.id,
|
||||
ChatMessage::Assistant(message) => message.id,
|
||||
}
|
||||
}
|
||||
|
||||
fn focus_handle(&self, cx: &AppContext) -> Option<FocusHandle> {
|
||||
match self {
|
||||
ChatMessage::User(UserMessage { body, .. }) => Some(body.focus_handle(cx)),
|
||||
@ -1073,18 +1088,18 @@ impl ChatMessage {
|
||||
}
|
||||
|
||||
struct UserMessage {
|
||||
id: MessageId,
|
||||
body: View<Editor>,
|
||||
attachments: Vec<UserAttachment>,
|
||||
pub id: MessageId,
|
||||
pub body: View<Editor>,
|
||||
pub attachments: Vec<UserAttachment>,
|
||||
}
|
||||
|
||||
struct AssistantMessagePart {
|
||||
pub body: RichText,
|
||||
pub tool_calls: Vec<ToolFunctionCall>,
|
||||
}
|
||||
|
||||
struct AssistantMessage {
|
||||
body: RichText,
|
||||
tool_calls: Vec<ToolFunctionCall>,
|
||||
}
|
||||
|
||||
struct GroupedAssistantMessage {
|
||||
id: MessageId,
|
||||
messages: Vec<AssistantMessage>,
|
||||
error: Option<SharedString>,
|
||||
pub id: MessageId,
|
||||
pub messages: Vec<AssistantMessagePart>,
|
||||
pub error: Option<SharedString>,
|
||||
}
|
||||
|
@ -1,64 +1,68 @@
|
||||
use std::{path::PathBuf, sync::Arc};
|
||||
|
||||
use anyhow::{anyhow, Result};
|
||||
use assistant_tooling::{LanguageModelAttachment, ProjectContext, ToolOutput};
|
||||
use editor::Editor;
|
||||
use gpui::{Render, Task, View, WeakModel, WeakView};
|
||||
use language::Buffer;
|
||||
use project::ProjectPath;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use ui::{prelude::*, ButtonLike, Tooltip, WindowContext};
|
||||
use util::maybe;
|
||||
use workspace::Workspace;
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct ActiveEditorAttachment {
|
||||
buffer: WeakModel<Buffer>,
|
||||
path: Option<ProjectPath>,
|
||||
#[serde(skip)]
|
||||
buffer: Option<WeakModel<Buffer>>,
|
||||
path: Option<PathBuf>,
|
||||
}
|
||||
|
||||
pub struct FileAttachmentView {
|
||||
output: Result<ActiveEditorAttachment>,
|
||||
project_path: Option<ProjectPath>,
|
||||
buffer: Option<WeakModel<Buffer>>,
|
||||
error: Option<anyhow::Error>,
|
||||
}
|
||||
|
||||
impl Render for FileAttachmentView {
|
||||
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
|
||||
match &self.output {
|
||||
Ok(attachment) => {
|
||||
let filename: SharedString = attachment
|
||||
.path
|
||||
.as_ref()
|
||||
.and_then(|p| p.path.file_name()?.to_str())
|
||||
.unwrap_or("Untitled")
|
||||
.to_string()
|
||||
.into();
|
||||
|
||||
// todo!(): make the button link to the actual file to open
|
||||
ButtonLike::new("file-attachment")
|
||||
.child(
|
||||
h_flex()
|
||||
.gap_1()
|
||||
.bg(cx.theme().colors().editor_background)
|
||||
.rounded_md()
|
||||
.child(ui::Icon::new(IconName::File))
|
||||
.child(filename.clone()),
|
||||
)
|
||||
.tooltip({
|
||||
move |cx| Tooltip::with_meta("File Attached", None, filename.clone(), cx)
|
||||
})
|
||||
.into_any_element()
|
||||
}
|
||||
Err(err) => div().child(err.to_string()).into_any_element(),
|
||||
if let Some(error) = &self.error {
|
||||
return div().child(error.to_string()).into_any_element();
|
||||
}
|
||||
|
||||
let filename: SharedString = self
|
||||
.project_path
|
||||
.as_ref()
|
||||
.and_then(|p| p.path.file_name()?.to_str())
|
||||
.unwrap_or("Untitled")
|
||||
.to_string()
|
||||
.into();
|
||||
|
||||
ButtonLike::new("file-attachment")
|
||||
.child(
|
||||
h_flex()
|
||||
.gap_1()
|
||||
.bg(cx.theme().colors().editor_background)
|
||||
.rounded_md()
|
||||
.child(ui::Icon::new(IconName::File))
|
||||
.child(filename.clone()),
|
||||
)
|
||||
.tooltip(move |cx| Tooltip::with_meta("File Attached", None, filename.clone(), cx))
|
||||
.into_any_element()
|
||||
}
|
||||
}
|
||||
|
||||
impl ToolOutput for FileAttachmentView {
|
||||
fn generate(&self, project: &mut ProjectContext, cx: &mut WindowContext) -> String {
|
||||
if let Ok(result) = &self.output {
|
||||
if let Some(path) = &result.path {
|
||||
project.add_file(path.clone());
|
||||
return format!("current file: {}", path.path.display());
|
||||
} else if let Some(buffer) = result.buffer.upgrade() {
|
||||
return format!("current untitled buffer text:\n{}", buffer.read(cx).text());
|
||||
}
|
||||
if let Some(path) = &self.project_path {
|
||||
project.add_file(path.clone());
|
||||
return format!("current file: {}", path.path.display());
|
||||
}
|
||||
|
||||
if let Some(buffer) = self.buffer.as_ref().and_then(|buffer| buffer.upgrade()) {
|
||||
return format!("current untitled buffer text:\n{}", buffer.read(cx).text());
|
||||
}
|
||||
|
||||
String::new()
|
||||
}
|
||||
}
|
||||
@ -77,6 +81,10 @@ impl LanguageModelAttachment for ActiveEditorAttachmentTool {
|
||||
type Output = ActiveEditorAttachment;
|
||||
type View = FileAttachmentView;
|
||||
|
||||
fn name(&self) -> Arc<str> {
|
||||
"active-editor-attachment".into()
|
||||
}
|
||||
|
||||
fn run(&self, cx: &mut WindowContext) -> Task<Result<ActiveEditorAttachment>> {
|
||||
Task::ready(maybe!({
|
||||
let active_buffer = self
|
||||
@ -91,13 +99,10 @@ impl LanguageModelAttachment for ActiveEditorAttachmentTool {
|
||||
let buffer = active_buffer.read(cx);
|
||||
|
||||
if let Some(buffer) = buffer.as_singleton() {
|
||||
let path =
|
||||
project::File::from_dyn(buffer.read(cx).file()).map(|file| ProjectPath {
|
||||
worktree_id: file.worktree_id(cx),
|
||||
path: file.path.clone(),
|
||||
});
|
||||
let path = project::File::from_dyn(buffer.read(cx).file())
|
||||
.and_then(|file| file.worktree.read(cx).absolutize(&file.path).ok());
|
||||
return Ok(ActiveEditorAttachment {
|
||||
buffer: buffer.downgrade(),
|
||||
buffer: Some(buffer.downgrade()),
|
||||
path,
|
||||
});
|
||||
} else {
|
||||
@ -106,7 +111,34 @@ impl LanguageModelAttachment for ActiveEditorAttachmentTool {
|
||||
}))
|
||||
}
|
||||
|
||||
fn view(output: Result<Self::Output>, cx: &mut WindowContext) -> View<Self::View> {
|
||||
cx.new_view(|_cx| FileAttachmentView { output })
|
||||
fn view(
|
||||
&self,
|
||||
output: Result<ActiveEditorAttachment>,
|
||||
cx: &mut WindowContext,
|
||||
) -> View<Self::View> {
|
||||
let error;
|
||||
let project_path;
|
||||
let buffer;
|
||||
match output {
|
||||
Ok(output) => {
|
||||
error = None;
|
||||
let workspace = self.workspace.upgrade().unwrap();
|
||||
let project = workspace.read(cx).project();
|
||||
project_path = output
|
||||
.path
|
||||
.and_then(|path| project.read(cx).project_path_for_absolute_path(&path, cx));
|
||||
buffer = output.buffer;
|
||||
}
|
||||
Err(err) => {
|
||||
error = Some(err);
|
||||
buffer = None;
|
||||
project_path = None;
|
||||
}
|
||||
}
|
||||
cx.new_view(|_cx| FileAttachmentView {
|
||||
project_path,
|
||||
buffer,
|
||||
error,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -1,3 +1,5 @@
|
||||
use assistant_tooling::{SavedToolFunctionCall, SavedUserAttachment};
|
||||
use gpui::SharedString;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::MessageId;
|
||||
@ -8,21 +10,27 @@ pub struct SavedConversation {
|
||||
pub version: String,
|
||||
/// The title of the conversation, generated by the Assistant.
|
||||
pub title: String,
|
||||
pub messages: Vec<SavedMessage>,
|
||||
pub messages: Vec<SavedChatMessage>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum SavedMessageRole {
|
||||
User,
|
||||
Assistant,
|
||||
pub enum SavedChatMessage {
|
||||
User {
|
||||
id: MessageId,
|
||||
body: String,
|
||||
attachments: Vec<SavedUserAttachment>,
|
||||
},
|
||||
Assistant {
|
||||
id: MessageId,
|
||||
messages: Vec<SavedAssistantMessagePart>,
|
||||
error: Option<SharedString>,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct SavedMessage {
|
||||
pub id: MessageId,
|
||||
pub role: SavedMessageRole,
|
||||
pub text: String,
|
||||
pub struct SavedAssistantMessagePart {
|
||||
pub body: SharedString,
|
||||
pub tool_calls: Vec<SavedToolFunctionCall>,
|
||||
}
|
||||
|
||||
/// Returns a list of placeholder conversations for mocking the UI.
|
||||
|
@ -6,7 +6,7 @@ use editor::{
|
||||
};
|
||||
use gpui::{prelude::*, AnyElement, Model, Task, View, WeakView};
|
||||
use language::ToPoint;
|
||||
use project::{Project, ProjectPath};
|
||||
use project::{search::SearchQuery, Project, ProjectPath};
|
||||
use schemars::JsonSchema;
|
||||
use serde::Deserialize;
|
||||
use std::path::Path;
|
||||
@ -29,17 +29,18 @@ impl AnnotationTool {
|
||||
pub struct AnnotationInput {
|
||||
/// Name for this set of annotations
|
||||
title: String,
|
||||
annotations: Vec<Annotation>,
|
||||
/// Excerpts from the file to show to the user.
|
||||
excerpts: Vec<Excerpt>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, JsonSchema, Clone)]
|
||||
struct Annotation {
|
||||
struct Excerpt {
|
||||
/// Path to the file
|
||||
path: String,
|
||||
/// Name of a symbol in the code
|
||||
symbol_name: String,
|
||||
/// Text to display near the symbol definition
|
||||
text: String,
|
||||
/// A short, distinctive string that appears in the file, used to define a location in the file.
|
||||
text_passage: String,
|
||||
/// Text to display above the code excerpt
|
||||
annotation: String,
|
||||
}
|
||||
|
||||
impl LanguageModelTool for AnnotationTool {
|
||||
@ -58,7 +59,7 @@ impl LanguageModelTool for AnnotationTool {
|
||||
fn execute(&self, input: &Self::Input, cx: &mut WindowContext) -> Task<Result<Self::Output>> {
|
||||
let workspace = self.workspace.clone();
|
||||
let project = self.project.clone();
|
||||
let excerpts = input.annotations.clone();
|
||||
let excerpts = input.excerpts.clone();
|
||||
let title = input.title.clone();
|
||||
|
||||
let worktree_id = project.update(cx, |project, cx| {
|
||||
@ -74,15 +75,16 @@ impl LanguageModelTool for AnnotationTool {
|
||||
};
|
||||
|
||||
let buffer_tasks = project.update(cx, |project, cx| {
|
||||
let excerpts = excerpts.clone();
|
||||
excerpts
|
||||
.iter()
|
||||
.map(|excerpt| {
|
||||
let project_path = ProjectPath {
|
||||
worktree_id,
|
||||
path: Path::new(&excerpt.path).into(),
|
||||
};
|
||||
project.open_buffer(project_path.clone(), cx)
|
||||
project.open_buffer(
|
||||
ProjectPath {
|
||||
worktree_id,
|
||||
path: Path::new(&excerpt.path).into(),
|
||||
},
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
});
|
||||
@ -99,39 +101,43 @@ impl LanguageModelTool for AnnotationTool {
|
||||
for (excerpt, buffer) in excerpts.iter().zip(buffers.iter()) {
|
||||
let snapshot = buffer.update(&mut cx, |buffer, _cx| buffer.snapshot())?;
|
||||
|
||||
if let Some(outline) = snapshot.outline(None) {
|
||||
let matches = outline
|
||||
.search(&excerpt.symbol_name, cx.background_executor().clone())
|
||||
.await;
|
||||
if let Some(mat) = matches.first() {
|
||||
let item = &outline.items[mat.candidate_id];
|
||||
let start = item.range.start.to_point(&snapshot);
|
||||
editor.update(&mut cx, |editor, cx| {
|
||||
let ranges = editor.buffer().update(cx, |multibuffer, cx| {
|
||||
multibuffer.push_excerpts_with_context_lines(
|
||||
buffer.clone(),
|
||||
vec![start..start],
|
||||
5,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
let explanation = SharedString::from(excerpt.text.clone());
|
||||
editor.insert_blocks(
|
||||
[BlockProperties {
|
||||
position: ranges[0].start,
|
||||
height: 2,
|
||||
style: BlockStyle::Fixed,
|
||||
render: Box::new(move |cx| {
|
||||
Self::render_note_block(&explanation, cx)
|
||||
}),
|
||||
disposition: BlockDisposition::Above,
|
||||
}],
|
||||
None,
|
||||
cx,
|
||||
);
|
||||
})?;
|
||||
}
|
||||
}
|
||||
let query =
|
||||
SearchQuery::text(&excerpt.text_passage, false, false, false, vec![], vec![])?;
|
||||
|
||||
let matches = query.search(&snapshot, None).await;
|
||||
let Some(first_match) = matches.first() else {
|
||||
log::warn!(
|
||||
"text {:?} does not appear in '{}'",
|
||||
excerpt.text_passage,
|
||||
excerpt.path
|
||||
);
|
||||
continue;
|
||||
};
|
||||
let mut start = first_match.start.to_point(&snapshot);
|
||||
start.column = 0;
|
||||
|
||||
editor.update(&mut cx, |editor, cx| {
|
||||
let ranges = editor.buffer().update(cx, |multibuffer, cx| {
|
||||
multibuffer.push_excerpts_with_context_lines(
|
||||
buffer.clone(),
|
||||
vec![start..start],
|
||||
5,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
let annotation = SharedString::from(excerpt.annotation.clone());
|
||||
editor.insert_blocks(
|
||||
[BlockProperties {
|
||||
position: ranges[0].start,
|
||||
height: annotation.split('\n').count() as u8 + 1,
|
||||
style: BlockStyle::Fixed,
|
||||
render: Box::new(move |cx| Self::render_note_block(&annotation, cx)),
|
||||
disposition: BlockDisposition::Above,
|
||||
}],
|
||||
None,
|
||||
cx,
|
||||
);
|
||||
})?;
|
||||
}
|
||||
|
||||
workspace
|
||||
@ -144,7 +150,8 @@ impl LanguageModelTool for AnnotationTool {
|
||||
})
|
||||
}
|
||||
|
||||
fn output_view(
|
||||
fn view(
|
||||
&self,
|
||||
_: Self::Input,
|
||||
output: Result<Self::Output>,
|
||||
cx: &mut WindowContext,
|
||||
|
@ -86,7 +86,8 @@ impl LanguageModelTool for CreateBufferTool {
|
||||
})
|
||||
}
|
||||
|
||||
fn output_view(
|
||||
fn view(
|
||||
&self,
|
||||
input: Self::Input,
|
||||
output: Result<Self::Output>,
|
||||
cx: &mut WindowContext,
|
||||
|
@ -1,13 +1,13 @@
|
||||
use anyhow::Result;
|
||||
use anyhow::{anyhow, Result};
|
||||
use assistant_tooling::{LanguageModelTool, ToolOutput};
|
||||
use collections::BTreeMap;
|
||||
use gpui::{prelude::*, Model, Task};
|
||||
use project::ProjectPath;
|
||||
use schemars::JsonSchema;
|
||||
use semantic_index::{ProjectIndex, Status};
|
||||
use serde::Deserialize;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use std::{fmt::Write as _, ops::Range};
|
||||
use std::{fmt::Write as _, ops::Range, path::Path, sync::Arc};
|
||||
use ui::{div, prelude::*, CollapsibleContainer, Color, Icon, IconName, Label, WindowContext};
|
||||
|
||||
const DEFAULT_SEARCH_LIMIT: usize = 20;
|
||||
@ -29,28 +29,24 @@ pub struct CodebaseQuery {
|
||||
|
||||
pub struct ProjectIndexView {
|
||||
input: CodebaseQuery,
|
||||
output: Result<ProjectIndexOutput>,
|
||||
status: Status,
|
||||
excerpts: Result<BTreeMap<ProjectPath, Vec<Range<usize>>>>,
|
||||
element_id: ElementId,
|
||||
expanded_header: bool,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct ProjectIndexOutput {
|
||||
status: Status,
|
||||
excerpts: BTreeMap<ProjectPath, Vec<Range<usize>>>,
|
||||
worktrees: BTreeMap<Arc<Path>, WorktreeIndexOutput>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct WorktreeIndexOutput {
|
||||
excerpts: BTreeMap<Arc<Path>, Vec<Range<usize>>>,
|
||||
}
|
||||
|
||||
impl ProjectIndexView {
|
||||
fn new(input: CodebaseQuery, output: Result<ProjectIndexOutput>) -> Self {
|
||||
let element_id = ElementId::Name(nanoid::nanoid!().into());
|
||||
|
||||
Self {
|
||||
input,
|
||||
output,
|
||||
element_id,
|
||||
expanded_header: false,
|
||||
}
|
||||
}
|
||||
|
||||
fn toggle_header(&mut self, cx: &mut ViewContext<Self>) {
|
||||
self.expanded_header = !self.expanded_header;
|
||||
cx.notify();
|
||||
@ -60,18 +56,14 @@ impl ProjectIndexView {
|
||||
impl Render for ProjectIndexView {
|
||||
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
|
||||
let query = self.input.query.clone();
|
||||
|
||||
let result = &self.output;
|
||||
|
||||
let output = match result {
|
||||
let excerpts = match &self.excerpts {
|
||||
Err(err) => {
|
||||
return div().child(Label::new(format!("Error: {}", err)).color(Color::Error));
|
||||
}
|
||||
Ok(output) => output,
|
||||
Ok(excerpts) => excerpts,
|
||||
};
|
||||
|
||||
let file_count = output.excerpts.len();
|
||||
|
||||
let file_count = excerpts.len();
|
||||
let header = h_flex()
|
||||
.gap_2()
|
||||
.child(Icon::new(IconName::File))
|
||||
@ -97,16 +89,12 @@ impl Render for ProjectIndexView {
|
||||
.child(Icon::new(IconName::MagnifyingGlass))
|
||||
.child(Label::new(format!("`{}`", query)).color(Color::Muted)),
|
||||
)
|
||||
.child(
|
||||
v_flex()
|
||||
.gap_2()
|
||||
.children(output.excerpts.keys().map(|path| {
|
||||
h_flex().gap_2().child(Icon::new(IconName::File)).child(
|
||||
Label::new(path.path.to_string_lossy().to_string())
|
||||
.color(Color::Muted),
|
||||
)
|
||||
})),
|
||||
),
|
||||
.child(v_flex().gap_2().children(excerpts.keys().map(|path| {
|
||||
h_flex().gap_2().child(Icon::new(IconName::File)).child(
|
||||
Label::new(path.path.to_string_lossy().to_string())
|
||||
.color(Color::Muted),
|
||||
)
|
||||
}))),
|
||||
),
|
||||
)
|
||||
}
|
||||
@ -118,16 +106,16 @@ impl ToolOutput for ProjectIndexView {
|
||||
context: &mut assistant_tooling::ProjectContext,
|
||||
_: &mut WindowContext,
|
||||
) -> String {
|
||||
match &self.output {
|
||||
Ok(output) => {
|
||||
match &self.excerpts {
|
||||
Ok(excerpts) => {
|
||||
let mut body = "found results in the following paths:\n".to_string();
|
||||
|
||||
for (project_path, ranges) in &output.excerpts {
|
||||
for (project_path, ranges) in excerpts {
|
||||
context.add_excerpts(project_path.clone(), ranges);
|
||||
writeln!(&mut body, "* {}", &project_path.path.display()).unwrap();
|
||||
}
|
||||
|
||||
if output.status != Status::Idle {
|
||||
if self.status != Status::Idle {
|
||||
body.push_str("Still indexing. Results may be incomplete.\n");
|
||||
}
|
||||
|
||||
@ -172,16 +160,20 @@ impl LanguageModelTool for ProjectIndexTool {
|
||||
cx.update(|cx| {
|
||||
let mut output = ProjectIndexOutput {
|
||||
status,
|
||||
excerpts: Default::default(),
|
||||
worktrees: Default::default(),
|
||||
};
|
||||
|
||||
for search_result in search_results {
|
||||
let path = ProjectPath {
|
||||
worktree_id: search_result.worktree.read(cx).id(),
|
||||
path: search_result.path.clone(),
|
||||
};
|
||||
let worktree_path = search_result.worktree.read(cx).abs_path();
|
||||
let excerpts = &mut output
|
||||
.worktrees
|
||||
.entry(worktree_path)
|
||||
.or_insert(WorktreeIndexOutput {
|
||||
excerpts: Default::default(),
|
||||
})
|
||||
.excerpts;
|
||||
|
||||
let excerpts_for_path = output.excerpts.entry(path).or_default();
|
||||
let excerpts_for_path = excerpts.entry(search_result.path).or_default();
|
||||
let ix = match excerpts_for_path
|
||||
.binary_search_by_key(&search_result.range.start, |r| r.start)
|
||||
{
|
||||
@ -195,12 +187,57 @@ impl LanguageModelTool for ProjectIndexTool {
|
||||
})
|
||||
}
|
||||
|
||||
fn output_view(
|
||||
fn view(
|
||||
&self,
|
||||
input: Self::Input,
|
||||
output: Result<Self::Output>,
|
||||
cx: &mut WindowContext,
|
||||
) -> gpui::View<Self::View> {
|
||||
cx.new_view(|_cx| ProjectIndexView::new(input, output))
|
||||
cx.new_view(|cx| {
|
||||
let status;
|
||||
let excerpts;
|
||||
match output {
|
||||
Ok(output) => {
|
||||
status = output.status;
|
||||
let project_index = self.project_index.read(cx);
|
||||
if let Some(project) = project_index.project().upgrade() {
|
||||
let project = project.read(cx);
|
||||
excerpts = Ok(output
|
||||
.worktrees
|
||||
.into_iter()
|
||||
.filter_map(|(abs_path, output)| {
|
||||
for worktree in project.worktrees() {
|
||||
let worktree = worktree.read(cx);
|
||||
if worktree.abs_path() == abs_path {
|
||||
return Some((worktree.id(), output.excerpts));
|
||||
}
|
||||
}
|
||||
None
|
||||
})
|
||||
.flat_map(|(worktree_id, excerpts)| {
|
||||
excerpts.into_iter().map(move |(path, ranges)| {
|
||||
(ProjectPath { worktree_id, path }, ranges)
|
||||
})
|
||||
})
|
||||
.collect::<BTreeMap<_, _>>());
|
||||
} else {
|
||||
excerpts = Err(anyhow!("project was dropped"));
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
status = Status::Idle;
|
||||
excerpts = Err(err);
|
||||
}
|
||||
};
|
||||
|
||||
ProjectIndexView {
|
||||
input,
|
||||
status,
|
||||
excerpts,
|
||||
element_id: ElementId::Name(nanoid::nanoid!().into()),
|
||||
expanded_header: false,
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn render_running(arguments: &Option<Value>, _: &mut WindowContext) -> impl IntoElement {
|
||||
|
@ -2,9 +2,12 @@ mod attachment_registry;
|
||||
mod project_context;
|
||||
mod tool_registry;
|
||||
|
||||
pub use attachment_registry::{AttachmentRegistry, LanguageModelAttachment, UserAttachment};
|
||||
pub use attachment_registry::{
|
||||
AttachmentRegistry, LanguageModelAttachment, SavedUserAttachment, UserAttachment,
|
||||
};
|
||||
pub use project_context::ProjectContext;
|
||||
pub use tool_registry::{
|
||||
tool_running_placeholder, LanguageModelTool, ToolFunctionCall, ToolFunctionDefinition,
|
||||
tool_running_placeholder, LanguageModelTool, SavedToolFunctionCall,
|
||||
SavedToolFunctionCallResult, ToolFunctionCall, ToolFunctionCallResult, ToolFunctionDefinition,
|
||||
ToolOutput, ToolRegistry,
|
||||
};
|
||||
|
@ -3,6 +3,8 @@ use anyhow::{anyhow, Result};
|
||||
use collections::HashMap;
|
||||
use futures::future::join_all;
|
||||
use gpui::{AnyView, Render, Task, View, WindowContext};
|
||||
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
||||
use serde_json::value::RawValue;
|
||||
use std::{
|
||||
any::TypeId,
|
||||
sync::{
|
||||
@ -17,24 +19,34 @@ pub struct AttachmentRegistry {
|
||||
}
|
||||
|
||||
pub trait LanguageModelAttachment {
|
||||
type Output: 'static;
|
||||
type Output: DeserializeOwned + Serialize + 'static;
|
||||
type View: Render + ToolOutput;
|
||||
|
||||
fn name(&self) -> Arc<str>;
|
||||
fn run(&self, cx: &mut WindowContext) -> Task<Result<Self::Output>>;
|
||||
|
||||
fn view(output: Result<Self::Output>, cx: &mut WindowContext) -> View<Self::View>;
|
||||
fn view(&self, output: Result<Self::Output>, cx: &mut WindowContext) -> View<Self::View>;
|
||||
}
|
||||
|
||||
/// A collected attachment from running an attachment tool
|
||||
pub struct UserAttachment {
|
||||
pub view: AnyView,
|
||||
name: Arc<str>,
|
||||
serialized_output: Result<Box<RawValue>, String>,
|
||||
generate_fn: fn(AnyView, &mut ProjectContext, cx: &mut WindowContext) -> String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct SavedUserAttachment {
|
||||
name: Arc<str>,
|
||||
serialized_output: Result<Box<RawValue>, String>,
|
||||
}
|
||||
|
||||
/// Internal representation of an attachment tool to allow us to treat them dynamically
|
||||
struct RegisteredAttachment {
|
||||
name: Arc<str>,
|
||||
enabled: AtomicBool,
|
||||
call: Box<dyn Fn(&mut WindowContext) -> Task<Result<UserAttachment>>>,
|
||||
deserialize: Box<dyn Fn(&SavedUserAttachment, &mut WindowContext) -> Result<UserAttachment>>,
|
||||
}
|
||||
|
||||
impl AttachmentRegistry {
|
||||
@ -45,24 +57,65 @@ impl AttachmentRegistry {
|
||||
}
|
||||
|
||||
pub fn register<A: LanguageModelAttachment + 'static>(&mut self, attachment: A) {
|
||||
let call = Box::new(move |cx: &mut WindowContext| {
|
||||
let result = attachment.run(cx);
|
||||
let attachment = Arc::new(attachment);
|
||||
|
||||
cx.spawn(move |mut cx| async move {
|
||||
let result: Result<A::Output> = result.await;
|
||||
let view = cx.update(|cx| A::view(result, cx))?;
|
||||
let call = Box::new({
|
||||
let attachment = attachment.clone();
|
||||
move |cx: &mut WindowContext| {
|
||||
let result = attachment.run(cx);
|
||||
let attachment = attachment.clone();
|
||||
cx.spawn(move |mut cx| async move {
|
||||
let result: Result<A::Output> = result.await;
|
||||
let serialized_output =
|
||||
result
|
||||
.as_ref()
|
||||
.map_err(ToString::to_string)
|
||||
.and_then(|output| {
|
||||
Ok(RawValue::from_string(
|
||||
serde_json::to_string(output).map_err(|e| e.to_string())?,
|
||||
)
|
||||
.unwrap())
|
||||
});
|
||||
|
||||
let view = cx.update(|cx| attachment.view(result, cx))?;
|
||||
|
||||
Ok(UserAttachment {
|
||||
name: attachment.name(),
|
||||
view: view.into(),
|
||||
generate_fn: generate::<A>,
|
||||
serialized_output,
|
||||
})
|
||||
})
|
||||
}
|
||||
});
|
||||
|
||||
let deserialize = Box::new({
|
||||
let attachment = attachment.clone();
|
||||
move |saved_attachment: &SavedUserAttachment, cx: &mut WindowContext| {
|
||||
let serialized_output = saved_attachment.serialized_output.clone();
|
||||
let output = match &serialized_output {
|
||||
Ok(serialized_output) => {
|
||||
Ok(serde_json::from_str::<A::Output>(serialized_output.get())?)
|
||||
}
|
||||
Err(error) => Err(anyhow!("{error}")),
|
||||
};
|
||||
let view = attachment.view(output, cx).into();
|
||||
|
||||
Ok(UserAttachment {
|
||||
view: view.into(),
|
||||
name: saved_attachment.name.clone(),
|
||||
view,
|
||||
serialized_output,
|
||||
generate_fn: generate::<A>,
|
||||
})
|
||||
})
|
||||
}
|
||||
});
|
||||
|
||||
self.registered_attachments.insert(
|
||||
TypeId::of::<A>(),
|
||||
RegisteredAttachment {
|
||||
name: attachment.name(),
|
||||
call,
|
||||
deserialize,
|
||||
enabled: AtomicBool::new(true),
|
||||
},
|
||||
);
|
||||
@ -134,6 +187,35 @@ impl AttachmentRegistry {
|
||||
.collect())
|
||||
})
|
||||
}
|
||||
|
||||
pub fn serialize_user_attachment(
|
||||
&self,
|
||||
user_attachment: &UserAttachment,
|
||||
) -> SavedUserAttachment {
|
||||
SavedUserAttachment {
|
||||
name: user_attachment.name.clone(),
|
||||
serialized_output: user_attachment.serialized_output.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn deserialize_user_attachment(
|
||||
&self,
|
||||
saved_user_attachment: SavedUserAttachment,
|
||||
cx: &mut WindowContext,
|
||||
) -> Result<UserAttachment> {
|
||||
if let Some(registered_attachment) = self
|
||||
.registered_attachments
|
||||
.values()
|
||||
.find(|attachment| attachment.name == saved_user_attachment.name)
|
||||
{
|
||||
(registered_attachment.deserialize)(&saved_user_attachment, cx)
|
||||
} else {
|
||||
Err(anyhow!(
|
||||
"no attachment tool for name {}",
|
||||
saved_user_attachment.name
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl UserAttachment {
|
||||
|
@ -1,41 +1,60 @@
|
||||
use crate::ProjectContext;
|
||||
use anyhow::{anyhow, Result};
|
||||
use gpui::{
|
||||
div, AnyElement, AnyView, IntoElement, ParentElement, Render, Styled, Task, View, WindowContext,
|
||||
};
|
||||
use schemars::{schema::RootSchema, schema_for, JsonSchema};
|
||||
use serde::Deserialize;
|
||||
use serde_json::Value;
|
||||
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
||||
use serde_json::{value::RawValue, Value};
|
||||
use std::{
|
||||
any::TypeId,
|
||||
collections::HashMap,
|
||||
fmt::Display,
|
||||
sync::atomic::{AtomicBool, Ordering::SeqCst},
|
||||
sync::{
|
||||
atomic::{AtomicBool, Ordering::SeqCst},
|
||||
Arc,
|
||||
},
|
||||
};
|
||||
|
||||
use crate::ProjectContext;
|
||||
|
||||
pub struct ToolRegistry {
|
||||
registered_tools: HashMap<String, RegisteredTool>,
|
||||
}
|
||||
|
||||
#[derive(Default, Deserialize)]
|
||||
#[derive(Default)]
|
||||
pub struct ToolFunctionCall {
|
||||
pub id: String,
|
||||
pub name: String,
|
||||
pub arguments: String,
|
||||
#[serde(skip)]
|
||||
pub result: Option<ToolFunctionCallResult>,
|
||||
}
|
||||
|
||||
#[derive(Default, Serialize, Deserialize)]
|
||||
pub struct SavedToolFunctionCall {
|
||||
pub id: String,
|
||||
pub name: String,
|
||||
pub arguments: String,
|
||||
pub result: Option<SavedToolFunctionCallResult>,
|
||||
}
|
||||
|
||||
pub enum ToolFunctionCallResult {
|
||||
NoSuchTool,
|
||||
ParsingFailed,
|
||||
Finished {
|
||||
view: AnyView,
|
||||
serialized_output: Result<Box<RawValue>, String>,
|
||||
generate_fn: fn(AnyView, &mut ProjectContext, &mut WindowContext) -> String,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub enum SavedToolFunctionCallResult {
|
||||
NoSuchTool,
|
||||
ParsingFailed,
|
||||
Finished {
|
||||
serialized_output: Result<Box<RawValue>, String>,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ToolFunctionDefinition {
|
||||
pub name: String,
|
||||
@ -46,10 +65,10 @@ pub struct ToolFunctionDefinition {
|
||||
pub trait LanguageModelTool {
|
||||
/// The input type that will be passed in to `execute` when the tool is called
|
||||
/// by the language model.
|
||||
type Input: for<'de> Deserialize<'de> + JsonSchema;
|
||||
type Input: DeserializeOwned + JsonSchema;
|
||||
|
||||
/// The output returned by executing the tool.
|
||||
type Output: 'static;
|
||||
type Output: DeserializeOwned + Serialize + 'static;
|
||||
|
||||
type View: Render + ToolOutput;
|
||||
|
||||
@ -80,7 +99,8 @@ pub trait LanguageModelTool {
|
||||
fn execute(&self, input: &Self::Input, cx: &mut WindowContext) -> Task<Result<Self::Output>>;
|
||||
|
||||
/// A view of the output of running the tool, for displaying to the user.
|
||||
fn output_view(
|
||||
fn view(
|
||||
&self,
|
||||
input: Self::Input,
|
||||
output: Result<Self::Output>,
|
||||
cx: &mut WindowContext,
|
||||
@ -102,7 +122,8 @@ pub trait ToolOutput: Sized {
|
||||
struct RegisteredTool {
|
||||
enabled: AtomicBool,
|
||||
type_id: TypeId,
|
||||
call: Box<dyn Fn(&ToolFunctionCall, &mut WindowContext) -> Task<Result<ToolFunctionCall>>>,
|
||||
execute: Box<dyn Fn(&ToolFunctionCall, &mut WindowContext) -> Task<Result<ToolFunctionCall>>>,
|
||||
deserialize: Box<dyn Fn(&SavedToolFunctionCall, &mut WindowContext) -> ToolFunctionCall>,
|
||||
render_running: fn(&ToolFunctionCall, &mut WindowContext) -> gpui::AnyElement,
|
||||
definition: ToolFunctionDefinition,
|
||||
}
|
||||
@ -162,23 +183,125 @@ impl ToolRegistry {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn serialize_tool_call(&self, call: &ToolFunctionCall) -> SavedToolFunctionCall {
|
||||
SavedToolFunctionCall {
|
||||
id: call.id.clone(),
|
||||
name: call.name.clone(),
|
||||
arguments: call.arguments.clone(),
|
||||
result: call.result.as_ref().map(|result| match result {
|
||||
ToolFunctionCallResult::NoSuchTool => SavedToolFunctionCallResult::NoSuchTool,
|
||||
ToolFunctionCallResult::ParsingFailed => SavedToolFunctionCallResult::ParsingFailed,
|
||||
ToolFunctionCallResult::Finished {
|
||||
serialized_output, ..
|
||||
} => SavedToolFunctionCallResult::Finished {
|
||||
serialized_output: match serialized_output {
|
||||
Ok(value) => Ok(value.clone()),
|
||||
Err(e) => Err(e.to_string()),
|
||||
},
|
||||
},
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn deserialize_tool_call(
|
||||
&self,
|
||||
call: &SavedToolFunctionCall,
|
||||
cx: &mut WindowContext,
|
||||
) -> ToolFunctionCall {
|
||||
if let Some(tool) = &self.registered_tools.get(&call.name) {
|
||||
(tool.deserialize)(call, cx)
|
||||
} else {
|
||||
ToolFunctionCall {
|
||||
id: call.id.clone(),
|
||||
name: call.name.clone(),
|
||||
arguments: call.arguments.clone(),
|
||||
result: Some(ToolFunctionCallResult::NoSuchTool),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn register<T: 'static + LanguageModelTool>(
|
||||
&mut self,
|
||||
tool: T,
|
||||
_cx: &mut WindowContext,
|
||||
) -> Result<()> {
|
||||
let name = tool.name();
|
||||
let tool = Arc::new(tool);
|
||||
let registered_tool = RegisteredTool {
|
||||
type_id: TypeId::of::<T>(),
|
||||
definition: tool.definition(),
|
||||
enabled: AtomicBool::new(true),
|
||||
call: Box::new(
|
||||
move |tool_call: &ToolFunctionCall, cx: &mut WindowContext| {
|
||||
deserialize: Box::new({
|
||||
let tool = tool.clone();
|
||||
move |tool_call: &SavedToolFunctionCall, cx: &mut WindowContext| {
|
||||
let id = tool_call.id.clone();
|
||||
let name = tool_call.name.clone();
|
||||
let arguments = tool_call.arguments.clone();
|
||||
let id = tool_call.id.clone();
|
||||
|
||||
let Ok(input) = serde_json::from_str::<T::Input>(arguments.as_str()) else {
|
||||
let Ok(input) = serde_json::from_str::<T::Input>(&tool_call.arguments) else {
|
||||
return ToolFunctionCall {
|
||||
id,
|
||||
name: name.clone(),
|
||||
arguments,
|
||||
result: Some(ToolFunctionCallResult::ParsingFailed),
|
||||
};
|
||||
};
|
||||
|
||||
let result = match &tool_call.result {
|
||||
Some(result) => match result {
|
||||
SavedToolFunctionCallResult::NoSuchTool => {
|
||||
Some(ToolFunctionCallResult::NoSuchTool)
|
||||
}
|
||||
SavedToolFunctionCallResult::ParsingFailed => {
|
||||
Some(ToolFunctionCallResult::ParsingFailed)
|
||||
}
|
||||
SavedToolFunctionCallResult::Finished { serialized_output } => {
|
||||
let output = match serialized_output {
|
||||
Ok(value) => {
|
||||
match serde_json::from_str::<T::Output>(value.get()) {
|
||||
Ok(value) => Ok(value),
|
||||
Err(_) => {
|
||||
return ToolFunctionCall {
|
||||
id,
|
||||
name: name.clone(),
|
||||
arguments,
|
||||
result: Some(
|
||||
ToolFunctionCallResult::ParsingFailed,
|
||||
),
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => Err(anyhow!("{e}")),
|
||||
};
|
||||
|
||||
let view = tool.view(input, output, cx).into();
|
||||
Some(ToolFunctionCallResult::Finished {
|
||||
serialized_output: serialized_output.clone(),
|
||||
generate_fn: generate::<T>,
|
||||
view,
|
||||
})
|
||||
}
|
||||
},
|
||||
None => None,
|
||||
};
|
||||
|
||||
ToolFunctionCall {
|
||||
id: tool_call.id.clone(),
|
||||
name: name.clone(),
|
||||
arguments: tool_call.arguments.clone(),
|
||||
result,
|
||||
}
|
||||
}
|
||||
}),
|
||||
execute: Box::new({
|
||||
let tool = tool.clone();
|
||||
move |tool_call: &ToolFunctionCall, cx: &mut WindowContext| {
|
||||
let id = tool_call.id.clone();
|
||||
let name = tool_call.name.clone();
|
||||
let arguments = tool_call.arguments.clone();
|
||||
|
||||
let Ok(input) = serde_json::from_str::<T::Input>(&arguments) else {
|
||||
return Task::ready(Ok(ToolFunctionCall {
|
||||
id,
|
||||
name: name.clone(),
|
||||
@ -188,23 +311,33 @@ impl ToolRegistry {
|
||||
};
|
||||
|
||||
let result = tool.execute(&input, cx);
|
||||
|
||||
let tool = tool.clone();
|
||||
cx.spawn(move |mut cx| async move {
|
||||
let result: Result<T::Output> = result.await;
|
||||
let view = cx.update(|cx| T::output_view(input, result, cx))?;
|
||||
let result = result.await;
|
||||
let serialized_output = result
|
||||
.as_ref()
|
||||
.map_err(ToString::to_string)
|
||||
.and_then(|output| {
|
||||
Ok(RawValue::from_string(
|
||||
serde_json::to_string(output).map_err(|e| e.to_string())?,
|
||||
)
|
||||
.unwrap())
|
||||
});
|
||||
let view = cx.update(|cx| tool.view(input, result, cx))?;
|
||||
|
||||
Ok(ToolFunctionCall {
|
||||
id,
|
||||
name: name.clone(),
|
||||
arguments,
|
||||
result: Some(ToolFunctionCallResult::Finished {
|
||||
serialized_output,
|
||||
view: view.into(),
|
||||
generate_fn: generate::<T>,
|
||||
}),
|
||||
})
|
||||
})
|
||||
},
|
||||
),
|
||||
}
|
||||
}),
|
||||
render_running: render_running::<T>,
|
||||
};
|
||||
|
||||
@ -259,7 +392,7 @@ impl ToolRegistry {
|
||||
}
|
||||
};
|
||||
|
||||
(tool.call)(tool_call, cx)
|
||||
(tool.execute)(tool_call, cx)
|
||||
}
|
||||
}
|
||||
|
||||
@ -275,9 +408,9 @@ impl ToolFunctionCallResult {
|
||||
ToolFunctionCallResult::ParsingFailed => {
|
||||
format!("Unable to parse arguments for {name}")
|
||||
}
|
||||
ToolFunctionCallResult::Finished { generate_fn, view } => {
|
||||
(generate_fn)(view.clone(), project, cx)
|
||||
}
|
||||
ToolFunctionCallResult::Finished {
|
||||
generate_fn, view, ..
|
||||
} => (generate_fn)(view.clone(), project, cx),
|
||||
}
|
||||
}
|
||||
|
||||
@ -373,7 +506,8 @@ mod test {
|
||||
Task::ready(Ok(weather))
|
||||
}
|
||||
|
||||
fn output_view(
|
||||
fn view(
|
||||
&self,
|
||||
_input: Self::Input,
|
||||
result: Result<Self::Output>,
|
||||
cx: &mut WindowContext,
|
||||
|
@ -7864,6 +7864,18 @@ impl Project {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn project_path_for_absolute_path(
|
||||
&self,
|
||||
abs_path: &Path,
|
||||
cx: &AppContext,
|
||||
) -> Option<ProjectPath> {
|
||||
self.find_local_worktree(abs_path, cx)
|
||||
.map(|(worktree, relative_path)| ProjectPath {
|
||||
worktree_id: worktree.read(cx).id(),
|
||||
path: relative_path.into(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn get_workspace_root(
|
||||
&self,
|
||||
project_path: &ProjectPath,
|
||||
|
@ -250,6 +250,7 @@ impl SearchQuery {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn search(
|
||||
&self,
|
||||
buffer: &BufferSnapshot,
|
||||
|
@ -450,7 +450,7 @@ pub struct WorktreeSearchResult {
|
||||
pub score: f32,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
|
||||
#[derive(Copy, Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
|
||||
pub enum Status {
|
||||
Idle,
|
||||
Loading,
|
||||
|
Loading…
Reference in New Issue
Block a user