mirror of
https://github.com/zed-industries/zed.git
synced 2024-11-08 07:35:01 +03:00
collab: Add support for more providers to the LLM service (#15832)
This PR adds support for additional providers to the LLM service: - OpenAI - Google - Custom Zed models (through Hugging Face) Release Notes: - N/A
This commit is contained in:
parent
8e9c2b1125
commit
ca9511393b
@ -12,7 +12,7 @@ use axum::{
|
||||
};
|
||||
use futures::StreamExt as _;
|
||||
use http_client::IsahcHttpClient;
|
||||
use rpc::{PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME};
|
||||
use rpc::{LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME};
|
||||
use std::sync::Arc;
|
||||
|
||||
pub use token::*;
|
||||
@ -94,6 +94,8 @@ async fn perform_completion(
|
||||
Extension(_claims): Extension<LlmTokenClaims>,
|
||||
Json(params): Json<PerformCompletionParams>,
|
||||
) -> Result<impl IntoResponse> {
|
||||
match params.provider {
|
||||
LanguageModelProvider::Anthropic => {
|
||||
let api_key = state
|
||||
.config
|
||||
.anthropic_api_key
|
||||
@ -120,3 +122,90 @@ async fn perform_completion(
|
||||
|
||||
Ok(Response::new(Body::wrap_stream(stream)))
|
||||
}
|
||||
LanguageModelProvider::OpenAi => {
|
||||
let api_key = state
|
||||
.config
|
||||
.openai_api_key
|
||||
.as_ref()
|
||||
.context("no OpenAI API key configured on the server")?;
|
||||
let chunks = open_ai::stream_completion(
|
||||
&state.http_client,
|
||||
open_ai::OPEN_AI_API_URL,
|
||||
api_key,
|
||||
serde_json::from_str(¶ms.provider_request.get())?,
|
||||
None,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let stream = chunks.map(|event| {
|
||||
let mut buffer = Vec::new();
|
||||
event.map(|chunk| {
|
||||
buffer.clear();
|
||||
serde_json::to_writer(&mut buffer, &chunk).unwrap();
|
||||
buffer.push(b'\n');
|
||||
buffer
|
||||
})
|
||||
});
|
||||
|
||||
Ok(Response::new(Body::wrap_stream(stream)))
|
||||
}
|
||||
LanguageModelProvider::Google => {
|
||||
let api_key = state
|
||||
.config
|
||||
.google_ai_api_key
|
||||
.as_ref()
|
||||
.context("no Google AI API key configured on the server")?;
|
||||
let chunks = google_ai::stream_generate_content(
|
||||
&state.http_client,
|
||||
google_ai::API_URL,
|
||||
api_key,
|
||||
serde_json::from_str(¶ms.provider_request.get())?,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let stream = chunks.map(|event| {
|
||||
let mut buffer = Vec::new();
|
||||
event.map(|chunk| {
|
||||
buffer.clear();
|
||||
serde_json::to_writer(&mut buffer, &chunk).unwrap();
|
||||
buffer.push(b'\n');
|
||||
buffer
|
||||
})
|
||||
});
|
||||
|
||||
Ok(Response::new(Body::wrap_stream(stream)))
|
||||
}
|
||||
LanguageModelProvider::Zed => {
|
||||
let api_key = state
|
||||
.config
|
||||
.qwen2_7b_api_key
|
||||
.as_ref()
|
||||
.context("no Qwen2-7B API key configured on the server")?;
|
||||
let api_url = state
|
||||
.config
|
||||
.qwen2_7b_api_url
|
||||
.as_ref()
|
||||
.context("no Qwen2-7B URL configured on the server")?;
|
||||
let chunks = open_ai::stream_completion(
|
||||
&state.http_client,
|
||||
&api_url,
|
||||
api_key,
|
||||
serde_json::from_str(¶ms.provider_request.get())?,
|
||||
None,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let stream = chunks.map(|event| {
|
||||
let mut buffer = Vec::new();
|
||||
event.map(|chunk| {
|
||||
buffer.clear();
|
||||
serde_json::to_writer(&mut buffer, &chunk).unwrap();
|
||||
buffer.push(b'\n');
|
||||
buffer
|
||||
})
|
||||
});
|
||||
|
||||
Ok(Response::new(Body::wrap_stream(stream)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -10,7 +10,7 @@ use collections::BTreeMap;
|
||||
use feature_flags::{FeatureFlag, FeatureFlagAppExt};
|
||||
use futures::{future::BoxFuture, stream::BoxStream, AsyncBufReadExt, FutureExt, StreamExt};
|
||||
use gpui::{AnyView, AppContext, AsyncAppContext, Model, ModelContext, Subscription, Task};
|
||||
use http_client::{HttpClient, Method};
|
||||
use http_client::{AsyncBody, HttpClient, Method, Response};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::value::RawValue;
|
||||
@ -239,6 +239,47 @@ pub struct CloudLanguageModel {
|
||||
#[derive(Clone, Default)]
|
||||
struct LlmApiToken(Arc<RwLock<Option<String>>>);
|
||||
|
||||
impl CloudLanguageModel {
|
||||
async fn perform_llm_completion(
|
||||
client: Arc<Client>,
|
||||
llm_api_token: LlmApiToken,
|
||||
body: PerformCompletionParams,
|
||||
) -> Result<Response<AsyncBody>> {
|
||||
let http_client = &client.http_client();
|
||||
|
||||
let mut token = llm_api_token.acquire(&client).await?;
|
||||
let mut did_retry = false;
|
||||
|
||||
let response = loop {
|
||||
let request = http_client::Request::builder()
|
||||
.method(Method::POST)
|
||||
.uri(http_client.build_zed_llm_url("/completion", &[])?.as_ref())
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.body(serde_json::to_string(&body)?.into())?;
|
||||
let response = http_client.send(request).await?;
|
||||
if response.status().is_success() {
|
||||
break response;
|
||||
} else if !did_retry
|
||||
&& response
|
||||
.headers()
|
||||
.get(EXPIRED_LLM_TOKEN_HEADER_NAME)
|
||||
.is_some()
|
||||
{
|
||||
did_retry = true;
|
||||
token = llm_api_token.refresh(&client).await?;
|
||||
} else {
|
||||
break Err(anyhow!(
|
||||
"cloud language model completion failed with status {}",
|
||||
response.status()
|
||||
))?;
|
||||
}
|
||||
};
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
}
|
||||
|
||||
impl LanguageModel for CloudLanguageModel {
|
||||
fn id(&self) -> LanguageModelId {
|
||||
self.id.clone()
|
||||
@ -314,46 +355,21 @@ impl LanguageModel for CloudLanguageModel {
|
||||
.update(|cx| cx.has_flag::<LlmServiceFeatureFlag>())
|
||||
.unwrap_or(false)
|
||||
{
|
||||
let http_client = self.client.http_client();
|
||||
let llm_api_token = self.llm_api_token.clone();
|
||||
let future = self.request_limiter.stream(async move {
|
||||
let request = serde_json::to_string(&request)?;
|
||||
let mut token = llm_api_token.acquire(&client).await?;
|
||||
let mut did_retry = false;
|
||||
|
||||
let response = loop {
|
||||
let request = http_client::Request::builder()
|
||||
.method(Method::POST)
|
||||
.uri(http_client.build_zed_llm_url("/completion", &[])?.as_ref())
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.body(
|
||||
serde_json::to_string(&PerformCompletionParams {
|
||||
provider_request: RawValue::from_string(request.clone())?,
|
||||
})?
|
||||
.into(),
|
||||
)?;
|
||||
let response = http_client.send(request).await?;
|
||||
if response.status().is_success() {
|
||||
break response;
|
||||
} else if !did_retry
|
||||
&& response
|
||||
.headers()
|
||||
.get(EXPIRED_LLM_TOKEN_HEADER_NAME)
|
||||
.is_some()
|
||||
{
|
||||
did_retry = true;
|
||||
token = llm_api_token.refresh(&client).await?;
|
||||
} else {
|
||||
break Err(anyhow!(
|
||||
"cloud language model completion failed with status {}",
|
||||
response.status()
|
||||
))?;
|
||||
}
|
||||
};
|
||||
|
||||
let response = Self::perform_llm_completion(
|
||||
client.clone(),
|
||||
llm_api_token,
|
||||
PerformCompletionParams {
|
||||
provider: client::LanguageModelProvider::Anthropic,
|
||||
model: request.model.clone(),
|
||||
provider_request: RawValue::from_string(serde_json::to_string(
|
||||
&request,
|
||||
)?)?,
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
let body = BufReader::new(response.into_body());
|
||||
|
||||
let stream =
|
||||
futures::stream::try_unfold(body, move |mut body| async move {
|
||||
let mut buffer = String::new();
|
||||
@ -389,6 +405,44 @@ impl LanguageModel for CloudLanguageModel {
|
||||
CloudModel::OpenAi(model) => {
|
||||
let client = self.client.clone();
|
||||
let request = request.into_open_ai(model.id().into());
|
||||
|
||||
if cx
|
||||
.update(|cx| cx.has_flag::<LlmServiceFeatureFlag>())
|
||||
.unwrap_or(false)
|
||||
{
|
||||
let llm_api_token = self.llm_api_token.clone();
|
||||
let future = self.request_limiter.stream(async move {
|
||||
let response = Self::perform_llm_completion(
|
||||
client.clone(),
|
||||
llm_api_token,
|
||||
PerformCompletionParams {
|
||||
provider: client::LanguageModelProvider::OpenAi,
|
||||
model: request.model.clone(),
|
||||
provider_request: RawValue::from_string(serde_json::to_string(
|
||||
&request,
|
||||
)?)?,
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
let body = BufReader::new(response.into_body());
|
||||
let stream =
|
||||
futures::stream::try_unfold(body, move |mut body| async move {
|
||||
let mut buffer = String::new();
|
||||
match body.read_line(&mut buffer).await {
|
||||
Ok(0) => Ok(None),
|
||||
Ok(_) => {
|
||||
let event: open_ai::ResponseStreamEvent =
|
||||
serde_json::from_str(&buffer)?;
|
||||
Ok(Some((event, body)))
|
||||
}
|
||||
Err(e) => Err(e.into()),
|
||||
}
|
||||
});
|
||||
|
||||
Ok(open_ai::extract_text_from_events(stream))
|
||||
});
|
||||
async move { Ok(future.await?.boxed()) }.boxed()
|
||||
} else {
|
||||
let future = self.request_limiter.stream(async move {
|
||||
let request = serde_json::to_string(&request)?;
|
||||
let stream = client
|
||||
@ -403,9 +457,48 @@ impl LanguageModel for CloudLanguageModel {
|
||||
});
|
||||
async move { Ok(future.await?.boxed()) }.boxed()
|
||||
}
|
||||
}
|
||||
CloudModel::Google(model) => {
|
||||
let client = self.client.clone();
|
||||
let request = request.into_google(model.id().into());
|
||||
|
||||
if cx
|
||||
.update(|cx| cx.has_flag::<LlmServiceFeatureFlag>())
|
||||
.unwrap_or(false)
|
||||
{
|
||||
let llm_api_token = self.llm_api_token.clone();
|
||||
let future = self.request_limiter.stream(async move {
|
||||
let response = Self::perform_llm_completion(
|
||||
client.clone(),
|
||||
llm_api_token,
|
||||
PerformCompletionParams {
|
||||
provider: client::LanguageModelProvider::Google,
|
||||
model: request.model.clone(),
|
||||
provider_request: RawValue::from_string(serde_json::to_string(
|
||||
&request,
|
||||
)?)?,
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
let body = BufReader::new(response.into_body());
|
||||
let stream =
|
||||
futures::stream::try_unfold(body, move |mut body| async move {
|
||||
let mut buffer = String::new();
|
||||
match body.read_line(&mut buffer).await {
|
||||
Ok(0) => Ok(None),
|
||||
Ok(_) => {
|
||||
let event: google_ai::GenerateContentResponse =
|
||||
serde_json::from_str(&buffer)?;
|
||||
Ok(Some((event, body)))
|
||||
}
|
||||
Err(e) => Err(e.into()),
|
||||
}
|
||||
});
|
||||
|
||||
Ok(google_ai::extract_text_from_events(stream))
|
||||
});
|
||||
async move { Ok(future.await?.boxed()) }.boxed()
|
||||
} else {
|
||||
let future = self.request_limiter.stream(async move {
|
||||
let request = serde_json::to_string(&request)?;
|
||||
let stream = client
|
||||
@ -420,10 +513,49 @@ impl LanguageModel for CloudLanguageModel {
|
||||
});
|
||||
async move { Ok(future.await?.boxed()) }.boxed()
|
||||
}
|
||||
}
|
||||
CloudModel::Zed(model) => {
|
||||
let client = self.client.clone();
|
||||
let mut request = request.into_open_ai(model.id().into());
|
||||
request.max_tokens = Some(4000);
|
||||
|
||||
if cx
|
||||
.update(|cx| cx.has_flag::<LlmServiceFeatureFlag>())
|
||||
.unwrap_or(false)
|
||||
{
|
||||
let llm_api_token = self.llm_api_token.clone();
|
||||
let future = self.request_limiter.stream(async move {
|
||||
let response = Self::perform_llm_completion(
|
||||
client.clone(),
|
||||
llm_api_token,
|
||||
PerformCompletionParams {
|
||||
provider: client::LanguageModelProvider::Zed,
|
||||
model: request.model.clone(),
|
||||
provider_request: RawValue::from_string(serde_json::to_string(
|
||||
&request,
|
||||
)?)?,
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
let body = BufReader::new(response.into_body());
|
||||
let stream =
|
||||
futures::stream::try_unfold(body, move |mut body| async move {
|
||||
let mut buffer = String::new();
|
||||
match body.read_line(&mut buffer).await {
|
||||
Ok(0) => Ok(None),
|
||||
Ok(_) => {
|
||||
let event: open_ai::ResponseStreamEvent =
|
||||
serde_json::from_str(&buffer)?;
|
||||
Ok(Some((event, body)))
|
||||
}
|
||||
Err(e) => Err(e.into()),
|
||||
}
|
||||
});
|
||||
|
||||
Ok(open_ai::extract_text_from_events(stream))
|
||||
});
|
||||
async move { Ok(future.await?.boxed()) }.boxed()
|
||||
} else {
|
||||
let future = self.request_limiter.stream(async move {
|
||||
let request = serde_json::to_string(&request)?;
|
||||
let stream = client
|
||||
@ -440,6 +572,7 @@ impl LanguageModel for CloudLanguageModel {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn use_any_tool(
|
||||
&self,
|
||||
|
@ -2,7 +2,18 @@ use serde::{Deserialize, Serialize};
|
||||
|
||||
pub const EXPIRED_LLM_TOKEN_HEADER_NAME: &str = "x-zed-expired-token";
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum LanguageModelProvider {
|
||||
Anthropic,
|
||||
OpenAi,
|
||||
Google,
|
||||
Zed,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct PerformCompletionParams {
|
||||
pub provider: LanguageModelProvider,
|
||||
pub model: String,
|
||||
pub provider_request: Box<serde_json::value::RawValue>,
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user