mirror of
https://github.com/zed-industries/zed.git
synced 2024-11-07 20:39:04 +03:00
assistant: Propagate LLM stop reason upwards (#17358)
This PR makes it so we propagate the `stop_reason` from Anthropic up to the Assistant so that we can take action based on it. The `extract_content_from_events` function was moved from `anthropic` to the `anthropic` module in `language_model` since it is more useful if it is able to name the `LanguageModelCompletionEvent` type, as otherwise we'd need an additional layer of plumbing. Release Notes: - N/A
This commit is contained in:
parent
7c8f62e943
commit
f38956943b
1
Cargo.lock
generated
1
Cargo.lock
generated
@ -243,7 +243,6 @@ version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"chrono",
|
||||
"collections",
|
||||
"futures 0.3.30",
|
||||
"http_client",
|
||||
"isahc",
|
||||
|
@ -18,7 +18,6 @@ path = "src/anthropic.rs"
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
chrono.workspace = true
|
||||
collections.workspace = true
|
||||
futures.workspace = true
|
||||
http_client.workspace = true
|
||||
isahc.workspace = true
|
||||
|
@ -5,7 +5,6 @@ use std::{pin::Pin, str::FromStr};
|
||||
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use chrono::{DateTime, Utc};
|
||||
use collections::HashMap;
|
||||
use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
|
||||
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
|
||||
use isahc::config::Configurable;
|
||||
@ -13,7 +12,7 @@ use isahc::http::{HeaderMap, HeaderValue};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use strum::{EnumIter, EnumString};
|
||||
use thiserror::Error;
|
||||
use util::{maybe, ResultExt as _};
|
||||
use util::ResultExt as _;
|
||||
|
||||
pub use supported_countries::*;
|
||||
|
||||
@ -332,94 +331,6 @@ pub async fn stream_completion_with_rate_limit_info(
|
||||
}
|
||||
}
|
||||
|
||||
pub fn extract_content_from_events(
|
||||
events: Pin<Box<dyn Send + Stream<Item = Result<Event, AnthropicError>>>>,
|
||||
) -> impl Stream<Item = Result<ResponseContent, AnthropicError>> {
|
||||
struct RawToolUse {
|
||||
id: String,
|
||||
name: String,
|
||||
input_json: String,
|
||||
}
|
||||
|
||||
struct State {
|
||||
events: Pin<Box<dyn Send + Stream<Item = Result<Event, AnthropicError>>>>,
|
||||
tool_uses_by_index: HashMap<usize, RawToolUse>,
|
||||
}
|
||||
|
||||
futures::stream::unfold(
|
||||
State {
|
||||
events,
|
||||
tool_uses_by_index: HashMap::default(),
|
||||
},
|
||||
|mut state| async move {
|
||||
while let Some(event) = state.events.next().await {
|
||||
match event {
|
||||
Ok(event) => match event {
|
||||
Event::ContentBlockStart {
|
||||
index,
|
||||
content_block,
|
||||
} => match content_block {
|
||||
ResponseContent::Text { text } => {
|
||||
return Some((Some(Ok(ResponseContent::Text { text })), state));
|
||||
}
|
||||
ResponseContent::ToolUse { id, name, .. } => {
|
||||
state.tool_uses_by_index.insert(
|
||||
index,
|
||||
RawToolUse {
|
||||
id,
|
||||
name,
|
||||
input_json: String::new(),
|
||||
},
|
||||
);
|
||||
|
||||
return Some((None, state));
|
||||
}
|
||||
},
|
||||
Event::ContentBlockDelta { index, delta } => match delta {
|
||||
ContentDelta::TextDelta { text } => {
|
||||
return Some((Some(Ok(ResponseContent::Text { text })), state));
|
||||
}
|
||||
ContentDelta::InputJsonDelta { partial_json } => {
|
||||
if let Some(tool_use) = state.tool_uses_by_index.get_mut(&index) {
|
||||
tool_use.input_json.push_str(&partial_json);
|
||||
return Some((None, state));
|
||||
}
|
||||
}
|
||||
},
|
||||
Event::ContentBlockStop { index } => {
|
||||
if let Some(tool_use) = state.tool_uses_by_index.remove(&index) {
|
||||
return Some((
|
||||
Some(maybe!({
|
||||
Ok(ResponseContent::ToolUse {
|
||||
id: tool_use.id,
|
||||
name: tool_use.name,
|
||||
input: serde_json::Value::from_str(
|
||||
&tool_use.input_json,
|
||||
)
|
||||
.map_err(|err| anyhow!(err))?,
|
||||
})
|
||||
})),
|
||||
state,
|
||||
));
|
||||
}
|
||||
}
|
||||
Event::Error { error } => {
|
||||
return Some((Some(Err(AnthropicError::ApiError(error))), state));
|
||||
}
|
||||
_ => {}
|
||||
},
|
||||
Err(err) => {
|
||||
return Some((Some(Err(err)), state));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
},
|
||||
)
|
||||
.filter_map(|event| async move { event })
|
||||
}
|
||||
|
||||
pub async fn extract_tool_args_from_events(
|
||||
tool_name: String,
|
||||
mut events: Pin<Box<dyn Send + Stream<Item = Result<Event>>>>,
|
||||
|
@ -1999,6 +1999,11 @@ impl Context {
|
||||
});
|
||||
|
||||
match event {
|
||||
LanguageModelCompletionEvent::Stop(reason) => match reason {
|
||||
language_model::StopReason::ToolUse => {}
|
||||
language_model::StopReason::EndTurn => {}
|
||||
language_model::StopReason::MaxTokens => {}
|
||||
},
|
||||
LanguageModelCompletionEvent::Text(chunk) => {
|
||||
buffer.edit(
|
||||
[(
|
||||
|
@ -55,10 +55,19 @@ pub struct LanguageModelCacheConfiguration {
|
||||
/// A completion event from a language model.
|
||||
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
|
||||
pub enum LanguageModelCompletionEvent {
|
||||
Stop(StopReason),
|
||||
Text(String),
|
||||
ToolUse(LanguageModelToolUse),
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum StopReason {
|
||||
EndTurn,
|
||||
MaxTokens,
|
||||
ToolUse,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
|
||||
pub struct LanguageModelToolUse {
|
||||
pub id: String,
|
||||
@ -112,6 +121,7 @@ pub trait LanguageModel: Send + Sync {
|
||||
.filter_map(|result| async move {
|
||||
match result {
|
||||
Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
|
||||
Ok(LanguageModelCompletionEvent::Stop(_)) => None,
|
||||
Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
|
||||
Err(err) => Some(Err(err)),
|
||||
}
|
||||
|
@ -3,11 +3,12 @@ use crate::{
|
||||
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
|
||||
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
|
||||
};
|
||||
use crate::{LanguageModelCompletionEvent, LanguageModelToolUse};
|
||||
use anthropic::AnthropicError;
|
||||
use crate::{LanguageModelCompletionEvent, LanguageModelToolUse, StopReason};
|
||||
use anthropic::{AnthropicError, ContentDelta, Event, ResponseContent};
|
||||
use anyhow::{anyhow, Context as _, Result};
|
||||
use collections::BTreeMap;
|
||||
use collections::{BTreeMap, HashMap};
|
||||
use editor::{Editor, EditorElement, EditorStyle};
|
||||
use futures::Stream;
|
||||
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt, TryStreamExt as _};
|
||||
use gpui::{
|
||||
AnyView, AppContext, AsyncAppContext, FontStyle, ModelContext, Subscription, Task, TextStyle,
|
||||
@ -17,11 +18,13 @@ use http_client::HttpClient;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::{Settings, SettingsStore};
|
||||
use std::pin::Pin;
|
||||
use std::str::FromStr;
|
||||
use std::{sync::Arc, time::Duration};
|
||||
use strum::IntoEnumIterator;
|
||||
use theme::ThemeSettings;
|
||||
use ui::{prelude::*, Icon, IconName, Tooltip};
|
||||
use util::ResultExt;
|
||||
use util::{maybe, ResultExt};
|
||||
|
||||
const PROVIDER_ID: &str = "anthropic";
|
||||
const PROVIDER_NAME: &str = "Anthropic";
|
||||
@ -371,30 +374,9 @@ impl LanguageModel for AnthropicModel {
|
||||
let request = self.stream_completion(request, cx);
|
||||
let future = self.request_limiter.stream(async move {
|
||||
let response = request.await.map_err(|err| anyhow!(err))?;
|
||||
Ok(anthropic::extract_content_from_events(response))
|
||||
Ok(map_to_language_model_completion_events(response))
|
||||
});
|
||||
async move {
|
||||
Ok(future
|
||||
.await?
|
||||
.map(|result| {
|
||||
result
|
||||
.map(|content| match content {
|
||||
anthropic::ResponseContent::Text { text } => {
|
||||
LanguageModelCompletionEvent::Text(text)
|
||||
}
|
||||
anthropic::ResponseContent::ToolUse { id, name, input } => {
|
||||
LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse {
|
||||
id,
|
||||
name,
|
||||
input,
|
||||
})
|
||||
}
|
||||
})
|
||||
.map_err(|err| anyhow!(err))
|
||||
})
|
||||
.boxed())
|
||||
}
|
||||
.boxed()
|
||||
async move { Ok(future.await?.boxed()) }.boxed()
|
||||
}
|
||||
|
||||
fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
|
||||
@ -443,6 +425,120 @@ impl LanguageModel for AnthropicModel {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn map_to_language_model_completion_events(
|
||||
events: Pin<Box<dyn Send + Stream<Item = Result<Event, AnthropicError>>>>,
|
||||
) -> impl Stream<Item = Result<LanguageModelCompletionEvent>> {
|
||||
struct RawToolUse {
|
||||
id: String,
|
||||
name: String,
|
||||
input_json: String,
|
||||
}
|
||||
|
||||
struct State {
|
||||
events: Pin<Box<dyn Send + Stream<Item = Result<Event, AnthropicError>>>>,
|
||||
tool_uses_by_index: HashMap<usize, RawToolUse>,
|
||||
}
|
||||
|
||||
futures::stream::unfold(
|
||||
State {
|
||||
events,
|
||||
tool_uses_by_index: HashMap::default(),
|
||||
},
|
||||
|mut state| async move {
|
||||
while let Some(event) = state.events.next().await {
|
||||
match event {
|
||||
Ok(event) => match event {
|
||||
Event::ContentBlockStart {
|
||||
index,
|
||||
content_block,
|
||||
} => match content_block {
|
||||
ResponseContent::Text { text } => {
|
||||
return Some((
|
||||
Some(Ok(LanguageModelCompletionEvent::Text(text))),
|
||||
state,
|
||||
));
|
||||
}
|
||||
ResponseContent::ToolUse { id, name, .. } => {
|
||||
state.tool_uses_by_index.insert(
|
||||
index,
|
||||
RawToolUse {
|
||||
id,
|
||||
name,
|
||||
input_json: String::new(),
|
||||
},
|
||||
);
|
||||
|
||||
return Some((None, state));
|
||||
}
|
||||
},
|
||||
Event::ContentBlockDelta { index, delta } => match delta {
|
||||
ContentDelta::TextDelta { text } => {
|
||||
return Some((
|
||||
Some(Ok(LanguageModelCompletionEvent::Text(text))),
|
||||
state,
|
||||
));
|
||||
}
|
||||
ContentDelta::InputJsonDelta { partial_json } => {
|
||||
if let Some(tool_use) = state.tool_uses_by_index.get_mut(&index) {
|
||||
tool_use.input_json.push_str(&partial_json);
|
||||
return Some((None, state));
|
||||
}
|
||||
}
|
||||
},
|
||||
Event::ContentBlockStop { index } => {
|
||||
if let Some(tool_use) = state.tool_uses_by_index.remove(&index) {
|
||||
return Some((
|
||||
Some(maybe!({
|
||||
Ok(LanguageModelCompletionEvent::ToolUse(
|
||||
LanguageModelToolUse {
|
||||
id: tool_use.id,
|
||||
name: tool_use.name,
|
||||
input: serde_json::Value::from_str(
|
||||
&tool_use.input_json,
|
||||
)
|
||||
.map_err(|err| anyhow!(err))?,
|
||||
},
|
||||
))
|
||||
})),
|
||||
state,
|
||||
));
|
||||
}
|
||||
}
|
||||
Event::MessageDelta { delta, .. } => {
|
||||
if let Some(stop_reason) = delta.stop_reason.as_deref() {
|
||||
let stop_reason = match stop_reason {
|
||||
"end_turn" => StopReason::EndTurn,
|
||||
"max_tokens" => StopReason::MaxTokens,
|
||||
"tool_use" => StopReason::ToolUse,
|
||||
_ => StopReason::EndTurn,
|
||||
};
|
||||
|
||||
return Some((
|
||||
Some(Ok(LanguageModelCompletionEvent::Stop(stop_reason))),
|
||||
state,
|
||||
));
|
||||
}
|
||||
}
|
||||
Event::Error { error } => {
|
||||
return Some((
|
||||
Some(Err(anyhow!(AnthropicError::ApiError(error)))),
|
||||
state,
|
||||
));
|
||||
}
|
||||
_ => {}
|
||||
},
|
||||
Err(err) => {
|
||||
return Some((Some(Err(anyhow!(err))), state));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
},
|
||||
)
|
||||
.filter_map(|event| async move { event })
|
||||
}
|
||||
|
||||
struct ConfigurationView {
|
||||
api_key_editor: View<Editor>,
|
||||
state: gpui::Model<State>,
|
||||
|
@ -1,4 +1,5 @@
|
||||
use super::open_ai::count_open_ai_tokens;
|
||||
use crate::provider::anthropic::map_to_language_model_completion_events;
|
||||
use crate::{
|
||||
settings::AllLanguageModelSettings, CloudModel, LanguageModel, LanguageModelCacheConfiguration,
|
||||
LanguageModelId, LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
|
||||
@ -33,10 +34,7 @@ use std::{
|
||||
use strum::IntoEnumIterator;
|
||||
use ui::{prelude::*, TintColor};
|
||||
|
||||
use crate::{
|
||||
LanguageModelAvailability, LanguageModelCompletionEvent, LanguageModelProvider,
|
||||
LanguageModelToolUse,
|
||||
};
|
||||
use crate::{LanguageModelAvailability, LanguageModelCompletionEvent, LanguageModelProvider};
|
||||
|
||||
use super::anthropic::count_anthropic_tokens;
|
||||
|
||||
@ -518,30 +516,11 @@ impl LanguageModel for CloudLanguageModel {
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
Ok(anthropic::extract_content_from_events(Box::pin(
|
||||
Ok(map_to_language_model_completion_events(Box::pin(
|
||||
response_lines(response).map_err(AnthropicError::Other),
|
||||
)))
|
||||
});
|
||||
async move {
|
||||
Ok(future
|
||||
.await?
|
||||
.map(|result| {
|
||||
result
|
||||
.map(|content| match content {
|
||||
anthropic::ResponseContent::Text { text } => {
|
||||
LanguageModelCompletionEvent::Text(text)
|
||||
}
|
||||
anthropic::ResponseContent::ToolUse { id, name, input } => {
|
||||
LanguageModelCompletionEvent::ToolUse(
|
||||
LanguageModelToolUse { id, name, input },
|
||||
)
|
||||
}
|
||||
})
|
||||
.map_err(|err| anyhow!(err))
|
||||
})
|
||||
.boxed())
|
||||
}
|
||||
.boxed()
|
||||
async move { Ok(future.await?.boxed()) }.boxed()
|
||||
}
|
||||
CloudModel::OpenAi(model) => {
|
||||
let client = self.client.clone();
|
||||
|
Loading…
Reference in New Issue
Block a user