diff --git a/crates/assistant/src/completion_provider.rs b/crates/assistant/src/completion_provider.rs index 78b22556ac..82fcc88a81 100644 --- a/crates/assistant/src/completion_provider.rs +++ b/crates/assistant/src/completion_provider.rs @@ -62,6 +62,7 @@ pub fn init(client: Arc, cx: &mut AppContext) { client.http_client(), low_speed_timeout_in_seconds.map(Duration::from_secs), settings_version, + cx, )), }; cx.set_global(provider); @@ -114,6 +115,7 @@ pub fn init(client: Arc, cx: &mut AppContext) { api_url.clone(), low_speed_timeout_in_seconds.map(Duration::from_secs), settings_version, + cx, ); } @@ -174,6 +176,7 @@ pub fn init(client: Arc, cx: &mut AppContext) { client.http_client(), low_speed_timeout_in_seconds.map(Duration::from_secs), settings_version, + cx, )); } } diff --git a/crates/assistant/src/completion_provider/ollama.rs b/crates/assistant/src/completion_provider/ollama.rs index 74524da6dd..2275785b06 100644 --- a/crates/assistant/src/completion_provider/ollama.rs +++ b/crates/assistant/src/completion_provider/ollama.rs @@ -7,7 +7,8 @@ use futures::{future::BoxFuture, stream::BoxStream, FutureExt}; use gpui::{AnyView, AppContext, Task}; use http::HttpClient; use ollama::{ - get_models, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest, Role as OllamaRole, + get_models, preload_model, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest, + Role as OllamaRole, }; use std::sync::Arc; use std::time::Duration; @@ -31,7 +32,17 @@ impl OllamaCompletionProvider { http_client: Arc, low_speed_timeout: Option, settings_version: usize, + cx: &AppContext, ) -> Self { + cx.spawn({ + let api_url = api_url.clone(); + let client = http_client.clone(); + let model = model.name.clone(); + + |_| async move { preload_model(client.as_ref(), &api_url, &model).await } + }) + .detach_and_log_err(cx); + Self { api_url, model, @@ -48,7 +59,17 @@ impl OllamaCompletionProvider { api_url: String, low_speed_timeout: Option, settings_version: usize, + cx: &AppContext, ) { + cx.spawn({ + let api_url = api_url.clone(); + let client = self.http_client.clone(); + let model = model.name.clone(); + + |_| async move { preload_model(client.as_ref(), &api_url, &model).await } + }) + .detach_and_log_err(cx); + self.model = model; self.api_url = api_url; self.low_speed_timeout = low_speed_timeout; @@ -93,7 +114,7 @@ impl OllamaCompletionProvider { // indicating which models are embedding models, // simply filter out models with "-embed" in their name .filter(|model| !model.name.contains("-embed")) - .map(|model| OllamaModel::new(&model.name, &model.details.parameter_size)) + .map(|model| OllamaModel::new(&model.name)) .collect(); models.sort_by(|a, b| a.name.cmp(&b.name)); diff --git a/crates/ollama/src/ollama.rs b/crates/ollama/src/ollama.rs index 141d7fe000..ff2f568097 100644 --- a/crates/ollama/src/ollama.rs +++ b/crates/ollama/src/ollama.rs @@ -42,18 +42,14 @@ impl From for String { #[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)] pub struct Model { pub name: String, - pub parameter_size: String, pub max_tokens: usize, pub keep_alive: Option, } impl Model { - pub fn new(name: &str, parameter_size: &str) -> Self { + pub fn new(name: &str) -> Self { Self { name: name.to_owned(), - parameter_size: parameter_size.to_owned(), - // todo: determine if there's an endpoint to find the max tokens - // I'm not seeing it in the API docs but it's on the model cards max_tokens: 2048, keep_alive: Some("10m".to_owned()), } @@ -222,3 +218,43 @@ pub async fn get_models( )) } } + +/// Sends an empty request to Ollama to trigger loading the model +pub async fn preload_model(client: &dyn HttpClient, api_url: &str, model: &str) -> Result<()> { + let uri = format!("{api_url}/api/generate"); + let request = HttpRequest::builder() + .method(Method::POST) + .uri(uri) + .header("Content-Type", "application/json") + .body(AsyncBody::from(serde_json::to_string( + &serde_json::json!({ + "model": model, + "keep_alive": "15m", + }), + )?))?; + + let mut response = match client.send(request).await { + Ok(response) => response, + Err(err) => { + // Be ok with a timeout during preload of the model + if err.is_timeout() { + return Ok(()); + } else { + return Err(err.into()); + } + } + }; + + if response.status().is_success() { + Ok(()) + } else { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + + Err(anyhow!( + "Failed to connect to Ollama API: {} {}", + response.status(), + body, + )) + } +}