assistant: Propagate LLM stop reason upwards (#17358)

This PR makes it so we propagate the `stop_reason` from Anthropic up to
the Assistant so that we can take action based on it.

The `extract_content_from_events` function was moved from `anthropic` to
the `anthropic` module in `language_model` since it is more useful if it
is able to name the `LanguageModelCompletionEvent` type, as otherwise
we'd need an additional layer of plumbing.

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2024-09-04 12:31:10 -04:00 committed by GitHub
parent 7c8f62e943
commit f38956943b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 143 additions and 144 deletions

1
Cargo.lock generated
View File

@ -243,7 +243,6 @@ version = "0.1.0"
dependencies = [
"anyhow",
"chrono",
"collections",
"futures 0.3.30",
"http_client",
"isahc",

View File

@ -18,7 +18,6 @@ path = "src/anthropic.rs"
[dependencies]
anyhow.workspace = true
chrono.workspace = true
collections.workspace = true
futures.workspace = true
http_client.workspace = true
isahc.workspace = true

View File

@ -5,7 +5,6 @@ use std::{pin::Pin, str::FromStr};
use anyhow::{anyhow, Context, Result};
use chrono::{DateTime, Utc};
use collections::HashMap;
use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
use isahc::config::Configurable;
@ -13,7 +12,7 @@ use isahc::http::{HeaderMap, HeaderValue};
use serde::{Deserialize, Serialize};
use strum::{EnumIter, EnumString};
use thiserror::Error;
use util::{maybe, ResultExt as _};
use util::ResultExt as _;
pub use supported_countries::*;
@ -332,94 +331,6 @@ pub async fn stream_completion_with_rate_limit_info(
}
}
pub fn extract_content_from_events(
events: Pin<Box<dyn Send + Stream<Item = Result<Event, AnthropicError>>>>,
) -> impl Stream<Item = Result<ResponseContent, AnthropicError>> {
struct RawToolUse {
id: String,
name: String,
input_json: String,
}
struct State {
events: Pin<Box<dyn Send + Stream<Item = Result<Event, AnthropicError>>>>,
tool_uses_by_index: HashMap<usize, RawToolUse>,
}
futures::stream::unfold(
State {
events,
tool_uses_by_index: HashMap::default(),
},
|mut state| async move {
while let Some(event) = state.events.next().await {
match event {
Ok(event) => match event {
Event::ContentBlockStart {
index,
content_block,
} => match content_block {
ResponseContent::Text { text } => {
return Some((Some(Ok(ResponseContent::Text { text })), state));
}
ResponseContent::ToolUse { id, name, .. } => {
state.tool_uses_by_index.insert(
index,
RawToolUse {
id,
name,
input_json: String::new(),
},
);
return Some((None, state));
}
},
Event::ContentBlockDelta { index, delta } => match delta {
ContentDelta::TextDelta { text } => {
return Some((Some(Ok(ResponseContent::Text { text })), state));
}
ContentDelta::InputJsonDelta { partial_json } => {
if let Some(tool_use) = state.tool_uses_by_index.get_mut(&index) {
tool_use.input_json.push_str(&partial_json);
return Some((None, state));
}
}
},
Event::ContentBlockStop { index } => {
if let Some(tool_use) = state.tool_uses_by_index.remove(&index) {
return Some((
Some(maybe!({
Ok(ResponseContent::ToolUse {
id: tool_use.id,
name: tool_use.name,
input: serde_json::Value::from_str(
&tool_use.input_json,
)
.map_err(|err| anyhow!(err))?,
})
})),
state,
));
}
}
Event::Error { error } => {
return Some((Some(Err(AnthropicError::ApiError(error))), state));
}
_ => {}
},
Err(err) => {
return Some((Some(Err(err)), state));
}
}
}
None
},
)
.filter_map(|event| async move { event })
}
pub async fn extract_tool_args_from_events(
tool_name: String,
mut events: Pin<Box<dyn Send + Stream<Item = Result<Event>>>>,

View File

@ -1999,6 +1999,11 @@ impl Context {
});
match event {
LanguageModelCompletionEvent::Stop(reason) => match reason {
language_model::StopReason::ToolUse => {}
language_model::StopReason::EndTurn => {}
language_model::StopReason::MaxTokens => {}
},
LanguageModelCompletionEvent::Text(chunk) => {
buffer.edit(
[(

View File

@ -55,10 +55,19 @@ pub struct LanguageModelCacheConfiguration {
/// A completion event from a language model.
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
pub enum LanguageModelCompletionEvent {
Stop(StopReason),
Text(String),
ToolUse(LanguageModelToolUse),
}
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum StopReason {
EndTurn,
MaxTokens,
ToolUse,
}
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
pub struct LanguageModelToolUse {
pub id: String,
@ -112,6 +121,7 @@ pub trait LanguageModel: Send + Sync {
.filter_map(|result| async move {
match result {
Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
Ok(LanguageModelCompletionEvent::Stop(_)) => None,
Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
Err(err) => Some(Err(err)),
}

View File

@ -3,11 +3,12 @@ use crate::{
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
};
use crate::{LanguageModelCompletionEvent, LanguageModelToolUse};
use anthropic::AnthropicError;
use crate::{LanguageModelCompletionEvent, LanguageModelToolUse, StopReason};
use anthropic::{AnthropicError, ContentDelta, Event, ResponseContent};
use anyhow::{anyhow, Context as _, Result};
use collections::BTreeMap;
use collections::{BTreeMap, HashMap};
use editor::{Editor, EditorElement, EditorStyle};
use futures::Stream;
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt, TryStreamExt as _};
use gpui::{
AnyView, AppContext, AsyncAppContext, FontStyle, ModelContext, Subscription, Task, TextStyle,
@ -17,11 +18,13 @@ use http_client::HttpClient;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore};
use std::pin::Pin;
use std::str::FromStr;
use std::{sync::Arc, time::Duration};
use strum::IntoEnumIterator;
use theme::ThemeSettings;
use ui::{prelude::*, Icon, IconName, Tooltip};
use util::ResultExt;
use util::{maybe, ResultExt};
const PROVIDER_ID: &str = "anthropic";
const PROVIDER_NAME: &str = "Anthropic";
@ -371,30 +374,9 @@ impl LanguageModel for AnthropicModel {
let request = self.stream_completion(request, cx);
let future = self.request_limiter.stream(async move {
let response = request.await.map_err(|err| anyhow!(err))?;
Ok(anthropic::extract_content_from_events(response))
Ok(map_to_language_model_completion_events(response))
});
async move {
Ok(future
.await?
.map(|result| {
result
.map(|content| match content {
anthropic::ResponseContent::Text { text } => {
LanguageModelCompletionEvent::Text(text)
}
anthropic::ResponseContent::ToolUse { id, name, input } => {
LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse {
id,
name,
input,
})
}
})
.map_err(|err| anyhow!(err))
})
.boxed())
}
.boxed()
async move { Ok(future.await?.boxed()) }.boxed()
}
fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
@ -443,6 +425,120 @@ impl LanguageModel for AnthropicModel {
}
}
pub fn map_to_language_model_completion_events(
events: Pin<Box<dyn Send + Stream<Item = Result<Event, AnthropicError>>>>,
) -> impl Stream<Item = Result<LanguageModelCompletionEvent>> {
struct RawToolUse {
id: String,
name: String,
input_json: String,
}
struct State {
events: Pin<Box<dyn Send + Stream<Item = Result<Event, AnthropicError>>>>,
tool_uses_by_index: HashMap<usize, RawToolUse>,
}
futures::stream::unfold(
State {
events,
tool_uses_by_index: HashMap::default(),
},
|mut state| async move {
while let Some(event) = state.events.next().await {
match event {
Ok(event) => match event {
Event::ContentBlockStart {
index,
content_block,
} => match content_block {
ResponseContent::Text { text } => {
return Some((
Some(Ok(LanguageModelCompletionEvent::Text(text))),
state,
));
}
ResponseContent::ToolUse { id, name, .. } => {
state.tool_uses_by_index.insert(
index,
RawToolUse {
id,
name,
input_json: String::new(),
},
);
return Some((None, state));
}
},
Event::ContentBlockDelta { index, delta } => match delta {
ContentDelta::TextDelta { text } => {
return Some((
Some(Ok(LanguageModelCompletionEvent::Text(text))),
state,
));
}
ContentDelta::InputJsonDelta { partial_json } => {
if let Some(tool_use) = state.tool_uses_by_index.get_mut(&index) {
tool_use.input_json.push_str(&partial_json);
return Some((None, state));
}
}
},
Event::ContentBlockStop { index } => {
if let Some(tool_use) = state.tool_uses_by_index.remove(&index) {
return Some((
Some(maybe!({
Ok(LanguageModelCompletionEvent::ToolUse(
LanguageModelToolUse {
id: tool_use.id,
name: tool_use.name,
input: serde_json::Value::from_str(
&tool_use.input_json,
)
.map_err(|err| anyhow!(err))?,
},
))
})),
state,
));
}
}
Event::MessageDelta { delta, .. } => {
if let Some(stop_reason) = delta.stop_reason.as_deref() {
let stop_reason = match stop_reason {
"end_turn" => StopReason::EndTurn,
"max_tokens" => StopReason::MaxTokens,
"tool_use" => StopReason::ToolUse,
_ => StopReason::EndTurn,
};
return Some((
Some(Ok(LanguageModelCompletionEvent::Stop(stop_reason))),
state,
));
}
}
Event::Error { error } => {
return Some((
Some(Err(anyhow!(AnthropicError::ApiError(error)))),
state,
));
}
_ => {}
},
Err(err) => {
return Some((Some(Err(anyhow!(err))), state));
}
}
}
None
},
)
.filter_map(|event| async move { event })
}
struct ConfigurationView {
api_key_editor: View<Editor>,
state: gpui::Model<State>,

View File

@ -1,4 +1,5 @@
use super::open_ai::count_open_ai_tokens;
use crate::provider::anthropic::map_to_language_model_completion_events;
use crate::{
settings::AllLanguageModelSettings, CloudModel, LanguageModel, LanguageModelCacheConfiguration,
LanguageModelId, LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
@ -33,10 +34,7 @@ use std::{
use strum::IntoEnumIterator;
use ui::{prelude::*, TintColor};
use crate::{
LanguageModelAvailability, LanguageModelCompletionEvent, LanguageModelProvider,
LanguageModelToolUse,
};
use crate::{LanguageModelAvailability, LanguageModelCompletionEvent, LanguageModelProvider};
use super::anthropic::count_anthropic_tokens;
@ -518,30 +516,11 @@ impl LanguageModel for CloudLanguageModel {
},
)
.await?;
Ok(anthropic::extract_content_from_events(Box::pin(
Ok(map_to_language_model_completion_events(Box::pin(
response_lines(response).map_err(AnthropicError::Other),
)))
});
async move {
Ok(future
.await?
.map(|result| {
result
.map(|content| match content {
anthropic::ResponseContent::Text { text } => {
LanguageModelCompletionEvent::Text(text)
}
anthropic::ResponseContent::ToolUse { id, name, input } => {
LanguageModelCompletionEvent::ToolUse(
LanguageModelToolUse { id, name, input },
)
}
})
.map_err(|err| anyhow!(err))
})
.boxed())
}
.boxed()
async move { Ok(future.await?.boxed()) }.boxed()
}
CloudModel::OpenAi(model) => {
let client = self.client.clone();