mirror of
https://github.com/zed-industries/zed.git
synced 2024-09-19 02:17:35 +03:00
assistant: Use tools in other providers (#15803)
- [x] OpenAI - [ ] ~Google~ Moved into a separate branch at: https://github.com/zed-industries/zed/tree/tool-calls-in-google-ai I've ran into issues with having the API digest our schema without tripping over itself - the function call parameters are malformed and whatnot. We can resume from that branch if needed. - [x] Ollama - [x] Cloud - [ ] ~Copilot Chat (?)~ Release Notes: - Added tool calling capabilities to OpenAI and Ollama models.
This commit is contained in:
parent
be514f23e1
commit
874f0c0712
@ -4,7 +4,7 @@ use crate::{
|
|||||||
LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
|
LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
|
||||||
LanguageModelProviderState, LanguageModelRequest, RateLimiter, ZedModel,
|
LanguageModelProviderState, LanguageModelRequest, RateLimiter, ZedModel,
|
||||||
};
|
};
|
||||||
use anyhow::{anyhow, Context as _, Result};
|
use anyhow::{anyhow, bail, Context as _, Result};
|
||||||
use client::{Client, PerformCompletionParams, UserStore, EXPIRED_LLM_TOKEN_HEADER_NAME};
|
use client::{Client, PerformCompletionParams, UserStore, EXPIRED_LLM_TOKEN_HEADER_NAME};
|
||||||
use collections::BTreeMap;
|
use collections::BTreeMap;
|
||||||
use feature_flags::{FeatureFlag, FeatureFlagAppExt, LanguageModels};
|
use feature_flags::{FeatureFlag, FeatureFlagAppExt, LanguageModels};
|
||||||
@ -634,14 +634,143 @@ impl LanguageModel for CloudLanguageModel {
|
|||||||
})
|
})
|
||||||
.boxed()
|
.boxed()
|
||||||
}
|
}
|
||||||
CloudModel::OpenAi(_) => {
|
CloudModel::OpenAi(model) => {
|
||||||
future::ready(Err(anyhow!("tool use not implemented for OpenAI"))).boxed()
|
let mut request = request.into_open_ai(model.id().into());
|
||||||
|
let client = self.client.clone();
|
||||||
|
let mut function = open_ai::FunctionDefinition {
|
||||||
|
name: tool_name.clone(),
|
||||||
|
description: None,
|
||||||
|
parameters: None,
|
||||||
|
};
|
||||||
|
let func = open_ai::ToolDefinition::Function {
|
||||||
|
function: function.clone(),
|
||||||
|
};
|
||||||
|
request.tool_choice = Some(open_ai::ToolChoice::Other(func.clone()));
|
||||||
|
// Fill in description and params separately, as they're not needed for tool_choice field.
|
||||||
|
function.description = Some(tool_description);
|
||||||
|
function.parameters = Some(input_schema);
|
||||||
|
request.tools = vec![open_ai::ToolDefinition::Function { function }];
|
||||||
|
self.request_limiter
|
||||||
|
.run(async move {
|
||||||
|
let request = serde_json::to_string(&request)?;
|
||||||
|
let response = client
|
||||||
|
.request_stream(proto::StreamCompleteWithLanguageModel {
|
||||||
|
provider: proto::LanguageModelProvider::OpenAi as i32,
|
||||||
|
request,
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
// Call arguments are gonna be streamed in over multiple chunks.
|
||||||
|
let mut load_state = None;
|
||||||
|
let mut response = response.map(
|
||||||
|
|item: Result<
|
||||||
|
proto::StreamCompleteWithLanguageModelResponse,
|
||||||
|
anyhow::Error,
|
||||||
|
>| {
|
||||||
|
Result::<open_ai::ResponseStreamEvent, anyhow::Error>::Ok(
|
||||||
|
serde_json::from_str(&item?.event)?,
|
||||||
|
)
|
||||||
|
},
|
||||||
|
);
|
||||||
|
while let Some(Ok(part)) = response.next().await {
|
||||||
|
for choice in part.choices {
|
||||||
|
let Some(tool_calls) = choice.delta.tool_calls else {
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
|
||||||
|
for call in tool_calls {
|
||||||
|
if let Some(func) = call.function {
|
||||||
|
if func.name.as_deref() == Some(tool_name.as_str()) {
|
||||||
|
load_state = Some((String::default(), call.index));
|
||||||
|
}
|
||||||
|
if let Some((arguments, (output, index))) =
|
||||||
|
func.arguments.zip(load_state.as_mut())
|
||||||
|
{
|
||||||
|
if call.index == *index {
|
||||||
|
output.push_str(&arguments);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if let Some((arguments, _)) = load_state {
|
||||||
|
return Ok(serde_json::from_str(&arguments)?);
|
||||||
|
} else {
|
||||||
|
bail!("tool not used");
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.boxed()
|
||||||
}
|
}
|
||||||
CloudModel::Google(_) => {
|
CloudModel::Google(_) => {
|
||||||
future::ready(Err(anyhow!("tool use not implemented for Google AI"))).boxed()
|
future::ready(Err(anyhow!("tool use not implemented for Google AI"))).boxed()
|
||||||
}
|
}
|
||||||
CloudModel::Zed(_) => {
|
CloudModel::Zed(model) => {
|
||||||
future::ready(Err(anyhow!("tool use not implemented for Zed models"))).boxed()
|
// All Zed models are OpenAI-based at the time of writing.
|
||||||
|
let mut request = request.into_open_ai(model.id().into());
|
||||||
|
let client = self.client.clone();
|
||||||
|
let mut function = open_ai::FunctionDefinition {
|
||||||
|
name: tool_name.clone(),
|
||||||
|
description: None,
|
||||||
|
parameters: None,
|
||||||
|
};
|
||||||
|
let func = open_ai::ToolDefinition::Function {
|
||||||
|
function: function.clone(),
|
||||||
|
};
|
||||||
|
request.tool_choice = Some(open_ai::ToolChoice::Other(func.clone()));
|
||||||
|
// Fill in description and params separately, as they're not needed for tool_choice field.
|
||||||
|
function.description = Some(tool_description);
|
||||||
|
function.parameters = Some(input_schema);
|
||||||
|
request.tools = vec![open_ai::ToolDefinition::Function { function }];
|
||||||
|
self.request_limiter
|
||||||
|
.run(async move {
|
||||||
|
let request = serde_json::to_string(&request)?;
|
||||||
|
let response = client
|
||||||
|
.request_stream(proto::StreamCompleteWithLanguageModel {
|
||||||
|
provider: proto::LanguageModelProvider::OpenAi as i32,
|
||||||
|
request,
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
// Call arguments are gonna be streamed in over multiple chunks.
|
||||||
|
let mut load_state = None;
|
||||||
|
let mut response = response.map(
|
||||||
|
|item: Result<
|
||||||
|
proto::StreamCompleteWithLanguageModelResponse,
|
||||||
|
anyhow::Error,
|
||||||
|
>| {
|
||||||
|
Result::<open_ai::ResponseStreamEvent, anyhow::Error>::Ok(
|
||||||
|
serde_json::from_str(&item?.event)?,
|
||||||
|
)
|
||||||
|
},
|
||||||
|
);
|
||||||
|
while let Some(Ok(part)) = response.next().await {
|
||||||
|
for choice in part.choices {
|
||||||
|
let Some(tool_calls) = choice.delta.tool_calls else {
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
|
||||||
|
for call in tool_calls {
|
||||||
|
if let Some(func) = call.function {
|
||||||
|
if func.name.as_deref() == Some(tool_name.as_str()) {
|
||||||
|
load_state = Some((String::default(), call.index));
|
||||||
|
}
|
||||||
|
if let Some((arguments, (output, index))) =
|
||||||
|
func.arguments.zip(load_state.as_mut())
|
||||||
|
{
|
||||||
|
if call.index == *index {
|
||||||
|
output.push_str(&arguments);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if let Some((arguments, _)) = load_state {
|
||||||
|
return Ok(serde_json::from_str(&arguments)?);
|
||||||
|
} else {
|
||||||
|
bail!("tool not used");
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.boxed()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,12 +1,14 @@
|
|||||||
use anyhow::{anyhow, Result};
|
use anyhow::{anyhow, bail, Result};
|
||||||
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
|
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
|
||||||
use gpui::{AnyView, AppContext, AsyncAppContext, ModelContext, Subscription, Task};
|
use gpui::{AnyView, AppContext, AsyncAppContext, ModelContext, Subscription, Task};
|
||||||
use http_client::HttpClient;
|
use http_client::HttpClient;
|
||||||
use ollama::{
|
use ollama::{
|
||||||
get_models, preload_model, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest,
|
get_models, preload_model, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest,
|
||||||
|
ChatResponseDelta, OllamaToolCall,
|
||||||
};
|
};
|
||||||
|
use serde_json::Value;
|
||||||
use settings::{Settings, SettingsStore};
|
use settings::{Settings, SettingsStore};
|
||||||
use std::{future, sync::Arc, time::Duration};
|
use std::{sync::Arc, time::Duration};
|
||||||
use ui::{prelude::*, ButtonLike, Indicator};
|
use ui::{prelude::*, ButtonLike, Indicator};
|
||||||
use util::ResultExt;
|
use util::ResultExt;
|
||||||
|
|
||||||
@ -184,6 +186,7 @@ impl OllamaLanguageModel {
|
|||||||
},
|
},
|
||||||
Role::Assistant => ChatMessage::Assistant {
|
Role::Assistant => ChatMessage::Assistant {
|
||||||
content: msg.content,
|
content: msg.content,
|
||||||
|
tool_calls: None,
|
||||||
},
|
},
|
||||||
Role::System => ChatMessage::System {
|
Role::System => ChatMessage::System {
|
||||||
content: msg.content,
|
content: msg.content,
|
||||||
@ -198,8 +201,25 @@ impl OllamaLanguageModel {
|
|||||||
temperature: Some(request.temperature),
|
temperature: Some(request.temperature),
|
||||||
..Default::default()
|
..Default::default()
|
||||||
}),
|
}),
|
||||||
|
tools: vec![],
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
fn request_completion(
|
||||||
|
&self,
|
||||||
|
request: ChatRequest,
|
||||||
|
cx: &AsyncAppContext,
|
||||||
|
) -> BoxFuture<'static, Result<ChatResponseDelta>> {
|
||||||
|
let http_client = self.http_client.clone();
|
||||||
|
|
||||||
|
let Ok(api_url) = cx.update(|cx| {
|
||||||
|
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
|
||||||
|
settings.api_url.clone()
|
||||||
|
}) else {
|
||||||
|
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
|
||||||
|
};
|
||||||
|
|
||||||
|
async move { ollama::complete(http_client.as_ref(), &api_url, request).await }.boxed()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl LanguageModel for OllamaLanguageModel {
|
impl LanguageModel for OllamaLanguageModel {
|
||||||
@ -269,7 +289,7 @@ impl LanguageModel for OllamaLanguageModel {
|
|||||||
Ok(delta) => {
|
Ok(delta) => {
|
||||||
let content = match delta.message {
|
let content = match delta.message {
|
||||||
ChatMessage::User { content } => content,
|
ChatMessage::User { content } => content,
|
||||||
ChatMessage::Assistant { content } => content,
|
ChatMessage::Assistant { content, .. } => content,
|
||||||
ChatMessage::System { content } => content,
|
ChatMessage::System { content } => content,
|
||||||
};
|
};
|
||||||
Some(Ok(content))
|
Some(Ok(content))
|
||||||
@ -286,13 +306,48 @@ impl LanguageModel for OllamaLanguageModel {
|
|||||||
|
|
||||||
fn use_any_tool(
|
fn use_any_tool(
|
||||||
&self,
|
&self,
|
||||||
_request: LanguageModelRequest,
|
request: LanguageModelRequest,
|
||||||
_name: String,
|
tool_name: String,
|
||||||
_description: String,
|
tool_description: String,
|
||||||
_schema: serde_json::Value,
|
schema: serde_json::Value,
|
||||||
_cx: &AsyncAppContext,
|
cx: &AsyncAppContext,
|
||||||
) -> BoxFuture<'static, Result<serde_json::Value>> {
|
) -> BoxFuture<'static, Result<serde_json::Value>> {
|
||||||
future::ready(Err(anyhow!("not implemented"))).boxed()
|
use ollama::{OllamaFunctionTool, OllamaTool};
|
||||||
|
let function = OllamaFunctionTool {
|
||||||
|
name: tool_name.clone(),
|
||||||
|
description: Some(tool_description),
|
||||||
|
parameters: Some(schema),
|
||||||
|
};
|
||||||
|
let tools = vec![OllamaTool::Function { function }];
|
||||||
|
let request = self.to_ollama_request(request).with_tools(tools);
|
||||||
|
let response = self.request_completion(request, cx);
|
||||||
|
self.request_limiter
|
||||||
|
.run(async move {
|
||||||
|
let response = response.await?;
|
||||||
|
let ChatMessage::Assistant {
|
||||||
|
tool_calls,
|
||||||
|
content,
|
||||||
|
} = response.message
|
||||||
|
else {
|
||||||
|
bail!("message does not have an assistant role");
|
||||||
|
};
|
||||||
|
if let Some(tool_calls) = tool_calls.filter(|calls| !calls.is_empty()) {
|
||||||
|
for call in tool_calls {
|
||||||
|
let OllamaToolCall::Function(function) = call;
|
||||||
|
if function.name == tool_name {
|
||||||
|
return Ok(function.arguments);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if let Ok(args) = serde_json::from_str::<Value>(&content) {
|
||||||
|
// Parse content as arguments.
|
||||||
|
return Ok(args);
|
||||||
|
} else {
|
||||||
|
bail!("assistant message does not have any tool calls");
|
||||||
|
};
|
||||||
|
|
||||||
|
bail!("tool not used")
|
||||||
|
})
|
||||||
|
.boxed()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
use anyhow::{anyhow, Result};
|
use anyhow::{anyhow, bail, Result};
|
||||||
use collections::BTreeMap;
|
use collections::BTreeMap;
|
||||||
use editor::{Editor, EditorElement, EditorStyle};
|
use editor::{Editor, EditorElement, EditorStyle};
|
||||||
use futures::{future::BoxFuture, FutureExt, StreamExt};
|
use futures::{future::BoxFuture, FutureExt, StreamExt};
|
||||||
@ -7,11 +7,13 @@ use gpui::{
|
|||||||
View, WhiteSpace,
|
View, WhiteSpace,
|
||||||
};
|
};
|
||||||
use http_client::HttpClient;
|
use http_client::HttpClient;
|
||||||
use open_ai::stream_completion;
|
use open_ai::{
|
||||||
|
stream_completion, FunctionDefinition, ResponseStreamEvent, ToolChoice, ToolDefinition,
|
||||||
|
};
|
||||||
use schemars::JsonSchema;
|
use schemars::JsonSchema;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use settings::{Settings, SettingsStore};
|
use settings::{Settings, SettingsStore};
|
||||||
use std::{future, sync::Arc, time::Duration};
|
use std::{sync::Arc, time::Duration};
|
||||||
use strum::IntoEnumIterator;
|
use strum::IntoEnumIterator;
|
||||||
use theme::ThemeSettings;
|
use theme::ThemeSettings;
|
||||||
use ui::{prelude::*, Indicator};
|
use ui::{prelude::*, Indicator};
|
||||||
@ -206,6 +208,41 @@ pub struct OpenAiLanguageModel {
|
|||||||
request_limiter: RateLimiter,
|
request_limiter: RateLimiter,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl OpenAiLanguageModel {
|
||||||
|
fn stream_completion(
|
||||||
|
&self,
|
||||||
|
request: open_ai::Request,
|
||||||
|
cx: &AsyncAppContext,
|
||||||
|
) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<ResponseStreamEvent>>>>
|
||||||
|
{
|
||||||
|
let http_client = self.http_client.clone();
|
||||||
|
let Ok((api_key, api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, cx| {
|
||||||
|
let settings = &AllLanguageModelSettings::get_global(cx).openai;
|
||||||
|
(
|
||||||
|
state.api_key.clone(),
|
||||||
|
settings.api_url.clone(),
|
||||||
|
settings.low_speed_timeout,
|
||||||
|
)
|
||||||
|
}) else {
|
||||||
|
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
|
||||||
|
};
|
||||||
|
|
||||||
|
let future = self.request_limiter.stream(async move {
|
||||||
|
let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
|
||||||
|
let request = stream_completion(
|
||||||
|
http_client.as_ref(),
|
||||||
|
&api_url,
|
||||||
|
&api_key,
|
||||||
|
request,
|
||||||
|
low_speed_timeout,
|
||||||
|
);
|
||||||
|
let response = request.await?;
|
||||||
|
Ok(response)
|
||||||
|
});
|
||||||
|
|
||||||
|
async move { Ok(future.await?.boxed()) }.boxed()
|
||||||
|
}
|
||||||
|
}
|
||||||
impl LanguageModel for OpenAiLanguageModel {
|
impl LanguageModel for OpenAiLanguageModel {
|
||||||
fn id(&self) -> LanguageModelId {
|
fn id(&self) -> LanguageModelId {
|
||||||
self.id.clone()
|
self.id.clone()
|
||||||
@ -245,44 +282,68 @@ impl LanguageModel for OpenAiLanguageModel {
|
|||||||
cx: &AsyncAppContext,
|
cx: &AsyncAppContext,
|
||||||
) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<String>>>> {
|
) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<String>>>> {
|
||||||
let request = request.into_open_ai(self.model.id().into());
|
let request = request.into_open_ai(self.model.id().into());
|
||||||
|
let completions = self.stream_completion(request, cx);
|
||||||
let http_client = self.http_client.clone();
|
async move { Ok(open_ai::extract_text_from_events(completions.await?).boxed()) }.boxed()
|
||||||
let Ok((api_key, api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, cx| {
|
|
||||||
let settings = &AllLanguageModelSettings::get_global(cx).openai;
|
|
||||||
(
|
|
||||||
state.api_key.clone(),
|
|
||||||
settings.api_url.clone(),
|
|
||||||
settings.low_speed_timeout,
|
|
||||||
)
|
|
||||||
}) else {
|
|
||||||
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
|
|
||||||
};
|
|
||||||
|
|
||||||
let future = self.request_limiter.stream(async move {
|
|
||||||
let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
|
|
||||||
let request = stream_completion(
|
|
||||||
http_client.as_ref(),
|
|
||||||
&api_url,
|
|
||||||
&api_key,
|
|
||||||
request,
|
|
||||||
low_speed_timeout,
|
|
||||||
);
|
|
||||||
let response = request.await?;
|
|
||||||
Ok(open_ai::extract_text_from_events(response).boxed())
|
|
||||||
});
|
|
||||||
|
|
||||||
async move { Ok(future.await?.boxed()) }.boxed()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn use_any_tool(
|
fn use_any_tool(
|
||||||
&self,
|
&self,
|
||||||
_request: LanguageModelRequest,
|
request: LanguageModelRequest,
|
||||||
_name: String,
|
tool_name: String,
|
||||||
_description: String,
|
tool_description: String,
|
||||||
_schema: serde_json::Value,
|
schema: serde_json::Value,
|
||||||
_cx: &AsyncAppContext,
|
cx: &AsyncAppContext,
|
||||||
) -> BoxFuture<'static, Result<serde_json::Value>> {
|
) -> BoxFuture<'static, Result<serde_json::Value>> {
|
||||||
future::ready(Err(anyhow!("not implemented"))).boxed()
|
let mut request = request.into_open_ai(self.model.id().into());
|
||||||
|
let mut function = FunctionDefinition {
|
||||||
|
name: tool_name.clone(),
|
||||||
|
description: None,
|
||||||
|
parameters: None,
|
||||||
|
};
|
||||||
|
let func = ToolDefinition::Function {
|
||||||
|
function: function.clone(),
|
||||||
|
};
|
||||||
|
request.tool_choice = Some(ToolChoice::Other(func.clone()));
|
||||||
|
// Fill in description and params separately, as they're not needed for tool_choice field.
|
||||||
|
function.description = Some(tool_description);
|
||||||
|
function.parameters = Some(schema);
|
||||||
|
request.tools = vec![ToolDefinition::Function { function }];
|
||||||
|
let response = self.stream_completion(request, cx);
|
||||||
|
self.request_limiter
|
||||||
|
.run(async move {
|
||||||
|
let mut response = response.await?;
|
||||||
|
|
||||||
|
// Call arguments are gonna be streamed in over multiple chunks.
|
||||||
|
let mut load_state = None;
|
||||||
|
while let Some(Ok(part)) = response.next().await {
|
||||||
|
for choice in part.choices {
|
||||||
|
let Some(tool_calls) = choice.delta.tool_calls else {
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
|
||||||
|
for call in tool_calls {
|
||||||
|
if let Some(func) = call.function {
|
||||||
|
if func.name.as_deref() == Some(tool_name.as_str()) {
|
||||||
|
load_state = Some((String::default(), call.index));
|
||||||
|
}
|
||||||
|
if let Some((arguments, (output, index))) =
|
||||||
|
func.arguments.zip(load_state.as_mut())
|
||||||
|
{
|
||||||
|
if call.index == *index {
|
||||||
|
output.push_str(&arguments);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if let Some((arguments, _)) = load_state {
|
||||||
|
return Ok(serde_json::from_str(&arguments)?);
|
||||||
|
} else {
|
||||||
|
bail!("tool not used");
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.boxed()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -4,6 +4,7 @@ use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
|
|||||||
use isahc::config::Configurable;
|
use isahc::config::Configurable;
|
||||||
use schemars::JsonSchema;
|
use schemars::JsonSchema;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
use serde_json::Value;
|
||||||
use std::{convert::TryFrom, sync::Arc, time::Duration};
|
use std::{convert::TryFrom, sync::Arc, time::Duration};
|
||||||
|
|
||||||
pub const OLLAMA_API_URL: &str = "http://localhost:11434";
|
pub const OLLAMA_API_URL: &str = "http://localhost:11434";
|
||||||
@ -94,22 +95,63 @@ impl Model {
|
|||||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||||
#[serde(tag = "role", rename_all = "lowercase")]
|
#[serde(tag = "role", rename_all = "lowercase")]
|
||||||
pub enum ChatMessage {
|
pub enum ChatMessage {
|
||||||
Assistant { content: String },
|
Assistant {
|
||||||
User { content: String },
|
content: String,
|
||||||
System { content: String },
|
tool_calls: Option<Vec<OllamaToolCall>>,
|
||||||
|
},
|
||||||
|
User {
|
||||||
|
content: String,
|
||||||
|
},
|
||||||
|
System {
|
||||||
|
content: String,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize)]
|
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||||
|
#[serde(rename_all = "lowercase")]
|
||||||
|
pub enum OllamaToolCall {
|
||||||
|
Function(OllamaFunctionCall),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||||
|
pub struct OllamaFunctionCall {
|
||||||
|
pub name: String,
|
||||||
|
pub arguments: Value,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||||
|
pub struct OllamaFunctionTool {
|
||||||
|
pub name: String,
|
||||||
|
pub description: Option<String>,
|
||||||
|
pub parameters: Option<Value>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||||
|
#[serde(tag = "type", rename_all = "lowercase")]
|
||||||
|
pub enum OllamaTool {
|
||||||
|
Function { function: OllamaFunctionTool },
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Debug)]
|
||||||
pub struct ChatRequest {
|
pub struct ChatRequest {
|
||||||
pub model: String,
|
pub model: String,
|
||||||
pub messages: Vec<ChatMessage>,
|
pub messages: Vec<ChatMessage>,
|
||||||
pub stream: bool,
|
pub stream: bool,
|
||||||
pub keep_alive: KeepAlive,
|
pub keep_alive: KeepAlive,
|
||||||
pub options: Option<ChatOptions>,
|
pub options: Option<ChatOptions>,
|
||||||
|
pub tools: Vec<OllamaTool>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ChatRequest {
|
||||||
|
pub fn with_tools(mut self, tools: Vec<OllamaTool>) -> Self {
|
||||||
|
self.stream = false;
|
||||||
|
self.tools = tools;
|
||||||
|
self
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values
|
// https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values
|
||||||
#[derive(Serialize, Default)]
|
#[derive(Serialize, Default, Debug)]
|
||||||
pub struct ChatOptions {
|
pub struct ChatOptions {
|
||||||
pub num_ctx: Option<usize>,
|
pub num_ctx: Option<usize>,
|
||||||
pub num_predict: Option<isize>,
|
pub num_predict: Option<isize>,
|
||||||
@ -118,7 +160,7 @@ pub struct ChatOptions {
|
|||||||
pub top_p: Option<f32>,
|
pub top_p: Option<f32>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize, Debug)]
|
||||||
pub struct ChatResponseDelta {
|
pub struct ChatResponseDelta {
|
||||||
#[allow(unused)]
|
#[allow(unused)]
|
||||||
pub model: String,
|
pub model: String,
|
||||||
@ -162,6 +204,38 @@ pub struct ModelDetails {
|
|||||||
pub quantization_level: String,
|
pub quantization_level: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub async fn complete(
|
||||||
|
client: &dyn HttpClient,
|
||||||
|
api_url: &str,
|
||||||
|
request: ChatRequest,
|
||||||
|
) -> Result<ChatResponseDelta> {
|
||||||
|
let uri = format!("{api_url}/api/chat");
|
||||||
|
let request_builder = HttpRequest::builder()
|
||||||
|
.method(Method::POST)
|
||||||
|
.uri(uri)
|
||||||
|
.header("Content-Type", "application/json");
|
||||||
|
|
||||||
|
let serialized_request = serde_json::to_string(&request)?;
|
||||||
|
let request = request_builder.body(AsyncBody::from(serialized_request))?;
|
||||||
|
|
||||||
|
let mut response = client.send(request).await?;
|
||||||
|
if response.status().is_success() {
|
||||||
|
let mut body = Vec::new();
|
||||||
|
response.body_mut().read_to_end(&mut body).await?;
|
||||||
|
let response_message: ChatResponseDelta = serde_json::from_slice(&body)?;
|
||||||
|
Ok(response_message)
|
||||||
|
} else {
|
||||||
|
let mut body = Vec::new();
|
||||||
|
response.body_mut().read_to_end(&mut body).await?;
|
||||||
|
let body_str = std::str::from_utf8(&body)?;
|
||||||
|
Err(anyhow!(
|
||||||
|
"Failed to connect to API: {} {}",
|
||||||
|
response.status(),
|
||||||
|
body_str
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub async fn stream_chat_completion(
|
pub async fn stream_chat_completion(
|
||||||
client: &dyn HttpClient,
|
client: &dyn HttpClient,
|
||||||
api_url: &str,
|
api_url: &str,
|
||||||
|
@ -3,7 +3,7 @@ use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, S
|
|||||||
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
|
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
|
||||||
use isahc::config::Configurable;
|
use isahc::config::Configurable;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_json::{Map, Value};
|
use serde_json::Value;
|
||||||
use std::{convert::TryFrom, future::Future, time::Duration};
|
use std::{convert::TryFrom, future::Future, time::Duration};
|
||||||
use strum::EnumIter;
|
use strum::EnumIter;
|
||||||
|
|
||||||
@ -121,25 +121,34 @@ pub struct Request {
|
|||||||
pub stop: Vec<String>,
|
pub stop: Vec<String>,
|
||||||
pub temperature: f32,
|
pub temperature: f32,
|
||||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
pub tool_choice: Option<String>,
|
pub tool_choice: Option<ToolChoice>,
|
||||||
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
||||||
pub tools: Vec<ToolDefinition>,
|
pub tools: Vec<ToolDefinition>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize, Serialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
pub struct FunctionDefinition {
|
#[serde(untagged)]
|
||||||
pub name: String,
|
pub enum ToolChoice {
|
||||||
pub description: Option<String>,
|
Auto,
|
||||||
pub parameters: Option<Map<String, Value>>,
|
Required,
|
||||||
|
None,
|
||||||
|
Other(ToolDefinition),
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Deserialize, Serialize, Debug)]
|
#[derive(Clone, Deserialize, Serialize, Debug)]
|
||||||
#[serde(tag = "type", rename_all = "snake_case")]
|
#[serde(tag = "type", rename_all = "snake_case")]
|
||||||
pub enum ToolDefinition {
|
pub enum ToolDefinition {
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
Function { function: FunctionDefinition },
|
Function { function: FunctionDefinition },
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
|
pub struct FunctionDefinition {
|
||||||
|
pub name: String,
|
||||||
|
pub description: Option<String>,
|
||||||
|
pub parameters: Option<Value>,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||||
#[serde(tag = "role", rename_all = "lowercase")]
|
#[serde(tag = "role", rename_all = "lowercase")]
|
||||||
pub enum RequestMessage {
|
pub enum RequestMessage {
|
||||||
|
Loading…
Reference in New Issue
Block a user