diff --git a/crates/search/src/project_search.rs b/crates/search/src/project_search.rs index 52ee12c26d..7f1a639a69 100644 --- a/crates/search/src/project_search.rs +++ b/crates/search/src/project_search.rs @@ -30,6 +30,7 @@ use std::{ ops::{Not, Range}, path::PathBuf, sync::Arc, + time::Instant, }; use util::ResultExt as _; use workspace::{ @@ -192,6 +193,7 @@ impl ProjectSearch { exclude_files: Vec, cx: &mut ModelContext, ) { + let t0 = Instant::now(); let search = SemanticIndex::global(cx).map(|index| { index.update(cx, |semantic_index, cx| { semantic_index.search_project( @@ -208,6 +210,7 @@ impl ProjectSearch { self.match_ranges.clear(); self.pending_search = Some(cx.spawn(|this, mut cx| async move { let results = search?.await.log_err()?; + log::trace!("semantic search elapsed: {:?}", t0.elapsed().as_millis()); let (_task, mut match_ranges) = this.update(&mut cx, |this, cx| { this.excerpts.update(cx, |excerpts, cx| { diff --git a/crates/semantic_index/src/db.rs b/crates/semantic_index/src/db.rs index 4bc97da0f0..85631e7fc6 100644 --- a/crates/semantic_index/src/db.rs +++ b/crates/semantic_index/src/db.rs @@ -267,41 +267,56 @@ impl VectorDatabase { pub fn top_k_search( &self, - worktree_ids: &[i64], query_embedding: &Vec, limit: usize, - include_globs: Vec, - exclude_globs: Vec, - ) -> Result)>> { + file_ids: &[i64], + ) -> Result> { let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1); - self.for_each_document( - &worktree_ids, - include_globs, - exclude_globs, - |id, embedding| { - let similarity = dot(&embedding, &query_embedding); - let ix = match results.binary_search_by(|(_, s)| { - similarity.partial_cmp(&s).unwrap_or(Ordering::Equal) - }) { - Ok(ix) => ix, - Err(ix) => ix, - }; - results.insert(ix, (id, similarity)); - results.truncate(limit); - }, - )?; + self.for_each_document(file_ids, |id, embedding| { + let similarity = dot(&embedding, &query_embedding); + let ix = match results + .binary_search_by(|(_, s)| similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)) + { + Ok(ix) => ix, + Err(ix) => ix, + }; + results.insert(ix, (id, similarity)); + results.truncate(limit); + })?; - let ids = results.into_iter().map(|(id, _)| id).collect::>(); - self.get_documents_by_ids(&ids) + Ok(results) } - fn for_each_document( + // pub fn top_k_search( + // &self, + // worktree_ids: &[i64], + // query_embedding: &Vec, + // limit: usize, + // file_ids: Vec, + // ) -> Result)>> { + // let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1); + // self.for_each_document(&worktree_ids, file_ids, |id, embedding| { + // let similarity = dot(&embedding, &query_embedding); + // let ix = match results + // .binary_search_by(|(_, s)| similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)) + // { + // Ok(ix) => ix, + // Err(ix) => ix, + // }; + // results.insert(ix, (id, similarity)); + // results.truncate(limit); + // })?; + + // let ids = results.into_iter().map(|(id, _)| id).collect::>(); + // self.get_documents_by_ids(&ids) + // } + + pub fn retrieve_included_file_ids( &self, worktree_ids: &[i64], include_globs: Vec, exclude_globs: Vec, - mut f: impl FnMut(i64, Vec), - ) -> Result<()> { + ) -> Result> { let mut file_query = self.db.prepare( " SELECT @@ -315,6 +330,7 @@ impl VectorDatabase { let mut file_ids = Vec::::new(); let mut rows = file_query.query([ids_to_sql(worktree_ids)])?; + while let Some(row) = rows.next()? { let file_id = row.get(0)?; let relative_path = row.get_ref(1)?.as_str()?; @@ -330,6 +346,10 @@ impl VectorDatabase { } } + Ok(file_ids) + } + + fn for_each_document(&self, file_ids: &[i64], mut f: impl FnMut(i64, Vec)) -> Result<()> { let mut query_statement = self.db.prepare( " SELECT @@ -350,7 +370,7 @@ impl VectorDatabase { Ok(()) } - fn get_documents_by_ids(&self, ids: &[i64]) -> Result)>> { + pub fn get_documents_by_ids(&self, ids: &[i64]) -> Result)>> { let mut statement = self.db.prepare( " SELECT diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index e4a307573a..d2b69a0329 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -20,6 +20,7 @@ use postage::watch; use project::{Fs, Project, WorktreeId}; use smol::channel; use std::{ + cmp::Ordering, collections::HashMap, mem, ops::Range, @@ -704,27 +705,64 @@ impl SemanticIndex { let database_url = self.database_url.clone(); let fs = self.fs.clone(); cx.spawn(|this, mut cx| async move { - let documents = cx - .background() - .spawn(async move { - let database = VectorDatabase::new(fs, database_url).await?; + let database = VectorDatabase::new(fs.clone(), database_url.clone()).await?; - let phrase_embedding = embedding_provider - .embed_batch(vec![&phrase]) - .await? - .into_iter() - .next() - .unwrap(); + let phrase_embedding = embedding_provider + .embed_batch(vec![&phrase]) + .await? + .into_iter() + .next() + .unwrap(); - database.top_k_search( - &worktree_db_ids, - &phrase_embedding, - limit, - include_globs, - exclude_globs, - ) - }) - .await?; + let file_ids = database.retrieve_included_file_ids( + &worktree_db_ids, + include_globs, + exclude_globs, + )?; + + let batch_n = cx.background().num_cpus(); + let batch_size = file_ids.clone().len() / batch_n; + + let mut result_tasks = Vec::new(); + for batch in file_ids.chunks(batch_size) { + let batch = batch.into_iter().map(|v| *v).collect::>(); + let limit = limit.clone(); + let fs = fs.clone(); + let database_url = database_url.clone(); + let phrase_embedding = phrase_embedding.clone(); + let task = cx.background().spawn(async move { + let database = VectorDatabase::new(fs, database_url).await.log_err(); + if database.is_none() { + return Err(anyhow!("failed to acquire database connection")); + } else { + database + .unwrap() + .top_k_search(&phrase_embedding, limit, batch.as_slice()) + } + }); + result_tasks.push(task); + } + + let batch_results = futures::future::join_all(result_tasks).await; + + let mut results = Vec::new(); + for batch_result in batch_results { + if batch_result.is_ok() { + for (id, similarity) in batch_result.unwrap() { + let ix = match results.binary_search_by(|(_, s)| { + similarity.partial_cmp(&s).unwrap_or(Ordering::Equal) + }) { + Ok(ix) => ix, + Err(ix) => ix, + }; + results.insert(ix, (id, similarity)); + results.truncate(limit); + } + } + } + + let ids = results.into_iter().map(|(id, _)| id).collect::>(); + let documents = database.get_documents_by_ids(ids.as_slice())?; let mut tasks = Vec::new(); let mut ranges = Vec::new();