collab: Add support for more providers to the LLM service (#15832)

This PR adds support for additional providers to the LLM service:

- OpenAI
- Google
- Custom Zed models (through Hugging Face)

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2024-08-05 21:16:18 -04:00 committed by GitHub
parent 8e9c2b1125
commit ca9511393b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 331 additions and 98 deletions

View File

@ -12,7 +12,7 @@ use axum::{
};
use futures::StreamExt as _;
use http_client::IsahcHttpClient;
use rpc::{PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME};
use rpc::{LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME};
use std::sync::Arc;
pub use token::*;
@ -94,6 +94,8 @@ async fn perform_completion(
Extension(_claims): Extension<LlmTokenClaims>,
Json(params): Json<PerformCompletionParams>,
) -> Result<impl IntoResponse> {
match params.provider {
LanguageModelProvider::Anthropic => {
let api_key = state
.config
.anthropic_api_key
@ -120,3 +122,90 @@ async fn perform_completion(
Ok(Response::new(Body::wrap_stream(stream)))
}
LanguageModelProvider::OpenAi => {
let api_key = state
.config
.openai_api_key
.as_ref()
.context("no OpenAI API key configured on the server")?;
let chunks = open_ai::stream_completion(
&state.http_client,
open_ai::OPEN_AI_API_URL,
api_key,
serde_json::from_str(&params.provider_request.get())?,
None,
)
.await?;
let stream = chunks.map(|event| {
let mut buffer = Vec::new();
event.map(|chunk| {
buffer.clear();
serde_json::to_writer(&mut buffer, &chunk).unwrap();
buffer.push(b'\n');
buffer
})
});
Ok(Response::new(Body::wrap_stream(stream)))
}
LanguageModelProvider::Google => {
let api_key = state
.config
.google_ai_api_key
.as_ref()
.context("no Google AI API key configured on the server")?;
let chunks = google_ai::stream_generate_content(
&state.http_client,
google_ai::API_URL,
api_key,
serde_json::from_str(&params.provider_request.get())?,
)
.await?;
let stream = chunks.map(|event| {
let mut buffer = Vec::new();
event.map(|chunk| {
buffer.clear();
serde_json::to_writer(&mut buffer, &chunk).unwrap();
buffer.push(b'\n');
buffer
})
});
Ok(Response::new(Body::wrap_stream(stream)))
}
LanguageModelProvider::Zed => {
let api_key = state
.config
.qwen2_7b_api_key
.as_ref()
.context("no Qwen2-7B API key configured on the server")?;
let api_url = state
.config
.qwen2_7b_api_url
.as_ref()
.context("no Qwen2-7B URL configured on the server")?;
let chunks = open_ai::stream_completion(
&state.http_client,
&api_url,
api_key,
serde_json::from_str(&params.provider_request.get())?,
None,
)
.await?;
let stream = chunks.map(|event| {
let mut buffer = Vec::new();
event.map(|chunk| {
buffer.clear();
serde_json::to_writer(&mut buffer, &chunk).unwrap();
buffer.push(b'\n');
buffer
})
});
Ok(Response::new(Body::wrap_stream(stream)))
}
}
}

View File

@ -10,7 +10,7 @@ use collections::BTreeMap;
use feature_flags::{FeatureFlag, FeatureFlagAppExt};
use futures::{future::BoxFuture, stream::BoxStream, AsyncBufReadExt, FutureExt, StreamExt};
use gpui::{AnyView, AppContext, AsyncAppContext, Model, ModelContext, Subscription, Task};
use http_client::{HttpClient, Method};
use http_client::{AsyncBody, HttpClient, Method, Response};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_json::value::RawValue;
@ -239,6 +239,47 @@ pub struct CloudLanguageModel {
#[derive(Clone, Default)]
struct LlmApiToken(Arc<RwLock<Option<String>>>);
impl CloudLanguageModel {
async fn perform_llm_completion(
client: Arc<Client>,
llm_api_token: LlmApiToken,
body: PerformCompletionParams,
) -> Result<Response<AsyncBody>> {
let http_client = &client.http_client();
let mut token = llm_api_token.acquire(&client).await?;
let mut did_retry = false;
let response = loop {
let request = http_client::Request::builder()
.method(Method::POST)
.uri(http_client.build_zed_llm_url("/completion", &[])?.as_ref())
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {token}"))
.body(serde_json::to_string(&body)?.into())?;
let response = http_client.send(request).await?;
if response.status().is_success() {
break response;
} else if !did_retry
&& response
.headers()
.get(EXPIRED_LLM_TOKEN_HEADER_NAME)
.is_some()
{
did_retry = true;
token = llm_api_token.refresh(&client).await?;
} else {
break Err(anyhow!(
"cloud language model completion failed with status {}",
response.status()
))?;
}
};
Ok(response)
}
}
impl LanguageModel for CloudLanguageModel {
fn id(&self) -> LanguageModelId {
self.id.clone()
@ -314,46 +355,21 @@ impl LanguageModel for CloudLanguageModel {
.update(|cx| cx.has_flag::<LlmServiceFeatureFlag>())
.unwrap_or(false)
{
let http_client = self.client.http_client();
let llm_api_token = self.llm_api_token.clone();
let future = self.request_limiter.stream(async move {
let request = serde_json::to_string(&request)?;
let mut token = llm_api_token.acquire(&client).await?;
let mut did_retry = false;
let response = loop {
let request = http_client::Request::builder()
.method(Method::POST)
.uri(http_client.build_zed_llm_url("/completion", &[])?.as_ref())
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {token}"))
.body(
serde_json::to_string(&PerformCompletionParams {
provider_request: RawValue::from_string(request.clone())?,
})?
.into(),
)?;
let response = http_client.send(request).await?;
if response.status().is_success() {
break response;
} else if !did_retry
&& response
.headers()
.get(EXPIRED_LLM_TOKEN_HEADER_NAME)
.is_some()
{
did_retry = true;
token = llm_api_token.refresh(&client).await?;
} else {
break Err(anyhow!(
"cloud language model completion failed with status {}",
response.status()
))?;
}
};
let response = Self::perform_llm_completion(
client.clone(),
llm_api_token,
PerformCompletionParams {
provider: client::LanguageModelProvider::Anthropic,
model: request.model.clone(),
provider_request: RawValue::from_string(serde_json::to_string(
&request,
)?)?,
},
)
.await?;
let body = BufReader::new(response.into_body());
let stream =
futures::stream::try_unfold(body, move |mut body| async move {
let mut buffer = String::new();
@ -389,6 +405,44 @@ impl LanguageModel for CloudLanguageModel {
CloudModel::OpenAi(model) => {
let client = self.client.clone();
let request = request.into_open_ai(model.id().into());
if cx
.update(|cx| cx.has_flag::<LlmServiceFeatureFlag>())
.unwrap_or(false)
{
let llm_api_token = self.llm_api_token.clone();
let future = self.request_limiter.stream(async move {
let response = Self::perform_llm_completion(
client.clone(),
llm_api_token,
PerformCompletionParams {
provider: client::LanguageModelProvider::OpenAi,
model: request.model.clone(),
provider_request: RawValue::from_string(serde_json::to_string(
&request,
)?)?,
},
)
.await?;
let body = BufReader::new(response.into_body());
let stream =
futures::stream::try_unfold(body, move |mut body| async move {
let mut buffer = String::new();
match body.read_line(&mut buffer).await {
Ok(0) => Ok(None),
Ok(_) => {
let event: open_ai::ResponseStreamEvent =
serde_json::from_str(&buffer)?;
Ok(Some((event, body)))
}
Err(e) => Err(e.into()),
}
});
Ok(open_ai::extract_text_from_events(stream))
});
async move { Ok(future.await?.boxed()) }.boxed()
} else {
let future = self.request_limiter.stream(async move {
let request = serde_json::to_string(&request)?;
let stream = client
@ -403,9 +457,48 @@ impl LanguageModel for CloudLanguageModel {
});
async move { Ok(future.await?.boxed()) }.boxed()
}
}
CloudModel::Google(model) => {
let client = self.client.clone();
let request = request.into_google(model.id().into());
if cx
.update(|cx| cx.has_flag::<LlmServiceFeatureFlag>())
.unwrap_or(false)
{
let llm_api_token = self.llm_api_token.clone();
let future = self.request_limiter.stream(async move {
let response = Self::perform_llm_completion(
client.clone(),
llm_api_token,
PerformCompletionParams {
provider: client::LanguageModelProvider::Google,
model: request.model.clone(),
provider_request: RawValue::from_string(serde_json::to_string(
&request,
)?)?,
},
)
.await?;
let body = BufReader::new(response.into_body());
let stream =
futures::stream::try_unfold(body, move |mut body| async move {
let mut buffer = String::new();
match body.read_line(&mut buffer).await {
Ok(0) => Ok(None),
Ok(_) => {
let event: google_ai::GenerateContentResponse =
serde_json::from_str(&buffer)?;
Ok(Some((event, body)))
}
Err(e) => Err(e.into()),
}
});
Ok(google_ai::extract_text_from_events(stream))
});
async move { Ok(future.await?.boxed()) }.boxed()
} else {
let future = self.request_limiter.stream(async move {
let request = serde_json::to_string(&request)?;
let stream = client
@ -420,10 +513,49 @@ impl LanguageModel for CloudLanguageModel {
});
async move { Ok(future.await?.boxed()) }.boxed()
}
}
CloudModel::Zed(model) => {
let client = self.client.clone();
let mut request = request.into_open_ai(model.id().into());
request.max_tokens = Some(4000);
if cx
.update(|cx| cx.has_flag::<LlmServiceFeatureFlag>())
.unwrap_or(false)
{
let llm_api_token = self.llm_api_token.clone();
let future = self.request_limiter.stream(async move {
let response = Self::perform_llm_completion(
client.clone(),
llm_api_token,
PerformCompletionParams {
provider: client::LanguageModelProvider::Zed,
model: request.model.clone(),
provider_request: RawValue::from_string(serde_json::to_string(
&request,
)?)?,
},
)
.await?;
let body = BufReader::new(response.into_body());
let stream =
futures::stream::try_unfold(body, move |mut body| async move {
let mut buffer = String::new();
match body.read_line(&mut buffer).await {
Ok(0) => Ok(None),
Ok(_) => {
let event: open_ai::ResponseStreamEvent =
serde_json::from_str(&buffer)?;
Ok(Some((event, body)))
}
Err(e) => Err(e.into()),
}
});
Ok(open_ai::extract_text_from_events(stream))
});
async move { Ok(future.await?.boxed()) }.boxed()
} else {
let future = self.request_limiter.stream(async move {
let request = serde_json::to_string(&request)?;
let stream = client
@ -440,6 +572,7 @@ impl LanguageModel for CloudLanguageModel {
}
}
}
}
fn use_any_tool(
&self,

View File

@ -2,7 +2,18 @@ use serde::{Deserialize, Serialize};
pub const EXPIRED_LLM_TOKEN_HEADER_NAME: &str = "x-zed-expired-token";
#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum LanguageModelProvider {
Anthropic,
OpenAi,
Google,
Zed,
}
#[derive(Serialize, Deserialize)]
pub struct PerformCompletionParams {
pub provider: LanguageModelProvider,
pub model: String,
pub provider_request: Box<serde_json::value::RawValue>,
}