mirror of
https://github.com/zed-industries/zed.git
synced 2024-11-08 07:35:01 +03:00
fill embeddings with database values and skip during embeddings queue
This commit is contained in:
parent
220533ff1a
commit
50cfb067e7
@ -42,6 +42,7 @@ pub struct EmbeddingQueue {
|
||||
finished_files_rx: channel::Receiver<FileToEmbed>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct FileToEmbedFragment {
|
||||
file: Arc<Mutex<FileToEmbed>>,
|
||||
document_range: Range<usize>,
|
||||
@ -74,8 +75,16 @@ impl EmbeddingQueue {
|
||||
});
|
||||
|
||||
let mut fragment_range = &mut self.pending_batch.last_mut().unwrap().document_range;
|
||||
let mut saved_tokens = 0;
|
||||
for (ix, document) in file.lock().documents.iter().enumerate() {
|
||||
let next_token_count = self.pending_batch_token_count + document.token_count;
|
||||
let document_token_count = if document.embedding.is_none() {
|
||||
document.token_count
|
||||
} else {
|
||||
saved_tokens += document.token_count;
|
||||
0
|
||||
};
|
||||
|
||||
let next_token_count = self.pending_batch_token_count + document_token_count;
|
||||
if next_token_count > self.embedding_provider.max_tokens_per_batch() {
|
||||
let range_end = fragment_range.end;
|
||||
self.flush();
|
||||
@ -87,8 +96,9 @@ impl EmbeddingQueue {
|
||||
}
|
||||
|
||||
fragment_range.end = ix + 1;
|
||||
self.pending_batch_token_count += document.token_count;
|
||||
self.pending_batch_token_count += document_token_count;
|
||||
}
|
||||
log::trace!("Saved Tokens: {:?}", saved_tokens);
|
||||
}
|
||||
|
||||
pub fn flush(&mut self) {
|
||||
@ -100,25 +110,41 @@ impl EmbeddingQueue {
|
||||
|
||||
let finished_files_tx = self.finished_files_tx.clone();
|
||||
let embedding_provider = self.embedding_provider.clone();
|
||||
|
||||
self.executor.spawn(async move {
|
||||
let mut spans = Vec::new();
|
||||
let mut document_count = 0;
|
||||
for fragment in &batch {
|
||||
let file = fragment.file.lock();
|
||||
document_count += file.documents[fragment.document_range.clone()].len();
|
||||
spans.extend(
|
||||
{
|
||||
file.documents[fragment.document_range.clone()]
|
||||
.iter()
|
||||
.iter().filter(|d| d.embedding.is_none())
|
||||
.map(|d| d.content.clone())
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
log::trace!("Documents Length: {:?}", document_count);
|
||||
log::trace!("Span Length: {:?}", spans.clone().len());
|
||||
|
||||
// If spans is 0, just send the fragment to the finished files if its the last one.
|
||||
if spans.len() == 0 {
|
||||
for fragment in batch.clone() {
|
||||
if let Some(file) = Arc::into_inner(fragment.file) {
|
||||
finished_files_tx.try_send(file.into_inner()).unwrap();
|
||||
}
|
||||
}
|
||||
return;
|
||||
};
|
||||
|
||||
match embedding_provider.embed_batch(spans).await {
|
||||
Ok(embeddings) => {
|
||||
let mut embeddings = embeddings.into_iter();
|
||||
for fragment in batch {
|
||||
for document in
|
||||
&mut fragment.file.lock().documents[fragment.document_range.clone()]
|
||||
&mut fragment.file.lock().documents[fragment.document_range.clone()].iter_mut().filter(|d| d.embedding.is_none())
|
||||
{
|
||||
if let Some(embedding) = embeddings.next() {
|
||||
document.embedding = Some(embedding);
|
||||
|
@ -255,6 +255,7 @@ impl SemanticIndex {
|
||||
let parsing_files_rx = parsing_files_rx.clone();
|
||||
let embedding_provider = embedding_provider.clone();
|
||||
let embedding_queue = embedding_queue.clone();
|
||||
let db = db.clone();
|
||||
_parsing_files_tasks.push(cx.background().spawn(async move {
|
||||
let mut retriever = CodeContextRetriever::new(embedding_provider.clone());
|
||||
while let Ok(pending_file) = parsing_files_rx.recv().await {
|
||||
@ -264,6 +265,7 @@ impl SemanticIndex {
|
||||
&mut retriever,
|
||||
&embedding_queue,
|
||||
&parsing_files_rx,
|
||||
&db,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
@ -293,13 +295,14 @@ impl SemanticIndex {
|
||||
retriever: &mut CodeContextRetriever,
|
||||
embedding_queue: &Arc<Mutex<EmbeddingQueue>>,
|
||||
parsing_files_rx: &channel::Receiver<PendingFile>,
|
||||
db: &VectorDatabase,
|
||||
) {
|
||||
let Some(language) = pending_file.language else {
|
||||
return;
|
||||
};
|
||||
|
||||
if let Some(content) = fs.load(&pending_file.absolute_path).await.log_err() {
|
||||
if let Some(documents) = retriever
|
||||
if let Some(mut documents) = retriever
|
||||
.parse_file_with_template(&pending_file.relative_path, &content, language)
|
||||
.log_err()
|
||||
{
|
||||
@ -309,22 +312,20 @@ impl SemanticIndex {
|
||||
documents.len()
|
||||
);
|
||||
|
||||
todo!();
|
||||
// if let Some(embeddings) = db
|
||||
// .embeddings_for_documents(
|
||||
// pending_file.worktree_db_id,
|
||||
// pending_file.relative_path,
|
||||
// &documents,
|
||||
// )
|
||||
// .await
|
||||
// .log_err()
|
||||
// {
|
||||
// for (document, embedding) in documents.iter_mut().zip(embeddings) {
|
||||
// if let Some(embedding) = embedding {
|
||||
// document.embedding = embedding;
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
if let Some(sha_to_embeddings) = db
|
||||
.embeddings_for_file(
|
||||
pending_file.worktree_db_id,
|
||||
pending_file.relative_path.clone(),
|
||||
)
|
||||
.await
|
||||
.log_err()
|
||||
{
|
||||
for document in documents.iter_mut() {
|
||||
if let Some(embedding) = sha_to_embeddings.get(&document.digest) {
|
||||
document.embedding = Some(embedding.to_owned());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
embedding_queue.lock().push(FileToEmbed {
|
||||
worktree_id: pending_file.worktree_db_id,
|
||||
|
Loading…
Reference in New Issue
Block a user