Allow customization of the model used for tool calling (#15479)

We also eliminate the `completion` crate and moved its logic into
`LanguageModelRegistry`.

Release Notes:

- N/A

---------

Co-authored-by: Nathan <nathan@zed.dev>
This commit is contained in:
Antonio Scandurra 2024-07-30 16:18:53 +02:00 committed by GitHub
parent 1bfea9d443
commit 99bc90a372
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
32 changed files with 478 additions and 691 deletions

28
Cargo.lock generated
View File

@ -406,7 +406,6 @@ dependencies = [
"clock", "clock",
"collections", "collections",
"command_palette_hooks", "command_palette_hooks",
"completion",
"ctor", "ctor",
"editor", "editor",
"env_logger", "env_logger",
@ -2470,7 +2469,6 @@ dependencies = [
"clock", "clock",
"collab_ui", "collab_ui",
"collections", "collections",
"completion",
"ctor", "ctor",
"dashmap 6.0.1", "dashmap 6.0.1",
"dev_server_projects", "dev_server_projects",
@ -2655,30 +2653,6 @@ dependencies = [
"gpui", "gpui",
] ]
[[package]]
name = "completion"
version = "0.1.0"
dependencies = [
"anyhow",
"ctor",
"editor",
"env_logger",
"futures 0.3.28",
"gpui",
"language",
"language_model",
"project",
"rand 0.8.5",
"schemars",
"serde",
"serde_json",
"settings",
"smol",
"text",
"ui",
"unindent",
]
[[package]] [[package]]
name = "concurrent-queue" name = "concurrent-queue"
version = "2.2.0" version = "2.2.0"
@ -6048,6 +6022,7 @@ dependencies = [
"serde", "serde",
"serde_json", "serde_json",
"settings", "settings",
"smol",
"strum", "strum",
"text", "text",
"theme", "theme",
@ -9506,7 +9481,6 @@ dependencies = [
"client", "client",
"clock", "clock",
"collections", "collections",
"completion",
"env_logger", "env_logger",
"fs", "fs",
"futures 0.3.28", "futures 0.3.28",

View File

@ -19,7 +19,6 @@ members = [
"crates/collections", "crates/collections",
"crates/command_palette", "crates/command_palette",
"crates/command_palette_hooks", "crates/command_palette_hooks",
"crates/completion",
"crates/copilot", "crates/copilot",
"crates/db", "crates/db",
"crates/dev_server_projects", "crates/dev_server_projects",
@ -190,7 +189,6 @@ collab_ui = { path = "crates/collab_ui" }
collections = { path = "crates/collections" } collections = { path = "crates/collections" }
command_palette = { path = "crates/command_palette" } command_palette = { path = "crates/command_palette" }
command_palette_hooks = { path = "crates/command_palette_hooks" } command_palette_hooks = { path = "crates/command_palette_hooks" }
completion = { path = "crates/completion" }
copilot = { path = "crates/copilot" } copilot = { path = "crates/copilot" }
db = { path = "crates/db" } db = { path = "crates/db" }
dev_server_projects = { path = "crates/dev_server_projects" } dev_server_projects = { path = "crates/dev_server_projects" }

View File

@ -21,7 +21,12 @@ pub enum Model {
#[serde(alias = "claude-3-haiku", rename = "claude-3-haiku-20240307")] #[serde(alias = "claude-3-haiku", rename = "claude-3-haiku-20240307")]
Claude3Haiku, Claude3Haiku,
#[serde(rename = "custom")] #[serde(rename = "custom")]
Custom { name: String, max_tokens: usize }, Custom {
name: String,
max_tokens: usize,
/// Override this model with a different Anthropic model for tool calls.
tool_override: Option<String>,
},
} }
impl Model { impl Model {
@ -68,6 +73,18 @@ impl Model {
Self::Custom { max_tokens, .. } => *max_tokens, Self::Custom { max_tokens, .. } => *max_tokens,
} }
} }
pub fn tool_model_id(&self) -> &str {
if let Self::Custom {
tool_override: Some(tool_override),
..
} = self
{
tool_override
} else {
self.id()
}
}
} }
pub async fn complete( pub async fn complete(

View File

@ -32,7 +32,6 @@ client.workspace = true
clock.workspace = true clock.workspace = true
collections.workspace = true collections.workspace = true
command_palette_hooks.workspace = true command_palette_hooks.workspace = true
completion.workspace = true
editor.workspace = true editor.workspace = true
fs.workspace = true fs.workspace = true
futures.workspace = true futures.workspace = true
@ -77,7 +76,6 @@ workspace.workspace = true
picker.workspace = true picker.workspace = true
[dev-dependencies] [dev-dependencies]
completion = { workspace = true, features = ["test-support"] }
ctor.workspace = true ctor.workspace = true
editor = { workspace = true, features = ["test-support"] } editor = { workspace = true, features = ["test-support"] }
env_logger.workspace = true env_logger.workspace = true

View File

@ -15,7 +15,6 @@ use assistant_settings::AssistantSettings;
use assistant_slash_command::SlashCommandRegistry; use assistant_slash_command::SlashCommandRegistry;
use client::{proto, Client}; use client::{proto, Client};
use command_palette_hooks::CommandPaletteFilter; use command_palette_hooks::CommandPaletteFilter;
use completion::LanguageModelCompletionProvider;
pub use context::*; pub use context::*;
pub use context_store::*; pub use context_store::*;
use fs::Fs; use fs::Fs;
@ -192,7 +191,7 @@ pub fn init(fs: Arc<dyn Fs>, client: Arc<Client>, cx: &mut AppContext) {
context_store::init(&client); context_store::init(&client);
prompt_library::init(cx); prompt_library::init(cx);
init_completion_provider(cx); init_language_model_settings(cx);
assistant_slash_command::init(cx); assistant_slash_command::init(cx);
register_slash_commands(cx); register_slash_commands(cx);
assistant_panel::init(cx); assistant_panel::init(cx);
@ -217,8 +216,7 @@ pub fn init(fs: Arc<dyn Fs>, client: Arc<Client>, cx: &mut AppContext) {
.detach(); .detach();
} }
fn init_completion_provider(cx: &mut AppContext) { fn init_language_model_settings(cx: &mut AppContext) {
completion::init(cx);
update_active_language_model_from_settings(cx); update_active_language_model_from_settings(cx);
cx.observe_global::<SettingsStore>(update_active_language_model_from_settings) cx.observe_global::<SettingsStore>(update_active_language_model_from_settings)
@ -233,20 +231,9 @@ fn update_active_language_model_from_settings(cx: &mut AppContext) {
let settings = AssistantSettings::get_global(cx); let settings = AssistantSettings::get_global(cx);
let provider_name = LanguageModelProviderId::from(settings.default_model.provider.clone()); let provider_name = LanguageModelProviderId::from(settings.default_model.provider.clone());
let model_id = LanguageModelId::from(settings.default_model.model.clone()); let model_id = LanguageModelId::from(settings.default_model.model.clone());
LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
let Some(provider) = LanguageModelRegistry::global(cx) registry.select_active_model(&provider_name, &model_id, cx);
.read(cx) });
.provider(&provider_name)
else {
return;
};
let models = provider.provided_models(cx);
if let Some(model) = models.iter().find(|model| model.id() == model_id).cloned() {
LanguageModelCompletionProvider::global(cx).update(cx, |completion_provider, cx| {
completion_provider.set_active_model(model, cx);
});
}
} }
fn register_slash_commands(cx: &mut AppContext) { fn register_slash_commands(cx: &mut AppContext) {

View File

@ -19,7 +19,6 @@ use anyhow::{anyhow, Result};
use assistant_slash_command::{SlashCommand, SlashCommandOutputSection}; use assistant_slash_command::{SlashCommand, SlashCommandOutputSection};
use client::proto; use client::proto;
use collections::{BTreeSet, HashMap, HashSet}; use collections::{BTreeSet, HashMap, HashSet};
use completion::LanguageModelCompletionProvider;
use editor::{ use editor::{
actions::{FoldAt, MoveToEndOfLine, Newline, ShowCompletions, UnfoldAt}, actions::{FoldAt, MoveToEndOfLine, Newline, ShowCompletions, UnfoldAt},
display_map::{ display_map::{
@ -43,7 +42,7 @@ use language::{
language_settings::SoftWrap, Buffer, Capability, LanguageRegistry, LspAdapterDelegate, Point, language_settings::SoftWrap, Buffer, Capability, LanguageRegistry, LspAdapterDelegate, Point,
ToOffset, ToOffset,
}; };
use language_model::{LanguageModelProviderId, Role}; use language_model::{LanguageModelProviderId, LanguageModelRegistry, Role};
use multi_buffer::MultiBufferRow; use multi_buffer::MultiBufferRow;
use picker::{Picker, PickerDelegate}; use picker::{Picker, PickerDelegate};
use project::{Project, ProjectLspAdapterDelegate}; use project::{Project, ProjectLspAdapterDelegate};
@ -392,9 +391,9 @@ impl AssistantPanel {
cx.subscribe(&context_editor_toolbar, Self::handle_toolbar_event), cx.subscribe(&context_editor_toolbar, Self::handle_toolbar_event),
cx.subscribe(&model_summary_editor, Self::handle_summary_editor_event), cx.subscribe(&model_summary_editor, Self::handle_summary_editor_event),
cx.subscribe(&context_store, Self::handle_context_store_event), cx.subscribe(&context_store, Self::handle_context_store_event),
cx.observe( cx.subscribe(
&LanguageModelCompletionProvider::global(cx), &LanguageModelRegistry::global(cx),
|this, _, cx| { |this, _, _: &language_model::ActiveModelChanged, cx| {
this.completion_provider_changed(cx); this.completion_provider_changed(cx);
}, },
), ),
@ -560,7 +559,7 @@ impl AssistantPanel {
}) })
} }
let Some(new_provider_id) = LanguageModelCompletionProvider::read_global(cx) let Some(new_provider_id) = LanguageModelRegistry::read_global(cx)
.active_provider() .active_provider()
.map(|p| p.id()) .map(|p| p.id())
else { else {
@ -599,7 +598,7 @@ impl AssistantPanel {
} }
fn authentication_prompt(cx: &mut WindowContext) -> Option<AnyView> { fn authentication_prompt(cx: &mut WindowContext) -> Option<AnyView> {
if let Some(provider) = LanguageModelCompletionProvider::read_global(cx).active_provider() { if let Some(provider) = LanguageModelRegistry::read_global(cx).active_provider() {
if !provider.is_authenticated(cx) { if !provider.is_authenticated(cx) {
return Some(provider.authentication_prompt(cx)); return Some(provider.authentication_prompt(cx));
} }
@ -904,9 +903,9 @@ impl AssistantPanel {
} }
fn reset_credentials(&mut self, _: &ResetKey, cx: &mut ViewContext<Self>) { fn reset_credentials(&mut self, _: &ResetKey, cx: &mut ViewContext<Self>) {
LanguageModelCompletionProvider::read_global(cx) if let Some(provider) = LanguageModelRegistry::read_global(cx).active_provider() {
.reset_credentials(cx) provider.reset_credentials(cx).detach_and_log_err(cx);
.detach_and_log_err(cx); }
} }
fn toggle_model_selector(&mut self, _: &ToggleModelSelector, cx: &mut ViewContext<Self>) { fn toggle_model_selector(&mut self, _: &ToggleModelSelector, cx: &mut ViewContext<Self>) {
@ -1041,11 +1040,18 @@ impl AssistantPanel {
} }
fn is_authenticated(&mut self, cx: &mut ViewContext<Self>) -> bool { fn is_authenticated(&mut self, cx: &mut ViewContext<Self>) -> bool {
LanguageModelCompletionProvider::read_global(cx).is_authenticated(cx) LanguageModelRegistry::read_global(cx)
.active_provider()
.map_or(false, |provider| provider.is_authenticated(cx))
} }
fn authenticate(&mut self, cx: &mut ViewContext<Self>) -> Task<Result<()>> { fn authenticate(&mut self, cx: &mut ViewContext<Self>) -> Task<Result<()>> {
LanguageModelCompletionProvider::read_global(cx).authenticate(cx) LanguageModelRegistry::read_global(cx)
.active_provider()
.map_or(
Task::ready(Err(anyhow!("no active language model provider"))),
|provider| provider.authenticate(cx),
)
} }
fn render_signed_in(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement { fn render_signed_in(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
@ -2707,7 +2713,7 @@ impl ContextEditorToolbarItem {
} }
fn render_remaining_tokens(&self, cx: &mut ViewContext<Self>) -> Option<impl IntoElement> { fn render_remaining_tokens(&self, cx: &mut ViewContext<Self>) -> Option<impl IntoElement> {
let model = LanguageModelCompletionProvider::read_global(cx).active_model()?; let model = LanguageModelRegistry::read_global(cx).active_model()?;
let context = &self let context = &self
.active_context_editor .active_context_editor
.as_ref()? .as_ref()?
@ -2779,7 +2785,7 @@ impl Render for ContextEditorToolbarItem {
.whitespace_nowrap() .whitespace_nowrap()
.child( .child(
Label::new( Label::new(
LanguageModelCompletionProvider::read_global(cx) LanguageModelRegistry::read_global(cx)
.active_model() .active_model()
.map(|model| model.name().0) .map(|model| model.name().0)
.unwrap_or_else(|| "No model selected".into()), .unwrap_or_else(|| "No model selected".into()),

View File

@ -52,7 +52,7 @@ pub struct AssistantSettings {
pub dock: AssistantDockPosition, pub dock: AssistantDockPosition,
pub default_width: Pixels, pub default_width: Pixels,
pub default_height: Pixels, pub default_height: Pixels,
pub default_model: AssistantDefaultModel, pub default_model: LanguageModelSelection,
pub using_outdated_settings_version: bool, pub using_outdated_settings_version: bool,
} }
@ -198,25 +198,25 @@ impl AssistantSettingsContent {
.clone() .clone()
.and_then(|provider| match provider { .and_then(|provider| match provider {
AssistantProviderContentV1::ZedDotDev { default_model } => { AssistantProviderContentV1::ZedDotDev { default_model } => {
default_model.map(|model| AssistantDefaultModel { default_model.map(|model| LanguageModelSelection {
provider: "zed.dev".to_string(), provider: "zed.dev".to_string(),
model: model.id().to_string(), model: model.id().to_string(),
}) })
} }
AssistantProviderContentV1::OpenAi { default_model, .. } => { AssistantProviderContentV1::OpenAi { default_model, .. } => {
default_model.map(|model| AssistantDefaultModel { default_model.map(|model| LanguageModelSelection {
provider: "openai".to_string(), provider: "openai".to_string(),
model: model.id().to_string(), model: model.id().to_string(),
}) })
} }
AssistantProviderContentV1::Anthropic { default_model, .. } => { AssistantProviderContentV1::Anthropic { default_model, .. } => {
default_model.map(|model| AssistantDefaultModel { default_model.map(|model| LanguageModelSelection {
provider: "anthropic".to_string(), provider: "anthropic".to_string(),
model: model.id().to_string(), model: model.id().to_string(),
}) })
} }
AssistantProviderContentV1::Ollama { default_model, .. } => { AssistantProviderContentV1::Ollama { default_model, .. } => {
default_model.map(|model| AssistantDefaultModel { default_model.map(|model| LanguageModelSelection {
provider: "ollama".to_string(), provider: "ollama".to_string(),
model: model.id().to_string(), model: model.id().to_string(),
}) })
@ -231,7 +231,7 @@ impl AssistantSettingsContent {
dock: settings.dock, dock: settings.dock,
default_width: settings.default_width, default_width: settings.default_width,
default_height: settings.default_height, default_height: settings.default_height,
default_model: Some(AssistantDefaultModel { default_model: Some(LanguageModelSelection {
provider: "openai".to_string(), provider: "openai".to_string(),
model: settings model: settings
.default_open_ai_model .default_open_ai_model
@ -325,7 +325,7 @@ impl AssistantSettingsContent {
_ => {} _ => {}
}, },
VersionedAssistantSettingsContent::V2(settings) => { VersionedAssistantSettingsContent::V2(settings) => {
settings.default_model = Some(AssistantDefaultModel { provider, model }); settings.default_model = Some(LanguageModelSelection { provider, model });
} }
}, },
AssistantSettingsContent::Legacy(settings) => { AssistantSettingsContent::Legacy(settings) => {
@ -382,11 +382,11 @@ pub struct AssistantSettingsContentV2 {
/// Default: 320 /// Default: 320
default_height: Option<f32>, default_height: Option<f32>,
/// The default model to use when creating new contexts. /// The default model to use when creating new contexts.
default_model: Option<AssistantDefaultModel>, default_model: Option<LanguageModelSelection>,
} }
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)] #[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
pub struct AssistantDefaultModel { pub struct LanguageModelSelection {
#[schemars(schema_with = "providers_schema")] #[schemars(schema_with = "providers_schema")]
pub provider: String, pub provider: String,
pub model: String, pub model: String,
@ -407,7 +407,7 @@ fn providers_schema(_: &mut schemars::gen::SchemaGenerator) -> schemars::schema:
.into() .into()
} }
impl Default for AssistantDefaultModel { impl Default for LanguageModelSelection {
fn default() -> Self { fn default() -> Self {
Self { Self {
provider: "openai".to_string(), provider: "openai".to_string(),
@ -542,7 +542,7 @@ mod tests {
assert!(!AssistantSettings::get_global(cx).using_outdated_settings_version); assert!(!AssistantSettings::get_global(cx).using_outdated_settings_version);
assert_eq!( assert_eq!(
AssistantSettings::get_global(cx).default_model, AssistantSettings::get_global(cx).default_model,
AssistantDefaultModel { LanguageModelSelection {
provider: "openai".into(), provider: "openai".into(),
model: "gpt-4o".into(), model: "gpt-4o".into(),
} }
@ -555,7 +555,7 @@ mod tests {
|settings, _| { |settings, _| {
*settings = AssistantSettingsContent::Versioned( *settings = AssistantSettingsContent::Versioned(
VersionedAssistantSettingsContent::V2(AssistantSettingsContentV2 { VersionedAssistantSettingsContent::V2(AssistantSettingsContentV2 {
default_model: Some(AssistantDefaultModel { default_model: Some(LanguageModelSelection {
provider: "test-provider".into(), provider: "test-provider".into(),
model: "gpt-99".into(), model: "gpt-99".into(),
}), }),

View File

@ -1,6 +1,6 @@
use crate::{ use crate::{
prompt_library::PromptStore, slash_command::SlashCommandLine, InitialInsertion, prompt_library::PromptStore, slash_command::SlashCommandLine, InitialInsertion, MessageId,
LanguageModelCompletionProvider, MessageId, MessageStatus, MessageStatus,
}; };
use anyhow::{anyhow, Context as _, Result}; use anyhow::{anyhow, Context as _, Result};
use assistant_slash_command::{ use assistant_slash_command::{
@ -18,7 +18,10 @@ use gpui::{AppContext, Context as _, EventEmitter, Model, ModelContext, Subscrip
use language::{ use language::{
AnchorRangeExt, Bias, Buffer, LanguageRegistry, OffsetRangeExt, ParseStatus, Point, ToOffset, AnchorRangeExt, Bias, Buffer, LanguageRegistry, OffsetRangeExt, ParseStatus, Point, ToOffset,
}; };
use language_model::{LanguageModelRequest, LanguageModelRequestMessage, LanguageModelTool, Role}; use language_model::{
LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, LanguageModelTool,
Role,
};
use open_ai::Model as OpenAiModel; use open_ai::Model as OpenAiModel;
use paths::contexts_dir; use paths::contexts_dir;
use project::Project; use project::Project;
@ -1180,17 +1183,16 @@ impl Context {
pub(crate) fn count_remaining_tokens(&mut self, cx: &mut ModelContext<Self>) { pub(crate) fn count_remaining_tokens(&mut self, cx: &mut ModelContext<Self>) {
let request = self.to_completion_request(cx); let request = self.to_completion_request(cx);
let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
return;
};
self.pending_token_count = cx.spawn(|this, mut cx| { self.pending_token_count = cx.spawn(|this, mut cx| {
async move { async move {
cx.background_executor() cx.background_executor()
.timer(Duration::from_millis(200)) .timer(Duration::from_millis(200))
.await; .await;
let token_count = cx let token_count = cx.update(|cx| model.count_tokens(request, cx))?.await?;
.update(|cx| {
LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx)
})?
.await?;
this.update(&mut cx, |this, cx| { this.update(&mut cx, |this, cx| {
this.token_count = Some(token_count); this.token_count = Some(token_count);
cx.notify() cx.notify()
@ -1368,6 +1370,10 @@ impl Context {
} }
} }
let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
return Task::ready(Err(anyhow!("no active model")).log_err());
};
let mut request = self.to_completion_request(cx); let mut request = self.to_completion_request(cx);
let edit_step_range = edit_step.source_range.clone(); let edit_step_range = edit_step.source_range.clone();
let step_text = self let step_text = self
@ -1388,12 +1394,7 @@ impl Context {
content: prompt, content: prompt,
}); });
let tool_use = cx let tool_use = model.use_tool::<EditTool>(request, &cx).await?;
.update(|cx| {
LanguageModelCompletionProvider::read_global(cx)
.use_tool::<EditTool>(request, cx)
})?
.await?;
this.update(&mut cx, |this, cx| { this.update(&mut cx, |this, cx| {
let step_index = this let step_index = this
@ -1568,6 +1569,8 @@ impl Context {
} }
pub fn assist(&mut self, cx: &mut ModelContext<Self>) -> Option<MessageAnchor> { pub fn assist(&mut self, cx: &mut ModelContext<Self>) -> Option<MessageAnchor> {
let provider = LanguageModelRegistry::read_global(cx).active_provider()?;
let model = LanguageModelRegistry::read_global(cx).active_model()?;
let last_message_id = self.message_anchors.iter().rev().find_map(|message| { let last_message_id = self.message_anchors.iter().rev().find_map(|message| {
message message
.start .start
@ -1575,14 +1578,12 @@ impl Context {
.then_some(message.id) .then_some(message.id)
})?; })?;
if !LanguageModelCompletionProvider::read_global(cx).is_authenticated(cx) { if !provider.is_authenticated(cx) {
log::info!("completion provider has no credentials"); log::info!("completion provider has no credentials");
return None; return None;
} }
let request = self.to_completion_request(cx); let request = self.to_completion_request(cx);
let stream =
LanguageModelCompletionProvider::read_global(cx).stream_completion(request, cx);
let assistant_message = self let assistant_message = self
.insert_message_after(last_message_id, Role::Assistant, MessageStatus::Pending, cx) .insert_message_after(last_message_id, Role::Assistant, MessageStatus::Pending, cx)
.unwrap(); .unwrap();
@ -1594,6 +1595,7 @@ impl Context {
let task = cx.spawn({ let task = cx.spawn({
|this, mut cx| async move { |this, mut cx| async move {
let stream = model.stream_completion(request, &cx);
let assistant_message_id = assistant_message.id; let assistant_message_id = assistant_message.id;
let mut response_latency = None; let mut response_latency = None;
let stream_completion = async { let stream_completion = async {
@ -1662,14 +1664,10 @@ impl Context {
}); });
if let Some(telemetry) = this.telemetry.as_ref() { if let Some(telemetry) = this.telemetry.as_ref() {
let model_telemetry_id = LanguageModelCompletionProvider::read_global(cx)
.active_model()
.map(|m| m.telemetry_id())
.unwrap_or_default();
telemetry.report_assistant_event( telemetry.report_assistant_event(
Some(this.id.0.clone()), Some(this.id.0.clone()),
AssistantKind::Panel, AssistantKind::Panel,
model_telemetry_id, model.telemetry_id(),
response_latency, response_latency,
error_message, error_message,
); );
@ -1935,8 +1933,15 @@ impl Context {
} }
pub(super) fn summarize(&mut self, replace_old: bool, cx: &mut ModelContext<Self>) { pub(super) fn summarize(&mut self, replace_old: bool, cx: &mut ModelContext<Self>) {
let Some(provider) = LanguageModelRegistry::read_global(cx).active_provider() else {
return;
};
let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
return;
};
if replace_old || (self.message_anchors.len() >= 2 && self.summary.is_none()) { if replace_old || (self.message_anchors.len() >= 2 && self.summary.is_none()) {
if !LanguageModelCompletionProvider::read_global(cx).is_authenticated(cx) { if !provider.is_authenticated(cx) {
return; return;
} }
@ -1953,10 +1958,9 @@ impl Context {
temperature: 1.0, temperature: 1.0,
}; };
let stream =
LanguageModelCompletionProvider::read_global(cx).stream_completion(request, cx);
self.pending_summary = cx.spawn(|this, mut cx| { self.pending_summary = cx.spawn(|this, mut cx| {
async move { async move {
let stream = model.stream_completion(request, &cx);
let mut messages = stream.await?; let mut messages = stream.await?;
let mut replaced = !replace_old; let mut replaced = !replace_old;
@ -2490,7 +2494,6 @@ mod tests {
fn test_inserting_and_removing_messages(cx: &mut AppContext) { fn test_inserting_and_removing_messages(cx: &mut AppContext) {
let settings_store = SettingsStore::test(cx); let settings_store = SettingsStore::test(cx);
language_model::LanguageModelRegistry::test(cx); language_model::LanguageModelRegistry::test(cx);
completion::LanguageModelCompletionProvider::test(cx);
cx.set_global(settings_store); cx.set_global(settings_store);
assistant_panel::init(cx); assistant_panel::init(cx);
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone())); let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
@ -2623,7 +2626,6 @@ mod tests {
let settings_store = SettingsStore::test(cx); let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store); cx.set_global(settings_store);
language_model::LanguageModelRegistry::test(cx); language_model::LanguageModelRegistry::test(cx);
completion::LanguageModelCompletionProvider::test(cx);
assistant_panel::init(cx); assistant_panel::init(cx);
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone())); let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
@ -2717,7 +2719,6 @@ mod tests {
fn test_messages_for_offsets(cx: &mut AppContext) { fn test_messages_for_offsets(cx: &mut AppContext) {
let settings_store = SettingsStore::test(cx); let settings_store = SettingsStore::test(cx);
language_model::LanguageModelRegistry::test(cx); language_model::LanguageModelRegistry::test(cx);
completion::LanguageModelCompletionProvider::test(cx);
cx.set_global(settings_store); cx.set_global(settings_store);
assistant_panel::init(cx); assistant_panel::init(cx);
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone())); let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
@ -2803,7 +2804,6 @@ mod tests {
let settings_store = cx.update(SettingsStore::test); let settings_store = cx.update(SettingsStore::test);
cx.set_global(settings_store); cx.set_global(settings_store);
cx.update(language_model::LanguageModelRegistry::test); cx.update(language_model::LanguageModelRegistry::test);
cx.update(completion::LanguageModelCompletionProvider::test);
cx.update(Project::init_settings); cx.update(Project::init_settings);
cx.update(assistant_panel::init); cx.update(assistant_panel::init);
let fs = FakeFs::new(cx.background_executor.clone()); let fs = FakeFs::new(cx.background_executor.clone());
@ -2930,7 +2930,6 @@ mod tests {
cx.set_global(settings_store); cx.set_global(settings_store);
let fake_provider = cx.update(language_model::LanguageModelRegistry::test); let fake_provider = cx.update(language_model::LanguageModelRegistry::test);
cx.update(completion::LanguageModelCompletionProvider::test);
let fake_model = fake_provider.test_model(); let fake_model = fake_provider.test_model();
cx.update(assistant_panel::init); cx.update(assistant_panel::init);
@ -3032,7 +3031,6 @@ mod tests {
let settings_store = cx.update(SettingsStore::test); let settings_store = cx.update(SettingsStore::test);
cx.set_global(settings_store); cx.set_global(settings_store);
cx.update(language_model::LanguageModelRegistry::test); cx.update(language_model::LanguageModelRegistry::test);
cx.update(completion::LanguageModelCompletionProvider::test);
cx.update(assistant_panel::init); cx.update(assistant_panel::init);
let registry = Arc::new(LanguageRegistry::test(cx.executor())); let registry = Arc::new(LanguageRegistry::test(cx.executor()));
let context = cx.new_model(|cx| Context::local(registry.clone(), None, cx)); let context = cx.new_model(|cx| Context::local(registry.clone(), None, cx));
@ -3109,7 +3107,6 @@ mod tests {
let settings_store = cx.update(SettingsStore::test); let settings_store = cx.update(SettingsStore::test);
cx.set_global(settings_store); cx.set_global(settings_store);
cx.update(language_model::LanguageModelRegistry::test); cx.update(language_model::LanguageModelRegistry::test);
cx.update(completion::LanguageModelCompletionProvider::test);
cx.update(assistant_panel::init); cx.update(assistant_panel::init);
let slash_commands = cx.update(SlashCommandRegistry::default_global); let slash_commands = cx.update(SlashCommandRegistry::default_global);

View File

@ -1,6 +1,6 @@
use crate::{ use crate::{
humanize_token_count, prompts::generate_content_prompt, AssistantPanel, AssistantPanelEvent, humanize_token_count, prompts::generate_content_prompt, AssistantPanel, AssistantPanelEvent,
Hunk, LanguageModelCompletionProvider, ModelSelector, StreamingDiff, Hunk, ModelSelector, StreamingDiff,
}; };
use anyhow::{anyhow, Context as _, Result}; use anyhow::{anyhow, Context as _, Result};
use client::telemetry::Telemetry; use client::telemetry::Telemetry;
@ -27,7 +27,9 @@ use gpui::{
WindowContext, WindowContext,
}; };
use language::{Buffer, IndentKind, Point, Selection, TransactionId}; use language::{Buffer, IndentKind, Point, Selection, TransactionId};
use language_model::{LanguageModelRequest, LanguageModelRequestMessage, Role}; use language_model::{
LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
};
use multi_buffer::MultiBufferRow; use multi_buffer::MultiBufferRow;
use parking_lot::Mutex; use parking_lot::Mutex;
use rope::Rope; use rope::Rope;
@ -1328,7 +1330,7 @@ impl Render for PromptEditor {
Tooltip::with_meta( Tooltip::with_meta(
format!( format!(
"Using {}", "Using {}",
LanguageModelCompletionProvider::read_global(cx) LanguageModelRegistry::read_global(cx)
.active_model() .active_model()
.map(|model| model.name().0) .map(|model| model.name().0)
.unwrap_or_else(|| "No model selected".into()), .unwrap_or_else(|| "No model selected".into()),
@ -1662,7 +1664,7 @@ impl PromptEditor {
} }
fn render_token_count(&self, cx: &mut ViewContext<Self>) -> Option<impl IntoElement> { fn render_token_count(&self, cx: &mut ViewContext<Self>) -> Option<impl IntoElement> {
let model = LanguageModelCompletionProvider::read_global(cx).active_model()?; let model = LanguageModelRegistry::read_global(cx).active_model()?;
let token_count = self.token_count?; let token_count = self.token_count?;
let max_token_count = model.max_token_count(); let max_token_count = model.max_token_count();
@ -2013,8 +2015,12 @@ impl Codegen {
assistant_panel_context: Option<LanguageModelRequest>, assistant_panel_context: Option<LanguageModelRequest>,
cx: &AppContext, cx: &AppContext,
) -> BoxFuture<'static, Result<usize>> { ) -> BoxFuture<'static, Result<usize>> {
let request = self.build_request(user_prompt, assistant_panel_context, edit_range, cx); if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() {
LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx) let request = self.build_request(user_prompt, assistant_panel_context, edit_range, cx);
model.count_tokens(request, cx)
} else {
future::ready(Err(anyhow!("no active model"))).boxed()
}
} }
pub fn start( pub fn start(
@ -2024,6 +2030,10 @@ impl Codegen {
assistant_panel_context: Option<LanguageModelRequest>, assistant_panel_context: Option<LanguageModelRequest>,
cx: &mut ModelContext<Self>, cx: &mut ModelContext<Self>,
) -> Result<()> { ) -> Result<()> {
let model = LanguageModelRegistry::read_global(cx)
.active_model()
.context("no active model")?;
self.undo(cx); self.undo(cx);
// Handle initial insertion // Handle initial insertion
@ -2053,10 +2063,7 @@ impl Codegen {
None None
}; };
let model_telemetry_id = LanguageModelCompletionProvider::read_global(cx) let telemetry_id = model.telemetry_id();
.active_model_telemetry_id()
.context("no active model")?;
let chunks: LocalBoxFuture<Result<BoxStream<Result<String>>>> = if user_prompt let chunks: LocalBoxFuture<Result<BoxStream<Result<String>>>> = if user_prompt
.trim() .trim()
.to_lowercase() .to_lowercase()
@ -2067,10 +2074,10 @@ impl Codegen {
let request = let request =
self.build_request(user_prompt, assistant_panel_context, edit_range.clone(), cx); self.build_request(user_prompt, assistant_panel_context, edit_range.clone(), cx);
let chunks = let chunks =
LanguageModelCompletionProvider::read_global(cx).stream_completion(request, cx); cx.spawn(|_, cx| async move { model.stream_completion(request, &cx).await });
async move { Ok(chunks.await?.boxed()) }.boxed_local() async move { Ok(chunks.await?.boxed()) }.boxed_local()
}; };
self.handle_stream(model_telemetry_id, edit_range, chunks, cx); self.handle_stream(telemetry_id, edit_range, chunks, cx);
Ok(()) Ok(())
} }
@ -2657,7 +2664,6 @@ mod tests {
async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) { async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
cx.set_global(cx.update(SettingsStore::test)); cx.set_global(cx.update(SettingsStore::test));
cx.update(language_model::LanguageModelRegistry::test); cx.update(language_model::LanguageModelRegistry::test);
cx.update(completion::LanguageModelCompletionProvider::test);
cx.update(language_settings::init); cx.update(language_settings::init);
let text = indoc! {" let text = indoc! {"
@ -2789,7 +2795,6 @@ mod tests {
mut rng: StdRng, mut rng: StdRng,
) { ) {
cx.update(LanguageModelRegistry::test); cx.update(LanguageModelRegistry::test);
cx.update(completion::LanguageModelCompletionProvider::test);
cx.set_global(cx.update(SettingsStore::test)); cx.set_global(cx.update(SettingsStore::test));
cx.update(language_settings::init); cx.update(language_settings::init);
@ -2853,7 +2858,6 @@ mod tests {
#[gpui::test(iterations = 10)] #[gpui::test(iterations = 10)]
async fn test_autoindent_respects_tabs_in_selection(cx: &mut TestAppContext) { async fn test_autoindent_respects_tabs_in_selection(cx: &mut TestAppContext) {
cx.update(LanguageModelRegistry::test); cx.update(LanguageModelRegistry::test);
cx.update(completion::LanguageModelCompletionProvider::test);
cx.set_global(cx.update(SettingsStore::test)); cx.set_global(cx.update(SettingsStore::test));
cx.update(language_settings::init); cx.update(language_settings::init);

View File

@ -1,6 +1,6 @@
use std::sync::Arc; use std::sync::Arc;
use crate::{assistant_settings::AssistantSettings, LanguageModelCompletionProvider}; use crate::assistant_settings::AssistantSettings;
use fs::Fs; use fs::Fs;
use gpui::SharedString; use gpui::SharedString;
use language_model::LanguageModelRegistry; use language_model::LanguageModelRegistry;
@ -81,13 +81,13 @@ impl<T: PopoverTrigger> RenderOnce for ModelSelector<T> {
} }
}, },
{ {
let provider = provider.id(); let provider = provider.clone();
move |cx| { move |cx| {
LanguageModelCompletionProvider::global(cx).update( LanguageModelRegistry::global(cx).update(
cx, cx,
|completion_provider, cx| { |completion_provider, cx| {
completion_provider completion_provider
.set_active_provider(provider.clone(), cx) .set_active_provider(Some(provider.clone()), cx);
}, },
); );
} }
@ -95,12 +95,12 @@ impl<T: PopoverTrigger> RenderOnce for ModelSelector<T> {
); );
} }
let selected_model = LanguageModelCompletionProvider::read_global(cx) let selected_provider = LanguageModelRegistry::read_global(cx)
.active_model()
.map(|m| m.id());
let selected_provider = LanguageModelCompletionProvider::read_global(cx)
.active_provider() .active_provider()
.map(|m| m.id()); .map(|m| m.id());
let selected_model = LanguageModelRegistry::read_global(cx)
.active_model()
.map(|m| m.id());
for available_model in available_models { for available_model in available_models {
menu = menu.custom_entry( menu = menu.custom_entry(

View File

@ -1,6 +1,5 @@
use crate::{ use crate::{
slash_command::SlashCommandCompletionProvider, AssistantPanel, InlineAssist, InlineAssistant, slash_command::SlashCommandCompletionProvider, AssistantPanel, InlineAssist, InlineAssistant,
LanguageModelCompletionProvider,
}; };
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use assets::Assets; use assets::Assets;
@ -19,7 +18,9 @@ use gpui::{
}; };
use heed::{types::SerdeBincode, Database, RoTxn}; use heed::{types::SerdeBincode, Database, RoTxn};
use language::{language_settings::SoftWrap, Buffer, LanguageRegistry}; use language::{language_settings::SoftWrap, Buffer, LanguageRegistry};
use language_model::{LanguageModelRequest, LanguageModelRequestMessage, Role}; use language_model::{
LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
};
use parking_lot::RwLock; use parking_lot::RwLock;
use picker::{Picker, PickerDelegate}; use picker::{Picker, PickerDelegate};
use rope::Rope; use rope::Rope;
@ -636,7 +637,10 @@ impl PromptLibrary {
}; };
let prompt_editor = &self.prompt_editors[&active_prompt_id].body_editor; let prompt_editor = &self.prompt_editors[&active_prompt_id].body_editor;
let provider = LanguageModelCompletionProvider::read_global(cx); let Some(provider) = LanguageModelRegistry::read_global(cx).active_provider() else {
return;
};
let initial_prompt = action.prompt.clone(); let initial_prompt = action.prompt.clone();
if provider.is_authenticated(cx) { if provider.is_authenticated(cx) {
InlineAssistant::update_global(cx, |assistant, cx| { InlineAssistant::update_global(cx, |assistant, cx| {
@ -725,6 +729,9 @@ impl PromptLibrary {
} }
fn count_tokens(&mut self, prompt_id: PromptId, cx: &mut ViewContext<Self>) { fn count_tokens(&mut self, prompt_id: PromptId, cx: &mut ViewContext<Self>) {
let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
return;
};
if let Some(prompt) = self.prompt_editors.get_mut(&prompt_id) { if let Some(prompt) = self.prompt_editors.get_mut(&prompt_id) {
let editor = &prompt.body_editor.read(cx); let editor = &prompt.body_editor.read(cx);
let buffer = &editor.buffer().read(cx).as_singleton().unwrap().read(cx); let buffer = &editor.buffer().read(cx).as_singleton().unwrap().read(cx);
@ -736,7 +743,7 @@ impl PromptLibrary {
cx.background_executor().timer(DEBOUNCE_TIMEOUT).await; cx.background_executor().timer(DEBOUNCE_TIMEOUT).await;
let token_count = cx let token_count = cx
.update(|cx| { .update(|cx| {
LanguageModelCompletionProvider::read_global(cx).count_tokens( model.count_tokens(
LanguageModelRequest { LanguageModelRequest {
messages: vec![LanguageModelRequestMessage { messages: vec![LanguageModelRequestMessage {
role: Role::System, role: Role::System,
@ -804,7 +811,7 @@ impl PromptLibrary {
let prompt_metadata = self.store.metadata(prompt_id)?; let prompt_metadata = self.store.metadata(prompt_id)?;
let prompt_editor = &self.prompt_editors[&prompt_id]; let prompt_editor = &self.prompt_editors[&prompt_id];
let focus_handle = prompt_editor.body_editor.focus_handle(cx); let focus_handle = prompt_editor.body_editor.focus_handle(cx);
let current_model = LanguageModelCompletionProvider::read_global(cx).active_model(); let model = LanguageModelRegistry::read_global(cx).active_model();
let settings = ThemeSettings::get_global(cx); let settings = ThemeSettings::get_global(cx);
Some( Some(
@ -914,7 +921,7 @@ impl PromptLibrary {
None, None,
format!( format!(
"Model: {}", "Model: {}",
current_model model
.as_ref() .as_ref()
.map(|model| model .map(|model| model
.name() .name()

View File

@ -1,6 +1,6 @@
use crate::{ use crate::{
humanize_token_count, prompts::generate_terminal_assistant_prompt, AssistantPanel, humanize_token_count, prompts::generate_terminal_assistant_prompt, AssistantPanel,
AssistantPanelEvent, LanguageModelCompletionProvider, ModelSelector, AssistantPanelEvent, ModelSelector,
}; };
use anyhow::{Context as _, Result}; use anyhow::{Context as _, Result};
use client::telemetry::Telemetry; use client::telemetry::Telemetry;
@ -16,7 +16,9 @@ use gpui::{
Subscription, Task, TextStyle, UpdateGlobal, View, WeakView, Subscription, Task, TextStyle, UpdateGlobal, View, WeakView,
}; };
use language::Buffer; use language::Buffer;
use language_model::{LanguageModelRequest, LanguageModelRequestMessage, Role}; use language_model::{
LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
};
use settings::Settings; use settings::Settings;
use std::{ use std::{
cmp, cmp,
@ -556,7 +558,7 @@ impl Render for PromptEditor {
Tooltip::with_meta( Tooltip::with_meta(
format!( format!(
"Using {}", "Using {}",
LanguageModelCompletionProvider::read_global(cx) LanguageModelRegistry::read_global(cx)
.active_model() .active_model()
.map(|model| model.name().0) .map(|model| model.name().0)
.unwrap_or_else(|| "No model selected".into()), .unwrap_or_else(|| "No model selected".into()),
@ -700,6 +702,9 @@ impl PromptEditor {
fn count_tokens(&mut self, cx: &mut ViewContext<Self>) { fn count_tokens(&mut self, cx: &mut ViewContext<Self>) {
let assist_id = self.id; let assist_id = self.id;
let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
return;
};
self.pending_token_count = cx.spawn(|this, mut cx| async move { self.pending_token_count = cx.spawn(|this, mut cx| async move {
cx.background_executor().timer(Duration::from_secs(1)).await; cx.background_executor().timer(Duration::from_secs(1)).await;
let request = let request =
@ -707,11 +712,7 @@ impl PromptEditor {
inline_assistant.request_for_inline_assist(assist_id, cx) inline_assistant.request_for_inline_assist(assist_id, cx)
})??; })??;
let token_count = cx let token_count = cx.update(|cx| model.count_tokens(request, cx))?.await?;
.update(|cx| {
LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx)
})?
.await?;
this.update(&mut cx, |this, cx| { this.update(&mut cx, |this, cx| {
this.token_count = Some(token_count); this.token_count = Some(token_count);
cx.notify(); cx.notify();
@ -840,7 +841,7 @@ impl PromptEditor {
} }
fn render_token_count(&self, cx: &mut ViewContext<Self>) -> Option<impl IntoElement> { fn render_token_count(&self, cx: &mut ViewContext<Self>) -> Option<impl IntoElement> {
let model = LanguageModelCompletionProvider::read_global(cx).active_model()?; let model = LanguageModelRegistry::read_global(cx).active_model()?;
let token_count = self.token_count?; let token_count = self.token_count?;
let max_token_count = model.max_token_count(); let max_token_count = model.max_token_count();
@ -982,19 +983,16 @@ impl Codegen {
} }
pub fn start(&mut self, prompt: LanguageModelRequest, cx: &mut ModelContext<Self>) { pub fn start(&mut self, prompt: LanguageModelRequest, cx: &mut ModelContext<Self>) {
self.status = CodegenStatus::Pending; let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
self.transaction = Some(TerminalTransaction::start(self.terminal.clone())); return;
};
let telemetry = self.telemetry.clone(); let telemetry = self.telemetry.clone();
let model_telemetry_id = LanguageModelCompletionProvider::read_global(cx) self.status = CodegenStatus::Pending;
.active_model() self.transaction = Some(TerminalTransaction::start(self.terminal.clone()));
.map(|m| m.telemetry_id())
.unwrap_or_default();
let response =
LanguageModelCompletionProvider::read_global(cx).stream_completion(prompt, cx);
self.generation = cx.spawn(|this, mut cx| async move { self.generation = cx.spawn(|this, mut cx| async move {
let response = response.await; let model_telemetry_id = model.telemetry_id();
let response = model.stream_completion(prompt, &cx).await;
let generate = async { let generate = async {
let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1); let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);

View File

@ -80,7 +80,6 @@ channel.workspace = true
client = { workspace = true, features = ["test-support"] } client = { workspace = true, features = ["test-support"] }
collab_ui = { workspace = true, features = ["test-support"] } collab_ui = { workspace = true, features = ["test-support"] }
collections = { workspace = true, features = ["test-support"] } collections = { workspace = true, features = ["test-support"] }
completion = { workspace = true, features = ["test-support"] }
ctor.workspace = true ctor.workspace = true
editor = { workspace = true, features = ["test-support"] } editor = { workspace = true, features = ["test-support"] }
env_logger.workspace = true env_logger.workspace = true

View File

@ -300,7 +300,6 @@ impl TestServer {
dev_server_projects::init(client.clone(), cx); dev_server_projects::init(client.clone(), cx);
settings::KeymapFile::load_asset(os_keymap, cx).unwrap(); settings::KeymapFile::load_asset(os_keymap, cx).unwrap();
language_model::LanguageModelRegistry::test(cx); language_model::LanguageModelRegistry::test(cx);
completion::init(cx);
assistant::context_store::init(&client); assistant::context_store::init(&client);
}); });

View File

@ -1,45 +0,0 @@
[package]
name = "completion"
version = "0.1.0"
edition = "2021"
publish = false
license = "GPL-3.0-or-later"
[lints]
workspace = true
[lib]
path = "src/completion.rs"
doctest = false
[features]
test-support = [
"editor/test-support",
"language/test-support",
"language_model/test-support",
"project/test-support",
"text/test-support",
]
[dependencies]
anyhow.workspace = true
futures.workspace = true
gpui.workspace = true
language_model.workspace = true
schemars.workspace = true
serde.workspace = true
serde_json.workspace = true
settings.workspace = true
smol.workspace = true
ui.workspace = true
[dev-dependencies]
ctor.workspace = true
editor = { workspace = true, features = ["test-support"] }
env_logger.workspace = true
language = { workspace = true, features = ["test-support"] }
project = { workspace = true, features = ["test-support"] }
language_model = { workspace = true, features = ["test-support"] }
rand.workspace = true
text = { workspace = true, features = ["test-support"] }
unindent.workspace = true

View File

@ -1 +0,0 @@
../../LICENSE-GPL

View File

@ -1,312 +0,0 @@
use anyhow::{anyhow, Result};
use futures::{future::BoxFuture, stream::BoxStream, StreamExt};
use gpui::{AppContext, Global, Model, ModelContext, Task};
use language_model::{
LanguageModel, LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry,
LanguageModelRequest, LanguageModelTool,
};
use smol::{
future::FutureExt,
lock::{Semaphore, SemaphoreGuardArc},
};
use std::{future, pin::Pin, sync::Arc, task::Poll};
use ui::Context;
pub fn init(cx: &mut AppContext) {
let completion_provider = cx.new_model(|cx| LanguageModelCompletionProvider::new(cx));
cx.set_global(GlobalLanguageModelCompletionProvider(completion_provider));
}
struct GlobalLanguageModelCompletionProvider(Model<LanguageModelCompletionProvider>);
impl Global for GlobalLanguageModelCompletionProvider {}
pub struct LanguageModelCompletionProvider {
active_provider: Option<Arc<dyn LanguageModelProvider>>,
active_model: Option<Arc<dyn LanguageModel>>,
request_limiter: Arc<Semaphore>,
}
const MAX_CONCURRENT_COMPLETION_REQUESTS: usize = 4;
pub struct LanguageModelCompletionResponse {
inner: BoxStream<'static, Result<String>>,
_lock: SemaphoreGuardArc,
}
impl futures::Stream for LanguageModelCompletionResponse {
type Item = Result<String>;
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Self::Item>> {
Pin::new(&mut self.inner).poll_next(cx)
}
}
impl LanguageModelCompletionProvider {
pub fn global(cx: &AppContext) -> Model<Self> {
cx.global::<GlobalLanguageModelCompletionProvider>()
.0
.clone()
}
pub fn read_global(cx: &AppContext) -> &Self {
cx.global::<GlobalLanguageModelCompletionProvider>()
.0
.read(cx)
}
#[cfg(any(test, feature = "test-support"))]
pub fn test(cx: &mut AppContext) {
let provider = cx.new_model(|cx| {
let mut this = Self::new(cx);
let available_model = LanguageModelRegistry::read_global(cx)
.available_models(cx)
.first()
.unwrap()
.clone();
this.set_active_model(available_model, cx);
this
});
cx.set_global(GlobalLanguageModelCompletionProvider(provider));
}
pub fn new(cx: &mut ModelContext<Self>) -> Self {
cx.observe(&LanguageModelRegistry::global(cx), |_, _, cx| {
cx.notify();
})
.detach();
Self {
active_provider: None,
active_model: None,
request_limiter: Arc::new(Semaphore::new(MAX_CONCURRENT_COMPLETION_REQUESTS)),
}
}
pub fn active_provider(&self) -> Option<Arc<dyn LanguageModelProvider>> {
self.active_provider.clone()
}
pub fn set_active_provider(
&mut self,
provider_id: LanguageModelProviderId,
cx: &mut ModelContext<Self>,
) {
self.active_provider = LanguageModelRegistry::read_global(cx).provider(&provider_id);
self.active_model = None;
cx.notify();
}
pub fn active_model(&self) -> Option<Arc<dyn LanguageModel>> {
self.active_model.clone()
}
pub fn set_active_model(&mut self, model: Arc<dyn LanguageModel>, cx: &mut ModelContext<Self>) {
if self.active_model.as_ref().map_or(false, |m| {
m.id() == model.id() && m.provider_id() == model.provider_id()
}) {
return;
}
self.active_provider =
LanguageModelRegistry::read_global(cx).provider(&model.provider_id());
self.active_model = Some(model.clone());
if let Some(provider) = self.active_provider.as_ref() {
provider.load_model(model, cx);
}
cx.notify();
}
pub fn is_authenticated(&self, cx: &AppContext) -> bool {
self.active_provider
.as_ref()
.map_or(false, |provider| provider.is_authenticated(cx))
}
pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
self.active_provider
.as_ref()
.map_or(Task::ready(Ok(())), |provider| provider.authenticate(cx))
}
pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
self.active_provider
.as_ref()
.map_or(Task::ready(Ok(())), |provider| {
provider.reset_credentials(cx)
})
}
pub fn count_tokens(
&self,
request: LanguageModelRequest,
cx: &AppContext,
) -> BoxFuture<'static, Result<usize>> {
if let Some(model) = self.active_model() {
model.count_tokens(request, cx)
} else {
future::ready(Err(anyhow!("no active model"))).boxed()
}
}
pub fn stream_completion(
&self,
request: LanguageModelRequest,
cx: &AppContext,
) -> Task<Result<LanguageModelCompletionResponse>> {
if let Some(language_model) = self.active_model() {
let rate_limiter = self.request_limiter.clone();
cx.spawn(|cx| async move {
let lock = rate_limiter.acquire_arc().await;
let response = language_model.stream_completion(request, &cx).await?;
Ok(LanguageModelCompletionResponse {
inner: response,
_lock: lock,
})
})
} else {
Task::ready(Err(anyhow!("No active model set")))
}
}
pub fn complete(&self, request: LanguageModelRequest, cx: &AppContext) -> Task<Result<String>> {
let response = self.stream_completion(request, cx);
cx.foreground_executor().spawn(async move {
let mut chunks = response.await?;
let mut completion = String::new();
while let Some(chunk) = chunks.next().await {
let chunk = chunk?;
completion.push_str(&chunk);
}
Ok(completion)
})
}
pub fn use_tool<T: LanguageModelTool>(
&self,
request: LanguageModelRequest,
cx: &AppContext,
) -> Task<Result<T>> {
if let Some(language_model) = self.active_model() {
cx.spawn(|cx| async move {
let schema = schemars::schema_for!(T);
let schema_json = serde_json::to_value(&schema).unwrap();
let request =
language_model.use_tool(request, T::name(), T::description(), schema_json, &cx);
let response = request.await?;
Ok(serde_json::from_value(response)?)
})
} else {
Task::ready(Err(anyhow!("No active model set")))
}
}
pub fn active_model_telemetry_id(&self) -> Option<String> {
self.active_model.as_ref().map(|m| m.telemetry_id())
}
}
#[cfg(test)]
mod tests {
use futures::StreamExt;
use gpui::AppContext;
use settings::SettingsStore;
use ui::Context;
use crate::{
LanguageModelCompletionProvider, LanguageModelRequest, MAX_CONCURRENT_COMPLETION_REQUESTS,
};
use language_model::LanguageModelRegistry;
#[gpui::test]
fn test_rate_limiting(cx: &mut AppContext) {
SettingsStore::test(cx);
let fake_provider = LanguageModelRegistry::test(cx);
let model = LanguageModelRegistry::read_global(cx)
.available_models(cx)
.first()
.cloned()
.unwrap();
let provider = cx.new_model(|cx| {
let mut provider = LanguageModelCompletionProvider::new(cx);
provider.set_active_model(model.clone(), cx);
provider
});
let fake_model = fake_provider.test_model();
// Enqueue some requests
for i in 0..MAX_CONCURRENT_COMPLETION_REQUESTS * 2 {
let response = provider.read(cx).stream_completion(
LanguageModelRequest {
temperature: i as f32 / 10.0,
..Default::default()
},
cx,
);
cx.background_executor()
.spawn(async move {
let mut stream = response.await.unwrap();
while let Some(message) = stream.next().await {
message.unwrap();
}
})
.detach();
}
cx.background_executor().run_until_parked();
assert_eq!(
fake_model.completion_count(),
MAX_CONCURRENT_COMPLETION_REQUESTS
);
// Get the first completion request that is in flight and mark it as completed.
let completion = fake_model.pending_completions().into_iter().next().unwrap();
fake_model.finish_completion(&completion);
// Ensure that the number of in-flight completion requests is reduced.
assert_eq!(
fake_model.completion_count(),
MAX_CONCURRENT_COMPLETION_REQUESTS - 1
);
cx.background_executor().run_until_parked();
// Ensure that another completion request was allowed to acquire the lock.
assert_eq!(
fake_model.completion_count(),
MAX_CONCURRENT_COMPLETION_REQUESTS
);
// Mark all completion requests as finished that are in flight.
for request in fake_model.pending_completions() {
fake_model.finish_completion(&request);
}
assert_eq!(fake_model.completion_count(), 0);
// Wait until the background tasks acquire the lock again.
cx.background_executor().run_until_parked();
assert_eq!(
fake_model.completion_count(),
MAX_CONCURRENT_COMPLETION_REQUESTS - 1
);
// Finish all remaining completion requests.
for request in fake_model.pending_completions() {
fake_model.finish_completion(&request);
}
cx.background_executor().run_until_parked();
assert_eq!(fake_model.completion_count(), 0);
}
}

View File

@ -208,13 +208,13 @@ impl CopilotChat {
pub async fn stream_completion( pub async fn stream_completion(
request: Request, request: Request,
low_speed_timeout: Option<Duration>, low_speed_timeout: Option<Duration>,
cx: &mut AsyncAppContext, mut cx: AsyncAppContext,
) -> Result<BoxStream<'static, Result<ResponseEvent>>> { ) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
let Some(this) = cx.update(|cx| Self::global(cx)).ok().flatten() else { let Some(this) = cx.update(|cx| Self::global(cx)).ok().flatten() else {
return Err(anyhow!("Copilot chat is not enabled")); return Err(anyhow!("Copilot chat is not enabled"));
}; };
let (oauth_token, api_token, client) = this.read_with(cx, |this, _| { let (oauth_token, api_token, client) = this.read_with(&cx, |this, _| {
( (
this.oauth_token.clone(), this.oauth_token.clone(),
this.api_token.clone(), this.api_token.clone(),
@ -229,7 +229,7 @@ impl CopilotChat {
_ => { _ => {
let token = let token =
request_api_token(&oauth_token, client.clone(), low_speed_timeout).await?; request_api_token(&oauth_token, client.clone(), low_speed_timeout).await?;
this.update(cx, |this, cx| { this.update(&mut cx, |this, cx| {
this.api_token = Some(token.clone()); this.api_token = Some(token.clone());
cx.notify(); cx.notify();
})?; })?;

View File

@ -33,6 +33,7 @@ google_ai = { workspace = true, features = ["schemars"] }
gpui.workspace = true gpui.workspace = true
http_client.workspace = true http_client.workspace = true
inline_completion_button.workspace = true inline_completion_button.workspace = true
log.workspace = true
menu.workspace = true menu.workspace = true
ollama = { workspace = true, features = ["schemars"] } ollama = { workspace = true, features = ["schemars"] }
open_ai = { workspace = true, features = ["schemars"] } open_ai = { workspace = true, features = ["schemars"] }
@ -42,6 +43,7 @@ schemars.workspace = true
serde.workspace = true serde.workspace = true
serde_json.workspace = true serde_json.workspace = true
settings.workspace = true settings.workspace = true
smol.workspace = true
strum.workspace = true strum.workspace = true
theme.workspace = true theme.workspace = true
tiktoken-rs.workspace = true tiktoken-rs.workspace = true

View File

@ -1,24 +1,24 @@
mod model; mod model;
pub mod provider; pub mod provider;
mod rate_limiter;
mod registry; mod registry;
mod request; mod request;
mod role; mod role;
pub mod settings; pub mod settings;
use std::sync::Arc;
use anyhow::Result; use anyhow::Result;
use client::Client; use client::Client;
use futures::{future::BoxFuture, stream::BoxStream}; use futures::{future::BoxFuture, stream::BoxStream};
use gpui::{AnyView, AppContext, AsyncAppContext, SharedString, Task, WindowContext}; use gpui::{AnyView, AppContext, AsyncAppContext, SharedString, Task, WindowContext};
pub use model::*; pub use model::*;
use project::Fs; use project::Fs;
pub(crate) use rate_limiter::*;
pub use registry::*; pub use registry::*;
pub use request::*; pub use request::*;
pub use role::*; pub use role::*;
use schemars::JsonSchema; use schemars::JsonSchema;
use serde::de::DeserializeOwned; use serde::de::DeserializeOwned;
use std::{future::Future, sync::Arc};
pub fn init(client: Arc<Client>, fs: Arc<dyn Fs>, cx: &mut AppContext) { pub fn init(client: Arc<Client>, fs: Arc<dyn Fs>, cx: &mut AppContext) {
settings::init(fs, cx); settings::init(fs, cx);
@ -46,7 +46,7 @@ pub trait LanguageModel: Send + Sync {
cx: &AsyncAppContext, cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>; ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
fn use_tool( fn use_any_tool(
&self, &self,
request: LanguageModelRequest, request: LanguageModelRequest,
name: String, name: String,
@ -56,6 +56,22 @@ pub trait LanguageModel: Send + Sync {
) -> BoxFuture<'static, Result<serde_json::Value>>; ) -> BoxFuture<'static, Result<serde_json::Value>>;
} }
impl dyn LanguageModel {
pub fn use_tool<T: LanguageModelTool>(
&self,
request: LanguageModelRequest,
cx: &AsyncAppContext,
) -> impl 'static + Future<Output = Result<T>> {
let schema = schemars::schema_for!(T);
let schema_json = serde_json::to_value(&schema).unwrap();
let request = self.use_any_tool(request, T::name(), T::description(), schema_json, cx);
async move {
let response = request.await?;
Ok(serde_json::from_value(response)?)
}
}
}
pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema { pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema {
fn name() -> String; fn name() -> String;
fn description() -> String; fn description() -> String;
@ -67,9 +83,9 @@ pub trait LanguageModelProvider: 'static {
fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>>; fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>>;
fn load_model(&self, _model: Arc<dyn LanguageModel>, _cx: &AppContext) {} fn load_model(&self, _model: Arc<dyn LanguageModel>, _cx: &AppContext) {}
fn is_authenticated(&self, cx: &AppContext) -> bool; fn is_authenticated(&self, cx: &AppContext) -> bool;
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>>; fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>>;
fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView; fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView;
fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>>; fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>>;
} }
pub trait LanguageModelProviderState: 'static { pub trait LanguageModelProviderState: 'static {

View File

@ -1,7 +1,7 @@
use crate::{ use crate::{
settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName, settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest, Role, LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
}; };
use anyhow::{anyhow, Context as _, Result}; use anyhow::{anyhow, Context as _, Result};
use collections::BTreeMap; use collections::BTreeMap;
@ -36,6 +36,7 @@ pub struct AnthropicSettings {
pub struct AvailableModel { pub struct AvailableModel {
pub name: String, pub name: String,
pub max_tokens: usize, pub max_tokens: usize,
pub tool_override: Option<String>,
} }
pub struct AnthropicLanguageModelProvider { pub struct AnthropicLanguageModelProvider {
@ -98,6 +99,7 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider {
anthropic::Model::Custom { anthropic::Model::Custom {
name: model.name.clone(), name: model.name.clone(),
max_tokens: model.max_tokens, max_tokens: model.max_tokens,
tool_override: model.tool_override.clone(),
}, },
); );
} }
@ -110,6 +112,7 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider {
model, model,
state: self.state.clone(), state: self.state.clone(),
http_client: self.http_client.clone(), http_client: self.http_client.clone(),
request_limiter: RateLimiter::new(4),
}) as Arc<dyn LanguageModel> }) as Arc<dyn LanguageModel>
}) })
.collect() .collect()
@ -119,7 +122,7 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider {
self.state.read(cx).api_key.is_some() self.state.read(cx).api_key.is_some()
} }
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> { fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>> {
if self.is_authenticated(cx) { if self.is_authenticated(cx) {
Task::ready(Ok(())) Task::ready(Ok(()))
} else { } else {
@ -152,7 +155,7 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider {
.into() .into()
} }
fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> { fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>> {
let state = self.state.clone(); let state = self.state.clone();
let delete_credentials = let delete_credentials =
cx.delete_credentials(&AllLanguageModelSettings::get_global(cx).anthropic.api_url); cx.delete_credentials(&AllLanguageModelSettings::get_global(cx).anthropic.api_url);
@ -171,6 +174,7 @@ pub struct AnthropicModel {
model: anthropic::Model, model: anthropic::Model,
state: gpui::Model<State>, state: gpui::Model<State>,
http_client: Arc<dyn HttpClient>, http_client: Arc<dyn HttpClient>,
request_limiter: RateLimiter,
} }
pub fn count_anthropic_tokens( pub fn count_anthropic_tokens(
@ -296,14 +300,14 @@ impl LanguageModel for AnthropicModel {
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> { ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
let request = request.into_anthropic(self.model.id().into()); let request = request.into_anthropic(self.model.id().into());
let request = self.stream_completion(request, cx); let request = self.stream_completion(request, cx);
async move { let future = self.request_limiter.stream(async move {
let response = request.await?; let response = request.await?;
Ok(anthropic::extract_text_from_events(response).boxed()) Ok(anthropic::extract_text_from_events(response))
} });
.boxed() async move { Ok(future.await?.boxed()) }.boxed()
} }
fn use_tool( fn use_any_tool(
&self, &self,
request: LanguageModelRequest, request: LanguageModelRequest,
tool_name: String, tool_name: String,
@ -311,7 +315,7 @@ impl LanguageModel for AnthropicModel {
input_schema: serde_json::Value, input_schema: serde_json::Value,
cx: &AsyncAppContext, cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<serde_json::Value>> { ) -> BoxFuture<'static, Result<serde_json::Value>> {
let mut request = request.into_anthropic(self.model.id().into()); let mut request = request.into_anthropic(self.model.tool_model_id().into());
request.tool_choice = Some(anthropic::ToolChoice::Tool { request.tool_choice = Some(anthropic::ToolChoice::Tool {
name: tool_name.clone(), name: tool_name.clone(),
}); });
@ -322,25 +326,26 @@ impl LanguageModel for AnthropicModel {
}]; }];
let response = self.request_completion(request, cx); let response = self.request_completion(request, cx);
async move { self.request_limiter
let response = response.await?; .run(async move {
response let response = response.await?;
.content response
.into_iter() .content
.find_map(|content| { .into_iter()
if let anthropic::Content::ToolUse { name, input, .. } = content { .find_map(|content| {
if name == tool_name { if let anthropic::Content::ToolUse { name, input, .. } = content {
Some(input) if name == tool_name {
Some(input)
} else {
None
}
} else { } else {
None None
} }
} else { })
None .context("tool not used")
} })
}) .boxed()
.context("tool not used")
}
.boxed()
} }
} }

View File

@ -2,7 +2,7 @@ use super::open_ai::count_open_ai_tokens;
use crate::{ use crate::{
settings::AllLanguageModelSettings, CloudModel, LanguageModel, LanguageModelId, settings::AllLanguageModelSettings, CloudModel, LanguageModel, LanguageModelId,
LanguageModelName, LanguageModelProviderId, LanguageModelProviderName, LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest, LanguageModelProviderState, LanguageModelRequest, RateLimiter,
}; };
use anyhow::{anyhow, Context as _, Result}; use anyhow::{anyhow, Context as _, Result};
use client::Client; use client::Client;
@ -41,6 +41,7 @@ pub struct AvailableModel {
provider: AvailableProvider, provider: AvailableProvider,
name: String, name: String,
max_tokens: usize, max_tokens: usize,
tool_override: Option<String>,
} }
pub struct CloudLanguageModelProvider { pub struct CloudLanguageModelProvider {
@ -56,7 +57,7 @@ struct State {
} }
impl State { impl State {
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> { fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>> {
let client = self.client.clone(); let client = self.client.clone();
cx.spawn(move |cx| async move { client.authenticate_and_connect(true, &cx).await }) cx.spawn(move |cx| async move { client.authenticate_and_connect(true, &cx).await })
} }
@ -142,6 +143,7 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
AvailableProvider::Anthropic => CloudModel::Anthropic(anthropic::Model::Custom { AvailableProvider::Anthropic => CloudModel::Anthropic(anthropic::Model::Custom {
name: model.name.clone(), name: model.name.clone(),
max_tokens: model.max_tokens, max_tokens: model.max_tokens,
tool_override: model.tool_override.clone(),
}), }),
AvailableProvider::OpenAi => CloudModel::OpenAi(open_ai::Model::Custom { AvailableProvider::OpenAi => CloudModel::OpenAi(open_ai::Model::Custom {
name: model.name.clone(), name: model.name.clone(),
@ -162,6 +164,7 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
id: LanguageModelId::from(model.id().to_string()), id: LanguageModelId::from(model.id().to_string()),
model, model,
client: self.client.clone(), client: self.client.clone(),
request_limiter: RateLimiter::new(4),
}) as Arc<dyn LanguageModel> }) as Arc<dyn LanguageModel>
}) })
.collect() .collect()
@ -171,8 +174,8 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
self.state.read(cx).status.is_connected() self.state.read(cx).status.is_connected()
} }
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> { fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>> {
self.state.read(cx).authenticate(cx) self.state.update(cx, |state, cx| state.authenticate(cx))
} }
fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView { fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
@ -182,7 +185,7 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
.into() .into()
} }
fn reset_credentials(&self, _cx: &AppContext) -> Task<Result<()>> { fn reset_credentials(&self, _cx: &mut AppContext) -> Task<Result<()>> {
Task::ready(Ok(())) Task::ready(Ok(()))
} }
} }
@ -191,6 +194,7 @@ pub struct CloudLanguageModel {
id: LanguageModelId, id: LanguageModelId,
model: CloudModel, model: CloudModel,
client: Arc<Client>, client: Arc<Client>,
request_limiter: RateLimiter,
} }
impl LanguageModel for CloudLanguageModel { impl LanguageModel for CloudLanguageModel {
@ -256,7 +260,7 @@ impl LanguageModel for CloudLanguageModel {
CloudModel::Anthropic(model) => { CloudModel::Anthropic(model) => {
let client = self.client.clone(); let client = self.client.clone();
let request = request.into_anthropic(model.id().into()); let request = request.into_anthropic(model.id().into());
async move { let future = self.request_limiter.stream(async move {
let request = serde_json::to_string(&request)?; let request = serde_json::to_string(&request)?;
let stream = client let stream = client
.request_stream(proto::StreamCompleteWithLanguageModel { .request_stream(proto::StreamCompleteWithLanguageModel {
@ -266,15 +270,14 @@ impl LanguageModel for CloudLanguageModel {
.await?; .await?;
Ok(anthropic::extract_text_from_events( Ok(anthropic::extract_text_from_events(
stream.map(|item| Ok(serde_json::from_str(&item?.event)?)), stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
) ))
.boxed()) });
} async move { Ok(future.await?.boxed()) }.boxed()
.boxed()
} }
CloudModel::OpenAi(model) => { CloudModel::OpenAi(model) => {
let client = self.client.clone(); let client = self.client.clone();
let request = request.into_open_ai(model.id().into()); let request = request.into_open_ai(model.id().into());
async move { let future = self.request_limiter.stream(async move {
let request = serde_json::to_string(&request)?; let request = serde_json::to_string(&request)?;
let stream = client let stream = client
.request_stream(proto::StreamCompleteWithLanguageModel { .request_stream(proto::StreamCompleteWithLanguageModel {
@ -284,15 +287,14 @@ impl LanguageModel for CloudLanguageModel {
.await?; .await?;
Ok(open_ai::extract_text_from_events( Ok(open_ai::extract_text_from_events(
stream.map(|item| Ok(serde_json::from_str(&item?.event)?)), stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
) ))
.boxed()) });
} async move { Ok(future.await?.boxed()) }.boxed()
.boxed()
} }
CloudModel::Google(model) => { CloudModel::Google(model) => {
let client = self.client.clone(); let client = self.client.clone();
let request = request.into_google(model.id().into()); let request = request.into_google(model.id().into());
async move { let future = self.request_limiter.stream(async move {
let request = serde_json::to_string(&request)?; let request = serde_json::to_string(&request)?;
let stream = client let stream = client
.request_stream(proto::StreamCompleteWithLanguageModel { .request_stream(proto::StreamCompleteWithLanguageModel {
@ -302,15 +304,14 @@ impl LanguageModel for CloudLanguageModel {
.await?; .await?;
Ok(google_ai::extract_text_from_events( Ok(google_ai::extract_text_from_events(
stream.map(|item| Ok(serde_json::from_str(&item?.event)?)), stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
) ))
.boxed()) });
} async move { Ok(future.await?.boxed()) }.boxed()
.boxed()
} }
} }
} }
fn use_tool( fn use_any_tool(
&self, &self,
request: LanguageModelRequest, request: LanguageModelRequest,
tool_name: String, tool_name: String,
@ -321,7 +322,7 @@ impl LanguageModel for CloudLanguageModel {
match &self.model { match &self.model {
CloudModel::Anthropic(model) => { CloudModel::Anthropic(model) => {
let client = self.client.clone(); let client = self.client.clone();
let mut request = request.into_anthropic(model.id().into()); let mut request = request.into_anthropic(model.tool_model_id().into());
request.tool_choice = Some(anthropic::ToolChoice::Tool { request.tool_choice = Some(anthropic::ToolChoice::Tool {
name: tool_name.clone(), name: tool_name.clone(),
}); });
@ -331,32 +332,34 @@ impl LanguageModel for CloudLanguageModel {
input_schema, input_schema,
}]; }];
async move { self.request_limiter
let request = serde_json::to_string(&request)?; .run(async move {
let response = client let request = serde_json::to_string(&request)?;
.request(proto::CompleteWithLanguageModel { let response = client
provider: proto::LanguageModelProvider::Anthropic as i32, .request(proto::CompleteWithLanguageModel {
request, provider: proto::LanguageModelProvider::Anthropic as i32,
}) request,
.await?; })
let response: anthropic::Response = serde_json::from_str(&response.completion)?; .await?;
response let response: anthropic::Response =
.content serde_json::from_str(&response.completion)?;
.into_iter() response
.find_map(|content| { .content
if let anthropic::Content::ToolUse { name, input, .. } = content { .into_iter()
if name == tool_name { .find_map(|content| {
Some(input) if let anthropic::Content::ToolUse { name, input, .. } = content {
if name == tool_name {
Some(input)
} else {
None
}
} else { } else {
None None
} }
} else { })
None .context("tool not used")
} })
}) .boxed()
.context("tool not used")
}
.boxed()
} }
CloudModel::OpenAi(_) => { CloudModel::OpenAi(_) => {
future::ready(Err(anyhow!("tool use not implemented for OpenAI"))).boxed() future::ready(Err(anyhow!("tool use not implemented for OpenAI"))).boxed()

View File

@ -27,7 +27,7 @@ use crate::settings::AllLanguageModelSettings;
use crate::LanguageModelProviderState; use crate::LanguageModelProviderState;
use crate::{ use crate::{
LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelRequest, Role, LanguageModelProviderId, LanguageModelProviderName, LanguageModelRequest, RateLimiter, Role,
}; };
use super::open_ai::count_open_ai_tokens; use super::open_ai::count_open_ai_tokens;
@ -85,7 +85,12 @@ impl LanguageModelProvider for CopilotChatLanguageModelProvider {
fn provided_models(&self, _cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> { fn provided_models(&self, _cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
CopilotChatModel::iter() CopilotChatModel::iter()
.map(|model| Arc::new(CopilotChatLanguageModel { model }) as Arc<dyn LanguageModel>) .map(|model| {
Arc::new(CopilotChatLanguageModel {
model,
request_limiter: RateLimiter::new(4),
}) as Arc<dyn LanguageModel>
})
.collect() .collect()
} }
@ -95,7 +100,7 @@ impl LanguageModelProvider for CopilotChatLanguageModelProvider {
.unwrap_or(false) .unwrap_or(false)
} }
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> { fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>> {
let result = if self.is_authenticated(cx) { let result = if self.is_authenticated(cx) {
Ok(()) Ok(())
} else if let Some(copilot) = Copilot::global(cx) { } else if let Some(copilot) = Copilot::global(cx) {
@ -121,7 +126,7 @@ impl LanguageModelProvider for CopilotChatLanguageModelProvider {
cx.new_view(|cx| AuthenticationPrompt::new(cx)).into() cx.new_view(|cx| AuthenticationPrompt::new(cx)).into()
} }
fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> { fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>> {
let Some(copilot) = Copilot::global(cx) else { let Some(copilot) = Copilot::global(cx) else {
return Task::ready(Err(anyhow::anyhow!( return Task::ready(Err(anyhow::anyhow!(
"Copilot is not available. Please ensure Copilot is enabled and running and try again." "Copilot is not available. Please ensure Copilot is enabled and running and try again."
@ -145,6 +150,7 @@ impl LanguageModelProvider for CopilotChatLanguageModelProvider {
pub struct CopilotChatLanguageModel { pub struct CopilotChatLanguageModel {
model: CopilotChatModel, model: CopilotChatModel,
request_limiter: RateLimiter,
} }
impl LanguageModel for CopilotChatLanguageModel { impl LanguageModel for CopilotChatLanguageModel {
@ -215,30 +221,35 @@ impl LanguageModel for CopilotChatLanguageModel {
return futures::future::ready(Err(anyhow::anyhow!("App state dropped"))).boxed(); return futures::future::ready(Err(anyhow::anyhow!("App state dropped"))).boxed();
}; };
cx.spawn(|mut cx| async move { let request_limiter = self.request_limiter.clone();
let response = CopilotChat::stream_completion(request, low_speed_timeout, &mut cx).await?; let future = cx.spawn(|cx| async move {
let stream = response let response = CopilotChat::stream_completion(request, low_speed_timeout, cx);
.filter_map(|response| async move { request_limiter.stream(async move {
match response { let response = response.await?;
Ok(result) => { let stream = response
let choice = result.choices.first(); .filter_map(|response| async move {
match choice { match response {
Some(choice) => Some(Ok(choice.delta.content.clone().unwrap_or_default())), Ok(result) => {
None => Some(Err(anyhow::anyhow!( let choice = result.choices.first();
"The Copilot Chat API returned a response with no choices, but hadn't finished the message yet. Please try again." match choice {
))), Some(choice) => Some(Ok(choice.delta.content.clone().unwrap_or_default())),
None => Some(Err(anyhow::anyhow!(
"The Copilot Chat API returned a response with no choices, but hadn't finished the message yet. Please try again."
))),
}
} }
Err(err) => Some(Err(err)),
} }
Err(err) => Some(Err(err)), })
} .boxed();
}) Ok(stream)
.boxed(); }).await
Ok(stream) });
})
.boxed() async move { Ok(future.await?.boxed()) }.boxed()
} }
fn use_tool( fn use_any_tool(
&self, &self,
_request: LanguageModelRequest, _request: LanguageModelRequest,
_name: String, _name: String,

View File

@ -60,7 +60,7 @@ impl LanguageModelProvider for FakeLanguageModelProvider {
true true
} }
fn authenticate(&self, _: &AppContext) -> Task<Result<()>> { fn authenticate(&self, _: &mut AppContext) -> Task<Result<()>> {
Task::ready(Ok(())) Task::ready(Ok(()))
} }
@ -68,7 +68,7 @@ impl LanguageModelProvider for FakeLanguageModelProvider {
unimplemented!() unimplemented!()
} }
fn reset_credentials(&self, _: &AppContext) -> Task<Result<()>> { fn reset_credentials(&self, _: &mut AppContext) -> Task<Result<()>> {
Task::ready(Ok(())) Task::ready(Ok(()))
} }
} }
@ -173,7 +173,7 @@ impl LanguageModel for FakeLanguageModel {
async move { Ok(rx.map(Ok).boxed()) }.boxed() async move { Ok(rx.map(Ok).boxed()) }.boxed()
} }
fn use_tool( fn use_any_tool(
&self, &self,
_request: LanguageModelRequest, _request: LanguageModelRequest,
_name: String, _name: String,

View File

@ -20,7 +20,7 @@ use util::ResultExt;
use crate::{ use crate::{
settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName, settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest, LanguageModelProviderState, LanguageModelRequest, RateLimiter,
}; };
const PROVIDER_ID: &str = "google"; const PROVIDER_ID: &str = "google";
@ -111,6 +111,7 @@ impl LanguageModelProvider for GoogleLanguageModelProvider {
model, model,
state: self.state.clone(), state: self.state.clone(),
http_client: self.http_client.clone(), http_client: self.http_client.clone(),
rate_limiter: RateLimiter::new(4),
}) as Arc<dyn LanguageModel> }) as Arc<dyn LanguageModel>
}) })
.collect() .collect()
@ -120,7 +121,7 @@ impl LanguageModelProvider for GoogleLanguageModelProvider {
self.state.read(cx).api_key.is_some() self.state.read(cx).api_key.is_some()
} }
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> { fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>> {
if self.is_authenticated(cx) { if self.is_authenticated(cx) {
Task::ready(Ok(())) Task::ready(Ok(()))
} else { } else {
@ -153,7 +154,7 @@ impl LanguageModelProvider for GoogleLanguageModelProvider {
.into() .into()
} }
fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> { fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>> {
let state = self.state.clone(); let state = self.state.clone();
let delete_credentials = let delete_credentials =
cx.delete_credentials(&AllLanguageModelSettings::get_global(cx).google.api_url); cx.delete_credentials(&AllLanguageModelSettings::get_global(cx).google.api_url);
@ -172,6 +173,7 @@ pub struct GoogleLanguageModel {
model: google_ai::Model, model: google_ai::Model,
state: gpui::Model<State>, state: gpui::Model<State>,
http_client: Arc<dyn HttpClient>, http_client: Arc<dyn HttpClient>,
rate_limiter: RateLimiter,
} }
impl LanguageModel for GoogleLanguageModel { impl LanguageModel for GoogleLanguageModel {
@ -243,17 +245,17 @@ impl LanguageModel for GoogleLanguageModel {
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed(); return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
}; };
async move { let future = self.rate_limiter.stream(async move {
let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?; let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
let response = let response =
stream_generate_content(http_client.as_ref(), &api_url, &api_key, request); stream_generate_content(http_client.as_ref(), &api_url, &api_key, request);
let events = response.await?; let events = response.await?;
Ok(google_ai::extract_text_from_events(events).boxed()) Ok(google_ai::extract_text_from_events(events).boxed())
} });
.boxed() async move { Ok(future.await?.boxed()) }.boxed()
} }
fn use_tool( fn use_any_tool(
&self, &self,
_request: LanguageModelRequest, _request: LanguageModelRequest,
_name: String, _name: String,

View File

@ -12,7 +12,7 @@ use ui::{prelude::*, ButtonLike, ElevationIndex};
use crate::{ use crate::{
settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName, settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest, Role, LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
}; };
const OLLAMA_DOWNLOAD_URL: &str = "https://ollama.com/download"; const OLLAMA_DOWNLOAD_URL: &str = "https://ollama.com/download";
@ -39,7 +39,7 @@ struct State {
} }
impl State { impl State {
fn fetch_models(&self, cx: &ModelContext<Self>) -> Task<Result<()>> { fn fetch_models(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
let settings = &AllLanguageModelSettings::get_global(cx).ollama; let settings = &AllLanguageModelSettings::get_global(cx).ollama;
let http_client = self.http_client.clone(); let http_client = self.http_client.clone();
let api_url = settings.api_url.clone(); let api_url = settings.api_url.clone();
@ -80,37 +80,10 @@ impl OllamaLanguageModelProvider {
}), }),
}), }),
}; };
this.fetch_models(cx).detach(); this.state
.update(cx, |state, cx| state.fetch_models(cx).detach());
this this
} }
fn fetch_models(&self, cx: &AppContext) -> Task<Result<()>> {
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
let http_client = self.http_client.clone();
let api_url = settings.api_url.clone();
let state = self.state.clone();
// As a proxy for the server being "authenticated", we'll check if its up by fetching the models
cx.spawn(|mut cx| async move {
let models = get_models(http_client.as_ref(), &api_url, None).await?;
let mut models: Vec<ollama::Model> = models
.into_iter()
// Since there is no metadata from the Ollama API
// indicating which models are embedding models,
// simply filter out models with "-embed" in their name
.filter(|model| !model.name.contains("-embed"))
.map(|model| ollama::Model::new(&model.name))
.collect();
models.sort_by(|a, b| a.name.cmp(&b.name));
state.update(&mut cx, |this, cx| {
this.available_models = models;
cx.notify();
})
})
}
} }
impl LanguageModelProviderState for OllamaLanguageModelProvider { impl LanguageModelProviderState for OllamaLanguageModelProvider {
@ -140,6 +113,7 @@ impl LanguageModelProvider for OllamaLanguageModelProvider {
id: LanguageModelId::from(model.name.clone()), id: LanguageModelId::from(model.name.clone()),
model: model.clone(), model: model.clone(),
http_client: self.http_client.clone(), http_client: self.http_client.clone(),
request_limiter: RateLimiter::new(4),
}) as Arc<dyn LanguageModel> }) as Arc<dyn LanguageModel>
}) })
.collect() .collect()
@ -158,11 +132,11 @@ impl LanguageModelProvider for OllamaLanguageModelProvider {
!self.state.read(cx).available_models.is_empty() !self.state.read(cx).available_models.is_empty()
} }
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> { fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>> {
if self.is_authenticated(cx) { if self.is_authenticated(cx) {
Task::ready(Ok(())) Task::ready(Ok(()))
} else { } else {
self.fetch_models(cx) self.state.update(cx, |state, cx| state.fetch_models(cx))
} }
} }
@ -176,8 +150,8 @@ impl LanguageModelProvider for OllamaLanguageModelProvider {
.into() .into()
} }
fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> { fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>> {
self.fetch_models(cx) self.state.update(cx, |state, cx| state.fetch_models(cx))
} }
} }
@ -185,6 +159,7 @@ pub struct OllamaLanguageModel {
id: LanguageModelId, id: LanguageModelId,
model: ollama::Model, model: ollama::Model,
http_client: Arc<dyn HttpClient>, http_client: Arc<dyn HttpClient>,
request_limiter: RateLimiter,
} }
impl OllamaLanguageModel { impl OllamaLanguageModel {
@ -235,14 +210,14 @@ impl LanguageModel for OllamaLanguageModel {
LanguageModelProviderName(PROVIDER_NAME.into()) LanguageModelProviderName(PROVIDER_NAME.into())
} }
fn max_token_count(&self) -> usize {
self.model.max_token_count()
}
fn telemetry_id(&self) -> String { fn telemetry_id(&self) -> String {
format!("ollama/{}", self.model.id()) format!("ollama/{}", self.model.id())
} }
fn max_token_count(&self) -> usize {
self.model.max_token_count()
}
fn count_tokens( fn count_tokens(
&self, &self,
request: LanguageModelRequest, request: LanguageModelRequest,
@ -275,10 +250,10 @@ impl LanguageModel for OllamaLanguageModel {
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed(); return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
}; };
async move { let future = self.request_limiter.stream(async move {
let request = let response =
stream_chat_completion(http_client.as_ref(), &api_url, request, low_speed_timeout); stream_chat_completion(http_client.as_ref(), &api_url, request, low_speed_timeout)
let response = request.await?; .await?;
let stream = response let stream = response
.filter_map(|response| async move { .filter_map(|response| async move {
match response { match response {
@ -295,11 +270,12 @@ impl LanguageModel for OllamaLanguageModel {
}) })
.boxed(); .boxed();
Ok(stream) Ok(stream)
} });
.boxed()
async move { Ok(future.await?.boxed()) }.boxed()
} }
fn use_tool( fn use_any_tool(
&self, &self,
_request: LanguageModelRequest, _request: LanguageModelRequest,
_name: String, _name: String,

View File

@ -20,7 +20,7 @@ use util::ResultExt;
use crate::{ use crate::{
settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName, settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest, Role, LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
}; };
const PROVIDER_ID: &str = "openai"; const PROVIDER_ID: &str = "openai";
@ -112,6 +112,7 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider {
model, model,
state: self.state.clone(), state: self.state.clone(),
http_client: self.http_client.clone(), http_client: self.http_client.clone(),
request_limiter: RateLimiter::new(4),
}) as Arc<dyn LanguageModel> }) as Arc<dyn LanguageModel>
}) })
.collect() .collect()
@ -121,7 +122,7 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider {
self.state.read(cx).api_key.is_some() self.state.read(cx).api_key.is_some()
} }
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> { fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>> {
if self.is_authenticated(cx) { if self.is_authenticated(cx) {
Task::ready(Ok(())) Task::ready(Ok(()))
} else { } else {
@ -153,7 +154,7 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider {
.into() .into()
} }
fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> { fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>> {
let settings = &AllLanguageModelSettings::get_global(cx).openai; let settings = &AllLanguageModelSettings::get_global(cx).openai;
let delete_credentials = cx.delete_credentials(&settings.api_url); let delete_credentials = cx.delete_credentials(&settings.api_url);
let state = self.state.clone(); let state = self.state.clone();
@ -172,6 +173,7 @@ pub struct OpenAiLanguageModel {
model: open_ai::Model, model: open_ai::Model,
state: gpui::Model<State>, state: gpui::Model<State>,
http_client: Arc<dyn HttpClient>, http_client: Arc<dyn HttpClient>,
request_limiter: RateLimiter,
} }
impl LanguageModel for OpenAiLanguageModel { impl LanguageModel for OpenAiLanguageModel {
@ -226,7 +228,7 @@ impl LanguageModel for OpenAiLanguageModel {
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed(); return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
}; };
async move { let future = self.request_limiter.stream(async move {
let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?; let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
let request = stream_completion( let request = stream_completion(
http_client.as_ref(), http_client.as_ref(),
@ -237,11 +239,12 @@ impl LanguageModel for OpenAiLanguageModel {
); );
let response = request.await?; let response = request.await?;
Ok(open_ai::extract_text_from_events(response).boxed()) Ok(open_ai::extract_text_from_events(response).boxed())
} });
.boxed()
async move { Ok(future.await?.boxed()) }.boxed()
} }
fn use_tool( fn use_any_tool(
&self, &self,
_request: LanguageModelRequest, _request: LanguageModelRequest,
_name: String, _name: String,

View File

@ -0,0 +1,70 @@
use anyhow::Result;
use futures::Stream;
use smol::lock::{Semaphore, SemaphoreGuardArc};
use std::{
future::Future,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
#[derive(Clone)]
pub struct RateLimiter {
semaphore: Arc<Semaphore>,
}
pub struct RateLimitGuard<T> {
inner: T,
_guard: SemaphoreGuardArc,
}
impl<T> Stream for RateLimitGuard<T>
where
T: Stream,
{
type Item = T::Item;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
unsafe { Pin::map_unchecked_mut(self, |this| &mut this.inner).poll_next(cx) }
}
}
impl RateLimiter {
pub fn new(limit: usize) -> Self {
Self {
semaphore: Arc::new(Semaphore::new(limit)),
}
}
pub fn run<'a, Fut, T>(&self, future: Fut) -> impl 'a + Future<Output = Result<T>>
where
Fut: 'a + Future<Output = Result<T>>,
{
let guard = self.semaphore.acquire_arc();
async move {
let guard = guard.await;
let result = future.await?;
drop(guard);
Ok(result)
}
}
pub fn stream<'a, Fut, T>(
&self,
future: Fut,
) -> impl 'a + Future<Output = Result<impl Stream<Item = T::Item>>>
where
Fut: 'a + Future<Output = Result<T>>,
T: Stream,
{
let guard = self.semaphore.acquire_arc();
async move {
let guard = guard.await;
let inner = future.await?;
Ok(RateLimitGuard {
inner,
_guard: guard,
})
}
}
}

View File

@ -4,11 +4,12 @@ use crate::{
copilot_chat::CopilotChatLanguageModelProvider, google::GoogleLanguageModelProvider, copilot_chat::CopilotChatLanguageModelProvider, google::GoogleLanguageModelProvider,
ollama::OllamaLanguageModelProvider, open_ai::OpenAiLanguageModelProvider, ollama::OllamaLanguageModelProvider, open_ai::OpenAiLanguageModelProvider,
}, },
LanguageModel, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderState, LanguageModel, LanguageModelId, LanguageModelProvider, LanguageModelProviderId,
LanguageModelProviderState,
}; };
use client::Client; use client::Client;
use collections::BTreeMap; use collections::BTreeMap;
use gpui::{AppContext, Global, Model, ModelContext}; use gpui::{AppContext, EventEmitter, Global, Model, ModelContext};
use std::sync::Arc; use std::sync::Arc;
use ui::Context; use ui::Context;
@ -70,9 +71,19 @@ impl Global for GlobalLanguageModelRegistry {}
#[derive(Default)] #[derive(Default)]
pub struct LanguageModelRegistry { pub struct LanguageModelRegistry {
active_model: Option<ActiveModel>,
providers: BTreeMap<LanguageModelProviderId, Arc<dyn LanguageModelProvider>>, providers: BTreeMap<LanguageModelProviderId, Arc<dyn LanguageModelProvider>>,
} }
pub struct ActiveModel {
provider: Arc<dyn LanguageModelProvider>,
model: Option<Arc<dyn LanguageModel>>,
}
pub struct ActiveModelChanged;
impl EventEmitter<ActiveModelChanged> for LanguageModelRegistry {}
impl LanguageModelRegistry { impl LanguageModelRegistry {
pub fn global(cx: &AppContext) -> Model<Self> { pub fn global(cx: &AppContext) -> Model<Self> {
cx.global::<GlobalLanguageModelRegistry>().0.clone() cx.global::<GlobalLanguageModelRegistry>().0.clone()
@ -88,6 +99,8 @@ impl LanguageModelRegistry {
let registry = cx.new_model(|cx| { let registry = cx.new_model(|cx| {
let mut registry = Self::default(); let mut registry = Self::default();
registry.register_provider(fake_provider.clone(), cx); registry.register_provider(fake_provider.clone(), cx);
let model = fake_provider.provided_models(cx)[0].clone();
registry.set_active_model(Some(model), cx);
registry registry
}); });
cx.set_global(GlobalLanguageModelRegistry(registry)); cx.set_global(GlobalLanguageModelRegistry(registry));
@ -136,6 +149,64 @@ impl LanguageModelRegistry {
) -> Option<Arc<dyn LanguageModelProvider>> { ) -> Option<Arc<dyn LanguageModelProvider>> {
self.providers.get(name).cloned() self.providers.get(name).cloned()
} }
pub fn select_active_model(
&mut self,
provider: &LanguageModelProviderId,
model_id: &LanguageModelId,
cx: &mut ModelContext<Self>,
) {
let Some(provider) = self.provider(&provider) else {
return;
};
let models = provider.provided_models(cx);
if let Some(model) = models.iter().find(|model| &model.id() == model_id).cloned() {
self.set_active_model(Some(model), cx);
}
}
pub fn set_active_provider(
&mut self,
provider: Option<Arc<dyn LanguageModelProvider>>,
cx: &mut ModelContext<Self>,
) {
self.active_model = provider.map(|provider| ActiveModel {
provider,
model: None,
});
cx.emit(ActiveModelChanged);
}
pub fn set_active_model(
&mut self,
model: Option<Arc<dyn LanguageModel>>,
cx: &mut ModelContext<Self>,
) {
if let Some(model) = model {
let provider_id = model.provider_id();
if let Some(provider) = self.providers.get(&provider_id).cloned() {
self.active_model = Some(ActiveModel {
provider,
model: Some(model),
});
cx.emit(ActiveModelChanged);
} else {
log::warn!("Active model's provider not found in registry");
}
} else {
self.active_model = None;
cx.emit(ActiveModelChanged);
}
}
pub fn active_provider(&self) -> Option<Arc<dyn LanguageModelProvider>> {
Some(self.active_model.as_ref()?.provider.clone())
}
pub fn active_model(&self) -> Option<Arc<dyn LanguageModel>> {
self.active_model.as_ref()?.model.clone()
}
} }
#[cfg(test)] #[cfg(test)]

View File

@ -89,9 +89,15 @@ impl AnthropicSettingsContent {
models models
.into_iter() .into_iter()
.filter_map(|model| match model { .filter_map(|model| match model {
anthropic::Model::Custom { name, max_tokens } => { anthropic::Model::Custom {
Some(provider::anthropic::AvailableModel { name, max_tokens }) name,
} max_tokens,
tool_override,
} => Some(provider::anthropic::AvailableModel {
name,
max_tokens,
tool_override,
}),
_ => None, _ => None,
}) })
.collect() .collect()

View File

@ -22,7 +22,6 @@ anyhow.workspace = true
client.workspace = true client.workspace = true
clock.workspace = true clock.workspace = true
collections.workspace = true collections.workspace = true
completion.workspace = true
fs.workspace = true fs.workspace = true
futures.workspace = true futures.workspace = true
futures-batch.workspace = true futures-batch.workspace = true

View File

@ -1261,6 +1261,3 @@ mod tests {
); );
} }
} }
// See https://github.com/zed-industries/zed/pull/14823#discussion_r1684616398 for why this is here and when it should be removed.
type _TODO = completion::LanguageModelCompletionProvider;