mirror of
https://github.com/zed-industries/zed.git
synced 2024-11-08 07:35:01 +03:00
Fix some semantic index issues (#11216)
* [x] Fixed an issue where embeddings would be assigned incorrectly to files if a subset of embedding batches failed * [x] Added a command to debug which paths are present in the semantic index * [x] Determine why so many paths are often missing from the semantic index * we erroring out if an embedding batch contained multiple texts that were the same, which can happen if a worktree contains multiple copies of the same text (e.g. a license). Release Notes: - N/A --------- Co-authored-by: Marshall <marshall@zed.dev> Co-authored-by: Nathan <nathan@zed.dev> Co-authored-by: Kyle <kylek@zed.dev> Co-authored-by: Kyle Kelley <rgbkrk@gmail.com>
This commit is contained in:
parent
d01428e69c
commit
38b9d5cc36
@ -365,7 +365,7 @@ impl Example {
|
||||
) -> Self {
|
||||
Self {
|
||||
assistant_panel: cx.new_view(|cx| {
|
||||
AssistantPanel::new(language_registry, tool_registry, user_store, cx)
|
||||
AssistantPanel::new(language_registry, tool_registry, user_store, None, cx)
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
@ -19,7 +19,7 @@ use gpui::{
|
||||
use language::{language_settings::SoftWrap, LanguageRegistry};
|
||||
use open_ai::{FunctionContent, ToolCall, ToolCallContent};
|
||||
use rich_text::RichText;
|
||||
use semantic_index::{CloudEmbeddingProvider, SemanticIndex};
|
||||
use semantic_index::{CloudEmbeddingProvider, ProjectIndex, SemanticIndex};
|
||||
use serde::Deserialize;
|
||||
use settings::Settings;
|
||||
use std::sync::Arc;
|
||||
@ -51,7 +51,7 @@ pub enum SubmitMode {
|
||||
Codebase,
|
||||
}
|
||||
|
||||
gpui::actions!(assistant2, [Cancel, ToggleFocus]);
|
||||
gpui::actions!(assistant2, [Cancel, ToggleFocus, DebugProjectIndex]);
|
||||
gpui::impl_actions!(assistant2, [Submit]);
|
||||
|
||||
pub fn init(client: Arc<Client>, cx: &mut AppContext) {
|
||||
@ -131,7 +131,13 @@ impl AssistantPanel {
|
||||
|
||||
let tool_registry = Arc::new(tool_registry);
|
||||
|
||||
Self::new(app_state.languages.clone(), tool_registry, user_store, cx)
|
||||
Self::new(
|
||||
app_state.languages.clone(),
|
||||
tool_registry,
|
||||
user_store,
|
||||
Some(project_index),
|
||||
cx,
|
||||
)
|
||||
})
|
||||
})
|
||||
}
|
||||
@ -140,6 +146,7 @@ impl AssistantPanel {
|
||||
language_registry: Arc<LanguageRegistry>,
|
||||
tool_registry: Arc<ToolRegistry>,
|
||||
user_store: Model<UserStore>,
|
||||
project_index: Option<Model<ProjectIndex>>,
|
||||
cx: &mut ViewContext<Self>,
|
||||
) -> Self {
|
||||
let chat = cx.new_view(|cx| {
|
||||
@ -147,6 +154,7 @@ impl AssistantPanel {
|
||||
language_registry.clone(),
|
||||
tool_registry.clone(),
|
||||
user_store,
|
||||
project_index,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
@ -225,6 +233,7 @@ struct AssistantChat {
|
||||
collapsed_messages: HashMap<MessageId, bool>,
|
||||
pending_completion: Option<Task<()>>,
|
||||
tool_registry: Arc<ToolRegistry>,
|
||||
project_index: Option<Model<ProjectIndex>>,
|
||||
}
|
||||
|
||||
impl AssistantChat {
|
||||
@ -232,6 +241,7 @@ impl AssistantChat {
|
||||
language_registry: Arc<LanguageRegistry>,
|
||||
tool_registry: Arc<ToolRegistry>,
|
||||
user_store: Model<UserStore>,
|
||||
project_index: Option<Model<ProjectIndex>>,
|
||||
cx: &mut ViewContext<Self>,
|
||||
) -> Self {
|
||||
let model = CompletionProvider::get(cx).default_model();
|
||||
@ -258,6 +268,7 @@ impl AssistantChat {
|
||||
list_state,
|
||||
user_store,
|
||||
language_registry,
|
||||
project_index,
|
||||
next_message_id: MessageId(0),
|
||||
collapsed_messages: HashMap::default(),
|
||||
pending_completion: None,
|
||||
@ -342,6 +353,14 @@ impl AssistantChat {
|
||||
self.pending_completion.is_none()
|
||||
}
|
||||
|
||||
fn debug_project_index(&mut self, _: &DebugProjectIndex, cx: &mut ViewContext<Self>) {
|
||||
if let Some(index) = &self.project_index {
|
||||
index.update(cx, |project_index, cx| {
|
||||
project_index.debug(cx).detach_and_log_err(cx)
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
async fn request_completion(
|
||||
this: WeakView<Self>,
|
||||
mode: SubmitMode,
|
||||
@ -686,6 +705,7 @@ impl Render for AssistantChat {
|
||||
.key_context("AssistantChat")
|
||||
.on_action(cx.listener(Self::submit))
|
||||
.on_action(cx.listener(Self::cancel))
|
||||
.on_action(cx.listener(Self::debug_project_index))
|
||||
.text_color(Color::Default.color(cx))
|
||||
.child(list(self.list_state.clone()).flex_1())
|
||||
.child(Composer::new(
|
||||
|
@ -72,10 +72,11 @@ impl EmbeddingProvider for CloudEmbeddingProvider {
|
||||
texts
|
||||
.iter()
|
||||
.map(|to_embed| {
|
||||
let dimensions = embeddings.remove(&to_embed.digest).with_context(|| {
|
||||
let embedding =
|
||||
embeddings.get(&to_embed.digest).cloned().with_context(|| {
|
||||
format!("server did not return an embedding for {:?}", to_embed)
|
||||
})?;
|
||||
Ok(Embedding::new(dimensions))
|
||||
Ok(Embedding::new(embedding))
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
@ -21,6 +21,7 @@ use smol::channel;
|
||||
use std::{
|
||||
cmp::Ordering,
|
||||
future::Future,
|
||||
iter,
|
||||
num::NonZeroUsize,
|
||||
ops::Range,
|
||||
path::{Path, PathBuf},
|
||||
@ -295,6 +296,28 @@ impl ProjectIndex {
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
pub fn debug(&self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
|
||||
let indices = self
|
||||
.worktree_indices
|
||||
.values()
|
||||
.filter_map(|worktree_index| {
|
||||
if let WorktreeIndexHandle::Loaded { index, .. } = worktree_index {
|
||||
Some(index.clone())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
cx.spawn(|_, mut cx| async move {
|
||||
eprintln!("semantic index contents:");
|
||||
for index in indices {
|
||||
index.update(&mut cx, |index, cx| index.debug(cx))?.await?
|
||||
}
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub struct SearchResult {
|
||||
@ -419,7 +442,7 @@ impl WorktreeIndex {
|
||||
let worktree_abs_path = worktree.abs_path().clone();
|
||||
let scan = self.scan_entries(worktree.clone(), cx);
|
||||
let chunk = self.chunk_files(worktree_abs_path, scan.updated_entries, cx);
|
||||
let embed = self.embed_files(chunk.files, cx);
|
||||
let embed = Self::embed_files(self.embedding_provider.clone(), chunk.files, cx);
|
||||
let persist = self.persist_embeddings(scan.deleted_entry_ranges, embed.files, cx);
|
||||
async move {
|
||||
futures::try_join!(scan.task, chunk.task, embed.task, persist)?;
|
||||
@ -436,7 +459,7 @@ impl WorktreeIndex {
|
||||
let worktree_abs_path = worktree.abs_path().clone();
|
||||
let scan = self.scan_updated_entries(worktree, updated_entries.clone(), cx);
|
||||
let chunk = self.chunk_files(worktree_abs_path, scan.updated_entries, cx);
|
||||
let embed = self.embed_files(chunk.files, cx);
|
||||
let embed = Self::embed_files(self.embedding_provider.clone(), chunk.files, cx);
|
||||
let persist = self.persist_embeddings(scan.deleted_entry_ranges, embed.files, cx);
|
||||
async move {
|
||||
futures::try_join!(scan.task, chunk.task, embed.task, persist)?;
|
||||
@ -500,7 +523,7 @@ impl WorktreeIndex {
|
||||
}
|
||||
|
||||
if entry.mtime != saved_mtime {
|
||||
let handle = entries_being_indexed.insert(&entry);
|
||||
let handle = entries_being_indexed.insert(entry.id);
|
||||
updated_entries_tx.send((entry.clone(), handle)).await?;
|
||||
}
|
||||
}
|
||||
@ -539,7 +562,7 @@ impl WorktreeIndex {
|
||||
| project::PathChange::AddedOrUpdated => {
|
||||
if let Some(entry) = worktree.entry_for_id(*entry_id) {
|
||||
if entry.is_file() {
|
||||
let handle = entries_being_indexed.insert(&entry);
|
||||
let handle = entries_being_indexed.insert(entry.id);
|
||||
updated_entries_tx.send((entry.clone(), handle)).await?;
|
||||
}
|
||||
}
|
||||
@ -601,7 +624,8 @@ impl WorktreeIndex {
|
||||
let chunked_file = ChunkedFile {
|
||||
chunks: chunk_text(&text, grammar),
|
||||
handle,
|
||||
entry,
|
||||
path: entry.path,
|
||||
mtime: entry.mtime,
|
||||
text,
|
||||
};
|
||||
|
||||
@ -623,11 +647,11 @@ impl WorktreeIndex {
|
||||
}
|
||||
|
||||
fn embed_files(
|
||||
&self,
|
||||
embedding_provider: Arc<dyn EmbeddingProvider>,
|
||||
chunked_files: channel::Receiver<ChunkedFile>,
|
||||
cx: &AppContext,
|
||||
) -> EmbedFiles {
|
||||
let embedding_provider = self.embedding_provider.clone();
|
||||
let embedding_provider = embedding_provider.clone();
|
||||
let (embedded_files_tx, embedded_files_rx) = channel::bounded(512);
|
||||
let task = cx.background_executor().spawn(async move {
|
||||
let mut chunked_file_batches =
|
||||
@ -635,9 +659,10 @@ impl WorktreeIndex {
|
||||
while let Some(chunked_files) = chunked_file_batches.next().await {
|
||||
// View the batch of files as a vec of chunks
|
||||
// Flatten out to a vec of chunks that we can subdivide into batch sized pieces
|
||||
// Once those are done, reassemble it back into which files they belong to
|
||||
// Once those are done, reassemble them back into the files in which they belong
|
||||
// If any embeddings fail for a file, the entire file is discarded
|
||||
|
||||
let chunks = chunked_files
|
||||
let chunks: Vec<TextToEmbed> = chunked_files
|
||||
.iter()
|
||||
.flat_map(|file| {
|
||||
file.chunks.iter().map(|chunk| TextToEmbed {
|
||||
@ -647,38 +672,52 @@ impl WorktreeIndex {
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let mut embeddings = Vec::new();
|
||||
let mut embeddings: Vec<Option<Embedding>> = Vec::new();
|
||||
for embedding_batch in chunks.chunks(embedding_provider.batch_size()) {
|
||||
if let Some(batch_embeddings) =
|
||||
embedding_provider.embed(embedding_batch).await.log_err()
|
||||
{
|
||||
embeddings.extend_from_slice(&batch_embeddings);
|
||||
if batch_embeddings.len() == embedding_batch.len() {
|
||||
embeddings.extend(batch_embeddings.into_iter().map(Some));
|
||||
continue;
|
||||
}
|
||||
log::error!(
|
||||
"embedding provider returned unexpected embedding count {}, expected {}",
|
||||
batch_embeddings.len(), embedding_batch.len()
|
||||
);
|
||||
}
|
||||
|
||||
embeddings.extend(iter::repeat(None).take(embedding_batch.len()));
|
||||
}
|
||||
|
||||
let mut embeddings = embeddings.into_iter();
|
||||
for chunked_file in chunked_files {
|
||||
let chunk_embeddings = embeddings
|
||||
.by_ref()
|
||||
.take(chunked_file.chunks.len())
|
||||
.collect::<Vec<_>>();
|
||||
let embedded_chunks = chunked_file
|
||||
.chunks
|
||||
.into_iter()
|
||||
.zip(chunk_embeddings)
|
||||
.map(|(chunk, embedding)| EmbeddedChunk { chunk, embedding })
|
||||
.collect();
|
||||
let embedded_file = EmbeddedFile {
|
||||
path: chunked_file.entry.path.clone(),
|
||||
mtime: chunked_file.entry.mtime,
|
||||
chunks: embedded_chunks,
|
||||
let mut embedded_file = EmbeddedFile {
|
||||
path: chunked_file.path,
|
||||
mtime: chunked_file.mtime,
|
||||
chunks: Vec::new(),
|
||||
};
|
||||
|
||||
let mut embedded_all_chunks = true;
|
||||
for (chunk, embedding) in
|
||||
chunked_file.chunks.into_iter().zip(embeddings.by_ref())
|
||||
{
|
||||
if let Some(embedding) = embedding {
|
||||
embedded_file
|
||||
.chunks
|
||||
.push(EmbeddedChunk { chunk, embedding });
|
||||
} else {
|
||||
embedded_all_chunks = false;
|
||||
}
|
||||
}
|
||||
|
||||
if embedded_all_chunks {
|
||||
embedded_files_tx
|
||||
.send((embedded_file, chunked_file.handle))
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
});
|
||||
|
||||
@ -826,6 +865,21 @@ impl WorktreeIndex {
|
||||
})
|
||||
}
|
||||
|
||||
fn debug(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
|
||||
let connection = self.db_connection.clone();
|
||||
let db = self.db;
|
||||
cx.background_executor().spawn(async move {
|
||||
let tx = connection
|
||||
.read_txn()
|
||||
.context("failed to create read transaction")?;
|
||||
for record in db.iter(&tx)? {
|
||||
let (key, _) = record?;
|
||||
eprintln!("{}", path_for_db_key(key));
|
||||
}
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn path_count(&self) -> Result<u64> {
|
||||
let txn = self
|
||||
@ -848,7 +902,8 @@ struct ChunkFiles {
|
||||
}
|
||||
|
||||
struct ChunkedFile {
|
||||
pub entry: Entry,
|
||||
pub path: Arc<Path>,
|
||||
pub mtime: Option<SystemTime>,
|
||||
pub handle: IndexingEntryHandle,
|
||||
pub text: String,
|
||||
pub chunks: Vec<Chunk>,
|
||||
@ -872,11 +927,14 @@ struct EmbeddedChunk {
|
||||
embedding: Embedding,
|
||||
}
|
||||
|
||||
/// The set of entries that are currently being indexed.
|
||||
struct IndexingEntrySet {
|
||||
entry_ids: Mutex<HashSet<ProjectEntryId>>,
|
||||
tx: channel::Sender<()>,
|
||||
}
|
||||
|
||||
/// When dropped, removes the entry from the set of entries that are being indexed.
|
||||
#[derive(Clone)]
|
||||
struct IndexingEntryHandle {
|
||||
entry_id: ProjectEntryId,
|
||||
set: Weak<IndexingEntrySet>,
|
||||
@ -890,11 +948,11 @@ impl IndexingEntrySet {
|
||||
}
|
||||
}
|
||||
|
||||
fn insert(self: &Arc<Self>, entry: &project::Entry) -> IndexingEntryHandle {
|
||||
self.entry_ids.lock().insert(entry.id);
|
||||
fn insert(self: &Arc<Self>, entry_id: ProjectEntryId) -> IndexingEntryHandle {
|
||||
self.entry_ids.lock().insert(entry_id);
|
||||
self.tx.send_blocking(()).ok();
|
||||
IndexingEntryHandle {
|
||||
entry_id: entry.id,
|
||||
entry_id,
|
||||
set: Arc::downgrade(self),
|
||||
}
|
||||
}
|
||||
@ -917,6 +975,10 @@ fn db_key_for_path(path: &Arc<Path>) -> String {
|
||||
path.to_string_lossy().replace('/', "\0")
|
||||
}
|
||||
|
||||
fn path_for_db_key(key: &str) -> String {
|
||||
key.replace('\0', "/")
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
@ -939,7 +1001,22 @@ mod tests {
|
||||
});
|
||||
}
|
||||
|
||||
pub struct TestEmbeddingProvider;
|
||||
pub struct TestEmbeddingProvider {
|
||||
batch_size: usize,
|
||||
compute_embedding: Box<dyn Fn(&str) -> Result<Embedding> + Send + Sync>,
|
||||
}
|
||||
|
||||
impl TestEmbeddingProvider {
|
||||
pub fn new(
|
||||
batch_size: usize,
|
||||
compute_embedding: impl 'static + Fn(&str) -> Result<Embedding> + Send + Sync,
|
||||
) -> Self {
|
||||
return Self {
|
||||
batch_size,
|
||||
compute_embedding: Box::new(compute_embedding),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
impl EmbeddingProvider for TestEmbeddingProvider {
|
||||
fn embed<'a>(
|
||||
@ -948,29 +1025,13 @@ mod tests {
|
||||
) -> BoxFuture<'a, Result<Vec<Embedding>>> {
|
||||
let embeddings = texts
|
||||
.iter()
|
||||
.map(|text| {
|
||||
let mut embedding = vec![0f32; 2];
|
||||
// if the text contains garbage, give it a 1 in the first dimension
|
||||
if text.text.contains("garbage in") {
|
||||
embedding[0] = 0.9;
|
||||
} else {
|
||||
embedding[0] = -0.9;
|
||||
}
|
||||
|
||||
if text.text.contains("garbage out") {
|
||||
embedding[1] = 0.9;
|
||||
} else {
|
||||
embedding[1] = -0.9;
|
||||
}
|
||||
|
||||
Embedding::new(embedding)
|
||||
})
|
||||
.map(|to_embed| (self.compute_embedding)(to_embed.text))
|
||||
.collect();
|
||||
future::ready(Ok(embeddings)).boxed()
|
||||
future::ready(embeddings).boxed()
|
||||
}
|
||||
|
||||
fn batch_size(&self) -> usize {
|
||||
16
|
||||
self.batch_size
|
||||
}
|
||||
}
|
||||
|
||||
@ -984,7 +1045,23 @@ mod tests {
|
||||
|
||||
let mut semantic_index = SemanticIndex::new(
|
||||
temp_dir.path().into(),
|
||||
Arc::new(TestEmbeddingProvider),
|
||||
Arc::new(TestEmbeddingProvider::new(16, |text| {
|
||||
let mut embedding = vec![0f32; 2];
|
||||
// if the text contains garbage, give it a 1 in the first dimension
|
||||
if text.contains("garbage in") {
|
||||
embedding[0] = 0.9;
|
||||
} else {
|
||||
embedding[0] = -0.9;
|
||||
}
|
||||
|
||||
if text.contains("garbage out") {
|
||||
embedding[1] = 0.9;
|
||||
} else {
|
||||
embedding[1] = -0.9;
|
||||
}
|
||||
|
||||
Ok(Embedding::new(embedding))
|
||||
})),
|
||||
&mut cx.to_async(),
|
||||
)
|
||||
.await
|
||||
@ -1046,4 +1123,82 @@ mod tests {
|
||||
|
||||
assert!(content.contains("garbage in, garbage out"));
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_embed_files(cx: &mut TestAppContext) {
|
||||
cx.executor().allow_parking();
|
||||
|
||||
let provider = Arc::new(TestEmbeddingProvider::new(3, |text| {
|
||||
if text.contains('g') {
|
||||
Err(anyhow!("cannot embed text containing a 'g' character"))
|
||||
} else {
|
||||
Ok(Embedding::new(
|
||||
('a'..'z')
|
||||
.map(|char| text.chars().filter(|c| *c == char).count() as f32)
|
||||
.collect(),
|
||||
))
|
||||
}
|
||||
}));
|
||||
|
||||
let (indexing_progress_tx, _) = channel::unbounded();
|
||||
let indexing_entries = Arc::new(IndexingEntrySet::new(indexing_progress_tx));
|
||||
|
||||
let (chunked_files_tx, chunked_files_rx) = channel::unbounded::<ChunkedFile>();
|
||||
chunked_files_tx
|
||||
.send_blocking(ChunkedFile {
|
||||
path: Path::new("test1.md").into(),
|
||||
mtime: None,
|
||||
handle: indexing_entries.insert(ProjectEntryId::from_proto(0)),
|
||||
text: "abcdefghijklmnop".to_string(),
|
||||
chunks: [0..4, 4..8, 8..12, 12..16]
|
||||
.into_iter()
|
||||
.map(|range| Chunk {
|
||||
range,
|
||||
digest: Default::default(),
|
||||
})
|
||||
.collect(),
|
||||
})
|
||||
.unwrap();
|
||||
chunked_files_tx
|
||||
.send_blocking(ChunkedFile {
|
||||
path: Path::new("test2.md").into(),
|
||||
mtime: None,
|
||||
handle: indexing_entries.insert(ProjectEntryId::from_proto(1)),
|
||||
text: "qrstuvwxyz".to_string(),
|
||||
chunks: [0..4, 4..8, 8..10]
|
||||
.into_iter()
|
||||
.map(|range| Chunk {
|
||||
range,
|
||||
digest: Default::default(),
|
||||
})
|
||||
.collect(),
|
||||
})
|
||||
.unwrap();
|
||||
chunked_files_tx.close();
|
||||
|
||||
let embed_files_task =
|
||||
cx.update(|cx| WorktreeIndex::embed_files(provider.clone(), chunked_files_rx, cx));
|
||||
embed_files_task.task.await.unwrap();
|
||||
|
||||
let mut embedded_files_rx = embed_files_task.files;
|
||||
let mut embedded_files = Vec::new();
|
||||
while let Some((embedded_file, _)) = embedded_files_rx.next().await {
|
||||
embedded_files.push(embedded_file);
|
||||
}
|
||||
|
||||
assert_eq!(embedded_files.len(), 1);
|
||||
assert_eq!(embedded_files[0].path.as_ref(), Path::new("test2.md"));
|
||||
assert_eq!(
|
||||
embedded_files[0]
|
||||
.chunks
|
||||
.iter()
|
||||
.map(|embedded_chunk| { embedded_chunk.embedding.clone() })
|
||||
.collect::<Vec<Embedding>>(),
|
||||
vec![
|
||||
(provider.compute_embedding)("qrst").unwrap(),
|
||||
(provider.compute_embedding)("uvwx").unwrap(),
|
||||
(provider.compute_embedding)("yz").unwrap(),
|
||||
],
|
||||
);
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user