assistant: Add tool registry (#17331)

This PR adds a tool registry to hold tools that can be called by the
Assistant.

Currently we just have a `now` tool for retrieving the current datetime.

This is all behind the `assistant-tool-use` feature flag which currently
needs to be explicitly opted-in to in order for the LLM to see the
tools.

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2024-09-03 19:14:36 -04:00 committed by GitHub
parent c2448e1673
commit e81b484bf2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 243 additions and 2 deletions

15
Cargo.lock generated
View File

@ -373,6 +373,7 @@ dependencies = [
"anyhow",
"assets",
"assistant_slash_command",
"assistant_tool",
"async-watch",
"cargo_toml",
"chrono",
@ -454,6 +455,20 @@ dependencies = [
"workspace",
]
[[package]]
name = "assistant_tool"
version = "0.1.0"
dependencies = [
"anyhow",
"collections",
"derive_more",
"gpui",
"parking_lot",
"serde",
"serde_json",
"workspace",
]
[[package]]
name = "async-attributes"
version = "1.1.2"

View File

@ -6,6 +6,7 @@ members = [
"crates/assets",
"crates/assistant",
"crates/assistant_slash_command",
"crates/assistant_tool",
"crates/audio",
"crates/auto_update",
"crates/breadcrumbs",
@ -181,6 +182,7 @@ anthropic = { path = "crates/anthropic" }
assets = { path = "crates/assets" }
assistant = { path = "crates/assistant" }
assistant_slash_command = { path = "crates/assistant_slash_command" }
assistant_tool = { path = "crates/assistant_tool" }
audio = { path = "crates/audio" }
auto_update = { path = "crates/auto_update" }
breadcrumbs = { path = "crates/breadcrumbs" }

View File

@ -25,6 +25,7 @@ anthropic = { workspace = true, features = ["schemars"] }
anyhow.workspace = true
assets.workspace = true
assistant_slash_command.workspace = true
assistant_tool.workspace = true
async-watch.workspace = true
cargo_toml.workspace = true
chrono.workspace = true

View File

@ -13,11 +13,13 @@ pub(crate) mod slash_command_picker;
pub mod slash_command_settings;
mod streaming_diff;
mod terminal_inline_assistant;
mod tools;
mod workflow;
pub use assistant_panel::{AssistantPanel, AssistantPanelEvent};
use assistant_settings::AssistantSettings;
use assistant_slash_command::SlashCommandRegistry;
use assistant_tool::ToolRegistry;
use client::{proto, Client};
use command_palette_hooks::CommandPaletteFilter;
pub use context::*;
@ -214,6 +216,7 @@ pub fn init(
prompt_library::init(cx);
init_language_model_settings(cx);
assistant_slash_command::init(cx);
assistant_tool::init(cx);
assistant_panel::init(cx);
context_servers::init(cx);
@ -228,6 +231,7 @@ pub fn init(
.map(Arc::new)
.unwrap_or_else(|| Arc::new(prompts::PromptBuilder::new(None).unwrap()));
register_slash_commands(Some(prompt_builder.clone()), cx);
register_tools(cx);
inline_assistant::init(
fs.clone(),
prompt_builder.clone(),
@ -401,6 +405,11 @@ fn update_slash_commands_from_settings(cx: &mut AppContext) {
}
}
fn register_tools(cx: &mut AppContext) {
let tool_registry = ToolRegistry::global(cx);
tool_registry.register_tool(tools::now_tool::NowTool);
}
pub fn humanize_token_count(count: usize) -> String {
match count {
0..=999 => count.to_string(),

View File

@ -9,9 +9,11 @@ use anyhow::{anyhow, Context as _, Result};
use assistant_slash_command::{
SlashCommandOutput, SlashCommandOutputSection, SlashCommandRegistry,
};
use assistant_tool::ToolRegistry;
use client::{self, proto, telemetry::Telemetry};
use clock::ReplicaId;
use collections::{HashMap, HashSet};
use feature_flags::{FeatureFlag, FeatureFlagAppExt};
use fs::{Fs, RemoveOptions};
use futures::{
future::{self, Shared},
@ -27,7 +29,7 @@ use language::{AnchorRangeExt, Bias, Buffer, LanguageRegistry, OffsetRangeExt, P
use language_model::{
LanguageModel, LanguageModelCacheConfiguration, LanguageModelCompletionEvent,
LanguageModelImage, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
MessageContent, Role,
LanguageModelRequestTool, MessageContent, Role,
};
use open_ai::Model as OpenAiModel;
use paths::{context_images_dir, contexts_dir};
@ -1942,7 +1944,21 @@ impl Context {
// Compute which messages to cache, including the last one.
self.mark_cache_anchors(&model.cache_configuration(), false, cx);
let request = self.to_completion_request(cx);
let mut request = self.to_completion_request(cx);
if cx.has_flag::<ToolUseFeatureFlag>() {
let tool_registry = ToolRegistry::global(cx);
request.tools = tool_registry
.tools()
.into_iter()
.map(|tool| LanguageModelRequestTool {
name: tool.name(),
description: tool.description(),
input_schema: tool.input_schema(),
})
.collect();
}
let assistant_message = self
.insert_message_after(last_message_id, Role::Assistant, MessageStatus::Pending, cx)
.unwrap();
@ -2788,6 +2804,16 @@ pub enum PendingSlashCommandStatus {
Error(String),
}
pub(crate) struct ToolUseFeatureFlag;
impl FeatureFlag for ToolUseFeatureFlag {
const NAME: &'static str = "assistant-tool-use";
fn enabled_for_staff() -> bool {
false
}
}
#[derive(Debug, Clone)]
pub struct PendingToolUse {
pub id: String,

View File

@ -0,0 +1 @@
pub mod now_tool;

View File

@ -0,0 +1,60 @@
use std::sync::Arc;
use anyhow::{anyhow, Result};
use assistant_tool::Tool;
use chrono::{Local, Utc};
use gpui::{Task, WeakView, WindowContext};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "snake_case")]
pub enum Timezone {
/// Use UTC for the datetime.
Utc,
/// Use local time for the datetime.
Local,
}
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct FileToolInput {
/// The timezone to use for the datetime.
timezone: Timezone,
}
pub struct NowTool;
impl Tool for NowTool {
fn name(&self) -> String {
"now".into()
}
fn description(&self) -> String {
"Returns the current datetime in RFC 3339 format.".into()
}
fn input_schema(&self) -> serde_json::Value {
let schema = schemars::schema_for!(FileToolInput);
serde_json::to_value(&schema).unwrap()
}
fn run(
self: Arc<Self>,
input: serde_json::Value,
_workspace: WeakView<workspace::Workspace>,
_cx: &mut WindowContext,
) -> Task<Result<String>> {
let input: FileToolInput = match serde_json::from_value(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))),
};
let now = match input.timezone {
Timezone::Utc => Utc::now().to_rfc3339(),
Timezone::Local => Local::now().to_rfc3339(),
};
let text = format!("The current datetime is {now}.");
Task::ready(Ok(text))
}
}

View File

@ -0,0 +1,22 @@
[package]
name = "assistant_tool"
version = "0.1.0"
edition = "2021"
publish = false
license = "GPL-3.0-or-later"
[lints]
workspace = true
[lib]
path = "src/assistant_tool.rs"
[dependencies]
anyhow.workspace = true
collections.workspace = true
derive_more.workspace = true
gpui.workspace = true
parking_lot.workspace = true
serde.workspace = true
serde_json.workspace = true
workspace.workspace = true

View File

@ -0,0 +1 @@
../../LICENSE-GPL

View File

@ -0,0 +1,35 @@
mod tool_registry;
use std::sync::Arc;
use anyhow::Result;
use gpui::{AppContext, Task, WeakView, WindowContext};
use workspace::Workspace;
pub use tool_registry::*;
pub fn init(cx: &mut AppContext) {
ToolRegistry::default_global(cx);
}
/// A tool that can be used by a language model.
pub trait Tool: 'static + Send + Sync {
/// Returns the name of the tool.
fn name(&self) -> String;
/// Returns the description of the tool.
fn description(&self) -> String;
/// Returns the JSON schema that describes the tool's input.
fn input_schema(&self) -> serde_json::Value {
serde_json::Value::Object(serde_json::Map::default())
}
/// Runs the tool with the provided input.
fn run(
self: Arc<Self>,
input: serde_json::Value,
workspace: WeakView<Workspace>,
cx: &mut WindowContext,
) -> Task<Result<String>>;
}

View File

@ -0,0 +1,69 @@
use std::sync::Arc;
use collections::HashMap;
use derive_more::{Deref, DerefMut};
use gpui::Global;
use gpui::{AppContext, ReadGlobal};
use parking_lot::RwLock;
use crate::Tool;
#[derive(Default, Deref, DerefMut)]
struct GlobalToolRegistry(Arc<ToolRegistry>);
impl Global for GlobalToolRegistry {}
#[derive(Default)]
struct ToolRegistryState {
tools: HashMap<Arc<str>, Arc<dyn Tool>>,
}
#[derive(Default)]
pub struct ToolRegistry {
state: RwLock<ToolRegistryState>,
}
impl ToolRegistry {
/// Returns the global [`ToolRegistry`].
pub fn global(cx: &AppContext) -> Arc<Self> {
GlobalToolRegistry::global(cx).0.clone()
}
/// Returns the global [`ToolRegistry`].
///
/// Inserts a default [`ToolRegistry`] if one does not yet exist.
pub fn default_global(cx: &mut AppContext) -> Arc<Self> {
cx.default_global::<GlobalToolRegistry>().0.clone()
}
pub fn new() -> Arc<Self> {
Arc::new(Self {
state: RwLock::new(ToolRegistryState {
tools: HashMap::default(),
}),
})
}
/// Registers the provided [`Tool`].
pub fn register_tool(&self, tool: impl Tool) {
let mut state = self.state.write();
let tool_name: Arc<str> = tool.name().into();
state.tools.insert(tool_name, Arc::new(tool));
}
/// Unregisters the provided [`Tool`].
pub fn unregister_tool(&self, tool: impl Tool) {
self.unregister_tool_by_name(tool.name().as_str())
}
/// Unregisters the tool with the given name.
pub fn unregister_tool_by_name(&self, tool_name: &str) {
let mut state = self.state.write();
state.tools.remove(tool_name);
}
/// Returns the list of tools in the registry.
pub fn tools(&self) -> Vec<Arc<dyn Tool>> {
self.state.read().tools.values().cloned().collect()
}
}