diff --git a/Cargo.lock b/Cargo.lock index c18c904f68..fee0657369 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -359,6 +359,7 @@ dependencies = [ "log", "menu", "multi_buffer", + "ollama", "open_ai", "ordered-float 2.10.0", "parking_lot", @@ -6921,6 +6922,19 @@ dependencies = [ "cc", ] +[[package]] +name = "ollama" +version = "0.1.0" +dependencies = [ + "anyhow", + "futures 0.3.28", + "http 0.1.0", + "isahc", + "schemars", + "serde", + "serde_json", +] + [[package]] name = "once_cell" version = "1.19.0" diff --git a/Cargo.toml b/Cargo.toml index 5f001e5d29..79510e808e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -61,6 +61,7 @@ members = [ "crates/multi_buffer", "crates/node_runtime", "crates/notifications", + "crates/ollama", "crates/open_ai", "crates/outline", "crates/picker", @@ -207,6 +208,7 @@ menu = { path = "crates/menu" } multi_buffer = { path = "crates/multi_buffer" } node_runtime = { path = "crates/node_runtime" } notifications = { path = "crates/notifications" } +ollama = { path = "crates/ollama" } open_ai = { path = "crates/open_ai" } outline = { path = "crates/outline" } picker = { path = "crates/picker" } diff --git a/crates/assistant/Cargo.toml b/crates/assistant/Cargo.toml index eaa6fc5b73..77f0bc4ae0 100644 --- a/crates/assistant/Cargo.toml +++ b/crates/assistant/Cargo.toml @@ -35,6 +35,7 @@ language.workspace = true log.workspace = true menu.workspace = true multi_buffer.workspace = true +ollama = { workspace = true, features = ["schemars"] } open_ai = { workspace = true, features = ["schemars"] } ordered-float.workspace = true parking_lot.workspace = true diff --git a/crates/assistant/src/assistant.rs b/crates/assistant/src/assistant.rs index bb795b1034..07488fdc5b 100644 --- a/crates/assistant/src/assistant.rs +++ b/crates/assistant/src/assistant.rs @@ -12,7 +12,7 @@ mod streaming_diff; pub use assistant_panel::AssistantPanel; -use assistant_settings::{AnthropicModel, AssistantSettings, CloudModel, OpenAiModel}; +use assistant_settings::{AnthropicModel, AssistantSettings, CloudModel, OllamaModel, OpenAiModel}; use assistant_slash_command::SlashCommandRegistry; use client::{proto, Client}; use command_palette_hooks::CommandPaletteFilter; @@ -91,6 +91,7 @@ pub enum LanguageModel { Cloud(CloudModel), OpenAi(OpenAiModel), Anthropic(AnthropicModel), + Ollama(OllamaModel), } impl Default for LanguageModel { @@ -105,6 +106,7 @@ impl LanguageModel { LanguageModel::OpenAi(model) => format!("openai/{}", model.id()), LanguageModel::Anthropic(model) => format!("anthropic/{}", model.id()), LanguageModel::Cloud(model) => format!("zed.dev/{}", model.id()), + LanguageModel::Ollama(model) => format!("ollama/{}", model.id()), } } @@ -113,6 +115,7 @@ impl LanguageModel { LanguageModel::OpenAi(model) => model.display_name().into(), LanguageModel::Anthropic(model) => model.display_name().into(), LanguageModel::Cloud(model) => model.display_name().into(), + LanguageModel::Ollama(model) => model.display_name().into(), } } @@ -121,6 +124,7 @@ impl LanguageModel { LanguageModel::OpenAi(model) => model.max_token_count(), LanguageModel::Anthropic(model) => model.max_token_count(), LanguageModel::Cloud(model) => model.max_token_count(), + LanguageModel::Ollama(model) => model.max_token_count(), } } @@ -129,6 +133,7 @@ impl LanguageModel { LanguageModel::OpenAi(model) => model.id(), LanguageModel::Anthropic(model) => model.id(), LanguageModel::Cloud(model) => model.id(), + LanguageModel::Ollama(model) => model.id(), } } } @@ -179,6 +184,7 @@ impl LanguageModelRequest { match &self.model { LanguageModel::OpenAi(_) => {} LanguageModel::Anthropic(_) => {} + LanguageModel::Ollama(_) => {} LanguageModel::Cloud(model) => match model { CloudModel::Claude3Opus | CloudModel::Claude3Sonnet | CloudModel::Claude3Haiku => { preprocess_anthropic_request(self); diff --git a/crates/assistant/src/assistant_settings.rs b/crates/assistant/src/assistant_settings.rs index 3efaff100d..efc726fe22 100644 --- a/crates/assistant/src/assistant_settings.rs +++ b/crates/assistant/src/assistant_settings.rs @@ -2,6 +2,7 @@ use std::fmt; pub use anthropic::Model as AnthropicModel; use gpui::Pixels; +pub use ollama::Model as OllamaModel; pub use open_ai::Model as OpenAiModel; use schemars::{ schema::{InstanceType, Metadata, Schema, SchemaObject}, @@ -168,6 +169,11 @@ pub enum AssistantProvider { api_url: String, low_speed_timeout_in_seconds: Option, }, + Ollama { + model: OllamaModel, + api_url: String, + low_speed_timeout_in_seconds: Option, + }, } impl Default for AssistantProvider { @@ -197,6 +203,12 @@ pub enum AssistantProviderContent { api_url: Option, low_speed_timeout_in_seconds: Option, }, + #[serde(rename = "ollama")] + Ollama { + default_model: Option, + api_url: Option, + low_speed_timeout_in_seconds: Option, + }, } #[derive(Debug, Default)] @@ -328,6 +340,13 @@ impl AssistantSettingsContent { low_speed_timeout_in_seconds: None, }) } + LanguageModel::Ollama(model) => { + *provider = Some(AssistantProviderContent::Ollama { + default_model: Some(model), + api_url: None, + low_speed_timeout_in_seconds: None, + }) + } }, }, }, @@ -472,6 +491,27 @@ impl Settings for AssistantSettings { Some(low_speed_timeout_in_seconds_override); } } + ( + AssistantProvider::Ollama { + model, + api_url, + low_speed_timeout_in_seconds, + }, + AssistantProviderContent::Ollama { + default_model: model_override, + api_url: api_url_override, + low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override, + }, + ) => { + merge(model, model_override); + merge(api_url, api_url_override); + if let Some(low_speed_timeout_in_seconds_override) = + low_speed_timeout_in_seconds_override + { + *low_speed_timeout_in_seconds = + Some(low_speed_timeout_in_seconds_override); + } + } ( AssistantProvider::Anthropic { model, @@ -519,6 +559,15 @@ impl Settings for AssistantSettings { .unwrap_or_else(|| anthropic::ANTHROPIC_API_URL.into()), low_speed_timeout_in_seconds, }, + AssistantProviderContent::Ollama { + default_model: model, + api_url, + low_speed_timeout_in_seconds, + } => AssistantProvider::Ollama { + model: model.unwrap_or_default(), + api_url: api_url.unwrap_or_else(|| ollama::OLLAMA_API_URL.into()), + low_speed_timeout_in_seconds, + }, }; } } diff --git a/crates/assistant/src/completion_provider.rs b/crates/assistant/src/completion_provider.rs index 01ea6325ad..78b22556ac 100644 --- a/crates/assistant/src/completion_provider.rs +++ b/crates/assistant/src/completion_provider.rs @@ -2,12 +2,14 @@ mod anthropic; mod cloud; #[cfg(test)] mod fake; +mod ollama; mod open_ai; pub use anthropic::*; pub use cloud::*; #[cfg(test)] pub use fake::*; +pub use ollama::*; pub use open_ai::*; use crate::{ @@ -50,6 +52,17 @@ pub fn init(client: Arc, cx: &mut AppContext) { low_speed_timeout_in_seconds.map(Duration::from_secs), settings_version, )), + AssistantProvider::Ollama { + model, + api_url, + low_speed_timeout_in_seconds, + } => CompletionProvider::Ollama(OllamaCompletionProvider::new( + model.clone(), + api_url.clone(), + client.http_client(), + low_speed_timeout_in_seconds.map(Duration::from_secs), + settings_version, + )), }; cx.set_global(provider); @@ -87,6 +100,23 @@ pub fn init(client: Arc, cx: &mut AppContext) { settings_version, ); } + + ( + CompletionProvider::Ollama(provider), + AssistantProvider::Ollama { + model, + api_url, + low_speed_timeout_in_seconds, + }, + ) => { + provider.update( + model.clone(), + api_url.clone(), + low_speed_timeout_in_seconds.map(Duration::from_secs), + settings_version, + ); + } + (CompletionProvider::Cloud(provider), AssistantProvider::ZedDotDev { model }) => { provider.update(model.clone(), settings_version); } @@ -130,6 +160,22 @@ pub fn init(client: Arc, cx: &mut AppContext) { settings_version, )); } + ( + _, + AssistantProvider::Ollama { + model, + api_url, + low_speed_timeout_in_seconds, + }, + ) => { + *provider = CompletionProvider::Ollama(OllamaCompletionProvider::new( + model.clone(), + api_url.clone(), + client.http_client(), + low_speed_timeout_in_seconds.map(Duration::from_secs), + settings_version, + )); + } } }) }) @@ -142,6 +188,7 @@ pub enum CompletionProvider { Cloud(CloudCompletionProvider), #[cfg(test)] Fake(FakeCompletionProvider), + Ollama(OllamaCompletionProvider), } impl gpui::Global for CompletionProvider {} @@ -165,6 +212,10 @@ impl CompletionProvider { .available_models() .map(LanguageModel::Cloud) .collect(), + CompletionProvider::Ollama(provider) => provider + .available_models() + .map(|model| LanguageModel::Ollama(model.clone())) + .collect(), #[cfg(test)] CompletionProvider::Fake(_) => unimplemented!(), } @@ -175,6 +226,7 @@ impl CompletionProvider { CompletionProvider::OpenAi(provider) => provider.settings_version(), CompletionProvider::Anthropic(provider) => provider.settings_version(), CompletionProvider::Cloud(provider) => provider.settings_version(), + CompletionProvider::Ollama(provider) => provider.settings_version(), #[cfg(test)] CompletionProvider::Fake(_) => unimplemented!(), } @@ -185,6 +237,7 @@ impl CompletionProvider { CompletionProvider::OpenAi(provider) => provider.is_authenticated(), CompletionProvider::Anthropic(provider) => provider.is_authenticated(), CompletionProvider::Cloud(provider) => provider.is_authenticated(), + CompletionProvider::Ollama(provider) => provider.is_authenticated(), #[cfg(test)] CompletionProvider::Fake(_) => true, } @@ -195,6 +248,7 @@ impl CompletionProvider { CompletionProvider::OpenAi(provider) => provider.authenticate(cx), CompletionProvider::Anthropic(provider) => provider.authenticate(cx), CompletionProvider::Cloud(provider) => provider.authenticate(cx), + CompletionProvider::Ollama(provider) => provider.authenticate(cx), #[cfg(test)] CompletionProvider::Fake(_) => Task::ready(Ok(())), } @@ -205,6 +259,7 @@ impl CompletionProvider { CompletionProvider::OpenAi(provider) => provider.authentication_prompt(cx), CompletionProvider::Anthropic(provider) => provider.authentication_prompt(cx), CompletionProvider::Cloud(provider) => provider.authentication_prompt(cx), + CompletionProvider::Ollama(provider) => provider.authentication_prompt(cx), #[cfg(test)] CompletionProvider::Fake(_) => unimplemented!(), } @@ -215,6 +270,7 @@ impl CompletionProvider { CompletionProvider::OpenAi(provider) => provider.reset_credentials(cx), CompletionProvider::Anthropic(provider) => provider.reset_credentials(cx), CompletionProvider::Cloud(_) => Task::ready(Ok(())), + CompletionProvider::Ollama(provider) => provider.reset_credentials(cx), #[cfg(test)] CompletionProvider::Fake(_) => Task::ready(Ok(())), } @@ -225,6 +281,7 @@ impl CompletionProvider { CompletionProvider::OpenAi(provider) => LanguageModel::OpenAi(provider.model()), CompletionProvider::Anthropic(provider) => LanguageModel::Anthropic(provider.model()), CompletionProvider::Cloud(provider) => LanguageModel::Cloud(provider.model()), + CompletionProvider::Ollama(provider) => LanguageModel::Ollama(provider.model()), #[cfg(test)] CompletionProvider::Fake(_) => LanguageModel::default(), } @@ -239,6 +296,7 @@ impl CompletionProvider { CompletionProvider::OpenAi(provider) => provider.count_tokens(request, cx), CompletionProvider::Anthropic(provider) => provider.count_tokens(request, cx), CompletionProvider::Cloud(provider) => provider.count_tokens(request, cx), + CompletionProvider::Ollama(provider) => provider.count_tokens(request, cx), #[cfg(test)] CompletionProvider::Fake(_) => futures::FutureExt::boxed(futures::future::ready(Ok(0))), } @@ -252,6 +310,7 @@ impl CompletionProvider { CompletionProvider::OpenAi(provider) => provider.complete(request), CompletionProvider::Anthropic(provider) => provider.complete(request), CompletionProvider::Cloud(provider) => provider.complete(request), + CompletionProvider::Ollama(provider) => provider.complete(request), #[cfg(test)] CompletionProvider::Fake(provider) => provider.complete(), } diff --git a/crates/assistant/src/completion_provider/ollama.rs b/crates/assistant/src/completion_provider/ollama.rs new file mode 100644 index 0000000000..74524da6dd --- /dev/null +++ b/crates/assistant/src/completion_provider/ollama.rs @@ -0,0 +1,246 @@ +use crate::{ + assistant_settings::OllamaModel, CompletionProvider, LanguageModel, LanguageModelRequest, Role, +}; +use anyhow::Result; +use futures::StreamExt as _; +use futures::{future::BoxFuture, stream::BoxStream, FutureExt}; +use gpui::{AnyView, AppContext, Task}; +use http::HttpClient; +use ollama::{ + get_models, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest, Role as OllamaRole, +}; +use std::sync::Arc; +use std::time::Duration; +use ui::{prelude::*, ButtonLike, ElevationIndex}; + +const OLLAMA_DOWNLOAD_URL: &str = "https://ollama.com/download"; + +pub struct OllamaCompletionProvider { + api_url: String, + model: OllamaModel, + http_client: Arc, + low_speed_timeout: Option, + settings_version: usize, + available_models: Vec, +} + +impl OllamaCompletionProvider { + pub fn new( + model: OllamaModel, + api_url: String, + http_client: Arc, + low_speed_timeout: Option, + settings_version: usize, + ) -> Self { + Self { + api_url, + model, + http_client, + low_speed_timeout, + settings_version, + available_models: Default::default(), + } + } + + pub fn update( + &mut self, + model: OllamaModel, + api_url: String, + low_speed_timeout: Option, + settings_version: usize, + ) { + self.model = model; + self.api_url = api_url; + self.low_speed_timeout = low_speed_timeout; + self.settings_version = settings_version; + } + + pub fn available_models(&self) -> impl Iterator { + self.available_models.iter() + } + + pub fn settings_version(&self) -> usize { + self.settings_version + } + + pub fn is_authenticated(&self) -> bool { + !self.available_models.is_empty() + } + + pub fn authenticate(&self, cx: &AppContext) -> Task> { + if self.is_authenticated() { + Task::ready(Ok(())) + } else { + self.fetch_models(cx) + } + } + + pub fn reset_credentials(&self, cx: &AppContext) -> Task> { + self.fetch_models(cx) + } + + pub fn fetch_models(&self, cx: &AppContext) -> Task> { + let http_client = self.http_client.clone(); + let api_url = self.api_url.clone(); + + // As a proxy for the server being "authenticated", we'll check if its up by fetching the models + cx.spawn(|mut cx| async move { + let models = get_models(http_client.as_ref(), &api_url, None).await?; + + let mut models: Vec = models + .into_iter() + // Since there is no metadata from the Ollama API + // indicating which models are embedding models, + // simply filter out models with "-embed" in their name + .filter(|model| !model.name.contains("-embed")) + .map(|model| OllamaModel::new(&model.name, &model.details.parameter_size)) + .collect(); + + models.sort_by(|a, b| a.name.cmp(&b.name)); + + cx.update_global::(|provider, _cx| { + if let CompletionProvider::Ollama(provider) = provider { + provider.available_models = models; + } + }) + }) + } + + pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView { + cx.new_view(|cx| DownloadOllamaMessage::new(cx)).into() + } + + pub fn model(&self) -> OllamaModel { + self.model.clone() + } + + pub fn count_tokens( + &self, + request: LanguageModelRequest, + _cx: &AppContext, + ) -> BoxFuture<'static, Result> { + // There is no endpoint for this _yet_ in Ollama + // see: https://github.com/ollama/ollama/issues/1716 and https://github.com/ollama/ollama/issues/3582 + let token_count = request + .messages + .iter() + .map(|msg| msg.content.chars().count()) + .sum::() + / 4; + + async move { Ok(token_count) }.boxed() + } + + pub fn complete( + &self, + request: LanguageModelRequest, + ) -> BoxFuture<'static, Result>>> { + let request = self.to_ollama_request(request); + + let http_client = self.http_client.clone(); + let api_url = self.api_url.clone(); + let low_speed_timeout = self.low_speed_timeout; + async move { + let request = + stream_chat_completion(http_client.as_ref(), &api_url, request, low_speed_timeout); + let response = request.await?; + let stream = response + .filter_map(|response| async move { + match response { + Ok(delta) => { + let content = match delta.message { + ChatMessage::User { content } => content, + ChatMessage::Assistant { content } => content, + ChatMessage::System { content } => content, + }; + Some(Ok(content)) + } + Err(error) => Some(Err(error)), + } + }) + .boxed(); + Ok(stream) + } + .boxed() + } + + fn to_ollama_request(&self, request: LanguageModelRequest) -> ChatRequest { + let model = match request.model { + LanguageModel::Ollama(model) => model, + _ => self.model(), + }; + + ChatRequest { + model: model.name, + messages: request + .messages + .into_iter() + .map(|msg| match msg.role { + Role::User => ChatMessage::User { + content: msg.content, + }, + Role::Assistant => ChatMessage::Assistant { + content: msg.content, + }, + Role::System => ChatMessage::System { + content: msg.content, + }, + }) + .collect(), + keep_alive: model.keep_alive, + stream: true, + options: Some(ChatOptions { + num_ctx: Some(model.max_tokens), + stop: Some(request.stop), + temperature: Some(request.temperature), + ..Default::default() + }), + } + } +} + +impl From for ollama::Role { + fn from(val: Role) -> Self { + match val { + Role::User => OllamaRole::User, + Role::Assistant => OllamaRole::Assistant, + Role::System => OllamaRole::System, + } + } +} + +struct DownloadOllamaMessage {} + +impl DownloadOllamaMessage { + pub fn new(_cx: &mut ViewContext) -> Self { + Self {} + } + + fn render_download_button(&self, _cx: &mut ViewContext) -> impl IntoElement { + ButtonLike::new("download_ollama_button") + .style(ButtonStyle::Filled) + .size(ButtonSize::Large) + .layer(ElevationIndex::ModalSurface) + .child(Label::new("Get Ollama")) + .on_click(move |_, cx| cx.open_url(OLLAMA_DOWNLOAD_URL)) + } +} + +impl Render for DownloadOllamaMessage { + fn render(&mut self, cx: &mut ViewContext) -> impl IntoElement { + v_flex() + .p_4() + .size_full() + .child(Label::new("To use Ollama models via the assistant, Ollama must be running on your machine.").size(LabelSize::Large)) + .child( + h_flex() + .w_full() + .p_4() + .justify_center() + .child( + self.render_download_button(cx) + ) + ) + .into_any() + } +} diff --git a/crates/ollama/Cargo.toml b/crates/ollama/Cargo.toml new file mode 100644 index 0000000000..2ff329df00 --- /dev/null +++ b/crates/ollama/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "ollama" +version = "0.1.0" +edition = "2021" +publish = false +license = "GPL-3.0-or-later" + +[lib] +path = "src/ollama.rs" + +[features] +default = [] +schemars = ["dep:schemars"] + +[dependencies] +anyhow.workspace = true +futures.workspace = true +http.workspace = true +isahc.workspace = true +schemars = { workspace = true, optional = true } +serde.workspace = true +serde_json.workspace = true diff --git a/crates/ollama/src/ollama.rs b/crates/ollama/src/ollama.rs new file mode 100644 index 0000000000..141d7fe000 --- /dev/null +++ b/crates/ollama/src/ollama.rs @@ -0,0 +1,224 @@ +use anyhow::{anyhow, Context, Result}; +use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt}; +use http::{AsyncBody, HttpClient, Method, Request as HttpRequest}; +use isahc::config::Configurable; +use serde::{Deserialize, Serialize}; +use std::{convert::TryFrom, time::Duration}; + +pub const OLLAMA_API_URL: &str = "http://localhost:11434"; + +#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum Role { + User, + Assistant, + System, +} + +impl TryFrom for Role { + type Error = anyhow::Error; + + fn try_from(value: String) -> Result { + match value.as_str() { + "user" => Ok(Self::User), + "assistant" => Ok(Self::Assistant), + "system" => Ok(Self::System), + _ => Err(anyhow!("invalid role '{value}'")), + } + } +} + +impl From for String { + fn from(val: Role) -> Self { + match val { + Role::User => "user".to_owned(), + Role::Assistant => "assistant".to_owned(), + Role::System => "system".to_owned(), + } + } +} + +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)] +pub struct Model { + pub name: String, + pub parameter_size: String, + pub max_tokens: usize, + pub keep_alive: Option, +} + +impl Model { + pub fn new(name: &str, parameter_size: &str) -> Self { + Self { + name: name.to_owned(), + parameter_size: parameter_size.to_owned(), + // todo: determine if there's an endpoint to find the max tokens + // I'm not seeing it in the API docs but it's on the model cards + max_tokens: 2048, + keep_alive: Some("10m".to_owned()), + } + } + + pub fn id(&self) -> &str { + &self.name + } + + pub fn display_name(&self) -> &str { + &self.name + } + + pub fn max_token_count(&self) -> usize { + self.max_tokens + } +} + +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +#[serde(tag = "role", rename_all = "lowercase")] +pub enum ChatMessage { + Assistant { content: String }, + User { content: String }, + System { content: String }, +} + +#[derive(Serialize)] +pub struct ChatRequest { + pub model: String, + pub messages: Vec, + pub stream: bool, + pub keep_alive: Option, + pub options: Option, +} + +// https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values +#[derive(Serialize, Default)] +pub struct ChatOptions { + pub num_ctx: Option, + pub num_predict: Option, + pub stop: Option>, + pub temperature: Option, + pub top_p: Option, +} + +#[derive(Deserialize)] +pub struct ChatResponseDelta { + #[allow(unused)] + pub model: String, + #[allow(unused)] + pub created_at: String, + pub message: ChatMessage, + #[allow(unused)] + pub done_reason: Option, + #[allow(unused)] + pub done: bool, +} + +#[derive(Serialize, Deserialize)] +pub struct LocalModelsResponse { + pub models: Vec, +} + +#[derive(Serialize, Deserialize)] +pub struct LocalModelListing { + pub name: String, + pub modified_at: String, + pub size: u64, + pub digest: String, + pub details: ModelDetails, +} + +#[derive(Serialize, Deserialize)] +pub struct LocalModel { + pub modelfile: String, + pub parameters: String, + pub template: String, + pub details: ModelDetails, +} + +#[derive(Serialize, Deserialize)] +pub struct ModelDetails { + pub format: String, + pub family: String, + pub families: Option>, + pub parameter_size: String, + pub quantization_level: String, +} + +pub async fn stream_chat_completion( + client: &dyn HttpClient, + api_url: &str, + request: ChatRequest, + low_speed_timeout: Option, +) -> Result>> { + let uri = format!("{api_url}/api/chat"); + let mut request_builder = HttpRequest::builder() + .method(Method::POST) + .uri(uri) + .header("Content-Type", "application/json"); + + if let Some(low_speed_timeout) = low_speed_timeout { + request_builder = request_builder.low_speed_timeout(100, low_speed_timeout); + }; + + let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?; + let mut response = client.send(request).await?; + if response.status().is_success() { + let reader = BufReader::new(response.into_body()); + + Ok(reader + .lines() + .filter_map(|line| async move { + match line { + Ok(line) => { + Some(serde_json::from_str(&line).context("Unable to parse chat response")) + } + Err(e) => Some(Err(e.into())), + } + }) + .boxed()) + } else { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + + Err(anyhow!( + "Failed to connect to Ollama API: {} {}", + response.status(), + body, + )) + } +} + +pub async fn get_models( + client: &dyn HttpClient, + api_url: &str, + low_speed_timeout: Option, +) -> Result> { + let uri = format!("{api_url}/api/tags"); + let mut request_builder = HttpRequest::builder() + .method(Method::GET) + .uri(uri) + .header("Accept", "application/json"); + + if let Some(low_speed_timeout) = low_speed_timeout { + request_builder = request_builder.low_speed_timeout(100, low_speed_timeout); + }; + + let request = request_builder.body(AsyncBody::default())?; + + let mut response = client.send(request).await?; + + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + + if response.status().is_success() { + let response: LocalModelsResponse = + serde_json::from_str(&body).context("Unable to parse Ollama tag listing")?; + + Ok(response.models) + } else { + Err(anyhow!( + "Failed to connect to Ollama API: {} {}", + response.status(), + body, + )) + } +}