Add the ability to customize available models for OpenAI-compatible services (#13276)

Closes #11984, closes #11075.

Release Notes:

- Added the ability to customize available models for OpenAI-compatible
services ([#11984](https://github.com/zed-industries/zed/issues/11984))
([#11075](https://github.com/zed-industries/zed/issues/11075)).


![image](https://github.com/zed-industries/zed/assets/32017007/01057e7b-1f21-49ad-a3ad-abc5282ffaf0)
This commit is contained in:
ᴀᴍᴛᴏᴀᴇʀ 2024-06-26 04:37:02 +08:00 committed by GitHub
parent 9f88460870
commit 922fcaf5a6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 79 additions and 11 deletions

View File

@ -169,6 +169,7 @@ pub enum AssistantProvider {
model: OpenAiModel, model: OpenAiModel,
api_url: String, api_url: String,
low_speed_timeout_in_seconds: Option<u64>, low_speed_timeout_in_seconds: Option<u64>,
available_models: Vec<OpenAiModel>,
}, },
Anthropic { Anthropic {
model: AnthropicModel, model: AnthropicModel,
@ -188,6 +189,7 @@ impl Default for AssistantProvider {
model: OpenAiModel::default(), model: OpenAiModel::default(),
api_url: open_ai::OPEN_AI_API_URL.into(), api_url: open_ai::OPEN_AI_API_URL.into(),
low_speed_timeout_in_seconds: None, low_speed_timeout_in_seconds: None,
available_models: Default::default(),
} }
} }
} }
@ -202,6 +204,7 @@ pub enum AssistantProviderContent {
default_model: Option<OpenAiModel>, default_model: Option<OpenAiModel>,
api_url: Option<String>, api_url: Option<String>,
low_speed_timeout_in_seconds: Option<u64>, low_speed_timeout_in_seconds: Option<u64>,
available_models: Option<Vec<OpenAiModel>>,
}, },
#[serde(rename = "anthropic")] #[serde(rename = "anthropic")]
Anthropic { Anthropic {
@ -272,6 +275,7 @@ impl AssistantSettingsContent {
default_model: settings.default_open_ai_model.clone(), default_model: settings.default_open_ai_model.clone(),
api_url: Some(open_ai_api_url.clone()), api_url: Some(open_ai_api_url.clone()),
low_speed_timeout_in_seconds: None, low_speed_timeout_in_seconds: None,
available_models: Some(Default::default()),
}) })
} else { } else {
settings.default_open_ai_model.clone().map(|open_ai_model| { settings.default_open_ai_model.clone().map(|open_ai_model| {
@ -279,6 +283,7 @@ impl AssistantSettingsContent {
default_model: Some(open_ai_model), default_model: Some(open_ai_model),
api_url: None, api_url: None,
low_speed_timeout_in_seconds: None, low_speed_timeout_in_seconds: None,
available_models: Some(Default::default()),
} }
}) })
}, },
@ -345,6 +350,7 @@ impl AssistantSettingsContent {
default_model: Some(model), default_model: Some(model),
api_url: None, api_url: None,
low_speed_timeout_in_seconds: None, low_speed_timeout_in_seconds: None,
available_models: Some(Default::default()),
}) })
} }
LanguageModel::Anthropic(model) => { LanguageModel::Anthropic(model) => {
@ -489,15 +495,18 @@ impl Settings for AssistantSettings {
model, model,
api_url, api_url,
low_speed_timeout_in_seconds, low_speed_timeout_in_seconds,
available_models,
}, },
AssistantProviderContent::OpenAi { AssistantProviderContent::OpenAi {
default_model: model_override, default_model: model_override,
api_url: api_url_override, api_url: api_url_override,
low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override, low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override,
available_models: available_models_override,
}, },
) => { ) => {
merge(model, model_override); merge(model, model_override);
merge(api_url, api_url_override); merge(api_url, api_url_override);
merge(available_models, available_models_override);
if let Some(low_speed_timeout_in_seconds_override) = if let Some(low_speed_timeout_in_seconds_override) =
low_speed_timeout_in_seconds_override low_speed_timeout_in_seconds_override
{ {
@ -558,10 +567,12 @@ impl Settings for AssistantSettings {
default_model: model, default_model: model,
api_url, api_url,
low_speed_timeout_in_seconds, low_speed_timeout_in_seconds,
available_models,
} => AssistantProvider::OpenAi { } => AssistantProvider::OpenAi {
model: model.unwrap_or_default(), model: model.unwrap_or_default(),
api_url: api_url.unwrap_or_else(|| open_ai::OPEN_AI_API_URL.into()), api_url: api_url.unwrap_or_else(|| open_ai::OPEN_AI_API_URL.into()),
low_speed_timeout_in_seconds, low_speed_timeout_in_seconds,
available_models: available_models.unwrap_or_default(),
}, },
AssistantProviderContent::Anthropic { AssistantProviderContent::Anthropic {
default_model: model, default_model: model,
@ -618,6 +629,7 @@ mod tests {
model: OpenAiModel::FourOmni, model: OpenAiModel::FourOmni,
api_url: open_ai::OPEN_AI_API_URL.into(), api_url: open_ai::OPEN_AI_API_URL.into(),
low_speed_timeout_in_seconds: None, low_speed_timeout_in_seconds: None,
available_models: Default::default(),
} }
); );
@ -640,6 +652,7 @@ mod tests {
model: OpenAiModel::FourOmni, model: OpenAiModel::FourOmni,
api_url: "test-url".into(), api_url: "test-url".into(),
low_speed_timeout_in_seconds: None, low_speed_timeout_in_seconds: None,
available_models: Default::default(),
} }
); );
SettingsStore::update_global(cx, |store, cx| { SettingsStore::update_global(cx, |store, cx| {
@ -660,6 +673,7 @@ mod tests {
model: OpenAiModel::Four, model: OpenAiModel::Four,
api_url: open_ai::OPEN_AI_API_URL.into(), api_url: open_ai::OPEN_AI_API_URL.into(),
low_speed_timeout_in_seconds: None, low_speed_timeout_in_seconds: None,
available_models: Default::default(),
} }
); );

View File

@ -24,6 +24,20 @@ use settings::{Settings, SettingsStore};
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
/// Choose which model to use for openai provider.
/// If the model is not available, try to use the first available model, or fallback to the original model.
fn choose_openai_model(
model: &::open_ai::Model,
available_models: &[::open_ai::Model],
) -> ::open_ai::Model {
available_models
.iter()
.find(|&m| m == model)
.or_else(|| available_models.first())
.unwrap_or_else(|| model)
.clone()
}
pub fn init(client: Arc<Client>, cx: &mut AppContext) { pub fn init(client: Arc<Client>, cx: &mut AppContext) {
let mut settings_version = 0; let mut settings_version = 0;
let provider = match &AssistantSettings::get_global(cx).provider { let provider = match &AssistantSettings::get_global(cx).provider {
@ -34,8 +48,9 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
model, model,
api_url, api_url,
low_speed_timeout_in_seconds, low_speed_timeout_in_seconds,
available_models,
} => CompletionProvider::OpenAi(OpenAiCompletionProvider::new( } => CompletionProvider::OpenAi(OpenAiCompletionProvider::new(
model.clone(), choose_openai_model(model, available_models),
api_url.clone(), api_url.clone(),
client.http_client(), client.http_client(),
low_speed_timeout_in_seconds.map(Duration::from_secs), low_speed_timeout_in_seconds.map(Duration::from_secs),
@ -77,10 +92,11 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
model, model,
api_url, api_url,
low_speed_timeout_in_seconds, low_speed_timeout_in_seconds,
available_models,
}, },
) => { ) => {
provider.update( provider.update(
model.clone(), choose_openai_model(model, available_models),
api_url.clone(), api_url.clone(),
low_speed_timeout_in_seconds.map(Duration::from_secs), low_speed_timeout_in_seconds.map(Duration::from_secs),
settings_version, settings_version,
@ -136,10 +152,11 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
model, model,
api_url, api_url,
low_speed_timeout_in_seconds, low_speed_timeout_in_seconds,
available_models,
}, },
) => { ) => {
*provider = CompletionProvider::OpenAi(OpenAiCompletionProvider::new( *provider = CompletionProvider::OpenAi(OpenAiCompletionProvider::new(
model.clone(), choose_openai_model(model, available_models),
api_url.clone(), api_url.clone(),
client.http_client(), client.http_client(),
low_speed_timeout_in_seconds.map(Duration::from_secs), low_speed_timeout_in_seconds.map(Duration::from_secs),
@ -201,10 +218,10 @@ impl CompletionProvider {
cx.global::<Self>() cx.global::<Self>()
} }
pub fn available_models(&self) -> Vec<LanguageModel> { pub fn available_models(&self, cx: &AppContext) -> Vec<LanguageModel> {
match self { match self {
CompletionProvider::OpenAi(provider) => provider CompletionProvider::OpenAi(provider) => provider
.available_models() .available_models(cx)
.map(LanguageModel::OpenAi) .map(LanguageModel::OpenAi)
.collect(), .collect(),
CompletionProvider::Anthropic(provider) => provider CompletionProvider::Anthropic(provider) => provider

View File

@ -1,4 +1,5 @@
use crate::assistant_settings::CloudModel; use crate::assistant_settings::CloudModel;
use crate::assistant_settings::{AssistantProvider, AssistantSettings};
use crate::{ use crate::{
assistant_settings::OpenAiModel, CompletionProvider, LanguageModel, LanguageModelRequest, Role, assistant_settings::OpenAiModel, CompletionProvider, LanguageModel, LanguageModelRequest, Role,
}; };
@ -56,8 +57,26 @@ impl OpenAiCompletionProvider {
self.settings_version = settings_version; self.settings_version = settings_version;
} }
pub fn available_models(&self) -> impl Iterator<Item = OpenAiModel> { pub fn available_models(&self, cx: &AppContext) -> impl Iterator<Item = OpenAiModel> {
if let AssistantProvider::OpenAi {
available_models, ..
} = &AssistantSettings::get_global(cx).provider
{
if !available_models.is_empty() {
// available_models is set, just return it
return available_models.clone().into_iter();
}
}
let available_models = if matches!(self.model, OpenAiModel::Custom { .. }) {
// available_models is not set but the default model is set to custom, only show custom
vec![self.model.clone()]
} else {
// default case, use all models except custom
OpenAiModel::iter() OpenAiModel::iter()
.filter(|model| !matches!(model, OpenAiModel::Custom { .. }))
.collect()
};
available_models.into_iter()
} }
pub fn settings_version(&self) -> usize { pub fn settings_version(&self) -> usize {
@ -213,7 +232,8 @@ pub fn count_open_ai_tokens(
| LanguageModel::Cloud(CloudModel::Claude3_5Sonnet) | LanguageModel::Cloud(CloudModel::Claude3_5Sonnet)
| LanguageModel::Cloud(CloudModel::Claude3Opus) | LanguageModel::Cloud(CloudModel::Claude3Opus)
| LanguageModel::Cloud(CloudModel::Claude3Sonnet) | LanguageModel::Cloud(CloudModel::Claude3Sonnet)
| LanguageModel::Cloud(CloudModel::Claude3Haiku) => { | LanguageModel::Cloud(CloudModel::Claude3Haiku)
| LanguageModel::OpenAi(OpenAiModel::Custom { .. }) => {
// Tiktoken doesn't yet support these models, so we manually use the // Tiktoken doesn't yet support these models, so we manually use the
// same tokenizer as GPT-4. // same tokenizer as GPT-4.
tiktoken_rs::num_tokens_from_messages("gpt-4", &messages) tiktoken_rs::num_tokens_from_messages("gpt-4", &messages)

View File

@ -1298,7 +1298,8 @@ impl Render for PromptEditor {
PopoverMenu::new("model-switcher") PopoverMenu::new("model-switcher")
.menu(move |cx| { .menu(move |cx| {
ContextMenu::build(cx, |mut menu, cx| { ContextMenu::build(cx, |mut menu, cx| {
for model in CompletionProvider::global(cx).available_models() { for model in CompletionProvider::global(cx).available_models(cx)
{
menu = menu.custom_entry( menu = menu.custom_entry(
{ {
let model = model.clone(); let model = model.clone();

View File

@ -23,7 +23,7 @@ impl RenderOnce for ModelSelector {
.with_handle(self.handle) .with_handle(self.handle)
.menu(move |cx| { .menu(move |cx| {
ContextMenu::build(cx, |mut menu, cx| { ContextMenu::build(cx, |mut menu, cx| {
for model in CompletionProvider::global(cx).available_models() { for model in CompletionProvider::global(cx).available_models(cx) {
menu = menu.custom_entry( menu = menu.custom_entry(
{ {
let model = model.clone(); let model = model.clone();

View File

@ -55,6 +55,8 @@ pub enum Model {
#[serde(rename = "gpt-4o", alias = "gpt-4o-2024-05-13")] #[serde(rename = "gpt-4o", alias = "gpt-4o-2024-05-13")]
#[default] #[default]
FourOmni, FourOmni,
#[serde(rename = "custom")]
Custom { name: String, max_tokens: usize },
} }
impl Model { impl Model {
@ -74,15 +76,17 @@ impl Model {
Self::Four => "gpt-4", Self::Four => "gpt-4",
Self::FourTurbo => "gpt-4-turbo-preview", Self::FourTurbo => "gpt-4-turbo-preview",
Self::FourOmni => "gpt-4o", Self::FourOmni => "gpt-4o",
Self::Custom { .. } => "custom",
} }
} }
pub fn display_name(&self) -> &'static str { pub fn display_name(&self) -> &str {
match self { match self {
Self::ThreePointFiveTurbo => "gpt-3.5-turbo", Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
Self::Four => "gpt-4", Self::Four => "gpt-4",
Self::FourTurbo => "gpt-4-turbo", Self::FourTurbo => "gpt-4-turbo",
Self::FourOmni => "gpt-4o", Self::FourOmni => "gpt-4o",
Self::Custom { name, .. } => name,
} }
} }
@ -92,12 +96,24 @@ impl Model {
Model::Four => 8192, Model::Four => 8192,
Model::FourTurbo => 128000, Model::FourTurbo => 128000,
Model::FourOmni => 128000, Model::FourOmni => 128000,
Model::Custom { max_tokens, .. } => *max_tokens,
} }
} }
} }
fn serialize_model<S>(model: &Model, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
match model {
Model::Custom { name, .. } => serializer.serialize_str(name),
_ => serializer.serialize_str(model.id()),
}
}
#[derive(Debug, Serialize)] #[derive(Debug, Serialize)]
pub struct Request { pub struct Request {
#[serde(serialize_with = "serialize_model")]
pub model: Model, pub model: Model,
pub messages: Vec<RequestMessage>, pub messages: Vec<RequestMessage>,
pub stream: bool, pub stream: bool,