This commit is contained in:
Antonio Scandurra 2023-09-14 17:42:30 +02:00
parent 6a271617b4
commit f86e5a987f
4 changed files with 94 additions and 3 deletions

View File

@ -912,7 +912,6 @@ impl Project {
self.user_store.clone()
}
#[cfg(any(test, feature = "test-support"))]
pub fn opened_buffers(&self, cx: &AppContext) -> Vec<ModelHandle<Buffer>> {
self.opened_buffers
.values()

View File

@ -190,6 +190,10 @@ impl VectorDatabase {
)",
[],
)?;
db.execute(
"CREATE INDEX spans_digest ON spans (digest)",
[],
)?;
log::trace!("vector database initialized with updated schema.");
Ok(())

View File

@ -207,7 +207,7 @@ impl CodeContextRetriever {
if PARSEABLE_ENTIRE_FILE_TYPES.contains(&language_name.as_ref()) {
return self.parse_entire_file(relative_path, language_name, &content);
} else if language_name.as_ref() == "Markdown" {
} else if ["Markdown", "Plain Text"].contains(&language_name.as_ref()) {
return self.parse_markdown_file(relative_path, &content);
}

View File

@ -24,6 +24,7 @@ use smol::channel;
use std::{
cmp::Ordering,
future::Future,
mem,
ops::Range,
path::{Path, PathBuf},
sync::{Arc, Weak},
@ -37,7 +38,7 @@ use util::{
};
use workspace::WorkspaceCreated;
const SEMANTIC_INDEX_VERSION: usize = 10;
const SEMANTIC_INDEX_VERSION: usize = 11;
const BACKGROUND_INDEXING_DELAY: Duration = Duration::from_secs(5 * 60);
const EMBEDDING_QUEUE_FLUSH_TIMEOUT: Duration = Duration::from_millis(250);
@ -767,6 +768,93 @@ impl SemanticIndex {
});
}
}
let dirty_buffers = project.read_with(&cx, |project, cx| {
project
.opened_buffers(cx)
.into_iter()
.filter_map(|buffer_handle| {
let buffer = buffer_handle.read(cx);
if buffer.is_dirty() {
Some((buffer_handle.downgrade(), buffer.snapshot()))
} else {
None
}
})
.collect::<HashMap<_, _>>()
});
cx.background()
.spawn({
let mut retriever = CodeContextRetriever::new(embedding_provider.clone());
let embedding_provider = embedding_provider.clone();
let phrase_embedding = phrase_embedding.clone();
async move {
let mut results = Vec::new();
'buffers: for (buffer_handle, buffer_snapshot) in dirty_buffers {
let language = buffer_snapshot
.language_at(0)
.cloned()
.unwrap_or_else(|| language::PLAIN_TEXT.clone());
if let Some(spans) = retriever
.parse_file_with_template(None, &buffer_snapshot.text(), language)
.log_err()
{
let mut batch = Vec::new();
let mut batch_tokens = 0;
let mut embeddings = Vec::new();
// TODO: query span digests in the database to avoid embedding them again.
for span in &spans {
if span.embedding.is_some() {
continue;
}
if batch_tokens + span.token_count
> embedding_provider.max_tokens_per_batch()
{
if let Some(batch_embeddings) = embedding_provider
.embed_batch(mem::take(&mut batch))
.await
.log_err()
{
embeddings.extend(batch_embeddings);
batch_tokens = 0;
} else {
continue 'buffers;
}
}
batch_tokens += span.token_count;
batch.push(span.content.clone());
}
if let Some(batch_embeddings) = embedding_provider
.embed_batch(mem::take(&mut batch))
.await
.log_err()
{
embeddings.extend(batch_embeddings);
} else {
continue 'buffers;
}
let mut embeddings = embeddings.into_iter();
for span in spans {
let embedding = span.embedding.or_else(|| embeddings.next());
if let Some(embedding) = embedding {
todo!()
} else {
log::error!("failed to embed span");
continue 'buffers;
}
}
}
}
}
})
.await;
let batch_results = futures::future::join_all(batch_results).await;
let mut results = Vec::new();