From 27779e33fb662a904724d29444588bce6d66d6e1 Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Mon, 5 Aug 2024 12:07:38 -0700 Subject: [PATCH] Refactor: Restructure collab main function to prepare for new subcommand: `serve llm` (#15824) This is just a refactor that we're landing ahead of any functional changes to make sure we haven't broken anything. Release Notes: - N/A Co-authored-by: Marshall Co-authored-by: Jason --- Procfile | 2 +- crates/collab/src/api.rs | 9 +- crates/collab/src/lib.rs | 26 +++++- crates/collab/src/llm.rs | 16 ++++ crates/collab/src/main.rs | 174 +++++++++++++++++++++----------------- 5 files changed, 142 insertions(+), 85 deletions(-) create mode 100644 crates/collab/src/llm.rs diff --git a/Procfile b/Procfile index c74eeeac44..5f1231b90a 100644 --- a/Procfile +++ b/Procfile @@ -1,3 +1,3 @@ -collab: RUST_LOG=${RUST_LOG:-info} cargo run --package=collab serve +collab: RUST_LOG=${RUST_LOG:-info} cargo run --package=collab serve all livekit: livekit-server --dev blob_store: ./script/run-local-minio diff --git a/crates/collab/src/api.rs b/crates/collab/src/api.rs index fd2973c770..69be39bd18 100644 --- a/crates/collab/src/api.rs +++ b/crates/collab/src/api.rs @@ -61,7 +61,7 @@ impl std::fmt::Display for CloudflareIpCountryHeader { } } -pub fn routes(rpc_server: Option>, state: Arc) -> Router<(), Body> { +pub fn routes(rpc_server: Arc) -> Router<(), Body> { Router::new() .route("/user", get(get_authenticated_user)) .route("/users/:id/access_tokens", post(create_access_token)) @@ -70,7 +70,6 @@ pub fn routes(rpc_server: Option>, state: Arc) -> Rou .merge(contributors::router()) .layer( ServiceBuilder::new() - .layer(Extension(state)) .layer(Extension(rpc_server)) .layer(middleware::from_fn(validate_api_token)), ) @@ -152,12 +151,8 @@ struct CreateUserParams { } async fn get_rpc_server_snapshot( - Extension(rpc_server): Extension>>, + Extension(rpc_server): Extension>, ) -> Result { - let Some(rpc_server) = rpc_server else { - return Err(Error::Internal(anyhow!("rpc server is not available"))); - }; - Ok(ErasedJson::pretty(rpc_server.snapshot().await)) } diff --git a/crates/collab/src/lib.rs b/crates/collab/src/lib.rs index 6925f62874..5201721383 100644 --- a/crates/collab/src/lib.rs +++ b/crates/collab/src/lib.rs @@ -3,6 +3,7 @@ pub mod auth; pub mod db; pub mod env; pub mod executor; +pub mod llm; mod rate_limiter; pub mod rpc; pub mod seed; @@ -124,7 +125,7 @@ impl std::fmt::Display for Error { impl std::error::Error for Error {} -#[derive(Deserialize)] +#[derive(Clone, Deserialize)] pub struct Config { pub http_port: u16, pub database_url: String, @@ -176,6 +177,29 @@ impl Config { } } +/// The service mode that collab should run in. +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub enum ServiceMode { + Api, + Collab, + Llm, + All, +} + +impl ServiceMode { + pub fn is_collab(&self) -> bool { + matches!(self, Self::Collab | Self::All) + } + + pub fn is_api(&self) -> bool { + matches!(self, Self::Api | Self::All) + } + + pub fn is_llm(&self) -> bool { + matches!(self, Self::Llm | Self::All) + } +} + pub struct AppState { pub db: Arc, pub live_kit_client: Option>, diff --git a/crates/collab/src/llm.rs b/crates/collab/src/llm.rs new file mode 100644 index 0000000000..305aee10c9 --- /dev/null +++ b/crates/collab/src/llm.rs @@ -0,0 +1,16 @@ +use std::sync::Arc; + +use crate::{executor::Executor, Config, Result}; + +pub struct LlmState { + pub config: Config, + pub executor: Executor, +} + +impl LlmState { + pub async fn new(config: Config, executor: Executor) -> Result> { + let this = Self { config, executor }; + + Ok(Arc::new(this)) + } +} diff --git a/crates/collab/src/main.rs b/crates/collab/src/main.rs index 7725988f99..31ed0c0f29 100644 --- a/crates/collab/src/main.rs +++ b/crates/collab/src/main.rs @@ -5,7 +5,7 @@ use axum::{ routing::get, Extension, Router, }; -use collab::api::billing::poll_stripe_events_periodically; +use collab::{api::billing::poll_stripe_events_periodically, llm::LlmState, ServiceMode}; use collab::{ api::fetch_extensions_from_blob_store_periodically, db, env, executor::Executor, rpc::ResultExt, AppState, Config, RateLimiter, Result, @@ -56,88 +56,99 @@ async fn main() -> Result<()> { collab::seed::seed(&config, &db, true).await?; } Some("serve") => { - let (is_api, is_collab) = if let Some(next) = args.next() { - (next == "api", next == "collab") - } else { - (true, true) + let mode = match args.next().as_deref() { + Some("collab") => ServiceMode::Collab, + Some("api") => ServiceMode::Api, + Some("llm") => ServiceMode::Llm, + Some("all") => ServiceMode::All, + _ => { + return Err(anyhow!( + "usage: collab >" + ))?; + } }; - if !is_api && !is_collab { - Err(anyhow!( - "usage: collab " - ))?; - } let config = envy::from_env::().expect("error loading config"); init_tracing(&config); + let mut app = Router::new() + .route("/", get(handle_root)) + .route("/healthz", get(handle_liveness_probe)) + .layer(Extension(mode)); - run_migrations(&config).await?; - - let state = AppState::new(config, Executor::Production).await?; - - let listener = TcpListener::bind(&format!("0.0.0.0:{}", state.config.http_port)) + let listener = TcpListener::bind(&format!("0.0.0.0:{}", config.http_port)) .expect("failed to bind TCP listener"); - let rpc_server = if is_collab { - let epoch = state - .db - .create_server(&state.config.zed_environment) - .await?; - let rpc_server = collab::rpc::Server::new(epoch, state.clone()); - rpc_server.start().await?; + let mut on_shutdown = None; - Some(rpc_server) - } else { - None - }; + if mode.is_llm() { + let state = LlmState::new(config.clone(), Executor::Production).await?; - if is_collab { - state.db.purge_old_embeddings().await.trace_err(); - RateLimiter::save_periodically(state.rate_limiter.clone(), state.executor.clone()); + app = app.layer(Extension(state.clone())); } - if is_api { - poll_stripe_events_periodically(state.clone()); - fetch_extensions_from_blob_store_periodically(state.clone()); - } + if mode.is_collab() || mode.is_api() { + run_migrations(&config).await?; - let mut app = collab::api::routes(rpc_server.clone(), state.clone()); - if let Some(rpc_server) = rpc_server.clone() { - app = app.merge(collab::rpc::routes(rpc_server)) - } - app = app - .merge( - Router::new() - .route("/", get(handle_root)) - .route("/healthz", get(handle_liveness_probe)) - .merge(collab::api::extensions::router()) + let state = AppState::new(config, Executor::Production).await?; + + if mode.is_collab() { + state.db.purge_old_embeddings().await.trace_err(); + RateLimiter::save_periodically( + state.rate_limiter.clone(), + state.executor.clone(), + ); + + let epoch = state + .db + .create_server(&state.config.zed_environment) + .await?; + let rpc_server = collab::rpc::Server::new(epoch, state.clone()); + rpc_server.start().await?; + + app = app + .merge(collab::api::routes(rpc_server.clone())) + .merge(collab::rpc::routes(rpc_server.clone())); + + on_shutdown = Some(Box::new(move || rpc_server.teardown())); + } + + if mode.is_api() { + poll_stripe_events_periodically(state.clone()); + fetch_extensions_from_blob_store_periodically(state.clone()); + + app = app .merge(collab::api::events::router()) - .layer(Extension(state.clone())), - ) - .layer( - TraceLayer::new_for_http() - .make_span_with(|request: &Request<_>| { - let matched_path = request - .extensions() - .get::() - .map(MatchedPath::as_str); + .merge(collab::api::extensions::router()) + } - tracing::info_span!( - "http_request", - method = ?request.method(), - matched_path, - ) - }) - .on_response( - |response: &Response<_>, latency: Duration, _: &tracing::Span| { - let duration_ms = latency.as_micros() as f64 / 1000.; - tracing::info!( - duration_ms, - status = response.status().as_u16(), - "finished processing request" - ); - }, - ), - ); + app = app.layer(Extension(state.clone())); + } + + app = app.layer( + TraceLayer::new_for_http() + .make_span_with(|request: &Request<_>| { + let matched_path = request + .extensions() + .get::() + .map(MatchedPath::as_str); + + tracing::info_span!( + "http_request", + method = ?request.method(), + matched_path, + ) + }) + .on_response( + |response: &Response<_>, latency: Duration, _: &tracing::Span| { + let duration_ms = latency.as_micros() as f64 / 1000.; + tracing::info!( + duration_ms, + status = response.status().as_u16(), + "finished processing request" + ); + }, + ), + ); #[cfg(unix)] let signal = async move { @@ -174,8 +185,8 @@ async fn main() -> Result<()> { signal.await; tracing::info!("Received interrupt signal"); - if let Some(rpc_server) = rpc_server { - rpc_server.teardown(); + if let Some(on_shutdown) = on_shutdown { + on_shutdown(); } }) .await @@ -183,7 +194,7 @@ async fn main() -> Result<()> { } _ => { Err(anyhow!( - "usage: collab " + "usage: collab >" ))?; } } @@ -222,12 +233,23 @@ async fn run_migrations(config: &Config) -> Result<()> { return Ok(()); } -async fn handle_root() -> String { - format!("collab v{} ({})", VERSION, REVISION.unwrap_or("unknown")) +async fn handle_root(Extension(mode): Extension) -> String { + format!( + "collab {mode:?} v{VERSION} ({})", + REVISION.unwrap_or("unknown") + ) } -async fn handle_liveness_probe(Extension(state): Extension>) -> Result { - state.db.get_all_users(0, 1).await?; +async fn handle_liveness_probe( + app_state: Option>>, + llm_state: Option>>, +) -> Result { + if let Some(state) = app_state { + state.db.get_all_users(0, 1).await?; + } + + if let Some(_llm_state) = llm_state {} + Ok("ok".to_string()) }