diff --git a/Cargo.lock b/Cargo.lock index c734a89f45..40f9c59228 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -382,6 +382,7 @@ dependencies = [ "clock", "collections", "command_palette_hooks", + "completion", "ctor", "editor", "env_logger", @@ -396,6 +397,7 @@ dependencies = [ "indexed_docs", "indoc", "language", + "language_model", "log", "menu", "multi_buffer", @@ -418,13 +420,11 @@ dependencies = [ "settings", "similar", "smol", - "strum", "telemetry_events", "terminal", "terminal_view", "text", "theme", - "tiktoken-rs", "toml 0.8.10", "ui", "unindent", @@ -2491,6 +2491,7 @@ dependencies = [ "clock", "collab_ui", "collections", + "completion", "ctor", "dashmap", "dev_server_projects", @@ -2673,6 +2674,42 @@ dependencies = [ "gpui", ] +[[package]] +name = "completion" +version = "0.1.0" +dependencies = [ + "anthropic", + "anyhow", + "client", + "collections", + "ctor", + "editor", + "env_logger", + "futures 0.3.28", + "gpui", + "http 0.1.0", + "language", + "language_model", + "log", + "menu", + "ollama", + "open_ai", + "parking_lot", + "project", + "rand 0.8.5", + "serde", + "serde_json", + "settings", + "smol", + "strum", + "text", + "theme", + "tiktoken-rs", + "ui", + "unindent", + "util", +] + [[package]] name = "concurrent-queue" version = "2.2.0" @@ -5996,6 +6033,28 @@ dependencies = [ "util", ] +[[package]] +name = "language_model" +version = "0.1.0" +dependencies = [ + "anthropic", + "ctor", + "editor", + "env_logger", + "language", + "log", + "ollama", + "open_ai", + "project", + "proto", + "rand 0.8.5", + "schemars", + "serde", + "strum", + "text", + "unindent", +] + [[package]] name = "language_selector" version = "0.1.0" @@ -9510,6 +9569,7 @@ dependencies = [ "client", "clock", "collections", + "completion", "env_logger", "fs", "futures 0.3.28", diff --git a/Cargo.toml b/Cargo.toml index 2f607134c4..1df8affd08 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,6 +19,7 @@ members = [ "crates/collections", "crates/command_palette", "crates/command_palette_hooks", + "crates/completion", "crates/copilot", "crates/db", "crates/dev_server_projects", @@ -50,6 +51,7 @@ members = [ "crates/install_cli", "crates/journal", "crates/language", + "crates/language_model", "crates/language_selector", "crates/language_tools", "crates/languages", @@ -176,6 +178,7 @@ collab_ui = { path = "crates/collab_ui" } collections = { path = "crates/collections" } command_palette = { path = "crates/command_palette" } command_palette_hooks = { path = "crates/command_palette_hooks" } +completion = { path = "crates/completion" } copilot = { path = "crates/copilot" } db = { path = "crates/db" } dev_server_projects = { path = "crates/dev_server_projects" } @@ -205,6 +208,7 @@ inline_completion_button = { path = "crates/inline_completion_button" } install_cli = { path = "crates/install_cli" } journal = { path = "crates/journal" } language = { path = "crates/language" } +language_model = { path = "crates/language_model" } language_selector = { path = "crates/language_selector" } language_tools = { path = "crates/language_tools" } languages = { path = "crates/languages" } diff --git a/crates/assistant/Cargo.toml b/crates/assistant/Cargo.toml index e3ddd4e2c7..201e16bd57 100644 --- a/crates/assistant/Cargo.toml +++ b/crates/assistant/Cargo.toml @@ -33,6 +33,7 @@ client.workspace = true clock.workspace = true collections.workspace = true command_palette_hooks.workspace = true +completion.workspace = true editor.workspace = true feature_flags.workspace = true fs.workspace = true @@ -45,6 +46,7 @@ http.workspace = true indexed_docs.workspace = true indoc.workspace = true language.workspace = true +language_model.workspace = true log.workspace = true menu.workspace = true multi_buffer.workspace = true @@ -64,12 +66,10 @@ serde_json.workspace = true settings.workspace = true similar.workspace = true smol.workspace = true -strum.workspace = true telemetry_events.workspace = true terminal.workspace = true terminal_view.workspace = true theme.workspace = true -tiktoken-rs.workspace = true toml.workspace = true ui.workspace = true util.workspace = true @@ -79,6 +79,7 @@ picker.workspace = true roxmltree = "0.20.0" [dev-dependencies] +completion = { workspace = true, features = ["test-support"] } ctor.workspace = true editor = { workspace = true, features = ["test-support"] } env_logger.workspace = true diff --git a/crates/assistant/src/assistant.rs b/crates/assistant/src/assistant.rs index cf3726485f..0b12cc099c 100644 --- a/crates/assistant/src/assistant.rs +++ b/crates/assistant/src/assistant.rs @@ -1,6 +1,5 @@ pub mod assistant_panel; pub mod assistant_settings; -mod completion_provider; mod context; pub mod context_store; mod inline_assistant; @@ -12,17 +11,20 @@ mod streaming_diff; mod terminal_inline_assistant; pub use assistant_panel::{AssistantPanel, AssistantPanelEvent}; -use assistant_settings::{AnthropicModel, AssistantSettings, CloudModel, OllamaModel, OpenAiModel}; +use assistant_settings::AssistantSettings; use assistant_slash_command::SlashCommandRegistry; use client::{proto, Client}; use command_palette_hooks::CommandPaletteFilter; -pub use completion_provider::*; +use completion::CompletionProvider; pub use context::*; pub use context_store::*; use fs::Fs; -use gpui::{actions, impl_actions, AppContext, Global, SharedString, UpdateGlobal}; +use gpui::{ + actions, impl_actions, AppContext, BorrowAppContext, Global, SharedString, UpdateGlobal, +}; use indexed_docs::IndexedDocsRegistry; pub(crate) use inline_assistant::*; +use language_model::LanguageModelResponseMessage; pub(crate) use model_selector::*; use semantic_index::{CloudEmbeddingProvider, SemanticIndex}; use serde::{Deserialize, Serialize}; @@ -32,10 +34,7 @@ use slash_command::{ file_command, now_command, project_command, prompt_command, search_command, symbols_command, tabs_command, term_command, }; -use std::{ - fmt::{self, Display}, - sync::Arc, -}; +use std::sync::Arc; pub(crate) use streaming_diff::*; actions!( @@ -73,166 +72,6 @@ impl MessageId { } } -#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)] -#[serde(rename_all = "lowercase")] -pub enum Role { - User, - Assistant, - System, -} - -impl Role { - pub fn from_proto(role: i32) -> Role { - match proto::LanguageModelRole::from_i32(role) { - Some(proto::LanguageModelRole::LanguageModelUser) => Role::User, - Some(proto::LanguageModelRole::LanguageModelAssistant) => Role::Assistant, - Some(proto::LanguageModelRole::LanguageModelSystem) => Role::System, - Some(proto::LanguageModelRole::LanguageModelTool) => Role::System, - None => Role::User, - } - } - - pub fn to_proto(&self) -> proto::LanguageModelRole { - match self { - Role::User => proto::LanguageModelRole::LanguageModelUser, - Role::Assistant => proto::LanguageModelRole::LanguageModelAssistant, - Role::System => proto::LanguageModelRole::LanguageModelSystem, - } - } - - pub fn cycle(self) -> Role { - match self { - Role::User => Role::Assistant, - Role::Assistant => Role::System, - Role::System => Role::User, - } - } -} - -impl Display for Role { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Role::User => write!(f, "user"), - Role::Assistant => write!(f, "assistant"), - Role::System => write!(f, "system"), - } - } -} - -#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] -pub enum LanguageModel { - Cloud(CloudModel), - OpenAi(OpenAiModel), - Anthropic(AnthropicModel), - Ollama(OllamaModel), -} - -impl Default for LanguageModel { - fn default() -> Self { - LanguageModel::Cloud(CloudModel::default()) - } -} - -impl LanguageModel { - pub fn telemetry_id(&self) -> String { - match self { - LanguageModel::OpenAi(model) => format!("openai/{}", model.id()), - LanguageModel::Anthropic(model) => format!("anthropic/{}", model.id()), - LanguageModel::Cloud(model) => format!("zed.dev/{}", model.id()), - LanguageModel::Ollama(model) => format!("ollama/{}", model.id()), - } - } - - pub fn display_name(&self) -> String { - match self { - LanguageModel::OpenAi(model) => model.display_name().into(), - LanguageModel::Anthropic(model) => model.display_name().into(), - LanguageModel::Cloud(model) => model.display_name().into(), - LanguageModel::Ollama(model) => model.display_name().into(), - } - } - - pub fn max_token_count(&self) -> usize { - match self { - LanguageModel::OpenAi(model) => model.max_token_count(), - LanguageModel::Anthropic(model) => model.max_token_count(), - LanguageModel::Cloud(model) => model.max_token_count(), - LanguageModel::Ollama(model) => model.max_token_count(), - } - } - - pub fn id(&self) -> &str { - match self { - LanguageModel::OpenAi(model) => model.id(), - LanguageModel::Anthropic(model) => model.id(), - LanguageModel::Cloud(model) => model.id(), - LanguageModel::Ollama(model) => model.id(), - } - } -} - -#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] -pub struct LanguageModelRequestMessage { - pub role: Role, - pub content: String, -} - -impl LanguageModelRequestMessage { - pub fn to_proto(&self) -> proto::LanguageModelRequestMessage { - proto::LanguageModelRequestMessage { - role: self.role.to_proto() as i32, - content: self.content.clone(), - tool_calls: Vec::new(), - tool_call_id: None, - } - } -} - -#[derive(Debug, Default, Serialize, Deserialize)] -pub struct LanguageModelRequest { - pub model: LanguageModel, - pub messages: Vec, - pub stop: Vec, - pub temperature: f32, -} - -impl LanguageModelRequest { - pub fn to_proto(&self) -> proto::CompleteWithLanguageModel { - proto::CompleteWithLanguageModel { - model: self.model.id().to_string(), - messages: self.messages.iter().map(|m| m.to_proto()).collect(), - stop: self.stop.clone(), - temperature: self.temperature, - tool_choice: None, - tools: Vec::new(), - } - } - - /// Before we send the request to the server, we can perform fixups on it appropriate to the model. - pub fn preprocess(&mut self) { - match &self.model { - LanguageModel::OpenAi(_) => {} - LanguageModel::Anthropic(_) => {} - LanguageModel::Ollama(_) => {} - LanguageModel::Cloud(model) => match model { - CloudModel::Claude3Opus - | CloudModel::Claude3Sonnet - | CloudModel::Claude3Haiku - | CloudModel::Claude3_5Sonnet => { - preprocess_anthropic_request(self); - } - _ => {} - }, - } - } -} - -#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] -pub struct LanguageModelResponseMessage { - pub role: Option, - pub content: Option, -} - #[derive(Deserialize, Debug)] pub struct LanguageModelUsage { pub prompt_tokens: u32, @@ -343,7 +182,7 @@ pub fn init(fs: Arc, client: Arc, cx: &mut AppContext) { context_store::init(&client); prompt_library::init(cx); - completion_provider::init(client.clone(), cx); + init_completion_provider(Arc::clone(&client), cx); assistant_slash_command::init(cx); register_slash_commands(cx); assistant_panel::init(cx); @@ -368,6 +207,20 @@ pub fn init(fs: Arc, client: Arc, cx: &mut AppContext) { .detach(); } +fn init_completion_provider(client: Arc, cx: &mut AppContext) { + let provider = assistant_settings::create_provider_from_settings(client.clone(), 0, cx); + cx.set_global(CompletionProvider::new(provider, Some(client))); + + let mut settings_version = 0; + cx.observe_global::(move |cx| { + settings_version += 1; + cx.update_global::(|provider, cx| { + assistant_settings::update_completion_provider_settings(provider, settings_version, cx); + }) + }) + .detach(); +} + fn register_slash_commands(cx: &mut AppContext) { let slash_command_registry = SlashCommandRegistry::global(cx); slash_command_registry.register_command(file_command::FileSlashCommand, true); diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index 92bd4b9cbe..e02c26837a 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -8,18 +8,18 @@ use crate::{ SlashCommandCompletionProvider, SlashCommandRegistry, }, terminal_inline_assistant::TerminalInlineAssistant, - Assist, CompletionProvider, ConfirmCommand, Context, ContextEvent, ContextId, ContextStore, - CycleMessageRole, DebugEditSteps, DeployHistory, DeployPromptLibrary, EditStep, - EditStepOperations, EditSuggestionGroup, InlineAssist, InlineAssistId, InlineAssistant, - InsertIntoEditor, MessageStatus, ModelSelector, PendingSlashCommand, PendingSlashCommandStatus, - QuoteSelection, RemoteContextMetadata, ResetKey, Role, SavedContextMetadata, Split, - ToggleFocus, ToggleModelSelector, + Assist, ConfirmCommand, Context, ContextEvent, ContextId, ContextStore, CycleMessageRole, + DebugEditSteps, DeployHistory, DeployPromptLibrary, EditStep, EditStepOperations, + EditSuggestionGroup, InlineAssist, InlineAssistId, InlineAssistant, InsertIntoEditor, + MessageStatus, ModelSelector, PendingSlashCommand, PendingSlashCommandStatus, QuoteSelection, + RemoteContextMetadata, ResetKey, SavedContextMetadata, Split, ToggleFocus, ToggleModelSelector, }; use anyhow::{anyhow, Result}; use assistant_slash_command::{SlashCommand, SlashCommandOutputSection}; use breadcrumbs::Breadcrumbs; use client::proto; use collections::{BTreeSet, HashMap, HashSet}; +use completion::CompletionProvider; use editor::{ actions::{FoldAt, MoveToEndOfLine, Newline, ShowCompletions, UnfoldAt}, display_map::{ @@ -43,6 +43,7 @@ use language::{ language_settings::SoftWrap, Buffer, Capability, LanguageRegistry, LspAdapterDelegate, Point, ToOffset, }; +use language_model::Role; use multi_buffer::MultiBufferRow; use picker::{Picker, PickerDelegate}; use project::{Project, ProjectLspAdapterDelegate}; diff --git a/crates/assistant/src/assistant_settings.rs b/crates/assistant/src/assistant_settings.rs index d341973326..7fca691e7a 100644 --- a/crates/assistant/src/assistant_settings.rs +++ b/crates/assistant/src/assistant_settings.rs @@ -1,166 +1,19 @@ -use std::fmt; +use std::{sync::Arc, time::Duration}; -use crate::{preprocess_anthropic_request, LanguageModel, LanguageModelRequest}; -pub use anthropic::Model as AnthropicModel; -use gpui::Pixels; -pub use ollama::Model as OllamaModel; -pub use open_ai::Model as OpenAiModel; -use schemars::{ - schema::{InstanceType, Metadata, Schema, SchemaObject}, - JsonSchema, -}; -use serde::{ - de::{self, Visitor}, - Deserialize, Deserializer, Serialize, Serializer, +use anthropic::Model as AnthropicModel; +use client::Client; +use completion::{ + AnthropicCompletionProvider, CloudCompletionProvider, CompletionProvider, + LanguageModelCompletionProvider, OllamaCompletionProvider, OpenAiCompletionProvider, }; +use gpui::{AppContext, Pixels}; +use language_model::{CloudModel, LanguageModel}; +use ollama::Model as OllamaModel; +use open_ai::Model as OpenAiModel; +use parking_lot::RwLock; +use schemars::{schema::Schema, JsonSchema}; +use serde::{Deserialize, Serialize}; use settings::{Settings, SettingsSources}; -use strum::{EnumIter, IntoEnumIterator}; - -#[derive(Clone, Debug, Default, PartialEq, EnumIter)] -pub enum CloudModel { - Gpt3Point5Turbo, - Gpt4, - Gpt4Turbo, - #[default] - Gpt4Omni, - Gpt4OmniMini, - Claude3_5Sonnet, - Claude3Opus, - Claude3Sonnet, - Claude3Haiku, - Gemini15Pro, - Gemini15Flash, - Custom(String), -} - -impl Serialize for CloudModel { - fn serialize(&self, serializer: S) -> Result - where - S: Serializer, - { - serializer.serialize_str(self.id()) - } -} - -impl<'de> Deserialize<'de> for CloudModel { - fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - struct ZedDotDevModelVisitor; - - impl<'de> Visitor<'de> for ZedDotDevModelVisitor { - type Value = CloudModel; - - fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { - formatter.write_str("a string for a ZedDotDevModel variant or a custom model") - } - - fn visit_str(self, value: &str) -> Result - where - E: de::Error, - { - let model = CloudModel::iter() - .find(|model| model.id() == value) - .unwrap_or_else(|| CloudModel::Custom(value.to_string())); - Ok(model) - } - } - - deserializer.deserialize_str(ZedDotDevModelVisitor) - } -} - -impl JsonSchema for CloudModel { - fn schema_name() -> String { - "ZedDotDevModel".to_owned() - } - - fn json_schema(_generator: &mut schemars::gen::SchemaGenerator) -> Schema { - let variants = CloudModel::iter() - .filter_map(|model| { - let id = model.id(); - if id.is_empty() { - None - } else { - Some(id.to_string()) - } - }) - .collect::>(); - Schema::Object(SchemaObject { - instance_type: Some(InstanceType::String.into()), - enum_values: Some(variants.iter().map(|s| s.clone().into()).collect()), - metadata: Some(Box::new(Metadata { - title: Some("ZedDotDevModel".to_owned()), - default: Some(CloudModel::default().id().into()), - examples: variants.into_iter().map(Into::into).collect(), - ..Default::default() - })), - ..Default::default() - }) - } -} - -impl CloudModel { - pub fn id(&self) -> &str { - match self { - Self::Gpt3Point5Turbo => "gpt-3.5-turbo", - Self::Gpt4 => "gpt-4", - Self::Gpt4Turbo => "gpt-4-turbo-preview", - Self::Gpt4Omni => "gpt-4o", - Self::Gpt4OmniMini => "gpt-4o-mini", - Self::Claude3_5Sonnet => "claude-3-5-sonnet", - Self::Claude3Opus => "claude-3-opus", - Self::Claude3Sonnet => "claude-3-sonnet", - Self::Claude3Haiku => "claude-3-haiku", - Self::Gemini15Pro => "gemini-1.5-pro", - Self::Gemini15Flash => "gemini-1.5-flash", - Self::Custom(id) => id, - } - } - - pub fn display_name(&self) -> &str { - match self { - Self::Gpt3Point5Turbo => "GPT 3.5 Turbo", - Self::Gpt4 => "GPT 4", - Self::Gpt4Turbo => "GPT 4 Turbo", - Self::Gpt4Omni => "GPT 4 Omni", - Self::Gpt4OmniMini => "GPT 4 Omni Mini", - Self::Claude3_5Sonnet => "Claude 3.5 Sonnet", - Self::Claude3Opus => "Claude 3 Opus", - Self::Claude3Sonnet => "Claude 3 Sonnet", - Self::Claude3Haiku => "Claude 3 Haiku", - Self::Gemini15Pro => "Gemini 1.5 Pro", - Self::Gemini15Flash => "Gemini 1.5 Flash", - Self::Custom(id) => id.as_str(), - } - } - - pub fn max_token_count(&self) -> usize { - match self { - Self::Gpt3Point5Turbo => 2048, - Self::Gpt4 => 4096, - Self::Gpt4Turbo | Self::Gpt4Omni => 128000, - Self::Gpt4OmniMini => 128000, - Self::Claude3_5Sonnet - | Self::Claude3Opus - | Self::Claude3Sonnet - | Self::Claude3Haiku => 200000, - Self::Gemini15Pro => 128000, - Self::Gemini15Flash => 32000, - Self::Custom(_) => 4096, // TODO: Make this configurable - } - } - - pub fn preprocess_request(&self, request: &mut LanguageModelRequest) { - match self { - Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => { - preprocess_anthropic_request(request) - } - _ => {} - } - } -} #[derive(Copy, Clone, Default, Debug, Serialize, Deserialize, JsonSchema)] #[serde(rename_all = "snake_case")] @@ -620,6 +473,124 @@ fn merge(target: &mut T, value: Option) { } } +pub fn update_completion_provider_settings( + provider: &mut CompletionProvider, + version: usize, + cx: &mut AppContext, +) { + let updated = match &AssistantSettings::get_global(cx).provider { + AssistantProvider::ZedDotDev { model } => provider + .update_current_as::<_, CloudCompletionProvider>(|provider| { + provider.update(model.clone(), version); + }), + AssistantProvider::OpenAi { + model, + api_url, + low_speed_timeout_in_seconds, + available_models, + } => provider.update_current_as::<_, OpenAiCompletionProvider>(|provider| { + provider.update( + choose_openai_model(&model, &available_models), + api_url.clone(), + low_speed_timeout_in_seconds.map(Duration::from_secs), + version, + ); + }), + AssistantProvider::Anthropic { + model, + api_url, + low_speed_timeout_in_seconds, + } => provider.update_current_as::<_, AnthropicCompletionProvider>(|provider| { + provider.update( + model.clone(), + api_url.clone(), + low_speed_timeout_in_seconds.map(Duration::from_secs), + version, + ); + }), + AssistantProvider::Ollama { + model, + api_url, + low_speed_timeout_in_seconds, + } => provider.update_current_as::<_, OllamaCompletionProvider>(|provider| { + provider.update( + model.clone(), + api_url.clone(), + low_speed_timeout_in_seconds.map(Duration::from_secs), + version, + cx, + ); + }), + }; + + // Previously configured provider was changed to another one + if updated.is_none() { + provider.update_provider(|client| create_provider_from_settings(client, version, cx)); + } +} + +pub(crate) fn create_provider_from_settings( + client: Arc, + settings_version: usize, + cx: &mut AppContext, +) -> Arc> { + match &AssistantSettings::get_global(cx).provider { + AssistantProvider::ZedDotDev { model } => Arc::new(RwLock::new( + CloudCompletionProvider::new(model.clone(), client.clone(), settings_version, cx), + )), + AssistantProvider::OpenAi { + model, + api_url, + low_speed_timeout_in_seconds, + available_models, + } => Arc::new(RwLock::new(OpenAiCompletionProvider::new( + choose_openai_model(&model, &available_models), + api_url.clone(), + client.http_client(), + low_speed_timeout_in_seconds.map(Duration::from_secs), + settings_version, + available_models.clone(), + ))), + AssistantProvider::Anthropic { + model, + api_url, + low_speed_timeout_in_seconds, + } => Arc::new(RwLock::new(AnthropicCompletionProvider::new( + model.clone(), + api_url.clone(), + client.http_client(), + low_speed_timeout_in_seconds.map(Duration::from_secs), + settings_version, + ))), + AssistantProvider::Ollama { + model, + api_url, + low_speed_timeout_in_seconds, + } => Arc::new(RwLock::new(OllamaCompletionProvider::new( + model.clone(), + api_url.clone(), + client.http_client(), + low_speed_timeout_in_seconds.map(Duration::from_secs), + settings_version, + cx, + ))), + } +} + +/// Choose which model to use for openai provider. +/// If the model is not available, try to use the first available model, or fallback to the original model. +fn choose_openai_model( + model: &::open_ai::Model, + available_models: &[::open_ai::Model], +) -> ::open_ai::Model { + available_models + .iter() + .find(|&m| m == model) + .or_else(|| available_models.first()) + .unwrap_or_else(|| model) + .clone() +} + #[cfg(test)] mod tests { use gpui::{AppContext, UpdateGlobal}; diff --git a/crates/assistant/src/context.rs b/crates/assistant/src/context.rs index 25f24753a1..f75b693bbd 100644 --- a/crates/assistant/src/context.rs +++ b/crates/assistant/src/context.rs @@ -1,12 +1,12 @@ use crate::{ - prompt_library::PromptStore, slash_command::SlashCommandLine, CompletionProvider, - LanguageModelRequest, LanguageModelRequestMessage, MessageId, MessageStatus, Role, + prompt_library::PromptStore, slash_command::SlashCommandLine, CompletionProvider, MessageId, + MessageStatus, }; use anyhow::{anyhow, Context as _, Result}; use assistant_slash_command::{ SlashCommandOutput, SlashCommandOutputSection, SlashCommandRegistry, }; -use client::{proto, telemetry::Telemetry}; +use client::{self, proto, telemetry::Telemetry}; use clock::ReplicaId; use collections::{HashMap, HashSet}; use fs::Fs; @@ -18,6 +18,8 @@ use gpui::{AppContext, Context as _, EventEmitter, Model, ModelContext, Subscrip use language::{ AnchorRangeExt, Bias, Buffer, LanguageRegistry, OffsetRangeExt, ParseStatus, Point, ToOffset, }; +use language_model::LanguageModelRequestMessage; +use language_model::{LanguageModelRequest, Role}; use open_ai::Model as OpenAiModel; use paths::contexts_dir; use project::Project; @@ -2477,9 +2479,10 @@ mod tests { use crate::{ assistant_panel, prompt_library, slash_command::{active_command, file_command}, - FakeCompletionProvider, MessageId, + MessageId, }; use assistant_slash_command::{ArgumentCompletion, SlashCommand}; + use completion::FakeCompletionProvider; use fs::FakeFs; use gpui::{AppContext, TestAppContext, WeakView}; use indoc::indoc; diff --git a/crates/assistant/src/inline_assistant.rs b/crates/assistant/src/inline_assistant.rs index be14e271e8..b8dbcacd2b 100644 --- a/crates/assistant/src/inline_assistant.rs +++ b/crates/assistant/src/inline_assistant.rs @@ -1,7 +1,6 @@ use crate::{ assistant_settings::AssistantSettings, humanize_token_count, prompts::generate_content_prompt, - AssistantPanel, AssistantPanelEvent, CompletionProvider, Hunk, LanguageModelRequest, - LanguageModelRequestMessage, Role, StreamingDiff, + AssistantPanel, AssistantPanelEvent, CompletionProvider, Hunk, StreamingDiff, }; use anyhow::{anyhow, Context as _, Result}; use client::telemetry::Telemetry; @@ -28,6 +27,7 @@ use gpui::{ WhiteSpace, WindowContext, }; use language::{Buffer, Point, Selection, TransactionId}; +use language_model::{LanguageModelRequest, LanguageModelRequestMessage, Role}; use multi_buffer::MultiBufferRow; use parking_lot::Mutex; use rope::Rope; @@ -1432,8 +1432,7 @@ impl Render for PromptEditor { PopoverMenu::new("model-switcher") .menu(move |cx| { ContextMenu::build(cx, |mut menu, cx| { - for model in CompletionProvider::global(cx).available_models(cx) - { + for model in CompletionProvider::global(cx).available_models() { menu = menu.custom_entry( { let model = model.clone(); @@ -2606,7 +2605,7 @@ fn merge_ranges(ranges: &mut Vec>, buffer: &MultiBufferSnapshot) { #[cfg(test)] mod tests { use super::*; - use crate::FakeCompletionProvider; + use completion::FakeCompletionProvider; use futures::stream::{self}; use gpui::{Context, TestAppContext}; use indoc::indoc; diff --git a/crates/assistant/src/model_selector.rs b/crates/assistant/src/model_selector.rs index a27b2b5565..6cd50a59da 100644 --- a/crates/assistant/src/model_selector.rs +++ b/crates/assistant/src/model_selector.rs @@ -23,7 +23,7 @@ impl RenderOnce for ModelSelector { .with_handle(self.handle) .menu(move |cx| { ContextMenu::build(cx, |mut menu, cx| { - for model in CompletionProvider::global(cx).available_models(cx) { + for model in CompletionProvider::global(cx).available_models() { menu = menu.custom_entry( { let model = model.clone(); diff --git a/crates/assistant/src/prompt_library.rs b/crates/assistant/src/prompt_library.rs index 9d782aedc7..a59f4e3c0f 100644 --- a/crates/assistant/src/prompt_library.rs +++ b/crates/assistant/src/prompt_library.rs @@ -1,6 +1,6 @@ use crate::{ slash_command::SlashCommandCompletionProvider, AssistantPanel, CompletionProvider, - InlineAssist, InlineAssistant, LanguageModelRequest, LanguageModelRequestMessage, Role, + InlineAssist, InlineAssistant, }; use anyhow::{anyhow, Result}; use assets::Assets; @@ -19,6 +19,7 @@ use gpui::{ }; use heed::{types::SerdeBincode, Database, RoTxn}; use language::{language_settings::SoftWrap, Buffer, LanguageRegistry}; +use language_model::{LanguageModelRequest, LanguageModelRequestMessage, Role}; use parking_lot::RwLock; use picker::{Picker, PickerDelegate}; use rope::Rope; diff --git a/crates/assistant/src/terminal_inline_assistant.rs b/crates/assistant/src/terminal_inline_assistant.rs index 8f2cd63bac..192db0cf5e 100644 --- a/crates/assistant/src/terminal_inline_assistant.rs +++ b/crates/assistant/src/terminal_inline_assistant.rs @@ -1,7 +1,7 @@ use crate::{ assistant_settings::AssistantSettings, humanize_token_count, prompts::generate_terminal_assistant_prompt, AssistantPanel, AssistantPanelEvent, - CompletionProvider, LanguageModelRequest, LanguageModelRequestMessage, Role, + CompletionProvider, }; use anyhow::{Context as _, Result}; use client::telemetry::Telemetry; @@ -17,6 +17,7 @@ use gpui::{ Model, ModelContext, Subscription, Task, TextStyle, UpdateGlobal, View, WeakView, WhiteSpace, }; use language::Buffer; +use language_model::{LanguageModelRequest, LanguageModelRequestMessage, Role}; use settings::{update_settings_file, Settings}; use std::{ cmp, @@ -558,8 +559,7 @@ impl Render for PromptEditor { PopoverMenu::new("model-switcher") .menu(move |cx| { ContextMenu::build(cx, |mut menu, cx| { - for model in CompletionProvider::global(cx).available_models(cx) - { + for model in CompletionProvider::global(cx).available_models() { menu = menu.custom_entry( { let model = model.clone(); diff --git a/crates/collab/Cargo.toml b/crates/collab/Cargo.toml index a413a46489..cf99e7c90c 100644 --- a/crates/collab/Cargo.toml +++ b/crates/collab/Cargo.toml @@ -30,6 +30,7 @@ chrono.workspace = true clock.workspace = true clickhouse.workspace = true collections.workspace = true +completion.workspace = true dashmap = "5.4" envy = "0.4.2" futures.workspace = true @@ -79,6 +80,7 @@ channel.workspace = true client = { workspace = true, features = ["test-support"] } collab_ui = { workspace = true, features = ["test-support"] } collections = { workspace = true, features = ["test-support"] } +completion = { workspace = true, features = ["test-support"] } ctor.workspace = true editor = { workspace = true, features = ["test-support"] } env_logger.workspace = true diff --git a/crates/collab/src/tests/test_server.rs b/crates/collab/src/tests/test_server.rs index a3eafd0f94..61c0a8239d 100644 --- a/crates/collab/src/tests/test_server.rs +++ b/crates/collab/src/tests/test_server.rs @@ -295,7 +295,7 @@ impl TestServer { menu::init(); dev_server_projects::init(client.clone(), cx); settings::KeymapFile::load_asset(os_keymap, cx).unwrap(); - assistant::FakeCompletionProvider::setup_test(cx); + completion::FakeCompletionProvider::setup_test(cx); assistant::context_store::init(&client); }); diff --git a/crates/completion/Cargo.toml b/crates/completion/Cargo.toml new file mode 100644 index 0000000000..18181e7bb5 --- /dev/null +++ b/crates/completion/Cargo.toml @@ -0,0 +1,56 @@ +[package] +name = "completion" +version = "0.1.0" +edition = "2021" +publish = false +license = "GPL-3.0-or-later" + +[lints] +workspace = true + +[lib] +path = "src/completion.rs" +doctest = false + +[features] +test-support = [ + "editor/test-support", + "language/test-support", + "project/test-support", + "text/test-support", +] + +[dependencies] +anthropic = { workspace = true, features = ["schemars"] } +anyhow.workspace = true +client.workspace = true +collections.workspace = true +editor.workspace = true +futures.workspace = true +gpui.workspace = true +http.workspace = true +language_model.workspace = true +log.workspace = true +menu.workspace = true +ollama = { workspace = true, features = ["schemars"] } +open_ai = { workspace = true, features = ["schemars"] } +parking_lot.workspace = true +serde.workspace = true +serde_json.workspace = true +settings.workspace = true +smol.workspace = true +strum.workspace = true +theme.workspace = true +tiktoken-rs.workspace = true +ui.workspace = true +util.workspace = true + +[dev-dependencies] +ctor.workspace = true +editor = { workspace = true, features = ["test-support"] } +env_logger.workspace = true +language = { workspace = true, features = ["test-support"] } +project = { workspace = true, features = ["test-support"] } +rand.workspace = true +text = { workspace = true, features = ["test-support"] } +unindent.workspace = true diff --git a/crates/completion/LICENSE-GPL b/crates/completion/LICENSE-GPL new file mode 120000 index 0000000000..89e542f750 --- /dev/null +++ b/crates/completion/LICENSE-GPL @@ -0,0 +1 @@ +../../LICENSE-GPL \ No newline at end of file diff --git a/crates/assistant/src/completion_provider/anthropic.rs b/crates/completion/src/anthropic.rs similarity index 86% rename from crates/assistant/src/completion_provider/anthropic.rs rename to crates/completion/src/anthropic.rs index 48d2020cbe..dc71ebd8ca 100644 --- a/crates/assistant/src/completion_provider/anthropic.rs +++ b/crates/completion/src/anthropic.rs @@ -1,14 +1,12 @@ -use crate::{ - assistant_settings::AnthropicModel, CompletionProvider, LanguageModel, LanguageModelRequest, - Role, -}; -use crate::{count_open_ai_tokens, LanguageModelCompletionProvider, LanguageModelRequestMessage}; -use anthropic::{stream_completion, Request, RequestMessage}; +use crate::{count_open_ai_tokens, LanguageModelCompletionProvider}; +use crate::{CompletionProvider, LanguageModel, LanguageModelRequest}; +use anthropic::{stream_completion, Model as AnthropicModel, Request, RequestMessage}; use anyhow::{anyhow, Result}; use editor::{Editor, EditorElement, EditorStyle}; use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt}; use gpui::{AnyView, AppContext, FontStyle, Task, TextStyle, View, WhiteSpace}; use http::HttpClient; +use language_model::Role; use settings::Settings; use std::time::Duration; use std::{env, sync::Arc}; @@ -27,7 +25,7 @@ pub struct AnthropicCompletionProvider { } impl LanguageModelCompletionProvider for AnthropicCompletionProvider { - fn available_models(&self, _cx: &AppContext) -> Vec { + fn available_models(&self) -> Vec { AnthropicModel::iter() .map(LanguageModel::Anthropic) .collect() @@ -176,7 +174,7 @@ impl AnthropicCompletionProvider { } fn to_anthropic_request(&self, mut request: LanguageModelRequest) -> Request { - preprocess_anthropic_request(&mut request); + request.preprocess_anthropic(); let model = match request.model { LanguageModel::Anthropic(model) => model, @@ -213,49 +211,6 @@ impl AnthropicCompletionProvider { } } -pub fn preprocess_anthropic_request(request: &mut LanguageModelRequest) { - let mut new_messages: Vec = Vec::new(); - let mut system_message = String::new(); - - for message in request.messages.drain(..) { - if message.content.is_empty() { - continue; - } - - match message.role { - Role::User | Role::Assistant => { - if let Some(last_message) = new_messages.last_mut() { - if last_message.role == message.role { - last_message.content.push_str("\n\n"); - last_message.content.push_str(&message.content); - continue; - } - } - - new_messages.push(message); - } - Role::System => { - if !system_message.is_empty() { - system_message.push_str("\n\n"); - } - system_message.push_str(&message.content); - } - } - } - - if !system_message.is_empty() { - new_messages.insert( - 0, - LanguageModelRequestMessage { - role: Role::System, - content: system_message, - }, - ); - } - - request.messages = new_messages; -} - struct AuthenticationPrompt { api_key: View, api_url: String, diff --git a/crates/assistant/src/completion_provider/cloud.rs b/crates/completion/src/cloud.rs similarity index 96% rename from crates/assistant/src/completion_provider/cloud.rs rename to crates/completion/src/cloud.rs index 32b8587116..f84576aeca 100644 --- a/crates/assistant/src/completion_provider/cloud.rs +++ b/crates/completion/src/cloud.rs @@ -1,11 +1,12 @@ use crate::{ - assistant_settings::CloudModel, count_open_ai_tokens, CompletionProvider, LanguageModel, - LanguageModelCompletionProvider, LanguageModelRequest, + count_open_ai_tokens, CompletionProvider, LanguageModel, LanguageModelCompletionProvider, + LanguageModelRequest, }; use anyhow::{anyhow, Result}; use client::{proto, Client}; use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt, TryFutureExt}; use gpui::{AnyView, AppContext, Task}; +use language_model::CloudModel; use std::{future, sync::Arc}; use strum::IntoEnumIterator; use ui::prelude::*; @@ -52,7 +53,7 @@ impl CloudCompletionProvider { } impl LanguageModelCompletionProvider for CloudCompletionProvider { - fn available_models(&self, _cx: &AppContext) -> Vec { + fn available_models(&self) -> Vec { let mut custom_model = if let CloudModel::Custom(custom_model) = self.model.clone() { Some(custom_model) } else { diff --git a/crates/assistant/src/completion_provider.rs b/crates/completion/src/completion.rs similarity index 57% rename from crates/assistant/src/completion_provider.rs rename to crates/completion/src/completion.rs index 13f91f70e3..a219e90b51 100644 --- a/crates/assistant/src/completion_provider.rs +++ b/crates/completion/src/completion.rs @@ -6,52 +6,19 @@ mod ollama; mod open_ai; pub use anthropic::*; +use anyhow::Result; +use client::Client; pub use cloud::*; #[cfg(any(test, feature = "test-support"))] pub use fake::*; +use futures::{future::BoxFuture, stream::BoxStream, StreamExt}; +use gpui::{AnyView, AppContext, Task, WindowContext}; +use language_model::{LanguageModel, LanguageModelRequest}; pub use ollama::*; pub use open_ai::*; use parking_lot::RwLock; use smol::lock::{Semaphore, SemaphoreGuardArc}; - -use crate::{ - assistant_settings::{AssistantProvider, AssistantSettings}, - LanguageModel, LanguageModelRequest, -}; -use anyhow::Result; -use client::Client; -use futures::{future::BoxFuture, stream::BoxStream, StreamExt}; -use gpui::{AnyView, AppContext, BorrowAppContext, Task, WindowContext}; -use settings::{Settings, SettingsStore}; -use std::{any::Any, pin::Pin, sync::Arc, task::Poll, time::Duration}; - -/// Choose which model to use for openai provider. -/// If the model is not available, try to use the first available model, or fallback to the original model. -fn choose_openai_model( - model: &::open_ai::Model, - available_models: &[::open_ai::Model], -) -> ::open_ai::Model { - available_models - .iter() - .find(|&m| m == model) - .or_else(|| available_models.first()) - .unwrap_or_else(|| model) - .clone() -} - -pub fn init(client: Arc, cx: &mut AppContext) { - let provider = create_provider_from_settings(client.clone(), 0, cx); - cx.set_global(CompletionProvider::new(provider, Some(client))); - - let mut settings_version = 0; - cx.observe_global::(move |cx| { - settings_version += 1; - cx.update_global::(|provider, cx| { - provider.update_settings(settings_version, cx); - }) - }) - .detach(); -} +use std::{any::Any, pin::Pin, sync::Arc, task::Poll}; pub struct CompletionResponse { inner: BoxStream<'static, Result>, @@ -70,7 +37,7 @@ impl futures::Stream for CompletionResponse { } pub trait LanguageModelCompletionProvider: Send + Sync { - fn available_models(&self, cx: &AppContext) -> Vec; + fn available_models(&self) -> Vec; fn settings_version(&self) -> usize; fn is_authenticated(&self) -> bool; fn authenticate(&self, cx: &AppContext) -> Task>; @@ -110,8 +77,8 @@ impl CompletionProvider { } } - pub fn available_models(&self, cx: &AppContext) -> Vec { - self.provider.read().available_models(cx) + pub fn available_models(&self) -> Vec { + self.provider.read().available_models() } pub fn settings_version(&self) -> usize { @@ -176,6 +143,17 @@ impl CompletionProvider { Ok(completion) }) } + + pub fn update_provider( + &mut self, + get_provider: impl FnOnce(Arc) -> Arc>, + ) { + if let Some(client) = &self.client { + self.provider = get_provider(Arc::clone(client)); + } else { + log::warn!("completion provider cannot be updated because its client was not set"); + } + } } impl gpui::Global for CompletionProvider {} @@ -196,109 +174,6 @@ impl CompletionProvider { None } } - - pub fn update_settings(&mut self, version: usize, cx: &mut AppContext) { - let updated = match &AssistantSettings::get_global(cx).provider { - AssistantProvider::ZedDotDev { model } => self - .update_current_as::<_, CloudCompletionProvider>(|provider| { - provider.update(model.clone(), version); - }), - AssistantProvider::OpenAi { - model, - api_url, - low_speed_timeout_in_seconds, - available_models, - } => self.update_current_as::<_, OpenAiCompletionProvider>(|provider| { - provider.update( - choose_openai_model(&model, &available_models), - api_url.clone(), - low_speed_timeout_in_seconds.map(Duration::from_secs), - version, - ); - }), - AssistantProvider::Anthropic { - model, - api_url, - low_speed_timeout_in_seconds, - } => self.update_current_as::<_, AnthropicCompletionProvider>(|provider| { - provider.update( - model.clone(), - api_url.clone(), - low_speed_timeout_in_seconds.map(Duration::from_secs), - version, - ); - }), - AssistantProvider::Ollama { - model, - api_url, - low_speed_timeout_in_seconds, - } => self.update_current_as::<_, OllamaCompletionProvider>(|provider| { - provider.update( - model.clone(), - api_url.clone(), - low_speed_timeout_in_seconds.map(Duration::from_secs), - version, - cx, - ); - }), - }; - - // Previously configured provider was changed to another one - if updated.is_none() { - if let Some(client) = self.client.clone() { - self.provider = create_provider_from_settings(client, version, cx); - } else { - log::warn!("completion provider cannot be created because client is not set"); - } - } - } -} - -fn create_provider_from_settings( - client: Arc, - settings_version: usize, - cx: &mut AppContext, -) -> Arc> { - match &AssistantSettings::get_global(cx).provider { - AssistantProvider::ZedDotDev { model } => Arc::new(RwLock::new( - CloudCompletionProvider::new(model.clone(), client.clone(), settings_version, cx), - )), - AssistantProvider::OpenAi { - model, - api_url, - low_speed_timeout_in_seconds, - available_models, - } => Arc::new(RwLock::new(OpenAiCompletionProvider::new( - choose_openai_model(&model, &available_models), - api_url.clone(), - client.http_client(), - low_speed_timeout_in_seconds.map(Duration::from_secs), - settings_version, - ))), - AssistantProvider::Anthropic { - model, - api_url, - low_speed_timeout_in_seconds, - } => Arc::new(RwLock::new(AnthropicCompletionProvider::new( - model.clone(), - api_url.clone(), - client.http_client(), - low_speed_timeout_in_seconds.map(Duration::from_secs), - settings_version, - ))), - AssistantProvider::Ollama { - model, - api_url, - low_speed_timeout_in_seconds, - } => Arc::new(RwLock::new(OllamaCompletionProvider::new( - model.clone(), - api_url.clone(), - client.http_client(), - low_speed_timeout_in_seconds.map(Duration::from_secs), - settings_version, - cx, - ))), - } } #[cfg(test)] @@ -311,8 +186,8 @@ mod tests { use smol::stream::StreamExt; use crate::{ - completion_provider::MAX_CONCURRENT_COMPLETION_REQUESTS, CompletionProvider, - FakeCompletionProvider, LanguageModelRequest, + CompletionProvider, FakeCompletionProvider, LanguageModelRequest, + MAX_CONCURRENT_COMPLETION_REQUESTS, }; #[gpui::test] diff --git a/crates/assistant/src/completion_provider/fake.rs b/crates/completion/src/fake.rs similarity index 97% rename from crates/assistant/src/completion_provider/fake.rs rename to crates/completion/src/fake.rs index e9ad8d9a0f..9eee0f736f 100644 --- a/crates/assistant/src/completion_provider/fake.rs +++ b/crates/completion/src/fake.rs @@ -62,7 +62,7 @@ impl FakeCompletionProvider { } impl LanguageModelCompletionProvider for FakeCompletionProvider { - fn available_models(&self, _cx: &AppContext) -> Vec { + fn available_models(&self) -> Vec { vec![LanguageModel::default()] } diff --git a/crates/assistant/src/completion_provider/ollama.rs b/crates/completion/src/ollama.rs similarity index 96% rename from crates/assistant/src/completion_provider/ollama.rs rename to crates/completion/src/ollama.rs index 59d79e3ae7..30d797c76b 100644 --- a/crates/assistant/src/completion_provider/ollama.rs +++ b/crates/completion/src/ollama.rs @@ -1,15 +1,14 @@ use crate::LanguageModelCompletionProvider; -use crate::{ - assistant_settings::OllamaModel, CompletionProvider, LanguageModel, LanguageModelRequest, Role, -}; +use crate::{CompletionProvider, LanguageModel, LanguageModelRequest}; use anyhow::Result; use futures::StreamExt as _; use futures::{future::BoxFuture, stream::BoxStream, FutureExt}; use gpui::{AnyView, AppContext, Task}; use http::HttpClient; +use language_model::Role; +use ollama::Model as OllamaModel; use ollama::{ get_models, preload_model, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest, - Role as OllamaRole, }; use std::sync::Arc; use std::time::Duration; @@ -28,7 +27,7 @@ pub struct OllamaCompletionProvider { } impl LanguageModelCompletionProvider for OllamaCompletionProvider { - fn available_models(&self, _cx: &AppContext) -> Vec { + fn available_models(&self) -> Vec { self.available_models .iter() .map(|m| LanguageModel::Ollama(m.clone())) @@ -262,16 +261,6 @@ impl OllamaCompletionProvider { } } -impl From for ollama::Role { - fn from(val: Role) -> Self { - match val { - Role::User => OllamaRole::User, - Role::Assistant => OllamaRole::Assistant, - Role::System => OllamaRole::System, - } - } -} - struct DownloadOllamaMessage { retry_connection: Box Task>>, } diff --git a/crates/assistant/src/completion_provider/open_ai.rs b/crates/completion/src/open_ai.rs similarity index 89% rename from crates/assistant/src/completion_provider/open_ai.rs rename to crates/completion/src/open_ai.rs index fd65d1afe5..0a0f6d5b4a 100644 --- a/crates/assistant/src/completion_provider/open_ai.rs +++ b/crates/completion/src/open_ai.rs @@ -1,15 +1,13 @@ -use crate::assistant_settings::CloudModel; -use crate::assistant_settings::{AssistantProvider, AssistantSettings}; +use crate::CompletionProvider; use crate::LanguageModelCompletionProvider; -use crate::{ - assistant_settings::OpenAiModel, CompletionProvider, LanguageModel, LanguageModelRequest, Role, -}; use anyhow::{anyhow, Result}; use editor::{Editor, EditorElement, EditorStyle}; use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt}; use gpui::{AnyView, AppContext, FontStyle, Task, TextStyle, View, WhiteSpace}; use http::HttpClient; -use open_ai::{stream_completion, Request, RequestMessage, Role as OpenAiRole}; +use language_model::{CloudModel, LanguageModel, LanguageModelRequest, Role}; +use open_ai::Model as OpenAiModel; +use open_ai::{stream_completion, Request, RequestMessage}; use settings::Settings; use std::time::Duration; use std::{env, sync::Arc}; @@ -25,6 +23,7 @@ pub struct OpenAiCompletionProvider { http_client: Arc, low_speed_timeout: Option, settings_version: usize, + available_models_from_settings: Vec, } impl OpenAiCompletionProvider { @@ -34,6 +33,7 @@ impl OpenAiCompletionProvider { http_client: Arc, low_speed_timeout: Option, settings_version: usize, + available_models_from_settings: Vec, ) -> Self { Self { api_key: None, @@ -42,6 +42,7 @@ impl OpenAiCompletionProvider { http_client, low_speed_timeout, settings_version, + available_models_from_settings, } } @@ -92,30 +93,26 @@ impl OpenAiCompletionProvider { } impl LanguageModelCompletionProvider for OpenAiCompletionProvider { - fn available_models(&self, cx: &AppContext) -> Vec { - if let AssistantProvider::OpenAi { - available_models, .. - } = &AssistantSettings::get_global(cx).provider - { - if !available_models.is_empty() { - return available_models - .iter() - .cloned() - .map(LanguageModel::OpenAi) - .collect(); - } - } - let available_models = if matches!(self.model, OpenAiModel::Custom { .. }) { - vec![self.model.clone()] - } else { - OpenAiModel::iter() - .filter(|model| !matches!(model, OpenAiModel::Custom { .. })) + fn available_models(&self) -> Vec { + if self.available_models_from_settings.is_empty() { + let available_models = if matches!(self.model, OpenAiModel::Custom { .. }) { + vec![self.model.clone()] + } else { + OpenAiModel::iter() + .filter(|model| !matches!(model, OpenAiModel::Custom { .. })) + .collect() + }; + available_models + .into_iter() + .map(LanguageModel::OpenAi) .collect() - }; - available_models - .into_iter() - .map(LanguageModel::OpenAi) - .collect() + } else { + self.available_models_from_settings + .iter() + .cloned() + .map(LanguageModel::OpenAi) + .collect() + } } fn settings_version(&self) -> usize { @@ -255,16 +252,6 @@ pub fn count_open_ai_tokens( .boxed() } -impl From for open_ai::Role { - fn from(val: Role) -> Self { - match val { - Role::User => OpenAiRole::User, - Role::Assistant => OpenAiRole::Assistant, - Role::System => OpenAiRole::System, - } - } -} - struct AuthenticationPrompt { api_key: View, api_url: String, diff --git a/crates/language_model/Cargo.toml b/crates/language_model/Cargo.toml new file mode 100644 index 0000000000..bdc3ad63d5 --- /dev/null +++ b/crates/language_model/Cargo.toml @@ -0,0 +1,41 @@ +[package] +name = "language_model" +version = "0.1.0" +edition = "2021" +publish = false +license = "GPL-3.0-or-later" + +[lints] +workspace = true + +[lib] +path = "src/language_model.rs" +doctest = false + +[features] +test-support = [ + "editor/test-support", + "language/test-support", + "project/test-support", + "text/test-support", +] + +[dependencies] +anthropic = { workspace = true, features = ["schemars"] } +ollama = { workspace = true, features = ["schemars"] } +open_ai = { workspace = true, features = ["schemars"] } +schemars.workspace = true +serde.workspace = true +strum.workspace = true +proto = { workspace = true, features = ["test-support"] } + +[dev-dependencies] +ctor.workspace = true +editor = { workspace = true, features = ["test-support"] } +env_logger.workspace = true +language = { workspace = true, features = ["test-support"] } +log.workspace = true +project = { workspace = true, features = ["test-support"] } +rand.workspace = true +text = { workspace = true, features = ["test-support"] } +unindent.workspace = true diff --git a/crates/language_model/LICENSE-GPL b/crates/language_model/LICENSE-GPL new file mode 120000 index 0000000000..89e542f750 --- /dev/null +++ b/crates/language_model/LICENSE-GPL @@ -0,0 +1 @@ +../../LICENSE-GPL \ No newline at end of file diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs new file mode 100644 index 0000000000..09de409ff4 --- /dev/null +++ b/crates/language_model/src/language_model.rs @@ -0,0 +1,7 @@ +mod model; +mod request; +mod role; + +pub use model::*; +pub use request::*; +pub use role::*; diff --git a/crates/language_model/src/model/cloud_model.rs b/crates/language_model/src/model/cloud_model.rs new file mode 100644 index 0000000000..20b2bf7d4f --- /dev/null +++ b/crates/language_model/src/model/cloud_model.rs @@ -0,0 +1,160 @@ +use crate::LanguageModelRequest; +pub use anthropic::Model as AnthropicModel; +pub use ollama::Model as OllamaModel; +pub use open_ai::Model as OpenAiModel; +use schemars::{ + schema::{InstanceType, Metadata, Schema, SchemaObject}, + JsonSchema, +}; +use serde::{ + de::{self, Visitor}, + Deserialize, Deserializer, Serialize, Serializer, +}; +use std::fmt; +use strum::{EnumIter, IntoEnumIterator}; + +#[derive(Clone, Debug, Default, PartialEq, EnumIter)] +pub enum CloudModel { + Gpt3Point5Turbo, + Gpt4, + Gpt4Turbo, + #[default] + Gpt4Omni, + Gpt4OmniMini, + Claude3_5Sonnet, + Claude3Opus, + Claude3Sonnet, + Claude3Haiku, + Gemini15Pro, + Gemini15Flash, + Custom(String), +} + +impl Serialize for CloudModel { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + serializer.serialize_str(self.id()) + } +} + +impl<'de> Deserialize<'de> for CloudModel { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + struct ZedDotDevModelVisitor; + + impl<'de> Visitor<'de> for ZedDotDevModelVisitor { + type Value = CloudModel; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a string for a ZedDotDevModel variant or a custom model") + } + + fn visit_str(self, value: &str) -> Result + where + E: de::Error, + { + let model = CloudModel::iter() + .find(|model| model.id() == value) + .unwrap_or_else(|| CloudModel::Custom(value.to_string())); + Ok(model) + } + } + + deserializer.deserialize_str(ZedDotDevModelVisitor) + } +} + +impl JsonSchema for CloudModel { + fn schema_name() -> String { + "ZedDotDevModel".to_owned() + } + + fn json_schema(_generator: &mut schemars::gen::SchemaGenerator) -> Schema { + let variants = CloudModel::iter() + .filter_map(|model| { + let id = model.id(); + if id.is_empty() { + None + } else { + Some(id.to_string()) + } + }) + .collect::>(); + Schema::Object(SchemaObject { + instance_type: Some(InstanceType::String.into()), + enum_values: Some(variants.iter().map(|s| s.clone().into()).collect()), + metadata: Some(Box::new(Metadata { + title: Some("ZedDotDevModel".to_owned()), + default: Some(CloudModel::default().id().into()), + examples: variants.into_iter().map(Into::into).collect(), + ..Default::default() + })), + ..Default::default() + }) + } +} + +impl CloudModel { + pub fn id(&self) -> &str { + match self { + Self::Gpt3Point5Turbo => "gpt-3.5-turbo", + Self::Gpt4 => "gpt-4", + Self::Gpt4Turbo => "gpt-4-turbo-preview", + Self::Gpt4Omni => "gpt-4o", + Self::Gpt4OmniMini => "gpt-4o-mini", + Self::Claude3_5Sonnet => "claude-3-5-sonnet", + Self::Claude3Opus => "claude-3-opus", + Self::Claude3Sonnet => "claude-3-sonnet", + Self::Claude3Haiku => "claude-3-haiku", + Self::Gemini15Pro => "gemini-1.5-pro", + Self::Gemini15Flash => "gemini-1.5-flash", + Self::Custom(id) => id, + } + } + + pub fn display_name(&self) -> &str { + match self { + Self::Gpt3Point5Turbo => "GPT 3.5 Turbo", + Self::Gpt4 => "GPT 4", + Self::Gpt4Turbo => "GPT 4 Turbo", + Self::Gpt4Omni => "GPT 4 Omni", + Self::Gpt4OmniMini => "GPT 4 Omni Mini", + Self::Claude3_5Sonnet => "Claude 3.5 Sonnet", + Self::Claude3Opus => "Claude 3 Opus", + Self::Claude3Sonnet => "Claude 3 Sonnet", + Self::Claude3Haiku => "Claude 3 Haiku", + Self::Gemini15Pro => "Gemini 1.5 Pro", + Self::Gemini15Flash => "Gemini 1.5 Flash", + Self::Custom(id) => id.as_str(), + } + } + + pub fn max_token_count(&self) -> usize { + match self { + Self::Gpt3Point5Turbo => 2048, + Self::Gpt4 => 4096, + Self::Gpt4Turbo | Self::Gpt4Omni => 128000, + Self::Gpt4OmniMini => 128000, + Self::Claude3_5Sonnet + | Self::Claude3Opus + | Self::Claude3Sonnet + | Self::Claude3Haiku => 200000, + Self::Gemini15Pro => 128000, + Self::Gemini15Flash => 32000, + Self::Custom(_) => 4096, // TODO: Make this configurable + } + } + + pub fn preprocess_request(&self, request: &mut LanguageModelRequest) { + match self { + Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => { + request.preprocess_anthropic() + } + _ => {} + } + } +} diff --git a/crates/language_model/src/model/mod.rs b/crates/language_model/src/model/mod.rs new file mode 100644 index 0000000000..b61766308f --- /dev/null +++ b/crates/language_model/src/model/mod.rs @@ -0,0 +1,60 @@ +pub mod cloud_model; + +pub use anthropic::Model as AnthropicModel; +pub use cloud_model::*; +pub use ollama::Model as OllamaModel; +pub use open_ai::Model as OpenAiModel; + +use serde::{Deserialize, Serialize}; + +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] +pub enum LanguageModel { + Cloud(CloudModel), + OpenAi(OpenAiModel), + Anthropic(AnthropicModel), + Ollama(OllamaModel), +} + +impl Default for LanguageModel { + fn default() -> Self { + LanguageModel::Cloud(CloudModel::default()) + } +} + +impl LanguageModel { + pub fn telemetry_id(&self) -> String { + match self { + LanguageModel::OpenAi(model) => format!("openai/{}", model.id()), + LanguageModel::Anthropic(model) => format!("anthropic/{}", model.id()), + LanguageModel::Cloud(model) => format!("zed.dev/{}", model.id()), + LanguageModel::Ollama(model) => format!("ollama/{}", model.id()), + } + } + + pub fn display_name(&self) -> String { + match self { + LanguageModel::OpenAi(model) => model.display_name().into(), + LanguageModel::Anthropic(model) => model.display_name().into(), + LanguageModel::Cloud(model) => model.display_name().into(), + LanguageModel::Ollama(model) => model.display_name().into(), + } + } + + pub fn max_token_count(&self) -> usize { + match self { + LanguageModel::OpenAi(model) => model.max_token_count(), + LanguageModel::Anthropic(model) => model.max_token_count(), + LanguageModel::Cloud(model) => model.max_token_count(), + LanguageModel::Ollama(model) => model.max_token_count(), + } + } + + pub fn id(&self) -> &str { + match self { + LanguageModel::OpenAi(model) => model.id(), + LanguageModel::Anthropic(model) => model.id(), + LanguageModel::Cloud(model) => model.id(), + LanguageModel::Ollama(model) => model.id(), + } + } +} diff --git a/crates/language_model/src/request.rs b/crates/language_model/src/request.rs new file mode 100644 index 0000000000..f9c4322cdf --- /dev/null +++ b/crates/language_model/src/request.rs @@ -0,0 +1,110 @@ +use crate::{ + model::{CloudModel, LanguageModel}, + role::Role, +}; +use serde::{Deserialize, Serialize}; + +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +pub struct LanguageModelRequestMessage { + pub role: Role, + pub content: String, +} + +impl LanguageModelRequestMessage { + pub fn to_proto(&self) -> proto::LanguageModelRequestMessage { + proto::LanguageModelRequestMessage { + role: self.role.to_proto() as i32, + content: self.content.clone(), + tool_calls: Vec::new(), + tool_call_id: None, + } + } +} + +#[derive(Debug, Default, Serialize, Deserialize)] +pub struct LanguageModelRequest { + pub model: LanguageModel, + pub messages: Vec, + pub stop: Vec, + pub temperature: f32, +} + +impl LanguageModelRequest { + pub fn to_proto(&self) -> proto::CompleteWithLanguageModel { + proto::CompleteWithLanguageModel { + model: self.model.id().to_string(), + messages: self.messages.iter().map(|m| m.to_proto()).collect(), + stop: self.stop.clone(), + temperature: self.temperature, + tool_choice: None, + tools: Vec::new(), + } + } + + /// Before we send the request to the server, we can perform fixups on it appropriate to the model. + pub fn preprocess(&mut self) { + match &self.model { + LanguageModel::OpenAi(_) => {} + LanguageModel::Anthropic(_) => {} + LanguageModel::Ollama(_) => {} + LanguageModel::Cloud(model) => match model { + CloudModel::Claude3Opus + | CloudModel::Claude3Sonnet + | CloudModel::Claude3Haiku + | CloudModel::Claude3_5Sonnet => { + self.preprocess_anthropic(); + } + _ => {} + }, + } + } + + pub fn preprocess_anthropic(&mut self) { + let mut new_messages: Vec = Vec::new(); + let mut system_message = String::new(); + + for message in self.messages.drain(..) { + if message.content.is_empty() { + continue; + } + + match message.role { + Role::User | Role::Assistant => { + if let Some(last_message) = new_messages.last_mut() { + if last_message.role == message.role { + last_message.content.push_str("\n\n"); + last_message.content.push_str(&message.content); + continue; + } + } + + new_messages.push(message); + } + Role::System => { + if !system_message.is_empty() { + system_message.push_str("\n\n"); + } + system_message.push_str(&message.content); + } + } + } + + if !system_message.is_empty() { + new_messages.insert( + 0, + LanguageModelRequestMessage { + role: Role::System, + content: system_message, + }, + ); + } + + self.messages = new_messages; + } +} + +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +pub struct LanguageModelResponseMessage { + pub role: Option, + pub content: Option, +} diff --git a/crates/language_model/src/role.rs b/crates/language_model/src/role.rs new file mode 100644 index 0000000000..f6276a4823 --- /dev/null +++ b/crates/language_model/src/role.rs @@ -0,0 +1,68 @@ +use serde::{Deserialize, Serialize}; +use std::fmt::{self, Display}; + +#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum Role { + User, + Assistant, + System, +} + +impl Role { + pub fn from_proto(role: i32) -> Role { + match proto::LanguageModelRole::from_i32(role) { + Some(proto::LanguageModelRole::LanguageModelUser) => Role::User, + Some(proto::LanguageModelRole::LanguageModelAssistant) => Role::Assistant, + Some(proto::LanguageModelRole::LanguageModelSystem) => Role::System, + Some(proto::LanguageModelRole::LanguageModelTool) => Role::System, + None => Role::User, + } + } + + pub fn to_proto(&self) -> proto::LanguageModelRole { + match self { + Role::User => proto::LanguageModelRole::LanguageModelUser, + Role::Assistant => proto::LanguageModelRole::LanguageModelAssistant, + Role::System => proto::LanguageModelRole::LanguageModelSystem, + } + } + + pub fn cycle(self) -> Role { + match self { + Role::User => Role::Assistant, + Role::Assistant => Role::System, + Role::System => Role::User, + } + } +} + +impl Display for Role { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Role::User => write!(f, "user"), + Role::Assistant => write!(f, "assistant"), + Role::System => write!(f, "system"), + } + } +} + +impl From for ollama::Role { + fn from(val: Role) -> Self { + match val { + Role::User => ollama::Role::User, + Role::Assistant => ollama::Role::Assistant, + Role::System => ollama::Role::System, + } + } +} + +impl From for open_ai::Role { + fn from(val: Role) -> Self { + match val { + Role::User => open_ai::Role::User, + Role::Assistant => open_ai::Role::Assistant, + Role::System => open_ai::Role::System, + } + } +} diff --git a/crates/semantic_index/Cargo.toml b/crates/semantic_index/Cargo.toml index 3f49490941..19cb0c96fe 100644 --- a/crates/semantic_index/Cargo.toml +++ b/crates/semantic_index/Cargo.toml @@ -22,6 +22,7 @@ anyhow.workspace = true client.workspace = true clock.workspace = true collections.workspace = true +completion.workspace = true fs.workspace = true futures.workspace = true futures-batch.workspace = true diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index 7a29f3be25..4c43fc1e46 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -1261,3 +1261,6 @@ mod tests { ); } } + +// See https://github.com/zed-industries/zed/pull/14823#discussion_r1684616398 for why this is here and when it should be removed. +type _TODO = completion::CompletionProvider;