From 6ff01b17ca6e4a1fb70a5361d7ff52e31cbfb579 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Thu, 30 May 2024 12:36:07 +0200 Subject: [PATCH] Improve model selection in the assistant (#12472) https://github.com/zed-industries/zed/assets/482957/3b017850-b7b6-457a-9b2f-324d5533442e Release Notes: - Improved the UX for selecting a model in the assistant panel. You can now switch model using just the keyboard by pressing `alt-m`. Also, when switching models via the UI, settings will now be updated automatically. --- Cargo.lock | 3 + assets/keymaps/default-linux.json | 3 +- assets/keymaps/default-macos.json | 5 +- crates/anthropic/Cargo.toml | 1 + crates/anthropic/src/anthropic.rs | 3 +- crates/assistant/Cargo.toml | 1 + crates/assistant/src/assistant.rs | 5 +- crates/assistant/src/assistant_panel.rs | 184 ++++--------- crates/assistant/src/assistant_settings.rs | 256 +++++++++++++----- crates/assistant/src/completion_provider.rs | 76 +++--- .../src/completion_provider/anthropic.rs | 21 +- .../src/completion_provider/open_ai.rs | 21 +- .../assistant/src/completion_provider/zed.rs | 30 +- crates/assistant/src/model_selector.rs | 84 ++++++ crates/open_ai/Cargo.toml | 1 + crates/open_ai/src/open_ai.rs | 6 +- crates/ui/src/components/popover_menu.rs | 112 ++++++-- 17 files changed, 517 insertions(+), 295 deletions(-) create mode 100644 crates/assistant/src/model_selector.rs diff --git a/Cargo.lock b/Cargo.lock index 43a5ac03cc..6567bf9f8a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -230,6 +230,7 @@ dependencies = [ "schemars", "serde", "serde_json", + "strum", "tokio", ] @@ -376,6 +377,7 @@ dependencies = [ "settings", "smol", "strsim 0.11.1", + "strum", "telemetry_events", "theme", "tiktoken-rs", @@ -6983,6 +6985,7 @@ dependencies = [ "schemars", "serde", "serde_json", + "strum", ] [[package]] diff --git a/assets/keymaps/default-linux.json b/assets/keymaps/default-linux.json index 67e12b0235..b62fe2522f 100644 --- a/assets/keymaps/default-linux.json +++ b/assets/keymaps/default-linux.json @@ -201,7 +201,8 @@ "context": "AssistantPanel", "bindings": { "ctrl-g": "search::SelectNextMatch", - "ctrl-shift-g": "search::SelectPrevMatch" + "ctrl-shift-g": "search::SelectPrevMatch", + "alt-m": "assistant::ToggleModelSelector" } }, { diff --git a/assets/keymaps/default-macos.json b/assets/keymaps/default-macos.json index a1e8ea2f4f..7141ee810b 100644 --- a/assets/keymaps/default-macos.json +++ b/assets/keymaps/default-macos.json @@ -214,10 +214,11 @@ } }, { - "context": "AssistantPanel", // Used in the assistant crate, which we're replacing + "context": "AssistantPanel", "bindings": { "cmd-g": "search::SelectNextMatch", - "cmd-shift-g": "search::SelectPrevMatch" + "cmd-shift-g": "search::SelectPrevMatch", + "alt-m": "assistant::ToggleModelSelector" } }, { diff --git a/crates/anthropic/Cargo.toml b/crates/anthropic/Cargo.toml index 484a9b3e10..0ea24d5c07 100644 --- a/crates/anthropic/Cargo.toml +++ b/crates/anthropic/Cargo.toml @@ -23,6 +23,7 @@ isahc.workspace = true schemars = { workspace = true, optional = true } serde.workspace = true serde_json.workspace = true +strum.workspace = true [dev-dependencies] tokio.workspace = true diff --git a/crates/anthropic/src/anthropic.rs b/crates/anthropic/src/anthropic.rs index 65df4e74dc..7927b88000 100644 --- a/crates/anthropic/src/anthropic.rs +++ b/crates/anthropic/src/anthropic.rs @@ -4,11 +4,12 @@ use http::{AsyncBody, HttpClient, Method, Request as HttpRequest}; use isahc::config::Configurable; use serde::{Deserialize, Serialize}; use std::{convert::TryFrom, time::Duration}; +use strum::EnumIter; pub const ANTHROPIC_API_URL: &'static str = "https://api.anthropic.com"; #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] -#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)] +#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)] pub enum Model { #[default] #[serde(alias = "claude-3-opus", rename = "claude-3-opus-20240229")] diff --git a/crates/assistant/Cargo.toml b/crates/assistant/Cargo.toml index 8f745a4ee0..f3d79ba467 100644 --- a/crates/assistant/Cargo.toml +++ b/crates/assistant/Cargo.toml @@ -49,6 +49,7 @@ serde_json.workspace = true settings.workspace = true smol.workspace = true strsim = "0.11" +strum.workspace = true telemetry_events.workspace = true theme.workspace = true tiktoken-rs.workspace = true diff --git a/crates/assistant/src/assistant.rs b/crates/assistant/src/assistant.rs index 19ae88e5a2..b6b49da9e1 100644 --- a/crates/assistant/src/assistant.rs +++ b/crates/assistant/src/assistant.rs @@ -2,6 +2,7 @@ pub mod assistant_panel; pub mod assistant_settings; mod codegen; mod completion_provider; +mod model_selector; mod prompts; mod saved_conversation; mod search; @@ -15,6 +16,7 @@ use client::{proto, Client}; use command_palette_hooks::CommandPaletteFilter; pub(crate) use completion_provider::*; use gpui::{actions, AppContext, Global, SharedString, UpdateGlobal}; +pub(crate) use model_selector::*; pub(crate) use saved_conversation::*; use semantic_index::{CloudEmbeddingProvider, SemanticIndex}; use serde::{Deserialize, Serialize}; @@ -38,7 +40,8 @@ actions!( InsertActivePrompt, ToggleHistory, ApplyEdit, - ConfirmCommand + ConfirmCommand, + ToggleModelSelector ] ); diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index 6b2b31d14a..26aecfbe5f 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -1,7 +1,7 @@ use crate::prompts::{generate_content_prompt, PromptLibrary, PromptManager}; use crate::slash_command::{rustdoc_command, search_command, tabs_command}; use crate::{ - assistant_settings::{AssistantDockPosition, AssistantSettings, ZedDotDevModel}, + assistant_settings::{AssistantDockPosition, AssistantSettings}, codegen::{self, Codegen, CodegenKind}, search::*, slash_command::{ @@ -9,10 +9,11 @@ use crate::{ SlashCommandCompletionProvider, SlashCommandLine, SlashCommandRegistry, }, ApplyEdit, Assist, CompletionProvider, ConfirmCommand, CycleMessageRole, InlineAssist, - LanguageModel, LanguageModelRequest, LanguageModelRequestMessage, MessageId, MessageMetadata, - MessageStatus, QuoteSelection, ResetKey, Role, SavedConversation, SavedConversationMetadata, - SavedMessage, Split, ToggleFocus, ToggleHistory, + LanguageModelRequest, LanguageModelRequestMessage, MessageId, MessageMetadata, MessageStatus, + QuoteSelection, ResetKey, Role, SavedConversation, SavedConversationMetadata, SavedMessage, + Split, ToggleFocus, ToggleHistory, }; +use crate::{ModelSelector, ToggleModelSelector}; use anyhow::{anyhow, Result}; use assistant_slash_command::{SlashCommandOutput, SlashCommandOutputSection}; use client::telemetry::Telemetry; @@ -64,8 +65,8 @@ use std::{ use telemetry_events::AssistantKind; use theme::ThemeSettings; use ui::{ - popover_menu, prelude::*, ButtonLike, ContextMenu, ElevationIndex, KeyBinding, Tab, TabBar, - Tooltip, + popover_menu, prelude::*, ButtonLike, ContextMenu, ElevationIndex, KeyBinding, + PopoverMenuHandle, Tab, TabBar, Tooltip, }; use util::{paths::CONVERSATIONS_DIR, post_inc, ResultExt, TryFutureExt}; use uuid::Uuid; @@ -119,8 +120,8 @@ pub struct AssistantPanel { pending_inline_assist_ids_by_editor: HashMap, Vec>, inline_prompt_history: VecDeque, _watch_saved_conversations: Task>, - model: LanguageModel, authentication_prompt: Option, + model_menu_handle: PopoverMenuHandle, } struct ActiveConversationEditor { @@ -203,7 +204,6 @@ impl AssistantPanel { } }), ]; - let model = CompletionProvider::global(cx).default_model(); cx.observe_global::(|_, cx| { cx.notify(); @@ -244,8 +244,8 @@ impl AssistantPanel { pending_inline_assist_ids_by_editor: Default::default(), inline_prompt_history: Default::default(), _watch_saved_conversations, - model, authentication_prompt: None, + model_menu_handle: PopoverMenuHandle::default(), } }) }) @@ -277,12 +277,20 @@ impl AssistantPanel { if self.is_authenticated(cx) { self.authentication_prompt = None; - let model = CompletionProvider::global(cx).default_model(); - self.set_model(model, cx); + if let Some(editor) = self.active_conversation_editor() { + editor.update(cx, |active_conversation, cx| { + active_conversation + .conversation + .update(cx, |conversation, cx| { + conversation.completion_provider_changed(cx) + }) + }) + } if self.active_conversation_editor().is_none() { self.new_conversation(cx); } + cx.notify(); } else if self.authentication_prompt.is_none() || prev_settings_version != CompletionProvider::global(cx).settings_version() { @@ -290,6 +298,7 @@ impl AssistantPanel { Some(cx.update_global::(|provider, cx| { provider.authentication_prompt(cx) })); + cx.notify(); } } @@ -734,7 +743,7 @@ impl AssistantPanel { .map(|message| message.to_request_message(buffer)), ); } - let model = self.model.clone(); + let model = CompletionProvider::global(cx).model(); cx.spawn(|_, mut cx| async move { // I Don't know if we want to return a ? here. @@ -809,7 +818,6 @@ impl AssistantPanel { let editor = cx.new_view(|cx| { ConversationEditor::new( - self.model.clone(), self.languages.clone(), self.slash_commands.clone(), self.fs.clone(), @@ -850,53 +858,6 @@ impl AssistantPanel { cx.notify(); } - fn cycle_model(&mut self, cx: &mut ViewContext) { - let next_model = match &self.model { - LanguageModel::OpenAi(model) => LanguageModel::OpenAi(match &model { - open_ai::Model::ThreePointFiveTurbo => open_ai::Model::Four, - open_ai::Model::Four => open_ai::Model::FourTurbo, - open_ai::Model::FourTurbo => open_ai::Model::FourOmni, - open_ai::Model::FourOmni => open_ai::Model::ThreePointFiveTurbo, - }), - LanguageModel::Anthropic(model) => LanguageModel::Anthropic(match &model { - anthropic::Model::Claude3Opus => anthropic::Model::Claude3Sonnet, - anthropic::Model::Claude3Sonnet => anthropic::Model::Claude3Haiku, - anthropic::Model::Claude3Haiku => anthropic::Model::Claude3Opus, - }), - LanguageModel::ZedDotDev(model) => LanguageModel::ZedDotDev(match &model { - ZedDotDevModel::Gpt3Point5Turbo => ZedDotDevModel::Gpt4, - ZedDotDevModel::Gpt4 => ZedDotDevModel::Gpt4Turbo, - ZedDotDevModel::Gpt4Turbo => ZedDotDevModel::Gpt4Omni, - ZedDotDevModel::Gpt4Omni => ZedDotDevModel::Claude3Opus, - ZedDotDevModel::Claude3Opus => ZedDotDevModel::Claude3Sonnet, - ZedDotDevModel::Claude3Sonnet => ZedDotDevModel::Claude3Haiku, - ZedDotDevModel::Claude3Haiku => { - match CompletionProvider::global(cx).default_model() { - LanguageModel::ZedDotDev(custom @ ZedDotDevModel::Custom(_)) => custom, - _ => ZedDotDevModel::Gpt3Point5Turbo, - } - } - ZedDotDevModel::Custom(_) => ZedDotDevModel::Gpt3Point5Turbo, - }), - }; - - self.set_model(next_model, cx); - } - - fn set_model(&mut self, model: LanguageModel, cx: &mut ViewContext) { - self.model = model.clone(); - if let Some(editor) = self.active_conversation_editor() { - editor.update(cx, |active_conversation, cx| { - active_conversation - .conversation - .update(cx, |conversation, cx| { - conversation.set_model(model, cx); - }) - }) - } - cx.notify(); - } - fn handle_conversation_editor_event( &mut self, _: View, @@ -978,6 +939,10 @@ impl AssistantPanel { .detach_and_log_err(cx); } + fn toggle_model_selector(&mut self, _: &ToggleModelSelector, cx: &mut ViewContext) { + self.model_menu_handle.toggle(cx); + } + fn active_conversation_editor(&self) -> Option<&View> { Some(&self.active_conversation_editor.as_ref()?.editor) } @@ -1133,10 +1098,8 @@ impl AssistantPanel { cx.spawn(|this, mut cx| async move { let saved_conversation = SavedConversation::load(&path, fs.as_ref()).await?; - let model = this.update(&mut cx, |this, _| this.model.clone())?; let conversation = Conversation::deserialize( saved_conversation, - model, path.clone(), languages, slash_commands, @@ -1206,7 +1169,10 @@ impl AssistantPanel { this.child( h_flex() .gap_1() - .child(self.render_model(&conversation, cx)) + .child(ModelSelector::new( + self.model_menu_handle.clone(), + self.fs.clone(), + )) .children(self.render_remaining_tokens(&conversation, cx)), ) .child( @@ -1256,6 +1222,7 @@ impl AssistantPanel { .on_action(cx.listener(AssistantPanel::select_prev_match)) .on_action(cx.listener(AssistantPanel::handle_editor_cancel)) .on_action(cx.listener(AssistantPanel::reset_credentials)) + .on_action(cx.listener(AssistantPanel::toggle_model_selector)) .track_focus(&self.focus_handle) .child(header) .children(if self.toolbar.read(cx).hidden() { @@ -1314,23 +1281,12 @@ impl AssistantPanel { )) } - fn render_model( - &self, - conversation: &Model, - cx: &mut ViewContext, - ) -> impl IntoElement { - Button::new("current_model", conversation.read(cx).model.display_name()) - .style(ButtonStyle::Filled) - .tooltip(move |cx| Tooltip::text("Change Model", cx)) - .on_click(cx.listener(|this, _, cx| this.cycle_model(cx))) - } - fn render_remaining_tokens( &self, conversation: &Model, cx: &mut ViewContext, ) -> Option { - let remaining_tokens = conversation.read(cx).remaining_tokens()?; + let remaining_tokens = conversation.read(cx).remaining_tokens(cx)?; let remaining_tokens_color = if remaining_tokens <= 0 { Color::Error } else if remaining_tokens <= 500 { @@ -1486,7 +1442,6 @@ pub struct Conversation { pending_summary: Task>, completion_count: usize, pending_completions: Vec, - model: LanguageModel, token_count: Option, pending_token_count: Task>, pending_edit_suggestion_parse: Option>, @@ -1502,7 +1457,6 @@ impl EventEmitter for Conversation {} impl Conversation { fn new( - model: LanguageModel, language_registry: Arc, slash_command_registry: Arc, telemetry: Option>, @@ -1530,7 +1484,6 @@ impl Conversation { token_count: None, pending_token_count: Task::ready(None), pending_edit_suggestion_parse: None, - model, _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)], pending_save: Task::ready(Ok(())), path: None, @@ -1583,7 +1536,6 @@ impl Conversation { #[allow(clippy::too_many_arguments)] async fn deserialize( saved_conversation: SavedConversation, - model: LanguageModel, path: PathBuf, language_registry: Arc, slash_command_registry: Arc, @@ -1640,7 +1592,6 @@ impl Conversation { token_count: None, pending_edit_suggestion_parse: None, pending_token_count: Task::ready(None), - model, _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)], pending_save: Task::ready(Ok(())), path: Some(path), @@ -1938,12 +1889,12 @@ impl Conversation { } } - fn remaining_tokens(&self) -> Option { - Some(self.model.max_token_count() as isize - self.token_count? as isize) + fn remaining_tokens(&self, cx: &AppContext) -> Option { + let model = CompletionProvider::global(cx).model(); + Some(model.max_token_count() as isize - self.token_count? as isize) } - fn set_model(&mut self, model: LanguageModel, cx: &mut ModelContext) { - self.model = model; + fn completion_provider_changed(&mut self, cx: &mut ModelContext) { self.count_remaining_tokens(cx); } @@ -2079,10 +2030,11 @@ impl Conversation { } if let Some(telemetry) = this.telemetry.as_ref() { + let model = CompletionProvider::global(cx).model(); telemetry.report_assistant_event( this.id.clone(), AssistantKind::Panel, - this.model.telemetry_id(), + model.telemetry_id(), response_latency, error_message, ); @@ -2111,7 +2063,7 @@ impl Conversation { .map(|message| message.to_request_message(self.buffer.read(cx))); LanguageModelRequest { - model: self.model.clone(), + model: CompletionProvider::global(cx).model(), messages: messages.collect(), stop: vec![], temperature: 1.0, @@ -2300,7 +2252,7 @@ impl Conversation { .into(), })); let request = LanguageModelRequest { - model: self.model.clone(), + model: CompletionProvider::global(cx).model(), messages: messages.collect(), stop: vec![], temperature: 1.0, @@ -2605,7 +2557,6 @@ pub struct ConversationEditor { impl ConversationEditor { fn new( - model: LanguageModel, language_registry: Arc, slash_command_registry: Arc, fs: Arc, @@ -2618,7 +2569,6 @@ impl ConversationEditor { let conversation = cx.new_model(|cx| { Conversation::new( - model, language_registry, slash_command_registry, Some(telemetry), @@ -3847,15 +3797,8 @@ mod tests { init(cx); let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone())); - let conversation = cx.new_model(|cx| { - Conversation::new( - LanguageModel::default(), - registry, - Default::default(), - None, - cx, - ) - }); + let conversation = + cx.new_model(|cx| Conversation::new(registry, Default::default(), None, cx)); let buffer = conversation.read(cx).buffer.clone(); let message_1 = conversation.read(cx).message_anchors[0].clone(); @@ -3986,15 +3929,8 @@ mod tests { init(cx); let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone())); - let conversation = cx.new_model(|cx| { - Conversation::new( - LanguageModel::default(), - registry, - Default::default(), - None, - cx, - ) - }); + let conversation = + cx.new_model(|cx| Conversation::new(registry, Default::default(), None, cx)); let buffer = conversation.read(cx).buffer.clone(); let message_1 = conversation.read(cx).message_anchors[0].clone(); @@ -4092,15 +4028,8 @@ mod tests { cx.set_global(settings_store); init(cx); let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone())); - let conversation = cx.new_model(|cx| { - Conversation::new( - LanguageModel::default(), - registry, - Default::default(), - None, - cx, - ) - }); + let conversation = + cx.new_model(|cx| Conversation::new(registry, Default::default(), None, cx)); let buffer = conversation.read(cx).buffer.clone(); let message_1 = conversation.read(cx).message_anchors[0].clone(); @@ -4209,15 +4138,8 @@ mod tests { )); let registry = Arc::new(LanguageRegistry::test(cx.executor())); - let conversation = cx.new_model(|cx| { - Conversation::new( - LanguageModel::default(), - registry.clone(), - slash_command_registry, - None, - cx, - ) - }); + let conversation = cx + .new_model(|cx| Conversation::new(registry.clone(), slash_command_registry, None, cx)); let output_ranges = Rc::new(RefCell::new(HashSet::default())); conversation.update(cx, |_, cx| { @@ -4390,15 +4312,8 @@ mod tests { cx.set_global(CompletionProvider::Fake(FakeCompletionProvider::default())); cx.update(init); let registry = Arc::new(LanguageRegistry::test(cx.executor())); - let conversation = cx.new_model(|cx| { - Conversation::new( - LanguageModel::default(), - registry.clone(), - Default::default(), - None, - cx, - ) - }); + let conversation = + cx.new_model(|cx| Conversation::new(registry.clone(), Default::default(), None, cx)); let buffer = conversation.read_with(cx, |conversation, _| conversation.buffer.clone()); let message_0 = conversation.read_with(cx, |conversation, _| conversation.message_anchors[0].id); @@ -4434,7 +4349,6 @@ mod tests { let deserialized_conversation = Conversation::deserialize( conversation.read_with(cx, |conversation, cx| conversation.serialize(cx)), - LanguageModel::default(), Default::default(), registry.clone(), Default::default(), diff --git a/crates/assistant/src/assistant_settings.rs b/crates/assistant/src/assistant_settings.rs index d822db8f70..5d866b6efc 100644 --- a/crates/assistant/src/assistant_settings.rs +++ b/crates/assistant/src/assistant_settings.rs @@ -12,8 +12,11 @@ use serde::{ Deserialize, Deserializer, Serialize, Serializer, }; use settings::{Settings, SettingsSources}; +use strum::{EnumIter, IntoEnumIterator}; -#[derive(Clone, Debug, Default, PartialEq)] +use crate::LanguageModel; + +#[derive(Clone, Debug, Default, PartialEq, EnumIter)] pub enum ZedDotDevModel { Gpt3Point5Turbo, Gpt4, @@ -53,13 +56,10 @@ impl<'de> Deserialize<'de> for ZedDotDevModel { where E: de::Error, { - match value { - "gpt-3.5-turbo" => Ok(ZedDotDevModel::Gpt3Point5Turbo), - "gpt-4" => Ok(ZedDotDevModel::Gpt4), - "gpt-4-turbo-preview" => Ok(ZedDotDevModel::Gpt4Turbo), - "gpt-4o" => Ok(ZedDotDevModel::Gpt4Omni), - _ => Ok(ZedDotDevModel::Custom(value.to_owned())), - } + let model = ZedDotDevModel::iter() + .find(|model| model.id() == value) + .unwrap_or_else(|| ZedDotDevModel::Custom(value.to_string())); + Ok(model) } } @@ -73,24 +73,23 @@ impl JsonSchema for ZedDotDevModel { } fn json_schema(_generator: &mut schemars::gen::SchemaGenerator) -> Schema { - let variants = vec![ - "gpt-3.5-turbo".to_owned(), - "gpt-4".to_owned(), - "gpt-4-turbo-preview".to_owned(), - "gpt-4o".to_owned(), - ]; + let variants = ZedDotDevModel::iter() + .filter_map(|model| { + let id = model.id(); + if id.is_empty() { + None + } else { + Some(id.to_string()) + } + }) + .collect::>(); Schema::Object(SchemaObject { instance_type: Some(InstanceType::String.into()), - enum_values: Some(variants.into_iter().map(|s| s.into()).collect()), + enum_values: Some(variants.iter().map(|s| s.clone().into()).collect()), metadata: Some(Box::new(Metadata { title: Some("ZedDotDevModel".to_owned()), - default: Some(serde_json::json!("gpt-4-turbo-preview")), - examples: vec![ - serde_json::json!("gpt-3.5-turbo"), - serde_json::json!("gpt-4"), - serde_json::json!("gpt-4-turbo-preview"), - serde_json::json!("custom-model-name"), - ], + default: Some(ZedDotDevModel::default().id().into()), + examples: variants.into_iter().map(Into::into).collect(), ..Default::default() })), ..Default::default() @@ -145,51 +144,55 @@ pub enum AssistantDockPosition { Bottom, } -#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)] -#[serde(tag = "name", rename_all = "snake_case")] +#[derive(Debug, PartialEq)] pub enum AssistantProvider { - #[serde(rename = "zed.dev")] ZedDotDev { - #[serde(default)] - default_model: ZedDotDevModel, + model: ZedDotDevModel, }, - #[serde(rename = "openai")] OpenAi { - #[serde(default)] - default_model: OpenAiModel, - #[serde(default = "open_ai_url")] + model: OpenAiModel, api_url: String, - #[serde(default)] low_speed_timeout_in_seconds: Option, }, - #[serde(rename = "anthropic")] Anthropic { - #[serde(default)] - default_model: AnthropicModel, - #[serde(default = "anthropic_api_url")] + model: AnthropicModel, api_url: String, - #[serde(default)] low_speed_timeout_in_seconds: Option, }, } impl Default for AssistantProvider { fn default() -> Self { - Self::ZedDotDev { - default_model: ZedDotDevModel::default(), + Self::OpenAi { + model: OpenAiModel::default(), + api_url: open_ai::OPEN_AI_API_URL.into(), + low_speed_timeout_in_seconds: None, } } } -fn open_ai_url() -> String { - open_ai::OPEN_AI_API_URL.to_string() +#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)] +#[serde(tag = "name", rename_all = "snake_case")] +pub enum AssistantProviderContent { + #[serde(rename = "zed.dev")] + ZedDotDev { + default_model: Option, + }, + #[serde(rename = "openai")] + OpenAi { + default_model: Option, + api_url: Option, + low_speed_timeout_in_seconds: Option, + }, + #[serde(rename = "anthropic")] + Anthropic { + default_model: Option, + api_url: Option, + low_speed_timeout_in_seconds: Option, + }, } -fn anthropic_api_url() -> String { - anthropic::ANTHROPIC_API_URL.to_string() -} - -#[derive(Default, Debug, Deserialize, Serialize)] +#[derive(Debug, Default)] pub struct AssistantSettings { pub enabled: bool, pub button: bool, @@ -240,16 +243,16 @@ impl AssistantSettingsContent { default_width: settings.default_width, default_height: settings.default_height, provider: if let Some(open_ai_api_url) = settings.openai_api_url.as_ref() { - Some(AssistantProvider::OpenAi { - default_model: settings.default_open_ai_model.clone().unwrap_or_default(), - api_url: open_ai_api_url.clone(), + Some(AssistantProviderContent::OpenAi { + default_model: settings.default_open_ai_model.clone(), + api_url: Some(open_ai_api_url.clone()), low_speed_timeout_in_seconds: None, }) } else { settings.default_open_ai_model.clone().map(|open_ai_model| { - AssistantProvider::OpenAi { - default_model: open_ai_model, - api_url: open_ai_url(), + AssistantProviderContent::OpenAi { + default_model: Some(open_ai_model), + api_url: None, low_speed_timeout_in_seconds: None, } }) @@ -270,6 +273,64 @@ impl AssistantSettingsContent { } } } + + pub fn set_model(&mut self, new_model: LanguageModel) { + match self { + AssistantSettingsContent::Versioned(settings) => match settings { + VersionedAssistantSettingsContent::V1(settings) => match &mut settings.provider { + Some(AssistantProviderContent::ZedDotDev { + default_model: model, + }) => { + if let LanguageModel::ZedDotDev(new_model) = new_model { + *model = Some(new_model); + } + } + Some(AssistantProviderContent::OpenAi { + default_model: model, + .. + }) => { + if let LanguageModel::OpenAi(new_model) = new_model { + *model = Some(new_model); + } + } + Some(AssistantProviderContent::Anthropic { + default_model: model, + .. + }) => { + if let LanguageModel::Anthropic(new_model) = new_model { + *model = Some(new_model); + } + } + provider => match new_model { + LanguageModel::ZedDotDev(model) => { + *provider = Some(AssistantProviderContent::ZedDotDev { + default_model: Some(model), + }) + } + LanguageModel::OpenAi(model) => { + *provider = Some(AssistantProviderContent::OpenAi { + default_model: Some(model), + api_url: None, + low_speed_timeout_in_seconds: None, + }) + } + LanguageModel::Anthropic(model) => { + *provider = Some(AssistantProviderContent::Anthropic { + default_model: Some(model), + api_url: None, + low_speed_timeout_in_seconds: None, + }) + } + }, + }, + }, + AssistantSettingsContent::Legacy(settings) => { + if let LanguageModel::OpenAi(model) = new_model { + settings.default_open_ai_model = Some(model); + } + } + } + } } #[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)] @@ -318,7 +379,7 @@ pub struct AssistantSettingsContentV1 { /// /// This can either be the internal `zed.dev` service or an external `openai` service, /// each with their respective default models and configurations. - provider: Option, + provider: Option, } #[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)] @@ -376,31 +437,82 @@ impl Settings for AssistantSettings { if let Some(provider) = value.provider.clone() { match (&mut settings.provider, provider) { ( - AssistantProvider::ZedDotDev { default_model }, - AssistantProvider::ZedDotDev { - default_model: default_model_override, + AssistantProvider::ZedDotDev { model }, + AssistantProviderContent::ZedDotDev { + default_model: model_override, }, ) => { - *default_model = default_model_override; + merge(model, model_override); } ( AssistantProvider::OpenAi { - default_model, + model, api_url, low_speed_timeout_in_seconds, }, - AssistantProvider::OpenAi { - default_model: default_model_override, + AssistantProviderContent::OpenAi { + default_model: model_override, api_url: api_url_override, low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override, }, ) => { - *default_model = default_model_override; - *api_url = api_url_override; - *low_speed_timeout_in_seconds = low_speed_timeout_in_seconds_override; + merge(model, model_override); + merge(api_url, api_url_override); + if let Some(low_speed_timeout_in_seconds_override) = + low_speed_timeout_in_seconds_override + { + *low_speed_timeout_in_seconds = + Some(low_speed_timeout_in_seconds_override); + } } - (merged, provider_override) => { - *merged = provider_override; + ( + AssistantProvider::Anthropic { + model, + api_url, + low_speed_timeout_in_seconds, + }, + AssistantProviderContent::Anthropic { + default_model: model_override, + api_url: api_url_override, + low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override, + }, + ) => { + merge(model, model_override); + merge(api_url, api_url_override); + if let Some(low_speed_timeout_in_seconds_override) = + low_speed_timeout_in_seconds_override + { + *low_speed_timeout_in_seconds = + Some(low_speed_timeout_in_seconds_override); + } + } + (provider, provider_override) => { + *provider = match provider_override { + AssistantProviderContent::ZedDotDev { + default_model: model, + } => AssistantProvider::ZedDotDev { + model: model.unwrap_or_default(), + }, + AssistantProviderContent::OpenAi { + default_model: model, + api_url, + low_speed_timeout_in_seconds, + } => AssistantProvider::OpenAi { + model: model.unwrap_or_default(), + api_url: api_url.unwrap_or_else(|| open_ai::OPEN_AI_API_URL.into()), + low_speed_timeout_in_seconds, + }, + AssistantProviderContent::Anthropic { + default_model: model, + api_url, + low_speed_timeout_in_seconds, + } => AssistantProvider::Anthropic { + model: model.unwrap_or_default(), + api_url: api_url + .unwrap_or_else(|| anthropic::ANTHROPIC_API_URL.into()), + low_speed_timeout_in_seconds, + }, + }; } } } @@ -410,7 +522,7 @@ impl Settings for AssistantSettings { } } -fn merge(target: &mut T, value: Option) { +fn merge(target: &mut T, value: Option) { if let Some(value) = value { *target = value; } @@ -433,8 +545,8 @@ mod tests { assert_eq!( AssistantSettings::get_global(cx).provider, AssistantProvider::OpenAi { - default_model: OpenAiModel::FourOmni, - api_url: open_ai_url(), + model: OpenAiModel::FourOmni, + api_url: open_ai::OPEN_AI_API_URL.into(), low_speed_timeout_in_seconds: None, } ); @@ -455,7 +567,7 @@ mod tests { assert_eq!( AssistantSettings::get_global(cx).provider, AssistantProvider::OpenAi { - default_model: OpenAiModel::FourOmni, + model: OpenAiModel::FourOmni, api_url: "test-url".into(), low_speed_timeout_in_seconds: None, } @@ -475,8 +587,8 @@ mod tests { assert_eq!( AssistantSettings::get_global(cx).provider, AssistantProvider::OpenAi { - default_model: OpenAiModel::Four, - api_url: open_ai_url(), + model: OpenAiModel::Four, + api_url: open_ai::OPEN_AI_API_URL.into(), low_speed_timeout_in_seconds: None, } ); @@ -501,7 +613,7 @@ mod tests { assert_eq!( AssistantSettings::get_global(cx).provider, AssistantProvider::ZedDotDev { - default_model: ZedDotDevModel::Custom("custom".into()) + model: ZedDotDevModel::Custom("custom".into()) } ); } diff --git a/crates/assistant/src/completion_provider.rs b/crates/assistant/src/completion_provider.rs index 666dab5dfc..99b8b407fb 100644 --- a/crates/assistant/src/completion_provider.rs +++ b/crates/assistant/src/completion_provider.rs @@ -25,31 +25,26 @@ use std::time::Duration; pub fn init(client: Arc, cx: &mut AppContext) { let mut settings_version = 0; let provider = match &AssistantSettings::get_global(cx).provider { - AssistantProvider::ZedDotDev { default_model } => { - CompletionProvider::ZedDotDev(ZedDotDevCompletionProvider::new( - default_model.clone(), - client.clone(), - settings_version, - cx, - )) - } + AssistantProvider::ZedDotDev { model } => CompletionProvider::ZedDotDev( + ZedDotDevCompletionProvider::new(model.clone(), client.clone(), settings_version, cx), + ), AssistantProvider::OpenAi { - default_model, + model, api_url, low_speed_timeout_in_seconds, } => CompletionProvider::OpenAi(OpenAiCompletionProvider::new( - default_model.clone(), + model.clone(), api_url.clone(), client.http_client(), low_speed_timeout_in_seconds.map(Duration::from_secs), settings_version, )), AssistantProvider::Anthropic { - default_model, + model, api_url, low_speed_timeout_in_seconds, } => CompletionProvider::Anthropic(AnthropicCompletionProvider::new( - default_model.clone(), + model.clone(), api_url.clone(), client.http_client(), low_speed_timeout_in_seconds.map(Duration::from_secs), @@ -65,13 +60,13 @@ pub fn init(client: Arc, cx: &mut AppContext) { ( CompletionProvider::OpenAi(provider), AssistantProvider::OpenAi { - default_model, + model, api_url, low_speed_timeout_in_seconds, }, ) => { provider.update( - default_model.clone(), + model.clone(), api_url.clone(), low_speed_timeout_in_seconds.map(Duration::from_secs), settings_version, @@ -80,13 +75,13 @@ pub fn init(client: Arc, cx: &mut AppContext) { ( CompletionProvider::Anthropic(provider), AssistantProvider::Anthropic { - default_model, + model, api_url, low_speed_timeout_in_seconds, }, ) => { provider.update( - default_model.clone(), + model.clone(), api_url.clone(), low_speed_timeout_in_seconds.map(Duration::from_secs), settings_version, @@ -94,13 +89,13 @@ pub fn init(client: Arc, cx: &mut AppContext) { } ( CompletionProvider::ZedDotDev(provider), - AssistantProvider::ZedDotDev { default_model }, + AssistantProvider::ZedDotDev { model }, ) => { - provider.update(default_model.clone(), settings_version); + provider.update(model.clone(), settings_version); } - (_, AssistantProvider::ZedDotDev { default_model }) => { + (_, AssistantProvider::ZedDotDev { model }) => { *provider = CompletionProvider::ZedDotDev(ZedDotDevCompletionProvider::new( - default_model.clone(), + model.clone(), client.clone(), settings_version, cx, @@ -109,13 +104,13 @@ pub fn init(client: Arc, cx: &mut AppContext) { ( _, AssistantProvider::OpenAi { - default_model, + model, api_url, low_speed_timeout_in_seconds, }, ) => { *provider = CompletionProvider::OpenAi(OpenAiCompletionProvider::new( - default_model.clone(), + model.clone(), api_url.clone(), client.http_client(), low_speed_timeout_in_seconds.map(Duration::from_secs), @@ -125,13 +120,13 @@ pub fn init(client: Arc, cx: &mut AppContext) { ( _, AssistantProvider::Anthropic { - default_model, + model, api_url, low_speed_timeout_in_seconds, }, ) => { *provider = CompletionProvider::Anthropic(AnthropicCompletionProvider::new( - default_model.clone(), + model.clone(), api_url.clone(), client.http_client(), low_speed_timeout_in_seconds.map(Duration::from_secs), @@ -159,6 +154,25 @@ impl CompletionProvider { cx.global::() } + pub fn available_models(&self) -> Vec { + match self { + CompletionProvider::OpenAi(provider) => provider + .available_models() + .map(LanguageModel::OpenAi) + .collect(), + CompletionProvider::Anthropic(provider) => provider + .available_models() + .map(LanguageModel::Anthropic) + .collect(), + CompletionProvider::ZedDotDev(provider) => provider + .available_models() + .map(LanguageModel::ZedDotDev) + .collect(), + #[cfg(test)] + CompletionProvider::Fake(_) => unimplemented!(), + } + } + pub fn settings_version(&self) -> usize { match self { CompletionProvider::OpenAi(provider) => provider.settings_version(), @@ -209,17 +223,13 @@ impl CompletionProvider { } } - pub fn default_model(&self) -> LanguageModel { + pub fn model(&self) -> LanguageModel { match self { - CompletionProvider::OpenAi(provider) => LanguageModel::OpenAi(provider.default_model()), - CompletionProvider::Anthropic(provider) => { - LanguageModel::Anthropic(provider.default_model()) - } - CompletionProvider::ZedDotDev(provider) => { - LanguageModel::ZedDotDev(provider.default_model()) - } + CompletionProvider::OpenAi(provider) => LanguageModel::OpenAi(provider.model()), + CompletionProvider::Anthropic(provider) => LanguageModel::Anthropic(provider.model()), + CompletionProvider::ZedDotDev(provider) => LanguageModel::ZedDotDev(provider.model()), #[cfg(test)] - CompletionProvider::Fake(_) => unimplemented!(), + CompletionProvider::Fake(_) => LanguageModel::default(), } } diff --git a/crates/assistant/src/completion_provider/anthropic.rs b/crates/assistant/src/completion_provider/anthropic.rs index f098f7eccc..8ae40993bc 100644 --- a/crates/assistant/src/completion_provider/anthropic.rs +++ b/crates/assistant/src/completion_provider/anthropic.rs @@ -12,6 +12,7 @@ use http::HttpClient; use settings::Settings; use std::time::Duration; use std::{env, sync::Arc}; +use strum::IntoEnumIterator; use theme::ThemeSettings; use ui::prelude::*; use util::ResultExt; @@ -19,7 +20,7 @@ use util::ResultExt; pub struct AnthropicCompletionProvider { api_key: Option, api_url: String, - default_model: AnthropicModel, + model: AnthropicModel, http_client: Arc, low_speed_timeout: Option, settings_version: usize, @@ -27,7 +28,7 @@ pub struct AnthropicCompletionProvider { impl AnthropicCompletionProvider { pub fn new( - default_model: AnthropicModel, + model: AnthropicModel, api_url: String, http_client: Arc, low_speed_timeout: Option, @@ -36,7 +37,7 @@ impl AnthropicCompletionProvider { Self { api_key: None, api_url, - default_model, + model, http_client, low_speed_timeout, settings_version, @@ -45,17 +46,21 @@ impl AnthropicCompletionProvider { pub fn update( &mut self, - default_model: AnthropicModel, + model: AnthropicModel, api_url: String, low_speed_timeout: Option, settings_version: usize, ) { - self.default_model = default_model; + self.model = model; self.api_url = api_url; self.low_speed_timeout = low_speed_timeout; self.settings_version = settings_version; } + pub fn available_models(&self) -> impl Iterator { + AnthropicModel::iter() + } + pub fn settings_version(&self) -> usize { self.settings_version } @@ -105,8 +110,8 @@ impl AnthropicCompletionProvider { .into() } - pub fn default_model(&self) -> AnthropicModel { - self.default_model.clone() + pub fn model(&self) -> AnthropicModel { + self.model.clone() } pub fn count_tokens( @@ -165,7 +170,7 @@ impl AnthropicCompletionProvider { fn to_anthropic_request(&self, request: LanguageModelRequest) -> Request { let model = match request.model { LanguageModel::Anthropic(model) => model, - _ => self.default_model(), + _ => self.model(), }; let mut system_message = String::new(); diff --git a/crates/assistant/src/completion_provider/open_ai.rs b/crates/assistant/src/completion_provider/open_ai.rs index bf99a95e05..6ab43d773b 100644 --- a/crates/assistant/src/completion_provider/open_ai.rs +++ b/crates/assistant/src/completion_provider/open_ai.rs @@ -11,6 +11,7 @@ use open_ai::{stream_completion, Request, RequestMessage, Role as OpenAiRole}; use settings::Settings; use std::time::Duration; use std::{env, sync::Arc}; +use strum::IntoEnumIterator; use theme::ThemeSettings; use ui::prelude::*; use util::ResultExt; @@ -18,7 +19,7 @@ use util::ResultExt; pub struct OpenAiCompletionProvider { api_key: Option, api_url: String, - default_model: OpenAiModel, + model: OpenAiModel, http_client: Arc, low_speed_timeout: Option, settings_version: usize, @@ -26,7 +27,7 @@ pub struct OpenAiCompletionProvider { impl OpenAiCompletionProvider { pub fn new( - default_model: OpenAiModel, + model: OpenAiModel, api_url: String, http_client: Arc, low_speed_timeout: Option, @@ -35,7 +36,7 @@ impl OpenAiCompletionProvider { Self { api_key: None, api_url, - default_model, + model, http_client, low_speed_timeout, settings_version, @@ -44,17 +45,21 @@ impl OpenAiCompletionProvider { pub fn update( &mut self, - default_model: OpenAiModel, + model: OpenAiModel, api_url: String, low_speed_timeout: Option, settings_version: usize, ) { - self.default_model = default_model; + self.model = model; self.api_url = api_url; self.low_speed_timeout = low_speed_timeout; self.settings_version = settings_version; } + pub fn available_models(&self) -> impl Iterator { + OpenAiModel::iter() + } + pub fn settings_version(&self) -> usize { self.settings_version } @@ -104,8 +109,8 @@ impl OpenAiCompletionProvider { .into() } - pub fn default_model(&self) -> OpenAiModel { - self.default_model.clone() + pub fn model(&self) -> OpenAiModel { + self.model.clone() } pub fn count_tokens( @@ -152,7 +157,7 @@ impl OpenAiCompletionProvider { fn to_open_ai_request(&self, request: LanguageModelRequest) -> Request { let model = match request.model { LanguageModel::OpenAi(model) => model, - _ => self.default_model(), + _ => self.model(), }; Request { diff --git a/crates/assistant/src/completion_provider/zed.rs b/crates/assistant/src/completion_provider/zed.rs index 8fa1498072..d300541a88 100644 --- a/crates/assistant/src/completion_provider/zed.rs +++ b/crates/assistant/src/completion_provider/zed.rs @@ -7,11 +7,12 @@ use client::{proto, Client}; use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt, TryFutureExt}; use gpui::{AnyView, AppContext, Task}; use std::{future, sync::Arc}; +use strum::IntoEnumIterator; use ui::prelude::*; pub struct ZedDotDevCompletionProvider { client: Arc, - default_model: ZedDotDevModel, + model: ZedDotDevModel, settings_version: usize, status: client::Status, _maintain_client_status: Task<()>, @@ -19,7 +20,7 @@ pub struct ZedDotDevCompletionProvider { impl ZedDotDevCompletionProvider { pub fn new( - default_model: ZedDotDevModel, + model: ZedDotDevModel, client: Arc, settings_version: usize, cx: &mut AppContext, @@ -39,24 +40,39 @@ impl ZedDotDevCompletionProvider { }); Self { client, - default_model, + model, settings_version, status, _maintain_client_status: maintain_client_status, } } - pub fn update(&mut self, default_model: ZedDotDevModel, settings_version: usize) { - self.default_model = default_model; + pub fn update(&mut self, model: ZedDotDevModel, settings_version: usize) { + self.model = model; self.settings_version = settings_version; } + pub fn available_models(&self) -> impl Iterator { + let mut custom_model = if let ZedDotDevModel::Custom(custom_model) = self.model.clone() { + Some(custom_model) + } else { + None + }; + ZedDotDevModel::iter().filter_map(move |model| { + if let ZedDotDevModel::Custom(_) = model { + Some(ZedDotDevModel::Custom(custom_model.take()?)) + } else { + Some(model) + } + }) + } + pub fn settings_version(&self) -> usize { self.settings_version } - pub fn default_model(&self) -> ZedDotDevModel { - self.default_model.clone() + pub fn model(&self) -> ZedDotDevModel { + self.model.clone() } pub fn is_authenticated(&self) -> bool { diff --git a/crates/assistant/src/model_selector.rs b/crates/assistant/src/model_selector.rs new file mode 100644 index 0000000000..3a407bb547 --- /dev/null +++ b/crates/assistant/src/model_selector.rs @@ -0,0 +1,84 @@ +use std::sync::Arc; + +use crate::{assistant_settings::AssistantSettings, CompletionProvider, ToggleModelSelector}; +use fs::Fs; +use settings::update_settings_file; +use ui::{popover_menu, prelude::*, ButtonLike, ContextMenu, PopoverMenuHandle, Tooltip}; + +#[derive(IntoElement)] +pub struct ModelSelector { + handle: PopoverMenuHandle, + fs: Arc, +} + +impl ModelSelector { + pub fn new(handle: PopoverMenuHandle, fs: Arc) -> Self { + ModelSelector { handle, fs } + } +} + +impl RenderOnce for ModelSelector { + fn render(self, cx: &mut WindowContext) -> impl IntoElement { + popover_menu("model-switcher") + .with_handle(self.handle) + .menu(move |cx| { + ContextMenu::build(cx, |mut menu, cx| { + for model in CompletionProvider::global(cx).available_models() { + menu = menu.custom_entry( + { + let model = model.clone(); + move |_| Label::new(model.display_name()).into_any_element() + }, + { + let fs = self.fs.clone(); + let model = model.clone(); + move |cx| { + let model = model.clone(); + update_settings_file::( + fs.clone(), + cx, + move |settings| settings.set_model(model), + ); + } + }, + ); + } + menu + }) + .into() + }) + .trigger( + ButtonLike::new("active-model") + .child( + h_flex() + .w_full() + .gap_0p5() + .child( + div() + .overflow_x_hidden() + .flex_grow() + .whitespace_nowrap() + .child( + Label::new( + CompletionProvider::global(cx).model().display_name(), + ) + .size(LabelSize::Small) + .color(Color::Muted), + ), + ) + .child( + div().child( + Icon::new(IconName::ChevronDown) + .color(Color::Muted) + .size(IconSize::XSmall), + ), + ), + ) + .style(ButtonStyle::Subtle) + .tooltip(move |cx| { + Tooltip::for_action("Change Model", &ToggleModelSelector, cx) + }), + ) + .anchor(gpui::AnchorCorner::BottomRight) + } +} diff --git a/crates/open_ai/Cargo.toml b/crates/open_ai/Cargo.toml index 3ebb5c10c9..eae3a306a7 100644 --- a/crates/open_ai/Cargo.toml +++ b/crates/open_ai/Cargo.toml @@ -20,3 +20,4 @@ isahc.workspace = true schemars = { workspace = true, optional = true } serde.workspace = true serde_json.workspace = true +strum.workspace = true diff --git a/crates/open_ai/src/open_ai.rs b/crates/open_ai/src/open_ai.rs index 3c25bca9f9..a1cb63d60b 100644 --- a/crates/open_ai/src/open_ai.rs +++ b/crates/open_ai/src/open_ai.rs @@ -4,8 +4,8 @@ use http::{AsyncBody, HttpClient, Method, Request as HttpRequest}; use isahc::config::Configurable; use serde::{Deserialize, Serialize}; use serde_json::{Map, Value}; -use std::time::Duration; -use std::{convert::TryFrom, future::Future}; +use std::{convert::TryFrom, future::Future, time::Duration}; +use strum::EnumIter; pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1"; @@ -44,7 +44,7 @@ impl From for String { } #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] -#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)] +#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)] pub enum Model { #[serde(rename = "gpt-3.5-turbo", alias = "gpt-3.5-turbo-0613")] ThreePointFiveTurbo, diff --git a/crates/ui/src/components/popover_menu.rs b/crates/ui/src/components/popover_menu.rs index 3fb5f7834c..cddb8f797a 100644 --- a/crates/ui/src/components/popover_menu.rs +++ b/crates/ui/src/components/popover_menu.rs @@ -13,6 +13,51 @@ pub trait PopoverTrigger: IntoElement + Clickable + Selectable + 'static {} impl PopoverTrigger for T {} +pub struct PopoverMenuHandle(Rc>>>); + +impl Clone for PopoverMenuHandle { + fn clone(&self) -> Self { + Self(self.0.clone()) + } +} + +impl Default for PopoverMenuHandle { + fn default() -> Self { + Self(Rc::default()) + } +} + +struct PopoverMenuHandleState { + menu_builder: Rc Option>>, + menu: Rc>>>, +} + +impl PopoverMenuHandle { + pub fn show(&self, cx: &mut WindowContext) { + if let Some(state) = self.0.borrow().as_ref() { + show_menu(&state.menu_builder, &state.menu, cx); + } + } + + pub fn hide(&self, cx: &mut WindowContext) { + if let Some(state) = self.0.borrow().as_ref() { + if let Some(menu) = state.menu.borrow().as_ref() { + menu.update(cx, |_, cx| cx.emit(DismissEvent)); + } + } + } + + pub fn toggle(&self, cx: &mut WindowContext) { + if let Some(state) = self.0.borrow().as_ref() { + if state.menu.borrow().is_some() { + self.hide(cx); + } else { + self.show(cx); + } + } + } +} + pub struct PopoverMenu { id: ElementId, child_builder: Option< @@ -28,6 +73,7 @@ pub struct PopoverMenu { anchor: AnchorCorner, attach: Option, offset: Option>, + trigger_handle: Option>, } impl PopoverMenu { @@ -36,35 +82,17 @@ impl PopoverMenu { self } + pub fn with_handle(mut self, handle: PopoverMenuHandle) -> Self { + self.trigger_handle = Some(handle); + self + } + pub fn trigger(mut self, t: T) -> Self { self.child_builder = Some(Box::new(|menu, builder| { let open = menu.borrow().is_some(); t.selected(open) .when_some(builder, |el, builder| { - el.on_click({ - move |_, cx| { - let Some(new_menu) = (builder)(cx) else { - return; - }; - let menu2 = menu.clone(); - let previous_focus_handle = cx.focused(); - - cx.subscribe(&new_menu, move |modal, _: &DismissEvent, cx| { - if modal.focus_handle(cx).contains_focused(cx) { - if let Some(previous_focus_handle) = - previous_focus_handle.as_ref() - { - cx.focus(previous_focus_handle); - } - } - *menu2.borrow_mut() = None; - cx.refresh(); - }) - .detach(); - cx.focus_view(&new_menu); - *menu.borrow_mut() = Some(new_menu); - } - }) + el.on_click(move |_, cx| show_menu(&builder, &menu, cx)) }) .into_any_element() })); @@ -111,6 +139,32 @@ impl PopoverMenu { } } +fn show_menu( + builder: &Rc Option>>, + menu: &Rc>>>, + cx: &mut WindowContext, +) { + let Some(new_menu) = (builder)(cx) else { + return; + }; + let menu2 = menu.clone(); + let previous_focus_handle = cx.focused(); + + cx.subscribe(&new_menu, move |modal, _: &DismissEvent, cx| { + if modal.focus_handle(cx).contains_focused(cx) { + if let Some(previous_focus_handle) = previous_focus_handle.as_ref() { + cx.focus(previous_focus_handle); + } + } + *menu2.borrow_mut() = None; + cx.refresh(); + }) + .detach(); + cx.focus_view(&new_menu); + *menu.borrow_mut() = Some(new_menu); + cx.refresh(); +} + /// Creates a [`PopoverMenu`] pub fn popover_menu(id: impl Into) -> PopoverMenu { PopoverMenu { @@ -120,6 +174,7 @@ pub fn popover_menu(id: impl Into) -> PopoverMenu anchor: AnchorCorner::TopLeft, attach: None, offset: None, + trigger_handle: None, } } @@ -190,6 +245,15 @@ impl Element for PopoverMenu { (child_builder)(element_state.menu.clone(), self.menu_builder.clone()) }); + if let Some(trigger_handle) = self.trigger_handle.take() { + if let Some(menu_builder) = self.menu_builder.clone() { + *trigger_handle.0.borrow_mut() = Some(PopoverMenuHandleState { + menu_builder, + menu: element_state.menu.clone(), + }); + } + } + let child_layout_id = child_element .as_mut() .map(|child_element| child_element.request_layout(cx));