mirror of
https://github.com/zed-industries/zed.git
synced 2024-11-08 07:35:01 +03:00
collab: Add GET /models
endpoint to LLM service (#17307)
This PR adds a `GET /models` endpoint to the LLM service. This endpoint returns the models that the authenticated user has access to. This is the first step towards populating the models for the hosted service from the server. Release Notes: - N/A
This commit is contained in:
parent
122f01f9e5
commit
30056254f3
@ -9,6 +9,7 @@ use crate::{
|
|||||||
};
|
};
|
||||||
use anyhow::{anyhow, Context as _};
|
use anyhow::{anyhow, Context as _};
|
||||||
use authorization::authorize_access_to_language_model;
|
use authorization::authorize_access_to_language_model;
|
||||||
|
use axum::routing::get;
|
||||||
use axum::{
|
use axum::{
|
||||||
body::Body,
|
body::Body,
|
||||||
http::{self, HeaderName, HeaderValue, Request, StatusCode},
|
http::{self, HeaderName, HeaderValue, Request, StatusCode},
|
||||||
@ -22,6 +23,7 @@ use collections::HashMap;
|
|||||||
use db::{usage_measure::UsageMeasure, ActiveUserCount, LlmDatabase};
|
use db::{usage_measure::UsageMeasure, ActiveUserCount, LlmDatabase};
|
||||||
use futures::{Stream, StreamExt as _};
|
use futures::{Stream, StreamExt as _};
|
||||||
use http_client::IsahcHttpClient;
|
use http_client::IsahcHttpClient;
|
||||||
|
use rpc::ListModelsResponse;
|
||||||
use rpc::{
|
use rpc::{
|
||||||
proto::Plan, LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME,
|
proto::Plan, LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME,
|
||||||
};
|
};
|
||||||
@ -114,6 +116,7 @@ impl LlmState {
|
|||||||
|
|
||||||
pub fn routes() -> Router<(), Body> {
|
pub fn routes() -> Router<(), Body> {
|
||||||
Router::new()
|
Router::new()
|
||||||
|
.route("/models", get(list_models))
|
||||||
.route("/completion", post(perform_completion))
|
.route("/completion", post(perform_completion))
|
||||||
.layer(middleware::from_fn(validate_api_token))
|
.layer(middleware::from_fn(validate_api_token))
|
||||||
}
|
}
|
||||||
@ -173,6 +176,37 @@ async fn validate_api_token<B>(mut req: Request<B>, next: Next<B>) -> impl IntoR
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn list_models(
|
||||||
|
Extension(state): Extension<Arc<LlmState>>,
|
||||||
|
Extension(claims): Extension<LlmTokenClaims>,
|
||||||
|
country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
|
||||||
|
) -> Result<Json<ListModelsResponse>> {
|
||||||
|
let country_code = country_code_header.map(|header| header.to_string());
|
||||||
|
|
||||||
|
let mut accessible_models = Vec::new();
|
||||||
|
|
||||||
|
for (provider, model) in state.db.all_models() {
|
||||||
|
let authorize_result = authorize_access_to_language_model(
|
||||||
|
&state.config,
|
||||||
|
&claims,
|
||||||
|
country_code.as_deref(),
|
||||||
|
provider,
|
||||||
|
&model.name,
|
||||||
|
);
|
||||||
|
|
||||||
|
if authorize_result.is_ok() {
|
||||||
|
accessible_models.push(rpc::LanguageModel {
|
||||||
|
provider,
|
||||||
|
name: model.name,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(Json(ListModelsResponse {
|
||||||
|
models: accessible_models,
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
async fn perform_completion(
|
async fn perform_completion(
|
||||||
Extension(state): Extension<Arc<LlmState>>,
|
Extension(state): Extension<Arc<LlmState>>,
|
||||||
Extension(claims): Extension<LlmTokenClaims>,
|
Extension(claims): Extension<LlmTokenClaims>,
|
||||||
@ -187,7 +221,9 @@ async fn perform_completion(
|
|||||||
authorize_access_to_language_model(
|
authorize_access_to_language_model(
|
||||||
&state.config,
|
&state.config,
|
||||||
&claims,
|
&claims,
|
||||||
country_code_header.map(|header| header.to_string()),
|
country_code_header
|
||||||
|
.map(|header| header.to_string())
|
||||||
|
.as_deref(),
|
||||||
params.provider,
|
params.provider,
|
||||||
&model,
|
&model,
|
||||||
)?;
|
)?;
|
||||||
|
@ -7,7 +7,7 @@ use crate::{Config, Error, Result};
|
|||||||
pub fn authorize_access_to_language_model(
|
pub fn authorize_access_to_language_model(
|
||||||
config: &Config,
|
config: &Config,
|
||||||
claims: &LlmTokenClaims,
|
claims: &LlmTokenClaims,
|
||||||
country_code: Option<String>,
|
country_code: Option<&str>,
|
||||||
provider: LanguageModelProvider,
|
provider: LanguageModelProvider,
|
||||||
model: &str,
|
model: &str,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
@ -49,7 +49,7 @@ fn authorize_access_to_model(
|
|||||||
|
|
||||||
fn authorize_access_for_country(
|
fn authorize_access_for_country(
|
||||||
config: &Config,
|
config: &Config,
|
||||||
country_code: Option<String>,
|
country_code: Option<&str>,
|
||||||
provider: LanguageModelProvider,
|
provider: LanguageModelProvider,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
// In development we won't have the `CF-IPCountry` header, so we can't check
|
// In development we won't have the `CF-IPCountry` header, so we can't check
|
||||||
@ -62,7 +62,7 @@ fn authorize_access_for_country(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// https://developers.cloudflare.com/fundamentals/reference/http-request-headers/#cf-ipcountry
|
// https://developers.cloudflare.com/fundamentals/reference/http-request-headers/#cf-ipcountry
|
||||||
let country_code = match country_code.as_deref() {
|
let country_code = match country_code {
|
||||||
// `XX` - Used for clients without country code data.
|
// `XX` - Used for clients without country code data.
|
||||||
None | Some("XX") => Err(Error::http(
|
None | Some("XX") => Err(Error::http(
|
||||||
StatusCode::BAD_REQUEST,
|
StatusCode::BAD_REQUEST,
|
||||||
@ -128,7 +128,7 @@ mod tests {
|
|||||||
authorize_access_to_language_model(
|
authorize_access_to_language_model(
|
||||||
&config,
|
&config,
|
||||||
&claims,
|
&claims,
|
||||||
Some(country_code.into()),
|
Some(country_code),
|
||||||
provider,
|
provider,
|
||||||
"the-model",
|
"the-model",
|
||||||
)
|
)
|
||||||
@ -178,7 +178,7 @@ mod tests {
|
|||||||
let error_response = authorize_access_to_language_model(
|
let error_response = authorize_access_to_language_model(
|
||||||
&config,
|
&config,
|
||||||
&claims,
|
&claims,
|
||||||
Some(country_code.into()),
|
Some(country_code),
|
||||||
provider,
|
provider,
|
||||||
"the-model",
|
"the-model",
|
||||||
)
|
)
|
||||||
@ -223,7 +223,7 @@ mod tests {
|
|||||||
let error_response = authorize_access_to_language_model(
|
let error_response = authorize_access_to_language_model(
|
||||||
&config,
|
&config,
|
||||||
&claims,
|
&claims,
|
||||||
Some(country_code.into()),
|
Some(country_code),
|
||||||
provider,
|
provider,
|
||||||
"the-model",
|
"the-model",
|
||||||
)
|
)
|
||||||
@ -278,13 +278,8 @@ mod tests {
|
|||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
|
|
||||||
let result = authorize_access_to_language_model(
|
let result =
|
||||||
&config,
|
authorize_access_to_language_model(&config, &claims, Some("US"), provider, model);
|
||||||
&claims,
|
|
||||||
Some("US".into()),
|
|
||||||
provider,
|
|
||||||
model,
|
|
||||||
);
|
|
||||||
|
|
||||||
if expected_access {
|
if expected_access {
|
||||||
assert!(
|
assert!(
|
||||||
@ -324,13 +319,8 @@ mod tests {
|
|||||||
];
|
];
|
||||||
|
|
||||||
for (provider, model) in test_cases {
|
for (provider, model) in test_cases {
|
||||||
let result = authorize_access_to_language_model(
|
let result =
|
||||||
&config,
|
authorize_access_to_language_model(&config, &claims, Some("US"), provider, model);
|
||||||
&claims,
|
|
||||||
Some("US".into()),
|
|
||||||
provider,
|
|
||||||
model,
|
|
||||||
);
|
|
||||||
|
|
||||||
assert!(
|
assert!(
|
||||||
result.is_ok(),
|
result.is_ok(),
|
||||||
|
@ -67,6 +67,14 @@ impl LlmDatabase {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns the list of all known models, with their [`LanguageModelProvider`].
|
||||||
|
pub fn all_models(&self) -> Vec<(LanguageModelProvider, model::Model)> {
|
||||||
|
self.models
|
||||||
|
.iter()
|
||||||
|
.map(|((model_provider, _model_name), model)| (*model_provider, model.clone()))
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
}
|
||||||
|
|
||||||
/// Returns the names of the known models for the given [`LanguageModelProvider`].
|
/// Returns the names of the known models for the given [`LanguageModelProvider`].
|
||||||
pub fn model_names_for_provider(&self, provider: LanguageModelProvider) -> Vec<String> {
|
pub fn model_names_for_provider(&self, provider: LanguageModelProvider) -> Vec<String> {
|
||||||
self.models
|
self.models
|
||||||
|
@ -15,7 +15,18 @@ pub enum LanguageModelProvider {
|
|||||||
Zed,
|
Zed,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
pub struct LanguageModel {
|
||||||
|
pub provider: LanguageModelProvider,
|
||||||
|
pub name: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
pub struct ListModelsResponse {
|
||||||
|
pub models: Vec<LanguageModel>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
pub struct PerformCompletionParams {
|
pub struct PerformCompletionParams {
|
||||||
pub provider: LanguageModelProvider,
|
pub provider: LanguageModelProvider,
|
||||||
pub model: String,
|
pub model: String,
|
||||||
|
Loading…
Reference in New Issue
Block a user