diff --git a/Cargo.lock b/Cargo.lock index 92dd5d9a8e..2876ec86a4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -471,27 +471,6 @@ dependencies = [ "workspace", ] -[[package]] -name = "assistant_tooling" -version = "0.1.0" -dependencies = [ - "anyhow", - "collections", - "futures 0.3.28", - "gpui", - "log", - "project", - "repair_json", - "schemars", - "serde", - "serde_json", - "settings", - "sum_tree", - "ui", - "unindent", - "util", -] - [[package]] name = "async-attributes" version = "1.1.2" @@ -4811,8 +4790,10 @@ dependencies = [ "anyhow", "futures 0.3.28", "http_client", + "schemars", "serde", "serde_json", + "strum", ] [[package]] @@ -5988,6 +5969,7 @@ dependencies = [ "env_logger", "feature_flags", "futures 0.3.28", + "google_ai", "gpui", "http_client", "language", @@ -8715,15 +8697,6 @@ dependencies = [ "bytecheck", ] -[[package]] -name = "repair_json" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5ee191e184125fe72cb59b74160e25584e3908f2aaa84cbda1e161347102aa15" -dependencies = [ - "thiserror", -] - [[package]] name = "repl" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index b289d083bb..19a6b6b836 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,7 +6,6 @@ members = [ "crates/assets", "crates/assistant", "crates/assistant_slash_command", - "crates/assistant_tooling", "crates/audio", "crates/auto_update", "crates/breadcrumbs", @@ -178,7 +177,6 @@ anthropic = { path = "crates/anthropic" } assets = { path = "crates/assets" } assistant = { path = "crates/assistant" } assistant_slash_command = { path = "crates/assistant_slash_command" } -assistant_tooling = { path = "crates/assistant_tooling" } audio = { path = "crates/audio" } auto_update = { path = "crates/auto_update" } breadcrumbs = { path = "crates/breadcrumbs" } diff --git a/assets/settings/default.json b/assets/settings/default.json index 529b91b7cd..a26c7d27a0 100644 --- a/assets/settings/default.json +++ b/assets/settings/default.json @@ -870,6 +870,9 @@ "openai": { "api_url": "https://api.openai.com/v1" }, + "google": { + "api_url": "https://generativelanguage.googleapis.com" + }, "ollama": { "api_url": "http://localhost:11434" } diff --git a/crates/anthropic/src/anthropic.rs b/crates/anthropic/src/anthropic.rs index 2d9bd311b8..45a4dfc0d3 100644 --- a/crates/anthropic/src/anthropic.rs +++ b/crates/anthropic/src/anthropic.rs @@ -1,5 +1,5 @@ use anyhow::{anyhow, Result}; -use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt}; +use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt}; use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest}; use isahc::config::Configurable; use serde::{Deserialize, Serialize}; @@ -98,7 +98,7 @@ impl From for String { } } -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize, Deserialize)] pub struct Request { pub model: String, pub messages: Vec, @@ -113,7 +113,7 @@ pub struct RequestMessage { pub content: String, } -#[derive(Deserialize, Debug)] +#[derive(Deserialize, Serialize, Debug)] #[serde(tag = "type", rename_all = "snake_case")] pub enum ResponseEvent { MessageStart { @@ -138,7 +138,7 @@ pub enum ResponseEvent { MessageStop {}, } -#[derive(Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug)] pub struct ResponseMessage { #[serde(rename = "type")] pub message_type: Option, @@ -151,19 +151,19 @@ pub struct ResponseMessage { pub usage: Option, } -#[derive(Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug)] pub struct Usage { pub input_tokens: Option, pub output_tokens: Option, } -#[derive(Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug)] #[serde(tag = "type", rename_all = "snake_case")] pub enum ContentBlock { Text { text: String }, } -#[derive(Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug)] #[serde(tag = "type", rename_all = "snake_case")] pub enum TextDelta { TextDelta { text: String }, @@ -226,6 +226,25 @@ pub async fn stream_completion( } } +pub fn extract_text_from_events( + response: impl Stream>, +) -> impl Stream> { + response.filter_map(|response| async move { + match response { + Ok(response) => match response { + ResponseEvent::ContentBlockStart { content_block, .. } => match content_block { + ContentBlock::Text { text } => Some(Ok(text)), + }, + ResponseEvent::ContentBlockDelta { delta, .. } => match delta { + TextDelta::TextDelta { text } => Some(Ok(text)), + }, + _ => None, + }, + Err(error) => Some(Err(error)), + } + }) +} + // #[cfg(test)] // mod tests { // use super::*; diff --git a/crates/assistant/src/assistant_settings.rs b/crates/assistant/src/assistant_settings.rs index 05c5b56f1c..0d4dbd6824 100644 --- a/crates/assistant/src/assistant_settings.rs +++ b/crates/assistant/src/assistant_settings.rs @@ -249,9 +249,7 @@ impl AssistantSettingsContent { AssistantSettingsContent::Versioned(settings) => match settings { VersionedAssistantSettingsContent::V1(settings) => match provider.as_ref() { "zed.dev" => { - settings.provider = Some(AssistantProviderContentV1::ZedDotDev { - default_model: CloudModel::from_id(&model).ok(), - }); + log::warn!("attempted to set zed.dev model on outdated settings"); } "anthropic" => { let (api_url, low_speed_timeout_in_seconds) = match &settings.provider { diff --git a/crates/assistant_tooling/Cargo.toml b/crates/assistant_tooling/Cargo.toml deleted file mode 100644 index 79f41faad2..0000000000 --- a/crates/assistant_tooling/Cargo.toml +++ /dev/null @@ -1,33 +0,0 @@ -[package] -name = "assistant_tooling" -version = "0.1.0" -edition = "2021" -publish = false -license = "GPL-3.0-or-later" - -[lints] -workspace = true - -[lib] -path = "src/assistant_tooling.rs" - -[dependencies] -anyhow.workspace = true -collections.workspace = true -futures.workspace = true -gpui.workspace = true -log.workspace = true -project.workspace = true -repair_json.workspace = true -schemars.workspace = true -serde.workspace = true -serde_json.workspace = true -sum_tree.workspace = true -ui.workspace = true -util.workspace = true - -[dev-dependencies] -gpui = { workspace = true, features = ["test-support"] } -project = { workspace = true, features = ["test-support"] } -settings = { workspace = true, features = ["test-support"] } -unindent.workspace = true diff --git a/crates/assistant_tooling/LICENSE-GPL b/crates/assistant_tooling/LICENSE-GPL deleted file mode 120000 index 89e542f750..0000000000 --- a/crates/assistant_tooling/LICENSE-GPL +++ /dev/null @@ -1 +0,0 @@ -../../LICENSE-GPL \ No newline at end of file diff --git a/crates/assistant_tooling/README.md b/crates/assistant_tooling/README.md deleted file mode 100644 index 160869ae97..0000000000 --- a/crates/assistant_tooling/README.md +++ /dev/null @@ -1,85 +0,0 @@ -# Assistant Tooling - -Bringing Language Model tool calling to GPUI. - -This unlocks: - -- **Structured Extraction** of model responses -- **Validation** of model inputs -- **Execution** of chosen tools - -## Overview - -Language Models can produce structured outputs that are perfect for calling functions. The most famous of these is OpenAI's tool calling. When making a chat completion you can pass a list of tools available to the model. The model will choose `0..n` tools to help them complete a user's task. It's up to _you_ to create the tools that the model can call. - -> **User**: "Hey I need help with implementing a collapsible panel in GPUI" -> -> **Assistant**: "Sure, I can help with that. Let me see what I can find." -> -> `tool_calls: ["name": "query_codebase", arguments: "{ 'query': 'GPUI collapsible panel' }"]` -> -> `result: "['crates/gpui/src/panel.rs:12: impl Panel { ... }', 'crates/gpui/src/panel.rs:20: impl Panel { ... }']"` -> -> **Assistant**: "Here are some excerpts from the GPUI codebase that might help you." - -This library is designed to facilitate this interaction mode by allowing you to go from `struct` to `tool` with two simple traits, `LanguageModelTool` and `ToolView`. - -## Using the Tool Registry - -```rust -let mut tool_registry = ToolRegistry::new(); -tool_registry - .register(WeatherTool { api_client }, - }) - .unwrap(); // You can only register one tool per name - -let completion = cx.update(|cx| { - CompletionProvider::get(cx).complete( - model_name, - messages, - Vec::new(), - 1.0, - // The definitions get passed directly to OpenAI when you want - // the model to be able to call your tool - tool_registry.definitions(), - ) -}); - -let mut stream = completion?.await?; - -let mut message = AssistantMessage::new(); - -while let Some(delta) = stream.next().await { - // As messages stream in, you'll get both assistant content - if let Some(content) = &delta.content { - message - .body - .update(cx, |message, cx| message.append(&content, cx)); - } - - // And tool calls! - for tool_call_delta in delta.tool_calls { - let index = tool_call_delta.index as usize; - if index >= message.tool_calls.len() { - message.tool_calls.resize_with(index + 1, Default::default); - } - let tool_call = &mut message.tool_calls[index]; - - // Build up an ID - if let Some(id) = &tool_call_delta.id { - tool_call.id.push_str(id); - } - - tool_registry.update_tool_call( - tool_call, - tool_call_delta.name.as_deref(), - tool_call_delta.arguments.as_deref(), - cx, - ); - } -} -``` - -Once the stream of tokens is complete, you can execute the tool call by calling `tool_registry.execute_tool_call(tool_call, cx)`, which returns a `Task>`. - -As the tokens stream in and tool calls are executed, your `ToolView` will get updates. Render each tool call by passing that `tool_call` in to `tool_registry.render_tool_call(tool_call, cx)`. The final message for the model can be pulled by calling `self.tool_registry.content_for_tool_call( tool_call, &mut project_context, cx, )`. diff --git a/crates/assistant_tooling/src/assistant_tooling.rs b/crates/assistant_tooling/src/assistant_tooling.rs deleted file mode 100644 index 9dcf2908e9..0000000000 --- a/crates/assistant_tooling/src/assistant_tooling.rs +++ /dev/null @@ -1,13 +0,0 @@ -mod attachment_registry; -mod project_context; -mod tool_registry; - -pub use attachment_registry::{ - AttachmentOutput, AttachmentRegistry, LanguageModelAttachment, SavedUserAttachment, - UserAttachment, -}; -pub use project_context::ProjectContext; -pub use tool_registry::{ - LanguageModelTool, SavedToolFunctionCall, ToolFunctionCall, ToolFunctionDefinition, - ToolRegistry, ToolView, -}; diff --git a/crates/assistant_tooling/src/attachment_registry.rs b/crates/assistant_tooling/src/attachment_registry.rs deleted file mode 100644 index e8b52d26f0..0000000000 --- a/crates/assistant_tooling/src/attachment_registry.rs +++ /dev/null @@ -1,234 +0,0 @@ -use crate::ProjectContext; -use anyhow::{anyhow, Result}; -use collections::HashMap; -use futures::future::join_all; -use gpui::{AnyView, Render, Task, View, WindowContext}; -use serde::{de::DeserializeOwned, Deserialize, Serialize}; -use serde_json::value::RawValue; -use std::{ - any::TypeId, - sync::{ - atomic::{AtomicBool, Ordering::SeqCst}, - Arc, - }, -}; -use util::ResultExt as _; - -pub struct AttachmentRegistry { - registered_attachments: HashMap, -} - -pub trait AttachmentOutput { - fn generate(&self, project: &mut ProjectContext, cx: &mut WindowContext) -> String; -} - -pub trait LanguageModelAttachment { - type Output: DeserializeOwned + Serialize + 'static; - type View: Render + AttachmentOutput; - - fn name(&self) -> Arc; - fn run(&self, cx: &mut WindowContext) -> Task>; - fn view(&self, output: Result, cx: &mut WindowContext) -> View; -} - -/// A collected attachment from running an attachment tool -pub struct UserAttachment { - pub view: AnyView, - name: Arc, - serialized_output: Result, String>, - generate_fn: fn(AnyView, &mut ProjectContext, cx: &mut WindowContext) -> String, -} - -#[derive(Serialize, Deserialize)] -pub struct SavedUserAttachment { - name: Arc, - serialized_output: Result, String>, -} - -/// Internal representation of an attachment tool to allow us to treat them dynamically -struct RegisteredAttachment { - name: Arc, - enabled: AtomicBool, - call: Box Task>>, - deserialize: Box Result>, -} - -impl AttachmentRegistry { - pub fn new() -> Self { - Self { - registered_attachments: HashMap::default(), - } - } - - pub fn register(&mut self, attachment: A) { - let attachment = Arc::new(attachment); - - let call = Box::new({ - let attachment = attachment.clone(); - move |cx: &mut WindowContext| { - let result = attachment.run(cx); - let attachment = attachment.clone(); - cx.spawn(move |mut cx| async move { - let result: Result = result.await; - let serialized_output = - result - .as_ref() - .map_err(ToString::to_string) - .and_then(|output| { - Ok(RawValue::from_string( - serde_json::to_string(output).map_err(|e| e.to_string())?, - ) - .unwrap()) - }); - - let view = cx.update(|cx| attachment.view(result, cx))?; - - Ok(UserAttachment { - name: attachment.name(), - view: view.into(), - generate_fn: generate::, - serialized_output, - }) - }) - } - }); - - let deserialize = Box::new({ - let attachment = attachment.clone(); - move |saved_attachment: &SavedUserAttachment, cx: &mut WindowContext| { - let serialized_output = saved_attachment.serialized_output.clone(); - let output = match &serialized_output { - Ok(serialized_output) => { - Ok(serde_json::from_str::(serialized_output.get())?) - } - Err(error) => Err(anyhow!("{error}")), - }; - let view = attachment.view(output, cx).into(); - - Ok(UserAttachment { - name: saved_attachment.name.clone(), - view, - serialized_output, - generate_fn: generate::, - }) - } - }); - - self.registered_attachments.insert( - TypeId::of::(), - RegisteredAttachment { - name: attachment.name(), - call, - deserialize, - enabled: AtomicBool::new(true), - }, - ); - return; - - fn generate( - view: AnyView, - project: &mut ProjectContext, - cx: &mut WindowContext, - ) -> String { - view.downcast::() - .unwrap() - .update(cx, |view, cx| T::View::generate(view, project, cx)) - } - } - - pub fn set_attachment_tool_enabled( - &self, - is_enabled: bool, - ) { - if let Some(attachment) = self.registered_attachments.get(&TypeId::of::()) { - attachment.enabled.store(is_enabled, SeqCst); - } - } - - pub fn is_attachment_tool_enabled(&self) -> bool { - if let Some(attachment) = self.registered_attachments.get(&TypeId::of::()) { - attachment.enabled.load(SeqCst) - } else { - false - } - } - - pub fn call( - &self, - cx: &mut WindowContext, - ) -> Task> { - let Some(attachment) = self.registered_attachments.get(&TypeId::of::()) else { - return Task::ready(Err(anyhow!("no attachment tool"))); - }; - - (attachment.call)(cx) - } - - pub fn call_all_attachment_tools( - self: Arc, - cx: &mut WindowContext<'_>, - ) -> Task>> { - let this = self.clone(); - cx.spawn(|mut cx| async move { - let attachment_tasks = cx.update(|cx| { - let mut tasks = Vec::new(); - for attachment in this - .registered_attachments - .values() - .filter(|attachment| attachment.enabled.load(SeqCst)) - { - tasks.push((attachment.call)(cx)) - } - - tasks - })?; - - let attachments = join_all(attachment_tasks.into_iter()).await; - - Ok(attachments - .into_iter() - .filter_map(|attachment| attachment.log_err()) - .collect()) - }) - } - - pub fn serialize_user_attachment( - &self, - user_attachment: &UserAttachment, - ) -> SavedUserAttachment { - SavedUserAttachment { - name: user_attachment.name.clone(), - serialized_output: user_attachment.serialized_output.clone(), - } - } - - pub fn deserialize_user_attachment( - &self, - saved_user_attachment: SavedUserAttachment, - cx: &mut WindowContext, - ) -> Result { - if let Some(registered_attachment) = self - .registered_attachments - .values() - .find(|attachment| attachment.name == saved_user_attachment.name) - { - (registered_attachment.deserialize)(&saved_user_attachment, cx) - } else { - Err(anyhow!( - "no attachment tool for name {}", - saved_user_attachment.name - )) - } - } -} - -impl UserAttachment { - pub fn generate(&self, output: &mut ProjectContext, cx: &mut WindowContext) -> Option { - let result = (self.generate_fn)(self.view.clone(), output, cx); - if result.is_empty() { - None - } else { - Some(result) - } - } -} diff --git a/crates/assistant_tooling/src/project_context.rs b/crates/assistant_tooling/src/project_context.rs deleted file mode 100644 index 2640ce1ed5..0000000000 --- a/crates/assistant_tooling/src/project_context.rs +++ /dev/null @@ -1,296 +0,0 @@ -use anyhow::{anyhow, Result}; -use gpui::{AppContext, Model, Task, WeakModel}; -use project::{Fs, Project, ProjectPath, Worktree}; -use std::{cmp::Ordering, fmt::Write as _, ops::Range, sync::Arc}; -use sum_tree::TreeMap; - -pub struct ProjectContext { - files: TreeMap, - project: WeakModel, - fs: Arc, -} - -#[derive(Debug, Clone)] -enum PathState { - PathOnly, - EntireFile, - Excerpts { ranges: Vec> }, -} - -impl ProjectContext { - pub fn new(project: WeakModel, fs: Arc) -> Self { - Self { - files: TreeMap::default(), - fs, - project, - } - } - - pub fn add_path(&mut self, project_path: ProjectPath) { - if self.files.get(&project_path).is_none() { - self.files.insert(project_path, PathState::PathOnly); - } - } - - pub fn add_excerpts(&mut self, project_path: ProjectPath, new_ranges: &[Range]) { - let previous_state = self - .files - .get(&project_path) - .unwrap_or(&PathState::PathOnly); - - let mut ranges = match previous_state { - PathState::EntireFile => return, - PathState::PathOnly => Vec::new(), - PathState::Excerpts { ranges } => ranges.to_vec(), - }; - - for new_range in new_ranges { - let ix = ranges.binary_search_by(|probe| { - if probe.end < new_range.start { - Ordering::Less - } else if probe.start > new_range.end { - Ordering::Greater - } else { - Ordering::Equal - } - }); - - match ix { - Ok(mut ix) => { - let existing = &mut ranges[ix]; - existing.start = existing.start.min(new_range.start); - existing.end = existing.end.max(new_range.end); - while ix + 1 < ranges.len() && ranges[ix + 1].start <= ranges[ix].end { - ranges[ix].end = ranges[ix].end.max(ranges[ix + 1].end); - ranges.remove(ix + 1); - } - while ix > 0 && ranges[ix - 1].end >= ranges[ix].start { - ranges[ix].start = ranges[ix].start.min(ranges[ix - 1].start); - ranges.remove(ix - 1); - ix -= 1; - } - } - Err(ix) => { - ranges.insert(ix, new_range.clone()); - } - } - } - - self.files - .insert(project_path, PathState::Excerpts { ranges }); - } - - pub fn add_file(&mut self, project_path: ProjectPath) { - self.files.insert(project_path, PathState::EntireFile); - } - - pub fn generate_system_message(&self, cx: &mut AppContext) -> Task> { - let project = self - .project - .upgrade() - .ok_or_else(|| anyhow!("project dropped")); - let files = self.files.clone(); - let fs = self.fs.clone(); - cx.spawn(|cx| async move { - let project = project?; - let mut result = "project structure:\n".to_string(); - - let mut last_worktree: Option> = None; - for (project_path, path_state) in files.iter() { - if let Some(worktree) = &last_worktree { - if worktree.read_with(&cx, |tree, _| tree.id())? != project_path.worktree_id { - last_worktree = None; - } - } - - let worktree; - if let Some(last_worktree) = &last_worktree { - worktree = last_worktree.clone(); - } else if let Some(tree) = project.read_with(&cx, |project, cx| { - project.worktree_for_id(project_path.worktree_id, cx) - })? { - worktree = tree; - last_worktree = Some(worktree.clone()); - let worktree_name = - worktree.read_with(&cx, |tree, _cx| tree.root_name().to_string())?; - writeln!(&mut result, "# {}", worktree_name).unwrap(); - } else { - continue; - } - - let worktree_abs_path = worktree.read_with(&cx, |tree, _cx| tree.abs_path())?; - let path = &project_path.path; - writeln!(&mut result, "## {}", path.display()).unwrap(); - - match path_state { - PathState::PathOnly => {} - PathState::EntireFile => { - let text = fs.load(&worktree_abs_path.join(&path)).await?; - writeln!(&mut result, "~~~\n{text}\n~~~").unwrap(); - } - PathState::Excerpts { ranges } => { - let text = fs.load(&worktree_abs_path.join(&path)).await?; - - writeln!(&mut result, "~~~").unwrap(); - - // Assumption: ranges are in order, not overlapping - let mut prev_range_end = 0; - for range in ranges { - if range.start > prev_range_end { - writeln!(&mut result, "...").unwrap(); - prev_range_end = range.end; - } - - let mut start = range.start; - let mut end = range.end.min(text.len()); - while !text.is_char_boundary(start) { - start += 1; - } - while !text.is_char_boundary(end) { - end -= 1; - } - result.push_str(&text[start..end]); - if !result.ends_with('\n') { - result.push('\n'); - } - } - - if prev_range_end < text.len() { - writeln!(&mut result, "...").unwrap(); - } - - writeln!(&mut result, "~~~").unwrap(); - } - } - } - Ok(result) - }) - } -} - -#[cfg(test)] -mod tests { - use std::path::Path; - - use super::*; - use gpui::TestAppContext; - use project::FakeFs; - use serde_json::json; - use settings::SettingsStore; - - use unindent::Unindent as _; - - #[gpui::test] - async fn test_system_message_generation(cx: &mut TestAppContext) { - init_test(cx); - - let file_3_contents = r#" - fn test1() {} - fn test2() {} - fn test3() {} - "# - .unindent(); - - let fs = FakeFs::new(cx.executor()); - fs.insert_tree( - "/code", - json!({ - "root1": { - "lib": { - "file1.rs": "mod example;", - "file2.rs": "", - }, - "test": { - "file3.rs": file_3_contents, - } - }, - "root2": { - "src": { - "main.rs": "" - } - } - }), - ) - .await; - - let project = Project::test( - fs.clone(), - ["/code/root1".as_ref(), "/code/root2".as_ref()], - cx, - ) - .await; - - let worktree_ids = project.read_with(cx, |project, cx| { - project - .worktrees(cx) - .map(|worktree| worktree.read(cx).id()) - .collect::>() - }); - - let mut ax = ProjectContext::new(project.downgrade(), fs); - - ax.add_file(ProjectPath { - worktree_id: worktree_ids[0], - path: Path::new("lib/file1.rs").into(), - }); - - let message = cx - .update(|cx| ax.generate_system_message(cx)) - .await - .unwrap(); - assert_eq!( - r#" - project structure: - # root1 - ## lib/file1.rs - ~~~ - mod example; - ~~~ - "# - .unindent(), - message - ); - - ax.add_excerpts( - ProjectPath { - worktree_id: worktree_ids[0], - path: Path::new("test/file3.rs").into(), - }, - &[ - file_3_contents.find("fn test2").unwrap() - ..file_3_contents.find("fn test3").unwrap(), - ], - ); - - let message = cx - .update(|cx| ax.generate_system_message(cx)) - .await - .unwrap(); - assert_eq!( - r#" - project structure: - # root1 - ## lib/file1.rs - ~~~ - mod example; - ~~~ - ## test/file3.rs - ~~~ - ... - fn test2() {} - ... - ~~~ - "# - .unindent(), - message - ); - } - - fn init_test(cx: &mut TestAppContext) { - cx.update(|cx| { - let settings_store = SettingsStore::test(cx); - cx.set_global(settings_store); - Project::init_settings(cx); - }); - } -} diff --git a/crates/assistant_tooling/src/tool_registry.rs b/crates/assistant_tooling/src/tool_registry.rs deleted file mode 100644 index e5f8914eb5..0000000000 --- a/crates/assistant_tooling/src/tool_registry.rs +++ /dev/null @@ -1,526 +0,0 @@ -use crate::ProjectContext; -use anyhow::{anyhow, Result}; -use gpui::{AnyElement, AnyView, IntoElement, Render, Task, View, WindowContext}; -use repair_json::repair; -use schemars::{schema::RootSchema, schema_for, JsonSchema}; -use serde::{de::DeserializeOwned, Deserialize, Serialize}; -use serde_json::value::RawValue; -use std::{ - any::TypeId, - collections::HashMap, - fmt::Display, - mem, - sync::atomic::{AtomicBool, Ordering::SeqCst}, -}; -use ui::ViewContext; - -pub struct ToolRegistry { - registered_tools: HashMap, -} - -#[derive(Default)] -pub struct ToolFunctionCall { - pub id: String, - pub name: String, - pub arguments: String, - state: ToolFunctionCallState, -} - -#[derive(Default)] -enum ToolFunctionCallState { - #[default] - Initializing, - NoSuchTool, - KnownTool(Box), - ExecutedTool(Box), -} - -trait InternalToolView { - fn view(&self) -> AnyView; - fn generate(&self, project: &mut ProjectContext, cx: &mut WindowContext) -> String; - fn try_set_input(&self, input: &str, cx: &mut WindowContext); - fn execute(&self, cx: &mut WindowContext) -> Task>; - fn serialize_output(&self, cx: &mut WindowContext) -> Result>; - fn deserialize_output(&self, raw_value: &RawValue, cx: &mut WindowContext) -> Result<()>; -} - -#[derive(Default, Serialize, Deserialize)] -pub struct SavedToolFunctionCall { - id: String, - name: String, - arguments: String, - state: SavedToolFunctionCallState, -} - -#[derive(Default, Serialize, Deserialize)] -enum SavedToolFunctionCallState { - #[default] - Initializing, - NoSuchTool, - KnownTool, - ExecutedTool(Box), -} - -#[derive(Clone, Debug, PartialEq)] -pub struct ToolFunctionDefinition { - pub name: String, - pub description: String, - pub parameters: RootSchema, -} - -pub trait LanguageModelTool { - type View: ToolView; - - /// Returns the name of the tool. - /// - /// This name is exposed to the language model to allow the model to pick - /// which tools to use. As this name is used to identify the tool within a - /// tool registry, it should be unique. - fn name(&self) -> String; - - /// Returns the description of the tool. - /// - /// This can be used to _prompt_ the model as to what the tool does. - fn description(&self) -> String; - - /// Returns the OpenAI Function definition for the tool, for direct use with OpenAI's API. - fn definition(&self) -> ToolFunctionDefinition { - let root_schema = schema_for!(::Input); - - ToolFunctionDefinition { - name: self.name(), - description: self.description(), - parameters: root_schema, - } - } - - /// A view of the output of running the tool, for displaying to the user. - fn view(&self, cx: &mut WindowContext) -> View; -} - -pub trait ToolView: Render { - /// The input type that will be passed in to `execute` when the tool is called - /// by the language model. - type Input: DeserializeOwned + JsonSchema; - - /// The output returned by executing the tool. - type SerializedState: DeserializeOwned + Serialize; - - fn generate(&self, project: &mut ProjectContext, cx: &mut ViewContext) -> String; - fn set_input(&mut self, input: Self::Input, cx: &mut ViewContext); - fn execute(&mut self, cx: &mut ViewContext) -> Task>; - - fn serialize(&self, cx: &mut ViewContext) -> Self::SerializedState; - fn deserialize( - &mut self, - output: Self::SerializedState, - cx: &mut ViewContext, - ) -> Result<()>; -} - -struct RegisteredTool { - enabled: AtomicBool, - type_id: TypeId, - build_view: Box Box>, - definition: ToolFunctionDefinition, -} - -impl ToolRegistry { - pub fn new() -> Self { - Self { - registered_tools: HashMap::new(), - } - } - - pub fn set_tool_enabled(&self, is_enabled: bool) { - for tool in self.registered_tools.values() { - if tool.type_id == TypeId::of::() { - tool.enabled.store(is_enabled, SeqCst); - return; - } - } - } - - pub fn is_tool_enabled(&self) -> bool { - for tool in self.registered_tools.values() { - if tool.type_id == TypeId::of::() { - return tool.enabled.load(SeqCst); - } - } - false - } - - pub fn definitions(&self) -> Vec { - self.registered_tools - .values() - .filter(|tool| tool.enabled.load(SeqCst)) - .map(|tool| tool.definition.clone()) - .collect() - } - - pub fn update_tool_call( - &self, - call: &mut ToolFunctionCall, - name: Option<&str>, - arguments: Option<&str>, - cx: &mut WindowContext, - ) { - if let Some(name) = name { - call.name.push_str(name); - } - if let Some(arguments) = arguments { - if call.arguments.is_empty() { - if let Some(tool) = self.registered_tools.get(&call.name) { - let view = (tool.build_view)(cx); - call.state = ToolFunctionCallState::KnownTool(view); - } else { - call.state = ToolFunctionCallState::NoSuchTool; - } - } - call.arguments.push_str(arguments); - - if let ToolFunctionCallState::KnownTool(view) = &call.state { - if let Ok(repaired_arguments) = repair(call.arguments.clone()) { - view.try_set_input(&repaired_arguments, cx) - } - } - } - } - - pub fn execute_tool_call( - &self, - tool_call: &mut ToolFunctionCall, - cx: &mut WindowContext, - ) -> Option>> { - if let ToolFunctionCallState::KnownTool(view) = mem::take(&mut tool_call.state) { - let task = view.execute(cx); - tool_call.state = ToolFunctionCallState::ExecutedTool(view); - Some(task) - } else { - None - } - } - - pub fn render_tool_call( - &self, - tool_call: &ToolFunctionCall, - _cx: &mut WindowContext, - ) -> Option { - match &tool_call.state { - ToolFunctionCallState::NoSuchTool => { - Some(ui::Label::new("No such tool").into_any_element()) - } - ToolFunctionCallState::Initializing => None, - ToolFunctionCallState::KnownTool(view) | ToolFunctionCallState::ExecutedTool(view) => { - Some(view.view().into_any_element()) - } - } - } - - pub fn content_for_tool_call( - &self, - tool_call: &ToolFunctionCall, - project_context: &mut ProjectContext, - cx: &mut WindowContext, - ) -> String { - match &tool_call.state { - ToolFunctionCallState::Initializing => String::new(), - ToolFunctionCallState::NoSuchTool => { - format!("No such tool: {}", tool_call.name) - } - ToolFunctionCallState::KnownTool(view) | ToolFunctionCallState::ExecutedTool(view) => { - view.generate(project_context, cx) - } - } - } - - pub fn serialize_tool_call( - &self, - call: &ToolFunctionCall, - cx: &mut WindowContext, - ) -> Result { - Ok(SavedToolFunctionCall { - id: call.id.clone(), - name: call.name.clone(), - arguments: call.arguments.clone(), - state: match &call.state { - ToolFunctionCallState::Initializing => SavedToolFunctionCallState::Initializing, - ToolFunctionCallState::NoSuchTool => SavedToolFunctionCallState::NoSuchTool, - ToolFunctionCallState::KnownTool(_) => SavedToolFunctionCallState::KnownTool, - ToolFunctionCallState::ExecutedTool(view) => { - SavedToolFunctionCallState::ExecutedTool(view.serialize_output(cx)?) - } - }, - }) - } - - pub fn deserialize_tool_call( - &self, - call: &SavedToolFunctionCall, - cx: &mut WindowContext, - ) -> Result { - let Some(tool) = self.registered_tools.get(&call.name) else { - return Err(anyhow!("no such tool {}", call.name)); - }; - - Ok(ToolFunctionCall { - id: call.id.clone(), - name: call.name.clone(), - arguments: call.arguments.clone(), - state: match &call.state { - SavedToolFunctionCallState::Initializing => ToolFunctionCallState::Initializing, - SavedToolFunctionCallState::NoSuchTool => ToolFunctionCallState::NoSuchTool, - SavedToolFunctionCallState::KnownTool => { - log::error!("Deserialized tool that had not executed"); - let view = (tool.build_view)(cx); - view.try_set_input(&call.arguments, cx); - ToolFunctionCallState::KnownTool(view) - } - SavedToolFunctionCallState::ExecutedTool(output) => { - let view = (tool.build_view)(cx); - view.try_set_input(&call.arguments, cx); - view.deserialize_output(output, cx)?; - ToolFunctionCallState::ExecutedTool(view) - } - }, - }) - } - - pub fn register(&mut self, tool: T) -> Result<()> { - let name = tool.name(); - let registered_tool = RegisteredTool { - type_id: TypeId::of::(), - definition: tool.definition(), - enabled: AtomicBool::new(true), - build_view: Box::new(move |cx: &mut WindowContext| Box::new(tool.view(cx))), - }; - - let previous = self.registered_tools.insert(name.clone(), registered_tool); - if previous.is_some() { - return Err(anyhow!("already registered a tool with name {}", name)); - } - - return Ok(()); - } -} - -impl InternalToolView for View { - fn view(&self) -> AnyView { - self.clone().into() - } - - fn generate(&self, project: &mut ProjectContext, cx: &mut WindowContext) -> String { - self.update(cx, |view, cx| view.generate(project, cx)) - } - - fn try_set_input(&self, input: &str, cx: &mut WindowContext) { - if let Ok(input) = serde_json::from_str::(input) { - self.update(cx, |view, cx| { - view.set_input(input, cx); - cx.notify(); - }); - } - } - - fn execute(&self, cx: &mut WindowContext) -> Task> { - self.update(cx, |view, cx| view.execute(cx)) - } - - fn serialize_output(&self, cx: &mut WindowContext) -> Result> { - let output = self.update(cx, |view, cx| view.serialize(cx)); - Ok(RawValue::from_string(serde_json::to_string(&output)?)?) - } - - fn deserialize_output(&self, output: &RawValue, cx: &mut WindowContext) -> Result<()> { - let state = serde_json::from_str::(output.get())?; - self.update(cx, |view, cx| view.deserialize(state, cx))?; - Ok(()) - } -} - -impl Display for ToolFunctionDefinition { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let schema = serde_json::to_string(&self.parameters).ok(); - let schema = schema.unwrap_or("None".to_string()); - write!(f, "Name: {}:\n", self.name)?; - write!(f, "Description: {}\n", self.description)?; - write!(f, "Parameters: {}", schema) - } -} - -#[cfg(test)] -mod test { - use super::*; - use gpui::{div, prelude::*, Render, TestAppContext}; - use gpui::{EmptyView, View}; - use schemars::JsonSchema; - use serde::{Deserialize, Serialize}; - use serde_json::json; - - #[derive(Deserialize, Serialize, JsonSchema)] - struct WeatherQuery { - location: String, - unit: String, - } - - #[derive(Clone, Serialize, Deserialize, PartialEq, Debug)] - struct WeatherResult { - location: String, - temperature: f64, - unit: String, - } - - struct WeatherView { - input: Option, - result: Option, - - // Fake API call - current_weather: WeatherResult, - } - - #[derive(Clone, Serialize)] - struct WeatherTool { - current_weather: WeatherResult, - } - - impl WeatherView { - fn new(current_weather: WeatherResult) -> Self { - Self { - input: None, - result: None, - current_weather, - } - } - } - - impl Render for WeatherView { - fn render(&mut self, _cx: &mut gpui::ViewContext) -> impl IntoElement { - match self.result { - Some(ref result) => div() - .child(format!("temperature: {}", result.temperature)) - .into_any_element(), - None => div().child("Calculating weather...").into_any_element(), - } - } - } - - impl ToolView for WeatherView { - type Input = WeatherQuery; - - type SerializedState = WeatherResult; - - fn generate(&self, _output: &mut ProjectContext, _cx: &mut ViewContext) -> String { - serde_json::to_string(&self.result).unwrap() - } - - fn set_input(&mut self, input: Self::Input, cx: &mut ViewContext) { - self.input = Some(input); - cx.notify(); - } - - fn execute(&mut self, _cx: &mut ViewContext) -> Task> { - let input = self.input.as_ref().unwrap(); - - let _location = input.location.clone(); - let _unit = input.unit.clone(); - - let weather = self.current_weather.clone(); - - self.result = Some(weather); - - Task::ready(Ok(())) - } - - fn serialize(&self, _cx: &mut ViewContext) -> Self::SerializedState { - self.current_weather.clone() - } - - fn deserialize( - &mut self, - output: Self::SerializedState, - _cx: &mut ViewContext, - ) -> Result<()> { - self.current_weather = output; - Ok(()) - } - } - - impl LanguageModelTool for WeatherTool { - type View = WeatherView; - - fn name(&self) -> String { - "get_current_weather".to_string() - } - - fn description(&self) -> String { - "Fetches the current weather for a given location.".to_string() - } - - fn view(&self, cx: &mut WindowContext) -> View { - cx.new_view(|_cx| WeatherView::new(self.current_weather.clone())) - } - } - - #[gpui::test] - async fn test_openai_weather_example(cx: &mut TestAppContext) { - let (_, cx) = cx.add_window_view(|_cx| EmptyView); - - let mut registry = ToolRegistry::new(); - registry - .register(WeatherTool { - current_weather: WeatherResult { - location: "San Francisco".to_string(), - temperature: 21.0, - unit: "Celsius".to_string(), - }, - }) - .unwrap(); - - let definitions = registry.definitions(); - assert_eq!( - definitions, - [ToolFunctionDefinition { - name: "get_current_weather".to_string(), - description: "Fetches the current weather for a given location.".to_string(), - parameters: serde_json::from_value(json!({ - "$schema": "http://json-schema.org/draft-07/schema#", - "title": "WeatherQuery", - "type": "object", - "properties": { - "location": { - "type": "string" - }, - "unit": { - "type": "string" - } - }, - "required": ["location", "unit"] - })) - .unwrap(), - }] - ); - - let mut call = ToolFunctionCall { - id: "the-id".to_string(), - name: "get_cur".to_string(), - ..Default::default() - }; - - let task = cx.update(|cx| { - registry.update_tool_call( - &mut call, - Some("rent_weather"), - Some(r#"{"location": "San Francisco","#), - cx, - ); - registry.update_tool_call(&mut call, None, Some(r#" "unit": "Celsius"}"#), cx); - registry.execute_tool_call(&mut call, cx).unwrap() - }); - task.await.unwrap(); - - match &call.state { - ToolFunctionCallState::ExecutedTool(_view) => {} - _ => panic!(), - } - } -} diff --git a/crates/collab/src/ai.rs b/crates/collab/src/ai.rs deleted file mode 100644 index 06c6e77dfd..0000000000 --- a/crates/collab/src/ai.rs +++ /dev/null @@ -1,138 +0,0 @@ -use anyhow::{anyhow, Context as _, Result}; -use rpc::proto; -use util::ResultExt as _; - -pub fn language_model_request_to_open_ai( - request: proto::CompleteWithLanguageModel, -) -> Result { - Ok(open_ai::Request { - model: open_ai::Model::from_id(&request.model).unwrap_or(open_ai::Model::FourTurbo), - messages: request - .messages - .into_iter() - .map(|message: proto::LanguageModelRequestMessage| { - let role = proto::LanguageModelRole::from_i32(message.role) - .ok_or_else(|| anyhow!("invalid role {}", message.role))?; - - let openai_message = match role { - proto::LanguageModelRole::LanguageModelUser => open_ai::RequestMessage::User { - content: message.content, - }, - proto::LanguageModelRole::LanguageModelAssistant => { - open_ai::RequestMessage::Assistant { - content: Some(message.content), - tool_calls: message - .tool_calls - .into_iter() - .filter_map(|call| { - Some(open_ai::ToolCall { - id: call.id, - content: match call.variant? { - proto::tool_call::Variant::Function(f) => { - open_ai::ToolCallContent::Function { - function: open_ai::FunctionContent { - name: f.name, - arguments: f.arguments, - }, - } - } - }, - }) - }) - .collect(), - } - } - proto::LanguageModelRole::LanguageModelSystem => { - open_ai::RequestMessage::System { - content: message.content, - } - } - proto::LanguageModelRole::LanguageModelTool => open_ai::RequestMessage::Tool { - tool_call_id: message - .tool_call_id - .ok_or_else(|| anyhow!("tool message is missing tool call id"))?, - content: message.content, - }, - }; - - Ok(openai_message) - }) - .collect::>>()?, - stream: true, - stop: request.stop, - temperature: request.temperature, - tools: request - .tools - .into_iter() - .filter_map(|tool| { - Some(match tool.variant? { - proto::chat_completion_tool::Variant::Function(f) => { - open_ai::ToolDefinition::Function { - function: open_ai::FunctionDefinition { - name: f.name, - description: f.description, - parameters: if let Some(params) = &f.parameters { - Some( - serde_json::from_str(params) - .context("failed to deserialize tool parameters") - .log_err()?, - ) - } else { - None - }, - }, - } - } - }) - }) - .collect(), - tool_choice: request.tool_choice, - }) -} - -pub fn language_model_request_to_google_ai( - request: proto::CompleteWithLanguageModel, -) -> Result { - Ok(google_ai::GenerateContentRequest { - contents: request - .messages - .into_iter() - .map(language_model_request_message_to_google_ai) - .collect::>>()?, - generation_config: None, - safety_settings: None, - }) -} - -pub fn language_model_request_message_to_google_ai( - message: proto::LanguageModelRequestMessage, -) -> Result { - let role = proto::LanguageModelRole::from_i32(message.role) - .ok_or_else(|| anyhow!("invalid role {}", message.role))?; - - Ok(google_ai::Content { - parts: vec![google_ai::Part::TextPart(google_ai::TextPart { - text: message.content, - })], - role: match role { - proto::LanguageModelRole::LanguageModelUser => google_ai::Role::User, - proto::LanguageModelRole::LanguageModelAssistant => google_ai::Role::Model, - proto::LanguageModelRole::LanguageModelSystem => google_ai::Role::User, - proto::LanguageModelRole::LanguageModelTool => { - Err(anyhow!("we don't handle tool calls with google ai yet"))? - } - }, - }) -} - -pub fn count_tokens_request_to_google_ai( - request: proto::CountTokensWithLanguageModel, -) -> Result { - Ok(google_ai::CountTokensRequest { - contents: request - .messages - .into_iter() - .map(language_model_request_message_to_google_ai) - .collect::>>()?, - }) -} diff --git a/crates/collab/src/lib.rs b/crates/collab/src/lib.rs index ae83fccb98..2673ca3fb8 100644 --- a/crates/collab/src/lib.rs +++ b/crates/collab/src/lib.rs @@ -1,4 +1,3 @@ -pub mod ai; pub mod api; pub mod auth; pub mod db; diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 3ec13ce045..92e5b1a584 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -46,8 +46,8 @@ use http_client::IsahcHttpClient; use prometheus::{register_int_gauge, IntGauge}; use rpc::{ proto::{ - self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LanguageModelRole, - LiveKitConnectionInfo, RequestMessage, ShareProject, UpdateChannelBufferCollaborators, + self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo, + RequestMessage, ShareProject, UpdateChannelBufferCollaborators, }, Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope, }; @@ -618,17 +618,6 @@ impl Server { ) } }) - .add_request_handler({ - let app_state = app_state.clone(); - user_handler(move |request, response, session| { - count_tokens_with_language_model( - request, - response, - session, - app_state.config.google_ai_api_key.clone(), - ) - }) - }) .add_request_handler({ user_handler(move |request, response, session| { get_cached_embeddings(request, response, session) @@ -4514,8 +4503,8 @@ impl RateLimit for CompleteWithLanguageModelRateLimit { } async fn complete_with_language_model( - mut request: proto::CompleteWithLanguageModel, - response: StreamingResponse, + query: proto::QueryLanguageModel, + response: StreamingResponse, session: Session, open_ai_api_key: Option>, google_ai_api_key: Option>, @@ -4525,287 +4514,95 @@ async fn complete_with_language_model( return Err(anyhow!("user not found"))?; }; authorize_access_to_language_models(&session).await?; - session - .rate_limiter - .check::(session.user_id()) - .await?; - let mut provider_and_model = request.model.split('/'); - let (provider, model) = match ( - provider_and_model.next().unwrap(), - provider_and_model.next(), - ) { - (provider, Some(model)) => (provider, model), - (model, None) => { - if model.starts_with("gpt") { - ("openai", model) - } else if model.starts_with("gemini") { - ("google", model) - } else if model.starts_with("claude") { - ("anthropic", model) - } else { - ("unknown", model) - } + match proto::LanguageModelRequestKind::from_i32(query.kind) { + Some(proto::LanguageModelRequestKind::Complete) => { + session + .rate_limiter + .check::(session.user_id()) + .await?; } - }; - let provider = provider.to_string(); - request.model = model.to_string(); + Some(proto::LanguageModelRequestKind::CountTokens) => { + session + .rate_limiter + .check::(session.user_id()) + .await?; + } + None => Err(anyhow!("unknown request kind"))?, + } - match provider.as_str() { - "openai" => { - let api_key = open_ai_api_key.context("no OpenAI API key configured on the server")?; - complete_with_open_ai(request, response, session, api_key).await?; - } - "anthropic" => { + match proto::LanguageModelProvider::from_i32(query.provider) { + Some(proto::LanguageModelProvider::Anthropic) => { let api_key = anthropic_api_key.context("no Anthropic AI API key configured on the server")?; - complete_with_anthropic(request, response, session, api_key).await?; + let mut chunks = anthropic::stream_completion( + session.http_client.as_ref(), + anthropic::ANTHROPIC_API_URL, + &api_key, + serde_json::from_str(&query.request)?, + None, + ) + .await?; + while let Some(chunk) = chunks.next().await { + let chunk = chunk?; + response.send(proto::QueryLanguageModelResponse { + response: serde_json::to_string(&chunk)?, + })?; + } } - "google" => { + Some(proto::LanguageModelProvider::OpenAi) => { + let api_key = open_ai_api_key.context("no OpenAI API key configured on the server")?; + let mut chunks = open_ai::stream_completion( + session.http_client.as_ref(), + open_ai::OPEN_AI_API_URL, + &api_key, + serde_json::from_str(&query.request)?, + None, + ) + .await?; + while let Some(chunk) = chunks.next().await { + let chunk = chunk?; + response.send(proto::QueryLanguageModelResponse { + response: serde_json::to_string(&chunk)?, + })?; + } + } + Some(proto::LanguageModelProvider::Google) => { let api_key = google_ai_api_key.context("no Google AI API key configured on the server")?; - complete_with_google_ai(request, response, session, api_key).await?; - } - provider => return Err(anyhow!("unknown provider {:?}", provider))?, - } - Ok(()) -} - -async fn complete_with_open_ai( - request: proto::CompleteWithLanguageModel, - response: StreamingResponse, - session: UserSession, - api_key: Arc, -) -> Result<()> { - let mut completion_stream = open_ai::stream_completion( - session.http_client.as_ref(), - OPEN_AI_API_URL, - &api_key, - crate::ai::language_model_request_to_open_ai(request)?, - None, - ) - .await - .context("open_ai::stream_completion request failed within collab")?; - - while let Some(event) = completion_stream.next().await { - let event = event?; - response.send(proto::LanguageModelResponse { - choices: event - .choices - .into_iter() - .map(|choice| proto::LanguageModelChoiceDelta { - index: choice.index, - delta: Some(proto::LanguageModelResponseMessage { - role: choice.delta.role.map(|role| match role { - open_ai::Role::User => LanguageModelRole::LanguageModelUser, - open_ai::Role::Assistant => LanguageModelRole::LanguageModelAssistant, - open_ai::Role::System => LanguageModelRole::LanguageModelSystem, - open_ai::Role::Tool => LanguageModelRole::LanguageModelTool, - } as i32), - content: choice.delta.content, - tool_calls: choice - .delta - .tool_calls - .unwrap_or_default() - .into_iter() - .map(|delta| proto::ToolCallDelta { - index: delta.index as u32, - id: delta.id, - variant: match delta.function { - Some(function) => { - let name = function.name; - let arguments = function.arguments; - - Some(proto::tool_call_delta::Variant::Function( - proto::tool_call_delta::FunctionCallDelta { - name, - arguments, - }, - )) - } - None => None, - }, - }) - .collect(), - }), - finish_reason: choice.finish_reason, - }) - .collect(), - })?; - } - - Ok(()) -} - -async fn complete_with_google_ai( - request: proto::CompleteWithLanguageModel, - response: StreamingResponse, - session: UserSession, - api_key: Arc, -) -> Result<()> { - let mut stream = google_ai::stream_generate_content( - session.http_client.clone(), - google_ai::API_URL, - api_key.as_ref(), - &request.model.clone(), - crate::ai::language_model_request_to_google_ai(request)?, - ) - .await - .context("google_ai::stream_generate_content request failed")?; - - while let Some(event) = stream.next().await { - let event = event?; - response.send(proto::LanguageModelResponse { - choices: event - .candidates - .unwrap_or_default() - .into_iter() - .map(|candidate| proto::LanguageModelChoiceDelta { - index: candidate.index as u32, - delta: Some(proto::LanguageModelResponseMessage { - role: Some(match candidate.content.role { - google_ai::Role::User => LanguageModelRole::LanguageModelUser, - google_ai::Role::Model => LanguageModelRole::LanguageModelAssistant, - } as i32), - content: Some( - candidate - .content - .parts - .into_iter() - .filter_map(|part| match part { - google_ai::Part::TextPart(part) => Some(part.text), - google_ai::Part::InlineDataPart(_) => None, - }) - .collect(), - ), - // Tool calls are not supported for Google - tool_calls: Vec::new(), - }), - finish_reason: candidate.finish_reason.map(|reason| reason.to_string()), - }) - .collect(), - })?; - } - - Ok(()) -} - -async fn complete_with_anthropic( - request: proto::CompleteWithLanguageModel, - response: StreamingResponse, - session: UserSession, - api_key: Arc, -) -> Result<()> { - let mut system_message = String::new(); - let messages = request - .messages - .into_iter() - .filter_map(|message| { - match message.role() { - LanguageModelRole::LanguageModelUser => Some(anthropic::RequestMessage { - role: anthropic::Role::User, - content: message.content, - }), - LanguageModelRole::LanguageModelAssistant => Some(anthropic::RequestMessage { - role: anthropic::Role::Assistant, - content: message.content, - }), - // Anthropic's API breaks system instructions out as a separate field rather - // than having a system message role. - LanguageModelRole::LanguageModelSystem => { - if !system_message.is_empty() { - system_message.push_str("\n\n"); - } - system_message.push_str(&message.content); - - None - } - // We don't yet support tool calls for Anthropic - LanguageModelRole::LanguageModelTool => None, - } - }) - .collect(); - - let mut stream = anthropic::stream_completion( - session.http_client.as_ref(), - anthropic::ANTHROPIC_API_URL, - &api_key, - anthropic::Request { - model: request.model, - messages, - stream: true, - system: system_message, - max_tokens: 4092, - }, - None, - ) - .await?; - - let mut current_role = proto::LanguageModelRole::LanguageModelAssistant; - - while let Some(event) = stream.next().await { - let event = event?; - - match event { - anthropic::ResponseEvent::MessageStart { message } => { - if let Some(role) = message.role { - if role == "assistant" { - current_role = proto::LanguageModelRole::LanguageModelAssistant; - } else if role == "user" { - current_role = proto::LanguageModelRole::LanguageModelUser; + match proto::LanguageModelRequestKind::from_i32(query.kind) { + Some(proto::LanguageModelRequestKind::Complete) => { + let mut chunks = google_ai::stream_generate_content( + session.http_client.as_ref(), + google_ai::API_URL, + &api_key, + serde_json::from_str(&query.request)?, + ) + .await?; + while let Some(chunk) = chunks.next().await { + let chunk = chunk?; + response.send(proto::QueryLanguageModelResponse { + response: serde_json::to_string(&chunk)?, + })?; } } - } - anthropic::ResponseEvent::ContentBlockStart { content_block, .. } => { - match content_block { - anthropic::ContentBlock::Text { text } => { - if !text.is_empty() { - response.send(proto::LanguageModelResponse { - choices: vec![proto::LanguageModelChoiceDelta { - index: 0, - delta: Some(proto::LanguageModelResponseMessage { - role: Some(current_role as i32), - content: Some(text), - tool_calls: Vec::new(), - }), - finish_reason: None, - }], - })?; - } - } - } - } - anthropic::ResponseEvent::ContentBlockDelta { delta, .. } => match delta { - anthropic::TextDelta::TextDelta { text } => { - response.send(proto::LanguageModelResponse { - choices: vec![proto::LanguageModelChoiceDelta { - index: 0, - delta: Some(proto::LanguageModelResponseMessage { - role: Some(current_role as i32), - content: Some(text), - tool_calls: Vec::new(), - }), - finish_reason: None, - }], - })?; - } - }, - anthropic::ResponseEvent::MessageDelta { delta, .. } => { - if let Some(stop_reason) = delta.stop_reason { - response.send(proto::LanguageModelResponse { - choices: vec![proto::LanguageModelChoiceDelta { - index: 0, - delta: None, - finish_reason: Some(stop_reason), - }], + Some(proto::LanguageModelRequestKind::CountTokens) => { + let tokens_response = google_ai::count_tokens( + session.http_client.as_ref(), + google_ai::API_URL, + &api_key, + serde_json::from_str(&query.request)?, + ) + .await?; + response.send(proto::QueryLanguageModelResponse { + response: serde_json::to_string(&tokens_response)?, })?; } + None => Err(anyhow!("unknown request kind"))?, } - anthropic::ResponseEvent::ContentBlockStop { .. } => {} - anthropic::ResponseEvent::MessageStop {} => {} - anthropic::ResponseEvent::Ping {} => {} } + None => return Err(anyhow!("unknown provider"))?, } Ok(()) @@ -4830,41 +4627,6 @@ impl RateLimit for CountTokensWithLanguageModelRateLimit { } } -async fn count_tokens_with_language_model( - request: proto::CountTokensWithLanguageModel, - response: Response, - session: UserSession, - google_ai_api_key: Option>, -) -> Result<()> { - authorize_access_to_language_models(&session).await?; - - if !request.model.starts_with("gemini") { - return Err(anyhow!( - "counting tokens for model: {:?} is not supported", - request.model - ))?; - } - - session - .rate_limiter - .check::(session.user_id()) - .await?; - - let api_key = google_ai_api_key - .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?; - let tokens_response = google_ai::count_tokens( - session.http_client.as_ref(), - google_ai::API_URL, - &api_key, - crate::ai::count_tokens_request_to_google_ai(request)?, - ) - .await?; - response.send(proto::CountTokensResponse { - token_count: tokens_response.total_tokens as u32, - })?; - Ok(()) -} - struct ComputeEmbeddingsRateLimit; impl RateLimit for ComputeEmbeddingsRateLimit { diff --git a/crates/google_ai/Cargo.toml b/crates/google_ai/Cargo.toml index 1495f55a31..f923e0ec91 100644 --- a/crates/google_ai/Cargo.toml +++ b/crates/google_ai/Cargo.toml @@ -11,9 +11,14 @@ workspace = true [lib] path = "src/google_ai.rs" +[features] +schemars = ["dep:schemars"] + [dependencies] anyhow.workspace = true futures.workspace = true http_client.workspace = true +schemars = { workspace = true, optional = true } serde.workspace = true serde_json.workspace = true +strum.workspace = true diff --git a/crates/google_ai/src/google_ai.rs b/crates/google_ai/src/google_ai.rs index 34c43176d0..b2ecf33243 100644 --- a/crates/google_ai/src/google_ai.rs +++ b/crates/google_ai/src/google_ai.rs @@ -1,23 +1,21 @@ -use std::sync::Arc; - use anyhow::{anyhow, Result}; -use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt}; +use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt}; use http_client::HttpClient; use serde::{Deserialize, Serialize}; pub const API_URL: &str = "https://generativelanguage.googleapis.com"; pub async fn stream_generate_content( - client: Arc, + client: &dyn HttpClient, api_url: &str, api_key: &str, - model: &str, - request: GenerateContentRequest, + mut request: GenerateContentRequest, ) -> Result>> { let uri = format!( - "{}/v1beta/models/{model}:streamGenerateContent?alt=sse&key={}", - api_url, api_key + "{api_url}/v1beta/models/{model}:streamGenerateContent?alt=sse&key={api_key}", + model = request.model ); + request.model.clear(); let request = serde_json::to_string(&request)?; let mut response = client.post_json(&uri, request.into()).await?; @@ -52,8 +50,8 @@ pub async fn stream_generate_content( } } -pub async fn count_tokens( - client: &T, +pub async fn count_tokens( + client: &dyn HttpClient, api_url: &str, api_key: &str, request: CountTokensRequest, @@ -91,22 +89,24 @@ pub enum Task { BatchEmbedContents, } -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct GenerateContentRequest { + #[serde(default, skip_serializing_if = "String::is_empty")] + pub model: String, pub contents: Vec, pub generation_config: Option, pub safety_settings: Option>, } -#[derive(Debug, Deserialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct GenerateContentResponse { pub candidates: Option>, pub prompt_feedback: Option, } -#[derive(Debug, Deserialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct GenerateContentCandidate { pub index: usize, @@ -157,7 +157,7 @@ pub struct GenerativeContentBlob { pub data: String, } -#[derive(Debug, Deserialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct CitationSource { pub start_index: Option, @@ -166,13 +166,13 @@ pub struct CitationSource { pub license: Option, } -#[derive(Debug, Deserialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct CitationMetadata { pub citation_sources: Vec, } -#[derive(Debug, Deserialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct PromptFeedback { pub block_reason: Option, @@ -180,7 +180,7 @@ pub struct PromptFeedback { pub block_reason_message: Option, } -#[derive(Debug, Serialize)] +#[derive(Debug, Deserialize, Serialize)] #[serde(rename_all = "camelCase")] pub struct GenerationConfig { pub candidate_count: Option, @@ -191,7 +191,7 @@ pub struct GenerationConfig { pub top_k: Option, } -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct SafetySetting { pub category: HarmCategory, @@ -224,7 +224,7 @@ pub enum HarmCategory { DangerousContent, } -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize, Deserialize)] pub enum HarmBlockThreshold { #[serde(rename = "HARM_BLOCK_THRESHOLD_UNSPECIFIED")] Unspecified, @@ -238,7 +238,7 @@ pub enum HarmBlockThreshold { BlockNone, } -#[derive(Debug, Deserialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "SCREAMING_SNAKE_CASE")] pub enum HarmProbability { #[serde(rename = "HARM_PROBABILITY_UNSPECIFIED")] @@ -249,21 +249,85 @@ pub enum HarmProbability { High, } -#[derive(Debug, Deserialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct SafetyRating { pub category: HarmCategory, pub probability: HarmProbability, } -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct CountTokensRequest { pub contents: Vec, } -#[derive(Debug, Deserialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct CountTokensResponse { pub total_tokens: usize, } + +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq, strum::EnumIter)] +pub enum Model { + #[serde(rename = "gemini-1.5-pro")] + Gemini15Pro, + #[serde(rename = "gemini-1.5-flash")] + Gemini15Flash, + #[serde(rename = "custom")] + Custom { name: String, max_tokens: usize }, +} + +impl Model { + pub fn id(&self) -> &str { + match self { + Model::Gemini15Pro => "gemini-1.5-pro", + Model::Gemini15Flash => "gemini-1.5-flash", + Model::Custom { name, .. } => name, + } + } + + pub fn display_name(&self) -> &str { + match self { + Model::Gemini15Pro => "Gemini 1.5 Pro", + Model::Gemini15Flash => "Gemini 1.5 Flash", + Model::Custom { name, .. } => name, + } + } + + pub fn max_token_count(&self) -> usize { + match self { + Model::Gemini15Pro => 2_000_000, + Model::Gemini15Flash => 1_000_000, + Model::Custom { max_tokens, .. } => *max_tokens, + } + } +} + +impl std::fmt::Display for Model { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.id()) + } +} + +pub fn extract_text_from_events( + events: impl Stream>, +) -> impl Stream> { + events.filter_map(|event| async move { + match event { + Ok(event) => event.candidates.and_then(|candidates| { + candidates.into_iter().next().and_then(|candidate| { + candidate.content.parts.into_iter().next().and_then(|part| { + if let Part::TextPart(TextPart { text }) = part { + Some(Ok(text)) + } else { + None + } + }) + }) + }), + Err(error) => Some(Err(error)), + } + }) +} diff --git a/crates/language_model/Cargo.toml b/crates/language_model/Cargo.toml index 1a099897a3..de3ba8ef65 100644 --- a/crates/language_model/Cargo.toml +++ b/crates/language_model/Cargo.toml @@ -28,6 +28,7 @@ collections.workspace = true editor.workspace = true feature_flags.workspace = true futures.workspace = true +google_ai = { workspace = true, features = ["schemars"] } gpui.workspace = true http_client.workspace = true menu.workspace = true diff --git a/crates/language_model/src/model/cloud_model.rs b/crates/language_model/src/model/cloud_model.rs index b7b304a65d..1023ee337a 100644 --- a/crates/language_model/src/model/cloud_model.rs +++ b/crates/language_model/src/model/cloud_model.rs @@ -1,108 +1,42 @@ -pub use anthropic::Model as AnthropicModel; -use anyhow::{anyhow, Result}; -pub use ollama::Model as OllamaModel; -pub use open_ai::Model as OpenAiModel; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; -use strum::EnumIter; -#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize, JsonSchema, EnumIter)] +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)] +#[serde(tag = "provider", rename_all = "lowercase")] pub enum CloudModel { - #[serde(rename = "gpt-3.5-turbo")] - Gpt3Point5Turbo, - #[serde(rename = "gpt-4")] - Gpt4, - #[serde(rename = "gpt-4-turbo-preview")] - Gpt4Turbo, - #[serde(rename = "gpt-4o")] - #[default] - Gpt4Omni, - #[serde(rename = "gpt-4o-mini")] - Gpt4OmniMini, - #[serde(rename = "claude-3-5-sonnet")] - Claude3_5Sonnet, - #[serde(rename = "claude-3-opus")] - Claude3Opus, - #[serde(rename = "claude-3-sonnet")] - Claude3Sonnet, - #[serde(rename = "claude-3-haiku")] - Claude3Haiku, - #[serde(rename = "gemini-1.5-pro")] - Gemini15Pro, - #[serde(rename = "gemini-1.5-flash")] - Gemini15Flash, - #[serde(rename = "custom")] - Custom { - name: String, - max_tokens: Option, - }, + Anthropic(anthropic::Model), + OpenAi(open_ai::Model), + Google(google_ai::Model), +} + +impl Default for CloudModel { + fn default() -> Self { + Self::Anthropic(anthropic::Model::default()) + } } impl CloudModel { - pub fn from_id(value: &str) -> Result { - match value { - "gpt-3.5-turbo" => Ok(Self::Gpt3Point5Turbo), - "gpt-4" => Ok(Self::Gpt4), - "gpt-4-turbo-preview" => Ok(Self::Gpt4Turbo), - "gpt-4o" => Ok(Self::Gpt4Omni), - "gpt-4o-mini" => Ok(Self::Gpt4OmniMini), - "claude-3-5-sonnet" => Ok(Self::Claude3_5Sonnet), - "claude-3-opus" => Ok(Self::Claude3Opus), - "claude-3-sonnet" => Ok(Self::Claude3Sonnet), - "claude-3-haiku" => Ok(Self::Claude3Haiku), - "gemini-1.5-pro" => Ok(Self::Gemini15Pro), - "gemini-1.5-flash" => Ok(Self::Gemini15Flash), - _ => Err(anyhow!("invalid model id")), - } - } - pub fn id(&self) -> &str { match self { - Self::Gpt3Point5Turbo => "gpt-3.5-turbo", - Self::Gpt4 => "gpt-4", - Self::Gpt4Turbo => "gpt-4-turbo-preview", - Self::Gpt4Omni => "gpt-4o", - Self::Gpt4OmniMini => "gpt-4o-mini", - Self::Claude3_5Sonnet => "claude-3-5-sonnet", - Self::Claude3Opus => "claude-3-opus", - Self::Claude3Sonnet => "claude-3-sonnet", - Self::Claude3Haiku => "claude-3-haiku", - Self::Gemini15Pro => "gemini-1.5-pro", - Self::Gemini15Flash => "gemini-1.5-flash", - Self::Custom { name, .. } => name, + CloudModel::Anthropic(model) => model.id(), + CloudModel::OpenAi(model) => model.id(), + CloudModel::Google(model) => model.id(), } } pub fn display_name(&self) -> &str { match self { - Self::Gpt3Point5Turbo => "GPT 3.5 Turbo", - Self::Gpt4 => "GPT 4", - Self::Gpt4Turbo => "GPT 4 Turbo", - Self::Gpt4Omni => "GPT 4 Omni", - Self::Gpt4OmniMini => "GPT 4 Omni Mini", - Self::Claude3_5Sonnet => "Claude 3.5 Sonnet", - Self::Claude3Opus => "Claude 3 Opus", - Self::Claude3Sonnet => "Claude 3 Sonnet", - Self::Claude3Haiku => "Claude 3 Haiku", - Self::Gemini15Pro => "Gemini 1.5 Pro", - Self::Gemini15Flash => "Gemini 1.5 Flash", - Self::Custom { name, .. } => name, + CloudModel::Anthropic(model) => model.display_name(), + CloudModel::OpenAi(model) => model.display_name(), + CloudModel::Google(model) => model.display_name(), } } pub fn max_token_count(&self) -> usize { match self { - Self::Gpt3Point5Turbo => 2048, - Self::Gpt4 => 4096, - Self::Gpt4Turbo | Self::Gpt4Omni => 128000, - Self::Gpt4OmniMini => 128000, - Self::Claude3_5Sonnet - | Self::Claude3Opus - | Self::Claude3Sonnet - | Self::Claude3Haiku => 200000, - Self::Gemini15Pro => 128000, - Self::Gemini15Flash => 32000, - Self::Custom { max_tokens, .. } => max_tokens.unwrap_or(200_000), + CloudModel::Anthropic(model) => model.max_token_count(), + CloudModel::OpenAi(model) => model.max_token_count(), + CloudModel::Google(model) => model.max_token_count(), } } } diff --git a/crates/language_model/src/provider.rs b/crates/language_model/src/provider.rs index f2713db003..6fe0bfd7a1 100644 --- a/crates/language_model/src/provider.rs +++ b/crates/language_model/src/provider.rs @@ -2,5 +2,6 @@ pub mod anthropic; pub mod cloud; #[cfg(any(test, feature = "test-support"))] pub mod fake; +pub mod google; pub mod ollama; pub mod open_ai; diff --git a/crates/language_model/src/provider/anthropic.rs b/crates/language_model/src/provider/anthropic.rs index 52ac22b29f..7cc9922546 100644 --- a/crates/language_model/src/provider/anthropic.rs +++ b/crates/language_model/src/provider/anthropic.rs @@ -1,4 +1,4 @@ -use anthropic::{stream_completion, Request, RequestMessage}; +use anthropic::stream_completion; use anyhow::{anyhow, Result}; use collections::BTreeMap; use editor::{Editor, EditorElement, EditorStyle}; @@ -18,7 +18,7 @@ use util::ResultExt; use crate::{ settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, - LanguageModelProviderState, LanguageModelRequest, LanguageModelRequestMessage, Role, + LanguageModelProviderState, LanguageModelRequest, Role, }; const PROVIDER_ID: &str = "anthropic"; @@ -160,40 +160,6 @@ pub struct AnthropicModel { http_client: Arc, } -impl AnthropicModel { - fn to_anthropic_request(&self, mut request: LanguageModelRequest) -> Request { - preprocess_anthropic_request(&mut request); - - let mut system_message = String::new(); - if request - .messages - .first() - .map_or(false, |message| message.role == Role::System) - { - system_message = request.messages.remove(0).content; - } - - Request { - model: self.model.id().to_string(), - messages: request - .messages - .iter() - .map(|msg| RequestMessage { - role: match msg.role { - Role::User => anthropic::Role::User, - Role::Assistant => anthropic::Role::Assistant, - Role::System => unreachable!("filtered out by preprocess_request"), - }, - content: msg.content.clone(), - }) - .collect(), - stream: true, - system: system_message, - max_tokens: 4092, - } - } -} - pub fn count_anthropic_tokens( request: LanguageModelRequest, cx: &AppContext, @@ -260,7 +226,7 @@ impl LanguageModel for AnthropicModel { request: LanguageModelRequest, cx: &AsyncAppContext, ) -> BoxFuture<'static, Result>>> { - let request = self.to_anthropic_request(request); + let request = request.into_anthropic(self.model.id().into()); let http_client = self.http_client.clone(); @@ -285,75 +251,12 @@ impl LanguageModel for AnthropicModel { low_speed_timeout, ); let response = request.await?; - let stream = response - .filter_map(|response| async move { - match response { - Ok(response) => match response { - anthropic::ResponseEvent::ContentBlockStart { - content_block, .. - } => match content_block { - anthropic::ContentBlock::Text { text } => Some(Ok(text)), - }, - anthropic::ResponseEvent::ContentBlockDelta { delta, .. } => { - match delta { - anthropic::TextDelta::TextDelta { text } => Some(Ok(text)), - } - } - _ => None, - }, - Err(error) => Some(Err(error)), - } - }) - .boxed(); - Ok(stream) + Ok(anthropic::extract_text_from_events(response).boxed()) } .boxed() } } -pub fn preprocess_anthropic_request(request: &mut LanguageModelRequest) { - let mut new_messages: Vec = Vec::new(); - let mut system_message = String::new(); - - for message in request.messages.drain(..) { - if message.content.is_empty() { - continue; - } - - match message.role { - Role::User | Role::Assistant => { - if let Some(last_message) = new_messages.last_mut() { - if last_message.role == message.role { - last_message.content.push_str("\n\n"); - last_message.content.push_str(&message.content); - continue; - } - } - - new_messages.push(message); - } - Role::System => { - if !system_message.is_empty() { - system_message.push_str("\n\n"); - } - system_message.push_str(&message.content); - } - } - } - - if !system_message.is_empty() { - new_messages.insert( - 0, - LanguageModelRequestMessage { - role: Role::System, - content: system_message, - }, - ); - } - - request.messages = new_messages; -} - struct AuthenticationPrompt { api_key: View, state: gpui::Model, diff --git a/crates/language_model/src/provider/cloud.rs b/crates/language_model/src/provider/cloud.rs index 1cd8b99e98..d290876ad9 100644 --- a/crates/language_model/src/provider/cloud.rs +++ b/crates/language_model/src/provider/cloud.rs @@ -7,8 +7,10 @@ use crate::{ use anyhow::Result; use client::Client; use collections::BTreeMap; -use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt, TryFutureExt}; +use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt}; use gpui::{AnyView, AppContext, AsyncAppContext, Subscription, Task}; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; use settings::{Settings, SettingsStore}; use std::sync::Arc; use strum::IntoEnumIterator; @@ -16,14 +18,29 @@ use ui::prelude::*; use crate::LanguageModelProvider; -use super::anthropic::{count_anthropic_tokens, preprocess_anthropic_request}; +use super::anthropic::count_anthropic_tokens; pub const PROVIDER_ID: &str = "zed.dev"; pub const PROVIDER_NAME: &str = "zed.dev"; #[derive(Default, Clone, Debug, PartialEq)] pub struct ZedDotDevSettings { - pub available_models: Vec, + pub available_models: Vec, +} + +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)] +#[serde(rename_all = "lowercase")] +pub enum AvailableProvider { + Anthropic, + OpenAi, + Google, +} + +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)] +pub struct AvailableModel { + provider: AvailableProvider, + name: String, + max_tokens: usize, } pub struct CloudLanguageModelProvider { @@ -100,10 +117,19 @@ impl LanguageModelProvider for CloudLanguageModelProvider { fn provided_models(&self, cx: &AppContext) -> Vec> { let mut models = BTreeMap::default(); - // Add base models from CloudModel::iter() - for model in CloudModel::iter() { - if !matches!(model, CloudModel::Custom { .. }) { - models.insert(model.id().to_string(), model); + for model in anthropic::Model::iter() { + if !matches!(model, anthropic::Model::Custom { .. }) { + models.insert(model.id().to_string(), CloudModel::Anthropic(model)); + } + } + for model in open_ai::Model::iter() { + if !matches!(model, open_ai::Model::Custom { .. }) { + models.insert(model.id().to_string(), CloudModel::OpenAi(model)); + } + } + for model in google_ai::Model::iter() { + if !matches!(model, google_ai::Model::Custom { .. }) { + models.insert(model.id().to_string(), CloudModel::Google(model)); } } @@ -112,6 +138,20 @@ impl LanguageModelProvider for CloudLanguageModelProvider { .zed_dot_dev .available_models { + let model = match model.provider { + AvailableProvider::Anthropic => CloudModel::Anthropic(anthropic::Model::Custom { + name: model.name.clone(), + max_tokens: model.max_tokens, + }), + AvailableProvider::OpenAi => CloudModel::OpenAi(open_ai::Model::Custom { + name: model.name.clone(), + max_tokens: model.max_tokens, + }), + AvailableProvider::Google => CloudModel::Google(google_ai::Model::Custom { + name: model.name.clone(), + max_tokens: model.max_tokens, + }), + }; models.insert(model.id().to_string(), model.clone()); } @@ -183,35 +223,26 @@ impl LanguageModel for CloudLanguageModel { request: LanguageModelRequest, cx: &AppContext, ) -> BoxFuture<'static, Result> { - match &self.model { - CloudModel::Gpt3Point5Turbo => { - count_open_ai_tokens(request, open_ai::Model::ThreePointFiveTurbo, cx) - } - CloudModel::Gpt4 => count_open_ai_tokens(request, open_ai::Model::Four, cx), - CloudModel::Gpt4Turbo => count_open_ai_tokens(request, open_ai::Model::FourTurbo, cx), - CloudModel::Gpt4Omni => count_open_ai_tokens(request, open_ai::Model::FourOmni, cx), - CloudModel::Gpt4OmniMini => { - count_open_ai_tokens(request, open_ai::Model::FourOmniMini, cx) - } - CloudModel::Claude3_5Sonnet - | CloudModel::Claude3Opus - | CloudModel::Claude3Sonnet - | CloudModel::Claude3Haiku => count_anthropic_tokens(request, cx), - CloudModel::Custom { name, .. } if name.starts_with("anthropic/") => { - count_anthropic_tokens(request, cx) - } - _ => { - let request = self.client.request(proto::CountTokensWithLanguageModel { - model: self.model.id().to_string(), - messages: request - .messages - .iter() - .map(|message| message.to_proto()) - .collect(), - }); + match self.model.clone() { + CloudModel::Anthropic(_) => count_anthropic_tokens(request, cx), + CloudModel::OpenAi(model) => count_open_ai_tokens(request, model, cx), + CloudModel::Google(model) => { + let client = self.client.clone(); + let request = request.into_google(model.id().into()); + let request = google_ai::CountTokensRequest { + contents: request.contents, + }; async move { - let response = request.await?; - Ok(response.token_count as usize) + let request = serde_json::to_string(&request)?; + let response = client.request(proto::QueryLanguageModel { + provider: proto::LanguageModelProvider::Google as i32, + kind: proto::LanguageModelRequestKind::CountTokens as i32, + request, + }); + let response = response.await?; + let response = + serde_json::from_str::(&response.response)?; + Ok(response.total_tokens) } .boxed() } @@ -220,46 +251,65 @@ impl LanguageModel for CloudLanguageModel { fn stream_completion( &self, - mut request: LanguageModelRequest, + request: LanguageModelRequest, _: &AsyncAppContext, ) -> BoxFuture<'static, Result>>> { match &self.model { - CloudModel::Claude3Opus - | CloudModel::Claude3Sonnet - | CloudModel::Claude3Haiku - | CloudModel::Claude3_5Sonnet => preprocess_anthropic_request(&mut request), - CloudModel::Custom { name, .. } if name.starts_with("anthropic/") => { - preprocess_anthropic_request(&mut request) + CloudModel::Anthropic(model) => { + let client = self.client.clone(); + let request = request.into_anthropic(model.id().into()); + async move { + let request = serde_json::to_string(&request)?; + let response = client.request_stream(proto::QueryLanguageModel { + provider: proto::LanguageModelProvider::Anthropic as i32, + kind: proto::LanguageModelRequestKind::Complete as i32, + request, + }); + let chunks = response.await?; + Ok(anthropic::extract_text_from_events( + chunks.map(|chunk| Ok(serde_json::from_str(&chunk?.response)?)), + ) + .boxed()) + } + .boxed() + } + CloudModel::OpenAi(model) => { + let client = self.client.clone(); + let request = request.into_open_ai(model.id().into()); + async move { + let request = serde_json::to_string(&request)?; + let response = client.request_stream(proto::QueryLanguageModel { + provider: proto::LanguageModelProvider::OpenAi as i32, + kind: proto::LanguageModelRequestKind::Complete as i32, + request, + }); + let chunks = response.await?; + Ok(open_ai::extract_text_from_events( + chunks.map(|chunk| Ok(serde_json::from_str(&chunk?.response)?)), + ) + .boxed()) + } + .boxed() + } + CloudModel::Google(model) => { + let client = self.client.clone(); + let request = request.into_google(model.id().into()); + async move { + let request = serde_json::to_string(&request)?; + let response = client.request_stream(proto::QueryLanguageModel { + provider: proto::LanguageModelProvider::Google as i32, + kind: proto::LanguageModelRequestKind::Complete as i32, + request, + }); + let chunks = response.await?; + Ok(google_ai::extract_text_from_events( + chunks.map(|chunk| Ok(serde_json::from_str(&chunk?.response)?)), + ) + .boxed()) + } + .boxed() } - _ => {} } - - let request = proto::CompleteWithLanguageModel { - model: self.id.0.to_string(), - messages: request - .messages - .iter() - .map(|message| message.to_proto()) - .collect(), - stop: request.stop, - temperature: request.temperature, - tools: Vec::new(), - tool_choice: None, - }; - - self.client - .request_stream(request) - .map_ok(|stream| { - stream - .filter_map(|response| async move { - match response { - Ok(mut response) => Some(Ok(response.choices.pop()?.delta?.content?)), - Err(error) => Some(Err(error)), - } - }) - .boxed() - }) - .boxed() } } diff --git a/crates/language_model/src/provider/google.rs b/crates/language_model/src/provider/google.rs new file mode 100644 index 0000000000..3a0c0a3f7e --- /dev/null +++ b/crates/language_model/src/provider/google.rs @@ -0,0 +1,351 @@ +use anyhow::{anyhow, Result}; +use collections::BTreeMap; +use editor::{Editor, EditorElement, EditorStyle}; +use futures::{future::BoxFuture, FutureExt, StreamExt}; +use google_ai::stream_generate_content; +use gpui::{ + AnyView, AppContext, AsyncAppContext, FontStyle, Subscription, Task, TextStyle, View, + WhiteSpace, +}; +use http_client::HttpClient; +use settings::{Settings, SettingsStore}; +use std::{sync::Arc, time::Duration}; +use strum::IntoEnumIterator; +use theme::ThemeSettings; +use ui::prelude::*; +use util::ResultExt; + +use crate::{ + settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName, + LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, + LanguageModelProviderState, LanguageModelRequest, +}; + +const PROVIDER_ID: &str = "google"; +const PROVIDER_NAME: &str = "Google AI"; + +#[derive(Default, Clone, Debug, PartialEq)] +pub struct GoogleSettings { + pub api_url: String, + pub low_speed_timeout: Option, + pub available_models: Vec, +} + +pub struct GoogleLanguageModelProvider { + http_client: Arc, + state: gpui::Model, +} + +struct State { + api_key: Option, + _subscription: Subscription, +} + +impl GoogleLanguageModelProvider { + pub fn new(http_client: Arc, cx: &mut AppContext) -> Self { + let state = cx.new_model(|cx| State { + api_key: None, + _subscription: cx.observe_global::(|_, cx| { + cx.notify(); + }), + }); + + Self { http_client, state } + } +} + +impl LanguageModelProviderState for GoogleLanguageModelProvider { + fn subscribe(&self, cx: &mut gpui::ModelContext) -> Option { + Some(cx.observe(&self.state, |_, _, cx| { + cx.notify(); + })) + } +} + +impl LanguageModelProvider for GoogleLanguageModelProvider { + fn id(&self) -> LanguageModelProviderId { + LanguageModelProviderId(PROVIDER_ID.into()) + } + + fn name(&self) -> LanguageModelProviderName { + LanguageModelProviderName(PROVIDER_NAME.into()) + } + + fn provided_models(&self, cx: &AppContext) -> Vec> { + let mut models = BTreeMap::default(); + + // Add base models from google_ai::Model::iter() + for model in google_ai::Model::iter() { + if !matches!(model, google_ai::Model::Custom { .. }) { + models.insert(model.id().to_string(), model); + } + } + + // Override with available models from settings + for model in &AllLanguageModelSettings::get_global(cx) + .google + .available_models + { + models.insert(model.id().to_string(), model.clone()); + } + + models + .into_values() + .map(|model| { + Arc::new(GoogleLanguageModel { + id: LanguageModelId::from(model.id().to_string()), + model, + state: self.state.clone(), + http_client: self.http_client.clone(), + }) as Arc + }) + .collect() + } + + fn is_authenticated(&self, cx: &AppContext) -> bool { + self.state.read(cx).api_key.is_some() + } + + fn authenticate(&self, cx: &AppContext) -> Task> { + if self.is_authenticated(cx) { + Task::ready(Ok(())) + } else { + let api_url = AllLanguageModelSettings::get_global(cx) + .google + .api_url + .clone(); + let state = self.state.clone(); + cx.spawn(|mut cx| async move { + let api_key = if let Ok(api_key) = std::env::var("GOOGLE_AI_API_KEY") { + api_key + } else { + let (_, api_key) = cx + .update(|cx| cx.read_credentials(&api_url))? + .await? + .ok_or_else(|| anyhow!("credentials not found"))?; + String::from_utf8(api_key)? + }; + + state.update(&mut cx, |this, cx| { + this.api_key = Some(api_key); + cx.notify(); + }) + }) + } + } + + fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView { + cx.new_view(|cx| AuthenticationPrompt::new(self.state.clone(), cx)) + .into() + } + + fn reset_credentials(&self, cx: &AppContext) -> Task> { + let state = self.state.clone(); + let delete_credentials = + cx.delete_credentials(&AllLanguageModelSettings::get_global(cx).google.api_url); + cx.spawn(|mut cx| async move { + delete_credentials.await.log_err(); + state.update(&mut cx, |this, cx| { + this.api_key = None; + cx.notify(); + }) + }) + } +} + +pub struct GoogleLanguageModel { + id: LanguageModelId, + model: google_ai::Model, + state: gpui::Model, + http_client: Arc, +} + +impl LanguageModel for GoogleLanguageModel { + fn id(&self) -> LanguageModelId { + self.id.clone() + } + + fn name(&self) -> LanguageModelName { + LanguageModelName::from(self.model.display_name().to_string()) + } + + fn provider_id(&self) -> LanguageModelProviderId { + LanguageModelProviderId(PROVIDER_ID.into()) + } + + fn provider_name(&self) -> LanguageModelProviderName { + LanguageModelProviderName(PROVIDER_NAME.into()) + } + + fn telemetry_id(&self) -> String { + format!("google/{}", self.model.id()) + } + + fn max_token_count(&self) -> usize { + self.model.max_token_count() + } + + fn count_tokens( + &self, + request: LanguageModelRequest, + cx: &AppContext, + ) -> BoxFuture<'static, Result> { + let request = request.into_google(self.model.id().to_string()); + let http_client = self.http_client.clone(); + let api_key = self.state.read(cx).api_key.clone(); + let api_url = AllLanguageModelSettings::get_global(cx) + .google + .api_url + .clone(); + + async move { + let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?; + let response = google_ai::count_tokens( + http_client.as_ref(), + &api_url, + &api_key, + google_ai::CountTokensRequest { + contents: request.contents, + }, + ) + .await?; + Ok(response.total_tokens) + } + .boxed() + } + + fn stream_completion( + &self, + request: LanguageModelRequest, + cx: &AsyncAppContext, + ) -> BoxFuture<'static, Result>>> { + let request = request.into_google(self.model.id().to_string()); + + let http_client = self.http_client.clone(); + let Ok((api_key, api_url)) = cx.read_model(&self.state, |state, cx| { + let settings = &AllLanguageModelSettings::get_global(cx).google; + (state.api_key.clone(), settings.api_url.clone()) + }) else { + return futures::future::ready(Err(anyhow!("App state dropped"))).boxed(); + }; + + async move { + let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?; + let response = + stream_generate_content(http_client.as_ref(), &api_url, &api_key, request); + let events = response.await?; + Ok(google_ai::extract_text_from_events(events).boxed()) + } + .boxed() + } +} + +struct AuthenticationPrompt { + api_key: View, + state: gpui::Model, +} + +impl AuthenticationPrompt { + fn new(state: gpui::Model, cx: &mut WindowContext) -> Self { + Self { + api_key: cx.new_view(|cx| { + let mut editor = Editor::single_line(cx); + editor.set_placeholder_text("AIzaSy...", cx); + editor + }), + state, + } + } + + fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext) { + let api_key = self.api_key.read(cx).text(cx); + if api_key.is_empty() { + return; + } + + let settings = &AllLanguageModelSettings::get_global(cx).google; + let write_credentials = + cx.write_credentials(&settings.api_url, "Bearer", api_key.as_bytes()); + let state = self.state.clone(); + cx.spawn(|_, mut cx| async move { + write_credentials.await?; + state.update(&mut cx, |this, cx| { + this.api_key = Some(api_key); + cx.notify(); + }) + }) + .detach_and_log_err(cx); + } + + fn render_api_key_editor(&self, cx: &mut ViewContext) -> impl IntoElement { + let settings = ThemeSettings::get_global(cx); + let text_style = TextStyle { + color: cx.theme().colors().text, + font_family: settings.ui_font.family.clone(), + font_features: settings.ui_font.features.clone(), + font_fallbacks: settings.ui_font.fallbacks.clone(), + font_size: rems(0.875).into(), + font_weight: settings.ui_font.weight, + font_style: FontStyle::Normal, + line_height: relative(1.3), + background_color: None, + underline: None, + strikethrough: None, + white_space: WhiteSpace::Normal, + }; + EditorElement::new( + &self.api_key, + EditorStyle { + background: cx.theme().colors().editor_background, + local_player: cx.theme().players().local(), + text: text_style, + ..Default::default() + }, + ) + } +} + +impl Render for AuthenticationPrompt { + fn render(&mut self, cx: &mut ViewContext) -> impl IntoElement { + const INSTRUCTIONS: [&str; 4] = [ + "To use the Google AI assistant, you need to add your Google AI API key.", + "You can create an API key at: https://makersuite.google.com/app/apikey", + "", + "Paste your Google AI API key below and hit enter to use the assistant:", + ]; + + v_flex() + .p_4() + .size_full() + .on_action(cx.listener(Self::save_api_key)) + .children( + INSTRUCTIONS.map(|instruction| Label::new(instruction).size(LabelSize::Small)), + ) + .child( + h_flex() + .w_full() + .my_2() + .px_2() + .py_1() + .bg(cx.theme().colors().editor_background) + .rounded_md() + .child(self.render_api_key_editor(cx)), + ) + .child( + Label::new( + "You can also assign the GOOGLE_AI_API_KEY environment variable and restart Zed.", + ) + .size(LabelSize::Small), + ) + .child( + h_flex() + .gap_2() + .child(Label::new("Click on").size(LabelSize::Small)) + .child(Icon::new(IconName::ZedAssistant).size(IconSize::XSmall)) + .child( + Label::new("in the status bar to close this panel.").size(LabelSize::Small), + ), + ) + .into_any() + } +} diff --git a/crates/language_model/src/provider/open_ai.rs b/crates/language_model/src/provider/open_ai.rs index c81a435946..1b3bf18dd5 100644 --- a/crates/language_model/src/provider/open_ai.rs +++ b/crates/language_model/src/provider/open_ai.rs @@ -7,7 +7,7 @@ use gpui::{ WhiteSpace, }; use http_client::HttpClient; -use open_ai::{stream_completion, Request, RequestMessage}; +use open_ai::stream_completion; use settings::{Settings, SettingsStore}; use std::{sync::Arc, time::Duration}; use strum::IntoEnumIterator; @@ -159,35 +159,6 @@ pub struct OpenAiLanguageModel { http_client: Arc, } -impl OpenAiLanguageModel { - fn to_open_ai_request(&self, request: LanguageModelRequest) -> Request { - Request { - model: self.model.clone(), - messages: request - .messages - .into_iter() - .map(|msg| match msg.role { - Role::User => RequestMessage::User { - content: msg.content, - }, - Role::Assistant => RequestMessage::Assistant { - content: Some(msg.content), - tool_calls: Vec::new(), - }, - Role::System => RequestMessage::System { - content: msg.content, - }, - }) - .collect(), - stream: true, - stop: request.stop, - temperature: request.temperature, - tools: Vec::new(), - tool_choice: None, - } - } -} - impl LanguageModel for OpenAiLanguageModel { fn id(&self) -> LanguageModelId { self.id.clone() @@ -226,7 +197,7 @@ impl LanguageModel for OpenAiLanguageModel { request: LanguageModelRequest, cx: &AsyncAppContext, ) -> BoxFuture<'static, Result>>> { - let request = self.to_open_ai_request(request); + let request = request.into_open_ai(self.model.id().into()); let http_client = self.http_client.clone(); let Ok((api_key, api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, cx| { @@ -250,15 +221,7 @@ impl LanguageModel for OpenAiLanguageModel { low_speed_timeout, ); let response = request.await?; - let stream = response - .filter_map(|response| async move { - match response { - Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)), - Err(error) => Some(Err(error)), - } - }) - .boxed(); - Ok(stream) + Ok(open_ai::extract_text_from_events(response).boxed()) } .boxed() } diff --git a/crates/language_model/src/registry.rs b/crates/language_model/src/registry.rs index e787f5f7e7..05dcbced5d 100644 --- a/crates/language_model/src/registry.rs +++ b/crates/language_model/src/registry.rs @@ -1,17 +1,17 @@ +use crate::{ + provider::{ + anthropic::AnthropicLanguageModelProvider, cloud::CloudLanguageModelProvider, + google::GoogleLanguageModelProvider, ollama::OllamaLanguageModelProvider, + open_ai::OpenAiLanguageModelProvider, + }, + LanguageModel, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderState, +}; use client::Client; use collections::BTreeMap; use gpui::{AppContext, Global, Model, ModelContext}; use std::sync::Arc; use ui::Context; -use crate::{ - provider::{ - anthropic::AnthropicLanguageModelProvider, cloud::CloudLanguageModelProvider, - ollama::OllamaLanguageModelProvider, open_ai::OpenAiLanguageModelProvider, - }, - LanguageModel, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderState, -}; - pub fn init(client: Arc, cx: &mut AppContext) { let registry = cx.new_model(|cx| { let mut registry = LanguageModelRegistry::default(); @@ -40,6 +40,10 @@ fn register_language_model_providers( OllamaLanguageModelProvider::new(client.http_client(), cx), cx, ); + registry.register_provider( + GoogleLanguageModelProvider::new(client.http_client(), cx), + cx, + ); cx.observe_flag::(move |enabled, cx| { let client = client.clone(); diff --git a/crates/language_model/src/request.rs b/crates/language_model/src/request.rs index e3e1d3e77b..fc3b8c0192 100644 --- a/crates/language_model/src/request.rs +++ b/crates/language_model/src/request.rs @@ -1,4 +1,4 @@ -use crate::{role::Role, LanguageModelId}; +use crate::role::Role; use serde::{Deserialize, Serialize}; #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] @@ -7,17 +7,6 @@ pub struct LanguageModelRequestMessage { pub content: String, } -impl LanguageModelRequestMessage { - pub fn to_proto(&self) -> proto::LanguageModelRequestMessage { - proto::LanguageModelRequestMessage { - role: self.role.to_proto() as i32, - content: self.content.clone(), - tool_calls: Vec::new(), - tool_call_id: None, - } - } -} - #[derive(Debug, Default, Serialize, Deserialize)] pub struct LanguageModelRequest { pub messages: Vec, @@ -26,14 +15,110 @@ pub struct LanguageModelRequest { } impl LanguageModelRequest { - pub fn to_proto(&self, model_id: LanguageModelId) -> proto::CompleteWithLanguageModel { - proto::CompleteWithLanguageModel { - model: model_id.0.to_string(), - messages: self.messages.iter().map(|m| m.to_proto()).collect(), - stop: self.stop.clone(), + pub fn into_open_ai(self, model: String) -> open_ai::Request { + open_ai::Request { + model, + messages: self + .messages + .into_iter() + .map(|msg| match msg.role { + Role::User => open_ai::RequestMessage::User { + content: msg.content, + }, + Role::Assistant => open_ai::RequestMessage::Assistant { + content: Some(msg.content), + tool_calls: Vec::new(), + }, + Role::System => open_ai::RequestMessage::System { + content: msg.content, + }, + }) + .collect(), + stream: true, + stop: self.stop, temperature: self.temperature, - tool_choice: None, tools: Vec::new(), + tool_choice: None, + } + } + + pub fn into_google(self, model: String) -> google_ai::GenerateContentRequest { + google_ai::GenerateContentRequest { + model, + contents: self + .messages + .into_iter() + .map(|msg| google_ai::Content { + parts: vec![google_ai::Part::TextPart(google_ai::TextPart { + text: msg.content, + })], + role: match msg.role { + Role::User => google_ai::Role::User, + Role::Assistant => google_ai::Role::Model, + Role::System => google_ai::Role::User, // Google AI doesn't have a system role + }, + }) + .collect(), + generation_config: Some(google_ai::GenerationConfig { + candidate_count: Some(1), + stop_sequences: Some(self.stop), + max_output_tokens: None, + temperature: Some(self.temperature as f64), + top_p: None, + top_k: None, + }), + safety_settings: None, + } + } + + pub fn into_anthropic(self, model: String) -> anthropic::Request { + let mut new_messages: Vec = Vec::new(); + let mut system_message = String::new(); + + for message in self.messages { + if message.content.is_empty() { + continue; + } + + match message.role { + Role::User | Role::Assistant => { + if let Some(last_message) = new_messages.last_mut() { + if last_message.role == message.role { + last_message.content.push_str("\n\n"); + last_message.content.push_str(&message.content); + continue; + } + } + + new_messages.push(message); + } + Role::System => { + if !system_message.is_empty() { + system_message.push_str("\n\n"); + } + system_message.push_str(&message.content); + } + } + } + + anthropic::Request { + model, + messages: new_messages + .into_iter() + .filter_map(|message| { + Some(anthropic::RequestMessage { + role: match message.role { + Role::User => anthropic::Role::User, + Role::Assistant => anthropic::Role::Assistant, + Role::System => return None, + }, + content: message.content, + }) + }) + .collect(), + stream: true, + max_tokens: 4092, + system: system_message, } } } diff --git a/crates/language_model/src/role.rs b/crates/language_model/src/role.rs index f6276a4823..82184038f6 100644 --- a/crates/language_model/src/role.rs +++ b/crates/language_model/src/role.rs @@ -15,7 +15,6 @@ impl Role { Some(proto::LanguageModelRole::LanguageModelUser) => Role::User, Some(proto::LanguageModelRole::LanguageModelAssistant) => Role::Assistant, Some(proto::LanguageModelRole::LanguageModelSystem) => Role::System, - Some(proto::LanguageModelRole::LanguageModelTool) => Role::System, None => Role::User, } } diff --git a/crates/language_model/src/settings.rs b/crates/language_model/src/settings.rs index 262e14937a..85ae91649a 100644 --- a/crates/language_model/src/settings.rs +++ b/crates/language_model/src/settings.rs @@ -6,12 +6,12 @@ use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use settings::{Settings, SettingsSources}; -use crate::{ - provider::{ - anthropic::AnthropicSettings, cloud::ZedDotDevSettings, ollama::OllamaSettings, - open_ai::OpenAiSettings, - }, - CloudModel, +use crate::provider::{ + anthropic::AnthropicSettings, + cloud::{self, ZedDotDevSettings}, + google::GoogleSettings, + ollama::OllamaSettings, + open_ai::OpenAiSettings, }; /// Initializes the language model settings. @@ -25,6 +25,7 @@ pub struct AllLanguageModelSettings { pub ollama: OllamaSettings, pub openai: OpenAiSettings, pub zed_dot_dev: ZedDotDevSettings, + pub google: GoogleSettings, } #[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)] @@ -34,6 +35,7 @@ pub struct AllLanguageModelSettingsContent { pub openai: Option, #[serde(rename = "zed.dev")] pub zed_dot_dev: Option, + pub google: Option, } #[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)] @@ -56,9 +58,16 @@ pub struct OpenAiSettingsContent { pub available_models: Option>, } +#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)] +pub struct GoogleSettingsContent { + pub api_url: Option, + pub low_speed_timeout_in_seconds: Option, + pub available_models: Option>, +} + #[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)] pub struct ZedDotDevSettingsContent { - available_models: Option>, + available_models: Option>, } impl settings::Settings for AllLanguageModelSettings { @@ -136,6 +145,26 @@ impl settings::Settings for AllLanguageModelSettings { .as_ref() .and_then(|s| s.available_models.clone()), ); + + merge( + &mut settings.google.api_url, + value.google.as_ref().and_then(|s| s.api_url.clone()), + ); + if let Some(low_speed_timeout_in_seconds) = value + .google + .as_ref() + .and_then(|s| s.low_speed_timeout_in_seconds) + { + settings.google.low_speed_timeout = + Some(Duration::from_secs(low_speed_timeout_in_seconds)); + } + merge( + &mut settings.google.available_models, + value + .google + .as_ref() + .and_then(|s| s.available_models.clone()), + ); } Ok(settings) diff --git a/crates/open_ai/src/open_ai.rs b/crates/open_ai/src/open_ai.rs index dfcd6646d1..13a6eb11d1 100644 --- a/crates/open_ai/src/open_ai.rs +++ b/crates/open_ai/src/open_ai.rs @@ -1,5 +1,5 @@ use anyhow::{anyhow, Context, Result}; -use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt}; +use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt}; use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest}; use isahc::config::Configurable; use serde::{Deserialize, Serialize}; @@ -111,38 +111,27 @@ impl Model { } } -fn serialize_model(model: &Model, serializer: S) -> Result -where - S: serde::Serializer, -{ - match model { - Model::Custom { name, .. } => serializer.serialize_str(name), - _ => serializer.serialize_str(model.id()), - } -} - -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize, Deserialize)] pub struct Request { - #[serde(serialize_with = "serialize_model")] - pub model: Model, + pub model: String, pub messages: Vec, pub stream: bool, pub stop: Vec, pub temperature: f32, - #[serde(skip_serializing_if = "Option::is_none")] + #[serde(default, skip_serializing_if = "Option::is_none")] pub tool_choice: Option, - #[serde(skip_serializing_if = "Vec::is_empty")] + #[serde(default, skip_serializing_if = "Vec::is_empty")] pub tools: Vec, } -#[derive(Debug, Serialize)] +#[derive(Debug, Deserialize, Serialize)] pub struct FunctionDefinition { pub name: String, pub description: Option, pub parameters: Option>, } -#[derive(Serialize, Debug)] +#[derive(Deserialize, Serialize, Debug)] #[serde(tag = "type", rename_all = "snake_case")] pub enum ToolDefinition { #[allow(dead_code)] @@ -213,21 +202,21 @@ pub struct FunctionChunk { pub arguments: Option, } -#[derive(Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug)] pub struct Usage { pub prompt_tokens: u32, pub completion_tokens: u32, pub total_tokens: u32, } -#[derive(Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug)] pub struct ChoiceDelta { pub index: u32, pub delta: ResponseMessageDelta, pub finish_reason: Option, } -#[derive(Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug)] pub struct ResponseStreamEvent { pub created: u32, pub model: String, @@ -369,3 +358,14 @@ pub fn embed<'a>( } } } + +pub fn extract_text_from_events( + response: impl Stream>, +) -> impl Stream> { + response.filter_map(|response| async move { + match response { + Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)), + Err(error) => Some(Err(error)), + } + }) +} diff --git a/crates/proto/proto/zed.proto b/crates/proto/proto/zed.proto index 60f8d01558..658d552848 100644 --- a/crates/proto/proto/zed.proto +++ b/crates/proto/proto/zed.proto @@ -13,13 +13,6 @@ message Envelope { optional uint32 responding_to = 2; optional PeerId original_sender_id = 3; - /* - When you are adding a new message type, instead of adding it in semantic order - and bumping the message ID's of everything that follows, add it at the end of the - file and bump the max number. See this - https://github.com/zed-industries/zed/pull/7890#discussion_r1496621823 - - */ oneof payload { Hello hello = 4; Ack ack = 5; @@ -201,10 +194,8 @@ message Envelope { JoinHostedProject join_hosted_project = 164; - CompleteWithLanguageModel complete_with_language_model = 166; - LanguageModelResponse language_model_response = 167; - CountTokensWithLanguageModel count_tokens_with_language_model = 168; - CountTokensResponse count_tokens_response = 169; + QueryLanguageModel query_language_model = 224; + QueryLanguageModelResponse query_language_model_response = 225; // current max GetCachedEmbeddings get_cached_embeddings = 189; GetCachedEmbeddingsResponse get_cached_embeddings_response = 190; ComputeEmbeddings compute_embeddings = 191; @@ -271,10 +262,11 @@ message Envelope { UpdateDevServerProject update_dev_server_project = 221; AddWorktree add_worktree = 222; - AddWorktreeResponse add_worktree_response = 223; // current max + AddWorktreeResponse add_worktree_response = 223; } reserved 158 to 161; + reserved 166 to 169; } // Messages @@ -2051,94 +2043,32 @@ message SetRoomParticipantRole { ChannelRole role = 3; } -message CompleteWithLanguageModel { - string model = 1; - repeated LanguageModelRequestMessage messages = 2; - repeated string stop = 3; - float temperature = 4; - repeated ChatCompletionTool tools = 5; - optional string tool_choice = 6; -} - -// A tool presented to the language model for its use -message ChatCompletionTool { - oneof variant { - FunctionObject function = 1; - } - - message FunctionObject { - string name = 1; - optional string description = 2; - optional string parameters = 3; - } -} - -// A message to the language model -message LanguageModelRequestMessage { - LanguageModelRole role = 1; - string content = 2; - optional string tool_call_id = 3; - repeated ToolCall tool_calls = 4; -} - enum LanguageModelRole { LanguageModelUser = 0; LanguageModelAssistant = 1; LanguageModelSystem = 2; - LanguageModelTool = 3; + reserved 3; } -message LanguageModelResponseMessage { - optional LanguageModelRole role = 1; - optional string content = 2; - repeated ToolCallDelta tool_calls = 3; +message QueryLanguageModel { + LanguageModelProvider provider = 1; + LanguageModelRequestKind kind = 2; + string request = 3; } -// A request to call a tool, by the language model -message ToolCall { - string id = 1; - - oneof variant { - FunctionCall function = 2; - } - - message FunctionCall { - string name = 1; - string arguments = 2; - } +enum LanguageModelProvider { + Anthropic = 0; + OpenAI = 1; + Google = 2; } -message ToolCallDelta { - uint32 index = 1; - optional string id = 2; - - oneof variant { - FunctionCallDelta function = 3; - } - - message FunctionCallDelta { - optional string name = 1; - optional string arguments = 2; - } +enum LanguageModelRequestKind { + Complete = 0; + CountTokens = 1; } -message LanguageModelResponse { - repeated LanguageModelChoiceDelta choices = 1; -} - -message LanguageModelChoiceDelta { - uint32 index = 1; - LanguageModelResponseMessage delta = 2; - optional string finish_reason = 3; -} - -message CountTokensWithLanguageModel { - string model = 1; - repeated LanguageModelRequestMessage messages = 2; -} - -message CountTokensResponse { - uint32 token_count = 1; +message QueryLanguageModelResponse { + string response = 1; } message GetCachedEmbeddings { diff --git a/crates/proto/src/proto.rs b/crates/proto/src/proto.rs index a205b79ecb..7ef1866acd 100644 --- a/crates/proto/src/proto.rs +++ b/crates/proto/src/proto.rs @@ -203,12 +203,9 @@ messages!( (CancelCall, Foreground), (ChannelMessageSent, Foreground), (ChannelMessageUpdate, Foreground), - (CompleteWithLanguageModel, Background), (ComputeEmbeddings, Background), (ComputeEmbeddingsResponse, Background), (CopyProjectEntry, Foreground), - (CountTokensWithLanguageModel, Background), - (CountTokensResponse, Background), (CreateBufferForPeer, Foreground), (CreateChannel, Foreground), (CreateChannelResponse, Foreground), @@ -278,7 +275,6 @@ messages!( (JoinProjectResponse, Foreground), (JoinRoom, Foreground), (JoinRoomResponse, Foreground), - (LanguageModelResponse, Background), (LeaveChannelBuffer, Background), (LeaveChannelChat, Foreground), (LeaveProject, Foreground), @@ -298,6 +294,8 @@ messages!( (PrepareRename, Background), (PrepareRenameResponse, Background), (ProjectEntryResponse, Foreground), + (QueryLanguageModel, Background), + (QueryLanguageModelResponse, Background), (RefreshInlayHints, Foreground), (RejoinChannelBuffers, Foreground), (RejoinChannelBuffersResponse, Foreground), @@ -412,9 +410,7 @@ request_messages!( (Call, Ack), (CancelCall, Ack), (CopyProjectEntry, ProjectEntryResponse), - (CompleteWithLanguageModel, LanguageModelResponse), (ComputeEmbeddings, ComputeEmbeddingsResponse), - (CountTokensWithLanguageModel, CountTokensResponse), (CreateChannel, CreateChannelResponse), (CreateProjectEntry, ProjectEntryResponse), (CreateRoom, CreateRoomResponse), @@ -467,6 +463,7 @@ request_messages!( (PerformRename, PerformRenameResponse), (Ping, Ack), (PrepareRename, PrepareRenameResponse), + (QueryLanguageModel, QueryLanguageModelResponse), (RefreshInlayHints, Ack), (RejoinChannelBuffers, RejoinChannelBuffersResponse), (RejoinRoom, RejoinRoomResponse),