refactored code context retrieval and standardized database migration

Co-authored-by: maxbrunsfeld <max@zed.dev>
This commit is contained in:
KCaverly 2023-07-13 16:34:32 -04:00
parent 5eab628580
commit 0a0e40fb24
7 changed files with 232 additions and 148 deletions

2
Cargo.lock generated
View File

@ -8483,7 +8483,9 @@ dependencies = [
"anyhow",
"async-trait",
"bincode",
"ctor",
"editor",
"env_logger 0.9.3",
"futures 0.3.28",
"gpui",
"isahc",

View File

@ -44,6 +44,9 @@ rpc = { path = "../rpc", features = ["test-support"] }
workspace = { path = "../workspace", features = ["test-support"] }
settings = { path = "../settings", features = ["test-support"]}
tree-sitter-rust = "*"
rand.workspace = true
unindent.workspace = true
tempdir.workspace = true
ctor.workspace = true
env_logger.workspace = true

View File

@ -1,20 +1,20 @@
use std::{
cmp::Ordering,
collections::HashMap,
path::{Path, PathBuf},
rc::Rc,
time::SystemTime,
};
use crate::{parsing::Document, VECTOR_STORE_VERSION};
use anyhow::{anyhow, Result};
use crate::parsing::ParsedFile;
use crate::VECTOR_STORE_VERSION;
use project::Fs;
use rpc::proto::Timestamp;
use rusqlite::{
params,
types::{FromSql, FromSqlResult, ValueRef},
};
use std::{
cmp::Ordering,
collections::HashMap,
ops::Range,
path::{Path, PathBuf},
rc::Rc,
sync::Arc,
time::SystemTime,
};
#[derive(Debug)]
pub struct FileRecord {
@ -42,48 +42,88 @@ pub struct VectorDatabase {
}
impl VectorDatabase {
pub fn new(path: String) -> Result<Self> {
pub async fn new(fs: Arc<dyn Fs>, path: Arc<PathBuf>) -> Result<Self> {
if let Some(db_directory) = path.parent() {
fs.create_dir(db_directory).await?;
}
let this = Self {
db: rusqlite::Connection::open(path)?,
db: rusqlite::Connection::open(path.as_path())?,
};
this.initialize_database()?;
Ok(this)
}
fn get_existing_version(&self) -> Result<i64> {
let mut version_query = self.db.prepare("SELECT version from vector_store_config")?;
version_query
.query_row([], |row| Ok(row.get::<_, i64>(0)?))
.map_err(|err| anyhow!("version query failed: {err}"))
}
fn initialize_database(&self) -> Result<()> {
rusqlite::vtab::array::load_module(&self.db)?;
// This will create the database if it doesnt exist
if self
.get_existing_version()
.map_or(false, |version| version == VECTOR_STORE_VERSION as i64)
{
return Ok(());
}
self.db
.execute(
"
DROP TABLE vector_store_config;
DROP TABLE worktrees;
DROP TABLE files;
DROP TABLE documents;
",
[],
)
.ok();
// Initialize Vector Databasing Tables
self.db.execute(
"CREATE TABLE IF NOT EXISTS worktrees (
"CREATE TABLE vector_store_config (
version INTEGER NOT NULL
)",
[],
)?;
self.db.execute(
"INSERT INTO vector_store_config (version) VALUES (?1)",
params![VECTOR_STORE_VERSION],
)?;
self.db.execute(
"CREATE TABLE worktrees (
id INTEGER PRIMARY KEY AUTOINCREMENT,
absolute_path VARCHAR NOT NULL
);
CREATE UNIQUE INDEX IF NOT EXISTS worktrees_absolute_path ON worktrees (absolute_path);
CREATE UNIQUE INDEX worktrees_absolute_path ON worktrees (absolute_path);
",
[],
)?;
self.db.execute(
"CREATE TABLE IF NOT EXISTS files (
"CREATE TABLE files (
id INTEGER PRIMARY KEY AUTOINCREMENT,
worktree_id INTEGER NOT NULL,
relative_path VARCHAR NOT NULL,
mtime_seconds INTEGER NOT NULL,
mtime_nanos INTEGER NOT NULL,
vector_store_version INTEGER NOT NULL,
FOREIGN KEY(worktree_id) REFERENCES worktrees(id) ON DELETE CASCADE
)",
[],
)?;
self.db.execute(
"CREATE TABLE IF NOT EXISTS documents (
"CREATE TABLE documents (
id INTEGER PRIMARY KEY AUTOINCREMENT,
file_id INTEGER NOT NULL,
offset INTEGER NOT NULL,
start_byte INTEGER NOT NULL,
end_byte INTEGER NOT NULL,
name VARCHAR NOT NULL,
embedding BLOB NOT NULL,
FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE
@ -102,43 +142,44 @@ impl VectorDatabase {
Ok(())
}
pub fn insert_file(&self, worktree_id: i64, indexed_file: ParsedFile) -> Result<()> {
pub fn insert_file(
&self,
worktree_id: i64,
path: PathBuf,
mtime: SystemTime,
documents: Vec<Document>,
) -> Result<()> {
// Write to files table, and return generated id.
self.db.execute(
"
DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2;
",
params![worktree_id, indexed_file.path.to_str()],
params![worktree_id, path.to_str()],
)?;
let mtime = Timestamp::from(indexed_file.mtime);
let mtime = Timestamp::from(mtime);
self.db.execute(
"
INSERT INTO files
(worktree_id, relative_path, mtime_seconds, mtime_nanos, vector_store_version)
(worktree_id, relative_path, mtime_seconds, mtime_nanos)
VALUES
(?1, ?2, $3, $4, $5);
(?1, ?2, $3, $4);
",
params![
worktree_id,
indexed_file.path.to_str(),
mtime.seconds,
mtime.nanos,
VECTOR_STORE_VERSION
],
params![worktree_id, path.to_str(), mtime.seconds, mtime.nanos],
)?;
let file_id = self.db.last_insert_rowid();
// Currently inserting at approximately 3400 documents a second
// I imagine we can speed this up with a bulk insert of some kind.
for document in indexed_file.documents {
for document in documents {
let embedding_blob = bincode::serialize(&document.embedding)?;
self.db.execute(
"INSERT INTO documents (file_id, offset, name, embedding) VALUES (?1, ?2, ?3, ?4)",
"INSERT INTO documents (file_id, start_byte, end_byte, name, embedding) VALUES (?1, ?2, ?3, ?4, $5)",
params![
file_id,
document.offset.to_string(),
document.range.start.to_string(),
document.range.end.to_string(),
document.name,
embedding_blob
],
@ -204,7 +245,7 @@ impl VectorDatabase {
worktree_ids: &[i64],
query_embedding: &Vec<f32>,
limit: usize,
) -> Result<Vec<(i64, PathBuf, usize, String)>> {
) -> Result<Vec<(i64, PathBuf, Range<usize>, String)>> {
let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
self.for_each_document(&worktree_ids, |id, embedding| {
let similarity = dot(&embedding, &query_embedding);
@ -248,11 +289,18 @@ impl VectorDatabase {
Ok(())
}
fn get_documents_by_ids(&self, ids: &[i64]) -> Result<Vec<(i64, PathBuf, usize, String)>> {
fn get_documents_by_ids(
&self,
ids: &[i64],
) -> Result<Vec<(i64, PathBuf, Range<usize>, String)>> {
let mut statement = self.db.prepare(
"
SELECT
documents.id, files.worktree_id, files.relative_path, documents.offset, documents.name
documents.id,
files.worktree_id,
files.relative_path,
documents.start_byte,
documents.end_byte, documents.name
FROM
documents, files
WHERE
@ -266,15 +314,15 @@ impl VectorDatabase {
row.get::<_, i64>(0)?,
row.get::<_, i64>(1)?,
row.get::<_, String>(2)?.into(),
row.get(3)?,
row.get(4)?,
row.get(3)?..row.get(4)?,
row.get(5)?,
))
})?;
let mut values_by_id = HashMap::<i64, (i64, PathBuf, usize, String)>::default();
let mut values_by_id = HashMap::<i64, (i64, PathBuf, Range<usize>, String)>::default();
for row in result_iter {
let (id, worktree_id, path, offset, name) = row?;
values_by_id.insert(id, (worktree_id, path, offset, name));
let (id, worktree_id, path, range, name) = row?;
values_by_id.insert(id, (worktree_id, path, range, name));
}
let mut results = Vec::with_capacity(ids.len());

View File

@ -66,7 +66,7 @@ impl PickerDelegate for SemanticSearchDelegate {
});
let workspace = self.workspace.clone();
let position = search_result.clone().offset;
let position = search_result.clone().byte_range.start;
cx.spawn(|_, mut cx| async move {
let buffer = buffer.await?;
workspace.update(&mut cx, |workspace, cx| {

View File

@ -1,41 +1,39 @@
use std::{path::PathBuf, sync::Arc, time::SystemTime};
use anyhow::{anyhow, Ok, Result};
use project::Fs;
use language::Language;
use std::{ops::Range, path::Path, sync::Arc};
use tree_sitter::{Parser, QueryCursor};
use crate::PendingFile;
#[derive(Debug, PartialEq, Clone)]
pub struct Document {
pub offset: usize,
pub name: String,
pub range: Range<usize>,
pub content: String,
pub embedding: Vec<f32>,
}
#[derive(Debug, PartialEq, Clone)]
pub struct ParsedFile {
pub path: PathBuf,
pub mtime: SystemTime,
pub documents: Vec<Document>,
}
const CODE_CONTEXT_TEMPLATE: &str =
"The below code snippet is from file '<path>'\n\n```<language>\n<item>\n```";
pub struct CodeContextRetriever {
pub parser: Parser,
pub cursor: QueryCursor,
pub fs: Arc<dyn Fs>,
}
impl CodeContextRetriever {
pub async fn parse_file(
pub fn new() -> Self {
Self {
parser: Parser::new(),
cursor: QueryCursor::new(),
}
}
pub fn parse_file(
&mut self,
pending_file: PendingFile,
) -> Result<(ParsedFile, Vec<String>)> {
let grammar = pending_file
.language
relative_path: &Path,
content: &str,
language: Arc<Language>,
) -> Result<Vec<Document>> {
let grammar = language
.grammar()
.ok_or_else(|| anyhow!("no grammar for language"))?;
let embedding_config = grammar
@ -43,8 +41,6 @@ impl CodeContextRetriever {
.as_ref()
.ok_or_else(|| anyhow!("no embedding queries"))?;
let content = self.fs.load(&pending_file.absolute_path).await?;
self.parser.set_language(grammar.ts_language).unwrap();
let tree = self
@ -53,7 +49,6 @@ impl CodeContextRetriever {
.ok_or_else(|| anyhow!("parsing failed"))?;
let mut documents = Vec::new();
let mut document_texts = Vec::new();
// Iterate through query matches
for mat in self.cursor.matches(
@ -63,11 +58,11 @@ impl CodeContextRetriever {
) {
let mut name: Vec<&str> = vec![];
let mut item: Option<&str> = None;
let mut offset: Option<usize> = None;
let mut byte_range: Option<Range<usize>> = None;
let mut context_spans: Vec<&str> = vec![];
for capture in mat.captures {
if capture.index == embedding_config.item_capture_ix {
offset = Some(capture.node.byte_range().start);
byte_range = Some(capture.node.byte_range());
item = content.get(capture.node.byte_range());
} else if capture.index == embedding_config.name_capture_ix {
if let Some(name_content) = content.get(capture.node.byte_range()) {
@ -84,30 +79,25 @@ impl CodeContextRetriever {
}
}
if item.is_some() && offset.is_some() && name.len() > 0 {
let item = format!("{}\n{}", context_spans.join("\n"), item.unwrap());
if let Some((item, byte_range)) = item.zip(byte_range) {
if !name.is_empty() {
let item = format!("{}\n{}", context_spans.join("\n"), item);
let document_text = CODE_CONTEXT_TEMPLATE
.replace("<path>", pending_file.relative_path.to_str().unwrap())
.replace("<language>", &pending_file.language.name().to_lowercase())
.replace("<path>", relative_path.to_str().unwrap())
.replace("<language>", &language.name().to_lowercase())
.replace("<item>", item.as_str());
document_texts.push(document_text);
documents.push(Document {
name: name.join(" "),
offset: offset.unwrap(),
range: byte_range,
content: document_text,
embedding: Vec::new(),
})
name: name.join(" ").to_string(),
});
}
}
}
return Ok((
ParsedFile {
path: pending_file.relative_path,
mtime: pending_file.modified_time,
documents,
},
document_texts,
));
return Ok(documents);
}
}

View File

@ -18,16 +18,16 @@ use gpui::{
};
use language::{Language, LanguageRegistry};
use modal::{SemanticSearch, SemanticSearchDelegate, Toggle};
use parsing::{CodeContextRetriever, ParsedFile};
use parsing::{CodeContextRetriever, Document};
use project::{Fs, PathChange, Project, ProjectEntryId, WorktreeId};
use smol::channel;
use std::{
collections::HashMap,
ops::Range,
path::{Path, PathBuf},
sync::Arc,
time::{Duration, Instant, SystemTime},
};
use tree_sitter::{Parser, QueryCursor};
use util::{
channel::{ReleaseChannel, RELEASE_CHANNEL, RELEASE_CHANNEL_NAME},
http::HttpClient,
@ -36,7 +36,7 @@ use util::{
};
use workspace::{Workspace, WorkspaceCreated};
const VECTOR_STORE_VERSION: usize = 0;
const VECTOR_STORE_VERSION: usize = 1;
const EMBEDDINGS_BATCH_SIZE: usize = 150;
pub fn init(
@ -80,11 +80,11 @@ pub fn init(
let vector_store = VectorStore::new(
fs,
db_file_path,
// Arc::new(embedding::DummyEmbeddings {}),
Arc::new(OpenAIEmbeddings {
client: http_client,
executor: cx.background(),
}),
Arc::new(embedding::DummyEmbeddings {}),
// Arc::new(OpenAIEmbeddings {
// client: http_client,
// executor: cx.background(),
// }),
language_registry,
cx.clone(),
)
@ -212,14 +212,16 @@ pub struct PendingFile {
pub struct SearchResult {
pub worktree_id: WorktreeId,
pub name: String,
pub offset: usize,
pub byte_range: Range<usize>,
pub file_path: PathBuf,
}
enum DbOperation {
InsertFile {
worktree_id: i64,
indexed_file: ParsedFile,
documents: Vec<Document>,
path: PathBuf,
mtime: SystemTime,
},
Delete {
worktree_id: i64,
@ -238,8 +240,9 @@ enum DbOperation {
enum EmbeddingJob {
Enqueue {
worktree_id: i64,
parsed_file: ParsedFile,
document_spans: Vec<String>,
path: PathBuf,
mtime: SystemTime,
documents: Vec<Document>,
},
Flush,
}
@ -256,18 +259,7 @@ impl VectorStore {
let db = cx
.background()
.spawn({
let fs = fs.clone();
let database_url = database_url.clone();
async move {
if let Some(db_directory) = database_url.parent() {
fs.create_dir(db_directory).await.log_err();
}
let db = VectorDatabase::new(database_url.to_string_lossy().to_string())?;
anyhow::Ok(db)
}
})
.spawn(VectorDatabase::new(fs.clone(), database_url.clone()))
.await?;
Ok(cx.add_model(|cx| {
@ -280,9 +272,12 @@ impl VectorStore {
match job {
DbOperation::InsertFile {
worktree_id,
indexed_file,
documents,
path,
mtime,
} => {
db.insert_file(worktree_id, indexed_file).log_err();
db.insert_file(worktree_id, path, mtime, documents)
.log_err();
}
DbOperation::Delete { worktree_id, path } => {
db.delete_file(worktree_id, path).log_err();
@ -304,35 +299,45 @@ impl VectorStore {
// embed_tx/rx: Embed Batch and Send to Database
let (embed_batch_tx, embed_batch_rx) =
channel::unbounded::<Vec<(i64, ParsedFile, Vec<String>)>>();
channel::unbounded::<Vec<(i64, Vec<Document>, PathBuf, SystemTime)>>();
let _embed_batch_task = cx.background().spawn({
let db_update_tx = db_update_tx.clone();
let embedding_provider = embedding_provider.clone();
async move {
while let Ok(mut embeddings_queue) = embed_batch_rx.recv().await {
// Construct Batch
let mut document_spans = vec![];
for (_, _, document_span) in embeddings_queue.iter() {
document_spans.extend(document_span.iter().map(|s| s.as_str()));
let mut batch_documents = vec![];
for (_, documents, _, _) in embeddings_queue.iter() {
batch_documents
.extend(documents.iter().map(|document| document.content.as_str()));
}
if let Ok(embeddings) = embedding_provider.embed_batch(document_spans).await
if let Ok(embeddings) =
embedding_provider.embed_batch(batch_documents).await
{
log::trace!(
"created {} embeddings for {} files",
embeddings.len(),
embeddings_queue.len(),
);
let mut i = 0;
let mut j = 0;
for embedding in embeddings.iter() {
while embeddings_queue[i].1.documents.len() == j {
while embeddings_queue[i].1.len() == j {
i += 1;
j = 0;
}
embeddings_queue[i].1.documents[j].embedding = embedding.to_owned();
embeddings_queue[i].1[j].embedding = embedding.to_owned();
j += 1;
}
for (worktree_id, indexed_file, _) in embeddings_queue.into_iter() {
for document in indexed_file.documents.iter() {
for (worktree_id, documents, path, mtime) in
embeddings_queue.into_iter()
{
for document in documents.iter() {
// TODO: Update this so it doesn't panic
assert!(
document.embedding.len() > 0,
@ -343,7 +348,9 @@ impl VectorStore {
db_update_tx
.send(DbOperation::InsertFile {
worktree_id,
indexed_file,
documents,
path,
mtime,
})
.await
.unwrap();
@ -362,12 +369,13 @@ impl VectorStore {
while let Ok(job) = batch_files_rx.recv().await {
let should_flush = match job {
EmbeddingJob::Enqueue {
document_spans,
documents,
worktree_id,
parsed_file,
path,
mtime,
} => {
queue_len += &document_spans.len();
embeddings_queue.push((worktree_id, parsed_file, document_spans));
queue_len += &documents.len();
embeddings_queue.push((worktree_id, documents, path, mtime));
queue_len >= EMBEDDINGS_BATCH_SIZE
}
EmbeddingJob::Flush => true,
@ -385,27 +393,39 @@ impl VectorStore {
let (parsing_files_tx, parsing_files_rx) = channel::unbounded::<PendingFile>();
let mut _parsing_files_tasks = Vec::new();
// for _ in 0..cx.background().num_cpus() {
for _ in 0..1 {
for _ in 0..cx.background().num_cpus() {
let fs = fs.clone();
let parsing_files_rx = parsing_files_rx.clone();
let batch_files_tx = batch_files_tx.clone();
_parsing_files_tasks.push(cx.background().spawn(async move {
let parser = Parser::new();
let cursor = QueryCursor::new();
let mut retriever = CodeContextRetriever { parser, cursor, fs };
let mut retriever = CodeContextRetriever::new();
while let Ok(pending_file) = parsing_files_rx.recv().await {
if let Some((indexed_file, document_spans)) =
retriever.parse_file(pending_file.clone()).await.log_err()
if let Some(content) = fs.load(&pending_file.absolute_path).await.log_err()
{
if let Some(documents) = retriever
.parse_file(
&pending_file.relative_path,
&content,
pending_file.language,
)
.log_err()
{
log::trace!(
"parsed path {:?}: {} documents",
pending_file.relative_path,
documents.len()
);
batch_files_tx
.try_send(EmbeddingJob::Enqueue {
worktree_id: pending_file.worktree_db_id,
parsed_file: indexed_file,
document_spans,
path: pending_file.relative_path,
mtime: pending_file.modified_time,
documents,
})
.unwrap();
}
}
if parsing_files_rx.len() == 0 {
batch_files_tx.try_send(EmbeddingJob::Flush).unwrap();
@ -543,6 +563,7 @@ impl VectorStore {
});
if !already_stored {
log::trace!("sending for parsing: {:?}", path_buf);
parsing_files_tx
.try_send(PendingFile {
worktree_db_id: db_ids_by_worktree_id
@ -565,8 +586,8 @@ impl VectorStore {
.unwrap();
}
}
log::info!(
"Parsing Worktree Completed in {:?}",
log::trace!(
"parsing worktree completed in {:?}",
t0.elapsed().as_millis()
);
}
@ -622,11 +643,12 @@ impl VectorStore {
let embedding_provider = self.embedding_provider.clone();
let database_url = self.database_url.clone();
let fs = self.fs.clone();
cx.spawn(|this, cx| async move {
let documents = cx
.background()
.spawn(async move {
let database = VectorDatabase::new(database_url.to_string_lossy().into())?;
let database = VectorDatabase::new(fs, database_url).await?;
let phrase_embedding = embedding_provider
.embed_batch(vec![&phrase])
@ -648,12 +670,12 @@ impl VectorStore {
Ok(documents
.into_iter()
.filter_map(|(worktree_db_id, file_path, offset, name)| {
.filter_map(|(worktree_db_id, file_path, byte_range, name)| {
let worktree_id = project_state.worktree_id_for_db_id(worktree_db_id)?;
Some(SearchResult {
worktree_id,
name,
offset,
byte_range,
file_path,
})
})

View File

@ -12,6 +12,13 @@ use settings::SettingsStore;
use std::sync::Arc;
use unindent::Unindent;
#[ctor::ctor]
fn init_logger() {
if std::env::var("RUST_LOG").is_ok() {
env_logger::init();
}
}
#[gpui::test]
async fn test_vector_store(cx: &mut TestAppContext) {
cx.update(|cx| {
@ -95,11 +102,23 @@ async fn test_vector_store(cx: &mut TestAppContext) {
.await
.unwrap();
assert_eq!(search_results[0].offset, 0);
assert_eq!(search_results[0].byte_range.start, 0);
assert_eq!(search_results[0].name, "aaa");
assert_eq!(search_results[0].worktree_id, worktree_id);
}
#[gpui::test]
async fn test_code_context_retrieval(cx: &mut TestAppContext) {
// let mut retriever = CodeContextRetriever::new(fs);
// retriever::parse_file(
// "
// //
// ",
// );
//
}
#[gpui::test]
fn test_dot_product(mut rng: StdRng) {
assert_eq!(dot(&[1., 0., 0., 0., 0.], &[0., 1., 0., 0., 0.]), 0.);