Simplify LLM protocol (#15366)

In this pull request, we change the zed.dev protocol so that we pass the
raw JSON for the specified provider directly to our server. This avoids
the need to define a protobuf message that's a superset of all these
formats.

@bennetbo: We also changed the settings for available_models under
zed.dev to be a flat format, because the nesting seemed too confusing.
Can you help us upgrade the local provider configuration to be
consistent with this? We do whatever we need to do when parsing the
settings to make this simple for users, even if it's a bit more complex
on our end. We want to use versioning to avoid breaking existing users,
but need to keep making progress.

```json
"zed.dev": {
  "available_models": [
    {
      "provider": "anthropic",
        "name": "some-newly-released-model-we-havent-added",
        "max_tokens": 200000
      }
  ]
}
```

Release Notes:

- N/A

---------

Co-authored-by: Nathan <nathan@zed.dev>
This commit is contained in:
Antonio Scandurra 2024-07-28 11:07:10 +02:00 committed by GitHub
parent e0fe7f632c
commit d6bdaa8a91
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
31 changed files with 896 additions and 2154 deletions

33
Cargo.lock generated
View File

@ -471,27 +471,6 @@ dependencies = [
"workspace",
]
[[package]]
name = "assistant_tooling"
version = "0.1.0"
dependencies = [
"anyhow",
"collections",
"futures 0.3.28",
"gpui",
"log",
"project",
"repair_json",
"schemars",
"serde",
"serde_json",
"settings",
"sum_tree",
"ui",
"unindent",
"util",
]
[[package]]
name = "async-attributes"
version = "1.1.2"
@ -4811,8 +4790,10 @@ dependencies = [
"anyhow",
"futures 0.3.28",
"http_client",
"schemars",
"serde",
"serde_json",
"strum",
]
[[package]]
@ -5988,6 +5969,7 @@ dependencies = [
"env_logger",
"feature_flags",
"futures 0.3.28",
"google_ai",
"gpui",
"http_client",
"language",
@ -8715,15 +8697,6 @@ dependencies = [
"bytecheck",
]
[[package]]
name = "repair_json"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5ee191e184125fe72cb59b74160e25584e3908f2aaa84cbda1e161347102aa15"
dependencies = [
"thiserror",
]
[[package]]
name = "repl"
version = "0.1.0"

View File

@ -6,7 +6,6 @@ members = [
"crates/assets",
"crates/assistant",
"crates/assistant_slash_command",
"crates/assistant_tooling",
"crates/audio",
"crates/auto_update",
"crates/breadcrumbs",
@ -178,7 +177,6 @@ anthropic = { path = "crates/anthropic" }
assets = { path = "crates/assets" }
assistant = { path = "crates/assistant" }
assistant_slash_command = { path = "crates/assistant_slash_command" }
assistant_tooling = { path = "crates/assistant_tooling" }
audio = { path = "crates/audio" }
auto_update = { path = "crates/auto_update" }
breadcrumbs = { path = "crates/breadcrumbs" }

View File

@ -870,6 +870,9 @@
"openai": {
"api_url": "https://api.openai.com/v1"
},
"google": {
"api_url": "https://generativelanguage.googleapis.com"
},
"ollama": {
"api_url": "http://localhost:11434"
}

View File

@ -1,5 +1,5 @@
use anyhow::{anyhow, Result};
use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
use isahc::config::Configurable;
use serde::{Deserialize, Serialize};
@ -98,7 +98,7 @@ impl From<Role> for String {
}
}
#[derive(Debug, Serialize)]
#[derive(Debug, Serialize, Deserialize)]
pub struct Request {
pub model: String,
pub messages: Vec<RequestMessage>,
@ -113,7 +113,7 @@ pub struct RequestMessage {
pub content: String,
}
#[derive(Deserialize, Debug)]
#[derive(Deserialize, Serialize, Debug)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ResponseEvent {
MessageStart {
@ -138,7 +138,7 @@ pub enum ResponseEvent {
MessageStop {},
}
#[derive(Deserialize, Debug)]
#[derive(Serialize, Deserialize, Debug)]
pub struct ResponseMessage {
#[serde(rename = "type")]
pub message_type: Option<String>,
@ -151,19 +151,19 @@ pub struct ResponseMessage {
pub usage: Option<Usage>,
}
#[derive(Deserialize, Debug)]
#[derive(Serialize, Deserialize, Debug)]
pub struct Usage {
pub input_tokens: Option<u32>,
pub output_tokens: Option<u32>,
}
#[derive(Deserialize, Debug)]
#[derive(Serialize, Deserialize, Debug)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ContentBlock {
Text { text: String },
}
#[derive(Deserialize, Debug)]
#[derive(Serialize, Deserialize, Debug)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum TextDelta {
TextDelta { text: String },
@ -226,6 +226,25 @@ pub async fn stream_completion(
}
}
pub fn extract_text_from_events(
response: impl Stream<Item = Result<ResponseEvent>>,
) -> impl Stream<Item = Result<String>> {
response.filter_map(|response| async move {
match response {
Ok(response) => match response {
ResponseEvent::ContentBlockStart { content_block, .. } => match content_block {
ContentBlock::Text { text } => Some(Ok(text)),
},
ResponseEvent::ContentBlockDelta { delta, .. } => match delta {
TextDelta::TextDelta { text } => Some(Ok(text)),
},
_ => None,
},
Err(error) => Some(Err(error)),
}
})
}
// #[cfg(test)]
// mod tests {
// use super::*;

View File

@ -249,9 +249,7 @@ impl AssistantSettingsContent {
AssistantSettingsContent::Versioned(settings) => match settings {
VersionedAssistantSettingsContent::V1(settings) => match provider.as_ref() {
"zed.dev" => {
settings.provider = Some(AssistantProviderContentV1::ZedDotDev {
default_model: CloudModel::from_id(&model).ok(),
});
log::warn!("attempted to set zed.dev model on outdated settings");
}
"anthropic" => {
let (api_url, low_speed_timeout_in_seconds) = match &settings.provider {

View File

@ -1,33 +0,0 @@
[package]
name = "assistant_tooling"
version = "0.1.0"
edition = "2021"
publish = false
license = "GPL-3.0-or-later"
[lints]
workspace = true
[lib]
path = "src/assistant_tooling.rs"
[dependencies]
anyhow.workspace = true
collections.workspace = true
futures.workspace = true
gpui.workspace = true
log.workspace = true
project.workspace = true
repair_json.workspace = true
schemars.workspace = true
serde.workspace = true
serde_json.workspace = true
sum_tree.workspace = true
ui.workspace = true
util.workspace = true
[dev-dependencies]
gpui = { workspace = true, features = ["test-support"] }
project = { workspace = true, features = ["test-support"] }
settings = { workspace = true, features = ["test-support"] }
unindent.workspace = true

View File

@ -1 +0,0 @@
../../LICENSE-GPL

View File

@ -1,85 +0,0 @@
# Assistant Tooling
Bringing Language Model tool calling to GPUI.
This unlocks:
- **Structured Extraction** of model responses
- **Validation** of model inputs
- **Execution** of chosen tools
## Overview
Language Models can produce structured outputs that are perfect for calling functions. The most famous of these is OpenAI's tool calling. When making a chat completion you can pass a list of tools available to the model. The model will choose `0..n` tools to help them complete a user's task. It's up to _you_ to create the tools that the model can call.
> **User**: "Hey I need help with implementing a collapsible panel in GPUI"
>
> **Assistant**: "Sure, I can help with that. Let me see what I can find."
>
> `tool_calls: ["name": "query_codebase", arguments: "{ 'query': 'GPUI collapsible panel' }"]`
>
> `result: "['crates/gpui/src/panel.rs:12: impl Panel { ... }', 'crates/gpui/src/panel.rs:20: impl Panel { ... }']"`
>
> **Assistant**: "Here are some excerpts from the GPUI codebase that might help you."
This library is designed to facilitate this interaction mode by allowing you to go from `struct` to `tool` with two simple traits, `LanguageModelTool` and `ToolView`.
## Using the Tool Registry
```rust
let mut tool_registry = ToolRegistry::new();
tool_registry
.register(WeatherTool { api_client },
})
.unwrap(); // You can only register one tool per name
let completion = cx.update(|cx| {
CompletionProvider::get(cx).complete(
model_name,
messages,
Vec::new(),
1.0,
// The definitions get passed directly to OpenAI when you want
// the model to be able to call your tool
tool_registry.definitions(),
)
});
let mut stream = completion?.await?;
let mut message = AssistantMessage::new();
while let Some(delta) = stream.next().await {
// As messages stream in, you'll get both assistant content
if let Some(content) = &delta.content {
message
.body
.update(cx, |message, cx| message.append(&content, cx));
}
// And tool calls!
for tool_call_delta in delta.tool_calls {
let index = tool_call_delta.index as usize;
if index >= message.tool_calls.len() {
message.tool_calls.resize_with(index + 1, Default::default);
}
let tool_call = &mut message.tool_calls[index];
// Build up an ID
if let Some(id) = &tool_call_delta.id {
tool_call.id.push_str(id);
}
tool_registry.update_tool_call(
tool_call,
tool_call_delta.name.as_deref(),
tool_call_delta.arguments.as_deref(),
cx,
);
}
}
```
Once the stream of tokens is complete, you can execute the tool call by calling `tool_registry.execute_tool_call(tool_call, cx)`, which returns a `Task<Result<()>>`.
As the tokens stream in and tool calls are executed, your `ToolView` will get updates. Render each tool call by passing that `tool_call` in to `tool_registry.render_tool_call(tool_call, cx)`. The final message for the model can be pulled by calling `self.tool_registry.content_for_tool_call( tool_call, &mut project_context, cx, )`.

View File

@ -1,13 +0,0 @@
mod attachment_registry;
mod project_context;
mod tool_registry;
pub use attachment_registry::{
AttachmentOutput, AttachmentRegistry, LanguageModelAttachment, SavedUserAttachment,
UserAttachment,
};
pub use project_context::ProjectContext;
pub use tool_registry::{
LanguageModelTool, SavedToolFunctionCall, ToolFunctionCall, ToolFunctionDefinition,
ToolRegistry, ToolView,
};

View File

@ -1,234 +0,0 @@
use crate::ProjectContext;
use anyhow::{anyhow, Result};
use collections::HashMap;
use futures::future::join_all;
use gpui::{AnyView, Render, Task, View, WindowContext};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use serde_json::value::RawValue;
use std::{
any::TypeId,
sync::{
atomic::{AtomicBool, Ordering::SeqCst},
Arc,
},
};
use util::ResultExt as _;
pub struct AttachmentRegistry {
registered_attachments: HashMap<TypeId, RegisteredAttachment>,
}
pub trait AttachmentOutput {
fn generate(&self, project: &mut ProjectContext, cx: &mut WindowContext) -> String;
}
pub trait LanguageModelAttachment {
type Output: DeserializeOwned + Serialize + 'static;
type View: Render + AttachmentOutput;
fn name(&self) -> Arc<str>;
fn run(&self, cx: &mut WindowContext) -> Task<Result<Self::Output>>;
fn view(&self, output: Result<Self::Output>, cx: &mut WindowContext) -> View<Self::View>;
}
/// A collected attachment from running an attachment tool
pub struct UserAttachment {
pub view: AnyView,
name: Arc<str>,
serialized_output: Result<Box<RawValue>, String>,
generate_fn: fn(AnyView, &mut ProjectContext, cx: &mut WindowContext) -> String,
}
#[derive(Serialize, Deserialize)]
pub struct SavedUserAttachment {
name: Arc<str>,
serialized_output: Result<Box<RawValue>, String>,
}
/// Internal representation of an attachment tool to allow us to treat them dynamically
struct RegisteredAttachment {
name: Arc<str>,
enabled: AtomicBool,
call: Box<dyn Fn(&mut WindowContext) -> Task<Result<UserAttachment>>>,
deserialize: Box<dyn Fn(&SavedUserAttachment, &mut WindowContext) -> Result<UserAttachment>>,
}
impl AttachmentRegistry {
pub fn new() -> Self {
Self {
registered_attachments: HashMap::default(),
}
}
pub fn register<A: LanguageModelAttachment + 'static>(&mut self, attachment: A) {
let attachment = Arc::new(attachment);
let call = Box::new({
let attachment = attachment.clone();
move |cx: &mut WindowContext| {
let result = attachment.run(cx);
let attachment = attachment.clone();
cx.spawn(move |mut cx| async move {
let result: Result<A::Output> = result.await;
let serialized_output =
result
.as_ref()
.map_err(ToString::to_string)
.and_then(|output| {
Ok(RawValue::from_string(
serde_json::to_string(output).map_err(|e| e.to_string())?,
)
.unwrap())
});
let view = cx.update(|cx| attachment.view(result, cx))?;
Ok(UserAttachment {
name: attachment.name(),
view: view.into(),
generate_fn: generate::<A>,
serialized_output,
})
})
}
});
let deserialize = Box::new({
let attachment = attachment.clone();
move |saved_attachment: &SavedUserAttachment, cx: &mut WindowContext| {
let serialized_output = saved_attachment.serialized_output.clone();
let output = match &serialized_output {
Ok(serialized_output) => {
Ok(serde_json::from_str::<A::Output>(serialized_output.get())?)
}
Err(error) => Err(anyhow!("{error}")),
};
let view = attachment.view(output, cx).into();
Ok(UserAttachment {
name: saved_attachment.name.clone(),
view,
serialized_output,
generate_fn: generate::<A>,
})
}
});
self.registered_attachments.insert(
TypeId::of::<A>(),
RegisteredAttachment {
name: attachment.name(),
call,
deserialize,
enabled: AtomicBool::new(true),
},
);
return;
fn generate<T: LanguageModelAttachment>(
view: AnyView,
project: &mut ProjectContext,
cx: &mut WindowContext,
) -> String {
view.downcast::<T::View>()
.unwrap()
.update(cx, |view, cx| T::View::generate(view, project, cx))
}
}
pub fn set_attachment_tool_enabled<A: LanguageModelAttachment + 'static>(
&self,
is_enabled: bool,
) {
if let Some(attachment) = self.registered_attachments.get(&TypeId::of::<A>()) {
attachment.enabled.store(is_enabled, SeqCst);
}
}
pub fn is_attachment_tool_enabled<A: LanguageModelAttachment + 'static>(&self) -> bool {
if let Some(attachment) = self.registered_attachments.get(&TypeId::of::<A>()) {
attachment.enabled.load(SeqCst)
} else {
false
}
}
pub fn call<A: LanguageModelAttachment + 'static>(
&self,
cx: &mut WindowContext,
) -> Task<Result<UserAttachment>> {
let Some(attachment) = self.registered_attachments.get(&TypeId::of::<A>()) else {
return Task::ready(Err(anyhow!("no attachment tool")));
};
(attachment.call)(cx)
}
pub fn call_all_attachment_tools(
self: Arc<Self>,
cx: &mut WindowContext<'_>,
) -> Task<Result<Vec<UserAttachment>>> {
let this = self.clone();
cx.spawn(|mut cx| async move {
let attachment_tasks = cx.update(|cx| {
let mut tasks = Vec::new();
for attachment in this
.registered_attachments
.values()
.filter(|attachment| attachment.enabled.load(SeqCst))
{
tasks.push((attachment.call)(cx))
}
tasks
})?;
let attachments = join_all(attachment_tasks.into_iter()).await;
Ok(attachments
.into_iter()
.filter_map(|attachment| attachment.log_err())
.collect())
})
}
pub fn serialize_user_attachment(
&self,
user_attachment: &UserAttachment,
) -> SavedUserAttachment {
SavedUserAttachment {
name: user_attachment.name.clone(),
serialized_output: user_attachment.serialized_output.clone(),
}
}
pub fn deserialize_user_attachment(
&self,
saved_user_attachment: SavedUserAttachment,
cx: &mut WindowContext,
) -> Result<UserAttachment> {
if let Some(registered_attachment) = self
.registered_attachments
.values()
.find(|attachment| attachment.name == saved_user_attachment.name)
{
(registered_attachment.deserialize)(&saved_user_attachment, cx)
} else {
Err(anyhow!(
"no attachment tool for name {}",
saved_user_attachment.name
))
}
}
}
impl UserAttachment {
pub fn generate(&self, output: &mut ProjectContext, cx: &mut WindowContext) -> Option<String> {
let result = (self.generate_fn)(self.view.clone(), output, cx);
if result.is_empty() {
None
} else {
Some(result)
}
}
}

View File

@ -1,296 +0,0 @@
use anyhow::{anyhow, Result};
use gpui::{AppContext, Model, Task, WeakModel};
use project::{Fs, Project, ProjectPath, Worktree};
use std::{cmp::Ordering, fmt::Write as _, ops::Range, sync::Arc};
use sum_tree::TreeMap;
pub struct ProjectContext {
files: TreeMap<ProjectPath, PathState>,
project: WeakModel<Project>,
fs: Arc<dyn Fs>,
}
#[derive(Debug, Clone)]
enum PathState {
PathOnly,
EntireFile,
Excerpts { ranges: Vec<Range<usize>> },
}
impl ProjectContext {
pub fn new(project: WeakModel<Project>, fs: Arc<dyn Fs>) -> Self {
Self {
files: TreeMap::default(),
fs,
project,
}
}
pub fn add_path(&mut self, project_path: ProjectPath) {
if self.files.get(&project_path).is_none() {
self.files.insert(project_path, PathState::PathOnly);
}
}
pub fn add_excerpts(&mut self, project_path: ProjectPath, new_ranges: &[Range<usize>]) {
let previous_state = self
.files
.get(&project_path)
.unwrap_or(&PathState::PathOnly);
let mut ranges = match previous_state {
PathState::EntireFile => return,
PathState::PathOnly => Vec::new(),
PathState::Excerpts { ranges } => ranges.to_vec(),
};
for new_range in new_ranges {
let ix = ranges.binary_search_by(|probe| {
if probe.end < new_range.start {
Ordering::Less
} else if probe.start > new_range.end {
Ordering::Greater
} else {
Ordering::Equal
}
});
match ix {
Ok(mut ix) => {
let existing = &mut ranges[ix];
existing.start = existing.start.min(new_range.start);
existing.end = existing.end.max(new_range.end);
while ix + 1 < ranges.len() && ranges[ix + 1].start <= ranges[ix].end {
ranges[ix].end = ranges[ix].end.max(ranges[ix + 1].end);
ranges.remove(ix + 1);
}
while ix > 0 && ranges[ix - 1].end >= ranges[ix].start {
ranges[ix].start = ranges[ix].start.min(ranges[ix - 1].start);
ranges.remove(ix - 1);
ix -= 1;
}
}
Err(ix) => {
ranges.insert(ix, new_range.clone());
}
}
}
self.files
.insert(project_path, PathState::Excerpts { ranges });
}
pub fn add_file(&mut self, project_path: ProjectPath) {
self.files.insert(project_path, PathState::EntireFile);
}
pub fn generate_system_message(&self, cx: &mut AppContext) -> Task<Result<String>> {
let project = self
.project
.upgrade()
.ok_or_else(|| anyhow!("project dropped"));
let files = self.files.clone();
let fs = self.fs.clone();
cx.spawn(|cx| async move {
let project = project?;
let mut result = "project structure:\n".to_string();
let mut last_worktree: Option<Model<Worktree>> = None;
for (project_path, path_state) in files.iter() {
if let Some(worktree) = &last_worktree {
if worktree.read_with(&cx, |tree, _| tree.id())? != project_path.worktree_id {
last_worktree = None;
}
}
let worktree;
if let Some(last_worktree) = &last_worktree {
worktree = last_worktree.clone();
} else if let Some(tree) = project.read_with(&cx, |project, cx| {
project.worktree_for_id(project_path.worktree_id, cx)
})? {
worktree = tree;
last_worktree = Some(worktree.clone());
let worktree_name =
worktree.read_with(&cx, |tree, _cx| tree.root_name().to_string())?;
writeln!(&mut result, "# {}", worktree_name).unwrap();
} else {
continue;
}
let worktree_abs_path = worktree.read_with(&cx, |tree, _cx| tree.abs_path())?;
let path = &project_path.path;
writeln!(&mut result, "## {}", path.display()).unwrap();
match path_state {
PathState::PathOnly => {}
PathState::EntireFile => {
let text = fs.load(&worktree_abs_path.join(&path)).await?;
writeln!(&mut result, "~~~\n{text}\n~~~").unwrap();
}
PathState::Excerpts { ranges } => {
let text = fs.load(&worktree_abs_path.join(&path)).await?;
writeln!(&mut result, "~~~").unwrap();
// Assumption: ranges are in order, not overlapping
let mut prev_range_end = 0;
for range in ranges {
if range.start > prev_range_end {
writeln!(&mut result, "...").unwrap();
prev_range_end = range.end;
}
let mut start = range.start;
let mut end = range.end.min(text.len());
while !text.is_char_boundary(start) {
start += 1;
}
while !text.is_char_boundary(end) {
end -= 1;
}
result.push_str(&text[start..end]);
if !result.ends_with('\n') {
result.push('\n');
}
}
if prev_range_end < text.len() {
writeln!(&mut result, "...").unwrap();
}
writeln!(&mut result, "~~~").unwrap();
}
}
}
Ok(result)
})
}
}
#[cfg(test)]
mod tests {
use std::path::Path;
use super::*;
use gpui::TestAppContext;
use project::FakeFs;
use serde_json::json;
use settings::SettingsStore;
use unindent::Unindent as _;
#[gpui::test]
async fn test_system_message_generation(cx: &mut TestAppContext) {
init_test(cx);
let file_3_contents = r#"
fn test1() {}
fn test2() {}
fn test3() {}
"#
.unindent();
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
"/code",
json!({
"root1": {
"lib": {
"file1.rs": "mod example;",
"file2.rs": "",
},
"test": {
"file3.rs": file_3_contents,
}
},
"root2": {
"src": {
"main.rs": ""
}
}
}),
)
.await;
let project = Project::test(
fs.clone(),
["/code/root1".as_ref(), "/code/root2".as_ref()],
cx,
)
.await;
let worktree_ids = project.read_with(cx, |project, cx| {
project
.worktrees(cx)
.map(|worktree| worktree.read(cx).id())
.collect::<Vec<_>>()
});
let mut ax = ProjectContext::new(project.downgrade(), fs);
ax.add_file(ProjectPath {
worktree_id: worktree_ids[0],
path: Path::new("lib/file1.rs").into(),
});
let message = cx
.update(|cx| ax.generate_system_message(cx))
.await
.unwrap();
assert_eq!(
r#"
project structure:
# root1
## lib/file1.rs
~~~
mod example;
~~~
"#
.unindent(),
message
);
ax.add_excerpts(
ProjectPath {
worktree_id: worktree_ids[0],
path: Path::new("test/file3.rs").into(),
},
&[
file_3_contents.find("fn test2").unwrap()
..file_3_contents.find("fn test3").unwrap(),
],
);
let message = cx
.update(|cx| ax.generate_system_message(cx))
.await
.unwrap();
assert_eq!(
r#"
project structure:
# root1
## lib/file1.rs
~~~
mod example;
~~~
## test/file3.rs
~~~
...
fn test2() {}
...
~~~
"#
.unindent(),
message
);
}
fn init_test(cx: &mut TestAppContext) {
cx.update(|cx| {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
Project::init_settings(cx);
});
}
}

View File

@ -1,526 +0,0 @@
use crate::ProjectContext;
use anyhow::{anyhow, Result};
use gpui::{AnyElement, AnyView, IntoElement, Render, Task, View, WindowContext};
use repair_json::repair;
use schemars::{schema::RootSchema, schema_for, JsonSchema};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use serde_json::value::RawValue;
use std::{
any::TypeId,
collections::HashMap,
fmt::Display,
mem,
sync::atomic::{AtomicBool, Ordering::SeqCst},
};
use ui::ViewContext;
pub struct ToolRegistry {
registered_tools: HashMap<String, RegisteredTool>,
}
#[derive(Default)]
pub struct ToolFunctionCall {
pub id: String,
pub name: String,
pub arguments: String,
state: ToolFunctionCallState,
}
#[derive(Default)]
enum ToolFunctionCallState {
#[default]
Initializing,
NoSuchTool,
KnownTool(Box<dyn InternalToolView>),
ExecutedTool(Box<dyn InternalToolView>),
}
trait InternalToolView {
fn view(&self) -> AnyView;
fn generate(&self, project: &mut ProjectContext, cx: &mut WindowContext) -> String;
fn try_set_input(&self, input: &str, cx: &mut WindowContext);
fn execute(&self, cx: &mut WindowContext) -> Task<Result<()>>;
fn serialize_output(&self, cx: &mut WindowContext) -> Result<Box<RawValue>>;
fn deserialize_output(&self, raw_value: &RawValue, cx: &mut WindowContext) -> Result<()>;
}
#[derive(Default, Serialize, Deserialize)]
pub struct SavedToolFunctionCall {
id: String,
name: String,
arguments: String,
state: SavedToolFunctionCallState,
}
#[derive(Default, Serialize, Deserialize)]
enum SavedToolFunctionCallState {
#[default]
Initializing,
NoSuchTool,
KnownTool,
ExecutedTool(Box<RawValue>),
}
#[derive(Clone, Debug, PartialEq)]
pub struct ToolFunctionDefinition {
pub name: String,
pub description: String,
pub parameters: RootSchema,
}
pub trait LanguageModelTool {
type View: ToolView;
/// Returns the name of the tool.
///
/// This name is exposed to the language model to allow the model to pick
/// which tools to use. As this name is used to identify the tool within a
/// tool registry, it should be unique.
fn name(&self) -> String;
/// Returns the description of the tool.
///
/// This can be used to _prompt_ the model as to what the tool does.
fn description(&self) -> String;
/// Returns the OpenAI Function definition for the tool, for direct use with OpenAI's API.
fn definition(&self) -> ToolFunctionDefinition {
let root_schema = schema_for!(<Self::View as ToolView>::Input);
ToolFunctionDefinition {
name: self.name(),
description: self.description(),
parameters: root_schema,
}
}
/// A view of the output of running the tool, for displaying to the user.
fn view(&self, cx: &mut WindowContext) -> View<Self::View>;
}
pub trait ToolView: Render {
/// The input type that will be passed in to `execute` when the tool is called
/// by the language model.
type Input: DeserializeOwned + JsonSchema;
/// The output returned by executing the tool.
type SerializedState: DeserializeOwned + Serialize;
fn generate(&self, project: &mut ProjectContext, cx: &mut ViewContext<Self>) -> String;
fn set_input(&mut self, input: Self::Input, cx: &mut ViewContext<Self>);
fn execute(&mut self, cx: &mut ViewContext<Self>) -> Task<Result<()>>;
fn serialize(&self, cx: &mut ViewContext<Self>) -> Self::SerializedState;
fn deserialize(
&mut self,
output: Self::SerializedState,
cx: &mut ViewContext<Self>,
) -> Result<()>;
}
struct RegisteredTool {
enabled: AtomicBool,
type_id: TypeId,
build_view: Box<dyn Fn(&mut WindowContext) -> Box<dyn InternalToolView>>,
definition: ToolFunctionDefinition,
}
impl ToolRegistry {
pub fn new() -> Self {
Self {
registered_tools: HashMap::new(),
}
}
pub fn set_tool_enabled<T: 'static + LanguageModelTool>(&self, is_enabled: bool) {
for tool in self.registered_tools.values() {
if tool.type_id == TypeId::of::<T>() {
tool.enabled.store(is_enabled, SeqCst);
return;
}
}
}
pub fn is_tool_enabled<T: 'static + LanguageModelTool>(&self) -> bool {
for tool in self.registered_tools.values() {
if tool.type_id == TypeId::of::<T>() {
return tool.enabled.load(SeqCst);
}
}
false
}
pub fn definitions(&self) -> Vec<ToolFunctionDefinition> {
self.registered_tools
.values()
.filter(|tool| tool.enabled.load(SeqCst))
.map(|tool| tool.definition.clone())
.collect()
}
pub fn update_tool_call(
&self,
call: &mut ToolFunctionCall,
name: Option<&str>,
arguments: Option<&str>,
cx: &mut WindowContext,
) {
if let Some(name) = name {
call.name.push_str(name);
}
if let Some(arguments) = arguments {
if call.arguments.is_empty() {
if let Some(tool) = self.registered_tools.get(&call.name) {
let view = (tool.build_view)(cx);
call.state = ToolFunctionCallState::KnownTool(view);
} else {
call.state = ToolFunctionCallState::NoSuchTool;
}
}
call.arguments.push_str(arguments);
if let ToolFunctionCallState::KnownTool(view) = &call.state {
if let Ok(repaired_arguments) = repair(call.arguments.clone()) {
view.try_set_input(&repaired_arguments, cx)
}
}
}
}
pub fn execute_tool_call(
&self,
tool_call: &mut ToolFunctionCall,
cx: &mut WindowContext,
) -> Option<Task<Result<()>>> {
if let ToolFunctionCallState::KnownTool(view) = mem::take(&mut tool_call.state) {
let task = view.execute(cx);
tool_call.state = ToolFunctionCallState::ExecutedTool(view);
Some(task)
} else {
None
}
}
pub fn render_tool_call(
&self,
tool_call: &ToolFunctionCall,
_cx: &mut WindowContext,
) -> Option<AnyElement> {
match &tool_call.state {
ToolFunctionCallState::NoSuchTool => {
Some(ui::Label::new("No such tool").into_any_element())
}
ToolFunctionCallState::Initializing => None,
ToolFunctionCallState::KnownTool(view) | ToolFunctionCallState::ExecutedTool(view) => {
Some(view.view().into_any_element())
}
}
}
pub fn content_for_tool_call(
&self,
tool_call: &ToolFunctionCall,
project_context: &mut ProjectContext,
cx: &mut WindowContext,
) -> String {
match &tool_call.state {
ToolFunctionCallState::Initializing => String::new(),
ToolFunctionCallState::NoSuchTool => {
format!("No such tool: {}", tool_call.name)
}
ToolFunctionCallState::KnownTool(view) | ToolFunctionCallState::ExecutedTool(view) => {
view.generate(project_context, cx)
}
}
}
pub fn serialize_tool_call(
&self,
call: &ToolFunctionCall,
cx: &mut WindowContext,
) -> Result<SavedToolFunctionCall> {
Ok(SavedToolFunctionCall {
id: call.id.clone(),
name: call.name.clone(),
arguments: call.arguments.clone(),
state: match &call.state {
ToolFunctionCallState::Initializing => SavedToolFunctionCallState::Initializing,
ToolFunctionCallState::NoSuchTool => SavedToolFunctionCallState::NoSuchTool,
ToolFunctionCallState::KnownTool(_) => SavedToolFunctionCallState::KnownTool,
ToolFunctionCallState::ExecutedTool(view) => {
SavedToolFunctionCallState::ExecutedTool(view.serialize_output(cx)?)
}
},
})
}
pub fn deserialize_tool_call(
&self,
call: &SavedToolFunctionCall,
cx: &mut WindowContext,
) -> Result<ToolFunctionCall> {
let Some(tool) = self.registered_tools.get(&call.name) else {
return Err(anyhow!("no such tool {}", call.name));
};
Ok(ToolFunctionCall {
id: call.id.clone(),
name: call.name.clone(),
arguments: call.arguments.clone(),
state: match &call.state {
SavedToolFunctionCallState::Initializing => ToolFunctionCallState::Initializing,
SavedToolFunctionCallState::NoSuchTool => ToolFunctionCallState::NoSuchTool,
SavedToolFunctionCallState::KnownTool => {
log::error!("Deserialized tool that had not executed");
let view = (tool.build_view)(cx);
view.try_set_input(&call.arguments, cx);
ToolFunctionCallState::KnownTool(view)
}
SavedToolFunctionCallState::ExecutedTool(output) => {
let view = (tool.build_view)(cx);
view.try_set_input(&call.arguments, cx);
view.deserialize_output(output, cx)?;
ToolFunctionCallState::ExecutedTool(view)
}
},
})
}
pub fn register<T: 'static + LanguageModelTool>(&mut self, tool: T) -> Result<()> {
let name = tool.name();
let registered_tool = RegisteredTool {
type_id: TypeId::of::<T>(),
definition: tool.definition(),
enabled: AtomicBool::new(true),
build_view: Box::new(move |cx: &mut WindowContext| Box::new(tool.view(cx))),
};
let previous = self.registered_tools.insert(name.clone(), registered_tool);
if previous.is_some() {
return Err(anyhow!("already registered a tool with name {}", name));
}
return Ok(());
}
}
impl<T: ToolView> InternalToolView for View<T> {
fn view(&self) -> AnyView {
self.clone().into()
}
fn generate(&self, project: &mut ProjectContext, cx: &mut WindowContext) -> String {
self.update(cx, |view, cx| view.generate(project, cx))
}
fn try_set_input(&self, input: &str, cx: &mut WindowContext) {
if let Ok(input) = serde_json::from_str::<T::Input>(input) {
self.update(cx, |view, cx| {
view.set_input(input, cx);
cx.notify();
});
}
}
fn execute(&self, cx: &mut WindowContext) -> Task<Result<()>> {
self.update(cx, |view, cx| view.execute(cx))
}
fn serialize_output(&self, cx: &mut WindowContext) -> Result<Box<RawValue>> {
let output = self.update(cx, |view, cx| view.serialize(cx));
Ok(RawValue::from_string(serde_json::to_string(&output)?)?)
}
fn deserialize_output(&self, output: &RawValue, cx: &mut WindowContext) -> Result<()> {
let state = serde_json::from_str::<T::SerializedState>(output.get())?;
self.update(cx, |view, cx| view.deserialize(state, cx))?;
Ok(())
}
}
impl Display for ToolFunctionDefinition {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let schema = serde_json::to_string(&self.parameters).ok();
let schema = schema.unwrap_or("None".to_string());
write!(f, "Name: {}:\n", self.name)?;
write!(f, "Description: {}\n", self.description)?;
write!(f, "Parameters: {}", schema)
}
}
#[cfg(test)]
mod test {
use super::*;
use gpui::{div, prelude::*, Render, TestAppContext};
use gpui::{EmptyView, View};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_json::json;
#[derive(Deserialize, Serialize, JsonSchema)]
struct WeatherQuery {
location: String,
unit: String,
}
#[derive(Clone, Serialize, Deserialize, PartialEq, Debug)]
struct WeatherResult {
location: String,
temperature: f64,
unit: String,
}
struct WeatherView {
input: Option<WeatherQuery>,
result: Option<WeatherResult>,
// Fake API call
current_weather: WeatherResult,
}
#[derive(Clone, Serialize)]
struct WeatherTool {
current_weather: WeatherResult,
}
impl WeatherView {
fn new(current_weather: WeatherResult) -> Self {
Self {
input: None,
result: None,
current_weather,
}
}
}
impl Render for WeatherView {
fn render(&mut self, _cx: &mut gpui::ViewContext<Self>) -> impl IntoElement {
match self.result {
Some(ref result) => div()
.child(format!("temperature: {}", result.temperature))
.into_any_element(),
None => div().child("Calculating weather...").into_any_element(),
}
}
}
impl ToolView for WeatherView {
type Input = WeatherQuery;
type SerializedState = WeatherResult;
fn generate(&self, _output: &mut ProjectContext, _cx: &mut ViewContext<Self>) -> String {
serde_json::to_string(&self.result).unwrap()
}
fn set_input(&mut self, input: Self::Input, cx: &mut ViewContext<Self>) {
self.input = Some(input);
cx.notify();
}
fn execute(&mut self, _cx: &mut ViewContext<Self>) -> Task<Result<()>> {
let input = self.input.as_ref().unwrap();
let _location = input.location.clone();
let _unit = input.unit.clone();
let weather = self.current_weather.clone();
self.result = Some(weather);
Task::ready(Ok(()))
}
fn serialize(&self, _cx: &mut ViewContext<Self>) -> Self::SerializedState {
self.current_weather.clone()
}
fn deserialize(
&mut self,
output: Self::SerializedState,
_cx: &mut ViewContext<Self>,
) -> Result<()> {
self.current_weather = output;
Ok(())
}
}
impl LanguageModelTool for WeatherTool {
type View = WeatherView;
fn name(&self) -> String {
"get_current_weather".to_string()
}
fn description(&self) -> String {
"Fetches the current weather for a given location.".to_string()
}
fn view(&self, cx: &mut WindowContext) -> View<Self::View> {
cx.new_view(|_cx| WeatherView::new(self.current_weather.clone()))
}
}
#[gpui::test]
async fn test_openai_weather_example(cx: &mut TestAppContext) {
let (_, cx) = cx.add_window_view(|_cx| EmptyView);
let mut registry = ToolRegistry::new();
registry
.register(WeatherTool {
current_weather: WeatherResult {
location: "San Francisco".to_string(),
temperature: 21.0,
unit: "Celsius".to_string(),
},
})
.unwrap();
let definitions = registry.definitions();
assert_eq!(
definitions,
[ToolFunctionDefinition {
name: "get_current_weather".to_string(),
description: "Fetches the current weather for a given location.".to_string(),
parameters: serde_json::from_value(json!({
"$schema": "http://json-schema.org/draft-07/schema#",
"title": "WeatherQuery",
"type": "object",
"properties": {
"location": {
"type": "string"
},
"unit": {
"type": "string"
}
},
"required": ["location", "unit"]
}))
.unwrap(),
}]
);
let mut call = ToolFunctionCall {
id: "the-id".to_string(),
name: "get_cur".to_string(),
..Default::default()
};
let task = cx.update(|cx| {
registry.update_tool_call(
&mut call,
Some("rent_weather"),
Some(r#"{"location": "San Francisco","#),
cx,
);
registry.update_tool_call(&mut call, None, Some(r#" "unit": "Celsius"}"#), cx);
registry.execute_tool_call(&mut call, cx).unwrap()
});
task.await.unwrap();
match &call.state {
ToolFunctionCallState::ExecutedTool(_view) => {}
_ => panic!(),
}
}
}

View File

@ -1,138 +0,0 @@
use anyhow::{anyhow, Context as _, Result};
use rpc::proto;
use util::ResultExt as _;
pub fn language_model_request_to_open_ai(
request: proto::CompleteWithLanguageModel,
) -> Result<open_ai::Request> {
Ok(open_ai::Request {
model: open_ai::Model::from_id(&request.model).unwrap_or(open_ai::Model::FourTurbo),
messages: request
.messages
.into_iter()
.map(|message: proto::LanguageModelRequestMessage| {
let role = proto::LanguageModelRole::from_i32(message.role)
.ok_or_else(|| anyhow!("invalid role {}", message.role))?;
let openai_message = match role {
proto::LanguageModelRole::LanguageModelUser => open_ai::RequestMessage::User {
content: message.content,
},
proto::LanguageModelRole::LanguageModelAssistant => {
open_ai::RequestMessage::Assistant {
content: Some(message.content),
tool_calls: message
.tool_calls
.into_iter()
.filter_map(|call| {
Some(open_ai::ToolCall {
id: call.id,
content: match call.variant? {
proto::tool_call::Variant::Function(f) => {
open_ai::ToolCallContent::Function {
function: open_ai::FunctionContent {
name: f.name,
arguments: f.arguments,
},
}
}
},
})
})
.collect(),
}
}
proto::LanguageModelRole::LanguageModelSystem => {
open_ai::RequestMessage::System {
content: message.content,
}
}
proto::LanguageModelRole::LanguageModelTool => open_ai::RequestMessage::Tool {
tool_call_id: message
.tool_call_id
.ok_or_else(|| anyhow!("tool message is missing tool call id"))?,
content: message.content,
},
};
Ok(openai_message)
})
.collect::<Result<Vec<open_ai::RequestMessage>>>()?,
stream: true,
stop: request.stop,
temperature: request.temperature,
tools: request
.tools
.into_iter()
.filter_map(|tool| {
Some(match tool.variant? {
proto::chat_completion_tool::Variant::Function(f) => {
open_ai::ToolDefinition::Function {
function: open_ai::FunctionDefinition {
name: f.name,
description: f.description,
parameters: if let Some(params) = &f.parameters {
Some(
serde_json::from_str(params)
.context("failed to deserialize tool parameters")
.log_err()?,
)
} else {
None
},
},
}
}
})
})
.collect(),
tool_choice: request.tool_choice,
})
}
pub fn language_model_request_to_google_ai(
request: proto::CompleteWithLanguageModel,
) -> Result<google_ai::GenerateContentRequest> {
Ok(google_ai::GenerateContentRequest {
contents: request
.messages
.into_iter()
.map(language_model_request_message_to_google_ai)
.collect::<Result<Vec<_>>>()?,
generation_config: None,
safety_settings: None,
})
}
pub fn language_model_request_message_to_google_ai(
message: proto::LanguageModelRequestMessage,
) -> Result<google_ai::Content> {
let role = proto::LanguageModelRole::from_i32(message.role)
.ok_or_else(|| anyhow!("invalid role {}", message.role))?;
Ok(google_ai::Content {
parts: vec![google_ai::Part::TextPart(google_ai::TextPart {
text: message.content,
})],
role: match role {
proto::LanguageModelRole::LanguageModelUser => google_ai::Role::User,
proto::LanguageModelRole::LanguageModelAssistant => google_ai::Role::Model,
proto::LanguageModelRole::LanguageModelSystem => google_ai::Role::User,
proto::LanguageModelRole::LanguageModelTool => {
Err(anyhow!("we don't handle tool calls with google ai yet"))?
}
},
})
}
pub fn count_tokens_request_to_google_ai(
request: proto::CountTokensWithLanguageModel,
) -> Result<google_ai::CountTokensRequest> {
Ok(google_ai::CountTokensRequest {
contents: request
.messages
.into_iter()
.map(language_model_request_message_to_google_ai)
.collect::<Result<Vec<_>>>()?,
})
}

View File

@ -1,4 +1,3 @@
pub mod ai;
pub mod api;
pub mod auth;
pub mod db;

View File

@ -46,8 +46,8 @@ use http_client::IsahcHttpClient;
use prometheus::{register_int_gauge, IntGauge};
use rpc::{
proto::{
self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LanguageModelRole,
LiveKitConnectionInfo, RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
},
Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
};
@ -618,17 +618,6 @@ impl Server {
)
}
})
.add_request_handler({
let app_state = app_state.clone();
user_handler(move |request, response, session| {
count_tokens_with_language_model(
request,
response,
session,
app_state.config.google_ai_api_key.clone(),
)
})
})
.add_request_handler({
user_handler(move |request, response, session| {
get_cached_embeddings(request, response, session)
@ -4514,8 +4503,8 @@ impl RateLimit for CompleteWithLanguageModelRateLimit {
}
async fn complete_with_language_model(
mut request: proto::CompleteWithLanguageModel,
response: StreamingResponse<proto::CompleteWithLanguageModel>,
query: proto::QueryLanguageModel,
response: StreamingResponse<proto::QueryLanguageModel>,
session: Session,
open_ai_api_key: Option<Arc<str>>,
google_ai_api_key: Option<Arc<str>>,
@ -4525,287 +4514,95 @@ async fn complete_with_language_model(
return Err(anyhow!("user not found"))?;
};
authorize_access_to_language_models(&session).await?;
session
.rate_limiter
.check::<CompleteWithLanguageModelRateLimit>(session.user_id())
.await?;
let mut provider_and_model = request.model.split('/');
let (provider, model) = match (
provider_and_model.next().unwrap(),
provider_and_model.next(),
) {
(provider, Some(model)) => (provider, model),
(model, None) => {
if model.starts_with("gpt") {
("openai", model)
} else if model.starts_with("gemini") {
("google", model)
} else if model.starts_with("claude") {
("anthropic", model)
} else {
("unknown", model)
}
match proto::LanguageModelRequestKind::from_i32(query.kind) {
Some(proto::LanguageModelRequestKind::Complete) => {
session
.rate_limiter
.check::<CompleteWithLanguageModelRateLimit>(session.user_id())
.await?;
}
};
let provider = provider.to_string();
request.model = model.to_string();
Some(proto::LanguageModelRequestKind::CountTokens) => {
session
.rate_limiter
.check::<CountTokensWithLanguageModelRateLimit>(session.user_id())
.await?;
}
None => Err(anyhow!("unknown request kind"))?,
}
match provider.as_str() {
"openai" => {
let api_key = open_ai_api_key.context("no OpenAI API key configured on the server")?;
complete_with_open_ai(request, response, session, api_key).await?;
}
"anthropic" => {
match proto::LanguageModelProvider::from_i32(query.provider) {
Some(proto::LanguageModelProvider::Anthropic) => {
let api_key =
anthropic_api_key.context("no Anthropic AI API key configured on the server")?;
complete_with_anthropic(request, response, session, api_key).await?;
let mut chunks = anthropic::stream_completion(
session.http_client.as_ref(),
anthropic::ANTHROPIC_API_URL,
&api_key,
serde_json::from_str(&query.request)?,
None,
)
.await?;
while let Some(chunk) = chunks.next().await {
let chunk = chunk?;
response.send(proto::QueryLanguageModelResponse {
response: serde_json::to_string(&chunk)?,
})?;
}
}
"google" => {
Some(proto::LanguageModelProvider::OpenAi) => {
let api_key = open_ai_api_key.context("no OpenAI API key configured on the server")?;
let mut chunks = open_ai::stream_completion(
session.http_client.as_ref(),
open_ai::OPEN_AI_API_URL,
&api_key,
serde_json::from_str(&query.request)?,
None,
)
.await?;
while let Some(chunk) = chunks.next().await {
let chunk = chunk?;
response.send(proto::QueryLanguageModelResponse {
response: serde_json::to_string(&chunk)?,
})?;
}
}
Some(proto::LanguageModelProvider::Google) => {
let api_key =
google_ai_api_key.context("no Google AI API key configured on the server")?;
complete_with_google_ai(request, response, session, api_key).await?;
}
provider => return Err(anyhow!("unknown provider {:?}", provider))?,
}
Ok(())
}
async fn complete_with_open_ai(
request: proto::CompleteWithLanguageModel,
response: StreamingResponse<proto::CompleteWithLanguageModel>,
session: UserSession,
api_key: Arc<str>,
) -> Result<()> {
let mut completion_stream = open_ai::stream_completion(
session.http_client.as_ref(),
OPEN_AI_API_URL,
&api_key,
crate::ai::language_model_request_to_open_ai(request)?,
None,
)
.await
.context("open_ai::stream_completion request failed within collab")?;
while let Some(event) = completion_stream.next().await {
let event = event?;
response.send(proto::LanguageModelResponse {
choices: event
.choices
.into_iter()
.map(|choice| proto::LanguageModelChoiceDelta {
index: choice.index,
delta: Some(proto::LanguageModelResponseMessage {
role: choice.delta.role.map(|role| match role {
open_ai::Role::User => LanguageModelRole::LanguageModelUser,
open_ai::Role::Assistant => LanguageModelRole::LanguageModelAssistant,
open_ai::Role::System => LanguageModelRole::LanguageModelSystem,
open_ai::Role::Tool => LanguageModelRole::LanguageModelTool,
} as i32),
content: choice.delta.content,
tool_calls: choice
.delta
.tool_calls
.unwrap_or_default()
.into_iter()
.map(|delta| proto::ToolCallDelta {
index: delta.index as u32,
id: delta.id,
variant: match delta.function {
Some(function) => {
let name = function.name;
let arguments = function.arguments;
Some(proto::tool_call_delta::Variant::Function(
proto::tool_call_delta::FunctionCallDelta {
name,
arguments,
},
))
}
None => None,
},
})
.collect(),
}),
finish_reason: choice.finish_reason,
})
.collect(),
})?;
}
Ok(())
}
async fn complete_with_google_ai(
request: proto::CompleteWithLanguageModel,
response: StreamingResponse<proto::CompleteWithLanguageModel>,
session: UserSession,
api_key: Arc<str>,
) -> Result<()> {
let mut stream = google_ai::stream_generate_content(
session.http_client.clone(),
google_ai::API_URL,
api_key.as_ref(),
&request.model.clone(),
crate::ai::language_model_request_to_google_ai(request)?,
)
.await
.context("google_ai::stream_generate_content request failed")?;
while let Some(event) = stream.next().await {
let event = event?;
response.send(proto::LanguageModelResponse {
choices: event
.candidates
.unwrap_or_default()
.into_iter()
.map(|candidate| proto::LanguageModelChoiceDelta {
index: candidate.index as u32,
delta: Some(proto::LanguageModelResponseMessage {
role: Some(match candidate.content.role {
google_ai::Role::User => LanguageModelRole::LanguageModelUser,
google_ai::Role::Model => LanguageModelRole::LanguageModelAssistant,
} as i32),
content: Some(
candidate
.content
.parts
.into_iter()
.filter_map(|part| match part {
google_ai::Part::TextPart(part) => Some(part.text),
google_ai::Part::InlineDataPart(_) => None,
})
.collect(),
),
// Tool calls are not supported for Google
tool_calls: Vec::new(),
}),
finish_reason: candidate.finish_reason.map(|reason| reason.to_string()),
})
.collect(),
})?;
}
Ok(())
}
async fn complete_with_anthropic(
request: proto::CompleteWithLanguageModel,
response: StreamingResponse<proto::CompleteWithLanguageModel>,
session: UserSession,
api_key: Arc<str>,
) -> Result<()> {
let mut system_message = String::new();
let messages = request
.messages
.into_iter()
.filter_map(|message| {
match message.role() {
LanguageModelRole::LanguageModelUser => Some(anthropic::RequestMessage {
role: anthropic::Role::User,
content: message.content,
}),
LanguageModelRole::LanguageModelAssistant => Some(anthropic::RequestMessage {
role: anthropic::Role::Assistant,
content: message.content,
}),
// Anthropic's API breaks system instructions out as a separate field rather
// than having a system message role.
LanguageModelRole::LanguageModelSystem => {
if !system_message.is_empty() {
system_message.push_str("\n\n");
}
system_message.push_str(&message.content);
None
}
// We don't yet support tool calls for Anthropic
LanguageModelRole::LanguageModelTool => None,
}
})
.collect();
let mut stream = anthropic::stream_completion(
session.http_client.as_ref(),
anthropic::ANTHROPIC_API_URL,
&api_key,
anthropic::Request {
model: request.model,
messages,
stream: true,
system: system_message,
max_tokens: 4092,
},
None,
)
.await?;
let mut current_role = proto::LanguageModelRole::LanguageModelAssistant;
while let Some(event) = stream.next().await {
let event = event?;
match event {
anthropic::ResponseEvent::MessageStart { message } => {
if let Some(role) = message.role {
if role == "assistant" {
current_role = proto::LanguageModelRole::LanguageModelAssistant;
} else if role == "user" {
current_role = proto::LanguageModelRole::LanguageModelUser;
match proto::LanguageModelRequestKind::from_i32(query.kind) {
Some(proto::LanguageModelRequestKind::Complete) => {
let mut chunks = google_ai::stream_generate_content(
session.http_client.as_ref(),
google_ai::API_URL,
&api_key,
serde_json::from_str(&query.request)?,
)
.await?;
while let Some(chunk) = chunks.next().await {
let chunk = chunk?;
response.send(proto::QueryLanguageModelResponse {
response: serde_json::to_string(&chunk)?,
})?;
}
}
}
anthropic::ResponseEvent::ContentBlockStart { content_block, .. } => {
match content_block {
anthropic::ContentBlock::Text { text } => {
if !text.is_empty() {
response.send(proto::LanguageModelResponse {
choices: vec![proto::LanguageModelChoiceDelta {
index: 0,
delta: Some(proto::LanguageModelResponseMessage {
role: Some(current_role as i32),
content: Some(text),
tool_calls: Vec::new(),
}),
finish_reason: None,
}],
})?;
}
}
}
}
anthropic::ResponseEvent::ContentBlockDelta { delta, .. } => match delta {
anthropic::TextDelta::TextDelta { text } => {
response.send(proto::LanguageModelResponse {
choices: vec![proto::LanguageModelChoiceDelta {
index: 0,
delta: Some(proto::LanguageModelResponseMessage {
role: Some(current_role as i32),
content: Some(text),
tool_calls: Vec::new(),
}),
finish_reason: None,
}],
})?;
}
},
anthropic::ResponseEvent::MessageDelta { delta, .. } => {
if let Some(stop_reason) = delta.stop_reason {
response.send(proto::LanguageModelResponse {
choices: vec![proto::LanguageModelChoiceDelta {
index: 0,
delta: None,
finish_reason: Some(stop_reason),
}],
Some(proto::LanguageModelRequestKind::CountTokens) => {
let tokens_response = google_ai::count_tokens(
session.http_client.as_ref(),
google_ai::API_URL,
&api_key,
serde_json::from_str(&query.request)?,
)
.await?;
response.send(proto::QueryLanguageModelResponse {
response: serde_json::to_string(&tokens_response)?,
})?;
}
None => Err(anyhow!("unknown request kind"))?,
}
anthropic::ResponseEvent::ContentBlockStop { .. } => {}
anthropic::ResponseEvent::MessageStop {} => {}
anthropic::ResponseEvent::Ping {} => {}
}
None => return Err(anyhow!("unknown provider"))?,
}
Ok(())
@ -4830,41 +4627,6 @@ impl RateLimit for CountTokensWithLanguageModelRateLimit {
}
}
async fn count_tokens_with_language_model(
request: proto::CountTokensWithLanguageModel,
response: Response<proto::CountTokensWithLanguageModel>,
session: UserSession,
google_ai_api_key: Option<Arc<str>>,
) -> Result<()> {
authorize_access_to_language_models(&session).await?;
if !request.model.starts_with("gemini") {
return Err(anyhow!(
"counting tokens for model: {:?} is not supported",
request.model
))?;
}
session
.rate_limiter
.check::<CountTokensWithLanguageModelRateLimit>(session.user_id())
.await?;
let api_key = google_ai_api_key
.ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
let tokens_response = google_ai::count_tokens(
session.http_client.as_ref(),
google_ai::API_URL,
&api_key,
crate::ai::count_tokens_request_to_google_ai(request)?,
)
.await?;
response.send(proto::CountTokensResponse {
token_count: tokens_response.total_tokens as u32,
})?;
Ok(())
}
struct ComputeEmbeddingsRateLimit;
impl RateLimit for ComputeEmbeddingsRateLimit {

View File

@ -11,9 +11,14 @@ workspace = true
[lib]
path = "src/google_ai.rs"
[features]
schemars = ["dep:schemars"]
[dependencies]
anyhow.workspace = true
futures.workspace = true
http_client.workspace = true
schemars = { workspace = true, optional = true }
serde.workspace = true
serde_json.workspace = true
strum.workspace = true

View File

@ -1,23 +1,21 @@
use std::sync::Arc;
use anyhow::{anyhow, Result};
use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
use http_client::HttpClient;
use serde::{Deserialize, Serialize};
pub const API_URL: &str = "https://generativelanguage.googleapis.com";
pub async fn stream_generate_content(
client: Arc<dyn HttpClient>,
client: &dyn HttpClient,
api_url: &str,
api_key: &str,
model: &str,
request: GenerateContentRequest,
mut request: GenerateContentRequest,
) -> Result<BoxStream<'static, Result<GenerateContentResponse>>> {
let uri = format!(
"{}/v1beta/models/{model}:streamGenerateContent?alt=sse&key={}",
api_url, api_key
"{api_url}/v1beta/models/{model}:streamGenerateContent?alt=sse&key={api_key}",
model = request.model
);
request.model.clear();
let request = serde_json::to_string(&request)?;
let mut response = client.post_json(&uri, request.into()).await?;
@ -52,8 +50,8 @@ pub async fn stream_generate_content(
}
}
pub async fn count_tokens<T: HttpClient>(
client: &T,
pub async fn count_tokens(
client: &dyn HttpClient,
api_url: &str,
api_key: &str,
request: CountTokensRequest,
@ -91,22 +89,24 @@ pub enum Task {
BatchEmbedContents,
}
#[derive(Debug, Serialize)]
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct GenerateContentRequest {
#[serde(default, skip_serializing_if = "String::is_empty")]
pub model: String,
pub contents: Vec<Content>,
pub generation_config: Option<GenerationConfig>,
pub safety_settings: Option<Vec<SafetySetting>>,
}
#[derive(Debug, Deserialize)]
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct GenerateContentResponse {
pub candidates: Option<Vec<GenerateContentCandidate>>,
pub prompt_feedback: Option<PromptFeedback>,
}
#[derive(Debug, Deserialize)]
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct GenerateContentCandidate {
pub index: usize,
@ -157,7 +157,7 @@ pub struct GenerativeContentBlob {
pub data: String,
}
#[derive(Debug, Deserialize)]
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct CitationSource {
pub start_index: Option<usize>,
@ -166,13 +166,13 @@ pub struct CitationSource {
pub license: Option<String>,
}
#[derive(Debug, Deserialize)]
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct CitationMetadata {
pub citation_sources: Vec<CitationSource>,
}
#[derive(Debug, Deserialize)]
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct PromptFeedback {
pub block_reason: Option<String>,
@ -180,7 +180,7 @@ pub struct PromptFeedback {
pub block_reason_message: Option<String>,
}
#[derive(Debug, Serialize)]
#[derive(Debug, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct GenerationConfig {
pub candidate_count: Option<usize>,
@ -191,7 +191,7 @@ pub struct GenerationConfig {
pub top_k: Option<usize>,
}
#[derive(Debug, Serialize)]
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SafetySetting {
pub category: HarmCategory,
@ -224,7 +224,7 @@ pub enum HarmCategory {
DangerousContent,
}
#[derive(Debug, Serialize)]
#[derive(Debug, Serialize, Deserialize)]
pub enum HarmBlockThreshold {
#[serde(rename = "HARM_BLOCK_THRESHOLD_UNSPECIFIED")]
Unspecified,
@ -238,7 +238,7 @@ pub enum HarmBlockThreshold {
BlockNone,
}
#[derive(Debug, Deserialize)]
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum HarmProbability {
#[serde(rename = "HARM_PROBABILITY_UNSPECIFIED")]
@ -249,21 +249,85 @@ pub enum HarmProbability {
High,
}
#[derive(Debug, Deserialize)]
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SafetyRating {
pub category: HarmCategory,
pub probability: HarmProbability,
}
#[derive(Debug, Serialize)]
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct CountTokensRequest {
pub contents: Vec<Content>,
}
#[derive(Debug, Deserialize)]
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct CountTokensResponse {
pub total_tokens: usize,
}
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq, strum::EnumIter)]
pub enum Model {
#[serde(rename = "gemini-1.5-pro")]
Gemini15Pro,
#[serde(rename = "gemini-1.5-flash")]
Gemini15Flash,
#[serde(rename = "custom")]
Custom { name: String, max_tokens: usize },
}
impl Model {
pub fn id(&self) -> &str {
match self {
Model::Gemini15Pro => "gemini-1.5-pro",
Model::Gemini15Flash => "gemini-1.5-flash",
Model::Custom { name, .. } => name,
}
}
pub fn display_name(&self) -> &str {
match self {
Model::Gemini15Pro => "Gemini 1.5 Pro",
Model::Gemini15Flash => "Gemini 1.5 Flash",
Model::Custom { name, .. } => name,
}
}
pub fn max_token_count(&self) -> usize {
match self {
Model::Gemini15Pro => 2_000_000,
Model::Gemini15Flash => 1_000_000,
Model::Custom { max_tokens, .. } => *max_tokens,
}
}
}
impl std::fmt::Display for Model {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.id())
}
}
pub fn extract_text_from_events(
events: impl Stream<Item = Result<GenerateContentResponse>>,
) -> impl Stream<Item = Result<String>> {
events.filter_map(|event| async move {
match event {
Ok(event) => event.candidates.and_then(|candidates| {
candidates.into_iter().next().and_then(|candidate| {
candidate.content.parts.into_iter().next().and_then(|part| {
if let Part::TextPart(TextPart { text }) = part {
Some(Ok(text))
} else {
None
}
})
})
}),
Err(error) => Some(Err(error)),
}
})
}

View File

@ -28,6 +28,7 @@ collections.workspace = true
editor.workspace = true
feature_flags.workspace = true
futures.workspace = true
google_ai = { workspace = true, features = ["schemars"] }
gpui.workspace = true
http_client.workspace = true
menu.workspace = true

View File

@ -1,108 +1,42 @@
pub use anthropic::Model as AnthropicModel;
use anyhow::{anyhow, Result};
pub use ollama::Model as OllamaModel;
pub use open_ai::Model as OpenAiModel;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use strum::EnumIter;
#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize, JsonSchema, EnumIter)]
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
#[serde(tag = "provider", rename_all = "lowercase")]
pub enum CloudModel {
#[serde(rename = "gpt-3.5-turbo")]
Gpt3Point5Turbo,
#[serde(rename = "gpt-4")]
Gpt4,
#[serde(rename = "gpt-4-turbo-preview")]
Gpt4Turbo,
#[serde(rename = "gpt-4o")]
#[default]
Gpt4Omni,
#[serde(rename = "gpt-4o-mini")]
Gpt4OmniMini,
#[serde(rename = "claude-3-5-sonnet")]
Claude3_5Sonnet,
#[serde(rename = "claude-3-opus")]
Claude3Opus,
#[serde(rename = "claude-3-sonnet")]
Claude3Sonnet,
#[serde(rename = "claude-3-haiku")]
Claude3Haiku,
#[serde(rename = "gemini-1.5-pro")]
Gemini15Pro,
#[serde(rename = "gemini-1.5-flash")]
Gemini15Flash,
#[serde(rename = "custom")]
Custom {
name: String,
max_tokens: Option<usize>,
},
Anthropic(anthropic::Model),
OpenAi(open_ai::Model),
Google(google_ai::Model),
}
impl Default for CloudModel {
fn default() -> Self {
Self::Anthropic(anthropic::Model::default())
}
}
impl CloudModel {
pub fn from_id(value: &str) -> Result<Self> {
match value {
"gpt-3.5-turbo" => Ok(Self::Gpt3Point5Turbo),
"gpt-4" => Ok(Self::Gpt4),
"gpt-4-turbo-preview" => Ok(Self::Gpt4Turbo),
"gpt-4o" => Ok(Self::Gpt4Omni),
"gpt-4o-mini" => Ok(Self::Gpt4OmniMini),
"claude-3-5-sonnet" => Ok(Self::Claude3_5Sonnet),
"claude-3-opus" => Ok(Self::Claude3Opus),
"claude-3-sonnet" => Ok(Self::Claude3Sonnet),
"claude-3-haiku" => Ok(Self::Claude3Haiku),
"gemini-1.5-pro" => Ok(Self::Gemini15Pro),
"gemini-1.5-flash" => Ok(Self::Gemini15Flash),
_ => Err(anyhow!("invalid model id")),
}
}
pub fn id(&self) -> &str {
match self {
Self::Gpt3Point5Turbo => "gpt-3.5-turbo",
Self::Gpt4 => "gpt-4",
Self::Gpt4Turbo => "gpt-4-turbo-preview",
Self::Gpt4Omni => "gpt-4o",
Self::Gpt4OmniMini => "gpt-4o-mini",
Self::Claude3_5Sonnet => "claude-3-5-sonnet",
Self::Claude3Opus => "claude-3-opus",
Self::Claude3Sonnet => "claude-3-sonnet",
Self::Claude3Haiku => "claude-3-haiku",
Self::Gemini15Pro => "gemini-1.5-pro",
Self::Gemini15Flash => "gemini-1.5-flash",
Self::Custom { name, .. } => name,
CloudModel::Anthropic(model) => model.id(),
CloudModel::OpenAi(model) => model.id(),
CloudModel::Google(model) => model.id(),
}
}
pub fn display_name(&self) -> &str {
match self {
Self::Gpt3Point5Turbo => "GPT 3.5 Turbo",
Self::Gpt4 => "GPT 4",
Self::Gpt4Turbo => "GPT 4 Turbo",
Self::Gpt4Omni => "GPT 4 Omni",
Self::Gpt4OmniMini => "GPT 4 Omni Mini",
Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
Self::Claude3Opus => "Claude 3 Opus",
Self::Claude3Sonnet => "Claude 3 Sonnet",
Self::Claude3Haiku => "Claude 3 Haiku",
Self::Gemini15Pro => "Gemini 1.5 Pro",
Self::Gemini15Flash => "Gemini 1.5 Flash",
Self::Custom { name, .. } => name,
CloudModel::Anthropic(model) => model.display_name(),
CloudModel::OpenAi(model) => model.display_name(),
CloudModel::Google(model) => model.display_name(),
}
}
pub fn max_token_count(&self) -> usize {
match self {
Self::Gpt3Point5Turbo => 2048,
Self::Gpt4 => 4096,
Self::Gpt4Turbo | Self::Gpt4Omni => 128000,
Self::Gpt4OmniMini => 128000,
Self::Claude3_5Sonnet
| Self::Claude3Opus
| Self::Claude3Sonnet
| Self::Claude3Haiku => 200000,
Self::Gemini15Pro => 128000,
Self::Gemini15Flash => 32000,
Self::Custom { max_tokens, .. } => max_tokens.unwrap_or(200_000),
CloudModel::Anthropic(model) => model.max_token_count(),
CloudModel::OpenAi(model) => model.max_token_count(),
CloudModel::Google(model) => model.max_token_count(),
}
}
}

View File

@ -2,5 +2,6 @@ pub mod anthropic;
pub mod cloud;
#[cfg(any(test, feature = "test-support"))]
pub mod fake;
pub mod google;
pub mod ollama;
pub mod open_ai;

View File

@ -1,4 +1,4 @@
use anthropic::{stream_completion, Request, RequestMessage};
use anthropic::stream_completion;
use anyhow::{anyhow, Result};
use collections::BTreeMap;
use editor::{Editor, EditorElement, EditorStyle};
@ -18,7 +18,7 @@ use util::ResultExt;
use crate::{
settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest, LanguageModelRequestMessage, Role,
LanguageModelProviderState, LanguageModelRequest, Role,
};
const PROVIDER_ID: &str = "anthropic";
@ -160,40 +160,6 @@ pub struct AnthropicModel {
http_client: Arc<dyn HttpClient>,
}
impl AnthropicModel {
fn to_anthropic_request(&self, mut request: LanguageModelRequest) -> Request {
preprocess_anthropic_request(&mut request);
let mut system_message = String::new();
if request
.messages
.first()
.map_or(false, |message| message.role == Role::System)
{
system_message = request.messages.remove(0).content;
}
Request {
model: self.model.id().to_string(),
messages: request
.messages
.iter()
.map(|msg| RequestMessage {
role: match msg.role {
Role::User => anthropic::Role::User,
Role::Assistant => anthropic::Role::Assistant,
Role::System => unreachable!("filtered out by preprocess_request"),
},
content: msg.content.clone(),
})
.collect(),
stream: true,
system: system_message,
max_tokens: 4092,
}
}
}
pub fn count_anthropic_tokens(
request: LanguageModelRequest,
cx: &AppContext,
@ -260,7 +226,7 @@ impl LanguageModel for AnthropicModel {
request: LanguageModelRequest,
cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
let request = self.to_anthropic_request(request);
let request = request.into_anthropic(self.model.id().into());
let http_client = self.http_client.clone();
@ -285,75 +251,12 @@ impl LanguageModel for AnthropicModel {
low_speed_timeout,
);
let response = request.await?;
let stream = response
.filter_map(|response| async move {
match response {
Ok(response) => match response {
anthropic::ResponseEvent::ContentBlockStart {
content_block, ..
} => match content_block {
anthropic::ContentBlock::Text { text } => Some(Ok(text)),
},
anthropic::ResponseEvent::ContentBlockDelta { delta, .. } => {
match delta {
anthropic::TextDelta::TextDelta { text } => Some(Ok(text)),
}
}
_ => None,
},
Err(error) => Some(Err(error)),
}
})
.boxed();
Ok(stream)
Ok(anthropic::extract_text_from_events(response).boxed())
}
.boxed()
}
}
pub fn preprocess_anthropic_request(request: &mut LanguageModelRequest) {
let mut new_messages: Vec<LanguageModelRequestMessage> = Vec::new();
let mut system_message = String::new();
for message in request.messages.drain(..) {
if message.content.is_empty() {
continue;
}
match message.role {
Role::User | Role::Assistant => {
if let Some(last_message) = new_messages.last_mut() {
if last_message.role == message.role {
last_message.content.push_str("\n\n");
last_message.content.push_str(&message.content);
continue;
}
}
new_messages.push(message);
}
Role::System => {
if !system_message.is_empty() {
system_message.push_str("\n\n");
}
system_message.push_str(&message.content);
}
}
}
if !system_message.is_empty() {
new_messages.insert(
0,
LanguageModelRequestMessage {
role: Role::System,
content: system_message,
},
);
}
request.messages = new_messages;
}
struct AuthenticationPrompt {
api_key: View<Editor>,
state: gpui::Model<State>,

View File

@ -7,8 +7,10 @@ use crate::{
use anyhow::Result;
use client::Client;
use collections::BTreeMap;
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt, TryFutureExt};
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
use gpui::{AnyView, AppContext, AsyncAppContext, Subscription, Task};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore};
use std::sync::Arc;
use strum::IntoEnumIterator;
@ -16,14 +18,29 @@ use ui::prelude::*;
use crate::LanguageModelProvider;
use super::anthropic::{count_anthropic_tokens, preprocess_anthropic_request};
use super::anthropic::count_anthropic_tokens;
pub const PROVIDER_ID: &str = "zed.dev";
pub const PROVIDER_NAME: &str = "zed.dev";
#[derive(Default, Clone, Debug, PartialEq)]
pub struct ZedDotDevSettings {
pub available_models: Vec<CloudModel>,
pub available_models: Vec<AvailableModel>,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "lowercase")]
pub enum AvailableProvider {
Anthropic,
OpenAi,
Google,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
pub struct AvailableModel {
provider: AvailableProvider,
name: String,
max_tokens: usize,
}
pub struct CloudLanguageModelProvider {
@ -100,10 +117,19 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
let mut models = BTreeMap::default();
// Add base models from CloudModel::iter()
for model in CloudModel::iter() {
if !matches!(model, CloudModel::Custom { .. }) {
models.insert(model.id().to_string(), model);
for model in anthropic::Model::iter() {
if !matches!(model, anthropic::Model::Custom { .. }) {
models.insert(model.id().to_string(), CloudModel::Anthropic(model));
}
}
for model in open_ai::Model::iter() {
if !matches!(model, open_ai::Model::Custom { .. }) {
models.insert(model.id().to_string(), CloudModel::OpenAi(model));
}
}
for model in google_ai::Model::iter() {
if !matches!(model, google_ai::Model::Custom { .. }) {
models.insert(model.id().to_string(), CloudModel::Google(model));
}
}
@ -112,6 +138,20 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
.zed_dot_dev
.available_models
{
let model = match model.provider {
AvailableProvider::Anthropic => CloudModel::Anthropic(anthropic::Model::Custom {
name: model.name.clone(),
max_tokens: model.max_tokens,
}),
AvailableProvider::OpenAi => CloudModel::OpenAi(open_ai::Model::Custom {
name: model.name.clone(),
max_tokens: model.max_tokens,
}),
AvailableProvider::Google => CloudModel::Google(google_ai::Model::Custom {
name: model.name.clone(),
max_tokens: model.max_tokens,
}),
};
models.insert(model.id().to_string(), model.clone());
}
@ -183,35 +223,26 @@ impl LanguageModel for CloudLanguageModel {
request: LanguageModelRequest,
cx: &AppContext,
) -> BoxFuture<'static, Result<usize>> {
match &self.model {
CloudModel::Gpt3Point5Turbo => {
count_open_ai_tokens(request, open_ai::Model::ThreePointFiveTurbo, cx)
}
CloudModel::Gpt4 => count_open_ai_tokens(request, open_ai::Model::Four, cx),
CloudModel::Gpt4Turbo => count_open_ai_tokens(request, open_ai::Model::FourTurbo, cx),
CloudModel::Gpt4Omni => count_open_ai_tokens(request, open_ai::Model::FourOmni, cx),
CloudModel::Gpt4OmniMini => {
count_open_ai_tokens(request, open_ai::Model::FourOmniMini, cx)
}
CloudModel::Claude3_5Sonnet
| CloudModel::Claude3Opus
| CloudModel::Claude3Sonnet
| CloudModel::Claude3Haiku => count_anthropic_tokens(request, cx),
CloudModel::Custom { name, .. } if name.starts_with("anthropic/") => {
count_anthropic_tokens(request, cx)
}
_ => {
let request = self.client.request(proto::CountTokensWithLanguageModel {
model: self.model.id().to_string(),
messages: request
.messages
.iter()
.map(|message| message.to_proto())
.collect(),
});
match self.model.clone() {
CloudModel::Anthropic(_) => count_anthropic_tokens(request, cx),
CloudModel::OpenAi(model) => count_open_ai_tokens(request, model, cx),
CloudModel::Google(model) => {
let client = self.client.clone();
let request = request.into_google(model.id().into());
let request = google_ai::CountTokensRequest {
contents: request.contents,
};
async move {
let response = request.await?;
Ok(response.token_count as usize)
let request = serde_json::to_string(&request)?;
let response = client.request(proto::QueryLanguageModel {
provider: proto::LanguageModelProvider::Google as i32,
kind: proto::LanguageModelRequestKind::CountTokens as i32,
request,
});
let response = response.await?;
let response =
serde_json::from_str::<google_ai::CountTokensResponse>(&response.response)?;
Ok(response.total_tokens)
}
.boxed()
}
@ -220,46 +251,65 @@ impl LanguageModel for CloudLanguageModel {
fn stream_completion(
&self,
mut request: LanguageModelRequest,
request: LanguageModelRequest,
_: &AsyncAppContext,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
match &self.model {
CloudModel::Claude3Opus
| CloudModel::Claude3Sonnet
| CloudModel::Claude3Haiku
| CloudModel::Claude3_5Sonnet => preprocess_anthropic_request(&mut request),
CloudModel::Custom { name, .. } if name.starts_with("anthropic/") => {
preprocess_anthropic_request(&mut request)
CloudModel::Anthropic(model) => {
let client = self.client.clone();
let request = request.into_anthropic(model.id().into());
async move {
let request = serde_json::to_string(&request)?;
let response = client.request_stream(proto::QueryLanguageModel {
provider: proto::LanguageModelProvider::Anthropic as i32,
kind: proto::LanguageModelRequestKind::Complete as i32,
request,
});
let chunks = response.await?;
Ok(anthropic::extract_text_from_events(
chunks.map(|chunk| Ok(serde_json::from_str(&chunk?.response)?)),
)
.boxed())
}
.boxed()
}
CloudModel::OpenAi(model) => {
let client = self.client.clone();
let request = request.into_open_ai(model.id().into());
async move {
let request = serde_json::to_string(&request)?;
let response = client.request_stream(proto::QueryLanguageModel {
provider: proto::LanguageModelProvider::OpenAi as i32,
kind: proto::LanguageModelRequestKind::Complete as i32,
request,
});
let chunks = response.await?;
Ok(open_ai::extract_text_from_events(
chunks.map(|chunk| Ok(serde_json::from_str(&chunk?.response)?)),
)
.boxed())
}
.boxed()
}
CloudModel::Google(model) => {
let client = self.client.clone();
let request = request.into_google(model.id().into());
async move {
let request = serde_json::to_string(&request)?;
let response = client.request_stream(proto::QueryLanguageModel {
provider: proto::LanguageModelProvider::Google as i32,
kind: proto::LanguageModelRequestKind::Complete as i32,
request,
});
let chunks = response.await?;
Ok(google_ai::extract_text_from_events(
chunks.map(|chunk| Ok(serde_json::from_str(&chunk?.response)?)),
)
.boxed())
}
.boxed()
}
_ => {}
}
let request = proto::CompleteWithLanguageModel {
model: self.id.0.to_string(),
messages: request
.messages
.iter()
.map(|message| message.to_proto())
.collect(),
stop: request.stop,
temperature: request.temperature,
tools: Vec::new(),
tool_choice: None,
};
self.client
.request_stream(request)
.map_ok(|stream| {
stream
.filter_map(|response| async move {
match response {
Ok(mut response) => Some(Ok(response.choices.pop()?.delta?.content?)),
Err(error) => Some(Err(error)),
}
})
.boxed()
})
.boxed()
}
}

View File

@ -0,0 +1,351 @@
use anyhow::{anyhow, Result};
use collections::BTreeMap;
use editor::{Editor, EditorElement, EditorStyle};
use futures::{future::BoxFuture, FutureExt, StreamExt};
use google_ai::stream_generate_content;
use gpui::{
AnyView, AppContext, AsyncAppContext, FontStyle, Subscription, Task, TextStyle, View,
WhiteSpace,
};
use http_client::HttpClient;
use settings::{Settings, SettingsStore};
use std::{sync::Arc, time::Duration};
use strum::IntoEnumIterator;
use theme::ThemeSettings;
use ui::prelude::*;
use util::ResultExt;
use crate::{
settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest,
};
const PROVIDER_ID: &str = "google";
const PROVIDER_NAME: &str = "Google AI";
#[derive(Default, Clone, Debug, PartialEq)]
pub struct GoogleSettings {
pub api_url: String,
pub low_speed_timeout: Option<Duration>,
pub available_models: Vec<google_ai::Model>,
}
pub struct GoogleLanguageModelProvider {
http_client: Arc<dyn HttpClient>,
state: gpui::Model<State>,
}
struct State {
api_key: Option<String>,
_subscription: Subscription,
}
impl GoogleLanguageModelProvider {
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut AppContext) -> Self {
let state = cx.new_model(|cx| State {
api_key: None,
_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
cx.notify();
}),
});
Self { http_client, state }
}
}
impl LanguageModelProviderState for GoogleLanguageModelProvider {
fn subscribe<T: 'static>(&self, cx: &mut gpui::ModelContext<T>) -> Option<gpui::Subscription> {
Some(cx.observe(&self.state, |_, _, cx| {
cx.notify();
}))
}
}
impl LanguageModelProvider for GoogleLanguageModelProvider {
fn id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
}
fn name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
}
fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
let mut models = BTreeMap::default();
// Add base models from google_ai::Model::iter()
for model in google_ai::Model::iter() {
if !matches!(model, google_ai::Model::Custom { .. }) {
models.insert(model.id().to_string(), model);
}
}
// Override with available models from settings
for model in &AllLanguageModelSettings::get_global(cx)
.google
.available_models
{
models.insert(model.id().to_string(), model.clone());
}
models
.into_values()
.map(|model| {
Arc::new(GoogleLanguageModel {
id: LanguageModelId::from(model.id().to_string()),
model,
state: self.state.clone(),
http_client: self.http_client.clone(),
}) as Arc<dyn LanguageModel>
})
.collect()
}
fn is_authenticated(&self, cx: &AppContext) -> bool {
self.state.read(cx).api_key.is_some()
}
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
if self.is_authenticated(cx) {
Task::ready(Ok(()))
} else {
let api_url = AllLanguageModelSettings::get_global(cx)
.google
.api_url
.clone();
let state = self.state.clone();
cx.spawn(|mut cx| async move {
let api_key = if let Ok(api_key) = std::env::var("GOOGLE_AI_API_KEY") {
api_key
} else {
let (_, api_key) = cx
.update(|cx| cx.read_credentials(&api_url))?
.await?
.ok_or_else(|| anyhow!("credentials not found"))?;
String::from_utf8(api_key)?
};
state.update(&mut cx, |this, cx| {
this.api_key = Some(api_key);
cx.notify();
})
})
}
}
fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
cx.new_view(|cx| AuthenticationPrompt::new(self.state.clone(), cx))
.into()
}
fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
let state = self.state.clone();
let delete_credentials =
cx.delete_credentials(&AllLanguageModelSettings::get_global(cx).google.api_url);
cx.spawn(|mut cx| async move {
delete_credentials.await.log_err();
state.update(&mut cx, |this, cx| {
this.api_key = None;
cx.notify();
})
})
}
}
pub struct GoogleLanguageModel {
id: LanguageModelId,
model: google_ai::Model,
state: gpui::Model<State>,
http_client: Arc<dyn HttpClient>,
}
impl LanguageModel for GoogleLanguageModel {
fn id(&self) -> LanguageModelId {
self.id.clone()
}
fn name(&self) -> LanguageModelName {
LanguageModelName::from(self.model.display_name().to_string())
}
fn provider_id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
}
fn provider_name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
}
fn telemetry_id(&self) -> String {
format!("google/{}", self.model.id())
}
fn max_token_count(&self) -> usize {
self.model.max_token_count()
}
fn count_tokens(
&self,
request: LanguageModelRequest,
cx: &AppContext,
) -> BoxFuture<'static, Result<usize>> {
let request = request.into_google(self.model.id().to_string());
let http_client = self.http_client.clone();
let api_key = self.state.read(cx).api_key.clone();
let api_url = AllLanguageModelSettings::get_global(cx)
.google
.api_url
.clone();
async move {
let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
let response = google_ai::count_tokens(
http_client.as_ref(),
&api_url,
&api_key,
google_ai::CountTokensRequest {
contents: request.contents,
},
)
.await?;
Ok(response.total_tokens)
}
.boxed()
}
fn stream_completion(
&self,
request: LanguageModelRequest,
cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<String>>>> {
let request = request.into_google(self.model.id().to_string());
let http_client = self.http_client.clone();
let Ok((api_key, api_url)) = cx.read_model(&self.state, |state, cx| {
let settings = &AllLanguageModelSettings::get_global(cx).google;
(state.api_key.clone(), settings.api_url.clone())
}) else {
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
};
async move {
let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
let response =
stream_generate_content(http_client.as_ref(), &api_url, &api_key, request);
let events = response.await?;
Ok(google_ai::extract_text_from_events(events).boxed())
}
.boxed()
}
}
struct AuthenticationPrompt {
api_key: View<Editor>,
state: gpui::Model<State>,
}
impl AuthenticationPrompt {
fn new(state: gpui::Model<State>, cx: &mut WindowContext) -> Self {
Self {
api_key: cx.new_view(|cx| {
let mut editor = Editor::single_line(cx);
editor.set_placeholder_text("AIzaSy...", cx);
editor
}),
state,
}
}
fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
let api_key = self.api_key.read(cx).text(cx);
if api_key.is_empty() {
return;
}
let settings = &AllLanguageModelSettings::get_global(cx).google;
let write_credentials =
cx.write_credentials(&settings.api_url, "Bearer", api_key.as_bytes());
let state = self.state.clone();
cx.spawn(|_, mut cx| async move {
write_credentials.await?;
state.update(&mut cx, |this, cx| {
this.api_key = Some(api_key);
cx.notify();
})
})
.detach_and_log_err(cx);
}
fn render_api_key_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
let settings = ThemeSettings::get_global(cx);
let text_style = TextStyle {
color: cx.theme().colors().text,
font_family: settings.ui_font.family.clone(),
font_features: settings.ui_font.features.clone(),
font_fallbacks: settings.ui_font.fallbacks.clone(),
font_size: rems(0.875).into(),
font_weight: settings.ui_font.weight,
font_style: FontStyle::Normal,
line_height: relative(1.3),
background_color: None,
underline: None,
strikethrough: None,
white_space: WhiteSpace::Normal,
};
EditorElement::new(
&self.api_key,
EditorStyle {
background: cx.theme().colors().editor_background,
local_player: cx.theme().players().local(),
text: text_style,
..Default::default()
},
)
}
}
impl Render for AuthenticationPrompt {
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
const INSTRUCTIONS: [&str; 4] = [
"To use the Google AI assistant, you need to add your Google AI API key.",
"You can create an API key at: https://makersuite.google.com/app/apikey",
"",
"Paste your Google AI API key below and hit enter to use the assistant:",
];
v_flex()
.p_4()
.size_full()
.on_action(cx.listener(Self::save_api_key))
.children(
INSTRUCTIONS.map(|instruction| Label::new(instruction).size(LabelSize::Small)),
)
.child(
h_flex()
.w_full()
.my_2()
.px_2()
.py_1()
.bg(cx.theme().colors().editor_background)
.rounded_md()
.child(self.render_api_key_editor(cx)),
)
.child(
Label::new(
"You can also assign the GOOGLE_AI_API_KEY environment variable and restart Zed.",
)
.size(LabelSize::Small),
)
.child(
h_flex()
.gap_2()
.child(Label::new("Click on").size(LabelSize::Small))
.child(Icon::new(IconName::ZedAssistant).size(IconSize::XSmall))
.child(
Label::new("in the status bar to close this panel.").size(LabelSize::Small),
),
)
.into_any()
}
}

View File

@ -7,7 +7,7 @@ use gpui::{
WhiteSpace,
};
use http_client::HttpClient;
use open_ai::{stream_completion, Request, RequestMessage};
use open_ai::stream_completion;
use settings::{Settings, SettingsStore};
use std::{sync::Arc, time::Duration};
use strum::IntoEnumIterator;
@ -159,35 +159,6 @@ pub struct OpenAiLanguageModel {
http_client: Arc<dyn HttpClient>,
}
impl OpenAiLanguageModel {
fn to_open_ai_request(&self, request: LanguageModelRequest) -> Request {
Request {
model: self.model.clone(),
messages: request
.messages
.into_iter()
.map(|msg| match msg.role {
Role::User => RequestMessage::User {
content: msg.content,
},
Role::Assistant => RequestMessage::Assistant {
content: Some(msg.content),
tool_calls: Vec::new(),
},
Role::System => RequestMessage::System {
content: msg.content,
},
})
.collect(),
stream: true,
stop: request.stop,
temperature: request.temperature,
tools: Vec::new(),
tool_choice: None,
}
}
}
impl LanguageModel for OpenAiLanguageModel {
fn id(&self) -> LanguageModelId {
self.id.clone()
@ -226,7 +197,7 @@ impl LanguageModel for OpenAiLanguageModel {
request: LanguageModelRequest,
cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<String>>>> {
let request = self.to_open_ai_request(request);
let request = request.into_open_ai(self.model.id().into());
let http_client = self.http_client.clone();
let Ok((api_key, api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, cx| {
@ -250,15 +221,7 @@ impl LanguageModel for OpenAiLanguageModel {
low_speed_timeout,
);
let response = request.await?;
let stream = response
.filter_map(|response| async move {
match response {
Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
Err(error) => Some(Err(error)),
}
})
.boxed();
Ok(stream)
Ok(open_ai::extract_text_from_events(response).boxed())
}
.boxed()
}

View File

@ -1,17 +1,17 @@
use crate::{
provider::{
anthropic::AnthropicLanguageModelProvider, cloud::CloudLanguageModelProvider,
google::GoogleLanguageModelProvider, ollama::OllamaLanguageModelProvider,
open_ai::OpenAiLanguageModelProvider,
},
LanguageModel, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderState,
};
use client::Client;
use collections::BTreeMap;
use gpui::{AppContext, Global, Model, ModelContext};
use std::sync::Arc;
use ui::Context;
use crate::{
provider::{
anthropic::AnthropicLanguageModelProvider, cloud::CloudLanguageModelProvider,
ollama::OllamaLanguageModelProvider, open_ai::OpenAiLanguageModelProvider,
},
LanguageModel, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderState,
};
pub fn init(client: Arc<Client>, cx: &mut AppContext) {
let registry = cx.new_model(|cx| {
let mut registry = LanguageModelRegistry::default();
@ -40,6 +40,10 @@ fn register_language_model_providers(
OllamaLanguageModelProvider::new(client.http_client(), cx),
cx,
);
registry.register_provider(
GoogleLanguageModelProvider::new(client.http_client(), cx),
cx,
);
cx.observe_flag::<feature_flags::LanguageModels, _>(move |enabled, cx| {
let client = client.clone();

View File

@ -1,4 +1,4 @@
use crate::{role::Role, LanguageModelId};
use crate::role::Role;
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
@ -7,17 +7,6 @@ pub struct LanguageModelRequestMessage {
pub content: String,
}
impl LanguageModelRequestMessage {
pub fn to_proto(&self) -> proto::LanguageModelRequestMessage {
proto::LanguageModelRequestMessage {
role: self.role.to_proto() as i32,
content: self.content.clone(),
tool_calls: Vec::new(),
tool_call_id: None,
}
}
}
#[derive(Debug, Default, Serialize, Deserialize)]
pub struct LanguageModelRequest {
pub messages: Vec<LanguageModelRequestMessage>,
@ -26,14 +15,110 @@ pub struct LanguageModelRequest {
}
impl LanguageModelRequest {
pub fn to_proto(&self, model_id: LanguageModelId) -> proto::CompleteWithLanguageModel {
proto::CompleteWithLanguageModel {
model: model_id.0.to_string(),
messages: self.messages.iter().map(|m| m.to_proto()).collect(),
stop: self.stop.clone(),
pub fn into_open_ai(self, model: String) -> open_ai::Request {
open_ai::Request {
model,
messages: self
.messages
.into_iter()
.map(|msg| match msg.role {
Role::User => open_ai::RequestMessage::User {
content: msg.content,
},
Role::Assistant => open_ai::RequestMessage::Assistant {
content: Some(msg.content),
tool_calls: Vec::new(),
},
Role::System => open_ai::RequestMessage::System {
content: msg.content,
},
})
.collect(),
stream: true,
stop: self.stop,
temperature: self.temperature,
tool_choice: None,
tools: Vec::new(),
tool_choice: None,
}
}
pub fn into_google(self, model: String) -> google_ai::GenerateContentRequest {
google_ai::GenerateContentRequest {
model,
contents: self
.messages
.into_iter()
.map(|msg| google_ai::Content {
parts: vec![google_ai::Part::TextPart(google_ai::TextPart {
text: msg.content,
})],
role: match msg.role {
Role::User => google_ai::Role::User,
Role::Assistant => google_ai::Role::Model,
Role::System => google_ai::Role::User, // Google AI doesn't have a system role
},
})
.collect(),
generation_config: Some(google_ai::GenerationConfig {
candidate_count: Some(1),
stop_sequences: Some(self.stop),
max_output_tokens: None,
temperature: Some(self.temperature as f64),
top_p: None,
top_k: None,
}),
safety_settings: None,
}
}
pub fn into_anthropic(self, model: String) -> anthropic::Request {
let mut new_messages: Vec<LanguageModelRequestMessage> = Vec::new();
let mut system_message = String::new();
for message in self.messages {
if message.content.is_empty() {
continue;
}
match message.role {
Role::User | Role::Assistant => {
if let Some(last_message) = new_messages.last_mut() {
if last_message.role == message.role {
last_message.content.push_str("\n\n");
last_message.content.push_str(&message.content);
continue;
}
}
new_messages.push(message);
}
Role::System => {
if !system_message.is_empty() {
system_message.push_str("\n\n");
}
system_message.push_str(&message.content);
}
}
}
anthropic::Request {
model,
messages: new_messages
.into_iter()
.filter_map(|message| {
Some(anthropic::RequestMessage {
role: match message.role {
Role::User => anthropic::Role::User,
Role::Assistant => anthropic::Role::Assistant,
Role::System => return None,
},
content: message.content,
})
})
.collect(),
stream: true,
max_tokens: 4092,
system: system_message,
}
}
}

View File

@ -15,7 +15,6 @@ impl Role {
Some(proto::LanguageModelRole::LanguageModelUser) => Role::User,
Some(proto::LanguageModelRole::LanguageModelAssistant) => Role::Assistant,
Some(proto::LanguageModelRole::LanguageModelSystem) => Role::System,
Some(proto::LanguageModelRole::LanguageModelTool) => Role::System,
None => Role::User,
}
}

View File

@ -6,12 +6,12 @@ use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsSources};
use crate::{
provider::{
anthropic::AnthropicSettings, cloud::ZedDotDevSettings, ollama::OllamaSettings,
open_ai::OpenAiSettings,
},
CloudModel,
use crate::provider::{
anthropic::AnthropicSettings,
cloud::{self, ZedDotDevSettings},
google::GoogleSettings,
ollama::OllamaSettings,
open_ai::OpenAiSettings,
};
/// Initializes the language model settings.
@ -25,6 +25,7 @@ pub struct AllLanguageModelSettings {
pub ollama: OllamaSettings,
pub openai: OpenAiSettings,
pub zed_dot_dev: ZedDotDevSettings,
pub google: GoogleSettings,
}
#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
@ -34,6 +35,7 @@ pub struct AllLanguageModelSettingsContent {
pub openai: Option<OpenAiSettingsContent>,
#[serde(rename = "zed.dev")]
pub zed_dot_dev: Option<ZedDotDevSettingsContent>,
pub google: Option<GoogleSettingsContent>,
}
#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
@ -56,9 +58,16 @@ pub struct OpenAiSettingsContent {
pub available_models: Option<Vec<open_ai::Model>>,
}
#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
pub struct GoogleSettingsContent {
pub api_url: Option<String>,
pub low_speed_timeout_in_seconds: Option<u64>,
pub available_models: Option<Vec<google_ai::Model>>,
}
#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
pub struct ZedDotDevSettingsContent {
available_models: Option<Vec<CloudModel>>,
available_models: Option<Vec<cloud::AvailableModel>>,
}
impl settings::Settings for AllLanguageModelSettings {
@ -136,6 +145,26 @@ impl settings::Settings for AllLanguageModelSettings {
.as_ref()
.and_then(|s| s.available_models.clone()),
);
merge(
&mut settings.google.api_url,
value.google.as_ref().and_then(|s| s.api_url.clone()),
);
if let Some(low_speed_timeout_in_seconds) = value
.google
.as_ref()
.and_then(|s| s.low_speed_timeout_in_seconds)
{
settings.google.low_speed_timeout =
Some(Duration::from_secs(low_speed_timeout_in_seconds));
}
merge(
&mut settings.google.available_models,
value
.google
.as_ref()
.and_then(|s| s.available_models.clone()),
);
}
Ok(settings)

View File

@ -1,5 +1,5 @@
use anyhow::{anyhow, Context, Result};
use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
use isahc::config::Configurable;
use serde::{Deserialize, Serialize};
@ -111,38 +111,27 @@ impl Model {
}
}
fn serialize_model<S>(model: &Model, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
match model {
Model::Custom { name, .. } => serializer.serialize_str(name),
_ => serializer.serialize_str(model.id()),
}
}
#[derive(Debug, Serialize)]
#[derive(Debug, Serialize, Deserialize)]
pub struct Request {
#[serde(serialize_with = "serialize_model")]
pub model: Model,
pub model: String,
pub messages: Vec<RequestMessage>,
pub stream: bool,
pub stop: Vec<String>,
pub temperature: f32,
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<String>,
#[serde(skip_serializing_if = "Vec::is_empty")]
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub tools: Vec<ToolDefinition>,
}
#[derive(Debug, Serialize)]
#[derive(Debug, Deserialize, Serialize)]
pub struct FunctionDefinition {
pub name: String,
pub description: Option<String>,
pub parameters: Option<Map<String, Value>>,
}
#[derive(Serialize, Debug)]
#[derive(Deserialize, Serialize, Debug)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ToolDefinition {
#[allow(dead_code)]
@ -213,21 +202,21 @@ pub struct FunctionChunk {
pub arguments: Option<String>,
}
#[derive(Deserialize, Debug)]
#[derive(Serialize, Deserialize, Debug)]
pub struct Usage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}
#[derive(Deserialize, Debug)]
#[derive(Serialize, Deserialize, Debug)]
pub struct ChoiceDelta {
pub index: u32,
pub delta: ResponseMessageDelta,
pub finish_reason: Option<String>,
}
#[derive(Deserialize, Debug)]
#[derive(Serialize, Deserialize, Debug)]
pub struct ResponseStreamEvent {
pub created: u32,
pub model: String,
@ -369,3 +358,14 @@ pub fn embed<'a>(
}
}
}
pub fn extract_text_from_events(
response: impl Stream<Item = Result<ResponseStreamEvent>>,
) -> impl Stream<Item = Result<String>> {
response.filter_map(|response| async move {
match response {
Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
Err(error) => Some(Err(error)),
}
})
}

View File

@ -13,13 +13,6 @@ message Envelope {
optional uint32 responding_to = 2;
optional PeerId original_sender_id = 3;
/*
When you are adding a new message type, instead of adding it in semantic order
and bumping the message ID's of everything that follows, add it at the end of the
file and bump the max number. See this
https://github.com/zed-industries/zed/pull/7890#discussion_r1496621823
*/
oneof payload {
Hello hello = 4;
Ack ack = 5;
@ -201,10 +194,8 @@ message Envelope {
JoinHostedProject join_hosted_project = 164;
CompleteWithLanguageModel complete_with_language_model = 166;
LanguageModelResponse language_model_response = 167;
CountTokensWithLanguageModel count_tokens_with_language_model = 168;
CountTokensResponse count_tokens_response = 169;
QueryLanguageModel query_language_model = 224;
QueryLanguageModelResponse query_language_model_response = 225; // current max
GetCachedEmbeddings get_cached_embeddings = 189;
GetCachedEmbeddingsResponse get_cached_embeddings_response = 190;
ComputeEmbeddings compute_embeddings = 191;
@ -271,10 +262,11 @@ message Envelope {
UpdateDevServerProject update_dev_server_project = 221;
AddWorktree add_worktree = 222;
AddWorktreeResponse add_worktree_response = 223; // current max
AddWorktreeResponse add_worktree_response = 223;
}
reserved 158 to 161;
reserved 166 to 169;
}
// Messages
@ -2051,94 +2043,32 @@ message SetRoomParticipantRole {
ChannelRole role = 3;
}
message CompleteWithLanguageModel {
string model = 1;
repeated LanguageModelRequestMessage messages = 2;
repeated string stop = 3;
float temperature = 4;
repeated ChatCompletionTool tools = 5;
optional string tool_choice = 6;
}
// A tool presented to the language model for its use
message ChatCompletionTool {
oneof variant {
FunctionObject function = 1;
}
message FunctionObject {
string name = 1;
optional string description = 2;
optional string parameters = 3;
}
}
// A message to the language model
message LanguageModelRequestMessage {
LanguageModelRole role = 1;
string content = 2;
optional string tool_call_id = 3;
repeated ToolCall tool_calls = 4;
}
enum LanguageModelRole {
LanguageModelUser = 0;
LanguageModelAssistant = 1;
LanguageModelSystem = 2;
LanguageModelTool = 3;
reserved 3;
}
message LanguageModelResponseMessage {
optional LanguageModelRole role = 1;
optional string content = 2;
repeated ToolCallDelta tool_calls = 3;
message QueryLanguageModel {
LanguageModelProvider provider = 1;
LanguageModelRequestKind kind = 2;
string request = 3;
}
// A request to call a tool, by the language model
message ToolCall {
string id = 1;
oneof variant {
FunctionCall function = 2;
}
message FunctionCall {
string name = 1;
string arguments = 2;
}
enum LanguageModelProvider {
Anthropic = 0;
OpenAI = 1;
Google = 2;
}
message ToolCallDelta {
uint32 index = 1;
optional string id = 2;
oneof variant {
FunctionCallDelta function = 3;
}
message FunctionCallDelta {
optional string name = 1;
optional string arguments = 2;
}
enum LanguageModelRequestKind {
Complete = 0;
CountTokens = 1;
}
message LanguageModelResponse {
repeated LanguageModelChoiceDelta choices = 1;
}
message LanguageModelChoiceDelta {
uint32 index = 1;
LanguageModelResponseMessage delta = 2;
optional string finish_reason = 3;
}
message CountTokensWithLanguageModel {
string model = 1;
repeated LanguageModelRequestMessage messages = 2;
}
message CountTokensResponse {
uint32 token_count = 1;
message QueryLanguageModelResponse {
string response = 1;
}
message GetCachedEmbeddings {

View File

@ -203,12 +203,9 @@ messages!(
(CancelCall, Foreground),
(ChannelMessageSent, Foreground),
(ChannelMessageUpdate, Foreground),
(CompleteWithLanguageModel, Background),
(ComputeEmbeddings, Background),
(ComputeEmbeddingsResponse, Background),
(CopyProjectEntry, Foreground),
(CountTokensWithLanguageModel, Background),
(CountTokensResponse, Background),
(CreateBufferForPeer, Foreground),
(CreateChannel, Foreground),
(CreateChannelResponse, Foreground),
@ -278,7 +275,6 @@ messages!(
(JoinProjectResponse, Foreground),
(JoinRoom, Foreground),
(JoinRoomResponse, Foreground),
(LanguageModelResponse, Background),
(LeaveChannelBuffer, Background),
(LeaveChannelChat, Foreground),
(LeaveProject, Foreground),
@ -298,6 +294,8 @@ messages!(
(PrepareRename, Background),
(PrepareRenameResponse, Background),
(ProjectEntryResponse, Foreground),
(QueryLanguageModel, Background),
(QueryLanguageModelResponse, Background),
(RefreshInlayHints, Foreground),
(RejoinChannelBuffers, Foreground),
(RejoinChannelBuffersResponse, Foreground),
@ -412,9 +410,7 @@ request_messages!(
(Call, Ack),
(CancelCall, Ack),
(CopyProjectEntry, ProjectEntryResponse),
(CompleteWithLanguageModel, LanguageModelResponse),
(ComputeEmbeddings, ComputeEmbeddingsResponse),
(CountTokensWithLanguageModel, CountTokensResponse),
(CreateChannel, CreateChannelResponse),
(CreateProjectEntry, ProjectEntryResponse),
(CreateRoom, CreateRoomResponse),
@ -467,6 +463,7 @@ request_messages!(
(PerformRename, PerformRenameResponse),
(Ping, Ack),
(PrepareRename, PrepareRenameResponse),
(QueryLanguageModel, QueryLanguageModelResponse),
(RefreshInlayHints, Ack),
(RejoinChannelBuffers, RejoinChannelBuffersResponse),
(RejoinRoom, RejoinRoomResponse),