mirror of
https://github.com/zed-industries/zed.git
synced 2024-09-16 00:47:39 +03:00
assistant: Fix Google AI provider not respecting low_speed_timeout_in_seconds
(#17423)
Release Notes: - Fixed an issue when using Google Gemini models, where the setting `low_speed_timeout_in_seconds` was not respected
This commit is contained in:
parent
a1c676128a
commit
f413ea90bf
1
Cargo.lock
generated
1
Cargo.lock
generated
@ -4928,6 +4928,7 @@ dependencies = [
|
||||
"anyhow",
|
||||
"futures 0.3.30",
|
||||
"http_client",
|
||||
"isahc",
|
||||
"schemars",
|
||||
"serde",
|
||||
"serde_json",
|
||||
|
@ -380,6 +380,7 @@ async fn perform_completion(
|
||||
google_ai::API_URL,
|
||||
api_key,
|
||||
serde_json::from_str(¶ms.provider_request.get())?,
|
||||
None,
|
||||
)
|
||||
.await?;
|
||||
|
||||
|
@ -4540,6 +4540,7 @@ async fn count_language_model_tokens(
|
||||
google_ai::API_URL,
|
||||
api_key,
|
||||
serde_json::from_str(&request.request)?,
|
||||
None,
|
||||
)
|
||||
.await?
|
||||
}
|
||||
|
@ -18,6 +18,7 @@ schemars = ["dep:schemars"]
|
||||
anyhow.workspace = true
|
||||
futures.workspace = true
|
||||
http_client.workspace = true
|
||||
isahc.workspace = true
|
||||
schemars = { workspace = true, optional = true }
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
|
@ -2,8 +2,10 @@ mod supported_countries;
|
||||
|
||||
use anyhow::{anyhow, Result};
|
||||
use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
|
||||
use http_client::HttpClient;
|
||||
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
|
||||
use isahc::config::Configurable;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::time::Duration;
|
||||
|
||||
pub use supported_countries::*;
|
||||
|
||||
@ -14,6 +16,7 @@ pub async fn stream_generate_content(
|
||||
api_url: &str,
|
||||
api_key: &str,
|
||||
mut request: GenerateContentRequest,
|
||||
low_speed_timeout: Option<Duration>,
|
||||
) -> Result<BoxStream<'static, Result<GenerateContentResponse>>> {
|
||||
let uri = format!(
|
||||
"{api_url}/v1beta/models/{model}:streamGenerateContent?alt=sse&key={api_key}",
|
||||
@ -21,8 +24,17 @@ pub async fn stream_generate_content(
|
||||
);
|
||||
request.model.clear();
|
||||
|
||||
let request = serde_json::to_string(&request)?;
|
||||
let mut response = client.post_json(&uri, request.into()).await?;
|
||||
let mut request_builder = HttpRequest::builder()
|
||||
.method(Method::POST)
|
||||
.uri(uri)
|
||||
.header("Content-Type", "application/json");
|
||||
|
||||
if let Some(low_speed_timeout) = low_speed_timeout {
|
||||
request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
|
||||
};
|
||||
|
||||
let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
|
||||
let mut response = client.send(request).await?;
|
||||
if response.status().is_success() {
|
||||
let reader = BufReader::new(response.into_body());
|
||||
Ok(reader
|
||||
@ -59,13 +71,25 @@ pub async fn count_tokens(
|
||||
api_url: &str,
|
||||
api_key: &str,
|
||||
request: CountTokensRequest,
|
||||
low_speed_timeout: Option<Duration>,
|
||||
) -> Result<CountTokensResponse> {
|
||||
let uri = format!(
|
||||
"{}/v1beta/models/gemini-pro:countTokens?key={}",
|
||||
api_url, api_key
|
||||
);
|
||||
let request = serde_json::to_string(&request)?;
|
||||
let mut response = client.post_json(&uri, request.into()).await?;
|
||||
|
||||
let mut request_builder = HttpRequest::builder()
|
||||
.method(Method::POST)
|
||||
.uri(&uri)
|
||||
.header("Content-Type", "application/json");
|
||||
|
||||
if let Some(low_speed_timeout) = low_speed_timeout {
|
||||
request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
|
||||
}
|
||||
|
||||
let http_request = request_builder.body(AsyncBody::from(request))?;
|
||||
let mut response = client.send(http_request).await?;
|
||||
let mut text = String::new();
|
||||
response.body_mut().read_to_string(&mut text).await?;
|
||||
if response.status().is_success() {
|
||||
|
@ -257,10 +257,10 @@ impl LanguageModel for GoogleLanguageModel {
|
||||
let request = request.into_google(self.model.id().to_string());
|
||||
let http_client = self.http_client.clone();
|
||||
let api_key = self.state.read(cx).api_key.clone();
|
||||
let api_url = AllLanguageModelSettings::get_global(cx)
|
||||
.google
|
||||
.api_url
|
||||
.clone();
|
||||
|
||||
let settings = &AllLanguageModelSettings::get_global(cx).google;
|
||||
let api_url = settings.api_url.clone();
|
||||
let low_speed_timeout = settings.low_speed_timeout;
|
||||
|
||||
async move {
|
||||
let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
|
||||
@ -271,6 +271,7 @@ impl LanguageModel for GoogleLanguageModel {
|
||||
google_ai::CountTokensRequest {
|
||||
contents: request.contents,
|
||||
},
|
||||
low_speed_timeout,
|
||||
)
|
||||
.await?;
|
||||
Ok(response.total_tokens)
|
||||
@ -289,17 +290,26 @@ impl LanguageModel for GoogleLanguageModel {
|
||||
let request = request.into_google(self.model.id().to_string());
|
||||
|
||||
let http_client = self.http_client.clone();
|
||||
let Ok((api_key, api_url)) = cx.read_model(&self.state, |state, cx| {
|
||||
let Ok((api_key, api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, cx| {
|
||||
let settings = &AllLanguageModelSettings::get_global(cx).google;
|
||||
(state.api_key.clone(), settings.api_url.clone())
|
||||
(
|
||||
state.api_key.clone(),
|
||||
settings.api_url.clone(),
|
||||
settings.low_speed_timeout,
|
||||
)
|
||||
}) else {
|
||||
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
|
||||
};
|
||||
|
||||
let future = self.rate_limiter.stream(async move {
|
||||
let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
|
||||
let response =
|
||||
stream_generate_content(http_client.as_ref(), &api_url, &api_key, request);
|
||||
let response = stream_generate_content(
|
||||
http_client.as_ref(),
|
||||
&api_url,
|
||||
&api_key,
|
||||
request,
|
||||
low_speed_timeout,
|
||||
);
|
||||
let events = response.await?;
|
||||
Ok(google_ai::extract_text_from_events(events).boxed())
|
||||
});
|
||||
|
Loading…
Reference in New Issue
Block a user