mirror of
https://github.com/zed-industries/zed.git
synced 2024-11-08 07:35:01 +03:00
Return an error from project index tool when embedding query fails (#11264)
Previously, a failure to embed the search query (due to a rate limit error) would appear the same as if there were no results. * Avoid repeatedly embedding the search query for each worktree * Unify tasks for searching all worktree Release Notes: - N/A
This commit is contained in:
parent
4b767697af
commit
5831d80f51
@ -140,10 +140,9 @@ impl LanguageModelTool for ProjectIndexTool {
|
||||
|
||||
fn execute(&self, query: &Self::Input, cx: &mut WindowContext) -> Task<Result<Self::Output>> {
|
||||
let project_index = self.project_index.read(cx);
|
||||
|
||||
let status = project_index.status();
|
||||
let results = project_index.search(
|
||||
query.query.as_str(),
|
||||
query.query.clone(),
|
||||
query.limit.unwrap_or(DEFAULT_SEARCH_LIMIT),
|
||||
cx,
|
||||
);
|
||||
@ -151,7 +150,7 @@ impl LanguageModelTool for ProjectIndexTool {
|
||||
let fs = self.fs.clone();
|
||||
|
||||
cx.spawn(|cx| async move {
|
||||
let results = results.await;
|
||||
let results = results.await?;
|
||||
|
||||
let excerpts = results.into_iter().map(|result| {
|
||||
let abs_path = result
|
||||
|
@ -92,10 +92,11 @@ fn main() {
|
||||
.update(|cx| {
|
||||
let project_index = project_index.read(cx);
|
||||
let query = "converting an anchor to a point";
|
||||
project_index.search(query, 4, cx)
|
||||
project_index.search(query.into(), 4, cx)
|
||||
})
|
||||
.unwrap()
|
||||
.await;
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
for search_result in results {
|
||||
let path = search_result.path.clone();
|
||||
|
@ -98,12 +98,9 @@ fn chunk_lines(text: &str) -> Vec<Chunk> {
|
||||
|
||||
chunk_ranges
|
||||
.into_iter()
|
||||
.map(|range| {
|
||||
let mut hasher = Sha256::new();
|
||||
hasher.update(&text[range.clone()]);
|
||||
let mut digest = [0u8; 32];
|
||||
digest.copy_from_slice(hasher.finalize().as_slice());
|
||||
Chunk { range, digest }
|
||||
.map(|range| Chunk {
|
||||
digest: Sha256::digest(&text[range.clone()]).into(),
|
||||
range,
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
@ -15,7 +15,7 @@ use gpui::{
|
||||
use heed::types::{SerdeBincode, Str};
|
||||
use language::LanguageRegistry;
|
||||
use parking_lot::Mutex;
|
||||
use project::{Entry, Project, ProjectEntryId, UpdatedEntriesSet, Worktree};
|
||||
use project::{Entry, Project, ProjectEntryId, UpdatedEntriesSet, Worktree, WorktreeId};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use smol::channel;
|
||||
use std::{
|
||||
@ -156,6 +156,10 @@ impl ProjectIndex {
|
||||
self.last_status
|
||||
}
|
||||
|
||||
pub fn project(&self) -> WeakModel<Project> {
|
||||
self.project.clone()
|
||||
}
|
||||
|
||||
fn handle_project_event(
|
||||
&mut self,
|
||||
_: Model<Project>,
|
||||
@ -259,30 +263,126 @@ impl ProjectIndex {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn search(&self, query: &str, limit: usize, cx: &AppContext) -> Task<Vec<SearchResult>> {
|
||||
let mut worktree_searches = Vec::new();
|
||||
pub fn search(
|
||||
&self,
|
||||
query: String,
|
||||
limit: usize,
|
||||
cx: &AppContext,
|
||||
) -> Task<Result<Vec<SearchResult>>> {
|
||||
let (chunks_tx, chunks_rx) = channel::bounded(1024);
|
||||
let mut worktree_scan_tasks = Vec::new();
|
||||
for worktree_index in self.worktree_indices.values() {
|
||||
if let WorktreeIndexHandle::Loaded { index, .. } = worktree_index {
|
||||
worktree_searches
|
||||
.push(index.read_with(cx, |index, cx| index.search(query, limit, cx)));
|
||||
let chunks_tx = chunks_tx.clone();
|
||||
index.read_with(cx, |index, cx| {
|
||||
let worktree_id = index.worktree.read(cx).id();
|
||||
let db_connection = index.db_connection.clone();
|
||||
let db = index.db;
|
||||
worktree_scan_tasks.push(cx.background_executor().spawn({
|
||||
async move {
|
||||
let txn = db_connection
|
||||
.read_txn()
|
||||
.context("failed to create read transaction")?;
|
||||
let db_entries = db.iter(&txn).context("failed to iterate database")?;
|
||||
for db_entry in db_entries {
|
||||
let (_key, db_embedded_file) = db_entry?;
|
||||
for chunk in db_embedded_file.chunks {
|
||||
chunks_tx
|
||||
.send((worktree_id, db_embedded_file.path.clone(), chunk))
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
anyhow::Ok(())
|
||||
}
|
||||
}));
|
||||
})
|
||||
}
|
||||
}
|
||||
drop(chunks_tx);
|
||||
|
||||
let project = self.project.clone();
|
||||
let embedding_provider = self.embedding_provider.clone();
|
||||
cx.spawn(|cx| async move {
|
||||
#[cfg(debug_assertions)]
|
||||
let embedding_query_start = std::time::Instant::now();
|
||||
log::info!("Searching for {query}");
|
||||
|
||||
let query_embeddings = embedding_provider
|
||||
.embed(&[TextToEmbed::new(&query)])
|
||||
.await?;
|
||||
let query_embedding = query_embeddings
|
||||
.into_iter()
|
||||
.next()
|
||||
.ok_or_else(|| anyhow!("no embedding for query"))?;
|
||||
|
||||
let mut results_by_worker = Vec::new();
|
||||
for _ in 0..cx.background_executor().num_cpus() {
|
||||
results_by_worker.push(Vec::<WorktreeSearchResult>::new());
|
||||
}
|
||||
|
||||
cx.spawn(|_| async move {
|
||||
let mut results = Vec::new();
|
||||
let worktree_searches = futures::future::join_all(worktree_searches).await;
|
||||
#[cfg(debug_assertions)]
|
||||
let search_start = std::time::Instant::now();
|
||||
|
||||
for worktree_search_results in worktree_searches {
|
||||
if let Some(worktree_search_results) = worktree_search_results.log_err() {
|
||||
results.extend(worktree_search_results);
|
||||
}
|
||||
}
|
||||
|
||||
results
|
||||
.sort_unstable_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal));
|
||||
cx.background_executor()
|
||||
.scoped(|cx| {
|
||||
for results in results_by_worker.iter_mut() {
|
||||
cx.spawn(async {
|
||||
while let Ok((worktree_id, path, chunk)) = chunks_rx.recv().await {
|
||||
let score = chunk.embedding.similarity(&query_embedding);
|
||||
let ix = match results.binary_search_by(|probe| {
|
||||
score.partial_cmp(&probe.score).unwrap_or(Ordering::Equal)
|
||||
}) {
|
||||
Ok(ix) | Err(ix) => ix,
|
||||
};
|
||||
results.insert(
|
||||
ix,
|
||||
WorktreeSearchResult {
|
||||
worktree_id,
|
||||
path: path.clone(),
|
||||
range: chunk.chunk.range.clone(),
|
||||
score,
|
||||
},
|
||||
);
|
||||
results.truncate(limit);
|
||||
}
|
||||
});
|
||||
}
|
||||
})
|
||||
.await;
|
||||
|
||||
results
|
||||
futures::future::try_join_all(worktree_scan_tasks).await?;
|
||||
|
||||
project.read_with(&cx, |project, cx| {
|
||||
let mut search_results = Vec::with_capacity(results_by_worker.len() * limit);
|
||||
for worker_results in results_by_worker {
|
||||
search_results.extend(worker_results.into_iter().filter_map(|result| {
|
||||
Some(SearchResult {
|
||||
worktree: project.worktree_for_id(result.worktree_id, cx)?,
|
||||
path: result.path,
|
||||
range: result.range,
|
||||
score: result.score,
|
||||
})
|
||||
}));
|
||||
}
|
||||
search_results.sort_unstable_by(|a, b| {
|
||||
b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal)
|
||||
});
|
||||
search_results.truncate(limit);
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
let search_elapsed = search_start.elapsed();
|
||||
log::debug!(
|
||||
"searched {} entries in {:?}",
|
||||
search_results.len(),
|
||||
search_elapsed
|
||||
);
|
||||
let embedding_query_elapsed = embedding_query_start.elapsed();
|
||||
log::debug!("embedding query took {:?}", embedding_query_elapsed);
|
||||
}
|
||||
|
||||
search_results
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
@ -327,6 +427,13 @@ pub struct SearchResult {
|
||||
pub score: f32,
|
||||
}
|
||||
|
||||
pub struct WorktreeSearchResult {
|
||||
pub worktree_id: WorktreeId,
|
||||
pub path: Arc<Path>,
|
||||
pub range: Range<usize>,
|
||||
pub score: f32,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
|
||||
pub enum Status {
|
||||
Idle,
|
||||
@ -764,107 +871,6 @@ impl WorktreeIndex {
|
||||
})
|
||||
}
|
||||
|
||||
fn search(
|
||||
&self,
|
||||
query: &str,
|
||||
limit: usize,
|
||||
cx: &AppContext,
|
||||
) -> Task<Result<Vec<SearchResult>>> {
|
||||
let (chunks_tx, chunks_rx) = channel::bounded(1024);
|
||||
|
||||
let db_connection = self.db_connection.clone();
|
||||
let db = self.db;
|
||||
let scan_chunks = cx.background_executor().spawn({
|
||||
async move {
|
||||
let txn = db_connection
|
||||
.read_txn()
|
||||
.context("failed to create read transaction")?;
|
||||
let db_entries = db.iter(&txn).context("failed to iterate database")?;
|
||||
for db_entry in db_entries {
|
||||
let (_key, db_embedded_file) = db_entry?;
|
||||
for chunk in db_embedded_file.chunks {
|
||||
chunks_tx
|
||||
.send((db_embedded_file.path.clone(), chunk))
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
anyhow::Ok(())
|
||||
}
|
||||
});
|
||||
|
||||
let query = query.to_string();
|
||||
let embedding_provider = self.embedding_provider.clone();
|
||||
let worktree = self.worktree.clone();
|
||||
cx.spawn(|cx| async move {
|
||||
#[cfg(debug_assertions)]
|
||||
let embedding_query_start = std::time::Instant::now();
|
||||
log::info!("Searching for {query}");
|
||||
|
||||
let mut query_embeddings = embedding_provider
|
||||
.embed(&[TextToEmbed::new(&query)])
|
||||
.await?;
|
||||
let query_embedding = query_embeddings
|
||||
.pop()
|
||||
.ok_or_else(|| anyhow!("no embedding for query"))?;
|
||||
let mut workers = Vec::new();
|
||||
for _ in 0..cx.background_executor().num_cpus() {
|
||||
workers.push(Vec::<SearchResult>::new());
|
||||
}
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
let search_start = std::time::Instant::now();
|
||||
|
||||
cx.background_executor()
|
||||
.scoped(|cx| {
|
||||
for worker_results in workers.iter_mut() {
|
||||
cx.spawn(async {
|
||||
while let Ok((path, embedded_chunk)) = chunks_rx.recv().await {
|
||||
let score = embedded_chunk.embedding.similarity(&query_embedding);
|
||||
let ix = match worker_results.binary_search_by(|probe| {
|
||||
score.partial_cmp(&probe.score).unwrap_or(Ordering::Equal)
|
||||
}) {
|
||||
Ok(ix) | Err(ix) => ix,
|
||||
};
|
||||
worker_results.insert(
|
||||
ix,
|
||||
SearchResult {
|
||||
worktree: worktree.clone(),
|
||||
path: path.clone(),
|
||||
range: embedded_chunk.chunk.range.clone(),
|
||||
score,
|
||||
},
|
||||
);
|
||||
worker_results.truncate(limit);
|
||||
}
|
||||
});
|
||||
}
|
||||
})
|
||||
.await;
|
||||
scan_chunks.await?;
|
||||
|
||||
let mut search_results = Vec::with_capacity(workers.len() * limit);
|
||||
for worker_results in workers {
|
||||
search_results.extend(worker_results);
|
||||
}
|
||||
search_results
|
||||
.sort_unstable_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal));
|
||||
search_results.truncate(limit);
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
let search_elapsed = search_start.elapsed();
|
||||
log::debug!(
|
||||
"searched {} entries in {:?}",
|
||||
search_results.len(),
|
||||
search_elapsed
|
||||
);
|
||||
let embedding_query_elapsed = embedding_query_start.elapsed();
|
||||
log::debug!("embedding query took {:?}", embedding_query_elapsed);
|
||||
}
|
||||
|
||||
Ok(search_results)
|
||||
})
|
||||
}
|
||||
|
||||
fn debug(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
|
||||
let connection = self.db_connection.clone();
|
||||
let db = self.db;
|
||||
@ -1093,9 +1099,10 @@ mod tests {
|
||||
.update(|cx| {
|
||||
let project_index = project_index.read(cx);
|
||||
let query = "garbage in, garbage out";
|
||||
project_index.search(query, 4, cx)
|
||||
project_index.search(query.into(), 4, cx)
|
||||
})
|
||||
.await;
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(results.len() > 1, "should have found some results");
|
||||
|
||||
@ -1112,9 +1119,10 @@ mod tests {
|
||||
let content = cx
|
||||
.update(|cx| {
|
||||
let worktree = search_result.worktree.read(cx);
|
||||
let entry_abs_path = worktree.abs_path().join(search_result.path.clone());
|
||||
let entry_abs_path = worktree.abs_path().join(&search_result.path);
|
||||
let fs = project.read(cx).fs().clone();
|
||||
cx.spawn(|_| async move { fs.load(&entry_abs_path).await.unwrap() })
|
||||
cx.background_executor()
|
||||
.spawn(async move { fs.load(&entry_abs_path).await.unwrap() })
|
||||
})
|
||||
.await;
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user