mirror of
https://github.com/zed-industries/zed.git
synced 2024-11-08 07:35:01 +03:00
Introduce a separate backend service for LLM calls (#15831)
This PR introduces a separate backend service for making LLM calls. It exposes an HTTP interface that can be called by Zed clients. To call these endpoints, the client must provide a `Bearer` token. These tokens are issued/refreshed by the collab service over RPC. We're adding this in a backwards-compatible way. Right now the access tokens can only be minted for Zed staff, and calling this separate LLM service is behind the `llm-service` feature flag (which is not automatically enabled for Zed staff). Release Notes: - N/A --------- Co-authored-by: Marshall <marshall@zed.dev> Co-authored-by: Marshall Bowers <elliott.codes@gmail.com>
This commit is contained in:
parent
4ed43e6e6f
commit
8e9c2b1125
2
Cargo.lock
generated
2
Cargo.lock
generated
@ -2466,6 +2466,7 @@ dependencies = [
|
|||||||
"hex",
|
"hex",
|
||||||
"http_client",
|
"http_client",
|
||||||
"indoc",
|
"indoc",
|
||||||
|
"jsonwebtoken",
|
||||||
"language",
|
"language",
|
||||||
"language_model",
|
"language_model",
|
||||||
"live_kit_client",
|
"live_kit_client",
|
||||||
@ -2507,6 +2508,7 @@ dependencies = [
|
|||||||
"telemetry_events",
|
"telemetry_events",
|
||||||
"text",
|
"text",
|
||||||
"theme",
|
"theme",
|
||||||
|
"thiserror",
|
||||||
"time",
|
"time",
|
||||||
"tokio",
|
"tokio",
|
||||||
"toml 0.8.16",
|
"toml 0.8.16",
|
||||||
|
@ -15,6 +15,7 @@ BLOB_STORE_URL = "http://127.0.0.1:9000"
|
|||||||
BLOB_STORE_REGION = "the-region"
|
BLOB_STORE_REGION = "the-region"
|
||||||
ZED_CLIENT_CHECKSUM_SEED = "development-checksum-seed"
|
ZED_CLIENT_CHECKSUM_SEED = "development-checksum-seed"
|
||||||
SEED_PATH = "crates/collab/seed.default.json"
|
SEED_PATH = "crates/collab/seed.default.json"
|
||||||
|
LLM_API_SECRET = "llm-secret"
|
||||||
|
|
||||||
# CLICKHOUSE_URL = ""
|
# CLICKHOUSE_URL = ""
|
||||||
# CLICKHOUSE_USER = "default"
|
# CLICKHOUSE_USER = "default"
|
||||||
|
@ -37,6 +37,7 @@ futures.workspace = true
|
|||||||
google_ai.workspace = true
|
google_ai.workspace = true
|
||||||
hex.workspace = true
|
hex.workspace = true
|
||||||
http_client.workspace = true
|
http_client.workspace = true
|
||||||
|
jsonwebtoken.workspace = true
|
||||||
live_kit_server.workspace = true
|
live_kit_server.workspace = true
|
||||||
log.workspace = true
|
log.workspace = true
|
||||||
nanoid.workspace = true
|
nanoid.workspace = true
|
||||||
@ -61,6 +62,7 @@ subtle.workspace = true
|
|||||||
rustc-demangle.workspace = true
|
rustc-demangle.workspace = true
|
||||||
telemetry_events.workspace = true
|
telemetry_events.workspace = true
|
||||||
text.workspace = true
|
text.workspace = true
|
||||||
|
thiserror.workspace = true
|
||||||
time.workspace = true
|
time.workspace = true
|
||||||
tokio.workspace = true
|
tokio.workspace = true
|
||||||
toml.workspace = true
|
toml.workspace = true
|
||||||
|
@ -81,14 +81,14 @@ pub async fn validate_api_token<B>(req: Request<B>, next: Next<B>) -> impl IntoR
|
|||||||
.get(http::header::AUTHORIZATION)
|
.get(http::header::AUTHORIZATION)
|
||||||
.and_then(|header| header.to_str().ok())
|
.and_then(|header| header.to_str().ok())
|
||||||
.ok_or_else(|| {
|
.ok_or_else(|| {
|
||||||
Error::Http(
|
Error::http(
|
||||||
StatusCode::BAD_REQUEST,
|
StatusCode::BAD_REQUEST,
|
||||||
"missing authorization header".to_string(),
|
"missing authorization header".to_string(),
|
||||||
)
|
)
|
||||||
})?
|
})?
|
||||||
.strip_prefix("token ")
|
.strip_prefix("token ")
|
||||||
.ok_or_else(|| {
|
.ok_or_else(|| {
|
||||||
Error::Http(
|
Error::http(
|
||||||
StatusCode::BAD_REQUEST,
|
StatusCode::BAD_REQUEST,
|
||||||
"invalid authorization header".to_string(),
|
"invalid authorization header".to_string(),
|
||||||
)
|
)
|
||||||
@ -97,7 +97,7 @@ pub async fn validate_api_token<B>(req: Request<B>, next: Next<B>) -> impl IntoR
|
|||||||
let state = req.extensions().get::<Arc<AppState>>().unwrap();
|
let state = req.extensions().get::<Arc<AppState>>().unwrap();
|
||||||
|
|
||||||
if token != state.config.api_token {
|
if token != state.config.api_token {
|
||||||
Err(Error::Http(
|
Err(Error::http(
|
||||||
StatusCode::UNAUTHORIZED,
|
StatusCode::UNAUTHORIZED,
|
||||||
"invalid authorization token".to_string(),
|
"invalid authorization token".to_string(),
|
||||||
))?
|
))?
|
||||||
@ -185,13 +185,13 @@ async fn create_access_token(
|
|||||||
if let Some(impersonated_user) = app.db.get_user_by_github_login(&impersonate).await? {
|
if let Some(impersonated_user) = app.db.get_user_by_github_login(&impersonate).await? {
|
||||||
impersonated_user_id = Some(impersonated_user.id);
|
impersonated_user_id = Some(impersonated_user.id);
|
||||||
} else {
|
} else {
|
||||||
return Err(Error::Http(
|
return Err(Error::http(
|
||||||
StatusCode::UNPROCESSABLE_ENTITY,
|
StatusCode::UNPROCESSABLE_ENTITY,
|
||||||
format!("user {impersonate} does not exist"),
|
format!("user {impersonate} does not exist"),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
return Err(Error::Http(
|
return Err(Error::http(
|
||||||
StatusCode::UNAUTHORIZED,
|
StatusCode::UNAUTHORIZED,
|
||||||
"you do not have permission to impersonate other users".to_string(),
|
"you do not have permission to impersonate other users".to_string(),
|
||||||
));
|
));
|
||||||
|
@ -120,7 +120,7 @@ async fn create_billing_subscription(
|
|||||||
.zip(app.config.stripe_price_id.clone())
|
.zip(app.config.stripe_price_id.clone())
|
||||||
else {
|
else {
|
||||||
log::error!("failed to retrieve Stripe client or price ID");
|
log::error!("failed to retrieve Stripe client or price ID");
|
||||||
Err(Error::Http(
|
Err(Error::http(
|
||||||
StatusCode::NOT_IMPLEMENTED,
|
StatusCode::NOT_IMPLEMENTED,
|
||||||
"not supported".into(),
|
"not supported".into(),
|
||||||
))?
|
))?
|
||||||
@ -201,7 +201,7 @@ async fn manage_billing_subscription(
|
|||||||
|
|
||||||
let Some(stripe_client) = app.stripe_client.clone() else {
|
let Some(stripe_client) = app.stripe_client.clone() else {
|
||||||
log::error!("failed to retrieve Stripe client");
|
log::error!("failed to retrieve Stripe client");
|
||||||
Err(Error::Http(
|
Err(Error::http(
|
||||||
StatusCode::NOT_IMPLEMENTED,
|
StatusCode::NOT_IMPLEMENTED,
|
||||||
"not supported".into(),
|
"not supported".into(),
|
||||||
))?
|
))?
|
||||||
|
@ -206,14 +206,14 @@ pub async fn post_hang(
|
|||||||
body: Bytes,
|
body: Bytes,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let Some(expected) = calculate_json_checksum(app.clone(), &body) else {
|
let Some(expected) = calculate_json_checksum(app.clone(), &body) else {
|
||||||
return Err(Error::Http(
|
return Err(Error::http(
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
"events not enabled".into(),
|
"events not enabled".into(),
|
||||||
))?;
|
))?;
|
||||||
};
|
};
|
||||||
|
|
||||||
if checksum != expected {
|
if checksum != expected {
|
||||||
return Err(Error::Http(
|
return Err(Error::http(
|
||||||
StatusCode::BAD_REQUEST,
|
StatusCode::BAD_REQUEST,
|
||||||
"invalid checksum".into(),
|
"invalid checksum".into(),
|
||||||
))?;
|
))?;
|
||||||
@ -265,25 +265,25 @@ pub async fn post_panic(
|
|||||||
body: Bytes,
|
body: Bytes,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let Some(expected) = calculate_json_checksum(app.clone(), &body) else {
|
let Some(expected) = calculate_json_checksum(app.clone(), &body) else {
|
||||||
return Err(Error::Http(
|
return Err(Error::http(
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
"events not enabled".into(),
|
"events not enabled".into(),
|
||||||
))?;
|
))?;
|
||||||
};
|
};
|
||||||
|
|
||||||
if checksum != expected {
|
if checksum != expected {
|
||||||
return Err(Error::Http(
|
return Err(Error::http(
|
||||||
StatusCode::BAD_REQUEST,
|
StatusCode::BAD_REQUEST,
|
||||||
"invalid checksum".into(),
|
"invalid checksum".into(),
|
||||||
))?;
|
))?;
|
||||||
}
|
}
|
||||||
|
|
||||||
let report: telemetry_events::PanicRequest = serde_json::from_slice(&body)
|
let report: telemetry_events::PanicRequest = serde_json::from_slice(&body)
|
||||||
.map_err(|_| Error::Http(StatusCode::BAD_REQUEST, "invalid json".into()))?;
|
.map_err(|_| Error::http(StatusCode::BAD_REQUEST, "invalid json".into()))?;
|
||||||
let panic = report.panic;
|
let panic = report.panic;
|
||||||
|
|
||||||
if panic.os_name == "Linux" && panic.os_version == Some("1.0.0".to_string()) {
|
if panic.os_name == "Linux" && panic.os_version == Some("1.0.0".to_string()) {
|
||||||
return Err(Error::Http(
|
return Err(Error::http(
|
||||||
StatusCode::BAD_REQUEST,
|
StatusCode::BAD_REQUEST,
|
||||||
"invalid os version".into(),
|
"invalid os version".into(),
|
||||||
))?;
|
))?;
|
||||||
@ -362,14 +362,14 @@ pub async fn post_events(
|
|||||||
body: Bytes,
|
body: Bytes,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let Some(clickhouse_client) = app.clickhouse_client.clone() else {
|
let Some(clickhouse_client) = app.clickhouse_client.clone() else {
|
||||||
Err(Error::Http(
|
Err(Error::http(
|
||||||
StatusCode::NOT_IMPLEMENTED,
|
StatusCode::NOT_IMPLEMENTED,
|
||||||
"not supported".into(),
|
"not supported".into(),
|
||||||
))?
|
))?
|
||||||
};
|
};
|
||||||
|
|
||||||
let Some(expected) = calculate_json_checksum(app.clone(), &body) else {
|
let Some(expected) = calculate_json_checksum(app.clone(), &body) else {
|
||||||
return Err(Error::Http(
|
return Err(Error::http(
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
"events not enabled".into(),
|
"events not enabled".into(),
|
||||||
))?;
|
))?;
|
||||||
@ -385,7 +385,7 @@ pub async fn post_events(
|
|||||||
|
|
||||||
let mut to_upload = ToUpload::default();
|
let mut to_upload = ToUpload::default();
|
||||||
let Some(last_event) = request_body.events.last() else {
|
let Some(last_event) = request_body.events.last() else {
|
||||||
return Err(Error::Http(StatusCode::BAD_REQUEST, "no events".into()))?;
|
return Err(Error::http(StatusCode::BAD_REQUEST, "no events".into()))?;
|
||||||
};
|
};
|
||||||
let country_code = country_code_header.map(|h| h.to_string());
|
let country_code = country_code_header.map(|h| h.to_string());
|
||||||
|
|
||||||
|
@ -185,7 +185,7 @@ async fn download_extension(
|
|||||||
.clone()
|
.clone()
|
||||||
.zip(app.config.blob_store_bucket.clone())
|
.zip(app.config.blob_store_bucket.clone())
|
||||||
else {
|
else {
|
||||||
Err(Error::Http(
|
Err(Error::http(
|
||||||
StatusCode::NOT_IMPLEMENTED,
|
StatusCode::NOT_IMPLEMENTED,
|
||||||
"not supported".into(),
|
"not supported".into(),
|
||||||
))?
|
))?
|
||||||
@ -202,7 +202,7 @@ async fn download_extension(
|
|||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
if !version_exists {
|
if !version_exists {
|
||||||
Err(Error::Http(
|
Err(Error::http(
|
||||||
StatusCode::NOT_FOUND,
|
StatusCode::NOT_FOUND,
|
||||||
"unknown extension version".into(),
|
"unknown extension version".into(),
|
||||||
))?;
|
))?;
|
||||||
|
@ -33,7 +33,7 @@ pub async fn validate_header<B>(mut req: Request<B>, next: Next<B>) -> impl Into
|
|||||||
.get(http::header::AUTHORIZATION)
|
.get(http::header::AUTHORIZATION)
|
||||||
.and_then(|header| header.to_str().ok())
|
.and_then(|header| header.to_str().ok())
|
||||||
.ok_or_else(|| {
|
.ok_or_else(|| {
|
||||||
Error::Http(
|
Error::http(
|
||||||
StatusCode::UNAUTHORIZED,
|
StatusCode::UNAUTHORIZED,
|
||||||
"missing authorization header".to_string(),
|
"missing authorization header".to_string(),
|
||||||
)
|
)
|
||||||
@ -45,14 +45,14 @@ pub async fn validate_header<B>(mut req: Request<B>, next: Next<B>) -> impl Into
|
|||||||
let first = auth_header.next().unwrap_or("");
|
let first = auth_header.next().unwrap_or("");
|
||||||
if first == "dev-server-token" {
|
if first == "dev-server-token" {
|
||||||
let dev_server_token = auth_header.next().ok_or_else(|| {
|
let dev_server_token = auth_header.next().ok_or_else(|| {
|
||||||
Error::Http(
|
Error::http(
|
||||||
StatusCode::BAD_REQUEST,
|
StatusCode::BAD_REQUEST,
|
||||||
"missing dev-server-token token in authorization header".to_string(),
|
"missing dev-server-token token in authorization header".to_string(),
|
||||||
)
|
)
|
||||||
})?;
|
})?;
|
||||||
let dev_server = verify_dev_server_token(dev_server_token, &state.db)
|
let dev_server = verify_dev_server_token(dev_server_token, &state.db)
|
||||||
.await
|
.await
|
||||||
.map_err(|e| Error::Http(StatusCode::UNAUTHORIZED, format!("{}", e)))?;
|
.map_err(|e| Error::http(StatusCode::UNAUTHORIZED, format!("{}", e)))?;
|
||||||
|
|
||||||
req.extensions_mut()
|
req.extensions_mut()
|
||||||
.insert(Principal::DevServer(dev_server));
|
.insert(Principal::DevServer(dev_server));
|
||||||
@ -60,14 +60,14 @@ pub async fn validate_header<B>(mut req: Request<B>, next: Next<B>) -> impl Into
|
|||||||
}
|
}
|
||||||
|
|
||||||
let user_id = UserId(first.parse().map_err(|_| {
|
let user_id = UserId(first.parse().map_err(|_| {
|
||||||
Error::Http(
|
Error::http(
|
||||||
StatusCode::BAD_REQUEST,
|
StatusCode::BAD_REQUEST,
|
||||||
"missing user id in authorization header".to_string(),
|
"missing user id in authorization header".to_string(),
|
||||||
)
|
)
|
||||||
})?);
|
})?);
|
||||||
|
|
||||||
let access_token = auth_header.next().ok_or_else(|| {
|
let access_token = auth_header.next().ok_or_else(|| {
|
||||||
Error::Http(
|
Error::http(
|
||||||
StatusCode::BAD_REQUEST,
|
StatusCode::BAD_REQUEST,
|
||||||
"missing access token in authorization header".to_string(),
|
"missing access token in authorization header".to_string(),
|
||||||
)
|
)
|
||||||
@ -111,7 +111,7 @@ pub async fn validate_header<B>(mut req: Request<B>, next: Next<B>) -> impl Into
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Err(Error::Http(
|
Err(Error::http(
|
||||||
StatusCode::UNAUTHORIZED,
|
StatusCode::UNAUTHORIZED,
|
||||||
"invalid credentials".to_string(),
|
"invalid credentials".to_string(),
|
||||||
))
|
))
|
||||||
|
@ -13,7 +13,10 @@ mod tests;
|
|||||||
|
|
||||||
use anyhow::anyhow;
|
use anyhow::anyhow;
|
||||||
use aws_config::{BehaviorVersion, Region};
|
use aws_config::{BehaviorVersion, Region};
|
||||||
use axum::{http::StatusCode, response::IntoResponse};
|
use axum::{
|
||||||
|
http::{HeaderMap, StatusCode},
|
||||||
|
response::IntoResponse,
|
||||||
|
};
|
||||||
use db::{ChannelId, Database};
|
use db::{ChannelId, Database};
|
||||||
use executor::Executor;
|
use executor::Executor;
|
||||||
pub use rate_limiter::*;
|
pub use rate_limiter::*;
|
||||||
@ -24,7 +27,7 @@ use util::ResultExt;
|
|||||||
pub type Result<T, E = Error> = std::result::Result<T, E>;
|
pub type Result<T, E = Error> = std::result::Result<T, E>;
|
||||||
|
|
||||||
pub enum Error {
|
pub enum Error {
|
||||||
Http(StatusCode, String),
|
Http(StatusCode, String, HeaderMap),
|
||||||
Database(sea_orm::error::DbErr),
|
Database(sea_orm::error::DbErr),
|
||||||
Internal(anyhow::Error),
|
Internal(anyhow::Error),
|
||||||
Stripe(stripe::StripeError),
|
Stripe(stripe::StripeError),
|
||||||
@ -66,12 +69,18 @@ impl From<serde_json::Error> for Error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Error {
|
||||||
|
fn http(code: StatusCode, message: String) -> Self {
|
||||||
|
Self::Http(code, message, HeaderMap::default())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl IntoResponse for Error {
|
impl IntoResponse for Error {
|
||||||
fn into_response(self) -> axum::response::Response {
|
fn into_response(self) -> axum::response::Response {
|
||||||
match self {
|
match self {
|
||||||
Error::Http(code, message) => {
|
Error::Http(code, message, headers) => {
|
||||||
log::error!("HTTP error {}: {}", code, &message);
|
log::error!("HTTP error {}: {}", code, &message);
|
||||||
(code, message).into_response()
|
(code, headers, message).into_response()
|
||||||
}
|
}
|
||||||
Error::Database(error) => {
|
Error::Database(error) => {
|
||||||
log::error!(
|
log::error!(
|
||||||
@ -104,7 +113,7 @@ impl IntoResponse for Error {
|
|||||||
impl std::fmt::Debug for Error {
|
impl std::fmt::Debug for Error {
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
match self {
|
match self {
|
||||||
Error::Http(code, message) => (code, message).fmt(f),
|
Error::Http(code, message, _headers) => (code, message).fmt(f),
|
||||||
Error::Database(error) => error.fmt(f),
|
Error::Database(error) => error.fmt(f),
|
||||||
Error::Internal(error) => error.fmt(f),
|
Error::Internal(error) => error.fmt(f),
|
||||||
Error::Stripe(error) => error.fmt(f),
|
Error::Stripe(error) => error.fmt(f),
|
||||||
@ -115,7 +124,7 @@ impl std::fmt::Debug for Error {
|
|||||||
impl std::fmt::Display for Error {
|
impl std::fmt::Display for Error {
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
match self {
|
match self {
|
||||||
Error::Http(code, message) => write!(f, "{code}: {message}"),
|
Error::Http(code, message, _) => write!(f, "{code}: {message}"),
|
||||||
Error::Database(error) => error.fmt(f),
|
Error::Database(error) => error.fmt(f),
|
||||||
Error::Internal(error) => error.fmt(f),
|
Error::Internal(error) => error.fmt(f),
|
||||||
Error::Stripe(error) => error.fmt(f),
|
Error::Stripe(error) => error.fmt(f),
|
||||||
@ -141,6 +150,7 @@ pub struct Config {
|
|||||||
pub live_kit_server: Option<String>,
|
pub live_kit_server: Option<String>,
|
||||||
pub live_kit_key: Option<String>,
|
pub live_kit_key: Option<String>,
|
||||||
pub live_kit_secret: Option<String>,
|
pub live_kit_secret: Option<String>,
|
||||||
|
pub llm_api_secret: Option<String>,
|
||||||
pub rust_log: Option<String>,
|
pub rust_log: Option<String>,
|
||||||
pub log_json: Option<bool>,
|
pub log_json: Option<bool>,
|
||||||
pub blob_store_url: Option<String>,
|
pub blob_store_url: Option<String>,
|
||||||
|
@ -1,16 +1,122 @@
|
|||||||
|
mod token;
|
||||||
|
|
||||||
|
use crate::{executor::Executor, Config, Error, Result};
|
||||||
|
use anyhow::Context as _;
|
||||||
|
use axum::{
|
||||||
|
body::Body,
|
||||||
|
http::{self, HeaderName, HeaderValue, Request, StatusCode},
|
||||||
|
middleware::{self, Next},
|
||||||
|
response::{IntoResponse, Response},
|
||||||
|
routing::post,
|
||||||
|
Extension, Json, Router,
|
||||||
|
};
|
||||||
|
use futures::StreamExt as _;
|
||||||
|
use http_client::IsahcHttpClient;
|
||||||
|
use rpc::{PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use crate::{executor::Executor, Config, Result};
|
pub use token::*;
|
||||||
|
|
||||||
pub struct LlmState {
|
pub struct LlmState {
|
||||||
pub config: Config,
|
pub config: Config,
|
||||||
pub executor: Executor,
|
pub executor: Executor,
|
||||||
|
pub http_client: IsahcHttpClient,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl LlmState {
|
impl LlmState {
|
||||||
pub async fn new(config: Config, executor: Executor) -> Result<Arc<Self>> {
|
pub async fn new(config: Config, executor: Executor) -> Result<Arc<Self>> {
|
||||||
let this = Self { config, executor };
|
let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
|
||||||
|
let http_client = IsahcHttpClient::builder()
|
||||||
|
.default_header("User-Agent", user_agent)
|
||||||
|
.build()
|
||||||
|
.context("failed to construct http client")?;
|
||||||
|
|
||||||
|
let this = Self {
|
||||||
|
config,
|
||||||
|
executor,
|
||||||
|
http_client,
|
||||||
|
};
|
||||||
|
|
||||||
Ok(Arc::new(this))
|
Ok(Arc::new(this))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn routes() -> Router<(), Body> {
|
||||||
|
Router::new()
|
||||||
|
.route("/completion", post(perform_completion))
|
||||||
|
.layer(middleware::from_fn(validate_api_token))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn validate_api_token<B>(mut req: Request<B>, next: Next<B>) -> impl IntoResponse {
|
||||||
|
let token = req
|
||||||
|
.headers()
|
||||||
|
.get(http::header::AUTHORIZATION)
|
||||||
|
.and_then(|header| header.to_str().ok())
|
||||||
|
.ok_or_else(|| {
|
||||||
|
Error::http(
|
||||||
|
StatusCode::BAD_REQUEST,
|
||||||
|
"missing authorization header".to_string(),
|
||||||
|
)
|
||||||
|
})?
|
||||||
|
.strip_prefix("Bearer ")
|
||||||
|
.ok_or_else(|| {
|
||||||
|
Error::http(
|
||||||
|
StatusCode::BAD_REQUEST,
|
||||||
|
"invalid authorization header".to_string(),
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let state = req.extensions().get::<Arc<LlmState>>().unwrap();
|
||||||
|
match LlmTokenClaims::validate(&token, &state.config) {
|
||||||
|
Ok(claims) => {
|
||||||
|
req.extensions_mut().insert(claims);
|
||||||
|
Ok::<_, Error>(next.run(req).await.into_response())
|
||||||
|
}
|
||||||
|
Err(ValidateLlmTokenError::Expired) => Err(Error::Http(
|
||||||
|
StatusCode::UNAUTHORIZED,
|
||||||
|
"unauthorized".to_string(),
|
||||||
|
[(
|
||||||
|
HeaderName::from_static(EXPIRED_LLM_TOKEN_HEADER_NAME),
|
||||||
|
HeaderValue::from_static("true"),
|
||||||
|
)]
|
||||||
|
.into_iter()
|
||||||
|
.collect(),
|
||||||
|
)),
|
||||||
|
Err(_err) => Err(Error::http(
|
||||||
|
StatusCode::UNAUTHORIZED,
|
||||||
|
"unauthorized".to_string(),
|
||||||
|
)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn perform_completion(
|
||||||
|
Extension(state): Extension<Arc<LlmState>>,
|
||||||
|
Extension(_claims): Extension<LlmTokenClaims>,
|
||||||
|
Json(params): Json<PerformCompletionParams>,
|
||||||
|
) -> Result<impl IntoResponse> {
|
||||||
|
let api_key = state
|
||||||
|
.config
|
||||||
|
.anthropic_api_key
|
||||||
|
.as_ref()
|
||||||
|
.context("no Anthropic AI API key configured on the server")?;
|
||||||
|
let chunks = anthropic::stream_completion(
|
||||||
|
&state.http_client,
|
||||||
|
anthropic::ANTHROPIC_API_URL,
|
||||||
|
api_key,
|
||||||
|
serde_json::from_str(¶ms.provider_request.get())?,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let stream = chunks.map(|event| {
|
||||||
|
let mut buffer = Vec::new();
|
||||||
|
event.map(|chunk| {
|
||||||
|
buffer.clear();
|
||||||
|
serde_json::to_writer(&mut buffer, &chunk).unwrap();
|
||||||
|
buffer.push(b'\n');
|
||||||
|
buffer
|
||||||
|
})
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(Response::new(Body::wrap_stream(stream)))
|
||||||
|
}
|
||||||
|
75
crates/collab/src/llm/token.rs
Normal file
75
crates/collab/src/llm/token.rs
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
use crate::{db::UserId, Config};
|
||||||
|
use anyhow::{anyhow, Result};
|
||||||
|
use chrono::Utc;
|
||||||
|
use jsonwebtoken::{DecodingKey, EncodingKey, Header, Validation};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::time::Duration;
|
||||||
|
use thiserror::Error;
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
pub struct LlmTokenClaims {
|
||||||
|
pub iat: u64,
|
||||||
|
pub exp: u64,
|
||||||
|
pub jti: String,
|
||||||
|
pub user_id: u64,
|
||||||
|
pub plan: rpc::proto::Plan,
|
||||||
|
}
|
||||||
|
|
||||||
|
const LLM_TOKEN_LIFETIME: Duration = Duration::from_secs(60 * 60);
|
||||||
|
|
||||||
|
impl LlmTokenClaims {
|
||||||
|
pub fn create(user_id: UserId, plan: rpc::proto::Plan, config: &Config) -> Result<String> {
|
||||||
|
let secret = config
|
||||||
|
.llm_api_secret
|
||||||
|
.as_ref()
|
||||||
|
.ok_or_else(|| anyhow!("no LLM API secret"))?;
|
||||||
|
|
||||||
|
let now = Utc::now();
|
||||||
|
let claims = Self {
|
||||||
|
iat: now.timestamp() as u64,
|
||||||
|
exp: (now + LLM_TOKEN_LIFETIME).timestamp() as u64,
|
||||||
|
jti: uuid::Uuid::new_v4().to_string(),
|
||||||
|
user_id: user_id.to_proto(),
|
||||||
|
plan,
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(jsonwebtoken::encode(
|
||||||
|
&Header::default(),
|
||||||
|
&claims,
|
||||||
|
&EncodingKey::from_secret(secret.as_ref()),
|
||||||
|
)?)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn validate(token: &str, config: &Config) -> Result<LlmTokenClaims, ValidateLlmTokenError> {
|
||||||
|
let secret = config
|
||||||
|
.llm_api_secret
|
||||||
|
.as_ref()
|
||||||
|
.ok_or_else(|| anyhow!("no LLM API secret"))?;
|
||||||
|
|
||||||
|
match jsonwebtoken::decode::<Self>(
|
||||||
|
token,
|
||||||
|
&DecodingKey::from_secret(secret.as_ref()),
|
||||||
|
&Validation::default(),
|
||||||
|
) {
|
||||||
|
Ok(token) => Ok(token.claims),
|
||||||
|
Err(e) => {
|
||||||
|
if e.kind() == &jsonwebtoken::errors::ErrorKind::ExpiredSignature {
|
||||||
|
Err(ValidateLlmTokenError::Expired)
|
||||||
|
} else {
|
||||||
|
Err(ValidateLlmTokenError::JwtError(e))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Error, Debug)]
|
||||||
|
pub enum ValidateLlmTokenError {
|
||||||
|
#[error("access token is expired")]
|
||||||
|
Expired,
|
||||||
|
#[error("access token validation error: {0}")]
|
||||||
|
JwtError(#[from] jsonwebtoken::errors::Error),
|
||||||
|
#[error("{0}")]
|
||||||
|
Other(#[from] anyhow::Error),
|
||||||
|
}
|
@ -83,7 +83,9 @@ async fn main() -> Result<()> {
|
|||||||
if mode.is_llm() {
|
if mode.is_llm() {
|
||||||
let state = LlmState::new(config.clone(), Executor::Production).await?;
|
let state = LlmState::new(config.clone(), Executor::Production).await?;
|
||||||
|
|
||||||
app = app.layer(Extension(state.clone()));
|
app = app
|
||||||
|
.merge(collab::llm::routes())
|
||||||
|
.layer(Extension(state.clone()));
|
||||||
}
|
}
|
||||||
|
|
||||||
if mode.is_collab() || mode.is_api() {
|
if mode.is_collab() || mode.is_api() {
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
mod connection_pool;
|
mod connection_pool;
|
||||||
|
|
||||||
use crate::api::CloudflareIpCountryHeader;
|
use crate::api::CloudflareIpCountryHeader;
|
||||||
|
use crate::llm::LlmTokenClaims;
|
||||||
use crate::{
|
use crate::{
|
||||||
auth,
|
auth,
|
||||||
db::{
|
db::{
|
||||||
@ -11,7 +12,7 @@ use crate::{
|
|||||||
ServerId, UpdatedChannelMessage, User, UserId,
|
ServerId, UpdatedChannelMessage, User, UserId,
|
||||||
},
|
},
|
||||||
executor::Executor,
|
executor::Executor,
|
||||||
AppState, Config, Error, RateLimit, RateLimiter, Result,
|
AppState, Config, Error, RateLimit, Result,
|
||||||
};
|
};
|
||||||
use anyhow::{anyhow, bail, Context as _};
|
use anyhow::{anyhow, bail, Context as _};
|
||||||
use async_tungstenite::tungstenite::{
|
use async_tungstenite::tungstenite::{
|
||||||
@ -149,10 +150,9 @@ struct Session {
|
|||||||
db: Arc<tokio::sync::Mutex<DbHandle>>,
|
db: Arc<tokio::sync::Mutex<DbHandle>>,
|
||||||
peer: Arc<Peer>,
|
peer: Arc<Peer>,
|
||||||
connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
|
connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
|
||||||
live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
|
app_state: Arc<AppState>,
|
||||||
supermaven_client: Option<Arc<SupermavenAdminApi>>,
|
supermaven_client: Option<Arc<SupermavenAdminApi>>,
|
||||||
http_client: Arc<IsahcHttpClient>,
|
http_client: Arc<IsahcHttpClient>,
|
||||||
rate_limiter: Arc<RateLimiter>,
|
|
||||||
/// The GeoIP country code for the user.
|
/// The GeoIP country code for the user.
|
||||||
#[allow(unused)]
|
#[allow(unused)]
|
||||||
geoip_country_code: Option<String>,
|
geoip_country_code: Option<String>,
|
||||||
@ -615,6 +615,7 @@ impl Server {
|
|||||||
.add_message_handler(user_message_handler(unfollow))
|
.add_message_handler(user_message_handler(unfollow))
|
||||||
.add_message_handler(user_message_handler(update_followers))
|
.add_message_handler(user_message_handler(update_followers))
|
||||||
.add_request_handler(user_handler(get_private_user_info))
|
.add_request_handler(user_handler(get_private_user_info))
|
||||||
|
.add_request_handler(user_handler(get_llm_api_token))
|
||||||
.add_message_handler(user_message_handler(acknowledge_channel_message))
|
.add_message_handler(user_message_handler(acknowledge_channel_message))
|
||||||
.add_message_handler(user_message_handler(acknowledge_buffer_version))
|
.add_message_handler(user_message_handler(acknowledge_buffer_version))
|
||||||
.add_request_handler(user_handler(get_supermaven_api_key))
|
.add_request_handler(user_handler(get_supermaven_api_key))
|
||||||
@ -1046,9 +1047,8 @@ impl Server {
|
|||||||
db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
|
db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
|
||||||
peer: this.peer.clone(),
|
peer: this.peer.clone(),
|
||||||
connection_pool: this.connection_pool.clone(),
|
connection_pool: this.connection_pool.clone(),
|
||||||
live_kit_client: this.app_state.live_kit_client.clone(),
|
app_state: this.app_state.clone(),
|
||||||
http_client,
|
http_client,
|
||||||
rate_limiter: this.app_state.rate_limiter.clone(),
|
|
||||||
geoip_country_code,
|
geoip_country_code,
|
||||||
_executor: executor.clone(),
|
_executor: executor.clone(),
|
||||||
supermaven_client,
|
supermaven_client,
|
||||||
@ -1559,7 +1559,7 @@ async fn create_room(
|
|||||||
let live_kit_room = nanoid::nanoid!(30);
|
let live_kit_room = nanoid::nanoid!(30);
|
||||||
|
|
||||||
let live_kit_connection_info = util::maybe!(async {
|
let live_kit_connection_info = util::maybe!(async {
|
||||||
let live_kit = session.live_kit_client.as_ref();
|
let live_kit = session.app_state.live_kit_client.as_ref();
|
||||||
let live_kit = live_kit?;
|
let live_kit = live_kit?;
|
||||||
let user_id = session.user_id().to_string();
|
let user_id = session.user_id().to_string();
|
||||||
|
|
||||||
@ -1630,7 +1630,8 @@ async fn join_room(
|
|||||||
.trace_err();
|
.trace_err();
|
||||||
}
|
}
|
||||||
|
|
||||||
let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
|
let live_kit_connection_info =
|
||||||
|
if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
|
||||||
if let Some(token) = live_kit
|
if let Some(token) = live_kit
|
||||||
.room_token(
|
.room_token(
|
||||||
&joined_room.room.live_kit_room,
|
&joined_room.room.live_kit_room,
|
||||||
@ -1877,7 +1878,7 @@ async fn set_room_participant_role(
|
|||||||
(live_kit_room, can_publish)
|
(live_kit_room, can_publish)
|
||||||
};
|
};
|
||||||
|
|
||||||
if let Some(live_kit) = session.live_kit_client.as_ref() {
|
if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
|
||||||
live_kit
|
live_kit
|
||||||
.update_participant(
|
.update_participant(
|
||||||
live_kit_room.clone(),
|
live_kit_room.clone(),
|
||||||
@ -4048,7 +4049,12 @@ async fn join_channel_internal(
|
|||||||
.join_channel(channel_id, session.user_id(), session.connection_id)
|
.join_channel(channel_id, session.user_id(), session.connection_id)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
|
let live_kit_connection_info =
|
||||||
|
session
|
||||||
|
.app_state
|
||||||
|
.live_kit_client
|
||||||
|
.as_ref()
|
||||||
|
.and_then(|live_kit| {
|
||||||
let (can_publish, token) = if role == ChannelRole::Guest {
|
let (can_publish, token) = if role == ChannelRole::Guest {
|
||||||
(
|
(
|
||||||
false,
|
false,
|
||||||
@ -4610,6 +4616,7 @@ async fn complete_with_language_model(
|
|||||||
};
|
};
|
||||||
|
|
||||||
session
|
session
|
||||||
|
.app_state
|
||||||
.rate_limiter
|
.rate_limiter
|
||||||
.check(&*rate_limit, session.user_id())
|
.check(&*rate_limit, session.user_id())
|
||||||
.await?;
|
.await?;
|
||||||
@ -4655,6 +4662,7 @@ async fn stream_complete_with_language_model(
|
|||||||
};
|
};
|
||||||
|
|
||||||
session
|
session
|
||||||
|
.app_state
|
||||||
.rate_limiter
|
.rate_limiter
|
||||||
.check(&*rate_limit, session.user_id())
|
.check(&*rate_limit, session.user_id())
|
||||||
.await?;
|
.await?;
|
||||||
@ -4766,6 +4774,7 @@ async fn count_language_model_tokens(
|
|||||||
};
|
};
|
||||||
|
|
||||||
session
|
session
|
||||||
|
.app_state
|
||||||
.rate_limiter
|
.rate_limiter
|
||||||
.check(&*rate_limit, session.user_id())
|
.check(&*rate_limit, session.user_id())
|
||||||
.await?;
|
.await?;
|
||||||
@ -4885,6 +4894,7 @@ async fn compute_embeddings(
|
|||||||
};
|
};
|
||||||
|
|
||||||
session
|
session
|
||||||
|
.app_state
|
||||||
.rate_limiter
|
.rate_limiter
|
||||||
.check(&*rate_limit, session.user_id())
|
.check(&*rate_limit, session.user_id())
|
||||||
.await?;
|
.await?;
|
||||||
@ -5143,6 +5153,24 @@ async fn get_private_user_info(
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn get_llm_api_token(
|
||||||
|
_request: proto::GetLlmToken,
|
||||||
|
response: Response<proto::GetLlmToken>,
|
||||||
|
session: UserSession,
|
||||||
|
) -> Result<()> {
|
||||||
|
if !session.is_staff() {
|
||||||
|
Err(anyhow!("permission denied"))?
|
||||||
|
}
|
||||||
|
|
||||||
|
let token = LlmTokenClaims::create(
|
||||||
|
session.user_id(),
|
||||||
|
session.current_plan().await?,
|
||||||
|
&session.app_state.config,
|
||||||
|
)?;
|
||||||
|
response.send(proto::GetLlmTokenResponse { token })?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
|
fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
|
||||||
let message = match message {
|
let message = match message {
|
||||||
TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
|
TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
|
||||||
@ -5486,7 +5514,7 @@ async fn leave_room_for_session(session: &UserSession, connection_id: Connection
|
|||||||
update_user_contacts(contact_user_id, &session).await?;
|
update_user_contacts(contact_user_id, &session).await?;
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(live_kit) = session.live_kit_client.as_ref() {
|
if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
|
||||||
live_kit
|
live_kit
|
||||||
.remove_participant(live_kit_room.clone(), session.user_id().to_string())
|
.remove_participant(live_kit_room.clone(), session.user_id().to_string())
|
||||||
.await
|
.await
|
||||||
|
@ -651,6 +651,7 @@ impl TestServer {
|
|||||||
live_kit_server: None,
|
live_kit_server: None,
|
||||||
live_kit_key: None,
|
live_kit_key: None,
|
||||||
live_kit_secret: None,
|
live_kit_secret: None,
|
||||||
|
llm_api_secret: None,
|
||||||
rust_log: None,
|
rust_log: None,
|
||||||
log_json: None,
|
log_json: None,
|
||||||
zed_environment: "test".into(),
|
zed_environment: "test".into(),
|
||||||
|
@ -175,6 +175,22 @@ impl HttpClientWithUrl {
|
|||||||
query,
|
query,
|
||||||
)?)
|
)?)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Builds a Zed LLM URL using the given path.
|
||||||
|
pub fn build_zed_llm_url(&self, path: &str, query: &[(&str, &str)]) -> Result<Url> {
|
||||||
|
let base_url = self.base_url();
|
||||||
|
let base_api_url = match base_url.as_ref() {
|
||||||
|
"https://zed.dev" => "https://llm.zed.dev",
|
||||||
|
"https://staging.zed.dev" => "https://llm-staging.zed.dev",
|
||||||
|
"http://localhost:3000" => "http://localhost:8080",
|
||||||
|
other => other,
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(Url::parse_with_params(
|
||||||
|
&format!("{}{}", base_api_url, path),
|
||||||
|
query,
|
||||||
|
)?)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl HttpClient for Arc<HttpClientWithUrl> {
|
impl HttpClient for Arc<HttpClientWithUrl> {
|
||||||
|
@ -5,13 +5,20 @@ use crate::{
|
|||||||
LanguageModelProviderState, LanguageModelRequest, RateLimiter, ZedModel,
|
LanguageModelProviderState, LanguageModelRequest, RateLimiter, ZedModel,
|
||||||
};
|
};
|
||||||
use anyhow::{anyhow, Context as _, Result};
|
use anyhow::{anyhow, Context as _, Result};
|
||||||
use client::{Client, UserStore};
|
use client::{Client, PerformCompletionParams, UserStore, EXPIRED_LLM_TOKEN_HEADER_NAME};
|
||||||
use collections::BTreeMap;
|
use collections::BTreeMap;
|
||||||
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
|
use feature_flags::{FeatureFlag, FeatureFlagAppExt};
|
||||||
|
use futures::{future::BoxFuture, stream::BoxStream, AsyncBufReadExt, FutureExt, StreamExt};
|
||||||
use gpui::{AnyView, AppContext, AsyncAppContext, Model, ModelContext, Subscription, Task};
|
use gpui::{AnyView, AppContext, AsyncAppContext, Model, ModelContext, Subscription, Task};
|
||||||
|
use http_client::{HttpClient, Method};
|
||||||
use schemars::JsonSchema;
|
use schemars::JsonSchema;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
use serde_json::value::RawValue;
|
||||||
use settings::{Settings, SettingsStore};
|
use settings::{Settings, SettingsStore};
|
||||||
|
use smol::{
|
||||||
|
io::BufReader,
|
||||||
|
lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard},
|
||||||
|
};
|
||||||
use std::{future, sync::Arc};
|
use std::{future, sync::Arc};
|
||||||
use strum::IntoEnumIterator;
|
use strum::IntoEnumIterator;
|
||||||
use ui::prelude::*;
|
use ui::prelude::*;
|
||||||
@ -46,6 +53,7 @@ pub struct AvailableModel {
|
|||||||
|
|
||||||
pub struct CloudLanguageModelProvider {
|
pub struct CloudLanguageModelProvider {
|
||||||
client: Arc<Client>,
|
client: Arc<Client>,
|
||||||
|
llm_api_token: LlmApiToken,
|
||||||
state: gpui::Model<State>,
|
state: gpui::Model<State>,
|
||||||
_maintain_client_status: Task<()>,
|
_maintain_client_status: Task<()>,
|
||||||
}
|
}
|
||||||
@ -104,6 +112,7 @@ impl CloudLanguageModelProvider {
|
|||||||
Self {
|
Self {
|
||||||
client,
|
client,
|
||||||
state,
|
state,
|
||||||
|
llm_api_token: LlmApiToken::default(),
|
||||||
_maintain_client_status: maintain_client_status,
|
_maintain_client_status: maintain_client_status,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -181,6 +190,7 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
|
|||||||
Arc::new(CloudLanguageModel {
|
Arc::new(CloudLanguageModel {
|
||||||
id: LanguageModelId::from(model.id().to_string()),
|
id: LanguageModelId::from(model.id().to_string()),
|
||||||
model,
|
model,
|
||||||
|
llm_api_token: self.llm_api_token.clone(),
|
||||||
client: self.client.clone(),
|
client: self.client.clone(),
|
||||||
request_limiter: RateLimiter::new(4),
|
request_limiter: RateLimiter::new(4),
|
||||||
}) as Arc<dyn LanguageModel>
|
}) as Arc<dyn LanguageModel>
|
||||||
@ -208,13 +218,27 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct LlmServiceFeatureFlag;
|
||||||
|
|
||||||
|
impl FeatureFlag for LlmServiceFeatureFlag {
|
||||||
|
const NAME: &'static str = "llm-service";
|
||||||
|
|
||||||
|
fn enabled_for_staff() -> bool {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub struct CloudLanguageModel {
|
pub struct CloudLanguageModel {
|
||||||
id: LanguageModelId,
|
id: LanguageModelId,
|
||||||
model: CloudModel,
|
model: CloudModel,
|
||||||
|
llm_api_token: LlmApiToken,
|
||||||
client: Arc<Client>,
|
client: Arc<Client>,
|
||||||
request_limiter: RateLimiter,
|
request_limiter: RateLimiter,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Default)]
|
||||||
|
struct LlmApiToken(Arc<RwLock<Option<String>>>);
|
||||||
|
|
||||||
impl LanguageModel for CloudLanguageModel {
|
impl LanguageModel for CloudLanguageModel {
|
||||||
fn id(&self) -> LanguageModelId {
|
fn id(&self) -> LanguageModelId {
|
||||||
self.id.clone()
|
self.id.clone()
|
||||||
@ -279,12 +303,75 @@ impl LanguageModel for CloudLanguageModel {
|
|||||||
fn stream_completion(
|
fn stream_completion(
|
||||||
&self,
|
&self,
|
||||||
request: LanguageModelRequest,
|
request: LanguageModelRequest,
|
||||||
_: &AsyncAppContext,
|
cx: &AsyncAppContext,
|
||||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||||
match &self.model {
|
match &self.model {
|
||||||
CloudModel::Anthropic(model) => {
|
CloudModel::Anthropic(model) => {
|
||||||
let client = self.client.clone();
|
|
||||||
let request = request.into_anthropic(model.id().into());
|
let request = request.into_anthropic(model.id().into());
|
||||||
|
let client = self.client.clone();
|
||||||
|
|
||||||
|
if cx
|
||||||
|
.update(|cx| cx.has_flag::<LlmServiceFeatureFlag>())
|
||||||
|
.unwrap_or(false)
|
||||||
|
{
|
||||||
|
let http_client = self.client.http_client();
|
||||||
|
let llm_api_token = self.llm_api_token.clone();
|
||||||
|
let future = self.request_limiter.stream(async move {
|
||||||
|
let request = serde_json::to_string(&request)?;
|
||||||
|
let mut token = llm_api_token.acquire(&client).await?;
|
||||||
|
let mut did_retry = false;
|
||||||
|
|
||||||
|
let response = loop {
|
||||||
|
let request = http_client::Request::builder()
|
||||||
|
.method(Method::POST)
|
||||||
|
.uri(http_client.build_zed_llm_url("/completion", &[])?.as_ref())
|
||||||
|
.header("Content-Type", "application/json")
|
||||||
|
.header("Authorization", format!("Bearer {token}"))
|
||||||
|
.body(
|
||||||
|
serde_json::to_string(&PerformCompletionParams {
|
||||||
|
provider_request: RawValue::from_string(request.clone())?,
|
||||||
|
})?
|
||||||
|
.into(),
|
||||||
|
)?;
|
||||||
|
let response = http_client.send(request).await?;
|
||||||
|
if response.status().is_success() {
|
||||||
|
break response;
|
||||||
|
} else if !did_retry
|
||||||
|
&& response
|
||||||
|
.headers()
|
||||||
|
.get(EXPIRED_LLM_TOKEN_HEADER_NAME)
|
||||||
|
.is_some()
|
||||||
|
{
|
||||||
|
did_retry = true;
|
||||||
|
token = llm_api_token.refresh(&client).await?;
|
||||||
|
} else {
|
||||||
|
break Err(anyhow!(
|
||||||
|
"cloud language model completion failed with status {}",
|
||||||
|
response.status()
|
||||||
|
))?;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let body = BufReader::new(response.into_body());
|
||||||
|
|
||||||
|
let stream =
|
||||||
|
futures::stream::try_unfold(body, move |mut body| async move {
|
||||||
|
let mut buffer = String::new();
|
||||||
|
match body.read_line(&mut buffer).await {
|
||||||
|
Ok(0) => Ok(None),
|
||||||
|
Ok(_) => {
|
||||||
|
let event: anthropic::Event =
|
||||||
|
serde_json::from_str(&buffer)?;
|
||||||
|
Ok(Some((event, body)))
|
||||||
|
}
|
||||||
|
Err(e) => Err(e.into()),
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(anthropic::extract_text_from_events(stream))
|
||||||
|
});
|
||||||
|
async move { Ok(future.await?.boxed()) }.boxed()
|
||||||
|
} else {
|
||||||
let future = self.request_limiter.stream(async move {
|
let future = self.request_limiter.stream(async move {
|
||||||
let request = serde_json::to_string(&request)?;
|
let request = serde_json::to_string(&request)?;
|
||||||
let stream = client
|
let stream = client
|
||||||
@ -292,13 +379,13 @@ impl LanguageModel for CloudLanguageModel {
|
|||||||
provider: proto::LanguageModelProvider::Anthropic as i32,
|
provider: proto::LanguageModelProvider::Anthropic as i32,
|
||||||
request,
|
request,
|
||||||
})
|
})
|
||||||
.await?;
|
.await?
|
||||||
Ok(anthropic::extract_text_from_events(
|
.map(|event| Ok(serde_json::from_str(&event?.event)?));
|
||||||
stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
|
Ok(anthropic::extract_text_from_events(stream))
|
||||||
))
|
|
||||||
});
|
});
|
||||||
async move { Ok(future.await?.boxed()) }.boxed()
|
async move { Ok(future.await?.boxed()) }.boxed()
|
||||||
}
|
}
|
||||||
|
}
|
||||||
CloudModel::OpenAi(model) => {
|
CloudModel::OpenAi(model) => {
|
||||||
let client = self.client.clone();
|
let client = self.client.clone();
|
||||||
let request = request.into_open_ai(model.id().into());
|
let request = request.into_open_ai(model.id().into());
|
||||||
@ -417,6 +504,30 @@ impl LanguageModel for CloudLanguageModel {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl LlmApiToken {
|
||||||
|
async fn acquire(&self, client: &Arc<Client>) -> Result<String> {
|
||||||
|
let lock = self.0.upgradable_read().await;
|
||||||
|
if let Some(token) = lock.as_ref() {
|
||||||
|
Ok(token.to_string())
|
||||||
|
} else {
|
||||||
|
Self::fetch(RwLockUpgradableReadGuard::upgrade(lock).await, &client).await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn refresh(&self, client: &Arc<Client>) -> Result<String> {
|
||||||
|
Self::fetch(self.0.write().await, &client).await
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn fetch<'a>(
|
||||||
|
mut lock: RwLockWriteGuard<'a, Option<String>>,
|
||||||
|
client: &Arc<Client>,
|
||||||
|
) -> Result<String> {
|
||||||
|
let response = client.request(proto::GetLlmToken {}).await?;
|
||||||
|
*lock = Some(response.token.clone());
|
||||||
|
Ok(response.token.clone())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
struct ConfigurationView {
|
struct ConfigurationView {
|
||||||
state: gpui::Model<State>,
|
state: gpui::Model<State>,
|
||||||
}
|
}
|
||||||
|
@ -126,7 +126,7 @@ message Envelope {
|
|||||||
Unfollow unfollow = 101;
|
Unfollow unfollow = 101;
|
||||||
GetPrivateUserInfo get_private_user_info = 102;
|
GetPrivateUserInfo get_private_user_info = 102;
|
||||||
GetPrivateUserInfoResponse get_private_user_info_response = 103;
|
GetPrivateUserInfoResponse get_private_user_info_response = 103;
|
||||||
UpdateUserPlan update_user_plan = 234; // current max
|
UpdateUserPlan update_user_plan = 234;
|
||||||
UpdateDiffBase update_diff_base = 104;
|
UpdateDiffBase update_diff_base = 104;
|
||||||
|
|
||||||
OnTypeFormatting on_type_formatting = 105;
|
OnTypeFormatting on_type_formatting = 105;
|
||||||
@ -270,6 +270,9 @@ message Envelope {
|
|||||||
|
|
||||||
AddWorktree add_worktree = 222;
|
AddWorktree add_worktree = 222;
|
||||||
AddWorktreeResponse add_worktree_response = 223;
|
AddWorktreeResponse add_worktree_response = 223;
|
||||||
|
|
||||||
|
GetLlmToken get_llm_token = 235;
|
||||||
|
GetLlmTokenResponse get_llm_token_response = 236; // current max
|
||||||
}
|
}
|
||||||
|
|
||||||
reserved 158 to 161;
|
reserved 158 to 161;
|
||||||
@ -2425,6 +2428,12 @@ message SynchronizeContextsResponse {
|
|||||||
repeated ContextVersion contexts = 1;
|
repeated ContextVersion contexts = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
message GetLlmToken {}
|
||||||
|
|
||||||
|
message GetLlmTokenResponse {
|
||||||
|
string token = 1;
|
||||||
|
}
|
||||||
|
|
||||||
// Remote FS
|
// Remote FS
|
||||||
|
|
||||||
message AddWorktree {
|
message AddWorktree {
|
||||||
|
@ -259,6 +259,8 @@ messages!(
|
|||||||
(GetTypeDefinitionResponse, Background),
|
(GetTypeDefinitionResponse, Background),
|
||||||
(GetImplementation, Background),
|
(GetImplementation, Background),
|
||||||
(GetImplementationResponse, Background),
|
(GetImplementationResponse, Background),
|
||||||
|
(GetLlmToken, Background),
|
||||||
|
(GetLlmTokenResponse, Background),
|
||||||
(GetUsers, Foreground),
|
(GetUsers, Foreground),
|
||||||
(Hello, Foreground),
|
(Hello, Foreground),
|
||||||
(IncomingCall, Foreground),
|
(IncomingCall, Foreground),
|
||||||
@ -438,6 +440,7 @@ request_messages!(
|
|||||||
(GetImplementation, GetImplementationResponse),
|
(GetImplementation, GetImplementationResponse),
|
||||||
(GetDocumentHighlights, GetDocumentHighlightsResponse),
|
(GetDocumentHighlights, GetDocumentHighlightsResponse),
|
||||||
(GetHover, GetHoverResponse),
|
(GetHover, GetHoverResponse),
|
||||||
|
(GetLlmToken, GetLlmTokenResponse),
|
||||||
(GetNotifications, GetNotificationsResponse),
|
(GetNotifications, GetNotificationsResponse),
|
||||||
(GetPrivateUserInfo, GetPrivateUserInfoResponse),
|
(GetPrivateUserInfo, GetPrivateUserInfoResponse),
|
||||||
(GetProjectSymbols, GetProjectSymbolsResponse),
|
(GetProjectSymbols, GetProjectSymbolsResponse),
|
||||||
|
8
crates/rpc/src/llm.rs
Normal file
8
crates/rpc/src/llm.rs
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
pub const EXPIRED_LLM_TOKEN_HEADER_NAME: &str = "x-zed-expired-token";
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize)]
|
||||||
|
pub struct PerformCompletionParams {
|
||||||
|
pub provider_request: Box<serde_json::value::RawValue>,
|
||||||
|
}
|
@ -1,12 +1,14 @@
|
|||||||
pub mod auth;
|
pub mod auth;
|
||||||
mod conn;
|
mod conn;
|
||||||
mod extension;
|
mod extension;
|
||||||
|
mod llm;
|
||||||
mod notification;
|
mod notification;
|
||||||
mod peer;
|
mod peer;
|
||||||
pub mod proto;
|
pub mod proto;
|
||||||
|
|
||||||
pub use conn::Connection;
|
pub use conn::Connection;
|
||||||
pub use extension::*;
|
pub use extension::*;
|
||||||
|
pub use llm::*;
|
||||||
pub use notification::*;
|
pub use notification::*;
|
||||||
pub use peer::*;
|
pub use peer::*;
|
||||||
pub use proto::{error::*, Receipt, TypedEnvelope};
|
pub use proto::{error::*, Receipt, TypedEnvelope};
|
||||||
|
Loading…
Reference in New Issue
Block a user