diff --git a/crates/collab/k8s/collab.template.yml b/crates/collab/k8s/collab.template.yml index ab12e3e61a..a8dcd23c4b 100644 --- a/crates/collab/k8s/collab.template.yml +++ b/crates/collab/k8s/collab.template.yml @@ -85,6 +85,11 @@ spec: secretKeyRef: name: database key: url + - name: LLM_DATABASE_URL + valueFrom: + secretKeyRef: + name: llm-database + key: url - name: DATABASE_MAX_CONNECTIONS value: "${DATABASE_MAX_CONNECTIONS}" - name: API_TOKEN diff --git a/crates/collab/k8s/postgrest.template.yml b/crates/collab/k8s/postgrest.template.yml index ff83880a95..4819408bff 100644 --- a/crates/collab/k8s/postgrest.template.yml +++ b/crates/collab/k8s/postgrest.template.yml @@ -12,7 +12,7 @@ metadata: spec: type: LoadBalancer selector: - app: postgrest + app: nginx ports: - name: web protocol: TCP @@ -24,17 +24,99 @@ apiVersion: apps/v1 kind: Deployment metadata: namespace: ${ZED_KUBE_NAMESPACE} - name: postgrest - + name: nginx spec: replicas: 1 selector: matchLabels: - app: postgrest + app: nginx template: metadata: labels: - app: postgrest + app: nginx + spec: + containers: + - name: nginx + image: nginx:latest + ports: + - containerPort: 8080 + protocol: TCP + volumeMounts: + - name: nginx-config + mountPath: /etc/nginx/nginx.conf + subPath: nginx.conf + volumes: + - name: nginx-config + configMap: + name: nginx-config + +--- +apiVersion: v1 +kind: ConfigMap +metadata: + namespace: ${ZED_KUBE_NAMESPACE} + name: nginx-config +data: + nginx.conf: | + events {} + + http { + server { + listen 8080; + + location /app/ { + proxy_pass http://postgrest-app:8080/; + } + + location /llm/ { + proxy_pass http://postgrest-llm:8080/; + } + } + } + +--- +apiVersion: v1 +kind: Service +metadata: + namespace: ${ZED_KUBE_NAMESPACE} + name: postgrest-app +spec: + selector: + app: postgrest-app + ports: + - protocol: TCP + port: 8080 + targetPort: 8080 + +--- +apiVersion: v1 +kind: Service +metadata: + namespace: ${ZED_KUBE_NAMESPACE} + name: postgrest-llm +spec: + selector: + app: postgrest-llm + ports: + - protocol: TCP + port: 8080 + targetPort: 8080 + +--- +apiVersion: apps/v1 +kind: Deployment +metadata: + namespace: ${ZED_KUBE_NAMESPACE} + name: postgrest-app +spec: + replicas: 1 + selector: + matchLabels: + app: postgrest-app + template: + metadata: + labels: + app: postgrest-app spec: containers: - name: postgrest @@ -55,3 +137,39 @@ spec: secretKeyRef: name: postgrest key: jwt_secret + +--- +apiVersion: apps/v1 +kind: Deployment +metadata: + namespace: ${ZED_KUBE_NAMESPACE} + name: postgrest-llm +spec: + replicas: 1 + selector: + matchLabels: + app: postgrest-llm + template: + metadata: + labels: + app: postgrest-llm + spec: + containers: + - name: postgrest + image: "postgrest/postgrest" + ports: + - containerPort: 8080 + protocol: TCP + env: + - name: PGRST_SERVER_PORT + value: "8080" + - name: PGRST_DB_URI + valueFrom: + secretKeyRef: + name: llm-database + key: url + - name: PGRST_JWT_SECRET + valueFrom: + secretKeyRef: + name: postgrest + key: jwt_secret diff --git a/crates/collab/migrations_llm.sqlite/20240806182921_test_schema.sql b/crates/collab/migrations_llm.sqlite/20240806182921_test_schema.sql deleted file mode 100644 index ab854d3e43..0000000000 --- a/crates/collab/migrations_llm.sqlite/20240806182921_test_schema.sql +++ /dev/null @@ -1,32 +0,0 @@ -create table providers ( - id integer primary key autoincrement, - name text not null -); - -create unique index uix_providers_on_name on providers (name); - -create table models ( - id integer primary key autoincrement, - provider_id integer not null references providers (id) on delete cascade, - name text not null -); - -create unique index uix_models_on_provider_id_name on models (provider_id, name); -create index ix_models_on_provider_id on models (provider_id); -create index ix_models_on_name on models (name); - -create table if not exists usages ( - id integer primary key autoincrement, - user_id integer not null, - model_id integer not null references models (id) on delete cascade, - requests_this_minute integer not null default 0, - tokens_this_minute integer not null default 0, - requests_this_day integer not null default 0, - tokens_this_day integer not null default 0, - requests_this_month integer not null default 0, - tokens_this_month integer not null default 0 -); - -create index ix_usages_on_user_id on usages (user_id); -create index ix_usages_on_model_id on usages (model_id); -create unique index uix_usages_on_user_id_model_id on usages (user_id, model_id); diff --git a/crates/collab/migrations_llm/20240806182921_create_providers_and_models.sql b/crates/collab/migrations_llm/20240806182921_create_providers_and_models.sql index 059e6059dc..b81ab7567f 100644 --- a/crates/collab/migrations_llm/20240806182921_create_providers_and_models.sql +++ b/crates/collab/migrations_llm/20240806182921_create_providers_and_models.sql @@ -8,7 +8,10 @@ create unique index uix_providers_on_name on providers (name); create table if not exists models ( id serial primary key, provider_id integer not null references providers (id) on delete cascade, - name text not null + name text not null, + max_requests_per_minute integer not null, + max_tokens_per_minute integer not null, + max_tokens_per_day integer not null ); create unique index uix_models_on_provider_id_name on models (provider_id, name); diff --git a/crates/collab/migrations_llm/20240806213401_create_usages.sql b/crates/collab/migrations_llm/20240806213401_create_usages.sql index 913f0a1add..da2245d4b9 100644 --- a/crates/collab/migrations_llm/20240806213401_create_usages.sql +++ b/crates/collab/migrations_llm/20240806213401_create_usages.sql @@ -1,15 +1,19 @@ +create table usage_measures ( + id serial primary key, + name text not null +); + +create unique index uix_usage_measures_on_name on usage_measures (name); + create table if not exists usages ( id serial primary key, user_id integer not null, model_id integer not null references models (id) on delete cascade, - requests_this_minute integer not null default 0, - tokens_this_minute bigint not null default 0, - requests_this_day integer not null default 0, - tokens_this_day bigint not null default 0, - requests_this_month integer not null default 0, - tokens_this_month bigint not null default 0 + measure_id integer not null references usage_measures (id) on delete cascade, + timestamp timestamp without time zone not null, + buckets bigint[] not null ); create index ix_usages_on_user_id on usages (user_id); create index ix_usages_on_model_id on usages (model_id); -create unique index uix_usages_on_user_id_model_id on usages (user_id, model_id); +create unique index uix_usages_on_user_id_model_id_measure_id on usages (user_id, model_id, measure_id); diff --git a/crates/collab/src/llm.rs b/crates/collab/src/llm.rs index 43eaba572b..2072861363 100644 --- a/crates/collab/src/llm.rs +++ b/crates/collab/src/llm.rs @@ -2,24 +2,25 @@ mod authorization; pub mod db; mod token; -use crate::api::CloudflareIpCountryHeader; -use crate::llm::authorization::authorize_access_to_language_model; -use crate::llm::db::LlmDatabase; -use crate::{executor::Executor, Config, Error, Result}; +use crate::{api::CloudflareIpCountryHeader, executor::Executor, Config, Error, Result}; use anyhow::{anyhow, Context as _}; -use axum::TypedHeader; +use authorization::authorize_access_to_language_model; use axum::{ body::Body, http::{self, HeaderName, HeaderValue, Request, StatusCode}, middleware::{self, Next}, response::{IntoResponse, Response}, routing::post, - Extension, Json, Router, + Extension, Json, Router, TypedHeader, }; +use chrono::{DateTime, Duration, Utc}; +use db::{ActiveUserCount, LlmDatabase}; use futures::StreamExt as _; use http_client::IsahcHttpClient; use rpc::{LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME}; use std::sync::Arc; +use tokio::sync::RwLock; +use util::ResultExt; pub use token::*; @@ -28,8 +29,11 @@ pub struct LlmState { pub executor: Executor, pub db: Option>, pub http_client: IsahcHttpClient, + active_user_count: RwLock, ActiveUserCount)>>, } +const ACTIVE_USER_COUNT_CACHE_DURATION: Duration = Duration::seconds(30); + impl LlmState { pub async fn new(config: Config, executor: Executor) -> Result> { // TODO: This is temporary until we have the LLM database stood up. @@ -44,7 +48,8 @@ impl LlmState { let mut db_options = db::ConnectOptions::new(database_url); db_options.max_connections(max_connections); - let db = LlmDatabase::new(db_options, executor.clone()).await?; + let mut db = LlmDatabase::new(db_options, executor.clone()).await?; + db.initialize().await?; Some(Arc::new(db)) } else { @@ -57,15 +62,41 @@ impl LlmState { .build() .context("failed to construct http client")?; + let initial_active_user_count = if let Some(db) = &db { + Some((Utc::now(), db.get_active_user_count(Utc::now()).await?)) + } else { + None + }; + let this = Self { config, executor, db, http_client, + active_user_count: RwLock::new(initial_active_user_count), }; Ok(Arc::new(this)) } + + pub async fn get_active_user_count(&self) -> Result { + let now = Utc::now(); + + if let Some((last_updated, count)) = self.active_user_count.read().await.as_ref() { + if now - *last_updated < ACTIVE_USER_COUNT_CACHE_DURATION { + return Ok(*count); + } + } + + if let Some(db) = &self.db { + let mut cache = self.active_user_count.write().await; + let new_count = db.get_active_user_count(now).await?; + *cache = Some((now, new_count)); + Ok(new_count) + } else { + Ok(ActiveUserCount::default()) + } + } } pub fn routes() -> Router<(), Body> { @@ -122,14 +153,22 @@ async fn perform_completion( country_code_header: Option>, Json(params): Json, ) -> Result { + let model = normalize_model_name(params.provider, params.model); + authorize_access_to_language_model( &state.config, &claims, country_code_header.map(|header| header.to_string()), params.provider, - ¶ms.model, + &model, )?; + let user_id = claims.user_id as i32; + + if state.db.is_some() { + check_usage_limit(&state, params.provider, &model, &claims).await?; + } + match params.provider { LanguageModelProvider::Anthropic => { let api_key = state @@ -160,9 +199,31 @@ async fn perform_completion( ) .await?; - let stream = chunks.map(|event| { + let mut recorder = state.db.clone().map(|db| UsageRecorder { + db, + executor: state.executor.clone(), + user_id, + provider: params.provider, + model, + token_count: 0, + }); + + let stream = chunks.map(move |event| { let mut buffer = Vec::new(); event.map(|chunk| { + match &chunk { + anthropic::Event::MessageStart { + message: anthropic::Response { usage, .. }, + } + | anthropic::Event::MessageDelta { usage, .. } => { + if let Some(recorder) = &mut recorder { + recorder.token_count += usage.input_tokens.unwrap_or(0) as usize; + recorder.token_count += usage.output_tokens.unwrap_or(0) as usize; + } + } + _ => {} + } + buffer.clear(); serde_json::to_writer(&mut buffer, &chunk).unwrap(); buffer.push(b'\n'); @@ -259,3 +320,102 @@ async fn perform_completion( } } } + +fn normalize_model_name(provider: LanguageModelProvider, name: String) -> String { + match provider { + LanguageModelProvider::Anthropic => { + for prefix in &[ + "claude-3-5-sonnet", + "claude-3-haiku", + "claude-3-opus", + "claude-3-sonnet", + ] { + if name.starts_with(prefix) { + return prefix.to_string(); + } + } + } + LanguageModelProvider::OpenAi => {} + LanguageModelProvider::Google => {} + LanguageModelProvider::Zed => {} + } + + name +} + +async fn check_usage_limit( + state: &Arc, + provider: LanguageModelProvider, + model_name: &str, + claims: &LlmTokenClaims, +) -> Result<()> { + let db = state + .db + .as_ref() + .ok_or_else(|| anyhow!("LLM database not configured"))?; + let model = db.model(provider, model_name)?; + let usage = db + .get_usage(claims.user_id as i32, provider, model_name, Utc::now()) + .await?; + + let active_users = state.get_active_user_count().await?; + + let per_user_max_requests_per_minute = + model.max_requests_per_minute as usize / active_users.users_in_recent_minutes.max(1); + let per_user_max_tokens_per_minute = + model.max_tokens_per_minute as usize / active_users.users_in_recent_minutes.max(1); + let per_user_max_tokens_per_day = + model.max_tokens_per_day as usize / active_users.users_in_recent_days.max(1); + + let checks = [ + ( + usage.requests_this_minute, + per_user_max_requests_per_minute, + "requests per minute", + ), + ( + usage.tokens_this_minute, + per_user_max_tokens_per_minute, + "tokens per minute", + ), + ( + usage.tokens_this_day, + per_user_max_tokens_per_day, + "tokens per day", + ), + ]; + + for (usage, limit, resource) in checks { + if usage > limit { + return Err(Error::http( + StatusCode::TOO_MANY_REQUESTS, + format!("Rate limit exceeded. Maximum {} reached.", resource), + )); + } + } + + Ok(()) +} +struct UsageRecorder { + db: Arc, + executor: Executor, + user_id: i32, + provider: LanguageModelProvider, + model: String, + token_count: usize, +} + +impl Drop for UsageRecorder { + fn drop(&mut self) { + let db = self.db.clone(); + let user_id = self.user_id; + let provider = self.provider; + let model = std::mem::take(&mut self.model); + let token_count = self.token_count; + self.executor.spawn_detached(async move { + db.record_usage(user_id, provider, &model, token_count, Utc::now()) + .await + .log_err(); + }) + } +} diff --git a/crates/collab/src/llm/db.rs b/crates/collab/src/llm/db.rs index c0eff088bd..b3144eeecd 100644 --- a/crates/collab/src/llm/db.rs +++ b/crates/collab/src/llm/db.rs @@ -1,20 +1,26 @@ mod ids; mod queries; +mod seed; mod tables; #[cfg(test)] mod tests; +use collections::HashMap; pub use ids::*; +use rpc::LanguageModelProvider; +pub use seed::*; pub use tables::*; #[cfg(test)] pub use tests::TestLlmDb; +use usage_measure::UsageMeasure; use std::future::Future; use std::sync::Arc; use anyhow::anyhow; +pub use queries::usages::ActiveUserCount; use sea_orm::prelude::*; pub use sea_orm::ConnectOptions; use sea_orm::{ @@ -31,6 +37,9 @@ pub struct LlmDatabase { pool: DatabaseConnection, #[allow(unused)] executor: Executor, + provider_ids: HashMap, + models: HashMap<(LanguageModelProvider, String), model::Model>, + usage_measure_ids: HashMap, #[cfg(test)] runtime: Option, } @@ -43,11 +52,28 @@ impl LlmDatabase { options: options.clone(), pool: sea_orm::Database::connect(options).await?, executor, + provider_ids: HashMap::default(), + models: HashMap::default(), + usage_measure_ids: HashMap::default(), #[cfg(test)] runtime: None, }) } + pub async fn initialize(&mut self) -> Result<()> { + self.initialize_providers().await?; + self.initialize_models().await?; + self.initialize_usage_measures().await?; + Ok(()) + } + + pub fn model(&self, provider: LanguageModelProvider, name: &str) -> Result<&model::Model> { + Ok(self + .models + .get(&(provider, name.to_string())) + .ok_or_else(|| anyhow!("unknown model {provider:?}:{name}"))?) + } + pub fn options(&self) -> &ConnectOptions { &self.options } diff --git a/crates/collab/src/llm/db/ids.rs b/crates/collab/src/llm/db/ids.rs index 2b256651f8..d0705024df 100644 --- a/crates/collab/src/llm/db/ids.rs +++ b/crates/collab/src/llm/db/ids.rs @@ -6,3 +6,4 @@ use crate::id_type; id_type!(ModelId); id_type!(ProviderId); id_type!(UsageId); +id_type!(UsageMeasureId); diff --git a/crates/collab/src/llm/db/queries/providers.rs b/crates/collab/src/llm/db/queries/providers.rs index d96f8453e2..975bf607ce 100644 --- a/crates/collab/src/llm/db/queries/providers.rs +++ b/crates/collab/src/llm/db/queries/providers.rs @@ -1,66 +1,115 @@ -use sea_orm::sea_query::OnConflict; -use sea_orm::QueryOrder; - use super::*; +use sea_orm::QueryOrder; +use std::str::FromStr; +use strum::IntoEnumIterator as _; + +pub struct ModelRateLimits { + pub max_requests_per_minute: i32, + pub max_tokens_per_minute: i32, + pub max_tokens_per_day: i32, +} impl LlmDatabase { - pub async fn initialize_providers(&self) -> Result<()> { + pub async fn initialize_providers(&mut self) -> Result<()> { + self.provider_ids = self + .transaction(|tx| async move { + let existing_providers = provider::Entity::find().all(&*tx).await?; + + let mut new_providers = LanguageModelProvider::iter() + .filter(|provider| { + !existing_providers + .iter() + .any(|p| p.name == provider.to_string()) + }) + .map(|provider| provider::ActiveModel { + name: ActiveValue::set(provider.to_string()), + ..Default::default() + }) + .peekable(); + + if new_providers.peek().is_some() { + provider::Entity::insert_many(new_providers) + .exec(&*tx) + .await?; + } + + let all_providers: HashMap<_, _> = provider::Entity::find() + .all(&*tx) + .await? + .iter() + .filter_map(|provider| { + LanguageModelProvider::from_str(&provider.name) + .ok() + .map(|p| (p, provider.id)) + }) + .collect(); + + Ok(all_providers) + }) + .await?; + Ok(()) + } + + pub async fn initialize_models(&mut self) -> Result<()> { + let all_provider_ids = &self.provider_ids; + self.models = self + .transaction(|tx| async move { + let all_models: HashMap<_, _> = model::Entity::find() + .all(&*tx) + .await? + .into_iter() + .filter_map(|model| { + let provider = all_provider_ids.iter().find_map(|(provider, id)| { + if *id == model.provider_id { + Some(provider) + } else { + None + } + })?; + Some(((*provider, model.name.clone()), model)) + }) + .collect(); + Ok(all_models) + }) + .await?; + Ok(()) + } + + pub async fn insert_models( + &mut self, + models: &[(LanguageModelProvider, String, ModelRateLimits)], + ) -> Result<()> { + let all_provider_ids = &self.provider_ids; self.transaction(|tx| async move { - let providers_and_models = vec![ - ("anthropic", "claude-3-5-sonnet"), - ("anthropic", "claude-3-opus"), - ("anthropic", "claude-3-sonnet"), - ("anthropic", "claude-3-haiku"), - ]; - - for (provider_name, model_name) in providers_and_models { - let insert_provider = provider::Entity::insert(provider::ActiveModel { - name: ActiveValue::set(provider_name.to_owned()), + model::Entity::insert_many(models.into_iter().map(|(provider, name, rate_limits)| { + let provider_id = all_provider_ids[&provider]; + model::ActiveModel { + provider_id: ActiveValue::set(provider_id), + name: ActiveValue::set(name.clone()), + max_requests_per_minute: ActiveValue::set(rate_limits.max_requests_per_minute), + max_tokens_per_minute: ActiveValue::set(rate_limits.max_tokens_per_minute), + max_tokens_per_day: ActiveValue::set(rate_limits.max_tokens_per_day), ..Default::default() - }) - .on_conflict( - OnConflict::columns([provider::Column::Name]) - .update_column(provider::Column::Name) - .to_owned(), - ); - - let provider = if tx.support_returning() { - insert_provider.exec_with_returning(&*tx).await? - } else { - insert_provider.exec_without_returning(&*tx).await?; - provider::Entity::find() - .filter(provider::Column::Name.eq(provider_name)) - .one(&*tx) - .await? - .ok_or_else(|| anyhow!("failed to insert provider"))? - }; - - model::Entity::insert(model::ActiveModel { - provider_id: ActiveValue::set(provider.id), - name: ActiveValue::set(model_name.to_owned()), - ..Default::default() - }) - .on_conflict( - OnConflict::columns([model::Column::ProviderId, model::Column::Name]) - .update_column(model::Column::Name) - .to_owned(), - ) - .exec_without_returning(&*tx) - .await?; - } - + } + })) + .exec_without_returning(&*tx) + .await?; Ok(()) }) - .await + .await?; + self.initialize_models().await } /// Returns the list of LLM providers. - pub async fn list_providers(&self) -> Result> { + pub async fn list_providers(&self) -> Result> { self.transaction(|tx| async move { Ok(provider::Entity::find() .order_by_asc(provider::Column::Name) .all(&*tx) - .await?) + .await? + .into_iter() + .filter_map(|p| LanguageModelProvider::from_str(&p.name).ok()) + .collect()) }) .await } diff --git a/crates/collab/src/llm/db/queries/usages.rs b/crates/collab/src/llm/db/queries/usages.rs index 4b672fa6ac..206e5d39ab 100644 --- a/crates/collab/src/llm/db/queries/usages.rs +++ b/crates/collab/src/llm/db/queries/usages.rs @@ -1,57 +1,318 @@ +use chrono::Duration; use rpc::LanguageModelProvider; +use sea_orm::QuerySelect; +use std::{iter, str::FromStr}; +use strum::IntoEnumIterator as _; use super::*; +#[derive(Debug, PartialEq, Clone, Copy)] +pub struct Usage { + pub requests_this_minute: usize, + pub tokens_this_minute: usize, + pub tokens_this_day: usize, + pub tokens_this_month: usize, +} + +#[derive(Clone, Copy, Debug, Default)] +pub struct ActiveUserCount { + pub users_in_recent_minutes: usize, + pub users_in_recent_days: usize, +} + impl LlmDatabase { - pub async fn find_or_create_usage( + pub async fn initialize_usage_measures(&mut self) -> Result<()> { + let all_measures = self + .transaction(|tx| async move { + let existing_measures = usage_measure::Entity::find().all(&*tx).await?; + + let new_measures = UsageMeasure::iter() + .filter(|measure| { + !existing_measures + .iter() + .any(|m| m.name == measure.to_string()) + }) + .map(|measure| usage_measure::ActiveModel { + name: ActiveValue::set(measure.to_string()), + ..Default::default() + }) + .collect::>(); + + if !new_measures.is_empty() { + usage_measure::Entity::insert_many(new_measures) + .exec(&*tx) + .await?; + } + + Ok(usage_measure::Entity::find().all(&*tx).await?) + }) + .await?; + + self.usage_measure_ids = all_measures + .into_iter() + .filter_map(|measure| { + UsageMeasure::from_str(&measure.name) + .ok() + .map(|um| (um, measure.id)) + }) + .collect(); + Ok(()) + } + + pub async fn get_usage( &self, user_id: i32, provider: LanguageModelProvider, model_name: &str, - ) -> Result { + now: DateTimeUtc, + ) -> Result { self.transaction(|tx| async move { - let provider_name = match provider { - LanguageModelProvider::Anthropic => "anthropic", - LanguageModelProvider::OpenAi => "open_ai", - LanguageModelProvider::Google => "google", - LanguageModelProvider::Zed => "zed", - }; + let model = self + .models + .get(&(provider, model_name.to_string())) + .ok_or_else(|| anyhow!("unknown model {provider}:{model_name}"))?; - let model = model::Entity::find() - .inner_join(provider::Entity) - .filter( - provider::Column::Name - .eq(provider_name) - .and(model::Column::Name.eq(model_name)), - ) - .one(&*tx) - .await? - // TODO: Create the model, if one doesn't exist. - .ok_or_else(|| anyhow!("no model found for {provider_name}:{model_name}"))?; - let model_id = model.id; - - let existing_usage = usage::Entity::find() + let usages = usage::Entity::find() .filter( usage::Column::UserId .eq(user_id) - .and(usage::Column::ModelId.eq(model_id)), + .and(usage::Column::ModelId.eq(model.id)), ) - .one(&*tx) + .all(&*tx) .await?; - if let Some(usage) = existing_usage { - return Ok(usage); - } - let usage = usage::Entity::insert(usage::ActiveModel { - user_id: ActiveValue::set(user_id), - model_id: ActiveValue::set(model_id), - ..Default::default() + let requests_this_minute = + self.get_usage_for_measure(&usages, now, UsageMeasure::RequestsPerMinute)?; + let tokens_this_minute = + self.get_usage_for_measure(&usages, now, UsageMeasure::TokensPerMinute)?; + let tokens_this_day = + self.get_usage_for_measure(&usages, now, UsageMeasure::TokensPerDay)?; + let tokens_this_month = + self.get_usage_for_measure(&usages, now, UsageMeasure::TokensPerMonth)?; + + Ok(Usage { + requests_this_minute, + tokens_this_minute, + tokens_this_day, + tokens_this_month, }) - .exec_with_returning(&*tx) - .await?; - - Ok(usage) }) .await } + + pub async fn record_usage( + &self, + user_id: i32, + provider: LanguageModelProvider, + model_name: &str, + token_count: usize, + now: DateTimeUtc, + ) -> Result<()> { + self.transaction(|tx| async move { + let model = self.model(provider, model_name)?; + + let usages = usage::Entity::find() + .filter( + usage::Column::UserId + .eq(user_id) + .and(usage::Column::ModelId.eq(model.id)), + ) + .all(&*tx) + .await?; + + self.update_usage_for_measure( + user_id, + model.id, + &usages, + UsageMeasure::RequestsPerMinute, + now, + 1, + &tx, + ) + .await?; + self.update_usage_for_measure( + user_id, + model.id, + &usages, + UsageMeasure::TokensPerMinute, + now, + token_count, + &tx, + ) + .await?; + self.update_usage_for_measure( + user_id, + model.id, + &usages, + UsageMeasure::TokensPerDay, + now, + token_count, + &tx, + ) + .await?; + self.update_usage_for_measure( + user_id, + model.id, + &usages, + UsageMeasure::TokensPerMonth, + now, + token_count, + &tx, + ) + .await?; + + Ok(()) + }) + .await + } + + pub async fn get_active_user_count(&self, now: DateTimeUtc) -> Result { + self.transaction(|tx| async move { + let minute_since = now - Duration::minutes(5); + let day_since = now - Duration::days(5); + + let users_in_recent_minutes = usage::Entity::find() + .filter(usage::Column::Timestamp.gte(minute_since.naive_utc())) + .group_by(usage::Column::UserId) + .count(&*tx) + .await? as usize; + + let users_in_recent_days = usage::Entity::find() + .filter(usage::Column::Timestamp.gte(day_since.naive_utc())) + .group_by(usage::Column::UserId) + .count(&*tx) + .await? as usize; + + Ok(ActiveUserCount { + users_in_recent_minutes, + users_in_recent_days, + }) + }) + .await + } + + #[allow(clippy::too_many_arguments)] + async fn update_usage_for_measure( + &self, + user_id: i32, + model_id: ModelId, + usages: &[usage::Model], + usage_measure: UsageMeasure, + now: DateTimeUtc, + usage_to_add: usize, + tx: &DatabaseTransaction, + ) -> Result<()> { + let now = now.naive_utc(); + let measure_id = *self + .usage_measure_ids + .get(&usage_measure) + .ok_or_else(|| anyhow!("usage measure {usage_measure} not found"))?; + + let mut id = None; + let mut timestamp = now; + let mut buckets = vec![0_i64]; + + if let Some(old_usage) = usages.iter().find(|usage| usage.measure_id == measure_id) { + id = Some(old_usage.id); + let (live_buckets, buckets_since) = + Self::get_live_buckets(old_usage, now, usage_measure); + if !live_buckets.is_empty() { + buckets.clear(); + buckets.extend_from_slice(live_buckets); + buckets.extend(iter::repeat(0).take(buckets_since)); + timestamp = + old_usage.timestamp + (usage_measure.bucket_duration() * buckets_since as i32); + } + } + + *buckets.last_mut().unwrap() += usage_to_add as i64; + + let mut model = usage::ActiveModel { + user_id: ActiveValue::set(user_id), + model_id: ActiveValue::set(model_id), + measure_id: ActiveValue::set(measure_id), + timestamp: ActiveValue::set(timestamp), + buckets: ActiveValue::set(buckets), + ..Default::default() + }; + + if let Some(id) = id { + model.id = ActiveValue::unchanged(id); + model.update(tx).await?; + } else { + usage::Entity::insert(model) + .exec_without_returning(tx) + .await?; + } + + Ok(()) + } + + fn get_usage_for_measure( + &self, + usages: &[usage::Model], + now: DateTimeUtc, + usage_measure: UsageMeasure, + ) -> Result { + let now = now.naive_utc(); + let measure_id = *self + .usage_measure_ids + .get(&usage_measure) + .ok_or_else(|| anyhow!("usage measure {usage_measure} not found"))?; + let Some(usage) = usages.iter().find(|usage| usage.measure_id == measure_id) else { + return Ok(0); + }; + + let (live_buckets, _) = Self::get_live_buckets(usage, now, usage_measure); + Ok(live_buckets.iter().sum::() as _) + } + + fn get_live_buckets( + usage: &usage::Model, + now: chrono::NaiveDateTime, + measure: UsageMeasure, + ) -> (&[i64], usize) { + let seconds_since_usage = (now - usage.timestamp).num_seconds().max(0); + let buckets_since_usage = + seconds_since_usage as f32 / measure.bucket_duration().num_seconds() as f32; + let buckets_since_usage = buckets_since_usage.ceil() as usize; + let mut live_buckets = &[] as &[i64]; + if buckets_since_usage < measure.bucket_count() { + let expired_bucket_count = + (usage.buckets.len() + buckets_since_usage).saturating_sub(measure.bucket_count()); + live_buckets = &usage.buckets[expired_bucket_count..]; + while live_buckets.first() == Some(&0) { + live_buckets = &live_buckets[1..]; + } + } + (live_buckets, buckets_since_usage) + } +} + +const MINUTE_BUCKET_COUNT: usize = 12; +const DAY_BUCKET_COUNT: usize = 48; +const MONTH_BUCKET_COUNT: usize = 30; + +impl UsageMeasure { + fn bucket_count(&self) -> usize { + match self { + UsageMeasure::RequestsPerMinute => MINUTE_BUCKET_COUNT, + UsageMeasure::TokensPerMinute => MINUTE_BUCKET_COUNT, + UsageMeasure::TokensPerDay => DAY_BUCKET_COUNT, + UsageMeasure::TokensPerMonth => MONTH_BUCKET_COUNT, + } + } + + fn total_duration(&self) -> Duration { + match self { + UsageMeasure::RequestsPerMinute => Duration::minutes(1), + UsageMeasure::TokensPerMinute => Duration::minutes(1), + UsageMeasure::TokensPerDay => Duration::hours(24), + UsageMeasure::TokensPerMonth => Duration::days(30), + } + } + + fn bucket_duration(&self) -> Duration { + self.total_duration() / self.bucket_count() as i32 + } } diff --git a/crates/collab/src/llm/db/seed.rs b/crates/collab/src/llm/db/seed.rs new file mode 100644 index 0000000000..fe1a073b15 --- /dev/null +++ b/crates/collab/src/llm/db/seed.rs @@ -0,0 +1,45 @@ +use super::*; +use crate::{Config, Result}; +use queries::providers::ModelRateLimits; + +pub async fn seed_database(_config: &Config, db: &mut LlmDatabase, _force: bool) -> Result<()> { + db.insert_models(&[ + ( + LanguageModelProvider::Anthropic, + "claude-3-5-sonnet".into(), + ModelRateLimits { + max_requests_per_minute: 5, + max_tokens_per_minute: 20_000, + max_tokens_per_day: 300_000, + }, + ), + ( + LanguageModelProvider::Anthropic, + "claude-3-opus".into(), + ModelRateLimits { + max_requests_per_minute: 5, + max_tokens_per_minute: 10_000, + max_tokens_per_day: 300_000, + }, + ), + ( + LanguageModelProvider::Anthropic, + "claude-3-sonnet".into(), + ModelRateLimits { + max_requests_per_minute: 5, + max_tokens_per_minute: 20_000, + max_tokens_per_day: 300_000, + }, + ), + ( + LanguageModelProvider::Anthropic, + "claude-3-haiku".into(), + ModelRateLimits { + max_requests_per_minute: 5, + max_tokens_per_minute: 25_000, + max_tokens_per_day: 300_000, + }, + ), + ]) + .await +} diff --git a/crates/collab/src/llm/db/tables.rs b/crates/collab/src/llm/db/tables.rs index 87307eacfa..603e7f91a4 100644 --- a/crates/collab/src/llm/db/tables.rs +++ b/crates/collab/src/llm/db/tables.rs @@ -1,3 +1,4 @@ pub mod model; pub mod provider; pub mod usage; +pub mod usage_measure; diff --git a/crates/collab/src/llm/db/tables/model.rs b/crates/collab/src/llm/db/tables/model.rs index c8ff1ce47e..eb07ab9473 100644 --- a/crates/collab/src/llm/db/tables/model.rs +++ b/crates/collab/src/llm/db/tables/model.rs @@ -10,6 +10,9 @@ pub struct Model { pub id: ModelId, pub provider_id: ProviderId, pub name: String, + pub max_requests_per_minute: i32, + pub max_tokens_per_minute: i32, + pub max_tokens_per_day: i32, } #[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] diff --git a/crates/collab/src/llm/db/tables/provider.rs b/crates/collab/src/llm/db/tables/provider.rs index 7f9aa8ee0d..90838f7c65 100644 --- a/crates/collab/src/llm/db/tables/provider.rs +++ b/crates/collab/src/llm/db/tables/provider.rs @@ -1,6 +1,5 @@ -use sea_orm::entity::prelude::*; - use crate::llm::db::ProviderId; +use sea_orm::entity::prelude::*; /// An LLM provider. #[derive(Clone, Debug, PartialEq, DeriveEntityModel)] diff --git a/crates/collab/src/llm/db/tables/usage.rs b/crates/collab/src/llm/db/tables/usage.rs index afb4f7e03a..5d131133c3 100644 --- a/crates/collab/src/llm/db/tables/usage.rs +++ b/crates/collab/src/llm/db/tables/usage.rs @@ -1,24 +1,20 @@ +use crate::llm::db::{ModelId, UsageId, UsageMeasureId}; use sea_orm::entity::prelude::*; -use crate::llm::db::ModelId; - /// An LLM usage record. #[derive(Clone, Debug, PartialEq, DeriveEntityModel)] #[sea_orm(table_name = "usages")] pub struct Model { #[sea_orm(primary_key)] - pub id: i32, + pub id: UsageId, /// The ID of the Zed user. /// /// Corresponds to the `users` table in the primary collab database. pub user_id: i32, pub model_id: ModelId, - pub requests_this_minute: i32, - pub tokens_this_minute: i64, - pub requests_this_day: i32, - pub tokens_this_day: i64, - pub requests_this_month: i32, - pub tokens_this_month: i64, + pub measure_id: UsageMeasureId, + pub timestamp: DateTime, + pub buckets: Vec, } #[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] @@ -29,6 +25,12 @@ pub enum Relation { to = "super::model::Column::Id" )] Model, + #[sea_orm( + belongs_to = "super::usage_measure::Entity", + from = "Column::MeasureId", + to = "super::usage_measure::Column::Id" + )] + UsageMeasure, } impl Related for Entity { @@ -37,4 +39,10 @@ impl Related for Entity { } } +impl Related for Entity { + fn to() -> RelationDef { + Relation::UsageMeasure.def() + } +} + impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab/src/llm/db/tables/usage_measure.rs b/crates/collab/src/llm/db/tables/usage_measure.rs new file mode 100644 index 0000000000..6462f24907 --- /dev/null +++ b/crates/collab/src/llm/db/tables/usage_measure.rs @@ -0,0 +1,35 @@ +use crate::llm::db::UsageMeasureId; +use sea_orm::entity::prelude::*; + +#[derive( + Copy, Clone, Debug, PartialEq, Eq, Hash, strum::EnumString, strum::Display, strum::EnumIter, +)] +#[strum(serialize_all = "snake_case")] +pub enum UsageMeasure { + RequestsPerMinute, + TokensPerMinute, + TokensPerDay, + TokensPerMonth, +} + +#[derive(Clone, Debug, PartialEq, DeriveEntityModel)] +#[sea_orm(table_name = "usage_measures")] +pub struct Model { + #[sea_orm(primary_key)] + pub id: UsageMeasureId, + pub name: String, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm(has_many = "super::usage::Entity")] + Usages, +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::Usages.def() + } +} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab/src/llm/db/tests.rs b/crates/collab/src/llm/db/tests.rs index 1e76d85522..5fba2a24fd 100644 --- a/crates/collab/src/llm/db/tests.rs +++ b/crates/collab/src/llm/db/tests.rs @@ -6,7 +6,6 @@ use parking_lot::Mutex; use rand::prelude::*; use sea_orm::ConnectionTrait; use sqlx::migrate::MigrateDatabase; -use std::sync::Arc; use std::time::Duration; use crate::migrations::run_database_migrations; @@ -14,47 +13,11 @@ use crate::migrations::run_database_migrations; use super::*; pub struct TestLlmDb { - pub db: Option>, + pub db: Option, pub connection: Option, } impl TestLlmDb { - pub fn sqlite(background: BackgroundExecutor) -> Self { - let url = "sqlite::memory:"; - let runtime = tokio::runtime::Builder::new_current_thread() - .enable_io() - .enable_time() - .build() - .unwrap(); - - let mut db = runtime.block_on(async { - let mut options = ConnectOptions::new(url); - options.max_connections(5); - let db = LlmDatabase::new(options, Executor::Deterministic(background)) - .await - .unwrap(); - let sql = include_str!(concat!( - env!("CARGO_MANIFEST_DIR"), - "/migrations_llm.sqlite/20240806182921_test_schema.sql" - )); - db.pool - .execute(sea_orm::Statement::from_string( - db.pool.get_database_backend(), - sql, - )) - .await - .unwrap(); - db - }); - - db.runtime = Some(runtime); - - Self { - db: Some(Arc::new(db)), - connection: None, - } - } - pub fn postgres(background: BackgroundExecutor) -> Self { static LOCK: Mutex<()> = Mutex::new(()); @@ -91,29 +54,26 @@ impl TestLlmDb { db.runtime = Some(runtime); Self { - db: Some(Arc::new(db)), + db: Some(db), connection: None, } } - pub fn db(&self) -> &Arc { - self.db.as_ref().unwrap() + pub fn db(&mut self) -> &mut LlmDatabase { + self.db.as_mut().unwrap() } } #[macro_export] -macro_rules! test_both_llm_dbs { - ($test_name:ident, $postgres_test_name:ident, $sqlite_test_name:ident) => { - #[cfg(target_os = "macos")] +macro_rules! test_llm_db { + ($test_name:ident, $postgres_test_name:ident) => { #[gpui::test] async fn $postgres_test_name(cx: &mut gpui::TestAppContext) { - let test_db = $crate::llm::db::TestLlmDb::postgres(cx.executor().clone()); - $test_name(test_db.db()).await; - } + if !cfg!(target_os = "macos") { + return; + } - #[gpui::test] - async fn $sqlite_test_name(cx: &mut gpui::TestAppContext) { - let test_db = $crate::llm::db::TestLlmDb::sqlite(cx.executor().clone()); + let mut test_db = $crate::llm::db::TestLlmDb::postgres(cx.executor().clone()); $test_name(test_db.db()).await; } }; diff --git a/crates/collab/src/llm/db/tests/provider_tests.rs b/crates/collab/src/llm/db/tests/provider_tests.rs index 2b0d692bb8..ef0da1c373 100644 --- a/crates/collab/src/llm/db/tests/provider_tests.rs +++ b/crates/collab/src/llm/db/tests/provider_tests.rs @@ -1,17 +1,15 @@ -use std::sync::Arc; - use pretty_assertions::assert_eq; +use rpc::LanguageModelProvider; use crate::llm::db::LlmDatabase; -use crate::test_both_llm_dbs; +use crate::test_llm_db; -test_both_llm_dbs!( +test_llm_db!( test_initialize_providers, - test_initialize_providers_postgres, - test_initialize_providers_sqlite + test_initialize_providers_postgres ); -async fn test_initialize_providers(db: &Arc) { +async fn test_initialize_providers(db: &mut LlmDatabase) { let initial_providers = db.list_providers().await.unwrap(); assert_eq!(initial_providers, vec![]); @@ -22,9 +20,13 @@ async fn test_initialize_providers(db: &Arc) { let providers = db.list_providers().await.unwrap(); - let provider_names = providers - .into_iter() - .map(|provider| provider.name) - .collect::>(); - assert_eq!(provider_names, vec!["anthropic".to_string()]); + assert_eq!( + providers, + &[ + LanguageModelProvider::Anthropic, + LanguageModelProvider::Google, + LanguageModelProvider::OpenAi, + LanguageModelProvider::Zed + ] + ) } diff --git a/crates/collab/src/llm/db/tests/usage_tests.rs b/crates/collab/src/llm/db/tests/usage_tests.rs index ee2bcdbe01..081b333afc 100644 --- a/crates/collab/src/llm/db/tests/usage_tests.rs +++ b/crates/collab/src/llm/db/tests/usage_tests.rs @@ -1,24 +1,120 @@ -use std::sync::Arc; - +use crate::{ + llm::db::{queries::providers::ModelRateLimits, queries::usages::Usage, LlmDatabase}, + test_llm_db, +}; +use chrono::{Duration, Utc}; use pretty_assertions::assert_eq; use rpc::LanguageModelProvider; -use crate::llm::db::LlmDatabase; -use crate::test_both_llm_dbs; +test_llm_db!(test_tracking_usage, test_tracking_usage_postgres); -test_both_llm_dbs!( - test_find_or_create_usage, - test_find_or_create_usage_postgres, - test_find_or_create_usage_sqlite -); +async fn test_tracking_usage(db: &mut LlmDatabase) { + let provider = LanguageModelProvider::Anthropic; + let model = "claude-3-5-sonnet"; -async fn test_find_or_create_usage(db: &Arc) { - db.initialize_providers().await.unwrap(); + db.initialize().await.unwrap(); + db.insert_models(&[( + provider, + model.to_string(), + ModelRateLimits { + max_requests_per_minute: 5, + max_tokens_per_minute: 10_000, + max_tokens_per_day: 50_000, + }, + )]) + .await + .unwrap(); - let usage = db - .find_or_create_usage(123, LanguageModelProvider::Anthropic, "claude-3-5-sonnet") + let t0 = Utc::now(); + let user_id = 123; + + let now = t0; + db.record_usage(user_id, provider, model, 1000, now) .await .unwrap(); - assert_eq!(usage.user_id, 123); + let now = t0 + Duration::seconds(10); + db.record_usage(user_id, provider, model, 2000, now) + .await + .unwrap(); + + let usage = db.get_usage(user_id, provider, model, now).await.unwrap(); + assert_eq!( + usage, + Usage { + requests_this_minute: 2, + tokens_this_minute: 3000, + tokens_this_day: 3000, + tokens_this_month: 3000, + } + ); + + let now = t0 + Duration::seconds(60); + let usage = db.get_usage(user_id, provider, model, now).await.unwrap(); + assert_eq!( + usage, + Usage { + requests_this_minute: 1, + tokens_this_minute: 2000, + tokens_this_day: 3000, + tokens_this_month: 3000, + } + ); + + let now = t0 + Duration::seconds(60); + db.record_usage(user_id, provider, model, 3000, now) + .await + .unwrap(); + + let usage = db.get_usage(user_id, provider, model, now).await.unwrap(); + assert_eq!( + usage, + Usage { + requests_this_minute: 2, + tokens_this_minute: 5000, + tokens_this_day: 6000, + tokens_this_month: 6000, + } + ); + + let t1 = t0 + Duration::hours(24); + let now = t1; + let usage = db.get_usage(user_id, provider, model, now).await.unwrap(); + assert_eq!( + usage, + Usage { + requests_this_minute: 0, + tokens_this_minute: 0, + tokens_this_day: 5000, + tokens_this_month: 6000, + } + ); + + db.record_usage(user_id, provider, model, 4000, now) + .await + .unwrap(); + + let usage = db.get_usage(user_id, provider, model, now).await.unwrap(); + assert_eq!( + usage, + Usage { + requests_this_minute: 1, + tokens_this_minute: 4000, + tokens_this_day: 9000, + tokens_this_month: 10000, + } + ); + + let t2 = t0 + Duration::days(30); + let now = t2; + let usage = db.get_usage(user_id, provider, model, now).await.unwrap(); + assert_eq!( + usage, + Usage { + requests_this_minute: 0, + tokens_this_minute: 0, + tokens_this_day: 0, + tokens_this_month: 9000, + } + ); } diff --git a/crates/collab/src/main.rs b/crates/collab/src/main.rs index ec11bb6543..18515192d5 100644 --- a/crates/collab/src/main.rs +++ b/crates/collab/src/main.rs @@ -52,10 +52,18 @@ async fn main() -> Result<()> { Some("seed") => { let config = envy::from_env::().expect("error loading config"); let db_options = db::ConnectOptions::new(config.database_url.clone()); + let mut db = Database::new(db_options, Executor::Production).await?; db.initialize_notification_kinds().await?; - collab::seed::seed(&config, &db, true).await?; + collab::seed::seed(&config, &db, false).await?; + + if let Some(llm_database_url) = config.llm_database_url.clone() { + let db_options = db::ConnectOptions::new(llm_database_url); + let mut db = LlmDatabase::new(db_options.clone(), Executor::Production).await?; + db.initialize().await?; + collab::llm::db::seed_database(&config, &mut db, true).await?; + } } Some("serve") => { let mode = match args.next().as_deref() { diff --git a/crates/rpc/src/llm.rs b/crates/rpc/src/llm.rs index 2b1f4b9f4d..7f97b02df7 100644 --- a/crates/rpc/src/llm.rs +++ b/crates/rpc/src/llm.rs @@ -1,9 +1,13 @@ use serde::{Deserialize, Serialize}; +use strum::{Display, EnumIter, EnumString}; pub const EXPIRED_LLM_TOKEN_HEADER_NAME: &str = "x-zed-expired-token"; -#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)] +#[derive( + Debug, PartialEq, Eq, Hash, Clone, Copy, Serialize, Deserialize, EnumString, EnumIter, Display, +)] #[serde(rename_all = "snake_case")] +#[strum(serialize_all = "snake_case")] pub enum LanguageModelProvider { Anthropic, OpenAi,