moved semantic index to use embeddings queue to batch and managed for atomic database writes

Co-authored-by: Max <max@zed.dev>
This commit is contained in:
KCaverly 2023-08-30 16:58:45 -04:00
parent 76ce52df4e
commit 5abad58b0d
3 changed files with 55 additions and 222 deletions

View File

@ -1,10 +1,8 @@
use std::{mem, ops::Range, path::PathBuf, sync::Arc, time::SystemTime};
use gpui::AppContext;
use crate::{embedding::EmbeddingProvider, parsing::Document, JobHandle};
use gpui::executor::Background;
use parking_lot::Mutex;
use smol::channel;
use crate::{embedding::EmbeddingProvider, parsing::Document, JobHandle};
use std::{mem, ops::Range, path::PathBuf, sync::Arc, time::SystemTime};
#[derive(Clone)]
pub struct FileToEmbed {
@ -38,6 +36,7 @@ impl PartialEq for FileToEmbed {
pub struct EmbeddingQueue {
embedding_provider: Arc<dyn EmbeddingProvider>,
pending_batch: Vec<FileToEmbedFragment>,
executor: Arc<Background>,
pending_batch_token_count: usize,
finished_files_tx: channel::Sender<FileToEmbed>,
finished_files_rx: channel::Receiver<FileToEmbed>,
@ -49,10 +48,11 @@ pub struct FileToEmbedFragment {
}
impl EmbeddingQueue {
pub fn new(embedding_provider: Arc<dyn EmbeddingProvider>) -> Self {
pub fn new(embedding_provider: Arc<dyn EmbeddingProvider>, executor: Arc<Background>) -> Self {
let (finished_files_tx, finished_files_rx) = channel::unbounded();
Self {
embedding_provider,
executor,
pending_batch: Vec::new(),
pending_batch_token_count: 0,
finished_files_tx,
@ -60,7 +60,12 @@ impl EmbeddingQueue {
}
}
pub fn push(&mut self, file: FileToEmbed, cx: &mut AppContext) {
pub fn push(&mut self, file: FileToEmbed) {
if file.documents.is_empty() {
self.finished_files_tx.try_send(file).unwrap();
return;
}
let file = Arc::new(Mutex::new(file));
self.pending_batch.push(FileToEmbedFragment {
@ -73,7 +78,7 @@ impl EmbeddingQueue {
let next_token_count = self.pending_batch_token_count + document.token_count;
if next_token_count > self.embedding_provider.max_tokens_per_batch() {
let range_end = fragment_range.end;
self.flush(cx);
self.flush();
self.pending_batch.push(FileToEmbedFragment {
file: file.clone(),
document_range: range_end..range_end,
@ -86,7 +91,7 @@ impl EmbeddingQueue {
}
}
pub fn flush(&mut self, cx: &mut AppContext) {
pub fn flush(&mut self) {
let batch = mem::take(&mut self.pending_batch);
self.pending_batch_token_count = 0;
if batch.is_empty() {
@ -95,7 +100,7 @@ impl EmbeddingQueue {
let finished_files_tx = self.finished_files_tx.clone();
let embedding_provider = self.embedding_provider.clone();
cx.background().spawn(async move {
self.executor.spawn(async move {
let mut spans = Vec::new();
for fragment in &batch {
let file = fragment.file.lock();

View File

@ -1,5 +1,6 @@
mod db;
mod embedding;
mod embedding_queue;
mod parsing;
pub mod semantic_index_settings;
@ -10,6 +11,7 @@ use crate::semantic_index_settings::SemanticIndexSettings;
use anyhow::{anyhow, Result};
use db::VectorDatabase;
use embedding::{EmbeddingProvider, OpenAIEmbeddings};
use embedding_queue::{EmbeddingQueue, FileToEmbed};
use futures::{channel::oneshot, Future};
use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle};
use language::{Anchor, Buffer, Language, LanguageRegistry};
@ -23,7 +25,6 @@ use smol::channel;
use std::{
cmp::Ordering,
collections::{BTreeMap, HashMap},
mem,
ops::Range,
path::{Path, PathBuf},
sync::{Arc, Weak},
@ -38,7 +39,6 @@ use util::{
use workspace::WorkspaceCreated;
const SEMANTIC_INDEX_VERSION: usize = 7;
const EMBEDDINGS_BATCH_SIZE: usize = 80;
const BACKGROUND_INDEXING_DELAY: Duration = Duration::from_secs(600);
pub fn init(
@ -106,9 +106,8 @@ pub struct SemanticIndex {
language_registry: Arc<LanguageRegistry>,
db_update_tx: channel::Sender<DbOperation>,
parsing_files_tx: channel::Sender<PendingFile>,
_embedding_task: Task<()>,
_db_update_task: Task<()>,
_embed_batch_tasks: Vec<Task<()>>,
_batch_files_task: Task<()>,
_parsing_files_tasks: Vec<Task<()>>,
projects: HashMap<WeakModelHandle<Project>, ProjectState>,
}
@ -128,7 +127,7 @@ struct ChangedPathInfo {
}
#[derive(Clone)]
struct JobHandle {
pub struct JobHandle {
/// The outer Arc is here to count the clones of a JobHandle instance;
/// when the last handle to a given job is dropped, we decrement a counter (just once).
tx: Arc<Weak<Mutex<watch::Sender<usize>>>>,
@ -230,17 +229,6 @@ enum DbOperation {
},
}
enum EmbeddingJob {
Enqueue {
worktree_id: i64,
path: PathBuf,
mtime: SystemTime,
documents: Vec<Document>,
job_handle: JobHandle,
},
Flush,
}
impl SemanticIndex {
pub fn global(cx: &AppContext) -> Option<ModelHandle<SemanticIndex>> {
if cx.has_global::<ModelHandle<Self>>() {
@ -287,52 +275,35 @@ impl SemanticIndex {
}
});
// Group documents into batches and send them to the embedding provider.
let (embed_batch_tx, embed_batch_rx) =
channel::unbounded::<Vec<(i64, Vec<Document>, PathBuf, SystemTime, JobHandle)>>();
let mut _embed_batch_tasks = Vec::new();
for _ in 0..cx.background().num_cpus() {
let embed_batch_rx = embed_batch_rx.clone();
_embed_batch_tasks.push(cx.background().spawn({
let db_update_tx = db_update_tx.clone();
let embedding_provider = embedding_provider.clone();
async move {
while let Ok(embeddings_queue) = embed_batch_rx.recv().await {
Self::compute_embeddings_for_batch(
embeddings_queue,
&embedding_provider,
&db_update_tx,
)
.await;
}
let embedding_queue =
EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone());
let _embedding_task = cx.background().spawn({
let embedded_files = embedding_queue.finished_files();
let db_update_tx = db_update_tx.clone();
async move {
while let Ok(file) = embedded_files.recv().await {
db_update_tx
.try_send(DbOperation::InsertFile {
worktree_id: file.worktree_id,
documents: file.documents,
path: file.path,
mtime: file.mtime,
job_handle: file.job_handle,
})
.ok();
}
}));
}
// Group documents into batches and send them to the embedding provider.
let (batch_files_tx, batch_files_rx) = channel::unbounded::<EmbeddingJob>();
let _batch_files_task = cx.background().spawn(async move {
let mut queue_len = 0;
let mut embeddings_queue = vec![];
while let Ok(job) = batch_files_rx.recv().await {
Self::enqueue_documents_to_embed(
job,
&mut queue_len,
&mut embeddings_queue,
&embed_batch_tx,
);
}
});
// Parse files into embeddable documents.
let (parsing_files_tx, parsing_files_rx) = channel::unbounded::<PendingFile>();
let embedding_queue = Arc::new(Mutex::new(embedding_queue));
let mut _parsing_files_tasks = Vec::new();
for _ in 0..cx.background().num_cpus() {
let fs = fs.clone();
let parsing_files_rx = parsing_files_rx.clone();
let batch_files_tx = batch_files_tx.clone();
let db_update_tx = db_update_tx.clone();
let embedding_provider = embedding_provider.clone();
let embedding_queue = embedding_queue.clone();
_parsing_files_tasks.push(cx.background().spawn(async move {
let mut retriever = CodeContextRetriever::new(embedding_provider.clone());
while let Ok(pending_file) = parsing_files_rx.recv().await {
@ -340,9 +311,8 @@ impl SemanticIndex {
&fs,
pending_file,
&mut retriever,
&batch_files_tx,
&embedding_queue,
&parsing_files_rx,
&db_update_tx,
)
.await;
}
@ -361,8 +331,7 @@ impl SemanticIndex {
db_update_tx,
parsing_files_tx,
_db_update_task,
_embed_batch_tasks,
_batch_files_task,
_embedding_task,
_parsing_files_tasks,
projects: HashMap::new(),
}
@ -403,136 +372,12 @@ impl SemanticIndex {
}
}
async fn compute_embeddings_for_batch(
mut embeddings_queue: Vec<(i64, Vec<Document>, PathBuf, SystemTime, JobHandle)>,
embedding_provider: &Arc<dyn EmbeddingProvider>,
db_update_tx: &channel::Sender<DbOperation>,
) {
let mut batch_documents = vec![];
for (_, documents, _, _, _) in embeddings_queue.iter() {
batch_documents.extend(documents.iter().map(|document| document.content.as_str()));
}
if let Ok(embeddings) = embedding_provider.embed_batch(batch_documents).await {
log::trace!(
"created {} embeddings for {} files",
embeddings.len(),
embeddings_queue.len(),
);
let mut i = 0;
let mut j = 0;
for embedding in embeddings.iter() {
while embeddings_queue[i].1.len() == j {
i += 1;
j = 0;
}
embeddings_queue[i].1[j].embedding = embedding.to_owned();
j += 1;
}
for (worktree_id, documents, path, mtime, job_handle) in embeddings_queue.into_iter() {
db_update_tx
.send(DbOperation::InsertFile {
worktree_id,
documents,
path,
mtime,
job_handle,
})
.await
.unwrap();
}
} else {
// Insert the file in spite of failure so that future attempts to index it do not take place (unless the file is changed).
for (worktree_id, _, path, mtime, job_handle) in embeddings_queue.into_iter() {
db_update_tx
.send(DbOperation::InsertFile {
worktree_id,
documents: vec![],
path,
mtime,
job_handle,
})
.await
.unwrap();
}
}
}
fn enqueue_documents_to_embed(
job: EmbeddingJob,
queue_len: &mut usize,
embeddings_queue: &mut Vec<(i64, Vec<Document>, PathBuf, SystemTime, JobHandle)>,
embed_batch_tx: &channel::Sender<Vec<(i64, Vec<Document>, PathBuf, SystemTime, JobHandle)>>,
) {
// Handle edge case where individual file has more documents than max batch size
let should_flush = match job {
EmbeddingJob::Enqueue {
documents,
worktree_id,
path,
mtime,
job_handle,
} => {
// If documents is greater than embeddings batch size, recursively batch existing rows.
if &documents.len() > &EMBEDDINGS_BATCH_SIZE {
let first_job = EmbeddingJob::Enqueue {
documents: documents[..EMBEDDINGS_BATCH_SIZE].to_vec(),
worktree_id,
path: path.clone(),
mtime,
job_handle: job_handle.clone(),
};
Self::enqueue_documents_to_embed(
first_job,
queue_len,
embeddings_queue,
embed_batch_tx,
);
let second_job = EmbeddingJob::Enqueue {
documents: documents[EMBEDDINGS_BATCH_SIZE..].to_vec(),
worktree_id,
path: path.clone(),
mtime,
job_handle: job_handle.clone(),
};
Self::enqueue_documents_to_embed(
second_job,
queue_len,
embeddings_queue,
embed_batch_tx,
);
return;
} else {
*queue_len += &documents.len();
embeddings_queue.push((worktree_id, documents, path, mtime, job_handle));
*queue_len >= EMBEDDINGS_BATCH_SIZE
}
}
EmbeddingJob::Flush => true,
};
if should_flush {
embed_batch_tx
.try_send(mem::take(embeddings_queue))
.unwrap();
*queue_len = 0;
}
}
async fn parse_file(
fs: &Arc<dyn Fs>,
pending_file: PendingFile,
retriever: &mut CodeContextRetriever,
batch_files_tx: &channel::Sender<EmbeddingJob>,
embedding_queue: &Arc<Mutex<EmbeddingQueue>>,
parsing_files_rx: &channel::Receiver<PendingFile>,
db_update_tx: &channel::Sender<DbOperation>,
) {
let Some(language) = pending_file.language else {
return;
@ -549,33 +394,18 @@ impl SemanticIndex {
documents.len()
);
if documents.len() == 0 {
db_update_tx
.send(DbOperation::InsertFile {
worktree_id: pending_file.worktree_db_id,
documents,
path: pending_file.relative_path,
mtime: pending_file.modified_time,
job_handle: pending_file.job_handle,
})
.await
.unwrap();
} else {
batch_files_tx
.try_send(EmbeddingJob::Enqueue {
worktree_id: pending_file.worktree_db_id,
path: pending_file.relative_path,
mtime: pending_file.modified_time,
job_handle: pending_file.job_handle,
documents,
})
.unwrap();
}
embedding_queue.lock().push(FileToEmbed {
worktree_id: pending_file.worktree_db_id,
path: pending_file.relative_path,
mtime: pending_file.modified_time,
job_handle: pending_file.job_handle,
documents,
});
}
}
if parsing_files_rx.len() == 0 {
batch_files_tx.try_send(EmbeddingJob::Flush).unwrap();
embedding_queue.lock().flush();
}
}
@ -881,7 +711,7 @@ impl SemanticIndex {
let database = VectorDatabase::new(fs.clone(), database_url.clone()).await?;
let phrase_embedding = embedding_provider
.embed_batch(vec![&phrase])
.embed_batch(vec![phrase])
.await?
.into_iter()
.next()

View File

@ -235,17 +235,15 @@ async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) {
.collect::<Vec<_>>();
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
let mut queue = EmbeddingQueue::new(embedding_provider.clone());
let finished_files = cx.update(|cx| {
for file in &files {
queue.push(file.clone(), cx);
}
queue.flush(cx);
queue.finished_files()
});
let mut queue = EmbeddingQueue::new(embedding_provider.clone(), cx.background());
for file in &files {
queue.push(file.clone());
}
queue.flush();
cx.foreground().run_until_parked();
let finished_files = queue.finished_files();
let mut embedded_files: Vec<_> = files
.iter()
.map(|_| finished_files.try_recv().expect("no finished file"))