mirror of
https://github.com/zed-industries/zed.git
synced 2024-11-07 20:39:04 +03:00
WIP: Got the streaming matrix multiplication working, and started work on file hashing.
Co-authored-by: maxbrunsfeld <max@zed.dev>
This commit is contained in:
parent
74b693d6b9
commit
953e928bdb
5
Cargo.lock
generated
5
Cargo.lock
generated
@ -7958,13 +7958,18 @@ dependencies = [
|
||||
"language",
|
||||
"lazy_static",
|
||||
"log",
|
||||
"matrixmultiply",
|
||||
"ndarray",
|
||||
"project",
|
||||
"rand 0.8.5",
|
||||
"rusqlite",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"sha-1 0.10.1",
|
||||
"smol",
|
||||
"tree-sitter",
|
||||
"tree-sitter-rust",
|
||||
"unindent",
|
||||
"util",
|
||||
"workspace",
|
||||
]
|
||||
|
@ -27,9 +27,14 @@ serde_json.workspace = true
|
||||
async-trait.workspace = true
|
||||
bincode = "1.3.3"
|
||||
ndarray = "0.15.6"
|
||||
sha-1 = "0.10.1"
|
||||
matrixmultiply = "0.3.7"
|
||||
|
||||
[dev-dependencies]
|
||||
gpui = { path = "../gpui", features = ["test-support"] }
|
||||
language = { path = "../language", features = ["test-support"] }
|
||||
project = { path = "../project", features = ["test-support"] }
|
||||
workspace = { path = "../workspace", features = ["test-support"] }
|
||||
tree-sitter-rust = "*"
|
||||
rand.workspace = true
|
||||
unindent.workspace = true
|
||||
|
@ -1,4 +1,7 @@
|
||||
use std::{collections::HashMap, path::PathBuf};
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
path::{Path, PathBuf},
|
||||
};
|
||||
|
||||
use anyhow::{anyhow, Result};
|
||||
|
||||
@ -13,7 +16,7 @@ use crate::IndexedFile;
|
||||
// This is saving to a local database store within the users dev zed path
|
||||
// Where do we want this to sit?
|
||||
// Assuming near where the workspace DB sits.
|
||||
const VECTOR_DB_URL: &str = "embeddings_db";
|
||||
pub const VECTOR_DB_URL: &str = "embeddings_db";
|
||||
|
||||
// Note this is not an appropriate document
|
||||
#[derive(Debug)]
|
||||
@ -28,7 +31,7 @@ pub struct DocumentRecord {
|
||||
#[derive(Debug)]
|
||||
pub struct FileRecord {
|
||||
pub id: usize,
|
||||
pub path: String,
|
||||
pub relative_path: String,
|
||||
pub sha1: String,
|
||||
}
|
||||
|
||||
@ -51,9 +54,9 @@ pub struct VectorDatabase {
|
||||
}
|
||||
|
||||
impl VectorDatabase {
|
||||
pub fn new() -> Result<Self> {
|
||||
pub fn new(path: &str) -> Result<Self> {
|
||||
let this = Self {
|
||||
db: rusqlite::Connection::open(VECTOR_DB_URL)?,
|
||||
db: rusqlite::Connection::open(path)?,
|
||||
};
|
||||
this.initialize_database()?;
|
||||
Ok(this)
|
||||
@ -63,21 +66,23 @@ impl VectorDatabase {
|
||||
// This will create the database if it doesnt exist
|
||||
|
||||
// Initialize Vector Databasing Tables
|
||||
// self.db.execute(
|
||||
// "
|
||||
// CREATE TABLE IF NOT EXISTS projects (
|
||||
// id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
// path NVARCHAR(100) NOT NULL
|
||||
// )
|
||||
// ",
|
||||
// [],
|
||||
// )?;
|
||||
self.db.execute(
|
||||
"CREATE TABLE IF NOT EXISTS worktrees (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
absolute_path VARCHAR NOT NULL
|
||||
);
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS worktrees_absolute_path ON worktrees (absolute_path);
|
||||
",
|
||||
[],
|
||||
)?;
|
||||
|
||||
self.db.execute(
|
||||
"CREATE TABLE IF NOT EXISTS files (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
path NVARCHAR(100) NOT NULL,
|
||||
sha1 NVARCHAR(40) NOT NULL
|
||||
worktree_id INTEGER NOT NULL,
|
||||
relative_path VARCHAR NOT NULL,
|
||||
sha1 NVARCHAR(40) NOT NULL,
|
||||
FOREIGN KEY(worktree_id) REFERENCES worktrees(id) ON DELETE CASCADE
|
||||
)",
|
||||
[],
|
||||
)?;
|
||||
@ -87,7 +92,7 @@ impl VectorDatabase {
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
file_id INTEGER NOT NULL,
|
||||
offset INTEGER NOT NULL,
|
||||
name NVARCHAR(100) NOT NULL,
|
||||
name VARCHAR NOT NULL,
|
||||
embedding BLOB NOT NULL,
|
||||
FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE
|
||||
)",
|
||||
@ -116,7 +121,7 @@ impl VectorDatabase {
|
||||
pub fn insert_file(&self, indexed_file: IndexedFile) -> Result<()> {
|
||||
// Write to files table, and return generated id.
|
||||
let files_insert = self.db.execute(
|
||||
"INSERT INTO files (path, sha1) VALUES (?1, ?2)",
|
||||
"INSERT INTO files (relative_path, sha1) VALUES (?1, ?2)",
|
||||
params![indexed_file.path.to_str(), indexed_file.sha1],
|
||||
)?;
|
||||
|
||||
@ -141,12 +146,38 @@ impl VectorDatabase {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn find_or_create_worktree(&self, worktree_root_path: &Path) -> Result<i64> {
|
||||
self.db.execute(
|
||||
"
|
||||
INSERT into worktrees (absolute_path) VALUES (?1)
|
||||
ON CONFLICT DO NOTHING
|
||||
",
|
||||
params![worktree_root_path.to_string_lossy()],
|
||||
)?;
|
||||
Ok(self.db.last_insert_rowid())
|
||||
}
|
||||
|
||||
pub fn get_file_hashes(&self, worktree_id: i64) -> Result<Vec<(PathBuf, String)>> {
|
||||
let mut statement = self
|
||||
.db
|
||||
.prepare("SELECT relative_path, sha1 FROM files ORDER BY relative_path")?;
|
||||
let mut result = Vec::new();
|
||||
for row in
|
||||
statement.query_map([], |row| Ok((row.get::<_, String>(0)?.into(), row.get(1)?)))?
|
||||
{
|
||||
result.push(row?);
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
pub fn get_files(&self) -> Result<HashMap<usize, FileRecord>> {
|
||||
let mut query_statement = self.db.prepare("SELECT id, path, sha1 FROM files")?;
|
||||
let mut query_statement = self
|
||||
.db
|
||||
.prepare("SELECT id, relative_path, sha1 FROM files")?;
|
||||
let result_iter = query_statement.query_map([], |row| {
|
||||
Ok(FileRecord {
|
||||
id: row.get(0)?,
|
||||
path: row.get(1)?,
|
||||
relative_path: row.get(1)?,
|
||||
sha1: row.get(2)?,
|
||||
})
|
||||
})?;
|
||||
@ -160,6 +191,19 @@ impl VectorDatabase {
|
||||
Ok(pages)
|
||||
}
|
||||
|
||||
pub fn for_each_document(
|
||||
&self,
|
||||
worktree_id: i64,
|
||||
mut f: impl FnMut(i64, Embedding),
|
||||
) -> Result<()> {
|
||||
let mut query_statement = self.db.prepare("SELECT id, embedding FROM documents")?;
|
||||
query_statement
|
||||
.query_map(params![], |row| Ok((row.get(0)?, row.get(1)?)))?
|
||||
.filter_map(|row| row.ok())
|
||||
.for_each(|row| f(row.0, row.1));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn get_documents(&self) -> Result<HashMap<usize, DocumentRecord>> {
|
||||
let mut query_statement = self
|
||||
.db
|
||||
|
@ -44,7 +44,7 @@ struct OpenAIEmbeddingUsage {
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait EmbeddingProvider: Sync {
|
||||
pub trait EmbeddingProvider: Sync + Send {
|
||||
async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>>;
|
||||
}
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
use std::cmp::Ordering;
|
||||
use std::{cmp::Ordering, path::PathBuf};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use ndarray::{Array1, Array2};
|
||||
@ -20,7 +20,6 @@ pub struct BruteForceSearch {
|
||||
|
||||
impl BruteForceSearch {
|
||||
pub fn load(db: &VectorDatabase) -> Result<Self> {
|
||||
// let db = VectorDatabase {};
|
||||
let documents = db.get_documents()?;
|
||||
let embeddings: Vec<&DocumentRecord> = documents.values().into_iter().collect();
|
||||
let mut document_ids = vec![];
|
||||
@ -63,20 +62,5 @@ impl VectorSearch for BruteForceSearch {
|
||||
with_indices.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal));
|
||||
with_indices.truncate(limit);
|
||||
with_indices
|
||||
|
||||
// // extract the sorted indices from the sorted tuple vector
|
||||
// let stored_indices = with_indices
|
||||
// .into_iter()
|
||||
// .map(|(index, value)| index)
|
||||
// .collect::<Vec<>>();
|
||||
|
||||
// let sorted_indices: Vec<usize> = stored_indices.into_iter().rev().collect();
|
||||
|
||||
// let mut results = vec![];
|
||||
// for idx in sorted_indices[0..limit].to_vec() {
|
||||
// results.push((self.document_ids[idx], 1.0 - similarities[idx]));
|
||||
// }
|
||||
|
||||
// return results;
|
||||
}
|
||||
}
|
||||
|
@ -3,16 +3,19 @@ mod embedding;
|
||||
mod parsing;
|
||||
mod search;
|
||||
|
||||
#[cfg(test)]
|
||||
mod vector_store_tests;
|
||||
|
||||
use anyhow::{anyhow, Result};
|
||||
use db::VectorDatabase;
|
||||
use db::{VectorDatabase, VECTOR_DB_URL};
|
||||
use embedding::{DummyEmbeddings, EmbeddingProvider, OpenAIEmbeddings};
|
||||
use gpui::{AppContext, Entity, ModelContext, ModelHandle};
|
||||
use gpui::{AppContext, Entity, ModelContext, ModelHandle, Task};
|
||||
use language::LanguageRegistry;
|
||||
use parsing::Document;
|
||||
use project::{Fs, Project};
|
||||
use search::{BruteForceSearch, VectorSearch};
|
||||
use smol::channel;
|
||||
use std::{path::PathBuf, sync::Arc, time::Instant};
|
||||
use std::{cmp::Ordering, path::PathBuf, sync::Arc, time::Instant};
|
||||
use tree_sitter::{Parser, QueryCursor};
|
||||
use util::{http::HttpClient, ResultExt, TryFutureExt};
|
||||
use workspace::WorkspaceCreated;
|
||||
@ -23,7 +26,16 @@ pub fn init(
|
||||
language_registry: Arc<LanguageRegistry>,
|
||||
cx: &mut AppContext,
|
||||
) {
|
||||
let vector_store = cx.add_model(|cx| VectorStore::new(fs, http_client, language_registry));
|
||||
let vector_store = cx.add_model(|cx| {
|
||||
VectorStore::new(
|
||||
fs,
|
||||
VECTOR_DB_URL.to_string(),
|
||||
Arc::new(OpenAIEmbeddings {
|
||||
client: http_client,
|
||||
}),
|
||||
language_registry,
|
||||
)
|
||||
});
|
||||
|
||||
cx.subscribe_global::<WorkspaceCreated, _>({
|
||||
let vector_store = vector_store.clone();
|
||||
@ -49,28 +61,36 @@ pub struct IndexedFile {
|
||||
documents: Vec<Document>,
|
||||
}
|
||||
|
||||
struct SearchResult {
|
||||
path: PathBuf,
|
||||
offset: usize,
|
||||
name: String,
|
||||
distance: f32,
|
||||
}
|
||||
|
||||
// struct SearchResult {
|
||||
// path: PathBuf,
|
||||
// offset: usize,
|
||||
// name: String,
|
||||
// distance: f32,
|
||||
// }
|
||||
struct VectorStore {
|
||||
fs: Arc<dyn Fs>,
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
database_url: Arc<str>,
|
||||
embedding_provider: Arc<dyn EmbeddingProvider>,
|
||||
language_registry: Arc<LanguageRegistry>,
|
||||
}
|
||||
|
||||
pub struct SearchResult {
|
||||
pub name: String,
|
||||
pub offset: usize,
|
||||
pub file_path: PathBuf,
|
||||
}
|
||||
|
||||
impl VectorStore {
|
||||
fn new(
|
||||
fs: Arc<dyn Fs>,
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
database_url: String,
|
||||
embedding_provider: Arc<dyn EmbeddingProvider>,
|
||||
language_registry: Arc<LanguageRegistry>,
|
||||
) -> Self {
|
||||
Self {
|
||||
fs,
|
||||
http_client,
|
||||
database_url: database_url.into(),
|
||||
embedding_provider,
|
||||
language_registry,
|
||||
}
|
||||
}
|
||||
@ -79,10 +99,12 @@ impl VectorStore {
|
||||
cursor: &mut QueryCursor,
|
||||
parser: &mut Parser,
|
||||
embedding_provider: &dyn EmbeddingProvider,
|
||||
fs: &Arc<dyn Fs>,
|
||||
language_registry: &Arc<LanguageRegistry>,
|
||||
file_path: PathBuf,
|
||||
content: String,
|
||||
) -> Result<IndexedFile> {
|
||||
dbg!(&file_path, &content);
|
||||
|
||||
let language = language_registry
|
||||
.language_for_file(&file_path, None)
|
||||
.await?;
|
||||
@ -97,7 +119,6 @@ impl VectorStore {
|
||||
.as_ref()
|
||||
.ok_or_else(|| anyhow!("no outline query"))?;
|
||||
|
||||
let content = fs.load(&file_path).await?;
|
||||
parser.set_language(grammar.ts_language).unwrap();
|
||||
let tree = parser
|
||||
.parse(&content, None)
|
||||
@ -142,7 +163,11 @@ impl VectorStore {
|
||||
});
|
||||
}
|
||||
|
||||
fn add_project(&mut self, project: ModelHandle<Project>, cx: &mut ModelContext<Self>) {
|
||||
fn add_project(
|
||||
&mut self,
|
||||
project: ModelHandle<Project>,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) -> Task<Result<()>> {
|
||||
let worktree_scans_complete = project
|
||||
.read(cx)
|
||||
.worktrees(cx)
|
||||
@ -151,7 +176,8 @@ impl VectorStore {
|
||||
|
||||
let fs = self.fs.clone();
|
||||
let language_registry = self.language_registry.clone();
|
||||
let client = self.http_client.clone();
|
||||
let embedding_provider = self.embedding_provider.clone();
|
||||
let database_url = self.database_url.clone();
|
||||
|
||||
cx.spawn(|_, cx| async move {
|
||||
futures::future::join_all(worktree_scans_complete).await;
|
||||
@ -163,24 +189,47 @@ impl VectorStore {
|
||||
.collect::<Vec<_>>()
|
||||
});
|
||||
|
||||
let (paths_tx, paths_rx) = channel::unbounded::<PathBuf>();
|
||||
let db = VectorDatabase::new(&database_url)?;
|
||||
let worktree_root_paths = worktrees
|
||||
.iter()
|
||||
.map(|worktree| worktree.abs_path().clone())
|
||||
.collect::<Vec<_>>();
|
||||
let (db, file_hashes) = cx
|
||||
.background()
|
||||
.spawn(async move {
|
||||
let mut hashes = Vec::new();
|
||||
for worktree_root_path in worktree_root_paths {
|
||||
let worktree_id =
|
||||
db.find_or_create_worktree(worktree_root_path.as_ref())?;
|
||||
hashes.push((worktree_id, db.get_file_hashes(worktree_id)?));
|
||||
}
|
||||
anyhow::Ok((db, hashes))
|
||||
})
|
||||
.await?;
|
||||
|
||||
let (paths_tx, paths_rx) = channel::unbounded::<(i64, PathBuf, String)>();
|
||||
let (indexed_files_tx, indexed_files_rx) = channel::unbounded::<IndexedFile>();
|
||||
cx.background()
|
||||
.spawn(async move {
|
||||
for worktree in worktrees {
|
||||
for file in worktree.files(false, 0) {
|
||||
paths_tx.try_send(worktree.absolutize(&file.path)).unwrap();
|
||||
.spawn({
|
||||
let fs = fs.clone();
|
||||
async move {
|
||||
for worktree in worktrees.into_iter() {
|
||||
for file in worktree.files(false, 0) {
|
||||
let absolute_path = worktree.absolutize(&file.path);
|
||||
dbg!(&absolute_path);
|
||||
if let Some(content) = fs.load(&absolute_path).await.log_err() {
|
||||
dbg!(&content);
|
||||
paths_tx.try_send((0, absolute_path, content)).unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
|
||||
cx.background()
|
||||
.spawn({
|
||||
let client = client.clone();
|
||||
async move {
|
||||
let db_write_task = cx.background().spawn(
|
||||
async move {
|
||||
// Initialize Database, creates database and tables if not exists
|
||||
let db = VectorDatabase::new()?;
|
||||
while let Ok(indexed_file) = indexed_files_rx.recv().await {
|
||||
db.insert_file(indexed_file).log_err();
|
||||
}
|
||||
@ -188,39 +237,39 @@ impl VectorStore {
|
||||
// ALL OF THE BELOW IS FOR TESTING,
|
||||
// This should be removed as we find and appropriate place for evaluate our search.
|
||||
|
||||
let embedding_provider = OpenAIEmbeddings{ client };
|
||||
let queries = vec![
|
||||
"compute embeddings for all of the symbols in the codebase, and write them to a database",
|
||||
"compute an outline view of all of the symbols in a buffer",
|
||||
"scan a directory on the file system and load all of its children into an in-memory snapshot",
|
||||
];
|
||||
let embeddings = embedding_provider.embed_batch(queries.clone()).await?;
|
||||
// let queries = vec![
|
||||
// "compute embeddings for all of the symbols in the codebase, and write them to a database",
|
||||
// "compute an outline view of all of the symbols in a buffer",
|
||||
// "scan a directory on the file system and load all of its children into an in-memory snapshot",
|
||||
// ];
|
||||
// let embeddings = embedding_provider.embed_batch(queries.clone()).await?;
|
||||
|
||||
let t2 = Instant::now();
|
||||
let documents = db.get_documents().unwrap();
|
||||
let files = db.get_files().unwrap();
|
||||
println!("Retrieving all documents from Database: {}", t2.elapsed().as_millis());
|
||||
// let t2 = Instant::now();
|
||||
// let documents = db.get_documents().unwrap();
|
||||
// let files = db.get_files().unwrap();
|
||||
// println!("Retrieving all documents from Database: {}", t2.elapsed().as_millis());
|
||||
|
||||
let t1 = Instant::now();
|
||||
let mut bfs = BruteForceSearch::load(&db).unwrap();
|
||||
println!("Loading BFS to Memory: {:?}", t1.elapsed().as_millis());
|
||||
for (idx, embed) in embeddings.into_iter().enumerate() {
|
||||
let t0 = Instant::now();
|
||||
println!("\nQuery: {:?}", queries[idx]);
|
||||
let results = bfs.top_k_search(&embed, 5).await;
|
||||
println!("Search Elapsed: {}", t0.elapsed().as_millis());
|
||||
for (id, distance) in results {
|
||||
println!("");
|
||||
println!(" distance: {:?}", distance);
|
||||
println!(" document: {:?}", documents[&id].name);
|
||||
println!(" path: {:?}", files[&documents[&id].file_id].path);
|
||||
}
|
||||
// let t1 = Instant::now();
|
||||
// let mut bfs = BruteForceSearch::load(&db).unwrap();
|
||||
// println!("Loading BFS to Memory: {:?}", t1.elapsed().as_millis());
|
||||
// for (idx, embed) in embeddings.into_iter().enumerate() {
|
||||
// let t0 = Instant::now();
|
||||
// println!("\nQuery: {:?}", queries[idx]);
|
||||
// let results = bfs.top_k_search(&embed, 5).await;
|
||||
// println!("Search Elapsed: {}", t0.elapsed().as_millis());
|
||||
// for (id, distance) in results {
|
||||
// println!("");
|
||||
// println!(" distance: {:?}", distance);
|
||||
// println!(" document: {:?}", documents[&id].name);
|
||||
// println!(" path: {:?}", files[&documents[&id].file_id].relative_path);
|
||||
// }
|
||||
|
||||
}
|
||||
// }
|
||||
|
||||
anyhow::Ok(())
|
||||
}}.log_err())
|
||||
.detach();
|
||||
}
|
||||
.log_err(),
|
||||
);
|
||||
|
||||
let provider = DummyEmbeddings {};
|
||||
// let provider = OpenAIEmbeddings { client };
|
||||
@ -231,14 +280,15 @@ impl VectorStore {
|
||||
scope.spawn(async {
|
||||
let mut parser = Parser::new();
|
||||
let mut cursor = QueryCursor::new();
|
||||
while let Ok(file_path) = paths_rx.recv().await {
|
||||
while let Ok((worktree_id, file_path, content)) = paths_rx.recv().await
|
||||
{
|
||||
if let Some(indexed_file) = Self::index_file(
|
||||
&mut cursor,
|
||||
&mut parser,
|
||||
&provider,
|
||||
&fs,
|
||||
&language_registry,
|
||||
file_path,
|
||||
content,
|
||||
)
|
||||
.await
|
||||
.log_err()
|
||||
@ -250,11 +300,86 @@ impl VectorStore {
|
||||
}
|
||||
})
|
||||
.await;
|
||||
drop(indexed_files_tx);
|
||||
|
||||
db_write_task.await;
|
||||
anyhow::Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
pub fn search(
|
||||
&mut self,
|
||||
phrase: String,
|
||||
limit: usize,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) -> Task<Result<Vec<SearchResult>>> {
|
||||
let embedding_provider = self.embedding_provider.clone();
|
||||
let database_url = self.database_url.clone();
|
||||
cx.spawn(|this, cx| async move {
|
||||
let database = VectorDatabase::new(database_url.as_ref())?;
|
||||
|
||||
// let embedding = embedding_provider.embed_batch(vec![&phrase]).await?;
|
||||
//
|
||||
let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
|
||||
|
||||
database.for_each_document(0, |id, embedding| {
|
||||
dbg!(id, &embedding);
|
||||
|
||||
let similarity = dot(&embedding.0, &embedding.0);
|
||||
let ix = match results.binary_search_by(|(_, s)| {
|
||||
s.partial_cmp(&similarity).unwrap_or(Ordering::Equal)
|
||||
}) {
|
||||
Ok(ix) => ix,
|
||||
Err(ix) => ix,
|
||||
};
|
||||
|
||||
results.insert(ix, (id, similarity));
|
||||
results.truncate(limit);
|
||||
})?;
|
||||
|
||||
dbg!(&results);
|
||||
|
||||
let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<_>>();
|
||||
// let documents = database.get_documents_by_ids(ids)?;
|
||||
|
||||
// let search_provider = cx
|
||||
// .background()
|
||||
// .spawn(async move { BruteForceSearch::load(&database) })
|
||||
// .await?;
|
||||
|
||||
// let results = search_provider.top_k_search(&embedding, limit))
|
||||
|
||||
anyhow::Ok(vec![])
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
}
|
||||
|
||||
impl Entity for VectorStore {
|
||||
type Event = ();
|
||||
}
|
||||
|
||||
fn dot(vec_a: &[f32], vec_b: &[f32]) -> f32 {
|
||||
let len = vec_a.len();
|
||||
assert_eq!(len, vec_b.len());
|
||||
|
||||
let mut result = 0.0;
|
||||
unsafe {
|
||||
matrixmultiply::sgemm(
|
||||
1,
|
||||
len,
|
||||
1,
|
||||
1.0,
|
||||
vec_a.as_ptr(),
|
||||
len as isize,
|
||||
1,
|
||||
vec_b.as_ptr(),
|
||||
1,
|
||||
len as isize,
|
||||
0.0,
|
||||
&mut result as *mut f32,
|
||||
1,
|
||||
1,
|
||||
);
|
||||
}
|
||||
result
|
||||
}
|
||||
|
136
crates/vector_store/src/vector_store_tests.rs
Normal file
136
crates/vector_store/src/vector_store_tests.rs
Normal file
@ -0,0 +1,136 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::{dot, embedding::EmbeddingProvider, VectorStore};
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use gpui::{Task, TestAppContext};
|
||||
use language::{Language, LanguageConfig, LanguageRegistry};
|
||||
use project::{FakeFs, Project};
|
||||
use rand::Rng;
|
||||
use serde_json::json;
|
||||
use unindent::Unindent;
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_vector_store(cx: &mut TestAppContext) {
|
||||
let fs = FakeFs::new(cx.background());
|
||||
fs.insert_tree(
|
||||
"/the-root",
|
||||
json!({
|
||||
"src": {
|
||||
"file1.rs": "
|
||||
fn aaa() {
|
||||
println!(\"aaaa!\");
|
||||
}
|
||||
|
||||
fn zzzzzzzzz() {
|
||||
println!(\"SLEEPING\");
|
||||
}
|
||||
".unindent(),
|
||||
"file2.rs": "
|
||||
fn bbb() {
|
||||
println!(\"bbbb!\");
|
||||
}
|
||||
".unindent(),
|
||||
}
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
|
||||
let languages = Arc::new(LanguageRegistry::new(Task::ready(())));
|
||||
let rust_language = Arc::new(
|
||||
Language::new(
|
||||
LanguageConfig {
|
||||
name: "Rust".into(),
|
||||
path_suffixes: vec!["rs".into()],
|
||||
..Default::default()
|
||||
},
|
||||
Some(tree_sitter_rust::language()),
|
||||
)
|
||||
.with_outline_query(
|
||||
r#"
|
||||
(function_item
|
||||
name: (identifier) @name
|
||||
body: (block)) @item
|
||||
"#,
|
||||
)
|
||||
.unwrap(),
|
||||
);
|
||||
languages.add(rust_language);
|
||||
|
||||
let store = cx.add_model(|_| {
|
||||
VectorStore::new(
|
||||
fs.clone(),
|
||||
"foo".to_string(),
|
||||
Arc::new(FakeEmbeddingProvider),
|
||||
languages,
|
||||
)
|
||||
});
|
||||
|
||||
let project = Project::test(fs, ["/the-root".as_ref()], cx).await;
|
||||
store
|
||||
.update(cx, |store, cx| store.add_project(project, cx))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let search_results = store
|
||||
.update(cx, |store, cx| store.search("aaaa".to_string(), 5, cx))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(search_results[0].offset, 0);
|
||||
assert_eq!(search_results[1].name, "aaa");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dot_product() {
|
||||
assert_eq!(dot(&[1., 0., 0., 0., 0.], &[0., 1., 0., 0., 0.]), 0.);
|
||||
assert_eq!(dot(&[2., 0., 0., 0., 0.], &[3., 1., 0., 0., 0.]), 6.);
|
||||
|
||||
for _ in 0..100 {
|
||||
let mut rng = rand::thread_rng();
|
||||
let a: [f32; 32] = rng.gen();
|
||||
let b: [f32; 32] = rng.gen();
|
||||
assert_eq!(
|
||||
round_to_decimals(dot(&a, &b), 3),
|
||||
round_to_decimals(reference_dot(&a, &b), 3)
|
||||
);
|
||||
}
|
||||
|
||||
fn round_to_decimals(n: f32, decimal_places: i32) -> f32 {
|
||||
let factor = (10.0 as f32).powi(decimal_places);
|
||||
(n * factor).round() / factor
|
||||
}
|
||||
|
||||
fn reference_dot(a: &[f32], b: &[f32]) -> f32 {
|
||||
a.iter().zip(b.iter()).map(|(a, b)| a * b).sum()
|
||||
}
|
||||
}
|
||||
|
||||
struct FakeEmbeddingProvider;
|
||||
|
||||
#[async_trait]
|
||||
impl EmbeddingProvider for FakeEmbeddingProvider {
|
||||
async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
|
||||
Ok(spans
|
||||
.iter()
|
||||
.map(|span| {
|
||||
let mut result = vec![0.0; 26];
|
||||
for letter in span.chars() {
|
||||
if letter as u32 > 'a' as u32 {
|
||||
let ix = (letter as u32) - ('a' as u32);
|
||||
if ix < 26 {
|
||||
result[ix as usize] += 1.0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
for x in &mut result {
|
||||
*x /= norm;
|
||||
}
|
||||
|
||||
result
|
||||
})
|
||||
.collect())
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user