mirror of
https://github.com/zed-industries/zed.git
synced 2024-11-08 07:35:01 +03:00
Make LanguageModel::use_any_tool return a stream of chunks (#16262)
This PR is a refactor to pave the way for allowing the user to view and edit workflow step resolutions. I've made tool calls work more like normal streaming completions for all providers. The `use_any_tool` method returns a stream of strings (which contain chunks of JSON). I've also done some minor cleanup of language model providers in general, removing the duplication around handling streaming responses. Release Notes: - N/A
This commit is contained in:
parent
1117d89057
commit
4c390b82fb
@ -5,8 +5,8 @@ use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, S
|
||||
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
|
||||
use isahc::config::Configurable;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::str::FromStr;
|
||||
use std::time::Duration;
|
||||
use std::{pin::Pin, str::FromStr};
|
||||
use strum::{EnumIter, EnumString};
|
||||
use thiserror::Error;
|
||||
|
||||
@ -241,6 +241,50 @@ pub fn extract_text_from_events(
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn extract_tool_args_from_events(
|
||||
tool_name: String,
|
||||
mut events: Pin<Box<dyn Send + Stream<Item = Result<Event>>>>,
|
||||
) -> Result<impl Send + Stream<Item = Result<String>>> {
|
||||
let mut tool_use_index = None;
|
||||
while let Some(event) = events.next().await {
|
||||
if let Event::ContentBlockStart {
|
||||
index,
|
||||
content_block,
|
||||
} = event?
|
||||
{
|
||||
if let Content::ToolUse { name, .. } = content_block {
|
||||
if name == tool_name {
|
||||
tool_use_index = Some(index);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let Some(tool_use_index) = tool_use_index else {
|
||||
return Err(anyhow!("tool not used"));
|
||||
};
|
||||
|
||||
Ok(events.filter_map(move |event| {
|
||||
let result = match event {
|
||||
Err(error) => Some(Err(error)),
|
||||
Ok(Event::ContentBlockDelta { index, delta }) => match delta {
|
||||
ContentDelta::TextDelta { .. } => None,
|
||||
ContentDelta::InputJsonDelta { partial_json } => {
|
||||
if index == tool_use_index {
|
||||
Some(Ok(partial_json))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
},
|
||||
_ => None,
|
||||
};
|
||||
|
||||
async move { result }
|
||||
}))
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct Message {
|
||||
pub role: Role,
|
||||
|
@ -1,6 +1,6 @@
|
||||
use crate::{
|
||||
prompts::PromptBuilder, slash_command::SlashCommandLine, AssistantPanel, InitialInsertion,
|
||||
InlineAssistId, InlineAssistant, MessageId, MessageStatus,
|
||||
prompts::PromptBuilder, slash_command::SlashCommandLine, AssistantPanel, InlineAssistId,
|
||||
InlineAssistant, MessageId, MessageStatus,
|
||||
};
|
||||
use anyhow::{anyhow, Context as _, Result};
|
||||
use assistant_slash_command::{
|
||||
@ -3342,7 +3342,7 @@ mod tests {
|
||||
|
||||
model
|
||||
.as_fake()
|
||||
.respond_to_last_tool_use(Ok(serde_json::to_value(tool::WorkflowStepResolution {
|
||||
.respond_to_last_tool_use(tool::WorkflowStepResolution {
|
||||
step_title: "Title".into(),
|
||||
suggestions: vec![tool::WorkflowSuggestion {
|
||||
path: "/root/hello.rs".into(),
|
||||
@ -3352,8 +3352,7 @@ mod tests {
|
||||
description: "Extract a greeting function".into(),
|
||||
},
|
||||
}],
|
||||
})
|
||||
.unwrap()));
|
||||
});
|
||||
|
||||
// Wait for tool use to be processed.
|
||||
cx.run_until_parked();
|
||||
@ -4084,44 +4083,4 @@ mod tool {
|
||||
symbol: String,
|
||||
},
|
||||
}
|
||||
|
||||
impl WorkflowSuggestionKind {
|
||||
pub fn symbol(&self) -> Option<&str> {
|
||||
match self {
|
||||
Self::Update { symbol, .. } => Some(symbol),
|
||||
Self::InsertSiblingBefore { symbol, .. } => Some(symbol),
|
||||
Self::InsertSiblingAfter { symbol, .. } => Some(symbol),
|
||||
Self::PrependChild { symbol, .. } => symbol.as_deref(),
|
||||
Self::AppendChild { symbol, .. } => symbol.as_deref(),
|
||||
Self::Delete { symbol } => Some(symbol),
|
||||
Self::Create { .. } => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn description(&self) -> Option<&str> {
|
||||
match self {
|
||||
Self::Update { description, .. } => Some(description),
|
||||
Self::Create { description } => Some(description),
|
||||
Self::InsertSiblingBefore { description, .. } => Some(description),
|
||||
Self::InsertSiblingAfter { description, .. } => Some(description),
|
||||
Self::PrependChild { description, .. } => Some(description),
|
||||
Self::AppendChild { description, .. } => Some(description),
|
||||
Self::Delete { .. } => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn initial_insertion(&self) -> Option<InitialInsertion> {
|
||||
match self {
|
||||
WorkflowSuggestionKind::InsertSiblingBefore { .. } => {
|
||||
Some(InitialInsertion::NewlineAfter)
|
||||
}
|
||||
WorkflowSuggestionKind::InsertSiblingAfter { .. } => {
|
||||
Some(InitialInsertion::NewlineBefore)
|
||||
}
|
||||
WorkflowSuggestionKind::PrependChild { .. } => Some(InitialInsertion::NewlineAfter),
|
||||
WorkflowSuggestionKind::AppendChild { .. } => Some(InitialInsertion::NewlineBefore),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1280,12 +1280,6 @@ fn build_assist_editor_renderer(editor: &View<PromptEditor>) -> RenderBlock {
|
||||
})
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
|
||||
pub enum InitialInsertion {
|
||||
NewlineBefore,
|
||||
NewlineAfter,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
|
||||
pub struct InlineAssistId(usize);
|
||||
|
||||
|
@ -351,10 +351,13 @@ impl Asset for ImageAsset {
|
||||
let mut body = Vec::new();
|
||||
response.body_mut().read_to_end(&mut body).await?;
|
||||
if !response.status().is_success() {
|
||||
let mut body = String::from_utf8_lossy(&body).into_owned();
|
||||
let first_line = body.lines().next().unwrap_or("").trim_end();
|
||||
body.truncate(first_line.len());
|
||||
return Err(ImageCacheError::BadStatus {
|
||||
uri,
|
||||
status: response.status(),
|
||||
body: String::from_utf8_lossy(&body).into_owned(),
|
||||
body,
|
||||
});
|
||||
}
|
||||
body
|
||||
|
@ -8,7 +8,7 @@ pub mod settings;
|
||||
|
||||
use anyhow::Result;
|
||||
use client::{Client, UserStore};
|
||||
use futures::{future::BoxFuture, stream::BoxStream};
|
||||
use futures::{future::BoxFuture, stream::BoxStream, TryStreamExt as _};
|
||||
use gpui::{
|
||||
AnyElement, AnyView, AppContext, AsyncAppContext, Model, SharedString, Task, WindowContext,
|
||||
};
|
||||
@ -76,7 +76,7 @@ pub trait LanguageModel: Send + Sync {
|
||||
description: String,
|
||||
schema: serde_json::Value,
|
||||
cx: &AsyncAppContext,
|
||||
) -> BoxFuture<'static, Result<serde_json::Value>>;
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
|
||||
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
fn as_fake(&self) -> &provider::fake::FakeLanguageModel {
|
||||
@ -92,10 +92,11 @@ impl dyn LanguageModel {
|
||||
) -> impl 'static + Future<Output = Result<T>> {
|
||||
let schema = schemars::schema_for!(T);
|
||||
let schema_json = serde_json::to_value(&schema).unwrap();
|
||||
let request = self.use_any_tool(request, T::name(), T::description(), schema_json, cx);
|
||||
let stream = self.use_any_tool(request, T::name(), T::description(), schema_json, cx);
|
||||
async move {
|
||||
let response = request.await?;
|
||||
Ok(serde_json::from_value(response)?)
|
||||
let stream = stream.await?;
|
||||
let response = stream.try_collect::<String>().await?;
|
||||
Ok(serde_json::from_str(&response)?)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -7,7 +7,7 @@ use anthropic::AnthropicError;
|
||||
use anyhow::{anyhow, Context as _, Result};
|
||||
use collections::BTreeMap;
|
||||
use editor::{Editor, EditorElement, EditorStyle};
|
||||
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
|
||||
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt, TryStreamExt as _};
|
||||
use gpui::{
|
||||
AnyView, AppContext, AsyncAppContext, FontStyle, ModelContext, Subscription, Task, TextStyle,
|
||||
View, WhiteSpace,
|
||||
@ -264,29 +264,6 @@ pub fn count_anthropic_tokens(
|
||||
}
|
||||
|
||||
impl AnthropicModel {
|
||||
fn request_completion(
|
||||
&self,
|
||||
request: anthropic::Request,
|
||||
cx: &AsyncAppContext,
|
||||
) -> BoxFuture<'static, Result<anthropic::Response>> {
|
||||
let http_client = self.http_client.clone();
|
||||
|
||||
let Ok((api_key, api_url)) = cx.read_model(&self.state, |state, cx| {
|
||||
let settings = &AllLanguageModelSettings::get_global(cx).anthropic;
|
||||
(state.api_key.clone(), settings.api_url.clone())
|
||||
}) else {
|
||||
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
|
||||
};
|
||||
|
||||
async move {
|
||||
let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
|
||||
anthropic::complete(http_client.as_ref(), &api_url, &api_key, request)
|
||||
.await
|
||||
.context("failed to retrieve completion")
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
|
||||
fn stream_completion(
|
||||
&self,
|
||||
request: anthropic::Request,
|
||||
@ -381,7 +358,7 @@ impl LanguageModel for AnthropicModel {
|
||||
tool_description: String,
|
||||
input_schema: serde_json::Value,
|
||||
cx: &AsyncAppContext,
|
||||
) -> BoxFuture<'static, Result<serde_json::Value>> {
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||
let mut request = request.into_anthropic(self.model.tool_model_id().into());
|
||||
request.tool_choice = Some(anthropic::ToolChoice::Tool {
|
||||
name: tool_name.clone(),
|
||||
@ -392,25 +369,16 @@ impl LanguageModel for AnthropicModel {
|
||||
input_schema,
|
||||
}];
|
||||
|
||||
let response = self.request_completion(request, cx);
|
||||
let response = self.stream_completion(request, cx);
|
||||
self.request_limiter
|
||||
.run(async move {
|
||||
let response = response.await?;
|
||||
response
|
||||
.content
|
||||
.into_iter()
|
||||
.find_map(|content| {
|
||||
if let anthropic::Content::ToolUse { name, input, .. } = content {
|
||||
if name == tool_name {
|
||||
Some(input)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.context("tool not used")
|
||||
Ok(anthropic::extract_tool_args_from_events(
|
||||
tool_name,
|
||||
Box::pin(response.map_err(|e| anyhow!(e))),
|
||||
)
|
||||
.await?
|
||||
.boxed())
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
|
@ -5,18 +5,21 @@ use crate::{
|
||||
LanguageModelProviderState, LanguageModelRequest, RateLimiter, ZedModel,
|
||||
};
|
||||
use anthropic::AnthropicError;
|
||||
use anyhow::{anyhow, bail, Context as _, Result};
|
||||
use anyhow::{anyhow, Result};
|
||||
use client::{Client, PerformCompletionParams, UserStore, EXPIRED_LLM_TOKEN_HEADER_NAME};
|
||||
use collections::BTreeMap;
|
||||
use feature_flags::{FeatureFlagAppExt, ZedPro};
|
||||
use futures::{future::BoxFuture, stream::BoxStream, AsyncBufReadExt, FutureExt, StreamExt};
|
||||
use futures::{
|
||||
future::BoxFuture, stream::BoxStream, AsyncBufReadExt, FutureExt, Stream, StreamExt,
|
||||
TryStreamExt as _,
|
||||
};
|
||||
use gpui::{
|
||||
AnyElement, AnyView, AppContext, AsyncAppContext, FontWeight, Model, ModelContext,
|
||||
Subscription, Task,
|
||||
};
|
||||
use http_client::{AsyncBody, HttpClient, Method, Response};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
||||
use serde_json::value::RawValue;
|
||||
use settings::{Settings, SettingsStore};
|
||||
use smol::{
|
||||
@ -451,21 +454,9 @@ impl LanguageModel for CloudLanguageModel {
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
let body = BufReader::new(response.into_body());
|
||||
let stream = futures::stream::try_unfold(body, move |mut body| async move {
|
||||
let mut buffer = String::new();
|
||||
match body.read_line(&mut buffer).await {
|
||||
Ok(0) => Ok(None),
|
||||
Ok(_) => {
|
||||
let event: anthropic::Event = serde_json::from_str(&buffer)
|
||||
.context("failed to parse Anthropic event")?;
|
||||
Ok(Some((event, body)))
|
||||
}
|
||||
Err(err) => Err(AnthropicError::Other(err.into())),
|
||||
}
|
||||
});
|
||||
|
||||
Ok(anthropic::extract_text_from_events(stream))
|
||||
Ok(anthropic::extract_text_from_events(
|
||||
response_lines(response).map_err(AnthropicError::Other),
|
||||
))
|
||||
});
|
||||
async move {
|
||||
Ok(future
|
||||
@ -492,21 +483,7 @@ impl LanguageModel for CloudLanguageModel {
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
let body = BufReader::new(response.into_body());
|
||||
let stream = futures::stream::try_unfold(body, move |mut body| async move {
|
||||
let mut buffer = String::new();
|
||||
match body.read_line(&mut buffer).await {
|
||||
Ok(0) => Ok(None),
|
||||
Ok(_) => {
|
||||
let event: open_ai::ResponseStreamEvent =
|
||||
serde_json::from_str(&buffer)?;
|
||||
Ok(Some((event, body)))
|
||||
}
|
||||
Err(e) => Err(e.into()),
|
||||
}
|
||||
});
|
||||
|
||||
Ok(open_ai::extract_text_from_events(stream))
|
||||
Ok(open_ai::extract_text_from_events(response_lines(response)))
|
||||
});
|
||||
async move { Ok(future.await?.boxed()) }.boxed()
|
||||
}
|
||||
@ -527,21 +504,9 @@ impl LanguageModel for CloudLanguageModel {
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
let body = BufReader::new(response.into_body());
|
||||
let stream = futures::stream::try_unfold(body, move |mut body| async move {
|
||||
let mut buffer = String::new();
|
||||
match body.read_line(&mut buffer).await {
|
||||
Ok(0) => Ok(None),
|
||||
Ok(_) => {
|
||||
let event: google_ai::GenerateContentResponse =
|
||||
serde_json::from_str(&buffer)?;
|
||||
Ok(Some((event, body)))
|
||||
}
|
||||
Err(e) => Err(e.into()),
|
||||
}
|
||||
});
|
||||
|
||||
Ok(google_ai::extract_text_from_events(stream))
|
||||
Ok(google_ai::extract_text_from_events(response_lines(
|
||||
response,
|
||||
)))
|
||||
});
|
||||
async move { Ok(future.await?.boxed()) }.boxed()
|
||||
}
|
||||
@ -563,21 +528,7 @@ impl LanguageModel for CloudLanguageModel {
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
let body = BufReader::new(response.into_body());
|
||||
let stream = futures::stream::try_unfold(body, move |mut body| async move {
|
||||
let mut buffer = String::new();
|
||||
match body.read_line(&mut buffer).await {
|
||||
Ok(0) => Ok(None),
|
||||
Ok(_) => {
|
||||
let event: open_ai::ResponseStreamEvent =
|
||||
serde_json::from_str(&buffer)?;
|
||||
Ok(Some((event, body)))
|
||||
}
|
||||
Err(e) => Err(e.into()),
|
||||
}
|
||||
});
|
||||
|
||||
Ok(open_ai::extract_text_from_events(stream))
|
||||
Ok(open_ai::extract_text_from_events(response_lines(response)))
|
||||
});
|
||||
async move { Ok(future.await?.boxed()) }.boxed()
|
||||
}
|
||||
@ -591,10 +542,12 @@ impl LanguageModel for CloudLanguageModel {
|
||||
tool_description: String,
|
||||
input_schema: serde_json::Value,
|
||||
_cx: &AsyncAppContext,
|
||||
) -> BoxFuture<'static, Result<serde_json::Value>> {
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||
let client = self.client.clone();
|
||||
let llm_api_token = self.llm_api_token.clone();
|
||||
|
||||
match &self.model {
|
||||
CloudModel::Anthropic(model) => {
|
||||
let client = self.client.clone();
|
||||
let mut request = request.into_anthropic(model.tool_model_id().into());
|
||||
request.tool_choice = Some(anthropic::ToolChoice::Tool {
|
||||
name: tool_name.clone(),
|
||||
@ -605,7 +558,6 @@ impl LanguageModel for CloudLanguageModel {
|
||||
input_schema,
|
||||
}];
|
||||
|
||||
let llm_api_token = self.llm_api_token.clone();
|
||||
self.request_limiter
|
||||
.run(async move {
|
||||
let response = Self::perform_llm_completion(
|
||||
@ -621,70 +573,34 @@ impl LanguageModel for CloudLanguageModel {
|
||||
)
|
||||
.await?;
|
||||
|
||||
let mut tool_use_index = None;
|
||||
let mut tool_input = String::new();
|
||||
let mut body = BufReader::new(response.into_body());
|
||||
let mut line = String::new();
|
||||
while body.read_line(&mut line).await? > 0 {
|
||||
let event: anthropic::Event = serde_json::from_str(&line)?;
|
||||
line.clear();
|
||||
|
||||
match event {
|
||||
anthropic::Event::ContentBlockStart {
|
||||
content_block,
|
||||
index,
|
||||
} => {
|
||||
if let anthropic::Content::ToolUse { name, .. } = content_block
|
||||
{
|
||||
if name == tool_name {
|
||||
tool_use_index = Some(index);
|
||||
}
|
||||
}
|
||||
}
|
||||
anthropic::Event::ContentBlockDelta { index, delta } => match delta
|
||||
{
|
||||
anthropic::ContentDelta::TextDelta { .. } => {}
|
||||
anthropic::ContentDelta::InputJsonDelta { partial_json } => {
|
||||
if Some(index) == tool_use_index {
|
||||
tool_input.push_str(&partial_json);
|
||||
}
|
||||
}
|
||||
},
|
||||
anthropic::Event::ContentBlockStop { index } => {
|
||||
if Some(index) == tool_use_index {
|
||||
return Ok(serde_json::from_str(&tool_input)?);
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
if tool_use_index.is_some() {
|
||||
Err(anyhow!("tool content incomplete"))
|
||||
} else {
|
||||
Err(anyhow!("tool not used"))
|
||||
}
|
||||
Ok(anthropic::extract_tool_args_from_events(
|
||||
tool_name,
|
||||
Box::pin(response_lines(response)),
|
||||
)
|
||||
.await?
|
||||
.boxed())
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
CloudModel::OpenAi(model) => {
|
||||
let mut request = request.into_open_ai(model.id().into());
|
||||
let client = self.client.clone();
|
||||
let mut function = open_ai::FunctionDefinition {
|
||||
request.tool_choice = Some(open_ai::ToolChoice::Other(
|
||||
open_ai::ToolDefinition::Function {
|
||||
function: open_ai::FunctionDefinition {
|
||||
name: tool_name.clone(),
|
||||
description: None,
|
||||
parameters: None,
|
||||
};
|
||||
let func = open_ai::ToolDefinition::Function {
|
||||
function: function.clone(),
|
||||
};
|
||||
request.tool_choice = Some(open_ai::ToolChoice::Other(func.clone()));
|
||||
// Fill in description and params separately, as they're not needed for tool_choice field.
|
||||
function.description = Some(tool_description);
|
||||
function.parameters = Some(input_schema);
|
||||
request.tools = vec![open_ai::ToolDefinition::Function { function }];
|
||||
},
|
||||
},
|
||||
));
|
||||
request.tools = vec![open_ai::ToolDefinition::Function {
|
||||
function: open_ai::FunctionDefinition {
|
||||
name: tool_name.clone(),
|
||||
description: Some(tool_description),
|
||||
parameters: Some(input_schema),
|
||||
},
|
||||
}];
|
||||
|
||||
let llm_api_token = self.llm_api_token.clone();
|
||||
self.request_limiter
|
||||
.run(async move {
|
||||
let response = Self::perform_llm_completion(
|
||||
@ -700,41 +616,12 @@ impl LanguageModel for CloudLanguageModel {
|
||||
)
|
||||
.await?;
|
||||
|
||||
let mut body = BufReader::new(response.into_body());
|
||||
let mut line = String::new();
|
||||
let mut load_state = None;
|
||||
|
||||
while body.read_line(&mut line).await? > 0 {
|
||||
let part: open_ai::ResponseStreamEvent = serde_json::from_str(&line)?;
|
||||
line.clear();
|
||||
|
||||
for choice in part.choices {
|
||||
let Some(tool_calls) = choice.delta.tool_calls else {
|
||||
continue;
|
||||
};
|
||||
|
||||
for call in tool_calls {
|
||||
if let Some(func) = call.function {
|
||||
if func.name.as_deref() == Some(tool_name.as_str()) {
|
||||
load_state = Some((String::default(), call.index));
|
||||
}
|
||||
if let Some((arguments, (output, index))) =
|
||||
func.arguments.zip(load_state.as_mut())
|
||||
{
|
||||
if call.index == *index {
|
||||
output.push_str(&arguments);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some((arguments, _)) = load_state {
|
||||
return Ok(serde_json::from_str(&arguments)?);
|
||||
} else {
|
||||
bail!("tool not used");
|
||||
}
|
||||
Ok(open_ai::extract_tool_args_from_events(
|
||||
tool_name,
|
||||
Box::pin(response_lines(response)),
|
||||
)
|
||||
.await?
|
||||
.boxed())
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
@ -744,22 +631,23 @@ impl LanguageModel for CloudLanguageModel {
|
||||
CloudModel::Zed(model) => {
|
||||
// All Zed models are OpenAI-based at the time of writing.
|
||||
let mut request = request.into_open_ai(model.id().into());
|
||||
let client = self.client.clone();
|
||||
let mut function = open_ai::FunctionDefinition {
|
||||
request.tool_choice = Some(open_ai::ToolChoice::Other(
|
||||
open_ai::ToolDefinition::Function {
|
||||
function: open_ai::FunctionDefinition {
|
||||
name: tool_name.clone(),
|
||||
description: None,
|
||||
parameters: None,
|
||||
};
|
||||
let func = open_ai::ToolDefinition::Function {
|
||||
function: function.clone(),
|
||||
};
|
||||
request.tool_choice = Some(open_ai::ToolChoice::Other(func.clone()));
|
||||
// Fill in description and params separately, as they're not needed for tool_choice field.
|
||||
function.description = Some(tool_description);
|
||||
function.parameters = Some(input_schema);
|
||||
request.tools = vec![open_ai::ToolDefinition::Function { function }];
|
||||
},
|
||||
},
|
||||
));
|
||||
request.tools = vec![open_ai::ToolDefinition::Function {
|
||||
function: open_ai::FunctionDefinition {
|
||||
name: tool_name.clone(),
|
||||
description: Some(tool_description),
|
||||
parameters: Some(input_schema),
|
||||
},
|
||||
}];
|
||||
|
||||
let llm_api_token = self.llm_api_token.clone();
|
||||
self.request_limiter
|
||||
.run(async move {
|
||||
let response = Self::perform_llm_completion(
|
||||
@ -775,40 +663,12 @@ impl LanguageModel for CloudLanguageModel {
|
||||
)
|
||||
.await?;
|
||||
|
||||
let mut body = BufReader::new(response.into_body());
|
||||
let mut line = String::new();
|
||||
let mut load_state = None;
|
||||
|
||||
while body.read_line(&mut line).await? > 0 {
|
||||
let part: open_ai::ResponseStreamEvent = serde_json::from_str(&line)?;
|
||||
line.clear();
|
||||
|
||||
for choice in part.choices {
|
||||
let Some(tool_calls) = choice.delta.tool_calls else {
|
||||
continue;
|
||||
};
|
||||
|
||||
for call in tool_calls {
|
||||
if let Some(func) = call.function {
|
||||
if func.name.as_deref() == Some(tool_name.as_str()) {
|
||||
load_state = Some((String::default(), call.index));
|
||||
}
|
||||
if let Some((arguments, (output, index))) =
|
||||
func.arguments.zip(load_state.as_mut())
|
||||
{
|
||||
if call.index == *index {
|
||||
output.push_str(&arguments);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if let Some((arguments, _)) = load_state {
|
||||
return Ok(serde_json::from_str(&arguments)?);
|
||||
} else {
|
||||
bail!("tool not used");
|
||||
}
|
||||
Ok(open_ai::extract_tool_args_from_events(
|
||||
tool_name,
|
||||
Box::pin(response_lines(response)),
|
||||
)
|
||||
.await?
|
||||
.boxed())
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
@ -816,6 +676,25 @@ impl LanguageModel for CloudLanguageModel {
|
||||
}
|
||||
}
|
||||
|
||||
fn response_lines<T: DeserializeOwned>(
|
||||
response: Response<AsyncBody>,
|
||||
) -> impl Stream<Item = Result<T>> {
|
||||
futures::stream::try_unfold(
|
||||
(String::new(), BufReader::new(response.into_body())),
|
||||
move |(mut line, mut body)| async {
|
||||
match body.read_line(&mut line).await {
|
||||
Ok(0) => Ok(None),
|
||||
Ok(_) => {
|
||||
let event: T = serde_json::from_str(&line)?;
|
||||
line.clear();
|
||||
Ok(Some((event, (line, body))))
|
||||
}
|
||||
Err(e) => Err(e.into()),
|
||||
}
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
impl LlmApiToken {
|
||||
async fn acquire(&self, client: &Arc<Client>) -> Result<String> {
|
||||
let lock = self.0.upgradable_read().await;
|
||||
|
@ -252,7 +252,7 @@ impl LanguageModel for CopilotChatLanguageModel {
|
||||
_description: String,
|
||||
_schema: serde_json::Value,
|
||||
_cx: &AsyncAppContext,
|
||||
) -> BoxFuture<'static, Result<serde_json::Value>> {
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||
future::ready(Err(anyhow!("not implemented"))).boxed()
|
||||
}
|
||||
}
|
||||
|
@ -3,16 +3,11 @@ use crate::{
|
||||
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
|
||||
LanguageModelRequest,
|
||||
};
|
||||
use anyhow::Context as _;
|
||||
use futures::{
|
||||
channel::{mpsc, oneshot},
|
||||
future::BoxFuture,
|
||||
stream::BoxStream,
|
||||
FutureExt, StreamExt,
|
||||
};
|
||||
use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
|
||||
use gpui::{AnyView, AppContext, AsyncAppContext, Task};
|
||||
use http_client::Result;
|
||||
use parking_lot::Mutex;
|
||||
use serde::Serialize;
|
||||
use std::sync::Arc;
|
||||
use ui::WindowContext;
|
||||
|
||||
@ -90,7 +85,7 @@ pub struct ToolUseRequest {
|
||||
#[derive(Default)]
|
||||
pub struct FakeLanguageModel {
|
||||
current_completion_txs: Mutex<Vec<(LanguageModelRequest, mpsc::UnboundedSender<String>)>>,
|
||||
current_tool_use_txs: Mutex<Vec<(ToolUseRequest, oneshot::Sender<Result<serde_json::Value>>)>>,
|
||||
current_tool_use_txs: Mutex<Vec<(ToolUseRequest, mpsc::UnboundedSender<String>)>>,
|
||||
}
|
||||
|
||||
impl FakeLanguageModel {
|
||||
@ -130,25 +125,11 @@ impl FakeLanguageModel {
|
||||
self.end_completion_stream(self.pending_completions().last().unwrap());
|
||||
}
|
||||
|
||||
pub fn respond_to_tool_use(
|
||||
&self,
|
||||
tool_call: &ToolUseRequest,
|
||||
response: Result<serde_json::Value>,
|
||||
) {
|
||||
let mut current_tool_call_txs = self.current_tool_use_txs.lock();
|
||||
if let Some(index) = current_tool_call_txs
|
||||
.iter()
|
||||
.position(|(call, _)| call == tool_call)
|
||||
{
|
||||
let (_, tx) = current_tool_call_txs.remove(index);
|
||||
tx.send(response).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
pub fn respond_to_last_tool_use(&self, response: Result<serde_json::Value>) {
|
||||
pub fn respond_to_last_tool_use<T: Serialize>(&self, response: T) {
|
||||
let response = serde_json::to_string(&response).unwrap();
|
||||
let mut current_tool_call_txs = self.current_tool_use_txs.lock();
|
||||
let (_, tx) = current_tool_call_txs.pop().unwrap();
|
||||
tx.send(response).unwrap();
|
||||
tx.unbounded_send(response).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
@ -202,8 +183,8 @@ impl LanguageModel for FakeLanguageModel {
|
||||
description: String,
|
||||
schema: serde_json::Value,
|
||||
_cx: &AsyncAppContext,
|
||||
) -> BoxFuture<'static, Result<serde_json::Value>> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||
let (tx, rx) = mpsc::unbounded();
|
||||
let tool_call = ToolUseRequest {
|
||||
request,
|
||||
name,
|
||||
@ -211,7 +192,7 @@ impl LanguageModel for FakeLanguageModel {
|
||||
schema,
|
||||
};
|
||||
self.current_tool_use_txs.lock().push((tool_call, tx));
|
||||
async move { rx.await.context("FakeLanguageModel was dropped")? }.boxed()
|
||||
async move { Ok(rx.map(Ok).boxed()) }.boxed()
|
||||
}
|
||||
|
||||
fn as_fake(&self) -> &Self {
|
||||
|
@ -302,7 +302,7 @@ impl LanguageModel for GoogleLanguageModel {
|
||||
_description: String,
|
||||
_schema: serde_json::Value,
|
||||
_cx: &AsyncAppContext,
|
||||
) -> BoxFuture<'static, Result<serde_json::Value>> {
|
||||
) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<String>>>> {
|
||||
future::ready(Err(anyhow!("not implemented"))).boxed()
|
||||
}
|
||||
}
|
||||
|
@ -6,7 +6,6 @@ use ollama::{
|
||||
get_models, preload_model, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest,
|
||||
ChatResponseDelta, OllamaToolCall,
|
||||
};
|
||||
use serde_json::Value;
|
||||
use settings::{Settings, SettingsStore};
|
||||
use std::{sync::Arc, time::Duration};
|
||||
use ui::{prelude::*, ButtonLike, Indicator};
|
||||
@ -311,7 +310,7 @@ impl LanguageModel for OllamaLanguageModel {
|
||||
tool_description: String,
|
||||
schema: serde_json::Value,
|
||||
cx: &AsyncAppContext,
|
||||
) -> BoxFuture<'static, Result<serde_json::Value>> {
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||
use ollama::{OllamaFunctionTool, OllamaTool};
|
||||
let function = OllamaFunctionTool {
|
||||
name: tool_name.clone(),
|
||||
@ -324,23 +323,19 @@ impl LanguageModel for OllamaLanguageModel {
|
||||
self.request_limiter
|
||||
.run(async move {
|
||||
let response = response.await?;
|
||||
let ChatMessage::Assistant {
|
||||
tool_calls,
|
||||
content,
|
||||
} = response.message
|
||||
else {
|
||||
let ChatMessage::Assistant { tool_calls, .. } = response.message else {
|
||||
bail!("message does not have an assistant role");
|
||||
};
|
||||
if let Some(tool_calls) = tool_calls.filter(|calls| !calls.is_empty()) {
|
||||
for call in tool_calls {
|
||||
let OllamaToolCall::Function(function) = call;
|
||||
if function.name == tool_name {
|
||||
return Ok(function.arguments);
|
||||
return Ok(futures::stream::once(async move {
|
||||
Ok(function.arguments.to_string())
|
||||
})
|
||||
.boxed());
|
||||
}
|
||||
}
|
||||
} else if let Ok(args) = serde_json::from_str::<Value>(&content) {
|
||||
// Parse content as arguments.
|
||||
return Ok(args);
|
||||
} else {
|
||||
bail!("assistant message does not have any tool calls");
|
||||
};
|
||||
|
@ -1,4 +1,4 @@
|
||||
use anyhow::{anyhow, bail, Result};
|
||||
use anyhow::{anyhow, Result};
|
||||
use collections::BTreeMap;
|
||||
use editor::{Editor, EditorElement, EditorStyle};
|
||||
use futures::{future::BoxFuture, FutureExt, StreamExt};
|
||||
@ -243,6 +243,7 @@ impl OpenAiLanguageModel {
|
||||
async move { Ok(future.await?.boxed()) }.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
impl LanguageModel for OpenAiLanguageModel {
|
||||
fn id(&self) -> LanguageModelId {
|
||||
self.id.clone()
|
||||
@ -293,55 +294,32 @@ impl LanguageModel for OpenAiLanguageModel {
|
||||
tool_description: String,
|
||||
schema: serde_json::Value,
|
||||
cx: &AsyncAppContext,
|
||||
) -> BoxFuture<'static, Result<serde_json::Value>> {
|
||||
) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<String>>>> {
|
||||
let mut request = request.into_open_ai(self.model.id().into());
|
||||
let mut function = FunctionDefinition {
|
||||
request.tool_choice = Some(ToolChoice::Other(ToolDefinition::Function {
|
||||
function: FunctionDefinition {
|
||||
name: tool_name.clone(),
|
||||
description: None,
|
||||
parameters: None,
|
||||
};
|
||||
let func = ToolDefinition::Function {
|
||||
function: function.clone(),
|
||||
};
|
||||
request.tool_choice = Some(ToolChoice::Other(func.clone()));
|
||||
// Fill in description and params separately, as they're not needed for tool_choice field.
|
||||
function.description = Some(tool_description);
|
||||
function.parameters = Some(schema);
|
||||
request.tools = vec![ToolDefinition::Function { function }];
|
||||
},
|
||||
}));
|
||||
request.tools = vec![ToolDefinition::Function {
|
||||
function: FunctionDefinition {
|
||||
name: tool_name.clone(),
|
||||
description: Some(tool_description),
|
||||
parameters: Some(schema),
|
||||
},
|
||||
}];
|
||||
|
||||
let response = self.stream_completion(request, cx);
|
||||
self.request_limiter
|
||||
.run(async move {
|
||||
let mut response = response.await?;
|
||||
|
||||
// Call arguments are gonna be streamed in over multiple chunks.
|
||||
let mut load_state = None;
|
||||
while let Some(Ok(part)) = response.next().await {
|
||||
for choice in part.choices {
|
||||
let Some(tool_calls) = choice.delta.tool_calls else {
|
||||
continue;
|
||||
};
|
||||
|
||||
for call in tool_calls {
|
||||
if let Some(func) = call.function {
|
||||
if func.name.as_deref() == Some(tool_name.as_str()) {
|
||||
load_state = Some((String::default(), call.index));
|
||||
}
|
||||
if let Some((arguments, (output, index))) =
|
||||
func.arguments.zip(load_state.as_mut())
|
||||
{
|
||||
if call.index == *index {
|
||||
output.push_str(&arguments);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if let Some((arguments, _)) = load_state {
|
||||
return Ok(serde_json::from_str(&arguments)?);
|
||||
} else {
|
||||
bail!("tool not used");
|
||||
}
|
||||
let response = response.await?;
|
||||
Ok(
|
||||
open_ai::extract_tool_args_from_events(tool_name, Box::pin(response))
|
||||
.await?
|
||||
.boxed(),
|
||||
)
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
|
@ -4,7 +4,7 @@ use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
|
||||
use isahc::config::Configurable;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use serde_json::{value::RawValue, Value};
|
||||
use std::{convert::TryFrom, sync::Arc, time::Duration};
|
||||
|
||||
pub const OLLAMA_API_URL: &str = "http://localhost:11434";
|
||||
@ -92,7 +92,7 @@ impl Model {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[serde(tag = "role", rename_all = "lowercase")]
|
||||
pub enum ChatMessage {
|
||||
Assistant {
|
||||
@ -107,16 +107,16 @@ pub enum ChatMessage {
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum OllamaToolCall {
|
||||
Function(OllamaFunctionCall),
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct OllamaFunctionCall {
|
||||
pub name: String,
|
||||
pub arguments: Value,
|
||||
pub arguments: Box<RawValue>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
|
@ -6,7 +6,7 @@ use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
|
||||
use isahc::config::Configurable;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use std::{convert::TryFrom, future::Future, time::Duration};
|
||||
use std::{convert::TryFrom, future::Future, pin::Pin, time::Duration};
|
||||
use strum::EnumIter;
|
||||
|
||||
pub use supported_countries::*;
|
||||
@ -384,6 +384,57 @@ pub fn embed<'a>(
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn extract_tool_args_from_events(
|
||||
tool_name: String,
|
||||
mut events: Pin<Box<dyn Send + Stream<Item = Result<ResponseStreamEvent>>>>,
|
||||
) -> Result<impl Send + Stream<Item = Result<String>>> {
|
||||
let mut tool_use_index = None;
|
||||
let mut first_chunk = None;
|
||||
while let Some(event) = events.next().await {
|
||||
let call = event?.choices.into_iter().find_map(|choice| {
|
||||
choice.delta.tool_calls?.into_iter().find_map(|call| {
|
||||
if call.function.as_ref()?.name.as_deref()? == tool_name {
|
||||
Some(call)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
});
|
||||
if let Some(call) = call {
|
||||
tool_use_index = Some(call.index);
|
||||
first_chunk = call.function.and_then(|func| func.arguments);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let Some(tool_use_index) = tool_use_index else {
|
||||
return Err(anyhow!("tool not used"));
|
||||
};
|
||||
|
||||
Ok(events.filter_map(move |event| {
|
||||
let result = match event {
|
||||
Err(error) => Some(Err(error)),
|
||||
Ok(ResponseStreamEvent { choices, .. }) => choices.into_iter().find_map(|choice| {
|
||||
choice.delta.tool_calls?.into_iter().find_map(|call| {
|
||||
if call.index == tool_use_index {
|
||||
let func = call.function?;
|
||||
let mut arguments = func.arguments?;
|
||||
if let Some(mut first_chunk) = first_chunk.take() {
|
||||
first_chunk.push_str(&arguments);
|
||||
arguments = first_chunk
|
||||
}
|
||||
Some(Ok(arguments))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
}),
|
||||
};
|
||||
|
||||
async move { result }
|
||||
}))
|
||||
}
|
||||
|
||||
pub fn extract_text_from_events(
|
||||
response: impl Stream<Item = Result<ResponseStreamEvent>>,
|
||||
) -> impl Stream<Item = Result<String>> {
|
||||
|
Loading…
Reference in New Issue
Block a user