From f38956943bb004d79511efd7d85fcd8efde88144 Mon Sep 17 00:00:00 2001 From: Marshall Bowers Date: Wed, 4 Sep 2024 12:31:10 -0400 Subject: [PATCH] assistant: Propagate LLM stop reason upwards (#17358) This PR makes it so we propagate the `stop_reason` from Anthropic up to the Assistant so that we can take action based on it. The `extract_content_from_events` function was moved from `anthropic` to the `anthropic` module in `language_model` since it is more useful if it is able to name the `LanguageModelCompletionEvent` type, as otherwise we'd need an additional layer of plumbing. Release Notes: - N/A --- Cargo.lock | 1 - crates/anthropic/Cargo.toml | 1 - crates/anthropic/src/anthropic.rs | 91 +---------- crates/assistant/src/context.rs | 5 + crates/language_model/src/language_model.rs | 10 ++ .../language_model/src/provider/anthropic.rs | 150 ++++++++++++++---- crates/language_model/src/provider/cloud.rs | 29 +--- 7 files changed, 143 insertions(+), 144 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 03ba1aca64..60bb96eee5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -243,7 +243,6 @@ version = "0.1.0" dependencies = [ "anyhow", "chrono", - "collections", "futures 0.3.30", "http_client", "isahc", diff --git a/crates/anthropic/Cargo.toml b/crates/anthropic/Cargo.toml index ddab9dfd7c..9e48ad0e57 100644 --- a/crates/anthropic/Cargo.toml +++ b/crates/anthropic/Cargo.toml @@ -18,7 +18,6 @@ path = "src/anthropic.rs" [dependencies] anyhow.workspace = true chrono.workspace = true -collections.workspace = true futures.workspace = true http_client.workspace = true isahc.workspace = true diff --git a/crates/anthropic/src/anthropic.rs b/crates/anthropic/src/anthropic.rs index 03aec20568..6ac10ff793 100644 --- a/crates/anthropic/src/anthropic.rs +++ b/crates/anthropic/src/anthropic.rs @@ -5,7 +5,6 @@ use std::{pin::Pin, str::FromStr}; use anyhow::{anyhow, Context, Result}; use chrono::{DateTime, Utc}; -use collections::HashMap; use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt}; use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest}; use isahc::config::Configurable; @@ -13,7 +12,7 @@ use isahc::http::{HeaderMap, HeaderValue}; use serde::{Deserialize, Serialize}; use strum::{EnumIter, EnumString}; use thiserror::Error; -use util::{maybe, ResultExt as _}; +use util::ResultExt as _; pub use supported_countries::*; @@ -332,94 +331,6 @@ pub async fn stream_completion_with_rate_limit_info( } } -pub fn extract_content_from_events( - events: Pin>>>, -) -> impl Stream> { - struct RawToolUse { - id: String, - name: String, - input_json: String, - } - - struct State { - events: Pin>>>, - tool_uses_by_index: HashMap, - } - - futures::stream::unfold( - State { - events, - tool_uses_by_index: HashMap::default(), - }, - |mut state| async move { - while let Some(event) = state.events.next().await { - match event { - Ok(event) => match event { - Event::ContentBlockStart { - index, - content_block, - } => match content_block { - ResponseContent::Text { text } => { - return Some((Some(Ok(ResponseContent::Text { text })), state)); - } - ResponseContent::ToolUse { id, name, .. } => { - state.tool_uses_by_index.insert( - index, - RawToolUse { - id, - name, - input_json: String::new(), - }, - ); - - return Some((None, state)); - } - }, - Event::ContentBlockDelta { index, delta } => match delta { - ContentDelta::TextDelta { text } => { - return Some((Some(Ok(ResponseContent::Text { text })), state)); - } - ContentDelta::InputJsonDelta { partial_json } => { - if let Some(tool_use) = state.tool_uses_by_index.get_mut(&index) { - tool_use.input_json.push_str(&partial_json); - return Some((None, state)); - } - } - }, - Event::ContentBlockStop { index } => { - if let Some(tool_use) = state.tool_uses_by_index.remove(&index) { - return Some(( - Some(maybe!({ - Ok(ResponseContent::ToolUse { - id: tool_use.id, - name: tool_use.name, - input: serde_json::Value::from_str( - &tool_use.input_json, - ) - .map_err(|err| anyhow!(err))?, - }) - })), - state, - )); - } - } - Event::Error { error } => { - return Some((Some(Err(AnthropicError::ApiError(error))), state)); - } - _ => {} - }, - Err(err) => { - return Some((Some(Err(err)), state)); - } - } - } - - None - }, - ) - .filter_map(|event| async move { event }) -} - pub async fn extract_tool_args_from_events( tool_name: String, mut events: Pin>>>, diff --git a/crates/assistant/src/context.rs b/crates/assistant/src/context.rs index 51e7d626d7..ec2248f19f 100644 --- a/crates/assistant/src/context.rs +++ b/crates/assistant/src/context.rs @@ -1999,6 +1999,11 @@ impl Context { }); match event { + LanguageModelCompletionEvent::Stop(reason) => match reason { + language_model::StopReason::ToolUse => {} + language_model::StopReason::EndTurn => {} + language_model::StopReason::MaxTokens => {} + }, LanguageModelCompletionEvent::Text(chunk) => { buffer.edit( [( diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs index cd85ca7f53..d24a5f9001 100644 --- a/crates/language_model/src/language_model.rs +++ b/crates/language_model/src/language_model.rs @@ -55,10 +55,19 @@ pub struct LanguageModelCacheConfiguration { /// A completion event from a language model. #[derive(Debug, PartialEq, Clone, Serialize, Deserialize)] pub enum LanguageModelCompletionEvent { + Stop(StopReason), Text(String), ToolUse(LanguageModelToolUse), } +#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum StopReason { + EndTurn, + MaxTokens, + ToolUse, +} + #[derive(Debug, PartialEq, Clone, Serialize, Deserialize)] pub struct LanguageModelToolUse { pub id: String, @@ -112,6 +121,7 @@ pub trait LanguageModel: Send + Sync { .filter_map(|result| async move { match result { Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)), + Ok(LanguageModelCompletionEvent::Stop(_)) => None, Ok(LanguageModelCompletionEvent::ToolUse(_)) => None, Err(err) => Some(Err(err)), } diff --git a/crates/language_model/src/provider/anthropic.rs b/crates/language_model/src/provider/anthropic.rs index 8258768a6a..62b049c9ea 100644 --- a/crates/language_model/src/provider/anthropic.rs +++ b/crates/language_model/src/provider/anthropic.rs @@ -3,11 +3,12 @@ use crate::{ LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role, }; -use crate::{LanguageModelCompletionEvent, LanguageModelToolUse}; -use anthropic::AnthropicError; +use crate::{LanguageModelCompletionEvent, LanguageModelToolUse, StopReason}; +use anthropic::{AnthropicError, ContentDelta, Event, ResponseContent}; use anyhow::{anyhow, Context as _, Result}; -use collections::BTreeMap; +use collections::{BTreeMap, HashMap}; use editor::{Editor, EditorElement, EditorStyle}; +use futures::Stream; use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt, TryStreamExt as _}; use gpui::{ AnyView, AppContext, AsyncAppContext, FontStyle, ModelContext, Subscription, Task, TextStyle, @@ -17,11 +18,13 @@ use http_client::HttpClient; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use settings::{Settings, SettingsStore}; +use std::pin::Pin; +use std::str::FromStr; use std::{sync::Arc, time::Duration}; use strum::IntoEnumIterator; use theme::ThemeSettings; use ui::{prelude::*, Icon, IconName, Tooltip}; -use util::ResultExt; +use util::{maybe, ResultExt}; const PROVIDER_ID: &str = "anthropic"; const PROVIDER_NAME: &str = "Anthropic"; @@ -371,30 +374,9 @@ impl LanguageModel for AnthropicModel { let request = self.stream_completion(request, cx); let future = self.request_limiter.stream(async move { let response = request.await.map_err(|err| anyhow!(err))?; - Ok(anthropic::extract_content_from_events(response)) + Ok(map_to_language_model_completion_events(response)) }); - async move { - Ok(future - .await? - .map(|result| { - result - .map(|content| match content { - anthropic::ResponseContent::Text { text } => { - LanguageModelCompletionEvent::Text(text) - } - anthropic::ResponseContent::ToolUse { id, name, input } => { - LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { - id, - name, - input, - }) - } - }) - .map_err(|err| anyhow!(err)) - }) - .boxed()) - } - .boxed() + async move { Ok(future.await?.boxed()) }.boxed() } fn cache_configuration(&self) -> Option { @@ -443,6 +425,120 @@ impl LanguageModel for AnthropicModel { } } +pub fn map_to_language_model_completion_events( + events: Pin>>>, +) -> impl Stream> { + struct RawToolUse { + id: String, + name: String, + input_json: String, + } + + struct State { + events: Pin>>>, + tool_uses_by_index: HashMap, + } + + futures::stream::unfold( + State { + events, + tool_uses_by_index: HashMap::default(), + }, + |mut state| async move { + while let Some(event) = state.events.next().await { + match event { + Ok(event) => match event { + Event::ContentBlockStart { + index, + content_block, + } => match content_block { + ResponseContent::Text { text } => { + return Some(( + Some(Ok(LanguageModelCompletionEvent::Text(text))), + state, + )); + } + ResponseContent::ToolUse { id, name, .. } => { + state.tool_uses_by_index.insert( + index, + RawToolUse { + id, + name, + input_json: String::new(), + }, + ); + + return Some((None, state)); + } + }, + Event::ContentBlockDelta { index, delta } => match delta { + ContentDelta::TextDelta { text } => { + return Some(( + Some(Ok(LanguageModelCompletionEvent::Text(text))), + state, + )); + } + ContentDelta::InputJsonDelta { partial_json } => { + if let Some(tool_use) = state.tool_uses_by_index.get_mut(&index) { + tool_use.input_json.push_str(&partial_json); + return Some((None, state)); + } + } + }, + Event::ContentBlockStop { index } => { + if let Some(tool_use) = state.tool_uses_by_index.remove(&index) { + return Some(( + Some(maybe!({ + Ok(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: tool_use.id, + name: tool_use.name, + input: serde_json::Value::from_str( + &tool_use.input_json, + ) + .map_err(|err| anyhow!(err))?, + }, + )) + })), + state, + )); + } + } + Event::MessageDelta { delta, .. } => { + if let Some(stop_reason) = delta.stop_reason.as_deref() { + let stop_reason = match stop_reason { + "end_turn" => StopReason::EndTurn, + "max_tokens" => StopReason::MaxTokens, + "tool_use" => StopReason::ToolUse, + _ => StopReason::EndTurn, + }; + + return Some(( + Some(Ok(LanguageModelCompletionEvent::Stop(stop_reason))), + state, + )); + } + } + Event::Error { error } => { + return Some(( + Some(Err(anyhow!(AnthropicError::ApiError(error)))), + state, + )); + } + _ => {} + }, + Err(err) => { + return Some((Some(Err(anyhow!(err))), state)); + } + } + } + + None + }, + ) + .filter_map(|event| async move { event }) +} + struct ConfigurationView { api_key_editor: View, state: gpui::Model, diff --git a/crates/language_model/src/provider/cloud.rs b/crates/language_model/src/provider/cloud.rs index a9b0008bbd..d3741d2078 100644 --- a/crates/language_model/src/provider/cloud.rs +++ b/crates/language_model/src/provider/cloud.rs @@ -1,4 +1,5 @@ use super::open_ai::count_open_ai_tokens; +use crate::provider::anthropic::map_to_language_model_completion_events; use crate::{ settings::AllLanguageModelSettings, CloudModel, LanguageModel, LanguageModelCacheConfiguration, LanguageModelId, LanguageModelName, LanguageModelProviderId, LanguageModelProviderName, @@ -33,10 +34,7 @@ use std::{ use strum::IntoEnumIterator; use ui::{prelude::*, TintColor}; -use crate::{ - LanguageModelAvailability, LanguageModelCompletionEvent, LanguageModelProvider, - LanguageModelToolUse, -}; +use crate::{LanguageModelAvailability, LanguageModelCompletionEvent, LanguageModelProvider}; use super::anthropic::count_anthropic_tokens; @@ -518,30 +516,11 @@ impl LanguageModel for CloudLanguageModel { }, ) .await?; - Ok(anthropic::extract_content_from_events(Box::pin( + Ok(map_to_language_model_completion_events(Box::pin( response_lines(response).map_err(AnthropicError::Other), ))) }); - async move { - Ok(future - .await? - .map(|result| { - result - .map(|content| match content { - anthropic::ResponseContent::Text { text } => { - LanguageModelCompletionEvent::Text(text) - } - anthropic::ResponseContent::ToolUse { id, name, input } => { - LanguageModelCompletionEvent::ToolUse( - LanguageModelToolUse { id, name, input }, - ) - } - }) - .map_err(|err| anyhow!(err)) - }) - .boxed()) - } - .boxed() + async move { Ok(future.await?.boxed()) }.boxed() } CloudModel::OpenAi(model) => { let client = self.client.clone();