From 5831d80f513fe374673dabe1346588ea5ad8cba1 Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Wed, 1 May 2024 12:15:44 -0700 Subject: [PATCH] Return an error from project index tool when embedding query fails (#11264) Previously, a failure to embed the search query (due to a rate limit error) would appear the same as if there were no results. * Avoid repeatedly embedding the search query for each worktree * Unify tasks for searching all worktree Release Notes: - N/A --- crates/assistant2/src/tools/project_index.rs | 5 +- crates/semantic_index/examples/index.rs | 5 +- crates/semantic_index/src/chunking.rs | 9 +- crates/semantic_index/src/semantic_index.rs | 250 ++++++++++--------- 4 files changed, 137 insertions(+), 132 deletions(-) diff --git a/crates/assistant2/src/tools/project_index.rs b/crates/assistant2/src/tools/project_index.rs index 20f6c51add..d9f3e8271d 100644 --- a/crates/assistant2/src/tools/project_index.rs +++ b/crates/assistant2/src/tools/project_index.rs @@ -140,10 +140,9 @@ impl LanguageModelTool for ProjectIndexTool { fn execute(&self, query: &Self::Input, cx: &mut WindowContext) -> Task> { let project_index = self.project_index.read(cx); - let status = project_index.status(); let results = project_index.search( - query.query.as_str(), + query.query.clone(), query.limit.unwrap_or(DEFAULT_SEARCH_LIMIT), cx, ); @@ -151,7 +150,7 @@ impl LanguageModelTool for ProjectIndexTool { let fs = self.fs.clone(); cx.spawn(|cx| async move { - let results = results.await; + let results = results.await?; let excerpts = results.into_iter().map(|result| { let abs_path = result diff --git a/crates/semantic_index/examples/index.rs b/crates/semantic_index/examples/index.rs index ed5f461377..d166a4d7d0 100644 --- a/crates/semantic_index/examples/index.rs +++ b/crates/semantic_index/examples/index.rs @@ -92,10 +92,11 @@ fn main() { .update(|cx| { let project_index = project_index.read(cx); let query = "converting an anchor to a point"; - project_index.search(query, 4, cx) + project_index.search(query.into(), 4, cx) }) .unwrap() - .await; + .await + .unwrap(); for search_result in results { let path = search_result.path.clone(); diff --git a/crates/semantic_index/src/chunking.rs b/crates/semantic_index/src/chunking.rs index b0ddd73122..9918bb1d2c 100644 --- a/crates/semantic_index/src/chunking.rs +++ b/crates/semantic_index/src/chunking.rs @@ -98,12 +98,9 @@ fn chunk_lines(text: &str) -> Vec { chunk_ranges .into_iter() - .map(|range| { - let mut hasher = Sha256::new(); - hasher.update(&text[range.clone()]); - let mut digest = [0u8; 32]; - digest.copy_from_slice(hasher.finalize().as_slice()); - Chunk { range, digest } + .map(|range| Chunk { + digest: Sha256::digest(&text[range.clone()]).into(), + range, }) .collect() } diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index f8082de07b..d17ede7162 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -15,7 +15,7 @@ use gpui::{ use heed::types::{SerdeBincode, Str}; use language::LanguageRegistry; use parking_lot::Mutex; -use project::{Entry, Project, ProjectEntryId, UpdatedEntriesSet, Worktree}; +use project::{Entry, Project, ProjectEntryId, UpdatedEntriesSet, Worktree, WorktreeId}; use serde::{Deserialize, Serialize}; use smol::channel; use std::{ @@ -156,6 +156,10 @@ impl ProjectIndex { self.last_status } + pub fn project(&self) -> WeakModel { + self.project.clone() + } + fn handle_project_event( &mut self, _: Model, @@ -259,30 +263,126 @@ impl ProjectIndex { } } - pub fn search(&self, query: &str, limit: usize, cx: &AppContext) -> Task> { - let mut worktree_searches = Vec::new(); + pub fn search( + &self, + query: String, + limit: usize, + cx: &AppContext, + ) -> Task>> { + let (chunks_tx, chunks_rx) = channel::bounded(1024); + let mut worktree_scan_tasks = Vec::new(); for worktree_index in self.worktree_indices.values() { if let WorktreeIndexHandle::Loaded { index, .. } = worktree_index { - worktree_searches - .push(index.read_with(cx, |index, cx| index.search(query, limit, cx))); + let chunks_tx = chunks_tx.clone(); + index.read_with(cx, |index, cx| { + let worktree_id = index.worktree.read(cx).id(); + let db_connection = index.db_connection.clone(); + let db = index.db; + worktree_scan_tasks.push(cx.background_executor().spawn({ + async move { + let txn = db_connection + .read_txn() + .context("failed to create read transaction")?; + let db_entries = db.iter(&txn).context("failed to iterate database")?; + for db_entry in db_entries { + let (_key, db_embedded_file) = db_entry?; + for chunk in db_embedded_file.chunks { + chunks_tx + .send((worktree_id, db_embedded_file.path.clone(), chunk)) + .await?; + } + } + anyhow::Ok(()) + } + })); + }) } } + drop(chunks_tx); - cx.spawn(|_| async move { - let mut results = Vec::new(); - let worktree_searches = futures::future::join_all(worktree_searches).await; + let project = self.project.clone(); + let embedding_provider = self.embedding_provider.clone(); + cx.spawn(|cx| async move { + #[cfg(debug_assertions)] + let embedding_query_start = std::time::Instant::now(); + log::info!("Searching for {query}"); - for worktree_search_results in worktree_searches { - if let Some(worktree_search_results) = worktree_search_results.log_err() { - results.extend(worktree_search_results); - } + let query_embeddings = embedding_provider + .embed(&[TextToEmbed::new(&query)]) + .await?; + let query_embedding = query_embeddings + .into_iter() + .next() + .ok_or_else(|| anyhow!("no embedding for query"))?; + + let mut results_by_worker = Vec::new(); + for _ in 0..cx.background_executor().num_cpus() { + results_by_worker.push(Vec::::new()); } - results - .sort_unstable_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal)); - results.truncate(limit); + #[cfg(debug_assertions)] + let search_start = std::time::Instant::now(); - results + cx.background_executor() + .scoped(|cx| { + for results in results_by_worker.iter_mut() { + cx.spawn(async { + while let Ok((worktree_id, path, chunk)) = chunks_rx.recv().await { + let score = chunk.embedding.similarity(&query_embedding); + let ix = match results.binary_search_by(|probe| { + score.partial_cmp(&probe.score).unwrap_or(Ordering::Equal) + }) { + Ok(ix) | Err(ix) => ix, + }; + results.insert( + ix, + WorktreeSearchResult { + worktree_id, + path: path.clone(), + range: chunk.chunk.range.clone(), + score, + }, + ); + results.truncate(limit); + } + }); + } + }) + .await; + + futures::future::try_join_all(worktree_scan_tasks).await?; + + project.read_with(&cx, |project, cx| { + let mut search_results = Vec::with_capacity(results_by_worker.len() * limit); + for worker_results in results_by_worker { + search_results.extend(worker_results.into_iter().filter_map(|result| { + Some(SearchResult { + worktree: project.worktree_for_id(result.worktree_id, cx)?, + path: result.path, + range: result.range, + score: result.score, + }) + })); + } + search_results.sort_unstable_by(|a, b| { + b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal) + }); + search_results.truncate(limit); + + #[cfg(debug_assertions)] + { + let search_elapsed = search_start.elapsed(); + log::debug!( + "searched {} entries in {:?}", + search_results.len(), + search_elapsed + ); + let embedding_query_elapsed = embedding_query_start.elapsed(); + log::debug!("embedding query took {:?}", embedding_query_elapsed); + } + + search_results + }) }) } @@ -327,6 +427,13 @@ pub struct SearchResult { pub score: f32, } +pub struct WorktreeSearchResult { + pub worktree_id: WorktreeId, + pub path: Arc, + pub range: Range, + pub score: f32, +} + #[derive(Copy, Clone, Debug, Eq, PartialEq)] pub enum Status { Idle, @@ -764,107 +871,6 @@ impl WorktreeIndex { }) } - fn search( - &self, - query: &str, - limit: usize, - cx: &AppContext, - ) -> Task>> { - let (chunks_tx, chunks_rx) = channel::bounded(1024); - - let db_connection = self.db_connection.clone(); - let db = self.db; - let scan_chunks = cx.background_executor().spawn({ - async move { - let txn = db_connection - .read_txn() - .context("failed to create read transaction")?; - let db_entries = db.iter(&txn).context("failed to iterate database")?; - for db_entry in db_entries { - let (_key, db_embedded_file) = db_entry?; - for chunk in db_embedded_file.chunks { - chunks_tx - .send((db_embedded_file.path.clone(), chunk)) - .await?; - } - } - anyhow::Ok(()) - } - }); - - let query = query.to_string(); - let embedding_provider = self.embedding_provider.clone(); - let worktree = self.worktree.clone(); - cx.spawn(|cx| async move { - #[cfg(debug_assertions)] - let embedding_query_start = std::time::Instant::now(); - log::info!("Searching for {query}"); - - let mut query_embeddings = embedding_provider - .embed(&[TextToEmbed::new(&query)]) - .await?; - let query_embedding = query_embeddings - .pop() - .ok_or_else(|| anyhow!("no embedding for query"))?; - let mut workers = Vec::new(); - for _ in 0..cx.background_executor().num_cpus() { - workers.push(Vec::::new()); - } - - #[cfg(debug_assertions)] - let search_start = std::time::Instant::now(); - - cx.background_executor() - .scoped(|cx| { - for worker_results in workers.iter_mut() { - cx.spawn(async { - while let Ok((path, embedded_chunk)) = chunks_rx.recv().await { - let score = embedded_chunk.embedding.similarity(&query_embedding); - let ix = match worker_results.binary_search_by(|probe| { - score.partial_cmp(&probe.score).unwrap_or(Ordering::Equal) - }) { - Ok(ix) | Err(ix) => ix, - }; - worker_results.insert( - ix, - SearchResult { - worktree: worktree.clone(), - path: path.clone(), - range: embedded_chunk.chunk.range.clone(), - score, - }, - ); - worker_results.truncate(limit); - } - }); - } - }) - .await; - scan_chunks.await?; - - let mut search_results = Vec::with_capacity(workers.len() * limit); - for worker_results in workers { - search_results.extend(worker_results); - } - search_results - .sort_unstable_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal)); - search_results.truncate(limit); - #[cfg(debug_assertions)] - { - let search_elapsed = search_start.elapsed(); - log::debug!( - "searched {} entries in {:?}", - search_results.len(), - search_elapsed - ); - let embedding_query_elapsed = embedding_query_start.elapsed(); - log::debug!("embedding query took {:?}", embedding_query_elapsed); - } - - Ok(search_results) - }) - } - fn debug(&mut self, cx: &mut ModelContext) -> Task> { let connection = self.db_connection.clone(); let db = self.db; @@ -1093,9 +1099,10 @@ mod tests { .update(|cx| { let project_index = project_index.read(cx); let query = "garbage in, garbage out"; - project_index.search(query, 4, cx) + project_index.search(query.into(), 4, cx) }) - .await; + .await + .unwrap(); assert!(results.len() > 1, "should have found some results"); @@ -1112,9 +1119,10 @@ mod tests { let content = cx .update(|cx| { let worktree = search_result.worktree.read(cx); - let entry_abs_path = worktree.abs_path().join(search_result.path.clone()); + let entry_abs_path = worktree.abs_path().join(&search_result.path); let fs = project.read(cx).fs().clone(); - cx.spawn(|_| async move { fs.load(&entry_abs_path).await.unwrap() }) + cx.background_executor() + .spawn(async move { fs.load(&entry_abs_path).await.unwrap() }) }) .await;