Extract SlashCommand trait from assistant (#12252)

This PR extracts the `SlashCommand` trait (along with the
`SlashCommandRegistry`) from the `assistant` crate.

This will allow us to register slash commands from extensions without
having to make `extension` depend on `assistant`.

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2024-05-24 13:03:41 -04:00 committed by GitHub
parent af3d7a60c8
commit 8040e43520
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 183 additions and 91 deletions

13
Cargo.lock generated
View File

@ -337,6 +337,7 @@ version = "0.1.0"
dependencies = [ dependencies = [
"anthropic", "anthropic",
"anyhow", "anyhow",
"assistant_slash_command",
"cargo_toml", "cargo_toml",
"chrono", "chrono",
"client", "client",
@ -426,6 +427,18 @@ dependencies = [
"workspace", "workspace",
] ]
[[package]]
name = "assistant_slash_command"
version = "0.1.0"
dependencies = [
"anyhow",
"collections",
"derive_more",
"futures 0.3.28",
"gpui",
"parking_lot",
]
[[package]] [[package]]
name = "assistant_tooling" name = "assistant_tooling"
version = "0.1.0" version = "0.1.0"

View File

@ -5,6 +5,7 @@ members = [
"crates/assets", "crates/assets",
"crates/assistant", "crates/assistant",
"crates/assistant2", "crates/assistant2",
"crates/assistant_slash_command",
"crates/assistant_tooling", "crates/assistant_tooling",
"crates/audio", "crates/audio",
"crates/auto_update", "crates/auto_update",
@ -148,6 +149,7 @@ anthropic = { path = "crates/anthropic" }
assets = { path = "crates/assets" } assets = { path = "crates/assets" }
assistant = { path = "crates/assistant" } assistant = { path = "crates/assistant" }
assistant2 = { path = "crates/assistant2" } assistant2 = { path = "crates/assistant2" }
assistant_slash_command = { path = "crates/assistant_slash_command" }
assistant_tooling = { path = "crates/assistant_tooling" } assistant_tooling = { path = "crates/assistant_tooling" }
audio = { path = "crates/audio" } audio = { path = "crates/audio" }
auto_update = { path = "crates/auto_update" } auto_update = { path = "crates/auto_update" }

View File

@ -12,6 +12,7 @@ doctest = false
[dependencies] [dependencies]
anyhow.workspace = true anyhow.workspace = true
anthropic = { workspace = true, features = ["schemars"] } anthropic = { workspace = true, features = ["schemars"] }
assistant_slash_command.workspace = true
cargo_toml.workspace = true cargo_toml.workspace = true
chrono.workspace = true chrono.workspace = true
client.workspace = true client.workspace = true

View File

@ -17,7 +17,6 @@ use client::{proto, Client};
use command_palette_hooks::CommandPaletteFilter; use command_palette_hooks::CommandPaletteFilter;
pub(crate) use completion_provider::*; pub(crate) use completion_provider::*;
use gpui::{actions, AppContext, Global, SharedString, UpdateGlobal}; use gpui::{actions, AppContext, Global, SharedString, UpdateGlobal};
pub(crate) use prompts::prompt_library::*;
pub(crate) use saved_conversation::*; pub(crate) use saved_conversation::*;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore}; use settings::{Settings, SettingsStore};

View File

@ -9,7 +9,8 @@ use crate::{
prompts::prompt::generate_content_prompt, prompts::prompt::generate_content_prompt,
search::*, search::*,
slash_command::{ slash_command::{
SlashCommandCleanup, SlashCommandCompletionProvider, SlashCommandLine, SlashCommandRegistry, current_file_command, file_command, prompt_command, SlashCommandCleanup,
SlashCommandCompletionProvider, SlashCommandLine, SlashCommandRegistry,
}, },
ApplyEdit, Assist, CompletionProvider, CycleMessageRole, InlineAssist, LanguageModel, ApplyEdit, Assist, CompletionProvider, CycleMessageRole, InlineAssist, LanguageModel,
LanguageModelRequest, LanguageModelRequestMessage, MessageId, MessageMetadata, MessageStatus, LanguageModelRequest, LanguageModelRequestMessage, MessageId, MessageMetadata, MessageStatus,
@ -204,11 +205,21 @@ impl AssistantPanel {
}) })
.detach(); .detach();
let slash_command_registry = SlashCommandRegistry::new( let slash_command_registry = SlashCommandRegistry::new();
let window = cx.window_handle().downcast::<Workspace>();
slash_command_registry.register_command(file_command::FileSlashCommand::new(
workspace.project().clone(), workspace.project().clone(),
prompt_library.clone(), ));
cx.window_handle().downcast::<Workspace>(), slash_command_registry.register_command(
prompt_command::PromptSlashCommand::new(prompt_library.clone()),
); );
if let Some(window) = window {
slash_command_registry.register_command(
current_file_command::CurrentFileSlashCommand::new(window),
);
}
Self { Self {
workspace: workspace_handle, workspace: workspace_handle,
@ -4273,8 +4284,13 @@ mod tests {
let project = Project::test(fs.clone(), ["/test".as_ref()], cx).await; let project = Project::test(fs.clone(), ["/test".as_ref()], cx).await;
let prompt_library = Arc::new(PromptLibrary::default()); let prompt_library = Arc::new(PromptLibrary::default());
let slash_command_registry = let slash_command_registry = SlashCommandRegistry::new();
SlashCommandRegistry::new(project.clone(), prompt_library, None);
slash_command_registry
.register_command(file_command::FileSlashCommand::new(project.clone()));
slash_command_registry.register_command(prompt_command::PromptSlashCommand::new(
prompt_library.clone(),
));
let registry = Arc::new(LanguageRegistry::test(cx.executor())); let registry = Arc::new(LanguageRegistry::test(cx.executor()));
let conversation = cx.new_model(|cx| { let conversation = cx.new_model(|cx| {

View File

@ -1,12 +1,9 @@
use anyhow::Result; use anyhow::Result;
use collections::HashMap;
use editor::{CompletionProvider, Editor}; use editor::{CompletionProvider, Editor};
use futures::channel::oneshot;
use fuzzy::{match_strings, StringMatchCandidate}; use fuzzy::{match_strings, StringMatchCandidate};
use gpui::{AppContext, Model, Task, ViewContext, WindowHandle}; use gpui::{AppContext, Model, Task, ViewContext};
use language::{Anchor, Buffer, CodeLabel, Documentation, LanguageServerId, ToPoint}; use language::{Anchor, Buffer, CodeLabel, Documentation, LanguageServerId, ToPoint};
use parking_lot::{Mutex, RwLock}; use parking_lot::{Mutex, RwLock};
use project::Project;
use rope::Point; use rope::Point;
use std::{ use std::{
ops::Range, ops::Range,
@ -15,60 +12,20 @@ use std::{
Arc, Arc,
}, },
}; };
use workspace::Workspace;
use crate::PromptLibrary; pub use assistant_slash_command::{
SlashCommand, SlashCommandCleanup, SlashCommandInvocation, SlashCommandRegistry,
};
mod current_file_command; pub mod current_file_command;
mod file_command; pub mod file_command;
mod prompt_command; pub mod prompt_command;
pub(crate) struct SlashCommandCompletionProvider { pub(crate) struct SlashCommandCompletionProvider {
commands: Arc<SlashCommandRegistry>, commands: Arc<SlashCommandRegistry>,
cancel_flag: Mutex<Arc<AtomicBool>>, cancel_flag: Mutex<Arc<AtomicBool>>,
} }
#[derive(Default)]
pub(crate) struct SlashCommandRegistry {
commands: HashMap<String, Box<dyn SlashCommand>>,
}
pub(crate) trait SlashCommand: 'static + Send + Sync {
fn name(&self) -> String;
fn description(&self) -> String;
fn complete_argument(
&self,
query: String,
cancel: Arc<AtomicBool>,
cx: &mut AppContext,
) -> Task<Result<Vec<String>>>;
fn requires_argument(&self) -> bool;
fn run(&self, argument: Option<&str>, cx: &mut AppContext) -> SlashCommandInvocation;
}
pub(crate) struct SlashCommandInvocation {
pub output: Task<Result<String>>,
pub invalidated: oneshot::Receiver<()>,
pub cleanup: SlashCommandCleanup,
}
#[derive(Default)]
pub(crate) struct SlashCommandCleanup(Option<Box<dyn FnOnce()>>);
impl SlashCommandCleanup {
pub fn new(cleanup: impl FnOnce() + 'static) -> Self {
Self(Some(Box::new(cleanup)))
}
}
impl Drop for SlashCommandCleanup {
fn drop(&mut self) {
if let Some(cleanup) = self.0.take() {
cleanup();
}
}
}
pub(crate) struct SlashCommandLine { pub(crate) struct SlashCommandLine {
/// The range within the line containing the command name. /// The range within the line containing the command name.
pub name: Range<usize>, pub name: Range<usize>,
@ -76,38 +33,6 @@ pub(crate) struct SlashCommandLine {
pub argument: Option<Range<usize>>, pub argument: Option<Range<usize>>,
} }
impl SlashCommandRegistry {
pub fn new(
project: Model<Project>,
prompt_library: Arc<PromptLibrary>,
window: Option<WindowHandle<Workspace>>,
) -> Arc<Self> {
let mut this = Self {
commands: HashMap::default(),
};
this.register_command(file_command::FileSlashCommand::new(project));
this.register_command(prompt_command::PromptSlashCommand::new(prompt_library));
if let Some(window) = window {
this.register_command(current_file_command::CurrentFileSlashCommand::new(window));
}
Arc::new(this)
}
fn register_command(&mut self, command: impl SlashCommand) {
self.commands.insert(command.name(), Box::new(command));
}
fn command_names(&self) -> impl Iterator<Item = &String> {
self.commands.keys()
}
pub(crate) fn command(&self, name: &str) -> Option<&dyn SlashCommand> {
self.commands.get(name).map(|b| &**b)
}
}
impl SlashCommandCompletionProvider { impl SlashCommandCompletionProvider {
pub fn new(commands: Arc<SlashCommandRegistry>) -> Self { pub fn new(commands: Arc<SlashCommandRegistry>) -> Self {
Self { Self {
@ -125,11 +50,12 @@ impl SlashCommandCompletionProvider {
let candidates = self let candidates = self
.commands .commands
.command_names() .command_names()
.into_iter()
.enumerate() .enumerate()
.map(|(ix, def)| StringMatchCandidate { .map(|(ix, def)| StringMatchCandidate {
id: ix, id: ix,
string: def.clone(), string: def.to_string(),
char_bag: def.as_str().into(), char_bag: def.as_ref().into(),
}) })
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let commands = self.commands.clone(); let commands = self.commands.clone();

View File

@ -0,0 +1,20 @@
[package]
name = "assistant_slash_command"
version = "0.1.0"
edition = "2021"
publish = false
license = "GPL-3.0-or-later"
[lints]
workspace = true
[lib]
path = "src/assistant_slash_command.rs"
[dependencies]
anyhow.workspace = true
collections.workspace = true
derive_more.workspace = true
futures.workspace = true
gpui.workspace = true
parking_lot.workspace = true

View File

@ -0,0 +1 @@
../../LICENSE-GPL

View File

@ -0,0 +1,50 @@
mod slash_command_registry;
use std::sync::atomic::AtomicBool;
use std::sync::Arc;
use anyhow::Result;
use futures::channel::oneshot;
use gpui::{AppContext, Task};
pub use slash_command_registry::*;
pub fn init(cx: &mut AppContext) {
SlashCommandRegistry::default_global(cx);
}
pub trait SlashCommand: 'static + Send + Sync {
fn name(&self) -> String;
fn description(&self) -> String;
fn complete_argument(
&self,
query: String,
cancel: Arc<AtomicBool>,
cx: &mut AppContext,
) -> Task<Result<Vec<String>>>;
fn requires_argument(&self) -> bool;
fn run(&self, argument: Option<&str>, cx: &mut AppContext) -> SlashCommandInvocation;
}
pub struct SlashCommandInvocation {
pub output: Task<Result<String>>,
pub invalidated: oneshot::Receiver<()>,
pub cleanup: SlashCommandCleanup,
}
#[derive(Default)]
pub struct SlashCommandCleanup(Option<Box<dyn FnOnce()>>);
impl SlashCommandCleanup {
pub fn new(cleanup: impl FnOnce() + 'static) -> Self {
Self(Some(Box::new(cleanup)))
}
}
impl Drop for SlashCommandCleanup {
fn drop(&mut self) {
if let Some(cleanup) = self.0.take() {
cleanup();
}
}
}

View File

@ -0,0 +1,64 @@
use std::sync::Arc;
use collections::HashMap;
use derive_more::{Deref, DerefMut};
use gpui::Global;
use gpui::{AppContext, ReadGlobal};
use parking_lot::RwLock;
use crate::SlashCommand;
#[derive(Default, Deref, DerefMut)]
struct GlobalSlashCommandRegistry(Arc<SlashCommandRegistry>);
impl Global for GlobalSlashCommandRegistry {}
#[derive(Default)]
struct SlashCommandRegistryState {
commands: HashMap<Arc<str>, Arc<dyn SlashCommand>>,
}
#[derive(Default)]
pub struct SlashCommandRegistry {
state: RwLock<SlashCommandRegistryState>,
}
impl SlashCommandRegistry {
/// Returns the global [`SlashCommandRegistry`].
pub fn global(cx: &AppContext) -> Arc<Self> {
GlobalSlashCommandRegistry::global(cx).0.clone()
}
/// Returns the global [`SlashCommandRegistry`].
///
/// Inserts a default [`SlashCommandRegistry`] if one does not yet exist.
pub fn default_global(cx: &mut AppContext) -> Arc<Self> {
cx.default_global::<GlobalSlashCommandRegistry>().0.clone()
}
pub fn new() -> Arc<Self> {
Arc::new(Self {
state: RwLock::new(SlashCommandRegistryState {
commands: HashMap::default(),
}),
})
}
/// Registers the provided [`SlashCommand`].
pub fn register_command(&self, command: impl SlashCommand) {
self.state
.write()
.commands
.insert(command.name().into(), Arc::new(command));
}
/// Returns the names of registered [`SlashCommand`]s.
pub fn command_names(&self) -> Vec<Arc<str>> {
self.state.read().commands.keys().cloned().collect()
}
/// Returns the [`SlashCommand`] with the given name.
pub fn command(&self, name: &str) -> Option<Arc<dyn SlashCommand>> {
self.state.read().commands.get(name).cloned()
}
}