9cef0ac869
Fast followups to #11629 Release Notes: - N/A --------- Co-authored-by: Max <max@zed.dev> |
||
---|---|---|
.. | ||
src | ||
Cargo.toml | ||
LICENSE-GPL | ||
README.md |
Assistant Tooling
Bringing OpenAI compatible tool calling to GPUI.
This unlocks:
- Structured Extraction of model responses
- Validation of model inputs
- Execution of chosen toolsn
Overview
Language Models can produce structured outputs that are perfect for calling functions. The most famous of these is OpenAI's tool calling. When make a chat completion you can pass a list of tools available to the model. The model will choose 0..n
tools to help them complete a user's task. It's up to you to create the tools that the model can call.
User: "Hey I need help with implementing a collapsible panel in GPUI"
Assistant: "Sure, I can help with that. Let me see what I can find."
tool_calls: ["name": "query_codebase", arguments: "{ 'query': 'GPUI collapsible panel' }"]
result: "['crates/gpui/src/panel.rs:12: impl Panel { ... }', 'crates/gpui/src/panel.rs:20: impl Panel { ... }']"
Assistant: "Here are some excerpts from the GPUI codebase that might help you."
This library is designed to facilitate this interaction mode by allowing you to go from struct
to tool
with a simple trait, LanguageModelTool
.
Example
Let's expose querying a semantic index directly by the model. First, we'll set up some necessary imports
use anyhow::Result;
use assistant_tooling::{LanguageModelTool, ToolRegistry};
use gpui::{App, AppContext, Task};
use schemars::JsonSchema;
use serde::Deserialize;
use serde_json::json;
Then we'll define the query structure the model must fill in. This must derive Deserialize
from serde
and JsonSchema
from the schemars
crate.
#[derive(Deserialize, JsonSchema)]
struct CodebaseQuery {
query: String,
}
After that we can define our tool, with the expectation that it will need a ProjectIndex
to search against. For this example, the index uses the same interface as semantic_index::ProjectIndex
.
struct ProjectIndex {}
impl ProjectIndex {
fn new() -> Self {
ProjectIndex {}
}
fn search(&self, _query: &str, _limit: usize, _cx: &AppContext) -> Task<Result<Vec<String>>> {
// Instead of hooking up a real index, we're going to fake it
if _query.contains("gpui") {
return Task::ready(Ok(vec![r#"// crates/gpui/src/gpui.rs
//! # Welcome to GPUI!
//!
//! GPUI is a hybrid immediate and retained mode, GPU accelerated, UI framework
//! for Rust, designed to support a wide variety of applications
"#
.to_string()]));
}
return Task::ready(Ok(vec![]));
}
}
struct ProjectIndexTool {
project_index: ProjectIndex,
}
Now we can implement the LanguageModelTool
trait for our tool by:
- Defining the
Input
from the model, which isCodebaseQuery
- Defining the
Output
- Implementing the
name
anddescription
functions to provide the model information when it's choosing a tool - Implementing the
execute
function to run the tool
impl LanguageModelTool for ProjectIndexTool {
type Input = CodebaseQuery;
type Output = String;
fn name(&self) -> String {
"query_codebase".to_string()
}
fn description(&self) -> String {
"Executes a query against the codebase, returning excerpts related to the query".to_string()
}
fn execute(&self, query: Self::Input, cx: &AppContext) -> Task<Result<Self::Output>> {
let results = self.project_index.search(query.query.as_str(), 10, cx);
cx.spawn(|_cx| async move {
let results = results.await?;
if !results.is_empty() {
Ok(results.join("\n"))
} else {
Ok("No results".to_string())
}
})
}
}
For the sake of this example, let's look at the types that OpenAI will be passing to us
// OpenAI definitions, shown here for demonstration
#[derive(Deserialize)]
struct FunctionCall {
name: String,
args: String,
}
#[derive(Deserialize, Eq, PartialEq)]
enum ToolCallType {
#[serde(rename = "function")]
Function,
Other,
}
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash, Ord, PartialOrd)]
struct ToolCallId(String);
#[derive(Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
enum ToolCall {
Function {
#[allow(dead_code)]
id: ToolCallId,
function: FunctionCall,
},
Other {
#[allow(dead_code)]
id: ToolCallId,
},
}
#[derive(Deserialize)]
struct AssistantMessage {
role: String,
content: Option<String>,
tool_calls: Option<Vec<ToolCall>>,
}
When the model wants to call tools, it will pass a list of ToolCall
s. When those are function
s that we can handle, we'll pass them to our ToolRegistry
to get a future that we can await.
// Inside `fn main()`
App::new().run(|cx: &mut AppContext| {
let tool = ProjectIndexTool {
project_index: ProjectIndex::new(),
};
let mut registry = ToolRegistry::new();
let registered = registry.register(tool);
assert!(registered.is_ok());
Let's pretend the model sent us back a message requesting
let model_response = json!({
"role": "assistant",
"tool_calls": [
{
"id": "call_1",
"function": {
"name": "query_codebase",
"args": r#"{"query":"GPUI Task background_executor"}"#
},
"type": "function"
}
]
});
let message: AssistantMessage = serde_json::from_value(model_response).unwrap();
// We know there's a tool call, so let's skip straight to it for this example
let tool_calls = message.tool_calls.as_ref().unwrap();
let tool_call = tool_calls.get(0).unwrap();
We can now use our registry to call the tool.
let task = registry.call(
tool_call.name,
tool_call.args,
);
cx.spawn(|_cx| async move {
let result = task.await?;
println!("{}", result.unwrap());
Ok(())
})