Implement serialization of assistant conversations, including tool calls and attachments (#11577)

Release Notes:

- N/A

---------

Co-authored-by: Kyle <kylek@zed.dev>
Co-authored-by: Marshall <marshall@zed.dev>
This commit is contained in:
Max Brunsfeld 2024-05-08 14:52:15 -07:00 committed by GitHub
parent 24ffa0fcf3
commit a7aa2578e1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 585 additions and 253 deletions

View File

@ -6,19 +6,14 @@ mod saved_conversation_picker;
mod tools;
pub mod ui;
use crate::saved_conversation::{SavedConversation, SavedMessage, SavedMessageRole};
use crate::saved_conversation_picker::SavedConversationPicker;
use crate::{
attachments::ActiveEditorAttachmentTool,
tools::{CreateBufferTool, ProjectIndexTool},
ui::UserOrAssistant,
};
use crate::ui::UserOrAssistant;
use ::ui::{div, prelude::*, Color, Tooltip, ViewContext};
use anyhow::{Context, Result};
use assistant_tooling::{
tool_running_placeholder, AttachmentRegistry, ProjectContext, ToolFunctionCall, ToolRegistry,
UserAttachment,
};
use attachments::ActiveEditorAttachmentTool;
use client::{proto, Client, UserStore};
use collections::HashMap;
use completion_provider::*;
@ -33,11 +28,13 @@ use gpui::{
use language::{language_settings::SoftWrap, LanguageRegistry};
use open_ai::{FunctionContent, ToolCall, ToolCallContent};
use rich_text::RichText;
use saved_conversation::{SavedAssistantMessagePart, SavedChatMessage, SavedConversation};
use saved_conversation_picker::SavedConversationPicker;
use semantic_index::{CloudEmbeddingProvider, ProjectIndex, ProjectIndexDebugView, SemanticIndex};
use serde::{Deserialize, Serialize};
use settings::Settings;
use std::sync::Arc;
use tools::AnnotationTool;
use tools::{AnnotationTool, CreateBufferTool, ProjectIndexTool};
use ui::{ActiveFileButton, Composer, ProjectIndexButton};
use util::paths::CONVERSATIONS_DIR;
use util::{maybe, paths::EMBEDDINGS_DIR, ResultExt};
@ -506,13 +503,11 @@ impl AssistantChat {
while let Some(delta) = stream.next().await {
let delta = delta?;
this.update(cx, |this, cx| {
if let Some(ChatMessage::Assistant(GroupedAssistantMessage {
messages,
..
})) = this.messages.last_mut()
if let Some(ChatMessage::Assistant(AssistantMessage { messages, .. })) =
this.messages.last_mut()
{
if messages.is_empty() {
messages.push(AssistantMessage {
messages.push(AssistantMessagePart {
body: RichText::default(),
tool_calls: Vec::new(),
})
@ -563,7 +558,7 @@ impl AssistantChat {
let mut tool_tasks = Vec::new();
this.update(cx, |this, cx| {
if let Some(ChatMessage::Assistant(GroupedAssistantMessage {
if let Some(ChatMessage::Assistant(AssistantMessage {
error: message_error,
messages,
..
@ -592,7 +587,7 @@ impl AssistantChat {
let tools = tools.into_iter().filter_map(|tool| tool.ok()).collect();
this.update(cx, |this, cx| {
if let Some(ChatMessage::Assistant(GroupedAssistantMessage { messages, .. })) =
if let Some(ChatMessage::Assistant(AssistantMessage { messages, .. })) =
this.messages.last_mut()
{
if let Some(current_message) = messages.last_mut() {
@ -608,19 +603,19 @@ impl AssistantChat {
fn push_new_assistant_message(&mut self, cx: &mut ViewContext<Self>) {
// If the last message is a grouped assistant message, add to the grouped message
if let Some(ChatMessage::Assistant(GroupedAssistantMessage { messages, .. })) =
if let Some(ChatMessage::Assistant(AssistantMessage { messages, .. })) =
self.messages.last_mut()
{
messages.push(AssistantMessage {
messages.push(AssistantMessagePart {
body: RichText::default(),
tool_calls: Vec::new(),
});
return;
}
let message = ChatMessage::Assistant(GroupedAssistantMessage {
let message = ChatMessage::Assistant(AssistantMessage {
id: self.next_message_id.post_inc(),
messages: vec![AssistantMessage {
messages: vec![AssistantMessagePart {
body: RichText::default(),
tool_calls: Vec::new(),
}],
@ -669,40 +664,30 @@ impl AssistantChat {
*entry = !*entry;
}
fn new_conversation(&mut self, cx: &mut ViewContext<Self>) {
let messages = self
.messages
.drain(..)
.map(|message| {
let text = match &message {
ChatMessage::User(message) => message.body.read(cx).text(cx),
ChatMessage::Assistant(message) => message
.messages
.iter()
.map(|message| message.body.text.to_string())
.collect::<Vec<_>>()
.join("\n\n"),
};
SavedMessage {
id: message.id(),
role: match message {
ChatMessage::User(_) => SavedMessageRole::User,
ChatMessage::Assistant(_) => SavedMessageRole::Assistant,
},
text,
}
})
.collect::<Vec<_>>();
// Reset the chat for the new conversation.
fn reset(&mut self) {
self.messages.clear();
self.list_state.reset(0);
self.editing_message.take();
self.collapsed_messages.clear();
}
fn new_conversation(&mut self, cx: &mut ViewContext<Self>) {
let messages = std::mem::take(&mut self.messages)
.into_iter()
.map(|message| self.serialize_message(message, cx))
.collect::<Vec<_>>();
self.reset();
let title = messages
.first()
.map(|message| message.text.clone())
.map(|message| match message {
SavedChatMessage::User { body, .. } => body.clone(),
SavedChatMessage::Assistant { messages, .. } => messages
.first()
.map(|message| message.body.to_string())
.unwrap_or_default(),
})
.unwrap_or_else(|| "A conversation with the assistant.".to_string());
let saved_conversation = SavedConversation {
@ -836,7 +821,7 @@ impl AssistantChat {
}
})
.into_any(),
ChatMessage::Assistant(GroupedAssistantMessage {
ChatMessage::Assistant(AssistantMessage {
id,
messages,
error,
@ -917,7 +902,7 @@ impl AssistantChat {
content: body.read(cx).text(cx),
});
}
ChatMessage::Assistant(GroupedAssistantMessage { messages, .. }) => {
ChatMessage::Assistant(AssistantMessage { messages, .. }) => {
for message in messages {
let body = message.body.clone();
@ -971,6 +956,43 @@ impl AssistantChat {
Ok(completion_messages)
})
}
fn serialize_message(
&self,
message: ChatMessage,
cx: &mut ViewContext<AssistantChat>,
) -> SavedChatMessage {
match message {
ChatMessage::User(message) => SavedChatMessage::User {
id: message.id,
body: message.body.read(cx).text(cx),
attachments: message
.attachments
.iter()
.map(|attachment| {
self.attachment_registry
.serialize_user_attachment(attachment)
})
.collect(),
},
ChatMessage::Assistant(message) => SavedChatMessage::Assistant {
id: message.id,
error: message.error,
messages: message
.messages
.iter()
.map(|message| SavedAssistantMessagePart {
body: message.body.text.clone(),
tool_calls: message
.tool_calls
.iter()
.map(|tool_call| self.tool_registry.serialize_tool_call(tool_call))
.collect(),
})
.collect(),
},
}
}
}
impl Render for AssistantChat {
@ -1053,17 +1075,10 @@ impl MessageId {
enum ChatMessage {
User(UserMessage),
Assistant(GroupedAssistantMessage),
Assistant(AssistantMessage),
}
impl ChatMessage {
pub fn id(&self) -> MessageId {
match self {
ChatMessage::User(message) => message.id,
ChatMessage::Assistant(message) => message.id,
}
}
fn focus_handle(&self, cx: &AppContext) -> Option<FocusHandle> {
match self {
ChatMessage::User(UserMessage { body, .. }) => Some(body.focus_handle(cx)),
@ -1073,18 +1088,18 @@ impl ChatMessage {
}
struct UserMessage {
id: MessageId,
body: View<Editor>,
attachments: Vec<UserAttachment>,
pub id: MessageId,
pub body: View<Editor>,
pub attachments: Vec<UserAttachment>,
}
struct AssistantMessagePart {
pub body: RichText,
pub tool_calls: Vec<ToolFunctionCall>,
}
struct AssistantMessage {
body: RichText,
tool_calls: Vec<ToolFunctionCall>,
}
struct GroupedAssistantMessage {
id: MessageId,
messages: Vec<AssistantMessage>,
error: Option<SharedString>,
pub id: MessageId,
pub messages: Vec<AssistantMessagePart>,
pub error: Option<SharedString>,
}

View File

@ -1,64 +1,68 @@
use std::{path::PathBuf, sync::Arc};
use anyhow::{anyhow, Result};
use assistant_tooling::{LanguageModelAttachment, ProjectContext, ToolOutput};
use editor::Editor;
use gpui::{Render, Task, View, WeakModel, WeakView};
use language::Buffer;
use project::ProjectPath;
use serde::{Deserialize, Serialize};
use ui::{prelude::*, ButtonLike, Tooltip, WindowContext};
use util::maybe;
use workspace::Workspace;
#[derive(Serialize, Deserialize)]
pub struct ActiveEditorAttachment {
buffer: WeakModel<Buffer>,
path: Option<ProjectPath>,
#[serde(skip)]
buffer: Option<WeakModel<Buffer>>,
path: Option<PathBuf>,
}
pub struct FileAttachmentView {
output: Result<ActiveEditorAttachment>,
project_path: Option<ProjectPath>,
buffer: Option<WeakModel<Buffer>>,
error: Option<anyhow::Error>,
}
impl Render for FileAttachmentView {
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
match &self.output {
Ok(attachment) => {
let filename: SharedString = attachment
.path
.as_ref()
.and_then(|p| p.path.file_name()?.to_str())
.unwrap_or("Untitled")
.to_string()
.into();
// todo!(): make the button link to the actual file to open
ButtonLike::new("file-attachment")
.child(
h_flex()
.gap_1()
.bg(cx.theme().colors().editor_background)
.rounded_md()
.child(ui::Icon::new(IconName::File))
.child(filename.clone()),
)
.tooltip({
move |cx| Tooltip::with_meta("File Attached", None, filename.clone(), cx)
})
.into_any_element()
}
Err(err) => div().child(err.to_string()).into_any_element(),
if let Some(error) = &self.error {
return div().child(error.to_string()).into_any_element();
}
let filename: SharedString = self
.project_path
.as_ref()
.and_then(|p| p.path.file_name()?.to_str())
.unwrap_or("Untitled")
.to_string()
.into();
ButtonLike::new("file-attachment")
.child(
h_flex()
.gap_1()
.bg(cx.theme().colors().editor_background)
.rounded_md()
.child(ui::Icon::new(IconName::File))
.child(filename.clone()),
)
.tooltip(move |cx| Tooltip::with_meta("File Attached", None, filename.clone(), cx))
.into_any_element()
}
}
impl ToolOutput for FileAttachmentView {
fn generate(&self, project: &mut ProjectContext, cx: &mut WindowContext) -> String {
if let Ok(result) = &self.output {
if let Some(path) = &result.path {
project.add_file(path.clone());
return format!("current file: {}", path.path.display());
} else if let Some(buffer) = result.buffer.upgrade() {
return format!("current untitled buffer text:\n{}", buffer.read(cx).text());
}
if let Some(path) = &self.project_path {
project.add_file(path.clone());
return format!("current file: {}", path.path.display());
}
if let Some(buffer) = self.buffer.as_ref().and_then(|buffer| buffer.upgrade()) {
return format!("current untitled buffer text:\n{}", buffer.read(cx).text());
}
String::new()
}
}
@ -77,6 +81,10 @@ impl LanguageModelAttachment for ActiveEditorAttachmentTool {
type Output = ActiveEditorAttachment;
type View = FileAttachmentView;
fn name(&self) -> Arc<str> {
"active-editor-attachment".into()
}
fn run(&self, cx: &mut WindowContext) -> Task<Result<ActiveEditorAttachment>> {
Task::ready(maybe!({
let active_buffer = self
@ -91,13 +99,10 @@ impl LanguageModelAttachment for ActiveEditorAttachmentTool {
let buffer = active_buffer.read(cx);
if let Some(buffer) = buffer.as_singleton() {
let path =
project::File::from_dyn(buffer.read(cx).file()).map(|file| ProjectPath {
worktree_id: file.worktree_id(cx),
path: file.path.clone(),
});
let path = project::File::from_dyn(buffer.read(cx).file())
.and_then(|file| file.worktree.read(cx).absolutize(&file.path).ok());
return Ok(ActiveEditorAttachment {
buffer: buffer.downgrade(),
buffer: Some(buffer.downgrade()),
path,
});
} else {
@ -106,7 +111,34 @@ impl LanguageModelAttachment for ActiveEditorAttachmentTool {
}))
}
fn view(output: Result<Self::Output>, cx: &mut WindowContext) -> View<Self::View> {
cx.new_view(|_cx| FileAttachmentView { output })
fn view(
&self,
output: Result<ActiveEditorAttachment>,
cx: &mut WindowContext,
) -> View<Self::View> {
let error;
let project_path;
let buffer;
match output {
Ok(output) => {
error = None;
let workspace = self.workspace.upgrade().unwrap();
let project = workspace.read(cx).project();
project_path = output
.path
.and_then(|path| project.read(cx).project_path_for_absolute_path(&path, cx));
buffer = output.buffer;
}
Err(err) => {
error = Some(err);
buffer = None;
project_path = None;
}
}
cx.new_view(|_cx| FileAttachmentView {
project_path,
buffer,
error,
})
}
}

View File

@ -1,3 +1,5 @@
use assistant_tooling::{SavedToolFunctionCall, SavedUserAttachment};
use gpui::SharedString;
use serde::{Deserialize, Serialize};
use crate::MessageId;
@ -8,21 +10,27 @@ pub struct SavedConversation {
pub version: String,
/// The title of the conversation, generated by the Assistant.
pub title: String,
pub messages: Vec<SavedMessage>,
pub messages: Vec<SavedChatMessage>,
}
#[derive(Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum SavedMessageRole {
User,
Assistant,
pub enum SavedChatMessage {
User {
id: MessageId,
body: String,
attachments: Vec<SavedUserAttachment>,
},
Assistant {
id: MessageId,
messages: Vec<SavedAssistantMessagePart>,
error: Option<SharedString>,
},
}
#[derive(Serialize, Deserialize)]
pub struct SavedMessage {
pub id: MessageId,
pub role: SavedMessageRole,
pub text: String,
pub struct SavedAssistantMessagePart {
pub body: SharedString,
pub tool_calls: Vec<SavedToolFunctionCall>,
}
/// Returns a list of placeholder conversations for mocking the UI.

View File

@ -6,7 +6,7 @@ use editor::{
};
use gpui::{prelude::*, AnyElement, Model, Task, View, WeakView};
use language::ToPoint;
use project::{Project, ProjectPath};
use project::{search::SearchQuery, Project, ProjectPath};
use schemars::JsonSchema;
use serde::Deserialize;
use std::path::Path;
@ -29,17 +29,18 @@ impl AnnotationTool {
pub struct AnnotationInput {
/// Name for this set of annotations
title: String,
annotations: Vec<Annotation>,
/// Excerpts from the file to show to the user.
excerpts: Vec<Excerpt>,
}
#[derive(Debug, Deserialize, JsonSchema, Clone)]
struct Annotation {
struct Excerpt {
/// Path to the file
path: String,
/// Name of a symbol in the code
symbol_name: String,
/// Text to display near the symbol definition
text: String,
/// A short, distinctive string that appears in the file, used to define a location in the file.
text_passage: String,
/// Text to display above the code excerpt
annotation: String,
}
impl LanguageModelTool for AnnotationTool {
@ -58,7 +59,7 @@ impl LanguageModelTool for AnnotationTool {
fn execute(&self, input: &Self::Input, cx: &mut WindowContext) -> Task<Result<Self::Output>> {
let workspace = self.workspace.clone();
let project = self.project.clone();
let excerpts = input.annotations.clone();
let excerpts = input.excerpts.clone();
let title = input.title.clone();
let worktree_id = project.update(cx, |project, cx| {
@ -74,15 +75,16 @@ impl LanguageModelTool for AnnotationTool {
};
let buffer_tasks = project.update(cx, |project, cx| {
let excerpts = excerpts.clone();
excerpts
.iter()
.map(|excerpt| {
let project_path = ProjectPath {
worktree_id,
path: Path::new(&excerpt.path).into(),
};
project.open_buffer(project_path.clone(), cx)
project.open_buffer(
ProjectPath {
worktree_id,
path: Path::new(&excerpt.path).into(),
},
cx,
)
})
.collect::<Vec<_>>()
});
@ -99,39 +101,43 @@ impl LanguageModelTool for AnnotationTool {
for (excerpt, buffer) in excerpts.iter().zip(buffers.iter()) {
let snapshot = buffer.update(&mut cx, |buffer, _cx| buffer.snapshot())?;
if let Some(outline) = snapshot.outline(None) {
let matches = outline
.search(&excerpt.symbol_name, cx.background_executor().clone())
.await;
if let Some(mat) = matches.first() {
let item = &outline.items[mat.candidate_id];
let start = item.range.start.to_point(&snapshot);
editor.update(&mut cx, |editor, cx| {
let ranges = editor.buffer().update(cx, |multibuffer, cx| {
multibuffer.push_excerpts_with_context_lines(
buffer.clone(),
vec![start..start],
5,
cx,
)
});
let explanation = SharedString::from(excerpt.text.clone());
editor.insert_blocks(
[BlockProperties {
position: ranges[0].start,
height: 2,
style: BlockStyle::Fixed,
render: Box::new(move |cx| {
Self::render_note_block(&explanation, cx)
}),
disposition: BlockDisposition::Above,
}],
None,
cx,
);
})?;
}
}
let query =
SearchQuery::text(&excerpt.text_passage, false, false, false, vec![], vec![])?;
let matches = query.search(&snapshot, None).await;
let Some(first_match) = matches.first() else {
log::warn!(
"text {:?} does not appear in '{}'",
excerpt.text_passage,
excerpt.path
);
continue;
};
let mut start = first_match.start.to_point(&snapshot);
start.column = 0;
editor.update(&mut cx, |editor, cx| {
let ranges = editor.buffer().update(cx, |multibuffer, cx| {
multibuffer.push_excerpts_with_context_lines(
buffer.clone(),
vec![start..start],
5,
cx,
)
});
let annotation = SharedString::from(excerpt.annotation.clone());
editor.insert_blocks(
[BlockProperties {
position: ranges[0].start,
height: annotation.split('\n').count() as u8 + 1,
style: BlockStyle::Fixed,
render: Box::new(move |cx| Self::render_note_block(&annotation, cx)),
disposition: BlockDisposition::Above,
}],
None,
cx,
);
})?;
}
workspace
@ -144,7 +150,8 @@ impl LanguageModelTool for AnnotationTool {
})
}
fn output_view(
fn view(
&self,
_: Self::Input,
output: Result<Self::Output>,
cx: &mut WindowContext,

View File

@ -86,7 +86,8 @@ impl LanguageModelTool for CreateBufferTool {
})
}
fn output_view(
fn view(
&self,
input: Self::Input,
output: Result<Self::Output>,
cx: &mut WindowContext,

View File

@ -1,13 +1,13 @@
use anyhow::Result;
use anyhow::{anyhow, Result};
use assistant_tooling::{LanguageModelTool, ToolOutput};
use collections::BTreeMap;
use gpui::{prelude::*, Model, Task};
use project::ProjectPath;
use schemars::JsonSchema;
use semantic_index::{ProjectIndex, Status};
use serde::Deserialize;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::{fmt::Write as _, ops::Range};
use std::{fmt::Write as _, ops::Range, path::Path, sync::Arc};
use ui::{div, prelude::*, CollapsibleContainer, Color, Icon, IconName, Label, WindowContext};
const DEFAULT_SEARCH_LIMIT: usize = 20;
@ -29,28 +29,24 @@ pub struct CodebaseQuery {
pub struct ProjectIndexView {
input: CodebaseQuery,
output: Result<ProjectIndexOutput>,
status: Status,
excerpts: Result<BTreeMap<ProjectPath, Vec<Range<usize>>>>,
element_id: ElementId,
expanded_header: bool,
}
#[derive(Serialize, Deserialize)]
pub struct ProjectIndexOutput {
status: Status,
excerpts: BTreeMap<ProjectPath, Vec<Range<usize>>>,
worktrees: BTreeMap<Arc<Path>, WorktreeIndexOutput>,
}
#[derive(Serialize, Deserialize)]
struct WorktreeIndexOutput {
excerpts: BTreeMap<Arc<Path>, Vec<Range<usize>>>,
}
impl ProjectIndexView {
fn new(input: CodebaseQuery, output: Result<ProjectIndexOutput>) -> Self {
let element_id = ElementId::Name(nanoid::nanoid!().into());
Self {
input,
output,
element_id,
expanded_header: false,
}
}
fn toggle_header(&mut self, cx: &mut ViewContext<Self>) {
self.expanded_header = !self.expanded_header;
cx.notify();
@ -60,18 +56,14 @@ impl ProjectIndexView {
impl Render for ProjectIndexView {
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
let query = self.input.query.clone();
let result = &self.output;
let output = match result {
let excerpts = match &self.excerpts {
Err(err) => {
return div().child(Label::new(format!("Error: {}", err)).color(Color::Error));
}
Ok(output) => output,
Ok(excerpts) => excerpts,
};
let file_count = output.excerpts.len();
let file_count = excerpts.len();
let header = h_flex()
.gap_2()
.child(Icon::new(IconName::File))
@ -97,16 +89,12 @@ impl Render for ProjectIndexView {
.child(Icon::new(IconName::MagnifyingGlass))
.child(Label::new(format!("`{}`", query)).color(Color::Muted)),
)
.child(
v_flex()
.gap_2()
.children(output.excerpts.keys().map(|path| {
h_flex().gap_2().child(Icon::new(IconName::File)).child(
Label::new(path.path.to_string_lossy().to_string())
.color(Color::Muted),
)
})),
),
.child(v_flex().gap_2().children(excerpts.keys().map(|path| {
h_flex().gap_2().child(Icon::new(IconName::File)).child(
Label::new(path.path.to_string_lossy().to_string())
.color(Color::Muted),
)
}))),
),
)
}
@ -118,16 +106,16 @@ impl ToolOutput for ProjectIndexView {
context: &mut assistant_tooling::ProjectContext,
_: &mut WindowContext,
) -> String {
match &self.output {
Ok(output) => {
match &self.excerpts {
Ok(excerpts) => {
let mut body = "found results in the following paths:\n".to_string();
for (project_path, ranges) in &output.excerpts {
for (project_path, ranges) in excerpts {
context.add_excerpts(project_path.clone(), ranges);
writeln!(&mut body, "* {}", &project_path.path.display()).unwrap();
}
if output.status != Status::Idle {
if self.status != Status::Idle {
body.push_str("Still indexing. Results may be incomplete.\n");
}
@ -172,16 +160,20 @@ impl LanguageModelTool for ProjectIndexTool {
cx.update(|cx| {
let mut output = ProjectIndexOutput {
status,
excerpts: Default::default(),
worktrees: Default::default(),
};
for search_result in search_results {
let path = ProjectPath {
worktree_id: search_result.worktree.read(cx).id(),
path: search_result.path.clone(),
};
let worktree_path = search_result.worktree.read(cx).abs_path();
let excerpts = &mut output
.worktrees
.entry(worktree_path)
.or_insert(WorktreeIndexOutput {
excerpts: Default::default(),
})
.excerpts;
let excerpts_for_path = output.excerpts.entry(path).or_default();
let excerpts_for_path = excerpts.entry(search_result.path).or_default();
let ix = match excerpts_for_path
.binary_search_by_key(&search_result.range.start, |r| r.start)
{
@ -195,12 +187,57 @@ impl LanguageModelTool for ProjectIndexTool {
})
}
fn output_view(
fn view(
&self,
input: Self::Input,
output: Result<Self::Output>,
cx: &mut WindowContext,
) -> gpui::View<Self::View> {
cx.new_view(|_cx| ProjectIndexView::new(input, output))
cx.new_view(|cx| {
let status;
let excerpts;
match output {
Ok(output) => {
status = output.status;
let project_index = self.project_index.read(cx);
if let Some(project) = project_index.project().upgrade() {
let project = project.read(cx);
excerpts = Ok(output
.worktrees
.into_iter()
.filter_map(|(abs_path, output)| {
for worktree in project.worktrees() {
let worktree = worktree.read(cx);
if worktree.abs_path() == abs_path {
return Some((worktree.id(), output.excerpts));
}
}
None
})
.flat_map(|(worktree_id, excerpts)| {
excerpts.into_iter().map(move |(path, ranges)| {
(ProjectPath { worktree_id, path }, ranges)
})
})
.collect::<BTreeMap<_, _>>());
} else {
excerpts = Err(anyhow!("project was dropped"));
}
}
Err(err) => {
status = Status::Idle;
excerpts = Err(err);
}
};
ProjectIndexView {
input,
status,
excerpts,
element_id: ElementId::Name(nanoid::nanoid!().into()),
expanded_header: false,
}
})
}
fn render_running(arguments: &Option<Value>, _: &mut WindowContext) -> impl IntoElement {

View File

@ -2,9 +2,12 @@ mod attachment_registry;
mod project_context;
mod tool_registry;
pub use attachment_registry::{AttachmentRegistry, LanguageModelAttachment, UserAttachment};
pub use attachment_registry::{
AttachmentRegistry, LanguageModelAttachment, SavedUserAttachment, UserAttachment,
};
pub use project_context::ProjectContext;
pub use tool_registry::{
tool_running_placeholder, LanguageModelTool, ToolFunctionCall, ToolFunctionDefinition,
tool_running_placeholder, LanguageModelTool, SavedToolFunctionCall,
SavedToolFunctionCallResult, ToolFunctionCall, ToolFunctionCallResult, ToolFunctionDefinition,
ToolOutput, ToolRegistry,
};

View File

@ -3,6 +3,8 @@ use anyhow::{anyhow, Result};
use collections::HashMap;
use futures::future::join_all;
use gpui::{AnyView, Render, Task, View, WindowContext};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use serde_json::value::RawValue;
use std::{
any::TypeId,
sync::{
@ -17,24 +19,34 @@ pub struct AttachmentRegistry {
}
pub trait LanguageModelAttachment {
type Output: 'static;
type Output: DeserializeOwned + Serialize + 'static;
type View: Render + ToolOutput;
fn name(&self) -> Arc<str>;
fn run(&self, cx: &mut WindowContext) -> Task<Result<Self::Output>>;
fn view(output: Result<Self::Output>, cx: &mut WindowContext) -> View<Self::View>;
fn view(&self, output: Result<Self::Output>, cx: &mut WindowContext) -> View<Self::View>;
}
/// A collected attachment from running an attachment tool
pub struct UserAttachment {
pub view: AnyView,
name: Arc<str>,
serialized_output: Result<Box<RawValue>, String>,
generate_fn: fn(AnyView, &mut ProjectContext, cx: &mut WindowContext) -> String,
}
#[derive(Serialize, Deserialize)]
pub struct SavedUserAttachment {
name: Arc<str>,
serialized_output: Result<Box<RawValue>, String>,
}
/// Internal representation of an attachment tool to allow us to treat them dynamically
struct RegisteredAttachment {
name: Arc<str>,
enabled: AtomicBool,
call: Box<dyn Fn(&mut WindowContext) -> Task<Result<UserAttachment>>>,
deserialize: Box<dyn Fn(&SavedUserAttachment, &mut WindowContext) -> Result<UserAttachment>>,
}
impl AttachmentRegistry {
@ -45,24 +57,65 @@ impl AttachmentRegistry {
}
pub fn register<A: LanguageModelAttachment + 'static>(&mut self, attachment: A) {
let call = Box::new(move |cx: &mut WindowContext| {
let result = attachment.run(cx);
let attachment = Arc::new(attachment);
cx.spawn(move |mut cx| async move {
let result: Result<A::Output> = result.await;
let view = cx.update(|cx| A::view(result, cx))?;
let call = Box::new({
let attachment = attachment.clone();
move |cx: &mut WindowContext| {
let result = attachment.run(cx);
let attachment = attachment.clone();
cx.spawn(move |mut cx| async move {
let result: Result<A::Output> = result.await;
let serialized_output =
result
.as_ref()
.map_err(ToString::to_string)
.and_then(|output| {
Ok(RawValue::from_string(
serde_json::to_string(output).map_err(|e| e.to_string())?,
)
.unwrap())
});
let view = cx.update(|cx| attachment.view(result, cx))?;
Ok(UserAttachment {
name: attachment.name(),
view: view.into(),
generate_fn: generate::<A>,
serialized_output,
})
})
}
});
let deserialize = Box::new({
let attachment = attachment.clone();
move |saved_attachment: &SavedUserAttachment, cx: &mut WindowContext| {
let serialized_output = saved_attachment.serialized_output.clone();
let output = match &serialized_output {
Ok(serialized_output) => {
Ok(serde_json::from_str::<A::Output>(serialized_output.get())?)
}
Err(error) => Err(anyhow!("{error}")),
};
let view = attachment.view(output, cx).into();
Ok(UserAttachment {
view: view.into(),
name: saved_attachment.name.clone(),
view,
serialized_output,
generate_fn: generate::<A>,
})
})
}
});
self.registered_attachments.insert(
TypeId::of::<A>(),
RegisteredAttachment {
name: attachment.name(),
call,
deserialize,
enabled: AtomicBool::new(true),
},
);
@ -134,6 +187,35 @@ impl AttachmentRegistry {
.collect())
})
}
pub fn serialize_user_attachment(
&self,
user_attachment: &UserAttachment,
) -> SavedUserAttachment {
SavedUserAttachment {
name: user_attachment.name.clone(),
serialized_output: user_attachment.serialized_output.clone(),
}
}
pub fn deserialize_user_attachment(
&self,
saved_user_attachment: SavedUserAttachment,
cx: &mut WindowContext,
) -> Result<UserAttachment> {
if let Some(registered_attachment) = self
.registered_attachments
.values()
.find(|attachment| attachment.name == saved_user_attachment.name)
{
(registered_attachment.deserialize)(&saved_user_attachment, cx)
} else {
Err(anyhow!(
"no attachment tool for name {}",
saved_user_attachment.name
))
}
}
}
impl UserAttachment {

View File

@ -1,41 +1,60 @@
use crate::ProjectContext;
use anyhow::{anyhow, Result};
use gpui::{
div, AnyElement, AnyView, IntoElement, ParentElement, Render, Styled, Task, View, WindowContext,
};
use schemars::{schema::RootSchema, schema_for, JsonSchema};
use serde::Deserialize;
use serde_json::Value;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use serde_json::{value::RawValue, Value};
use std::{
any::TypeId,
collections::HashMap,
fmt::Display,
sync::atomic::{AtomicBool, Ordering::SeqCst},
sync::{
atomic::{AtomicBool, Ordering::SeqCst},
Arc,
},
};
use crate::ProjectContext;
pub struct ToolRegistry {
registered_tools: HashMap<String, RegisteredTool>,
}
#[derive(Default, Deserialize)]
#[derive(Default)]
pub struct ToolFunctionCall {
pub id: String,
pub name: String,
pub arguments: String,
#[serde(skip)]
pub result: Option<ToolFunctionCallResult>,
}
#[derive(Default, Serialize, Deserialize)]
pub struct SavedToolFunctionCall {
pub id: String,
pub name: String,
pub arguments: String,
pub result: Option<SavedToolFunctionCallResult>,
}
pub enum ToolFunctionCallResult {
NoSuchTool,
ParsingFailed,
Finished {
view: AnyView,
serialized_output: Result<Box<RawValue>, String>,
generate_fn: fn(AnyView, &mut ProjectContext, &mut WindowContext) -> String,
},
}
#[derive(Serialize, Deserialize)]
pub enum SavedToolFunctionCallResult {
NoSuchTool,
ParsingFailed,
Finished {
serialized_output: Result<Box<RawValue>, String>,
},
}
#[derive(Clone)]
pub struct ToolFunctionDefinition {
pub name: String,
@ -46,10 +65,10 @@ pub struct ToolFunctionDefinition {
pub trait LanguageModelTool {
/// The input type that will be passed in to `execute` when the tool is called
/// by the language model.
type Input: for<'de> Deserialize<'de> + JsonSchema;
type Input: DeserializeOwned + JsonSchema;
/// The output returned by executing the tool.
type Output: 'static;
type Output: DeserializeOwned + Serialize + 'static;
type View: Render + ToolOutput;
@ -80,7 +99,8 @@ pub trait LanguageModelTool {
fn execute(&self, input: &Self::Input, cx: &mut WindowContext) -> Task<Result<Self::Output>>;
/// A view of the output of running the tool, for displaying to the user.
fn output_view(
fn view(
&self,
input: Self::Input,
output: Result<Self::Output>,
cx: &mut WindowContext,
@ -102,7 +122,8 @@ pub trait ToolOutput: Sized {
struct RegisteredTool {
enabled: AtomicBool,
type_id: TypeId,
call: Box<dyn Fn(&ToolFunctionCall, &mut WindowContext) -> Task<Result<ToolFunctionCall>>>,
execute: Box<dyn Fn(&ToolFunctionCall, &mut WindowContext) -> Task<Result<ToolFunctionCall>>>,
deserialize: Box<dyn Fn(&SavedToolFunctionCall, &mut WindowContext) -> ToolFunctionCall>,
render_running: fn(&ToolFunctionCall, &mut WindowContext) -> gpui::AnyElement,
definition: ToolFunctionDefinition,
}
@ -162,23 +183,125 @@ impl ToolRegistry {
}
}
pub fn serialize_tool_call(&self, call: &ToolFunctionCall) -> SavedToolFunctionCall {
SavedToolFunctionCall {
id: call.id.clone(),
name: call.name.clone(),
arguments: call.arguments.clone(),
result: call.result.as_ref().map(|result| match result {
ToolFunctionCallResult::NoSuchTool => SavedToolFunctionCallResult::NoSuchTool,
ToolFunctionCallResult::ParsingFailed => SavedToolFunctionCallResult::ParsingFailed,
ToolFunctionCallResult::Finished {
serialized_output, ..
} => SavedToolFunctionCallResult::Finished {
serialized_output: match serialized_output {
Ok(value) => Ok(value.clone()),
Err(e) => Err(e.to_string()),
},
},
}),
}
}
pub fn deserialize_tool_call(
&self,
call: &SavedToolFunctionCall,
cx: &mut WindowContext,
) -> ToolFunctionCall {
if let Some(tool) = &self.registered_tools.get(&call.name) {
(tool.deserialize)(call, cx)
} else {
ToolFunctionCall {
id: call.id.clone(),
name: call.name.clone(),
arguments: call.arguments.clone(),
result: Some(ToolFunctionCallResult::NoSuchTool),
}
}
}
pub fn register<T: 'static + LanguageModelTool>(
&mut self,
tool: T,
_cx: &mut WindowContext,
) -> Result<()> {
let name = tool.name();
let tool = Arc::new(tool);
let registered_tool = RegisteredTool {
type_id: TypeId::of::<T>(),
definition: tool.definition(),
enabled: AtomicBool::new(true),
call: Box::new(
move |tool_call: &ToolFunctionCall, cx: &mut WindowContext| {
deserialize: Box::new({
let tool = tool.clone();
move |tool_call: &SavedToolFunctionCall, cx: &mut WindowContext| {
let id = tool_call.id.clone();
let name = tool_call.name.clone();
let arguments = tool_call.arguments.clone();
let id = tool_call.id.clone();
let Ok(input) = serde_json::from_str::<T::Input>(arguments.as_str()) else {
let Ok(input) = serde_json::from_str::<T::Input>(&tool_call.arguments) else {
return ToolFunctionCall {
id,
name: name.clone(),
arguments,
result: Some(ToolFunctionCallResult::ParsingFailed),
};
};
let result = match &tool_call.result {
Some(result) => match result {
SavedToolFunctionCallResult::NoSuchTool => {
Some(ToolFunctionCallResult::NoSuchTool)
}
SavedToolFunctionCallResult::ParsingFailed => {
Some(ToolFunctionCallResult::ParsingFailed)
}
SavedToolFunctionCallResult::Finished { serialized_output } => {
let output = match serialized_output {
Ok(value) => {
match serde_json::from_str::<T::Output>(value.get()) {
Ok(value) => Ok(value),
Err(_) => {
return ToolFunctionCall {
id,
name: name.clone(),
arguments,
result: Some(
ToolFunctionCallResult::ParsingFailed,
),
};
}
}
}
Err(e) => Err(anyhow!("{e}")),
};
let view = tool.view(input, output, cx).into();
Some(ToolFunctionCallResult::Finished {
serialized_output: serialized_output.clone(),
generate_fn: generate::<T>,
view,
})
}
},
None => None,
};
ToolFunctionCall {
id: tool_call.id.clone(),
name: name.clone(),
arguments: tool_call.arguments.clone(),
result,
}
}
}),
execute: Box::new({
let tool = tool.clone();
move |tool_call: &ToolFunctionCall, cx: &mut WindowContext| {
let id = tool_call.id.clone();
let name = tool_call.name.clone();
let arguments = tool_call.arguments.clone();
let Ok(input) = serde_json::from_str::<T::Input>(&arguments) else {
return Task::ready(Ok(ToolFunctionCall {
id,
name: name.clone(),
@ -188,23 +311,33 @@ impl ToolRegistry {
};
let result = tool.execute(&input, cx);
let tool = tool.clone();
cx.spawn(move |mut cx| async move {
let result: Result<T::Output> = result.await;
let view = cx.update(|cx| T::output_view(input, result, cx))?;
let result = result.await;
let serialized_output = result
.as_ref()
.map_err(ToString::to_string)
.and_then(|output| {
Ok(RawValue::from_string(
serde_json::to_string(output).map_err(|e| e.to_string())?,
)
.unwrap())
});
let view = cx.update(|cx| tool.view(input, result, cx))?;
Ok(ToolFunctionCall {
id,
name: name.clone(),
arguments,
result: Some(ToolFunctionCallResult::Finished {
serialized_output,
view: view.into(),
generate_fn: generate::<T>,
}),
})
})
},
),
}
}),
render_running: render_running::<T>,
};
@ -259,7 +392,7 @@ impl ToolRegistry {
}
};
(tool.call)(tool_call, cx)
(tool.execute)(tool_call, cx)
}
}
@ -275,9 +408,9 @@ impl ToolFunctionCallResult {
ToolFunctionCallResult::ParsingFailed => {
format!("Unable to parse arguments for {name}")
}
ToolFunctionCallResult::Finished { generate_fn, view } => {
(generate_fn)(view.clone(), project, cx)
}
ToolFunctionCallResult::Finished {
generate_fn, view, ..
} => (generate_fn)(view.clone(), project, cx),
}
}
@ -373,7 +506,8 @@ mod test {
Task::ready(Ok(weather))
}
fn output_view(
fn view(
&self,
_input: Self::Input,
result: Result<Self::Output>,
cx: &mut WindowContext,

View File

@ -7864,6 +7864,18 @@ impl Project {
})
}
pub fn project_path_for_absolute_path(
&self,
abs_path: &Path,
cx: &AppContext,
) -> Option<ProjectPath> {
self.find_local_worktree(abs_path, cx)
.map(|(worktree, relative_path)| ProjectPath {
worktree_id: worktree.read(cx).id(),
path: relative_path.into(),
})
}
pub fn get_workspace_root(
&self,
project_path: &ProjectPath,

View File

@ -250,6 +250,7 @@ impl SearchQuery {
}
}
}
pub async fn search(
&self,
buffer: &BufferSnapshot,

View File

@ -450,7 +450,7 @@ pub struct WorktreeSearchResult {
pub score: f32,
}
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
#[derive(Copy, Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
pub enum Status {
Idle,
Loading,