From ca9511393b2e5cc5f68311def63669e438209eb5 Mon Sep 17 00:00:00 2001 From: Marshall Bowers Date: Mon, 5 Aug 2024 21:16:18 -0400 Subject: [PATCH] collab: Add support for more providers to the LLM service (#15832) This PR adds support for additional providers to the LLM service: - OpenAI - Google - Custom Zed models (through Hugging Face) Release Notes: - N/A --- crates/collab/src/llm.rs | 137 ++++++++-- crates/language_model/src/provider/cloud.rs | 281 ++++++++++++++------ crates/rpc/src/llm.rs | 11 + 3 files changed, 331 insertions(+), 98 deletions(-) diff --git a/crates/collab/src/llm.rs b/crates/collab/src/llm.rs index e3e17562fa..1c5cf8625b 100644 --- a/crates/collab/src/llm.rs +++ b/crates/collab/src/llm.rs @@ -12,7 +12,7 @@ use axum::{ }; use futures::StreamExt as _; use http_client::IsahcHttpClient; -use rpc::{PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME}; +use rpc::{LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME}; use std::sync::Arc; pub use token::*; @@ -94,29 +94,118 @@ async fn perform_completion( Extension(_claims): Extension, Json(params): Json, ) -> Result { - let api_key = state - .config - .anthropic_api_key - .as_ref() - .context("no Anthropic AI API key configured on the server")?; - let chunks = anthropic::stream_completion( - &state.http_client, - anthropic::ANTHROPIC_API_URL, - api_key, - serde_json::from_str(¶ms.provider_request.get())?, - None, - ) - .await?; + match params.provider { + LanguageModelProvider::Anthropic => { + let api_key = state + .config + .anthropic_api_key + .as_ref() + .context("no Anthropic AI API key configured on the server")?; + let chunks = anthropic::stream_completion( + &state.http_client, + anthropic::ANTHROPIC_API_URL, + api_key, + serde_json::from_str(¶ms.provider_request.get())?, + None, + ) + .await?; - let stream = chunks.map(|event| { - let mut buffer = Vec::new(); - event.map(|chunk| { - buffer.clear(); - serde_json::to_writer(&mut buffer, &chunk).unwrap(); - buffer.push(b'\n'); - buffer - }) - }); + let stream = chunks.map(|event| { + let mut buffer = Vec::new(); + event.map(|chunk| { + buffer.clear(); + serde_json::to_writer(&mut buffer, &chunk).unwrap(); + buffer.push(b'\n'); + buffer + }) + }); - Ok(Response::new(Body::wrap_stream(stream))) + Ok(Response::new(Body::wrap_stream(stream))) + } + LanguageModelProvider::OpenAi => { + let api_key = state + .config + .openai_api_key + .as_ref() + .context("no OpenAI API key configured on the server")?; + let chunks = open_ai::stream_completion( + &state.http_client, + open_ai::OPEN_AI_API_URL, + api_key, + serde_json::from_str(¶ms.provider_request.get())?, + None, + ) + .await?; + + let stream = chunks.map(|event| { + let mut buffer = Vec::new(); + event.map(|chunk| { + buffer.clear(); + serde_json::to_writer(&mut buffer, &chunk).unwrap(); + buffer.push(b'\n'); + buffer + }) + }); + + Ok(Response::new(Body::wrap_stream(stream))) + } + LanguageModelProvider::Google => { + let api_key = state + .config + .google_ai_api_key + .as_ref() + .context("no Google AI API key configured on the server")?; + let chunks = google_ai::stream_generate_content( + &state.http_client, + google_ai::API_URL, + api_key, + serde_json::from_str(¶ms.provider_request.get())?, + ) + .await?; + + let stream = chunks.map(|event| { + let mut buffer = Vec::new(); + event.map(|chunk| { + buffer.clear(); + serde_json::to_writer(&mut buffer, &chunk).unwrap(); + buffer.push(b'\n'); + buffer + }) + }); + + Ok(Response::new(Body::wrap_stream(stream))) + } + LanguageModelProvider::Zed => { + let api_key = state + .config + .qwen2_7b_api_key + .as_ref() + .context("no Qwen2-7B API key configured on the server")?; + let api_url = state + .config + .qwen2_7b_api_url + .as_ref() + .context("no Qwen2-7B URL configured on the server")?; + let chunks = open_ai::stream_completion( + &state.http_client, + &api_url, + api_key, + serde_json::from_str(¶ms.provider_request.get())?, + None, + ) + .await?; + + let stream = chunks.map(|event| { + let mut buffer = Vec::new(); + event.map(|chunk| { + buffer.clear(); + serde_json::to_writer(&mut buffer, &chunk).unwrap(); + buffer.push(b'\n'); + buffer + }) + }); + + Ok(Response::new(Body::wrap_stream(stream))) + } + } } diff --git a/crates/language_model/src/provider/cloud.rs b/crates/language_model/src/provider/cloud.rs index 7862794e92..a7f7ca5164 100644 --- a/crates/language_model/src/provider/cloud.rs +++ b/crates/language_model/src/provider/cloud.rs @@ -10,7 +10,7 @@ use collections::BTreeMap; use feature_flags::{FeatureFlag, FeatureFlagAppExt}; use futures::{future::BoxFuture, stream::BoxStream, AsyncBufReadExt, FutureExt, StreamExt}; use gpui::{AnyView, AppContext, AsyncAppContext, Model, ModelContext, Subscription, Task}; -use http_client::{HttpClient, Method}; +use http_client::{AsyncBody, HttpClient, Method, Response}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use serde_json::value::RawValue; @@ -239,6 +239,47 @@ pub struct CloudLanguageModel { #[derive(Clone, Default)] struct LlmApiToken(Arc>>); +impl CloudLanguageModel { + async fn perform_llm_completion( + client: Arc, + llm_api_token: LlmApiToken, + body: PerformCompletionParams, + ) -> Result> { + let http_client = &client.http_client(); + + let mut token = llm_api_token.acquire(&client).await?; + let mut did_retry = false; + + let response = loop { + let request = http_client::Request::builder() + .method(Method::POST) + .uri(http_client.build_zed_llm_url("/completion", &[])?.as_ref()) + .header("Content-Type", "application/json") + .header("Authorization", format!("Bearer {token}")) + .body(serde_json::to_string(&body)?.into())?; + let response = http_client.send(request).await?; + if response.status().is_success() { + break response; + } else if !did_retry + && response + .headers() + .get(EXPIRED_LLM_TOKEN_HEADER_NAME) + .is_some() + { + did_retry = true; + token = llm_api_token.refresh(&client).await?; + } else { + break Err(anyhow!( + "cloud language model completion failed with status {}", + response.status() + ))?; + } + }; + + Ok(response) + } +} + impl LanguageModel for CloudLanguageModel { fn id(&self) -> LanguageModelId { self.id.clone() @@ -314,46 +355,21 @@ impl LanguageModel for CloudLanguageModel { .update(|cx| cx.has_flag::()) .unwrap_or(false) { - let http_client = self.client.http_client(); let llm_api_token = self.llm_api_token.clone(); let future = self.request_limiter.stream(async move { - let request = serde_json::to_string(&request)?; - let mut token = llm_api_token.acquire(&client).await?; - let mut did_retry = false; - - let response = loop { - let request = http_client::Request::builder() - .method(Method::POST) - .uri(http_client.build_zed_llm_url("/completion", &[])?.as_ref()) - .header("Content-Type", "application/json") - .header("Authorization", format!("Bearer {token}")) - .body( - serde_json::to_string(&PerformCompletionParams { - provider_request: RawValue::from_string(request.clone())?, - })? - .into(), - )?; - let response = http_client.send(request).await?; - if response.status().is_success() { - break response; - } else if !did_retry - && response - .headers() - .get(EXPIRED_LLM_TOKEN_HEADER_NAME) - .is_some() - { - did_retry = true; - token = llm_api_token.refresh(&client).await?; - } else { - break Err(anyhow!( - "cloud language model completion failed with status {}", - response.status() - ))?; - } - }; - + let response = Self::perform_llm_completion( + client.clone(), + llm_api_token, + PerformCompletionParams { + provider: client::LanguageModelProvider::Anthropic, + model: request.model.clone(), + provider_request: RawValue::from_string(serde_json::to_string( + &request, + )?)?, + }, + ) + .await?; let body = BufReader::new(response.into_body()); - let stream = futures::stream::try_unfold(body, move |mut body| async move { let mut buffer = String::new(); @@ -389,54 +405,171 @@ impl LanguageModel for CloudLanguageModel { CloudModel::OpenAi(model) => { let client = self.client.clone(); let request = request.into_open_ai(model.id().into()); - let future = self.request_limiter.stream(async move { - let request = serde_json::to_string(&request)?; - let stream = client - .request_stream(proto::StreamCompleteWithLanguageModel { - provider: proto::LanguageModelProvider::OpenAi as i32, - request, - }) + + if cx + .update(|cx| cx.has_flag::()) + .unwrap_or(false) + { + let llm_api_token = self.llm_api_token.clone(); + let future = self.request_limiter.stream(async move { + let response = Self::perform_llm_completion( + client.clone(), + llm_api_token, + PerformCompletionParams { + provider: client::LanguageModelProvider::OpenAi, + model: request.model.clone(), + provider_request: RawValue::from_string(serde_json::to_string( + &request, + )?)?, + }, + ) .await?; - Ok(open_ai::extract_text_from_events( - stream.map(|item| Ok(serde_json::from_str(&item?.event)?)), - )) - }); - async move { Ok(future.await?.boxed()) }.boxed() + let body = BufReader::new(response.into_body()); + let stream = + futures::stream::try_unfold(body, move |mut body| async move { + let mut buffer = String::new(); + match body.read_line(&mut buffer).await { + Ok(0) => Ok(None), + Ok(_) => { + let event: open_ai::ResponseStreamEvent = + serde_json::from_str(&buffer)?; + Ok(Some((event, body))) + } + Err(e) => Err(e.into()), + } + }); + + Ok(open_ai::extract_text_from_events(stream)) + }); + async move { Ok(future.await?.boxed()) }.boxed() + } else { + let future = self.request_limiter.stream(async move { + let request = serde_json::to_string(&request)?; + let stream = client + .request_stream(proto::StreamCompleteWithLanguageModel { + provider: proto::LanguageModelProvider::OpenAi as i32, + request, + }) + .await?; + Ok(open_ai::extract_text_from_events( + stream.map(|item| Ok(serde_json::from_str(&item?.event)?)), + )) + }); + async move { Ok(future.await?.boxed()) }.boxed() + } } CloudModel::Google(model) => { let client = self.client.clone(); let request = request.into_google(model.id().into()); - let future = self.request_limiter.stream(async move { - let request = serde_json::to_string(&request)?; - let stream = client - .request_stream(proto::StreamCompleteWithLanguageModel { - provider: proto::LanguageModelProvider::Google as i32, - request, - }) + + if cx + .update(|cx| cx.has_flag::()) + .unwrap_or(false) + { + let llm_api_token = self.llm_api_token.clone(); + let future = self.request_limiter.stream(async move { + let response = Self::perform_llm_completion( + client.clone(), + llm_api_token, + PerformCompletionParams { + provider: client::LanguageModelProvider::Google, + model: request.model.clone(), + provider_request: RawValue::from_string(serde_json::to_string( + &request, + )?)?, + }, + ) .await?; - Ok(google_ai::extract_text_from_events( - stream.map(|item| Ok(serde_json::from_str(&item?.event)?)), - )) - }); - async move { Ok(future.await?.boxed()) }.boxed() + let body = BufReader::new(response.into_body()); + let stream = + futures::stream::try_unfold(body, move |mut body| async move { + let mut buffer = String::new(); + match body.read_line(&mut buffer).await { + Ok(0) => Ok(None), + Ok(_) => { + let event: google_ai::GenerateContentResponse = + serde_json::from_str(&buffer)?; + Ok(Some((event, body))) + } + Err(e) => Err(e.into()), + } + }); + + Ok(google_ai::extract_text_from_events(stream)) + }); + async move { Ok(future.await?.boxed()) }.boxed() + } else { + let future = self.request_limiter.stream(async move { + let request = serde_json::to_string(&request)?; + let stream = client + .request_stream(proto::StreamCompleteWithLanguageModel { + provider: proto::LanguageModelProvider::Google as i32, + request, + }) + .await?; + Ok(google_ai::extract_text_from_events( + stream.map(|item| Ok(serde_json::from_str(&item?.event)?)), + )) + }); + async move { Ok(future.await?.boxed()) }.boxed() + } } CloudModel::Zed(model) => { let client = self.client.clone(); let mut request = request.into_open_ai(model.id().into()); request.max_tokens = Some(4000); - let future = self.request_limiter.stream(async move { - let request = serde_json::to_string(&request)?; - let stream = client - .request_stream(proto::StreamCompleteWithLanguageModel { - provider: proto::LanguageModelProvider::Zed as i32, - request, - }) + + if cx + .update(|cx| cx.has_flag::()) + .unwrap_or(false) + { + let llm_api_token = self.llm_api_token.clone(); + let future = self.request_limiter.stream(async move { + let response = Self::perform_llm_completion( + client.clone(), + llm_api_token, + PerformCompletionParams { + provider: client::LanguageModelProvider::Zed, + model: request.model.clone(), + provider_request: RawValue::from_string(serde_json::to_string( + &request, + )?)?, + }, + ) .await?; - Ok(open_ai::extract_text_from_events( - stream.map(|item| Ok(serde_json::from_str(&item?.event)?)), - )) - }); - async move { Ok(future.await?.boxed()) }.boxed() + let body = BufReader::new(response.into_body()); + let stream = + futures::stream::try_unfold(body, move |mut body| async move { + let mut buffer = String::new(); + match body.read_line(&mut buffer).await { + Ok(0) => Ok(None), + Ok(_) => { + let event: open_ai::ResponseStreamEvent = + serde_json::from_str(&buffer)?; + Ok(Some((event, body))) + } + Err(e) => Err(e.into()), + } + }); + + Ok(open_ai::extract_text_from_events(stream)) + }); + async move { Ok(future.await?.boxed()) }.boxed() + } else { + let future = self.request_limiter.stream(async move { + let request = serde_json::to_string(&request)?; + let stream = client + .request_stream(proto::StreamCompleteWithLanguageModel { + provider: proto::LanguageModelProvider::Zed as i32, + request, + }) + .await?; + Ok(open_ai::extract_text_from_events( + stream.map(|item| Ok(serde_json::from_str(&item?.event)?)), + )) + }); + async move { Ok(future.await?.boxed()) }.boxed() + } } } } diff --git a/crates/rpc/src/llm.rs b/crates/rpc/src/llm.rs index 64df4110ef..2b1f4b9f4d 100644 --- a/crates/rpc/src/llm.rs +++ b/crates/rpc/src/llm.rs @@ -2,7 +2,18 @@ use serde::{Deserialize, Serialize}; pub const EXPIRED_LLM_TOKEN_HEADER_NAME: &str = "x-zed-expired-token"; +#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum LanguageModelProvider { + Anthropic, + OpenAi, + Google, + Zed, +} + #[derive(Serialize, Deserialize)] pub struct PerformCompletionParams { + pub provider: LanguageModelProvider, + pub model: String, pub provider_request: Box, }