Return an error from project index tool when embedding query fails (#11264)

Previously, a failure to embed the search query (due to a rate limit
error) would appear the same as if there were no results.

* Avoid repeatedly embedding the search query for each worktree
* Unify tasks for searching all worktree

Release Notes:

- N/A
This commit is contained in:
Max Brunsfeld 2024-05-01 12:15:44 -07:00 committed by GitHub
parent 4b767697af
commit 5831d80f51
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 137 additions and 132 deletions

View File

@ -140,10 +140,9 @@ impl LanguageModelTool for ProjectIndexTool {
fn execute(&self, query: &Self::Input, cx: &mut WindowContext) -> Task<Result<Self::Output>> {
let project_index = self.project_index.read(cx);
let status = project_index.status();
let results = project_index.search(
query.query.as_str(),
query.query.clone(),
query.limit.unwrap_or(DEFAULT_SEARCH_LIMIT),
cx,
);
@ -151,7 +150,7 @@ impl LanguageModelTool for ProjectIndexTool {
let fs = self.fs.clone();
cx.spawn(|cx| async move {
let results = results.await;
let results = results.await?;
let excerpts = results.into_iter().map(|result| {
let abs_path = result

View File

@ -92,10 +92,11 @@ fn main() {
.update(|cx| {
let project_index = project_index.read(cx);
let query = "converting an anchor to a point";
project_index.search(query, 4, cx)
project_index.search(query.into(), 4, cx)
})
.unwrap()
.await;
.await
.unwrap();
for search_result in results {
let path = search_result.path.clone();

View File

@ -98,12 +98,9 @@ fn chunk_lines(text: &str) -> Vec<Chunk> {
chunk_ranges
.into_iter()
.map(|range| {
let mut hasher = Sha256::new();
hasher.update(&text[range.clone()]);
let mut digest = [0u8; 32];
digest.copy_from_slice(hasher.finalize().as_slice());
Chunk { range, digest }
.map(|range| Chunk {
digest: Sha256::digest(&text[range.clone()]).into(),
range,
})
.collect()
}

View File

@ -15,7 +15,7 @@ use gpui::{
use heed::types::{SerdeBincode, Str};
use language::LanguageRegistry;
use parking_lot::Mutex;
use project::{Entry, Project, ProjectEntryId, UpdatedEntriesSet, Worktree};
use project::{Entry, Project, ProjectEntryId, UpdatedEntriesSet, Worktree, WorktreeId};
use serde::{Deserialize, Serialize};
use smol::channel;
use std::{
@ -156,6 +156,10 @@ impl ProjectIndex {
self.last_status
}
pub fn project(&self) -> WeakModel<Project> {
self.project.clone()
}
fn handle_project_event(
&mut self,
_: Model<Project>,
@ -259,30 +263,126 @@ impl ProjectIndex {
}
}
pub fn search(&self, query: &str, limit: usize, cx: &AppContext) -> Task<Vec<SearchResult>> {
let mut worktree_searches = Vec::new();
pub fn search(
&self,
query: String,
limit: usize,
cx: &AppContext,
) -> Task<Result<Vec<SearchResult>>> {
let (chunks_tx, chunks_rx) = channel::bounded(1024);
let mut worktree_scan_tasks = Vec::new();
for worktree_index in self.worktree_indices.values() {
if let WorktreeIndexHandle::Loaded { index, .. } = worktree_index {
worktree_searches
.push(index.read_with(cx, |index, cx| index.search(query, limit, cx)));
let chunks_tx = chunks_tx.clone();
index.read_with(cx, |index, cx| {
let worktree_id = index.worktree.read(cx).id();
let db_connection = index.db_connection.clone();
let db = index.db;
worktree_scan_tasks.push(cx.background_executor().spawn({
async move {
let txn = db_connection
.read_txn()
.context("failed to create read transaction")?;
let db_entries = db.iter(&txn).context("failed to iterate database")?;
for db_entry in db_entries {
let (_key, db_embedded_file) = db_entry?;
for chunk in db_embedded_file.chunks {
chunks_tx
.send((worktree_id, db_embedded_file.path.clone(), chunk))
.await?;
}
}
anyhow::Ok(())
}
}));
})
}
}
drop(chunks_tx);
cx.spawn(|_| async move {
let mut results = Vec::new();
let worktree_searches = futures::future::join_all(worktree_searches).await;
let project = self.project.clone();
let embedding_provider = self.embedding_provider.clone();
cx.spawn(|cx| async move {
#[cfg(debug_assertions)]
let embedding_query_start = std::time::Instant::now();
log::info!("Searching for {query}");
for worktree_search_results in worktree_searches {
if let Some(worktree_search_results) = worktree_search_results.log_err() {
results.extend(worktree_search_results);
}
let query_embeddings = embedding_provider
.embed(&[TextToEmbed::new(&query)])
.await?;
let query_embedding = query_embeddings
.into_iter()
.next()
.ok_or_else(|| anyhow!("no embedding for query"))?;
let mut results_by_worker = Vec::new();
for _ in 0..cx.background_executor().num_cpus() {
results_by_worker.push(Vec::<WorktreeSearchResult>::new());
}
results
.sort_unstable_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal));
results.truncate(limit);
#[cfg(debug_assertions)]
let search_start = std::time::Instant::now();
results
cx.background_executor()
.scoped(|cx| {
for results in results_by_worker.iter_mut() {
cx.spawn(async {
while let Ok((worktree_id, path, chunk)) = chunks_rx.recv().await {
let score = chunk.embedding.similarity(&query_embedding);
let ix = match results.binary_search_by(|probe| {
score.partial_cmp(&probe.score).unwrap_or(Ordering::Equal)
}) {
Ok(ix) | Err(ix) => ix,
};
results.insert(
ix,
WorktreeSearchResult {
worktree_id,
path: path.clone(),
range: chunk.chunk.range.clone(),
score,
},
);
results.truncate(limit);
}
});
}
})
.await;
futures::future::try_join_all(worktree_scan_tasks).await?;
project.read_with(&cx, |project, cx| {
let mut search_results = Vec::with_capacity(results_by_worker.len() * limit);
for worker_results in results_by_worker {
search_results.extend(worker_results.into_iter().filter_map(|result| {
Some(SearchResult {
worktree: project.worktree_for_id(result.worktree_id, cx)?,
path: result.path,
range: result.range,
score: result.score,
})
}));
}
search_results.sort_unstable_by(|a, b| {
b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal)
});
search_results.truncate(limit);
#[cfg(debug_assertions)]
{
let search_elapsed = search_start.elapsed();
log::debug!(
"searched {} entries in {:?}",
search_results.len(),
search_elapsed
);
let embedding_query_elapsed = embedding_query_start.elapsed();
log::debug!("embedding query took {:?}", embedding_query_elapsed);
}
search_results
})
})
}
@ -327,6 +427,13 @@ pub struct SearchResult {
pub score: f32,
}
pub struct WorktreeSearchResult {
pub worktree_id: WorktreeId,
pub path: Arc<Path>,
pub range: Range<usize>,
pub score: f32,
}
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub enum Status {
Idle,
@ -764,107 +871,6 @@ impl WorktreeIndex {
})
}
fn search(
&self,
query: &str,
limit: usize,
cx: &AppContext,
) -> Task<Result<Vec<SearchResult>>> {
let (chunks_tx, chunks_rx) = channel::bounded(1024);
let db_connection = self.db_connection.clone();
let db = self.db;
let scan_chunks = cx.background_executor().spawn({
async move {
let txn = db_connection
.read_txn()
.context("failed to create read transaction")?;
let db_entries = db.iter(&txn).context("failed to iterate database")?;
for db_entry in db_entries {
let (_key, db_embedded_file) = db_entry?;
for chunk in db_embedded_file.chunks {
chunks_tx
.send((db_embedded_file.path.clone(), chunk))
.await?;
}
}
anyhow::Ok(())
}
});
let query = query.to_string();
let embedding_provider = self.embedding_provider.clone();
let worktree = self.worktree.clone();
cx.spawn(|cx| async move {
#[cfg(debug_assertions)]
let embedding_query_start = std::time::Instant::now();
log::info!("Searching for {query}");
let mut query_embeddings = embedding_provider
.embed(&[TextToEmbed::new(&query)])
.await?;
let query_embedding = query_embeddings
.pop()
.ok_or_else(|| anyhow!("no embedding for query"))?;
let mut workers = Vec::new();
for _ in 0..cx.background_executor().num_cpus() {
workers.push(Vec::<SearchResult>::new());
}
#[cfg(debug_assertions)]
let search_start = std::time::Instant::now();
cx.background_executor()
.scoped(|cx| {
for worker_results in workers.iter_mut() {
cx.spawn(async {
while let Ok((path, embedded_chunk)) = chunks_rx.recv().await {
let score = embedded_chunk.embedding.similarity(&query_embedding);
let ix = match worker_results.binary_search_by(|probe| {
score.partial_cmp(&probe.score).unwrap_or(Ordering::Equal)
}) {
Ok(ix) | Err(ix) => ix,
};
worker_results.insert(
ix,
SearchResult {
worktree: worktree.clone(),
path: path.clone(),
range: embedded_chunk.chunk.range.clone(),
score,
},
);
worker_results.truncate(limit);
}
});
}
})
.await;
scan_chunks.await?;
let mut search_results = Vec::with_capacity(workers.len() * limit);
for worker_results in workers {
search_results.extend(worker_results);
}
search_results
.sort_unstable_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal));
search_results.truncate(limit);
#[cfg(debug_assertions)]
{
let search_elapsed = search_start.elapsed();
log::debug!(
"searched {} entries in {:?}",
search_results.len(),
search_elapsed
);
let embedding_query_elapsed = embedding_query_start.elapsed();
log::debug!("embedding query took {:?}", embedding_query_elapsed);
}
Ok(search_results)
})
}
fn debug(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
let connection = self.db_connection.clone();
let db = self.db;
@ -1093,9 +1099,10 @@ mod tests {
.update(|cx| {
let project_index = project_index.read(cx);
let query = "garbage in, garbage out";
project_index.search(query, 4, cx)
project_index.search(query.into(), 4, cx)
})
.await;
.await
.unwrap();
assert!(results.len() > 1, "should have found some results");
@ -1112,9 +1119,10 @@ mod tests {
let content = cx
.update(|cx| {
let worktree = search_result.worktree.read(cx);
let entry_abs_path = worktree.abs_path().join(search_result.path.clone());
let entry_abs_path = worktree.abs_path().join(&search_result.path);
let fs = project.read(cx).fs().clone();
cx.spawn(|_| async move { fs.load(&entry_abs_path).await.unwrap() })
cx.background_executor()
.spawn(async move { fs.load(&entry_abs_path).await.unwrap() })
})
.await;