Fix bugs preventing non-staff users from using LLM service (#16307)

- db deadlock in GetLlmToken for non-staff users
- typo in allowed model name for non-staff users

Release Notes:

- N/A

---------

Co-authored-by: Marshall <marshall@zed.dev>
Co-authored-by: Joseph <joseph@zed.dev>
This commit is contained in:
Max Brunsfeld 2024-08-15 11:21:19 -07:00 committed by GitHub
parent 931883aca9
commit 6b7664ef4a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 10 additions and 12 deletions

View File

@ -26,7 +26,7 @@ fn authorize_access_to_model(
} }
match (provider, model) { match (provider, model) {
(LanguageModelProvider::Anthropic, model) if model.starts_with("claude-3.5-sonnet") => { (LanguageModelProvider::Anthropic, model) if model.starts_with("claude-3-5-sonnet") => {
Ok(()) Ok(())
} }
_ => Err(Error::http( _ => Err(Error::http(
@ -240,14 +240,14 @@ mod tests {
( (
Plan::ZedPro, Plan::ZedPro,
LanguageModelProvider::Anthropic, LanguageModelProvider::Anthropic,
"claude-3.5-sonnet", "claude-3-5-sonnet",
true, true,
), ),
// Free plan should have access to claude-3.5-sonnet // Free plan should have access to claude-3.5-sonnet
( (
Plan::Free, Plan::Free,
LanguageModelProvider::Anthropic, LanguageModelProvider::Anthropic,
"claude-3.5-sonnet", "claude-3-5-sonnet",
true, true,
), ),
// Pro plan should NOT have access to other Anthropic models // Pro plan should NOT have access to other Anthropic models
@ -303,7 +303,7 @@ mod tests {
// Staff should have access to all models // Staff should have access to all models
let test_cases = vec![ let test_cases = vec![
(LanguageModelProvider::Anthropic, "claude-3.5-sonnet"), (LanguageModelProvider::Anthropic, "claude-3-5-sonnet"),
(LanguageModelProvider::Anthropic, "claude-2"), (LanguageModelProvider::Anthropic, "claude-2"),
(LanguageModelProvider::Anthropic, "claude-123-agi"), (LanguageModelProvider::Anthropic, "claude-123-agi"),
(LanguageModelProvider::OpenAi, "gpt-4"), (LanguageModelProvider::OpenAi, "gpt-4"),

View File

@ -71,7 +71,7 @@ use std::{
time::{Duration, Instant}, time::{Duration, Instant},
}; };
use time::OffsetDateTime; use time::OffsetDateTime;
use tokio::sync::{watch, Semaphore}; use tokio::sync::{watch, MutexGuard, Semaphore};
use tower::ServiceBuilder; use tower::ServiceBuilder;
use tracing::{ use tracing::{
field::{self}, field::{self},
@ -192,7 +192,7 @@ impl Session {
} }
} }
pub async fn current_plan(&self) -> anyhow::Result<proto::Plan> { pub async fn current_plan(&self, db: MutexGuard<'_, DbHandle>) -> anyhow::Result<proto::Plan> {
if self.is_staff() { if self.is_staff() {
return Ok(proto::Plan::ZedPro); return Ok(proto::Plan::ZedPro);
} }
@ -201,7 +201,6 @@ impl Session {
return Ok(proto::Plan::Free); return Ok(proto::Plan::Free);
}; };
let db = self.db().await;
if db.has_active_billing_subscription(user_id).await? { if db.has_active_billing_subscription(user_id).await? {
Ok(proto::Plan::ZedPro) Ok(proto::Plan::ZedPro)
} else { } else {
@ -3500,7 +3499,7 @@ fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
} }
async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> { async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
let plan = session.current_plan().await?; let plan = session.current_plan(session.db().await).await?;
session session
.peer .peer
@ -4503,7 +4502,7 @@ async fn count_language_model_tokens(
}; };
authorize_access_to_legacy_llm_endpoints(&session).await?; authorize_access_to_legacy_llm_endpoints(&session).await?;
let rate_limit: Box<dyn RateLimit> = match session.current_plan().await? { let rate_limit: Box<dyn RateLimit> = match session.current_plan(session.db().await).await? {
proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit), proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit), proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
}; };
@ -4623,7 +4622,7 @@ async fn compute_embeddings(
let api_key = api_key.context("no OpenAI API key configured on the server")?; let api_key = api_key.context("no OpenAI API key configured on the server")?;
authorize_access_to_legacy_llm_endpoints(&session).await?; authorize_access_to_legacy_llm_endpoints(&session).await?;
let rate_limit: Box<dyn RateLimit> = match session.current_plan().await? { let rate_limit: Box<dyn RateLimit> = match session.current_plan(session.db().await).await? {
proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit), proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit), proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
}; };
@ -4940,11 +4939,10 @@ async fn get_llm_api_token(
if Utc::now().naive_utc() - account_created_at < MIN_ACCOUNT_AGE_FOR_LLM_USE { if Utc::now().naive_utc() - account_created_at < MIN_ACCOUNT_AGE_FOR_LLM_USE {
Err(anyhow!("account too young"))? Err(anyhow!("account too young"))?
} }
let token = LlmTokenClaims::create( let token = LlmTokenClaims::create(
user.id, user.id,
session.is_staff(), session.is_staff(),
session.current_plan().await?, session.current_plan(db).await?,
&session.app_state.config, &session.app_state.config,
)?; )?;
response.send(proto::GetLlmTokenResponse { token })?; response.send(proto::GetLlmTokenResponse { token })?;