Enable Claude 3 models to be used via the Zed server if "language-models" feature flag is enabled for user (#10015)

Release Notes:

- N/A
This commit is contained in:
Nathan Sobo 2024-03-31 14:57:57 -07:00 committed by GitHub
parent b1ccead0f6
commit 9b673089db
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 447 additions and 26 deletions

13
Cargo.lock generated
View File

@ -213,6 +213,18 @@ dependencies = [
"windows-sys 0.48.0",
]
[[package]]
name = "anthropic"
version = "0.1.0"
dependencies = [
"anyhow",
"futures 0.3.28",
"serde",
"serde_json",
"tokio",
"util",
]
[[package]]
name = "anyhow"
version = "1.0.75"
@ -2214,6 +2226,7 @@ dependencies = [
name = "collab"
version = "0.44.0"
dependencies = [
"anthropic",
"anyhow",
"async-trait",
"async-tungstenite",

View File

@ -1,6 +1,7 @@
[workspace]
members = [
"crates/activity_indicator",
"crates/anthropic",
"crates/assets",
"crates/assistant",
"crates/audio",
@ -119,6 +120,7 @@ resolver = "2"
[workspace.dependencies]
activity_indicator = { path = "crates/activity_indicator" }
ai = { path = "crates/ai" }
anthropic = { path = "crates/anthropic" }
assets = { path = "crates/assets" }
assistant = { path = "crates/assistant" }
audio = { path = "crates/audio" }

View File

@ -0,0 +1,22 @@
[package]
name = "anthropic"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[lib]
path = "src/anthropic.rs"
[dependencies]
anyhow.workspace = true
futures.workspace = true
serde.workspace = true
serde_json.workspace = true
util.workspace = true
[dev-dependencies]
tokio.workspace = true
[lints]
workspace = true

View File

@ -0,0 +1,234 @@
use anyhow::{anyhow, Result};
use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
use serde::{Deserialize, Serialize};
use std::convert::TryFrom;
use util::http::{AsyncBody, HttpClient, Method, Request as HttpRequest};
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
pub enum Model {
#[default]
#[serde(rename = "claude-3-opus-20240229")]
Claude3Opus,
#[serde(rename = "claude-3-sonnet-20240229")]
Claude3Sonnet,
#[serde(rename = "claude-3-haiku-20240307")]
Claude3Haiku,
}
impl Model {
pub fn from_id(id: &str) -> Result<Self> {
if id.starts_with("claude-3-opus") {
Ok(Self::Claude3Opus)
} else if id.starts_with("claude-3-sonnet") {
Ok(Self::Claude3Sonnet)
} else if id.starts_with("claude-3-haiku") {
Ok(Self::Claude3Haiku)
} else {
Err(anyhow!("Invalid model id: {}", id))
}
}
pub fn display_name(&self) -> &'static str {
match self {
Self::Claude3Opus => "Claude 3 Opus",
Self::Claude3Sonnet => "Claude 3 Sonnet",
Self::Claude3Haiku => "Claude 3 Haiku",
}
}
pub fn max_token_count(&self) -> usize {
200_000
}
}
#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum Role {
User,
Assistant,
}
impl TryFrom<String> for Role {
type Error = anyhow::Error;
fn try_from(value: String) -> Result<Self> {
match value.as_str() {
"user" => Ok(Self::User),
"assistant" => Ok(Self::Assistant),
_ => Err(anyhow!("invalid role '{value}'")),
}
}
}
impl From<Role> for String {
fn from(val: Role) -> Self {
match val {
Role::User => "user".to_owned(),
Role::Assistant => "assistant".to_owned(),
}
}
}
#[derive(Debug, Serialize)]
pub struct Request {
pub model: Model,
pub messages: Vec<RequestMessage>,
pub stream: bool,
pub system: String,
pub max_tokens: u32,
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct RequestMessage {
pub role: Role,
pub content: String,
}
#[derive(Deserialize, Debug)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ResponseEvent {
MessageStart {
message: ResponseMessage,
},
ContentBlockStart {
index: u32,
content_block: ContentBlock,
},
Ping {},
ContentBlockDelta {
index: u32,
delta: TextDelta,
},
ContentBlockStop {
index: u32,
},
MessageDelta {
delta: ResponseMessage,
usage: Usage,
},
MessageStop {},
}
#[derive(Deserialize, Debug)]
pub struct ResponseMessage {
#[serde(rename = "type")]
pub message_type: Option<String>,
pub id: Option<String>,
pub role: Option<String>,
pub content: Option<Vec<String>>,
pub model: Option<String>,
pub stop_reason: Option<String>,
pub stop_sequence: Option<String>,
pub usage: Option<Usage>,
}
#[derive(Deserialize, Debug)]
pub struct Usage {
pub input_tokens: Option<u32>,
pub output_tokens: Option<u32>,
}
#[derive(Deserialize, Debug)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ContentBlock {
Text { text: String },
}
#[derive(Deserialize, Debug)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum TextDelta {
TextDelta { text: String },
}
pub async fn stream_completion(
client: &dyn HttpClient,
api_url: &str,
api_key: &str,
request: Request,
) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
let uri = format!("{api_url}/v1/messages");
let request = HttpRequest::builder()
.method(Method::POST)
.uri(uri)
.header("Anthropic-Version", "2023-06-01")
.header("Anthropic-Beta", "messages-2023-12-15")
.header("X-Api-Key", api_key)
.header("Content-Type", "application/json")
.body(AsyncBody::from(serde_json::to_string(&request)?))?;
let mut response = client.send(request).await?;
if response.status().is_success() {
let reader = BufReader::new(response.into_body());
Ok(reader
.lines()
.filter_map(|line| async move {
match line {
Ok(line) => {
let line = line.strip_prefix("data: ")?;
match serde_json::from_str(line) {
Ok(response) => Some(Ok(response)),
Err(error) => Some(Err(anyhow!(error))),
}
}
Err(error) => Some(Err(anyhow!(error))),
}
})
.boxed())
} else {
let mut body = Vec::new();
response.body_mut().read_to_end(&mut body).await?;
let body_str = std::str::from_utf8(&body)?;
match serde_json::from_str::<ResponseEvent>(body_str) {
Ok(_) => Err(anyhow!(
"Unexpected success response while expecting an error: {}",
body_str,
)),
Err(_) => Err(anyhow!(
"Failed to connect to API: {} {}",
response.status(),
body_str,
)),
}
}
}
// #[cfg(test)]
// mod tests {
// use super::*;
// use util::http::IsahcHttpClient;
// #[tokio::test]
// async fn stream_completion_success() {
// let http_client = IsahcHttpClient::new().unwrap();
// let request = Request {
// model: Model::Claude3Opus,
// messages: vec![RequestMessage {
// role: Role::User,
// content: "Ping".to_string(),
// }],
// stream: true,
// system: "Respond to ping with pong".to_string(),
// max_tokens: 4096,
// };
// let stream = stream_completion(
// &http_client,
// "https://api.anthropic.com",
// &std::env::var("ANTHROPIC_API_KEY").expect("ANTHROPIC_API_KEY not set"),
// request,
// )
// .await
// .unwrap();
// stream
// .for_each(|event| async {
// match event {
// Ok(event) => println!("{:?}", event),
// Err(e) => eprintln!("Error: {:?}", e),
// }
// })
// .await;
// }
// }

