mirror of
https://github.com/zed-industries/zed.git
synced 2024-11-08 15:44:31 +03:00
Implement Anthropic prompt caching (#16274)
Release Notes: - Adds support for Prompt Caching in Anthropic. For models that support it this can dramatically lower cost while improving performance.
This commit is contained in:
parent
09b6e3f2a6
commit
46fb917e02
@ -14,6 +14,14 @@ pub use supported_countries::*;
|
||||
|
||||
pub const ANTHROPIC_API_URL: &'static str = "https://api.anthropic.com";
|
||||
|
||||
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
|
||||
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
|
||||
pub struct AnthropicModelCacheConfiguration {
|
||||
pub min_total_token: usize,
|
||||
pub should_speculate: bool,
|
||||
pub max_cache_anchors: usize,
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
|
||||
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
|
||||
pub enum Model {
|
||||
@ -32,6 +40,8 @@ pub enum Model {
|
||||
max_tokens: usize,
|
||||
/// Override this model with a different Anthropic model for tool calls.
|
||||
tool_override: Option<String>,
|
||||
/// Indicates whether this custom model supports caching.
|
||||
cache_configuration: Option<AnthropicModelCacheConfiguration>,
|
||||
},
|
||||
}
|
||||
|
||||
@ -70,6 +80,21 @@ impl Model {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn cache_configuration(&self) -> Option<AnthropicModelCacheConfiguration> {
|
||||
match self {
|
||||
Self::Claude3_5Sonnet | Self::Claude3Haiku => Some(AnthropicModelCacheConfiguration {
|
||||
min_total_token: 2_048,
|
||||
should_speculate: true,
|
||||
max_cache_anchors: 4,
|
||||
}),
|
||||
Self::Custom {
|
||||
cache_configuration,
|
||||
..
|
||||
} => cache_configuration.clone(),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn max_token_count(&self) -> usize {
|
||||
match self {
|
||||
Self::Claude3_5Sonnet
|
||||
@ -104,7 +129,10 @@ pub async fn complete(
|
||||
.method(Method::POST)
|
||||
.uri(uri)
|
||||
.header("Anthropic-Version", "2023-06-01")
|
||||
.header("Anthropic-Beta", "tools-2024-04-04")
|
||||
.header(
|
||||
"Anthropic-Beta",
|
||||
"tools-2024-04-04,prompt-caching-2024-07-31",
|
||||
)
|
||||
.header("X-Api-Key", api_key)
|
||||
.header("Content-Type", "application/json");
|
||||
|
||||
@ -161,7 +189,10 @@ pub async fn stream_completion(
|
||||
.method(Method::POST)
|
||||
.uri(uri)
|
||||
.header("Anthropic-Version", "2023-06-01")
|
||||
.header("Anthropic-Beta", "tools-2024-04-04")
|
||||
.header(
|
||||
"Anthropic-Beta",
|
||||
"tools-2024-04-04,prompt-caching-2024-07-31",
|
||||
)
|
||||
.header("X-Api-Key", api_key)
|
||||
.header("Content-Type", "application/json");
|
||||
if let Some(low_speed_timeout) = low_speed_timeout {
|
||||
@ -226,7 +257,7 @@ pub fn extract_text_from_events(
|
||||
match response {
|
||||
Ok(response) => match response {
|
||||
Event::ContentBlockStart { content_block, .. } => match content_block {
|
||||
Content::Text { text } => Some(Ok(text)),
|
||||
Content::Text { text, .. } => Some(Ok(text)),
|
||||
_ => None,
|
||||
},
|
||||
Event::ContentBlockDelta { delta, .. } => match delta {
|
||||
@ -285,13 +316,25 @@ pub async fn extract_tool_args_from_events(
|
||||
}))
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Copy, Clone)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum CacheControlType {
|
||||
Ephemeral,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Copy, Clone)]
|
||||
pub struct CacheControl {
|
||||
#[serde(rename = "type")]
|
||||
pub cache_type: CacheControlType,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct Message {
|
||||
pub role: Role,
|
||||
pub content: Vec<Content>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum Role {
|
||||
User,
|
||||
@ -302,19 +345,31 @@ pub enum Role {
|
||||
#[serde(tag = "type")]
|
||||
pub enum Content {
|
||||
#[serde(rename = "text")]
|
||||
Text { text: String },
|
||||
Text {
|
||||
text: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
cache_control: Option<CacheControl>,
|
||||
},
|
||||
#[serde(rename = "image")]
|
||||
Image { source: ImageSource },
|
||||
Image {
|
||||
source: ImageSource,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
cache_control: Option<CacheControl>,
|
||||
},
|
||||
#[serde(rename = "tool_use")]
|
||||
ToolUse {
|
||||
id: String,
|
||||
name: String,
|
||||
input: serde_json::Value,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
cache_control: Option<CacheControl>,
|
||||
},
|
||||
#[serde(rename = "tool_result")]
|
||||
ToolResult {
|
||||
tool_use_id: String,
|
||||
content: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
cache_control: Option<CacheControl>,
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -21,8 +21,8 @@ use gpui::{
|
||||
|
||||
use language::{AnchorRangeExt, Bias, Buffer, LanguageRegistry, OffsetRangeExt, Point, ToOffset};
|
||||
use language_model::{
|
||||
LanguageModelImage, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
|
||||
Role,
|
||||
LanguageModel, LanguageModelCacheConfiguration, LanguageModelImage, LanguageModelRegistry,
|
||||
LanguageModelRequest, LanguageModelRequestMessage, Role,
|
||||
};
|
||||
use open_ai::Model as OpenAiModel;
|
||||
use paths::{context_images_dir, contexts_dir};
|
||||
@ -30,7 +30,7 @@ use project::Project;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use smallvec::SmallVec;
|
||||
use std::{
|
||||
cmp::Ordering,
|
||||
cmp::{max, Ordering},
|
||||
collections::hash_map,
|
||||
fmt::Debug,
|
||||
iter, mem,
|
||||
@ -107,6 +107,8 @@ impl ContextOperation {
|
||||
message.status.context("invalid status")?,
|
||||
),
|
||||
timestamp: id.0,
|
||||
should_cache: false,
|
||||
is_cache_anchor: false,
|
||||
},
|
||||
version: language::proto::deserialize_version(&insert.version),
|
||||
})
|
||||
@ -121,6 +123,8 @@ impl ContextOperation {
|
||||
timestamp: language::proto::deserialize_timestamp(
|
||||
update.timestamp.context("invalid timestamp")?,
|
||||
),
|
||||
should_cache: false,
|
||||
is_cache_anchor: false,
|
||||
},
|
||||
version: language::proto::deserialize_version(&update.version),
|
||||
}),
|
||||
@ -313,6 +317,8 @@ pub struct MessageMetadata {
|
||||
pub role: Role,
|
||||
pub status: MessageStatus,
|
||||
timestamp: clock::Lamport,
|
||||
should_cache: bool,
|
||||
is_cache_anchor: bool,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
@ -338,6 +344,7 @@ pub struct Message {
|
||||
pub anchor: language::Anchor,
|
||||
pub role: Role,
|
||||
pub status: MessageStatus,
|
||||
pub cache: bool,
|
||||
}
|
||||
|
||||
impl Message {
|
||||
@ -373,6 +380,7 @@ impl Message {
|
||||
LanguageModelRequestMessage {
|
||||
role: self.role,
|
||||
content,
|
||||
cache: self.cache,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -421,6 +429,7 @@ pub struct Context {
|
||||
token_count: Option<usize>,
|
||||
pending_token_count: Task<Option<()>>,
|
||||
pending_save: Task<Result<()>>,
|
||||
pending_cache_warming_task: Task<Option<()>>,
|
||||
path: Option<PathBuf>,
|
||||
_subscriptions: Vec<Subscription>,
|
||||
telemetry: Option<Arc<Telemetry>>,
|
||||
@ -498,6 +507,7 @@ impl Context {
|
||||
pending_completions: Default::default(),
|
||||
token_count: None,
|
||||
pending_token_count: Task::ready(None),
|
||||
pending_cache_warming_task: Task::ready(None),
|
||||
_subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
|
||||
pending_save: Task::ready(Ok(())),
|
||||
path: None,
|
||||
@ -524,6 +534,8 @@ impl Context {
|
||||
role: Role::User,
|
||||
status: MessageStatus::Done,
|
||||
timestamp: first_message_id.0,
|
||||
should_cache: false,
|
||||
is_cache_anchor: false,
|
||||
},
|
||||
);
|
||||
this.message_anchors.push(message);
|
||||
@ -948,6 +960,7 @@ impl Context {
|
||||
let token_count = cx.update(|cx| model.count_tokens(request, cx))?.await?;
|
||||
this.update(&mut cx, |this, cx| {
|
||||
this.token_count = Some(token_count);
|
||||
this.start_cache_warming(&model, cx);
|
||||
cx.notify()
|
||||
})
|
||||
}
|
||||
@ -955,6 +968,121 @@ impl Context {
|
||||
});
|
||||
}
|
||||
|
||||
pub fn mark_longest_messages_for_cache(
|
||||
&mut self,
|
||||
cache_configuration: &Option<LanguageModelCacheConfiguration>,
|
||||
speculative: bool,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) -> bool {
|
||||
let cache_configuration =
|
||||
cache_configuration
|
||||
.as_ref()
|
||||
.unwrap_or(&LanguageModelCacheConfiguration {
|
||||
max_cache_anchors: 0,
|
||||
should_speculate: false,
|
||||
min_total_token: 0,
|
||||
});
|
||||
|
||||
let messages: Vec<Message> = self
|
||||
.messages_from_anchors(
|
||||
self.message_anchors.iter().take(if speculative {
|
||||
self.message_anchors.len().saturating_sub(1)
|
||||
} else {
|
||||
self.message_anchors.len()
|
||||
}),
|
||||
cx,
|
||||
)
|
||||
.filter(|message| message.offset_range.len() >= 5_000)
|
||||
.collect();
|
||||
|
||||
let mut sorted_messages = messages.clone();
|
||||
sorted_messages.sort_by(|a, b| b.offset_range.len().cmp(&a.offset_range.len()));
|
||||
if cache_configuration.max_cache_anchors == 0 && cache_configuration.should_speculate {
|
||||
// Some models support caching, but don't support anchors. In that case we want to
|
||||
// mark the largest message as needing to be cached, but we will not mark it as an
|
||||
// anchor.
|
||||
sorted_messages.truncate(1);
|
||||
} else {
|
||||
// Save 1 anchor for the inline assistant.
|
||||
sorted_messages.truncate(max(cache_configuration.max_cache_anchors, 1) - 1);
|
||||
}
|
||||
|
||||
let longest_message_ids: HashSet<MessageId> = sorted_messages
|
||||
.into_iter()
|
||||
.map(|message| message.id)
|
||||
.collect();
|
||||
|
||||
let cache_deltas: HashSet<MessageId> = self
|
||||
.messages_metadata
|
||||
.iter()
|
||||
.filter_map(|(id, metadata)| {
|
||||
let should_cache = longest_message_ids.contains(id);
|
||||
let should_be_anchor = should_cache && cache_configuration.max_cache_anchors > 0;
|
||||
if metadata.should_cache != should_cache
|
||||
|| metadata.is_cache_anchor != should_be_anchor
|
||||
{
|
||||
Some(*id)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
let mut newly_cached_item = false;
|
||||
for id in cache_deltas {
|
||||
newly_cached_item = newly_cached_item || longest_message_ids.contains(&id);
|
||||
self.update_metadata(id, cx, |metadata| {
|
||||
metadata.should_cache = longest_message_ids.contains(&id);
|
||||
metadata.is_cache_anchor =
|
||||
metadata.should_cache && (cache_configuration.max_cache_anchors > 0);
|
||||
});
|
||||
}
|
||||
newly_cached_item
|
||||
}
|
||||
|
||||
fn start_cache_warming(&mut self, model: &Arc<dyn LanguageModel>, cx: &mut ModelContext<Self>) {
|
||||
let cache_configuration = model.cache_configuration();
|
||||
if !self.mark_longest_messages_for_cache(&cache_configuration, true, cx) {
|
||||
return;
|
||||
}
|
||||
if let Some(cache_configuration) = cache_configuration {
|
||||
if !cache_configuration.should_speculate {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
let request = {
|
||||
let mut req = self.to_completion_request(cx);
|
||||
// Skip the last message because it's likely to change and
|
||||
// therefore would be a waste to cache.
|
||||
req.messages.pop();
|
||||
req.messages.push(LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: vec!["Respond only with OK, nothing else.".into()],
|
||||
cache: false,
|
||||
});
|
||||
req
|
||||
};
|
||||
|
||||
let model = Arc::clone(model);
|
||||
self.pending_cache_warming_task = cx.spawn(|_, cx| {
|
||||
async move {
|
||||
match model.stream_completion(request, &cx).await {
|
||||
Ok(mut stream) => {
|
||||
stream.next().await;
|
||||
log::info!("Cache warming completed successfully");
|
||||
}
|
||||
Err(e) => {
|
||||
log::warn!("Cache warming failed: {}", e);
|
||||
}
|
||||
};
|
||||
|
||||
anyhow::Ok(())
|
||||
}
|
||||
.log_err()
|
||||
});
|
||||
}
|
||||
|
||||
pub fn reparse_slash_commands(&mut self, cx: &mut ModelContext<Self>) {
|
||||
let buffer = self.buffer.read(cx);
|
||||
let mut row_ranges = self
|
||||
@ -1352,20 +1480,26 @@ impl Context {
|
||||
self.count_remaining_tokens(cx);
|
||||
}
|
||||
|
||||
pub fn assist(&mut self, cx: &mut ModelContext<Self>) -> Option<MessageAnchor> {
|
||||
let provider = LanguageModelRegistry::read_global(cx).active_provider()?;
|
||||
let model = LanguageModelRegistry::read_global(cx).active_model()?;
|
||||
let last_message_id = self.message_anchors.iter().rev().find_map(|message| {
|
||||
fn get_last_valid_message_id(&self, cx: &ModelContext<Self>) -> Option<MessageId> {
|
||||
self.message_anchors.iter().rev().find_map(|message| {
|
||||
message
|
||||
.start
|
||||
.is_valid(self.buffer.read(cx))
|
||||
.then_some(message.id)
|
||||
})?;
|
||||
})
|
||||
}
|
||||
|
||||
pub fn assist(&mut self, cx: &mut ModelContext<Self>) -> Option<MessageAnchor> {
|
||||
let provider = LanguageModelRegistry::read_global(cx).active_provider()?;
|
||||
let model = LanguageModelRegistry::read_global(cx).active_model()?;
|
||||
let last_message_id = self.get_last_valid_message_id(cx)?;
|
||||
|
||||
if !provider.is_authenticated(cx) {
|
||||
log::info!("completion provider has no credentials");
|
||||
return None;
|
||||
}
|
||||
// Compute which messages to cache, including the last one.
|
||||
self.mark_longest_messages_for_cache(&model.cache_configuration(), false, cx);
|
||||
|
||||
let request = self.to_completion_request(cx);
|
||||
let assistant_message = self
|
||||
@ -1580,6 +1714,8 @@ impl Context {
|
||||
role,
|
||||
status,
|
||||
timestamp: anchor.id.0,
|
||||
should_cache: false,
|
||||
is_cache_anchor: false,
|
||||
};
|
||||
self.insert_message(anchor.clone(), metadata.clone(), cx);
|
||||
self.push_op(
|
||||
@ -1696,6 +1832,8 @@ impl Context {
|
||||
role,
|
||||
status: MessageStatus::Done,
|
||||
timestamp: suffix.id.0,
|
||||
should_cache: false,
|
||||
is_cache_anchor: false,
|
||||
};
|
||||
self.insert_message(suffix.clone(), suffix_metadata.clone(), cx);
|
||||
self.push_op(
|
||||
@ -1745,6 +1883,8 @@ impl Context {
|
||||
role,
|
||||
status: MessageStatus::Done,
|
||||
timestamp: selection.id.0,
|
||||
should_cache: false,
|
||||
is_cache_anchor: false,
|
||||
};
|
||||
self.insert_message(selection.clone(), selection_metadata.clone(), cx);
|
||||
self.push_op(
|
||||
@ -1811,6 +1951,7 @@ impl Context {
|
||||
content: vec![
|
||||
"Summarize the context into a short title without punctuation.".into(),
|
||||
],
|
||||
cache: false,
|
||||
}));
|
||||
let request = LanguageModelRequest {
|
||||
messages: messages.collect(),
|
||||
@ -1910,14 +2051,22 @@ impl Context {
|
||||
result
|
||||
}
|
||||
|
||||
pub fn messages<'a>(&'a self, cx: &'a AppContext) -> impl 'a + Iterator<Item = Message> {
|
||||
fn messages_from_anchors<'a>(
|
||||
&'a self,
|
||||
message_anchors: impl Iterator<Item = &'a MessageAnchor> + 'a,
|
||||
cx: &'a AppContext,
|
||||
) -> impl 'a + Iterator<Item = Message> {
|
||||
let buffer = self.buffer.read(cx);
|
||||
let messages = self.message_anchors.iter().enumerate();
|
||||
let messages = message_anchors.enumerate();
|
||||
let images = self.image_anchors.iter();
|
||||
|
||||
Self::messages_from_iters(buffer, &self.messages_metadata, messages, images)
|
||||
}
|
||||
|
||||
pub fn messages<'a>(&'a self, cx: &'a AppContext) -> impl 'a + Iterator<Item = Message> {
|
||||
self.messages_from_anchors(self.message_anchors.iter(), cx)
|
||||
}
|
||||
|
||||
pub fn messages_from_iters<'a>(
|
||||
buffer: &'a Buffer,
|
||||
metadata: &'a HashMap<MessageId, MessageMetadata>,
|
||||
@ -1969,6 +2118,7 @@ impl Context {
|
||||
anchor: message_anchor.start,
|
||||
role: metadata.role,
|
||||
status: metadata.status.clone(),
|
||||
cache: metadata.is_cache_anchor,
|
||||
image_offsets,
|
||||
});
|
||||
}
|
||||
@ -2215,6 +2365,8 @@ impl SavedContext {
|
||||
role: message.metadata.role,
|
||||
status: message.metadata.status,
|
||||
timestamp: message.metadata.timestamp,
|
||||
should_cache: false,
|
||||
is_cache_anchor: false,
|
||||
},
|
||||
version: version.clone(),
|
||||
});
|
||||
@ -2231,6 +2383,8 @@ impl SavedContext {
|
||||
role: metadata.role,
|
||||
status: metadata.status,
|
||||
timestamp,
|
||||
should_cache: false,
|
||||
is_cache_anchor: false,
|
||||
},
|
||||
version: version.clone(),
|
||||
});
|
||||
@ -2325,6 +2479,8 @@ impl SavedContextV0_3_0 {
|
||||
role: metadata.role,
|
||||
status: metadata.status.clone(),
|
||||
timestamp,
|
||||
should_cache: false,
|
||||
is_cache_anchor: false,
|
||||
},
|
||||
image_offsets: Vec::new(),
|
||||
})
|
||||
|
@ -2387,6 +2387,7 @@ impl Codegen {
|
||||
messages.push(LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: vec![prompt.into()],
|
||||
cache: false,
|
||||
});
|
||||
|
||||
Ok(LanguageModelRequest {
|
||||
|
@ -784,6 +784,7 @@ impl PromptLibrary {
|
||||
messages: vec![LanguageModelRequestMessage {
|
||||
role: Role::System,
|
||||
content: vec![body.to_string().into()],
|
||||
cache: false,
|
||||
}],
|
||||
stop: Vec::new(),
|
||||
temperature: 1.,
|
||||
|
@ -277,6 +277,7 @@ impl TerminalInlineAssistant {
|
||||
messages.push(LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: vec![prompt.into()],
|
||||
cache: false,
|
||||
});
|
||||
|
||||
Ok(LanguageModelRequest {
|
||||
|
@ -136,6 +136,7 @@ impl WorkflowStep {
|
||||
request.messages.push(LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: vec![prompt.into()],
|
||||
cache: false,
|
||||
});
|
||||
|
||||
// Invoke the model to get its edit suggestions for this workflow step.
|
||||
|
@ -20,7 +20,7 @@ pub use registry::*;
|
||||
pub use request::*;
|
||||
pub use role::*;
|
||||
use schemars::JsonSchema;
|
||||
use serde::de::DeserializeOwned;
|
||||
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
||||
use std::{future::Future, sync::Arc};
|
||||
use ui::IconName;
|
||||
|
||||
@ -43,6 +43,14 @@ pub enum LanguageModelAvailability {
|
||||
RequiresPlan(Plan),
|
||||
}
|
||||
|
||||
/// Configuration for caching language model messages.
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct LanguageModelCacheConfiguration {
|
||||
pub max_cache_anchors: usize,
|
||||
pub should_speculate: bool,
|
||||
pub min_total_token: usize,
|
||||
}
|
||||
|
||||
pub trait LanguageModel: Send + Sync {
|
||||
fn id(&self) -> LanguageModelId;
|
||||
fn name(&self) -> LanguageModelName;
|
||||
@ -78,6 +86,10 @@ pub trait LanguageModel: Send + Sync {
|
||||
cx: &AsyncAppContext,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
|
||||
|
||||
fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
|
||||
None
|
||||
}
|
||||
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
fn as_fake(&self) -> &provider::fake::FakeLanguageModel {
|
||||
unimplemented!()
|
||||
|
@ -1,7 +1,7 @@
|
||||
use crate::{
|
||||
settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
|
||||
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
|
||||
LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
|
||||
settings::AllLanguageModelSettings, LanguageModel, LanguageModelCacheConfiguration,
|
||||
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
|
||||
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
|
||||
};
|
||||
use anthropic::AnthropicError;
|
||||
use anyhow::{anyhow, Context as _, Result};
|
||||
@ -38,6 +38,7 @@ pub struct AvailableModel {
|
||||
pub name: String,
|
||||
pub max_tokens: usize,
|
||||
pub tool_override: Option<String>,
|
||||
pub cache_configuration: Option<LanguageModelCacheConfiguration>,
|
||||
}
|
||||
|
||||
pub struct AnthropicLanguageModelProvider {
|
||||
@ -171,6 +172,13 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider {
|
||||
name: model.name.clone(),
|
||||
max_tokens: model.max_tokens,
|
||||
tool_override: model.tool_override.clone(),
|
||||
cache_configuration: model.cache_configuration.as_ref().map(|config| {
|
||||
anthropic::AnthropicModelCacheConfiguration {
|
||||
max_cache_anchors: config.max_cache_anchors,
|
||||
should_speculate: config.should_speculate,
|
||||
min_total_token: config.min_total_token,
|
||||
}
|
||||
}),
|
||||
},
|
||||
);
|
||||
}
|
||||
@ -351,6 +359,16 @@ impl LanguageModel for AnthropicModel {
|
||||
.boxed()
|
||||
}
|
||||
|
||||
fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
|
||||
self.model
|
||||
.cache_configuration()
|
||||
.map(|config| LanguageModelCacheConfiguration {
|
||||
max_cache_anchors: config.max_cache_anchors,
|
||||
should_speculate: config.should_speculate,
|
||||
min_total_token: config.min_total_token,
|
||||
})
|
||||
}
|
||||
|
||||
fn use_any_tool(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
|
@ -1,7 +1,7 @@
|
||||
use super::open_ai::count_open_ai_tokens;
|
||||
use crate::{
|
||||
settings::AllLanguageModelSettings, CloudModel, LanguageModel, LanguageModelId,
|
||||
LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
|
||||
settings::AllLanguageModelSettings, CloudModel, LanguageModel, LanguageModelCacheConfiguration,
|
||||
LanguageModelId, LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
|
||||
LanguageModelProviderState, LanguageModelRequest, RateLimiter, ZedModel,
|
||||
};
|
||||
use anthropic::AnthropicError;
|
||||
@ -56,6 +56,7 @@ pub struct AvailableModel {
|
||||
name: String,
|
||||
max_tokens: usize,
|
||||
tool_override: Option<String>,
|
||||
cache_configuration: Option<LanguageModelCacheConfiguration>,
|
||||
}
|
||||
|
||||
pub struct CloudLanguageModelProvider {
|
||||
@ -202,6 +203,13 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
|
||||
name: model.name.clone(),
|
||||
max_tokens: model.max_tokens,
|
||||
tool_override: model.tool_override.clone(),
|
||||
cache_configuration: model.cache_configuration.as_ref().map(|config| {
|
||||
anthropic::AnthropicModelCacheConfiguration {
|
||||
max_cache_anchors: config.max_cache_anchors,
|
||||
should_speculate: config.should_speculate,
|
||||
min_total_token: config.min_total_token,
|
||||
}
|
||||
}),
|
||||
})
|
||||
}
|
||||
AvailableProvider::OpenAi => CloudModel::OpenAi(open_ai::Model::Custom {
|
||||
|
@ -193,6 +193,7 @@ impl From<&str> for MessageContent {
|
||||
pub struct LanguageModelRequestMessage {
|
||||
pub role: Role,
|
||||
pub content: Vec<MessageContent>,
|
||||
pub cache: bool,
|
||||
}
|
||||
|
||||
impl LanguageModelRequestMessage {
|
||||
@ -213,7 +214,7 @@ impl LanguageModelRequestMessage {
|
||||
.content
|
||||
.get(0)
|
||||
.map(|content| match content {
|
||||
MessageContent::Text(s) => s.is_empty(),
|
||||
MessageContent::Text(s) => s.trim().is_empty(),
|
||||
MessageContent::Image(_) => true,
|
||||
})
|
||||
.unwrap_or(false)
|
||||
@ -286,7 +287,7 @@ impl LanguageModelRequest {
|
||||
}
|
||||
|
||||
pub fn into_anthropic(self, model: String) -> anthropic::Request {
|
||||
let mut new_messages: Vec<LanguageModelRequestMessage> = Vec::new();
|
||||
let mut new_messages: Vec<anthropic::Message> = Vec::new();
|
||||
let mut system_message = String::new();
|
||||
|
||||
for message in self.messages {
|
||||
@ -296,18 +297,50 @@ impl LanguageModelRequest {
|
||||
|
||||
match message.role {
|
||||
Role::User | Role::Assistant => {
|
||||
let cache_control = if message.cache {
|
||||
Some(anthropic::CacheControl {
|
||||
cache_type: anthropic::CacheControlType::Ephemeral,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let anthropic_message_content: Vec<anthropic::Content> = message
|
||||
.content
|
||||
.into_iter()
|
||||
// TODO: filter out the empty messages in the message construction step
|
||||
.filter_map(|content| match content {
|
||||
MessageContent::Text(t) if !t.is_empty() => {
|
||||
Some(anthropic::Content::Text {
|
||||
text: t,
|
||||
cache_control,
|
||||
})
|
||||
}
|
||||
MessageContent::Image(i) => Some(anthropic::Content::Image {
|
||||
source: anthropic::ImageSource {
|
||||
source_type: "base64".to_string(),
|
||||
media_type: "image/png".to_string(),
|
||||
data: i.source.to_string(),
|
||||
},
|
||||
cache_control,
|
||||
}),
|
||||
_ => None,
|
||||
})
|
||||
.collect();
|
||||
let anthropic_role = match message.role {
|
||||
Role::User => anthropic::Role::User,
|
||||
Role::Assistant => anthropic::Role::Assistant,
|
||||
Role::System => unreachable!("System role should never occur here"),
|
||||
};
|
||||
if let Some(last_message) = new_messages.last_mut() {
|
||||
if last_message.role == message.role {
|
||||
// TODO: is this append done properly?
|
||||
last_message.content.push(MessageContent::Text(format!(
|
||||
"\n\n{}",
|
||||
message.string_contents()
|
||||
)));
|
||||
if last_message.role == anthropic_role {
|
||||
last_message.content.extend(anthropic_message_content);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
new_messages.push(message);
|
||||
new_messages.push(anthropic::Message {
|
||||
role: anthropic_role,
|
||||
content: anthropic_message_content,
|
||||
});
|
||||
}
|
||||
Role::System => {
|
||||
if !system_message.is_empty() {
|
||||
@ -320,36 +353,7 @@ impl LanguageModelRequest {
|
||||
|
||||
anthropic::Request {
|
||||
model,
|
||||
messages: new_messages
|
||||
.into_iter()
|
||||
.filter_map(|message| {
|
||||
Some(anthropic::Message {
|
||||
role: match message.role {
|
||||
Role::User => anthropic::Role::User,
|
||||
Role::Assistant => anthropic::Role::Assistant,
|
||||
Role::System => return None,
|
||||
},
|
||||
content: message
|
||||
.content
|
||||
.into_iter()
|
||||
// TODO: filter out the empty messages in the message construction step
|
||||
.filter_map(|content| match content {
|
||||
MessageContent::Text(t) if !t.is_empty() => {
|
||||
Some(anthropic::Content::Text { text: t })
|
||||
}
|
||||
MessageContent::Image(i) => Some(anthropic::Content::Image {
|
||||
source: anthropic::ImageSource {
|
||||
source_type: "base64".to_string(),
|
||||
media_type: "image/png".to_string(),
|
||||
data: i.source.to_string(),
|
||||
},
|
||||
}),
|
||||
_ => None,
|
||||
})
|
||||
.collect(),
|
||||
})
|
||||
})
|
||||
.collect(),
|
||||
messages: new_messages,
|
||||
max_tokens: 4092,
|
||||
system: Some(system_message),
|
||||
tools: Vec::new(),
|
||||
|
@ -7,7 +7,8 @@ use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::{update_settings_file, Settings, SettingsSources};
|
||||
|
||||
use crate::provider::{
|
||||
use crate::{
|
||||
provider::{
|
||||
self,
|
||||
anthropic::AnthropicSettings,
|
||||
cloud::{self, ZedDotDevSettings},
|
||||
@ -15,6 +16,8 @@ use crate::provider::{
|
||||
google::GoogleSettings,
|
||||
ollama::OllamaSettings,
|
||||
open_ai::OpenAiSettings,
|
||||
},
|
||||
LanguageModelCacheConfiguration,
|
||||
};
|
||||
|
||||
/// Initializes the language model settings.
|
||||
@ -93,10 +96,18 @@ impl AnthropicSettingsContent {
|
||||
name,
|
||||
max_tokens,
|
||||
tool_override,
|
||||
cache_configuration,
|
||||
} => Some(provider::anthropic::AvailableModel {
|
||||
name,
|
||||
max_tokens,
|
||||
tool_override,
|
||||
cache_configuration: cache_configuration.as_ref().map(
|
||||
|config| LanguageModelCacheConfiguration {
|
||||
max_cache_anchors: config.max_cache_anchors,
|
||||
should_speculate: config.should_speculate,
|
||||
min_total_token: config.min_total_token,
|
||||
},
|
||||
),
|
||||
}),
|
||||
_ => None,
|
||||
})
|
||||
|
Loading…
Reference in New Issue
Block a user