WIP: Got the streaming matrix multiplication working, and started work on file hashing.

Co-authored-by: maxbrunsfeld <max@zed.dev>
This commit is contained in:
KCaverly 2023-06-26 19:01:19 -04:00
parent 74b693d6b9
commit 953e928bdb
7 changed files with 396 additions and 97 deletions

5
Cargo.lock generated
View File

@ -7958,13 +7958,18 @@ dependencies = [
"language",
"lazy_static",
"log",
"matrixmultiply",
"ndarray",
"project",
"rand 0.8.5",
"rusqlite",
"serde",
"serde_json",
"sha-1 0.10.1",
"smol",
"tree-sitter",
"tree-sitter-rust",
"unindent",
"util",
"workspace",
]

View File

@ -27,9 +27,14 @@ serde_json.workspace = true
async-trait.workspace = true
bincode = "1.3.3"
ndarray = "0.15.6"
sha-1 = "0.10.1"
matrixmultiply = "0.3.7"
[dev-dependencies]
gpui = { path = "../gpui", features = ["test-support"] }
language = { path = "../language", features = ["test-support"] }
project = { path = "../project", features = ["test-support"] }
workspace = { path = "../workspace", features = ["test-support"] }
tree-sitter-rust = "*"
rand.workspace = true
unindent.workspace = true

View File