View File

@ -768,15 +768,18 @@ impl AssistantPanel {
open_ai::Model::FourTurbo => open_ai::Model::ThreePointFiveTurbo,
}),
LanguageModel::ZedDotDev(model) => LanguageModel::ZedDotDev(match &model {
ZedDotDevModel::GptThreePointFiveTurbo => ZedDotDevModel::GptFour,
ZedDotDevModel::GptFour => ZedDotDevModel::GptFourTurbo,
ZedDotDevModel::GptFourTurbo => {
ZedDotDevModel::Gpt3Point5Turbo => ZedDotDevModel::Gpt4,
ZedDotDevModel::Gpt4 => ZedDotDevModel::Gpt4Turbo,
ZedDotDevModel::Gpt4Turbo => ZedDotDevModel::Claude3Opus,
ZedDotDevModel::Claude3Opus => ZedDotDevModel::Claude3Sonnet,
ZedDotDevModel::Claude3Sonnet => ZedDotDevModel::Claude3Haiku,
ZedDotDevModel::Claude3Haiku => {
match CompletionProvider::global(cx).default_model() {
LanguageModel::ZedDotDev(custom) => custom,
_ => ZedDotDevModel::GptThreePointFiveTurbo,
_ => ZedDotDevModel::Gpt3Point5Turbo,
}
}
ZedDotDevModel::Custom(_) => ZedDotDevModel::GptThreePointFiveTurbo,
ZedDotDevModel::Custom(_) => ZedDotDevModel::Gpt3Point5Turbo,
}),
};

