diff --git a/crates/assistant/src/assistant_settings.rs b/crates/assistant/src/assistant_settings.rs index 180281e7d4..9d89fe2920 100644 --- a/crates/assistant/src/assistant_settings.rs +++ b/crates/assistant/src/assistant_settings.rs @@ -153,8 +153,8 @@ impl AssistantSettingsContent { models .into_iter() .filter_map(|model| match model { - open_ai::Model::Custom { name, max_tokens } => { - Some(language_model::provider::open_ai::AvailableModel { name, max_tokens }) + open_ai::Model::Custom { name, max_tokens,max_output_tokens } => { + Some(language_model::provider::open_ai::AvailableModel { name, max_tokens,max_output_tokens }) } _ => None, }) diff --git a/crates/language_model/src/provider/cloud.rs b/crates/language_model/src/provider/cloud.rs index a42459b29b..0517986f28 100644 --- a/crates/language_model/src/provider/cloud.rs +++ b/crates/language_model/src/provider/cloud.rs @@ -254,6 +254,7 @@ impl LanguageModelProvider for CloudLanguageModelProvider { AvailableProvider::OpenAi => CloudModel::OpenAi(open_ai::Model::Custom { name: model.name.clone(), max_tokens: model.max_tokens, + max_output_tokens: model.max_output_tokens, }), AvailableProvider::Google => CloudModel::Google(google_ai::Model::Custom { name: model.name.clone(), @@ -513,7 +514,7 @@ impl LanguageModel for CloudLanguageModel { } CloudModel::OpenAi(model) => { let client = self.client.clone(); - let request = request.into_open_ai(model.id().into()); + let request = request.into_open_ai(model.id().into(), model.max_output_tokens()); let llm_api_token = self.llm_api_token.clone(); let future = self.request_limiter.stream(async move { let response = Self::perform_llm_completion( @@ -557,7 +558,7 @@ impl LanguageModel for CloudLanguageModel { } CloudModel::Zed(model) => { let client = self.client.clone(); - let mut request = request.into_open_ai(model.id().into()); + let mut request = request.into_open_ai(model.id().into(), None); request.max_tokens = Some(4000); let llm_api_token = self.llm_api_token.clone(); let future = self.request_limiter.stream(async move { @@ -629,7 +630,8 @@ impl LanguageModel for CloudLanguageModel { .boxed() } CloudModel::OpenAi(model) => { - let mut request = request.into_open_ai(model.id().into()); + let mut request = + request.into_open_ai(model.id().into(), model.max_output_tokens()); request.tool_choice = Some(open_ai::ToolChoice::Other( open_ai::ToolDefinition::Function { function: open_ai::FunctionDefinition { @@ -676,7 +678,7 @@ impl LanguageModel for CloudLanguageModel { } CloudModel::Zed(model) => { // All Zed models are OpenAI-based at the time of writing. - let mut request = request.into_open_ai(model.id().into()); + let mut request = request.into_open_ai(model.id().into(), None); request.tool_choice = Some(open_ai::ToolChoice::Other( open_ai::ToolDefinition::Function { function: open_ai::FunctionDefinition { diff --git a/crates/language_model/src/provider/open_ai.rs b/crates/language_model/src/provider/open_ai.rs index 65941f5c55..af2c0eb41d 100644 --- a/crates/language_model/src/provider/open_ai.rs +++ b/crates/language_model/src/provider/open_ai.rs @@ -40,6 +40,7 @@ pub struct OpenAiSettings { pub struct AvailableModel { pub name: String, pub max_tokens: usize, + pub max_output_tokens: Option, } pub struct OpenAiLanguageModelProvider { @@ -170,6 +171,7 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider { open_ai::Model::Custom { name: model.name.clone(), max_tokens: model.max_tokens, + max_output_tokens: model.max_output_tokens, }, ); } @@ -275,6 +277,10 @@ impl LanguageModel for OpenAiLanguageModel { self.model.max_token_count() } + fn max_output_tokens(&self) -> Option { + self.model.max_output_tokens() + } + fn count_tokens( &self, request: LanguageModelRequest, @@ -288,7 +294,7 @@ impl LanguageModel for OpenAiLanguageModel { request: LanguageModelRequest, cx: &AsyncAppContext, ) -> BoxFuture<'static, Result>>> { - let request = request.into_open_ai(self.model.id().into()); + let request = request.into_open_ai(self.model.id().into(), self.max_output_tokens()); let completions = self.stream_completion(request, cx); async move { Ok(open_ai::extract_text_from_events(completions.await?).boxed()) }.boxed() } @@ -301,7 +307,7 @@ impl LanguageModel for OpenAiLanguageModel { schema: serde_json::Value, cx: &AsyncAppContext, ) -> BoxFuture<'static, Result>>> { - let mut request = request.into_open_ai(self.model.id().into()); + let mut request = request.into_open_ai(self.model.id().into(), self.max_output_tokens()); request.tool_choice = Some(ToolChoice::Other(ToolDefinition::Function { function: FunctionDefinition { name: tool_name.clone(), diff --git a/crates/language_model/src/request.rs b/crates/language_model/src/request.rs index ecebc5e868..ef30d1904b 100644 --- a/crates/language_model/src/request.rs +++ b/crates/language_model/src/request.rs @@ -229,7 +229,7 @@ pub struct LanguageModelRequest { } impl LanguageModelRequest { - pub fn into_open_ai(self, model: String) -> open_ai::Request { + pub fn into_open_ai(self, model: String, max_output_tokens: Option) -> open_ai::Request { open_ai::Request { model, messages: self @@ -251,7 +251,7 @@ impl LanguageModelRequest { stream: true, stop: self.stop, temperature: self.temperature, - max_tokens: None, + max_tokens: max_output_tokens, tools: Vec::new(), tool_choice: None, } diff --git a/crates/language_model/src/settings.rs b/crates/language_model/src/settings.rs index afe223f842..923ad5b5e9 100644 --- a/crates/language_model/src/settings.rs +++ b/crates/language_model/src/settings.rs @@ -172,9 +172,15 @@ impl OpenAiSettingsContent { models .into_iter() .filter_map(|model| match model { - open_ai::Model::Custom { name, max_tokens } => { - Some(provider::open_ai::AvailableModel { name, max_tokens }) - } + open_ai::Model::Custom { + name, + max_tokens, + max_output_tokens, + } => Some(provider::open_ai::AvailableModel { + name, + max_tokens, + max_output_tokens, + }), _ => None, }) .collect() diff --git a/crates/open_ai/src/open_ai.rs b/crates/open_ai/src/open_ai.rs index 291fc1d0ec..ecb0828ee6 100644 --- a/crates/open_ai/src/open_ai.rs +++ b/crates/open_ai/src/open_ai.rs @@ -66,7 +66,11 @@ pub enum Model { #[serde(rename = "gpt-4o-mini", alias = "gpt-4o-mini-2024-07-18")] FourOmniMini, #[serde(rename = "custom")] - Custom { name: String, max_tokens: usize }, + Custom { + name: String, + max_tokens: usize, + max_output_tokens: Option, + }, } impl Model { @@ -113,6 +117,19 @@ impl Model { Self::Custom { max_tokens, .. } => *max_tokens, } } + + pub fn max_output_tokens(&self) -> Option { + match self { + Self::ThreePointFiveTurbo => Some(4096), + Self::Four => Some(8192), + Self::FourTurbo => Some(4096), + Self::FourOmni => Some(4096), + Self::FourOmniMini => Some(16384), + Self::Custom { + max_output_tokens, .. + } => *max_output_tokens, + } + } } #[derive(Debug, Serialize, Deserialize)] @@ -121,7 +138,7 @@ pub struct Request { pub messages: Vec, pub stream: bool, #[serde(default, skip_serializing_if = "Option::is_none")] - pub max_tokens: Option, + pub max_tokens: Option, pub stop: Vec, pub temperature: f32, #[serde(default, skip_serializing_if = "Option::is_none")]