mirror of
https://github.com/zed-industries/zed.git
synced 2024-09-19 02:17:35 +03:00
Return an error from project index tool when embedding query fails (#11264)
Previously, a failure to embed the search query (due to a rate limit error) would appear the same as if there were no results. * Avoid repeatedly embedding the search query for each worktree * Unify tasks for searching all worktree Release Notes: - N/A
This commit is contained in:
parent
4b767697af
commit
5831d80f51
@ -140,10 +140,9 @@ impl LanguageModelTool for ProjectIndexTool {
|
|||||||
|
|
||||||
fn execute(&self, query: &Self::Input, cx: &mut WindowContext) -> Task<Result<Self::Output>> {
|
fn execute(&self, query: &Self::Input, cx: &mut WindowContext) -> Task<Result<Self::Output>> {
|
||||||
let project_index = self.project_index.read(cx);
|
let project_index = self.project_index.read(cx);
|
||||||
|
|
||||||
let status = project_index.status();
|
let status = project_index.status();
|
||||||
let results = project_index.search(
|
let results = project_index.search(
|
||||||
query.query.as_str(),
|
query.query.clone(),
|
||||||
query.limit.unwrap_or(DEFAULT_SEARCH_LIMIT),
|
query.limit.unwrap_or(DEFAULT_SEARCH_LIMIT),
|
||||||
cx,
|
cx,
|
||||||
);
|
);
|
||||||
@ -151,7 +150,7 @@ impl LanguageModelTool for ProjectIndexTool {
|
|||||||
let fs = self.fs.clone();
|
let fs = self.fs.clone();
|
||||||
|
|
||||||
cx.spawn(|cx| async move {
|
cx.spawn(|cx| async move {
|
||||||
let results = results.await;
|
let results = results.await?;
|
||||||
|
|
||||||
let excerpts = results.into_iter().map(|result| {
|
let excerpts = results.into_iter().map(|result| {
|
||||||
let abs_path = result
|
let abs_path = result
|
||||||
|
@ -92,10 +92,11 @@ fn main() {
|
|||||||
.update(|cx| {
|
.update(|cx| {
|
||||||
let project_index = project_index.read(cx);
|
let project_index = project_index.read(cx);
|
||||||
let query = "converting an anchor to a point";
|
let query = "converting an anchor to a point";
|
||||||
project_index.search(query, 4, cx)
|
project_index.search(query.into(), 4, cx)
|
||||||
})
|
})
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.await;
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
for search_result in results {
|
for search_result in results {
|
||||||
let path = search_result.path.clone();
|
let path = search_result.path.clone();
|
||||||
|
@ -98,12 +98,9 @@ fn chunk_lines(text: &str) -> Vec<Chunk> {
|
|||||||
|
|
||||||
chunk_ranges
|
chunk_ranges
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|range| {
|
.map(|range| Chunk {
|
||||||
let mut hasher = Sha256::new();
|
digest: Sha256::digest(&text[range.clone()]).into(),
|
||||||
hasher.update(&text[range.clone()]);
|
range,
|
||||||
let mut digest = [0u8; 32];
|
|
||||||
digest.copy_from_slice(hasher.finalize().as_slice());
|
|
||||||
Chunk { range, digest }
|
|
||||||
})
|
})
|
||||||
.collect()
|
.collect()
|
||||||
}
|
}
|
||||||
|
@ -15,7 +15,7 @@ use gpui::{
|
|||||||
use heed::types::{SerdeBincode, Str};
|
use heed::types::{SerdeBincode, Str};
|
||||||
use language::LanguageRegistry;
|
use language::LanguageRegistry;
|
||||||
use parking_lot::Mutex;
|
use parking_lot::Mutex;
|
||||||
use project::{Entry, Project, ProjectEntryId, UpdatedEntriesSet, Worktree};
|
use project::{Entry, Project, ProjectEntryId, UpdatedEntriesSet, Worktree, WorktreeId};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use smol::channel;
|
use smol::channel;
|
||||||
use std::{
|
use std::{
|
||||||
@ -156,6 +156,10 @@ impl ProjectIndex {
|
|||||||
self.last_status
|
self.last_status
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn project(&self) -> WeakModel<Project> {
|
||||||
|
self.project.clone()
|
||||||
|
}
|
||||||
|
|
||||||
fn handle_project_event(
|
fn handle_project_event(
|
||||||
&mut self,
|
&mut self,
|
||||||
_: Model<Project>,
|
_: Model<Project>,
|
||||||
@ -259,30 +263,126 @@ impl ProjectIndex {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn search(&self, query: &str, limit: usize, cx: &AppContext) -> Task<Vec<SearchResult>> {
|
pub fn search(
|
||||||
let mut worktree_searches = Vec::new();
|
&self,
|
||||||
|
query: String,
|
||||||
|
limit: usize,
|
||||||
|
cx: &AppContext,
|
||||||
|
) -> Task<Result<Vec<SearchResult>>> {
|
||||||
|
let (chunks_tx, chunks_rx) = channel::bounded(1024);
|
||||||
|
let mut worktree_scan_tasks = Vec::new();
|
||||||
for worktree_index in self.worktree_indices.values() {
|
for worktree_index in self.worktree_indices.values() {
|
||||||
if let WorktreeIndexHandle::Loaded { index, .. } = worktree_index {
|
if let WorktreeIndexHandle::Loaded { index, .. } = worktree_index {
|
||||||
worktree_searches
|
let chunks_tx = chunks_tx.clone();
|
||||||
.push(index.read_with(cx, |index, cx| index.search(query, limit, cx)));
|
index.read_with(cx, |index, cx| {
|
||||||
|
let worktree_id = index.worktree.read(cx).id();
|
||||||
|
let db_connection = index.db_connection.clone();
|
||||||
|
let db = index.db;
|
||||||
|
worktree_scan_tasks.push(cx.background_executor().spawn({
|
||||||
|
async move {
|
||||||
|
let txn = db_connection
|
||||||
|
.read_txn()
|
||||||
|
.context("failed to create read transaction")?;
|
||||||
|
let db_entries = db.iter(&txn).context("failed to iterate database")?;
|
||||||
|
for db_entry in db_entries {
|
||||||
|
let (_key, db_embedded_file) = db_entry?;
|
||||||
|
for chunk in db_embedded_file.chunks {
|
||||||
|
chunks_tx
|
||||||
|
.send((worktree_id, db_embedded_file.path.clone(), chunk))
|
||||||
|
.await?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
anyhow::Ok(())
|
||||||
|
}
|
||||||
|
}));
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
drop(chunks_tx);
|
||||||
|
|
||||||
cx.spawn(|_| async move {
|
let project = self.project.clone();
|
||||||
let mut results = Vec::new();
|
let embedding_provider = self.embedding_provider.clone();
|
||||||
let worktree_searches = futures::future::join_all(worktree_searches).await;
|
cx.spawn(|cx| async move {
|
||||||
|
#[cfg(debug_assertions)]
|
||||||
|
let embedding_query_start = std::time::Instant::now();
|
||||||
|
log::info!("Searching for {query}");
|
||||||
|
|
||||||
for worktree_search_results in worktree_searches {
|
let query_embeddings = embedding_provider
|
||||||
if let Some(worktree_search_results) = worktree_search_results.log_err() {
|
.embed(&[TextToEmbed::new(&query)])
|
||||||
results.extend(worktree_search_results);
|
.await?;
|
||||||
}
|
let query_embedding = query_embeddings
|
||||||
|
.into_iter()
|
||||||
|
.next()
|
||||||
|
.ok_or_else(|| anyhow!("no embedding for query"))?;
|
||||||
|
|
||||||
|
let mut results_by_worker = Vec::new();
|
||||||
|
for _ in 0..cx.background_executor().num_cpus() {
|
||||||
|
results_by_worker.push(Vec::<WorktreeSearchResult>::new());
|
||||||
}
|
}
|
||||||
|
|
||||||
results
|
#[cfg(debug_assertions)]
|
||||||
.sort_unstable_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal));
|
let search_start = std::time::Instant::now();
|
||||||
results.truncate(limit);
|
|
||||||
|
|
||||||
results
|
cx.background_executor()
|
||||||
|
.scoped(|cx| {
|
||||||
|
for results in results_by_worker.iter_mut() {
|
||||||
|
cx.spawn(async {
|
||||||
|
while let Ok((worktree_id, path, chunk)) = chunks_rx.recv().await {
|
||||||
|
let score = chunk.embedding.similarity(&query_embedding);
|
||||||
|
let ix = match results.binary_search_by(|probe| {
|
||||||
|
score.partial_cmp(&probe.score).unwrap_or(Ordering::Equal)
|
||||||
|
}) {
|
||||||
|
Ok(ix) | Err(ix) => ix,
|
||||||
|
};
|
||||||
|
results.insert(
|
||||||
|
ix,
|
||||||
|
WorktreeSearchResult {
|
||||||
|
worktree_id,
|
||||||
|
path: path.clone(),
|
||||||
|
range: chunk.chunk.range.clone(),
|
||||||
|
score,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
results.truncate(limit);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
|
||||||
|
futures::future::try_join_all(worktree_scan_tasks).await?;
|
||||||
|
|
||||||
|
project.read_with(&cx, |project, cx| {
|
||||||
|
let mut search_results = Vec::with_capacity(results_by_worker.len() * limit);
|
||||||
|
for worker_results in results_by_worker {
|
||||||
|
search_results.extend(worker_results.into_iter().filter_map(|result| {
|
||||||
|
Some(SearchResult {
|
||||||
|
worktree: project.worktree_for_id(result.worktree_id, cx)?,
|
||||||
|
path: result.path,
|
||||||
|
range: result.range,
|
||||||
|
score: result.score,
|
||||||
|
})
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
search_results.sort_unstable_by(|a, b| {
|
||||||
|
b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal)
|
||||||
|
});
|
||||||
|
search_results.truncate(limit);
|
||||||
|
|
||||||
|
#[cfg(debug_assertions)]
|
||||||
|
{
|
||||||
|
let search_elapsed = search_start.elapsed();
|
||||||
|
log::debug!(
|
||||||
|
"searched {} entries in {:?}",
|
||||||
|
search_results.len(),
|
||||||
|
search_elapsed
|
||||||
|
);
|
||||||
|
let embedding_query_elapsed = embedding_query_start.elapsed();
|
||||||
|
log::debug!("embedding query took {:?}", embedding_query_elapsed);
|
||||||
|
}
|
||||||
|
|
||||||
|
search_results
|
||||||
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -327,6 +427,13 @@ pub struct SearchResult {
|
|||||||
pub score: f32,
|
pub score: f32,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub struct WorktreeSearchResult {
|
||||||
|
pub worktree_id: WorktreeId,
|
||||||
|
pub path: Arc<Path>,
|
||||||
|
pub range: Range<usize>,
|
||||||
|
pub score: f32,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
|
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
|
||||||
pub enum Status {
|
pub enum Status {
|
||||||
Idle,
|
Idle,
|
||||||
@ -764,107 +871,6 @@ impl WorktreeIndex {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn search(
|
|
||||||
&self,
|
|
||||||
query: &str,
|
|
||||||
limit: usize,
|
|
||||||
cx: &AppContext,
|
|
||||||
) -> Task<Result<Vec<SearchResult>>> {
|
|
||||||
let (chunks_tx, chunks_rx) = channel::bounded(1024);
|
|
||||||
|
|
||||||
let db_connection = self.db_connection.clone();
|
|
||||||
let db = self.db;
|
|
||||||
let scan_chunks = cx.background_executor().spawn({
|
|
||||||
async move {
|
|
||||||
let txn = db_connection
|
|
||||||
.read_txn()
|
|
||||||
.context("failed to create read transaction")?;
|
|
||||||
let db_entries = db.iter(&txn).context("failed to iterate database")?;
|
|
||||||
for db_entry in db_entries {
|
|
||||||
let (_key, db_embedded_file) = db_entry?;
|
|
||||||
for chunk in db_embedded_file.chunks {
|
|
||||||
chunks_tx
|
|
||||||
.send((db_embedded_file.path.clone(), chunk))
|
|
||||||
.await?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
anyhow::Ok(())
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
let query = query.to_string();
|
|
||||||
let embedding_provider = self.embedding_provider.clone();
|
|
||||||
let worktree = self.worktree.clone();
|
|
||||||
cx.spawn(|cx| async move {
|
|
||||||
#[cfg(debug_assertions)]
|
|
||||||
let embedding_query_start = std::time::Instant::now();
|
|
||||||
log::info!("Searching for {query}");
|
|
||||||
|
|
||||||
let mut query_embeddings = embedding_provider
|
|
||||||
.embed(&[TextToEmbed::new(&query)])
|
|
||||||
.await?;
|
|
||||||
let query_embedding = query_embeddings
|
|
||||||
.pop()
|
|
||||||
.ok_or_else(|| anyhow!("no embedding for query"))?;
|
|
||||||
let mut workers = Vec::new();
|
|
||||||
for _ in 0..cx.background_executor().num_cpus() {
|
|
||||||
workers.push(Vec::<SearchResult>::new());
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(debug_assertions)]
|
|
||||||
let search_start = std::time::Instant::now();
|
|
||||||
|
|
||||||
cx.background_executor()
|
|
||||||
.scoped(|cx| {
|
|
||||||
for worker_results in workers.iter_mut() {
|
|
||||||
cx.spawn(async {
|
|
||||||
while let Ok((path, embedded_chunk)) = chunks_rx.recv().await {
|
|
||||||
let score = embedded_chunk.embedding.similarity(&query_embedding);
|
|
||||||
let ix = match worker_results.binary_search_by(|probe| {
|
|
||||||
score.partial_cmp(&probe.score).unwrap_or(Ordering::Equal)
|
|
||||||
}) {
|
|
||||||
Ok(ix) | Err(ix) => ix,
|
|
||||||
};
|
|
||||||
worker_results.insert(
|
|
||||||
ix,
|
|
||||||
SearchResult {
|
|
||||||
worktree: worktree.clone(),
|
|
||||||
path: path.clone(),
|
|
||||||
range: embedded_chunk.chunk.range.clone(),
|
|
||||||
score,
|
|
||||||
},
|
|
||||||
);
|
|
||||||
worker_results.truncate(limit);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.await;
|
|
||||||
scan_chunks.await?;
|
|
||||||
|
|
||||||
let mut search_results = Vec::with_capacity(workers.len() * limit);
|
|
||||||
for worker_results in workers {
|
|
||||||
search_results.extend(worker_results);
|
|
||||||
}
|
|
||||||
search_results
|
|
||||||
.sort_unstable_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal));
|
|
||||||
search_results.truncate(limit);
|
|
||||||
#[cfg(debug_assertions)]
|
|
||||||
{
|
|
||||||
let search_elapsed = search_start.elapsed();
|
|
||||||
log::debug!(
|
|
||||||
"searched {} entries in {:?}",
|
|
||||||
search_results.len(),
|
|
||||||
search_elapsed
|
|
||||||
);
|
|
||||||
let embedding_query_elapsed = embedding_query_start.elapsed();
|
|
||||||
log::debug!("embedding query took {:?}", embedding_query_elapsed);
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(search_results)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn debug(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
|
fn debug(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
|
||||||
let connection = self.db_connection.clone();
|
let connection = self.db_connection.clone();
|
||||||
let db = self.db;
|
let db = self.db;
|
||||||
@ -1093,9 +1099,10 @@ mod tests {
|
|||||||
.update(|cx| {
|
.update(|cx| {
|
||||||
let project_index = project_index.read(cx);
|
let project_index = project_index.read(cx);
|
||||||
let query = "garbage in, garbage out";
|
let query = "garbage in, garbage out";
|
||||||
project_index.search(query, 4, cx)
|
project_index.search(query.into(), 4, cx)
|
||||||
})
|
})
|
||||||
.await;
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
assert!(results.len() > 1, "should have found some results");
|
assert!(results.len() > 1, "should have found some results");
|
||||||
|
|
||||||
@ -1112,9 +1119,10 @@ mod tests {
|
|||||||
let content = cx
|
let content = cx
|
||||||
.update(|cx| {
|
.update(|cx| {
|
||||||
let worktree = search_result.worktree.read(cx);
|
let worktree = search_result.worktree.read(cx);
|
||||||
let entry_abs_path = worktree.abs_path().join(search_result.path.clone());
|
let entry_abs_path = worktree.abs_path().join(&search_result.path);
|
||||||
let fs = project.read(cx).fs().clone();
|
let fs = project.read(cx).fs().clone();
|
||||||
cx.spawn(|_| async move { fs.load(&entry_abs_path).await.unwrap() })
|
cx.background_executor()
|
||||||
|
.spawn(async move { fs.load(&entry_abs_path).await.unwrap() })
|
||||||
})
|
})
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user