View File

@ -14,10 +14,13 @@ use settings::Settings;
#[derive(Clone, Debug, Default, PartialEq)]
pub enum ZedDotDevModel {
GptThreePointFiveTurbo,
GptFour,
Gpt3Point5Turbo,
Gpt4,
#[default]
GptFourTurbo,
Gpt4Turbo,
Claude3Opus,
Claude3Sonnet,
Claude3Haiku,
Custom(String),
}
@ -49,9 +52,9 @@ impl<'de> Deserialize<'de> for ZedDotDevModel {
E: de::Error,
{
match value {
"gpt-3.5-turbo" => Ok(ZedDotDevModel::GptThreePointFiveTurbo),
"gpt-4" => Ok(ZedDotDevModel::GptFour),
"gpt-4-turbo-preview" => Ok(ZedDotDevModel::GptFourTurbo),
"gpt-3.5-turbo" => Ok(ZedDotDevModel::Gpt3Point5Turbo),
"gpt-4" => Ok(ZedDotDevModel::Gpt4),
"gpt-4-turbo-preview" => Ok(ZedDotDevModel::Gpt4Turbo),
_ => Ok(ZedDotDevModel::Custom(value.to_owned())),
}
}
@ -94,27 +97,34 @@ impl JsonSchema for ZedDotDevModel {
impl ZedDotDevModel {
pub fn id(&self) -> &str {
match self {
Self::GptThreePointFiveTurbo => "gpt-3.5-turbo",
Self::GptFour => "gpt-4",
Self::GptFourTurbo => "gpt-4-turbo-preview",
Self::Gpt3Point5Turbo => "gpt-3.5-turbo",
Self::Gpt4 => "gpt-4",
Self::Gpt4Turbo => "gpt-4-turbo-preview",
Self::Claude3Opus => "claude-3-opus",
Self::Claude3Sonnet => "claude-3-sonnet",
Self::Claude3Haiku => "claude-3-haiku",
Self::Custom(id) => id,
}
}
pub fn display_name(&self) -> &str {
match self {
Self::GptThreePointFiveTurbo => "gpt-3.5-turbo",
Self::GptFour => "gpt-4",
Self::GptFourTurbo => "gpt-4-turbo",
Self::Gpt3Point5Turbo => "GPT 3.5 Turbo",
Self::Gpt4 => "GPT 4",
Self::Gpt4Turbo => "GPT 4 Turbo",
Self::Claude3Opus => "Claude 3 Opus",
Self::Claude3Sonnet => "Claude 3 Sonnet",
Self::Claude3Haiku => "Claude 3 Haiku",
Self::Custom(id) => id.as_str(),
}
}
pub fn max_token_count(&self) -> usize {
match self {
Self::GptThreePointFiveTurbo => 2048,
Self::GptFour => 4096,
Self::GptFourTurbo => 128000,
Self::Gpt3Point5Turbo => 2048,
Self::Gpt4 => 4096,
Self::Gpt4Turbo => 128000,
Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => 200000,
Self::Custom(_) => 4096, // TODO: Make this configurable
}
}

View File

@ -1,5 +1,5 @@
use crate::{
assistant_settings::ZedDotDevModel, count_open_ai_tokens, CompletionProvider,
assistant_settings::ZedDotDevModel, count_open_ai_tokens, CompletionProvider, LanguageModel,
LanguageModelRequest,
};
use anyhow::{anyhow, Result};
@ -78,13 +78,21 @@ impl ZedDotDevCompletionProvider {
cx: &AppContext,
) -> BoxFuture<'static, Result<usize>> {
match request.model {
crate::LanguageModel::OpenAi(_) => future::ready(Err(anyhow!("invalid model"))).boxed(),
crate::LanguageModel::ZedDotDev(ZedDotDevModel::GptFour)
| crate::LanguageModel::ZedDotDev(ZedDotDevModel::GptFourTurbo)
| crate::LanguageModel::ZedDotDev(ZedDotDevModel::GptThreePointFiveTurbo) => {
LanguageModel::OpenAi(_) => future::ready(Err(anyhow!("invalid model"))).boxed(),
LanguageModel::ZedDotDev(ZedDotDevModel::Gpt4)
| LanguageModel::ZedDotDev(ZedDotDevModel::Gpt4Turbo)
| LanguageModel::ZedDotDev(ZedDotDevModel::Gpt3Point5Turbo) => {
count_open_ai_tokens(request, cx.background_executor())
}
crate::LanguageModel::ZedDotDev(ZedDotDevModel::Custom(model)) => {
LanguageModel::ZedDotDev(
ZedDotDevModel::Claude3Opus
| ZedDotDevModel::Claude3Sonnet
| ZedDotDevModel::Claude3Haiku,
) => {
// Can't find a tokenizer for Claude 3, so for now just use the same as OpenAI's as an approximation.
count_open_ai_tokens(request, cx.background_executor())
}
LanguageModel::ZedDotDev(ZedDotDevModel::Custom(model)) => {
let request = self.client.request(proto::CountTokensWithLanguageModel {
model,
messages: request

View File

@ -18,6 +18,7 @@ sqlite = ["sea-orm/sqlx-sqlite", "sqlx/sqlite"]
test-support = ["sqlite"]
[dependencies]
anthropic.workspace = true
anyhow.workspace = true
async-tungstenite = "0.16"
aws-config = { version = "1.1.5" }

View File

@ -130,6 +130,11 @@ spec:
secretKeyRef:
name: openai
key: api_key
- name: ANTHROPIC_API_KEY
valueFrom:
secretKeyRef:
name: anthropic
key: api_key
- name: BLOB_STORE_ACCESS_KEY
valueFrom:
secretKeyRef:

View File

@ -134,6 +134,7 @@ pub struct Config {
pub zed_environment: Arc<str>,
pub openai_api_key: Option<Arc<str>>,
pub google_ai_api_key: Option<Arc<str>>,
pub anthropic_api_key: Option<Arc<str>>,
pub zed_client_checksum_seed: Option<String>,
pub slack_panics_webhook: Option<String>,
pub auto_join_channel_id: Option<ChannelId>,

View File

@ -419,6 +419,7 @@ impl Server {
session,
app_state.config.openai_api_key.clone(),
app_state.config.google_ai_api_key.clone(),
app_state.config.anthropic_api_key.clone(),
)
}
})
@ -3506,6 +3507,7 @@ async fn complete_with_language_model(
session: Session,
open_ai_api_key: Option<Arc<str>>,
google_ai_api_key: Option<Arc<str>>,
anthropic_api_key: Option<Arc<str>>,
) -> Result<()> {
let Some(session) = session.for_user() else {
return Err(anyhow!("user not found"))?;
@ -3524,6 +3526,10 @@ async fn complete_with_language_model(
let api_key = google_ai_api_key
.ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
complete_with_google_ai(request, response, session, api_key).await?;
} else if request.model.starts_with("claude") {
let api_key = anthropic_api_key
.ok_or_else(|| anyhow!("no Anthropic AI API key configured on the server"))?;
complete_with_anthropic(request, response, session, api_key).await?;
}
Ok(())
@ -3621,6 +3627,121 @@ async fn complete_with_google_ai(
Ok(())
}
async fn complete_with_anthropic(
request: proto::CompleteWithLanguageModel,
response: StreamingResponse<proto::CompleteWithLanguageModel>,
session: UserSession,
api_key: Arc<str>,
) -> Result<()> {
let model = anthropic::Model::from_id(&request.model)?;
let mut system_message = String::new();
let messages = request
.messages
.into_iter()
.filter_map(|message| match message.role() {
LanguageModelRole::LanguageModelUser => Some(anthropic::RequestMessage {
role: anthropic::Role::User,
content: message.content,
}),
LanguageModelRole::LanguageModelAssistant => Some(anthropic::RequestMessage {
role: anthropic::Role::Assistant,
content: message.content,
}),
// Anthropic's API breaks system instructions out as a separate field rather
// than having a system message role.
LanguageModelRole::LanguageModelSystem => {
if !system_message.is_empty() {
system_message.push_str("\n\n");
}
system_message.push_str(&message.content);
None
}
})
.collect();
let mut stream = anthropic::stream_completion(
&session.http_client,
"https://api.anthropic.com",
&api_key,
anthropic::Request {
model,
messages,
stream: true,
system: system_message,
max_tokens: 4092,
},
)
.await?;
let mut current_role = proto::LanguageModelRole::LanguageModelAssistant;
while let Some(event) = stream.next().await {
let event = event?;
match event {
anthropic::ResponseEvent::MessageStart { message } => {
if let Some(role) = message.role {
if role == "assistant" {
current_role = proto::LanguageModelRole::LanguageModelAssistant;
} else if role == "user" {
current_role = proto::LanguageModelRole::LanguageModelUser;
}
}
}
anthropic::ResponseEvent::ContentBlockStart { content_block, .. } => {
match content_block {
anthropic::ContentBlock::Text { text } => {
if !text.is_empty() {
response.send(proto::LanguageModelResponse {
choices: vec![proto::LanguageModelChoiceDelta {
index: 0,
delta: Some(proto::LanguageModelResponseMessage {
role: Some(current_role as i32),
content: Some(text),
}),
finish_reason: None,
}],
})?;
}
}
}
}
anthropic::ResponseEvent::ContentBlockDelta { delta, .. } => match delta {
anthropic::TextDelta::TextDelta { text } => {
response.send(proto::LanguageModelResponse {
choices: vec![proto::LanguageModelChoiceDelta {
index: 0,
delta: Some(proto::LanguageModelResponseMessage {
role: Some(current_role as i32),
content: Some(text),
}),
finish_reason: None,
}],
})?;
}
},
anthropic::ResponseEvent::MessageDelta { delta, .. } => {
if let Some(stop_reason) = delta.stop_reason {
response.send(proto::LanguageModelResponse {
choices: vec![proto::LanguageModelChoiceDelta {
index: 0,
delta: None,
finish_reason: Some(stop_reason),
}],
})?;
}
}
anthropic::ResponseEvent::ContentBlockStop { .. } => {}
anthropic::ResponseEvent::MessageStop {} => {}
anthropic::ResponseEvent::Ping {} => {}
}
}
Ok(())
}
struct CountTokensWithLanguageModelRateLimit;
impl RateLimit for CountTokensWithLanguageModelRateLimit {

View File

@ -512,6 +512,7 @@ impl TestServer {
blob_store_bucket: None,
openai_api_key: None,
google_ai_api_key: None,
anthropic_api_key: None,
clickhouse_url: None,
clickhouse_user: None,
clickhouse_password: None,