diff --git a/Cargo.lock b/Cargo.lock index 0dd8f1eda7..96456ae65c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2466,6 +2466,7 @@ dependencies = [ "hex", "http_client", "indoc", + "jsonwebtoken", "language", "language_model", "live_kit_client", @@ -2507,6 +2508,7 @@ dependencies = [ "telemetry_events", "text", "theme", + "thiserror", "time", "tokio", "toml 0.8.16", diff --git a/crates/collab/.env.toml b/crates/collab/.env.toml index 9bfdf294e4..9646d0c921 100644 --- a/crates/collab/.env.toml +++ b/crates/collab/.env.toml @@ -15,6 +15,7 @@ BLOB_STORE_URL = "http://127.0.0.1:9000" BLOB_STORE_REGION = "the-region" ZED_CLIENT_CHECKSUM_SEED = "development-checksum-seed" SEED_PATH = "crates/collab/seed.default.json" +LLM_API_SECRET = "llm-secret" # CLICKHOUSE_URL = "" # CLICKHOUSE_USER = "default" diff --git a/crates/collab/Cargo.toml b/crates/collab/Cargo.toml index 8ebeb3e555..2b6583f970 100644 --- a/crates/collab/Cargo.toml +++ b/crates/collab/Cargo.toml @@ -37,6 +37,7 @@ futures.workspace = true google_ai.workspace = true hex.workspace = true http_client.workspace = true +jsonwebtoken.workspace = true live_kit_server.workspace = true log.workspace = true nanoid.workspace = true @@ -61,6 +62,7 @@ subtle.workspace = true rustc-demangle.workspace = true telemetry_events.workspace = true text.workspace = true +thiserror.workspace = true time.workspace = true tokio.workspace = true toml.workspace = true diff --git a/crates/collab/src/api.rs b/crates/collab/src/api.rs index 69be39bd18..5504005f8d 100644 --- a/crates/collab/src/api.rs +++ b/crates/collab/src/api.rs @@ -81,14 +81,14 @@ pub async fn validate_api_token(req: Request, next: Next) -> impl IntoR .get(http::header::AUTHORIZATION) .and_then(|header| header.to_str().ok()) .ok_or_else(|| { - Error::Http( + Error::http( StatusCode::BAD_REQUEST, "missing authorization header".to_string(), ) })? .strip_prefix("token ") .ok_or_else(|| { - Error::Http( + Error::http( StatusCode::BAD_REQUEST, "invalid authorization header".to_string(), ) @@ -97,7 +97,7 @@ pub async fn validate_api_token(req: Request, next: Next) -> impl IntoR let state = req.extensions().get::>().unwrap(); if token != state.config.api_token { - Err(Error::Http( + Err(Error::http( StatusCode::UNAUTHORIZED, "invalid authorization token".to_string(), ))? @@ -185,13 +185,13 @@ async fn create_access_token( if let Some(impersonated_user) = app.db.get_user_by_github_login(&impersonate).await? { impersonated_user_id = Some(impersonated_user.id); } else { - return Err(Error::Http( + return Err(Error::http( StatusCode::UNPROCESSABLE_ENTITY, format!("user {impersonate} does not exist"), )); } } else { - return Err(Error::Http( + return Err(Error::http( StatusCode::UNAUTHORIZED, "you do not have permission to impersonate other users".to_string(), )); diff --git a/crates/collab/src/api/billing.rs b/crates/collab/src/api/billing.rs index bac0258ead..c8f328ad5d 100644 --- a/crates/collab/src/api/billing.rs +++ b/crates/collab/src/api/billing.rs @@ -120,7 +120,7 @@ async fn create_billing_subscription( .zip(app.config.stripe_price_id.clone()) else { log::error!("failed to retrieve Stripe client or price ID"); - Err(Error::Http( + Err(Error::http( StatusCode::NOT_IMPLEMENTED, "not supported".into(), ))? @@ -201,7 +201,7 @@ async fn manage_billing_subscription( let Some(stripe_client) = app.stripe_client.clone() else { log::error!("failed to retrieve Stripe client"); - Err(Error::Http( + Err(Error::http( StatusCode::NOT_IMPLEMENTED, "not supported".into(), ))? diff --git a/crates/collab/src/api/events.rs b/crates/collab/src/api/events.rs index e0cf79bb88..a6afc98bfc 100644 --- a/crates/collab/src/api/events.rs +++ b/crates/collab/src/api/events.rs @@ -206,14 +206,14 @@ pub async fn post_hang( body: Bytes, ) -> Result<()> { let Some(expected) = calculate_json_checksum(app.clone(), &body) else { - return Err(Error::Http( + return Err(Error::http( StatusCode::INTERNAL_SERVER_ERROR, "events not enabled".into(), ))?; }; if checksum != expected { - return Err(Error::Http( + return Err(Error::http( StatusCode::BAD_REQUEST, "invalid checksum".into(), ))?; @@ -265,25 +265,25 @@ pub async fn post_panic( body: Bytes, ) -> Result<()> { let Some(expected) = calculate_json_checksum(app.clone(), &body) else { - return Err(Error::Http( + return Err(Error::http( StatusCode::INTERNAL_SERVER_ERROR, "events not enabled".into(), ))?; }; if checksum != expected { - return Err(Error::Http( + return Err(Error::http( StatusCode::BAD_REQUEST, "invalid checksum".into(), ))?; } let report: telemetry_events::PanicRequest = serde_json::from_slice(&body) - .map_err(|_| Error::Http(StatusCode::BAD_REQUEST, "invalid json".into()))?; + .map_err(|_| Error::http(StatusCode::BAD_REQUEST, "invalid json".into()))?; let panic = report.panic; if panic.os_name == "Linux" && panic.os_version == Some("1.0.0".to_string()) { - return Err(Error::Http( + return Err(Error::http( StatusCode::BAD_REQUEST, "invalid os version".into(), ))?; @@ -362,14 +362,14 @@ pub async fn post_events( body: Bytes, ) -> Result<()> { let Some(clickhouse_client) = app.clickhouse_client.clone() else { - Err(Error::Http( + Err(Error::http( StatusCode::NOT_IMPLEMENTED, "not supported".into(), ))? }; let Some(expected) = calculate_json_checksum(app.clone(), &body) else { - return Err(Error::Http( + return Err(Error::http( StatusCode::INTERNAL_SERVER_ERROR, "events not enabled".into(), ))?; @@ -385,7 +385,7 @@ pub async fn post_events( let mut to_upload = ToUpload::default(); let Some(last_event) = request_body.events.last() else { - return Err(Error::Http(StatusCode::BAD_REQUEST, "no events".into()))?; + return Err(Error::http(StatusCode::BAD_REQUEST, "no events".into()))?; }; let country_code = country_code_header.map(|h| h.to_string()); diff --git a/crates/collab/src/api/extensions.rs b/crates/collab/src/api/extensions.rs index d0532504ed..1665cf0a0f 100644 --- a/crates/collab/src/api/extensions.rs +++ b/crates/collab/src/api/extensions.rs @@ -185,7 +185,7 @@ async fn download_extension( .clone() .zip(app.config.blob_store_bucket.clone()) else { - Err(Error::Http( + Err(Error::http( StatusCode::NOT_IMPLEMENTED, "not supported".into(), ))? @@ -202,7 +202,7 @@ async fn download_extension( .await?; if !version_exists { - Err(Error::Http( + Err(Error::http( StatusCode::NOT_FOUND, "unknown extension version".into(), ))?; diff --git a/crates/collab/src/auth.rs b/crates/collab/src/auth.rs index 261fe9b850..283ab9dbc2 100644 --- a/crates/collab/src/auth.rs +++ b/crates/collab/src/auth.rs @@ -33,7 +33,7 @@ pub async fn validate_header(mut req: Request, next: Next) -> impl Into .get(http::header::AUTHORIZATION) .and_then(|header| header.to_str().ok()) .ok_or_else(|| { - Error::Http( + Error::http( StatusCode::UNAUTHORIZED, "missing authorization header".to_string(), ) @@ -45,14 +45,14 @@ pub async fn validate_header(mut req: Request, next: Next) -> impl Into let first = auth_header.next().unwrap_or(""); if first == "dev-server-token" { let dev_server_token = auth_header.next().ok_or_else(|| { - Error::Http( + Error::http( StatusCode::BAD_REQUEST, "missing dev-server-token token in authorization header".to_string(), ) })?; let dev_server = verify_dev_server_token(dev_server_token, &state.db) .await - .map_err(|e| Error::Http(StatusCode::UNAUTHORIZED, format!("{}", e)))?; + .map_err(|e| Error::http(StatusCode::UNAUTHORIZED, format!("{}", e)))?; req.extensions_mut() .insert(Principal::DevServer(dev_server)); @@ -60,14 +60,14 @@ pub async fn validate_header(mut req: Request, next: Next) -> impl Into } let user_id = UserId(first.parse().map_err(|_| { - Error::Http( + Error::http( StatusCode::BAD_REQUEST, "missing user id in authorization header".to_string(), ) })?); let access_token = auth_header.next().ok_or_else(|| { - Error::Http( + Error::http( StatusCode::BAD_REQUEST, "missing access token in authorization header".to_string(), ) @@ -111,7 +111,7 @@ pub async fn validate_header(mut req: Request, next: Next) -> impl Into } } - Err(Error::Http( + Err(Error::http( StatusCode::UNAUTHORIZED, "invalid credentials".to_string(), )) diff --git a/crates/collab/src/lib.rs b/crates/collab/src/lib.rs index 5201721383..d88f37354e 100644 --- a/crates/collab/src/lib.rs +++ b/crates/collab/src/lib.rs @@ -13,7 +13,10 @@ mod tests; use anyhow::anyhow; use aws_config::{BehaviorVersion, Region}; -use axum::{http::StatusCode, response::IntoResponse}; +use axum::{ + http::{HeaderMap, StatusCode}, + response::IntoResponse, +}; use db::{ChannelId, Database}; use executor::Executor; pub use rate_limiter::*; @@ -24,7 +27,7 @@ use util::ResultExt; pub type Result = std::result::Result; pub enum Error { - Http(StatusCode, String), + Http(StatusCode, String, HeaderMap), Database(sea_orm::error::DbErr), Internal(anyhow::Error), Stripe(stripe::StripeError), @@ -66,12 +69,18 @@ impl From for Error { } } +impl Error { + fn http(code: StatusCode, message: String) -> Self { + Self::Http(code, message, HeaderMap::default()) + } +} + impl IntoResponse for Error { fn into_response(self) -> axum::response::Response { match self { - Error::Http(code, message) => { + Error::Http(code, message, headers) => { log::error!("HTTP error {}: {}", code, &message); - (code, message).into_response() + (code, headers, message).into_response() } Error::Database(error) => { log::error!( @@ -104,7 +113,7 @@ impl IntoResponse for Error { impl std::fmt::Debug for Error { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - Error::Http(code, message) => (code, message).fmt(f), + Error::Http(code, message, _headers) => (code, message).fmt(f), Error::Database(error) => error.fmt(f), Error::Internal(error) => error.fmt(f), Error::Stripe(error) => error.fmt(f), @@ -115,7 +124,7 @@ impl std::fmt::Debug for Error { impl std::fmt::Display for Error { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - Error::Http(code, message) => write!(f, "{code}: {message}"), + Error::Http(code, message, _) => write!(f, "{code}: {message}"), Error::Database(error) => error.fmt(f), Error::Internal(error) => error.fmt(f), Error::Stripe(error) => error.fmt(f), @@ -141,6 +150,7 @@ pub struct Config { pub live_kit_server: Option, pub live_kit_key: Option, pub live_kit_secret: Option, + pub llm_api_secret: Option, pub rust_log: Option, pub log_json: Option, pub blob_store_url: Option, diff --git a/crates/collab/src/llm.rs b/crates/collab/src/llm.rs index 305aee10c9..e3e17562fa 100644 --- a/crates/collab/src/llm.rs +++ b/crates/collab/src/llm.rs @@ -1,16 +1,122 @@ +mod token; + +use crate::{executor::Executor, Config, Error, Result}; +use anyhow::Context as _; +use axum::{ + body::Body, + http::{self, HeaderName, HeaderValue, Request, StatusCode}, + middleware::{self, Next}, + response::{IntoResponse, Response}, + routing::post, + Extension, Json, Router, +}; +use futures::StreamExt as _; +use http_client::IsahcHttpClient; +use rpc::{PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME}; use std::sync::Arc; -use crate::{executor::Executor, Config, Result}; +pub use token::*; pub struct LlmState { pub config: Config, pub executor: Executor, + pub http_client: IsahcHttpClient, } impl LlmState { pub async fn new(config: Config, executor: Executor) -> Result> { - let this = Self { config, executor }; + let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION")); + let http_client = IsahcHttpClient::builder() + .default_header("User-Agent", user_agent) + .build() + .context("failed to construct http client")?; + + let this = Self { + config, + executor, + http_client, + }; Ok(Arc::new(this)) } } + +pub fn routes() -> Router<(), Body> { + Router::new() + .route("/completion", post(perform_completion)) + .layer(middleware::from_fn(validate_api_token)) +} + +async fn validate_api_token(mut req: Request, next: Next) -> impl IntoResponse { + let token = req + .headers() + .get(http::header::AUTHORIZATION) + .and_then(|header| header.to_str().ok()) + .ok_or_else(|| { + Error::http( + StatusCode::BAD_REQUEST, + "missing authorization header".to_string(), + ) + })? + .strip_prefix("Bearer ") + .ok_or_else(|| { + Error::http( + StatusCode::BAD_REQUEST, + "invalid authorization header".to_string(), + ) + })?; + + let state = req.extensions().get::>().unwrap(); + match LlmTokenClaims::validate(&token, &state.config) { + Ok(claims) => { + req.extensions_mut().insert(claims); + Ok::<_, Error>(next.run(req).await.into_response()) + } + Err(ValidateLlmTokenError::Expired) => Err(Error::Http( + StatusCode::UNAUTHORIZED, + "unauthorized".to_string(), + [( + HeaderName::from_static(EXPIRED_LLM_TOKEN_HEADER_NAME), + HeaderValue::from_static("true"), + )] + .into_iter() + .collect(), + )), + Err(_err) => Err(Error::http( + StatusCode::UNAUTHORIZED, + "unauthorized".to_string(), + )), + } +} + +async fn perform_completion( + Extension(state): Extension>, + Extension(_claims): Extension, + Json(params): Json, +) -> Result { + let api_key = state + .config + .anthropic_api_key + .as_ref() + .context("no Anthropic AI API key configured on the server")?; + let chunks = anthropic::stream_completion( + &state.http_client, + anthropic::ANTHROPIC_API_URL, + api_key, + serde_json::from_str(¶ms.provider_request.get())?, + None, + ) + .await?; + + let stream = chunks.map(|event| { + let mut buffer = Vec::new(); + event.map(|chunk| { + buffer.clear(); + serde_json::to_writer(&mut buffer, &chunk).unwrap(); + buffer.push(b'\n'); + buffer + }) + }); + + Ok(Response::new(Body::wrap_stream(stream))) +} diff --git a/crates/collab/src/llm/token.rs b/crates/collab/src/llm/token.rs new file mode 100644 index 0000000000..99386443eb --- /dev/null +++ b/crates/collab/src/llm/token.rs @@ -0,0 +1,75 @@ +use crate::{db::UserId, Config}; +use anyhow::{anyhow, Result}; +use chrono::Utc; +use jsonwebtoken::{DecodingKey, EncodingKey, Header, Validation}; +use serde::{Deserialize, Serialize}; +use std::time::Duration; +use thiserror::Error; + +#[derive(Clone, Debug, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct LlmTokenClaims { + pub iat: u64, + pub exp: u64, + pub jti: String, + pub user_id: u64, + pub plan: rpc::proto::Plan, +} + +const LLM_TOKEN_LIFETIME: Duration = Duration::from_secs(60 * 60); + +impl LlmTokenClaims { + pub fn create(user_id: UserId, plan: rpc::proto::Plan, config: &Config) -> Result { + let secret = config + .llm_api_secret + .as_ref() + .ok_or_else(|| anyhow!("no LLM API secret"))?; + + let now = Utc::now(); + let claims = Self { + iat: now.timestamp() as u64, + exp: (now + LLM_TOKEN_LIFETIME).timestamp() as u64, + jti: uuid::Uuid::new_v4().to_string(), + user_id: user_id.to_proto(), + plan, + }; + + Ok(jsonwebtoken::encode( + &Header::default(), + &claims, + &EncodingKey::from_secret(secret.as_ref()), + )?) + } + + pub fn validate(token: &str, config: &Config) -> Result { + let secret = config + .llm_api_secret + .as_ref() + .ok_or_else(|| anyhow!("no LLM API secret"))?; + + match jsonwebtoken::decode::( + token, + &DecodingKey::from_secret(secret.as_ref()), + &Validation::default(), + ) { + Ok(token) => Ok(token.claims), + Err(e) => { + if e.kind() == &jsonwebtoken::errors::ErrorKind::ExpiredSignature { + Err(ValidateLlmTokenError::Expired) + } else { + Err(ValidateLlmTokenError::JwtError(e)) + } + } + } + } +} + +#[derive(Error, Debug)] +pub enum ValidateLlmTokenError { + #[error("access token is expired")] + Expired, + #[error("access token validation error: {0}")] + JwtError(#[from] jsonwebtoken::errors::Error), + #[error("{0}")] + Other(#[from] anyhow::Error), +} diff --git a/crates/collab/src/main.rs b/crates/collab/src/main.rs index 31ed0c0f29..60a6967ca2 100644 --- a/crates/collab/src/main.rs +++ b/crates/collab/src/main.rs @@ -83,7 +83,9 @@ async fn main() -> Result<()> { if mode.is_llm() { let state = LlmState::new(config.clone(), Executor::Production).await?; - app = app.layer(Extension(state.clone())); + app = app + .merge(collab::llm::routes()) + .layer(Extension(state.clone())); } if mode.is_collab() || mode.is_api() { diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index a75d132b57..361e4fe237 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -1,6 +1,7 @@ mod connection_pool; use crate::api::CloudflareIpCountryHeader; +use crate::llm::LlmTokenClaims; use crate::{ auth, db::{ @@ -11,7 +12,7 @@ use crate::{ ServerId, UpdatedChannelMessage, User, UserId, }, executor::Executor, - AppState, Config, Error, RateLimit, RateLimiter, Result, + AppState, Config, Error, RateLimit, Result, }; use anyhow::{anyhow, bail, Context as _}; use async_tungstenite::tungstenite::{ @@ -149,10 +150,9 @@ struct Session { db: Arc>, peer: Arc, connection_pool: Arc>, - live_kit_client: Option>, + app_state: Arc, supermaven_client: Option>, http_client: Arc, - rate_limiter: Arc, /// The GeoIP country code for the user. #[allow(unused)] geoip_country_code: Option, @@ -615,6 +615,7 @@ impl Server { .add_message_handler(user_message_handler(unfollow)) .add_message_handler(user_message_handler(update_followers)) .add_request_handler(user_handler(get_private_user_info)) + .add_request_handler(user_handler(get_llm_api_token)) .add_message_handler(user_message_handler(acknowledge_channel_message)) .add_message_handler(user_message_handler(acknowledge_buffer_version)) .add_request_handler(user_handler(get_supermaven_api_key)) @@ -1046,9 +1047,8 @@ impl Server { db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))), peer: this.peer.clone(), connection_pool: this.connection_pool.clone(), - live_kit_client: this.app_state.live_kit_client.clone(), + app_state: this.app_state.clone(), http_client, - rate_limiter: this.app_state.rate_limiter.clone(), geoip_country_code, _executor: executor.clone(), supermaven_client, @@ -1559,7 +1559,7 @@ async fn create_room( let live_kit_room = nanoid::nanoid!(30); let live_kit_connection_info = util::maybe!(async { - let live_kit = session.live_kit_client.as_ref(); + let live_kit = session.app_state.live_kit_client.as_ref(); let live_kit = live_kit?; let user_id = session.user_id().to_string(); @@ -1630,25 +1630,26 @@ async fn join_room( .trace_err(); } - let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() { - if let Some(token) = live_kit - .room_token( - &joined_room.room.live_kit_room, - &session.user_id().to_string(), - ) - .trace_err() - { - Some(proto::LiveKitConnectionInfo { - server_url: live_kit.url().into(), - token, - can_publish: true, - }) + let live_kit_connection_info = + if let Some(live_kit) = session.app_state.live_kit_client.as_ref() { + if let Some(token) = live_kit + .room_token( + &joined_room.room.live_kit_room, + &session.user_id().to_string(), + ) + .trace_err() + { + Some(proto::LiveKitConnectionInfo { + server_url: live_kit.url().into(), + token, + can_publish: true, + }) + } else { + None + } } else { None - } - } else { - None - }; + }; response.send(proto::JoinRoomResponse { room: Some(joined_room.room), @@ -1877,7 +1878,7 @@ async fn set_room_participant_role( (live_kit_room, can_publish) }; - if let Some(live_kit) = session.live_kit_client.as_ref() { + if let Some(live_kit) = session.app_state.live_kit_client.as_ref() { live_kit .update_participant( live_kit_room.clone(), @@ -4048,35 +4049,40 @@ async fn join_channel_internal( .join_channel(channel_id, session.user_id(), session.connection_id) .await?; - let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| { - let (can_publish, token) = if role == ChannelRole::Guest { - ( - false, - live_kit - .guest_token( - &joined_room.room.live_kit_room, - &session.user_id().to_string(), + let live_kit_connection_info = + session + .app_state + .live_kit_client + .as_ref() + .and_then(|live_kit| { + let (can_publish, token) = if role == ChannelRole::Guest { + ( + false, + live_kit + .guest_token( + &joined_room.room.live_kit_room, + &session.user_id().to_string(), + ) + .trace_err()?, ) - .trace_err()?, - ) - } else { - ( - true, - live_kit - .room_token( - &joined_room.room.live_kit_room, - &session.user_id().to_string(), + } else { + ( + true, + live_kit + .room_token( + &joined_room.room.live_kit_room, + &session.user_id().to_string(), + ) + .trace_err()?, ) - .trace_err()?, - ) - }; + }; - Some(LiveKitConnectionInfo { - server_url: live_kit.url().into(), - token, - can_publish, - }) - }); + Some(LiveKitConnectionInfo { + server_url: live_kit.url().into(), + token, + can_publish, + }) + }); response.send(proto::JoinRoomResponse { room: Some(joined_room.room.clone()), @@ -4610,6 +4616,7 @@ async fn complete_with_language_model( }; session + .app_state .rate_limiter .check(&*rate_limit, session.user_id()) .await?; @@ -4655,6 +4662,7 @@ async fn stream_complete_with_language_model( }; session + .app_state .rate_limiter .check(&*rate_limit, session.user_id()) .await?; @@ -4766,6 +4774,7 @@ async fn count_language_model_tokens( }; session + .app_state .rate_limiter .check(&*rate_limit, session.user_id()) .await?; @@ -4885,6 +4894,7 @@ async fn compute_embeddings( }; session + .app_state .rate_limiter .check(&*rate_limit, session.user_id()) .await?; @@ -5143,6 +5153,24 @@ async fn get_private_user_info( Ok(()) } +async fn get_llm_api_token( + _request: proto::GetLlmToken, + response: Response, + session: UserSession, +) -> Result<()> { + if !session.is_staff() { + Err(anyhow!("permission denied"))? + } + + let token = LlmTokenClaims::create( + session.user_id(), + session.current_plan().await?, + &session.app_state.config, + )?; + response.send(proto::GetLlmTokenResponse { token })?; + Ok(()) +} + fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result { let message = match message { TungsteniteMessage::Text(payload) => AxumMessage::Text(payload), @@ -5486,7 +5514,7 @@ async fn leave_room_for_session(session: &UserSession, connection_id: Connection update_user_contacts(contact_user_id, &session).await?; } - if let Some(live_kit) = session.live_kit_client.as_ref() { + if let Some(live_kit) = session.app_state.live_kit_client.as_ref() { live_kit .remove_participant(live_kit_room.clone(), session.user_id().to_string()) .await diff --git a/crates/collab/src/tests/test_server.rs b/crates/collab/src/tests/test_server.rs index fde3082102..bf8b031e5e 100644 --- a/crates/collab/src/tests/test_server.rs +++ b/crates/collab/src/tests/test_server.rs @@ -651,6 +651,7 @@ impl TestServer { live_kit_server: None, live_kit_key: None, live_kit_secret: None, + llm_api_secret: None, rust_log: None, log_json: None, zed_environment: "test".into(), diff --git a/crates/http_client/src/http_client.rs b/crates/http_client/src/http_client.rs index fe5174a982..452be0a243 100644 --- a/crates/http_client/src/http_client.rs +++ b/crates/http_client/src/http_client.rs @@ -175,6 +175,22 @@ impl HttpClientWithUrl { query, )?) } + + /// Builds a Zed LLM URL using the given path. + pub fn build_zed_llm_url(&self, path: &str, query: &[(&str, &str)]) -> Result { + let base_url = self.base_url(); + let base_api_url = match base_url.as_ref() { + "https://zed.dev" => "https://llm.zed.dev", + "https://staging.zed.dev" => "https://llm-staging.zed.dev", + "http://localhost:3000" => "http://localhost:8080", + other => other, + }; + + Ok(Url::parse_with_params( + &format!("{}{}", base_api_url, path), + query, + )?) + } } impl HttpClient for Arc { diff --git a/crates/language_model/src/provider/cloud.rs b/crates/language_model/src/provider/cloud.rs index 5614e6d98a..7862794e92 100644 --- a/crates/language_model/src/provider/cloud.rs +++ b/crates/language_model/src/provider/cloud.rs @@ -5,13 +5,20 @@ use crate::{ LanguageModelProviderState, LanguageModelRequest, RateLimiter, ZedModel, }; use anyhow::{anyhow, Context as _, Result}; -use client::{Client, UserStore}; +use client::{Client, PerformCompletionParams, UserStore, EXPIRED_LLM_TOKEN_HEADER_NAME}; use collections::BTreeMap; -use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt}; +use feature_flags::{FeatureFlag, FeatureFlagAppExt}; +use futures::{future::BoxFuture, stream::BoxStream, AsyncBufReadExt, FutureExt, StreamExt}; use gpui::{AnyView, AppContext, AsyncAppContext, Model, ModelContext, Subscription, Task}; +use http_client::{HttpClient, Method}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; +use serde_json::value::RawValue; use settings::{Settings, SettingsStore}; +use smol::{ + io::BufReader, + lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard}, +}; use std::{future, sync::Arc}; use strum::IntoEnumIterator; use ui::prelude::*; @@ -46,6 +53,7 @@ pub struct AvailableModel { pub struct CloudLanguageModelProvider { client: Arc, + llm_api_token: LlmApiToken, state: gpui::Model, _maintain_client_status: Task<()>, } @@ -104,6 +112,7 @@ impl CloudLanguageModelProvider { Self { client, state, + llm_api_token: LlmApiToken::default(), _maintain_client_status: maintain_client_status, } } @@ -181,6 +190,7 @@ impl LanguageModelProvider for CloudLanguageModelProvider { Arc::new(CloudLanguageModel { id: LanguageModelId::from(model.id().to_string()), model, + llm_api_token: self.llm_api_token.clone(), client: self.client.clone(), request_limiter: RateLimiter::new(4), }) as Arc @@ -208,13 +218,27 @@ impl LanguageModelProvider for CloudLanguageModelProvider { } } +struct LlmServiceFeatureFlag; + +impl FeatureFlag for LlmServiceFeatureFlag { + const NAME: &'static str = "llm-service"; + + fn enabled_for_staff() -> bool { + false + } +} + pub struct CloudLanguageModel { id: LanguageModelId, model: CloudModel, + llm_api_token: LlmApiToken, client: Arc, request_limiter: RateLimiter, } +#[derive(Clone, Default)] +struct LlmApiToken(Arc>>); + impl LanguageModel for CloudLanguageModel { fn id(&self) -> LanguageModelId { self.id.clone() @@ -279,25 +303,88 @@ impl LanguageModel for CloudLanguageModel { fn stream_completion( &self, request: LanguageModelRequest, - _: &AsyncAppContext, + cx: &AsyncAppContext, ) -> BoxFuture<'static, Result>>> { match &self.model { CloudModel::Anthropic(model) => { - let client = self.client.clone(); let request = request.into_anthropic(model.id().into()); - let future = self.request_limiter.stream(async move { - let request = serde_json::to_string(&request)?; - let stream = client - .request_stream(proto::StreamCompleteWithLanguageModel { - provider: proto::LanguageModelProvider::Anthropic as i32, - request, - }) - .await?; - Ok(anthropic::extract_text_from_events( - stream.map(|item| Ok(serde_json::from_str(&item?.event)?)), - )) - }); - async move { Ok(future.await?.boxed()) }.boxed() + let client = self.client.clone(); + + if cx + .update(|cx| cx.has_flag::()) + .unwrap_or(false) + { + let http_client = self.client.http_client(); + let llm_api_token = self.llm_api_token.clone(); + let future = self.request_limiter.stream(async move { + let request = serde_json::to_string(&request)?; + let mut token = llm_api_token.acquire(&client).await?; + let mut did_retry = false; + + let response = loop { + let request = http_client::Request::builder() + .method(Method::POST) + .uri(http_client.build_zed_llm_url("/completion", &[])?.as_ref()) + .header("Content-Type", "application/json") + .header("Authorization", format!("Bearer {token}")) + .body( + serde_json::to_string(&PerformCompletionParams { + provider_request: RawValue::from_string(request.clone())?, + })? + .into(), + )?; + let response = http_client.send(request).await?; + if response.status().is_success() { + break response; + } else if !did_retry + && response + .headers() + .get(EXPIRED_LLM_TOKEN_HEADER_NAME) + .is_some() + { + did_retry = true; + token = llm_api_token.refresh(&client).await?; + } else { + break Err(anyhow!( + "cloud language model completion failed with status {}", + response.status() + ))?; + } + }; + + let body = BufReader::new(response.into_body()); + + let stream = + futures::stream::try_unfold(body, move |mut body| async move { + let mut buffer = String::new(); + match body.read_line(&mut buffer).await { + Ok(0) => Ok(None), + Ok(_) => { + let event: anthropic::Event = + serde_json::from_str(&buffer)?; + Ok(Some((event, body))) + } + Err(e) => Err(e.into()), + } + }); + + Ok(anthropic::extract_text_from_events(stream)) + }); + async move { Ok(future.await?.boxed()) }.boxed() + } else { + let future = self.request_limiter.stream(async move { + let request = serde_json::to_string(&request)?; + let stream = client + .request_stream(proto::StreamCompleteWithLanguageModel { + provider: proto::LanguageModelProvider::Anthropic as i32, + request, + }) + .await? + .map(|event| Ok(serde_json::from_str(&event?.event)?)); + Ok(anthropic::extract_text_from_events(stream)) + }); + async move { Ok(future.await?.boxed()) }.boxed() + } } CloudModel::OpenAi(model) => { let client = self.client.clone(); @@ -417,6 +504,30 @@ impl LanguageModel for CloudLanguageModel { } } +impl LlmApiToken { + async fn acquire(&self, client: &Arc) -> Result { + let lock = self.0.upgradable_read().await; + if let Some(token) = lock.as_ref() { + Ok(token.to_string()) + } else { + Self::fetch(RwLockUpgradableReadGuard::upgrade(lock).await, &client).await + } + } + + async fn refresh(&self, client: &Arc) -> Result { + Self::fetch(self.0.write().await, &client).await + } + + async fn fetch<'a>( + mut lock: RwLockWriteGuard<'a, Option>, + client: &Arc, + ) -> Result { + let response = client.request(proto::GetLlmToken {}).await?; + *lock = Some(response.token.clone()); + Ok(response.token.clone()) + } +} + struct ConfigurationView { state: gpui::Model, } diff --git a/crates/proto/proto/zed.proto b/crates/proto/proto/zed.proto index 55cfc77e30..2697069b17 100644 --- a/crates/proto/proto/zed.proto +++ b/crates/proto/proto/zed.proto @@ -126,7 +126,7 @@ message Envelope { Unfollow unfollow = 101; GetPrivateUserInfo get_private_user_info = 102; GetPrivateUserInfoResponse get_private_user_info_response = 103; - UpdateUserPlan update_user_plan = 234; // current max + UpdateUserPlan update_user_plan = 234; UpdateDiffBase update_diff_base = 104; OnTypeFormatting on_type_formatting = 105; @@ -270,6 +270,9 @@ message Envelope { AddWorktree add_worktree = 222; AddWorktreeResponse add_worktree_response = 223; + + GetLlmToken get_llm_token = 235; + GetLlmTokenResponse get_llm_token_response = 236; // current max } reserved 158 to 161; @@ -2425,6 +2428,12 @@ message SynchronizeContextsResponse { repeated ContextVersion contexts = 1; } +message GetLlmToken {} + +message GetLlmTokenResponse { + string token = 1; +} + // Remote FS message AddWorktree { diff --git a/crates/proto/src/proto.rs b/crates/proto/src/proto.rs index 17bf73e0bd..aca40cad26 100644 --- a/crates/proto/src/proto.rs +++ b/crates/proto/src/proto.rs @@ -259,6 +259,8 @@ messages!( (GetTypeDefinitionResponse, Background), (GetImplementation, Background), (GetImplementationResponse, Background), + (GetLlmToken, Background), + (GetLlmTokenResponse, Background), (GetUsers, Foreground), (Hello, Foreground), (IncomingCall, Foreground), @@ -438,6 +440,7 @@ request_messages!( (GetImplementation, GetImplementationResponse), (GetDocumentHighlights, GetDocumentHighlightsResponse), (GetHover, GetHoverResponse), + (GetLlmToken, GetLlmTokenResponse), (GetNotifications, GetNotificationsResponse), (GetPrivateUserInfo, GetPrivateUserInfoResponse), (GetProjectSymbols, GetProjectSymbolsResponse), diff --git a/crates/rpc/src/llm.rs b/crates/rpc/src/llm.rs new file mode 100644 index 0000000000..64df4110ef --- /dev/null +++ b/crates/rpc/src/llm.rs @@ -0,0 +1,8 @@ +use serde::{Deserialize, Serialize}; + +pub const EXPIRED_LLM_TOKEN_HEADER_NAME: &str = "x-zed-expired-token"; + +#[derive(Serialize, Deserialize)] +pub struct PerformCompletionParams { + pub provider_request: Box, +} diff --git a/crates/rpc/src/rpc.rs b/crates/rpc/src/rpc.rs index b8741fd805..2e8b1ef6b7 100644 --- a/crates/rpc/src/rpc.rs +++ b/crates/rpc/src/rpc.rs @@ -1,12 +1,14 @@ pub mod auth; mod conn; mod extension; +mod llm; mod notification; mod peer; pub mod proto; pub use conn::Connection; pub use extension::*; +pub use llm::*; pub use notification::*; pub use peer::*; pub use proto::{error::*, Receipt, TypedEnvelope};