Apply rate limits in LLM service (#15997)

Release Notes:

- N/A

---------

Co-authored-by: Marshall <marshall@zed.dev>
Co-authored-by: Marshall Bowers <elliott.codes@gmail.com>
This commit is contained in:
Max Brunsfeld 2024-08-08 15:46:33 -07:00 committed by GitHub
parent 2bc503771b
commit 06625bfe94
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 983 additions and 227 deletions

View File

@ -85,6 +85,11 @@ spec:
secretKeyRef: secretKeyRef:
name: database name: database
key: url key: url
- name: LLM_DATABASE_URL
valueFrom:
secretKeyRef:
name: llm-database
key: url
- name: DATABASE_MAX_CONNECTIONS - name: DATABASE_MAX_CONNECTIONS
value: "${DATABASE_MAX_CONNECTIONS}" value: "${DATABASE_MAX_CONNECTIONS}"
- name: API_TOKEN - name: API_TOKEN

View File

@ -12,7 +12,7 @@ metadata:
spec: spec:
type: LoadBalancer type: LoadBalancer
selector: selector:
app: postgrest app: nginx
ports: ports:
- name: web - name: web
protocol: TCP protocol: TCP
@ -24,17 +24,99 @@ apiVersion: apps/v1
kind: Deployment kind: Deployment
metadata: metadata:
namespace: ${ZED_KUBE_NAMESPACE} namespace: ${ZED_KUBE_NAMESPACE}
name: postgrest name: nginx
spec: spec:
replicas: 1 replicas: 1
selector: selector:
matchLabels: matchLabels:
app: postgrest app: nginx
template: template:
metadata: metadata:
labels: labels:
app: postgrest app: nginx
spec:
containers:
- name: nginx
image: nginx:latest
ports:
- containerPort: 8080
protocol: TCP
volumeMounts:
- name: nginx-config
mountPath: /etc/nginx/nginx.conf
subPath: nginx.conf
volumes:
- name: nginx-config
configMap:
name: nginx-config
---
apiVersion: v1
kind: ConfigMap
metadata:
namespace: ${ZED_KUBE_NAMESPACE}
name: nginx-config
data:
nginx.conf: |
events {}
http {
server {
listen 8080;
location /app/ {
proxy_pass http://postgrest-app:8080/;
}
location /llm/ {
proxy_pass http://postgrest-llm:8080/;
}
}
}
---
apiVersion: v1
kind: Service
metadata:
namespace: ${ZED_KUBE_NAMESPACE}
name: postgrest-app
spec:
selector:
app: postgrest-app
ports:
- protocol: TCP
port: 8080
targetPort: 8080
---
apiVersion: v1
kind: Service
metadata:
namespace: ${ZED_KUBE_NAMESPACE}
name: postgrest-llm
spec:
selector:
app: postgrest-llm
ports:
- protocol: TCP
port: 8080
targetPort: 8080
---
apiVersion: apps/v1
kind: Deployment
metadata:
namespace: ${ZED_KUBE_NAMESPACE}
name: postgrest-app
spec:
replicas: 1
selector:
matchLabels:
app: postgrest-app
template:
metadata:
labels:
app: postgrest-app
spec: spec:
containers: containers:
- name: postgrest - name: postgrest
@ -55,3 +137,39 @@ spec:
secretKeyRef: secretKeyRef:
name: postgrest name: postgrest
key: jwt_secret key: jwt_secret
---
apiVersion: apps/v1
kind: Deployment
metadata:
namespace: ${ZED_KUBE_NAMESPACE}
name: postgrest-llm
spec:
replicas: 1
selector:
matchLabels:
app: postgrest-llm
template:
metadata:
labels:
app: postgrest-llm
spec:
containers:
- name: postgrest
image: "postgrest/postgrest"
ports:
- containerPort: 8080
protocol: TCP
env:
- name: PGRST_SERVER_PORT
value: "8080"
- name: PGRST_DB_URI
valueFrom:
secretKeyRef:
name: llm-database
key: url
- name: PGRST_JWT_SECRET
valueFrom:
secretKeyRef:
name: postgrest
key: jwt_secret

View File

@ -1,32 +0,0 @@
create table providers (
id integer primary key autoincrement,
name text not null
);
create unique index uix_providers_on_name on providers (name);
create table models (
id integer primary key autoincrement,
provider_id integer not null references providers (id) on delete cascade,
name text not null
);
create unique index uix_models_on_provider_id_name on models (provider_id, name);
create index ix_models_on_provider_id on models (provider_id);
create index ix_models_on_name on models (name);
create table if not exists usages (
id integer primary key autoincrement,
user_id integer not null,
model_id integer not null references models (id) on delete cascade,
requests_this_minute integer not null default 0,
tokens_this_minute integer not null default 0,
requests_this_day integer not null default 0,
tokens_this_day integer not null default 0,
requests_this_month integer not null default 0,
tokens_this_month integer not null default 0
);
create index ix_usages_on_user_id on usages (user_id);
create index ix_usages_on_model_id on usages (model_id);
create unique index uix_usages_on_user_id_model_id on usages (user_id, model_id);

View File

@ -8,7 +8,10 @@ create unique index uix_providers_on_name on providers (name);
create table if not exists models ( create table if not exists models (
id serial primary key, id serial primary key,
provider_id integer not null references providers (id) on delete cascade, provider_id integer not null references providers (id) on delete cascade,
name text not null name text not null,
max_requests_per_minute integer not null,
max_tokens_per_minute integer not null,
max_tokens_per_day integer not null
); );
create unique index uix_models_on_provider_id_name on models (provider_id, name); create unique index uix_models_on_provider_id_name on models (provider_id, name);

View File

@ -1,15 +1,19 @@
create table usage_measures (
id serial primary key,
name text not null
);
create unique index uix_usage_measures_on_name on usage_measures (name);
create table if not exists usages ( create table if not exists usages (
id serial primary key, id serial primary key,
user_id integer not null, user_id integer not null,
model_id integer not null references models (id) on delete cascade, model_id integer not null references models (id) on delete cascade,
requests_this_minute integer not null default 0, measure_id integer not null references usage_measures (id) on delete cascade,
tokens_this_minute bigint not null default 0, timestamp timestamp without time zone not null,
requests_this_day integer not null default 0, buckets bigint[] not null
tokens_this_day bigint not null default 0,
requests_this_month integer not null default 0,
tokens_this_month bigint not null default 0
); );
create index ix_usages_on_user_id on usages (user_id); create index ix_usages_on_user_id on usages (user_id);
create index ix_usages_on_model_id on usages (model_id); create index ix_usages_on_model_id on usages (model_id);
create unique index uix_usages_on_user_id_model_id on usages (user_id, model_id); create unique index uix_usages_on_user_id_model_id_measure_id on usages (user_id, model_id, measure_id);

View File

@ -2,24 +2,25 @@ mod authorization;
pub mod db; pub mod db;
mod token; mod token;
use crate::api::CloudflareIpCountryHeader; use crate::{api::CloudflareIpCountryHeader, executor::Executor, Config, Error, Result};
use crate::llm::authorization::authorize_access_to_language_model;
use crate::llm::db::LlmDatabase;
use crate::{executor::Executor, Config, Error, Result};
use anyhow::{anyhow, Context as _}; use anyhow::{anyhow, Context as _};
use axum::TypedHeader; use authorization::authorize_access_to_language_model;
use axum::{ use axum::{
body::Body, body::Body,
http::{self, HeaderName, HeaderValue, Request, StatusCode}, http::{self, HeaderName, HeaderValue, Request, StatusCode},
middleware::{self, Next}, middleware::{self, Next},
response::{IntoResponse, Response}, response::{IntoResponse, Response},
routing::post, routing::post,
Extension, Json, Router, Extension, Json, Router, TypedHeader,
}; };
use chrono::{DateTime, Duration, Utc};
use db::{ActiveUserCount, LlmDatabase};
use futures::StreamExt as _; use futures::StreamExt as _;
use http_client::IsahcHttpClient; use http_client::IsahcHttpClient;
use rpc::{LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME}; use rpc::{LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME};
use std::sync::Arc; use std::sync::Arc;
use tokio::sync::RwLock;
use util::ResultExt;
pub use token::*; pub use token::*;
@ -28,8 +29,11 @@ pub struct LlmState {
pub executor: Executor, pub executor: Executor,
pub db: Option<Arc<LlmDatabase>>, pub db: Option<Arc<LlmDatabase>>,
pub http_client: IsahcHttpClient, pub http_client: IsahcHttpClient,
active_user_count: RwLock<Option<(DateTime<Utc>, ActiveUserCount)>>,
} }
const ACTIVE_USER_COUNT_CACHE_DURATION: Duration = Duration::seconds(30);
impl LlmState { impl LlmState {
pub async fn new(config: Config, executor: Executor) -> Result<Arc<Self>> { pub async fn new(config: Config, executor: Executor) -> Result<Arc<Self>> {
// TODO: This is temporary until we have the LLM database stood up. // TODO: This is temporary until we have the LLM database stood up.
@ -44,7 +48,8 @@ impl LlmState {
let mut db_options = db::ConnectOptions::new(database_url); let mut db_options = db::ConnectOptions::new(database_url);
db_options.max_connections(max_connections); db_options.max_connections(max_connections);
let db = LlmDatabase::new(db_options, executor.clone()).await?; let mut db = LlmDatabase::new(db_options, executor.clone()).await?;
db.initialize().await?;
Some(Arc::new(db)) Some(Arc::new(db))
} else { } else {
@ -57,15 +62,41 @@ impl LlmState {
.build() .build()
.context("failed to construct http client")?; .context("failed to construct http client")?;
let initial_active_user_count = if let Some(db) = &db {
Some((Utc::now(), db.get_active_user_count(Utc::now()).await?))
} else {
None
};
let this = Self { let this = Self {
config, config,
executor, executor,
db, db,
http_client, http_client,
active_user_count: RwLock::new(initial_active_user_count),
}; };
Ok(Arc::new(this)) Ok(Arc::new(this))
} }
pub async fn get_active_user_count(&self) -> Result<ActiveUserCount> {
let now = Utc::now();
if let Some((last_updated, count)) = self.active_user_count.read().await.as_ref() {
if now - *last_updated < ACTIVE_USER_COUNT_CACHE_DURATION {
return Ok(*count);
}
}
if let Some(db) = &self.db {
let mut cache = self.active_user_count.write().await;
let new_count = db.get_active_user_count(now).await?;
*cache = Some((now, new_count));
Ok(new_count)
} else {
Ok(ActiveUserCount::default())
}
}
} }
pub fn routes() -> Router<(), Body> { pub fn routes() -> Router<(), Body> {
@ -122,14 +153,22 @@ async fn perform_completion(
country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>, country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
Json(params): Json<PerformCompletionParams>, Json(params): Json<PerformCompletionParams>,
) -> Result<impl IntoResponse> { ) -> Result<impl IntoResponse> {
let model = normalize_model_name(params.provider, params.model);
authorize_access_to_language_model( authorize_access_to_language_model(
&state.config, &state.config,
&claims, &claims,
country_code_header.map(|header| header.to_string()), country_code_header.map(|header| header.to_string()),
params.provider, params.provider,
&params.model, &model,
)?; )?;
let user_id = claims.user_id as i32;
if state.db.is_some() {
check_usage_limit(&state, params.provider, &model, &claims).await?;
}
match params.provider { match params.provider {
LanguageModelProvider::Anthropic => { LanguageModelProvider::Anthropic => {
let api_key = state let api_key = state
@ -160,9 +199,31 @@ async fn perform_completion(
) )
.await?; .await?;
let stream = chunks.map(|event| { let mut recorder = state.db.clone().map(|db| UsageRecorder {
db,
executor: state.executor.clone(),
user_id,
provider: params.provider,
model,
token_count: 0,
});
let stream = chunks.map(move |event| {
let mut buffer = Vec::new(); let mut buffer = Vec::new();
event.map(|chunk| { event.map(|chunk| {
match &chunk {
anthropic::Event::MessageStart {
message: anthropic::Response { usage, .. },
}
| anthropic::Event::MessageDelta { usage, .. } => {
if let Some(recorder) = &mut recorder {
recorder.token_count += usage.input_tokens.unwrap_or(0) as usize;
recorder.token_count += usage.output_tokens.unwrap_or(0) as usize;
}
}
_ => {}
}
buffer.clear(); buffer.clear();
serde_json::to_writer(&mut buffer, &chunk).unwrap(); serde_json::to_writer(&mut buffer, &chunk).unwrap();
buffer.push(b'\n'); buffer.push(b'\n');
@ -259,3 +320,102 @@ async fn perform_completion(
} }
} }
} }
fn normalize_model_name(provider: LanguageModelProvider, name: String) -> String {
match provider {
LanguageModelProvider::Anthropic => {
for prefix in &[
"claude-3-5-sonnet",
"claude-3-haiku",
"claude-3-opus",
"claude-3-sonnet",
] {
if name.starts_with(prefix) {
return prefix.to_string();
}
}
}
LanguageModelProvider::OpenAi => {}
LanguageModelProvider::Google => {}
LanguageModelProvider::Zed => {}
}
name
}
async fn check_usage_limit(
state: &Arc<LlmState>,
provider: LanguageModelProvider,
model_name: &str,
claims: &LlmTokenClaims,
) -> Result<()> {
let db = state
.db
.as_ref()
.ok_or_else(|| anyhow!("LLM database not configured"))?;
let model = db.model(provider, model_name)?;
let usage = db
.get_usage(claims.user_id as i32, provider, model_name, Utc::now())
.await?;
let active_users = state.get_active_user_count().await?;
let per_user_max_requests_per_minute =
model.max_requests_per_minute as usize / active_users.users_in_recent_minutes.max(1);
let per_user_max_tokens_per_minute =
model.max_tokens_per_minute as usize / active_users.users_in_recent_minutes.max(1);
let per_user_max_tokens_per_day =
model.max_tokens_per_day as usize / active_users.users_in_recent_days.max(1);
let checks = [
(
usage.requests_this_minute,
per_user_max_requests_per_minute,
"requests per minute",
),
(
usage.tokens_this_minute,
per_user_max_tokens_per_minute,
"tokens per minute",
),
(
usage.tokens_this_day,
per_user_max_tokens_per_day,
"tokens per day",
),
];
for (usage, limit, resource) in checks {
if usage > limit {
return Err(Error::http(
StatusCode::TOO_MANY_REQUESTS,
format!("Rate limit exceeded. Maximum {} reached.", resource),
));
}
}
Ok(())
}
struct UsageRecorder {
db: Arc<LlmDatabase>,
executor: Executor,
user_id: i32,
provider: LanguageModelProvider,
model: String,
token_count: usize,
}
impl Drop for UsageRecorder {
fn drop(&mut self) {
let db = self.db.clone();
let user_id = self.user_id;
let provider = self.provider;
let model = std::mem::take(&mut self.model);
let token_count = self.token_count;
self.executor.spawn_detached(async move {
db.record_usage(user_id, provider, &model, token_count, Utc::now())
.await
.log_err();
})
}
}

View File

@ -1,20 +1,26 @@
mod ids; mod ids;
mod queries; mod queries;
mod seed;
mod tables; mod tables;
#[cfg(test)] #[cfg(test)]
mod tests; mod tests;
use collections::HashMap;
pub use ids::*; pub use ids::*;
use rpc::LanguageModelProvider;
pub use seed::*;
pub use tables::*; pub use tables::*;
#[cfg(test)] #[cfg(test)]
pub use tests::TestLlmDb; pub use tests::TestLlmDb;
use usage_measure::UsageMeasure;
use std::future::Future; use std::future::Future;
use std::sync::Arc; use std::sync::Arc;
use anyhow::anyhow; use anyhow::anyhow;
pub use queries::usages::ActiveUserCount;
use sea_orm::prelude::*; use sea_orm::prelude::*;
pub use sea_orm::ConnectOptions; pub use sea_orm::ConnectOptions;
use sea_orm::{ use sea_orm::{
@ -31,6 +37,9 @@ pub struct LlmDatabase {
pool: DatabaseConnection, pool: DatabaseConnection,
#[allow(unused)] #[allow(unused)]
executor: Executor, executor: Executor,
provider_ids: HashMap<LanguageModelProvider, ProviderId>,
models: HashMap<(LanguageModelProvider, String), model::Model>,
usage_measure_ids: HashMap<UsageMeasure, UsageMeasureId>,
#[cfg(test)] #[cfg(test)]
runtime: Option<tokio::runtime::Runtime>, runtime: Option<tokio::runtime::Runtime>,
} }
@ -43,11 +52,28 @@ impl LlmDatabase {
options: options.clone(), options: options.clone(),
pool: sea_orm::Database::connect(options).await?, pool: sea_orm::Database::connect(options).await?,
executor, executor,
provider_ids: HashMap::default(),
models: HashMap::default(),
usage_measure_ids: HashMap::default(),
#[cfg(test)] #[cfg(test)]
runtime: None, runtime: None,
}) })
} }
pub async fn initialize(&mut self) -> Result<()> {
self.initialize_providers().await?;
self.initialize_models().await?;
self.initialize_usage_measures().await?;
Ok(())
}
pub fn model(&self, provider: LanguageModelProvider, name: &str) -> Result<&model::Model> {
Ok(self
.models
.get(&(provider, name.to_string()))
.ok_or_else(|| anyhow!("unknown model {provider:?}:{name}"))?)
}
pub fn options(&self) -> &ConnectOptions { pub fn options(&self) -> &ConnectOptions {
&self.options &self.options
} }

View File

@ -6,3 +6,4 @@ use crate::id_type;
id_type!(ModelId); id_type!(ModelId);
id_type!(ProviderId); id_type!(ProviderId);
id_type!(UsageId); id_type!(UsageId);
id_type!(UsageMeasureId);

View File

@ -1,66 +1,115 @@
use sea_orm::sea_query::OnConflict;
use sea_orm::QueryOrder;
use super::*; use super::*;
use sea_orm::QueryOrder;
use std::str::FromStr;
use strum::IntoEnumIterator as _;
pub struct ModelRateLimits {
pub max_requests_per_minute: i32,
pub max_tokens_per_minute: i32,
pub max_tokens_per_day: i32,
}
impl LlmDatabase { impl LlmDatabase {
pub async fn initialize_providers(&self) -> Result<()> { pub async fn initialize_providers(&mut self) -> Result<()> {
self.provider_ids = self
.transaction(|tx| async move {
let existing_providers = provider::Entity::find().all(&*tx).await?;
let mut new_providers = LanguageModelProvider::iter()
.filter(|provider| {
!existing_providers
.iter()
.any(|p| p.name == provider.to_string())
})
.map(|provider| provider::ActiveModel {
name: ActiveValue::set(provider.to_string()),
..Default::default()
})
.peekable();
if new_providers.peek().is_some() {
provider::Entity::insert_many(new_providers)
.exec(&*tx)
.await?;
}
let all_providers: HashMap<_, _> = provider::Entity::find()
.all(&*tx)
.await?
.iter()
.filter_map(|provider| {
LanguageModelProvider::from_str(&provider.name)
.ok()
.map(|p| (p, provider.id))
})
.collect();
Ok(all_providers)
})
.await?;
Ok(())
}
pub async fn initialize_models(&mut self) -> Result<()> {
let all_provider_ids = &self.provider_ids;
self.models = self
.transaction(|tx| async move {
let all_models: HashMap<_, _> = model::Entity::find()
.all(&*tx)
.await?
.into_iter()
.filter_map(|model| {
let provider = all_provider_ids.iter().find_map(|(provider, id)| {
if *id == model.provider_id {
Some(provider)
} else {
None
}
})?;
Some(((*provider, model.name.clone()), model))
})
.collect();
Ok(all_models)
})
.await?;
Ok(())
}
pub async fn insert_models(
&mut self,
models: &[(LanguageModelProvider, String, ModelRateLimits)],
) -> Result<()> {
let all_provider_ids = &self.provider_ids;
self.transaction(|tx| async move { self.transaction(|tx| async move {
let providers_and_models = vec![ model::Entity::insert_many(models.into_iter().map(|(provider, name, rate_limits)| {
("anthropic", "claude-3-5-sonnet"), let provider_id = all_provider_ids[&provider];
("anthropic", "claude-3-opus"), model::ActiveModel {
("anthropic", "claude-3-sonnet"), provider_id: ActiveValue::set(provider_id),
("anthropic", "claude-3-haiku"), name: ActiveValue::set(name.clone()),
]; max_requests_per_minute: ActiveValue::set(rate_limits.max_requests_per_minute),
max_tokens_per_minute: ActiveValue::set(rate_limits.max_tokens_per_minute),
for (provider_name, model_name) in providers_and_models { max_tokens_per_day: ActiveValue::set(rate_limits.max_tokens_per_day),
let insert_provider = provider::Entity::insert(provider::ActiveModel {
name: ActiveValue::set(provider_name.to_owned()),
..Default::default() ..Default::default()
}) }
.on_conflict( }))
OnConflict::columns([provider::Column::Name]) .exec_without_returning(&*tx)
.update_column(provider::Column::Name) .await?;
.to_owned(),
);
let provider = if tx.support_returning() {
insert_provider.exec_with_returning(&*tx).await?
} else {
insert_provider.exec_without_returning(&*tx).await?;
provider::Entity::find()
.filter(provider::Column::Name.eq(provider_name))
.one(&*tx)
.await?
.ok_or_else(|| anyhow!("failed to insert provider"))?
};
model::Entity::insert(model::ActiveModel {
provider_id: ActiveValue::set(provider.id),
name: ActiveValue::set(model_name.to_owned()),
..Default::default()
})
.on_conflict(
OnConflict::columns([model::Column::ProviderId, model::Column::Name])
.update_column(model::Column::Name)
.to_owned(),
)
.exec_without_returning(&*tx)
.await?;
}
Ok(()) Ok(())
}) })
.await .await?;
self.initialize_models().await
} }
/// Returns the list of LLM providers. /// Returns the list of LLM providers.
pub async fn list_providers(&self) -> Result<Vec<provider::Model>> { pub async fn list_providers(&self) -> Result<Vec<LanguageModelProvider>> {
self.transaction(|tx| async move { self.transaction(|tx| async move {
Ok(provider::Entity::find() Ok(provider::Entity::find()
.order_by_asc(provider::Column::Name) .order_by_asc(provider::Column::Name)
.all(&*tx) .all(&*tx)
.await?) .await?
.into_iter()
.filter_map(|p| LanguageModelProvider::from_str(&p.name).ok())
.collect())
}) })
.await .await
} }

View File

@ -1,57 +1,318 @@
use chrono::Duration;
use rpc::LanguageModelProvider; use rpc::LanguageModelProvider;
use sea_orm::QuerySelect;
use std::{iter, str::FromStr};
use strum::IntoEnumIterator as _;
use super::*; use super::*;
#[derive(Debug, PartialEq, Clone, Copy)]
pub struct Usage {
pub requests_this_minute: usize,
pub tokens_this_minute: usize,
pub tokens_this_day: usize,
pub tokens_this_month: usize,
}
#[derive(Clone, Copy, Debug, Default)]
pub struct ActiveUserCount {
pub users_in_recent_minutes: usize,
pub users_in_recent_days: usize,
}
impl LlmDatabase { impl LlmDatabase {
pub async fn find_or_create_usage( pub async fn initialize_usage_measures(&mut self) -> Result<()> {
let all_measures = self
.transaction(|tx| async move {
let existing_measures = usage_measure::Entity::find().all(&*tx).await?;
let new_measures = UsageMeasure::iter()
.filter(|measure| {
!existing_measures
.iter()
.any(|m| m.name == measure.to_string())
})
.map(|measure| usage_measure::ActiveModel {
name: ActiveValue::set(measure.to_string()),
..Default::default()
})
.collect::<Vec<_>>();
if !new_measures.is_empty() {
usage_measure::Entity::insert_many(new_measures)
.exec(&*tx)
.await?;
}
Ok(usage_measure::Entity::find().all(&*tx).await?)
})
.await?;
self.usage_measure_ids = all_measures
.into_iter()
.filter_map(|measure| {
UsageMeasure::from_str(&measure.name)
.ok()
.map(|um| (um, measure.id))
})
.collect();
Ok(())
}
pub async fn get_usage(
&self, &self,
user_id: i32, user_id: i32,
provider: LanguageModelProvider, provider: LanguageModelProvider,
model_name: &str, model_name: &str,
) -> Result<usage::Model> { now: DateTimeUtc,
) -> Result<Usage> {
self.transaction(|tx| async move { self.transaction(|tx| async move {
let provider_name = match provider { let model = self
LanguageModelProvider::Anthropic => "anthropic", .models
LanguageModelProvider::OpenAi => "open_ai", .get(&(provider, model_name.to_string()))
LanguageModelProvider::Google => "google", .ok_or_else(|| anyhow!("unknown model {provider}:{model_name}"))?;
LanguageModelProvider::Zed => "zed",
};
let model = model::Entity::find() let usages = usage::Entity::find()
.inner_join(provider::Entity)
.filter(
provider::Column::Name
.eq(provider_name)
.and(model::Column::Name.eq(model_name)),
)
.one(&*tx)
.await?
// TODO: Create the model, if one doesn't exist.
.ok_or_else(|| anyhow!("no model found for {provider_name}:{model_name}"))?;
let model_id = model.id;
let existing_usage = usage::Entity::find()
.filter( .filter(
usage::Column::UserId usage::Column::UserId
.eq(user_id) .eq(user_id)
.and(usage::Column::ModelId.eq(model_id)), .and(usage::Column::ModelId.eq(model.id)),
) )
.one(&*tx) .all(&*tx)
.await?; .await?;
if let Some(usage) = existing_usage {
return Ok(usage);
}
let usage = usage::Entity::insert(usage::ActiveModel { let requests_this_minute =
user_id: ActiveValue::set(user_id), self.get_usage_for_measure(&usages, now, UsageMeasure::RequestsPerMinute)?;
model_id: ActiveValue::set(model_id), let tokens_this_minute =
..Default::default() self.get_usage_for_measure(&usages, now, UsageMeasure::TokensPerMinute)?;
let tokens_this_day =
self.get_usage_for_measure(&usages, now, UsageMeasure::TokensPerDay)?;
let tokens_this_month =
self.get_usage_for_measure(&usages, now, UsageMeasure::TokensPerMonth)?;
Ok(Usage {
requests_this_minute,
tokens_this_minute,
tokens_this_day,
tokens_this_month,
}) })
.exec_with_returning(&*tx)
.await?;
Ok(usage)
}) })
.await .await
} }
pub async fn record_usage(
&self,
user_id: i32,
provider: LanguageModelProvider,
model_name: &str,
token_count: usize,
now: DateTimeUtc,
) -> Result<()> {
self.transaction(|tx| async move {
let model = self.model(provider, model_name)?;
let usages = usage::Entity::find()
.filter(
usage::Column::UserId
.eq(user_id)
.and(usage::Column::ModelId.eq(model.id)),
)
.all(&*tx)
.await?;
self.update_usage_for_measure(
user_id,
model.id,
&usages,
UsageMeasure::RequestsPerMinute,
now,
1,
&tx,
)
.await?;
self.update_usage_for_measure(
user_id,
model.id,
&usages,
UsageMeasure::TokensPerMinute,
now,
token_count,
&tx,
)
.await?;
self.update_usage_for_measure(
user_id,
model.id,
&usages,
UsageMeasure::TokensPerDay,
now,
token_count,
&tx,
)
.await?;
self.update_usage_for_measure(
user_id,
model.id,
&usages,
UsageMeasure::TokensPerMonth,
now,
token_count,
&tx,
)
.await?;
Ok(())
})
.await
}
pub async fn get_active_user_count(&self, now: DateTimeUtc) -> Result<ActiveUserCount> {
self.transaction(|tx| async move {
let minute_since = now - Duration::minutes(5);
let day_since = now - Duration::days(5);
let users_in_recent_minutes = usage::Entity::find()
.filter(usage::Column::Timestamp.gte(minute_since.naive_utc()))
.group_by(usage::Column::UserId)
.count(&*tx)
.await? as usize;
let users_in_recent_days = usage::Entity::find()
.filter(usage::Column::Timestamp.gte(day_since.naive_utc()))
.group_by(usage::Column::UserId)
.count(&*tx)
.await? as usize;
Ok(ActiveUserCount {
users_in_recent_minutes,
users_in_recent_days,
})
})
.await
}
#[allow(clippy::too_many_arguments)]
async fn update_usage_for_measure(
&self,
user_id: i32,
model_id: ModelId,
usages: &[usage::Model],
usage_measure: UsageMeasure,
now: DateTimeUtc,
usage_to_add: usize,
tx: &DatabaseTransaction,
) -> Result<()> {
let now = now.naive_utc();
let measure_id = *self
.usage_measure_ids
.get(&usage_measure)
.ok_or_else(|| anyhow!("usage measure {usage_measure} not found"))?;
let mut id = None;
let mut timestamp = now;
let mut buckets = vec![0_i64];
if let Some(old_usage) = usages.iter().find(|usage| usage.measure_id == measure_id) {
id = Some(old_usage.id);
let (live_buckets, buckets_since) =
Self::get_live_buckets(old_usage, now, usage_measure);
if !live_buckets.is_empty() {
buckets.clear();
buckets.extend_from_slice(live_buckets);
buckets.extend(iter::repeat(0).take(buckets_since));
timestamp =
old_usage.timestamp + (usage_measure.bucket_duration() * buckets_since as i32);
}
}
*buckets.last_mut().unwrap() += usage_to_add as i64;
let mut model = usage::ActiveModel {
user_id: ActiveValue::set(user_id),
model_id: ActiveValue::set(model_id),
measure_id: ActiveValue::set(measure_id),
timestamp: ActiveValue::set(timestamp),
buckets: ActiveValue::set(buckets),
..Default::default()
};
if let Some(id) = id {
model.id = ActiveValue::unchanged(id);
model.update(tx).await?;
} else {
usage::Entity::insert(model)
.exec_without_returning(tx)
.await?;
}
Ok(())
}
fn get_usage_for_measure(
&self,
usages: &[usage::Model],
now: DateTimeUtc,
usage_measure: UsageMeasure,
) -> Result<usize> {
let now = now.naive_utc();
let measure_id = *self
.usage_measure_ids
.get(&usage_measure)
.ok_or_else(|| anyhow!("usage measure {usage_measure} not found"))?;
let Some(usage) = usages.iter().find(|usage| usage.measure_id == measure_id) else {
return Ok(0);
};
let (live_buckets, _) = Self::get_live_buckets(usage, now, usage_measure);
Ok(live_buckets.iter().sum::<i64>() as _)
}
fn get_live_buckets(
usage: &usage::Model,
now: chrono::NaiveDateTime,
measure: UsageMeasure,
) -> (&[i64], usize) {
let seconds_since_usage = (now - usage.timestamp).num_seconds().max(0);
let buckets_since_usage =
seconds_since_usage as f32 / measure.bucket_duration().num_seconds() as f32;
let buckets_since_usage = buckets_since_usage.ceil() as usize;
let mut live_buckets = &[] as &[i64];
if buckets_since_usage < measure.bucket_count() {
let expired_bucket_count =
(usage.buckets.len() + buckets_since_usage).saturating_sub(measure.bucket_count());
live_buckets = &usage.buckets[expired_bucket_count..];
while live_buckets.first() == Some(&0) {
live_buckets = &live_buckets[1..];
}
}
(live_buckets, buckets_since_usage)
}
}
const MINUTE_BUCKET_COUNT: usize = 12;
const DAY_BUCKET_COUNT: usize = 48;
const MONTH_BUCKET_COUNT: usize = 30;
impl UsageMeasure {
fn bucket_count(&self) -> usize {
match self {
UsageMeasure::RequestsPerMinute => MINUTE_BUCKET_COUNT,
UsageMeasure::TokensPerMinute => MINUTE_BUCKET_COUNT,
UsageMeasure::TokensPerDay => DAY_BUCKET_COUNT,
UsageMeasure::TokensPerMonth => MONTH_BUCKET_COUNT,
}
}
fn total_duration(&self) -> Duration {
match self {
UsageMeasure::RequestsPerMinute => Duration::minutes(1),
UsageMeasure::TokensPerMinute => Duration::minutes(1),
UsageMeasure::TokensPerDay => Duration::hours(24),
UsageMeasure::TokensPerMonth => Duration::days(30),
}
}
fn bucket_duration(&self) -> Duration {
self.total_duration() / self.bucket_count() as i32
}
} }

View File

@ -0,0 +1,45 @@
use super::*;
use crate::{Config, Result};
use queries::providers::ModelRateLimits;
pub async fn seed_database(_config: &Config, db: &mut LlmDatabase, _force: bool) -> Result<()> {
db.insert_models(&[
(
LanguageModelProvider::Anthropic,
"claude-3-5-sonnet".into(),
ModelRateLimits {
max_requests_per_minute: 5,
max_tokens_per_minute: 20_000,
max_tokens_per_day: 300_000,
},
),
(
LanguageModelProvider::Anthropic,
"claude-3-opus".into(),
ModelRateLimits {
max_requests_per_minute: 5,
max_tokens_per_minute: 10_000,
max_tokens_per_day: 300_000,
},
),
(
LanguageModelProvider::Anthropic,
"claude-3-sonnet".into(),
ModelRateLimits {
max_requests_per_minute: 5,
max_tokens_per_minute: 20_000,
max_tokens_per_day: 300_000,
},
),
(
LanguageModelProvider::Anthropic,
"claude-3-haiku".into(),
ModelRateLimits {
max_requests_per_minute: 5,
max_tokens_per_minute: 25_000,
max_tokens_per_day: 300_000,
},
),
])
.await
}

View File

@ -1,3 +1,4 @@
pub mod model; pub mod model;
pub mod provider; pub mod provider;
pub mod usage; pub mod usage;
pub mod usage_measure;

View File

@ -10,6 +10,9 @@ pub struct Model {
pub id: ModelId, pub id: ModelId,
pub provider_id: ProviderId, pub provider_id: ProviderId,
pub name: String, pub name: String,
pub max_requests_per_minute: i32,
pub max_tokens_per_minute: i32,
pub max_tokens_per_day: i32,
} }
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] #[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]

View File

@ -1,6 +1,5 @@
use sea_orm::entity::prelude::*;
use crate::llm::db::ProviderId; use crate::llm::db::ProviderId;
use sea_orm::entity::prelude::*;
/// An LLM provider. /// An LLM provider.
#[derive(Clone, Debug, PartialEq, DeriveEntityModel)] #[derive(Clone, Debug, PartialEq, DeriveEntityModel)]

View File

@ -1,24 +1,20 @@
use crate::llm::db::{ModelId, UsageId, UsageMeasureId};
use sea_orm::entity::prelude::*; use sea_orm::entity::prelude::*;
use crate::llm::db::ModelId;
/// An LLM usage record. /// An LLM usage record.
#[derive(Clone, Debug, PartialEq, DeriveEntityModel)] #[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
#[sea_orm(table_name = "usages")] #[sea_orm(table_name = "usages")]
pub struct Model { pub struct Model {
#[sea_orm(primary_key)] #[sea_orm(primary_key)]
pub id: i32, pub id: UsageId,
/// The ID of the Zed user. /// The ID of the Zed user.
/// ///
/// Corresponds to the `users` table in the primary collab database. /// Corresponds to the `users` table in the primary collab database.
pub user_id: i32, pub user_id: i32,
pub model_id: ModelId, pub model_id: ModelId,
pub requests_this_minute: i32, pub measure_id: UsageMeasureId,
pub tokens_this_minute: i64, pub timestamp: DateTime,
pub requests_this_day: i32, pub buckets: Vec<i64>,
pub tokens_this_day: i64,
pub requests_this_month: i32,
pub tokens_this_month: i64,
} }
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] #[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
@ -29,6 +25,12 @@ pub enum Relation {
to = "super::model::Column::Id" to = "super::model::Column::Id"
)] )]
Model, Model,
#[sea_orm(
belongs_to = "super::usage_measure::Entity",
from = "Column::MeasureId",
to = "super::usage_measure::Column::Id"
)]
UsageMeasure,
} }
impl Related<super::model::Entity> for Entity { impl Related<super::model::Entity> for Entity {
@ -37,4 +39,10 @@ impl Related<super::model::Entity> for Entity {
} }
} }
impl Related<super::usage_measure::Entity> for Entity {
fn to() -> RelationDef {
Relation::UsageMeasure.def()
}
}
impl ActiveModelBehavior for ActiveModel {} impl ActiveModelBehavior for ActiveModel {}

