mirror of
https://github.com/zed-industries/zed.git
synced 2024-09-20 02:47:34 +03:00
Add feature-flagged access to LLM service (#16136)
This PR adds feature-flagged access to the LLM service. We've repurposed the `language-models` feature flag to be used for providing access to Claude 3.5 Sonnet through the Zed provider. The remaining RPC endpoints that were previously behind the `language-models` feature flag are now behind a staff check. We also put some Zed Pro related messaging behind a feature flag. Release Notes: - N/A --------- Co-authored-by: Max <max@zed.dev>
This commit is contained in:
parent
3bebb8b401
commit
8a148f3a13
@ -4501,7 +4501,7 @@ async fn count_language_model_tokens(
|
|||||||
let Some(session) = session.for_user() else {
|
let Some(session) = session.for_user() else {
|
||||||
return Err(anyhow!("user not found"))?;
|
return Err(anyhow!("user not found"))?;
|
||||||
};
|
};
|
||||||
authorize_access_to_language_models(&session).await?;
|
authorize_access_to_legacy_llm_endpoints(&session).await?;
|
||||||
|
|
||||||
let rate_limit: Box<dyn RateLimit> = match session.current_plan().await? {
|
let rate_limit: Box<dyn RateLimit> = match session.current_plan().await? {
|
||||||
proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
|
proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
|
||||||
@ -4621,7 +4621,7 @@ async fn compute_embeddings(
|
|||||||
api_key: Option<Arc<str>>,
|
api_key: Option<Arc<str>>,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let api_key = api_key.context("no OpenAI API key configured on the server")?;
|
let api_key = api_key.context("no OpenAI API key configured on the server")?;
|
||||||
authorize_access_to_language_models(&session).await?;
|
authorize_access_to_legacy_llm_endpoints(&session).await?;
|
||||||
|
|
||||||
let rate_limit: Box<dyn RateLimit> = match session.current_plan().await? {
|
let rate_limit: Box<dyn RateLimit> = match session.current_plan().await? {
|
||||||
proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
|
proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
|
||||||
@ -4685,7 +4685,7 @@ async fn get_cached_embeddings(
|
|||||||
response: Response<proto::GetCachedEmbeddings>,
|
response: Response<proto::GetCachedEmbeddings>,
|
||||||
session: UserSession,
|
session: UserSession,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
authorize_access_to_language_models(&session).await?;
|
authorize_access_to_legacy_llm_endpoints(&session).await?;
|
||||||
|
|
||||||
let db = session.db().await;
|
let db = session.db().await;
|
||||||
let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
|
let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
|
||||||
@ -4699,14 +4699,15 @@ async fn get_cached_embeddings(
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn authorize_access_to_language_models(session: &UserSession) -> Result<(), Error> {
|
/// This is leftover from before the LLM service.
|
||||||
let db = session.db().await;
|
///
|
||||||
let flags = db.get_user_flags(session.user_id()).await?;
|
/// The endpoints protected by this check will be moved there eventually.
|
||||||
if flags.iter().any(|flag| flag == "language-models") {
|
async fn authorize_access_to_legacy_llm_endpoints(session: &UserSession) -> Result<(), Error> {
|
||||||
return Ok(());
|
if session.is_staff() {
|
||||||
|
Ok(())
|
||||||
|
} else {
|
||||||
|
Err(anyhow!("permission denied"))?
|
||||||
}
|
}
|
||||||
|
|
||||||
Err(anyhow!("permission denied"))?
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get a Supermaven API key for the user
|
/// Get a Supermaven API key for the user
|
||||||
@ -4915,12 +4916,13 @@ async fn get_llm_api_token(
|
|||||||
response: Response<proto::GetLlmToken>,
|
response: Response<proto::GetLlmToken>,
|
||||||
session: UserSession,
|
session: UserSession,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
if !session.is_staff() {
|
let db = session.db().await;
|
||||||
|
|
||||||
|
let flags = db.get_user_flags(session.user_id()).await?;
|
||||||
|
if !session.is_staff() && !flags.iter().any(|flag| flag == "language-models") {
|
||||||
Err(anyhow!("permission denied"))?
|
Err(anyhow!("permission denied"))?
|
||||||
}
|
}
|
||||||
|
|
||||||
let db = session.db().await;
|
|
||||||
|
|
||||||
let user_id = session.user_id();
|
let user_id = session.user_id();
|
||||||
let user = db
|
let user = db
|
||||||
.get_user_by_id(user_id)
|
.get_user_by_id(user_id)
|
||||||
|
@ -8,7 +8,7 @@ use anthropic::AnthropicError;
|
|||||||
use anyhow::{anyhow, bail, Context as _, Result};
|
use anyhow::{anyhow, bail, Context as _, Result};
|
||||||
use client::{Client, PerformCompletionParams, UserStore, EXPIRED_LLM_TOKEN_HEADER_NAME};
|
use client::{Client, PerformCompletionParams, UserStore, EXPIRED_LLM_TOKEN_HEADER_NAME};
|
||||||
use collections::BTreeMap;
|
use collections::BTreeMap;
|
||||||
use feature_flags::{FeatureFlagAppExt, LanguageModels};
|
use feature_flags::{FeatureFlagAppExt, ZedPro};
|
||||||
use futures::{future::BoxFuture, stream::BoxStream, AsyncBufReadExt, FutureExt, StreamExt};
|
use futures::{future::BoxFuture, stream::BoxStream, AsyncBufReadExt, FutureExt, StreamExt};
|
||||||
use gpui::{
|
use gpui::{
|
||||||
AnyElement, AnyView, AppContext, AsyncAppContext, FontWeight, Model, ModelContext,
|
AnyElement, AnyView, AppContext, AsyncAppContext, FontWeight, Model, ModelContext,
|
||||||
@ -168,13 +168,7 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
|
|||||||
fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
|
fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
|
||||||
let mut models = BTreeMap::default();
|
let mut models = BTreeMap::default();
|
||||||
|
|
||||||
let is_user = !cx.has_flag::<LanguageModels>();
|
if cx.is_staff() {
|
||||||
if is_user {
|
|
||||||
models.insert(
|
|
||||||
anthropic::Model::Claude3_5Sonnet.id().to_string(),
|
|
||||||
CloudModel::Anthropic(anthropic::Model::Claude3_5Sonnet),
|
|
||||||
);
|
|
||||||
} else {
|
|
||||||
for model in anthropic::Model::iter() {
|
for model in anthropic::Model::iter() {
|
||||||
if !matches!(model, anthropic::Model::Custom { .. }) {
|
if !matches!(model, anthropic::Model::Custom { .. }) {
|
||||||
models.insert(model.id().to_string(), CloudModel::Anthropic(model));
|
models.insert(model.id().to_string(), CloudModel::Anthropic(model));
|
||||||
@ -218,6 +212,11 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
|
|||||||
};
|
};
|
||||||
models.insert(model.id().to_string(), model.clone());
|
models.insert(model.id().to_string(), model.clone());
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
models.insert(
|
||||||
|
anthropic::Model::Claude3_5Sonnet.id().to_string(),
|
||||||
|
CloudModel::Anthropic(anthropic::Model::Claude3_5Sonnet),
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
models
|
models
|
||||||
@ -869,34 +868,39 @@ impl Render for ConfigurationView {
|
|||||||
if is_pro {
|
if is_pro {
|
||||||
"You have full access to Zed's hosted models from Anthropic, OpenAI, Google with faster speeds and higher limits through Zed Pro."
|
"You have full access to Zed's hosted models from Anthropic, OpenAI, Google with faster speeds and higher limits through Zed Pro."
|
||||||
} else {
|
} else {
|
||||||
"You have basic access to models from Anthropic, OpenAI, Google and more through the Zed AI Free plan."
|
"You have basic access to models from Anthropic through the Zed AI Free plan."
|
||||||
}))
|
}))
|
||||||
.child(
|
.children(if is_pro {
|
||||||
if is_pro {
|
Some(
|
||||||
h_flex().child(
|
h_flex().child(
|
||||||
Button::new("manage_settings", "Manage Subscription")
|
Button::new("manage_settings", "Manage Subscription")
|
||||||
.style(ButtonStyle::Filled)
|
.style(ButtonStyle::Filled)
|
||||||
.on_click(cx.listener(|_, _, cx| {
|
.on_click(
|
||||||
cx.open_url(ACCOUNT_SETTINGS_URL)
|
cx.listener(|_, _, cx| cx.open_url(ACCOUNT_SETTINGS_URL)),
|
||||||
})))
|
),
|
||||||
} else {
|
),
|
||||||
|
)
|
||||||
|
} else if cx.has_flag::<ZedPro>() {
|
||||||
|
Some(
|
||||||
h_flex()
|
h_flex()
|
||||||
.gap_2()
|
.gap_2()
|
||||||
.child(
|
.child(
|
||||||
Button::new("learn_more", "Learn more")
|
Button::new("learn_more", "Learn more")
|
||||||
.style(ButtonStyle::Subtle)
|
.style(ButtonStyle::Subtle)
|
||||||
.on_click(cx.listener(|_, _, cx| {
|
.on_click(cx.listener(|_, _, cx| cx.open_url(ZED_AI_URL))),
|
||||||
cx.open_url(ZED_AI_URL)
|
)
|
||||||
})))
|
|
||||||
.child(
|
.child(
|
||||||
Button::new("upgrade", "Upgrade")
|
Button::new("upgrade", "Upgrade")
|
||||||
.style(ButtonStyle::Subtle)
|
.style(ButtonStyle::Subtle)
|
||||||
.color(Color::Accent)
|
.color(Color::Accent)
|
||||||
.on_click(cx.listener(|_, _, cx| {
|
.on_click(
|
||||||
cx.open_url(ACCOUNT_SETTINGS_URL)
|
cx.listener(|_, _, cx| cx.open_url(ACCOUNT_SETTINGS_URL)),
|
||||||
})))
|
),
|
||||||
},
|
),
|
||||||
)
|
)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
})
|
||||||
} else {
|
} else {
|
||||||
v_flex()
|
v_flex()
|
||||||
.gap_6()
|
.gap_6()
|
||||||
|
Loading…
Reference in New Issue
Block a user