From ec487d8f649603e040fec4df2764585f9425532f Mon Sep 17 00:00:00 2001 From: Richard Feldman Date: Fri, 19 Jul 2024 13:35:34 -0400 Subject: [PATCH] Extract completion provider crate (#14823) We will soon need `semantic_index` to be able to use `CompletionProvider`. This is currently impossible due to a cyclic crate dependency, because `CompletionProvider` lives in the `assistant` crate, which depends on `semantic_index`. This PR breaks the dependency cycle by extracting two crates out of `assistant`: `language_model` and `completion`. Only one piece of logic changed: [this code](https://github.com/zed-industries/zed/commit/922fcaf5a6076e56890373035b1065b13512546d#diff-3857b3707687a4d585f1200eec4c34a7a079eae8d303b4ce5b4fce46234ace9fR61-R69). * As of https://github.com/zed-industries/zed/pull/13276, whenever we ask a given completion provider for its available models, OpenAI providers would go and ask the global assistant settings whether the user had configured an `available_models` setting, and if so, return that. * This PR changes it so that instead of eagerly asking the assistant settings for this info (the new crate must not depend on `assistant`, or else the dependency cycle would be back), OpenAI completion providers now store the user-configured settings as part of their struct, and whenever the settings change, we update the provider. In theory, this change should not change user-visible behavior...but since it's the only change in this large PR that's more than just moving code around, I'm mentioning it here in case there's an unexpected regression in practice! (cc @amtoaer in case you'd like to try out this branch and verify that the feature is still working the way you expect.) Release Notes: - N/A --------- Co-authored-by: Marshall Bowers --- Cargo.lock | 64 +++- Cargo.toml | 4 + crates/assistant/Cargo.toml | 5 +- crates/assistant/src/assistant.rs | 191 ++---------- crates/assistant/src/assistant_panel.rs | 13 +- crates/assistant/src/assistant_settings.rs | 291 ++++++++---------- crates/assistant/src/context.rs | 11 +- crates/assistant/src/inline_assistant.rs | 9 +- crates/assistant/src/model_selector.rs | 2 +- crates/assistant/src/prompt_library.rs | 3 +- .../src/terminal_inline_assistant.rs | 6 +- crates/collab/Cargo.toml | 2 + crates/collab/src/tests/test_server.rs | 2 +- crates/completion/Cargo.toml | 56 ++++ crates/completion/LICENSE-GPL | 1 + .../src}/anthropic.rs | 57 +--- .../src}/cloud.rs | 7 +- .../src/completion.rs} | 169 ++-------- .../src}/fake.rs | 2 +- .../src}/ollama.rs | 19 +- .../src}/open_ai.rs | 65 ++-- crates/language_model/Cargo.toml | 41 +++ crates/language_model/LICENSE-GPL | 1 + crates/language_model/src/language_model.rs | 7 + .../language_model/src/model/cloud_model.rs | 160 ++++++++++ crates/language_model/src/model/mod.rs | 60 ++++ crates/language_model/src/request.rs | 110 +++++++ crates/language_model/src/role.rs | 68 ++++ crates/semantic_index/Cargo.toml | 1 + crates/semantic_index/src/semantic_index.rs | 3 + 30 files changed, 820 insertions(+), 610 deletions(-) create mode 100644 crates/completion/Cargo.toml create mode 120000 crates/completion/LICENSE-GPL rename crates/{assistant/src/completion_provider => completion/src}/anthropic.rs (86%) rename crates/{assistant/src/completion_provider => completion/src}/cloud.rs (96%) rename crates/{assistant/src/completion_provider.rs => completion/src/completion.rs} (57%) rename crates/{assistant/src/completion_provider => completion/src}/fake.rs (97%) rename crates/{assistant/src/completion_provider => completion/src}/ollama.rs (96%) rename crates/{assistant/src/completion_provider => completion/src}/open_ai.rs (89%) create mode 100644 crates/language_model/Cargo.toml create mode 120000 crates/language_model/LICENSE-GPL create mode 100644 crates/language_model/src/language_model.rs create mode 100644 crates/language_model/src/model/cloud_model.rs create mode 100644 crates/language_model/src/model/mod.rs create mode 100644 crates/language_model/src/request.rs create mode 100644 crates/language_model/src/role.rs diff --git a/Cargo.lock b/Cargo.lock index c734a89f45..40f9c59228 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -382,6 +382,7 @@ dependencies = [ "clock", "collections", "command_palette_hooks", + "completion", "ctor", "editor", "env_logger", @@ -396,6 +397,7 @@ dependencies = [ "indexed_docs", "indoc", "language", + "language_model", "log", "menu", "multi_buffer", @@ -418,13 +420,11 @@ dependencies = [ "settings", "similar", "smol", - "strum", "telemetry_events", "terminal", "terminal_view", "text", "theme", - "tiktoken-rs", "toml 0.8.10", "ui", "unindent", @@ -2491,6 +2491,7 @@ dependencies = [ "clock", "collab_ui", "collections", + "completion", "ctor", "dashmap", "dev_server_projects", @@ -2673,6 +2674,42 @@ dependencies = [ "gpui", ] +[[package]] +name = "completion" +version = "0.1.0" +dependencies = [ + "anthropic", + "anyhow", + "client", + "collections", + "ctor", + "editor", + "env_logger", + "futures 0.3.28", + "gpui", + "http 0.1.0", + "language", + "language_model", + "log", + "menu", + "ollama", + "open_ai", + "parking_lot", + "project", + "rand 0.8.5", + "serde", + "serde_json", + "settings", + "smol", + "strum", + "text", + "theme", + "tiktoken-rs", + "ui", + "unindent", + "util", +] + [[package]] name = "concurrent-queue" version = "2.2.0" @@ -5996,6 +6033,28 @@ dependencies = [ "util", ] +[[package]] +name = "language_model" +version = "0.1.0" +dependencies = [ + "anthropic", + "ctor", + "editor", + "env_logger", + "language", + "log", + "ollama", + "open_ai", + "project", + "proto", + "rand 0.8.5", + "schemars", + "serde", + "strum", + "text", + "unindent", +] + [[package]] name = "language_selector" version = "0.1.0" @@ -9510,6 +9569,7 @@ dependencies = [ "client", "clock", "collections", + "completion", "env_logger", "fs", "futures 0.3.28", diff --git a/Cargo.toml b/Cargo.toml index 2f607134c4..1df8affd08 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,6 +19,7 @@ members = [ "crates/collections", "crates/command_palette", "crates/command_palette_hooks", + "crates/completion", "crates/copilot", "crates/db", "crates/dev_server_projects", @@ -50,6 +51,7 @@ members = [ "crates/install_cli", "crates/journal", "crates/language", + "crates/language_model", "crates/language_selector", "crates/language_tools", "crates/languages", @@ -176,6 +178,7 @@ collab_ui = { path = "crates/collab_ui" } collections = { path = "crates/collections" } command_palette = { path = "crates/command_palette" } command_palette_hooks = { path = "crates/command_palette_hooks" } +completion = { path = "crates/completion" } copilot = { path = "crates/copilot" } db = { path = "crates/db" } dev_server_projects = { path = "crates/dev_server_projects" } @@ -205,6 +208,7 @@ inline_completion_button = { path = "crates/inline_completion_button" } install_cli = { path = "crates/install_cli" } journal = { path = "crates/journal" } language = { path = "crates/language" } +language_model = { path = "crates/language_model" } language_selector = { path = "crates/language_selector" } language_tools = { path = "crates/language_tools" } languages = { path = "crates/languages" } diff --git a/crates/assistant/Cargo.toml b/crates/assistant/Cargo.toml index e3ddd4e2c7..201e16bd57 100644 --- a/crates/assistant/Cargo.toml +++ b/crates/assistant/Cargo.toml @@ -33,6 +33,7 @@ client.workspace = true clock.workspace = true collections.workspace = true command_palette_hooks.workspace = true +completion.workspace = true editor.workspace = true feature_flags.workspace = true fs.workspace = true @@ -45,6 +46,7 @@ http.workspace = true indexed_docs.workspace = true indoc.workspace = true language.workspace = true +language_model.workspace = true log.workspace = true menu.workspace = true multi_buffer.workspace = true @@ -64,12 +66,10 @@ serde_json.workspace = true settings.workspace = true similar.workspace = true smol.workspace = true -strum.workspace = true telemetry_events.workspace = true terminal.workspace = true terminal_view.workspace = true theme.workspace = true -tiktoken-rs.workspace = true toml.workspace = true ui.workspace = true util.workspace = true @@ -79,6 +79,7 @@ picker.workspace = true roxmltree = "0.20.0" [dev-dependencies] +completion = { workspace = true, features = ["test-support"] } ctor.workspace = true editor = { workspace = true, features = ["test-support"] } env_logger.workspace = true diff --git a/crates/assistant/src/assistant.rs b/crates/assistant/src/assistant.rs index cf3726485f..0b12cc099c 100644 --- a/crates/assistant/src/assistant.rs +++ b/crates/assistant/src/assistant.rs @@ -1,6 +1,5 @@ pub mod assistant_panel; pub mod assistant_settings; -mod completion_provider; mod context; pub mod context_store; mod inline_assistant; @@ -12,17 +11,20 @@ mod streaming_diff; mod terminal_inline_assistant; pub use assistant_panel::{AssistantPanel, AssistantPanelEvent}; -use assistant_settings::{AnthropicModel, AssistantSettings, CloudModel, OllamaModel, OpenAiModel}; +use assistant_settings::AssistantSettings; use assistant_slash_command::SlashCommandRegistry; use client::{proto, Client}; use command_palette_hooks::CommandPaletteFilter; -pub use completion_provider::*; +use completion::CompletionProvider; pub use context::*; pub use context_store::*; use fs::Fs; -use gpui::{actions, impl_actions, AppContext, Global, SharedString, UpdateGlobal}; +use gpui::{ + actions, impl_actions, AppContext, BorrowAppContext, Global, SharedString, UpdateGlobal, +}; use indexed_docs::IndexedDocsRegistry; pub(crate) use inline_assistant::*; +use language_model::LanguageModelResponseMessage; pub(crate) use model_selector::*; use semantic_index::{CloudEmbeddingProvider, SemanticIndex}; use serde::{Deserialize, Serialize}; @@ -32,10 +34,7 @@ use slash_command::{ file_command, now_command, project_command, prompt_command, search_command, symbols_command, tabs_command, term_command, }; -use std::{ - fmt::{self, Display}, - sync::Arc, -}; +use std::sync::Arc; pub(crate) use streaming_diff::*; actions!( @@ -73,166 +72,6 @@ impl MessageId { } } -#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)] -#[serde(rename_all = "lowercase")] -pub enum Role { - User, - Assistant, - System, -} - -impl Role { - pub fn from_proto(role: i32) -> Role { - match proto::LanguageModelRole::from_i32(role) { - Some(proto::LanguageModelRole::LanguageModelUser) => Role::User, - Some(proto::LanguageModelRole::LanguageModelAssistant) => Role::Assistant, - Some(proto::LanguageModelRole::LanguageModelSystem) => Role::System, - Some(proto::LanguageModelRole::LanguageModelTool) => Role::System, - None => Role::User, - } - } - - pub fn to_proto(&self) -> proto::LanguageModelRole { - match self { - Role::User => proto::LanguageModelRole::LanguageModelUser, - Role::Assistant => proto::LanguageModelRole::LanguageModelAssistant, - Role::System => proto::LanguageModelRole::LanguageModelSystem, - } - } - - pub fn cycle(self) -> Role { - match self { - Role::User => Role::Assistant, - Role::Assistant => Role::System, - Role::System => Role::User, - } - } -} - -impl Display for Role { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Role::User => write!(f, "user"), - Role::Assistant => write!(f, "assistant"), - Role::System => write!(f, "system"), - } - } -} - -#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] -pub enum LanguageModel { - Cloud(CloudModel), - OpenAi(OpenAiModel), - Anthropic(AnthropicModel), - Ollama(OllamaModel), -} - -impl Default for LanguageModel { - fn default() -> Self { - LanguageModel::Cloud(CloudModel::default()) - } -} - -impl LanguageModel { - pub fn telemetry_id(&self) -> String { - match self { - LanguageModel::OpenAi(model) => format!("openai/{}", model.id()), - LanguageModel::Anthropic(model) => format!("anthropic/{}", model.id()), - LanguageModel::Cloud(model) => format!("zed.dev/{}", model.id()), - LanguageModel::Ollama(model) => format!("ollama/{}", model.id()), - } - } - - pub fn display_name(&self) -> String { - match self { - LanguageModel::OpenAi(model) => model.display_name().into(), - LanguageModel::Anthropic(model) => model.display_name().into(), - LanguageModel::Cloud(model) => model.display_name().into(), - LanguageModel::Ollama(model) => model.display_name().into(), - } - } - - pub fn max_token_count(&self) -> usize { - match self { - LanguageModel::OpenAi(model) => model.max_token_count(), - LanguageModel::Anthropic(model) => model.max_token_count(), - LanguageModel::Cloud(model) => model.max_token_count(), - LanguageModel::Ollama(model) => model.max_token_count(), - } - } - - pub fn id(&self) -> &str { - match self { - LanguageModel::OpenAi(model) => model.id(), - LanguageModel::Anthropic(model) => model.id(), - LanguageModel::Cloud(model) => model.id(), - LanguageModel::Ollama(model) => model.id(), - } - } -} - -#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] -pub struct LanguageModelRequestMessage { - pub role: Role, - pub content: String, -} - -impl LanguageModelRequestMessage { - pub fn to_proto(&self) -> proto::LanguageModelRequestMessage { - proto::LanguageModelRequestMessage { - role: self.role.to_proto() as i32, - content: self.content.clone(), - tool_calls: Vec::new(), - tool_call_id: None, - } - } -} - -#[derive(Debug, Default, Serialize, Deserialize)] -pub struct LanguageModelRequest { - pub model: LanguageModel, - pub messages: Vec, - pub stop: Vec, - pub temperature: f32, -} - -impl LanguageModelRequest { - pub fn to_proto(&self) -> proto::CompleteWithLanguageModel { - proto::CompleteWithLanguageModel { - model: self.model.id().to_string(), - messages: self.messages.iter().map(|m| m.to_proto()).collect(), - stop: self.stop.clone(), - temperature: self.temperature, - tool_choice: None, - tools: Vec::new(), - } - } - - /// Before we send the request to the server, we can perform fixups on it appropriate to the model. - pub fn preprocess(&mut self) { - match &self.model { - LanguageModel::OpenAi(_) => {} - LanguageModel::Anthropic(_) => {} - LanguageModel::Ollama(_) => {} - LanguageModel::Cloud(model) => match model { - CloudModel::Claude3Opus - | CloudModel::Claude3Sonnet - | CloudModel::Claude3Haiku - | CloudModel::Claude3_5Sonnet => { - preprocess_anthropic_request(self); - } - _ => {} - }, - } - } -} - -#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] -pub struct LanguageModelResponseMessage { - pub role: Option, - pub content: Option, -} - #[derive(Deserialize, Debug)] pub struct LanguageModelUsage { pub prompt_tokens: u32, @@ -343,7 +182,7 @@ pub fn init(fs: Arc, client: Arc, cx: &mut AppContext) { context_store::init(&client); prompt_library::init(cx); - completion_provider::init(client.clone(), cx); + init_completion_provider(Arc::clone(&client), cx); assistant_slash_command::init(cx); register_slash_commands(cx); assistant_panel::init(cx); @@ -368,6 +207,20 @@ pub fn init(fs: Arc, client: Arc, cx: &mut AppContext) { .detach(); } +fn init_completion_provider(client: Arc, cx: &mut AppContext) { + let provider = assistant_settings::create_provider_from_settings(client.clone(), 0, cx); + cx.set_global(CompletionProvider::new(provider, Some(client))); + + let mut settings_version = 0; + cx.observe_global::(move |cx| { + settings_version += 1; + cx.update_global::(|provider, cx| { + assistant_settings::update_completion_provider_settings(provider, settings_version, cx); + }) + }) + .detach(); +} + fn register_slash_commands(cx: &mut AppContext) { let slash_command_registry = SlashCommandRegistry::global(cx); slash_command_registry.register_command(file_command::FileSlashCommand, true); diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index 92bd4b9cbe..e02c26837a 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -8,18 +8,18 @@ use crate::{ SlashCommandCompletionProvider, SlashCommandRegistry, }, terminal_inline_assistant::TerminalInlineAssistant, - Assist, CompletionProvider, ConfirmCommand, Context, ContextEvent, ContextId, ContextStore, - CycleMessageRole, DebugEditSteps, DeployHistory, DeployPromptLibrary, EditStep, - EditStepOperations, EditSuggestionGroup, InlineAssist, InlineAssistId, InlineAssistant, - InsertIntoEditor, MessageStatus, ModelSelector, PendingSlashCommand, PendingSlashCommandStatus, - QuoteSelection, RemoteContextMetadata, ResetKey, Role, SavedContextMetadata, Split, - ToggleFocus, ToggleModelSelector, + Assist, ConfirmCommand, Context, ContextEvent, ContextId, ContextStore, CycleMessageRole, + DebugEditSteps, DeployHistory, DeployPromptLibrary, EditStep, EditStepOperations, + EditSuggestionGroup, InlineAssist, InlineAssistId, InlineAssistant, InsertIntoEditor, + MessageStatus, ModelSelector, PendingSlashCommand, PendingSlashCommandStatus, QuoteSelection, + RemoteContextMetadata, ResetKey, SavedContextMetadata, Split, ToggleFocus, ToggleModelSelector, }; use anyhow::{anyhow, Result}; use assistant_slash_command::{SlashCommand, SlashCommandOutputSection}; use breadcrumbs::Breadcrumbs; use client::proto; use collections::{BTreeSet, HashMap, HashSet}; +use completion::CompletionProvider; use editor::{ actions::{FoldAt, MoveToEndOfLine, Newline, ShowCompletions, UnfoldAt}, display_map::{ @@ -43,6 +43,7 @@ use language::{ language_settings::SoftWrap, Buffer, Capability, LanguageRegistry, LspAdapterDelegate, Point, ToOffset, }; +use language_model::Role; use multi_buffer::MultiBufferRow; use picker::{Picker, PickerDelegate}; use project::{Project, ProjectLspAdapterDelegate}; diff --git a/crates/assistant/src/assistant_settings.rs b/crates/assistant/src/assistant_settings.rs index d341973326..7fca691e7a 100644 --- a/crates/assistant/src/assistant_settings.rs +++ b/crates/assistant/src/assistant_settings.rs @@ -1,166 +1,19 @@ -use std::fmt; +use std::{sync::Arc, time::Duration}; -use crate::{preprocess_anthropic_request, LanguageModel, LanguageModelRequest}; -pub use anthropic::Model as AnthropicModel; -use gpui::Pixels; -pub use ollama::Model as OllamaModel; -pub use open_ai::Model as OpenAiModel; -use schemars::{ - schema::{InstanceType, Metadata, Schema, SchemaObject}, - JsonSchema, -}; -use serde::{ - de::{self, Visitor}, - Deserialize, Deserializer, Serialize, Serializer, +use anthropic::Model as AnthropicModel; +use client::Client; +use completion::{ + AnthropicCompletionProvider, CloudCompletionProvider, CompletionProvider, + LanguageModelCompletionProvider, OllamaCompletionProvider, OpenAiCompletionProvider, }; +use gpui::{AppContext, Pixels}; +use language_model::{CloudModel, LanguageModel}; +use ollama::Model as OllamaModel; +use open_ai::Model as OpenAiModel; +use parking_lot::RwLock; +use schemars::{schema::Schema, JsonSchema}; +use serde::{Deserialize, Serialize}; use settings::{Settings, SettingsSources}; -use strum::{EnumIter, IntoEnumIterator}; - -#[derive(Clone, Debug, Default, PartialEq, EnumIter)] -pub enum CloudModel { - Gpt3Point5Turbo, - Gpt4, - Gpt4Turbo, - #[default] - Gpt4Omni, - Gpt4OmniMini, - Claude3_5Sonnet, - Claude3Opus, - Claude3Sonnet, - Claude3Haiku, - Gemini15Pro, - Gemini15Flash, - Custom(String), -} - -impl Serialize for CloudModel { - fn serialize(&self, serializer: S) -> Result - where - S: Serializer, - { - serializer.serialize_str(self.id()) - } -} - -impl<'de> Deserialize<'de> for CloudModel { - fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - struct ZedDotDevModelVisitor; - - impl<'de> Visitor<'de> for ZedDotDevModelVisitor { - type Value = CloudModel; - - fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { - formatter.write_str("a string for a ZedDotDevModel variant or a custom model") - } - - fn visit_str(self, value: &str) -> Result - where - E: de::Error, - { - let model = CloudModel::iter() - .find(|model| model.id() == value) - .unwrap_or_else(|| CloudModel::Custom(value.to_string())); - Ok(model) - } - } - - deserializer.deserialize_str(ZedDotDevModelVisitor) - } -} - -impl JsonSchema for CloudModel { - fn schema_name() -> String { - "ZedDotDevModel".to_owned() - } - - fn json_schema(_generator: &mut schemars::gen::SchemaGenerator) -> Schema { - let variants = CloudModel::iter() - .filter_map(|model| { - let id = model.id(); - if id.is_empty() { - None - } else { - Some(id.to_string()) - } - }) - .collect::>(); - Schema::Object(SchemaObject { - instance_type: Some(InstanceType::String.into()), - enum_values: Some(variants.iter().map(|s| s.clone().into()).collect()), - metadata: Some(Box::new(Metadata { - title: Some("ZedDotDevModel".to_owned()), - default: Some(CloudModel::default().id().into()), - examples: variants.into_iter().map(Into::into).collect(), - ..Default::default() - })), - ..Default::default() - }) - } -} - -impl CloudModel { - pub fn id(&self) -> &str { - match self { - Self::Gpt3Point5Turbo => "gpt-3.5-turbo", - Self::Gpt4 => "gpt-4", - Self::Gpt4Turbo => "gpt-4-turbo-preview", - Self::Gpt4Omni => "gpt-4o", - Self::Gpt4OmniMini => "gpt-4o-mini", - Self::Claude3_5Sonnet => "claude-3-5-sonnet", - Self::Claude3Opus => "claude-3-opus", - Self::Claude3Sonnet => "claude-3-sonnet", - Self::Claude3Haiku => "claude-3-haiku", - Self::Gemini15Pro => "gemini-1.5-pro", - Self::Gemini15Flash => "gemini-1.5-flash", - Self::Custom(id) => id, - } - } - - pub fn display_name(&self) -> &str { - match self { - Self::Gpt3Point5Turbo => "GPT 3.5 Turbo", - Self::Gpt4 => "GPT 4", - Self::Gpt4Turbo => "GPT 4 Turbo", - Self::Gpt4Omni => "GPT 4 Omni", - Self::Gpt4OmniMini => "GPT 4 Omni Mini", - Self::Claude3_5Sonnet => "Claude 3.5 Sonnet", - Self::Claude3Opus => "Claude 3 Opus", - Self::Claude3Sonnet => "Claude 3 Sonnet", - Self::Claude3Haiku => "Claude 3 Haiku", - Self::Gemini15Pro => "Gemini 1.5 Pro", - Self::Gemini15Flash => "Gemini 1.5 Flash", - Self::Custom(id) => id.as_str(), - } - } - - pub fn max_token_count(&self) -> usize { - match self { - Self::Gpt3Point5Turbo => 2048, - Self::Gpt4 => 4096, - Self::Gpt4Turbo | Self::Gpt4Omni => 128000, - Self::Gpt4OmniMini => 128000, - Self::Claude3_5Sonnet - | Self::Claude3Opus - | Self::Claude3Sonnet - | Self::Claude3Haiku => 200000, - Self::Gemini15Pro => 128000, - Self::Gemini15Flash => 32000, - Self::Custom(_) => 4096, // TODO: Make this configurable - } - } - - pub fn preprocess_request(&self, request: &mut LanguageModelRequest) { - match self { - Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => { - preprocess_anthropic_request(request) - } - _ => {} - } - } -} #[derive(Copy, Clone, Default, Debug, Serialize, Deserialize, JsonSchema)] #[serde(rename_all = "snake_case")] @@ -620,6 +473,124 @@ fn merge(target: &mut T, value: Option) { } } +pub fn update_completion_provider_settings( + provider: &mut CompletionProvider, + version: usize, + cx: &mut AppContext, +) { + let updated = match &AssistantSettings::get_global(cx).provider { + AssistantProvider::ZedDotDev { model } => provider + .update_current_as::<_, CloudCompletionProvider>(|provider| { + provider.update(model.clone(), version); + }), + AssistantProvider::OpenAi { + model, + api_url, + low_speed_timeout_in_seconds, + available_models, + } => provider.update_current_as::<_, OpenAiCompletionProvider>(|provider| { + provider.update( + choose_openai_model(&model, &available_models), + api_url.clone(), + low_speed_timeout_in_seconds.map(Duration::from_secs), + version, + ); + }), + AssistantProvider::Anthropic { + model, + api_url, + low_speed_timeout_in_seconds, + } => provider.update_current_as::<_, AnthropicCompletionProvider>(|provider| { + provider.update( + model.clone(), + api_url.clone(), + low_speed_timeout_in_seconds.map(Duration::from_secs), + version, + ); + }), + AssistantProvider::Ollama { + model, + api_url, + low_speed_timeout_in_seconds, + } => provider.update_current_as::<_, OllamaCompletionProvider>(|provider| { + provider.update( + model.clone(), + api_url.clone(), + low_speed_timeout_in_seconds.map(Duration::from_secs), + version, + cx, + ); + }), + }; + + // Previously configured provider was changed to another one + if updated.is_none() { + provider.update_provider(|client| create_provider_from_settings(client, version, cx)); + } +} + +pub(crate) fn create_provider_from_settings( + client: Arc, + settings_version: usize, + cx: &mut AppContext, +) -> Arc> { + match &AssistantSettings::get_global(cx).provider { + AssistantProvider::ZedDotDev { model } => Arc::new(RwLock::new( + CloudCompletionProvider::new(model.clone(), client.clone(), settings_version, cx), + )), + AssistantProvider::OpenAi { + model, + api_url, + low_speed_timeout_in_seconds, + available_models, + } => Arc::new(RwLock::new(OpenAiCompletionProvider::new( + choose_openai_model(&model, &available_models), + api_url.clone(), + client.http_client(), + low_speed_timeout_in_seconds.map(Duration::from_secs), + settings_version, + available_models.clone(), + ))), + AssistantProvider::Anthropic { + model, + api_url, + low_speed_timeout_in_seconds, + } => Arc::new(RwLock::new(AnthropicCompletionProvider::new( + model.clone(), + api_url.clone(), + client.http_client(), + low_speed_timeout_in_seconds.map(Duration::from_secs), + settings_version, + ))), + AssistantProvider::Ollama { + model, + api_url, + low_speed_timeout_in_seconds, + } => Arc::new(RwLock::new(OllamaCompletionProvider::new( + model.clone(), + api_url.clone(), + client.http_client(), + low_speed_timeout_in_seconds.map(Duration::from_secs), + settings_version, + cx, + ))), + } +} + +/// Choose which model to use for openai provider. +/// If the model is not available, try to use the first available model, or fallback to the original model. +fn choose_openai_model( + model: &::open_ai::Model, + available_models: &[::open_ai::Model], +) -> ::open_ai::Model { + available_models + .iter() + .find(|&m| m == model) + .or_else(|| available_models.first()) + .unwrap_or_else(|| model) + .clone() +} + #[cfg(test)] mod tests { use gpui::{AppContext, UpdateGlobal}; diff --git a/crates/assistant/src/context.rs b/crates/assistant/src/context.rs index 25f24753a1..f75b693bbd 100644 --- a/crates/assistant/src/context.rs +++ b/crates/assistant/src/context.rs @@ -1,12 +1,12 @@ use crate::{ - prompt_library::PromptStore, slash_command::SlashCommandLine, CompletionProvider, - LanguageModelRequest, LanguageModelRequestMessage, MessageId, MessageStatus, Role, + prompt_library::PromptStore, slash_command::SlashCommandLine, CompletionProvider, MessageId, + MessageStatus, }; use anyhow::{anyhow, Context as _, Result}; use assistant_slash_command::{ SlashCommandOutput, SlashCommandOutputSection, SlashCommandRegistry, }; -use client::{proto, telemetry::Telemetry}; +use client::{self, proto, telemetry::Telemetry}; use clock::ReplicaId; use collections::{HashMap, HashSet}; use fs::Fs; @@ -18,6 +18,8 @@ use gpui::{AppContext, Context as _, EventEmitter, Model, ModelContext, Subscrip use language::{ AnchorRangeExt, Bias, Buffer, LanguageRegistry, OffsetRangeExt, ParseStatus, Point, ToOffset, }; +use language_model::LanguageModelRequestMessage; +use language_model::{LanguageModelRequest, Role}; use open_ai::Model as OpenAiModel; use paths::contexts_dir; use project::Project; @@ -2477,9 +2479,10 @@ mod tests { use crate::{ assistant_panel, prompt_library, slash_command::{active_command, file_command}, - FakeCompletionProvider, MessageId, + MessageId, }; use assistant_slash_command::{ArgumentCompletion, SlashCommand}; + use completion::FakeCompletionProvider; use fs::FakeFs; use gpui::{AppContext, TestAppContext, WeakView}; use indoc::indoc; diff --git a/crates/assistant/src/inline_assistant.rs b/crates/assistant/src/inline_assistant.rs index be14e271e8..b8dbcacd2b 100644 --- a/crates/assistant/src/inline_assistant.rs +++ b/crates/assistant/src/inline_assistant.rs @@ -1,7 +1,6 @@ use crate::{ assistant_settings::AssistantSettings, humanize_token_count, prompts::generate_content_prompt, - AssistantPanel, AssistantPanelEvent, CompletionProvider, Hunk, LanguageModelRequest, - LanguageModelRequestMessage, Role, StreamingDiff, + AssistantPanel, AssistantPanelEvent, CompletionProvider, Hunk, StreamingDiff, }; use anyhow::{anyhow, Context as _, Result}; use client::telemetry::Telemetry; @@ -28,6 +27,7 @@ use gpui::{ WhiteSpace, WindowContext, }; use language::{Buffer, Point, Selection, TransactionId}; +use language_model::{LanguageModelRequest, LanguageModelRequestMessage, Role}; use multi_buffer::MultiBufferRow; use parking_lot::Mutex; use rope::Rope; @@ -1432,8 +1432,7 @@ impl Render for PromptEditor { PopoverMenu::new("model-switcher") .menu(move |cx| { ContextMenu::build(cx, |mut menu, cx| { - for model in CompletionProvider::global(cx).available_models(cx) - { + for model in CompletionProvider::global(cx).available_models() { menu = menu.custom_entry( { let model = model.clone(); @@ -2606,7 +2605,7 @@ fn merge_ranges(ranges: &mut Vec>, buffer: &MultiBufferSnapshot) { #[cfg(test)] mod tests { use super::*; - use crate::FakeCompletionProvider; + use completion::FakeCompletionProvider; use futures::stream::{self}; use gpui::{Context, TestAppContext}; use indoc::indoc; diff --git a/crates/assistant/src/model_selector.rs b/crates/assistant/src/model_selector.rs index a27b2b5565..6cd50a59da 100644 --- a/crates/assistant/src/model_selector.rs +++ b/crates/assistant/src/model_selector.rs @@ -23,7 +23,7 @@ impl RenderOnce for ModelSelector { .with_handle(self.handle) .menu(move |cx| { ContextMenu::build(cx, |mut menu, cx| { - for model in CompletionProvider::global(cx).available_models(cx) { + for model in CompletionProvider::global(cx).available_models() { menu = menu.custom_entry( { let model = model.clone(); diff --git a/crates/assistant/src/prompt_library.rs b/crates/assistant/src/prompt_library.rs index 9d782aedc7..a59f4e3c0f 100644 --- a/crates/assistant/src/prompt_library.rs +++ b/crates/assistant/src/prompt_library.rs @@ -1,6 +1,6 @@ use crate::{ slash_command::SlashCommandCompletionProvider, AssistantPanel, CompletionProvider, - InlineAssist, InlineAssistant, LanguageModelRequest, LanguageModelRequestMessage, Role, + InlineAssist, InlineAssistant, }; use anyhow::{anyhow, Result}; use assets::Assets; @@ -19,6 +19,7 @@ use gpui::{ }; use heed::{types::SerdeBincode, Database, RoTxn}; use language::{language_settings::SoftWrap, Buffer, LanguageRegistry}; +use language_model::{LanguageModelRequest, LanguageModelRequestMessage, Role}; use parking_lot::RwLock; use picker::{Picker, PickerDelegate}; use rope::Rope; diff --git a/crates/assistant/src/terminal_inline_assistant.rs b/crates/assistant/src/terminal_inline_assistant.rs index 8f2cd63bac..192db0cf5e 100644 --- a/crates/assistant/src/terminal_inline_assistant.rs +++ b/crates/assistant/src/terminal_inline_assistant.rs @@ -1,7 +1,7 @@ use crate::{ assistant_settings::AssistantSettings, humanize_token_count, prompts::generate_terminal_assistant_prompt, AssistantPanel, AssistantPanelEvent, - CompletionProvider, LanguageModelRequest, LanguageModelRequestMessage, Role, + CompletionProvider, }; use anyhow::{Context as _, Result}; use client::telemetry::Telemetry; @@ -17,6 +17,7 @@ use gpui::{ Model, ModelContext, Subscription, Task, TextStyle, UpdateGlobal, View, WeakView, WhiteSpace, }; use language::Buffer; +use language_model::{LanguageModelRequest, LanguageModelRequestMessage, Role}; use settings::{update_settings_file, Settings}; use std::{ cmp, @@ -558,8 +559,7 @@ impl Render for PromptEditor { PopoverMenu::new("model-switcher") .menu(move |cx| { ContextMenu::build(cx, |mut menu, cx| { - for model in CompletionProvider::global(cx).available_models(cx) - { + for model in CompletionProvider::global(cx).available_models() { menu = menu.custom_entry( { let model = model.clone(); diff --git a/crates/collab/Cargo.toml b/crates/collab/Cargo.toml index a413a46489..cf99e7c90c 100644 --- a/crates/collab/Cargo.toml +++ b/crates/collab/Cargo.toml @@ -30,6 +30,7 @@ chrono.workspace = true clock.workspace = true clickhouse.workspace = true collections.workspace = true +completion.workspace = true dashmap = "5.4" envy = "0.4.2" futures.workspace = true @@ -79,6 +80,7 @@ channel.workspace = true client = { workspace = true, features = ["test-support"] } collab_ui = { workspace = true, features = ["test-support"] } collections = { workspace = true, features = ["test-support"] } +completion = { workspace = true, features = ["test-support"] } ctor.workspace = true editor = { workspace = true, features = ["test-support"] } env_logger.workspace = true diff --git a/crates/collab/src/tests/test_server.rs b/crates/collab/src/tests/test_server.rs index a3eafd0f94..61c0a8239d 100644 --- a/crates/collab/src/tests/test_server.rs +++ b/crates/collab/src/tests/test_server.rs @@ -295,7 +295,7 @@ impl TestServer { menu::init(); dev_server_projects::init(client.clone(), cx); settings::KeymapFile::load_asset(os_keymap, cx).unwrap(); - assistant::FakeCompletionProvider::setup_test(cx); + completion::FakeCompletionProvider::setup_test(cx); assistant::context_store::init(&client); }); diff --git a/crates/completion/Cargo.toml b/crates/completion/Cargo.toml new file mode 100644 index 0000000000..18181e7bb5 --- /dev/null +++ b/crates/completion/Cargo.toml @@ -0,0 +1,56 @@ +[package] +name = "completion" +version = "0.1.0" +edition = "2021" +publish = false +license = "GPL-3.0-or-later" + +[lints] +workspace = true + +[lib] +path = "src/completion.rs" +doctest = false + +[features] +test-support = [ + "editor/test-support", + "language/test-support", + "project/test-support", + "text/test-support", +] + +[dependencies] +anthropic = { workspace = true, features = ["schemars"] } +anyhow.workspace = true +client.workspace = true +collections.workspace = true +editor.workspace = true +futures.workspace = true +gpui.workspace = true +http.workspace = true +language_model.workspace = true +log.workspace = true +menu.workspace = true +ollama = { workspace = true, features = ["schemars"] } +open_ai = { workspace = true, features = ["schemars"] } +parking_lot.workspace = true +serde.workspace = true +serde_json.workspace = true +settings.workspace = true +smol.workspace = true +strum.workspace = true +theme.workspace = true +tiktoken-rs.workspace = true +ui.workspace = true +util.workspace = true + +[dev-dependencies] +ctor.workspace = true +editor = { workspace = true, features = ["test-support"] } +env_logger.workspace = true +language = { workspace = true, features = ["test-support"] } +project = { workspace = true, features = ["test-support"] } +rand.workspace = true +text = { workspace = true, features = ["test-support"] } +unindent.workspace = true diff --git a/crates/completion/LICENSE-GPL b/crates/completion/LICENSE-GPL new file mode 120000 index 0000000000..89e542f750 --- /dev/null +++ b/crates/completion/LICENSE-GPL @@ -0,0 +1 @@ +../../LICENSE-GPL \ No newline at end of file diff --git a/crates/assistant/src/completion_provider/anthropic.rs b/crates/completion/src/anthropic.rs similarity index 86% rename from crates/assistant/src/completion_provider/anthropic.rs rename to crates/completion/src/anthropic.rs index 48d2020cbe..dc71ebd8ca 100644 --- a/crates/assistant/src/completion_provider/anthropic.rs +++ b/crates/completion/src/anthropic.rs @@ -1,14 +1,12 @@ -use crate::{ - assistant_settings::AnthropicModel, CompletionProvider, LanguageModel, LanguageModelRequest, - Role, -}; -use crate::{count_open_ai_tokens, LanguageModelCompletionProvider, LanguageModelRequestMessage}; -use anthropic::{stream_completion, Request, RequestMessage}; +use crate::{count_open_ai_tokens, LanguageModelCompletionProvider}; +use crate::{CompletionProvider, LanguageModel, LanguageModelRequest}; +use anthropic::{stream_completion, Model as AnthropicModel, Request, RequestMessage}; use anyhow::{anyhow, Result}; use editor::{Editor, EditorElement, EditorStyle}; use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt}; use gpui::{AnyView, AppContext, FontStyle, Task, TextStyle, View, WhiteSpace}; use http::HttpClient; +use language_model::Role; use settings::Settings; use std::time::Duration; use std::{env, sync::Arc}; @@ -27,7 +25,7 @@ pub struct AnthropicCompletionProvider { } impl LanguageModelCompletionProvider for AnthropicCompletionProvider { - fn available_models(&self, _cx: &AppContext) -> Vec { + fn available_models(&self) -> Vec { AnthropicModel::iter() .map(LanguageModel::Anthropic) .collect() @@ -176,7 +174,7 @@ impl AnthropicCompletionProvider { } fn to_anthropic_request(&self, mut request: LanguageModelRequest) -> Request { - preprocess_anthropic_request(&mut request); + request.preprocess_anthropic(); let model = match request.model { LanguageModel::Anthropic(model) => model, @@ -213,49 +211,6 @@ impl AnthropicCompletionProvider { } } -pub fn preprocess_anthropic_request(request: &mut LanguageModelRequest) { - let mut new_messages: Vec = Vec::new(); - let mut system_message = String::new(); - - for message in request.messages.drain(..) { - if message.content.is_empty() { - continue; - } - - match message.role { - Role::User | Role::Assistant => { - if let Some(last_message) = new_messages.last_mut() { - if last_message.role == message.role { - last_message.content.push_str("\n\n"); - last_message.content.push_str(&message.content); - continue; - } - } - - new_messages.push(message); - } - Role::System => { - if !system_message.is_empty() { - system_message.push_str("\n\n"); - } - system_message.push_str(&message.content); - } - } - } - - if !system_message.is_empty() { - new_messages.insert( - 0, - LanguageModelRequestMessage { - role: Role::System, - content: system_message, - }, - ); - } - - request.messages = new_messages; -} - struct AuthenticationPrompt { api_key: View, api_url: String, diff --git a/crates/assistant/src/completion_provider/cloud.rs b/crates/completion/src/cloud.rs similarity index 96% rename from crates/assistant/src/completion_provider/cloud.rs rename to crates/completion/src/cloud.rs index 32b8587116..f84576aeca 100644 --- a/crates/assistant/src/completion_provider/cloud.rs +++ b/crates/completion/src/cloud.rs @@ -1,11 +1,12 @@ use crate::{ - assistant_settings::CloudModel, count_open_ai_tokens, CompletionProvider, LanguageModel, - LanguageModelCompletionProvider, LanguageModelRequest, + count_open_ai_tokens, CompletionProvider, LanguageModel, LanguageModelCompletionProvider, + LanguageModelRequest, }; use anyhow::{anyhow, Result}; use client::{proto, Client}; use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt, TryFutureExt}; use gpui::{AnyView, AppContext, Task}; +use language_model::CloudModel; use std::{future, sync::Arc}; use strum::IntoEnumIterator; use ui::prelude::*; @@ -52,7 +53,7 @@ impl CloudCompletionProvider { } impl LanguageModelCompletionProvider for CloudCompletionProvider { - fn available_models(&self, _cx: &AppContext) -> Vec { + fn available_models(&self) -> Vec { let mut custom_model = if let CloudModel::Custom(custom_model) = self.model.clone() { Some(custom_model) } else { diff --git a/crates/assistant/src/completion_provider.rs b/crates/completion/src/completion.rs similarity index 57% rename from crates/assistant/src/completion_provider.rs rename to crates/completion/src/completion.rs index 13f91f70e3..a219e90b51 100644 --- a/crates/assistant/src/completion_provider.rs +++ b/crates/completion/src/completion.rs @@ -6,52 +6,19 @@ mod ollama; mod open_ai; pub use anthropic::*; +use anyhow::Result; +use client::Client; pub use cloud::*; #[cfg(any(test, feature = "test-support"))] pub use fake::*; +use futures::{future::BoxFuture, stream::BoxStream, StreamExt}; +use gpui::{AnyView, AppContext, Task, WindowContext}; +use language_model::{LanguageModel, LanguageModelRequest}; pub use ollama::*; pub use open_ai::*; use parking_lot::RwLock; use smol::lock::{Semaphore, SemaphoreGuardArc}; - -use crate::{ - assistant_settings::{AssistantProvider, AssistantSettings}, - LanguageModel, LanguageModelRequest, -}; -use anyhow::Result; -use client::Client; -use futures::{future::BoxFuture, stream::BoxStream, StreamExt}; -use gpui::{AnyView, AppContext, BorrowAppContext, Task, WindowContext}; -use settings::{Settings, SettingsStore}; -use std::{any::Any, pin::Pin, sync::Arc, task::Poll, time::Duration}; - -/// Choose which model to use for openai provider. -/// If the model is not available, try to use the first available model, or fallback to the original model. -fn choose_openai_model( - model: &::open_ai::Model, - available_models: &[::open_ai::Model], -) -> ::open_ai::Model { - available_models - .iter() - .find(|&m| m == model) - .or_else(|| available_models.first()) - .unwrap_or_else(|| model) - .clone() -} - -pub fn init(client: Arc, cx: &mut AppContext) { - let provider = create_provider_from_settings(client.clone(), 0, cx); - cx.set_global(CompletionProvider::new(provider, Some(client))); - - let mut settings_version = 0; - cx.observe_global::(move |cx| { - settings_version += 1; - cx.update_global::(|provider, cx| { - provider.update_settings(settings_version, cx); - }) - }) - .detach(); -} +use std::{any::Any, pin::Pin, sync::Arc, task::Poll}; pub struct CompletionResponse { inner: BoxStream<'static, Result>, @@ -70,7 +37,7 @@ impl futures::Stream for CompletionResponse { } pub trait LanguageModelCompletionProvider: Send + Sync { - fn available_models(&self, cx: &AppContext) -> Vec; + fn available_models(&self) -> Vec; fn settings_version(&self) -> usize; fn is_authenticated(&self) -> bool; fn authenticate(&self, cx: &AppContext) -> Task>; @@ -110,8 +77,8 @@ impl CompletionProvider { } } - pub fn available_models(&self, cx: &AppContext) -> Vec { - self.provider.read().available_models(cx) + pub fn available_models(&self) -> Vec { + self.provider.read().available_models() } pub fn settings_version(&self) -> usize { @@ -176,6 +143,17 @@ impl CompletionProvider { Ok(completion) }) } + + pub fn update_provider( + &mut self, + get_provider: impl FnOnce(Arc) -> Arc>, + ) { + if let Some(client) = &self.client { + self.provider = get_provider(Arc::clone(client)); + } else { + log::warn!("completion provider cannot be updated because its client was not set"); + } + } } impl gpui::Global for CompletionProvider {} @@ -196,109 +174,6 @@ impl CompletionProvider { None } } - - pub fn update_settings(&mut self, version: usize, cx: &mut AppContext) { - let updated = match &AssistantSettings::get_global(cx).provider { - AssistantProvider::ZedDotDev { model } => self - .update_current_as::<_, CloudCompletionProvider>(|provider| { - provider.update(model.clone(), version); - }), - AssistantProvider::OpenAi { - model, - api_url, - low_speed_timeout_in_seconds, - available_models, - } => self.update_current_as::<_, OpenAiCompletionProvider>(|provider| { - provider.update( - choose_openai_model(&model, &available_models), - api_url.clone(), - low_speed_timeout_in_seconds.map(Duration::from_secs), - version, - ); - }), - AssistantProvider::Anthropic { - model, - api_url, - low_speed_timeout_in_seconds, - } => self.update_current_as::<_, AnthropicCompletionProvider>(|provider| { - provider.update( - model.clone(), - api_url.clone(), - low_speed_timeout_in_seconds.map(Duration::from_secs), - version, - ); - }), - AssistantProvider::Ollama { - model, - api_url, - low_speed_timeout_in_seconds, - } => self.update_current_as::<_, OllamaCompletionProvider>(|provider| { - provider.update( - model.clone(), - api_url.clone(), - low_speed_timeout_in_seconds.map(Duration::from_secs), - version, - cx, - ); - }), - }; - - // Previously configured provider was changed to another one - if updated.is_none() { - if let Some(client) = self.client.clone() { - self.provider = create_provider_from_settings(client, version, cx); - } else { - log::warn!("completion provider cannot be created because client is not set"); - } - } - } -} - -fn create_provider_from_settings( - client: Arc, - settings_version: usize, - cx: &mut AppContext, -) -> Arc> { - match &AssistantSettings::get_global(cx).provider { - AssistantProvider::ZedDotDev { model } => Arc::new(RwLock::new( - CloudCompletionProvider::new(model.clone(), client.clone(), settings_version, cx), - )), - AssistantProvider::OpenAi { - model, - api_url, - low_speed_timeout_in_seconds, - available_models, - } => Arc::new(RwLock::new(OpenAiCompletionProvider::new( - choose_openai_model(&model, &available_models), - api_url.clone(), - client.http_client(), - low_speed_timeout_in_seconds.map(Duration::from_secs), - settings_version, - ))), - AssistantProvider::Anthropic { - model, - api_url, - low_speed_timeout_in_seconds, - } => Arc::new(RwLock::new(AnthropicCompletionProvider::new( - model.clone(), - api_url.clone(), - client.http_client(), - low_speed_timeout_in_seconds.map(Duration::from_secs), - settings_version, - ))), - AssistantProvider::Ollama { - model, - api_url, - low_speed_timeout_in_seconds, - } => Arc::new(RwLock::new(OllamaCompletionProvider::new( - model.clone(), - api_url.clone(), - client.http_client(), - low_speed_timeout_in_seconds.map(Duration::from_secs), - settings_version, - cx, - ))), - } } #[cfg(test)] @@ -311,8 +186,8 @@ mod tests { use smol::stream::StreamExt; use crate::{ - completion_provider::MAX_CONCURRENT_COMPLETION_REQUESTS, CompletionProvider, - FakeCompletionProvider, LanguageModelRequest, + CompletionProvider, FakeCompletionProvider, LanguageModelRequest, + MAX_CONCURRENT_COMPLETION_REQUESTS, }; #[gpui::test] diff --git a/crates/assistant/src/completion_provider/fake.rs b/crates/completion/src/fake.rs similarity index 97% rename from crates/assistant/src/completion_provider/fake.rs rename to crates/completion/src/fake.rs index e9ad8d9a0f..9eee0f736f 100644 --- a/crates/assistant/src/completion_provider/fake.rs +++ b/crates/completion/src/fake.rs @@ -62,7 +62,7 @@ impl FakeCompletionProvider { } impl LanguageModelCompletionProvider for FakeCompletionProvider { - fn available_models(&self, _cx: &AppContext) -> Vec { + fn available_models(&self) -> Vec { vec![LanguageModel::default()] } diff --git a/crates/assistant/src/completion_provider/ollama.rs b/crates/completion/src/ollama.rs similarity index 96% rename from crates/assistant/src/completion_provider/ollama.rs rename to crates/completion/src/ollama.rs index 59d79e3ae7..30d797c76b 100644 --- a/crates/assistant/src/completion_provider/ollama.rs +++ b/crates/completion/src/ollama.rs @@ -1,15 +1,14 @@ use crate::LanguageModelCompletionProvider; -use crate::{ - assistant_settings::OllamaModel, CompletionProvider, LanguageModel, LanguageModelRequest, Role, -}; +use crate::{CompletionProvider, LanguageModel, LanguageModelRequest}; use anyhow::Result; use futures::StreamExt as _; use futures::{future::BoxFuture, stream::BoxStream, FutureExt}; use gpui::{AnyView, AppContext, Task}; use http::HttpClient; +use language_model::Role; +use ollama::Model as OllamaModel; use ollama::{ get_models, preload_model, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest, - Role as OllamaRole, }; use std::sync::Arc; use std::time::Duration; @@ -28,7 +27,7 @@ pub struct OllamaCompletionProvider { } impl LanguageModelCompletionProvider for OllamaCompletionProvider { - fn available_models(&self, _cx: &AppContext) -> Vec { + fn available_models(&self) -> Vec { self.available_models .iter() .map(|m| LanguageModel::Ollama(m.clone())) @@ -262,16 +261,6 @@ impl OllamaCompletionProvider { } } -impl From for ollama::Role { - fn from(val: Role) -> Self { - match val { - Role::User => OllamaRole::User, - Role::Assistant => OllamaRole::Assistant, - Role::System => OllamaRole::System, - } - } -} - struct DownloadOllamaMessage { retry_connection: Box Task>>, } diff --git a/crates/assistant/src/completion_provider/open_ai.rs b/crates/completion/src/open_ai.rs similarity index 89% rename from crates/assistant/src/completion_provider/open_ai.rs rename to crates/completion/src/open_ai.rs index fd65d1afe5..0a0f6d5b4a 100644 --- a/crates/assistant/src/completion_provider/open_ai.rs +++ b/crates/completion/src/open_ai.rs @@ -1,15 +1,13 @@ -use crate::assistant_settings::CloudModel; -use crate::assistant_settings::{AssistantProvider, AssistantSettings}; +use crate::CompletionProvider; use crate::LanguageModelCompletionProvider; -use crate::{ - assistant_settings::OpenAiModel, CompletionProvider, LanguageModel, LanguageModelRequest, Role, -}; use anyhow::{anyhow, Result}; use editor::{Editor, EditorElement, EditorStyle}; use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt}; use gpui::{AnyView, AppContext, FontStyle, Task, TextStyle, View, WhiteSpace}; use http::HttpClient; -use open_ai::{stream_completion, Request, RequestMessage, Role as OpenAiRole}; +use language_model::{CloudModel, LanguageModel, LanguageModelRequest, Role}; +use open_ai::Model as OpenAiModel; +use open_ai::{stream_completion, Request, RequestMessage}; use settings::Settings; use std::time::Duration; use std::{env, sync::Arc}; @@ -25,6 +23,7 @@ pub struct OpenAiCompletionProvider { http_client: Arc, low_speed_timeout: Option, settings_version: usize, + available_models_from_settings: Vec, } impl OpenAiCompletionProvider { @@ -34,6 +33,7 @@ impl OpenAiCompletionProvider { http_client: Arc, low_speed_timeout: Option, settings_version: usize, + available_models_from_settings: Vec, ) -> Self { Self { api_key: None, @@ -42,6 +42,7 @@ impl OpenAiCompletionProvider { http_client, low_speed_timeout, settings_version, + available_models_from_settings, } } @@ -92,30 +93,26 @@ impl OpenAiCompletionProvider { } impl LanguageModelCompletionProvider for OpenAiCompletionProvider { - fn available_models(&self, cx: &AppContext) -> Vec { - if let AssistantProvider::OpenAi { - available_models, .. - } = &AssistantSettings::get_global(cx).provider - { - if !available_models.is_empty() { - return available_models - .iter() - .cloned() - .map(LanguageModel::OpenAi) - .collect(); - } - } - let available_models = if matches!(self.model, OpenAiModel::Custom { .. }) { - vec![self.model.clone()] - } else { - OpenAiModel::iter() - .filter(|model| !matches!(model, OpenAiModel::Custom { .. })) + fn available_models(&self) -> Vec { + if self.available_models_from_settings.is_empty() { + let available_models = if matches!(self.model, OpenAiModel::Custom { .. }) { + vec![self.model.clone()] + } else { + OpenAiModel::iter() + .filter(|model| !matches!(model, OpenAiModel::Custom { .. })) + .collect() + }; + available_models + .into_iter() + .map(LanguageModel::OpenAi) .collect() - }; - available_models - .into_iter() - .map(LanguageModel::OpenAi) - .collect() + } else { + self.available_models_from_settings + .iter() + .cloned() + .map(LanguageModel::OpenAi) + .collect() + } } fn settings_version(&self) -> usize { @@ -255,16 +252,6 @@ pub fn count_open_ai_tokens( .boxed() } -impl From for open_ai::Role { - fn from(val: Role) -> Self { - match val { - Role::User => OpenAiRole::User, - Role::Assistant => OpenAiRole::Assistant, - Role::System => OpenAiRole::System, - } - } -} - struct AuthenticationPrompt { api_key: View, api_url: String, diff --git a/crates/language_model/Cargo.toml b/crates/language_model/Cargo.toml new file mode 100644 index 0000000000..bdc3ad63d5 --- /dev/null +++ b/crates/language_model/Cargo.toml @@ -0,0 +1,41 @@ +[package] +name = "language_model" +version = "0.1.0" +edition = "2021" +publish = false +license = "GPL-3.0-or-later" + +[lints] +workspace = true + +[lib] +path = "src/language_model.rs" +doctest = false + +[features] +test-support = [ + "editor/test-support", + "language/test-support", + "project/test-support", + "text/test-support", +] + +[dependencies] +anthropic = { workspace = true, features = ["schemars"] } +ollama = { workspace = true, features = ["schemars"] } +open_ai = { workspace = true, features = ["schemars"] } +schemars.workspace = true +serde.workspace = true +strum.workspace = true +proto = { workspace = true, features = ["test-support"] } + +[dev-dependencies] +ctor.workspace = true +editor = { workspace = true, features = ["test-support"] } +env_logger.workspace = true +language = { workspace = true, features = ["test-support"] } +log.workspace = true +project = { workspace = true, features = ["test-support"] } +rand.workspace = true +text = { workspace = true, features = ["test-support"] } +unindent.workspace = true diff --git a/crates/language_model/LICENSE-GPL b/crates/language_model/LICENSE-GPL new file mode 120000 index 0000000000..89e542f750 --- /dev/null +++ b/crates/language_model/LICENSE-GPL @@ -0,0 +1 @@ +../../LICENSE-GPL \ No newline at end of file diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs new file mode 100644 index 0000000000..09de409ff4 --- /dev/null +++ b/crates/language_model/src/language_model.rs @@ -0,0 +1,7 @@ +mod model; +mod request; +mod role; + +pub use model::*; +pub use request::*; +pub use role::*; diff --git a/crates/language_model/src/model/cloud_model.rs b/crates/language_model/src/model/cloud_model.rs new file mode 100644 index 0000000000..20b2bf7d4f --- /dev/null +++ b/crates/language_model/src/model/cloud_model.rs @@ -0,0 +1,160 @@ +use crate::LanguageModelRequest; +pub use anthropic::Model as AnthropicModel; +pub use ollama::Model as OllamaModel; +pub use open_ai::Model as OpenAiModel; +use schemars::{ + schema::{InstanceType, Metadata, Schema, SchemaObject}, + JsonSchema, +}; +use serde::{ + de::{self, Visitor}, + Deserialize, Deserializer, Serialize, Serializer, +}; +use std::fmt; +use strum::{EnumIter, IntoEnumIterator}; + +#[derive(Clone, Debug, Default, PartialEq, EnumIter)] +pub enum CloudModel { + Gpt3Point5Turbo, + Gpt4, + Gpt4Turbo, + #[default] + Gpt4Omni, + Gpt4OmniMini, + Claude3_5Sonnet, + Claude3Opus, + Claude3Sonnet, + Claude3Haiku, + Gemini15Pro, + Gemini15Flash, + Custom(String), +} + +impl Serialize for CloudModel { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + serializer.serialize_str(self.id()) + } +} + +impl<'de> Deserialize<'de> for CloudModel { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + struct ZedDotDevModelVisitor; + + impl<'de> Visitor<'de> for ZedDotDevModelVisitor { + type Value = CloudModel; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a string for a ZedDotDevModel variant or a custom model") + } + + fn visit_str(self, value: &str) -> Result + where + E: de::Error, + { + let model = CloudModel::iter() + .find(|model| model.id() == value) + .unwrap_or_else(|| CloudModel::Custom(value.to_string())); + Ok(model) + } + } + + deserializer.deserialize_str(ZedDotDevModelVisitor) + } +} + +impl JsonSchema for CloudModel { + fn schema_name() -> String { + "ZedDotDevModel".to_owned() + } + + fn json_schema(_generator: &mut schemars::gen::SchemaGenerator) -> Schema { + let variants = CloudModel::iter() + .filter_map(|model| { + let id = model.id(); + if id.is_empty() { + None + } else { + Some(id.to_string()) + } + }) + .collect::>(); + Schema::Object(SchemaObject { + instance_type: Some(InstanceType::String.into()), + enum_values: Some(variants.iter().map(|s| s.clone().into()).collect()), + metadata: Some(Box::new(Metadata { + title: Some("ZedDotDevModel".to_owned()), + default: Some(CloudModel::default().id().into()), + examples: variants.into_iter().map(Into::into).collect(), + ..Default::default() + })), + ..Default::default() + }) + } +} + +impl CloudModel { + pub fn id(&self) -> &str { + match self { + Self::Gpt3Point5Turbo => "gpt-3.5-turbo", + Self::Gpt4 => "gpt-4", + Self::Gpt4Turbo => "gpt-4-turbo-preview", + Self::Gpt4Omni => "gpt-4o", + Self::Gpt4OmniMini => "gpt-4o-mini", + Self::Claude3_5Sonnet => "claude-3-5-sonnet", + Self::Claude3Opus => "claude-3-opus", + Self::Claude3Sonnet => "claude-3-sonnet", + Self::Claude3Haiku => "claude-3-haiku", + Self::Gemini15Pro => "gemini-1.5-pro", + Self::Gemini15Flash => "gemini-1.5-flash", + Self::Custom(id) => id, + } + } + + pub fn display_name(&self) -> &str { + match self { + Self::Gpt3Point5Turbo => "GPT 3.5 Turbo", + Self::Gpt4 => "GPT 4", + Self::Gpt4Turbo => "GPT 4 Turbo", + Self::Gpt4Omni => "GPT 4 Omni", + Self::Gpt4OmniMini => "GPT 4 Omni Mini", + Self::Claude3_5Sonnet => "Claude 3.5 Sonnet", + Self::Claude3Opus => "Claude 3 Opus", + Self::Claude3Sonnet => "Claude 3 Sonnet", + Self::Claude3Haiku => "Claude 3 Haiku", + Self::Gemini15Pro => "Gemini 1.5 Pro", + Self::Gemini15Flash => "Gemini 1.5 Flash", + Self::Custom(id) => id.as_str(), + } + } + + pub fn max_token_count(&self) -> usize { + match self { + Self::Gpt3Point5Turbo => 2048, + Self::Gpt4 => 4096, + Self::Gpt4Turbo | Self::Gpt4Omni => 128000, + Self::Gpt4OmniMini => 128000, + Self::Claude3_5Sonnet + | Self::Claude3Opus + | Self::Claude3Sonnet + | Self::Claude3Haiku => 200000, + Self::Gemini15Pro => 128000, + Self::Gemini15Flash => 32000, + Self::Custom(_) => 4096, // TODO: Make this configurable + } + } + + pub fn preprocess_request(&self, request: &mut LanguageModelRequest) { + match self { + Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => { + request.preprocess_anthropic() + } + _ => {} + } + } +} diff --git a/crates/language_model/src/model/mod.rs b/crates/language_model/src/model/mod.rs new file mode 100644 index 0000000000..b61766308f --- /dev/null +++ b/crates/language_model/src/model/mod.rs @@ -0,0 +1,60 @@ +pub mod cloud_model; + +pub use anthropic::Model as AnthropicModel; +pub use cloud_model::*; +pub use ollama::Model as OllamaModel; +pub use open_ai::Model as OpenAiModel; + +use serde::{Deserialize, Serialize}; + +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] +pub enum LanguageModel { + Cloud(CloudModel), + OpenAi(OpenAiModel), + Anthropic(AnthropicModel), + Ollama(OllamaModel), +} + +impl Default for LanguageModel { + fn default() -> Self { + LanguageModel::Cloud(CloudModel::default()) + } +} + +impl LanguageModel { + pub fn telemetry_id(&self) -> String { + match self { + LanguageModel::OpenAi(model) => format!("openai/{}", model.id()), + LanguageModel::Anthropic(model) => format!("anthropic/{}", model.id()), + LanguageModel::Cloud(model) => format!("zed.dev/{}", model.id()), + LanguageModel::Ollama(model) => format!("ollama/{}", model.id()), + } + } + + pub fn display_name(&self) -> String { + match self { + LanguageModel::OpenAi(model) => model.display_name().into(), + LanguageModel::Anthropic(model) => model.display_name().into(), + LanguageModel::Cloud(model) => model.display_name().into(), + LanguageModel::Ollama(model) => model.display_name().into(), + } + } + + pub fn max_token_count(&self) -> usize { + match self { + LanguageModel::OpenAi(model) => model.max_token_count(), + LanguageModel::Anthropic(model) => model.max_token_count(), + LanguageModel::Cloud(model) => model.max_token_count(), + LanguageModel::Ollama(model) => model.max_token_count(), + } + } + + pub fn id(&self) -> &str { + match self { + LanguageModel::OpenAi(model) => model.id(), + LanguageModel::Anthropic(model) => model.id(), + LanguageModel::Cloud(model) => model.id(), + LanguageModel::Ollama(model) => model.id(), + } + } +} diff --git a/crates/language_model/src/request.rs b/crates/language_model/src/request.rs new file mode 100644 index 0000000000..f9c4322cdf --- /dev/null +++ b/crates/language_model/src/request.rs @@ -0,0 +1,110 @@ +use crate::{ + model::{CloudModel, LanguageModel}, + role::Role, +}; +use serde::{Deserialize, Serialize}; + +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +pub struct LanguageModelRequestMessage { + pub role: Role, + pub content: String, +} + +impl LanguageModelRequestMessage { + pub fn to_proto(&self) -> proto::LanguageModelRequestMessage { + proto::LanguageModelRequestMessage { + role: self.role.to_proto() as i32, + content: self.content.clone(), + tool_calls: Vec::new(), + tool_call_id: None, + } + } +} + +#[derive(Debug, Default, Serialize, Deserialize)] +pub struct LanguageModelRequest { + pub model: LanguageModel, + pub messages: Vec, + pub stop: Vec, + pub temperature: f32, +} + +impl LanguageModelRequest { + pub fn to_proto(&self) -> proto::CompleteWithLanguageModel { + proto::CompleteWithLanguageModel { + model: self.model.id().to_string(), + messages: self.messages.iter().map(|m| m.to_proto()).collect(), + stop: self.stop.clone(), + temperature: self.temperature, + tool_choice: None, + tools: Vec::new(), + } + } + + /// Before we send the request to the server, we can perform fixups on it appropriate to the model. + pub fn preprocess(&mut self) { + match &self.model { + LanguageModel::OpenAi(_) => {} + LanguageModel::Anthropic(_) => {} + LanguageModel::Ollama(_) => {} + LanguageModel::Cloud(model) => match model { + CloudModel::Claude3Opus + | CloudModel::Claude3Sonnet + | CloudModel::Claude3Haiku + | CloudModel::Claude3_5Sonnet => { + self.preprocess_anthropic(); + } + _ => {} + }, + } + } + + pub fn preprocess_anthropic(&mut self) { + let mut new_messages: Vec = Vec::new(); + let mut system_message = String::new(); + + for message in self.messages.drain(..) { + if message.content.is_empty() { + continue; + } + + match message.role { + Role::User | Role::Assistant => { + if let Some(last_message) = new_messages.last_mut() { + if last_message.role == message.role { + last_message.content.push_str("\n\n"); + last_message.content.push_str(&message.content); + continue; + } + } + + new_messages.push(message); + } + Role::System => { + if !system_message.is_empty() { + system_message.push_str("\n\n"); + } + system_message.push_str(&message.content); + } + } + } + + if !system_message.is_empty() { + new_messages.insert( + 0, + LanguageModelRequestMessage { + role: Role::System, + content: system_message, + }, + ); + } + + self.messages = new_messages; + } +} + +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +pub struct LanguageModelResponseMessage { + pub role: Option, + pub content: Option, +} diff --git a/crates/language_model/src/role.rs b/crates/language_model/src/role.rs new file mode 100644 index 0000000000..f6276a4823 --- /dev/null +++ b/crates/language_model/src/role.rs @@ -0,0 +1,68 @@ +use serde::{Deserialize, Serialize}; +use std::fmt::{self, Display}; + +#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum Role { + User, + Assistant, + System, +} + +impl Role { + pub fn from_proto(role: i32) -> Role { + match proto::LanguageModelRole::from_i32(role) { + Some(proto::LanguageModelRole::LanguageModelUser) => Role::User, + Some(proto::LanguageModelRole::LanguageModelAssistant) => Role::Assistant, + Some(proto::LanguageModelRole::LanguageModelSystem) => Role::System, + Some(proto::LanguageModelRole::LanguageModelTool) => Role::System, + None => Role::User, + } + } + + pub fn to_proto(&self) -> proto::LanguageModelRole { + match self { + Role::User => proto::LanguageModelRole::LanguageModelUser, + Role::Assistant => proto::LanguageModelRole::LanguageModelAssistant, + Role::System => proto::LanguageModelRole::LanguageModelSystem, + } + } + + pub fn cycle(self) -> Role { + match self { + Role::User => Role::Assistant, + Role::Assistant => Role::System, + Role::System => Role::User, + } + } +} + +impl Display for Role { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Role::User => write!(f, "user"), + Role::Assistant => write!(f, "assistant"), + Role::System => write!(f, "system"), + } + } +} + +impl From for ollama::Role { + fn from(val: Role) -> Self { + match val { + Role::User => ollama::Role::User, + Role::Assistant => ollama::Role::Assistant, + Role::System => ollama::Role::System, + } + } +} + +impl From for open_ai::Role { + fn from(val: Role) -> Self { + match val { + Role::User => open_ai::Role::User, + Role::Assistant => open_ai::Role::Assistant, + Role::System => open_ai::Role::System, + } + } +} diff --git a/crates/semantic_index/Cargo.toml b/crates/semantic_index/Cargo.toml index 3f49490941..19cb0c96fe 100644 --- a/crates/semantic_index/Cargo.toml +++ b/crates/semantic_index/Cargo.toml @@ -22,6 +22,7 @@ anyhow.workspace = true client.workspace = true clock.workspace = true collections.workspace = true +completion.workspace = true fs.workspace = true futures.workspace = true futures-batch.workspace = true diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index 7a29f3be25..4c43fc1e46 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -1261,3 +1261,6 @@ mod tests { ); } } + +// See https://github.com/zed-industries/zed/pull/14823#discussion_r1684616398 for why this is here and when it should be removed. +type _TODO = completion::CompletionProvider;