From 874f0c07126bdba209240dd8c8e6ef58b327b8d4 Mon Sep 17 00:00:00 2001 From: Piotr Osiewicz <24362066+osiewicz@users.noreply.github.com> Date: Tue, 6 Aug 2024 15:45:47 +0200 Subject: [PATCH] assistant: Use tools in other providers (#15803) - [x] OpenAI - [ ] ~Google~ Moved into a separate branch at: https://github.com/zed-industries/zed/tree/tool-calls-in-google-ai I've ran into issues with having the API digest our schema without tripping over itself - the function call parameters are malformed and whatnot. We can resume from that branch if needed. - [x] Ollama - [x] Cloud - [ ] ~Copilot Chat (?)~ Release Notes: - Added tool calling capabilities to OpenAI and Ollama models. --- crates/language_model/src/provider/cloud.rs | 139 +++++++++++++++++- crates/language_model/src/provider/ollama.rs | 73 +++++++-- crates/language_model/src/provider/open_ai.rs | 133 ++++++++++++----- crates/ollama/src/ollama.rs | 86 ++++++++++- crates/open_ai/src/open_ai.rs | 25 +++- 5 files changed, 392 insertions(+), 64 deletions(-) diff --git a/crates/language_model/src/provider/cloud.rs b/crates/language_model/src/provider/cloud.rs index a5aa993e43..b413f8d2cb 100644 --- a/crates/language_model/src/provider/cloud.rs +++ b/crates/language_model/src/provider/cloud.rs @@ -4,7 +4,7 @@ use crate::{ LanguageModelName, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, RateLimiter, ZedModel, }; -use anyhow::{anyhow, Context as _, Result}; +use anyhow::{anyhow, bail, Context as _, Result}; use client::{Client, PerformCompletionParams, UserStore, EXPIRED_LLM_TOKEN_HEADER_NAME}; use collections::BTreeMap; use feature_flags::{FeatureFlag, FeatureFlagAppExt, LanguageModels}; @@ -634,14 +634,143 @@ impl LanguageModel for CloudLanguageModel { }) .boxed() } - CloudModel::OpenAi(_) => { - future::ready(Err(anyhow!("tool use not implemented for OpenAI"))).boxed() + CloudModel::OpenAi(model) => { + let mut request = request.into_open_ai(model.id().into()); + let client = self.client.clone(); + let mut function = open_ai::FunctionDefinition { + name: tool_name.clone(), + description: None, + parameters: None, + }; + let func = open_ai::ToolDefinition::Function { + function: function.clone(), + }; + request.tool_choice = Some(open_ai::ToolChoice::Other(func.clone())); + // Fill in description and params separately, as they're not needed for tool_choice field. + function.description = Some(tool_description); + function.parameters = Some(input_schema); + request.tools = vec![open_ai::ToolDefinition::Function { function }]; + self.request_limiter + .run(async move { + let request = serde_json::to_string(&request)?; + let response = client + .request_stream(proto::StreamCompleteWithLanguageModel { + provider: proto::LanguageModelProvider::OpenAi as i32, + request, + }) + .await?; + // Call arguments are gonna be streamed in over multiple chunks. + let mut load_state = None; + let mut response = response.map( + |item: Result< + proto::StreamCompleteWithLanguageModelResponse, + anyhow::Error, + >| { + Result::::Ok( + serde_json::from_str(&item?.event)?, + ) + }, + ); + while let Some(Ok(part)) = response.next().await { + for choice in part.choices { + let Some(tool_calls) = choice.delta.tool_calls else { + continue; + }; + + for call in tool_calls { + if let Some(func) = call.function { + if func.name.as_deref() == Some(tool_name.as_str()) { + load_state = Some((String::default(), call.index)); + } + if let Some((arguments, (output, index))) = + func.arguments.zip(load_state.as_mut()) + { + if call.index == *index { + output.push_str(&arguments); + } + } + } + } + } + } + if let Some((arguments, _)) = load_state { + return Ok(serde_json::from_str(&arguments)?); + } else { + bail!("tool not used"); + } + }) + .boxed() } CloudModel::Google(_) => { future::ready(Err(anyhow!("tool use not implemented for Google AI"))).boxed() } - CloudModel::Zed(_) => { - future::ready(Err(anyhow!("tool use not implemented for Zed models"))).boxed() + CloudModel::Zed(model) => { + // All Zed models are OpenAI-based at the time of writing. + let mut request = request.into_open_ai(model.id().into()); + let client = self.client.clone(); + let mut function = open_ai::FunctionDefinition { + name: tool_name.clone(), + description: None, + parameters: None, + }; + let func = open_ai::ToolDefinition::Function { + function: function.clone(), + }; + request.tool_choice = Some(open_ai::ToolChoice::Other(func.clone())); + // Fill in description and params separately, as they're not needed for tool_choice field. + function.description = Some(tool_description); + function.parameters = Some(input_schema); + request.tools = vec![open_ai::ToolDefinition::Function { function }]; + self.request_limiter + .run(async move { + let request = serde_json::to_string(&request)?; + let response = client + .request_stream(proto::StreamCompleteWithLanguageModel { + provider: proto::LanguageModelProvider::OpenAi as i32, + request, + }) + .await?; + // Call arguments are gonna be streamed in over multiple chunks. + let mut load_state = None; + let mut response = response.map( + |item: Result< + proto::StreamCompleteWithLanguageModelResponse, + anyhow::Error, + >| { + Result::::Ok( + serde_json::from_str(&item?.event)?, + ) + }, + ); + while let Some(Ok(part)) = response.next().await { + for choice in part.choices { + let Some(tool_calls) = choice.delta.tool_calls else { + continue; + }; + + for call in tool_calls { + if let Some(func) = call.function { + if func.name.as_deref() == Some(tool_name.as_str()) { + load_state = Some((String::default(), call.index)); + } + if let Some((arguments, (output, index))) = + func.arguments.zip(load_state.as_mut()) + { + if call.index == *index { + output.push_str(&arguments); + } + } + } + } + } + } + if let Some((arguments, _)) = load_state { + return Ok(serde_json::from_str(&arguments)?); + } else { + bail!("tool not used"); + } + }) + .boxed() } } } diff --git a/crates/language_model/src/provider/ollama.rs b/crates/language_model/src/provider/ollama.rs index 15b70b1cb6..fcf253c7a6 100644 --- a/crates/language_model/src/provider/ollama.rs +++ b/crates/language_model/src/provider/ollama.rs @@ -1,12 +1,14 @@ -use anyhow::{anyhow, Result}; +use anyhow::{anyhow, bail, Result}; use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt}; use gpui::{AnyView, AppContext, AsyncAppContext, ModelContext, Subscription, Task}; use http_client::HttpClient; use ollama::{ get_models, preload_model, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest, + ChatResponseDelta, OllamaToolCall, }; +use serde_json::Value; use settings::{Settings, SettingsStore}; -use std::{future, sync::Arc, time::Duration}; +use std::{sync::Arc, time::Duration}; use ui::{prelude::*, ButtonLike, Indicator}; use util::ResultExt; @@ -184,6 +186,7 @@ impl OllamaLanguageModel { }, Role::Assistant => ChatMessage::Assistant { content: msg.content, + tool_calls: None, }, Role::System => ChatMessage::System { content: msg.content, @@ -198,8 +201,25 @@ impl OllamaLanguageModel { temperature: Some(request.temperature), ..Default::default() }), + tools: vec![], } } + fn request_completion( + &self, + request: ChatRequest, + cx: &AsyncAppContext, + ) -> BoxFuture<'static, Result> { + let http_client = self.http_client.clone(); + + let Ok(api_url) = cx.update(|cx| { + let settings = &AllLanguageModelSettings::get_global(cx).ollama; + settings.api_url.clone() + }) else { + return futures::future::ready(Err(anyhow!("App state dropped"))).boxed(); + }; + + async move { ollama::complete(http_client.as_ref(), &api_url, request).await }.boxed() + } } impl LanguageModel for OllamaLanguageModel { @@ -269,7 +289,7 @@ impl LanguageModel for OllamaLanguageModel { Ok(delta) => { let content = match delta.message { ChatMessage::User { content } => content, - ChatMessage::Assistant { content } => content, + ChatMessage::Assistant { content, .. } => content, ChatMessage::System { content } => content, }; Some(Ok(content)) @@ -286,13 +306,48 @@ impl LanguageModel for OllamaLanguageModel { fn use_any_tool( &self, - _request: LanguageModelRequest, - _name: String, - _description: String, - _schema: serde_json::Value, - _cx: &AsyncAppContext, + request: LanguageModelRequest, + tool_name: String, + tool_description: String, + schema: serde_json::Value, + cx: &AsyncAppContext, ) -> BoxFuture<'static, Result> { - future::ready(Err(anyhow!("not implemented"))).boxed() + use ollama::{OllamaFunctionTool, OllamaTool}; + let function = OllamaFunctionTool { + name: tool_name.clone(), + description: Some(tool_description), + parameters: Some(schema), + }; + let tools = vec![OllamaTool::Function { function }]; + let request = self.to_ollama_request(request).with_tools(tools); + let response = self.request_completion(request, cx); + self.request_limiter + .run(async move { + let response = response.await?; + let ChatMessage::Assistant { + tool_calls, + content, + } = response.message + else { + bail!("message does not have an assistant role"); + }; + if let Some(tool_calls) = tool_calls.filter(|calls| !calls.is_empty()) { + for call in tool_calls { + let OllamaToolCall::Function(function) = call; + if function.name == tool_name { + return Ok(function.arguments); + } + } + } else if let Ok(args) = serde_json::from_str::(&content) { + // Parse content as arguments. + return Ok(args); + } else { + bail!("assistant message does not have any tool calls"); + }; + + bail!("tool not used") + }) + .boxed() } } diff --git a/crates/language_model/src/provider/open_ai.rs b/crates/language_model/src/provider/open_ai.rs index 5121ea802a..aacb386651 100644 --- a/crates/language_model/src/provider/open_ai.rs +++ b/crates/language_model/src/provider/open_ai.rs @@ -1,4 +1,4 @@ -use anyhow::{anyhow, Result}; +use anyhow::{anyhow, bail, Result}; use collections::BTreeMap; use editor::{Editor, EditorElement, EditorStyle}; use futures::{future::BoxFuture, FutureExt, StreamExt}; @@ -7,11 +7,13 @@ use gpui::{ View, WhiteSpace, }; use http_client::HttpClient; -use open_ai::stream_completion; +use open_ai::{ + stream_completion, FunctionDefinition, ResponseStreamEvent, ToolChoice, ToolDefinition, +}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use settings::{Settings, SettingsStore}; -use std::{future, sync::Arc, time::Duration}; +use std::{sync::Arc, time::Duration}; use strum::IntoEnumIterator; use theme::ThemeSettings; use ui::{prelude::*, Indicator}; @@ -206,6 +208,41 @@ pub struct OpenAiLanguageModel { request_limiter: RateLimiter, } +impl OpenAiLanguageModel { + fn stream_completion( + &self, + request: open_ai::Request, + cx: &AsyncAppContext, + ) -> BoxFuture<'static, Result>>> + { + let http_client = self.http_client.clone(); + let Ok((api_key, api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, cx| { + let settings = &AllLanguageModelSettings::get_global(cx).openai; + ( + state.api_key.clone(), + settings.api_url.clone(), + settings.low_speed_timeout, + ) + }) else { + return futures::future::ready(Err(anyhow!("App state dropped"))).boxed(); + }; + + let future = self.request_limiter.stream(async move { + let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?; + let request = stream_completion( + http_client.as_ref(), + &api_url, + &api_key, + request, + low_speed_timeout, + ); + let response = request.await?; + Ok(response) + }); + + async move { Ok(future.await?.boxed()) }.boxed() + } +} impl LanguageModel for OpenAiLanguageModel { fn id(&self) -> LanguageModelId { self.id.clone() @@ -245,44 +282,68 @@ impl LanguageModel for OpenAiLanguageModel { cx: &AsyncAppContext, ) -> BoxFuture<'static, Result>>> { let request = request.into_open_ai(self.model.id().into()); - - let http_client = self.http_client.clone(); - let Ok((api_key, api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, cx| { - let settings = &AllLanguageModelSettings::get_global(cx).openai; - ( - state.api_key.clone(), - settings.api_url.clone(), - settings.low_speed_timeout, - ) - }) else { - return futures::future::ready(Err(anyhow!("App state dropped"))).boxed(); - }; - - let future = self.request_limiter.stream(async move { - let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?; - let request = stream_completion( - http_client.as_ref(), - &api_url, - &api_key, - request, - low_speed_timeout, - ); - let response = request.await?; - Ok(open_ai::extract_text_from_events(response).boxed()) - }); - - async move { Ok(future.await?.boxed()) }.boxed() + let completions = self.stream_completion(request, cx); + async move { Ok(open_ai::extract_text_from_events(completions.await?).boxed()) }.boxed() } fn use_any_tool( &self, - _request: LanguageModelRequest, - _name: String, - _description: String, - _schema: serde_json::Value, - _cx: &AsyncAppContext, + request: LanguageModelRequest, + tool_name: String, + tool_description: String, + schema: serde_json::Value, + cx: &AsyncAppContext, ) -> BoxFuture<'static, Result> { - future::ready(Err(anyhow!("not implemented"))).boxed() + let mut request = request.into_open_ai(self.model.id().into()); + let mut function = FunctionDefinition { + name: tool_name.clone(), + description: None, + parameters: None, + }; + let func = ToolDefinition::Function { + function: function.clone(), + }; + request.tool_choice = Some(ToolChoice::Other(func.clone())); + // Fill in description and params separately, as they're not needed for tool_choice field. + function.description = Some(tool_description); + function.parameters = Some(schema); + request.tools = vec![ToolDefinition::Function { function }]; + let response = self.stream_completion(request, cx); + self.request_limiter + .run(async move { + let mut response = response.await?; + + // Call arguments are gonna be streamed in over multiple chunks. + let mut load_state = None; + while let Some(Ok(part)) = response.next().await { + for choice in part.choices { + let Some(tool_calls) = choice.delta.tool_calls else { + continue; + }; + + for call in tool_calls { + if let Some(func) = call.function { + if func.name.as_deref() == Some(tool_name.as_str()) { + load_state = Some((String::default(), call.index)); + } + if let Some((arguments, (output, index))) = + func.arguments.zip(load_state.as_mut()) + { + if call.index == *index { + output.push_str(&arguments); + } + } + } + } + } + } + if let Some((arguments, _)) = load_state { + return Ok(serde_json::from_str(&arguments)?); + } else { + bail!("tool not used"); + } + }) + .boxed() } } diff --git a/crates/ollama/src/ollama.rs b/crates/ollama/src/ollama.rs index e627118072..9f7cb4db27 100644 --- a/crates/ollama/src/ollama.rs +++ b/crates/ollama/src/ollama.rs @@ -4,6 +4,7 @@ use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest}; use isahc::config::Configurable; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; +use serde_json::Value; use std::{convert::TryFrom, sync::Arc, time::Duration}; pub const OLLAMA_API_URL: &str = "http://localhost:11434"; @@ -94,22 +95,63 @@ impl Model { #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] #[serde(tag = "role", rename_all = "lowercase")] pub enum ChatMessage { - Assistant { content: String }, - User { content: String }, - System { content: String }, + Assistant { + content: String, + tool_calls: Option>, + }, + User { + content: String, + }, + System { + content: String, + }, } -#[derive(Serialize)] +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum OllamaToolCall { + Function(OllamaFunctionCall), +} + +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +pub struct OllamaFunctionCall { + pub name: String, + pub arguments: Value, +} + +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +pub struct OllamaFunctionTool { + pub name: String, + pub description: Option, + pub parameters: Option, +} + +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +#[serde(tag = "type", rename_all = "lowercase")] +pub enum OllamaTool { + Function { function: OllamaFunctionTool }, +} + +#[derive(Serialize, Debug)] pub struct ChatRequest { pub model: String, pub messages: Vec, pub stream: bool, pub keep_alive: KeepAlive, pub options: Option, + pub tools: Vec, +} + +impl ChatRequest { + pub fn with_tools(mut self, tools: Vec) -> Self { + self.stream = false; + self.tools = tools; + self + } } // https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values -#[derive(Serialize, Default)] +#[derive(Serialize, Default, Debug)] pub struct ChatOptions { pub num_ctx: Option, pub num_predict: Option, @@ -118,7 +160,7 @@ pub struct ChatOptions { pub top_p: Option, } -#[derive(Deserialize)] +#[derive(Deserialize, Debug)] pub struct ChatResponseDelta { #[allow(unused)] pub model: String, @@ -162,6 +204,38 @@ pub struct ModelDetails { pub quantization_level: String, } +pub async fn complete( + client: &dyn HttpClient, + api_url: &str, + request: ChatRequest, +) -> Result { + let uri = format!("{api_url}/api/chat"); + let request_builder = HttpRequest::builder() + .method(Method::POST) + .uri(uri) + .header("Content-Type", "application/json"); + + let serialized_request = serde_json::to_string(&request)?; + let request = request_builder.body(AsyncBody::from(serialized_request))?; + + let mut response = client.send(request).await?; + if response.status().is_success() { + let mut body = Vec::new(); + response.body_mut().read_to_end(&mut body).await?; + let response_message: ChatResponseDelta = serde_json::from_slice(&body)?; + Ok(response_message) + } else { + let mut body = Vec::new(); + response.body_mut().read_to_end(&mut body).await?; + let body_str = std::str::from_utf8(&body)?; + Err(anyhow!( + "Failed to connect to API: {} {}", + response.status(), + body_str + )) + } +} + pub async fn stream_chat_completion( client: &dyn HttpClient, api_url: &str, diff --git a/crates/open_ai/src/open_ai.rs b/crates/open_ai/src/open_ai.rs index eb7769c4b6..4fb62831b6 100644 --- a/crates/open_ai/src/open_ai.rs +++ b/crates/open_ai/src/open_ai.rs @@ -3,7 +3,7 @@ use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, S use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest}; use isahc::config::Configurable; use serde::{Deserialize, Serialize}; -use serde_json::{Map, Value}; +use serde_json::Value; use std::{convert::TryFrom, future::Future, time::Duration}; use strum::EnumIter; @@ -121,25 +121,34 @@ pub struct Request { pub stop: Vec, pub temperature: f32, #[serde(default, skip_serializing_if = "Option::is_none")] - pub tool_choice: Option, + pub tool_choice: Option, #[serde(default, skip_serializing_if = "Vec::is_empty")] pub tools: Vec, } -#[derive(Debug, Deserialize, Serialize)] -pub struct FunctionDefinition { - pub name: String, - pub description: Option, - pub parameters: Option>, +#[derive(Debug, Serialize, Deserialize)] +#[serde(untagged)] +pub enum ToolChoice { + Auto, + Required, + None, + Other(ToolDefinition), } -#[derive(Deserialize, Serialize, Debug)] +#[derive(Clone, Deserialize, Serialize, Debug)] #[serde(tag = "type", rename_all = "snake_case")] pub enum ToolDefinition { #[allow(dead_code)] Function { function: FunctionDefinition }, } +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct FunctionDefinition { + pub name: String, + pub description: Option, + pub parameters: Option, +} + #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] #[serde(tag = "role", rename_all = "lowercase")] pub enum RequestMessage {