Improve model selection in the assistant (#12472)

https://github.com/zed-industries/zed/assets/482957/3b017850-b7b6-457a-9b2f-324d5533442e


Release Notes:

- Improved the UX for selecting a model in the assistant panel. You can
now switch model using just the keyboard by pressing `alt-m`. Also, when
switching models via the UI, settings will now be updated automatically.
This commit is contained in:
Antonio Scandurra 2024-05-30 12:36:07 +02:00 committed by GitHub
parent 5a149b970c
commit 6ff01b17ca
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 517 additions and 295 deletions

3
Cargo.lock generated
View File

@ -230,6 +230,7 @@ dependencies = [
"schemars", "schemars",
"serde", "serde",
"serde_json", "serde_json",
"strum",
"tokio", "tokio",
] ]
@ -376,6 +377,7 @@ dependencies = [
"settings", "settings",
"smol", "smol",
"strsim 0.11.1", "strsim 0.11.1",
"strum",
"telemetry_events", "telemetry_events",
"theme", "theme",
"tiktoken-rs", "tiktoken-rs",
@ -6983,6 +6985,7 @@ dependencies = [
"schemars", "schemars",
"serde", "serde",
"serde_json", "serde_json",
"strum",
] ]
[[package]] [[package]]

View File

@ -201,7 +201,8 @@
"context": "AssistantPanel", "context": "AssistantPanel",
"bindings": { "bindings": {
"ctrl-g": "search::SelectNextMatch", "ctrl-g": "search::SelectNextMatch",
"ctrl-shift-g": "search::SelectPrevMatch" "ctrl-shift-g": "search::SelectPrevMatch",
"alt-m": "assistant::ToggleModelSelector"
} }
}, },
{ {

View File

@ -214,10 +214,11 @@
} }
}, },
{ {
"context": "AssistantPanel", // Used in the assistant crate, which we're replacing "context": "AssistantPanel",
"bindings": { "bindings": {
"cmd-g": "search::SelectNextMatch", "cmd-g": "search::SelectNextMatch",
"cmd-shift-g": "search::SelectPrevMatch" "cmd-shift-g": "search::SelectPrevMatch",
"alt-m": "assistant::ToggleModelSelector"
} }
}, },
{ {

View File

@ -23,6 +23,7 @@ isahc.workspace = true
schemars = { workspace = true, optional = true } schemars = { workspace = true, optional = true }
serde.workspace = true serde.workspace = true
serde_json.workspace = true serde_json.workspace = true
strum.workspace = true
[dev-dependencies] [dev-dependencies]
tokio.workspace = true tokio.workspace = true

View File

@ -4,11 +4,12 @@ use http::{AsyncBody, HttpClient, Method, Request as HttpRequest};
use isahc::config::Configurable; use isahc::config::Configurable;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::{convert::TryFrom, time::Duration}; use std::{convert::TryFrom, time::Duration};
use strum::EnumIter;
pub const ANTHROPIC_API_URL: &'static str = "https://api.anthropic.com"; pub const ANTHROPIC_API_URL: &'static str = "https://api.anthropic.com";
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)] #[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
pub enum Model { pub enum Model {
#[default] #[default]
#[serde(alias = "claude-3-opus", rename = "claude-3-opus-20240229")] #[serde(alias = "claude-3-opus", rename = "claude-3-opus-20240229")]

View File

@ -49,6 +49,7 @@ serde_json.workspace = true
settings.workspace = true settings.workspace = true
smol.workspace = true smol.workspace = true
strsim = "0.11" strsim = "0.11"
strum.workspace = true
telemetry_events.workspace = true telemetry_events.workspace = true
theme.workspace = true theme.workspace = true
tiktoken-rs.workspace = true tiktoken-rs.workspace = true

View File

@ -2,6 +2,7 @@ pub mod assistant_panel;
pub mod assistant_settings; pub mod assistant_settings;
mod codegen; mod codegen;
mod completion_provider; mod completion_provider;
mod model_selector;
mod prompts; mod prompts;
mod saved_conversation; mod saved_conversation;
mod search; mod search;
@ -15,6 +16,7 @@ use client::{proto, Client};
use command_palette_hooks::CommandPaletteFilter; use command_palette_hooks::CommandPaletteFilter;
pub(crate) use completion_provider::*; pub(crate) use completion_provider::*;
use gpui::{actions, AppContext, Global, SharedString, UpdateGlobal}; use gpui::{actions, AppContext, Global, SharedString, UpdateGlobal};
pub(crate) use model_selector::*;
pub(crate) use saved_conversation::*; pub(crate) use saved_conversation::*;
use semantic_index::{CloudEmbeddingProvider, SemanticIndex}; use semantic_index::{CloudEmbeddingProvider, SemanticIndex};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -38,7 +40,8 @@ actions!(
InsertActivePrompt, InsertActivePrompt,
ToggleHistory, ToggleHistory,
ApplyEdit, ApplyEdit,
ConfirmCommand ConfirmCommand,
ToggleModelSelector
] ]
); );

View File

