mirror of
https://github.com/zed-industries/zed.git
synced 2024-11-07 20:39:04 +03:00
assistant: Stream tool uses as structured data (#17322)
This PR adjusts the approach we use to encoding tool uses in the completion response to use a structured format rather than simply injecting it into the response stream as text. In #17170 we would encode the tool uses as XML and insert them as text. This would require then re-parsing the tool uses out of the buffer in order to use them. The approach taken in this PR is to make `stream_completion` return a stream of `LanguageModelCompletionEvent`s. Each of these events can be either text, or a tool use. A new `stream_completion_text` method has been added to `LanguageModel` for scenarios where we only care about textual content (currently, everywhere that isn't the Assistant context editor). Release Notes: - N/A
This commit is contained in:
parent
132e8e8064
commit
452272e5df
1
Cargo.lock
generated
1
Cargo.lock
generated
@ -243,6 +243,7 @@ version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"chrono",
|
||||
"collections",
|
||||
"futures 0.3.30",
|
||||
"http_client",
|
||||
"isahc",
|
||||
|
@ -18,6 +18,7 @@ path = "src/anthropic.rs"
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
chrono.workspace = true
|
||||
collections.workspace = true
|
||||
futures.workspace = true
|
||||
http_client.workspace = true
|
||||
isahc.workspace = true
|
||||
|
@ -1,17 +1,19 @@
|
||||
mod supported_countries;
|
||||
|
||||
use std::time::Duration;
|
||||
use std::{pin::Pin, str::FromStr};
|
||||
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use chrono::{DateTime, Utc};
|
||||
use collections::HashMap;
|
||||
use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
|
||||
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
|
||||
use isahc::config::Configurable;
|
||||
use isahc::http::{HeaderMap, HeaderValue};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::time::Duration;
|
||||
use std::{pin::Pin, str::FromStr};
|
||||
use strum::{EnumIter, EnumString};
|
||||
use thiserror::Error;
|
||||
use util::ResultExt as _;
|
||||
use util::{maybe, ResultExt as _};
|
||||
|
||||
pub use supported_countries::*;
|
||||
|
||||
@ -332,19 +334,22 @@ pub async fn stream_completion_with_rate_limit_info(
|
||||
|
||||
pub fn extract_content_from_events(
|
||||
events: Pin<Box<dyn Send + Stream<Item = Result<Event, AnthropicError>>>>,
|
||||
) -> impl Stream<Item = Result<String, AnthropicError>> {
|
||||
struct State {
|
||||
events: Pin<Box<dyn Send + Stream<Item = Result<Event, AnthropicError>>>>,
|
||||
current_tool_use_index: Option<usize>,
|
||||
) -> impl Stream<Item = Result<ResponseContent, AnthropicError>> {
|
||||
struct RawToolUse {
|
||||
id: String,
|
||||
name: String,
|
||||
input_json: String,
|
||||
}
|
||||
|
||||
const INDENT: &str = " ";
|
||||
const NEWLINE: char = '\n';
|
||||
struct State {
|
||||
events: Pin<Box<dyn Send + Stream<Item = Result<Event, AnthropicError>>>>,
|
||||
tool_uses_by_index: HashMap<usize, RawToolUse>,
|
||||
}
|
||||
|
||||
futures::stream::unfold(
|
||||
State {
|
||||
events,
|
||||
current_tool_use_index: None,
|
||||
tool_uses_by_index: HashMap::default(),
|
||||
},
|
||||
|mut state| async move {
|
||||
while let Some(event) = state.events.next().await {
|
||||
@ -355,62 +360,56 @@ pub fn extract_content_from_events(
|
||||
content_block,
|
||||
} => match content_block {
|
||||
ResponseContent::Text { text } => {
|
||||
return Some((Ok(text), state));
|
||||
return Some((Some(Ok(ResponseContent::Text { text })), state));
|
||||
}
|
||||
ResponseContent::ToolUse { id, name, .. } => {
|
||||
state.current_tool_use_index = Some(index);
|
||||
state.tool_uses_by_index.insert(
|
||||
index,
|
||||
RawToolUse {
|
||||
id,
|
||||
name,
|
||||
input_json: String::new(),
|
||||
},
|
||||
);
|
||||
|
||||
let mut text = String::new();
|
||||
text.push(NEWLINE);
|
||||
|
||||
text.push_str("<tool_use>");
|
||||
text.push(NEWLINE);
|
||||
|
||||
text.push_str(INDENT);
|
||||
text.push_str("<id>");
|
||||
text.push_str(&id);
|
||||
text.push_str("</id>");
|
||||
text.push(NEWLINE);
|
||||
|
||||
text.push_str(INDENT);
|
||||
text.push_str("<name>");
|
||||
text.push_str(&name);
|
||||
text.push_str("</name>");
|
||||
text.push(NEWLINE);
|
||||
|
||||
text.push_str(INDENT);
|
||||
text.push_str("<input>");
|
||||
|
||||
return Some((Ok(text), state));
|
||||
return Some((None, state));
|
||||
}
|
||||
},
|
||||
Event::ContentBlockDelta { index, delta } => match delta {
|
||||
ContentDelta::TextDelta { text } => {
|
||||
return Some((Ok(text), state));
|
||||
return Some((Some(Ok(ResponseContent::Text { text })), state));
|
||||
}
|
||||
ContentDelta::InputJsonDelta { partial_json } => {
|
||||
if Some(index) == state.current_tool_use_index {
|
||||
return Some((Ok(partial_json), state));
|
||||
if let Some(tool_use) = state.tool_uses_by_index.get_mut(&index) {
|
||||
tool_use.input_json.push_str(&partial_json);
|
||||
return Some((None, state));
|
||||
}
|
||||
}
|
||||
},
|
||||
Event::ContentBlockStop { index } => {
|
||||
if Some(index) == state.current_tool_use_index.take() {
|
||||
let mut text = String::new();
|
||||
text.push_str("</input>");
|
||||
text.push(NEWLINE);
|
||||
text.push_str("</tool_use>");
|
||||
|
||||
return Some((Ok(text), state));
|
||||
if let Some(tool_use) = state.tool_uses_by_index.remove(&index) {
|
||||
return Some((
|
||||
Some(maybe!({
|
||||
Ok(ResponseContent::ToolUse {
|
||||
id: tool_use.id,
|
||||
name: tool_use.name,
|
||||
input: serde_json::Value::from_str(
|
||||
&tool_use.input_json,
|
||||
)
|
||||
.map_err(|err| anyhow!(err))?,
|
||||
})
|
||||
})),
|
||||
state,
|
||||
));
|
||||
}
|
||||
}
|
||||
Event::Error { error } => {
|
||||
return Some((Err(AnthropicError::ApiError(error)), state));
|
||||
return Some((Some(Err(AnthropicError::ApiError(error))), state));
|
||||
}
|
||||
_ => {}
|
||||
},
|
||||
Err(err) => {
|
||||
return Some((Err(err), state));
|
||||
return Some((Some(Err(err)), state));
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -418,6 +417,7 @@ pub fn extract_content_from_events(
|
||||
None
|
||||
},
|
||||
)
|
||||
.filter_map(|event| async move { event })
|
||||
}
|
||||
|
||||
pub async fn extract_tool_args_from_events(
|
||||
|
@ -25,8 +25,9 @@ use gpui::{
|
||||
|
||||
use language::{AnchorRangeExt, Bias, Buffer, LanguageRegistry, OffsetRangeExt, Point, ToOffset};
|
||||
use language_model::{
|
||||
LanguageModel, LanguageModelCacheConfiguration, LanguageModelImage, LanguageModelRegistry,
|
||||
LanguageModelRequest, LanguageModelRequestMessage, MessageContent, Role,
|
||||
LanguageModel, LanguageModelCacheConfiguration, LanguageModelCompletionEvent,
|
||||
LanguageModelImage, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
|
||||
MessageContent, Role,
|
||||
};
|
||||
use open_ai::Model as OpenAiModel;
|
||||
use paths::{context_images_dir, contexts_dir};
|
||||
@ -1950,13 +1951,13 @@ impl Context {
|
||||
let mut response_latency = None;
|
||||
let stream_completion = async {
|
||||
let request_start = Instant::now();
|
||||
let mut chunks = stream.await?;
|
||||
let mut events = stream.await?;
|
||||
|
||||
while let Some(chunk) = chunks.next().await {
|
||||
while let Some(event) = events.next().await {
|
||||
if response_latency.is_none() {
|
||||
response_latency = Some(request_start.elapsed());
|
||||
}
|
||||
let chunk = chunk?;
|
||||
let event = event?;
|
||||
|
||||
this.update(&mut cx, |this, cx| {
|
||||
let message_ix = this
|
||||
@ -1970,11 +1971,36 @@ impl Context {
|
||||
.map_or(buffer.len(), |message| {
|
||||
message.start.to_offset(buffer).saturating_sub(1)
|
||||
});
|
||||
buffer.edit(
|
||||
[(message_old_end_offset..message_old_end_offset, chunk)],
|
||||
None,
|
||||
cx,
|
||||
);
|
||||
|
||||
match event {
|
||||
LanguageModelCompletionEvent::Text(chunk) => {
|
||||
buffer.edit(
|
||||
[(
|
||||
message_old_end_offset..message_old_end_offset,
|
||||
chunk,
|
||||
)],
|
||||
None,
|
||||
cx,
|
||||
);
|
||||
}
|
||||
LanguageModelCompletionEvent::ToolUse(tool_use) => {
|
||||
let mut text = String::new();
|
||||
text.push('\n');
|
||||
text.push_str(
|
||||
&serde_json::to_string_pretty(&tool_use)
|
||||
.expect("failed to serialize tool use to JSON"),
|
||||
);
|
||||
|
||||
buffer.edit(
|
||||
[(
|
||||
message_old_end_offset..message_old_end_offset,
|
||||
text,
|
||||
)],
|
||||
None,
|
||||
cx,
|
||||
);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
cx.emit(ContextEvent::StreamedCompletion);
|
||||
@ -2406,7 +2432,7 @@ impl Context {
|
||||
|
||||
self.pending_summary = cx.spawn(|this, mut cx| {
|
||||
async move {
|
||||
let stream = model.stream_completion(request, &cx);
|
||||
let stream = model.stream_completion_text(request, &cx);
|
||||
let mut messages = stream.await?;
|
||||
|
||||
let mut replaced = !replace_old;
|
||||
|
@ -2344,7 +2344,7 @@ impl Codegen {
|
||||
self.build_request(user_prompt, assistant_panel_context, edit_range.clone(), cx)?;
|
||||
|
||||
let chunks =
|
||||
cx.spawn(|_, cx| async move { model.stream_completion(request, &cx).await });
|
||||
cx.spawn(|_, cx| async move { model.stream_completion_text(request, &cx).await });
|
||||
async move { Ok(chunks.await?.boxed()) }.boxed_local()
|
||||
};
|
||||
self.handle_stream(telemetry_id, edit_range, chunks, cx);
|
||||
|
@ -1010,7 +1010,7 @@ impl Codegen {
|
||||
self.transaction = Some(TerminalTransaction::start(self.terminal.clone()));
|
||||
self.generation = cx.spawn(|this, mut cx| async move {
|
||||
let model_telemetry_id = model.telemetry_id();
|
||||
let response = model.stream_completion(prompt, &cx).await;
|
||||
let response = model.stream_completion_text(prompt, &cx).await;
|
||||
let generate = async {
|
||||
let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
|
||||
|
||||
|
@ -8,7 +8,8 @@ pub mod settings;
|
||||
|
||||
use anyhow::Result;
|
||||
use client::{Client, UserStore};
|
||||
use futures::{future::BoxFuture, stream::BoxStream, TryStreamExt as _};
|
||||
use futures::FutureExt;
|
||||
use futures::{future::BoxFuture, stream::BoxStream, StreamExt, TryStreamExt as _};
|
||||
use gpui::{
|
||||
AnyElement, AnyView, AppContext, AsyncAppContext, Model, SharedString, Task, WindowContext,
|
||||
};
|
||||
@ -51,6 +52,20 @@ pub struct LanguageModelCacheConfiguration {
|
||||
pub min_total_token: usize,
|
||||
}
|
||||
|
||||
/// A completion event from a language model.
|
||||
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
|
||||
pub enum LanguageModelCompletionEvent {
|
||||
Text(String),
|
||||
ToolUse(LanguageModelToolUse),
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
|
||||
pub struct LanguageModelToolUse {
|
||||
pub id: String,
|
||||
pub name: String,
|
||||
pub input: serde_json::Value,
|
||||
}
|
||||
|
||||
pub trait LanguageModel: Send + Sync {
|
||||
fn id(&self) -> LanguageModelId;
|
||||
fn name(&self) -> LanguageModelName;
|
||||
@ -82,7 +97,29 @@ pub trait LanguageModel: Send + Sync {
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
cx: &AsyncAppContext,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>>;
|
||||
|
||||
fn stream_completion_text(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
cx: &AsyncAppContext,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||
let events = self.stream_completion(request, cx);
|
||||
|
||||
async move {
|
||||
Ok(events
|
||||
.await?
|
||||
.filter_map(|result| async move {
|
||||
match result {
|
||||
Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
|
||||
Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
|
||||
Err(err) => Some(Err(err)),
|
||||
}
|
||||
})
|
||||
.boxed())
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
|
||||
fn use_any_tool(
|
||||
&self,
|
||||
|
@ -3,6 +3,7 @@ use crate::{
|
||||
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
|
||||
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
|
||||
};
|
||||
use crate::{LanguageModelCompletionEvent, LanguageModelToolUse};
|
||||
use anthropic::AnthropicError;
|
||||
use anyhow::{anyhow, Context as _, Result};
|
||||
use collections::BTreeMap;
|
||||
@ -364,7 +365,7 @@ impl LanguageModel for AnthropicModel {
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
cx: &AsyncAppContext,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
|
||||
let request =
|
||||
request.into_anthropic(self.model.id().into(), self.model.max_output_tokens());
|
||||
let request = self.stream_completion(request, cx);
|
||||
@ -375,7 +376,22 @@ impl LanguageModel for AnthropicModel {
|
||||
async move {
|
||||
Ok(future
|
||||
.await?
|
||||
.map(|result| result.map_err(|err| anyhow!(err)))
|
||||
.map(|result| {
|
||||
result
|
||||
.map(|content| match content {
|
||||
anthropic::ResponseContent::Text { text } => {
|
||||
LanguageModelCompletionEvent::Text(text)
|
||||
}
|
||||
anthropic::ResponseContent::ToolUse { id, name, input } => {
|
||||
LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse {
|
||||
id,
|
||||
name,
|
||||
input,
|
||||
})
|
||||
}
|
||||
})
|
||||
.map_err(|err| anyhow!(err))
|
||||
})
|
||||
.boxed())
|
||||
}
|
||||
.boxed()
|
||||
|
@ -33,7 +33,10 @@ use std::{
|
||||
use strum::IntoEnumIterator;
|
||||
use ui::{prelude::*, TintColor};
|
||||
|
||||
use crate::{LanguageModelAvailability, LanguageModelProvider};
|
||||
use crate::{
|
||||
LanguageModelAvailability, LanguageModelCompletionEvent, LanguageModelProvider,
|
||||
LanguageModelToolUse,
|
||||
};
|
||||
|
||||
use super::anthropic::count_anthropic_tokens;
|
||||
|
||||
@ -496,7 +499,7 @@ impl LanguageModel for CloudLanguageModel {
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
_cx: &AsyncAppContext,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
|
||||
match &self.model {
|
||||
CloudModel::Anthropic(model) => {
|
||||
let request = request.into_anthropic(model.id().into(), model.max_output_tokens());
|
||||
@ -522,7 +525,20 @@ impl LanguageModel for CloudLanguageModel {
|
||||
async move {
|
||||
Ok(future
|
||||
.await?
|
||||
.map(|result| result.map_err(|err| anyhow!(err)))
|
||||
.map(|result| {
|
||||
result
|
||||
.map(|content| match content {
|
||||
anthropic::ResponseContent::Text { text } => {
|
||||
LanguageModelCompletionEvent::Text(text)
|
||||
}
|
||||
anthropic::ResponseContent::ToolUse { id, name, input } => {
|
||||
LanguageModelCompletionEvent::ToolUse(
|
||||
LanguageModelToolUse { id, name, input },
|
||||
)
|
||||
}
|
||||
})
|
||||
.map_err(|err| anyhow!(err))
|
||||
})
|
||||
.boxed())
|
||||
}
|
||||
.boxed()
|
||||
@ -546,7 +562,13 @@ impl LanguageModel for CloudLanguageModel {
|
||||
.await?;
|
||||
Ok(open_ai::extract_text_from_events(response_lines(response)))
|
||||
});
|
||||
async move { Ok(future.await?.boxed()) }.boxed()
|
||||
async move {
|
||||
Ok(future
|
||||
.await?
|
||||
.map(|result| result.map(LanguageModelCompletionEvent::Text))
|
||||
.boxed())
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
CloudModel::Google(model) => {
|
||||
let client = self.client.clone();
|
||||
@ -569,7 +591,13 @@ impl LanguageModel for CloudLanguageModel {
|
||||
response,
|
||||
)))
|
||||
});
|
||||
async move { Ok(future.await?.boxed()) }.boxed()
|
||||
async move {
|
||||
Ok(future
|
||||
.await?
|
||||
.map(|result| result.map(LanguageModelCompletionEvent::Text))
|
||||
.boxed())
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
CloudModel::Zed(model) => {
|
||||
let client = self.client.clone();
|
||||
@ -591,7 +619,13 @@ impl LanguageModel for CloudLanguageModel {
|
||||
.await?;
|
||||
Ok(open_ai::extract_text_from_events(response_lines(response)))
|
||||
});
|
||||
async move { Ok(future.await?.boxed()) }.boxed()
|
||||
async move {
|
||||
Ok(future
|
||||
.await?
|
||||
.map(|result| result.map(LanguageModelCompletionEvent::Text))
|
||||
.boxed())
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -24,11 +24,11 @@ use ui::{
|
||||
};
|
||||
|
||||
use crate::settings::AllLanguageModelSettings;
|
||||
use crate::LanguageModelProviderState;
|
||||
use crate::{
|
||||
LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
|
||||
LanguageModelProviderId, LanguageModelProviderName, LanguageModelRequest, RateLimiter, Role,
|
||||
};
|
||||
use crate::{LanguageModelCompletionEvent, LanguageModelProviderState};
|
||||
|
||||
use super::open_ai::count_open_ai_tokens;
|
||||
|
||||
@ -192,7 +192,7 @@ impl LanguageModel for CopilotChatLanguageModel {
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
cx: &AsyncAppContext,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
|
||||
if let Some(message) = request.messages.last() {
|
||||
if message.contents_empty() {
|
||||
const EMPTY_PROMPT_MSG: &str =
|
||||
@ -243,7 +243,13 @@ impl LanguageModel for CopilotChatLanguageModel {
|
||||
}).await
|
||||
});
|
||||
|
||||
async move { Ok(future.await?.boxed()) }.boxed()
|
||||
async move {
|
||||
Ok(future
|
||||
.await?
|
||||
.map(|result| result.map(LanguageModelCompletionEvent::Text))
|
||||
.boxed())
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
|
||||
fn use_any_tool(
|
||||
|
@ -1,7 +1,7 @@
|
||||
use crate::{
|
||||
LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
|
||||
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
|
||||
LanguageModelRequest,
|
||||
LanguageModel, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName,
|
||||
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
|
||||
LanguageModelProviderState, LanguageModelRequest,
|
||||
};
|
||||
use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
|
||||
use gpui::{AnyView, AppContext, AsyncAppContext, Task};
|
||||
@ -170,10 +170,15 @@ impl LanguageModel for FakeLanguageModel {
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
_: &AsyncAppContext,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
|
||||
let (tx, rx) = mpsc::unbounded();
|
||||
self.current_completion_txs.lock().push((request, tx));
|
||||
async move { Ok(rx.map(Ok).boxed()) }.boxed()
|
||||
async move {
|
||||
Ok(rx
|
||||
.map(|text| Ok(LanguageModelCompletionEvent::Text(text)))
|
||||
.boxed())
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
|
||||
fn use_any_tool(
|
||||
|
@ -17,6 +17,7 @@ use theme::ThemeSettings;
|
||||
use ui::{prelude::*, Icon, IconName, Tooltip};
|
||||
use util::ResultExt;
|
||||
|
||||
use crate::LanguageModelCompletionEvent;
|
||||
use crate::{
|
||||
settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
|
||||
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
|
||||
@ -281,7 +282,10 @@ impl LanguageModel for GoogleLanguageModel {
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
cx: &AsyncAppContext,
|
||||
) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<String>>>> {
|
||||
) -> BoxFuture<
|
||||
'static,
|
||||
Result<futures::stream::BoxStream<'static, Result<LanguageModelCompletionEvent>>>,
|
||||
> {
|
||||
let request = request.into_google(self.model.id().to_string());
|
||||
|
||||
let http_client = self.http_client.clone();
|
||||
@ -299,7 +303,13 @@ impl LanguageModel for GoogleLanguageModel {
|
||||
let events = response.await?;
|
||||
Ok(google_ai::extract_text_from_events(events).boxed())
|
||||
});
|
||||
async move { Ok(future.await?.boxed()) }.boxed()
|
||||
async move {
|
||||
Ok(future
|
||||
.await?
|
||||
.map(|result| result.map(LanguageModelCompletionEvent::Text))
|
||||
.boxed())
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
|
||||
fn use_any_tool(
|
||||
|
@ -13,6 +13,7 @@ use std::{collections::BTreeMap, sync::Arc, time::Duration};
|
||||
use ui::{prelude::*, ButtonLike, Indicator};
|
||||
use util::ResultExt;
|
||||
|
||||
use crate::LanguageModelCompletionEvent;
|
||||
use crate::{
|
||||
settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
|
||||
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
|
||||
@ -302,7 +303,7 @@ impl LanguageModel for OllamaLanguageModel {
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
cx: &AsyncAppContext,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
|
||||
let request = self.to_ollama_request(request);
|
||||
|
||||
let http_client = self.http_client.clone();
|
||||
@ -335,7 +336,13 @@ impl LanguageModel for OllamaLanguageModel {
|
||||
Ok(stream)
|
||||
});
|
||||
|
||||
async move { Ok(future.await?.boxed()) }.boxed()
|
||||
async move {
|
||||
Ok(future
|
||||
.await?
|
||||
.map(|result| result.map(LanguageModelCompletionEvent::Text))
|
||||
.boxed())
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
|
||||
fn use_any_tool(
|
||||
|
@ -19,6 +19,7 @@ use theme::ThemeSettings;
|
||||
use ui::{prelude::*, Icon, IconName, Tooltip};
|
||||
use util::ResultExt;
|
||||
|
||||
use crate::LanguageModelCompletionEvent;
|
||||
use crate::{
|
||||
settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
|
||||
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
|
||||
@ -293,10 +294,18 @@ impl LanguageModel for OpenAiLanguageModel {
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
cx: &AsyncAppContext,
|
||||
) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<String>>>> {
|
||||
) -> BoxFuture<
|
||||
'static,
|
||||
Result<futures::stream::BoxStream<'static, Result<LanguageModelCompletionEvent>>>,
|
||||
> {
|
||||
let request = request.into_open_ai(self.model.id().into(), self.max_output_tokens());
|
||||
let completions = self.stream_completion(request, cx);
|
||||
async move { Ok(open_ai::extract_text_from_events(completions.await?).boxed()) }.boxed()
|
||||
async move {
|
||||
Ok(open_ai::extract_text_from_events(completions.await?)
|
||||
.map(|result| result.map(LanguageModelCompletionEvent::Text))
|
||||
.boxed())
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
|
||||
fn use_any_tool(
|
||||
|
Loading…
Reference in New Issue
Block a user