diff --git a/crates/vector_store/Cargo.toml b/crates/vector_store/Cargo.toml index 854afe5b6e..35a6a689ae 100644 --- a/crates/vector_store/Cargo.toml +++ b/crates/vector_store/Cargo.toml @@ -32,6 +32,7 @@ async-trait.workspace = true bincode = "1.3.3" matrixmultiply = "0.3.7" tiktoken-rs = "0.5.0" +rand.workspace = true [dev-dependencies] gpui = { path = "../gpui", features = ["test-support"] } diff --git a/crates/vector_store/src/embedding.rs b/crates/vector_store/src/embedding.rs index 72b30d9424..029a6cdf61 100644 --- a/crates/vector_store/src/embedding.rs +++ b/crates/vector_store/src/embedding.rs @@ -2,15 +2,20 @@ use anyhow::{anyhow, Result}; use async_trait::async_trait; use futures::AsyncReadExt; use gpui::serde_json; +use isahc::http::StatusCode; use isahc::prelude::Configurable; +use isahc::{AsyncBody, Response}; use lazy_static::lazy_static; use serde::{Deserialize, Serialize}; +use std::env; use std::sync::Arc; -use std::{env, time::Instant}; +use std::time::Duration; +use tiktoken_rs::{cl100k_base, CoreBPE}; use util::http::{HttpClient, Request}; lazy_static! { static ref OPENAI_API_KEY: Option = env::var("OPENAI_API_KEY").ok(); + static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap(); } #[derive(Clone)] @@ -60,69 +65,100 @@ impl EmbeddingProvider for DummyEmbeddings { } } -// impl OpenAIEmbeddings { -// async fn truncate(span: &str) -> String { -// let bpe = cl100k_base().unwrap(); -// let mut tokens = bpe.encode_with_special_tokens(span); -// if tokens.len() > 8192 { -// tokens.truncate(8192); -// let result = bpe.decode(tokens); -// if result.is_ok() { -// return result.unwrap(); -// } -// } +impl OpenAIEmbeddings { + async fn truncate(span: String) -> String { + let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span.as_ref()); + if tokens.len() > 8190 { + tokens.truncate(8190); + let result = OPENAI_BPE_TOKENIZER.decode(tokens.clone()); + if result.is_ok() { + let transformed = result.unwrap(); + // assert_ne!(transformed, span); + return transformed; + } + } -// return span.to_string(); -// } -// } - -#[async_trait] -impl EmbeddingProvider for OpenAIEmbeddings { - async fn embed_batch(&self, spans: Vec<&str>) -> Result>> { - // Truncate spans to 8192 if needed - // let t0 = Instant::now(); - // let mut truncated_spans = vec![]; - // for span in spans { - // truncated_spans.push(Self::truncate(span)); - // } - // let spans = futures::future::join_all(truncated_spans).await; - // log::info!("Truncated Spans in {:?}", t0.elapsed().as_secs()); - - let api_key = OPENAI_API_KEY - .as_ref() - .ok_or_else(|| anyhow!("no api key"))?; + return span.to_string(); + } + async fn send_request(&self, api_key: &str, spans: Vec<&str>) -> Result> { let request = Request::post("https://api.openai.com/v1/embeddings") .redirect_policy(isahc::config::RedirectPolicy::Follow) .header("Content-Type", "application/json") .header("Authorization", format!("Bearer {}", api_key)) .body( serde_json::to_string(&OpenAIEmbeddingRequest { - input: spans, + input: spans.clone(), model: "text-embedding-ada-002", }) .unwrap() .into(), )?; - let mut response = self.client.send(request).await?; - if !response.status().is_success() { - return Err(anyhow!("openai embedding failed {}", response.status())); - } - - let mut body = String::new(); - response.body_mut().read_to_string(&mut body).await?; - let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?; - - log::info!( - "openai embedding completed. tokens: {:?}", - response.usage.total_tokens - ); - - Ok(response - .data - .into_iter() - .map(|embedding| embedding.embedding) - .collect()) + Ok(self.client.send(request).await?) + } +} + +#[async_trait] +impl EmbeddingProvider for OpenAIEmbeddings { + async fn embed_batch(&self, spans: Vec<&str>) -> Result>> { + const BACKOFF_SECONDS: [usize; 3] = [65, 180, 360]; + const MAX_RETRIES: usize = 3; + + let api_key = OPENAI_API_KEY + .as_ref() + .ok_or_else(|| anyhow!("no api key"))?; + + let mut request_number = 0; + let mut response: Response; + let mut spans: Vec = spans.iter().map(|x| x.to_string()).collect(); + while request_number < MAX_RETRIES { + response = self + .send_request(api_key, spans.iter().map(|x| &**x).collect()) + .await?; + request_number += 1; + + if request_number + 1 == MAX_RETRIES && response.status() != StatusCode::OK { + return Err(anyhow!( + "openai max retries, error: {:?}", + &response.status() + )); + } + + match response.status() { + StatusCode::TOO_MANY_REQUESTS => { + let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64); + std::thread::sleep(delay); + } + StatusCode::BAD_REQUEST => { + log::info!("BAD REQUEST: {:?}", &response.status()); + // Don't worry about delaying bad request, as we can assume + // we haven't been rate limited yet. + for span in spans.iter_mut() { + *span = Self::truncate(span.to_string()).await; + } + } + StatusCode::OK => { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?; + + log::info!( + "openai embedding completed. tokens: {:?}", + response.usage.total_tokens + ); + return Ok(response + .data + .into_iter() + .map(|embedding| embedding.embedding) + .collect()); + } + _ => { + return Err(anyhow!("openai embedding failed {}", response.status())); + } + } + } + + Err(anyhow!("openai embedding failed")) } } diff --git a/crates/vector_store/src/vector_store.rs b/crates/vector_store/src/vector_store.rs index a63674bc34..5141451e64 100644 --- a/crates/vector_store/src/vector_store.rs +++ b/crates/vector_store/src/vector_store.rs @@ -74,7 +74,6 @@ pub fn init( cx.subscribe_global::({ let vector_store = vector_store.clone(); move |event, cx| { - let t0 = Instant::now(); let workspace = &event.0; if let Some(workspace) = workspace.upgrade(cx) { let project = workspace.read(cx).project().clone(); @@ -126,9 +125,7 @@ pub struct VectorStore { language_registry: Arc, db_update_tx: channel::Sender, // embed_batch_tx: channel::Sender)>>, - batch_files_tx: channel::Sender<(i64, IndexedFile, Vec)>, parsing_files_tx: channel::Sender<(i64, PathBuf, Arc, SystemTime)>, - parsing_files_rx: channel::Receiver<(i64, PathBuf, Arc, SystemTime)>, _db_update_task: Task<()>, _embed_batch_task: Vec>, _batch_files_task: Task<()>, @@ -220,14 +217,13 @@ impl VectorStore { let (embed_batch_tx, embed_batch_rx) = channel::unbounded::)>>(); let mut _embed_batch_task = Vec::new(); - for _ in 0..cx.background().num_cpus() { + for _ in 0..1 { + //cx.background().num_cpus() { let db_update_tx = db_update_tx.clone(); let embed_batch_rx = embed_batch_rx.clone(); let embedding_provider = embedding_provider.clone(); _embed_batch_task.push(cx.background().spawn(async move { while let Ok(embeddings_queue) = embed_batch_rx.recv().await { - log::info!("Embedding Batch! "); - // Construct Batch let mut embeddings_queue = embeddings_queue.clone(); let mut document_spans = vec![]; @@ -235,20 +231,20 @@ impl VectorStore { document_spans.extend(document_span); } - if let Some(mut embeddings) = embedding_provider + if let Ok(embeddings) = embedding_provider .embed_batch(document_spans.iter().map(|x| &**x).collect()) .await - .log_err() { let mut i = 0; let mut j = 0; - while let Some(embedding) = embeddings.pop() { + + for embedding in embeddings.iter() { while embeddings_queue[i].1.documents.len() == j { i += 1; j = 0; } - embeddings_queue[i].1.documents[j].embedding = embedding; + embeddings_queue[i].1.documents[j].embedding = embedding.to_owned(); j += 1; } @@ -283,7 +279,6 @@ impl VectorStore { while let Ok((worktree_id, indexed_file, document_spans)) = batch_files_rx.recv().await { - log::info!("Batching File: {:?}", &indexed_file.path); queue_len += &document_spans.len(); embeddings_queue.push((worktree_id, indexed_file, document_spans)); if queue_len >= EMBEDDINGS_BATCH_SIZE { @@ -338,10 +333,7 @@ impl VectorStore { embedding_provider, language_registry, db_update_tx, - // embed_batch_tx, - batch_files_tx, parsing_files_tx, - parsing_files_rx, _db_update_task, _embed_batch_task, _batch_files_task, @@ -449,8 +441,6 @@ impl VectorStore { let database_url = self.database_url.clone(); let db_update_tx = self.db_update_tx.clone(); let parsing_files_tx = self.parsing_files_tx.clone(); - let parsing_files_rx = self.parsing_files_rx.clone(); - let batch_files_tx = self.batch_files_tx.clone(); cx.spawn(|this, mut cx| async move { let t0 = Instant::now(); @@ -553,37 +543,6 @@ impl VectorStore { }) .detach(); - // cx.background() - // .scoped(|scope| { - // for _ in 0..cx.background().num_cpus() { - // scope.spawn(async { - // let mut parser = Parser::new(); - // let mut cursor = QueryCursor::new(); - // while let Ok((worktree_id, file_path, language, mtime)) = - // parsing_files_rx.recv().await - // { - // log::info!("Parsing File: {:?}", &file_path); - // if let Some((indexed_file, document_spans)) = Self::index_file( - // &mut cursor, - // &mut parser, - // &fs, - // language, - // file_path.clone(), - // mtime, - // ) - // .await - // .log_err() - // { - // batch_files_tx - // .try_send((worktree_id, indexed_file, document_spans)) - // .unwrap(); - // } - // } - // }); - // } - // }) - // .await; - this.update(&mut cx, |this, cx| { // The below is managing for updated on save // Currently each time a file is saved, this code is run, and for all the files that were changed, if the current time is @@ -592,90 +551,90 @@ impl VectorStore { if let Some(project_state) = this.projects.get(&project.downgrade()) { let worktree_db_ids = project_state.worktree_db_ids.clone(); - // if let project::Event::WorktreeUpdatedEntries(worktree_id, changes) = event - // { - // // Iterate through changes - // let language_registry = this.language_registry.clone(); + if let project::Event::WorktreeUpdatedEntries(worktree_id, changes) = event + { + // Iterate through changes + let language_registry = this.language_registry.clone(); - // let db = - // VectorDatabase::new(this.database_url.to_string_lossy().into()); - // if db.is_err() { - // return; - // } - // let db = db.unwrap(); + let db = + VectorDatabase::new(this.database_url.to_string_lossy().into()); + if db.is_err() { + return; + } + let db = db.unwrap(); - // let worktree_db_id: Option = { - // let mut found_db_id = None; - // for (w_id, db_id) in worktree_db_ids.into_iter() { - // if &w_id == worktree_id { - // found_db_id = Some(db_id); - // } - // } + let worktree_db_id: Option = { + let mut found_db_id = None; + for (w_id, db_id) in worktree_db_ids.into_iter() { + if &w_id == worktree_id { + found_db_id = Some(db_id); + } + } - // found_db_id - // }; + found_db_id + }; - // if worktree_db_id.is_none() { - // return; - // } - // let worktree_db_id = worktree_db_id.unwrap(); + if worktree_db_id.is_none() { + return; + } + let worktree_db_id = worktree_db_id.unwrap(); - // let file_mtimes = db.get_file_mtimes(worktree_db_id); - // if file_mtimes.is_err() { - // return; - // } + let file_mtimes = db.get_file_mtimes(worktree_db_id); + if file_mtimes.is_err() { + return; + } - // let file_mtimes = file_mtimes.unwrap(); - // let paths_tx = this.paths_tx.clone(); + let file_mtimes = file_mtimes.unwrap(); + let parsing_files_tx = this.parsing_files_tx.clone(); - // smol::block_on(async move { - // for change in changes.into_iter() { - // let change_path = change.0.clone(); - // log::info!("Change: {:?}", &change_path); - // if let Ok(language) = language_registry - // .language_for_file(&change_path.to_path_buf(), None) - // .await - // { - // if language - // .grammar() - // .and_then(|grammar| grammar.embedding_config.as_ref()) - // .is_none() - // { - // continue; - // } + smol::block_on(async move { + for change in changes.into_iter() { + let change_path = change.0.clone(); + log::info!("Change: {:?}", &change_path); + if let Ok(language) = language_registry + .language_for_file(&change_path.to_path_buf(), None) + .await + { + if language + .grammar() + .and_then(|grammar| grammar.embedding_config.as_ref()) + .is_none() + { + continue; + } - // // TODO: Make this a bit more defensive - // let modified_time = - // change_path.metadata().unwrap().modified().unwrap(); - // let existing_time = - // file_mtimes.get(&change_path.to_path_buf()); - // let already_stored = - // existing_time.map_or(false, |existing_time| { - // if &modified_time != existing_time - // && existing_time.elapsed().unwrap().as_secs() - // > REINDEXING_DELAY - // { - // false - // } else { - // true - // } - // }); + // TODO: Make this a bit more defensive + let modified_time = + change_path.metadata().unwrap().modified().unwrap(); + let existing_time = + file_mtimes.get(&change_path.to_path_buf()); + let already_stored = + existing_time.map_or(false, |existing_time| { + if &modified_time != existing_time + && existing_time.elapsed().unwrap().as_secs() + > REINDEXING_DELAY + { + false + } else { + true + } + }); - // if !already_stored { - // log::info!("Need to reindex: {:?}", &change_path); - // paths_tx - // .try_send(( - // worktree_db_id, - // change_path.to_path_buf(), - // language, - // modified_time, - // )) - // .unwrap(); - // } - // } - // } - // }) - // } + if !already_stored { + log::info!("Need to reindex: {:?}", &change_path); + parsing_files_tx + .try_send(( + worktree_db_id, + change_path.to_path_buf(), + language, + modified_time, + )) + .unwrap(); + } + } + } + }) + } } });