diff --git a/Cargo.lock b/Cargo.lock index 404214707b..1e3eff9369 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2509,6 +2509,7 @@ dependencies = [ "http 0.1.0", "indoc", "language", + "language_model", "live_kit_client", "live_kit_server", "log", @@ -2678,36 +2679,22 @@ dependencies = [ name = "completion" version = "0.1.0" dependencies = [ - "anthropic", "anyhow", - "client", - "collections", "ctor", "editor", "env_logger", "futures 0.3.28", "gpui", - "http 0.1.0", "language", "language_model", - "log", - "menu", - "ollama", - "open_ai", - "parking_lot", "project", "rand 0.8.5", "serde", - "serde_json", "settings", "smol", - "strum", "text", - "theme", - "tiktoken-rs", "ui", "unindent", - "util", ] [[package]] @@ -6040,11 +6027,19 @@ name = "language_model" version = "0.1.0" dependencies = [ "anthropic", + "anyhow", + "client", + "collections", "ctor", "editor", "env_logger", + "feature_flags", + "futures 0.3.28", + "gpui", + "http 0.1.0", "language", "log", + "menu", "ollama", "open_ai", "project", @@ -6052,9 +6047,15 @@ dependencies = [ "rand 0.8.5", "schemars", "serde", + "serde_json", + "settings", "strum", "text", + "theme", + "tiktoken-rs", + "ui", "unindent", + "util", ] [[package]] @@ -13802,6 +13803,7 @@ dependencies = [ "isahc", "journal", "language", + "language_model", "language_selector", "language_tools", "languages", diff --git a/assets/settings/default.json b/assets/settings/default.json index d3e0f43ed1..0c6ed54e69 100644 --- a/assets/settings/default.json +++ b/assets/settings/default.json @@ -375,7 +375,7 @@ }, "assistant": { // Version of this setting. - "version": "1", + "version": "2", // Whether the assistant is enabled. "enabled": true, // Whether to show the assistant panel button in the status bar. @@ -386,18 +386,12 @@ "default_width": 640, // Default height when the assistant is docked to the bottom. "default_height": 320, - // AI provider. - "provider": { - "name": "openai", - // The default model to use when creating new contexts. This - // setting can take three values: - // - // 1. "gpt-3.5-turbo" - // 2. "gpt-4" - // 3. "gpt-4-turbo-preview" - // 4. "gpt-4o" - // 5. "gpt-4o-mini" - "default_model": "gpt-4o" + // The default model to use when creating new contexts. + "default_model": { + // The provider to use. + "provider": "openai", + // The model to use. + "model": "gpt-4o" } }, // Whether the screen sharing icon is shown in the os status bar. @@ -858,6 +852,8 @@ } } }, + // Different settings for specific language models. + "language_models": {}, // Zed's Prettier integration settings. // Allows to enable/disable formatting with Prettier // and configure default Prettier, used when no project-level Prettier installation is found. diff --git a/crates/anthropic/src/anthropic.rs b/crates/anthropic/src/anthropic.rs index 21cb4d75aa..62f796cf25 100644 --- a/crates/anthropic/src/anthropic.rs +++ b/crates/anthropic/src/anthropic.rs @@ -21,11 +21,7 @@ pub enum Model { #[serde(alias = "claude-3-haiku", rename = "claude-3-haiku-20240307")] Claude3Haiku, #[serde(rename = "custom")] - Custom { - name: String, - #[serde(default)] - max_tokens: Option, - }, + Custom { name: String, max_tokens: usize }, } impl Model { @@ -39,10 +35,7 @@ impl Model { } else if id.starts_with("claude-3-haiku") { Ok(Self::Claude3Haiku) } else { - Ok(Self::Custom { - name: id.to_string(), - max_tokens: None, - }) + Err(anyhow!("invalid model id")) } } @@ -52,7 +45,7 @@ impl Model { Model::Claude3Opus => "claude-3-opus-20240229", Model::Claude3Sonnet => "claude-3-sonnet-20240229", Model::Claude3Haiku => "claude-3-opus-20240307", - Model::Custom { name, .. } => name, + Self::Custom { name, .. } => name, } } @@ -72,7 +65,7 @@ impl Model { | Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => 200_000, - Self::Custom { max_tokens, .. } => max_tokens.unwrap_or(200_000), + Self::Custom { max_tokens, .. } => *max_tokens, } } } diff --git a/crates/assistant/src/assistant.rs b/crates/assistant/src/assistant.rs index 0b12cc099c..1c97402f8c 100644 --- a/crates/assistant/src/assistant.rs +++ b/crates/assistant/src/assistant.rs @@ -15,20 +15,20 @@ use assistant_settings::AssistantSettings; use assistant_slash_command::SlashCommandRegistry; use client::{proto, Client}; use command_palette_hooks::CommandPaletteFilter; -use completion::CompletionProvider; +use completion::LanguageModelCompletionProvider; pub use context::*; pub use context_store::*; use fs::Fs; -use gpui::{ - actions, impl_actions, AppContext, BorrowAppContext, Global, SharedString, UpdateGlobal, -}; +use gpui::{actions, impl_actions, AppContext, Global, SharedString, UpdateGlobal}; use indexed_docs::IndexedDocsRegistry; pub(crate) use inline_assistant::*; -use language_model::LanguageModelResponseMessage; +use language_model::{ + LanguageModelId, LanguageModelProviderName, LanguageModelRegistry, LanguageModelResponseMessage, +}; pub(crate) use model_selector::*; use semantic_index::{CloudEmbeddingProvider, SemanticIndex}; use serde::{Deserialize, Serialize}; -use settings::{Settings, SettingsStore}; +use settings::{update_settings_file, Settings, SettingsStore}; use slash_command::{ active_command, default_command, diagnostics_command, docs_command, fetch_command, file_command, now_command, project_command, prompt_command, search_command, symbols_command, @@ -165,6 +165,16 @@ pub fn init(fs: Arc, client: Arc, cx: &mut AppContext) { cx.set_global(Assistant::default()); AssistantSettings::register(cx); + // TODO: remove this when 0.148.0 is released. + if AssistantSettings::get_global(cx).using_outdated_settings_version { + update_settings_file::(fs.clone(), cx, { + let fs = fs.clone(); + |content, cx| { + content.update_file(fs, cx); + } + }); + } + cx.spawn(|mut cx| { let client = client.clone(); async move { @@ -182,7 +192,7 @@ pub fn init(fs: Arc, client: Arc, cx: &mut AppContext) { context_store::init(&client); prompt_library::init(cx); - init_completion_provider(Arc::clone(&client), cx); + init_completion_provider(cx); assistant_slash_command::init(cx); register_slash_commands(cx); assistant_panel::init(cx); @@ -207,20 +217,38 @@ pub fn init(fs: Arc, client: Arc, cx: &mut AppContext) { .detach(); } -fn init_completion_provider(client: Arc, cx: &mut AppContext) { - let provider = assistant_settings::create_provider_from_settings(client.clone(), 0, cx); - cx.set_global(CompletionProvider::new(provider, Some(client))); +fn init_completion_provider(cx: &mut AppContext) { + completion::init(cx); + update_active_language_model_from_settings(cx); - let mut settings_version = 0; - cx.observe_global::(move |cx| { - settings_version += 1; - cx.update_global::(|provider, cx| { - assistant_settings::update_completion_provider_settings(provider, settings_version, cx); - }) + cx.observe_global::(update_active_language_model_from_settings) + .detach(); + cx.observe(&LanguageModelRegistry::global(cx), |_, cx| { + update_active_language_model_from_settings(cx) }) .detach(); } +fn update_active_language_model_from_settings(cx: &mut AppContext) { + let settings = AssistantSettings::get_global(cx); + let provider_name = LanguageModelProviderName::from(settings.default_model.provider.clone()); + let model_id = LanguageModelId::from(settings.default_model.model.clone()); + + let Some(provider) = LanguageModelRegistry::global(cx) + .read(cx) + .provider(&provider_name) + else { + return; + }; + + let models = provider.provided_models(cx); + if let Some(model) = models.iter().find(|model| model.id() == model_id).cloned() { + LanguageModelCompletionProvider::global(cx).update(cx, |completion_provider, cx| { + completion_provider.set_active_model(model, cx); + }); + } +} + fn register_slash_commands(cx: &mut AppContext) { let slash_command_registry = SlashCommandRegistry::global(cx); slash_command_registry.register_command(file_command::FileSlashCommand, true); diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index 8524da4066..77bd9ac286 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -18,7 +18,7 @@ use anyhow::{anyhow, Result}; use assistant_slash_command::{SlashCommand, SlashCommandOutputSection}; use client::proto; use collections::{BTreeSet, HashMap, HashSet}; -use completion::CompletionProvider; +use completion::LanguageModelCompletionProvider; use editor::{ actions::{FoldAt, MoveToEndOfLine, Newline, ShowCompletions, UnfoldAt}, display_map::{ @@ -364,13 +364,12 @@ impl AssistantPanel { cx.subscribe(&pane, Self::handle_pane_event), cx.subscribe(&context_editor_toolbar, Self::handle_toolbar_event), cx.subscribe(&model_summary_editor, Self::handle_summary_editor_event), - cx.observe_global::({ - let mut prev_settings_version = CompletionProvider::global(cx).settings_version(); - move |this, cx| { - this.completion_provider_changed(prev_settings_version, cx); - prev_settings_version = CompletionProvider::global(cx).settings_version(); - } - }), + cx.observe( + &LanguageModelCompletionProvider::global(cx), + |this, _, cx| { + this.completion_provider_changed(cx); + }, + ), ]; Self { @@ -483,37 +482,36 @@ impl AssistantPanel { } } - fn completion_provider_changed( - &mut self, - prev_settings_version: usize, - cx: &mut ViewContext, - ) { - if self.is_authenticated(cx) { - self.authentication_prompt = None; - - match self.active_context_editor(cx) { - Some(editor) => { - editor.update(cx, |active_context, cx| { - active_context - .context - .update(cx, |context, cx| context.completion_provider_changed(cx)) - }); - } - None => { - self.new_context(cx); - } - } - - cx.notify(); - } else if self.authentication_prompt.is_none() - || prev_settings_version != CompletionProvider::global(cx).settings_version() - { - self.authentication_prompt = - Some(cx.update_global::(|provider, cx| { - provider.authentication_prompt(cx) - })); - cx.notify(); + fn completion_provider_changed(&mut self, cx: &mut ViewContext) { + if let Some(editor) = self.active_context_editor(cx) { + editor.update(cx, |active_context, cx| { + active_context + .context + .update(cx, |context, cx| context.completion_provider_changed(cx)) + }) } + + if self.active_context_editor(cx).is_none() { + self.new_context(cx); + } + + let authentication_prompt = Self::authentication_prompt(cx); + for context_editor in self.context_editors(cx) { + context_editor.update(cx, |editor, cx| { + editor.set_authentication_prompt(authentication_prompt.clone(), cx); + }); + } + + cx.notify(); + } + + fn authentication_prompt(cx: &mut WindowContext) -> Option { + if let Some(provider) = LanguageModelCompletionProvider::read_global(cx).active_provider() { + if !provider.is_authenticated(cx) { + return Some(provider.authentication_prompt(cx)); + } + } + None } pub fn inline_assist( @@ -774,7 +772,7 @@ impl AssistantPanel { } fn reset_credentials(&mut self, _: &ResetKey, cx: &mut ViewContext) { - CompletionProvider::global(cx) + LanguageModelCompletionProvider::read_global(cx) .reset_credentials(cx) .detach_and_log_err(cx); } @@ -783,6 +781,13 @@ impl AssistantPanel { self.model_selector_menu_handle.toggle(cx); } + fn context_editors(&self, cx: &AppContext) -> Vec> { + self.pane + .read(cx) + .items_of_type::() + .collect() + } + fn active_context_editor(&self, cx: &AppContext) -> Option> { self.pane .read(cx) @@ -904,11 +909,11 @@ impl AssistantPanel { } fn is_authenticated(&mut self, cx: &mut ViewContext) -> bool { - CompletionProvider::global(cx).is_authenticated() + LanguageModelCompletionProvider::read_global(cx).is_authenticated(cx) } fn authenticate(&mut self, cx: &mut ViewContext) -> Task> { - cx.update_global::(|provider, cx| provider.authenticate(cx)) + LanguageModelCompletionProvider::read_global(cx).authenticate(cx) } fn render_signed_in(&mut self, cx: &mut ViewContext) -> impl IntoElement { @@ -968,14 +973,18 @@ impl Panel for AssistantPanel { } fn set_position(&mut self, position: DockPosition, cx: &mut ViewContext) { - settings::update_settings_file::(self.fs.clone(), cx, move |settings| { - let dock = match position { - DockPosition::Left => AssistantDockPosition::Left, - DockPosition::Bottom => AssistantDockPosition::Bottom, - DockPosition::Right => AssistantDockPosition::Right, - }; - settings.set_dock(dock); - }); + settings::update_settings_file::( + self.fs.clone(), + cx, + move |settings, _| { + let dock = match position { + DockPosition::Left => AssistantDockPosition::Left, + DockPosition::Bottom => AssistantDockPosition::Bottom, + DockPosition::Right => AssistantDockPosition::Right, + }; + settings.set_dock(dock); + }, + ); } fn size(&self, cx: &WindowContext) -> Pixels { @@ -1074,6 +1083,7 @@ struct ActiveEditStep { pub struct ContextEditor { context: Model, + authentication_prompt: Option, fs: Arc, workspace: WeakView, project: Model, @@ -1131,6 +1141,7 @@ impl ContextEditor { let sections = context.read(cx).slash_command_output_sections().to_vec(); let mut this = Self { context, + authentication_prompt: None, editor, lsp_adapter_delegate, blocks: Default::default(), @@ -1150,6 +1161,15 @@ impl ContextEditor { this } + fn set_authentication_prompt( + &mut self, + authentication_prompt: Option, + cx: &mut ViewContext, + ) { + self.authentication_prompt = authentication_prompt; + cx.notify(); + } + fn insert_default_prompt(&mut self, cx: &mut ViewContext) { let command_name = DefaultSlashCommand.name(); self.editor.update(cx, |editor, cx| { @@ -1176,6 +1196,10 @@ impl ContextEditor { } fn assist(&mut self, _: &Assist, cx: &mut ViewContext) { + if self.authentication_prompt.is_some() { + return; + } + if !self.apply_edit_step(cx) { self.send_to_model(cx); } @@ -2203,19 +2227,26 @@ impl Render for ContextEditor { .size_full() .v_flex() .child( - div() - .flex_grow() - .bg(cx.theme().colors().editor_background) - .child(self.editor.clone()) - .child( - h_flex() - .w_full() - .absolute() - .bottom_0() - .p_4() - .justify_end() - .child(self.render_send_button(cx)), - ), + if let Some(authentication_prompt) = self.authentication_prompt.as_ref() { + div() + .flex_grow() + .bg(cx.theme().colors().editor_background) + .child(authentication_prompt.clone().into_any()) + } else { + div() + .flex_grow() + .bg(cx.theme().colors().editor_background) + .child(self.editor.clone()) + .child( + h_flex() + .w_full() + .absolute() + .bottom_0() + .p_4() + .justify_end() + .child(self.render_send_button(cx)), + ) + }, ) } } @@ -2543,7 +2574,7 @@ impl ContextEditorToolbarItem { } fn render_remaining_tokens(&self, cx: &mut ViewContext) -> Option { - let model = CompletionProvider::global(cx).model(); + let model = LanguageModelCompletionProvider::read_global(cx).active_model()?; let context = &self .active_context_editor .as_ref()? diff --git a/crates/assistant/src/assistant_settings.rs b/crates/assistant/src/assistant_settings.rs index e19dc65a44..09c5a9e733 100644 --- a/crates/assistant/src/assistant_settings.rs +++ b/crates/assistant/src/assistant_settings.rs @@ -1,19 +1,14 @@ -use std::{sync::Arc, time::Duration}; +use std::sync::Arc; use anthropic::Model as AnthropicModel; -use client::Client; -use completion::{ - AnthropicCompletionProvider, CloudCompletionProvider, CompletionProvider, - LanguageModelCompletionProvider, OllamaCompletionProvider, OpenAiCompletionProvider, -}; +use fs::Fs; use gpui::{AppContext, Pixels}; -use language_model::{CloudModel, LanguageModel}; +use language_model::{settings::AllLanguageModelSettings, CloudModel, LanguageModel}; use ollama::Model as OllamaModel; use open_ai::Model as OpenAiModel; -use parking_lot::RwLock; use schemars::{schema::Schema, JsonSchema}; use serde::{Deserialize, Serialize}; -use settings::{Settings, SettingsSources}; +use settings::{update_settings_file, Settings, SettingsSources}; #[derive(Copy, Clone, Default, Debug, Serialize, Deserialize, JsonSchema)] #[serde(rename_all = "snake_case")] @@ -24,43 +19,9 @@ pub enum AssistantDockPosition { Bottom, } -#[derive(Debug, PartialEq)] -pub enum AssistantProvider { - ZedDotDev { - model: CloudModel, - }, - OpenAi { - model: OpenAiModel, - api_url: String, - low_speed_timeout_in_seconds: Option, - available_models: Vec, - }, - Anthropic { - model: AnthropicModel, - api_url: String, - low_speed_timeout_in_seconds: Option, - }, - Ollama { - model: OllamaModel, - api_url: String, - low_speed_timeout_in_seconds: Option, - }, -} - -impl Default for AssistantProvider { - fn default() -> Self { - Self::OpenAi { - model: OpenAiModel::default(), - api_url: open_ai::OPEN_AI_API_URL.into(), - low_speed_timeout_in_seconds: None, - available_models: Default::default(), - } - } -} - #[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)] #[serde(tag = "name", rename_all = "snake_case")] -pub enum AssistantProviderContent { +pub enum AssistantProviderContentV1 { #[serde(rename = "zed.dev")] ZedDotDev { default_model: Option }, #[serde(rename = "openai")] @@ -91,7 +52,8 @@ pub struct AssistantSettings { pub dock: AssistantDockPosition, pub default_width: Pixels, pub default_height: Pixels, - pub provider: AssistantProvider, + pub default_model: AssistantDefaultModel, + pub using_outdated_settings_version: bool, } /// Assistant panel settings @@ -123,34 +85,142 @@ impl Default for AssistantSettingsContent { } impl AssistantSettingsContent { - fn upgrade(&self) -> AssistantSettingsContentV1 { + pub fn is_version_outdated(&self) -> bool { match self { AssistantSettingsContent::Versioned(settings) => match settings { - VersionedAssistantSettingsContent::V1(settings) => settings.clone(), + VersionedAssistantSettingsContent::V1(_) => true, + VersionedAssistantSettingsContent::V2(_) => false, }, - AssistantSettingsContent::Legacy(settings) => AssistantSettingsContentV1 { + AssistantSettingsContent::Legacy(_) => true, + } + } + + pub fn update_file(&mut self, fs: Arc, cx: &AppContext) { + if let AssistantSettingsContent::Versioned(settings) = self { + if let VersionedAssistantSettingsContent::V1(settings) = settings { + if let Some(provider) = settings.provider.clone() { + match provider { + AssistantProviderContentV1::Anthropic { + api_url, + low_speed_timeout_in_seconds, + .. + } => update_settings_file::( + fs, + cx, + move |content, _| { + if content.anthropic.is_none() { + content.anthropic = + Some(language_model::settings::AnthropicSettingsContent { + api_url, + low_speed_timeout_in_seconds, + ..Default::default() + }); + } + }, + ), + AssistantProviderContentV1::Ollama { + api_url, + low_speed_timeout_in_seconds, + .. + } => update_settings_file::( + fs, + cx, + move |content, _| { + if content.ollama.is_none() { + content.ollama = + Some(language_model::settings::OllamaSettingsContent { + api_url, + low_speed_timeout_in_seconds, + }); + } + }, + ), + AssistantProviderContentV1::OpenAi { + api_url, + low_speed_timeout_in_seconds, + available_models, + .. + } => update_settings_file::( + fs, + cx, + move |content, _| { + if content.open_ai.is_none() { + content.open_ai = + Some(language_model::settings::OpenAiSettingsContent { + api_url, + low_speed_timeout_in_seconds, + available_models, + }); + } + }, + ), + _ => {} + } + } + } + } + + *self = AssistantSettingsContent::Versioned(VersionedAssistantSettingsContent::V2( + self.upgrade(), + )); + } + + fn upgrade(&self) -> AssistantSettingsContentV2 { + match self { + AssistantSettingsContent::Versioned(settings) => match settings { + VersionedAssistantSettingsContent::V1(settings) => AssistantSettingsContentV2 { + enabled: settings.enabled, + button: settings.button, + dock: settings.dock, + default_width: settings.default_width, + default_height: settings.default_width, + default_model: settings + .provider + .clone() + .and_then(|provider| match provider { + AssistantProviderContentV1::ZedDotDev { default_model } => { + default_model.map(|model| AssistantDefaultModel { + provider: "zed.dev".to_string(), + model: model.id().to_string(), + }) + } + AssistantProviderContentV1::OpenAi { default_model, .. } => { + default_model.map(|model| AssistantDefaultModel { + provider: "openai".to_string(), + model: model.id().to_string(), + }) + } + AssistantProviderContentV1::Anthropic { default_model, .. } => { + default_model.map(|model| AssistantDefaultModel { + provider: "anthropic".to_string(), + model: model.id().to_string(), + }) + } + AssistantProviderContentV1::Ollama { default_model, .. } => { + default_model.map(|model| AssistantDefaultModel { + provider: "ollama".to_string(), + model: model.id().to_string(), + }) + } + }), + }, + VersionedAssistantSettingsContent::V2(settings) => settings.clone(), + }, + AssistantSettingsContent::Legacy(settings) => AssistantSettingsContentV2 { enabled: None, button: settings.button, dock: settings.dock, default_width: settings.default_width, default_height: settings.default_height, - provider: if let Some(open_ai_api_url) = settings.openai_api_url.as_ref() { - Some(AssistantProviderContent::OpenAi { - default_model: settings.default_open_ai_model.clone(), - api_url: Some(open_ai_api_url.clone()), - low_speed_timeout_in_seconds: None, - available_models: Some(Default::default()), - }) - } else { - settings.default_open_ai_model.clone().map(|open_ai_model| { - AssistantProviderContent::OpenAi { - default_model: Some(open_ai_model), - api_url: None, - low_speed_timeout_in_seconds: None, - available_models: Some(Default::default()), - } - }) - }, + default_model: Some(AssistantDefaultModel { + provider: "openai".to_string(), + model: settings + .default_open_ai_model + .clone() + .unwrap_or_default() + .id() + .to_string(), + }), }, } } @@ -161,6 +231,9 @@ impl AssistantSettingsContent { VersionedAssistantSettingsContent::V1(settings) => { settings.dock = Some(dock); } + VersionedAssistantSettingsContent::V2(settings) => { + settings.dock = Some(dock); + } }, AssistantSettingsContent::Legacy(settings) => { settings.dock = Some(dock); @@ -168,74 +241,78 @@ impl AssistantSettingsContent { } } - pub fn set_model(&mut self, new_model: LanguageModel) { + pub fn set_model(&mut self, language_model: Arc) { + let model = language_model.id().0.to_string(); + let provider = language_model.provider_name().0.to_string(); + match self { AssistantSettingsContent::Versioned(settings) => match settings { - VersionedAssistantSettingsContent::V1(settings) => match &mut settings.provider { - Some(AssistantProviderContent::ZedDotDev { - default_model: model, - }) => { - if let LanguageModel::Cloud(new_model) = new_model { - *model = Some(new_model); - } + VersionedAssistantSettingsContent::V1(settings) => match provider.as_ref() { + "zed.dev" => { + settings.provider = Some(AssistantProviderContentV1::ZedDotDev { + default_model: CloudModel::from_id(&model).ok(), + }); } - Some(AssistantProviderContent::OpenAi { - default_model: model, - .. - }) => { - if let LanguageModel::OpenAi(new_model) = new_model { - *model = Some(new_model); - } + "anthropic" => { + let (api_url, low_speed_timeout_in_seconds) = match &settings.provider { + Some(AssistantProviderContentV1::Anthropic { + api_url, + low_speed_timeout_in_seconds, + .. + }) => (api_url.clone(), *low_speed_timeout_in_seconds), + _ => (None, None), + }; + settings.provider = Some(AssistantProviderContentV1::Anthropic { + default_model: AnthropicModel::from_id(&model).ok(), + api_url, + low_speed_timeout_in_seconds, + }); } - Some(AssistantProviderContent::Anthropic { - default_model: model, - .. - }) => { - if let LanguageModel::Anthropic(new_model) = new_model { - *model = Some(new_model); - } + "ollama" => { + let (api_url, low_speed_timeout_in_seconds) = match &settings.provider { + Some(AssistantProviderContentV1::Ollama { + api_url, + low_speed_timeout_in_seconds, + .. + }) => (api_url.clone(), *low_speed_timeout_in_seconds), + _ => (None, None), + }; + settings.provider = Some(AssistantProviderContentV1::Ollama { + default_model: Some(ollama::Model::new(&model)), + api_url, + low_speed_timeout_in_seconds, + }); } - Some(AssistantProviderContent::Ollama { - default_model: model, - .. - }) => { - if let LanguageModel::Ollama(new_model) = new_model { - *model = Some(new_model); - } + "openai" => { + let (api_url, low_speed_timeout_in_seconds, available_models) = + match &settings.provider { + Some(AssistantProviderContentV1::OpenAi { + api_url, + low_speed_timeout_in_seconds, + available_models, + .. + }) => ( + api_url.clone(), + *low_speed_timeout_in_seconds, + available_models.clone(), + ), + _ => (None, None, None), + }; + settings.provider = Some(AssistantProviderContentV1::OpenAi { + default_model: open_ai::Model::from_id(&model).ok(), + api_url, + low_speed_timeout_in_seconds, + available_models, + }); } - provider => match new_model { - LanguageModel::Cloud(model) => { - *provider = Some(AssistantProviderContent::ZedDotDev { - default_model: Some(model), - }) - } - LanguageModel::OpenAi(model) => { - *provider = Some(AssistantProviderContent::OpenAi { - default_model: Some(model), - api_url: None, - low_speed_timeout_in_seconds: None, - available_models: Some(Default::default()), - }) - } - LanguageModel::Anthropic(model) => { - *provider = Some(AssistantProviderContent::Anthropic { - default_model: Some(model), - api_url: None, - low_speed_timeout_in_seconds: None, - }) - } - LanguageModel::Ollama(model) => { - *provider = Some(AssistantProviderContent::Ollama { - default_model: Some(model), - api_url: None, - low_speed_timeout_in_seconds: None, - }) - } - }, + _ => {} }, + VersionedAssistantSettingsContent::V2(settings) => { + settings.default_model = Some(AssistantDefaultModel { provider, model }); + } }, AssistantSettingsContent::Legacy(settings) => { - if let LanguageModel::OpenAi(model) = new_model { + if let Ok(model) = open_ai::Model::from_id(&language_model.id().0) { settings.default_open_ai_model = Some(model); } } @@ -248,21 +325,78 @@ impl AssistantSettingsContent { pub enum VersionedAssistantSettingsContent { #[serde(rename = "1")] V1(AssistantSettingsContentV1), + #[serde(rename = "2")] + V2(AssistantSettingsContentV2), } impl Default for VersionedAssistantSettingsContent { fn default() -> Self { - Self::V1(AssistantSettingsContentV1 { + Self::V2(AssistantSettingsContentV2 { enabled: None, button: None, dock: None, default_width: None, default_height: None, - provider: None, + default_model: None, }) } } +#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)] +pub struct AssistantSettingsContentV2 { + /// Whether the Assistant is enabled. + /// + /// Default: true + enabled: Option, + /// Whether to show the assistant panel button in the status bar. + /// + /// Default: true + button: Option, + /// Where to dock the assistant. + /// + /// Default: right + dock: Option, + /// Default width in pixels when the assistant is docked to the left or right. + /// + /// Default: 640 + default_width: Option, + /// Default height in pixels when the assistant is docked to the bottom. + /// + /// Default: 320 + default_height: Option, + /// The default model to use when creating new contexts. + default_model: Option, +} + +#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)] +pub struct AssistantDefaultModel { + #[schemars(schema_with = "providers_schema")] + pub provider: String, + pub model: String, +} + +fn providers_schema(_: &mut schemars::gen::SchemaGenerator) -> schemars::schema::Schema { + schemars::schema::SchemaObject { + enum_values: Some(vec![ + "anthropic".into(), + "ollama".into(), + "openai".into(), + "zed.dev".into(), + ]), + ..Default::default() + } + .into() +} + +impl Default for AssistantDefaultModel { + fn default() -> Self { + Self { + provider: "openai".to_string(), + model: "gpt-4".to_string(), + } + } +} + #[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)] pub struct AssistantSettingsContentV1 { /// Whether the Assistant is enabled. @@ -289,7 +423,7 @@ pub struct AssistantSettingsContentV1 { /// /// This can either be the internal `zed.dev` service or an external `openai` service, /// each with their respective default models and configurations. - provider: Option, + provider: Option, } #[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)] @@ -332,6 +466,10 @@ impl Settings for AssistantSettings { let mut settings = AssistantSettings::default(); for value in sources.defaults_and_customizations() { + if value.is_version_outdated() { + settings.using_outdated_settings_version = true; + } + let value = value.upgrade(); merge(&mut settings.enabled, value.enabled); merge(&mut settings.button, value.button); @@ -344,123 +482,10 @@ impl Settings for AssistantSettings { &mut settings.default_height, value.default_height.map(Into::into), ); - if let Some(provider) = value.provider.clone() { - match (&mut settings.provider, provider) { - ( - AssistantProvider::ZedDotDev { model }, - AssistantProviderContent::ZedDotDev { - default_model: model_override, - }, - ) => { - merge(model, model_override); - } - ( - AssistantProvider::OpenAi { - model, - api_url, - low_speed_timeout_in_seconds, - available_models, - }, - AssistantProviderContent::OpenAi { - default_model: model_override, - api_url: api_url_override, - low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override, - available_models: available_models_override, - }, - ) => { - merge(model, model_override); - merge(api_url, api_url_override); - merge(available_models, available_models_override); - if let Some(low_speed_timeout_in_seconds_override) = - low_speed_timeout_in_seconds_override - { - *low_speed_timeout_in_seconds = - Some(low_speed_timeout_in_seconds_override); - } - } - ( - AssistantProvider::Ollama { - model, - api_url, - low_speed_timeout_in_seconds, - }, - AssistantProviderContent::Ollama { - default_model: model_override, - api_url: api_url_override, - low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override, - }, - ) => { - merge(model, model_override); - merge(api_url, api_url_override); - if let Some(low_speed_timeout_in_seconds_override) = - low_speed_timeout_in_seconds_override - { - *low_speed_timeout_in_seconds = - Some(low_speed_timeout_in_seconds_override); - } - } - ( - AssistantProvider::Anthropic { - model, - api_url, - low_speed_timeout_in_seconds, - }, - AssistantProviderContent::Anthropic { - default_model: model_override, - api_url: api_url_override, - low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override, - }, - ) => { - merge(model, model_override); - merge(api_url, api_url_override); - if let Some(low_speed_timeout_in_seconds_override) = - low_speed_timeout_in_seconds_override - { - *low_speed_timeout_in_seconds = - Some(low_speed_timeout_in_seconds_override); - } - } - (provider, provider_override) => { - *provider = match provider_override { - AssistantProviderContent::ZedDotDev { - default_model: model, - } => AssistantProvider::ZedDotDev { - model: model.unwrap_or_default(), - }, - AssistantProviderContent::OpenAi { - default_model: model, - api_url, - low_speed_timeout_in_seconds, - available_models, - } => AssistantProvider::OpenAi { - model: model.unwrap_or_default(), - api_url: api_url.unwrap_or_else(|| open_ai::OPEN_AI_API_URL.into()), - low_speed_timeout_in_seconds, - available_models: available_models.unwrap_or_default(), - }, - AssistantProviderContent::Anthropic { - default_model: model, - api_url, - low_speed_timeout_in_seconds, - } => AssistantProvider::Anthropic { - model: model.unwrap_or_default(), - api_url: api_url - .unwrap_or_else(|| anthropic::ANTHROPIC_API_URL.into()), - low_speed_timeout_in_seconds, - }, - AssistantProviderContent::Ollama { - default_model: model, - api_url, - low_speed_timeout_in_seconds, - } => AssistantProvider::Ollama { - model: model.unwrap_or_default(), - api_url: api_url.unwrap_or_else(|| ollama::OLLAMA_API_URL.into()), - low_speed_timeout_in_seconds, - }, - }; - } - } - } + merge( + &mut settings.default_model, + value.default_model.map(Into::into), + ); } Ok(settings) @@ -473,221 +498,103 @@ fn merge(target: &mut T, value: Option) { } } -pub fn update_completion_provider_settings( - provider: &mut CompletionProvider, - version: usize, - cx: &mut AppContext, -) { - let updated = match &AssistantSettings::get_global(cx).provider { - AssistantProvider::ZedDotDev { model } => provider - .update_current_as::<_, CloudCompletionProvider>(|provider| { - provider.update(model.clone(), version); - }), - AssistantProvider::OpenAi { - model, - api_url, - low_speed_timeout_in_seconds, - available_models, - } => provider.update_current_as::<_, OpenAiCompletionProvider>(|provider| { - provider.update( - choose_openai_model(&model, &available_models), - api_url.clone(), - low_speed_timeout_in_seconds.map(Duration::from_secs), - version, - ); - }), - AssistantProvider::Anthropic { - model, - api_url, - low_speed_timeout_in_seconds, - } => provider.update_current_as::<_, AnthropicCompletionProvider>(|provider| { - provider.update( - model.clone(), - api_url.clone(), - low_speed_timeout_in_seconds.map(Duration::from_secs), - version, - ); - }), - AssistantProvider::Ollama { - model, - api_url, - low_speed_timeout_in_seconds, - } => provider.update_current_as::<_, OllamaCompletionProvider>(|provider| { - provider.update( - model.clone(), - api_url.clone(), - low_speed_timeout_in_seconds.map(Duration::from_secs), - version, - cx, - ); - }), - }; +// #[cfg(test)] +// mod tests { +// use gpui::{AppContext, UpdateGlobal}; +// use settings::SettingsStore; - // Previously configured provider was changed to another one - if updated.is_none() { - provider.update_provider(|client| create_provider_from_settings(client, version, cx)); - } -} +// use super::*; -pub(crate) fn create_provider_from_settings( - client: Arc, - settings_version: usize, - cx: &mut AppContext, -) -> Arc> { - match &AssistantSettings::get_global(cx).provider { - AssistantProvider::ZedDotDev { model } => Arc::new(RwLock::new( - CloudCompletionProvider::new(model.clone(), client.clone(), settings_version, cx), - )), - AssistantProvider::OpenAi { - model, - api_url, - low_speed_timeout_in_seconds, - available_models, - } => Arc::new(RwLock::new(OpenAiCompletionProvider::new( - choose_openai_model(&model, &available_models), - api_url.clone(), - client.http_client(), - low_speed_timeout_in_seconds.map(Duration::from_secs), - settings_version, - available_models.clone(), - ))), - AssistantProvider::Anthropic { - model, - api_url, - low_speed_timeout_in_seconds, - } => Arc::new(RwLock::new(AnthropicCompletionProvider::new( - model.clone(), - api_url.clone(), - client.http_client(), - low_speed_timeout_in_seconds.map(Duration::from_secs), - settings_version, - ))), - AssistantProvider::Ollama { - model, - api_url, - low_speed_timeout_in_seconds, - } => Arc::new(RwLock::new(OllamaCompletionProvider::new( - model.clone(), - api_url.clone(), - client.http_client(), - low_speed_timeout_in_seconds.map(Duration::from_secs), - settings_version, - cx, - ))), - } -} +// #[gpui::test] +// fn test_deserialize_assistant_settings(cx: &mut AppContext) { +// let store = settings::SettingsStore::test(cx); +// cx.set_global(store); -/// Choose which model to use for openai provider. -/// If the model is not available, try to use the first available model, or fallback to the original model. -fn choose_openai_model( - model: &::open_ai::Model, - available_models: &[::open_ai::Model], -) -> ::open_ai::Model { - available_models - .iter() - .find(|&m| m == model) - .or_else(|| available_models.first()) - .unwrap_or_else(|| model) - .clone() -} +// // Settings default to gpt-4-turbo. +// AssistantSettings::register(cx); +// assert_eq!( +// AssistantSettings::get_global(cx).provider, +// AssistantProvider::OpenAi { +// model: OpenAiModel::FourOmni, +// api_url: open_ai::OPEN_AI_API_URL.into(), +// low_speed_timeout_in_seconds: None, +// available_models: Default::default(), +// } +// ); -#[cfg(test)] -mod tests { - use gpui::{AppContext, UpdateGlobal}; - use settings::SettingsStore; +// // Ensure backward-compatibility. +// SettingsStore::update_global(cx, |store, cx| { +// store +// .set_user_settings( +// r#"{ +// "assistant": { +// "openai_api_url": "test-url", +// } +// }"#, +// cx, +// ) +// .unwrap(); +// }); +// assert_eq!( +// AssistantSettings::get_global(cx).provider, +// AssistantProvider::OpenAi { +// model: OpenAiModel::FourOmni, +// api_url: "test-url".into(), +// low_speed_timeout_in_seconds: None, +// available_models: Default::default(), +// } +// ); +// SettingsStore::update_global(cx, |store, cx| { +// store +// .set_user_settings( +// r#"{ +// "assistant": { +// "default_open_ai_model": "gpt-4-0613" +// } +// }"#, +// cx, +// ) +// .unwrap(); +// }); +// assert_eq!( +// AssistantSettings::get_global(cx).provider, +// AssistantProvider::OpenAi { +// model: OpenAiModel::Four, +// api_url: open_ai::OPEN_AI_API_URL.into(), +// low_speed_timeout_in_seconds: None, +// available_models: Default::default(), +// } +// ); - use super::*; - - #[gpui::test] - fn test_deserialize_assistant_settings(cx: &mut AppContext) { - let store = settings::SettingsStore::test(cx); - cx.set_global(store); - - // Settings default to gpt-4-turbo. - AssistantSettings::register(cx); - assert_eq!( - AssistantSettings::get_global(cx).provider, - AssistantProvider::OpenAi { - model: OpenAiModel::FourOmni, - api_url: open_ai::OPEN_AI_API_URL.into(), - low_speed_timeout_in_seconds: None, - available_models: Default::default(), - } - ); - - // Ensure backward-compatibility. - SettingsStore::update_global(cx, |store, cx| { - store - .set_user_settings( - r#"{ - "assistant": { - "openai_api_url": "test-url", - } - }"#, - cx, - ) - .unwrap(); - }); - assert_eq!( - AssistantSettings::get_global(cx).provider, - AssistantProvider::OpenAi { - model: OpenAiModel::FourOmni, - api_url: "test-url".into(), - low_speed_timeout_in_seconds: None, - available_models: Default::default(), - } - ); - SettingsStore::update_global(cx, |store, cx| { - store - .set_user_settings( - r#"{ - "assistant": { - "default_open_ai_model": "gpt-4-0613" - } - }"#, - cx, - ) - .unwrap(); - }); - assert_eq!( - AssistantSettings::get_global(cx).provider, - AssistantProvider::OpenAi { - model: OpenAiModel::Four, - api_url: open_ai::OPEN_AI_API_URL.into(), - low_speed_timeout_in_seconds: None, - available_models: Default::default(), - } - ); - - // The new version supports setting a custom model when using zed.dev. - SettingsStore::update_global(cx, |store, cx| { - store - .set_user_settings( - r#"{ - "assistant": { - "version": "1", - "provider": { - "name": "zed.dev", - "default_model": { - "custom": { - "name": "custom-provider" - } - } - } - } - }"#, - cx, - ) - .unwrap(); - }); - assert_eq!( - AssistantSettings::get_global(cx).provider, - AssistantProvider::ZedDotDev { - model: CloudModel::Custom { - name: "custom-provider".into(), - max_tokens: None - } - } - ); - } -} +// // The new version supports setting a custom model when using zed.dev. +// SettingsStore::update_global(cx, |store, cx| { +// store +// .set_user_settings( +// r#"{ +// "assistant": { +// "version": "1", +// "provider": { +// "name": "zed.dev", +// "default_model": { +// "custom": { +// "name": "custom-provider" +// } +// } +// } +// } +// }"#, +// cx, +// ) +// .unwrap(); +// }); +// assert_eq!( +// AssistantSettings::get_global(cx).provider, +// AssistantProvider::ZedDotDev { +// model: CloudModel::Custom { +// name: "custom-provider".into(), +// max_tokens: None +// } +// } +// ); +// } +// } diff --git a/crates/assistant/src/context.rs b/crates/assistant/src/context.rs index 22511ab896..1b53686e06 100644 --- a/crates/assistant/src/context.rs +++ b/crates/assistant/src/context.rs @@ -1,6 +1,6 @@ use crate::{ - prompt_library::PromptStore, slash_command::SlashCommandLine, CompletionProvider, MessageId, - MessageStatus, + prompt_library::PromptStore, slash_command::SlashCommandLine, LanguageModelCompletionProvider, + MessageId, MessageStatus, }; use anyhow::{anyhow, Context as _, Result}; use assistant_slash_command::{ @@ -1124,7 +1124,9 @@ impl Context { .await; let token_count = cx - .update(|cx| CompletionProvider::global(cx).count_tokens(request, cx))? + .update(|cx| { + LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx) + })? .await?; this.update(&mut cx, |this, cx| { @@ -1308,7 +1310,9 @@ impl Context { }); let raw_output = cx - .update(|cx| CompletionProvider::global(cx).complete(request, cx))? + .update(|cx| { + LanguageModelCompletionProvider::read_global(cx).complete(request, cx) + })? .await?; let operations = Self::parse_edit_operations(&raw_output); @@ -1612,13 +1616,14 @@ impl Context { .then_some(message.id) })?; - if !CompletionProvider::global(cx).is_authenticated() { + if !LanguageModelCompletionProvider::read_global(cx).is_authenticated(cx) { log::info!("completion provider has no credentials"); return None; } let request = self.to_completion_request(cx); - let stream = CompletionProvider::global(cx).stream_completion(request, cx); + let stream = + LanguageModelCompletionProvider::read_global(cx).stream_completion(request, cx); let assistant_message = self .insert_message_after(last_message_id, Role::Assistant, MessageStatus::Pending, cx) .unwrap(); @@ -1698,11 +1703,14 @@ impl Context { }); if let Some(telemetry) = this.telemetry.as_ref() { - let model = CompletionProvider::global(cx).model(); + let model_telemetry_id = LanguageModelCompletionProvider::read_global(cx) + .active_model() + .map(|m| m.telemetry_id()) + .unwrap_or_default(); telemetry.report_assistant_event( Some(this.id.0.clone()), AssistantKind::Panel, - model.telemetry_id(), + model_telemetry_id, response_latency, error_message, ); @@ -1727,7 +1735,6 @@ impl Context { .map(|message| message.to_request_message(self.buffer.read(cx))); LanguageModelRequest { - model: CompletionProvider::global(cx).model(), messages: messages.collect(), stop: vec![], temperature: 1.0, @@ -1970,7 +1977,7 @@ impl Context { pub(super) fn summarize(&mut self, replace_old: bool, cx: &mut ModelContext) { if replace_old || (self.message_anchors.len() >= 2 && self.summary.is_none()) { - if !CompletionProvider::global(cx).is_authenticated() { + if !LanguageModelCompletionProvider::read_global(cx).is_authenticated(cx) { return; } @@ -1982,13 +1989,13 @@ impl Context { content: "Summarize the context into a short title without punctuation.".into(), })); let request = LanguageModelRequest { - model: CompletionProvider::global(cx).model(), messages: messages.collect(), stop: vec![], temperature: 1.0, }; - let stream = CompletionProvider::global(cx).stream_completion(request, cx); + let stream = + LanguageModelCompletionProvider::read_global(cx).stream_completion(request, cx); self.pending_summary = cx.spawn(|this, mut cx| { async move { let mut messages = stream.await?; @@ -2504,7 +2511,6 @@ mod tests { MessageId, }; use assistant_slash_command::{ArgumentCompletion, SlashCommand}; - use completion::FakeCompletionProvider; use fs::FakeFs; use gpui::{AppContext, TestAppContext, WeakView}; use indoc::indoc; @@ -2524,7 +2530,8 @@ mod tests { #[gpui::test] fn test_inserting_and_removing_messages(cx: &mut AppContext) { let settings_store = SettingsStore::test(cx); - FakeCompletionProvider::setup_test(cx); + language_model::LanguageModelRegistry::test(cx); + completion::LanguageModelCompletionProvider::test(cx); cx.set_global(settings_store); assistant_panel::init(cx); let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone())); @@ -2656,7 +2663,8 @@ mod tests { fn test_message_splitting(cx: &mut AppContext) { let settings_store = SettingsStore::test(cx); cx.set_global(settings_store); - FakeCompletionProvider::setup_test(cx); + language_model::LanguageModelRegistry::test(cx); + completion::LanguageModelCompletionProvider::test(cx); assistant_panel::init(cx); let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone())); @@ -2749,7 +2757,8 @@ mod tests { #[gpui::test] fn test_messages_for_offsets(cx: &mut AppContext) { let settings_store = SettingsStore::test(cx); - FakeCompletionProvider::setup_test(cx); + language_model::LanguageModelRegistry::test(cx); + completion::LanguageModelCompletionProvider::test(cx); cx.set_global(settings_store); assistant_panel::init(cx); let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone())); @@ -2834,7 +2843,8 @@ mod tests { async fn test_slash_commands(cx: &mut TestAppContext) { let settings_store = cx.update(SettingsStore::test); cx.set_global(settings_store); - cx.update(FakeCompletionProvider::setup_test); + cx.update(language_model::LanguageModelRegistry::test); + cx.update(completion::LanguageModelCompletionProvider::test); cx.update(Project::init_settings); cx.update(assistant_panel::init); let fs = FakeFs::new(cx.background_executor.clone()); @@ -2959,7 +2969,11 @@ mod tests { cx.update(prompt_library::init); let settings_store = cx.update(SettingsStore::test); cx.set_global(settings_store); - let fake_provider = cx.update(FakeCompletionProvider::setup_test); + + let fake_provider = cx.update(language_model::LanguageModelRegistry::test); + cx.update(completion::LanguageModelCompletionProvider::test); + + let fake_model = fake_provider.test_model(); cx.update(assistant_panel::init); let registry = Arc::new(LanguageRegistry::test(cx.executor())); @@ -3025,8 +3039,8 @@ mod tests { }); // Simulate the LLM completion - fake_provider.send_last_completion_chunk(llm_response.to_string()); - fake_provider.finish_last_completion(); + fake_model.send_last_completion_chunk(llm_response.to_string()); + fake_model.finish_last_completion(); // Wait for the completion to be processed cx.run_until_parked(); @@ -3107,7 +3121,8 @@ mod tests { async fn test_serialization(cx: &mut TestAppContext) { let settings_store = cx.update(SettingsStore::test); cx.set_global(settings_store); - cx.update(FakeCompletionProvider::setup_test); + cx.update(language_model::LanguageModelRegistry::test); + cx.update(completion::LanguageModelCompletionProvider::test); cx.update(assistant_panel::init); let registry = Arc::new(LanguageRegistry::test(cx.executor())); let context = cx.new_model(|cx| Context::local(registry.clone(), None, cx)); @@ -3183,7 +3198,9 @@ mod tests { let settings_store = cx.update(SettingsStore::test); cx.set_global(settings_store); - cx.update(FakeCompletionProvider::setup_test); + cx.update(language_model::LanguageModelRegistry::test); + cx.update(completion::LanguageModelCompletionProvider::test); + cx.update(assistant_panel::init); let slash_commands = cx.update(SlashCommandRegistry::default_global); slash_commands.register_command(FakeSlashCommand("cmd-1".into()), false); diff --git a/crates/assistant/src/inline_assistant.rs b/crates/assistant/src/inline_assistant.rs index a76a015d47..f3915c191d 100644 --- a/crates/assistant/src/inline_assistant.rs +++ b/crates/assistant/src/inline_assistant.rs @@ -1,6 +1,6 @@ use crate::{ assistant_settings::AssistantSettings, humanize_token_count, prompts::generate_content_prompt, - AssistantPanel, AssistantPanelEvent, CompletionProvider, Hunk, StreamingDiff, + AssistantPanel, AssistantPanelEvent, Hunk, LanguageModelCompletionProvider, StreamingDiff, }; use anyhow::{anyhow, Context as _, Result}; use client::telemetry::Telemetry; @@ -27,7 +27,9 @@ use gpui::{ WindowContext, }; use language::{Buffer, Point, Selection, TransactionId}; -use language_model::{LanguageModelRequest, LanguageModelRequestMessage, Role}; +use language_model::{ + LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role, +}; use multi_buffer::MultiBufferRow; use parking_lot::Mutex; use rope::Rope; @@ -844,7 +846,10 @@ impl InlineAssistant { } let codegen = assist.codegen.clone(); - let telemetry_id = CompletionProvider::global(cx).model().telemetry_id(); + let telemetry_id = LanguageModelCompletionProvider::read_global(cx) + .active_model() + .map(|m| m.telemetry_id()) + .unwrap_or_default(); let chunks: LocalBoxFuture>>> = if user_prompt.trim().to_lowercase() == "delete" { async { Ok(stream::empty().boxed()) }.boxed_local() @@ -854,7 +859,10 @@ impl InlineAssistant { async move { let request = request.await?; let chunks = cx - .update(|cx| CompletionProvider::global(cx).stream_completion(request, cx))? + .update(|cx| { + LanguageModelCompletionProvider::read_global(cx) + .stream_completion(request, cx) + })? .await?; Ok(chunks.boxed()) } @@ -871,8 +879,8 @@ impl InlineAssistant { cx: &mut WindowContext, ) -> Task> { cx.spawn(|mut cx| async move { - let (user_prompt, context_request, project_name, buffer, range, model) = cx - .read_global(|this: &InlineAssistant, cx: &WindowContext| { + let (user_prompt, context_request, project_name, buffer, range) = + cx.read_global(|this: &InlineAssistant, cx: &WindowContext| { let assist = this.assists.get(&assist_id).context("invalid assist")?; let decorations = assist.decorations.as_ref().context("invalid assist")?; let editor = assist.editor.upgrade().context("invalid assist")?; @@ -906,15 +914,7 @@ impl InlineAssistant { }); let buffer = editor.read(cx).buffer().read(cx).snapshot(cx); let range = assist.codegen.read(cx).range.clone(); - let model = CompletionProvider::global(cx).model(); - anyhow::Ok(( - user_prompt, - context_request, - project_name, - buffer, - range, - model, - )) + anyhow::Ok((user_prompt, context_request, project_name, buffer, range)) })??; let language = buffer.language_at(range.start); @@ -973,7 +973,6 @@ impl InlineAssistant { }); Ok(LanguageModelRequest { - model, messages, stop: vec!["|END|>".to_string()], temperature, @@ -1432,24 +1431,39 @@ impl Render for PromptEditor { PopoverMenu::new("model-switcher") .menu(move |cx| { ContextMenu::build(cx, |mut menu, cx| { - for model in CompletionProvider::global(cx).available_models() { + for available_model in + LanguageModelRegistry::read_global(cx).available_models(cx) + { menu = menu.custom_entry( { - let model = model.clone(); + let model_name = available_model.name().0.clone(); + let provider = + available_model.provider_name().0.clone(); move |_| { - Label::new(model.display_name()) - .into_any_element() + h_flex() + .w_full() + .justify_between() + .child(Label::new(model_name.clone())) + .child( + div().ml_4().child( + Label::new(provider.clone()) + .color(Color::Muted), + ), + ) + .into_any() } }, { let fs = fs.clone(); - let model = model.clone(); + let model = available_model.clone(); move |cx| { let model = model.clone(); update_settings_file::( fs.clone(), cx, - move |settings| settings.set_model(model), + move |settings, _| { + settings.set_model(model) + }, ); } }, @@ -1468,9 +1482,10 @@ impl Render for PromptEditor { Tooltip::with_meta( format!( "Using {}", - CompletionProvider::global(cx) - .model() - .display_name() + LanguageModelCompletionProvider::read_global(cx) + .active_model() + .map(|model| model.name().0) + .unwrap_or_else(|| "No model selected".into()), ), None, "Change Model", @@ -1668,7 +1683,9 @@ impl PromptEditor { .await?; let token_count = cx - .update(|cx| CompletionProvider::global(cx).count_tokens(request, cx))? + .update(|cx| { + LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx) + })? .await?; this.update(&mut cx, |this, cx| { this.token_count = Some(token_count); @@ -1796,7 +1813,7 @@ impl PromptEditor { } fn render_token_count(&self, cx: &mut ViewContext) -> Option { - let model = CompletionProvider::global(cx).model(); + let model = LanguageModelCompletionProvider::read_global(cx).active_model()?; let token_count = self.token_count?; let max_token_count = model.max_token_count(); @@ -2601,7 +2618,6 @@ fn merge_ranges(ranges: &mut Vec>, buffer: &MultiBufferSnapshot) { #[cfg(test)] mod tests { use super::*; - use completion::FakeCompletionProvider; use futures::stream::{self}; use gpui::{Context, TestAppContext}; use indoc::indoc; @@ -2622,7 +2638,8 @@ mod tests { #[gpui::test(iterations = 10)] async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) { cx.set_global(cx.update(SettingsStore::test)); - cx.update(|cx| FakeCompletionProvider::setup_test(cx)); + cx.update(language_model::LanguageModelRegistry::test); + cx.update(completion::LanguageModelCompletionProvider::test); cx.update(language_settings::init); let text = indoc! {" @@ -2749,7 +2766,8 @@ mod tests { cx: &mut TestAppContext, mut rng: StdRng, ) { - cx.update(|cx| FakeCompletionProvider::setup_test(cx)); + cx.update(LanguageModelRegistry::test); + cx.update(completion::LanguageModelCompletionProvider::test); cx.set_global(cx.update(SettingsStore::test)); cx.update(language_settings::init); diff --git a/crates/assistant/src/model_selector.rs b/crates/assistant/src/model_selector.rs index 6cd50a59da..3d628497b7 100644 --- a/crates/assistant/src/model_selector.rs +++ b/crates/assistant/src/model_selector.rs @@ -1,7 +1,10 @@ use std::sync::Arc; -use crate::{assistant_settings::AssistantSettings, CompletionProvider, ToggleModelSelector}; +use crate::{ + assistant_settings::AssistantSettings, LanguageModelCompletionProvider, ToggleModelSelector, +}; use fs::Fs; +use language_model::LanguageModelRegistry; use settings::update_settings_file; use ui::{prelude::*, ButtonLike, ContextMenu, PopoverMenu, PopoverMenuHandle, Tooltip}; @@ -23,25 +26,64 @@ impl RenderOnce for ModelSelector { .with_handle(self.handle) .menu(move |cx| { ContextMenu::build(cx, |mut menu, cx| { - for model in CompletionProvider::global(cx).available_models() { - menu = menu.custom_entry( - { - let model = model.clone(); - move |_| Label::new(model.display_name()).into_any_element() - }, - { - let fs = self.fs.clone(); - let model = model.clone(); - move |cx| { - let model = model.clone(); - update_settings_file::( - fs.clone(), - cx, - move |settings| settings.set_model(model), - ); - } - }, - ); + for (provider, available_models) in LanguageModelRegistry::global(cx) + .read(cx) + .available_models_grouped_by_provider(cx) + { + menu = menu.header(provider.0.clone()); + + if available_models.is_empty() { + menu = menu.custom_entry( + { + move |_| { + h_flex() + .w_full() + .gap_1() + .child(Icon::new(IconName::Settings)) + .child(Label::new("Configure")) + .into_any() + } + }, + { + let provider = provider.clone(); + move |cx| { + LanguageModelCompletionProvider::global(cx).update( + cx, + |completion_provider, cx| { + completion_provider + .set_active_provider(provider.clone(), cx) + }, + ); + } + }, + ); + } + + for available_model in available_models { + menu = menu.custom_entry( + { + let model_name = available_model.name().0.clone(); + move |_| { + h_flex() + .w_full() + .child(Label::new(model_name.clone())) + .into_any() + } + }, + { + let fs = self.fs.clone(); + let model = available_model.clone(); + move |cx| { + let model = model.clone(); + update_settings_file::( + fs.clone(), + cx, + move |settings, _| settings.set_model(model), + ); + } + }, + ); + } } menu }) @@ -61,7 +103,10 @@ impl RenderOnce for ModelSelector { .whitespace_nowrap() .child( Label::new( - CompletionProvider::global(cx).model().display_name(), + LanguageModelCompletionProvider::read_global(cx) + .active_model() + .map(|model| model.name().0) + .unwrap_or_else(|| "No model selected".into()), ) .size(LabelSize::Small) .color(Color::Muted), diff --git a/crates/assistant/src/prompt_library.rs b/crates/assistant/src/prompt_library.rs index a59f4e3c0f..c85aef9314 100644 --- a/crates/assistant/src/prompt_library.rs +++ b/crates/assistant/src/prompt_library.rs @@ -1,6 +1,6 @@ use crate::{ - slash_command::SlashCommandCompletionProvider, AssistantPanel, CompletionProvider, - InlineAssist, InlineAssistant, + slash_command::SlashCommandCompletionProvider, AssistantPanel, InlineAssist, InlineAssistant, + LanguageModelCompletionProvider, }; use anyhow::{anyhow, Result}; use assets::Assets; @@ -636,9 +636,9 @@ impl PromptLibrary { }; let prompt_editor = &self.prompt_editors[&active_prompt_id].body_editor; - let provider = CompletionProvider::global(cx); + let provider = LanguageModelCompletionProvider::read_global(cx); let initial_prompt = action.prompt.clone(); - if provider.is_authenticated() { + if provider.is_authenticated(cx) { InlineAssistant::update_global(cx, |assistant, cx| { assistant.assist(&prompt_editor, None, None, initial_prompt, cx) }) @@ -736,11 +736,8 @@ impl PromptLibrary { cx.background_executor().timer(DEBOUNCE_TIMEOUT).await; let token_count = cx .update(|cx| { - let provider = CompletionProvider::global(cx); - let model = provider.model(); - provider.count_tokens( + LanguageModelCompletionProvider::read_global(cx).count_tokens( LanguageModelRequest { - model, messages: vec![LanguageModelRequestMessage { role: Role::System, content: body.to_string(), @@ -806,7 +803,7 @@ impl PromptLibrary { let prompt_metadata = self.store.metadata(prompt_id)?; let prompt_editor = &self.prompt_editors[&prompt_id]; let focus_handle = prompt_editor.body_editor.focus_handle(cx); - let current_model = CompletionProvider::global(cx).model(); + let current_model = LanguageModelCompletionProvider::read_global(cx).active_model(); let settings = ThemeSettings::get_global(cx); Some( @@ -917,7 +914,11 @@ impl PromptLibrary { format!( "Model: {}", current_model - .display_name() + .as_ref() + .map(|model| model + .name() + .0) + .unwrap_or_default() ), cx, ) diff --git a/crates/assistant/src/terminal_inline_assistant.rs b/crates/assistant/src/terminal_inline_assistant.rs index 4a2b06ab57..c1d5fb898a 100644 --- a/crates/assistant/src/terminal_inline_assistant.rs +++ b/crates/assistant/src/terminal_inline_assistant.rs @@ -1,7 +1,7 @@ use crate::{ assistant_settings::AssistantSettings, humanize_token_count, prompts::generate_terminal_assistant_prompt, AssistantPanel, AssistantPanelEvent, - CompletionProvider, + LanguageModelCompletionProvider, }; use anyhow::{Context as _, Result}; use client::telemetry::Telemetry; @@ -17,7 +17,9 @@ use gpui::{ Subscription, Task, TextStyle, UpdateGlobal, View, WeakView, }; use language::Buffer; -use language_model::{LanguageModelRequest, LanguageModelRequestMessage, Role}; +use language_model::{ + LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role, +}; use settings::{update_settings_file, Settings}; use std::{ cmp, @@ -215,8 +217,6 @@ impl TerminalInlineAssistant { ) -> Result { let assist = self.assists.get(&assist_id).context("invalid assist")?; - let model = CompletionProvider::global(cx).model(); - let shell = std::env::var("SHELL").ok(); let working_directory = assist .terminal @@ -268,7 +268,6 @@ impl TerminalInlineAssistant { }); Ok(LanguageModelRequest { - model, messages, stop: Vec::new(), temperature: 1.0, @@ -559,24 +558,39 @@ impl Render for PromptEditor { PopoverMenu::new("model-switcher") .menu(move |cx| { ContextMenu::build(cx, |mut menu, cx| { - for model in CompletionProvider::global(cx).available_models() { + for available_model in + LanguageModelRegistry::read_global(cx).available_models(cx) + { menu = menu.custom_entry( { - let model = model.clone(); + let model_name = available_model.name().0.clone(); + let provider = + available_model.provider_name().0.clone(); move |_| { - Label::new(model.display_name()) - .into_any_element() + h_flex() + .w_full() + .justify_between() + .child(Label::new(model_name.clone())) + .child( + div().ml_4().child( + Label::new(provider.clone()) + .color(Color::Muted), + ), + ) + .into_any() } }, { let fs = fs.clone(); - let model = model.clone(); + let model = available_model.clone(); move |cx| { let model = model.clone(); update_settings_file::( fs.clone(), cx, - move |settings| settings.set_model(model), + move |settings, _| { + settings.set_model(model) + }, ); } }, @@ -595,9 +609,10 @@ impl Render for PromptEditor { Tooltip::with_meta( format!( "Using {}", - CompletionProvider::global(cx) - .model() - .display_name() + LanguageModelCompletionProvider::read_global(cx) + .active_model() + .map(|model| model.name().0) + .unwrap_or_else(|| "No model selected".into()) ), None, "Change Model", @@ -748,7 +763,9 @@ impl PromptEditor { })??; let token_count = cx - .update(|cx| CompletionProvider::global(cx).count_tokens(request, cx))? + .update(|cx| { + LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx) + })? .await?; this.update(&mut cx, |this, cx| { this.token_count = Some(token_count); @@ -878,7 +895,7 @@ impl PromptEditor { } fn render_token_count(&self, cx: &mut ViewContext) -> Option { - let model = CompletionProvider::global(cx).model(); + let model = LanguageModelCompletionProvider::read_global(cx).active_model()?; let token_count = self.token_count?; let max_token_count = model.max_token_count(); @@ -1023,8 +1040,12 @@ impl Codegen { self.transaction = Some(TerminalTransaction::start(self.terminal.clone())); let telemetry = self.telemetry.clone(); - let model_telemetry_id = prompt.model.telemetry_id(); - let response = CompletionProvider::global(cx).stream_completion(prompt, cx); + let model_telemetry_id = LanguageModelCompletionProvider::read_global(cx) + .active_model() + .map(|m| m.telemetry_id()) + .unwrap_or_default(); + let response = + LanguageModelCompletionProvider::read_global(cx).stream_completion(prompt, cx); self.generation = cx.spawn(|this, mut cx| async move { let response = response.await; diff --git a/crates/collab/Cargo.toml b/crates/collab/Cargo.toml index 06269b2948..7d429fb689 100644 --- a/crates/collab/Cargo.toml +++ b/crates/collab/Cargo.toml @@ -90,6 +90,7 @@ git_hosting_providers.workspace = true gpui = { workspace = true, features = ["test-support"] } indoc.workspace = true language = { workspace = true, features = ["test-support"] } +language_model = { workspace = true, features = ["test-support"] } live_kit_client = { workspace = true, features = ["test-support"] } lsp = { workspace = true, features = ["test-support"] } menu.workspace = true diff --git a/crates/collab/src/tests/test_server.rs b/crates/collab/src/tests/test_server.rs index cd5ef9ff64..423455ea0b 100644 --- a/crates/collab/src/tests/test_server.rs +++ b/crates/collab/src/tests/test_server.rs @@ -157,6 +157,8 @@ impl TestServer { } pub async fn create_client(&mut self, cx: &mut TestAppContext, name: &str) -> TestClient { + let fs = FakeFs::new(cx.executor()); + cx.update(|cx| { if cx.has_global::() { panic!("Same cx used to create two test clients") @@ -265,7 +267,6 @@ impl TestServer { git_hosting_provider_registry .register_hosting_provider(Arc::new(git_hosting_providers::Github)); - let fs = FakeFs::new(cx.executor()); let user_store = cx.new_model(|cx| UserStore::new(client.clone(), cx)); let workspace_store = cx.new_model(|cx| WorkspaceStore::new(client.clone(), cx)); let language_registry = Arc::new(LanguageRegistry::test(cx.executor())); @@ -297,7 +298,8 @@ impl TestServer { menu::init(); dev_server_projects::init(client.clone(), cx); settings::KeymapFile::load_asset(os_keymap, cx).unwrap(); - completion::FakeCompletionProvider::setup_test(cx); + language_model::LanguageModelRegistry::test(cx); + completion::init(cx); assistant::context_store::init(&client); }); diff --git a/crates/collab_ui/src/chat_panel.rs b/crates/collab_ui/src/chat_panel.rs index 67d5047a51..4cede4e12b 100644 --- a/crates/collab_ui/src/chat_panel.rs +++ b/crates/collab_ui/src/chat_panel.rs @@ -1107,9 +1107,11 @@ impl Panel for ChatPanel { } fn set_position(&mut self, position: DockPosition, cx: &mut ViewContext) { - settings::update_settings_file::(self.fs.clone(), cx, move |settings| { - settings.dock = Some(position) - }); + settings::update_settings_file::( + self.fs.clone(), + cx, + move |settings, _| settings.dock = Some(position), + ); } fn size(&self, cx: &gpui::WindowContext) -> Pixels { diff --git a/crates/collab_ui/src/collab_panel.rs b/crates/collab_ui/src/collab_panel.rs index e1458cbadf..9a91e403bf 100644 --- a/crates/collab_ui/src/collab_panel.rs +++ b/crates/collab_ui/src/collab_panel.rs @@ -2806,7 +2806,7 @@ impl Panel for CollabPanel { settings::update_settings_file::( self.fs.clone(), cx, - move |settings| settings.dock = Some(position), + move |settings, _| settings.dock = Some(position), ); } diff --git a/crates/collab_ui/src/notification_panel.rs b/crates/collab_ui/src/notification_panel.rs index 08dee3686d..22764beda5 100644 --- a/crates/collab_ui/src/notification_panel.rs +++ b/crates/collab_ui/src/notification_panel.rs @@ -672,7 +672,7 @@ impl Panel for NotificationPanel { settings::update_settings_file::( self.fs.clone(), cx, - move |settings| settings.dock = Some(position), + move |settings, _| settings.dock = Some(position), ); } diff --git a/crates/completion/Cargo.toml b/crates/completion/Cargo.toml index 18181e7bb5..9e3855676e 100644 --- a/crates/completion/Cargo.toml +++ b/crates/completion/Cargo.toml @@ -16,34 +16,20 @@ doctest = false test-support = [ "editor/test-support", "language/test-support", + "language_model/test-support", "project/test-support", "text/test-support", ] [dependencies] -anthropic = { workspace = true, features = ["schemars"] } anyhow.workspace = true -client.workspace = true -collections.workspace = true -editor.workspace = true futures.workspace = true gpui.workspace = true -http.workspace = true language_model.workspace = true -log.workspace = true -menu.workspace = true -ollama = { workspace = true, features = ["schemars"] } -open_ai = { workspace = true, features = ["schemars"] } -parking_lot.workspace = true serde.workspace = true -serde_json.workspace = true settings.workspace = true smol.workspace = true -strum.workspace = true -theme.workspace = true -tiktoken-rs.workspace = true ui.workspace = true -util.workspace = true [dev-dependencies] ctor.workspace = true @@ -51,6 +37,7 @@ editor = { workspace = true, features = ["test-support"] } env_logger.workspace = true language = { workspace = true, features = ["test-support"] } project = { workspace = true, features = ["test-support"] } +language_model = { workspace = true, features = ["test-support"] } rand.workspace = true text = { workspace = true, features = ["test-support"] } unindent.workspace = true diff --git a/crates/completion/src/anthropic.rs b/crates/completion/src/anthropic.rs deleted file mode 100644 index b1bbe8a35b..0000000000 --- a/crates/completion/src/anthropic.rs +++ /dev/null @@ -1,318 +0,0 @@ -use crate::{count_open_ai_tokens, LanguageModelCompletionProvider}; -use crate::{CompletionProvider, LanguageModel, LanguageModelRequest}; -use anthropic::{stream_completion, Model as AnthropicModel, Request, RequestMessage}; -use anyhow::{anyhow, Result}; -use editor::{Editor, EditorElement, EditorStyle}; -use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt}; -use gpui::{AnyView, AppContext, Task, TextStyle, View}; -use http::HttpClient; -use language_model::Role; -use settings::Settings; -use std::time::Duration; -use std::{env, sync::Arc}; -use strum::IntoEnumIterator; -use theme::ThemeSettings; -use ui::prelude::*; -use util::ResultExt; - -pub struct AnthropicCompletionProvider { - api_key: Option, - api_url: String, - model: AnthropicModel, - http_client: Arc, - low_speed_timeout: Option, - settings_version: usize, -} - -impl LanguageModelCompletionProvider for AnthropicCompletionProvider { - fn available_models(&self) -> Vec { - AnthropicModel::iter() - .map(LanguageModel::Anthropic) - .collect() - } - - fn settings_version(&self) -> usize { - self.settings_version - } - - fn is_authenticated(&self) -> bool { - self.api_key.is_some() - } - - fn authenticate(&self, cx: &AppContext) -> Task> { - if self.is_authenticated() { - Task::ready(Ok(())) - } else { - let api_url = self.api_url.clone(); - cx.spawn(|mut cx| async move { - let api_key = if let Ok(api_key) = env::var("ANTHROPIC_API_KEY") { - api_key - } else { - let (_, api_key) = cx - .update(|cx| cx.read_credentials(&api_url))? - .await? - .ok_or_else(|| anyhow!("credentials not found"))?; - String::from_utf8(api_key)? - }; - cx.update_global::(|provider, _cx| { - provider.update_current_as::<_, AnthropicCompletionProvider>(|provider| { - provider.api_key = Some(api_key); - }); - }) - }) - } - } - - fn reset_credentials(&self, cx: &AppContext) -> Task> { - let delete_credentials = cx.delete_credentials(&self.api_url); - cx.spawn(|mut cx| async move { - delete_credentials.await.log_err(); - cx.update_global::(|provider, _cx| { - provider.update_current_as::<_, AnthropicCompletionProvider>(|provider| { - provider.api_key = None; - }); - }) - }) - } - - fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView { - cx.new_view(|cx| AuthenticationPrompt::new(self.api_url.clone(), cx)) - .into() - } - - fn model(&self) -> LanguageModel { - LanguageModel::Anthropic(self.model.clone()) - } - - fn count_tokens( - &self, - request: LanguageModelRequest, - cx: &AppContext, - ) -> BoxFuture<'static, Result> { - count_open_ai_tokens(request, cx.background_executor()) - } - - fn stream_completion( - &self, - request: LanguageModelRequest, - ) -> BoxFuture<'static, Result>>> { - let request = self.to_anthropic_request(request); - - let http_client = self.http_client.clone(); - let api_key = self.api_key.clone(); - let api_url = self.api_url.clone(); - let low_speed_timeout = self.low_speed_timeout; - async move { - let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?; - let request = stream_completion( - http_client.as_ref(), - &api_url, - &api_key, - request, - low_speed_timeout, - ); - let response = request.await?; - let stream = response - .filter_map(|response| async move { - match response { - Ok(response) => match response { - anthropic::ResponseEvent::ContentBlockStart { - content_block, .. - } => match content_block { - anthropic::ContentBlock::Text { text } => Some(Ok(text)), - }, - anthropic::ResponseEvent::ContentBlockDelta { delta, .. } => { - match delta { - anthropic::TextDelta::TextDelta { text } => Some(Ok(text)), - } - } - _ => None, - }, - Err(error) => Some(Err(error)), - } - }) - .boxed(); - Ok(stream) - } - .boxed() - } - - fn as_any_mut(&mut self) -> &mut dyn std::any::Any { - self - } -} - -impl AnthropicCompletionProvider { - pub fn new( - model: AnthropicModel, - api_url: String, - http_client: Arc, - low_speed_timeout: Option, - settings_version: usize, - ) -> Self { - Self { - api_key: None, - api_url, - model, - http_client, - low_speed_timeout, - settings_version, - } - } - - pub fn update( - &mut self, - model: AnthropicModel, - api_url: String, - low_speed_timeout: Option, - settings_version: usize, - ) { - self.model = model; - self.api_url = api_url; - self.low_speed_timeout = low_speed_timeout; - self.settings_version = settings_version; - } - - fn to_anthropic_request(&self, mut request: LanguageModelRequest) -> Request { - request.preprocess_anthropic(); - - let model = match request.model { - LanguageModel::Anthropic(model) => model, - _ => self.model.clone(), - }; - - let mut system_message = String::new(); - if request - .messages - .first() - .map_or(false, |message| message.role == Role::System) - { - system_message = request.messages.remove(0).content; - } - - Request { - model, - messages: request - .messages - .iter() - .map(|msg| RequestMessage { - role: match msg.role { - Role::User => anthropic::Role::User, - Role::Assistant => anthropic::Role::Assistant, - Role::System => unreachable!("filtered out by preprocess_request"), - }, - content: msg.content.clone(), - }) - .collect(), - stream: true, - system: system_message, - max_tokens: 4092, - } - } -} - -struct AuthenticationPrompt { - api_key: View, - api_url: String, -} - -impl AuthenticationPrompt { - fn new(api_url: String, cx: &mut WindowContext) -> Self { - Self { - api_key: cx.new_view(|cx| { - let mut editor = Editor::single_line(cx); - editor.set_placeholder_text( - "sk-000000000000000000000000000000000000000000000000", - cx, - ); - editor - }), - api_url, - } - } - - fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext) { - let api_key = self.api_key.read(cx).text(cx); - if api_key.is_empty() { - return; - } - - let write_credentials = cx.write_credentials(&self.api_url, "Bearer", api_key.as_bytes()); - cx.spawn(|_, mut cx| async move { - write_credentials.await?; - cx.update_global::(|provider, _cx| { - provider.update_current_as::<_, AnthropicCompletionProvider>(|provider| { - provider.api_key = Some(api_key); - }); - }) - }) - .detach_and_log_err(cx); - } - - fn render_api_key_editor(&self, cx: &mut ViewContext) -> impl IntoElement { - let settings = ThemeSettings::get_global(cx); - let text_style = TextStyle { - color: cx.theme().colors().text, - font_family: settings.ui_font.family.clone(), - font_features: settings.ui_font.features.clone(), - font_size: rems(0.875).into(), - font_weight: settings.ui_font.weight, - line_height: relative(1.3), - ..Default::default() - }; - EditorElement::new( - &self.api_key, - EditorStyle { - background: cx.theme().colors().editor_background, - local_player: cx.theme().players().local(), - text: text_style, - ..Default::default() - }, - ) - } -} - -impl Render for AuthenticationPrompt { - fn render(&mut self, cx: &mut ViewContext) -> impl IntoElement { - const INSTRUCTIONS: [&str; 4] = [ - "To use the assistant panel or inline assistant, you need to add your Anthropic API key.", - "You can create an API key at: https://console.anthropic.com/settings/keys", - "", - "Paste your Anthropic API key below and hit enter to use the assistant:", - ]; - - v_flex() - .p_4() - .size_full() - .on_action(cx.listener(Self::save_api_key)) - .children( - INSTRUCTIONS.map(|instruction| Label::new(instruction).size(LabelSize::Small)), - ) - .child( - h_flex() - .w_full() - .my_2() - .px_2() - .py_1() - .bg(cx.theme().colors().editor_background) - .rounded_md() - .child(self.render_api_key_editor(cx)), - ) - .child( - Label::new( - "You can also assign the ANTHROPIC_API_KEY environment variable and restart Zed.", - ) - .size(LabelSize::Small), - ) - .child( - h_flex() - .gap_2() - .child(Label::new("Click on").size(LabelSize::Small)) - .child(Icon::new(IconName::ZedAssistant).size(IconSize::XSmall)) - .child( - Label::new("in the status bar to close this panel.").size(LabelSize::Small), - ), - ) - .into_any() - } -} diff --git a/crates/completion/src/cloud.rs b/crates/completion/src/cloud.rs deleted file mode 100644 index 959394715b..0000000000 --- a/crates/completion/src/cloud.rs +++ /dev/null @@ -1,214 +0,0 @@ -use crate::{ - count_open_ai_tokens, CompletionProvider, LanguageModel, LanguageModelCompletionProvider, - LanguageModelRequest, -}; -use anyhow::{anyhow, Result}; -use client::{proto, Client}; -use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt, TryFutureExt}; -use gpui::{AnyView, AppContext, Task}; -use language_model::CloudModel; -use std::{future, sync::Arc}; -use strum::IntoEnumIterator; -use ui::prelude::*; - -pub struct CloudCompletionProvider { - client: Arc, - model: CloudModel, - settings_version: usize, - status: client::Status, - _maintain_client_status: Task<()>, -} - -impl CloudCompletionProvider { - pub fn new( - model: CloudModel, - client: Arc, - settings_version: usize, - cx: &mut AppContext, - ) -> Self { - let mut status_rx = client.status(); - let status = *status_rx.borrow(); - let maintain_client_status = cx.spawn(|mut cx| async move { - while let Some(status) = status_rx.next().await { - let _ = cx.update_global::(|provider, _cx| { - provider.update_current_as::<_, Self>(|provider| { - provider.status = status; - }); - }); - } - }); - Self { - client, - model, - settings_version, - status, - _maintain_client_status: maintain_client_status, - } - } - - pub fn update(&mut self, model: CloudModel, settings_version: usize) { - self.model = model; - self.settings_version = settings_version; - } -} - -impl LanguageModelCompletionProvider for CloudCompletionProvider { - fn available_models(&self) -> Vec { - let mut custom_model = if matches!(self.model, CloudModel::Custom { .. }) { - Some(self.model.clone()) - } else { - None - }; - CloudModel::iter() - .filter_map(move |model| { - if let CloudModel::Custom { .. } = model { - custom_model.take() - } else { - Some(model) - } - }) - .map(LanguageModel::Cloud) - .collect() - } - - fn settings_version(&self) -> usize { - self.settings_version - } - - fn is_authenticated(&self) -> bool { - self.status.is_connected() - } - - fn authenticate(&self, cx: &AppContext) -> Task> { - let client = self.client.clone(); - cx.spawn(move |cx| async move { client.authenticate_and_connect(true, &cx).await }) - } - - fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView { - cx.new_view(|_cx| AuthenticationPrompt).into() - } - - fn reset_credentials(&self, _cx: &AppContext) -> Task> { - Task::ready(Ok(())) - } - - fn model(&self) -> LanguageModel { - LanguageModel::Cloud(self.model.clone()) - } - - fn count_tokens( - &self, - request: LanguageModelRequest, - cx: &AppContext, - ) -> BoxFuture<'static, Result> { - match &request.model { - LanguageModel::Cloud(CloudModel::Gpt4) - | LanguageModel::Cloud(CloudModel::Gpt4Turbo) - | LanguageModel::Cloud(CloudModel::Gpt4Omni) - | LanguageModel::Cloud(CloudModel::Gpt3Point5Turbo) => { - count_open_ai_tokens(request, cx.background_executor()) - } - LanguageModel::Cloud( - CloudModel::Claude3_5Sonnet - | CloudModel::Claude3Opus - | CloudModel::Claude3Sonnet - | CloudModel::Claude3Haiku, - ) => { - // Can't find a tokenizer for Claude 3, so for now just use the same as OpenAI's as an approximation. - count_open_ai_tokens(request, cx.background_executor()) - } - LanguageModel::Cloud(CloudModel::Custom { name, .. }) => { - if name.starts_with("anthropic/") { - // Can't find a tokenizer for Anthropic models, so for now just use the same as OpenAI's as an approximation. - count_open_ai_tokens(request, cx.background_executor()) - } else { - let request = self.client.request(proto::CountTokensWithLanguageModel { - model: name.clone(), - messages: request - .messages - .iter() - .map(|message| message.to_proto()) - .collect(), - }); - async move { - let response = request.await?; - Ok(response.token_count as usize) - } - .boxed() - } - } - _ => future::ready(Err(anyhow!("invalid model"))).boxed(), - } - } - - fn stream_completion( - &self, - mut request: LanguageModelRequest, - ) -> BoxFuture<'static, Result>>> { - request.preprocess(); - - let request = proto::CompleteWithLanguageModel { - model: request.model.id().to_string(), - messages: request - .messages - .iter() - .map(|message| message.to_proto()) - .collect(), - stop: request.stop, - temperature: request.temperature, - tools: Vec::new(), - tool_choice: None, - }; - - self.client - .request_stream(request) - .map_ok(|stream| { - stream - .filter_map(|response| async move { - match response { - Ok(mut response) => Some(Ok(response.choices.pop()?.delta?.content?)), - Err(error) => Some(Err(error)), - } - }) - .boxed() - }) - .boxed() - } - - fn as_any_mut(&mut self) -> &mut dyn std::any::Any { - self - } -} - -struct AuthenticationPrompt; - -impl Render for AuthenticationPrompt { - fn render(&mut self, _cx: &mut ViewContext) -> impl IntoElement { - const LABEL: &str = "Generate and analyze code with language models. You can dialog with the assistant in this panel or transform code inline."; - - v_flex().gap_6().p_4().child(Label::new(LABEL)).child( - v_flex() - .gap_2() - .child( - Button::new("sign_in", "Sign in") - .icon_color(Color::Muted) - .icon(IconName::Github) - .icon_position(IconPosition::Start) - .style(ButtonStyle::Filled) - .full_width() - .on_click(|_, cx| { - CompletionProvider::global(cx) - .authenticate(cx) - .detach_and_log_err(cx); - }), - ) - .child( - div().flex().w_full().items_center().child( - Label::new("Sign in to enable collaboration.") - .color(Color::Muted) - .size(LabelSize::Small), - ), - ), - ) - } -} diff --git a/crates/completion/src/completion.rs b/crates/completion/src/completion.rs index a219e90b51..f7c2da95cb 100644 --- a/crates/completion/src/completion.rs +++ b/crates/completion/src/completion.rs @@ -1,31 +1,37 @@ -mod anthropic; -mod cloud; -#[cfg(any(test, feature = "test-support"))] -mod fake; -mod ollama; -mod open_ai; - -pub use anthropic::*; -use anyhow::Result; -use client::Client; -pub use cloud::*; -#[cfg(any(test, feature = "test-support"))] -pub use fake::*; -use futures::{future::BoxFuture, stream::BoxStream, StreamExt}; -use gpui::{AnyView, AppContext, Task, WindowContext}; -use language_model::{LanguageModel, LanguageModelRequest}; -pub use ollama::*; -pub use open_ai::*; -use parking_lot::RwLock; +use anyhow::{anyhow, Result}; +use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt}; +use gpui::{AppContext, Global, Model, ModelContext, Task}; +use language_model::{ + LanguageModel, LanguageModelProvider, LanguageModelProviderName, LanguageModelRegistry, + LanguageModelRequest, +}; use smol::lock::{Semaphore, SemaphoreGuardArc}; -use std::{any::Any, pin::Pin, sync::Arc, task::Poll}; +use std::{pin::Pin, sync::Arc, task::Poll}; +use ui::Context; -pub struct CompletionResponse { - inner: BoxStream<'static, Result>, +pub fn init(cx: &mut AppContext) { + let completion_provider = cx.new_model(|cx| LanguageModelCompletionProvider::new(cx)); + cx.set_global(GlobalLanguageModelCompletionProvider(completion_provider)); +} + +struct GlobalLanguageModelCompletionProvider(Model); + +impl Global for GlobalLanguageModelCompletionProvider {} + +pub struct LanguageModelCompletionProvider { + active_provider: Option>, + active_model: Option>, + request_limiter: Arc, +} + +const MAX_CONCURRENT_COMPLETION_REQUESTS: usize = 4; + +pub struct LanguageModelCompletionResponse { + pub inner: BoxStream<'static, Result>, _lock: SemaphoreGuardArc, } -impl futures::Stream for CompletionResponse { +impl futures::Stream for LanguageModelCompletionResponse { type Item = Result; fn poll_next( @@ -36,73 +42,96 @@ impl futures::Stream for CompletionResponse { } } -pub trait LanguageModelCompletionProvider: Send + Sync { - fn available_models(&self) -> Vec; - fn settings_version(&self) -> usize; - fn is_authenticated(&self) -> bool; - fn authenticate(&self, cx: &AppContext) -> Task>; - fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView; - fn reset_credentials(&self, cx: &AppContext) -> Task>; - fn model(&self) -> LanguageModel; - fn count_tokens( - &self, - request: LanguageModelRequest, - cx: &AppContext, - ) -> BoxFuture<'static, Result>; - fn stream_completion( - &self, - request: LanguageModelRequest, - ) -> BoxFuture<'static, Result>>>; +impl LanguageModelCompletionProvider { + pub fn global(cx: &AppContext) -> Model { + cx.global::() + .0 + .clone() + } - fn as_any_mut(&mut self) -> &mut dyn Any; -} + pub fn read_global(cx: &AppContext) -> &Self { + cx.global::() + .0 + .read(cx) + } -const MAX_CONCURRENT_COMPLETION_REQUESTS: usize = 4; + #[cfg(any(test, feature = "test-support"))] + pub fn test(cx: &mut AppContext) { + let provider = cx.new_model(|cx| { + let mut this = Self::new(cx); + let available_model = LanguageModelRegistry::read_global(cx) + .available_models(cx) + .first() + .unwrap() + .clone(); + this.set_active_model(available_model, cx); + this + }); + cx.set_global(GlobalLanguageModelCompletionProvider(provider)); + } -pub struct CompletionProvider { - provider: Arc>, - client: Option>, - request_limiter: Arc, -} + pub fn new(cx: &mut ModelContext) -> Self { + cx.observe(&LanguageModelRegistry::global(cx), |_, _, cx| { + cx.notify(); + }) + .detach(); -impl CompletionProvider { - pub fn new( - provider: Arc>, - client: Option>, - ) -> Self { Self { - provider, - client, + active_provider: None, + active_model: None, request_limiter: Arc::new(Semaphore::new(MAX_CONCURRENT_COMPLETION_REQUESTS)), } } - pub fn available_models(&self) -> Vec { - self.provider.read().available_models() + pub fn active_provider(&self) -> Option> { + self.active_provider.clone() } - pub fn settings_version(&self) -> usize { - self.provider.read().settings_version() + pub fn set_active_provider( + &mut self, + provider_name: LanguageModelProviderName, + cx: &mut ModelContext, + ) { + self.active_provider = LanguageModelRegistry::read_global(cx).provider(&provider_name); + self.active_model = None; + cx.notify(); } - pub fn is_authenticated(&self) -> bool { - self.provider.read().is_authenticated() + pub fn active_model(&self) -> Option> { + self.active_model.clone() + } + + pub fn set_active_model(&mut self, model: Arc, cx: &mut ModelContext) { + if self.active_model.as_ref().map_or(false, |m| { + m.id() == model.id() && m.provider_name() == model.provider_name() + }) { + return; + } + + self.active_provider = + LanguageModelRegistry::read_global(cx).provider(&model.provider_name()); + self.active_model = Some(model); + cx.notify(); + } + + pub fn is_authenticated(&self, cx: &AppContext) -> bool { + self.active_provider + .as_ref() + .map_or(false, |provider| provider.is_authenticated(cx)) } pub fn authenticate(&self, cx: &AppContext) -> Task> { - self.provider.read().authenticate(cx) - } - - pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView { - self.provider.read().authentication_prompt(cx) + self.active_provider + .as_ref() + .map_or(Task::ready(Ok(())), |provider| provider.authenticate(cx)) } pub fn reset_credentials(&self, cx: &AppContext) -> Task> { - self.provider.read().reset_credentials(cx) - } - - pub fn model(&self) -> LanguageModel { - self.provider.read().model() + self.active_provider + .as_ref() + .map_or(Task::ready(Ok(())), |provider| { + provider.reset_credentials(cx) + }) } pub fn count_tokens( @@ -110,25 +139,31 @@ impl CompletionProvider { request: LanguageModelRequest, cx: &AppContext, ) -> BoxFuture<'static, Result> { - self.provider.read().count_tokens(request, cx) + if let Some(model) = self.active_model() { + model.count_tokens(request, cx) + } else { + std::future::ready(Err(anyhow!("No active model set"))).boxed() + } } pub fn stream_completion( &self, request: LanguageModelRequest, cx: &AppContext, - ) -> Task> { - let rate_limiter = self.request_limiter.clone(); - let provider = self.provider.clone(); - cx.foreground_executor().spawn(async move { - let lock = rate_limiter.acquire_arc().await; - let response = provider.read().stream_completion(request); - let response = response.await?; - Ok(CompletionResponse { - inner: response, - _lock: lock, + ) -> Task> { + if let Some(language_model) = self.active_model() { + let rate_limiter = self.request_limiter.clone(); + cx.spawn(|cx| async move { + let lock = rate_limiter.acquire_arc().await; + let response = language_model.stream_completion(request, &cx).await?; + Ok(LanguageModelCompletionResponse { + inner: response, + _lock: lock, + }) }) - }) + } else { + Task::ready(Err(anyhow!("No active model set"))) + } } pub fn complete(&self, request: LanguageModelRequest, cx: &AppContext) -> Task> { @@ -143,63 +178,43 @@ impl CompletionProvider { Ok(completion) }) } - - pub fn update_provider( - &mut self, - get_provider: impl FnOnce(Arc) -> Arc>, - ) { - if let Some(client) = &self.client { - self.provider = get_provider(Arc::clone(client)); - } else { - log::warn!("completion provider cannot be updated because its client was not set"); - } - } -} - -impl gpui::Global for CompletionProvider {} - -impl CompletionProvider { - pub fn global(cx: &AppContext) -> &Self { - cx.global::() - } - - pub fn update_current_as( - &mut self, - update: impl FnOnce(&mut T) -> R, - ) -> Option { - let mut provider = self.provider.write(); - if let Some(provider) = provider.as_any_mut().downcast_mut::() { - Some(update(provider)) - } else { - None - } - } } #[cfg(test)] mod tests { - use std::sync::Arc; - + use futures::StreamExt; use gpui::AppContext; - use parking_lot::RwLock; use settings::SettingsStore; - use smol::stream::StreamExt; + use ui::Context; use crate::{ - CompletionProvider, FakeCompletionProvider, LanguageModelRequest, - MAX_CONCURRENT_COMPLETION_REQUESTS, + LanguageModelCompletionProvider, LanguageModelRequest, MAX_CONCURRENT_COMPLETION_REQUESTS, }; + use language_model::LanguageModelRegistry; + #[gpui::test] fn test_rate_limiting(cx: &mut AppContext) { SettingsStore::test(cx); - let fake_provider = FakeCompletionProvider::setup_test(cx); + let fake_provider = LanguageModelRegistry::test(cx); - let provider = CompletionProvider::new(Arc::new(RwLock::new(fake_provider.clone())), None); + let model = LanguageModelRegistry::read_global(cx) + .available_models(cx) + .first() + .cloned() + .unwrap(); + + let provider = cx.new_model(|cx| { + let mut provider = LanguageModelCompletionProvider::new(cx); + provider.set_active_model(model.clone(), cx); + provider + }); + + let fake_model = fake_provider.test_model(); // Enqueue some requests for i in 0..MAX_CONCURRENT_COMPLETION_REQUESTS * 2 { - let response = provider.stream_completion( + let response = provider.read(cx).stream_completion( LanguageModelRequest { temperature: i as f32 / 10.0, ..Default::default() @@ -216,23 +231,18 @@ mod tests { .detach(); } cx.background_executor().run_until_parked(); - assert_eq!( - fake_provider.completion_count(), + fake_model.completion_count(), MAX_CONCURRENT_COMPLETION_REQUESTS ); // Get the first completion request that is in flight and mark it as completed. - let completion = fake_provider - .pending_completions() - .into_iter() - .next() - .unwrap(); - fake_provider.finish_completion(&completion); + let completion = fake_model.pending_completions().into_iter().next().unwrap(); + fake_model.finish_completion(&completion); // Ensure that the number of in-flight completion requests is reduced. assert_eq!( - fake_provider.completion_count(), + fake_model.completion_count(), MAX_CONCURRENT_COMPLETION_REQUESTS - 1 ); @@ -240,32 +250,32 @@ mod tests { // Ensure that another completion request was allowed to acquire the lock. assert_eq!( - fake_provider.completion_count(), + fake_model.completion_count(), MAX_CONCURRENT_COMPLETION_REQUESTS ); // Mark all completion requests as finished that are in flight. - for request in fake_provider.pending_completions() { - fake_provider.finish_completion(&request); + for request in fake_model.pending_completions() { + fake_model.finish_completion(&request); } - assert_eq!(fake_provider.completion_count(), 0); + assert_eq!(fake_model.completion_count(), 0); // Wait until the background tasks acquire the lock again. cx.background_executor().run_until_parked(); assert_eq!( - fake_provider.completion_count(), + fake_model.completion_count(), MAX_CONCURRENT_COMPLETION_REQUESTS - 1 ); // Finish all remaining completion requests. - for request in fake_provider.pending_completions() { - fake_provider.finish_completion(&request); + for request in fake_model.pending_completions() { + fake_model.finish_completion(&request); } cx.background_executor().run_until_parked(); - assert_eq!(fake_provider.completion_count(), 0); + assert_eq!(fake_model.completion_count(), 0); } } diff --git a/crates/completion/src/fake.rs b/crates/completion/src/fake.rs deleted file mode 100644 index 9eee0f736f..0000000000 --- a/crates/completion/src/fake.rs +++ /dev/null @@ -1,115 +0,0 @@ -use anyhow::Result; -use collections::HashMap; -use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt}; -use gpui::{AnyView, AppContext, Task}; -use std::sync::Arc; -use ui::WindowContext; - -use crate::{LanguageModel, LanguageModelCompletionProvider, LanguageModelRequest}; - -#[derive(Clone, Default)] -pub struct FakeCompletionProvider { - current_completion_txs: Arc>>>, -} - -impl FakeCompletionProvider { - pub fn setup_test(cx: &mut AppContext) -> Self { - use crate::CompletionProvider; - use parking_lot::RwLock; - - let this = Self::default(); - let provider = CompletionProvider::new(Arc::new(RwLock::new(this.clone())), None); - cx.set_global(provider); - this - } - - pub fn pending_completions(&self) -> Vec { - self.current_completion_txs - .lock() - .keys() - .map(|k| serde_json::from_str(k).unwrap()) - .collect() - } - - pub fn completion_count(&self) -> usize { - self.current_completion_txs.lock().len() - } - - pub fn send_completion_chunk(&self, request: &LanguageModelRequest, chunk: String) { - let json = serde_json::to_string(request).unwrap(); - self.current_completion_txs - .lock() - .get(&json) - .unwrap() - .unbounded_send(chunk) - .unwrap(); - } - - pub fn send_last_completion_chunk(&self, chunk: String) { - self.send_completion_chunk(self.pending_completions().last().unwrap(), chunk); - } - - pub fn finish_completion(&self, request: &LanguageModelRequest) { - self.current_completion_txs - .lock() - .remove(&serde_json::to_string(request).unwrap()) - .unwrap(); - } - - pub fn finish_last_completion(&self) { - self.finish_completion(self.pending_completions().last().unwrap()); - } -} - -impl LanguageModelCompletionProvider for FakeCompletionProvider { - fn available_models(&self) -> Vec { - vec![LanguageModel::default()] - } - - fn settings_version(&self) -> usize { - 0 - } - - fn is_authenticated(&self) -> bool { - true - } - - fn authenticate(&self, _cx: &AppContext) -> Task> { - Task::ready(Ok(())) - } - - fn authentication_prompt(&self, _cx: &mut WindowContext) -> AnyView { - unimplemented!() - } - - fn reset_credentials(&self, _cx: &AppContext) -> Task> { - Task::ready(Ok(())) - } - - fn model(&self) -> LanguageModel { - LanguageModel::default() - } - - fn count_tokens( - &self, - _request: LanguageModelRequest, - _cx: &AppContext, - ) -> BoxFuture<'static, Result> { - futures::future::ready(Ok(0)).boxed() - } - - fn stream_completion( - &self, - _request: LanguageModelRequest, - ) -> BoxFuture<'static, Result>>> { - let (tx, rx) = mpsc::unbounded(); - self.current_completion_txs - .lock() - .insert(serde_json::to_string(&_request).unwrap(), tx); - async move { Ok(rx.map(Ok).boxed()) }.boxed() - } - - fn as_any_mut(&mut self) -> &mut dyn std::any::Any { - self - } -} diff --git a/crates/editor/src/editor.rs b/crates/editor/src/editor.rs index 34f1143f89..c0aa4c332c 100644 --- a/crates/editor/src/editor.rs +++ b/crates/editor/src/editor.rs @@ -10384,7 +10384,7 @@ impl Editor { }; let fs = workspace.read(cx).app_state().fs.clone(); let current_show = TabBarSettings::get_global(cx).show; - update_settings_file::(fs, cx, move |setting| { + update_settings_file::(fs, cx, move |setting, _| { setting.show = Some(!current_show); }); } diff --git a/crates/extensions_ui/src/extension_version_selector.rs b/crates/extensions_ui/src/extension_version_selector.rs index c04efd6701..23208bc710 100644 --- a/crates/extensions_ui/src/extension_version_selector.rs +++ b/crates/extensions_ui/src/extension_version_selector.rs @@ -178,7 +178,7 @@ impl PickerDelegate for ExtensionVersionSelectorDelegate { update_settings_file::(self.fs.clone(), cx, { let extension_id = extension_id.clone(); - move |settings| { + move |settings, _| { settings.auto_update_extensions.insert(extension_id, false); } }); diff --git a/crates/extensions_ui/src/extensions_ui.rs b/crates/extensions_ui/src/extensions_ui.rs index e39bd18c7c..5b268b45fd 100644 --- a/crates/extensions_ui/src/extensions_ui.rs +++ b/crates/extensions_ui/src/extensions_ui.rs @@ -910,7 +910,7 @@ impl ExtensionsPage { if let Some(workspace) = self.workspace.upgrade() { let fs = workspace.read(cx).app_state().fs.clone(); let selection = *selection; - settings::update_settings_file::(fs, cx, move |settings| { + settings::update_settings_file::(fs, cx, move |settings, _| { let value = match selection { Selection::Unselected => false, Selection::Selected => true, diff --git a/crates/feature_flags/src/feature_flags.rs b/crates/feature_flags/src/feature_flags.rs index c90e3e34aa..4e8b4f0199 100644 --- a/crates/feature_flags/src/feature_flags.rs +++ b/crates/feature_flags/src/feature_flags.rs @@ -29,6 +29,11 @@ impl FeatureFlag for Remoting { const NAME: &'static str = "remoting"; } +pub struct LanguageModels {} +impl FeatureFlag for LanguageModels { + const NAME: &'static str = "language-models"; +} + pub struct TerminalInlineAssist {} impl FeatureFlag for TerminalInlineAssist { const NAME: &'static str = "terminal-inline-assist"; @@ -65,6 +70,10 @@ pub trait FeatureFlagAppExt { fn set_staff(&mut self, staff: bool); fn has_flag(&self) -> bool; fn is_staff(&self) -> bool; + + fn observe_flag(&mut self, callback: F) -> Subscription + where + F: Fn(bool, &mut AppContext) + 'static; } impl FeatureFlagAppExt for AppContext { @@ -90,4 +99,14 @@ impl FeatureFlagAppExt for AppContext { .map(|flags| flags.staff) .unwrap_or(false) } + + fn observe_flag(&mut self, callback: F) -> Subscription + where + F: Fn(bool, &mut AppContext) + 'static, + { + self.observe_global::(move |cx| { + let feature_flags = cx.global::(); + callback(feature_flags.has_flag(::NAME), cx); + }) + } } diff --git a/crates/inline_completion_button/src/inline_completion_button.rs b/crates/inline_completion_button/src/inline_completion_button.rs index 19e4fe2545..55daa1b040 100644 --- a/crates/inline_completion_button/src/inline_completion_button.rs +++ b/crates/inline_completion_button/src/inline_completion_button.rs @@ -420,7 +420,7 @@ async fn configure_disabled_globs( fn toggle_inline_completions_globally(fs: Arc, cx: &mut AppContext) { let show_inline_completions = all_language_settings(None, cx).inline_completions_enabled(None, None); - update_settings_file::(fs, cx, move |file| { + update_settings_file::(fs, cx, move |file, _| { file.defaults.show_inline_completions = Some(!show_inline_completions) }); } @@ -432,7 +432,7 @@ fn toggle_inline_completions_for_language( ) { let show_inline_completions = all_language_settings(None, cx).inline_completions_enabled(Some(&language), None); - update_settings_file::(fs, cx, move |file| { + update_settings_file::(fs, cx, move |file, _| { file.languages .entry(language.name()) .or_default() @@ -441,7 +441,7 @@ fn toggle_inline_completions_for_language( } fn hide_copilot(fs: Arc, cx: &mut AppContext) { - update_settings_file::(fs, cx, move |file| { + update_settings_file::(fs, cx, move |file, _| { file.features .get_or_insert(Default::default()) .inline_completion_provider = Some(InlineCompletionProvider::None); diff --git a/crates/language_model/Cargo.toml b/crates/language_model/Cargo.toml index bdc3ad63d5..1324676120 100644 --- a/crates/language_model/Cargo.toml +++ b/crates/language_model/Cargo.toml @@ -22,12 +22,27 @@ test-support = [ [dependencies] anthropic = { workspace = true, features = ["schemars"] } +anyhow.workspace = true +client.workspace = true +collections.workspace = true +editor.workspace = true +feature_flags.workspace = true +futures.workspace = true +gpui.workspace = true +http.workspace = true +menu.workspace = true ollama = { workspace = true, features = ["schemars"] } open_ai = { workspace = true, features = ["schemars"] } +proto = { workspace = true, features = ["test-support"] } schemars.workspace = true serde.workspace = true +serde_json.workspace = true +settings.workspace = true strum.workspace = true -proto = { workspace = true, features = ["test-support"] } +theme.workspace = true +tiktoken-rs.workspace = true +ui.workspace = true +util.workspace = true [dev-dependencies] ctor.workspace = true diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs index 09de409ff4..aa0c2d697a 100644 --- a/crates/language_model/src/language_model.rs +++ b/crates/language_model/src/language_model.rs @@ -1,7 +1,84 @@ mod model; +pub mod provider; +mod registry; mod request; mod role; +pub mod settings; + +use std::sync::Arc; + +use anyhow::Result; +use client::Client; +use futures::{future::BoxFuture, stream::BoxStream}; +use gpui::{AnyView, AppContext, AsyncAppContext, SharedString, Task, WindowContext}; pub use model::*; +pub use registry::*; pub use request::*; pub use role::*; + +pub fn init(client: Arc, cx: &mut AppContext) { + settings::init(cx); + registry::init(client, cx); +} + +pub trait LanguageModel: Send + Sync { + fn id(&self) -> LanguageModelId; + fn name(&self) -> LanguageModelName; + fn provider_name(&self) -> LanguageModelProviderName; + fn telemetry_id(&self) -> String; + + fn max_token_count(&self) -> usize; + + fn count_tokens( + &self, + request: LanguageModelRequest, + cx: &AppContext, + ) -> BoxFuture<'static, Result>; + + fn stream_completion( + &self, + request: LanguageModelRequest, + cx: &AsyncAppContext, + ) -> BoxFuture<'static, Result>>>; +} + +pub trait LanguageModelProvider: 'static { + fn name(&self) -> LanguageModelProviderName; + fn provided_models(&self, cx: &AppContext) -> Vec>; + fn is_authenticated(&self, cx: &AppContext) -> bool; + fn authenticate(&self, cx: &AppContext) -> Task>; + fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView; + fn reset_credentials(&self, cx: &AppContext) -> Task>; +} + +pub trait LanguageModelProviderState: 'static { + fn subscribe(&self, cx: &mut gpui::ModelContext) -> Option; +} + +#[derive(Clone, Eq, PartialEq, Hash, Debug)] +pub struct LanguageModelId(pub SharedString); + +#[derive(Clone, Eq, PartialEq, Hash, Debug)] +pub struct LanguageModelName(pub SharedString); + +#[derive(Clone, Eq, PartialEq, Hash, Debug)] +pub struct LanguageModelProviderName(pub SharedString); + +impl From for LanguageModelId { + fn from(value: String) -> Self { + Self(SharedString::from(value)) + } +} + +impl From for LanguageModelName { + fn from(value: String) -> Self { + Self(SharedString::from(value)) + } +} + +impl From for LanguageModelProviderName { + fn from(value: String) -> Self { + Self(SharedString::from(value)) + } +} diff --git a/crates/language_model/src/model/cloud_model.rs b/crates/language_model/src/model/cloud_model.rs index 0460b5dcf1..b7b304a65d 100644 --- a/crates/language_model/src/model/cloud_model.rs +++ b/crates/language_model/src/model/cloud_model.rs @@ -1,4 +1,5 @@ pub use anthropic::Model as AnthropicModel; +use anyhow::{anyhow, Result}; pub use ollama::Model as OllamaModel; pub use open_ai::Model as OpenAiModel; use schemars::JsonSchema; @@ -38,6 +39,23 @@ pub enum CloudModel { } impl CloudModel { + pub fn from_id(value: &str) -> Result { + match value { + "gpt-3.5-turbo" => Ok(Self::Gpt3Point5Turbo), + "gpt-4" => Ok(Self::Gpt4), + "gpt-4-turbo-preview" => Ok(Self::Gpt4Turbo), + "gpt-4o" => Ok(Self::Gpt4Omni), + "gpt-4o-mini" => Ok(Self::Gpt4OmniMini), + "claude-3-5-sonnet" => Ok(Self::Claude3_5Sonnet), + "claude-3-opus" => Ok(Self::Claude3Opus), + "claude-3-sonnet" => Ok(Self::Claude3Sonnet), + "claude-3-haiku" => Ok(Self::Claude3Haiku), + "gemini-1.5-pro" => Ok(Self::Gemini15Pro), + "gemini-1.5-flash" => Ok(Self::Gemini15Flash), + _ => Err(anyhow!("invalid model id")), + } + } + pub fn id(&self) -> &str { match self { Self::Gpt3Point5Turbo => "gpt-3.5-turbo", diff --git a/crates/language_model/src/model/mod.rs b/crates/language_model/src/model/mod.rs index b61766308f..7b5ac88dea 100644 --- a/crates/language_model/src/model/mod.rs +++ b/crates/language_model/src/model/mod.rs @@ -4,57 +4,3 @@ pub use anthropic::Model as AnthropicModel; pub use cloud_model::*; pub use ollama::Model as OllamaModel; pub use open_ai::Model as OpenAiModel; - -use serde::{Deserialize, Serialize}; - -#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] -pub enum LanguageModel { - Cloud(CloudModel), - OpenAi(OpenAiModel), - Anthropic(AnthropicModel), - Ollama(OllamaModel), -} - -impl Default for LanguageModel { - fn default() -> Self { - LanguageModel::Cloud(CloudModel::default()) - } -} - -impl LanguageModel { - pub fn telemetry_id(&self) -> String { - match self { - LanguageModel::OpenAi(model) => format!("openai/{}", model.id()), - LanguageModel::Anthropic(model) => format!("anthropic/{}", model.id()), - LanguageModel::Cloud(model) => format!("zed.dev/{}", model.id()), - LanguageModel::Ollama(model) => format!("ollama/{}", model.id()), - } - } - - pub fn display_name(&self) -> String { - match self { - LanguageModel::OpenAi(model) => model.display_name().into(), - LanguageModel::Anthropic(model) => model.display_name().into(), - LanguageModel::Cloud(model) => model.display_name().into(), - LanguageModel::Ollama(model) => model.display_name().into(), - } - } - - pub fn max_token_count(&self) -> usize { - match self { - LanguageModel::OpenAi(model) => model.max_token_count(), - LanguageModel::Anthropic(model) => model.max_token_count(), - LanguageModel::Cloud(model) => model.max_token_count(), - LanguageModel::Ollama(model) => model.max_token_count(), - } - } - - pub fn id(&self) -> &str { - match self { - LanguageModel::OpenAi(model) => model.id(), - LanguageModel::Anthropic(model) => model.id(), - LanguageModel::Cloud(model) => model.id(), - LanguageModel::Ollama(model) => model.id(), - } - } -} diff --git a/crates/language_model/src/provider.rs b/crates/language_model/src/provider.rs new file mode 100644 index 0000000000..f2713db003 --- /dev/null +++ b/crates/language_model/src/provider.rs @@ -0,0 +1,6 @@ +pub mod anthropic; +pub mod cloud; +#[cfg(any(test, feature = "test-support"))] +pub mod fake; +pub mod ollama; +pub mod open_ai; diff --git a/crates/language_model/src/provider/anthropic.rs b/crates/language_model/src/provider/anthropic.rs new file mode 100644 index 0000000000..facdfb21ac --- /dev/null +++ b/crates/language_model/src/provider/anthropic.rs @@ -0,0 +1,454 @@ +use anthropic::{stream_completion, Request, RequestMessage}; +use anyhow::{anyhow, Result}; +use collections::HashMap; +use editor::{Editor, EditorElement, EditorStyle}; +use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt}; +use gpui::{ + AnyView, AppContext, AsyncAppContext, FontStyle, Subscription, Task, TextStyle, View, + WhiteSpace, +}; +use http::HttpClient; +use settings::{Settings, SettingsStore}; +use std::{sync::Arc, time::Duration}; +use strum::IntoEnumIterator; +use theme::ThemeSettings; +use ui::prelude::*; +use util::ResultExt; + +use crate::{ + settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName, + LanguageModelProvider, LanguageModelProviderName, LanguageModelProviderState, + LanguageModelRequest, LanguageModelRequestMessage, Role, +}; + +const PROVIDER_NAME: &str = "anthropic"; + +#[derive(Default, Clone, Debug, PartialEq)] +pub struct AnthropicSettings { + pub api_url: String, + pub low_speed_timeout: Option, + pub available_models: Vec, +} + +pub struct AnthropicLanguageModelProvider { + http_client: Arc, + state: gpui::Model, +} + +struct State { + api_key: Option, + settings: AnthropicSettings, + _subscription: Subscription, +} + +impl AnthropicLanguageModelProvider { + pub fn new(http_client: Arc, cx: &mut AppContext) -> Self { + let state = cx.new_model(|cx| State { + api_key: None, + settings: AnthropicSettings::default(), + _subscription: cx.observe_global::(|this: &mut State, cx| { + this.settings = AllLanguageModelSettings::get_global(cx).anthropic.clone(); + cx.notify(); + }), + }); + + Self { http_client, state } + } +} +impl LanguageModelProviderState for AnthropicLanguageModelProvider { + fn subscribe(&self, cx: &mut gpui::ModelContext) -> Option { + Some(cx.observe(&self.state, |_, _, cx| { + cx.notify(); + })) + } +} + +impl LanguageModelProvider for AnthropicLanguageModelProvider { + fn name(&self) -> LanguageModelProviderName { + LanguageModelProviderName(PROVIDER_NAME.into()) + } + + fn provided_models(&self, cx: &AppContext) -> Vec> { + let mut models = HashMap::default(); + + // Add base models from anthropic::Model::iter() + for model in anthropic::Model::iter() { + if !matches!(model, anthropic::Model::Custom { .. }) { + models.insert(model.id().to_string(), model); + } + } + + // Override with available models from settings + for model in &self.state.read(cx).settings.available_models { + models.insert(model.id().to_string(), model.clone()); + } + + models + .into_values() + .map(|model| { + Arc::new(AnthropicModel { + id: LanguageModelId::from(model.id().to_string()), + model, + state: self.state.clone(), + http_client: self.http_client.clone(), + }) as Arc + }) + .collect() + } + + fn is_authenticated(&self, cx: &AppContext) -> bool { + self.state.read(cx).api_key.is_some() + } + + fn authenticate(&self, cx: &AppContext) -> Task> { + if self.is_authenticated(cx) { + Task::ready(Ok(())) + } else { + let api_url = self.state.read(cx).settings.api_url.clone(); + let state = self.state.clone(); + cx.spawn(|mut cx| async move { + let api_key = if let Ok(api_key) = std::env::var("ANTHROPIC_API_KEY") { + api_key + } else { + let (_, api_key) = cx + .update(|cx| cx.read_credentials(&api_url))? + .await? + .ok_or_else(|| anyhow!("credentials not found"))?; + String::from_utf8(api_key)? + }; + + state.update(&mut cx, |this, cx| { + this.api_key = Some(api_key); + cx.notify(); + }) + }) + } + } + + fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView { + cx.new_view(|cx| AuthenticationPrompt::new(self.state.clone(), cx)) + .into() + } + + fn reset_credentials(&self, cx: &AppContext) -> Task> { + let state = self.state.clone(); + let delete_credentials = cx.delete_credentials(&self.state.read(cx).settings.api_url); + cx.spawn(|mut cx| async move { + delete_credentials.await.log_err(); + state.update(&mut cx, |this, cx| { + this.api_key = None; + cx.notify(); + }) + }) + } +} + +pub struct AnthropicModel { + id: LanguageModelId, + model: anthropic::Model, + state: gpui::Model, + http_client: Arc, +} + +impl AnthropicModel { + fn to_anthropic_request(&self, mut request: LanguageModelRequest) -> Request { + preprocess_anthropic_request(&mut request); + + let mut system_message = String::new(); + if request + .messages + .first() + .map_or(false, |message| message.role == Role::System) + { + system_message = request.messages.remove(0).content; + } + + Request { + model: self.model.clone(), + messages: request + .messages + .iter() + .map(|msg| RequestMessage { + role: match msg.role { + Role::User => anthropic::Role::User, + Role::Assistant => anthropic::Role::Assistant, + Role::System => unreachable!("filtered out by preprocess_request"), + }, + content: msg.content.clone(), + }) + .collect(), + stream: true, + system: system_message, + max_tokens: 4092, + } + } +} + +pub fn count_anthropic_tokens( + request: LanguageModelRequest, + cx: &AppContext, +) -> BoxFuture<'static, Result> { + cx.background_executor() + .spawn(async move { + let messages = request + .messages + .into_iter() + .map(|message| tiktoken_rs::ChatCompletionRequestMessage { + role: match message.role { + Role::User => "user".into(), + Role::Assistant => "assistant".into(), + Role::System => "system".into(), + }, + content: Some(message.content), + name: None, + function_call: None, + }) + .collect::>(); + + // Tiktoken doesn't yet support these models, so we manually use the + // same tokenizer as GPT-4. + tiktoken_rs::num_tokens_from_messages("gpt-4", &messages) + }) + .boxed() +} + +impl LanguageModel for AnthropicModel { + fn id(&self) -> LanguageModelId { + self.id.clone() + } + + fn name(&self) -> LanguageModelName { + LanguageModelName::from(self.model.display_name().to_string()) + } + + fn provider_name(&self) -> LanguageModelProviderName { + LanguageModelProviderName(PROVIDER_NAME.into()) + } + + fn telemetry_id(&self) -> String { + format!("anthropic/{}", self.model.id()) + } + + fn max_token_count(&self) -> usize { + self.model.max_token_count() + } + + fn count_tokens( + &self, + request: LanguageModelRequest, + cx: &AppContext, + ) -> BoxFuture<'static, Result> { + count_anthropic_tokens(request, cx) + } + + fn stream_completion( + &self, + request: LanguageModelRequest, + cx: &AsyncAppContext, + ) -> BoxFuture<'static, Result>>> { + let request = self.to_anthropic_request(request); + + let http_client = self.http_client.clone(); + let Ok((api_key, api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, _| { + ( + state.api_key.clone(), + state.settings.api_url.clone(), + state.settings.low_speed_timeout, + ) + }) else { + return futures::future::ready(Err(anyhow!("App state dropped"))).boxed(); + }; + + async move { + let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?; + let request = stream_completion( + http_client.as_ref(), + &api_url, + &api_key, + request, + low_speed_timeout, + ); + let response = request.await?; + let stream = response + .filter_map(|response| async move { + match response { + Ok(response) => match response { + anthropic::ResponseEvent::ContentBlockStart { + content_block, .. + } => match content_block { + anthropic::ContentBlock::Text { text } => Some(Ok(text)), + }, + anthropic::ResponseEvent::ContentBlockDelta { delta, .. } => { + match delta { + anthropic::TextDelta::TextDelta { text } => Some(Ok(text)), + } + } + _ => None, + }, + Err(error) => Some(Err(error)), + } + }) + .boxed(); + Ok(stream) + } + .boxed() + } +} + +pub fn preprocess_anthropic_request(request: &mut LanguageModelRequest) { + let mut new_messages: Vec = Vec::new(); + let mut system_message = String::new(); + + for message in request.messages.drain(..) { + if message.content.is_empty() { + continue; + } + + match message.role { + Role::User | Role::Assistant => { + if let Some(last_message) = new_messages.last_mut() { + if last_message.role == message.role { + last_message.content.push_str("\n\n"); + last_message.content.push_str(&message.content); + continue; + } + } + + new_messages.push(message); + } + Role::System => { + if !system_message.is_empty() { + system_message.push_str("\n\n"); + } + system_message.push_str(&message.content); + } + } + } + + if !system_message.is_empty() { + new_messages.insert( + 0, + LanguageModelRequestMessage { + role: Role::System, + content: system_message, + }, + ); + } + + request.messages = new_messages; +} + +struct AuthenticationPrompt { + api_key: View, + state: gpui::Model, +} + +impl AuthenticationPrompt { + fn new(state: gpui::Model, cx: &mut WindowContext) -> Self { + Self { + api_key: cx.new_view(|cx| { + let mut editor = Editor::single_line(cx); + editor.set_placeholder_text( + "sk-000000000000000000000000000000000000000000000000", + cx, + ); + editor + }), + state, + } + } + + fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext) { + let api_key = self.api_key.read(cx).text(cx); + if api_key.is_empty() { + return; + } + + let write_credentials = cx.write_credentials( + &self.state.read(cx).settings.api_url, + "Bearer", + api_key.as_bytes(), + ); + let state = self.state.clone(); + cx.spawn(|_, mut cx| async move { + write_credentials.await?; + + state.update(&mut cx, |this, cx| { + this.api_key = Some(api_key); + cx.notify(); + }) + }) + .detach_and_log_err(cx); + } + + fn render_api_key_editor(&self, cx: &mut ViewContext) -> impl IntoElement { + let settings = ThemeSettings::get_global(cx); + let text_style = TextStyle { + color: cx.theme().colors().text, + font_family: settings.ui_font.family.clone(), + font_features: settings.ui_font.features.clone(), + font_size: rems(0.875).into(), + font_weight: settings.ui_font.weight, + font_style: FontStyle::Normal, + line_height: relative(1.3), + background_color: None, + underline: None, + strikethrough: None, + white_space: WhiteSpace::Normal, + }; + EditorElement::new( + &self.api_key, + EditorStyle { + background: cx.theme().colors().editor_background, + local_player: cx.theme().players().local(), + text: text_style, + ..Default::default() + }, + ) + } +} + +impl Render for AuthenticationPrompt { + fn render(&mut self, cx: &mut ViewContext) -> impl IntoElement { + const INSTRUCTIONS: [&str; 4] = [ + "To use the assistant panel or inline assistant, you need to add your Anthropic API key.", + "You can create an API key at: https://console.anthropic.com/settings/keys", + "", + "Paste your Anthropic API key below and hit enter to use the assistant:", + ]; + + v_flex() + .p_4() + .size_full() + .on_action(cx.listener(Self::save_api_key)) + .children( + INSTRUCTIONS.map(|instruction| Label::new(instruction).size(LabelSize::Small)), + ) + .child( + h_flex() + .w_full() + .my_2() + .px_2() + .py_1() + .bg(cx.theme().colors().editor_background) + .rounded_md() + .child(self.render_api_key_editor(cx)), + ) + .child( + Label::new( + "You can also assign the ANTHROPIC_API_KEY environment variable and restart Zed.", + ) + .size(LabelSize::Small), + ) + .child( + h_flex() + .gap_2() + .child(Label::new("Click on").size(LabelSize::Small)) + .child(Icon::new(IconName::ZedAssistant).size(IconSize::XSmall)) + .child( + Label::new("in the status bar to close this panel.").size(LabelSize::Small), + ), + ) + .into_any() + } +} diff --git a/crates/language_model/src/provider/cloud.rs b/crates/language_model/src/provider/cloud.rs new file mode 100644 index 0000000000..3b42b72090 --- /dev/null +++ b/crates/language_model/src/provider/cloud.rs @@ -0,0 +1,287 @@ +use super::open_ai::count_open_ai_tokens; +use crate::{ + settings::AllLanguageModelSettings, CloudModel, LanguageModel, LanguageModelId, + LanguageModelName, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, +}; +use anyhow::Result; +use client::Client; +use collections::HashMap; +use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt, TryFutureExt}; +use gpui::{AnyView, AppContext, AsyncAppContext, Subscription, Task}; +use settings::{Settings, SettingsStore}; +use std::sync::Arc; +use strum::IntoEnumIterator; +use ui::prelude::*; + +use crate::LanguageModelProvider; + +use super::anthropic::{count_anthropic_tokens, preprocess_anthropic_request}; + +pub const PROVIDER_NAME: &str = "zed.dev"; + +#[derive(Default, Clone, Debug, PartialEq)] +pub struct ZedDotDevSettings { + pub available_models: Vec, +} + +pub struct CloudLanguageModelProvider { + client: Arc, + state: gpui::Model, + _maintain_client_status: Task<()>, +} + +struct State { + client: Arc, + status: client::Status, + settings: ZedDotDevSettings, + _subscription: Subscription, +} + +impl State { + fn authenticate(&self, cx: &AppContext) -> Task> { + let client = self.client.clone(); + cx.spawn(move |cx| async move { client.authenticate_and_connect(true, &cx).await }) + } +} + +impl CloudLanguageModelProvider { + pub fn new(client: Arc, cx: &mut AppContext) -> Self { + let mut status_rx = client.status(); + let status = *status_rx.borrow(); + + let state = cx.new_model(|cx| State { + client: client.clone(), + status, + settings: ZedDotDevSettings::default(), + _subscription: cx.observe_global::(|this: &mut State, cx| { + this.settings = AllLanguageModelSettings::get_global(cx).zed_dot_dev.clone(); + cx.notify(); + }), + }); + + let state_ref = state.downgrade(); + let maintain_client_status = cx.spawn(|mut cx| async move { + while let Some(status) = status_rx.next().await { + if let Some(this) = state_ref.upgrade() { + _ = this.update(&mut cx, |this, cx| { + this.status = status; + cx.notify(); + }); + } else { + break; + } + } + }); + + Self { + client, + state, + _maintain_client_status: maintain_client_status, + } + } +} + +impl LanguageModelProviderState for CloudLanguageModelProvider { + fn subscribe(&self, cx: &mut gpui::ModelContext) -> Option { + Some(cx.observe(&self.state, |_, _, cx| { + cx.notify(); + })) + } +} + +impl LanguageModelProvider for CloudLanguageModelProvider { + fn name(&self) -> LanguageModelProviderName { + LanguageModelProviderName(PROVIDER_NAME.into()) + } + + fn provided_models(&self, cx: &AppContext) -> Vec> { + let mut models = HashMap::default(); + + // Add base models from CloudModel::iter() + for model in CloudModel::iter() { + if !matches!(model, CloudModel::Custom { .. }) { + models.insert(model.id().to_string(), model); + } + } + + // Override with available models from settings + for model in &self.state.read(cx).settings.available_models { + models.insert(model.id().to_string(), model.clone()); + } + + models + .into_values() + .map(|model| { + Arc::new(CloudLanguageModel { + id: LanguageModelId::from(model.id().to_string()), + model, + client: self.client.clone(), + }) as Arc + }) + .collect() + } + + fn is_authenticated(&self, cx: &AppContext) -> bool { + self.state.read(cx).status.is_connected() + } + + fn authenticate(&self, cx: &AppContext) -> Task> { + self.state.read(cx).authenticate(cx) + } + + fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView { + cx.new_view(|_cx| AuthenticationPrompt { + state: self.state.clone(), + }) + .into() + } + + fn reset_credentials(&self, _cx: &AppContext) -> Task> { + Task::ready(Ok(())) + } +} + +pub struct CloudLanguageModel { + id: LanguageModelId, + model: CloudModel, + client: Arc, +} + +impl LanguageModel for CloudLanguageModel { + fn id(&self) -> LanguageModelId { + self.id.clone() + } + + fn name(&self) -> LanguageModelName { + LanguageModelName::from(self.model.display_name().to_string()) + } + + fn provider_name(&self) -> LanguageModelProviderName { + LanguageModelProviderName(PROVIDER_NAME.into()) + } + + fn telemetry_id(&self) -> String { + format!("zed.dev/{}", self.model.id()) + } + + fn max_token_count(&self) -> usize { + self.model.max_token_count() + } + + fn count_tokens( + &self, + request: LanguageModelRequest, + cx: &AppContext, + ) -> BoxFuture<'static, Result> { + match &self.model { + CloudModel::Gpt3Point5Turbo => { + count_open_ai_tokens(request, open_ai::Model::ThreePointFiveTurbo, cx) + } + CloudModel::Gpt4 => count_open_ai_tokens(request, open_ai::Model::Four, cx), + CloudModel::Gpt4Turbo => count_open_ai_tokens(request, open_ai::Model::FourTurbo, cx), + CloudModel::Gpt4Omni => count_open_ai_tokens(request, open_ai::Model::FourOmni, cx), + CloudModel::Gpt4OmniMini => { + count_open_ai_tokens(request, open_ai::Model::FourOmniMini, cx) + } + CloudModel::Claude3_5Sonnet + | CloudModel::Claude3Opus + | CloudModel::Claude3Sonnet + | CloudModel::Claude3Haiku => count_anthropic_tokens(request, cx), + _ => { + let request = self.client.request(proto::CountTokensWithLanguageModel { + model: self.model.id().to_string(), + messages: request + .messages + .iter() + .map(|message| message.to_proto()) + .collect(), + }); + async move { + let response = request.await?; + Ok(response.token_count as usize) + } + .boxed() + } + } + } + + fn stream_completion( + &self, + mut request: LanguageModelRequest, + _: &AsyncAppContext, + ) -> BoxFuture<'static, Result>>> { + match &self.model { + CloudModel::Claude3Opus + | CloudModel::Claude3Sonnet + | CloudModel::Claude3Haiku + | CloudModel::Claude3_5Sonnet => preprocess_anthropic_request(&mut request), + CloudModel::Custom { name, .. } if name.starts_with("anthropic/") => { + preprocess_anthropic_request(&mut request) + } + _ => {} + } + + let request = proto::CompleteWithLanguageModel { + model: self.id.0.to_string(), + messages: request + .messages + .iter() + .map(|message| message.to_proto()) + .collect(), + stop: request.stop, + temperature: request.temperature, + tools: Vec::new(), + tool_choice: None, + }; + + self.client + .request_stream(request) + .map_ok(|stream| { + stream + .filter_map(|response| async move { + match response { + Ok(mut response) => Some(Ok(response.choices.pop()?.delta?.content?)), + Err(error) => Some(Err(error)), + } + }) + .boxed() + }) + .boxed() + } +} + +struct AuthenticationPrompt { + state: gpui::Model, +} + +impl Render for AuthenticationPrompt { + fn render(&mut self, cx: &mut ViewContext) -> impl IntoElement { + const LABEL: &str = "Generate and analyze code with language models. You can dialog with the assistant in this panel or transform code inline."; + + v_flex().gap_6().p_4().child(Label::new(LABEL)).child( + v_flex() + .gap_2() + .child( + Button::new("sign_in", "Sign in") + .icon_color(Color::Muted) + .icon(IconName::Github) + .icon_position(IconPosition::Start) + .style(ButtonStyle::Filled) + .full_width() + .on_click(cx.listener(move |this, _, cx| { + this.state.update(cx, |provider, cx| { + provider.authenticate(cx).detach_and_log_err(cx); + cx.notify(); + }); + })), + ) + .child( + div().flex().w_full().items_center().child( + Label::new("Sign in to enable collaboration.") + .color(Color::Muted) + .size(LabelSize::Small), + ), + ), + ) + } +} diff --git a/crates/language_model/src/provider/fake.rs b/crates/language_model/src/provider/fake.rs new file mode 100644 index 0000000000..bcd46a6fdd --- /dev/null +++ b/crates/language_model/src/provider/fake.rs @@ -0,0 +1,160 @@ +use std::sync::{Arc, Mutex}; + +use collections::HashMap; +use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt}; + +use crate::{ + LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider, + LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, +}; +use gpui::{AnyView, AppContext, AsyncAppContext, Task}; +use http::Result; +use ui::WindowContext; + +pub fn language_model_id() -> LanguageModelId { + LanguageModelId::from("fake".to_string()) +} + +pub fn language_model_name() -> LanguageModelName { + LanguageModelName::from("Fake".to_string()) +} + +pub fn provider_name() -> LanguageModelProviderName { + LanguageModelProviderName::from("fake".to_string()) +} + +#[derive(Clone, Default)] +pub struct FakeLanguageModelProvider { + current_completion_txs: Arc>>>, +} + +impl LanguageModelProviderState for FakeLanguageModelProvider { + fn subscribe(&self, _: &mut gpui::ModelContext) -> Option { + None + } +} + +impl LanguageModelProvider for FakeLanguageModelProvider { + fn name(&self) -> LanguageModelProviderName { + provider_name() + } + + fn provided_models(&self, _: &AppContext) -> Vec> { + vec![Arc::new(FakeLanguageModel { + current_completion_txs: self.current_completion_txs.clone(), + })] + } + + fn is_authenticated(&self, _: &AppContext) -> bool { + true + } + + fn authenticate(&self, _: &AppContext) -> Task> { + Task::ready(Ok(())) + } + + fn authentication_prompt(&self, _: &mut WindowContext) -> AnyView { + unimplemented!() + } + + fn reset_credentials(&self, _: &AppContext) -> Task> { + Task::ready(Ok(())) + } +} + +impl FakeLanguageModelProvider { + pub fn test_model(&self) -> FakeLanguageModel { + FakeLanguageModel { + current_completion_txs: self.current_completion_txs.clone(), + } + } +} + +pub struct FakeLanguageModel { + current_completion_txs: Arc>>>, +} + +impl FakeLanguageModel { + pub fn pending_completions(&self) -> Vec { + self.current_completion_txs + .lock() + .unwrap() + .keys() + .map(|k| serde_json::from_str(k).unwrap()) + .collect() + } + + pub fn completion_count(&self) -> usize { + self.current_completion_txs.lock().unwrap().len() + } + + pub fn send_completion_chunk(&self, request: &LanguageModelRequest, chunk: String) { + let json = serde_json::to_string(request).unwrap(); + self.current_completion_txs + .lock() + .unwrap() + .get(&json) + .unwrap() + .unbounded_send(chunk) + .unwrap(); + } + + pub fn send_last_completion_chunk(&self, chunk: String) { + self.send_completion_chunk(self.pending_completions().last().unwrap(), chunk); + } + + pub fn finish_completion(&self, request: &LanguageModelRequest) { + self.current_completion_txs + .lock() + .unwrap() + .remove(&serde_json::to_string(request).unwrap()) + .unwrap(); + } + + pub fn finish_last_completion(&self) { + self.finish_completion(self.pending_completions().last().unwrap()); + } +} + +impl LanguageModel for FakeLanguageModel { + fn id(&self) -> LanguageModelId { + language_model_id() + } + + fn name(&self) -> LanguageModelName { + language_model_name() + } + + fn provider_name(&self) -> LanguageModelProviderName { + provider_name() + } + + fn telemetry_id(&self) -> String { + "fake".to_string() + } + + fn max_token_count(&self) -> usize { + 1000000 + } + + fn count_tokens( + &self, + _: LanguageModelRequest, + _: &AppContext, + ) -> BoxFuture<'static, Result> { + futures::future::ready(Ok(0)).boxed() + } + + fn stream_completion( + &self, + request: LanguageModelRequest, + _: &AsyncAppContext, + ) -> BoxFuture<'static, Result>>> { + let (tx, rx) = mpsc::unbounded(); + self.current_completion_txs + .lock() + .unwrap() + .insert(serde_json::to_string(&request).unwrap(), tx); + async move { Ok(rx.map(Ok).boxed()) }.boxed() + } +} diff --git a/crates/completion/src/ollama.rs b/crates/language_model/src/provider/ollama.rs similarity index 57% rename from crates/completion/src/ollama.rs rename to crates/language_model/src/provider/ollama.rs index 30d797c76b..47a3c3731d 100644 --- a/crates/completion/src/ollama.rs +++ b/crates/language_model/src/provider/ollama.rs @@ -1,49 +1,148 @@ -use crate::LanguageModelCompletionProvider; -use crate::{CompletionProvider, LanguageModel, LanguageModelRequest}; -use anyhow::Result; -use futures::StreamExt as _; -use futures::{future::BoxFuture, stream::BoxStream, FutureExt}; -use gpui::{AnyView, AppContext, Task}; +use anyhow::{anyhow, Result}; +use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt}; +use gpui::{AnyView, AppContext, AsyncAppContext, ModelContext, Subscription, Task}; use http::HttpClient; -use language_model::Role; -use ollama::Model as OllamaModel; -use ollama::{ - get_models, preload_model, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest, -}; -use std::sync::Arc; -use std::time::Duration; +use ollama::{get_models, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest}; +use settings::{Settings, SettingsStore}; +use std::{sync::Arc, time::Duration}; use ui::{prelude::*, ButtonLike, ElevationIndex}; +use crate::{ + settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName, + LanguageModelProvider, LanguageModelProviderName, LanguageModelProviderState, + LanguageModelRequest, Role, +}; + const OLLAMA_DOWNLOAD_URL: &str = "https://ollama.com/download"; const OLLAMA_LIBRARY_URL: &str = "https://ollama.com/library"; -pub struct OllamaCompletionProvider { - api_url: String, - model: OllamaModel, - http_client: Arc, - low_speed_timeout: Option, - settings_version: usize, - available_models: Vec, +const PROVIDER_NAME: &str = "ollama"; + +#[derive(Default, Debug, Clone, PartialEq)] +pub struct OllamaSettings { + pub api_url: String, + pub low_speed_timeout: Option, } -impl LanguageModelCompletionProvider for OllamaCompletionProvider { - fn available_models(&self) -> Vec { - self.available_models +pub struct OllamaLanguageModelProvider { + http_client: Arc, + state: gpui::Model, +} + +struct State { + http_client: Arc, + available_models: Vec, + settings: OllamaSettings, + _subscription: Subscription, +} + +impl State { + fn fetch_models(&self, cx: &mut ModelContext) -> Task> { + let http_client = self.http_client.clone(); + let api_url = self.settings.api_url.clone(); + + // As a proxy for the server being "authenticated", we'll check if its up by fetching the models + cx.spawn(|this, mut cx| async move { + let models = get_models(http_client.as_ref(), &api_url, None).await?; + + let mut models: Vec = models + .into_iter() + // Since there is no metadata from the Ollama API + // indicating which models are embedding models, + // simply filter out models with "-embed" in their name + .filter(|model| !model.name.contains("-embed")) + .map(|model| ollama::Model::new(&model.name)) + .collect(); + + models.sort_by(|a, b| a.name.cmp(&b.name)); + + this.update(&mut cx, |this, cx| { + this.available_models = models; + cx.notify(); + }) + }) + } +} + +impl OllamaLanguageModelProvider { + pub fn new(http_client: Arc, cx: &mut AppContext) -> Self { + Self { + http_client: http_client.clone(), + state: cx.new_model(|cx| State { + http_client, + available_models: Default::default(), + settings: OllamaSettings::default(), + _subscription: cx.observe_global::(|this: &mut State, cx| { + this.settings = AllLanguageModelSettings::get_global(cx).ollama.clone(); + cx.notify(); + }), + }), + } + } + + fn fetch_models(&self, cx: &AppContext) -> Task> { + let http_client = self.http_client.clone(); + let api_url = self.state.read(cx).settings.api_url.clone(); + + let state = self.state.clone(); + // As a proxy for the server being "authenticated", we'll check if its up by fetching the models + cx.spawn(|mut cx| async move { + let models = get_models(http_client.as_ref(), &api_url, None).await?; + + let mut models: Vec = models + .into_iter() + // Since there is no metadata from the Ollama API + // indicating which models are embedding models, + // simply filter out models with "-embed" in their name + .filter(|model| !model.name.contains("-embed")) + .map(|model| ollama::Model::new(&model.name)) + .collect(); + + models.sort_by(|a, b| a.name.cmp(&b.name)); + + state.update(&mut cx, |this, cx| { + this.available_models = models; + cx.notify(); + }) + }) + } +} + +impl LanguageModelProviderState for OllamaLanguageModelProvider { + fn subscribe(&self, cx: &mut gpui::ModelContext) -> Option { + Some(cx.observe(&self.state, |_, _, cx| { + cx.notify(); + })) + } +} + +impl LanguageModelProvider for OllamaLanguageModelProvider { + fn name(&self) -> LanguageModelProviderName { + LanguageModelProviderName(PROVIDER_NAME.into()) + } + + fn provided_models(&self, cx: &AppContext) -> Vec> { + self.state + .read(cx) + .available_models .iter() - .map(|m| LanguageModel::Ollama(m.clone())) + .map(|model| { + Arc::new(OllamaLanguageModel { + id: LanguageModelId::from(model.name.clone()), + model: model.clone(), + http_client: self.http_client.clone(), + state: self.state.clone(), + }) as Arc + }) .collect() } - fn settings_version(&self) -> usize { - self.settings_version - } - - fn is_authenticated(&self) -> bool { - !self.available_models.is_empty() + fn is_authenticated(&self, cx: &AppContext) -> bool { + !self.state.read(cx).available_models.is_empty() } fn authenticate(&self, cx: &AppContext) -> Task> { - if self.is_authenticated() { + if self.is_authenticated(cx) { Task::ready(Ok(())) } else { self.fetch_models(cx) @@ -51,14 +150,9 @@ impl LanguageModelCompletionProvider for OllamaCompletionProvider { } fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView { + let state = self.state.clone(); let fetch_models = Box::new(move |cx: &mut WindowContext| { - cx.update_global::(|provider, cx| { - provider - .update_current_as::<_, OllamaCompletionProvider>(|provider| { - provider.fetch_models(cx) - }) - .unwrap_or_else(|| Task::ready(Ok(()))) - }) + state.update(cx, |this, cx| this.fetch_models(cx)) }); cx.new_view(|cx| DownloadOllamaMessage::new(fetch_models, cx)) @@ -68,9 +162,65 @@ impl LanguageModelCompletionProvider for OllamaCompletionProvider { fn reset_credentials(&self, cx: &AppContext) -> Task> { self.fetch_models(cx) } +} - fn model(&self) -> LanguageModel { - LanguageModel::Ollama(self.model.clone()) +pub struct OllamaLanguageModel { + id: LanguageModelId, + model: ollama::Model, + state: gpui::Model, + http_client: Arc, +} + +impl OllamaLanguageModel { + fn to_ollama_request(&self, request: LanguageModelRequest) -> ChatRequest { + ChatRequest { + model: self.model.name.clone(), + messages: request + .messages + .into_iter() + .map(|msg| match msg.role { + Role::User => ChatMessage::User { + content: msg.content, + }, + Role::Assistant => ChatMessage::Assistant { + content: msg.content, + }, + Role::System => ChatMessage::System { + content: msg.content, + }, + }) + .collect(), + keep_alive: self.model.keep_alive.clone().unwrap_or_default(), + stream: true, + options: Some(ChatOptions { + num_ctx: Some(self.model.max_tokens), + stop: Some(request.stop), + temperature: Some(request.temperature), + ..Default::default() + }), + } + } +} + +impl LanguageModel for OllamaLanguageModel { + fn id(&self) -> LanguageModelId { + self.id.clone() + } + + fn name(&self) -> LanguageModelName { + LanguageModelName::from(self.model.display_name().to_string()) + } + + fn max_token_count(&self) -> usize { + self.model.max_token_count() + } + + fn telemetry_id(&self) -> String { + format!("ollama/{}", self.model.id()) + } + + fn provider_name(&self) -> LanguageModelProviderName { + LanguageModelProviderName(PROVIDER_NAME.into()) } fn count_tokens( @@ -93,12 +243,20 @@ impl LanguageModelCompletionProvider for OllamaCompletionProvider { fn stream_completion( &self, request: LanguageModelRequest, + cx: &AsyncAppContext, ) -> BoxFuture<'static, Result>>> { let request = self.to_ollama_request(request); let http_client = self.http_client.clone(); - let api_url = self.api_url.clone(); - let low_speed_timeout = self.low_speed_timeout; + let Ok((api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, _| { + ( + state.settings.api_url.clone(), + state.settings.low_speed_timeout, + ) + }) else { + return futures::future::ready(Err(anyhow!("App state dropped"))).boxed(); + }; + async move { let request = stream_chat_completion(http_client.as_ref(), &api_url, request, low_speed_timeout); @@ -122,143 +280,6 @@ impl LanguageModelCompletionProvider for OllamaCompletionProvider { } .boxed() } - - fn as_any_mut(&mut self) -> &mut dyn std::any::Any { - self - } -} - -impl OllamaCompletionProvider { - pub fn new( - model: OllamaModel, - api_url: String, - http_client: Arc, - low_speed_timeout: Option, - settings_version: usize, - cx: &AppContext, - ) -> Self { - cx.spawn({ - let api_url = api_url.clone(); - let client = http_client.clone(); - let model = model.name.clone(); - - |_| async move { - if model.is_empty() { - return Ok(()); - } - preload_model(client.as_ref(), &api_url, &model).await - } - }) - .detach_and_log_err(cx); - - Self { - api_url, - model, - http_client, - low_speed_timeout, - settings_version, - available_models: Default::default(), - } - } - - pub fn update( - &mut self, - model: OllamaModel, - api_url: String, - low_speed_timeout: Option, - settings_version: usize, - cx: &AppContext, - ) { - cx.spawn({ - let api_url = api_url.clone(); - let client = self.http_client.clone(); - let model = model.name.clone(); - - |_| async move { preload_model(client.as_ref(), &api_url, &model).await } - }) - .detach_and_log_err(cx); - - if model.name.is_empty() { - self.select_first_available_model() - } else { - self.model = model; - } - - self.api_url = api_url; - self.low_speed_timeout = low_speed_timeout; - self.settings_version = settings_version; - } - - pub fn select_first_available_model(&mut self) { - if let Some(model) = self.available_models.first() { - self.model = model.clone(); - } - } - - pub fn fetch_models(&self, cx: &AppContext) -> Task> { - let http_client = self.http_client.clone(); - let api_url = self.api_url.clone(); - - // As a proxy for the server being "authenticated", we'll check if its up by fetching the models - cx.spawn(|mut cx| async move { - let models = get_models(http_client.as_ref(), &api_url, None).await?; - - let mut models: Vec = models - .into_iter() - // Since there is no metadata from the Ollama API - // indicating which models are embedding models, - // simply filter out models with "-embed" in their name - .filter(|model| !model.name.contains("-embed")) - .map(|model| OllamaModel::new(&model.name)) - .collect(); - - models.sort_by(|a, b| a.name.cmp(&b.name)); - - cx.update_global::(|provider, _cx| { - provider.update_current_as::<_, OllamaCompletionProvider>(|provider| { - provider.available_models = models; - - if !provider.available_models.is_empty() && provider.model.name.is_empty() { - provider.select_first_available_model() - } - }); - }) - }) - } - - fn to_ollama_request(&self, request: LanguageModelRequest) -> ChatRequest { - let model = match request.model { - LanguageModel::Ollama(model) => model, - _ => self.model.clone(), - }; - - ChatRequest { - model: model.name, - messages: request - .messages - .into_iter() - .map(|msg| match msg.role { - Role::User => ChatMessage::User { - content: msg.content, - }, - Role::Assistant => ChatMessage::Assistant { - content: msg.content, - }, - Role::System => ChatMessage::System { - content: msg.content, - }, - }) - .collect(), - keep_alive: model.keep_alive.unwrap_or_default(), - stream: true, - options: Some(ChatOptions { - num_ctx: Some(model.max_tokens), - stop: Some(request.stop), - temperature: Some(request.temperature), - ..Default::default() - }), - } - } } struct DownloadOllamaMessage { diff --git a/crates/completion/src/open_ai.rs b/crates/language_model/src/provider/open_ai.rs similarity index 56% rename from crates/completion/src/open_ai.rs rename to crates/language_model/src/provider/open_ai.rs index 21a0bbd73e..b82df4ca48 100644 --- a/crates/completion/src/open_ai.rs +++ b/crates/language_model/src/provider/open_ai.rs @@ -1,72 +1,159 @@ -use crate::CompletionProvider; -use crate::LanguageModelCompletionProvider; use anyhow::{anyhow, Result}; +use collections::HashMap; use editor::{Editor, EditorElement, EditorStyle}; -use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt}; -use gpui::{AnyView, AppContext, Task, TextStyle, View}; +use futures::{future::BoxFuture, FutureExt, StreamExt}; +use gpui::{ + AnyView, AppContext, AsyncAppContext, FontStyle, Subscription, Task, TextStyle, View, + WhiteSpace, +}; use http::HttpClient; -use language_model::{CloudModel, LanguageModel, LanguageModelRequest, Role}; -use open_ai::Model as OpenAiModel; use open_ai::{stream_completion, Request, RequestMessage}; -use settings::Settings; -use std::time::Duration; -use std::{env, sync::Arc}; +use settings::{Settings, SettingsStore}; +use std::{sync::Arc, time::Duration}; use strum::IntoEnumIterator; use theme::ThemeSettings; use ui::prelude::*; use util::ResultExt; -pub struct OpenAiCompletionProvider { - api_key: Option, - api_url: String, - model: OpenAiModel, - http_client: Arc, - low_speed_timeout: Option, - settings_version: usize, - available_models_from_settings: Vec, +use crate::{ + settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName, + LanguageModelProvider, LanguageModelProviderName, LanguageModelProviderState, + LanguageModelRequest, Role, +}; + +const PROVIDER_NAME: &str = "openai"; + +#[derive(Default, Clone, Debug, PartialEq)] +pub struct OpenAiSettings { + pub api_url: String, + pub low_speed_timeout: Option, + pub available_models: Vec, } -impl OpenAiCompletionProvider { - pub fn new( - model: OpenAiModel, - api_url: String, - http_client: Arc, - low_speed_timeout: Option, - settings_version: usize, - available_models_from_settings: Vec, - ) -> Self { - Self { +pub struct OpenAiLanguageModelProvider { + http_client: Arc, + state: gpui::Model, +} + +struct State { + api_key: Option, + settings: OpenAiSettings, + _subscription: Subscription, +} + +impl OpenAiLanguageModelProvider { + pub fn new(http_client: Arc, cx: &mut AppContext) -> Self { + let state = cx.new_model(|cx| State { api_key: None, - api_url, - model, - http_client, - low_speed_timeout, - settings_version, - available_models_from_settings, + settings: OpenAiSettings::default(), + _subscription: cx.observe_global::(|this: &mut State, cx| { + this.settings = AllLanguageModelSettings::get_global(cx).open_ai.clone(); + cx.notify(); + }), + }); + + Self { http_client, state } + } +} + +impl LanguageModelProviderState for OpenAiLanguageModelProvider { + fn subscribe(&self, cx: &mut gpui::ModelContext) -> Option { + Some(cx.observe(&self.state, |_, _, cx| { + cx.notify(); + })) + } +} + +impl LanguageModelProvider for OpenAiLanguageModelProvider { + fn name(&self) -> LanguageModelProviderName { + LanguageModelProviderName(PROVIDER_NAME.into()) + } + + fn provided_models(&self, cx: &AppContext) -> Vec> { + let mut models = HashMap::default(); + + // Add base models from open_ai::Model::iter() + for model in open_ai::Model::iter() { + if !matches!(model, open_ai::Model::Custom { .. }) { + models.insert(model.id().to_string(), model); + } + } + + // Override with available models from settings + for model in &self.state.read(cx).settings.available_models { + models.insert(model.id().to_string(), model.clone()); + } + + models + .into_values() + .map(|model| { + Arc::new(OpenAiLanguageModel { + id: LanguageModelId::from(model.id().to_string()), + model, + state: self.state.clone(), + http_client: self.http_client.clone(), + }) as Arc + }) + .collect() + } + + fn is_authenticated(&self, cx: &AppContext) -> bool { + self.state.read(cx).api_key.is_some() + } + + fn authenticate(&self, cx: &AppContext) -> Task> { + if self.is_authenticated(cx) { + Task::ready(Ok(())) + } else { + let api_url = self.state.read(cx).settings.api_url.clone(); + let state = self.state.clone(); + cx.spawn(|mut cx| async move { + let api_key = if let Ok(api_key) = std::env::var("OPENAI_API_KEY") { + api_key + } else { + let (_, api_key) = cx + .update(|cx| cx.read_credentials(&api_url))? + .await? + .ok_or_else(|| anyhow!("credentials not found"))?; + String::from_utf8(api_key)? + }; + state.update(&mut cx, |this, cx| { + this.api_key = Some(api_key); + cx.notify(); + }) + }) } } - pub fn update( - &mut self, - model: OpenAiModel, - api_url: String, - low_speed_timeout: Option, - settings_version: usize, - ) { - self.model = model; - self.api_url = api_url; - self.low_speed_timeout = low_speed_timeout; - self.settings_version = settings_version; + fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView { + cx.new_view(|cx| AuthenticationPrompt::new(self.state.clone(), cx)) + .into() } - fn to_open_ai_request(&self, request: LanguageModelRequest) -> Request { - let model = match request.model { - LanguageModel::OpenAi(model) => model, - _ => self.model.clone(), - }; + fn reset_credentials(&self, cx: &AppContext) -> Task> { + let delete_credentials = cx.delete_credentials(&self.state.read(cx).settings.api_url); + let state = self.state.clone(); + cx.spawn(|mut cx| async move { + delete_credentials.await.log_err(); + state.update(&mut cx, |this, cx| { + this.api_key = None; + cx.notify(); + }) + }) + } +} +pub struct OpenAiLanguageModel { + id: LanguageModelId, + model: open_ai::Model, + state: gpui::Model, + http_client: Arc, +} + +impl OpenAiLanguageModel { + fn to_open_ai_request(&self, request: LanguageModelRequest) -> Request { Request { - model, + model: self.model.clone(), messages: request .messages .into_iter() @@ -92,80 +179,25 @@ impl OpenAiCompletionProvider { } } -impl LanguageModelCompletionProvider for OpenAiCompletionProvider { - fn available_models(&self) -> Vec { - if self.available_models_from_settings.is_empty() { - let available_models = if matches!(self.model, OpenAiModel::Custom { .. }) { - vec![self.model.clone()] - } else { - OpenAiModel::iter() - .filter(|model| !matches!(model, OpenAiModel::Custom { .. })) - .collect() - }; - available_models - .into_iter() - .map(LanguageModel::OpenAi) - .collect() - } else { - self.available_models_from_settings - .iter() - .cloned() - .map(LanguageModel::OpenAi) - .collect() - } +impl LanguageModel for OpenAiLanguageModel { + fn id(&self) -> LanguageModelId { + self.id.clone() } - fn settings_version(&self) -> usize { - self.settings_version + fn name(&self) -> LanguageModelName { + LanguageModelName::from(self.model.display_name().to_string()) } - fn is_authenticated(&self) -> bool { - self.api_key.is_some() + fn provider_name(&self) -> LanguageModelProviderName { + LanguageModelProviderName(PROVIDER_NAME.into()) } - fn authenticate(&self, cx: &AppContext) -> Task> { - if self.is_authenticated() { - Task::ready(Ok(())) - } else { - let api_url = self.api_url.clone(); - cx.spawn(|mut cx| async move { - let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") { - api_key - } else { - let (_, api_key) = cx - .update(|cx| cx.read_credentials(&api_url))? - .await? - .ok_or_else(|| anyhow!("credentials not found"))?; - String::from_utf8(api_key)? - }; - cx.update_global::(|provider, _cx| { - provider.update_current_as::<_, Self>(|provider| { - provider.api_key = Some(api_key); - }); - }) - }) - } + fn telemetry_id(&self) -> String { + format!("openai/{}", self.model.id()) } - fn reset_credentials(&self, cx: &AppContext) -> Task> { - let delete_credentials = cx.delete_credentials(&self.api_url); - cx.spawn(|mut cx| async move { - delete_credentials.await.log_err(); - cx.update_global::(|provider, _cx| { - provider.update_current_as::<_, Self>(|provider| { - provider.api_key = None; - }); - }) - }) - } - - fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView { - cx.new_view(|cx| AuthenticationPrompt::new(self.api_url.clone(), cx)) - .into() - } - - fn model(&self) -> LanguageModel { - LanguageModel::OpenAi(self.model.clone()) + fn max_token_count(&self) -> usize { + self.model.max_token_count() } fn count_tokens( @@ -173,19 +205,27 @@ impl LanguageModelCompletionProvider for OpenAiCompletionProvider { request: LanguageModelRequest, cx: &AppContext, ) -> BoxFuture<'static, Result> { - count_open_ai_tokens(request, cx.background_executor()) + count_open_ai_tokens(request, self.model.clone(), cx) } fn stream_completion( &self, request: LanguageModelRequest, - ) -> BoxFuture<'static, Result>>> { + cx: &AsyncAppContext, + ) -> BoxFuture<'static, Result>>> { let request = self.to_open_ai_request(request); let http_client = self.http_client.clone(); - let api_key = self.api_key.clone(); - let api_url = self.api_url.clone(); - let low_speed_timeout = self.low_speed_timeout; + let Ok((api_key, api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, _| { + ( + state.api_key.clone(), + state.settings.api_url.clone(), + state.settings.low_speed_timeout, + ) + }) else { + return futures::future::ready(Err(anyhow!("App state dropped"))).boxed(); + }; + async move { let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?; let request = stream_completion( @@ -208,17 +248,14 @@ impl LanguageModelCompletionProvider for OpenAiCompletionProvider { } .boxed() } - - fn as_any_mut(&mut self) -> &mut dyn std::any::Any { - self - } } pub fn count_open_ai_tokens( request: LanguageModelRequest, - background_executor: &gpui::BackgroundExecutor, + model: open_ai::Model, + cx: &AppContext, ) -> BoxFuture<'static, Result> { - background_executor + cx.background_executor() .spawn(async move { let messages = request .messages @@ -235,19 +272,10 @@ pub fn count_open_ai_tokens( }) .collect::>(); - match request.model { - LanguageModel::Anthropic(_) - | LanguageModel::Cloud(CloudModel::Claude3_5Sonnet) - | LanguageModel::Cloud(CloudModel::Claude3Opus) - | LanguageModel::Cloud(CloudModel::Claude3Sonnet) - | LanguageModel::Cloud(CloudModel::Claude3Haiku) - | LanguageModel::Cloud(CloudModel::Custom { .. }) - | LanguageModel::OpenAi(OpenAiModel::Custom { .. }) => { - // Tiktoken doesn't yet support these models, so we manually use the - // same tokenizer as GPT-4. - tiktoken_rs::num_tokens_from_messages("gpt-4", &messages) - } - _ => tiktoken_rs::num_tokens_from_messages(request.model.id(), &messages), + if let open_ai::Model::Custom { .. } = model { + tiktoken_rs::num_tokens_from_messages("gpt-4", &messages) + } else { + tiktoken_rs::num_tokens_from_messages(model.id(), &messages) } }) .boxed() @@ -255,11 +283,11 @@ pub fn count_open_ai_tokens( struct AuthenticationPrompt { api_key: View, - api_url: String, + state: gpui::Model, } impl AuthenticationPrompt { - fn new(api_url: String, cx: &mut WindowContext) -> Self { + fn new(state: gpui::Model, cx: &mut WindowContext) -> Self { Self { api_key: cx.new_view(|cx| { let mut editor = Editor::single_line(cx); @@ -269,7 +297,7 @@ impl AuthenticationPrompt { ); editor }), - api_url, + state, } } @@ -279,13 +307,17 @@ impl AuthenticationPrompt { return; } - let write_credentials = cx.write_credentials(&self.api_url, "Bearer", api_key.as_bytes()); + let write_credentials = cx.write_credentials( + &self.state.read(cx).settings.api_url, + "Bearer", + api_key.as_bytes(), + ); + let state = self.state.clone(); cx.spawn(|_, mut cx| async move { write_credentials.await?; - cx.update_global::(|provider, _cx| { - provider.update_current_as::<_, OpenAiCompletionProvider>(|provider| { - provider.api_key = Some(api_key); - }); + state.update(&mut cx, |this, cx| { + this.api_key = Some(api_key); + cx.notify(); }) }) .detach_and_log_err(cx); @@ -299,8 +331,12 @@ impl AuthenticationPrompt { font_features: settings.ui_font.features.clone(), font_size: rems(0.875).into(), font_weight: settings.ui_font.weight, + font_style: FontStyle::Normal, line_height: relative(1.3), - ..Default::default() + background_color: None, + underline: None, + strikethrough: None, + white_space: WhiteSpace::Normal, }; EditorElement::new( &self.api_key, diff --git a/crates/language_model/src/registry.rs b/crates/language_model/src/registry.rs new file mode 100644 index 0000000000..5308a2fce8 --- /dev/null +++ b/crates/language_model/src/registry.rs @@ -0,0 +1,172 @@ +use client::Client; +use collections::HashMap; +use gpui::{AppContext, Global, Model, ModelContext}; +use std::sync::Arc; +use ui::Context; + +use crate::{ + provider::{ + anthropic::AnthropicLanguageModelProvider, cloud::CloudLanguageModelProvider, + ollama::OllamaLanguageModelProvider, open_ai::OpenAiLanguageModelProvider, + }, + LanguageModel, LanguageModelProvider, LanguageModelProviderName, LanguageModelProviderState, +}; + +pub fn init(client: Arc, cx: &mut AppContext) { + let registry = cx.new_model(|cx| { + let mut registry = LanguageModelRegistry::default(); + register_language_model_providers(&mut registry, client, cx); + registry + }); + cx.set_global(GlobalLanguageModelRegistry(registry)); +} + +fn register_language_model_providers( + registry: &mut LanguageModelRegistry, + client: Arc, + cx: &mut ModelContext, +) { + use feature_flags::FeatureFlagAppExt; + + registry.register_provider( + AnthropicLanguageModelProvider::new(client.http_client(), cx), + cx, + ); + registry.register_provider( + OpenAiLanguageModelProvider::new(client.http_client(), cx), + cx, + ); + registry.register_provider( + OllamaLanguageModelProvider::new(client.http_client(), cx), + cx, + ); + + cx.observe_flag::(move |enabled, cx| { + let client = client.clone(); + LanguageModelRegistry::global(cx).update(cx, move |registry, cx| { + if enabled { + registry.register_provider(CloudLanguageModelProvider::new(client.clone(), cx), cx); + } else { + registry.unregister_provider( + &LanguageModelProviderName::from( + crate::provider::cloud::PROVIDER_NAME.to_string(), + ), + cx, + ); + } + }); + }) + .detach(); +} + +struct GlobalLanguageModelRegistry(Model); + +impl Global for GlobalLanguageModelRegistry {} + +#[derive(Default)] +pub struct LanguageModelRegistry { + providers: HashMap>, +} + +impl LanguageModelRegistry { + pub fn global(cx: &AppContext) -> Model { + cx.global::().0.clone() + } + + pub fn read_global(cx: &AppContext) -> &Self { + cx.global::().0.read(cx) + } + + #[cfg(any(test, feature = "test-support"))] + pub fn test(cx: &mut AppContext) -> crate::provider::fake::FakeLanguageModelProvider { + let fake_provider = crate::provider::fake::FakeLanguageModelProvider::default(); + let registry = cx.new_model(|cx| { + let mut registry = Self::default(); + registry.register_provider(fake_provider.clone(), cx); + registry + }); + cx.set_global(GlobalLanguageModelRegistry(registry)); + fake_provider + } + + pub fn register_provider( + &mut self, + provider: T, + cx: &mut ModelContext, + ) { + let name = provider.name(); + + if let Some(subscription) = provider.subscribe(cx) { + subscription.detach(); + } + + self.providers.insert(name, Arc::new(provider)); + cx.notify(); + } + + pub fn unregister_provider( + &mut self, + name: &LanguageModelProviderName, + cx: &mut ModelContext, + ) { + if self.providers.remove(name).is_some() { + cx.notify(); + } + } + + pub fn providers( + &self, + ) -> impl Iterator)> { + self.providers.iter() + } + + pub fn available_models(&self, cx: &AppContext) -> Vec> { + self.providers + .values() + .flat_map(|provider| provider.provided_models(cx)) + .collect() + } + + pub fn available_models_grouped_by_provider( + &self, + cx: &AppContext, + ) -> HashMap>> { + self.providers + .iter() + .map(|(name, provider)| (name.clone(), provider.provided_models(cx))) + .collect() + } + + pub fn provider( + &self, + name: &LanguageModelProviderName, + ) -> Option> { + self.providers.get(name).cloned() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::provider::fake::FakeLanguageModelProvider; + + #[gpui::test] + fn test_register_providers(cx: &mut AppContext) { + let registry = cx.new_model(|_| LanguageModelRegistry::default()); + + registry.update(cx, |registry, cx| { + registry.register_provider(FakeLanguageModelProvider::default(), cx); + }); + + let providers = registry.read(cx).providers().collect::>(); + assert_eq!(providers.len(), 1); + assert_eq!(providers[0].0, &crate::provider::fake::provider_name()); + + registry.update(cx, |registry, cx| { + registry.unregister_provider(&crate::provider::fake::provider_name(), cx); + }); + + let providers = registry.read(cx).providers().collect::>(); + assert!(providers.is_empty()); + } +} diff --git a/crates/language_model/src/request.rs b/crates/language_model/src/request.rs index 50a46c55a5..e3e1d3e77b 100644 --- a/crates/language_model/src/request.rs +++ b/crates/language_model/src/request.rs @@ -1,7 +1,4 @@ -use crate::{ - model::{CloudModel, LanguageModel}, - role::Role, -}; +use crate::{role::Role, LanguageModelId}; use serde::{Deserialize, Serialize}; #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] @@ -23,16 +20,15 @@ impl LanguageModelRequestMessage { #[derive(Debug, Default, Serialize, Deserialize)] pub struct LanguageModelRequest { - pub model: LanguageModel, pub messages: Vec, pub stop: Vec, pub temperature: f32, } impl LanguageModelRequest { - pub fn to_proto(&self) -> proto::CompleteWithLanguageModel { + pub fn to_proto(&self, model_id: LanguageModelId) -> proto::CompleteWithLanguageModel { proto::CompleteWithLanguageModel { - model: self.model.id().to_string(), + model: model_id.0.to_string(), messages: self.messages.iter().map(|m| m.to_proto()).collect(), stop: self.stop.clone(), temperature: self.temperature, @@ -40,70 +36,6 @@ impl LanguageModelRequest { tools: Vec::new(), } } - - /// Before we send the request to the server, we can perform fixups on it appropriate to the model. - pub fn preprocess(&mut self) { - match &self.model { - LanguageModel::OpenAi(_) => {} - LanguageModel::Anthropic(_) => self.preprocess_anthropic(), - LanguageModel::Ollama(_) => {} - LanguageModel::Cloud(model) => match model { - CloudModel::Claude3Opus - | CloudModel::Claude3Sonnet - | CloudModel::Claude3Haiku - | CloudModel::Claude3_5Sonnet => { - self.preprocess_anthropic(); - } - CloudModel::Custom { name, .. } if name.starts_with("anthropic/") => { - self.preprocess_anthropic(); - } - _ => {} - }, - } - } - - pub fn preprocess_anthropic(&mut self) { - let mut new_messages: Vec = Vec::new(); - let mut system_message = String::new(); - - for message in self.messages.drain(..) { - if message.content.is_empty() { - continue; - } - - match message.role { - Role::User | Role::Assistant => { - if let Some(last_message) = new_messages.last_mut() { - if last_message.role == message.role { - last_message.content.push_str("\n\n"); - last_message.content.push_str(&message.content); - continue; - } - } - - new_messages.push(message); - } - Role::System => { - if !system_message.is_empty() { - system_message.push_str("\n\n"); - } - system_message.push_str(&message.content); - } - } - } - - if !system_message.is_empty() { - new_messages.insert( - 0, - LanguageModelRequestMessage { - role: Role::System, - content: system_message, - }, - ); - } - - self.messages = new_messages; - } } #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] diff --git a/crates/language_model/src/settings.rs b/crates/language_model/src/settings.rs new file mode 100644 index 0000000000..0dcc5b4065 --- /dev/null +++ b/crates/language_model/src/settings.rs @@ -0,0 +1,143 @@ +use std::time::Duration; + +use anyhow::Result; +use gpui::AppContext; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use settings::{Settings, SettingsSources}; + +use crate::{ + provider::{ + anthropic::AnthropicSettings, cloud::ZedDotDevSettings, ollama::OllamaSettings, + open_ai::OpenAiSettings, + }, + CloudModel, +}; + +/// Initializes the language model settings. +pub fn init(cx: &mut AppContext) { + AllLanguageModelSettings::register(cx); +} + +#[derive(Default)] +pub struct AllLanguageModelSettings { + pub open_ai: OpenAiSettings, + pub anthropic: AnthropicSettings, + pub ollama: OllamaSettings, + pub zed_dot_dev: ZedDotDevSettings, +} + +#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)] +pub struct AllLanguageModelSettingsContent { + pub anthropic: Option, + pub ollama: Option, + pub open_ai: Option, + #[serde(rename = "zed.dev")] + pub zed_dot_dev: Option, +} + +#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)] +pub struct AnthropicSettingsContent { + pub api_url: Option, + pub low_speed_timeout_in_seconds: Option, + pub available_models: Option>, +} + +#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)] +pub struct OllamaSettingsContent { + pub api_url: Option, + pub low_speed_timeout_in_seconds: Option, +} + +#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)] +pub struct OpenAiSettingsContent { + pub api_url: Option, + pub low_speed_timeout_in_seconds: Option, + pub available_models: Option>, +} + +#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)] +pub struct ZedDotDevSettingsContent { + available_models: Option>, +} + +impl settings::Settings for AllLanguageModelSettings { + const KEY: Option<&'static str> = Some("language_models"); + + type FileContent = AllLanguageModelSettingsContent; + + fn load(sources: SettingsSources, _: &mut AppContext) -> Result { + fn merge(target: &mut T, value: Option) { + if let Some(value) = value { + *target = value; + } + } + + let mut settings = AllLanguageModelSettings::default(); + + for value in sources.defaults_and_customizations() { + merge( + &mut settings.anthropic.api_url, + value.anthropic.as_ref().and_then(|s| s.api_url.clone()), + ); + if let Some(low_speed_timeout_in_seconds) = value + .anthropic + .as_ref() + .and_then(|s| s.low_speed_timeout_in_seconds) + { + settings.anthropic.low_speed_timeout = + Some(Duration::from_secs(low_speed_timeout_in_seconds)); + } + merge( + &mut settings.anthropic.available_models, + value + .anthropic + .as_ref() + .and_then(|s| s.available_models.clone()), + ); + + merge( + &mut settings.ollama.api_url, + value.ollama.as_ref().and_then(|s| s.api_url.clone()), + ); + if let Some(low_speed_timeout_in_seconds) = value + .ollama + .as_ref() + .and_then(|s| s.low_speed_timeout_in_seconds) + { + settings.ollama.low_speed_timeout = + Some(Duration::from_secs(low_speed_timeout_in_seconds)); + } + + merge( + &mut settings.open_ai.api_url, + value.open_ai.as_ref().and_then(|s| s.api_url.clone()), + ); + if let Some(low_speed_timeout_in_seconds) = value + .open_ai + .as_ref() + .and_then(|s| s.low_speed_timeout_in_seconds) + { + settings.open_ai.low_speed_timeout = + Some(Duration::from_secs(low_speed_timeout_in_seconds)); + } + merge( + &mut settings.open_ai.available_models, + value + .open_ai + .as_ref() + .and_then(|s| s.available_models.clone()), + ); + + merge( + &mut settings.zed_dot_dev.available_models, + value + .zed_dot_dev + .as_ref() + .and_then(|s| s.available_models.clone()), + ); + } + + Ok(settings) + } +} diff --git a/crates/open_ai/src/open_ai.rs b/crates/open_ai/src/open_ai.rs index cd1312fd62..56d934c5b0 100644 --- a/crates/open_ai/src/open_ai.rs +++ b/crates/open_ai/src/open_ai.rs @@ -77,14 +77,14 @@ impl Model { } } - pub fn id(&self) -> &'static str { + pub fn id(&self) -> &str { match self { Self::ThreePointFiveTurbo => "gpt-3.5-turbo", Self::Four => "gpt-4", Self::FourTurbo => "gpt-4-turbo-preview", Self::FourOmni => "gpt-4o", Self::FourOmniMini => "gpt-4o-mini", - Self::Custom { .. } => "custom", + Self::Custom { name, .. } => name, } } diff --git a/crates/outline_panel/src/outline_panel.rs b/crates/outline_panel/src/outline_panel.rs index 5ce4d4f801..bfdbfb8eab 100644 --- a/crates/outline_panel/src/outline_panel.rs +++ b/crates/outline_panel/src/outline_panel.rs @@ -2785,7 +2785,7 @@ impl Panel for OutlinePanel { settings::update_settings_file::( self.fs.clone(), cx, - move |settings| { + move |settings, _| { let dock = match position { DockPosition::Left | DockPosition::Bottom => OutlinePanelDockPosition::Left, DockPosition::Right => OutlinePanelDockPosition::Right, diff --git a/crates/project_panel/src/project_panel.rs b/crates/project_panel/src/project_panel.rs index 537f2424ee..f15ab4f6a9 100644 --- a/crates/project_panel/src/project_panel.rs +++ b/crates/project_panel/src/project_panel.rs @@ -2572,7 +2572,7 @@ impl Panel for ProjectPanel { settings::update_settings_file::( self.fs.clone(), cx, - move |settings| { + move |settings, _| { let dock = match position { DockPosition::Left | DockPosition::Bottom => ProjectPanelDockPosition::Left, DockPosition::Right => ProjectPanelDockPosition::Right, diff --git a/crates/remote_server/src/headless_project.rs b/crates/remote_server/src/headless_project.rs index 5d6b33d92e..e3c1f91492 100644 --- a/crates/remote_server/src/headless_project.rs +++ b/crates/remote_server/src/headless_project.rs @@ -27,7 +27,7 @@ pub struct HeadlessProject { impl HeadlessProject { pub fn init(cx: &mut AppContext) { - cx.set_global(SettingsStore::default()); + cx.set_global(SettingsStore::new(cx)); WorktreeSettings::register(cx); } diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index 4c43fc1e46..471c58b5bf 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -1263,4 +1263,4 @@ mod tests { } // See https://github.com/zed-industries/zed/pull/14823#discussion_r1684616398 for why this is here and when it should be removed. -type _TODO = completion::CompletionProvider; +type _TODO = completion::LanguageModelCompletionProvider; diff --git a/crates/settings/src/settings.rs b/crates/settings/src/settings.rs index fcfc86dd3d..193f1a28a7 100644 --- a/crates/settings/src/settings.rs +++ b/crates/settings/src/settings.rs @@ -21,7 +21,7 @@ pub use settings_store::{ pub struct SettingsAssets; pub fn init(cx: &mut AppContext) { - let mut settings = SettingsStore::default(); + let mut settings = SettingsStore::new(cx); settings .set_default_settings(&default_settings(), cx) .unwrap(); diff --git a/crates/settings/src/settings_file.rs b/crates/settings/src/settings_file.rs index ff6927a787..59adabd7ff 100644 --- a/crates/settings/src/settings_file.rs +++ b/crates/settings/src/settings_file.rs @@ -1,9 +1,8 @@ use crate::{settings_store::SettingsStore, Settings}; -use anyhow::{Context, Result}; use fs::Fs; use futures::{channel::mpsc, StreamExt}; -use gpui::{AppContext, BackgroundExecutor, UpdateGlobal}; -use std::{io::ErrorKind, path::PathBuf, sync::Arc, time::Duration}; +use gpui::{AppContext, BackgroundExecutor, ReadGlobal, UpdateGlobal}; +use std::{path::PathBuf, sync::Arc, time::Duration}; use util::ResultExt; pub const EMPTY_THEME_NAME: &str = "empty-theme"; @@ -91,46 +90,10 @@ pub fn handle_settings_file_changes( .detach(); } -async fn load_settings(fs: &Arc) -> Result { - match fs.load(paths::settings_file()).await { - result @ Ok(_) => result, - Err(err) => { - if let Some(e) = err.downcast_ref::() { - if e.kind() == ErrorKind::NotFound { - return Ok(crate::initial_user_settings_content().to_string()); - } - } - Err(err) - } - } -} - pub fn update_settings_file( fs: Arc, - cx: &mut AppContext, - update: impl 'static + Send + FnOnce(&mut T::FileContent), + cx: &AppContext, + update: impl 'static + Send + FnOnce(&mut T::FileContent, &AppContext), ) { - cx.spawn(|cx| async move { - let old_text = load_settings(&fs).await?; - let new_text = cx.read_global(|store: &SettingsStore, _cx| { - store.new_text_for_update::(old_text, update) - })?; - let initial_path = paths::settings_file().as_path(); - if fs.is_file(initial_path).await { - let resolved_path = fs.canonicalize(initial_path).await.with_context(|| { - format!("Failed to canonicalize settings path {:?}", initial_path) - })?; - - fs.atomic_write(resolved_path.clone(), new_text) - .await - .with_context(|| format!("Failed to write settings to file {:?}", resolved_path))?; - } else { - fs.atomic_write(initial_path.to_path_buf(), new_text) - .await - .with_context(|| format!("Failed to write settings to file {:?}", initial_path))?; - } - - anyhow::Ok(()) - }) - .detach_and_log_err(cx); + SettingsStore::global(cx).update_settings_file::(fs, update); } diff --git a/crates/settings/src/settings_store.rs b/crates/settings/src/settings_store.rs index 3d7e4cfe5b..a3417e578e 100644 --- a/crates/settings/src/settings_store.rs +++ b/crates/settings/src/settings_store.rs @@ -1,6 +1,8 @@ use anyhow::{anyhow, Context, Result}; use collections::{btree_map, hash_map, BTreeMap, HashMap}; -use gpui::{AppContext, AsyncAppContext, BorrowAppContext, Global, UpdateGlobal}; +use fs::Fs; +use futures::{channel::mpsc, future::LocalBoxFuture, FutureExt, StreamExt}; +use gpui::{AppContext, AsyncAppContext, BorrowAppContext, Global, Task, UpdateGlobal}; use lazy_static::lazy_static; use schemars::{gen::SchemaGenerator, schema::RootSchema, JsonSchema}; use serde::{de::DeserializeOwned, Deserialize as _, Serialize}; @@ -161,23 +163,14 @@ pub struct SettingsStore { TypeId, Box Option + Send + Sync + 'static>, )>, + _setting_file_updates: Task<()>, + setting_file_updates_tx: mpsc::UnboundedSender< + Box LocalBoxFuture<'static, Result<()>>>, + >, } impl Global for SettingsStore {} -impl Default for SettingsStore { - fn default() -> Self { - SettingsStore { - setting_values: Default::default(), - raw_default_settings: serde_json::json!({}), - raw_user_settings: serde_json::json!({}), - raw_extension_settings: serde_json::json!({}), - raw_local_settings: Default::default(), - tab_size_callback: Default::default(), - } - } -} - #[derive(Debug)] struct SettingValue { global_value: Option, @@ -207,6 +200,24 @@ trait AnySettingValue: 'static + Send + Sync { struct DeserializedSetting(Box); impl SettingsStore { + pub fn new(cx: &AppContext) -> Self { + let (setting_file_updates_tx, mut setting_file_updates_rx) = mpsc::unbounded(); + Self { + setting_values: Default::default(), + raw_default_settings: serde_json::json!({}), + raw_user_settings: serde_json::json!({}), + raw_extension_settings: serde_json::json!({}), + raw_local_settings: Default::default(), + tab_size_callback: Default::default(), + setting_file_updates_tx, + _setting_file_updates: cx.spawn(|cx| async move { + while let Some(setting_file_update) = setting_file_updates_rx.next().await { + (setting_file_update)(cx.clone()).await.log_err(); + } + }), + } + } + pub fn update(cx: &mut C, f: impl FnOnce(&mut Self, &mut C) -> R) -> R where C: BorrowAppContext, @@ -301,7 +312,7 @@ impl SettingsStore { #[cfg(any(test, feature = "test-support"))] pub fn test(cx: &mut AppContext) -> Self { - let mut this = Self::default(); + let mut this = Self::new(cx); this.set_default_settings(&crate::test_settings(), cx) .unwrap(); this.set_user_settings("{}", cx).unwrap(); @@ -323,6 +334,59 @@ impl SettingsStore { self.set_user_settings(&new_text, cx).unwrap(); } + async fn load_settings(fs: &Arc) -> Result { + match fs.load(paths::settings_file()).await { + result @ Ok(_) => result, + Err(err) => { + if let Some(e) = err.downcast_ref::() { + if e.kind() == std::io::ErrorKind::NotFound { + return Ok(crate::initial_user_settings_content().to_string()); + } + } + Err(err) + } + } + } + + pub fn update_settings_file( + &self, + fs: Arc, + update: impl 'static + Send + FnOnce(&mut T::FileContent, &AppContext), + ) { + self.setting_file_updates_tx + .unbounded_send(Box::new(move |cx: AsyncAppContext| { + async move { + let old_text = Self::load_settings(&fs).await?; + let new_text = cx.read_global(|store: &SettingsStore, cx| { + store.new_text_for_update::(old_text, |content| update(content, cx)) + })?; + let initial_path = paths::settings_file().as_path(); + if fs.is_file(initial_path).await { + let resolved_path = + fs.canonicalize(initial_path).await.with_context(|| { + format!("Failed to canonicalize settings path {:?}", initial_path) + })?; + + fs.atomic_write(resolved_path.clone(), new_text) + .await + .with_context(|| { + format!("Failed to write settings to file {:?}", resolved_path) + })?; + } else { + fs.atomic_write(initial_path.to_path_buf(), new_text) + .await + .with_context(|| { + format!("Failed to write settings to file {:?}", initial_path) + })?; + } + + anyhow::Ok(()) + } + .boxed_local() + })) + .ok(); + } + /// Updates the value of a setting in a JSON file, returning the new text /// for that JSON file. pub fn new_text_for_update( @@ -1019,7 +1083,7 @@ mod tests { #[gpui::test] fn test_settings_store_basic(cx: &mut AppContext) { - let mut store = SettingsStore::default(); + let mut store = SettingsStore::new(cx); store.register_setting::(cx); store.register_setting::(cx); store.register_setting::(cx); @@ -1148,7 +1212,7 @@ mod tests { #[gpui::test] fn test_setting_store_assign_json_before_register(cx: &mut AppContext) { - let mut store = SettingsStore::default(); + let mut store = SettingsStore::new(cx); store .set_default_settings( r#"{ @@ -1191,7 +1255,7 @@ mod tests { #[gpui::test] fn test_setting_store_update(cx: &mut AppContext) { - let mut store = SettingsStore::default(); + let mut store = SettingsStore::new(cx); store.register_setting::(cx); store.register_setting::(cx); store.register_setting::(cx); diff --git a/crates/terminal_view/src/terminal_panel.rs b/crates/terminal_view/src/terminal_panel.rs index 7d50bdb12e..08105970f5 100644 --- a/crates/terminal_view/src/terminal_panel.rs +++ b/crates/terminal_view/src/terminal_panel.rs @@ -760,14 +760,18 @@ impl Panel for TerminalPanel { } fn set_position(&mut self, position: DockPosition, cx: &mut ViewContext) { - settings::update_settings_file::(self.fs.clone(), cx, move |settings| { - let dock = match position { - DockPosition::Left => TerminalDockPosition::Left, - DockPosition::Bottom => TerminalDockPosition::Bottom, - DockPosition::Right => TerminalDockPosition::Right, - }; - settings.dock = Some(dock); - }); + settings::update_settings_file::( + self.fs.clone(), + cx, + move |settings, _| { + let dock = match position { + DockPosition::Left => TerminalDockPosition::Left, + DockPosition::Bottom => TerminalDockPosition::Bottom, + DockPosition::Right => TerminalDockPosition::Right, + }; + settings.dock = Some(dock); + }, + ); } fn size(&self, cx: &WindowContext) -> Pixels { diff --git a/crates/theme_selector/src/theme_selector.rs b/crates/theme_selector/src/theme_selector.rs index a54be4b56c..80a3769539 100644 --- a/crates/theme_selector/src/theme_selector.rs +++ b/crates/theme_selector/src/theme_selector.rs @@ -196,7 +196,7 @@ impl PickerDelegate for ThemeSelectorDelegate { let appearance = Appearance::from(cx.appearance()); - update_settings_file::(self.fs.clone(), cx, move |settings| { + update_settings_file::(self.fs.clone(), cx, move |settings, _| { if let Some(selection) = settings.theme.as_mut() { let theme_to_update = match selection { ThemeSelection::Static(theme) => theme, diff --git a/crates/vim/src/vim.rs b/crates/vim/src/vim.rs index 336fd4ffc4..4e0a02fb0b 100644 --- a/crates/vim/src/vim.rs +++ b/crates/vim/src/vim.rs @@ -147,7 +147,7 @@ fn register(workspace: &mut Workspace, cx: &mut ViewContext) { workspace.register_action(|workspace: &mut Workspace, _: &ToggleVimMode, cx| { let fs = workspace.app_state().fs.clone(); let currently_enabled = VimModeSetting::get_global(cx).0; - update_settings_file::(fs, cx, move |setting| { + update_settings_file::(fs, cx, move |setting, _| { *setting = Some(!currently_enabled) }) }); diff --git a/crates/welcome/src/base_keymap_picker.rs b/crates/welcome/src/base_keymap_picker.rs index aa65051c0b..96a9df9c3c 100644 --- a/crates/welcome/src/base_keymap_picker.rs +++ b/crates/welcome/src/base_keymap_picker.rs @@ -176,7 +176,7 @@ impl PickerDelegate for BaseKeymapSelectorDelegate { self.telemetry .report_setting_event("keymap", base_keymap.to_string()); - update_settings_file::(self.fs.clone(), cx, move |setting| { + update_settings_file::(self.fs.clone(), cx, move |setting, _| { *setting = Some(base_keymap) }); } diff --git a/crates/welcome/src/welcome.rs b/crates/welcome/src/welcome.rs index 718939ab9f..cba91add01 100644 --- a/crates/welcome/src/welcome.rs +++ b/crates/welcome/src/welcome.rs @@ -279,7 +279,7 @@ impl WelcomePage { if let Some(workspace) = self.workspace.upgrade() { let fs = workspace.read(cx).app_state().fs.clone(); let selection = *selection; - settings::update_settings_file::(fs, cx, move |settings| { + settings::update_settings_file::(fs, cx, move |settings, _| { let value = match selection { Selection::Unselected => false, Selection::Selected => true, diff --git a/crates/zed/Cargo.toml b/crates/zed/Cargo.toml index f33cf26aa6..8b1b7ebf1e 100644 --- a/crates/zed/Cargo.toml +++ b/crates/zed/Cargo.toml @@ -56,6 +56,7 @@ install_cli.workspace = true isahc.workspace = true journal.workspace = true language.workspace = true +language_model.workspace = true language_selector.workspace = true language_tools.workspace = true languages.workspace = true diff --git a/crates/zed/src/main.rs b/crates/zed/src/main.rs index a508e571f8..c9efc2a8d6 100644 --- a/crates/zed/src/main.rs +++ b/crates/zed/src/main.rs @@ -164,6 +164,7 @@ fn init_common(app_state: Arc, cx: &mut AppContext) { SystemAppearance::init(cx); theme::init(theme::LoadThemes::All(Box::new(Assets)), cx); command_palette::init(cx); + language_model::init(app_state.client.clone(), cx); snippet_provider::init(cx); supermaven::init(app_state.client.clone(), cx); inline_completion_registry::init(app_state.client.telemetry().clone(), cx); diff --git a/crates/zed/src/zed.rs b/crates/zed/src/zed.rs index 11499e554f..fd881dd7fa 100644 --- a/crates/zed/src/zed.rs +++ b/crates/zed/src/zed.rs @@ -3436,6 +3436,7 @@ mod tests { project_panel::init((), cx); outline_panel::init((), cx); terminal_view::init(cx); + language_model::init(app_state.client.clone(), cx); assistant::init(app_state.fs.clone(), app_state.client.clone(), cx); repl::init(app_state.fs.clone(), cx); tasks_ui::init(cx);