mirror of
https://github.com/zed-industries/zed.git
synced 2024-11-08 07:35:01 +03:00
Apply rate limits in LLM service (#15997)
Release Notes: - N/A --------- Co-authored-by: Marshall <marshall@zed.dev> Co-authored-by: Marshall Bowers <elliott.codes@gmail.com>
This commit is contained in:
parent
2bc503771b
commit
06625bfe94
@ -85,6 +85,11 @@ spec:
|
|||||||
secretKeyRef:
|
secretKeyRef:
|
||||||
name: database
|
name: database
|
||||||
key: url
|
key: url
|
||||||
|
- name: LLM_DATABASE_URL
|
||||||
|
valueFrom:
|
||||||
|
secretKeyRef:
|
||||||
|
name: llm-database
|
||||||
|
key: url
|
||||||
- name: DATABASE_MAX_CONNECTIONS
|
- name: DATABASE_MAX_CONNECTIONS
|
||||||
value: "${DATABASE_MAX_CONNECTIONS}"
|
value: "${DATABASE_MAX_CONNECTIONS}"
|
||||||
- name: API_TOKEN
|
- name: API_TOKEN
|
||||||
|
@ -12,7 +12,7 @@ metadata:
|
|||||||
spec:
|
spec:
|
||||||
type: LoadBalancer
|
type: LoadBalancer
|
||||||
selector:
|
selector:
|
||||||
app: postgrest
|
app: nginx
|
||||||
ports:
|
ports:
|
||||||
- name: web
|
- name: web
|
||||||
protocol: TCP
|
protocol: TCP
|
||||||
@ -24,17 +24,99 @@ apiVersion: apps/v1
|
|||||||
kind: Deployment
|
kind: Deployment
|
||||||
metadata:
|
metadata:
|
||||||
namespace: ${ZED_KUBE_NAMESPACE}
|
namespace: ${ZED_KUBE_NAMESPACE}
|
||||||
name: postgrest
|
name: nginx
|
||||||
|
|
||||||
spec:
|
spec:
|
||||||
replicas: 1
|
replicas: 1
|
||||||
selector:
|
selector:
|
||||||
matchLabels:
|
matchLabels:
|
||||||
app: postgrest
|
app: nginx
|
||||||
template:
|
template:
|
||||||
metadata:
|
metadata:
|
||||||
labels:
|
labels:
|
||||||
app: postgrest
|
app: nginx
|
||||||
|
spec:
|
||||||
|
containers:
|
||||||
|
- name: nginx
|
||||||
|
image: nginx:latest
|
||||||
|
ports:
|
||||||
|
- containerPort: 8080
|
||||||
|
protocol: TCP
|
||||||
|
volumeMounts:
|
||||||
|
- name: nginx-config
|
||||||
|
mountPath: /etc/nginx/nginx.conf
|
||||||
|
subPath: nginx.conf
|
||||||
|
volumes:
|
||||||
|
- name: nginx-config
|
||||||
|
configMap:
|
||||||
|
name: nginx-config
|
||||||
|
|
||||||
|
---
|
||||||
|
apiVersion: v1
|
||||||
|
kind: ConfigMap
|
||||||
|
metadata:
|
||||||
|
namespace: ${ZED_KUBE_NAMESPACE}
|
||||||
|
name: nginx-config
|
||||||
|
data:
|
||||||
|
nginx.conf: |
|
||||||
|
events {}
|
||||||
|
|
||||||
|
http {
|
||||||
|
server {
|
||||||
|
listen 8080;
|
||||||
|
|
||||||
|
location /app/ {
|
||||||
|
proxy_pass http://postgrest-app:8080/;
|
||||||
|
}
|
||||||
|
|
||||||
|
location /llm/ {
|
||||||
|
proxy_pass http://postgrest-llm:8080/;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
---
|
||||||
|
apiVersion: v1
|
||||||
|
kind: Service
|
||||||
|
metadata:
|
||||||
|
namespace: ${ZED_KUBE_NAMESPACE}
|
||||||
|
name: postgrest-app
|
||||||
|
spec:
|
||||||
|
selector:
|
||||||
|
app: postgrest-app
|
||||||
|
ports:
|
||||||
|
- protocol: TCP
|
||||||
|
port: 8080
|
||||||
|
targetPort: 8080
|
||||||
|
|
||||||
|
---
|
||||||
|
apiVersion: v1
|
||||||
|
kind: Service
|
||||||
|
metadata:
|
||||||
|
namespace: ${ZED_KUBE_NAMESPACE}
|
||||||
|
name: postgrest-llm
|
||||||
|
spec:
|
||||||
|
selector:
|
||||||
|
app: postgrest-llm
|
||||||
|
ports:
|
||||||
|
- protocol: TCP
|
||||||
|
port: 8080
|
||||||
|
targetPort: 8080
|
||||||
|
|
||||||
|
---
|
||||||
|
apiVersion: apps/v1
|
||||||
|
kind: Deployment
|
||||||
|
metadata:
|
||||||
|
namespace: ${ZED_KUBE_NAMESPACE}
|
||||||
|
name: postgrest-app
|
||||||
|
spec:
|
||||||
|
replicas: 1
|
||||||
|
selector:
|
||||||
|
matchLabels:
|
||||||
|
app: postgrest-app
|
||||||
|
template:
|
||||||
|
metadata:
|
||||||
|
labels:
|
||||||
|
app: postgrest-app
|
||||||
spec:
|
spec:
|
||||||
containers:
|
containers:
|
||||||
- name: postgrest
|
- name: postgrest
|
||||||
@ -55,3 +137,39 @@ spec:
|
|||||||
secretKeyRef:
|
secretKeyRef:
|
||||||
name: postgrest
|
name: postgrest
|
||||||
key: jwt_secret
|
key: jwt_secret
|
||||||
|
|
||||||
|
---
|
||||||
|
apiVersion: apps/v1
|
||||||
|
kind: Deployment
|
||||||
|
metadata:
|
||||||
|
namespace: ${ZED_KUBE_NAMESPACE}
|
||||||
|
name: postgrest-llm
|
||||||
|
spec:
|
||||||
|
replicas: 1
|
||||||
|
selector:
|
||||||
|
matchLabels:
|
||||||
|
app: postgrest-llm
|
||||||
|
template:
|
||||||
|
metadata:
|
||||||
|
labels:
|
||||||
|
app: postgrest-llm
|
||||||
|
spec:
|
||||||
|
containers:
|
||||||
|
- name: postgrest
|
||||||
|
image: "postgrest/postgrest"
|
||||||
|
ports:
|
||||||
|
- containerPort: 8080
|
||||||
|
protocol: TCP
|
||||||
|
env:
|
||||||
|
- name: PGRST_SERVER_PORT
|
||||||
|
value: "8080"
|
||||||
|
- name: PGRST_DB_URI
|
||||||
|
valueFrom:
|
||||||
|
secretKeyRef:
|
||||||
|
name: llm-database
|
||||||
|
key: url
|
||||||
|
- name: PGRST_JWT_SECRET
|
||||||
|
valueFrom:
|
||||||
|
secretKeyRef:
|
||||||
|
name: postgrest
|
||||||
|
key: jwt_secret
|
||||||
|
@ -1,32 +0,0 @@
|
|||||||
create table providers (
|
|
||||||
id integer primary key autoincrement,
|
|
||||||
name text not null
|
|
||||||
);
|
|
||||||
|
|
||||||
create unique index uix_providers_on_name on providers (name);
|
|
||||||
|
|
||||||
create table models (
|
|
||||||
id integer primary key autoincrement,
|
|
||||||
provider_id integer not null references providers (id) on delete cascade,
|
|
||||||
name text not null
|
|
||||||
);
|
|
||||||
|
|
||||||
create unique index uix_models_on_provider_id_name on models (provider_id, name);
|
|
||||||
create index ix_models_on_provider_id on models (provider_id);
|
|
||||||
create index ix_models_on_name on models (name);
|
|
||||||
|
|
||||||
create table if not exists usages (
|
|
||||||
id integer primary key autoincrement,
|
|
||||||
user_id integer not null,
|
|
||||||
model_id integer not null references models (id) on delete cascade,
|
|
||||||
requests_this_minute integer not null default 0,
|
|
||||||
tokens_this_minute integer not null default 0,
|
|
||||||
requests_this_day integer not null default 0,
|
|
||||||
tokens_this_day integer not null default 0,
|
|
||||||
requests_this_month integer not null default 0,
|
|
||||||
tokens_this_month integer not null default 0
|
|
||||||
);
|
|
||||||
|
|
||||||
create index ix_usages_on_user_id on usages (user_id);
|
|
||||||
create index ix_usages_on_model_id on usages (model_id);
|
|
||||||
create unique index uix_usages_on_user_id_model_id on usages (user_id, model_id);
|
|
@ -8,7 +8,10 @@ create unique index uix_providers_on_name on providers (name);
|
|||||||
create table if not exists models (
|
create table if not exists models (
|
||||||
id serial primary key,
|
id serial primary key,
|
||||||
provider_id integer not null references providers (id) on delete cascade,
|
provider_id integer not null references providers (id) on delete cascade,
|
||||||
name text not null
|
name text not null,
|
||||||
|
max_requests_per_minute integer not null,
|
||||||
|
max_tokens_per_minute integer not null,
|
||||||
|
max_tokens_per_day integer not null
|
||||||
);
|
);
|
||||||
|
|
||||||
create unique index uix_models_on_provider_id_name on models (provider_id, name);
|
create unique index uix_models_on_provider_id_name on models (provider_id, name);
|
||||||
|
@ -1,15 +1,19 @@
|
|||||||
|
create table usage_measures (
|
||||||
|
id serial primary key,
|
||||||
|
name text not null
|
||||||
|
);
|
||||||
|
|
||||||
|
create unique index uix_usage_measures_on_name on usage_measures (name);
|
||||||
|
|
||||||
create table if not exists usages (
|
create table if not exists usages (
|
||||||
id serial primary key,
|
id serial primary key,
|
||||||
user_id integer not null,
|
user_id integer not null,
|
||||||
model_id integer not null references models (id) on delete cascade,
|
model_id integer not null references models (id) on delete cascade,
|
||||||
requests_this_minute integer not null default 0,
|
measure_id integer not null references usage_measures (id) on delete cascade,
|
||||||
tokens_this_minute bigint not null default 0,
|
timestamp timestamp without time zone not null,
|
||||||
requests_this_day integer not null default 0,
|
buckets bigint[] not null
|
||||||
tokens_this_day bigint not null default 0,
|
|
||||||
requests_this_month integer not null default 0,
|
|
||||||
tokens_this_month bigint not null default 0
|
|
||||||
);
|
);
|
||||||
|
|
||||||
create index ix_usages_on_user_id on usages (user_id);
|
create index ix_usages_on_user_id on usages (user_id);
|
||||||
create index ix_usages_on_model_id on usages (model_id);
|
create index ix_usages_on_model_id on usages (model_id);
|
||||||
create unique index uix_usages_on_user_id_model_id on usages (user_id, model_id);
|
create unique index uix_usages_on_user_id_model_id_measure_id on usages (user_id, model_id, measure_id);
|
||||||
|
@ -2,24 +2,25 @@ mod authorization;
|
|||||||
pub mod db;
|
pub mod db;
|
||||||
mod token;
|
mod token;
|
||||||
|
|
||||||
use crate::api::CloudflareIpCountryHeader;
|
use crate::{api::CloudflareIpCountryHeader, executor::Executor, Config, Error, Result};
|
||||||
use crate::llm::authorization::authorize_access_to_language_model;
|
|
||||||
use crate::llm::db::LlmDatabase;
|
|
||||||
use crate::{executor::Executor, Config, Error, Result};
|
|
||||||
use anyhow::{anyhow, Context as _};
|
use anyhow::{anyhow, Context as _};
|
||||||
use axum::TypedHeader;
|
use authorization::authorize_access_to_language_model;
|
||||||
use axum::{
|
use axum::{
|
||||||
body::Body,
|
body::Body,
|
||||||
http::{self, HeaderName, HeaderValue, Request, StatusCode},
|
http::{self, HeaderName, HeaderValue, Request, StatusCode},
|
||||||
middleware::{self, Next},
|
middleware::{self, Next},
|
||||||
response::{IntoResponse, Response},
|
response::{IntoResponse, Response},
|
||||||
routing::post,
|
routing::post,
|
||||||
Extension, Json, Router,
|
Extension, Json, Router, TypedHeader,
|
||||||
};
|
};
|
||||||
|
use chrono::{DateTime, Duration, Utc};
|
||||||
|
use db::{ActiveUserCount, LlmDatabase};
|
||||||
use futures::StreamExt as _;
|
use futures::StreamExt as _;
|
||||||
use http_client::IsahcHttpClient;
|
use http_client::IsahcHttpClient;
|
||||||
use rpc::{LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME};
|
use rpc::{LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
use tokio::sync::RwLock;
|
||||||
|
use util::ResultExt;
|
||||||
|
|
||||||
pub use token::*;
|
pub use token::*;
|
||||||
|
|
||||||
@ -28,8 +29,11 @@ pub struct LlmState {
|
|||||||
pub executor: Executor,
|
pub executor: Executor,
|
||||||
pub db: Option<Arc<LlmDatabase>>,
|
pub db: Option<Arc<LlmDatabase>>,
|
||||||
pub http_client: IsahcHttpClient,
|
pub http_client: IsahcHttpClient,
|
||||||
|
active_user_count: RwLock<Option<(DateTime<Utc>, ActiveUserCount)>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const ACTIVE_USER_COUNT_CACHE_DURATION: Duration = Duration::seconds(30);
|
||||||
|
|
||||||
impl LlmState {
|
impl LlmState {
|
||||||
pub async fn new(config: Config, executor: Executor) -> Result<Arc<Self>> {
|
pub async fn new(config: Config, executor: Executor) -> Result<Arc<Self>> {
|
||||||
// TODO: This is temporary until we have the LLM database stood up.
|
// TODO: This is temporary until we have the LLM database stood up.
|
||||||
@ -44,7 +48,8 @@ impl LlmState {
|
|||||||
|
|
||||||
let mut db_options = db::ConnectOptions::new(database_url);
|
let mut db_options = db::ConnectOptions::new(database_url);
|
||||||
db_options.max_connections(max_connections);
|
db_options.max_connections(max_connections);
|
||||||
let db = LlmDatabase::new(db_options, executor.clone()).await?;
|
let mut db = LlmDatabase::new(db_options, executor.clone()).await?;
|
||||||
|
db.initialize().await?;
|
||||||
|
|
||||||
Some(Arc::new(db))
|
Some(Arc::new(db))
|
||||||
} else {
|
} else {
|
||||||
@ -57,15 +62,41 @@ impl LlmState {
|
|||||||
.build()
|
.build()
|
||||||
.context("failed to construct http client")?;
|
.context("failed to construct http client")?;
|
||||||
|
|
||||||
|
let initial_active_user_count = if let Some(db) = &db {
|
||||||
|
Some((Utc::now(), db.get_active_user_count(Utc::now()).await?))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
let this = Self {
|
let this = Self {
|
||||||
config,
|
config,
|
||||||
executor,
|
executor,
|
||||||
db,
|
db,
|
||||||
http_client,
|
http_client,
|
||||||
|
active_user_count: RwLock::new(initial_active_user_count),
|
||||||
};
|
};
|
||||||
|
|
||||||
Ok(Arc::new(this))
|
Ok(Arc::new(this))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub async fn get_active_user_count(&self) -> Result<ActiveUserCount> {
|
||||||
|
let now = Utc::now();
|
||||||
|
|
||||||
|
if let Some((last_updated, count)) = self.active_user_count.read().await.as_ref() {
|
||||||
|
if now - *last_updated < ACTIVE_USER_COUNT_CACHE_DURATION {
|
||||||
|
return Ok(*count);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(db) = &self.db {
|
||||||
|
let mut cache = self.active_user_count.write().await;
|
||||||
|
let new_count = db.get_active_user_count(now).await?;
|
||||||
|
*cache = Some((now, new_count));
|
||||||
|
Ok(new_count)
|
||||||
|
} else {
|
||||||
|
Ok(ActiveUserCount::default())
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn routes() -> Router<(), Body> {
|
pub fn routes() -> Router<(), Body> {
|
||||||
@ -122,14 +153,22 @@ async fn perform_completion(
|
|||||||
country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
|
country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
|
||||||
Json(params): Json<PerformCompletionParams>,
|
Json(params): Json<PerformCompletionParams>,
|
||||||
) -> Result<impl IntoResponse> {
|
) -> Result<impl IntoResponse> {
|
||||||
|
let model = normalize_model_name(params.provider, params.model);
|
||||||
|
|
||||||
authorize_access_to_language_model(
|
authorize_access_to_language_model(
|
||||||
&state.config,
|
&state.config,
|
||||||
&claims,
|
&claims,
|
||||||
country_code_header.map(|header| header.to_string()),
|
country_code_header.map(|header| header.to_string()),
|
||||||
params.provider,
|
params.provider,
|
||||||
¶ms.model,
|
&model,
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
|
let user_id = claims.user_id as i32;
|
||||||
|
|
||||||
|
if state.db.is_some() {
|
||||||
|
check_usage_limit(&state, params.provider, &model, &claims).await?;
|
||||||
|
}
|
||||||
|
|
||||||
match params.provider {
|
match params.provider {
|
||||||
LanguageModelProvider::Anthropic => {
|
LanguageModelProvider::Anthropic => {
|
||||||
let api_key = state
|
let api_key = state
|
||||||
@ -160,9 +199,31 @@ async fn perform_completion(
|
|||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let stream = chunks.map(|event| {
|
let mut recorder = state.db.clone().map(|db| UsageRecorder {
|
||||||
|
db,
|
||||||
|
executor: state.executor.clone(),
|
||||||
|
user_id,
|
||||||
|
provider: params.provider,
|
||||||
|
model,
|
||||||
|
token_count: 0,
|
||||||
|
});
|
||||||
|
|
||||||
|
let stream = chunks.map(move |event| {
|
||||||
let mut buffer = Vec::new();
|
let mut buffer = Vec::new();
|
||||||
event.map(|chunk| {
|
event.map(|chunk| {
|
||||||
|
match &chunk {
|
||||||
|
anthropic::Event::MessageStart {
|
||||||
|
message: anthropic::Response { usage, .. },
|
||||||
|
}
|
||||||
|
| anthropic::Event::MessageDelta { usage, .. } => {
|
||||||
|
if let Some(recorder) = &mut recorder {
|
||||||
|
recorder.token_count += usage.input_tokens.unwrap_or(0) as usize;
|
||||||
|
recorder.token_count += usage.output_tokens.unwrap_or(0) as usize;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
|
||||||
buffer.clear();
|
buffer.clear();
|
||||||
serde_json::to_writer(&mut buffer, &chunk).unwrap();
|
serde_json::to_writer(&mut buffer, &chunk).unwrap();
|
||||||
buffer.push(b'\n');
|
buffer.push(b'\n');
|
||||||
@ -259,3 +320,102 @@ async fn perform_completion(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn normalize_model_name(provider: LanguageModelProvider, name: String) -> String {
|
||||||
|
match provider {
|
||||||
|
LanguageModelProvider::Anthropic => {
|
||||||
|
for prefix in &[
|
||||||
|
"claude-3-5-sonnet",
|
||||||
|
"claude-3-haiku",
|
||||||
|
"claude-3-opus",
|
||||||
|
"claude-3-sonnet",
|
||||||
|
] {
|
||||||
|
if name.starts_with(prefix) {
|
||||||
|
return prefix.to_string();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
LanguageModelProvider::OpenAi => {}
|
||||||
|
LanguageModelProvider::Google => {}
|
||||||
|
LanguageModelProvider::Zed => {}
|
||||||
|
}
|
||||||
|
|
||||||
|
name
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn check_usage_limit(
|
||||||
|
state: &Arc<LlmState>,
|
||||||
|
provider: LanguageModelProvider,
|
||||||
|
model_name: &str,
|
||||||
|
claims: &LlmTokenClaims,
|
||||||
|
) -> Result<()> {
|
||||||
|
let db = state
|
||||||
|
.db
|
||||||
|
.as_ref()
|
||||||
|
.ok_or_else(|| anyhow!("LLM database not configured"))?;
|
||||||
|
let model = db.model(provider, model_name)?;
|
||||||
|
let usage = db
|
||||||
|
.get_usage(claims.user_id as i32, provider, model_name, Utc::now())
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let active_users = state.get_active_user_count().await?;
|
||||||
|
|
||||||
|
let per_user_max_requests_per_minute =
|
||||||
|
model.max_requests_per_minute as usize / active_users.users_in_recent_minutes.max(1);
|
||||||
|
let per_user_max_tokens_per_minute =
|
||||||
|
model.max_tokens_per_minute as usize / active_users.users_in_recent_minutes.max(1);
|
||||||
|
let per_user_max_tokens_per_day =
|
||||||
|
model.max_tokens_per_day as usize / active_users.users_in_recent_days.max(1);
|
||||||
|
|
||||||
|
let checks = [
|
||||||
|
(
|
||||||
|
usage.requests_this_minute,
|
||||||
|
per_user_max_requests_per_minute,
|
||||||
|
"requests per minute",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
usage.tokens_this_minute,
|
||||||
|
per_user_max_tokens_per_minute,
|
||||||
|
"tokens per minute",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
usage.tokens_this_day,
|
||||||
|
per_user_max_tokens_per_day,
|
||||||
|
"tokens per day",
|
||||||
|
),
|
||||||
|
];
|
||||||
|
|
||||||
|
for (usage, limit, resource) in checks {
|
||||||
|
if usage > limit {
|
||||||
|
return Err(Error::http(
|
||||||
|
StatusCode::TOO_MANY_REQUESTS,
|
||||||
|
format!("Rate limit exceeded. Maximum {} reached.", resource),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
struct UsageRecorder {
|
||||||
|
db: Arc<LlmDatabase>,
|
||||||
|
executor: Executor,
|
||||||
|
user_id: i32,
|
||||||
|
provider: LanguageModelProvider,
|
||||||
|
model: String,
|
||||||
|
token_count: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for UsageRecorder {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
let db = self.db.clone();
|
||||||
|
let user_id = self.user_id;
|
||||||
|
let provider = self.provider;
|
||||||
|
let model = std::mem::take(&mut self.model);
|
||||||
|
let token_count = self.token_count;
|
||||||
|
self.executor.spawn_detached(async move {
|
||||||
|
db.record_usage(user_id, provider, &model, token_count, Utc::now())
|
||||||
|
.await
|
||||||
|
.log_err();
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -1,20 +1,26 @@
|
|||||||
mod ids;
|
mod ids;
|
||||||
mod queries;
|
mod queries;
|
||||||
|
mod seed;
|
||||||
mod tables;
|
mod tables;
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests;
|
mod tests;
|
||||||
|
|
||||||
|
use collections::HashMap;
|
||||||
pub use ids::*;
|
pub use ids::*;
|
||||||
|
use rpc::LanguageModelProvider;
|
||||||
|
pub use seed::*;
|
||||||
pub use tables::*;
|
pub use tables::*;
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
pub use tests::TestLlmDb;
|
pub use tests::TestLlmDb;
|
||||||
|
use usage_measure::UsageMeasure;
|
||||||
|
|
||||||
use std::future::Future;
|
use std::future::Future;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use anyhow::anyhow;
|
use anyhow::anyhow;
|
||||||
|
pub use queries::usages::ActiveUserCount;
|
||||||
use sea_orm::prelude::*;
|
use sea_orm::prelude::*;
|
||||||
pub use sea_orm::ConnectOptions;
|
pub use sea_orm::ConnectOptions;
|
||||||
use sea_orm::{
|
use sea_orm::{
|
||||||
@ -31,6 +37,9 @@ pub struct LlmDatabase {
|
|||||||
pool: DatabaseConnection,
|
pool: DatabaseConnection,
|
||||||
#[allow(unused)]
|
#[allow(unused)]
|
||||||
executor: Executor,
|
executor: Executor,
|
||||||
|
provider_ids: HashMap<LanguageModelProvider, ProviderId>,
|
||||||
|
models: HashMap<(LanguageModelProvider, String), model::Model>,
|
||||||
|
usage_measure_ids: HashMap<UsageMeasure, UsageMeasureId>,
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
runtime: Option<tokio::runtime::Runtime>,
|
runtime: Option<tokio::runtime::Runtime>,
|
||||||
}
|
}
|
||||||
@ -43,11 +52,28 @@ impl LlmDatabase {
|
|||||||
options: options.clone(),
|
options: options.clone(),
|
||||||
pool: sea_orm::Database::connect(options).await?,
|
pool: sea_orm::Database::connect(options).await?,
|
||||||
executor,
|
executor,
|
||||||
|
provider_ids: HashMap::default(),
|
||||||
|
models: HashMap::default(),
|
||||||
|
usage_measure_ids: HashMap::default(),
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
runtime: None,
|
runtime: None,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub async fn initialize(&mut self) -> Result<()> {
|
||||||
|
self.initialize_providers().await?;
|
||||||
|
self.initialize_models().await?;
|
||||||
|
self.initialize_usage_measures().await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn model(&self, provider: LanguageModelProvider, name: &str) -> Result<&model::Model> {
|
||||||
|
Ok(self
|
||||||
|
.models
|
||||||
|
.get(&(provider, name.to_string()))
|
||||||
|
.ok_or_else(|| anyhow!("unknown model {provider:?}:{name}"))?)
|
||||||
|
}
|
||||||
|
|
||||||
pub fn options(&self) -> &ConnectOptions {
|
pub fn options(&self) -> &ConnectOptions {
|
||||||
&self.options
|
&self.options
|
||||||
}
|
}
|
||||||
|
@ -6,3 +6,4 @@ use crate::id_type;
|
|||||||
id_type!(ModelId);
|
id_type!(ModelId);
|
||||||
id_type!(ProviderId);
|
id_type!(ProviderId);
|
||||||
id_type!(UsageId);
|
id_type!(UsageId);
|
||||||
|
id_type!(UsageMeasureId);
|
||||||
|
@ -1,66 +1,115 @@
|
|||||||
use sea_orm::sea_query::OnConflict;
|
|
||||||
use sea_orm::QueryOrder;
|
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
|
use sea_orm::QueryOrder;
|
||||||
|
use std::str::FromStr;
|
||||||
|
use strum::IntoEnumIterator as _;
|
||||||
|
|
||||||
|
pub struct ModelRateLimits {
|
||||||
|
pub max_requests_per_minute: i32,
|
||||||
|
pub max_tokens_per_minute: i32,
|
||||||
|
pub max_tokens_per_day: i32,
|
||||||
|
}
|
||||||
|
|
||||||
impl LlmDatabase {
|
impl LlmDatabase {
|
||||||
pub async fn initialize_providers(&self) -> Result<()> {
|
pub async fn initialize_providers(&mut self) -> Result<()> {
|
||||||
|
self.provider_ids = self
|
||||||
|
.transaction(|tx| async move {
|
||||||
|
let existing_providers = provider::Entity::find().all(&*tx).await?;
|
||||||
|
|
||||||
|
let mut new_providers = LanguageModelProvider::iter()
|
||||||
|
.filter(|provider| {
|
||||||
|
!existing_providers
|
||||||
|
.iter()
|
||||||
|
.any(|p| p.name == provider.to_string())
|
||||||
|
})
|
||||||
|
.map(|provider| provider::ActiveModel {
|
||||||
|
name: ActiveValue::set(provider.to_string()),
|
||||||
|
..Default::default()
|
||||||
|
})
|
||||||
|
.peekable();
|
||||||
|
|
||||||
|
if new_providers.peek().is_some() {
|
||||||
|
provider::Entity::insert_many(new_providers)
|
||||||
|
.exec(&*tx)
|
||||||
|
.await?;
|
||||||
|
}
|
||||||
|
|
||||||
|
let all_providers: HashMap<_, _> = provider::Entity::find()
|
||||||
|
.all(&*tx)
|
||||||
|
.await?
|
||||||
|
.iter()
|
||||||
|
.filter_map(|provider| {
|
||||||
|
LanguageModelProvider::from_str(&provider.name)
|
||||||
|
.ok()
|
||||||
|
.map(|p| (p, provider.id))
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
Ok(all_providers)
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn initialize_models(&mut self) -> Result<()> {
|
||||||
|
let all_provider_ids = &self.provider_ids;
|
||||||
|
self.models = self
|
||||||
|
.transaction(|tx| async move {
|
||||||
|
let all_models: HashMap<_, _> = model::Entity::find()
|
||||||
|
.all(&*tx)
|
||||||
|
.await?
|
||||||
|
.into_iter()
|
||||||
|
.filter_map(|model| {
|
||||||
|
let provider = all_provider_ids.iter().find_map(|(provider, id)| {
|
||||||
|
if *id == model.provider_id {
|
||||||
|
Some(provider)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
})?;
|
||||||
|
Some(((*provider, model.name.clone()), model))
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
Ok(all_models)
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn insert_models(
|
||||||
|
&mut self,
|
||||||
|
models: &[(LanguageModelProvider, String, ModelRateLimits)],
|
||||||
|
) -> Result<()> {
|
||||||
|
let all_provider_ids = &self.provider_ids;
|
||||||
self.transaction(|tx| async move {
|
self.transaction(|tx| async move {
|
||||||
let providers_and_models = vec![
|
model::Entity::insert_many(models.into_iter().map(|(provider, name, rate_limits)| {
|
||||||
("anthropic", "claude-3-5-sonnet"),
|
let provider_id = all_provider_ids[&provider];
|
||||||
("anthropic", "claude-3-opus"),
|
model::ActiveModel {
|
||||||
("anthropic", "claude-3-sonnet"),
|
provider_id: ActiveValue::set(provider_id),
|
||||||
("anthropic", "claude-3-haiku"),
|
name: ActiveValue::set(name.clone()),
|
||||||
];
|
max_requests_per_minute: ActiveValue::set(rate_limits.max_requests_per_minute),
|
||||||
|
max_tokens_per_minute: ActiveValue::set(rate_limits.max_tokens_per_minute),
|
||||||
for (provider_name, model_name) in providers_and_models {
|
max_tokens_per_day: ActiveValue::set(rate_limits.max_tokens_per_day),
|
||||||
let insert_provider = provider::Entity::insert(provider::ActiveModel {
|
|
||||||
name: ActiveValue::set(provider_name.to_owned()),
|
|
||||||
..Default::default()
|
..Default::default()
|
||||||
})
|
}
|
||||||
.on_conflict(
|
}))
|
||||||
OnConflict::columns([provider::Column::Name])
|
.exec_without_returning(&*tx)
|
||||||
.update_column(provider::Column::Name)
|
.await?;
|
||||||
.to_owned(),
|
|
||||||
);
|
|
||||||
|
|
||||||
let provider = if tx.support_returning() {
|
|
||||||
insert_provider.exec_with_returning(&*tx).await?
|
|
||||||
} else {
|
|
||||||
insert_provider.exec_without_returning(&*tx).await?;
|
|
||||||
provider::Entity::find()
|
|
||||||
.filter(provider::Column::Name.eq(provider_name))
|
|
||||||
.one(&*tx)
|
|
||||||
.await?
|
|
||||||
.ok_or_else(|| anyhow!("failed to insert provider"))?
|
|
||||||
};
|
|
||||||
|
|
||||||
model::Entity::insert(model::ActiveModel {
|
|
||||||
provider_id: ActiveValue::set(provider.id),
|
|
||||||
name: ActiveValue::set(model_name.to_owned()),
|
|
||||||
..Default::default()
|
|
||||||
})
|
|
||||||
.on_conflict(
|
|
||||||
OnConflict::columns([model::Column::ProviderId, model::Column::Name])
|
|
||||||
.update_column(model::Column::Name)
|
|
||||||
.to_owned(),
|
|
||||||
)
|
|
||||||
.exec_without_returning(&*tx)
|
|
||||||
.await?;
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
})
|
})
|
||||||
.await
|
.await?;
|
||||||
|
self.initialize_models().await
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the list of LLM providers.
|
/// Returns the list of LLM providers.
|
||||||
pub async fn list_providers(&self) -> Result<Vec<provider::Model>> {
|
pub async fn list_providers(&self) -> Result<Vec<LanguageModelProvider>> {
|
||||||
self.transaction(|tx| async move {
|
self.transaction(|tx| async move {
|
||||||
Ok(provider::Entity::find()
|
Ok(provider::Entity::find()
|
||||||
.order_by_asc(provider::Column::Name)
|
.order_by_asc(provider::Column::Name)
|
||||||
.all(&*tx)
|
.all(&*tx)
|
||||||
.await?)
|
.await?
|
||||||
|
.into_iter()
|
||||||
|
.filter_map(|p| LanguageModelProvider::from_str(&p.name).ok())
|
||||||
|
.collect())
|
||||||
})
|
})
|
||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
@ -1,57 +1,318 @@
|
|||||||
|
use chrono::Duration;
|
||||||
use rpc::LanguageModelProvider;
|
use rpc::LanguageModelProvider;
|
||||||
|
use sea_orm::QuerySelect;
|
||||||
|
use std::{iter, str::FromStr};
|
||||||
|
use strum::IntoEnumIterator as _;
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
|
#[derive(Debug, PartialEq, Clone, Copy)]
|
||||||
|
pub struct Usage {
|
||||||
|
pub requests_this_minute: usize,
|
||||||
|
pub tokens_this_minute: usize,
|
||||||
|
pub tokens_this_day: usize,
|
||||||
|
pub tokens_this_month: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Copy, Debug, Default)]
|
||||||
|
pub struct ActiveUserCount {
|
||||||
|
pub users_in_recent_minutes: usize,
|
||||||
|
pub users_in_recent_days: usize,
|
||||||
|
}
|
||||||
|
|
||||||
impl LlmDatabase {
|
impl LlmDatabase {
|
||||||
pub async fn find_or_create_usage(
|
pub async fn initialize_usage_measures(&mut self) -> Result<()> {
|
||||||
|
let all_measures = self
|
||||||
|
.transaction(|tx| async move {
|
||||||
|
let existing_measures = usage_measure::Entity::find().all(&*tx).await?;
|
||||||
|
|
||||||
|
let new_measures = UsageMeasure::iter()
|
||||||
|
.filter(|measure| {
|
||||||
|
!existing_measures
|
||||||
|
.iter()
|
||||||
|
.any(|m| m.name == measure.to_string())
|
||||||
|
})
|
||||||
|
.map(|measure| usage_measure::ActiveModel {
|
||||||
|
name: ActiveValue::set(measure.to_string()),
|
||||||
|
..Default::default()
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
if !new_measures.is_empty() {
|
||||||
|
usage_measure::Entity::insert_many(new_measures)
|
||||||
|
.exec(&*tx)
|
||||||
|
.await?;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(usage_measure::Entity::find().all(&*tx).await?)
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
self.usage_measure_ids = all_measures
|
||||||
|
.into_iter()
|
||||||
|
.filter_map(|measure| {
|
||||||
|
UsageMeasure::from_str(&measure.name)
|
||||||
|
.ok()
|
||||||
|
.map(|um| (um, measure.id))
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn get_usage(
|
||||||
&self,
|
&self,
|
||||||
user_id: i32,
|
user_id: i32,
|
||||||
provider: LanguageModelProvider,
|
provider: LanguageModelProvider,
|
||||||
model_name: &str,
|
model_name: &str,
|
||||||
) -> Result<usage::Model> {
|
now: DateTimeUtc,
|
||||||
|
) -> Result<Usage> {
|
||||||
self.transaction(|tx| async move {
|
self.transaction(|tx| async move {
|
||||||
let provider_name = match provider {
|
let model = self
|
||||||
LanguageModelProvider::Anthropic => "anthropic",
|
.models
|
||||||
LanguageModelProvider::OpenAi => "open_ai",
|
.get(&(provider, model_name.to_string()))
|
||||||
LanguageModelProvider::Google => "google",
|
.ok_or_else(|| anyhow!("unknown model {provider}:{model_name}"))?;
|
||||||
LanguageModelProvider::Zed => "zed",
|
|
||||||
};
|
|
||||||
|
|
||||||
let model = model::Entity::find()
|
let usages = usage::Entity::find()
|
||||||
.inner_join(provider::Entity)
|
|
||||||
.filter(
|
|
||||||
provider::Column::Name
|
|
||||||
.eq(provider_name)
|
|
||||||
.and(model::Column::Name.eq(model_name)),
|
|
||||||
)
|
|
||||||
.one(&*tx)
|
|
||||||
.await?
|
|
||||||
// TODO: Create the model, if one doesn't exist.
|
|
||||||
.ok_or_else(|| anyhow!("no model found for {provider_name}:{model_name}"))?;
|
|
||||||
let model_id = model.id;
|
|
||||||
|
|
||||||
let existing_usage = usage::Entity::find()
|
|
||||||
.filter(
|
.filter(
|
||||||
usage::Column::UserId
|
usage::Column::UserId
|
||||||
.eq(user_id)
|
.eq(user_id)
|
||||||
.and(usage::Column::ModelId.eq(model_id)),
|
.and(usage::Column::ModelId.eq(model.id)),
|
||||||
)
|
)
|
||||||
.one(&*tx)
|
.all(&*tx)
|
||||||
.await?;
|
.await?;
|
||||||
if let Some(usage) = existing_usage {
|
|
||||||
return Ok(usage);
|
|
||||||
}
|
|
||||||
|
|
||||||
let usage = usage::Entity::insert(usage::ActiveModel {
|
let requests_this_minute =
|
||||||
user_id: ActiveValue::set(user_id),
|
self.get_usage_for_measure(&usages, now, UsageMeasure::RequestsPerMinute)?;
|
||||||
model_id: ActiveValue::set(model_id),
|
let tokens_this_minute =
|
||||||
..Default::default()
|
self.get_usage_for_measure(&usages, now, UsageMeasure::TokensPerMinute)?;
|
||||||
|
let tokens_this_day =
|
||||||
|
self.get_usage_for_measure(&usages, now, UsageMeasure::TokensPerDay)?;
|
||||||
|
let tokens_this_month =
|
||||||
|
self.get_usage_for_measure(&usages, now, UsageMeasure::TokensPerMonth)?;
|
||||||
|
|
||||||
|
Ok(Usage {
|
||||||
|
requests_this_minute,
|
||||||
|
tokens_this_minute,
|
||||||
|
tokens_this_day,
|
||||||
|
tokens_this_month,
|
||||||
})
|
})
|
||||||
.exec_with_returning(&*tx)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
Ok(usage)
|
|
||||||
})
|
})
|
||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub async fn record_usage(
|
||||||
|
&self,
|
||||||
|
user_id: i32,
|
||||||
|
provider: LanguageModelProvider,
|
||||||
|
model_name: &str,
|
||||||
|
token_count: usize,
|
||||||
|
now: DateTimeUtc,
|
||||||
|
) -> Result<()> {
|
||||||
|
self.transaction(|tx| async move {
|
||||||
|
let model = self.model(provider, model_name)?;
|
||||||
|
|
||||||
|
let usages = usage::Entity::find()
|
||||||
|
.filter(
|
||||||
|
usage::Column::UserId
|
||||||
|
.eq(user_id)
|
||||||
|
.and(usage::Column::ModelId.eq(model.id)),
|
||||||
|
)
|
||||||
|
.all(&*tx)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
self.update_usage_for_measure(
|
||||||
|
user_id,
|
||||||
|
model.id,
|
||||||
|
&usages,
|
||||||
|
UsageMeasure::RequestsPerMinute,
|
||||||
|
now,
|
||||||
|
1,
|
||||||
|
&tx,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
self.update_usage_for_measure(
|
||||||
|
user_id,
|
||||||
|
model.id,
|
||||||
|
&usages,
|
||||||
|
UsageMeasure::TokensPerMinute,
|
||||||
|
now,
|
||||||
|
token_count,
|
||||||
|
&tx,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
self.update_usage_for_measure(
|
||||||
|
user_id,
|
||||||
|
model.id,
|
||||||
|
&usages,
|
||||||
|
UsageMeasure::TokensPerDay,
|
||||||
|
now,
|
||||||
|
token_count,
|
||||||
|
&tx,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
self.update_usage_for_measure(
|
||||||
|
user_id,
|
||||||
|
model.id,
|
||||||
|
&usages,
|
||||||
|
UsageMeasure::TokensPerMonth,
|
||||||
|
now,
|
||||||
|
token_count,
|
||||||
|
&tx,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn get_active_user_count(&self, now: DateTimeUtc) -> Result<ActiveUserCount> {
|
||||||
|
self.transaction(|tx| async move {
|
||||||
|
let minute_since = now - Duration::minutes(5);
|
||||||
|
let day_since = now - Duration::days(5);
|
||||||
|
|
||||||
|
let users_in_recent_minutes = usage::Entity::find()
|
||||||
|
.filter(usage::Column::Timestamp.gte(minute_since.naive_utc()))
|
||||||
|
.group_by(usage::Column::UserId)
|
||||||
|
.count(&*tx)
|
||||||
|
.await? as usize;
|
||||||
|
|
||||||
|
let users_in_recent_days = usage::Entity::find()
|
||||||
|
.filter(usage::Column::Timestamp.gte(day_since.naive_utc()))
|
||||||
|
.group_by(usage::Column::UserId)
|
||||||
|
.count(&*tx)
|
||||||
|
.await? as usize;
|
||||||
|
|
||||||
|
Ok(ActiveUserCount {
|
||||||
|
users_in_recent_minutes,
|
||||||
|
users_in_recent_days,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
async fn update_usage_for_measure(
|
||||||
|
&self,
|
||||||
|
user_id: i32,
|
||||||
|
model_id: ModelId,
|
||||||
|
usages: &[usage::Model],
|
||||||
|
usage_measure: UsageMeasure,
|
||||||
|
now: DateTimeUtc,
|
||||||
|
usage_to_add: usize,
|
||||||
|
tx: &DatabaseTransaction,
|
||||||
|
) -> Result<()> {
|
||||||
|
let now = now.naive_utc();
|
||||||
|
let measure_id = *self
|
||||||
|
.usage_measure_ids
|
||||||
|
.get(&usage_measure)
|
||||||
|
.ok_or_else(|| anyhow!("usage measure {usage_measure} not found"))?;
|
||||||
|
|
||||||
|
let mut id = None;
|
||||||
|
let mut timestamp = now;
|
||||||
|
let mut buckets = vec![0_i64];
|
||||||
|
|
||||||
|
if let Some(old_usage) = usages.iter().find(|usage| usage.measure_id == measure_id) {
|
||||||
|
id = Some(old_usage.id);
|
||||||
|
let (live_buckets, buckets_since) =
|
||||||
|
Self::get_live_buckets(old_usage, now, usage_measure);
|
||||||
|
if !live_buckets.is_empty() {
|
||||||
|
buckets.clear();
|
||||||
|
buckets.extend_from_slice(live_buckets);
|
||||||
|
buckets.extend(iter::repeat(0).take(buckets_since));
|
||||||
|
timestamp =
|
||||||
|
old_usage.timestamp + (usage_measure.bucket_duration() * buckets_since as i32);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
*buckets.last_mut().unwrap() += usage_to_add as i64;
|
||||||
|
|
||||||
|
let mut model = usage::ActiveModel {
|
||||||
|
user_id: ActiveValue::set(user_id),
|
||||||
|
model_id: ActiveValue::set(model_id),
|
||||||
|
measure_id: ActiveValue::set(measure_id),
|
||||||
|
timestamp: ActiveValue::set(timestamp),
|
||||||
|
buckets: ActiveValue::set(buckets),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
|
||||||
|
if let Some(id) = id {
|
||||||
|
model.id = ActiveValue::unchanged(id);
|
||||||
|
model.update(tx).await?;
|
||||||
|
} else {
|
||||||
|
usage::Entity::insert(model)
|
||||||
|
.exec_without_returning(tx)
|
||||||
|
.await?;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_usage_for_measure(
|
||||||
|
&self,
|
||||||
|
usages: &[usage::Model],
|
||||||
|
now: DateTimeUtc,
|
||||||
|
usage_measure: UsageMeasure,
|
||||||
|
) -> Result<usize> {
|
||||||
|
let now = now.naive_utc();
|
||||||
|
let measure_id = *self
|
||||||
|
.usage_measure_ids
|
||||||
|
.get(&usage_measure)
|
||||||
|
.ok_or_else(|| anyhow!("usage measure {usage_measure} not found"))?;
|
||||||
|
let Some(usage) = usages.iter().find(|usage| usage.measure_id == measure_id) else {
|
||||||
|
return Ok(0);
|
||||||
|
};
|
||||||
|
|
||||||
|
let (live_buckets, _) = Self::get_live_buckets(usage, now, usage_measure);
|
||||||
|
Ok(live_buckets.iter().sum::<i64>() as _)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_live_buckets(
|
||||||
|
usage: &usage::Model,
|
||||||
|
now: chrono::NaiveDateTime,
|
||||||
|
measure: UsageMeasure,
|
||||||
|
) -> (&[i64], usize) {
|
||||||
|
let seconds_since_usage = (now - usage.timestamp).num_seconds().max(0);
|
||||||
|
let buckets_since_usage =
|
||||||
|
seconds_since_usage as f32 / measure.bucket_duration().num_seconds() as f32;
|
||||||
|
let buckets_since_usage = buckets_since_usage.ceil() as usize;
|
||||||
|
let mut live_buckets = &[] as &[i64];
|
||||||
|
if buckets_since_usage < measure.bucket_count() {
|
||||||
|
let expired_bucket_count =
|
||||||
|
(usage.buckets.len() + buckets_since_usage).saturating_sub(measure.bucket_count());
|
||||||
|
live_buckets = &usage.buckets[expired_bucket_count..];
|
||||||
|
while live_buckets.first() == Some(&0) {
|
||||||
|
live_buckets = &live_buckets[1..];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
(live_buckets, buckets_since_usage)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const MINUTE_BUCKET_COUNT: usize = 12;
|
||||||
|
const DAY_BUCKET_COUNT: usize = 48;
|
||||||
|
const MONTH_BUCKET_COUNT: usize = 30;
|
||||||
|
|
||||||
|
impl UsageMeasure {
|
||||||
|
fn bucket_count(&self) -> usize {
|
||||||
|
match self {
|
||||||
|
UsageMeasure::RequestsPerMinute => MINUTE_BUCKET_COUNT,
|
||||||
|
UsageMeasure::TokensPerMinute => MINUTE_BUCKET_COUNT,
|
||||||
|
UsageMeasure::TokensPerDay => DAY_BUCKET_COUNT,
|
||||||
|
UsageMeasure::TokensPerMonth => MONTH_BUCKET_COUNT,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn total_duration(&self) -> Duration {
|
||||||
|
match self {
|
||||||
|
UsageMeasure::RequestsPerMinute => Duration::minutes(1),
|
||||||
|
UsageMeasure::TokensPerMinute => Duration::minutes(1),
|
||||||
|
UsageMeasure::TokensPerDay => Duration::hours(24),
|
||||||
|
UsageMeasure::TokensPerMonth => Duration::days(30),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn bucket_duration(&self) -> Duration {
|
||||||
|
self.total_duration() / self.bucket_count() as i32
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
45
crates/collab/src/llm/db/seed.rs
Normal file
45
crates/collab/src/llm/db/seed.rs
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
use super::*;
|
||||||
|
use crate::{Config, Result};
|
||||||
|
use queries::providers::ModelRateLimits;
|
||||||
|
|
||||||
|
pub async fn seed_database(_config: &Config, db: &mut LlmDatabase, _force: bool) -> Result<()> {
|
||||||
|
db.insert_models(&[
|
||||||
|
(
|
||||||
|
LanguageModelProvider::Anthropic,
|
||||||
|
"claude-3-5-sonnet".into(),
|
||||||
|
ModelRateLimits {
|
||||||
|
max_requests_per_minute: 5,
|
||||||
|
max_tokens_per_minute: 20_000,
|
||||||
|
max_tokens_per_day: 300_000,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
LanguageModelProvider::Anthropic,
|
||||||
|
"claude-3-opus".into(),
|
||||||
|
ModelRateLimits {
|
||||||
|
max_requests_per_minute: 5,
|
||||||
|
max_tokens_per_minute: 10_000,
|
||||||
|
max_tokens_per_day: 300_000,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
LanguageModelProvider::Anthropic,
|
||||||
|
"claude-3-sonnet".into(),
|
||||||
|
ModelRateLimits {
|
||||||
|
max_requests_per_minute: 5,
|
||||||
|
max_tokens_per_minute: 20_000,
|
||||||
|
max_tokens_per_day: 300_000,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
LanguageModelProvider::Anthropic,
|
||||||
|
"claude-3-haiku".into(),
|
||||||
|
ModelRateLimits {
|
||||||
|
max_requests_per_minute: 5,
|
||||||
|
max_tokens_per_minute: 25_000,
|
||||||
|
max_tokens_per_day: 300_000,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
])
|
||||||
|
.await
|
||||||
|
}
|
@ -1,3 +1,4 @@
|
|||||||
pub mod model;
|
pub mod model;
|
||||||
pub mod provider;
|
pub mod provider;
|
||||||
pub mod usage;
|
pub mod usage;
|
||||||
|
pub mod usage_measure;
|
||||||
|
@ -10,6 +10,9 @@ pub struct Model {
|
|||||||
pub id: ModelId,
|
pub id: ModelId,
|
||||||
pub provider_id: ProviderId,
|
pub provider_id: ProviderId,
|
||||||
pub name: String,
|
pub name: String,
|
||||||
|
pub max_requests_per_minute: i32,
|
||||||
|
pub max_tokens_per_minute: i32,
|
||||||
|
pub max_tokens_per_day: i32,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
|
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
use sea_orm::entity::prelude::*;
|
|
||||||
|
|
||||||
use crate::llm::db::ProviderId;
|
use crate::llm::db::ProviderId;
|
||||||
|
use sea_orm::entity::prelude::*;
|
||||||
|
|
||||||
/// An LLM provider.
|
/// An LLM provider.
|
||||||
#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
|
#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
|
||||||
|
@ -1,24 +1,20 @@
|
|||||||
|
use crate::llm::db::{ModelId, UsageId, UsageMeasureId};
|
||||||
use sea_orm::entity::prelude::*;
|
use sea_orm::entity::prelude::*;
|
||||||
|
|
||||||
use crate::llm::db::ModelId;
|
|
||||||
|
|
||||||
/// An LLM usage record.
|
/// An LLM usage record.
|
||||||
#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
|
#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
|
||||||
#[sea_orm(table_name = "usages")]
|
#[sea_orm(table_name = "usages")]
|
||||||
pub struct Model {
|
pub struct Model {
|
||||||
#[sea_orm(primary_key)]
|
#[sea_orm(primary_key)]
|
||||||
pub id: i32,
|
pub id: UsageId,
|
||||||
/// The ID of the Zed user.
|
/// The ID of the Zed user.
|
||||||
///
|
///
|
||||||
/// Corresponds to the `users` table in the primary collab database.
|
/// Corresponds to the `users` table in the primary collab database.
|
||||||
pub user_id: i32,
|
pub user_id: i32,
|
||||||
pub model_id: ModelId,
|
pub model_id: ModelId,
|
||||||
pub requests_this_minute: i32,
|
pub measure_id: UsageMeasureId,
|
||||||
pub tokens_this_minute: i64,
|
pub timestamp: DateTime,
|
||||||
pub requests_this_day: i32,
|
pub buckets: Vec<i64>,
|
||||||
pub tokens_this_day: i64,
|
|
||||||
pub requests_this_month: i32,
|
|
||||||
pub tokens_this_month: i64,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
|
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
|
||||||
@ -29,6 +25,12 @@ pub enum Relation {
|
|||||||
to = "super::model::Column::Id"
|
to = "super::model::Column::Id"
|
||||||
)]
|
)]
|
||||||
Model,
|
Model,
|
||||||
|
#[sea_orm(
|
||||||
|
belongs_to = "super::usage_measure::Entity",
|
||||||
|
from = "Column::MeasureId",
|
||||||
|
to = "super::usage_measure::Column::Id"
|
||||||
|
)]
|
||||||
|
UsageMeasure,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Related<super::model::Entity> for Entity {
|
impl Related<super::model::Entity> for Entity {
|
||||||
@ -37,4 +39,10 @@ impl Related<super::model::Entity> for Entity {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Related<super::usage_measure::Entity> for Entity {
|
||||||
|
fn to() -> RelationDef {
|
||||||
|
Relation::UsageMeasure.def()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl ActiveModelBehavior for ActiveModel {}
|
impl ActiveModelBehavior for ActiveModel {}
|
||||||
|
35
crates/collab/src/llm/db/tables/usage_measure.rs
Normal file
35
crates/collab/src/llm/db/tables/usage_measure.rs
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
use crate::llm::db::UsageMeasureId;
|
||||||
|
use sea_orm::entity::prelude::*;
|
||||||
|
|
||||||
|
#[derive(
|
||||||
|
Copy, Clone, Debug, PartialEq, Eq, Hash, strum::EnumString, strum::Display, strum::EnumIter,
|
||||||
|
)]
|
||||||
|
#[strum(serialize_all = "snake_case")]
|
||||||
|
pub enum UsageMeasure {
|
||||||
|
RequestsPerMinute,
|
||||||
|
TokensPerMinute,
|
||||||
|
TokensPerDay,
|
||||||
|
TokensPerMonth,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
|
||||||
|
#[sea_orm(table_name = "usage_measures")]
|
||||||
|
pub struct Model {
|
||||||
|
#[sea_orm(primary_key)]
|
||||||
|
pub id: UsageMeasureId,
|
||||||
|
pub name: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
|
||||||
|
pub enum Relation {
|
||||||
|
#[sea_orm(has_many = "super::usage::Entity")]
|
||||||
|
Usages,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Related<super::usage::Entity> for Entity {
|
||||||
|
fn to() -> RelationDef {
|
||||||
|
Relation::Usages.def()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ActiveModelBehavior for ActiveModel {}
|
@ -6,7 +6,6 @@ use parking_lot::Mutex;
|
|||||||
use rand::prelude::*;
|
use rand::prelude::*;
|
||||||
use sea_orm::ConnectionTrait;
|
use sea_orm::ConnectionTrait;
|
||||||
use sqlx::migrate::MigrateDatabase;
|
use sqlx::migrate::MigrateDatabase;
|
||||||
use std::sync::Arc;
|
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
|
||||||
use crate::migrations::run_database_migrations;
|
use crate::migrations::run_database_migrations;
|
||||||
@ -14,47 +13,11 @@ use crate::migrations::run_database_migrations;
|
|||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
pub struct TestLlmDb {
|
pub struct TestLlmDb {
|
||||||
pub db: Option<Arc<LlmDatabase>>,
|
pub db: Option<LlmDatabase>,
|
||||||
pub connection: Option<sqlx::AnyConnection>,
|
pub connection: Option<sqlx::AnyConnection>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TestLlmDb {
|
impl TestLlmDb {
|
||||||
pub fn sqlite(background: BackgroundExecutor) -> Self {
|
|
||||||
let url = "sqlite::memory:";
|
|
||||||
let runtime = tokio::runtime::Builder::new_current_thread()
|
|
||||||
.enable_io()
|
|
||||||
.enable_time()
|
|
||||||
.build()
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let mut db = runtime.block_on(async {
|
|
||||||
let mut options = ConnectOptions::new(url);
|
|
||||||
options.max_connections(5);
|
|
||||||
let db = LlmDatabase::new(options, Executor::Deterministic(background))
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
let sql = include_str!(concat!(
|
|
||||||
env!("CARGO_MANIFEST_DIR"),
|
|
||||||
"/migrations_llm.sqlite/20240806182921_test_schema.sql"
|
|
||||||
));
|
|
||||||
db.pool
|
|
||||||
.execute(sea_orm::Statement::from_string(
|
|
||||||
db.pool.get_database_backend(),
|
|
||||||
sql,
|
|
||||||
))
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
db
|
|
||||||
});
|
|
||||||
|
|
||||||
db.runtime = Some(runtime);
|
|
||||||
|
|
||||||
Self {
|
|
||||||
db: Some(Arc::new(db)),
|
|
||||||
connection: None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn postgres(background: BackgroundExecutor) -> Self {
|
pub fn postgres(background: BackgroundExecutor) -> Self {
|
||||||
static LOCK: Mutex<()> = Mutex::new(());
|
static LOCK: Mutex<()> = Mutex::new(());
|
||||||
|
|
||||||
@ -91,29 +54,26 @@ impl TestLlmDb {
|
|||||||
db.runtime = Some(runtime);
|
db.runtime = Some(runtime);
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
db: Some(Arc::new(db)),
|
db: Some(db),
|
||||||
connection: None,
|
connection: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn db(&self) -> &Arc<LlmDatabase> {
|
pub fn db(&mut self) -> &mut LlmDatabase {
|
||||||
self.db.as_ref().unwrap()
|
self.db.as_mut().unwrap()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[macro_export]
|
#[macro_export]
|
||||||
macro_rules! test_both_llm_dbs {
|
macro_rules! test_llm_db {
|
||||||
($test_name:ident, $postgres_test_name:ident, $sqlite_test_name:ident) => {
|
($test_name:ident, $postgres_test_name:ident) => {
|
||||||
#[cfg(target_os = "macos")]
|
|
||||||
#[gpui::test]
|
#[gpui::test]
|
||||||
async fn $postgres_test_name(cx: &mut gpui::TestAppContext) {
|
async fn $postgres_test_name(cx: &mut gpui::TestAppContext) {
|
||||||
let test_db = $crate::llm::db::TestLlmDb::postgres(cx.executor().clone());
|
if !cfg!(target_os = "macos") {
|
||||||
$test_name(test_db.db()).await;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
#[gpui::test]
|
let mut test_db = $crate::llm::db::TestLlmDb::postgres(cx.executor().clone());
|
||||||
async fn $sqlite_test_name(cx: &mut gpui::TestAppContext) {
|
|
||||||
let test_db = $crate::llm::db::TestLlmDb::sqlite(cx.executor().clone());
|
|
||||||
$test_name(test_db.db()).await;
|
$test_name(test_db.db()).await;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -1,17 +1,15 @@
|
|||||||
use std::sync::Arc;
|
|
||||||
|
|
||||||
use pretty_assertions::assert_eq;
|
use pretty_assertions::assert_eq;
|
||||||
|
use rpc::LanguageModelProvider;
|
||||||
|
|
||||||
use crate::llm::db::LlmDatabase;
|
use crate::llm::db::LlmDatabase;
|
||||||
use crate::test_both_llm_dbs;
|
use crate::test_llm_db;
|
||||||
|
|
||||||
test_both_llm_dbs!(
|
test_llm_db!(
|
||||||
test_initialize_providers,
|
test_initialize_providers,
|
||||||
test_initialize_providers_postgres,
|
test_initialize_providers_postgres
|
||||||
test_initialize_providers_sqlite
|
|
||||||
);
|
);
|
||||||
|
|
||||||
async fn test_initialize_providers(db: &Arc<LlmDatabase>) {
|
async fn test_initialize_providers(db: &mut LlmDatabase) {
|
||||||
let initial_providers = db.list_providers().await.unwrap();
|
let initial_providers = db.list_providers().await.unwrap();
|
||||||
assert_eq!(initial_providers, vec![]);
|
assert_eq!(initial_providers, vec![]);
|
||||||
|
|
||||||
@ -22,9 +20,13 @@ async fn test_initialize_providers(db: &Arc<LlmDatabase>) {
|
|||||||
|
|
||||||
let providers = db.list_providers().await.unwrap();
|
let providers = db.list_providers().await.unwrap();
|
||||||
|
|
||||||
let provider_names = providers
|
assert_eq!(
|
||||||
.into_iter()
|
providers,
|
||||||
.map(|provider| provider.name)
|
&[
|
||||||
.collect::<Vec<_>>();
|
LanguageModelProvider::Anthropic,
|
||||||
assert_eq!(provider_names, vec!["anthropic".to_string()]);
|
LanguageModelProvider::Google,
|
||||||
|
LanguageModelProvider::OpenAi,
|
||||||
|
LanguageModelProvider::Zed
|
||||||
|
]
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
@ -1,24 +1,120 @@
|
|||||||
use std::sync::Arc;
|
use crate::{
|
||||||
|
llm::db::{queries::providers::ModelRateLimits, queries::usages::Usage, LlmDatabase},
|
||||||
|
test_llm_db,
|
||||||
|
};
|
||||||
|
use chrono::{Duration, Utc};
|
||||||
use pretty_assertions::assert_eq;
|
use pretty_assertions::assert_eq;
|
||||||
use rpc::LanguageModelProvider;
|
use rpc::LanguageModelProvider;
|
||||||
|
|
||||||
use crate::llm::db::LlmDatabase;
|
test_llm_db!(test_tracking_usage, test_tracking_usage_postgres);
|
||||||
use crate::test_both_llm_dbs;
|
|
||||||
|
|
||||||
test_both_llm_dbs!(
|
async fn test_tracking_usage(db: &mut LlmDatabase) {
|
||||||
test_find_or_create_usage,
|
let provider = LanguageModelProvider::Anthropic;
|
||||||
test_find_or_create_usage_postgres,
|
let model = "claude-3-5-sonnet";
|
||||||
test_find_or_create_usage_sqlite
|
|
||||||
);
|
|
||||||
|
|
||||||
async fn test_find_or_create_usage(db: &Arc<LlmDatabase>) {
|
db.initialize().await.unwrap();
|
||||||
db.initialize_providers().await.unwrap();
|
db.insert_models(&[(
|
||||||
|
provider,
|
||||||
|
model.to_string(),
|
||||||
|
ModelRateLimits {
|
||||||
|
max_requests_per_minute: 5,
|
||||||
|
max_tokens_per_minute: 10_000,
|
||||||
|
max_tokens_per_day: 50_000,
|
||||||
|
},
|
||||||
|
)])
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let usage = db
|
let t0 = Utc::now();
|
||||||
.find_or_create_usage(123, LanguageModelProvider::Anthropic, "claude-3-5-sonnet")
|
let user_id = 123;
|
||||||
|
|
||||||
|
let now = t0;
|
||||||
|
db.record_usage(user_id, provider, model, 1000, now)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
assert_eq!(usage.user_id, 123);
|
let now = t0 + Duration::seconds(10);
|
||||||
|
db.record_usage(user_id, provider, model, 2000, now)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
usage,
|
||||||
|
Usage {
|
||||||
|
requests_this_minute: 2,
|
||||||
|
tokens_this_minute: 3000,
|
||||||
|
tokens_this_day: 3000,
|
||||||
|
tokens_this_month: 3000,
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
let now = t0 + Duration::seconds(60);
|
||||||
|
let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
usage,
|
||||||
|
Usage {
|
||||||
|
requests_this_minute: 1,
|
||||||
|
tokens_this_minute: 2000,
|
||||||
|
tokens_this_day: 3000,
|
||||||
|
tokens_this_month: 3000,
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
let now = t0 + Duration::seconds(60);
|
||||||
|
db.record_usage(user_id, provider, model, 3000, now)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
usage,
|
||||||
|
Usage {
|
||||||
|
requests_this_minute: 2,
|
||||||
|
tokens_this_minute: 5000,
|
||||||
|
tokens_this_day: 6000,
|
||||||
|
tokens_this_month: 6000,
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
let t1 = t0 + Duration::hours(24);
|
||||||
|
let now = t1;
|
||||||
|
let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
usage,
|
||||||
|
Usage {
|
||||||
|
requests_this_minute: 0,
|
||||||
|
tokens_this_minute: 0,
|
||||||
|
tokens_this_day: 5000,
|
||||||
|
tokens_this_month: 6000,
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
db.record_usage(user_id, provider, model, 4000, now)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
usage,
|
||||||
|
Usage {
|
||||||
|
requests_this_minute: 1,
|
||||||
|
tokens_this_minute: 4000,
|
||||||
|
tokens_this_day: 9000,
|
||||||
|
tokens_this_month: 10000,
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
let t2 = t0 + Duration::days(30);
|
||||||
|
let now = t2;
|
||||||
|
let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
usage,
|
||||||
|
Usage {
|
||||||
|
requests_this_minute: 0,
|
||||||
|
tokens_this_minute: 0,
|
||||||
|
tokens_this_day: 0,
|
||||||
|
tokens_this_month: 9000,
|
||||||
|
}
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
@ -52,10 +52,18 @@ async fn main() -> Result<()> {
|
|||||||
Some("seed") => {
|
Some("seed") => {
|
||||||
let config = envy::from_env::<Config>().expect("error loading config");
|
let config = envy::from_env::<Config>().expect("error loading config");
|
||||||
let db_options = db::ConnectOptions::new(config.database_url.clone());
|
let db_options = db::ConnectOptions::new(config.database_url.clone());
|
||||||
|
|
||||||
let mut db = Database::new(db_options, Executor::Production).await?;
|
let mut db = Database::new(db_options, Executor::Production).await?;
|
||||||
db.initialize_notification_kinds().await?;
|
db.initialize_notification_kinds().await?;
|
||||||
|
|
||||||
collab::seed::seed(&config, &db, true).await?;
|
collab::seed::seed(&config, &db, false).await?;
|
||||||
|
|
||||||
|
if let Some(llm_database_url) = config.llm_database_url.clone() {
|
||||||
|
let db_options = db::ConnectOptions::new(llm_database_url);
|
||||||
|
let mut db = LlmDatabase::new(db_options.clone(), Executor::Production).await?;
|
||||||
|
db.initialize().await?;
|
||||||
|
collab::llm::db::seed_database(&config, &mut db, true).await?;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
Some("serve") => {
|
Some("serve") => {
|
||||||
let mode = match args.next().as_deref() {
|
let mode = match args.next().as_deref() {
|
||||||
|
@ -1,9 +1,13 @@
|
|||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
use strum::{Display, EnumIter, EnumString};
|
||||||
|
|
||||||
pub const EXPIRED_LLM_TOKEN_HEADER_NAME: &str = "x-zed-expired-token";
|
pub const EXPIRED_LLM_TOKEN_HEADER_NAME: &str = "x-zed-expired-token";
|
||||||
|
|
||||||
#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
|
#[derive(
|
||||||
|
Debug, PartialEq, Eq, Hash, Clone, Copy, Serialize, Deserialize, EnumString, EnumIter, Display,
|
||||||
|
)]
|
||||||
#[serde(rename_all = "snake_case")]
|
#[serde(rename_all = "snake_case")]
|
||||||
|
#[strum(serialize_all = "snake_case")]
|
||||||
pub enum LanguageModelProvider {
|
pub enum LanguageModelProvider {
|
||||||
Anthropic,
|
Anthropic,
|
||||||
OpenAi,
|
OpenAi,
|
||||||
|
Loading…
Reference in New Issue
Block a user