From 50cfb067e7c536636ed5bf7e119968d50843b287 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Thu, 31 Aug 2023 13:19:17 -0400 Subject: [PATCH] fill embeddings with database values and skip during embeddings queue --- crates/semantic_index/src/embedding_queue.rs | 34 ++++++++++++++++--- crates/semantic_index/src/semantic_index.rs | 35 ++++++++++---------- 2 files changed, 48 insertions(+), 21 deletions(-) diff --git a/crates/semantic_index/src/embedding_queue.rs b/crates/semantic_index/src/embedding_queue.rs index 4c82ced918..96493fc4d3 100644 --- a/crates/semantic_index/src/embedding_queue.rs +++ b/crates/semantic_index/src/embedding_queue.rs @@ -42,6 +42,7 @@ pub struct EmbeddingQueue { finished_files_rx: channel::Receiver, } +#[derive(Clone)] pub struct FileToEmbedFragment { file: Arc>, document_range: Range, @@ -74,8 +75,16 @@ impl EmbeddingQueue { }); let mut fragment_range = &mut self.pending_batch.last_mut().unwrap().document_range; + let mut saved_tokens = 0; for (ix, document) in file.lock().documents.iter().enumerate() { - let next_token_count = self.pending_batch_token_count + document.token_count; + let document_token_count = if document.embedding.is_none() { + document.token_count + } else { + saved_tokens += document.token_count; + 0 + }; + + let next_token_count = self.pending_batch_token_count + document_token_count; if next_token_count > self.embedding_provider.max_tokens_per_batch() { let range_end = fragment_range.end; self.flush(); @@ -87,8 +96,9 @@ impl EmbeddingQueue { } fragment_range.end = ix + 1; - self.pending_batch_token_count += document.token_count; + self.pending_batch_token_count += document_token_count; } + log::trace!("Saved Tokens: {:?}", saved_tokens); } pub fn flush(&mut self) { @@ -100,25 +110,41 @@ impl EmbeddingQueue { let finished_files_tx = self.finished_files_tx.clone(); let embedding_provider = self.embedding_provider.clone(); + self.executor.spawn(async move { let mut spans = Vec::new(); + let mut document_count = 0; for fragment in &batch { let file = fragment.file.lock(); + document_count += file.documents[fragment.document_range.clone()].len(); spans.extend( { file.documents[fragment.document_range.clone()] - .iter() + .iter().filter(|d| d.embedding.is_none()) .map(|d| d.content.clone()) } ); } + log::trace!("Documents Length: {:?}", document_count); + log::trace!("Span Length: {:?}", spans.clone().len()); + + // If spans is 0, just send the fragment to the finished files if its the last one. + if spans.len() == 0 { + for fragment in batch.clone() { + if let Some(file) = Arc::into_inner(fragment.file) { + finished_files_tx.try_send(file.into_inner()).unwrap(); + } + } + return; + }; + match embedding_provider.embed_batch(spans).await { Ok(embeddings) => { let mut embeddings = embeddings.into_iter(); for fragment in batch { for document in - &mut fragment.file.lock().documents[fragment.document_range.clone()] + &mut fragment.file.lock().documents[fragment.document_range.clone()].iter_mut().filter(|d| d.embedding.is_none()) { if let Some(embedding) = embeddings.next() { document.embedding = Some(embedding); diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index 58166c1a22..726b04583a 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -255,6 +255,7 @@ impl SemanticIndex { let parsing_files_rx = parsing_files_rx.clone(); let embedding_provider = embedding_provider.clone(); let embedding_queue = embedding_queue.clone(); + let db = db.clone(); _parsing_files_tasks.push(cx.background().spawn(async move { let mut retriever = CodeContextRetriever::new(embedding_provider.clone()); while let Ok(pending_file) = parsing_files_rx.recv().await { @@ -264,6 +265,7 @@ impl SemanticIndex { &mut retriever, &embedding_queue, &parsing_files_rx, + &db, ) .await; } @@ -293,13 +295,14 @@ impl SemanticIndex { retriever: &mut CodeContextRetriever, embedding_queue: &Arc>, parsing_files_rx: &channel::Receiver, + db: &VectorDatabase, ) { let Some(language) = pending_file.language else { return; }; if let Some(content) = fs.load(&pending_file.absolute_path).await.log_err() { - if let Some(documents) = retriever + if let Some(mut documents) = retriever .parse_file_with_template(&pending_file.relative_path, &content, language) .log_err() { @@ -309,22 +312,20 @@ impl SemanticIndex { documents.len() ); - todo!(); - // if let Some(embeddings) = db - // .embeddings_for_documents( - // pending_file.worktree_db_id, - // pending_file.relative_path, - // &documents, - // ) - // .await - // .log_err() - // { - // for (document, embedding) in documents.iter_mut().zip(embeddings) { - // if let Some(embedding) = embedding { - // document.embedding = embedding; - // } - // } - // } + if let Some(sha_to_embeddings) = db + .embeddings_for_file( + pending_file.worktree_db_id, + pending_file.relative_path.clone(), + ) + .await + .log_err() + { + for document in documents.iter_mut() { + if let Some(embedding) = sha_to_embeddings.get(&document.digest) { + document.embedding = Some(embedding.to_owned()); + } + } + } embedding_queue.lock().push(FileToEmbed { worktree_id: pending_file.worktree_db_id,