From 6a271617b4f23d024f5dd71efbd44e8a68676aea Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Thu, 14 Sep 2023 17:09:08 +0200 Subject: [PATCH 1/5] Make path optional when parsing file Co-Authored-By: Kyle Caverly --- crates/semantic_index/src/parsing.rs | 26 ++++++++++++++++----- crates/semantic_index/src/semantic_index.rs | 4 ++-- 2 files changed, 22 insertions(+), 8 deletions(-) diff --git a/crates/semantic_index/src/parsing.rs b/crates/semantic_index/src/parsing.rs index b6fc000e1d..281b683853 100644 --- a/crates/semantic_index/src/parsing.rs +++ b/crates/semantic_index/src/parsing.rs @@ -7,6 +7,7 @@ use rusqlite::{ }; use sha1::{Digest, Sha1}; use std::{ + borrow::Cow, cmp::{self, Reverse}, collections::HashSet, ops::Range, @@ -94,12 +95,15 @@ impl CodeContextRetriever { fn parse_entire_file( &self, - relative_path: &Path, + relative_path: Option<&Path>, language_name: Arc, content: &str, ) -> Result> { let document_span = ENTIRE_FILE_TEMPLATE - .replace("", relative_path.to_string_lossy().as_ref()) + .replace( + "", + &relative_path.map_or(Cow::Borrowed("untitled"), |path| path.to_string_lossy()), + ) .replace("", language_name.as_ref()) .replace("", &content); let digest = SpanDigest::from(document_span.as_str()); @@ -114,9 +118,16 @@ impl CodeContextRetriever { }]) } - fn parse_markdown_file(&self, relative_path: &Path, content: &str) -> Result> { + fn parse_markdown_file( + &self, + relative_path: Option<&Path>, + content: &str, + ) -> Result> { let document_span = MARKDOWN_CONTEXT_TEMPLATE - .replace("", relative_path.to_string_lossy().as_ref()) + .replace( + "", + &relative_path.map_or(Cow::Borrowed("untitled"), |path| path.to_string_lossy()), + ) .replace("", &content); let digest = SpanDigest::from(document_span.as_str()); let (document_span, token_count) = self.embedding_provider.truncate(&document_span); @@ -188,7 +199,7 @@ impl CodeContextRetriever { pub fn parse_file_with_template( &mut self, - relative_path: &Path, + relative_path: Option<&Path>, content: &str, language: Arc, ) -> Result> { @@ -203,7 +214,10 @@ impl CodeContextRetriever { let mut spans = self.parse_file(content, language)?; for span in &mut spans { let document_content = CODE_CONTEXT_TEMPLATE - .replace("", relative_path.to_string_lossy().as_ref()) + .replace( + "", + &relative_path.map_or(Cow::Borrowed("untitled"), |path| path.to_string_lossy()), + ) .replace("", language_name.as_ref()) .replace("item", &span.content); diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index 115bf5d7a8..53df3476d3 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -402,7 +402,7 @@ impl SemanticIndex { if let Some(content) = fs.load(&pending_file.absolute_path).await.log_err() { if let Some(mut spans) = retriever - .parse_file_with_template(&pending_file.relative_path, &content, language) + .parse_file_with_template(Some(&pending_file.relative_path), &content, language) .log_err() { log::trace!( @@ -422,7 +422,7 @@ impl SemanticIndex { path: pending_file.relative_path, mtime: pending_file.modified_time, job_handle: pending_file.job_handle, - spans: spans, + spans, }); } } From f86e5a987fd5d0b30a7149c09fe6dbd37d6e64eb Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Thu, 14 Sep 2023 17:42:30 +0200 Subject: [PATCH 2/5] WIP --- crates/project/src/project.rs | 1 - crates/semantic_index/src/db.rs | 4 + crates/semantic_index/src/parsing.rs | 2 +- crates/semantic_index/src/semantic_index.rs | 90 ++++++++++++++++++++- 4 files changed, 94 insertions(+), 3 deletions(-) diff --git a/crates/project/src/project.rs b/crates/project/src/project.rs index 0690cc9188..b4e698e08a 100644 --- a/crates/project/src/project.rs +++ b/crates/project/src/project.rs @@ -912,7 +912,6 @@ impl Project { self.user_store.clone() } - #[cfg(any(test, feature = "test-support"))] pub fn opened_buffers(&self, cx: &AppContext) -> Vec> { self.opened_buffers .values() diff --git a/crates/semantic_index/src/db.rs b/crates/semantic_index/src/db.rs index c53a3e1ba9..15172323c2 100644 --- a/crates/semantic_index/src/db.rs +++ b/crates/semantic_index/src/db.rs @@ -190,6 +190,10 @@ impl VectorDatabase { )", [], )?; + db.execute( + "CREATE INDEX spans_digest ON spans (digest)", + [], + )?; log::trace!("vector database initialized with updated schema."); Ok(()) diff --git a/crates/semantic_index/src/parsing.rs b/crates/semantic_index/src/parsing.rs index 281b683853..49d748a07c 100644 --- a/crates/semantic_index/src/parsing.rs +++ b/crates/semantic_index/src/parsing.rs @@ -207,7 +207,7 @@ impl CodeContextRetriever { if PARSEABLE_ENTIRE_FILE_TYPES.contains(&language_name.as_ref()) { return self.parse_entire_file(relative_path, language_name, &content); - } else if language_name.as_ref() == "Markdown" { + } else if ["Markdown", "Plain Text"].contains(&language_name.as_ref()) { return self.parse_markdown_file(relative_path, &content); } diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index 53df3476d3..6dd5572ab0 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -24,6 +24,7 @@ use smol::channel; use std::{ cmp::Ordering, future::Future, + mem, ops::Range, path::{Path, PathBuf}, sync::{Arc, Weak}, @@ -37,7 +38,7 @@ use util::{ }; use workspace::WorkspaceCreated; -const SEMANTIC_INDEX_VERSION: usize = 10; +const SEMANTIC_INDEX_VERSION: usize = 11; const BACKGROUND_INDEXING_DELAY: Duration = Duration::from_secs(5 * 60); const EMBEDDING_QUEUE_FLUSH_TIMEOUT: Duration = Duration::from_millis(250); @@ -767,6 +768,93 @@ impl SemanticIndex { }); } } + let dirty_buffers = project.read_with(&cx, |project, cx| { + project + .opened_buffers(cx) + .into_iter() + .filter_map(|buffer_handle| { + let buffer = buffer_handle.read(cx); + if buffer.is_dirty() { + Some((buffer_handle.downgrade(), buffer.snapshot())) + } else { + None + } + }) + .collect::>() + }); + + cx.background() + .spawn({ + let mut retriever = CodeContextRetriever::new(embedding_provider.clone()); + let embedding_provider = embedding_provider.clone(); + let phrase_embedding = phrase_embedding.clone(); + async move { + let mut results = Vec::new(); + 'buffers: for (buffer_handle, buffer_snapshot) in dirty_buffers { + let language = buffer_snapshot + .language_at(0) + .cloned() + .unwrap_or_else(|| language::PLAIN_TEXT.clone()); + if let Some(spans) = retriever + .parse_file_with_template(None, &buffer_snapshot.text(), language) + .log_err() + { + let mut batch = Vec::new(); + let mut batch_tokens = 0; + let mut embeddings = Vec::new(); + + // TODO: query span digests in the database to avoid embedding them again. + + for span in &spans { + if span.embedding.is_some() { + continue; + } + + if batch_tokens + span.token_count + > embedding_provider.max_tokens_per_batch() + { + if let Some(batch_embeddings) = embedding_provider + .embed_batch(mem::take(&mut batch)) + .await + .log_err() + { + embeddings.extend(batch_embeddings); + batch_tokens = 0; + } else { + continue 'buffers; + } + } + + batch_tokens += span.token_count; + batch.push(span.content.clone()); + } + + if let Some(batch_embeddings) = embedding_provider + .embed_batch(mem::take(&mut batch)) + .await + .log_err() + { + embeddings.extend(batch_embeddings); + } else { + continue 'buffers; + } + + let mut embeddings = embeddings.into_iter(); + for span in spans { + let embedding = span.embedding.or_else(|| embeddings.next()); + if let Some(embedding) = embedding { + todo!() + } else { + log::error!("failed to embed span"); + continue 'buffers; + } + } + } + } + } + }) + .await; + let batch_results = futures::future::join_all(batch_results).await; let mut results = Vec::new(); From c19c8899fe1c27f3029dda4ea3a071460f1b9560 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Thu, 14 Sep 2023 14:58:34 -0400 Subject: [PATCH 3/5] add initial search inside modified buffers --- crates/semantic_index/src/db.rs | 33 +++ crates/semantic_index/src/parsing.rs | 2 +- crates/semantic_index/src/semantic_index.rs | 246 +++++++++++++++----- 3 files changed, 216 insertions(+), 65 deletions(-) diff --git a/crates/semantic_index/src/db.rs b/crates/semantic_index/src/db.rs index 15172323c2..cad0734e76 100644 --- a/crates/semantic_index/src/db.rs +++ b/crates/semantic_index/src/db.rs @@ -278,6 +278,39 @@ impl VectorDatabase { }) } + pub fn embeddings_for_digests( + &self, + digests: Vec, + ) -> impl Future>> { + self.transact(move |db| { + let mut query = db.prepare( + " + SELECT digest, embedding + FROM spans + WHERE digest IN rarray(?) + ", + )?; + let mut embeddings_by_digest = HashMap::default(); + let digests = Rc::new( + digests + .into_iter() + .map(|p| Value::Blob(p.0.to_vec())) + .collect::>(), + ); + let rows = query.query_map(params![digests], |row| { + Ok((row.get::<_, SpanDigest>(0)?, row.get::<_, Embedding>(1)?)) + })?; + + for row in rows { + if let Ok(row) = row { + embeddings_by_digest.insert(row.0, row.1); + } + } + + Ok(embeddings_by_digest) + }) + } + pub fn embeddings_for_files( &self, worktree_id_file_paths: HashMap>>, diff --git a/crates/semantic_index/src/parsing.rs b/crates/semantic_index/src/parsing.rs index 49d748a07c..9f5a339b23 100644 --- a/crates/semantic_index/src/parsing.rs +++ b/crates/semantic_index/src/parsing.rs @@ -17,7 +17,7 @@ use std::{ use tree_sitter::{Parser, QueryCursor}; #[derive(Debug, PartialEq, Eq, Clone, Hash)] -pub struct SpanDigest([u8; 20]); +pub struct SpanDigest(pub [u8; 20]); impl FromSql for SpanDigest { fn column_result(value: ValueRef) -> FromSqlResult { diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index 6dd5572ab0..056f6a3386 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -263,9 +263,11 @@ pub struct PendingFile { job_handle: JobHandle, } +#[derive(Clone)] pub struct SearchResult { pub buffer: ModelHandle, pub range: Range, + pub similarity: f32, } impl SemanticIndex { @@ -775,7 +777,8 @@ impl SemanticIndex { .filter_map(|buffer_handle| { let buffer = buffer_handle.read(cx); if buffer.is_dirty() { - Some((buffer_handle.downgrade(), buffer.snapshot())) + // TOOD: @as-cii I removed the downgrade for now to fix the compiler - @kcaverly + Some((buffer_handle, buffer.snapshot())) } else { None } @@ -783,77 +786,133 @@ impl SemanticIndex { .collect::>() }); - cx.background() - .spawn({ - let mut retriever = CodeContextRetriever::new(embedding_provider.clone()); - let embedding_provider = embedding_provider.clone(); - let phrase_embedding = phrase_embedding.clone(); - async move { - let mut results = Vec::new(); - 'buffers: for (buffer_handle, buffer_snapshot) in dirty_buffers { - let language = buffer_snapshot - .language_at(0) - .cloned() - .unwrap_or_else(|| language::PLAIN_TEXT.clone()); - if let Some(spans) = retriever - .parse_file_with_template(None, &buffer_snapshot.text(), language) - .log_err() - { - let mut batch = Vec::new(); - let mut batch_tokens = 0; - let mut embeddings = Vec::new(); + let buffer_results = if let Some(db) = + VectorDatabase::new(fs, db_path.clone(), cx.background()) + .await + .log_err() + { + cx.background() + .spawn({ + let mut retriever = CodeContextRetriever::new(embedding_provider.clone()); + let embedding_provider = embedding_provider.clone(); + let phrase_embedding = phrase_embedding.clone(); + async move { + let mut results = Vec::::new(); + 'buffers: for (buffer_handle, buffer_snapshot) in dirty_buffers { + let language = buffer_snapshot + .language_at(0) + .cloned() + .unwrap_or_else(|| language::PLAIN_TEXT.clone()); + if let Some(spans) = retriever + .parse_file_with_template( + None, + &buffer_snapshot.text(), + language, + ) + .log_err() + { + let mut batch = Vec::new(); + let mut batch_tokens = 0; + let mut embeddings = Vec::new(); - // TODO: query span digests in the database to avoid embedding them again. + let digests = spans + .iter() + .map(|span| span.digest.clone()) + .collect::>(); + let embeddings_for_digests = db + .embeddings_for_digests(digests) + .await + .map_or(Default::default(), |m| m); - for span in &spans { - if span.embedding.is_some() { - continue; + for span in &spans { + if embeddings_for_digests.contains_key(&span.digest) { + continue; + }; + + if batch_tokens + span.token_count + > embedding_provider.max_tokens_per_batch() + { + if let Some(batch_embeddings) = embedding_provider + .embed_batch(mem::take(&mut batch)) + .await + .log_err() + { + embeddings.extend(batch_embeddings); + batch_tokens = 0; + } else { + continue 'buffers; + } + } + + batch_tokens += span.token_count; + batch.push(span.content.clone()); } - if batch_tokens + span.token_count - > embedding_provider.max_tokens_per_batch() + if let Some(batch_embeddings) = embedding_provider + .embed_batch(mem::take(&mut batch)) + .await + .log_err() { - if let Some(batch_embeddings) = embedding_provider - .embed_batch(mem::take(&mut batch)) - .await - .log_err() + embeddings.extend(batch_embeddings); + } else { + continue 'buffers; + } + + let mut embeddings = embeddings.into_iter(); + for span in spans { + let embedding = if let Some(embedding) = + embeddings_for_digests.get(&span.digest) { - embeddings.extend(batch_embeddings); - batch_tokens = 0; + Some(embedding.clone()) } else { + embeddings.next() + }; + + if let Some(embedding) = embedding { + let similarity = + embedding.similarity(&phrase_embedding); + + let ix = match results.binary_search_by(|s| { + similarity + .partial_cmp(&s.similarity) + .unwrap_or(Ordering::Equal) + }) { + Ok(ix) => ix, + Err(ix) => ix, + }; + + let range = { + let start = buffer_snapshot + .clip_offset(span.range.start, Bias::Left); + let end = buffer_snapshot + .clip_offset(span.range.end, Bias::Right); + buffer_snapshot.anchor_before(start) + ..buffer_snapshot.anchor_after(end) + }; + + results.insert( + ix, + SearchResult { + buffer: buffer_handle.clone(), + range, + similarity, + }, + ); + results.truncate(limit); + } else { + log::error!("failed to embed span"); continue 'buffers; } } - - batch_tokens += span.token_count; - batch.push(span.content.clone()); - } - - if let Some(batch_embeddings) = embedding_provider - .embed_batch(mem::take(&mut batch)) - .await - .log_err() - { - embeddings.extend(batch_embeddings); - } else { - continue 'buffers; - } - - let mut embeddings = embeddings.into_iter(); - for span in spans { - let embedding = span.embedding.or_else(|| embeddings.next()); - if let Some(embedding) = embedding { - todo!() - } else { - log::error!("failed to embed span"); - continue 'buffers; - } } } + anyhow::Ok(results) } - } - }) - .await; + }) + .await + } else { + Ok(Vec::new()) + }; let batch_results = futures::future::join_all(batch_results).await; @@ -873,7 +932,11 @@ impl SemanticIndex { } } - let ids = results.into_iter().map(|(id, _)| id).collect::>(); + let ids = results.iter().map(|(id, _)| *id).collect::>(); + let scores = results + .into_iter() + .map(|(_, score)| score) + .collect::>(); let spans = database.spans_for_ids(ids.as_slice()).await?; let mut tasks = Vec::new(); @@ -903,19 +966,74 @@ impl SemanticIndex { t0.elapsed().as_millis() ); - Ok(buffers + let database_results = buffers .into_iter() .zip(ranges) - .filter_map(|(buffer, range)| { + .zip(scores) + .filter_map(|((buffer, range), similarity)| { let buffer = buffer.log_err()?; let range = buffer.read_with(&cx, |buffer, _| { let start = buffer.clip_offset(range.start, Bias::Left); let end = buffer.clip_offset(range.end, Bias::Right); buffer.anchor_before(start)..buffer.anchor_after(end) }); - Some(SearchResult { buffer, range }) + Some(SearchResult { + buffer, + range, + similarity, + }) }) - .collect::>()) + .collect::>(); + + // Stitch Together Database Results & Buffer Results + if let Ok(buffer_results) = buffer_results { + let mut buffer_map = HashMap::default(); + for buffer_result in buffer_results { + buffer_map + .entry(buffer_result.clone().buffer) + .or_insert(Vec::new()) + .push(buffer_result); + } + + for db_result in database_results { + if !buffer_map.contains_key(&db_result.buffer) { + buffer_map + .entry(db_result.clone().buffer) + .or_insert(Vec::new()) + .push(db_result); + } + } + + let mut full_results = Vec::::new(); + + for (_, results) in buffer_map { + for res in results.into_iter() { + let ix = match full_results.binary_search_by(|search_result| { + res.similarity + .partial_cmp(&search_result.similarity) + .unwrap_or(Ordering::Equal) + }) { + Ok(ix) => ix, + Err(ix) => ix, + }; + full_results.insert(ix, res); + full_results.truncate(limit); + } + } + + return Ok(full_results); + } else { + return Ok(database_results); + } + + // let ix = match results.binary_search_by(|(_, s)| { + // similarity.partial_cmp(&s).unwrap_or(Ordering::Equal) + // }) { + // Ok(ix) => ix, + // Err(ix) => ix, + // }; + // results.insert(ix, (id, similarity)); + // results.truncate(limit); }) } From 796bdd3da792c2373b814279634218504416d523 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Thu, 14 Sep 2023 19:42:06 -0400 Subject: [PATCH 4/5] update searching in modified buffers to accomodate for excluded paths --- crates/semantic_index/src/semantic_index.rs | 111 +++++++++----------- 1 file changed, 51 insertions(+), 60 deletions(-) diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index 056f6a3386..063aff96e9 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -693,7 +693,7 @@ impl SemanticIndex { phrase: String, limit: usize, includes: Vec, - excludes: Vec, + mut excludes: Vec, cx: &mut ModelContext, ) -> Task>> { let index = self.index_project(project.clone(), cx); @@ -741,6 +741,43 @@ impl SemanticIndex { .collect::>(); anyhow::Ok(worktree_db_ids) })?; + + let (dirty_buffers, dirty_paths) = project.read_with(&cx, |project, cx| { + let mut dirty_paths = Vec::new(); + let dirty_buffers = project + .opened_buffers(cx) + .into_iter() + .filter_map(|buffer_handle| { + let buffer = buffer_handle.read(cx); + if buffer.is_dirty() { + let snapshot = buffer.snapshot(); + if let Some(file_pathbuf) = snapshot.resolve_file_path(cx, false) { + let file_path = file_pathbuf.as_path(); + + if excludes.iter().any(|glob| glob.is_match(file_path)) { + return None; + } + + file_pathbuf + .to_str() + .and_then(|path| PathMatcher::new(path).log_err()) + .and_then(|path_matcher| { + dirty_paths.push(path_matcher); + Some(()) + }); + } + // TOOD: @as-cii I removed the downgrade for now to fix the compiler - @kcaverly + Some((buffer_handle, buffer.snapshot())) + } else { + None + } + }) + .collect::>(); + + (dirty_buffers, dirty_paths) + }); + + excludes.extend(dirty_paths); let file_ids = database .retrieve_included_file_ids(&worktree_db_ids, &includes, &excludes) .await?; @@ -770,21 +807,6 @@ impl SemanticIndex { }); } } - let dirty_buffers = project.read_with(&cx, |project, cx| { - project - .opened_buffers(cx) - .into_iter() - .filter_map(|buffer_handle| { - let buffer = buffer_handle.read(cx); - if buffer.is_dirty() { - // TOOD: @as-cii I removed the downgrade for now to fix the compiler - @kcaverly - Some((buffer_handle, buffer.snapshot())) - } else { - None - } - }) - .collect::>() - }); let buffer_results = if let Some(db) = VectorDatabase::new(fs, db_path.clone(), cx.background()) @@ -966,7 +988,7 @@ impl SemanticIndex { t0.elapsed().as_millis() ); - let database_results = buffers + let mut database_results = buffers .into_iter() .zip(ranges) .zip(scores) @@ -987,53 +1009,22 @@ impl SemanticIndex { // Stitch Together Database Results & Buffer Results if let Ok(buffer_results) = buffer_results { - let mut buffer_map = HashMap::default(); for buffer_result in buffer_results { - buffer_map - .entry(buffer_result.clone().buffer) - .or_insert(Vec::new()) - .push(buffer_result); + let ix = match database_results.binary_search_by(|search_result| { + buffer_result + .similarity + .partial_cmp(&search_result.similarity) + .unwrap_or(Ordering::Equal) + }) { + Ok(ix) => ix, + Err(ix) => ix, + }; + database_results.insert(ix, buffer_result); + database_results.truncate(limit); } - - for db_result in database_results { - if !buffer_map.contains_key(&db_result.buffer) { - buffer_map - .entry(db_result.clone().buffer) - .or_insert(Vec::new()) - .push(db_result); - } - } - - let mut full_results = Vec::::new(); - - for (_, results) in buffer_map { - for res in results.into_iter() { - let ix = match full_results.binary_search_by(|search_result| { - res.similarity - .partial_cmp(&search_result.similarity) - .unwrap_or(Ordering::Equal) - }) { - Ok(ix) => ix, - Err(ix) => ix, - }; - full_results.insert(ix, res); - full_results.truncate(limit); - } - } - - return Ok(full_results); - } else { - return Ok(database_results); } - // let ix = match results.binary_search_by(|(_, s)| { - // similarity.partial_cmp(&s).unwrap_or(Ordering::Equal) - // }) { - // Ok(ix) => ix, - // Err(ix) => ix, - // }; - // results.insert(ix, (id, similarity)); - // results.truncate(limit); + Ok(database_results) }) } From ae85a520f2673892397da04b26e8d880601ac865 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Fri, 15 Sep 2023 12:12:20 +0200 Subject: [PATCH 5/5] Refactor semantic searching of modified buffers --- Cargo.lock | 1 + crates/semantic_index/Cargo.toml | 1 + crates/semantic_index/src/db.rs | 13 +- crates/semantic_index/src/embedding.rs | 11 +- crates/semantic_index/src/semantic_index.rs | 417 ++++++++++---------- 5 files changed, 215 insertions(+), 228 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 20bc1c9d0d..2f549c568d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6739,6 +6739,7 @@ dependencies = [ "lazy_static", "log", "matrixmultiply", + "ordered-float", "parking_lot 0.11.2", "parse_duration", "picker", diff --git a/crates/semantic_index/Cargo.toml b/crates/semantic_index/Cargo.toml index 72a36efd50..45b02722ac 100644 --- a/crates/semantic_index/Cargo.toml +++ b/crates/semantic_index/Cargo.toml @@ -23,6 +23,7 @@ settings = { path = "../settings" } anyhow.workspace = true postage.workspace = true futures.workspace = true +ordered-float.workspace = true smol.workspace = true rusqlite = { version = "0.27.0", features = ["blob", "array", "modern_sqlite"] } isahc.workspace = true diff --git a/crates/semantic_index/src/db.rs b/crates/semantic_index/src/db.rs index cad0734e76..3e35284027 100644 --- a/crates/semantic_index/src/db.rs +++ b/crates/semantic_index/src/db.rs @@ -7,12 +7,13 @@ use anyhow::{anyhow, Context, Result}; use collections::HashMap; use futures::channel::oneshot; use gpui::executor; +use ordered_float::OrderedFloat; use project::{search::PathMatcher, Fs}; use rpc::proto::Timestamp; use rusqlite::params; use rusqlite::types::Value; use std::{ - cmp::Ordering, + cmp::Reverse, future::Future, ops::Range, path::{Path, PathBuf}, @@ -407,16 +408,16 @@ impl VectorDatabase { query_embedding: &Embedding, limit: usize, file_ids: &[i64], - ) -> impl Future>> { + ) -> impl Future)>>> { let query_embedding = query_embedding.clone(); let file_ids = file_ids.to_vec(); self.transact(move |db| { - let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1); + let mut results = Vec::<(i64, OrderedFloat)>::with_capacity(limit + 1); Self::for_each_span(db, &file_ids, |id, embedding| { let similarity = embedding.similarity(&query_embedding); - let ix = match results.binary_search_by(|(_, s)| { - similarity.partial_cmp(&s).unwrap_or(Ordering::Equal) - }) { + let ix = match results + .binary_search_by_key(&Reverse(similarity), |(_, s)| Reverse(*s)) + { Ok(ix) => ix, Err(ix) => ix, }; diff --git a/crates/semantic_index/src/embedding.rs b/crates/semantic_index/src/embedding.rs index 42d90f0fdb..b0124bf7df 100644 --- a/crates/semantic_index/src/embedding.rs +++ b/crates/semantic_index/src/embedding.rs @@ -7,6 +7,7 @@ use isahc::http::StatusCode; use isahc::prelude::Configurable; use isahc::{AsyncBody, Response}; use lazy_static::lazy_static; +use ordered_float::OrderedFloat; use parking_lot::Mutex; use parse_duration::parse; use postage::watch; @@ -35,7 +36,7 @@ impl From> for Embedding { } impl Embedding { - pub fn similarity(&self, other: &Self) -> f32 { + pub fn similarity(&self, other: &Self) -> OrderedFloat { let len = self.0.len(); assert_eq!(len, other.0.len()); @@ -58,7 +59,7 @@ impl Embedding { 1, ); } - result + OrderedFloat(result) } } @@ -379,13 +380,13 @@ mod tests { ); } - fn round_to_decimals(n: f32, decimal_places: i32) -> f32 { + fn round_to_decimals(n: OrderedFloat, decimal_places: i32) -> f32 { let factor = (10.0 as f32).powi(decimal_places); (n * factor).round() / factor } - fn reference_dot(a: &[f32], b: &[f32]) -> f32 { - a.iter().zip(b.iter()).map(|(a, b)| a * b).sum() + fn reference_dot(a: &[f32], b: &[f32]) -> OrderedFloat { + OrderedFloat(a.iter().zip(b.iter()).map(|(a, b)| a * b).sum()) } } } diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index 063aff96e9..06c7aa53fa 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -16,13 +16,14 @@ use embedding_queue::{EmbeddingQueue, FileToEmbed}; use futures::{future, FutureExt, StreamExt}; use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle}; use language::{Anchor, Bias, Buffer, Language, LanguageRegistry}; +use ordered_float::OrderedFloat; use parking_lot::Mutex; -use parsing::{CodeContextRetriever, SpanDigest, PARSEABLE_ENTIRE_FILE_TYPES}; +use parsing::{CodeContextRetriever, Span, SpanDigest, PARSEABLE_ENTIRE_FILE_TYPES}; use postage::watch; use project::{search::PathMatcher, Fs, PathChange, Project, ProjectEntryId, Worktree, WorktreeId}; use smol::channel; use std::{ - cmp::Ordering, + cmp::Reverse, future::Future, mem, ops::Range, @@ -267,7 +268,7 @@ pub struct PendingFile { pub struct SearchResult { pub buffer: ModelHandle, pub range: Range, - pub similarity: f32, + pub similarity: OrderedFloat, } impl SemanticIndex { @@ -690,39 +691,71 @@ impl SemanticIndex { pub fn search_project( &mut self, project: ModelHandle, - phrase: String, + query: String, limit: usize, includes: Vec, - mut excludes: Vec, + excludes: Vec, cx: &mut ModelContext, ) -> Task>> { + if query.is_empty() { + return Task::ready(Ok(Vec::new())); + } + let index = self.index_project(project.clone(), cx); let embedding_provider = self.embedding_provider.clone(); + + cx.spawn(|this, mut cx| async move { + let query = embedding_provider + .embed_batch(vec![query]) + .await? + .pop() + .ok_or_else(|| anyhow!("could not embed query"))?; + index.await?; + + let search_start = Instant::now(); + let modified_buffer_results = this.update(&mut cx, |this, cx| { + this.search_modified_buffers(&project, query.clone(), limit, &excludes, cx) + }); + let file_results = this.update(&mut cx, |this, cx| { + this.search_files(project, query, limit, includes, excludes, cx) + }); + let (modified_buffer_results, file_results) = + futures::join!(modified_buffer_results, file_results); + + // Weave together the results from modified buffers and files. + let mut results = Vec::new(); + let mut modified_buffers = HashSet::default(); + for result in modified_buffer_results.log_err().unwrap_or_default() { + modified_buffers.insert(result.buffer.clone()); + results.push(result); + } + for result in file_results.log_err().unwrap_or_default() { + if !modified_buffers.contains(&result.buffer) { + results.push(result); + } + } + results.sort_by_key(|result| Reverse(result.similarity)); + results.truncate(limit); + log::trace!("Semantic search took {:?}", search_start.elapsed()); + Ok(results) + }) + } + + pub fn search_files( + &mut self, + project: ModelHandle, + query: Embedding, + limit: usize, + includes: Vec, + excludes: Vec, + cx: &mut ModelContext, + ) -> Task>> { let db_path = self.db.path().clone(); let fs = self.fs.clone(); cx.spawn(|this, mut cx| async move { - index.await?; - - let t0 = Instant::now(); let database = VectorDatabase::new(fs.clone(), db_path.clone(), cx.background()).await?; - if phrase.len() == 0 { - return Ok(Vec::new()); - } - - let phrase_embedding = embedding_provider - .embed_batch(vec![phrase]) - .await? - .into_iter() - .next() - .unwrap(); - - log::trace!( - "Embedding search phrase took: {:?} milliseconds", - t0.elapsed().as_millis() - ); - let worktree_db_ids = this.read_with(&cx, |this, _| { let project_state = this .projects @@ -742,42 +775,6 @@ impl SemanticIndex { anyhow::Ok(worktree_db_ids) })?; - let (dirty_buffers, dirty_paths) = project.read_with(&cx, |project, cx| { - let mut dirty_paths = Vec::new(); - let dirty_buffers = project - .opened_buffers(cx) - .into_iter() - .filter_map(|buffer_handle| { - let buffer = buffer_handle.read(cx); - if buffer.is_dirty() { - let snapshot = buffer.snapshot(); - if let Some(file_pathbuf) = snapshot.resolve_file_path(cx, false) { - let file_path = file_pathbuf.as_path(); - - if excludes.iter().any(|glob| glob.is_match(file_path)) { - return None; - } - - file_pathbuf - .to_str() - .and_then(|path| PathMatcher::new(path).log_err()) - .and_then(|path_matcher| { - dirty_paths.push(path_matcher); - Some(()) - }); - } - // TOOD: @as-cii I removed the downgrade for now to fix the compiler - @kcaverly - Some((buffer_handle, buffer.snapshot())) - } else { - None - } - }) - .collect::>(); - - (dirty_buffers, dirty_paths) - }); - - excludes.extend(dirty_paths); let file_ids = database .retrieve_included_file_ids(&worktree_db_ids, &includes, &excludes) .await?; @@ -796,155 +793,26 @@ impl SemanticIndex { let limit = limit.clone(); let fs = fs.clone(); let db_path = db_path.clone(); - let phrase_embedding = phrase_embedding.clone(); + let query = query.clone(); if let Some(db) = VectorDatabase::new(fs, db_path.clone(), cx.background()) .await .log_err() { batch_results.push(async move { - db.top_k_search(&phrase_embedding, limit, batch.as_slice()) - .await + db.top_k_search(&query, limit, batch.as_slice()).await }); } } - let buffer_results = if let Some(db) = - VectorDatabase::new(fs, db_path.clone(), cx.background()) - .await - .log_err() - { - cx.background() - .spawn({ - let mut retriever = CodeContextRetriever::new(embedding_provider.clone()); - let embedding_provider = embedding_provider.clone(); - let phrase_embedding = phrase_embedding.clone(); - async move { - let mut results = Vec::::new(); - 'buffers: for (buffer_handle, buffer_snapshot) in dirty_buffers { - let language = buffer_snapshot - .language_at(0) - .cloned() - .unwrap_or_else(|| language::PLAIN_TEXT.clone()); - if let Some(spans) = retriever - .parse_file_with_template( - None, - &buffer_snapshot.text(), - language, - ) - .log_err() - { - let mut batch = Vec::new(); - let mut batch_tokens = 0; - let mut embeddings = Vec::new(); - - let digests = spans - .iter() - .map(|span| span.digest.clone()) - .collect::>(); - let embeddings_for_digests = db - .embeddings_for_digests(digests) - .await - .map_or(Default::default(), |m| m); - - for span in &spans { - if embeddings_for_digests.contains_key(&span.digest) { - continue; - }; - - if batch_tokens + span.token_count - > embedding_provider.max_tokens_per_batch() - { - if let Some(batch_embeddings) = embedding_provider - .embed_batch(mem::take(&mut batch)) - .await - .log_err() - { - embeddings.extend(batch_embeddings); - batch_tokens = 0; - } else { - continue 'buffers; - } - } - - batch_tokens += span.token_count; - batch.push(span.content.clone()); - } - - if let Some(batch_embeddings) = embedding_provider - .embed_batch(mem::take(&mut batch)) - .await - .log_err() - { - embeddings.extend(batch_embeddings); - } else { - continue 'buffers; - } - - let mut embeddings = embeddings.into_iter(); - for span in spans { - let embedding = if let Some(embedding) = - embeddings_for_digests.get(&span.digest) - { - Some(embedding.clone()) - } else { - embeddings.next() - }; - - if let Some(embedding) = embedding { - let similarity = - embedding.similarity(&phrase_embedding); - - let ix = match results.binary_search_by(|s| { - similarity - .partial_cmp(&s.similarity) - .unwrap_or(Ordering::Equal) - }) { - Ok(ix) => ix, - Err(ix) => ix, - }; - - let range = { - let start = buffer_snapshot - .clip_offset(span.range.start, Bias::Left); - let end = buffer_snapshot - .clip_offset(span.range.end, Bias::Right); - buffer_snapshot.anchor_before(start) - ..buffer_snapshot.anchor_after(end) - }; - - results.insert( - ix, - SearchResult { - buffer: buffer_handle.clone(), - range, - similarity, - }, - ); - results.truncate(limit); - } else { - log::error!("failed to embed span"); - continue 'buffers; - } - } - } - } - anyhow::Ok(results) - } - }) - .await - } else { - Ok(Vec::new()) - }; - let batch_results = futures::future::join_all(batch_results).await; let mut results = Vec::new(); for batch_result in batch_results { if batch_result.is_ok() { for (id, similarity) in batch_result.unwrap() { - let ix = match results.binary_search_by(|(_, s)| { - similarity.partial_cmp(&s).unwrap_or(Ordering::Equal) - }) { + let ix = match results + .binary_search_by_key(&Reverse(similarity), |(_, s)| Reverse(*s)) + { Ok(ix) => ix, Err(ix) => ix, }; @@ -958,7 +826,7 @@ impl SemanticIndex { let scores = results .into_iter() .map(|(_, score)| score) - .collect::>(); + .collect::>(); let spans = database.spans_for_ids(ids.as_slice()).await?; let mut tasks = Vec::new(); @@ -983,12 +851,7 @@ impl SemanticIndex { let buffers = futures::future::join_all(tasks).await; - log::trace!( - "Semantic Searching took: {:?} milliseconds in total", - t0.elapsed().as_millis() - ); - - let mut database_results = buffers + Ok(buffers .into_iter() .zip(ranges) .zip(scores) @@ -1005,26 +868,89 @@ impl SemanticIndex { similarity, }) }) - .collect::>(); + .collect()) + }) + } - // Stitch Together Database Results & Buffer Results - if let Ok(buffer_results) = buffer_results { - for buffer_result in buffer_results { - let ix = match database_results.binary_search_by(|search_result| { - buffer_result - .similarity - .partial_cmp(&search_result.similarity) - .unwrap_or(Ordering::Equal) - }) { - Ok(ix) => ix, - Err(ix) => ix, - }; - database_results.insert(ix, buffer_result); - database_results.truncate(limit); + fn search_modified_buffers( + &self, + project: &ModelHandle, + query: Embedding, + limit: usize, + excludes: &[PathMatcher], + cx: &mut ModelContext, + ) -> Task>> { + let modified_buffers = project + .read(cx) + .opened_buffers(cx) + .into_iter() + .filter_map(|buffer_handle| { + let buffer = buffer_handle.read(cx); + let snapshot = buffer.snapshot(); + let excluded = snapshot.resolve_file_path(cx, false).map_or(false, |path| { + excludes.iter().any(|matcher| matcher.is_match(&path)) + }); + if buffer.is_dirty() && !excluded { + Some((buffer_handle, snapshot)) + } else { + None + } + }) + .collect::>(); + + let embedding_provider = self.embedding_provider.clone(); + let fs = self.fs.clone(); + let db_path = self.db.path().clone(); + let background = cx.background().clone(); + cx.background().spawn(async move { + let db = VectorDatabase::new(fs, db_path.clone(), background).await?; + let mut results = Vec::::new(); + + let mut retriever = CodeContextRetriever::new(embedding_provider.clone()); + for (buffer, snapshot) in modified_buffers { + let language = snapshot + .language_at(0) + .cloned() + .unwrap_or_else(|| language::PLAIN_TEXT.clone()); + let mut spans = retriever + .parse_file_with_template(None, &snapshot.text(), language) + .log_err() + .unwrap_or_default(); + if Self::embed_spans(&mut spans, embedding_provider.as_ref(), &db) + .await + .log_err() + .is_some() + { + for span in spans { + let similarity = span.embedding.unwrap().similarity(&query); + let ix = match results + .binary_search_by_key(&Reverse(similarity), |result| { + Reverse(result.similarity) + }) { + Ok(ix) => ix, + Err(ix) => ix, + }; + + let range = { + let start = snapshot.clip_offset(span.range.start, Bias::Left); + let end = snapshot.clip_offset(span.range.end, Bias::Right); + snapshot.anchor_before(start)..snapshot.anchor_after(end) + }; + + results.insert( + ix, + SearchResult { + buffer: buffer.clone(), + range, + similarity, + }, + ); + results.truncate(limit); + } } } - Ok(database_results) + Ok(results) }) } @@ -1208,6 +1134,63 @@ impl SemanticIndex { Ok(()) }) } + + async fn embed_spans( + spans: &mut [Span], + embedding_provider: &dyn EmbeddingProvider, + db: &VectorDatabase, + ) -> Result<()> { + let mut batch = Vec::new(); + let mut batch_tokens = 0; + let mut embeddings = Vec::new(); + + let digests = spans + .iter() + .map(|span| span.digest.clone()) + .collect::>(); + let embeddings_for_digests = db + .embeddings_for_digests(digests) + .await + .log_err() + .unwrap_or_default(); + + for span in &*spans { + if embeddings_for_digests.contains_key(&span.digest) { + continue; + }; + + if batch_tokens + span.token_count > embedding_provider.max_tokens_per_batch() { + let batch_embeddings = embedding_provider + .embed_batch(mem::take(&mut batch)) + .await?; + embeddings.extend(batch_embeddings); + batch_tokens = 0; + } + + batch_tokens += span.token_count; + batch.push(span.content.clone()); + } + + if !batch.is_empty() { + let batch_embeddings = embedding_provider + .embed_batch(mem::take(&mut batch)) + .await?; + + embeddings.extend(batch_embeddings); + } + + let mut embeddings = embeddings.into_iter(); + for span in spans { + let embedding = if let Some(embedding) = embeddings_for_digests.get(&span.digest) { + Some(embedding.clone()) + } else { + embeddings.next() + }; + let embedding = embedding.ok_or_else(|| anyhow!("failed to embed spans"))?; + span.embedding = Some(embedding); + } + Ok(()) + } } impl Entity for SemanticIndex {