fill embeddings with database values and skip during embeddings queue

This commit is contained in:
KCaverly 2023-08-31 13:19:17 -04:00
parent 220533ff1a
commit 50cfb067e7
2 changed files with 48 additions and 21 deletions

View File

@ -42,6 +42,7 @@ pub struct EmbeddingQueue {
finished_files_rx: channel::Receiver<FileToEmbed>,
}
#[derive(Clone)]
pub struct FileToEmbedFragment {
file: Arc<Mutex<FileToEmbed>>,
document_range: Range<usize>,
@ -74,8 +75,16 @@ impl EmbeddingQueue {
});
let mut fragment_range = &mut self.pending_batch.last_mut().unwrap().document_range;
let mut saved_tokens = 0;
for (ix, document) in file.lock().documents.iter().enumerate() {
let next_token_count = self.pending_batch_token_count + document.token_count;
let document_token_count = if document.embedding.is_none() {
document.token_count
} else {
saved_tokens += document.token_count;
0
};
let next_token_count = self.pending_batch_token_count + document_token_count;
if next_token_count > self.embedding_provider.max_tokens_per_batch() {
let range_end = fragment_range.end;
self.flush();
@ -87,8 +96,9 @@ impl EmbeddingQueue {
}
fragment_range.end = ix + 1;
self.pending_batch_token_count += document.token_count;
self.pending_batch_token_count += document_token_count;
}
log::trace!("Saved Tokens: {:?}", saved_tokens);
}
pub fn flush(&mut self) {
@ -100,25 +110,41 @@ impl EmbeddingQueue {
let finished_files_tx = self.finished_files_tx.clone();
let embedding_provider = self.embedding_provider.clone();
self.executor.spawn(async move {
let mut spans = Vec::new();
let mut document_count = 0;
for fragment in &batch {
let file = fragment.file.lock();
document_count += file.documents[fragment.document_range.clone()].len();
spans.extend(
{
file.documents[fragment.document_range.clone()]
.iter()
.iter().filter(|d| d.embedding.is_none())
.map(|d| d.content.clone())
}
);
}
log::trace!("Documents Length: {:?}", document_count);
log::trace!("Span Length: {:?}", spans.clone().len());
// If spans is 0, just send the fragment to the finished files if its the last one.
if spans.len() == 0 {
for fragment in batch.clone() {
if let Some(file) = Arc::into_inner(fragment.file) {
finished_files_tx.try_send(file.into_inner()).unwrap();
}
}
return;
};
match embedding_provider.embed_batch(spans).await {
Ok(embeddings) => {
let mut embeddings = embeddings.into_iter();
for fragment in batch {
for document in
&mut fragment.file.lock().documents[fragment.document_range.clone()]
&mut fragment.file.lock().documents[fragment.document_range.clone()].iter_mut().filter(|d| d.embedding.is_none())
{
if let Some(embedding) = embeddings.next() {
document.embedding = Some(embedding);

View File

@ -255,6 +255,7 @@ impl SemanticIndex {
let parsing_files_rx = parsing_files_rx.clone();
let embedding_provider = embedding_provider.clone();
let embedding_queue = embedding_queue.clone();
let db = db.clone();
_parsing_files_tasks.push(cx.background().spawn(async move {
let mut retriever = CodeContextRetriever::new(embedding_provider.clone());
while let Ok(pending_file) = parsing_files_rx.recv().await {
@ -264,6 +265,7 @@ impl SemanticIndex {
&mut retriever,
&embedding_queue,
&parsing_files_rx,
&db,
)
.await;
}
@ -293,13 +295,14 @@ impl SemanticIndex {
retriever: &mut CodeContextRetriever,
embedding_queue: &Arc<Mutex<EmbeddingQueue>>,
parsing_files_rx: &channel::Receiver<PendingFile>,
db: &VectorDatabase,
) {
let Some(language) = pending_file.language else {
return;
};
if let Some(content) = fs.load(&pending_file.absolute_path).await.log_err() {
if let Some(documents) = retriever
if let Some(mut documents) = retriever
.parse_file_with_template(&pending_file.relative_path, &content, language)
.log_err()
{
@ -309,22 +312,20 @@ impl SemanticIndex {
documents.len()
);
todo!();
// if let Some(embeddings) = db
// .embeddings_for_documents(
// pending_file.worktree_db_id,
// pending_file.relative_path,
// &documents,
// )
// .await
// .log_err()
// {
// for (document, embedding) in documents.iter_mut().zip(embeddings) {
// if let Some(embedding) = embedding {
// document.embedding = embedding;
// }
// }
// }
if let Some(sha_to_embeddings) = db
.embeddings_for_file(
pending_file.worktree_db_id,
pending_file.relative_path.clone(),
)
.await
.log_err()
{
for document in documents.iter_mut() {
if let Some(embedding) = sha_to_embeddings.get(&document.digest) {
document.embedding = Some(embedding.to_owned());
}
}
}
embedding_queue.lock().push(FileToEmbed {
worktree_id: pending_file.worktree_db_id,