From b5bd8a5c5d70eebda1206217b1d358a2e58dba7f Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Mon, 19 Aug 2024 11:09:52 -0700 Subject: [PATCH] Add logic for closed beta LLM models (#16482) Release Notes: - N/A --------- Co-authored-by: Marshall --- crates/collab/k8s/collab.template.yml | 5 ++ crates/collab/src/lib.rs | 2 + crates/collab/src/llm/authorization.rs | 27 ++++-- crates/collab/src/llm/db/queries/usages.rs | 7 +- crates/collab/src/llm/token.rs | 4 + crates/collab/src/rpc.rs | 6 +- crates/collab/src/tests/test_server.rs | 1 + crates/feature_flags/src/feature_flags.rs | 5 ++ crates/language_model/src/provider/cloud.rs | 94 +++++++++++++-------- 9 files changed, 104 insertions(+), 47 deletions(-) diff --git a/crates/collab/k8s/collab.template.yml b/crates/collab/k8s/collab.template.yml index 741d00d0e7..6d0cafc0ac 100644 --- a/crates/collab/k8s/collab.template.yml +++ b/crates/collab/k8s/collab.template.yml @@ -139,6 +139,11 @@ spec: secretKeyRef: name: anthropic key: staff_api_key + - name: LLM_CLOSED_BETA_MODEL_NAME + valueFrom: + secretKeyRef: + name: llm-closed-beta + key: model_name - name: GOOGLE_AI_API_KEY valueFrom: secretKeyRef: diff --git a/crates/collab/src/lib.rs b/crates/collab/src/lib.rs index 81cc334c43..06d741e2ca 100644 --- a/crates/collab/src/lib.rs +++ b/crates/collab/src/lib.rs @@ -168,6 +168,7 @@ pub struct Config { pub google_ai_api_key: Option>, pub anthropic_api_key: Option>, pub anthropic_staff_api_key: Option>, + pub llm_closed_beta_model_name: Option>, pub qwen2_7b_api_key: Option>, pub qwen2_7b_api_url: Option>, pub zed_client_checksum_seed: Option, @@ -219,6 +220,7 @@ impl Config { google_ai_api_key: None, anthropic_api_key: None, anthropic_staff_api_key: None, + llm_closed_beta_model_name: None, clickhouse_url: None, clickhouse_user: None, clickhouse_password: None, diff --git a/crates/collab/src/llm/authorization.rs b/crates/collab/src/llm/authorization.rs index 0b62dd4e0a..d18979a970 100644 --- a/crates/collab/src/llm/authorization.rs +++ b/crates/collab/src/llm/authorization.rs @@ -12,11 +12,12 @@ pub fn authorize_access_to_language_model( model: &str, ) -> Result<()> { authorize_access_for_country(config, country_code, provider)?; - authorize_access_to_model(claims, provider, model)?; + authorize_access_to_model(config, claims, provider, model)?; Ok(()) } fn authorize_access_to_model( + config: &Config, claims: &LlmTokenClaims, provider: LanguageModelProvider, model: &str, @@ -25,13 +26,25 @@ fn authorize_access_to_model( return Ok(()); } - match (provider, model) { - (LanguageModelProvider::Anthropic, "claude-3-5-sonnet") => Ok(()), - _ => Err(Error::http( - StatusCode::FORBIDDEN, - format!("access to model {model:?} is not included in your plan"), - ))?, + match provider { + LanguageModelProvider::Anthropic => { + if model == "claude-3-5-sonnet" { + return Ok(()); + } + + if claims.has_llm_closed_beta_feature_flag + && Some(model) == config.llm_closed_beta_model_name.as_deref() + { + return Ok(()); + } + } + _ => {} } + + Err(Error::http( + StatusCode::FORBIDDEN, + format!("access to model {model:?} is not included in your plan"), + )) } fn authorize_access_for_country( diff --git a/crates/collab/src/llm/db/queries/usages.rs b/crates/collab/src/llm/db/queries/usages.rs index 0bfbb4c1b1..fbffca1c89 100644 --- a/crates/collab/src/llm/db/queries/usages.rs +++ b/crates/collab/src/llm/db/queries/usages.rs @@ -82,12 +82,13 @@ impl LlmDatabase { let tokens_per_minute = self.usage_measure_ids[&UsageMeasure::TokensPerMinute]; let mut results = Vec::new(); - for (provider, model) in self.models.keys().cloned() { + for ((provider, model_name), model) in self.models.iter() { let mut usages = usage::Entity::find() .filter( usage::Column::Timestamp .gte(past_minute.naive_utc()) .and(usage::Column::IsStaff.eq(false)) + .and(usage::Column::ModelId.eq(model.id)) .and( usage::Column::MeasureId .eq(requests_per_minute) @@ -125,8 +126,8 @@ impl LlmDatabase { } results.push(ApplicationWideUsage { - provider, - model, + provider: *provider, + model: model_name.clone(), requests_this_minute, tokens_this_minute, }) diff --git a/crates/collab/src/llm/token.rs b/crates/collab/src/llm/token.rs index f789e5c220..e1e6c73326 100644 --- a/crates/collab/src/llm/token.rs +++ b/crates/collab/src/llm/token.rs @@ -20,6 +20,8 @@ pub struct LlmTokenClaims { #[serde(default)] pub github_user_login: Option, pub is_staff: bool, + #[serde(default)] + pub has_llm_closed_beta_feature_flag: bool, pub plan: rpc::proto::Plan, } @@ -30,6 +32,7 @@ impl LlmTokenClaims { user_id: UserId, github_user_login: String, is_staff: bool, + has_llm_closed_beta_feature_flag: bool, plan: rpc::proto::Plan, config: &Config, ) -> Result { @@ -46,6 +49,7 @@ impl LlmTokenClaims { user_id: user_id.to_proto(), github_user_login: Some(github_user_login), is_staff, + has_llm_closed_beta_feature_flag, plan, }; diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index b3dc904df1..f436c02e3e 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -4918,7 +4918,10 @@ async fn get_llm_api_token( let db = session.db().await; let flags = db.get_user_flags(session.user_id()).await?; - if !session.is_staff() && !flags.iter().any(|flag| flag == "language-models") { + let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models"); + let has_llm_closed_beta_feature_flag = flags.iter().any(|flag| flag == "llm-closed-beta"); + + if !session.is_staff() && !has_language_models_feature_flag { Err(anyhow!("permission denied"))? } @@ -4943,6 +4946,7 @@ async fn get_llm_api_token( user.id, user.github_login.clone(), session.is_staff(), + has_llm_closed_beta_feature_flag, session.current_plan(db).await?, &session.app_state.config, )?; diff --git a/crates/collab/src/tests/test_server.rs b/crates/collab/src/tests/test_server.rs index 16243e1ff2..6fde54ff1c 100644 --- a/crates/collab/src/tests/test_server.rs +++ b/crates/collab/src/tests/test_server.rs @@ -667,6 +667,7 @@ impl TestServer { google_ai_api_key: None, anthropic_api_key: None, anthropic_staff_api_key: None, + llm_closed_beta_model_name: None, clickhouse_url: None, clickhouse_user: None, clickhouse_password: None, diff --git a/crates/feature_flags/src/feature_flags.rs b/crates/feature_flags/src/feature_flags.rs index 0270e5b2c8..29768138af 100644 --- a/crates/feature_flags/src/feature_flags.rs +++ b/crates/feature_flags/src/feature_flags.rs @@ -43,6 +43,11 @@ impl FeatureFlag for LanguageModels { const NAME: &'static str = "language-models"; } +pub struct LlmClosedBeta {} +impl FeatureFlag for LlmClosedBeta { + const NAME: &'static str = "llm-closed-beta"; +} + pub struct ZedPro {} impl FeatureFlag for ZedPro { const NAME: &'static str = "zed-pro"; diff --git a/crates/language_model/src/provider/cloud.rs b/crates/language_model/src/provider/cloud.rs index 8372129d5a..a42459b29b 100644 --- a/crates/language_model/src/provider/cloud.rs +++ b/crates/language_model/src/provider/cloud.rs @@ -8,7 +8,7 @@ use anthropic::AnthropicError; use anyhow::{anyhow, Result}; use client::{Client, PerformCompletionParams, UserStore, EXPIRED_LLM_TOKEN_HEADER_NAME}; use collections::BTreeMap; -use feature_flags::{FeatureFlagAppExt, ZedPro}; +use feature_flags::{FeatureFlagAppExt, LlmClosedBeta, ZedPro}; use futures::{ future::BoxFuture, stream::BoxStream, AsyncBufReadExt, FutureExt, Stream, StreamExt, TryStreamExt as _, @@ -26,7 +26,10 @@ use smol::{ io::{AsyncReadExt, BufReader}, lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard}, }; -use std::{future, sync::Arc}; +use std::{ + future, + sync::{Arc, LazyLock}, +}; use strum::IntoEnumIterator; use ui::prelude::*; @@ -37,6 +40,18 @@ use super::anthropic::count_anthropic_tokens; pub const PROVIDER_ID: &str = "zed.dev"; pub const PROVIDER_NAME: &str = "Zed"; +const ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON: Option<&str> = + option_env!("ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON"); + +fn zed_cloud_provider_additional_models() -> &'static [AvailableModel] { + static ADDITIONAL_MODELS: LazyLock> = LazyLock::new(|| { + ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON + .map(|json| serde_json::from_str(json).unwrap()) + .unwrap_or(Vec::new()) + }); + ADDITIONAL_MODELS.as_slice() +} + #[derive(Default, Clone, Debug, PartialEq)] pub struct ZedDotDevSettings { pub available_models: Vec, @@ -200,40 +215,6 @@ impl LanguageModelProvider for CloudLanguageModelProvider { for model in ZedModel::iter() { models.insert(model.id().to_string(), CloudModel::Zed(model)); } - - // Override with available models from settings - for model in &AllLanguageModelSettings::get_global(cx) - .zed_dot_dev - .available_models - { - let model = match model.provider { - AvailableProvider::Anthropic => { - CloudModel::Anthropic(anthropic::Model::Custom { - name: model.name.clone(), - display_name: model.display_name.clone(), - max_tokens: model.max_tokens, - tool_override: model.tool_override.clone(), - cache_configuration: model.cache_configuration.as_ref().map(|config| { - anthropic::AnthropicModelCacheConfiguration { - max_cache_anchors: config.max_cache_anchors, - should_speculate: config.should_speculate, - min_total_token: config.min_total_token, - } - }), - max_output_tokens: model.max_output_tokens, - }) - } - AvailableProvider::OpenAi => CloudModel::OpenAi(open_ai::Model::Custom { - name: model.name.clone(), - max_tokens: model.max_tokens, - }), - AvailableProvider::Google => CloudModel::Google(google_ai::Model::Custom { - name: model.name.clone(), - max_tokens: model.max_tokens, - }), - }; - models.insert(model.id().to_string(), model.clone()); - } } else { models.insert( anthropic::Model::Claude3_5Sonnet.id().to_string(), @@ -241,6 +222,47 @@ impl LanguageModelProvider for CloudLanguageModelProvider { ); } + let llm_closed_beta_models = if cx.has_flag::() { + zed_cloud_provider_additional_models() + } else { + &[] + }; + + // Override with available models from settings + for model in AllLanguageModelSettings::get_global(cx) + .zed_dot_dev + .available_models + .iter() + .chain(llm_closed_beta_models) + .cloned() + { + let model = match model.provider { + AvailableProvider::Anthropic => CloudModel::Anthropic(anthropic::Model::Custom { + name: model.name.clone(), + display_name: model.display_name.clone(), + max_tokens: model.max_tokens, + tool_override: model.tool_override.clone(), + cache_configuration: model.cache_configuration.as_ref().map(|config| { + anthropic::AnthropicModelCacheConfiguration { + max_cache_anchors: config.max_cache_anchors, + should_speculate: config.should_speculate, + min_total_token: config.min_total_token, + } + }), + max_output_tokens: model.max_output_tokens, + }), + AvailableProvider::OpenAi => CloudModel::OpenAi(open_ai::Model::Custom { + name: model.name.clone(), + max_tokens: model.max_tokens, + }), + AvailableProvider::Google => CloudModel::Google(google_ai::Model::Custom { + name: model.name.clone(), + max_tokens: model.max_tokens, + }), + }; + models.insert(model.id().to_string(), model.clone()); + } + models .into_values() .map(|model| {