mirror of
https://github.com/zed-industries/zed.git
synced 2024-09-19 02:17:35 +03:00
Improve model selection in the assistant (#12472)
https://github.com/zed-industries/zed/assets/482957/3b017850-b7b6-457a-9b2f-324d5533442e Release Notes: - Improved the UX for selecting a model in the assistant panel. You can now switch model using just the keyboard by pressing `alt-m`. Also, when switching models via the UI, settings will now be updated automatically.
This commit is contained in:
parent
5a149b970c
commit
6ff01b17ca
3
Cargo.lock
generated
3
Cargo.lock
generated
@ -230,6 +230,7 @@ dependencies = [
|
|||||||
"schemars",
|
"schemars",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
|
"strum",
|
||||||
"tokio",
|
"tokio",
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -376,6 +377,7 @@ dependencies = [
|
|||||||
"settings",
|
"settings",
|
||||||
"smol",
|
"smol",
|
||||||
"strsim 0.11.1",
|
"strsim 0.11.1",
|
||||||
|
"strum",
|
||||||
"telemetry_events",
|
"telemetry_events",
|
||||||
"theme",
|
"theme",
|
||||||
"tiktoken-rs",
|
"tiktoken-rs",
|
||||||
@ -6983,6 +6985,7 @@ dependencies = [
|
|||||||
"schemars",
|
"schemars",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
|
"strum",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -201,7 +201,8 @@
|
|||||||
"context": "AssistantPanel",
|
"context": "AssistantPanel",
|
||||||
"bindings": {
|
"bindings": {
|
||||||
"ctrl-g": "search::SelectNextMatch",
|
"ctrl-g": "search::SelectNextMatch",
|
||||||
"ctrl-shift-g": "search::SelectPrevMatch"
|
"ctrl-shift-g": "search::SelectPrevMatch",
|
||||||
|
"alt-m": "assistant::ToggleModelSelector"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -214,10 +214,11 @@
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"context": "AssistantPanel", // Used in the assistant crate, which we're replacing
|
"context": "AssistantPanel",
|
||||||
"bindings": {
|
"bindings": {
|
||||||
"cmd-g": "search::SelectNextMatch",
|
"cmd-g": "search::SelectNextMatch",
|
||||||
"cmd-shift-g": "search::SelectPrevMatch"
|
"cmd-shift-g": "search::SelectPrevMatch",
|
||||||
|
"alt-m": "assistant::ToggleModelSelector"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -23,6 +23,7 @@ isahc.workspace = true
|
|||||||
schemars = { workspace = true, optional = true }
|
schemars = { workspace = true, optional = true }
|
||||||
serde.workspace = true
|
serde.workspace = true
|
||||||
serde_json.workspace = true
|
serde_json.workspace = true
|
||||||
|
strum.workspace = true
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
tokio.workspace = true
|
tokio.workspace = true
|
||||||
|
@ -4,11 +4,12 @@ use http::{AsyncBody, HttpClient, Method, Request as HttpRequest};
|
|||||||
use isahc::config::Configurable;
|
use isahc::config::Configurable;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::{convert::TryFrom, time::Duration};
|
use std::{convert::TryFrom, time::Duration};
|
||||||
|
use strum::EnumIter;
|
||||||
|
|
||||||
pub const ANTHROPIC_API_URL: &'static str = "https://api.anthropic.com";
|
pub const ANTHROPIC_API_URL: &'static str = "https://api.anthropic.com";
|
||||||
|
|
||||||
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
|
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
|
||||||
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
|
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
|
||||||
pub enum Model {
|
pub enum Model {
|
||||||
#[default]
|
#[default]
|
||||||
#[serde(alias = "claude-3-opus", rename = "claude-3-opus-20240229")]
|
#[serde(alias = "claude-3-opus", rename = "claude-3-opus-20240229")]
|
||||||
|
@ -49,6 +49,7 @@ serde_json.workspace = true
|
|||||||
settings.workspace = true
|
settings.workspace = true
|
||||||
smol.workspace = true
|
smol.workspace = true
|
||||||
strsim = "0.11"
|
strsim = "0.11"
|
||||||
|
strum.workspace = true
|
||||||
telemetry_events.workspace = true
|
telemetry_events.workspace = true
|
||||||
theme.workspace = true
|
theme.workspace = true
|
||||||
tiktoken-rs.workspace = true
|
tiktoken-rs.workspace = true
|
||||||
|
@ -2,6 +2,7 @@ pub mod assistant_panel;
|
|||||||
pub mod assistant_settings;
|
pub mod assistant_settings;
|
||||||
mod codegen;
|
mod codegen;
|
||||||
mod completion_provider;
|
mod completion_provider;
|
||||||
|
mod model_selector;
|
||||||
mod prompts;
|
mod prompts;
|
||||||
mod saved_conversation;
|
mod saved_conversation;
|
||||||
mod search;
|
mod search;
|
||||||
@ -15,6 +16,7 @@ use client::{proto, Client};
|
|||||||
use command_palette_hooks::CommandPaletteFilter;
|
use command_palette_hooks::CommandPaletteFilter;
|
||||||
pub(crate) use completion_provider::*;
|
pub(crate) use completion_provider::*;
|
||||||
use gpui::{actions, AppContext, Global, SharedString, UpdateGlobal};
|
use gpui::{actions, AppContext, Global, SharedString, UpdateGlobal};
|
||||||
|
pub(crate) use model_selector::*;
|
||||||
pub(crate) use saved_conversation::*;
|
pub(crate) use saved_conversation::*;
|
||||||
use semantic_index::{CloudEmbeddingProvider, SemanticIndex};
|
use semantic_index::{CloudEmbeddingProvider, SemanticIndex};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
@ -38,7 +40,8 @@ actions!(
|
|||||||
InsertActivePrompt,
|
InsertActivePrompt,
|
||||||
ToggleHistory,
|
ToggleHistory,
|
||||||
ApplyEdit,
|
ApplyEdit,
|
||||||
ConfirmCommand
|
ConfirmCommand,
|
||||||
|
ToggleModelSelector
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
use crate::prompts::{generate_content_prompt, PromptLibrary, PromptManager};
|
use crate::prompts::{generate_content_prompt, PromptLibrary, PromptManager};
|
||||||
use crate::slash_command::{rustdoc_command, search_command, tabs_command};
|
use crate::slash_command::{rustdoc_command, search_command, tabs_command};
|
||||||
use crate::{
|
use crate::{
|
||||||
assistant_settings::{AssistantDockPosition, AssistantSettings, ZedDotDevModel},
|
assistant_settings::{AssistantDockPosition, AssistantSettings},
|
||||||
codegen::{self, Codegen, CodegenKind},
|
codegen::{self, Codegen, CodegenKind},
|
||||||
search::*,
|
search::*,
|
||||||
slash_command::{
|
slash_command::{
|
||||||
@ -9,10 +9,11 @@ use crate::{
|
|||||||
SlashCommandCompletionProvider, SlashCommandLine, SlashCommandRegistry,
|
SlashCommandCompletionProvider, SlashCommandLine, SlashCommandRegistry,
|
||||||
},
|
},
|
||||||
ApplyEdit, Assist, CompletionProvider, ConfirmCommand, CycleMessageRole, InlineAssist,
|
ApplyEdit, Assist, CompletionProvider, ConfirmCommand, CycleMessageRole, InlineAssist,
|
||||||
LanguageModel, LanguageModelRequest, LanguageModelRequestMessage, MessageId, MessageMetadata,
|
LanguageModelRequest, LanguageModelRequestMessage, MessageId, MessageMetadata, MessageStatus,
|
||||||
MessageStatus, QuoteSelection, ResetKey, Role, SavedConversation, SavedConversationMetadata,
|
QuoteSelection, ResetKey, Role, SavedConversation, SavedConversationMetadata, SavedMessage,
|
||||||
SavedMessage, Split, ToggleFocus, ToggleHistory,
|
Split, ToggleFocus, ToggleHistory,
|
||||||
};
|
};
|
||||||
|
use crate::{ModelSelector, ToggleModelSelector};
|
||||||
use anyhow::{anyhow, Result};
|
use anyhow::{anyhow, Result};
|
||||||
use assistant_slash_command::{SlashCommandOutput, SlashCommandOutputSection};
|
use assistant_slash_command::{SlashCommandOutput, SlashCommandOutputSection};
|
||||||
use client::telemetry::Telemetry;
|
use client::telemetry::Telemetry;
|
||||||
@ -64,8 +65,8 @@ use std::{
|
|||||||
use telemetry_events::AssistantKind;
|
use telemetry_events::AssistantKind;
|
||||||
use theme::ThemeSettings;
|
use theme::ThemeSettings;
|
||||||
use ui::{
|
use ui::{
|
||||||
popover_menu, prelude::*, ButtonLike, ContextMenu, ElevationIndex, KeyBinding, Tab, TabBar,
|
popover_menu, prelude::*, ButtonLike, ContextMenu, ElevationIndex, KeyBinding,
|
||||||
Tooltip,
|
PopoverMenuHandle, Tab, TabBar, Tooltip,
|
||||||
};
|
};
|
||||||
use util::{paths::CONVERSATIONS_DIR, post_inc, ResultExt, TryFutureExt};
|
use util::{paths::CONVERSATIONS_DIR, post_inc, ResultExt, TryFutureExt};
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
@ -119,8 +120,8 @@ pub struct AssistantPanel {
|
|||||||
pending_inline_assist_ids_by_editor: HashMap<WeakView<Editor>, Vec<usize>>,
|
pending_inline_assist_ids_by_editor: HashMap<WeakView<Editor>, Vec<usize>>,
|
||||||
inline_prompt_history: VecDeque<String>,
|
inline_prompt_history: VecDeque<String>,
|
||||||
_watch_saved_conversations: Task<Result<()>>,
|
_watch_saved_conversations: Task<Result<()>>,
|
||||||
model: LanguageModel,
|
|
||||||
authentication_prompt: Option<AnyView>,
|
authentication_prompt: Option<AnyView>,
|
||||||
|
model_menu_handle: PopoverMenuHandle<ContextMenu>,
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ActiveConversationEditor {
|
struct ActiveConversationEditor {
|
||||||
@ -203,7 +204,6 @@ impl AssistantPanel {
|
|||||||
}
|
}
|
||||||
}),
|
}),
|
||||||
];
|
];
|
||||||
let model = CompletionProvider::global(cx).default_model();
|
|
||||||
|
|
||||||
cx.observe_global::<FileIcons>(|_, cx| {
|
cx.observe_global::<FileIcons>(|_, cx| {
|
||||||
cx.notify();
|
cx.notify();
|
||||||
@ -244,8 +244,8 @@ impl AssistantPanel {
|
|||||||
pending_inline_assist_ids_by_editor: Default::default(),
|
pending_inline_assist_ids_by_editor: Default::default(),
|
||||||
inline_prompt_history: Default::default(),
|
inline_prompt_history: Default::default(),
|
||||||
_watch_saved_conversations,
|
_watch_saved_conversations,
|
||||||
model,
|
|
||||||
authentication_prompt: None,
|
authentication_prompt: None,
|
||||||
|
model_menu_handle: PopoverMenuHandle::default(),
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
@ -277,12 +277,20 @@ impl AssistantPanel {
|
|||||||
if self.is_authenticated(cx) {
|
if self.is_authenticated(cx) {
|
||||||
self.authentication_prompt = None;
|
self.authentication_prompt = None;
|
||||||
|
|
||||||
let model = CompletionProvider::global(cx).default_model();
|
if let Some(editor) = self.active_conversation_editor() {
|
||||||
self.set_model(model, cx);
|
editor.update(cx, |active_conversation, cx| {
|
||||||
|
active_conversation
|
||||||
|
.conversation
|
||||||
|
.update(cx, |conversation, cx| {
|
||||||
|
conversation.completion_provider_changed(cx)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
if self.active_conversation_editor().is_none() {
|
if self.active_conversation_editor().is_none() {
|
||||||
self.new_conversation(cx);
|
self.new_conversation(cx);
|
||||||
}
|
}
|
||||||
|
cx.notify();
|
||||||
} else if self.authentication_prompt.is_none()
|
} else if self.authentication_prompt.is_none()
|
||||||
|| prev_settings_version != CompletionProvider::global(cx).settings_version()
|
|| prev_settings_version != CompletionProvider::global(cx).settings_version()
|
||||||
{
|
{
|
||||||
@ -290,6 +298,7 @@ impl AssistantPanel {
|
|||||||
Some(cx.update_global::<CompletionProvider, _>(|provider, cx| {
|
Some(cx.update_global::<CompletionProvider, _>(|provider, cx| {
|
||||||
provider.authentication_prompt(cx)
|
provider.authentication_prompt(cx)
|
||||||
}));
|
}));
|
||||||
|
cx.notify();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -734,7 +743,7 @@ impl AssistantPanel {
|
|||||||
.map(|message| message.to_request_message(buffer)),
|
.map(|message| message.to_request_message(buffer)),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
let model = self.model.clone();
|
let model = CompletionProvider::global(cx).model();
|
||||||
|
|
||||||
cx.spawn(|_, mut cx| async move {
|
cx.spawn(|_, mut cx| async move {
|
||||||
// I Don't know if we want to return a ? here.
|
// I Don't know if we want to return a ? here.
|
||||||
@ -809,7 +818,6 @@ impl AssistantPanel {
|
|||||||
|
|
||||||
let editor = cx.new_view(|cx| {
|
let editor = cx.new_view(|cx| {
|
||||||
ConversationEditor::new(
|
ConversationEditor::new(
|
||||||
self.model.clone(),
|
|
||||||
self.languages.clone(),
|
self.languages.clone(),
|
||||||
self.slash_commands.clone(),
|
self.slash_commands.clone(),
|
||||||
self.fs.clone(),
|
self.fs.clone(),
|
||||||
@ -850,53 +858,6 @@ impl AssistantPanel {
|
|||||||
cx.notify();
|
cx.notify();
|
||||||
}
|
}
|
||||||
|
|
||||||
fn cycle_model(&mut self, cx: &mut ViewContext<Self>) {
|
|
||||||
let next_model = match &self.model {
|
|
||||||
LanguageModel::OpenAi(model) => LanguageModel::OpenAi(match &model {
|
|
||||||
open_ai::Model::ThreePointFiveTurbo => open_ai::Model::Four,
|
|
||||||
open_ai::Model::Four => open_ai::Model::FourTurbo,
|
|
||||||
open_ai::Model::FourTurbo => open_ai::Model::FourOmni,
|
|
||||||
open_ai::Model::FourOmni => open_ai::Model::ThreePointFiveTurbo,
|
|
||||||
}),
|
|
||||||
LanguageModel::Anthropic(model) => LanguageModel::Anthropic(match &model {
|
|
||||||
anthropic::Model::Claude3Opus => anthropic::Model::Claude3Sonnet,
|
|
||||||
anthropic::Model::Claude3Sonnet => anthropic::Model::Claude3Haiku,
|
|
||||||
anthropic::Model::Claude3Haiku => anthropic::Model::Claude3Opus,
|
|
||||||
}),
|
|
||||||
LanguageModel::ZedDotDev(model) => LanguageModel::ZedDotDev(match &model {
|
|
||||||
ZedDotDevModel::Gpt3Point5Turbo => ZedDotDevModel::Gpt4,
|
|
||||||
ZedDotDevModel::Gpt4 => ZedDotDevModel::Gpt4Turbo,
|
|
||||||
ZedDotDevModel::Gpt4Turbo => ZedDotDevModel::Gpt4Omni,
|
|
||||||
ZedDotDevModel::Gpt4Omni => ZedDotDevModel::Claude3Opus,
|
|
||||||
ZedDotDevModel::Claude3Opus => ZedDotDevModel::Claude3Sonnet,
|
|
||||||
ZedDotDevModel::Claude3Sonnet => ZedDotDevModel::Claude3Haiku,
|
|
||||||
ZedDotDevModel::Claude3Haiku => {
|
|
||||||
match CompletionProvider::global(cx).default_model() {
|
|
||||||
LanguageModel::ZedDotDev(custom @ ZedDotDevModel::Custom(_)) => custom,
|
|
||||||
_ => ZedDotDevModel::Gpt3Point5Turbo,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ZedDotDevModel::Custom(_) => ZedDotDevModel::Gpt3Point5Turbo,
|
|
||||||
}),
|
|
||||||
};
|
|
||||||
|
|
||||||
self.set_model(next_model, cx);
|
|
||||||
}
|
|
||||||
|
|
||||||
fn set_model(&mut self, model: LanguageModel, cx: &mut ViewContext<Self>) {
|
|
||||||
self.model = model.clone();
|
|
||||||
if let Some(editor) = self.active_conversation_editor() {
|
|
||||||
editor.update(cx, |active_conversation, cx| {
|
|
||||||
active_conversation
|
|
||||||
.conversation
|
|
||||||
.update(cx, |conversation, cx| {
|
|
||||||
conversation.set_model(model, cx);
|
|
||||||
})
|
|
||||||
})
|
|
||||||
}
|
|
||||||
cx.notify();
|
|
||||||
}
|
|
||||||
|
|
||||||
fn handle_conversation_editor_event(
|
fn handle_conversation_editor_event(
|
||||||
&mut self,
|
&mut self,
|
||||||
_: View<ConversationEditor>,
|
_: View<ConversationEditor>,
|
||||||
@ -978,6 +939,10 @@ impl AssistantPanel {
|
|||||||
.detach_and_log_err(cx);
|
.detach_and_log_err(cx);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn toggle_model_selector(&mut self, _: &ToggleModelSelector, cx: &mut ViewContext<Self>) {
|
||||||
|
self.model_menu_handle.toggle(cx);
|
||||||
|
}
|
||||||
|
|
||||||
fn active_conversation_editor(&self) -> Option<&View<ConversationEditor>> {
|
fn active_conversation_editor(&self) -> Option<&View<ConversationEditor>> {
|
||||||
Some(&self.active_conversation_editor.as_ref()?.editor)
|
Some(&self.active_conversation_editor.as_ref()?.editor)
|
||||||
}
|
}
|
||||||
@ -1133,10 +1098,8 @@ impl AssistantPanel {
|
|||||||
|
|
||||||
cx.spawn(|this, mut cx| async move {
|
cx.spawn(|this, mut cx| async move {
|
||||||
let saved_conversation = SavedConversation::load(&path, fs.as_ref()).await?;
|
let saved_conversation = SavedConversation::load(&path, fs.as_ref()).await?;
|
||||||
let model = this.update(&mut cx, |this, _| this.model.clone())?;
|
|
||||||
let conversation = Conversation::deserialize(
|
let conversation = Conversation::deserialize(
|
||||||
saved_conversation,
|
saved_conversation,
|
||||||
model,
|
|
||||||
path.clone(),
|
path.clone(),
|
||||||
languages,
|
languages,
|
||||||
slash_commands,
|
slash_commands,
|
||||||
@ -1206,7 +1169,10 @@ impl AssistantPanel {
|
|||||||
this.child(
|
this.child(
|
||||||
h_flex()
|
h_flex()
|
||||||
.gap_1()
|
.gap_1()
|
||||||
.child(self.render_model(&conversation, cx))
|
.child(ModelSelector::new(
|
||||||
|
self.model_menu_handle.clone(),
|
||||||
|
self.fs.clone(),
|
||||||
|
))
|
||||||
.children(self.render_remaining_tokens(&conversation, cx)),
|
.children(self.render_remaining_tokens(&conversation, cx)),
|
||||||
)
|
)
|
||||||
.child(
|
.child(
|
||||||
@ -1256,6 +1222,7 @@ impl AssistantPanel {
|
|||||||
.on_action(cx.listener(AssistantPanel::select_prev_match))
|
.on_action(cx.listener(AssistantPanel::select_prev_match))
|
||||||
.on_action(cx.listener(AssistantPanel::handle_editor_cancel))
|
.on_action(cx.listener(AssistantPanel::handle_editor_cancel))
|
||||||
.on_action(cx.listener(AssistantPanel::reset_credentials))
|
.on_action(cx.listener(AssistantPanel::reset_credentials))
|
||||||
|
.on_action(cx.listener(AssistantPanel::toggle_model_selector))
|
||||||
.track_focus(&self.focus_handle)
|
.track_focus(&self.focus_handle)
|
||||||
.child(header)
|
.child(header)
|
||||||
.children(if self.toolbar.read(cx).hidden() {
|
.children(if self.toolbar.read(cx).hidden() {
|
||||||
@ -1314,23 +1281,12 @@ impl AssistantPanel {
|
|||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn render_model(
|
|
||||||
&self,
|
|
||||||
conversation: &Model<Conversation>,
|
|
||||||
cx: &mut ViewContext<Self>,
|
|
||||||
) -> impl IntoElement {
|
|
||||||
Button::new("current_model", conversation.read(cx).model.display_name())
|
|
||||||
.style(ButtonStyle::Filled)
|
|
||||||
.tooltip(move |cx| Tooltip::text("Change Model", cx))
|
|
||||||
.on_click(cx.listener(|this, _, cx| this.cycle_model(cx)))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn render_remaining_tokens(
|
fn render_remaining_tokens(
|
||||||
&self,
|
&self,
|
||||||
conversation: &Model<Conversation>,
|
conversation: &Model<Conversation>,
|
||||||
cx: &mut ViewContext<Self>,
|
cx: &mut ViewContext<Self>,
|
||||||
) -> Option<impl IntoElement> {
|
) -> Option<impl IntoElement> {
|
||||||
let remaining_tokens = conversation.read(cx).remaining_tokens()?;
|
let remaining_tokens = conversation.read(cx).remaining_tokens(cx)?;
|
||||||
let remaining_tokens_color = if remaining_tokens <= 0 {
|
let remaining_tokens_color = if remaining_tokens <= 0 {
|
||||||
Color::Error
|
Color::Error
|
||||||
} else if remaining_tokens <= 500 {
|
} else if remaining_tokens <= 500 {
|
||||||
@ -1486,7 +1442,6 @@ pub struct Conversation {
|
|||||||
pending_summary: Task<Option<()>>,
|
pending_summary: Task<Option<()>>,
|
||||||
completion_count: usize,
|
completion_count: usize,
|
||||||
pending_completions: Vec<PendingCompletion>,
|
pending_completions: Vec<PendingCompletion>,
|
||||||
model: LanguageModel,
|
|
||||||
token_count: Option<usize>,
|
token_count: Option<usize>,
|
||||||
pending_token_count: Task<Option<()>>,
|
pending_token_count: Task<Option<()>>,
|
||||||
pending_edit_suggestion_parse: Option<Task<()>>,
|
pending_edit_suggestion_parse: Option<Task<()>>,
|
||||||
@ -1502,7 +1457,6 @@ impl EventEmitter<ConversationEvent> for Conversation {}
|
|||||||
|
|
||||||
impl Conversation {
|
impl Conversation {
|
||||||
fn new(
|
fn new(
|
||||||
model: LanguageModel,
|
|
||||||
language_registry: Arc<LanguageRegistry>,
|
language_registry: Arc<LanguageRegistry>,
|
||||||
slash_command_registry: Arc<SlashCommandRegistry>,
|
slash_command_registry: Arc<SlashCommandRegistry>,
|
||||||
telemetry: Option<Arc<Telemetry>>,
|
telemetry: Option<Arc<Telemetry>>,
|
||||||
@ -1530,7 +1484,6 @@ impl Conversation {
|
|||||||
token_count: None,
|
token_count: None,
|
||||||
pending_token_count: Task::ready(None),
|
pending_token_count: Task::ready(None),
|
||||||
pending_edit_suggestion_parse: None,
|
pending_edit_suggestion_parse: None,
|
||||||
model,
|
|
||||||
_subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
|
_subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
|
||||||
pending_save: Task::ready(Ok(())),
|
pending_save: Task::ready(Ok(())),
|
||||||
path: None,
|
path: None,
|
||||||
@ -1583,7 +1536,6 @@ impl Conversation {
|
|||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
async fn deserialize(
|
async fn deserialize(
|
||||||
saved_conversation: SavedConversation,
|
saved_conversation: SavedConversation,
|
||||||
model: LanguageModel,
|
|
||||||
path: PathBuf,
|
path: PathBuf,
|
||||||
language_registry: Arc<LanguageRegistry>,
|
language_registry: Arc<LanguageRegistry>,
|
||||||
slash_command_registry: Arc<SlashCommandRegistry>,
|
slash_command_registry: Arc<SlashCommandRegistry>,
|
||||||
@ -1640,7 +1592,6 @@ impl Conversation {
|
|||||||
token_count: None,
|
token_count: None,
|
||||||
pending_edit_suggestion_parse: None,
|
pending_edit_suggestion_parse: None,
|
||||||
pending_token_count: Task::ready(None),
|
pending_token_count: Task::ready(None),
|
||||||
model,
|
|
||||||
_subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
|
_subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
|
||||||
pending_save: Task::ready(Ok(())),
|
pending_save: Task::ready(Ok(())),
|
||||||
path: Some(path),
|
path: Some(path),
|
||||||
@ -1938,12 +1889,12 @@ impl Conversation {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn remaining_tokens(&self) -> Option<isize> {
|
fn remaining_tokens(&self, cx: &AppContext) -> Option<isize> {
|
||||||
Some(self.model.max_token_count() as isize - self.token_count? as isize)
|
let model = CompletionProvider::global(cx).model();
|
||||||
|
Some(model.max_token_count() as isize - self.token_count? as isize)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn set_model(&mut self, model: LanguageModel, cx: &mut ModelContext<Self>) {
|
fn completion_provider_changed(&mut self, cx: &mut ModelContext<Self>) {
|
||||||
self.model = model;
|
|
||||||
self.count_remaining_tokens(cx);
|
self.count_remaining_tokens(cx);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2079,10 +2030,11 @@ impl Conversation {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if let Some(telemetry) = this.telemetry.as_ref() {
|
if let Some(telemetry) = this.telemetry.as_ref() {
|
||||||
|
let model = CompletionProvider::global(cx).model();
|
||||||
telemetry.report_assistant_event(
|
telemetry.report_assistant_event(
|
||||||
this.id.clone(),
|
this.id.clone(),
|
||||||
AssistantKind::Panel,
|
AssistantKind::Panel,
|
||||||
this.model.telemetry_id(),
|
model.telemetry_id(),
|
||||||
response_latency,
|
response_latency,
|
||||||
error_message,
|
error_message,
|
||||||
);
|
);
|
||||||
@ -2111,7 +2063,7 @@ impl Conversation {
|
|||||||
.map(|message| message.to_request_message(self.buffer.read(cx)));
|
.map(|message| message.to_request_message(self.buffer.read(cx)));
|
||||||
|
|
||||||
LanguageModelRequest {
|
LanguageModelRequest {
|
||||||
model: self.model.clone(),
|
model: CompletionProvider::global(cx).model(),
|
||||||
messages: messages.collect(),
|
messages: messages.collect(),
|
||||||
stop: vec![],
|
stop: vec![],
|
||||||
temperature: 1.0,
|
temperature: 1.0,
|
||||||
@ -2300,7 +2252,7 @@ impl Conversation {
|
|||||||
.into(),
|
.into(),
|
||||||
}));
|
}));
|
||||||
let request = LanguageModelRequest {
|
let request = LanguageModelRequest {
|
||||||
model: self.model.clone(),
|
model: CompletionProvider::global(cx).model(),
|
||||||
messages: messages.collect(),
|
messages: messages.collect(),
|
||||||
stop: vec![],
|
stop: vec![],
|
||||||
temperature: 1.0,
|
temperature: 1.0,
|
||||||
@ -2605,7 +2557,6 @@ pub struct ConversationEditor {
|
|||||||
|
|
||||||
impl ConversationEditor {
|
impl ConversationEditor {
|
||||||
fn new(
|
fn new(
|
||||||
model: LanguageModel,
|
|
||||||
language_registry: Arc<LanguageRegistry>,
|
language_registry: Arc<LanguageRegistry>,
|
||||||
slash_command_registry: Arc<SlashCommandRegistry>,
|
slash_command_registry: Arc<SlashCommandRegistry>,
|
||||||
fs: Arc<dyn Fs>,
|
fs: Arc<dyn Fs>,
|
||||||
@ -2618,7 +2569,6 @@ impl ConversationEditor {
|
|||||||
|
|
||||||
let conversation = cx.new_model(|cx| {
|
let conversation = cx.new_model(|cx| {
|
||||||
Conversation::new(
|
Conversation::new(
|
||||||
model,
|
|
||||||
language_registry,
|
language_registry,
|
||||||
slash_command_registry,
|
slash_command_registry,
|
||||||
Some(telemetry),
|
Some(telemetry),
|
||||||
@ -3847,15 +3797,8 @@ mod tests {
|
|||||||
init(cx);
|
init(cx);
|
||||||
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
|
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
|
||||||
|
|
||||||
let conversation = cx.new_model(|cx| {
|
let conversation =
|
||||||
Conversation::new(
|
cx.new_model(|cx| Conversation::new(registry, Default::default(), None, cx));
|
||||||
LanguageModel::default(),
|
|
||||||
registry,
|
|
||||||
Default::default(),
|
|
||||||
None,
|
|
||||||
cx,
|
|
||||||
)
|
|
||||||
});
|
|
||||||
let buffer = conversation.read(cx).buffer.clone();
|
let buffer = conversation.read(cx).buffer.clone();
|
||||||
|
|
||||||
let message_1 = conversation.read(cx).message_anchors[0].clone();
|
let message_1 = conversation.read(cx).message_anchors[0].clone();
|
||||||
@ -3986,15 +3929,8 @@ mod tests {
|
|||||||
init(cx);
|
init(cx);
|
||||||
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
|
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
|
||||||
|
|
||||||
let conversation = cx.new_model(|cx| {
|
let conversation =
|
||||||
Conversation::new(
|
cx.new_model(|cx| Conversation::new(registry, Default::default(), None, cx));
|
||||||
LanguageModel::default(),
|
|
||||||
registry,
|
|
||||||
Default::default(),
|
|
||||||
None,
|
|
||||||
cx,
|
|
||||||
)
|
|
||||||
});
|
|
||||||
let buffer = conversation.read(cx).buffer.clone();
|
let buffer = conversation.read(cx).buffer.clone();
|
||||||
|
|
||||||
let message_1 = conversation.read(cx).message_anchors[0].clone();
|
let message_1 = conversation.read(cx).message_anchors[0].clone();
|
||||||
@ -4092,15 +4028,8 @@ mod tests {
|
|||||||
cx.set_global(settings_store);
|
cx.set_global(settings_store);
|
||||||
init(cx);
|
init(cx);
|
||||||
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
|
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
|
||||||
let conversation = cx.new_model(|cx| {
|
let conversation =
|
||||||
Conversation::new(
|
cx.new_model(|cx| Conversation::new(registry, Default::default(), None, cx));
|
||||||
LanguageModel::default(),
|
|
||||||
registry,
|
|
||||||
Default::default(),
|
|
||||||
None,
|
|
||||||
cx,
|
|
||||||
)
|
|
||||||
});
|
|
||||||
let buffer = conversation.read(cx).buffer.clone();
|
let buffer = conversation.read(cx).buffer.clone();
|
||||||
|
|
||||||
let message_1 = conversation.read(cx).message_anchors[0].clone();
|
let message_1 = conversation.read(cx).message_anchors[0].clone();
|
||||||
@ -4209,15 +4138,8 @@ mod tests {
|
|||||||
));
|
));
|
||||||
|
|
||||||
let registry = Arc::new(LanguageRegistry::test(cx.executor()));
|
let registry = Arc::new(LanguageRegistry::test(cx.executor()));
|
||||||
let conversation = cx.new_model(|cx| {
|
let conversation = cx
|
||||||
Conversation::new(
|
.new_model(|cx| Conversation::new(registry.clone(), slash_command_registry, None, cx));
|
||||||
LanguageModel::default(),
|
|
||||||
registry.clone(),
|
|
||||||
slash_command_registry,
|
|
||||||
None,
|
|
||||||
cx,
|
|
||||||
)
|
|
||||||
});
|
|
||||||
|
|
||||||
let output_ranges = Rc::new(RefCell::new(HashSet::default()));
|
let output_ranges = Rc::new(RefCell::new(HashSet::default()));
|
||||||
conversation.update(cx, |_, cx| {
|
conversation.update(cx, |_, cx| {
|
||||||
@ -4390,15 +4312,8 @@ mod tests {
|
|||||||
cx.set_global(CompletionProvider::Fake(FakeCompletionProvider::default()));
|
cx.set_global(CompletionProvider::Fake(FakeCompletionProvider::default()));
|
||||||
cx.update(init);
|
cx.update(init);
|
||||||
let registry = Arc::new(LanguageRegistry::test(cx.executor()));
|
let registry = Arc::new(LanguageRegistry::test(cx.executor()));
|
||||||
let conversation = cx.new_model(|cx| {
|
let conversation =
|
||||||
Conversation::new(
|
cx.new_model(|cx| Conversation::new(registry.clone(), Default::default(), None, cx));
|
||||||
LanguageModel::default(),
|
|
||||||
registry.clone(),
|
|
||||||
Default::default(),
|
|
||||||
None,
|
|
||||||
cx,
|
|
||||||
)
|
|
||||||
});
|
|
||||||
let buffer = conversation.read_with(cx, |conversation, _| conversation.buffer.clone());
|
let buffer = conversation.read_with(cx, |conversation, _| conversation.buffer.clone());
|
||||||
let message_0 =
|
let message_0 =
|
||||||
conversation.read_with(cx, |conversation, _| conversation.message_anchors[0].id);
|
conversation.read_with(cx, |conversation, _| conversation.message_anchors[0].id);
|
||||||
@ -4434,7 +4349,6 @@ mod tests {
|
|||||||
|
|
||||||
let deserialized_conversation = Conversation::deserialize(
|
let deserialized_conversation = Conversation::deserialize(
|
||||||
conversation.read_with(cx, |conversation, cx| conversation.serialize(cx)),
|
conversation.read_with(cx, |conversation, cx| conversation.serialize(cx)),
|
||||||
LanguageModel::default(),
|
|
||||||
Default::default(),
|
Default::default(),
|
||||||
registry.clone(),
|
registry.clone(),
|
||||||
Default::default(),
|
Default::default(),
|
||||||
|
@ -12,8 +12,11 @@ use serde::{
|
|||||||
Deserialize, Deserializer, Serialize, Serializer,
|
Deserialize, Deserializer, Serialize, Serializer,
|
||||||
};
|
};
|
||||||
use settings::{Settings, SettingsSources};
|
use settings::{Settings, SettingsSources};
|
||||||
|
use strum::{EnumIter, IntoEnumIterator};
|
||||||
|
|
||||||
#[derive(Clone, Debug, Default, PartialEq)]
|
use crate::LanguageModel;
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Default, PartialEq, EnumIter)]
|
||||||
pub enum ZedDotDevModel {
|
pub enum ZedDotDevModel {
|
||||||
Gpt3Point5Turbo,
|
Gpt3Point5Turbo,
|
||||||
Gpt4,
|
Gpt4,
|
||||||
@ -53,13 +56,10 @@ impl<'de> Deserialize<'de> for ZedDotDevModel {
|
|||||||
where
|
where
|
||||||
E: de::Error,
|
E: de::Error,
|
||||||
{
|
{
|
||||||
match value {
|
let model = ZedDotDevModel::iter()
|
||||||
"gpt-3.5-turbo" => Ok(ZedDotDevModel::Gpt3Point5Turbo),
|
.find(|model| model.id() == value)
|
||||||
"gpt-4" => Ok(ZedDotDevModel::Gpt4),
|
.unwrap_or_else(|| ZedDotDevModel::Custom(value.to_string()));
|
||||||
"gpt-4-turbo-preview" => Ok(ZedDotDevModel::Gpt4Turbo),
|
Ok(model)
|
||||||
"gpt-4o" => Ok(ZedDotDevModel::Gpt4Omni),
|
|
||||||
_ => Ok(ZedDotDevModel::Custom(value.to_owned())),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -73,24 +73,23 @@ impl JsonSchema for ZedDotDevModel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn json_schema(_generator: &mut schemars::gen::SchemaGenerator) -> Schema {
|
fn json_schema(_generator: &mut schemars::gen::SchemaGenerator) -> Schema {
|
||||||
let variants = vec![
|
let variants = ZedDotDevModel::iter()
|
||||||
"gpt-3.5-turbo".to_owned(),
|
.filter_map(|model| {
|
||||||
"gpt-4".to_owned(),
|
let id = model.id();
|
||||||
"gpt-4-turbo-preview".to_owned(),
|
if id.is_empty() {
|
||||||
"gpt-4o".to_owned(),
|
None
|
||||||
];
|
} else {
|
||||||
|
Some(id.to_string())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>();
|
||||||
Schema::Object(SchemaObject {
|
Schema::Object(SchemaObject {
|
||||||
instance_type: Some(InstanceType::String.into()),
|
instance_type: Some(InstanceType::String.into()),
|
||||||
enum_values: Some(variants.into_iter().map(|s| s.into()).collect()),
|
enum_values: Some(variants.iter().map(|s| s.clone().into()).collect()),
|
||||||
metadata: Some(Box::new(Metadata {
|
metadata: Some(Box::new(Metadata {
|
||||||
title: Some("ZedDotDevModel".to_owned()),
|
title: Some("ZedDotDevModel".to_owned()),
|
||||||
default: Some(serde_json::json!("gpt-4-turbo-preview")),
|
default: Some(ZedDotDevModel::default().id().into()),
|
||||||
examples: vec![
|
examples: variants.into_iter().map(Into::into).collect(),
|
||||||
serde_json::json!("gpt-3.5-turbo"),
|
|
||||||
serde_json::json!("gpt-4"),
|
|
||||||
serde_json::json!("gpt-4-turbo-preview"),
|
|
||||||
serde_json::json!("custom-model-name"),
|
|
||||||
],
|
|
||||||
..Default::default()
|
..Default::default()
|
||||||
})),
|
})),
|
||||||
..Default::default()
|
..Default::default()
|
||||||
@ -145,51 +144,55 @@ pub enum AssistantDockPosition {
|
|||||||
Bottom,
|
Bottom,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
|
#[derive(Debug, PartialEq)]
|
||||||
#[serde(tag = "name", rename_all = "snake_case")]
|
|
||||||
pub enum AssistantProvider {
|
pub enum AssistantProvider {
|
||||||
#[serde(rename = "zed.dev")]
|
|
||||||
ZedDotDev {
|
ZedDotDev {
|
||||||
#[serde(default)]
|
model: ZedDotDevModel,
|
||||||
default_model: ZedDotDevModel,
|
|
||||||
},
|
},
|
||||||
#[serde(rename = "openai")]
|
|
||||||
OpenAi {
|
OpenAi {
|
||||||
#[serde(default)]
|
model: OpenAiModel,
|
||||||
default_model: OpenAiModel,
|
|
||||||
#[serde(default = "open_ai_url")]
|
|
||||||
api_url: String,
|
api_url: String,
|
||||||
#[serde(default)]
|
|
||||||
low_speed_timeout_in_seconds: Option<u64>,
|
low_speed_timeout_in_seconds: Option<u64>,
|
||||||
},
|
},
|
||||||
#[serde(rename = "anthropic")]
|
|
||||||
Anthropic {
|
Anthropic {
|
||||||
#[serde(default)]
|
model: AnthropicModel,
|
||||||
default_model: AnthropicModel,
|
|
||||||
#[serde(default = "anthropic_api_url")]
|
|
||||||
api_url: String,
|
api_url: String,
|
||||||
#[serde(default)]
|
|
||||||
low_speed_timeout_in_seconds: Option<u64>,
|
low_speed_timeout_in_seconds: Option<u64>,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for AssistantProvider {
|
impl Default for AssistantProvider {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
Self::ZedDotDev {
|
Self::OpenAi {
|
||||||
default_model: ZedDotDevModel::default(),
|
model: OpenAiModel::default(),
|
||||||
|
api_url: open_ai::OPEN_AI_API_URL.into(),
|
||||||
|
low_speed_timeout_in_seconds: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn open_ai_url() -> String {
|
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
|
||||||
open_ai::OPEN_AI_API_URL.to_string()
|
#[serde(tag = "name", rename_all = "snake_case")]
|
||||||
|
pub enum AssistantProviderContent {
|
||||||
|
#[serde(rename = "zed.dev")]
|
||||||
|
ZedDotDev {
|
||||||
|
default_model: Option<ZedDotDevModel>,
|
||||||
|
},
|
||||||
|
#[serde(rename = "openai")]
|
||||||
|
OpenAi {
|
||||||
|
default_model: Option<OpenAiModel>,
|
||||||
|
api_url: Option<String>,
|
||||||
|
low_speed_timeout_in_seconds: Option<u64>,
|
||||||
|
},
|
||||||
|
#[serde(rename = "anthropic")]
|
||||||
|
Anthropic {
|
||||||
|
default_model: Option<AnthropicModel>,
|
||||||
|
api_url: Option<String>,
|
||||||
|
low_speed_timeout_in_seconds: Option<u64>,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
fn anthropic_api_url() -> String {
|
#[derive(Debug, Default)]
|
||||||
anthropic::ANTHROPIC_API_URL.to_string()
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Default, Debug, Deserialize, Serialize)]
|
|
||||||
pub struct AssistantSettings {
|
pub struct AssistantSettings {
|
||||||
pub enabled: bool,
|
pub enabled: bool,
|
||||||
pub button: bool,
|
pub button: bool,
|
||||||
@ -240,16 +243,16 @@ impl AssistantSettingsContent {
|
|||||||
default_width: settings.default_width,
|
default_width: settings.default_width,
|
||||||
default_height: settings.default_height,
|
default_height: settings.default_height,
|
||||||
provider: if let Some(open_ai_api_url) = settings.openai_api_url.as_ref() {
|
provider: if let Some(open_ai_api_url) = settings.openai_api_url.as_ref() {
|
||||||
Some(AssistantProvider::OpenAi {
|
Some(AssistantProviderContent::OpenAi {
|
||||||
default_model: settings.default_open_ai_model.clone().unwrap_or_default(),
|
default_model: settings.default_open_ai_model.clone(),
|
||||||
api_url: open_ai_api_url.clone(),
|
api_url: Some(open_ai_api_url.clone()),
|
||||||
low_speed_timeout_in_seconds: None,
|
low_speed_timeout_in_seconds: None,
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
settings.default_open_ai_model.clone().map(|open_ai_model| {
|
settings.default_open_ai_model.clone().map(|open_ai_model| {
|
||||||
AssistantProvider::OpenAi {
|
AssistantProviderContent::OpenAi {
|
||||||
default_model: open_ai_model,
|
default_model: Some(open_ai_model),
|
||||||
api_url: open_ai_url(),
|
api_url: None,
|
||||||
low_speed_timeout_in_seconds: None,
|
low_speed_timeout_in_seconds: None,
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@ -270,6 +273,64 @@ impl AssistantSettingsContent {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn set_model(&mut self, new_model: LanguageModel) {
|
||||||
|
match self {
|
||||||
|
AssistantSettingsContent::Versioned(settings) => match settings {
|
||||||
|
VersionedAssistantSettingsContent::V1(settings) => match &mut settings.provider {
|
||||||
|
Some(AssistantProviderContent::ZedDotDev {
|
||||||
|
default_model: model,
|
||||||
|
}) => {
|
||||||
|
if let LanguageModel::ZedDotDev(new_model) = new_model {
|
||||||
|
*model = Some(new_model);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Some(AssistantProviderContent::OpenAi {
|
||||||
|
default_model: model,
|
||||||
|
..
|
||||||
|
}) => {
|
||||||
|
if let LanguageModel::OpenAi(new_model) = new_model {
|
||||||
|
*model = Some(new_model);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Some(AssistantProviderContent::Anthropic {
|
||||||
|
default_model: model,
|
||||||
|
..
|
||||||
|
}) => {
|
||||||
|
if let LanguageModel::Anthropic(new_model) = new_model {
|
||||||
|
*model = Some(new_model);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
provider => match new_model {
|
||||||
|
LanguageModel::ZedDotDev(model) => {
|
||||||
|
*provider = Some(AssistantProviderContent::ZedDotDev {
|
||||||
|
default_model: Some(model),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
LanguageModel::OpenAi(model) => {
|
||||||
|
*provider = Some(AssistantProviderContent::OpenAi {
|
||||||
|
default_model: Some(model),
|
||||||
|
api_url: None,
|
||||||
|
low_speed_timeout_in_seconds: None,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
LanguageModel::Anthropic(model) => {
|
||||||
|
*provider = Some(AssistantProviderContent::Anthropic {
|
||||||
|
default_model: Some(model),
|
||||||
|
api_url: None,
|
||||||
|
low_speed_timeout_in_seconds: None,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
AssistantSettingsContent::Legacy(settings) => {
|
||||||
|
if let LanguageModel::OpenAi(model) = new_model {
|
||||||
|
settings.default_open_ai_model = Some(model);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
|
#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
|
||||||
@ -318,7 +379,7 @@ pub struct AssistantSettingsContentV1 {
|
|||||||
///
|
///
|
||||||
/// This can either be the internal `zed.dev` service or an external `openai` service,
|
/// This can either be the internal `zed.dev` service or an external `openai` service,
|
||||||
/// each with their respective default models and configurations.
|
/// each with their respective default models and configurations.
|
||||||
provider: Option<AssistantProvider>,
|
provider: Option<AssistantProviderContent>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
|
#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
|
||||||
@ -376,31 +437,82 @@ impl Settings for AssistantSettings {
|
|||||||
if let Some(provider) = value.provider.clone() {
|
if let Some(provider) = value.provider.clone() {
|
||||||
match (&mut settings.provider, provider) {
|
match (&mut settings.provider, provider) {
|
||||||
(
|
(
|
||||||
AssistantProvider::ZedDotDev { default_model },
|
AssistantProvider::ZedDotDev { model },
|
||||||
AssistantProvider::ZedDotDev {
|
AssistantProviderContent::ZedDotDev {
|
||||||
default_model: default_model_override,
|
default_model: model_override,
|
||||||
},
|
},
|
||||||
) => {
|
) => {
|
||||||
*default_model = default_model_override;
|
merge(model, model_override);
|
||||||
}
|
}
|
||||||
(
|
(
|
||||||
AssistantProvider::OpenAi {
|
AssistantProvider::OpenAi {
|
||||||
default_model,
|
model,
|
||||||
api_url,
|
api_url,
|
||||||
low_speed_timeout_in_seconds,
|
low_speed_timeout_in_seconds,
|
||||||
},
|
},
|
||||||
AssistantProvider::OpenAi {
|
AssistantProviderContent::OpenAi {
|
||||||
default_model: default_model_override,
|
default_model: model_override,
|
||||||
api_url: api_url_override,
|
api_url: api_url_override,
|
||||||
low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override,
|
low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override,
|
||||||
},
|
},
|
||||||
) => {
|
) => {
|
||||||
*default_model = default_model_override;
|
merge(model, model_override);
|
||||||
*api_url = api_url_override;
|
merge(api_url, api_url_override);
|
||||||
*low_speed_timeout_in_seconds = low_speed_timeout_in_seconds_override;
|
if let Some(low_speed_timeout_in_seconds_override) =
|
||||||
|
low_speed_timeout_in_seconds_override
|
||||||
|
{
|
||||||
|
*low_speed_timeout_in_seconds =
|
||||||
|
Some(low_speed_timeout_in_seconds_override);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
(merged, provider_override) => {
|
(
|
||||||
*merged = provider_override;
|
AssistantProvider::Anthropic {
|
||||||
|
model,
|
||||||
|
api_url,
|
||||||
|
low_speed_timeout_in_seconds,
|
||||||
|
},
|
||||||
|
AssistantProviderContent::Anthropic {
|
||||||
|
default_model: model_override,
|
||||||
|
api_url: api_url_override,
|
||||||
|
low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override,
|
||||||
|
},
|
||||||
|
) => {
|
||||||
|
merge(model, model_override);
|
||||||
|
merge(api_url, api_url_override);
|
||||||
|
if let Some(low_speed_timeout_in_seconds_override) =
|
||||||
|
low_speed_timeout_in_seconds_override
|
||||||
|
{
|
||||||
|
*low_speed_timeout_in_seconds =
|
||||||
|
Some(low_speed_timeout_in_seconds_override);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
(provider, provider_override) => {
|
||||||
|
*provider = match provider_override {
|
||||||
|
AssistantProviderContent::ZedDotDev {
|
||||||
|
default_model: model,
|
||||||
|
} => AssistantProvider::ZedDotDev {
|
||||||
|
model: model.unwrap_or_default(),
|
||||||
|
},
|
||||||
|
AssistantProviderContent::OpenAi {
|
||||||
|
default_model: model,
|
||||||
|
api_url,
|
||||||
|
low_speed_timeout_in_seconds,
|
||||||
|
} => AssistantProvider::OpenAi {
|
||||||
|
model: model.unwrap_or_default(),
|
||||||
|
api_url: api_url.unwrap_or_else(|| open_ai::OPEN_AI_API_URL.into()),
|
||||||
|
low_speed_timeout_in_seconds,
|
||||||
|
},
|
||||||
|
AssistantProviderContent::Anthropic {
|
||||||
|
default_model: model,
|
||||||
|
api_url,
|
||||||
|
low_speed_timeout_in_seconds,
|
||||||
|
} => AssistantProvider::Anthropic {
|
||||||
|
model: model.unwrap_or_default(),
|
||||||
|
api_url: api_url
|
||||||
|
.unwrap_or_else(|| anthropic::ANTHROPIC_API_URL.into()),
|
||||||
|
low_speed_timeout_in_seconds,
|
||||||
|
},
|
||||||
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -410,7 +522,7 @@ impl Settings for AssistantSettings {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn merge<T: Copy>(target: &mut T, value: Option<T>) {
|
fn merge<T>(target: &mut T, value: Option<T>) {
|
||||||
if let Some(value) = value {
|
if let Some(value) = value {
|
||||||
*target = value;
|
*target = value;
|
||||||
}
|
}
|
||||||
@ -433,8 +545,8 @@ mod tests {
|
|||||||
assert_eq!(
|
assert_eq!(
|
||||||
AssistantSettings::get_global(cx).provider,
|
AssistantSettings::get_global(cx).provider,
|
||||||
AssistantProvider::OpenAi {
|
AssistantProvider::OpenAi {
|
||||||
default_model: OpenAiModel::FourOmni,
|
model: OpenAiModel::FourOmni,
|
||||||
api_url: open_ai_url(),
|
api_url: open_ai::OPEN_AI_API_URL.into(),
|
||||||
low_speed_timeout_in_seconds: None,
|
low_speed_timeout_in_seconds: None,
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
@ -455,7 +567,7 @@ mod tests {
|
|||||||
assert_eq!(
|
assert_eq!(
|
||||||
AssistantSettings::get_global(cx).provider,
|
AssistantSettings::get_global(cx).provider,
|
||||||
AssistantProvider::OpenAi {
|
AssistantProvider::OpenAi {
|
||||||
default_model: OpenAiModel::FourOmni,
|
model: OpenAiModel::FourOmni,
|
||||||
api_url: "test-url".into(),
|
api_url: "test-url".into(),
|
||||||
low_speed_timeout_in_seconds: None,
|
low_speed_timeout_in_seconds: None,
|
||||||
}
|
}
|
||||||
@ -475,8 +587,8 @@ mod tests {
|
|||||||
assert_eq!(
|
assert_eq!(
|
||||||
AssistantSettings::get_global(cx).provider,
|
AssistantSettings::get_global(cx).provider,
|
||||||
AssistantProvider::OpenAi {
|
AssistantProvider::OpenAi {
|
||||||
default_model: OpenAiModel::Four,
|
model: OpenAiModel::Four,
|
||||||
api_url: open_ai_url(),
|
api_url: open_ai::OPEN_AI_API_URL.into(),
|
||||||
low_speed_timeout_in_seconds: None,
|
low_speed_timeout_in_seconds: None,
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
@ -501,7 +613,7 @@ mod tests {
|
|||||||
assert_eq!(
|
assert_eq!(
|
||||||
AssistantSettings::get_global(cx).provider,
|
AssistantSettings::get_global(cx).provider,
|
||||||
AssistantProvider::ZedDotDev {
|
AssistantProvider::ZedDotDev {
|
||||||
default_model: ZedDotDevModel::Custom("custom".into())
|
model: ZedDotDevModel::Custom("custom".into())
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
@ -25,31 +25,26 @@ use std::time::Duration;
|
|||||||
pub fn init(client: Arc<Client>, cx: &mut AppContext) {
|
pub fn init(client: Arc<Client>, cx: &mut AppContext) {
|
||||||
let mut settings_version = 0;
|
let mut settings_version = 0;
|
||||||
let provider = match &AssistantSettings::get_global(cx).provider {
|
let provider = match &AssistantSettings::get_global(cx).provider {
|
||||||
AssistantProvider::ZedDotDev { default_model } => {
|
AssistantProvider::ZedDotDev { model } => CompletionProvider::ZedDotDev(
|
||||||
CompletionProvider::ZedDotDev(ZedDotDevCompletionProvider::new(
|
ZedDotDevCompletionProvider::new(model.clone(), client.clone(), settings_version, cx),
|
||||||
default_model.clone(),
|
),
|
||||||
client.clone(),
|
|
||||||
settings_version,
|
|
||||||
cx,
|
|
||||||
))
|
|
||||||
}
|
|
||||||
AssistantProvider::OpenAi {
|
AssistantProvider::OpenAi {
|
||||||
default_model,
|
model,
|
||||||
api_url,
|
api_url,
|
||||||
low_speed_timeout_in_seconds,
|
low_speed_timeout_in_seconds,
|
||||||
} => CompletionProvider::OpenAi(OpenAiCompletionProvider::new(
|
} => CompletionProvider::OpenAi(OpenAiCompletionProvider::new(
|
||||||
default_model.clone(),
|
model.clone(),
|
||||||
api_url.clone(),
|
api_url.clone(),
|
||||||
client.http_client(),
|
client.http_client(),
|
||||||
low_speed_timeout_in_seconds.map(Duration::from_secs),
|
low_speed_timeout_in_seconds.map(Duration::from_secs),
|
||||||
settings_version,
|
settings_version,
|
||||||
)),
|
)),
|
||||||
AssistantProvider::Anthropic {
|
AssistantProvider::Anthropic {
|
||||||
default_model,
|
model,
|
||||||
api_url,
|
api_url,
|
||||||
low_speed_timeout_in_seconds,
|
low_speed_timeout_in_seconds,
|
||||||
} => CompletionProvider::Anthropic(AnthropicCompletionProvider::new(
|
} => CompletionProvider::Anthropic(AnthropicCompletionProvider::new(
|
||||||
default_model.clone(),
|
model.clone(),
|
||||||
api_url.clone(),
|
api_url.clone(),
|
||||||
client.http_client(),
|
client.http_client(),
|
||||||
low_speed_timeout_in_seconds.map(Duration::from_secs),
|
low_speed_timeout_in_seconds.map(Duration::from_secs),
|
||||||
@ -65,13 +60,13 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
|
|||||||
(
|
(
|
||||||
CompletionProvider::OpenAi(provider),
|
CompletionProvider::OpenAi(provider),
|
||||||
AssistantProvider::OpenAi {
|
AssistantProvider::OpenAi {
|
||||||
default_model,
|
model,
|
||||||
api_url,
|
api_url,
|
||||||
low_speed_timeout_in_seconds,
|
low_speed_timeout_in_seconds,
|
||||||
},
|
},
|
||||||
) => {
|
) => {
|
||||||
provider.update(
|
provider.update(
|
||||||
default_model.clone(),
|
model.clone(),
|
||||||
api_url.clone(),
|
api_url.clone(),
|
||||||
low_speed_timeout_in_seconds.map(Duration::from_secs),
|
low_speed_timeout_in_seconds.map(Duration::from_secs),
|
||||||
settings_version,
|
settings_version,
|
||||||
@ -80,13 +75,13 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
|
|||||||
(
|
(
|
||||||
CompletionProvider::Anthropic(provider),
|
CompletionProvider::Anthropic(provider),
|
||||||
AssistantProvider::Anthropic {
|
AssistantProvider::Anthropic {
|
||||||
default_model,
|
model,
|
||||||
api_url,
|
api_url,
|
||||||
low_speed_timeout_in_seconds,
|
low_speed_timeout_in_seconds,
|
||||||
},
|
},
|
||||||
) => {
|
) => {
|
||||||
provider.update(
|
provider.update(
|
||||||
default_model.clone(),
|
model.clone(),
|
||||||
api_url.clone(),
|
api_url.clone(),
|
||||||
low_speed_timeout_in_seconds.map(Duration::from_secs),
|
low_speed_timeout_in_seconds.map(Duration::from_secs),
|
||||||
settings_version,
|
settings_version,
|
||||||
@ -94,13 +89,13 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
|
|||||||
}
|
}
|
||||||
(
|
(
|
||||||
CompletionProvider::ZedDotDev(provider),
|
CompletionProvider::ZedDotDev(provider),
|
||||||
AssistantProvider::ZedDotDev { default_model },
|
AssistantProvider::ZedDotDev { model },
|
||||||
) => {
|
) => {
|
||||||
provider.update(default_model.clone(), settings_version);
|
provider.update(model.clone(), settings_version);
|
||||||
}
|
}
|
||||||
(_, AssistantProvider::ZedDotDev { default_model }) => {
|
(_, AssistantProvider::ZedDotDev { model }) => {
|
||||||
*provider = CompletionProvider::ZedDotDev(ZedDotDevCompletionProvider::new(
|
*provider = CompletionProvider::ZedDotDev(ZedDotDevCompletionProvider::new(
|
||||||
default_model.clone(),
|
model.clone(),
|
||||||
client.clone(),
|
client.clone(),
|
||||||
settings_version,
|
settings_version,
|
||||||
cx,
|
cx,
|
||||||
@ -109,13 +104,13 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
|
|||||||
(
|
(
|
||||||
_,
|
_,
|
||||||
AssistantProvider::OpenAi {
|
AssistantProvider::OpenAi {
|
||||||
default_model,
|
model,
|
||||||
api_url,
|
api_url,
|
||||||
low_speed_timeout_in_seconds,
|
low_speed_timeout_in_seconds,
|
||||||
},
|
},
|
||||||
) => {
|
) => {
|
||||||
*provider = CompletionProvider::OpenAi(OpenAiCompletionProvider::new(
|
*provider = CompletionProvider::OpenAi(OpenAiCompletionProvider::new(
|
||||||
default_model.clone(),
|
model.clone(),
|
||||||
api_url.clone(),
|
api_url.clone(),
|
||||||
client.http_client(),
|
client.http_client(),
|
||||||
low_speed_timeout_in_seconds.map(Duration::from_secs),
|
low_speed_timeout_in_seconds.map(Duration::from_secs),
|
||||||
@ -125,13 +120,13 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
|
|||||||
(
|
(
|
||||||
_,
|
_,
|
||||||
AssistantProvider::Anthropic {
|
AssistantProvider::Anthropic {
|
||||||
default_model,
|
model,
|
||||||
api_url,
|
api_url,
|
||||||
low_speed_timeout_in_seconds,
|
low_speed_timeout_in_seconds,
|
||||||
},
|
},
|
||||||
) => {
|
) => {
|
||||||
*provider = CompletionProvider::Anthropic(AnthropicCompletionProvider::new(
|
*provider = CompletionProvider::Anthropic(AnthropicCompletionProvider::new(
|
||||||
default_model.clone(),
|
model.clone(),
|
||||||
api_url.clone(),
|
api_url.clone(),
|
||||||
client.http_client(),
|
client.http_client(),
|
||||||
low_speed_timeout_in_seconds.map(Duration::from_secs),
|
low_speed_timeout_in_seconds.map(Duration::from_secs),
|
||||||
@ -159,6 +154,25 @@ impl CompletionProvider {
|
|||||||
cx.global::<Self>()
|
cx.global::<Self>()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn available_models(&self) -> Vec<LanguageModel> {
|
||||||
|
match self {
|
||||||
|
CompletionProvider::OpenAi(provider) => provider
|
||||||
|
.available_models()
|
||||||
|
.map(LanguageModel::OpenAi)
|
||||||
|
.collect(),
|
||||||
|
CompletionProvider::Anthropic(provider) => provider
|
||||||
|
.available_models()
|
||||||
|
.map(LanguageModel::Anthropic)
|
||||||
|
.collect(),
|
||||||
|
CompletionProvider::ZedDotDev(provider) => provider
|
||||||
|
.available_models()
|
||||||
|
.map(LanguageModel::ZedDotDev)
|
||||||
|
.collect(),
|
||||||
|
#[cfg(test)]
|
||||||
|
CompletionProvider::Fake(_) => unimplemented!(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn settings_version(&self) -> usize {
|
pub fn settings_version(&self) -> usize {
|
||||||
match self {
|
match self {
|
||||||
CompletionProvider::OpenAi(provider) => provider.settings_version(),
|
CompletionProvider::OpenAi(provider) => provider.settings_version(),
|
||||||
@ -209,17 +223,13 @@ impl CompletionProvider {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn default_model(&self) -> LanguageModel {
|
pub fn model(&self) -> LanguageModel {
|
||||||
match self {
|
match self {
|
||||||
CompletionProvider::OpenAi(provider) => LanguageModel::OpenAi(provider.default_model()),
|
CompletionProvider::OpenAi(provider) => LanguageModel::OpenAi(provider.model()),
|
||||||
CompletionProvider::Anthropic(provider) => {
|
CompletionProvider::Anthropic(provider) => LanguageModel::Anthropic(provider.model()),
|
||||||
LanguageModel::Anthropic(provider.default_model())
|
CompletionProvider::ZedDotDev(provider) => LanguageModel::ZedDotDev(provider.model()),
|
||||||
}
|
|
||||||
CompletionProvider::ZedDotDev(provider) => {
|
|
||||||
LanguageModel::ZedDotDev(provider.default_model())
|
|
||||||
}
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
CompletionProvider::Fake(_) => unimplemented!(),
|
CompletionProvider::Fake(_) => LanguageModel::default(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -12,6 +12,7 @@ use http::HttpClient;
|
|||||||
use settings::Settings;
|
use settings::Settings;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use std::{env, sync::Arc};
|
use std::{env, sync::Arc};
|
||||||
|
use strum::IntoEnumIterator;
|
||||||
use theme::ThemeSettings;
|
use theme::ThemeSettings;
|
||||||
use ui::prelude::*;
|
use ui::prelude::*;
|
||||||
use util::ResultExt;
|
use util::ResultExt;
|
||||||
@ -19,7 +20,7 @@ use util::ResultExt;
|
|||||||
pub struct AnthropicCompletionProvider {
|
pub struct AnthropicCompletionProvider {
|
||||||
api_key: Option<String>,
|
api_key: Option<String>,
|
||||||
api_url: String,
|
api_url: String,
|
||||||
default_model: AnthropicModel,
|
model: AnthropicModel,
|
||||||
http_client: Arc<dyn HttpClient>,
|
http_client: Arc<dyn HttpClient>,
|
||||||
low_speed_timeout: Option<Duration>,
|
low_speed_timeout: Option<Duration>,
|
||||||
settings_version: usize,
|
settings_version: usize,
|
||||||
@ -27,7 +28,7 @@ pub struct AnthropicCompletionProvider {
|
|||||||
|
|
||||||
impl AnthropicCompletionProvider {
|
impl AnthropicCompletionProvider {
|
||||||
pub fn new(
|
pub fn new(
|
||||||
default_model: AnthropicModel,
|
model: AnthropicModel,
|
||||||
api_url: String,
|
api_url: String,
|
||||||
http_client: Arc<dyn HttpClient>,
|
http_client: Arc<dyn HttpClient>,
|
||||||
low_speed_timeout: Option<Duration>,
|
low_speed_timeout: Option<Duration>,
|
||||||
@ -36,7 +37,7 @@ impl AnthropicCompletionProvider {
|
|||||||
Self {
|
Self {
|
||||||
api_key: None,
|
api_key: None,
|
||||||
api_url,
|
api_url,
|
||||||
default_model,
|
model,
|
||||||
http_client,
|
http_client,
|
||||||
low_speed_timeout,
|
low_speed_timeout,
|
||||||
settings_version,
|
settings_version,
|
||||||
@ -45,17 +46,21 @@ impl AnthropicCompletionProvider {
|
|||||||
|
|
||||||
pub fn update(
|
pub fn update(
|
||||||
&mut self,
|
&mut self,
|
||||||
default_model: AnthropicModel,
|
model: AnthropicModel,
|
||||||
api_url: String,
|
api_url: String,
|
||||||
low_speed_timeout: Option<Duration>,
|
low_speed_timeout: Option<Duration>,
|
||||||
settings_version: usize,
|
settings_version: usize,
|
||||||
) {
|
) {
|
||||||
self.default_model = default_model;
|
self.model = model;
|
||||||
self.api_url = api_url;
|
self.api_url = api_url;
|
||||||
self.low_speed_timeout = low_speed_timeout;
|
self.low_speed_timeout = low_speed_timeout;
|
||||||
self.settings_version = settings_version;
|
self.settings_version = settings_version;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn available_models(&self) -> impl Iterator<Item = AnthropicModel> {
|
||||||
|
AnthropicModel::iter()
|
||||||
|
}
|
||||||
|
|
||||||
pub fn settings_version(&self) -> usize {
|
pub fn settings_version(&self) -> usize {
|
||||||
self.settings_version
|
self.settings_version
|
||||||
}
|
}
|
||||||
@ -105,8 +110,8 @@ impl AnthropicCompletionProvider {
|
|||||||
.into()
|
.into()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn default_model(&self) -> AnthropicModel {
|
pub fn model(&self) -> AnthropicModel {
|
||||||
self.default_model.clone()
|
self.model.clone()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn count_tokens(
|
pub fn count_tokens(
|
||||||
@ -165,7 +170,7 @@ impl AnthropicCompletionProvider {
|
|||||||
fn to_anthropic_request(&self, request: LanguageModelRequest) -> Request {
|
fn to_anthropic_request(&self, request: LanguageModelRequest) -> Request {
|
||||||
let model = match request.model {
|
let model = match request.model {
|
||||||
LanguageModel::Anthropic(model) => model,
|
LanguageModel::Anthropic(model) => model,
|
||||||
_ => self.default_model(),
|
_ => self.model(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut system_message = String::new();
|
let mut system_message = String::new();
|
||||||
|
@ -11,6 +11,7 @@ use open_ai::{stream_completion, Request, RequestMessage, Role as OpenAiRole};
|
|||||||
use settings::Settings;
|
use settings::Settings;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use std::{env, sync::Arc};
|
use std::{env, sync::Arc};
|
||||||
|
use strum::IntoEnumIterator;
|
||||||
use theme::ThemeSettings;
|
use theme::ThemeSettings;
|
||||||
use ui::prelude::*;
|
use ui::prelude::*;
|
||||||
use util::ResultExt;
|
use util::ResultExt;
|
||||||
@ -18,7 +19,7 @@ use util::ResultExt;
|
|||||||
pub struct OpenAiCompletionProvider {
|
pub struct OpenAiCompletionProvider {
|
||||||
api_key: Option<String>,
|
api_key: Option<String>,
|
||||||
api_url: String,
|
api_url: String,
|
||||||
default_model: OpenAiModel,
|
model: OpenAiModel,
|
||||||
http_client: Arc<dyn HttpClient>,
|
http_client: Arc<dyn HttpClient>,
|
||||||
low_speed_timeout: Option<Duration>,
|
low_speed_timeout: Option<Duration>,
|
||||||
settings_version: usize,
|
settings_version: usize,
|
||||||
@ -26,7 +27,7 @@ pub struct OpenAiCompletionProvider {
|
|||||||
|
|
||||||
impl OpenAiCompletionProvider {
|
impl OpenAiCompletionProvider {
|
||||||
pub fn new(
|
pub fn new(
|
||||||
default_model: OpenAiModel,
|
model: OpenAiModel,
|
||||||
api_url: String,
|
api_url: String,
|
||||||
http_client: Arc<dyn HttpClient>,
|
http_client: Arc<dyn HttpClient>,
|
||||||
low_speed_timeout: Option<Duration>,
|
low_speed_timeout: Option<Duration>,
|
||||||
@ -35,7 +36,7 @@ impl OpenAiCompletionProvider {
|
|||||||
Self {
|
Self {
|
||||||
api_key: None,
|
api_key: None,
|
||||||
api_url,
|
api_url,
|
||||||
default_model,
|
model,
|
||||||
http_client,
|
http_client,
|
||||||
low_speed_timeout,
|
low_speed_timeout,
|
||||||
settings_version,
|
settings_version,
|
||||||
@ -44,17 +45,21 @@ impl OpenAiCompletionProvider {
|
|||||||
|
|
||||||
pub fn update(
|
pub fn update(
|
||||||
&mut self,
|
&mut self,
|
||||||
default_model: OpenAiModel,
|
model: OpenAiModel,
|
||||||
api_url: String,
|
api_url: String,
|
||||||
low_speed_timeout: Option<Duration>,
|
low_speed_timeout: Option<Duration>,
|
||||||
settings_version: usize,
|
settings_version: usize,
|
||||||
) {
|
) {
|
||||||
self.default_model = default_model;
|
self.model = model;
|
||||||
self.api_url = api_url;
|
self.api_url = api_url;
|
||||||
self.low_speed_timeout = low_speed_timeout;
|
self.low_speed_timeout = low_speed_timeout;
|
||||||
self.settings_version = settings_version;
|
self.settings_version = settings_version;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn available_models(&self) -> impl Iterator<Item = OpenAiModel> {
|
||||||
|
OpenAiModel::iter()
|
||||||
|
}
|
||||||
|
|
||||||
pub fn settings_version(&self) -> usize {
|
pub fn settings_version(&self) -> usize {
|
||||||
self.settings_version
|
self.settings_version
|
||||||
}
|
}
|
||||||
@ -104,8 +109,8 @@ impl OpenAiCompletionProvider {
|
|||||||
.into()
|
.into()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn default_model(&self) -> OpenAiModel {
|
pub fn model(&self) -> OpenAiModel {
|
||||||
self.default_model.clone()
|
self.model.clone()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn count_tokens(
|
pub fn count_tokens(
|
||||||
@ -152,7 +157,7 @@ impl OpenAiCompletionProvider {
|
|||||||
fn to_open_ai_request(&self, request: LanguageModelRequest) -> Request {
|
fn to_open_ai_request(&self, request: LanguageModelRequest) -> Request {
|
||||||
let model = match request.model {
|
let model = match request.model {
|
||||||
LanguageModel::OpenAi(model) => model,
|
LanguageModel::OpenAi(model) => model,
|
||||||
_ => self.default_model(),
|
_ => self.model(),
|
||||||
};
|
};
|
||||||
|
|
||||||
Request {
|
Request {
|
||||||
|
@ -7,11 +7,12 @@ use client::{proto, Client};
|
|||||||
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt, TryFutureExt};
|
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt, TryFutureExt};
|
||||||
use gpui::{AnyView, AppContext, Task};
|
use gpui::{AnyView, AppContext, Task};
|
||||||
use std::{future, sync::Arc};
|
use std::{future, sync::Arc};
|
||||||
|
use strum::IntoEnumIterator;
|
||||||
use ui::prelude::*;
|
use ui::prelude::*;
|
||||||
|
|
||||||
pub struct ZedDotDevCompletionProvider {
|
pub struct ZedDotDevCompletionProvider {
|
||||||
client: Arc<Client>,
|
client: Arc<Client>,
|
||||||
default_model: ZedDotDevModel,
|
model: ZedDotDevModel,
|
||||||
settings_version: usize,
|
settings_version: usize,
|
||||||
status: client::Status,
|
status: client::Status,
|
||||||
_maintain_client_status: Task<()>,
|
_maintain_client_status: Task<()>,
|
||||||
@ -19,7 +20,7 @@ pub struct ZedDotDevCompletionProvider {
|
|||||||
|
|
||||||
impl ZedDotDevCompletionProvider {
|
impl ZedDotDevCompletionProvider {
|
||||||
pub fn new(
|
pub fn new(
|
||||||
default_model: ZedDotDevModel,
|
model: ZedDotDevModel,
|
||||||
client: Arc<Client>,
|
client: Arc<Client>,
|
||||||
settings_version: usize,
|
settings_version: usize,
|
||||||
cx: &mut AppContext,
|
cx: &mut AppContext,
|
||||||
@ -39,24 +40,39 @@ impl ZedDotDevCompletionProvider {
|
|||||||
});
|
});
|
||||||
Self {
|
Self {
|
||||||
client,
|
client,
|
||||||
default_model,
|
model,
|
||||||
settings_version,
|
settings_version,
|
||||||
status,
|
status,
|
||||||
_maintain_client_status: maintain_client_status,
|
_maintain_client_status: maintain_client_status,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn update(&mut self, default_model: ZedDotDevModel, settings_version: usize) {
|
pub fn update(&mut self, model: ZedDotDevModel, settings_version: usize) {
|
||||||
self.default_model = default_model;
|
self.model = model;
|
||||||
self.settings_version = settings_version;
|
self.settings_version = settings_version;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn available_models(&self) -> impl Iterator<Item = ZedDotDevModel> {
|
||||||
|
let mut custom_model = if let ZedDotDevModel::Custom(custom_model) = self.model.clone() {
|
||||||
|
Some(custom_model)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
ZedDotDevModel::iter().filter_map(move |model| {
|
||||||
|
if let ZedDotDevModel::Custom(_) = model {
|
||||||
|
Some(ZedDotDevModel::Custom(custom_model.take()?))
|
||||||
|
} else {
|
||||||
|
Some(model)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
pub fn settings_version(&self) -> usize {
|
pub fn settings_version(&self) -> usize {
|
||||||
self.settings_version
|
self.settings_version
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn default_model(&self) -> ZedDotDevModel {
|
pub fn model(&self) -> ZedDotDevModel {
|
||||||
self.default_model.clone()
|
self.model.clone()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn is_authenticated(&self) -> bool {
|
pub fn is_authenticated(&self) -> bool {
|
||||||
|
84
crates/assistant/src/model_selector.rs
Normal file
84
crates/assistant/src/model_selector.rs
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use crate::{assistant_settings::AssistantSettings, CompletionProvider, ToggleModelSelector};
|
||||||
|
use fs::Fs;
|
||||||
|
use settings::update_settings_file;
|
||||||
|
use ui::{popover_menu, prelude::*, ButtonLike, ContextMenu, PopoverMenuHandle, Tooltip};
|
||||||
|
|
||||||
|
#[derive(IntoElement)]
|
||||||
|
pub struct ModelSelector {
|
||||||
|
handle: PopoverMenuHandle<ContextMenu>,
|
||||||
|
fs: Arc<dyn Fs>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ModelSelector {
|
||||||
|
pub fn new(handle: PopoverMenuHandle<ContextMenu>, fs: Arc<dyn Fs>) -> Self {
|
||||||
|
ModelSelector { handle, fs }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RenderOnce for ModelSelector {
|
||||||
|
fn render(self, cx: &mut WindowContext) -> impl IntoElement {
|
||||||
|
popover_menu("model-switcher")
|
||||||
|
.with_handle(self.handle)
|
||||||
|
.menu(move |cx| {
|
||||||
|
ContextMenu::build(cx, |mut menu, cx| {
|
||||||
|
for model in CompletionProvider::global(cx).available_models() {
|
||||||
|
menu = menu.custom_entry(
|
||||||
|
{
|
||||||
|
let model = model.clone();
|
||||||
|
move |_| Label::new(model.display_name()).into_any_element()
|
||||||
|
},
|
||||||
|
{
|
||||||
|
let fs = self.fs.clone();
|
||||||
|
let model = model.clone();
|
||||||
|
move |cx| {
|
||||||
|
let model = model.clone();
|
||||||
|
update_settings_file::<AssistantSettings>(
|
||||||
|
fs.clone(),
|
||||||
|
cx,
|
||||||
|
move |settings| settings.set_model(model),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
);
|
||||||
|
}
|
||||||
|
menu
|
||||||
|
})
|
||||||
|
.into()
|
||||||
|
})
|
||||||
|
.trigger(
|
||||||
|
ButtonLike::new("active-model")
|
||||||
|
.child(
|
||||||
|
h_flex()
|
||||||
|
.w_full()
|
||||||
|
.gap_0p5()
|
||||||
|
.child(
|
||||||
|
div()
|
||||||
|
.overflow_x_hidden()
|
||||||
|
.flex_grow()
|
||||||
|
.whitespace_nowrap()
|
||||||
|
.child(
|
||||||
|
Label::new(
|
||||||
|
CompletionProvider::global(cx).model().display_name(),
|
||||||
|
)
|
||||||
|
.size(LabelSize::Small)
|
||||||
|
.color(Color::Muted),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.child(
|
||||||
|
div().child(
|
||||||
|
Icon::new(IconName::ChevronDown)
|
||||||
|
.color(Color::Muted)
|
||||||
|
.size(IconSize::XSmall),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.style(ButtonStyle::Subtle)
|
||||||
|
.tooltip(move |cx| {
|
||||||
|
Tooltip::for_action("Change Model", &ToggleModelSelector, cx)
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
.anchor(gpui::AnchorCorner::BottomRight)
|
||||||
|
}
|
||||||
|
}
|
@ -20,3 +20,4 @@ isahc.workspace = true
|
|||||||
schemars = { workspace = true, optional = true }
|
schemars = { workspace = true, optional = true }
|
||||||
serde.workspace = true
|
serde.workspace = true
|
||||||
serde_json.workspace = true
|
serde_json.workspace = true
|
||||||
|
strum.workspace = true
|
||||||
|
@ -4,8 +4,8 @@ use http::{AsyncBody, HttpClient, Method, Request as HttpRequest};
|
|||||||
use isahc::config::Configurable;
|
use isahc::config::Configurable;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_json::{Map, Value};
|
use serde_json::{Map, Value};
|
||||||
use std::time::Duration;
|
use std::{convert::TryFrom, future::Future, time::Duration};
|
||||||
use std::{convert::TryFrom, future::Future};
|
use strum::EnumIter;
|
||||||
|
|
||||||
pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
|
pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
|
||||||
|
|
||||||
@ -44,7 +44,7 @@ impl From<Role> for String {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
|
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
|
||||||
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
|
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
|
||||||
pub enum Model {
|
pub enum Model {
|
||||||
#[serde(rename = "gpt-3.5-turbo", alias = "gpt-3.5-turbo-0613")]
|
#[serde(rename = "gpt-3.5-turbo", alias = "gpt-3.5-turbo-0613")]
|
||||||
ThreePointFiveTurbo,
|
ThreePointFiveTurbo,
|
||||||
|
@ -13,6 +13,51 @@ pub trait PopoverTrigger: IntoElement + Clickable + Selectable + 'static {}
|
|||||||
|
|
||||||
impl<T: IntoElement + Clickable + Selectable + 'static> PopoverTrigger for T {}
|
impl<T: IntoElement + Clickable + Selectable + 'static> PopoverTrigger for T {}
|
||||||
|
|
||||||
|
pub struct PopoverMenuHandle<M>(Rc<RefCell<Option<PopoverMenuHandleState<M>>>>);
|
||||||
|
|
||||||
|
impl<M> Clone for PopoverMenuHandle<M> {
|
||||||
|
fn clone(&self) -> Self {
|
||||||
|
Self(self.0.clone())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<M> Default for PopoverMenuHandle<M> {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self(Rc::default())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct PopoverMenuHandleState<M> {
|
||||||
|
menu_builder: Rc<dyn Fn(&mut WindowContext) -> Option<View<M>>>,
|
||||||
|
menu: Rc<RefCell<Option<View<M>>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<M: ManagedView> PopoverMenuHandle<M> {
|
||||||
|
pub fn show(&self, cx: &mut WindowContext) {
|
||||||
|
if let Some(state) = self.0.borrow().as_ref() {
|
||||||
|
show_menu(&state.menu_builder, &state.menu, cx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn hide(&self, cx: &mut WindowContext) {
|
||||||
|
if let Some(state) = self.0.borrow().as_ref() {
|
||||||
|
if let Some(menu) = state.menu.borrow().as_ref() {
|
||||||
|
menu.update(cx, |_, cx| cx.emit(DismissEvent));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn toggle(&self, cx: &mut WindowContext) {
|
||||||
|
if let Some(state) = self.0.borrow().as_ref() {
|
||||||
|
if state.menu.borrow().is_some() {
|
||||||
|
self.hide(cx);
|
||||||
|
} else {
|
||||||
|
self.show(cx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub struct PopoverMenu<M: ManagedView> {
|
pub struct PopoverMenu<M: ManagedView> {
|
||||||
id: ElementId,
|
id: ElementId,
|
||||||
child_builder: Option<
|
child_builder: Option<
|
||||||
@ -28,6 +73,7 @@ pub struct PopoverMenu<M: ManagedView> {
|
|||||||
anchor: AnchorCorner,
|
anchor: AnchorCorner,
|
||||||
attach: Option<AnchorCorner>,
|
attach: Option<AnchorCorner>,
|
||||||
offset: Option<Point<Pixels>>,
|
offset: Option<Point<Pixels>>,
|
||||||
|
trigger_handle: Option<PopoverMenuHandle<M>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<M: ManagedView> PopoverMenu<M> {
|
impl<M: ManagedView> PopoverMenu<M> {
|
||||||
@ -36,35 +82,17 @@ impl<M: ManagedView> PopoverMenu<M> {
|
|||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn with_handle(mut self, handle: PopoverMenuHandle<M>) -> Self {
|
||||||
|
self.trigger_handle = Some(handle);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
pub fn trigger<T: PopoverTrigger>(mut self, t: T) -> Self {
|
pub fn trigger<T: PopoverTrigger>(mut self, t: T) -> Self {
|
||||||
self.child_builder = Some(Box::new(|menu, builder| {
|
self.child_builder = Some(Box::new(|menu, builder| {
|
||||||
let open = menu.borrow().is_some();
|
let open = menu.borrow().is_some();
|
||||||
t.selected(open)
|
t.selected(open)
|
||||||
.when_some(builder, |el, builder| {
|
.when_some(builder, |el, builder| {
|
||||||
el.on_click({
|
el.on_click(move |_, cx| show_menu(&builder, &menu, cx))
|
||||||
move |_, cx| {
|
|
||||||
let Some(new_menu) = (builder)(cx) else {
|
|
||||||
return;
|
|
||||||
};
|
|
||||||
let menu2 = menu.clone();
|
|
||||||
let previous_focus_handle = cx.focused();
|
|
||||||
|
|
||||||
cx.subscribe(&new_menu, move |modal, _: &DismissEvent, cx| {
|
|
||||||
if modal.focus_handle(cx).contains_focused(cx) {
|
|
||||||
if let Some(previous_focus_handle) =
|
|
||||||
previous_focus_handle.as_ref()
|
|
||||||
{
|
|
||||||
cx.focus(previous_focus_handle);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
*menu2.borrow_mut() = None;
|
|
||||||
cx.refresh();
|
|
||||||
})
|
|
||||||
.detach();
|
|
||||||
cx.focus_view(&new_menu);
|
|
||||||
*menu.borrow_mut() = Some(new_menu);
|
|
||||||
}
|
|
||||||
})
|
|
||||||
})
|
})
|
||||||
.into_any_element()
|
.into_any_element()
|
||||||
}));
|
}));
|
||||||
@ -111,6 +139,32 @@ impl<M: ManagedView> PopoverMenu<M> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn show_menu<M: ManagedView>(
|
||||||
|
builder: &Rc<dyn Fn(&mut WindowContext) -> Option<View<M>>>,
|
||||||
|
menu: &Rc<RefCell<Option<View<M>>>>,
|
||||||
|
cx: &mut WindowContext,
|
||||||
|
) {
|
||||||
|
let Some(new_menu) = (builder)(cx) else {
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
let menu2 = menu.clone();
|
||||||
|
let previous_focus_handle = cx.focused();
|
||||||
|
|
||||||
|
cx.subscribe(&new_menu, move |modal, _: &DismissEvent, cx| {
|
||||||
|
if modal.focus_handle(cx).contains_focused(cx) {
|
||||||
|
if let Some(previous_focus_handle) = previous_focus_handle.as_ref() {
|
||||||
|
cx.focus(previous_focus_handle);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
*menu2.borrow_mut() = None;
|
||||||
|
cx.refresh();
|
||||||
|
})
|
||||||
|
.detach();
|
||||||
|
cx.focus_view(&new_menu);
|
||||||
|
*menu.borrow_mut() = Some(new_menu);
|
||||||
|
cx.refresh();
|
||||||
|
}
|
||||||
|
|
||||||
/// Creates a [`PopoverMenu`]
|
/// Creates a [`PopoverMenu`]
|
||||||
pub fn popover_menu<M: ManagedView>(id: impl Into<ElementId>) -> PopoverMenu<M> {
|
pub fn popover_menu<M: ManagedView>(id: impl Into<ElementId>) -> PopoverMenu<M> {
|
||||||
PopoverMenu {
|
PopoverMenu {
|
||||||
@ -120,6 +174,7 @@ pub fn popover_menu<M: ManagedView>(id: impl Into<ElementId>) -> PopoverMenu<M>
|
|||||||
anchor: AnchorCorner::TopLeft,
|
anchor: AnchorCorner::TopLeft,
|
||||||
attach: None,
|
attach: None,
|
||||||
offset: None,
|
offset: None,
|
||||||
|
trigger_handle: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -190,6 +245,15 @@ impl<M: ManagedView> Element for PopoverMenu<M> {
|
|||||||
(child_builder)(element_state.menu.clone(), self.menu_builder.clone())
|
(child_builder)(element_state.menu.clone(), self.menu_builder.clone())
|
||||||
});
|
});
|
||||||
|
|
||||||
|
if let Some(trigger_handle) = self.trigger_handle.take() {
|
||||||
|
if let Some(menu_builder) = self.menu_builder.clone() {
|
||||||
|
*trigger_handle.0.borrow_mut() = Some(PopoverMenuHandleState {
|
||||||
|
menu_builder,
|
||||||
|
menu: element_state.menu.clone(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
let child_layout_id = child_element
|
let child_layout_id = child_element
|
||||||
.as_mut()
|
.as_mut()
|
||||||
.map(|child_element| child_element.request_layout(cx));
|
.map(|child_element| child_element.request_layout(cx));
|
||||||
|
Loading…
Reference in New Issue
Block a user