mirror of
https://github.com/zed-industries/zed.git
synced 2024-11-07 20:39:04 +03:00
Extract completion provider crate (#14823)
We will soon need `semantic_index` to be able to use
`CompletionProvider`. This is currently impossible due to a cyclic crate
dependency, because `CompletionProvider` lives in the `assistant` crate,
which depends on `semantic_index`.
This PR breaks the dependency cycle by extracting two crates out of
`assistant`: `language_model` and `completion`.
Only one piece of logic changed: [this
code](922fcaf5a6 (diff-3857b3707687a4d585f1200eec4c34a7a079eae8d303b4ce5b4fce46234ace9fR61-R69)
).
* As of https://github.com/zed-industries/zed/pull/13276, whenever we
ask a given completion provider for its available models, OpenAI
providers would go and ask the global assistant settings whether the
user had configured an `available_models` setting, and if so, return
that.
* This PR changes it so that instead of eagerly asking the assistant
settings for this info (the new crate must not depend on `assistant`, or
else the dependency cycle would be back), OpenAI completion providers
now store the user-configured settings as part of their struct, and
whenever the settings change, we update the provider.
In theory, this change should not change user-visible behavior...but
since it's the only change in this large PR that's more than just moving
code around, I'm mentioning it here in case there's an unexpected
regression in practice! (cc @amtoaer in case you'd like to try out this
branch and verify that the feature is still working the way you expect.)
Release Notes:
- N/A
---------
Co-authored-by: Marshall Bowers <elliott.codes@gmail.com>
This commit is contained in:
parent
b9a53ffa0b
commit
ec487d8f64
64
Cargo.lock
generated
64
Cargo.lock
generated
@ -382,6 +382,7 @@ dependencies = [
|
|||||||
"clock",
|
"clock",
|
||||||
"collections",
|
"collections",
|
||||||
"command_palette_hooks",
|
"command_palette_hooks",
|
||||||
|
"completion",
|
||||||
"ctor",
|
"ctor",
|
||||||
"editor",
|
"editor",
|
||||||
"env_logger",
|
"env_logger",
|
||||||
@ -396,6 +397,7 @@ dependencies = [
|
|||||||
"indexed_docs",
|
"indexed_docs",
|
||||||
"indoc",
|
"indoc",
|
||||||
"language",
|
"language",
|
||||||
|
"language_model",
|
||||||
"log",
|
"log",
|
||||||
"menu",
|
"menu",
|
||||||
"multi_buffer",
|
"multi_buffer",
|
||||||
@ -418,13 +420,11 @@ dependencies = [
|
|||||||
"settings",
|
"settings",
|
||||||
"similar",
|
"similar",
|
||||||
"smol",
|
"smol",
|
||||||
"strum",
|
|
||||||
"telemetry_events",
|
"telemetry_events",
|
||||||
"terminal",
|
"terminal",
|
||||||
"terminal_view",
|
"terminal_view",
|
||||||
"text",
|
"text",
|
||||||
"theme",
|
"theme",
|
||||||
"tiktoken-rs",
|
|
||||||
"toml 0.8.10",
|
"toml 0.8.10",
|
||||||
"ui",
|
"ui",
|
||||||
"unindent",
|
"unindent",
|
||||||
@ -2491,6 +2491,7 @@ dependencies = [
|
|||||||
"clock",
|
"clock",
|
||||||
"collab_ui",
|
"collab_ui",
|
||||||
"collections",
|
"collections",
|
||||||
|
"completion",
|
||||||
"ctor",
|
"ctor",
|
||||||
"dashmap",
|
"dashmap",
|
||||||
"dev_server_projects",
|
"dev_server_projects",
|
||||||
@ -2673,6 +2674,42 @@ dependencies = [
|
|||||||
"gpui",
|
"gpui",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "completion"
|
||||||
|
version = "0.1.0"
|
||||||
|
dependencies = [
|
||||||
|
"anthropic",
|
||||||
|
"anyhow",
|
||||||
|
"client",
|
||||||
|
"collections",
|
||||||
|
"ctor",
|
||||||
|
"editor",
|
||||||
|
"env_logger",
|
||||||
|
"futures 0.3.28",
|
||||||
|
"gpui",
|
||||||
|
"http 0.1.0",
|
||||||
|
"language",
|
||||||
|
"language_model",
|
||||||
|
"log",
|
||||||
|
"menu",
|
||||||
|
"ollama",
|
||||||
|
"open_ai",
|
||||||
|
"parking_lot",
|
||||||
|
"project",
|
||||||
|
"rand 0.8.5",
|
||||||
|
"serde",
|
||||||
|
"serde_json",
|
||||||
|
"settings",
|
||||||
|
"smol",
|
||||||
|
"strum",
|
||||||
|
"text",
|
||||||
|
"theme",
|
||||||
|
"tiktoken-rs",
|
||||||
|
"ui",
|
||||||
|
"unindent",
|
||||||
|
"util",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "concurrent-queue"
|
name = "concurrent-queue"
|
||||||
version = "2.2.0"
|
version = "2.2.0"
|
||||||
@ -5996,6 +6033,28 @@ dependencies = [
|
|||||||
"util",
|
"util",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "language_model"
|
||||||
|
version = "0.1.0"
|
||||||
|
dependencies = [
|
||||||
|
"anthropic",
|
||||||
|
"ctor",
|
||||||
|
"editor",
|
||||||
|
"env_logger",
|
||||||
|
"language",
|
||||||
|
"log",
|
||||||
|
"ollama",
|
||||||
|
"open_ai",
|
||||||
|
"project",
|
||||||
|
"proto",
|
||||||
|
"rand 0.8.5",
|
||||||
|
"schemars",
|
||||||
|
"serde",
|
||||||
|
"strum",
|
||||||
|
"text",
|
||||||
|
"unindent",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "language_selector"
|
name = "language_selector"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
@ -9510,6 +9569,7 @@ dependencies = [
|
|||||||
"client",
|
"client",
|
||||||
"clock",
|
"clock",
|
||||||
"collections",
|
"collections",
|
||||||
|
"completion",
|
||||||
"env_logger",
|
"env_logger",
|
||||||
"fs",
|
"fs",
|
||||||
"futures 0.3.28",
|
"futures 0.3.28",
|
||||||
|
@ -19,6 +19,7 @@ members = [
|
|||||||
"crates/collections",
|
"crates/collections",
|
||||||
"crates/command_palette",
|
"crates/command_palette",
|
||||||
"crates/command_palette_hooks",
|
"crates/command_palette_hooks",
|
||||||
|
"crates/completion",
|
||||||
"crates/copilot",
|
"crates/copilot",
|
||||||
"crates/db",
|
"crates/db",
|
||||||
"crates/dev_server_projects",
|
"crates/dev_server_projects",
|
||||||
@ -50,6 +51,7 @@ members = [
|
|||||||
"crates/install_cli",
|
"crates/install_cli",
|
||||||
"crates/journal",
|
"crates/journal",
|
||||||
"crates/language",
|
"crates/language",
|
||||||
|
"crates/language_model",
|
||||||
"crates/language_selector",
|
"crates/language_selector",
|
||||||
"crates/language_tools",
|
"crates/language_tools",
|
||||||
"crates/languages",
|
"crates/languages",
|
||||||
@ -176,6 +178,7 @@ collab_ui = { path = "crates/collab_ui" }
|
|||||||
collections = { path = "crates/collections" }
|
collections = { path = "crates/collections" }
|
||||||
command_palette = { path = "crates/command_palette" }
|
command_palette = { path = "crates/command_palette" }
|
||||||
command_palette_hooks = { path = "crates/command_palette_hooks" }
|
command_palette_hooks = { path = "crates/command_palette_hooks" }
|
||||||
|
completion = { path = "crates/completion" }
|
||||||
copilot = { path = "crates/copilot" }
|
copilot = { path = "crates/copilot" }
|
||||||
db = { path = "crates/db" }
|
db = { path = "crates/db" }
|
||||||
dev_server_projects = { path = "crates/dev_server_projects" }
|
dev_server_projects = { path = "crates/dev_server_projects" }
|
||||||
@ -205,6 +208,7 @@ inline_completion_button = { path = "crates/inline_completion_button" }
|
|||||||
install_cli = { path = "crates/install_cli" }
|
install_cli = { path = "crates/install_cli" }
|
||||||
journal = { path = "crates/journal" }
|
journal = { path = "crates/journal" }
|
||||||
language = { path = "crates/language" }
|
language = { path = "crates/language" }
|
||||||
|
language_model = { path = "crates/language_model" }
|
||||||
language_selector = { path = "crates/language_selector" }
|
language_selector = { path = "crates/language_selector" }
|
||||||
language_tools = { path = "crates/language_tools" }
|
language_tools = { path = "crates/language_tools" }
|
||||||
languages = { path = "crates/languages" }
|
languages = { path = "crates/languages" }
|
||||||
|
@ -33,6 +33,7 @@ client.workspace = true
|
|||||||
clock.workspace = true
|
clock.workspace = true
|
||||||
collections.workspace = true
|
collections.workspace = true
|
||||||
command_palette_hooks.workspace = true
|
command_palette_hooks.workspace = true
|
||||||
|
completion.workspace = true
|
||||||
editor.workspace = true
|
editor.workspace = true
|
||||||
feature_flags.workspace = true
|
feature_flags.workspace = true
|
||||||
fs.workspace = true
|
fs.workspace = true
|
||||||
@ -45,6 +46,7 @@ http.workspace = true
|
|||||||
indexed_docs.workspace = true
|
indexed_docs.workspace = true
|
||||||
indoc.workspace = true
|
indoc.workspace = true
|
||||||
language.workspace = true
|
language.workspace = true
|
||||||
|
language_model.workspace = true
|
||||||
log.workspace = true
|
log.workspace = true
|
||||||
menu.workspace = true
|
menu.workspace = true
|
||||||
multi_buffer.workspace = true
|
multi_buffer.workspace = true
|
||||||
@ -64,12 +66,10 @@ serde_json.workspace = true
|
|||||||
settings.workspace = true
|
settings.workspace = true
|
||||||
similar.workspace = true
|
similar.workspace = true
|
||||||
smol.workspace = true
|
smol.workspace = true
|
||||||
strum.workspace = true
|
|
||||||
telemetry_events.workspace = true
|
telemetry_events.workspace = true
|
||||||
terminal.workspace = true
|
terminal.workspace = true
|
||||||
terminal_view.workspace = true
|
terminal_view.workspace = true
|
||||||
theme.workspace = true
|
theme.workspace = true
|
||||||
tiktoken-rs.workspace = true
|
|
||||||
toml.workspace = true
|
toml.workspace = true
|
||||||
ui.workspace = true
|
ui.workspace = true
|
||||||
util.workspace = true
|
util.workspace = true
|
||||||
@ -79,6 +79,7 @@ picker.workspace = true
|
|||||||
roxmltree = "0.20.0"
|
roxmltree = "0.20.0"
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
|
completion = { workspace = true, features = ["test-support"] }
|
||||||
ctor.workspace = true
|
ctor.workspace = true
|
||||||
editor = { workspace = true, features = ["test-support"] }
|
editor = { workspace = true, features = ["test-support"] }
|
||||||
env_logger.workspace = true
|
env_logger.workspace = true
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
pub mod assistant_panel;
|
pub mod assistant_panel;
|
||||||
pub mod assistant_settings;
|
pub mod assistant_settings;
|
||||||
mod completion_provider;
|
|
||||||
mod context;
|
mod context;
|
||||||
pub mod context_store;
|
pub mod context_store;
|
||||||
mod inline_assistant;
|
mod inline_assistant;
|
||||||
@ -12,17 +11,20 @@ mod streaming_diff;
|
|||||||
mod terminal_inline_assistant;
|
mod terminal_inline_assistant;
|
||||||
|
|
||||||
pub use assistant_panel::{AssistantPanel, AssistantPanelEvent};
|
pub use assistant_panel::{AssistantPanel, AssistantPanelEvent};
|
||||||
use assistant_settings::{AnthropicModel, AssistantSettings, CloudModel, OllamaModel, OpenAiModel};
|
use assistant_settings::AssistantSettings;
|
||||||
use assistant_slash_command::SlashCommandRegistry;
|
use assistant_slash_command::SlashCommandRegistry;
|
||||||
use client::{proto, Client};
|
use client::{proto, Client};
|
||||||
use command_palette_hooks::CommandPaletteFilter;
|
use command_palette_hooks::CommandPaletteFilter;
|
||||||
pub use completion_provider::*;
|
use completion::CompletionProvider;
|
||||||
pub use context::*;
|
pub use context::*;
|
||||||
pub use context_store::*;
|
pub use context_store::*;
|
||||||
use fs::Fs;
|
use fs::Fs;
|
||||||
use gpui::{actions, impl_actions, AppContext, Global, SharedString, UpdateGlobal};
|
use gpui::{
|
||||||
|
actions, impl_actions, AppContext, BorrowAppContext, Global, SharedString, UpdateGlobal,
|
||||||
|
};
|
||||||
use indexed_docs::IndexedDocsRegistry;
|
use indexed_docs::IndexedDocsRegistry;
|
||||||
pub(crate) use inline_assistant::*;
|
pub(crate) use inline_assistant::*;
|
||||||
|
use language_model::LanguageModelResponseMessage;
|
||||||
pub(crate) use model_selector::*;
|
pub(crate) use model_selector::*;
|
||||||
use semantic_index::{CloudEmbeddingProvider, SemanticIndex};
|
use semantic_index::{CloudEmbeddingProvider, SemanticIndex};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
@ -32,10 +34,7 @@ use slash_command::{
|
|||||||
file_command, now_command, project_command, prompt_command, search_command, symbols_command,
|
file_command, now_command, project_command, prompt_command, search_command, symbols_command,
|
||||||
tabs_command, term_command,
|
tabs_command, term_command,
|
||||||
};
|
};
|
||||||
use std::{
|
use std::sync::Arc;
|
||||||
fmt::{self, Display},
|
|
||||||
sync::Arc,
|
|
||||||
};
|
|
||||||
pub(crate) use streaming_diff::*;
|
pub(crate) use streaming_diff::*;
|
||||||
|
|
||||||
actions!(
|
actions!(
|
||||||
@ -73,166 +72,6 @@ impl MessageId {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
|
|
||||||
#[serde(rename_all = "lowercase")]
|
|
||||||
pub enum Role {
|
|
||||||
User,
|
|
||||||
Assistant,
|
|
||||||
System,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Role {
|
|
||||||
pub fn from_proto(role: i32) -> Role {
|
|
||||||
match proto::LanguageModelRole::from_i32(role) {
|
|
||||||
Some(proto::LanguageModelRole::LanguageModelUser) => Role::User,
|
|
||||||
Some(proto::LanguageModelRole::LanguageModelAssistant) => Role::Assistant,
|
|
||||||
Some(proto::LanguageModelRole::LanguageModelSystem) => Role::System,
|
|
||||||
Some(proto::LanguageModelRole::LanguageModelTool) => Role::System,
|
|
||||||
None => Role::User,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn to_proto(&self) -> proto::LanguageModelRole {
|
|
||||||
match self {
|
|
||||||
Role::User => proto::LanguageModelRole::LanguageModelUser,
|
|
||||||
Role::Assistant => proto::LanguageModelRole::LanguageModelAssistant,
|
|
||||||
Role::System => proto::LanguageModelRole::LanguageModelSystem,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn cycle(self) -> Role {
|
|
||||||
match self {
|
|
||||||
Role::User => Role::Assistant,
|
|
||||||
Role::Assistant => Role::System,
|
|
||||||
Role::System => Role::User,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Display for Role {
|
|
||||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
|
|
||||||
match self {
|
|
||||||
Role::User => write!(f, "user"),
|
|
||||||
Role::Assistant => write!(f, "assistant"),
|
|
||||||
Role::System => write!(f, "system"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
|
|
||||||
pub enum LanguageModel {
|
|
||||||
Cloud(CloudModel),
|
|
||||||
OpenAi(OpenAiModel),
|
|
||||||
Anthropic(AnthropicModel),
|
|
||||||
Ollama(OllamaModel),
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Default for LanguageModel {
|
|
||||||
fn default() -> Self {
|
|
||||||
LanguageModel::Cloud(CloudModel::default())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl LanguageModel {
|
|
||||||
pub fn telemetry_id(&self) -> String {
|
|
||||||
match self {
|
|
||||||
LanguageModel::OpenAi(model) => format!("openai/{}", model.id()),
|
|
||||||
LanguageModel::Anthropic(model) => format!("anthropic/{}", model.id()),
|
|
||||||
LanguageModel::Cloud(model) => format!("zed.dev/{}", model.id()),
|
|
||||||
LanguageModel::Ollama(model) => format!("ollama/{}", model.id()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn display_name(&self) -> String {
|
|
||||||
match self {
|
|
||||||
LanguageModel::OpenAi(model) => model.display_name().into(),
|
|
||||||
LanguageModel::Anthropic(model) => model.display_name().into(),
|
|
||||||
LanguageModel::Cloud(model) => model.display_name().into(),
|
|
||||||
LanguageModel::Ollama(model) => model.display_name().into(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn max_token_count(&self) -> usize {
|
|
||||||
match self {
|
|
||||||
LanguageModel::OpenAi(model) => model.max_token_count(),
|
|
||||||
LanguageModel::Anthropic(model) => model.max_token_count(),
|
|
||||||
LanguageModel::Cloud(model) => model.max_token_count(),
|
|
||||||
LanguageModel::Ollama(model) => model.max_token_count(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn id(&self) -> &str {
|
|
||||||
match self {
|
|
||||||
LanguageModel::OpenAi(model) => model.id(),
|
|
||||||
LanguageModel::Anthropic(model) => model.id(),
|
|
||||||
LanguageModel::Cloud(model) => model.id(),
|
|
||||||
LanguageModel::Ollama(model) => model.id(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
|
||||||
pub struct LanguageModelRequestMessage {
|
|
||||||
pub role: Role,
|
|
||||||
pub content: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl LanguageModelRequestMessage {
|
|
||||||
pub fn to_proto(&self) -> proto::LanguageModelRequestMessage {
|
|
||||||
proto::LanguageModelRequestMessage {
|
|
||||||
role: self.role.to_proto() as i32,
|
|
||||||
content: self.content.clone(),
|
|
||||||
tool_calls: Vec::new(),
|
|
||||||
tool_call_id: None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Default, Serialize, Deserialize)]
|
|
||||||
pub struct LanguageModelRequest {
|
|
||||||
pub model: LanguageModel,
|
|
||||||
pub messages: Vec<LanguageModelRequestMessage>,
|
|
||||||
pub stop: Vec<String>,
|
|
||||||
pub temperature: f32,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl LanguageModelRequest {
|
|
||||||
pub fn to_proto(&self) -> proto::CompleteWithLanguageModel {
|
|
||||||
proto::CompleteWithLanguageModel {
|
|
||||||
model: self.model.id().to_string(),
|
|
||||||
messages: self.messages.iter().map(|m| m.to_proto()).collect(),
|
|
||||||
stop: self.stop.clone(),
|
|
||||||
temperature: self.temperature,
|
|
||||||
tool_choice: None,
|
|
||||||
tools: Vec::new(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Before we send the request to the server, we can perform fixups on it appropriate to the model.
|
|
||||||
pub fn preprocess(&mut self) {
|
|
||||||
match &self.model {
|
|
||||||
LanguageModel::OpenAi(_) => {}
|
|
||||||
LanguageModel::Anthropic(_) => {}
|
|
||||||
LanguageModel::Ollama(_) => {}
|
|
||||||
LanguageModel::Cloud(model) => match model {
|
|
||||||
CloudModel::Claude3Opus
|
|
||||||
| CloudModel::Claude3Sonnet
|
|
||||||
| CloudModel::Claude3Haiku
|
|
||||||
| CloudModel::Claude3_5Sonnet => {
|
|
||||||
preprocess_anthropic_request(self);
|
|
||||||
}
|
|
||||||
_ => {}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
|
||||||
pub struct LanguageModelResponseMessage {
|
|
||||||
pub role: Option<Role>,
|
|
||||||
pub content: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize, Debug)]
|
#[derive(Deserialize, Debug)]
|
||||||
pub struct LanguageModelUsage {
|
pub struct LanguageModelUsage {
|
||||||
pub prompt_tokens: u32,
|
pub prompt_tokens: u32,
|
||||||
@ -343,7 +182,7 @@ pub fn init(fs: Arc<dyn Fs>, client: Arc<Client>, cx: &mut AppContext) {
|
|||||||
|
|
||||||
context_store::init(&client);
|
context_store::init(&client);
|
||||||
prompt_library::init(cx);
|
prompt_library::init(cx);
|
||||||
completion_provider::init(client.clone(), cx);
|
init_completion_provider(Arc::clone(&client), cx);
|
||||||
assistant_slash_command::init(cx);
|
assistant_slash_command::init(cx);
|
||||||
register_slash_commands(cx);
|
register_slash_commands(cx);
|
||||||
assistant_panel::init(cx);
|
assistant_panel::init(cx);
|
||||||
@ -368,6 +207,20 @@ pub fn init(fs: Arc<dyn Fs>, client: Arc<Client>, cx: &mut AppContext) {
|
|||||||
.detach();
|
.detach();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn init_completion_provider(client: Arc<Client>, cx: &mut AppContext) {
|
||||||
|
let provider = assistant_settings::create_provider_from_settings(client.clone(), 0, cx);
|
||||||
|
cx.set_global(CompletionProvider::new(provider, Some(client)));
|
||||||
|
|
||||||
|
let mut settings_version = 0;
|
||||||
|
cx.observe_global::<SettingsStore>(move |cx| {
|
||||||
|
settings_version += 1;
|
||||||
|
cx.update_global::<CompletionProvider, _>(|provider, cx| {
|
||||||
|
assistant_settings::update_completion_provider_settings(provider, settings_version, cx);
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.detach();
|
||||||
|
}
|
||||||
|
|
||||||
fn register_slash_commands(cx: &mut AppContext) {
|
fn register_slash_commands(cx: &mut AppContext) {
|
||||||
let slash_command_registry = SlashCommandRegistry::global(cx);
|
let slash_command_registry = SlashCommandRegistry::global(cx);
|
||||||
slash_command_registry.register_command(file_command::FileSlashCommand, true);
|
slash_command_registry.register_command(file_command::FileSlashCommand, true);
|
||||||
|
@ -8,18 +8,18 @@ use crate::{
|
|||||||
SlashCommandCompletionProvider, SlashCommandRegistry,
|
SlashCommandCompletionProvider, SlashCommandRegistry,
|
||||||
},
|
},
|
||||||
terminal_inline_assistant::TerminalInlineAssistant,
|
terminal_inline_assistant::TerminalInlineAssistant,
|
||||||
Assist, CompletionProvider, ConfirmCommand, Context, ContextEvent, ContextId, ContextStore,
|
Assist, ConfirmCommand, Context, ContextEvent, ContextId, ContextStore, CycleMessageRole,
|
||||||
CycleMessageRole, DebugEditSteps, DeployHistory, DeployPromptLibrary, EditStep,
|
DebugEditSteps, DeployHistory, DeployPromptLibrary, EditStep, EditStepOperations,
|
||||||
EditStepOperations, EditSuggestionGroup, InlineAssist, InlineAssistId, InlineAssistant,
|
EditSuggestionGroup, InlineAssist, InlineAssistId, InlineAssistant, InsertIntoEditor,
|
||||||
InsertIntoEditor, MessageStatus, ModelSelector, PendingSlashCommand, PendingSlashCommandStatus,
|
MessageStatus, ModelSelector, PendingSlashCommand, PendingSlashCommandStatus, QuoteSelection,
|
||||||
QuoteSelection, RemoteContextMetadata, ResetKey, Role, SavedContextMetadata, Split,
|
RemoteContextMetadata, ResetKey, SavedContextMetadata, Split, ToggleFocus, ToggleModelSelector,
|
||||||
ToggleFocus, ToggleModelSelector,
|
|
||||||
};
|
};
|
||||||
use anyhow::{anyhow, Result};
|
use anyhow::{anyhow, Result};
|
||||||
use assistant_slash_command::{SlashCommand, SlashCommandOutputSection};
|
use assistant_slash_command::{SlashCommand, SlashCommandOutputSection};
|
||||||
use breadcrumbs::Breadcrumbs;
|
use breadcrumbs::Breadcrumbs;
|
||||||
use client::proto;
|
use client::proto;
|
||||||
use collections::{BTreeSet, HashMap, HashSet};
|
use collections::{BTreeSet, HashMap, HashSet};
|
||||||
|
use completion::CompletionProvider;
|
||||||
use editor::{
|
use editor::{
|
||||||
actions::{FoldAt, MoveToEndOfLine, Newline, ShowCompletions, UnfoldAt},
|
actions::{FoldAt, MoveToEndOfLine, Newline, ShowCompletions, UnfoldAt},
|
||||||
display_map::{
|
display_map::{
|
||||||
@ -43,6 +43,7 @@ use language::{
|
|||||||
language_settings::SoftWrap, Buffer, Capability, LanguageRegistry, LspAdapterDelegate, Point,
|
language_settings::SoftWrap, Buffer, Capability, LanguageRegistry, LspAdapterDelegate, Point,
|
||||||
ToOffset,
|
ToOffset,
|
||||||
};
|
};
|
||||||
|
use language_model::Role;
|
||||||
use multi_buffer::MultiBufferRow;
|
use multi_buffer::MultiBufferRow;
|
||||||
use picker::{Picker, PickerDelegate};
|
use picker::{Picker, PickerDelegate};
|
||||||
use project::{Project, ProjectLspAdapterDelegate};
|
use project::{Project, ProjectLspAdapterDelegate};
|
||||||
|
@ -1,166 +1,19 @@
|
|||||||
use std::fmt;
|
use std::{sync::Arc, time::Duration};
|
||||||
|
|
||||||
use crate::{preprocess_anthropic_request, LanguageModel, LanguageModelRequest};
|
use anthropic::Model as AnthropicModel;
|
||||||
pub use anthropic::Model as AnthropicModel;
|
use client::Client;
|
||||||
use gpui::Pixels;
|
use completion::{
|
||||||
pub use ollama::Model as OllamaModel;
|
AnthropicCompletionProvider, CloudCompletionProvider, CompletionProvider,
|
||||||
pub use open_ai::Model as OpenAiModel;
|
LanguageModelCompletionProvider, OllamaCompletionProvider, OpenAiCompletionProvider,
|
||||||
use schemars::{
|
|
||||||
schema::{InstanceType, Metadata, Schema, SchemaObject},
|
|
||||||
JsonSchema,
|
|
||||||
};
|
|
||||||
use serde::{
|
|
||||||
de::{self, Visitor},
|
|
||||||
Deserialize, Deserializer, Serialize, Serializer,
|
|
||||||
};
|
};
|
||||||
|
use gpui::{AppContext, Pixels};
|
||||||
|
use language_model::{CloudModel, LanguageModel};
|
||||||
|
use ollama::Model as OllamaModel;
|
||||||
|
use open_ai::Model as OpenAiModel;
|
||||||
|
use parking_lot::RwLock;
|
||||||
|
use schemars::{schema::Schema, JsonSchema};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
use settings::{Settings, SettingsSources};
|
use settings::{Settings, SettingsSources};
|
||||||
use strum::{EnumIter, IntoEnumIterator};
|
|
||||||
|
|
||||||
#[derive(Clone, Debug, Default, PartialEq, EnumIter)]
|
|
||||||
pub enum CloudModel {
|
|
||||||
Gpt3Point5Turbo,
|
|
||||||
Gpt4,
|
|
||||||
Gpt4Turbo,
|
|
||||||
#[default]
|
|
||||||
Gpt4Omni,
|
|
||||||
Gpt4OmniMini,
|
|
||||||
Claude3_5Sonnet,
|
|
||||||
Claude3Opus,
|
|
||||||
Claude3Sonnet,
|
|
||||||
Claude3Haiku,
|
|
||||||
Gemini15Pro,
|
|
||||||
Gemini15Flash,
|
|
||||||
Custom(String),
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Serialize for CloudModel {
|
|
||||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
|
||||||
where
|
|
||||||
S: Serializer,
|
|
||||||
{
|
|
||||||
serializer.serialize_str(self.id())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'de> Deserialize<'de> for CloudModel {
|
|
||||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
|
||||||
where
|
|
||||||
D: Deserializer<'de>,
|
|
||||||
{
|
|
||||||
struct ZedDotDevModelVisitor;
|
|
||||||
|
|
||||||
impl<'de> Visitor<'de> for ZedDotDevModelVisitor {
|
|
||||||
type Value = CloudModel;
|
|
||||||
|
|
||||||
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
|
|
||||||
formatter.write_str("a string for a ZedDotDevModel variant or a custom model")
|
|
||||||
}
|
|
||||||
|
|
||||||
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
|
|
||||||
where
|
|
||||||
E: de::Error,
|
|
||||||
{
|
|
||||||
let model = CloudModel::iter()
|
|
||||||
.find(|model| model.id() == value)
|
|
||||||
.unwrap_or_else(|| CloudModel::Custom(value.to_string()));
|
|
||||||
Ok(model)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
deserializer.deserialize_str(ZedDotDevModelVisitor)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl JsonSchema for CloudModel {
|
|
||||||
fn schema_name() -> String {
|
|
||||||
"ZedDotDevModel".to_owned()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn json_schema(_generator: &mut schemars::gen::SchemaGenerator) -> Schema {
|
|
||||||
let variants = CloudModel::iter()
|
|
||||||
.filter_map(|model| {
|
|
||||||
let id = model.id();
|
|
||||||
if id.is_empty() {
|
|
||||||
None
|
|
||||||
} else {
|
|
||||||
Some(id.to_string())
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
Schema::Object(SchemaObject {
|
|
||||||
instance_type: Some(InstanceType::String.into()),
|
|
||||||
enum_values: Some(variants.iter().map(|s| s.clone().into()).collect()),
|
|
||||||
metadata: Some(Box::new(Metadata {
|
|
||||||
title: Some("ZedDotDevModel".to_owned()),
|
|
||||||
default: Some(CloudModel::default().id().into()),
|
|
||||||
examples: variants.into_iter().map(Into::into).collect(),
|
|
||||||
..Default::default()
|
|
||||||
})),
|
|
||||||
..Default::default()
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl CloudModel {
|
|
||||||
pub fn id(&self) -> &str {
|
|
||||||
match self {
|
|
||||||
Self::Gpt3Point5Turbo => "gpt-3.5-turbo",
|
|
||||||
Self::Gpt4 => "gpt-4",
|
|
||||||
Self::Gpt4Turbo => "gpt-4-turbo-preview",
|
|
||||||
Self::Gpt4Omni => "gpt-4o",
|
|
||||||
Self::Gpt4OmniMini => "gpt-4o-mini",
|
|
||||||
Self::Claude3_5Sonnet => "claude-3-5-sonnet",
|
|
||||||
Self::Claude3Opus => "claude-3-opus",
|
|
||||||
Self::Claude3Sonnet => "claude-3-sonnet",
|
|
||||||
Self::Claude3Haiku => "claude-3-haiku",
|
|
||||||
Self::Gemini15Pro => "gemini-1.5-pro",
|
|
||||||
Self::Gemini15Flash => "gemini-1.5-flash",
|
|
||||||
Self::Custom(id) => id,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn display_name(&self) -> &str {
|
|
||||||
match self {
|
|
||||||
Self::Gpt3Point5Turbo => "GPT 3.5 Turbo",
|
|
||||||
Self::Gpt4 => "GPT 4",
|
|
||||||
Self::Gpt4Turbo => "GPT 4 Turbo",
|
|
||||||
Self::Gpt4Omni => "GPT 4 Omni",
|
|
||||||
Self::Gpt4OmniMini => "GPT 4 Omni Mini",
|
|
||||||
Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
|
|
||||||
Self::Claude3Opus => "Claude 3 Opus",
|
|
||||||
Self::Claude3Sonnet => "Claude 3 Sonnet",
|
|
||||||
Self::Claude3Haiku => "Claude 3 Haiku",
|
|
||||||
Self::Gemini15Pro => "Gemini 1.5 Pro",
|
|
||||||
Self::Gemini15Flash => "Gemini 1.5 Flash",
|
|
||||||
Self::Custom(id) => id.as_str(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn max_token_count(&self) -> usize {
|
|
||||||
match self {
|
|
||||||
Self::Gpt3Point5Turbo => 2048,
|
|
||||||
Self::Gpt4 => 4096,
|
|
||||||
Self::Gpt4Turbo | Self::Gpt4Omni => 128000,
|
|
||||||
Self::Gpt4OmniMini => 128000,
|
|
||||||
Self::Claude3_5Sonnet
|
|
||||||
| Self::Claude3Opus
|
|
||||||
| Self::Claude3Sonnet
|
|
||||||
| Self::Claude3Haiku => 200000,
|
|
||||||
Self::Gemini15Pro => 128000,
|
|
||||||
Self::Gemini15Flash => 32000,
|
|
||||||
Self::Custom(_) => 4096, // TODO: Make this configurable
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn preprocess_request(&self, request: &mut LanguageModelRequest) {
|
|
||||||
match self {
|
|
||||||
Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => {
|
|
||||||
preprocess_anthropic_request(request)
|
|
||||||
}
|
|
||||||
_ => {}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Copy, Clone, Default, Debug, Serialize, Deserialize, JsonSchema)]
|
#[derive(Copy, Clone, Default, Debug, Serialize, Deserialize, JsonSchema)]
|
||||||
#[serde(rename_all = "snake_case")]
|
#[serde(rename_all = "snake_case")]
|
||||||
@ -620,6 +473,124 @@ fn merge<T>(target: &mut T, value: Option<T>) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn update_completion_provider_settings(
|
||||||
|
provider: &mut CompletionProvider,
|
||||||
|
version: usize,
|
||||||
|
cx: &mut AppContext,
|
||||||
|
) {
|
||||||
|
let updated = match &AssistantSettings::get_global(cx).provider {
|
||||||
|
AssistantProvider::ZedDotDev { model } => provider
|
||||||
|
.update_current_as::<_, CloudCompletionProvider>(|provider| {
|
||||||
|
provider.update(model.clone(), version);
|
||||||
|
}),
|
||||||
|
AssistantProvider::OpenAi {
|
||||||
|
model,
|
||||||
|
api_url,
|
||||||
|
low_speed_timeout_in_seconds,
|
||||||
|
available_models,
|
||||||
|
} => provider.update_current_as::<_, OpenAiCompletionProvider>(|provider| {
|
||||||
|
provider.update(
|
||||||
|
choose_openai_model(&model, &available_models),
|
||||||
|
api_url.clone(),
|
||||||
|
low_speed_timeout_in_seconds.map(Duration::from_secs),
|
||||||
|
version,
|
||||||
|
);
|
||||||
|
}),
|
||||||
|
AssistantProvider::Anthropic {
|
||||||
|
model,
|
||||||
|
api_url,
|
||||||
|
low_speed_timeout_in_seconds,
|
||||||
|
} => provider.update_current_as::<_, AnthropicCompletionProvider>(|provider| {
|
||||||
|
provider.update(
|
||||||
|
model.clone(),
|
||||||
|
api_url.clone(),
|
||||||
|
low_speed_timeout_in_seconds.map(Duration::from_secs),
|
||||||
|
version,
|
||||||
|
);
|
||||||
|
}),
|
||||||
|
AssistantProvider::Ollama {
|
||||||
|
model,
|
||||||
|
api_url,
|
||||||
|
low_speed_timeout_in_seconds,
|
||||||
|
} => provider.update_current_as::<_, OllamaCompletionProvider>(|provider| {
|
||||||
|
provider.update(
|
||||||
|
model.clone(),
|
||||||
|
api_url.clone(),
|
||||||
|
low_speed_timeout_in_seconds.map(Duration::from_secs),
|
||||||
|
version,
|
||||||
|
cx,
|
||||||
|
);
|
||||||
|
}),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Previously configured provider was changed to another one
|
||||||
|
if updated.is_none() {
|
||||||
|
provider.update_provider(|client| create_provider_from_settings(client, version, cx));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn create_provider_from_settings(
|
||||||
|
client: Arc<Client>,
|
||||||
|
settings_version: usize,
|
||||||
|
cx: &mut AppContext,
|
||||||
|
) -> Arc<RwLock<dyn LanguageModelCompletionProvider>> {
|
||||||
|
match &AssistantSettings::get_global(cx).provider {
|
||||||
|
AssistantProvider::ZedDotDev { model } => Arc::new(RwLock::new(
|
||||||
|
CloudCompletionProvider::new(model.clone(), client.clone(), settings_version, cx),
|
||||||
|
)),
|
||||||
|
AssistantProvider::OpenAi {
|
||||||
|
model,
|
||||||
|
api_url,
|
||||||
|
low_speed_timeout_in_seconds,
|
||||||
|
available_models,
|
||||||
|
} => Arc::new(RwLock::new(OpenAiCompletionProvider::new(
|
||||||
|
choose_openai_model(&model, &available_models),
|
||||||
|
api_url.clone(),
|
||||||
|
client.http_client(),
|
||||||
|
low_speed_timeout_in_seconds.map(Duration::from_secs),
|
||||||
|
settings_version,
|
||||||
|
available_models.clone(),
|
||||||
|
))),
|
||||||
|
AssistantProvider::Anthropic {
|
||||||
|
model,
|
||||||
|
api_url,
|
||||||
|
low_speed_timeout_in_seconds,
|
||||||
|
} => Arc::new(RwLock::new(AnthropicCompletionProvider::new(
|
||||||
|
model.clone(),
|
||||||
|
api_url.clone(),
|
||||||
|
client.http_client(),
|
||||||
|
low_speed_timeout_in_seconds.map(Duration::from_secs),
|
||||||
|
settings_version,
|
||||||
|
))),
|
||||||
|
AssistantProvider::Ollama {
|
||||||
|
model,
|
||||||
|
api_url,
|
||||||
|
low_speed_timeout_in_seconds,
|
||||||
|
} => Arc::new(RwLock::new(OllamaCompletionProvider::new(
|
||||||
|
model.clone(),
|
||||||
|
api_url.clone(),
|
||||||
|
client.http_client(),
|
||||||
|
low_speed_timeout_in_seconds.map(Duration::from_secs),
|
||||||
|
settings_version,
|
||||||
|
cx,
|
||||||
|
))),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Choose which model to use for openai provider.
|
||||||
|
/// If the model is not available, try to use the first available model, or fallback to the original model.
|
||||||
|
fn choose_openai_model(
|
||||||
|
model: &::open_ai::Model,
|
||||||
|
available_models: &[::open_ai::Model],
|
||||||
|
) -> ::open_ai::Model {
|
||||||
|
available_models
|
||||||
|
.iter()
|
||||||
|
.find(|&m| m == model)
|
||||||
|
.or_else(|| available_models.first())
|
||||||
|
.unwrap_or_else(|| model)
|
||||||
|
.clone()
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use gpui::{AppContext, UpdateGlobal};
|
use gpui::{AppContext, UpdateGlobal};
|
||||||
|
@ -1,12 +1,12 @@
|
|||||||
use crate::{
|
use crate::{
|
||||||
prompt_library::PromptStore, slash_command::SlashCommandLine, CompletionProvider,
|
prompt_library::PromptStore, slash_command::SlashCommandLine, CompletionProvider, MessageId,
|
||||||
LanguageModelRequest, LanguageModelRequestMessage, MessageId, MessageStatus, Role,
|
MessageStatus,
|
||||||
};
|
};
|
||||||
use anyhow::{anyhow, Context as _, Result};
|
use anyhow::{anyhow, Context as _, Result};
|
||||||
use assistant_slash_command::{
|
use assistant_slash_command::{
|
||||||
SlashCommandOutput, SlashCommandOutputSection, SlashCommandRegistry,
|
SlashCommandOutput, SlashCommandOutputSection, SlashCommandRegistry,
|
||||||
};
|
};
|
||||||
use client::{proto, telemetry::Telemetry};
|
use client::{self, proto, telemetry::Telemetry};
|
||||||
use clock::ReplicaId;
|
use clock::ReplicaId;
|
||||||
use collections::{HashMap, HashSet};
|
use collections::{HashMap, HashSet};
|
||||||
use fs::Fs;
|
use fs::Fs;
|
||||||
@ -18,6 +18,8 @@ use gpui::{AppContext, Context as _, EventEmitter, Model, ModelContext, Subscrip
|
|||||||
use language::{
|
use language::{
|
||||||
AnchorRangeExt, Bias, Buffer, LanguageRegistry, OffsetRangeExt, ParseStatus, Point, ToOffset,
|
AnchorRangeExt, Bias, Buffer, LanguageRegistry, OffsetRangeExt, ParseStatus, Point, ToOffset,
|
||||||
};
|
};
|
||||||
|
use language_model::LanguageModelRequestMessage;
|
||||||
|
use language_model::{LanguageModelRequest, Role};
|
||||||
use open_ai::Model as OpenAiModel;
|
use open_ai::Model as OpenAiModel;
|
||||||
use paths::contexts_dir;
|
use paths::contexts_dir;
|
||||||
use project::Project;
|
use project::Project;
|
||||||
@ -2477,9 +2479,10 @@ mod tests {
|
|||||||
use crate::{
|
use crate::{
|
||||||
assistant_panel, prompt_library,
|
assistant_panel, prompt_library,
|
||||||
slash_command::{active_command, file_command},
|
slash_command::{active_command, file_command},
|
||||||
FakeCompletionProvider, MessageId,
|
MessageId,
|
||||||
};
|
};
|
||||||
use assistant_slash_command::{ArgumentCompletion, SlashCommand};
|
use assistant_slash_command::{ArgumentCompletion, SlashCommand};
|
||||||
|
use completion::FakeCompletionProvider;
|
||||||
use fs::FakeFs;
|
use fs::FakeFs;
|
||||||
use gpui::{AppContext, TestAppContext, WeakView};
|
use gpui::{AppContext, TestAppContext, WeakView};
|
||||||
use indoc::indoc;
|
use indoc::indoc;
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
use crate::{
|
use crate::{
|
||||||
assistant_settings::AssistantSettings, humanize_token_count, prompts::generate_content_prompt,
|
assistant_settings::AssistantSettings, humanize_token_count, prompts::generate_content_prompt,
|
||||||
AssistantPanel, AssistantPanelEvent, CompletionProvider, Hunk, LanguageModelRequest,
|
AssistantPanel, AssistantPanelEvent, CompletionProvider, Hunk, StreamingDiff,
|
||||||
LanguageModelRequestMessage, Role, StreamingDiff,
|
|
||||||
};
|
};
|
||||||
use anyhow::{anyhow, Context as _, Result};
|
use anyhow::{anyhow, Context as _, Result};
|
||||||
use client::telemetry::Telemetry;
|
use client::telemetry::Telemetry;
|
||||||
@ -28,6 +27,7 @@ use gpui::{
|
|||||||
WhiteSpace, WindowContext,
|
WhiteSpace, WindowContext,
|
||||||
};
|
};
|
||||||
use language::{Buffer, Point, Selection, TransactionId};
|
use language::{Buffer, Point, Selection, TransactionId};
|
||||||
|
use language_model::{LanguageModelRequest, LanguageModelRequestMessage, Role};
|
||||||
use multi_buffer::MultiBufferRow;
|
use multi_buffer::MultiBufferRow;
|
||||||
use parking_lot::Mutex;
|
use parking_lot::Mutex;
|
||||||
use rope::Rope;
|
use rope::Rope;
|
||||||
@ -1432,8 +1432,7 @@ impl Render for PromptEditor {
|
|||||||
PopoverMenu::new("model-switcher")
|
PopoverMenu::new("model-switcher")
|
||||||
.menu(move |cx| {
|
.menu(move |cx| {
|
||||||
ContextMenu::build(cx, |mut menu, cx| {
|
ContextMenu::build(cx, |mut menu, cx| {
|
||||||
for model in CompletionProvider::global(cx).available_models(cx)
|
for model in CompletionProvider::global(cx).available_models() {
|
||||||
{
|
|
||||||
menu = menu.custom_entry(
|
menu = menu.custom_entry(
|
||||||
{
|
{
|
||||||
let model = model.clone();
|
let model = model.clone();
|
||||||
@ -2606,7 +2605,7 @@ fn merge_ranges(ranges: &mut Vec<Range<Anchor>>, buffer: &MultiBufferSnapshot) {
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::FakeCompletionProvider;
|
use completion::FakeCompletionProvider;
|
||||||
use futures::stream::{self};
|
use futures::stream::{self};
|
||||||
use gpui::{Context, TestAppContext};
|
use gpui::{Context, TestAppContext};
|
||||||
use indoc::indoc;
|
use indoc::indoc;
|
||||||
|
@ -23,7 +23,7 @@ impl RenderOnce for ModelSelector {
|
|||||||
.with_handle(self.handle)
|
.with_handle(self.handle)
|
||||||
.menu(move |cx| {
|
.menu(move |cx| {
|
||||||
ContextMenu::build(cx, |mut menu, cx| {
|
ContextMenu::build(cx, |mut menu, cx| {
|
||||||
for model in CompletionProvider::global(cx).available_models(cx) {
|
for model in CompletionProvider::global(cx).available_models() {
|
||||||
menu = menu.custom_entry(
|
menu = menu.custom_entry(
|
||||||
{
|
{
|
||||||
let model = model.clone();
|
let model = model.clone();
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
use crate::{
|
use crate::{
|
||||||
slash_command::SlashCommandCompletionProvider, AssistantPanel, CompletionProvider,
|
slash_command::SlashCommandCompletionProvider, AssistantPanel, CompletionProvider,
|
||||||
InlineAssist, InlineAssistant, LanguageModelRequest, LanguageModelRequestMessage, Role,
|
InlineAssist, InlineAssistant,
|
||||||
};
|
};
|
||||||
use anyhow::{anyhow, Result};
|
use anyhow::{anyhow, Result};
|
||||||
use assets::Assets;
|
use assets::Assets;
|
||||||
@ -19,6 +19,7 @@ use gpui::{
|
|||||||
};
|
};
|
||||||
use heed::{types::SerdeBincode, Database, RoTxn};
|
use heed::{types::SerdeBincode, Database, RoTxn};
|
||||||
use language::{language_settings::SoftWrap, Buffer, LanguageRegistry};
|
use language::{language_settings::SoftWrap, Buffer, LanguageRegistry};
|
||||||
|
use language_model::{LanguageModelRequest, LanguageModelRequestMessage, Role};
|
||||||
use parking_lot::RwLock;
|
use parking_lot::RwLock;
|
||||||
use picker::{Picker, PickerDelegate};
|
use picker::{Picker, PickerDelegate};
|
||||||
use rope::Rope;
|
use rope::Rope;
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
use crate::{
|
use crate::{
|
||||||
assistant_settings::AssistantSettings, humanize_token_count,
|
assistant_settings::AssistantSettings, humanize_token_count,
|
||||||
prompts::generate_terminal_assistant_prompt, AssistantPanel, AssistantPanelEvent,
|
prompts::generate_terminal_assistant_prompt, AssistantPanel, AssistantPanelEvent,
|
||||||
CompletionProvider, LanguageModelRequest, LanguageModelRequestMessage, Role,
|
CompletionProvider,
|
||||||
};
|
};
|
||||||
use anyhow::{Context as _, Result};
|
use anyhow::{Context as _, Result};
|
||||||
use client::telemetry::Telemetry;
|
use client::telemetry::Telemetry;
|
||||||
@ -17,6 +17,7 @@ use gpui::{
|
|||||||
Model, ModelContext, Subscription, Task, TextStyle, UpdateGlobal, View, WeakView, WhiteSpace,
|
Model, ModelContext, Subscription, Task, TextStyle, UpdateGlobal, View, WeakView, WhiteSpace,
|
||||||
};
|
};
|
||||||
use language::Buffer;
|
use language::Buffer;
|
||||||
|
use language_model::{LanguageModelRequest, LanguageModelRequestMessage, Role};
|
||||||
use settings::{update_settings_file, Settings};
|
use settings::{update_settings_file, Settings};
|
||||||
use std::{
|
use std::{
|
||||||
cmp,
|
cmp,
|
||||||
@ -558,8 +559,7 @@ impl Render for PromptEditor {
|
|||||||
PopoverMenu::new("model-switcher")
|
PopoverMenu::new("model-switcher")
|
||||||
.menu(move |cx| {
|
.menu(move |cx| {
|
||||||
ContextMenu::build(cx, |mut menu, cx| {
|
ContextMenu::build(cx, |mut menu, cx| {
|
||||||
for model in CompletionProvider::global(cx).available_models(cx)
|
for model in CompletionProvider::global(cx).available_models() {
|
||||||
{
|
|
||||||
menu = menu.custom_entry(
|
menu = menu.custom_entry(
|
||||||
{
|
{
|
||||||
let model = model.clone();
|
let model = model.clone();
|
||||||
|
@ -30,6 +30,7 @@ chrono.workspace = true
|
|||||||
clock.workspace = true
|
clock.workspace = true
|
||||||
clickhouse.workspace = true
|
clickhouse.workspace = true
|
||||||
collections.workspace = true
|
collections.workspace = true
|
||||||
|
completion.workspace = true
|
||||||
dashmap = "5.4"
|
dashmap = "5.4"
|
||||||
envy = "0.4.2"
|
envy = "0.4.2"
|
||||||
futures.workspace = true
|
futures.workspace = true
|
||||||
@ -79,6 +80,7 @@ channel.workspace = true
|
|||||||
client = { workspace = true, features = ["test-support"] }
|
client = { workspace = true, features = ["test-support"] }
|
||||||
collab_ui = { workspace = true, features = ["test-support"] }
|
collab_ui = { workspace = true, features = ["test-support"] }
|
||||||
collections = { workspace = true, features = ["test-support"] }
|
collections = { workspace = true, features = ["test-support"] }
|
||||||
|
completion = { workspace = true, features = ["test-support"] }
|
||||||
ctor.workspace = true
|
ctor.workspace = true
|
||||||
editor = { workspace = true, features = ["test-support"] }
|
editor = { workspace = true, features = ["test-support"] }
|
||||||
env_logger.workspace = true
|
env_logger.workspace = true
|
||||||
|
@ -295,7 +295,7 @@ impl TestServer {
|
|||||||
menu::init();
|
menu::init();
|
||||||
dev_server_projects::init(client.clone(), cx);
|
dev_server_projects::init(client.clone(), cx);
|
||||||
settings::KeymapFile::load_asset(os_keymap, cx).unwrap();
|
settings::KeymapFile::load_asset(os_keymap, cx).unwrap();
|
||||||
assistant::FakeCompletionProvider::setup_test(cx);
|
completion::FakeCompletionProvider::setup_test(cx);
|
||||||
assistant::context_store::init(&client);
|
assistant::context_store::init(&client);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
56
crates/completion/Cargo.toml
Normal file
56
crates/completion/Cargo.toml
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
[package]
|
||||||
|
name = "completion"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2021"
|
||||||
|
publish = false
|
||||||
|
license = "GPL-3.0-or-later"
|
||||||
|
|
||||||
|
[lints]
|
||||||
|
workspace = true
|
||||||
|
|
||||||
|
[lib]
|
||||||
|
path = "src/completion.rs"
|
||||||
|
doctest = false
|
||||||
|
|
||||||
|
[features]
|
||||||
|
test-support = [
|
||||||
|
"editor/test-support",
|
||||||
|
"language/test-support",
|
||||||
|
"project/test-support",
|
||||||
|
"text/test-support",
|
||||||
|
]
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
anthropic = { workspace = true, features = ["schemars"] }
|
||||||
|
anyhow.workspace = true
|
||||||
|
client.workspace = true
|
||||||
|
collections.workspace = true
|
||||||
|
editor.workspace = true
|
||||||
|
futures.workspace = true
|
||||||
|
gpui.workspace = true
|
||||||
|
http.workspace = true
|
||||||
|
language_model.workspace = true
|
||||||
|
log.workspace = true
|
||||||
|
menu.workspace = true
|
||||||
|
ollama = { workspace = true, features = ["schemars"] }
|
||||||
|
open_ai = { workspace = true, features = ["schemars"] }
|
||||||
|
parking_lot.workspace = true
|
||||||
|
serde.workspace = true
|
||||||
|
serde_json.workspace = true
|
||||||
|
settings.workspace = true
|
||||||
|
smol.workspace = true
|
||||||
|
strum.workspace = true
|
||||||
|
theme.workspace = true
|
||||||
|
tiktoken-rs.workspace = true
|
||||||
|
ui.workspace = true
|
||||||
|
util.workspace = true
|
||||||
|
|
||||||
|
[dev-dependencies]
|
||||||
|
ctor.workspace = true
|
||||||
|
editor = { workspace = true, features = ["test-support"] }
|
||||||
|
env_logger.workspace = true
|
||||||
|
language = { workspace = true, features = ["test-support"] }
|
||||||
|
project = { workspace = true, features = ["test-support"] }
|
||||||
|
rand.workspace = true
|
||||||
|
text = { workspace = true, features = ["test-support"] }
|
||||||
|
unindent.workspace = true
|
1
crates/completion/LICENSE-GPL
Symbolic link
1
crates/completion/LICENSE-GPL
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../LICENSE-GPL
|
@ -1,14 +1,12 @@
|
|||||||
use crate::{
|
use crate::{count_open_ai_tokens, LanguageModelCompletionProvider};
|
||||||
assistant_settings::AnthropicModel, CompletionProvider, LanguageModel, LanguageModelRequest,
|
use crate::{CompletionProvider, LanguageModel, LanguageModelRequest};
|
||||||
Role,
|
use anthropic::{stream_completion, Model as AnthropicModel, Request, RequestMessage};
|
||||||
};
|
|
||||||
use crate::{count_open_ai_tokens, LanguageModelCompletionProvider, LanguageModelRequestMessage};
|
|
||||||
use anthropic::{stream_completion, Request, RequestMessage};
|
|
||||||
use anyhow::{anyhow, Result};
|
use anyhow::{anyhow, Result};
|
||||||
use editor::{Editor, EditorElement, EditorStyle};
|
use editor::{Editor, EditorElement, EditorStyle};
|
||||||
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
|
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
|
||||||
use gpui::{AnyView, AppContext, FontStyle, Task, TextStyle, View, WhiteSpace};
|
use gpui::{AnyView, AppContext, FontStyle, Task, TextStyle, View, WhiteSpace};
|
||||||
use http::HttpClient;
|
use http::HttpClient;
|
||||||
|
use language_model::Role;
|
||||||
use settings::Settings;
|
use settings::Settings;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use std::{env, sync::Arc};
|
use std::{env, sync::Arc};
|
||||||
@ -27,7 +25,7 @@ pub struct AnthropicCompletionProvider {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl LanguageModelCompletionProvider for AnthropicCompletionProvider {
|
impl LanguageModelCompletionProvider for AnthropicCompletionProvider {
|
||||||
fn available_models(&self, _cx: &AppContext) -> Vec<LanguageModel> {
|
fn available_models(&self) -> Vec<LanguageModel> {
|
||||||
AnthropicModel::iter()
|
AnthropicModel::iter()
|
||||||
.map(LanguageModel::Anthropic)
|
.map(LanguageModel::Anthropic)
|
||||||
.collect()
|
.collect()
|
||||||
@ -176,7 +174,7 @@ impl AnthropicCompletionProvider {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn to_anthropic_request(&self, mut request: LanguageModelRequest) -> Request {
|
fn to_anthropic_request(&self, mut request: LanguageModelRequest) -> Request {
|
||||||
preprocess_anthropic_request(&mut request);
|
request.preprocess_anthropic();
|
||||||
|
|
||||||
let model = match request.model {
|
let model = match request.model {
|
||||||
LanguageModel::Anthropic(model) => model,
|
LanguageModel::Anthropic(model) => model,
|
||||||
@ -213,49 +211,6 @@ impl AnthropicCompletionProvider {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn preprocess_anthropic_request(request: &mut LanguageModelRequest) {
|
|
||||||
let mut new_messages: Vec<LanguageModelRequestMessage> = Vec::new();
|
|
||||||
let mut system_message = String::new();
|
|
||||||
|
|
||||||
for message in request.messages.drain(..) {
|
|
||||||
if message.content.is_empty() {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
match message.role {
|
|
||||||
Role::User | Role::Assistant => {
|
|
||||||
if let Some(last_message) = new_messages.last_mut() {
|
|
||||||
if last_message.role == message.role {
|
|
||||||
last_message.content.push_str("\n\n");
|
|
||||||
last_message.content.push_str(&message.content);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
new_messages.push(message);
|
|
||||||
}
|
|
||||||
Role::System => {
|
|
||||||
if !system_message.is_empty() {
|
|
||||||
system_message.push_str("\n\n");
|
|
||||||
}
|
|
||||||
system_message.push_str(&message.content);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !system_message.is_empty() {
|
|
||||||
new_messages.insert(
|
|
||||||
0,
|
|
||||||
LanguageModelRequestMessage {
|
|
||||||
role: Role::System,
|
|
||||||
content: system_message,
|
|
||||||
},
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
request.messages = new_messages;
|
|
||||||
}
|
|
||||||
|
|
||||||
struct AuthenticationPrompt {
|
struct AuthenticationPrompt {
|
||||||
api_key: View<Editor>,
|
api_key: View<Editor>,
|
||||||
api_url: String,
|
api_url: String,
|
@ -1,11 +1,12 @@
|
|||||||
use crate::{
|
use crate::{
|
||||||
assistant_settings::CloudModel, count_open_ai_tokens, CompletionProvider, LanguageModel,
|
count_open_ai_tokens, CompletionProvider, LanguageModel, LanguageModelCompletionProvider,
|
||||||
LanguageModelCompletionProvider, LanguageModelRequest,
|
LanguageModelRequest,
|
||||||
};
|
};
|
||||||
use anyhow::{anyhow, Result};
|
use anyhow::{anyhow, Result};
|
||||||
use client::{proto, Client};
|
use client::{proto, Client};
|
||||||
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt, TryFutureExt};
|
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt, TryFutureExt};
|
||||||
use gpui::{AnyView, AppContext, Task};
|
use gpui::{AnyView, AppContext, Task};
|
||||||
|
use language_model::CloudModel;
|
||||||
use std::{future, sync::Arc};
|
use std::{future, sync::Arc};
|
||||||
use strum::IntoEnumIterator;
|
use strum::IntoEnumIterator;
|
||||||
use ui::prelude::*;
|
use ui::prelude::*;
|
||||||
@ -52,7 +53,7 @@ impl CloudCompletionProvider {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl LanguageModelCompletionProvider for CloudCompletionProvider {
|
impl LanguageModelCompletionProvider for CloudCompletionProvider {
|
||||||
fn available_models(&self, _cx: &AppContext) -> Vec<LanguageModel> {
|
fn available_models(&self) -> Vec<LanguageModel> {
|
||||||
let mut custom_model = if let CloudModel::Custom(custom_model) = self.model.clone() {
|
let mut custom_model = if let CloudModel::Custom(custom_model) = self.model.clone() {
|
||||||
Some(custom_model)
|
Some(custom_model)
|
||||||
} else {
|
} else {
|
@ -6,52 +6,19 @@ mod ollama;
|
|||||||
mod open_ai;
|
mod open_ai;
|
||||||
|
|
||||||
pub use anthropic::*;
|
pub use anthropic::*;
|
||||||
|
use anyhow::Result;
|
||||||
|
use client::Client;
|
||||||
pub use cloud::*;
|
pub use cloud::*;
|
||||||
#[cfg(any(test, feature = "test-support"))]
|
#[cfg(any(test, feature = "test-support"))]
|
||||||
pub use fake::*;
|
pub use fake::*;
|
||||||
|
use futures::{future::BoxFuture, stream::BoxStream, StreamExt};
|
||||||
|
use gpui::{AnyView, AppContext, Task, WindowContext};
|
||||||
|
use language_model::{LanguageModel, LanguageModelRequest};
|
||||||
pub use ollama::*;
|
pub use ollama::*;
|
||||||
pub use open_ai::*;
|
pub use open_ai::*;
|
||||||
use parking_lot::RwLock;
|
use parking_lot::RwLock;
|
||||||
use smol::lock::{Semaphore, SemaphoreGuardArc};
|
use smol::lock::{Semaphore, SemaphoreGuardArc};
|
||||||
|
use std::{any::Any, pin::Pin, sync::Arc, task::Poll};
|
||||||
use crate::{
|
|
||||||
assistant_settings::{AssistantProvider, AssistantSettings},
|
|
||||||
LanguageModel, LanguageModelRequest,
|
|
||||||
};
|
|
||||||
use anyhow::Result;
|
|
||||||
use client::Client;
|
|
||||||
use futures::{future::BoxFuture, stream::BoxStream, StreamExt};
|
|
||||||
use gpui::{AnyView, AppContext, BorrowAppContext, Task, WindowContext};
|
|
||||||
use settings::{Settings, SettingsStore};
|
|
||||||
use std::{any::Any, pin::Pin, sync::Arc, task::Poll, time::Duration};
|
|
||||||
|
|
||||||
/// Choose which model to use for openai provider.
|
|
||||||
/// If the model is not available, try to use the first available model, or fallback to the original model.
|
|
||||||
fn choose_openai_model(
|
|
||||||
model: &::open_ai::Model,
|
|
||||||
available_models: &[::open_ai::Model],
|
|
||||||
) -> ::open_ai::Model {
|
|
||||||
available_models
|
|
||||||
.iter()
|
|
||||||
.find(|&m| m == model)
|
|
||||||
.or_else(|| available_models.first())
|
|
||||||
.unwrap_or_else(|| model)
|
|
||||||
.clone()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn init(client: Arc<Client>, cx: &mut AppContext) {
|
|
||||||
let provider = create_provider_from_settings(client.clone(), 0, cx);
|
|
||||||
cx.set_global(CompletionProvider::new(provider, Some(client)));
|
|
||||||
|
|
||||||
let mut settings_version = 0;
|
|
||||||
cx.observe_global::<SettingsStore>(move |cx| {
|
|
||||||
settings_version += 1;
|
|
||||||
cx.update_global::<CompletionProvider, _>(|provider, cx| {
|
|
||||||
provider.update_settings(settings_version, cx);
|
|
||||||
})
|
|
||||||
})
|
|
||||||
.detach();
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct CompletionResponse {
|
pub struct CompletionResponse {
|
||||||
inner: BoxStream<'static, Result<String>>,
|
inner: BoxStream<'static, Result<String>>,
|
||||||
@ -70,7 +37,7 @@ impl futures::Stream for CompletionResponse {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub trait LanguageModelCompletionProvider: Send + Sync {
|
pub trait LanguageModelCompletionProvider: Send + Sync {
|
||||||
fn available_models(&self, cx: &AppContext) -> Vec<LanguageModel>;
|
fn available_models(&self) -> Vec<LanguageModel>;
|
||||||
fn settings_version(&self) -> usize;
|
fn settings_version(&self) -> usize;
|
||||||
fn is_authenticated(&self) -> bool;
|
fn is_authenticated(&self) -> bool;
|
||||||
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>>;
|
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>>;
|
||||||
@ -110,8 +77,8 @@ impl CompletionProvider {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn available_models(&self, cx: &AppContext) -> Vec<LanguageModel> {
|
pub fn available_models(&self) -> Vec<LanguageModel> {
|
||||||
self.provider.read().available_models(cx)
|
self.provider.read().available_models()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn settings_version(&self) -> usize {
|
pub fn settings_version(&self) -> usize {
|
||||||
@ -176,6 +143,17 @@ impl CompletionProvider {
|
|||||||
Ok(completion)
|
Ok(completion)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn update_provider(
|
||||||
|
&mut self,
|
||||||
|
get_provider: impl FnOnce(Arc<Client>) -> Arc<RwLock<dyn LanguageModelCompletionProvider>>,
|
||||||
|
) {
|
||||||
|
if let Some(client) = &self.client {
|
||||||
|
self.provider = get_provider(Arc::clone(client));
|
||||||
|
} else {
|
||||||
|
log::warn!("completion provider cannot be updated because its client was not set");
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl gpui::Global for CompletionProvider {}
|
impl gpui::Global for CompletionProvider {}
|
||||||
@ -196,109 +174,6 @@ impl CompletionProvider {
|
|||||||
None
|
None
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn update_settings(&mut self, version: usize, cx: &mut AppContext) {
|
|
||||||
let updated = match &AssistantSettings::get_global(cx).provider {
|
|
||||||
AssistantProvider::ZedDotDev { model } => self
|
|
||||||
.update_current_as::<_, CloudCompletionProvider>(|provider| {
|
|
||||||
provider.update(model.clone(), version);
|
|
||||||
}),
|
|
||||||
AssistantProvider::OpenAi {
|
|
||||||
model,
|
|
||||||
api_url,
|
|
||||||
low_speed_timeout_in_seconds,
|
|
||||||
available_models,
|
|
||||||
} => self.update_current_as::<_, OpenAiCompletionProvider>(|provider| {
|
|
||||||
provider.update(
|
|
||||||
choose_openai_model(&model, &available_models),
|
|
||||||
api_url.clone(),
|
|
||||||
low_speed_timeout_in_seconds.map(Duration::from_secs),
|
|
||||||
version,
|
|
||||||
);
|
|
||||||
}),
|
|
||||||
AssistantProvider::Anthropic {
|
|
||||||
model,
|
|
||||||
api_url,
|
|
||||||
low_speed_timeout_in_seconds,
|
|
||||||
} => self.update_current_as::<_, AnthropicCompletionProvider>(|provider| {
|
|
||||||
provider.update(
|
|
||||||
model.clone(),
|
|
||||||
api_url.clone(),
|
|
||||||
low_speed_timeout_in_seconds.map(Duration::from_secs),
|
|
||||||
version,
|
|
||||||
);
|
|
||||||
}),
|
|
||||||
AssistantProvider::Ollama {
|
|
||||||
model,
|
|
||||||
api_url,
|
|
||||||
low_speed_timeout_in_seconds,
|
|
||||||
} => self.update_current_as::<_, OllamaCompletionProvider>(|provider| {
|
|
||||||
provider.update(
|
|
||||||
model.clone(),
|
|
||||||
api_url.clone(),
|
|
||||||
low_speed_timeout_in_seconds.map(Duration::from_secs),
|
|
||||||
version,
|
|
||||||
cx,
|
|
||||||
);
|
|
||||||
}),
|
|
||||||
};
|
|
||||||
|
|
||||||
// Previously configured provider was changed to another one
|
|
||||||
if updated.is_none() {
|
|
||||||
if let Some(client) = self.client.clone() {
|
|
||||||
self.provider = create_provider_from_settings(client, version, cx);
|
|
||||||
} else {
|
|
||||||
log::warn!("completion provider cannot be created because client is not set");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn create_provider_from_settings(
|
|
||||||
client: Arc<Client>,
|
|
||||||
settings_version: usize,
|
|
||||||
cx: &mut AppContext,
|
|
||||||
) -> Arc<RwLock<dyn LanguageModelCompletionProvider>> {
|
|
||||||
match &AssistantSettings::get_global(cx).provider {
|
|
||||||
AssistantProvider::ZedDotDev { model } => Arc::new(RwLock::new(
|
|
||||||
CloudCompletionProvider::new(model.clone(), client.clone(), settings_version, cx),
|
|
||||||
)),
|
|
||||||
AssistantProvider::OpenAi {
|
|
||||||
model,
|
|
||||||
api_url,
|
|
||||||
low_speed_timeout_in_seconds,
|
|
||||||
available_models,
|
|
||||||
} => Arc::new(RwLock::new(OpenAiCompletionProvider::new(
|
|
||||||
choose_openai_model(&model, &available_models),
|
|
||||||
api_url.clone(),
|
|
||||||
client.http_client(),
|
|
||||||
low_speed_timeout_in_seconds.map(Duration::from_secs),
|
|
||||||
settings_version,
|
|
||||||
))),
|
|
||||||
AssistantProvider::Anthropic {
|
|
||||||
model,
|
|
||||||
api_url,
|
|
||||||
low_speed_timeout_in_seconds,
|
|
||||||
} => Arc::new(RwLock::new(AnthropicCompletionProvider::new(
|
|
||||||
model.clone(),
|
|
||||||
api_url.clone(),
|
|
||||||
client.http_client(),
|
|
||||||
low_speed_timeout_in_seconds.map(Duration::from_secs),
|
|
||||||
settings_version,
|
|
||||||
))),
|
|
||||||
AssistantProvider::Ollama {
|
|
||||||
model,
|
|
||||||
api_url,
|
|
||||||
low_speed_timeout_in_seconds,
|
|
||||||
} => Arc::new(RwLock::new(OllamaCompletionProvider::new(
|
|
||||||
model.clone(),
|
|
||||||
api_url.clone(),
|
|
||||||
client.http_client(),
|
|
||||||
low_speed_timeout_in_seconds.map(Duration::from_secs),
|
|
||||||
settings_version,
|
|
||||||
cx,
|
|
||||||
))),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
@ -311,8 +186,8 @@ mod tests {
|
|||||||
use smol::stream::StreamExt;
|
use smol::stream::StreamExt;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
completion_provider::MAX_CONCURRENT_COMPLETION_REQUESTS, CompletionProvider,
|
CompletionProvider, FakeCompletionProvider, LanguageModelRequest,
|
||||||
FakeCompletionProvider, LanguageModelRequest,
|
MAX_CONCURRENT_COMPLETION_REQUESTS,
|
||||||
};
|
};
|
||||||
|
|
||||||
#[gpui::test]
|
#[gpui::test]
|
@ -62,7 +62,7 @@ impl FakeCompletionProvider {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl LanguageModelCompletionProvider for FakeCompletionProvider {
|
impl LanguageModelCompletionProvider for FakeCompletionProvider {
|
||||||
fn available_models(&self, _cx: &AppContext) -> Vec<LanguageModel> {
|
fn available_models(&self) -> Vec<LanguageModel> {
|
||||||
vec![LanguageModel::default()]
|
vec![LanguageModel::default()]
|
||||||
}
|
}
|
||||||
|
|
@ -1,15 +1,14 @@
|
|||||||
use crate::LanguageModelCompletionProvider;
|
use crate::LanguageModelCompletionProvider;
|
||||||
use crate::{
|
use crate::{CompletionProvider, LanguageModel, LanguageModelRequest};
|
||||||
assistant_settings::OllamaModel, CompletionProvider, LanguageModel, LanguageModelRequest, Role,
|
|
||||||
};
|
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use futures::StreamExt as _;
|
use futures::StreamExt as _;
|
||||||
use futures::{future::BoxFuture, stream::BoxStream, FutureExt};
|
use futures::{future::BoxFuture, stream::BoxStream, FutureExt};
|
||||||
use gpui::{AnyView, AppContext, Task};
|
use gpui::{AnyView, AppContext, Task};
|
||||||
use http::HttpClient;
|
use http::HttpClient;
|
||||||
|
use language_model::Role;
|
||||||
|
use ollama::Model as OllamaModel;
|
||||||
use ollama::{
|
use ollama::{
|
||||||
get_models, preload_model, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest,
|
get_models, preload_model, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest,
|
||||||
Role as OllamaRole,
|
|
||||||
};
|
};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
@ -28,7 +27,7 @@ pub struct OllamaCompletionProvider {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl LanguageModelCompletionProvider for OllamaCompletionProvider {
|
impl LanguageModelCompletionProvider for OllamaCompletionProvider {
|
||||||
fn available_models(&self, _cx: &AppContext) -> Vec<LanguageModel> {
|
fn available_models(&self) -> Vec<LanguageModel> {
|
||||||
self.available_models
|
self.available_models
|
||||||
.iter()
|
.iter()
|
||||||
.map(|m| LanguageModel::Ollama(m.clone()))
|
.map(|m| LanguageModel::Ollama(m.clone()))
|
||||||
@ -262,16 +261,6 @@ impl OllamaCompletionProvider {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<Role> for ollama::Role {
|
|
||||||
fn from(val: Role) -> Self {
|
|
||||||
match val {
|
|
||||||
Role::User => OllamaRole::User,
|
|
||||||
Role::Assistant => OllamaRole::Assistant,
|
|
||||||
Role::System => OllamaRole::System,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
struct DownloadOllamaMessage {
|
struct DownloadOllamaMessage {
|
||||||
retry_connection: Box<dyn Fn(&mut WindowContext) -> Task<Result<()>>>,
|
retry_connection: Box<dyn Fn(&mut WindowContext) -> Task<Result<()>>>,
|
||||||
}
|
}
|
@ -1,15 +1,13 @@
|
|||||||
use crate::assistant_settings::CloudModel;
|
use crate::CompletionProvider;
|
||||||
use crate::assistant_settings::{AssistantProvider, AssistantSettings};
|
|
||||||
use crate::LanguageModelCompletionProvider;
|
use crate::LanguageModelCompletionProvider;
|
||||||
use crate::{
|
|
||||||
assistant_settings::OpenAiModel, CompletionProvider, LanguageModel, LanguageModelRequest, Role,
|
|
||||||
};
|
|
||||||
use anyhow::{anyhow, Result};
|
use anyhow::{anyhow, Result};
|
||||||
use editor::{Editor, EditorElement, EditorStyle};
|
use editor::{Editor, EditorElement, EditorStyle};
|
||||||
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
|
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
|
||||||
use gpui::{AnyView, AppContext, FontStyle, Task, TextStyle, View, WhiteSpace};
|
use gpui::{AnyView, AppContext, FontStyle, Task, TextStyle, View, WhiteSpace};
|
||||||
use http::HttpClient;
|
use http::HttpClient;
|
||||||
use open_ai::{stream_completion, Request, RequestMessage, Role as OpenAiRole};
|
use language_model::{CloudModel, LanguageModel, LanguageModelRequest, Role};
|
||||||
|
use open_ai::Model as OpenAiModel;
|
||||||
|
use open_ai::{stream_completion, Request, RequestMessage};
|
||||||
use settings::Settings;
|
use settings::Settings;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use std::{env, sync::Arc};
|
use std::{env, sync::Arc};
|
||||||
@ -25,6 +23,7 @@ pub struct OpenAiCompletionProvider {
|
|||||||
http_client: Arc<dyn HttpClient>,
|
http_client: Arc<dyn HttpClient>,
|
||||||
low_speed_timeout: Option<Duration>,
|
low_speed_timeout: Option<Duration>,
|
||||||
settings_version: usize,
|
settings_version: usize,
|
||||||
|
available_models_from_settings: Vec<OpenAiModel>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl OpenAiCompletionProvider {
|
impl OpenAiCompletionProvider {
|
||||||
@ -34,6 +33,7 @@ impl OpenAiCompletionProvider {
|
|||||||
http_client: Arc<dyn HttpClient>,
|
http_client: Arc<dyn HttpClient>,
|
||||||
low_speed_timeout: Option<Duration>,
|
low_speed_timeout: Option<Duration>,
|
||||||
settings_version: usize,
|
settings_version: usize,
|
||||||
|
available_models_from_settings: Vec<OpenAiModel>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
api_key: None,
|
api_key: None,
|
||||||
@ -42,6 +42,7 @@ impl OpenAiCompletionProvider {
|
|||||||
http_client,
|
http_client,
|
||||||
low_speed_timeout,
|
low_speed_timeout,
|
||||||
settings_version,
|
settings_version,
|
||||||
|
available_models_from_settings,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -92,30 +93,26 @@ impl OpenAiCompletionProvider {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl LanguageModelCompletionProvider for OpenAiCompletionProvider {
|
impl LanguageModelCompletionProvider for OpenAiCompletionProvider {
|
||||||
fn available_models(&self, cx: &AppContext) -> Vec<LanguageModel> {
|
fn available_models(&self) -> Vec<LanguageModel> {
|
||||||
if let AssistantProvider::OpenAi {
|
if self.available_models_from_settings.is_empty() {
|
||||||
available_models, ..
|
let available_models = if matches!(self.model, OpenAiModel::Custom { .. }) {
|
||||||
} = &AssistantSettings::get_global(cx).provider
|
vec![self.model.clone()]
|
||||||
{
|
} else {
|
||||||
if !available_models.is_empty() {
|
OpenAiModel::iter()
|
||||||
return available_models
|
.filter(|model| !matches!(model, OpenAiModel::Custom { .. }))
|
||||||
.iter()
|
.collect()
|
||||||
.cloned()
|
};
|
||||||
.map(LanguageModel::OpenAi)
|
available_models
|
||||||
.collect();
|
.into_iter()
|
||||||
}
|
.map(LanguageModel::OpenAi)
|
||||||
}
|
|
||||||
let available_models = if matches!(self.model, OpenAiModel::Custom { .. }) {
|
|
||||||
vec![self.model.clone()]
|
|
||||||
} else {
|
|
||||||
OpenAiModel::iter()
|
|
||||||
.filter(|model| !matches!(model, OpenAiModel::Custom { .. }))
|
|
||||||
.collect()
|
.collect()
|
||||||
};
|
} else {
|
||||||
available_models
|
self.available_models_from_settings
|
||||||
.into_iter()
|
.iter()
|
||||||
.map(LanguageModel::OpenAi)
|
.cloned()
|
||||||
.collect()
|
.map(LanguageModel::OpenAi)
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn settings_version(&self) -> usize {
|
fn settings_version(&self) -> usize {
|
||||||
@ -255,16 +252,6 @@ pub fn count_open_ai_tokens(
|
|||||||
.boxed()
|
.boxed()
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<Role> for open_ai::Role {
|
|
||||||
fn from(val: Role) -> Self {
|
|
||||||
match val {
|
|
||||||
Role::User => OpenAiRole::User,
|
|
||||||
Role::Assistant => OpenAiRole::Assistant,
|
|
||||||
Role::System => OpenAiRole::System,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
struct AuthenticationPrompt {
|
struct AuthenticationPrompt {
|
||||||
api_key: View<Editor>,
|
api_key: View<Editor>,
|
||||||
api_url: String,
|
api_url: String,
|
41
crates/language_model/Cargo.toml
Normal file
41
crates/language_model/Cargo.toml
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
[package]
|
||||||
|
name = "language_model"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2021"
|
||||||
|
publish = false
|
||||||
|
license = "GPL-3.0-or-later"
|
||||||
|
|
||||||
|
[lints]
|
||||||
|
workspace = true
|
||||||
|
|
||||||
|
[lib]
|
||||||
|
path = "src/language_model.rs"
|
||||||
|
doctest = false
|
||||||
|
|
||||||
|
[features]
|
||||||
|
test-support = [
|
||||||
|
"editor/test-support",
|
||||||
|
"language/test-support",
|
||||||
|
"project/test-support",
|
||||||
|
"text/test-support",
|
||||||
|
]
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
anthropic = { workspace = true, features = ["schemars"] }
|
||||||
|
ollama = { workspace = true, features = ["schemars"] }
|
||||||
|
open_ai = { workspace = true, features = ["schemars"] }
|
||||||
|
schemars.workspace = true
|
||||||
|
serde.workspace = true
|
||||||
|
strum.workspace = true
|
||||||
|
proto = { workspace = true, features = ["test-support"] }
|
||||||
|
|
||||||
|
[dev-dependencies]
|
||||||
|
ctor.workspace = true
|
||||||
|
editor = { workspace = true, features = ["test-support"] }
|
||||||
|
env_logger.workspace = true
|
||||||
|
language = { workspace = true, features = ["test-support"] }
|
||||||
|
log.workspace = true
|
||||||
|
project = { workspace = true, features = ["test-support"] }
|
||||||
|
rand.workspace = true
|
||||||
|
text = { workspace = true, features = ["test-support"] }
|
||||||
|
unindent.workspace = true
|
1
crates/language_model/LICENSE-GPL
Symbolic link
1
crates/language_model/LICENSE-GPL
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../LICENSE-GPL
|
7
crates/language_model/src/language_model.rs
Normal file
7
crates/language_model/src/language_model.rs
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
mod model;
|
||||||
|
mod request;
|
||||||
|
mod role;
|
||||||
|
|
||||||
|
pub use model::*;
|
||||||
|
pub use request::*;
|
||||||
|
pub use role::*;
|
160
crates/language_model/src/model/cloud_model.rs
Normal file
160
crates/language_model/src/model/cloud_model.rs
Normal file
@ -0,0 +1,160 @@
|
|||||||
|
use crate::LanguageModelRequest;
|
||||||
|
pub use anthropic::Model as AnthropicModel;
|
||||||
|
pub use ollama::Model as OllamaModel;
|
||||||
|
pub use open_ai::Model as OpenAiModel;
|
||||||
|
use schemars::{
|
||||||
|
schema::{InstanceType, Metadata, Schema, SchemaObject},
|
||||||
|
JsonSchema,
|
||||||
|
};
|
||||||
|
use serde::{
|
||||||
|
de::{self, Visitor},
|
||||||
|
Deserialize, Deserializer, Serialize, Serializer,
|
||||||
|
};
|
||||||
|
use std::fmt;
|
||||||
|
use strum::{EnumIter, IntoEnumIterator};
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Default, PartialEq, EnumIter)]
|
||||||
|
pub enum CloudModel {
|
||||||
|
Gpt3Point5Turbo,
|
||||||
|
Gpt4,
|
||||||
|
Gpt4Turbo,
|
||||||
|
#[default]
|
||||||
|
Gpt4Omni,
|
||||||
|
Gpt4OmniMini,
|
||||||
|
Claude3_5Sonnet,
|
||||||
|
Claude3Opus,
|
||||||
|
Claude3Sonnet,
|
||||||
|
Claude3Haiku,
|
||||||
|
Gemini15Pro,
|
||||||
|
Gemini15Flash,
|
||||||
|
Custom(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Serialize for CloudModel {
|
||||||
|
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||||
|
where
|
||||||
|
S: Serializer,
|
||||||
|
{
|
||||||
|
serializer.serialize_str(self.id())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'de> Deserialize<'de> for CloudModel {
|
||||||
|
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||||
|
where
|
||||||
|
D: Deserializer<'de>,
|
||||||
|
{
|
||||||
|
struct ZedDotDevModelVisitor;
|
||||||
|
|
||||||
|
impl<'de> Visitor<'de> for ZedDotDevModelVisitor {
|
||||||
|
type Value = CloudModel;
|
||||||
|
|
||||||
|
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
|
||||||
|
formatter.write_str("a string for a ZedDotDevModel variant or a custom model")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
|
||||||
|
where
|
||||||
|
E: de::Error,
|
||||||
|
{
|
||||||
|
let model = CloudModel::iter()
|
||||||
|
.find(|model| model.id() == value)
|
||||||
|
.unwrap_or_else(|| CloudModel::Custom(value.to_string()));
|
||||||
|
Ok(model)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
deserializer.deserialize_str(ZedDotDevModelVisitor)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl JsonSchema for CloudModel {
|
||||||
|
fn schema_name() -> String {
|
||||||
|
"ZedDotDevModel".to_owned()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn json_schema(_generator: &mut schemars::gen::SchemaGenerator) -> Schema {
|
||||||
|
let variants = CloudModel::iter()
|
||||||
|
.filter_map(|model| {
|
||||||
|
let id = model.id();
|
||||||
|
if id.is_empty() {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(id.to_string())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
Schema::Object(SchemaObject {
|
||||||
|
instance_type: Some(InstanceType::String.into()),
|
||||||
|
enum_values: Some(variants.iter().map(|s| s.clone().into()).collect()),
|
||||||
|
metadata: Some(Box::new(Metadata {
|
||||||
|
title: Some("ZedDotDevModel".to_owned()),
|
||||||
|
default: Some(CloudModel::default().id().into()),
|
||||||
|
examples: variants.into_iter().map(Into::into).collect(),
|
||||||
|
..Default::default()
|
||||||
|
})),
|
||||||
|
..Default::default()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CloudModel {
|
||||||
|
pub fn id(&self) -> &str {
|
||||||
|
match self {
|
||||||
|
Self::Gpt3Point5Turbo => "gpt-3.5-turbo",
|
||||||
|
Self::Gpt4 => "gpt-4",
|
||||||
|
Self::Gpt4Turbo => "gpt-4-turbo-preview",
|
||||||
|
Self::Gpt4Omni => "gpt-4o",
|
||||||
|
Self::Gpt4OmniMini => "gpt-4o-mini",
|
||||||
|
Self::Claude3_5Sonnet => "claude-3-5-sonnet",
|
||||||
|
Self::Claude3Opus => "claude-3-opus",
|
||||||
|
Self::Claude3Sonnet => "claude-3-sonnet",
|
||||||
|
Self::Claude3Haiku => "claude-3-haiku",
|
||||||
|
Self::Gemini15Pro => "gemini-1.5-pro",
|
||||||
|
Self::Gemini15Flash => "gemini-1.5-flash",
|
||||||
|
Self::Custom(id) => id,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn display_name(&self) -> &str {
|
||||||
|
match self {
|
||||||
|
Self::Gpt3Point5Turbo => "GPT 3.5 Turbo",
|
||||||
|
Self::Gpt4 => "GPT 4",
|
||||||
|
Self::Gpt4Turbo => "GPT 4 Turbo",
|
||||||
|
Self::Gpt4Omni => "GPT 4 Omni",
|
||||||
|
Self::Gpt4OmniMini => "GPT 4 Omni Mini",
|
||||||
|
Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
|
||||||
|
Self::Claude3Opus => "Claude 3 Opus",
|
||||||
|
Self::Claude3Sonnet => "Claude 3 Sonnet",
|
||||||
|
Self::Claude3Haiku => "Claude 3 Haiku",
|
||||||
|
Self::Gemini15Pro => "Gemini 1.5 Pro",
|
||||||
|
Self::Gemini15Flash => "Gemini 1.5 Flash",
|
||||||
|
Self::Custom(id) => id.as_str(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn max_token_count(&self) -> usize {
|
||||||
|
match self {
|
||||||
|
Self::Gpt3Point5Turbo => 2048,
|
||||||
|
Self::Gpt4 => 4096,
|
||||||
|
Self::Gpt4Turbo | Self::Gpt4Omni => 128000,
|
||||||
|
Self::Gpt4OmniMini => 128000,
|
||||||
|
Self::Claude3_5Sonnet
|
||||||
|
| Self::Claude3Opus
|
||||||
|
| Self::Claude3Sonnet
|
||||||
|
| Self::Claude3Haiku => 200000,
|
||||||
|
Self::Gemini15Pro => 128000,
|
||||||
|
Self::Gemini15Flash => 32000,
|
||||||
|
Self::Custom(_) => 4096, // TODO: Make this configurable
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn preprocess_request(&self, request: &mut LanguageModelRequest) {
|
||||||
|
match self {
|
||||||
|
Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => {
|
||||||
|
request.preprocess_anthropic()
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
60
crates/language_model/src/model/mod.rs
Normal file
60
crates/language_model/src/model/mod.rs
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
pub mod cloud_model;
|
||||||
|
|
||||||
|
pub use anthropic::Model as AnthropicModel;
|
||||||
|
pub use cloud_model::*;
|
||||||
|
pub use ollama::Model as OllamaModel;
|
||||||
|
pub use open_ai::Model as OpenAiModel;
|
||||||
|
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
|
||||||
|
pub enum LanguageModel {
|
||||||
|
Cloud(CloudModel),
|
||||||
|
OpenAi(OpenAiModel),
|
||||||
|
Anthropic(AnthropicModel),
|
||||||
|
Ollama(OllamaModel),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for LanguageModel {
|
||||||
|
fn default() -> Self {
|
||||||
|
LanguageModel::Cloud(CloudModel::default())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LanguageModel {
|
||||||
|
pub fn telemetry_id(&self) -> String {
|
||||||
|
match self {
|
||||||
|
LanguageModel::OpenAi(model) => format!("openai/{}", model.id()),
|
||||||
|
LanguageModel::Anthropic(model) => format!("anthropic/{}", model.id()),
|
||||||
|
LanguageModel::Cloud(model) => format!("zed.dev/{}", model.id()),
|
||||||
|
LanguageModel::Ollama(model) => format!("ollama/{}", model.id()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn display_name(&self) -> String {
|
||||||
|
match self {
|
||||||
|
LanguageModel::OpenAi(model) => model.display_name().into(),
|
||||||
|
LanguageModel::Anthropic(model) => model.display_name().into(),
|
||||||
|
LanguageModel::Cloud(model) => model.display_name().into(),
|
||||||
|
LanguageModel::Ollama(model) => model.display_name().into(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn max_token_count(&self) -> usize {
|
||||||
|
match self {
|
||||||
|
LanguageModel::OpenAi(model) => model.max_token_count(),
|
||||||
|
LanguageModel::Anthropic(model) => model.max_token_count(),
|
||||||
|
LanguageModel::Cloud(model) => model.max_token_count(),
|
||||||
|
LanguageModel::Ollama(model) => model.max_token_count(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn id(&self) -> &str {
|
||||||
|
match self {
|
||||||
|
LanguageModel::OpenAi(model) => model.id(),
|
||||||
|
LanguageModel::Anthropic(model) => model.id(),
|
||||||
|
LanguageModel::Cloud(model) => model.id(),
|
||||||
|
LanguageModel::Ollama(model) => model.id(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
110
crates/language_model/src/request.rs
Normal file
110
crates/language_model/src/request.rs
Normal file
@ -0,0 +1,110 @@
|
|||||||
|
use crate::{
|
||||||
|
model::{CloudModel, LanguageModel},
|
||||||
|
role::Role,
|
||||||
|
};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||||
|
pub struct LanguageModelRequestMessage {
|
||||||
|
pub role: Role,
|
||||||
|
pub content: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LanguageModelRequestMessage {
|
||||||
|
pub fn to_proto(&self) -> proto::LanguageModelRequestMessage {
|
||||||
|
proto::LanguageModelRequestMessage {
|
||||||
|
role: self.role.to_proto() as i32,
|
||||||
|
content: self.content.clone(),
|
||||||
|
tool_calls: Vec::new(),
|
||||||
|
tool_call_id: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Default, Serialize, Deserialize)]
|
||||||
|
pub struct LanguageModelRequest {
|
||||||
|
pub model: LanguageModel,
|
||||||
|
pub messages: Vec<LanguageModelRequestMessage>,
|
||||||
|
pub stop: Vec<String>,
|
||||||
|
pub temperature: f32,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LanguageModelRequest {
|
||||||
|
pub fn to_proto(&self) -> proto::CompleteWithLanguageModel {
|
||||||
|
proto::CompleteWithLanguageModel {
|
||||||
|
model: self.model.id().to_string(),
|
||||||
|
messages: self.messages.iter().map(|m| m.to_proto()).collect(),
|
||||||
|
stop: self.stop.clone(),
|
||||||
|
temperature: self.temperature,
|
||||||
|
tool_choice: None,
|
||||||
|
tools: Vec::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Before we send the request to the server, we can perform fixups on it appropriate to the model.
|
||||||
|
pub fn preprocess(&mut self) {
|
||||||
|
match &self.model {
|
||||||
|
LanguageModel::OpenAi(_) => {}
|
||||||
|
LanguageModel::Anthropic(_) => {}
|
||||||
|
LanguageModel::Ollama(_) => {}
|
||||||
|
LanguageModel::Cloud(model) => match model {
|
||||||
|
CloudModel::Claude3Opus
|
||||||
|
| CloudModel::Claude3Sonnet
|
||||||
|
| CloudModel::Claude3Haiku
|
||||||
|
| CloudModel::Claude3_5Sonnet => {
|
||||||
|
self.preprocess_anthropic();
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn preprocess_anthropic(&mut self) {
|
||||||
|
let mut new_messages: Vec<LanguageModelRequestMessage> = Vec::new();
|
||||||
|
let mut system_message = String::new();
|
||||||
|
|
||||||
|
for message in self.messages.drain(..) {
|
||||||
|
if message.content.is_empty() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
match message.role {
|
||||||
|
Role::User | Role::Assistant => {
|
||||||
|
if let Some(last_message) = new_messages.last_mut() {
|
||||||
|
if last_message.role == message.role {
|
||||||
|
last_message.content.push_str("\n\n");
|
||||||
|
last_message.content.push_str(&message.content);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
new_messages.push(message);
|
||||||
|
}
|
||||||
|
Role::System => {
|
||||||
|
if !system_message.is_empty() {
|
||||||
|
system_message.push_str("\n\n");
|
||||||
|
}
|
||||||
|
system_message.push_str(&message.content);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !system_message.is_empty() {
|
||||||
|
new_messages.insert(
|
||||||
|
0,
|
||||||
|
LanguageModelRequestMessage {
|
||||||
|
role: Role::System,
|
||||||
|
content: system_message,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
self.messages = new_messages;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||||
|
pub struct LanguageModelResponseMessage {
|
||||||
|
pub role: Option<Role>,
|
||||||
|
pub content: Option<String>,
|
||||||
|
}
|
68
crates/language_model/src/role.rs
Normal file
68
crates/language_model/src/role.rs
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::fmt::{self, Display};
|
||||||
|
|
||||||
|
#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||||
|
#[serde(rename_all = "lowercase")]
|
||||||
|
pub enum Role {
|
||||||
|
User,
|
||||||
|
Assistant,
|
||||||
|
System,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Role {
|
||||||
|
pub fn from_proto(role: i32) -> Role {
|
||||||
|
match proto::LanguageModelRole::from_i32(role) {
|
||||||
|
Some(proto::LanguageModelRole::LanguageModelUser) => Role::User,
|
||||||
|
Some(proto::LanguageModelRole::LanguageModelAssistant) => Role::Assistant,
|
||||||
|
Some(proto::LanguageModelRole::LanguageModelSystem) => Role::System,
|
||||||
|
Some(proto::LanguageModelRole::LanguageModelTool) => Role::System,
|
||||||
|
None => Role::User,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn to_proto(&self) -> proto::LanguageModelRole {
|
||||||
|
match self {
|
||||||
|
Role::User => proto::LanguageModelRole::LanguageModelUser,
|
||||||
|
Role::Assistant => proto::LanguageModelRole::LanguageModelAssistant,
|
||||||
|
Role::System => proto::LanguageModelRole::LanguageModelSystem,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn cycle(self) -> Role {
|
||||||
|
match self {
|
||||||
|
Role::User => Role::Assistant,
|
||||||
|
Role::Assistant => Role::System,
|
||||||
|
Role::System => Role::User,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Display for Role {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
match self {
|
||||||
|
Role::User => write!(f, "user"),
|
||||||
|
Role::Assistant => write!(f, "assistant"),
|
||||||
|
Role::System => write!(f, "system"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<Role> for ollama::Role {
|
||||||
|
fn from(val: Role) -> Self {
|
||||||
|
match val {
|
||||||
|
Role::User => ollama::Role::User,
|
||||||
|
Role::Assistant => ollama::Role::Assistant,
|
||||||
|
Role::System => ollama::Role::System,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<Role> for open_ai::Role {
|
||||||
|
fn from(val: Role) -> Self {
|
||||||
|
match val {
|
||||||
|
Role::User => open_ai::Role::User,
|
||||||
|
Role::Assistant => open_ai::Role::Assistant,
|
||||||
|
Role::System => open_ai::Role::System,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -22,6 +22,7 @@ anyhow.workspace = true
|
|||||||
client.workspace = true
|
client.workspace = true
|
||||||
clock.workspace = true
|
clock.workspace = true
|
||||||
collections.workspace = true
|
collections.workspace = true
|
||||||
|
completion.workspace = true
|
||||||
fs.workspace = true
|
fs.workspace = true
|
||||||
futures.workspace = true
|
futures.workspace = true
|
||||||
futures-batch.workspace = true
|
futures-batch.workspace = true
|
||||||
|
@ -1261,3 +1261,6 @@ mod tests {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// See https://github.com/zed-industries/zed/pull/14823#discussion_r1684616398 for why this is here and when it should be removed.
|
||||||
|
type _TODO = completion::CompletionProvider;
|
||||||
|
Loading…
Reference in New Issue
Block a user