mirror of
https://github.com/zed-industries/zed.git
synced 2024-11-08 07:35:01 +03:00
Ollama improvements (#12921)
Attempt to load the model early on when the user has switched the model. This is a follow up to #12902 Release Notes: - N/A
This commit is contained in:
parent
113546f766
commit
bee3441c78
@ -62,6 +62,7 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
|
||||
client.http_client(),
|
||||
low_speed_timeout_in_seconds.map(Duration::from_secs),
|
||||
settings_version,
|
||||
cx,
|
||||
)),
|
||||
};
|
||||
cx.set_global(provider);
|
||||
@ -114,6 +115,7 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
|
||||
api_url.clone(),
|
||||
low_speed_timeout_in_seconds.map(Duration::from_secs),
|
||||
settings_version,
|
||||
cx,
|
||||
);
|
||||
}
|
||||
|
||||
@ -174,6 +176,7 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
|
||||
client.http_client(),
|
||||
low_speed_timeout_in_seconds.map(Duration::from_secs),
|
||||
settings_version,
|
||||
cx,
|
||||
));
|
||||
}
|
||||
}
|
||||
|
@ -7,7 +7,8 @@ use futures::{future::BoxFuture, stream::BoxStream, FutureExt};
|
||||
use gpui::{AnyView, AppContext, Task};
|
||||
use http::HttpClient;
|
||||
use ollama::{
|
||||
get_models, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest, Role as OllamaRole,
|
||||
get_models, preload_model, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest,
|
||||
Role as OllamaRole,
|
||||
};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
@ -31,7 +32,17 @@ impl OllamaCompletionProvider {
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
low_speed_timeout: Option<Duration>,
|
||||
settings_version: usize,
|
||||
cx: &AppContext,
|
||||
) -> Self {
|
||||
cx.spawn({
|
||||
let api_url = api_url.clone();
|
||||
let client = http_client.clone();
|
||||
let model = model.name.clone();
|
||||
|
||||
|_| async move { preload_model(client.as_ref(), &api_url, &model).await }
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
|
||||
Self {
|
||||
api_url,
|
||||
model,
|
||||
@ -48,7 +59,17 @@ impl OllamaCompletionProvider {
|
||||
api_url: String,
|
||||
low_speed_timeout: Option<Duration>,
|
||||
settings_version: usize,
|
||||
cx: &AppContext,
|
||||
) {
|
||||
cx.spawn({
|
||||
let api_url = api_url.clone();
|
||||
let client = self.http_client.clone();
|
||||
let model = model.name.clone();
|
||||
|
||||
|_| async move { preload_model(client.as_ref(), &api_url, &model).await }
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
|
||||
self.model = model;
|
||||
self.api_url = api_url;
|
||||
self.low_speed_timeout = low_speed_timeout;
|
||||
@ -93,7 +114,7 @@ impl OllamaCompletionProvider {
|
||||
// indicating which models are embedding models,
|
||||
// simply filter out models with "-embed" in their name
|
||||
.filter(|model| !model.name.contains("-embed"))
|
||||
.map(|model| OllamaModel::new(&model.name, &model.details.parameter_size))
|
||||
.map(|model| OllamaModel::new(&model.name))
|
||||
.collect();
|
||||
|
||||
models.sort_by(|a, b| a.name.cmp(&b.name));
|
||||
|
@ -42,18 +42,14 @@ impl From<Role> for String {
|
||||
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
|
||||
pub struct Model {
|
||||
pub name: String,
|
||||
pub parameter_size: String,
|
||||
pub max_tokens: usize,
|
||||
pub keep_alive: Option<String>,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
pub fn new(name: &str, parameter_size: &str) -> Self {
|
||||
pub fn new(name: &str) -> Self {
|
||||
Self {
|
||||
name: name.to_owned(),
|
||||
parameter_size: parameter_size.to_owned(),
|
||||
// todo: determine if there's an endpoint to find the max tokens
|
||||
// I'm not seeing it in the API docs but it's on the model cards
|
||||
max_tokens: 2048,
|
||||
keep_alive: Some("10m".to_owned()),
|
||||
}
|
||||
@ -222,3 +218,43 @@ pub async fn get_models(
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
/// Sends an empty request to Ollama to trigger loading the model
|
||||
pub async fn preload_model(client: &dyn HttpClient, api_url: &str, model: &str) -> Result<()> {
|
||||
let uri = format!("{api_url}/api/generate");
|
||||
let request = HttpRequest::builder()
|
||||
.method(Method::POST)
|
||||
.uri(uri)
|
||||
.header("Content-Type", "application/json")
|
||||
.body(AsyncBody::from(serde_json::to_string(
|
||||
&serde_json::json!({
|
||||
"model": model,
|
||||
"keep_alive": "15m",
|
||||
}),
|
||||
)?))?;
|
||||
|
||||
let mut response = match client.send(request).await {
|
||||
Ok(response) => response,
|
||||
Err(err) => {
|
||||
// Be ok with a timeout during preload of the model
|
||||
if err.is_timeout() {
|
||||
return Ok(());
|
||||
} else {
|
||||
return Err(err.into());
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if response.status().is_success() {
|
||||
Ok(())
|
||||
} else {
|
||||
let mut body = String::new();
|
||||
response.body_mut().read_to_string(&mut body).await?;
|
||||
|
||||
Err(anyhow!(
|
||||
"Failed to connect to Ollama API: {} {}",
|
||||
response.status(),
|
||||
body,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user