initial outline for rate limiting status updates

This commit is contained in:
KCaverly 2023-09-08 12:35:15 -04:00
parent e9747d0fea
commit a5ee8fc805
4 changed files with 106 additions and 8 deletions

View File

@ -34,6 +34,7 @@ use std::{
ops::{Not, Range},
path::PathBuf,
sync::Arc,
time::Duration,
};
use util::ResultExt as _;
use workspace::{
@ -319,11 +320,22 @@ impl View for ProjectSearchView {
let status = semantic.index_status;
match status {
SemanticIndexStatus::Indexed => Some("Indexing complete".to_string()),
SemanticIndexStatus::Indexing { remaining_files } => {
SemanticIndexStatus::Indexing {
remaining_files,
rate_limiting,
} => {
if remaining_files == 0 {
Some(format!("Indexing..."))
} else {
Some(format!("Remaining files to index: {}", remaining_files))
if rate_limiting > Duration::ZERO {
Some(format!(
"Remaining files to index (rate limit resets in {}s): {}",
rate_limiting.as_secs(),
remaining_files
))
} else {
Some(format!("Remaining files to index: {}", remaining_files))
}
}
}
SemanticIndexStatus::NotIndexed => None,

View File

@ -7,7 +7,9 @@ use isahc::http::StatusCode;
use isahc::prelude::Configurable;
use isahc::{AsyncBody, Response};
use lazy_static::lazy_static;
use parking_lot::Mutex;
use parse_duration::parse;
use postage::watch;
use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef};
use rusqlite::ToSql;
use serde::{Deserialize, Serialize};
@ -82,6 +84,8 @@ impl ToSql for Embedding {
pub struct OpenAIEmbeddings {
pub client: Arc<dyn HttpClient>,
pub executor: Arc<Background>,
rate_limit_count_rx: watch::Receiver<(Duration, usize)>,
rate_limit_count_tx: Arc<Mutex<watch::Sender<(Duration, usize)>>>,
}
#[derive(Serialize)]
@ -114,12 +118,16 @@ pub trait EmbeddingProvider: Sync + Send {
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>>;
fn max_tokens_per_batch(&self) -> usize;
fn truncate(&self, span: &str) -> (String, usize);
fn rate_limit_expiration(&self) -> Duration;
}
pub struct DummyEmbeddings {}
#[async_trait]
impl EmbeddingProvider for DummyEmbeddings {
fn rate_limit_expiration(&self) -> Duration {
Duration::ZERO
}
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
// 1024 is the OpenAI Embeddings size for ada models.
// the model we will likely be starting with.
@ -149,6 +157,53 @@ impl EmbeddingProvider for DummyEmbeddings {
const OPENAI_INPUT_LIMIT: usize = 8190;
impl OpenAIEmbeddings {
pub fn new(client: Arc<dyn HttpClient>, executor: Arc<Background>) -> Self {
let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with((Duration::ZERO, 0));
let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
OpenAIEmbeddings {
client,
executor,
rate_limit_count_rx,
rate_limit_count_tx,
}
}
fn resolve_rate_limit(&self) {
let (current_delay, delay_count) = *self.rate_limit_count_tx.lock().borrow();
let updated_count = delay_count - 1;
let updated_duration = if updated_count == 0 {
Duration::ZERO
} else {
current_delay
};
log::trace!(
"resolving rate limit: Count: {:?} Duration: {:?}",
updated_count,
updated_duration
);
*self.rate_limit_count_tx.lock().borrow_mut() = (updated_duration, updated_count);
}
fn update_rate_limit(&self, delay_duration: Duration, count_increase: usize) {
let (current_delay, delay_count) = *self.rate_limit_count_tx.lock().borrow();
let updated_count = delay_count + count_increase;
let updated_duration = if current_delay < delay_duration {
delay_duration
} else {
current_delay
};
log::trace!(
"updating rate limit: Count: {:?} Duration: {:?}",
updated_count,
updated_duration
);
*self.rate_limit_count_tx.lock().borrow_mut() = (updated_duration, updated_count);
}
async fn send_request(
&self,
api_key: &str,
@ -179,6 +234,10 @@ impl EmbeddingProvider for OpenAIEmbeddings {
50000
}
fn rate_limit_expiration(&self) -> Duration {
let (duration, _) = *self.rate_limit_count_rx.borrow();
duration
}
fn truncate(&self, span: &str) -> (String, usize) {
let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
let output = if tokens.len() > OPENAI_INPUT_LIMIT {
@ -203,6 +262,7 @@ impl EmbeddingProvider for OpenAIEmbeddings {
.ok_or_else(|| anyhow!("no api key"))?;
let mut request_number = 0;
let mut rate_limiting = false;
let mut request_timeout: u64 = 15;
let mut response: Response<AsyncBody>;
while request_number < MAX_RETRIES {
@ -229,6 +289,12 @@ impl EmbeddingProvider for OpenAIEmbeddings {
response.usage.total_tokens
);
// If we complete a request successfully that was previously rate_limited
// resolve the rate limit
if rate_limiting {
self.resolve_rate_limit()
}
return Ok(response
.data
.into_iter()
@ -254,6 +320,15 @@ impl EmbeddingProvider for OpenAIEmbeddings {
}
};
// If we've previously rate limited, increment the duration but not the count
if rate_limiting {
self.update_rate_limit(delay_duration, 0);
} else {
self.update_rate_limit(delay_duration, 1);
}
rate_limiting = true;
log::trace!(
"openai rate limiting: waiting {:?} until lifted",
&delay_duration

View File

@ -91,10 +91,7 @@ pub fn init(
let semantic_index = SemanticIndex::new(
fs,
db_file_path,
Arc::new(OpenAIEmbeddings {
client: http_client,
executor: cx.background(),
}),
Arc::new(OpenAIEmbeddings::new(http_client, cx.background())),
language_registry,
cx.clone(),
)
@ -113,7 +110,10 @@ pub fn init(
pub enum SemanticIndexStatus {
NotIndexed,
Indexed,
Indexing { remaining_files: usize },
Indexing {
remaining_files: usize,
rate_limiting: Duration,
},
}
pub struct SemanticIndex {
@ -132,6 +132,8 @@ struct ProjectState {
pending_file_count_rx: watch::Receiver<usize>,
pending_file_count_tx: Arc<Mutex<watch::Sender<usize>>>,
pending_index: usize,
rate_limiting_count_rx: watch::Receiver<usize>,
rate_limiting_count_tx: Arc<Mutex<watch::Sender<usize>>>,
_subscription: gpui::Subscription,
_observe_pending_file_count: Task<()>,
}
@ -223,11 +225,15 @@ impl ProjectState {
fn new(subscription: gpui::Subscription, cx: &mut ModelContext<SemanticIndex>) -> Self {
let (pending_file_count_tx, pending_file_count_rx) = watch::channel_with(0);
let pending_file_count_tx = Arc::new(Mutex::new(pending_file_count_tx));
let (rate_limiting_count_tx, rate_limiting_count_rx) = watch::channel_with(0);
let rate_limiting_count_tx = Arc::new(Mutex::new(rate_limiting_count_tx));
Self {
worktrees: Default::default(),
pending_file_count_rx: pending_file_count_rx.clone(),
pending_file_count_tx,
pending_index: 0,
rate_limiting_count_rx: rate_limiting_count_rx.clone(),
rate_limiting_count_tx,
_subscription: subscription,
_observe_pending_file_count: cx.spawn_weak({
let mut pending_file_count_rx = pending_file_count_rx.clone();
@ -293,6 +299,7 @@ impl SemanticIndex {
} else {
SemanticIndexStatus::Indexing {
remaining_files: project_state.pending_file_count_rx.borrow().clone(),
rate_limiting: self.embedding_provider.rate_limit_expiration(),
}
}
} else {

View File

@ -21,7 +21,7 @@ use std::{
atomic::{self, AtomicUsize},
Arc,
},
time::SystemTime,
time::{Duration, SystemTime},
};
use unindent::Unindent;
use util::RandomCharIter;
@ -1275,6 +1275,10 @@ impl EmbeddingProvider for FakeEmbeddingProvider {
200
}
fn rate_limit_expiration(&self) -> Duration {
Duration::ZERO
}
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
self.embedding_count
.fetch_add(spans.len(), atomic::Ordering::SeqCst);