From 0a0e40fb246b3f1e0e8751f24bf008387f223c4b Mon Sep 17 00:00:00 2001 From: KCaverly Date: Thu, 13 Jul 2023 16:34:32 -0400 Subject: [PATCH] refactored code context retrieval and standardized database migration Co-authored-by: maxbrunsfeld --- Cargo.lock | 2 + crates/vector_store/Cargo.toml | 3 + crates/vector_store/src/db.rs | 134 +++++++++++------ crates/vector_store/src/modal.rs | 2 +- crates/vector_store/src/parsing.rs | 78 +++++----- crates/vector_store/src/vector_store.rs | 140 ++++++++++-------- crates/vector_store/src/vector_store_tests.rs | 21 ++- 7 files changed, 232 insertions(+), 148 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 0ac6a2ee89..4359659a53 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8483,7 +8483,9 @@ dependencies = [ "anyhow", "async-trait", "bincode", + "ctor", "editor", + "env_logger 0.9.3", "futures 0.3.28", "gpui", "isahc", diff --git a/crates/vector_store/Cargo.toml b/crates/vector_store/Cargo.toml index 40bff8b95c..8e1dea59fd 100644 --- a/crates/vector_store/Cargo.toml +++ b/crates/vector_store/Cargo.toml @@ -44,6 +44,9 @@ rpc = { path = "../rpc", features = ["test-support"] } workspace = { path = "../workspace", features = ["test-support"] } settings = { path = "../settings", features = ["test-support"]} tree-sitter-rust = "*" + rand.workspace = true unindent.workspace = true tempdir.workspace = true +ctor.workspace = true +env_logger.workspace = true diff --git a/crates/vector_store/src/db.rs b/crates/vector_store/src/db.rs index a91a1872b5..d3d05f8c62 100644 --- a/crates/vector_store/src/db.rs +++ b/crates/vector_store/src/db.rs @@ -1,20 +1,20 @@ -use std::{ - cmp::Ordering, - collections::HashMap, - path::{Path, PathBuf}, - rc::Rc, - time::SystemTime, -}; - +use crate::{parsing::Document, VECTOR_STORE_VERSION}; use anyhow::{anyhow, Result}; - -use crate::parsing::ParsedFile; -use crate::VECTOR_STORE_VERSION; +use project::Fs; use rpc::proto::Timestamp; use rusqlite::{ params, types::{FromSql, FromSqlResult, ValueRef}, }; +use std::{ + cmp::Ordering, + collections::HashMap, + ops::Range, + path::{Path, PathBuf}, + rc::Rc, + sync::Arc, + time::SystemTime, +}; #[derive(Debug)] pub struct FileRecord { @@ -42,48 +42,88 @@ pub struct VectorDatabase { } impl VectorDatabase { - pub fn new(path: String) -> Result { + pub async fn new(fs: Arc, path: Arc) -> Result { + if let Some(db_directory) = path.parent() { + fs.create_dir(db_directory).await?; + } + let this = Self { - db: rusqlite::Connection::open(path)?, + db: rusqlite::Connection::open(path.as_path())?, }; this.initialize_database()?; Ok(this) } + fn get_existing_version(&self) -> Result { + let mut version_query = self.db.prepare("SELECT version from vector_store_config")?; + version_query + .query_row([], |row| Ok(row.get::<_, i64>(0)?)) + .map_err(|err| anyhow!("version query failed: {err}")) + } + fn initialize_database(&self) -> Result<()> { rusqlite::vtab::array::load_module(&self.db)?; - // This will create the database if it doesnt exist + if self + .get_existing_version() + .map_or(false, |version| version == VECTOR_STORE_VERSION as i64) + { + return Ok(()); + } + + self.db + .execute( + " + DROP TABLE vector_store_config; + DROP TABLE worktrees; + DROP TABLE files; + DROP TABLE documents; + ", + [], + ) + .ok(); // Initialize Vector Databasing Tables self.db.execute( - "CREATE TABLE IF NOT EXISTS worktrees ( + "CREATE TABLE vector_store_config ( + version INTEGER NOT NULL + )", + [], + )?; + + self.db.execute( + "INSERT INTO vector_store_config (version) VALUES (?1)", + params![VECTOR_STORE_VERSION], + )?; + + self.db.execute( + "CREATE TABLE worktrees ( id INTEGER PRIMARY KEY AUTOINCREMENT, absolute_path VARCHAR NOT NULL ); - CREATE UNIQUE INDEX IF NOT EXISTS worktrees_absolute_path ON worktrees (absolute_path); + CREATE UNIQUE INDEX worktrees_absolute_path ON worktrees (absolute_path); ", [], )?; self.db.execute( - "CREATE TABLE IF NOT EXISTS files ( + "CREATE TABLE files ( id INTEGER PRIMARY KEY AUTOINCREMENT, worktree_id INTEGER NOT NULL, relative_path VARCHAR NOT NULL, mtime_seconds INTEGER NOT NULL, mtime_nanos INTEGER NOT NULL, - vector_store_version INTEGER NOT NULL, FOREIGN KEY(worktree_id) REFERENCES worktrees(id) ON DELETE CASCADE )", [], )?; self.db.execute( - "CREATE TABLE IF NOT EXISTS documents ( + "CREATE TABLE documents ( id INTEGER PRIMARY KEY AUTOINCREMENT, file_id INTEGER NOT NULL, - offset INTEGER NOT NULL, + start_byte INTEGER NOT NULL, + end_byte INTEGER NOT NULL, name VARCHAR NOT NULL, embedding BLOB NOT NULL, FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE @@ -102,43 +142,44 @@ impl VectorDatabase { Ok(()) } - pub fn insert_file(&self, worktree_id: i64, indexed_file: ParsedFile) -> Result<()> { + pub fn insert_file( + &self, + worktree_id: i64, + path: PathBuf, + mtime: SystemTime, + documents: Vec, + ) -> Result<()> { // Write to files table, and return generated id. self.db.execute( " DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2; ", - params![worktree_id, indexed_file.path.to_str()], + params![worktree_id, path.to_str()], )?; - let mtime = Timestamp::from(indexed_file.mtime); + let mtime = Timestamp::from(mtime); self.db.execute( " INSERT INTO files - (worktree_id, relative_path, mtime_seconds, mtime_nanos, vector_store_version) + (worktree_id, relative_path, mtime_seconds, mtime_nanos) VALUES - (?1, ?2, $3, $4, $5); + (?1, ?2, $3, $4); ", - params![ - worktree_id, - indexed_file.path.to_str(), - mtime.seconds, - mtime.nanos, - VECTOR_STORE_VERSION - ], + params![worktree_id, path.to_str(), mtime.seconds, mtime.nanos], )?; let file_id = self.db.last_insert_rowid(); // Currently inserting at approximately 3400 documents a second // I imagine we can speed this up with a bulk insert of some kind. - for document in indexed_file.documents { + for document in documents { let embedding_blob = bincode::serialize(&document.embedding)?; self.db.execute( - "INSERT INTO documents (file_id, offset, name, embedding) VALUES (?1, ?2, ?3, ?4)", + "INSERT INTO documents (file_id, start_byte, end_byte, name, embedding) VALUES (?1, ?2, ?3, ?4, $5)", params![ file_id, - document.offset.to_string(), + document.range.start.to_string(), + document.range.end.to_string(), document.name, embedding_blob ], @@ -204,7 +245,7 @@ impl VectorDatabase { worktree_ids: &[i64], query_embedding: &Vec, limit: usize, - ) -> Result> { + ) -> Result, String)>> { let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1); self.for_each_document(&worktree_ids, |id, embedding| { let similarity = dot(&embedding, &query_embedding); @@ -248,11 +289,18 @@ impl VectorDatabase { Ok(()) } - fn get_documents_by_ids(&self, ids: &[i64]) -> Result> { + fn get_documents_by_ids( + &self, + ids: &[i64], + ) -> Result, String)>> { let mut statement = self.db.prepare( " SELECT - documents.id, files.worktree_id, files.relative_path, documents.offset, documents.name + documents.id, + files.worktree_id, + files.relative_path, + documents.start_byte, + documents.end_byte, documents.name FROM documents, files WHERE @@ -266,15 +314,15 @@ impl VectorDatabase { row.get::<_, i64>(0)?, row.get::<_, i64>(1)?, row.get::<_, String>(2)?.into(), - row.get(3)?, - row.get(4)?, + row.get(3)?..row.get(4)?, + row.get(5)?, )) })?; - let mut values_by_id = HashMap::::default(); + let mut values_by_id = HashMap::, String)>::default(); for row in result_iter { - let (id, worktree_id, path, offset, name) = row?; - values_by_id.insert(id, (worktree_id, path, offset, name)); + let (id, worktree_id, path, range, name) = row?; + values_by_id.insert(id, (worktree_id, path, range, name)); } let mut results = Vec::with_capacity(ids.len()); diff --git a/crates/vector_store/src/modal.rs b/crates/vector_store/src/modal.rs index 9225fe8786..b797a20806 100644 --- a/crates/vector_store/src/modal.rs +++ b/crates/vector_store/src/modal.rs @@ -66,7 +66,7 @@ impl PickerDelegate for SemanticSearchDelegate { }); let workspace = self.workspace.clone(); - let position = search_result.clone().offset; + let position = search_result.clone().byte_range.start; cx.spawn(|_, mut cx| async move { let buffer = buffer.await?; workspace.update(&mut cx, |workspace, cx| { diff --git a/crates/vector_store/src/parsing.rs b/crates/vector_store/src/parsing.rs index 3e697399b1..23dcf505c9 100644 --- a/crates/vector_store/src/parsing.rs +++ b/crates/vector_store/src/parsing.rs @@ -1,41 +1,39 @@ -use std::{path::PathBuf, sync::Arc, time::SystemTime}; - use anyhow::{anyhow, Ok, Result}; -use project::Fs; +use language::Language; +use std::{ops::Range, path::Path, sync::Arc}; use tree_sitter::{Parser, QueryCursor}; -use crate::PendingFile; - #[derive(Debug, PartialEq, Clone)] pub struct Document { - pub offset: usize, pub name: String, + pub range: Range, + pub content: String, pub embedding: Vec, } -#[derive(Debug, PartialEq, Clone)] -pub struct ParsedFile { - pub path: PathBuf, - pub mtime: SystemTime, - pub documents: Vec, -} - const CODE_CONTEXT_TEMPLATE: &str = "The below code snippet is from file ''\n\n```\n\n```"; pub struct CodeContextRetriever { pub parser: Parser, pub cursor: QueryCursor, - pub fs: Arc, } impl CodeContextRetriever { - pub async fn parse_file( + pub fn new() -> Self { + Self { + parser: Parser::new(), + cursor: QueryCursor::new(), + } + } + + pub fn parse_file( &mut self, - pending_file: PendingFile, - ) -> Result<(ParsedFile, Vec)> { - let grammar = pending_file - .language + relative_path: &Path, + content: &str, + language: Arc, + ) -> Result> { + let grammar = language .grammar() .ok_or_else(|| anyhow!("no grammar for language"))?; let embedding_config = grammar @@ -43,8 +41,6 @@ impl CodeContextRetriever { .as_ref() .ok_or_else(|| anyhow!("no embedding queries"))?; - let content = self.fs.load(&pending_file.absolute_path).await?; - self.parser.set_language(grammar.ts_language).unwrap(); let tree = self @@ -53,7 +49,6 @@ impl CodeContextRetriever { .ok_or_else(|| anyhow!("parsing failed"))?; let mut documents = Vec::new(); - let mut document_texts = Vec::new(); // Iterate through query matches for mat in self.cursor.matches( @@ -63,11 +58,11 @@ impl CodeContextRetriever { ) { let mut name: Vec<&str> = vec![]; let mut item: Option<&str> = None; - let mut offset: Option = None; + let mut byte_range: Option> = None; let mut context_spans: Vec<&str> = vec![]; for capture in mat.captures { if capture.index == embedding_config.item_capture_ix { - offset = Some(capture.node.byte_range().start); + byte_range = Some(capture.node.byte_range()); item = content.get(capture.node.byte_range()); } else if capture.index == embedding_config.name_capture_ix { if let Some(name_content) = content.get(capture.node.byte_range()) { @@ -84,30 +79,25 @@ impl CodeContextRetriever { } } - if item.is_some() && offset.is_some() && name.len() > 0 { - let item = format!("{}\n{}", context_spans.join("\n"), item.unwrap()); + if let Some((item, byte_range)) = item.zip(byte_range) { + if !name.is_empty() { + let item = format!("{}\n{}", context_spans.join("\n"), item); - let document_text = CODE_CONTEXT_TEMPLATE - .replace("", pending_file.relative_path.to_str().unwrap()) - .replace("", &pending_file.language.name().to_lowercase()) - .replace("", item.as_str()); + let document_text = CODE_CONTEXT_TEMPLATE + .replace("", relative_path.to_str().unwrap()) + .replace("", &language.name().to_lowercase()) + .replace("", item.as_str()); - document_texts.push(document_text); - documents.push(Document { - name: name.join(" "), - offset: offset.unwrap(), - embedding: Vec::new(), - }) + documents.push(Document { + range: byte_range, + content: document_text, + embedding: Vec::new(), + name: name.join(" ").to_string(), + }); + } } } - return Ok(( - ParsedFile { - path: pending_file.relative_path, - mtime: pending_file.modified_time, - documents, - }, - document_texts, - )); + return Ok(documents); } } diff --git a/crates/vector_store/src/vector_store.rs b/crates/vector_store/src/vector_store.rs index 0a197bc406..3d9c32875e 100644 --- a/crates/vector_store/src/vector_store.rs +++ b/crates/vector_store/src/vector_store.rs @@ -18,16 +18,16 @@ use gpui::{ }; use language::{Language, LanguageRegistry}; use modal::{SemanticSearch, SemanticSearchDelegate, Toggle}; -use parsing::{CodeContextRetriever, ParsedFile}; +use parsing::{CodeContextRetriever, Document}; use project::{Fs, PathChange, Project, ProjectEntryId, WorktreeId}; use smol::channel; use std::{ collections::HashMap, + ops::Range, path::{Path, PathBuf}, sync::Arc, time::{Duration, Instant, SystemTime}, }; -use tree_sitter::{Parser, QueryCursor}; use util::{ channel::{ReleaseChannel, RELEASE_CHANNEL, RELEASE_CHANNEL_NAME}, http::HttpClient, @@ -36,7 +36,7 @@ use util::{ }; use workspace::{Workspace, WorkspaceCreated}; -const VECTOR_STORE_VERSION: usize = 0; +const VECTOR_STORE_VERSION: usize = 1; const EMBEDDINGS_BATCH_SIZE: usize = 150; pub fn init( @@ -80,11 +80,11 @@ pub fn init( let vector_store = VectorStore::new( fs, db_file_path, - // Arc::new(embedding::DummyEmbeddings {}), - Arc::new(OpenAIEmbeddings { - client: http_client, - executor: cx.background(), - }), + Arc::new(embedding::DummyEmbeddings {}), + // Arc::new(OpenAIEmbeddings { + // client: http_client, + // executor: cx.background(), + // }), language_registry, cx.clone(), ) @@ -212,14 +212,16 @@ pub struct PendingFile { pub struct SearchResult { pub worktree_id: WorktreeId, pub name: String, - pub offset: usize, + pub byte_range: Range, pub file_path: PathBuf, } enum DbOperation { InsertFile { worktree_id: i64, - indexed_file: ParsedFile, + documents: Vec, + path: PathBuf, + mtime: SystemTime, }, Delete { worktree_id: i64, @@ -238,8 +240,9 @@ enum DbOperation { enum EmbeddingJob { Enqueue { worktree_id: i64, - parsed_file: ParsedFile, - document_spans: Vec, + path: PathBuf, + mtime: SystemTime, + documents: Vec, }, Flush, } @@ -256,18 +259,7 @@ impl VectorStore { let db = cx .background() - .spawn({ - let fs = fs.clone(); - let database_url = database_url.clone(); - async move { - if let Some(db_directory) = database_url.parent() { - fs.create_dir(db_directory).await.log_err(); - } - - let db = VectorDatabase::new(database_url.to_string_lossy().to_string())?; - anyhow::Ok(db) - } - }) + .spawn(VectorDatabase::new(fs.clone(), database_url.clone())) .await?; Ok(cx.add_model(|cx| { @@ -280,9 +272,12 @@ impl VectorStore { match job { DbOperation::InsertFile { worktree_id, - indexed_file, + documents, + path, + mtime, } => { - db.insert_file(worktree_id, indexed_file).log_err(); + db.insert_file(worktree_id, path, mtime, documents) + .log_err(); } DbOperation::Delete { worktree_id, path } => { db.delete_file(worktree_id, path).log_err(); @@ -304,35 +299,45 @@ impl VectorStore { // embed_tx/rx: Embed Batch and Send to Database let (embed_batch_tx, embed_batch_rx) = - channel::unbounded::)>>(); + channel::unbounded::, PathBuf, SystemTime)>>(); let _embed_batch_task = cx.background().spawn({ let db_update_tx = db_update_tx.clone(); let embedding_provider = embedding_provider.clone(); async move { while let Ok(mut embeddings_queue) = embed_batch_rx.recv().await { // Construct Batch - let mut document_spans = vec![]; - for (_, _, document_span) in embeddings_queue.iter() { - document_spans.extend(document_span.iter().map(|s| s.as_str())); + let mut batch_documents = vec![]; + for (_, documents, _, _) in embeddings_queue.iter() { + batch_documents + .extend(documents.iter().map(|document| document.content.as_str())); } - if let Ok(embeddings) = embedding_provider.embed_batch(document_spans).await + if let Ok(embeddings) = + embedding_provider.embed_batch(batch_documents).await { + log::trace!( + "created {} embeddings for {} files", + embeddings.len(), + embeddings_queue.len(), + ); + let mut i = 0; let mut j = 0; for embedding in embeddings.iter() { - while embeddings_queue[i].1.documents.len() == j { + while embeddings_queue[i].1.len() == j { i += 1; j = 0; } - embeddings_queue[i].1.documents[j].embedding = embedding.to_owned(); + embeddings_queue[i].1[j].embedding = embedding.to_owned(); j += 1; } - for (worktree_id, indexed_file, _) in embeddings_queue.into_iter() { - for document in indexed_file.documents.iter() { + for (worktree_id, documents, path, mtime) in + embeddings_queue.into_iter() + { + for document in documents.iter() { // TODO: Update this so it doesn't panic assert!( document.embedding.len() > 0, @@ -343,7 +348,9 @@ impl VectorStore { db_update_tx .send(DbOperation::InsertFile { worktree_id, - indexed_file, + documents, + path, + mtime, }) .await .unwrap(); @@ -362,12 +369,13 @@ impl VectorStore { while let Ok(job) = batch_files_rx.recv().await { let should_flush = match job { EmbeddingJob::Enqueue { - document_spans, + documents, worktree_id, - parsed_file, + path, + mtime, } => { - queue_len += &document_spans.len(); - embeddings_queue.push((worktree_id, parsed_file, document_spans)); + queue_len += &documents.len(); + embeddings_queue.push((worktree_id, documents, path, mtime)); queue_len >= EMBEDDINGS_BATCH_SIZE } EmbeddingJob::Flush => true, @@ -385,26 +393,38 @@ impl VectorStore { let (parsing_files_tx, parsing_files_rx) = channel::unbounded::(); let mut _parsing_files_tasks = Vec::new(); - // for _ in 0..cx.background().num_cpus() { - for _ in 0..1 { + for _ in 0..cx.background().num_cpus() { let fs = fs.clone(); let parsing_files_rx = parsing_files_rx.clone(); let batch_files_tx = batch_files_tx.clone(); _parsing_files_tasks.push(cx.background().spawn(async move { - let parser = Parser::new(); - let cursor = QueryCursor::new(); - let mut retriever = CodeContextRetriever { parser, cursor, fs }; + let mut retriever = CodeContextRetriever::new(); while let Ok(pending_file) = parsing_files_rx.recv().await { - if let Some((indexed_file, document_spans)) = - retriever.parse_file(pending_file.clone()).await.log_err() + if let Some(content) = fs.load(&pending_file.absolute_path).await.log_err() { - batch_files_tx - .try_send(EmbeddingJob::Enqueue { - worktree_id: pending_file.worktree_db_id, - parsed_file: indexed_file, - document_spans, - }) - .unwrap(); + if let Some(documents) = retriever + .parse_file( + &pending_file.relative_path, + &content, + pending_file.language, + ) + .log_err() + { + log::trace!( + "parsed path {:?}: {} documents", + pending_file.relative_path, + documents.len() + ); + + batch_files_tx + .try_send(EmbeddingJob::Enqueue { + worktree_id: pending_file.worktree_db_id, + path: pending_file.relative_path, + mtime: pending_file.modified_time, + documents, + }) + .unwrap(); + } } if parsing_files_rx.len() == 0 { @@ -543,6 +563,7 @@ impl VectorStore { }); if !already_stored { + log::trace!("sending for parsing: {:?}", path_buf); parsing_files_tx .try_send(PendingFile { worktree_db_id: db_ids_by_worktree_id @@ -565,8 +586,8 @@ impl VectorStore { .unwrap(); } } - log::info!( - "Parsing Worktree Completed in {:?}", + log::trace!( + "parsing worktree completed in {:?}", t0.elapsed().as_millis() ); } @@ -622,11 +643,12 @@ impl VectorStore { let embedding_provider = self.embedding_provider.clone(); let database_url = self.database_url.clone(); + let fs = self.fs.clone(); cx.spawn(|this, cx| async move { let documents = cx .background() .spawn(async move { - let database = VectorDatabase::new(database_url.to_string_lossy().into())?; + let database = VectorDatabase::new(fs, database_url).await?; let phrase_embedding = embedding_provider .embed_batch(vec![&phrase]) @@ -648,12 +670,12 @@ impl VectorStore { Ok(documents .into_iter() - .filter_map(|(worktree_db_id, file_path, offset, name)| { + .filter_map(|(worktree_db_id, file_path, byte_range, name)| { let worktree_id = project_state.worktree_id_for_db_id(worktree_db_id)?; Some(SearchResult { worktree_id, name, - offset, + byte_range, file_path, }) }) diff --git a/crates/vector_store/src/vector_store_tests.rs b/crates/vector_store/src/vector_store_tests.rs index b6e47e7a23..c4349c7280 100644 --- a/crates/vector_store/src/vector_store_tests.rs +++ b/crates/vector_store/src/vector_store_tests.rs @@ -12,6 +12,13 @@ use settings::SettingsStore; use std::sync::Arc; use unindent::Unindent; +#[ctor::ctor] +fn init_logger() { + if std::env::var("RUST_LOG").is_ok() { + env_logger::init(); + } +} + #[gpui::test] async fn test_vector_store(cx: &mut TestAppContext) { cx.update(|cx| { @@ -95,11 +102,23 @@ async fn test_vector_store(cx: &mut TestAppContext) { .await .unwrap(); - assert_eq!(search_results[0].offset, 0); + assert_eq!(search_results[0].byte_range.start, 0); assert_eq!(search_results[0].name, "aaa"); assert_eq!(search_results[0].worktree_id, worktree_id); } +#[gpui::test] +async fn test_code_context_retrieval(cx: &mut TestAppContext) { + // let mut retriever = CodeContextRetriever::new(fs); + + // retriever::parse_file( + // " + // // + // ", + // ); + // +} + #[gpui::test] fn test_dot_product(mut rng: StdRng) { assert_eq!(dot(&[1., 0., 0., 0., 0.], &[0., 1., 0., 0., 0.]), 0.);