mirror of
https://github.com/zed-industries/zed.git
synced 2024-09-19 10:29:35 +03:00
Allow codebase search to be turned on or off within the composer for assistant2 (#11315)
![image](https://github.com/zed-industries/zed/assets/836375/e03d2357-e2e4-49f1-86d6-7593bce13618) ![image](https://github.com/zed-industries/zed/assets/836375/3d769622-82e1-4e6f-bdec-4dce81e43423) ![image](https://github.com/zed-industries/zed/assets/836375/bf79a4ec-1660-47b1-8525-e741575dc5d4) Release Notes: - N/A
This commit is contained in:
parent
43ad470e58
commit
1915a756a0
@ -23,7 +23,7 @@ use semantic_index::{CloudEmbeddingProvider, ProjectIndex, SemanticIndex};
|
||||
use serde::Deserialize;
|
||||
use settings::Settings;
|
||||
use std::sync::Arc;
|
||||
use ui::Composer;
|
||||
use ui::{Composer, ProjectIndexButton};
|
||||
use util::{paths::EMBEDDINGS_DIR, ResultExt};
|
||||
use workspace::{
|
||||
dock::{DockPosition, Panel, PanelEvent},
|
||||
@ -228,6 +228,7 @@ pub struct AssistantChat {
|
||||
list_state: ListState,
|
||||
language_registry: Arc<LanguageRegistry>,
|
||||
composer_editor: View<Editor>,
|
||||
project_index_button: Option<View<ProjectIndexButton>>,
|
||||
user_store: Model<UserStore>,
|
||||
next_message_id: MessageId,
|
||||
collapsed_messages: HashMap<MessageId, bool>,
|
||||
@ -263,6 +264,10 @@ impl AssistantChat {
|
||||
},
|
||||
);
|
||||
|
||||
let project_index_button = project_index.clone().map(|project_index| {
|
||||
cx.new_view(|cx| ProjectIndexButton::new(project_index, tool_registry.clone(), cx))
|
||||
});
|
||||
|
||||
Self {
|
||||
model,
|
||||
messages: Vec::new(),
|
||||
@ -275,6 +280,7 @@ impl AssistantChat {
|
||||
list_state,
|
||||
user_store,
|
||||
language_registry,
|
||||
project_index_button,
|
||||
project_index,
|
||||
next_message_id: MessageId(0),
|
||||
editing_message: None,
|
||||
@ -397,7 +403,7 @@ impl AssistantChat {
|
||||
{
|
||||
this.tool_registry.definitions()
|
||||
} else {
|
||||
&[]
|
||||
Vec::new()
|
||||
};
|
||||
call_count += 1;
|
||||
|
||||
@ -590,7 +596,7 @@ impl AssistantChat {
|
||||
element.child(Composer::new(
|
||||
body.clone(),
|
||||
self.user_store.read(cx).current_user(),
|
||||
self.tool_registry.clone(),
|
||||
self.project_index_button.clone(),
|
||||
crate::ui::ModelSelector::new(
|
||||
cx.view().downgrade(),
|
||||
self.model.clone(),
|
||||
@ -768,7 +774,7 @@ impl Render for AssistantChat {
|
||||
.child(Composer::new(
|
||||
self.composer_editor.clone(),
|
||||
self.user_store.read(cx).current_user(),
|
||||
self.tool_registry.clone(),
|
||||
self.project_index_button.clone(),
|
||||
crate::ui::ModelSelector::new(cx.view().downgrade(), self.model.clone())
|
||||
.into_any_element(),
|
||||
))
|
||||
|
@ -33,7 +33,7 @@ impl CompletionProvider {
|
||||
messages: Vec<CompletionMessage>,
|
||||
stop: Vec<String>,
|
||||
temperature: f32,
|
||||
tools: &[ToolFunctionDefinition],
|
||||
tools: Vec<ToolFunctionDefinition>,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<proto::LanguageModelResponseMessage>>>>
|
||||
{
|
||||
self.0.complete(model, messages, stop, temperature, tools)
|
||||
@ -51,7 +51,7 @@ pub trait CompletionProviderBackend: 'static {
|
||||
messages: Vec<CompletionMessage>,
|
||||
stop: Vec<String>,
|
||||
temperature: f32,
|
||||
tools: &[ToolFunctionDefinition],
|
||||
tools: Vec<ToolFunctionDefinition>,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<proto::LanguageModelResponseMessage>>>>;
|
||||
}
|
||||
|
||||
@ -80,7 +80,7 @@ impl CompletionProviderBackend for CloudCompletionProvider {
|
||||
messages: Vec<CompletionMessage>,
|
||||
stop: Vec<String>,
|
||||
temperature: f32,
|
||||
tools: &[ToolFunctionDefinition],
|
||||
tools: Vec<ToolFunctionDefinition>,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<proto::LanguageModelResponseMessage>>>>
|
||||
{
|
||||
let client = self.client.clone();
|
||||
|
@ -1,14 +1,17 @@
|
||||
use anyhow::Result;
|
||||
use assistant_tooling::LanguageModelTool;
|
||||
use gpui::{percentage, prelude::*, Animation, AnimationExt, AnyView, Model, Task, Transformation};
|
||||
use assistant_tooling::{
|
||||
// assistant_tool_button::{AssistantToolButton, ToolStatus},
|
||||
LanguageModelTool,
|
||||
};
|
||||
use gpui::{prelude::*, Model, Task};
|
||||
use project::Fs;
|
||||
use schemars::JsonSchema;
|
||||
use semantic_index::{ProjectIndex, Status};
|
||||
use serde::Deserialize;
|
||||
use std::{sync::Arc, time::Duration};
|
||||
use std::sync::Arc;
|
||||
use ui::{
|
||||
div, prelude::*, ButtonLike, CollapsibleContainer, Color, Icon, IconName, Indicator, Label,
|
||||
SharedString, Tooltip, WindowContext,
|
||||
div, prelude::*, CollapsibleContainer, Color, Icon, IconName, Label, SharedString,
|
||||
WindowContext,
|
||||
};
|
||||
use util::ResultExt as _;
|
||||
|
||||
@ -199,13 +202,6 @@ impl LanguageModelTool for ProjectIndexTool {
|
||||
cx.new_view(|_cx| ProjectIndexView { input, output })
|
||||
}
|
||||
|
||||
fn status_view(&self, cx: &mut WindowContext) -> Option<AnyView> {
|
||||
Some(
|
||||
cx.new_view(|cx| ProjectIndexStatusView::new(self.project_index.clone(), cx))
|
||||
.into(),
|
||||
)
|
||||
}
|
||||
|
||||
fn format(_input: &Self::Input, output: &Result<Self::Output>) -> String {
|
||||
match &output {
|
||||
Ok(output) => {
|
||||
@ -236,82 +232,3 @@ impl LanguageModelTool for ProjectIndexTool {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct ProjectIndexStatusView {
|
||||
project_index: Model<ProjectIndex>,
|
||||
}
|
||||
|
||||
impl ProjectIndexStatusView {
|
||||
pub fn new(project_index: Model<ProjectIndex>, cx: &mut ViewContext<Self>) -> Self {
|
||||
cx.subscribe(&project_index, |_this, _, _status: &Status, cx| {
|
||||
cx.notify();
|
||||
})
|
||||
.detach();
|
||||
Self { project_index }
|
||||
}
|
||||
}
|
||||
|
||||
impl Render for ProjectIndexStatusView {
|
||||
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
|
||||
let status = self.project_index.read(cx).status();
|
||||
|
||||
let is_enabled = match status {
|
||||
Status::Idle => true,
|
||||
_ => false,
|
||||
};
|
||||
|
||||
let icon = match status {
|
||||
Status::Idle => Icon::new(IconName::Code)
|
||||
.size(IconSize::XSmall)
|
||||
.color(Color::Default),
|
||||
Status::Loading => Icon::new(IconName::Code)
|
||||
.size(IconSize::XSmall)
|
||||
.color(Color::Muted),
|
||||
Status::Scanning { .. } => Icon::new(IconName::Code)
|
||||
.size(IconSize::XSmall)
|
||||
.color(Color::Muted),
|
||||
};
|
||||
|
||||
let indicator = match status {
|
||||
Status::Idle => Some(Indicator::dot().color(Color::Success)),
|
||||
Status::Scanning { .. } => Some(Indicator::dot().color(Color::Warning)),
|
||||
Status::Loading => Some(Indicator::icon(
|
||||
Icon::new(IconName::Spinner)
|
||||
.color(Color::Accent)
|
||||
.with_animation(
|
||||
"arrow-circle",
|
||||
Animation::new(Duration::from_secs(2)).repeat(),
|
||||
|icon, delta| icon.transform(Transformation::rotate(percentage(delta))),
|
||||
),
|
||||
)),
|
||||
};
|
||||
|
||||
ButtonLike::new("project-index")
|
||||
.disabled(!is_enabled)
|
||||
.child(
|
||||
ui::IconWithIndicator::new(icon, indicator)
|
||||
.indicator_border_color(Some(gpui::transparent_black())),
|
||||
)
|
||||
.tooltip({
|
||||
move |cx| {
|
||||
let (tooltip, meta) = match status {
|
||||
Status::Idle => (
|
||||
"Project index ready".to_string(),
|
||||
Some("Click to disable".to_string()),
|
||||
),
|
||||
Status::Loading => ("Project index loading...".to_string(), None),
|
||||
Status::Scanning { remaining_count } => (
|
||||
"Project index scanning...".to_string(),
|
||||
Some(format!("{} remaining...", remaining_count)),
|
||||
),
|
||||
};
|
||||
|
||||
if let Some(meta) = meta {
|
||||
Tooltip::with_meta(tooltip, None, meta, cx)
|
||||
} else {
|
||||
Tooltip::text(tooltip, cx)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
mod chat_message;
|
||||
mod chat_notice;
|
||||
mod composer;
|
||||
mod project_index_button;
|
||||
|
||||
#[cfg(feature = "stories")]
|
||||
mod stories;
|
||||
@ -8,6 +9,7 @@ mod stories;
|
||||
pub use chat_message::*;
|
||||
pub use chat_notice::*;
|
||||
pub use composer::*;
|
||||
pub use project_index_button::*;
|
||||
|
||||
#[cfg(feature = "stories")]
|
||||
pub use stories::*;
|
||||
|
@ -1,4 +1,4 @@
|
||||
use assistant_tooling::ToolRegistry;
|
||||
use crate::{ui::ProjectIndexButton, AssistantChat, CompletionProvider};
|
||||
use client::User;
|
||||
use editor::{Editor, EditorElement, EditorStyle};
|
||||
use gpui::{AnyElement, FontStyle, FontWeight, TextStyle, View, WeakView, WhiteSpace};
|
||||
@ -7,13 +7,11 @@ use std::sync::Arc;
|
||||
use theme::ThemeSettings;
|
||||
use ui::{popover_menu, prelude::*, Avatar, ButtonLike, ContextMenu, Tooltip};
|
||||
|
||||
use crate::{AssistantChat, CompletionProvider};
|
||||
|
||||
#[derive(IntoElement)]
|
||||
pub struct Composer {
|
||||
editor: View<Editor>,
|
||||
player: Option<Arc<User>>,
|
||||
tool_registry: Arc<ToolRegistry>,
|
||||
project_index_button: Option<View<ProjectIndexButton>>,
|
||||
model_selector: AnyElement,
|
||||
}
|
||||
|
||||
@ -21,20 +19,28 @@ impl Composer {
|
||||
pub fn new(
|
||||
editor: View<Editor>,
|
||||
player: Option<Arc<User>>,
|
||||
tool_registry: Arc<ToolRegistry>,
|
||||
project_index_button: Option<View<ProjectIndexButton>>,
|
||||
model_selector: AnyElement,
|
||||
) -> Self {
|
||||
Self {
|
||||
editor,
|
||||
player,
|
||||
tool_registry,
|
||||
project_index_button,
|
||||
model_selector,
|
||||
}
|
||||
}
|
||||
|
||||
fn render_tools(&mut self, _cx: &mut WindowContext) -> impl IntoElement {
|
||||
h_flex().children(
|
||||
self.project_index_button
|
||||
.clone()
|
||||
.map(|view| view.into_any_element()),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl RenderOnce for Composer {
|
||||
fn render(self, cx: &mut WindowContext) -> impl IntoElement {
|
||||
fn render(mut self, cx: &mut WindowContext) -> impl IntoElement {
|
||||
let mut player_avatar = div().size(rems_from_px(20.)).into_any_element();
|
||||
if let Some(player) = self.player.clone() {
|
||||
player_avatar = Avatar::new(player.avatar_uri.clone())
|
||||
@ -95,9 +101,7 @@ impl RenderOnce for Composer {
|
||||
.gap_2()
|
||||
.justify_between()
|
||||
.w_full()
|
||||
.child(h_flex().gap_1().children(
|
||||
self.tool_registry.status_views().iter().cloned(),
|
||||
))
|
||||
.child(h_flex().gap_1().child(self.render_tools(cx)))
|
||||
.child(h_flex().gap_1().child(self.model_selector)),
|
||||
),
|
||||
),
|
||||
|
109
crates/assistant2/src/ui/project_index_button.rs
Normal file
109
crates/assistant2/src/ui/project_index_button.rs
Normal file
@ -0,0 +1,109 @@
|
||||
use assistant_tooling::ToolRegistry;
|
||||
use gpui::{percentage, prelude::*, Animation, AnimationExt, Model, Transformation};
|
||||
use semantic_index::{ProjectIndex, Status};
|
||||
use std::{sync::Arc, time::Duration};
|
||||
use ui::{prelude::*, ButtonLike, Color, Icon, IconName, Indicator, Tooltip};
|
||||
|
||||
use crate::tools::ProjectIndexTool;
|
||||
|
||||
pub struct ProjectIndexButton {
|
||||
project_index: Model<ProjectIndex>,
|
||||
tool_registry: Arc<ToolRegistry>,
|
||||
}
|
||||
|
||||
impl ProjectIndexButton {
|
||||
pub fn new(
|
||||
project_index: Model<ProjectIndex>,
|
||||
tool_registry: Arc<ToolRegistry>,
|
||||
cx: &mut ViewContext<Self>,
|
||||
) -> Self {
|
||||
cx.subscribe(&project_index, |_this, _, _status: &Status, cx| {
|
||||
cx.notify();
|
||||
})
|
||||
.detach();
|
||||
Self {
|
||||
project_index,
|
||||
tool_registry,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_enabled(&mut self, enabled: bool) {
|
||||
self.tool_registry
|
||||
.set_tool_enabled::<ProjectIndexTool>(enabled);
|
||||
}
|
||||
}
|
||||
|
||||
impl Render for ProjectIndexButton {
|
||||
// Expanded information on ToolView
|
||||
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
|
||||
let status = self.project_index.read(cx).status();
|
||||
let is_enabled = self.tool_registry.is_tool_enabled::<ProjectIndexTool>();
|
||||
|
||||
let icon = if is_enabled {
|
||||
match status {
|
||||
Status::Idle => Icon::new(IconName::Code)
|
||||
.size(IconSize::XSmall)
|
||||
.color(Color::Default),
|
||||
Status::Loading => Icon::new(IconName::Code)
|
||||
.size(IconSize::XSmall)
|
||||
.color(Color::Muted),
|
||||
Status::Scanning { .. } => Icon::new(IconName::Code)
|
||||
.size(IconSize::XSmall)
|
||||
.color(Color::Muted),
|
||||
}
|
||||
} else {
|
||||
Icon::new(IconName::Code)
|
||||
.size(IconSize::XSmall)
|
||||
.color(Color::Disabled)
|
||||
};
|
||||
|
||||
let indicator = if is_enabled {
|
||||
match status {
|
||||
Status::Idle => Some(Indicator::dot().color(Color::Success)),
|
||||
Status::Scanning { .. } => Some(Indicator::dot().color(Color::Warning)),
|
||||
Status::Loading => Some(Indicator::icon(
|
||||
Icon::new(IconName::Spinner)
|
||||
.color(Color::Accent)
|
||||
.with_animation(
|
||||
"arrow-circle",
|
||||
Animation::new(Duration::from_secs(2)).repeat(),
|
||||
|icon, delta| icon.transform(Transformation::rotate(percentage(delta))),
|
||||
),
|
||||
)),
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
ButtonLike::new("project-index")
|
||||
.child(
|
||||
ui::IconWithIndicator::new(icon, indicator)
|
||||
.indicator_border_color(Some(gpui::transparent_black())),
|
||||
)
|
||||
.tooltip({
|
||||
move |cx| {
|
||||
let (tooltip, meta) = match status {
|
||||
Status::Idle => (
|
||||
"Project index ready".to_string(),
|
||||
Some("Click to disable".to_string()),
|
||||
),
|
||||
Status::Loading => ("Project index loading...".to_string(), None),
|
||||
Status::Scanning { remaining_count } => (
|
||||
"Project index scanning...".to_string(),
|
||||
Some(format!("{} remaining...", remaining_count)),
|
||||
),
|
||||
};
|
||||
|
||||
if let Some(meta) = meta {
|
||||
Tooltip::with_meta(tooltip, None, meta, cx)
|
||||
} else {
|
||||
Tooltip::text(tooltip, cx)
|
||||
}
|
||||
}
|
||||
})
|
||||
.on_click(cx.listener(move |this, _, cx| {
|
||||
this.set_enabled(!is_enabled);
|
||||
cx.notify();
|
||||
}))
|
||||
}
|
||||
}
|
@ -1,48 +1,86 @@
|
||||
use anyhow::{anyhow, Result};
|
||||
use gpui::{AnyView, Task, WindowContext};
|
||||
use std::collections::HashMap;
|
||||
use gpui::{Task, WindowContext};
|
||||
use std::{
|
||||
any::TypeId,
|
||||
collections::HashMap,
|
||||
sync::atomic::{AtomicBool, Ordering::SeqCst},
|
||||
};
|
||||
|
||||
use crate::tool::{
|
||||
LanguageModelTool, ToolFunctionCall, ToolFunctionCallResult, ToolFunctionDefinition,
|
||||
};
|
||||
|
||||
// Internal Tool representation for the registry
|
||||
pub struct Tool {
|
||||
enabled: AtomicBool,
|
||||
type_id: TypeId,
|
||||
call: Box<dyn Fn(&ToolFunctionCall, &mut WindowContext) -> Task<Result<ToolFunctionCall>>>,
|
||||
definition: ToolFunctionDefinition,
|
||||
}
|
||||
|
||||
impl Tool {
|
||||
fn new(
|
||||
type_id: TypeId,
|
||||
call: Box<dyn Fn(&ToolFunctionCall, &mut WindowContext) -> Task<Result<ToolFunctionCall>>>,
|
||||
definition: ToolFunctionDefinition,
|
||||
) -> Self {
|
||||
Self {
|
||||
enabled: AtomicBool::new(true),
|
||||
type_id,
|
||||
call,
|
||||
definition,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ToolRegistry {
|
||||
tools: HashMap<
|
||||
String,
|
||||
Box<dyn Fn(&ToolFunctionCall, &mut WindowContext) -> Task<Result<ToolFunctionCall>>>,
|
||||
>,
|
||||
definitions: Vec<ToolFunctionDefinition>,
|
||||
status_views: Vec<AnyView>,
|
||||
tools: HashMap<String, Tool>,
|
||||
}
|
||||
|
||||
impl ToolRegistry {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
tools: HashMap::new(),
|
||||
definitions: Vec::new(),
|
||||
status_views: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn definitions(&self) -> &[ToolFunctionDefinition] {
|
||||
&self.definitions
|
||||
pub fn set_tool_enabled<T: 'static + LanguageModelTool>(&self, is_enabled: bool) {
|
||||
for tool in self.tools.values() {
|
||||
if tool.type_id == TypeId::of::<T>() {
|
||||
tool.enabled.store(is_enabled, SeqCst);
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_tool_enabled<T: 'static + LanguageModelTool>(&self) -> bool {
|
||||
for tool in self.tools.values() {
|
||||
if tool.type_id == TypeId::of::<T>() {
|
||||
return tool.enabled.load(SeqCst);
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
pub fn definitions(&self) -> Vec<ToolFunctionDefinition> {
|
||||
self.tools
|
||||
.values()
|
||||
.filter(|tool| tool.enabled.load(SeqCst))
|
||||
.map(|tool| tool.definition.clone())
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn register<T: 'static + LanguageModelTool>(
|
||||
&mut self,
|
||||
tool: T,
|
||||
cx: &mut WindowContext,
|
||||
_cx: &mut WindowContext,
|
||||
) -> Result<()> {
|
||||
self.definitions.push(tool.definition());
|
||||
|
||||
if let Some(tool_view) = tool.status_view(cx) {
|
||||
self.status_views.push(tool_view);
|
||||
}
|
||||
let definition = tool.definition();
|
||||
|
||||
let name = tool.name();
|
||||
let previous = self.tools.insert(
|
||||
name.clone(),
|
||||
// registry.call(tool_call, cx)
|
||||
|
||||
let registered_tool = Tool::new(
|
||||
TypeId::of::<T>(),
|
||||
Box::new(
|
||||
move |tool_call: &ToolFunctionCall, cx: &mut WindowContext| {
|
||||
let name = tool_call.name.clone();
|
||||
@ -77,8 +115,11 @@ impl ToolRegistry {
|
||||
})
|
||||
},
|
||||
),
|
||||
definition,
|
||||
);
|
||||
|
||||
let previous = self.tools.insert(name.clone(), registered_tool);
|
||||
|
||||
if previous.is_some() {
|
||||
return Err(anyhow!("already registered a tool with name {}", name));
|
||||
}
|
||||
@ -109,11 +150,7 @@ impl ToolRegistry {
|
||||
}
|
||||
};
|
||||
|
||||
tool(tool_call, cx)
|
||||
}
|
||||
|
||||
pub fn status_views(&self) -> &[AnyView] {
|
||||
&self.status_views
|
||||
(tool.call)(tool_call, cx)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -104,8 +104,4 @@ pub trait LanguageModelTool {
|
||||
output: Result<Self::Output>,
|
||||
cx: &mut WindowContext,
|
||||
) -> View<Self::View>;
|
||||
|
||||
fn status_view(&self, _cx: &mut WindowContext) -> Option<AnyView> {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user