add initial search inside modified buffers

This commit is contained in:
KCaverly 2023-09-14 14:58:34 -04:00
parent f86e5a987f
commit c19c8899fe
3 changed files with 216 additions and 65 deletions

View File

@ -278,6 +278,39 @@ impl VectorDatabase {
}) })
} }
pub fn embeddings_for_digests(
&self,
digests: Vec<SpanDigest>,
) -> impl Future<Output = Result<HashMap<SpanDigest, Embedding>>> {
self.transact(move |db| {
let mut query = db.prepare(
"
SELECT digest, embedding
FROM spans
WHERE digest IN rarray(?)
",
)?;
let mut embeddings_by_digest = HashMap::default();
let digests = Rc::new(
digests
.into_iter()
.map(|p| Value::Blob(p.0.to_vec()))
.collect::<Vec<_>>(),
);
let rows = query.query_map(params![digests], |row| {
Ok((row.get::<_, SpanDigest>(0)?, row.get::<_, Embedding>(1)?))
})?;
for row in rows {
if let Ok(row) = row {
embeddings_by_digest.insert(row.0, row.1);
}
}
Ok(embeddings_by_digest)
})
}
pub fn embeddings_for_files( pub fn embeddings_for_files(
&self, &self,
worktree_id_file_paths: HashMap<i64, Vec<Arc<Path>>>, worktree_id_file_paths: HashMap<i64, Vec<Arc<Path>>>,

View File

@ -17,7 +17,7 @@ use std::{
use tree_sitter::{Parser, QueryCursor}; use tree_sitter::{Parser, QueryCursor};
#[derive(Debug, PartialEq, Eq, Clone, Hash)] #[derive(Debug, PartialEq, Eq, Clone, Hash)]
pub struct SpanDigest([u8; 20]); pub struct SpanDigest(pub [u8; 20]);
impl FromSql for SpanDigest { impl FromSql for SpanDigest {
fn column_result(value: ValueRef) -> FromSqlResult<Self> { fn column_result(value: ValueRef) -> FromSqlResult<Self> {

View File

@ -263,9 +263,11 @@ pub struct PendingFile {
job_handle: JobHandle, job_handle: JobHandle,
} }
#[derive(Clone)]
pub struct SearchResult { pub struct SearchResult {
pub buffer: ModelHandle<Buffer>, pub buffer: ModelHandle<Buffer>,
pub range: Range<Anchor>, pub range: Range<Anchor>,
pub similarity: f32,
} }
impl SemanticIndex { impl SemanticIndex {
@ -775,7 +777,8 @@ impl SemanticIndex {
.filter_map(|buffer_handle| { .filter_map(|buffer_handle| {
let buffer = buffer_handle.read(cx); let buffer = buffer_handle.read(cx);
if buffer.is_dirty() { if buffer.is_dirty() {
Some((buffer_handle.downgrade(), buffer.snapshot())) // TOOD: @as-cii I removed the downgrade for now to fix the compiler - @kcaverly
Some((buffer_handle, buffer.snapshot()))
} else { } else {
None None
} }
@ -783,77 +786,133 @@ impl SemanticIndex {
.collect::<HashMap<_, _>>() .collect::<HashMap<_, _>>()
}); });
cx.background() let buffer_results = if let Some(db) =
.spawn({ VectorDatabase::new(fs, db_path.clone(), cx.background())
let mut retriever = CodeContextRetriever::new(embedding_provider.clone()); .await
let embedding_provider = embedding_provider.clone(); .log_err()
let phrase_embedding = phrase_embedding.clone(); {
async move { cx.background()
let mut results = Vec::new(); .spawn({
'buffers: for (buffer_handle, buffer_snapshot) in dirty_buffers { let mut retriever = CodeContextRetriever::new(embedding_provider.clone());
let language = buffer_snapshot let embedding_provider = embedding_provider.clone();
.language_at(0) let phrase_embedding = phrase_embedding.clone();
.cloned() async move {
.unwrap_or_else(|| language::PLAIN_TEXT.clone()); let mut results = Vec::<SearchResult>::new();
if let Some(spans) = retriever 'buffers: for (buffer_handle, buffer_snapshot) in dirty_buffers {
.parse_file_with_template(None, &buffer_snapshot.text(), language) let language = buffer_snapshot
.log_err() .language_at(0)
{ .cloned()
let mut batch = Vec::new(); .unwrap_or_else(|| language::PLAIN_TEXT.clone());
let mut batch_tokens = 0; if let Some(spans) = retriever
let mut embeddings = Vec::new(); .parse_file_with_template(
None,
&buffer_snapshot.text(),
language,
)
.log_err()
{
let mut batch = Vec::new();
let mut batch_tokens = 0;
let mut embeddings = Vec::new();
// TODO: query span digests in the database to avoid embedding them again. let digests = spans
.iter()
.map(|span| span.digest.clone())
.collect::<Vec<_>>();
let embeddings_for_digests = db
.embeddings_for_digests(digests)
.await
.map_or(Default::default(), |m| m);
for span in &spans { for span in &spans {
if span.embedding.is_some() { if embeddings_for_digests.contains_key(&span.digest) {
continue; continue;
};
if batch_tokens + span.token_count
> embedding_provider.max_tokens_per_batch()
{
if let Some(batch_embeddings) = embedding_provider
.embed_batch(mem::take(&mut batch))
.await
.log_err()
{
embeddings.extend(batch_embeddings);
batch_tokens = 0;
} else {
continue 'buffers;
}
}
batch_tokens += span.token_count;
batch.push(span.content.clone());
} }
if batch_tokens + span.token_count if let Some(batch_embeddings) = embedding_provider
> embedding_provider.max_tokens_per_batch() .embed_batch(mem::take(&mut batch))
.await
.log_err()
{ {
if let Some(batch_embeddings) = embedding_provider embeddings.extend(batch_embeddings);
.embed_batch(mem::take(&mut batch)) } else {
.await continue 'buffers;
.log_err() }
let mut embeddings = embeddings.into_iter();
for span in spans {
let embedding = if let Some(embedding) =
embeddings_for_digests.get(&span.digest)
{ {
embeddings.extend(batch_embeddings); Some(embedding.clone())
batch_tokens = 0;
} else { } else {
embeddings.next()
};
if let Some(embedding) = embedding {
let similarity =
embedding.similarity(&phrase_embedding);
let ix = match results.binary_search_by(|s| {
similarity
.partial_cmp(&s.similarity)
.unwrap_or(Ordering::Equal)
}) {
Ok(ix) => ix,
Err(ix) => ix,
};
let range = {
let start = buffer_snapshot
.clip_offset(span.range.start, Bias::Left);
let end = buffer_snapshot
.clip_offset(span.range.end, Bias::Right);
buffer_snapshot.anchor_before(start)
..buffer_snapshot.anchor_after(end)
};
results.insert(
ix,
SearchResult {
buffer: buffer_handle.clone(),
range,
similarity,
},
);
results.truncate(limit);
} else {
log::error!("failed to embed span");
continue 'buffers; continue 'buffers;
} }
} }
batch_tokens += span.token_count;
batch.push(span.content.clone());
}
if let Some(batch_embeddings) = embedding_provider
.embed_batch(mem::take(&mut batch))
.await
.log_err()
{
embeddings.extend(batch_embeddings);
} else {
continue 'buffers;
}
let mut embeddings = embeddings.into_iter();
for span in spans {
let embedding = span.embedding.or_else(|| embeddings.next());
if let Some(embedding) = embedding {
todo!()
} else {
log::error!("failed to embed span");
continue 'buffers;
}
} }
} }
anyhow::Ok(results)
} }
} })
}) .await
.await; } else {
Ok(Vec::new())
};
let batch_results = futures::future::join_all(batch_results).await; let batch_results = futures::future::join_all(batch_results).await;
@ -873,7 +932,11 @@ impl SemanticIndex {
} }
} }
let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<i64>>(); let ids = results.iter().map(|(id, _)| *id).collect::<Vec<i64>>();
let scores = results
.into_iter()
.map(|(_, score)| score)
.collect::<Vec<f32>>();
let spans = database.spans_for_ids(ids.as_slice()).await?; let spans = database.spans_for_ids(ids.as_slice()).await?;
let mut tasks = Vec::new(); let mut tasks = Vec::new();
@ -903,19 +966,74 @@ impl SemanticIndex {
t0.elapsed().as_millis() t0.elapsed().as_millis()
); );
Ok(buffers let database_results = buffers
.into_iter() .into_iter()
.zip(ranges) .zip(ranges)
.filter_map(|(buffer, range)| { .zip(scores)
.filter_map(|((buffer, range), similarity)| {
let buffer = buffer.log_err()?; let buffer = buffer.log_err()?;
let range = buffer.read_with(&cx, |buffer, _| { let range = buffer.read_with(&cx, |buffer, _| {
let start = buffer.clip_offset(range.start, Bias::Left); let start = buffer.clip_offset(range.start, Bias::Left);
let end = buffer.clip_offset(range.end, Bias::Right); let end = buffer.clip_offset(range.end, Bias::Right);
buffer.anchor_before(start)..buffer.anchor_after(end) buffer.anchor_before(start)..buffer.anchor_after(end)
}); });
Some(SearchResult { buffer, range }) Some(SearchResult {
buffer,
range,
similarity,
})
}) })
.collect::<Vec<_>>()) .collect::<Vec<_>>();
// Stitch Together Database Results & Buffer Results
if let Ok(buffer_results) = buffer_results {
let mut buffer_map = HashMap::default();
for buffer_result in buffer_results {
buffer_map
.entry(buffer_result.clone().buffer)
.or_insert(Vec::new())
.push(buffer_result);
}
for db_result in database_results {
if !buffer_map.contains_key(&db_result.buffer) {
buffer_map
.entry(db_result.clone().buffer)
.or_insert(Vec::new())
.push(db_result);
}
}
let mut full_results = Vec::<SearchResult>::new();
for (_, results) in buffer_map {
for res in results.into_iter() {
let ix = match full_results.binary_search_by(|search_result| {
res.similarity
.partial_cmp(&search_result.similarity)
.unwrap_or(Ordering::Equal)
}) {
Ok(ix) => ix,
Err(ix) => ix,
};
full_results.insert(ix, res);
full_results.truncate(limit);
}
}
return Ok(full_results);
} else {
return Ok(database_results);
}
// let ix = match results.binary_search_by(|(_, s)| {
// similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)
// }) {
// Ok(ix) => ix,
// Err(ix) => ix,
// };
// results.insert(ix, (id, similarity));
// results.truncate(limit);
}) })
} }