From af4b9805c99d30d75314d8e3bb04855712fa1c0d Mon Sep 17 00:00:00 2001 From: Bennet Bo Fenner Date: Wed, 24 Jul 2024 11:21:31 +0200 Subject: [PATCH] assistant: Fix issues when configuring different providers (#15072) Release Notes: - N/A --------- Co-authored-by: Antonio Scandurra --- assets/settings/default.json | 12 +++- crates/assistant/src/assistant.rs | 4 +- crates/assistant/src/assistant_settings.rs | 6 +- crates/assistant/src/inline_assistant.rs | 2 +- .../src/terminal_inline_assistant.rs | 2 +- crates/completion/src/completion.rs | 15 +++-- crates/language_model/src/language_model.rs | 12 ++++ .../language_model/src/provider/anthropic.rs | 52 ++++++++++----- crates/language_model/src/provider/cloud.rs | 30 ++++++--- crates/language_model/src/provider/fake.rs | 17 ++++- crates/language_model/src/provider/ollama.rs | 64 +++++++++++------- crates/language_model/src/provider/open_ai.rs | 52 +++++++++------ crates/language_model/src/registry.rs | 20 +++--- crates/language_model/src/settings.rs | 16 ++--- crates/ollama/src/ollama.rs | 4 +- docs/src/language-model-integration.md | 65 ++++++------------- 16 files changed, 225 insertions(+), 148 deletions(-) diff --git a/assets/settings/default.json b/assets/settings/default.json index 0c6ed54e69..743ec545f1 100644 --- a/assets/settings/default.json +++ b/assets/settings/default.json @@ -853,7 +853,17 @@ } }, // Different settings for specific language models. - "language_models": {}, + "language_models": { + "anthropic": { + "api_url": "https://api.anthropic.com" + }, + "openai": { + "api_url": "https://api.openai.com/v1" + }, + "ollama": { + "api_url": "http://localhost:11434" + } + }, // Zed's Prettier integration settings. // Allows to enable/disable formatting with Prettier // and configure default Prettier, used when no project-level Prettier installation is found. diff --git a/crates/assistant/src/assistant.rs b/crates/assistant/src/assistant.rs index 1c97402f8c..181a4165c1 100644 --- a/crates/assistant/src/assistant.rs +++ b/crates/assistant/src/assistant.rs @@ -23,7 +23,7 @@ use gpui::{actions, impl_actions, AppContext, Global, SharedString, UpdateGlobal use indexed_docs::IndexedDocsRegistry; pub(crate) use inline_assistant::*; use language_model::{ - LanguageModelId, LanguageModelProviderName, LanguageModelRegistry, LanguageModelResponseMessage, + LanguageModelId, LanguageModelProviderId, LanguageModelRegistry, LanguageModelResponseMessage, }; pub(crate) use model_selector::*; use semantic_index::{CloudEmbeddingProvider, SemanticIndex}; @@ -231,7 +231,7 @@ fn init_completion_provider(cx: &mut AppContext) { fn update_active_language_model_from_settings(cx: &mut AppContext) { let settings = AssistantSettings::get_global(cx); - let provider_name = LanguageModelProviderName::from(settings.default_model.provider.clone()); + let provider_name = LanguageModelProviderId::from(settings.default_model.provider.clone()); let model_id = LanguageModelId::from(settings.default_model.model.clone()); let Some(provider) = LanguageModelRegistry::global(cx) diff --git a/crates/assistant/src/assistant_settings.rs b/crates/assistant/src/assistant_settings.rs index 09c5a9e733..05c5b56f1c 100644 --- a/crates/assistant/src/assistant_settings.rs +++ b/crates/assistant/src/assistant_settings.rs @@ -144,8 +144,8 @@ impl AssistantSettingsContent { fs, cx, move |content, _| { - if content.open_ai.is_none() { - content.open_ai = + if content.openai.is_none() { + content.openai = Some(language_model::settings::OpenAiSettingsContent { api_url, low_speed_timeout_in_seconds, @@ -243,7 +243,7 @@ impl AssistantSettingsContent { pub fn set_model(&mut self, language_model: Arc) { let model = language_model.id().0.to_string(); - let provider = language_model.provider_name().0.to_string(); + let provider = language_model.provider_id().0.to_string(); match self { AssistantSettingsContent::Versioned(settings) => match settings { diff --git a/crates/assistant/src/inline_assistant.rs b/crates/assistant/src/inline_assistant.rs index f3915c191d..39d66e8100 100644 --- a/crates/assistant/src/inline_assistant.rs +++ b/crates/assistant/src/inline_assistant.rs @@ -1438,7 +1438,7 @@ impl Render for PromptEditor { { let model_name = available_model.name().0.clone(); let provider = - available_model.provider_name().0.clone(); + available_model.provider_id().0.clone(); move |_| { h_flex() .w_full() diff --git a/crates/assistant/src/terminal_inline_assistant.rs b/crates/assistant/src/terminal_inline_assistant.rs index c1d5fb898a..5a391f8ac4 100644 --- a/crates/assistant/src/terminal_inline_assistant.rs +++ b/crates/assistant/src/terminal_inline_assistant.rs @@ -565,7 +565,7 @@ impl Render for PromptEditor { { let model_name = available_model.name().0.clone(); let provider = - available_model.provider_name().0.clone(); + available_model.provider_id().0.clone(); move |_| { h_flex() .w_full() diff --git a/crates/completion/src/completion.rs b/crates/completion/src/completion.rs index f7c2da95cb..22179907d0 100644 --- a/crates/completion/src/completion.rs +++ b/crates/completion/src/completion.rs @@ -2,7 +2,7 @@ use anyhow::{anyhow, Result}; use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt}; use gpui::{AppContext, Global, Model, ModelContext, Task}; use language_model::{ - LanguageModel, LanguageModelProvider, LanguageModelProviderName, LanguageModelRegistry, + LanguageModel, LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry, LanguageModelRequest, }; use smol::lock::{Semaphore, SemaphoreGuardArc}; @@ -89,7 +89,7 @@ impl LanguageModelCompletionProvider { pub fn set_active_provider( &mut self, - provider_name: LanguageModelProviderName, + provider_name: LanguageModelProviderId, cx: &mut ModelContext, ) { self.active_provider = LanguageModelRegistry::read_global(cx).provider(&provider_name); @@ -103,14 +103,19 @@ impl LanguageModelCompletionProvider { pub fn set_active_model(&mut self, model: Arc, cx: &mut ModelContext) { if self.active_model.as_ref().map_or(false, |m| { - m.id() == model.id() && m.provider_name() == model.provider_name() + m.id() == model.id() && m.provider_id() == model.provider_id() }) { return; } self.active_provider = - LanguageModelRegistry::read_global(cx).provider(&model.provider_name()); - self.active_model = Some(model); + LanguageModelRegistry::read_global(cx).provider(&model.provider_id()); + self.active_model = Some(model.clone()); + + if let Some(provider) = self.active_provider.as_ref() { + provider.load_model(model, cx); + } + cx.notify(); } diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs index aa0c2d697a..4eb3f8a32b 100644 --- a/crates/language_model/src/language_model.rs +++ b/crates/language_model/src/language_model.rs @@ -25,6 +25,7 @@ pub fn init(client: Arc, cx: &mut AppContext) { pub trait LanguageModel: Send + Sync { fn id(&self) -> LanguageModelId; fn name(&self) -> LanguageModelName; + fn provider_id(&self) -> LanguageModelProviderId; fn provider_name(&self) -> LanguageModelProviderName; fn telemetry_id(&self) -> String; @@ -44,8 +45,10 @@ pub trait LanguageModel: Send + Sync { } pub trait LanguageModelProvider: 'static { + fn id(&self) -> LanguageModelProviderId; fn name(&self) -> LanguageModelProviderName; fn provided_models(&self, cx: &AppContext) -> Vec>; + fn load_model(&self, _model: Arc, _cx: &AppContext) {} fn is_authenticated(&self, cx: &AppContext) -> bool; fn authenticate(&self, cx: &AppContext) -> Task>; fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView; @@ -62,6 +65,9 @@ pub struct LanguageModelId(pub SharedString); #[derive(Clone, Eq, PartialEq, Hash, Debug)] pub struct LanguageModelName(pub SharedString); +#[derive(Clone, Eq, PartialEq, Hash, Debug)] +pub struct LanguageModelProviderId(pub SharedString); + #[derive(Clone, Eq, PartialEq, Hash, Debug)] pub struct LanguageModelProviderName(pub SharedString); @@ -77,6 +83,12 @@ impl From for LanguageModelName { } } +impl From for LanguageModelProviderId { + fn from(value: String) -> Self { + Self(SharedString::from(value)) + } +} + impl From for LanguageModelProviderName { fn from(value: String) -> Self { Self(SharedString::from(value)) diff --git a/crates/language_model/src/provider/anthropic.rs b/crates/language_model/src/provider/anthropic.rs index 620c99017b..093cc25057 100644 --- a/crates/language_model/src/provider/anthropic.rs +++ b/crates/language_model/src/provider/anthropic.rs @@ -1,6 +1,5 @@ use anthropic::{stream_completion, Request, RequestMessage}; use anyhow::{anyhow, Result}; -use collections::HashMap; use editor::{Editor, EditorElement, EditorStyle}; use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt}; use gpui::{ @@ -9,7 +8,7 @@ use gpui::{ }; use http_client::HttpClient; use settings::{Settings, SettingsStore}; -use std::{sync::Arc, time::Duration}; +use std::{collections::BTreeMap, sync::Arc, time::Duration}; use strum::IntoEnumIterator; use theme::ThemeSettings; use ui::prelude::*; @@ -17,11 +16,12 @@ use util::ResultExt; use crate::{ settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName, - LanguageModelProvider, LanguageModelProviderName, LanguageModelProviderState, - LanguageModelRequest, LanguageModelRequestMessage, Role, + LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, + LanguageModelProviderState, LanguageModelRequest, LanguageModelRequestMessage, Role, }; -const PROVIDER_NAME: &str = "anthropic"; +const PROVIDER_ID: &str = "anthropic"; +const PROVIDER_NAME: &str = "Anthropic"; #[derive(Default, Clone, Debug, PartialEq)] pub struct AnthropicSettings { @@ -37,7 +37,6 @@ pub struct AnthropicLanguageModelProvider { struct State { api_key: Option, - settings: AnthropicSettings, _subscription: Subscription, } @@ -45,9 +44,7 @@ impl AnthropicLanguageModelProvider { pub fn new(http_client: Arc, cx: &mut AppContext) -> Self { let state = cx.new_model(|cx| State { api_key: None, - settings: AnthropicSettings::default(), - _subscription: cx.observe_global::(|this: &mut State, cx| { - this.settings = AllLanguageModelSettings::get_global(cx).anthropic.clone(); + _subscription: cx.observe_global::(|_, cx| { cx.notify(); }), }); @@ -64,12 +61,16 @@ impl LanguageModelProviderState for AnthropicLanguageModelProvider { } impl LanguageModelProvider for AnthropicLanguageModelProvider { + fn id(&self) -> LanguageModelProviderId { + LanguageModelProviderId(PROVIDER_ID.into()) + } + fn name(&self) -> LanguageModelProviderName { LanguageModelProviderName(PROVIDER_NAME.into()) } fn provided_models(&self, cx: &AppContext) -> Vec> { - let mut models = HashMap::default(); + let mut models = BTreeMap::default(); // Add base models from anthropic::Model::iter() for model in anthropic::Model::iter() { @@ -79,7 +80,11 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider { } // Override with available models from settings - for model in &self.state.read(cx).settings.available_models { + for model in AllLanguageModelSettings::get_global(cx) + .anthropic + .available_models + .iter() + { models.insert(model.id().to_string(), model.clone()); } @@ -104,7 +109,10 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider { if self.is_authenticated(cx) { Task::ready(Ok(())) } else { - let api_url = self.state.read(cx).settings.api_url.clone(); + let api_url = AllLanguageModelSettings::get_global(cx) + .anthropic + .api_url + .clone(); let state = self.state.clone(); cx.spawn(|mut cx| async move { let api_key = if let Ok(api_key) = std::env::var("ANTHROPIC_API_KEY") { @@ -132,7 +140,8 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider { fn reset_credentials(&self, cx: &AppContext) -> Task> { let state = self.state.clone(); - let delete_credentials = cx.delete_credentials(&self.state.read(cx).settings.api_url); + let delete_credentials = + cx.delete_credentials(&AllLanguageModelSettings::get_global(cx).anthropic.api_url); cx.spawn(|mut cx| async move { delete_credentials.await.log_err(); state.update(&mut cx, |this, cx| { @@ -221,6 +230,10 @@ impl LanguageModel for AnthropicModel { LanguageModelName::from(self.model.display_name().to_string()) } + fn provider_id(&self) -> LanguageModelProviderId { + LanguageModelProviderId(PROVIDER_ID.into()) + } + fn provider_name(&self) -> LanguageModelProviderName { LanguageModelProviderName(PROVIDER_NAME.into()) } @@ -249,11 +262,13 @@ impl LanguageModel for AnthropicModel { let request = self.to_anthropic_request(request); let http_client = self.http_client.clone(); - let Ok((api_key, api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, _| { + + let Ok((api_key, api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, cx| { + let settings = &AllLanguageModelSettings::get_global(cx).anthropic; ( state.api_key.clone(), - state.settings.api_url.clone(), - state.settings.low_speed_timeout, + settings.api_url.clone(), + settings.low_speed_timeout, ) }) else { return futures::future::ready(Err(anyhow!("App state dropped"))).boxed(); @@ -365,7 +380,10 @@ impl AuthenticationPrompt { } let write_credentials = cx.write_credentials( - &self.state.read(cx).settings.api_url, + AllLanguageModelSettings::get_global(cx) + .anthropic + .api_url + .as_str(), "Bearer", api_key.as_bytes(), ); diff --git a/crates/language_model/src/provider/cloud.rs b/crates/language_model/src/provider/cloud.rs index 3b42b72090..c0cbdbdf1d 100644 --- a/crates/language_model/src/provider/cloud.rs +++ b/crates/language_model/src/provider/cloud.rs @@ -1,15 +1,15 @@ use super::open_ai::count_open_ai_tokens; use crate::{ settings::AllLanguageModelSettings, CloudModel, LanguageModel, LanguageModelId, - LanguageModelName, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, + LanguageModelName, LanguageModelProviderId, LanguageModelProviderName, + LanguageModelProviderState, LanguageModelRequest, }; use anyhow::Result; use client::Client; -use collections::HashMap; use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt, TryFutureExt}; use gpui::{AnyView, AppContext, AsyncAppContext, Subscription, Task}; use settings::{Settings, SettingsStore}; -use std::sync::Arc; +use std::{collections::BTreeMap, sync::Arc}; use strum::IntoEnumIterator; use ui::prelude::*; @@ -17,6 +17,7 @@ use crate::LanguageModelProvider; use super::anthropic::{count_anthropic_tokens, preprocess_anthropic_request}; +pub const PROVIDER_ID: &str = "zed.dev"; pub const PROVIDER_NAME: &str = "zed.dev"; #[derive(Default, Clone, Debug, PartialEq)] @@ -33,7 +34,6 @@ pub struct CloudLanguageModelProvider { struct State { client: Arc, status: client::Status, - settings: ZedDotDevSettings, _subscription: Subscription, } @@ -52,9 +52,7 @@ impl CloudLanguageModelProvider { let state = cx.new_model(|cx| State { client: client.clone(), status, - settings: ZedDotDevSettings::default(), - _subscription: cx.observe_global::(|this: &mut State, cx| { - this.settings = AllLanguageModelSettings::get_global(cx).zed_dot_dev.clone(); + _subscription: cx.observe_global::(|_, cx| { cx.notify(); }), }); @@ -90,12 +88,16 @@ impl LanguageModelProviderState for CloudLanguageModelProvider { } impl LanguageModelProvider for CloudLanguageModelProvider { + fn id(&self) -> LanguageModelProviderId { + LanguageModelProviderId(PROVIDER_ID.into()) + } + fn name(&self) -> LanguageModelProviderName { LanguageModelProviderName(PROVIDER_NAME.into()) } fn provided_models(&self, cx: &AppContext) -> Vec> { - let mut models = HashMap::default(); + let mut models = BTreeMap::default(); // Add base models from CloudModel::iter() for model in CloudModel::iter() { @@ -105,7 +107,10 @@ impl LanguageModelProvider for CloudLanguageModelProvider { } // Override with available models from settings - for model in &self.state.read(cx).settings.available_models { + for model in &AllLanguageModelSettings::get_global(cx) + .zed_dot_dev + .available_models + { models.insert(model.id().to_string(), model.clone()); } @@ -156,6 +161,10 @@ impl LanguageModel for CloudLanguageModel { LanguageModelName::from(self.model.display_name().to_string()) } + fn provider_id(&self) -> LanguageModelProviderId { + LanguageModelProviderId(PROVIDER_ID.into()) + } + fn provider_name(&self) -> LanguageModelProviderName { LanguageModelProviderName(PROVIDER_NAME.into()) } @@ -187,6 +196,9 @@ impl LanguageModel for CloudLanguageModel { | CloudModel::Claude3Opus | CloudModel::Claude3Sonnet | CloudModel::Claude3Haiku => count_anthropic_tokens(request, cx), + CloudModel::Custom { name, .. } if name.starts_with("anthropic/") => { + count_anthropic_tokens(request, cx) + } _ => { let request = self.client.request(proto::CountTokensWithLanguageModel { model: self.model.id().to_string(), diff --git a/crates/language_model/src/provider/fake.rs b/crates/language_model/src/provider/fake.rs index 81261bbe64..8f91155cd4 100644 --- a/crates/language_model/src/provider/fake.rs +++ b/crates/language_model/src/provider/fake.rs @@ -5,7 +5,8 @@ use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, St use crate::{ LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider, - LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, + LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, + LanguageModelRequest, }; use gpui::{AnyView, AppContext, AsyncAppContext, Task}; use http_client::Result; @@ -19,8 +20,12 @@ pub fn language_model_name() -> LanguageModelName { LanguageModelName::from("Fake".to_string()) } +pub fn provider_id() -> LanguageModelProviderId { + LanguageModelProviderId::from("fake".to_string()) +} + pub fn provider_name() -> LanguageModelProviderName { - LanguageModelProviderName::from("fake".to_string()) + LanguageModelProviderName::from("Fake".to_string()) } #[derive(Clone, Default)] @@ -35,6 +40,10 @@ impl LanguageModelProviderState for FakeLanguageModelProvider { } impl LanguageModelProvider for FakeLanguageModelProvider { + fn id(&self) -> LanguageModelProviderId { + provider_id() + } + fn name(&self) -> LanguageModelProviderName { provider_name() } @@ -125,6 +134,10 @@ impl LanguageModel for FakeLanguageModel { language_model_name() } + fn provider_id(&self) -> LanguageModelProviderId { + provider_id() + } + fn provider_name(&self) -> LanguageModelProviderName { provider_name() } diff --git a/crates/language_model/src/provider/ollama.rs b/crates/language_model/src/provider/ollama.rs index 4dd0f4dcb1..3b6a3fb3b3 100644 --- a/crates/language_model/src/provider/ollama.rs +++ b/crates/language_model/src/provider/ollama.rs @@ -2,21 +2,24 @@ use anyhow::{anyhow, Result}; use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt}; use gpui::{AnyView, AppContext, AsyncAppContext, ModelContext, Subscription, Task}; use http_client::HttpClient; -use ollama::{get_models, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest}; +use ollama::{ + get_models, preload_model, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest, +}; use settings::{Settings, SettingsStore}; use std::{sync::Arc, time::Duration}; use ui::{prelude::*, ButtonLike, ElevationIndex}; use crate::{ settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName, - LanguageModelProvider, LanguageModelProviderName, LanguageModelProviderState, - LanguageModelRequest, Role, + LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, + LanguageModelProviderState, LanguageModelRequest, Role, }; const OLLAMA_DOWNLOAD_URL: &str = "https://ollama.com/download"; const OLLAMA_LIBRARY_URL: &str = "https://ollama.com/library"; -const PROVIDER_NAME: &str = "ollama"; +const PROVIDER_ID: &str = "ollama"; +const PROVIDER_NAME: &str = "Ollama"; #[derive(Default, Debug, Clone, PartialEq)] pub struct OllamaSettings { @@ -32,14 +35,14 @@ pub struct OllamaLanguageModelProvider { struct State { http_client: Arc, available_models: Vec, - settings: OllamaSettings, _subscription: Subscription, } impl State { - fn fetch_models(&self, cx: &mut ModelContext) -> Task> { + fn fetch_models(&self, cx: &ModelContext) -> Task> { + let settings = &AllLanguageModelSettings::get_global(cx).ollama; let http_client = self.http_client.clone(); - let api_url = self.settings.api_url.clone(); + let api_url = settings.api_url.clone(); // As a proxy for the server being "authenticated", we'll check if its up by fetching the models cx.spawn(|this, mut cx| async move { @@ -66,23 +69,25 @@ impl State { impl OllamaLanguageModelProvider { pub fn new(http_client: Arc, cx: &mut AppContext) -> Self { - Self { + let this = Self { http_client: http_client.clone(), state: cx.new_model(|cx| State { http_client, available_models: Default::default(), - settings: OllamaSettings::default(), _subscription: cx.observe_global::(|this: &mut State, cx| { - this.settings = AllLanguageModelSettings::get_global(cx).ollama.clone(); + this.fetch_models(cx).detach_and_log_err(cx); cx.notify(); }), }), - } + }; + this.fetch_models(cx).detach_and_log_err(cx); + this } fn fetch_models(&self, cx: &AppContext) -> Task> { + let settings = &AllLanguageModelSettings::get_global(cx).ollama; let http_client = self.http_client.clone(); - let api_url = self.state.read(cx).settings.api_url.clone(); + let api_url = settings.api_url.clone(); let state = self.state.clone(); // As a proxy for the server being "authenticated", we'll check if its up by fetching the models @@ -117,6 +122,10 @@ impl LanguageModelProviderState for OllamaLanguageModelProvider { } impl LanguageModelProvider for OllamaLanguageModelProvider { + fn id(&self) -> LanguageModelProviderId { + LanguageModelProviderId(PROVIDER_ID.into()) + } + fn name(&self) -> LanguageModelProviderName { LanguageModelProviderName(PROVIDER_NAME.into()) } @@ -131,12 +140,20 @@ impl LanguageModelProvider for OllamaLanguageModelProvider { id: LanguageModelId::from(model.name.clone()), model: model.clone(), http_client: self.http_client.clone(), - state: self.state.clone(), }) as Arc }) .collect() } + fn load_model(&self, model: Arc, cx: &AppContext) { + let settings = &AllLanguageModelSettings::get_global(cx).ollama; + let http_client = self.http_client.clone(); + let api_url = settings.api_url.clone(); + let id = model.id().0.to_string(); + cx.spawn(|_| async move { preload_model(http_client, &api_url, &id).await }) + .detach_and_log_err(cx); + } + fn is_authenticated(&self, cx: &AppContext) -> bool { !self.state.read(cx).available_models.is_empty() } @@ -167,7 +184,6 @@ impl LanguageModelProvider for OllamaLanguageModelProvider { pub struct OllamaLanguageModel { id: LanguageModelId, model: ollama::Model, - state: gpui::Model, http_client: Arc, } @@ -211,6 +227,14 @@ impl LanguageModel for OllamaLanguageModel { LanguageModelName::from(self.model.display_name().to_string()) } + fn provider_id(&self) -> LanguageModelProviderId { + LanguageModelProviderId(PROVIDER_ID.into()) + } + + fn provider_name(&self) -> LanguageModelProviderName { + LanguageModelProviderName(PROVIDER_NAME.into()) + } + fn max_token_count(&self) -> usize { self.model.max_token_count() } @@ -219,10 +243,6 @@ impl LanguageModel for OllamaLanguageModel { format!("ollama/{}", self.model.id()) } - fn provider_name(&self) -> LanguageModelProviderName { - LanguageModelProviderName(PROVIDER_NAME.into()) - } - fn count_tokens( &self, request: LanguageModelRequest, @@ -248,11 +268,9 @@ impl LanguageModel for OllamaLanguageModel { let request = self.to_ollama_request(request); let http_client = self.http_client.clone(); - let Ok((api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, _| { - ( - state.settings.api_url.clone(), - state.settings.low_speed_timeout, - ) + let Ok((api_url, low_speed_timeout)) = cx.update(|cx| { + let settings = &AllLanguageModelSettings::get_global(cx).ollama; + (settings.api_url.clone(), settings.low_speed_timeout) }) else { return futures::future::ready(Err(anyhow!("App state dropped"))).boxed(); }; diff --git a/crates/language_model/src/provider/open_ai.rs b/crates/language_model/src/provider/open_ai.rs index 8135f4e941..bc31ccafed 100644 --- a/crates/language_model/src/provider/open_ai.rs +++ b/crates/language_model/src/provider/open_ai.rs @@ -1,5 +1,5 @@ use anyhow::{anyhow, Result}; -use collections::HashMap; +use collections::BTreeMap; use editor::{Editor, EditorElement, EditorStyle}; use futures::{future::BoxFuture, FutureExt, StreamExt}; use gpui::{ @@ -17,11 +17,12 @@ use util::ResultExt; use crate::{ settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName, - LanguageModelProvider, LanguageModelProviderName, LanguageModelProviderState, - LanguageModelRequest, Role, + LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, + LanguageModelProviderState, LanguageModelRequest, Role, }; -const PROVIDER_NAME: &str = "openai"; +const PROVIDER_ID: &str = "openai"; +const PROVIDER_NAME: &str = "OpenAI"; #[derive(Default, Clone, Debug, PartialEq)] pub struct OpenAiSettings { @@ -37,7 +38,6 @@ pub struct OpenAiLanguageModelProvider { struct State { api_key: Option, - settings: OpenAiSettings, _subscription: Subscription, } @@ -45,9 +45,7 @@ impl OpenAiLanguageModelProvider { pub fn new(http_client: Arc, cx: &mut AppContext) -> Self { let state = cx.new_model(|cx| State { api_key: None, - settings: OpenAiSettings::default(), - _subscription: cx.observe_global::(|this: &mut State, cx| { - this.settings = AllLanguageModelSettings::get_global(cx).open_ai.clone(); + _subscription: cx.observe_global::(|_this: &mut State, cx| { cx.notify(); }), }); @@ -65,12 +63,16 @@ impl LanguageModelProviderState for OpenAiLanguageModelProvider { } impl LanguageModelProvider for OpenAiLanguageModelProvider { + fn id(&self) -> LanguageModelProviderId { + LanguageModelProviderId(PROVIDER_ID.into()) + } + fn name(&self) -> LanguageModelProviderName { LanguageModelProviderName(PROVIDER_NAME.into()) } fn provided_models(&self, cx: &AppContext) -> Vec> { - let mut models = HashMap::default(); + let mut models = BTreeMap::default(); // Add base models from open_ai::Model::iter() for model in open_ai::Model::iter() { @@ -80,7 +82,10 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider { } // Override with available models from settings - for model in &self.state.read(cx).settings.available_models { + for model in &AllLanguageModelSettings::get_global(cx) + .openai + .available_models + { models.insert(model.id().to_string(), model.clone()); } @@ -105,7 +110,10 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider { if self.is_authenticated(cx) { Task::ready(Ok(())) } else { - let api_url = self.state.read(cx).settings.api_url.clone(); + let api_url = AllLanguageModelSettings::get_global(cx) + .openai + .api_url + .clone(); let state = self.state.clone(); cx.spawn(|mut cx| async move { let api_key = if let Ok(api_key) = std::env::var("OPENAI_API_KEY") { @@ -131,7 +139,8 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider { } fn reset_credentials(&self, cx: &AppContext) -> Task> { - let delete_credentials = cx.delete_credentials(&self.state.read(cx).settings.api_url); + let settings = &AllLanguageModelSettings::get_global(cx).openai; + let delete_credentials = cx.delete_credentials(&settings.api_url); let state = self.state.clone(); cx.spawn(|mut cx| async move { delete_credentials.await.log_err(); @@ -188,6 +197,10 @@ impl LanguageModel for OpenAiLanguageModel { LanguageModelName::from(self.model.display_name().to_string()) } + fn provider_id(&self) -> LanguageModelProviderId { + LanguageModelProviderId(PROVIDER_ID.into()) + } + fn provider_name(&self) -> LanguageModelProviderName { LanguageModelProviderName(PROVIDER_NAME.into()) } @@ -216,11 +229,12 @@ impl LanguageModel for OpenAiLanguageModel { let request = self.to_open_ai_request(request); let http_client = self.http_client.clone(); - let Ok((api_key, api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, _| { + let Ok((api_key, api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, cx| { + let settings = &AllLanguageModelSettings::get_global(cx).openai; ( state.api_key.clone(), - state.settings.api_url.clone(), - state.settings.low_speed_timeout, + settings.api_url.clone(), + settings.low_speed_timeout, ) }) else { return futures::future::ready(Err(anyhow!("App state dropped"))).boxed(); @@ -307,11 +321,9 @@ impl AuthenticationPrompt { return; } - let write_credentials = cx.write_credentials( - &self.state.read(cx).settings.api_url, - "Bearer", - api_key.as_bytes(), - ); + let settings = &AllLanguageModelSettings::get_global(cx).openai; + let write_credentials = + cx.write_credentials(&settings.api_url, "Bearer", api_key.as_bytes()); let state = self.state.clone(); cx.spawn(|_, mut cx| async move { write_credentials.await?; diff --git a/crates/language_model/src/registry.rs b/crates/language_model/src/registry.rs index 5308a2fce8..9c92f912bd 100644 --- a/crates/language_model/src/registry.rs +++ b/crates/language_model/src/registry.rs @@ -9,7 +9,7 @@ use crate::{ anthropic::AnthropicLanguageModelProvider, cloud::CloudLanguageModelProvider, ollama::OllamaLanguageModelProvider, open_ai::OpenAiLanguageModelProvider, }, - LanguageModel, LanguageModelProvider, LanguageModelProviderName, LanguageModelProviderState, + LanguageModel, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderState, }; pub fn init(client: Arc, cx: &mut AppContext) { @@ -48,7 +48,7 @@ fn register_language_model_providers( registry.register_provider(CloudLanguageModelProvider::new(client.clone(), cx), cx); } else { registry.unregister_provider( - &LanguageModelProviderName::from( + &LanguageModelProviderId::from( crate::provider::cloud::PROVIDER_NAME.to_string(), ), cx, @@ -65,7 +65,7 @@ impl Global for GlobalLanguageModelRegistry {} #[derive(Default)] pub struct LanguageModelRegistry { - providers: HashMap>, + providers: HashMap>, } impl LanguageModelRegistry { @@ -94,7 +94,7 @@ impl LanguageModelRegistry { provider: T, cx: &mut ModelContext, ) { - let name = provider.name(); + let name = provider.id(); if let Some(subscription) = provider.subscribe(cx) { subscription.detach(); @@ -106,7 +106,7 @@ impl LanguageModelRegistry { pub fn unregister_provider( &mut self, - name: &LanguageModelProviderName, + name: &LanguageModelProviderId, cx: &mut ModelContext, ) { if self.providers.remove(name).is_some() { @@ -116,7 +116,7 @@ impl LanguageModelRegistry { pub fn providers( &self, - ) -> impl Iterator)> { + ) -> impl Iterator)> { self.providers.iter() } @@ -130,7 +130,7 @@ impl LanguageModelRegistry { pub fn available_models_grouped_by_provider( &self, cx: &AppContext, - ) -> HashMap>> { + ) -> HashMap>> { self.providers .iter() .map(|(name, provider)| (name.clone(), provider.provided_models(cx))) @@ -139,7 +139,7 @@ impl LanguageModelRegistry { pub fn provider( &self, - name: &LanguageModelProviderName, + name: &LanguageModelProviderId, ) -> Option> { self.providers.get(name).cloned() } @@ -160,10 +160,10 @@ mod tests { let providers = registry.read(cx).providers().collect::>(); assert_eq!(providers.len(), 1); - assert_eq!(providers[0].0, &crate::provider::fake::provider_name()); + assert_eq!(providers[0].0, &crate::provider::fake::provider_id()); registry.update(cx, |registry, cx| { - registry.unregister_provider(&crate::provider::fake::provider_name(), cx); + registry.unregister_provider(&crate::provider::fake::provider_id(), cx); }); let providers = registry.read(cx).providers().collect::>(); diff --git a/crates/language_model/src/settings.rs b/crates/language_model/src/settings.rs index 0dcc5b4065..262e14937a 100644 --- a/crates/language_model/src/settings.rs +++ b/crates/language_model/src/settings.rs @@ -21,9 +21,9 @@ pub fn init(cx: &mut AppContext) { #[derive(Default)] pub struct AllLanguageModelSettings { - pub open_ai: OpenAiSettings, pub anthropic: AnthropicSettings, pub ollama: OllamaSettings, + pub openai: OpenAiSettings, pub zed_dot_dev: ZedDotDevSettings, } @@ -31,7 +31,7 @@ pub struct AllLanguageModelSettings { pub struct AllLanguageModelSettingsContent { pub anthropic: Option, pub ollama: Option, - pub open_ai: Option, + pub openai: Option, #[serde(rename = "zed.dev")] pub zed_dot_dev: Option, } @@ -110,21 +110,21 @@ impl settings::Settings for AllLanguageModelSettings { } merge( - &mut settings.open_ai.api_url, - value.open_ai.as_ref().and_then(|s| s.api_url.clone()), + &mut settings.openai.api_url, + value.openai.as_ref().and_then(|s| s.api_url.clone()), ); if let Some(low_speed_timeout_in_seconds) = value - .open_ai + .openai .as_ref() .and_then(|s| s.low_speed_timeout_in_seconds) { - settings.open_ai.low_speed_timeout = + settings.openai.low_speed_timeout = Some(Duration::from_secs(low_speed_timeout_in_seconds)); } merge( - &mut settings.open_ai.available_models, + &mut settings.openai.available_models, value - .open_ai + .openai .as_ref() .and_then(|s| s.available_models.clone()), ); diff --git a/crates/ollama/src/ollama.rs b/crates/ollama/src/ollama.rs index 269698ebab..e627118072 100644 --- a/crates/ollama/src/ollama.rs +++ b/crates/ollama/src/ollama.rs @@ -4,7 +4,7 @@ use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest}; use isahc::config::Configurable; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; -use std::{convert::TryFrom, time::Duration}; +use std::{convert::TryFrom, sync::Arc, time::Duration}; pub const OLLAMA_API_URL: &str = "http://localhost:11434"; @@ -243,7 +243,7 @@ pub async fn get_models( } /// Sends an empty request to Ollama to trigger loading the model -pub async fn preload_model(client: &dyn HttpClient, api_url: &str, model: &str) -> Result<()> { +pub async fn preload_model(client: Arc, api_url: &str, model: &str) -> Result<()> { let uri = format!("{api_url}/api/generate"); let request = HttpRequest::builder() .method(Method::POST) diff --git a/docs/src/language-model-integration.md b/docs/src/language-model-integration.md index 76fa5bcbab..24f97379dc 100644 --- a/docs/src/language-model-integration.md +++ b/docs/src/language-model-integration.md @@ -85,12 +85,8 @@ To do so, add the following to your Zed `settings.json`: ```json { - "assistant": { - "version": "1", - "provider": { - "name": "openai", - "type": "openai", - "default_model": "gpt-4-turbo-preview", + "language_models": { + "openai": { "api_url": "http://localhost:11434/v1" } } @@ -103,51 +99,32 @@ The custom URL here is `http://localhost:11434/v1`. You can use Ollama with the Zed assistant by making Ollama appear as an OpenAPI endpoint. -1. Add the following to your Zed `settings.json`: +1. Download, for example, the `mistral` model with Ollama: + ``` + ollama pull mistral + ``` +2. Make sure that the Ollama server is running. You can start it either via running the Ollama app, or launching: + ``` + ollama serve + ``` +3. In the assistant panel, select one of the Ollama models using the model dropdown. +4. (Optional) If you want to change the default url that is used to access the Ollama server, you can do so by adding the following settings: - ```json - { - "assistant": { - "version": "1", - "provider": { - "name": "openai", - "type": "openai", - "default_model": "gpt-4-turbo-preview", - "api_url": "http://localhost:11434/v1" - } +```json +{ + "language_models": { + "ollama": { + "api_url": "http://localhost:11434" } } - ``` -2. Download, for example, the `mistral` model with Ollama: - ``` - ollama run mistral - ``` -3. Copy the model and change its name to match the model in the Zed `settings.json`: - ``` - ollama cp mistral gpt-4-turbo-preview - ``` -4. Use `assistant: reset key` (see the [Setup](#setup) section above) and enter the following API key: - ``` - ollama - ``` -5. Restart Zed +} +``` ### Using Claude 3.5 Sonnet -You can use Claude with the Zed assistant by adding the following settings: +You can use Claude with the Zed assistant by choosing it via the model dropdown in the assistant panel. -```json -"assistant": { - "version": "1", - "provider": { - "default_model": "claude-3-5-sonnet", - "name": "anthropic" - } -}, -``` - -When you save the settings, the assistant panel will open and ask you to add your Anthropic API key. -You need can obtain this key [here](https://console.anthropic.com/settings/keys). +You need can obtain an API key [here](https://console.anthropic.com/settings/keys). Even if you pay for Claude Pro, you will still have to [pay for additional credits](https://console.anthropic.com/settings/plans) to use it via the API.