diff --git a/crates/search/src/project_search.rs b/crates/search/src/project_search.rs index c52be64141..977ead8c9e 100644 --- a/crates/search/src/project_search.rs +++ b/crates/search/src/project_search.rs @@ -34,6 +34,7 @@ use std::{ ops::{Not, Range}, path::PathBuf, sync::Arc, + time::Duration, }; use util::ResultExt as _; use workspace::{ @@ -319,11 +320,22 @@ impl View for ProjectSearchView { let status = semantic.index_status; match status { SemanticIndexStatus::Indexed => Some("Indexing complete".to_string()), - SemanticIndexStatus::Indexing { remaining_files } => { + SemanticIndexStatus::Indexing { + remaining_files, + rate_limiting, + } => { if remaining_files == 0 { Some(format!("Indexing...")) } else { - Some(format!("Remaining files to index: {}", remaining_files)) + if rate_limiting > Duration::ZERO { + Some(format!( + "Remaining files to index (rate limit resets in {}s): {}", + rate_limiting.as_secs(), + remaining_files + )) + } else { + Some(format!("Remaining files to index: {}", remaining_files)) + } } } SemanticIndexStatus::NotIndexed => None, diff --git a/crates/semantic_index/src/embedding.rs b/crates/semantic_index/src/embedding.rs index 7228738525..6affac2556 100644 --- a/crates/semantic_index/src/embedding.rs +++ b/crates/semantic_index/src/embedding.rs @@ -7,7 +7,9 @@ use isahc::http::StatusCode; use isahc::prelude::Configurable; use isahc::{AsyncBody, Response}; use lazy_static::lazy_static; +use parking_lot::Mutex; use parse_duration::parse; +use postage::watch; use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef}; use rusqlite::ToSql; use serde::{Deserialize, Serialize}; @@ -82,6 +84,8 @@ impl ToSql for Embedding { pub struct OpenAIEmbeddings { pub client: Arc, pub executor: Arc, + rate_limit_count_rx: watch::Receiver<(Duration, usize)>, + rate_limit_count_tx: Arc>>, } #[derive(Serialize)] @@ -114,12 +118,16 @@ pub trait EmbeddingProvider: Sync + Send { async fn embed_batch(&self, spans: Vec) -> Result>; fn max_tokens_per_batch(&self) -> usize; fn truncate(&self, span: &str) -> (String, usize); + fn rate_limit_expiration(&self) -> Duration; } pub struct DummyEmbeddings {} #[async_trait] impl EmbeddingProvider for DummyEmbeddings { + fn rate_limit_expiration(&self) -> Duration { + Duration::ZERO + } async fn embed_batch(&self, spans: Vec) -> Result> { // 1024 is the OpenAI Embeddings size for ada models. // the model we will likely be starting with. @@ -149,6 +157,53 @@ impl EmbeddingProvider for DummyEmbeddings { const OPENAI_INPUT_LIMIT: usize = 8190; impl OpenAIEmbeddings { + pub fn new(client: Arc, executor: Arc) -> Self { + let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with((Duration::ZERO, 0)); + let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx)); + + OpenAIEmbeddings { + client, + executor, + rate_limit_count_rx, + rate_limit_count_tx, + } + } + + fn resolve_rate_limit(&self) { + let (current_delay, delay_count) = *self.rate_limit_count_tx.lock().borrow(); + let updated_count = delay_count - 1; + let updated_duration = if updated_count == 0 { + Duration::ZERO + } else { + current_delay + }; + + log::trace!( + "resolving rate limit: Count: {:?} Duration: {:?}", + updated_count, + updated_duration + ); + + *self.rate_limit_count_tx.lock().borrow_mut() = (updated_duration, updated_count); + } + + fn update_rate_limit(&self, delay_duration: Duration, count_increase: usize) { + let (current_delay, delay_count) = *self.rate_limit_count_tx.lock().borrow(); + let updated_count = delay_count + count_increase; + let updated_duration = if current_delay < delay_duration { + delay_duration + } else { + current_delay + }; + + log::trace!( + "updating rate limit: Count: {:?} Duration: {:?}", + updated_count, + updated_duration + ); + + *self.rate_limit_count_tx.lock().borrow_mut() = (updated_duration, updated_count); + } async fn send_request( &self, api_key: &str, @@ -179,6 +234,10 @@ impl EmbeddingProvider for OpenAIEmbeddings { 50000 } + fn rate_limit_expiration(&self) -> Duration { + let (duration, _) = *self.rate_limit_count_rx.borrow(); + duration + } fn truncate(&self, span: &str) -> (String, usize) { let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span); let output = if tokens.len() > OPENAI_INPUT_LIMIT { @@ -203,6 +262,7 @@ impl EmbeddingProvider for OpenAIEmbeddings { .ok_or_else(|| anyhow!("no api key"))?; let mut request_number = 0; + let mut rate_limiting = false; let mut request_timeout: u64 = 15; let mut response: Response; while request_number < MAX_RETRIES { @@ -229,6 +289,12 @@ impl EmbeddingProvider for OpenAIEmbeddings { response.usage.total_tokens ); + // If we complete a request successfully that was previously rate_limited + // resolve the rate limit + if rate_limiting { + self.resolve_rate_limit() + } + return Ok(response .data .into_iter() @@ -254,6 +320,15 @@ impl EmbeddingProvider for OpenAIEmbeddings { } }; + // If we've previously rate limited, increment the duration but not the count + if rate_limiting { + self.update_rate_limit(delay_duration, 0); + } else { + self.update_rate_limit(delay_duration, 1); + } + + rate_limiting = true; + log::trace!( "openai rate limiting: waiting {:?} until lifted", &delay_duration diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index 0e18c42049..8fba7de0f0 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -91,10 +91,7 @@ pub fn init( let semantic_index = SemanticIndex::new( fs, db_file_path, - Arc::new(OpenAIEmbeddings { - client: http_client, - executor: cx.background(), - }), + Arc::new(OpenAIEmbeddings::new(http_client, cx.background())), language_registry, cx.clone(), ) @@ -113,7 +110,10 @@ pub fn init( pub enum SemanticIndexStatus { NotIndexed, Indexed, - Indexing { remaining_files: usize }, + Indexing { + remaining_files: usize, + rate_limiting: Duration, + }, } pub struct SemanticIndex { @@ -132,6 +132,8 @@ struct ProjectState { pending_file_count_rx: watch::Receiver, pending_file_count_tx: Arc>>, pending_index: usize, + rate_limiting_count_rx: watch::Receiver, + rate_limiting_count_tx: Arc>>, _subscription: gpui::Subscription, _observe_pending_file_count: Task<()>, } @@ -223,11 +225,15 @@ impl ProjectState { fn new(subscription: gpui::Subscription, cx: &mut ModelContext) -> Self { let (pending_file_count_tx, pending_file_count_rx) = watch::channel_with(0); let pending_file_count_tx = Arc::new(Mutex::new(pending_file_count_tx)); + let (rate_limiting_count_tx, rate_limiting_count_rx) = watch::channel_with(0); + let rate_limiting_count_tx = Arc::new(Mutex::new(rate_limiting_count_tx)); Self { worktrees: Default::default(), pending_file_count_rx: pending_file_count_rx.clone(), pending_file_count_tx, pending_index: 0, + rate_limiting_count_rx: rate_limiting_count_rx.clone(), + rate_limiting_count_tx, _subscription: subscription, _observe_pending_file_count: cx.spawn_weak({ let mut pending_file_count_rx = pending_file_count_rx.clone(); @@ -293,6 +299,7 @@ impl SemanticIndex { } else { SemanticIndexStatus::Indexing { remaining_files: project_state.pending_file_count_rx.borrow().clone(), + rate_limiting: self.embedding_provider.rate_limit_expiration(), } } } else { diff --git a/crates/semantic_index/src/semantic_index_tests.rs b/crates/semantic_index/src/semantic_index_tests.rs index ffd8db8781..09c94b9a94 100644 --- a/crates/semantic_index/src/semantic_index_tests.rs +++ b/crates/semantic_index/src/semantic_index_tests.rs @@ -21,7 +21,7 @@ use std::{ atomic::{self, AtomicUsize}, Arc, }, - time::SystemTime, + time::{Duration, SystemTime}, }; use unindent::Unindent; use util::RandomCharIter; @@ -1275,6 +1275,10 @@ impl EmbeddingProvider for FakeEmbeddingProvider { 200 } + fn rate_limit_expiration(&self) -> Duration { + Duration::ZERO + } + async fn embed_batch(&self, spans: Vec) -> Result> { self.embedding_count .fetch_add(spans.len(), atomic::Ordering::SeqCst);