mirror of
https://github.com/zed-industries/zed.git
synced 2024-11-07 20:39:04 +03:00
Introduce a separate backend service for LLM calls (#15831)
This PR introduces a separate backend service for making LLM calls. It exposes an HTTP interface that can be called by Zed clients. To call these endpoints, the client must provide a `Bearer` token. These tokens are issued/refreshed by the collab service over RPC. We're adding this in a backwards-compatible way. Right now the access tokens can only be minted for Zed staff, and calling this separate LLM service is behind the `llm-service` feature flag (which is not automatically enabled for Zed staff). Release Notes: - N/A --------- Co-authored-by: Marshall <marshall@zed.dev> Co-authored-by: Marshall Bowers <elliott.codes@gmail.com>
This commit is contained in:
parent
4ed43e6e6f
commit
8e9c2b1125
2
Cargo.lock
generated
2
Cargo.lock
generated
@ -2466,6 +2466,7 @@ dependencies = [
|
||||
"hex",
|
||||
"http_client",
|
||||
"indoc",
|
||||
"jsonwebtoken",
|
||||
"language",
|
||||
"language_model",
|
||||
"live_kit_client",
|
||||
@ -2507,6 +2508,7 @@ dependencies = [
|
||||
"telemetry_events",
|
||||
"text",
|
||||
"theme",
|
||||
"thiserror",
|
||||
"time",
|
||||
"tokio",
|
||||
"toml 0.8.16",
|
||||
|
@ -15,6 +15,7 @@ BLOB_STORE_URL = "http://127.0.0.1:9000"
|
||||
BLOB_STORE_REGION = "the-region"
|
||||
ZED_CLIENT_CHECKSUM_SEED = "development-checksum-seed"
|
||||
SEED_PATH = "crates/collab/seed.default.json"
|
||||
LLM_API_SECRET = "llm-secret"
|
||||
|
||||
# CLICKHOUSE_URL = ""
|
||||
# CLICKHOUSE_USER = "default"
|
||||
|
@ -37,6 +37,7 @@ futures.workspace = true
|
||||
google_ai.workspace = true
|
||||
hex.workspace = true
|
||||
http_client.workspace = true
|
||||
jsonwebtoken.workspace = true
|
||||
live_kit_server.workspace = true
|
||||
log.workspace = true
|
||||
nanoid.workspace = true
|
||||
@ -61,6 +62,7 @@ subtle.workspace = true
|
||||
rustc-demangle.workspace = true
|
||||
telemetry_events.workspace = true
|
||||
text.workspace = true
|
||||
thiserror.workspace = true
|
||||
time.workspace = true
|
||||
tokio.workspace = true
|
||||
toml.workspace = true
|
||||
|
@ -81,14 +81,14 @@ pub async fn validate_api_token<B>(req: Request<B>, next: Next<B>) -> impl IntoR
|
||||
.get(http::header::AUTHORIZATION)
|
||||
.and_then(|header| header.to_str().ok())
|
||||
.ok_or_else(|| {
|
||||
Error::Http(
|
||||
Error::http(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"missing authorization header".to_string(),
|
||||
)
|
||||
})?
|
||||
.strip_prefix("token ")
|
||||
.ok_or_else(|| {
|
||||
Error::Http(
|
||||
Error::http(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"invalid authorization header".to_string(),
|
||||
)
|
||||
@ -97,7 +97,7 @@ pub async fn validate_api_token<B>(req: Request<B>, next: Next<B>) -> impl IntoR
|
||||
let state = req.extensions().get::<Arc<AppState>>().unwrap();
|
||||
|
||||
if token != state.config.api_token {
|
||||
Err(Error::Http(
|
||||
Err(Error::http(
|
||||
StatusCode::UNAUTHORIZED,
|
||||
"invalid authorization token".to_string(),
|
||||
))?
|
||||
@ -185,13 +185,13 @@ async fn create_access_token(
|
||||
if let Some(impersonated_user) = app.db.get_user_by_github_login(&impersonate).await? {
|
||||
impersonated_user_id = Some(impersonated_user.id);
|
||||
} else {
|
||||
return Err(Error::Http(
|
||||
return Err(Error::http(
|
||||
StatusCode::UNPROCESSABLE_ENTITY,
|
||||
format!("user {impersonate} does not exist"),
|
||||
));
|
||||
}
|
||||
} else {
|
||||
return Err(Error::Http(
|
||||
return Err(Error::http(
|
||||
StatusCode::UNAUTHORIZED,
|
||||
"you do not have permission to impersonate other users".to_string(),
|
||||
));
|
||||
|
@ -120,7 +120,7 @@ async fn create_billing_subscription(
|
||||
.zip(app.config.stripe_price_id.clone())
|
||||
else {
|
||||
log::error!("failed to retrieve Stripe client or price ID");
|
||||
Err(Error::Http(
|
||||
Err(Error::http(
|
||||
StatusCode::NOT_IMPLEMENTED,
|
||||
"not supported".into(),
|
||||
))?
|
||||
@ -201,7 +201,7 @@ async fn manage_billing_subscription(
|
||||
|
||||
let Some(stripe_client) = app.stripe_client.clone() else {
|
||||
log::error!("failed to retrieve Stripe client");
|
||||
Err(Error::Http(
|
||||
Err(Error::http(
|
||||
StatusCode::NOT_IMPLEMENTED,
|
||||
"not supported".into(),
|
||||
))?
|
||||
|
@ -206,14 +206,14 @@ pub async fn post_hang(
|
||||
body: Bytes,
|
||||
) -> Result<()> {
|
||||
let Some(expected) = calculate_json_checksum(app.clone(), &body) else {
|
||||
return Err(Error::Http(
|
||||
return Err(Error::http(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"events not enabled".into(),
|
||||
))?;
|
||||
};
|
||||
|
||||
if checksum != expected {
|
||||
return Err(Error::Http(
|
||||
return Err(Error::http(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"invalid checksum".into(),
|
||||
))?;
|
||||
@ -265,25 +265,25 @@ pub async fn post_panic(
|
||||
body: Bytes,
|
||||
) -> Result<()> {
|
||||
let Some(expected) = calculate_json_checksum(app.clone(), &body) else {
|
||||
return Err(Error::Http(
|
||||
return Err(Error::http(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"events not enabled".into(),
|
||||
))?;
|
||||
};
|
||||
|
||||
if checksum != expected {
|
||||
return Err(Error::Http(
|
||||
return Err(Error::http(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"invalid checksum".into(),
|
||||
))?;
|
||||
}
|
||||
|
||||
let report: telemetry_events::PanicRequest = serde_json::from_slice(&body)
|
||||
.map_err(|_| Error::Http(StatusCode::BAD_REQUEST, "invalid json".into()))?;
|
||||
.map_err(|_| Error::http(StatusCode::BAD_REQUEST, "invalid json".into()))?;
|
||||
let panic = report.panic;
|
||||
|
||||
if panic.os_name == "Linux" && panic.os_version == Some("1.0.0".to_string()) {
|
||||
return Err(Error::Http(
|
||||
return Err(Error::http(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"invalid os version".into(),
|
||||
))?;
|
||||
@ -362,14 +362,14 @@ pub async fn post_events(
|
||||
body: Bytes,
|
||||
) -> Result<()> {
|
||||
let Some(clickhouse_client) = app.clickhouse_client.clone() else {
|
||||
Err(Error::Http(
|
||||
Err(Error::http(
|
||||
StatusCode::NOT_IMPLEMENTED,
|
||||
"not supported".into(),
|
||||
))?
|
||||
};
|
||||
|
||||
let Some(expected) = calculate_json_checksum(app.clone(), &body) else {
|
||||
return Err(Error::Http(
|
||||
return Err(Error::http(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"events not enabled".into(),
|
||||
))?;
|
||||
@ -385,7 +385,7 @@ pub async fn post_events(
|
||||
|
||||
let mut to_upload = ToUpload::default();
|
||||
let Some(last_event) = request_body.events.last() else {
|
||||
return Err(Error::Http(StatusCode::BAD_REQUEST, "no events".into()))?;
|
||||
return Err(Error::http(StatusCode::BAD_REQUEST, "no events".into()))?;
|
||||
};
|
||||
let country_code = country_code_header.map(|h| h.to_string());
|
||||
|
||||
|
@ -185,7 +185,7 @@ async fn download_extension(
|
||||
.clone()
|
||||
.zip(app.config.blob_store_bucket.clone())
|
||||
else {
|
||||
Err(Error::Http(
|
||||
Err(Error::http(
|
||||
StatusCode::NOT_IMPLEMENTED,
|
||||
"not supported".into(),
|
||||
))?
|
||||
@ -202,7 +202,7 @@ async fn download_extension(
|
||||
.await?;
|
||||
|
||||
if !version_exists {
|
||||
Err(Error::Http(
|
||||
Err(Error::http(
|
||||
StatusCode::NOT_FOUND,
|
||||
"unknown extension version".into(),
|
||||
))?;
|
||||
|
@ -33,7 +33,7 @@ pub async fn validate_header<B>(mut req: Request<B>, next: Next<B>) -> impl Into
|
||||
.get(http::header::AUTHORIZATION)
|
||||
.and_then(|header| header.to_str().ok())
|
||||
.ok_or_else(|| {
|
||||
Error::Http(
|
||||
Error::http(
|
||||
StatusCode::UNAUTHORIZED,
|
||||
"missing authorization header".to_string(),
|
||||
)
|
||||
@ -45,14 +45,14 @@ pub async fn validate_header<B>(mut req: Request<B>, next: Next<B>) -> impl Into
|
||||
let first = auth_header.next().unwrap_or("");
|
||||
if first == "dev-server-token" {
|
||||
let dev_server_token = auth_header.next().ok_or_else(|| {
|
||||
Error::Http(
|
||||
Error::http(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"missing dev-server-token token in authorization header".to_string(),
|
||||
)
|
||||
})?;
|
||||
let dev_server = verify_dev_server_token(dev_server_token, &state.db)
|
||||
.await
|
||||
.map_err(|e| Error::Http(StatusCode::UNAUTHORIZED, format!("{}", e)))?;
|
||||
.map_err(|e| Error::http(StatusCode::UNAUTHORIZED, format!("{}", e)))?;
|
||||
|
||||
req.extensions_mut()
|
||||
.insert(Principal::DevServer(dev_server));
|
||||
@ -60,14 +60,14 @@ pub async fn validate_header<B>(mut req: Request<B>, next: Next<B>) -> impl Into
|
||||
}
|
||||
|
||||
let user_id = UserId(first.parse().map_err(|_| {
|
||||
Error::Http(
|
||||
Error::http(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"missing user id in authorization header".to_string(),
|
||||
)
|
||||
})?);
|
||||
|
||||
let access_token = auth_header.next().ok_or_else(|| {
|
||||
Error::Http(
|
||||
Error::http(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"missing access token in authorization header".to_string(),
|
||||
)
|
||||
@ -111,7 +111,7 @@ pub async fn validate_header<B>(mut req: Request<B>, next: Next<B>) -> impl Into
|
||||
}
|
||||
}
|
||||
|
||||
Err(Error::Http(
|
||||
Err(Error::http(
|
||||
StatusCode::UNAUTHORIZED,
|
||||
"invalid credentials".to_string(),
|
||||
))
|
||||
|
@ -13,7 +13,10 @@ mod tests;
|
||||
|
||||
use anyhow::anyhow;
|
||||
use aws_config::{BehaviorVersion, Region};
|
||||
use axum::{http::StatusCode, response::IntoResponse};
|
||||
use axum::{
|
||||
http::{HeaderMap, StatusCode},
|
||||
response::IntoResponse,
|
||||
};
|
||||
use db::{ChannelId, Database};
|
||||
use executor::Executor;
|
||||
pub use rate_limiter::*;
|
||||
@ -24,7 +27,7 @@ use util::ResultExt;
|
||||
pub type Result<T, E = Error> = std::result::Result<T, E>;
|
||||
|
||||
pub enum Error {
|
||||
Http(StatusCode, String),
|
||||
Http(StatusCode, String, HeaderMap),
|
||||
Database(sea_orm::error::DbErr),
|
||||
Internal(anyhow::Error),
|
||||
Stripe(stripe::StripeError),
|
||||
@ -66,12 +69,18 @@ impl From<serde_json::Error> for Error {
|
||||
}
|
||||
}
|
||||
|
||||
impl Error {
|
||||
fn http(code: StatusCode, message: String) -> Self {
|
||||
Self::Http(code, message, HeaderMap::default())
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoResponse for Error {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
match self {
|
||||
Error::Http(code, message) => {
|
||||
Error::Http(code, message, headers) => {
|
||||
log::error!("HTTP error {}: {}", code, &message);
|
||||
(code, message).into_response()
|
||||
(code, headers, message).into_response()
|
||||
}
|
||||
Error::Database(error) => {
|
||||
log::error!(
|
||||
@ -104,7 +113,7 @@ impl IntoResponse for Error {
|
||||
impl std::fmt::Debug for Error {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Error::Http(code, message) => (code, message).fmt(f),
|
||||
Error::Http(code, message, _headers) => (code, message).fmt(f),
|
||||
Error::Database(error) => error.fmt(f),
|
||||
Error::Internal(error) => error.fmt(f),
|
||||
Error::Stripe(error) => error.fmt(f),
|
||||
@ -115,7 +124,7 @@ impl std::fmt::Debug for Error {
|
||||
impl std::fmt::Display for Error {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Error::Http(code, message) => write!(f, "{code}: {message}"),
|
||||
Error::Http(code, message, _) => write!(f, "{code}: {message}"),
|
||||
Error::Database(error) => error.fmt(f),
|
||||
Error::Internal(error) => error.fmt(f),
|
||||
Error::Stripe(error) => error.fmt(f),
|
||||
@ -141,6 +150,7 @@ pub struct Config {
|
||||
pub live_kit_server: Option<String>,
|
||||
pub live_kit_key: Option<String>,
|
||||
pub live_kit_secret: Option<String>,
|
||||
pub llm_api_secret: Option<String>,
|
||||
pub rust_log: Option<String>,
|
||||
pub log_json: Option<bool>,
|
||||
pub blob_store_url: Option<String>,
|
||||
|
@ -1,16 +1,122 @@
|
||||
mod token;
|
||||
|
||||
use crate::{executor::Executor, Config, Error, Result};
|
||||
use anyhow::Context as _;
|
||||
use axum::{
|
||||
body::Body,
|
||||
http::{self, HeaderName, HeaderValue, Request, StatusCode},
|
||||
middleware::{self, Next},
|
||||
response::{IntoResponse, Response},
|
||||
routing::post,
|
||||
Extension, Json, Router,
|
||||
};
|
||||
use futures::StreamExt as _;
|
||||
use http_client::IsahcHttpClient;
|
||||
use rpc::{PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME};
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::{executor::Executor, Config, Result};
|
||||
pub use token::*;
|
||||
|
||||
pub struct LlmState {
|
||||
pub config: Config,
|
||||
pub executor: Executor,
|
||||
pub http_client: IsahcHttpClient,
|
||||
}
|
||||
|
||||
impl LlmState {
|
||||
pub async fn new(config: Config, executor: Executor) -> Result<Arc<Self>> {
|
||||
let this = Self { config, executor };
|
||||
let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
|
||||
let http_client = IsahcHttpClient::builder()
|
||||
.default_header("User-Agent", user_agent)
|
||||
.build()
|
||||
.context("failed to construct http client")?;
|
||||
|
||||
let this = Self {
|
||||
config,
|
||||
executor,
|
||||
http_client,
|
||||
};
|
||||
|
||||
Ok(Arc::new(this))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn routes() -> Router<(), Body> {
|
||||
Router::new()
|
||||
.route("/completion", post(perform_completion))
|
||||
.layer(middleware::from_fn(validate_api_token))
|
||||
}
|
||||
|
||||
async fn validate_api_token<B>(mut req: Request<B>, next: Next<B>) -> impl IntoResponse {
|
||||
let token = req
|
||||
.headers()
|
||||
.get(http::header::AUTHORIZATION)
|
||||
.and_then(|header| header.to_str().ok())
|
||||
.ok_or_else(|| {
|
||||
Error::http(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"missing authorization header".to_string(),
|
||||
)
|
||||
})?
|
||||
.strip_prefix("Bearer ")
|
||||
.ok_or_else(|| {
|
||||
Error::http(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"invalid authorization header".to_string(),
|
||||
)
|
||||
})?;
|
||||
|
||||
let state = req.extensions().get::<Arc<LlmState>>().unwrap();
|
||||
match LlmTokenClaims::validate(&token, &state.config) {
|
||||
Ok(claims) => {
|
||||
req.extensions_mut().insert(claims);
|
||||
Ok::<_, Error>(next.run(req).await.into_response())
|
||||
}
|
||||
Err(ValidateLlmTokenError::Expired) => Err(Error::Http(
|
||||
StatusCode::UNAUTHORIZED,
|
||||
"unauthorized".to_string(),
|
||||
[(
|
||||
HeaderName::from_static(EXPIRED_LLM_TOKEN_HEADER_NAME),
|
||||
HeaderValue::from_static("true"),
|
||||
)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
)),
|
||||
Err(_err) => Err(Error::http(
|
||||
StatusCode::UNAUTHORIZED,
|
||||
"unauthorized".to_string(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
async fn perform_completion(
|
||||
Extension(state): Extension<Arc<LlmState>>,
|
||||
Extension(_claims): Extension<LlmTokenClaims>,
|
||||
Json(params): Json<PerformCompletionParams>,
|
||||
) -> Result<impl IntoResponse> {
|
||||
let api_key = state
|
||||
.config
|
||||
.anthropic_api_key
|
||||
.as_ref()
|
||||
.context("no Anthropic AI API key configured on the server")?;
|
||||
let chunks = anthropic::stream_completion(
|
||||
&state.http_client,
|
||||
anthropic::ANTHROPIC_API_URL,
|
||||
api_key,
|
||||
serde_json::from_str(¶ms.provider_request.get())?,
|
||||
None,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let stream = chunks.map(|event| {
|
||||
let mut buffer = Vec::new();
|
||||
event.map(|chunk| {
|
||||
buffer.clear();
|
||||
serde_json::to_writer(&mut buffer, &chunk).unwrap();
|
||||
buffer.push(b'\n');
|
||||
buffer
|
||||
})
|
||||
});
|
||||
|
||||
Ok(Response::new(Body::wrap_stream(stream)))
|
||||
}
|
||||
|
75
crates/collab/src/llm/token.rs
Normal file
75
crates/collab/src/llm/token.rs
Normal file
@ -0,0 +1,75 @@
|
||||
use crate::{db::UserId, Config};
|
||||
use anyhow::{anyhow, Result};
|
||||
use chrono::Utc;
|
||||
use jsonwebtoken::{DecodingKey, EncodingKey, Header, Validation};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::time::Duration;
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct LlmTokenClaims {
|
||||
pub iat: u64,
|
||||
pub exp: u64,
|
||||
pub jti: String,
|
||||
pub user_id: u64,
|
||||
pub plan: rpc::proto::Plan,
|
||||
}
|
||||
|
||||
const LLM_TOKEN_LIFETIME: Duration = Duration::from_secs(60 * 60);
|
||||
|
||||
impl LlmTokenClaims {
|
||||
pub fn create(user_id: UserId, plan: rpc::proto::Plan, config: &Config) -> Result<String> {
|
||||
let secret = config
|
||||
.llm_api_secret
|
||||
.as_ref()
|
||||
.ok_or_else(|| anyhow!("no LLM API secret"))?;
|
||||
|
||||
let now = Utc::now();
|
||||
let claims = Self {
|
||||
iat: now.timestamp() as u64,
|
||||
exp: (now + LLM_TOKEN_LIFETIME).timestamp() as u64,
|
||||
jti: uuid::Uuid::new_v4().to_string(),
|
||||
user_id: user_id.to_proto(),
|
||||
plan,
|
||||
};
|
||||
|
||||
Ok(jsonwebtoken::encode(
|
||||
&Header::default(),
|
||||
&claims,
|
||||
&EncodingKey::from_secret(secret.as_ref()),
|
||||
)?)
|
||||
}
|
||||
|
||||
pub fn validate(token: &str, config: &Config) -> Result<LlmTokenClaims, ValidateLlmTokenError> {
|
||||
let secret = config
|
||||
.llm_api_secret
|
||||
.as_ref()
|
||||
.ok_or_else(|| anyhow!("no LLM API secret"))?;
|
||||
|
||||
match jsonwebtoken::decode::<Self>(
|
||||
token,
|
||||
&DecodingKey::from_secret(secret.as_ref()),
|
||||
&Validation::default(),
|
||||
) {
|
||||
Ok(token) => Ok(token.claims),
|
||||
Err(e) => {
|
||||
if e.kind() == &jsonwebtoken::errors::ErrorKind::ExpiredSignature {
|
||||
Err(ValidateLlmTokenError::Expired)
|
||||
} else {
|
||||
Err(ValidateLlmTokenError::JwtError(e))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum ValidateLlmTokenError {
|
||||
#[error("access token is expired")]
|
||||
Expired,
|
||||
#[error("access token validation error: {0}")]
|
||||
JwtError(#[from] jsonwebtoken::errors::Error),
|
||||
#[error("{0}")]
|
||||
Other(#[from] anyhow::Error),
|
||||
}
|
@ -83,7 +83,9 @@ async fn main() -> Result<()> {
|
||||
if mode.is_llm() {
|
||||
let state = LlmState::new(config.clone(), Executor::Production).await?;
|
||||
|
||||
app = app.layer(Extension(state.clone()));
|
||||
app = app
|
||||
.merge(collab::llm::routes())
|
||||
.layer(Extension(state.clone()));
|
||||
}
|
||||
|
||||
if mode.is_collab() || mode.is_api() {
|
||||
|
@ -1,6 +1,7 @@
|
||||
mod connection_pool;
|
||||
|
||||
use crate::api::CloudflareIpCountryHeader;
|
||||
use crate::llm::LlmTokenClaims;
|
||||
use crate::{
|
||||
auth,
|
||||
db::{
|
||||
@ -11,7 +12,7 @@ use crate::{
|
||||
ServerId, UpdatedChannelMessage, User, UserId,
|
||||
},
|
||||
executor::Executor,
|
||||
AppState, Config, Error, RateLimit, RateLimiter, Result,
|
||||
AppState, Config, Error, RateLimit, Result,
|
||||
};
|
||||
use anyhow::{anyhow, bail, Context as _};
|
||||
use async_tungstenite::tungstenite::{
|
||||
@ -149,10 +150,9 @@ struct Session {
|
||||
db: Arc<tokio::sync::Mutex<DbHandle>>,
|
||||
peer: Arc<Peer>,
|
||||
connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
|
||||
live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
|
||||
app_state: Arc<AppState>,
|
||||
supermaven_client: Option<Arc<SupermavenAdminApi>>,
|
||||
http_client: Arc<IsahcHttpClient>,
|
||||
rate_limiter: Arc<RateLimiter>,
|
||||
/// The GeoIP country code for the user.
|
||||
#[allow(unused)]
|
||||
geoip_country_code: Option<String>,
|
||||
@ -615,6 +615,7 @@ impl Server {
|
||||
.add_message_handler(user_message_handler(unfollow))
|
||||
.add_message_handler(user_message_handler(update_followers))
|
||||
.add_request_handler(user_handler(get_private_user_info))
|
||||
.add_request_handler(user_handler(get_llm_api_token))
|
||||
.add_message_handler(user_message_handler(acknowledge_channel_message))
|
||||
.add_message_handler(user_message_handler(acknowledge_buffer_version))
|
||||
.add_request_handler(user_handler(get_supermaven_api_key))
|
||||
@ -1046,9 +1047,8 @@ impl Server {
|
||||
db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
|
||||
peer: this.peer.clone(),
|
||||
connection_pool: this.connection_pool.clone(),
|
||||
live_kit_client: this.app_state.live_kit_client.clone(),
|
||||
app_state: this.app_state.clone(),
|
||||
http_client,
|
||||
rate_limiter: this.app_state.rate_limiter.clone(),
|
||||
geoip_country_code,
|
||||
_executor: executor.clone(),
|
||||
supermaven_client,
|
||||
@ -1559,7 +1559,7 @@ async fn create_room(
|
||||
let live_kit_room = nanoid::nanoid!(30);
|
||||
|
||||
let live_kit_connection_info = util::maybe!(async {
|
||||
let live_kit = session.live_kit_client.as_ref();
|
||||
let live_kit = session.app_state.live_kit_client.as_ref();
|
||||
let live_kit = live_kit?;
|
||||
let user_id = session.user_id().to_string();
|
||||
|
||||
@ -1630,25 +1630,26 @@ async fn join_room(
|
||||
.trace_err();
|
||||
}
|
||||
|
||||
let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
|
||||
if let Some(token) = live_kit
|
||||
.room_token(
|
||||
&joined_room.room.live_kit_room,
|
||||
&session.user_id().to_string(),
|
||||
)
|
||||
.trace_err()
|
||||
{
|
||||
Some(proto::LiveKitConnectionInfo {
|
||||
server_url: live_kit.url().into(),
|
||||
token,
|
||||
can_publish: true,
|
||||
})
|
||||
let live_kit_connection_info =
|
||||
if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
|
||||
if let Some(token) = live_kit
|
||||
.room_token(
|
||||
&joined_room.room.live_kit_room,
|
||||
&session.user_id().to_string(),
|
||||
)
|
||||
.trace_err()
|
||||
{
|
||||
Some(proto::LiveKitConnectionInfo {
|
||||
server_url: live_kit.url().into(),
|
||||
token,
|
||||
can_publish: true,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
};
|
||||
|
||||
response.send(proto::JoinRoomResponse {
|
||||
room: Some(joined_room.room),
|
||||
@ -1877,7 +1878,7 @@ async fn set_room_participant_role(
|
||||
(live_kit_room, can_publish)
|
||||
};
|
||||
|
||||
if let Some(live_kit) = session.live_kit_client.as_ref() {
|
||||
if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
|
||||
live_kit
|
||||
.update_participant(
|
||||
live_kit_room.clone(),
|
||||
@ -4048,35 +4049,40 @@ async fn join_channel_internal(
|
||||
.join_channel(channel_id, session.user_id(), session.connection_id)
|
||||
.await?;
|
||||
|
||||
let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
|
||||
let (can_publish, token) = if role == ChannelRole::Guest {
|
||||
(
|
||||
false,
|
||||
live_kit
|
||||
.guest_token(
|
||||
&joined_room.room.live_kit_room,
|
||||
&session.user_id().to_string(),
|
||||
let live_kit_connection_info =
|
||||
session
|
||||
.app_state
|
||||
.live_kit_client
|
||||
.as_ref()
|
||||
.and_then(|live_kit| {
|
||||
let (can_publish, token) = if role == ChannelRole::Guest {
|
||||
(
|
||||
false,
|
||||
live_kit
|
||||
.guest_token(
|
||||
&joined_room.room.live_kit_room,
|
||||
&session.user_id().to_string(),
|
||||
)
|
||||
.trace_err()?,
|
||||
)
|
||||
.trace_err()?,
|
||||
)
|
||||
} else {
|
||||
(
|
||||
true,
|
||||
live_kit
|
||||
.room_token(
|
||||
&joined_room.room.live_kit_room,
|
||||
&session.user_id().to_string(),
|
||||
} else {
|
||||
(
|
||||
true,
|
||||
live_kit
|
||||
.room_token(
|
||||
&joined_room.room.live_kit_room,
|
||||
&session.user_id().to_string(),
|
||||
)
|
||||
.trace_err()?,
|
||||
)
|
||||
.trace_err()?,
|
||||
)
|
||||
};
|
||||
};
|
||||
|
||||
Some(LiveKitConnectionInfo {
|
||||
server_url: live_kit.url().into(),
|
||||
token,
|
||||
can_publish,
|
||||
})
|
||||
});
|
||||
Some(LiveKitConnectionInfo {
|
||||
server_url: live_kit.url().into(),
|
||||
token,
|
||||
can_publish,
|
||||
})
|
||||
});
|
||||
|
||||
response.send(proto::JoinRoomResponse {
|
||||
room: Some(joined_room.room.clone()),
|
||||
@ -4610,6 +4616,7 @@ async fn complete_with_language_model(
|
||||
};
|
||||
|
||||
session
|
||||
.app_state
|
||||
.rate_limiter
|
||||
.check(&*rate_limit, session.user_id())
|
||||
.await?;
|
||||
@ -4655,6 +4662,7 @@ async fn stream_complete_with_language_model(
|
||||
};
|
||||
|
||||
session
|
||||
.app_state
|
||||
.rate_limiter
|
||||
.check(&*rate_limit, session.user_id())
|
||||
.await?;
|
||||
@ -4766,6 +4774,7 @@ async fn count_language_model_tokens(
|
||||
};
|
||||
|
||||
session
|
||||
.app_state
|
||||
.rate_limiter
|
||||
.check(&*rate_limit, session.user_id())
|
||||
.await?;
|
||||
@ -4885,6 +4894,7 @@ async fn compute_embeddings(
|
||||
};
|
||||
|
||||
session
|
||||
.app_state
|
||||
.rate_limiter
|
||||
.check(&*rate_limit, session.user_id())
|
||||
.await?;
|
||||
@ -5143,6 +5153,24 @@ async fn get_private_user_info(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn get_llm_api_token(
|
||||
_request: proto::GetLlmToken,
|
||||
response: Response<proto::GetLlmToken>,
|
||||
session: UserSession,
|
||||
) -> Result<()> {
|
||||
if !session.is_staff() {
|
||||
Err(anyhow!("permission denied"))?
|
||||
}
|
||||
|
||||
let token = LlmTokenClaims::create(
|
||||
session.user_id(),
|
||||
session.current_plan().await?,
|
||||
&session.app_state.config,
|
||||
)?;
|
||||
response.send(proto::GetLlmTokenResponse { token })?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
|
||||
let message = match message {
|
||||
TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
|
||||
@ -5486,7 +5514,7 @@ async fn leave_room_for_session(session: &UserSession, connection_id: Connection
|
||||
update_user_contacts(contact_user_id, &session).await?;
|
||||
}
|
||||
|
||||
if let Some(live_kit) = session.live_kit_client.as_ref() {
|
||||
if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
|
||||
live_kit
|
||||
.remove_participant(live_kit_room.clone(), session.user_id().to_string())
|
||||
.await
|
||||
|
@ -651,6 +651,7 @@ impl TestServer {
|
||||
live_kit_server: None,
|
||||
live_kit_key: None,
|
||||
live_kit_secret: None,
|
||||
llm_api_secret: None,
|
||||
rust_log: None,
|
||||
log_json: None,
|
||||
zed_environment: "test".into(),
|
||||
|
@ -175,6 +175,22 @@ impl HttpClientWithUrl {
|
||||
query,
|
||||
)?)
|
||||
}
|
||||
|
||||
/// Builds a Zed LLM URL using the given path.
|
||||
pub fn build_zed_llm_url(&self, path: &str, query: &[(&str, &str)]) -> Result<Url> {
|
||||
let base_url = self.base_url();
|
||||
let base_api_url = match base_url.as_ref() {
|
||||
"https://zed.dev" => "https://llm.zed.dev",
|
||||
"https://staging.zed.dev" => "https://llm-staging.zed.dev",
|
||||
"http://localhost:3000" => "http://localhost:8080",
|
||||
other => other,
|
||||
};
|
||||
|
||||
Ok(Url::parse_with_params(
|
||||
&format!("{}{}", base_api_url, path),
|
||||
query,
|
||||
)?)
|
||||
}
|
||||
}
|
||||
|
||||
impl HttpClient for Arc<HttpClientWithUrl> {
|
||||
|
@ -5,13 +5,20 @@ use crate::{
|
||||
LanguageModelProviderState, LanguageModelRequest, RateLimiter, ZedModel,
|
||||
};
|
||||
use anyhow::{anyhow, Context as _, Result};
|
||||
use client::{Client, UserStore};
|
||||
use client::{Client, PerformCompletionParams, UserStore, EXPIRED_LLM_TOKEN_HEADER_NAME};
|
||||
use collections::BTreeMap;
|
||||
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
|
||||
use feature_flags::{FeatureFlag, FeatureFlagAppExt};
|
||||
use futures::{future::BoxFuture, stream::BoxStream, AsyncBufReadExt, FutureExt, StreamExt};
|
||||
use gpui::{AnyView, AppContext, AsyncAppContext, Model, ModelContext, Subscription, Task};
|
||||
use http_client::{HttpClient, Method};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::value::RawValue;
|
||||
use settings::{Settings, SettingsStore};
|
||||
use smol::{
|
||||
io::BufReader,
|
||||
lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard},
|
||||
};
|
||||
use std::{future, sync::Arc};
|
||||
use strum::IntoEnumIterator;
|
||||
use ui::prelude::*;
|
||||
@ -46,6 +53,7 @@ pub struct AvailableModel {
|
||||
|
||||
pub struct CloudLanguageModelProvider {
|
||||
client: Arc<Client>,
|
||||
llm_api_token: LlmApiToken,
|
||||
state: gpui::Model<State>,
|
||||
_maintain_client_status: Task<()>,
|
||||
}
|
||||
@ -104,6 +112,7 @@ impl CloudLanguageModelProvider {
|
||||
Self {
|
||||
client,
|
||||
state,
|
||||
llm_api_token: LlmApiToken::default(),
|
||||
_maintain_client_status: maintain_client_status,
|
||||
}
|
||||
}
|
||||
@ -181,6 +190,7 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
|
||||
Arc::new(CloudLanguageModel {
|
||||
id: LanguageModelId::from(model.id().to_string()),
|
||||
model,
|
||||
llm_api_token: self.llm_api_token.clone(),
|
||||
client: self.client.clone(),
|
||||
request_limiter: RateLimiter::new(4),
|
||||
}) as Arc<dyn LanguageModel>
|
||||
@ -208,13 +218,27 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
|
||||
}
|
||||
}
|
||||
|
||||
struct LlmServiceFeatureFlag;
|
||||
|
||||
impl FeatureFlag for LlmServiceFeatureFlag {
|
||||
const NAME: &'static str = "llm-service";
|
||||
|
||||
fn enabled_for_staff() -> bool {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
pub struct CloudLanguageModel {
|
||||
id: LanguageModelId,
|
||||
model: CloudModel,
|
||||
llm_api_token: LlmApiToken,
|
||||
client: Arc<Client>,
|
||||
request_limiter: RateLimiter,
|
||||
}
|
||||
|
||||
#[derive(Clone, Default)]
|
||||
struct LlmApiToken(Arc<RwLock<Option<String>>>);
|
||||
|
||||
impl LanguageModel for CloudLanguageModel {
|
||||
fn id(&self) -> LanguageModelId {
|
||||
self.id.clone()
|
||||
@ -279,25 +303,88 @@ impl LanguageModel for CloudLanguageModel {
|
||||
fn stream_completion(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
_: &AsyncAppContext,
|
||||
cx: &AsyncAppContext,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||
match &self.model {
|
||||
CloudModel::Anthropic(model) => {
|
||||
let client = self.client.clone();
|
||||
let request = request.into_anthropic(model.id().into());
|
||||
let future = self.request_limiter.stream(async move {
|
||||
let request = serde_json::to_string(&request)?;
|
||||
let stream = client
|
||||
.request_stream(proto::StreamCompleteWithLanguageModel {
|
||||
provider: proto::LanguageModelProvider::Anthropic as i32,
|
||||
request,
|
||||
})
|
||||
.await?;
|
||||
Ok(anthropic::extract_text_from_events(
|
||||
stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
|
||||
))
|
||||
});
|
||||
async move { Ok(future.await?.boxed()) }.boxed()
|
||||
let client = self.client.clone();
|
||||
|
||||
if cx
|
||||
.update(|cx| cx.has_flag::<LlmServiceFeatureFlag>())
|
||||
.unwrap_or(false)
|
||||
{
|
||||
let http_client = self.client.http_client();
|
||||
let llm_api_token = self.llm_api_token.clone();
|
||||
let future = self.request_limiter.stream(async move {
|
||||
let request = serde_json::to_string(&request)?;
|
||||
let mut token = llm_api_token.acquire(&client).await?;
|
||||
let mut did_retry = false;
|
||||
|
||||
let response = loop {
|
||||
let request = http_client::Request::builder()
|
||||
.method(Method::POST)
|
||||
.uri(http_client.build_zed_llm_url("/completion", &[])?.as_ref())
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.body(
|
||||
serde_json::to_string(&PerformCompletionParams {
|
||||
provider_request: RawValue::from_string(request.clone())?,
|
||||
})?
|
||||
.into(),
|
||||
)?;
|
||||
let response = http_client.send(request).await?;
|
||||
if response.status().is_success() {
|
||||
break response;
|
||||
} else if !did_retry
|
||||
&& response
|
||||
.headers()
|
||||
.get(EXPIRED_LLM_TOKEN_HEADER_NAME)
|
||||
.is_some()
|
||||
{
|
||||
did_retry = true;
|
||||
token = llm_api_token.refresh(&client).await?;
|
||||
} else {
|
||||
break Err(anyhow!(
|
||||
"cloud language model completion failed with status {}",
|
||||
response.status()
|
||||
))?;
|
||||
}
|
||||
};
|
||||
|
||||
let body = BufReader::new(response.into_body());
|
||||
|
||||
let stream =
|
||||
futures::stream::try_unfold(body, move |mut body| async move {
|
||||
let mut buffer = String::new();
|
||||
match body.read_line(&mut buffer).await {
|
||||
Ok(0) => Ok(None),
|
||||
Ok(_) => {
|
||||
let event: anthropic::Event =
|
||||
serde_json::from_str(&buffer)?;
|
||||
Ok(Some((event, body)))
|
||||
}
|
||||
Err(e) => Err(e.into()),
|
||||
}
|
||||
});
|
||||
|
||||
Ok(anthropic::extract_text_from_events(stream))
|
||||
});
|
||||
async move { Ok(future.await?.boxed()) }.boxed()
|
||||
} else {
|
||||
let future = self.request_limiter.stream(async move {
|
||||
let request = serde_json::to_string(&request)?;
|
||||
let stream = client
|
||||
.request_stream(proto::StreamCompleteWithLanguageModel {
|
||||
provider: proto::LanguageModelProvider::Anthropic as i32,
|
||||
request,
|
||||
})
|
||||
.await?
|
||||
.map(|event| Ok(serde_json::from_str(&event?.event)?));
|
||||
Ok(anthropic::extract_text_from_events(stream))
|
||||
});
|
||||
async move { Ok(future.await?.boxed()) }.boxed()
|
||||
}
|
||||
}
|
||||
CloudModel::OpenAi(model) => {
|
||||
let client = self.client.clone();
|
||||
@ -417,6 +504,30 @@ impl LanguageModel for CloudLanguageModel {
|
||||
}
|
||||
}
|
||||
|
||||
impl LlmApiToken {
|
||||
async fn acquire(&self, client: &Arc<Client>) -> Result<String> {
|
||||
let lock = self.0.upgradable_read().await;
|
||||
if let Some(token) = lock.as_ref() {
|
||||
Ok(token.to_string())
|
||||
} else {
|
||||
Self::fetch(RwLockUpgradableReadGuard::upgrade(lock).await, &client).await
|
||||
}
|
||||
}
|
||||
|
||||
async fn refresh(&self, client: &Arc<Client>) -> Result<String> {
|
||||
Self::fetch(self.0.write().await, &client).await
|
||||
}
|
||||
|
||||
async fn fetch<'a>(
|
||||
mut lock: RwLockWriteGuard<'a, Option<String>>,
|
||||
client: &Arc<Client>,
|
||||
) -> Result<String> {
|
||||
let response = client.request(proto::GetLlmToken {}).await?;
|
||||
*lock = Some(response.token.clone());
|
||||
Ok(response.token.clone())
|
||||
}
|
||||
}
|
||||
|
||||
struct ConfigurationView {
|
||||
state: gpui::Model<State>,
|
||||
}
|
||||
|
@ -126,7 +126,7 @@ message Envelope {
|
||||
Unfollow unfollow = 101;
|
||||
GetPrivateUserInfo get_private_user_info = 102;
|
||||
GetPrivateUserInfoResponse get_private_user_info_response = 103;
|
||||
UpdateUserPlan update_user_plan = 234; // current max
|
||||
UpdateUserPlan update_user_plan = 234;
|
||||
UpdateDiffBase update_diff_base = 104;
|
||||
|
||||
OnTypeFormatting on_type_formatting = 105;
|
||||
@ -270,6 +270,9 @@ message Envelope {
|
||||
|
||||
AddWorktree add_worktree = 222;
|
||||
AddWorktreeResponse add_worktree_response = 223;
|
||||
|
||||
GetLlmToken get_llm_token = 235;
|
||||
GetLlmTokenResponse get_llm_token_response = 236; // current max
|
||||
}
|
||||
|
||||
reserved 158 to 161;
|
||||
@ -2425,6 +2428,12 @@ message SynchronizeContextsResponse {
|
||||
repeated ContextVersion contexts = 1;
|
||||
}
|
||||
|
||||
message GetLlmToken {}
|
||||
|
||||
message GetLlmTokenResponse {
|
||||
string token = 1;
|
||||
}
|
||||
|
||||
// Remote FS
|
||||
|
||||
message AddWorktree {
|
||||
|
@ -259,6 +259,8 @@ messages!(
|
||||
(GetTypeDefinitionResponse, Background),
|
||||
(GetImplementation, Background),
|
||||
(GetImplementationResponse, Background),
|
||||
(GetLlmToken, Background),
|
||||
(GetLlmTokenResponse, Background),
|
||||
(GetUsers, Foreground),
|
||||
(Hello, Foreground),
|
||||
(IncomingCall, Foreground),
|
||||
@ -438,6 +440,7 @@ request_messages!(
|
||||
(GetImplementation, GetImplementationResponse),
|
||||
(GetDocumentHighlights, GetDocumentHighlightsResponse),
|
||||
(GetHover, GetHoverResponse),
|
||||
(GetLlmToken, GetLlmTokenResponse),
|
||||
(GetNotifications, GetNotificationsResponse),
|
||||
(GetPrivateUserInfo, GetPrivateUserInfoResponse),
|
||||
(GetProjectSymbols, GetProjectSymbolsResponse),
|
||||
|
8
crates/rpc/src/llm.rs
Normal file
8
crates/rpc/src/llm.rs
Normal file
@ -0,0 +1,8 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
pub const EXPIRED_LLM_TOKEN_HEADER_NAME: &str = "x-zed-expired-token";
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct PerformCompletionParams {
|
||||
pub provider_request: Box<serde_json::value::RawValue>,
|
||||
}
|
@ -1,12 +1,14 @@
|
||||
pub mod auth;
|
||||
mod conn;
|
||||
mod extension;
|
||||
mod llm;
|
||||
mod notification;
|
||||
mod peer;
|
||||
pub mod proto;
|
||||
|
||||
pub use conn::Connection;
|
||||
pub use extension::*;
|
||||
pub use llm::*;
|
||||
pub use notification::*;
|
||||
pub use peer::*;
|
||||
pub use proto::{error::*, Receipt, TypedEnvelope};
|
||||
|
Loading…
Reference in New Issue
Block a user