mirror of
https://github.com/zed-industries/zed.git
synced 2024-11-12 19:11:23 +03:00
Include modified buffers in semantic search results (#2970)
This pull request introduces an additional step to `SemanticIndex::search_project` that includes the content of buffers that are modified but haven't been saved yet. In most cases, the buffer will contain a small portion of changed spans that are potentially not included in the index. To reuse all the other spans that haven't changed, we will query the database for embeddings by their digest. This means we have to index spans by their digest, which means some penalty when writing, but in our tests this didn't seem to make indexing much slower. Release Notes: - Improved semantic search to include results from modified buffers. (preview-only)
This commit is contained in:
commit
a1250b8525
1
Cargo.lock
generated
1
Cargo.lock
generated
@ -6739,6 +6739,7 @@ dependencies = [
|
||||
"lazy_static",
|
||||
"log",
|
||||
"matrixmultiply",
|
||||
"ordered-float",
|
||||
"parking_lot 0.11.2",
|
||||
"parse_duration",
|
||||
"picker",
|
||||
|
@ -912,7 +912,6 @@ impl Project {
|
||||
self.user_store.clone()
|
||||
}
|
||||
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
pub fn opened_buffers(&self, cx: &AppContext) -> Vec<ModelHandle<Buffer>> {
|
||||
self.opened_buffers
|
||||
.values()
|
||||
|
@ -23,6 +23,7 @@ settings = { path = "../settings" }
|
||||
anyhow.workspace = true
|
||||
postage.workspace = true
|
||||
futures.workspace = true
|
||||
ordered-float.workspace = true
|
||||
smol.workspace = true
|
||||
rusqlite = { version = "0.27.0", features = ["blob", "array", "modern_sqlite"] }
|
||||
isahc.workspace = true
|
||||
|
@ -7,12 +7,13 @@ use anyhow::{anyhow, Context, Result};
|
||||
use collections::HashMap;
|
||||
use futures::channel::oneshot;
|
||||
use gpui::executor;
|
||||
use ordered_float::OrderedFloat;
|
||||
use project::{search::PathMatcher, Fs};
|
||||
use rpc::proto::Timestamp;
|
||||
use rusqlite::params;
|
||||
use rusqlite::types::Value;
|
||||
use std::{
|
||||
cmp::Ordering,
|
||||
cmp::Reverse,
|
||||
future::Future,
|
||||
ops::Range,
|
||||
path::{Path, PathBuf},
|
||||
@ -190,6 +191,10 @@ impl VectorDatabase {
|
||||
)",
|
||||
[],
|
||||
)?;
|
||||
db.execute(
|
||||
"CREATE INDEX spans_digest ON spans (digest)",
|
||||
[],
|
||||
)?;
|
||||
|
||||
log::trace!("vector database initialized with updated schema.");
|
||||
Ok(())
|
||||
@ -274,6 +279,39 @@ impl VectorDatabase {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn embeddings_for_digests(
|
||||
&self,
|
||||
digests: Vec<SpanDigest>,
|
||||
) -> impl Future<Output = Result<HashMap<SpanDigest, Embedding>>> {
|
||||
self.transact(move |db| {
|
||||
let mut query = db.prepare(
|
||||
"
|
||||
SELECT digest, embedding
|
||||
FROM spans
|
||||
WHERE digest IN rarray(?)
|
||||
",
|
||||
)?;
|
||||
let mut embeddings_by_digest = HashMap::default();
|
||||
let digests = Rc::new(
|
||||
digests
|
||||
.into_iter()
|
||||
.map(|p| Value::Blob(p.0.to_vec()))
|
||||
.collect::<Vec<_>>(),
|
||||
);
|
||||
let rows = query.query_map(params![digests], |row| {
|
||||
Ok((row.get::<_, SpanDigest>(0)?, row.get::<_, Embedding>(1)?))
|
||||
})?;
|
||||
|
||||
for row in rows {
|
||||
if let Ok(row) = row {
|
||||
embeddings_by_digest.insert(row.0, row.1);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(embeddings_by_digest)
|
||||
})
|
||||
}
|
||||
|
||||
pub fn embeddings_for_files(
|
||||
&self,
|
||||
worktree_id_file_paths: HashMap<i64, Vec<Arc<Path>>>,
|
||||
@ -370,16 +408,16 @@ impl VectorDatabase {
|
||||
query_embedding: &Embedding,
|
||||
limit: usize,
|
||||
file_ids: &[i64],
|
||||
) -> impl Future<Output = Result<Vec<(i64, f32)>>> {
|
||||
) -> impl Future<Output = Result<Vec<(i64, OrderedFloat<f32>)>>> {
|
||||
let query_embedding = query_embedding.clone();
|
||||
let file_ids = file_ids.to_vec();
|
||||
self.transact(move |db| {
|
||||
let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
|
||||
let mut results = Vec::<(i64, OrderedFloat<f32>)>::with_capacity(limit + 1);
|
||||
Self::for_each_span(db, &file_ids, |id, embedding| {
|
||||
let similarity = embedding.similarity(&query_embedding);
|
||||
let ix = match results.binary_search_by(|(_, s)| {
|
||||
similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)
|
||||
}) {
|
||||
let ix = match results
|
||||
.binary_search_by_key(&Reverse(similarity), |(_, s)| Reverse(*s))
|
||||
{
|
||||
Ok(ix) => ix,
|
||||
Err(ix) => ix,
|
||||
};
|
||||
|
@ -7,6 +7,7 @@ use isahc::http::StatusCode;
|
||||
use isahc::prelude::Configurable;
|
||||
use isahc::{AsyncBody, Response};
|
||||
use lazy_static::lazy_static;
|
||||
use ordered_float::OrderedFloat;
|
||||
use parking_lot::Mutex;
|
||||
use parse_duration::parse;
|
||||
use postage::watch;
|
||||
@ -35,7 +36,7 @@ impl From<Vec<f32>> for Embedding {
|
||||
}
|
||||
|
||||
impl Embedding {
|
||||
pub fn similarity(&self, other: &Self) -> f32 {
|
||||
pub fn similarity(&self, other: &Self) -> OrderedFloat<f32> {
|
||||
let len = self.0.len();
|
||||
assert_eq!(len, other.0.len());
|
||||
|
||||
@ -58,7 +59,7 @@ impl Embedding {
|
||||
1,
|
||||
);
|
||||
}
|
||||
result
|
||||
OrderedFloat(result)
|
||||
}
|
||||
}
|
||||
|
||||
@ -379,13 +380,13 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
fn round_to_decimals(n: f32, decimal_places: i32) -> f32 {
|
||||
fn round_to_decimals(n: OrderedFloat<f32>, decimal_places: i32) -> f32 {
|
||||
let factor = (10.0 as f32).powi(decimal_places);
|
||||
(n * factor).round() / factor
|
||||
}
|
||||
|
||||
fn reference_dot(a: &[f32], b: &[f32]) -> f32 {
|
||||
a.iter().zip(b.iter()).map(|(a, b)| a * b).sum()
|
||||
fn reference_dot(a: &[f32], b: &[f32]) -> OrderedFloat<f32> {
|
||||
OrderedFloat(a.iter().zip(b.iter()).map(|(a, b)| a * b).sum())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -7,6 +7,7 @@ use rusqlite::{
|
||||
};
|
||||
use sha1::{Digest, Sha1};
|
||||
use std::{
|
||||
borrow::Cow,
|
||||
cmp::{self, Reverse},
|
||||
collections::HashSet,
|
||||
ops::Range,
|
||||
@ -16,7 +17,7 @@ use std::{
|
||||
use tree_sitter::{Parser, QueryCursor};
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Hash)]
|
||||
pub struct SpanDigest([u8; 20]);
|
||||
pub struct SpanDigest(pub [u8; 20]);
|
||||
|
||||
impl FromSql for SpanDigest {
|
||||
fn column_result(value: ValueRef) -> FromSqlResult<Self> {
|
||||
@ -94,12 +95,15 @@ impl CodeContextRetriever {
|
||||
|
||||
fn parse_entire_file(
|
||||
&self,
|
||||
relative_path: &Path,
|
||||
relative_path: Option<&Path>,
|
||||
language_name: Arc<str>,
|
||||
content: &str,
|
||||
) -> Result<Vec<Span>> {
|
||||
let document_span = ENTIRE_FILE_TEMPLATE
|
||||
.replace("<path>", relative_path.to_string_lossy().as_ref())
|
||||
.replace(
|
||||
"<path>",
|
||||
&relative_path.map_or(Cow::Borrowed("untitled"), |path| path.to_string_lossy()),
|
||||
)
|
||||
.replace("<language>", language_name.as_ref())
|
||||
.replace("<item>", &content);
|
||||
let digest = SpanDigest::from(document_span.as_str());
|
||||
@ -114,9 +118,16 @@ impl CodeContextRetriever {
|
||||
}])
|
||||
}
|
||||
|
||||
fn parse_markdown_file(&self, relative_path: &Path, content: &str) -> Result<Vec<Span>> {
|
||||
fn parse_markdown_file(
|
||||
&self,
|
||||
relative_path: Option<&Path>,
|
||||
content: &str,
|
||||
) -> Result<Vec<Span>> {
|
||||
let document_span = MARKDOWN_CONTEXT_TEMPLATE
|
||||
.replace("<path>", relative_path.to_string_lossy().as_ref())
|
||||
.replace(
|
||||
"<path>",
|
||||
&relative_path.map_or(Cow::Borrowed("untitled"), |path| path.to_string_lossy()),
|
||||
)
|
||||
.replace("<item>", &content);
|
||||
let digest = SpanDigest::from(document_span.as_str());
|
||||
let (document_span, token_count) = self.embedding_provider.truncate(&document_span);
|
||||
@ -188,7 +199,7 @@ impl CodeContextRetriever {
|
||||
|
||||
pub fn parse_file_with_template(
|
||||
&mut self,
|
||||
relative_path: &Path,
|
||||
relative_path: Option<&Path>,
|
||||
content: &str,
|
||||
language: Arc<Language>,
|
||||
) -> Result<Vec<Span>> {
|
||||
@ -196,14 +207,17 @@ impl CodeContextRetriever {
|
||||
|
||||
if PARSEABLE_ENTIRE_FILE_TYPES.contains(&language_name.as_ref()) {
|
||||
return self.parse_entire_file(relative_path, language_name, &content);
|
||||
} else if language_name.as_ref() == "Markdown" {
|
||||
} else if ["Markdown", "Plain Text"].contains(&language_name.as_ref()) {
|
||||
return self.parse_markdown_file(relative_path, &content);
|
||||
}
|
||||
|
||||
let mut spans = self.parse_file(content, language)?;
|
||||
for span in &mut spans {
|
||||
let document_content = CODE_CONTEXT_TEMPLATE
|
||||
.replace("<path>", relative_path.to_string_lossy().as_ref())
|
||||
.replace(
|
||||
"<path>",
|
||||
&relative_path.map_or(Cow::Borrowed("untitled"), |path| path.to_string_lossy()),
|
||||
)
|
||||
.replace("<language>", language_name.as_ref())
|
||||
.replace("item", &span.content);
|
||||
|
||||
|
@ -16,14 +16,16 @@ use embedding_queue::{EmbeddingQueue, FileToEmbed};
|
||||
use futures::{future, FutureExt, StreamExt};
|
||||
use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle};
|
||||
use language::{Anchor, Bias, Buffer, Language, LanguageRegistry};
|
||||
use ordered_float::OrderedFloat;
|
||||
use parking_lot::Mutex;
|
||||
use parsing::{CodeContextRetriever, SpanDigest, PARSEABLE_ENTIRE_FILE_TYPES};
|
||||
use parsing::{CodeContextRetriever, Span, SpanDigest, PARSEABLE_ENTIRE_FILE_TYPES};
|
||||
use postage::watch;
|
||||
use project::{search::PathMatcher, Fs, PathChange, Project, ProjectEntryId, Worktree, WorktreeId};
|
||||
use smol::channel;
|
||||
use std::{
|
||||
cmp::Ordering,
|
||||
cmp::Reverse,
|
||||
future::Future,
|
||||
mem,
|
||||
ops::Range,
|
||||
path::{Path, PathBuf},
|
||||
sync::{Arc, Weak},
|
||||
@ -37,7 +39,7 @@ use util::{
|
||||
};
|
||||
use workspace::WorkspaceCreated;
|
||||
|
||||
const SEMANTIC_INDEX_VERSION: usize = 10;
|
||||
const SEMANTIC_INDEX_VERSION: usize = 11;
|
||||
const BACKGROUND_INDEXING_DELAY: Duration = Duration::from_secs(5 * 60);
|
||||
const EMBEDDING_QUEUE_FLUSH_TIMEOUT: Duration = Duration::from_millis(250);
|
||||
|
||||
@ -262,9 +264,11 @@ pub struct PendingFile {
|
||||
job_handle: JobHandle,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct SearchResult {
|
||||
pub buffer: ModelHandle<Buffer>,
|
||||
pub range: Range<Anchor>,
|
||||
pub similarity: OrderedFloat<f32>,
|
||||
}
|
||||
|
||||
impl SemanticIndex {
|
||||
@ -402,7 +406,7 @@ impl SemanticIndex {
|
||||
|
||||
if let Some(content) = fs.load(&pending_file.absolute_path).await.log_err() {
|
||||
if let Some(mut spans) = retriever
|
||||
.parse_file_with_template(&pending_file.relative_path, &content, language)
|
||||
.parse_file_with_template(Some(&pending_file.relative_path), &content, language)
|
||||
.log_err()
|
||||
{
|
||||
log::trace!(
|
||||
@ -422,7 +426,7 @@ impl SemanticIndex {
|
||||
path: pending_file.relative_path,
|
||||
mtime: pending_file.modified_time,
|
||||
job_handle: pending_file.job_handle,
|
||||
spans: spans,
|
||||
spans,
|
||||
});
|
||||
}
|
||||
}
|
||||
@ -687,39 +691,71 @@ impl SemanticIndex {
|
||||
pub fn search_project(
|
||||
&mut self,
|
||||
project: ModelHandle<Project>,
|
||||
phrase: String,
|
||||
query: String,
|
||||
limit: usize,
|
||||
includes: Vec<PathMatcher>,
|
||||
excludes: Vec<PathMatcher>,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) -> Task<Result<Vec<SearchResult>>> {
|
||||
if query.is_empty() {
|
||||
return Task::ready(Ok(Vec::new()));
|
||||
}
|
||||
|
||||
let index = self.index_project(project.clone(), cx);
|
||||
let embedding_provider = self.embedding_provider.clone();
|
||||
|
||||
cx.spawn(|this, mut cx| async move {
|
||||
let query = embedding_provider
|
||||
.embed_batch(vec![query])
|
||||
.await?
|
||||
.pop()
|
||||
.ok_or_else(|| anyhow!("could not embed query"))?;
|
||||
index.await?;
|
||||
|
||||
let search_start = Instant::now();
|
||||
let modified_buffer_results = this.update(&mut cx, |this, cx| {
|
||||
this.search_modified_buffers(&project, query.clone(), limit, &excludes, cx)
|
||||
});
|
||||
let file_results = this.update(&mut cx, |this, cx| {
|
||||
this.search_files(project, query, limit, includes, excludes, cx)
|
||||
});
|
||||
let (modified_buffer_results, file_results) =
|
||||
futures::join!(modified_buffer_results, file_results);
|
||||
|
||||
// Weave together the results from modified buffers and files.
|
||||
let mut results = Vec::new();
|
||||
let mut modified_buffers = HashSet::default();
|
||||
for result in modified_buffer_results.log_err().unwrap_or_default() {
|
||||
modified_buffers.insert(result.buffer.clone());
|
||||
results.push(result);
|
||||
}
|
||||
for result in file_results.log_err().unwrap_or_default() {
|
||||
if !modified_buffers.contains(&result.buffer) {
|
||||
results.push(result);
|
||||
}
|
||||
}
|
||||
results.sort_by_key(|result| Reverse(result.similarity));
|
||||
results.truncate(limit);
|
||||
log::trace!("Semantic search took {:?}", search_start.elapsed());
|
||||
Ok(results)
|
||||
})
|
||||
}
|
||||
|
||||
pub fn search_files(
|
||||
&mut self,
|
||||
project: ModelHandle<Project>,
|
||||
query: Embedding,
|
||||
limit: usize,
|
||||
includes: Vec<PathMatcher>,
|
||||
excludes: Vec<PathMatcher>,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) -> Task<Result<Vec<SearchResult>>> {
|
||||
let db_path = self.db.path().clone();
|
||||
let fs = self.fs.clone();
|
||||
cx.spawn(|this, mut cx| async move {
|
||||
index.await?;
|
||||
|
||||
let t0 = Instant::now();
|
||||
let database =
|
||||
VectorDatabase::new(fs.clone(), db_path.clone(), cx.background()).await?;
|
||||
|
||||
if phrase.len() == 0 {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
let phrase_embedding = embedding_provider
|
||||
.embed_batch(vec![phrase])
|
||||
.await?
|
||||
.into_iter()
|
||||
.next()
|
||||
.unwrap();
|
||||
|
||||
log::trace!(
|
||||
"Embedding search phrase took: {:?} milliseconds",
|
||||
t0.elapsed().as_millis()
|
||||
);
|
||||
|
||||
let worktree_db_ids = this.read_with(&cx, |this, _| {
|
||||
let project_state = this
|
||||
.projects
|
||||
@ -738,6 +774,7 @@ impl SemanticIndex {
|
||||
.collect::<Vec<i64>>();
|
||||
anyhow::Ok(worktree_db_ids)
|
||||
})?;
|
||||
|
||||
let file_ids = database
|
||||
.retrieve_included_file_ids(&worktree_db_ids, &includes, &excludes)
|
||||
.await?;
|
||||
@ -756,26 +793,26 @@ impl SemanticIndex {
|
||||
let limit = limit.clone();
|
||||
let fs = fs.clone();
|
||||
let db_path = db_path.clone();
|
||||
let phrase_embedding = phrase_embedding.clone();
|
||||
let query = query.clone();
|
||||
if let Some(db) = VectorDatabase::new(fs, db_path.clone(), cx.background())
|
||||
.await
|
||||
.log_err()
|
||||
{
|
||||
batch_results.push(async move {
|
||||
db.top_k_search(&phrase_embedding, limit, batch.as_slice())
|
||||
.await
|
||||
db.top_k_search(&query, limit, batch.as_slice()).await
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
let batch_results = futures::future::join_all(batch_results).await;
|
||||
|
||||
let mut results = Vec::new();
|
||||
for batch_result in batch_results {
|
||||
if batch_result.is_ok() {
|
||||
for (id, similarity) in batch_result.unwrap() {
|
||||
let ix = match results.binary_search_by(|(_, s)| {
|
||||
similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)
|
||||
}) {
|
||||
let ix = match results
|
||||
.binary_search_by_key(&Reverse(similarity), |(_, s)| Reverse(*s))
|
||||
{
|
||||
Ok(ix) => ix,
|
||||
Err(ix) => ix,
|
||||
};
|
||||
@ -785,7 +822,11 @@ impl SemanticIndex {
|
||||
}
|
||||
}
|
||||
|
||||
let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<i64>>();
|
||||
let ids = results.iter().map(|(id, _)| *id).collect::<Vec<i64>>();
|
||||
let scores = results
|
||||
.into_iter()
|
||||
.map(|(_, score)| score)
|
||||
.collect::<Vec<_>>();
|
||||
let spans = database.spans_for_ids(ids.as_slice()).await?;
|
||||
|
||||
let mut tasks = Vec::new();
|
||||
@ -810,24 +851,106 @@ impl SemanticIndex {
|
||||
|
||||
let buffers = futures::future::join_all(tasks).await;
|
||||
|
||||
log::trace!(
|
||||
"Semantic Searching took: {:?} milliseconds in total",
|
||||
t0.elapsed().as_millis()
|
||||
);
|
||||
|
||||
Ok(buffers
|
||||
.into_iter()
|
||||
.zip(ranges)
|
||||
.filter_map(|(buffer, range)| {
|
||||
.zip(scores)
|
||||
.filter_map(|((buffer, range), similarity)| {
|
||||
let buffer = buffer.log_err()?;
|
||||
let range = buffer.read_with(&cx, |buffer, _| {
|
||||
let start = buffer.clip_offset(range.start, Bias::Left);
|
||||
let end = buffer.clip_offset(range.end, Bias::Right);
|
||||
buffer.anchor_before(start)..buffer.anchor_after(end)
|
||||
});
|
||||
Some(SearchResult { buffer, range })
|
||||
Some(SearchResult {
|
||||
buffer,
|
||||
range,
|
||||
similarity,
|
||||
})
|
||||
})
|
||||
.collect::<Vec<_>>())
|
||||
.collect())
|
||||
})
|
||||
}
|
||||
|
||||
fn search_modified_buffers(
|
||||
&self,
|
||||
project: &ModelHandle<Project>,
|
||||
query: Embedding,
|
||||
limit: usize,
|
||||
excludes: &[PathMatcher],
|
||||
cx: &mut ModelContext<Self>,
|
||||
) -> Task<Result<Vec<SearchResult>>> {
|
||||
let modified_buffers = project
|
||||
.read(cx)
|
||||
.opened_buffers(cx)
|
||||
.into_iter()
|
||||
.filter_map(|buffer_handle| {
|
||||
let buffer = buffer_handle.read(cx);
|
||||
let snapshot = buffer.snapshot();
|
||||
let excluded = snapshot.resolve_file_path(cx, false).map_or(false, |path| {
|
||||
excludes.iter().any(|matcher| matcher.is_match(&path))
|
||||
});
|
||||
if buffer.is_dirty() && !excluded {
|
||||
Some((buffer_handle, snapshot))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect::<HashMap<_, _>>();
|
||||
|
||||
let embedding_provider = self.embedding_provider.clone();
|
||||
let fs = self.fs.clone();
|
||||
let db_path = self.db.path().clone();
|
||||
let background = cx.background().clone();
|
||||
cx.background().spawn(async move {
|
||||
let db = VectorDatabase::new(fs, db_path.clone(), background).await?;
|
||||
let mut results = Vec::<SearchResult>::new();
|
||||
|
||||
let mut retriever = CodeContextRetriever::new(embedding_provider.clone());
|
||||
for (buffer, snapshot) in modified_buffers {
|
||||
let language = snapshot
|
||||
.language_at(0)
|
||||
.cloned()
|
||||
.unwrap_or_else(|| language::PLAIN_TEXT.clone());
|
||||
let mut spans = retriever
|
||||
.parse_file_with_template(None, &snapshot.text(), language)
|
||||
.log_err()
|
||||
.unwrap_or_default();
|
||||
if Self::embed_spans(&mut spans, embedding_provider.as_ref(), &db)
|
||||
.await
|
||||
.log_err()
|
||||
.is_some()
|
||||
{
|
||||
for span in spans {
|
||||
let similarity = span.embedding.unwrap().similarity(&query);
|
||||
let ix = match results
|
||||
.binary_search_by_key(&Reverse(similarity), |result| {
|
||||
Reverse(result.similarity)
|
||||
}) {
|
||||
Ok(ix) => ix,
|
||||
Err(ix) => ix,
|
||||
};
|
||||
|
||||
let range = {
|
||||
let start = snapshot.clip_offset(span.range.start, Bias::Left);
|
||||
let end = snapshot.clip_offset(span.range.end, Bias::Right);
|
||||
snapshot.anchor_before(start)..snapshot.anchor_after(end)
|
||||
};
|
||||
|
||||
results.insert(
|
||||
ix,
|
||||
SearchResult {
|
||||
buffer: buffer.clone(),
|
||||
range,
|
||||
similarity,
|
||||
},
|
||||
);
|
||||
results.truncate(limit);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(results)
|
||||
})
|
||||
}
|
||||
|
||||
@ -1011,6 +1134,63 @@ impl SemanticIndex {
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
async fn embed_spans(
|
||||
spans: &mut [Span],
|
||||
embedding_provider: &dyn EmbeddingProvider,
|
||||
db: &VectorDatabase,
|
||||
) -> Result<()> {
|
||||
let mut batch = Vec::new();
|
||||
let mut batch_tokens = 0;
|
||||
let mut embeddings = Vec::new();
|
||||
|
||||
let digests = spans
|
||||
.iter()
|
||||
.map(|span| span.digest.clone())
|
||||
.collect::<Vec<_>>();
|
||||
let embeddings_for_digests = db
|
||||
.embeddings_for_digests(digests)
|
||||
.await
|
||||
.log_err()
|
||||
.unwrap_or_default();
|
||||
|
||||
for span in &*spans {
|
||||
if embeddings_for_digests.contains_key(&span.digest) {
|
||||
continue;
|
||||
};
|
||||
|
||||
if batch_tokens + span.token_count > embedding_provider.max_tokens_per_batch() {
|
||||
let batch_embeddings = embedding_provider
|
||||
.embed_batch(mem::take(&mut batch))
|
||||
.await?;
|
||||
embeddings.extend(batch_embeddings);
|
||||
batch_tokens = 0;
|
||||
}
|
||||
|
||||
batch_tokens += span.token_count;
|
||||
batch.push(span.content.clone());
|
||||
}
|
||||
|
||||
if !batch.is_empty() {
|
||||
let batch_embeddings = embedding_provider
|
||||
.embed_batch(mem::take(&mut batch))
|
||||
.await?;
|
||||
|
||||
embeddings.extend(batch_embeddings);
|
||||
}
|
||||
|
||||
let mut embeddings = embeddings.into_iter();
|
||||
for span in spans {
|
||||
let embedding = if let Some(embedding) = embeddings_for_digests.get(&span.digest) {
|
||||
Some(embedding.clone())
|
||||
} else {
|
||||
embeddings.next()
|
||||
};
|
||||
let embedding = embedding.ok_or_else(|| anyhow!("failed to embed spans"))?;
|
||||
span.embedding = Some(embedding);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Entity for SemanticIndex {
|
||||
|
Loading…
Reference in New Issue
Block a user