@ -1,4 +1,7 @@
use std::{collections::HashMap, path::PathBuf};
use std::{
collections::HashMap,
path::{Path, PathBuf},
};
use anyhow::{anyhow, Result};
@ -13,7 +16,7 @@ use crate::IndexedFile;
// This is saving to a local database store within the users dev zed path
// Where do we want this to sit?
// Assuming near where the workspace DB sits.
const VECTOR_DB_URL: &str = "embeddings_db";
pub const VECTOR_DB_URL: &str = "embeddings_db";
// Note this is not an appropriate document
#[derive(Debug)]
@ -28,7 +31,7 @@ pub struct DocumentRecord {
#[derive(Debug)]
pub struct FileRecord {
pub id: usize,
pub path: String,
pub relative_path: String,
pub sha1: String,
}
@ -51,9 +54,9 @@ pub struct VectorDatabase {
}
impl VectorDatabase {
pub fn new() -> Result<Self> {
pub fn new(path: &str) -> Result<Self> {
let this = Self {
db: rusqlite::Connection::open(VECTOR_DB_URL)?,
db: rusqlite::Connection::open(path)?,
};
this.initialize_database()?;
Ok(this)
@ -63,21 +66,23 @@ impl VectorDatabase {
// This will create the database if it doesnt exist
// Initialize Vector Databasing Tables
// self.db.execute(
// "
// CREATE TABLE IF NOT EXISTS projects (
// id INTEGER PRIMARY KEY AUTOINCREMENT,
// path NVARCHAR(100) NOT NULL
// )
// ",
// [],
// )?;
self.db.execute(
"CREATE TABLE IF NOT EXISTS worktrees (
id INTEGER PRIMARY KEY AUTOINCREMENT,
absolute_path VARCHAR NOT NULL
);
CREATE UNIQUE INDEX IF NOT EXISTS worktrees_absolute_path ON worktrees (absolute_path);
",
[],
)?;
self.db.execute(
"CREATE TABLE IF NOT EXISTS files (
id INTEGER PRIMARY KEY AUTOINCREMENT,
path NVARCHAR(100) NOT NULL,
sha1 NVARCHAR(40) NOT NULL
worktree_id INTEGER NOT NULL,
relative_path VARCHAR NOT NULL,
sha1 NVARCHAR(40) NOT NULL,
FOREIGN KEY(worktree_id) REFERENCES worktrees(id) ON DELETE CASCADE
)",
[],
)?;
@ -87,7 +92,7 @@ impl VectorDatabase {
id INTEGER PRIMARY KEY AUTOINCREMENT,
file_id INTEGER NOT NULL,
offset INTEGER NOT NULL,
name NVARCHAR(100) NOT NULL,
name VARCHAR NOT NULL,
embedding BLOB NOT NULL,
FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE
)",
@ -116,7 +121,7 @@ impl VectorDatabase {
pub fn insert_file(&self, indexed_file: IndexedFile) -> Result<()> {
// Write to files table, and return generated id.
let files_insert = self.db.execute(
"INSERT INTO files (path, sha1) VALUES (?1, ?2)",
"INSERT INTO files (relative_path, sha1) VALUES (?1, ?2)",
params![indexed_file.path.to_str(), indexed_file.sha1],
)?;
@ -141,12 +146,38 @@ impl VectorDatabase {
Ok(())
}
pub fn find_or_create_worktree(&self, worktree_root_path: &Path) -> Result<i64> {
self.db.execute(
"
INSERT into worktrees (absolute_path) VALUES (?1)
ON CONFLICT DO NOTHING
",
params![worktree_root_path.to_string_lossy()],
)?;
Ok(self.db.last_insert_rowid())
}
pub fn get_file_hashes(&self, worktree_id: i64) -> Result<Vec<(PathBuf, String)>> {
let mut statement = self
.db
.prepare("SELECT relative_path, sha1 FROM files ORDER BY relative_path")?;
let mut result = Vec::new();
for row in
statement.query_map([], |row| Ok((row.get::<_, String>(0)?.into(), row.get(1)?)))?
{
result.push(row?);
}
Ok(result)
}
pub fn get_files(&self) -> Result<HashMap<usize, FileRecord>> {
let mut query_statement = self.db.prepare("SELECT id, path, sha1 FROM files")?;
let mut query_statement = self
.db
.prepare("SELECT id, relative_path, sha1 FROM files")?;
let result_iter = query_statement.query_map([], |row| {
Ok(FileRecord {
id: row.get(0)?,
path: row.get(1)?,
relative_path: row.get(1)?,
sha1: row.get(2)?,
})
})?;
@ -160,6 +191,19 @@ impl VectorDatabase {
Ok(pages)
}
pub fn for_each_document(
&self,
worktree_id: i64,
mut f: impl FnMut(i64, Embedding),
) -> Result<()> {
let mut query_statement = self.db.prepare("SELECT id, embedding FROM documents")?;
query_statement
.query_map(params![], |row| Ok((row.get(0)?, row.get(1)?)))?
.filter_map(|row| row.ok())
.for_each(|row| f(row.0, row.1));
Ok(())
}
pub fn get_documents(&self) -> Result<HashMap<usize, DocumentRecord>> {
let mut query_statement = self
.db

View File

@ -44,7 +44,7 @@ struct OpenAIEmbeddingUsage {
}
#[async_trait]
pub trait EmbeddingProvider: Sync {
pub trait EmbeddingProvider: Sync + Send {
async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>>;
}

View File

@ -1,4 +1,4 @@
use std::cmp::Ordering;
use std::{cmp::Ordering, path::PathBuf};
use async_trait::async_trait;
use ndarray::{Array1, Array2};
@ -20,7 +20,6 @@ pub struct BruteForceSearch {
impl BruteForceSearch {
pub fn load(db: &VectorDatabase) -> Result<Self> {
// let db = VectorDatabase {};
let documents = db.get_documents()?;
let embeddings: Vec<&DocumentRecord> = documents.values().into_iter().collect();
let mut document_ids = vec![];
@ -63,20 +62,5 @@ impl VectorSearch for BruteForceSearch {
with_indices.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal));
with_indices.truncate(limit);
with_indices
// // extract the sorted indices from the sorted tuple vector
// let stored_indices = with_indices
// .into_iter()
// .map(|(index, value)| index)
// .collect::<Vec<>>();
// let sorted_indices: Vec<usize> = stored_indices.into_iter().rev().collect();
// let mut results = vec![];
// for idx in sorted_indices[0..limit].to_vec() {
// results.push((self.document_ids[idx], 1.0 - similarities[idx]));
// }
// return results;
}
}

View File

@ -3,16 +3,19 @@ mod embedding;
mod parsing;
mod search;
#[cfg(test)]
mod vector_store_tests;
use anyhow::{anyhow, Result};
use db::VectorDatabase;
use db::{VectorDatabase, VECTOR_DB_URL};
use embedding::{DummyEmbeddings, EmbeddingProvider, OpenAIEmbeddings};
use gpui::{AppContext, Entity, ModelContext, ModelHandle};
use gpui::{AppContext, Entity, ModelContext, ModelHandle, Task};
use language::LanguageRegistry;
use parsing::Document;
use project::{Fs, Project};
use search::{BruteForceSearch, VectorSearch};
use smol::channel;
use std::{path::PathBuf, sync::Arc, time::Instant};
use std::{cmp::Ordering, path::PathBuf, sync::Arc, time::Instant};
use tree_sitter::{Parser, QueryCursor};
use util::{http::HttpClient, ResultExt, TryFutureExt};
use workspace::WorkspaceCreated;
@ -23,7 +26,16 @@ pub fn init(
language_registry: Arc<LanguageRegistry>,
cx: &mut AppContext,
) {
let vector_store = cx.add_model(|cx| VectorStore::new(fs, http_client, language_registry));
let vector_store = cx.add_model(|cx| {
VectorStore::new(
fs,
VECTOR_DB_URL.to_string(),
Arc::new(OpenAIEmbeddings {
client: http_client,
}),
language_registry,
)
});
cx.subscribe_global::<WorkspaceCreated, _>({
let vector_store = vector_store.clone();
@ -49,28 +61,36 @@ pub struct IndexedFile {
documents: Vec<Document>,
}
struct SearchResult {
path: PathBuf,
offset: usize,
name: String,
distance: f32,
}
// struct SearchResult {
// path: PathBuf,
// offset: usize,
// name: String,
// distance: f32,
// }
struct VectorStore {
fs: Arc<dyn Fs>,
http_client: Arc<dyn HttpClient>,
database_url: Arc<str>,
embedding_provider: Arc<dyn EmbeddingProvider>,
language_registry: Arc<LanguageRegistry>,
}
pub struct SearchResult {
pub name: String,
pub offset: usize,
pub file_path: PathBuf,
}
impl VectorStore {
fn new(
fs: Arc<dyn Fs>,
http_client: Arc<dyn HttpClient>,
database_url: String,
embedding_provider: Arc<dyn EmbeddingProvider>,
language_registry: Arc<LanguageRegistry>,
) -> Self {
Self {
fs,
http_client,
database_url: database_url.into(),
embedding_provider,
language_registry,
}
}
@ -79,10 +99,12 @@ impl VectorStore {
cursor: &mut QueryCursor,
parser: &mut Parser,
embedding_provider: &dyn EmbeddingProvider,
fs: &Arc<dyn Fs>,
language_registry: &Arc<LanguageRegistry>,
file_path: PathBuf,
content: String,
) -> Result<IndexedFile> {
dbg!(&file_path, &content);
let language = language_registry
.language_for_file(&file_path, None)
.await?;
@ -97,7 +119,6 @@ impl VectorStore {
.as_ref()
.ok_or_else(|| anyhow!("no outline query"))?;
let content = fs.load(&file_path).await?;
parser.set_language(grammar.ts_language).unwrap();
let tree = parser
.parse(&content, None)
@ -142,7 +163,11 @@ impl VectorStore {
});
}
fn add_project(&mut self, project: ModelHandle<Project>, cx: &mut ModelContext<Self>) {
fn add_project(
&mut self,
project: ModelHandle<Project>,
cx: &mut ModelContext<Self>,
) -> Task<Result<()>> {
let worktree_scans_complete = project
.read(cx)
.worktrees(cx)
@ -151,7 +176,8 @@ impl VectorStore {
let fs = self.fs.clone();
let language_registry = self.language_registry.clone();
let client = self.http_client.clone();
let embedding_provider = self.embedding_provider.clone();
let database_url = self.database_url.clone();
cx.spawn(|_, cx| async move {
futures::future::join_all(worktree_scans_complete).await;
@ -163,24 +189,47 @@ impl VectorStore {
.collect::<Vec<_>>()
});
let (paths_tx, paths_rx) = channel::unbounded::<PathBuf>();
let db = VectorDatabase::new(&database_url)?;
let worktree_root_paths = worktrees
.iter()
.map(|worktree| worktree.abs_path().clone())
.collect::<Vec<_>>();
let (db, file_hashes) = cx
.background()
.spawn(async move {
let mut hashes = Vec::new();
for worktree_root_path in worktree_root_paths {
let worktree_id =
db.find_or_create_worktree(worktree_root_path.as_ref())?;
hashes.push((worktree_id, db.get_file_hashes(worktree_id)?));
}
anyhow::Ok((db, hashes))
})
.await?;
let (paths_tx, paths_rx) = channel::unbounded::<(i64, PathBuf, String)>();
let (indexed_files_tx, indexed_files_rx) = channel::unbounded::<IndexedFile>();
cx.background()
.spawn(async move {
for worktree in worktrees {
for file in worktree.files(false, 0) {
paths_tx.try_send(worktree.absolutize(&file.path)).unwrap();
.spawn({
let fs = fs.clone();
async move {
for worktree in worktrees.into_iter() {
for file in worktree.files(false, 0) {
let absolute_path = worktree.absolutize(&file.path);
dbg!(&absolute_path);
if let Some(content) = fs.load(&absolute_path).await.log_err() {
dbg!(&content);
paths_tx.try_send((0, absolute_path, content)).unwrap();
}
}
}
}
})
.detach();
cx.background()
.spawn({
let client = client.clone();
async move {
let db_write_task = cx.background().spawn(
async move {
// Initialize Database, creates database and tables if not exists
let db = VectorDatabase::new()?;
while let Ok(indexed_file) = indexed_files_rx.recv().await {
db.insert_file(indexed_file).log_err();
}
@ -188,39 +237,39 @@ impl VectorStore {
// ALL OF THE BELOW IS FOR TESTING,
// This should be removed as we find and appropriate place for evaluate our search.
let embedding_provider = OpenAIEmbeddings{ client };
let queries = vec![
"compute embeddings for all of the symbols in the codebase, and write them to a database",
"compute an outline view of all of the symbols in a buffer",
"scan a directory on the file system and load all of its children into an in-memory snapshot",
];
let embeddings = embedding_provider.embed_batch(queries.clone()).await?;
// let queries = vec![
// "compute embeddings for all of the symbols in the codebase, and write them to a database",
// "compute an outline view of all of the symbols in a buffer",
// "scan a directory on the file system and load all of its children into an in-memory snapshot",
// ];
// let embeddings = embedding_provider.embed_batch(queries.clone()).await?;
let t2 = Instant::now();
let documents = db.get_documents().unwrap();
let files = db.get_files().unwrap();
println!("Retrieving all documents from Database: {}", t2.elapsed().as_millis());
// let t2 = Instant::now();
// let documents = db.get_documents().unwrap();
// let files = db.get_files().unwrap();
// println!("Retrieving all documents from Database: {}", t2.elapsed().as_millis());
let t1 = Instant::now();
let mut bfs = BruteForceSearch::load(&db).unwrap();
println!("Loading BFS to Memory: {:?}", t1.elapsed().as_millis());
for (idx, embed) in embeddings.into_iter().enumerate() {
let t0 = Instant::now();
println!("\nQuery: {:?}", queries[idx]);
let results = bfs.top_k_search(&embed, 5).await;
println!("Search Elapsed: {}", t0.elapsed().as_millis());
for (id, distance) in results {
println!("");
println!(" distance: {:?}", distance);
println!(" document: {:?}", documents[&id].name);
println!(" path: {:?}", files[&documents[&id].file_id].path);
}
// let t1 = Instant::now();
// let mut bfs = BruteForceSearch::load(&db).unwrap();
// println!("Loading BFS to Memory: {:?}", t1.elapsed().as_millis());
// for (idx, embed) in embeddings.into_iter().enumerate() {
// let t0 = Instant::now();
// println!("\nQuery: {:?}", queries[idx]);
// let results = bfs.top_k_search(&embed, 5).await;
// println!("Search Elapsed: {}", t0.elapsed().as_millis());
// for (id, distance) in results {
// println!("");
// println!(" distance: {:?}", distance);
// println!(" document: {:?}", documents[&id].name);
// println!(" path: {:?}", files[&documents[&id].file_id].relative_path);
// }
}
// }
anyhow::Ok(())
}}.log_err())
.detach();
}
.log_err(),
);
let provider = DummyEmbeddings {};
// let provider = OpenAIEmbeddings { client };
@ -231,14 +280,15 @@ impl VectorStore {
scope.spawn(async {
let mut parser = Parser::new();
let mut cursor = QueryCursor::new();
while let Ok(file_path) = paths_rx.recv().await {
while let Ok((worktree_id, file_path, content)) = paths_rx.recv().await
{
if let Some(indexed_file) = Self::index_file(
&mut cursor,
&mut parser,
&provider,
&fs,
&language_registry,
file_path,
content,
)
.await
.log_err()
@ -250,11 +300,86 @@ impl VectorStore {
}
})
.await;
drop(indexed_files_tx);
db_write_task.await;
anyhow::Ok(())
})
}
pub fn search(
&mut self,
phrase: String,
limit: usize,
cx: &mut ModelContext<Self>,
) -> Task<Result<Vec<SearchResult>>> {
let embedding_provider = self.embedding_provider.clone();
let database_url = self.database_url.clone();
cx.spawn(|this, cx| async move {
let database = VectorDatabase::new(database_url.as_ref())?;
// let embedding = embedding_provider.embed_batch(vec![&phrase]).await?;
//
let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
database.for_each_document(0, |id, embedding| {
dbg!(id, &embedding);
let similarity = dot(&embedding.0, &embedding.0);
let ix = match results.binary_search_by(|(_, s)| {
s.partial_cmp(&similarity).unwrap_or(Ordering::Equal)
}) {
Ok(ix) => ix,
Err(ix) => ix,
};
results.insert(ix, (id, similarity));
results.truncate(limit);
})?;
dbg!(&results);
let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<_>>();
// let documents = database.get_documents_by_ids(ids)?;
// let search_provider = cx
// .background()
// .spawn(async move { BruteForceSearch::load(&database) })
// .await?;
// let results = search_provider.top_k_search(&embedding, limit))
anyhow::Ok(vec![])
})
.detach();
}
}
impl Entity for VectorStore {
type Event = ();
}
fn dot(vec_a: &[f32], vec_b: &[f32]) -> f32 {
let len = vec_a.len();
assert_eq!(len, vec_b.len());
let mut result = 0.0;
unsafe {
matrixmultiply::sgemm(
1,
len,
1,
1.0,
vec_a.as_ptr(),
len as isize,
1,
vec_b.as_ptr(),
1,
len as isize,
0.0,
&mut result as *mut f32,
1,
1,
);
}
result
}

View File

@ -0,0 +1,136 @@
use std::sync::Arc;
use crate::{dot, embedding::EmbeddingProvider, VectorStore};
use anyhow::Result;
use async_trait::async_trait;
use gpui::{Task, TestAppContext};
use language::{Language, LanguageConfig, LanguageRegistry};
use project::{FakeFs, Project};
use rand::Rng;
use serde_json::json;
use unindent::Unindent;
#[gpui::test]
async fn test_vector_store(cx: &mut TestAppContext) {
let fs = FakeFs::new(cx.background());
fs.insert_tree(
"/the-root",
json!({
"src": {
"file1.rs": "
fn aaa() {
println!(\"aaaa!\");
}
fn zzzzzzzzz() {
println!(\"SLEEPING\");
}
".unindent(),
"file2.rs": "
fn bbb() {
println!(\"bbbb!\");
}
".unindent(),
}
}),
)
.await;
let languages = Arc::new(LanguageRegistry::new(Task::ready(())));
let rust_language = Arc::new(
Language::new(
LanguageConfig {
name: "Rust".into(),
path_suffixes: vec!["rs".into()],
..Default::default()
},
Some(tree_sitter_rust::language()),
)
.with_outline_query(
r#"
(function_item
name: (identifier) @name
body: (block)) @item
"#,
)
.unwrap(),
);
languages.add(rust_language);
let store = cx.add_model(|_| {
VectorStore::new(
fs.clone(),
"foo".to_string(),
Arc::new(FakeEmbeddingProvider),
languages,
)
});
let project = Project::test(fs, ["/the-root".as_ref()], cx).await;
store
.update(cx, |store, cx| store.add_project(project, cx))
.await
.unwrap();
let search_results = store
.update(cx, |store, cx| store.search("aaaa".to_string(), 5, cx))
.await
.unwrap();
assert_eq!(search_results[0].offset, 0);
assert_eq!(search_results[1].name, "aaa");
}
#[test]
fn test_dot_product() {
assert_eq!(dot(&[1., 0., 0., 0., 0.], &[0., 1., 0., 0., 0.]), 0.);
assert_eq!(dot(&[2., 0., 0., 0., 0.], &[3., 1., 0., 0., 0.]), 6.);
for _ in 0..100 {
let mut rng = rand::thread_rng();
let a: [f32; 32] = rng.gen();
let b: [f32; 32] = rng.gen();
assert_eq!(
round_to_decimals(dot(&a, &b), 3),
round_to_decimals(reference_dot(&a, &b), 3)
);
}
fn round_to_decimals(n: f32, decimal_places: i32) -> f32 {
let factor = (10.0 as f32).powi(decimal_places);
(n * factor).round() / factor
}
fn reference_dot(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b.iter()).map(|(a, b)| a * b).sum()
}
}
struct FakeEmbeddingProvider;
#[async_trait]
impl EmbeddingProvider for FakeEmbeddingProvider {
async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
Ok(spans
.iter()
.map(|span| {
let mut result = vec![0.0; 26];
for letter in span.chars() {
if letter as u32 > 'a' as u32 {
let ix = (letter as u32) - ('a' as u32);
if ix < 26 {
result[ix as usize] += 1.0;
}
}
}
let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
for x in &mut result {
*x /= norm;
}
result
})
.collect())
}
}