diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index b5c6ec4920..ca5e1990f4 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -4501,7 +4501,7 @@ async fn count_language_model_tokens( let Some(session) = session.for_user() else { return Err(anyhow!("user not found"))?; }; - authorize_access_to_language_models(&session).await?; + authorize_access_to_legacy_llm_endpoints(&session).await?; let rate_limit: Box = match session.current_plan().await? { proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit), @@ -4621,7 +4621,7 @@ async fn compute_embeddings( api_key: Option>, ) -> Result<()> { let api_key = api_key.context("no OpenAI API key configured on the server")?; - authorize_access_to_language_models(&session).await?; + authorize_access_to_legacy_llm_endpoints(&session).await?; let rate_limit: Box = match session.current_plan().await? { proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit), @@ -4685,7 +4685,7 @@ async fn get_cached_embeddings( response: Response, session: UserSession, ) -> Result<()> { - authorize_access_to_language_models(&session).await?; + authorize_access_to_legacy_llm_endpoints(&session).await?; let db = session.db().await; let embeddings = db.get_embeddings(&request.model, &request.digests).await?; @@ -4699,14 +4699,15 @@ async fn get_cached_embeddings( Ok(()) } -async fn authorize_access_to_language_models(session: &UserSession) -> Result<(), Error> { - let db = session.db().await; - let flags = db.get_user_flags(session.user_id()).await?; - if flags.iter().any(|flag| flag == "language-models") { - return Ok(()); +/// This is leftover from before the LLM service. +/// +/// The endpoints protected by this check will be moved there eventually. +async fn authorize_access_to_legacy_llm_endpoints(session: &UserSession) -> Result<(), Error> { + if session.is_staff() { + Ok(()) + } else { + Err(anyhow!("permission denied"))? } - - Err(anyhow!("permission denied"))? } /// Get a Supermaven API key for the user @@ -4915,12 +4916,13 @@ async fn get_llm_api_token( response: Response, session: UserSession, ) -> Result<()> { - if !session.is_staff() { + let db = session.db().await; + + let flags = db.get_user_flags(session.user_id()).await?; + if !session.is_staff() && !flags.iter().any(|flag| flag == "language-models") { Err(anyhow!("permission denied"))? } - let db = session.db().await; - let user_id = session.user_id(); let user = db .get_user_by_id(user_id) diff --git a/crates/language_model/src/provider/cloud.rs b/crates/language_model/src/provider/cloud.rs index 9c2d41e99a..a8418ff8f8 100644 --- a/crates/language_model/src/provider/cloud.rs +++ b/crates/language_model/src/provider/cloud.rs @@ -8,7 +8,7 @@ use anthropic::AnthropicError; use anyhow::{anyhow, bail, Context as _, Result}; use client::{Client, PerformCompletionParams, UserStore, EXPIRED_LLM_TOKEN_HEADER_NAME}; use collections::BTreeMap; -use feature_flags::{FeatureFlagAppExt, LanguageModels}; +use feature_flags::{FeatureFlagAppExt, ZedPro}; use futures::{future::BoxFuture, stream::BoxStream, AsyncBufReadExt, FutureExt, StreamExt}; use gpui::{ AnyElement, AnyView, AppContext, AsyncAppContext, FontWeight, Model, ModelContext, @@ -168,13 +168,7 @@ impl LanguageModelProvider for CloudLanguageModelProvider { fn provided_models(&self, cx: &AppContext) -> Vec> { let mut models = BTreeMap::default(); - let is_user = !cx.has_flag::(); - if is_user { - models.insert( - anthropic::Model::Claude3_5Sonnet.id().to_string(), - CloudModel::Anthropic(anthropic::Model::Claude3_5Sonnet), - ); - } else { + if cx.is_staff() { for model in anthropic::Model::iter() { if !matches!(model, anthropic::Model::Custom { .. }) { models.insert(model.id().to_string(), CloudModel::Anthropic(model)); @@ -218,6 +212,11 @@ impl LanguageModelProvider for CloudLanguageModelProvider { }; models.insert(model.id().to_string(), model.clone()); } + } else { + models.insert( + anthropic::Model::Claude3_5Sonnet.id().to_string(), + CloudModel::Anthropic(anthropic::Model::Claude3_5Sonnet), + ); } models @@ -869,34 +868,39 @@ impl Render for ConfigurationView { if is_pro { "You have full access to Zed's hosted models from Anthropic, OpenAI, Google with faster speeds and higher limits through Zed Pro." } else { - "You have basic access to models from Anthropic, OpenAI, Google and more through the Zed AI Free plan." + "You have basic access to models from Anthropic through the Zed AI Free plan." })) - .child( - if is_pro { + .children(if is_pro { + Some( h_flex().child( - Button::new("manage_settings", "Manage Subscription") - .style(ButtonStyle::Filled) - .on_click(cx.listener(|_, _, cx| { - cx.open_url(ACCOUNT_SETTINGS_URL) - }))) - } else { + Button::new("manage_settings", "Manage Subscription") + .style(ButtonStyle::Filled) + .on_click( + cx.listener(|_, _, cx| cx.open_url(ACCOUNT_SETTINGS_URL)), + ), + ), + ) + } else if cx.has_flag::() { + Some( h_flex() .gap_2() .child( - Button::new("learn_more", "Learn more") - .style(ButtonStyle::Subtle) - .on_click(cx.listener(|_, _, cx| { - cx.open_url(ZED_AI_URL) - }))) + Button::new("learn_more", "Learn more") + .style(ButtonStyle::Subtle) + .on_click(cx.listener(|_, _, cx| cx.open_url(ZED_AI_URL))), + ) .child( - Button::new("upgrade", "Upgrade") - .style(ButtonStyle::Subtle) - .color(Color::Accent) - .on_click(cx.listener(|_, _, cx| { - cx.open_url(ACCOUNT_SETTINGS_URL) - }))) - }, - ) + Button::new("upgrade", "Upgrade") + .style(ButtonStyle::Subtle) + .color(Color::Accent) + .on_click( + cx.listener(|_, _, cx| cx.open_url(ACCOUNT_SETTINGS_URL)), + ), + ), + ) + } else { + None + }) } else { v_flex() .gap_6()