mirror of
https://github.com/zed-industries/zed.git
synced 2024-09-18 18:08:07 +03:00
Enable Claude 3 models to be used via the Zed server if "language-models" feature flag is enabled for user (#10015)
Release Notes: - N/A
This commit is contained in:
parent
b1ccead0f6
commit
9b673089db
13
Cargo.lock
generated
13
Cargo.lock
generated
@ -213,6 +213,18 @@ dependencies = [
|
|||||||
"windows-sys 0.48.0",
|
"windows-sys 0.48.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "anthropic"
|
||||||
|
version = "0.1.0"
|
||||||
|
dependencies = [
|
||||||
|
"anyhow",
|
||||||
|
"futures 0.3.28",
|
||||||
|
"serde",
|
||||||
|
"serde_json",
|
||||||
|
"tokio",
|
||||||
|
"util",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "anyhow"
|
name = "anyhow"
|
||||||
version = "1.0.75"
|
version = "1.0.75"
|
||||||
@ -2214,6 +2226,7 @@ dependencies = [
|
|||||||
name = "collab"
|
name = "collab"
|
||||||
version = "0.44.0"
|
version = "0.44.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"anthropic",
|
||||||
"anyhow",
|
"anyhow",
|
||||||
"async-trait",
|
"async-trait",
|
||||||
"async-tungstenite",
|
"async-tungstenite",
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
[workspace]
|
[workspace]
|
||||||
members = [
|
members = [
|
||||||
"crates/activity_indicator",
|
"crates/activity_indicator",
|
||||||
|
"crates/anthropic",
|
||||||
"crates/assets",
|
"crates/assets",
|
||||||
"crates/assistant",
|
"crates/assistant",
|
||||||
"crates/audio",
|
"crates/audio",
|
||||||
@ -119,6 +120,7 @@ resolver = "2"
|
|||||||
[workspace.dependencies]
|
[workspace.dependencies]
|
||||||
activity_indicator = { path = "crates/activity_indicator" }
|
activity_indicator = { path = "crates/activity_indicator" }
|
||||||
ai = { path = "crates/ai" }
|
ai = { path = "crates/ai" }
|
||||||
|
anthropic = { path = "crates/anthropic" }
|
||||||
assets = { path = "crates/assets" }
|
assets = { path = "crates/assets" }
|
||||||
assistant = { path = "crates/assistant" }
|
assistant = { path = "crates/assistant" }
|
||||||
audio = { path = "crates/audio" }
|
audio = { path = "crates/audio" }
|
||||||
|
22
crates/anthropic/Cargo.toml
Normal file
22
crates/anthropic/Cargo.toml
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
[package]
|
||||||
|
name = "anthropic"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2021"
|
||||||
|
|
||||||
|
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||||
|
|
||||||
|
[lib]
|
||||||
|
path = "src/anthropic.rs"
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
anyhow.workspace = true
|
||||||
|
futures.workspace = true
|
||||||
|
serde.workspace = true
|
||||||
|
serde_json.workspace = true
|
||||||
|
util.workspace = true
|
||||||
|
|
||||||
|
[dev-dependencies]
|
||||||
|
tokio.workspace = true
|
||||||
|
|
||||||
|
[lints]
|
||||||
|
workspace = true
|
234
crates/anthropic/src/anthropic.rs
Normal file
234
crates/anthropic/src/anthropic.rs
Normal file
@ -0,0 +1,234 @@
|
|||||||
|
use anyhow::{anyhow, Result};
|
||||||
|
use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::convert::TryFrom;
|
||||||
|
use util::http::{AsyncBody, HttpClient, Method, Request as HttpRequest};
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
|
||||||
|
pub enum Model {
|
||||||
|
#[default]
|
||||||
|
#[serde(rename = "claude-3-opus-20240229")]
|
||||||
|
Claude3Opus,
|
||||||
|
#[serde(rename = "claude-3-sonnet-20240229")]
|
||||||
|
Claude3Sonnet,
|
||||||
|
#[serde(rename = "claude-3-haiku-20240307")]
|
||||||
|
Claude3Haiku,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Model {
|
||||||
|
pub fn from_id(id: &str) -> Result<Self> {
|
||||||
|
if id.starts_with("claude-3-opus") {
|
||||||
|
Ok(Self::Claude3Opus)
|
||||||
|
} else if id.starts_with("claude-3-sonnet") {
|
||||||
|
Ok(Self::Claude3Sonnet)
|
||||||
|
} else if id.starts_with("claude-3-haiku") {
|
||||||
|
Ok(Self::Claude3Haiku)
|
||||||
|
} else {
|
||||||
|
Err(anyhow!("Invalid model id: {}", id))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn display_name(&self) -> &'static str {
|
||||||
|
match self {
|
||||||
|
Self::Claude3Opus => "Claude 3 Opus",
|
||||||
|
Self::Claude3Sonnet => "Claude 3 Sonnet",
|
||||||
|
Self::Claude3Haiku => "Claude 3 Haiku",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn max_token_count(&self) -> usize {
|
||||||
|
200_000
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||||
|
#[serde(rename_all = "lowercase")]
|
||||||
|
pub enum Role {
|
||||||
|
User,
|
||||||
|
Assistant,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TryFrom<String> for Role {
|
||||||
|
type Error = anyhow::Error;
|
||||||
|
|
||||||
|
fn try_from(value: String) -> Result<Self> {
|
||||||
|
match value.as_str() {
|
||||||
|
"user" => Ok(Self::User),
|
||||||
|
"assistant" => Ok(Self::Assistant),
|
||||||
|
_ => Err(anyhow!("invalid role '{value}'")),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<Role> for String {
|
||||||
|
fn from(val: Role) -> Self {
|
||||||
|
match val {
|
||||||
|
Role::User => "user".to_owned(),
|
||||||
|
Role::Assistant => "assistant".to_owned(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
pub struct Request {
|
||||||
|
pub model: Model,
|
||||||
|
pub messages: Vec<RequestMessage>,
|
||||||
|
pub stream: bool,
|
||||||
|
pub system: String,
|
||||||
|
pub max_tokens: u32,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||||
|
pub struct RequestMessage {
|
||||||
|
pub role: Role,
|
||||||
|
pub content: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize, Debug)]
|
||||||
|
#[serde(tag = "type", rename_all = "snake_case")]
|
||||||
|
pub enum ResponseEvent {
|
||||||
|
MessageStart {
|
||||||
|
message: ResponseMessage,
|
||||||
|
},
|
||||||
|
ContentBlockStart {
|
||||||
|
index: u32,
|
||||||
|
content_block: ContentBlock,
|
||||||
|
},
|
||||||
|
Ping {},
|
||||||
|
ContentBlockDelta {
|
||||||
|
index: u32,
|
||||||
|
delta: TextDelta,
|
||||||
|
},
|
||||||
|
ContentBlockStop {
|
||||||
|
index: u32,
|
||||||
|
},
|
||||||
|
MessageDelta {
|
||||||
|
delta: ResponseMessage,
|
||||||
|
usage: Usage,
|
||||||
|
},
|
||||||
|
MessageStop {},
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize, Debug)]
|
||||||
|
pub struct ResponseMessage {
|
||||||
|
#[serde(rename = "type")]
|
||||||
|
pub message_type: Option<String>,
|
||||||
|
pub id: Option<String>,
|
||||||
|
pub role: Option<String>,
|
||||||
|
pub content: Option<Vec<String>>,
|
||||||
|
pub model: Option<String>,
|
||||||
|
pub stop_reason: Option<String>,
|
||||||
|
pub stop_sequence: Option<String>,
|
||||||
|
pub usage: Option<Usage>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize, Debug)]
|
||||||
|
pub struct Usage {
|
||||||
|
pub input_tokens: Option<u32>,
|
||||||
|
pub output_tokens: Option<u32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize, Debug)]
|
||||||
|
#[serde(tag = "type", rename_all = "snake_case")]
|
||||||
|
pub enum ContentBlock {
|
||||||
|
Text { text: String },
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize, Debug)]
|
||||||
|
#[serde(tag = "type", rename_all = "snake_case")]
|
||||||
|
pub enum TextDelta {
|
||||||
|
TextDelta { text: String },
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn stream_completion(
|
||||||
|
client: &dyn HttpClient,
|
||||||
|
api_url: &str,
|
||||||
|
api_key: &str,
|
||||||
|
request: Request,
|
||||||
|
) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
|
||||||
|
let uri = format!("{api_url}/v1/messages");
|
||||||
|
let request = HttpRequest::builder()
|
||||||
|
.method(Method::POST)
|
||||||
|
.uri(uri)
|
||||||
|
.header("Anthropic-Version", "2023-06-01")
|
||||||
|
.header("Anthropic-Beta", "messages-2023-12-15")
|
||||||
|
.header("X-Api-Key", api_key)
|
||||||
|
.header("Content-Type", "application/json")
|
||||||
|
.body(AsyncBody::from(serde_json::to_string(&request)?))?;
|
||||||
|
let mut response = client.send(request).await?;
|
||||||
|
if response.status().is_success() {
|
||||||
|
let reader = BufReader::new(response.into_body());
|
||||||
|
Ok(reader
|
||||||
|
.lines()
|
||||||
|
.filter_map(|line| async move {
|
||||||
|
match line {
|
||||||
|
Ok(line) => {
|
||||||
|
let line = line.strip_prefix("data: ")?;
|
||||||
|
match serde_json::from_str(line) {
|
||||||
|
Ok(response) => Some(Ok(response)),
|
||||||
|
Err(error) => Some(Err(anyhow!(error))),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(error) => Some(Err(anyhow!(error))),
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.boxed())
|
||||||
|
} else {
|
||||||
|
let mut body = Vec::new();
|
||||||
|
response.body_mut().read_to_end(&mut body).await?;
|
||||||
|
|
||||||
|
let body_str = std::str::from_utf8(&body)?;
|
||||||
|
|
||||||
|
match serde_json::from_str::<ResponseEvent>(body_str) {
|
||||||
|
Ok(_) => Err(anyhow!(
|
||||||
|
"Unexpected success response while expecting an error: {}",
|
||||||
|
body_str,
|
||||||
|
)),
|
||||||
|
Err(_) => Err(anyhow!(
|
||||||
|
"Failed to connect to API: {} {}",
|
||||||
|
response.status(),
|
||||||
|
body_str,
|
||||||
|
)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// #[cfg(test)]
|
||||||
|
// mod tests {
|
||||||
|
// use super::*;
|
||||||
|
// use util::http::IsahcHttpClient;
|
||||||
|
|
||||||
|
// #[tokio::test]
|
||||||
|
// async fn stream_completion_success() {
|
||||||
|
// let http_client = IsahcHttpClient::new().unwrap();
|
||||||
|
|
||||||
|
// let request = Request {
|
||||||
|
// model: Model::Claude3Opus,
|
||||||
|
// messages: vec![RequestMessage {
|
||||||
|
// role: Role::User,
|
||||||
|
// content: "Ping".to_string(),
|
||||||
|
// }],
|
||||||
|
// stream: true,
|
||||||
|
// system: "Respond to ping with pong".to_string(),
|
||||||
|
// max_tokens: 4096,
|
||||||
|
// };
|
||||||
|
|
||||||
|
// let stream = stream_completion(
|
||||||
|
// &http_client,
|
||||||
|
// "https://api.anthropic.com",
|
||||||
|
// &std::env::var("ANTHROPIC_API_KEY").expect("ANTHROPIC_API_KEY not set"),
|
||||||
|
// request,
|
||||||
|
// )
|
||||||
|
// .await
|
||||||
|
// .unwrap();
|
||||||
|
|
||||||
|
// stream
|
||||||
|
// .for_each(|event| async {
|
||||||
|
// match event {
|
||||||
|
// Ok(event) => println!("{:?}", event),
|
||||||
|
// Err(e) => eprintln!("Error: {:?}", e),
|
||||||
|
// }
|
||||||
|
// })
|
||||||
|
// .await;
|
||||||
|
// }
|
||||||
|
// }
|
@ -768,15 +768,18 @@ impl AssistantPanel {
|
|||||||
open_ai::Model::FourTurbo => open_ai::Model::ThreePointFiveTurbo,
|
open_ai::Model::FourTurbo => open_ai::Model::ThreePointFiveTurbo,
|
||||||
}),
|
}),
|
||||||
LanguageModel::ZedDotDev(model) => LanguageModel::ZedDotDev(match &model {
|
LanguageModel::ZedDotDev(model) => LanguageModel::ZedDotDev(match &model {
|
||||||
ZedDotDevModel::GptThreePointFiveTurbo => ZedDotDevModel::GptFour,
|
ZedDotDevModel::Gpt3Point5Turbo => ZedDotDevModel::Gpt4,
|
||||||
ZedDotDevModel::GptFour => ZedDotDevModel::GptFourTurbo,
|
ZedDotDevModel::Gpt4 => ZedDotDevModel::Gpt4Turbo,
|
||||||
ZedDotDevModel::GptFourTurbo => {
|
ZedDotDevModel::Gpt4Turbo => ZedDotDevModel::Claude3Opus,
|
||||||
|
ZedDotDevModel::Claude3Opus => ZedDotDevModel::Claude3Sonnet,
|
||||||
|
ZedDotDevModel::Claude3Sonnet => ZedDotDevModel::Claude3Haiku,
|
||||||
|
ZedDotDevModel::Claude3Haiku => {
|
||||||
match CompletionProvider::global(cx).default_model() {
|
match CompletionProvider::global(cx).default_model() {
|
||||||
LanguageModel::ZedDotDev(custom) => custom,
|
LanguageModel::ZedDotDev(custom) => custom,
|
||||||
_ => ZedDotDevModel::GptThreePointFiveTurbo,
|
_ => ZedDotDevModel::Gpt3Point5Turbo,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ZedDotDevModel::Custom(_) => ZedDotDevModel::GptThreePointFiveTurbo,
|
ZedDotDevModel::Custom(_) => ZedDotDevModel::Gpt3Point5Turbo,
|
||||||
}),
|
}),
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -14,10 +14,13 @@ use settings::Settings;
|
|||||||
|
|
||||||
#[derive(Clone, Debug, Default, PartialEq)]
|
#[derive(Clone, Debug, Default, PartialEq)]
|
||||||
pub enum ZedDotDevModel {
|
pub enum ZedDotDevModel {
|
||||||
GptThreePointFiveTurbo,
|
Gpt3Point5Turbo,
|
||||||
GptFour,
|
Gpt4,
|
||||||
#[default]
|
#[default]
|
||||||
GptFourTurbo,
|
Gpt4Turbo,
|
||||||
|
Claude3Opus,
|
||||||
|
Claude3Sonnet,
|
||||||
|
Claude3Haiku,
|
||||||
Custom(String),
|
Custom(String),
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -49,9 +52,9 @@ impl<'de> Deserialize<'de> for ZedDotDevModel {
|
|||||||
E: de::Error,
|
E: de::Error,
|
||||||
{
|
{
|
||||||
match value {
|
match value {
|
||||||
"gpt-3.5-turbo" => Ok(ZedDotDevModel::GptThreePointFiveTurbo),
|
"gpt-3.5-turbo" => Ok(ZedDotDevModel::Gpt3Point5Turbo),
|
||||||
"gpt-4" => Ok(ZedDotDevModel::GptFour),
|
"gpt-4" => Ok(ZedDotDevModel::Gpt4),
|
||||||
"gpt-4-turbo-preview" => Ok(ZedDotDevModel::GptFourTurbo),
|
"gpt-4-turbo-preview" => Ok(ZedDotDevModel::Gpt4Turbo),
|
||||||
_ => Ok(ZedDotDevModel::Custom(value.to_owned())),
|
_ => Ok(ZedDotDevModel::Custom(value.to_owned())),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -94,27 +97,34 @@ impl JsonSchema for ZedDotDevModel {
|
|||||||
impl ZedDotDevModel {
|
impl ZedDotDevModel {
|
||||||
pub fn id(&self) -> &str {
|
pub fn id(&self) -> &str {
|
||||||
match self {
|
match self {
|
||||||
Self::GptThreePointFiveTurbo => "gpt-3.5-turbo",
|
Self::Gpt3Point5Turbo => "gpt-3.5-turbo",
|
||||||
Self::GptFour => "gpt-4",
|
Self::Gpt4 => "gpt-4",
|
||||||
Self::GptFourTurbo => "gpt-4-turbo-preview",
|
Self::Gpt4Turbo => "gpt-4-turbo-preview",
|
||||||
|
Self::Claude3Opus => "claude-3-opus",
|
||||||
|
Self::Claude3Sonnet => "claude-3-sonnet",
|
||||||
|
Self::Claude3Haiku => "claude-3-haiku",
|
||||||
Self::Custom(id) => id,
|
Self::Custom(id) => id,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn display_name(&self) -> &str {
|
pub fn display_name(&self) -> &str {
|
||||||
match self {
|
match self {
|
||||||
Self::GptThreePointFiveTurbo => "gpt-3.5-turbo",
|
Self::Gpt3Point5Turbo => "GPT 3.5 Turbo",
|
||||||
Self::GptFour => "gpt-4",
|
Self::Gpt4 => "GPT 4",
|
||||||
Self::GptFourTurbo => "gpt-4-turbo",
|
Self::Gpt4Turbo => "GPT 4 Turbo",
|
||||||
|
Self::Claude3Opus => "Claude 3 Opus",
|
||||||
|
Self::Claude3Sonnet => "Claude 3 Sonnet",
|
||||||
|
Self::Claude3Haiku => "Claude 3 Haiku",
|
||||||
Self::Custom(id) => id.as_str(),
|
Self::Custom(id) => id.as_str(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn max_token_count(&self) -> usize {
|
pub fn max_token_count(&self) -> usize {
|
||||||
match self {
|
match self {
|
||||||
Self::GptThreePointFiveTurbo => 2048,
|
Self::Gpt3Point5Turbo => 2048,
|
||||||
Self::GptFour => 4096,
|
Self::Gpt4 => 4096,
|
||||||
Self::GptFourTurbo => 128000,
|
Self::Gpt4Turbo => 128000,
|
||||||
|
Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => 200000,
|
||||||
Self::Custom(_) => 4096, // TODO: Make this configurable
|
Self::Custom(_) => 4096, // TODO: Make this configurable
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
use crate::{
|
use crate::{
|
||||||
assistant_settings::ZedDotDevModel, count_open_ai_tokens, CompletionProvider,
|
assistant_settings::ZedDotDevModel, count_open_ai_tokens, CompletionProvider, LanguageModel,
|
||||||
LanguageModelRequest,
|
LanguageModelRequest,
|
||||||
};
|
};
|
||||||
use anyhow::{anyhow, Result};
|
use anyhow::{anyhow, Result};
|
||||||
@ -78,13 +78,21 @@ impl ZedDotDevCompletionProvider {
|
|||||||
cx: &AppContext,
|
cx: &AppContext,
|
||||||
) -> BoxFuture<'static, Result<usize>> {
|
) -> BoxFuture<'static, Result<usize>> {
|
||||||
match request.model {
|
match request.model {
|
||||||
crate::LanguageModel::OpenAi(_) => future::ready(Err(anyhow!("invalid model"))).boxed(),
|
LanguageModel::OpenAi(_) => future::ready(Err(anyhow!("invalid model"))).boxed(),
|
||||||
crate::LanguageModel::ZedDotDev(ZedDotDevModel::GptFour)
|
LanguageModel::ZedDotDev(ZedDotDevModel::Gpt4)
|
||||||
| crate::LanguageModel::ZedDotDev(ZedDotDevModel::GptFourTurbo)
|
| LanguageModel::ZedDotDev(ZedDotDevModel::Gpt4Turbo)
|
||||||
| crate::LanguageModel::ZedDotDev(ZedDotDevModel::GptThreePointFiveTurbo) => {
|
| LanguageModel::ZedDotDev(ZedDotDevModel::Gpt3Point5Turbo) => {
|
||||||
count_open_ai_tokens(request, cx.background_executor())
|
count_open_ai_tokens(request, cx.background_executor())
|
||||||
}
|
}
|
||||||
crate::LanguageModel::ZedDotDev(ZedDotDevModel::Custom(model)) => {
|
LanguageModel::ZedDotDev(
|
||||||
|
ZedDotDevModel::Claude3Opus
|
||||||
|
| ZedDotDevModel::Claude3Sonnet
|
||||||
|
| ZedDotDevModel::Claude3Haiku,
|
||||||
|
) => {
|
||||||
|
// Can't find a tokenizer for Claude 3, so for now just use the same as OpenAI's as an approximation.
|
||||||
|
count_open_ai_tokens(request, cx.background_executor())
|
||||||
|
}
|
||||||
|
LanguageModel::ZedDotDev(ZedDotDevModel::Custom(model)) => {
|
||||||
let request = self.client.request(proto::CountTokensWithLanguageModel {
|
let request = self.client.request(proto::CountTokensWithLanguageModel {
|
||||||
model,
|
model,
|
||||||
messages: request
|
messages: request
|
||||||
|
@ -18,6 +18,7 @@ sqlite = ["sea-orm/sqlx-sqlite", "sqlx/sqlite"]
|
|||||||
test-support = ["sqlite"]
|
test-support = ["sqlite"]
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
|
anthropic.workspace = true
|
||||||
anyhow.workspace = true
|
anyhow.workspace = true
|
||||||
async-tungstenite = "0.16"
|
async-tungstenite = "0.16"
|
||||||
aws-config = { version = "1.1.5" }
|
aws-config = { version = "1.1.5" }
|
||||||
|
@ -130,6 +130,11 @@ spec:
|
|||||||
secretKeyRef:
|
secretKeyRef:
|
||||||
name: openai
|
name: openai
|
||||||
key: api_key
|
key: api_key
|
||||||
|
- name: ANTHROPIC_API_KEY
|
||||||
|
valueFrom:
|
||||||
|
secretKeyRef:
|
||||||
|
name: anthropic
|
||||||
|
key: api_key
|
||||||
- name: BLOB_STORE_ACCESS_KEY
|
- name: BLOB_STORE_ACCESS_KEY
|
||||||
valueFrom:
|
valueFrom:
|
||||||
secretKeyRef:
|
secretKeyRef:
|
||||||
|
@ -134,6 +134,7 @@ pub struct Config {
|
|||||||
pub zed_environment: Arc<str>,
|
pub zed_environment: Arc<str>,
|
||||||
pub openai_api_key: Option<Arc<str>>,
|
pub openai_api_key: Option<Arc<str>>,
|
||||||
pub google_ai_api_key: Option<Arc<str>>,
|
pub google_ai_api_key: Option<Arc<str>>,
|
||||||
|
pub anthropic_api_key: Option<Arc<str>>,
|
||||||
pub zed_client_checksum_seed: Option<String>,
|
pub zed_client_checksum_seed: Option<String>,
|
||||||
pub slack_panics_webhook: Option<String>,
|
pub slack_panics_webhook: Option<String>,
|
||||||
pub auto_join_channel_id: Option<ChannelId>,
|
pub auto_join_channel_id: Option<ChannelId>,
|
||||||
|
@ -419,6 +419,7 @@ impl Server {
|
|||||||
session,
|
session,
|
||||||
app_state.config.openai_api_key.clone(),
|
app_state.config.openai_api_key.clone(),
|
||||||
app_state.config.google_ai_api_key.clone(),
|
app_state.config.google_ai_api_key.clone(),
|
||||||
|
app_state.config.anthropic_api_key.clone(),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@ -3506,6 +3507,7 @@ async fn complete_with_language_model(
|
|||||||
session: Session,
|
session: Session,
|
||||||
open_ai_api_key: Option<Arc<str>>,
|
open_ai_api_key: Option<Arc<str>>,
|
||||||
google_ai_api_key: Option<Arc<str>>,
|
google_ai_api_key: Option<Arc<str>>,
|
||||||
|
anthropic_api_key: Option<Arc<str>>,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let Some(session) = session.for_user() else {
|
let Some(session) = session.for_user() else {
|
||||||
return Err(anyhow!("user not found"))?;
|
return Err(anyhow!("user not found"))?;
|
||||||
@ -3524,6 +3526,10 @@ async fn complete_with_language_model(
|
|||||||
let api_key = google_ai_api_key
|
let api_key = google_ai_api_key
|
||||||
.ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
|
.ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
|
||||||
complete_with_google_ai(request, response, session, api_key).await?;
|
complete_with_google_ai(request, response, session, api_key).await?;
|
||||||
|
} else if request.model.starts_with("claude") {
|
||||||
|
let api_key = anthropic_api_key
|
||||||
|
.ok_or_else(|| anyhow!("no Anthropic AI API key configured on the server"))?;
|
||||||
|
complete_with_anthropic(request, response, session, api_key).await?;
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
@ -3621,6 +3627,121 @@ async fn complete_with_google_ai(
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn complete_with_anthropic(
|
||||||
|
request: proto::CompleteWithLanguageModel,
|
||||||
|
response: StreamingResponse<proto::CompleteWithLanguageModel>,
|
||||||
|
session: UserSession,
|
||||||
|
api_key: Arc<str>,
|
||||||
|
) -> Result<()> {
|
||||||
|
let model = anthropic::Model::from_id(&request.model)?;
|
||||||
|
|
||||||
|
let mut system_message = String::new();
|
||||||
|
let messages = request
|
||||||
|
.messages
|
||||||
|
.into_iter()
|
||||||
|
.filter_map(|message| match message.role() {
|
||||||
|
LanguageModelRole::LanguageModelUser => Some(anthropic::RequestMessage {
|
||||||
|
role: anthropic::Role::User,
|
||||||
|
content: message.content,
|
||||||
|
}),
|
||||||
|
LanguageModelRole::LanguageModelAssistant => Some(anthropic::RequestMessage {
|
||||||
|
role: anthropic::Role::Assistant,
|
||||||
|
content: message.content,
|
||||||
|
}),
|
||||||
|
// Anthropic's API breaks system instructions out as a separate field rather
|
||||||
|
// than having a system message role.
|
||||||
|
LanguageModelRole::LanguageModelSystem => {
|
||||||
|
if !system_message.is_empty() {
|
||||||
|
system_message.push_str("\n\n");
|
||||||
|
}
|
||||||
|
system_message.push_str(&message.content);
|
||||||
|
|
||||||
|
None
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let mut stream = anthropic::stream_completion(
|
||||||
|
&session.http_client,
|
||||||
|
"https://api.anthropic.com",
|
||||||
|
&api_key,
|
||||||
|
anthropic::Request {
|
||||||
|
model,
|
||||||
|
messages,
|
||||||
|
stream: true,
|
||||||
|
system: system_message,
|
||||||
|
max_tokens: 4092,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let mut current_role = proto::LanguageModelRole::LanguageModelAssistant;
|
||||||
|
|
||||||
|
while let Some(event) = stream.next().await {
|
||||||
|
let event = event?;
|
||||||
|
|
||||||
|
match event {
|
||||||
|
anthropic::ResponseEvent::MessageStart { message } => {
|
||||||
|
if let Some(role) = message.role {
|
||||||
|
if role == "assistant" {
|
||||||
|
current_role = proto::LanguageModelRole::LanguageModelAssistant;
|
||||||
|
} else if role == "user" {
|
||||||
|
current_role = proto::LanguageModelRole::LanguageModelUser;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
anthropic::ResponseEvent::ContentBlockStart { content_block, .. } => {
|
||||||
|
match content_block {
|
||||||
|
anthropic::ContentBlock::Text { text } => {
|
||||||
|
if !text.is_empty() {
|
||||||
|
response.send(proto::LanguageModelResponse {
|
||||||
|
choices: vec![proto::LanguageModelChoiceDelta {
|
||||||
|
index: 0,
|
||||||
|
delta: Some(proto::LanguageModelResponseMessage {
|
||||||
|
role: Some(current_role as i32),
|
||||||
|
content: Some(text),
|
||||||
|
}),
|
||||||
|
finish_reason: None,
|
||||||
|
}],
|
||||||
|
})?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
anthropic::ResponseEvent::ContentBlockDelta { delta, .. } => match delta {
|
||||||
|
anthropic::TextDelta::TextDelta { text } => {
|
||||||
|
response.send(proto::LanguageModelResponse {
|
||||||
|
choices: vec![proto::LanguageModelChoiceDelta {
|
||||||
|
index: 0,
|
||||||
|
delta: Some(proto::LanguageModelResponseMessage {
|
||||||
|
role: Some(current_role as i32),
|
||||||
|
content: Some(text),
|
||||||
|
}),
|
||||||
|
finish_reason: None,
|
||||||
|
}],
|
||||||
|
})?;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
anthropic::ResponseEvent::MessageDelta { delta, .. } => {
|
||||||
|
if let Some(stop_reason) = delta.stop_reason {
|
||||||
|
response.send(proto::LanguageModelResponse {
|
||||||
|
choices: vec![proto::LanguageModelChoiceDelta {
|
||||||
|
index: 0,
|
||||||
|
delta: None,
|
||||||
|
finish_reason: Some(stop_reason),
|
||||||
|
}],
|
||||||
|
})?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
anthropic::ResponseEvent::ContentBlockStop { .. } => {}
|
||||||
|
anthropic::ResponseEvent::MessageStop {} => {}
|
||||||
|
anthropic::ResponseEvent::Ping {} => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
struct CountTokensWithLanguageModelRateLimit;
|
struct CountTokensWithLanguageModelRateLimit;
|
||||||
|
|
||||||
impl RateLimit for CountTokensWithLanguageModelRateLimit {
|
impl RateLimit for CountTokensWithLanguageModelRateLimit {
|
||||||
|
@ -512,6 +512,7 @@ impl TestServer {
|
|||||||
blob_store_bucket: None,
|
blob_store_bucket: None,
|
||||||
openai_api_key: None,
|
openai_api_key: None,
|
||||||
google_ai_api_key: None,
|
google_ai_api_key: None,
|
||||||
|
anthropic_api_key: None,
|
||||||
clickhouse_url: None,
|
clickhouse_url: None,
|
||||||
clickhouse_user: None,
|
clickhouse_user: None,
|
||||||
clickhouse_password: None,
|
clickhouse_password: None,
|
||||||
|
Loading…
Reference in New Issue
Block a user