From 99bc90a372155f66ca0c6846c36e861575a96501 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Tue, 30 Jul 2024 16:18:53 +0200 Subject: [PATCH] Allow customization of the model used for tool calling (#15479) We also eliminate the `completion` crate and moved its logic into `LanguageModelRegistry`. Release Notes: - N/A --------- Co-authored-by: Nathan --- Cargo.lock | 28 +- Cargo.toml | 2 - crates/anthropic/src/anthropic.rs | 19 +- crates/assistant/Cargo.toml | 2 - crates/assistant/src/assistant.rs | 23 +- crates/assistant/src/assistant_panel.rs | 34 +- crates/assistant/src/assistant_settings.rs | 24 +- crates/assistant/src/context.rs | 61 ++-- crates/assistant/src/inline_assistant.rs | 34 +- crates/assistant/src/model_selector.rs | 16 +- crates/assistant/src/prompt_library.rs | 19 +- .../src/terminal_inline_assistant.rs | 36 +- crates/collab/Cargo.toml | 1 - crates/collab/src/tests/test_server.rs | 1 - crates/completion/Cargo.toml | 45 --- crates/completion/LICENSE-GPL | 1 - crates/completion/src/completion.rs | 312 ------------------ crates/copilot/src/copilot_chat.rs | 6 +- crates/language_model/Cargo.toml | 2 + crates/language_model/src/language_model.rs | 28 +- .../language_model/src/provider/anthropic.rs | 55 +-- crates/language_model/src/provider/cloud.rs | 93 +++--- .../src/provider/copilot_chat.rs | 59 ++-- crates/language_model/src/provider/fake.rs | 6 +- crates/language_model/src/provider/google.rs | 16 +- crates/language_model/src/provider/ollama.rs | 68 ++-- crates/language_model/src/provider/open_ai.rs | 17 +- crates/language_model/src/rate_limiter.rs | 70 ++++ crates/language_model/src/registry.rs | 75 ++++- crates/language_model/src/settings.rs | 12 +- crates/semantic_index/Cargo.toml | 1 - crates/semantic_index/src/semantic_index.rs | 3 - 32 files changed, 478 insertions(+), 691 deletions(-) delete mode 100644 crates/completion/Cargo.toml delete mode 120000 crates/completion/LICENSE-GPL delete mode 100644 crates/completion/src/completion.rs create mode 100644 crates/language_model/src/rate_limiter.rs diff --git a/Cargo.lock b/Cargo.lock index d2d0b70658..392be06d6c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -406,7 +406,6 @@ dependencies = [ "clock", "collections", "command_palette_hooks", - "completion", "ctor", "editor", "env_logger", @@ -2470,7 +2469,6 @@ dependencies = [ "clock", "collab_ui", "collections", - "completion", "ctor", "dashmap 6.0.1", "dev_server_projects", @@ -2655,30 +2653,6 @@ dependencies = [ "gpui", ] -[[package]] -name = "completion" -version = "0.1.0" -dependencies = [ - "anyhow", - "ctor", - "editor", - "env_logger", - "futures 0.3.28", - "gpui", - "language", - "language_model", - "project", - "rand 0.8.5", - "schemars", - "serde", - "serde_json", - "settings", - "smol", - "text", - "ui", - "unindent", -] - [[package]] name = "concurrent-queue" version = "2.2.0" @@ -6048,6 +6022,7 @@ dependencies = [ "serde", "serde_json", "settings", + "smol", "strum", "text", "theme", @@ -9506,7 +9481,6 @@ dependencies = [ "client", "clock", "collections", - "completion", "env_logger", "fs", "futures 0.3.28", diff --git a/Cargo.toml b/Cargo.toml index ca20aa1384..4282eb5caf 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,7 +19,6 @@ members = [ "crates/collections", "crates/command_palette", "crates/command_palette_hooks", - "crates/completion", "crates/copilot", "crates/db", "crates/dev_server_projects", @@ -190,7 +189,6 @@ collab_ui = { path = "crates/collab_ui" } collections = { path = "crates/collections" } command_palette = { path = "crates/command_palette" } command_palette_hooks = { path = "crates/command_palette_hooks" } -completion = { path = "crates/completion" } copilot = { path = "crates/copilot" } db = { path = "crates/db" } dev_server_projects = { path = "crates/dev_server_projects" } diff --git a/crates/anthropic/src/anthropic.rs b/crates/anthropic/src/anthropic.rs index c24d19bd1d..dea7e531b2 100644 --- a/crates/anthropic/src/anthropic.rs +++ b/crates/anthropic/src/anthropic.rs @@ -21,7 +21,12 @@ pub enum Model { #[serde(alias = "claude-3-haiku", rename = "claude-3-haiku-20240307")] Claude3Haiku, #[serde(rename = "custom")] - Custom { name: String, max_tokens: usize }, + Custom { + name: String, + max_tokens: usize, + /// Override this model with a different Anthropic model for tool calls. + tool_override: Option, + }, } impl Model { @@ -68,6 +73,18 @@ impl Model { Self::Custom { max_tokens, .. } => *max_tokens, } } + + pub fn tool_model_id(&self) -> &str { + if let Self::Custom { + tool_override: Some(tool_override), + .. + } = self + { + tool_override + } else { + self.id() + } + } } pub async fn complete( diff --git a/crates/assistant/Cargo.toml b/crates/assistant/Cargo.toml index 29e460b8a6..0bdc4642c5 100644 --- a/crates/assistant/Cargo.toml +++ b/crates/assistant/Cargo.toml @@ -32,7 +32,6 @@ client.workspace = true clock.workspace = true collections.workspace = true command_palette_hooks.workspace = true -completion.workspace = true editor.workspace = true fs.workspace = true futures.workspace = true @@ -77,7 +76,6 @@ workspace.workspace = true picker.workspace = true [dev-dependencies] -completion = { workspace = true, features = ["test-support"] } ctor.workspace = true editor = { workspace = true, features = ["test-support"] } env_logger.workspace = true diff --git a/crates/assistant/src/assistant.rs b/crates/assistant/src/assistant.rs index 181a4165c1..071385d745 100644 --- a/crates/assistant/src/assistant.rs +++ b/crates/assistant/src/assistant.rs @@ -15,7 +15,6 @@ use assistant_settings::AssistantSettings; use assistant_slash_command::SlashCommandRegistry; use client::{proto, Client}; use command_palette_hooks::CommandPaletteFilter; -use completion::LanguageModelCompletionProvider; pub use context::*; pub use context_store::*; use fs::Fs; @@ -192,7 +191,7 @@ pub fn init(fs: Arc, client: Arc, cx: &mut AppContext) { context_store::init(&client); prompt_library::init(cx); - init_completion_provider(cx); + init_language_model_settings(cx); assistant_slash_command::init(cx); register_slash_commands(cx); assistant_panel::init(cx); @@ -217,8 +216,7 @@ pub fn init(fs: Arc, client: Arc, cx: &mut AppContext) { .detach(); } -fn init_completion_provider(cx: &mut AppContext) { - completion::init(cx); +fn init_language_model_settings(cx: &mut AppContext) { update_active_language_model_from_settings(cx); cx.observe_global::(update_active_language_model_from_settings) @@ -233,20 +231,9 @@ fn update_active_language_model_from_settings(cx: &mut AppContext) { let settings = AssistantSettings::get_global(cx); let provider_name = LanguageModelProviderId::from(settings.default_model.provider.clone()); let model_id = LanguageModelId::from(settings.default_model.model.clone()); - - let Some(provider) = LanguageModelRegistry::global(cx) - .read(cx) - .provider(&provider_name) - else { - return; - }; - - let models = provider.provided_models(cx); - if let Some(model) = models.iter().find(|model| model.id() == model_id).cloned() { - LanguageModelCompletionProvider::global(cx).update(cx, |completion_provider, cx| { - completion_provider.set_active_model(model, cx); - }); - } + LanguageModelRegistry::global(cx).update(cx, |registry, cx| { + registry.select_active_model(&provider_name, &model_id, cx); + }); } fn register_slash_commands(cx: &mut AppContext) { diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index ad3616c459..3138f0257f 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -19,7 +19,6 @@ use anyhow::{anyhow, Result}; use assistant_slash_command::{SlashCommand, SlashCommandOutputSection}; use client::proto; use collections::{BTreeSet, HashMap, HashSet}; -use completion::LanguageModelCompletionProvider; use editor::{ actions::{FoldAt, MoveToEndOfLine, Newline, ShowCompletions, UnfoldAt}, display_map::{ @@ -43,7 +42,7 @@ use language::{ language_settings::SoftWrap, Buffer, Capability, LanguageRegistry, LspAdapterDelegate, Point, ToOffset, }; -use language_model::{LanguageModelProviderId, Role}; +use language_model::{LanguageModelProviderId, LanguageModelRegistry, Role}; use multi_buffer::MultiBufferRow; use picker::{Picker, PickerDelegate}; use project::{Project, ProjectLspAdapterDelegate}; @@ -392,9 +391,9 @@ impl AssistantPanel { cx.subscribe(&context_editor_toolbar, Self::handle_toolbar_event), cx.subscribe(&model_summary_editor, Self::handle_summary_editor_event), cx.subscribe(&context_store, Self::handle_context_store_event), - cx.observe( - &LanguageModelCompletionProvider::global(cx), - |this, _, cx| { + cx.subscribe( + &LanguageModelRegistry::global(cx), + |this, _, _: &language_model::ActiveModelChanged, cx| { this.completion_provider_changed(cx); }, ), @@ -560,7 +559,7 @@ impl AssistantPanel { }) } - let Some(new_provider_id) = LanguageModelCompletionProvider::read_global(cx) + let Some(new_provider_id) = LanguageModelRegistry::read_global(cx) .active_provider() .map(|p| p.id()) else { @@ -599,7 +598,7 @@ impl AssistantPanel { } fn authentication_prompt(cx: &mut WindowContext) -> Option { - if let Some(provider) = LanguageModelCompletionProvider::read_global(cx).active_provider() { + if let Some(provider) = LanguageModelRegistry::read_global(cx).active_provider() { if !provider.is_authenticated(cx) { return Some(provider.authentication_prompt(cx)); } @@ -904,9 +903,9 @@ impl AssistantPanel { } fn reset_credentials(&mut self, _: &ResetKey, cx: &mut ViewContext) { - LanguageModelCompletionProvider::read_global(cx) - .reset_credentials(cx) - .detach_and_log_err(cx); + if let Some(provider) = LanguageModelRegistry::read_global(cx).active_provider() { + provider.reset_credentials(cx).detach_and_log_err(cx); + } } fn toggle_model_selector(&mut self, _: &ToggleModelSelector, cx: &mut ViewContext) { @@ -1041,11 +1040,18 @@ impl AssistantPanel { } fn is_authenticated(&mut self, cx: &mut ViewContext) -> bool { - LanguageModelCompletionProvider::read_global(cx).is_authenticated(cx) + LanguageModelRegistry::read_global(cx) + .active_provider() + .map_or(false, |provider| provider.is_authenticated(cx)) } fn authenticate(&mut self, cx: &mut ViewContext) -> Task> { - LanguageModelCompletionProvider::read_global(cx).authenticate(cx) + LanguageModelRegistry::read_global(cx) + .active_provider() + .map_or( + Task::ready(Err(anyhow!("no active language model provider"))), + |provider| provider.authenticate(cx), + ) } fn render_signed_in(&mut self, cx: &mut ViewContext) -> impl IntoElement { @@ -2707,7 +2713,7 @@ impl ContextEditorToolbarItem { } fn render_remaining_tokens(&self, cx: &mut ViewContext) -> Option { - let model = LanguageModelCompletionProvider::read_global(cx).active_model()?; + let model = LanguageModelRegistry::read_global(cx).active_model()?; let context = &self .active_context_editor .as_ref()? @@ -2779,7 +2785,7 @@ impl Render for ContextEditorToolbarItem { .whitespace_nowrap() .child( Label::new( - LanguageModelCompletionProvider::read_global(cx) + LanguageModelRegistry::read_global(cx) .active_model() .map(|model| model.name().0) .unwrap_or_else(|| "No model selected".into()), diff --git a/crates/assistant/src/assistant_settings.rs b/crates/assistant/src/assistant_settings.rs index 152a2d629d..d62318ef13 100644 --- a/crates/assistant/src/assistant_settings.rs +++ b/crates/assistant/src/assistant_settings.rs @@ -52,7 +52,7 @@ pub struct AssistantSettings { pub dock: AssistantDockPosition, pub default_width: Pixels, pub default_height: Pixels, - pub default_model: AssistantDefaultModel, + pub default_model: LanguageModelSelection, pub using_outdated_settings_version: bool, } @@ -198,25 +198,25 @@ impl AssistantSettingsContent { .clone() .and_then(|provider| match provider { AssistantProviderContentV1::ZedDotDev { default_model } => { - default_model.map(|model| AssistantDefaultModel { + default_model.map(|model| LanguageModelSelection { provider: "zed.dev".to_string(), model: model.id().to_string(), }) } AssistantProviderContentV1::OpenAi { default_model, .. } => { - default_model.map(|model| AssistantDefaultModel { + default_model.map(|model| LanguageModelSelection { provider: "openai".to_string(), model: model.id().to_string(), }) } AssistantProviderContentV1::Anthropic { default_model, .. } => { - default_model.map(|model| AssistantDefaultModel { + default_model.map(|model| LanguageModelSelection { provider: "anthropic".to_string(), model: model.id().to_string(), }) } AssistantProviderContentV1::Ollama { default_model, .. } => { - default_model.map(|model| AssistantDefaultModel { + default_model.map(|model| LanguageModelSelection { provider: "ollama".to_string(), model: model.id().to_string(), }) @@ -231,7 +231,7 @@ impl AssistantSettingsContent { dock: settings.dock, default_width: settings.default_width, default_height: settings.default_height, - default_model: Some(AssistantDefaultModel { + default_model: Some(LanguageModelSelection { provider: "openai".to_string(), model: settings .default_open_ai_model @@ -325,7 +325,7 @@ impl AssistantSettingsContent { _ => {} }, VersionedAssistantSettingsContent::V2(settings) => { - settings.default_model = Some(AssistantDefaultModel { provider, model }); + settings.default_model = Some(LanguageModelSelection { provider, model }); } }, AssistantSettingsContent::Legacy(settings) => { @@ -382,11 +382,11 @@ pub struct AssistantSettingsContentV2 { /// Default: 320 default_height: Option, /// The default model to use when creating new contexts. - default_model: Option, + default_model: Option, } #[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)] -pub struct AssistantDefaultModel { +pub struct LanguageModelSelection { #[schemars(schema_with = "providers_schema")] pub provider: String, pub model: String, @@ -407,7 +407,7 @@ fn providers_schema(_: &mut schemars::gen::SchemaGenerator) -> schemars::schema: .into() } -impl Default for AssistantDefaultModel { +impl Default for LanguageModelSelection { fn default() -> Self { Self { provider: "openai".to_string(), @@ -542,7 +542,7 @@ mod tests { assert!(!AssistantSettings::get_global(cx).using_outdated_settings_version); assert_eq!( AssistantSettings::get_global(cx).default_model, - AssistantDefaultModel { + LanguageModelSelection { provider: "openai".into(), model: "gpt-4o".into(), } @@ -555,7 +555,7 @@ mod tests { |settings, _| { *settings = AssistantSettingsContent::Versioned( VersionedAssistantSettingsContent::V2(AssistantSettingsContentV2 { - default_model: Some(AssistantDefaultModel { + default_model: Some(LanguageModelSelection { provider: "test-provider".into(), model: "gpt-99".into(), }), diff --git a/crates/assistant/src/context.rs b/crates/assistant/src/context.rs index 59069aceb2..7392b7d1f7 100644 --- a/crates/assistant/src/context.rs +++ b/crates/assistant/src/context.rs @@ -1,6 +1,6 @@ use crate::{ - prompt_library::PromptStore, slash_command::SlashCommandLine, InitialInsertion, - LanguageModelCompletionProvider, MessageId, MessageStatus, + prompt_library::PromptStore, slash_command::SlashCommandLine, InitialInsertion, MessageId, + MessageStatus, }; use anyhow::{anyhow, Context as _, Result}; use assistant_slash_command::{ @@ -18,7 +18,10 @@ use gpui::{AppContext, Context as _, EventEmitter, Model, ModelContext, Subscrip use language::{ AnchorRangeExt, Bias, Buffer, LanguageRegistry, OffsetRangeExt, ParseStatus, Point, ToOffset, }; -use language_model::{LanguageModelRequest, LanguageModelRequestMessage, LanguageModelTool, Role}; +use language_model::{ + LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, LanguageModelTool, + Role, +}; use open_ai::Model as OpenAiModel; use paths::contexts_dir; use project::Project; @@ -1180,17 +1183,16 @@ impl Context { pub(crate) fn count_remaining_tokens(&mut self, cx: &mut ModelContext) { let request = self.to_completion_request(cx); + let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else { + return; + }; self.pending_token_count = cx.spawn(|this, mut cx| { async move { cx.background_executor() .timer(Duration::from_millis(200)) .await; - let token_count = cx - .update(|cx| { - LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx) - })? - .await?; + let token_count = cx.update(|cx| model.count_tokens(request, cx))?.await?; this.update(&mut cx, |this, cx| { this.token_count = Some(token_count); cx.notify() @@ -1368,6 +1370,10 @@ impl Context { } } + let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else { + return Task::ready(Err(anyhow!("no active model")).log_err()); + }; + let mut request = self.to_completion_request(cx); let edit_step_range = edit_step.source_range.clone(); let step_text = self @@ -1388,12 +1394,7 @@ impl Context { content: prompt, }); - let tool_use = cx - .update(|cx| { - LanguageModelCompletionProvider::read_global(cx) - .use_tool::(request, cx) - })? - .await?; + let tool_use = model.use_tool::(request, &cx).await?; this.update(&mut cx, |this, cx| { let step_index = this @@ -1568,6 +1569,8 @@ impl Context { } pub fn assist(&mut self, cx: &mut ModelContext) -> Option { + let provider = LanguageModelRegistry::read_global(cx).active_provider()?; + let model = LanguageModelRegistry::read_global(cx).active_model()?; let last_message_id = self.message_anchors.iter().rev().find_map(|message| { message .start @@ -1575,14 +1578,12 @@ impl Context { .then_some(message.id) })?; - if !LanguageModelCompletionProvider::read_global(cx).is_authenticated(cx) { + if !provider.is_authenticated(cx) { log::info!("completion provider has no credentials"); return None; } let request = self.to_completion_request(cx); - let stream = - LanguageModelCompletionProvider::read_global(cx).stream_completion(request, cx); let assistant_message = self .insert_message_after(last_message_id, Role::Assistant, MessageStatus::Pending, cx) .unwrap(); @@ -1594,6 +1595,7 @@ impl Context { let task = cx.spawn({ |this, mut cx| async move { + let stream = model.stream_completion(request, &cx); let assistant_message_id = assistant_message.id; let mut response_latency = None; let stream_completion = async { @@ -1662,14 +1664,10 @@ impl Context { }); if let Some(telemetry) = this.telemetry.as_ref() { - let model_telemetry_id = LanguageModelCompletionProvider::read_global(cx) - .active_model() - .map(|m| m.telemetry_id()) - .unwrap_or_default(); telemetry.report_assistant_event( Some(this.id.0.clone()), AssistantKind::Panel, - model_telemetry_id, + model.telemetry_id(), response_latency, error_message, ); @@ -1935,8 +1933,15 @@ impl Context { } pub(super) fn summarize(&mut self, replace_old: bool, cx: &mut ModelContext) { + let Some(provider) = LanguageModelRegistry::read_global(cx).active_provider() else { + return; + }; + let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else { + return; + }; + if replace_old || (self.message_anchors.len() >= 2 && self.summary.is_none()) { - if !LanguageModelCompletionProvider::read_global(cx).is_authenticated(cx) { + if !provider.is_authenticated(cx) { return; } @@ -1953,10 +1958,9 @@ impl Context { temperature: 1.0, }; - let stream = - LanguageModelCompletionProvider::read_global(cx).stream_completion(request, cx); self.pending_summary = cx.spawn(|this, mut cx| { async move { + let stream = model.stream_completion(request, &cx); let mut messages = stream.await?; let mut replaced = !replace_old; @@ -2490,7 +2494,6 @@ mod tests { fn test_inserting_and_removing_messages(cx: &mut AppContext) { let settings_store = SettingsStore::test(cx); language_model::LanguageModelRegistry::test(cx); - completion::LanguageModelCompletionProvider::test(cx); cx.set_global(settings_store); assistant_panel::init(cx); let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone())); @@ -2623,7 +2626,6 @@ mod tests { let settings_store = SettingsStore::test(cx); cx.set_global(settings_store); language_model::LanguageModelRegistry::test(cx); - completion::LanguageModelCompletionProvider::test(cx); assistant_panel::init(cx); let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone())); @@ -2717,7 +2719,6 @@ mod tests { fn test_messages_for_offsets(cx: &mut AppContext) { let settings_store = SettingsStore::test(cx); language_model::LanguageModelRegistry::test(cx); - completion::LanguageModelCompletionProvider::test(cx); cx.set_global(settings_store); assistant_panel::init(cx); let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone())); @@ -2803,7 +2804,6 @@ mod tests { let settings_store = cx.update(SettingsStore::test); cx.set_global(settings_store); cx.update(language_model::LanguageModelRegistry::test); - cx.update(completion::LanguageModelCompletionProvider::test); cx.update(Project::init_settings); cx.update(assistant_panel::init); let fs = FakeFs::new(cx.background_executor.clone()); @@ -2930,7 +2930,6 @@ mod tests { cx.set_global(settings_store); let fake_provider = cx.update(language_model::LanguageModelRegistry::test); - cx.update(completion::LanguageModelCompletionProvider::test); let fake_model = fake_provider.test_model(); cx.update(assistant_panel::init); @@ -3032,7 +3031,6 @@ mod tests { let settings_store = cx.update(SettingsStore::test); cx.set_global(settings_store); cx.update(language_model::LanguageModelRegistry::test); - cx.update(completion::LanguageModelCompletionProvider::test); cx.update(assistant_panel::init); let registry = Arc::new(LanguageRegistry::test(cx.executor())); let context = cx.new_model(|cx| Context::local(registry.clone(), None, cx)); @@ -3109,7 +3107,6 @@ mod tests { let settings_store = cx.update(SettingsStore::test); cx.set_global(settings_store); cx.update(language_model::LanguageModelRegistry::test); - cx.update(completion::LanguageModelCompletionProvider::test); cx.update(assistant_panel::init); let slash_commands = cx.update(SlashCommandRegistry::default_global); diff --git a/crates/assistant/src/inline_assistant.rs b/crates/assistant/src/inline_assistant.rs index 87bccee8ea..83917ce1d8 100644 --- a/crates/assistant/src/inline_assistant.rs +++ b/crates/assistant/src/inline_assistant.rs @@ -1,6 +1,6 @@ use crate::{ humanize_token_count, prompts::generate_content_prompt, AssistantPanel, AssistantPanelEvent, - Hunk, LanguageModelCompletionProvider, ModelSelector, StreamingDiff, + Hunk, ModelSelector, StreamingDiff, }; use anyhow::{anyhow, Context as _, Result}; use client::telemetry::Telemetry; @@ -27,7 +27,9 @@ use gpui::{ WindowContext, }; use language::{Buffer, IndentKind, Point, Selection, TransactionId}; -use language_model::{LanguageModelRequest, LanguageModelRequestMessage, Role}; +use language_model::{ + LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role, +}; use multi_buffer::MultiBufferRow; use parking_lot::Mutex; use rope::Rope; @@ -1328,7 +1330,7 @@ impl Render for PromptEditor { Tooltip::with_meta( format!( "Using {}", - LanguageModelCompletionProvider::read_global(cx) + LanguageModelRegistry::read_global(cx) .active_model() .map(|model| model.name().0) .unwrap_or_else(|| "No model selected".into()), @@ -1662,7 +1664,7 @@ impl PromptEditor { } fn render_token_count(&self, cx: &mut ViewContext) -> Option { - let model = LanguageModelCompletionProvider::read_global(cx).active_model()?; + let model = LanguageModelRegistry::read_global(cx).active_model()?; let token_count = self.token_count?; let max_token_count = model.max_token_count(); @@ -2013,8 +2015,12 @@ impl Codegen { assistant_panel_context: Option, cx: &AppContext, ) -> BoxFuture<'static, Result> { - let request = self.build_request(user_prompt, assistant_panel_context, edit_range, cx); - LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx) + if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() { + let request = self.build_request(user_prompt, assistant_panel_context, edit_range, cx); + model.count_tokens(request, cx) + } else { + future::ready(Err(anyhow!("no active model"))).boxed() + } } pub fn start( @@ -2024,6 +2030,10 @@ impl Codegen { assistant_panel_context: Option, cx: &mut ModelContext, ) -> Result<()> { + let model = LanguageModelRegistry::read_global(cx) + .active_model() + .context("no active model")?; + self.undo(cx); // Handle initial insertion @@ -2053,10 +2063,7 @@ impl Codegen { None }; - let model_telemetry_id = LanguageModelCompletionProvider::read_global(cx) - .active_model_telemetry_id() - .context("no active model")?; - + let telemetry_id = model.telemetry_id(); let chunks: LocalBoxFuture>>> = if user_prompt .trim() .to_lowercase() @@ -2067,10 +2074,10 @@ impl Codegen { let request = self.build_request(user_prompt, assistant_panel_context, edit_range.clone(), cx); let chunks = - LanguageModelCompletionProvider::read_global(cx).stream_completion(request, cx); + cx.spawn(|_, cx| async move { model.stream_completion(request, &cx).await }); async move { Ok(chunks.await?.boxed()) }.boxed_local() }; - self.handle_stream(model_telemetry_id, edit_range, chunks, cx); + self.handle_stream(telemetry_id, edit_range, chunks, cx); Ok(()) } @@ -2657,7 +2664,6 @@ mod tests { async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) { cx.set_global(cx.update(SettingsStore::test)); cx.update(language_model::LanguageModelRegistry::test); - cx.update(completion::LanguageModelCompletionProvider::test); cx.update(language_settings::init); let text = indoc! {" @@ -2789,7 +2795,6 @@ mod tests { mut rng: StdRng, ) { cx.update(LanguageModelRegistry::test); - cx.update(completion::LanguageModelCompletionProvider::test); cx.set_global(cx.update(SettingsStore::test)); cx.update(language_settings::init); @@ -2853,7 +2858,6 @@ mod tests { #[gpui::test(iterations = 10)] async fn test_autoindent_respects_tabs_in_selection(cx: &mut TestAppContext) { cx.update(LanguageModelRegistry::test); - cx.update(completion::LanguageModelCompletionProvider::test); cx.set_global(cx.update(SettingsStore::test)); cx.update(language_settings::init); diff --git a/crates/assistant/src/model_selector.rs b/crates/assistant/src/model_selector.rs index fbf4eb7bd1..f499c4c8f5 100644 --- a/crates/assistant/src/model_selector.rs +++ b/crates/assistant/src/model_selector.rs @@ -1,6 +1,6 @@ use std::sync::Arc; -use crate::{assistant_settings::AssistantSettings, LanguageModelCompletionProvider}; +use crate::assistant_settings::AssistantSettings; use fs::Fs; use gpui::SharedString; use language_model::LanguageModelRegistry; @@ -81,13 +81,13 @@ impl RenderOnce for ModelSelector { } }, { - let provider = provider.id(); + let provider = provider.clone(); move |cx| { - LanguageModelCompletionProvider::global(cx).update( + LanguageModelRegistry::global(cx).update( cx, |completion_provider, cx| { completion_provider - .set_active_provider(provider.clone(), cx) + .set_active_provider(Some(provider.clone()), cx); }, ); } @@ -95,12 +95,12 @@ impl RenderOnce for ModelSelector { ); } - let selected_model = LanguageModelCompletionProvider::read_global(cx) - .active_model() - .map(|m| m.id()); - let selected_provider = LanguageModelCompletionProvider::read_global(cx) + let selected_provider = LanguageModelRegistry::read_global(cx) .active_provider() .map(|m| m.id()); + let selected_model = LanguageModelRegistry::read_global(cx) + .active_model() + .map(|m| m.id()); for available_model in available_models { menu = menu.custom_entry( diff --git a/crates/assistant/src/prompt_library.rs b/crates/assistant/src/prompt_library.rs index cea5db2a6f..81091b3fd5 100644 --- a/crates/assistant/src/prompt_library.rs +++ b/crates/assistant/src/prompt_library.rs @@ -1,6 +1,5 @@ use crate::{ slash_command::SlashCommandCompletionProvider, AssistantPanel, InlineAssist, InlineAssistant, - LanguageModelCompletionProvider, }; use anyhow::{anyhow, Result}; use assets::Assets; @@ -19,7 +18,9 @@ use gpui::{ }; use heed::{types::SerdeBincode, Database, RoTxn}; use language::{language_settings::SoftWrap, Buffer, LanguageRegistry}; -use language_model::{LanguageModelRequest, LanguageModelRequestMessage, Role}; +use language_model::{ + LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role, +}; use parking_lot::RwLock; use picker::{Picker, PickerDelegate}; use rope::Rope; @@ -636,7 +637,10 @@ impl PromptLibrary { }; let prompt_editor = &self.prompt_editors[&active_prompt_id].body_editor; - let provider = LanguageModelCompletionProvider::read_global(cx); + let Some(provider) = LanguageModelRegistry::read_global(cx).active_provider() else { + return; + }; + let initial_prompt = action.prompt.clone(); if provider.is_authenticated(cx) { InlineAssistant::update_global(cx, |assistant, cx| { @@ -725,6 +729,9 @@ impl PromptLibrary { } fn count_tokens(&mut self, prompt_id: PromptId, cx: &mut ViewContext) { + let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else { + return; + }; if let Some(prompt) = self.prompt_editors.get_mut(&prompt_id) { let editor = &prompt.body_editor.read(cx); let buffer = &editor.buffer().read(cx).as_singleton().unwrap().read(cx); @@ -736,7 +743,7 @@ impl PromptLibrary { cx.background_executor().timer(DEBOUNCE_TIMEOUT).await; let token_count = cx .update(|cx| { - LanguageModelCompletionProvider::read_global(cx).count_tokens( + model.count_tokens( LanguageModelRequest { messages: vec![LanguageModelRequestMessage { role: Role::System, @@ -804,7 +811,7 @@ impl PromptLibrary { let prompt_metadata = self.store.metadata(prompt_id)?; let prompt_editor = &self.prompt_editors[&prompt_id]; let focus_handle = prompt_editor.body_editor.focus_handle(cx); - let current_model = LanguageModelCompletionProvider::read_global(cx).active_model(); + let model = LanguageModelRegistry::read_global(cx).active_model(); let settings = ThemeSettings::get_global(cx); Some( @@ -914,7 +921,7 @@ impl PromptLibrary { None, format!( "Model: {}", - current_model + model .as_ref() .map(|model| model .name() diff --git a/crates/assistant/src/terminal_inline_assistant.rs b/crates/assistant/src/terminal_inline_assistant.rs index bea35ea89b..029cc079ca 100644 --- a/crates/assistant/src/terminal_inline_assistant.rs +++ b/crates/assistant/src/terminal_inline_assistant.rs @@ -1,6 +1,6 @@ use crate::{ humanize_token_count, prompts::generate_terminal_assistant_prompt, AssistantPanel, - AssistantPanelEvent, LanguageModelCompletionProvider, ModelSelector, + AssistantPanelEvent, ModelSelector, }; use anyhow::{Context as _, Result}; use client::telemetry::Telemetry; @@ -16,7 +16,9 @@ use gpui::{ Subscription, Task, TextStyle, UpdateGlobal, View, WeakView, }; use language::Buffer; -use language_model::{LanguageModelRequest, LanguageModelRequestMessage, Role}; +use language_model::{ + LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role, +}; use settings::Settings; use std::{ cmp, @@ -556,7 +558,7 @@ impl Render for PromptEditor { Tooltip::with_meta( format!( "Using {}", - LanguageModelCompletionProvider::read_global(cx) + LanguageModelRegistry::read_global(cx) .active_model() .map(|model| model.name().0) .unwrap_or_else(|| "No model selected".into()), @@ -700,6 +702,9 @@ impl PromptEditor { fn count_tokens(&mut self, cx: &mut ViewContext) { let assist_id = self.id; + let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else { + return; + }; self.pending_token_count = cx.spawn(|this, mut cx| async move { cx.background_executor().timer(Duration::from_secs(1)).await; let request = @@ -707,11 +712,7 @@ impl PromptEditor { inline_assistant.request_for_inline_assist(assist_id, cx) })??; - let token_count = cx - .update(|cx| { - LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx) - })? - .await?; + let token_count = cx.update(|cx| model.count_tokens(request, cx))?.await?; this.update(&mut cx, |this, cx| { this.token_count = Some(token_count); cx.notify(); @@ -840,7 +841,7 @@ impl PromptEditor { } fn render_token_count(&self, cx: &mut ViewContext) -> Option { - let model = LanguageModelCompletionProvider::read_global(cx).active_model()?; + let model = LanguageModelRegistry::read_global(cx).active_model()?; let token_count = self.token_count?; let max_token_count = model.max_token_count(); @@ -982,19 +983,16 @@ impl Codegen { } pub fn start(&mut self, prompt: LanguageModelRequest, cx: &mut ModelContext) { - self.status = CodegenStatus::Pending; - self.transaction = Some(TerminalTransaction::start(self.terminal.clone())); + let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else { + return; + }; let telemetry = self.telemetry.clone(); - let model_telemetry_id = LanguageModelCompletionProvider::read_global(cx) - .active_model() - .map(|m| m.telemetry_id()) - .unwrap_or_default(); - let response = - LanguageModelCompletionProvider::read_global(cx).stream_completion(prompt, cx); - + self.status = CodegenStatus::Pending; + self.transaction = Some(TerminalTransaction::start(self.terminal.clone())); self.generation = cx.spawn(|this, mut cx| async move { - let response = response.await; + let model_telemetry_id = model.telemetry_id(); + let response = model.stream_completion(prompt, &cx).await; let generate = async { let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1); diff --git a/crates/collab/Cargo.toml b/crates/collab/Cargo.toml index 289212a6b1..8ebeb3e555 100644 --- a/crates/collab/Cargo.toml +++ b/crates/collab/Cargo.toml @@ -80,7 +80,6 @@ channel.workspace = true client = { workspace = true, features = ["test-support"] } collab_ui = { workspace = true, features = ["test-support"] } collections = { workspace = true, features = ["test-support"] } -completion = { workspace = true, features = ["test-support"] } ctor.workspace = true editor = { workspace = true, features = ["test-support"] } env_logger.workspace = true diff --git a/crates/collab/src/tests/test_server.rs b/crates/collab/src/tests/test_server.rs index 7a3bc92a5f..76174f5953 100644 --- a/crates/collab/src/tests/test_server.rs +++ b/crates/collab/src/tests/test_server.rs @@ -300,7 +300,6 @@ impl TestServer { dev_server_projects::init(client.clone(), cx); settings::KeymapFile::load_asset(os_keymap, cx).unwrap(); language_model::LanguageModelRegistry::test(cx); - completion::init(cx); assistant::context_store::init(&client); }); diff --git a/crates/completion/Cargo.toml b/crates/completion/Cargo.toml deleted file mode 100644 index 7224dc6b0d..0000000000 --- a/crates/completion/Cargo.toml +++ /dev/null @@ -1,45 +0,0 @@ -[package] -name = "completion" -version = "0.1.0" -edition = "2021" -publish = false -license = "GPL-3.0-or-later" - -[lints] -workspace = true - -[lib] -path = "src/completion.rs" -doctest = false - -[features] -test-support = [ - "editor/test-support", - "language/test-support", - "language_model/test-support", - "project/test-support", - "text/test-support", -] - -[dependencies] -anyhow.workspace = true -futures.workspace = true -gpui.workspace = true -language_model.workspace = true -schemars.workspace = true -serde.workspace = true -serde_json.workspace = true -settings.workspace = true -smol.workspace = true -ui.workspace = true - -[dev-dependencies] -ctor.workspace = true -editor = { workspace = true, features = ["test-support"] } -env_logger.workspace = true -language = { workspace = true, features = ["test-support"] } -project = { workspace = true, features = ["test-support"] } -language_model = { workspace = true, features = ["test-support"] } -rand.workspace = true -text = { workspace = true, features = ["test-support"] } -unindent.workspace = true diff --git a/crates/completion/LICENSE-GPL b/crates/completion/LICENSE-GPL deleted file mode 120000 index 89e542f750..0000000000 --- a/crates/completion/LICENSE-GPL +++ /dev/null @@ -1 +0,0 @@ -../../LICENSE-GPL \ No newline at end of file diff --git a/crates/completion/src/completion.rs b/crates/completion/src/completion.rs deleted file mode 100644 index f55818e284..0000000000 --- a/crates/completion/src/completion.rs +++ /dev/null @@ -1,312 +0,0 @@ -use anyhow::{anyhow, Result}; -use futures::{future::BoxFuture, stream::BoxStream, StreamExt}; -use gpui::{AppContext, Global, Model, ModelContext, Task}; -use language_model::{ - LanguageModel, LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry, - LanguageModelRequest, LanguageModelTool, -}; -use smol::{ - future::FutureExt, - lock::{Semaphore, SemaphoreGuardArc}, -}; -use std::{future, pin::Pin, sync::Arc, task::Poll}; -use ui::Context; - -pub fn init(cx: &mut AppContext) { - let completion_provider = cx.new_model(|cx| LanguageModelCompletionProvider::new(cx)); - cx.set_global(GlobalLanguageModelCompletionProvider(completion_provider)); -} - -struct GlobalLanguageModelCompletionProvider(Model); - -impl Global for GlobalLanguageModelCompletionProvider {} - -pub struct LanguageModelCompletionProvider { - active_provider: Option>, - active_model: Option>, - request_limiter: Arc, -} - -const MAX_CONCURRENT_COMPLETION_REQUESTS: usize = 4; - -pub struct LanguageModelCompletionResponse { - inner: BoxStream<'static, Result>, - _lock: SemaphoreGuardArc, -} - -impl futures::Stream for LanguageModelCompletionResponse { - type Item = Result; - - fn poll_next( - mut self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> Poll> { - Pin::new(&mut self.inner).poll_next(cx) - } -} - -impl LanguageModelCompletionProvider { - pub fn global(cx: &AppContext) -> Model { - cx.global::() - .0 - .clone() - } - - pub fn read_global(cx: &AppContext) -> &Self { - cx.global::() - .0 - .read(cx) - } - - #[cfg(any(test, feature = "test-support"))] - pub fn test(cx: &mut AppContext) { - let provider = cx.new_model(|cx| { - let mut this = Self::new(cx); - let available_model = LanguageModelRegistry::read_global(cx) - .available_models(cx) - .first() - .unwrap() - .clone(); - this.set_active_model(available_model, cx); - this - }); - cx.set_global(GlobalLanguageModelCompletionProvider(provider)); - } - - pub fn new(cx: &mut ModelContext) -> Self { - cx.observe(&LanguageModelRegistry::global(cx), |_, _, cx| { - cx.notify(); - }) - .detach(); - - Self { - active_provider: None, - active_model: None, - request_limiter: Arc::new(Semaphore::new(MAX_CONCURRENT_COMPLETION_REQUESTS)), - } - } - - pub fn active_provider(&self) -> Option> { - self.active_provider.clone() - } - - pub fn set_active_provider( - &mut self, - provider_id: LanguageModelProviderId, - cx: &mut ModelContext, - ) { - self.active_provider = LanguageModelRegistry::read_global(cx).provider(&provider_id); - self.active_model = None; - cx.notify(); - } - - pub fn active_model(&self) -> Option> { - self.active_model.clone() - } - - pub fn set_active_model(&mut self, model: Arc, cx: &mut ModelContext) { - if self.active_model.as_ref().map_or(false, |m| { - m.id() == model.id() && m.provider_id() == model.provider_id() - }) { - return; - } - - self.active_provider = - LanguageModelRegistry::read_global(cx).provider(&model.provider_id()); - self.active_model = Some(model.clone()); - - if let Some(provider) = self.active_provider.as_ref() { - provider.load_model(model, cx); - } - - cx.notify(); - } - - pub fn is_authenticated(&self, cx: &AppContext) -> bool { - self.active_provider - .as_ref() - .map_or(false, |provider| provider.is_authenticated(cx)) - } - - pub fn authenticate(&self, cx: &AppContext) -> Task> { - self.active_provider - .as_ref() - .map_or(Task::ready(Ok(())), |provider| provider.authenticate(cx)) - } - - pub fn reset_credentials(&self, cx: &AppContext) -> Task> { - self.active_provider - .as_ref() - .map_or(Task::ready(Ok(())), |provider| { - provider.reset_credentials(cx) - }) - } - - pub fn count_tokens( - &self, - request: LanguageModelRequest, - cx: &AppContext, - ) -> BoxFuture<'static, Result> { - if let Some(model) = self.active_model() { - model.count_tokens(request, cx) - } else { - future::ready(Err(anyhow!("no active model"))).boxed() - } - } - - pub fn stream_completion( - &self, - request: LanguageModelRequest, - cx: &AppContext, - ) -> Task> { - if let Some(language_model) = self.active_model() { - let rate_limiter = self.request_limiter.clone(); - cx.spawn(|cx| async move { - let lock = rate_limiter.acquire_arc().await; - let response = language_model.stream_completion(request, &cx).await?; - Ok(LanguageModelCompletionResponse { - inner: response, - _lock: lock, - }) - }) - } else { - Task::ready(Err(anyhow!("No active model set"))) - } - } - - pub fn complete(&self, request: LanguageModelRequest, cx: &AppContext) -> Task> { - let response = self.stream_completion(request, cx); - cx.foreground_executor().spawn(async move { - let mut chunks = response.await?; - let mut completion = String::new(); - while let Some(chunk) = chunks.next().await { - let chunk = chunk?; - completion.push_str(&chunk); - } - Ok(completion) - }) - } - - pub fn use_tool( - &self, - request: LanguageModelRequest, - cx: &AppContext, - ) -> Task> { - if let Some(language_model) = self.active_model() { - cx.spawn(|cx| async move { - let schema = schemars::schema_for!(T); - let schema_json = serde_json::to_value(&schema).unwrap(); - let request = - language_model.use_tool(request, T::name(), T::description(), schema_json, &cx); - let response = request.await?; - Ok(serde_json::from_value(response)?) - }) - } else { - Task::ready(Err(anyhow!("No active model set"))) - } - } - - pub fn active_model_telemetry_id(&self) -> Option { - self.active_model.as_ref().map(|m| m.telemetry_id()) - } -} - -#[cfg(test)] -mod tests { - use futures::StreamExt; - use gpui::AppContext; - use settings::SettingsStore; - use ui::Context; - - use crate::{ - LanguageModelCompletionProvider, LanguageModelRequest, MAX_CONCURRENT_COMPLETION_REQUESTS, - }; - - use language_model::LanguageModelRegistry; - - #[gpui::test] - fn test_rate_limiting(cx: &mut AppContext) { - SettingsStore::test(cx); - let fake_provider = LanguageModelRegistry::test(cx); - - let model = LanguageModelRegistry::read_global(cx) - .available_models(cx) - .first() - .cloned() - .unwrap(); - - let provider = cx.new_model(|cx| { - let mut provider = LanguageModelCompletionProvider::new(cx); - provider.set_active_model(model.clone(), cx); - provider - }); - - let fake_model = fake_provider.test_model(); - - // Enqueue some requests - for i in 0..MAX_CONCURRENT_COMPLETION_REQUESTS * 2 { - let response = provider.read(cx).stream_completion( - LanguageModelRequest { - temperature: i as f32 / 10.0, - ..Default::default() - }, - cx, - ); - cx.background_executor() - .spawn(async move { - let mut stream = response.await.unwrap(); - while let Some(message) = stream.next().await { - message.unwrap(); - } - }) - .detach(); - } - cx.background_executor().run_until_parked(); - assert_eq!( - fake_model.completion_count(), - MAX_CONCURRENT_COMPLETION_REQUESTS - ); - - // Get the first completion request that is in flight and mark it as completed. - let completion = fake_model.pending_completions().into_iter().next().unwrap(); - fake_model.finish_completion(&completion); - - // Ensure that the number of in-flight completion requests is reduced. - assert_eq!( - fake_model.completion_count(), - MAX_CONCURRENT_COMPLETION_REQUESTS - 1 - ); - - cx.background_executor().run_until_parked(); - - // Ensure that another completion request was allowed to acquire the lock. - assert_eq!( - fake_model.completion_count(), - MAX_CONCURRENT_COMPLETION_REQUESTS - ); - - // Mark all completion requests as finished that are in flight. - for request in fake_model.pending_completions() { - fake_model.finish_completion(&request); - } - - assert_eq!(fake_model.completion_count(), 0); - - // Wait until the background tasks acquire the lock again. - cx.background_executor().run_until_parked(); - - assert_eq!( - fake_model.completion_count(), - MAX_CONCURRENT_COMPLETION_REQUESTS - 1 - ); - - // Finish all remaining completion requests. - for request in fake_model.pending_completions() { - fake_model.finish_completion(&request); - } - - cx.background_executor().run_until_parked(); - - assert_eq!(fake_model.completion_count(), 0); - } -} diff --git a/crates/copilot/src/copilot_chat.rs b/crates/copilot/src/copilot_chat.rs index 6d3a2ee7dc..4fe76e6e51 100644 --- a/crates/copilot/src/copilot_chat.rs +++ b/crates/copilot/src/copilot_chat.rs @@ -208,13 +208,13 @@ impl CopilotChat { pub async fn stream_completion( request: Request, low_speed_timeout: Option, - cx: &mut AsyncAppContext, + mut cx: AsyncAppContext, ) -> Result>> { let Some(this) = cx.update(|cx| Self::global(cx)).ok().flatten() else { return Err(anyhow!("Copilot chat is not enabled")); }; - let (oauth_token, api_token, client) = this.read_with(cx, |this, _| { + let (oauth_token, api_token, client) = this.read_with(&cx, |this, _| { ( this.oauth_token.clone(), this.api_token.clone(), @@ -229,7 +229,7 @@ impl CopilotChat { _ => { let token = request_api_token(&oauth_token, client.clone(), low_speed_timeout).await?; - this.update(cx, |this, cx| { + this.update(&mut cx, |this, cx| { this.api_token = Some(token.clone()); cx.notify(); })?; diff --git a/crates/language_model/Cargo.toml b/crates/language_model/Cargo.toml index d7b609dde0..eb90143847 100644 --- a/crates/language_model/Cargo.toml +++ b/crates/language_model/Cargo.toml @@ -33,6 +33,7 @@ google_ai = { workspace = true, features = ["schemars"] } gpui.workspace = true http_client.workspace = true inline_completion_button.workspace = true +log.workspace = true menu.workspace = true ollama = { workspace = true, features = ["schemars"] } open_ai = { workspace = true, features = ["schemars"] } @@ -42,6 +43,7 @@ schemars.workspace = true serde.workspace = true serde_json.workspace = true settings.workspace = true +smol.workspace = true strum.workspace = true theme.workspace = true tiktoken-rs.workspace = true diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs index 0d7a003663..6dcc874721 100644 --- a/crates/language_model/src/language_model.rs +++ b/crates/language_model/src/language_model.rs @@ -1,24 +1,24 @@ mod model; pub mod provider; +mod rate_limiter; mod registry; mod request; mod role; pub mod settings; -use std::sync::Arc; - use anyhow::Result; use client::Client; use futures::{future::BoxFuture, stream::BoxStream}; use gpui::{AnyView, AppContext, AsyncAppContext, SharedString, Task, WindowContext}; - pub use model::*; use project::Fs; +pub(crate) use rate_limiter::*; pub use registry::*; pub use request::*; pub use role::*; use schemars::JsonSchema; use serde::de::DeserializeOwned; +use std::{future::Future, sync::Arc}; pub fn init(client: Arc, fs: Arc, cx: &mut AppContext) { settings::init(fs, cx); @@ -46,7 +46,7 @@ pub trait LanguageModel: Send + Sync { cx: &AsyncAppContext, ) -> BoxFuture<'static, Result>>>; - fn use_tool( + fn use_any_tool( &self, request: LanguageModelRequest, name: String, @@ -56,6 +56,22 @@ pub trait LanguageModel: Send + Sync { ) -> BoxFuture<'static, Result>; } +impl dyn LanguageModel { + pub fn use_tool( + &self, + request: LanguageModelRequest, + cx: &AsyncAppContext, + ) -> impl 'static + Future> { + let schema = schemars::schema_for!(T); + let schema_json = serde_json::to_value(&schema).unwrap(); + let request = self.use_any_tool(request, T::name(), T::description(), schema_json, cx); + async move { + let response = request.await?; + Ok(serde_json::from_value(response)?) + } + } +} + pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema { fn name() -> String; fn description() -> String; @@ -67,9 +83,9 @@ pub trait LanguageModelProvider: 'static { fn provided_models(&self, cx: &AppContext) -> Vec>; fn load_model(&self, _model: Arc, _cx: &AppContext) {} fn is_authenticated(&self, cx: &AppContext) -> bool; - fn authenticate(&self, cx: &AppContext) -> Task>; + fn authenticate(&self, cx: &mut AppContext) -> Task>; fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView; - fn reset_credentials(&self, cx: &AppContext) -> Task>; + fn reset_credentials(&self, cx: &mut AppContext) -> Task>; } pub trait LanguageModelProviderState: 'static { diff --git a/crates/language_model/src/provider/anthropic.rs b/crates/language_model/src/provider/anthropic.rs index 32932953e7..ddaad618c4 100644 --- a/crates/language_model/src/provider/anthropic.rs +++ b/crates/language_model/src/provider/anthropic.rs @@ -1,7 +1,7 @@ use crate::{ settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, - LanguageModelProviderState, LanguageModelRequest, Role, + LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role, }; use anyhow::{anyhow, Context as _, Result}; use collections::BTreeMap; @@ -36,6 +36,7 @@ pub struct AnthropicSettings { pub struct AvailableModel { pub name: String, pub max_tokens: usize, + pub tool_override: Option, } pub struct AnthropicLanguageModelProvider { @@ -98,6 +99,7 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider { anthropic::Model::Custom { name: model.name.clone(), max_tokens: model.max_tokens, + tool_override: model.tool_override.clone(), }, ); } @@ -110,6 +112,7 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider { model, state: self.state.clone(), http_client: self.http_client.clone(), + request_limiter: RateLimiter::new(4), }) as Arc }) .collect() @@ -119,7 +122,7 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider { self.state.read(cx).api_key.is_some() } - fn authenticate(&self, cx: &AppContext) -> Task> { + fn authenticate(&self, cx: &mut AppContext) -> Task> { if self.is_authenticated(cx) { Task::ready(Ok(())) } else { @@ -152,7 +155,7 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider { .into() } - fn reset_credentials(&self, cx: &AppContext) -> Task> { + fn reset_credentials(&self, cx: &mut AppContext) -> Task> { let state = self.state.clone(); let delete_credentials = cx.delete_credentials(&AllLanguageModelSettings::get_global(cx).anthropic.api_url); @@ -171,6 +174,7 @@ pub struct AnthropicModel { model: anthropic::Model, state: gpui::Model, http_client: Arc, + request_limiter: RateLimiter, } pub fn count_anthropic_tokens( @@ -296,14 +300,14 @@ impl LanguageModel for AnthropicModel { ) -> BoxFuture<'static, Result>>> { let request = request.into_anthropic(self.model.id().into()); let request = self.stream_completion(request, cx); - async move { + let future = self.request_limiter.stream(async move { let response = request.await?; - Ok(anthropic::extract_text_from_events(response).boxed()) - } - .boxed() + Ok(anthropic::extract_text_from_events(response)) + }); + async move { Ok(future.await?.boxed()) }.boxed() } - fn use_tool( + fn use_any_tool( &self, request: LanguageModelRequest, tool_name: String, @@ -311,7 +315,7 @@ impl LanguageModel for AnthropicModel { input_schema: serde_json::Value, cx: &AsyncAppContext, ) -> BoxFuture<'static, Result> { - let mut request = request.into_anthropic(self.model.id().into()); + let mut request = request.into_anthropic(self.model.tool_model_id().into()); request.tool_choice = Some(anthropic::ToolChoice::Tool { name: tool_name.clone(), }); @@ -322,25 +326,26 @@ impl LanguageModel for AnthropicModel { }]; let response = self.request_completion(request, cx); - async move { - let response = response.await?; - response - .content - .into_iter() - .find_map(|content| { - if let anthropic::Content::ToolUse { name, input, .. } = content { - if name == tool_name { - Some(input) + self.request_limiter + .run(async move { + let response = response.await?; + response + .content + .into_iter() + .find_map(|content| { + if let anthropic::Content::ToolUse { name, input, .. } = content { + if name == tool_name { + Some(input) + } else { + None + } } else { None } - } else { - None - } - }) - .context("tool not used") - } - .boxed() + }) + .context("tool not used") + }) + .boxed() } } diff --git a/crates/language_model/src/provider/cloud.rs b/crates/language_model/src/provider/cloud.rs index 8c32c723c9..362539fd85 100644 --- a/crates/language_model/src/provider/cloud.rs +++ b/crates/language_model/src/provider/cloud.rs @@ -2,7 +2,7 @@ use super::open_ai::count_open_ai_tokens; use crate::{ settings::AllLanguageModelSettings, CloudModel, LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProviderId, LanguageModelProviderName, - LanguageModelProviderState, LanguageModelRequest, + LanguageModelProviderState, LanguageModelRequest, RateLimiter, }; use anyhow::{anyhow, Context as _, Result}; use client::Client; @@ -41,6 +41,7 @@ pub struct AvailableModel { provider: AvailableProvider, name: String, max_tokens: usize, + tool_override: Option, } pub struct CloudLanguageModelProvider { @@ -56,7 +57,7 @@ struct State { } impl State { - fn authenticate(&self, cx: &AppContext) -> Task> { + fn authenticate(&self, cx: &mut AppContext) -> Task> { let client = self.client.clone(); cx.spawn(move |cx| async move { client.authenticate_and_connect(true, &cx).await }) } @@ -142,6 +143,7 @@ impl LanguageModelProvider for CloudLanguageModelProvider { AvailableProvider::Anthropic => CloudModel::Anthropic(anthropic::Model::Custom { name: model.name.clone(), max_tokens: model.max_tokens, + tool_override: model.tool_override.clone(), }), AvailableProvider::OpenAi => CloudModel::OpenAi(open_ai::Model::Custom { name: model.name.clone(), @@ -162,6 +164,7 @@ impl LanguageModelProvider for CloudLanguageModelProvider { id: LanguageModelId::from(model.id().to_string()), model, client: self.client.clone(), + request_limiter: RateLimiter::new(4), }) as Arc }) .collect() @@ -171,8 +174,8 @@ impl LanguageModelProvider for CloudLanguageModelProvider { self.state.read(cx).status.is_connected() } - fn authenticate(&self, cx: &AppContext) -> Task> { - self.state.read(cx).authenticate(cx) + fn authenticate(&self, cx: &mut AppContext) -> Task> { + self.state.update(cx, |state, cx| state.authenticate(cx)) } fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView { @@ -182,7 +185,7 @@ impl LanguageModelProvider for CloudLanguageModelProvider { .into() } - fn reset_credentials(&self, _cx: &AppContext) -> Task> { + fn reset_credentials(&self, _cx: &mut AppContext) -> Task> { Task::ready(Ok(())) } } @@ -191,6 +194,7 @@ pub struct CloudLanguageModel { id: LanguageModelId, model: CloudModel, client: Arc, + request_limiter: RateLimiter, } impl LanguageModel for CloudLanguageModel { @@ -256,7 +260,7 @@ impl LanguageModel for CloudLanguageModel { CloudModel::Anthropic(model) => { let client = self.client.clone(); let request = request.into_anthropic(model.id().into()); - async move { + let future = self.request_limiter.stream(async move { let request = serde_json::to_string(&request)?; let stream = client .request_stream(proto::StreamCompleteWithLanguageModel { @@ -266,15 +270,14 @@ impl LanguageModel for CloudLanguageModel { .await?; Ok(anthropic::extract_text_from_events( stream.map(|item| Ok(serde_json::from_str(&item?.event)?)), - ) - .boxed()) - } - .boxed() + )) + }); + async move { Ok(future.await?.boxed()) }.boxed() } CloudModel::OpenAi(model) => { let client = self.client.clone(); let request = request.into_open_ai(model.id().into()); - async move { + let future = self.request_limiter.stream(async move { let request = serde_json::to_string(&request)?; let stream = client .request_stream(proto::StreamCompleteWithLanguageModel { @@ -284,15 +287,14 @@ impl LanguageModel for CloudLanguageModel { .await?; Ok(open_ai::extract_text_from_events( stream.map(|item| Ok(serde_json::from_str(&item?.event)?)), - ) - .boxed()) - } - .boxed() + )) + }); + async move { Ok(future.await?.boxed()) }.boxed() } CloudModel::Google(model) => { let client = self.client.clone(); let request = request.into_google(model.id().into()); - async move { + let future = self.request_limiter.stream(async move { let request = serde_json::to_string(&request)?; let stream = client .request_stream(proto::StreamCompleteWithLanguageModel { @@ -302,15 +304,14 @@ impl LanguageModel for CloudLanguageModel { .await?; Ok(google_ai::extract_text_from_events( stream.map(|item| Ok(serde_json::from_str(&item?.event)?)), - ) - .boxed()) - } - .boxed() + )) + }); + async move { Ok(future.await?.boxed()) }.boxed() } } } - fn use_tool( + fn use_any_tool( &self, request: LanguageModelRequest, tool_name: String, @@ -321,7 +322,7 @@ impl LanguageModel for CloudLanguageModel { match &self.model { CloudModel::Anthropic(model) => { let client = self.client.clone(); - let mut request = request.into_anthropic(model.id().into()); + let mut request = request.into_anthropic(model.tool_model_id().into()); request.tool_choice = Some(anthropic::ToolChoice::Tool { name: tool_name.clone(), }); @@ -331,32 +332,34 @@ impl LanguageModel for CloudLanguageModel { input_schema, }]; - async move { - let request = serde_json::to_string(&request)?; - let response = client - .request(proto::CompleteWithLanguageModel { - provider: proto::LanguageModelProvider::Anthropic as i32, - request, - }) - .await?; - let response: anthropic::Response = serde_json::from_str(&response.completion)?; - response - .content - .into_iter() - .find_map(|content| { - if let anthropic::Content::ToolUse { name, input, .. } = content { - if name == tool_name { - Some(input) + self.request_limiter + .run(async move { + let request = serde_json::to_string(&request)?; + let response = client + .request(proto::CompleteWithLanguageModel { + provider: proto::LanguageModelProvider::Anthropic as i32, + request, + }) + .await?; + let response: anthropic::Response = + serde_json::from_str(&response.completion)?; + response + .content + .into_iter() + .find_map(|content| { + if let anthropic::Content::ToolUse { name, input, .. } = content { + if name == tool_name { + Some(input) + } else { + None + } } else { None } - } else { - None - } - }) - .context("tool not used") - } - .boxed() + }) + .context("tool not used") + }) + .boxed() } CloudModel::OpenAi(_) => { future::ready(Err(anyhow!("tool use not implemented for OpenAI"))).boxed() diff --git a/crates/language_model/src/provider/copilot_chat.rs b/crates/language_model/src/provider/copilot_chat.rs index 285537a848..072c87b92e 100644 --- a/crates/language_model/src/provider/copilot_chat.rs +++ b/crates/language_model/src/provider/copilot_chat.rs @@ -27,7 +27,7 @@ use crate::settings::AllLanguageModelSettings; use crate::LanguageModelProviderState; use crate::{ LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider, - LanguageModelProviderId, LanguageModelProviderName, LanguageModelRequest, Role, + LanguageModelProviderId, LanguageModelProviderName, LanguageModelRequest, RateLimiter, Role, }; use super::open_ai::count_open_ai_tokens; @@ -85,7 +85,12 @@ impl LanguageModelProvider for CopilotChatLanguageModelProvider { fn provided_models(&self, _cx: &AppContext) -> Vec> { CopilotChatModel::iter() - .map(|model| Arc::new(CopilotChatLanguageModel { model }) as Arc) + .map(|model| { + Arc::new(CopilotChatLanguageModel { + model, + request_limiter: RateLimiter::new(4), + }) as Arc + }) .collect() } @@ -95,7 +100,7 @@ impl LanguageModelProvider for CopilotChatLanguageModelProvider { .unwrap_or(false) } - fn authenticate(&self, cx: &AppContext) -> Task> { + fn authenticate(&self, cx: &mut AppContext) -> Task> { let result = if self.is_authenticated(cx) { Ok(()) } else if let Some(copilot) = Copilot::global(cx) { @@ -121,7 +126,7 @@ impl LanguageModelProvider for CopilotChatLanguageModelProvider { cx.new_view(|cx| AuthenticationPrompt::new(cx)).into() } - fn reset_credentials(&self, cx: &AppContext) -> Task> { + fn reset_credentials(&self, cx: &mut AppContext) -> Task> { let Some(copilot) = Copilot::global(cx) else { return Task::ready(Err(anyhow::anyhow!( "Copilot is not available. Please ensure Copilot is enabled and running and try again." @@ -145,6 +150,7 @@ impl LanguageModelProvider for CopilotChatLanguageModelProvider { pub struct CopilotChatLanguageModel { model: CopilotChatModel, + request_limiter: RateLimiter, } impl LanguageModel for CopilotChatLanguageModel { @@ -215,30 +221,35 @@ impl LanguageModel for CopilotChatLanguageModel { return futures::future::ready(Err(anyhow::anyhow!("App state dropped"))).boxed(); }; - cx.spawn(|mut cx| async move { - let response = CopilotChat::stream_completion(request, low_speed_timeout, &mut cx).await?; - let stream = response - .filter_map(|response| async move { - match response { - Ok(result) => { - let choice = result.choices.first(); - match choice { - Some(choice) => Some(Ok(choice.delta.content.clone().unwrap_or_default())), - None => Some(Err(anyhow::anyhow!( - "The Copilot Chat API returned a response with no choices, but hadn't finished the message yet. Please try again." - ))), + let request_limiter = self.request_limiter.clone(); + let future = cx.spawn(|cx| async move { + let response = CopilotChat::stream_completion(request, low_speed_timeout, cx); + request_limiter.stream(async move { + let response = response.await?; + let stream = response + .filter_map(|response| async move { + match response { + Ok(result) => { + let choice = result.choices.first(); + match choice { + Some(choice) => Some(Ok(choice.delta.content.clone().unwrap_or_default())), + None => Some(Err(anyhow::anyhow!( + "The Copilot Chat API returned a response with no choices, but hadn't finished the message yet. Please try again." + ))), + } } + Err(err) => Some(Err(err)), } - Err(err) => Some(Err(err)), - } - }) - .boxed(); - Ok(stream) - }) - .boxed() + }) + .boxed(); + Ok(stream) + }).await + }); + + async move { Ok(future.await?.boxed()) }.boxed() } - fn use_tool( + fn use_any_tool( &self, _request: LanguageModelRequest, _name: String, diff --git a/crates/language_model/src/provider/fake.rs b/crates/language_model/src/provider/fake.rs index 7d5a6192a8..f92ecaf467 100644 --- a/crates/language_model/src/provider/fake.rs +++ b/crates/language_model/src/provider/fake.rs @@ -60,7 +60,7 @@ impl LanguageModelProvider for FakeLanguageModelProvider { true } - fn authenticate(&self, _: &AppContext) -> Task> { + fn authenticate(&self, _: &mut AppContext) -> Task> { Task::ready(Ok(())) } @@ -68,7 +68,7 @@ impl LanguageModelProvider for FakeLanguageModelProvider { unimplemented!() } - fn reset_credentials(&self, _: &AppContext) -> Task> { + fn reset_credentials(&self, _: &mut AppContext) -> Task> { Task::ready(Ok(())) } } @@ -173,7 +173,7 @@ impl LanguageModel for FakeLanguageModel { async move { Ok(rx.map(Ok).boxed()) }.boxed() } - fn use_tool( + fn use_any_tool( &self, _request: LanguageModelRequest, _name: String, diff --git a/crates/language_model/src/provider/google.rs b/crates/language_model/src/provider/google.rs index e0969eda0b..2739623c6a 100644 --- a/crates/language_model/src/provider/google.rs +++ b/crates/language_model/src/provider/google.rs @@ -20,7 +20,7 @@ use util::ResultExt; use crate::{ settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, - LanguageModelProviderState, LanguageModelRequest, + LanguageModelProviderState, LanguageModelRequest, RateLimiter, }; const PROVIDER_ID: &str = "google"; @@ -111,6 +111,7 @@ impl LanguageModelProvider for GoogleLanguageModelProvider { model, state: self.state.clone(), http_client: self.http_client.clone(), + rate_limiter: RateLimiter::new(4), }) as Arc }) .collect() @@ -120,7 +121,7 @@ impl LanguageModelProvider for GoogleLanguageModelProvider { self.state.read(cx).api_key.is_some() } - fn authenticate(&self, cx: &AppContext) -> Task> { + fn authenticate(&self, cx: &mut AppContext) -> Task> { if self.is_authenticated(cx) { Task::ready(Ok(())) } else { @@ -153,7 +154,7 @@ impl LanguageModelProvider for GoogleLanguageModelProvider { .into() } - fn reset_credentials(&self, cx: &AppContext) -> Task> { + fn reset_credentials(&self, cx: &mut AppContext) -> Task> { let state = self.state.clone(); let delete_credentials = cx.delete_credentials(&AllLanguageModelSettings::get_global(cx).google.api_url); @@ -172,6 +173,7 @@ pub struct GoogleLanguageModel { model: google_ai::Model, state: gpui::Model, http_client: Arc, + rate_limiter: RateLimiter, } impl LanguageModel for GoogleLanguageModel { @@ -243,17 +245,17 @@ impl LanguageModel for GoogleLanguageModel { return futures::future::ready(Err(anyhow!("App state dropped"))).boxed(); }; - async move { + let future = self.rate_limiter.stream(async move { let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?; let response = stream_generate_content(http_client.as_ref(), &api_url, &api_key, request); let events = response.await?; Ok(google_ai::extract_text_from_events(events).boxed()) - } - .boxed() + }); + async move { Ok(future.await?.boxed()) }.boxed() } - fn use_tool( + fn use_any_tool( &self, _request: LanguageModelRequest, _name: String, diff --git a/crates/language_model/src/provider/ollama.rs b/crates/language_model/src/provider/ollama.rs index 3502748e08..0364866ccd 100644 --- a/crates/language_model/src/provider/ollama.rs +++ b/crates/language_model/src/provider/ollama.rs @@ -12,7 +12,7 @@ use ui::{prelude::*, ButtonLike, ElevationIndex}; use crate::{ settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, - LanguageModelProviderState, LanguageModelRequest, Role, + LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role, }; const OLLAMA_DOWNLOAD_URL: &str = "https://ollama.com/download"; @@ -39,7 +39,7 @@ struct State { } impl State { - fn fetch_models(&self, cx: &ModelContext) -> Task> { + fn fetch_models(&mut self, cx: &mut ModelContext) -> Task> { let settings = &AllLanguageModelSettings::get_global(cx).ollama; let http_client = self.http_client.clone(); let api_url = settings.api_url.clone(); @@ -80,37 +80,10 @@ impl OllamaLanguageModelProvider { }), }), }; - this.fetch_models(cx).detach(); + this.state + .update(cx, |state, cx| state.fetch_models(cx).detach()); this } - - fn fetch_models(&self, cx: &AppContext) -> Task> { - let settings = &AllLanguageModelSettings::get_global(cx).ollama; - let http_client = self.http_client.clone(); - let api_url = settings.api_url.clone(); - - let state = self.state.clone(); - // As a proxy for the server being "authenticated", we'll check if its up by fetching the models - cx.spawn(|mut cx| async move { - let models = get_models(http_client.as_ref(), &api_url, None).await?; - - let mut models: Vec = models - .into_iter() - // Since there is no metadata from the Ollama API - // indicating which models are embedding models, - // simply filter out models with "-embed" in their name - .filter(|model| !model.name.contains("-embed")) - .map(|model| ollama::Model::new(&model.name)) - .collect(); - - models.sort_by(|a, b| a.name.cmp(&b.name)); - - state.update(&mut cx, |this, cx| { - this.available_models = models; - cx.notify(); - }) - }) - } } impl LanguageModelProviderState for OllamaLanguageModelProvider { @@ -140,6 +113,7 @@ impl LanguageModelProvider for OllamaLanguageModelProvider { id: LanguageModelId::from(model.name.clone()), model: model.clone(), http_client: self.http_client.clone(), + request_limiter: RateLimiter::new(4), }) as Arc }) .collect() @@ -158,11 +132,11 @@ impl LanguageModelProvider for OllamaLanguageModelProvider { !self.state.read(cx).available_models.is_empty() } - fn authenticate(&self, cx: &AppContext) -> Task> { + fn authenticate(&self, cx: &mut AppContext) -> Task> { if self.is_authenticated(cx) { Task::ready(Ok(())) } else { - self.fetch_models(cx) + self.state.update(cx, |state, cx| state.fetch_models(cx)) } } @@ -176,8 +150,8 @@ impl LanguageModelProvider for OllamaLanguageModelProvider { .into() } - fn reset_credentials(&self, cx: &AppContext) -> Task> { - self.fetch_models(cx) + fn reset_credentials(&self, cx: &mut AppContext) -> Task> { + self.state.update(cx, |state, cx| state.fetch_models(cx)) } } @@ -185,6 +159,7 @@ pub struct OllamaLanguageModel { id: LanguageModelId, model: ollama::Model, http_client: Arc, + request_limiter: RateLimiter, } impl OllamaLanguageModel { @@ -235,14 +210,14 @@ impl LanguageModel for OllamaLanguageModel { LanguageModelProviderName(PROVIDER_NAME.into()) } - fn max_token_count(&self) -> usize { - self.model.max_token_count() - } - fn telemetry_id(&self) -> String { format!("ollama/{}", self.model.id()) } + fn max_token_count(&self) -> usize { + self.model.max_token_count() + } + fn count_tokens( &self, request: LanguageModelRequest, @@ -275,10 +250,10 @@ impl LanguageModel for OllamaLanguageModel { return futures::future::ready(Err(anyhow!("App state dropped"))).boxed(); }; - async move { - let request = - stream_chat_completion(http_client.as_ref(), &api_url, request, low_speed_timeout); - let response = request.await?; + let future = self.request_limiter.stream(async move { + let response = + stream_chat_completion(http_client.as_ref(), &api_url, request, low_speed_timeout) + .await?; let stream = response .filter_map(|response| async move { match response { @@ -295,11 +270,12 @@ impl LanguageModel for OllamaLanguageModel { }) .boxed(); Ok(stream) - } - .boxed() + }); + + async move { Ok(future.await?.boxed()) }.boxed() } - fn use_tool( + fn use_any_tool( &self, _request: LanguageModelRequest, _name: String, diff --git a/crates/language_model/src/provider/open_ai.rs b/crates/language_model/src/provider/open_ai.rs index 6beec3d0f5..9f24dabb09 100644 --- a/crates/language_model/src/provider/open_ai.rs +++ b/crates/language_model/src/provider/open_ai.rs @@ -20,7 +20,7 @@ use util::ResultExt; use crate::{ settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, - LanguageModelProviderState, LanguageModelRequest, Role, + LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role, }; const PROVIDER_ID: &str = "openai"; @@ -112,6 +112,7 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider { model, state: self.state.clone(), http_client: self.http_client.clone(), + request_limiter: RateLimiter::new(4), }) as Arc }) .collect() @@ -121,7 +122,7 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider { self.state.read(cx).api_key.is_some() } - fn authenticate(&self, cx: &AppContext) -> Task> { + fn authenticate(&self, cx: &mut AppContext) -> Task> { if self.is_authenticated(cx) { Task::ready(Ok(())) } else { @@ -153,7 +154,7 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider { .into() } - fn reset_credentials(&self, cx: &AppContext) -> Task> { + fn reset_credentials(&self, cx: &mut AppContext) -> Task> { let settings = &AllLanguageModelSettings::get_global(cx).openai; let delete_credentials = cx.delete_credentials(&settings.api_url); let state = self.state.clone(); @@ -172,6 +173,7 @@ pub struct OpenAiLanguageModel { model: open_ai::Model, state: gpui::Model, http_client: Arc, + request_limiter: RateLimiter, } impl LanguageModel for OpenAiLanguageModel { @@ -226,7 +228,7 @@ impl LanguageModel for OpenAiLanguageModel { return futures::future::ready(Err(anyhow!("App state dropped"))).boxed(); }; - async move { + let future = self.request_limiter.stream(async move { let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?; let request = stream_completion( http_client.as_ref(), @@ -237,11 +239,12 @@ impl LanguageModel for OpenAiLanguageModel { ); let response = request.await?; Ok(open_ai::extract_text_from_events(response).boxed()) - } - .boxed() + }); + + async move { Ok(future.await?.boxed()) }.boxed() } - fn use_tool( + fn use_any_tool( &self, _request: LanguageModelRequest, _name: String, diff --git a/crates/language_model/src/rate_limiter.rs b/crates/language_model/src/rate_limiter.rs new file mode 100644 index 0000000000..faa117fe53 --- /dev/null +++ b/crates/language_model/src/rate_limiter.rs @@ -0,0 +1,70 @@ +use anyhow::Result; +use futures::Stream; +use smol::lock::{Semaphore, SemaphoreGuardArc}; +use std::{ + future::Future, + pin::Pin, + sync::Arc, + task::{Context, Poll}, +}; + +#[derive(Clone)] +pub struct RateLimiter { + semaphore: Arc, +} + +pub struct RateLimitGuard { + inner: T, + _guard: SemaphoreGuardArc, +} + +impl Stream for RateLimitGuard +where + T: Stream, +{ + type Item = T::Item; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + unsafe { Pin::map_unchecked_mut(self, |this| &mut this.inner).poll_next(cx) } + } +} + +impl RateLimiter { + pub fn new(limit: usize) -> Self { + Self { + semaphore: Arc::new(Semaphore::new(limit)), + } + } + + pub fn run<'a, Fut, T>(&self, future: Fut) -> impl 'a + Future> + where + Fut: 'a + Future>, + { + let guard = self.semaphore.acquire_arc(); + async move { + let guard = guard.await; + let result = future.await?; + drop(guard); + Ok(result) + } + } + + pub fn stream<'a, Fut, T>( + &self, + future: Fut, + ) -> impl 'a + Future>> + where + Fut: 'a + Future>, + T: Stream, + { + let guard = self.semaphore.acquire_arc(); + async move { + let guard = guard.await; + let inner = future.await?; + Ok(RateLimitGuard { + inner, + _guard: guard, + }) + } + } +} diff --git a/crates/language_model/src/registry.rs b/crates/language_model/src/registry.rs index d90163671e..8bda65b07a 100644 --- a/crates/language_model/src/registry.rs +++ b/crates/language_model/src/registry.rs @@ -4,11 +4,12 @@ use crate::{ copilot_chat::CopilotChatLanguageModelProvider, google::GoogleLanguageModelProvider, ollama::OllamaLanguageModelProvider, open_ai::OpenAiLanguageModelProvider, }, - LanguageModel, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderState, + LanguageModel, LanguageModelId, LanguageModelProvider, LanguageModelProviderId, + LanguageModelProviderState, }; use client::Client; use collections::BTreeMap; -use gpui::{AppContext, Global, Model, ModelContext}; +use gpui::{AppContext, EventEmitter, Global, Model, ModelContext}; use std::sync::Arc; use ui::Context; @@ -70,9 +71,19 @@ impl Global for GlobalLanguageModelRegistry {} #[derive(Default)] pub struct LanguageModelRegistry { + active_model: Option, providers: BTreeMap>, } +pub struct ActiveModel { + provider: Arc, + model: Option>, +} + +pub struct ActiveModelChanged; + +impl EventEmitter for LanguageModelRegistry {} + impl LanguageModelRegistry { pub fn global(cx: &AppContext) -> Model { cx.global::().0.clone() @@ -88,6 +99,8 @@ impl LanguageModelRegistry { let registry = cx.new_model(|cx| { let mut registry = Self::default(); registry.register_provider(fake_provider.clone(), cx); + let model = fake_provider.provided_models(cx)[0].clone(); + registry.set_active_model(Some(model), cx); registry }); cx.set_global(GlobalLanguageModelRegistry(registry)); @@ -136,6 +149,64 @@ impl LanguageModelRegistry { ) -> Option> { self.providers.get(name).cloned() } + + pub fn select_active_model( + &mut self, + provider: &LanguageModelProviderId, + model_id: &LanguageModelId, + cx: &mut ModelContext, + ) { + let Some(provider) = self.provider(&provider) else { + return; + }; + + let models = provider.provided_models(cx); + if let Some(model) = models.iter().find(|model| &model.id() == model_id).cloned() { + self.set_active_model(Some(model), cx); + } + } + + pub fn set_active_provider( + &mut self, + provider: Option>, + cx: &mut ModelContext, + ) { + self.active_model = provider.map(|provider| ActiveModel { + provider, + model: None, + }); + cx.emit(ActiveModelChanged); + } + + pub fn set_active_model( + &mut self, + model: Option>, + cx: &mut ModelContext, + ) { + if let Some(model) = model { + let provider_id = model.provider_id(); + if let Some(provider) = self.providers.get(&provider_id).cloned() { + self.active_model = Some(ActiveModel { + provider, + model: Some(model), + }); + cx.emit(ActiveModelChanged); + } else { + log::warn!("Active model's provider not found in registry"); + } + } else { + self.active_model = None; + cx.emit(ActiveModelChanged); + } + } + + pub fn active_provider(&self) -> Option> { + Some(self.active_model.as_ref()?.provider.clone()) + } + + pub fn active_model(&self) -> Option> { + self.active_model.as_ref()?.model.clone() + } } #[cfg(test)] diff --git a/crates/language_model/src/settings.rs b/crates/language_model/src/settings.rs index 3cb012860c..17bbaf77ad 100644 --- a/crates/language_model/src/settings.rs +++ b/crates/language_model/src/settings.rs @@ -89,9 +89,15 @@ impl AnthropicSettingsContent { models .into_iter() .filter_map(|model| match model { - anthropic::Model::Custom { name, max_tokens } => { - Some(provider::anthropic::AvailableModel { name, max_tokens }) - } + anthropic::Model::Custom { + name, + max_tokens, + tool_override, + } => Some(provider::anthropic::AvailableModel { + name, + max_tokens, + tool_override, + }), _ => None, }) .collect() diff --git a/crates/semantic_index/Cargo.toml b/crates/semantic_index/Cargo.toml index 6dc12b040c..4fd3a86b29 100644 --- a/crates/semantic_index/Cargo.toml +++ b/crates/semantic_index/Cargo.toml @@ -22,7 +22,6 @@ anyhow.workspace = true client.workspace = true clock.workspace = true collections.workspace = true -completion.workspace = true fs.workspace = true futures.workspace = true futures-batch.workspace = true diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index 23a04a47eb..404884cfb5 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -1261,6 +1261,3 @@ mod tests { ); } } - -// See https://github.com/zed-industries/zed/pull/14823#discussion_r1684616398 for why this is here and when it should be removed. -type _TODO = completion::LanguageModelCompletionProvider;