View File

@ -0,0 +1,35 @@
use crate::llm::db::UsageMeasureId;
use sea_orm::entity::prelude::*;
#[derive(
Copy, Clone, Debug, PartialEq, Eq, Hash, strum::EnumString, strum::Display, strum::EnumIter,
)]
#[strum(serialize_all = "snake_case")]
pub enum UsageMeasure {
RequestsPerMinute,
TokensPerMinute,
TokensPerDay,
TokensPerMonth,
}
#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
#[sea_orm(table_name = "usage_measures")]
pub struct Model {
#[sea_orm(primary_key)]
pub id: UsageMeasureId,
pub name: String,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {
#[sea_orm(has_many = "super::usage::Entity")]
Usages,
}
impl Related<super::usage::Entity> for Entity {
fn to() -> RelationDef {
Relation::Usages.def()
}
}
impl ActiveModelBehavior for ActiveModel {}

View File

@ -6,7 +6,6 @@ use parking_lot::Mutex;
use rand::prelude::*; use rand::prelude::*;
use sea_orm::ConnectionTrait; use sea_orm::ConnectionTrait;
use sqlx::migrate::MigrateDatabase; use sqlx::migrate::MigrateDatabase;
use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use crate::migrations::run_database_migrations; use crate::migrations::run_database_migrations;
@ -14,47 +13,11 @@ use crate::migrations::run_database_migrations;
use super::*; use super::*;
pub struct TestLlmDb { pub struct TestLlmDb {
pub db: Option<Arc<LlmDatabase>>, pub db: Option<LlmDatabase>,
pub connection: Option<sqlx::AnyConnection>, pub connection: Option<sqlx::AnyConnection>,
} }
impl TestLlmDb { impl TestLlmDb {
pub fn sqlite(background: BackgroundExecutor) -> Self {
let url = "sqlite::memory:";
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_io()
.enable_time()
.build()
.unwrap();
let mut db = runtime.block_on(async {
let mut options = ConnectOptions::new(url);
options.max_connections(5);
let db = LlmDatabase::new(options, Executor::Deterministic(background))
.await
.unwrap();
let sql = include_str!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/migrations_llm.sqlite/20240806182921_test_schema.sql"
));
db.pool
.execute(sea_orm::Statement::from_string(
db.pool.get_database_backend(),
sql,
))
.await
.unwrap();
db
});
db.runtime = Some(runtime);
Self {
db: Some(Arc::new(db)),
connection: None,
}
}
pub fn postgres(background: BackgroundExecutor) -> Self { pub fn postgres(background: BackgroundExecutor) -> Self {
static LOCK: Mutex<()> = Mutex::new(()); static LOCK: Mutex<()> = Mutex::new(());
@ -91,29 +54,26 @@ impl TestLlmDb {
db.runtime = Some(runtime); db.runtime = Some(runtime);
Self { Self {
db: Some(Arc::new(db)), db: Some(db),
connection: None, connection: None,
} }
} }
pub fn db(&self) -> &Arc<LlmDatabase> { pub fn db(&mut self) -> &mut LlmDatabase {
self.db.as_ref().unwrap() self.db.as_mut().unwrap()
} }
} }
#[macro_export] #[macro_export]
macro_rules! test_both_llm_dbs { macro_rules! test_llm_db {
($test_name:ident, $postgres_test_name:ident, $sqlite_test_name:ident) => { ($test_name:ident, $postgres_test_name:ident) => {
#[cfg(target_os = "macos")]
#[gpui::test] #[gpui::test]
async fn $postgres_test_name(cx: &mut gpui::TestAppContext) { async fn $postgres_test_name(cx: &mut gpui::TestAppContext) {
let test_db = $crate::llm::db::TestLlmDb::postgres(cx.executor().clone()); if !cfg!(target_os = "macos") {
$test_name(test_db.db()).await; return;
} }
#[gpui::test] let mut test_db = $crate::llm::db::TestLlmDb::postgres(cx.executor().clone());
async fn $sqlite_test_name(cx: &mut gpui::TestAppContext) {
let test_db = $crate::llm::db::TestLlmDb::sqlite(cx.executor().clone());
$test_name(test_db.db()).await; $test_name(test_db.db()).await;
} }
}; };

View File

@ -1,17 +1,15 @@
use std::sync::Arc;
use pretty_assertions::assert_eq; use pretty_assertions::assert_eq;
use rpc::LanguageModelProvider;
use crate::llm::db::LlmDatabase; use crate::llm::db::LlmDatabase;
use crate::test_both_llm_dbs; use crate::test_llm_db;
test_both_llm_dbs!( test_llm_db!(
test_initialize_providers, test_initialize_providers,
test_initialize_providers_postgres, test_initialize_providers_postgres
test_initialize_providers_sqlite
); );
async fn test_initialize_providers(db: &Arc<LlmDatabase>) { async fn test_initialize_providers(db: &mut LlmDatabase) {
let initial_providers = db.list_providers().await.unwrap(); let initial_providers = db.list_providers().await.unwrap();
assert_eq!(initial_providers, vec![]); assert_eq!(initial_providers, vec![]);
@ -22,9 +20,13 @@ async fn test_initialize_providers(db: &Arc<LlmDatabase>) {
let providers = db.list_providers().await.unwrap(); let providers = db.list_providers().await.unwrap();
let provider_names = providers assert_eq!(
.into_iter() providers,
.map(|provider| provider.name) &[
.collect::<Vec<_>>(); LanguageModelProvider::Anthropic,
assert_eq!(provider_names, vec!["anthropic".to_string()]); LanguageModelProvider::Google,
LanguageModelProvider::OpenAi,
LanguageModelProvider::Zed
]
)
} }

View File

@ -1,24 +1,120 @@
use std::sync::Arc; use crate::{
llm::db::{queries::providers::ModelRateLimits, queries::usages::Usage, LlmDatabase},
test_llm_db,
};
use chrono::{Duration, Utc};
use pretty_assertions::assert_eq; use pretty_assertions::assert_eq;
use rpc::LanguageModelProvider; use rpc::LanguageModelProvider;
use crate::llm::db::LlmDatabase; test_llm_db!(test_tracking_usage, test_tracking_usage_postgres);
use crate::test_both_llm_dbs;
test_both_llm_dbs!( async fn test_tracking_usage(db: &mut LlmDatabase) {
test_find_or_create_usage, let provider = LanguageModelProvider::Anthropic;
test_find_or_create_usage_postgres, let model = "claude-3-5-sonnet";
test_find_or_create_usage_sqlite
);
async fn test_find_or_create_usage(db: &Arc<LlmDatabase>) { db.initialize().await.unwrap();
db.initialize_providers().await.unwrap(); db.insert_models(&[(
provider,
model.to_string(),
ModelRateLimits {
max_requests_per_minute: 5,
max_tokens_per_minute: 10_000,
max_tokens_per_day: 50_000,
},
)])
.await
.unwrap();
let usage = db let t0 = Utc::now();
.find_or_create_usage(123, LanguageModelProvider::Anthropic, "claude-3-5-sonnet") let user_id = 123;
let now = t0;
db.record_usage(user_id, provider, model, 1000, now)
.await .await
.unwrap(); .unwrap();
assert_eq!(usage.user_id, 123); let now = t0 + Duration::seconds(10);
db.record_usage(user_id, provider, model, 2000, now)
.await
.unwrap();
let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
assert_eq!(
usage,
Usage {
requests_this_minute: 2,
tokens_this_minute: 3000,
tokens_this_day: 3000,
tokens_this_month: 3000,
}
);
let now = t0 + Duration::seconds(60);
let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
assert_eq!(
usage,
Usage {
requests_this_minute: 1,
tokens_this_minute: 2000,
tokens_this_day: 3000,
tokens_this_month: 3000,
}
);
let now = t0 + Duration::seconds(60);
db.record_usage(user_id, provider, model, 3000, now)
.await
.unwrap();
let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
assert_eq!(
usage,
Usage {
requests_this_minute: 2,
tokens_this_minute: 5000,
tokens_this_day: 6000,
tokens_this_month: 6000,
}
);
let t1 = t0 + Duration::hours(24);
let now = t1;
let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
assert_eq!(
usage,
Usage {
requests_this_minute: 0,
tokens_this_minute: 0,
tokens_this_day: 5000,
tokens_this_month: 6000,
}
);
db.record_usage(user_id, provider, model, 4000, now)
.await
.unwrap();
let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
assert_eq!(
usage,
Usage {
requests_this_minute: 1,
tokens_this_minute: 4000,
tokens_this_day: 9000,
tokens_this_month: 10000,
}
);
let t2 = t0 + Duration::days(30);
let now = t2;
let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
assert_eq!(
usage,
Usage {
requests_this_minute: 0,
tokens_this_minute: 0,
tokens_this_day: 0,
tokens_this_month: 9000,
}
);
} }

View File

@ -52,10 +52,18 @@ async fn main() -> Result<()> {
Some("seed") => { Some("seed") => {
let config = envy::from_env::<Config>().expect("error loading config"); let config = envy::from_env::<Config>().expect("error loading config");
let db_options = db::ConnectOptions::new(config.database_url.clone()); let db_options = db::ConnectOptions::new(config.database_url.clone());
let mut db = Database::new(db_options, Executor::Production).await?; let mut db = Database::new(db_options, Executor::Production).await?;
db.initialize_notification_kinds().await?; db.initialize_notification_kinds().await?;
collab::seed::seed(&config, &db, true).await?; collab::seed::seed(&config, &db, false).await?;
if let Some(llm_database_url) = config.llm_database_url.clone() {
let db_options = db::ConnectOptions::new(llm_database_url);
let mut db = LlmDatabase::new(db_options.clone(), Executor::Production).await?;
db.initialize().await?;
collab::llm::db::seed_database(&config, &mut db, true).await?;
}
} }
Some("serve") => { Some("serve") => {
let mode = match args.next().as_deref() { let mode = match args.next().as_deref() {

View File

@ -1,9 +1,13 @@
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use strum::{Display, EnumIter, EnumString};
pub const EXPIRED_LLM_TOKEN_HEADER_NAME: &str = "x-zed-expired-token"; pub const EXPIRED_LLM_TOKEN_HEADER_NAME: &str = "x-zed-expired-token";
#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)] #[derive(
Debug, PartialEq, Eq, Hash, Clone, Copy, Serialize, Deserialize, EnumString, EnumIter, Display,
)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
#[strum(serialize_all = "snake_case")]
pub enum LanguageModelProvider { pub enum LanguageModelProvider {
Anthropic, Anthropic,
OpenAi, OpenAi,