diff --git a/crates/collab/k8s/collab.template.yml b/crates/collab/k8s/collab.template.yml index 77be722228..aed2516557 100644 --- a/crates/collab/k8s/collab.template.yml +++ b/crates/collab/k8s/collab.template.yml @@ -127,6 +127,16 @@ spec: secretKeyRef: name: google-ai key: api_key + - name: QWEN2_7B_API_KEY + valueFrom: + secretKeyRef: + name: hugging-face + key: api_key + - name: QWEN2_7B_API_URL + valueFrom: + secretKeyRef: + name: hugging-face + key: qwen2_api_url - name: BLOB_STORE_ACCESS_KEY valueFrom: secretKeyRef: diff --git a/crates/collab/src/lib.rs b/crates/collab/src/lib.rs index b45c961efd..6925f62874 100644 --- a/crates/collab/src/lib.rs +++ b/crates/collab/src/lib.rs @@ -151,6 +151,8 @@ pub struct Config { pub openai_api_key: Option>, pub google_ai_api_key: Option>, pub anthropic_api_key: Option>, + pub qwen2_7b_api_key: Option>, + pub qwen2_7b_api_url: Option>, pub zed_client_checksum_seed: Option, pub slack_panics_webhook: Option, pub auto_join_channel_id: Option, diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 3683cdc5c8..939ab55110 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -4706,6 +4706,30 @@ async fn stream_complete_with_language_model( })?; } } + Some(proto::LanguageModelProvider::Zed) => { + let api_key = config + .qwen2_7b_api_key + .as_ref() + .context("no Qwen2-7B API key configured on the server")?; + let api_url = config + .qwen2_7b_api_url + .as_ref() + .context("no Qwen2-7B URL configured on the server")?; + let mut events = open_ai::stream_completion( + session.http_client.as_ref(), + &api_url, + api_key, + serde_json::from_str(&request.request)?, + None, + ) + .await?; + while let Some(event) = events.next().await { + let event = event?; + response.send(proto::StreamCompleteWithLanguageModelResponse { + event: serde_json::to_string(&event)?, + })?; + } + } None => return Err(anyhow!("unknown provider"))?, } diff --git a/crates/collab/src/tests/test_server.rs b/crates/collab/src/tests/test_server.rs index 76174f5953..d1aa42f28b 100644 --- a/crates/collab/src/tests/test_server.rs +++ b/crates/collab/src/tests/test_server.rs @@ -672,6 +672,8 @@ impl TestServer { stripe_api_key: None, stripe_price_id: None, supermaven_admin_api_key: None, + qwen2_7b_api_key: None, + qwen2_7b_api_url: None, }, }) } diff --git a/crates/language_model/src/model/cloud_model.rs b/crates/language_model/src/model/cloud_model.rs index 1023ee337a..76d530d909 100644 --- a/crates/language_model/src/model/cloud_model.rs +++ b/crates/language_model/src/model/cloud_model.rs @@ -1,5 +1,6 @@ use schemars::JsonSchema; use serde::{Deserialize, Serialize}; +use strum::EnumIter; #[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)] #[serde(tag = "provider", rename_all = "lowercase")] @@ -7,6 +8,33 @@ pub enum CloudModel { Anthropic(anthropic::Model), OpenAi(open_ai::Model), Google(google_ai::Model), + Zed(ZedModel), +} + +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema, EnumIter)] +pub enum ZedModel { + #[serde(rename = "qwen2-7b-instruct")] + Qwen2_7bInstruct, +} + +impl ZedModel { + pub fn id(&self) -> &str { + match self { + ZedModel::Qwen2_7bInstruct => "qwen2-7b-instruct", + } + } + + pub fn display_name(&self) -> &str { + match self { + ZedModel::Qwen2_7bInstruct => "Qwen2 7B Instruct", + } + } + + pub fn max_token_count(&self) -> usize { + match self { + ZedModel::Qwen2_7bInstruct => 8192, + } + } } impl Default for CloudModel { @@ -21,6 +49,7 @@ impl CloudModel { CloudModel::Anthropic(model) => model.id(), CloudModel::OpenAi(model) => model.id(), CloudModel::Google(model) => model.id(), + CloudModel::Zed(model) => model.id(), } } @@ -29,6 +58,7 @@ impl CloudModel { CloudModel::Anthropic(model) => model.display_name(), CloudModel::OpenAi(model) => model.display_name(), CloudModel::Google(model) => model.display_name(), + CloudModel::Zed(model) => model.display_name(), } } @@ -37,6 +67,7 @@ impl CloudModel { CloudModel::Anthropic(model) => model.max_token_count(), CloudModel::OpenAi(model) => model.max_token_count(), CloudModel::Google(model) => model.max_token_count(), + CloudModel::Zed(model) => model.max_token_count(), } } } diff --git a/crates/language_model/src/provider/cloud.rs b/crates/language_model/src/provider/cloud.rs index b16dc36be8..0c6402c7ab 100644 --- a/crates/language_model/src/provider/cloud.rs +++ b/crates/language_model/src/provider/cloud.rs @@ -2,7 +2,7 @@ use super::open_ai::count_open_ai_tokens; use crate::{ settings::AllLanguageModelSettings, CloudModel, LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProviderId, LanguageModelProviderName, - LanguageModelProviderState, LanguageModelRequest, RateLimiter, + LanguageModelProviderState, LanguageModelRequest, RateLimiter, ZedModel, }; use anyhow::{anyhow, Context as _, Result}; use client::{Client, UserStore}; @@ -146,6 +146,9 @@ impl LanguageModelProvider for CloudLanguageModelProvider { models.insert(model.id().to_string(), CloudModel::Google(model)); } } + for model in ZedModel::iter() { + models.insert(model.id().to_string(), CloudModel::Zed(model)); + } // Override with available models from settings for model in &AllLanguageModelSettings::get_global(cx) @@ -263,6 +266,9 @@ impl LanguageModel for CloudLanguageModel { } .boxed() } + CloudModel::Zed(_) => { + count_open_ai_tokens(request, open_ai::Model::ThreePointFiveTurbo, cx) + } } } @@ -323,6 +329,24 @@ impl LanguageModel for CloudLanguageModel { }); async move { Ok(future.await?.boxed()) }.boxed() } + CloudModel::Zed(model) => { + let client = self.client.clone(); + let mut request = request.into_open_ai(model.id().into()); + request.max_tokens = Some(4000); + let future = self.request_limiter.stream(async move { + let request = serde_json::to_string(&request)?; + let stream = client + .request_stream(proto::StreamCompleteWithLanguageModel { + provider: proto::LanguageModelProvider::Zed as i32, + request, + }) + .await?; + Ok(open_ai::extract_text_from_events( + stream.map(|item| Ok(serde_json::from_str(&item?.event)?)), + )) + }); + async move { Ok(future.await?.boxed()) }.boxed() + } } } @@ -382,6 +406,9 @@ impl LanguageModel for CloudLanguageModel { CloudModel::Google(_) => { future::ready(Err(anyhow!("tool use not implemented for Google AI"))).boxed() } + CloudModel::Zed(_) => { + future::ready(Err(anyhow!("tool use not implemented for Zed models"))).boxed() + } } } } diff --git a/crates/language_model/src/request.rs b/crates/language_model/src/request.rs index ca57706f15..243dcf906b 100644 --- a/crates/language_model/src/request.rs +++ b/crates/language_model/src/request.rs @@ -37,6 +37,7 @@ impl LanguageModelRequest { stream: true, stop: self.stop, temperature: self.temperature, + max_tokens: None, tools: Vec::new(), tool_choice: None, } diff --git a/crates/open_ai/src/open_ai.rs b/crates/open_ai/src/open_ai.rs index 13a6eb11d1..eb7769c4b6 100644 --- a/crates/open_ai/src/open_ai.rs +++ b/crates/open_ai/src/open_ai.rs @@ -116,6 +116,8 @@ pub struct Request { pub model: String, pub messages: Vec, pub stream: bool, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub max_tokens: Option, pub stop: Vec, pub temperature: f32, #[serde(default, skip_serializing_if = "Option::is_none")] @@ -216,6 +218,13 @@ pub struct ChoiceDelta { pub finish_reason: Option, } +#[derive(Serialize, Deserialize, Debug)] +#[serde(untagged)] +pub enum ResponseStreamResult { + Ok(ResponseStreamEvent), + Err { error: String }, +} + #[derive(Serialize, Deserialize, Debug)] pub struct ResponseStreamEvent { pub created: u32, @@ -256,7 +265,10 @@ pub async fn stream_completion( None } else { match serde_json::from_str(line) { - Ok(response) => Some(Ok(response)), + Ok(ResponseStreamResult::Ok(response)) => Some(Ok(response)), + Ok(ResponseStreamResult::Err { error }) => { + Some(Err(anyhow!(error))) + } Err(error) => Some(Err(anyhow!(error))), } } diff --git a/crates/proto/proto/zed.proto b/crates/proto/proto/zed.proto index b396abe8e3..55cfc77e30 100644 --- a/crates/proto/proto/zed.proto +++ b/crates/proto/proto/zed.proto @@ -2099,6 +2099,7 @@ enum LanguageModelProvider { Anthropic = 0; OpenAI = 1; Google = 2; + Zed = 3; } message GetCachedEmbeddings {