From 4cb8d6f40ed1427353dcba5a10fcc4b22da1a365 Mon Sep 17 00:00:00 2001 From: Kyle Kelley Date: Tue, 11 Jun 2024 17:35:27 -0700 Subject: [PATCH] Ollama Provider for Assistant (#12902) Closes #4424. A few design decisions that may need some rethinking or later PRs: * Other providers have a check for authentication. I use this opportunity to fetch the models which doubles as a way of finding out if the Ollama server is running. * Ollama has _no_ API for getting the max tokens per model * Ollama has _no_ API for getting the current token count https://github.com/ollama/ollama/issues/1716 * Ollama does allow setting the `num_ctx` so I've defaulted this to 4096. It can be overridden in settings. * Ollama models will be "slow" to start inference because they're loading the model into memory. It's faster after that. There's no UI affordance to show that the model is being loaded. Release Notes: - Added an Ollama Provider for the assistant. If you have [Ollama](https://ollama.com/) running locally on your machine, you can enable it in your settings under: ```jsonc "assistant": { "version": "1", "provider": { "name": "ollama", // Recommended setting to allow for model startup "low_speed_timeout_in_seconds": 30, } } ``` Chat like usual image Interact with any model from the [Ollama Library](https://ollama.com/library) image Open up the terminal to download new models via `ollama pull`: ![image](https://github.com/zed-industries/zed/assets/836375/af7ec411-76bf-41c7-ba81-64bbaeea98a8) --- Cargo.lock | 14 + Cargo.toml | 2 + crates/assistant/Cargo.toml | 1 + crates/assistant/src/assistant.rs | 8 +- crates/assistant/src/assistant_settings.rs | 49 ++++ crates/assistant/src/completion_provider.rs | 59 +++++ .../src/completion_provider/ollama.rs | 246 ++++++++++++++++++ crates/ollama/Cargo.toml | 22 ++ crates/ollama/src/ollama.rs | 224 ++++++++++++++++ 9 files changed, 624 insertions(+), 1 deletion(-) create mode 100644 crates/assistant/src/completion_provider/ollama.rs create mode 100644 crates/ollama/Cargo.toml create mode 100644 crates/ollama/src/ollama.rs diff --git a/Cargo.lock b/Cargo.lock index c18c904f68..fee0657369 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -359,6 +359,7 @@ dependencies = [ "log", "menu", "multi_buffer", + "ollama", "open_ai", "ordered-float 2.10.0", "parking_lot", @@ -6921,6 +6922,19 @@ dependencies = [ "cc", ] +[[package]] +name = "ollama" +version = "0.1.0" +dependencies = [ + "anyhow", + "futures 0.3.28", + "http 0.1.0", + "isahc", + "schemars", + "serde", + "serde_json", +] + [[package]] name = "once_cell" version = "1.19.0" diff --git a/Cargo.toml b/Cargo.toml index 5f001e5d29..79510e808e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -61,6 +61,7 @@ members = [ "crates/multi_buffer", "crates/node_runtime", "crates/notifications", + "crates/ollama", "crates/open_ai", "crates/outline", "crates/picker", @@ -207,6 +208,7 @@ menu = { path = "crates/menu" } multi_buffer = { path = "crates/multi_buffer" } node_runtime = { path = "crates/node_runtime" } notifications = { path = "crates/notifications" } +ollama = { path = "crates/ollama" } open_ai = { path = "crates/open_ai" } outline = { path = "crates/outline" } picker = { path = "crates/picker" } diff --git a/crates/assistant/Cargo.toml b/crates/assistant/Cargo.toml index eaa6fc5b73..77f0bc4ae0 100644 --- a/crates/assistant/Cargo.toml +++ b/crates/assistant/Cargo.toml @@ -35,6 +35,7 @@ language.workspace = true log.workspace = true menu.workspace = true multi_buffer.workspace = true +ollama = { workspace = true, features = ["schemars"] } open_ai = { workspace = true, features = ["schemars"] } ordered-float.workspace = true parking_lot.workspace = true diff --git a/crates/assistant/src/assistant.rs b/crates/assistant/src/assistant.rs index bb795b1034..07488fdc5b 100644 --- a/crates/assistant/src/assistant.rs +++ b/crates/assistant/src/assistant.rs @@ -12,7 +12,7 @@ mod streaming_diff; pub use assistant_panel::AssistantPanel; -use assistant_settings::{AnthropicModel, AssistantSettings, CloudModel, OpenAiModel}; +use assistant_settings::{AnthropicModel, AssistantSettings, CloudModel, OllamaModel, OpenAiModel}; use assistant_slash_command::SlashCommandRegistry; use client::{proto, Client}; use command_palette_hooks::CommandPaletteFilter; @@ -91,6 +91,7 @@ pub enum LanguageModel { Cloud(CloudModel), OpenAi(OpenAiModel), Anthropic(AnthropicModel), + Ollama(OllamaModel), } impl Default for LanguageModel { @@ -105,6 +106,7 @@ impl LanguageModel { LanguageModel::OpenAi(model) => format!("openai/{}", model.id()), LanguageModel::Anthropic(model) => format!("anthropic/{}", model.id()), LanguageModel::Cloud(model) => format!("zed.dev/{}", model.id()), + LanguageModel::Ollama(model) => format!("ollama/{}", model.id()), } } @@ -113,6 +115,7 @@ impl LanguageModel { LanguageModel::OpenAi(model) => model.display_name().into(), LanguageModel::Anthropic(model) => model.display_name().into(), LanguageModel::Cloud(model) => model.display_name().into(), + LanguageModel::Ollama(model) => model.display_name().into(), } } @@ -121,6 +124,7 @@ impl LanguageModel { LanguageModel::OpenAi(model) => model.max_token_count(), LanguageModel::Anthropic(model) => model.max_token_count(), LanguageModel::Cloud(model) => model.max_token_count(), + LanguageModel::Ollama(model) => model.max_token_count(), } } @@ -129,6 +133,7 @@ impl LanguageModel { LanguageModel::OpenAi(model) => model.id(), LanguageModel::Anthropic(model) => model.id(), LanguageModel::Cloud(model) => model.id(), + LanguageModel::Ollama(model) => model.id(), } } } @@ -179,6 +184,7 @@ impl LanguageModelRequest { match &self.model { LanguageModel::OpenAi(_) => {} LanguageModel::Anthropic(_) => {} + LanguageModel::Ollama(_) => {} LanguageModel::Cloud(model) => match model { CloudModel::Claude3Opus | CloudModel::Claude3Sonnet | CloudModel::Claude3Haiku => { preprocess_anthropic_request(self); diff --git a/crates/assistant/src/assistant_settings.rs b/crates/assistant/src/assistant_settings.rs index 3efaff100d..efc726fe22 100644 --- a/crates/assistant/src/assistant_settings.rs +++ b/crates/assistant/src/assistant_settings.rs @@ -2,6 +2,7 @@ use std::fmt; pub use anthropic::Model as AnthropicModel; use gpui::Pixels; +pub use ollama::Model as OllamaModel; pub use open_ai::Model as OpenAiModel; use schemars::{ schema::{InstanceType, Metadata, Schema, SchemaObject}, @@ -168,6 +169,11 @@ pub enum AssistantProvider { api_url: String, low_speed_timeout_in_seconds: Option, }, + Ollama { + model: OllamaModel, + api_url: String, + low_speed_timeout_in_seconds: Option, + }, } impl Default for AssistantProvider { @@ -197,6 +203,12 @@ pub enum AssistantProviderContent { api_url: Option, low_speed_timeout_in_seconds: Option, }, + #[serde(rename = "ollama")] + Ollama { + default_model: Option, + api_url: Option, + low_speed_timeout_in_seconds: Option, + }, } #[derive(Debug, Default)] @@ -328,6 +340,13 @@ impl AssistantSettingsContent { low_speed_timeout_in_seconds: None, }) } + LanguageModel::Ollama(model) => { + *provider = Some(AssistantProviderContent::Ollama { + default_model: Some(model), + api_url: None, + low_speed_timeout_in_seconds: None, + }) + } }, }, }, @@ -472,6 +491,27 @@ impl Settings for AssistantSettings { Some(low_speed_timeout_in_seconds_override); } } + ( + AssistantProvider::Ollama { + model, + api_url, + low_speed_timeout_in_seconds, + }, + AssistantProviderContent::Ollama { + default_model: model_override, + api_url: api_url_override, + low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override, + }, + ) => { + merge(model, model_override); + merge(api_url, api_url_override); + if let Some(low_speed_timeout_in_seconds_override) = + low_speed_timeout_in_seconds_override + { + *low_speed_timeout_in_seconds = + Some(low_speed_timeout_in_seconds_override); + } + } ( AssistantProvider::Anthropic { model, @@ -519,6 +559,15 @@ impl Settings for AssistantSettings { .unwrap_or_else(|| anthropic::ANTHROPIC_API_URL.into()), low_speed_timeout_in_seconds, }, + AssistantProviderContent::Ollama { + default_model: model, + api_url, + low_speed_timeout_in_seconds, + } => AssistantProvider::Ollama { + model: model.unwrap_or_default(), + api_url: api_url.unwrap_or_else(|| ollama::OLLAMA_API_URL.into()), + low_speed_timeout_in_seconds, + }, }; } } diff --git a/crates/assistant/src/completion_provider.rs b/crates/assistant/src/completion_provider.rs index 01ea6325ad..78b22556ac 100644 --- a/crates/assistant/src/completion_provider.rs +++ b/crates/assistant/src/completion_provider.rs @@ -2,12 +2,14 @@ mod anthropic; mod cloud; #[cfg(test)] mod fake; +mod ollama; mod open_ai; pub use anthropic::*; pub use cloud::*; #[cfg(test)] pub use fake::*; +pub use ollama::*; pub use open_ai::*; use crate::{ @@ -50,6 +52,17 @@ pub fn init(client: Arc, cx: &mut AppContext) { low_speed_timeout_in_seconds.map(Duration::from_secs), settings_version, )), + AssistantProvider::Ollama { + model, + api_url, + low_speed_timeout_in_seconds, + } => CompletionProvider::Ollama(OllamaCompletionProvider::new( + model.clone(), + api_url.clone(), + client.http_client(), + low_speed_timeout_in_seconds.map(Duration::from_secs), + settings_version, + )), }; cx.set_global(provider); @@ -87,6 +100,23 @@ pub fn init(client: Arc, cx: &mut AppContext) { settings_version, ); } + + ( + CompletionProvider::Ollama(provider), + AssistantProvider::Ollama { + model, + api_url, + low_speed_timeout_in_seconds, + }, + ) => { + provider.update( + model.clone(), + api_url.clone(), + low_speed_timeout_in_seconds.map(Duration::from_secs), + settings_version, + ); + } + (CompletionProvider::Cloud(provider), AssistantProvider::ZedDotDev { model }) => { provider.update(model.clone(), settings_version); } @@ -130,6 +160,22 @@ pub fn init(client: Arc, cx: &mut AppContext) { settings_version, )); } + ( + _, + AssistantProvider::Ollama { + model, + api_url, + low_speed_timeout_in_seconds, + }, + ) => { + *provider = CompletionProvider::Ollama(OllamaCompletionProvider::new( + model.clone(), + api_url.clone(), + client.http_client(), + low_speed_timeout_in_seconds.map(Duration::from_secs), + settings_version, + )); + } } }) }) @@ -142,6 +188,7 @@ pub enum CompletionProvider { Cloud(CloudCompletionProvider), #[cfg(test)] Fake(FakeCompletionProvider), + Ollama(OllamaCompletionProvider), } impl gpui::Global for CompletionProvider {} @@ -165,6 +212,10 @@ impl CompletionProvider { .available_models() .map(LanguageModel::Cloud) .collect(), + CompletionProvider::Ollama(provider) => provider + .available_models() + .map(|model| LanguageModel::Ollama(model.clone())) + .collect(), #[cfg(test)] CompletionProvider::Fake(_) => unimplemented!(), } @@ -175,6 +226,7 @@ impl CompletionProvider { CompletionProvider::OpenAi(provider) => provider.settings_version(), CompletionProvider::Anthropic(provider) => provider.settings_version(), CompletionProvider::Cloud(provider) => provider.settings_version(), + CompletionProvider::Ollama(provider) => provider.settings_version(), #[cfg(test)] CompletionProvider::Fake(_) => unimplemented!(), } @@ -185,6 +237,7 @@ impl CompletionProvider { CompletionProvider::OpenAi(provider) => provider.is_authenticated(), CompletionProvider::Anthropic(provider) => provider.is_authenticated(), CompletionProvider::Cloud(provider) => provider.is_authenticated(), + CompletionProvider::Ollama(provider) => provider.is_authenticated(), #[cfg(test)] CompletionProvider::Fake(_) => true, } @@ -195,6 +248,7 @@ impl CompletionProvider { CompletionProvider::OpenAi(provider) => provider.authenticate(cx), CompletionProvider::Anthropic(provider) => provider.authenticate(cx), CompletionProvider::Cloud(provider) => provider.authenticate(cx), + CompletionProvider::Ollama(provider) => provider.authenticate(cx), #[cfg(test)] CompletionProvider::Fake(_) => Task::ready(Ok(())), } @@ -205,6 +259,7 @@ impl CompletionProvider { CompletionProvider::OpenAi(provider) => provider.authentication_prompt(cx), CompletionProvider::Anthropic(provider) => provider.authentication_prompt(cx), CompletionProvider::Cloud(provider) => provider.authentication_prompt(cx), + CompletionProvider::Ollama(provider) => provider.authentication_prompt(cx), #[cfg(test)] CompletionProvider::Fake(_) => unimplemented!(), } @@ -215,6 +270,7 @@ impl CompletionProvider { CompletionProvider::OpenAi(provider) => provider.reset_credentials(cx), CompletionProvider::Anthropic(provider) => provider.reset_credentials(cx), CompletionProvider::Cloud(_) => Task::ready(Ok(())), + CompletionProvider::Ollama(provider) => provider.reset_credentials(cx), #[cfg(test)] CompletionProvider::Fake(_) => Task::ready(Ok(())), } @@ -225,6 +281,7 @@ impl CompletionProvider { CompletionProvider::OpenAi(provider) => LanguageModel::OpenAi(provider.model()), CompletionProvider::Anthropic(provider) => LanguageModel::Anthropic(provider.model()), CompletionProvider::Cloud(provider) => LanguageModel::Cloud(provider.model()), + CompletionProvider::Ollama(provider) => LanguageModel::Ollama(provider.model()), #[cfg(test)] CompletionProvider::Fake(_) => LanguageModel::default(), } @@ -239,6 +296,7 @@ impl CompletionProvider { CompletionProvider::OpenAi(provider) => provider.count_tokens(request, cx), CompletionProvider::Anthropic(provider) => provider.count_tokens(request, cx), CompletionProvider::Cloud(provider) => provider.count_tokens(request, cx), + CompletionProvider::Ollama(provider) => provider.count_tokens(request, cx), #[cfg(test)] CompletionProvider::Fake(_) => futures::FutureExt::boxed(futures::future::ready(Ok(0))), } @@ -252,6 +310,7 @@ impl CompletionProvider { CompletionProvider::OpenAi(provider) => provider.complete(request), CompletionProvider::Anthropic(provider) => provider.complete(request), CompletionProvider::Cloud(provider) => provider.complete(request), + CompletionProvider::Ollama(provider) => provider.complete(request), #[cfg(test)] CompletionProvider::Fake(provider) => provider.complete(), } diff --git a/crates/assistant/src/completion_provider/ollama.rs b/crates/assistant/src/completion_provider/ollama.rs new file mode 100644 index 0000000000..74524da6dd --- /dev/null +++ b/crates/assistant/src/completion_provider/ollama.rs @@ -0,0 +1,246 @@ +use crate::{ + assistant_settings::OllamaModel, CompletionProvider, LanguageModel, LanguageModelRequest, Role, +}; +use anyhow::Result; +use futures::StreamExt as _; +use futures::{future::BoxFuture, stream::BoxStream, FutureExt}; +use gpui::{AnyView, AppContext, Task}; +use http::HttpClient; +use ollama::{ + get_models, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest, Role as OllamaRole, +}; +use std::sync::Arc; +use std::time::Duration; +use ui::{prelude::*, ButtonLike, ElevationIndex}; + +const OLLAMA_DOWNLOAD_URL: &str = "https://ollama.com/download"; + +pub struct OllamaCompletionProvider { + api_url: String, + model: OllamaModel, + http_client: Arc, + low_speed_timeout: Option, + settings_version: usize, + available_models: Vec, +} + +impl OllamaCompletionProvider { + pub fn new( + model: OllamaModel, + api_url: String, + http_client: Arc, + low_speed_timeout: Option, + settings_version: usize, + ) -> Self { + Self { + api_url, + model, + http_client, + low_speed_timeout, + settings_version, + available_models: Default::default(), + } + } + + pub fn update( + &mut self, + model: OllamaModel, + api_url: String, + low_speed_timeout: Option, + settings_version: usize, + ) { + self.model = model; + self.api_url = api_url; + self.low_speed_timeout = low_speed_timeout; + self.settings_version = settings_version; + } + + pub fn available_models(&self) -> impl Iterator { + self.available_models.iter() + } + + pub fn settings_version(&self) -> usize { + self.settings_version + } + + pub fn is_authenticated(&self) -> bool { + !self.available_models.is_empty() + } + + pub fn authenticate(&self, cx: &AppContext) -> Task> { + if self.is_authenticated() { + Task::ready(Ok(())) + } else { + self.fetch_models(cx) + } + } + + pub fn reset_credentials(&self, cx: &AppContext) -> Task> { + self.fetch_models(cx) + } + + pub fn fetch_models(&self, cx: &AppContext) -> Task> { + let http_client = self.http_client.clone(); + let api_url = self.api_url.clone(); + + // As a proxy for the server being "authenticated", we'll check if its up by fetching the models + cx.spawn(|mut cx| async move { + let models = get_models(http_client.as_ref(), &api_url, None).await?; + + let mut models: Vec = models + .into_iter() + // Since there is no metadata from the Ollama API + // indicating which models are embedding models, + // simply filter out models with "-embed" in their name + .filter(|model| !model.name.contains("-embed")) + .map(|model| OllamaModel::new(&model.name, &model.details.parameter_size)) + .collect(); + + models.sort_by(|a, b| a.name.cmp(&b.name)); + + cx.update_global::(|provider, _cx| { + if let CompletionProvider::Ollama(provider) = provider { + provider.available_models = models; + } + }) + }) + } + + pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView { + cx.new_view(|cx| DownloadOllamaMessage::new(cx)).into() + } + + pub fn model(&self) -> OllamaModel { + self.model.clone() + } + + pub fn count_tokens( + &self, + request: LanguageModelRequest, + _cx: &AppContext, + ) -> BoxFuture<'static, Result> { + // There is no endpoint for this _yet_ in Ollama + // see: https://github.com/ollama/ollama/issues/1716 and https://github.com/ollama/ollama/issues/3582 + let token_count = request + .messages + .iter() + .map(|msg| msg.content.chars().count()) + .sum::() + / 4; + + async move { Ok(token_count) }.boxed() + } + + pub fn complete( + &self, + request: LanguageModelRequest, + ) -> BoxFuture<'static, Result>>> { + let request = self.to_ollama_request(request); + + let http_client = self.http_client.clone(); + let api_url = self.api_url.clone(); + let low_speed_timeout = self.low_speed_timeout; + async move { + let request = + stream_chat_completion(http_client.as_ref(), &api_url, request, low_speed_timeout); + let response = request.await?; + let stream = response + .filter_map(|response| async move { + match response { + Ok(delta) => { + let content = match delta.message { + ChatMessage::User { content } => content, + ChatMessage::Assistant { content } => content, + ChatMessage::System { content } => content, + }; + Some(Ok(content)) + } + Err(error) => Some(Err(error)), + } + }) + .boxed(); + Ok(stream) + } + .boxed() + } + + fn to_ollama_request(&self, request: LanguageModelRequest) -> ChatRequest { + let model = match request.model { + LanguageModel::Ollama(model) => model, + _ => self.model(), + }; + + ChatRequest { + model: model.name, + messages: request + .messages + .into_iter() + .map(|msg| match msg.role { + Role::User => ChatMessage::User { + content: msg.content, + }, + Role::Assistant => ChatMessage::Assistant { + content: msg.content, + }, + Role::System => ChatMessage::System { + content: msg.content, + }, + }) + .collect(), + keep_alive: model.keep_alive, + stream: true, + options: Some(ChatOptions { + num_ctx: Some(model.max_tokens), + stop: Some(request.stop), + temperature: Some(request.temperature), + ..Default::default() + }), + } + } +} + +impl From for ollama::Role { + fn from(val: Role) -> Self { + match val { + Role::User => OllamaRole::User, + Role::Assistant => OllamaRole::Assistant, + Role::System => OllamaRole::System, + } + } +} + +struct DownloadOllamaMessage {} + +impl DownloadOllamaMessage { + pub fn new(_cx: &mut ViewContext) -> Self { + Self {} + } + + fn render_download_button(&self, _cx: &mut ViewContext) -> impl IntoElement { + ButtonLike::new("download_ollama_button") + .style(ButtonStyle::Filled) + .size(ButtonSize::Large) + .layer(ElevationIndex::ModalSurface) + .child(Label::new("Get Ollama")) + .on_click(move |_, cx| cx.open_url(OLLAMA_DOWNLOAD_URL)) + } +} + +impl Render for DownloadOllamaMessage { + fn render(&mut self, cx: &mut ViewContext) -> impl IntoElement { + v_flex() + .p_4() + .size_full() + .child(Label::new("To use Ollama models via the assistant, Ollama must be running on your machine.").size(LabelSize::Large)) + .child( + h_flex() + .w_full() + .p_4() + .justify_center() + .child( + self.render_download_button(cx) + ) + ) + .into_any() + } +} diff --git a/crates/ollama/Cargo.toml b/crates/ollama/Cargo.toml new file mode 100644 index 0000000000..2ff329df00 --- /dev/null +++ b/crates/ollama/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "ollama" +version = "0.1.0" +edition = "2021" +publish = false +license = "GPL-3.0-or-later" + +[lib] +path = "src/ollama.rs" + +[features] +default = [] +schemars = ["dep:schemars"] + +[dependencies] +anyhow.workspace = true +futures.workspace = true +http.workspace = true +isahc.workspace = true +schemars = { workspace = true, optional = true } +serde.workspace = true +serde_json.workspace = true diff --git a/crates/ollama/src/ollama.rs b/crates/ollama/src/ollama.rs new file mode 100644 index 0000000000..141d7fe000 --- /dev/null +++ b/crates/ollama/src/ollama.rs @@ -0,0 +1,224 @@ +use anyhow::{anyhow, Context, Result}; +use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt}; +use http::{AsyncBody, HttpClient, Method, Request as HttpRequest}; +use isahc::config::Configurable; +use serde::{Deserialize, Serialize}; +use std::{convert::TryFrom, time::Duration}; + +pub const OLLAMA_API_URL: &str = "http://localhost:11434"; + +#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum Role { + User, + Assistant, + System, +} + +impl TryFrom for Role { + type Error = anyhow::Error; + + fn try_from(value: String) -> Result { + match value.as_str() { + "user" => Ok(Self::User), + "assistant" => Ok(Self::Assistant), + "system" => Ok(Self::System), + _ => Err(anyhow!("invalid role '{value}'")), + } + } +} + +impl From for String { + fn from(val: Role) -> Self { + match val { + Role::User => "user".to_owned(), + Role::Assistant => "assistant".to_owned(), + Role::System => "system".to_owned(), + } + } +} + +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)] +pub struct Model { + pub name: String, + pub parameter_size: String, + pub max_tokens: usize, + pub keep_alive: Option, +} + +impl Model { + pub fn new(name: &str, parameter_size: &str) -> Self { + Self { + name: name.to_owned(), + parameter_size: parameter_size.to_owned(), + // todo: determine if there's an endpoint to find the max tokens + // I'm not seeing it in the API docs but it's on the model cards + max_tokens: 2048, + keep_alive: Some("10m".to_owned()), + } + } + + pub fn id(&self) -> &str { + &self.name + } + + pub fn display_name(&self) -> &str { + &self.name + } + + pub fn max_token_count(&self) -> usize { + self.max_tokens + } +} + +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +#[serde(tag = "role", rename_all = "lowercase")] +pub enum ChatMessage { + Assistant { content: String }, + User { content: String }, + System { content: String }, +} + +#[derive(Serialize)] +pub struct ChatRequest { + pub model: String, + pub messages: Vec, + pub stream: bool, + pub keep_alive: Option, + pub options: Option, +} + +// https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values +#[derive(Serialize, Default)] +pub struct ChatOptions { + pub num_ctx: Option, + pub num_predict: Option, + pub stop: Option>, + pub temperature: Option, + pub top_p: Option, +} + +#[derive(Deserialize)] +pub struct ChatResponseDelta { + #[allow(unused)] + pub model: String, + #[allow(unused)] + pub created_at: String, + pub message: ChatMessage, + #[allow(unused)] + pub done_reason: Option, + #[allow(unused)] + pub done: bool, +} + +#[derive(Serialize, Deserialize)] +pub struct LocalModelsResponse { + pub models: Vec, +} + +#[derive(Serialize, Deserialize)] +pub struct LocalModelListing { + pub name: String, + pub modified_at: String, + pub size: u64, + pub digest: String, + pub details: ModelDetails, +} + +#[derive(Serialize, Deserialize)] +pub struct LocalModel { + pub modelfile: String, + pub parameters: String, + pub template: String, + pub details: ModelDetails, +} + +#[derive(Serialize, Deserialize)] +pub struct ModelDetails { + pub format: String, + pub family: String, + pub families: Option>, + pub parameter_size: String, + pub quantization_level: String, +} + +pub async fn stream_chat_completion( + client: &dyn HttpClient, + api_url: &str, + request: ChatRequest, + low_speed_timeout: Option, +) -> Result>> { + let uri = format!("{api_url}/api/chat"); + let mut request_builder = HttpRequest::builder() + .method(Method::POST) + .uri(uri) + .header("Content-Type", "application/json"); + + if let Some(low_speed_timeout) = low_speed_timeout { + request_builder = request_builder.low_speed_timeout(100, low_speed_timeout); + }; + + let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?; + let mut response = client.send(request).await?; + if response.status().is_success() { + let reader = BufReader::new(response.into_body()); + + Ok(reader + .lines() + .filter_map(|line| async move { + match line { + Ok(line) => { + Some(serde_json::from_str(&line).context("Unable to parse chat response")) + } + Err(e) => Some(Err(e.into())), + } + }) + .boxed()) + } else { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + + Err(anyhow!( + "Failed to connect to Ollama API: {} {}", + response.status(), + body, + )) + } +} + +pub async fn get_models( + client: &dyn HttpClient, + api_url: &str, + low_speed_timeout: Option, +) -> Result> { + let uri = format!("{api_url}/api/tags"); + let mut request_builder = HttpRequest::builder() + .method(Method::GET) + .uri(uri) + .header("Accept", "application/json"); + + if let Some(low_speed_timeout) = low_speed_timeout { + request_builder = request_builder.low_speed_timeout(100, low_speed_timeout); + }; + + let request = request_builder.body(AsyncBody::default())?; + + let mut response = client.send(request).await?; + + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + + if response.status().is_success() { + let response: LocalModelsResponse = + serde_json::from_str(&body).context("Unable to parse Ollama tag listing")?; + + Ok(response.models) + } else { + Err(anyhow!( + "Failed to connect to Ollama API: {} {}", + response.status(), + body, + )) + } +}