@ -1,7 +1,7 @@
use crate::prompts::{generate_content_prompt, PromptLibrary, PromptManager}; use crate::prompts::{generate_content_prompt, PromptLibrary, PromptManager};
use crate::slash_command::{rustdoc_command, search_command, tabs_command}; use crate::slash_command::{rustdoc_command, search_command, tabs_command};
use crate::{ use crate::{
assistant_settings::{AssistantDockPosition, AssistantSettings, ZedDotDevModel}, assistant_settings::{AssistantDockPosition, AssistantSettings},
codegen::{self, Codegen, CodegenKind}, codegen::{self, Codegen, CodegenKind},
search::*, search::*,
slash_command::{ slash_command::{
@ -9,10 +9,11 @@ use crate::{
SlashCommandCompletionProvider, SlashCommandLine, SlashCommandRegistry, SlashCommandCompletionProvider, SlashCommandLine, SlashCommandRegistry,
}, },
ApplyEdit, Assist, CompletionProvider, ConfirmCommand, CycleMessageRole, InlineAssist, ApplyEdit, Assist, CompletionProvider, ConfirmCommand, CycleMessageRole, InlineAssist,
LanguageModel, LanguageModelRequest, LanguageModelRequestMessage, MessageId, MessageMetadata, LanguageModelRequest, LanguageModelRequestMessage, MessageId, MessageMetadata, MessageStatus,
MessageStatus, QuoteSelection, ResetKey, Role, SavedConversation, SavedConversationMetadata, QuoteSelection, ResetKey, Role, SavedConversation, SavedConversationMetadata, SavedMessage,
SavedMessage, Split, ToggleFocus, ToggleHistory, Split, ToggleFocus, ToggleHistory,
}; };
use crate::{ModelSelector, ToggleModelSelector};
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use assistant_slash_command::{SlashCommandOutput, SlashCommandOutputSection}; use assistant_slash_command::{SlashCommandOutput, SlashCommandOutputSection};
use client::telemetry::Telemetry; use client::telemetry::Telemetry;
@ -64,8 +65,8 @@ use std::{
use telemetry_events::AssistantKind; use telemetry_events::AssistantKind;
use theme::ThemeSettings; use theme::ThemeSettings;
use ui::{ use ui::{
popover_menu, prelude::*, ButtonLike, ContextMenu, ElevationIndex, KeyBinding, Tab, TabBar, popover_menu, prelude::*, ButtonLike, ContextMenu, ElevationIndex, KeyBinding,
Tooltip, PopoverMenuHandle, Tab, TabBar, Tooltip,
}; };
use util::{paths::CONVERSATIONS_DIR, post_inc, ResultExt, TryFutureExt}; use util::{paths::CONVERSATIONS_DIR, post_inc, ResultExt, TryFutureExt};
use uuid::Uuid; use uuid::Uuid;
@ -119,8 +120,8 @@ pub struct AssistantPanel {
pending_inline_assist_ids_by_editor: HashMap<WeakView<Editor>, Vec<usize>>, pending_inline_assist_ids_by_editor: HashMap<WeakView<Editor>, Vec<usize>>,
inline_prompt_history: VecDeque<String>, inline_prompt_history: VecDeque<String>,
_watch_saved_conversations: Task<Result<()>>, _watch_saved_conversations: Task<Result<()>>,
model: LanguageModel,
authentication_prompt: Option<AnyView>, authentication_prompt: Option<AnyView>,
model_menu_handle: PopoverMenuHandle<ContextMenu>,
} }
struct ActiveConversationEditor { struct ActiveConversationEditor {
@ -203,7 +204,6 @@ impl AssistantPanel {
} }
}), }),
]; ];
let model = CompletionProvider::global(cx).default_model();
cx.observe_global::<FileIcons>(|_, cx| { cx.observe_global::<FileIcons>(|_, cx| {
cx.notify(); cx.notify();
@ -244,8 +244,8 @@ impl AssistantPanel {
pending_inline_assist_ids_by_editor: Default::default(), pending_inline_assist_ids_by_editor: Default::default(),
inline_prompt_history: Default::default(), inline_prompt_history: Default::default(),
_watch_saved_conversations, _watch_saved_conversations,
model,
authentication_prompt: None, authentication_prompt: None,
model_menu_handle: PopoverMenuHandle::default(),
} }
}) })
}) })
@ -277,12 +277,20 @@ impl AssistantPanel {
if self.is_authenticated(cx) { if self.is_authenticated(cx) {
self.authentication_prompt = None; self.authentication_prompt = None;
let model = CompletionProvider::global(cx).default_model(); if let Some(editor) = self.active_conversation_editor() {
self.set_model(model, cx); editor.update(cx, |active_conversation, cx| {
active_conversation
.conversation
.update(cx, |conversation, cx| {
conversation.completion_provider_changed(cx)
})
})
}
if self.active_conversation_editor().is_none() { if self.active_conversation_editor().is_none() {
self.new_conversation(cx); self.new_conversation(cx);
} }
cx.notify();
} else if self.authentication_prompt.is_none() } else if self.authentication_prompt.is_none()
|| prev_settings_version != CompletionProvider::global(cx).settings_version() || prev_settings_version != CompletionProvider::global(cx).settings_version()
{ {
@ -290,6 +298,7 @@ impl AssistantPanel {
Some(cx.update_global::<CompletionProvider, _>(|provider, cx| { Some(cx.update_global::<CompletionProvider, _>(|provider, cx| {
provider.authentication_prompt(cx) provider.authentication_prompt(cx)
})); }));
cx.notify();
} }
} }
@ -734,7 +743,7 @@ impl AssistantPanel {
.map(|message| message.to_request_message(buffer)), .map(|message| message.to_request_message(buffer)),
); );
} }
let model = self.model.clone(); let model = CompletionProvider::global(cx).model();
cx.spawn(|_, mut cx| async move { cx.spawn(|_, mut cx| async move {
// I Don't know if we want to return a ? here. // I Don't know if we want to return a ? here.
@ -809,7 +818,6 @@ impl AssistantPanel {
let editor = cx.new_view(|cx| { let editor = cx.new_view(|cx| {
ConversationEditor::new( ConversationEditor::new(
self.model.clone(),
self.languages.clone(), self.languages.clone(),
self.slash_commands.clone(), self.slash_commands.clone(),
self.fs.clone(), self.fs.clone(),
@ -850,53 +858,6 @@ impl AssistantPanel {
cx.notify(); cx.notify();
} }
fn cycle_model(&mut self, cx: &mut ViewContext<Self>) {
let next_model = match &self.model {
LanguageModel::OpenAi(model) => LanguageModel::OpenAi(match &model {
open_ai::Model::ThreePointFiveTurbo => open_ai::Model::Four,
open_ai::Model::Four => open_ai::Model::FourTurbo,
open_ai::Model::FourTurbo => open_ai::Model::FourOmni,
open_ai::Model::FourOmni => open_ai::Model::ThreePointFiveTurbo,
}),
LanguageModel::Anthropic(model) => LanguageModel::Anthropic(match &model {
anthropic::Model::Claude3Opus => anthropic::Model::Claude3Sonnet,
anthropic::Model::Claude3Sonnet => anthropic::Model::Claude3Haiku,
anthropic::Model::Claude3Haiku => anthropic::Model::Claude3Opus,
}),
LanguageModel::ZedDotDev(model) => LanguageModel::ZedDotDev(match &model {
ZedDotDevModel::Gpt3Point5Turbo => ZedDotDevModel::Gpt4,
ZedDotDevModel::Gpt4 => ZedDotDevModel::Gpt4Turbo,
ZedDotDevModel::Gpt4Turbo => ZedDotDevModel::Gpt4Omni,
ZedDotDevModel::Gpt4Omni => ZedDotDevModel::Claude3Opus,
ZedDotDevModel::Claude3Opus => ZedDotDevModel::Claude3Sonnet,
ZedDotDevModel::Claude3Sonnet => ZedDotDevModel::Claude3Haiku,
ZedDotDevModel::Claude3Haiku => {
match CompletionProvider::global(cx).default_model() {
LanguageModel::ZedDotDev(custom @ ZedDotDevModel::Custom(_)) => custom,
_ => ZedDotDevModel::Gpt3Point5Turbo,
}
}
ZedDotDevModel::Custom(_) => ZedDotDevModel::Gpt3Point5Turbo,
}),
};
self.set_model(next_model, cx);
}
fn set_model(&mut self, model: LanguageModel, cx: &mut ViewContext<Self>) {
self.model = model.clone();
if let Some(editor) = self.active_conversation_editor() {
editor.update(cx, |active_conversation, cx| {
active_conversation
.conversation
.update(cx, |conversation, cx| {
conversation.set_model(model, cx);
})
})
}
cx.notify();
}
fn handle_conversation_editor_event( fn handle_conversation_editor_event(
&mut self, &mut self,
_: View<ConversationEditor>, _: View<ConversationEditor>,
@ -978,6 +939,10 @@ impl AssistantPanel {
.detach_and_log_err(cx); .detach_and_log_err(cx);
} }
fn toggle_model_selector(&mut self, _: &ToggleModelSelector, cx: &mut ViewContext<Self>) {
self.model_menu_handle.toggle(cx);
}
fn active_conversation_editor(&self) -> Option<&View<ConversationEditor>> { fn active_conversation_editor(&self) -> Option<&View<ConversationEditor>> {
Some(&self.active_conversation_editor.as_ref()?.editor) Some(&self.active_conversation_editor.as_ref()?.editor)
} }
@ -1133,10 +1098,8 @@ impl AssistantPanel {
cx.spawn(|this, mut cx| async move { cx.spawn(|this, mut cx| async move {
let saved_conversation = SavedConversation::load(&path, fs.as_ref()).await?; let saved_conversation = SavedConversation::load(&path, fs.as_ref()).await?;
let model = this.update(&mut cx, |this, _| this.model.clone())?;
let conversation = Conversation::deserialize( let conversation = Conversation::deserialize(
saved_conversation, saved_conversation,
model,
path.clone(), path.clone(),
languages, languages,
slash_commands, slash_commands,
@ -1206,7 +1169,10 @@ impl AssistantPanel {
this.child( this.child(
h_flex() h_flex()
.gap_1() .gap_1()
.child(self.render_model(&conversation, cx)) .child(ModelSelector::new(
self.model_menu_handle.clone(),
self.fs.clone(),
))
.children(self.render_remaining_tokens(&conversation, cx)), .children(self.render_remaining_tokens(&conversation, cx)),
) )
.child( .child(
@ -1256,6 +1222,7 @@ impl AssistantPanel {
.on_action(cx.listener(AssistantPanel::select_prev_match)) .on_action(cx.listener(AssistantPanel::select_prev_match))
.on_action(cx.listener(AssistantPanel::handle_editor_cancel)) .on_action(cx.listener(AssistantPanel::handle_editor_cancel))
.on_action(cx.listener(AssistantPanel::reset_credentials)) .on_action(cx.listener(AssistantPanel::reset_credentials))
.on_action(cx.listener(AssistantPanel::toggle_model_selector))
.track_focus(&self.focus_handle) .track_focus(&self.focus_handle)
.child(header) .child(header)
.children(if self.toolbar.read(cx).hidden() { .children(if self.toolbar.read(cx).hidden() {
@ -1314,23 +1281,12 @@ impl AssistantPanel {
)) ))
} }
fn render_model(
&self,
conversation: &Model<Conversation>,
cx: &mut ViewContext<Self>,
) -> impl IntoElement {
Button::new("current_model", conversation.read(cx).model.display_name())
.style(ButtonStyle::Filled)
.tooltip(move |cx| Tooltip::text("Change Model", cx))
.on_click(cx.listener(|this, _, cx| this.cycle_model(cx)))
}
fn render_remaining_tokens( fn render_remaining_tokens(
&self, &self,
conversation: &Model<Conversation>, conversation: &Model<Conversation>,
cx: &mut ViewContext<Self>, cx: &mut ViewContext<Self>,
) -> Option<impl IntoElement> { ) -> Option<impl IntoElement> {
let remaining_tokens = conversation.read(cx).remaining_tokens()?; let remaining_tokens = conversation.read(cx).remaining_tokens(cx)?;
let remaining_tokens_color = if remaining_tokens <= 0 { let remaining_tokens_color = if remaining_tokens <= 0 {
Color::Error Color::Error
} else if remaining_tokens <= 500 { } else if remaining_tokens <= 500 {
@ -1486,7 +1442,6 @@ pub struct Conversation {
pending_summary: Task<Option<()>>, pending_summary: Task<Option<()>>,
completion_count: usize, completion_count: usize,
pending_completions: Vec<PendingCompletion>, pending_completions: Vec<PendingCompletion>,
model: LanguageModel,
token_count: Option<usize>, token_count: Option<usize>,
pending_token_count: Task<Option<()>>, pending_token_count: Task<Option<()>>,
pending_edit_suggestion_parse: Option<Task<()>>, pending_edit_suggestion_parse: Option<Task<()>>,
@ -1502,7 +1457,6 @@ impl EventEmitter<ConversationEvent> for Conversation {}
impl Conversation { impl Conversation {
fn new( fn new(
model: LanguageModel,
language_registry: Arc<LanguageRegistry>, language_registry: Arc<LanguageRegistry>,
slash_command_registry: Arc<SlashCommandRegistry>, slash_command_registry: Arc<SlashCommandRegistry>,
telemetry: Option<Arc<Telemetry>>, telemetry: Option<Arc<Telemetry>>,
@ -1530,7 +1484,6 @@ impl Conversation {
token_count: None, token_count: None,
pending_token_count: Task::ready(None), pending_token_count: Task::ready(None),
pending_edit_suggestion_parse: None, pending_edit_suggestion_parse: None,
model,
_subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)], _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
pending_save: Task::ready(Ok(())), pending_save: Task::ready(Ok(())),
path: None, path: None,
@ -1583,7 +1536,6 @@ impl Conversation {
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
async fn deserialize( async fn deserialize(
saved_conversation: SavedConversation, saved_conversation: SavedConversation,
model: LanguageModel,
path: PathBuf, path: PathBuf,
language_registry: Arc<LanguageRegistry>, language_registry: Arc<LanguageRegistry>,
slash_command_registry: Arc<SlashCommandRegistry>, slash_command_registry: Arc<SlashCommandRegistry>,
@ -1640,7 +1592,6 @@ impl Conversation {
token_count: None, token_count: None,
pending_edit_suggestion_parse: None, pending_edit_suggestion_parse: None,
pending_token_count: Task::ready(None), pending_token_count: Task::ready(None),
model,
_subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)], _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
pending_save: Task::ready(Ok(())), pending_save: Task::ready(Ok(())),
path: Some(path), path: Some(path),
@ -1938,12 +1889,12 @@ impl Conversation {
} }
} }
fn remaining_tokens(&self) -> Option<isize> { fn remaining_tokens(&self, cx: &AppContext) -> Option<isize> {
Some(self.model.max_token_count() as isize - self.token_count? as isize) let model = CompletionProvider::global(cx).model();
Some(model.max_token_count() as isize - self.token_count? as isize)
} }
fn set_model(&mut self, model: LanguageModel, cx: &mut ModelContext<Self>) { fn completion_provider_changed(&mut self, cx: &mut ModelContext<Self>) {
self.model = model;
self.count_remaining_tokens(cx); self.count_remaining_tokens(cx);
} }
@ -2079,10 +2030,11 @@ impl Conversation {
} }
if let Some(telemetry) = this.telemetry.as_ref() { if let Some(telemetry) = this.telemetry.as_ref() {
let model = CompletionProvider::global(cx).model();
telemetry.report_assistant_event( telemetry.report_assistant_event(
this.id.clone(), this.id.clone(),
AssistantKind::Panel, AssistantKind::Panel,
this.model.telemetry_id(), model.telemetry_id(),
response_latency, response_latency,
error_message, error_message,
); );
@ -2111,7 +2063,7 @@ impl Conversation {
.map(|message| message.to_request_message(self.buffer.read(cx))); .map(|message| message.to_request_message(self.buffer.read(cx)));
LanguageModelRequest { LanguageModelRequest {
model: self.model.clone(), model: CompletionProvider::global(cx).model(),
messages: messages.collect(), messages: messages.collect(),
stop: vec![], stop: vec![],
temperature: 1.0, temperature: 1.0,
@ -2300,7 +2252,7 @@ impl Conversation {
.into(), .into(),
})); }));
let request = LanguageModelRequest { let request = LanguageModelRequest {
model: self.model.clone(), model: CompletionProvider::global(cx).model(),
messages: messages.collect(), messages: messages.collect(),
stop: vec![], stop: vec![],
temperature: 1.0, temperature: 1.0,
@ -2605,7 +2557,6 @@ pub struct ConversationEditor {
impl ConversationEditor { impl ConversationEditor {
fn new( fn new(
model: LanguageModel,
language_registry: Arc<LanguageRegistry>, language_registry: Arc<LanguageRegistry>,
slash_command_registry: Arc<SlashCommandRegistry>, slash_command_registry: Arc<SlashCommandRegistry>,
fs: Arc<dyn Fs>, fs: Arc<dyn Fs>,
@ -2618,7 +2569,6 @@ impl ConversationEditor {
let conversation = cx.new_model(|cx| { let conversation = cx.new_model(|cx| {
Conversation::new( Conversation::new(
model,
language_registry, language_registry,
slash_command_registry, slash_command_registry,
Some(telemetry), Some(telemetry),
@ -3847,15 +3797,8 @@ mod tests {
init(cx); init(cx);
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone())); let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
let conversation = cx.new_model(|cx| { let conversation =
Conversation::new( cx.new_model(|cx| Conversation::new(registry, Default::default(), None, cx));
LanguageModel::default(),
registry,
Default::default(),
None,
cx,
)
});
let buffer = conversation.read(cx).buffer.clone(); let buffer = conversation.read(cx).buffer.clone();
let message_1 = conversation.read(cx).message_anchors[0].clone(); let message_1 = conversation.read(cx).message_anchors[0].clone();
@ -3986,15 +3929,8 @@ mod tests {
init(cx); init(cx);
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone())); let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
let conversation = cx.new_model(|cx| { let conversation =
Conversation::new( cx.new_model(|cx| Conversation::new(registry, Default::default(), None, cx));
LanguageModel::default(),
registry,
Default::default(),
None,
cx,
)
});
let buffer = conversation.read(cx).buffer.clone(); let buffer = conversation.read(cx).buffer.clone();
let message_1 = conversation.read(cx).message_anchors[0].clone(); let message_1 = conversation.read(cx).message_anchors[0].clone();
@ -4092,15 +4028,8 @@ mod tests {
cx.set_global(settings_store); cx.set_global(settings_store);
init(cx); init(cx);
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone())); let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
let conversation = cx.new_model(|cx| { let conversation =
Conversation::new( cx.new_model(|cx| Conversation::new(registry, Default::default(), None, cx));
LanguageModel::default(),
registry,
Default::default(),
None,
cx,
)
});
let buffer = conversation.read(cx).buffer.clone(); let buffer = conversation.read(cx).buffer.clone();
let message_1 = conversation.read(cx).message_anchors[0].clone(); let message_1 = conversation.read(cx).message_anchors[0].clone();
@ -4209,15 +4138,8 @@ mod tests {
)); ));
let registry = Arc::new(LanguageRegistry::test(cx.executor())); let registry = Arc::new(LanguageRegistry::test(cx.executor()));
let conversation = cx.new_model(|cx| { let conversation = cx
Conversation::new( .new_model(|cx| Conversation::new(registry.clone(), slash_command_registry, None, cx));
LanguageModel::default(),
registry.clone(),
slash_command_registry,
None,
cx,
)
});
let output_ranges = Rc::new(RefCell::new(HashSet::default())); let output_ranges = Rc::new(RefCell::new(HashSet::default()));
conversation.update(cx, |_, cx| { conversation.update(cx, |_, cx| {
@ -4390,15 +4312,8 @@ mod tests {
cx.set_global(CompletionProvider::Fake(FakeCompletionProvider::default())); cx.set_global(CompletionProvider::Fake(FakeCompletionProvider::default()));
cx.update(init); cx.update(init);
let registry = Arc::new(LanguageRegistry::test(cx.executor())); let registry = Arc::new(LanguageRegistry::test(cx.executor()));
let conversation = cx.new_model(|cx| { let conversation =
Conversation::new( cx.new_model(|cx| Conversation::new(registry.clone(), Default::default(), None, cx));
LanguageModel::default(),
registry.clone(),
Default::default(),
None,
cx,
)
});
let buffer = conversation.read_with(cx, |conversation, _| conversation.buffer.clone()); let buffer = conversation.read_with(cx, |conversation, _| conversation.buffer.clone());
let message_0 = let message_0 =
conversation.read_with(cx, |conversation, _| conversation.message_anchors[0].id); conversation.read_with(cx, |conversation, _| conversation.message_anchors[0].id);
@ -4434,7 +4349,6 @@ mod tests {
let deserialized_conversation = Conversation::deserialize( let deserialized_conversation = Conversation::deserialize(
conversation.read_with(cx, |conversation, cx| conversation.serialize(cx)), conversation.read_with(cx, |conversation, cx| conversation.serialize(cx)),
LanguageModel::default(),
Default::default(), Default::default(),
registry.clone(), registry.clone(),
Default::default(), Default::default(),

View File

@ -12,8 +12,11 @@ use serde::{
Deserialize, Deserializer, Serialize, Serializer, Deserialize, Deserializer, Serialize, Serializer,
}; };
use settings::{Settings, SettingsSources}; use settings::{Settings, SettingsSources};
use strum::{EnumIter, IntoEnumIterator};
#[derive(Clone, Debug, Default, PartialEq)] use crate::LanguageModel;
#[derive(Clone, Debug, Default, PartialEq, EnumIter)]
pub enum ZedDotDevModel { pub enum ZedDotDevModel {
Gpt3Point5Turbo, Gpt3Point5Turbo,
Gpt4, Gpt4,
@ -53,13 +56,10 @@ impl<'de> Deserialize<'de> for ZedDotDevModel {
where where
E: de::Error, E: de::Error,
{ {
match value { let model = ZedDotDevModel::iter()
"gpt-3.5-turbo" => Ok(ZedDotDevModel::Gpt3Point5Turbo), .find(|model| model.id() == value)
"gpt-4" => Ok(ZedDotDevModel::Gpt4), .unwrap_or_else(|| ZedDotDevModel::Custom(value.to_string()));
"gpt-4-turbo-preview" => Ok(ZedDotDevModel::Gpt4Turbo), Ok(model)
"gpt-4o" => Ok(ZedDotDevModel::Gpt4Omni),
_ => Ok(ZedDotDevModel::Custom(value.to_owned())),
}
} }
} }
@ -73,24 +73,23 @@ impl JsonSchema for ZedDotDevModel {
} }
fn json_schema(_generator: &mut schemars::gen::SchemaGenerator) -> Schema { fn json_schema(_generator: &mut schemars::gen::SchemaGenerator) -> Schema {
let variants = vec![ let variants = ZedDotDevModel::iter()
"gpt-3.5-turbo".to_owned(), .filter_map(|model| {
"gpt-4".to_owned(), let id = model.id();
"gpt-4-turbo-preview".to_owned(), if id.is_empty() {
"gpt-4o".to_owned(), None
]; } else {
Some(id.to_string())
}
})
.collect::<Vec<_>>();
Schema::Object(SchemaObject { Schema::Object(SchemaObject {
instance_type: Some(InstanceType::String.into()), instance_type: Some(InstanceType::String.into()),
enum_values: Some(variants.into_iter().map(|s| s.into()).collect()), enum_values: Some(variants.iter().map(|s| s.clone().into()).collect()),
metadata: Some(Box::new(Metadata { metadata: Some(Box::new(Metadata {
title: Some("ZedDotDevModel".to_owned()), title: Some("ZedDotDevModel".to_owned()),
default: Some(serde_json::json!("gpt-4-turbo-preview")), default: Some(ZedDotDevModel::default().id().into()),
examples: vec![ examples: variants.into_iter().map(Into::into).collect(),
serde_json::json!("gpt-3.5-turbo"),
serde_json::json!("gpt-4"),
serde_json::json!("gpt-4-turbo-preview"),
serde_json::json!("custom-model-name"),
],
..Default::default() ..Default::default()
})), })),
..Default::default() ..Default::default()
@ -145,51 +144,55 @@ pub enum AssistantDockPosition {
Bottom, Bottom,
} }
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)] #[derive(Debug, PartialEq)]
#[serde(tag = "name", rename_all = "snake_case")]
pub enum AssistantProvider { pub enum AssistantProvider {
#[serde(rename = "zed.dev")]
ZedDotDev { ZedDotDev {
#[serde(default)] model: ZedDotDevModel,
default_model: ZedDotDevModel,
}, },
#[serde(rename = "openai")]
OpenAi { OpenAi {
#[serde(default)] model: OpenAiModel,
default_model: OpenAiModel,
#[serde(default = "open_ai_url")]
api_url: String, api_url: String,
#[serde(default)]
low_speed_timeout_in_seconds: Option<u64>, low_speed_timeout_in_seconds: Option<u64>,
}, },
#[serde(rename = "anthropic")]
Anthropic { Anthropic {
#[serde(default)] model: AnthropicModel,
default_model: AnthropicModel,
#[serde(default = "anthropic_api_url")]
api_url: String, api_url: String,
#[serde(default)]
low_speed_timeout_in_seconds: Option<u64>, low_speed_timeout_in_seconds: Option<u64>,
}, },
} }
impl Default for AssistantProvider { impl Default for AssistantProvider {
fn default() -> Self { fn default() -> Self {
Self::ZedDotDev { Self::OpenAi {
default_model: ZedDotDevModel::default(), model: OpenAiModel::default(),
api_url: open_ai::OPEN_AI_API_URL.into(),
low_speed_timeout_in_seconds: None,
} }
} }
} }
fn open_ai_url() -> String { #[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
open_ai::OPEN_AI_API_URL.to_string() #[serde(tag = "name", rename_all = "snake_case")]
pub enum AssistantProviderContent {
#[serde(rename = "zed.dev")]
ZedDotDev {
default_model: Option<ZedDotDevModel>,
},
#[serde(rename = "openai")]
OpenAi {
default_model: Option<OpenAiModel>,
api_url: Option<String>,
low_speed_timeout_in_seconds: Option<u64>,
},
#[serde(rename = "anthropic")]
Anthropic {
default_model: Option<AnthropicModel>,
api_url: Option<String>,
low_speed_timeout_in_seconds: Option<u64>,
},
} }
fn anthropic_api_url() -> String { #[derive(Debug, Default)]
anthropic::ANTHROPIC_API_URL.to_string()
}
#[derive(Default, Debug, Deserialize, Serialize)]
pub struct AssistantSettings { pub struct AssistantSettings {
pub enabled: bool, pub enabled: bool,
pub button: bool, pub button: bool,
@ -240,16 +243,16 @@ impl AssistantSettingsContent {
default_width: settings.default_width, default_width: settings.default_width,
default_height: settings.default_height, default_height: settings.default_height,
provider: if let Some(open_ai_api_url) = settings.openai_api_url.as_ref() { provider: if let Some(open_ai_api_url) = settings.openai_api_url.as_ref() {
Some(AssistantProvider::OpenAi { Some(AssistantProviderContent::OpenAi {
default_model: settings.default_open_ai_model.clone().unwrap_or_default(), default_model: settings.default_open_ai_model.clone(),
api_url: open_ai_api_url.clone(), api_url: Some(open_ai_api_url.clone()),
low_speed_timeout_in_seconds: None, low_speed_timeout_in_seconds: None,
}) })
} else { } else {
settings.default_open_ai_model.clone().map(|open_ai_model| { settings.default_open_ai_model.clone().map(|open_ai_model| {
AssistantProvider::OpenAi { AssistantProviderContent::OpenAi {
default_model: open_ai_model, default_model: Some(open_ai_model),
api_url: open_ai_url(), api_url: None,
low_speed_timeout_in_seconds: None, low_speed_timeout_in_seconds: None,
} }
}) })
@ -270,6 +273,64 @@ impl AssistantSettingsContent {
} }
} }
} }
pub fn set_model(&mut self, new_model: LanguageModel) {
match self {
AssistantSettingsContent::Versioned(settings) => match settings {
VersionedAssistantSettingsContent::V1(settings) => match &mut settings.provider {
Some(AssistantProviderContent::ZedDotDev {
default_model: model,
}) => {
if let LanguageModel::ZedDotDev(new_model) = new_model {
*model = Some(new_model);
}
}
Some(AssistantProviderContent::OpenAi {
default_model: model,
..
}) => {
if let LanguageModel::OpenAi(new_model) = new_model {
*model = Some(new_model);
}
}
Some(AssistantProviderContent::Anthropic {
default_model: model,
..
}) => {
if let LanguageModel::Anthropic(new_model) = new_model {
*model = Some(new_model);
}
}
provider => match new_model {
LanguageModel::ZedDotDev(model) => {
*provider = Some(AssistantProviderContent::ZedDotDev {
default_model: Some(model),
})
}
LanguageModel::OpenAi(model) => {
*provider = Some(AssistantProviderContent::OpenAi {
default_model: Some(model),
api_url: None,
low_speed_timeout_in_seconds: None,
})
}
LanguageModel::Anthropic(model) => {
*provider = Some(AssistantProviderContent::Anthropic {
default_model: Some(model),
api_url: None,
low_speed_timeout_in_seconds: None,
})
}
},
},
},
AssistantSettingsContent::Legacy(settings) => {
if let LanguageModel::OpenAi(model) = new_model {
settings.default_open_ai_model = Some(model);
}
}
}
}
} }
#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)] #[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
@ -318,7 +379,7 @@ pub struct AssistantSettingsContentV1 {
/// ///
/// This can either be the internal `zed.dev` service or an external `openai` service, /// This can either be the internal `zed.dev` service or an external `openai` service,
/// each with their respective default models and configurations. /// each with their respective default models and configurations.
provider: Option<AssistantProvider>, provider: Option<AssistantProviderContent>,
} }
#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)] #[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
@ -376,31 +437,82 @@ impl Settings for AssistantSettings {
if let Some(provider) = value.provider.clone() { if let Some(provider) = value.provider.clone() {
match (&mut settings.provider, provider) { match (&mut settings.provider, provider) {
( (
AssistantProvider::ZedDotDev { default_model }, AssistantProvider::ZedDotDev { model },
AssistantProvider::ZedDotDev { AssistantProviderContent::ZedDotDev {
default_model: default_model_override, default_model: model_override,
}, },
) => { ) => {
*default_model = default_model_override; merge(model, model_override);
} }
( (
AssistantProvider::OpenAi { AssistantProvider::OpenAi {
default_model, model,
api_url, api_url,
low_speed_timeout_in_seconds, low_speed_timeout_in_seconds,
}, },
AssistantProvider::OpenAi { AssistantProviderContent::OpenAi {
default_model: default_model_override, default_model: model_override,
api_url: api_url_override, api_url: api_url_override,
low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override, low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override,
}, },
) => { ) => {
*default_model = default_model_override; merge(model, model_override);
*api_url = api_url_override; merge(api_url, api_url_override);
*low_speed_timeout_in_seconds = low_speed_timeout_in_seconds_override; if let Some(low_speed_timeout_in_seconds_override) =
low_speed_timeout_in_seconds_override
{
*low_speed_timeout_in_seconds =
Some(low_speed_timeout_in_seconds_override);
}
} }
(merged, provider_override) => { (
*merged = provider_override; AssistantProvider::Anthropic {
model,
api_url,
low_speed_timeout_in_seconds,
},
AssistantProviderContent::Anthropic {
default_model: model_override,
api_url: api_url_override,
low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override,
},
) => {
merge(model, model_override);
merge(api_url, api_url_override);
if let Some(low_speed_timeout_in_seconds_override) =
low_speed_timeout_in_seconds_override
{
*low_speed_timeout_in_seconds =
Some(low_speed_timeout_in_seconds_override);
}
}
(provider, provider_override) => {
*provider = match provider_override {
AssistantProviderContent::ZedDotDev {
default_model: model,
} => AssistantProvider::ZedDotDev {
model: model.unwrap_or_default(),
},
AssistantProviderContent::OpenAi {
default_model: model,
api_url,
low_speed_timeout_in_seconds,
} => AssistantProvider::OpenAi {
model: model.unwrap_or_default(),
api_url: api_url.unwrap_or_else(|| open_ai::OPEN_AI_API_URL.into()),
low_speed_timeout_in_seconds,
},
AssistantProviderContent::Anthropic {
default_model: model,
api_url,
low_speed_timeout_in_seconds,
} => AssistantProvider::Anthropic {
model: model.unwrap_or_default(),
api_url: api_url
.unwrap_or_else(|| anthropic::ANTHROPIC_API_URL.into()),
low_speed_timeout_in_seconds,
},
};
} }
} }
} }
@ -410,7 +522,7 @@ impl Settings for AssistantSettings {
} }
} }
fn merge<T: Copy>(target: &mut T, value: Option<T>) { fn merge<T>(target: &mut T, value: Option<T>) {
if let Some(value) = value { if let Some(value) = value {
*target = value; *target = value;
} }
@ -433,8 +545,8 @@ mod tests {
assert_eq!( assert_eq!(
AssistantSettings::get_global(cx).provider, AssistantSettings::get_global(cx).provider,
AssistantProvider::OpenAi { AssistantProvider::OpenAi {
default_model: OpenAiModel::FourOmni, model: OpenAiModel::FourOmni,
api_url: open_ai_url(), api_url: open_ai::OPEN_AI_API_URL.into(),
low_speed_timeout_in_seconds: None, low_speed_timeout_in_seconds: None,
} }
); );
@ -455,7 +567,7 @@ mod tests {
assert_eq!( assert_eq!(
AssistantSettings::get_global(cx).provider, AssistantSettings::get_global(cx).provider,
AssistantProvider::OpenAi { AssistantProvider::OpenAi {
default_model: OpenAiModel::FourOmni, model: OpenAiModel::FourOmni,
api_url: "test-url".into(), api_url: "test-url".into(),
low_speed_timeout_in_seconds: None, low_speed_timeout_in_seconds: None,
} }
@ -475,8 +587,8 @@ mod tests {
assert_eq!( assert_eq!(
AssistantSettings::get_global(cx).provider, AssistantSettings::get_global(cx).provider,
AssistantProvider::OpenAi { AssistantProvider::OpenAi {
default_model: OpenAiModel::Four, model: OpenAiModel::Four,
api_url: open_ai_url(), api_url: open_ai::OPEN_AI_API_URL.into(),
low_speed_timeout_in_seconds: None, low_speed_timeout_in_seconds: None,
} }
); );
@ -501,7 +613,7 @@ mod tests {
assert_eq!( assert_eq!(
AssistantSettings::get_global(cx).provider, AssistantSettings::get_global(cx).provider,
AssistantProvider::ZedDotDev { AssistantProvider::ZedDotDev {
default_model: ZedDotDevModel::Custom("custom".into()) model: ZedDotDevModel::Custom("custom".into())
} }
); );
} }

View File

@ -25,31 +25,26 @@ use std::time::Duration;
pub fn init(client: Arc<Client>, cx: &mut AppContext) { pub fn init(client: Arc<Client>, cx: &mut AppContext) {
let mut settings_version = 0; let mut settings_version = 0;
let provider = match &AssistantSettings::get_global(cx).provider { let provider = match &AssistantSettings::get_global(cx).provider {
AssistantProvider::ZedDotDev { default_model } => { AssistantProvider::ZedDotDev { model } => CompletionProvider::ZedDotDev(
CompletionProvider::ZedDotDev(ZedDotDevCompletionProvider::new( ZedDotDevCompletionProvider::new(model.clone(), client.clone(), settings_version, cx),
default_model.clone(), ),
client.clone(),
settings_version,
cx,
))
}
AssistantProvider::OpenAi { AssistantProvider::OpenAi {
default_model, model,
api_url, api_url,
low_speed_timeout_in_seconds, low_speed_timeout_in_seconds,
} => CompletionProvider::OpenAi(OpenAiCompletionProvider::new( } => CompletionProvider::OpenAi(OpenAiCompletionProvider::new(
default_model.clone(), model.clone(),
api_url.clone(), api_url.clone(),
client.http_client(), client.http_client(),
low_speed_timeout_in_seconds.map(Duration::from_secs), low_speed_timeout_in_seconds.map(Duration::from_secs),
settings_version, settings_version,
)), )),
AssistantProvider::Anthropic { AssistantProvider::Anthropic {
default_model, model,
api_url, api_url,
low_speed_timeout_in_seconds, low_speed_timeout_in_seconds,
} => CompletionProvider::Anthropic(AnthropicCompletionProvider::new( } => CompletionProvider::Anthropic(AnthropicCompletionProvider::new(
default_model.clone(), model.clone(),
api_url.clone(), api_url.clone(),
client.http_client(), client.http_client(),
low_speed_timeout_in_seconds.map(Duration::from_secs), low_speed_timeout_in_seconds.map(Duration::from_secs),
@ -65,13 +60,13 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
( (
CompletionProvider::OpenAi(provider), CompletionProvider::OpenAi(provider),
AssistantProvider::OpenAi { AssistantProvider::OpenAi {
default_model, model,
api_url, api_url,
low_speed_timeout_in_seconds, low_speed_timeout_in_seconds,
}, },
) => { ) => {
provider.update( provider.update(
default_model.clone(), model.clone(),
api_url.clone(), api_url.clone(),
low_speed_timeout_in_seconds.map(Duration::from_secs), low_speed_timeout_in_seconds.map(Duration::from_secs),
settings_version, settings_version,
@ -80,13 +75,13 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
( (
CompletionProvider::Anthropic(provider), CompletionProvider::Anthropic(provider),
AssistantProvider::Anthropic { AssistantProvider::Anthropic {
default_model, model,
api_url, api_url,
low_speed_timeout_in_seconds, low_speed_timeout_in_seconds,
}, },
) => { ) => {
provider.update( provider.update(
default_model.clone(), model.clone(),
api_url.clone(), api_url.clone(),
low_speed_timeout_in_seconds.map(Duration::from_secs), low_speed_timeout_in_seconds.map(Duration::from_secs),
settings_version, settings_version,
@ -94,13 +89,13 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
} }
( (
CompletionProvider::ZedDotDev(provider), CompletionProvider::ZedDotDev(provider),
AssistantProvider::ZedDotDev { default_model }, AssistantProvider::ZedDotDev { model },
) => { ) => {
provider.update(default_model.clone(), settings_version); provider.update(model.clone(), settings_version);
} }
(_, AssistantProvider::ZedDotDev { default_model }) => { (_, AssistantProvider::ZedDotDev { model }) => {
*provider = CompletionProvider::ZedDotDev(ZedDotDevCompletionProvider::new( *provider = CompletionProvider::ZedDotDev(ZedDotDevCompletionProvider::new(
default_model.clone(), model.clone(),
client.clone(), client.clone(),
settings_version, settings_version,
cx, cx,
@ -109,13 +104,13 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
( (
_, _,
AssistantProvider::OpenAi { AssistantProvider::OpenAi {
default_model, model,
api_url, api_url,
low_speed_timeout_in_seconds, low_speed_timeout_in_seconds,
}, },
) => { ) => {
*provider = CompletionProvider::OpenAi(OpenAiCompletionProvider::new( *provider = CompletionProvider::OpenAi(OpenAiCompletionProvider::new(
default_model.clone(), model.clone(),
api_url.clone(), api_url.clone(),
client.http_client(), client.http_client(),
low_speed_timeout_in_seconds.map(Duration::from_secs), low_speed_timeout_in_seconds.map(Duration::from_secs),
@ -125,13 +120,13 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
( (
_, _,
AssistantProvider::Anthropic { AssistantProvider::Anthropic {
default_model, model,
api_url, api_url,
low_speed_timeout_in_seconds, low_speed_timeout_in_seconds,
}, },
) => { ) => {
*provider = CompletionProvider::Anthropic(AnthropicCompletionProvider::new( *provider = CompletionProvider::Anthropic(AnthropicCompletionProvider::new(
default_model.clone(), model.clone(),
api_url.clone(), api_url.clone(),
client.http_client(), client.http_client(),
low_speed_timeout_in_seconds.map(Duration::from_secs), low_speed_timeout_in_seconds.map(Duration::from_secs),
@ -159,6 +154,25 @@ impl CompletionProvider {
cx.global::<Self>() cx.global::<Self>()
} }
pub fn available_models(&self) -> Vec<LanguageModel> {
match self {
CompletionProvider::OpenAi(provider) => provider
.available_models()
.map(LanguageModel::OpenAi)
.collect(),
CompletionProvider::Anthropic(provider) => provider
.available_models()
.map(LanguageModel::Anthropic)
.collect(),
CompletionProvider::ZedDotDev(provider) => provider
.available_models()
.map(LanguageModel::ZedDotDev)
.collect(),
#[cfg(test)]
CompletionProvider::Fake(_) => unimplemented!(),
}
}
pub fn settings_version(&self) -> usize { pub fn settings_version(&self) -> usize {
match self { match self {
CompletionProvider::OpenAi(provider) => provider.settings_version(), CompletionProvider::OpenAi(provider) => provider.settings_version(),
@ -209,17 +223,13 @@ impl CompletionProvider {
} }
} }
pub fn default_model(&self) -> LanguageModel { pub fn model(&self) -> LanguageModel {
match self { match self {
CompletionProvider::OpenAi(provider) => LanguageModel::OpenAi(provider.default_model()), CompletionProvider::OpenAi(provider) => LanguageModel::OpenAi(provider.model()),
CompletionProvider::Anthropic(provider) => { CompletionProvider::Anthropic(provider) => LanguageModel::Anthropic(provider.model()),
LanguageModel::Anthropic(provider.default_model()) CompletionProvider::ZedDotDev(provider) => LanguageModel::ZedDotDev(provider.model()),
}
CompletionProvider::ZedDotDev(provider) => {
LanguageModel::ZedDotDev(provider.default_model())
}
#[cfg(test)] #[cfg(test)]
CompletionProvider::Fake(_) => unimplemented!(), CompletionProvider::Fake(_) => LanguageModel::default(),
} }
} }

View File

@ -12,6 +12,7 @@ use http::HttpClient;
use settings::Settings; use settings::Settings;
use std::time::Duration; use std::time::Duration;
use std::{env, sync::Arc}; use std::{env, sync::Arc};
use strum::IntoEnumIterator;
use theme::ThemeSettings; use theme::ThemeSettings;
use ui::prelude::*; use ui::prelude::*;
use util::ResultExt; use util::ResultExt;
@ -19,7 +20,7 @@ use util::ResultExt;
pub struct AnthropicCompletionProvider { pub struct AnthropicCompletionProvider {
api_key: Option<String>, api_key: Option<String>,
api_url: String, api_url: String,
default_model: AnthropicModel, model: AnthropicModel,
http_client: Arc<dyn HttpClient>, http_client: Arc<dyn HttpClient>,
low_speed_timeout: Option<Duration>, low_speed_timeout: Option<Duration>,
settings_version: usize, settings_version: usize,
@ -27,7 +28,7 @@ pub struct AnthropicCompletionProvider {
impl AnthropicCompletionProvider { impl AnthropicCompletionProvider {
pub fn new( pub fn new(
default_model: AnthropicModel, model: AnthropicModel,
api_url: String, api_url: String,
http_client: Arc<dyn HttpClient>, http_client: Arc<dyn HttpClient>,
low_speed_timeout: Option<Duration>, low_speed_timeout: Option<Duration>,
@ -36,7 +37,7 @@ impl AnthropicCompletionProvider {
Self { Self {
api_key: None, api_key: None,
api_url, api_url,
default_model, model,
http_client, http_client,
low_speed_timeout, low_speed_timeout,
settings_version, settings_version,
@ -45,17 +46,21 @@ impl AnthropicCompletionProvider {
pub fn update( pub fn update(
&mut self, &mut self,
default_model: AnthropicModel, model: AnthropicModel,
api_url: String, api_url: String,
low_speed_timeout: Option<Duration>, low_speed_timeout: Option<Duration>,
settings_version: usize, settings_version: usize,
) { ) {
self.default_model = default_model; self.model = model;
self.api_url = api_url; self.api_url = api_url;
self.low_speed_timeout = low_speed_timeout; self.low_speed_timeout = low_speed_timeout;
self.settings_version = settings_version; self.settings_version = settings_version;
} }
pub fn available_models(&self) -> impl Iterator<Item = AnthropicModel> {
AnthropicModel::iter()
}
pub fn settings_version(&self) -> usize { pub fn settings_version(&self) -> usize {
self.settings_version self.settings_version
} }
@ -105,8 +110,8 @@ impl AnthropicCompletionProvider {
.into() .into()
} }
pub fn default_model(&self) -> AnthropicModel { pub fn model(&self) -> AnthropicModel {
self.default_model.clone() self.model.clone()
} }
pub fn count_tokens( pub fn count_tokens(
@ -165,7 +170,7 @@ impl AnthropicCompletionProvider {
fn to_anthropic_request(&self, request: LanguageModelRequest) -> Request { fn to_anthropic_request(&self, request: LanguageModelRequest) -> Request {
let model = match request.model { let model = match request.model {
LanguageModel::Anthropic(model) => model, LanguageModel::Anthropic(model) => model,
_ => self.default_model(), _ => self.model(),
}; };
let mut system_message = String::new(); let mut system_message = String::new();

View File

@ -11,6 +11,7 @@ use open_ai::{stream_completion, Request, RequestMessage, Role as OpenAiRole};
use settings::Settings; use settings::Settings;
use std::time::Duration; use std::time::Duration;
use std::{env, sync::Arc}; use std::{env, sync::Arc};
use strum::IntoEnumIterator;
use theme::ThemeSettings; use theme::ThemeSettings;
use ui::prelude::*; use ui::prelude::*;
use util::ResultExt; use util::ResultExt;
@ -18,7 +19,7 @@ use util::ResultExt;
pub struct OpenAiCompletionProvider { pub struct OpenAiCompletionProvider {
api_key: Option<String>, api_key: Option<String>,
api_url: String, api_url: String,
default_model: OpenAiModel, model: OpenAiModel,
http_client: Arc<dyn HttpClient>, http_client: Arc<dyn HttpClient>,
low_speed_timeout: Option<Duration>, low_speed_timeout: Option<Duration>,
settings_version: usize, settings_version: usize,
@ -26,7 +27,7 @@ pub struct OpenAiCompletionProvider {
impl OpenAiCompletionProvider { impl OpenAiCompletionProvider {
pub fn new( pub fn new(
default_model: OpenAiModel, model: OpenAiModel,
api_url: String, api_url: String,
http_client: Arc<dyn HttpClient>, http_client: Arc<dyn HttpClient>,
low_speed_timeout: Option<Duration>, low_speed_timeout: Option<Duration>,
@ -35,7 +36,7 @@ impl OpenAiCompletionProvider {
Self { Self {
api_key: None, api_key: None,
api_url, api_url,
default_model, model,
http_client, http_client,
low_speed_timeout, low_speed_timeout,
settings_version, settings_version,
@ -44,17 +45,21 @@ impl OpenAiCompletionProvider {
pub fn update( pub fn update(
&mut self, &mut self,
default_model: OpenAiModel, model: OpenAiModel,
api_url: String, api_url: String,
low_speed_timeout: Option<Duration>, low_speed_timeout: Option<Duration>,
settings_version: usize, settings_version: usize,
) { ) {
self.default_model = default_model; self.model = model;
self.api_url = api_url; self.api_url = api_url;
self.low_speed_timeout = low_speed_timeout; self.low_speed_timeout = low_speed_timeout;
self.settings_version = settings_version; self.settings_version = settings_version;
} }
pub fn available_models(&self) -> impl Iterator<Item = OpenAiModel> {
OpenAiModel::iter()
}
pub fn settings_version(&self) -> usize { pub fn settings_version(&self) -> usize {
self.settings_version self.settings_version
} }
@ -104,8 +109,8 @@ impl OpenAiCompletionProvider {
.into() .into()
} }
pub fn default_model(&self) -> OpenAiModel { pub fn model(&self) -> OpenAiModel {
self.default_model.clone() self.model.clone()
} }
pub fn count_tokens( pub fn count_tokens(
@ -152,7 +157,7 @@ impl OpenAiCompletionProvider {
fn to_open_ai_request(&self, request: LanguageModelRequest) -> Request { fn to_open_ai_request(&self, request: LanguageModelRequest) -> Request {
let model = match request.model { let model = match request.model {
LanguageModel::OpenAi(model) => model, LanguageModel::OpenAi(model) => model,
_ => self.default_model(), _ => self.model(),
}; };
Request { Request {

View File

@ -7,11 +7,12 @@ use client::{proto, Client};
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt, TryFutureExt}; use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt, TryFutureExt};
use gpui::{AnyView, AppContext, Task}; use gpui::{AnyView, AppContext, Task};
use std::{future, sync::Arc}; use std::{future, sync::Arc};
use strum::IntoEnumIterator;
use ui::prelude::*; use ui::prelude::*;
pub struct ZedDotDevCompletionProvider { pub struct ZedDotDevCompletionProvider {
client: Arc<Client>, client: Arc<Client>,
default_model: ZedDotDevModel, model: ZedDotDevModel,
settings_version: usize, settings_version: usize,
status: client::Status, status: client::Status,
_maintain_client_status: Task<()>, _maintain_client_status: Task<()>,
@ -19,7 +20,7 @@ pub struct ZedDotDevCompletionProvider {
impl ZedDotDevCompletionProvider { impl ZedDotDevCompletionProvider {
pub fn new( pub fn new(
default_model: ZedDotDevModel, model: ZedDotDevModel,
client: Arc<Client>, client: Arc<Client>,
settings_version: usize, settings_version: usize,
cx: &mut AppContext, cx: &mut AppContext,
@ -39,24 +40,39 @@ impl ZedDotDevCompletionProvider {
}); });
Self { Self {
client, client,
default_model, model,
settings_version, settings_version,
status, status,
_maintain_client_status: maintain_client_status, _maintain_client_status: maintain_client_status,
} }
} }
pub fn update(&mut self, default_model: ZedDotDevModel, settings_version: usize) { pub fn update(&mut self, model: ZedDotDevModel, settings_version: usize) {
self.default_model = default_model; self.model = model;
self.settings_version = settings_version; self.settings_version = settings_version;
} }
pub fn available_models(&self) -> impl Iterator<Item = ZedDotDevModel> {
let mut custom_model = if let ZedDotDevModel::Custom(custom_model) = self.model.clone() {
Some(custom_model)
} else {
None
};
ZedDotDevModel::iter().filter_map(move |model| {
if let ZedDotDevModel::Custom(_) = model {
Some(ZedDotDevModel::Custom(custom_model.take()?))
} else {
Some(model)
}
})
}
pub fn settings_version(&self) -> usize { pub fn settings_version(&self) -> usize {
self.settings_version self.settings_version
} }
pub fn default_model(&self) -> ZedDotDevModel { pub fn model(&self) -> ZedDotDevModel {
self.default_model.clone() self.model.clone()
} }
pub fn is_authenticated(&self) -> bool { pub fn is_authenticated(&self) -> bool {

View File

@ -0,0 +1,84 @@
use std::sync::Arc;
use crate::{assistant_settings::AssistantSettings, CompletionProvider, ToggleModelSelector};
use fs::Fs;
use settings::update_settings_file;
use ui::{popover_menu, prelude::*, ButtonLike, ContextMenu, PopoverMenuHandle, Tooltip};
#[derive(IntoElement)]
pub struct ModelSelector {
handle: PopoverMenuHandle<ContextMenu>,
fs: Arc<dyn Fs>,
}
impl ModelSelector {
pub fn new(handle: PopoverMenuHandle<ContextMenu>, fs: Arc<dyn Fs>) -> Self {
ModelSelector { handle, fs }
}
}
impl RenderOnce for ModelSelector {
fn render(self, cx: &mut WindowContext) -> impl IntoElement {
popover_menu("model-switcher")
.with_handle(self.handle)
.menu(move |cx| {
ContextMenu::build(cx, |mut menu, cx| {
for model in CompletionProvider::global(cx).available_models() {
menu = menu.custom_entry(
{
let model = model.clone();
move |_| Label::new(model.display_name()).into_any_element()
},
{
let fs = self.fs.clone();
let model = model.clone();
move |cx| {
let model = model.clone();
update_settings_file::<AssistantSettings>(
fs.clone(),
cx,
move |settings| settings.set_model(model),
);
}
},
);
}
menu
})
.into()
})
.trigger(
ButtonLike::new("active-model")
.child(
h_flex()
.w_full()
.gap_0p5()
.child(
div()
.overflow_x_hidden()
.flex_grow()
.whitespace_nowrap()
.child(
Label::new(
CompletionProvider::global(cx).model().display_name(),
)
.size(LabelSize::Small)
.color(Color::Muted),
),
)
.child(
div().child(
Icon::new(IconName::ChevronDown)
.color(Color::Muted)
.size(IconSize::XSmall),
),
),
)
.style(ButtonStyle::Subtle)
.tooltip(move |cx| {
Tooltip::for_action("Change Model", &ToggleModelSelector, cx)
}),
)
.anchor(gpui::AnchorCorner::BottomRight)
}
}

View File

@ -20,3 +20,4 @@ isahc.workspace = true
schemars = { workspace = true, optional = true } schemars = { workspace = true, optional = true }
serde.workspace = true serde.workspace = true
serde_json.workspace = true serde_json.workspace = true
strum.workspace = true

View File

@ -4,8 +4,8 @@ use http::{AsyncBody, HttpClient, Method, Request as HttpRequest};
use isahc::config::Configurable; use isahc::config::Configurable;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::{Map, Value}; use serde_json::{Map, Value};
use std::time::Duration; use std::{convert::TryFrom, future::Future, time::Duration};
use std::{convert::TryFrom, future::Future}; use strum::EnumIter;
pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1"; pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
@ -44,7 +44,7 @@ impl From<Role> for String {
} }
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)] #[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
pub enum Model { pub enum Model {
#[serde(rename = "gpt-3.5-turbo", alias = "gpt-3.5-turbo-0613")] #[serde(rename = "gpt-3.5-turbo", alias = "gpt-3.5-turbo-0613")]
ThreePointFiveTurbo, ThreePointFiveTurbo,

View File

@ -13,6 +13,51 @@ pub trait PopoverTrigger: IntoElement + Clickable + Selectable + 'static {}
impl<T: IntoElement + Clickable + Selectable + 'static> PopoverTrigger for T {} impl<T: IntoElement + Clickable + Selectable + 'static> PopoverTrigger for T {}
pub struct PopoverMenuHandle<M>(Rc<RefCell<Option<PopoverMenuHandleState<M>>>>);
impl<M> Clone for PopoverMenuHandle<M> {
fn clone(&self) -> Self {
Self(self.0.clone())
}
}
impl<M> Default for PopoverMenuHandle<M> {
fn default() -> Self {
Self(Rc::default())
}
}
struct PopoverMenuHandleState<M> {
menu_builder: Rc<dyn Fn(&mut WindowContext) -> Option<View<M>>>,
menu: Rc<RefCell<Option<View<M>>>>,
}
impl<M: ManagedView> PopoverMenuHandle<M> {
pub fn show(&self, cx: &mut WindowContext) {
if let Some(state) = self.0.borrow().as_ref() {
show_menu(&state.menu_builder, &state.menu, cx);
}
}
pub fn hide(&self, cx: &mut WindowContext) {
if let Some(state) = self.0.borrow().as_ref() {
if let Some(menu) = state.menu.borrow().as_ref() {
menu.update(cx, |_, cx| cx.emit(DismissEvent));
}
}
}
pub fn toggle(&self, cx: &mut WindowContext) {
if let Some(state) = self.0.borrow().as_ref() {
if state.menu.borrow().is_some() {
self.hide(cx);
} else {
self.show(cx);
}
}
}
}
pub struct PopoverMenu<M: ManagedView> { pub struct PopoverMenu<M: ManagedView> {
id: ElementId, id: ElementId,
child_builder: Option< child_builder: Option<
@ -28,6 +73,7 @@ pub struct PopoverMenu<M: ManagedView> {
anchor: AnchorCorner, anchor: AnchorCorner,
attach: Option<AnchorCorner>, attach: Option<AnchorCorner>,
offset: Option<Point<Pixels>>, offset: Option<Point<Pixels>>,
trigger_handle: Option<PopoverMenuHandle<M>>,
} }
impl<M: ManagedView> PopoverMenu<M> { impl<M: ManagedView> PopoverMenu<M> {
@ -36,35 +82,17 @@ impl<M: ManagedView> PopoverMenu<M> {
self self
} }
pub fn with_handle(mut self, handle: PopoverMenuHandle<M>) -> Self {
self.trigger_handle = Some(handle);
self
}
pub fn trigger<T: PopoverTrigger>(mut self, t: T) -> Self { pub fn trigger<T: PopoverTrigger>(mut self, t: T) -> Self {
self.child_builder = Some(Box::new(|menu, builder| { self.child_builder = Some(Box::new(|menu, builder| {
let open = menu.borrow().is_some(); let open = menu.borrow().is_some();
t.selected(open) t.selected(open)
.when_some(builder, |el, builder| { .when_some(builder, |el, builder| {
el.on_click({ el.on_click(move |_, cx| show_menu(&builder, &menu, cx))
move |_, cx| {
let Some(new_menu) = (builder)(cx) else {
return;
};
let menu2 = menu.clone();
let previous_focus_handle = cx.focused();
cx.subscribe(&new_menu, move |modal, _: &DismissEvent, cx| {
if modal.focus_handle(cx).contains_focused(cx) {
if let Some(previous_focus_handle) =
previous_focus_handle.as_ref()
{
cx.focus(previous_focus_handle);
}
}
*menu2.borrow_mut() = None;
cx.refresh();
})
.detach();
cx.focus_view(&new_menu);
*menu.borrow_mut() = Some(new_menu);
}
})
}) })
.into_any_element() .into_any_element()
})); }));
@ -111,6 +139,32 @@ impl<M: ManagedView> PopoverMenu<M> {
} }
} }
fn show_menu<M: ManagedView>(
builder: &Rc<dyn Fn(&mut WindowContext) -> Option<View<M>>>,
menu: &Rc<RefCell<Option<View<M>>>>,
cx: &mut WindowContext,
) {
let Some(new_menu) = (builder)(cx) else {
return;
};
let menu2 = menu.clone();
let previous_focus_handle = cx.focused();
cx.subscribe(&new_menu, move |modal, _: &DismissEvent, cx| {
if modal.focus_handle(cx).contains_focused(cx) {
if let Some(previous_focus_handle) = previous_focus_handle.as_ref() {
cx.focus(previous_focus_handle);
}
}
*menu2.borrow_mut() = None;
cx.refresh();
})
.detach();
cx.focus_view(&new_menu);
*menu.borrow_mut() = Some(new_menu);
cx.refresh();
}
/// Creates a [`PopoverMenu`] /// Creates a [`PopoverMenu`]
pub fn popover_menu<M: ManagedView>(id: impl Into<ElementId>) -> PopoverMenu<M> { pub fn popover_menu<M: ManagedView>(id: impl Into<ElementId>) -> PopoverMenu<M> {
PopoverMenu { PopoverMenu {
@ -120,6 +174,7 @@ pub fn popover_menu<M: ManagedView>(id: impl Into<ElementId>) -> PopoverMenu<M>
anchor: AnchorCorner::TopLeft, anchor: AnchorCorner::TopLeft,
attach: None, attach: None,
offset: None, offset: None,
trigger_handle: None,
} }
} }
@ -190,6 +245,15 @@ impl<M: ManagedView> Element for PopoverMenu<M> {
(child_builder)(element_state.menu.clone(), self.menu_builder.clone()) (child_builder)(element_state.menu.clone(), self.menu_builder.clone())
}); });
if let Some(trigger_handle) = self.trigger_handle.take() {
if let Some(menu_builder) = self.menu_builder.clone() {
*trigger_handle.0.borrow_mut() = Some(PopoverMenuHandleState {
menu_builder,
menu: element_state.menu.clone(),
});
}
}
let child_layout_id = child_element let child_layout_id = child_element
.as_mut() .as_mut()
.map(|child_element| child_element.request_layout(cx)); .map(|child_element| child_element.request_layout(cx));