From f413ea90bf353456e918c4b2eec32ee6f0fbb059 Mon Sep 17 00:00:00 2001 From: Bennet Bo Fenner Date: Thu, 5 Sep 2024 18:16:30 +0200 Subject: [PATCH] assistant: Fix Google AI provider not respecting `low_speed_timeout_in_seconds` (#17423) Release Notes: - Fixed an issue when using Google Gemini models, where the setting `low_speed_timeout_in_seconds` was not respected --- Cargo.lock | 1 + crates/collab/src/llm.rs | 1 + crates/collab/src/rpc.rs | 1 + crates/google_ai/Cargo.toml | 1 + crates/google_ai/src/google_ai.rs | 32 +++++++++++++++++--- crates/language_model/src/provider/google.rs | 26 +++++++++++----- 6 files changed, 50 insertions(+), 12 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 20b9f4ab2d..1d8b7d874a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4928,6 +4928,7 @@ dependencies = [ "anyhow", "futures 0.3.30", "http_client", + "isahc", "schemars", "serde", "serde_json", diff --git a/crates/collab/src/llm.rs b/crates/collab/src/llm.rs index 320d7418ee..2bcefc8477 100644 --- a/crates/collab/src/llm.rs +++ b/crates/collab/src/llm.rs @@ -380,6 +380,7 @@ async fn perform_completion( google_ai::API_URL, api_key, serde_json::from_str(¶ms.provider_request.get())?, + None, ) .await?; diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index c0ed66d129..1c841d0401 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -4540,6 +4540,7 @@ async fn count_language_model_tokens( google_ai::API_URL, api_key, serde_json::from_str(&request.request)?, + None, ) .await? } diff --git a/crates/google_ai/Cargo.toml b/crates/google_ai/Cargo.toml index f923e0ec91..2a52f1968d 100644 --- a/crates/google_ai/Cargo.toml +++ b/crates/google_ai/Cargo.toml @@ -18,6 +18,7 @@ schemars = ["dep:schemars"] anyhow.workspace = true futures.workspace = true http_client.workspace = true +isahc.workspace = true schemars = { workspace = true, optional = true } serde.workspace = true serde_json.workspace = true diff --git a/crates/google_ai/src/google_ai.rs b/crates/google_ai/src/google_ai.rs index 631a6b20ca..f0803b4029 100644 --- a/crates/google_ai/src/google_ai.rs +++ b/crates/google_ai/src/google_ai.rs @@ -2,8 +2,10 @@ mod supported_countries; use anyhow::{anyhow, Result}; use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt}; -use http_client::HttpClient; +use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest}; +use isahc::config::Configurable; use serde::{Deserialize, Serialize}; +use std::time::Duration; pub use supported_countries::*; @@ -14,6 +16,7 @@ pub async fn stream_generate_content( api_url: &str, api_key: &str, mut request: GenerateContentRequest, + low_speed_timeout: Option, ) -> Result>> { let uri = format!( "{api_url}/v1beta/models/{model}:streamGenerateContent?alt=sse&key={api_key}", @@ -21,8 +24,17 @@ pub async fn stream_generate_content( ); request.model.clear(); - let request = serde_json::to_string(&request)?; - let mut response = client.post_json(&uri, request.into()).await?; + let mut request_builder = HttpRequest::builder() + .method(Method::POST) + .uri(uri) + .header("Content-Type", "application/json"); + + if let Some(low_speed_timeout) = low_speed_timeout { + request_builder = request_builder.low_speed_timeout(100, low_speed_timeout); + }; + + let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?; + let mut response = client.send(request).await?; if response.status().is_success() { let reader = BufReader::new(response.into_body()); Ok(reader @@ -59,13 +71,25 @@ pub async fn count_tokens( api_url: &str, api_key: &str, request: CountTokensRequest, + low_speed_timeout: Option, ) -> Result { let uri = format!( "{}/v1beta/models/gemini-pro:countTokens?key={}", api_url, api_key ); let request = serde_json::to_string(&request)?; - let mut response = client.post_json(&uri, request.into()).await?; + + let mut request_builder = HttpRequest::builder() + .method(Method::POST) + .uri(&uri) + .header("Content-Type", "application/json"); + + if let Some(low_speed_timeout) = low_speed_timeout { + request_builder = request_builder.low_speed_timeout(100, low_speed_timeout); + } + + let http_request = request_builder.body(AsyncBody::from(request))?; + let mut response = client.send(http_request).await?; let mut text = String::new(); response.body_mut().read_to_string(&mut text).await?; if response.status().is_success() { diff --git a/crates/language_model/src/provider/google.rs b/crates/language_model/src/provider/google.rs index b59d97e036..3aa5917c14 100644 --- a/crates/language_model/src/provider/google.rs +++ b/crates/language_model/src/provider/google.rs @@ -257,10 +257,10 @@ impl LanguageModel for GoogleLanguageModel { let request = request.into_google(self.model.id().to_string()); let http_client = self.http_client.clone(); let api_key = self.state.read(cx).api_key.clone(); - let api_url = AllLanguageModelSettings::get_global(cx) - .google - .api_url - .clone(); + + let settings = &AllLanguageModelSettings::get_global(cx).google; + let api_url = settings.api_url.clone(); + let low_speed_timeout = settings.low_speed_timeout; async move { let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?; @@ -271,6 +271,7 @@ impl LanguageModel for GoogleLanguageModel { google_ai::CountTokensRequest { contents: request.contents, }, + low_speed_timeout, ) .await?; Ok(response.total_tokens) @@ -289,17 +290,26 @@ impl LanguageModel for GoogleLanguageModel { let request = request.into_google(self.model.id().to_string()); let http_client = self.http_client.clone(); - let Ok((api_key, api_url)) = cx.read_model(&self.state, |state, cx| { + let Ok((api_key, api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, cx| { let settings = &AllLanguageModelSettings::get_global(cx).google; - (state.api_key.clone(), settings.api_url.clone()) + ( + state.api_key.clone(), + settings.api_url.clone(), + settings.low_speed_timeout, + ) }) else { return futures::future::ready(Err(anyhow!("App state dropped"))).boxed(); }; let future = self.rate_limiter.stream(async move { let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?; - let response = - stream_generate_content(http_client.as_ref(), &api_url, &api_key, request); + let response = stream_generate_content( + http_client.as_ref(), + &api_url, + &api_key, + request, + low_speed_timeout, + ); let events = response.await?; Ok(google_ai::extract_text_from_events(events).boxed()) });