assistant: Fix issues when configuring different providers (#15072)

Release Notes:

- N/A

---------

Co-authored-by: Antonio Scandurra <me@as-cii.com>
This commit is contained in:
Bennet Bo Fenner 2024-07-24 11:21:31 +02:00 committed by GitHub
parent ba6c36f370
commit af4b9805c9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 225 additions and 148 deletions

View File

@ -853,7 +853,17 @@
} }
}, },
// Different settings for specific language models. // Different settings for specific language models.
"language_models": {}, "language_models": {
"anthropic": {
"api_url": "https://api.anthropic.com"
},
"openai": {
"api_url": "https://api.openai.com/v1"
},
"ollama": {
"api_url": "http://localhost:11434"
}
},
// Zed's Prettier integration settings. // Zed's Prettier integration settings.
// Allows to enable/disable formatting with Prettier // Allows to enable/disable formatting with Prettier
// and configure default Prettier, used when no project-level Prettier installation is found. // and configure default Prettier, used when no project-level Prettier installation is found.

View File

@ -23,7 +23,7 @@ use gpui::{actions, impl_actions, AppContext, Global, SharedString, UpdateGlobal
use indexed_docs::IndexedDocsRegistry; use indexed_docs::IndexedDocsRegistry;
pub(crate) use inline_assistant::*; pub(crate) use inline_assistant::*;
use language_model::{ use language_model::{
LanguageModelId, LanguageModelProviderName, LanguageModelRegistry, LanguageModelResponseMessage, LanguageModelId, LanguageModelProviderId, LanguageModelRegistry, LanguageModelResponseMessage,
}; };
pub(crate) use model_selector::*; pub(crate) use model_selector::*;
use semantic_index::{CloudEmbeddingProvider, SemanticIndex}; use semantic_index::{CloudEmbeddingProvider, SemanticIndex};
@ -231,7 +231,7 @@ fn init_completion_provider(cx: &mut AppContext) {
fn update_active_language_model_from_settings(cx: &mut AppContext) { fn update_active_language_model_from_settings(cx: &mut AppContext) {
let settings = AssistantSettings::get_global(cx); let settings = AssistantSettings::get_global(cx);
let provider_name = LanguageModelProviderName::from(settings.default_model.provider.clone()); let provider_name = LanguageModelProviderId::from(settings.default_model.provider.clone());
let model_id = LanguageModelId::from(settings.default_model.model.clone()); let model_id = LanguageModelId::from(settings.default_model.model.clone());
let Some(provider) = LanguageModelRegistry::global(cx) let Some(provider) = LanguageModelRegistry::global(cx)

View File

@ -144,8 +144,8 @@ impl AssistantSettingsContent {
fs, fs,
cx, cx,
move |content, _| { move |content, _| {
if content.open_ai.is_none() { if content.openai.is_none() {
content.open_ai = content.openai =
Some(language_model::settings::OpenAiSettingsContent { Some(language_model::settings::OpenAiSettingsContent {
api_url, api_url,
low_speed_timeout_in_seconds, low_speed_timeout_in_seconds,
@ -243,7 +243,7 @@ impl AssistantSettingsContent {
pub fn set_model(&mut self, language_model: Arc<dyn LanguageModel>) { pub fn set_model(&mut self, language_model: Arc<dyn LanguageModel>) {
let model = language_model.id().0.to_string(); let model = language_model.id().0.to_string();
let provider = language_model.provider_name().0.to_string(); let provider = language_model.provider_id().0.to_string();
match self { match self {
AssistantSettingsContent::Versioned(settings) => match settings { AssistantSettingsContent::Versioned(settings) => match settings {

View File

@ -1438,7 +1438,7 @@ impl Render for PromptEditor {
{ {
let model_name = available_model.name().0.clone(); let model_name = available_model.name().0.clone();
let provider = let provider =
available_model.provider_name().0.clone(); available_model.provider_id().0.clone();
move |_| { move |_| {
h_flex() h_flex()
.w_full() .w_full()

View File

@ -565,7 +565,7 @@ impl Render for PromptEditor {
{ {
let model_name = available_model.name().0.clone(); let model_name = available_model.name().0.clone();
let provider = let provider =
available_model.provider_name().0.clone(); available_model.provider_id().0.clone();
move |_| { move |_| {
h_flex() h_flex()
.w_full() .w_full()

View File

@ -2,7 +2,7 @@ use anyhow::{anyhow, Result};
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt}; use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
use gpui::{AppContext, Global, Model, ModelContext, Task}; use gpui::{AppContext, Global, Model, ModelContext, Task};
use language_model::{ use language_model::{
LanguageModel, LanguageModelProvider, LanguageModelProviderName, LanguageModelRegistry, LanguageModel, LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry,
LanguageModelRequest, LanguageModelRequest,
}; };
use smol::lock::{Semaphore, SemaphoreGuardArc}; use smol::lock::{Semaphore, SemaphoreGuardArc};
@ -89,7 +89,7 @@ impl LanguageModelCompletionProvider {
pub fn set_active_provider( pub fn set_active_provider(
&mut self, &mut self,
provider_name: LanguageModelProviderName, provider_name: LanguageModelProviderId,
cx: &mut ModelContext<Self>, cx: &mut ModelContext<Self>,
) { ) {
self.active_provider = LanguageModelRegistry::read_global(cx).provider(&provider_name); self.active_provider = LanguageModelRegistry::read_global(cx).provider(&provider_name);
@ -103,14 +103,19 @@ impl LanguageModelCompletionProvider {
pub fn set_active_model(&mut self, model: Arc<dyn LanguageModel>, cx: &mut ModelContext<Self>) { pub fn set_active_model(&mut self, model: Arc<dyn LanguageModel>, cx: &mut ModelContext<Self>) {
if self.active_model.as_ref().map_or(false, |m| { if self.active_model.as_ref().map_or(false, |m| {
m.id() == model.id() && m.provider_name() == model.provider_name() m.id() == model.id() && m.provider_id() == model.provider_id()
}) { }) {
return; return;
} }
self.active_provider = self.active_provider =
LanguageModelRegistry::read_global(cx).provider(&model.provider_name()); LanguageModelRegistry::read_global(cx).provider(&model.provider_id());
self.active_model = Some(model); self.active_model = Some(model.clone());
if let Some(provider) = self.active_provider.as_ref() {
provider.load_model(model, cx);
}
cx.notify(); cx.notify();
} }

View File

@ -25,6 +25,7 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
pub trait LanguageModel: Send + Sync { pub trait LanguageModel: Send + Sync {
fn id(&self) -> LanguageModelId; fn id(&self) -> LanguageModelId;
fn name(&self) -> LanguageModelName; fn name(&self) -> LanguageModelName;
fn provider_id(&self) -> LanguageModelProviderId;
fn provider_name(&self) -> LanguageModelProviderName; fn provider_name(&self) -> LanguageModelProviderName;
fn telemetry_id(&self) -> String; fn telemetry_id(&self) -> String;
@ -44,8 +45,10 @@ pub trait LanguageModel: Send + Sync {
} }
pub trait LanguageModelProvider: 'static { pub trait LanguageModelProvider: 'static {
fn id(&self) -> LanguageModelProviderId;
fn name(&self) -> LanguageModelProviderName; fn name(&self) -> LanguageModelProviderName;
fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>>; fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>>;
fn load_model(&self, _model: Arc<dyn LanguageModel>, _cx: &AppContext) {}
fn is_authenticated(&self, cx: &AppContext) -> bool; fn is_authenticated(&self, cx: &AppContext) -> bool;
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>>; fn authenticate(&self, cx: &AppContext) -> Task<Result<()>>;
fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView; fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView;
@ -62,6 +65,9 @@ pub struct LanguageModelId(pub SharedString);
#[derive(Clone, Eq, PartialEq, Hash, Debug)] #[derive(Clone, Eq, PartialEq, Hash, Debug)]
pub struct LanguageModelName(pub SharedString); pub struct LanguageModelName(pub SharedString);
#[derive(Clone, Eq, PartialEq, Hash, Debug)]
pub struct LanguageModelProviderId(pub SharedString);
#[derive(Clone, Eq, PartialEq, Hash, Debug)] #[derive(Clone, Eq, PartialEq, Hash, Debug)]
pub struct LanguageModelProviderName(pub SharedString); pub struct LanguageModelProviderName(pub SharedString);
@ -77,6 +83,12 @@ impl From<String> for LanguageModelName {
} }
} }
impl From<String> for LanguageModelProviderId {
fn from(value: String) -> Self {
Self(SharedString::from(value))
}
}
impl From<String> for LanguageModelProviderName { impl From<String> for LanguageModelProviderName {
fn from(value: String) -> Self { fn from(value: String) -> Self {
Self(SharedString::from(value)) Self(SharedString::from(value))

View File

@ -1,6 +1,5 @@
use anthropic::{stream_completion, Request, RequestMessage}; use anthropic::{stream_completion, Request, RequestMessage};
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use collections::HashMap;
use editor::{Editor, EditorElement, EditorStyle}; use editor::{Editor, EditorElement, EditorStyle};
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt}; use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
use gpui::{ use gpui::{
@ -9,7 +8,7 @@ use gpui::{
}; };
use http_client::HttpClient; use http_client::HttpClient;
use settings::{Settings, SettingsStore}; use settings::{Settings, SettingsStore};
use std::{sync::Arc, time::Duration}; use std::{collections::BTreeMap, sync::Arc, time::Duration};
use strum::IntoEnumIterator; use strum::IntoEnumIterator;
use theme::ThemeSettings; use theme::ThemeSettings;
use ui::prelude::*; use ui::prelude::*;
@ -17,11 +16,12 @@ use util::ResultExt;
use crate::{ use crate::{
settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName, settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
LanguageModelProvider, LanguageModelProviderName, LanguageModelProviderState, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelRequest, LanguageModelRequestMessage, Role, LanguageModelProviderState, LanguageModelRequest, LanguageModelRequestMessage, Role,
}; };
const PROVIDER_NAME: &str = "anthropic"; const PROVIDER_ID: &str = "anthropic";
const PROVIDER_NAME: &str = "Anthropic";
#[derive(Default, Clone, Debug, PartialEq)] #[derive(Default, Clone, Debug, PartialEq)]
pub struct AnthropicSettings { pub struct AnthropicSettings {
@ -37,7 +37,6 @@ pub struct AnthropicLanguageModelProvider {
struct State { struct State {
api_key: Option<String>, api_key: Option<String>,
settings: AnthropicSettings,
_subscription: Subscription, _subscription: Subscription,
} }
@ -45,9 +44,7 @@ impl AnthropicLanguageModelProvider {
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut AppContext) -> Self { pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut AppContext) -> Self {
let state = cx.new_model(|cx| State { let state = cx.new_model(|cx| State {
api_key: None, api_key: None,
settings: AnthropicSettings::default(), _subscription: cx.observe_global::<SettingsStore>(|_, cx| {
_subscription: cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
this.settings = AllLanguageModelSettings::get_global(cx).anthropic.clone();
cx.notify(); cx.notify();
}), }),
}); });
@ -64,12 +61,16 @@ impl LanguageModelProviderState for AnthropicLanguageModelProvider {
} }
impl LanguageModelProvider for AnthropicLanguageModelProvider { impl LanguageModelProvider for AnthropicLanguageModelProvider {
fn id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
}
fn name(&self) -> LanguageModelProviderName { fn name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into()) LanguageModelProviderName(PROVIDER_NAME.into())
} }
fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> { fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
let mut models = HashMap::default(); let mut models = BTreeMap::default();
// Add base models from anthropic::Model::iter() // Add base models from anthropic::Model::iter()
for model in anthropic::Model::iter() { for model in anthropic::Model::iter() {
@ -79,7 +80,11 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider {
} }
// Override with available models from settings // Override with available models from settings
for model in &self.state.read(cx).settings.available_models { for model in AllLanguageModelSettings::get_global(cx)
.anthropic
.available_models
.iter()
{
models.insert(model.id().to_string(), model.clone()); models.insert(model.id().to_string(), model.clone());
} }
@ -104,7 +109,10 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider {
if self.is_authenticated(cx) { if self.is_authenticated(cx) {
Task::ready(Ok(())) Task::ready(Ok(()))
} else { } else {
let api_url = self.state.read(cx).settings.api_url.clone(); let api_url = AllLanguageModelSettings::get_global(cx)
.anthropic
.api_url
.clone();
let state = self.state.clone(); let state = self.state.clone();
cx.spawn(|mut cx| async move { cx.spawn(|mut cx| async move {
let api_key = if let Ok(api_key) = std::env::var("ANTHROPIC_API_KEY") { let api_key = if let Ok(api_key) = std::env::var("ANTHROPIC_API_KEY") {
@ -132,7 +140,8 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider {
fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> { fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
let state = self.state.clone(); let state = self.state.clone();
let delete_credentials = cx.delete_credentials(&self.state.read(cx).settings.api_url); let delete_credentials =
cx.delete_credentials(&AllLanguageModelSettings::get_global(cx).anthropic.api_url);
cx.spawn(|mut cx| async move { cx.spawn(|mut cx| async move {
delete_credentials.await.log_err(); delete_credentials.await.log_err();
state.update(&mut cx, |this, cx| { state.update(&mut cx, |this, cx| {
@ -221,6 +230,10 @@ impl LanguageModel for AnthropicModel {
LanguageModelName::from(self.model.display_name().to_string()) LanguageModelName::from(self.model.display_name().to_string())
} }
fn provider_id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
}
fn provider_name(&self) -> LanguageModelProviderName { fn provider_name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into()) LanguageModelProviderName(PROVIDER_NAME.into())
} }
@ -249,11 +262,13 @@ impl LanguageModel for AnthropicModel {
let request = self.to_anthropic_request(request); let request = self.to_anthropic_request(request);
let http_client = self.http_client.clone(); let http_client = self.http_client.clone();
let Ok((api_key, api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, _| {
let Ok((api_key, api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, cx| {
let settings = &AllLanguageModelSettings::get_global(cx).anthropic;
( (
state.api_key.clone(), state.api_key.clone(),
state.settings.api_url.clone(), settings.api_url.clone(),
state.settings.low_speed_timeout, settings.low_speed_timeout,
) )
}) else { }) else {
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed(); return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
@ -365,7 +380,10 @@ impl AuthenticationPrompt {
} }
let write_credentials = cx.write_credentials( let write_credentials = cx.write_credentials(
&self.state.read(cx).settings.api_url, AllLanguageModelSettings::get_global(cx)
.anthropic
.api_url
.as_str(),
"Bearer", "Bearer",
api_key.as_bytes(), api_key.as_bytes(),
); );

View File

@ -1,15 +1,15 @@
use super::open_ai::count_open_ai_tokens; use super::open_ai::count_open_ai_tokens;
use crate::{ use crate::{
settings::AllLanguageModelSettings, CloudModel, LanguageModel, LanguageModelId, settings::AllLanguageModelSettings, CloudModel, LanguageModel, LanguageModelId,
LanguageModelName, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest,
}; };
use anyhow::Result; use anyhow::Result;
use client::Client; use client::Client;
use collections::HashMap;
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt, TryFutureExt}; use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt, TryFutureExt};
use gpui::{AnyView, AppContext, AsyncAppContext, Subscription, Task}; use gpui::{AnyView, AppContext, AsyncAppContext, Subscription, Task};
use settings::{Settings, SettingsStore}; use settings::{Settings, SettingsStore};
use std::sync::Arc; use std::{collections::BTreeMap, sync::Arc};
use strum::IntoEnumIterator; use strum::IntoEnumIterator;
use ui::prelude::*; use ui::prelude::*;
@ -17,6 +17,7 @@ use crate::LanguageModelProvider;
use super::anthropic::{count_anthropic_tokens, preprocess_anthropic_request}; use super::anthropic::{count_anthropic_tokens, preprocess_anthropic_request};
pub const PROVIDER_ID: &str = "zed.dev";
pub const PROVIDER_NAME: &str = "zed.dev"; pub const PROVIDER_NAME: &str = "zed.dev";
#[derive(Default, Clone, Debug, PartialEq)] #[derive(Default, Clone, Debug, PartialEq)]
@ -33,7 +34,6 @@ pub struct CloudLanguageModelProvider {
struct State { struct State {
client: Arc<Client>, client: Arc<Client>,
status: client::Status, status: client::Status,
settings: ZedDotDevSettings,
_subscription: Subscription, _subscription: Subscription,
} }
@ -52,9 +52,7 @@ impl CloudLanguageModelProvider {
let state = cx.new_model(|cx| State { let state = cx.new_model(|cx| State {
client: client.clone(), client: client.clone(),
status, status,
settings: ZedDotDevSettings::default(), _subscription: cx.observe_global::<SettingsStore>(|_, cx| {
_subscription: cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
this.settings = AllLanguageModelSettings::get_global(cx).zed_dot_dev.clone();
cx.notify(); cx.notify();
}), }),
}); });
@ -90,12 +88,16 @@ impl LanguageModelProviderState for CloudLanguageModelProvider {
} }
impl LanguageModelProvider for CloudLanguageModelProvider { impl LanguageModelProvider for CloudLanguageModelProvider {
fn id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
}
fn name(&self) -> LanguageModelProviderName { fn name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into()) LanguageModelProviderName(PROVIDER_NAME.into())
} }
fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> { fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
let mut models = HashMap::default(); let mut models = BTreeMap::default();
// Add base models from CloudModel::iter() // Add base models from CloudModel::iter()
for model in CloudModel::iter() { for model in CloudModel::iter() {
@ -105,7 +107,10 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
} }
// Override with available models from settings // Override with available models from settings
for model in &self.state.read(cx).settings.available_models { for model in &AllLanguageModelSettings::get_global(cx)
.zed_dot_dev
.available_models
{
models.insert(model.id().to_string(), model.clone()); models.insert(model.id().to_string(), model.clone());
} }
@ -156,6 +161,10 @@ impl LanguageModel for CloudLanguageModel {
LanguageModelName::from(self.model.display_name().to_string()) LanguageModelName::from(self.model.display_name().to_string())
} }
fn provider_id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
}
fn provider_name(&self) -> LanguageModelProviderName { fn provider_name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into()) LanguageModelProviderName(PROVIDER_NAME.into())
} }
@ -187,6 +196,9 @@ impl LanguageModel for CloudLanguageModel {
| CloudModel::Claude3Opus | CloudModel::Claude3Opus
| CloudModel::Claude3Sonnet | CloudModel::Claude3Sonnet
| CloudModel::Claude3Haiku => count_anthropic_tokens(request, cx), | CloudModel::Claude3Haiku => count_anthropic_tokens(request, cx),
CloudModel::Custom { name, .. } if name.starts_with("anthropic/") => {
count_anthropic_tokens(request, cx)
}
_ => { _ => {
let request = self.client.request(proto::CountTokensWithLanguageModel { let request = self.client.request(proto::CountTokensWithLanguageModel {
model: self.model.id().to_string(), model: self.model.id().to_string(),

View File

@ -5,7 +5,8 @@ use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, St
use crate::{ use crate::{
LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
LanguageModelRequest,
}; };
use gpui::{AnyView, AppContext, AsyncAppContext, Task}; use gpui::{AnyView, AppContext, AsyncAppContext, Task};
use http_client::Result; use http_client::Result;
@ -19,8 +20,12 @@ pub fn language_model_name() -> LanguageModelName {
LanguageModelName::from("Fake".to_string()) LanguageModelName::from("Fake".to_string())
} }
pub fn provider_id() -> LanguageModelProviderId {
LanguageModelProviderId::from("fake".to_string())
}
pub fn provider_name() -> LanguageModelProviderName { pub fn provider_name() -> LanguageModelProviderName {
LanguageModelProviderName::from("fake".to_string()) LanguageModelProviderName::from("Fake".to_string())
} }
#[derive(Clone, Default)] #[derive(Clone, Default)]
@ -35,6 +40,10 @@ impl LanguageModelProviderState for FakeLanguageModelProvider {
} }
impl LanguageModelProvider for FakeLanguageModelProvider { impl LanguageModelProvider for FakeLanguageModelProvider {
fn id(&self) -> LanguageModelProviderId {
provider_id()
}
fn name(&self) -> LanguageModelProviderName { fn name(&self) -> LanguageModelProviderName {
provider_name() provider_name()
} }
@ -125,6 +134,10 @@ impl LanguageModel for FakeLanguageModel {
language_model_name() language_model_name()
} }
fn provider_id(&self) -> LanguageModelProviderId {
provider_id()
}
fn provider_name(&self) -> LanguageModelProviderName { fn provider_name(&self) -> LanguageModelProviderName {
provider_name() provider_name()
} }

View File

@ -2,21 +2,24 @@ use anyhow::{anyhow, Result};
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt}; use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
use gpui::{AnyView, AppContext, AsyncAppContext, ModelContext, Subscription, Task}; use gpui::{AnyView, AppContext, AsyncAppContext, ModelContext, Subscription, Task};
use http_client::HttpClient; use http_client::HttpClient;
use ollama::{get_models, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest}; use ollama::{
get_models, preload_model, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest,
};
use settings::{Settings, SettingsStore}; use settings::{Settings, SettingsStore};
use std::{sync::Arc, time::Duration}; use std::{sync::Arc, time::Duration};
use ui::{prelude::*, ButtonLike, ElevationIndex}; use ui::{prelude::*, ButtonLike, ElevationIndex};
use crate::{ use crate::{
settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName, settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
LanguageModelProvider, LanguageModelProviderName, LanguageModelProviderState, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelRequest, Role, LanguageModelProviderState, LanguageModelRequest, Role,
}; };
const OLLAMA_DOWNLOAD_URL: &str = "https://ollama.com/download"; const OLLAMA_DOWNLOAD_URL: &str = "https://ollama.com/download";
const OLLAMA_LIBRARY_URL: &str = "https://ollama.com/library"; const OLLAMA_LIBRARY_URL: &str = "https://ollama.com/library";
const PROVIDER_NAME: &str = "ollama"; const PROVIDER_ID: &str = "ollama";
const PROVIDER_NAME: &str = "Ollama";
#[derive(Default, Debug, Clone, PartialEq)] #[derive(Default, Debug, Clone, PartialEq)]
pub struct OllamaSettings { pub struct OllamaSettings {
@ -32,14 +35,14 @@ pub struct OllamaLanguageModelProvider {
struct State { struct State {
http_client: Arc<dyn HttpClient>, http_client: Arc<dyn HttpClient>,
available_models: Vec<ollama::Model>, available_models: Vec<ollama::Model>,
settings: OllamaSettings,
_subscription: Subscription, _subscription: Subscription,
} }
impl State { impl State {
fn fetch_models(&self, cx: &mut ModelContext<Self>) -> Task<Result<()>> { fn fetch_models(&self, cx: &ModelContext<Self>) -> Task<Result<()>> {
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
let http_client = self.http_client.clone(); let http_client = self.http_client.clone();
let api_url = self.settings.api_url.clone(); let api_url = settings.api_url.clone();
// As a proxy for the server being "authenticated", we'll check if its up by fetching the models // As a proxy for the server being "authenticated", we'll check if its up by fetching the models
cx.spawn(|this, mut cx| async move { cx.spawn(|this, mut cx| async move {
@ -66,23 +69,25 @@ impl State {
impl OllamaLanguageModelProvider { impl OllamaLanguageModelProvider {
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut AppContext) -> Self { pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut AppContext) -> Self {
Self { let this = Self {
http_client: http_client.clone(), http_client: http_client.clone(),
state: cx.new_model(|cx| State { state: cx.new_model(|cx| State {
http_client, http_client,
available_models: Default::default(), available_models: Default::default(),
settings: OllamaSettings::default(),
_subscription: cx.observe_global::<SettingsStore>(|this: &mut State, cx| { _subscription: cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
this.settings = AllLanguageModelSettings::get_global(cx).ollama.clone(); this.fetch_models(cx).detach_and_log_err(cx);
cx.notify(); cx.notify();
}), }),
}), }),
} };
this.fetch_models(cx).detach_and_log_err(cx);
this
} }
fn fetch_models(&self, cx: &AppContext) -> Task<Result<()>> { fn fetch_models(&self, cx: &AppContext) -> Task<Result<()>> {
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
let http_client = self.http_client.clone(); let http_client = self.http_client.clone();
let api_url = self.state.read(cx).settings.api_url.clone(); let api_url = settings.api_url.clone();
let state = self.state.clone(); let state = self.state.clone();
// As a proxy for the server being "authenticated", we'll check if its up by fetching the models // As a proxy for the server being "authenticated", we'll check if its up by fetching the models
@ -117,6 +122,10 @@ impl LanguageModelProviderState for OllamaLanguageModelProvider {
} }
impl LanguageModelProvider for OllamaLanguageModelProvider { impl LanguageModelProvider for OllamaLanguageModelProvider {
fn id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
}
fn name(&self) -> LanguageModelProviderName { fn name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into()) LanguageModelProviderName(PROVIDER_NAME.into())
} }
@ -131,12 +140,20 @@ impl LanguageModelProvider for OllamaLanguageModelProvider {
id: LanguageModelId::from(model.name.clone()), id: LanguageModelId::from(model.name.clone()),
model: model.clone(), model: model.clone(),
http_client: self.http_client.clone(), http_client: self.http_client.clone(),
state: self.state.clone(),
}) as Arc<dyn LanguageModel> }) as Arc<dyn LanguageModel>
}) })
.collect() .collect()
} }
fn load_model(&self, model: Arc<dyn LanguageModel>, cx: &AppContext) {
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
let http_client = self.http_client.clone();
let api_url = settings.api_url.clone();
let id = model.id().0.to_string();
cx.spawn(|_| async move { preload_model(http_client, &api_url, &id).await })
.detach_and_log_err(cx);
}
fn is_authenticated(&self, cx: &AppContext) -> bool { fn is_authenticated(&self, cx: &AppContext) -> bool {
!self.state.read(cx).available_models.is_empty() !self.state.read(cx).available_models.is_empty()
} }
@ -167,7 +184,6 @@ impl LanguageModelProvider for OllamaLanguageModelProvider {
pub struct OllamaLanguageModel { pub struct OllamaLanguageModel {
id: LanguageModelId, id: LanguageModelId,
model: ollama::Model, model: ollama::Model,
state: gpui::Model<State>,
http_client: Arc<dyn HttpClient>, http_client: Arc<dyn HttpClient>,
} }
@ -211,6 +227,14 @@ impl LanguageModel for OllamaLanguageModel {
LanguageModelName::from(self.model.display_name().to_string()) LanguageModelName::from(self.model.display_name().to_string())
} }
fn provider_id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
}
fn provider_name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
}
fn max_token_count(&self) -> usize { fn max_token_count(&self) -> usize {
self.model.max_token_count() self.model.max_token_count()
} }
@ -219,10 +243,6 @@ impl LanguageModel for OllamaLanguageModel {
format!("ollama/{}", self.model.id()) format!("ollama/{}", self.model.id())
} }
fn provider_name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
}
fn count_tokens( fn count_tokens(
&self, &self,
request: LanguageModelRequest, request: LanguageModelRequest,
@ -248,11 +268,9 @@ impl LanguageModel for OllamaLanguageModel {
let request = self.to_ollama_request(request); let request = self.to_ollama_request(request);
let http_client = self.http_client.clone(); let http_client = self.http_client.clone();
let Ok((api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, _| { let Ok((api_url, low_speed_timeout)) = cx.update(|cx| {
( let settings = &AllLanguageModelSettings::get_global(cx).ollama;
state.settings.api_url.clone(), (settings.api_url.clone(), settings.low_speed_timeout)
state.settings.low_speed_timeout,
)
}) else { }) else {
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed(); return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
}; };

View File

@ -1,5 +1,5 @@
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use collections::HashMap; use collections::BTreeMap;
use editor::{Editor, EditorElement, EditorStyle}; use editor::{Editor, EditorElement, EditorStyle};
use futures::{future::BoxFuture, FutureExt, StreamExt}; use futures::{future::BoxFuture, FutureExt, StreamExt};
use gpui::{ use gpui::{
@ -17,11 +17,12 @@ use util::ResultExt;
use crate::{ use crate::{
settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName, settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
LanguageModelProvider, LanguageModelProviderName, LanguageModelProviderState, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelRequest, Role, LanguageModelProviderState, LanguageModelRequest, Role,
}; };
const PROVIDER_NAME: &str = "openai"; const PROVIDER_ID: &str = "openai";
const PROVIDER_NAME: &str = "OpenAI";
#[derive(Default, Clone, Debug, PartialEq)] #[derive(Default, Clone, Debug, PartialEq)]
pub struct OpenAiSettings { pub struct OpenAiSettings {
@ -37,7 +38,6 @@ pub struct OpenAiLanguageModelProvider {
struct State { struct State {
api_key: Option<String>, api_key: Option<String>,
settings: OpenAiSettings,
_subscription: Subscription, _subscription: Subscription,
} }
@ -45,9 +45,7 @@ impl OpenAiLanguageModelProvider {
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut AppContext) -> Self { pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut AppContext) -> Self {
let state = cx.new_model(|cx| State { let state = cx.new_model(|cx| State {
api_key: None, api_key: None,
settings: OpenAiSettings::default(), _subscription: cx.observe_global::<SettingsStore>(|_this: &mut State, cx| {
_subscription: cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
this.settings = AllLanguageModelSettings::get_global(cx).open_ai.clone();
cx.notify(); cx.notify();
}), }),
}); });
@ -65,12 +63,16 @@ impl LanguageModelProviderState for OpenAiLanguageModelProvider {
} }
impl LanguageModelProvider for OpenAiLanguageModelProvider { impl LanguageModelProvider for OpenAiLanguageModelProvider {
fn id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
}
fn name(&self) -> LanguageModelProviderName { fn name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into()) LanguageModelProviderName(PROVIDER_NAME.into())
} }
fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> { fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
let mut models = HashMap::default(); let mut models = BTreeMap::default();
// Add base models from open_ai::Model::iter() // Add base models from open_ai::Model::iter()
for model in open_ai::Model::iter() { for model in open_ai::Model::iter() {
@ -80,7 +82,10 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider {
} }
// Override with available models from settings // Override with available models from settings
for model in &self.state.read(cx).settings.available_models { for model in &AllLanguageModelSettings::get_global(cx)
.openai
.available_models
{
models.insert(model.id().to_string(), model.clone()); models.insert(model.id().to_string(), model.clone());
} }
@ -105,7 +110,10 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider {
if self.is_authenticated(cx) { if self.is_authenticated(cx) {
Task::ready(Ok(())) Task::ready(Ok(()))
} else { } else {
let api_url = self.state.read(cx).settings.api_url.clone(); let api_url = AllLanguageModelSettings::get_global(cx)
.openai
.api_url
.clone();
let state = self.state.clone(); let state = self.state.clone();
cx.spawn(|mut cx| async move { cx.spawn(|mut cx| async move {
let api_key = if let Ok(api_key) = std::env::var("OPENAI_API_KEY") { let api_key = if let Ok(api_key) = std::env::var("OPENAI_API_KEY") {
@ -131,7 +139,8 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider {
} }
fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> { fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
let delete_credentials = cx.delete_credentials(&self.state.read(cx).settings.api_url); let settings = &AllLanguageModelSettings::get_global(cx).openai;
let delete_credentials = cx.delete_credentials(&settings.api_url);
let state = self.state.clone(); let state = self.state.clone();
cx.spawn(|mut cx| async move { cx.spawn(|mut cx| async move {
delete_credentials.await.log_err(); delete_credentials.await.log_err();
@ -188,6 +197,10 @@ impl LanguageModel for OpenAiLanguageModel {
LanguageModelName::from(self.model.display_name().to_string()) LanguageModelName::from(self.model.display_name().to_string())
} }
fn provider_id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
}
fn provider_name(&self) -> LanguageModelProviderName { fn provider_name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into()) LanguageModelProviderName(PROVIDER_NAME.into())
} }
@ -216,11 +229,12 @@ impl LanguageModel for OpenAiLanguageModel {
let request = self.to_open_ai_request(request); let request = self.to_open_ai_request(request);
let http_client = self.http_client.clone(); let http_client = self.http_client.clone();
let Ok((api_key, api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, _| { let Ok((api_key, api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, cx| {
let settings = &AllLanguageModelSettings::get_global(cx).openai;
( (
state.api_key.clone(), state.api_key.clone(),
state.settings.api_url.clone(), settings.api_url.clone(),
state.settings.low_speed_timeout, settings.low_speed_timeout,
) )
}) else { }) else {
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed(); return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
@ -307,11 +321,9 @@ impl AuthenticationPrompt {
return; return;
} }
let write_credentials = cx.write_credentials( let settings = &AllLanguageModelSettings::get_global(cx).openai;
&self.state.read(cx).settings.api_url, let write_credentials =
"Bearer", cx.write_credentials(&settings.api_url, "Bearer", api_key.as_bytes());
api_key.as_bytes(),
);
let state = self.state.clone(); let state = self.state.clone();
cx.spawn(|_, mut cx| async move { cx.spawn(|_, mut cx| async move {
write_credentials.await?; write_credentials.await?;

View File

@ -9,7 +9,7 @@ use crate::{
anthropic::AnthropicLanguageModelProvider, cloud::CloudLanguageModelProvider, anthropic::AnthropicLanguageModelProvider, cloud::CloudLanguageModelProvider,
ollama::OllamaLanguageModelProvider, open_ai::OpenAiLanguageModelProvider, ollama::OllamaLanguageModelProvider, open_ai::OpenAiLanguageModelProvider,
}, },
LanguageModel, LanguageModelProvider, LanguageModelProviderName, LanguageModelProviderState, LanguageModel, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderState,
}; };
pub fn init(client: Arc<Client>, cx: &mut AppContext) { pub fn init(client: Arc<Client>, cx: &mut AppContext) {
@ -48,7 +48,7 @@ fn register_language_model_providers(
registry.register_provider(CloudLanguageModelProvider::new(client.clone(), cx), cx); registry.register_provider(CloudLanguageModelProvider::new(client.clone(), cx), cx);
} else { } else {
registry.unregister_provider( registry.unregister_provider(
&LanguageModelProviderName::from( &LanguageModelProviderId::from(
crate::provider::cloud::PROVIDER_NAME.to_string(), crate::provider::cloud::PROVIDER_NAME.to_string(),
), ),
cx, cx,
@ -65,7 +65,7 @@ impl Global for GlobalLanguageModelRegistry {}
#[derive(Default)] #[derive(Default)]
pub struct LanguageModelRegistry { pub struct LanguageModelRegistry {
providers: HashMap<LanguageModelProviderName, Arc<dyn LanguageModelProvider>>, providers: HashMap<LanguageModelProviderId, Arc<dyn LanguageModelProvider>>,
} }
impl LanguageModelRegistry { impl LanguageModelRegistry {
@ -94,7 +94,7 @@ impl LanguageModelRegistry {
provider: T, provider: T,
cx: &mut ModelContext<Self>, cx: &mut ModelContext<Self>,
) { ) {
let name = provider.name(); let name = provider.id();
if let Some(subscription) = provider.subscribe(cx) { if let Some(subscription) = provider.subscribe(cx) {
subscription.detach(); subscription.detach();
@ -106,7 +106,7 @@ impl LanguageModelRegistry {
pub fn unregister_provider( pub fn unregister_provider(
&mut self, &mut self,
name: &LanguageModelProviderName, name: &LanguageModelProviderId,
cx: &mut ModelContext<Self>, cx: &mut ModelContext<Self>,
) { ) {
if self.providers.remove(name).is_some() { if self.providers.remove(name).is_some() {
@ -116,7 +116,7 @@ impl LanguageModelRegistry {
pub fn providers( pub fn providers(
&self, &self,
) -> impl Iterator<Item = (&LanguageModelProviderName, &Arc<dyn LanguageModelProvider>)> { ) -> impl Iterator<Item = (&LanguageModelProviderId, &Arc<dyn LanguageModelProvider>)> {
self.providers.iter() self.providers.iter()
} }
@ -130,7 +130,7 @@ impl LanguageModelRegistry {
pub fn available_models_grouped_by_provider( pub fn available_models_grouped_by_provider(
&self, &self,
cx: &AppContext, cx: &AppContext,
) -> HashMap<LanguageModelProviderName, Vec<Arc<dyn LanguageModel>>> { ) -> HashMap<LanguageModelProviderId, Vec<Arc<dyn LanguageModel>>> {
self.providers self.providers
.iter() .iter()
.map(|(name, provider)| (name.clone(), provider.provided_models(cx))) .map(|(name, provider)| (name.clone(), provider.provided_models(cx)))
@ -139,7 +139,7 @@ impl LanguageModelRegistry {
pub fn provider( pub fn provider(
&self, &self,
name: &LanguageModelProviderName, name: &LanguageModelProviderId,
) -> Option<Arc<dyn LanguageModelProvider>> { ) -> Option<Arc<dyn LanguageModelProvider>> {
self.providers.get(name).cloned() self.providers.get(name).cloned()
} }
@ -160,10 +160,10 @@ mod tests {
let providers = registry.read(cx).providers().collect::<Vec<_>>(); let providers = registry.read(cx).providers().collect::<Vec<_>>();
assert_eq!(providers.len(), 1); assert_eq!(providers.len(), 1);
assert_eq!(providers[0].0, &crate::provider::fake::provider_name()); assert_eq!(providers[0].0, &crate::provider::fake::provider_id());
registry.update(cx, |registry, cx| { registry.update(cx, |registry, cx| {
registry.unregister_provider(&crate::provider::fake::provider_name(), cx); registry.unregister_provider(&crate::provider::fake::provider_id(), cx);
}); });
let providers = registry.read(cx).providers().collect::<Vec<_>>(); let providers = registry.read(cx).providers().collect::<Vec<_>>();

View File

@ -21,9 +21,9 @@ pub fn init(cx: &mut AppContext) {
#[derive(Default)] #[derive(Default)]
pub struct AllLanguageModelSettings { pub struct AllLanguageModelSettings {
pub open_ai: OpenAiSettings,
pub anthropic: AnthropicSettings, pub anthropic: AnthropicSettings,
pub ollama: OllamaSettings, pub ollama: OllamaSettings,
pub openai: OpenAiSettings,
pub zed_dot_dev: ZedDotDevSettings, pub zed_dot_dev: ZedDotDevSettings,
} }
@ -31,7 +31,7 @@ pub struct AllLanguageModelSettings {
pub struct AllLanguageModelSettingsContent { pub struct AllLanguageModelSettingsContent {
pub anthropic: Option<AnthropicSettingsContent>, pub anthropic: Option<AnthropicSettingsContent>,
pub ollama: Option<OllamaSettingsContent>, pub ollama: Option<OllamaSettingsContent>,
pub open_ai: Option<OpenAiSettingsContent>, pub openai: Option<OpenAiSettingsContent>,
#[serde(rename = "zed.dev")] #[serde(rename = "zed.dev")]
pub zed_dot_dev: Option<ZedDotDevSettingsContent>, pub zed_dot_dev: Option<ZedDotDevSettingsContent>,
} }
@ -110,21 +110,21 @@ impl settings::Settings for AllLanguageModelSettings {
} }
merge( merge(
&mut settings.open_ai.api_url, &mut settings.openai.api_url,
value.open_ai.as_ref().and_then(|s| s.api_url.clone()), value.openai.as_ref().and_then(|s| s.api_url.clone()),
); );
if let Some(low_speed_timeout_in_seconds) = value if let Some(low_speed_timeout_in_seconds) = value
.open_ai .openai
.as_ref() .as_ref()
.and_then(|s| s.low_speed_timeout_in_seconds) .and_then(|s| s.low_speed_timeout_in_seconds)
{ {
settings.open_ai.low_speed_timeout = settings.openai.low_speed_timeout =
Some(Duration::from_secs(low_speed_timeout_in_seconds)); Some(Duration::from_secs(low_speed_timeout_in_seconds));
} }
merge( merge(
&mut settings.open_ai.available_models, &mut settings.openai.available_models,
value value
.open_ai .openai
.as_ref() .as_ref()
.and_then(|s| s.available_models.clone()), .and_then(|s| s.available_models.clone()),
); );

View File

@ -4,7 +4,7 @@ use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
use isahc::config::Configurable; use isahc::config::Configurable;
use schemars::JsonSchema; use schemars::JsonSchema;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::{convert::TryFrom, time::Duration}; use std::{convert::TryFrom, sync::Arc, time::Duration};
pub const OLLAMA_API_URL: &str = "http://localhost:11434"; pub const OLLAMA_API_URL: &str = "http://localhost:11434";
@ -243,7 +243,7 @@ pub async fn get_models(
} }
/// Sends an empty request to Ollama to trigger loading the model /// Sends an empty request to Ollama to trigger loading the model
pub async fn preload_model(client: &dyn HttpClient, api_url: &str, model: &str) -> Result<()> { pub async fn preload_model(client: Arc<dyn HttpClient>, api_url: &str, model: &str) -> Result<()> {
let uri = format!("{api_url}/api/generate"); let uri = format!("{api_url}/api/generate");
let request = HttpRequest::builder() let request = HttpRequest::builder()
.method(Method::POST) .method(Method::POST)

View File

@ -85,12 +85,8 @@ To do so, add the following to your Zed `settings.json`:
```json ```json
{ {
"assistant": { "language_models": {
"version": "1", "openai": {
"provider": {
"name": "openai",
"type": "openai",
"default_model": "gpt-4-turbo-preview",
"api_url": "http://localhost:11434/v1" "api_url": "http://localhost:11434/v1"
} }
} }
@ -103,51 +99,32 @@ The custom URL here is `http://localhost:11434/v1`.
You can use Ollama with the Zed assistant by making Ollama appear as an OpenAPI endpoint. You can use Ollama with the Zed assistant by making Ollama appear as an OpenAPI endpoint.
1. Add the following to your Zed `settings.json`: 1. Download, for example, the `mistral` model with Ollama:
```
ollama pull mistral
```
2. Make sure that the Ollama server is running. You can start it either via running the Ollama app, or launching:
```
ollama serve
```
3. In the assistant panel, select one of the Ollama models using the model dropdown.
4. (Optional) If you want to change the default url that is used to access the Ollama server, you can do so by adding the following settings:
```json ```json
{ {
"assistant": { "language_models": {
"version": "1", "ollama": {
"provider": { "api_url": "http://localhost:11434"
"name": "openai",
"type": "openai",
"default_model": "gpt-4-turbo-preview",
"api_url": "http://localhost:11434/v1"
}
} }
} }
``` }
2. Download, for example, the `mistral` model with Ollama: ```
```
ollama run mistral
```
3. Copy the model and change its name to match the model in the Zed `settings.json`:
```
ollama cp mistral gpt-4-turbo-preview
```
4. Use `assistant: reset key` (see the [Setup](#setup) section above) and enter the following API key:
```
ollama
```
5. Restart Zed
### Using Claude 3.5 Sonnet ### Using Claude 3.5 Sonnet
You can use Claude with the Zed assistant by adding the following settings: You can use Claude with the Zed assistant by choosing it via the model dropdown in the assistant panel.
```json You need can obtain an API key [here](https://console.anthropic.com/settings/keys).
"assistant": {
"version": "1",
"provider": {
"default_model": "claude-3-5-sonnet",
"name": "anthropic"
}
},
```
When you save the settings, the assistant panel will open and ask you to add your Anthropic API key.
You need can obtain this key [here](https://console.anthropic.com/settings/keys).
Even if you pay for Claude Pro, you will still have to [pay for additional credits](https://console.anthropic.com/settings/plans) to use it via the API. Even if you pay for Claude Pro, you will still have to [pay for additional credits](https://console.anthropic.com/settings/plans) to use it via the API.