diff --git a/Cargo.lock b/Cargo.lock index 2105303f54..a516ad0a1a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6826,6 +6826,7 @@ version = "0.1.0" dependencies = [ "anyhow", "futures 0.3.28", + "isahc", "schemars", "serde", "serde_json", diff --git a/crates/assistant/src/assistant_settings.rs b/crates/assistant/src/assistant_settings.rs index d29cd24eea..c531536729 100644 --- a/crates/assistant/src/assistant_settings.rs +++ b/crates/assistant/src/assistant_settings.rs @@ -153,6 +153,8 @@ pub enum AssistantProvider { default_model: OpenAiModel, #[serde(default = "open_ai_url")] api_url: String, + #[serde(default)] + low_speed_timeout_in_seconds: Option, }, } @@ -222,12 +224,14 @@ impl AssistantSettingsContent { Some(AssistantProvider::OpenAi { default_model: settings.default_open_ai_model.clone().unwrap_or_default(), api_url: open_ai_api_url.clone(), + low_speed_timeout_in_seconds: None, }) } else { settings.default_open_ai_model.clone().map(|open_ai_model| { AssistantProvider::OpenAi { default_model: open_ai_model, api_url: open_ai_url(), + low_speed_timeout_in_seconds: None, } }) }, @@ -364,14 +368,17 @@ impl Settings for AssistantSettings { AssistantProvider::OpenAi { default_model, api_url, + low_speed_timeout_in_seconds, }, AssistantProvider::OpenAi { default_model: default_model_override, api_url: api_url_override, + low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override, }, ) => { *default_model = default_model_override; *api_url = api_url_override; + *low_speed_timeout_in_seconds = low_speed_timeout_in_seconds_override; } (merged, provider_override) => { *merged = provider_override; @@ -408,7 +415,8 @@ mod tests { AssistantSettings::get_global(cx).provider, AssistantProvider::OpenAi { default_model: OpenAiModel::FourTurbo, - api_url: open_ai_url() + api_url: open_ai_url(), + low_speed_timeout_in_seconds: None, } ); @@ -429,7 +437,8 @@ mod tests { AssistantSettings::get_global(cx).provider, AssistantProvider::OpenAi { default_model: OpenAiModel::FourTurbo, - api_url: "test-url".into() + api_url: "test-url".into(), + low_speed_timeout_in_seconds: None, } ); cx.update_global::(|store, cx| { @@ -448,7 +457,8 @@ mod tests { AssistantSettings::get_global(cx).provider, AssistantProvider::OpenAi { default_model: OpenAiModel::Four, - api_url: open_ai_url() + api_url: open_ai_url(), + low_speed_timeout_in_seconds: None, } ); diff --git a/crates/assistant/src/completion_provider.rs b/crates/assistant/src/completion_provider.rs index 73fd7b52d1..534709358e 100644 --- a/crates/assistant/src/completion_provider.rs +++ b/crates/assistant/src/completion_provider.rs @@ -18,6 +18,7 @@ use futures::{future::BoxFuture, stream::BoxStream}; use gpui::{AnyView, AppContext, BorrowAppContext, Task, WindowContext}; use settings::{Settings, SettingsStore}; use std::sync::Arc; +use std::time::Duration; pub fn init(client: Arc, cx: &mut AppContext) { let mut settings_version = 0; @@ -33,10 +34,12 @@ pub fn init(client: Arc, cx: &mut AppContext) { AssistantProvider::OpenAi { default_model, api_url, + low_speed_timeout_in_seconds, } => CompletionProvider::OpenAi(OpenAiCompletionProvider::new( default_model.clone(), api_url.clone(), client.http_client(), + low_speed_timeout_in_seconds.map(Duration::from_secs), settings_version, )), }; @@ -51,9 +54,15 @@ pub fn init(client: Arc, cx: &mut AppContext) { AssistantProvider::OpenAi { default_model, api_url, + low_speed_timeout_in_seconds, }, ) => { - provider.update(default_model.clone(), api_url.clone(), settings_version); + provider.update( + default_model.clone(), + api_url.clone(), + low_speed_timeout_in_seconds.map(Duration::from_secs), + settings_version, + ); } ( CompletionProvider::ZedDotDev(provider), @@ -74,12 +83,14 @@ pub fn init(client: Arc, cx: &mut AppContext) { AssistantProvider::OpenAi { default_model, api_url, + low_speed_timeout_in_seconds, }, ) => { *provider = CompletionProvider::OpenAi(OpenAiCompletionProvider::new( default_model.clone(), api_url.clone(), client.http_client(), + low_speed_timeout_in_seconds.map(Duration::from_secs), settings_version, )); } diff --git a/crates/assistant/src/completion_provider/open_ai.rs b/crates/assistant/src/completion_provider/open_ai.rs index 9a7398ef7f..c92085c866 100644 --- a/crates/assistant/src/completion_provider/open_ai.rs +++ b/crates/assistant/src/completion_provider/open_ai.rs @@ -7,6 +7,7 @@ use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt}; use gpui::{AnyView, AppContext, FontStyle, FontWeight, Task, TextStyle, View, WhiteSpace}; use open_ai::{stream_completion, Request, RequestMessage, Role as OpenAiRole}; use settings::Settings; +use std::time::Duration; use std::{env, sync::Arc}; use theme::ThemeSettings; use ui::prelude::*; @@ -17,6 +18,7 @@ pub struct OpenAiCompletionProvider { api_url: String, default_model: OpenAiModel, http_client: Arc, + low_speed_timeout: Option, settings_version: usize, } @@ -25,6 +27,7 @@ impl OpenAiCompletionProvider { default_model: OpenAiModel, api_url: String, http_client: Arc, + low_speed_timeout: Option, settings_version: usize, ) -> Self { Self { @@ -32,13 +35,21 @@ impl OpenAiCompletionProvider { api_url, default_model, http_client, + low_speed_timeout, settings_version, } } - pub fn update(&mut self, default_model: OpenAiModel, api_url: String, settings_version: usize) { + pub fn update( + &mut self, + default_model: OpenAiModel, + api_url: String, + low_speed_timeout: Option, + settings_version: usize, + ) { self.default_model = default_model; self.api_url = api_url; + self.low_speed_timeout = low_speed_timeout; self.settings_version = settings_version; } @@ -112,9 +123,16 @@ impl OpenAiCompletionProvider { let http_client = self.http_client.clone(); let api_key = self.api_key.clone(); let api_url = self.api_url.clone(); + let low_speed_timeout = self.low_speed_timeout; async move { let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?; - let request = stream_completion(http_client.as_ref(), &api_url, &api_key, request); + let request = stream_completion( + http_client.as_ref(), + &api_url, + &api_key, + request, + low_speed_timeout, + ); let response = request.await?; let stream = response .filter_map(|response| async move { diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index c7b8f32bde..40bb862e68 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -4344,6 +4344,7 @@ async fn complete_with_open_ai( OPEN_AI_API_URL, &api_key, crate::ai::language_model_request_to_open_ai(request)?, + None, ) .await .context("open_ai::stream_completion request failed within collab")?; diff --git a/crates/open_ai/Cargo.toml b/crates/open_ai/Cargo.toml index a1b01c0b9e..1560cfcfcc 100644 --- a/crates/open_ai/Cargo.toml +++ b/crates/open_ai/Cargo.toml @@ -15,6 +15,7 @@ schemars = ["dep:schemars"] [dependencies] anyhow.workspace = true futures.workspace = true +isahc.workspace = true schemars = { workspace = true, optional = true } serde.workspace = true serde_json.workspace = true diff --git a/crates/open_ai/src/open_ai.rs b/crates/open_ai/src/open_ai.rs index bdc6d3cb9b..b9dfda171d 100644 --- a/crates/open_ai/src/open_ai.rs +++ b/crates/open_ai/src/open_ai.rs @@ -1,7 +1,9 @@ use anyhow::{anyhow, Context, Result}; use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt}; +use isahc::config::Configurable; use serde::{Deserialize, Serialize}; use serde_json::{Map, Value}; +use std::time::Duration; use std::{convert::TryFrom, future::Future}; use util::http::{AsyncBody, HttpClient, Method, Request as HttpRequest}; @@ -206,14 +208,20 @@ pub async fn stream_completion( api_url: &str, api_key: &str, request: Request, + low_speed_timeout: Option, ) -> Result>> { let uri = format!("{api_url}/chat/completions"); - let request = HttpRequest::builder() + let mut request_builder = HttpRequest::builder() .method(Method::POST) .uri(uri) .header("Content-Type", "application/json") - .header("Authorization", format!("Bearer {}", api_key)) - .body(AsyncBody::from(serde_json::to_string(&request)?))?; + .header("Authorization", format!("Bearer {}", api_key)); + + if let Some(low_speed_timeout) = low_speed_timeout { + request_builder = request_builder.low_speed_timeout(100, low_speed_timeout); + }; + + let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?; let mut response = client.send(request).await?; if response.status().is_success() { let reader = BufReader::new(response.into_body());