mirror of
https://github.com/zed-industries/zed.git
synced 2024-11-08 07:35:01 +03:00
moved semantic index to use embeddings queue to batch and managed for atomic database writes
Co-authored-by: Max <max@zed.dev>
This commit is contained in:
parent
76ce52df4e
commit
5abad58b0d
@ -1,10 +1,8 @@
|
||||
use std::{mem, ops::Range, path::PathBuf, sync::Arc, time::SystemTime};
|
||||
|
||||
use gpui::AppContext;
|
||||
use crate::{embedding::EmbeddingProvider, parsing::Document, JobHandle};
|
||||
use gpui::executor::Background;
|
||||
use parking_lot::Mutex;
|
||||
use smol::channel;
|
||||
|
||||
use crate::{embedding::EmbeddingProvider, parsing::Document, JobHandle};
|
||||
use std::{mem, ops::Range, path::PathBuf, sync::Arc, time::SystemTime};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct FileToEmbed {
|
||||
@ -38,6 +36,7 @@ impl PartialEq for FileToEmbed {
|
||||
pub struct EmbeddingQueue {
|
||||
embedding_provider: Arc<dyn EmbeddingProvider>,
|
||||
pending_batch: Vec<FileToEmbedFragment>,
|
||||
executor: Arc<Background>,
|
||||
pending_batch_token_count: usize,
|
||||
finished_files_tx: channel::Sender<FileToEmbed>,
|
||||
finished_files_rx: channel::Receiver<FileToEmbed>,
|
||||
@ -49,10 +48,11 @@ pub struct FileToEmbedFragment {
|
||||
}
|
||||
|
||||
impl EmbeddingQueue {
|
||||
pub fn new(embedding_provider: Arc<dyn EmbeddingProvider>) -> Self {
|
||||
pub fn new(embedding_provider: Arc<dyn EmbeddingProvider>, executor: Arc<Background>) -> Self {
|
||||
let (finished_files_tx, finished_files_rx) = channel::unbounded();
|
||||
Self {
|
||||
embedding_provider,
|
||||
executor,
|
||||
pending_batch: Vec::new(),
|
||||
pending_batch_token_count: 0,
|
||||
finished_files_tx,
|
||||
@ -60,7 +60,12 @@ impl EmbeddingQueue {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn push(&mut self, file: FileToEmbed, cx: &mut AppContext) {
|
||||
pub fn push(&mut self, file: FileToEmbed) {
|
||||
if file.documents.is_empty() {
|
||||
self.finished_files_tx.try_send(file).unwrap();
|
||||
return;
|
||||
}
|
||||
|
||||
let file = Arc::new(Mutex::new(file));
|
||||
|
||||
self.pending_batch.push(FileToEmbedFragment {
|
||||
@ -73,7 +78,7 @@ impl EmbeddingQueue {
|
||||
let next_token_count = self.pending_batch_token_count + document.token_count;
|
||||
if next_token_count > self.embedding_provider.max_tokens_per_batch() {
|
||||
let range_end = fragment_range.end;
|
||||
self.flush(cx);
|
||||
self.flush();
|
||||
self.pending_batch.push(FileToEmbedFragment {
|
||||
file: file.clone(),
|
||||
document_range: range_end..range_end,
|
||||
@ -86,7 +91,7 @@ impl EmbeddingQueue {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn flush(&mut self, cx: &mut AppContext) {
|
||||
pub fn flush(&mut self) {
|
||||
let batch = mem::take(&mut self.pending_batch);
|
||||
self.pending_batch_token_count = 0;
|
||||
if batch.is_empty() {
|
||||
@ -95,7 +100,7 @@ impl EmbeddingQueue {
|
||||
|
||||
let finished_files_tx = self.finished_files_tx.clone();
|
||||
let embedding_provider = self.embedding_provider.clone();
|
||||
cx.background().spawn(async move {
|
||||
self.executor.spawn(async move {
|
||||
let mut spans = Vec::new();
|
||||
for fragment in &batch {
|
||||
let file = fragment.file.lock();
|
||||
|
@ -1,5 +1,6 @@
|
||||
mod db;
|
||||
mod embedding;
|
||||
mod embedding_queue;
|
||||
mod parsing;
|
||||
pub mod semantic_index_settings;
|
||||
|
||||
@ -10,6 +11,7 @@ use crate::semantic_index_settings::SemanticIndexSettings;
|
||||
use anyhow::{anyhow, Result};
|
||||
use db::VectorDatabase;
|
||||
use embedding::{EmbeddingProvider, OpenAIEmbeddings};
|
||||
use embedding_queue::{EmbeddingQueue, FileToEmbed};
|
||||
use futures::{channel::oneshot, Future};
|
||||
use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle};
|
||||
use language::{Anchor, Buffer, Language, LanguageRegistry};
|
||||
@ -23,7 +25,6 @@ use smol::channel;
|
||||
use std::{
|
||||
cmp::Ordering,
|
||||
collections::{BTreeMap, HashMap},
|
||||
mem,
|
||||
ops::Range,
|
||||
path::{Path, PathBuf},
|
||||
sync::{Arc, Weak},
|
||||
@ -38,7 +39,6 @@ use util::{
|
||||
use workspace::WorkspaceCreated;
|
||||
|
||||
const SEMANTIC_INDEX_VERSION: usize = 7;
|
||||
const EMBEDDINGS_BATCH_SIZE: usize = 80;
|
||||
const BACKGROUND_INDEXING_DELAY: Duration = Duration::from_secs(600);
|
||||
|
||||
pub fn init(
|
||||
@ -106,9 +106,8 @@ pub struct SemanticIndex {
|
||||
language_registry: Arc<LanguageRegistry>,
|
||||
db_update_tx: channel::Sender<DbOperation>,
|
||||
parsing_files_tx: channel::Sender<PendingFile>,
|
||||
_embedding_task: Task<()>,
|
||||
_db_update_task: Task<()>,
|
||||
_embed_batch_tasks: Vec<Task<()>>,
|
||||
_batch_files_task: Task<()>,
|
||||
_parsing_files_tasks: Vec<Task<()>>,
|
||||
projects: HashMap<WeakModelHandle<Project>, ProjectState>,
|
||||
}
|
||||
@ -128,7 +127,7 @@ struct ChangedPathInfo {
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct JobHandle {
|
||||
pub struct JobHandle {
|
||||
/// The outer Arc is here to count the clones of a JobHandle instance;
|
||||
/// when the last handle to a given job is dropped, we decrement a counter (just once).
|
||||
tx: Arc<Weak<Mutex<watch::Sender<usize>>>>,
|
||||
@ -230,17 +229,6 @@ enum DbOperation {
|
||||
},
|
||||
}
|
||||
|
||||
enum EmbeddingJob {
|
||||
Enqueue {
|
||||
worktree_id: i64,
|
||||
path: PathBuf,
|
||||
mtime: SystemTime,
|
||||
documents: Vec<Document>,
|
||||
job_handle: JobHandle,
|
||||
},
|
||||
Flush,
|
||||
}
|
||||
|
||||
impl SemanticIndex {
|
||||
pub fn global(cx: &AppContext) -> Option<ModelHandle<SemanticIndex>> {
|
||||
if cx.has_global::<ModelHandle<Self>>() {
|
||||
@ -287,52 +275,35 @@ impl SemanticIndex {
|
||||
}
|
||||
});
|
||||
|
||||
// Group documents into batches and send them to the embedding provider.
|
||||
let (embed_batch_tx, embed_batch_rx) =
|
||||
channel::unbounded::<Vec<(i64, Vec<Document>, PathBuf, SystemTime, JobHandle)>>();
|
||||
let mut _embed_batch_tasks = Vec::new();
|
||||
for _ in 0..cx.background().num_cpus() {
|
||||
let embed_batch_rx = embed_batch_rx.clone();
|
||||
_embed_batch_tasks.push(cx.background().spawn({
|
||||
let db_update_tx = db_update_tx.clone();
|
||||
let embedding_provider = embedding_provider.clone();
|
||||
async move {
|
||||
while let Ok(embeddings_queue) = embed_batch_rx.recv().await {
|
||||
Self::compute_embeddings_for_batch(
|
||||
embeddings_queue,
|
||||
&embedding_provider,
|
||||
&db_update_tx,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
let embedding_queue =
|
||||
EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone());
|
||||
let _embedding_task = cx.background().spawn({
|
||||
let embedded_files = embedding_queue.finished_files();
|
||||
let db_update_tx = db_update_tx.clone();
|
||||
async move {
|
||||
while let Ok(file) = embedded_files.recv().await {
|
||||
db_update_tx
|
||||
.try_send(DbOperation::InsertFile {
|
||||
worktree_id: file.worktree_id,
|
||||
documents: file.documents,
|
||||
path: file.path,
|
||||
mtime: file.mtime,
|
||||
job_handle: file.job_handle,
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
// Group documents into batches and send them to the embedding provider.
|
||||
let (batch_files_tx, batch_files_rx) = channel::unbounded::<EmbeddingJob>();
|
||||
let _batch_files_task = cx.background().spawn(async move {
|
||||
let mut queue_len = 0;
|
||||
let mut embeddings_queue = vec![];
|
||||
while let Ok(job) = batch_files_rx.recv().await {
|
||||
Self::enqueue_documents_to_embed(
|
||||
job,
|
||||
&mut queue_len,
|
||||
&mut embeddings_queue,
|
||||
&embed_batch_tx,
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
// Parse files into embeddable documents.
|
||||
let (parsing_files_tx, parsing_files_rx) = channel::unbounded::<PendingFile>();
|
||||
let embedding_queue = Arc::new(Mutex::new(embedding_queue));
|
||||
let mut _parsing_files_tasks = Vec::new();
|
||||
for _ in 0..cx.background().num_cpus() {
|
||||
let fs = fs.clone();
|
||||
let parsing_files_rx = parsing_files_rx.clone();
|
||||
let batch_files_tx = batch_files_tx.clone();
|
||||
let db_update_tx = db_update_tx.clone();
|
||||
let embedding_provider = embedding_provider.clone();
|
||||
let embedding_queue = embedding_queue.clone();
|
||||
_parsing_files_tasks.push(cx.background().spawn(async move {
|
||||
let mut retriever = CodeContextRetriever::new(embedding_provider.clone());
|
||||
while let Ok(pending_file) = parsing_files_rx.recv().await {
|
||||
@ -340,9 +311,8 @@ impl SemanticIndex {
|
||||
&fs,
|
||||
pending_file,
|
||||
&mut retriever,
|
||||
&batch_files_tx,
|
||||
&embedding_queue,
|
||||
&parsing_files_rx,
|
||||
&db_update_tx,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
@ -361,8 +331,7 @@ impl SemanticIndex {
|
||||
db_update_tx,
|
||||
parsing_files_tx,
|
||||
_db_update_task,
|
||||
_embed_batch_tasks,
|
||||
_batch_files_task,
|
||||
_embedding_task,
|
||||
_parsing_files_tasks,
|
||||
projects: HashMap::new(),
|
||||
}
|
||||
@ -403,136 +372,12 @@ impl SemanticIndex {
|
||||
}
|
||||
}
|
||||
|
||||
async fn compute_embeddings_for_batch(
|
||||
mut embeddings_queue: Vec<(i64, Vec<Document>, PathBuf, SystemTime, JobHandle)>,
|
||||
embedding_provider: &Arc<dyn EmbeddingProvider>,
|
||||
db_update_tx: &channel::Sender<DbOperation>,
|
||||
) {
|
||||
let mut batch_documents = vec![];
|
||||
for (_, documents, _, _, _) in embeddings_queue.iter() {
|
||||
batch_documents.extend(documents.iter().map(|document| document.content.as_str()));
|
||||
}
|
||||
|
||||
if let Ok(embeddings) = embedding_provider.embed_batch(batch_documents).await {
|
||||
log::trace!(
|
||||
"created {} embeddings for {} files",
|
||||
embeddings.len(),
|
||||
embeddings_queue.len(),
|
||||
);
|
||||
|
||||
let mut i = 0;
|
||||
let mut j = 0;
|
||||
|
||||
for embedding in embeddings.iter() {
|
||||
while embeddings_queue[i].1.len() == j {
|
||||
i += 1;
|
||||
j = 0;
|
||||
}
|
||||
|
||||
embeddings_queue[i].1[j].embedding = embedding.to_owned();
|
||||
j += 1;
|
||||
}
|
||||
|
||||
for (worktree_id, documents, path, mtime, job_handle) in embeddings_queue.into_iter() {
|
||||
db_update_tx
|
||||
.send(DbOperation::InsertFile {
|
||||
worktree_id,
|
||||
documents,
|
||||
path,
|
||||
mtime,
|
||||
job_handle,
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
} else {
|
||||
// Insert the file in spite of failure so that future attempts to index it do not take place (unless the file is changed).
|
||||
for (worktree_id, _, path, mtime, job_handle) in embeddings_queue.into_iter() {
|
||||
db_update_tx
|
||||
.send(DbOperation::InsertFile {
|
||||
worktree_id,
|
||||
documents: vec![],
|
||||
path,
|
||||
mtime,
|
||||
job_handle,
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn enqueue_documents_to_embed(
|
||||
job: EmbeddingJob,
|
||||
queue_len: &mut usize,
|
||||
embeddings_queue: &mut Vec<(i64, Vec<Document>, PathBuf, SystemTime, JobHandle)>,
|
||||
embed_batch_tx: &channel::Sender<Vec<(i64, Vec<Document>, PathBuf, SystemTime, JobHandle)>>,
|
||||
) {
|
||||
// Handle edge case where individual file has more documents than max batch size
|
||||
let should_flush = match job {
|
||||
EmbeddingJob::Enqueue {
|
||||
documents,
|
||||
worktree_id,
|
||||
path,
|
||||
mtime,
|
||||
job_handle,
|
||||
} => {
|
||||
// If documents is greater than embeddings batch size, recursively batch existing rows.
|
||||
if &documents.len() > &EMBEDDINGS_BATCH_SIZE {
|
||||
let first_job = EmbeddingJob::Enqueue {
|
||||
documents: documents[..EMBEDDINGS_BATCH_SIZE].to_vec(),
|
||||
worktree_id,
|
||||
path: path.clone(),
|
||||
mtime,
|
||||
job_handle: job_handle.clone(),
|
||||
};
|
||||
|
||||
Self::enqueue_documents_to_embed(
|
||||
first_job,
|
||||
queue_len,
|
||||
embeddings_queue,
|
||||
embed_batch_tx,
|
||||
);
|
||||
|
||||
let second_job = EmbeddingJob::Enqueue {
|
||||
documents: documents[EMBEDDINGS_BATCH_SIZE..].to_vec(),
|
||||
worktree_id,
|
||||
path: path.clone(),
|
||||
mtime,
|
||||
job_handle: job_handle.clone(),
|
||||
};
|
||||
|
||||
Self::enqueue_documents_to_embed(
|
||||
second_job,
|
||||
queue_len,
|
||||
embeddings_queue,
|
||||
embed_batch_tx,
|
||||
);
|
||||
return;
|
||||
} else {
|
||||
*queue_len += &documents.len();
|
||||
embeddings_queue.push((worktree_id, documents, path, mtime, job_handle));
|
||||
*queue_len >= EMBEDDINGS_BATCH_SIZE
|
||||
}
|
||||
}
|
||||
EmbeddingJob::Flush => true,
|
||||
};
|
||||
|
||||
if should_flush {
|
||||
embed_batch_tx
|
||||
.try_send(mem::take(embeddings_queue))
|
||||
.unwrap();
|
||||
*queue_len = 0;
|
||||
}
|
||||
}
|
||||
|
||||
async fn parse_file(
|
||||
fs: &Arc<dyn Fs>,
|
||||
pending_file: PendingFile,
|
||||
retriever: &mut CodeContextRetriever,
|
||||
batch_files_tx: &channel::Sender<EmbeddingJob>,
|
||||
embedding_queue: &Arc<Mutex<EmbeddingQueue>>,
|
||||
parsing_files_rx: &channel::Receiver<PendingFile>,
|
||||
db_update_tx: &channel::Sender<DbOperation>,
|
||||
) {
|
||||
let Some(language) = pending_file.language else {
|
||||
return;
|
||||
@ -549,33 +394,18 @@ impl SemanticIndex {
|
||||
documents.len()
|
||||
);
|
||||
|
||||
if documents.len() == 0 {
|
||||
db_update_tx
|
||||
.send(DbOperation::InsertFile {
|
||||
worktree_id: pending_file.worktree_db_id,
|
||||
documents,
|
||||
path: pending_file.relative_path,
|
||||
mtime: pending_file.modified_time,
|
||||
job_handle: pending_file.job_handle,
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
} else {
|
||||
batch_files_tx
|
||||
.try_send(EmbeddingJob::Enqueue {
|
||||
worktree_id: pending_file.worktree_db_id,
|
||||
path: pending_file.relative_path,
|
||||
mtime: pending_file.modified_time,
|
||||
job_handle: pending_file.job_handle,
|
||||
documents,
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
embedding_queue.lock().push(FileToEmbed {
|
||||
worktree_id: pending_file.worktree_db_id,
|
||||
path: pending_file.relative_path,
|
||||
mtime: pending_file.modified_time,
|
||||
job_handle: pending_file.job_handle,
|
||||
documents,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if parsing_files_rx.len() == 0 {
|
||||
batch_files_tx.try_send(EmbeddingJob::Flush).unwrap();
|
||||
embedding_queue.lock().flush();
|
||||
}
|
||||
}
|
||||
|
||||
@ -881,7 +711,7 @@ impl SemanticIndex {
|
||||
let database = VectorDatabase::new(fs.clone(), database_url.clone()).await?;
|
||||
|
||||
let phrase_embedding = embedding_provider
|
||||
.embed_batch(vec![&phrase])
|
||||
.embed_batch(vec![phrase])
|
||||
.await?
|
||||
.into_iter()
|
||||
.next()
|
||||
|
@ -235,17 +235,15 @@ async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) {
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
|
||||
let mut queue = EmbeddingQueue::new(embedding_provider.clone());
|
||||
|
||||
let finished_files = cx.update(|cx| {
|
||||
for file in &files {
|
||||
queue.push(file.clone(), cx);
|
||||
}
|
||||
queue.flush(cx);
|
||||
queue.finished_files()
|
||||
});
|
||||
let mut queue = EmbeddingQueue::new(embedding_provider.clone(), cx.background());
|
||||
for file in &files {
|
||||
queue.push(file.clone());
|
||||
}
|
||||
queue.flush();
|
||||
|
||||
cx.foreground().run_until_parked();
|
||||
let finished_files = queue.finished_files();
|
||||
let mut embedded_files: Vec<_> = files
|
||||
.iter()
|
||||
.map(|_| finished_files.try_recv().expect("no finished file"))
|
||||
|
Loading…
Reference in New Issue
Block a user