batch search queries in the vector database

This commit is contained in:
KCaverly 2023-07-26 16:36:39 -04:00
parent 6cd10f3d5e
commit 98fde36834
3 changed files with 106 additions and 45 deletions

View File

@ -30,6 +30,7 @@ use std::{
ops::{Not, Range},
path::PathBuf,
sync::Arc,
time::Instant,
};
use util::ResultExt as _;
use workspace::{
@ -192,6 +193,7 @@ impl ProjectSearch {
exclude_files: Vec<GlobMatcher>,
cx: &mut ModelContext<Self>,
) {
let t0 = Instant::now();
let search = SemanticIndex::global(cx).map(|index| {
index.update(cx, |semantic_index, cx| {
semantic_index.search_project(
@ -208,6 +210,7 @@ impl ProjectSearch {
self.match_ranges.clear();
self.pending_search = Some(cx.spawn(|this, mut cx| async move {
let results = search?.await.log_err()?;
log::trace!("semantic search elapsed: {:?}", t0.elapsed().as_millis());
let (_task, mut match_ranges) = this.update(&mut cx, |this, cx| {
this.excerpts.update(cx, |excerpts, cx| {

View File

@ -267,41 +267,56 @@ impl VectorDatabase {
pub fn top_k_search(
&self,
worktree_ids: &[i64],
query_embedding: &Vec<f32>,
limit: usize,
include_globs: Vec<GlobMatcher>,
exclude_globs: Vec<GlobMatcher>,
) -> Result<Vec<(i64, PathBuf, Range<usize>)>> {
file_ids: &[i64],
) -> Result<Vec<(i64, f32)>> {
let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
self.for_each_document(
&worktree_ids,
include_globs,
exclude_globs,
|id, embedding| {
let similarity = dot(&embedding, &query_embedding);
let ix = match results.binary_search_by(|(_, s)| {
similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)
}) {
Ok(ix) => ix,
Err(ix) => ix,
};
results.insert(ix, (id, similarity));
results.truncate(limit);
},
)?;
self.for_each_document(file_ids, |id, embedding| {
let similarity = dot(&embedding, &query_embedding);
let ix = match results
.binary_search_by(|(_, s)| similarity.partial_cmp(&s).unwrap_or(Ordering::Equal))
{
Ok(ix) => ix,
Err(ix) => ix,
};
results.insert(ix, (id, similarity));
results.truncate(limit);
})?;
let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<_>>();
self.get_documents_by_ids(&ids)
Ok(results)
}
fn for_each_document(
// pub fn top_k_search(
// &self,
// worktree_ids: &[i64],
// query_embedding: &Vec<f32>,
// limit: usize,
// file_ids: Vec<i64>,
// ) -> Result<Vec<(i64, PathBuf, Range<usize>)>> {
// let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
// self.for_each_document(&worktree_ids, file_ids, |id, embedding| {
// let similarity = dot(&embedding, &query_embedding);
// let ix = match results
// .binary_search_by(|(_, s)| similarity.partial_cmp(&s).unwrap_or(Ordering::Equal))
// {
// Ok(ix) => ix,
// Err(ix) => ix,
// };
// results.insert(ix, (id, similarity));
// results.truncate(limit);
// })?;
// let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<_>>();
// self.get_documents_by_ids(&ids)
// }
pub fn retrieve_included_file_ids(
&self,
worktree_ids: &[i64],
include_globs: Vec<GlobMatcher>,
exclude_globs: Vec<GlobMatcher>,
mut f: impl FnMut(i64, Vec<f32>),
) -> Result<()> {
) -> Result<Vec<i64>> {
let mut file_query = self.db.prepare(
"
SELECT
@ -315,6 +330,7 @@ impl VectorDatabase {
let mut file_ids = Vec::<i64>::new();
let mut rows = file_query.query([ids_to_sql(worktree_ids)])?;
while let Some(row) = rows.next()? {
let file_id = row.get(0)?;
let relative_path = row.get_ref(1)?.as_str()?;
@ -330,6 +346,10 @@ impl VectorDatabase {
}
}
Ok(file_ids)
}
fn for_each_document(&self, file_ids: &[i64], mut f: impl FnMut(i64, Vec<f32>)) -> Result<()> {
let mut query_statement = self.db.prepare(
"
SELECT
@ -350,7 +370,7 @@ impl VectorDatabase {
Ok(())
}
fn get_documents_by_ids(&self, ids: &[i64]) -> Result<Vec<(i64, PathBuf, Range<usize>)>> {
pub fn get_documents_by_ids(&self, ids: &[i64]) -> Result<Vec<(i64, PathBuf, Range<usize>)>> {
let mut statement = self.db.prepare(
"
SELECT

View File

@ -20,6 +20,7 @@ use postage::watch;
use project::{Fs, Project, WorktreeId};
use smol::channel;
use std::{
cmp::Ordering,
collections::HashMap,
mem,
ops::Range,
@ -704,27 +705,64 @@ impl SemanticIndex {
let database_url = self.database_url.clone();
let fs = self.fs.clone();
cx.spawn(|this, mut cx| async move {
let documents = cx
.background()
.spawn(async move {
let database = VectorDatabase::new(fs, database_url).await?;
let database = VectorDatabase::new(fs.clone(), database_url.clone()).await?;
let phrase_embedding = embedding_provider
.embed_batch(vec![&phrase])
.await?
.into_iter()
.next()
.unwrap();
let phrase_embedding = embedding_provider
.embed_batch(vec![&phrase])
.await?
.into_iter()
.next()
.unwrap();
database.top_k_search(
&worktree_db_ids,
&phrase_embedding,
limit,
include_globs,
exclude_globs,
)
})
.await?;
let file_ids = database.retrieve_included_file_ids(
&worktree_db_ids,
include_globs,
exclude_globs,
)?;
let batch_n = cx.background().num_cpus();
let batch_size = file_ids.clone().len() / batch_n;
let mut result_tasks = Vec::new();
for batch in file_ids.chunks(batch_size) {
let batch = batch.into_iter().map(|v| *v).collect::<Vec<i64>>();
let limit = limit.clone();
let fs = fs.clone();
let database_url = database_url.clone();
let phrase_embedding = phrase_embedding.clone();
let task = cx.background().spawn(async move {
let database = VectorDatabase::new(fs, database_url).await.log_err();
if database.is_none() {
return Err(anyhow!("failed to acquire database connection"));
} else {
database
.unwrap()
.top_k_search(&phrase_embedding, limit, batch.as_slice())
}
});
result_tasks.push(task);
}
let batch_results = futures::future::join_all(result_tasks).await;
let mut results = Vec::new();
for batch_result in batch_results {
if batch_result.is_ok() {
for (id, similarity) in batch_result.unwrap() {
let ix = match results.binary_search_by(|(_, s)| {
similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)
}) {
Ok(ix) => ix,
Err(ix) => ix,
};
results.insert(ix, (id, similarity));
results.truncate(limit);
}
}
}
let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<i64>>();
let documents = database.get_documents_by_ids(ids.as_slice())?;
let mut tasks = Vec::new();
let mut ranges = Vec::new();