Add configurable low-speed timeout for OpenAI provider (#11668)

This PR adds a setting to allow configuring the low-speed timeout for
the Assistant when using the OpenAI provider.

The `low_speed_timeout_in_seconds` accepts a number of seconds that the
HTTP client can go below a minimum speed limit (currently set to 100
bytes/second) before it times out.

```json
{
  "assistant": {
    "version": "1",
    "provider": { "name": "openai", "low_speed_timeout_in_seconds": 60 }
  },
}
```

This should help the case where the `openai` provider is being used with
a local model that requires higher timeouts.

Issue: https://github.com/zed-industries/zed/issues/9913

Release Notes:

- Added a `low_speed_timeout_in_seconds` setting to the Assistant's
OpenAI provider
([#9913](https://github.com/zed-industries/zed/issues/9913)).
This commit is contained in:
Marshall Bowers 2024-05-10 13:19:21 -04:00 committed by GitHub
parent 19994fc190
commit 0d26beb91b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 59 additions and 9 deletions

1
Cargo.lock generated
View File

@ -6826,6 +6826,7 @@ version = "0.1.0"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"futures 0.3.28", "futures 0.3.28",
"isahc",
"schemars", "schemars",
"serde", "serde",
"serde_json", "serde_json",

View File

@ -153,6 +153,8 @@ pub enum AssistantProvider {
default_model: OpenAiModel, default_model: OpenAiModel,
#[serde(default = "open_ai_url")] #[serde(default = "open_ai_url")]
api_url: String, api_url: String,
#[serde(default)]
low_speed_timeout_in_seconds: Option<u64>,
}, },
} }
@ -222,12 +224,14 @@ impl AssistantSettingsContent {
Some(AssistantProvider::OpenAi { Some(AssistantProvider::OpenAi {
default_model: settings.default_open_ai_model.clone().unwrap_or_default(), default_model: settings.default_open_ai_model.clone().unwrap_or_default(),
api_url: open_ai_api_url.clone(), api_url: open_ai_api_url.clone(),
low_speed_timeout_in_seconds: None,
}) })
} else { } else {
settings.default_open_ai_model.clone().map(|open_ai_model| { settings.default_open_ai_model.clone().map(|open_ai_model| {
AssistantProvider::OpenAi { AssistantProvider::OpenAi {
default_model: open_ai_model, default_model: open_ai_model,
api_url: open_ai_url(), api_url: open_ai_url(),
low_speed_timeout_in_seconds: None,
} }
}) })
}, },
@ -364,14 +368,17 @@ impl Settings for AssistantSettings {
AssistantProvider::OpenAi { AssistantProvider::OpenAi {
default_model, default_model,
api_url, api_url,
low_speed_timeout_in_seconds,
}, },
AssistantProvider::OpenAi { AssistantProvider::OpenAi {
default_model: default_model_override, default_model: default_model_override,
api_url: api_url_override, api_url: api_url_override,
low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override,
}, },
) => { ) => {
*default_model = default_model_override; *default_model = default_model_override;
*api_url = api_url_override; *api_url = api_url_override;
*low_speed_timeout_in_seconds = low_speed_timeout_in_seconds_override;
} }
(merged, provider_override) => { (merged, provider_override) => {
*merged = provider_override; *merged = provider_override;
@ -408,7 +415,8 @@ mod tests {
AssistantSettings::get_global(cx).provider, AssistantSettings::get_global(cx).provider,
AssistantProvider::OpenAi { AssistantProvider::OpenAi {
default_model: OpenAiModel::FourTurbo, default_model: OpenAiModel::FourTurbo,
api_url: open_ai_url() api_url: open_ai_url(),
low_speed_timeout_in_seconds: None,
} }
); );
@ -429,7 +437,8 @@ mod tests {
AssistantSettings::get_global(cx).provider, AssistantSettings::get_global(cx).provider,
AssistantProvider::OpenAi { AssistantProvider::OpenAi {
default_model: OpenAiModel::FourTurbo, default_model: OpenAiModel::FourTurbo,
api_url: "test-url".into() api_url: "test-url".into(),
low_speed_timeout_in_seconds: None,
} }
); );
cx.update_global::<SettingsStore, _>(|store, cx| { cx.update_global::<SettingsStore, _>(|store, cx| {
@ -448,7 +457,8 @@ mod tests {
AssistantSettings::get_global(cx).provider, AssistantSettings::get_global(cx).provider,
AssistantProvider::OpenAi { AssistantProvider::OpenAi {
default_model: OpenAiModel::Four, default_model: OpenAiModel::Four,
api_url: open_ai_url() api_url: open_ai_url(),
low_speed_timeout_in_seconds: None,
} }
); );

View File

@ -18,6 +18,7 @@ use futures::{future::BoxFuture, stream::BoxStream};
use gpui::{AnyView, AppContext, BorrowAppContext, Task, WindowContext}; use gpui::{AnyView, AppContext, BorrowAppContext, Task, WindowContext};
use settings::{Settings, SettingsStore}; use settings::{Settings, SettingsStore};
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration;
pub fn init(client: Arc<Client>, cx: &mut AppContext) { pub fn init(client: Arc<Client>, cx: &mut AppContext) {
let mut settings_version = 0; let mut settings_version = 0;
@ -33,10 +34,12 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
AssistantProvider::OpenAi { AssistantProvider::OpenAi {
default_model, default_model,
api_url, api_url,
low_speed_timeout_in_seconds,
} => CompletionProvider::OpenAi(OpenAiCompletionProvider::new( } => CompletionProvider::OpenAi(OpenAiCompletionProvider::new(
default_model.clone(), default_model.clone(),
api_url.clone(), api_url.clone(),
client.http_client(), client.http_client(),
low_speed_timeout_in_seconds.map(Duration::from_secs),
settings_version, settings_version,
)), )),
}; };
@ -51,9 +54,15 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
AssistantProvider::OpenAi { AssistantProvider::OpenAi {
default_model, default_model,
api_url, api_url,
low_speed_timeout_in_seconds,
}, },
) => { ) => {
provider.update(default_model.clone(), api_url.clone(), settings_version); provider.update(
default_model.clone(),
api_url.clone(),
low_speed_timeout_in_seconds.map(Duration::from_secs),
settings_version,
);
} }
( (
CompletionProvider::ZedDotDev(provider), CompletionProvider::ZedDotDev(provider),
@ -74,12 +83,14 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
AssistantProvider::OpenAi { AssistantProvider::OpenAi {
default_model, default_model,
api_url, api_url,
low_speed_timeout_in_seconds,
}, },
) => { ) => {
*provider = CompletionProvider::OpenAi(OpenAiCompletionProvider::new( *provider = CompletionProvider::OpenAi(OpenAiCompletionProvider::new(
default_model.clone(), default_model.clone(),
api_url.clone(), api_url.clone(),
client.http_client(), client.http_client(),
low_speed_timeout_in_seconds.map(Duration::from_secs),
settings_version, settings_version,
)); ));
} }

View File

@ -7,6 +7,7 @@ use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
use gpui::{AnyView, AppContext, FontStyle, FontWeight, Task, TextStyle, View, WhiteSpace}; use gpui::{AnyView, AppContext, FontStyle, FontWeight, Task, TextStyle, View, WhiteSpace};
use open_ai::{stream_completion, Request, RequestMessage, Role as OpenAiRole}; use open_ai::{stream_completion, Request, RequestMessage, Role as OpenAiRole};
use settings::Settings; use settings::Settings;
use std::time::Duration;
use std::{env, sync::Arc}; use std::{env, sync::Arc};
use theme::ThemeSettings; use theme::ThemeSettings;
use ui::prelude::*; use ui::prelude::*;
@ -17,6 +18,7 @@ pub struct OpenAiCompletionProvider {
api_url: String, api_url: String,
default_model: OpenAiModel, default_model: OpenAiModel,
http_client: Arc<dyn HttpClient>, http_client: Arc<dyn HttpClient>,
low_speed_timeout: Option<Duration>,
settings_version: usize, settings_version: usize,
} }
@ -25,6 +27,7 @@ impl OpenAiCompletionProvider {
default_model: OpenAiModel, default_model: OpenAiModel,
api_url: String, api_url: String,
http_client: Arc<dyn HttpClient>, http_client: Arc<dyn HttpClient>,
low_speed_timeout: Option<Duration>,
settings_version: usize, settings_version: usize,
) -> Self { ) -> Self {
Self { Self {
@ -32,13 +35,21 @@ impl OpenAiCompletionProvider {
api_url, api_url,
default_model, default_model,
http_client, http_client,
low_speed_timeout,
settings_version, settings_version,
} }
} }
pub fn update(&mut self, default_model: OpenAiModel, api_url: String, settings_version: usize) { pub fn update(
&mut self,
default_model: OpenAiModel,
api_url: String,
low_speed_timeout: Option<Duration>,
settings_version: usize,
) {
self.default_model = default_model; self.default_model = default_model;
self.api_url = api_url; self.api_url = api_url;
self.low_speed_timeout = low_speed_timeout;
self.settings_version = settings_version; self.settings_version = settings_version;
} }
@ -112,9 +123,16 @@ impl OpenAiCompletionProvider {
let http_client = self.http_client.clone(); let http_client = self.http_client.clone();
let api_key = self.api_key.clone(); let api_key = self.api_key.clone();
let api_url = self.api_url.clone(); let api_url = self.api_url.clone();
let low_speed_timeout = self.low_speed_timeout;
async move { async move {
let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?; let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
let request = stream_completion(http_client.as_ref(), &api_url, &api_key, request); let request = stream_completion(
http_client.as_ref(),
&api_url,
&api_key,
request,
low_speed_timeout,
);
let response = request.await?; let response = request.await?;
let stream = response let stream = response
.filter_map(|response| async move { .filter_map(|response| async move {

View File

@ -4344,6 +4344,7 @@ async fn complete_with_open_ai(
OPEN_AI_API_URL, OPEN_AI_API_URL,
&api_key, &api_key,
crate::ai::language_model_request_to_open_ai(request)?, crate::ai::language_model_request_to_open_ai(request)?,
None,
) )
.await .await
.context("open_ai::stream_completion request failed within collab")?; .context("open_ai::stream_completion request failed within collab")?;

View File

@ -15,6 +15,7 @@ schemars = ["dep:schemars"]
[dependencies] [dependencies]
anyhow.workspace = true anyhow.workspace = true
futures.workspace = true futures.workspace = true
isahc.workspace = true
schemars = { workspace = true, optional = true } schemars = { workspace = true, optional = true }
serde.workspace = true serde.workspace = true
serde_json.workspace = true serde_json.workspace = true

View File

@ -1,7 +1,9 @@
use anyhow::{anyhow, Context, Result}; use anyhow::{anyhow, Context, Result};
use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt}; use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
use isahc::config::Configurable;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::{Map, Value}; use serde_json::{Map, Value};
use std::time::Duration;
use std::{convert::TryFrom, future::Future}; use std::{convert::TryFrom, future::Future};
use util::http::{AsyncBody, HttpClient, Method, Request as HttpRequest}; use util::http::{AsyncBody, HttpClient, Method, Request as HttpRequest};
@ -206,14 +208,20 @@ pub async fn stream_completion(
api_url: &str, api_url: &str,
api_key: &str, api_key: &str,
request: Request, request: Request,
low_speed_timeout: Option<Duration>,
) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> { ) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
let uri = format!("{api_url}/chat/completions"); let uri = format!("{api_url}/chat/completions");
let request = HttpRequest::builder() let mut request_builder = HttpRequest::builder()
.method(Method::POST) .method(Method::POST)
.uri(uri) .uri(uri)
.header("Content-Type", "application/json") .header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", api_key)) .header("Authorization", format!("Bearer {}", api_key));
.body(AsyncBody::from(serde_json::to_string(&request)?))?;
if let Some(low_speed_timeout) = low_speed_timeout {
request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
};
let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
let mut response = client.send(request).await?; let mut response = client.send(request).await?;
if response.status().is_success() { if response.status().is_success() {
let reader = BufReader::new(response.into_body()); let reader = BufReader::new(response.into_body());