diff --git a/crates/assistant/src/assistant.rs b/crates/assistant/src/assistant.rs index 036ae5c071..de810e0fc5 100644 --- a/crates/assistant/src/assistant.rs +++ b/crates/assistant/src/assistant.rs @@ -97,13 +97,8 @@ impl LanguageModel { pub fn max_token_count(&self) -> usize { match self { - LanguageModel::OpenAi(model) => tiktoken_rs::model::get_context_size(model.id()), - LanguageModel::ZedDotDev(model) => match model { - ZedDotDevModel::GptThreePointFiveTurbo - | ZedDotDevModel::GptFour - | ZedDotDevModel::GptFourTurbo => tiktoken_rs::model::get_context_size(model.id()), - ZedDotDevModel::Custom(_) => 30720, // TODO: Base this on the selected model. - }, + LanguageModel::OpenAi(model) => model.max_token_count(), + LanguageModel::ZedDotDev(model) => model.max_token_count(), } } diff --git a/crates/assistant/src/assistant_settings.rs b/crates/assistant/src/assistant_settings.rs index f338f7e8fb..f4e70f2548 100644 --- a/crates/assistant/src/assistant_settings.rs +++ b/crates/assistant/src/assistant_settings.rs @@ -109,6 +109,15 @@ impl ZedDotDevModel { Self::Custom(id) => id.as_str(), } } + + pub fn max_token_count(&self) -> usize { + match self { + Self::GptThreePointFiveTurbo => 2048, + Self::GptFour => 4096, + Self::GptFourTurbo => 128000, + Self::Custom(_) => 4096, // TODO: Make this configurable + } + } } #[derive(Copy, Clone, Default, Debug, Serialize, Deserialize, JsonSchema)] diff --git a/crates/open_ai/src/open_ai.rs b/crates/open_ai/src/open_ai.rs index 7bd7e19d5d..fcf4aa04bf 100644 --- a/crates/open_ai/src/open_ai.rs +++ b/crates/open_ai/src/open_ai.rs @@ -72,6 +72,14 @@ impl Model { Self::FourTurbo => "gpt-4-turbo", } } + + pub fn max_token_count(&self) -> usize { + match self { + Model::ThreePointFiveTurbo => 4096, + Model::Four => 8192, + Model::FourTurbo => 128000, + } + } } #[derive(Debug, Serialize)]