From d6bdaa8a9141e181ec91ffb634cccae03e46ba08 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Sun, 28 Jul 2024 11:07:10 +0200 Subject: [PATCH] Simplify LLM protocol (#15366) In this pull request, we change the zed.dev protocol so that we pass the raw JSON for the specified provider directly to our server. This avoids the need to define a protobuf message that's a superset of all these formats. @bennetbo: We also changed the settings for available_models under zed.dev to be a flat format, because the nesting seemed too confusing. Can you help us upgrade the local provider configuration to be consistent with this? We do whatever we need to do when parsing the settings to make this simple for users, even if it's a bit more complex on our end. We want to use versioning to avoid breaking existing users, but need to keep making progress. ```json "zed.dev": { "available_models": [ { "provider": "anthropic", "name": "some-newly-released-model-we-havent-added", "max_tokens": 200000 } ] } ``` Release Notes: - N/A --------- Co-authored-by: Nathan --- Cargo.lock | 33 +- Cargo.toml | 2 - assets/settings/default.json | 3 + crates/anthropic/src/anthropic.rs | 33 +- crates/assistant/src/assistant_settings.rs | 4 +- crates/assistant_tooling/Cargo.toml | 33 -- crates/assistant_tooling/LICENSE-GPL | 1 - crates/assistant_tooling/README.md | 85 --- .../src/assistant_tooling.rs | 13 - .../src/attachment_registry.rs | 234 -------- .../assistant_tooling/src/project_context.rs | 296 ---------- crates/assistant_tooling/src/tool_registry.rs | 526 ------------------ crates/collab/src/ai.rs | 138 ----- crates/collab/src/lib.rs | 1 - crates/collab/src/rpc.rs | 394 +++---------- crates/google_ai/Cargo.toml | 5 + crates/google_ai/src/google_ai.rs | 110 +++- crates/language_model/Cargo.toml | 1 + .../language_model/src/model/cloud_model.rs | 106 +--- crates/language_model/src/provider.rs | 1 + .../language_model/src/provider/anthropic.rs | 105 +--- crates/language_model/src/provider/cloud.rs | 190 ++++--- crates/language_model/src/provider/google.rs | 351 ++++++++++++ crates/language_model/src/provider/open_ai.rs | 43 +- crates/language_model/src/registry.rs | 20 +- crates/language_model/src/request.rs | 121 +++- crates/language_model/src/role.rs | 1 - crates/language_model/src/settings.rs | 43 +- crates/open_ai/src/open_ai.rs | 42 +- crates/proto/proto/zed.proto | 106 +--- crates/proto/src/proto.rs | 9 +- 31 files changed, 896 insertions(+), 2154 deletions(-) delete mode 100644 crates/assistant_tooling/Cargo.toml delete mode 120000 crates/assistant_tooling/LICENSE-GPL delete mode 100644 crates/assistant_tooling/README.md delete mode 100644 crates/assistant_tooling/src/assistant_tooling.rs delete mode 100644 crates/assistant_tooling/src/attachment_registry.rs delete mode 100644 crates/assistant_tooling/src/project_context.rs delete mode 100644 crates/assistant_tooling/src/tool_registry.rs delete mode 100644 crates/collab/src/ai.rs create mode 100644 crates/language_model/src/provider/google.rs diff --git a/Cargo.lock b/Cargo.lock index 92dd5d9a8e..2876ec86a4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -471,27 +471,6 @@ dependencies = [ "workspace", ] -[[package]] -name = "assistant_tooling" -version = "0.1.0" -dependencies = [ - "anyhow", - "collections", - "futures 0.3.28", - "gpui", - "log", - "project", - "repair_json", - "schemars", - "serde", - "serde_json", - "settings", - "sum_tree", - "ui", - "unindent", - "util", -] - [[package]] name = "async-attributes" version = "1.1.2" @@ -4811,8 +4790,10 @@ dependencies = [ "anyhow", "futures 0.3.28", "http_client", + "schemars", "serde", "serde_json", + "strum", ] [[package]] @@ -5988,6 +5969,7 @@ dependencies = [ "env_logger", "feature_flags", "futures 0.3.28", + "google_ai", "gpui", "http_client", "language", @@ -8715,15 +8697,6 @@ dependencies = [ "bytecheck", ] -[[package]] -name = "repair_json" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5ee191e184125fe72cb59b74160e25584e3908f2aaa84cbda1e161347102aa15" -dependencies = [ - "thiserror", -] - [[package]] name = "repl" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index b289d083bb..19a6b6b836 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,7 +6,6 @@ members = [ "crates/assets", "crates/assistant", "crates/assistant_slash_command", - "crates/assistant_tooling", "crates/audio", "crates/auto_update", "crates/breadcrumbs", @@ -178,7 +177,6 @@ anthropic = { path = "crates/anthropic" } assets = { path = "crates/assets" } assistant = { path = "crates/assistant" } assistant_slash_command = { path = "crates/assistant_slash_command" } -assistant_tooling = { path = "crates/assistant_tooling" } audio = { path = "crates/audio" } auto_update = { path = "crates/auto_update" } breadcrumbs = { path = "crates/breadcrumbs" } diff --git a/assets/settings/default.json b/assets/settings/default.json index 529b91b7cd..a26c7d27a0 100644 --- a/assets/settings/default.json +++ b/assets/settings/default.json @@ -870,6 +870,9 @@ "openai": { "api_url": "https://api.openai.com/v1" }, + "google": { + "api_url": "https://generativelanguage.googleapis.com" + }, "ollama": { "api_url": "http://localhost:11434" } diff --git a/crates/anthropic/src/anthropic.rs b/crates/anthropic/src/anthropic.rs index 2d9bd311b8..45a4dfc0d3 100644 --- a/crates/anthropic/src/anthropic.rs +++ b/crates/anthropic/src/anthropic.rs @@ -1,5 +1,5 @@ use anyhow::{anyhow, Result}; -use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt}; +use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt}; use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest}; use isahc::config::Configurable; use serde::{Deserialize, Serialize}; @@ -98,7 +98,7 @@ impl From for String { } } -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize, Deserialize)] pub struct Request { pub model: String, pub messages: Vec, @@ -113,7 +113,7 @@ pub struct RequestMessage { pub content: String, } -#[derive(Deserialize, Debug)] +#[derive(Deserialize, Serialize, Debug)] #[serde(tag = "type", rename_all = "snake_case")] pub enum ResponseEvent { MessageStart { @@ -138,7 +138,7 @@ pub enum ResponseEvent { MessageStop {}, } -#[derive(Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug)] pub struct ResponseMessage { #[serde(rename = "type")] pub message_type: Option, @@ -151,19 +151,19 @@ pub struct ResponseMessage { pub usage: Option, } -#[derive(Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug)] pub struct Usage { pub input_tokens: Option, pub output_tokens: Option, } -#[derive(Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug)] #[serde(tag = "type", rename_all = "snake_case")] pub enum ContentBlock { Text { text: String }, } -#[derive(Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug)] #[serde(tag = "type", rename_all = "snake_case")] pub enum TextDelta { TextDelta { text: String }, @@ -226,6 +226,25 @@ pub async fn stream_completion( } } +pub fn extract_text_from_events( + response: impl Stream>, +) -> impl Stream> { + response.filter_map(|response| async move { + match response { + Ok(response) => match response { + ResponseEvent::ContentBlockStart { content_block, .. } => match content_block { + ContentBlock::Text { text } => Some(Ok(text)), + }, + ResponseEvent::ContentBlockDelta { delta, .. } => match delta { + TextDelta::TextDelta { text } => Some(Ok(text)), + }, + _ => None, + }, + Err(error) => Some(Err(error)), + } + }) +} + // #[cfg(test)] // mod tests { // use super::*; diff --git a/crates/assistant/src/assistant_settings.rs b/crates/assistant/src/assistant_settings.rs index 05c5b56f1c..0d4dbd6824 100644 --- a/crates/assistant/src/assistant_settings.rs +++ b/crates/assistant/src/assistant_settings.rs @@ -249,9 +249,7 @@ impl AssistantSettingsContent { AssistantSettingsContent::Versioned(settings) => match settings { VersionedAssistantSettingsContent::V1(settings) => match provider.as_ref() { "zed.dev" => { - settings.provider = Some(AssistantProviderContentV1::ZedDotDev { - default_model: CloudModel::from_id(&model).ok(), - }); + log::warn!("attempted to set zed.dev model on outdated settings"); } "anthropic" => { let (api_url, low_speed_timeout_in_seconds) = match &settings.provider { diff --git a/crates/assistant_tooling/Cargo.toml b/crates/assistant_tooling/Cargo.toml deleted file mode 100644 index 79f41faad2..0000000000 --- a/crates/assistant_tooling/Cargo.toml +++ /dev/null @@ -1,33 +0,0 @@ -[package] -name = "assistant_tooling" -version = "0.1.0" -edition = "2021" -publish = false -license = "GPL-3.0-or-later" - -[lints] -workspace = true - -[lib] -path = "src/assistant_tooling.rs" - -[dependencies] -anyhow.workspace = true -collections.workspace = true -futures.workspace = true -gpui.workspace = true -log.workspace = true -project.workspace = true -repair_json.workspace = true -schemars.workspace = true -serde.workspace = true -serde_json.workspace = true -sum_tree.workspace = true -ui.workspace = true -util.workspace = true - -[dev-dependencies] -gpui = { workspace = true, features = ["test-support"] } -project = { workspace = true, features = ["test-support"] } -settings = { workspace = true, features = ["test-support"] } -unindent.workspace = true diff --git a/crates/assistant_tooling/LICENSE-GPL b/crates/assistant_tooling/LICENSE-GPL deleted file mode 120000 index 89e542f750..0000000000 --- a/crates/assistant_tooling/LICENSE-GPL +++ /dev/null @@ -1 +0,0 @@ -../../LICENSE-GPL \ No newline at end of file diff --git a/crates/assistant_tooling/README.md b/crates/assistant_tooling/README.md deleted file mode 100644 index 160869ae97..0000000000 --- a/crates/assistant_tooling/README.md +++ /dev/null @@ -1,85 +0,0 @@ -# Assistant Tooling - -Bringing Language Model tool calling to GPUI. - -This unlocks: - -- **Structured Extraction** of model responses -- **Validation** of model inputs -- **Execution** of chosen tools - -## Overview - -Language Models can produce structured outputs that are perfect for calling functions. The most famous of these is OpenAI's tool calling. When making a chat completion you can pass a list of tools available to the model. The model will choose `0..n` tools to help them complete a user's task. It's up to _you_ to create the tools that the model can call. - -> **User**: "Hey I need help with implementing a collapsible panel in GPUI" -> -> **Assistant**: "Sure, I can help with that. Let me see what I can find." -> -> `tool_calls: ["name": "query_codebase", arguments: "{ 'query': 'GPUI collapsible panel' }"]` -> -> `result: "['crates/gpui/src/panel.rs:12: impl Panel { ... }', 'crates/gpui/src/panel.rs:20: impl Panel { ... }']"` -> -> **Assistant**: "Here are some excerpts from the GPUI codebase that might help you." - -This library is designed to facilitate this interaction mode by allowing you to go from `struct` to `tool` with two simple traits, `LanguageModelTool` and `ToolView`. - -## Using the Tool Registry - -```rust -let mut tool_registry = ToolRegistry::new(); -tool_registry - .register(WeatherTool { api_client }, - }) - .unwrap(); // You can only register one tool per name - -let completion = cx.update(|cx| { - CompletionProvider::get(cx).complete( - model_name, - messages, - Vec::new(), - 1.0, - // The definitions get passed directly to OpenAI when you want - // the model to be able to call your tool - tool_registry.definitions(), - ) -}); - -let mut stream = completion?.await?; - -let mut message = AssistantMessage::new(); - -while let Some(delta) = stream.next().await { - // As messages stream in, you'll get both assistant content - if let Some(content) = &delta.content { - message - .body - .update(cx, |message, cx| message.append(&content, cx)); - } - - // And tool calls! - for tool_call_delta in delta.tool_calls { - let index = tool_call_delta.index as usize; - if index >= message.tool_calls.len() { - message.tool_calls.resize_with(index + 1, Default::default); - } - let tool_call = &mut message.tool_calls[index]; - - // Build up an ID - if let Some(id) = &tool_call_delta.id { - tool_call.id.push_str(id); - } - - tool_registry.update_tool_call( - tool_call, - tool_call_delta.name.as_deref(), - tool_call_delta.arguments.as_deref(), - cx, - ); - } -} -``` - -Once the stream of tokens is complete, you can execute the tool call by calling `tool_registry.execute_tool_call(tool_call, cx)`, which returns a `Task>`. - -As the tokens stream in and tool calls are executed, your `ToolView` will get updates. Render each tool call by passing that `tool_call` in to `tool_registry.render_tool_call(tool_call, cx)`. The final message for the model can be pulled by calling `self.tool_registry.content_for_tool_call( tool_call, &mut project_context, cx, )`. diff --git a/crates/assistant_tooling/src/assistant_tooling.rs b/crates/assistant_tooling/src/assistant_tooling.rs deleted file mode 100644 index 9dcf2908e9..0000000000 --- a/crates/assistant_tooling/src/assistant_tooling.rs +++ /dev/null @@ -1,13 +0,0 @@ -mod attachment_registry; -mod project_context; -mod tool_registry; - -pub use attachment_registry::{ - AttachmentOutput, AttachmentRegistry, LanguageModelAttachment, SavedUserAttachment, - UserAttachment, -}; -pub use project_context::ProjectContext; -pub use tool_registry::{ - LanguageModelTool, SavedToolFunctionCall, ToolFunctionCall, ToolFunctionDefinition, - ToolRegistry, ToolView, -}; diff --git a/crates/assistant_tooling/src/attachment_registry.rs b/crates/assistant_tooling/src/attachment_registry.rs deleted file mode 100644 index e8b52d26f0..0000000000 --- a/crates/assistant_tooling/src/attachment_registry.rs +++ /dev/null @@ -1,234 +0,0 @@ -use crate::ProjectContext; -use anyhow::{anyhow, Result}; -use collections::HashMap; -use futures::future::join_all; -use gpui::{AnyView, Render, Task, View, WindowContext}; -use serde::{de::DeserializeOwned, Deserialize, Serialize}; -use serde_json::value::RawValue; -use std::{ - any::TypeId, - sync::{ - atomic::{AtomicBool, Ordering::SeqCst}, - Arc, - }, -}; -use util::ResultExt as _; - -pub struct AttachmentRegistry { - registered_attachments: HashMap, -} - -pub trait AttachmentOutput { - fn generate(&self, project: &mut ProjectContext, cx: &mut WindowContext) -> String; -} - -pub trait LanguageModelAttachment { - type Output: DeserializeOwned + Serialize + 'static; - type View: Render + AttachmentOutput; - - fn name(&self) -> Arc; - fn run(&self, cx: &mut WindowContext) -> Task>; - fn view(&self, output: Result, cx: &mut WindowContext) -> View; -} - -/// A collected attachment from running an attachment tool -pub struct UserAttachment { - pub view: AnyView, - name: Arc, - serialized_output: Result, String>, - generate_fn: fn(AnyView, &mut ProjectContext, cx: &mut WindowContext) -> String, -} - -#[derive(Serialize, Deserialize)] -pub struct SavedUserAttachment { - name: Arc, - serialized_output: Result, String>, -} - -/// Internal representation of an attachment tool to allow us to treat them dynamically -struct RegisteredAttachment { - name: Arc, - enabled: AtomicBool, - call: Box Task>>, - deserialize: Box Result>, -} - -impl AttachmentRegistry { - pub fn new() -> Self { - Self { - registered_attachments: HashMap::default(), - } - } - - pub fn register(&mut self, attachment: A) { - let attachment = Arc::new(attachment); - - let call = Box::new({ - let attachment = attachment.clone(); - move |cx: &mut WindowContext| { - let result = attachment.run(cx); - let attachment = attachment.clone(); - cx.spawn(move |mut cx| async move { - let result: Result = result.await; - let serialized_output = - result - .as_ref() - .map_err(ToString::to_string) - .and_then(|output| { - Ok(RawValue::from_string( - serde_json::to_string(output).map_err(|e| e.to_string())?, - ) - .unwrap()) - }); - - let view = cx.update(|cx| attachment.view(result, cx))?; - - Ok(UserAttachment { - name: attachment.name(), - view: view.into(), - generate_fn: generate::, - serialized_output, - }) - }) - } - }); - - let deserialize = Box::new({ - let attachment = attachment.clone(); - move |saved_attachment: &SavedUserAttachment, cx: &mut WindowContext| { - let serialized_output = saved_attachment.serialized_output.clone(); - let output = match &serialized_output { - Ok(serialized_output) => { - Ok(serde_json::from_str::(serialized_output.get())?) - } - Err(error) => Err(anyhow!("{error}")), - }; - let view = attachment.view(output, cx).into(); - - Ok(UserAttachment { - name: saved_attachment.name.clone(), - view, - serialized_output, - generate_fn: generate::, - }) - } - }); - - self.registered_attachments.insert( - TypeId::of::(), - RegisteredAttachment { - name: attachment.name(), - call, - deserialize, - enabled: AtomicBool::new(true), - }, - ); - return; - - fn generate( - view: AnyView, - project: &mut ProjectContext, - cx: &mut WindowContext, - ) -> String { - view.downcast::() - .unwrap() - .update(cx, |view, cx| T::View::generate(view, project, cx)) - } - } - - pub fn set_attachment_tool_enabled( - &self, - is_enabled: bool, - ) { - if let Some(attachment) = self.registered_attachments.get(&TypeId::of::()) { - attachment.enabled.store(is_enabled, SeqCst); - } - } - - pub fn is_attachment_tool_enabled(&self) -> bool { - if let Some(attachment) = self.registered_attachments.get(&TypeId::of::()) { - attachment.enabled.load(SeqCst) - } else { - false - } - } - - pub fn call( - &self, - cx: &mut WindowContext, - ) -> Task> { - let Some(attachment) = self.registered_attachments.get(&TypeId::of::()) else { - return Task::ready(Err(anyhow!("no attachment tool"))); - }; - - (attachment.call)(cx) - } - - pub fn call_all_attachment_tools( - self: Arc, - cx: &mut WindowContext<'_>, - ) -> Task>> { - let this = self.clone(); - cx.spawn(|mut cx| async move { - let attachment_tasks = cx.update(|cx| { - let mut tasks = Vec::new(); - for attachment in this - .registered_attachments - .values() - .filter(|attachment| attachment.enabled.load(SeqCst)) - { - tasks.push((attachment.call)(cx)) - } - - tasks - })?; - - let attachments = join_all(attachment_tasks.into_iter()).await; - - Ok(attachments - .into_iter() - .filter_map(|attachment| attachment.log_err()) - .collect()) - }) - } - - pub fn serialize_user_attachment( - &self, - user_attachment: &UserAttachment, - ) -> SavedUserAttachment { - SavedUserAttachment { - name: user_attachment.name.clone(), - serialized_output: user_attachment.serialized_output.clone(), - } - } - - pub fn deserialize_user_attachment( - &self, - saved_user_attachment: SavedUserAttachment, - cx: &mut WindowContext, - ) -> Result { - if let Some(registered_attachment) = self - .registered_attachments - .values() - .find(|attachment| attachment.name == saved_user_attachment.name) - { - (registered_attachment.deserialize)(&saved_user_attachment, cx) - } else { - Err(anyhow!( - "no attachment tool for name {}", - saved_user_attachment.name - )) - } - } -} - -impl UserAttachment { - pub fn generate(&self, output: &mut ProjectContext, cx: &mut WindowContext) -> Option { - let result = (self.generate_fn)(self.view.clone(), output, cx); - if result.is_empty() { - None - } else { - Some(result) - } - } -} diff --git a/crates/assistant_tooling/src/project_context.rs b/crates/assistant_tooling/src/project_context.rs deleted file mode 100644 index 2640ce1ed5..0000000000 --- a/crates/assistant_tooling/src/project_context.rs +++ /dev/null @@ -1,296 +0,0 @@ -use anyhow::{anyhow, Result}; -use gpui::{AppContext, Model, Task, WeakModel}; -use project::{Fs, Project, ProjectPath, Worktree}; -use std::{cmp::Ordering, fmt::Write as _, ops::Range, sync::Arc}; -use sum_tree::TreeMap; - -pub struct ProjectContext { - files: TreeMap, - project: WeakModel, - fs: Arc, -} - -#[derive(Debug, Clone)] -enum PathState { - PathOnly, - EntireFile, - Excerpts { ranges: Vec> }, -} - -impl ProjectContext { - pub fn new(project: WeakModel, fs: Arc) -> Self { - Self { - files: TreeMap::default(), - fs, - project, - } - } - - pub fn add_path(&mut self, project_path: ProjectPath) { - if self.files.get(&project_path).is_none() { - self.files.insert(project_path, PathState::PathOnly); - } - } - - pub fn add_excerpts(&mut self, project_path: ProjectPath, new_ranges: &[Range]) { - let previous_state = self - .files - .get(&project_path) - .unwrap_or(&PathState::PathOnly); - - let mut ranges = match previous_state { - PathState::EntireFile => return, - PathState::PathOnly => Vec::new(), - PathState::Excerpts { ranges } => ranges.to_vec(), - }; - - for new_range in new_ranges { - let ix = ranges.binary_search_by(|probe| { - if probe.end < new_range.start { - Ordering::Less - } else if probe.start > new_range.end { - Ordering::Greater - } else { - Ordering::Equal - } - }); - - match ix { - Ok(mut ix) => { - let existing = &mut ranges[ix]; - existing.start = existing.start.min(new_range.start); - existing.end = existing.end.max(new_range.end); - while ix + 1 < ranges.len() && ranges[ix + 1].start <= ranges[ix].end { - ranges[ix].end = ranges[ix].end.max(ranges[ix + 1].end); - ranges.remove(ix + 1); - } - while ix > 0 && ranges[ix - 1].end >= ranges[ix].start { - ranges[ix].start = ranges[ix].start.min(ranges[ix - 1].start); - ranges.remove(ix - 1); - ix -= 1; - } - } - Err(ix) => { - ranges.insert(ix, new_range.clone()); - } - } - } - - self.files - .insert(project_path, PathState::Excerpts { ranges }); - } - - pub fn add_file(&mut self, project_path: ProjectPath) { - self.files.insert(project_path, PathState::EntireFile); - } - - pub fn generate_system_message(&self, cx: &mut AppContext) -> Task> { - let project = self - .project - .upgrade() - .ok_or_else(|| anyhow!("project dropped")); - let files = self.files.clone(); - let fs = self.fs.clone(); - cx.spawn(|cx| async move { - let project = project?; - let mut result = "project structure:\n".to_string(); - - let mut last_worktree: Option> = None; - for (project_path, path_state) in files.iter() { - if let Some(worktree) = &last_worktree { - if worktree.read_with(&cx, |tree, _| tree.id())? != project_path.worktree_id { - last_worktree = None; - } - } - - let worktree; - if let Some(last_worktree) = &last_worktree { - worktree = last_worktree.clone(); - } else if let Some(tree) = project.read_with(&cx, |project, cx| { - project.worktree_for_id(project_path.worktree_id, cx) - })? { - worktree = tree; - last_worktree = Some(worktree.clone()); - let worktree_name = - worktree.read_with(&cx, |tree, _cx| tree.root_name().to_string())?; - writeln!(&mut result, "# {}", worktree_name).unwrap(); - } else { - continue; - } - - let worktree_abs_path = worktree.read_with(&cx, |tree, _cx| tree.abs_path())?; - let path = &project_path.path; - writeln!(&mut result, "## {}", path.display()).unwrap(); - - match path_state { - PathState::PathOnly => {} - PathState::EntireFile => { - let text = fs.load(&worktree_abs_path.join(&path)).await?; - writeln!(&mut result, "~~~\n{text}\n~~~").unwrap(); - } - PathState::Excerpts { ranges } => { - let text = fs.load(&worktree_abs_path.join(&path)).await?; - - writeln!(&mut result, "~~~").unwrap(); - - // Assumption: ranges are in order, not overlapping - let mut prev_range_end = 0; - for range in ranges { - if range.start > prev_range_end { - writeln!(&mut result, "...").unwrap(); - prev_range_end = range.end; - } - - let mut start = range.start; - let mut end = range.end.min(text.len()); - while !text.is_char_boundary(start) { - start += 1; - } - while !text.is_char_boundary(end) { - end -= 1; - } - result.push_str(&text[start..end]); - if !result.ends_with('\n') { - result.push('\n'); - } - } - - if prev_range_end < text.len() { - writeln!(&mut result, "...").unwrap(); - } - - writeln!(&mut result, "~~~").unwrap(); - } - } - } - Ok(result) - }) - } -} - -#[cfg(test)] -mod tests { - use std::path::Path; - - use super::*; - use gpui::TestAppContext; - use project::FakeFs; - use serde_json::json; - use settings::SettingsStore; - - use unindent::Unindent as _; - - #[gpui::test] - async fn test_system_message_generation(cx: &mut TestAppContext) { - init_test(cx); - - let file_3_contents = r#" - fn test1() {} - fn test2() {} - fn test3() {} - "# - .unindent(); - - let fs = FakeFs::new(cx.executor()); - fs.insert_tree( - "/code", - json!({ - "root1": { - "lib": { - "file1.rs": "mod example;", - "file2.rs": "", - }, - "test": { - "file3.rs": file_3_contents, - } - }, - "root2": { - "src": { - "main.rs": "" - } - } - }), - ) - .await; - - let project = Project::test( - fs.clone(), - ["/code/root1".as_ref(), "/code/root2".as_ref()], - cx, - ) - .await; - - let worktree_ids = project.read_with(cx, |project, cx| { - project - .worktrees(cx) - .map(|worktree| worktree.read(cx).id()) - .collect::>() - }); - - let mut ax = ProjectContext::new(project.downgrade(), fs); - - ax.add_file(ProjectPath { - worktree_id: worktree_ids[0], - path: Path::new("lib/file1.rs").into(), - }); - - let message = cx - .update(|cx| ax.generate_system_message(cx)) - .await - .unwrap(); - assert_eq!( - r#" - project structure: - # root1 - ## lib/file1.rs - ~~~ - mod example; - ~~~ - "# - .unindent(), - message - ); - - ax.add_excerpts( - ProjectPath { - worktree_id: worktree_ids[0], - path: Path::new("test/file3.rs").into(), - }, - &[ - file_3_contents.find("fn test2").unwrap() - ..file_3_contents.find("fn test3").unwrap(), - ], - ); - - let message = cx - .update(|cx| ax.generate_system_message(cx)) - .await - .unwrap(); - assert_eq!( - r#" - project structure: - # root1 - ## lib/file1.rs - ~~~ - mod example; - ~~~ - ## test/file3.rs - ~~~ - ... - fn test2() {} - ... - ~~~ - "# - .unindent(), - message - ); - } - - fn init_test(cx: &mut TestAppContext) { - cx.update(|cx| { - let settings_store = SettingsStore::test(cx); - cx.set_global(settings_store); - Project::init_settings(cx); - }); - } -} diff --git a/crates/assistant_tooling/src/tool_registry.rs b/crates/assistant_tooling/src/tool_registry.rs deleted file mode 100644 index e5f8914eb5..0000000000 --- a/crates/assistant_tooling/src/tool_registry.rs +++ /dev/null @@ -1,526 +0,0 @@ -use crate::ProjectContext; -use anyhow::{anyhow, Result}; -use gpui::{AnyElement, AnyView, IntoElement, Render, Task, View, WindowContext}; -use repair_json::repair; -use schemars::{schema::RootSchema, schema_for, JsonSchema}; -use serde::{de::DeserializeOwned, Deserialize, Serialize}; -use serde_json::value::RawValue; -use std::{ - any::TypeId, - collections::HashMap, - fmt::Display, - mem, - sync::atomic::{AtomicBool, Ordering::SeqCst}, -}; -use ui::ViewContext; - -pub struct ToolRegistry { - registered_tools: HashMap, -} - -#[derive(Default)] -pub struct ToolFunctionCall { - pub id: String, - pub name: String, - pub arguments: String, - state: ToolFunctionCallState, -} - -#[derive(Default)] -enum ToolFunctionCallState { - #[default] - Initializing, - NoSuchTool, - KnownTool(Box), - ExecutedTool(Box), -} - -trait InternalToolView { - fn view(&self) -> AnyView; - fn generate(&self, project: &mut ProjectContext, cx: &mut WindowContext) -> String; - fn try_set_input(&self, input: &str, cx: &mut WindowContext); - fn execute(&self, cx: &mut WindowContext) -> Task>; - fn serialize_output(&self, cx: &mut WindowContext) -> Result>; - fn deserialize_output(&self, raw_value: &RawValue, cx: &mut WindowContext) -> Result<()>; -} - -#[derive(Default, Serialize, Deserialize)] -pub struct SavedToolFunctionCall { - id: String, - name: String, - arguments: String, - state: SavedToolFunctionCallState, -} - -#[derive(Default, Serialize, Deserialize)] -enum SavedToolFunctionCallState { - #[default] - Initializing, - NoSuchTool, - KnownTool, - ExecutedTool(Box), -} - -#[derive(Clone, Debug, PartialEq)] -pub struct ToolFunctionDefinition { - pub name: String, - pub description: String, - pub parameters: RootSchema, -} - -pub trait LanguageModelTool { - type View: ToolView; - - /// Returns the name of the tool. - /// - /// This name is exposed to the language model to allow the model to pick - /// which tools to use. As this name is used to identify the tool within a - /// tool registry, it should be unique. - fn name(&self) -> String; - - /// Returns the description of the tool. - /// - /// This can be used to _prompt_ the model as to what the tool does. - fn description(&self) -> String; - - /// Returns the OpenAI Function definition for the tool, for direct use with OpenAI's API. - fn definition(&self) -> ToolFunctionDefinition { - let root_schema = schema_for!(::Input); - - ToolFunctionDefinition { - name: self.name(), - description: self.description(), - parameters: root_schema, - } - } - - /// A view of the output of running the tool, for displaying to the user. - fn view(&self, cx: &mut WindowContext) -> View; -} - -pub trait ToolView: Render { - /// The input type that will be passed in to `execute` when the tool is called - /// by the language model. - type Input: DeserializeOwned + JsonSchema; - - /// The output returned by executing the tool. - type SerializedState: DeserializeOwned + Serialize; - - fn generate(&self, project: &mut ProjectContext, cx: &mut ViewContext) -> String; - fn set_input(&mut self, input: Self::Input, cx: &mut ViewContext); - fn execute(&mut self, cx: &mut ViewContext) -> Task>; - - fn serialize(&self, cx: &mut ViewContext) -> Self::SerializedState; - fn deserialize( - &mut self, - output: Self::SerializedState, - cx: &mut ViewContext, - ) -> Result<()>; -} - -struct RegisteredTool { - enabled: AtomicBool, - type_id: TypeId, - build_view: Box Box>, - definition: ToolFunctionDefinition, -} - -impl ToolRegistry { - pub fn new() -> Self { - Self { - registered_tools: HashMap::new(), - } - } - - pub fn set_tool_enabled(&self, is_enabled: bool) { - for tool in self.registered_tools.values() { - if tool.type_id == TypeId::of::() { - tool.enabled.store(is_enabled, SeqCst); - return; - } - } - } - - pub fn is_tool_enabled(&self) -> bool { - for tool in self.registered_tools.values() { - if tool.type_id == TypeId::of::() { - return tool.enabled.load(SeqCst); - } - } - false - } - - pub fn definitions(&self) -> Vec { - self.registered_tools - .values() - .filter(|tool| tool.enabled.load(SeqCst)) - .map(|tool| tool.definition.clone()) - .collect() - } - - pub fn update_tool_call( - &self, - call: &mut ToolFunctionCall, - name: Option<&str>, - arguments: Option<&str>, - cx: &mut WindowContext, - ) { - if let Some(name) = name { - call.name.push_str(name); - } - if let Some(arguments) = arguments { - if call.arguments.is_empty() { - if let Some(tool) = self.registered_tools.get(&call.name) { - let view = (tool.build_view)(cx); - call.state = ToolFunctionCallState::KnownTool(view); - } else { - call.state = ToolFunctionCallState::NoSuchTool; - } - } - call.arguments.push_str(arguments); - - if let ToolFunctionCallState::KnownTool(view) = &call.state { - if let Ok(repaired_arguments) = repair(call.arguments.clone()) { - view.try_set_input(&repaired_arguments, cx) - } - } - } - } - - pub fn execute_tool_call( - &self, - tool_call: &mut ToolFunctionCall, - cx: &mut WindowContext, - ) -> Option>> { - if let ToolFunctionCallState::KnownTool(view) = mem::take(&mut tool_call.state) { - let task = view.execute(cx); - tool_call.state = ToolFunctionCallState::ExecutedTool(view); - Some(task) - } else { - None - } - } - - pub fn render_tool_call( - &self, - tool_call: &ToolFunctionCall, - _cx: &mut WindowContext, - ) -> Option { - match &tool_call.state { - ToolFunctionCallState::NoSuchTool => { - Some(ui::Label::new("No such tool").into_any_element()) - } - ToolFunctionCallState::Initializing => None, - ToolFunctionCallState::KnownTool(view) | ToolFunctionCallState::ExecutedTool(view) => { - Some(view.view().into_any_element()) - } - } - } - - pub fn content_for_tool_call( - &self, - tool_call: &ToolFunctionCall, - project_context: &mut ProjectContext, - cx: &mut WindowContext, - ) -> String { - match &tool_call.state { - ToolFunctionCallState::Initializing => String::new(), - ToolFunctionCallState::NoSuchTool => { - format!("No such tool: {}", tool_call.name) - } - ToolFunctionCallState::KnownTool(view) | ToolFunctionCallState::ExecutedTool(view) => { - view.generate(project_context, cx) - } - } - } - - pub fn serialize_tool_call( - &self, - call: &ToolFunctionCall, - cx: &mut WindowContext, - ) -> Result { - Ok(SavedToolFunctionCall { - id: call.id.clone(), - name: call.name.clone(), - arguments: call.arguments.clone(), - state: match &call.state { - ToolFunctionCallState::Initializing => SavedToolFunctionCallState::Initializing, - ToolFunctionCallState::NoSuchTool => SavedToolFunctionCallState::NoSuchTool, - ToolFunctionCallState::KnownTool(_) => SavedToolFunctionCallState::KnownTool, - ToolFunctionCallState::ExecutedTool(view) => { - SavedToolFunctionCallState::ExecutedTool(view.serialize_output(cx)?) - } - }, - }) - } - - pub fn deserialize_tool_call( - &self, - call: &SavedToolFunctionCall, - cx: &mut WindowContext, - ) -> Result { - let Some(tool) = self.registered_tools.get(&call.name) else { - return Err(anyhow!("no such tool {}", call.name)); - }; - - Ok(ToolFunctionCall { - id: call.id.clone(), - name: call.name.clone(), - arguments: call.arguments.clone(), - state: match &call.state { - SavedToolFunctionCallState::Initializing => ToolFunctionCallState::Initializing, - SavedToolFunctionCallState::NoSuchTool => ToolFunctionCallState::NoSuchTool, - SavedToolFunctionCallState::KnownTool => { - log::error!("Deserialized tool that had not executed"); - let view = (tool.build_view)(cx); - view.try_set_input(&call.arguments, cx); - ToolFunctionCallState::KnownTool(view) - } - SavedToolFunctionCallState::ExecutedTool(output) => { - let view = (tool.build_view)(cx); - view.try_set_input(&call.arguments, cx); - view.deserialize_output(output, cx)?; - ToolFunctionCallState::ExecutedTool(view) - } - }, - }) - } - - pub fn register(&mut self, tool: T) -> Result<()> { - let name = tool.name(); - let registered_tool = RegisteredTool { - type_id: TypeId::of::(), - definition: tool.definition(), - enabled: AtomicBool::new(true), - build_view: Box::new(move |cx: &mut WindowContext| Box::new(tool.view(cx))), - }; - - let previous = self.registered_tools.insert(name.clone(), registered_tool); - if previous.is_some() { - return Err(anyhow!("already registered a tool with name {}", name)); - } - - return Ok(()); - } -} - -impl InternalToolView for View { - fn view(&self) -> AnyView { - self.clone().into() - } - - fn generate(&self, project: &mut ProjectContext, cx: &mut WindowContext) -> String { - self.update(cx, |view, cx| view.generate(project, cx)) - } - - fn try_set_input(&self, input: &str, cx: &mut WindowContext) { - if let Ok(input) = serde_json::from_str::(input) { - self.update(cx, |view, cx| { - view.set_input(input, cx); - cx.notify(); - }); - } - } - - fn execute(&self, cx: &mut WindowContext) -> Task> { - self.update(cx, |view, cx| view.execute(cx)) - } - - fn serialize_output(&self, cx: &mut WindowContext) -> Result> { - let output = self.update(cx, |view, cx| view.serialize(cx)); - Ok(RawValue::from_string(serde_json::to_string(&output)?)?) - } - - fn deserialize_output(&self, output: &RawValue, cx: &mut WindowContext) -> Result<()> { - let state = serde_json::from_str::(output.get())?; - self.update(cx, |view, cx| view.deserialize(state, cx))?; - Ok(()) - } -} - -impl Display for ToolFunctionDefinition { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let schema = serde_json::to_string(&self.parameters).ok(); - let schema = schema.unwrap_or("None".to_string()); - write!(f, "Name: {}:\n", self.name)?; - write!(f, "Description: {}\n", self.description)?; - write!(f, "Parameters: {}", schema) - } -} - -#[cfg(test)] -mod test { - use super::*; - use gpui::{div, prelude::*, Render, TestAppContext}; - use gpui::{EmptyView, View}; - use schemars::JsonSchema; - use serde::{Deserialize, Serialize}; - use serde_json::json; - - #[derive(Deserialize, Serialize, JsonSchema)] - struct WeatherQuery { - location: String, - unit: String, - } - - #[derive(Clone, Serialize, Deserialize, PartialEq, Debug)] - struct WeatherResult { - location: String, - temperature: f64, - unit: String, - } - - struct WeatherView { - input: Option, - result: Option, - - // Fake API call - current_weather: WeatherResult, - } - - #[derive(Clone, Serialize)] - struct WeatherTool { - current_weather: WeatherResult, - } - - impl WeatherView { - fn new(current_weather: WeatherResult) -> Self { - Self { - input: None, - result: None, - current_weather, - } - } - } - - impl Render for WeatherView { - fn render(&mut self, _cx: &mut gpui::ViewContext) -> impl IntoElement { - match self.result { - Some(ref result) => div() - .child(format!("temperature: {}", result.temperature)) - .into_any_element(), - None => div().child("Calculating weather...").into_any_element(), - } - } - } - - impl ToolView for WeatherView { - type Input = WeatherQuery; - - type SerializedState = WeatherResult; - - fn generate(&self, _output: &mut ProjectContext, _cx: &mut ViewContext) -> String { - serde_json::to_string(&self.result).unwrap() - } - - fn set_input(&mut self, input: Self::Input, cx: &mut ViewContext) { - self.input = Some(input); - cx.notify(); - } - - fn execute(&mut self, _cx: &mut ViewContext) -> Task> { - let input = self.input.as_ref().unwrap(); - - let _location = input.location.clone(); - let _unit = input.unit.clone(); - - let weather = self.current_weather.clone(); - - self.result = Some(weather); - - Task::ready(Ok(())) - } - - fn serialize(&self, _cx: &mut ViewContext) -> Self::SerializedState { - self.current_weather.clone() - } - - fn deserialize( - &mut self, - output: Self::SerializedState, - _cx: &mut ViewContext, - ) -> Result<()> { - self.current_weather = output; - Ok(()) - } - } - - impl LanguageModelTool for WeatherTool { - type View = WeatherView; - - fn name(&self) -> String { - "get_current_weather".to_string() - } - - fn description(&self) -> String { - "Fetches the current weather for a given location.".to_string() - } - - fn view(&self, cx: &mut WindowContext) -> View { - cx.new_view(|_cx| WeatherView::new(self.current_weather.clone())) - } - } - - #[gpui::test] - async fn test_openai_weather_example(cx: &mut TestAppContext) { - let (_, cx) = cx.add_window_view(|_cx| EmptyView); - - let mut registry = ToolRegistry::new(); - registry - .register(WeatherTool { - current_weather: WeatherResult { - location: "San Francisco".to_string(), - temperature: 21.0, - unit: "Celsius".to_string(), - }, - }) - .unwrap(); - - let definitions = registry.definitions(); - assert_eq!( - definitions, - [ToolFunctionDefinition { - name: "get_current_weather".to_string(), - description: "Fetches the current weather for a given location.".to_string(), - parameters: serde_json::from_value(json!({ - "$schema": "http://json-schema.org/draft-07/schema#", - "title": "WeatherQuery", - "type": "object", - "properties": { - "location": { - "type": "string" - }, - "unit": { - "type": "string" - } - }, - "required": ["location", "unit"] - })) - .unwrap(), - }] - ); - - let mut call = ToolFunctionCall { - id: "the-id".to_string(), - name: "get_cur".to_string(), - ..Default::default() - }; - - let task = cx.update(|cx| { - registry.update_tool_call( - &mut call, - Some("rent_weather"), - Some(r#"{"location": "San Francisco","#), - cx, - ); - registry.update_tool_call(&mut call, None, Some(r#" "unit": "Celsius"}"#), cx); - registry.execute_tool_call(&mut call, cx).unwrap() - }); - task.await.unwrap(); - - match &call.state { - ToolFunctionCallState::ExecutedTool(_view) => {} - _ => panic!(), - } - } -} diff --git a/crates/collab/src/ai.rs b/crates/collab/src/ai.rs deleted file mode 100644 index 06c6e77dfd..0000000000 --- a/crates/collab/src/ai.rs +++ /dev/null @@ -1,138 +0,0 @@ -use anyhow::{anyhow, Context as _, Result}; -use rpc::proto; -use util::ResultExt as _; - -pub fn language_model_request_to_open_ai( - request: proto::CompleteWithLanguageModel, -) -> Result { - Ok(open_ai::Request { - model: open_ai::Model::from_id(&request.model).unwrap_or(open_ai::Model::FourTurbo), - messages: request - .messages - .into_iter() - .map(|message: proto::LanguageModelRequestMessage| { - let role = proto::LanguageModelRole::from_i32(message.role) - .ok_or_else(|| anyhow!("invalid role {}", message.role))?; - - let openai_message = match role { - proto::LanguageModelRole::LanguageModelUser => open_ai::RequestMessage::User { - content: message.content, - }, - proto::LanguageModelRole::LanguageModelAssistant => { - open_ai::RequestMessage::Assistant { - content: Some(message.content), - tool_calls: message - .tool_calls - .into_iter() - .filter_map(|call| { - Some(open_ai::ToolCall { - id: call.id, - content: match call.variant? { - proto::tool_call::Variant::Function(f) => { - open_ai::ToolCallContent::Function { - function: open_ai::FunctionContent { - name: f.name, - arguments: f.arguments, - }, - } - } - }, - }) - }) - .collect(), - } - } - proto::LanguageModelRole::LanguageModelSystem => { - open_ai::RequestMessage::System { - content: message.content, - } - } - proto::LanguageModelRole::LanguageModelTool => open_ai::RequestMessage::Tool { - tool_call_id: message - .tool_call_id - .ok_or_else(|| anyhow!("tool message is missing tool call id"))?, - content: message.content, - }, - }; - - Ok(openai_message) - }) - .collect::>>()?, - stream: true, - stop: request.stop, - temperature: request.temperature, - tools: request - .tools - .into_iter() - .filter_map(|tool| { - Some(match tool.variant? { - proto::chat_completion_tool::Variant::Function(f) => { - open_ai::ToolDefinition::Function { - function: open_ai::FunctionDefinition { - name: f.name, - description: f.description, - parameters: if let Some(params) = &f.parameters { - Some( - serde_json::from_str(params) - .context("failed to deserialize tool parameters") - .log_err()?, - ) - } else { - None - }, - }, - } - } - }) - }) - .collect(), - tool_choice: request.tool_choice, - }) -} - -pub fn language_model_request_to_google_ai( - request: proto::CompleteWithLanguageModel, -) -> Result { - Ok(google_ai::GenerateContentRequest { - contents: request - .messages - .into_iter() - .map(language_model_request_message_to_google_ai) - .collect::>>()?, - generation_config: None, - safety_settings: None, - }) -} - -pub fn language_model_request_message_to_google_ai( - message: proto::LanguageModelRequestMessage, -) -> Result { - let role = proto::LanguageModelRole::from_i32(message.role) - .ok_or_else(|| anyhow!("invalid role {}", message.role))?; - - Ok(google_ai::Content { - parts: vec![google_ai::Part::TextPart(google_ai::TextPart { - text: message.content, - })], - role: match role { - proto::LanguageModelRole::LanguageModelUser => google_ai::Role::User, - proto::LanguageModelRole::LanguageModelAssistant => google_ai::Role::Model, - proto::LanguageModelRole::LanguageModelSystem => google_ai::Role::User, - proto::LanguageModelRole::LanguageModelTool => { - Err(anyhow!("we don't handle tool calls with google ai yet"))? - } - }, - }) -} - -pub fn count_tokens_request_to_google_ai( - request: proto::CountTokensWithLanguageModel, -) -> Result { - Ok(google_ai::CountTokensRequest { - contents: request - .messages - .into_iter() - .map(language_model_request_message_to_google_ai) - .collect::>>()?, - }) -} diff --git a/crates/collab/src/lib.rs b/crates/collab/src/lib.rs index ae83fccb98..2673ca3fb8 100644 --- a/crates/collab/src/lib.rs +++ b/crates/collab/src/lib.rs @@ -1,4 +1,3 @@ -pub mod ai; pub mod api; pub mod auth; pub mod db; diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 3ec13ce045..92e5b1a584 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -46,8 +46,8 @@ use http_client::IsahcHttpClient; use prometheus::{register_int_gauge, IntGauge}; use rpc::{ proto::{ - self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LanguageModelRole, - LiveKitConnectionInfo, RequestMessage, ShareProject, UpdateChannelBufferCollaborators, + self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo, + RequestMessage, ShareProject, UpdateChannelBufferCollaborators, }, Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope, }; @@ -618,17 +618,6 @@ impl Server { ) } }) - .add_request_handler({ - let app_state = app_state.clone(); - user_handler(move |request, response, session| { - count_tokens_with_language_model( - request, - response, - session, - app_state.config.google_ai_api_key.clone(), - ) - }) - }) .add_request_handler({ user_handler(move |request, response, session| { get_cached_embeddings(request, response, session) @@ -4514,8 +4503,8 @@ impl RateLimit for CompleteWithLanguageModelRateLimit { } async fn complete_with_language_model( - mut request: proto::CompleteWithLanguageModel, - response: StreamingResponse, + query: proto::QueryLanguageModel, + response: StreamingResponse, session: Session, open_ai_api_key: Option>, google_ai_api_key: Option>, @@ -4525,287 +4514,95 @@ async fn complete_with_language_model( return Err(anyhow!("user not found"))?; }; authorize_access_to_language_models(&session).await?; - session - .rate_limiter - .check::(session.user_id()) - .await?; - let mut provider_and_model = request.model.split('/'); - let (provider, model) = match ( - provider_and_model.next().unwrap(), - provider_and_model.next(), - ) { - (provider, Some(model)) => (provider, model), - (model, None) => { - if model.starts_with("gpt") { - ("openai", model) - } else if model.starts_with("gemini") { - ("google", model) - } else if model.starts_with("claude") { - ("anthropic", model) - } else { - ("unknown", model) - } + match proto::LanguageModelRequestKind::from_i32(query.kind) { + Some(proto::LanguageModelRequestKind::Complete) => { + session + .rate_limiter + .check::(session.user_id()) + .await?; } - }; - let provider = provider.to_string(); - request.model = model.to_string(); + Some(proto::LanguageModelRequestKind::CountTokens) => { + session + .rate_limiter + .check::(session.user_id()) + .await?; + } + None => Err(anyhow!("unknown request kind"))?, + } - match provider.as_str() { - "openai" => { - let api_key = open_ai_api_key.context("no OpenAI API key configured on the server")?; - complete_with_open_ai(request, response, session, api_key).await?; - } - "anthropic" => { + match proto::LanguageModelProvider::from_i32(query.provider) { + Some(proto::LanguageModelProvider::Anthropic) => { let api_key = anthropic_api_key.context("no Anthropic AI API key configured on the server")?; - complete_with_anthropic(request, response, session, api_key).await?; + let mut chunks = anthropic::stream_completion( + session.http_client.as_ref(), + anthropic::ANTHROPIC_API_URL, + &api_key, + serde_json::from_str(&query.request)?, + None, + ) + .await?; + while let Some(chunk) = chunks.next().await { + let chunk = chunk?; + response.send(proto::QueryLanguageModelResponse { + response: serde_json::to_string(&chunk)?, + })?; + } } - "google" => { + Some(proto::LanguageModelProvider::OpenAi) => { + let api_key = open_ai_api_key.context("no OpenAI API key configured on the server")?; + let mut chunks = open_ai::stream_completion( + session.http_client.as_ref(), + open_ai::OPEN_AI_API_URL, + &api_key, + serde_json::from_str(&query.request)?, + None, + ) + .await?; + while let Some(chunk) = chunks.next().await { + let chunk = chunk?; + response.send(proto::QueryLanguageModelResponse { + response: serde_json::to_string(&chunk)?, + })?; + } + } + Some(proto::LanguageModelProvider::Google) => { let api_key = google_ai_api_key.context("no Google AI API key configured on the server")?; - complete_with_google_ai(request, response, session, api_key).await?; - } - provider => return Err(anyhow!("unknown provider {:?}", provider))?, - } - Ok(()) -} - -async fn complete_with_open_ai( - request: proto::CompleteWithLanguageModel, - response: StreamingResponse, - session: UserSession, - api_key: Arc, -) -> Result<()> { - let mut completion_stream = open_ai::stream_completion( - session.http_client.as_ref(), - OPEN_AI_API_URL, - &api_key, - crate::ai::language_model_request_to_open_ai(request)?, - None, - ) - .await - .context("open_ai::stream_completion request failed within collab")?; - - while let Some(event) = completion_stream.next().await { - let event = event?; - response.send(proto::LanguageModelResponse { - choices: event - .choices - .into_iter() - .map(|choice| proto::LanguageModelChoiceDelta { - index: choice.index, - delta: Some(proto::LanguageModelResponseMessage { - role: choice.delta.role.map(|role| match role { - open_ai::Role::User => LanguageModelRole::LanguageModelUser, - open_ai::Role::Assistant => LanguageModelRole::LanguageModelAssistant, - open_ai::Role::System => LanguageModelRole::LanguageModelSystem, - open_ai::Role::Tool => LanguageModelRole::LanguageModelTool, - } as i32), - content: choice.delta.content, - tool_calls: choice - .delta - .tool_calls - .unwrap_or_default() - .into_iter() - .map(|delta| proto::ToolCallDelta { - index: delta.index as u32, - id: delta.id, - variant: match delta.function { - Some(function) => { - let name = function.name; - let arguments = function.arguments; - - Some(proto::tool_call_delta::Variant::Function( - proto::tool_call_delta::FunctionCallDelta { - name, - arguments, - }, - )) - } - None => None, - }, - }) - .collect(), - }), - finish_reason: choice.finish_reason, - }) - .collect(), - })?; - } - - Ok(()) -} - -async fn complete_with_google_ai( - request: proto::CompleteWithLanguageModel, - response: StreamingResponse, - session: UserSession, - api_key: Arc, -) -> Result<()> { - let mut stream = google_ai::stream_generate_content( - session.http_client.clone(), - google_ai::API_URL, - api_key.as_ref(), - &request.model.clone(), - crate::ai::language_model_request_to_google_ai(request)?, - ) - .await - .context("google_ai::stream_generate_content request failed")?; - - while let Some(event) = stream.next().await { - let event = event?; - response.send(proto::LanguageModelResponse { - choices: event - .candidates - .unwrap_or_default() - .into_iter() - .map(|candidate| proto::LanguageModelChoiceDelta { - index: candidate.index as u32, - delta: Some(proto::LanguageModelResponseMessage { - role: Some(match candidate.content.role { - google_ai::Role::User => LanguageModelRole::LanguageModelUser, - google_ai::Role::Model => LanguageModelRole::LanguageModelAssistant, - } as i32), - content: Some( - candidate - .content - .parts - .into_iter() - .filter_map(|part| match part { - google_ai::Part::TextPart(part) => Some(part.text), - google_ai::Part::InlineDataPart(_) => None, - }) - .collect(), - ), - // Tool calls are not supported for Google - tool_calls: Vec::new(), - }), - finish_reason: candidate.finish_reason.map(|reason| reason.to_string()), - }) - .collect(), - })?; - } - - Ok(()) -} - -async fn complete_with_anthropic( - request: proto::CompleteWithLanguageModel, - response: StreamingResponse, - session: UserSession, - api_key: Arc, -) -> Result<()> { - let mut system_message = String::new(); - let messages = request - .messages - .into_iter() - .filter_map(|message| { - match message.role() { - LanguageModelRole::LanguageModelUser => Some(anthropic::RequestMessage { - role: anthropic::Role::User, - content: message.content, - }), - LanguageModelRole::LanguageModelAssistant => Some(anthropic::RequestMessage { - role: anthropic::Role::Assistant, - content: message.content, - }), - // Anthropic's API breaks system instructions out as a separate field rather - // than having a system message role. - LanguageModelRole::LanguageModelSystem => { - if !system_message.is_empty() { - system_message.push_str("\n\n"); - } - system_message.push_str(&message.content); - - None - } - // We don't yet support tool calls for Anthropic - LanguageModelRole::LanguageModelTool => None, - } - }) - .collect(); - - let mut stream = anthropic::stream_completion( - session.http_client.as_ref(), - anthropic::ANTHROPIC_API_URL, - &api_key, - anthropic::Request { - model: request.model, - messages, - stream: true, - system: system_message, - max_tokens: 4092, - }, - None, - ) - .await?; - - let mut current_role = proto::LanguageModelRole::LanguageModelAssistant; - - while let Some(event) = stream.next().await { - let event = event?; - - match event { - anthropic::ResponseEvent::MessageStart { message } => { - if let Some(role) = message.role { - if role == "assistant" { - current_role = proto::LanguageModelRole::LanguageModelAssistant; - } else if role == "user" { - current_role = proto::LanguageModelRole::LanguageModelUser; + match proto::LanguageModelRequestKind::from_i32(query.kind) { + Some(proto::LanguageModelRequestKind::Complete) => { + let mut chunks = google_ai::stream_generate_content( + session.http_client.as_ref(), + google_ai::API_URL, + &api_key, + serde_json::from_str(&query.request)?, + ) + .await?; + while let Some(chunk) = chunks.next().await { + let chunk = chunk?; + response.send(proto::QueryLanguageModelResponse { + response: serde_json::to_string(&chunk)?, + })?; } } - } - anthropic::ResponseEvent::ContentBlockStart { content_block, .. } => { - match content_block { - anthropic::ContentBlock::Text { text } => { - if !text.is_empty() { - response.send(proto::LanguageModelResponse { - choices: vec![proto::LanguageModelChoiceDelta { - index: 0, - delta: Some(proto::LanguageModelResponseMessage { - role: Some(current_role as i32), - content: Some(text), - tool_calls: Vec::new(), - }), - finish_reason: None, - }], - })?; - } - } - } - } - anthropic::ResponseEvent::ContentBlockDelta { delta, .. } => match delta { - anthropic::TextDelta::TextDelta { text } => { - response.send(proto::LanguageModelResponse { - choices: vec![proto::LanguageModelChoiceDelta { - index: 0, - delta: Some(proto::LanguageModelResponseMessage { - role: Some(current_role as i32), - content: Some(text), - tool_calls: Vec::new(), - }), - finish_reason: None, - }], - })?; - } - }, - anthropic::ResponseEvent::MessageDelta { delta, .. } => { - if let Some(stop_reason) = delta.stop_reason { - response.send(proto::LanguageModelResponse { - choices: vec![proto::LanguageModelChoiceDelta { - index: 0, - delta: None, - finish_reason: Some(stop_reason), - }], + Some(proto::LanguageModelRequestKind::CountTokens) => { + let tokens_response = google_ai::count_tokens( + session.http_client.as_ref(), + google_ai::API_URL, + &api_key, + serde_json::from_str(&query.request)?, + ) + .await?; + response.send(proto::QueryLanguageModelResponse { + response: serde_json::to_string(&tokens_response)?, })?; } + None => Err(anyhow!("unknown request kind"))?, } - anthropic::ResponseEvent::ContentBlockStop { .. } => {} - anthropic::ResponseEvent::MessageStop {} => {} - anthropic::ResponseEvent::Ping {} => {} } + None => return Err(anyhow!("unknown provider"))?, } Ok(()) @@ -4830,41 +4627,6 @@ impl RateLimit for CountTokensWithLanguageModelRateLimit { } } -async fn count_tokens_with_language_model( - request: proto::CountTokensWithLanguageModel, - response: Response, - session: UserSession, - google_ai_api_key: Option>, -) -> Result<()> { - authorize_access_to_language_models(&session).await?; - - if !request.model.starts_with("gemini") { - return Err(anyhow!( - "counting tokens for model: {:?} is not supported", - request.model - ))?; - } - - session - .rate_limiter - .check::(session.user_id()) - .await?; - - let api_key = google_ai_api_key - .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?; - let tokens_response = google_ai::count_tokens( - session.http_client.as_ref(), - google_ai::API_URL, - &api_key, - crate::ai::count_tokens_request_to_google_ai(request)?, - ) - .await?; - response.send(proto::CountTokensResponse { - token_count: tokens_response.total_tokens as u32, - })?; - Ok(()) -} - struct ComputeEmbeddingsRateLimit; impl RateLimit for ComputeEmbeddingsRateLimit { diff --git a/crates/google_ai/Cargo.toml b/crates/google_ai/Cargo.toml index 1495f55a31..f923e0ec91 100644 --- a/crates/google_ai/Cargo.toml +++ b/crates/google_ai/Cargo.toml @@ -11,9 +11,14 @@ workspace = true [lib] path = "src/google_ai.rs" +[features] +schemars = ["dep:schemars"] + [dependencies] anyhow.workspace = true futures.workspace = true http_client.workspace = true +schemars = { workspace = true, optional = true } serde.workspace = true serde_json.workspace = true +strum.workspace = true diff --git a/crates/google_ai/src/google_ai.rs b/crates/google_ai/src/google_ai.rs index 34c43176d0..b2ecf33243 100644 --- a/crates/google_ai/src/google_ai.rs +++ b/crates/google_ai/src/google_ai.rs @@ -1,23 +1,21 @@ -use std::sync::Arc; - use anyhow::{anyhow, Result}; -use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt}; +use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt}; use http_client::HttpClient; use serde::{Deserialize, Serialize}; pub const API_URL: &str = "https://generativelanguage.googleapis.com"; pub async fn stream_generate_content( - client: Arc, + client: &dyn HttpClient, api_url: &str, api_key: &str, - model: &str, - request: GenerateContentRequest, + mut request: GenerateContentRequest, ) -> Result>> { let uri = format!( - "{}/v1beta/models/{model}:streamGenerateContent?alt=sse&key={}", - api_url, api_key + "{api_url}/v1beta/models/{model}:streamGenerateContent?alt=sse&key={api_key}", + model = request.model ); + request.model.clear(); let request = serde_json::to_string(&request)?; let mut response = client.post_json(&uri, request.into()).await?; @@ -52,8 +50,8 @@ pub async fn stream_generate_content( } } -pub async fn count_tokens( - client: &T, +pub async fn count_tokens( + client: &dyn HttpClient, api_url: &str, api_key: &str, request: CountTokensRequest, @@ -91,22 +89,24 @@ pub enum Task { BatchEmbedContents, } -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct GenerateContentRequest { + #[serde(default, skip_serializing_if = "String::is_empty")] + pub model: String, pub contents: Vec, pub generation_config: Option, pub safety_settings: Option>, } -#[derive(Debug, Deserialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct GenerateContentResponse { pub candidates: Option>, pub prompt_feedback: Option, } -#[derive(Debug, Deserialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct GenerateContentCandidate { pub index: usize, @@ -157,7 +157,7 @@ pub struct GenerativeContentBlob { pub data: String, } -#[derive(Debug, Deserialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct CitationSource { pub start_index: Option, @@ -166,13 +166,13 @@ pub struct CitationSource { pub license: Option, } -#[derive(Debug, Deserialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct CitationMetadata { pub citation_sources: Vec, } -#[derive(Debug, Deserialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct PromptFeedback { pub block_reason: Option, @@ -180,7 +180,7 @@ pub struct PromptFeedback { pub block_reason_message: Option, } -#[derive(Debug, Serialize)] +#[derive(Debug, Deserialize, Serialize)] #[serde(rename_all = "camelCase")] pub struct GenerationConfig { pub candidate_count: Option, @@ -191,7 +191,7 @@ pub struct GenerationConfig { pub top_k: Option, } -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct SafetySetting { pub category: HarmCategory, @@ -224,7 +224,7 @@ pub enum HarmCategory { DangerousContent, } -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize, Deserialize)] pub enum HarmBlockThreshold { #[serde(rename = "HARM_BLOCK_THRESHOLD_UNSPECIFIED")] Unspecified, @@ -238,7 +238,7 @@ pub enum HarmBlockThreshold { BlockNone, } -#[derive(Debug, Deserialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "SCREAMING_SNAKE_CASE")] pub enum HarmProbability { #[serde(rename = "HARM_PROBABILITY_UNSPECIFIED")] @@ -249,21 +249,85 @@ pub enum HarmProbability { High, } -#[derive(Debug, Deserialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct SafetyRating { pub category: HarmCategory, pub probability: HarmProbability, } -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct CountTokensRequest { pub contents: Vec, } -#[derive(Debug, Deserialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct CountTokensResponse { pub total_tokens: usize, } + +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq, strum::EnumIter)] +pub enum Model { + #[serde(rename = "gemini-1.5-pro")] + Gemini15Pro, + #[serde(rename = "gemini-1.5-flash")] + Gemini15Flash, + #[serde(rename = "custom")] + Custom { name: String, max_tokens: usize }, +} + +impl Model { + pub fn id(&self) -> &str { + match self { + Model::Gemini15Pro => "gemini-1.5-pro", + Model::Gemini15Flash => "gemini-1.5-flash", + Model::Custom { name, .. } => name, + } + } + + pub fn display_name(&self) -> &str { + match self { + Model::Gemini15Pro => "Gemini 1.5 Pro", + Model::Gemini15Flash => "Gemini 1.5 Flash", + Model::Custom { name, .. } => name, + } + } + + pub fn max_token_count(&self) -> usize { + match self { + Model::Gemini15Pro => 2_000_000, + Model::Gemini15Flash => 1_000_000, + Model::Custom { max_tokens, .. } => *max_tokens, + } + } +} + +impl std::fmt::Display for Model { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.id()) + } +} + +pub fn extract_text_from_events( + events: impl Stream>, +) -> impl Stream> { + events.filter_map(|event| async move { + match event { + Ok(event) => event.candidates.and_then(|candidates| { + candidates.into_iter().next().and_then(|candidate| { + candidate.content.parts.into_iter().next().and_then(|part| { + if let Part::TextPart(TextPart { text }) = part { + Some(Ok(text)) + } else { + None + } + }) + }) + }), + Err(error) => Some(Err(error)), + } + }) +} diff --git a/crates/language_model/Cargo.toml b/crates/language_model/Cargo.toml index 1a099897a3..de3ba8ef65 100644 --- a/crates/language_model/Cargo.toml +++ b/crates/language_model/Cargo.toml @@ -28,6 +28,7 @@ collections.workspace = true editor.workspace = true feature_flags.workspace = true futures.workspace = true +google_ai = { workspace = true, features = ["schemars"] } gpui.workspace = true http_client.workspace = true menu.workspace = true diff --git a/crates/language_model/src/model/cloud_model.rs b/crates/language_model/src/model/cloud_model.rs index b7b304a65d..1023ee337a 100644 --- a/crates/language_model/src/model/cloud_model.rs +++ b/crates/language_model/src/model/cloud_model.rs @@ -1,108 +1,42 @@ -pub use anthropic::Model as AnthropicModel; -use anyhow::{anyhow, Result}; -pub use ollama::Model as OllamaModel; -pub use open_ai::Model as OpenAiModel; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; -use strum::EnumIter; -#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize, JsonSchema, EnumIter)] +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)] +#[serde(tag = "provider", rename_all = "lowercase")] pub enum CloudModel { - #[serde(rename = "gpt-3.5-turbo")] - Gpt3Point5Turbo, - #[serde(rename = "gpt-4")] - Gpt4, - #[serde(rename = "gpt-4-turbo-preview")] - Gpt4Turbo, - #[serde(rename = "gpt-4o")] - #[default] - Gpt4Omni, - #[serde(rename = "gpt-4o-mini")] - Gpt4OmniMini, - #[serde(rename = "claude-3-5-sonnet")] - Claude3_5Sonnet, - #[serde(rename = "claude-3-opus")] - Claude3Opus, - #[serde(rename = "claude-3-sonnet")] - Claude3Sonnet, - #[serde(rename = "claude-3-haiku")] - Claude3Haiku, - #[serde(rename = "gemini-1.5-pro")] - Gemini15Pro, - #[serde(rename = "gemini-1.5-flash")] - Gemini15Flash, - #[serde(rename = "custom")] - Custom { - name: String, - max_tokens: Option, - }, + Anthropic(anthropic::Model), + OpenAi(open_ai::Model), + Google(google_ai::Model), +} + +impl Default for CloudModel { + fn default() -> Self { + Self::Anthropic(anthropic::Model::default()) + } } impl CloudModel { - pub fn from_id(value: &str) -> Result { - match value { - "gpt-3.5-turbo" => Ok(Self::Gpt3Point5Turbo), - "gpt-4" => Ok(Self::Gpt4), - "gpt-4-turbo-preview" => Ok(Self::Gpt4Turbo), - "gpt-4o" => Ok(Self::Gpt4Omni), - "gpt-4o-mini" => Ok(Self::Gpt4OmniMini), - "claude-3-5-sonnet" => Ok(Self::Claude3_5Sonnet), - "claude-3-opus" => Ok(Self::Claude3Opus), - "claude-3-sonnet" => Ok(Self::Claude3Sonnet), - "claude-3-haiku" => Ok(Self::Claude3Haiku), - "gemini-1.5-pro" => Ok(Self::Gemini15Pro), - "gemini-1.5-flash" => Ok(Self::Gemini15Flash), - _ => Err(anyhow!("invalid model id")), - } - } - pub fn id(&self) -> &str { match self { - Self::Gpt3Point5Turbo => "gpt-3.5-turbo", - Self::Gpt4 => "gpt-4", - Self::Gpt4Turbo => "gpt-4-turbo-preview", - Self::Gpt4Omni => "gpt-4o", - Self::Gpt4OmniMini => "gpt-4o-mini", - Self::Claude3_5Sonnet => "claude-3-5-sonnet", - Self::Claude3Opus => "claude-3-opus", - Self::Claude3Sonnet => "claude-3-sonnet", - Self::Claude3Haiku => "claude-3-haiku", - Self::Gemini15Pro => "gemini-1.5-pro", - Self::Gemini15Flash => "gemini-1.5-flash", - Self::Custom { name, .. } => name, + CloudModel::Anthropic(model) => model.id(), + CloudModel::OpenAi(model) => model.id(), + CloudModel::Google(model) => model.id(), } } pub fn display_name(&self) -> &str { match self { - Self::Gpt3Point5Turbo => "GPT 3.5 Turbo", - Self::Gpt4 => "GPT 4", - Self::Gpt4Turbo => "GPT 4 Turbo", - Self::Gpt4Omni => "GPT 4 Omni", - Self::Gpt4OmniMini => "GPT 4 Omni Mini", - Self::Claude3_5Sonnet => "Claude 3.5 Sonnet", - Self::Claude3Opus => "Claude 3 Opus", - Self::Claude3Sonnet => "Claude 3 Sonnet", - Self::Claude3Haiku => "Claude 3 Haiku", - Self::Gemini15Pro => "Gemini 1.5 Pro", - Self::Gemini15Flash => "Gemini 1.5 Flash", - Self::Custom { name, .. } => name, + CloudModel::Anthropic(model) => model.display_name(), + CloudModel::OpenAi(model) => model.display_name(), + CloudModel::Google(model) => model.display_name(), } } pub fn max_token_count(&self) -> usize { match self { - Self::Gpt3Point5Turbo => 2048, - Self::Gpt4 => 4096, - Self::Gpt4Turbo | Self::Gpt4Omni => 128000, - Self::Gpt4OmniMini => 128000, - Self::Claude3_5Sonnet - | Self::Claude3Opus - | Self::Claude3Sonnet - | Self::Claude3Haiku => 200000, - Self::Gemini15Pro => 128000, - Self::Gemini15Flash => 32000, - Self::Custom { max_tokens, .. } => max_tokens.unwrap_or(200_000), + CloudModel::Anthropic(model) => model.max_token_count(), + CloudModel::OpenAi(model) => model.max_token_count(), + CloudModel::Google(model) => model.max_token_count(), } } } diff --git a/crates/language_model/src/provider.rs b/crates/language_model/src/provider.rs index f2713db003..6fe0bfd7a1 100644 --- a/crates/language_model/src/provider.rs +++ b/crates/language_model/src/provider.rs @@ -2,5 +2,6 @@ pub mod anthropic; pub mod cloud; #[cfg(any(test, feature = "test-support"))] pub mod fake; +pub mod google; pub mod ollama; pub mod open_ai; diff --git a/crates/language_model/src/provider/anthropic.rs b/crates/language_model/src/provider/anthropic.rs index 52ac22b29f..7cc9922546 100644 --- a/crates/language_model/src/provider/anthropic.rs +++ b/crates/language_model/src/provider/anthropic.rs @@ -1,4 +1,4 @@ -use anthropic::{stream_completion, Request, RequestMessage}; +use anthropic::stream_completion; use anyhow::{anyhow, Result}; use collections::BTreeMap; use editor::{Editor, EditorElement, EditorStyle}; @@ -18,7 +18,7 @@ use util::ResultExt; use crate::{ settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, - LanguageModelProviderState, LanguageModelRequest, LanguageModelRequestMessage, Role, + LanguageModelProviderState, LanguageModelRequest, Role, }; const PROVIDER_ID: &str = "anthropic"; @@ -160,40 +160,6 @@ pub struct AnthropicModel { http_client: Arc, } -impl AnthropicModel { - fn to_anthropic_request(&self, mut request: LanguageModelRequest) -> Request { - preprocess_anthropic_request(&mut request); - - let mut system_message = String::new(); - if request - .messages - .first() - .map_or(false, |message| message.role == Role::System) - { - system_message = request.messages.remove(0).content; - } - - Request { - model: self.model.id().to_string(), - messages: request - .messages - .iter() - .map(|msg| RequestMessage { - role: match msg.role { - Role::User => anthropic::Role::User, - Role::Assistant => anthropic::Role::Assistant, - Role::System => unreachable!("filtered out by preprocess_request"), - }, - content: msg.content.clone(), - }) - .collect(), - stream: true, - system: system_message, - max_tokens: 4092, - } - } -} - pub fn count_anthropic_tokens( request: LanguageModelRequest, cx: &AppContext, @@ -260,7 +226,7 @@ impl LanguageModel for AnthropicModel { request: LanguageModelRequest, cx: &AsyncAppContext, ) -> BoxFuture<'static, Result>>> { - let request = self.to_anthropic_request(request); + let request = request.into_anthropic(self.model.id().into()); let http_client = self.http_client.clone(); @@ -285,75 +251,12 @@ impl LanguageModel for AnthropicModel { low_speed_timeout, ); let response = request.await?; - let stream = response - .filter_map(|response| async move { - match response { - Ok(response) => match response { - anthropic::ResponseEvent::ContentBlockStart { - content_block, .. - } => match content_block { - anthropic::ContentBlock::Text { text } => Some(Ok(text)), - }, - anthropic::ResponseEvent::ContentBlockDelta { delta, .. } => { - match delta { - anthropic::TextDelta::TextDelta { text } => Some(Ok(text)), - } - } - _ => None, - }, - Err(error) => Some(Err(error)), - } - }) - .boxed(); - Ok(stream) + Ok(anthropic::extract_text_from_events(response).boxed()) } .boxed() } } -pub fn preprocess_anthropic_request(request: &mut LanguageModelRequest) { - let mut new_messages: Vec = Vec::new(); - let mut system_message = String::new(); - - for message in request.messages.drain(..) { - if message.content.is_empty() { - continue; - } - - match message.role { - Role::User | Role::Assistant => { - if let Some(last_message) = new_messages.last_mut() { - if last_message.role == message.role { - last_message.content.push_str("\n\n"); - last_message.content.push_str(&message.content); - continue; - } - } - - new_messages.push(message); - } - Role::System => { - if !system_message.is_empty() { - system_message.push_str("\n\n"); - } - system_message.push_str(&message.content); - } - } - } - - if !system_message.is_empty() { - new_messages.insert( - 0, - LanguageModelRequestMessage { - role: Role::System, - content: system_message, - }, - ); - } - - request.messages = new_messages; -} - struct AuthenticationPrompt { api_key: View, state: gpui::Model, diff --git a/crates/language_model/src/provider/cloud.rs b/crates/language_model/src/provider/cloud.rs index 1cd8b99e98..d290876ad9 100644 --- a/crates/language_model/src/provider/cloud.rs +++ b/crates/language_model/src/provider/cloud.rs @@ -7,8 +7,10 @@ use crate::{ use anyhow::Result; use client::Client; use collections::BTreeMap; -use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt, TryFutureExt}; +use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt}; use gpui::{AnyView, AppContext, AsyncAppContext, Subscription, Task}; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; use settings::{Settings, SettingsStore}; use std::sync::Arc; use strum::IntoEnumIterator; @@ -16,14 +18,29 @@ use ui::prelude::*; use crate::LanguageModelProvider; -use super::anthropic::{count_anthropic_tokens, preprocess_anthropic_request}; +use super::anthropic::count_anthropic_tokens; pub const PROVIDER_ID: &str = "zed.dev"; pub const PROVIDER_NAME: &str = "zed.dev"; #[derive(Default, Clone, Debug, PartialEq)] pub struct ZedDotDevSettings { - pub available_models: Vec, + pub available_models: Vec, +} + +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)] +#[serde(rename_all = "lowercase")] +pub enum AvailableProvider { + Anthropic, + OpenAi, + Google, +} + +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)] +pub struct AvailableModel { + provider: AvailableProvider, + name: String, + max_tokens: usize, } pub struct CloudLanguageModelProvider { @@ -100,10 +117,19 @@ impl LanguageModelProvider for CloudLanguageModelProvider { fn provided_models(&self, cx: &AppContext) -> Vec> { let mut models = BTreeMap::default(); - // Add base models from CloudModel::iter() - for model in CloudModel::iter() { - if !matches!(model, CloudModel::Custom { .. }) { - models.insert(model.id().to_string(), model); + for model in anthropic::Model::iter() { + if !matches!(model, anthropic::Model::Custom { .. }) { + models.insert(model.id().to_string(), CloudModel::Anthropic(model)); + } + } + for model in open_ai::Model::iter() { + if !matches!(model, open_ai::Model::Custom { .. }) { + models.insert(model.id().to_string(), CloudModel::OpenAi(model)); + } + } + for model in google_ai::Model::iter() { + if !matches!(model, google_ai::Model::Custom { .. }) { + models.insert(model.id().to_string(), CloudModel::Google(model)); } } @@ -112,6 +138,20 @@ impl LanguageModelProvider for CloudLanguageModelProvider { .zed_dot_dev .available_models { + let model = match model.provider { + AvailableProvider::Anthropic => CloudModel::Anthropic(anthropic::Model::Custom { + name: model.name.clone(), + max_tokens: model.max_tokens, + }), + AvailableProvider::OpenAi => CloudModel::OpenAi(open_ai::Model::Custom { + name: model.name.clone(), + max_tokens: model.max_tokens, + }), + AvailableProvider::Google => CloudModel::Google(google_ai::Model::Custom { + name: model.name.clone(), + max_tokens: model.max_tokens, + }), + }; models.insert(model.id().to_string(), model.clone()); } @@ -183,35 +223,26 @@ impl LanguageModel for CloudLanguageModel { request: LanguageModelRequest, cx: &AppContext, ) -> BoxFuture<'static, Result> { - match &self.model { - CloudModel::Gpt3Point5Turbo => { - count_open_ai_tokens(request, open_ai::Model::ThreePointFiveTurbo, cx) - } - CloudModel::Gpt4 => count_open_ai_tokens(request, open_ai::Model::Four, cx), - CloudModel::Gpt4Turbo => count_open_ai_tokens(request, open_ai::Model::FourTurbo, cx), - CloudModel::Gpt4Omni => count_open_ai_tokens(request, open_ai::Model::FourOmni, cx), - CloudModel::Gpt4OmniMini => { - count_open_ai_tokens(request, open_ai::Model::FourOmniMini, cx) - } - CloudModel::Claude3_5Sonnet - | CloudModel::Claude3Opus - | CloudModel::Claude3Sonnet - | CloudModel::Claude3Haiku => count_anthropic_tokens(request, cx), - CloudModel::Custom { name, .. } if name.starts_with("anthropic/") => { - count_anthropic_tokens(request, cx) - } - _ => { - let request = self.client.request(proto::CountTokensWithLanguageModel { - model: self.model.id().to_string(), - messages: request - .messages - .iter() - .map(|message| message.to_proto()) - .collect(), - }); + match self.model.clone() { + CloudModel::Anthropic(_) => count_anthropic_tokens(request, cx), + CloudModel::OpenAi(model) => count_open_ai_tokens(request, model, cx), + CloudModel::Google(model) => { + let client = self.client.clone(); + let request = request.into_google(model.id().into()); + let request = google_ai::CountTokensRequest { + contents: request.contents, + }; async move { - let response = request.await?; - Ok(response.token_count as usize) + let request = serde_json::to_string(&request)?; + let response = client.request(proto::QueryLanguageModel { + provider: proto::LanguageModelProvider::Google as i32, + kind: proto::LanguageModelRequestKind::CountTokens as i32, + request, + }); + let response = response.await?; + let response = + serde_json::from_str::(&response.response)?; + Ok(response.total_tokens) } .boxed() } @@ -220,46 +251,65 @@ impl LanguageModel for CloudLanguageModel { fn stream_completion( &self, - mut request: LanguageModelRequest, + request: LanguageModelRequest, _: &AsyncAppContext, ) -> BoxFuture<'static, Result>>> { match &self.model { - CloudModel::Claude3Opus - | CloudModel::Claude3Sonnet - | CloudModel::Claude3Haiku - | CloudModel::Claude3_5Sonnet => preprocess_anthropic_request(&mut request), - CloudModel::Custom { name, .. } if name.starts_with("anthropic/") => { - preprocess_anthropic_request(&mut request) + CloudModel::Anthropic(model) => { + let client = self.client.clone(); + let request = request.into_anthropic(model.id().into()); + async move { + let request = serde_json::to_string(&request)?; + let response = client.request_stream(proto::QueryLanguageModel { + provider: proto::LanguageModelProvider::Anthropic as i32, + kind: proto::LanguageModelRequestKind::Complete as i32, + request, + }); + let chunks = response.await?; + Ok(anthropic::extract_text_from_events( + chunks.map(|chunk| Ok(serde_json::from_str(&chunk?.response)?)), + ) + .boxed()) + } + .boxed() + } + CloudModel::OpenAi(model) => { + let client = self.client.clone(); + let request = request.into_open_ai(model.id().into()); + async move { + let request = serde_json::to_string(&request)?; + let response = client.request_stream(proto::QueryLanguageModel { + provider: proto::LanguageModelProvider::OpenAi as i32, + kind: proto::LanguageModelRequestKind::Complete as i32, + request, + }); + let chunks = response.await?; + Ok(open_ai::extract_text_from_events( + chunks.map(|chunk| Ok(serde_json::from_str(&chunk?.response)?)), + ) + .boxed()) + } + .boxed() + } + CloudModel::Google(model) => { + let client = self.client.clone(); + let request = request.into_google(model.id().into()); + async move { + let request = serde_json::to_string(&request)?; + let response = client.request_stream(proto::QueryLanguageModel { + provider: proto::LanguageModelProvider::Google as i32, + kind: proto::LanguageModelRequestKind::Complete as i32, + request, + }); + let chunks = response.await?; + Ok(google_ai::extract_text_from_events( + chunks.map(|chunk| Ok(serde_json::from_str(&chunk?.response)?)), + ) + .boxed()) + } + .boxed() } - _ => {} } - - let request = proto::CompleteWithLanguageModel { - model: self.id.0.to_string(), - messages: request - .messages - .iter() - .map(|message| message.to_proto()) - .collect(), - stop: request.stop, - temperature: request.temperature, - tools: Vec::new(), - tool_choice: None, - }; - - self.client - .request_stream(request) - .map_ok(|stream| { - stream - .filter_map(|response| async move { - match response { - Ok(mut response) => Some(Ok(response.choices.pop()?.delta?.content?)), - Err(error) => Some(Err(error)), - } - }) - .boxed() - }) - .boxed() } } diff --git a/crates/language_model/src/provider/google.rs b/crates/language_model/src/provider/google.rs new file mode 100644 index 0000000000..3a0c0a3f7e --- /dev/null +++ b/crates/language_model/src/provider/google.rs @@ -0,0 +1,351 @@ +use anyhow::{anyhow, Result}; +use collections::BTreeMap; +use editor::{Editor, EditorElement, EditorStyle}; +use futures::{future::BoxFuture, FutureExt, StreamExt}; +use google_ai::stream_generate_content; +use gpui::{ + AnyView, AppContext, AsyncAppContext, FontStyle, Subscription, Task, TextStyle, View, + WhiteSpace, +}; +use http_client::HttpClient; +use settings::{Settings, SettingsStore}; +use std::{sync::Arc, time::Duration}; +use strum::IntoEnumIterator; +use theme::ThemeSettings; +use ui::prelude::*; +use util::ResultExt; + +use crate::{ + settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName, + LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, + LanguageModelProviderState, LanguageModelRequest, +}; + +const PROVIDER_ID: &str = "google"; +const PROVIDER_NAME: &str = "Google AI"; + +#[derive(Default, Clone, Debug, PartialEq)] +pub struct GoogleSettings { + pub api_url: String, + pub low_speed_timeout: Option, + pub available_models: Vec, +} + +pub struct GoogleLanguageModelProvider { + http_client: Arc, + state: gpui::Model, +} + +struct State { + api_key: Option, + _subscription: Subscription, +} + +impl GoogleLanguageModelProvider { + pub fn new(http_client: Arc, cx: &mut AppContext) -> Self { + let state = cx.new_model(|cx| State { + api_key: None, + _subscription: cx.observe_global::(|_, cx| { + cx.notify(); + }), + }); + + Self { http_client, state } + } +} + +impl LanguageModelProviderState for GoogleLanguageModelProvider { + fn subscribe(&self, cx: &mut gpui::ModelContext) -> Option { + Some(cx.observe(&self.state, |_, _, cx| { + cx.notify(); + })) + } +} + +impl LanguageModelProvider for GoogleLanguageModelProvider { + fn id(&self) -> LanguageModelProviderId { + LanguageModelProviderId(PROVIDER_ID.into()) + } + + fn name(&self) -> LanguageModelProviderName { + LanguageModelProviderName(PROVIDER_NAME.into()) + } + + fn provided_models(&self, cx: &AppContext) -> Vec> { + let mut models = BTreeMap::default(); + + // Add base models from google_ai::Model::iter() + for model in google_ai::Model::iter() { + if !matches!(model, google_ai::Model::Custom { .. }) { + models.insert(model.id().to_string(), model); + } + } + + // Override with available models from settings + for model in &AllLanguageModelSettings::get_global(cx) + .google + .available_models + { + models.insert(model.id().to_string(), model.clone()); + } + + models + .into_values() + .map(|model| { + Arc::new(GoogleLanguageModel { + id: LanguageModelId::from(model.id().to_string()), + model, + state: self.state.clone(), + http_client: self.http_client.clone(), + }) as Arc + }) + .collect() + } + + fn is_authenticated(&self, cx: &AppContext) -> bool { + self.state.read(cx).api_key.is_some() + } + + fn authenticate(&self, cx: &AppContext) -> Task> { + if self.is_authenticated(cx) { + Task::ready(Ok(())) + } else { + let api_url = AllLanguageModelSettings::get_global(cx) + .google + .api_url + .clone(); + let state = self.state.clone(); + cx.spawn(|mut cx| async move { + let api_key = if let Ok(api_key) = std::env::var("GOOGLE_AI_API_KEY") { + api_key + } else { + let (_, api_key) = cx + .update(|cx| cx.read_credentials(&api_url))? + .await? + .ok_or_else(|| anyhow!("credentials not found"))?; + String::from_utf8(api_key)? + }; + + state.update(&mut cx, |this, cx| { + this.api_key = Some(api_key); + cx.notify(); + }) + }) + } + } + + fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView { + cx.new_view(|cx| AuthenticationPrompt::new(self.state.clone(), cx)) + .into() + } + + fn reset_credentials(&self, cx: &AppContext) -> Task> { + let state = self.state.clone(); + let delete_credentials = + cx.delete_credentials(&AllLanguageModelSettings::get_global(cx).google.api_url); + cx.spawn(|mut cx| async move { + delete_credentials.await.log_err(); + state.update(&mut cx, |this, cx| { + this.api_key = None; + cx.notify(); + }) + }) + } +} + +pub struct GoogleLanguageModel { + id: LanguageModelId, + model: google_ai::Model, + state: gpui::Model, + http_client: Arc, +} + +impl LanguageModel for GoogleLanguageModel { + fn id(&self) -> LanguageModelId { + self.id.clone() + } + + fn name(&self) -> LanguageModelName { + LanguageModelName::from(self.model.display_name().to_string()) + } + + fn provider_id(&self) -> LanguageModelProviderId { + LanguageModelProviderId(PROVIDER_ID.into()) + } + + fn provider_name(&self) -> LanguageModelProviderName { + LanguageModelProviderName(PROVIDER_NAME.into()) + } + + fn telemetry_id(&self) -> String { + format!("google/{}", self.model.id()) + } + + fn max_token_count(&self) -> usize { + self.model.max_token_count() + } + + fn count_tokens( + &self, + request: LanguageModelRequest, + cx: &AppContext, + ) -> BoxFuture<'static, Result> { + let request = request.into_google(self.model.id().to_string()); + let http_client = self.http_client.clone(); + let api_key = self.state.read(cx).api_key.clone(); + let api_url = AllLanguageModelSettings::get_global(cx) + .google + .api_url + .clone(); + + async move { + let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?; + let response = google_ai::count_tokens( + http_client.as_ref(), + &api_url, + &api_key, + google_ai::CountTokensRequest { + contents: request.contents, + }, + ) + .await?; + Ok(response.total_tokens) + } + .boxed() + } + + fn stream_completion( + &self, + request: LanguageModelRequest, + cx: &AsyncAppContext, + ) -> BoxFuture<'static, Result>>> { + let request = request.into_google(self.model.id().to_string()); + + let http_client = self.http_client.clone(); + let Ok((api_key, api_url)) = cx.read_model(&self.state, |state, cx| { + let settings = &AllLanguageModelSettings::get_global(cx).google; + (state.api_key.clone(), settings.api_url.clone()) + }) else { + return futures::future::ready(Err(anyhow!("App state dropped"))).boxed(); + }; + + async move { + let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?; + let response = + stream_generate_content(http_client.as_ref(), &api_url, &api_key, request); + let events = response.await?; + Ok(google_ai::extract_text_from_events(events).boxed()) + } + .boxed() + } +} + +struct AuthenticationPrompt { + api_key: View, + state: gpui::Model, +} + +impl AuthenticationPrompt { + fn new(state: gpui::Model, cx: &mut WindowContext) -> Self { + Self { + api_key: cx.new_view(|cx| { + let mut editor = Editor::single_line(cx); + editor.set_placeholder_text("AIzaSy...", cx); + editor + }), + state, + } + } + + fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext) { + let api_key = self.api_key.read(cx).text(cx); + if api_key.is_empty() { + return; + } + + let settings = &AllLanguageModelSettings::get_global(cx).google; + let write_credentials = + cx.write_credentials(&settings.api_url, "Bearer", api_key.as_bytes()); + let state = self.state.clone(); + cx.spawn(|_, mut cx| async move { + write_credentials.await?; + state.update(&mut cx, |this, cx| { + this.api_key = Some(api_key); + cx.notify(); + }) + }) + .detach_and_log_err(cx); + } + + fn render_api_key_editor(&self, cx: &mut ViewContext) -> impl IntoElement { + let settings = ThemeSettings::get_global(cx); + let text_style = TextStyle { + color: cx.theme().colors().text, + font_family: settings.ui_font.family.clone(), + font_features: settings.ui_font.features.clone(), + font_fallbacks: settings.ui_font.fallbacks.clone(), + font_size: rems(0.875).into(), + font_weight: settings.ui_font.weight, + font_style: FontStyle::Normal, + line_height: relative(1.3), + background_color: None, + underline: None, + strikethrough: None, + white_space: WhiteSpace::Normal, + }; + EditorElement::new( + &self.api_key, + EditorStyle { + background: cx.theme().colors().editor_background, + local_player: cx.theme().players().local(), + text: text_style, + ..Default::default() + }, + ) + } +} + +impl Render for AuthenticationPrompt { + fn render(&mut self, cx: &mut ViewContext) -> impl IntoElement { + const INSTRUCTIONS: [&str; 4] = [ + "To use the Google AI assistant, you need to add your Google AI API key.", + "You can create an API key at: https://makersuite.google.com/app/apikey", + "", + "Paste your Google AI API key below and hit enter to use the assistant:", + ]; + + v_flex() + .p_4() + .size_full() + .on_action(cx.listener(Self::save_api_key)) + .children( + INSTRUCTIONS.map(|instruction| Label::new(instruction).size(LabelSize::Small)), + ) + .child( + h_flex() + .w_full() + .my_2() + .px_2() + .py_1() + .bg(cx.theme().colors().editor_background) + .rounded_md() + .child(self.render_api_key_editor(cx)), + ) + .child( + Label::new( + "You can also assign the GOOGLE_AI_API_KEY environment variable and restart Zed.", + ) + .size(LabelSize::Small), + ) + .child( + h_flex() + .gap_2() + .child(Label::new("Click on").size(LabelSize::Small)) + .child(Icon::new(IconName::ZedAssistant).size(IconSize::XSmall)) + .child( + Label::new("in the status bar to close this panel.").size(LabelSize::Small), + ), + ) + .into_any() + } +} diff --git a/crates/language_model/src/provider/open_ai.rs b/crates/language_model/src/provider/open_ai.rs index c81a435946..1b3bf18dd5 100644 --- a/crates/language_model/src/provider/open_ai.rs +++ b/crates/language_model/src/provider/open_ai.rs @@ -7,7 +7,7 @@ use gpui::{ WhiteSpace, }; use http_client::HttpClient; -use open_ai::{stream_completion, Request, RequestMessage}; +use open_ai::stream_completion; use settings::{Settings, SettingsStore}; use std::{sync::Arc, time::Duration}; use strum::IntoEnumIterator; @@ -159,35 +159,6 @@ pub struct OpenAiLanguageModel { http_client: Arc, } -impl OpenAiLanguageModel { - fn to_open_ai_request(&self, request: LanguageModelRequest) -> Request { - Request { - model: self.model.clone(), - messages: request - .messages - .into_iter() - .map(|msg| match msg.role { - Role::User => RequestMessage::User { - content: msg.content, - }, - Role::Assistant => RequestMessage::Assistant { - content: Some(msg.content), - tool_calls: Vec::new(), - }, - Role::System => RequestMessage::System { - content: msg.content, - }, - }) - .collect(), - stream: true, - stop: request.stop, - temperature: request.temperature, - tools: Vec::new(), - tool_choice: None, - } - } -} - impl LanguageModel for OpenAiLanguageModel { fn id(&self) -> LanguageModelId { self.id.clone() @@ -226,7 +197,7 @@ impl LanguageModel for OpenAiLanguageModel { request: LanguageModelRequest, cx: &AsyncAppContext, ) -> BoxFuture<'static, Result>>> { - let request = self.to_open_ai_request(request); + let request = request.into_open_ai(self.model.id().into()); let http_client = self.http_client.clone(); let Ok((api_key, api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, cx| { @@ -250,15 +221,7 @@ impl LanguageModel for OpenAiLanguageModel { low_speed_timeout, ); let response = request.await?; - let stream = response - .filter_map(|response| async move { - match response { - Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)), - Err(error) => Some(Err(error)), - } - }) - .boxed(); - Ok(stream) + Ok(open_ai::extract_text_from_events(response).boxed()) } .boxed() } diff --git a/crates/language_model/src/registry.rs b/crates/language_model/src/registry.rs index e787f5f7e7..05dcbced5d 100644 --- a/crates/language_model/src/registry.rs +++ b/crates/language_model/src/registry.rs @@ -1,17 +1,17 @@ +use crate::{ + provider::{ + anthropic::AnthropicLanguageModelProvider, cloud::CloudLanguageModelProvider, + google::GoogleLanguageModelProvider, ollama::OllamaLanguageModelProvider, + open_ai::OpenAiLanguageModelProvider, + }, + LanguageModel, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderState, +}; use client::Client; use collections::BTreeMap; use gpui::{AppContext, Global, Model, ModelContext}; use std::sync::Arc; use ui::Context; -use crate::{ - provider::{ - anthropic::AnthropicLanguageModelProvider, cloud::CloudLanguageModelProvider, - ollama::OllamaLanguageModelProvider, open_ai::OpenAiLanguageModelProvider, - }, - LanguageModel, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderState, -}; - pub fn init(client: Arc, cx: &mut AppContext) { let registry = cx.new_model(|cx| { let mut registry = LanguageModelRegistry::default(); @@ -40,6 +40,10 @@ fn register_language_model_providers( OllamaLanguageModelProvider::new(client.http_client(), cx), cx, ); + registry.register_provider( + GoogleLanguageModelProvider::new(client.http_client(), cx), + cx, + ); cx.observe_flag::(move |enabled, cx| { let client = client.clone(); diff --git a/crates/language_model/src/request.rs b/crates/language_model/src/request.rs index e3e1d3e77b..fc3b8c0192 100644 --- a/crates/language_model/src/request.rs +++ b/crates/language_model/src/request.rs @@ -1,4 +1,4 @@ -use crate::{role::Role, LanguageModelId}; +use crate::role::Role; use serde::{Deserialize, Serialize}; #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] @@ -7,17 +7,6 @@ pub struct LanguageModelRequestMessage { pub content: String, } -impl LanguageModelRequestMessage { - pub fn to_proto(&self) -> proto::LanguageModelRequestMessage { - proto::LanguageModelRequestMessage { - role: self.role.to_proto() as i32, - content: self.content.clone(), - tool_calls: Vec::new(), - tool_call_id: None, - } - } -} - #[derive(Debug, Default, Serialize, Deserialize)] pub struct LanguageModelRequest { pub messages: Vec, @@ -26,14 +15,110 @@ pub struct LanguageModelRequest { } impl LanguageModelRequest { - pub fn to_proto(&self, model_id: LanguageModelId) -> proto::CompleteWithLanguageModel { - proto::CompleteWithLanguageModel { - model: model_id.0.to_string(), - messages: self.messages.iter().map(|m| m.to_proto()).collect(), - stop: self.stop.clone(), + pub fn into_open_ai(self, model: String) -> open_ai::Request { + open_ai::Request { + model, + messages: self + .messages + .into_iter() + .map(|msg| match msg.role { + Role::User => open_ai::RequestMessage::User { + content: msg.content, + }, + Role::Assistant => open_ai::RequestMessage::Assistant { + content: Some(msg.content), + tool_calls: Vec::new(), + }, + Role::System => open_ai::RequestMessage::System { + content: msg.content, + }, + }) + .collect(), + stream: true, + stop: self.stop, temperature: self.temperature, - tool_choice: None, tools: Vec::new(), + tool_choice: None, + } + } + + pub fn into_google(self, model: String) -> google_ai::GenerateContentRequest { + google_ai::GenerateContentRequest { + model, + contents: self + .messages + .into_iter() + .map(|msg| google_ai::Content { + parts: vec![google_ai::Part::TextPart(google_ai::TextPart { + text: msg.content, + })], + role: match msg.role { + Role::User => google_ai::Role::User, + Role::Assistant => google_ai::Role::Model, + Role::System => google_ai::Role::User, // Google AI doesn't have a system role + }, + }) + .collect(), + generation_config: Some(google_ai::GenerationConfig { + candidate_count: Some(1), + stop_sequences: Some(self.stop), + max_output_tokens: None, + temperature: Some(self.temperature as f64), + top_p: None, + top_k: None, + }), + safety_settings: None, + } + } + + pub fn into_anthropic(self, model: String) -> anthropic::Request { + let mut new_messages: Vec = Vec::new(); + let mut system_message = String::new(); + + for message in self.messages { + if message.content.is_empty() { + continue; + } + + match message.role { + Role::User | Role::Assistant => { + if let Some(last_message) = new_messages.last_mut() { + if last_message.role == message.role { + last_message.content.push_str("\n\n"); + last_message.content.push_str(&message.content); + continue; + } + } + + new_messages.push(message); + } + Role::System => { + if !system_message.is_empty() { + system_message.push_str("\n\n"); + } + system_message.push_str(&message.content); + } + } + } + + anthropic::Request { + model, + messages: new_messages + .into_iter() + .filter_map(|message| { + Some(anthropic::RequestMessage { + role: match message.role { + Role::User => anthropic::Role::User, + Role::Assistant => anthropic::Role::Assistant, + Role::System => return None, + }, + content: message.content, + }) + }) + .collect(), + stream: true, + max_tokens: 4092, + system: system_message, } } } diff --git a/crates/language_model/src/role.rs b/crates/language_model/src/role.rs index f6276a4823..82184038f6 100644 --- a/crates/language_model/src/role.rs +++ b/crates/language_model/src/role.rs @@ -15,7 +15,6 @@ impl Role { Some(proto::LanguageModelRole::LanguageModelUser) => Role::User, Some(proto::LanguageModelRole::LanguageModelAssistant) => Role::Assistant, Some(proto::LanguageModelRole::LanguageModelSystem) => Role::System, - Some(proto::LanguageModelRole::LanguageModelTool) => Role::System, None => Role::User, } } diff --git a/crates/language_model/src/settings.rs b/crates/language_model/src/settings.rs index 262e14937a..85ae91649a 100644 --- a/crates/language_model/src/settings.rs +++ b/crates/language_model/src/settings.rs @@ -6,12 +6,12 @@ use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use settings::{Settings, SettingsSources}; -use crate::{ - provider::{ - anthropic::AnthropicSettings, cloud::ZedDotDevSettings, ollama::OllamaSettings, - open_ai::OpenAiSettings, - }, - CloudModel, +use crate::provider::{ + anthropic::AnthropicSettings, + cloud::{self, ZedDotDevSettings}, + google::GoogleSettings, + ollama::OllamaSettings, + open_ai::OpenAiSettings, }; /// Initializes the language model settings. @@ -25,6 +25,7 @@ pub struct AllLanguageModelSettings { pub ollama: OllamaSettings, pub openai: OpenAiSettings, pub zed_dot_dev: ZedDotDevSettings, + pub google: GoogleSettings, } #[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)] @@ -34,6 +35,7 @@ pub struct AllLanguageModelSettingsContent { pub openai: Option, #[serde(rename = "zed.dev")] pub zed_dot_dev: Option, + pub google: Option, } #[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)] @@ -56,9 +58,16 @@ pub struct OpenAiSettingsContent { pub available_models: Option>, } +#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)] +pub struct GoogleSettingsContent { + pub api_url: Option, + pub low_speed_timeout_in_seconds: Option, + pub available_models: Option>, +} + #[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)] pub struct ZedDotDevSettingsContent { - available_models: Option>, + available_models: Option>, } impl settings::Settings for AllLanguageModelSettings { @@ -136,6 +145,26 @@ impl settings::Settings for AllLanguageModelSettings { .as_ref() .and_then(|s| s.available_models.clone()), ); + + merge( + &mut settings.google.api_url, + value.google.as_ref().and_then(|s| s.api_url.clone()), + ); + if let Some(low_speed_timeout_in_seconds) = value + .google + .as_ref() + .and_then(|s| s.low_speed_timeout_in_seconds) + { + settings.google.low_speed_timeout = + Some(Duration::from_secs(low_speed_timeout_in_seconds)); + } + merge( + &mut settings.google.available_models, + value + .google + .as_ref() + .and_then(|s| s.available_models.clone()), + ); } Ok(settings) diff --git a/crates/open_ai/src/open_ai.rs b/crates/open_ai/src/open_ai.rs index dfcd6646d1..13a6eb11d1 100644 --- a/crates/open_ai/src/open_ai.rs +++ b/crates/open_ai/src/open_ai.rs @@ -1,5 +1,5 @@ use anyhow::{anyhow, Context, Result}; -use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt}; +use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt}; use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest}; use isahc::config::Configurable; use serde::{Deserialize, Serialize}; @@ -111,38 +111,27 @@ impl Model { } } -fn serialize_model(model: &Model, serializer: S) -> Result -where - S: serde::Serializer, -{ - match model { - Model::Custom { name, .. } => serializer.serialize_str(name), - _ => serializer.serialize_str(model.id()), - } -} - -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize, Deserialize)] pub struct Request { - #[serde(serialize_with = "serialize_model")] - pub model: Model, + pub model: String, pub messages: Vec, pub stream: bool, pub stop: Vec, pub temperature: f32, - #[serde(skip_serializing_if = "Option::is_none")] + #[serde(default, skip_serializing_if = "Option::is_none")] pub tool_choice: Option, - #[serde(skip_serializing_if = "Vec::is_empty")] + #[serde(default, skip_serializing_if = "Vec::is_empty")] pub tools: Vec, } -#[derive(Debug, Serialize)] +#[derive(Debug, Deserialize, Serialize)] pub struct FunctionDefinition { pub name: String, pub description: Option, pub parameters: Option>, } -#[derive(Serialize, Debug)] +#[derive(Deserialize, Serialize, Debug)] #[serde(tag = "type", rename_all = "snake_case")] pub enum ToolDefinition { #[allow(dead_code)] @@ -213,21 +202,21 @@ pub struct FunctionChunk { pub arguments: Option, } -#[derive(Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug)] pub struct Usage { pub prompt_tokens: u32, pub completion_tokens: u32, pub total_tokens: u32, } -#[derive(Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug)] pub struct ChoiceDelta { pub index: u32, pub delta: ResponseMessageDelta, pub finish_reason: Option, } -#[derive(Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug)] pub struct ResponseStreamEvent { pub created: u32, pub model: String, @@ -369,3 +358,14 @@ pub fn embed<'a>( } } } + +pub fn extract_text_from_events( + response: impl Stream>, +) -> impl Stream> { + response.filter_map(|response| async move { + match response { + Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)), + Err(error) => Some(Err(error)), + } + }) +} diff --git a/crates/proto/proto/zed.proto b/crates/proto/proto/zed.proto index 60f8d01558..658d552848 100644 --- a/crates/proto/proto/zed.proto +++ b/crates/proto/proto/zed.proto @@ -13,13 +13,6 @@ message Envelope { optional uint32 responding_to = 2; optional PeerId original_sender_id = 3; - /* - When you are adding a new message type, instead of adding it in semantic order - and bumping the message ID's of everything that follows, add it at the end of the - file and bump the max number. See this - https://github.com/zed-industries/zed/pull/7890#discussion_r1496621823 - - */ oneof payload { Hello hello = 4; Ack ack = 5; @@ -201,10 +194,8 @@ message Envelope { JoinHostedProject join_hosted_project = 164; - CompleteWithLanguageModel complete_with_language_model = 166; - LanguageModelResponse language_model_response = 167; - CountTokensWithLanguageModel count_tokens_with_language_model = 168; - CountTokensResponse count_tokens_response = 169; + QueryLanguageModel query_language_model = 224; + QueryLanguageModelResponse query_language_model_response = 225; // current max GetCachedEmbeddings get_cached_embeddings = 189; GetCachedEmbeddingsResponse get_cached_embeddings_response = 190; ComputeEmbeddings compute_embeddings = 191; @@ -271,10 +262,11 @@ message Envelope { UpdateDevServerProject update_dev_server_project = 221; AddWorktree add_worktree = 222; - AddWorktreeResponse add_worktree_response = 223; // current max + AddWorktreeResponse add_worktree_response = 223; } reserved 158 to 161; + reserved 166 to 169; } // Messages @@ -2051,94 +2043,32 @@ message SetRoomParticipantRole { ChannelRole role = 3; } -message CompleteWithLanguageModel { - string model = 1; - repeated LanguageModelRequestMessage messages = 2; - repeated string stop = 3; - float temperature = 4; - repeated ChatCompletionTool tools = 5; - optional string tool_choice = 6; -} - -// A tool presented to the language model for its use -message ChatCompletionTool { - oneof variant { - FunctionObject function = 1; - } - - message FunctionObject { - string name = 1; - optional string description = 2; - optional string parameters = 3; - } -} - -// A message to the language model -message LanguageModelRequestMessage { - LanguageModelRole role = 1; - string content = 2; - optional string tool_call_id = 3; - repeated ToolCall tool_calls = 4; -} - enum LanguageModelRole { LanguageModelUser = 0; LanguageModelAssistant = 1; LanguageModelSystem = 2; - LanguageModelTool = 3; + reserved 3; } -message LanguageModelResponseMessage { - optional LanguageModelRole role = 1; - optional string content = 2; - repeated ToolCallDelta tool_calls = 3; +message QueryLanguageModel { + LanguageModelProvider provider = 1; + LanguageModelRequestKind kind = 2; + string request = 3; } -// A request to call a tool, by the language model -message ToolCall { - string id = 1; - - oneof variant { - FunctionCall function = 2; - } - - message FunctionCall { - string name = 1; - string arguments = 2; - } +enum LanguageModelProvider { + Anthropic = 0; + OpenAI = 1; + Google = 2; } -message ToolCallDelta { - uint32 index = 1; - optional string id = 2; - - oneof variant { - FunctionCallDelta function = 3; - } - - message FunctionCallDelta { - optional string name = 1; - optional string arguments = 2; - } +enum LanguageModelRequestKind { + Complete = 0; + CountTokens = 1; } -message LanguageModelResponse { - repeated LanguageModelChoiceDelta choices = 1; -} - -message LanguageModelChoiceDelta { - uint32 index = 1; - LanguageModelResponseMessage delta = 2; - optional string finish_reason = 3; -} - -message CountTokensWithLanguageModel { - string model = 1; - repeated LanguageModelRequestMessage messages = 2; -} - -message CountTokensResponse { - uint32 token_count = 1; +message QueryLanguageModelResponse { + string response = 1; } message GetCachedEmbeddings { diff --git a/crates/proto/src/proto.rs b/crates/proto/src/proto.rs index a205b79ecb..7ef1866acd 100644 --- a/crates/proto/src/proto.rs +++ b/crates/proto/src/proto.rs @@ -203,12 +203,9 @@ messages!( (CancelCall, Foreground), (ChannelMessageSent, Foreground), (ChannelMessageUpdate, Foreground), - (CompleteWithLanguageModel, Background), (ComputeEmbeddings, Background), (ComputeEmbeddingsResponse, Background), (CopyProjectEntry, Foreground), - (CountTokensWithLanguageModel, Background), - (CountTokensResponse, Background), (CreateBufferForPeer, Foreground), (CreateChannel, Foreground), (CreateChannelResponse, Foreground), @@ -278,7 +275,6 @@ messages!( (JoinProjectResponse, Foreground), (JoinRoom, Foreground), (JoinRoomResponse, Foreground), - (LanguageModelResponse, Background), (LeaveChannelBuffer, Background), (LeaveChannelChat, Foreground), (LeaveProject, Foreground), @@ -298,6 +294,8 @@ messages!( (PrepareRename, Background), (PrepareRenameResponse, Background), (ProjectEntryResponse, Foreground), + (QueryLanguageModel, Background), + (QueryLanguageModelResponse, Background), (RefreshInlayHints, Foreground), (RejoinChannelBuffers, Foreground), (RejoinChannelBuffersResponse, Foreground), @@ -412,9 +410,7 @@ request_messages!( (Call, Ack), (CancelCall, Ack), (CopyProjectEntry, ProjectEntryResponse), - (CompleteWithLanguageModel, LanguageModelResponse), (ComputeEmbeddings, ComputeEmbeddingsResponse), - (CountTokensWithLanguageModel, CountTokensResponse), (CreateChannel, CreateChannelResponse), (CreateProjectEntry, ProjectEntryResponse), (CreateRoom, CreateRoomResponse), @@ -467,6 +463,7 @@ request_messages!( (PerformRename, PerformRenameResponse), (Ping, Ack), (PrepareRename, PrepareRenameResponse), + (QueryLanguageModel, QueryLanguageModelResponse), (RefreshInlayHints, Ack), (RejoinChannelBuffers, RejoinChannelBuffersResponse), (RejoinRoom, RejoinRoomResponse),