Remove code paths that skip LLM db in prod (#16008)

Release Notes:

- N/A
This commit is contained in:
Max Brunsfeld 2024-08-09 07:41:50 -07:00 committed by GitHub
parent c1872e9cb0
commit 225726ba4a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 33 additions and 52 deletions

View File

@ -27,7 +27,7 @@ pub use token::*;
pub struct LlmState {
pub config: Config,
pub executor: Executor,
pub db: Option<Arc<LlmDatabase>>,
pub db: Arc<LlmDatabase>,
pub http_client: IsahcHttpClient,
active_user_count: RwLock<Option<(DateTime<Utc>, ActiveUserCount)>>,
}
@ -36,25 +36,20 @@ const ACTIVE_USER_COUNT_CACHE_DURATION: Duration = Duration::seconds(30);
impl LlmState {
pub async fn new(config: Config, executor: Executor) -> Result<Arc<Self>> {
// TODO: This is temporary until we have the LLM database stood up.
let db = if config.is_development() {
let database_url = config
.llm_database_url
.as_ref()
.ok_or_else(|| anyhow!("missing LLM_DATABASE_URL"))?;
let max_connections = config
.llm_database_max_connections
.ok_or_else(|| anyhow!("missing LLM_DATABASE_MAX_CONNECTIONS"))?;
let database_url = config
.llm_database_url
.as_ref()
.ok_or_else(|| anyhow!("missing LLM_DATABASE_URL"))?;
let max_connections = config
.llm_database_max_connections
.ok_or_else(|| anyhow!("missing LLM_DATABASE_MAX_CONNECTIONS"))?;
let mut db_options = db::ConnectOptions::new(database_url);
db_options.max_connections(max_connections);
let mut db = LlmDatabase::new(db_options, executor.clone()).await?;
db.initialize().await?;
let mut db_options = db::ConnectOptions::new(database_url);
db_options.max_connections(max_connections);
let mut db = LlmDatabase::new(db_options, executor.clone()).await?;
db.initialize().await?;
Some(Arc::new(db))
} else {
None
};
let db = Arc::new(db);
let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
let http_client = IsahcHttpClient::builder()
@ -62,11 +57,8 @@ impl LlmState {
.build()
.context("failed to construct http client")?;
let initial_active_user_count = if let Some(db) = &db {
Some((Utc::now(), db.get_active_user_count(Utc::now()).await?))
} else {
None
};
let initial_active_user_count =
Some((Utc::now(), db.get_active_user_count(Utc::now()).await?));
let this = Self {
config,
@ -88,14 +80,10 @@ impl LlmState {
}
}
if let Some(db) = &self.db {
let mut cache = self.active_user_count.write().await;
let new_count = db.get_active_user_count(now).await?;
*cache = Some((now, new_count));
Ok(new_count)
} else {
Ok(ActiveUserCount::default())
}
let mut cache = self.active_user_count.write().await;
let new_count = self.db.get_active_user_count(now).await?;
*cache = Some((now, new_count));
Ok(new_count)
}
}
@ -165,9 +153,7 @@ async fn perform_completion(
let user_id = claims.user_id as i32;
if state.db.is_some() {
check_usage_limit(&state, params.provider, &model, &claims).await?;
}
check_usage_limit(&state, params.provider, &model, &claims).await?;
match params.provider {
LanguageModelProvider::Anthropic => {
@ -199,14 +185,14 @@ async fn perform_completion(
)
.await?;
let mut recorder = state.db.clone().map(|db| UsageRecorder {
db,
let mut recorder = UsageRecorder {
db: state.db.clone(),
executor: state.executor.clone(),
user_id,
provider: params.provider,
model,
token_count: 0,
});
};
let stream = chunks.map(move |event| {
let mut buffer = Vec::new();
@ -216,10 +202,8 @@ async fn perform_completion(
message: anthropic::Response { usage, .. },
}
| anthropic::Event::MessageDelta { usage, .. } => {
if let Some(recorder) = &mut recorder {
recorder.token_count += usage.input_tokens.unwrap_or(0) as usize;
recorder.token_count += usage.output_tokens.unwrap_or(0) as usize;
}
recorder.token_count += usage.input_tokens.unwrap_or(0) as usize;
recorder.token_count += usage.output_tokens.unwrap_or(0) as usize;
}
_ => {}
}
@ -349,12 +333,9 @@ async fn check_usage_limit(
model_name: &str,
claims: &LlmTokenClaims,
) -> Result<()> {
let db = state
let model = state.db.model(provider, model_name)?;
let usage = state
.db
.as_ref()
.ok_or_else(|| anyhow!("LLM database not configured"))?;
let model = db.model(provider, model_name)?;
let usage = db
.get_usage(claims.user_id as i32, provider, model_name, Utc::now())
.await?;

View File

@ -248,11 +248,6 @@ async fn setup_app_database(config: &Config) -> Result<()> {
}
async fn setup_llm_database(config: &Config) -> Result<()> {
// TODO: This is temporary until we have the LLM database stood up.
if !config.is_development() {
return Ok(());
}
let database_url = config
.llm_database_url
.as_ref()
@ -298,7 +293,12 @@ async fn handle_liveness_probe(
state.db.get_all_users(0, 1).await?;
}
if let Some(_llm_state) = llm_state {}
if let Some(llm_state) = llm_state {
llm_state
.db
.get_active_user_count(chrono::Utc::now())
.await?;
}
Ok("ok".to_string())
}