From 38b9d5cc36f2377dc89da54ea6a4df9e167019c1 Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Tue, 30 Apr 2024 10:55:38 -0700 Subject: [PATCH] Fix some semantic index issues (#11216) * [x] Fixed an issue where embeddings would be assigned incorrectly to files if a subset of embedding batches failed * [x] Added a command to debug which paths are present in the semantic index * [x] Determine why so many paths are often missing from the semantic index * we erroring out if an embedding batch contained multiple texts that were the same, which can happen if a worktree contains multiple copies of the same text (e.g. a license). Release Notes: - N/A --------- Co-authored-by: Marshall Co-authored-by: Nathan Co-authored-by: Kyle Co-authored-by: Kyle Kelley --- .../examples/chat_with_functions.rs | 2 +- crates/assistant2/src/assistant2.rs | 26 +- crates/semantic_index/src/embedding/cloud.rs | 9 +- crates/semantic_index/src/semantic_index.rs | 261 ++++++++++++++---- 4 files changed, 237 insertions(+), 61 deletions(-) diff --git a/crates/assistant2/examples/chat_with_functions.rs b/crates/assistant2/examples/chat_with_functions.rs index 7c7011caa3..1b8afa1973 100644 --- a/crates/assistant2/examples/chat_with_functions.rs +++ b/crates/assistant2/examples/chat_with_functions.rs @@ -365,7 +365,7 @@ impl Example { ) -> Self { Self { assistant_panel: cx.new_view(|cx| { - AssistantPanel::new(language_registry, tool_registry, user_store, cx) + AssistantPanel::new(language_registry, tool_registry, user_store, None, cx) }), } } diff --git a/crates/assistant2/src/assistant2.rs b/crates/assistant2/src/assistant2.rs index 54ffc686b1..3593ddd874 100644 --- a/crates/assistant2/src/assistant2.rs +++ b/crates/assistant2/src/assistant2.rs @@ -19,7 +19,7 @@ use gpui::{ use language::{language_settings::SoftWrap, LanguageRegistry}; use open_ai::{FunctionContent, ToolCall, ToolCallContent}; use rich_text::RichText; -use semantic_index::{CloudEmbeddingProvider, SemanticIndex}; +use semantic_index::{CloudEmbeddingProvider, ProjectIndex, SemanticIndex}; use serde::Deserialize; use settings::Settings; use std::sync::Arc; @@ -51,7 +51,7 @@ pub enum SubmitMode { Codebase, } -gpui::actions!(assistant2, [Cancel, ToggleFocus]); +gpui::actions!(assistant2, [Cancel, ToggleFocus, DebugProjectIndex]); gpui::impl_actions!(assistant2, [Submit]); pub fn init(client: Arc, cx: &mut AppContext) { @@ -131,7 +131,13 @@ impl AssistantPanel { let tool_registry = Arc::new(tool_registry); - Self::new(app_state.languages.clone(), tool_registry, user_store, cx) + Self::new( + app_state.languages.clone(), + tool_registry, + user_store, + Some(project_index), + cx, + ) }) }) } @@ -140,6 +146,7 @@ impl AssistantPanel { language_registry: Arc, tool_registry: Arc, user_store: Model, + project_index: Option>, cx: &mut ViewContext, ) -> Self { let chat = cx.new_view(|cx| { @@ -147,6 +154,7 @@ impl AssistantPanel { language_registry.clone(), tool_registry.clone(), user_store, + project_index, cx, ) }); @@ -225,6 +233,7 @@ struct AssistantChat { collapsed_messages: HashMap, pending_completion: Option>, tool_registry: Arc, + project_index: Option>, } impl AssistantChat { @@ -232,6 +241,7 @@ impl AssistantChat { language_registry: Arc, tool_registry: Arc, user_store: Model, + project_index: Option>, cx: &mut ViewContext, ) -> Self { let model = CompletionProvider::get(cx).default_model(); @@ -258,6 +268,7 @@ impl AssistantChat { list_state, user_store, language_registry, + project_index, next_message_id: MessageId(0), collapsed_messages: HashMap::default(), pending_completion: None, @@ -342,6 +353,14 @@ impl AssistantChat { self.pending_completion.is_none() } + fn debug_project_index(&mut self, _: &DebugProjectIndex, cx: &mut ViewContext) { + if let Some(index) = &self.project_index { + index.update(cx, |project_index, cx| { + project_index.debug(cx).detach_and_log_err(cx) + }); + } + } + async fn request_completion( this: WeakView, mode: SubmitMode, @@ -686,6 +705,7 @@ impl Render for AssistantChat { .key_context("AssistantChat") .on_action(cx.listener(Self::submit)) .on_action(cx.listener(Self::cancel)) + .on_action(cx.listener(Self::debug_project_index)) .text_color(Color::Default.color(cx)) .child(list(self.list_state.clone()).flex_1()) .child(Composer::new( diff --git a/crates/semantic_index/src/embedding/cloud.rs b/crates/semantic_index/src/embedding/cloud.rs index 2a1df705c8..ea09adea82 100644 --- a/crates/semantic_index/src/embedding/cloud.rs +++ b/crates/semantic_index/src/embedding/cloud.rs @@ -72,10 +72,11 @@ impl EmbeddingProvider for CloudEmbeddingProvider { texts .iter() .map(|to_embed| { - let dimensions = embeddings.remove(&to_embed.digest).with_context(|| { - format!("server did not return an embedding for {:?}", to_embed) - })?; - Ok(Embedding::new(dimensions)) + let embedding = + embeddings.get(&to_embed.digest).cloned().with_context(|| { + format!("server did not return an embedding for {:?}", to_embed) + })?; + Ok(Embedding::new(embedding)) }) .collect() } diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index e26ca0a0a7..f8082de07b 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -21,6 +21,7 @@ use smol::channel; use std::{ cmp::Ordering, future::Future, + iter, num::NonZeroUsize, ops::Range, path::{Path, PathBuf}, @@ -295,6 +296,28 @@ impl ProjectIndex { } Ok(result) } + + pub fn debug(&self, cx: &mut ModelContext) -> Task> { + let indices = self + .worktree_indices + .values() + .filter_map(|worktree_index| { + if let WorktreeIndexHandle::Loaded { index, .. } = worktree_index { + Some(index.clone()) + } else { + None + } + }) + .collect::>(); + + cx.spawn(|_, mut cx| async move { + eprintln!("semantic index contents:"); + for index in indices { + index.update(&mut cx, |index, cx| index.debug(cx))?.await? + } + Ok(()) + }) + } } pub struct SearchResult { @@ -419,7 +442,7 @@ impl WorktreeIndex { let worktree_abs_path = worktree.abs_path().clone(); let scan = self.scan_entries(worktree.clone(), cx); let chunk = self.chunk_files(worktree_abs_path, scan.updated_entries, cx); - let embed = self.embed_files(chunk.files, cx); + let embed = Self::embed_files(self.embedding_provider.clone(), chunk.files, cx); let persist = self.persist_embeddings(scan.deleted_entry_ranges, embed.files, cx); async move { futures::try_join!(scan.task, chunk.task, embed.task, persist)?; @@ -436,7 +459,7 @@ impl WorktreeIndex { let worktree_abs_path = worktree.abs_path().clone(); let scan = self.scan_updated_entries(worktree, updated_entries.clone(), cx); let chunk = self.chunk_files(worktree_abs_path, scan.updated_entries, cx); - let embed = self.embed_files(chunk.files, cx); + let embed = Self::embed_files(self.embedding_provider.clone(), chunk.files, cx); let persist = self.persist_embeddings(scan.deleted_entry_ranges, embed.files, cx); async move { futures::try_join!(scan.task, chunk.task, embed.task, persist)?; @@ -500,7 +523,7 @@ impl WorktreeIndex { } if entry.mtime != saved_mtime { - let handle = entries_being_indexed.insert(&entry); + let handle = entries_being_indexed.insert(entry.id); updated_entries_tx.send((entry.clone(), handle)).await?; } } @@ -539,7 +562,7 @@ impl WorktreeIndex { | project::PathChange::AddedOrUpdated => { if let Some(entry) = worktree.entry_for_id(*entry_id) { if entry.is_file() { - let handle = entries_being_indexed.insert(&entry); + let handle = entries_being_indexed.insert(entry.id); updated_entries_tx.send((entry.clone(), handle)).await?; } } @@ -601,7 +624,8 @@ impl WorktreeIndex { let chunked_file = ChunkedFile { chunks: chunk_text(&text, grammar), handle, - entry, + path: entry.path, + mtime: entry.mtime, text, }; @@ -623,11 +647,11 @@ impl WorktreeIndex { } fn embed_files( - &self, + embedding_provider: Arc, chunked_files: channel::Receiver, cx: &AppContext, ) -> EmbedFiles { - let embedding_provider = self.embedding_provider.clone(); + let embedding_provider = embedding_provider.clone(); let (embedded_files_tx, embedded_files_rx) = channel::bounded(512); let task = cx.background_executor().spawn(async move { let mut chunked_file_batches = @@ -635,9 +659,10 @@ impl WorktreeIndex { while let Some(chunked_files) = chunked_file_batches.next().await { // View the batch of files as a vec of chunks // Flatten out to a vec of chunks that we can subdivide into batch sized pieces - // Once those are done, reassemble it back into which files they belong to + // Once those are done, reassemble them back into the files in which they belong + // If any embeddings fail for a file, the entire file is discarded - let chunks = chunked_files + let chunks: Vec = chunked_files .iter() .flat_map(|file| { file.chunks.iter().map(|chunk| TextToEmbed { @@ -647,36 +672,50 @@ impl WorktreeIndex { }) .collect::>(); - let mut embeddings = Vec::new(); + let mut embeddings: Vec> = Vec::new(); for embedding_batch in chunks.chunks(embedding_provider.batch_size()) { if let Some(batch_embeddings) = embedding_provider.embed(embedding_batch).await.log_err() { - embeddings.extend_from_slice(&batch_embeddings); + if batch_embeddings.len() == embedding_batch.len() { + embeddings.extend(batch_embeddings.into_iter().map(Some)); + continue; + } + log::error!( + "embedding provider returned unexpected embedding count {}, expected {}", + batch_embeddings.len(), embedding_batch.len() + ); } + + embeddings.extend(iter::repeat(None).take(embedding_batch.len())); } let mut embeddings = embeddings.into_iter(); for chunked_file in chunked_files { - let chunk_embeddings = embeddings - .by_ref() - .take(chunked_file.chunks.len()) - .collect::>(); - let embedded_chunks = chunked_file - .chunks - .into_iter() - .zip(chunk_embeddings) - .map(|(chunk, embedding)| EmbeddedChunk { chunk, embedding }) - .collect(); - let embedded_file = EmbeddedFile { - path: chunked_file.entry.path.clone(), - mtime: chunked_file.entry.mtime, - chunks: embedded_chunks, + let mut embedded_file = EmbeddedFile { + path: chunked_file.path, + mtime: chunked_file.mtime, + chunks: Vec::new(), }; - embedded_files_tx - .send((embedded_file, chunked_file.handle)) - .await?; + let mut embedded_all_chunks = true; + for (chunk, embedding) in + chunked_file.chunks.into_iter().zip(embeddings.by_ref()) + { + if let Some(embedding) = embedding { + embedded_file + .chunks + .push(EmbeddedChunk { chunk, embedding }); + } else { + embedded_all_chunks = false; + } + } + + if embedded_all_chunks { + embedded_files_tx + .send((embedded_file, chunked_file.handle)) + .await?; + } } } Ok(()) @@ -826,6 +865,21 @@ impl WorktreeIndex { }) } + fn debug(&mut self, cx: &mut ModelContext) -> Task> { + let connection = self.db_connection.clone(); + let db = self.db; + cx.background_executor().spawn(async move { + let tx = connection + .read_txn() + .context("failed to create read transaction")?; + for record in db.iter(&tx)? { + let (key, _) = record?; + eprintln!("{}", path_for_db_key(key)); + } + Ok(()) + }) + } + #[cfg(test)] fn path_count(&self) -> Result { let txn = self @@ -848,7 +902,8 @@ struct ChunkFiles { } struct ChunkedFile { - pub entry: Entry, + pub path: Arc, + pub mtime: Option, pub handle: IndexingEntryHandle, pub text: String, pub chunks: Vec, @@ -872,11 +927,14 @@ struct EmbeddedChunk { embedding: Embedding, } +/// The set of entries that are currently being indexed. struct IndexingEntrySet { entry_ids: Mutex>, tx: channel::Sender<()>, } +/// When dropped, removes the entry from the set of entries that are being indexed. +#[derive(Clone)] struct IndexingEntryHandle { entry_id: ProjectEntryId, set: Weak, @@ -890,11 +948,11 @@ impl IndexingEntrySet { } } - fn insert(self: &Arc, entry: &project::Entry) -> IndexingEntryHandle { - self.entry_ids.lock().insert(entry.id); + fn insert(self: &Arc, entry_id: ProjectEntryId) -> IndexingEntryHandle { + self.entry_ids.lock().insert(entry_id); self.tx.send_blocking(()).ok(); IndexingEntryHandle { - entry_id: entry.id, + entry_id, set: Arc::downgrade(self), } } @@ -917,6 +975,10 @@ fn db_key_for_path(path: &Arc) -> String { path.to_string_lossy().replace('/', "\0") } +fn path_for_db_key(key: &str) -> String { + key.replace('\0', "/") +} + #[cfg(test)] mod tests { use super::*; @@ -939,7 +1001,22 @@ mod tests { }); } - pub struct TestEmbeddingProvider; + pub struct TestEmbeddingProvider { + batch_size: usize, + compute_embedding: Box Result + Send + Sync>, + } + + impl TestEmbeddingProvider { + pub fn new( + batch_size: usize, + compute_embedding: impl 'static + Fn(&str) -> Result + Send + Sync, + ) -> Self { + return Self { + batch_size, + compute_embedding: Box::new(compute_embedding), + }; + } + } impl EmbeddingProvider for TestEmbeddingProvider { fn embed<'a>( @@ -948,29 +1025,13 @@ mod tests { ) -> BoxFuture<'a, Result>> { let embeddings = texts .iter() - .map(|text| { - let mut embedding = vec![0f32; 2]; - // if the text contains garbage, give it a 1 in the first dimension - if text.text.contains("garbage in") { - embedding[0] = 0.9; - } else { - embedding[0] = -0.9; - } - - if text.text.contains("garbage out") { - embedding[1] = 0.9; - } else { - embedding[1] = -0.9; - } - - Embedding::new(embedding) - }) + .map(|to_embed| (self.compute_embedding)(to_embed.text)) .collect(); - future::ready(Ok(embeddings)).boxed() + future::ready(embeddings).boxed() } fn batch_size(&self) -> usize { - 16 + self.batch_size } } @@ -984,7 +1045,23 @@ mod tests { let mut semantic_index = SemanticIndex::new( temp_dir.path().into(), - Arc::new(TestEmbeddingProvider), + Arc::new(TestEmbeddingProvider::new(16, |text| { + let mut embedding = vec![0f32; 2]; + // if the text contains garbage, give it a 1 in the first dimension + if text.contains("garbage in") { + embedding[0] = 0.9; + } else { + embedding[0] = -0.9; + } + + if text.contains("garbage out") { + embedding[1] = 0.9; + } else { + embedding[1] = -0.9; + } + + Ok(Embedding::new(embedding)) + })), &mut cx.to_async(), ) .await @@ -1046,4 +1123,82 @@ mod tests { assert!(content.contains("garbage in, garbage out")); } + + #[gpui::test] + async fn test_embed_files(cx: &mut TestAppContext) { + cx.executor().allow_parking(); + + let provider = Arc::new(TestEmbeddingProvider::new(3, |text| { + if text.contains('g') { + Err(anyhow!("cannot embed text containing a 'g' character")) + } else { + Ok(Embedding::new( + ('a'..'z') + .map(|char| text.chars().filter(|c| *c == char).count() as f32) + .collect(), + )) + } + })); + + let (indexing_progress_tx, _) = channel::unbounded(); + let indexing_entries = Arc::new(IndexingEntrySet::new(indexing_progress_tx)); + + let (chunked_files_tx, chunked_files_rx) = channel::unbounded::(); + chunked_files_tx + .send_blocking(ChunkedFile { + path: Path::new("test1.md").into(), + mtime: None, + handle: indexing_entries.insert(ProjectEntryId::from_proto(0)), + text: "abcdefghijklmnop".to_string(), + chunks: [0..4, 4..8, 8..12, 12..16] + .into_iter() + .map(|range| Chunk { + range, + digest: Default::default(), + }) + .collect(), + }) + .unwrap(); + chunked_files_tx + .send_blocking(ChunkedFile { + path: Path::new("test2.md").into(), + mtime: None, + handle: indexing_entries.insert(ProjectEntryId::from_proto(1)), + text: "qrstuvwxyz".to_string(), + chunks: [0..4, 4..8, 8..10] + .into_iter() + .map(|range| Chunk { + range, + digest: Default::default(), + }) + .collect(), + }) + .unwrap(); + chunked_files_tx.close(); + + let embed_files_task = + cx.update(|cx| WorktreeIndex::embed_files(provider.clone(), chunked_files_rx, cx)); + embed_files_task.task.await.unwrap(); + + let mut embedded_files_rx = embed_files_task.files; + let mut embedded_files = Vec::new(); + while let Some((embedded_file, _)) = embedded_files_rx.next().await { + embedded_files.push(embedded_file); + } + + assert_eq!(embedded_files.len(), 1); + assert_eq!(embedded_files[0].path.as_ref(), Path::new("test2.md")); + assert_eq!( + embedded_files[0] + .chunks + .iter() + .map(|embedded_chunk| { embedded_chunk.embedding.clone() }) + .collect::>(), + vec![ + (provider.compute_embedding)("qrst").unwrap(), + (provider.compute_embedding)("uvwx").unwrap(), + (provider.compute_embedding)("yz").unwrap(), + ], + ); + } }