mirror of
https://github.com/zed-industries/zed.git
synced 2024-11-07 20:39:04 +03:00
Add configurable low-speed timeout for OpenAI provider (#11668)
This PR adds a setting to allow configuring the low-speed timeout for the Assistant when using the OpenAI provider. The `low_speed_timeout_in_seconds` accepts a number of seconds that the HTTP client can go below a minimum speed limit (currently set to 100 bytes/second) before it times out. ```json { "assistant": { "version": "1", "provider": { "name": "openai", "low_speed_timeout_in_seconds": 60 } }, } ``` This should help the case where the `openai` provider is being used with a local model that requires higher timeouts. Issue: https://github.com/zed-industries/zed/issues/9913 Release Notes: - Added a `low_speed_timeout_in_seconds` setting to the Assistant's OpenAI provider ([#9913](https://github.com/zed-industries/zed/issues/9913)).
This commit is contained in:
parent
19994fc190
commit
0d26beb91b
1
Cargo.lock
generated
1
Cargo.lock
generated
@ -6826,6 +6826,7 @@ version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"futures 0.3.28",
|
||||
"isahc",
|
||||
"schemars",
|
||||
"serde",
|
||||
"serde_json",
|
||||
|
@ -153,6 +153,8 @@ pub enum AssistantProvider {
|
||||
default_model: OpenAiModel,
|
||||
#[serde(default = "open_ai_url")]
|
||||
api_url: String,
|
||||
#[serde(default)]
|
||||
low_speed_timeout_in_seconds: Option<u64>,
|
||||
},
|
||||
}
|
||||
|
||||
@ -222,12 +224,14 @@ impl AssistantSettingsContent {
|
||||
Some(AssistantProvider::OpenAi {
|
||||
default_model: settings.default_open_ai_model.clone().unwrap_or_default(),
|
||||
api_url: open_ai_api_url.clone(),
|
||||
low_speed_timeout_in_seconds: None,
|
||||
})
|
||||
} else {
|
||||
settings.default_open_ai_model.clone().map(|open_ai_model| {
|
||||
AssistantProvider::OpenAi {
|
||||
default_model: open_ai_model,
|
||||
api_url: open_ai_url(),
|
||||
low_speed_timeout_in_seconds: None,
|
||||
}
|
||||
})
|
||||
},
|
||||
@ -364,14 +368,17 @@ impl Settings for AssistantSettings {
|
||||
AssistantProvider::OpenAi {
|
||||
default_model,
|
||||
api_url,
|
||||
low_speed_timeout_in_seconds,
|
||||
},
|
||||
AssistantProvider::OpenAi {
|
||||
default_model: default_model_override,
|
||||
api_url: api_url_override,
|
||||
low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override,
|
||||
},
|
||||
) => {
|
||||
*default_model = default_model_override;
|
||||
*api_url = api_url_override;
|
||||
*low_speed_timeout_in_seconds = low_speed_timeout_in_seconds_override;
|
||||
}
|
||||
(merged, provider_override) => {
|
||||
*merged = provider_override;
|
||||
@ -408,7 +415,8 @@ mod tests {
|
||||
AssistantSettings::get_global(cx).provider,
|
||||
AssistantProvider::OpenAi {
|
||||
default_model: OpenAiModel::FourTurbo,
|
||||
api_url: open_ai_url()
|
||||
api_url: open_ai_url(),
|
||||
low_speed_timeout_in_seconds: None,
|
||||
}
|
||||
);
|
||||
|
||||
@ -429,7 +437,8 @@ mod tests {
|
||||
AssistantSettings::get_global(cx).provider,
|
||||
AssistantProvider::OpenAi {
|
||||
default_model: OpenAiModel::FourTurbo,
|
||||
api_url: "test-url".into()
|
||||
api_url: "test-url".into(),
|
||||
low_speed_timeout_in_seconds: None,
|
||||
}
|
||||
);
|
||||
cx.update_global::<SettingsStore, _>(|store, cx| {
|
||||
@ -448,7 +457,8 @@ mod tests {
|
||||
AssistantSettings::get_global(cx).provider,
|
||||
AssistantProvider::OpenAi {
|
||||
default_model: OpenAiModel::Four,
|
||||
api_url: open_ai_url()
|
||||
api_url: open_ai_url(),
|
||||
low_speed_timeout_in_seconds: None,
|
||||
}
|
||||
);
|
||||
|
||||
|
@ -18,6 +18,7 @@ use futures::{future::BoxFuture, stream::BoxStream};
|
||||
use gpui::{AnyView, AppContext, BorrowAppContext, Task, WindowContext};
|
||||
use settings::{Settings, SettingsStore};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
pub fn init(client: Arc<Client>, cx: &mut AppContext) {
|
||||
let mut settings_version = 0;
|
||||
@ -33,10 +34,12 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
|
||||
AssistantProvider::OpenAi {
|
||||
default_model,
|
||||
api_url,
|
||||
low_speed_timeout_in_seconds,
|
||||
} => CompletionProvider::OpenAi(OpenAiCompletionProvider::new(
|
||||
default_model.clone(),
|
||||
api_url.clone(),
|
||||
client.http_client(),
|
||||
low_speed_timeout_in_seconds.map(Duration::from_secs),
|
||||
settings_version,
|
||||
)),
|
||||
};
|
||||
@ -51,9 +54,15 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
|
||||
AssistantProvider::OpenAi {
|
||||
default_model,
|
||||
api_url,
|
||||
low_speed_timeout_in_seconds,
|
||||
},
|
||||
) => {
|
||||
provider.update(default_model.clone(), api_url.clone(), settings_version);
|
||||
provider.update(
|
||||
default_model.clone(),
|
||||
api_url.clone(),
|
||||
low_speed_timeout_in_seconds.map(Duration::from_secs),
|
||||
settings_version,
|
||||
);
|
||||
}
|
||||
(
|
||||
CompletionProvider::ZedDotDev(provider),
|
||||
@ -74,12 +83,14 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
|
||||
AssistantProvider::OpenAi {
|
||||
default_model,
|
||||
api_url,
|
||||
low_speed_timeout_in_seconds,
|
||||
},
|
||||
) => {
|
||||
*provider = CompletionProvider::OpenAi(OpenAiCompletionProvider::new(
|
||||
default_model.clone(),
|
||||
api_url.clone(),
|
||||
client.http_client(),
|
||||
low_speed_timeout_in_seconds.map(Duration::from_secs),
|
||||
settings_version,
|
||||
));
|
||||
}
|
||||
|
@ -7,6 +7,7 @@ use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
|
||||
use gpui::{AnyView, AppContext, FontStyle, FontWeight, Task, TextStyle, View, WhiteSpace};
|
||||
use open_ai::{stream_completion, Request, RequestMessage, Role as OpenAiRole};
|
||||
use settings::Settings;
|
||||
use std::time::Duration;
|
||||
use std::{env, sync::Arc};
|
||||
use theme::ThemeSettings;
|
||||
use ui::prelude::*;
|
||||
@ -17,6 +18,7 @@ pub struct OpenAiCompletionProvider {
|
||||
api_url: String,
|
||||
default_model: OpenAiModel,
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
low_speed_timeout: Option<Duration>,
|
||||
settings_version: usize,
|
||||
}
|
||||
|
||||
@ -25,6 +27,7 @@ impl OpenAiCompletionProvider {
|
||||
default_model: OpenAiModel,
|
||||
api_url: String,
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
low_speed_timeout: Option<Duration>,
|
||||
settings_version: usize,
|
||||
) -> Self {
|
||||
Self {
|
||||
@ -32,13 +35,21 @@ impl OpenAiCompletionProvider {
|
||||
api_url,
|
||||
default_model,
|
||||
http_client,
|
||||
low_speed_timeout,
|
||||
settings_version,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn update(&mut self, default_model: OpenAiModel, api_url: String, settings_version: usize) {
|
||||
pub fn update(
|
||||
&mut self,
|
||||
default_model: OpenAiModel,
|
||||
api_url: String,
|
||||
low_speed_timeout: Option<Duration>,
|
||||
settings_version: usize,
|
||||
) {
|
||||
self.default_model = default_model;
|
||||
self.api_url = api_url;
|
||||
self.low_speed_timeout = low_speed_timeout;
|
||||
self.settings_version = settings_version;
|
||||
}
|
||||
|
||||
@ -112,9 +123,16 @@ impl OpenAiCompletionProvider {
|
||||
let http_client = self.http_client.clone();
|
||||
let api_key = self.api_key.clone();
|
||||
let api_url = self.api_url.clone();
|
||||
let low_speed_timeout = self.low_speed_timeout;
|
||||
async move {
|
||||
let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
|
||||
let request = stream_completion(http_client.as_ref(), &api_url, &api_key, request);
|
||||
let request = stream_completion(
|
||||
http_client.as_ref(),
|
||||
&api_url,
|
||||
&api_key,
|
||||
request,
|
||||
low_speed_timeout,
|
||||
);
|
||||
let response = request.await?;
|
||||
let stream = response
|
||||
.filter_map(|response| async move {
|
||||
|
@ -4344,6 +4344,7 @@ async fn complete_with_open_ai(
|
||||
OPEN_AI_API_URL,
|
||||
&api_key,
|
||||
crate::ai::language_model_request_to_open_ai(request)?,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.context("open_ai::stream_completion request failed within collab")?;
|
||||
|
@ -15,6 +15,7 @@ schemars = ["dep:schemars"]
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
futures.workspace = true
|
||||
isahc.workspace = true
|
||||
schemars = { workspace = true, optional = true }
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
|
@ -1,7 +1,9 @@
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
|
||||
use isahc::config::Configurable;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{Map, Value};
|
||||
use std::time::Duration;
|
||||
use std::{convert::TryFrom, future::Future};
|
||||
use util::http::{AsyncBody, HttpClient, Method, Request as HttpRequest};
|
||||
|
||||
@ -206,14 +208,20 @@ pub async fn stream_completion(
|
||||
api_url: &str,
|
||||
api_key: &str,
|
||||
request: Request,
|
||||
low_speed_timeout: Option<Duration>,
|
||||
) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
|
||||
let uri = format!("{api_url}/chat/completions");
|
||||
let request = HttpRequest::builder()
|
||||
let mut request_builder = HttpRequest::builder()
|
||||
.method(Method::POST)
|
||||
.uri(uri)
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Authorization", format!("Bearer {}", api_key))
|
||||
.body(AsyncBody::from(serde_json::to_string(&request)?))?;
|
||||
.header("Authorization", format!("Bearer {}", api_key));
|
||||
|
||||
if let Some(low_speed_timeout) = low_speed_timeout {
|
||||
request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
|
||||
};
|
||||
|
||||
let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
|
||||
let mut response = client.send(request).await?;
|
||||
if response.status().is_success() {
|
||||
let reader = BufReader::new(response.into_body());
|
||||
|
Loading…
Reference in New Issue
Block a user