Introduce a separate backend service for LLM calls (#15831)

This PR introduces a separate backend service for making LLM calls.

It exposes an HTTP interface that can be called by Zed clients. To call
these endpoints, the client must provide a `Bearer` token. These tokens
are issued/refreshed by the collab service over RPC.

We're adding this in a backwards-compatible way. Right now the access
tokens can only be minted for Zed staff, and calling this separate LLM
service is behind the `llm-service` feature flag (which is not
automatically enabled for Zed staff).

Release Notes:

- N/A

---------

Co-authored-by: Marshall <marshall@zed.dev>
Co-authored-by: Marshall Bowers <elliott.codes@gmail.com>
This commit is contained in:
Max Brunsfeld 2024-08-05 17:26:21 -07:00 committed by GitHub
parent 4ed43e6e6f
commit 8e9c2b1125
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 478 additions and 102 deletions

2
Cargo.lock generated
View File

@ -2466,6 +2466,7 @@ dependencies = [
"hex",
"http_client",
"indoc",
"jsonwebtoken",
"language",
"language_model",
"live_kit_client",
@ -2507,6 +2508,7 @@ dependencies = [
"telemetry_events",
"text",
"theme",
"thiserror",
"time",
"tokio",
"toml 0.8.16",

View File

@ -15,6 +15,7 @@ BLOB_STORE_URL = "http://127.0.0.1:9000"
BLOB_STORE_REGION = "the-region"
ZED_CLIENT_CHECKSUM_SEED = "development-checksum-seed"
SEED_PATH = "crates/collab/seed.default.json"
LLM_API_SECRET = "llm-secret"
# CLICKHOUSE_URL = ""
# CLICKHOUSE_USER = "default"

View File

@ -37,6 +37,7 @@ futures.workspace = true
google_ai.workspace = true
hex.workspace = true
http_client.workspace = true
jsonwebtoken.workspace = true
live_kit_server.workspace = true
log.workspace = true
nanoid.workspace = true
@ -61,6 +62,7 @@ subtle.workspace = true
rustc-demangle.workspace = true
telemetry_events.workspace = true
text.workspace = true
thiserror.workspace = true
time.workspace = true
tokio.workspace = true
toml.workspace = true

View File

@ -81,14 +81,14 @@ pub async fn validate_api_token<B>(req: Request<B>, next: Next<B>) -> impl IntoR
.get(http::header::AUTHORIZATION)
.and_then(|header| header.to_str().ok())
.ok_or_else(|| {
Error::Http(
Error::http(
StatusCode::BAD_REQUEST,
"missing authorization header".to_string(),
)
})?
.strip_prefix("token ")
.ok_or_else(|| {
Error::Http(
Error::http(
StatusCode::BAD_REQUEST,
"invalid authorization header".to_string(),
)
@ -97,7 +97,7 @@ pub async fn validate_api_token<B>(req: Request<B>, next: Next<B>) -> impl IntoR
let state = req.extensions().get::<Arc<AppState>>().unwrap();
if token != state.config.api_token {
Err(Error::Http(
Err(Error::http(
StatusCode::UNAUTHORIZED,
"invalid authorization token".to_string(),
))?
@ -185,13 +185,13 @@ async fn create_access_token(
if let Some(impersonated_user) = app.db.get_user_by_github_login(&impersonate).await? {
impersonated_user_id = Some(impersonated_user.id);
} else {
return Err(Error::Http(
return Err(Error::http(
StatusCode::UNPROCESSABLE_ENTITY,
format!("user {impersonate} does not exist"),
));
}
} else {
return Err(Error::Http(
return Err(Error::http(
StatusCode::UNAUTHORIZED,
"you do not have permission to impersonate other users".to_string(),
));

View File

@ -120,7 +120,7 @@ async fn create_billing_subscription(
.zip(app.config.stripe_price_id.clone())
else {
log::error!("failed to retrieve Stripe client or price ID");
Err(Error::Http(
Err(Error::http(
StatusCode::NOT_IMPLEMENTED,
"not supported".into(),
))?
@ -201,7 +201,7 @@ async fn manage_billing_subscription(
let Some(stripe_client) = app.stripe_client.clone() else {
log::error!("failed to retrieve Stripe client");
Err(Error::Http(
Err(Error::http(
StatusCode::NOT_IMPLEMENTED,
"not supported".into(),
))?

View File

@ -206,14 +206,14 @@ pub async fn post_hang(
body: Bytes,
) -> Result<()> {
let Some(expected) = calculate_json_checksum(app.clone(), &body) else {
return Err(Error::Http(
return Err(Error::http(
StatusCode::INTERNAL_SERVER_ERROR,
"events not enabled".into(),
))?;
};
if checksum != expected {
return Err(Error::Http(
return Err(Error::http(
StatusCode::BAD_REQUEST,
"invalid checksum".into(),
))?;
@ -265,25 +265,25 @@ pub async fn post_panic(
body: Bytes,
) -> Result<()> {
let Some(expected) = calculate_json_checksum(app.clone(), &body) else {
return Err(Error::Http(
return Err(Error::http(
StatusCode::INTERNAL_SERVER_ERROR,
"events not enabled".into(),
))?;
};
if checksum != expected {
return Err(Error::Http(
return Err(Error::http(
StatusCode::BAD_REQUEST,
"invalid checksum".into(),
))?;
}
let report: telemetry_events::PanicRequest = serde_json::from_slice(&body)
.map_err(|_| Error::Http(StatusCode::BAD_REQUEST, "invalid json".into()))?;
.map_err(|_| Error::http(StatusCode::BAD_REQUEST, "invalid json".into()))?;
let panic = report.panic;
if panic.os_name == "Linux" && panic.os_version == Some("1.0.0".to_string()) {
return Err(Error::Http(
return Err(Error::http(
StatusCode::BAD_REQUEST,
"invalid os version".into(),
))?;
@ -362,14 +362,14 @@ pub async fn post_events(
body: Bytes,
) -> Result<()> {
let Some(clickhouse_client) = app.clickhouse_client.clone() else {
Err(Error::Http(
Err(Error::http(
StatusCode::NOT_IMPLEMENTED,
"not supported".into(),
))?
};
let Some(expected) = calculate_json_checksum(app.clone(), &body) else {
return Err(Error::Http(
return Err(Error::http(
StatusCode::INTERNAL_SERVER_ERROR,
"events not enabled".into(),
))?;
@ -385,7 +385,7 @@ pub async fn post_events(
let mut to_upload = ToUpload::default();
let Some(last_event) = request_body.events.last() else {
return Err(Error::Http(StatusCode::BAD_REQUEST, "no events".into()))?;
return Err(Error::http(StatusCode::BAD_REQUEST, "no events".into()))?;
};
let country_code = country_code_header.map(|h| h.to_string());

View File

@ -185,7 +185,7 @@ async fn download_extension(
.clone()
.zip(app.config.blob_store_bucket.clone())
else {
Err(Error::Http(
Err(Error::http(
StatusCode::NOT_IMPLEMENTED,
"not supported".into(),
))?
@ -202,7 +202,7 @@ async fn download_extension(
.await?;
if !version_exists {
Err(Error::Http(
Err(Error::http(
StatusCode::NOT_FOUND,
"unknown extension version".into(),
))?;

View File

@ -33,7 +33,7 @@ pub async fn validate_header<B>(mut req: Request<B>, next: Next<B>) -> impl Into
.get(http::header::AUTHORIZATION)
.and_then(|header| header.to_str().ok())
.ok_or_else(|| {
Error::Http(
Error::http(
StatusCode::UNAUTHORIZED,
"missing authorization header".to_string(),
)
@ -45,14 +45,14 @@ pub async fn validate_header<B>(mut req: Request<B>, next: Next<B>) -> impl Into
let first = auth_header.next().unwrap_or("");
if first == "dev-server-token" {
let dev_server_token = auth_header.next().ok_or_else(|| {
Error::Http(
Error::http(
StatusCode::BAD_REQUEST,
"missing dev-server-token token in authorization header".to_string(),
)
})?;
let dev_server = verify_dev_server_token(dev_server_token, &state.db)
.await
.map_err(|e| Error::Http(StatusCode::UNAUTHORIZED, format!("{}", e)))?;
.map_err(|e| Error::http(StatusCode::UNAUTHORIZED, format!("{}", e)))?;
req.extensions_mut()
.insert(Principal::DevServer(dev_server));
@ -60,14 +60,14 @@ pub async fn validate_header<B>(mut req: Request<B>, next: Next<B>) -> impl Into
}
let user_id = UserId(first.parse().map_err(|_| {
Error::Http(
Error::http(
StatusCode::BAD_REQUEST,
"missing user id in authorization header".to_string(),
)
})?);
let access_token = auth_header.next().ok_or_else(|| {
Error::Http(
Error::http(
StatusCode::BAD_REQUEST,
"missing access token in authorization header".to_string(),
)
@ -111,7 +111,7 @@ pub async fn validate_header<B>(mut req: Request<B>, next: Next<B>) -> impl Into
}
}
Err(Error::Http(
Err(Error::http(
StatusCode::UNAUTHORIZED,
"invalid credentials".to_string(),
))

View File

@ -13,7 +13,10 @@ mod tests;
use anyhow::anyhow;
use aws_config::{BehaviorVersion, Region};
use axum::{http::StatusCode, response::IntoResponse};
use axum::{
http::{HeaderMap, StatusCode},
response::IntoResponse,
};
use db::{ChannelId, Database};
use executor::Executor;
pub use rate_limiter::*;
@ -24,7 +27,7 @@ use util::ResultExt;
pub type Result<T, E = Error> = std::result::Result<T, E>;
pub enum Error {
Http(StatusCode, String),
Http(StatusCode, String, HeaderMap),
Database(sea_orm::error::DbErr),
Internal(anyhow::Error),
Stripe(stripe::StripeError),
@ -66,12 +69,18 @@ impl From<serde_json::Error> for Error {
}
}
impl Error {
fn http(code: StatusCode, message: String) -> Self {
Self::Http(code, message, HeaderMap::default())
}
}
impl IntoResponse for Error {
fn into_response(self) -> axum::response::Response {
match self {
Error::Http(code, message) => {
Error::Http(code, message, headers) => {
log::error!("HTTP error {}: {}", code, &message);
(code, message).into_response()
(code, headers, message).into_response()
}
Error::Database(error) => {
log::error!(
@ -104,7 +113,7 @@ impl IntoResponse for Error {
impl std::fmt::Debug for Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Error::Http(code, message) => (code, message).fmt(f),
Error::Http(code, message, _headers) => (code, message).fmt(f),
Error::Database(error) => error.fmt(f),
Error::Internal(error) => error.fmt(f),
Error::Stripe(error) => error.fmt(f),
@ -115,7 +124,7 @@ impl std::fmt::Debug for Error {
impl std::fmt::Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Error::Http(code, message) => write!(f, "{code}: {message}"),
Error::Http(code, message, _) => write!(f, "{code}: {message}"),
Error::Database(error) => error.fmt(f),
Error::Internal(error) => error.fmt(f),
Error::Stripe(error) => error.fmt(f),
@ -141,6 +150,7 @@ pub struct Config {
pub live_kit_server: Option<String>,
pub live_kit_key: Option<String>,
pub live_kit_secret: Option<String>,
pub llm_api_secret: Option<String>,
pub rust_log: Option<String>,
pub log_json: Option<bool>,
pub blob_store_url: Option<String>,

View File

@ -1,16 +1,122 @@
mod token;
use crate::{executor::Executor, Config, Error, Result};
use anyhow::Context as _;
use axum::{
body::Body,
http::{self, HeaderName, HeaderValue, Request, StatusCode},
middleware::{self, Next},
response::{IntoResponse, Response},
routing::post,
Extension, Json, Router,
};
use futures::StreamExt as _;
use http_client::IsahcHttpClient;
use rpc::{PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME};
use std::sync::Arc;
use crate::{executor::Executor, Config, Result};
pub use token::*;
pub struct LlmState {
pub config: Config,
pub executor: Executor,
pub http_client: IsahcHttpClient,
}
impl LlmState {
pub async fn new(config: Config, executor: Executor) -> Result<Arc<Self>> {
let this = Self { config, executor };
let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
let http_client = IsahcHttpClient::builder()
.default_header("User-Agent", user_agent)
.build()
.context("failed to construct http client")?;
let this = Self {
config,
executor,
http_client,
};
Ok(Arc::new(this))
}
}
pub fn routes() -> Router<(), Body> {
Router::new()
.route("/completion", post(perform_completion))
.layer(middleware::from_fn(validate_api_token))
}
async fn validate_api_token<B>(mut req: Request<B>, next: Next<B>) -> impl IntoResponse {
let token = req
.headers()
.get(http::header::AUTHORIZATION)
.and_then(|header| header.to_str().ok())
.ok_or_else(|| {
Error::http(
StatusCode::BAD_REQUEST,
"missing authorization header".to_string(),
)
})?
.strip_prefix("Bearer ")
.ok_or_else(|| {
Error::http(
StatusCode::BAD_REQUEST,
"invalid authorization header".to_string(),
)
})?;
let state = req.extensions().get::<Arc<LlmState>>().unwrap();
match LlmTokenClaims::validate(&token, &state.config) {
Ok(claims) => {
req.extensions_mut().insert(claims);
Ok::<_, Error>(next.run(req).await.into_response())
}
Err(ValidateLlmTokenError::Expired) => Err(Error::Http(
StatusCode::UNAUTHORIZED,
"unauthorized".to_string(),
[(
HeaderName::from_static(EXPIRED_LLM_TOKEN_HEADER_NAME),
HeaderValue::from_static("true"),
)]
.into_iter()
.collect(),
)),
Err(_err) => Err(Error::http(
StatusCode::UNAUTHORIZED,
"unauthorized".to_string(),
)),
}
}
async fn perform_completion(
Extension(state): Extension<Arc<LlmState>>,
Extension(_claims): Extension<LlmTokenClaims>,
Json(params): Json<PerformCompletionParams>,
) -> Result<impl IntoResponse> {
let api_key = state
.config
.anthropic_api_key
.as_ref()
.context("no Anthropic AI API key configured on the server")?;
let chunks = anthropic::stream_completion(
&state.http_client,
anthropic::ANTHROPIC_API_URL,
api_key,
serde_json::from_str(&params.provider_request.get())?,
None,
)
.await?;
let stream = chunks.map(|event| {
let mut buffer = Vec::new();
event.map(|chunk| {
buffer.clear();
serde_json::to_writer(&mut buffer, &chunk).unwrap();
buffer.push(b'\n');
buffer
})
});
Ok(Response::new(Body::wrap_stream(stream)))
}

View File

@ -0,0 +1,75 @@
use crate::{db::UserId, Config};
use anyhow::{anyhow, Result};
use chrono::Utc;
use jsonwebtoken::{DecodingKey, EncodingKey, Header, Validation};
use serde::{Deserialize, Serialize};
use std::time::Duration;
use thiserror::Error;
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct LlmTokenClaims {
pub iat: u64,
pub exp: u64,
pub jti: String,
pub user_id: u64,
pub plan: rpc::proto::Plan,
}
const LLM_TOKEN_LIFETIME: Duration = Duration::from_secs(60 * 60);
impl LlmTokenClaims {
pub fn create(user_id: UserId, plan: rpc::proto::Plan, config: &Config) -> Result<String> {
let secret = config
.llm_api_secret
.as_ref()
.ok_or_else(|| anyhow!("no LLM API secret"))?;
let now = Utc::now();
let claims = Self {
iat: now.timestamp() as u64,
exp: (now + LLM_TOKEN_LIFETIME).timestamp() as u64,
jti: uuid::Uuid::new_v4().to_string(),
user_id: user_id.to_proto(),
plan,
};
Ok(jsonwebtoken::encode(
&Header::default(),
&claims,
&EncodingKey::from_secret(secret.as_ref()),
)?)
}
pub fn validate(token: &str, config: &Config) -> Result<LlmTokenClaims, ValidateLlmTokenError> {
let secret = config
.llm_api_secret
.as_ref()
.ok_or_else(|| anyhow!("no LLM API secret"))?;
match jsonwebtoken::decode::<Self>(
token,
&DecodingKey::from_secret(secret.as_ref()),
&Validation::default(),
) {
Ok(token) => Ok(token.claims),
Err(e) => {
if e.kind() == &jsonwebtoken::errors::ErrorKind::ExpiredSignature {
Err(ValidateLlmTokenError::Expired)
} else {
Err(ValidateLlmTokenError::JwtError(e))
}
}
}
}
}
#[derive(Error, Debug)]
pub enum ValidateLlmTokenError {
#[error("access token is expired")]
Expired,
#[error("access token validation error: {0}")]
JwtError(#[from] jsonwebtoken::errors::Error),
#[error("{0}")]
Other(#[from] anyhow::Error),
}

View File

@ -83,7 +83,9 @@ async fn main() -> Result<()> {
if mode.is_llm() {
let state = LlmState::new(config.clone(), Executor::Production).await?;
app = app.layer(Extension(state.clone()));
app = app
.merge(collab::llm::routes())
.layer(Extension(state.clone()));
}
if mode.is_collab() || mode.is_api() {

View File

@ -1,6 +1,7 @@
mod connection_pool;
use crate::api::CloudflareIpCountryHeader;
use crate::llm::LlmTokenClaims;
use crate::{
auth,
db::{
@ -11,7 +12,7 @@ use crate::{
ServerId, UpdatedChannelMessage, User, UserId,
},
executor::Executor,
AppState, Config, Error, RateLimit, RateLimiter, Result,
AppState, Config, Error, RateLimit, Result,
};
use anyhow::{anyhow, bail, Context as _};
use async_tungstenite::tungstenite::{
@ -149,10 +150,9 @@ struct Session {
db: Arc<tokio::sync::Mutex<DbHandle>>,
peer: Arc<Peer>,
connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
app_state: Arc<AppState>,
supermaven_client: Option<Arc<SupermavenAdminApi>>,
http_client: Arc<IsahcHttpClient>,
rate_limiter: Arc<RateLimiter>,
/// The GeoIP country code for the user.
#[allow(unused)]
geoip_country_code: Option<String>,
@ -615,6 +615,7 @@ impl Server {
.add_message_handler(user_message_handler(unfollow))
.add_message_handler(user_message_handler(update_followers))
.add_request_handler(user_handler(get_private_user_info))
.add_request_handler(user_handler(get_llm_api_token))
.add_message_handler(user_message_handler(acknowledge_channel_message))
.add_message_handler(user_message_handler(acknowledge_buffer_version))
.add_request_handler(user_handler(get_supermaven_api_key))
@ -1046,9 +1047,8 @@ impl Server {
db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
peer: this.peer.clone(),
connection_pool: this.connection_pool.clone(),
live_kit_client: this.app_state.live_kit_client.clone(),
app_state: this.app_state.clone(),
http_client,
rate_limiter: this.app_state.rate_limiter.clone(),
geoip_country_code,
_executor: executor.clone(),
supermaven_client,
@ -1559,7 +1559,7 @@ async fn create_room(
let live_kit_room = nanoid::nanoid!(30);
let live_kit_connection_info = util::maybe!(async {
let live_kit = session.live_kit_client.as_ref();
let live_kit = session.app_state.live_kit_client.as_ref();
let live_kit = live_kit?;
let user_id = session.user_id().to_string();
@ -1630,25 +1630,26 @@ async fn join_room(
.trace_err();
}
let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
if let Some(token) = live_kit
.room_token(
&joined_room.room.live_kit_room,
&session.user_id().to_string(),
)
.trace_err()
{
Some(proto::LiveKitConnectionInfo {
server_url: live_kit.url().into(),
token,
can_publish: true,
})
let live_kit_connection_info =
if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
if let Some(token) = live_kit
.room_token(
&joined_room.room.live_kit_room,
&session.user_id().to_string(),
)
.trace_err()
{
Some(proto::LiveKitConnectionInfo {
server_url: live_kit.url().into(),
token,
can_publish: true,
})
} else {
None
}
} else {
None
}
} else {
None
};
};
response.send(proto::JoinRoomResponse {
room: Some(joined_room.room),
@ -1877,7 +1878,7 @@ async fn set_room_participant_role(
(live_kit_room, can_publish)
};
if let Some(live_kit) = session.live_kit_client.as_ref() {
if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
live_kit
.update_participant(
live_kit_room.clone(),
@ -4048,35 +4049,40 @@ async fn join_channel_internal(
.join_channel(channel_id, session.user_id(), session.connection_id)
.await?;
let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
let (can_publish, token) = if role == ChannelRole::Guest {
(
false,
live_kit
.guest_token(
&joined_room.room.live_kit_room,
&session.user_id().to_string(),
let live_kit_connection_info =
session
.app_state
.live_kit_client
.as_ref()
.and_then(|live_kit| {
let (can_publish, token) = if role == ChannelRole::Guest {
(
false,
live_kit
.guest_token(
&joined_room.room.live_kit_room,
&session.user_id().to_string(),
)
.trace_err()?,
)
.trace_err()?,
)
} else {
(
true,
live_kit
.room_token(
&joined_room.room.live_kit_room,
&session.user_id().to_string(),
} else {
(
true,
live_kit
.room_token(
&joined_room.room.live_kit_room,
&session.user_id().to_string(),
)
.trace_err()?,
)
.trace_err()?,
)
};
};
Some(LiveKitConnectionInfo {
server_url: live_kit.url().into(),
token,
can_publish,
})
});
Some(LiveKitConnectionInfo {
server_url: live_kit.url().into(),
token,
can_publish,
})
});
response.send(proto::JoinRoomResponse {
room: Some(joined_room.room.clone()),
@ -4610,6 +4616,7 @@ async fn complete_with_language_model(
};
session
.app_state
.rate_limiter
.check(&*rate_limit, session.user_id())
.await?;
@ -4655,6 +4662,7 @@ async fn stream_complete_with_language_model(
};
session
.app_state
.rate_limiter
.check(&*rate_limit, session.user_id())
.await?;
@ -4766,6 +4774,7 @@ async fn count_language_model_tokens(
};
session
.app_state
.rate_limiter
.check(&*rate_limit, session.user_id())
.await?;
@ -4885,6 +4894,7 @@ async fn compute_embeddings(
};
session
.app_state
.rate_limiter
.check(&*rate_limit, session.user_id())
.await?;
@ -5143,6 +5153,24 @@ async fn get_private_user_info(
Ok(())
}
async fn get_llm_api_token(
_request: proto::GetLlmToken,
response: Response<proto::GetLlmToken>,
session: UserSession,
) -> Result<()> {
if !session.is_staff() {
Err(anyhow!("permission denied"))?
}
let token = LlmTokenClaims::create(
session.user_id(),
session.current_plan().await?,
&session.app_state.config,
)?;
response.send(proto::GetLlmTokenResponse { token })?;
Ok(())
}
fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
let message = match message {
TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
@ -5486,7 +5514,7 @@ async fn leave_room_for_session(session: &UserSession, connection_id: Connection
update_user_contacts(contact_user_id, &session).await?;
}
if let Some(live_kit) = session.live_kit_client.as_ref() {
if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
live_kit
.remove_participant(live_kit_room.clone(), session.user_id().to_string())
.await

View File

@ -651,6 +651,7 @@ impl TestServer {
live_kit_server: None,
live_kit_key: None,
live_kit_secret: None,
llm_api_secret: None,
rust_log: None,
log_json: None,
zed_environment: "test".into(),

View File

@ -175,6 +175,22 @@ impl HttpClientWithUrl {
query,
)?)
}
/// Builds a Zed LLM URL using the given path.
pub fn build_zed_llm_url(&self, path: &str, query: &[(&str, &str)]) -> Result<Url> {
let base_url = self.base_url();
let base_api_url = match base_url.as_ref() {
"https://zed.dev" => "https://llm.zed.dev",
"https://staging.zed.dev" => "https://llm-staging.zed.dev",
"http://localhost:3000" => "http://localhost:8080",
other => other,
};
Ok(Url::parse_with_params(
&format!("{}{}", base_api_url, path),
query,
)?)
}
}
impl HttpClient for Arc<HttpClientWithUrl> {

View File

@ -5,13 +5,20 @@ use crate::{
LanguageModelProviderState, LanguageModelRequest, RateLimiter, ZedModel,
};
use anyhow::{anyhow, Context as _, Result};
use client::{Client, UserStore};
use client::{Client, PerformCompletionParams, UserStore, EXPIRED_LLM_TOKEN_HEADER_NAME};
use collections::BTreeMap;
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
use feature_flags::{FeatureFlag, FeatureFlagAppExt};
use futures::{future::BoxFuture, stream::BoxStream, AsyncBufReadExt, FutureExt, StreamExt};
use gpui::{AnyView, AppContext, AsyncAppContext, Model, ModelContext, Subscription, Task};
use http_client::{HttpClient, Method};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_json::value::RawValue;
use settings::{Settings, SettingsStore};
use smol::{
io::BufReader,
lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard},
};
use std::{future, sync::Arc};
use strum::IntoEnumIterator;
use ui::prelude::*;
@ -46,6 +53,7 @@ pub struct AvailableModel {
pub struct CloudLanguageModelProvider {
client: Arc<Client>,
llm_api_token: LlmApiToken,
state: gpui::Model<State>,
_maintain_client_status: Task<()>,
}
@ -104,6 +112,7 @@ impl CloudLanguageModelProvider {
Self {
client,
state,
llm_api_token: LlmApiToken::default(),
_maintain_client_status: maintain_client_status,
}
}
@ -181,6 +190,7 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
Arc::new(CloudLanguageModel {
id: LanguageModelId::from(model.id().to_string()),
model,
llm_api_token: self.llm_api_token.clone(),
client: self.client.clone(),
request_limiter: RateLimiter::new(4),
}) as Arc<dyn LanguageModel>
@ -208,13 +218,27 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
}
}
struct LlmServiceFeatureFlag;
impl FeatureFlag for LlmServiceFeatureFlag {
const NAME: &'static str = "llm-service";
fn enabled_for_staff() -> bool {
false
}
}
pub struct CloudLanguageModel {
id: LanguageModelId,
model: CloudModel,
llm_api_token: LlmApiToken,
client: Arc<Client>,
request_limiter: RateLimiter,
}
#[derive(Clone, Default)]
struct LlmApiToken(Arc<RwLock<Option<String>>>);
impl LanguageModel for CloudLanguageModel {
fn id(&self) -> LanguageModelId {
self.id.clone()
@ -279,25 +303,88 @@ impl LanguageModel for CloudLanguageModel {
fn stream_completion(
&self,
request: LanguageModelRequest,
_: &AsyncAppContext,
cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
match &self.model {
CloudModel::Anthropic(model) => {
let client = self.client.clone();
let request = request.into_anthropic(model.id().into());
let future = self.request_limiter.stream(async move {
let request = serde_json::to_string(&request)?;
let stream = client
.request_stream(proto::StreamCompleteWithLanguageModel {
provider: proto::LanguageModelProvider::Anthropic as i32,
request,
})
.await?;
Ok(anthropic::extract_text_from_events(
stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
))
});
async move { Ok(future.await?.boxed()) }.boxed()
let client = self.client.clone();
if cx
.update(|cx| cx.has_flag::<LlmServiceFeatureFlag>())
.unwrap_or(false)
{
let http_client = self.client.http_client();
let llm_api_token = self.llm_api_token.clone();
let future = self.request_limiter.stream(async move {
let request = serde_json::to_string(&request)?;
let mut token = llm_api_token.acquire(&client).await?;
let mut did_retry = false;
let response = loop {
let request = http_client::Request::builder()
.method(Method::POST)
.uri(http_client.build_zed_llm_url("/completion", &[])?.as_ref())
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {token}"))
.body(
serde_json::to_string(&PerformCompletionParams {
provider_request: RawValue::from_string(request.clone())?,
})?
.into(),
)?;
let response = http_client.send(request).await?;
if response.status().is_success() {
break response;
} else if !did_retry
&& response
.headers()
.get(EXPIRED_LLM_TOKEN_HEADER_NAME)
.is_some()
{
did_retry = true;
token = llm_api_token.refresh(&client).await?;
} else {
break Err(anyhow!(
"cloud language model completion failed with status {}",
response.status()
))?;
}
};
let body = BufReader::new(response.into_body());
let stream =
futures::stream::try_unfold(body, move |mut body| async move {
let mut buffer = String::new();
match body.read_line(&mut buffer).await {
Ok(0) => Ok(None),
Ok(_) => {
let event: anthropic::Event =
serde_json::from_str(&buffer)?;
Ok(Some((event, body)))
}
Err(e) => Err(e.into()),
}
});
Ok(anthropic::extract_text_from_events(stream))
});
async move { Ok(future.await?.boxed()) }.boxed()
} else {
let future = self.request_limiter.stream(async move {
let request = serde_json::to_string(&request)?;
let stream = client
.request_stream(proto::StreamCompleteWithLanguageModel {
provider: proto::LanguageModelProvider::Anthropic as i32,
request,
})
.await?
.map(|event| Ok(serde_json::from_str(&event?.event)?));
Ok(anthropic::extract_text_from_events(stream))
});
async move { Ok(future.await?.boxed()) }.boxed()
}
}
CloudModel::OpenAi(model) => {
let client = self.client.clone();
@ -417,6 +504,30 @@ impl LanguageModel for CloudLanguageModel {
}
}
impl LlmApiToken {
async fn acquire(&self, client: &Arc<Client>) -> Result<String> {
let lock = self.0.upgradable_read().await;
if let Some(token) = lock.as_ref() {
Ok(token.to_string())
} else {
Self::fetch(RwLockUpgradableReadGuard::upgrade(lock).await, &client).await
}
}
async fn refresh(&self, client: &Arc<Client>) -> Result<String> {
Self::fetch(self.0.write().await, &client).await
}
async fn fetch<'a>(
mut lock: RwLockWriteGuard<'a, Option<String>>,
client: &Arc<Client>,
) -> Result<String> {
let response = client.request(proto::GetLlmToken {}).await?;
*lock = Some(response.token.clone());
Ok(response.token.clone())
}
}
struct ConfigurationView {
state: gpui::Model<State>,
}

View File

@ -126,7 +126,7 @@ message Envelope {
Unfollow unfollow = 101;
GetPrivateUserInfo get_private_user_info = 102;
GetPrivateUserInfoResponse get_private_user_info_response = 103;
UpdateUserPlan update_user_plan = 234; // current max
UpdateUserPlan update_user_plan = 234;
UpdateDiffBase update_diff_base = 104;
OnTypeFormatting on_type_formatting = 105;
@ -270,6 +270,9 @@ message Envelope {
AddWorktree add_worktree = 222;
AddWorktreeResponse add_worktree_response = 223;
GetLlmToken get_llm_token = 235;
GetLlmTokenResponse get_llm_token_response = 236; // current max
}
reserved 158 to 161;
@ -2425,6 +2428,12 @@ message SynchronizeContextsResponse {
repeated ContextVersion contexts = 1;
}
message GetLlmToken {}
message GetLlmTokenResponse {
string token = 1;
}
// Remote FS
message AddWorktree {

View File

@ -259,6 +259,8 @@ messages!(
(GetTypeDefinitionResponse, Background),
(GetImplementation, Background),
(GetImplementationResponse, Background),
(GetLlmToken, Background),
(GetLlmTokenResponse, Background),
(GetUsers, Foreground),
(Hello, Foreground),
(IncomingCall, Foreground),
@ -438,6 +440,7 @@ request_messages!(
(GetImplementation, GetImplementationResponse),
(GetDocumentHighlights, GetDocumentHighlightsResponse),
(GetHover, GetHoverResponse),
(GetLlmToken, GetLlmTokenResponse),
(GetNotifications, GetNotificationsResponse),
(GetPrivateUserInfo, GetPrivateUserInfoResponse),
(GetProjectSymbols, GetProjectSymbolsResponse),

8
crates/rpc/src/llm.rs Normal file
View File

@ -0,0 +1,8 @@
use serde::{Deserialize, Serialize};
pub const EXPIRED_LLM_TOKEN_HEADER_NAME: &str = "x-zed-expired-token";
#[derive(Serialize, Deserialize)]
pub struct PerformCompletionParams {
pub provider_request: Box<serde_json::value::RawValue>,
}

View File

@ -1,12 +1,14 @@
pub mod auth;
mod conn;
mod extension;
mod llm;
mod notification;
mod peer;
pub mod proto;
pub use conn::Connection;
pub use extension::*;
pub use llm::*;
pub use notification::*;
pub use peer::*;
pub use proto::{error::*, Receipt, TypedEnvelope};