diff --git a/crates/collab/src/llm.rs b/crates/collab/src/llm.rs index cf98604e20..320d7418ee 100644 --- a/crates/collab/src/llm.rs +++ b/crates/collab/src/llm.rs @@ -9,6 +9,7 @@ use crate::{ }; use anyhow::{anyhow, Context as _}; use authorization::authorize_access_to_language_model; +use axum::routing::get; use axum::{ body::Body, http::{self, HeaderName, HeaderValue, Request, StatusCode}, @@ -22,6 +23,7 @@ use collections::HashMap; use db::{usage_measure::UsageMeasure, ActiveUserCount, LlmDatabase}; use futures::{Stream, StreamExt as _}; use http_client::IsahcHttpClient; +use rpc::ListModelsResponse; use rpc::{ proto::Plan, LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME, }; @@ -114,6 +116,7 @@ impl LlmState { pub fn routes() -> Router<(), Body> { Router::new() + .route("/models", get(list_models)) .route("/completion", post(perform_completion)) .layer(middleware::from_fn(validate_api_token)) } @@ -173,6 +176,37 @@ async fn validate_api_token(mut req: Request, next: Next) -> impl IntoR } } +async fn list_models( + Extension(state): Extension>, + Extension(claims): Extension, + country_code_header: Option>, +) -> Result> { + let country_code = country_code_header.map(|header| header.to_string()); + + let mut accessible_models = Vec::new(); + + for (provider, model) in state.db.all_models() { + let authorize_result = authorize_access_to_language_model( + &state.config, + &claims, + country_code.as_deref(), + provider, + &model.name, + ); + + if authorize_result.is_ok() { + accessible_models.push(rpc::LanguageModel { + provider, + name: model.name, + }); + } + } + + Ok(Json(ListModelsResponse { + models: accessible_models, + })) +} + async fn perform_completion( Extension(state): Extension>, Extension(claims): Extension, @@ -187,7 +221,9 @@ async fn perform_completion( authorize_access_to_language_model( &state.config, &claims, - country_code_header.map(|header| header.to_string()), + country_code_header + .map(|header| header.to_string()) + .as_deref(), params.provider, &model, )?; diff --git a/crates/collab/src/llm/authorization.rs b/crates/collab/src/llm/authorization.rs index 98ee1b7c6a..f6acff2685 100644 --- a/crates/collab/src/llm/authorization.rs +++ b/crates/collab/src/llm/authorization.rs @@ -7,7 +7,7 @@ use crate::{Config, Error, Result}; pub fn authorize_access_to_language_model( config: &Config, claims: &LlmTokenClaims, - country_code: Option, + country_code: Option<&str>, provider: LanguageModelProvider, model: &str, ) -> Result<()> { @@ -49,7 +49,7 @@ fn authorize_access_to_model( fn authorize_access_for_country( config: &Config, - country_code: Option, + country_code: Option<&str>, provider: LanguageModelProvider, ) -> Result<()> { // In development we won't have the `CF-IPCountry` header, so we can't check @@ -62,7 +62,7 @@ fn authorize_access_for_country( } // https://developers.cloudflare.com/fundamentals/reference/http-request-headers/#cf-ipcountry - let country_code = match country_code.as_deref() { + let country_code = match country_code { // `XX` - Used for clients without country code data. None | Some("XX") => Err(Error::http( StatusCode::BAD_REQUEST, @@ -128,7 +128,7 @@ mod tests { authorize_access_to_language_model( &config, &claims, - Some(country_code.into()), + Some(country_code), provider, "the-model", ) @@ -178,7 +178,7 @@ mod tests { let error_response = authorize_access_to_language_model( &config, &claims, - Some(country_code.into()), + Some(country_code), provider, "the-model", ) @@ -223,7 +223,7 @@ mod tests { let error_response = authorize_access_to_language_model( &config, &claims, - Some(country_code.into()), + Some(country_code), provider, "the-model", ) @@ -278,13 +278,8 @@ mod tests { ..Default::default() }; - let result = authorize_access_to_language_model( - &config, - &claims, - Some("US".into()), - provider, - model, - ); + let result = + authorize_access_to_language_model(&config, &claims, Some("US"), provider, model); if expected_access { assert!( @@ -324,13 +319,8 @@ mod tests { ]; for (provider, model) in test_cases { - let result = authorize_access_to_language_model( - &config, - &claims, - Some("US".into()), - provider, - model, - ); + let result = + authorize_access_to_language_model(&config, &claims, Some("US"), provider, model); assert!( result.is_ok(), diff --git a/crates/collab/src/llm/db.rs b/crates/collab/src/llm/db.rs index f76a722471..cd370b14b1 100644 --- a/crates/collab/src/llm/db.rs +++ b/crates/collab/src/llm/db.rs @@ -67,6 +67,14 @@ impl LlmDatabase { Ok(()) } + /// Returns the list of all known models, with their [`LanguageModelProvider`]. + pub fn all_models(&self) -> Vec<(LanguageModelProvider, model::Model)> { + self.models + .iter() + .map(|((model_provider, _model_name), model)| (*model_provider, model.clone())) + .collect::>() + } + /// Returns the names of the known models for the given [`LanguageModelProvider`]. pub fn model_names_for_provider(&self, provider: LanguageModelProvider) -> Vec { self.models diff --git a/crates/rpc/src/llm.rs b/crates/rpc/src/llm.rs index 7f97b02df7..6cae54b309 100644 --- a/crates/rpc/src/llm.rs +++ b/crates/rpc/src/llm.rs @@ -15,7 +15,18 @@ pub enum LanguageModelProvider { Zed, } -#[derive(Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize)] +pub struct LanguageModel { + pub provider: LanguageModelProvider, + pub name: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct ListModelsResponse { + pub models: Vec, +} + +#[derive(Debug, Serialize, Deserialize)] pub struct PerformCompletionParams { pub provider: LanguageModelProvider, pub model: String,