diff --git a/Cargo.lock b/Cargo.lock index dc7126a608..7c77215e07 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -93,6 +93,7 @@ dependencies = [ "postage", "rand 0.8.5", "rusqlite", + "schemars", "serde", "serde_json", "tiktoken-rs", diff --git a/assets/settings/default.json b/assets/settings/default.json index b73dff2d9f..d07efe7ab9 100644 --- a/assets/settings/default.json +++ b/assets/settings/default.json @@ -228,15 +228,29 @@ "default_width": 640, // Default height when the assistant is docked to the bottom. "default_height": 320, + // Deprecated: Please use `provider.api_url` instead. // The default OpenAI API endpoint to use when starting new conversations. "openai_api_url": "https://api.openai.com/v1", + // Deprecated: Please use `provider.default_model` instead. // The default OpenAI model to use when starting new conversations. This // setting can take three values: // // 1. "gpt-3.5-turbo-0613"" // 2. "gpt-4-0613"" // 3. "gpt-4-1106-preview" - "default_open_ai_model": "gpt-4-1106-preview" + "default_open_ai_model": "gpt-4-1106-preview", + "provider": { + "type": "openai", + // The default OpenAI API endpoint to use when starting new conversations. + "api_url": "https://api.openai.com/v1", + // The default OpenAI model to use when starting new conversations. This + // setting can take three values: + // + // 1. "gpt-3.5-turbo-0613"" + // 2. "gpt-4-0613"" + // 3. "gpt-4-1106-preview" + "default_model": "gpt-4-1106-preview" + } }, // Whether the screen sharing icon is shown in the os status bar. "show_call_status_icon": true, diff --git a/crates/ai/Cargo.toml b/crates/ai/Cargo.toml index 35fe5354c4..1aa2f6d48e 100644 --- a/crates/ai/Cargo.toml +++ b/crates/ai/Cargo.toml @@ -29,6 +29,7 @@ parse_duration = "2.1.1" postage.workspace = true rand.workspace = true rusqlite = { version = "0.29.0", features = ["blob", "array", "modern_sqlite"] } +schemars.workspace = true serde.workspace = true serde_json.workspace = true tiktoken-rs.workspace = true diff --git a/crates/ai/src/providers/open_ai/completion.rs b/crates/ai/src/providers/open_ai/completion.rs index 5cf6658ba2..04cc358894 100644 --- a/crates/ai/src/providers/open_ai/completion.rs +++ b/crates/ai/src/providers/open_ai/completion.rs @@ -1,3 +1,10 @@ +use std::{ + env, + fmt::{self, Display}, + io, + sync::Arc, +}; + use anyhow::{anyhow, Result}; use futures::{ future::BoxFuture, io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt, @@ -6,23 +13,17 @@ use futures::{ use gpui::{AppContext, BackgroundExecutor}; use isahc::{http::StatusCode, Request, RequestExt}; use parking_lot::RwLock; +use schemars::JsonSchema; use serde::{Deserialize, Serialize}; -use std::{ - env, - fmt::{self, Display}, - io, - sync::Arc, -}; use util::ResultExt; +use crate::providers::open_ai::{OpenAiLanguageModel, OPEN_AI_API_URL}; use crate::{ auth::{CredentialProvider, ProviderCredential}, completion::{CompletionProvider, CompletionRequest}, models::LanguageModel, }; -use crate::providers::open_ai::{OpenAiLanguageModel, OPEN_AI_API_URL}; - #[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)] #[serde(rename_all = "lowercase")] pub enum Role { @@ -196,12 +197,56 @@ async fn stream_completion( } } +#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema)] +pub enum AzureOpenAiApiVersion { + /// Retiring April 2, 2024. + #[serde(rename = "2023-03-15-preview")] + V2023_03_15Preview, + #[serde(rename = "2023-05-15")] + V2023_05_15, + /// Retiring April 2, 2024. + #[serde(rename = "2023-06-01-preview")] + V2023_06_01Preview, + /// Retiring April 2, 2024. + #[serde(rename = "2023-07-01-preview")] + V2023_07_01Preview, + /// Retiring April 2, 2024. + #[serde(rename = "2023-08-01-preview")] + V2023_08_01Preview, + /// Retiring April 2, 2024. + #[serde(rename = "2023-09-01-preview")] + V2023_09_01Preview, + #[serde(rename = "2023-12-01-preview")] + V2023_12_01Preview, + #[serde(rename = "2024-02-15-preview")] + V2024_02_15Preview, +} + +impl fmt::Display for AzureOpenAiApiVersion { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "{}", + match self { + Self::V2023_03_15Preview => "2023-03-15-preview", + Self::V2023_05_15 => "2023-05-15", + Self::V2023_06_01Preview => "2023-06-01-preview", + Self::V2023_07_01Preview => "2023-07-01-preview", + Self::V2023_08_01Preview => "2023-08-01-preview", + Self::V2023_09_01Preview => "2023-09-01-preview", + Self::V2023_12_01Preview => "2023-12-01-preview", + Self::V2024_02_15Preview => "2024-02-15-preview", + } + ) + } +} + #[derive(Clone)] pub enum OpenAiCompletionProviderKind { OpenAi, AzureOpenAi { deployment_id: String, - api_version: String, + api_version: AzureOpenAiApiVersion, }, } @@ -217,8 +262,8 @@ impl OpenAiCompletionProviderKind { deployment_id, api_version, } => { - // https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#completions - format!("{api_url}/openai/deployments/{deployment_id}/completions?api-version={api_version}") + // https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions + format!("{api_url}/openai/deployments/{deployment_id}/chat/completions?api-version={api_version}") } } } diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index 2826b93c3a..d5bc08b7bf 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -124,16 +124,18 @@ impl AssistantPanel { .await .log_err() .unwrap_or_default(); - let (api_url, model_name) = cx.update(|cx| { + let (provider_kind, api_url, model_name) = cx.update(|cx| { let settings = AssistantSettings::get_global(cx); - ( - settings.openai_api_url.clone(), - settings.default_open_ai_model.full_name().to_string(), - ) - })?; + anyhow::Ok(( + settings.provider_kind()?, + settings.provider_api_url()?, + settings.provider_model_name()?, + )) + })??; + let completion_provider = OpenAiCompletionProvider::new( api_url, - OpenAiCompletionProviderKind::OpenAi, + provider_kind, model_name, cx.background_executor().clone(), ) @@ -693,24 +695,29 @@ impl AssistantPanel { Task::ready(Ok(Vec::new())) }; - let mut model = AssistantSettings::get_global(cx) - .default_open_ai_model - .clone(); - let model_name = model.full_name(); + let Some(mut model_name) = AssistantSettings::get_global(cx) + .provider_model_name() + .log_err() + else { + return; + }; - let prompt = cx.background_executor().spawn(async move { - let snippets = snippets.await?; + let prompt = cx.background_executor().spawn({ + let model_name = model_name.clone(); + async move { + let snippets = snippets.await?; - let language_name = language_name.as_deref(); - generate_content_prompt( - user_prompt, - language_name, - buffer, - range, - snippets, - model_name, - project_name, - ) + let language_name = language_name.as_deref(); + generate_content_prompt( + user_prompt, + language_name, + buffer, + range, + snippets, + &model_name, + project_name, + ) + } }); let mut messages = Vec::new(); @@ -722,7 +729,7 @@ impl AssistantPanel { .messages(cx) .map(|message| message.to_open_ai_message(buffer)), ); - model = conversation.model.clone(); + model_name = conversation.model.full_name().to_string(); } cx.spawn(|_, mut cx| async move { @@ -735,7 +742,7 @@ impl AssistantPanel { }); let request = Box::new(OpenAiRequest { - model: model.full_name().into(), + model: model_name, messages, stream: true, stop: vec!["|END|>".to_string()], @@ -1454,8 +1461,14 @@ impl Conversation { }); let settings = AssistantSettings::get_global(cx); - let model = settings.default_open_ai_model.clone(); - let api_url = settings.openai_api_url.clone(); + let model = settings + .provider_model() + .log_err() + .unwrap_or(OpenAiModel::FourTurbo); + let api_url = settings + .provider_api_url() + .log_err() + .unwrap_or_else(|| OPEN_AI_API_URL.to_string()); let mut this = Self { id: Some(Uuid::new_v4().to_string()), @@ -3655,9 +3668,9 @@ fn report_assistant_event( let client = workspace.read(cx).project().read(cx).client(); let telemetry = client.telemetry(); - let model = AssistantSettings::get_global(cx) - .default_open_ai_model - .clone(); + let Ok(model_name) = AssistantSettings::get_global(cx).provider_model_name() else { + return; + }; - telemetry.report_assistant_event(conversation_id, assistant_kind, model.full_name()) + telemetry.report_assistant_event(conversation_id, assistant_kind, &model_name) } diff --git a/crates/assistant/src/assistant_settings.rs b/crates/assistant/src/assistant_settings.rs index 16c2d452c2..4b8f3feb85 100644 --- a/crates/assistant/src/assistant_settings.rs +++ b/crates/assistant/src/assistant_settings.rs @@ -1,10 +1,14 @@ -use anyhow; +use ai::providers::open_ai::{ + AzureOpenAiApiVersion, OpenAiCompletionProviderKind, OPEN_AI_API_URL, +}; +use anyhow::anyhow; use gpui::Pixels; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use settings::Settings; -#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)] +#[derive(Clone, Copy, Debug, Serialize, Deserialize, JsonSchema, PartialEq)] +#[serde(rename_all = "snake_case")] pub enum OpenAiModel { #[serde(rename = "gpt-3.5-turbo-0613")] ThreePointFiveTurbo, @@ -17,25 +21,25 @@ pub enum OpenAiModel { impl OpenAiModel { pub fn full_name(&self) -> &'static str { match self { - OpenAiModel::ThreePointFiveTurbo => "gpt-3.5-turbo-0613", - OpenAiModel::Four => "gpt-4-0613", - OpenAiModel::FourTurbo => "gpt-4-1106-preview", + Self::ThreePointFiveTurbo => "gpt-3.5-turbo-0613", + Self::Four => "gpt-4-0613", + Self::FourTurbo => "gpt-4-1106-preview", } } pub fn short_name(&self) -> &'static str { match self { - OpenAiModel::ThreePointFiveTurbo => "gpt-3.5-turbo", - OpenAiModel::Four => "gpt-4", - OpenAiModel::FourTurbo => "gpt-4-turbo", + Self::ThreePointFiveTurbo => "gpt-3.5-turbo", + Self::Four => "gpt-4", + Self::FourTurbo => "gpt-4-turbo", } } pub fn cycle(&self) -> Self { match self { - OpenAiModel::ThreePointFiveTurbo => OpenAiModel::Four, - OpenAiModel::Four => OpenAiModel::FourTurbo, - OpenAiModel::FourTurbo => OpenAiModel::ThreePointFiveTurbo, + Self::ThreePointFiveTurbo => Self::Four, + Self::Four => Self::FourTurbo, + Self::FourTurbo => Self::ThreePointFiveTurbo, } } } @@ -48,14 +52,99 @@ pub enum AssistantDockPosition { Bottom, } -#[derive(Deserialize, Debug)] +#[derive(Debug, Deserialize)] pub struct AssistantSettings { + /// Whether to show the assistant panel button in the status bar. pub button: bool, + /// Where to dock the assistant. pub dock: AssistantDockPosition, + /// Default width in pixels when the assistant is docked to the left or right. pub default_width: Pixels, + /// Default height in pixels when the assistant is docked to the bottom. pub default_height: Pixels, + /// The default OpenAI model to use when starting new conversations. + #[deprecated = "Please use `provider.default_model` instead."] pub default_open_ai_model: OpenAiModel, + /// OpenAI API base URL to use when starting new conversations. + #[deprecated = "Please use `provider.api_url` instead."] pub openai_api_url: String, + /// The settings for the AI provider. + pub provider: AiProviderSettings, +} + +impl AssistantSettings { + pub fn provider_kind(&self) -> anyhow::Result { + match &self.provider { + AiProviderSettings::OpenAi(_) => Ok(OpenAiCompletionProviderKind::OpenAi), + AiProviderSettings::AzureOpenAi(settings) => { + let deployment_id = settings + .deployment_id + .clone() + .ok_or_else(|| anyhow!("no Azure OpenAI deployment ID"))?; + let api_version = settings + .api_version + .ok_or_else(|| anyhow!("no Azure OpenAI API version"))?; + + Ok(OpenAiCompletionProviderKind::AzureOpenAi { + deployment_id, + api_version, + }) + } + } + } + + pub fn provider_api_url(&self) -> anyhow::Result { + match &self.provider { + AiProviderSettings::OpenAi(settings) => Ok(settings + .api_url + .clone() + .unwrap_or_else(|| OPEN_AI_API_URL.to_string())), + AiProviderSettings::AzureOpenAi(settings) => settings + .api_url + .clone() + .ok_or_else(|| anyhow!("no Azure OpenAI API URL")), + } + } + + pub fn provider_model(&self) -> anyhow::Result { + match &self.provider { + AiProviderSettings::OpenAi(settings) => { + Ok(settings.default_model.unwrap_or(OpenAiModel::FourTurbo)) + } + AiProviderSettings::AzureOpenAi(_settings) => { + // TODO: We need to use an Azure OpenAI model here. + Ok(OpenAiModel::FourTurbo) + } + } + } + + pub fn provider_model_name(&self) -> anyhow::Result { + match &self.provider { + AiProviderSettings::OpenAi(settings) => Ok(settings + .default_model + .unwrap_or(OpenAiModel::FourTurbo) + .full_name() + .to_string()), + AiProviderSettings::AzureOpenAi(settings) => settings + .deployment_id + .clone() + .ok_or_else(|| anyhow!("no Azure OpenAI deployment ID")), + } + } +} + +impl Settings for AssistantSettings { + const KEY: Option<&'static str> = Some("assistant"); + + type FileContent = AssistantSettingsContent; + + fn load( + default_value: &Self::FileContent, + user_values: &[&Self::FileContent], + _: &mut gpui::AppContext, + ) -> anyhow::Result { + Self::load_via_json_merge(default_value, user_values) + } } /// Assistant panel settings @@ -77,26 +166,88 @@ pub struct AssistantSettingsContent { /// /// Default: 320 pub default_height: Option, + /// Deprecated: Please use `provider.default_model` instead. /// The default OpenAI model to use when starting new conversations. /// /// Default: gpt-4-1106-preview + #[deprecated = "Please use `provider.default_model` instead."] pub default_open_ai_model: Option, + /// Deprecated: Please use `provider.api_url` instead. /// OpenAI API base URL to use when starting new conversations. /// /// Default: https://api.openai.com/v1 + #[deprecated = "Please use `provider.api_url` instead."] pub openai_api_url: Option, + /// The settings for the AI provider. + #[serde(default)] + pub provider: AiProviderSettingsContent, } -impl Settings for AssistantSettings { - const KEY: Option<&'static str> = Some("assistant"); +#[derive(Debug, Clone, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum AiProviderSettings { + /// The settings for the OpenAI provider. + #[serde(rename = "openai")] + OpenAi(OpenAiProviderSettings), + /// The settings for the Azure OpenAI provider. + #[serde(rename = "azure_openai")] + AzureOpenAi(AzureOpenAiProviderSettings), +} - type FileContent = AssistantSettingsContent; +/// The settings for the AI provider used by the Zed Assistant. +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum AiProviderSettingsContent { + /// The settings for the OpenAI provider. + #[serde(rename = "openai")] + OpenAi(OpenAiProviderSettingsContent), + /// The settings for the Azure OpenAI provider. + #[serde(rename = "azure_openai")] + AzureOpenAi(AzureOpenAiProviderSettingsContent), +} - fn load( - default_value: &Self::FileContent, - user_values: &[&Self::FileContent], - _: &mut gpui::AppContext, - ) -> anyhow::Result { - Self::load_via_json_merge(default_value, user_values) +impl Default for AiProviderSettingsContent { + fn default() -> Self { + Self::OpenAi(OpenAiProviderSettingsContent::default()) } } + +#[derive(Debug, Clone, Deserialize)] +pub struct OpenAiProviderSettings { + /// The OpenAI API base URL to use when starting new conversations. + pub api_url: Option, + /// The default OpenAI model to use when starting new conversations. + pub default_model: Option, +} + +#[derive(Debug, Default, Clone, Serialize, Deserialize, JsonSchema)] +pub struct OpenAiProviderSettingsContent { + /// The OpenAI API base URL to use when starting new conversations. + /// + /// Default: https://api.openai.com/v1 + pub api_url: Option, + /// The default OpenAI model to use when starting new conversations. + /// + /// Default: gpt-4-1106-preview + pub default_model: Option, +} + +#[derive(Debug, Clone, Deserialize)] +pub struct AzureOpenAiProviderSettings { + /// The Azure OpenAI API base URL to use when starting new conversations. + pub api_url: Option, + /// The Azure OpenAI API version. + pub api_version: Option, + /// The Azure OpenAI API deployment ID. + pub deployment_id: Option, +} + +#[derive(Debug, Default, Clone, Serialize, Deserialize, JsonSchema)] +pub struct AzureOpenAiProviderSettingsContent { + /// The Azure OpenAI API base URL to use when starting new conversations. + pub api_url: Option, + /// The Azure OpenAI API version. + pub api_version: Option, + /// The Azure OpenAI deployment ID. + pub deployment_id: Option, +} diff --git a/crates/client/src/telemetry.rs b/crates/client/src/telemetry.rs index 4bddb2841b..946e5da407 100644 --- a/crates/client/src/telemetry.rs +++ b/crates/client/src/telemetry.rs @@ -263,7 +263,7 @@ impl Telemetry { self: &Arc, conversation_id: Option, kind: AssistantKind, - model: &'static str, + model: &str, ) { let event = Event::Assistant(AssistantEvent { conversation_id,