mirror of
https://github.com/zed-industries/zed.git
synced 2024-11-08 07:35:01 +03:00
Add Qwen2-7B to the list of zed.dev models (#15649)
Release Notes: - N/A --------- Co-authored-by: Nathan <nathan@zed.dev>
This commit is contained in:
parent
60127f2a8d
commit
21816d1ff5
@ -127,6 +127,16 @@ spec:
|
|||||||
secretKeyRef:
|
secretKeyRef:
|
||||||
name: google-ai
|
name: google-ai
|
||||||
key: api_key
|
key: api_key
|
||||||
|
- name: QWEN2_7B_API_KEY
|
||||||
|
valueFrom:
|
||||||
|
secretKeyRef:
|
||||||
|
name: hugging-face
|
||||||
|
key: api_key
|
||||||
|
- name: QWEN2_7B_API_URL
|
||||||
|
valueFrom:
|
||||||
|
secretKeyRef:
|
||||||
|
name: hugging-face
|
||||||
|
key: qwen2_api_url
|
||||||
- name: BLOB_STORE_ACCESS_KEY
|
- name: BLOB_STORE_ACCESS_KEY
|
||||||
valueFrom:
|
valueFrom:
|
||||||
secretKeyRef:
|
secretKeyRef:
|
||||||
|
@ -151,6 +151,8 @@ pub struct Config {
|
|||||||
pub openai_api_key: Option<Arc<str>>,
|
pub openai_api_key: Option<Arc<str>>,
|
||||||
pub google_ai_api_key: Option<Arc<str>>,
|
pub google_ai_api_key: Option<Arc<str>>,
|
||||||
pub anthropic_api_key: Option<Arc<str>>,
|
pub anthropic_api_key: Option<Arc<str>>,
|
||||||
|
pub qwen2_7b_api_key: Option<Arc<str>>,
|
||||||
|
pub qwen2_7b_api_url: Option<Arc<str>>,
|
||||||
pub zed_client_checksum_seed: Option<String>,
|
pub zed_client_checksum_seed: Option<String>,
|
||||||
pub slack_panics_webhook: Option<String>,
|
pub slack_panics_webhook: Option<String>,
|
||||||
pub auto_join_channel_id: Option<ChannelId>,
|
pub auto_join_channel_id: Option<ChannelId>,
|
||||||
|
@ -4706,6 +4706,30 @@ async fn stream_complete_with_language_model(
|
|||||||
})?;
|
})?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Some(proto::LanguageModelProvider::Zed) => {
|
||||||
|
let api_key = config
|
||||||
|
.qwen2_7b_api_key
|
||||||
|
.as_ref()
|
||||||
|
.context("no Qwen2-7B API key configured on the server")?;
|
||||||
|
let api_url = config
|
||||||
|
.qwen2_7b_api_url
|
||||||
|
.as_ref()
|
||||||
|
.context("no Qwen2-7B URL configured on the server")?;
|
||||||
|
let mut events = open_ai::stream_completion(
|
||||||
|
session.http_client.as_ref(),
|
||||||
|
&api_url,
|
||||||
|
api_key,
|
||||||
|
serde_json::from_str(&request.request)?,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
while let Some(event) = events.next().await {
|
||||||
|
let event = event?;
|
||||||
|
response.send(proto::StreamCompleteWithLanguageModelResponse {
|
||||||
|
event: serde_json::to_string(&event)?,
|
||||||
|
})?;
|
||||||
|
}
|
||||||
|
}
|
||||||
None => return Err(anyhow!("unknown provider"))?,
|
None => return Err(anyhow!("unknown provider"))?,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -672,6 +672,8 @@ impl TestServer {
|
|||||||
stripe_api_key: None,
|
stripe_api_key: None,
|
||||||
stripe_price_id: None,
|
stripe_price_id: None,
|
||||||
supermaven_admin_api_key: None,
|
supermaven_admin_api_key: None,
|
||||||
|
qwen2_7b_api_key: None,
|
||||||
|
qwen2_7b_api_url: None,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
use schemars::JsonSchema;
|
use schemars::JsonSchema;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
use strum::EnumIter;
|
||||||
|
|
||||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
|
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
|
||||||
#[serde(tag = "provider", rename_all = "lowercase")]
|
#[serde(tag = "provider", rename_all = "lowercase")]
|
||||||
@ -7,6 +8,33 @@ pub enum CloudModel {
|
|||||||
Anthropic(anthropic::Model),
|
Anthropic(anthropic::Model),
|
||||||
OpenAi(open_ai::Model),
|
OpenAi(open_ai::Model),
|
||||||
Google(google_ai::Model),
|
Google(google_ai::Model),
|
||||||
|
Zed(ZedModel),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema, EnumIter)]
|
||||||
|
pub enum ZedModel {
|
||||||
|
#[serde(rename = "qwen2-7b-instruct")]
|
||||||
|
Qwen2_7bInstruct,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ZedModel {
|
||||||
|
pub fn id(&self) -> &str {
|
||||||
|
match self {
|
||||||
|
ZedModel::Qwen2_7bInstruct => "qwen2-7b-instruct",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn display_name(&self) -> &str {
|
||||||
|
match self {
|
||||||
|
ZedModel::Qwen2_7bInstruct => "Qwen2 7B Instruct",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn max_token_count(&self) -> usize {
|
||||||
|
match self {
|
||||||
|
ZedModel::Qwen2_7bInstruct => 8192,
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for CloudModel {
|
impl Default for CloudModel {
|
||||||
@ -21,6 +49,7 @@ impl CloudModel {
|
|||||||
CloudModel::Anthropic(model) => model.id(),
|
CloudModel::Anthropic(model) => model.id(),
|
||||||
CloudModel::OpenAi(model) => model.id(),
|
CloudModel::OpenAi(model) => model.id(),
|
||||||
CloudModel::Google(model) => model.id(),
|
CloudModel::Google(model) => model.id(),
|
||||||
|
CloudModel::Zed(model) => model.id(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -29,6 +58,7 @@ impl CloudModel {
|
|||||||
CloudModel::Anthropic(model) => model.display_name(),
|
CloudModel::Anthropic(model) => model.display_name(),
|
||||||
CloudModel::OpenAi(model) => model.display_name(),
|
CloudModel::OpenAi(model) => model.display_name(),
|
||||||
CloudModel::Google(model) => model.display_name(),
|
CloudModel::Google(model) => model.display_name(),
|
||||||
|
CloudModel::Zed(model) => model.display_name(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -37,6 +67,7 @@ impl CloudModel {
|
|||||||
CloudModel::Anthropic(model) => model.max_token_count(),
|
CloudModel::Anthropic(model) => model.max_token_count(),
|
||||||
CloudModel::OpenAi(model) => model.max_token_count(),
|
CloudModel::OpenAi(model) => model.max_token_count(),
|
||||||
CloudModel::Google(model) => model.max_token_count(),
|
CloudModel::Google(model) => model.max_token_count(),
|
||||||
|
CloudModel::Zed(model) => model.max_token_count(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -2,7 +2,7 @@ use super::open_ai::count_open_ai_tokens;
|
|||||||
use crate::{
|
use crate::{
|
||||||
settings::AllLanguageModelSettings, CloudModel, LanguageModel, LanguageModelId,
|
settings::AllLanguageModelSettings, CloudModel, LanguageModel, LanguageModelId,
|
||||||
LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
|
LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
|
||||||
LanguageModelProviderState, LanguageModelRequest, RateLimiter,
|
LanguageModelProviderState, LanguageModelRequest, RateLimiter, ZedModel,
|
||||||
};
|
};
|
||||||
use anyhow::{anyhow, Context as _, Result};
|
use anyhow::{anyhow, Context as _, Result};
|
||||||
use client::{Client, UserStore};
|
use client::{Client, UserStore};
|
||||||
@ -146,6 +146,9 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
|
|||||||
models.insert(model.id().to_string(), CloudModel::Google(model));
|
models.insert(model.id().to_string(), CloudModel::Google(model));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
for model in ZedModel::iter() {
|
||||||
|
models.insert(model.id().to_string(), CloudModel::Zed(model));
|
||||||
|
}
|
||||||
|
|
||||||
// Override with available models from settings
|
// Override with available models from settings
|
||||||
for model in &AllLanguageModelSettings::get_global(cx)
|
for model in &AllLanguageModelSettings::get_global(cx)
|
||||||
@ -263,6 +266,9 @@ impl LanguageModel for CloudLanguageModel {
|
|||||||
}
|
}
|
||||||
.boxed()
|
.boxed()
|
||||||
}
|
}
|
||||||
|
CloudModel::Zed(_) => {
|
||||||
|
count_open_ai_tokens(request, open_ai::Model::ThreePointFiveTurbo, cx)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -323,6 +329,24 @@ impl LanguageModel for CloudLanguageModel {
|
|||||||
});
|
});
|
||||||
async move { Ok(future.await?.boxed()) }.boxed()
|
async move { Ok(future.await?.boxed()) }.boxed()
|
||||||
}
|
}
|
||||||
|
CloudModel::Zed(model) => {
|
||||||
|
let client = self.client.clone();
|
||||||
|
let mut request = request.into_open_ai(model.id().into());
|
||||||
|
request.max_tokens = Some(4000);
|
||||||
|
let future = self.request_limiter.stream(async move {
|
||||||
|
let request = serde_json::to_string(&request)?;
|
||||||
|
let stream = client
|
||||||
|
.request_stream(proto::StreamCompleteWithLanguageModel {
|
||||||
|
provider: proto::LanguageModelProvider::Zed as i32,
|
||||||
|
request,
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
Ok(open_ai::extract_text_from_events(
|
||||||
|
stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
|
||||||
|
))
|
||||||
|
});
|
||||||
|
async move { Ok(future.await?.boxed()) }.boxed()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -382,6 +406,9 @@ impl LanguageModel for CloudLanguageModel {
|
|||||||
CloudModel::Google(_) => {
|
CloudModel::Google(_) => {
|
||||||
future::ready(Err(anyhow!("tool use not implemented for Google AI"))).boxed()
|
future::ready(Err(anyhow!("tool use not implemented for Google AI"))).boxed()
|
||||||
}
|
}
|
||||||
|
CloudModel::Zed(_) => {
|
||||||
|
future::ready(Err(anyhow!("tool use not implemented for Zed models"))).boxed()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -37,6 +37,7 @@ impl LanguageModelRequest {
|
|||||||
stream: true,
|
stream: true,
|
||||||
stop: self.stop,
|
stop: self.stop,
|
||||||
temperature: self.temperature,
|
temperature: self.temperature,
|
||||||
|
max_tokens: None,
|
||||||
tools: Vec::new(),
|
tools: Vec::new(),
|
||||||
tool_choice: None,
|
tool_choice: None,
|
||||||
}
|
}
|
||||||
|
@ -116,6 +116,8 @@ pub struct Request {
|
|||||||
pub model: String,
|
pub model: String,
|
||||||
pub messages: Vec<RequestMessage>,
|
pub messages: Vec<RequestMessage>,
|
||||||
pub stream: bool,
|
pub stream: bool,
|
||||||
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
|
pub max_tokens: Option<usize>,
|
||||||
pub stop: Vec<String>,
|
pub stop: Vec<String>,
|
||||||
pub temperature: f32,
|
pub temperature: f32,
|
||||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
@ -216,6 +218,13 @@ pub struct ChoiceDelta {
|
|||||||
pub finish_reason: Option<String>,
|
pub finish_reason: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
|
#[serde(untagged)]
|
||||||
|
pub enum ResponseStreamResult {
|
||||||
|
Ok(ResponseStreamEvent),
|
||||||
|
Err { error: String },
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
pub struct ResponseStreamEvent {
|
pub struct ResponseStreamEvent {
|
||||||
pub created: u32,
|
pub created: u32,
|
||||||
@ -256,7 +265,10 @@ pub async fn stream_completion(
|
|||||||
None
|
None
|
||||||
} else {
|
} else {
|
||||||
match serde_json::from_str(line) {
|
match serde_json::from_str(line) {
|
||||||
Ok(response) => Some(Ok(response)),
|
Ok(ResponseStreamResult::Ok(response)) => Some(Ok(response)),
|
||||||
|
Ok(ResponseStreamResult::Err { error }) => {
|
||||||
|
Some(Err(anyhow!(error)))
|
||||||
|
}
|
||||||
Err(error) => Some(Err(anyhow!(error))),
|
Err(error) => Some(Err(anyhow!(error))),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -2099,6 +2099,7 @@ enum LanguageModelProvider {
|
|||||||
Anthropic = 0;
|
Anthropic = 0;
|
||||||
OpenAI = 1;
|
OpenAI = 1;
|
||||||
Google = 2;
|
Google = 2;
|
||||||
|
Zed = 3;
|
||||||
}
|
}
|
||||||
|
|
||||||
message GetCachedEmbeddings {
|
message GetCachedEmbeddings {
|
||||||
|
Loading…
Reference in New Issue
Block a user