From f1778dd9de5ae20e9ca3c6c3cf344087b88634fe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=BB=E4=BA=8C=E6=B0=AE=E6=9D=82=E8=8F=B2?= <40173605+Cupnfish@users.noreply.github.com> Date: Wed, 21 Aug 2024 12:39:10 +0800 Subject: [PATCH] Add max_output_tokens to OpenAI models and integrate into requests (#16381) ### Pull Request Title Introduce `max_output_tokens` Field for OpenAI Models https://platform.deepseek.com/api-docs/news/news0725/#4-8k-max_tokens-betarelease-longer-possibilities ### Description This commit introduces a new field `max_output_tokens` to the OpenAI models, which allows specifying the maximum number of tokens that can be generated in the output. This field is now integrated into the request handling across multiple crates, ensuring that the output token limit is respected during language model completions. Changes include: - Adding `max_output_tokens` to the `Custom` variant of the `open_ai::Model` enum. - Updating the `into_open_ai` method in `LanguageModelRequest` to accept and use `max_output_tokens`. - Modifying the `OpenAiLanguageModel` and `CloudLanguageModel` implementations to pass `max_output_tokens` when converting requests. - Ensuring that the `max_output_tokens` field is correctly serialized and deserialized in relevant structures. This enhancement provides more control over the output length of OpenAI model responses, improving the flexibility and accuracy of language model interactions. ### Changes - Added `max_output_tokens` to the `Custom` variant of the `open_ai::Model` enum. - Updated the `into_open_ai` method in `LanguageModelRequest` to accept and use `max_output_tokens`. - Modified the `OpenAiLanguageModel` and `CloudLanguageModel` implementations to pass `max_output_tokens` when converting requests. - Ensured that the `max_output_tokens` field is correctly serialized and deserialized in relevant structures. ### Related Issue https://github.com/zed-industries/zed/pull/16358 ### Screenshots / Media N/A ### Checklist - [x] Code compiles correctly. - [x] All tests pass. - [ ] Documentation has been updated accordingly. - [ ] Additional tests have been added to cover new functionality. - [ ] Relevant documentation has been updated or added. ### Release Notes - Added `max_output_tokens` field to OpenAI models for controlling output token length. --- crates/assistant/src/assistant_settings.rs | 4 ++-- crates/language_model/src/provider/cloud.rs | 10 +++++---- crates/language_model/src/provider/open_ai.rs | 10 +++++++-- crates/language_model/src/request.rs | 4 ++-- crates/language_model/src/settings.rs | 12 ++++++++--- crates/open_ai/src/open_ai.rs | 21 +++++++++++++++++-- 6 files changed, 46 insertions(+), 15 deletions(-) diff --git a/crates/assistant/src/assistant_settings.rs b/crates/assistant/src/assistant_settings.rs index 180281e7d4..9d89fe2920 100644 --- a/crates/assistant/src/assistant_settings.rs +++ b/crates/assistant/src/assistant_settings.rs @@ -153,8 +153,8 @@ impl AssistantSettingsContent { models .into_iter() .filter_map(|model| match model { - open_ai::Model::Custom { name, max_tokens } => { - Some(language_model::provider::open_ai::AvailableModel { name, max_tokens }) + open_ai::Model::Custom { name, max_tokens,max_output_tokens } => { + Some(language_model::provider::open_ai::AvailableModel { name, max_tokens,max_output_tokens }) } _ => None, }) diff --git a/crates/language_model/src/provider/cloud.rs b/crates/language_model/src/provider/cloud.rs index a42459b29b..0517986f28 100644 --- a/crates/language_model/src/provider/cloud.rs +++ b/crates/language_model/src/provider/cloud.rs @@ -254,6 +254,7 @@ impl LanguageModelProvider for CloudLanguageModelProvider { AvailableProvider::OpenAi => CloudModel::OpenAi(open_ai::Model::Custom { name: model.name.clone(), max_tokens: model.max_tokens, + max_output_tokens: model.max_output_tokens, }), AvailableProvider::Google => CloudModel::Google(google_ai::Model::Custom { name: model.name.clone(), @@ -513,7 +514,7 @@ impl LanguageModel for CloudLanguageModel { } CloudModel::OpenAi(model) => { let client = self.client.clone(); - let request = request.into_open_ai(model.id().into()); + let request = request.into_open_ai(model.id().into(), model.max_output_tokens()); let llm_api_token = self.llm_api_token.clone(); let future = self.request_limiter.stream(async move { let response = Self::perform_llm_completion( @@ -557,7 +558,7 @@ impl LanguageModel for CloudLanguageModel { } CloudModel::Zed(model) => { let client = self.client.clone(); - let mut request = request.into_open_ai(model.id().into()); + let mut request = request.into_open_ai(model.id().into(), None); request.max_tokens = Some(4000); let llm_api_token = self.llm_api_token.clone(); let future = self.request_limiter.stream(async move { @@ -629,7 +630,8 @@ impl LanguageModel for CloudLanguageModel { .boxed() } CloudModel::OpenAi(model) => { - let mut request = request.into_open_ai(model.id().into()); + let mut request = + request.into_open_ai(model.id().into(), model.max_output_tokens()); request.tool_choice = Some(open_ai::ToolChoice::Other( open_ai::ToolDefinition::Function { function: open_ai::FunctionDefinition { @@ -676,7 +678,7 @@ impl LanguageModel for CloudLanguageModel { } CloudModel::Zed(model) => { // All Zed models are OpenAI-based at the time of writing. - let mut request = request.into_open_ai(model.id().into()); + let mut request = request.into_open_ai(model.id().into(), None); request.tool_choice = Some(open_ai::ToolChoice::Other( open_ai::ToolDefinition::Function { function: open_ai::FunctionDefinition { diff --git a/crates/language_model/src/provider/open_ai.rs b/crates/language_model/src/provider/open_ai.rs index 65941f5c55..af2c0eb41d 100644 --- a/crates/language_model/src/provider/open_ai.rs +++ b/crates/language_model/src/provider/open_ai.rs @@ -40,6 +40,7 @@ pub struct OpenAiSettings { pub struct AvailableModel { pub name: String, pub max_tokens: usize, + pub max_output_tokens: Option, } pub struct OpenAiLanguageModelProvider { @@ -170,6 +171,7 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider { open_ai::Model::Custom { name: model.name.clone(), max_tokens: model.max_tokens, + max_output_tokens: model.max_output_tokens, }, ); } @@ -275,6 +277,10 @@ impl LanguageModel for OpenAiLanguageModel { self.model.max_token_count() } + fn max_output_tokens(&self) -> Option { + self.model.max_output_tokens() + } + fn count_tokens( &self, request: LanguageModelRequest, @@ -288,7 +294,7 @@ impl LanguageModel for OpenAiLanguageModel { request: LanguageModelRequest, cx: &AsyncAppContext, ) -> BoxFuture<'static, Result>>> { - let request = request.into_open_ai(self.model.id().into()); + let request = request.into_open_ai(self.model.id().into(), self.max_output_tokens()); let completions = self.stream_completion(request, cx); async move { Ok(open_ai::extract_text_from_events(completions.await?).boxed()) }.boxed() } @@ -301,7 +307,7 @@ impl LanguageModel for OpenAiLanguageModel { schema: serde_json::Value, cx: &AsyncAppContext, ) -> BoxFuture<'static, Result>>> { - let mut request = request.into_open_ai(self.model.id().into()); + let mut request = request.into_open_ai(self.model.id().into(), self.max_output_tokens()); request.tool_choice = Some(ToolChoice::Other(ToolDefinition::Function { function: FunctionDefinition { name: tool_name.clone(), diff --git a/crates/language_model/src/request.rs b/crates/language_model/src/request.rs index ecebc5e868..ef30d1904b 100644 --- a/crates/language_model/src/request.rs +++ b/crates/language_model/src/request.rs @@ -229,7 +229,7 @@ pub struct LanguageModelRequest { } impl LanguageModelRequest { - pub fn into_open_ai(self, model: String) -> open_ai::Request { + pub fn into_open_ai(self, model: String, max_output_tokens: Option) -> open_ai::Request { open_ai::Request { model, messages: self @@ -251,7 +251,7 @@ impl LanguageModelRequest { stream: true, stop: self.stop, temperature: self.temperature, - max_tokens: None, + max_tokens: max_output_tokens, tools: Vec::new(), tool_choice: None, } diff --git a/crates/language_model/src/settings.rs b/crates/language_model/src/settings.rs index afe223f842..923ad5b5e9 100644 --- a/crates/language_model/src/settings.rs +++ b/crates/language_model/src/settings.rs @@ -172,9 +172,15 @@ impl OpenAiSettingsContent { models .into_iter() .filter_map(|model| match model { - open_ai::Model::Custom { name, max_tokens } => { - Some(provider::open_ai::AvailableModel { name, max_tokens }) - } + open_ai::Model::Custom { + name, + max_tokens, + max_output_tokens, + } => Some(provider::open_ai::AvailableModel { + name, + max_tokens, + max_output_tokens, + }), _ => None, }) .collect() diff --git a/crates/open_ai/src/open_ai.rs b/crates/open_ai/src/open_ai.rs index 291fc1d0ec..ecb0828ee6 100644 --- a/crates/open_ai/src/open_ai.rs +++ b/crates/open_ai/src/open_ai.rs @@ -66,7 +66,11 @@ pub enum Model { #[serde(rename = "gpt-4o-mini", alias = "gpt-4o-mini-2024-07-18")] FourOmniMini, #[serde(rename = "custom")] - Custom { name: String, max_tokens: usize }, + Custom { + name: String, + max_tokens: usize, + max_output_tokens: Option, + }, } impl Model { @@ -113,6 +117,19 @@ impl Model { Self::Custom { max_tokens, .. } => *max_tokens, } } + + pub fn max_output_tokens(&self) -> Option { + match self { + Self::ThreePointFiveTurbo => Some(4096), + Self::Four => Some(8192), + Self::FourTurbo => Some(4096), + Self::FourOmni => Some(4096), + Self::FourOmniMini => Some(16384), + Self::Custom { + max_output_tokens, .. + } => *max_output_tokens, + } + } } #[derive(Debug, Serialize, Deserialize)] @@ -121,7 +138,7 @@ pub struct Request { pub messages: Vec, pub stream: bool, #[serde(default, skip_serializing_if = "Option::is_none")] - pub max_tokens: Option, + pub max_tokens: Option, pub stop: Vec, pub temperature: f32, #[serde(default, skip_serializing_if = "Option::is_none")]