Add feature-flagged access to LLM service (#16136)

This PR adds feature-flagged access to the LLM service.

We've repurposed the `language-models` feature flag to be used for
providing access to Claude 3.5 Sonnet through the Zed provider.

The remaining RPC endpoints that were previously behind the
`language-models` feature flag are now behind a staff check.

We also put some Zed Pro related messaging behind a feature flag.

Release Notes:

- N/A

---------

Co-authored-by: Max <max@zed.dev>
This commit is contained in:
Marshall Bowers 2024-08-12 18:13:40 -04:00 committed by GitHub
parent 3bebb8b401
commit 8a148f3a13
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 49 additions and 43 deletions

View File

@ -4501,7 +4501,7 @@ async fn count_language_model_tokens(
let Some(session) = session.for_user() else {
return Err(anyhow!("user not found"))?;
};
authorize_access_to_language_models(&session).await?;
authorize_access_to_legacy_llm_endpoints(&session).await?;
let rate_limit: Box<dyn RateLimit> = match session.current_plan().await? {
proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
@ -4621,7 +4621,7 @@ async fn compute_embeddings(
api_key: Option<Arc<str>>,
) -> Result<()> {
let api_key = api_key.context("no OpenAI API key configured on the server")?;
authorize_access_to_language_models(&session).await?;
authorize_access_to_legacy_llm_endpoints(&session).await?;
let rate_limit: Box<dyn RateLimit> = match session.current_plan().await? {
proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
@ -4685,7 +4685,7 @@ async fn get_cached_embeddings(
response: Response<proto::GetCachedEmbeddings>,
session: UserSession,
) -> Result<()> {
authorize_access_to_language_models(&session).await?;
authorize_access_to_legacy_llm_endpoints(&session).await?;
let db = session.db().await;
let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
@ -4699,14 +4699,15 @@ async fn get_cached_embeddings(
Ok(())
}
async fn authorize_access_to_language_models(session: &UserSession) -> Result<(), Error> {
let db = session.db().await;
let flags = db.get_user_flags(session.user_id()).await?;
if flags.iter().any(|flag| flag == "language-models") {
return Ok(());
/// This is leftover from before the LLM service.
///
/// The endpoints protected by this check will be moved there eventually.
async fn authorize_access_to_legacy_llm_endpoints(session: &UserSession) -> Result<(), Error> {
if session.is_staff() {
Ok(())
} else {
Err(anyhow!("permission denied"))?
}
Err(anyhow!("permission denied"))?
}
/// Get a Supermaven API key for the user
@ -4915,12 +4916,13 @@ async fn get_llm_api_token(
response: Response<proto::GetLlmToken>,
session: UserSession,
) -> Result<()> {
if !session.is_staff() {
let db = session.db().await;
let flags = db.get_user_flags(session.user_id()).await?;
if !session.is_staff() && !flags.iter().any(|flag| flag == "language-models") {
Err(anyhow!("permission denied"))?
}
let db = session.db().await;
let user_id = session.user_id();
let user = db
.get_user_by_id(user_id)

View File

@ -8,7 +8,7 @@ use anthropic::AnthropicError;
use anyhow::{anyhow, bail, Context as _, Result};
use client::{Client, PerformCompletionParams, UserStore, EXPIRED_LLM_TOKEN_HEADER_NAME};
use collections::BTreeMap;
use feature_flags::{FeatureFlagAppExt, LanguageModels};
use feature_flags::{FeatureFlagAppExt, ZedPro};
use futures::{future::BoxFuture, stream::BoxStream, AsyncBufReadExt, FutureExt, StreamExt};
use gpui::{
AnyElement, AnyView, AppContext, AsyncAppContext, FontWeight, Model, ModelContext,
@ -168,13 +168,7 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
let mut models = BTreeMap::default();
let is_user = !cx.has_flag::<LanguageModels>();
if is_user {
models.insert(
anthropic::Model::Claude3_5Sonnet.id().to_string(),
CloudModel::Anthropic(anthropic::Model::Claude3_5Sonnet),
);
} else {
if cx.is_staff() {
for model in anthropic::Model::iter() {
if !matches!(model, anthropic::Model::Custom { .. }) {
models.insert(model.id().to_string(), CloudModel::Anthropic(model));
@ -218,6 +212,11 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
};
models.insert(model.id().to_string(), model.clone());
}
} else {
models.insert(
anthropic::Model::Claude3_5Sonnet.id().to_string(),
CloudModel::Anthropic(anthropic::Model::Claude3_5Sonnet),
);
}
models
@ -869,34 +868,39 @@ impl Render for ConfigurationView {
if is_pro {
"You have full access to Zed's hosted models from Anthropic, OpenAI, Google with faster speeds and higher limits through Zed Pro."
} else {
"You have basic access to models from Anthropic, OpenAI, Google and more through the Zed AI Free plan."
"You have basic access to models from Anthropic through the Zed AI Free plan."
}))
.child(
if is_pro {
.children(if is_pro {
Some(
h_flex().child(
Button::new("manage_settings", "Manage Subscription")
.style(ButtonStyle::Filled)
.on_click(cx.listener(|_, _, cx| {
cx.open_url(ACCOUNT_SETTINGS_URL)
})))
} else {
Button::new("manage_settings", "Manage Subscription")
.style(ButtonStyle::Filled)
.on_click(
cx.listener(|_, _, cx| cx.open_url(ACCOUNT_SETTINGS_URL)),
),
),
)
} else if cx.has_flag::<ZedPro>() {
Some(
h_flex()
.gap_2()
.child(
Button::new("learn_more", "Learn more")
.style(ButtonStyle::Subtle)
.on_click(cx.listener(|_, _, cx| {
cx.open_url(ZED_AI_URL)
})))
Button::new("learn_more", "Learn more")
.style(ButtonStyle::Subtle)
.on_click(cx.listener(|_, _, cx| cx.open_url(ZED_AI_URL))),
)
.child(
Button::new("upgrade", "Upgrade")
.style(ButtonStyle::Subtle)
.color(Color::Accent)
.on_click(cx.listener(|_, _, cx| {
cx.open_url(ACCOUNT_SETTINGS_URL)
})))
},
)
Button::new("upgrade", "Upgrade")
.style(ButtonStyle::Subtle)
.color(Color::Accent)
.on_click(
cx.listener(|_, _, cx| cx.open_url(ACCOUNT_SETTINGS_URL)),
),
),
)
} else {
None
})
} else {
v_flex()
.gap_6()