mirror of
https://github.com/zed-industries/zed.git
synced 2024-11-08 07:35:01 +03:00
Add Qwen2-7B to the list of zed.dev models (#15649)
Release Notes: - N/A --------- Co-authored-by: Nathan <nathan@zed.dev>
This commit is contained in:
parent
60127f2a8d
commit
21816d1ff5
@ -127,6 +127,16 @@ spec:
|
||||
secretKeyRef:
|
||||
name: google-ai
|
||||
key: api_key
|
||||
- name: QWEN2_7B_API_KEY
|
||||
valueFrom:
|
||||
secretKeyRef:
|
||||
name: hugging-face
|
||||
key: api_key
|
||||
- name: QWEN2_7B_API_URL
|
||||
valueFrom:
|
||||
secretKeyRef:
|
||||
name: hugging-face
|
||||
key: qwen2_api_url
|
||||
- name: BLOB_STORE_ACCESS_KEY
|
||||
valueFrom:
|
||||
secretKeyRef:
|
||||
|
@ -151,6 +151,8 @@ pub struct Config {
|
||||
pub openai_api_key: Option<Arc<str>>,
|
||||
pub google_ai_api_key: Option<Arc<str>>,
|
||||
pub anthropic_api_key: Option<Arc<str>>,
|
||||
pub qwen2_7b_api_key: Option<Arc<str>>,
|
||||
pub qwen2_7b_api_url: Option<Arc<str>>,
|
||||
pub zed_client_checksum_seed: Option<String>,
|
||||
pub slack_panics_webhook: Option<String>,
|
||||
pub auto_join_channel_id: Option<ChannelId>,
|
||||
|
@ -4706,6 +4706,30 @@ async fn stream_complete_with_language_model(
|
||||
})?;
|
||||
}
|
||||
}
|
||||
Some(proto::LanguageModelProvider::Zed) => {
|
||||
let api_key = config
|
||||
.qwen2_7b_api_key
|
||||
.as_ref()
|
||||
.context("no Qwen2-7B API key configured on the server")?;
|
||||
let api_url = config
|
||||
.qwen2_7b_api_url
|
||||
.as_ref()
|
||||
.context("no Qwen2-7B URL configured on the server")?;
|
||||
let mut events = open_ai::stream_completion(
|
||||
session.http_client.as_ref(),
|
||||
&api_url,
|
||||
api_key,
|
||||
serde_json::from_str(&request.request)?,
|
||||
None,
|
||||
)
|
||||
.await?;
|
||||
while let Some(event) = events.next().await {
|
||||
let event = event?;
|
||||
response.send(proto::StreamCompleteWithLanguageModelResponse {
|
||||
event: serde_json::to_string(&event)?,
|
||||
})?;
|
||||
}
|
||||
}
|
||||
None => return Err(anyhow!("unknown provider"))?,
|
||||
}
|
||||
|
||||
|
@ -672,6 +672,8 @@ impl TestServer {
|
||||
stripe_api_key: None,
|
||||
stripe_price_id: None,
|
||||
supermaven_admin_api_key: None,
|
||||
qwen2_7b_api_key: None,
|
||||
qwen2_7b_api_url: None,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
@ -1,5 +1,6 @@
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use strum::EnumIter;
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
|
||||
#[serde(tag = "provider", rename_all = "lowercase")]
|
||||
@ -7,6 +8,33 @@ pub enum CloudModel {
|
||||
Anthropic(anthropic::Model),
|
||||
OpenAi(open_ai::Model),
|
||||
Google(google_ai::Model),
|
||||
Zed(ZedModel),
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema, EnumIter)]
|
||||
pub enum ZedModel {
|
||||
#[serde(rename = "qwen2-7b-instruct")]
|
||||
Qwen2_7bInstruct,
|
||||
}
|
||||
|
||||
impl ZedModel {
|
||||
pub fn id(&self) -> &str {
|
||||
match self {
|
||||
ZedModel::Qwen2_7bInstruct => "qwen2-7b-instruct",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn display_name(&self) -> &str {
|
||||
match self {
|
||||
ZedModel::Qwen2_7bInstruct => "Qwen2 7B Instruct",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn max_token_count(&self) -> usize {
|
||||
match self {
|
||||
ZedModel::Qwen2_7bInstruct => 8192,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for CloudModel {
|
||||
@ -21,6 +49,7 @@ impl CloudModel {
|
||||
CloudModel::Anthropic(model) => model.id(),
|
||||
CloudModel::OpenAi(model) => model.id(),
|
||||
CloudModel::Google(model) => model.id(),
|
||||
CloudModel::Zed(model) => model.id(),
|
||||
}
|
||||
}
|
||||
|
||||
@ -29,6 +58,7 @@ impl CloudModel {
|
||||
CloudModel::Anthropic(model) => model.display_name(),
|
||||
CloudModel::OpenAi(model) => model.display_name(),
|
||||
CloudModel::Google(model) => model.display_name(),
|
||||
CloudModel::Zed(model) => model.display_name(),
|
||||
}
|
||||
}
|
||||
|
||||
@ -37,6 +67,7 @@ impl CloudModel {
|
||||
CloudModel::Anthropic(model) => model.max_token_count(),
|
||||
CloudModel::OpenAi(model) => model.max_token_count(),
|
||||
CloudModel::Google(model) => model.max_token_count(),
|
||||
CloudModel::Zed(model) => model.max_token_count(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -2,7 +2,7 @@ use super::open_ai::count_open_ai_tokens;
|
||||
use crate::{
|
||||
settings::AllLanguageModelSettings, CloudModel, LanguageModel, LanguageModelId,
|
||||
LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
|
||||
LanguageModelProviderState, LanguageModelRequest, RateLimiter,
|
||||
LanguageModelProviderState, LanguageModelRequest, RateLimiter, ZedModel,
|
||||
};
|
||||
use anyhow::{anyhow, Context as _, Result};
|
||||
use client::{Client, UserStore};
|
||||
@ -146,6 +146,9 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
|
||||
models.insert(model.id().to_string(), CloudModel::Google(model));
|
||||
}
|
||||
}
|
||||
for model in ZedModel::iter() {
|
||||
models.insert(model.id().to_string(), CloudModel::Zed(model));
|
||||
}
|
||||
|
||||
// Override with available models from settings
|
||||
for model in &AllLanguageModelSettings::get_global(cx)
|
||||
@ -263,6 +266,9 @@ impl LanguageModel for CloudLanguageModel {
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
CloudModel::Zed(_) => {
|
||||
count_open_ai_tokens(request, open_ai::Model::ThreePointFiveTurbo, cx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -323,6 +329,24 @@ impl LanguageModel for CloudLanguageModel {
|
||||
});
|
||||
async move { Ok(future.await?.boxed()) }.boxed()
|
||||
}
|
||||
CloudModel::Zed(model) => {
|
||||
let client = self.client.clone();
|
||||
let mut request = request.into_open_ai(model.id().into());
|
||||
request.max_tokens = Some(4000);
|
||||
let future = self.request_limiter.stream(async move {
|
||||
let request = serde_json::to_string(&request)?;
|
||||
let stream = client
|
||||
.request_stream(proto::StreamCompleteWithLanguageModel {
|
||||
provider: proto::LanguageModelProvider::Zed as i32,
|
||||
request,
|
||||
})
|
||||
.await?;
|
||||
Ok(open_ai::extract_text_from_events(
|
||||
stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
|
||||
))
|
||||
});
|
||||
async move { Ok(future.await?.boxed()) }.boxed()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -382,6 +406,9 @@ impl LanguageModel for CloudLanguageModel {
|
||||
CloudModel::Google(_) => {
|
||||
future::ready(Err(anyhow!("tool use not implemented for Google AI"))).boxed()
|
||||
}
|
||||
CloudModel::Zed(_) => {
|
||||
future::ready(Err(anyhow!("tool use not implemented for Zed models"))).boxed()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -37,6 +37,7 @@ impl LanguageModelRequest {
|
||||
stream: true,
|
||||
stop: self.stop,
|
||||
temperature: self.temperature,
|
||||
max_tokens: None,
|
||||
tools: Vec::new(),
|
||||
tool_choice: None,
|
||||
}
|
||||
|
@ -116,6 +116,8 @@ pub struct Request {
|
||||
pub model: String,
|
||||
pub messages: Vec<RequestMessage>,
|
||||
pub stream: bool,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub max_tokens: Option<usize>,
|
||||
pub stop: Vec<String>,
|
||||
pub temperature: f32,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
@ -216,6 +218,13 @@ pub struct ChoiceDelta {
|
||||
pub finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[serde(untagged)]
|
||||
pub enum ResponseStreamResult {
|
||||
Ok(ResponseStreamEvent),
|
||||
Err { error: String },
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct ResponseStreamEvent {
|
||||
pub created: u32,
|
||||
@ -256,7 +265,10 @@ pub async fn stream_completion(
|
||||
None
|
||||
} else {
|
||||
match serde_json::from_str(line) {
|
||||
Ok(response) => Some(Ok(response)),
|
||||
Ok(ResponseStreamResult::Ok(response)) => Some(Ok(response)),
|
||||
Ok(ResponseStreamResult::Err { error }) => {
|
||||
Some(Err(anyhow!(error)))
|
||||
}
|
||||
Err(error) => Some(Err(anyhow!(error))),
|
||||
}
|
||||
}
|
||||
|
@ -2099,6 +2099,7 @@ enum LanguageModelProvider {
|
||||
Anthropic = 0;
|
||||
OpenAI = 1;
|
||||
Google = 2;
|
||||
Zed = 3;
|
||||
}
|
||||
|
||||
message GetCachedEmbeddings {
|
||||
|
Loading…
Reference in New Issue
Block a user