Extract completion provider crate (#14823)

We will soon need `semantic_index` to be able to use
`CompletionProvider`. This is currently impossible due to a cyclic crate
dependency, because `CompletionProvider` lives in the `assistant` crate,
which depends on `semantic_index`.

This PR breaks the dependency cycle by extracting two crates out of
`assistant`: `language_model` and `completion`.

Only one piece of logic changed: [this
code](922fcaf5a6 (diff-3857b3707687a4d585f1200eec4c34a7a079eae8d303b4ce5b4fce46234ace9fR61-R69)).
* As of https://github.com/zed-industries/zed/pull/13276, whenever we
ask a given completion provider for its available models, OpenAI
providers would go and ask the global assistant settings whether the
user had configured an `available_models` setting, and if so, return
that.
* This PR changes it so that instead of eagerly asking the assistant
settings for this info (the new crate must not depend on `assistant`, or
else the dependency cycle would be back), OpenAI completion providers
now store the user-configured settings as part of their struct, and
whenever the settings change, we update the provider.

In theory, this change should not change user-visible behavior...but
since it's the only change in this large PR that's more than just moving
code around, I'm mentioning it here in case there's an unexpected
regression in practice! (cc @amtoaer in case you'd like to try out this
branch and verify that the feature is still working the way you expect.)

Release Notes:

- N/A

---------

Co-authored-by: Marshall Bowers <elliott.codes@gmail.com>
This commit is contained in:
Richard Feldman 2024-07-19 13:35:34 -04:00 committed by GitHub
parent b9a53ffa0b
commit ec487d8f64
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
30 changed files with 820 additions and 610 deletions

64
Cargo.lock generated
View File

@ -382,6 +382,7 @@ dependencies = [
"clock", "clock",
"collections", "collections",
"command_palette_hooks", "command_palette_hooks",
"completion",
"ctor", "ctor",
"editor", "editor",
"env_logger", "env_logger",
@ -396,6 +397,7 @@ dependencies = [
"indexed_docs", "indexed_docs",
"indoc", "indoc",
"language", "language",
"language_model",
"log", "log",
"menu", "menu",
"multi_buffer", "multi_buffer",
@ -418,13 +420,11 @@ dependencies = [
"settings", "settings",
"similar", "similar",
"smol", "smol",
"strum",
"telemetry_events", "telemetry_events",
"terminal", "terminal",
"terminal_view", "terminal_view",
"text", "text",
"theme", "theme",
"tiktoken-rs",
"toml 0.8.10", "toml 0.8.10",
"ui", "ui",
"unindent", "unindent",
@ -2491,6 +2491,7 @@ dependencies = [
"clock", "clock",
"collab_ui", "collab_ui",
"collections", "collections",
"completion",
"ctor", "ctor",
"dashmap", "dashmap",
"dev_server_projects", "dev_server_projects",
@ -2673,6 +2674,42 @@ dependencies = [
"gpui", "gpui",
] ]
[[package]]
name = "completion"
version = "0.1.0"
dependencies = [
"anthropic",
"anyhow",
"client",
"collections",
"ctor",
"editor",
"env_logger",
"futures 0.3.28",
"gpui",
"http 0.1.0",
"language",
"language_model",
"log",
"menu",
"ollama",
"open_ai",
"parking_lot",
"project",
"rand 0.8.5",
"serde",
"serde_json",
"settings",
"smol",
"strum",
"text",
"theme",
"tiktoken-rs",
"ui",
"unindent",
"util",
]
[[package]] [[package]]
name = "concurrent-queue" name = "concurrent-queue"
version = "2.2.0" version = "2.2.0"
@ -5996,6 +6033,28 @@ dependencies = [
"util", "util",
] ]
[[package]]
name = "language_model"
version = "0.1.0"
dependencies = [
"anthropic",
"ctor",
"editor",
"env_logger",
"language",
"log",
"ollama",
"open_ai",
"project",
"proto",
"rand 0.8.5",
"schemars",
"serde",
"strum",
"text",
"unindent",
]
[[package]] [[package]]
name = "language_selector" name = "language_selector"
version = "0.1.0" version = "0.1.0"
@ -9510,6 +9569,7 @@ dependencies = [
"client", "client",
"clock", "clock",
"collections", "collections",
"completion",
"env_logger", "env_logger",
"fs", "fs",
"futures 0.3.28", "futures 0.3.28",

View File

@ -19,6 +19,7 @@ members = [
"crates/collections", "crates/collections",
"crates/command_palette", "crates/command_palette",
"crates/command_palette_hooks", "crates/command_palette_hooks",
"crates/completion",
"crates/copilot", "crates/copilot",
"crates/db", "crates/db",
"crates/dev_server_projects", "crates/dev_server_projects",
@ -50,6 +51,7 @@ members = [
"crates/install_cli", "crates/install_cli",
"crates/journal", "crates/journal",
"crates/language", "crates/language",
"crates/language_model",
"crates/language_selector", "crates/language_selector",
"crates/language_tools", "crates/language_tools",
"crates/languages", "crates/languages",
@ -176,6 +178,7 @@ collab_ui = { path = "crates/collab_ui" }
collections = { path = "crates/collections" } collections = { path = "crates/collections" }
command_palette = { path = "crates/command_palette" } command_palette = { path = "crates/command_palette" }
command_palette_hooks = { path = "crates/command_palette_hooks" } command_palette_hooks = { path = "crates/command_palette_hooks" }
completion = { path = "crates/completion" }
copilot = { path = "crates/copilot" } copilot = { path = "crates/copilot" }
db = { path = "crates/db" } db = { path = "crates/db" }
dev_server_projects = { path = "crates/dev_server_projects" } dev_server_projects = { path = "crates/dev_server_projects" }
@ -205,6 +208,7 @@ inline_completion_button = { path = "crates/inline_completion_button" }
install_cli = { path = "crates/install_cli" } install_cli = { path = "crates/install_cli" }
journal = { path = "crates/journal" } journal = { path = "crates/journal" }
language = { path = "crates/language" } language = { path = "crates/language" }
language_model = { path = "crates/language_model" }
language_selector = { path = "crates/language_selector" } language_selector = { path = "crates/language_selector" }
language_tools = { path = "crates/language_tools" } language_tools = { path = "crates/language_tools" }
languages = { path = "crates/languages" } languages = { path = "crates/languages" }

View File

@ -33,6 +33,7 @@ client.workspace = true
clock.workspace = true clock.workspace = true
collections.workspace = true collections.workspace = true
command_palette_hooks.workspace = true command_palette_hooks.workspace = true
completion.workspace = true
editor.workspace = true editor.workspace = true
feature_flags.workspace = true feature_flags.workspace = true
fs.workspace = true fs.workspace = true
@ -45,6 +46,7 @@ http.workspace = true
indexed_docs.workspace = true indexed_docs.workspace = true
indoc.workspace = true indoc.workspace = true
language.workspace = true language.workspace = true
language_model.workspace = true
log.workspace = true log.workspace = true
menu.workspace = true menu.workspace = true
multi_buffer.workspace = true multi_buffer.workspace = true
@ -64,12 +66,10 @@ serde_json.workspace = true
settings.workspace = true settings.workspace = true
similar.workspace = true similar.workspace = true
smol.workspace = true smol.workspace = true
strum.workspace = true
telemetry_events.workspace = true telemetry_events.workspace = true
terminal.workspace = true terminal.workspace = true
terminal_view.workspace = true terminal_view.workspace = true
theme.workspace = true theme.workspace = true
tiktoken-rs.workspace = true
toml.workspace = true toml.workspace = true
ui.workspace = true ui.workspace = true
util.workspace = true util.workspace = true
@ -79,6 +79,7 @@ picker.workspace = true
roxmltree = "0.20.0" roxmltree = "0.20.0"
[dev-dependencies] [dev-dependencies]
completion = { workspace = true, features = ["test-support"] }
ctor.workspace = true ctor.workspace = true
editor = { workspace = true, features = ["test-support"] } editor = { workspace = true, features = ["test-support"] }
env_logger.workspace = true env_logger.workspace = true

View File

@ -1,6 +1,5 @@
pub mod assistant_panel; pub mod assistant_panel;
pub mod assistant_settings; pub mod assistant_settings;
mod completion_provider;
mod context; mod context;
pub mod context_store; pub mod context_store;
mod inline_assistant; mod inline_assistant;
@ -12,17 +11,20 @@ mod streaming_diff;
mod terminal_inline_assistant; mod terminal_inline_assistant;
pub use assistant_panel::{AssistantPanel, AssistantPanelEvent}; pub use assistant_panel::{AssistantPanel, AssistantPanelEvent};
use assistant_settings::{AnthropicModel, AssistantSettings, CloudModel, OllamaModel, OpenAiModel}; use assistant_settings::AssistantSettings;
use assistant_slash_command::SlashCommandRegistry; use assistant_slash_command::SlashCommandRegistry;
use client::{proto, Client}; use client::{proto, Client};
use command_palette_hooks::CommandPaletteFilter; use command_palette_hooks::CommandPaletteFilter;
pub use completion_provider::*; use completion::CompletionProvider;
pub use context::*; pub use context::*;
pub use context_store::*; pub use context_store::*;
use fs::Fs; use fs::Fs;
use gpui::{actions, impl_actions, AppContext, Global, SharedString, UpdateGlobal}; use gpui::{
actions, impl_actions, AppContext, BorrowAppContext, Global, SharedString, UpdateGlobal,
};
use indexed_docs::IndexedDocsRegistry; use indexed_docs::IndexedDocsRegistry;
pub(crate) use inline_assistant::*; pub(crate) use inline_assistant::*;
use language_model::LanguageModelResponseMessage;
pub(crate) use model_selector::*; pub(crate) use model_selector::*;
use semantic_index::{CloudEmbeddingProvider, SemanticIndex}; use semantic_index::{CloudEmbeddingProvider, SemanticIndex};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -32,10 +34,7 @@ use slash_command::{
file_command, now_command, project_command, prompt_command, search_command, symbols_command, file_command, now_command, project_command, prompt_command, search_command, symbols_command,
tabs_command, term_command, tabs_command, term_command,
}; };
use std::{ use std::sync::Arc;
fmt::{self, Display},
sync::Arc,
};
pub(crate) use streaming_diff::*; pub(crate) use streaming_diff::*;
actions!( actions!(
@ -73,166 +72,6 @@ impl MessageId {
} }
} }
#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum Role {
User,
Assistant,
System,
}
impl Role {
pub fn from_proto(role: i32) -> Role {
match proto::LanguageModelRole::from_i32(role) {
Some(proto::LanguageModelRole::LanguageModelUser) => Role::User,
Some(proto::LanguageModelRole::LanguageModelAssistant) => Role::Assistant,
Some(proto::LanguageModelRole::LanguageModelSystem) => Role::System,
Some(proto::LanguageModelRole::LanguageModelTool) => Role::System,
None => Role::User,
}
}
pub fn to_proto(&self) -> proto::LanguageModelRole {
match self {
Role::User => proto::LanguageModelRole::LanguageModelUser,
Role::Assistant => proto::LanguageModelRole::LanguageModelAssistant,
Role::System => proto::LanguageModelRole::LanguageModelSystem,
}
}
pub fn cycle(self) -> Role {
match self {
Role::User => Role::Assistant,
Role::Assistant => Role::System,
Role::System => Role::User,
}
}
}
impl Display for Role {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Role::User => write!(f, "user"),
Role::Assistant => write!(f, "assistant"),
Role::System => write!(f, "system"),
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
pub enum LanguageModel {
Cloud(CloudModel),
OpenAi(OpenAiModel),
Anthropic(AnthropicModel),
Ollama(OllamaModel),
}
impl Default for LanguageModel {
fn default() -> Self {
LanguageModel::Cloud(CloudModel::default())
}
}
impl LanguageModel {
pub fn telemetry_id(&self) -> String {
match self {
LanguageModel::OpenAi(model) => format!("openai/{}", model.id()),
LanguageModel::Anthropic(model) => format!("anthropic/{}", model.id()),
LanguageModel::Cloud(model) => format!("zed.dev/{}", model.id()),
LanguageModel::Ollama(model) => format!("ollama/{}", model.id()),
}
}
pub fn display_name(&self) -> String {
match self {
LanguageModel::OpenAi(model) => model.display_name().into(),
LanguageModel::Anthropic(model) => model.display_name().into(),
LanguageModel::Cloud(model) => model.display_name().into(),
LanguageModel::Ollama(model) => model.display_name().into(),
}
}
pub fn max_token_count(&self) -> usize {
match self {
LanguageModel::OpenAi(model) => model.max_token_count(),
LanguageModel::Anthropic(model) => model.max_token_count(),
LanguageModel::Cloud(model) => model.max_token_count(),
LanguageModel::Ollama(model) => model.max_token_count(),
}
}
pub fn id(&self) -> &str {
match self {
LanguageModel::OpenAi(model) => model.id(),
LanguageModel::Anthropic(model) => model.id(),
LanguageModel::Cloud(model) => model.id(),
LanguageModel::Ollama(model) => model.id(),
}
}
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct LanguageModelRequestMessage {
pub role: Role,
pub content: String,
}
impl LanguageModelRequestMessage {
pub fn to_proto(&self) -> proto::LanguageModelRequestMessage {
proto::LanguageModelRequestMessage {
role: self.role.to_proto() as i32,
content: self.content.clone(),
tool_calls: Vec::new(),
tool_call_id: None,
}
}
}
#[derive(Debug, Default, Serialize, Deserialize)]
pub struct LanguageModelRequest {
pub model: LanguageModel,
pub messages: Vec<LanguageModelRequestMessage>,
pub stop: Vec<String>,
pub temperature: f32,
}
impl LanguageModelRequest {
pub fn to_proto(&self) -> proto::CompleteWithLanguageModel {
proto::CompleteWithLanguageModel {
model: self.model.id().to_string(),
messages: self.messages.iter().map(|m| m.to_proto()).collect(),
stop: self.stop.clone(),
temperature: self.temperature,
tool_choice: None,
tools: Vec::new(),
}
}
/// Before we send the request to the server, we can perform fixups on it appropriate to the model.
pub fn preprocess(&mut self) {
match &self.model {
LanguageModel::OpenAi(_) => {}
LanguageModel::Anthropic(_) => {}
LanguageModel::Ollama(_) => {}
LanguageModel::Cloud(model) => match model {
CloudModel::Claude3Opus
| CloudModel::Claude3Sonnet
| CloudModel::Claude3Haiku
| CloudModel::Claude3_5Sonnet => {
preprocess_anthropic_request(self);
}
_ => {}
},
}
}
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct LanguageModelResponseMessage {
pub role: Option<Role>,
pub content: Option<String>,
}
#[derive(Deserialize, Debug)] #[derive(Deserialize, Debug)]
pub struct LanguageModelUsage { pub struct LanguageModelUsage {
pub prompt_tokens: u32, pub prompt_tokens: u32,
@ -343,7 +182,7 @@ pub fn init(fs: Arc<dyn Fs>, client: Arc<Client>, cx: &mut AppContext) {
context_store::init(&client); context_store::init(&client);
prompt_library::init(cx); prompt_library::init(cx);
completion_provider::init(client.clone(), cx); init_completion_provider(Arc::clone(&client), cx);
assistant_slash_command::init(cx); assistant_slash_command::init(cx);
register_slash_commands(cx); register_slash_commands(cx);
assistant_panel::init(cx); assistant_panel::init(cx);
@ -368,6 +207,20 @@ pub fn init(fs: Arc<dyn Fs>, client: Arc<Client>, cx: &mut AppContext) {
.detach(); .detach();
} }
fn init_completion_provider(client: Arc<Client>, cx: &mut AppContext) {
let provider = assistant_settings::create_provider_from_settings(client.clone(), 0, cx);
cx.set_global(CompletionProvider::new(provider, Some(client)));
let mut settings_version = 0;
cx.observe_global::<SettingsStore>(move |cx| {
settings_version += 1;
cx.update_global::<CompletionProvider, _>(|provider, cx| {
assistant_settings::update_completion_provider_settings(provider, settings_version, cx);
})
})
.detach();
}
fn register_slash_commands(cx: &mut AppContext) { fn register_slash_commands(cx: &mut AppContext) {
let slash_command_registry = SlashCommandRegistry::global(cx); let slash_command_registry = SlashCommandRegistry::global(cx);
slash_command_registry.register_command(file_command::FileSlashCommand, true); slash_command_registry.register_command(file_command::FileSlashCommand, true);

View File

@ -8,18 +8,18 @@ use crate::{
SlashCommandCompletionProvider, SlashCommandRegistry, SlashCommandCompletionProvider, SlashCommandRegistry,
}, },
terminal_inline_assistant::TerminalInlineAssistant, terminal_inline_assistant::TerminalInlineAssistant,
Assist, CompletionProvider, ConfirmCommand, Context, ContextEvent, ContextId, ContextStore, Assist, ConfirmCommand, Context, ContextEvent, ContextId, ContextStore, CycleMessageRole,
CycleMessageRole, DebugEditSteps, DeployHistory, DeployPromptLibrary, EditStep, DebugEditSteps, DeployHistory, DeployPromptLibrary, EditStep, EditStepOperations,
EditStepOperations, EditSuggestionGroup, InlineAssist, InlineAssistId, InlineAssistant, EditSuggestionGroup, InlineAssist, InlineAssistId, InlineAssistant, InsertIntoEditor,
InsertIntoEditor, MessageStatus, ModelSelector, PendingSlashCommand, PendingSlashCommandStatus, MessageStatus, ModelSelector, PendingSlashCommand, PendingSlashCommandStatus, QuoteSelection,
QuoteSelection, RemoteContextMetadata, ResetKey, Role, SavedContextMetadata, Split, RemoteContextMetadata, ResetKey, SavedContextMetadata, Split, ToggleFocus, ToggleModelSelector,
ToggleFocus, ToggleModelSelector,
}; };
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use assistant_slash_command::{SlashCommand, SlashCommandOutputSection}; use assistant_slash_command::{SlashCommand, SlashCommandOutputSection};
use breadcrumbs::Breadcrumbs; use breadcrumbs::Breadcrumbs;
use client::proto; use client::proto;
use collections::{BTreeSet, HashMap, HashSet}; use collections::{BTreeSet, HashMap, HashSet};
use completion::CompletionProvider;
use editor::{ use editor::{
actions::{FoldAt, MoveToEndOfLine, Newline, ShowCompletions, UnfoldAt}, actions::{FoldAt, MoveToEndOfLine, Newline, ShowCompletions, UnfoldAt},
display_map::{ display_map::{
@ -43,6 +43,7 @@ use language::{
language_settings::SoftWrap, Buffer, Capability, LanguageRegistry, LspAdapterDelegate, Point, language_settings::SoftWrap, Buffer, Capability, LanguageRegistry, LspAdapterDelegate, Point,
ToOffset, ToOffset,
}; };
use language_model::Role;
use multi_buffer::MultiBufferRow; use multi_buffer::MultiBufferRow;
use picker::{Picker, PickerDelegate}; use picker::{Picker, PickerDelegate};
use project::{Project, ProjectLspAdapterDelegate}; use project::{Project, ProjectLspAdapterDelegate};

View File

@ -1,166 +1,19 @@
use std::fmt; use std::{sync::Arc, time::Duration};
use crate::{preprocess_anthropic_request, LanguageModel, LanguageModelRequest}; use anthropic::Model as AnthropicModel;
pub use anthropic::Model as AnthropicModel; use client::Client;
use gpui::Pixels; use completion::{
pub use ollama::Model as OllamaModel; AnthropicCompletionProvider, CloudCompletionProvider, CompletionProvider,
pub use open_ai::Model as OpenAiModel; LanguageModelCompletionProvider, OllamaCompletionProvider, OpenAiCompletionProvider,
use schemars::{
schema::{InstanceType, Metadata, Schema, SchemaObject},
JsonSchema,
};
use serde::{
de::{self, Visitor},
Deserialize, Deserializer, Serialize, Serializer,
}; };
use gpui::{AppContext, Pixels};
use language_model::{CloudModel, LanguageModel};
use ollama::Model as OllamaModel;
use open_ai::Model as OpenAiModel;
use parking_lot::RwLock;
use schemars::{schema::Schema, JsonSchema};
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsSources}; use settings::{Settings, SettingsSources};
use strum::{EnumIter, IntoEnumIterator};
#[derive(Clone, Debug, Default, PartialEq, EnumIter)]
pub enum CloudModel {
Gpt3Point5Turbo,
Gpt4,
Gpt4Turbo,
#[default]
Gpt4Omni,
Gpt4OmniMini,
Claude3_5Sonnet,
Claude3Opus,
Claude3Sonnet,
Claude3Haiku,
Gemini15Pro,
Gemini15Flash,
Custom(String),
}
impl Serialize for CloudModel {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(self.id())
}
}
impl<'de> Deserialize<'de> for CloudModel {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct ZedDotDevModelVisitor;
impl<'de> Visitor<'de> for ZedDotDevModelVisitor {
type Value = CloudModel;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a string for a ZedDotDevModel variant or a custom model")
}
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
where
E: de::Error,
{
let model = CloudModel::iter()
.find(|model| model.id() == value)
.unwrap_or_else(|| CloudModel::Custom(value.to_string()));
Ok(model)
}
}
deserializer.deserialize_str(ZedDotDevModelVisitor)
}
}
impl JsonSchema for CloudModel {
fn schema_name() -> String {
"ZedDotDevModel".to_owned()
}
fn json_schema(_generator: &mut schemars::gen::SchemaGenerator) -> Schema {
let variants = CloudModel::iter()
.filter_map(|model| {
let id = model.id();
if id.is_empty() {
None
} else {
Some(id.to_string())
}
})
.collect::<Vec<_>>();
Schema::Object(SchemaObject {
instance_type: Some(InstanceType::String.into()),
enum_values: Some(variants.iter().map(|s| s.clone().into()).collect()),
metadata: Some(Box::new(Metadata {
title: Some("ZedDotDevModel".to_owned()),
default: Some(CloudModel::default().id().into()),
examples: variants.into_iter().map(Into::into).collect(),
..Default::default()
})),
..Default::default()
})
}
}
impl CloudModel {
pub fn id(&self) -> &str {
match self {
Self::Gpt3Point5Turbo => "gpt-3.5-turbo",
Self::Gpt4 => "gpt-4",
Self::Gpt4Turbo => "gpt-4-turbo-preview",
Self::Gpt4Omni => "gpt-4o",
Self::Gpt4OmniMini => "gpt-4o-mini",
Self::Claude3_5Sonnet => "claude-3-5-sonnet",
Self::Claude3Opus => "claude-3-opus",
Self::Claude3Sonnet => "claude-3-sonnet",
Self::Claude3Haiku => "claude-3-haiku",
Self::Gemini15Pro => "gemini-1.5-pro",
Self::Gemini15Flash => "gemini-1.5-flash",
Self::Custom(id) => id,
}
}
pub fn display_name(&self) -> &str {
match self {
Self::Gpt3Point5Turbo => "GPT 3.5 Turbo",
Self::Gpt4 => "GPT 4",
Self::Gpt4Turbo => "GPT 4 Turbo",
Self::Gpt4Omni => "GPT 4 Omni",
Self::Gpt4OmniMini => "GPT 4 Omni Mini",
Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
Self::Claude3Opus => "Claude 3 Opus",
Self::Claude3Sonnet => "Claude 3 Sonnet",
Self::Claude3Haiku => "Claude 3 Haiku",
Self::Gemini15Pro => "Gemini 1.5 Pro",
Self::Gemini15Flash => "Gemini 1.5 Flash",
Self::Custom(id) => id.as_str(),
}
}
pub fn max_token_count(&self) -> usize {
match self {
Self::Gpt3Point5Turbo => 2048,
Self::Gpt4 => 4096,
Self::Gpt4Turbo | Self::Gpt4Omni => 128000,
Self::Gpt4OmniMini => 128000,
Self::Claude3_5Sonnet
| Self::Claude3Opus
| Self::Claude3Sonnet
| Self::Claude3Haiku => 200000,
Self::Gemini15Pro => 128000,
Self::Gemini15Flash => 32000,
Self::Custom(_) => 4096, // TODO: Make this configurable
}
}
pub fn preprocess_request(&self, request: &mut LanguageModelRequest) {
match self {
Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => {
preprocess_anthropic_request(request)
}
_ => {}
}
}
}
#[derive(Copy, Clone, Default, Debug, Serialize, Deserialize, JsonSchema)] #[derive(Copy, Clone, Default, Debug, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
@ -620,6 +473,124 @@ fn merge<T>(target: &mut T, value: Option<T>) {
} }
} }
pub fn update_completion_provider_settings(
provider: &mut CompletionProvider,
version: usize,
cx: &mut AppContext,
) {
let updated = match &AssistantSettings::get_global(cx).provider {
AssistantProvider::ZedDotDev { model } => provider
.update_current_as::<_, CloudCompletionProvider>(|provider| {
provider.update(model.clone(), version);
}),
AssistantProvider::OpenAi {
model,
api_url,
low_speed_timeout_in_seconds,
available_models,
} => provider.update_current_as::<_, OpenAiCompletionProvider>(|provider| {
provider.update(
choose_openai_model(&model, &available_models),
api_url.clone(),
low_speed_timeout_in_seconds.map(Duration::from_secs),
version,
);
}),
AssistantProvider::Anthropic {
model,
api_url,
low_speed_timeout_in_seconds,
} => provider.update_current_as::<_, AnthropicCompletionProvider>(|provider| {
provider.update(
model.clone(),
api_url.clone(),
low_speed_timeout_in_seconds.map(Duration::from_secs),
version,
);
}),
AssistantProvider::Ollama {
model,
api_url,
low_speed_timeout_in_seconds,
} => provider.update_current_as::<_, OllamaCompletionProvider>(|provider| {
provider.update(
model.clone(),
api_url.clone(),
low_speed_timeout_in_seconds.map(Duration::from_secs),
version,
cx,
);
}),
};
// Previously configured provider was changed to another one
if updated.is_none() {
provider.update_provider(|client| create_provider_from_settings(client, version, cx));
}
}
pub(crate) fn create_provider_from_settings(
client: Arc<Client>,
settings_version: usize,
cx: &mut AppContext,
) -> Arc<RwLock<dyn LanguageModelCompletionProvider>> {
match &AssistantSettings::get_global(cx).provider {
AssistantProvider::ZedDotDev { model } => Arc::new(RwLock::new(
CloudCompletionProvider::new(model.clone(), client.clone(), settings_version, cx),
)),
AssistantProvider::OpenAi {
model,
api_url,
low_speed_timeout_in_seconds,
available_models,
} => Arc::new(RwLock::new(OpenAiCompletionProvider::new(
choose_openai_model(&model, &available_models),
api_url.clone(),
client.http_client(),
low_speed_timeout_in_seconds.map(Duration::from_secs),
settings_version,
available_models.clone(),
))),
AssistantProvider::Anthropic {
model,
api_url,
low_speed_timeout_in_seconds,
} => Arc::new(RwLock::new(AnthropicCompletionProvider::new(
model.clone(),
api_url.clone(),
client.http_client(),
low_speed_timeout_in_seconds.map(Duration::from_secs),
settings_version,
))),
AssistantProvider::Ollama {
model,
api_url,
low_speed_timeout_in_seconds,
} => Arc::new(RwLock::new(OllamaCompletionProvider::new(
model.clone(),
api_url.clone(),
client.http_client(),
low_speed_timeout_in_seconds.map(Duration::from_secs),
settings_version,
cx,
))),
}
}
/// Choose which model to use for openai provider.
/// If the model is not available, try to use the first available model, or fallback to the original model.
fn choose_openai_model(
model: &::open_ai::Model,
available_models: &[::open_ai::Model],
) -> ::open_ai::Model {
available_models
.iter()
.find(|&m| m == model)
.or_else(|| available_models.first())
.unwrap_or_else(|| model)
.clone()
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use gpui::{AppContext, UpdateGlobal}; use gpui::{AppContext, UpdateGlobal};

View File

@ -1,12 +1,12 @@
use crate::{ use crate::{
prompt_library::PromptStore, slash_command::SlashCommandLine, CompletionProvider, prompt_library::PromptStore, slash_command::SlashCommandLine, CompletionProvider, MessageId,
LanguageModelRequest, LanguageModelRequestMessage, MessageId, MessageStatus, Role, MessageStatus,
}; };
use anyhow::{anyhow, Context as _, Result}; use anyhow::{anyhow, Context as _, Result};
use assistant_slash_command::{ use assistant_slash_command::{
SlashCommandOutput, SlashCommandOutputSection, SlashCommandRegistry, SlashCommandOutput, SlashCommandOutputSection, SlashCommandRegistry,
}; };
use client::{proto, telemetry::Telemetry}; use client::{self, proto, telemetry::Telemetry};
use clock::ReplicaId; use clock::ReplicaId;
use collections::{HashMap, HashSet}; use collections::{HashMap, HashSet};
use fs::Fs; use fs::Fs;
@ -18,6 +18,8 @@ use gpui::{AppContext, Context as _, EventEmitter, Model, ModelContext, Subscrip
use language::{ use language::{
AnchorRangeExt, Bias, Buffer, LanguageRegistry, OffsetRangeExt, ParseStatus, Point, ToOffset, AnchorRangeExt, Bias, Buffer, LanguageRegistry, OffsetRangeExt, ParseStatus, Point, ToOffset,
}; };
use language_model::LanguageModelRequestMessage;
use language_model::{LanguageModelRequest, Role};
use open_ai::Model as OpenAiModel; use open_ai::Model as OpenAiModel;
use paths::contexts_dir; use paths::contexts_dir;
use project::Project; use project::Project;
@ -2477,9 +2479,10 @@ mod tests {
use crate::{ use crate::{
assistant_panel, prompt_library, assistant_panel, prompt_library,
slash_command::{active_command, file_command}, slash_command::{active_command, file_command},
FakeCompletionProvider, MessageId, MessageId,
}; };
use assistant_slash_command::{ArgumentCompletion, SlashCommand}; use assistant_slash_command::{ArgumentCompletion, SlashCommand};
use completion::FakeCompletionProvider;
use fs::FakeFs; use fs::FakeFs;
use gpui::{AppContext, TestAppContext, WeakView}; use gpui::{AppContext, TestAppContext, WeakView};
use indoc::indoc; use indoc::indoc;

View File

@ -1,7 +1,6 @@
use crate::{ use crate::{
assistant_settings::AssistantSettings, humanize_token_count, prompts::generate_content_prompt, assistant_settings::AssistantSettings, humanize_token_count, prompts::generate_content_prompt,
AssistantPanel, AssistantPanelEvent, CompletionProvider, Hunk, LanguageModelRequest, AssistantPanel, AssistantPanelEvent, CompletionProvider, Hunk, StreamingDiff,
LanguageModelRequestMessage, Role, StreamingDiff,
}; };
use anyhow::{anyhow, Context as _, Result}; use anyhow::{anyhow, Context as _, Result};
use client::telemetry::Telemetry; use client::telemetry::Telemetry;
@ -28,6 +27,7 @@ use gpui::{
WhiteSpace, WindowContext, WhiteSpace, WindowContext,
}; };
use language::{Buffer, Point, Selection, TransactionId}; use language::{Buffer, Point, Selection, TransactionId};
use language_model::{LanguageModelRequest, LanguageModelRequestMessage, Role};
use multi_buffer::MultiBufferRow; use multi_buffer::MultiBufferRow;
use parking_lot::Mutex; use parking_lot::Mutex;
use rope::Rope; use rope::Rope;
@ -1432,8 +1432,7 @@ impl Render for PromptEditor {
PopoverMenu::new("model-switcher") PopoverMenu::new("model-switcher")
.menu(move |cx| { .menu(move |cx| {
ContextMenu::build(cx, |mut menu, cx| { ContextMenu::build(cx, |mut menu, cx| {
for model in CompletionProvider::global(cx).available_models(cx) for model in CompletionProvider::global(cx).available_models() {
{
menu = menu.custom_entry( menu = menu.custom_entry(
{ {
let model = model.clone(); let model = model.clone();
@ -2606,7 +2605,7 @@ fn merge_ranges(ranges: &mut Vec<Range<Anchor>>, buffer: &MultiBufferSnapshot) {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::FakeCompletionProvider; use completion::FakeCompletionProvider;
use futures::stream::{self}; use futures::stream::{self};
use gpui::{Context, TestAppContext}; use gpui::{Context, TestAppContext};
use indoc::indoc; use indoc::indoc;

View File

@ -23,7 +23,7 @@ impl RenderOnce for ModelSelector {
.with_handle(self.handle) .with_handle(self.handle)
.menu(move |cx| { .menu(move |cx| {
ContextMenu::build(cx, |mut menu, cx| { ContextMenu::build(cx, |mut menu, cx| {
for model in CompletionProvider::global(cx).available_models(cx) { for model in CompletionProvider::global(cx).available_models() {
menu = menu.custom_entry( menu = menu.custom_entry(
{ {
let model = model.clone(); let model = model.clone();

View File

@ -1,6 +1,6 @@
use crate::{ use crate::{
slash_command::SlashCommandCompletionProvider, AssistantPanel, CompletionProvider, slash_command::SlashCommandCompletionProvider, AssistantPanel, CompletionProvider,
InlineAssist, InlineAssistant, LanguageModelRequest, LanguageModelRequestMessage, Role, InlineAssist, InlineAssistant,
}; };
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use assets::Assets; use assets::Assets;
@ -19,6 +19,7 @@ use gpui::{
}; };
use heed::{types::SerdeBincode, Database, RoTxn}; use heed::{types::SerdeBincode, Database, RoTxn};
use language::{language_settings::SoftWrap, Buffer, LanguageRegistry}; use language::{language_settings::SoftWrap, Buffer, LanguageRegistry};
use language_model::{LanguageModelRequest, LanguageModelRequestMessage, Role};
use parking_lot::RwLock; use parking_lot::RwLock;
use picker::{Picker, PickerDelegate}; use picker::{Picker, PickerDelegate};
use rope::Rope; use rope::Rope;

View File

@ -1,7 +1,7 @@
use crate::{ use crate::{
assistant_settings::AssistantSettings, humanize_token_count, assistant_settings::AssistantSettings, humanize_token_count,
prompts::generate_terminal_assistant_prompt, AssistantPanel, AssistantPanelEvent, prompts::generate_terminal_assistant_prompt, AssistantPanel, AssistantPanelEvent,
CompletionProvider, LanguageModelRequest, LanguageModelRequestMessage, Role, CompletionProvider,
}; };
use anyhow::{Context as _, Result}; use anyhow::{Context as _, Result};
use client::telemetry::Telemetry; use client::telemetry::Telemetry;
@ -17,6 +17,7 @@ use gpui::{
Model, ModelContext, Subscription, Task, TextStyle, UpdateGlobal, View, WeakView, WhiteSpace, Model, ModelContext, Subscription, Task, TextStyle, UpdateGlobal, View, WeakView, WhiteSpace,
}; };
use language::Buffer; use language::Buffer;
use language_model::{LanguageModelRequest, LanguageModelRequestMessage, Role};
use settings::{update_settings_file, Settings}; use settings::{update_settings_file, Settings};
use std::{ use std::{
cmp, cmp,
@ -558,8 +559,7 @@ impl Render for PromptEditor {
PopoverMenu::new("model-switcher") PopoverMenu::new("model-switcher")
.menu(move |cx| { .menu(move |cx| {
ContextMenu::build(cx, |mut menu, cx| { ContextMenu::build(cx, |mut menu, cx| {
for model in CompletionProvider::global(cx).available_models(cx) for model in CompletionProvider::global(cx).available_models() {
{
menu = menu.custom_entry( menu = menu.custom_entry(
{ {
let model = model.clone(); let model = model.clone();

View File

@ -30,6 +30,7 @@ chrono.workspace = true
clock.workspace = true clock.workspace = true
clickhouse.workspace = true clickhouse.workspace = true
collections.workspace = true collections.workspace = true
completion.workspace = true
dashmap = "5.4" dashmap = "5.4"
envy = "0.4.2" envy = "0.4.2"
futures.workspace = true futures.workspace = true
@ -79,6 +80,7 @@ channel.workspace = true
client = { workspace = true, features = ["test-support"] } client = { workspace = true, features = ["test-support"] }
collab_ui = { workspace = true, features = ["test-support"] } collab_ui = { workspace = true, features = ["test-support"] }
collections = { workspace = true, features = ["test-support"] } collections = { workspace = true, features = ["test-support"] }
completion = { workspace = true, features = ["test-support"] }
ctor.workspace = true ctor.workspace = true
editor = { workspace = true, features = ["test-support"] } editor = { workspace = true, features = ["test-support"] }
env_logger.workspace = true env_logger.workspace = true

View File

@ -295,7 +295,7 @@ impl TestServer {
menu::init(); menu::init();
dev_server_projects::init(client.clone(), cx); dev_server_projects::init(client.clone(), cx);
settings::KeymapFile::load_asset(os_keymap, cx).unwrap(); settings::KeymapFile::load_asset(os_keymap, cx).unwrap();
assistant::FakeCompletionProvider::setup_test(cx); completion::FakeCompletionProvider::setup_test(cx);
assistant::context_store::init(&client); assistant::context_store::init(&client);
}); });

View File

@ -0,0 +1,56 @@
[package]
name = "completion"
version = "0.1.0"
edition = "2021"
publish = false
license = "GPL-3.0-or-later"
[lints]
workspace = true
[lib]
path = "src/completion.rs"
doctest = false
[features]
test-support = [
"editor/test-support",
"language/test-support",
"project/test-support",
"text/test-support",
]
[dependencies]
anthropic = { workspace = true, features = ["schemars"] }
anyhow.workspace = true
client.workspace = true
collections.workspace = true
editor.workspace = true
futures.workspace = true
gpui.workspace = true
http.workspace = true
language_model.workspace = true
log.workspace = true
menu.workspace = true
ollama = { workspace = true, features = ["schemars"] }
open_ai = { workspace = true, features = ["schemars"] }
parking_lot.workspace = true
serde.workspace = true
serde_json.workspace = true
settings.workspace = true
smol.workspace = true
strum.workspace = true
theme.workspace = true
tiktoken-rs.workspace = true
ui.workspace = true
util.workspace = true
[dev-dependencies]
ctor.workspace = true
editor = { workspace = true, features = ["test-support"] }
env_logger.workspace = true
language = { workspace = true, features = ["test-support"] }
project = { workspace = true, features = ["test-support"] }
rand.workspace = true
text = { workspace = true, features = ["test-support"] }
unindent.workspace = true

View File

@ -0,0 +1 @@
../../LICENSE-GPL

View File

@ -1,14 +1,12 @@
use crate::{ use crate::{count_open_ai_tokens, LanguageModelCompletionProvider};
assistant_settings::AnthropicModel, CompletionProvider, LanguageModel, LanguageModelRequest, use crate::{CompletionProvider, LanguageModel, LanguageModelRequest};
Role, use anthropic::{stream_completion, Model as AnthropicModel, Request, RequestMessage};
};
use crate::{count_open_ai_tokens, LanguageModelCompletionProvider, LanguageModelRequestMessage};
use anthropic::{stream_completion, Request, RequestMessage};
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use editor::{Editor, EditorElement, EditorStyle}; use editor::{Editor, EditorElement, EditorStyle};
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt}; use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
use gpui::{AnyView, AppContext, FontStyle, Task, TextStyle, View, WhiteSpace}; use gpui::{AnyView, AppContext, FontStyle, Task, TextStyle, View, WhiteSpace};
use http::HttpClient; use http::HttpClient;
use language_model::Role;
use settings::Settings; use settings::Settings;
use std::time::Duration; use std::time::Duration;
use std::{env, sync::Arc}; use std::{env, sync::Arc};
@ -27,7 +25,7 @@ pub struct AnthropicCompletionProvider {
} }
impl LanguageModelCompletionProvider for AnthropicCompletionProvider { impl LanguageModelCompletionProvider for AnthropicCompletionProvider {
fn available_models(&self, _cx: &AppContext) -> Vec<LanguageModel> { fn available_models(&self) -> Vec<LanguageModel> {
AnthropicModel::iter() AnthropicModel::iter()
.map(LanguageModel::Anthropic) .map(LanguageModel::Anthropic)
.collect() .collect()
@ -176,7 +174,7 @@ impl AnthropicCompletionProvider {
} }
fn to_anthropic_request(&self, mut request: LanguageModelRequest) -> Request { fn to_anthropic_request(&self, mut request: LanguageModelRequest) -> Request {
preprocess_anthropic_request(&mut request); request.preprocess_anthropic();
let model = match request.model { let model = match request.model {
LanguageModel::Anthropic(model) => model, LanguageModel::Anthropic(model) => model,
@ -213,49 +211,6 @@ impl AnthropicCompletionProvider {
} }
} }
pub fn preprocess_anthropic_request(request: &mut LanguageModelRequest) {
let mut new_messages: Vec<LanguageModelRequestMessage> = Vec::new();
let mut system_message = String::new();
for message in request.messages.drain(..) {
if message.content.is_empty() {
continue;
}
match message.role {
Role::User | Role::Assistant => {
if let Some(last_message) = new_messages.last_mut() {
if last_message.role == message.role {
last_message.content.push_str("\n\n");
last_message.content.push_str(&message.content);
continue;
}
}
new_messages.push(message);
}
Role::System => {
if !system_message.is_empty() {
system_message.push_str("\n\n");
}
system_message.push_str(&message.content);
}
}
}
if !system_message.is_empty() {
new_messages.insert(
0,
LanguageModelRequestMessage {
role: Role::System,
content: system_message,
},
);
}
request.messages = new_messages;
}
struct AuthenticationPrompt { struct AuthenticationPrompt {
api_key: View<Editor>, api_key: View<Editor>,
api_url: String, api_url: String,

View File

@ -1,11 +1,12 @@
use crate::{ use crate::{
assistant_settings::CloudModel, count_open_ai_tokens, CompletionProvider, LanguageModel, count_open_ai_tokens, CompletionProvider, LanguageModel, LanguageModelCompletionProvider,
LanguageModelCompletionProvider, LanguageModelRequest, LanguageModelRequest,
}; };
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use client::{proto, Client}; use client::{proto, Client};
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt, TryFutureExt}; use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt, TryFutureExt};
use gpui::{AnyView, AppContext, Task}; use gpui::{AnyView, AppContext, Task};
use language_model::CloudModel;
use std::{future, sync::Arc}; use std::{future, sync::Arc};
use strum::IntoEnumIterator; use strum::IntoEnumIterator;
use ui::prelude::*; use ui::prelude::*;
@ -52,7 +53,7 @@ impl CloudCompletionProvider {
} }
impl LanguageModelCompletionProvider for CloudCompletionProvider { impl LanguageModelCompletionProvider for CloudCompletionProvider {
fn available_models(&self, _cx: &AppContext) -> Vec<LanguageModel> { fn available_models(&self) -> Vec<LanguageModel> {
let mut custom_model = if let CloudModel::Custom(custom_model) = self.model.clone() { let mut custom_model = if let CloudModel::Custom(custom_model) = self.model.clone() {
Some(custom_model) Some(custom_model)
} else { } else {

View File

@ -6,52 +6,19 @@ mod ollama;
mod open_ai; mod open_ai;
pub use anthropic::*; pub use anthropic::*;
use anyhow::Result;
use client::Client;
pub use cloud::*; pub use cloud::*;
#[cfg(any(test, feature = "test-support"))] #[cfg(any(test, feature = "test-support"))]
pub use fake::*; pub use fake::*;
use futures::{future::BoxFuture, stream::BoxStream, StreamExt};
use gpui::{AnyView, AppContext, Task, WindowContext};
use language_model::{LanguageModel, LanguageModelRequest};
pub use ollama::*; pub use ollama::*;
pub use open_ai::*; pub use open_ai::*;
use parking_lot::RwLock; use parking_lot::RwLock;
use smol::lock::{Semaphore, SemaphoreGuardArc}; use smol::lock::{Semaphore, SemaphoreGuardArc};
use std::{any::Any, pin::Pin, sync::Arc, task::Poll};
use crate::{
assistant_settings::{AssistantProvider, AssistantSettings},
LanguageModel, LanguageModelRequest,
};
use anyhow::Result;
use client::Client;
use futures::{future::BoxFuture, stream::BoxStream, StreamExt};
use gpui::{AnyView, AppContext, BorrowAppContext, Task, WindowContext};
use settings::{Settings, SettingsStore};
use std::{any::Any, pin::Pin, sync::Arc, task::Poll, time::Duration};
/// Choose which model to use for openai provider.
/// If the model is not available, try to use the first available model, or fallback to the original model.
fn choose_openai_model(
model: &::open_ai::Model,
available_models: &[::open_ai::Model],
) -> ::open_ai::Model {
available_models
.iter()
.find(|&m| m == model)
.or_else(|| available_models.first())
.unwrap_or_else(|| model)
.clone()
}
pub fn init(client: Arc<Client>, cx: &mut AppContext) {
let provider = create_provider_from_settings(client.clone(), 0, cx);
cx.set_global(CompletionProvider::new(provider, Some(client)));
let mut settings_version = 0;
cx.observe_global::<SettingsStore>(move |cx| {
settings_version += 1;
cx.update_global::<CompletionProvider, _>(|provider, cx| {
provider.update_settings(settings_version, cx);
})
})
.detach();
}
pub struct CompletionResponse { pub struct CompletionResponse {
inner: BoxStream<'static, Result<String>>, inner: BoxStream<'static, Result<String>>,
@ -70,7 +37,7 @@ impl futures::Stream for CompletionResponse {
} }
pub trait LanguageModelCompletionProvider: Send + Sync { pub trait LanguageModelCompletionProvider: Send + Sync {
fn available_models(&self, cx: &AppContext) -> Vec<LanguageModel>; fn available_models(&self) -> Vec<LanguageModel>;
fn settings_version(&self) -> usize; fn settings_version(&self) -> usize;
fn is_authenticated(&self) -> bool; fn is_authenticated(&self) -> bool;
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>>; fn authenticate(&self, cx: &AppContext) -> Task<Result<()>>;
@ -110,8 +77,8 @@ impl CompletionProvider {
} }
} }
pub fn available_models(&self, cx: &AppContext) -> Vec<LanguageModel> { pub fn available_models(&self) -> Vec<LanguageModel> {
self.provider.read().available_models(cx) self.provider.read().available_models()
} }
pub fn settings_version(&self) -> usize { pub fn settings_version(&self) -> usize {
@ -176,6 +143,17 @@ impl CompletionProvider {
Ok(completion) Ok(completion)
}) })
} }
pub fn update_provider(
&mut self,
get_provider: impl FnOnce(Arc<Client>) -> Arc<RwLock<dyn LanguageModelCompletionProvider>>,
) {
if let Some(client) = &self.client {
self.provider = get_provider(Arc::clone(client));
} else {
log::warn!("completion provider cannot be updated because its client was not set");
}
}
} }
impl gpui::Global for CompletionProvider {} impl gpui::Global for CompletionProvider {}
@ -196,109 +174,6 @@ impl CompletionProvider {
None None
} }
} }
pub fn update_settings(&mut self, version: usize, cx: &mut AppContext) {
let updated = match &AssistantSettings::get_global(cx).provider {
AssistantProvider::ZedDotDev { model } => self
.update_current_as::<_, CloudCompletionProvider>(|provider| {
provider.update(model.clone(), version);
}),
AssistantProvider::OpenAi {
model,
api_url,
low_speed_timeout_in_seconds,
available_models,
} => self.update_current_as::<_, OpenAiCompletionProvider>(|provider| {
provider.update(
choose_openai_model(&model, &available_models),
api_url.clone(),
low_speed_timeout_in_seconds.map(Duration::from_secs),
version,
);
}),
AssistantProvider::Anthropic {
model,
api_url,
low_speed_timeout_in_seconds,
} => self.update_current_as::<_, AnthropicCompletionProvider>(|provider| {
provider.update(
model.clone(),
api_url.clone(),
low_speed_timeout_in_seconds.map(Duration::from_secs),
version,
);
}),
AssistantProvider::Ollama {
model,
api_url,
low_speed_timeout_in_seconds,
} => self.update_current_as::<_, OllamaCompletionProvider>(|provider| {
provider.update(
model.clone(),
api_url.clone(),
low_speed_timeout_in_seconds.map(Duration::from_secs),
version,
cx,
);
}),
};
// Previously configured provider was changed to another one
if updated.is_none() {
if let Some(client) = self.client.clone() {
self.provider = create_provider_from_settings(client, version, cx);
} else {
log::warn!("completion provider cannot be created because client is not set");
}
}
}
}
fn create_provider_from_settings(
client: Arc<Client>,
settings_version: usize,
cx: &mut AppContext,
) -> Arc<RwLock<dyn LanguageModelCompletionProvider>> {
match &AssistantSettings::get_global(cx).provider {
AssistantProvider::ZedDotDev { model } => Arc::new(RwLock::new(
CloudCompletionProvider::new(model.clone(), client.clone(), settings_version, cx),
)),
AssistantProvider::OpenAi {
model,
api_url,
low_speed_timeout_in_seconds,
available_models,
} => Arc::new(RwLock::new(OpenAiCompletionProvider::new(
choose_openai_model(&model, &available_models),
api_url.clone(),
client.http_client(),
low_speed_timeout_in_seconds.map(Duration::from_secs),
settings_version,
))),
AssistantProvider::Anthropic {
model,
api_url,
low_speed_timeout_in_seconds,
} => Arc::new(RwLock::new(AnthropicCompletionProvider::new(
model.clone(),
api_url.clone(),
client.http_client(),
low_speed_timeout_in_seconds.map(Duration::from_secs),
settings_version,
))),
AssistantProvider::Ollama {
model,
api_url,
low_speed_timeout_in_seconds,
} => Arc::new(RwLock::new(OllamaCompletionProvider::new(
model.clone(),
api_url.clone(),
client.http_client(),
low_speed_timeout_in_seconds.map(Duration::from_secs),
settings_version,
cx,
))),
}
} }
#[cfg(test)] #[cfg(test)]
@ -311,8 +186,8 @@ mod tests {
use smol::stream::StreamExt; use smol::stream::StreamExt;
use crate::{ use crate::{
completion_provider::MAX_CONCURRENT_COMPLETION_REQUESTS, CompletionProvider, CompletionProvider, FakeCompletionProvider, LanguageModelRequest,
FakeCompletionProvider, LanguageModelRequest, MAX_CONCURRENT_COMPLETION_REQUESTS,
}; };
#[gpui::test] #[gpui::test]

View File

@ -62,7 +62,7 @@ impl FakeCompletionProvider {
} }
impl LanguageModelCompletionProvider for FakeCompletionProvider { impl LanguageModelCompletionProvider for FakeCompletionProvider {
fn available_models(&self, _cx: &AppContext) -> Vec<LanguageModel> { fn available_models(&self) -> Vec<LanguageModel> {
vec![LanguageModel::default()] vec![LanguageModel::default()]
} }

View File

@ -1,15 +1,14 @@
use crate::LanguageModelCompletionProvider; use crate::LanguageModelCompletionProvider;
use crate::{ use crate::{CompletionProvider, LanguageModel, LanguageModelRequest};
assistant_settings::OllamaModel, CompletionProvider, LanguageModel, LanguageModelRequest, Role,
};
use anyhow::Result; use anyhow::Result;
use futures::StreamExt as _; use futures::StreamExt as _;
use futures::{future::BoxFuture, stream::BoxStream, FutureExt}; use futures::{future::BoxFuture, stream::BoxStream, FutureExt};
use gpui::{AnyView, AppContext, Task}; use gpui::{AnyView, AppContext, Task};
use http::HttpClient; use http::HttpClient;
use language_model::Role;
use ollama::Model as OllamaModel;
use ollama::{ use ollama::{
get_models, preload_model, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest, get_models, preload_model, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest,
Role as OllamaRole,
}; };
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
@ -28,7 +27,7 @@ pub struct OllamaCompletionProvider {
} }
impl LanguageModelCompletionProvider for OllamaCompletionProvider { impl LanguageModelCompletionProvider for OllamaCompletionProvider {
fn available_models(&self, _cx: &AppContext) -> Vec<LanguageModel> { fn available_models(&self) -> Vec<LanguageModel> {
self.available_models self.available_models
.iter() .iter()
.map(|m| LanguageModel::Ollama(m.clone())) .map(|m| LanguageModel::Ollama(m.clone()))
@ -262,16 +261,6 @@ impl OllamaCompletionProvider {
} }
} }
impl From<Role> for ollama::Role {
fn from(val: Role) -> Self {
match val {
Role::User => OllamaRole::User,
Role::Assistant => OllamaRole::Assistant,
Role::System => OllamaRole::System,
}
}
}
struct DownloadOllamaMessage { struct DownloadOllamaMessage {
retry_connection: Box<dyn Fn(&mut WindowContext) -> Task<Result<()>>>, retry_connection: Box<dyn Fn(&mut WindowContext) -> Task<Result<()>>>,
} }

View File

@ -1,15 +1,13 @@
use crate::assistant_settings::CloudModel; use crate::CompletionProvider;
use crate::assistant_settings::{AssistantProvider, AssistantSettings};
use crate::LanguageModelCompletionProvider; use crate::LanguageModelCompletionProvider;
use crate::{
assistant_settings::OpenAiModel, CompletionProvider, LanguageModel, LanguageModelRequest, Role,
};
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use editor::{Editor, EditorElement, EditorStyle}; use editor::{Editor, EditorElement, EditorStyle};
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt}; use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
use gpui::{AnyView, AppContext, FontStyle, Task, TextStyle, View, WhiteSpace}; use gpui::{AnyView, AppContext, FontStyle, Task, TextStyle, View, WhiteSpace};
use http::HttpClient; use http::HttpClient;
use open_ai::{stream_completion, Request, RequestMessage, Role as OpenAiRole}; use language_model::{CloudModel, LanguageModel, LanguageModelRequest, Role};
use open_ai::Model as OpenAiModel;
use open_ai::{stream_completion, Request, RequestMessage};
use settings::Settings; use settings::Settings;
use std::time::Duration; use std::time::Duration;
use std::{env, sync::Arc}; use std::{env, sync::Arc};
@ -25,6 +23,7 @@ pub struct OpenAiCompletionProvider {
http_client: Arc<dyn HttpClient>, http_client: Arc<dyn HttpClient>,
low_speed_timeout: Option<Duration>, low_speed_timeout: Option<Duration>,
settings_version: usize, settings_version: usize,
available_models_from_settings: Vec<OpenAiModel>,
} }
impl OpenAiCompletionProvider { impl OpenAiCompletionProvider {
@ -34,6 +33,7 @@ impl OpenAiCompletionProvider {
http_client: Arc<dyn HttpClient>, http_client: Arc<dyn HttpClient>,
low_speed_timeout: Option<Duration>, low_speed_timeout: Option<Duration>,
settings_version: usize, settings_version: usize,
available_models_from_settings: Vec<OpenAiModel>,
) -> Self { ) -> Self {
Self { Self {
api_key: None, api_key: None,
@ -42,6 +42,7 @@ impl OpenAiCompletionProvider {
http_client, http_client,
low_speed_timeout, low_speed_timeout,
settings_version, settings_version,
available_models_from_settings,
} }
} }
@ -92,30 +93,26 @@ impl OpenAiCompletionProvider {
} }
impl LanguageModelCompletionProvider for OpenAiCompletionProvider { impl LanguageModelCompletionProvider for OpenAiCompletionProvider {
fn available_models(&self, cx: &AppContext) -> Vec<LanguageModel> { fn available_models(&self) -> Vec<LanguageModel> {
if let AssistantProvider::OpenAi { if self.available_models_from_settings.is_empty() {
available_models, .. let available_models = if matches!(self.model, OpenAiModel::Custom { .. }) {
} = &AssistantSettings::get_global(cx).provider vec![self.model.clone()]
{ } else {
if !available_models.is_empty() { OpenAiModel::iter()
return available_models .filter(|model| !matches!(model, OpenAiModel::Custom { .. }))
.iter() .collect()
.cloned() };
.map(LanguageModel::OpenAi) available_models
.collect(); .into_iter()
} .map(LanguageModel::OpenAi)
}
let available_models = if matches!(self.model, OpenAiModel::Custom { .. }) {
vec![self.model.clone()]
} else {
OpenAiModel::iter()
.filter(|model| !matches!(model, OpenAiModel::Custom { .. }))
.collect() .collect()
}; } else {
available_models self.available_models_from_settings
.into_iter() .iter()
.map(LanguageModel::OpenAi) .cloned()
.collect() .map(LanguageModel::OpenAi)
.collect()
}
} }
fn settings_version(&self) -> usize { fn settings_version(&self) -> usize {
@ -255,16 +252,6 @@ pub fn count_open_ai_tokens(
.boxed() .boxed()
} }
impl From<Role> for open_ai::Role {
fn from(val: Role) -> Self {
match val {
Role::User => OpenAiRole::User,
Role::Assistant => OpenAiRole::Assistant,
Role::System => OpenAiRole::System,
}
}
}
struct AuthenticationPrompt { struct AuthenticationPrompt {
api_key: View<Editor>, api_key: View<Editor>,
api_url: String, api_url: String,

View File

@ -0,0 +1,41 @@
[package]
name = "language_model"
version = "0.1.0"
edition = "2021"
publish = false
license = "GPL-3.0-or-later"
[lints]
workspace = true
[lib]
path = "src/language_model.rs"
doctest = false
[features]
test-support = [
"editor/test-support",
"language/test-support",
"project/test-support",
"text/test-support",
]
[dependencies]
anthropic = { workspace = true, features = ["schemars"] }
ollama = { workspace = true, features = ["schemars"] }
open_ai = { workspace = true, features = ["schemars"] }
schemars.workspace = true
serde.workspace = true
strum.workspace = true
proto = { workspace = true, features = ["test-support"] }
[dev-dependencies]
ctor.workspace = true
editor = { workspace = true, features = ["test-support"] }
env_logger.workspace = true
language = { workspace = true, features = ["test-support"] }
log.workspace = true
project = { workspace = true, features = ["test-support"] }
rand.workspace = true
text = { workspace = true, features = ["test-support"] }
unindent.workspace = true

View File

@ -0,0 +1 @@
../../LICENSE-GPL

View File

@ -0,0 +1,7 @@
mod model;
mod request;
mod role;
pub use model::*;
pub use request::*;
pub use role::*;

View File

@ -0,0 +1,160 @@
use crate::LanguageModelRequest;
pub use anthropic::Model as AnthropicModel;
pub use ollama::Model as OllamaModel;
pub use open_ai::Model as OpenAiModel;
use schemars::{
schema::{InstanceType, Metadata, Schema, SchemaObject},
JsonSchema,
};
use serde::{
de::{self, Visitor},
Deserialize, Deserializer, Serialize, Serializer,
};
use std::fmt;
use strum::{EnumIter, IntoEnumIterator};
#[derive(Clone, Debug, Default, PartialEq, EnumIter)]
pub enum CloudModel {
Gpt3Point5Turbo,
Gpt4,
Gpt4Turbo,
#[default]
Gpt4Omni,
Gpt4OmniMini,
Claude3_5Sonnet,
Claude3Opus,
Claude3Sonnet,
Claude3Haiku,
Gemini15Pro,
Gemini15Flash,
Custom(String),
}
impl Serialize for CloudModel {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(self.id())
}
}
impl<'de> Deserialize<'de> for CloudModel {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct ZedDotDevModelVisitor;
impl<'de> Visitor<'de> for ZedDotDevModelVisitor {
type Value = CloudModel;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a string for a ZedDotDevModel variant or a custom model")
}
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
where
E: de::Error,
{
let model = CloudModel::iter()
.find(|model| model.id() == value)
.unwrap_or_else(|| CloudModel::Custom(value.to_string()));
Ok(model)
}
}
deserializer.deserialize_str(ZedDotDevModelVisitor)
}
}
impl JsonSchema for CloudModel {
fn schema_name() -> String {
"ZedDotDevModel".to_owned()
}
fn json_schema(_generator: &mut schemars::gen::SchemaGenerator) -> Schema {
let variants = CloudModel::iter()
.filter_map(|model| {
let id = model.id();
if id.is_empty() {
None
} else {
Some(id.to_string())
}
})
.collect::<Vec<_>>();
Schema::Object(SchemaObject {
instance_type: Some(InstanceType::String.into()),
enum_values: Some(variants.iter().map(|s| s.clone().into()).collect()),
metadata: Some(Box::new(Metadata {
title: Some("ZedDotDevModel".to_owned()),
default: Some(CloudModel::default().id().into()),
examples: variants.into_iter().map(Into::into).collect(),
..Default::default()
})),
..Default::default()
})
}
}
impl CloudModel {
pub fn id(&self) -> &str {
match self {
Self::Gpt3Point5Turbo => "gpt-3.5-turbo",
Self::Gpt4 => "gpt-4",
Self::Gpt4Turbo => "gpt-4-turbo-preview",
Self::Gpt4Omni => "gpt-4o",
Self::Gpt4OmniMini => "gpt-4o-mini",
Self::Claude3_5Sonnet => "claude-3-5-sonnet",
Self::Claude3Opus => "claude-3-opus",
Self::Claude3Sonnet => "claude-3-sonnet",
Self::Claude3Haiku => "claude-3-haiku",
Self::Gemini15Pro => "gemini-1.5-pro",
Self::Gemini15Flash => "gemini-1.5-flash",
Self::Custom(id) => id,
}
}
pub fn display_name(&self) -> &str {
match self {
Self::Gpt3Point5Turbo => "GPT 3.5 Turbo",
Self::Gpt4 => "GPT 4",
Self::Gpt4Turbo => "GPT 4 Turbo",
Self::Gpt4Omni => "GPT 4 Omni",
Self::Gpt4OmniMini => "GPT 4 Omni Mini",
Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
Self::Claude3Opus => "Claude 3 Opus",
Self::Claude3Sonnet => "Claude 3 Sonnet",
Self::Claude3Haiku => "Claude 3 Haiku",
Self::Gemini15Pro => "Gemini 1.5 Pro",
Self::Gemini15Flash => "Gemini 1.5 Flash",
Self::Custom(id) => id.as_str(),
}
}
pub fn max_token_count(&self) -> usize {
match self {
Self::Gpt3Point5Turbo => 2048,
Self::Gpt4 => 4096,
Self::Gpt4Turbo | Self::Gpt4Omni => 128000,
Self::Gpt4OmniMini => 128000,
Self::Claude3_5Sonnet
| Self::Claude3Opus
| Self::Claude3Sonnet
| Self::Claude3Haiku => 200000,
Self::Gemini15Pro => 128000,
Self::Gemini15Flash => 32000,
Self::Custom(_) => 4096, // TODO: Make this configurable
}
}
pub fn preprocess_request(&self, request: &mut LanguageModelRequest) {
match self {
Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => {
request.preprocess_anthropic()
}
_ => {}
}
}
}

View File

@ -0,0 +1,60 @@
pub mod cloud_model;
pub use anthropic::Model as AnthropicModel;
pub use cloud_model::*;
pub use ollama::Model as OllamaModel;
pub use open_ai::Model as OpenAiModel;
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
pub enum LanguageModel {
Cloud(CloudModel),
OpenAi(OpenAiModel),
Anthropic(AnthropicModel),
Ollama(OllamaModel),
}
impl Default for LanguageModel {
fn default() -> Self {
LanguageModel::Cloud(CloudModel::default())
}
}
impl LanguageModel {
pub fn telemetry_id(&self) -> String {
match self {
LanguageModel::OpenAi(model) => format!("openai/{}", model.id()),
LanguageModel::Anthropic(model) => format!("anthropic/{}", model.id()),
LanguageModel::Cloud(model) => format!("zed.dev/{}", model.id()),
LanguageModel::Ollama(model) => format!("ollama/{}", model.id()),
}
}
pub fn display_name(&self) -> String {
match self {
LanguageModel::OpenAi(model) => model.display_name().into(),
LanguageModel::Anthropic(model) => model.display_name().into(),
LanguageModel::Cloud(model) => model.display_name().into(),
LanguageModel::Ollama(model) => model.display_name().into(),
}
}
pub fn max_token_count(&self) -> usize {
match self {
LanguageModel::OpenAi(model) => model.max_token_count(),
LanguageModel::Anthropic(model) => model.max_token_count(),
LanguageModel::Cloud(model) => model.max_token_count(),
LanguageModel::Ollama(model) => model.max_token_count(),
}
}
pub fn id(&self) -> &str {
match self {
LanguageModel::OpenAi(model) => model.id(),
LanguageModel::Anthropic(model) => model.id(),
LanguageModel::Cloud(model) => model.id(),
LanguageModel::Ollama(model) => model.id(),
}
}
}

View File

@ -0,0 +1,110 @@
use crate::{
model::{CloudModel, LanguageModel},
role::Role,
};
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct LanguageModelRequestMessage {
pub role: Role,
pub content: String,
}
impl LanguageModelRequestMessage {
pub fn to_proto(&self) -> proto::LanguageModelRequestMessage {
proto::LanguageModelRequestMessage {
role: self.role.to_proto() as i32,
content: self.content.clone(),
tool_calls: Vec::new(),
tool_call_id: None,
}
}
}
#[derive(Debug, Default, Serialize, Deserialize)]
pub struct LanguageModelRequest {
pub model: LanguageModel,
pub messages: Vec<LanguageModelRequestMessage>,
pub stop: Vec<String>,
pub temperature: f32,
}
impl LanguageModelRequest {
pub fn to_proto(&self) -> proto::CompleteWithLanguageModel {
proto::CompleteWithLanguageModel {
model: self.model.id().to_string(),
messages: self.messages.iter().map(|m| m.to_proto()).collect(),
stop: self.stop.clone(),
temperature: self.temperature,
tool_choice: None,
tools: Vec::new(),
}
}
/// Before we send the request to the server, we can perform fixups on it appropriate to the model.
pub fn preprocess(&mut self) {
match &self.model {
LanguageModel::OpenAi(_) => {}
LanguageModel::Anthropic(_) => {}
LanguageModel::Ollama(_) => {}
LanguageModel::Cloud(model) => match model {
CloudModel::Claude3Opus
| CloudModel::Claude3Sonnet
| CloudModel::Claude3Haiku
| CloudModel::Claude3_5Sonnet => {
self.preprocess_anthropic();
}
_ => {}
},
}
}
pub fn preprocess_anthropic(&mut self) {
let mut new_messages: Vec<LanguageModelRequestMessage> = Vec::new();
let mut system_message = String::new();
for message in self.messages.drain(..) {
if message.content.is_empty() {
continue;
}
match message.role {
Role::User | Role::Assistant => {
if let Some(last_message) = new_messages.last_mut() {
if last_message.role == message.role {
last_message.content.push_str("\n\n");
last_message.content.push_str(&message.content);
continue;
}
}
new_messages.push(message);
}
Role::System => {
if !system_message.is_empty() {
system_message.push_str("\n\n");
}
system_message.push_str(&message.content);
}
}
}
if !system_message.is_empty() {
new_messages.insert(
0,
LanguageModelRequestMessage {
role: Role::System,
content: system_message,
},
);
}
self.messages = new_messages;
}
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct LanguageModelResponseMessage {
pub role: Option<Role>,
pub content: Option<String>,
}

View File

@ -0,0 +1,68 @@
use serde::{Deserialize, Serialize};
use std::fmt::{self, Display};
#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum Role {
User,
Assistant,
System,
}
impl Role {
pub fn from_proto(role: i32) -> Role {
match proto::LanguageModelRole::from_i32(role) {
Some(proto::LanguageModelRole::LanguageModelUser) => Role::User,
Some(proto::LanguageModelRole::LanguageModelAssistant) => Role::Assistant,
Some(proto::LanguageModelRole::LanguageModelSystem) => Role::System,
Some(proto::LanguageModelRole::LanguageModelTool) => Role::System,
None => Role::User,
}
}
pub fn to_proto(&self) -> proto::LanguageModelRole {
match self {
Role::User => proto::LanguageModelRole::LanguageModelUser,
Role::Assistant => proto::LanguageModelRole::LanguageModelAssistant,
Role::System => proto::LanguageModelRole::LanguageModelSystem,
}
}
pub fn cycle(self) -> Role {
match self {
Role::User => Role::Assistant,
Role::Assistant => Role::System,
Role::System => Role::User,
}
}
}
impl Display for Role {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Role::User => write!(f, "user"),
Role::Assistant => write!(f, "assistant"),
Role::System => write!(f, "system"),
}
}
}
impl From<Role> for ollama::Role {
fn from(val: Role) -> Self {
match val {
Role::User => ollama::Role::User,
Role::Assistant => ollama::Role::Assistant,
Role::System => ollama::Role::System,
}
}
}
impl From<Role> for open_ai::Role {
fn from(val: Role) -> Self {
match val {
Role::User => open_ai::Role::User,
Role::Assistant => open_ai::Role::Assistant,
Role::System => open_ai::Role::System,
}
}
}

View File

@ -22,6 +22,7 @@ anyhow.workspace = true
client.workspace = true client.workspace = true
clock.workspace = true clock.workspace = true
collections.workspace = true collections.workspace = true
completion.workspace = true
fs.workspace = true fs.workspace = true
futures.workspace = true futures.workspace = true
futures-batch.workspace = true futures-batch.workspace = true

View File

@ -1261,3 +1261,6 @@ mod tests {
); );
} }
} }
// See https://github.com/zed-industries/zed/pull/14823#discussion_r1684616398 for why this is here and when it should be removed.
type _TODO = completion::CompletionProvider;