mirror of
https://github.com/zed-industries/zed.git
synced 2024-11-07 20:39:04 +03:00
Enable Claude 3 models to be used via the Zed server if "language-models" feature flag is enabled for user (#10015)
Release Notes: - N/A
This commit is contained in:
parent
b1ccead0f6
commit
9b673089db
13
Cargo.lock
generated
13
Cargo.lock
generated
@ -213,6 +213,18 @@ dependencies = [
|
||||
"windows-sys 0.48.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "anthropic"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"futures 0.3.28",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"tokio",
|
||||
"util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "anyhow"
|
||||
version = "1.0.75"
|
||||
@ -2214,6 +2226,7 @@ dependencies = [
|
||||
name = "collab"
|
||||
version = "0.44.0"
|
||||
dependencies = [
|
||||
"anthropic",
|
||||
"anyhow",
|
||||
"async-trait",
|
||||
"async-tungstenite",
|
||||
|
@ -1,6 +1,7 @@
|
||||
[workspace]
|
||||
members = [
|
||||
"crates/activity_indicator",
|
||||
"crates/anthropic",
|
||||
"crates/assets",
|
||||
"crates/assistant",
|
||||
"crates/audio",
|
||||
@ -119,6 +120,7 @@ resolver = "2"
|
||||
[workspace.dependencies]
|
||||
activity_indicator = { path = "crates/activity_indicator" }
|
||||
ai = { path = "crates/ai" }
|
||||
anthropic = { path = "crates/anthropic" }
|
||||
assets = { path = "crates/assets" }
|
||||
assistant = { path = "crates/assistant" }
|
||||
audio = { path = "crates/audio" }
|
||||
|
22
crates/anthropic/Cargo.toml
Normal file
22
crates/anthropic/Cargo.toml
Normal file
@ -0,0 +1,22 @@
|
||||
[package]
|
||||
name = "anthropic"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[lib]
|
||||
path = "src/anthropic.rs"
|
||||
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
futures.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
util.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
tokio.workspace = true
|
||||
|
||||
[lints]
|
||||
workspace = true
|
234
crates/anthropic/src/anthropic.rs
Normal file
234
crates/anthropic/src/anthropic.rs
Normal file
@ -0,0 +1,234 @@
|
||||
use anyhow::{anyhow, Result};
|
||||
use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::convert::TryFrom;
|
||||
use util::http::{AsyncBody, HttpClient, Method, Request as HttpRequest};
|
||||
|
||||
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
|
||||
pub enum Model {
|
||||
#[default]
|
||||
#[serde(rename = "claude-3-opus-20240229")]
|
||||
Claude3Opus,
|
||||
#[serde(rename = "claude-3-sonnet-20240229")]
|
||||
Claude3Sonnet,
|
||||
#[serde(rename = "claude-3-haiku-20240307")]
|
||||
Claude3Haiku,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
pub fn from_id(id: &str) -> Result<Self> {
|
||||
if id.starts_with("claude-3-opus") {
|
||||
Ok(Self::Claude3Opus)
|
||||
} else if id.starts_with("claude-3-sonnet") {
|
||||
Ok(Self::Claude3Sonnet)
|
||||
} else if id.starts_with("claude-3-haiku") {
|
||||
Ok(Self::Claude3Haiku)
|
||||
} else {
|
||||
Err(anyhow!("Invalid model id: {}", id))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn display_name(&self) -> &'static str {
|
||||
match self {
|
||||
Self::Claude3Opus => "Claude 3 Opus",
|
||||
Self::Claude3Sonnet => "Claude 3 Sonnet",
|
||||
Self::Claude3Haiku => "Claude 3 Haiku",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn max_token_count(&self) -> usize {
|
||||
200_000
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum Role {
|
||||
User,
|
||||
Assistant,
|
||||
}
|
||||
|
||||
impl TryFrom<String> for Role {
|
||||
type Error = anyhow::Error;
|
||||
|
||||
fn try_from(value: String) -> Result<Self> {
|
||||
match value.as_str() {
|
||||
"user" => Ok(Self::User),
|
||||
"assistant" => Ok(Self::Assistant),
|
||||
_ => Err(anyhow!("invalid role '{value}'")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Role> for String {
|
||||
fn from(val: Role) -> Self {
|
||||
match val {
|
||||
Role::User => "user".to_owned(),
|
||||
Role::Assistant => "assistant".to_owned(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct Request {
|
||||
pub model: Model,
|
||||
pub messages: Vec<RequestMessage>,
|
||||
pub stream: bool,
|
||||
pub system: String,
|
||||
pub max_tokens: u32,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
pub struct RequestMessage {
|
||||
pub role: Role,
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum ResponseEvent {
|
||||
MessageStart {
|
||||
message: ResponseMessage,
|
||||
},
|
||||
ContentBlockStart {
|
||||
index: u32,
|
||||
content_block: ContentBlock,
|
||||
},
|
||||
Ping {},
|
||||
ContentBlockDelta {
|
||||
index: u32,
|
||||
delta: TextDelta,
|
||||
},
|
||||
ContentBlockStop {
|
||||
index: u32,
|
||||
},
|
||||
MessageDelta {
|
||||
delta: ResponseMessage,
|
||||
usage: Usage,
|
||||
},
|
||||
MessageStop {},
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct ResponseMessage {
|
||||
#[serde(rename = "type")]
|
||||
pub message_type: Option<String>,
|
||||
pub id: Option<String>,
|
||||
pub role: Option<String>,
|
||||
pub content: Option<Vec<String>>,
|
||||
pub model: Option<String>,
|
||||
pub stop_reason: Option<String>,
|
||||
pub stop_sequence: Option<String>,
|
||||
pub usage: Option<Usage>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct Usage {
|
||||
pub input_tokens: Option<u32>,
|
||||
pub output_tokens: Option<u32>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum ContentBlock {
|
||||
Text { text: String },
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum TextDelta {
|
||||
TextDelta { text: String },
|
||||
}
|
||||
|
||||
pub async fn stream_completion(
|
||||
client: &dyn HttpClient,
|
||||
api_url: &str,
|
||||
api_key: &str,
|
||||
request: Request,
|
||||
) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
|
||||
let uri = format!("{api_url}/v1/messages");
|
||||
let request = HttpRequest::builder()
|
||||
.method(Method::POST)
|
||||
.uri(uri)
|
||||
.header("Anthropic-Version", "2023-06-01")
|
||||
.header("Anthropic-Beta", "messages-2023-12-15")
|
||||
.header("X-Api-Key", api_key)
|
||||
.header("Content-Type", "application/json")
|
||||
.body(AsyncBody::from(serde_json::to_string(&request)?))?;
|
||||
let mut response = client.send(request).await?;
|
||||
if response.status().is_success() {
|
||||
let reader = BufReader::new(response.into_body());
|
||||
Ok(reader
|
||||
.lines()
|
||||
.filter_map(|line| async move {
|
||||
match line {
|
||||
Ok(line) => {
|
||||
let line = line.strip_prefix("data: ")?;
|
||||
match serde_json::from_str(line) {
|
||||
Ok(response) => Some(Ok(response)),
|
||||
Err(error) => Some(Err(anyhow!(error))),
|
||||
}
|
||||
}
|
||||
Err(error) => Some(Err(anyhow!(error))),
|
||||
}
|
||||
})
|
||||
.boxed())
|
||||
} else {
|
||||
let mut body = Vec::new();
|
||||
response.body_mut().read_to_end(&mut body).await?;
|
||||
|
||||
let body_str = std::str::from_utf8(&body)?;
|
||||
|
||||
match serde_json::from_str::<ResponseEvent>(body_str) {
|
||||
Ok(_) => Err(anyhow!(
|
||||
"Unexpected success response while expecting an error: {}",
|
||||
body_str,
|
||||
)),
|
||||
Err(_) => Err(anyhow!(
|
||||
"Failed to connect to API: {} {}",
|
||||
response.status(),
|
||||
body_str,
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// #[cfg(test)]
|
||||
// mod tests {
|
||||
// use super::*;
|
||||
// use util::http::IsahcHttpClient;
|
||||
|
||||
// #[tokio::test]
|
||||
// async fn stream_completion_success() {
|
||||
// let http_client = IsahcHttpClient::new().unwrap();
|
||||
|
||||
// let request = Request {
|
||||
// model: Model::Claude3Opus,
|
||||
// messages: vec![RequestMessage {
|
||||
// role: Role::User,
|
||||
// content: "Ping".to_string(),
|
||||
// }],
|
||||
// stream: true,
|
||||
// system: "Respond to ping with pong".to_string(),
|
||||
// max_tokens: 4096,
|
||||
// };
|
||||
|
||||
// let stream = stream_completion(
|
||||
// &http_client,
|
||||
// "https://api.anthropic.com",
|
||||
// &std::env::var("ANTHROPIC_API_KEY").expect("ANTHROPIC_API_KEY not set"),
|
||||
// request,
|
||||
// )
|
||||
// .await
|
||||
// .unwrap();
|
||||
|
||||
// stream
|
||||
// .for_each(|event| async {
|
||||
// match event {
|
||||
// Ok(event) => println!("{:?}", event),
|
||||
// Err(e) => eprintln!("Error: {:?}", e),
|
||||
// }
|
||||
// })
|
||||
// .await;
|
||||
// }
|
||||
// }
|
@ -768,15 +768,18 @@ impl AssistantPanel {
|
||||
open_ai::Model::FourTurbo => open_ai::Model::ThreePointFiveTurbo,
|
||||
}),
|
||||
LanguageModel::ZedDotDev(model) => LanguageModel::ZedDotDev(match &model {
|
||||
ZedDotDevModel::GptThreePointFiveTurbo => ZedDotDevModel::GptFour,
|
||||
ZedDotDevModel::GptFour => ZedDotDevModel::GptFourTurbo,
|
||||
ZedDotDevModel::GptFourTurbo => {
|
||||
ZedDotDevModel::Gpt3Point5Turbo => ZedDotDevModel::Gpt4,
|
||||
ZedDotDevModel::Gpt4 => ZedDotDevModel::Gpt4Turbo,
|
||||
ZedDotDevModel::Gpt4Turbo => ZedDotDevModel::Claude3Opus,
|
||||
ZedDotDevModel::Claude3Opus => ZedDotDevModel::Claude3Sonnet,
|
||||
ZedDotDevModel::Claude3Sonnet => ZedDotDevModel::Claude3Haiku,
|
||||
ZedDotDevModel::Claude3Haiku => {
|
||||
match CompletionProvider::global(cx).default_model() {
|
||||
LanguageModel::ZedDotDev(custom) => custom,
|
||||
_ => ZedDotDevModel::GptThreePointFiveTurbo,
|
||||
_ => ZedDotDevModel::Gpt3Point5Turbo,
|
||||
}
|
||||
}
|
||||
ZedDotDevModel::Custom(_) => ZedDotDevModel::GptThreePointFiveTurbo,
|
||||
ZedDotDevModel::Custom(_) => ZedDotDevModel::Gpt3Point5Turbo,
|
||||
}),
|
||||
};
|
||||
|
||||
|
@ -14,10 +14,13 @@ use settings::Settings;
|
||||
|
||||
#[derive(Clone, Debug, Default, PartialEq)]
|
||||
pub enum ZedDotDevModel {
|
||||
GptThreePointFiveTurbo,
|
||||
GptFour,
|
||||
Gpt3Point5Turbo,
|
||||
Gpt4,
|
||||
#[default]
|
||||
GptFourTurbo,
|
||||
Gpt4Turbo,
|
||||
Claude3Opus,
|
||||
Claude3Sonnet,
|
||||
Claude3Haiku,
|
||||
Custom(String),
|
||||
}
|
||||
|
||||
@ -49,9 +52,9 @@ impl<'de> Deserialize<'de> for ZedDotDevModel {
|
||||
E: de::Error,
|
||||
{
|
||||
match value {
|
||||
"gpt-3.5-turbo" => Ok(ZedDotDevModel::GptThreePointFiveTurbo),
|
||||
"gpt-4" => Ok(ZedDotDevModel::GptFour),
|
||||
"gpt-4-turbo-preview" => Ok(ZedDotDevModel::GptFourTurbo),
|
||||
"gpt-3.5-turbo" => Ok(ZedDotDevModel::Gpt3Point5Turbo),
|
||||
"gpt-4" => Ok(ZedDotDevModel::Gpt4),
|
||||
"gpt-4-turbo-preview" => Ok(ZedDotDevModel::Gpt4Turbo),
|
||||
_ => Ok(ZedDotDevModel::Custom(value.to_owned())),
|
||||
}
|
||||
}
|
||||
@ -94,27 +97,34 @@ impl JsonSchema for ZedDotDevModel {
|
||||
impl ZedDotDevModel {
|
||||
pub fn id(&self) -> &str {
|
||||
match self {
|
||||
Self::GptThreePointFiveTurbo => "gpt-3.5-turbo",
|
||||
Self::GptFour => "gpt-4",
|
||||
Self::GptFourTurbo => "gpt-4-turbo-preview",
|
||||
Self::Gpt3Point5Turbo => "gpt-3.5-turbo",
|
||||
Self::Gpt4 => "gpt-4",
|
||||
Self::Gpt4Turbo => "gpt-4-turbo-preview",
|
||||
Self::Claude3Opus => "claude-3-opus",
|
||||
Self::Claude3Sonnet => "claude-3-sonnet",
|
||||
Self::Claude3Haiku => "claude-3-haiku",
|
||||
Self::Custom(id) => id,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn display_name(&self) -> &str {
|
||||
match self {
|
||||
Self::GptThreePointFiveTurbo => "gpt-3.5-turbo",
|
||||
Self::GptFour => "gpt-4",
|
||||
Self::GptFourTurbo => "gpt-4-turbo",
|
||||
Self::Gpt3Point5Turbo => "GPT 3.5 Turbo",
|
||||
Self::Gpt4 => "GPT 4",
|
||||
Self::Gpt4Turbo => "GPT 4 Turbo",
|
||||
Self::Claude3Opus => "Claude 3 Opus",
|
||||
Self::Claude3Sonnet => "Claude 3 Sonnet",
|
||||
Self::Claude3Haiku => "Claude 3 Haiku",
|
||||
Self::Custom(id) => id.as_str(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn max_token_count(&self) -> usize {
|
||||
match self {
|
||||
Self::GptThreePointFiveTurbo => 2048,
|
||||
Self::GptFour => 4096,
|
||||
Self::GptFourTurbo => 128000,
|
||||
Self::Gpt3Point5Turbo => 2048,
|
||||
Self::Gpt4 => 4096,
|
||||
Self::Gpt4Turbo => 128000,
|
||||
Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => 200000,
|
||||
Self::Custom(_) => 4096, // TODO: Make this configurable
|
||||
}
|
||||
}
|
||||
|
@ -1,5 +1,5 @@
|
||||
use crate::{
|
||||
assistant_settings::ZedDotDevModel, count_open_ai_tokens, CompletionProvider,
|
||||
assistant_settings::ZedDotDevModel, count_open_ai_tokens, CompletionProvider, LanguageModel,
|
||||
LanguageModelRequest,
|
||||
};
|
||||
use anyhow::{anyhow, Result};
|
||||
@ -78,13 +78,21 @@ impl ZedDotDevCompletionProvider {
|
||||
cx: &AppContext,
|
||||
) -> BoxFuture<'static, Result<usize>> {
|
||||
match request.model {
|
||||
crate::LanguageModel::OpenAi(_) => future::ready(Err(anyhow!("invalid model"))).boxed(),
|
||||
crate::LanguageModel::ZedDotDev(ZedDotDevModel::GptFour)
|
||||
| crate::LanguageModel::ZedDotDev(ZedDotDevModel::GptFourTurbo)
|
||||
| crate::LanguageModel::ZedDotDev(ZedDotDevModel::GptThreePointFiveTurbo) => {
|
||||
LanguageModel::OpenAi(_) => future::ready(Err(anyhow!("invalid model"))).boxed(),
|
||||
LanguageModel::ZedDotDev(ZedDotDevModel::Gpt4)
|
||||
| LanguageModel::ZedDotDev(ZedDotDevModel::Gpt4Turbo)
|
||||
| LanguageModel::ZedDotDev(ZedDotDevModel::Gpt3Point5Turbo) => {
|
||||
count_open_ai_tokens(request, cx.background_executor())
|
||||
}
|
||||
crate::LanguageModel::ZedDotDev(ZedDotDevModel::Custom(model)) => {
|
||||
LanguageModel::ZedDotDev(
|
||||
ZedDotDevModel::Claude3Opus
|
||||
| ZedDotDevModel::Claude3Sonnet
|
||||
| ZedDotDevModel::Claude3Haiku,
|
||||
) => {
|
||||
// Can't find a tokenizer for Claude 3, so for now just use the same as OpenAI's as an approximation.
|
||||
count_open_ai_tokens(request, cx.background_executor())
|
||||
}
|
||||
LanguageModel::ZedDotDev(ZedDotDevModel::Custom(model)) => {
|
||||
let request = self.client.request(proto::CountTokensWithLanguageModel {
|
||||
model,
|
||||
messages: request
|
||||
|
@ -18,6 +18,7 @@ sqlite = ["sea-orm/sqlx-sqlite", "sqlx/sqlite"]
|
||||
test-support = ["sqlite"]
|
||||
|
||||
[dependencies]
|
||||
anthropic.workspace = true
|
||||
anyhow.workspace = true
|
||||
async-tungstenite = "0.16"
|
||||
aws-config = { version = "1.1.5" }
|
||||
|
@ -130,6 +130,11 @@ spec:
|
||||
secretKeyRef:
|
||||
name: openai
|
||||
key: api_key
|
||||
- name: ANTHROPIC_API_KEY
|
||||
valueFrom:
|
||||
secretKeyRef:
|
||||
name: anthropic
|
||||
key: api_key
|
||||
- name: BLOB_STORE_ACCESS_KEY
|
||||
valueFrom:
|
||||
secretKeyRef:
|
||||
|
@ -134,6 +134,7 @@ pub struct Config {
|
||||
pub zed_environment: Arc<str>,
|
||||
pub openai_api_key: Option<Arc<str>>,
|
||||
pub google_ai_api_key: Option<Arc<str>>,
|
||||
pub anthropic_api_key: Option<Arc<str>>,
|
||||
pub zed_client_checksum_seed: Option<String>,
|
||||
pub slack_panics_webhook: Option<String>,
|
||||
pub auto_join_channel_id: Option<ChannelId>,
|
||||
|
@ -419,6 +419,7 @@ impl Server {
|
||||
session,
|
||||
app_state.config.openai_api_key.clone(),
|
||||
app_state.config.google_ai_api_key.clone(),
|
||||
app_state.config.anthropic_api_key.clone(),
|
||||
)
|
||||
}
|
||||
})
|
||||
@ -3506,6 +3507,7 @@ async fn complete_with_language_model(
|
||||
session: Session,
|
||||
open_ai_api_key: Option<Arc<str>>,
|
||||
google_ai_api_key: Option<Arc<str>>,
|
||||
anthropic_api_key: Option<Arc<str>>,
|
||||
) -> Result<()> {
|
||||
let Some(session) = session.for_user() else {
|
||||
return Err(anyhow!("user not found"))?;
|
||||
@ -3524,6 +3526,10 @@ async fn complete_with_language_model(
|
||||
let api_key = google_ai_api_key
|
||||
.ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
|
||||
complete_with_google_ai(request, response, session, api_key).await?;
|
||||
} else if request.model.starts_with("claude") {
|
||||
let api_key = anthropic_api_key
|
||||
.ok_or_else(|| anyhow!("no Anthropic AI API key configured on the server"))?;
|
||||
complete_with_anthropic(request, response, session, api_key).await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
@ -3621,6 +3627,121 @@ async fn complete_with_google_ai(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn complete_with_anthropic(
|
||||
request: proto::CompleteWithLanguageModel,
|
||||
response: StreamingResponse<proto::CompleteWithLanguageModel>,
|
||||
session: UserSession,
|
||||
api_key: Arc<str>,
|
||||
) -> Result<()> {
|
||||
let model = anthropic::Model::from_id(&request.model)?;
|
||||
|
||||
let mut system_message = String::new();
|
||||
let messages = request
|
||||
.messages
|
||||
.into_iter()
|
||||
.filter_map(|message| match message.role() {
|
||||
LanguageModelRole::LanguageModelUser => Some(anthropic::RequestMessage {
|
||||
role: anthropic::Role::User,
|
||||
content: message.content,
|
||||
}),
|
||||
LanguageModelRole::LanguageModelAssistant => Some(anthropic::RequestMessage {
|
||||
role: anthropic::Role::Assistant,
|
||||
content: message.content,
|
||||
}),
|
||||
// Anthropic's API breaks system instructions out as a separate field rather
|
||||
// than having a system message role.
|
||||
LanguageModelRole::LanguageModelSystem => {
|
||||
if !system_message.is_empty() {
|
||||
system_message.push_str("\n\n");
|
||||
}
|
||||
system_message.push_str(&message.content);
|
||||
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
let mut stream = anthropic::stream_completion(
|
||||
&session.http_client,
|
||||
"https://api.anthropic.com",
|
||||
&api_key,
|
||||
anthropic::Request {
|
||||
model,
|
||||
messages,
|
||||
stream: true,
|
||||
system: system_message,
|
||||
max_tokens: 4092,
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
||||
let mut current_role = proto::LanguageModelRole::LanguageModelAssistant;
|
||||
|
||||
while let Some(event) = stream.next().await {
|
||||
let event = event?;
|
||||
|
||||
match event {
|
||||
anthropic::ResponseEvent::MessageStart { message } => {
|
||||
if let Some(role) = message.role {
|
||||
if role == "assistant" {
|
||||
current_role = proto::LanguageModelRole::LanguageModelAssistant;
|
||||
} else if role == "user" {
|
||||
current_role = proto::LanguageModelRole::LanguageModelUser;
|
||||
}
|
||||
}
|
||||
}
|
||||
anthropic::ResponseEvent::ContentBlockStart { content_block, .. } => {
|
||||
match content_block {
|
||||
anthropic::ContentBlock::Text { text } => {
|
||||
if !text.is_empty() {
|
||||
response.send(proto::LanguageModelResponse {
|
||||
choices: vec![proto::LanguageModelChoiceDelta {
|
||||
index: 0,
|
||||
delta: Some(proto::LanguageModelResponseMessage {
|
||||
role: Some(current_role as i32),
|
||||
content: Some(text),
|
||||
}),
|
||||
finish_reason: None,
|
||||
}],
|
||||
})?;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
anthropic::ResponseEvent::ContentBlockDelta { delta, .. } => match delta {
|
||||
anthropic::TextDelta::TextDelta { text } => {
|
||||
response.send(proto::LanguageModelResponse {
|
||||
choices: vec![proto::LanguageModelChoiceDelta {
|
||||
index: 0,
|
||||
delta: Some(proto::LanguageModelResponseMessage {
|
||||
role: Some(current_role as i32),
|
||||
content: Some(text),
|
||||
}),
|
||||
finish_reason: None,
|
||||
}],
|
||||
})?;
|
||||
}
|
||||
},
|
||||
anthropic::ResponseEvent::MessageDelta { delta, .. } => {
|
||||
if let Some(stop_reason) = delta.stop_reason {
|
||||
response.send(proto::LanguageModelResponse {
|
||||
choices: vec![proto::LanguageModelChoiceDelta {
|
||||
index: 0,
|
||||
delta: None,
|
||||
finish_reason: Some(stop_reason),
|
||||
}],
|
||||
})?;
|
||||
}
|
||||
}
|
||||
anthropic::ResponseEvent::ContentBlockStop { .. } => {}
|
||||
anthropic::ResponseEvent::MessageStop {} => {}
|
||||
anthropic::ResponseEvent::Ping {} => {}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
struct CountTokensWithLanguageModelRateLimit;
|
||||
|
||||
impl RateLimit for CountTokensWithLanguageModelRateLimit {
|
||||
|
@ -512,6 +512,7 @@ impl TestServer {
|
||||
blob_store_bucket: None,
|
||||
openai_api_key: None,
|
||||
google_ai_api_key: None,
|
||||
anthropic_api_key: None,
|
||||
clickhouse_url: None,
|
||||
clickhouse_user: None,
|
||||
clickhouse_password: None,
|
||||
|
Loading…
Reference in New Issue
Block a user