Kyle Kelley 2024-05-02 13:26:46 -07:00 committed by GitHub
parent 43ad470e58
commit 1915a756a0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 209 additions and 138 deletions

View File

@ -23,7 +23,7 @@ use semantic_index::{CloudEmbeddingProvider, ProjectIndex, SemanticIndex};
use serde::Deserialize;
use settings::Settings;
use std::sync::Arc;
use ui::Composer;
use ui::{Composer, ProjectIndexButton};
use util::{paths::EMBEDDINGS_DIR, ResultExt};
use workspace::{
dock::{DockPosition, Panel, PanelEvent},
@ -228,6 +228,7 @@ pub struct AssistantChat {
list_state: ListState,
language_registry: Arc<LanguageRegistry>,
composer_editor: View<Editor>,
project_index_button: Option<View<ProjectIndexButton>>,
user_store: Model<UserStore>,
next_message_id: MessageId,
collapsed_messages: HashMap<MessageId, bool>,
@ -263,6 +264,10 @@ impl AssistantChat {
},
);
let project_index_button = project_index.clone().map(|project_index| {
cx.new_view(|cx| ProjectIndexButton::new(project_index, tool_registry.clone(), cx))
});
Self {
model,
messages: Vec::new(),
@ -275,6 +280,7 @@ impl AssistantChat {
list_state,
user_store,
language_registry,
project_index_button,
project_index,
next_message_id: MessageId(0),
editing_message: None,
@ -397,7 +403,7 @@ impl AssistantChat {
{
this.tool_registry.definitions()
} else {
&[]
Vec::new()
};
call_count += 1;
@ -590,7 +596,7 @@ impl AssistantChat {
element.child(Composer::new(
body.clone(),
self.user_store.read(cx).current_user(),
self.tool_registry.clone(),
self.project_index_button.clone(),
crate::ui::ModelSelector::new(
cx.view().downgrade(),
self.model.clone(),
@ -768,7 +774,7 @@ impl Render for AssistantChat {
.child(Composer::new(
self.composer_editor.clone(),
self.user_store.read(cx).current_user(),
self.tool_registry.clone(),
self.project_index_button.clone(),
crate::ui::ModelSelector::new(cx.view().downgrade(), self.model.clone())
.into_any_element(),
))

View File

@ -33,7 +33,7 @@ impl CompletionProvider {
messages: Vec<CompletionMessage>,
stop: Vec<String>,
temperature: f32,
tools: &[ToolFunctionDefinition],
tools: Vec<ToolFunctionDefinition>,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<proto::LanguageModelResponseMessage>>>>
{
self.0.complete(model, messages, stop, temperature, tools)
@ -51,7 +51,7 @@ pub trait CompletionProviderBackend: 'static {
messages: Vec<CompletionMessage>,
stop: Vec<String>,
temperature: f32,
tools: &[ToolFunctionDefinition],
tools: Vec<ToolFunctionDefinition>,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<proto::LanguageModelResponseMessage>>>>;
}
@ -80,7 +80,7 @@ impl CompletionProviderBackend for CloudCompletionProvider {
messages: Vec<CompletionMessage>,
stop: Vec<String>,
temperature: f32,
tools: &[ToolFunctionDefinition],
tools: Vec<ToolFunctionDefinition>,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<proto::LanguageModelResponseMessage>>>>
{
let client = self.client.clone();

View File

@ -1,14 +1,17 @@
use anyhow::Result;
use assistant_tooling::LanguageModelTool;
use gpui::{percentage, prelude::*, Animation, AnimationExt, AnyView, Model, Task, Transformation};
use assistant_tooling::{
// assistant_tool_button::{AssistantToolButton, ToolStatus},
LanguageModelTool,
};
use gpui::{prelude::*, Model, Task};
use project::Fs;
use schemars::JsonSchema;
use semantic_index::{ProjectIndex, Status};
use serde::Deserialize;
use std::{sync::Arc, time::Duration};
use std::sync::Arc;
use ui::{
div, prelude::*, ButtonLike, CollapsibleContainer, Color, Icon, IconName, Indicator, Label,
SharedString, Tooltip, WindowContext,
div, prelude::*, CollapsibleContainer, Color, Icon, IconName, Label, SharedString,
WindowContext,
};
use util::ResultExt as _;
@ -199,13 +202,6 @@ impl LanguageModelTool for ProjectIndexTool {
cx.new_view(|_cx| ProjectIndexView { input, output })
}
fn status_view(&self, cx: &mut WindowContext) -> Option<AnyView> {
Some(
cx.new_view(|cx| ProjectIndexStatusView::new(self.project_index.clone(), cx))
.into(),
)
}
fn format(_input: &Self::Input, output: &Result<Self::Output>) -> String {
match &output {
Ok(output) => {
@ -236,82 +232,3 @@ impl LanguageModelTool for ProjectIndexTool {
}
}
}
struct ProjectIndexStatusView {
project_index: Model<ProjectIndex>,
}
impl ProjectIndexStatusView {
pub fn new(project_index: Model<ProjectIndex>, cx: &mut ViewContext<Self>) -> Self {
cx.subscribe(&project_index, |_this, _, _status: &Status, cx| {
cx.notify();
})
.detach();
Self { project_index }
}
}
impl Render for ProjectIndexStatusView {
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
let status = self.project_index.read(cx).status();
let is_enabled = match status {
Status::Idle => true,
_ => false,
};
let icon = match status {
Status::Idle => Icon::new(IconName::Code)
.size(IconSize::XSmall)
.color(Color::Default),
Status::Loading => Icon::new(IconName::Code)
.size(IconSize::XSmall)
.color(Color::Muted),
Status::Scanning { .. } => Icon::new(IconName::Code)
.size(IconSize::XSmall)
.color(Color::Muted),
};
let indicator = match status {
Status::Idle => Some(Indicator::dot().color(Color::Success)),
Status::Scanning { .. } => Some(Indicator::dot().color(Color::Warning)),
Status::Loading => Some(Indicator::icon(
Icon::new(IconName::Spinner)
.color(Color::Accent)
.with_animation(
"arrow-circle",
Animation::new(Duration::from_secs(2)).repeat(),
|icon, delta| icon.transform(Transformation::rotate(percentage(delta))),
),
)),
};
ButtonLike::new("project-index")
.disabled(!is_enabled)
.child(
ui::IconWithIndicator::new(icon, indicator)
.indicator_border_color(Some(gpui::transparent_black())),
)
.tooltip({
move |cx| {
let (tooltip, meta) = match status {
Status::Idle => (
"Project index ready".to_string(),
Some("Click to disable".to_string()),
),
Status::Loading => ("Project index loading...".to_string(), None),
Status::Scanning { remaining_count } => (
"Project index scanning...".to_string(),
Some(format!("{} remaining...", remaining_count)),
),
};
if let Some(meta) = meta {
Tooltip::with_meta(tooltip, None, meta, cx)
} else {
Tooltip::text(tooltip, cx)
}
}
})
}
}

View File

@ -1,6 +1,7 @@
mod chat_message;
mod chat_notice;
mod composer;
mod project_index_button;
#[cfg(feature = "stories")]
mod stories;
@ -8,6 +9,7 @@ mod stories;
pub use chat_message::*;
pub use chat_notice::*;
pub use composer::*;
pub use project_index_button::*;
#[cfg(feature = "stories")]
pub use stories::*;

View File

@ -1,4 +1,4 @@
use assistant_tooling::ToolRegistry;
use crate::{ui::ProjectIndexButton, AssistantChat, CompletionProvider};
use client::User;
use editor::{Editor, EditorElement, EditorStyle};
use gpui::{AnyElement, FontStyle, FontWeight, TextStyle, View, WeakView, WhiteSpace};
@ -7,13 +7,11 @@ use std::sync::Arc;
use theme::ThemeSettings;
use ui::{popover_menu, prelude::*, Avatar, ButtonLike, ContextMenu, Tooltip};
use crate::{AssistantChat, CompletionProvider};
#[derive(IntoElement)]
pub struct Composer {
editor: View<Editor>,
player: Option<Arc<User>>,
tool_registry: Arc<ToolRegistry>,
project_index_button: Option<View<ProjectIndexButton>>,
model_selector: AnyElement,
}
@ -21,20 +19,28 @@ impl Composer {
pub fn new(
editor: View<Editor>,
player: Option<Arc<User>>,
tool_registry: Arc<ToolRegistry>,
project_index_button: Option<View<ProjectIndexButton>>,
model_selector: AnyElement,
) -> Self {
Self {
editor,
player,
tool_registry,
project_index_button,
model_selector,
}
}
fn render_tools(&mut self, _cx: &mut WindowContext) -> impl IntoElement {
h_flex().children(
self.project_index_button
.clone()
.map(|view| view.into_any_element()),
)
}
}
impl RenderOnce for Composer {
fn render(self, cx: &mut WindowContext) -> impl IntoElement {
fn render(mut self, cx: &mut WindowContext) -> impl IntoElement {
let mut player_avatar = div().size(rems_from_px(20.)).into_any_element();
if let Some(player) = self.player.clone() {
player_avatar = Avatar::new(player.avatar_uri.clone())
@ -95,9 +101,7 @@ impl RenderOnce for Composer {
.gap_2()
.justify_between()
.w_full()
.child(h_flex().gap_1().children(
self.tool_registry.status_views().iter().cloned(),
))
.child(h_flex().gap_1().child(self.render_tools(cx)))
.child(h_flex().gap_1().child(self.model_selector)),
),
),

View File

@ -0,0 +1,109 @@
use assistant_tooling::ToolRegistry;
use gpui::{percentage, prelude::*, Animation, AnimationExt, Model, Transformation};
use semantic_index::{ProjectIndex, Status};
use std::{sync::Arc, time::Duration};
use ui::{prelude::*, ButtonLike, Color, Icon, IconName, Indicator, Tooltip};
use crate::tools::ProjectIndexTool;
pub struct ProjectIndexButton {
project_index: Model<ProjectIndex>,
tool_registry: Arc<ToolRegistry>,
}
impl ProjectIndexButton {
pub fn new(
project_index: Model<ProjectIndex>,
tool_registry: Arc<ToolRegistry>,
cx: &mut ViewContext<Self>,
) -> Self {
cx.subscribe(&project_index, |_this, _, _status: &Status, cx| {
cx.notify();
})
.detach();
Self {
project_index,
tool_registry,
}
}
pub fn set_enabled(&mut self, enabled: bool) {
self.tool_registry
.set_tool_enabled::<ProjectIndexTool>(enabled);
}
}
impl Render for ProjectIndexButton {
// Expanded information on ToolView
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
let status = self.project_index.read(cx).status();
let is_enabled = self.tool_registry.is_tool_enabled::<ProjectIndexTool>();
let icon = if is_enabled {
match status {
Status::Idle => Icon::new(IconName::Code)
.size(IconSize::XSmall)
.color(Color::Default),
Status::Loading => Icon::new(IconName::Code)
.size(IconSize::XSmall)
.color(Color::Muted),
Status::Scanning { .. } => Icon::new(IconName::Code)
.size(IconSize::XSmall)
.color(Color::Muted),
}
} else {
Icon::new(IconName::Code)
.size(IconSize::XSmall)
.color(Color::Disabled)
};
let indicator = if is_enabled {
match status {
Status::Idle => Some(Indicator::dot().color(Color::Success)),
Status::Scanning { .. } => Some(Indicator::dot().color(Color::Warning)),
Status::Loading => Some(Indicator::icon(
Icon::new(IconName::Spinner)
.color(Color::Accent)
.with_animation(
"arrow-circle",
Animation::new(Duration::from_secs(2)).repeat(),
|icon, delta| icon.transform(Transformation::rotate(percentage(delta))),
),
)),
}
} else {
None
};
ButtonLike::new("project-index")
.child(
ui::IconWithIndicator::new(icon, indicator)
.indicator_border_color(Some(gpui::transparent_black())),
)
.tooltip({
move |cx| {
let (tooltip, meta) = match status {
Status::Idle => (
"Project index ready".to_string(),
Some("Click to disable".to_string()),
),
Status::Loading => ("Project index loading...".to_string(), None),
Status::Scanning { remaining_count } => (
"Project index scanning...".to_string(),
Some(format!("{} remaining...", remaining_count)),
),
};
if let Some(meta) = meta {
Tooltip::with_meta(tooltip, None, meta, cx)
} else {
Tooltip::text(tooltip, cx)
}
}
})
.on_click(cx.listener(move |this, _, cx| {
this.set_enabled(!is_enabled);
cx.notify();
}))
}
}

View File

@ -1,48 +1,86 @@
use anyhow::{anyhow, Result};
use gpui::{AnyView, Task, WindowContext};
use std::collections::HashMap;
use gpui::{Task, WindowContext};
use std::{
any::TypeId,
collections::HashMap,
sync::atomic::{AtomicBool, Ordering::SeqCst},
};
use crate::tool::{
LanguageModelTool, ToolFunctionCall, ToolFunctionCallResult, ToolFunctionDefinition,
};
// Internal Tool representation for the registry
pub struct Tool {
enabled: AtomicBool,
type_id: TypeId,
call: Box<dyn Fn(&ToolFunctionCall, &mut WindowContext) -> Task<Result<ToolFunctionCall>>>,
definition: ToolFunctionDefinition,
}
impl Tool {
fn new(
type_id: TypeId,
call: Box<dyn Fn(&ToolFunctionCall, &mut WindowContext) -> Task<Result<ToolFunctionCall>>>,
definition: ToolFunctionDefinition,
) -> Self {
Self {
enabled: AtomicBool::new(true),
type_id,
call,
definition,
}
}
}
pub struct ToolRegistry {
tools: HashMap<
String,
Box<dyn Fn(&ToolFunctionCall, &mut WindowContext) -> Task<Result<ToolFunctionCall>>>,
>,
definitions: Vec<ToolFunctionDefinition>,
status_views: Vec<AnyView>,
tools: HashMap<String, Tool>,
}
impl ToolRegistry {
pub fn new() -> Self {
Self {
tools: HashMap::new(),
definitions: Vec::new(),
status_views: Vec::new(),
}
}
pub fn definitions(&self) -> &[ToolFunctionDefinition] {
&self.definitions
pub fn set_tool_enabled<T: 'static + LanguageModelTool>(&self, is_enabled: bool) {
for tool in self.tools.values() {
if tool.type_id == TypeId::of::<T>() {
tool.enabled.store(is_enabled, SeqCst);
return;
}
}
}
pub fn is_tool_enabled<T: 'static + LanguageModelTool>(&self) -> bool {
for tool in self.tools.values() {
if tool.type_id == TypeId::of::<T>() {
return tool.enabled.load(SeqCst);
}
}
false
}
pub fn definitions(&self) -> Vec<ToolFunctionDefinition> {
self.tools
.values()
.filter(|tool| tool.enabled.load(SeqCst))
.map(|tool| tool.definition.clone())
.collect()
}
pub fn register<T: 'static + LanguageModelTool>(
&mut self,
tool: T,
cx: &mut WindowContext,
_cx: &mut WindowContext,
) -> Result<()> {
self.definitions.push(tool.definition());
if let Some(tool_view) = tool.status_view(cx) {
self.status_views.push(tool_view);
}
let definition = tool.definition();
let name = tool.name();
let previous = self.tools.insert(
name.clone(),
// registry.call(tool_call, cx)
let registered_tool = Tool::new(
TypeId::of::<T>(),
Box::new(
move |tool_call: &ToolFunctionCall, cx: &mut WindowContext| {
let name = tool_call.name.clone();
@ -77,8 +115,11 @@ impl ToolRegistry {
})
},
),
definition,
);
let previous = self.tools.insert(name.clone(), registered_tool);
if previous.is_some() {
return Err(anyhow!("already registered a tool with name {}", name));
}
@ -109,11 +150,7 @@ impl ToolRegistry {
}
};
tool(tool_call, cx)
}
pub fn status_views(&self) -> &[AnyView] {
&self.status_views
(tool.call)(tool_call, cx)
}
}

View File

@ -104,8 +104,4 @@ pub trait LanguageModelTool {
output: Result<Self::Output>,
cx: &mut WindowContext,
) -> View<Self::View>;
fn status_view(&self, _cx: &mut WindowContext) -> Option<AnyView> {
None
}
}