Include modified buffers in semantic search results (#2970)

This pull request introduces an additional step to
`SemanticIndex::search_project` that includes the content of buffers
that are modified but haven't been saved yet. In most cases, the buffer
will contain a small portion of changed spans that are potentially not
included in the index. To reuse all the other spans that haven't
changed, we will query the database for embeddings by their digest. This
means we have to index spans by their digest, which means some penalty
when writing, but in our tests this didn't seem to make indexing much
slower.

Release Notes:

- Improved semantic search to include results from modified buffers.
(preview-only)
This commit is contained in:
Antonio Scandurra 2023-09-15 12:24:10 +02:00 committed by GitHub
commit a1250b8525
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 294 additions and 60 deletions

1
Cargo.lock generated
View File

@ -6739,6 +6739,7 @@ dependencies = [
"lazy_static",
"log",
"matrixmultiply",
"ordered-float",
"parking_lot 0.11.2",
"parse_duration",
"picker",

View File

@ -912,7 +912,6 @@ impl Project {
self.user_store.clone()
}
#[cfg(any(test, feature = "test-support"))]
pub fn opened_buffers(&self, cx: &AppContext) -> Vec<ModelHandle<Buffer>> {
self.opened_buffers
.values()

View File

@ -23,6 +23,7 @@ settings = { path = "../settings" }
anyhow.workspace = true
postage.workspace = true
futures.workspace = true
ordered-float.workspace = true
smol.workspace = true
rusqlite = { version = "0.27.0", features = ["blob", "array", "modern_sqlite"] }
isahc.workspace = true

View File

@ -7,12 +7,13 @@ use anyhow::{anyhow, Context, Result};
use collections::HashMap;
use futures::channel::oneshot;
use gpui::executor;
use ordered_float::OrderedFloat;
use project::{search::PathMatcher, Fs};
use rpc::proto::Timestamp;
use rusqlite::params;
use rusqlite::types::Value;
use std::{
cmp::Ordering,
cmp::Reverse,
future::Future,
ops::Range,
path::{Path, PathBuf},
@ -190,6 +191,10 @@ impl VectorDatabase {
)",
[],
)?;
db.execute(
"CREATE INDEX spans_digest ON spans (digest)",
[],
)?;
log::trace!("vector database initialized with updated schema.");
Ok(())
@ -274,6 +279,39 @@ impl VectorDatabase {
})
}
pub fn embeddings_for_digests(
&self,
digests: Vec<SpanDigest>,
) -> impl Future<Output = Result<HashMap<SpanDigest, Embedding>>> {
self.transact(move |db| {
let mut query = db.prepare(
"
SELECT digest, embedding
FROM spans
WHERE digest IN rarray(?)
",
)?;
let mut embeddings_by_digest = HashMap::default();
let digests = Rc::new(
digests
.into_iter()
.map(|p| Value::Blob(p.0.to_vec()))
.collect::<Vec<_>>(),
);
let rows = query.query_map(params![digests], |row| {
Ok((row.get::<_, SpanDigest>(0)?, row.get::<_, Embedding>(1)?))
})?;
for row in rows {
if let Ok(row) = row {
embeddings_by_digest.insert(row.0, row.1);
}
}
Ok(embeddings_by_digest)
})
}
pub fn embeddings_for_files(
&self,
worktree_id_file_paths: HashMap<i64, Vec<Arc<Path>>>,
@ -370,16 +408,16 @@ impl VectorDatabase {
query_embedding: &Embedding,
limit: usize,
file_ids: &[i64],
) -> impl Future<Output = Result<Vec<(i64, f32)>>> {
) -> impl Future<Output = Result<Vec<(i64, OrderedFloat<f32>)>>> {
let query_embedding = query_embedding.clone();
let file_ids = file_ids.to_vec();
self.transact(move |db| {
let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
let mut results = Vec::<(i64, OrderedFloat<f32>)>::with_capacity(limit + 1);
Self::for_each_span(db, &file_ids, |id, embedding| {
let similarity = embedding.similarity(&query_embedding);
let ix = match results.binary_search_by(|(_, s)| {
similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)
}) {
let ix = match results
.binary_search_by_key(&Reverse(similarity), |(_, s)| Reverse(*s))
{
Ok(ix) => ix,
Err(ix) => ix,
};

View File

@ -7,6 +7,7 @@ use isahc::http::StatusCode;
use isahc::prelude::Configurable;
use isahc::{AsyncBody, Response};
use lazy_static::lazy_static;
use ordered_float::OrderedFloat;
use parking_lot::Mutex;
use parse_duration::parse;
use postage::watch;
@ -35,7 +36,7 @@ impl From<Vec<f32>> for Embedding {
}
impl Embedding {
pub fn similarity(&self, other: &Self) -> f32 {
pub fn similarity(&self, other: &Self) -> OrderedFloat<f32> {
let len = self.0.len();
assert_eq!(len, other.0.len());
@ -58,7 +59,7 @@ impl Embedding {
1,
);
}
result
OrderedFloat(result)
}
}
@ -379,13 +380,13 @@ mod tests {
);
}
fn round_to_decimals(n: f32, decimal_places: i32) -> f32 {
fn round_to_decimals(n: OrderedFloat<f32>, decimal_places: i32) -> f32 {
let factor = (10.0 as f32).powi(decimal_places);
(n * factor).round() / factor
}
fn reference_dot(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b.iter()).map(|(a, b)| a * b).sum()
fn reference_dot(a: &[f32], b: &[f32]) -> OrderedFloat<f32> {
OrderedFloat(a.iter().zip(b.iter()).map(|(a, b)| a * b).sum())
}
}
}

View File

@ -7,6 +7,7 @@ use rusqlite::{
};
use sha1::{Digest, Sha1};
use std::{
borrow::Cow,
cmp::{self, Reverse},
collections::HashSet,
ops::Range,
@ -16,7 +17,7 @@ use std::{
use tree_sitter::{Parser, QueryCursor};
#[derive(Debug, PartialEq, Eq, Clone, Hash)]
pub struct SpanDigest([u8; 20]);
pub struct SpanDigest(pub [u8; 20]);
impl FromSql for SpanDigest {
fn column_result(value: ValueRef) -> FromSqlResult<Self> {
@ -94,12 +95,15 @@ impl CodeContextRetriever {
fn parse_entire_file(
&self,
relative_path: &Path,
relative_path: Option<&Path>,
language_name: Arc<str>,
content: &str,
) -> Result<Vec<Span>> {
let document_span = ENTIRE_FILE_TEMPLATE
.replace("<path>", relative_path.to_string_lossy().as_ref())
.replace(
"<path>",
&relative_path.map_or(Cow::Borrowed("untitled"), |path| path.to_string_lossy()),
)
.replace("<language>", language_name.as_ref())
.replace("<item>", &content);
let digest = SpanDigest::from(document_span.as_str());
@ -114,9 +118,16 @@ impl CodeContextRetriever {
}])
}
fn parse_markdown_file(&self, relative_path: &Path, content: &str) -> Result<Vec<Span>> {
fn parse_markdown_file(
&self,
relative_path: Option<&Path>,
content: &str,
) -> Result<Vec<Span>> {
let document_span = MARKDOWN_CONTEXT_TEMPLATE
.replace("<path>", relative_path.to_string_lossy().as_ref())
.replace(
"<path>",
&relative_path.map_or(Cow::Borrowed("untitled"), |path| path.to_string_lossy()),
)
.replace("<item>", &content);
let digest = SpanDigest::from(document_span.as_str());
let (document_span, token_count) = self.embedding_provider.truncate(&document_span);
@ -188,7 +199,7 @@ impl CodeContextRetriever {
pub fn parse_file_with_template(
&mut self,
relative_path: &Path,
relative_path: Option<&Path>,
content: &str,
language: Arc<Language>,
) -> Result<Vec<Span>> {
@ -196,14 +207,17 @@ impl CodeContextRetriever {
if PARSEABLE_ENTIRE_FILE_TYPES.contains(&language_name.as_ref()) {
return self.parse_entire_file(relative_path, language_name, &content);
} else if language_name.as_ref() == "Markdown" {
} else if ["Markdown", "Plain Text"].contains(&language_name.as_ref()) {
return self.parse_markdown_file(relative_path, &content);
}
let mut spans = self.parse_file(content, language)?;
for span in &mut spans {
let document_content = CODE_CONTEXT_TEMPLATE
.replace("<path>", relative_path.to_string_lossy().as_ref())
.replace(
"<path>",
&relative_path.map_or(Cow::Borrowed("untitled"), |path| path.to_string_lossy()),
)
.replace("<language>", language_name.as_ref())
.replace("item", &span.content);

View File

@ -16,14 +16,16 @@ use embedding_queue::{EmbeddingQueue, FileToEmbed};
use futures::{future, FutureExt, StreamExt};
use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle};
use language::{Anchor, Bias, Buffer, Language, LanguageRegistry};
use ordered_float::OrderedFloat;
use parking_lot::Mutex;
use parsing::{CodeContextRetriever, SpanDigest, PARSEABLE_ENTIRE_FILE_TYPES};
use parsing::{CodeContextRetriever, Span, SpanDigest, PARSEABLE_ENTIRE_FILE_TYPES};
use postage::watch;
use project::{search::PathMatcher, Fs, PathChange, Project, ProjectEntryId, Worktree, WorktreeId};
use smol::channel;
use std::{
cmp::Ordering,
cmp::Reverse,
future::Future,
mem,
ops::Range,
path::{Path, PathBuf},
sync::{Arc, Weak},
@ -37,7 +39,7 @@ use util::{
};
use workspace::WorkspaceCreated;
const SEMANTIC_INDEX_VERSION: usize = 10;
const SEMANTIC_INDEX_VERSION: usize = 11;
const BACKGROUND_INDEXING_DELAY: Duration = Duration::from_secs(5 * 60);
const EMBEDDING_QUEUE_FLUSH_TIMEOUT: Duration = Duration::from_millis(250);
@ -262,9 +264,11 @@ pub struct PendingFile {
job_handle: JobHandle,
}
#[derive(Clone)]
pub struct SearchResult {
pub buffer: ModelHandle<Buffer>,
pub range: Range<Anchor>,
pub similarity: OrderedFloat<f32>,
}
impl SemanticIndex {
@ -402,7 +406,7 @@ impl SemanticIndex {
if let Some(content) = fs.load(&pending_file.absolute_path).await.log_err() {
if let Some(mut spans) = retriever
.parse_file_with_template(&pending_file.relative_path, &content, language)
.parse_file_with_template(Some(&pending_file.relative_path), &content, language)
.log_err()
{
log::trace!(
@ -422,7 +426,7 @@ impl SemanticIndex {
path: pending_file.relative_path,
mtime: pending_file.modified_time,
job_handle: pending_file.job_handle,
spans: spans,
spans,
});
}
}
@ -687,39 +691,71 @@ impl SemanticIndex {
pub fn search_project(
&mut self,
project: ModelHandle<Project>,
phrase: String,
query: String,
limit: usize,
includes: Vec<PathMatcher>,
excludes: Vec<PathMatcher>,
cx: &mut ModelContext<Self>,
) -> Task<Result<Vec<SearchResult>>> {
if query.is_empty() {
return Task::ready(Ok(Vec::new()));
}
let index = self.index_project(project.clone(), cx);
let embedding_provider = self.embedding_provider.clone();
cx.spawn(|this, mut cx| async move {
let query = embedding_provider
.embed_batch(vec![query])
.await?
.pop()
.ok_or_else(|| anyhow!("could not embed query"))?;
index.await?;
let search_start = Instant::now();
let modified_buffer_results = this.update(&mut cx, |this, cx| {
this.search_modified_buffers(&project, query.clone(), limit, &excludes, cx)
});
let file_results = this.update(&mut cx, |this, cx| {
this.search_files(project, query, limit, includes, excludes, cx)
});
let (modified_buffer_results, file_results) =
futures::join!(modified_buffer_results, file_results);
// Weave together the results from modified buffers and files.
let mut results = Vec::new();
let mut modified_buffers = HashSet::default();
for result in modified_buffer_results.log_err().unwrap_or_default() {
modified_buffers.insert(result.buffer.clone());
results.push(result);
}
for result in file_results.log_err().unwrap_or_default() {
if !modified_buffers.contains(&result.buffer) {
results.push(result);
}
}
results.sort_by_key(|result| Reverse(result.similarity));
results.truncate(limit);
log::trace!("Semantic search took {:?}", search_start.elapsed());
Ok(results)
})
}
pub fn search_files(
&mut self,
project: ModelHandle<Project>,
query: Embedding,
limit: usize,
includes: Vec<PathMatcher>,
excludes: Vec<PathMatcher>,
cx: &mut ModelContext<Self>,
) -> Task<Result<Vec<SearchResult>>> {
let db_path = self.db.path().clone();
let fs = self.fs.clone();
cx.spawn(|this, mut cx| async move {
index.await?;
let t0 = Instant::now();
let database =
VectorDatabase::new(fs.clone(), db_path.clone(), cx.background()).await?;
if phrase.len() == 0 {
return Ok(Vec::new());
}
let phrase_embedding = embedding_provider
.embed_batch(vec![phrase])
.await?
.into_iter()
.next()
.unwrap();
log::trace!(
"Embedding search phrase took: {:?} milliseconds",
t0.elapsed().as_millis()
);
let worktree_db_ids = this.read_with(&cx, |this, _| {
let project_state = this
.projects
@ -738,6 +774,7 @@ impl SemanticIndex {
.collect::<Vec<i64>>();
anyhow::Ok(worktree_db_ids)
})?;
let file_ids = database
.retrieve_included_file_ids(&worktree_db_ids, &includes, &excludes)
.await?;
@ -756,26 +793,26 @@ impl SemanticIndex {
let limit = limit.clone();
let fs = fs.clone();
let db_path = db_path.clone();
let phrase_embedding = phrase_embedding.clone();
let query = query.clone();
if let Some(db) = VectorDatabase::new(fs, db_path.clone(), cx.background())
.await
.log_err()
{
batch_results.push(async move {
db.top_k_search(&phrase_embedding, limit, batch.as_slice())
.await
db.top_k_search(&query, limit, batch.as_slice()).await
});
}
}
let batch_results = futures::future::join_all(batch_results).await;
let mut results = Vec::new();
for batch_result in batch_results {
if batch_result.is_ok() {
for (id, similarity) in batch_result.unwrap() {
let ix = match results.binary_search_by(|(_, s)| {
similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)
}) {
let ix = match results
.binary_search_by_key(&Reverse(similarity), |(_, s)| Reverse(*s))
{
Ok(ix) => ix,
Err(ix) => ix,
};
@ -785,7 +822,11 @@ impl SemanticIndex {
}
}
let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<i64>>();
let ids = results.iter().map(|(id, _)| *id).collect::<Vec<i64>>();
let scores = results
.into_iter()
.map(|(_, score)| score)
.collect::<Vec<_>>();
let spans = database.spans_for_ids(ids.as_slice()).await?;
let mut tasks = Vec::new();
@ -810,24 +851,106 @@ impl SemanticIndex {
let buffers = futures::future::join_all(tasks).await;
log::trace!(
"Semantic Searching took: {:?} milliseconds in total",
t0.elapsed().as_millis()
);
Ok(buffers
.into_iter()
.zip(ranges)
.filter_map(|(buffer, range)| {
.zip(scores)
.filter_map(|((buffer, range), similarity)| {
let buffer = buffer.log_err()?;
let range = buffer.read_with(&cx, |buffer, _| {
let start = buffer.clip_offset(range.start, Bias::Left);
let end = buffer.clip_offset(range.end, Bias::Right);
buffer.anchor_before(start)..buffer.anchor_after(end)
});
Some(SearchResult { buffer, range })
Some(SearchResult {
buffer,
range,
similarity,
})
})
.collect::<Vec<_>>())
.collect())
})
}
fn search_modified_buffers(
&self,
project: &ModelHandle<Project>,
query: Embedding,
limit: usize,
excludes: &[PathMatcher],
cx: &mut ModelContext<Self>,
) -> Task<Result<Vec<SearchResult>>> {
let modified_buffers = project
.read(cx)
.opened_buffers(cx)
.into_iter()
.filter_map(|buffer_handle| {
let buffer = buffer_handle.read(cx);
let snapshot = buffer.snapshot();
let excluded = snapshot.resolve_file_path(cx, false).map_or(false, |path| {
excludes.iter().any(|matcher| matcher.is_match(&path))
});
if buffer.is_dirty() && !excluded {
Some((buffer_handle, snapshot))
} else {
None
}
})
.collect::<HashMap<_, _>>();
let embedding_provider = self.embedding_provider.clone();
let fs = self.fs.clone();
let db_path = self.db.path().clone();
let background = cx.background().clone();
cx.background().spawn(async move {
let db = VectorDatabase::new(fs, db_path.clone(), background).await?;
let mut results = Vec::<SearchResult>::new();
let mut retriever = CodeContextRetriever::new(embedding_provider.clone());
for (buffer, snapshot) in modified_buffers {
let language = snapshot
.language_at(0)
.cloned()
.unwrap_or_else(|| language::PLAIN_TEXT.clone());
let mut spans = retriever
.parse_file_with_template(None, &snapshot.text(), language)
.log_err()
.unwrap_or_default();
if Self::embed_spans(&mut spans, embedding_provider.as_ref(), &db)
.await
.log_err()
.is_some()
{
for span in spans {
let similarity = span.embedding.unwrap().similarity(&query);
let ix = match results
.binary_search_by_key(&Reverse(similarity), |result| {
Reverse(result.similarity)
}) {
Ok(ix) => ix,
Err(ix) => ix,
};
let range = {
let start = snapshot.clip_offset(span.range.start, Bias::Left);
let end = snapshot.clip_offset(span.range.end, Bias::Right);
snapshot.anchor_before(start)..snapshot.anchor_after(end)
};
results.insert(
ix,
SearchResult {
buffer: buffer.clone(),
range,
similarity,
},
);
results.truncate(limit);
}
}
}
Ok(results)
})
}
@ -1011,6 +1134,63 @@ impl SemanticIndex {
Ok(())
})
}
async fn embed_spans(
spans: &mut [Span],
embedding_provider: &dyn EmbeddingProvider,
db: &VectorDatabase,
) -> Result<()> {
let mut batch = Vec::new();
let mut batch_tokens = 0;
let mut embeddings = Vec::new();
let digests = spans
.iter()
.map(|span| span.digest.clone())
.collect::<Vec<_>>();
let embeddings_for_digests = db
.embeddings_for_digests(digests)
.await
.log_err()
.unwrap_or_default();
for span in &*spans {
if embeddings_for_digests.contains_key(&span.digest) {
continue;
};
if batch_tokens + span.token_count > embedding_provider.max_tokens_per_batch() {
let batch_embeddings = embedding_provider
.embed_batch(mem::take(&mut batch))
.await?;
embeddings.extend(batch_embeddings);
batch_tokens = 0;
}
batch_tokens += span.token_count;
batch.push(span.content.clone());
}
if !batch.is_empty() {
let batch_embeddings = embedding_provider
.embed_batch(mem::take(&mut batch))
.await?;
embeddings.extend(batch_embeddings);
}
let mut embeddings = embeddings.into_iter();
for span in spans {
let embedding = if let Some(embedding) = embeddings_for_digests.get(&span.digest) {
Some(embedding.clone())
} else {
embeddings.next()
};
let embedding = embedding.ok_or_else(|| anyhow!("failed to embed spans"))?;
span.embedding = Some(embedding);
}
Ok(())
}
}
impl Entity for SemanticIndex {