Rename Document to Span

This commit is contained in:
Antonio Scandurra 2023-09-06 16:48:53 +02:00
parent de0f53b39f
commit ce62173534
5 changed files with 109 additions and 112 deletions

View File

@ -1,6 +1,6 @@
use crate::{
embedding::Embedding,
parsing::{Document, DocumentDigest},
parsing::{Span, SpanDigest},
SEMANTIC_INDEX_VERSION,
};
use anyhow::{anyhow, Context, Result};
@ -124,8 +124,8 @@ impl VectorDatabase {
}
log::trace!("vector database schema out of date. updating...");
db.execute("DROP TABLE IF EXISTS documents", [])
.context("failed to drop 'documents' table")?;
db.execute("DROP TABLE IF EXISTS spans", [])
.context("failed to drop 'spans' table")?;
db.execute("DROP TABLE IF EXISTS files", [])
.context("failed to drop 'files' table")?;
db.execute("DROP TABLE IF EXISTS worktrees", [])
@ -174,7 +174,7 @@ impl VectorDatabase {
)?;
db.execute(
"CREATE TABLE documents (
"CREATE TABLE spans (
id INTEGER PRIMARY KEY AUTOINCREMENT,
file_id INTEGER NOT NULL,
start_byte INTEGER NOT NULL,
@ -211,7 +211,7 @@ impl VectorDatabase {
worktree_id: i64,
path: Arc<Path>,
mtime: SystemTime,
documents: Vec<Document>,
spans: Vec<Span>,
) -> impl Future<Output = Result<()>> {
self.transact(move |db| {
// Return the existing ID, if both the file and mtime match
@ -231,7 +231,7 @@ impl VectorDatabase {
let t0 = Instant::now();
let mut query = db.prepare(
"
INSERT INTO documents
INSERT INTO spans
(file_id, start_byte, end_byte, name, embedding, digest)
VALUES (?1, ?2, ?3, ?4, ?5, ?6)
",
@ -241,14 +241,14 @@ impl VectorDatabase {
t0.elapsed().as_millis()
);
for document in documents {
for span in spans {
query.execute(params![
file_id,
document.range.start.to_string(),
document.range.end.to_string(),
document.name,
document.embedding,
document.digest
span.range.start.to_string(),
span.range.end.to_string(),
span.name,
span.embedding,
span.digest
])?;
}
@ -278,13 +278,13 @@ impl VectorDatabase {
pub fn embeddings_for_files(
&self,
worktree_id_file_paths: HashMap<i64, Vec<Arc<Path>>>,
) -> impl Future<Output = Result<HashMap<DocumentDigest, Embedding>>> {
) -> impl Future<Output = Result<HashMap<SpanDigest, Embedding>>> {
self.transact(move |db| {
let mut query = db.prepare(
"
SELECT digest, embedding
FROM documents
LEFT JOIN files ON files.id = documents.file_id
FROM spans
LEFT JOIN files ON files.id = spans.file_id
WHERE files.worktree_id = ? AND files.relative_path IN rarray(?)
",
)?;
@ -297,10 +297,7 @@ impl VectorDatabase {
.collect::<Vec<_>>(),
);
let rows = query.query_map(params![worktree_id, file_paths], |row| {
Ok((
row.get::<_, DocumentDigest>(0)?,
row.get::<_, Embedding>(1)?,
))
Ok((row.get::<_, SpanDigest>(0)?, row.get::<_, Embedding>(1)?))
})?;
for row in rows {
@ -379,7 +376,7 @@ impl VectorDatabase {
let file_ids = file_ids.to_vec();
self.transact(move |db| {
let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
Self::for_each_document(db, &file_ids, |id, embedding| {
Self::for_each_span(db, &file_ids, |id, embedding| {
let similarity = embedding.similarity(&query_embedding);
let ix = match results.binary_search_by(|(_, s)| {
similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)
@ -434,7 +431,7 @@ impl VectorDatabase {
})
}
fn for_each_document(
fn for_each_span(
db: &rusqlite::Connection,
file_ids: &[i64],
mut f: impl FnMut(i64, Embedding),
@ -444,7 +441,7 @@ impl VectorDatabase {
SELECT
id, embedding
FROM
documents
spans
WHERE
file_id IN rarray(?)
",
@ -459,7 +456,7 @@ impl VectorDatabase {
Ok(())
}
pub fn get_documents_by_ids(
pub fn spans_for_ids(
&self,
ids: &[i64],
) -> impl Future<Output = Result<Vec<(i64, PathBuf, Range<usize>)>>> {
@ -468,16 +465,16 @@ impl VectorDatabase {
let mut statement = db.prepare(
"
SELECT
documents.id,
spans.id,
files.worktree_id,
files.relative_path,
documents.start_byte,
documents.end_byte
spans.start_byte,
spans.end_byte
FROM
documents, files
spans, files
WHERE
documents.file_id = files.id AND
documents.id in rarray(?)
spans.file_id = files.id AND
spans.id in rarray(?)
",
)?;
@ -500,7 +497,7 @@ impl VectorDatabase {
for id in &ids {
let value = values_by_id
.remove(id)
.ok_or(anyhow!("missing document id {}", id))?;
.ok_or(anyhow!("missing span id {}", id))?;
results.push(value);
}

View File

@ -1,4 +1,4 @@
use crate::{embedding::EmbeddingProvider, parsing::Document, JobHandle};
use crate::{embedding::EmbeddingProvider, parsing::Span, JobHandle};
use gpui::executor::Background;
use parking_lot::Mutex;
use smol::channel;
@ -9,7 +9,7 @@ pub struct FileToEmbed {
pub worktree_id: i64,
pub path: Arc<Path>,
pub mtime: SystemTime,
pub documents: Vec<Document>,
pub spans: Vec<Span>,
pub job_handle: JobHandle,
}
@ -19,7 +19,7 @@ impl std::fmt::Debug for FileToEmbed {
.field("worktree_id", &self.worktree_id)
.field("path", &self.path)
.field("mtime", &self.mtime)
.field("document", &self.documents)
.field("spans", &self.spans)
.finish_non_exhaustive()
}
}
@ -29,13 +29,13 @@ impl PartialEq for FileToEmbed {
self.worktree_id == other.worktree_id
&& self.path == other.path
&& self.mtime == other.mtime
&& self.documents == other.documents
&& self.spans == other.spans
}
}
pub struct EmbeddingQueue {
embedding_provider: Arc<dyn EmbeddingProvider>,
pending_batch: Vec<FileToEmbedFragment>,
pending_batch: Vec<FileFragmentToEmbed>,
executor: Arc<Background>,
pending_batch_token_count: usize,
finished_files_tx: channel::Sender<FileToEmbed>,
@ -43,9 +43,9 @@ pub struct EmbeddingQueue {
}
#[derive(Clone)]
pub struct FileToEmbedFragment {
pub struct FileFragmentToEmbed {
file: Arc<Mutex<FileToEmbed>>,
document_range: Range<usize>,
span_range: Range<usize>,
}
impl EmbeddingQueue {
@ -62,41 +62,41 @@ impl EmbeddingQueue {
}
pub fn push(&mut self, file: FileToEmbed) {
if file.documents.is_empty() {
if file.spans.is_empty() {
self.finished_files_tx.try_send(file).unwrap();
return;
}
let file = Arc::new(Mutex::new(file));
self.pending_batch.push(FileToEmbedFragment {
self.pending_batch.push(FileFragmentToEmbed {
file: file.clone(),
document_range: 0..0,
span_range: 0..0,
});
let mut fragment_range = &mut self.pending_batch.last_mut().unwrap().document_range;
let mut fragment_range = &mut self.pending_batch.last_mut().unwrap().span_range;
let mut saved_tokens = 0;
for (ix, document) in file.lock().documents.iter().enumerate() {
let document_token_count = if document.embedding.is_none() {
document.token_count
for (ix, span) in file.lock().spans.iter().enumerate() {
let span_token_count = if span.embedding.is_none() {
span.token_count
} else {
saved_tokens += document.token_count;
saved_tokens += span.token_count;
0
};
let next_token_count = self.pending_batch_token_count + document_token_count;
let next_token_count = self.pending_batch_token_count + span_token_count;
if next_token_count > self.embedding_provider.max_tokens_per_batch() {
let range_end = fragment_range.end;
self.flush();
self.pending_batch.push(FileToEmbedFragment {
self.pending_batch.push(FileFragmentToEmbed {
file: file.clone(),
document_range: range_end..range_end,
span_range: range_end..range_end,
});
fragment_range = &mut self.pending_batch.last_mut().unwrap().document_range;
fragment_range = &mut self.pending_batch.last_mut().unwrap().span_range;
}
fragment_range.end = ix + 1;
self.pending_batch_token_count += document_token_count;
self.pending_batch_token_count += span_token_count;
}
log::trace!("Saved Tokens: {:?}", saved_tokens);
}
@ -113,20 +113,20 @@ impl EmbeddingQueue {
self.executor.spawn(async move {
let mut spans = Vec::new();
let mut document_count = 0;
let mut span_count = 0;
for fragment in &batch {
let file = fragment.file.lock();
document_count += file.documents[fragment.document_range.clone()].len();
span_count += file.spans[fragment.span_range.clone()].len();
spans.extend(
{
file.documents[fragment.document_range.clone()]
file.spans[fragment.span_range.clone()]
.iter().filter(|d| d.embedding.is_none())
.map(|d| d.content.clone())
}
);
}
log::trace!("Documents Length: {:?}", document_count);
log::trace!("Documents Length: {:?}", span_count);
log::trace!("Span Length: {:?}", spans.clone().len());
// If spans is 0, just send the fragment to the finished files if its the last one.
@ -143,11 +143,11 @@ impl EmbeddingQueue {
Ok(embeddings) => {
let mut embeddings = embeddings.into_iter();
for fragment in batch {
for document in
&mut fragment.file.lock().documents[fragment.document_range.clone()].iter_mut().filter(|d| d.embedding.is_none())
for span in
&mut fragment.file.lock().spans[fragment.span_range.clone()].iter_mut().filter(|d| d.embedding.is_none())
{
if let Some(embedding) = embeddings.next() {
document.embedding = Some(embedding);
span.embedding = Some(embedding);
} else {
//
log::error!("number of embeddings returned different from number of documents");

View File

@ -16,9 +16,9 @@ use std::{
use tree_sitter::{Parser, QueryCursor};
#[derive(Debug, PartialEq, Eq, Clone, Hash)]
pub struct DocumentDigest([u8; 20]);
pub struct SpanDigest([u8; 20]);
impl FromSql for DocumentDigest {
impl FromSql for SpanDigest {
fn column_result(value: ValueRef) -> FromSqlResult<Self> {
let blob = value.as_blob()?;
let bytes =
@ -27,17 +27,17 @@ impl FromSql for DocumentDigest {
expected_size: 20,
blob_size: blob.len(),
})?;
return Ok(DocumentDigest(bytes));
return Ok(SpanDigest(bytes));
}
}
impl ToSql for DocumentDigest {
impl ToSql for SpanDigest {
fn to_sql(&self) -> rusqlite::Result<ToSqlOutput> {
self.0.to_sql()
}
}
impl From<&'_ str> for DocumentDigest {
impl From<&'_ str> for SpanDigest {
fn from(value: &'_ str) -> Self {
let mut sha1 = Sha1::new();
sha1.update(value);
@ -46,12 +46,12 @@ impl From<&'_ str> for DocumentDigest {
}
#[derive(Debug, PartialEq, Clone)]
pub struct Document {
pub struct Span {
pub name: String,
pub range: Range<usize>,
pub content: String,
pub embedding: Option<Embedding>,
pub digest: DocumentDigest,
pub digest: SpanDigest,
pub token_count: usize,
}
@ -97,14 +97,14 @@ impl CodeContextRetriever {
relative_path: &Path,
language_name: Arc<str>,
content: &str,
) -> Result<Vec<Document>> {
) -> Result<Vec<Span>> {
let document_span = ENTIRE_FILE_TEMPLATE
.replace("<path>", relative_path.to_string_lossy().as_ref())
.replace("<language>", language_name.as_ref())
.replace("<item>", &content);
let digest = DocumentDigest::from(document_span.as_str());
let digest = SpanDigest::from(document_span.as_str());
let (document_span, token_count) = self.embedding_provider.truncate(&document_span);
Ok(vec![Document {
Ok(vec![Span {
range: 0..content.len(),
content: document_span,
embedding: Default::default(),
@ -114,13 +114,13 @@ impl CodeContextRetriever {
}])
}
fn parse_markdown_file(&self, relative_path: &Path, content: &str) -> Result<Vec<Document>> {
fn parse_markdown_file(&self, relative_path: &Path, content: &str) -> Result<Vec<Span>> {
let document_span = MARKDOWN_CONTEXT_TEMPLATE
.replace("<path>", relative_path.to_string_lossy().as_ref())
.replace("<item>", &content);
let digest = DocumentDigest::from(document_span.as_str());
let digest = SpanDigest::from(document_span.as_str());
let (document_span, token_count) = self.embedding_provider.truncate(&document_span);
Ok(vec![Document {
Ok(vec![Span {
range: 0..content.len(),
content: document_span,
embedding: None,
@ -191,32 +191,32 @@ impl CodeContextRetriever {
relative_path: &Path,
content: &str,
language: Arc<Language>,
) -> Result<Vec<Document>> {
) -> Result<Vec<Span>> {
let language_name = language.name();
if PARSEABLE_ENTIRE_FILE_TYPES.contains(&language_name.as_ref()) {
return self.parse_entire_file(relative_path, language_name, &content);
} else if &language_name.to_string() == &"Markdown".to_string() {
} else if language_name.as_ref() == "Markdown" {
return self.parse_markdown_file(relative_path, &content);
}
let mut documents = self.parse_file(content, language)?;
for document in &mut documents {
let mut spans = self.parse_file(content, language)?;
for span in &mut spans {
let document_content = CODE_CONTEXT_TEMPLATE
.replace("<path>", relative_path.to_string_lossy().as_ref())
.replace("<language>", language_name.as_ref())
.replace("item", &document.content);
.replace("item", &span.content);
let (document_content, token_count) =
self.embedding_provider.truncate(&document_content);
document.content = document_content;
document.token_count = token_count;
span.content = document_content;
span.token_count = token_count;
}
Ok(documents)
Ok(spans)
}
pub fn parse_file(&mut self, content: &str, language: Arc<Language>) -> Result<Vec<Document>> {
pub fn parse_file(&mut self, content: &str, language: Arc<Language>) -> Result<Vec<Span>> {
let grammar = language
.grammar()
.ok_or_else(|| anyhow!("no grammar for language"))?;
@ -227,7 +227,7 @@ impl CodeContextRetriever {
let language_scope = language.default_scope();
let placeholder = language_scope.collapsed_placeholder();
let mut documents = Vec::new();
let mut spans = Vec::new();
let mut collapsed_ranges_within = Vec::new();
let mut parsed_name_ranges = HashSet::new();
for (i, context_match) in matches.iter().enumerate() {
@ -267,22 +267,22 @@ impl CodeContextRetriever {
collapsed_ranges_within.sort_by_key(|r| (r.start, Reverse(r.end)));
let mut document_content = String::new();
let mut span_content = String::new();
for context_range in &context_match.context_ranges {
add_content_from_range(
&mut document_content,
&mut span_content,
content,
context_range.clone(),
context_match.start_col,
);
document_content.push_str("\n");
span_content.push_str("\n");
}
let mut offset = item_range.start;
for collapsed_range in &collapsed_ranges_within {
if collapsed_range.start > offset {
add_content_from_range(
&mut document_content,
&mut span_content,
content,
offset..collapsed_range.start,
context_match.start_col,
@ -291,24 +291,24 @@ impl CodeContextRetriever {
}
if collapsed_range.end > offset {
document_content.push_str(placeholder);
span_content.push_str(placeholder);
offset = collapsed_range.end;
}
}
if offset < item_range.end {
add_content_from_range(
&mut document_content,
&mut span_content,
content,
offset..item_range.end,
context_match.start_col,
);
}
let sha1 = DocumentDigest::from(document_content.as_str());
documents.push(Document {
let sha1 = SpanDigest::from(span_content.as_str());
spans.push(Span {
name,
content: document_content,
content: span_content,
range: item_range.clone(),
embedding: None,
digest: sha1,
@ -316,7 +316,7 @@ impl CodeContextRetriever {
})
}
return Ok(documents);
return Ok(spans);
}
}

View File

@ -17,7 +17,7 @@ use futures::{future, FutureExt, StreamExt};
use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle};
use language::{Anchor, Bias, Buffer, Language, LanguageRegistry};
use parking_lot::Mutex;
use parsing::{CodeContextRetriever, DocumentDigest, PARSEABLE_ENTIRE_FILE_TYPES};
use parsing::{CodeContextRetriever, SpanDigest, PARSEABLE_ENTIRE_FILE_TYPES};
use postage::watch;
use project::{search::PathMatcher, Fs, PathChange, Project, ProjectEntryId, Worktree, WorktreeId};
use smol::channel;
@ -36,7 +36,7 @@ use util::{
ResultExt,
};
const SEMANTIC_INDEX_VERSION: usize = 9;
const SEMANTIC_INDEX_VERSION: usize = 10;
const BACKGROUND_INDEXING_DELAY: Duration = Duration::from_secs(5 * 60);
const EMBEDDING_QUEUE_FLUSH_TIMEOUT: Duration = Duration::from_millis(250);
@ -84,7 +84,7 @@ pub struct SemanticIndex {
db: VectorDatabase,
embedding_provider: Arc<dyn EmbeddingProvider>,
language_registry: Arc<LanguageRegistry>,
parsing_files_tx: channel::Sender<(Arc<HashMap<DocumentDigest, Embedding>>, PendingFile)>,
parsing_files_tx: channel::Sender<(Arc<HashMap<SpanDigest, Embedding>>, PendingFile)>,
_embedding_task: Task<()>,
_parsing_files_tasks: Vec<Task<()>>,
projects: HashMap<WeakModelHandle<Project>, ProjectState>,
@ -252,16 +252,16 @@ impl SemanticIndex {
let db = db.clone();
async move {
while let Ok(file) = embedded_files.recv().await {
db.insert_file(file.worktree_id, file.path, file.mtime, file.documents)
db.insert_file(file.worktree_id, file.path, file.mtime, file.spans)
.await
.log_err();
}
}
});
// Parse files into embeddable documents.
// Parse files into embeddable spans.
let (parsing_files_tx, parsing_files_rx) =
channel::unbounded::<(Arc<HashMap<DocumentDigest, Embedding>>, PendingFile)>();
channel::unbounded::<(Arc<HashMap<SpanDigest, Embedding>>, PendingFile)>();
let embedding_queue = Arc::new(Mutex::new(embedding_queue));
let mut _parsing_files_tasks = Vec::new();
for _ in 0..cx.background().num_cpus() {
@ -320,26 +320,26 @@ impl SemanticIndex {
pending_file: PendingFile,
retriever: &mut CodeContextRetriever,
embedding_queue: &Arc<Mutex<EmbeddingQueue>>,
embeddings_for_digest: &HashMap<DocumentDigest, Embedding>,
embeddings_for_digest: &HashMap<SpanDigest, Embedding>,
) {
let Some(language) = pending_file.language else {
return;
};
if let Some(content) = fs.load(&pending_file.absolute_path).await.log_err() {
if let Some(mut documents) = retriever
if let Some(mut spans) = retriever
.parse_file_with_template(&pending_file.relative_path, &content, language)
.log_err()
{
log::trace!(
"parsed path {:?}: {} documents",
"parsed path {:?}: {} spans",
pending_file.relative_path,
documents.len()
spans.len()
);
for document in documents.iter_mut() {
if let Some(embedding) = embeddings_for_digest.get(&document.digest) {
document.embedding = Some(embedding.to_owned());
for span in &mut spans {
if let Some(embedding) = embeddings_for_digest.get(&span.digest) {
span.embedding = Some(embedding.to_owned());
}
}
@ -348,7 +348,7 @@ impl SemanticIndex {
path: pending_file.relative_path,
mtime: pending_file.modified_time,
job_handle: pending_file.job_handle,
documents,
spans: spans,
});
}
}
@ -708,13 +708,13 @@ impl SemanticIndex {
}
let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<i64>>();
let documents = database.get_documents_by_ids(ids.as_slice()).await?;
let spans = database.spans_for_ids(ids.as_slice()).await?;
let mut tasks = Vec::new();
let mut ranges = Vec::new();
let weak_project = project.downgrade();
project.update(&mut cx, |project, cx| {
for (worktree_db_id, file_path, byte_range) in documents {
for (worktree_db_id, file_path, byte_range) in spans {
let project_state =
if let Some(state) = this.read(cx).projects.get(&weak_project) {
state

View File

@ -1,7 +1,7 @@
use crate::{
embedding::{DummyEmbeddings, Embedding, EmbeddingProvider},
embedding_queue::EmbeddingQueue,
parsing::{subtract_ranges, CodeContextRetriever, Document, DocumentDigest},
parsing::{subtract_ranges, CodeContextRetriever, Span, SpanDigest},
semantic_index_settings::SemanticIndexSettings,
FileToEmbed, JobHandle, SearchResult, SemanticIndex, EMBEDDING_QUEUE_FLUSH_TIMEOUT,
};
@ -204,15 +204,15 @@ async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) {
worktree_id: 5,
path: Path::new(&format!("path-{file_ix}")).into(),
mtime: SystemTime::now(),
documents: (0..rng.gen_range(4..22))
spans: (0..rng.gen_range(4..22))
.map(|document_ix| {
let content_len = rng.gen_range(10..100);
let content = RandomCharIter::new(&mut rng)
.with_simple_text()
.take(content_len)
.collect::<String>();
let digest = DocumentDigest::from(content.as_str());
Document {
let digest = SpanDigest::from(content.as_str());
Span {
range: 0..10,
embedding: None,
name: format!("document {document_ix}"),
@ -245,7 +245,7 @@ async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) {
.iter()
.map(|file| {
let mut file = file.clone();
for doc in &mut file.documents {
for doc in &mut file.spans {
doc.embedding = Some(embedding_provider.embed_sync(doc.content.as_ref()));
}
file
@ -437,7 +437,7 @@ async fn test_code_context_retrieval_json() {
}
fn assert_documents_eq(
documents: &[Document],
documents: &[Span],
expected_contents_and_start_offsets: &[(String, usize)],
) {
assert_eq!(