From a4cdca5141e75cd5e9ea5e47d96b4be9b3ba0ffe Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Tue, 25 Jun 2024 13:41:55 +0200 Subject: [PATCH] Refine UX for assistants (#13502) image Release Notes: - N/A --- crates/assistant/src/assistant.rs | 26 +- crates/assistant/src/assistant_panel.rs | 71 +++- crates/assistant/src/inline_assistant.rs | 519 +++++++++++++++-------- crates/assistant/src/prompt_library.rs | 2 +- crates/zed/src/main.rs | 2 +- crates/zed/src/zed.rs | 2 +- 6 files changed, 419 insertions(+), 203 deletions(-) diff --git a/crates/assistant/src/assistant.rs b/crates/assistant/src/assistant.rs index 4ca07d3dd4..f8b5047a99 100644 --- a/crates/assistant/src/assistant.rs +++ b/crates/assistant/src/assistant.rs @@ -10,14 +10,14 @@ mod search; mod slash_command; mod streaming_diff; -pub use assistant_panel::AssistantPanel; - +pub use assistant_panel::{AssistantPanel, AssistantPanelEvent}; use assistant_settings::{AnthropicModel, AssistantSettings, CloudModel, OllamaModel, OpenAiModel}; use assistant_slash_command::SlashCommandRegistry; use client::{proto, Client}; use command_palette_hooks::CommandPaletteFilter; pub(crate) use completion_provider::*; pub(crate) use context_store::*; +use fs::Fs; use gpui::{actions, AppContext, Global, SharedString, UpdateGlobal}; pub(crate) use inline_assistant::*; pub(crate) use model_selector::*; @@ -264,7 +264,7 @@ impl Assistant { } } -pub fn init(client: Arc, cx: &mut AppContext) { +pub fn init(fs: Arc, client: Arc, cx: &mut AppContext) { cx.set_global(Assistant::default()); AssistantSettings::register(cx); @@ -288,7 +288,7 @@ pub fn init(client: Arc, cx: &mut AppContext) { assistant_slash_command::init(cx); register_slash_commands(cx); assistant_panel::init(cx); - inline_assistant::init(client.telemetry().clone(), cx); + inline_assistant::init(fs.clone(), client.telemetry().clone(), cx); RustdocStore::init_global(cx); CommandPaletteFilter::update_global(cx, |filter, _cx| { @@ -324,6 +324,24 @@ fn register_slash_commands(cx: &mut AppContext) { slash_command_registry.register_command(fetch_command::FetchSlashCommand, false); } +pub fn humanize_token_count(count: usize) -> String { + match count { + 0..=999 => count.to_string(), + 1000..=9999 => { + let thousands = count / 1000; + let hundreds = (count % 1000 + 50) / 100; + if hundreds == 0 { + format!("{}k", thousands) + } else if hundreds == 10 { + format!("{}k", thousands + 1) + } else { + format!("{}.{}k", thousands, hundreds) + } + } + _ => format!("{}k", (count + 500) / 1000), + } +} + #[cfg(test)] #[ctor::ctor] fn init_logger() { diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index 2279a98ac2..d13a8379fb 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -1,5 +1,6 @@ use crate::{ assistant_settings::{AssistantDockPosition, AssistantSettings}, + humanize_token_count, prompt_library::open_prompt_library, search::*, slash_command::{ @@ -89,6 +90,10 @@ pub fn init(cx: &mut AppContext) { .detach(); } +pub enum AssistantPanelEvent { + ContextEdited, +} + pub struct AssistantPanel { workspace: WeakView, width: Option, @@ -360,11 +365,11 @@ impl AssistantPanel { return; } - let Some(assistant) = workspace.panel::(cx) else { + let Some(assistant_panel) = workspace.panel::(cx) else { return; }; - let context_editor = assistant + let context_editor = assistant_panel .read(cx) .active_context_editor() .and_then(|editor| { @@ -391,25 +396,37 @@ impl AssistantPanel { return; }; - if assistant.update(cx, |assistant, cx| assistant.is_authenticated(cx)) { + if assistant_panel.update(cx, |panel, cx| panel.is_authenticated(cx)) { InlineAssistant::update_global(cx, |assistant, cx| { assistant.assist( &active_editor, Some(cx.view().downgrade()), - include_context, + include_context.then_some(&assistant_panel), cx, ) }) } else { - let assistant = assistant.downgrade(); + let assistant_panel = assistant_panel.downgrade(); cx.spawn(|workspace, mut cx| async move { - assistant + assistant_panel .update(&mut cx, |assistant, cx| assistant.authenticate(cx))? .await?; - if assistant.update(&mut cx, |assistant, cx| assistant.is_authenticated(cx))? { + if assistant_panel + .update(&mut cx, |assistant, cx| assistant.is_authenticated(cx))? + { cx.update(|cx| { + let assistant_panel = if include_context { + assistant_panel.upgrade() + } else { + None + }; InlineAssistant::update_global(cx, |assistant, cx| { - assistant.assist(&active_editor, Some(workspace), include_context, cx) + assistant.assist( + &active_editor, + Some(workspace), + assistant_panel.as_ref(), + cx, + ) }) })? } else { @@ -460,7 +477,7 @@ impl AssistantPanel { _subscriptions: subscriptions, }); self.show_saved_contexts = false; - + cx.emit(AssistantPanelEvent::ContextEdited); cx.notify(); } @@ -472,6 +489,7 @@ impl AssistantPanel { ) { match event { ContextEditorEvent::TabContentChanged => cx.notify(), + ContextEditorEvent::Edited => cx.emit(AssistantPanelEvent::ContextEdited), } } @@ -863,18 +881,33 @@ impl AssistantPanel { context: &Model, cx: &mut ViewContext, ) -> Option { - let remaining_tokens = context.read(cx).remaining_tokens(cx)?; - let remaining_tokens_color = if remaining_tokens <= 0 { + let model = CompletionProvider::global(cx).model(); + let token_count = context.read(cx).token_count()?; + let max_token_count = model.max_token_count(); + + let remaining_tokens = max_token_count as isize - token_count as isize; + let token_count_color = if remaining_tokens <= 0 { Color::Error - } else if remaining_tokens <= 500 { + } else if token_count as f32 / max_token_count as f32 >= 0.8 { Color::Warning } else { Color::Muted }; + Some( - Label::new(remaining_tokens.to_string()) - .size(LabelSize::Small) - .color(remaining_tokens_color), + h_flex() + .gap_0p5() + .child( + Label::new(humanize_token_count(token_count)) + .size(LabelSize::Small) + .color(token_count_color), + ) + .child(Label::new("/").size(LabelSize::Small).color(Color::Muted)) + .child( + Label::new(humanize_token_count(max_token_count)) + .size(LabelSize::Small) + .color(Color::Muted), + ), ) } } @@ -978,6 +1011,7 @@ impl Panel for AssistantPanel { } impl EventEmitter for AssistantPanel {} +impl EventEmitter for AssistantPanel {} impl FocusableView for AssistantPanel { fn focus_handle(&self, _cx: &AppContext) -> FocusHandle { @@ -1538,11 +1572,6 @@ impl Context { } } - fn remaining_tokens(&self, cx: &AppContext) -> Option { - let model = CompletionProvider::global(cx).model(); - Some(model.max_token_count() as isize - self.token_count? as isize) - } - fn completion_provider_changed(&mut self, cx: &mut ModelContext) { self.count_remaining_tokens(cx); } @@ -2183,6 +2212,7 @@ struct PendingCompletion { } enum ContextEditorEvent { + Edited, TabContentChanged, } @@ -2775,6 +2805,7 @@ impl ContextEditor { EditorEvent::SelectionsChanged { .. } => { self.scroll_position = self.cursor_scroll_position(cx); } + EditorEvent::BufferEdited => cx.emit(ContextEditorEvent::Edited), _ => {} } } diff --git a/crates/assistant/src/inline_assistant.rs b/crates/assistant/src/inline_assistant.rs index 76131384a9..ad4856b29d 100644 --- a/crates/assistant/src/inline_assistant.rs +++ b/crates/assistant/src/inline_assistant.rs @@ -1,8 +1,9 @@ use crate::{ - prompts::generate_content_prompt, AssistantPanel, CompletionProvider, Hunk, - LanguageModelRequest, LanguageModelRequestMessage, Role, StreamingDiff, + assistant_settings::AssistantSettings, humanize_token_count, prompts::generate_content_prompt, + AssistantPanel, AssistantPanelEvent, CompletionProvider, Hunk, LanguageModelRequest, + LanguageModelRequestMessage, Role, StreamingDiff, }; -use anyhow::{Context as _, Result}; +use anyhow::{anyhow, Context as _, Result}; use client::telemetry::Telemetry; use collections::{hash_map, HashMap, HashSet, VecDeque}; use editor::{ @@ -14,6 +15,7 @@ use editor::{ Anchor, AnchorRangeExt, Editor, EditorElement, EditorEvent, EditorMode, EditorStyle, ExcerptRange, GutterDimensions, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint, }; +use fs::Fs; use futures::{channel::mpsc, SinkExt, Stream, StreamExt}; use gpui::{ point, AppContext, EventEmitter, FocusHandle, FocusableView, FontStyle, FontWeight, Global, @@ -24,7 +26,7 @@ use language::{Buffer, Point, Selection, TransactionId}; use multi_buffer::MultiBufferRow; use parking_lot::Mutex; use rope::Rope; -use settings::Settings; +use settings::{update_settings_file, Settings}; use similar::TextDiff; use std::{ cmp, mem, @@ -32,15 +34,15 @@ use std::{ pin::Pin, sync::Arc, task::{self, Poll}, - time::Instant, + time::{Duration, Instant}, }; use theme::ThemeSettings; -use ui::{prelude::*, Tooltip}; +use ui::{prelude::*, ContextMenu, PopoverMenu, Tooltip}; use util::RangeExt; use workspace::{notifications::NotificationId, Toast, Workspace}; -pub fn init(telemetry: Arc, cx: &mut AppContext) { - cx.set_global(InlineAssistant::new(telemetry)); +pub fn init(fs: Arc, telemetry: Arc, cx: &mut AppContext) { + cx.set_global(InlineAssistant::new(fs, telemetry)); } const PROMPT_HISTORY_MAX_LEN: usize = 20; @@ -53,12 +55,13 @@ pub struct InlineAssistant { assist_groups: HashMap, prompt_history: VecDeque, telemetry: Option>, + fs: Arc, } impl Global for InlineAssistant {} impl InlineAssistant { - pub fn new(telemetry: Arc) -> Self { + pub fn new(fs: Arc, telemetry: Arc) -> Self { Self { next_assist_id: InlineAssistId::default(), next_assist_group_id: InlineAssistGroupId::default(), @@ -67,6 +70,7 @@ impl InlineAssistant { assist_groups: HashMap::default(), prompt_history: VecDeque::default(), telemetry: Some(telemetry), + fs, } } @@ -74,7 +78,7 @@ impl InlineAssistant { &mut self, editor: &View, workspace: Option>, - include_context: bool, + assistant_panel: Option<&View>, cx: &mut WindowContext, ) { let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx); @@ -151,7 +155,10 @@ impl InlineAssistant { self.prompt_history.clone(), prompt_buffer.clone(), codegen.clone(), + editor, + assistant_panel, workspace.clone(), + self.fs.clone(), cx, ) }); @@ -208,7 +215,7 @@ impl InlineAssistant { InlineAssist::new( assist_id, assist_group_id, - include_context, + assistant_panel.is_some(), editor, &prompt_editor, block_ids[0], @@ -706,8 +713,6 @@ impl InlineAssistant { return; } - assist.codegen.update(cx, |codegen, cx| codegen.undo(cx)); - let Some(user_prompt) = assist .decorations .as_ref() @@ -716,115 +721,138 @@ impl InlineAssistant { return; }; - let context = if assist.include_context { - assist.workspace.as_ref().and_then(|workspace| { - let workspace = workspace.upgrade()?.read(cx); - let assistant_panel = workspace.panel::(cx)?; - assistant_panel.read(cx).active_context(cx) - }) - } else { - None - }; - - let editor = if let Some(editor) = assist.editor.upgrade() { - editor - } else { - return; - }; - - let project_name = assist.workspace.as_ref().and_then(|workspace| { - let workspace = workspace.upgrade()?; - Some( - workspace - .read(cx) - .project() - .read(cx) - .worktree_root_names(cx) - .collect::>() - .join("/"), - ) - }); - self.prompt_history.retain(|prompt| *prompt != user_prompt); self.prompt_history.push_back(user_prompt.clone()); if self.prompt_history.len() > PROMPT_HISTORY_MAX_LEN { self.prompt_history.pop_front(); } + assist.codegen.update(cx, |codegen, cx| codegen.undo(cx)); let codegen = assist.codegen.clone(); - let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx); - let range = codegen.read(cx).range.clone(); - let start = snapshot.point_to_buffer_offset(range.start); - let end = snapshot.point_to_buffer_offset(range.end); - let (buffer, range) = if let Some((start, end)) = start.zip(end) { - let (start_buffer, start_buffer_offset) = start; - let (end_buffer, end_buffer_offset) = end; - if start_buffer.remote_id() == end_buffer.remote_id() { - (start_buffer.clone(), start_buffer_offset..end_buffer_offset) - } else { - self.finish_assist(assist_id, false, cx); - return; - } - } else { - self.finish_assist(assist_id, false, cx); - return; - }; - - let language = buffer.language_at(range.start); - let language_name = if let Some(language) = language.as_ref() { - if Arc::ptr_eq(language, &language::PLAIN_TEXT) { - None - } else { - Some(language.name()) - } - } else { - None - }; - - // Higher Temperature increases the randomness of model outputs. - // If Markdown or No Language is Known, increase the randomness for more creative output - // If Code, decrease temperature to get more deterministic outputs - let temperature = if let Some(language) = language_name.clone() { - if language.as_ref() == "Markdown" { - 1.0 - } else { - 0.5 - } - } else { - 1.0 - }; - - let prompt = cx.background_executor().spawn(async move { - let language_name = language_name.as_deref(); - generate_content_prompt(user_prompt, language_name, buffer, range, project_name) - }); - - let mut messages = Vec::new(); - if let Some(context) = context { - let request = context.read(cx).to_completion_request(cx); - messages = request.messages; - } - let model = CompletionProvider::global(cx).model(); + let request = self.request_for_inline_assist(assist_id, cx); cx.spawn(|mut cx| async move { - let prompt = prompt.await?; + let request = request.await?; + codegen.update(&mut cx, |codegen, cx| codegen.start(request, cx))?; + anyhow::Ok(()) + }) + .detach_and_log_err(cx); + } + + fn request_for_inline_assist( + &self, + assist_id: InlineAssistId, + cx: &mut WindowContext, + ) -> Task> { + cx.spawn(|mut cx| async move { + let (user_prompt, context_request, project_name, buffer, range, model) = cx + .read_global(|this: &InlineAssistant, cx: &WindowContext| { + let assist = this.assists.get(&assist_id).context("invalid assist")?; + let decorations = assist.decorations.as_ref().context("invalid assist")?; + let editor = assist.editor.upgrade().context("invalid assist")?; + let user_prompt = decorations.prompt_editor.read(cx).prompt(cx); + let context_request = if assist.include_context { + assist.workspace.as_ref().and_then(|workspace| { + let workspace = workspace.upgrade()?.read(cx); + let assistant_panel = workspace.panel::(cx)?; + Some( + assistant_panel + .read(cx) + .active_context(cx)? + .read(cx) + .to_completion_request(cx), + ) + }) + } else { + None + }; + let project_name = assist.workspace.as_ref().and_then(|workspace| { + let workspace = workspace.upgrade()?; + Some( + workspace + .read(cx) + .project() + .read(cx) + .worktree_root_names(cx) + .collect::>() + .join("/"), + ) + }); + let buffer = editor.read(cx).buffer().read(cx).snapshot(cx); + let range = assist.codegen.read(cx).range.clone(); + let model = CompletionProvider::global(cx).model(); + anyhow::Ok(( + user_prompt, + context_request, + project_name, + buffer, + range, + model, + )) + })??; + + let language = buffer.language_at(range.start); + let language_name = if let Some(language) = language.as_ref() { + if Arc::ptr_eq(language, &language::PLAIN_TEXT) { + None + } else { + Some(language.name()) + } + } else { + None + }; + + // Higher Temperature increases the randomness of model outputs. + // If Markdown or No Language is Known, increase the randomness for more creative output + // If Code, decrease temperature to get more deterministic outputs + let temperature = if let Some(language) = language_name.clone() { + if language.as_ref() == "Markdown" { + 1.0 + } else { + 0.5 + } + } else { + 1.0 + }; + + let prompt = cx + .background_executor() + .spawn(async move { + let language_name = language_name.as_deref(); + let start = buffer.point_to_buffer_offset(range.start); + let end = buffer.point_to_buffer_offset(range.end); + let (buffer, range) = if let Some((start, end)) = start.zip(end) { + let (start_buffer, start_buffer_offset) = start; + let (end_buffer, end_buffer_offset) = end; + if start_buffer.remote_id() == end_buffer.remote_id() { + (start_buffer.clone(), start_buffer_offset..end_buffer_offset) + } else { + return Err(anyhow!("invalid transformation range")); + } + } else { + return Err(anyhow!("invalid transformation range")); + }; + generate_content_prompt(user_prompt, language_name, buffer, range, project_name) + }) + .await?; + + let mut messages = Vec::new(); + if let Some(context_request) = context_request { + messages = context_request.messages; + } messages.push(LanguageModelRequestMessage { role: Role::User, content: prompt, }); - let request = LanguageModelRequest { + Ok(LanguageModelRequest { model, messages, stop: vec!["|END|>".to_string()], temperature, - }; - - codegen.update(&mut cx, |codegen, cx| codegen.start(request, cx))?; - anyhow::Ok(()) + }) }) - .detach_and_log_err(cx); } fn stop_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) { @@ -1142,6 +1170,7 @@ enum PromptEditorEvent { struct PromptEditor { id: InlineAssistId, + fs: Arc, height_in_lines: u8, editor: View, edited_since_done: bool, @@ -1150,9 +1179,12 @@ struct PromptEditor { prompt_history_ix: Option, pending_prompt: String, codegen: Model, - workspace: Option>, _codegen_subscription: Subscription, editor_subscriptions: Vec, + pending_token_count: Task>, + token_count: Option, + _token_count_subscriptions: Vec, + workspace: Option>, } impl EventEmitter for PromptEditor {} @@ -1160,6 +1192,7 @@ impl EventEmitter for PromptEditor {} impl Render for PromptEditor { fn render(&mut self, cx: &mut ViewContext) -> impl IntoElement { let gutter_dimensions = *self.gutter_dimensions.lock(); + let fs = self.fs.clone(); let buttons = match &self.codegen.read(cx).status { CodegenStatus::Idle => { @@ -1245,85 +1278,100 @@ impl Render for PromptEditor { } }; - v_flex().h_full().w_full().justify_end().child( - h_flex() - .bg(cx.theme().colors().editor_background) - .border_y_1() - .border_color(cx.theme().status().info_border) - .py_1p5() - .w_full() - .on_action(cx.listener(Self::confirm)) - .on_action(cx.listener(Self::cancel)) - .on_action(cx.listener(Self::move_up)) - .on_action(cx.listener(Self::move_down)) - .child( - h_flex() - .w(gutter_dimensions.full_width() + (gutter_dimensions.margin / 2.0)) - // .pr(gutter_dimensions.fold_area_width()) - .justify_center() - .gap_2() - .children(self.workspace.clone().map(|workspace| { - IconButton::new("context", IconName::Context) - .size(ButtonSize::None) - .icon_size(IconSize::XSmall) - .icon_color(Color::Muted) - .on_click({ - let workspace = workspace.clone(); - cx.listener(move |_, _, cx| { - workspace - .update(cx, |workspace, cx| { - workspace.focus_panel::(cx); - }) - .ok(); - }) + h_flex() + .bg(cx.theme().colors().editor_background) + .border_y_1() + .border_color(cx.theme().status().info_border) + .py_1p5() + .h_full() + .w_full() + .on_action(cx.listener(Self::confirm)) + .on_action(cx.listener(Self::cancel)) + .on_action(cx.listener(Self::move_up)) + .on_action(cx.listener(Self::move_down)) + .child( + h_flex() + .w(gutter_dimensions.full_width() + (gutter_dimensions.margin / 2.0)) + .justify_center() + .gap_2() + .child( + PopoverMenu::new("model-switcher") + .menu(move |cx| { + ContextMenu::build(cx, |mut menu, cx| { + for model in CompletionProvider::global(cx).available_models() { + menu = menu.custom_entry( + { + let model = model.clone(); + move |_| { + Label::new(model.display_name()) + .into_any_element() + } + }, + { + let fs = fs.clone(); + let model = model.clone(); + move |cx| { + let model = model.clone(); + update_settings_file::( + fs.clone(), + cx, + move |settings| settings.set_model(model), + ); + } + }, + ); + } + menu }) - .tooltip(move |cx| { - let token_count = workspace.upgrade().and_then(|workspace| { - let panel = - workspace.read(cx).panel::(cx)?; - let context = panel.read(cx).active_context(cx)?; - context.read(cx).token_count() - }); - if let Some(token_count) = token_count { + .into() + }) + .trigger( + IconButton::new("context", IconName::Settings) + .size(ButtonSize::None) + .icon_size(IconSize::Small) + .icon_color(Color::Muted) + .tooltip(move |cx| { Tooltip::with_meta( format!( - "{} Additional Context Tokens from Assistant", - token_count + "Using {}", + CompletionProvider::global(cx) + .model() + .display_name() ), - Some(&crate::ToggleFocus), - "Click to open…", + None, + "Click to Change Model", cx, ) - } else { - Tooltip::for_action( - "Toggle Assistant Panel", - &crate::ToggleFocus, - cx, - ) - } - }) - })) - .children( - if let CodegenStatus::Error(error) = &self.codegen.read(cx).status { - let error_message = SharedString::from(error.to_string()); - Some( - div() - .id("error") - .tooltip(move |cx| Tooltip::text(error_message.clone(), cx)) - .child( - Icon::new(IconName::XCircle) - .size(IconSize::Small) - .color(Color::Error), - ), - ) - } else { - None - }, - ), - ) - .child(div().flex_1().child(self.render_prompt_editor(cx))) - .child(h_flex().gap_2().pr_4().children(buttons)), - ) + }), + ) + .anchor(gpui::AnchorCorner::BottomRight), + ) + .children( + if let CodegenStatus::Error(error) = &self.codegen.read(cx).status { + let error_message = SharedString::from(error.to_string()); + Some( + div() + .id("error") + .tooltip(move |cx| Tooltip::text(error_message.clone(), cx)) + .child( + Icon::new(IconName::XCircle) + .size(IconSize::Small) + .color(Color::Error), + ), + ) + } else { + None + }, + ), + ) + .child(div().flex_1().child(self.render_prompt_editor(cx))) + .child( + h_flex() + .gap_2() + .pr_4() + .children(self.render_token_count(cx)) + .children(buttons), + ) } } @@ -1336,13 +1384,17 @@ impl FocusableView for PromptEditor { impl PromptEditor { const MAX_LINES: u8 = 8; + #[allow(clippy::too_many_arguments)] fn new( id: InlineAssistId, gutter_dimensions: Arc>, prompt_history: VecDeque, prompt_buffer: Model, codegen: Model, + parent_editor: &View, + assistant_panel: Option<&View>, workspace: Option>, + fs: Arc, cx: &mut ViewContext, ) -> Self { let prompt_editor = cx.new_view(|cx| { @@ -1363,6 +1415,15 @@ impl PromptEditor { editor.set_placeholder_text("Add a prompt…", cx); editor }); + + let mut token_count_subscriptions = Vec::new(); + token_count_subscriptions + .push(cx.subscribe(parent_editor, Self::handle_parent_editor_event)); + if let Some(assistant_panel) = assistant_panel { + token_count_subscriptions + .push(cx.subscribe(assistant_panel, Self::handle_assistant_panel_event)); + } + let mut this = Self { id, height_in_lines: 1, @@ -1375,9 +1436,14 @@ impl PromptEditor { _codegen_subscription: cx.observe(&codegen, Self::handle_codegen_changed), editor_subscriptions: Vec::new(), codegen, + fs, + pending_token_count: Task::ready(Ok(())), + token_count: None, + _token_count_subscriptions: token_count_subscriptions, workspace, }; this.count_lines(cx); + this.count_tokens(cx); this.subscribe_to_editor(cx); this } @@ -1436,6 +1502,47 @@ impl PromptEditor { } } + fn handle_parent_editor_event( + &mut self, + _: View, + event: &EditorEvent, + cx: &mut ViewContext, + ) { + if let EditorEvent::BufferEdited { .. } = event { + self.count_tokens(cx); + } + } + + fn handle_assistant_panel_event( + &mut self, + _: View, + event: &AssistantPanelEvent, + cx: &mut ViewContext, + ) { + let AssistantPanelEvent::ContextEdited { .. } = event; + self.count_tokens(cx); + } + + fn count_tokens(&mut self, cx: &mut ViewContext) { + let assist_id = self.id; + self.pending_token_count = cx.spawn(|this, mut cx| async move { + cx.background_executor().timer(Duration::from_secs(1)).await; + let request = cx + .update_global(|inline_assistant: &mut InlineAssistant, cx| { + inline_assistant.request_for_inline_assist(assist_id, cx) + })? + .await?; + + let token_count = cx + .update(|cx| CompletionProvider::global(cx).count_tokens(request, cx))? + .await?; + this.update(&mut cx, |this, cx| { + this.token_count = Some(token_count); + cx.notify(); + }) + }) + } + fn handle_prompt_editor_changed(&mut self, _: View, cx: &mut ViewContext) { self.count_lines(cx); } @@ -1460,6 +1567,9 @@ impl PromptEditor { self.edited_since_done = true; cx.notify(); } + EditorEvent::BufferEdited => { + self.count_tokens(cx); + } _ => {} } } @@ -1551,6 +1661,63 @@ impl PromptEditor { } } + fn render_token_count(&self, cx: &mut ViewContext) -> Option { + let model = CompletionProvider::global(cx).model(); + let token_count = self.token_count?; + let max_token_count = model.max_token_count(); + + let remaining_tokens = max_token_count as isize - token_count as isize; + let token_count_color = if remaining_tokens <= 0 { + Color::Error + } else if token_count as f32 / max_token_count as f32 >= 0.8 { + Color::Warning + } else { + Color::Muted + }; + + let mut token_count = h_flex() + .id("token_count") + .gap_0p5() + .child( + Label::new(humanize_token_count(token_count)) + .size(LabelSize::Small) + .color(token_count_color), + ) + .child(Label::new("/").size(LabelSize::Small).color(Color::Muted)) + .child( + Label::new(humanize_token_count(max_token_count)) + .size(LabelSize::Small) + .color(Color::Muted), + ); + if let Some(workspace) = self.workspace.clone() { + token_count = token_count + .tooltip(|cx| { + Tooltip::with_meta( + "Tokens Used by Inline Assistant", + None, + "Click to Open Assistant Panel", + cx, + ) + }) + .cursor_pointer() + .on_mouse_down(gpui::MouseButton::Left, |_, cx| cx.stop_propagation()) + .on_click(move |_, cx| { + cx.stop_propagation(); + workspace + .update(cx, |workspace, cx| { + workspace.focus_panel::(cx) + }) + .ok(); + }); + } else { + token_count = token_count + .cursor_default() + .tooltip(|cx| Tooltip::text("Tokens Used by Inline Assistant", cx)); + } + + Some(token_count) + } + fn render_prompt_editor(&self, cx: &mut ViewContext) -> impl IntoElement { let settings = ThemeSettings::get_global(cx); let text_style = TextStyle { diff --git a/crates/assistant/src/prompt_library.rs b/crates/assistant/src/prompt_library.rs index c3047c243d..6d87e383ce 100644 --- a/crates/assistant/src/prompt_library.rs +++ b/crates/assistant/src/prompt_library.rs @@ -569,7 +569,7 @@ impl PromptLibrary { let provider = CompletionProvider::global(cx); if provider.is_authenticated() { InlineAssistant::update_global(cx, |assistant, cx| { - assistant.assist(&prompt_editor, None, false, cx) + assistant.assist(&prompt_editor, None, None, cx) }) } else { for window in cx.windows() { diff --git a/crates/zed/src/main.rs b/crates/zed/src/main.rs index d82cc49205..e551c00026 100644 --- a/crates/zed/src/main.rs +++ b/crates/zed/src/main.rs @@ -219,7 +219,7 @@ fn init_ui(app_state: Arc, cx: &mut AppContext) -> Result<()> { inline_completion_registry::init(app_state.client.telemetry().clone(), cx); - assistant::init(app_state.client.clone(), cx); + assistant::init(app_state.fs.clone(), app_state.client.clone(), cx); repl::init(app_state.fs.clone(), cx); diff --git a/crates/zed/src/zed.rs b/crates/zed/src/zed.rs index d02b30f6b9..3f205bfb6d 100644 --- a/crates/zed/src/zed.rs +++ b/crates/zed/src/zed.rs @@ -3181,7 +3181,7 @@ mod tests { project_panel::init((), cx); outline_panel::init((), cx); terminal_view::init(cx); - assistant::init(app_state.client.clone(), cx); + assistant::init(app_state.fs.clone(), app_state.client.clone(), cx); tasks_ui::init(cx); initialize_workspace(app_state.clone(), cx); app_state