mirror of
https://github.com/zed-industries/zed.git
synced 2024-09-19 18:41:56 +03:00
Allow customization of the model used for tool calling (#15479)
We also eliminate the `completion` crate and moved its logic into `LanguageModelRegistry`. Release Notes: - N/A --------- Co-authored-by: Nathan <nathan@zed.dev>
This commit is contained in:
parent
1bfea9d443
commit
99bc90a372
28
Cargo.lock
generated
28
Cargo.lock
generated
@ -406,7 +406,6 @@ dependencies = [
|
|||||||
"clock",
|
"clock",
|
||||||
"collections",
|
"collections",
|
||||||
"command_palette_hooks",
|
"command_palette_hooks",
|
||||||
"completion",
|
|
||||||
"ctor",
|
"ctor",
|
||||||
"editor",
|
"editor",
|
||||||
"env_logger",
|
"env_logger",
|
||||||
@ -2470,7 +2469,6 @@ dependencies = [
|
|||||||
"clock",
|
"clock",
|
||||||
"collab_ui",
|
"collab_ui",
|
||||||
"collections",
|
"collections",
|
||||||
"completion",
|
|
||||||
"ctor",
|
"ctor",
|
||||||
"dashmap 6.0.1",
|
"dashmap 6.0.1",
|
||||||
"dev_server_projects",
|
"dev_server_projects",
|
||||||
@ -2655,30 +2653,6 @@ dependencies = [
|
|||||||
"gpui",
|
"gpui",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "completion"
|
|
||||||
version = "0.1.0"
|
|
||||||
dependencies = [
|
|
||||||
"anyhow",
|
|
||||||
"ctor",
|
|
||||||
"editor",
|
|
||||||
"env_logger",
|
|
||||||
"futures 0.3.28",
|
|
||||||
"gpui",
|
|
||||||
"language",
|
|
||||||
"language_model",
|
|
||||||
"project",
|
|
||||||
"rand 0.8.5",
|
|
||||||
"schemars",
|
|
||||||
"serde",
|
|
||||||
"serde_json",
|
|
||||||
"settings",
|
|
||||||
"smol",
|
|
||||||
"text",
|
|
||||||
"ui",
|
|
||||||
"unindent",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "concurrent-queue"
|
name = "concurrent-queue"
|
||||||
version = "2.2.0"
|
version = "2.2.0"
|
||||||
@ -6048,6 +6022,7 @@ dependencies = [
|
|||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"settings",
|
"settings",
|
||||||
|
"smol",
|
||||||
"strum",
|
"strum",
|
||||||
"text",
|
"text",
|
||||||
"theme",
|
"theme",
|
||||||
@ -9506,7 +9481,6 @@ dependencies = [
|
|||||||
"client",
|
"client",
|
||||||
"clock",
|
"clock",
|
||||||
"collections",
|
"collections",
|
||||||
"completion",
|
|
||||||
"env_logger",
|
"env_logger",
|
||||||
"fs",
|
"fs",
|
||||||
"futures 0.3.28",
|
"futures 0.3.28",
|
||||||
|
@ -19,7 +19,6 @@ members = [
|
|||||||
"crates/collections",
|
"crates/collections",
|
||||||
"crates/command_palette",
|
"crates/command_palette",
|
||||||
"crates/command_palette_hooks",
|
"crates/command_palette_hooks",
|
||||||
"crates/completion",
|
|
||||||
"crates/copilot",
|
"crates/copilot",
|
||||||
"crates/db",
|
"crates/db",
|
||||||
"crates/dev_server_projects",
|
"crates/dev_server_projects",
|
||||||
@ -190,7 +189,6 @@ collab_ui = { path = "crates/collab_ui" }
|
|||||||
collections = { path = "crates/collections" }
|
collections = { path = "crates/collections" }
|
||||||
command_palette = { path = "crates/command_palette" }
|
command_palette = { path = "crates/command_palette" }
|
||||||
command_palette_hooks = { path = "crates/command_palette_hooks" }
|
command_palette_hooks = { path = "crates/command_palette_hooks" }
|
||||||
completion = { path = "crates/completion" }
|
|
||||||
copilot = { path = "crates/copilot" }
|
copilot = { path = "crates/copilot" }
|
||||||
db = { path = "crates/db" }
|
db = { path = "crates/db" }
|
||||||
dev_server_projects = { path = "crates/dev_server_projects" }
|
dev_server_projects = { path = "crates/dev_server_projects" }
|
||||||
|
@ -21,7 +21,12 @@ pub enum Model {
|
|||||||
#[serde(alias = "claude-3-haiku", rename = "claude-3-haiku-20240307")]
|
#[serde(alias = "claude-3-haiku", rename = "claude-3-haiku-20240307")]
|
||||||
Claude3Haiku,
|
Claude3Haiku,
|
||||||
#[serde(rename = "custom")]
|
#[serde(rename = "custom")]
|
||||||
Custom { name: String, max_tokens: usize },
|
Custom {
|
||||||
|
name: String,
|
||||||
|
max_tokens: usize,
|
||||||
|
/// Override this model with a different Anthropic model for tool calls.
|
||||||
|
tool_override: Option<String>,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Model {
|
impl Model {
|
||||||
@ -68,6 +73,18 @@ impl Model {
|
|||||||
Self::Custom { max_tokens, .. } => *max_tokens,
|
Self::Custom { max_tokens, .. } => *max_tokens,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn tool_model_id(&self) -> &str {
|
||||||
|
if let Self::Custom {
|
||||||
|
tool_override: Some(tool_override),
|
||||||
|
..
|
||||||
|
} = self
|
||||||
|
{
|
||||||
|
tool_override
|
||||||
|
} else {
|
||||||
|
self.id()
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn complete(
|
pub async fn complete(
|
||||||
|
@ -32,7 +32,6 @@ client.workspace = true
|
|||||||
clock.workspace = true
|
clock.workspace = true
|
||||||
collections.workspace = true
|
collections.workspace = true
|
||||||
command_palette_hooks.workspace = true
|
command_palette_hooks.workspace = true
|
||||||
completion.workspace = true
|
|
||||||
editor.workspace = true
|
editor.workspace = true
|
||||||
fs.workspace = true
|
fs.workspace = true
|
||||||
futures.workspace = true
|
futures.workspace = true
|
||||||
@ -77,7 +76,6 @@ workspace.workspace = true
|
|||||||
picker.workspace = true
|
picker.workspace = true
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
completion = { workspace = true, features = ["test-support"] }
|
|
||||||
ctor.workspace = true
|
ctor.workspace = true
|
||||||
editor = { workspace = true, features = ["test-support"] }
|
editor = { workspace = true, features = ["test-support"] }
|
||||||
env_logger.workspace = true
|
env_logger.workspace = true
|
||||||
|
@ -15,7 +15,6 @@ use assistant_settings::AssistantSettings;
|
|||||||
use assistant_slash_command::SlashCommandRegistry;
|
use assistant_slash_command::SlashCommandRegistry;
|
||||||
use client::{proto, Client};
|
use client::{proto, Client};
|
||||||
use command_palette_hooks::CommandPaletteFilter;
|
use command_palette_hooks::CommandPaletteFilter;
|
||||||
use completion::LanguageModelCompletionProvider;
|
|
||||||
pub use context::*;
|
pub use context::*;
|
||||||
pub use context_store::*;
|
pub use context_store::*;
|
||||||
use fs::Fs;
|
use fs::Fs;
|
||||||
@ -192,7 +191,7 @@ pub fn init(fs: Arc<dyn Fs>, client: Arc<Client>, cx: &mut AppContext) {
|
|||||||
|
|
||||||
context_store::init(&client);
|
context_store::init(&client);
|
||||||
prompt_library::init(cx);
|
prompt_library::init(cx);
|
||||||
init_completion_provider(cx);
|
init_language_model_settings(cx);
|
||||||
assistant_slash_command::init(cx);
|
assistant_slash_command::init(cx);
|
||||||
register_slash_commands(cx);
|
register_slash_commands(cx);
|
||||||
assistant_panel::init(cx);
|
assistant_panel::init(cx);
|
||||||
@ -217,8 +216,7 @@ pub fn init(fs: Arc<dyn Fs>, client: Arc<Client>, cx: &mut AppContext) {
|
|||||||
.detach();
|
.detach();
|
||||||
}
|
}
|
||||||
|
|
||||||
fn init_completion_provider(cx: &mut AppContext) {
|
fn init_language_model_settings(cx: &mut AppContext) {
|
||||||
completion::init(cx);
|
|
||||||
update_active_language_model_from_settings(cx);
|
update_active_language_model_from_settings(cx);
|
||||||
|
|
||||||
cx.observe_global::<SettingsStore>(update_active_language_model_from_settings)
|
cx.observe_global::<SettingsStore>(update_active_language_model_from_settings)
|
||||||
@ -233,21 +231,10 @@ fn update_active_language_model_from_settings(cx: &mut AppContext) {
|
|||||||
let settings = AssistantSettings::get_global(cx);
|
let settings = AssistantSettings::get_global(cx);
|
||||||
let provider_name = LanguageModelProviderId::from(settings.default_model.provider.clone());
|
let provider_name = LanguageModelProviderId::from(settings.default_model.provider.clone());
|
||||||
let model_id = LanguageModelId::from(settings.default_model.model.clone());
|
let model_id = LanguageModelId::from(settings.default_model.model.clone());
|
||||||
|
LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
|
||||||
let Some(provider) = LanguageModelRegistry::global(cx)
|
registry.select_active_model(&provider_name, &model_id, cx);
|
||||||
.read(cx)
|
|
||||||
.provider(&provider_name)
|
|
||||||
else {
|
|
||||||
return;
|
|
||||||
};
|
|
||||||
|
|
||||||
let models = provider.provided_models(cx);
|
|
||||||
if let Some(model) = models.iter().find(|model| model.id() == model_id).cloned() {
|
|
||||||
LanguageModelCompletionProvider::global(cx).update(cx, |completion_provider, cx| {
|
|
||||||
completion_provider.set_active_model(model, cx);
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
fn register_slash_commands(cx: &mut AppContext) {
|
fn register_slash_commands(cx: &mut AppContext) {
|
||||||
let slash_command_registry = SlashCommandRegistry::global(cx);
|
let slash_command_registry = SlashCommandRegistry::global(cx);
|
||||||
|
@ -19,7 +19,6 @@ use anyhow::{anyhow, Result};
|
|||||||
use assistant_slash_command::{SlashCommand, SlashCommandOutputSection};
|
use assistant_slash_command::{SlashCommand, SlashCommandOutputSection};
|
||||||
use client::proto;
|
use client::proto;
|
||||||
use collections::{BTreeSet, HashMap, HashSet};
|
use collections::{BTreeSet, HashMap, HashSet};
|
||||||
use completion::LanguageModelCompletionProvider;
|
|
||||||
use editor::{
|
use editor::{
|
||||||
actions::{FoldAt, MoveToEndOfLine, Newline, ShowCompletions, UnfoldAt},
|
actions::{FoldAt, MoveToEndOfLine, Newline, ShowCompletions, UnfoldAt},
|
||||||
display_map::{
|
display_map::{
|
||||||
@ -43,7 +42,7 @@ use language::{
|
|||||||
language_settings::SoftWrap, Buffer, Capability, LanguageRegistry, LspAdapterDelegate, Point,
|
language_settings::SoftWrap, Buffer, Capability, LanguageRegistry, LspAdapterDelegate, Point,
|
||||||
ToOffset,
|
ToOffset,
|
||||||
};
|
};
|
||||||
use language_model::{LanguageModelProviderId, Role};
|
use language_model::{LanguageModelProviderId, LanguageModelRegistry, Role};
|
||||||
use multi_buffer::MultiBufferRow;
|
use multi_buffer::MultiBufferRow;
|
||||||
use picker::{Picker, PickerDelegate};
|
use picker::{Picker, PickerDelegate};
|
||||||
use project::{Project, ProjectLspAdapterDelegate};
|
use project::{Project, ProjectLspAdapterDelegate};
|
||||||
@ -392,9 +391,9 @@ impl AssistantPanel {
|
|||||||
cx.subscribe(&context_editor_toolbar, Self::handle_toolbar_event),
|
cx.subscribe(&context_editor_toolbar, Self::handle_toolbar_event),
|
||||||
cx.subscribe(&model_summary_editor, Self::handle_summary_editor_event),
|
cx.subscribe(&model_summary_editor, Self::handle_summary_editor_event),
|
||||||
cx.subscribe(&context_store, Self::handle_context_store_event),
|
cx.subscribe(&context_store, Self::handle_context_store_event),
|
||||||
cx.observe(
|
cx.subscribe(
|
||||||
&LanguageModelCompletionProvider::global(cx),
|
&LanguageModelRegistry::global(cx),
|
||||||
|this, _, cx| {
|
|this, _, _: &language_model::ActiveModelChanged, cx| {
|
||||||
this.completion_provider_changed(cx);
|
this.completion_provider_changed(cx);
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
@ -560,7 +559,7 @@ impl AssistantPanel {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
let Some(new_provider_id) = LanguageModelCompletionProvider::read_global(cx)
|
let Some(new_provider_id) = LanguageModelRegistry::read_global(cx)
|
||||||
.active_provider()
|
.active_provider()
|
||||||
.map(|p| p.id())
|
.map(|p| p.id())
|
||||||
else {
|
else {
|
||||||
@ -599,7 +598,7 @@ impl AssistantPanel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn authentication_prompt(cx: &mut WindowContext) -> Option<AnyView> {
|
fn authentication_prompt(cx: &mut WindowContext) -> Option<AnyView> {
|
||||||
if let Some(provider) = LanguageModelCompletionProvider::read_global(cx).active_provider() {
|
if let Some(provider) = LanguageModelRegistry::read_global(cx).active_provider() {
|
||||||
if !provider.is_authenticated(cx) {
|
if !provider.is_authenticated(cx) {
|
||||||
return Some(provider.authentication_prompt(cx));
|
return Some(provider.authentication_prompt(cx));
|
||||||
}
|
}
|
||||||
@ -904,9 +903,9 @@ impl AssistantPanel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn reset_credentials(&mut self, _: &ResetKey, cx: &mut ViewContext<Self>) {
|
fn reset_credentials(&mut self, _: &ResetKey, cx: &mut ViewContext<Self>) {
|
||||||
LanguageModelCompletionProvider::read_global(cx)
|
if let Some(provider) = LanguageModelRegistry::read_global(cx).active_provider() {
|
||||||
.reset_credentials(cx)
|
provider.reset_credentials(cx).detach_and_log_err(cx);
|
||||||
.detach_and_log_err(cx);
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn toggle_model_selector(&mut self, _: &ToggleModelSelector, cx: &mut ViewContext<Self>) {
|
fn toggle_model_selector(&mut self, _: &ToggleModelSelector, cx: &mut ViewContext<Self>) {
|
||||||
@ -1041,11 +1040,18 @@ impl AssistantPanel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn is_authenticated(&mut self, cx: &mut ViewContext<Self>) -> bool {
|
fn is_authenticated(&mut self, cx: &mut ViewContext<Self>) -> bool {
|
||||||
LanguageModelCompletionProvider::read_global(cx).is_authenticated(cx)
|
LanguageModelRegistry::read_global(cx)
|
||||||
|
.active_provider()
|
||||||
|
.map_or(false, |provider| provider.is_authenticated(cx))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn authenticate(&mut self, cx: &mut ViewContext<Self>) -> Task<Result<()>> {
|
fn authenticate(&mut self, cx: &mut ViewContext<Self>) -> Task<Result<()>> {
|
||||||
LanguageModelCompletionProvider::read_global(cx).authenticate(cx)
|
LanguageModelRegistry::read_global(cx)
|
||||||
|
.active_provider()
|
||||||
|
.map_or(
|
||||||
|
Task::ready(Err(anyhow!("no active language model provider"))),
|
||||||
|
|provider| provider.authenticate(cx),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn render_signed_in(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
|
fn render_signed_in(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
|
||||||
@ -2707,7 +2713,7 @@ impl ContextEditorToolbarItem {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn render_remaining_tokens(&self, cx: &mut ViewContext<Self>) -> Option<impl IntoElement> {
|
fn render_remaining_tokens(&self, cx: &mut ViewContext<Self>) -> Option<impl IntoElement> {
|
||||||
let model = LanguageModelCompletionProvider::read_global(cx).active_model()?;
|
let model = LanguageModelRegistry::read_global(cx).active_model()?;
|
||||||
let context = &self
|
let context = &self
|
||||||
.active_context_editor
|
.active_context_editor
|
||||||
.as_ref()?
|
.as_ref()?
|
||||||
@ -2779,7 +2785,7 @@ impl Render for ContextEditorToolbarItem {
|
|||||||
.whitespace_nowrap()
|
.whitespace_nowrap()
|
||||||
.child(
|
.child(
|
||||||
Label::new(
|
Label::new(
|
||||||
LanguageModelCompletionProvider::read_global(cx)
|
LanguageModelRegistry::read_global(cx)
|
||||||
.active_model()
|
.active_model()
|
||||||
.map(|model| model.name().0)
|
.map(|model| model.name().0)
|
||||||
.unwrap_or_else(|| "No model selected".into()),
|
.unwrap_or_else(|| "No model selected".into()),
|
||||||
|
@ -52,7 +52,7 @@ pub struct AssistantSettings {
|
|||||||
pub dock: AssistantDockPosition,
|
pub dock: AssistantDockPosition,
|
||||||
pub default_width: Pixels,
|
pub default_width: Pixels,
|
||||||
pub default_height: Pixels,
|
pub default_height: Pixels,
|
||||||
pub default_model: AssistantDefaultModel,
|
pub default_model: LanguageModelSelection,
|
||||||
pub using_outdated_settings_version: bool,
|
pub using_outdated_settings_version: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -198,25 +198,25 @@ impl AssistantSettingsContent {
|
|||||||
.clone()
|
.clone()
|
||||||
.and_then(|provider| match provider {
|
.and_then(|provider| match provider {
|
||||||
AssistantProviderContentV1::ZedDotDev { default_model } => {
|
AssistantProviderContentV1::ZedDotDev { default_model } => {
|
||||||
default_model.map(|model| AssistantDefaultModel {
|
default_model.map(|model| LanguageModelSelection {
|
||||||
provider: "zed.dev".to_string(),
|
provider: "zed.dev".to_string(),
|
||||||
model: model.id().to_string(),
|
model: model.id().to_string(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
AssistantProviderContentV1::OpenAi { default_model, .. } => {
|
AssistantProviderContentV1::OpenAi { default_model, .. } => {
|
||||||
default_model.map(|model| AssistantDefaultModel {
|
default_model.map(|model| LanguageModelSelection {
|
||||||
provider: "openai".to_string(),
|
provider: "openai".to_string(),
|
||||||
model: model.id().to_string(),
|
model: model.id().to_string(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
AssistantProviderContentV1::Anthropic { default_model, .. } => {
|
AssistantProviderContentV1::Anthropic { default_model, .. } => {
|
||||||
default_model.map(|model| AssistantDefaultModel {
|
default_model.map(|model| LanguageModelSelection {
|
||||||
provider: "anthropic".to_string(),
|
provider: "anthropic".to_string(),
|
||||||
model: model.id().to_string(),
|
model: model.id().to_string(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
AssistantProviderContentV1::Ollama { default_model, .. } => {
|
AssistantProviderContentV1::Ollama { default_model, .. } => {
|
||||||
default_model.map(|model| AssistantDefaultModel {
|
default_model.map(|model| LanguageModelSelection {
|
||||||
provider: "ollama".to_string(),
|
provider: "ollama".to_string(),
|
||||||
model: model.id().to_string(),
|
model: model.id().to_string(),
|
||||||
})
|
})
|
||||||
@ -231,7 +231,7 @@ impl AssistantSettingsContent {
|
|||||||
dock: settings.dock,
|
dock: settings.dock,
|
||||||
default_width: settings.default_width,
|
default_width: settings.default_width,
|
||||||
default_height: settings.default_height,
|
default_height: settings.default_height,
|
||||||
default_model: Some(AssistantDefaultModel {
|
default_model: Some(LanguageModelSelection {
|
||||||
provider: "openai".to_string(),
|
provider: "openai".to_string(),
|
||||||
model: settings
|
model: settings
|
||||||
.default_open_ai_model
|
.default_open_ai_model
|
||||||
@ -325,7 +325,7 @@ impl AssistantSettingsContent {
|
|||||||
_ => {}
|
_ => {}
|
||||||
},
|
},
|
||||||
VersionedAssistantSettingsContent::V2(settings) => {
|
VersionedAssistantSettingsContent::V2(settings) => {
|
||||||
settings.default_model = Some(AssistantDefaultModel { provider, model });
|
settings.default_model = Some(LanguageModelSelection { provider, model });
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
AssistantSettingsContent::Legacy(settings) => {
|
AssistantSettingsContent::Legacy(settings) => {
|
||||||
@ -382,11 +382,11 @@ pub struct AssistantSettingsContentV2 {
|
|||||||
/// Default: 320
|
/// Default: 320
|
||||||
default_height: Option<f32>,
|
default_height: Option<f32>,
|
||||||
/// The default model to use when creating new contexts.
|
/// The default model to use when creating new contexts.
|
||||||
default_model: Option<AssistantDefaultModel>,
|
default_model: Option<LanguageModelSelection>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
|
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
|
||||||
pub struct AssistantDefaultModel {
|
pub struct LanguageModelSelection {
|
||||||
#[schemars(schema_with = "providers_schema")]
|
#[schemars(schema_with = "providers_schema")]
|
||||||
pub provider: String,
|
pub provider: String,
|
||||||
pub model: String,
|
pub model: String,
|
||||||
@ -407,7 +407,7 @@ fn providers_schema(_: &mut schemars::gen::SchemaGenerator) -> schemars::schema:
|
|||||||
.into()
|
.into()
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for AssistantDefaultModel {
|
impl Default for LanguageModelSelection {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
Self {
|
Self {
|
||||||
provider: "openai".to_string(),
|
provider: "openai".to_string(),
|
||||||
@ -542,7 +542,7 @@ mod tests {
|
|||||||
assert!(!AssistantSettings::get_global(cx).using_outdated_settings_version);
|
assert!(!AssistantSettings::get_global(cx).using_outdated_settings_version);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
AssistantSettings::get_global(cx).default_model,
|
AssistantSettings::get_global(cx).default_model,
|
||||||
AssistantDefaultModel {
|
LanguageModelSelection {
|
||||||
provider: "openai".into(),
|
provider: "openai".into(),
|
||||||
model: "gpt-4o".into(),
|
model: "gpt-4o".into(),
|
||||||
}
|
}
|
||||||
@ -555,7 +555,7 @@ mod tests {
|
|||||||
|settings, _| {
|
|settings, _| {
|
||||||
*settings = AssistantSettingsContent::Versioned(
|
*settings = AssistantSettingsContent::Versioned(
|
||||||
VersionedAssistantSettingsContent::V2(AssistantSettingsContentV2 {
|
VersionedAssistantSettingsContent::V2(AssistantSettingsContentV2 {
|
||||||
default_model: Some(AssistantDefaultModel {
|
default_model: Some(LanguageModelSelection {
|
||||||
provider: "test-provider".into(),
|
provider: "test-provider".into(),
|
||||||
model: "gpt-99".into(),
|
model: "gpt-99".into(),
|
||||||
}),
|
}),
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
use crate::{
|
use crate::{
|
||||||
prompt_library::PromptStore, slash_command::SlashCommandLine, InitialInsertion,
|
prompt_library::PromptStore, slash_command::SlashCommandLine, InitialInsertion, MessageId,
|
||||||
LanguageModelCompletionProvider, MessageId, MessageStatus,
|
MessageStatus,
|
||||||
};
|
};
|
||||||
use anyhow::{anyhow, Context as _, Result};
|
use anyhow::{anyhow, Context as _, Result};
|
||||||
use assistant_slash_command::{
|
use assistant_slash_command::{
|
||||||
@ -18,7 +18,10 @@ use gpui::{AppContext, Context as _, EventEmitter, Model, ModelContext, Subscrip
|
|||||||
use language::{
|
use language::{
|
||||||
AnchorRangeExt, Bias, Buffer, LanguageRegistry, OffsetRangeExt, ParseStatus, Point, ToOffset,
|
AnchorRangeExt, Bias, Buffer, LanguageRegistry, OffsetRangeExt, ParseStatus, Point, ToOffset,
|
||||||
};
|
};
|
||||||
use language_model::{LanguageModelRequest, LanguageModelRequestMessage, LanguageModelTool, Role};
|
use language_model::{
|
||||||
|
LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, LanguageModelTool,
|
||||||
|
Role,
|
||||||
|
};
|
||||||
use open_ai::Model as OpenAiModel;
|
use open_ai::Model as OpenAiModel;
|
||||||
use paths::contexts_dir;
|
use paths::contexts_dir;
|
||||||
use project::Project;
|
use project::Project;
|
||||||
@ -1180,17 +1183,16 @@ impl Context {
|
|||||||
|
|
||||||
pub(crate) fn count_remaining_tokens(&mut self, cx: &mut ModelContext<Self>) {
|
pub(crate) fn count_remaining_tokens(&mut self, cx: &mut ModelContext<Self>) {
|
||||||
let request = self.to_completion_request(cx);
|
let request = self.to_completion_request(cx);
|
||||||
|
let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
|
||||||
|
return;
|
||||||
|
};
|
||||||
self.pending_token_count = cx.spawn(|this, mut cx| {
|
self.pending_token_count = cx.spawn(|this, mut cx| {
|
||||||
async move {
|
async move {
|
||||||
cx.background_executor()
|
cx.background_executor()
|
||||||
.timer(Duration::from_millis(200))
|
.timer(Duration::from_millis(200))
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
let token_count = cx
|
let token_count = cx.update(|cx| model.count_tokens(request, cx))?.await?;
|
||||||
.update(|cx| {
|
|
||||||
LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx)
|
|
||||||
})?
|
|
||||||
.await?;
|
|
||||||
this.update(&mut cx, |this, cx| {
|
this.update(&mut cx, |this, cx| {
|
||||||
this.token_count = Some(token_count);
|
this.token_count = Some(token_count);
|
||||||
cx.notify()
|
cx.notify()
|
||||||
@ -1368,6 +1370,10 @@ impl Context {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
|
||||||
|
return Task::ready(Err(anyhow!("no active model")).log_err());
|
||||||
|
};
|
||||||
|
|
||||||
let mut request = self.to_completion_request(cx);
|
let mut request = self.to_completion_request(cx);
|
||||||
let edit_step_range = edit_step.source_range.clone();
|
let edit_step_range = edit_step.source_range.clone();
|
||||||
let step_text = self
|
let step_text = self
|
||||||
@ -1388,12 +1394,7 @@ impl Context {
|
|||||||
content: prompt,
|
content: prompt,
|
||||||
});
|
});
|
||||||
|
|
||||||
let tool_use = cx
|
let tool_use = model.use_tool::<EditTool>(request, &cx).await?;
|
||||||
.update(|cx| {
|
|
||||||
LanguageModelCompletionProvider::read_global(cx)
|
|
||||||
.use_tool::<EditTool>(request, cx)
|
|
||||||
})?
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
this.update(&mut cx, |this, cx| {
|
this.update(&mut cx, |this, cx| {
|
||||||
let step_index = this
|
let step_index = this
|
||||||
@ -1568,6 +1569,8 @@ impl Context {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn assist(&mut self, cx: &mut ModelContext<Self>) -> Option<MessageAnchor> {
|
pub fn assist(&mut self, cx: &mut ModelContext<Self>) -> Option<MessageAnchor> {
|
||||||
|
let provider = LanguageModelRegistry::read_global(cx).active_provider()?;
|
||||||
|
let model = LanguageModelRegistry::read_global(cx).active_model()?;
|
||||||
let last_message_id = self.message_anchors.iter().rev().find_map(|message| {
|
let last_message_id = self.message_anchors.iter().rev().find_map(|message| {
|
||||||
message
|
message
|
||||||
.start
|
.start
|
||||||
@ -1575,14 +1578,12 @@ impl Context {
|
|||||||
.then_some(message.id)
|
.then_some(message.id)
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
if !LanguageModelCompletionProvider::read_global(cx).is_authenticated(cx) {
|
if !provider.is_authenticated(cx) {
|
||||||
log::info!("completion provider has no credentials");
|
log::info!("completion provider has no credentials");
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
|
|
||||||
let request = self.to_completion_request(cx);
|
let request = self.to_completion_request(cx);
|
||||||
let stream =
|
|
||||||
LanguageModelCompletionProvider::read_global(cx).stream_completion(request, cx);
|
|
||||||
let assistant_message = self
|
let assistant_message = self
|
||||||
.insert_message_after(last_message_id, Role::Assistant, MessageStatus::Pending, cx)
|
.insert_message_after(last_message_id, Role::Assistant, MessageStatus::Pending, cx)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
@ -1594,6 +1595,7 @@ impl Context {
|
|||||||
|
|
||||||
let task = cx.spawn({
|
let task = cx.spawn({
|
||||||
|this, mut cx| async move {
|
|this, mut cx| async move {
|
||||||
|
let stream = model.stream_completion(request, &cx);
|
||||||
let assistant_message_id = assistant_message.id;
|
let assistant_message_id = assistant_message.id;
|
||||||
let mut response_latency = None;
|
let mut response_latency = None;
|
||||||
let stream_completion = async {
|
let stream_completion = async {
|
||||||
@ -1662,14 +1664,10 @@ impl Context {
|
|||||||
});
|
});
|
||||||
|
|
||||||
if let Some(telemetry) = this.telemetry.as_ref() {
|
if let Some(telemetry) = this.telemetry.as_ref() {
|
||||||
let model_telemetry_id = LanguageModelCompletionProvider::read_global(cx)
|
|
||||||
.active_model()
|
|
||||||
.map(|m| m.telemetry_id())
|
|
||||||
.unwrap_or_default();
|
|
||||||
telemetry.report_assistant_event(
|
telemetry.report_assistant_event(
|
||||||
Some(this.id.0.clone()),
|
Some(this.id.0.clone()),
|
||||||
AssistantKind::Panel,
|
AssistantKind::Panel,
|
||||||
model_telemetry_id,
|
model.telemetry_id(),
|
||||||
response_latency,
|
response_latency,
|
||||||
error_message,
|
error_message,
|
||||||
);
|
);
|
||||||
@ -1935,8 +1933,15 @@ impl Context {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub(super) fn summarize(&mut self, replace_old: bool, cx: &mut ModelContext<Self>) {
|
pub(super) fn summarize(&mut self, replace_old: bool, cx: &mut ModelContext<Self>) {
|
||||||
|
let Some(provider) = LanguageModelRegistry::read_global(cx).active_provider() else {
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
|
||||||
if replace_old || (self.message_anchors.len() >= 2 && self.summary.is_none()) {
|
if replace_old || (self.message_anchors.len() >= 2 && self.summary.is_none()) {
|
||||||
if !LanguageModelCompletionProvider::read_global(cx).is_authenticated(cx) {
|
if !provider.is_authenticated(cx) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1953,10 +1958,9 @@ impl Context {
|
|||||||
temperature: 1.0,
|
temperature: 1.0,
|
||||||
};
|
};
|
||||||
|
|
||||||
let stream =
|
|
||||||
LanguageModelCompletionProvider::read_global(cx).stream_completion(request, cx);
|
|
||||||
self.pending_summary = cx.spawn(|this, mut cx| {
|
self.pending_summary = cx.spawn(|this, mut cx| {
|
||||||
async move {
|
async move {
|
||||||
|
let stream = model.stream_completion(request, &cx);
|
||||||
let mut messages = stream.await?;
|
let mut messages = stream.await?;
|
||||||
|
|
||||||
let mut replaced = !replace_old;
|
let mut replaced = !replace_old;
|
||||||
@ -2490,7 +2494,6 @@ mod tests {
|
|||||||
fn test_inserting_and_removing_messages(cx: &mut AppContext) {
|
fn test_inserting_and_removing_messages(cx: &mut AppContext) {
|
||||||
let settings_store = SettingsStore::test(cx);
|
let settings_store = SettingsStore::test(cx);
|
||||||
language_model::LanguageModelRegistry::test(cx);
|
language_model::LanguageModelRegistry::test(cx);
|
||||||
completion::LanguageModelCompletionProvider::test(cx);
|
|
||||||
cx.set_global(settings_store);
|
cx.set_global(settings_store);
|
||||||
assistant_panel::init(cx);
|
assistant_panel::init(cx);
|
||||||
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
|
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
|
||||||
@ -2623,7 +2626,6 @@ mod tests {
|
|||||||
let settings_store = SettingsStore::test(cx);
|
let settings_store = SettingsStore::test(cx);
|
||||||
cx.set_global(settings_store);
|
cx.set_global(settings_store);
|
||||||
language_model::LanguageModelRegistry::test(cx);
|
language_model::LanguageModelRegistry::test(cx);
|
||||||
completion::LanguageModelCompletionProvider::test(cx);
|
|
||||||
assistant_panel::init(cx);
|
assistant_panel::init(cx);
|
||||||
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
|
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
|
||||||
|
|
||||||
@ -2717,7 +2719,6 @@ mod tests {
|
|||||||
fn test_messages_for_offsets(cx: &mut AppContext) {
|
fn test_messages_for_offsets(cx: &mut AppContext) {
|
||||||
let settings_store = SettingsStore::test(cx);
|
let settings_store = SettingsStore::test(cx);
|
||||||
language_model::LanguageModelRegistry::test(cx);
|
language_model::LanguageModelRegistry::test(cx);
|
||||||
completion::LanguageModelCompletionProvider::test(cx);
|
|
||||||
cx.set_global(settings_store);
|
cx.set_global(settings_store);
|
||||||
assistant_panel::init(cx);
|
assistant_panel::init(cx);
|
||||||
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
|
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
|
||||||
@ -2803,7 +2804,6 @@ mod tests {
|
|||||||
let settings_store = cx.update(SettingsStore::test);
|
let settings_store = cx.update(SettingsStore::test);
|
||||||
cx.set_global(settings_store);
|
cx.set_global(settings_store);
|
||||||
cx.update(language_model::LanguageModelRegistry::test);
|
cx.update(language_model::LanguageModelRegistry::test);
|
||||||
cx.update(completion::LanguageModelCompletionProvider::test);
|
|
||||||
cx.update(Project::init_settings);
|
cx.update(Project::init_settings);
|
||||||
cx.update(assistant_panel::init);
|
cx.update(assistant_panel::init);
|
||||||
let fs = FakeFs::new(cx.background_executor.clone());
|
let fs = FakeFs::new(cx.background_executor.clone());
|
||||||
@ -2930,7 +2930,6 @@ mod tests {
|
|||||||
cx.set_global(settings_store);
|
cx.set_global(settings_store);
|
||||||
|
|
||||||
let fake_provider = cx.update(language_model::LanguageModelRegistry::test);
|
let fake_provider = cx.update(language_model::LanguageModelRegistry::test);
|
||||||
cx.update(completion::LanguageModelCompletionProvider::test);
|
|
||||||
|
|
||||||
let fake_model = fake_provider.test_model();
|
let fake_model = fake_provider.test_model();
|
||||||
cx.update(assistant_panel::init);
|
cx.update(assistant_panel::init);
|
||||||
@ -3032,7 +3031,6 @@ mod tests {
|
|||||||
let settings_store = cx.update(SettingsStore::test);
|
let settings_store = cx.update(SettingsStore::test);
|
||||||
cx.set_global(settings_store);
|
cx.set_global(settings_store);
|
||||||
cx.update(language_model::LanguageModelRegistry::test);
|
cx.update(language_model::LanguageModelRegistry::test);
|
||||||
cx.update(completion::LanguageModelCompletionProvider::test);
|
|
||||||
cx.update(assistant_panel::init);
|
cx.update(assistant_panel::init);
|
||||||
let registry = Arc::new(LanguageRegistry::test(cx.executor()));
|
let registry = Arc::new(LanguageRegistry::test(cx.executor()));
|
||||||
let context = cx.new_model(|cx| Context::local(registry.clone(), None, cx));
|
let context = cx.new_model(|cx| Context::local(registry.clone(), None, cx));
|
||||||
@ -3109,7 +3107,6 @@ mod tests {
|
|||||||
let settings_store = cx.update(SettingsStore::test);
|
let settings_store = cx.update(SettingsStore::test);
|
||||||
cx.set_global(settings_store);
|
cx.set_global(settings_store);
|
||||||
cx.update(language_model::LanguageModelRegistry::test);
|
cx.update(language_model::LanguageModelRegistry::test);
|
||||||
cx.update(completion::LanguageModelCompletionProvider::test);
|
|
||||||
|
|
||||||
cx.update(assistant_panel::init);
|
cx.update(assistant_panel::init);
|
||||||
let slash_commands = cx.update(SlashCommandRegistry::default_global);
|
let slash_commands = cx.update(SlashCommandRegistry::default_global);
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
use crate::{
|
use crate::{
|
||||||
humanize_token_count, prompts::generate_content_prompt, AssistantPanel, AssistantPanelEvent,
|
humanize_token_count, prompts::generate_content_prompt, AssistantPanel, AssistantPanelEvent,
|
||||||
Hunk, LanguageModelCompletionProvider, ModelSelector, StreamingDiff,
|
Hunk, ModelSelector, StreamingDiff,
|
||||||
};
|
};
|
||||||
use anyhow::{anyhow, Context as _, Result};
|
use anyhow::{anyhow, Context as _, Result};
|
||||||
use client::telemetry::Telemetry;
|
use client::telemetry::Telemetry;
|
||||||
@ -27,7 +27,9 @@ use gpui::{
|
|||||||
WindowContext,
|
WindowContext,
|
||||||
};
|
};
|
||||||
use language::{Buffer, IndentKind, Point, Selection, TransactionId};
|
use language::{Buffer, IndentKind, Point, Selection, TransactionId};
|
||||||
use language_model::{LanguageModelRequest, LanguageModelRequestMessage, Role};
|
use language_model::{
|
||||||
|
LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
|
||||||
|
};
|
||||||
use multi_buffer::MultiBufferRow;
|
use multi_buffer::MultiBufferRow;
|
||||||
use parking_lot::Mutex;
|
use parking_lot::Mutex;
|
||||||
use rope::Rope;
|
use rope::Rope;
|
||||||
@ -1328,7 +1330,7 @@ impl Render for PromptEditor {
|
|||||||
Tooltip::with_meta(
|
Tooltip::with_meta(
|
||||||
format!(
|
format!(
|
||||||
"Using {}",
|
"Using {}",
|
||||||
LanguageModelCompletionProvider::read_global(cx)
|
LanguageModelRegistry::read_global(cx)
|
||||||
.active_model()
|
.active_model()
|
||||||
.map(|model| model.name().0)
|
.map(|model| model.name().0)
|
||||||
.unwrap_or_else(|| "No model selected".into()),
|
.unwrap_or_else(|| "No model selected".into()),
|
||||||
@ -1662,7 +1664,7 @@ impl PromptEditor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn render_token_count(&self, cx: &mut ViewContext<Self>) -> Option<impl IntoElement> {
|
fn render_token_count(&self, cx: &mut ViewContext<Self>) -> Option<impl IntoElement> {
|
||||||
let model = LanguageModelCompletionProvider::read_global(cx).active_model()?;
|
let model = LanguageModelRegistry::read_global(cx).active_model()?;
|
||||||
let token_count = self.token_count?;
|
let token_count = self.token_count?;
|
||||||
let max_token_count = model.max_token_count();
|
let max_token_count = model.max_token_count();
|
||||||
|
|
||||||
@ -2013,8 +2015,12 @@ impl Codegen {
|
|||||||
assistant_panel_context: Option<LanguageModelRequest>,
|
assistant_panel_context: Option<LanguageModelRequest>,
|
||||||
cx: &AppContext,
|
cx: &AppContext,
|
||||||
) -> BoxFuture<'static, Result<usize>> {
|
) -> BoxFuture<'static, Result<usize>> {
|
||||||
|
if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() {
|
||||||
let request = self.build_request(user_prompt, assistant_panel_context, edit_range, cx);
|
let request = self.build_request(user_prompt, assistant_panel_context, edit_range, cx);
|
||||||
LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx)
|
model.count_tokens(request, cx)
|
||||||
|
} else {
|
||||||
|
future::ready(Err(anyhow!("no active model"))).boxed()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn start(
|
pub fn start(
|
||||||
@ -2024,6 +2030,10 @@ impl Codegen {
|
|||||||
assistant_panel_context: Option<LanguageModelRequest>,
|
assistant_panel_context: Option<LanguageModelRequest>,
|
||||||
cx: &mut ModelContext<Self>,
|
cx: &mut ModelContext<Self>,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
|
let model = LanguageModelRegistry::read_global(cx)
|
||||||
|
.active_model()
|
||||||
|
.context("no active model")?;
|
||||||
|
|
||||||
self.undo(cx);
|
self.undo(cx);
|
||||||
|
|
||||||
// Handle initial insertion
|
// Handle initial insertion
|
||||||
@ -2053,10 +2063,7 @@ impl Codegen {
|
|||||||
None
|
None
|
||||||
};
|
};
|
||||||
|
|
||||||
let model_telemetry_id = LanguageModelCompletionProvider::read_global(cx)
|
let telemetry_id = model.telemetry_id();
|
||||||
.active_model_telemetry_id()
|
|
||||||
.context("no active model")?;
|
|
||||||
|
|
||||||
let chunks: LocalBoxFuture<Result<BoxStream<Result<String>>>> = if user_prompt
|
let chunks: LocalBoxFuture<Result<BoxStream<Result<String>>>> = if user_prompt
|
||||||
.trim()
|
.trim()
|
||||||
.to_lowercase()
|
.to_lowercase()
|
||||||
@ -2067,10 +2074,10 @@ impl Codegen {
|
|||||||
let request =
|
let request =
|
||||||
self.build_request(user_prompt, assistant_panel_context, edit_range.clone(), cx);
|
self.build_request(user_prompt, assistant_panel_context, edit_range.clone(), cx);
|
||||||
let chunks =
|
let chunks =
|
||||||
LanguageModelCompletionProvider::read_global(cx).stream_completion(request, cx);
|
cx.spawn(|_, cx| async move { model.stream_completion(request, &cx).await });
|
||||||
async move { Ok(chunks.await?.boxed()) }.boxed_local()
|
async move { Ok(chunks.await?.boxed()) }.boxed_local()
|
||||||
};
|
};
|
||||||
self.handle_stream(model_telemetry_id, edit_range, chunks, cx);
|
self.handle_stream(telemetry_id, edit_range, chunks, cx);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2657,7 +2664,6 @@ mod tests {
|
|||||||
async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
|
async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
|
||||||
cx.set_global(cx.update(SettingsStore::test));
|
cx.set_global(cx.update(SettingsStore::test));
|
||||||
cx.update(language_model::LanguageModelRegistry::test);
|
cx.update(language_model::LanguageModelRegistry::test);
|
||||||
cx.update(completion::LanguageModelCompletionProvider::test);
|
|
||||||
cx.update(language_settings::init);
|
cx.update(language_settings::init);
|
||||||
|
|
||||||
let text = indoc! {"
|
let text = indoc! {"
|
||||||
@ -2789,7 +2795,6 @@ mod tests {
|
|||||||
mut rng: StdRng,
|
mut rng: StdRng,
|
||||||
) {
|
) {
|
||||||
cx.update(LanguageModelRegistry::test);
|
cx.update(LanguageModelRegistry::test);
|
||||||
cx.update(completion::LanguageModelCompletionProvider::test);
|
|
||||||
cx.set_global(cx.update(SettingsStore::test));
|
cx.set_global(cx.update(SettingsStore::test));
|
||||||
cx.update(language_settings::init);
|
cx.update(language_settings::init);
|
||||||
|
|
||||||
@ -2853,7 +2858,6 @@ mod tests {
|
|||||||
#[gpui::test(iterations = 10)]
|
#[gpui::test(iterations = 10)]
|
||||||
async fn test_autoindent_respects_tabs_in_selection(cx: &mut TestAppContext) {
|
async fn test_autoindent_respects_tabs_in_selection(cx: &mut TestAppContext) {
|
||||||
cx.update(LanguageModelRegistry::test);
|
cx.update(LanguageModelRegistry::test);
|
||||||
cx.update(completion::LanguageModelCompletionProvider::test);
|
|
||||||
cx.set_global(cx.update(SettingsStore::test));
|
cx.set_global(cx.update(SettingsStore::test));
|
||||||
cx.update(language_settings::init);
|
cx.update(language_settings::init);
|
||||||
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use crate::{assistant_settings::AssistantSettings, LanguageModelCompletionProvider};
|
use crate::assistant_settings::AssistantSettings;
|
||||||
use fs::Fs;
|
use fs::Fs;
|
||||||
use gpui::SharedString;
|
use gpui::SharedString;
|
||||||
use language_model::LanguageModelRegistry;
|
use language_model::LanguageModelRegistry;
|
||||||
@ -81,13 +81,13 @@ impl<T: PopoverTrigger> RenderOnce for ModelSelector<T> {
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
let provider = provider.id();
|
let provider = provider.clone();
|
||||||
move |cx| {
|
move |cx| {
|
||||||
LanguageModelCompletionProvider::global(cx).update(
|
LanguageModelRegistry::global(cx).update(
|
||||||
cx,
|
cx,
|
||||||
|completion_provider, cx| {
|
|completion_provider, cx| {
|
||||||
completion_provider
|
completion_provider
|
||||||
.set_active_provider(provider.clone(), cx)
|
.set_active_provider(Some(provider.clone()), cx);
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
@ -95,12 +95,12 @@ impl<T: PopoverTrigger> RenderOnce for ModelSelector<T> {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
let selected_model = LanguageModelCompletionProvider::read_global(cx)
|
let selected_provider = LanguageModelRegistry::read_global(cx)
|
||||||
.active_model()
|
|
||||||
.map(|m| m.id());
|
|
||||||
let selected_provider = LanguageModelCompletionProvider::read_global(cx)
|
|
||||||
.active_provider()
|
.active_provider()
|
||||||
.map(|m| m.id());
|
.map(|m| m.id());
|
||||||
|
let selected_model = LanguageModelRegistry::read_global(cx)
|
||||||
|
.active_model()
|
||||||
|
.map(|m| m.id());
|
||||||
|
|
||||||
for available_model in available_models {
|
for available_model in available_models {
|
||||||
menu = menu.custom_entry(
|
menu = menu.custom_entry(
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
use crate::{
|
use crate::{
|
||||||
slash_command::SlashCommandCompletionProvider, AssistantPanel, InlineAssist, InlineAssistant,
|
slash_command::SlashCommandCompletionProvider, AssistantPanel, InlineAssist, InlineAssistant,
|
||||||
LanguageModelCompletionProvider,
|
|
||||||
};
|
};
|
||||||
use anyhow::{anyhow, Result};
|
use anyhow::{anyhow, Result};
|
||||||
use assets::Assets;
|
use assets::Assets;
|
||||||
@ -19,7 +18,9 @@ use gpui::{
|
|||||||
};
|
};
|
||||||
use heed::{types::SerdeBincode, Database, RoTxn};
|
use heed::{types::SerdeBincode, Database, RoTxn};
|
||||||
use language::{language_settings::SoftWrap, Buffer, LanguageRegistry};
|
use language::{language_settings::SoftWrap, Buffer, LanguageRegistry};
|
||||||
use language_model::{LanguageModelRequest, LanguageModelRequestMessage, Role};
|
use language_model::{
|
||||||
|
LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
|
||||||
|
};
|
||||||
use parking_lot::RwLock;
|
use parking_lot::RwLock;
|
||||||
use picker::{Picker, PickerDelegate};
|
use picker::{Picker, PickerDelegate};
|
||||||
use rope::Rope;
|
use rope::Rope;
|
||||||
@ -636,7 +637,10 @@ impl PromptLibrary {
|
|||||||
};
|
};
|
||||||
|
|
||||||
let prompt_editor = &self.prompt_editors[&active_prompt_id].body_editor;
|
let prompt_editor = &self.prompt_editors[&active_prompt_id].body_editor;
|
||||||
let provider = LanguageModelCompletionProvider::read_global(cx);
|
let Some(provider) = LanguageModelRegistry::read_global(cx).active_provider() else {
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
|
||||||
let initial_prompt = action.prompt.clone();
|
let initial_prompt = action.prompt.clone();
|
||||||
if provider.is_authenticated(cx) {
|
if provider.is_authenticated(cx) {
|
||||||
InlineAssistant::update_global(cx, |assistant, cx| {
|
InlineAssistant::update_global(cx, |assistant, cx| {
|
||||||
@ -725,6 +729,9 @@ impl PromptLibrary {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn count_tokens(&mut self, prompt_id: PromptId, cx: &mut ViewContext<Self>) {
|
fn count_tokens(&mut self, prompt_id: PromptId, cx: &mut ViewContext<Self>) {
|
||||||
|
let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
|
||||||
|
return;
|
||||||
|
};
|
||||||
if let Some(prompt) = self.prompt_editors.get_mut(&prompt_id) {
|
if let Some(prompt) = self.prompt_editors.get_mut(&prompt_id) {
|
||||||
let editor = &prompt.body_editor.read(cx);
|
let editor = &prompt.body_editor.read(cx);
|
||||||
let buffer = &editor.buffer().read(cx).as_singleton().unwrap().read(cx);
|
let buffer = &editor.buffer().read(cx).as_singleton().unwrap().read(cx);
|
||||||
@ -736,7 +743,7 @@ impl PromptLibrary {
|
|||||||
cx.background_executor().timer(DEBOUNCE_TIMEOUT).await;
|
cx.background_executor().timer(DEBOUNCE_TIMEOUT).await;
|
||||||
let token_count = cx
|
let token_count = cx
|
||||||
.update(|cx| {
|
.update(|cx| {
|
||||||
LanguageModelCompletionProvider::read_global(cx).count_tokens(
|
model.count_tokens(
|
||||||
LanguageModelRequest {
|
LanguageModelRequest {
|
||||||
messages: vec![LanguageModelRequestMessage {
|
messages: vec![LanguageModelRequestMessage {
|
||||||
role: Role::System,
|
role: Role::System,
|
||||||
@ -804,7 +811,7 @@ impl PromptLibrary {
|
|||||||
let prompt_metadata = self.store.metadata(prompt_id)?;
|
let prompt_metadata = self.store.metadata(prompt_id)?;
|
||||||
let prompt_editor = &self.prompt_editors[&prompt_id];
|
let prompt_editor = &self.prompt_editors[&prompt_id];
|
||||||
let focus_handle = prompt_editor.body_editor.focus_handle(cx);
|
let focus_handle = prompt_editor.body_editor.focus_handle(cx);
|
||||||
let current_model = LanguageModelCompletionProvider::read_global(cx).active_model();
|
let model = LanguageModelRegistry::read_global(cx).active_model();
|
||||||
let settings = ThemeSettings::get_global(cx);
|
let settings = ThemeSettings::get_global(cx);
|
||||||
|
|
||||||
Some(
|
Some(
|
||||||
@ -914,7 +921,7 @@ impl PromptLibrary {
|
|||||||
None,
|
None,
|
||||||
format!(
|
format!(
|
||||||
"Model: {}",
|
"Model: {}",
|
||||||
current_model
|
model
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.map(|model| model
|
.map(|model| model
|
||||||
.name()
|
.name()
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
use crate::{
|
use crate::{
|
||||||
humanize_token_count, prompts::generate_terminal_assistant_prompt, AssistantPanel,
|
humanize_token_count, prompts::generate_terminal_assistant_prompt, AssistantPanel,
|
||||||
AssistantPanelEvent, LanguageModelCompletionProvider, ModelSelector,
|
AssistantPanelEvent, ModelSelector,
|
||||||
};
|
};
|
||||||
use anyhow::{Context as _, Result};
|
use anyhow::{Context as _, Result};
|
||||||
use client::telemetry::Telemetry;
|
use client::telemetry::Telemetry;
|
||||||
@ -16,7 +16,9 @@ use gpui::{
|
|||||||
Subscription, Task, TextStyle, UpdateGlobal, View, WeakView,
|
Subscription, Task, TextStyle, UpdateGlobal, View, WeakView,
|
||||||
};
|
};
|
||||||
use language::Buffer;
|
use language::Buffer;
|
||||||
use language_model::{LanguageModelRequest, LanguageModelRequestMessage, Role};
|
use language_model::{
|
||||||
|
LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
|
||||||
|
};
|
||||||
use settings::Settings;
|
use settings::Settings;
|
||||||
use std::{
|
use std::{
|
||||||
cmp,
|
cmp,
|
||||||
@ -556,7 +558,7 @@ impl Render for PromptEditor {
|
|||||||
Tooltip::with_meta(
|
Tooltip::with_meta(
|
||||||
format!(
|
format!(
|
||||||
"Using {}",
|
"Using {}",
|
||||||
LanguageModelCompletionProvider::read_global(cx)
|
LanguageModelRegistry::read_global(cx)
|
||||||
.active_model()
|
.active_model()
|
||||||
.map(|model| model.name().0)
|
.map(|model| model.name().0)
|
||||||
.unwrap_or_else(|| "No model selected".into()),
|
.unwrap_or_else(|| "No model selected".into()),
|
||||||
@ -700,6 +702,9 @@ impl PromptEditor {
|
|||||||
|
|
||||||
fn count_tokens(&mut self, cx: &mut ViewContext<Self>) {
|
fn count_tokens(&mut self, cx: &mut ViewContext<Self>) {
|
||||||
let assist_id = self.id;
|
let assist_id = self.id;
|
||||||
|
let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
|
||||||
|
return;
|
||||||
|
};
|
||||||
self.pending_token_count = cx.spawn(|this, mut cx| async move {
|
self.pending_token_count = cx.spawn(|this, mut cx| async move {
|
||||||
cx.background_executor().timer(Duration::from_secs(1)).await;
|
cx.background_executor().timer(Duration::from_secs(1)).await;
|
||||||
let request =
|
let request =
|
||||||
@ -707,11 +712,7 @@ impl PromptEditor {
|
|||||||
inline_assistant.request_for_inline_assist(assist_id, cx)
|
inline_assistant.request_for_inline_assist(assist_id, cx)
|
||||||
})??;
|
})??;
|
||||||
|
|
||||||
let token_count = cx
|
let token_count = cx.update(|cx| model.count_tokens(request, cx))?.await?;
|
||||||
.update(|cx| {
|
|
||||||
LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx)
|
|
||||||
})?
|
|
||||||
.await?;
|
|
||||||
this.update(&mut cx, |this, cx| {
|
this.update(&mut cx, |this, cx| {
|
||||||
this.token_count = Some(token_count);
|
this.token_count = Some(token_count);
|
||||||
cx.notify();
|
cx.notify();
|
||||||
@ -840,7 +841,7 @@ impl PromptEditor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn render_token_count(&self, cx: &mut ViewContext<Self>) -> Option<impl IntoElement> {
|
fn render_token_count(&self, cx: &mut ViewContext<Self>) -> Option<impl IntoElement> {
|
||||||
let model = LanguageModelCompletionProvider::read_global(cx).active_model()?;
|
let model = LanguageModelRegistry::read_global(cx).active_model()?;
|
||||||
let token_count = self.token_count?;
|
let token_count = self.token_count?;
|
||||||
let max_token_count = model.max_token_count();
|
let max_token_count = model.max_token_count();
|
||||||
|
|
||||||
@ -982,19 +983,16 @@ impl Codegen {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn start(&mut self, prompt: LanguageModelRequest, cx: &mut ModelContext<Self>) {
|
pub fn start(&mut self, prompt: LanguageModelRequest, cx: &mut ModelContext<Self>) {
|
||||||
self.status = CodegenStatus::Pending;
|
let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
|
||||||
self.transaction = Some(TerminalTransaction::start(self.terminal.clone()));
|
return;
|
||||||
|
};
|
||||||
|
|
||||||
let telemetry = self.telemetry.clone();
|
let telemetry = self.telemetry.clone();
|
||||||
let model_telemetry_id = LanguageModelCompletionProvider::read_global(cx)
|
self.status = CodegenStatus::Pending;
|
||||||
.active_model()
|
self.transaction = Some(TerminalTransaction::start(self.terminal.clone()));
|
||||||
.map(|m| m.telemetry_id())
|
|
||||||
.unwrap_or_default();
|
|
||||||
let response =
|
|
||||||
LanguageModelCompletionProvider::read_global(cx).stream_completion(prompt, cx);
|
|
||||||
|
|
||||||
self.generation = cx.spawn(|this, mut cx| async move {
|
self.generation = cx.spawn(|this, mut cx| async move {
|
||||||
let response = response.await;
|
let model_telemetry_id = model.telemetry_id();
|
||||||
|
let response = model.stream_completion(prompt, &cx).await;
|
||||||
let generate = async {
|
let generate = async {
|
||||||
let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
|
let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
|
||||||
|
|
||||||
|
@ -80,7 +80,6 @@ channel.workspace = true
|
|||||||
client = { workspace = true, features = ["test-support"] }
|
client = { workspace = true, features = ["test-support"] }
|
||||||
collab_ui = { workspace = true, features = ["test-support"] }
|
collab_ui = { workspace = true, features = ["test-support"] }
|
||||||
collections = { workspace = true, features = ["test-support"] }
|
collections = { workspace = true, features = ["test-support"] }
|
||||||
completion = { workspace = true, features = ["test-support"] }
|
|
||||||
ctor.workspace = true
|
ctor.workspace = true
|
||||||
editor = { workspace = true, features = ["test-support"] }
|
editor = { workspace = true, features = ["test-support"] }
|
||||||
env_logger.workspace = true
|
env_logger.workspace = true
|
||||||
|
@ -300,7 +300,6 @@ impl TestServer {
|
|||||||
dev_server_projects::init(client.clone(), cx);
|
dev_server_projects::init(client.clone(), cx);
|
||||||
settings::KeymapFile::load_asset(os_keymap, cx).unwrap();
|
settings::KeymapFile::load_asset(os_keymap, cx).unwrap();
|
||||||
language_model::LanguageModelRegistry::test(cx);
|
language_model::LanguageModelRegistry::test(cx);
|
||||||
completion::init(cx);
|
|
||||||
assistant::context_store::init(&client);
|
assistant::context_store::init(&client);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -1,45 +0,0 @@
|
|||||||
[package]
|
|
||||||
name = "completion"
|
|
||||||
version = "0.1.0"
|
|
||||||
edition = "2021"
|
|
||||||
publish = false
|
|
||||||
license = "GPL-3.0-or-later"
|
|
||||||
|
|
||||||
[lints]
|
|
||||||
workspace = true
|
|
||||||
|
|
||||||
[lib]
|
|
||||||
path = "src/completion.rs"
|
|
||||||
doctest = false
|
|
||||||
|
|
||||||
[features]
|
|
||||||
test-support = [
|
|
||||||
"editor/test-support",
|
|
||||||
"language/test-support",
|
|
||||||
"language_model/test-support",
|
|
||||||
"project/test-support",
|
|
||||||
"text/test-support",
|
|
||||||
]
|
|
||||||
|
|
||||||
[dependencies]
|
|
||||||
anyhow.workspace = true
|
|
||||||
futures.workspace = true
|
|
||||||
gpui.workspace = true
|
|
||||||
language_model.workspace = true
|
|
||||||
schemars.workspace = true
|
|
||||||
serde.workspace = true
|
|
||||||
serde_json.workspace = true
|
|
||||||
settings.workspace = true
|
|
||||||
smol.workspace = true
|
|
||||||
ui.workspace = true
|
|
||||||
|
|
||||||
[dev-dependencies]
|
|
||||||
ctor.workspace = true
|
|
||||||
editor = { workspace = true, features = ["test-support"] }
|
|
||||||
env_logger.workspace = true
|
|
||||||
language = { workspace = true, features = ["test-support"] }
|
|
||||||
project = { workspace = true, features = ["test-support"] }
|
|
||||||
language_model = { workspace = true, features = ["test-support"] }
|
|
||||||
rand.workspace = true
|
|
||||||
text = { workspace = true, features = ["test-support"] }
|
|
||||||
unindent.workspace = true
|
|
@ -1 +0,0 @@
|
|||||||
../../LICENSE-GPL
|
|
@ -1,312 +0,0 @@
|
|||||||
use anyhow::{anyhow, Result};
|
|
||||||
use futures::{future::BoxFuture, stream::BoxStream, StreamExt};
|
|
||||||
use gpui::{AppContext, Global, Model, ModelContext, Task};
|
|
||||||
use language_model::{
|
|
||||||
LanguageModel, LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry,
|
|
||||||
LanguageModelRequest, LanguageModelTool,
|
|
||||||
};
|
|
||||||
use smol::{
|
|
||||||
future::FutureExt,
|
|
||||||
lock::{Semaphore, SemaphoreGuardArc},
|
|
||||||
};
|
|
||||||
use std::{future, pin::Pin, sync::Arc, task::Poll};
|
|
||||||
use ui::Context;
|
|
||||||
|
|
||||||
pub fn init(cx: &mut AppContext) {
|
|
||||||
let completion_provider = cx.new_model(|cx| LanguageModelCompletionProvider::new(cx));
|
|
||||||
cx.set_global(GlobalLanguageModelCompletionProvider(completion_provider));
|
|
||||||
}
|
|
||||||
|
|
||||||
struct GlobalLanguageModelCompletionProvider(Model<LanguageModelCompletionProvider>);
|
|
||||||
|
|
||||||
impl Global for GlobalLanguageModelCompletionProvider {}
|
|
||||||
|
|
||||||
pub struct LanguageModelCompletionProvider {
|
|
||||||
active_provider: Option<Arc<dyn LanguageModelProvider>>,
|
|
||||||
active_model: Option<Arc<dyn LanguageModel>>,
|
|
||||||
request_limiter: Arc<Semaphore>,
|
|
||||||
}
|
|
||||||
|
|
||||||
const MAX_CONCURRENT_COMPLETION_REQUESTS: usize = 4;
|
|
||||||
|
|
||||||
pub struct LanguageModelCompletionResponse {
|
|
||||||
inner: BoxStream<'static, Result<String>>,
|
|
||||||
_lock: SemaphoreGuardArc,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl futures::Stream for LanguageModelCompletionResponse {
|
|
||||||
type Item = Result<String>;
|
|
||||||
|
|
||||||
fn poll_next(
|
|
||||||
mut self: Pin<&mut Self>,
|
|
||||||
cx: &mut std::task::Context<'_>,
|
|
||||||
) -> Poll<Option<Self::Item>> {
|
|
||||||
Pin::new(&mut self.inner).poll_next(cx)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl LanguageModelCompletionProvider {
|
|
||||||
pub fn global(cx: &AppContext) -> Model<Self> {
|
|
||||||
cx.global::<GlobalLanguageModelCompletionProvider>()
|
|
||||||
.0
|
|
||||||
.clone()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn read_global(cx: &AppContext) -> &Self {
|
|
||||||
cx.global::<GlobalLanguageModelCompletionProvider>()
|
|
||||||
.0
|
|
||||||
.read(cx)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(any(test, feature = "test-support"))]
|
|
||||||
pub fn test(cx: &mut AppContext) {
|
|
||||||
let provider = cx.new_model(|cx| {
|
|
||||||
let mut this = Self::new(cx);
|
|
||||||
let available_model = LanguageModelRegistry::read_global(cx)
|
|
||||||
.available_models(cx)
|
|
||||||
.first()
|
|
||||||
.unwrap()
|
|
||||||
.clone();
|
|
||||||
this.set_active_model(available_model, cx);
|
|
||||||
this
|
|
||||||
});
|
|
||||||
cx.set_global(GlobalLanguageModelCompletionProvider(provider));
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn new(cx: &mut ModelContext<Self>) -> Self {
|
|
||||||
cx.observe(&LanguageModelRegistry::global(cx), |_, _, cx| {
|
|
||||||
cx.notify();
|
|
||||||
})
|
|
||||||
.detach();
|
|
||||||
|
|
||||||
Self {
|
|
||||||
active_provider: None,
|
|
||||||
active_model: None,
|
|
||||||
request_limiter: Arc::new(Semaphore::new(MAX_CONCURRENT_COMPLETION_REQUESTS)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn active_provider(&self) -> Option<Arc<dyn LanguageModelProvider>> {
|
|
||||||
self.active_provider.clone()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn set_active_provider(
|
|
||||||
&mut self,
|
|
||||||
provider_id: LanguageModelProviderId,
|
|
||||||
cx: &mut ModelContext<Self>,
|
|
||||||
) {
|
|
||||||
self.active_provider = LanguageModelRegistry::read_global(cx).provider(&provider_id);
|
|
||||||
self.active_model = None;
|
|
||||||
cx.notify();
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn active_model(&self) -> Option<Arc<dyn LanguageModel>> {
|
|
||||||
self.active_model.clone()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn set_active_model(&mut self, model: Arc<dyn LanguageModel>, cx: &mut ModelContext<Self>) {
|
|
||||||
if self.active_model.as_ref().map_or(false, |m| {
|
|
||||||
m.id() == model.id() && m.provider_id() == model.provider_id()
|
|
||||||
}) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
self.active_provider =
|
|
||||||
LanguageModelRegistry::read_global(cx).provider(&model.provider_id());
|
|
||||||
self.active_model = Some(model.clone());
|
|
||||||
|
|
||||||
if let Some(provider) = self.active_provider.as_ref() {
|
|
||||||
provider.load_model(model, cx);
|
|
||||||
}
|
|
||||||
|
|
||||||
cx.notify();
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn is_authenticated(&self, cx: &AppContext) -> bool {
|
|
||||||
self.active_provider
|
|
||||||
.as_ref()
|
|
||||||
.map_or(false, |provider| provider.is_authenticated(cx))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
|
|
||||||
self.active_provider
|
|
||||||
.as_ref()
|
|
||||||
.map_or(Task::ready(Ok(())), |provider| provider.authenticate(cx))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
|
|
||||||
self.active_provider
|
|
||||||
.as_ref()
|
|
||||||
.map_or(Task::ready(Ok(())), |provider| {
|
|
||||||
provider.reset_credentials(cx)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn count_tokens(
|
|
||||||
&self,
|
|
||||||
request: LanguageModelRequest,
|
|
||||||
cx: &AppContext,
|
|
||||||
) -> BoxFuture<'static, Result<usize>> {
|
|
||||||
if let Some(model) = self.active_model() {
|
|
||||||
model.count_tokens(request, cx)
|
|
||||||
} else {
|
|
||||||
future::ready(Err(anyhow!("no active model"))).boxed()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn stream_completion(
|
|
||||||
&self,
|
|
||||||
request: LanguageModelRequest,
|
|
||||||
cx: &AppContext,
|
|
||||||
) -> Task<Result<LanguageModelCompletionResponse>> {
|
|
||||||
if let Some(language_model) = self.active_model() {
|
|
||||||
let rate_limiter = self.request_limiter.clone();
|
|
||||||
cx.spawn(|cx| async move {
|
|
||||||
let lock = rate_limiter.acquire_arc().await;
|
|
||||||
let response = language_model.stream_completion(request, &cx).await?;
|
|
||||||
Ok(LanguageModelCompletionResponse {
|
|
||||||
inner: response,
|
|
||||||
_lock: lock,
|
|
||||||
})
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
Task::ready(Err(anyhow!("No active model set")))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn complete(&self, request: LanguageModelRequest, cx: &AppContext) -> Task<Result<String>> {
|
|
||||||
let response = self.stream_completion(request, cx);
|
|
||||||
cx.foreground_executor().spawn(async move {
|
|
||||||
let mut chunks = response.await?;
|
|
||||||
let mut completion = String::new();
|
|
||||||
while let Some(chunk) = chunks.next().await {
|
|
||||||
let chunk = chunk?;
|
|
||||||
completion.push_str(&chunk);
|
|
||||||
}
|
|
||||||
Ok(completion)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn use_tool<T: LanguageModelTool>(
|
|
||||||
&self,
|
|
||||||
request: LanguageModelRequest,
|
|
||||||
cx: &AppContext,
|
|
||||||
) -> Task<Result<T>> {
|
|
||||||
if let Some(language_model) = self.active_model() {
|
|
||||||
cx.spawn(|cx| async move {
|
|
||||||
let schema = schemars::schema_for!(T);
|
|
||||||
let schema_json = serde_json::to_value(&schema).unwrap();
|
|
||||||
let request =
|
|
||||||
language_model.use_tool(request, T::name(), T::description(), schema_json, &cx);
|
|
||||||
let response = request.await?;
|
|
||||||
Ok(serde_json::from_value(response)?)
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
Task::ready(Err(anyhow!("No active model set")))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn active_model_telemetry_id(&self) -> Option<String> {
|
|
||||||
self.active_model.as_ref().map(|m| m.telemetry_id())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use futures::StreamExt;
|
|
||||||
use gpui::AppContext;
|
|
||||||
use settings::SettingsStore;
|
|
||||||
use ui::Context;
|
|
||||||
|
|
||||||
use crate::{
|
|
||||||
LanguageModelCompletionProvider, LanguageModelRequest, MAX_CONCURRENT_COMPLETION_REQUESTS,
|
|
||||||
};
|
|
||||||
|
|
||||||
use language_model::LanguageModelRegistry;
|
|
||||||
|
|
||||||
#[gpui::test]
|
|
||||||
fn test_rate_limiting(cx: &mut AppContext) {
|
|
||||||
SettingsStore::test(cx);
|
|
||||||
let fake_provider = LanguageModelRegistry::test(cx);
|
|
||||||
|
|
||||||
let model = LanguageModelRegistry::read_global(cx)
|
|
||||||
.available_models(cx)
|
|
||||||
.first()
|
|
||||||
.cloned()
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let provider = cx.new_model(|cx| {
|
|
||||||
let mut provider = LanguageModelCompletionProvider::new(cx);
|
|
||||||
provider.set_active_model(model.clone(), cx);
|
|
||||||
provider
|
|
||||||
});
|
|
||||||
|
|
||||||
let fake_model = fake_provider.test_model();
|
|
||||||
|
|
||||||
// Enqueue some requests
|
|
||||||
for i in 0..MAX_CONCURRENT_COMPLETION_REQUESTS * 2 {
|
|
||||||
let response = provider.read(cx).stream_completion(
|
|
||||||
LanguageModelRequest {
|
|
||||||
temperature: i as f32 / 10.0,
|
|
||||||
..Default::default()
|
|
||||||
},
|
|
||||||
cx,
|
|
||||||
);
|
|
||||||
cx.background_executor()
|
|
||||||
.spawn(async move {
|
|
||||||
let mut stream = response.await.unwrap();
|
|
||||||
while let Some(message) = stream.next().await {
|
|
||||||
message.unwrap();
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.detach();
|
|
||||||
}
|
|
||||||
cx.background_executor().run_until_parked();
|
|
||||||
assert_eq!(
|
|
||||||
fake_model.completion_count(),
|
|
||||||
MAX_CONCURRENT_COMPLETION_REQUESTS
|
|
||||||
);
|
|
||||||
|
|
||||||
// Get the first completion request that is in flight and mark it as completed.
|
|
||||||
let completion = fake_model.pending_completions().into_iter().next().unwrap();
|
|
||||||
fake_model.finish_completion(&completion);
|
|
||||||
|
|
||||||
// Ensure that the number of in-flight completion requests is reduced.
|
|
||||||
assert_eq!(
|
|
||||||
fake_model.completion_count(),
|
|
||||||
MAX_CONCURRENT_COMPLETION_REQUESTS - 1
|
|
||||||
);
|
|
||||||
|
|
||||||
cx.background_executor().run_until_parked();
|
|
||||||
|
|
||||||
// Ensure that another completion request was allowed to acquire the lock.
|
|
||||||
assert_eq!(
|
|
||||||
fake_model.completion_count(),
|
|
||||||
MAX_CONCURRENT_COMPLETION_REQUESTS
|
|
||||||
);
|
|
||||||
|
|
||||||
// Mark all completion requests as finished that are in flight.
|
|
||||||
for request in fake_model.pending_completions() {
|
|
||||||
fake_model.finish_completion(&request);
|
|
||||||
}
|
|
||||||
|
|
||||||
assert_eq!(fake_model.completion_count(), 0);
|
|
||||||
|
|
||||||
// Wait until the background tasks acquire the lock again.
|
|
||||||
cx.background_executor().run_until_parked();
|
|
||||||
|
|
||||||
assert_eq!(
|
|
||||||
fake_model.completion_count(),
|
|
||||||
MAX_CONCURRENT_COMPLETION_REQUESTS - 1
|
|
||||||
);
|
|
||||||
|
|
||||||
// Finish all remaining completion requests.
|
|
||||||
for request in fake_model.pending_completions() {
|
|
||||||
fake_model.finish_completion(&request);
|
|
||||||
}
|
|
||||||
|
|
||||||
cx.background_executor().run_until_parked();
|
|
||||||
|
|
||||||
assert_eq!(fake_model.completion_count(), 0);
|
|
||||||
}
|
|
||||||
}
|
|
@ -208,13 +208,13 @@ impl CopilotChat {
|
|||||||
pub async fn stream_completion(
|
pub async fn stream_completion(
|
||||||
request: Request,
|
request: Request,
|
||||||
low_speed_timeout: Option<Duration>,
|
low_speed_timeout: Option<Duration>,
|
||||||
cx: &mut AsyncAppContext,
|
mut cx: AsyncAppContext,
|
||||||
) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
|
) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
|
||||||
let Some(this) = cx.update(|cx| Self::global(cx)).ok().flatten() else {
|
let Some(this) = cx.update(|cx| Self::global(cx)).ok().flatten() else {
|
||||||
return Err(anyhow!("Copilot chat is not enabled"));
|
return Err(anyhow!("Copilot chat is not enabled"));
|
||||||
};
|
};
|
||||||
|
|
||||||
let (oauth_token, api_token, client) = this.read_with(cx, |this, _| {
|
let (oauth_token, api_token, client) = this.read_with(&cx, |this, _| {
|
||||||
(
|
(
|
||||||
this.oauth_token.clone(),
|
this.oauth_token.clone(),
|
||||||
this.api_token.clone(),
|
this.api_token.clone(),
|
||||||
@ -229,7 +229,7 @@ impl CopilotChat {
|
|||||||
_ => {
|
_ => {
|
||||||
let token =
|
let token =
|
||||||
request_api_token(&oauth_token, client.clone(), low_speed_timeout).await?;
|
request_api_token(&oauth_token, client.clone(), low_speed_timeout).await?;
|
||||||
this.update(cx, |this, cx| {
|
this.update(&mut cx, |this, cx| {
|
||||||
this.api_token = Some(token.clone());
|
this.api_token = Some(token.clone());
|
||||||
cx.notify();
|
cx.notify();
|
||||||
})?;
|
})?;
|
||||||
|
@ -33,6 +33,7 @@ google_ai = { workspace = true, features = ["schemars"] }
|
|||||||
gpui.workspace = true
|
gpui.workspace = true
|
||||||
http_client.workspace = true
|
http_client.workspace = true
|
||||||
inline_completion_button.workspace = true
|
inline_completion_button.workspace = true
|
||||||
|
log.workspace = true
|
||||||
menu.workspace = true
|
menu.workspace = true
|
||||||
ollama = { workspace = true, features = ["schemars"] }
|
ollama = { workspace = true, features = ["schemars"] }
|
||||||
open_ai = { workspace = true, features = ["schemars"] }
|
open_ai = { workspace = true, features = ["schemars"] }
|
||||||
@ -42,6 +43,7 @@ schemars.workspace = true
|
|||||||
serde.workspace = true
|
serde.workspace = true
|
||||||
serde_json.workspace = true
|
serde_json.workspace = true
|
||||||
settings.workspace = true
|
settings.workspace = true
|
||||||
|
smol.workspace = true
|
||||||
strum.workspace = true
|
strum.workspace = true
|
||||||
theme.workspace = true
|
theme.workspace = true
|
||||||
tiktoken-rs.workspace = true
|
tiktoken-rs.workspace = true
|
||||||
|
@ -1,24 +1,24 @@
|
|||||||
mod model;
|
mod model;
|
||||||
pub mod provider;
|
pub mod provider;
|
||||||
|
mod rate_limiter;
|
||||||
mod registry;
|
mod registry;
|
||||||
mod request;
|
mod request;
|
||||||
mod role;
|
mod role;
|
||||||
pub mod settings;
|
pub mod settings;
|
||||||
|
|
||||||
use std::sync::Arc;
|
|
||||||
|
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use client::Client;
|
use client::Client;
|
||||||
use futures::{future::BoxFuture, stream::BoxStream};
|
use futures::{future::BoxFuture, stream::BoxStream};
|
||||||
use gpui::{AnyView, AppContext, AsyncAppContext, SharedString, Task, WindowContext};
|
use gpui::{AnyView, AppContext, AsyncAppContext, SharedString, Task, WindowContext};
|
||||||
|
|
||||||
pub use model::*;
|
pub use model::*;
|
||||||
use project::Fs;
|
use project::Fs;
|
||||||
|
pub(crate) use rate_limiter::*;
|
||||||
pub use registry::*;
|
pub use registry::*;
|
||||||
pub use request::*;
|
pub use request::*;
|
||||||
pub use role::*;
|
pub use role::*;
|
||||||
use schemars::JsonSchema;
|
use schemars::JsonSchema;
|
||||||
use serde::de::DeserializeOwned;
|
use serde::de::DeserializeOwned;
|
||||||
|
use std::{future::Future, sync::Arc};
|
||||||
|
|
||||||
pub fn init(client: Arc<Client>, fs: Arc<dyn Fs>, cx: &mut AppContext) {
|
pub fn init(client: Arc<Client>, fs: Arc<dyn Fs>, cx: &mut AppContext) {
|
||||||
settings::init(fs, cx);
|
settings::init(fs, cx);
|
||||||
@ -46,7 +46,7 @@ pub trait LanguageModel: Send + Sync {
|
|||||||
cx: &AsyncAppContext,
|
cx: &AsyncAppContext,
|
||||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
|
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
|
||||||
|
|
||||||
fn use_tool(
|
fn use_any_tool(
|
||||||
&self,
|
&self,
|
||||||
request: LanguageModelRequest,
|
request: LanguageModelRequest,
|
||||||
name: String,
|
name: String,
|
||||||
@ -56,6 +56,22 @@ pub trait LanguageModel: Send + Sync {
|
|||||||
) -> BoxFuture<'static, Result<serde_json::Value>>;
|
) -> BoxFuture<'static, Result<serde_json::Value>>;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl dyn LanguageModel {
|
||||||
|
pub fn use_tool<T: LanguageModelTool>(
|
||||||
|
&self,
|
||||||
|
request: LanguageModelRequest,
|
||||||
|
cx: &AsyncAppContext,
|
||||||
|
) -> impl 'static + Future<Output = Result<T>> {
|
||||||
|
let schema = schemars::schema_for!(T);
|
||||||
|
let schema_json = serde_json::to_value(&schema).unwrap();
|
||||||
|
let request = self.use_any_tool(request, T::name(), T::description(), schema_json, cx);
|
||||||
|
async move {
|
||||||
|
let response = request.await?;
|
||||||
|
Ok(serde_json::from_value(response)?)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema {
|
pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema {
|
||||||
fn name() -> String;
|
fn name() -> String;
|
||||||
fn description() -> String;
|
fn description() -> String;
|
||||||
@ -67,9 +83,9 @@ pub trait LanguageModelProvider: 'static {
|
|||||||
fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>>;
|
fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>>;
|
||||||
fn load_model(&self, _model: Arc<dyn LanguageModel>, _cx: &AppContext) {}
|
fn load_model(&self, _model: Arc<dyn LanguageModel>, _cx: &AppContext) {}
|
||||||
fn is_authenticated(&self, cx: &AppContext) -> bool;
|
fn is_authenticated(&self, cx: &AppContext) -> bool;
|
||||||
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>>;
|
fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>>;
|
||||||
fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView;
|
fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView;
|
||||||
fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>>;
|
fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>>;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait LanguageModelProviderState: 'static {
|
pub trait LanguageModelProviderState: 'static {
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
use crate::{
|
use crate::{
|
||||||
settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
|
settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
|
||||||
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
|
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
|
||||||
LanguageModelProviderState, LanguageModelRequest, Role,
|
LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
|
||||||
};
|
};
|
||||||
use anyhow::{anyhow, Context as _, Result};
|
use anyhow::{anyhow, Context as _, Result};
|
||||||
use collections::BTreeMap;
|
use collections::BTreeMap;
|
||||||
@ -36,6 +36,7 @@ pub struct AnthropicSettings {
|
|||||||
pub struct AvailableModel {
|
pub struct AvailableModel {
|
||||||
pub name: String,
|
pub name: String,
|
||||||
pub max_tokens: usize,
|
pub max_tokens: usize,
|
||||||
|
pub tool_override: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct AnthropicLanguageModelProvider {
|
pub struct AnthropicLanguageModelProvider {
|
||||||
@ -98,6 +99,7 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider {
|
|||||||
anthropic::Model::Custom {
|
anthropic::Model::Custom {
|
||||||
name: model.name.clone(),
|
name: model.name.clone(),
|
||||||
max_tokens: model.max_tokens,
|
max_tokens: model.max_tokens,
|
||||||
|
tool_override: model.tool_override.clone(),
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
@ -110,6 +112,7 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider {
|
|||||||
model,
|
model,
|
||||||
state: self.state.clone(),
|
state: self.state.clone(),
|
||||||
http_client: self.http_client.clone(),
|
http_client: self.http_client.clone(),
|
||||||
|
request_limiter: RateLimiter::new(4),
|
||||||
}) as Arc<dyn LanguageModel>
|
}) as Arc<dyn LanguageModel>
|
||||||
})
|
})
|
||||||
.collect()
|
.collect()
|
||||||
@ -119,7 +122,7 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider {
|
|||||||
self.state.read(cx).api_key.is_some()
|
self.state.read(cx).api_key.is_some()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
|
fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>> {
|
||||||
if self.is_authenticated(cx) {
|
if self.is_authenticated(cx) {
|
||||||
Task::ready(Ok(()))
|
Task::ready(Ok(()))
|
||||||
} else {
|
} else {
|
||||||
@ -152,7 +155,7 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider {
|
|||||||
.into()
|
.into()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
|
fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>> {
|
||||||
let state = self.state.clone();
|
let state = self.state.clone();
|
||||||
let delete_credentials =
|
let delete_credentials =
|
||||||
cx.delete_credentials(&AllLanguageModelSettings::get_global(cx).anthropic.api_url);
|
cx.delete_credentials(&AllLanguageModelSettings::get_global(cx).anthropic.api_url);
|
||||||
@ -171,6 +174,7 @@ pub struct AnthropicModel {
|
|||||||
model: anthropic::Model,
|
model: anthropic::Model,
|
||||||
state: gpui::Model<State>,
|
state: gpui::Model<State>,
|
||||||
http_client: Arc<dyn HttpClient>,
|
http_client: Arc<dyn HttpClient>,
|
||||||
|
request_limiter: RateLimiter,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn count_anthropic_tokens(
|
pub fn count_anthropic_tokens(
|
||||||
@ -296,14 +300,14 @@ impl LanguageModel for AnthropicModel {
|
|||||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||||
let request = request.into_anthropic(self.model.id().into());
|
let request = request.into_anthropic(self.model.id().into());
|
||||||
let request = self.stream_completion(request, cx);
|
let request = self.stream_completion(request, cx);
|
||||||
async move {
|
let future = self.request_limiter.stream(async move {
|
||||||
let response = request.await?;
|
let response = request.await?;
|
||||||
Ok(anthropic::extract_text_from_events(response).boxed())
|
Ok(anthropic::extract_text_from_events(response))
|
||||||
}
|
});
|
||||||
.boxed()
|
async move { Ok(future.await?.boxed()) }.boxed()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn use_tool(
|
fn use_any_tool(
|
||||||
&self,
|
&self,
|
||||||
request: LanguageModelRequest,
|
request: LanguageModelRequest,
|
||||||
tool_name: String,
|
tool_name: String,
|
||||||
@ -311,7 +315,7 @@ impl LanguageModel for AnthropicModel {
|
|||||||
input_schema: serde_json::Value,
|
input_schema: serde_json::Value,
|
||||||
cx: &AsyncAppContext,
|
cx: &AsyncAppContext,
|
||||||
) -> BoxFuture<'static, Result<serde_json::Value>> {
|
) -> BoxFuture<'static, Result<serde_json::Value>> {
|
||||||
let mut request = request.into_anthropic(self.model.id().into());
|
let mut request = request.into_anthropic(self.model.tool_model_id().into());
|
||||||
request.tool_choice = Some(anthropic::ToolChoice::Tool {
|
request.tool_choice = Some(anthropic::ToolChoice::Tool {
|
||||||
name: tool_name.clone(),
|
name: tool_name.clone(),
|
||||||
});
|
});
|
||||||
@ -322,7 +326,8 @@ impl LanguageModel for AnthropicModel {
|
|||||||
}];
|
}];
|
||||||
|
|
||||||
let response = self.request_completion(request, cx);
|
let response = self.request_completion(request, cx);
|
||||||
async move {
|
self.request_limiter
|
||||||
|
.run(async move {
|
||||||
let response = response.await?;
|
let response = response.await?;
|
||||||
response
|
response
|
||||||
.content
|
.content
|
||||||
@ -339,7 +344,7 @@ impl LanguageModel for AnthropicModel {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
.context("tool not used")
|
.context("tool not used")
|
||||||
}
|
})
|
||||||
.boxed()
|
.boxed()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -2,7 +2,7 @@ use super::open_ai::count_open_ai_tokens;
|
|||||||
use crate::{
|
use crate::{
|
||||||
settings::AllLanguageModelSettings, CloudModel, LanguageModel, LanguageModelId,
|
settings::AllLanguageModelSettings, CloudModel, LanguageModel, LanguageModelId,
|
||||||
LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
|
LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
|
||||||
LanguageModelProviderState, LanguageModelRequest,
|
LanguageModelProviderState, LanguageModelRequest, RateLimiter,
|
||||||
};
|
};
|
||||||
use anyhow::{anyhow, Context as _, Result};
|
use anyhow::{anyhow, Context as _, Result};
|
||||||
use client::Client;
|
use client::Client;
|
||||||
@ -41,6 +41,7 @@ pub struct AvailableModel {
|
|||||||
provider: AvailableProvider,
|
provider: AvailableProvider,
|
||||||
name: String,
|
name: String,
|
||||||
max_tokens: usize,
|
max_tokens: usize,
|
||||||
|
tool_override: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct CloudLanguageModelProvider {
|
pub struct CloudLanguageModelProvider {
|
||||||
@ -56,7 +57,7 @@ struct State {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl State {
|
impl State {
|
||||||
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
|
fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>> {
|
||||||
let client = self.client.clone();
|
let client = self.client.clone();
|
||||||
cx.spawn(move |cx| async move { client.authenticate_and_connect(true, &cx).await })
|
cx.spawn(move |cx| async move { client.authenticate_and_connect(true, &cx).await })
|
||||||
}
|
}
|
||||||
@ -142,6 +143,7 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
|
|||||||
AvailableProvider::Anthropic => CloudModel::Anthropic(anthropic::Model::Custom {
|
AvailableProvider::Anthropic => CloudModel::Anthropic(anthropic::Model::Custom {
|
||||||
name: model.name.clone(),
|
name: model.name.clone(),
|
||||||
max_tokens: model.max_tokens,
|
max_tokens: model.max_tokens,
|
||||||
|
tool_override: model.tool_override.clone(),
|
||||||
}),
|
}),
|
||||||
AvailableProvider::OpenAi => CloudModel::OpenAi(open_ai::Model::Custom {
|
AvailableProvider::OpenAi => CloudModel::OpenAi(open_ai::Model::Custom {
|
||||||
name: model.name.clone(),
|
name: model.name.clone(),
|
||||||
@ -162,6 +164,7 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
|
|||||||
id: LanguageModelId::from(model.id().to_string()),
|
id: LanguageModelId::from(model.id().to_string()),
|
||||||
model,
|
model,
|
||||||
client: self.client.clone(),
|
client: self.client.clone(),
|
||||||
|
request_limiter: RateLimiter::new(4),
|
||||||
}) as Arc<dyn LanguageModel>
|
}) as Arc<dyn LanguageModel>
|
||||||
})
|
})
|
||||||
.collect()
|
.collect()
|
||||||
@ -171,8 +174,8 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
|
|||||||
self.state.read(cx).status.is_connected()
|
self.state.read(cx).status.is_connected()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
|
fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>> {
|
||||||
self.state.read(cx).authenticate(cx)
|
self.state.update(cx, |state, cx| state.authenticate(cx))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
|
fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
|
||||||
@ -182,7 +185,7 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
|
|||||||
.into()
|
.into()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn reset_credentials(&self, _cx: &AppContext) -> Task<Result<()>> {
|
fn reset_credentials(&self, _cx: &mut AppContext) -> Task<Result<()>> {
|
||||||
Task::ready(Ok(()))
|
Task::ready(Ok(()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -191,6 +194,7 @@ pub struct CloudLanguageModel {
|
|||||||
id: LanguageModelId,
|
id: LanguageModelId,
|
||||||
model: CloudModel,
|
model: CloudModel,
|
||||||
client: Arc<Client>,
|
client: Arc<Client>,
|
||||||
|
request_limiter: RateLimiter,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl LanguageModel for CloudLanguageModel {
|
impl LanguageModel for CloudLanguageModel {
|
||||||
@ -256,7 +260,7 @@ impl LanguageModel for CloudLanguageModel {
|
|||||||
CloudModel::Anthropic(model) => {
|
CloudModel::Anthropic(model) => {
|
||||||
let client = self.client.clone();
|
let client = self.client.clone();
|
||||||
let request = request.into_anthropic(model.id().into());
|
let request = request.into_anthropic(model.id().into());
|
||||||
async move {
|
let future = self.request_limiter.stream(async move {
|
||||||
let request = serde_json::to_string(&request)?;
|
let request = serde_json::to_string(&request)?;
|
||||||
let stream = client
|
let stream = client
|
||||||
.request_stream(proto::StreamCompleteWithLanguageModel {
|
.request_stream(proto::StreamCompleteWithLanguageModel {
|
||||||
@ -266,15 +270,14 @@ impl LanguageModel for CloudLanguageModel {
|
|||||||
.await?;
|
.await?;
|
||||||
Ok(anthropic::extract_text_from_events(
|
Ok(anthropic::extract_text_from_events(
|
||||||
stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
|
stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
|
||||||
)
|
))
|
||||||
.boxed())
|
});
|
||||||
}
|
async move { Ok(future.await?.boxed()) }.boxed()
|
||||||
.boxed()
|
|
||||||
}
|
}
|
||||||
CloudModel::OpenAi(model) => {
|
CloudModel::OpenAi(model) => {
|
||||||
let client = self.client.clone();
|
let client = self.client.clone();
|
||||||
let request = request.into_open_ai(model.id().into());
|
let request = request.into_open_ai(model.id().into());
|
||||||
async move {
|
let future = self.request_limiter.stream(async move {
|
||||||
let request = serde_json::to_string(&request)?;
|
let request = serde_json::to_string(&request)?;
|
||||||
let stream = client
|
let stream = client
|
||||||
.request_stream(proto::StreamCompleteWithLanguageModel {
|
.request_stream(proto::StreamCompleteWithLanguageModel {
|
||||||
@ -284,15 +287,14 @@ impl LanguageModel for CloudLanguageModel {
|
|||||||
.await?;
|
.await?;
|
||||||
Ok(open_ai::extract_text_from_events(
|
Ok(open_ai::extract_text_from_events(
|
||||||
stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
|
stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
|
||||||
)
|
))
|
||||||
.boxed())
|
});
|
||||||
}
|
async move { Ok(future.await?.boxed()) }.boxed()
|
||||||
.boxed()
|
|
||||||
}
|
}
|
||||||
CloudModel::Google(model) => {
|
CloudModel::Google(model) => {
|
||||||
let client = self.client.clone();
|
let client = self.client.clone();
|
||||||
let request = request.into_google(model.id().into());
|
let request = request.into_google(model.id().into());
|
||||||
async move {
|
let future = self.request_limiter.stream(async move {
|
||||||
let request = serde_json::to_string(&request)?;
|
let request = serde_json::to_string(&request)?;
|
||||||
let stream = client
|
let stream = client
|
||||||
.request_stream(proto::StreamCompleteWithLanguageModel {
|
.request_stream(proto::StreamCompleteWithLanguageModel {
|
||||||
@ -302,15 +304,14 @@ impl LanguageModel for CloudLanguageModel {
|
|||||||
.await?;
|
.await?;
|
||||||
Ok(google_ai::extract_text_from_events(
|
Ok(google_ai::extract_text_from_events(
|
||||||
stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
|
stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
|
||||||
)
|
))
|
||||||
.boxed())
|
});
|
||||||
}
|
async move { Ok(future.await?.boxed()) }.boxed()
|
||||||
.boxed()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn use_tool(
|
fn use_any_tool(
|
||||||
&self,
|
&self,
|
||||||
request: LanguageModelRequest,
|
request: LanguageModelRequest,
|
||||||
tool_name: String,
|
tool_name: String,
|
||||||
@ -321,7 +322,7 @@ impl LanguageModel for CloudLanguageModel {
|
|||||||
match &self.model {
|
match &self.model {
|
||||||
CloudModel::Anthropic(model) => {
|
CloudModel::Anthropic(model) => {
|
||||||
let client = self.client.clone();
|
let client = self.client.clone();
|
||||||
let mut request = request.into_anthropic(model.id().into());
|
let mut request = request.into_anthropic(model.tool_model_id().into());
|
||||||
request.tool_choice = Some(anthropic::ToolChoice::Tool {
|
request.tool_choice = Some(anthropic::ToolChoice::Tool {
|
||||||
name: tool_name.clone(),
|
name: tool_name.clone(),
|
||||||
});
|
});
|
||||||
@ -331,7 +332,8 @@ impl LanguageModel for CloudLanguageModel {
|
|||||||
input_schema,
|
input_schema,
|
||||||
}];
|
}];
|
||||||
|
|
||||||
async move {
|
self.request_limiter
|
||||||
|
.run(async move {
|
||||||
let request = serde_json::to_string(&request)?;
|
let request = serde_json::to_string(&request)?;
|
||||||
let response = client
|
let response = client
|
||||||
.request(proto::CompleteWithLanguageModel {
|
.request(proto::CompleteWithLanguageModel {
|
||||||
@ -339,7 +341,8 @@ impl LanguageModel for CloudLanguageModel {
|
|||||||
request,
|
request,
|
||||||
})
|
})
|
||||||
.await?;
|
.await?;
|
||||||
let response: anthropic::Response = serde_json::from_str(&response.completion)?;
|
let response: anthropic::Response =
|
||||||
|
serde_json::from_str(&response.completion)?;
|
||||||
response
|
response
|
||||||
.content
|
.content
|
||||||
.into_iter()
|
.into_iter()
|
||||||
@ -355,7 +358,7 @@ impl LanguageModel for CloudLanguageModel {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
.context("tool not used")
|
.context("tool not used")
|
||||||
}
|
})
|
||||||
.boxed()
|
.boxed()
|
||||||
}
|
}
|
||||||
CloudModel::OpenAi(_) => {
|
CloudModel::OpenAi(_) => {
|
||||||
|
@ -27,7 +27,7 @@ use crate::settings::AllLanguageModelSettings;
|
|||||||
use crate::LanguageModelProviderState;
|
use crate::LanguageModelProviderState;
|
||||||
use crate::{
|
use crate::{
|
||||||
LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
|
LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
|
||||||
LanguageModelProviderId, LanguageModelProviderName, LanguageModelRequest, Role,
|
LanguageModelProviderId, LanguageModelProviderName, LanguageModelRequest, RateLimiter, Role,
|
||||||
};
|
};
|
||||||
|
|
||||||
use super::open_ai::count_open_ai_tokens;
|
use super::open_ai::count_open_ai_tokens;
|
||||||
@ -85,7 +85,12 @@ impl LanguageModelProvider for CopilotChatLanguageModelProvider {
|
|||||||
|
|
||||||
fn provided_models(&self, _cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
|
fn provided_models(&self, _cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
|
||||||
CopilotChatModel::iter()
|
CopilotChatModel::iter()
|
||||||
.map(|model| Arc::new(CopilotChatLanguageModel { model }) as Arc<dyn LanguageModel>)
|
.map(|model| {
|
||||||
|
Arc::new(CopilotChatLanguageModel {
|
||||||
|
model,
|
||||||
|
request_limiter: RateLimiter::new(4),
|
||||||
|
}) as Arc<dyn LanguageModel>
|
||||||
|
})
|
||||||
.collect()
|
.collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -95,7 +100,7 @@ impl LanguageModelProvider for CopilotChatLanguageModelProvider {
|
|||||||
.unwrap_or(false)
|
.unwrap_or(false)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
|
fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>> {
|
||||||
let result = if self.is_authenticated(cx) {
|
let result = if self.is_authenticated(cx) {
|
||||||
Ok(())
|
Ok(())
|
||||||
} else if let Some(copilot) = Copilot::global(cx) {
|
} else if let Some(copilot) = Copilot::global(cx) {
|
||||||
@ -121,7 +126,7 @@ impl LanguageModelProvider for CopilotChatLanguageModelProvider {
|
|||||||
cx.new_view(|cx| AuthenticationPrompt::new(cx)).into()
|
cx.new_view(|cx| AuthenticationPrompt::new(cx)).into()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
|
fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>> {
|
||||||
let Some(copilot) = Copilot::global(cx) else {
|
let Some(copilot) = Copilot::global(cx) else {
|
||||||
return Task::ready(Err(anyhow::anyhow!(
|
return Task::ready(Err(anyhow::anyhow!(
|
||||||
"Copilot is not available. Please ensure Copilot is enabled and running and try again."
|
"Copilot is not available. Please ensure Copilot is enabled and running and try again."
|
||||||
@ -145,6 +150,7 @@ impl LanguageModelProvider for CopilotChatLanguageModelProvider {
|
|||||||
|
|
||||||
pub struct CopilotChatLanguageModel {
|
pub struct CopilotChatLanguageModel {
|
||||||
model: CopilotChatModel,
|
model: CopilotChatModel,
|
||||||
|
request_limiter: RateLimiter,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl LanguageModel for CopilotChatLanguageModel {
|
impl LanguageModel for CopilotChatLanguageModel {
|
||||||
@ -215,8 +221,11 @@ impl LanguageModel for CopilotChatLanguageModel {
|
|||||||
return futures::future::ready(Err(anyhow::anyhow!("App state dropped"))).boxed();
|
return futures::future::ready(Err(anyhow::anyhow!("App state dropped"))).boxed();
|
||||||
};
|
};
|
||||||
|
|
||||||
cx.spawn(|mut cx| async move {
|
let request_limiter = self.request_limiter.clone();
|
||||||
let response = CopilotChat::stream_completion(request, low_speed_timeout, &mut cx).await?;
|
let future = cx.spawn(|cx| async move {
|
||||||
|
let response = CopilotChat::stream_completion(request, low_speed_timeout, cx);
|
||||||
|
request_limiter.stream(async move {
|
||||||
|
let response = response.await?;
|
||||||
let stream = response
|
let stream = response
|
||||||
.filter_map(|response| async move {
|
.filter_map(|response| async move {
|
||||||
match response {
|
match response {
|
||||||
@ -234,11 +243,13 @@ impl LanguageModel for CopilotChatLanguageModel {
|
|||||||
})
|
})
|
||||||
.boxed();
|
.boxed();
|
||||||
Ok(stream)
|
Ok(stream)
|
||||||
})
|
}).await
|
||||||
.boxed()
|
});
|
||||||
|
|
||||||
|
async move { Ok(future.await?.boxed()) }.boxed()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn use_tool(
|
fn use_any_tool(
|
||||||
&self,
|
&self,
|
||||||
_request: LanguageModelRequest,
|
_request: LanguageModelRequest,
|
||||||
_name: String,
|
_name: String,
|
||||||
|
@ -60,7 +60,7 @@ impl LanguageModelProvider for FakeLanguageModelProvider {
|
|||||||
true
|
true
|
||||||
}
|
}
|
||||||
|
|
||||||
fn authenticate(&self, _: &AppContext) -> Task<Result<()>> {
|
fn authenticate(&self, _: &mut AppContext) -> Task<Result<()>> {
|
||||||
Task::ready(Ok(()))
|
Task::ready(Ok(()))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -68,7 +68,7 @@ impl LanguageModelProvider for FakeLanguageModelProvider {
|
|||||||
unimplemented!()
|
unimplemented!()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn reset_credentials(&self, _: &AppContext) -> Task<Result<()>> {
|
fn reset_credentials(&self, _: &mut AppContext) -> Task<Result<()>> {
|
||||||
Task::ready(Ok(()))
|
Task::ready(Ok(()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -173,7 +173,7 @@ impl LanguageModel for FakeLanguageModel {
|
|||||||
async move { Ok(rx.map(Ok).boxed()) }.boxed()
|
async move { Ok(rx.map(Ok).boxed()) }.boxed()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn use_tool(
|
fn use_any_tool(
|
||||||
&self,
|
&self,
|
||||||
_request: LanguageModelRequest,
|
_request: LanguageModelRequest,
|
||||||
_name: String,
|
_name: String,
|
||||||
|
@ -20,7 +20,7 @@ use util::ResultExt;
|
|||||||
use crate::{
|
use crate::{
|
||||||
settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
|
settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
|
||||||
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
|
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
|
||||||
LanguageModelProviderState, LanguageModelRequest,
|
LanguageModelProviderState, LanguageModelRequest, RateLimiter,
|
||||||
};
|
};
|
||||||
|
|
||||||
const PROVIDER_ID: &str = "google";
|
const PROVIDER_ID: &str = "google";
|
||||||
@ -111,6 +111,7 @@ impl LanguageModelProvider for GoogleLanguageModelProvider {
|
|||||||
model,
|
model,
|
||||||
state: self.state.clone(),
|
state: self.state.clone(),
|
||||||
http_client: self.http_client.clone(),
|
http_client: self.http_client.clone(),
|
||||||
|
rate_limiter: RateLimiter::new(4),
|
||||||
}) as Arc<dyn LanguageModel>
|
}) as Arc<dyn LanguageModel>
|
||||||
})
|
})
|
||||||
.collect()
|
.collect()
|
||||||
@ -120,7 +121,7 @@ impl LanguageModelProvider for GoogleLanguageModelProvider {
|
|||||||
self.state.read(cx).api_key.is_some()
|
self.state.read(cx).api_key.is_some()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
|
fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>> {
|
||||||
if self.is_authenticated(cx) {
|
if self.is_authenticated(cx) {
|
||||||
Task::ready(Ok(()))
|
Task::ready(Ok(()))
|
||||||
} else {
|
} else {
|
||||||
@ -153,7 +154,7 @@ impl LanguageModelProvider for GoogleLanguageModelProvider {
|
|||||||
.into()
|
.into()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
|
fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>> {
|
||||||
let state = self.state.clone();
|
let state = self.state.clone();
|
||||||
let delete_credentials =
|
let delete_credentials =
|
||||||
cx.delete_credentials(&AllLanguageModelSettings::get_global(cx).google.api_url);
|
cx.delete_credentials(&AllLanguageModelSettings::get_global(cx).google.api_url);
|
||||||
@ -172,6 +173,7 @@ pub struct GoogleLanguageModel {
|
|||||||
model: google_ai::Model,
|
model: google_ai::Model,
|
||||||
state: gpui::Model<State>,
|
state: gpui::Model<State>,
|
||||||
http_client: Arc<dyn HttpClient>,
|
http_client: Arc<dyn HttpClient>,
|
||||||
|
rate_limiter: RateLimiter,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl LanguageModel for GoogleLanguageModel {
|
impl LanguageModel for GoogleLanguageModel {
|
||||||
@ -243,17 +245,17 @@ impl LanguageModel for GoogleLanguageModel {
|
|||||||
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
|
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
|
||||||
};
|
};
|
||||||
|
|
||||||
async move {
|
let future = self.rate_limiter.stream(async move {
|
||||||
let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
|
let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
|
||||||
let response =
|
let response =
|
||||||
stream_generate_content(http_client.as_ref(), &api_url, &api_key, request);
|
stream_generate_content(http_client.as_ref(), &api_url, &api_key, request);
|
||||||
let events = response.await?;
|
let events = response.await?;
|
||||||
Ok(google_ai::extract_text_from_events(events).boxed())
|
Ok(google_ai::extract_text_from_events(events).boxed())
|
||||||
}
|
});
|
||||||
.boxed()
|
async move { Ok(future.await?.boxed()) }.boxed()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn use_tool(
|
fn use_any_tool(
|
||||||
&self,
|
&self,
|
||||||
_request: LanguageModelRequest,
|
_request: LanguageModelRequest,
|
||||||
_name: String,
|
_name: String,
|
||||||
|
@ -12,7 +12,7 @@ use ui::{prelude::*, ButtonLike, ElevationIndex};
|
|||||||
use crate::{
|
use crate::{
|
||||||
settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
|
settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
|
||||||
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
|
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
|
||||||
LanguageModelProviderState, LanguageModelRequest, Role,
|
LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
|
||||||
};
|
};
|
||||||
|
|
||||||
const OLLAMA_DOWNLOAD_URL: &str = "https://ollama.com/download";
|
const OLLAMA_DOWNLOAD_URL: &str = "https://ollama.com/download";
|
||||||
@ -39,7 +39,7 @@ struct State {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl State {
|
impl State {
|
||||||
fn fetch_models(&self, cx: &ModelContext<Self>) -> Task<Result<()>> {
|
fn fetch_models(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
|
||||||
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
|
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
|
||||||
let http_client = self.http_client.clone();
|
let http_client = self.http_client.clone();
|
||||||
let api_url = settings.api_url.clone();
|
let api_url = settings.api_url.clone();
|
||||||
@ -80,37 +80,10 @@ impl OllamaLanguageModelProvider {
|
|||||||
}),
|
}),
|
||||||
}),
|
}),
|
||||||
};
|
};
|
||||||
this.fetch_models(cx).detach();
|
this.state
|
||||||
|
.update(cx, |state, cx| state.fetch_models(cx).detach());
|
||||||
this
|
this
|
||||||
}
|
}
|
||||||
|
|
||||||
fn fetch_models(&self, cx: &AppContext) -> Task<Result<()>> {
|
|
||||||
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
|
|
||||||
let http_client = self.http_client.clone();
|
|
||||||
let api_url = settings.api_url.clone();
|
|
||||||
|
|
||||||
let state = self.state.clone();
|
|
||||||
// As a proxy for the server being "authenticated", we'll check if its up by fetching the models
|
|
||||||
cx.spawn(|mut cx| async move {
|
|
||||||
let models = get_models(http_client.as_ref(), &api_url, None).await?;
|
|
||||||
|
|
||||||
let mut models: Vec<ollama::Model> = models
|
|
||||||
.into_iter()
|
|
||||||
// Since there is no metadata from the Ollama API
|
|
||||||
// indicating which models are embedding models,
|
|
||||||
// simply filter out models with "-embed" in their name
|
|
||||||
.filter(|model| !model.name.contains("-embed"))
|
|
||||||
.map(|model| ollama::Model::new(&model.name))
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
models.sort_by(|a, b| a.name.cmp(&b.name));
|
|
||||||
|
|
||||||
state.update(&mut cx, |this, cx| {
|
|
||||||
this.available_models = models;
|
|
||||||
cx.notify();
|
|
||||||
})
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl LanguageModelProviderState for OllamaLanguageModelProvider {
|
impl LanguageModelProviderState for OllamaLanguageModelProvider {
|
||||||
@ -140,6 +113,7 @@ impl LanguageModelProvider for OllamaLanguageModelProvider {
|
|||||||
id: LanguageModelId::from(model.name.clone()),
|
id: LanguageModelId::from(model.name.clone()),
|
||||||
model: model.clone(),
|
model: model.clone(),
|
||||||
http_client: self.http_client.clone(),
|
http_client: self.http_client.clone(),
|
||||||
|
request_limiter: RateLimiter::new(4),
|
||||||
}) as Arc<dyn LanguageModel>
|
}) as Arc<dyn LanguageModel>
|
||||||
})
|
})
|
||||||
.collect()
|
.collect()
|
||||||
@ -158,11 +132,11 @@ impl LanguageModelProvider for OllamaLanguageModelProvider {
|
|||||||
!self.state.read(cx).available_models.is_empty()
|
!self.state.read(cx).available_models.is_empty()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
|
fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>> {
|
||||||
if self.is_authenticated(cx) {
|
if self.is_authenticated(cx) {
|
||||||
Task::ready(Ok(()))
|
Task::ready(Ok(()))
|
||||||
} else {
|
} else {
|
||||||
self.fetch_models(cx)
|
self.state.update(cx, |state, cx| state.fetch_models(cx))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -176,8 +150,8 @@ impl LanguageModelProvider for OllamaLanguageModelProvider {
|
|||||||
.into()
|
.into()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
|
fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>> {
|
||||||
self.fetch_models(cx)
|
self.state.update(cx, |state, cx| state.fetch_models(cx))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -185,6 +159,7 @@ pub struct OllamaLanguageModel {
|
|||||||
id: LanguageModelId,
|
id: LanguageModelId,
|
||||||
model: ollama::Model,
|
model: ollama::Model,
|
||||||
http_client: Arc<dyn HttpClient>,
|
http_client: Arc<dyn HttpClient>,
|
||||||
|
request_limiter: RateLimiter,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl OllamaLanguageModel {
|
impl OllamaLanguageModel {
|
||||||
@ -235,14 +210,14 @@ impl LanguageModel for OllamaLanguageModel {
|
|||||||
LanguageModelProviderName(PROVIDER_NAME.into())
|
LanguageModelProviderName(PROVIDER_NAME.into())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn max_token_count(&self) -> usize {
|
|
||||||
self.model.max_token_count()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn telemetry_id(&self) -> String {
|
fn telemetry_id(&self) -> String {
|
||||||
format!("ollama/{}", self.model.id())
|
format!("ollama/{}", self.model.id())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn max_token_count(&self) -> usize {
|
||||||
|
self.model.max_token_count()
|
||||||
|
}
|
||||||
|
|
||||||
fn count_tokens(
|
fn count_tokens(
|
||||||
&self,
|
&self,
|
||||||
request: LanguageModelRequest,
|
request: LanguageModelRequest,
|
||||||
@ -275,10 +250,10 @@ impl LanguageModel for OllamaLanguageModel {
|
|||||||
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
|
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
|
||||||
};
|
};
|
||||||
|
|
||||||
async move {
|
let future = self.request_limiter.stream(async move {
|
||||||
let request =
|
let response =
|
||||||
stream_chat_completion(http_client.as_ref(), &api_url, request, low_speed_timeout);
|
stream_chat_completion(http_client.as_ref(), &api_url, request, low_speed_timeout)
|
||||||
let response = request.await?;
|
.await?;
|
||||||
let stream = response
|
let stream = response
|
||||||
.filter_map(|response| async move {
|
.filter_map(|response| async move {
|
||||||
match response {
|
match response {
|
||||||
@ -295,11 +270,12 @@ impl LanguageModel for OllamaLanguageModel {
|
|||||||
})
|
})
|
||||||
.boxed();
|
.boxed();
|
||||||
Ok(stream)
|
Ok(stream)
|
||||||
}
|
});
|
||||||
.boxed()
|
|
||||||
|
async move { Ok(future.await?.boxed()) }.boxed()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn use_tool(
|
fn use_any_tool(
|
||||||
&self,
|
&self,
|
||||||
_request: LanguageModelRequest,
|
_request: LanguageModelRequest,
|
||||||
_name: String,
|
_name: String,
|
||||||
|
@ -20,7 +20,7 @@ use util::ResultExt;
|
|||||||
use crate::{
|
use crate::{
|
||||||
settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
|
settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
|
||||||
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
|
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
|
||||||
LanguageModelProviderState, LanguageModelRequest, Role,
|
LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
|
||||||
};
|
};
|
||||||
|
|
||||||
const PROVIDER_ID: &str = "openai";
|
const PROVIDER_ID: &str = "openai";
|
||||||
@ -112,6 +112,7 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider {
|
|||||||
model,
|
model,
|
||||||
state: self.state.clone(),
|
state: self.state.clone(),
|
||||||
http_client: self.http_client.clone(),
|
http_client: self.http_client.clone(),
|
||||||
|
request_limiter: RateLimiter::new(4),
|
||||||
}) as Arc<dyn LanguageModel>
|
}) as Arc<dyn LanguageModel>
|
||||||
})
|
})
|
||||||
.collect()
|
.collect()
|
||||||
@ -121,7 +122,7 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider {
|
|||||||
self.state.read(cx).api_key.is_some()
|
self.state.read(cx).api_key.is_some()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
|
fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>> {
|
||||||
if self.is_authenticated(cx) {
|
if self.is_authenticated(cx) {
|
||||||
Task::ready(Ok(()))
|
Task::ready(Ok(()))
|
||||||
} else {
|
} else {
|
||||||
@ -153,7 +154,7 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider {
|
|||||||
.into()
|
.into()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
|
fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>> {
|
||||||
let settings = &AllLanguageModelSettings::get_global(cx).openai;
|
let settings = &AllLanguageModelSettings::get_global(cx).openai;
|
||||||
let delete_credentials = cx.delete_credentials(&settings.api_url);
|
let delete_credentials = cx.delete_credentials(&settings.api_url);
|
||||||
let state = self.state.clone();
|
let state = self.state.clone();
|
||||||
@ -172,6 +173,7 @@ pub struct OpenAiLanguageModel {
|
|||||||
model: open_ai::Model,
|
model: open_ai::Model,
|
||||||
state: gpui::Model<State>,
|
state: gpui::Model<State>,
|
||||||
http_client: Arc<dyn HttpClient>,
|
http_client: Arc<dyn HttpClient>,
|
||||||
|
request_limiter: RateLimiter,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl LanguageModel for OpenAiLanguageModel {
|
impl LanguageModel for OpenAiLanguageModel {
|
||||||
@ -226,7 +228,7 @@ impl LanguageModel for OpenAiLanguageModel {
|
|||||||
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
|
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
|
||||||
};
|
};
|
||||||
|
|
||||||
async move {
|
let future = self.request_limiter.stream(async move {
|
||||||
let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
|
let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
|
||||||
let request = stream_completion(
|
let request = stream_completion(
|
||||||
http_client.as_ref(),
|
http_client.as_ref(),
|
||||||
@ -237,11 +239,12 @@ impl LanguageModel for OpenAiLanguageModel {
|
|||||||
);
|
);
|
||||||
let response = request.await?;
|
let response = request.await?;
|
||||||
Ok(open_ai::extract_text_from_events(response).boxed())
|
Ok(open_ai::extract_text_from_events(response).boxed())
|
||||||
}
|
});
|
||||||
.boxed()
|
|
||||||
|
async move { Ok(future.await?.boxed()) }.boxed()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn use_tool(
|
fn use_any_tool(
|
||||||
&self,
|
&self,
|
||||||
_request: LanguageModelRequest,
|
_request: LanguageModelRequest,
|
||||||
_name: String,
|
_name: String,
|
||||||
|
70
crates/language_model/src/rate_limiter.rs
Normal file
70
crates/language_model/src/rate_limiter.rs
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
use anyhow::Result;
|
||||||
|
use futures::Stream;
|
||||||
|
use smol::lock::{Semaphore, SemaphoreGuardArc};
|
||||||
|
use std::{
|
||||||
|
future::Future,
|
||||||
|
pin::Pin,
|
||||||
|
sync::Arc,
|
||||||
|
task::{Context, Poll},
|
||||||
|
};
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct RateLimiter {
|
||||||
|
semaphore: Arc<Semaphore>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct RateLimitGuard<T> {
|
||||||
|
inner: T,
|
||||||
|
_guard: SemaphoreGuardArc,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> Stream for RateLimitGuard<T>
|
||||||
|
where
|
||||||
|
T: Stream,
|
||||||
|
{
|
||||||
|
type Item = T::Item;
|
||||||
|
|
||||||
|
fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
|
||||||
|
unsafe { Pin::map_unchecked_mut(self, |this| &mut this.inner).poll_next(cx) }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RateLimiter {
|
||||||
|
pub fn new(limit: usize) -> Self {
|
||||||
|
Self {
|
||||||
|
semaphore: Arc::new(Semaphore::new(limit)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn run<'a, Fut, T>(&self, future: Fut) -> impl 'a + Future<Output = Result<T>>
|
||||||
|
where
|
||||||
|
Fut: 'a + Future<Output = Result<T>>,
|
||||||
|
{
|
||||||
|
let guard = self.semaphore.acquire_arc();
|
||||||
|
async move {
|
||||||
|
let guard = guard.await;
|
||||||
|
let result = future.await?;
|
||||||
|
drop(guard);
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn stream<'a, Fut, T>(
|
||||||
|
&self,
|
||||||
|
future: Fut,
|
||||||
|
) -> impl 'a + Future<Output = Result<impl Stream<Item = T::Item>>>
|
||||||
|
where
|
||||||
|
Fut: 'a + Future<Output = Result<T>>,
|
||||||
|
T: Stream,
|
||||||
|
{
|
||||||
|
let guard = self.semaphore.acquire_arc();
|
||||||
|
async move {
|
||||||
|
let guard = guard.await;
|
||||||
|
let inner = future.await?;
|
||||||
|
Ok(RateLimitGuard {
|
||||||
|
inner,
|
||||||
|
_guard: guard,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -4,11 +4,12 @@ use crate::{
|
|||||||
copilot_chat::CopilotChatLanguageModelProvider, google::GoogleLanguageModelProvider,
|
copilot_chat::CopilotChatLanguageModelProvider, google::GoogleLanguageModelProvider,
|
||||||
ollama::OllamaLanguageModelProvider, open_ai::OpenAiLanguageModelProvider,
|
ollama::OllamaLanguageModelProvider, open_ai::OpenAiLanguageModelProvider,
|
||||||
},
|
},
|
||||||
LanguageModel, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderState,
|
LanguageModel, LanguageModelId, LanguageModelProvider, LanguageModelProviderId,
|
||||||
|
LanguageModelProviderState,
|
||||||
};
|
};
|
||||||
use client::Client;
|
use client::Client;
|
||||||
use collections::BTreeMap;
|
use collections::BTreeMap;
|
||||||
use gpui::{AppContext, Global, Model, ModelContext};
|
use gpui::{AppContext, EventEmitter, Global, Model, ModelContext};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use ui::Context;
|
use ui::Context;
|
||||||
|
|
||||||
@ -70,9 +71,19 @@ impl Global for GlobalLanguageModelRegistry {}
|
|||||||
|
|
||||||
#[derive(Default)]
|
#[derive(Default)]
|
||||||
pub struct LanguageModelRegistry {
|
pub struct LanguageModelRegistry {
|
||||||
|
active_model: Option<ActiveModel>,
|
||||||
providers: BTreeMap<LanguageModelProviderId, Arc<dyn LanguageModelProvider>>,
|
providers: BTreeMap<LanguageModelProviderId, Arc<dyn LanguageModelProvider>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub struct ActiveModel {
|
||||||
|
provider: Arc<dyn LanguageModelProvider>,
|
||||||
|
model: Option<Arc<dyn LanguageModel>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct ActiveModelChanged;
|
||||||
|
|
||||||
|
impl EventEmitter<ActiveModelChanged> for LanguageModelRegistry {}
|
||||||
|
|
||||||
impl LanguageModelRegistry {
|
impl LanguageModelRegistry {
|
||||||
pub fn global(cx: &AppContext) -> Model<Self> {
|
pub fn global(cx: &AppContext) -> Model<Self> {
|
||||||
cx.global::<GlobalLanguageModelRegistry>().0.clone()
|
cx.global::<GlobalLanguageModelRegistry>().0.clone()
|
||||||
@ -88,6 +99,8 @@ impl LanguageModelRegistry {
|
|||||||
let registry = cx.new_model(|cx| {
|
let registry = cx.new_model(|cx| {
|
||||||
let mut registry = Self::default();
|
let mut registry = Self::default();
|
||||||
registry.register_provider(fake_provider.clone(), cx);
|
registry.register_provider(fake_provider.clone(), cx);
|
||||||
|
let model = fake_provider.provided_models(cx)[0].clone();
|
||||||
|
registry.set_active_model(Some(model), cx);
|
||||||
registry
|
registry
|
||||||
});
|
});
|
||||||
cx.set_global(GlobalLanguageModelRegistry(registry));
|
cx.set_global(GlobalLanguageModelRegistry(registry));
|
||||||
@ -136,6 +149,64 @@ impl LanguageModelRegistry {
|
|||||||
) -> Option<Arc<dyn LanguageModelProvider>> {
|
) -> Option<Arc<dyn LanguageModelProvider>> {
|
||||||
self.providers.get(name).cloned()
|
self.providers.get(name).cloned()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn select_active_model(
|
||||||
|
&mut self,
|
||||||
|
provider: &LanguageModelProviderId,
|
||||||
|
model_id: &LanguageModelId,
|
||||||
|
cx: &mut ModelContext<Self>,
|
||||||
|
) {
|
||||||
|
let Some(provider) = self.provider(&provider) else {
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
|
||||||
|
let models = provider.provided_models(cx);
|
||||||
|
if let Some(model) = models.iter().find(|model| &model.id() == model_id).cloned() {
|
||||||
|
self.set_active_model(Some(model), cx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn set_active_provider(
|
||||||
|
&mut self,
|
||||||
|
provider: Option<Arc<dyn LanguageModelProvider>>,
|
||||||
|
cx: &mut ModelContext<Self>,
|
||||||
|
) {
|
||||||
|
self.active_model = provider.map(|provider| ActiveModel {
|
||||||
|
provider,
|
||||||
|
model: None,
|
||||||
|
});
|
||||||
|
cx.emit(ActiveModelChanged);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn set_active_model(
|
||||||
|
&mut self,
|
||||||
|
model: Option<Arc<dyn LanguageModel>>,
|
||||||
|
cx: &mut ModelContext<Self>,
|
||||||
|
) {
|
||||||
|
if let Some(model) = model {
|
||||||
|
let provider_id = model.provider_id();
|
||||||
|
if let Some(provider) = self.providers.get(&provider_id).cloned() {
|
||||||
|
self.active_model = Some(ActiveModel {
|
||||||
|
provider,
|
||||||
|
model: Some(model),
|
||||||
|
});
|
||||||
|
cx.emit(ActiveModelChanged);
|
||||||
|
} else {
|
||||||
|
log::warn!("Active model's provider not found in registry");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
self.active_model = None;
|
||||||
|
cx.emit(ActiveModelChanged);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn active_provider(&self) -> Option<Arc<dyn LanguageModelProvider>> {
|
||||||
|
Some(self.active_model.as_ref()?.provider.clone())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn active_model(&self) -> Option<Arc<dyn LanguageModel>> {
|
||||||
|
self.active_model.as_ref()?.model.clone()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
@ -89,9 +89,15 @@ impl AnthropicSettingsContent {
|
|||||||
models
|
models
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.filter_map(|model| match model {
|
.filter_map(|model| match model {
|
||||||
anthropic::Model::Custom { name, max_tokens } => {
|
anthropic::Model::Custom {
|
||||||
Some(provider::anthropic::AvailableModel { name, max_tokens })
|
name,
|
||||||
}
|
max_tokens,
|
||||||
|
tool_override,
|
||||||
|
} => Some(provider::anthropic::AvailableModel {
|
||||||
|
name,
|
||||||
|
max_tokens,
|
||||||
|
tool_override,
|
||||||
|
}),
|
||||||
_ => None,
|
_ => None,
|
||||||
})
|
})
|
||||||
.collect()
|
.collect()
|
||||||
|
@ -22,7 +22,6 @@ anyhow.workspace = true
|
|||||||
client.workspace = true
|
client.workspace = true
|
||||||
clock.workspace = true
|
clock.workspace = true
|
||||||
collections.workspace = true
|
collections.workspace = true
|
||||||
completion.workspace = true
|
|
||||||
fs.workspace = true
|
fs.workspace = true
|
||||||
futures.workspace = true
|
futures.workspace = true
|
||||||
futures-batch.workspace = true
|
futures-batch.workspace = true
|
||||||
|
@ -1261,6 +1261,3 @@ mod tests {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// See https://github.com/zed-industries/zed/pull/14823#discussion_r1684616398 for why this is here and when it should be removed.
|
|
||||||
type _TODO = completion::LanguageModelCompletionProvider;
|
|
||||||
|
Loading…
Reference in New Issue
Block a user