From 4c390b82fbe1c512932cce4d65ddb0fdc0d985b0 Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Wed, 14 Aug 2024 18:02:46 -0700 Subject: [PATCH] Make LanguageModel::use_any_tool return a stream of chunks (#16262) This PR is a refactor to pave the way for allowing the user to view and edit workflow step resolutions. I've made tool calls work more like normal streaming completions for all providers. The `use_any_tool` method returns a stream of strings (which contain chunks of JSON). I've also done some minor cleanup of language model providers in general, removing the duplication around handling streaming responses. Release Notes: - N/A --- crates/anthropic/src/anthropic.rs | 46 ++- crates/assistant/src/context.rs | 49 +-- crates/assistant/src/inline_assistant.rs | 6 - crates/gpui/src/elements/img.rs | 5 +- crates/language_model/src/language_model.rs | 11 +- .../language_model/src/provider/anthropic.rs | 50 +-- crates/language_model/src/provider/cloud.rs | 295 ++++++------------ .../src/provider/copilot_chat.rs | 2 +- crates/language_model/src/provider/fake.rs | 37 +-- crates/language_model/src/provider/google.rs | 2 +- crates/language_model/src/provider/ollama.rs | 17 +- crates/language_model/src/provider/open_ai.rs | 70 ++--- crates/ollama/src/ollama.rs | 10 +- crates/open_ai/src/open_ai.rs | 53 +++- 14 files changed, 253 insertions(+), 400 deletions(-) diff --git a/crates/anthropic/src/anthropic.rs b/crates/anthropic/src/anthropic.rs index 0ceee553d2..7c1774dee4 100644 --- a/crates/anthropic/src/anthropic.rs +++ b/crates/anthropic/src/anthropic.rs @@ -5,8 +5,8 @@ use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, S use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest}; use isahc::config::Configurable; use serde::{Deserialize, Serialize}; -use std::str::FromStr; use std::time::Duration; +use std::{pin::Pin, str::FromStr}; use strum::{EnumIter, EnumString}; use thiserror::Error; @@ -241,6 +241,50 @@ pub fn extract_text_from_events( }) } +pub async fn extract_tool_args_from_events( + tool_name: String, + mut events: Pin>>>, +) -> Result>> { + let mut tool_use_index = None; + while let Some(event) = events.next().await { + if let Event::ContentBlockStart { + index, + content_block, + } = event? + { + if let Content::ToolUse { name, .. } = content_block { + if name == tool_name { + tool_use_index = Some(index); + break; + } + } + } + } + + let Some(tool_use_index) = tool_use_index else { + return Err(anyhow!("tool not used")); + }; + + Ok(events.filter_map(move |event| { + let result = match event { + Err(error) => Some(Err(error)), + Ok(Event::ContentBlockDelta { index, delta }) => match delta { + ContentDelta::TextDelta { .. } => None, + ContentDelta::InputJsonDelta { partial_json } => { + if index == tool_use_index { + Some(Ok(partial_json)) + } else { + None + } + } + }, + _ => None, + }; + + async move { result } + })) +} + #[derive(Debug, Serialize, Deserialize)] pub struct Message { pub role: Role, diff --git a/crates/assistant/src/context.rs b/crates/assistant/src/context.rs index 8ff7f1bbef..14791d934a 100644 --- a/crates/assistant/src/context.rs +++ b/crates/assistant/src/context.rs @@ -1,6 +1,6 @@ use crate::{ - prompts::PromptBuilder, slash_command::SlashCommandLine, AssistantPanel, InitialInsertion, - InlineAssistId, InlineAssistant, MessageId, MessageStatus, + prompts::PromptBuilder, slash_command::SlashCommandLine, AssistantPanel, InlineAssistId, + InlineAssistant, MessageId, MessageStatus, }; use anyhow::{anyhow, Context as _, Result}; use assistant_slash_command::{ @@ -3342,7 +3342,7 @@ mod tests { model .as_fake() - .respond_to_last_tool_use(Ok(serde_json::to_value(tool::WorkflowStepResolution { + .respond_to_last_tool_use(tool::WorkflowStepResolution { step_title: "Title".into(), suggestions: vec![tool::WorkflowSuggestion { path: "/root/hello.rs".into(), @@ -3352,8 +3352,7 @@ mod tests { description: "Extract a greeting function".into(), }, }], - }) - .unwrap())); + }); // Wait for tool use to be processed. cx.run_until_parked(); @@ -4084,44 +4083,4 @@ mod tool { symbol: String, }, } - - impl WorkflowSuggestionKind { - pub fn symbol(&self) -> Option<&str> { - match self { - Self::Update { symbol, .. } => Some(symbol), - Self::InsertSiblingBefore { symbol, .. } => Some(symbol), - Self::InsertSiblingAfter { symbol, .. } => Some(symbol), - Self::PrependChild { symbol, .. } => symbol.as_deref(), - Self::AppendChild { symbol, .. } => symbol.as_deref(), - Self::Delete { symbol } => Some(symbol), - Self::Create { .. } => None, - } - } - - pub fn description(&self) -> Option<&str> { - match self { - Self::Update { description, .. } => Some(description), - Self::Create { description } => Some(description), - Self::InsertSiblingBefore { description, .. } => Some(description), - Self::InsertSiblingAfter { description, .. } => Some(description), - Self::PrependChild { description, .. } => Some(description), - Self::AppendChild { description, .. } => Some(description), - Self::Delete { .. } => None, - } - } - - pub fn initial_insertion(&self) -> Option { - match self { - WorkflowSuggestionKind::InsertSiblingBefore { .. } => { - Some(InitialInsertion::NewlineAfter) - } - WorkflowSuggestionKind::InsertSiblingAfter { .. } => { - Some(InitialInsertion::NewlineBefore) - } - WorkflowSuggestionKind::PrependChild { .. } => Some(InitialInsertion::NewlineAfter), - WorkflowSuggestionKind::AppendChild { .. } => Some(InitialInsertion::NewlineBefore), - _ => None, - } - } - } } diff --git a/crates/assistant/src/inline_assistant.rs b/crates/assistant/src/inline_assistant.rs index 5b84dc07d9..fbbe7d4224 100644 --- a/crates/assistant/src/inline_assistant.rs +++ b/crates/assistant/src/inline_assistant.rs @@ -1280,12 +1280,6 @@ fn build_assist_editor_renderer(editor: &View) -> RenderBlock { }) } -#[derive(Copy, Clone, Debug, Eq, PartialEq)] -pub enum InitialInsertion { - NewlineBefore, - NewlineAfter, -} - #[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)] pub struct InlineAssistId(usize); diff --git a/crates/gpui/src/elements/img.rs b/crates/gpui/src/elements/img.rs index 89fffd054e..2a647c3621 100644 --- a/crates/gpui/src/elements/img.rs +++ b/crates/gpui/src/elements/img.rs @@ -351,10 +351,13 @@ impl Asset for ImageAsset { let mut body = Vec::new(); response.body_mut().read_to_end(&mut body).await?; if !response.status().is_success() { + let mut body = String::from_utf8_lossy(&body).into_owned(); + let first_line = body.lines().next().unwrap_or("").trim_end(); + body.truncate(first_line.len()); return Err(ImageCacheError::BadStatus { uri, status: response.status(), - body: String::from_utf8_lossy(&body).into_owned(), + body, }); } body diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs index 9377dea178..f0a5754518 100644 --- a/crates/language_model/src/language_model.rs +++ b/crates/language_model/src/language_model.rs @@ -8,7 +8,7 @@ pub mod settings; use anyhow::Result; use client::{Client, UserStore}; -use futures::{future::BoxFuture, stream::BoxStream}; +use futures::{future::BoxFuture, stream::BoxStream, TryStreamExt as _}; use gpui::{ AnyElement, AnyView, AppContext, AsyncAppContext, Model, SharedString, Task, WindowContext, }; @@ -76,7 +76,7 @@ pub trait LanguageModel: Send + Sync { description: String, schema: serde_json::Value, cx: &AsyncAppContext, - ) -> BoxFuture<'static, Result>; + ) -> BoxFuture<'static, Result>>>; #[cfg(any(test, feature = "test-support"))] fn as_fake(&self) -> &provider::fake::FakeLanguageModel { @@ -92,10 +92,11 @@ impl dyn LanguageModel { ) -> impl 'static + Future> { let schema = schemars::schema_for!(T); let schema_json = serde_json::to_value(&schema).unwrap(); - let request = self.use_any_tool(request, T::name(), T::description(), schema_json, cx); + let stream = self.use_any_tool(request, T::name(), T::description(), schema_json, cx); async move { - let response = request.await?; - Ok(serde_json::from_value(response)?) + let stream = stream.await?; + let response = stream.try_collect::().await?; + Ok(serde_json::from_str(&response)?) } } } diff --git a/crates/language_model/src/provider/anthropic.rs b/crates/language_model/src/provider/anthropic.rs index 2445430775..7f3cbbf44f 100644 --- a/crates/language_model/src/provider/anthropic.rs +++ b/crates/language_model/src/provider/anthropic.rs @@ -7,7 +7,7 @@ use anthropic::AnthropicError; use anyhow::{anyhow, Context as _, Result}; use collections::BTreeMap; use editor::{Editor, EditorElement, EditorStyle}; -use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt}; +use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt, TryStreamExt as _}; use gpui::{ AnyView, AppContext, AsyncAppContext, FontStyle, ModelContext, Subscription, Task, TextStyle, View, WhiteSpace, @@ -264,29 +264,6 @@ pub fn count_anthropic_tokens( } impl AnthropicModel { - fn request_completion( - &self, - request: anthropic::Request, - cx: &AsyncAppContext, - ) -> BoxFuture<'static, Result> { - let http_client = self.http_client.clone(); - - let Ok((api_key, api_url)) = cx.read_model(&self.state, |state, cx| { - let settings = &AllLanguageModelSettings::get_global(cx).anthropic; - (state.api_key.clone(), settings.api_url.clone()) - }) else { - return futures::future::ready(Err(anyhow!("App state dropped"))).boxed(); - }; - - async move { - let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?; - anthropic::complete(http_client.as_ref(), &api_url, &api_key, request) - .await - .context("failed to retrieve completion") - } - .boxed() - } - fn stream_completion( &self, request: anthropic::Request, @@ -381,7 +358,7 @@ impl LanguageModel for AnthropicModel { tool_description: String, input_schema: serde_json::Value, cx: &AsyncAppContext, - ) -> BoxFuture<'static, Result> { + ) -> BoxFuture<'static, Result>>> { let mut request = request.into_anthropic(self.model.tool_model_id().into()); request.tool_choice = Some(anthropic::ToolChoice::Tool { name: tool_name.clone(), @@ -392,25 +369,16 @@ impl LanguageModel for AnthropicModel { input_schema, }]; - let response = self.request_completion(request, cx); + let response = self.stream_completion(request, cx); self.request_limiter .run(async move { let response = response.await?; - response - .content - .into_iter() - .find_map(|content| { - if let anthropic::Content::ToolUse { name, input, .. } = content { - if name == tool_name { - Some(input) - } else { - None - } - } else { - None - } - }) - .context("tool not used") + Ok(anthropic::extract_tool_args_from_events( + tool_name, + Box::pin(response.map_err(|e| anyhow!(e))), + ) + .await? + .boxed()) }) .boxed() } diff --git a/crates/language_model/src/provider/cloud.rs b/crates/language_model/src/provider/cloud.rs index 6d51bc6e6b..1bb68a0513 100644 --- a/crates/language_model/src/provider/cloud.rs +++ b/crates/language_model/src/provider/cloud.rs @@ -5,18 +5,21 @@ use crate::{ LanguageModelProviderState, LanguageModelRequest, RateLimiter, ZedModel, }; use anthropic::AnthropicError; -use anyhow::{anyhow, bail, Context as _, Result}; +use anyhow::{anyhow, Result}; use client::{Client, PerformCompletionParams, UserStore, EXPIRED_LLM_TOKEN_HEADER_NAME}; use collections::BTreeMap; use feature_flags::{FeatureFlagAppExt, ZedPro}; -use futures::{future::BoxFuture, stream::BoxStream, AsyncBufReadExt, FutureExt, StreamExt}; +use futures::{ + future::BoxFuture, stream::BoxStream, AsyncBufReadExt, FutureExt, Stream, StreamExt, + TryStreamExt as _, +}; use gpui::{ AnyElement, AnyView, AppContext, AsyncAppContext, FontWeight, Model, ModelContext, Subscription, Task, }; use http_client::{AsyncBody, HttpClient, Method, Response}; use schemars::JsonSchema; -use serde::{Deserialize, Serialize}; +use serde::{de::DeserializeOwned, Deserialize, Serialize}; use serde_json::value::RawValue; use settings::{Settings, SettingsStore}; use smol::{ @@ -451,21 +454,9 @@ impl LanguageModel for CloudLanguageModel { }, ) .await?; - let body = BufReader::new(response.into_body()); - let stream = futures::stream::try_unfold(body, move |mut body| async move { - let mut buffer = String::new(); - match body.read_line(&mut buffer).await { - Ok(0) => Ok(None), - Ok(_) => { - let event: anthropic::Event = serde_json::from_str(&buffer) - .context("failed to parse Anthropic event")?; - Ok(Some((event, body))) - } - Err(err) => Err(AnthropicError::Other(err.into())), - } - }); - - Ok(anthropic::extract_text_from_events(stream)) + Ok(anthropic::extract_text_from_events( + response_lines(response).map_err(AnthropicError::Other), + )) }); async move { Ok(future @@ -492,21 +483,7 @@ impl LanguageModel for CloudLanguageModel { }, ) .await?; - let body = BufReader::new(response.into_body()); - let stream = futures::stream::try_unfold(body, move |mut body| async move { - let mut buffer = String::new(); - match body.read_line(&mut buffer).await { - Ok(0) => Ok(None), - Ok(_) => { - let event: open_ai::ResponseStreamEvent = - serde_json::from_str(&buffer)?; - Ok(Some((event, body))) - } - Err(e) => Err(e.into()), - } - }); - - Ok(open_ai::extract_text_from_events(stream)) + Ok(open_ai::extract_text_from_events(response_lines(response))) }); async move { Ok(future.await?.boxed()) }.boxed() } @@ -527,21 +504,9 @@ impl LanguageModel for CloudLanguageModel { }, ) .await?; - let body = BufReader::new(response.into_body()); - let stream = futures::stream::try_unfold(body, move |mut body| async move { - let mut buffer = String::new(); - match body.read_line(&mut buffer).await { - Ok(0) => Ok(None), - Ok(_) => { - let event: google_ai::GenerateContentResponse = - serde_json::from_str(&buffer)?; - Ok(Some((event, body))) - } - Err(e) => Err(e.into()), - } - }); - - Ok(google_ai::extract_text_from_events(stream)) + Ok(google_ai::extract_text_from_events(response_lines( + response, + ))) }); async move { Ok(future.await?.boxed()) }.boxed() } @@ -563,21 +528,7 @@ impl LanguageModel for CloudLanguageModel { }, ) .await?; - let body = BufReader::new(response.into_body()); - let stream = futures::stream::try_unfold(body, move |mut body| async move { - let mut buffer = String::new(); - match body.read_line(&mut buffer).await { - Ok(0) => Ok(None), - Ok(_) => { - let event: open_ai::ResponseStreamEvent = - serde_json::from_str(&buffer)?; - Ok(Some((event, body))) - } - Err(e) => Err(e.into()), - } - }); - - Ok(open_ai::extract_text_from_events(stream)) + Ok(open_ai::extract_text_from_events(response_lines(response))) }); async move { Ok(future.await?.boxed()) }.boxed() } @@ -591,10 +542,12 @@ impl LanguageModel for CloudLanguageModel { tool_description: String, input_schema: serde_json::Value, _cx: &AsyncAppContext, - ) -> BoxFuture<'static, Result> { + ) -> BoxFuture<'static, Result>>> { + let client = self.client.clone(); + let llm_api_token = self.llm_api_token.clone(); + match &self.model { CloudModel::Anthropic(model) => { - let client = self.client.clone(); let mut request = request.into_anthropic(model.tool_model_id().into()); request.tool_choice = Some(anthropic::ToolChoice::Tool { name: tool_name.clone(), @@ -605,7 +558,6 @@ impl LanguageModel for CloudLanguageModel { input_schema, }]; - let llm_api_token = self.llm_api_token.clone(); self.request_limiter .run(async move { let response = Self::perform_llm_completion( @@ -621,70 +573,34 @@ impl LanguageModel for CloudLanguageModel { ) .await?; - let mut tool_use_index = None; - let mut tool_input = String::new(); - let mut body = BufReader::new(response.into_body()); - let mut line = String::new(); - while body.read_line(&mut line).await? > 0 { - let event: anthropic::Event = serde_json::from_str(&line)?; - line.clear(); - - match event { - anthropic::Event::ContentBlockStart { - content_block, - index, - } => { - if let anthropic::Content::ToolUse { name, .. } = content_block - { - if name == tool_name { - tool_use_index = Some(index); - } - } - } - anthropic::Event::ContentBlockDelta { index, delta } => match delta - { - anthropic::ContentDelta::TextDelta { .. } => {} - anthropic::ContentDelta::InputJsonDelta { partial_json } => { - if Some(index) == tool_use_index { - tool_input.push_str(&partial_json); - } - } - }, - anthropic::Event::ContentBlockStop { index } => { - if Some(index) == tool_use_index { - return Ok(serde_json::from_str(&tool_input)?); - } - } - _ => {} - } - } - - if tool_use_index.is_some() { - Err(anyhow!("tool content incomplete")) - } else { - Err(anyhow!("tool not used")) - } + Ok(anthropic::extract_tool_args_from_events( + tool_name, + Box::pin(response_lines(response)), + ) + .await? + .boxed()) }) .boxed() } CloudModel::OpenAi(model) => { let mut request = request.into_open_ai(model.id().into()); - let client = self.client.clone(); - let mut function = open_ai::FunctionDefinition { - name: tool_name.clone(), - description: None, - parameters: None, - }; - let func = open_ai::ToolDefinition::Function { - function: function.clone(), - }; - request.tool_choice = Some(open_ai::ToolChoice::Other(func.clone())); - // Fill in description and params separately, as they're not needed for tool_choice field. - function.description = Some(tool_description); - function.parameters = Some(input_schema); - request.tools = vec![open_ai::ToolDefinition::Function { function }]; + request.tool_choice = Some(open_ai::ToolChoice::Other( + open_ai::ToolDefinition::Function { + function: open_ai::FunctionDefinition { + name: tool_name.clone(), + description: None, + parameters: None, + }, + }, + )); + request.tools = vec![open_ai::ToolDefinition::Function { + function: open_ai::FunctionDefinition { + name: tool_name.clone(), + description: Some(tool_description), + parameters: Some(input_schema), + }, + }]; - let llm_api_token = self.llm_api_token.clone(); self.request_limiter .run(async move { let response = Self::perform_llm_completion( @@ -700,41 +616,12 @@ impl LanguageModel for CloudLanguageModel { ) .await?; - let mut body = BufReader::new(response.into_body()); - let mut line = String::new(); - let mut load_state = None; - - while body.read_line(&mut line).await? > 0 { - let part: open_ai::ResponseStreamEvent = serde_json::from_str(&line)?; - line.clear(); - - for choice in part.choices { - let Some(tool_calls) = choice.delta.tool_calls else { - continue; - }; - - for call in tool_calls { - if let Some(func) = call.function { - if func.name.as_deref() == Some(tool_name.as_str()) { - load_state = Some((String::default(), call.index)); - } - if let Some((arguments, (output, index))) = - func.arguments.zip(load_state.as_mut()) - { - if call.index == *index { - output.push_str(&arguments); - } - } - } - } - } - } - - if let Some((arguments, _)) = load_state { - return Ok(serde_json::from_str(&arguments)?); - } else { - bail!("tool not used"); - } + Ok(open_ai::extract_tool_args_from_events( + tool_name, + Box::pin(response_lines(response)), + ) + .await? + .boxed()) }) .boxed() } @@ -744,22 +631,23 @@ impl LanguageModel for CloudLanguageModel { CloudModel::Zed(model) => { // All Zed models are OpenAI-based at the time of writing. let mut request = request.into_open_ai(model.id().into()); - let client = self.client.clone(); - let mut function = open_ai::FunctionDefinition { - name: tool_name.clone(), - description: None, - parameters: None, - }; - let func = open_ai::ToolDefinition::Function { - function: function.clone(), - }; - request.tool_choice = Some(open_ai::ToolChoice::Other(func.clone())); - // Fill in description and params separately, as they're not needed for tool_choice field. - function.description = Some(tool_description); - function.parameters = Some(input_schema); - request.tools = vec![open_ai::ToolDefinition::Function { function }]; + request.tool_choice = Some(open_ai::ToolChoice::Other( + open_ai::ToolDefinition::Function { + function: open_ai::FunctionDefinition { + name: tool_name.clone(), + description: None, + parameters: None, + }, + }, + )); + request.tools = vec![open_ai::ToolDefinition::Function { + function: open_ai::FunctionDefinition { + name: tool_name.clone(), + description: Some(tool_description), + parameters: Some(input_schema), + }, + }]; - let llm_api_token = self.llm_api_token.clone(); self.request_limiter .run(async move { let response = Self::perform_llm_completion( @@ -775,40 +663,12 @@ impl LanguageModel for CloudLanguageModel { ) .await?; - let mut body = BufReader::new(response.into_body()); - let mut line = String::new(); - let mut load_state = None; - - while body.read_line(&mut line).await? > 0 { - let part: open_ai::ResponseStreamEvent = serde_json::from_str(&line)?; - line.clear(); - - for choice in part.choices { - let Some(tool_calls) = choice.delta.tool_calls else { - continue; - }; - - for call in tool_calls { - if let Some(func) = call.function { - if func.name.as_deref() == Some(tool_name.as_str()) { - load_state = Some((String::default(), call.index)); - } - if let Some((arguments, (output, index))) = - func.arguments.zip(load_state.as_mut()) - { - if call.index == *index { - output.push_str(&arguments); - } - } - } - } - } - } - if let Some((arguments, _)) = load_state { - return Ok(serde_json::from_str(&arguments)?); - } else { - bail!("tool not used"); - } + Ok(open_ai::extract_tool_args_from_events( + tool_name, + Box::pin(response_lines(response)), + ) + .await? + .boxed()) }) .boxed() } @@ -816,6 +676,25 @@ impl LanguageModel for CloudLanguageModel { } } +fn response_lines( + response: Response, +) -> impl Stream> { + futures::stream::try_unfold( + (String::new(), BufReader::new(response.into_body())), + move |(mut line, mut body)| async { + match body.read_line(&mut line).await { + Ok(0) => Ok(None), + Ok(_) => { + let event: T = serde_json::from_str(&line)?; + line.clear(); + Ok(Some((event, (line, body)))) + } + Err(e) => Err(e.into()), + } + }, + ) +} + impl LlmApiToken { async fn acquire(&self, client: &Arc) -> Result { let lock = self.0.upgradable_read().await; diff --git a/crates/language_model/src/provider/copilot_chat.rs b/crates/language_model/src/provider/copilot_chat.rs index e1fc35ed75..f538d31aec 100644 --- a/crates/language_model/src/provider/copilot_chat.rs +++ b/crates/language_model/src/provider/copilot_chat.rs @@ -252,7 +252,7 @@ impl LanguageModel for CopilotChatLanguageModel { _description: String, _schema: serde_json::Value, _cx: &AsyncAppContext, - ) -> BoxFuture<'static, Result> { + ) -> BoxFuture<'static, Result>>> { future::ready(Err(anyhow!("not implemented"))).boxed() } } diff --git a/crates/language_model/src/provider/fake.rs b/crates/language_model/src/provider/fake.rs index 9939cf2791..f62539aef2 100644 --- a/crates/language_model/src/provider/fake.rs +++ b/crates/language_model/src/provider/fake.rs @@ -3,16 +3,11 @@ use crate::{ LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, }; -use anyhow::Context as _; -use futures::{ - channel::{mpsc, oneshot}, - future::BoxFuture, - stream::BoxStream, - FutureExt, StreamExt, -}; +use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt}; use gpui::{AnyView, AppContext, AsyncAppContext, Task}; use http_client::Result; use parking_lot::Mutex; +use serde::Serialize; use std::sync::Arc; use ui::WindowContext; @@ -90,7 +85,7 @@ pub struct ToolUseRequest { #[derive(Default)] pub struct FakeLanguageModel { current_completion_txs: Mutex)>>, - current_tool_use_txs: Mutex>)>>, + current_tool_use_txs: Mutex)>>, } impl FakeLanguageModel { @@ -130,25 +125,11 @@ impl FakeLanguageModel { self.end_completion_stream(self.pending_completions().last().unwrap()); } - pub fn respond_to_tool_use( - &self, - tool_call: &ToolUseRequest, - response: Result, - ) { - let mut current_tool_call_txs = self.current_tool_use_txs.lock(); - if let Some(index) = current_tool_call_txs - .iter() - .position(|(call, _)| call == tool_call) - { - let (_, tx) = current_tool_call_txs.remove(index); - tx.send(response).unwrap(); - } - } - - pub fn respond_to_last_tool_use(&self, response: Result) { + pub fn respond_to_last_tool_use(&self, response: T) { + let response = serde_json::to_string(&response).unwrap(); let mut current_tool_call_txs = self.current_tool_use_txs.lock(); let (_, tx) = current_tool_call_txs.pop().unwrap(); - tx.send(response).unwrap(); + tx.unbounded_send(response).unwrap(); } } @@ -202,8 +183,8 @@ impl LanguageModel for FakeLanguageModel { description: String, schema: serde_json::Value, _cx: &AsyncAppContext, - ) -> BoxFuture<'static, Result> { - let (tx, rx) = oneshot::channel(); + ) -> BoxFuture<'static, Result>>> { + let (tx, rx) = mpsc::unbounded(); let tool_call = ToolUseRequest { request, name, @@ -211,7 +192,7 @@ impl LanguageModel for FakeLanguageModel { schema, }; self.current_tool_use_txs.lock().push((tool_call, tx)); - async move { rx.await.context("FakeLanguageModel was dropped")? }.boxed() + async move { Ok(rx.map(Ok).boxed()) }.boxed() } fn as_fake(&self) -> &Self { diff --git a/crates/language_model/src/provider/google.rs b/crates/language_model/src/provider/google.rs index 1fee3fb348..96475f752f 100644 --- a/crates/language_model/src/provider/google.rs +++ b/crates/language_model/src/provider/google.rs @@ -302,7 +302,7 @@ impl LanguageModel for GoogleLanguageModel { _description: String, _schema: serde_json::Value, _cx: &AsyncAppContext, - ) -> BoxFuture<'static, Result> { + ) -> BoxFuture<'static, Result>>> { future::ready(Err(anyhow!("not implemented"))).boxed() } } diff --git a/crates/language_model/src/provider/ollama.rs b/crates/language_model/src/provider/ollama.rs index 88bc21791f..0ff3d70a6a 100644 --- a/crates/language_model/src/provider/ollama.rs +++ b/crates/language_model/src/provider/ollama.rs @@ -6,7 +6,6 @@ use ollama::{ get_models, preload_model, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest, ChatResponseDelta, OllamaToolCall, }; -use serde_json::Value; use settings::{Settings, SettingsStore}; use std::{sync::Arc, time::Duration}; use ui::{prelude::*, ButtonLike, Indicator}; @@ -311,7 +310,7 @@ impl LanguageModel for OllamaLanguageModel { tool_description: String, schema: serde_json::Value, cx: &AsyncAppContext, - ) -> BoxFuture<'static, Result> { + ) -> BoxFuture<'static, Result>>> { use ollama::{OllamaFunctionTool, OllamaTool}; let function = OllamaFunctionTool { name: tool_name.clone(), @@ -324,23 +323,19 @@ impl LanguageModel for OllamaLanguageModel { self.request_limiter .run(async move { let response = response.await?; - let ChatMessage::Assistant { - tool_calls, - content, - } = response.message - else { + let ChatMessage::Assistant { tool_calls, .. } = response.message else { bail!("message does not have an assistant role"); }; if let Some(tool_calls) = tool_calls.filter(|calls| !calls.is_empty()) { for call in tool_calls { let OllamaToolCall::Function(function) = call; if function.name == tool_name { - return Ok(function.arguments); + return Ok(futures::stream::once(async move { + Ok(function.arguments.to_string()) + }) + .boxed()); } } - } else if let Ok(args) = serde_json::from_str::(&content) { - // Parse content as arguments. - return Ok(args); } else { bail!("assistant message does not have any tool calls"); }; diff --git a/crates/language_model/src/provider/open_ai.rs b/crates/language_model/src/provider/open_ai.rs index c8b99d8ff0..0d3ee56d74 100644 --- a/crates/language_model/src/provider/open_ai.rs +++ b/crates/language_model/src/provider/open_ai.rs @@ -1,4 +1,4 @@ -use anyhow::{anyhow, bail, Result}; +use anyhow::{anyhow, Result}; use collections::BTreeMap; use editor::{Editor, EditorElement, EditorStyle}; use futures::{future::BoxFuture, FutureExt, StreamExt}; @@ -243,6 +243,7 @@ impl OpenAiLanguageModel { async move { Ok(future.await?.boxed()) }.boxed() } } + impl LanguageModel for OpenAiLanguageModel { fn id(&self) -> LanguageModelId { self.id.clone() @@ -293,55 +294,32 @@ impl LanguageModel for OpenAiLanguageModel { tool_description: String, schema: serde_json::Value, cx: &AsyncAppContext, - ) -> BoxFuture<'static, Result> { + ) -> BoxFuture<'static, Result>>> { let mut request = request.into_open_ai(self.model.id().into()); - let mut function = FunctionDefinition { - name: tool_name.clone(), - description: None, - parameters: None, - }; - let func = ToolDefinition::Function { - function: function.clone(), - }; - request.tool_choice = Some(ToolChoice::Other(func.clone())); - // Fill in description and params separately, as they're not needed for tool_choice field. - function.description = Some(tool_description); - function.parameters = Some(schema); - request.tools = vec![ToolDefinition::Function { function }]; + request.tool_choice = Some(ToolChoice::Other(ToolDefinition::Function { + function: FunctionDefinition { + name: tool_name.clone(), + description: None, + parameters: None, + }, + })); + request.tools = vec![ToolDefinition::Function { + function: FunctionDefinition { + name: tool_name.clone(), + description: Some(tool_description), + parameters: Some(schema), + }, + }]; + let response = self.stream_completion(request, cx); self.request_limiter .run(async move { - let mut response = response.await?; - - // Call arguments are gonna be streamed in over multiple chunks. - let mut load_state = None; - while let Some(Ok(part)) = response.next().await { - for choice in part.choices { - let Some(tool_calls) = choice.delta.tool_calls else { - continue; - }; - - for call in tool_calls { - if let Some(func) = call.function { - if func.name.as_deref() == Some(tool_name.as_str()) { - load_state = Some((String::default(), call.index)); - } - if let Some((arguments, (output, index))) = - func.arguments.zip(load_state.as_mut()) - { - if call.index == *index { - output.push_str(&arguments); - } - } - } - } - } - } - if let Some((arguments, _)) = load_state { - return Ok(serde_json::from_str(&arguments)?); - } else { - bail!("tool not used"); - } + let response = response.await?; + Ok( + open_ai::extract_tool_args_from_events(tool_name, Box::pin(response)) + .await? + .boxed(), + ) }) .boxed() } diff --git a/crates/ollama/src/ollama.rs b/crates/ollama/src/ollama.rs index 9f7cb4db27..b7359acf6f 100644 --- a/crates/ollama/src/ollama.rs +++ b/crates/ollama/src/ollama.rs @@ -4,7 +4,7 @@ use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest}; use isahc::config::Configurable; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; -use serde_json::Value; +use serde_json::{value::RawValue, Value}; use std::{convert::TryFrom, sync::Arc, time::Duration}; pub const OLLAMA_API_URL: &str = "http://localhost:11434"; @@ -92,7 +92,7 @@ impl Model { } } -#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +#[derive(Serialize, Deserialize, Debug)] #[serde(tag = "role", rename_all = "lowercase")] pub enum ChatMessage { Assistant { @@ -107,16 +107,16 @@ pub enum ChatMessage { }, } -#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +#[derive(Serialize, Deserialize, Debug)] #[serde(rename_all = "lowercase")] pub enum OllamaToolCall { Function(OllamaFunctionCall), } -#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +#[derive(Serialize, Deserialize, Debug)] pub struct OllamaFunctionCall { pub name: String, - pub arguments: Value, + pub arguments: Box, } #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] diff --git a/crates/open_ai/src/open_ai.rs b/crates/open_ai/src/open_ai.rs index 7ef6d1413a..291fc1d0ec 100644 --- a/crates/open_ai/src/open_ai.rs +++ b/crates/open_ai/src/open_ai.rs @@ -6,7 +6,7 @@ use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest}; use isahc::config::Configurable; use serde::{Deserialize, Serialize}; use serde_json::Value; -use std::{convert::TryFrom, future::Future, time::Duration}; +use std::{convert::TryFrom, future::Future, pin::Pin, time::Duration}; use strum::EnumIter; pub use supported_countries::*; @@ -384,6 +384,57 @@ pub fn embed<'a>( } } +pub async fn extract_tool_args_from_events( + tool_name: String, + mut events: Pin>>>, +) -> Result>> { + let mut tool_use_index = None; + let mut first_chunk = None; + while let Some(event) = events.next().await { + let call = event?.choices.into_iter().find_map(|choice| { + choice.delta.tool_calls?.into_iter().find_map(|call| { + if call.function.as_ref()?.name.as_deref()? == tool_name { + Some(call) + } else { + None + } + }) + }); + if let Some(call) = call { + tool_use_index = Some(call.index); + first_chunk = call.function.and_then(|func| func.arguments); + break; + } + } + + let Some(tool_use_index) = tool_use_index else { + return Err(anyhow!("tool not used")); + }; + + Ok(events.filter_map(move |event| { + let result = match event { + Err(error) => Some(Err(error)), + Ok(ResponseStreamEvent { choices, .. }) => choices.into_iter().find_map(|choice| { + choice.delta.tool_calls?.into_iter().find_map(|call| { + if call.index == tool_use_index { + let func = call.function?; + let mut arguments = func.arguments?; + if let Some(mut first_chunk) = first_chunk.take() { + first_chunk.push_str(&arguments); + arguments = first_chunk + } + Some(Ok(arguments)) + } else { + None + } + }) + }), + }; + + async move { result } + })) +} + pub fn extract_text_from_events( response: impl Stream>, ) -> impl Stream> {