From c19c8899fe1c27f3029dda4ea3a071460f1b9560 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Thu, 14 Sep 2023 14:58:34 -0400 Subject: [PATCH] add initial search inside modified buffers --- crates/semantic_index/src/db.rs | 33 +++ crates/semantic_index/src/parsing.rs | 2 +- crates/semantic_index/src/semantic_index.rs | 246 +++++++++++++++----- 3 files changed, 216 insertions(+), 65 deletions(-) diff --git a/crates/semantic_index/src/db.rs b/crates/semantic_index/src/db.rs index 15172323c2..cad0734e76 100644 --- a/crates/semantic_index/src/db.rs +++ b/crates/semantic_index/src/db.rs @@ -278,6 +278,39 @@ impl VectorDatabase { }) } + pub fn embeddings_for_digests( + &self, + digests: Vec, + ) -> impl Future>> { + self.transact(move |db| { + let mut query = db.prepare( + " + SELECT digest, embedding + FROM spans + WHERE digest IN rarray(?) + ", + )?; + let mut embeddings_by_digest = HashMap::default(); + let digests = Rc::new( + digests + .into_iter() + .map(|p| Value::Blob(p.0.to_vec())) + .collect::>(), + ); + let rows = query.query_map(params![digests], |row| { + Ok((row.get::<_, SpanDigest>(0)?, row.get::<_, Embedding>(1)?)) + })?; + + for row in rows { + if let Ok(row) = row { + embeddings_by_digest.insert(row.0, row.1); + } + } + + Ok(embeddings_by_digest) + }) + } + pub fn embeddings_for_files( &self, worktree_id_file_paths: HashMap>>, diff --git a/crates/semantic_index/src/parsing.rs b/crates/semantic_index/src/parsing.rs index 49d748a07c..9f5a339b23 100644 --- a/crates/semantic_index/src/parsing.rs +++ b/crates/semantic_index/src/parsing.rs @@ -17,7 +17,7 @@ use std::{ use tree_sitter::{Parser, QueryCursor}; #[derive(Debug, PartialEq, Eq, Clone, Hash)] -pub struct SpanDigest([u8; 20]); +pub struct SpanDigest(pub [u8; 20]); impl FromSql for SpanDigest { fn column_result(value: ValueRef) -> FromSqlResult { diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index 6dd5572ab0..056f6a3386 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -263,9 +263,11 @@ pub struct PendingFile { job_handle: JobHandle, } +#[derive(Clone)] pub struct SearchResult { pub buffer: ModelHandle, pub range: Range, + pub similarity: f32, } impl SemanticIndex { @@ -775,7 +777,8 @@ impl SemanticIndex { .filter_map(|buffer_handle| { let buffer = buffer_handle.read(cx); if buffer.is_dirty() { - Some((buffer_handle.downgrade(), buffer.snapshot())) + // TOOD: @as-cii I removed the downgrade for now to fix the compiler - @kcaverly + Some((buffer_handle, buffer.snapshot())) } else { None } @@ -783,77 +786,133 @@ impl SemanticIndex { .collect::>() }); - cx.background() - .spawn({ - let mut retriever = CodeContextRetriever::new(embedding_provider.clone()); - let embedding_provider = embedding_provider.clone(); - let phrase_embedding = phrase_embedding.clone(); - async move { - let mut results = Vec::new(); - 'buffers: for (buffer_handle, buffer_snapshot) in dirty_buffers { - let language = buffer_snapshot - .language_at(0) - .cloned() - .unwrap_or_else(|| language::PLAIN_TEXT.clone()); - if let Some(spans) = retriever - .parse_file_with_template(None, &buffer_snapshot.text(), language) - .log_err() - { - let mut batch = Vec::new(); - let mut batch_tokens = 0; - let mut embeddings = Vec::new(); + let buffer_results = if let Some(db) = + VectorDatabase::new(fs, db_path.clone(), cx.background()) + .await + .log_err() + { + cx.background() + .spawn({ + let mut retriever = CodeContextRetriever::new(embedding_provider.clone()); + let embedding_provider = embedding_provider.clone(); + let phrase_embedding = phrase_embedding.clone(); + async move { + let mut results = Vec::::new(); + 'buffers: for (buffer_handle, buffer_snapshot) in dirty_buffers { + let language = buffer_snapshot + .language_at(0) + .cloned() + .unwrap_or_else(|| language::PLAIN_TEXT.clone()); + if let Some(spans) = retriever + .parse_file_with_template( + None, + &buffer_snapshot.text(), + language, + ) + .log_err() + { + let mut batch = Vec::new(); + let mut batch_tokens = 0; + let mut embeddings = Vec::new(); - // TODO: query span digests in the database to avoid embedding them again. + let digests = spans + .iter() + .map(|span| span.digest.clone()) + .collect::>(); + let embeddings_for_digests = db + .embeddings_for_digests(digests) + .await + .map_or(Default::default(), |m| m); - for span in &spans { - if span.embedding.is_some() { - continue; + for span in &spans { + if embeddings_for_digests.contains_key(&span.digest) { + continue; + }; + + if batch_tokens + span.token_count + > embedding_provider.max_tokens_per_batch() + { + if let Some(batch_embeddings) = embedding_provider + .embed_batch(mem::take(&mut batch)) + .await + .log_err() + { + embeddings.extend(batch_embeddings); + batch_tokens = 0; + } else { + continue 'buffers; + } + } + + batch_tokens += span.token_count; + batch.push(span.content.clone()); } - if batch_tokens + span.token_count - > embedding_provider.max_tokens_per_batch() + if let Some(batch_embeddings) = embedding_provider + .embed_batch(mem::take(&mut batch)) + .await + .log_err() { - if let Some(batch_embeddings) = embedding_provider - .embed_batch(mem::take(&mut batch)) - .await - .log_err() + embeddings.extend(batch_embeddings); + } else { + continue 'buffers; + } + + let mut embeddings = embeddings.into_iter(); + for span in spans { + let embedding = if let Some(embedding) = + embeddings_for_digests.get(&span.digest) { - embeddings.extend(batch_embeddings); - batch_tokens = 0; + Some(embedding.clone()) } else { + embeddings.next() + }; + + if let Some(embedding) = embedding { + let similarity = + embedding.similarity(&phrase_embedding); + + let ix = match results.binary_search_by(|s| { + similarity + .partial_cmp(&s.similarity) + .unwrap_or(Ordering::Equal) + }) { + Ok(ix) => ix, + Err(ix) => ix, + }; + + let range = { + let start = buffer_snapshot + .clip_offset(span.range.start, Bias::Left); + let end = buffer_snapshot + .clip_offset(span.range.end, Bias::Right); + buffer_snapshot.anchor_before(start) + ..buffer_snapshot.anchor_after(end) + }; + + results.insert( + ix, + SearchResult { + buffer: buffer_handle.clone(), + range, + similarity, + }, + ); + results.truncate(limit); + } else { + log::error!("failed to embed span"); continue 'buffers; } } - - batch_tokens += span.token_count; - batch.push(span.content.clone()); - } - - if let Some(batch_embeddings) = embedding_provider - .embed_batch(mem::take(&mut batch)) - .await - .log_err() - { - embeddings.extend(batch_embeddings); - } else { - continue 'buffers; - } - - let mut embeddings = embeddings.into_iter(); - for span in spans { - let embedding = span.embedding.or_else(|| embeddings.next()); - if let Some(embedding) = embedding { - todo!() - } else { - log::error!("failed to embed span"); - continue 'buffers; - } } } + anyhow::Ok(results) } - } - }) - .await; + }) + .await + } else { + Ok(Vec::new()) + }; let batch_results = futures::future::join_all(batch_results).await; @@ -873,7 +932,11 @@ impl SemanticIndex { } } - let ids = results.into_iter().map(|(id, _)| id).collect::>(); + let ids = results.iter().map(|(id, _)| *id).collect::>(); + let scores = results + .into_iter() + .map(|(_, score)| score) + .collect::>(); let spans = database.spans_for_ids(ids.as_slice()).await?; let mut tasks = Vec::new(); @@ -903,19 +966,74 @@ impl SemanticIndex { t0.elapsed().as_millis() ); - Ok(buffers + let database_results = buffers .into_iter() .zip(ranges) - .filter_map(|(buffer, range)| { + .zip(scores) + .filter_map(|((buffer, range), similarity)| { let buffer = buffer.log_err()?; let range = buffer.read_with(&cx, |buffer, _| { let start = buffer.clip_offset(range.start, Bias::Left); let end = buffer.clip_offset(range.end, Bias::Right); buffer.anchor_before(start)..buffer.anchor_after(end) }); - Some(SearchResult { buffer, range }) + Some(SearchResult { + buffer, + range, + similarity, + }) }) - .collect::>()) + .collect::>(); + + // Stitch Together Database Results & Buffer Results + if let Ok(buffer_results) = buffer_results { + let mut buffer_map = HashMap::default(); + for buffer_result in buffer_results { + buffer_map + .entry(buffer_result.clone().buffer) + .or_insert(Vec::new()) + .push(buffer_result); + } + + for db_result in database_results { + if !buffer_map.contains_key(&db_result.buffer) { + buffer_map + .entry(db_result.clone().buffer) + .or_insert(Vec::new()) + .push(db_result); + } + } + + let mut full_results = Vec::::new(); + + for (_, results) in buffer_map { + for res in results.into_iter() { + let ix = match full_results.binary_search_by(|search_result| { + res.similarity + .partial_cmp(&search_result.similarity) + .unwrap_or(Ordering::Equal) + }) { + Ok(ix) => ix, + Err(ix) => ix, + }; + full_results.insert(ix, res); + full_results.truncate(limit); + } + } + + return Ok(full_results); + } else { + return Ok(database_results); + } + + // let ix = match results.binary_search_by(|(_, s)| { + // similarity.partial_cmp(&s).unwrap_or(Ordering::Equal) + // }) { + // Ok(ix) => ix, + // Err(ix) => ix, + // }; + // results.insert(ix, (id, similarity)); + // results.truncate(limit); }) }