Eager background indexing (#2928)

This PR ships a series of optimizations for the semantic search engine.
Mostly focused on removing invalid states, optimizing requests to
OpenAI, and reducing token usage.

Release Notes (Preview-Only):

- Added eager incremental indexing in the background on a debounce.
- Added a local embeddings cache for reducing redundant calls to OpenAI.
- Moved to an Embeddings Queue model which ensures optimal batch sizes
at the token level, and atomic file & document writes.
- Adjusted OpenAI Embedding API requests to use provided backoff delays
during Rate Limiting.
- Removed flush races between parsing files step and embedding queue
steps.
- Moved truncation to parsing step reducing the probability that OpenAI
encounters bad data.
This commit is contained in:
Kyle Caverly 2023-09-05 13:15:54 -04:00 committed by GitHub
commit 49af2874bb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 1405 additions and 1085 deletions

65
Cargo.lock generated
View File

@ -3539,7 +3539,7 @@ dependencies = [
"gif",
"jpeg-decoder",
"num-iter",
"num-rational",
"num-rational 0.3.2",
"num-traits",
"png",
"scoped_threadpool",
@ -4631,6 +4631,31 @@ dependencies = [
"winapi 0.3.9",
]
[[package]]
name = "num"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b8536030f9fea7127f841b45bb6243b27255787fb4eb83958aa1ef9d2fdc0c36"
dependencies = [
"num-bigint 0.2.6",
"num-complex",
"num-integer",
"num-iter",
"num-rational 0.2.4",
"num-traits",
]
[[package]]
name = "num-bigint"
version = "0.2.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "090c7f9998ee0ff65aa5b723e4009f7b217707f1fb5ea551329cc4d6231fb304"
dependencies = [
"autocfg",
"num-integer",
"num-traits",
]
[[package]]
name = "num-bigint"
version = "0.4.4"
@ -4659,6 +4684,16 @@ dependencies = [
"zeroize",
]
[[package]]
name = "num-complex"
version = "0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b6b19411a9719e753aff12e5187b74d60d3dc449ec3f4dc21e3989c3f554bc95"
dependencies = [
"autocfg",
"num-traits",
]
[[package]]
name = "num-derive"
version = "0.3.3"
@ -4691,6 +4726,18 @@ dependencies = [
"num-traits",
]
[[package]]
name = "num-rational"
version = "0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5c000134b5dbf44adc5cb772486d335293351644b801551abe8f75c84cfa4aef"
dependencies = [
"autocfg",
"num-bigint 0.2.6",
"num-integer",
"num-traits",
]
[[package]]
name = "num-rational"
version = "0.3.2"
@ -5007,6 +5054,17 @@ dependencies = [
"windows-targets 0.48.5",
]
[[package]]
name = "parse_duration"
version = "2.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7037e5e93e0172a5a96874380bf73bc6ecef022e26fa25f2be26864d6b3ba95d"
dependencies = [
"lazy_static",
"num",
"regex",
]
[[package]]
name = "password-hash"
version = "0.2.3"
@ -6674,6 +6732,7 @@ dependencies = [
"log",
"matrixmultiply",
"parking_lot 0.11.2",
"parse_duration",
"picker",
"postage",
"pretty_assertions",
@ -7005,7 +7064,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8eb4ea60fb301dc81dfc113df680571045d375ab7345d171c5dc7d7e13107a80"
dependencies = [
"chrono",
"num-bigint",
"num-bigint 0.4.4",
"num-traits",
"thiserror",
]
@ -7237,7 +7296,7 @@ dependencies = [
"log",
"md-5",
"memchr",
"num-bigint",
"num-bigint 0.4.4",
"once_cell",
"paste",
"percent-encoding",

View File

@ -39,6 +39,7 @@ rand.workspace = true
schemars.workspace = true
globset.workspace = true
sha1 = "0.10.5"
parse_duration = "2.1.1"
[dev-dependencies]
gpui = { path = "../gpui", features = ["test-support"] }

View File

@ -1,20 +1,26 @@
use crate::{parsing::Document, SEMANTIC_INDEX_VERSION};
use crate::{
embedding::Embedding,
parsing::{Document, DocumentDigest},
SEMANTIC_INDEX_VERSION,
};
use anyhow::{anyhow, Context, Result};
use futures::channel::oneshot;
use gpui::executor;
use project::{search::PathMatcher, Fs};
use rpc::proto::Timestamp;
use rusqlite::{
params,
types::{FromSql, FromSqlResult, ValueRef},
};
use rusqlite::params;
use rusqlite::types::Value;
use std::{
cmp::Ordering,
collections::HashMap,
future::Future,
ops::Range,
path::{Path, PathBuf},
rc::Rc,
sync::Arc,
time::SystemTime,
time::{Instant, SystemTime},
};
use util::TryFutureExt;
#[derive(Debug)]
pub struct FileRecord {
@ -23,145 +29,181 @@ pub struct FileRecord {
pub mtime: Timestamp,
}
#[derive(Debug)]
struct Embedding(pub Vec<f32>);
#[derive(Debug)]
struct Sha1(pub Vec<u8>);
impl FromSql for Embedding {
fn column_result(value: ValueRef) -> FromSqlResult<Self> {
let bytes = value.as_blob()?;
let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
if embedding.is_err() {
return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
}
return Ok(Embedding(embedding.unwrap()));
}
}
impl FromSql for Sha1 {
fn column_result(value: ValueRef) -> FromSqlResult<Self> {
let bytes = value.as_blob()?;
let sha1: Result<Vec<u8>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
if sha1.is_err() {
return Err(rusqlite::types::FromSqlError::Other(sha1.unwrap_err()));
}
return Ok(Sha1(sha1.unwrap()));
}
}
#[derive(Clone)]
pub struct VectorDatabase {
db: rusqlite::Connection,
path: Arc<Path>,
transactions:
smol::channel::Sender<Box<dyn 'static + Send + FnOnce(&mut rusqlite::Connection)>>,
}
impl VectorDatabase {
pub async fn new(fs: Arc<dyn Fs>, path: Arc<PathBuf>) -> Result<Self> {
pub async fn new(
fs: Arc<dyn Fs>,
path: Arc<Path>,
executor: Arc<executor::Background>,
) -> Result<Self> {
if let Some(db_directory) = path.parent() {
fs.create_dir(db_directory).await?;
}
let (transactions_tx, transactions_rx) = smol::channel::unbounded::<
Box<dyn 'static + Send + FnOnce(&mut rusqlite::Connection)>,
>();
executor
.spawn({
let path = path.clone();
async move {
let mut connection = rusqlite::Connection::open(&path)?;
connection.pragma_update(None, "journal_mode", "wal")?;
connection.pragma_update(None, "synchronous", "normal")?;
connection.pragma_update(None, "cache_size", 1000000)?;
connection.pragma_update(None, "temp_store", "MEMORY")?;
while let Ok(transaction) = transactions_rx.recv().await {
transaction(&mut connection);
}
anyhow::Ok(())
}
.log_err()
})
.detach();
let this = Self {
db: rusqlite::Connection::open(path.as_path())?,
transactions: transactions_tx,
path,
};
this.initialize_database()?;
this.initialize_database().await?;
Ok(this)
}
fn get_existing_version(&self) -> Result<i64> {
let mut version_query = self
.db
.prepare("SELECT version from semantic_index_config")?;
version_query
.query_row([], |row| Ok(row.get::<_, i64>(0)?))
.map_err(|err| anyhow!("version query failed: {err}"))
pub fn path(&self) -> &Arc<Path> {
&self.path
}
fn initialize_database(&self) -> Result<()> {
rusqlite::vtab::array::load_module(&self.db)?;
// Delete existing tables, if SEMANTIC_INDEX_VERSION is bumped
if self
.get_existing_version()
.map_or(false, |version| version == SEMANTIC_INDEX_VERSION as i64)
{
log::trace!("vector database schema up to date");
return Ok(());
fn transact<F, T>(&self, f: F) -> impl Future<Output = Result<T>>
where
F: 'static + Send + FnOnce(&rusqlite::Transaction) -> Result<T>,
T: 'static + Send,
{
let (tx, rx) = oneshot::channel();
let transactions = self.transactions.clone();
async move {
if transactions
.send(Box::new(|connection| {
let result = connection
.transaction()
.map_err(|err| anyhow!(err))
.and_then(|transaction| {
let result = f(&transaction)?;
transaction.commit()?;
Ok(result)
});
let _ = tx.send(result);
}))
.await
.is_err()
{
return Err(anyhow!("connection was dropped"))?;
}
rx.await?
}
log::trace!("vector database schema out of date. updating...");
self.db
.execute("DROP TABLE IF EXISTS documents", [])
.context("failed to drop 'documents' table")?;
self.db
.execute("DROP TABLE IF EXISTS files", [])
.context("failed to drop 'files' table")?;
self.db
.execute("DROP TABLE IF EXISTS worktrees", [])
.context("failed to drop 'worktrees' table")?;
self.db
.execute("DROP TABLE IF EXISTS semantic_index_config", [])
.context("failed to drop 'semantic_index_config' table")?;
// Initialize Vector Databasing Tables
self.db.execute(
"CREATE TABLE semantic_index_config (
version INTEGER NOT NULL
)",
[],
)?;
self.db.execute(
"INSERT INTO semantic_index_config (version) VALUES (?1)",
params![SEMANTIC_INDEX_VERSION],
)?;
self.db.execute(
"CREATE TABLE worktrees (
id INTEGER PRIMARY KEY AUTOINCREMENT,
absolute_path VARCHAR NOT NULL
);
CREATE UNIQUE INDEX worktrees_absolute_path ON worktrees (absolute_path);
",
[],
)?;
self.db.execute(
"CREATE TABLE files (
id INTEGER PRIMARY KEY AUTOINCREMENT,
worktree_id INTEGER NOT NULL,
relative_path VARCHAR NOT NULL,
mtime_seconds INTEGER NOT NULL,
mtime_nanos INTEGER NOT NULL,
FOREIGN KEY(worktree_id) REFERENCES worktrees(id) ON DELETE CASCADE
)",
[],
)?;
self.db.execute(
"CREATE TABLE documents (
id INTEGER PRIMARY KEY AUTOINCREMENT,
file_id INTEGER NOT NULL,
start_byte INTEGER NOT NULL,
end_byte INTEGER NOT NULL,
name VARCHAR NOT NULL,
embedding BLOB NOT NULL,
sha1 BLOB NOT NULL,
FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE
)",
[],
)?;
log::trace!("vector database initialized with updated schema.");
Ok(())
}
pub fn delete_file(&self, worktree_id: i64, delete_path: PathBuf) -> Result<()> {
self.db.execute(
"DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2",
params![worktree_id, delete_path.to_str()],
)?;
Ok(())
fn initialize_database(&self) -> impl Future<Output = Result<()>> {
self.transact(|db| {
rusqlite::vtab::array::load_module(&db)?;
// Delete existing tables, if SEMANTIC_INDEX_VERSION is bumped
let version_query = db.prepare("SELECT version from semantic_index_config");
let version = version_query
.and_then(|mut query| query.query_row([], |row| Ok(row.get::<_, i64>(0)?)));
if version.map_or(false, |version| version == SEMANTIC_INDEX_VERSION as i64) {
log::trace!("vector database schema up to date");
return Ok(());
}
log::trace!("vector database schema out of date. updating...");
db.execute("DROP TABLE IF EXISTS documents", [])
.context("failed to drop 'documents' table")?;
db.execute("DROP TABLE IF EXISTS files", [])
.context("failed to drop 'files' table")?;
db.execute("DROP TABLE IF EXISTS worktrees", [])
.context("failed to drop 'worktrees' table")?;
db.execute("DROP TABLE IF EXISTS semantic_index_config", [])
.context("failed to drop 'semantic_index_config' table")?;
// Initialize Vector Databasing Tables
db.execute(
"CREATE TABLE semantic_index_config (
version INTEGER NOT NULL
)",
[],
)?;
db.execute(
"INSERT INTO semantic_index_config (version) VALUES (?1)",
params![SEMANTIC_INDEX_VERSION],
)?;
db.execute(
"CREATE TABLE worktrees (
id INTEGER PRIMARY KEY AUTOINCREMENT,
absolute_path VARCHAR NOT NULL
);
CREATE UNIQUE INDEX worktrees_absolute_path ON worktrees (absolute_path);
",
[],
)?;
db.execute(
"CREATE TABLE files (
id INTEGER PRIMARY KEY AUTOINCREMENT,
worktree_id INTEGER NOT NULL,
relative_path VARCHAR NOT NULL,
mtime_seconds INTEGER NOT NULL,
mtime_nanos INTEGER NOT NULL,
FOREIGN KEY(worktree_id) REFERENCES worktrees(id) ON DELETE CASCADE
)",
[],
)?;
db.execute(
"CREATE UNIQUE INDEX files_worktree_id_and_relative_path ON files (worktree_id, relative_path)",
[],
)?;
db.execute(
"CREATE TABLE documents (
id INTEGER PRIMARY KEY AUTOINCREMENT,
file_id INTEGER NOT NULL,
start_byte INTEGER NOT NULL,
end_byte INTEGER NOT NULL,
name VARCHAR NOT NULL,
embedding BLOB NOT NULL,
digest BLOB NOT NULL,
FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE
)",
[],
)?;
log::trace!("vector database initialized with updated schema.");
Ok(())
})
}
pub fn delete_file(
&self,
worktree_id: i64,
delete_path: PathBuf,
) -> impl Future<Output = Result<()>> {
self.transact(move |db| {
db.execute(
"DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2",
params![worktree_id, delete_path.to_str()],
)?;
Ok(())
})
}
pub fn insert_file(
@ -170,139 +212,187 @@ impl VectorDatabase {
path: PathBuf,
mtime: SystemTime,
documents: Vec<Document>,
) -> Result<()> {
// Return the existing ID, if both the file and mtime match
let mtime = Timestamp::from(mtime);
let mut existing_id_query = self.db.prepare("SELECT id FROM files WHERE worktree_id = ?1 AND relative_path = ?2 AND mtime_seconds = ?3 AND mtime_nanos = ?4")?;
let existing_id = existing_id_query
.query_row(
) -> impl Future<Output = Result<()>> {
self.transact(move |db| {
// Return the existing ID, if both the file and mtime match
let mtime = Timestamp::from(mtime);
db.execute(
"
REPLACE INTO files
(worktree_id, relative_path, mtime_seconds, mtime_nanos)
VALUES (?1, ?2, ?3, ?4)
",
params![worktree_id, path.to_str(), mtime.seconds, mtime.nanos],
|row| Ok(row.get::<_, i64>(0)?),
)
.map_err(|err| anyhow!(err));
let file_id = if existing_id.is_ok() {
// If already exists, just return the existing id
existing_id.unwrap()
} else {
// Delete Existing Row
self.db.execute(
"DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2;",
params![worktree_id, path.to_str()],
)?;
self.db.execute("INSERT INTO files (worktree_id, relative_path, mtime_seconds, mtime_nanos) VALUES (?1, ?2, ?3, ?4);", params![worktree_id, path.to_str(), mtime.seconds, mtime.nanos])?;
self.db.last_insert_rowid()
};
// Currently inserting at approximately 3400 documents a second
// I imagine we can speed this up with a bulk insert of some kind.
for document in documents {
let embedding_blob = bincode::serialize(&document.embedding)?;
let sha_blob = bincode::serialize(&document.sha1)?;
let file_id = db.last_insert_rowid();
self.db.execute(
"INSERT INTO documents (file_id, start_byte, end_byte, name, embedding, sha1) VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
params![
let t0 = Instant::now();
let mut query = db.prepare(
"
INSERT INTO documents
(file_id, start_byte, end_byte, name, embedding, digest)
VALUES (?1, ?2, ?3, ?4, ?5, ?6)
",
)?;
log::trace!(
"Preparing Query Took: {:?} milliseconds",
t0.elapsed().as_millis()
);
for document in documents {
query.execute(params![
file_id,
document.range.start.to_string(),
document.range.end.to_string(),
document.name,
embedding_blob,
sha_blob
],
)?;
}
document.embedding,
document.digest
])?;
}
Ok(())
Ok(())
})
}
pub fn worktree_previously_indexed(&self, worktree_root_path: &Path) -> Result<bool> {
let mut worktree_query = self
.db
.prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?;
let worktree_id = worktree_query
.query_row(params![worktree_root_path.to_string_lossy()], |row| {
Ok(row.get::<_, i64>(0)?)
})
.map_err(|err| anyhow!(err));
pub fn worktree_previously_indexed(
&self,
worktree_root_path: &Path,
) -> impl Future<Output = Result<bool>> {
let worktree_root_path = worktree_root_path.to_string_lossy().into_owned();
self.transact(move |db| {
let mut worktree_query =
db.prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?;
let worktree_id = worktree_query
.query_row(params![worktree_root_path], |row| Ok(row.get::<_, i64>(0)?));
if worktree_id.is_ok() {
return Ok(true);
} else {
return Ok(false);
}
if worktree_id.is_ok() {
return Ok(true);
} else {
return Ok(false);
}
})
}
pub fn find_or_create_worktree(&self, worktree_root_path: &Path) -> Result<i64> {
// Check that the absolute path doesnt exist
let mut worktree_query = self
.db
.prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?;
let worktree_id = worktree_query
.query_row(params![worktree_root_path.to_string_lossy()], |row| {
Ok(row.get::<_, i64>(0)?)
})
.map_err(|err| anyhow!(err));
if worktree_id.is_ok() {
return worktree_id;
}
// If worktree_id is Err, insert new worktree
self.db.execute(
"
INSERT into worktrees (absolute_path) VALUES (?1)
pub fn embeddings_for_files(
&self,
worktree_id_file_paths: HashMap<i64, Vec<Arc<Path>>>,
) -> impl Future<Output = Result<HashMap<DocumentDigest, Embedding>>> {
self.transact(move |db| {
let mut query = db.prepare(
"
SELECT digest, embedding
FROM documents
LEFT JOIN files ON files.id = documents.file_id
WHERE files.worktree_id = ? AND files.relative_path IN rarray(?)
",
params![worktree_root_path.to_string_lossy()],
)?;
Ok(self.db.last_insert_rowid())
)?;
let mut embeddings_by_digest = HashMap::new();
for (worktree_id, file_paths) in worktree_id_file_paths {
let file_paths = Rc::new(
file_paths
.into_iter()
.map(|p| Value::Text(p.to_string_lossy().into_owned()))
.collect::<Vec<_>>(),
);
let rows = query.query_map(params![worktree_id, file_paths], |row| {
Ok((
row.get::<_, DocumentDigest>(0)?,
row.get::<_, Embedding>(1)?,
))
})?;
for row in rows {
if let Ok(row) = row {
embeddings_by_digest.insert(row.0, row.1);
}
}
}
Ok(embeddings_by_digest)
})
}
pub fn get_file_mtimes(&self, worktree_id: i64) -> Result<HashMap<PathBuf, SystemTime>> {
let mut statement = self.db.prepare(
"
SELECT relative_path, mtime_seconds, mtime_nanos
FROM files
WHERE worktree_id = ?1
ORDER BY relative_path",
)?;
let mut result: HashMap<PathBuf, SystemTime> = HashMap::new();
for row in statement.query_map(params![worktree_id], |row| {
Ok((
row.get::<_, String>(0)?.into(),
Timestamp {
seconds: row.get(1)?,
nanos: row.get(2)?,
}
.into(),
))
})? {
let row = row?;
result.insert(row.0, row.1);
}
Ok(result)
pub fn find_or_create_worktree(
&self,
worktree_root_path: PathBuf,
) -> impl Future<Output = Result<i64>> {
self.transact(move |db| {
let mut worktree_query =
db.prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?;
let worktree_id = worktree_query
.query_row(params![worktree_root_path.to_string_lossy()], |row| {
Ok(row.get::<_, i64>(0)?)
});
if worktree_id.is_ok() {
return Ok(worktree_id?);
}
// If worktree_id is Err, insert new worktree
db.execute(
"INSERT into worktrees (absolute_path) VALUES (?1)",
params![worktree_root_path.to_string_lossy()],
)?;
Ok(db.last_insert_rowid())
})
}
pub fn get_file_mtimes(
&self,
worktree_id: i64,
) -> impl Future<Output = Result<HashMap<PathBuf, SystemTime>>> {
self.transact(move |db| {
let mut statement = db.prepare(
"
SELECT relative_path, mtime_seconds, mtime_nanos
FROM files
WHERE worktree_id = ?1
ORDER BY relative_path",
)?;
let mut result: HashMap<PathBuf, SystemTime> = HashMap::new();
for row in statement.query_map(params![worktree_id], |row| {
Ok((
row.get::<_, String>(0)?.into(),
Timestamp {
seconds: row.get(1)?,
nanos: row.get(2)?,
}
.into(),
))
})? {
let row = row?;
result.insert(row.0, row.1);
}
Ok(result)
})
}
pub fn top_k_search(
&self,
query_embedding: &Vec<f32>,
query_embedding: &Embedding,
limit: usize,
file_ids: &[i64],
) -> Result<Vec<(i64, f32)>> {
let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
self.for_each_document(file_ids, |id, embedding| {
let similarity = dot(&embedding, &query_embedding);
let ix = match results
.binary_search_by(|(_, s)| similarity.partial_cmp(&s).unwrap_or(Ordering::Equal))
{
Ok(ix) => ix,
Err(ix) => ix,
};
results.insert(ix, (id, similarity));
results.truncate(limit);
})?;
) -> impl Future<Output = Result<Vec<(i64, f32)>>> {
let query_embedding = query_embedding.clone();
let file_ids = file_ids.to_vec();
self.transact(move |db| {
let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
Self::for_each_document(db, &file_ids, |id, embedding| {
let similarity = embedding.similarity(&query_embedding);
let ix = match results.binary_search_by(|(_, s)| {
similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)
}) {
Ok(ix) => ix,
Err(ix) => ix,
};
results.insert(ix, (id, similarity));
results.truncate(limit);
})?;
Ok(results)
anyhow::Ok(results)
})
}
pub fn retrieve_included_file_ids(
@ -310,37 +400,46 @@ impl VectorDatabase {
worktree_ids: &[i64],
includes: &[PathMatcher],
excludes: &[PathMatcher],
) -> Result<Vec<i64>> {
let mut file_query = self.db.prepare(
"
SELECT
id, relative_path
FROM
files
WHERE
worktree_id IN rarray(?)
",
)?;
) -> impl Future<Output = Result<Vec<i64>>> {
let worktree_ids = worktree_ids.to_vec();
let includes = includes.to_vec();
let excludes = excludes.to_vec();
self.transact(move |db| {
let mut file_query = db.prepare(
"
SELECT
id, relative_path
FROM
files
WHERE
worktree_id IN rarray(?)
",
)?;
let mut file_ids = Vec::<i64>::new();
let mut rows = file_query.query([ids_to_sql(worktree_ids)])?;
let mut file_ids = Vec::<i64>::new();
let mut rows = file_query.query([ids_to_sql(&worktree_ids)])?;
while let Some(row) = rows.next()? {
let file_id = row.get(0)?;
let relative_path = row.get_ref(1)?.as_str()?;
let included =
includes.is_empty() || includes.iter().any(|glob| glob.is_match(relative_path));
let excluded = excludes.iter().any(|glob| glob.is_match(relative_path));
if included && !excluded {
file_ids.push(file_id);
while let Some(row) = rows.next()? {
let file_id = row.get(0)?;
let relative_path = row.get_ref(1)?.as_str()?;
let included =
includes.is_empty() || includes.iter().any(|glob| glob.is_match(relative_path));
let excluded = excludes.iter().any(|glob| glob.is_match(relative_path));
if included && !excluded {
file_ids.push(file_id);
}
}
}
Ok(file_ids)
anyhow::Ok(file_ids)
})
}
fn for_each_document(&self, file_ids: &[i64], mut f: impl FnMut(i64, Vec<f32>)) -> Result<()> {
let mut query_statement = self.db.prepare(
fn for_each_document(
db: &rusqlite::Connection,
file_ids: &[i64],
mut f: impl FnMut(i64, Embedding),
) -> Result<()> {
let mut query_statement = db.prepare(
"
SELECT
id, embedding
@ -356,51 +455,57 @@ impl VectorDatabase {
Ok((row.get(0)?, row.get::<_, Embedding>(1)?))
})?
.filter_map(|row| row.ok())
.for_each(|(id, embedding)| f(id, embedding.0));
.for_each(|(id, embedding)| f(id, embedding));
Ok(())
}
pub fn get_documents_by_ids(&self, ids: &[i64]) -> Result<Vec<(i64, PathBuf, Range<usize>)>> {
let mut statement = self.db.prepare(
"
SELECT
documents.id,
files.worktree_id,
files.relative_path,
documents.start_byte,
documents.end_byte
FROM
documents, files
WHERE
documents.file_id = files.id AND
documents.id in rarray(?)
",
)?;
pub fn get_documents_by_ids(
&self,
ids: &[i64],
) -> impl Future<Output = Result<Vec<(i64, PathBuf, Range<usize>)>>> {
let ids = ids.to_vec();
self.transact(move |db| {
let mut statement = db.prepare(
"
SELECT
documents.id,
files.worktree_id,
files.relative_path,
documents.start_byte,
documents.end_byte
FROM
documents, files
WHERE
documents.file_id = files.id AND
documents.id in rarray(?)
",
)?;
let result_iter = statement.query_map(params![ids_to_sql(ids)], |row| {
Ok((
row.get::<_, i64>(0)?,
row.get::<_, i64>(1)?,
row.get::<_, String>(2)?.into(),
row.get(3)?..row.get(4)?,
))
})?;
let result_iter = statement.query_map(params![ids_to_sql(&ids)], |row| {
Ok((
row.get::<_, i64>(0)?,
row.get::<_, i64>(1)?,
row.get::<_, String>(2)?.into(),
row.get(3)?..row.get(4)?,
))
})?;
let mut values_by_id = HashMap::<i64, (i64, PathBuf, Range<usize>)>::default();
for row in result_iter {
let (id, worktree_id, path, range) = row?;
values_by_id.insert(id, (worktree_id, path, range));
}
let mut values_by_id = HashMap::<i64, (i64, PathBuf, Range<usize>)>::default();
for row in result_iter {
let (id, worktree_id, path, range) = row?;
values_by_id.insert(id, (worktree_id, path, range));
}
let mut results = Vec::with_capacity(ids.len());
for id in ids {
let value = values_by_id
.remove(id)
.ok_or(anyhow!("missing document id {}", id))?;
results.push(value);
}
let mut results = Vec::with_capacity(ids.len());
for id in &ids {
let value = values_by_id
.remove(id)
.ok_or(anyhow!("missing document id {}", id))?;
results.push(value);
}
Ok(results)
Ok(results)
})
}
}
@ -412,29 +517,3 @@ fn ids_to_sql(ids: &[i64]) -> Rc<Vec<rusqlite::types::Value>> {
.collect::<Vec<_>>(),
)
}
pub(crate) fn dot(vec_a: &[f32], vec_b: &[f32]) -> f32 {
let len = vec_a.len();
assert_eq!(len, vec_b.len());
let mut result = 0.0;
unsafe {
matrixmultiply::sgemm(
1,
len,
1,
1.0,
vec_a.as_ptr(),
len as isize,
1,
vec_b.as_ptr(),
1,
len as isize,
0.0,
&mut result as *mut f32,
1,
1,
);
}
result
}

View File

@ -7,6 +7,9 @@ use isahc::http::StatusCode;
use isahc::prelude::Configurable;
use isahc::{AsyncBody, Response};
use lazy_static::lazy_static;
use parse_duration::parse;
use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef};
use rusqlite::ToSql;
use serde::{Deserialize, Serialize};
use std::env;
use std::sync::Arc;
@ -19,6 +22,62 @@ lazy_static! {
static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
}
#[derive(Debug, PartialEq, Clone)]
pub struct Embedding(Vec<f32>);
impl From<Vec<f32>> for Embedding {
fn from(value: Vec<f32>) -> Self {
Embedding(value)
}
}
impl Embedding {
pub fn similarity(&self, other: &Self) -> f32 {
let len = self.0.len();
assert_eq!(len, other.0.len());
let mut result = 0.0;
unsafe {
matrixmultiply::sgemm(
1,
len,
1,
1.0,
self.0.as_ptr(),
len as isize,
1,
other.0.as_ptr(),
1,
len as isize,
0.0,
&mut result as *mut f32,
1,
1,
);
}
result
}
}
impl FromSql for Embedding {
fn column_result(value: ValueRef) -> FromSqlResult<Self> {
let bytes = value.as_blob()?;
let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
if embedding.is_err() {
return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
}
Ok(Embedding(embedding.unwrap()))
}
}
impl ToSql for Embedding {
fn to_sql(&self) -> rusqlite::Result<ToSqlOutput> {
let bytes = bincode::serialize(&self.0)
.map_err(|err| rusqlite::Error::ToSqlConversionFailure(Box::new(err)))?;
Ok(ToSqlOutput::Owned(rusqlite::types::Value::Blob(bytes)))
}
}
#[derive(Clone)]
pub struct OpenAIEmbeddings {
pub client: Arc<dyn HttpClient>,
@ -52,42 +111,53 @@ struct OpenAIEmbeddingUsage {
#[async_trait]
pub trait EmbeddingProvider: Sync + Send {
async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>>;
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>>;
fn max_tokens_per_batch(&self) -> usize;
fn truncate(&self, span: &str) -> (String, usize);
}
pub struct DummyEmbeddings {}
#[async_trait]
impl EmbeddingProvider for DummyEmbeddings {
async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
// 1024 is the OpenAI Embeddings size for ada models.
// the model we will likely be starting with.
let dummy_vec = vec![0.32 as f32; 1536];
let dummy_vec = Embedding::from(vec![0.32 as f32; 1536]);
return Ok(vec![dummy_vec; spans.len()]);
}
fn max_tokens_per_batch(&self) -> usize {
OPENAI_INPUT_LIMIT
}
fn truncate(&self, span: &str) -> (String, usize) {
let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
let token_count = tokens.len();
let output = if token_count > OPENAI_INPUT_LIMIT {
tokens.truncate(OPENAI_INPUT_LIMIT);
let new_input = OPENAI_BPE_TOKENIZER.decode(tokens.clone());
new_input.ok().unwrap_or_else(|| span.to_string())
} else {
span.to_string()
};
(output, tokens.len())
}
}
const OPENAI_INPUT_LIMIT: usize = 8190;
impl OpenAIEmbeddings {
fn truncate(span: String) -> String {
let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span.as_ref());
if tokens.len() > OPENAI_INPUT_LIMIT {
tokens.truncate(OPENAI_INPUT_LIMIT);
let result = OPENAI_BPE_TOKENIZER.decode(tokens.clone());
if result.is_ok() {
let transformed = result.unwrap();
return transformed;
}
}
span
}
async fn send_request(&self, api_key: &str, spans: Vec<&str>) -> Result<Response<AsyncBody>> {
async fn send_request(
&self,
api_key: &str,
spans: Vec<&str>,
request_timeout: u64,
) -> Result<Response<AsyncBody>> {
let request = Request::post("https://api.openai.com/v1/embeddings")
.redirect_policy(isahc::config::RedirectPolicy::Follow)
.timeout(Duration::from_secs(4))
.timeout(Duration::from_secs(request_timeout))
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", api_key))
.body(
@ -105,7 +175,27 @@ impl OpenAIEmbeddings {
#[async_trait]
impl EmbeddingProvider for OpenAIEmbeddings {
async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
fn max_tokens_per_batch(&self) -> usize {
50000
}
fn truncate(&self, span: &str) -> (String, usize) {
let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
let token_count = tokens.len();
let output = if token_count > OPENAI_INPUT_LIMIT {
tokens.truncate(OPENAI_INPUT_LIMIT);
OPENAI_BPE_TOKENIZER
.decode(tokens)
.ok()
.unwrap_or_else(|| span.to_string())
} else {
span.to_string()
};
(output, token_count)
}
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
const MAX_RETRIES: usize = 4;
@ -114,45 +204,21 @@ impl EmbeddingProvider for OpenAIEmbeddings {
.ok_or_else(|| anyhow!("no api key"))?;
let mut request_number = 0;
let mut truncated = false;
let mut request_timeout: u64 = 10;
let mut response: Response<AsyncBody>;
let mut spans: Vec<String> = spans.iter().map(|x| x.to_string()).collect();
while request_number < MAX_RETRIES {
response = self
.send_request(api_key, spans.iter().map(|x| &**x).collect())
.send_request(
api_key,
spans.iter().map(|x| &**x).collect(),
request_timeout,
)
.await?;
request_number += 1;
if request_number + 1 == MAX_RETRIES && response.status() != StatusCode::OK {
return Err(anyhow!(
"openai max retries, error: {:?}",
&response.status()
));
}
match response.status() {
StatusCode::TOO_MANY_REQUESTS => {
let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
log::trace!(
"open ai rate limiting, delaying request by {:?} seconds",
delay.as_secs()
);
self.executor.timer(delay).await;
}
StatusCode::BAD_REQUEST => {
// Only truncate if it hasnt been truncated before
if !truncated {
for span in spans.iter_mut() {
*span = Self::truncate(span.clone());
}
truncated = true;
} else {
// If failing once already truncated, log the error and break the loop
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
log::trace!("open ai bad request: {:?} {:?}", &response.status(), body);
break;
}
StatusCode::REQUEST_TIMEOUT => {
request_timeout += 5;
}
StatusCode::OK => {
let mut body = String::new();
@ -163,18 +229,96 @@ impl EmbeddingProvider for OpenAIEmbeddings {
"openai embedding completed. tokens: {:?}",
response.usage.total_tokens
);
return Ok(response
.data
.into_iter()
.map(|embedding| embedding.embedding)
.map(|embedding| Embedding::from(embedding.embedding))
.collect());
}
StatusCode::TOO_MANY_REQUESTS => {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
let delay_duration = {
let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
if let Some(time_to_reset) =
response.headers().get("x-ratelimit-reset-tokens")
{
if let Ok(time_str) = time_to_reset.to_str() {
parse(time_str).unwrap_or(delay)
} else {
delay
}
} else {
delay
}
};
log::trace!(
"openai rate limiting: waiting {:?} until lifted",
&delay_duration
);
self.executor.timer(delay_duration).await;
}
_ => {
return Err(anyhow!("openai embedding failed {}", response.status()));
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
return Err(anyhow!(
"open ai bad request: {:?} {:?}",
&response.status(),
body
));
}
}
}
Err(anyhow!("openai embedding failed"))
Err(anyhow!("openai max retries"))
}
}
#[cfg(test)]
mod tests {
use super::*;
use rand::prelude::*;
#[gpui::test]
fn test_similarity(mut rng: StdRng) {
assert_eq!(
Embedding::from(vec![1., 0., 0., 0., 0.])
.similarity(&Embedding::from(vec![0., 1., 0., 0., 0.])),
0.
);
assert_eq!(
Embedding::from(vec![2., 0., 0., 0., 0.])
.similarity(&Embedding::from(vec![3., 1., 0., 0., 0.])),
6.
);
for _ in 0..100 {
let size = 1536;
let mut a = vec![0.; size];
let mut b = vec![0.; size];
for (a, b) in a.iter_mut().zip(b.iter_mut()) {
*a = rng.gen();
*b = rng.gen();
}
let a = Embedding::from(a);
let b = Embedding::from(b);
assert_eq!(
round_to_decimals(a.similarity(&b), 1),
round_to_decimals(reference_dot(&a.0, &b.0), 1)
);
}
fn round_to_decimals(n: f32, decimal_places: i32) -> f32 {
let factor = (10.0 as f32).powi(decimal_places);
(n * factor).round() / factor
}
fn reference_dot(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b.iter()).map(|(a, b)| a * b).sum()
}
}
}

View File

@ -0,0 +1,173 @@
use crate::{embedding::EmbeddingProvider, parsing::Document, JobHandle};
use gpui::executor::Background;
use parking_lot::Mutex;
use smol::channel;
use std::{mem, ops::Range, path::PathBuf, sync::Arc, time::SystemTime};
#[derive(Clone)]
pub struct FileToEmbed {
pub worktree_id: i64,
pub path: PathBuf,
pub mtime: SystemTime,
pub documents: Vec<Document>,
pub job_handle: JobHandle,
}
impl std::fmt::Debug for FileToEmbed {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("FileToEmbed")
.field("worktree_id", &self.worktree_id)
.field("path", &self.path)
.field("mtime", &self.mtime)
.field("document", &self.documents)
.finish_non_exhaustive()
}
}
impl PartialEq for FileToEmbed {
fn eq(&self, other: &Self) -> bool {
self.worktree_id == other.worktree_id
&& self.path == other.path
&& self.mtime == other.mtime
&& self.documents == other.documents
}
}
pub struct EmbeddingQueue {
embedding_provider: Arc<dyn EmbeddingProvider>,
pending_batch: Vec<FileToEmbedFragment>,
executor: Arc<Background>,
pending_batch_token_count: usize,
finished_files_tx: channel::Sender<FileToEmbed>,
finished_files_rx: channel::Receiver<FileToEmbed>,
}
#[derive(Clone)]
pub struct FileToEmbedFragment {
file: Arc<Mutex<FileToEmbed>>,
document_range: Range<usize>,
}
impl EmbeddingQueue {
pub fn new(embedding_provider: Arc<dyn EmbeddingProvider>, executor: Arc<Background>) -> Self {
let (finished_files_tx, finished_files_rx) = channel::unbounded();
Self {
embedding_provider,
executor,
pending_batch: Vec::new(),
pending_batch_token_count: 0,
finished_files_tx,
finished_files_rx,
}
}
pub fn push(&mut self, file: FileToEmbed) {
if file.documents.is_empty() {
self.finished_files_tx.try_send(file).unwrap();
return;
}
let file = Arc::new(Mutex::new(file));
self.pending_batch.push(FileToEmbedFragment {
file: file.clone(),
document_range: 0..0,
});
let mut fragment_range = &mut self.pending_batch.last_mut().unwrap().document_range;
let mut saved_tokens = 0;
for (ix, document) in file.lock().documents.iter().enumerate() {
let document_token_count = if document.embedding.is_none() {
document.token_count
} else {
saved_tokens += document.token_count;
0
};
let next_token_count = self.pending_batch_token_count + document_token_count;
if next_token_count > self.embedding_provider.max_tokens_per_batch() {
let range_end = fragment_range.end;
self.flush();
self.pending_batch.push(FileToEmbedFragment {
file: file.clone(),
document_range: range_end..range_end,
});
fragment_range = &mut self.pending_batch.last_mut().unwrap().document_range;
}
fragment_range.end = ix + 1;
self.pending_batch_token_count += document_token_count;
}
log::trace!("Saved Tokens: {:?}", saved_tokens);
}
pub fn flush(&mut self) {
let batch = mem::take(&mut self.pending_batch);
self.pending_batch_token_count = 0;
if batch.is_empty() {
return;
}
let finished_files_tx = self.finished_files_tx.clone();
let embedding_provider = self.embedding_provider.clone();
self.executor.spawn(async move {
let mut spans = Vec::new();
let mut document_count = 0;
for fragment in &batch {
let file = fragment.file.lock();
document_count += file.documents[fragment.document_range.clone()].len();
spans.extend(
{
file.documents[fragment.document_range.clone()]
.iter().filter(|d| d.embedding.is_none())
.map(|d| d.content.clone())
}
);
}
log::trace!("Documents Length: {:?}", document_count);
log::trace!("Span Length: {:?}", spans.clone().len());
// If spans is 0, just send the fragment to the finished files if its the last one.
if spans.len() == 0 {
for fragment in batch.clone() {
if let Some(file) = Arc::into_inner(fragment.file) {
finished_files_tx.try_send(file.into_inner()).unwrap();
}
}
return;
};
match embedding_provider.embed_batch(spans).await {
Ok(embeddings) => {
let mut embeddings = embeddings.into_iter();
for fragment in batch {
for document in
&mut fragment.file.lock().documents[fragment.document_range.clone()].iter_mut().filter(|d| d.embedding.is_none())
{
if let Some(embedding) = embeddings.next() {
document.embedding = Some(embedding);
} else {
//
log::error!("number of embeddings returned different from number of documents");
}
}
if let Some(file) = Arc::into_inner(fragment.file) {
finished_files_tx.try_send(file.into_inner()).unwrap();
}
}
}
Err(error) => {
log::error!("{:?}", error);
}
}
})
.detach();
}
pub fn finished_files(&self) -> channel::Receiver<FileToEmbed> {
self.finished_files_rx.clone()
}
}

View File

@ -1,5 +1,10 @@
use anyhow::{anyhow, Ok, Result};
use crate::embedding::{Embedding, EmbeddingProvider};
use anyhow::{anyhow, Result};
use language::{Grammar, Language};
use rusqlite::{
types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef},
ToSql,
};
use sha1::{Digest, Sha1};
use std::{
cmp::{self, Reverse},
@ -10,13 +15,44 @@ use std::{
};
use tree_sitter::{Parser, QueryCursor};
#[derive(Debug, PartialEq, Eq, Clone, Hash)]
pub struct DocumentDigest([u8; 20]);
impl FromSql for DocumentDigest {
fn column_result(value: ValueRef) -> FromSqlResult<Self> {
let blob = value.as_blob()?;
let bytes =
blob.try_into()
.map_err(|_| rusqlite::types::FromSqlError::InvalidBlobSize {
expected_size: 20,
blob_size: blob.len(),
})?;
return Ok(DocumentDigest(bytes));
}
}
impl ToSql for DocumentDigest {
fn to_sql(&self) -> rusqlite::Result<ToSqlOutput> {
self.0.to_sql()
}
}
impl From<&'_ str> for DocumentDigest {
fn from(value: &'_ str) -> Self {
let mut sha1 = Sha1::new();
sha1.update(value);
Self(sha1.finalize().into())
}
}
#[derive(Debug, PartialEq, Clone)]
pub struct Document {
pub name: String,
pub range: Range<usize>,
pub content: String,
pub embedding: Vec<f32>,
pub sha1: [u8; 20],
pub embedding: Option<Embedding>,
pub digest: DocumentDigest,
pub token_count: usize,
}
const CODE_CONTEXT_TEMPLATE: &str =
@ -30,6 +66,7 @@ pub const PARSEABLE_ENTIRE_FILE_TYPES: &[&str] =
pub struct CodeContextRetriever {
pub parser: Parser,
pub cursor: QueryCursor,
pub embedding_provider: Arc<dyn EmbeddingProvider>,
}
// Every match has an item, this represents the fundamental treesitter symbol and anchors the search
@ -47,10 +84,11 @@ pub struct CodeContextMatch {
}
impl CodeContextRetriever {
pub fn new() -> Self {
pub fn new(embedding_provider: Arc<dyn EmbeddingProvider>) -> Self {
Self {
parser: Parser::new(),
cursor: QueryCursor::new(),
embedding_provider,
}
}
@ -64,16 +102,15 @@ impl CodeContextRetriever {
.replace("<path>", relative_path.to_string_lossy().as_ref())
.replace("<language>", language_name.as_ref())
.replace("<item>", &content);
let mut sha1 = Sha1::new();
sha1.update(&document_span);
let digest = DocumentDigest::from(document_span.as_str());
let (document_span, token_count) = self.embedding_provider.truncate(&document_span);
Ok(vec![Document {
range: 0..content.len(),
content: document_span,
embedding: Vec::new(),
embedding: Default::default(),
name: language_name.to_string(),
sha1: sha1.finalize().into(),
digest,
token_count,
}])
}
@ -81,16 +118,15 @@ impl CodeContextRetriever {
let document_span = MARKDOWN_CONTEXT_TEMPLATE
.replace("<path>", relative_path.to_string_lossy().as_ref())
.replace("<item>", &content);
let mut sha1 = Sha1::new();
sha1.update(&document_span);
let digest = DocumentDigest::from(document_span.as_str());
let (document_span, token_count) = self.embedding_provider.truncate(&document_span);
Ok(vec![Document {
range: 0..content.len(),
content: document_span,
embedding: Vec::new(),
embedding: None,
name: "Markdown".to_string(),
sha1: sha1.finalize().into(),
digest,
token_count,
}])
}
@ -166,10 +202,16 @@ impl CodeContextRetriever {
let mut documents = self.parse_file(content, language)?;
for document in &mut documents {
document.content = CODE_CONTEXT_TEMPLATE
let document_content = CODE_CONTEXT_TEMPLATE
.replace("<path>", relative_path.to_string_lossy().as_ref())
.replace("<language>", language_name.as_ref())
.replace("item", &document.content);
let (document_content, token_count) =
self.embedding_provider.truncate(&document_content);
document.content = document_content;
document.token_count = token_count;
}
Ok(documents)
}
@ -263,15 +305,14 @@ impl CodeContextRetriever {
);
}
let mut sha1 = Sha1::new();
sha1.update(&document_content);
let sha1 = DocumentDigest::from(document_content.as_str());
documents.push(Document {
name,
content: document_content,
range: item_range.clone(),
embedding: vec![],
sha1: sha1.finalize().into(),
embedding: None,
digest: sha1,
token_count: 0,
})
}

File diff suppressed because it is too large Load Diff

View File

@ -1,14 +1,15 @@
use crate::{
db::dot,
embedding::EmbeddingProvider,
parsing::{subtract_ranges, CodeContextRetriever, Document},
embedding::{DummyEmbeddings, Embedding, EmbeddingProvider},
embedding_queue::EmbeddingQueue,
parsing::{subtract_ranges, CodeContextRetriever, Document, DocumentDigest},
semantic_index_settings::SemanticIndexSettings,
SearchResult, SemanticIndex,
FileToEmbed, JobHandle, SearchResult, SemanticIndex, EMBEDDING_QUEUE_FLUSH_TIMEOUT,
};
use anyhow::Result;
use async_trait::async_trait;
use gpui::{Task, TestAppContext};
use gpui::{executor::Deterministic, Task, TestAppContext};
use language::{Language, LanguageConfig, LanguageRegistry, ToOffset};
use parking_lot::Mutex;
use pretty_assertions::assert_eq;
use project::{project_settings::ProjectSettings, search::PathMatcher, FakeFs, Fs, Project};
use rand::{rngs::StdRng, Rng};
@ -20,8 +21,10 @@ use std::{
atomic::{self, AtomicUsize},
Arc,
},
time::SystemTime,
};
use unindent::Unindent;
use util::RandomCharIter;
#[ctor::ctor]
fn init_logger() {
@ -31,12 +34,8 @@ fn init_logger() {
}
#[gpui::test]
async fn test_semantic_index(cx: &mut TestAppContext) {
cx.update(|cx| {
cx.set_global(SettingsStore::test(cx));
settings::register::<SemanticIndexSettings>(cx);
settings::register::<ProjectSettings>(cx);
});
async fn test_semantic_index(deterministic: Arc<Deterministic>, cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.background());
fs.insert_tree(
@ -56,6 +55,7 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
fn bbb() {
println!(\"bbbbbbbbbbbbb!\");
}
struct pqpqpqp {}
".unindent(),
"file3.toml": "
ZZZZZZZZZZZZZZZZZZ = 5
@ -75,7 +75,7 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
let db_path = db_dir.path().join("db.sqlite");
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
let store = SemanticIndex::new(
let semantic_index = SemanticIndex::new(
fs.clone(),
db_path,
embedding_provider.clone(),
@ -87,21 +87,21 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
let project = Project::test(fs.clone(), ["/the-root".as_ref()], cx).await;
let _ = store
let _ = semantic_index
.update(cx, |store, cx| {
store.initialize_project(project.clone(), cx)
})
.await;
let (file_count, outstanding_file_count) = store
let (file_count, outstanding_file_count) = semantic_index
.update(cx, |store, cx| store.index_project(project.clone(), cx))
.await
.unwrap();
assert_eq!(file_count, 3);
cx.foreground().run_until_parked();
deterministic.advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
assert_eq!(*outstanding_file_count.borrow(), 0);
let search_results = store
let search_results = semantic_index
.update(cx, |store, cx| {
store.search_project(
project.clone(),
@ -122,6 +122,7 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
(Path::new("src/file2.rs").into(), 0),
(Path::new("src/file3.toml").into(), 0),
(Path::new("src/file1.rs").into(), 45),
(Path::new("src/file2.rs").into(), 45),
],
cx,
);
@ -129,7 +130,7 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
// Test Include Files Functonality
let include_files = vec![PathMatcher::new("*.rs").unwrap()];
let exclude_files = vec![PathMatcher::new("*.rs").unwrap()];
let rust_only_search_results = store
let rust_only_search_results = semantic_index
.update(cx, |store, cx| {
store.search_project(
project.clone(),
@ -149,11 +150,12 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
(Path::new("src/file1.rs").into(), 0),
(Path::new("src/file2.rs").into(), 0),
(Path::new("src/file1.rs").into(), 45),
(Path::new("src/file2.rs").into(), 45),
],
cx,
);
let no_rust_search_results = store
let no_rust_search_results = semantic_index
.update(cx, |store, cx| {
store.search_project(
project.clone(),
@ -186,24 +188,87 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
.await
.unwrap();
cx.foreground().run_until_parked();
deterministic.advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
let prev_embedding_count = embedding_provider.embedding_count();
let (file_count, outstanding_file_count) = store
let (file_count, outstanding_file_count) = semantic_index
.update(cx, |store, cx| store.index_project(project.clone(), cx))
.await
.unwrap();
assert_eq!(file_count, 1);
cx.foreground().run_until_parked();
deterministic.advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
assert_eq!(*outstanding_file_count.borrow(), 0);
assert_eq!(
embedding_provider.embedding_count() - prev_embedding_count,
2
1
);
}
#[gpui::test(iterations = 10)]
async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) {
let (outstanding_job_count, _) = postage::watch::channel_with(0);
let outstanding_job_count = Arc::new(Mutex::new(outstanding_job_count));
let files = (1..=3)
.map(|file_ix| FileToEmbed {
worktree_id: 5,
path: format!("path-{file_ix}").into(),
mtime: SystemTime::now(),
documents: (0..rng.gen_range(4..22))
.map(|document_ix| {
let content_len = rng.gen_range(10..100);
let content = RandomCharIter::new(&mut rng)
.with_simple_text()
.take(content_len)
.collect::<String>();
let digest = DocumentDigest::from(content.as_str());
Document {
range: 0..10,
embedding: None,
name: format!("document {document_ix}"),
content,
digest,
token_count: rng.gen_range(10..30),
}
})
.collect(),
job_handle: JobHandle::new(&outstanding_job_count),
})
.collect::<Vec<_>>();
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
let mut queue = EmbeddingQueue::new(embedding_provider.clone(), cx.background());
for file in &files {
queue.push(file.clone());
}
queue.flush();
cx.foreground().run_until_parked();
let finished_files = queue.finished_files();
let mut embedded_files: Vec<_> = files
.iter()
.map(|_| finished_files.try_recv().expect("no finished file"))
.collect();
let expected_files: Vec<_> = files
.iter()
.map(|file| {
let mut file = file.clone();
for doc in &mut file.documents {
doc.embedding = Some(embedding_provider.embed_sync(doc.content.as_ref()));
}
file
})
.collect();
embedded_files.sort_by_key(|f| f.path.clone());
assert_eq!(embedded_files, expected_files);
}
#[track_caller]
fn assert_search_results(
actual: &[SearchResult],
@ -227,7 +292,8 @@ fn assert_search_results(
#[gpui::test]
async fn test_code_context_retrieval_rust() {
let language = rust_lang();
let mut retriever = CodeContextRetriever::new();
let embedding_provider = Arc::new(DummyEmbeddings {});
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = "
/// A doc comment
@ -314,7 +380,8 @@ async fn test_code_context_retrieval_rust() {
#[gpui::test]
async fn test_code_context_retrieval_json() {
let language = json_lang();
let mut retriever = CodeContextRetriever::new();
let embedding_provider = Arc::new(DummyEmbeddings {});
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = r#"
{
@ -397,7 +464,8 @@ fn assert_documents_eq(
#[gpui::test]
async fn test_code_context_retrieval_javascript() {
let language = js_lang();
let mut retriever = CodeContextRetriever::new();
let embedding_provider = Arc::new(DummyEmbeddings {});
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = "
/* globals importScripts, backend */
@ -495,7 +563,8 @@ async fn test_code_context_retrieval_javascript() {
#[gpui::test]
async fn test_code_context_retrieval_lua() {
let language = lua_lang();
let mut retriever = CodeContextRetriever::new();
let embedding_provider = Arc::new(DummyEmbeddings {});
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = r#"
-- Creates a new class
@ -568,7 +637,8 @@ async fn test_code_context_retrieval_lua() {
#[gpui::test]
async fn test_code_context_retrieval_elixir() {
let language = elixir_lang();
let mut retriever = CodeContextRetriever::new();
let embedding_provider = Arc::new(DummyEmbeddings {});
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = r#"
defmodule File.Stream do
@ -684,7 +754,8 @@ async fn test_code_context_retrieval_elixir() {
#[gpui::test]
async fn test_code_context_retrieval_cpp() {
let language = cpp_lang();
let mut retriever = CodeContextRetriever::new();
let embedding_provider = Arc::new(DummyEmbeddings {});
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = "
/**
@ -836,7 +907,8 @@ async fn test_code_context_retrieval_cpp() {
#[gpui::test]
async fn test_code_context_retrieval_ruby() {
let language = ruby_lang();
let mut retriever = CodeContextRetriever::new();
let embedding_provider = Arc::new(DummyEmbeddings {});
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = r#"
# This concern is inspired by "sudo mode" on GitHub. It
@ -1026,7 +1098,8 @@ async fn test_code_context_retrieval_ruby() {
#[gpui::test]
async fn test_code_context_retrieval_php() {
let language = php_lang();
let mut retriever = CodeContextRetriever::new();
let embedding_provider = Arc::new(DummyEmbeddings {});
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = r#"
<?php
@ -1173,36 +1246,6 @@ async fn test_code_context_retrieval_php() {
);
}
#[gpui::test]
fn test_dot_product(mut rng: StdRng) {
assert_eq!(dot(&[1., 0., 0., 0., 0.], &[0., 1., 0., 0., 0.]), 0.);
assert_eq!(dot(&[2., 0., 0., 0., 0.], &[3., 1., 0., 0., 0.]), 6.);
for _ in 0..100 {
let size = 1536;
let mut a = vec![0.; size];
let mut b = vec![0.; size];
for (a, b) in a.iter_mut().zip(b.iter_mut()) {
*a = rng.gen();
*b = rng.gen();
}
assert_eq!(
round_to_decimals(dot(&a, &b), 1),
round_to_decimals(reference_dot(&a, &b), 1)
);
}
fn round_to_decimals(n: f32, decimal_places: i32) -> f32 {
let factor = (10.0 as f32).powi(decimal_places);
(n * factor).round() / factor
}
fn reference_dot(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b.iter()).map(|(a, b)| a * b).sum()
}
}
#[derive(Default)]
struct FakeEmbeddingProvider {
embedding_count: AtomicUsize,
@ -1212,35 +1255,42 @@ impl FakeEmbeddingProvider {
fn embedding_count(&self) -> usize {
self.embedding_count.load(atomic::Ordering::SeqCst)
}
fn embed_sync(&self, span: &str) -> Embedding {
let mut result = vec![1.0; 26];
for letter in span.chars() {
let letter = letter.to_ascii_lowercase();
if letter as u32 >= 'a' as u32 {
let ix = (letter as u32) - ('a' as u32);
if ix < 26 {
result[ix as usize] += 1.0;
}
}
}
let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
for x in &mut result {
*x /= norm;
}
result.into()
}
}
#[async_trait]
impl EmbeddingProvider for FakeEmbeddingProvider {
async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
fn truncate(&self, span: &str) -> (String, usize) {
(span.to_string(), 1)
}
fn max_tokens_per_batch(&self) -> usize {
200
}
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
self.embedding_count
.fetch_add(spans.len(), atomic::Ordering::SeqCst);
Ok(spans
.iter()
.map(|span| {
let mut result = vec![1.0; 26];
for letter in span.chars() {
let letter = letter.to_ascii_lowercase();
if letter as u32 >= 'a' as u32 {
let ix = (letter as u32) - ('a' as u32);
if ix < 26 {
result[ix as usize] += 1.0;
}
}
}
let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
for x in &mut result {
*x /= norm;
}
result
})
.collect())
Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
}
}
@ -1684,3 +1734,11 @@ fn test_subtract_ranges() {
assert_eq!(subtract_ranges(&[0..5], &[1..2]), &[0..1, 2..5]);
}
fn init_test(cx: &mut TestAppContext) {
cx.update(|cx| {
cx.set_global(SettingsStore::test(cx));
settings::register::<SemanticIndexSettings>(cx);
settings::register::<ProjectSettings>(cx);
});
}

View File

@ -260,11 +260,22 @@ pub fn defer<F: FnOnce()>(f: F) -> impl Drop {
Defer(Some(f))
}
pub struct RandomCharIter<T: Rng>(T);
pub struct RandomCharIter<T: Rng> {
rng: T,
simple_text: bool,
}
impl<T: Rng> RandomCharIter<T> {
pub fn new(rng: T) -> Self {
Self(rng)
Self {
rng,
simple_text: std::env::var("SIMPLE_TEXT").map_or(false, |v| !v.is_empty()),
}
}
pub fn with_simple_text(mut self) -> Self {
self.simple_text = true;
self
}
}
@ -272,25 +283,27 @@ impl<T: Rng> Iterator for RandomCharIter<T> {
type Item = char;
fn next(&mut self) -> Option<Self::Item> {
if std::env::var("SIMPLE_TEXT").map_or(false, |v| !v.is_empty()) {
return if self.0.gen_range(0..100) < 5 {
if self.simple_text {
return if self.rng.gen_range(0..100) < 5 {
Some('\n')
} else {
Some(self.0.gen_range(b'a'..b'z' + 1).into())
Some(self.rng.gen_range(b'a'..b'z' + 1).into())
};
}
match self.0.gen_range(0..100) {
match self.rng.gen_range(0..100) {
// whitespace
0..=19 => [' ', '\n', '\r', '\t'].choose(&mut self.0).copied(),
0..=19 => [' ', '\n', '\r', '\t'].choose(&mut self.rng).copied(),
// two-byte greek letters
20..=32 => char::from_u32(self.0.gen_range(('α' as u32)..('ω' as u32 + 1))),
20..=32 => char::from_u32(self.rng.gen_range(('α' as u32)..('ω' as u32 + 1))),
// // three-byte characters
33..=45 => ['✋', '✅', '❌', '❎', '⭐'].choose(&mut self.0).copied(),
33..=45 => ['✋', '✅', '❌', '❎', '⭐']
.choose(&mut self.rng)
.copied(),
// // four-byte characters
46..=58 => ['🍐', '🏀', '🍗', '🎉'].choose(&mut self.0).copied(),
46..=58 => ['🍐', '🏀', '🍗', '🎉'].choose(&mut self.rng).copied(),
// ascii letters
_ => Some(self.0.gen_range(b'a'..b'z' + 1).into()),
_ => Some(self.rng.gen_range(b'a'..b'z' + 1).into()),
}
}
}