From 09db455db25ef222abf74b5d45cf4138117ef334 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Tue, 5 Dec 2023 15:38:36 +0100 Subject: [PATCH] Port `semantic_index` to gpui2 Co-Authored-By: Julia Risley --- Cargo.lock | 51 + Cargo.toml | 3 +- crates/ai2/src/auth.rs | 2 +- crates/ai2/src/providers/open_ai/embedding.rs | 4 +- crates/gpui2/src/app/entity_map.rs | 19 +- crates/semantic_index2/Cargo.toml | 69 + crates/semantic_index2/README.md | 20 + crates/semantic_index2/eval/gpt-engineer.json | 114 ++ crates/semantic_index2/eval/tree-sitter.json | 104 + crates/semantic_index2/src/db.rs | 603 ++++++ crates/semantic_index2/src/embedding_queue.rs | 169 ++ crates/semantic_index2/src/parsing.rs | 414 ++++ crates/semantic_index2/src/semantic_index.rs | 1280 +++++++++++++ .../src/semantic_index_settings.rs | 28 + .../src/semantic_index_tests.rs | 1697 +++++++++++++++++ crates/workspace2/src/workspace2.rs | 2 - 16 files changed, 4569 insertions(+), 10 deletions(-) create mode 100644 crates/semantic_index2/Cargo.toml create mode 100644 crates/semantic_index2/README.md create mode 100644 crates/semantic_index2/eval/gpt-engineer.json create mode 100644 crates/semantic_index2/eval/tree-sitter.json create mode 100644 crates/semantic_index2/src/db.rs create mode 100644 crates/semantic_index2/src/embedding_queue.rs create mode 100644 crates/semantic_index2/src/parsing.rs create mode 100644 crates/semantic_index2/src/semantic_index.rs create mode 100644 crates/semantic_index2/src/semantic_index_settings.rs create mode 100644 crates/semantic_index2/src/semantic_index_tests.rs diff --git a/Cargo.lock b/Cargo.lock index 6121ec9718..39683c9fc1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8232,6 +8232,57 @@ dependencies = [ "workspace", ] +[[package]] +name = "semantic_index2" +version = "0.1.0" +dependencies = [ + "ai2", + "anyhow", + "async-trait", + "client2", + "collections", + "ctor", + "env_logger 0.9.3", + "futures 0.3.28", + "globset", + "gpui2", + "language2", + "lazy_static", + "log", + "ndarray", + "node_runtime", + "ordered-float 2.10.0", + "parking_lot 0.11.2", + "postage", + "pretty_assertions", + "project2", + "rand 0.8.5", + "rpc2", + "rusqlite", + "rust-embed", + "schemars", + "serde", + "serde_json", + "settings2", + "sha1", + "smol", + "tempdir", + "tiktoken-rs", + "tree-sitter", + "tree-sitter-cpp", + "tree-sitter-elixir", + "tree-sitter-json 0.20.0", + "tree-sitter-lua", + "tree-sitter-php", + "tree-sitter-ruby", + "tree-sitter-rust", + "tree-sitter-toml", + "tree-sitter-typescript", + "unindent", + "util", + "workspace2", +] + [[package]] name = "semver" version = "1.0.18" diff --git a/Cargo.toml b/Cargo.toml index 3658ffad29..610a4dc11e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -95,6 +95,8 @@ members = [ "crates/rpc2", "crates/search", "crates/search2", + "crates/semantic_index", + "crates/semantic_index2", "crates/settings", "crates/settings2", "crates/snippet", @@ -114,7 +116,6 @@ members = [ "crates/theme_selector2", "crates/ui2", "crates/util", - "crates/semantic_index", "crates/story", "crates/vim", "crates/vcs_menu", diff --git a/crates/ai2/src/auth.rs b/crates/ai2/src/auth.rs index baa1fe7b83..1ea49bd615 100644 --- a/crates/ai2/src/auth.rs +++ b/crates/ai2/src/auth.rs @@ -7,7 +7,7 @@ pub enum ProviderCredential { NotNeeded, } -pub trait CredentialProvider { +pub trait CredentialProvider: Send + Sync { fn has_credentials(&self) -> bool; fn retrieve_credentials(&self, cx: &mut AppContext) -> ProviderCredential; fn save_credentials(&self, cx: &mut AppContext, credential: ProviderCredential); diff --git a/crates/ai2/src/providers/open_ai/embedding.rs b/crates/ai2/src/providers/open_ai/embedding.rs index 8f62c8dc0d..d5fe4e8c58 100644 --- a/crates/ai2/src/providers/open_ai/embedding.rs +++ b/crates/ai2/src/providers/open_ai/embedding.rs @@ -35,7 +35,7 @@ pub struct OpenAIEmbeddingProvider { model: OpenAILanguageModel, credential: Arc>, pub client: Arc, - pub executor: Arc, + pub executor: BackgroundExecutor, rate_limit_count_rx: watch::Receiver>, rate_limit_count_tx: Arc>>>, } @@ -66,7 +66,7 @@ struct OpenAIEmbeddingUsage { } impl OpenAIEmbeddingProvider { - pub fn new(client: Arc, executor: Arc) -> Self { + pub fn new(client: Arc, executor: BackgroundExecutor) -> Self { let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None); let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx)); diff --git a/crates/gpui2/src/app/entity_map.rs b/crates/gpui2/src/app/entity_map.rs index a34582f4f4..99d8542eba 100644 --- a/crates/gpui2/src/app/entity_map.rs +++ b/crates/gpui2/src/app/entity_map.rs @@ -482,10 +482,6 @@ impl WeakModel { /// Update the entity referenced by this model with the given function if /// the referenced entity still exists. Returns an error if the entity has /// been released. - /// - /// The update function receives a context appropriate for its environment. - /// When updating in an `AppContext`, it receives a `ModelContext`. - /// When updating an a `WindowContext`, it receives a `ViewContext`. pub fn update( &self, cx: &mut C, @@ -501,6 +497,21 @@ impl WeakModel { .map(|this| cx.update_model(&this, update)), ) } + + /// Reads the entity referenced by this model with the given function if + /// the referenced entity still exists. Returns an error if the entity has + /// been released. + pub fn read_with(&self, cx: &C, read: impl FnOnce(&T, &AppContext) -> R) -> Result + where + C: Context, + Result>: crate::Flatten, + { + crate::Flatten::flatten( + self.upgrade() + .ok_or_else(|| anyhow!("entity release")) + .map(|this| cx.read_model(&this, read)), + ) + } } impl Hash for WeakModel { diff --git a/crates/semantic_index2/Cargo.toml b/crates/semantic_index2/Cargo.toml new file mode 100644 index 0000000000..65ffb05ca5 --- /dev/null +++ b/crates/semantic_index2/Cargo.toml @@ -0,0 +1,69 @@ +[package] +name = "semantic_index2" +version = "0.1.0" +edition = "2021" +publish = false + +[lib] +path = "src/semantic_index.rs" +doctest = false + +[dependencies] +ai = { package = "ai2", path = "../ai2" } +collections = { path = "../collections" } +gpui = { package = "gpui2", path = "../gpui2" } +language = { package = "language2", path = "../language2" } +project = { package = "project2", path = "../project2" } +workspace = { package = "workspace2", path = "../workspace2" } +util = { path = "../util" } +rpc = { package = "rpc2", path = "../rpc2" } +settings = { package = "settings2", path = "../settings2" } +anyhow.workspace = true +postage.workspace = true +futures.workspace = true +ordered-float.workspace = true +smol.workspace = true +rusqlite.workspace = true +log.workspace = true +tree-sitter.workspace = true +lazy_static.workspace = true +serde.workspace = true +serde_json.workspace = true +async-trait.workspace = true +tiktoken-rs.workspace = true +parking_lot.workspace = true +rand.workspace = true +schemars.workspace = true +globset.workspace = true +sha1 = "0.10.5" +ndarray = { version = "0.15.0" } + +[dev-dependencies] +ai = { package = "ai2", path = "../ai2", features = ["test-support"] } +collections = { path = "../collections", features = ["test-support"] } +gpui = { package = "gpui2", path = "../gpui2", features = ["test-support"] } +language = { package = "language2", path = "../language2", features = ["test-support"] } +project = { package = "project2", path = "../project2", features = ["test-support"] } +rpc = { package = "rpc2", path = "../rpc2", features = ["test-support"] } +workspace = { package = "workspace2", path = "../workspace2", features = ["test-support"] } +settings = { package = "settings2", path = "../settings2", features = ["test-support"]} +rust-embed = { version = "8.0", features = ["include-exclude"] } +client = { package = "client2", path = "../client2" } +node_runtime = { path = "../node_runtime"} + +pretty_assertions.workspace = true +rand.workspace = true +unindent.workspace = true +tempdir.workspace = true +ctor.workspace = true +env_logger.workspace = true + +tree-sitter-typescript.workspace = true +tree-sitter-json.workspace = true +tree-sitter-rust.workspace = true +tree-sitter-toml.workspace = true +tree-sitter-cpp.workspace = true +tree-sitter-elixir.workspace = true +tree-sitter-lua.workspace = true +tree-sitter-ruby.workspace = true +tree-sitter-php.workspace = true diff --git a/crates/semantic_index2/README.md b/crates/semantic_index2/README.md new file mode 100644 index 0000000000..85f83af121 --- /dev/null +++ b/crates/semantic_index2/README.md @@ -0,0 +1,20 @@ + +# Semantic Index + +## Evaluation + +### Metrics + +nDCG@k: +- "The value of NDCG is determined by comparing the relevance of the items returned by the search engine to the relevance of the item that a hypothetical "ideal" search engine would return. +- "The relevance of result is represented by a score (also known as a 'grade') that is assigned to the search query. The scores of these results are then discounted based on their position in the search results -- did they get recommended first or last?" + +MRR@k: +- "Mean reciprocal rank quantifies the rank of the first relevant item found in teh recommendation list." + +MAP@k: +- "Mean average precision averages the precision@k metric at each relevant item position in the recommendation list. + +Resources: +- [Evaluating recommendation metrics](https://www.shaped.ai/blog/evaluating-recommendation-systems-map-mmr-ndcg) +- [Math Walkthrough](https://towardsdatascience.com/demystifying-ndcg-bee3be58cfe0) diff --git a/crates/semantic_index2/eval/gpt-engineer.json b/crates/semantic_index2/eval/gpt-engineer.json new file mode 100644 index 0000000000..d008cc65d1 --- /dev/null +++ b/crates/semantic_index2/eval/gpt-engineer.json @@ -0,0 +1,114 @@ +{ + "repo": "https://github.com/AntonOsika/gpt-engineer.git", + "commit": "7735a6445bae3611c62f521e6464c67c957f87c2", + "assertions": [ + { + "query": "How do I contribute to this project?", + "matches": [ + ".github/CONTRIBUTING.md:1", + "ROADMAP.md:48" + ] + }, + { + "query": "What version of the openai package is active?", + "matches": [ + "pyproject.toml:14" + ] + }, + { + "query": "Ask user for clarification", + "matches": [ + "gpt_engineer/steps.py:69" + ] + }, + { + "query": "generate tests for python code", + "matches": [ + "gpt_engineer/steps.py:153" + ] + }, + { + "query": "get item from database based on key", + "matches": [ + "gpt_engineer/db.py:42", + "gpt_engineer/db.py:68" + ] + }, + { + "query": "prompt user to select files", + "matches": [ + "gpt_engineer/file_selector.py:171", + "gpt_engineer/file_selector.py:306", + "gpt_engineer/file_selector.py:289", + "gpt_engineer/file_selector.py:234" + ] + }, + { + "query": "send to rudderstack", + "matches": [ + "gpt_engineer/collect.py:11", + "gpt_engineer/collect.py:38" + ] + }, + { + "query": "parse code blocks from chat messages", + "matches": [ + "gpt_engineer/chat_to_files.py:10", + "docs/intro/chat_parsing.md:1" + ] + }, + { + "query": "how do I use the docker cli?", + "matches": [ + "docker/README.md:1" + ] + }, + { + "query": "ask the user if the code ran successfully?", + "matches": [ + "gpt_engineer/learning.py:54" + ] + }, + { + "query": "how is consent granted by the user?", + "matches": [ + "gpt_engineer/learning.py:107", + "gpt_engineer/learning.py:130", + "gpt_engineer/learning.py:152" + ] + }, + { + "query": "what are all the different steps the agent can take?", + "matches": [ + "docs/intro/steps_module.md:1", + "gpt_engineer/steps.py:391" + ] + }, + { + "query": "ask the user for clarification?", + "matches": [ + "gpt_engineer/steps.py:69" + ] + }, + { + "query": "what models are available?", + "matches": [ + "gpt_engineer/ai.py:315", + "gpt_engineer/ai.py:341", + "docs/open-models.md:1" + ] + }, + { + "query": "what is the current focus of the project?", + "matches": [ + "ROADMAP.md:11" + ] + }, + { + "query": "does the agent know how to fix code?", + "matches": [ + "gpt_engineer/steps.py:367" + ] + } + ] +} diff --git a/crates/semantic_index2/eval/tree-sitter.json b/crates/semantic_index2/eval/tree-sitter.json new file mode 100644 index 0000000000..d3dcc86937 --- /dev/null +++ b/crates/semantic_index2/eval/tree-sitter.json @@ -0,0 +1,104 @@ +{ + "repo": "https://github.com/tree-sitter/tree-sitter.git", + "commit": "46af27796a76c72d8466627d499f2bca4af958ee", + "assertions": [ + { + "query": "What attributes are available for the tags configuration struct?", + "matches": [ + "tags/src/lib.rs:24" + ] + }, + { + "query": "create a new tag configuration", + "matches": [ + "tags/src/lib.rs:119" + ] + }, + { + "query": "generate tags based on config", + "matches": [ + "tags/src/lib.rs:261" + ] + }, + { + "query": "match on ts quantifier in rust", + "matches": [ + "lib/binding_rust/lib.rs:139" + ] + }, + { + "query": "cli command to generate tags", + "matches": [ + "cli/src/tags.rs:10" + ] + }, + { + "query": "what version of the tree-sitter-tags package is active?", + "matches": [ + "tags/Cargo.toml:4" + ] + }, + { + "query": "Insert a new parse state", + "matches": [ + "cli/src/generate/build_tables/build_parse_table.rs:153" + ] + }, + { + "query": "Handle conflict when numerous actions occur on the same symbol", + "matches": [ + "cli/src/generate/build_tables/build_parse_table.rs:363", + "cli/src/generate/build_tables/build_parse_table.rs:442" + ] + }, + { + "query": "Match based on associativity of actions", + "matches": [ + "cri/src/generate/build_tables/build_parse_table.rs:542" + ] + }, + { + "query": "Format token set display", + "matches": [ + "cli/src/generate/build_tables/item.rs:246" + ] + }, + { + "query": "extract choices from rule", + "matches": [ + "cli/src/generate/prepare_grammar/flatten_grammar.rs:124" + ] + }, + { + "query": "How do we identify if a symbol is being used?", + "matches": [ + "cli/src/generate/prepare_grammar/flatten_grammar.rs:175" + ] + }, + { + "query": "How do we launch the playground?", + "matches": [ + "cli/src/playground.rs:46" + ] + }, + { + "query": "How do we test treesitter query matches in rust?", + "matches": [ + "cli/src/query_testing.rs:152", + "cli/src/tests/query_test.rs:781", + "cli/src/tests/query_test.rs:2163", + "cli/src/tests/query_test.rs:3781", + "cli/src/tests/query_test.rs:887" + ] + }, + { + "query": "What does the CLI do?", + "matches": [ + "cli/README.md:10", + "cli/loader/README.md:3", + "docs/section-5-implementation.md:14", + "docs/section-5-implementation.md:18" + ] + } + ] +} diff --git a/crates/semantic_index2/src/db.rs b/crates/semantic_index2/src/db.rs new file mode 100644 index 0000000000..f34baeaaae --- /dev/null +++ b/crates/semantic_index2/src/db.rs @@ -0,0 +1,603 @@ +use crate::{ + parsing::{Span, SpanDigest}, + SEMANTIC_INDEX_VERSION, +}; +use ai::embedding::Embedding; +use anyhow::{anyhow, Context, Result}; +use collections::HashMap; +use futures::channel::oneshot; +use gpui::BackgroundExecutor; +use ndarray::{Array1, Array2}; +use ordered_float::OrderedFloat; +use project::Fs; +use rpc::proto::Timestamp; +use rusqlite::params; +use rusqlite::types::Value; +use std::{ + future::Future, + ops::Range, + path::{Path, PathBuf}, + rc::Rc, + sync::Arc, + time::SystemTime, +}; +use util::{paths::PathMatcher, TryFutureExt}; + +pub fn argsort(data: &[T]) -> Vec { + let mut indices = (0..data.len()).collect::>(); + indices.sort_by_key(|&i| &data[i]); + indices.reverse(); + indices +} + +#[derive(Debug)] +pub struct FileRecord { + pub id: usize, + pub relative_path: String, + pub mtime: Timestamp, +} + +#[derive(Clone)] +pub struct VectorDatabase { + path: Arc, + transactions: + smol::channel::Sender>, +} + +impl VectorDatabase { + pub async fn new( + fs: Arc, + path: Arc, + executor: BackgroundExecutor, + ) -> Result { + if let Some(db_directory) = path.parent() { + fs.create_dir(db_directory).await?; + } + + let (transactions_tx, transactions_rx) = smol::channel::unbounded::< + Box, + >(); + executor + .spawn({ + let path = path.clone(); + async move { + let mut connection = rusqlite::Connection::open(&path)?; + + connection.pragma_update(None, "journal_mode", "wal")?; + connection.pragma_update(None, "synchronous", "normal")?; + connection.pragma_update(None, "cache_size", 1000000)?; + connection.pragma_update(None, "temp_store", "MEMORY")?; + + while let Ok(transaction) = transactions_rx.recv().await { + transaction(&mut connection); + } + + anyhow::Ok(()) + } + .log_err() + }) + .detach(); + let this = Self { + transactions: transactions_tx, + path, + }; + this.initialize_database().await?; + Ok(this) + } + + pub fn path(&self) -> &Arc { + &self.path + } + + fn transact(&self, f: F) -> impl Future> + where + F: 'static + Send + FnOnce(&rusqlite::Transaction) -> Result, + T: 'static + Send, + { + let (tx, rx) = oneshot::channel(); + let transactions = self.transactions.clone(); + async move { + if transactions + .send(Box::new(|connection| { + let result = connection + .transaction() + .map_err(|err| anyhow!(err)) + .and_then(|transaction| { + let result = f(&transaction)?; + transaction.commit()?; + Ok(result) + }); + let _ = tx.send(result); + })) + .await + .is_err() + { + return Err(anyhow!("connection was dropped"))?; + } + rx.await? + } + } + + fn initialize_database(&self) -> impl Future> { + self.transact(|db| { + rusqlite::vtab::array::load_module(&db)?; + + // Delete existing tables, if SEMANTIC_INDEX_VERSION is bumped + let version_query = db.prepare("SELECT version from semantic_index_config"); + let version = version_query + .and_then(|mut query| query.query_row([], |row| Ok(row.get::<_, i64>(0)?))); + if version.map_or(false, |version| version == SEMANTIC_INDEX_VERSION as i64) { + log::trace!("vector database schema up to date"); + return Ok(()); + } + + log::trace!("vector database schema out of date. updating..."); + // We renamed the `documents` table to `spans`, so we want to drop + // `documents` without recreating it if it exists. + db.execute("DROP TABLE IF EXISTS documents", []) + .context("failed to drop 'documents' table")?; + db.execute("DROP TABLE IF EXISTS spans", []) + .context("failed to drop 'spans' table")?; + db.execute("DROP TABLE IF EXISTS files", []) + .context("failed to drop 'files' table")?; + db.execute("DROP TABLE IF EXISTS worktrees", []) + .context("failed to drop 'worktrees' table")?; + db.execute("DROP TABLE IF EXISTS semantic_index_config", []) + .context("failed to drop 'semantic_index_config' table")?; + + // Initialize Vector Databasing Tables + db.execute( + "CREATE TABLE semantic_index_config ( + version INTEGER NOT NULL + )", + [], + )?; + + db.execute( + "INSERT INTO semantic_index_config (version) VALUES (?1)", + params![SEMANTIC_INDEX_VERSION], + )?; + + db.execute( + "CREATE TABLE worktrees ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + absolute_path VARCHAR NOT NULL + ); + CREATE UNIQUE INDEX worktrees_absolute_path ON worktrees (absolute_path); + ", + [], + )?; + + db.execute( + "CREATE TABLE files ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + worktree_id INTEGER NOT NULL, + relative_path VARCHAR NOT NULL, + mtime_seconds INTEGER NOT NULL, + mtime_nanos INTEGER NOT NULL, + FOREIGN KEY(worktree_id) REFERENCES worktrees(id) ON DELETE CASCADE + )", + [], + )?; + + db.execute( + "CREATE UNIQUE INDEX files_worktree_id_and_relative_path ON files (worktree_id, relative_path)", + [], + )?; + + db.execute( + "CREATE TABLE spans ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + file_id INTEGER NOT NULL, + start_byte INTEGER NOT NULL, + end_byte INTEGER NOT NULL, + name VARCHAR NOT NULL, + embedding BLOB NOT NULL, + digest BLOB NOT NULL, + FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE + )", + [], + )?; + db.execute( + "CREATE INDEX spans_digest ON spans (digest)", + [], + )?; + + log::trace!("vector database initialized with updated schema."); + Ok(()) + }) + } + + pub fn delete_file( + &self, + worktree_id: i64, + delete_path: Arc, + ) -> impl Future> { + self.transact(move |db| { + db.execute( + "DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2", + params![worktree_id, delete_path.to_str()], + )?; + Ok(()) + }) + } + + pub fn insert_file( + &self, + worktree_id: i64, + path: Arc, + mtime: SystemTime, + spans: Vec, + ) -> impl Future> { + self.transact(move |db| { + // Return the existing ID, if both the file and mtime match + let mtime = Timestamp::from(mtime); + + db.execute( + " + REPLACE INTO files + (worktree_id, relative_path, mtime_seconds, mtime_nanos) + VALUES (?1, ?2, ?3, ?4) + ", + params![worktree_id, path.to_str(), mtime.seconds, mtime.nanos], + )?; + + let file_id = db.last_insert_rowid(); + + let mut query = db.prepare( + " + INSERT INTO spans + (file_id, start_byte, end_byte, name, embedding, digest) + VALUES (?1, ?2, ?3, ?4, ?5, ?6) + ", + )?; + + for span in spans { + query.execute(params![ + file_id, + span.range.start.to_string(), + span.range.end.to_string(), + span.name, + span.embedding, + span.digest + ])?; + } + + Ok(()) + }) + } + + pub fn worktree_previously_indexed( + &self, + worktree_root_path: &Path, + ) -> impl Future> { + let worktree_root_path = worktree_root_path.to_string_lossy().into_owned(); + self.transact(move |db| { + let mut worktree_query = + db.prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?; + let worktree_id = worktree_query + .query_row(params![worktree_root_path], |row| Ok(row.get::<_, i64>(0)?)); + + if worktree_id.is_ok() { + return Ok(true); + } else { + return Ok(false); + } + }) + } + + pub fn embeddings_for_digests( + &self, + digests: Vec, + ) -> impl Future>> { + self.transact(move |db| { + let mut query = db.prepare( + " + SELECT digest, embedding + FROM spans + WHERE digest IN rarray(?) + ", + )?; + let mut embeddings_by_digest = HashMap::default(); + let digests = Rc::new( + digests + .into_iter() + .map(|p| Value::Blob(p.0.to_vec())) + .collect::>(), + ); + let rows = query.query_map(params![digests], |row| { + Ok((row.get::<_, SpanDigest>(0)?, row.get::<_, Embedding>(1)?)) + })?; + + for row in rows { + if let Ok(row) = row { + embeddings_by_digest.insert(row.0, row.1); + } + } + + Ok(embeddings_by_digest) + }) + } + + pub fn embeddings_for_files( + &self, + worktree_id_file_paths: HashMap>>, + ) -> impl Future>> { + self.transact(move |db| { + let mut query = db.prepare( + " + SELECT digest, embedding + FROM spans + LEFT JOIN files ON files.id = spans.file_id + WHERE files.worktree_id = ? AND files.relative_path IN rarray(?) + ", + )?; + let mut embeddings_by_digest = HashMap::default(); + for (worktree_id, file_paths) in worktree_id_file_paths { + let file_paths = Rc::new( + file_paths + .into_iter() + .map(|p| Value::Text(p.to_string_lossy().into_owned())) + .collect::>(), + ); + let rows = query.query_map(params![worktree_id, file_paths], |row| { + Ok((row.get::<_, SpanDigest>(0)?, row.get::<_, Embedding>(1)?)) + })?; + + for row in rows { + if let Ok(row) = row { + embeddings_by_digest.insert(row.0, row.1); + } + } + } + + Ok(embeddings_by_digest) + }) + } + + pub fn find_or_create_worktree( + &self, + worktree_root_path: Arc, + ) -> impl Future> { + self.transact(move |db| { + let mut worktree_query = + db.prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?; + let worktree_id = worktree_query + .query_row(params![worktree_root_path.to_string_lossy()], |row| { + Ok(row.get::<_, i64>(0)?) + }); + + if worktree_id.is_ok() { + return Ok(worktree_id?); + } + + // If worktree_id is Err, insert new worktree + db.execute( + "INSERT into worktrees (absolute_path) VALUES (?1)", + params![worktree_root_path.to_string_lossy()], + )?; + Ok(db.last_insert_rowid()) + }) + } + + pub fn get_file_mtimes( + &self, + worktree_id: i64, + ) -> impl Future>> { + self.transact(move |db| { + let mut statement = db.prepare( + " + SELECT relative_path, mtime_seconds, mtime_nanos + FROM files + WHERE worktree_id = ?1 + ORDER BY relative_path", + )?; + let mut result: HashMap = HashMap::default(); + for row in statement.query_map(params![worktree_id], |row| { + Ok(( + row.get::<_, String>(0)?.into(), + Timestamp { + seconds: row.get(1)?, + nanos: row.get(2)?, + } + .into(), + )) + })? { + let row = row?; + result.insert(row.0, row.1); + } + Ok(result) + }) + } + + pub fn top_k_search( + &self, + query_embedding: &Embedding, + limit: usize, + file_ids: &[i64], + ) -> impl Future)>>> { + let file_ids = file_ids.to_vec(); + let query = query_embedding.clone().0; + let query = Array1::from_vec(query); + self.transact(move |db| { + let mut query_statement = db.prepare( + " + SELECT + id, embedding + FROM + spans + WHERE + file_id IN rarray(?) + ", + )?; + + let deserialized_rows = query_statement + .query_map(params![ids_to_sql(&file_ids)], |row| { + Ok((row.get::<_, usize>(0)?, row.get::<_, Embedding>(1)?)) + })? + .filter_map(|row| row.ok()) + .collect::>(); + + if deserialized_rows.len() == 0 { + return Ok(Vec::new()); + } + + // Get Length of Embeddings Returned + let embedding_len = deserialized_rows[0].1 .0.len(); + + let batch_n = 1000; + let mut batches = Vec::new(); + let mut batch_ids = Vec::new(); + let mut batch_embeddings: Vec = Vec::new(); + deserialized_rows.iter().for_each(|(id, embedding)| { + batch_ids.push(id); + batch_embeddings.extend(&embedding.0); + + if batch_ids.len() == batch_n { + let embeddings = std::mem::take(&mut batch_embeddings); + let ids = std::mem::take(&mut batch_ids); + let array = + Array2::from_shape_vec((ids.len(), embedding_len.clone()), embeddings); + match array { + Ok(array) => { + batches.push((ids, array)); + } + Err(err) => log::error!("Failed to deserialize to ndarray: {:?}", err), + } + } + }); + + if batch_ids.len() > 0 { + let array = Array2::from_shape_vec( + (batch_ids.len(), embedding_len), + batch_embeddings.clone(), + ); + match array { + Ok(array) => { + batches.push((batch_ids.clone(), array)); + } + Err(err) => log::error!("Failed to deserialize to ndarray: {:?}", err), + } + } + + let mut ids: Vec = Vec::new(); + let mut results = Vec::new(); + for (batch_ids, array) in batches { + let scores = array + .dot(&query.t()) + .to_vec() + .iter() + .map(|score| OrderedFloat(*score)) + .collect::>>(); + results.extend(scores); + ids.extend(batch_ids); + } + + let sorted_idx = argsort(&results); + let mut sorted_results = Vec::new(); + let last_idx = limit.min(sorted_idx.len()); + for idx in &sorted_idx[0..last_idx] { + sorted_results.push((ids[*idx] as i64, results[*idx])) + } + + Ok(sorted_results) + }) + } + + pub fn retrieve_included_file_ids( + &self, + worktree_ids: &[i64], + includes: &[PathMatcher], + excludes: &[PathMatcher], + ) -> impl Future>> { + let worktree_ids = worktree_ids.to_vec(); + let includes = includes.to_vec(); + let excludes = excludes.to_vec(); + self.transact(move |db| { + let mut file_query = db.prepare( + " + SELECT + id, relative_path + FROM + files + WHERE + worktree_id IN rarray(?) + ", + )?; + + let mut file_ids = Vec::::new(); + let mut rows = file_query.query([ids_to_sql(&worktree_ids)])?; + + while let Some(row) = rows.next()? { + let file_id = row.get(0)?; + let relative_path = row.get_ref(1)?.as_str()?; + let included = + includes.is_empty() || includes.iter().any(|glob| glob.is_match(relative_path)); + let excluded = excludes.iter().any(|glob| glob.is_match(relative_path)); + if included && !excluded { + file_ids.push(file_id); + } + } + + anyhow::Ok(file_ids) + }) + } + + pub fn spans_for_ids( + &self, + ids: &[i64], + ) -> impl Future)>>> { + let ids = ids.to_vec(); + self.transact(move |db| { + let mut statement = db.prepare( + " + SELECT + spans.id, + files.worktree_id, + files.relative_path, + spans.start_byte, + spans.end_byte + FROM + spans, files + WHERE + spans.file_id = files.id AND + spans.id in rarray(?) + ", + )?; + + let result_iter = statement.query_map(params![ids_to_sql(&ids)], |row| { + Ok(( + row.get::<_, i64>(0)?, + row.get::<_, i64>(1)?, + row.get::<_, String>(2)?.into(), + row.get(3)?..row.get(4)?, + )) + })?; + + let mut values_by_id = HashMap::)>::default(); + for row in result_iter { + let (id, worktree_id, path, range) = row?; + values_by_id.insert(id, (worktree_id, path, range)); + } + + let mut results = Vec::with_capacity(ids.len()); + for id in &ids { + let value = values_by_id + .remove(id) + .ok_or(anyhow!("missing span id {}", id))?; + results.push(value); + } + + Ok(results) + }) + } +} + +fn ids_to_sql(ids: &[i64]) -> Rc> { + Rc::new( + ids.iter() + .copied() + .map(|v| rusqlite::types::Value::from(v)) + .collect::>(), + ) +} diff --git a/crates/semantic_index2/src/embedding_queue.rs b/crates/semantic_index2/src/embedding_queue.rs new file mode 100644 index 0000000000..a2371a1196 --- /dev/null +++ b/crates/semantic_index2/src/embedding_queue.rs @@ -0,0 +1,169 @@ +use crate::{parsing::Span, JobHandle}; +use ai::embedding::EmbeddingProvider; +use gpui::BackgroundExecutor; +use parking_lot::Mutex; +use smol::channel; +use std::{mem, ops::Range, path::Path, sync::Arc, time::SystemTime}; + +#[derive(Clone)] +pub struct FileToEmbed { + pub worktree_id: i64, + pub path: Arc, + pub mtime: SystemTime, + pub spans: Vec, + pub job_handle: JobHandle, +} + +impl std::fmt::Debug for FileToEmbed { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("FileToEmbed") + .field("worktree_id", &self.worktree_id) + .field("path", &self.path) + .field("mtime", &self.mtime) + .field("spans", &self.spans) + .finish_non_exhaustive() + } +} + +impl PartialEq for FileToEmbed { + fn eq(&self, other: &Self) -> bool { + self.worktree_id == other.worktree_id + && self.path == other.path + && self.mtime == other.mtime + && self.spans == other.spans + } +} + +pub struct EmbeddingQueue { + embedding_provider: Arc, + pending_batch: Vec, + executor: BackgroundExecutor, + pending_batch_token_count: usize, + finished_files_tx: channel::Sender, + finished_files_rx: channel::Receiver, +} + +#[derive(Clone)] +pub struct FileFragmentToEmbed { + file: Arc>, + span_range: Range, +} + +impl EmbeddingQueue { + pub fn new( + embedding_provider: Arc, + executor: BackgroundExecutor, + ) -> Self { + let (finished_files_tx, finished_files_rx) = channel::unbounded(); + Self { + embedding_provider, + executor, + pending_batch: Vec::new(), + pending_batch_token_count: 0, + finished_files_tx, + finished_files_rx, + } + } + + pub fn push(&mut self, file: FileToEmbed) { + if file.spans.is_empty() { + self.finished_files_tx.try_send(file).unwrap(); + return; + } + + let file = Arc::new(Mutex::new(file)); + + self.pending_batch.push(FileFragmentToEmbed { + file: file.clone(), + span_range: 0..0, + }); + + let mut fragment_range = &mut self.pending_batch.last_mut().unwrap().span_range; + for (ix, span) in file.lock().spans.iter().enumerate() { + let span_token_count = if span.embedding.is_none() { + span.token_count + } else { + 0 + }; + + let next_token_count = self.pending_batch_token_count + span_token_count; + if next_token_count > self.embedding_provider.max_tokens_per_batch() { + let range_end = fragment_range.end; + self.flush(); + self.pending_batch.push(FileFragmentToEmbed { + file: file.clone(), + span_range: range_end..range_end, + }); + fragment_range = &mut self.pending_batch.last_mut().unwrap().span_range; + } + + fragment_range.end = ix + 1; + self.pending_batch_token_count += span_token_count; + } + } + + pub fn flush(&mut self) { + let batch = mem::take(&mut self.pending_batch); + self.pending_batch_token_count = 0; + if batch.is_empty() { + return; + } + + let finished_files_tx = self.finished_files_tx.clone(); + let embedding_provider = self.embedding_provider.clone(); + + self.executor + .spawn(async move { + let mut spans = Vec::new(); + for fragment in &batch { + let file = fragment.file.lock(); + spans.extend( + file.spans[fragment.span_range.clone()] + .iter() + .filter(|d| d.embedding.is_none()) + .map(|d| d.content.clone()), + ); + } + + // If spans is 0, just send the fragment to the finished files if its the last one. + if spans.is_empty() { + for fragment in batch.clone() { + if let Some(file) = Arc::into_inner(fragment.file) { + finished_files_tx.try_send(file.into_inner()).unwrap(); + } + } + return; + }; + + match embedding_provider.embed_batch(spans).await { + Ok(embeddings) => { + let mut embeddings = embeddings.into_iter(); + for fragment in batch { + for span in &mut fragment.file.lock().spans[fragment.span_range.clone()] + .iter_mut() + .filter(|d| d.embedding.is_none()) + { + if let Some(embedding) = embeddings.next() { + span.embedding = Some(embedding); + } else { + log::error!("number of embeddings != number of documents"); + } + } + + if let Some(file) = Arc::into_inner(fragment.file) { + finished_files_tx.try_send(file.into_inner()).unwrap(); + } + } + } + Err(error) => { + log::error!("{:?}", error); + } + } + }) + .detach(); + } + + pub fn finished_files(&self) -> channel::Receiver { + self.finished_files_rx.clone() + } +} diff --git a/crates/semantic_index2/src/parsing.rs b/crates/semantic_index2/src/parsing.rs new file mode 100644 index 0000000000..cb15ca453b --- /dev/null +++ b/crates/semantic_index2/src/parsing.rs @@ -0,0 +1,414 @@ +use ai::{ + embedding::{Embedding, EmbeddingProvider}, + models::TruncationDirection, +}; +use anyhow::{anyhow, Result}; +use language::{Grammar, Language}; +use rusqlite::{ + types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef}, + ToSql, +}; +use sha1::{Digest, Sha1}; +use std::{ + borrow::Cow, + cmp::{self, Reverse}, + collections::HashSet, + ops::Range, + path::Path, + sync::Arc, +}; +use tree_sitter::{Parser, QueryCursor}; + +#[derive(Debug, PartialEq, Eq, Clone, Hash)] +pub struct SpanDigest(pub [u8; 20]); + +impl FromSql for SpanDigest { + fn column_result(value: ValueRef) -> FromSqlResult { + let blob = value.as_blob()?; + let bytes = + blob.try_into() + .map_err(|_| rusqlite::types::FromSqlError::InvalidBlobSize { + expected_size: 20, + blob_size: blob.len(), + })?; + return Ok(SpanDigest(bytes)); + } +} + +impl ToSql for SpanDigest { + fn to_sql(&self) -> rusqlite::Result { + self.0.to_sql() + } +} + +impl From<&'_ str> for SpanDigest { + fn from(value: &'_ str) -> Self { + let mut sha1 = Sha1::new(); + sha1.update(value); + Self(sha1.finalize().into()) + } +} + +#[derive(Debug, PartialEq, Clone)] +pub struct Span { + pub name: String, + pub range: Range, + pub content: String, + pub embedding: Option, + pub digest: SpanDigest, + pub token_count: usize, +} + +const CODE_CONTEXT_TEMPLATE: &str = + "The below code snippet is from file ''\n\n```\n\n```"; +const ENTIRE_FILE_TEMPLATE: &str = + "The below snippet is from file ''\n\n```\n\n```"; +const MARKDOWN_CONTEXT_TEMPLATE: &str = "The below file contents is from file ''\n\n"; +pub const PARSEABLE_ENTIRE_FILE_TYPES: &[&str] = &[ + "TOML", "YAML", "CSS", "HEEX", "ERB", "SVELTE", "HTML", "Scheme", +]; + +pub struct CodeContextRetriever { + pub parser: Parser, + pub cursor: QueryCursor, + pub embedding_provider: Arc, +} + +// Every match has an item, this represents the fundamental treesitter symbol and anchors the search +// Every match has one or more 'name' captures. These indicate the display range of the item for deduplication. +// If there are preceeding comments, we track this with a context capture +// If there is a piece that should be collapsed in hierarchical queries, we capture it with a collapse capture +// If there is a piece that should be kept inside a collapsed node, we capture it with a keep capture +#[derive(Debug, Clone)] +pub struct CodeContextMatch { + pub start_col: usize, + pub item_range: Option>, + pub name_range: Option>, + pub context_ranges: Vec>, + pub collapse_ranges: Vec>, +} + +impl CodeContextRetriever { + pub fn new(embedding_provider: Arc) -> Self { + Self { + parser: Parser::new(), + cursor: QueryCursor::new(), + embedding_provider, + } + } + + fn parse_entire_file( + &self, + relative_path: Option<&Path>, + language_name: Arc, + content: &str, + ) -> Result> { + let document_span = ENTIRE_FILE_TEMPLATE + .replace( + "", + &relative_path.map_or(Cow::Borrowed("untitled"), |path| path.to_string_lossy()), + ) + .replace("", language_name.as_ref()) + .replace("", &content); + let digest = SpanDigest::from(document_span.as_str()); + let model = self.embedding_provider.base_model(); + let document_span = model.truncate( + &document_span, + model.capacity()?, + ai::models::TruncationDirection::End, + )?; + let token_count = model.count_tokens(&document_span)?; + + Ok(vec![Span { + range: 0..content.len(), + content: document_span, + embedding: Default::default(), + name: language_name.to_string(), + digest, + token_count, + }]) + } + + fn parse_markdown_file( + &self, + relative_path: Option<&Path>, + content: &str, + ) -> Result> { + let document_span = MARKDOWN_CONTEXT_TEMPLATE + .replace( + "", + &relative_path.map_or(Cow::Borrowed("untitled"), |path| path.to_string_lossy()), + ) + .replace("", &content); + let digest = SpanDigest::from(document_span.as_str()); + + let model = self.embedding_provider.base_model(); + let document_span = model.truncate( + &document_span, + model.capacity()?, + ai::models::TruncationDirection::End, + )?; + let token_count = model.count_tokens(&document_span)?; + + Ok(vec![Span { + range: 0..content.len(), + content: document_span, + embedding: None, + name: "Markdown".to_string(), + digest, + token_count, + }]) + } + + fn get_matches_in_file( + &mut self, + content: &str, + grammar: &Arc, + ) -> Result> { + let embedding_config = grammar + .embedding_config + .as_ref() + .ok_or_else(|| anyhow!("no embedding queries"))?; + self.parser.set_language(grammar.ts_language).unwrap(); + + let tree = self + .parser + .parse(&content, None) + .ok_or_else(|| anyhow!("parsing failed"))?; + + let mut captures: Vec = Vec::new(); + let mut collapse_ranges: Vec> = Vec::new(); + let mut keep_ranges: Vec> = Vec::new(); + for mat in self.cursor.matches( + &embedding_config.query, + tree.root_node(), + content.as_bytes(), + ) { + let mut start_col = 0; + let mut item_range: Option> = None; + let mut name_range: Option> = None; + let mut context_ranges: Vec> = Vec::new(); + collapse_ranges.clear(); + keep_ranges.clear(); + for capture in mat.captures { + if capture.index == embedding_config.item_capture_ix { + item_range = Some(capture.node.byte_range()); + start_col = capture.node.start_position().column; + } else if Some(capture.index) == embedding_config.name_capture_ix { + name_range = Some(capture.node.byte_range()); + } else if Some(capture.index) == embedding_config.context_capture_ix { + context_ranges.push(capture.node.byte_range()); + } else if Some(capture.index) == embedding_config.collapse_capture_ix { + collapse_ranges.push(capture.node.byte_range()); + } else if Some(capture.index) == embedding_config.keep_capture_ix { + keep_ranges.push(capture.node.byte_range()); + } + } + + captures.push(CodeContextMatch { + start_col, + item_range, + name_range, + context_ranges, + collapse_ranges: subtract_ranges(&collapse_ranges, &keep_ranges), + }); + } + Ok(captures) + } + + pub fn parse_file_with_template( + &mut self, + relative_path: Option<&Path>, + content: &str, + language: Arc, + ) -> Result> { + let language_name = language.name(); + + if PARSEABLE_ENTIRE_FILE_TYPES.contains(&language_name.as_ref()) { + return self.parse_entire_file(relative_path, language_name, &content); + } else if ["Markdown", "Plain Text"].contains(&language_name.as_ref()) { + return self.parse_markdown_file(relative_path, &content); + } + + let mut spans = self.parse_file(content, language)?; + for span in &mut spans { + let document_content = CODE_CONTEXT_TEMPLATE + .replace( + "", + &relative_path.map_or(Cow::Borrowed("untitled"), |path| path.to_string_lossy()), + ) + .replace("", language_name.as_ref()) + .replace("item", &span.content); + + let model = self.embedding_provider.base_model(); + let document_content = model.truncate( + &document_content, + model.capacity()?, + TruncationDirection::End, + )?; + let token_count = model.count_tokens(&document_content)?; + + span.content = document_content; + span.token_count = token_count; + } + Ok(spans) + } + + pub fn parse_file(&mut self, content: &str, language: Arc) -> Result> { + let grammar = language + .grammar() + .ok_or_else(|| anyhow!("no grammar for language"))?; + + // Iterate through query matches + let matches = self.get_matches_in_file(content, grammar)?; + + let language_scope = language.default_scope(); + let placeholder = language_scope.collapsed_placeholder(); + + let mut spans = Vec::new(); + let mut collapsed_ranges_within = Vec::new(); + let mut parsed_name_ranges = HashSet::new(); + for (i, context_match) in matches.iter().enumerate() { + // Items which are collapsible but not embeddable have no item range + let item_range = if let Some(item_range) = context_match.item_range.clone() { + item_range + } else { + continue; + }; + + // Checks for deduplication + let name; + if let Some(name_range) = context_match.name_range.clone() { + name = content + .get(name_range.clone()) + .map_or(String::new(), |s| s.to_string()); + if parsed_name_ranges.contains(&name_range) { + continue; + } + parsed_name_ranges.insert(name_range); + } else { + name = String::new(); + } + + collapsed_ranges_within.clear(); + 'outer: for remaining_match in &matches[(i + 1)..] { + for collapsed_range in &remaining_match.collapse_ranges { + if item_range.start <= collapsed_range.start + && item_range.end >= collapsed_range.end + { + collapsed_ranges_within.push(collapsed_range.clone()); + } else { + break 'outer; + } + } + } + + collapsed_ranges_within.sort_by_key(|r| (r.start, Reverse(r.end))); + + let mut span_content = String::new(); + for context_range in &context_match.context_ranges { + add_content_from_range( + &mut span_content, + content, + context_range.clone(), + context_match.start_col, + ); + span_content.push_str("\n"); + } + + let mut offset = item_range.start; + for collapsed_range in &collapsed_ranges_within { + if collapsed_range.start > offset { + add_content_from_range( + &mut span_content, + content, + offset..collapsed_range.start, + context_match.start_col, + ); + offset = collapsed_range.start; + } + + if collapsed_range.end > offset { + span_content.push_str(placeholder); + offset = collapsed_range.end; + } + } + + if offset < item_range.end { + add_content_from_range( + &mut span_content, + content, + offset..item_range.end, + context_match.start_col, + ); + } + + let sha1 = SpanDigest::from(span_content.as_str()); + spans.push(Span { + name, + content: span_content, + range: item_range.clone(), + embedding: None, + digest: sha1, + token_count: 0, + }) + } + + return Ok(spans); + } +} + +pub(crate) fn subtract_ranges( + ranges: &[Range], + ranges_to_subtract: &[Range], +) -> Vec> { + let mut result = Vec::new(); + + let mut ranges_to_subtract = ranges_to_subtract.iter().peekable(); + + for range in ranges { + let mut offset = range.start; + + while offset < range.end { + if let Some(range_to_subtract) = ranges_to_subtract.peek() { + if offset < range_to_subtract.start { + let next_offset = cmp::min(range_to_subtract.start, range.end); + result.push(offset..next_offset); + offset = next_offset; + } else { + let next_offset = cmp::min(range_to_subtract.end, range.end); + offset = next_offset; + } + + if offset >= range_to_subtract.end { + ranges_to_subtract.next(); + } + } else { + result.push(offset..range.end); + offset = range.end; + } + } + } + + result +} + +fn add_content_from_range( + output: &mut String, + content: &str, + range: Range, + start_col: usize, +) { + for mut line in content.get(range.clone()).unwrap_or("").lines() { + for _ in 0..start_col { + if line.starts_with(' ') { + line = &line[1..]; + } else { + break; + } + } + output.push_str(line); + output.push('\n'); + } + output.pop(); +} diff --git a/crates/semantic_index2/src/semantic_index.rs b/crates/semantic_index2/src/semantic_index.rs new file mode 100644 index 0000000000..0b207b0bf6 --- /dev/null +++ b/crates/semantic_index2/src/semantic_index.rs @@ -0,0 +1,1280 @@ +mod db; +mod embedding_queue; +mod parsing; +pub mod semantic_index_settings; + +#[cfg(test)] +mod semantic_index_tests; + +use crate::semantic_index_settings::SemanticIndexSettings; +use ai::embedding::{Embedding, EmbeddingProvider}; +use ai::providers::open_ai::OpenAIEmbeddingProvider; +use anyhow::{anyhow, Context as _, Result}; +use collections::{BTreeMap, HashMap, HashSet}; +use db::VectorDatabase; +use embedding_queue::{EmbeddingQueue, FileToEmbed}; +use futures::{future, FutureExt, StreamExt}; +use gpui::{ + AppContext, AsyncAppContext, BorrowWindow, Context, Model, ModelContext, Task, ViewContext, + WeakModel, +}; +use language::{Anchor, Bias, Buffer, Language, LanguageRegistry}; +use lazy_static::lazy_static; +use ordered_float::OrderedFloat; +use parking_lot::Mutex; +use parsing::{CodeContextRetriever, Span, SpanDigest, PARSEABLE_ENTIRE_FILE_TYPES}; +use postage::watch; +use project::{Fs, PathChange, Project, ProjectEntryId, Worktree, WorktreeId}; +use settings::Settings; +use smol::channel; +use std::{ + cmp::Reverse, + env, + future::Future, + mem, + ops::Range, + path::{Path, PathBuf}, + sync::{Arc, Weak}, + time::{Duration, Instant, SystemTime}, +}; +use util::paths::PathMatcher; +use util::{channel::RELEASE_CHANNEL_NAME, http::HttpClient, paths::EMBEDDINGS_DIR, ResultExt}; +use workspace::Workspace; + +const SEMANTIC_INDEX_VERSION: usize = 11; +const BACKGROUND_INDEXING_DELAY: Duration = Duration::from_secs(5 * 60); +const EMBEDDING_QUEUE_FLUSH_TIMEOUT: Duration = Duration::from_millis(250); + +lazy_static! { + static ref OPENAI_API_KEY: Option = env::var("OPENAI_API_KEY").ok(); +} + +pub fn init( + fs: Arc, + http_client: Arc, + language_registry: Arc, + cx: &mut AppContext, +) { + SemanticIndexSettings::register(cx); + + let db_file_path = EMBEDDINGS_DIR + .join(Path::new(RELEASE_CHANNEL_NAME.as_str())) + .join("embeddings_db"); + + cx.observe_new_views( + |workspace: &mut Workspace, cx: &mut ViewContext| { + let Some(semantic_index) = SemanticIndex::global(cx) else { + return; + }; + let project = workspace.project().clone(); + + if project.read(cx).is_local() { + cx.app_mut() + .spawn(|mut cx| async move { + let previously_indexed = semantic_index + .update(&mut cx, |index, cx| { + index.project_previously_indexed(&project, cx) + })? + .await?; + if previously_indexed { + semantic_index + .update(&mut cx, |index, cx| index.index_project(project, cx))? + .await?; + } + anyhow::Ok(()) + }) + .detach_and_log_err(cx); + } + }, + ) + .detach(); + + cx.spawn(move |cx| async move { + let semantic_index = SemanticIndex::new( + fs, + db_file_path, + Arc::new(OpenAIEmbeddingProvider::new( + http_client, + cx.background_executor().clone(), + )), + language_registry, + cx.clone(), + ) + .await?; + + cx.update(|cx| cx.set_global(semantic_index.clone()))?; + + anyhow::Ok(()) + }) + .detach(); +} + +#[derive(Copy, Clone, Debug)] +pub enum SemanticIndexStatus { + NotAuthenticated, + NotIndexed, + Indexed, + Indexing { + remaining_files: usize, + rate_limit_expiry: Option, + }, +} + +pub struct SemanticIndex { + fs: Arc, + db: VectorDatabase, + embedding_provider: Arc, + language_registry: Arc, + parsing_files_tx: channel::Sender<(Arc>, PendingFile)>, + _embedding_task: Task<()>, + _parsing_files_tasks: Vec>, + projects: HashMap, ProjectState>, +} + +struct ProjectState { + worktrees: HashMap, + pending_file_count_rx: watch::Receiver, + pending_file_count_tx: Arc>>, + pending_index: usize, + _subscription: gpui::Subscription, + _observe_pending_file_count: Task<()>, +} + +enum WorktreeState { + Registering(RegisteringWorktreeState), + Registered(RegisteredWorktreeState), +} + +impl WorktreeState { + fn is_registered(&self) -> bool { + matches!(self, Self::Registered(_)) + } + + fn paths_changed( + &mut self, + changes: Arc<[(Arc, ProjectEntryId, PathChange)]>, + worktree: &Worktree, + ) { + let changed_paths = match self { + Self::Registering(state) => &mut state.changed_paths, + Self::Registered(state) => &mut state.changed_paths, + }; + + for (path, entry_id, change) in changes.iter() { + let Some(entry) = worktree.entry_for_id(*entry_id) else { + continue; + }; + if entry.is_ignored || entry.is_symlink || entry.is_external || entry.is_dir() { + continue; + } + changed_paths.insert( + path.clone(), + ChangedPathInfo { + mtime: entry.mtime, + is_deleted: *change == PathChange::Removed, + }, + ); + } + } +} + +struct RegisteringWorktreeState { + changed_paths: BTreeMap, ChangedPathInfo>, + done_rx: watch::Receiver>, + _registration: Task<()>, +} + +impl RegisteringWorktreeState { + fn done(&self) -> impl Future { + let mut done_rx = self.done_rx.clone(); + async move { + while let Some(result) = done_rx.next().await { + if result.is_some() { + break; + } + } + } + } +} + +struct RegisteredWorktreeState { + db_id: i64, + changed_paths: BTreeMap, ChangedPathInfo>, +} + +struct ChangedPathInfo { + mtime: SystemTime, + is_deleted: bool, +} + +#[derive(Clone)] +pub struct JobHandle { + /// The outer Arc is here to count the clones of a JobHandle instance; + /// when the last handle to a given job is dropped, we decrement a counter (just once). + tx: Arc>>>, +} + +impl JobHandle { + fn new(tx: &Arc>>) -> Self { + *tx.lock().borrow_mut() += 1; + Self { + tx: Arc::new(Arc::downgrade(&tx)), + } + } +} + +impl ProjectState { + fn new(subscription: gpui::Subscription, cx: &mut ModelContext) -> Self { + let (pending_file_count_tx, pending_file_count_rx) = watch::channel_with(0); + let pending_file_count_tx = Arc::new(Mutex::new(pending_file_count_tx)); + Self { + worktrees: Default::default(), + pending_file_count_rx: pending_file_count_rx.clone(), + pending_file_count_tx, + pending_index: 0, + _subscription: subscription, + _observe_pending_file_count: cx.spawn({ + let mut pending_file_count_rx = pending_file_count_rx.clone(); + |this, mut cx| async move { + while let Some(_) = pending_file_count_rx.next().await { + if this.update(&mut cx, |_, cx| cx.notify()).is_err() { + break; + } + } + } + }), + } + } + + fn worktree_id_for_db_id(&self, id: i64) -> Option { + self.worktrees + .iter() + .find_map(|(worktree_id, worktree_state)| match worktree_state { + WorktreeState::Registered(state) if state.db_id == id => Some(*worktree_id), + _ => None, + }) + } +} + +#[derive(Clone)] +pub struct PendingFile { + worktree_db_id: i64, + relative_path: Arc, + absolute_path: PathBuf, + language: Option>, + modified_time: SystemTime, + job_handle: JobHandle, +} + +#[derive(Clone)] +pub struct SearchResult { + pub buffer: Model, + pub range: Range, + pub similarity: OrderedFloat, +} + +impl SemanticIndex { + pub fn global(cx: &mut AppContext) -> Option> { + if cx.has_global::>() { + Some(cx.global::>().clone()) + } else { + None + } + } + + pub fn authenticate(&mut self, cx: &mut AppContext) -> bool { + if !self.embedding_provider.has_credentials() { + self.embedding_provider.retrieve_credentials(cx); + } else { + return true; + } + + self.embedding_provider.has_credentials() + } + + pub fn is_authenticated(&self) -> bool { + self.embedding_provider.has_credentials() + } + + pub fn enabled(cx: &AppContext) -> bool { + SemanticIndexSettings::get_global(cx).enabled + } + + pub fn status(&self, project: &Model) -> SemanticIndexStatus { + if !self.is_authenticated() { + return SemanticIndexStatus::NotAuthenticated; + } + + if let Some(project_state) = self.projects.get(&project.downgrade()) { + if project_state + .worktrees + .values() + .all(|worktree| worktree.is_registered()) + && project_state.pending_index == 0 + { + SemanticIndexStatus::Indexed + } else { + SemanticIndexStatus::Indexing { + remaining_files: project_state.pending_file_count_rx.borrow().clone(), + rate_limit_expiry: self.embedding_provider.rate_limit_expiration(), + } + } + } else { + SemanticIndexStatus::NotIndexed + } + } + + pub async fn new( + fs: Arc, + database_path: PathBuf, + embedding_provider: Arc, + language_registry: Arc, + mut cx: AsyncAppContext, + ) -> Result> { + let t0 = Instant::now(); + let database_path = Arc::from(database_path); + let db = VectorDatabase::new(fs.clone(), database_path, cx.background_executor().clone()) + .await?; + + log::trace!( + "db initialization took {:?} milliseconds", + t0.elapsed().as_millis() + ); + + cx.build_model(|cx| { + let t0 = Instant::now(); + let embedding_queue = + EmbeddingQueue::new(embedding_provider.clone(), cx.background_executor().clone()); + let _embedding_task = cx.background_executor().spawn({ + let embedded_files = embedding_queue.finished_files(); + let db = db.clone(); + async move { + while let Ok(file) = embedded_files.recv().await { + db.insert_file(file.worktree_id, file.path, file.mtime, file.spans) + .await + .log_err(); + } + } + }); + + // Parse files into embeddable spans. + let (parsing_files_tx, parsing_files_rx) = + channel::unbounded::<(Arc>, PendingFile)>(); + let embedding_queue = Arc::new(Mutex::new(embedding_queue)); + let mut _parsing_files_tasks = Vec::new(); + for _ in 0..cx.background_executor().num_cpus() { + let fs = fs.clone(); + let mut parsing_files_rx = parsing_files_rx.clone(); + let embedding_provider = embedding_provider.clone(); + let embedding_queue = embedding_queue.clone(); + let background = cx.background_executor().clone(); + _parsing_files_tasks.push(cx.background_executor().spawn(async move { + let mut retriever = CodeContextRetriever::new(embedding_provider.clone()); + loop { + let mut timer = background.timer(EMBEDDING_QUEUE_FLUSH_TIMEOUT).fuse(); + let mut next_file_to_parse = parsing_files_rx.next().fuse(); + futures::select_biased! { + next_file_to_parse = next_file_to_parse => { + if let Some((embeddings_for_digest, pending_file)) = next_file_to_parse { + Self::parse_file( + &fs, + pending_file, + &mut retriever, + &embedding_queue, + &embeddings_for_digest, + ) + .await + } else { + break; + } + }, + _ = timer => { + embedding_queue.lock().flush(); + } + } + } + })); + } + + log::trace!( + "semantic index task initialization took {:?} milliseconds", + t0.elapsed().as_millis() + ); + Self { + fs, + db, + embedding_provider, + language_registry, + parsing_files_tx, + _embedding_task, + _parsing_files_tasks, + projects: Default::default(), + } + }) + } + + async fn parse_file( + fs: &Arc, + pending_file: PendingFile, + retriever: &mut CodeContextRetriever, + embedding_queue: &Arc>, + embeddings_for_digest: &HashMap, + ) { + let Some(language) = pending_file.language else { + return; + }; + + if let Some(content) = fs.load(&pending_file.absolute_path).await.log_err() { + if let Some(mut spans) = retriever + .parse_file_with_template(Some(&pending_file.relative_path), &content, language) + .log_err() + { + log::trace!( + "parsed path {:?}: {} spans", + pending_file.relative_path, + spans.len() + ); + + for span in &mut spans { + if let Some(embedding) = embeddings_for_digest.get(&span.digest) { + span.embedding = Some(embedding.to_owned()); + } + } + + embedding_queue.lock().push(FileToEmbed { + worktree_id: pending_file.worktree_db_id, + path: pending_file.relative_path, + mtime: pending_file.modified_time, + job_handle: pending_file.job_handle, + spans, + }); + } + } + } + + pub fn project_previously_indexed( + &mut self, + project: &Model, + cx: &mut ModelContext, + ) -> Task> { + let worktrees_indexed_previously = project + .read(cx) + .worktrees() + .map(|worktree| { + self.db + .worktree_previously_indexed(&worktree.read(cx).abs_path()) + }) + .collect::>(); + cx.spawn(|_, _cx| async move { + let worktree_indexed_previously = + futures::future::join_all(worktrees_indexed_previously).await; + + Ok(worktree_indexed_previously + .iter() + .filter(|worktree| worktree.is_ok()) + .all(|v| v.as_ref().log_err().is_some_and(|v| v.to_owned()))) + }) + } + + fn project_entries_changed( + &mut self, + project: Model, + worktree_id: WorktreeId, + changes: Arc<[(Arc, ProjectEntryId, PathChange)]>, + cx: &mut ModelContext, + ) { + let Some(worktree) = project.read(cx).worktree_for_id(worktree_id.clone(), cx) else { + return; + }; + let project = project.downgrade(); + let Some(project_state) = self.projects.get_mut(&project) else { + return; + }; + + let worktree = worktree.read(cx); + let worktree_state = + if let Some(worktree_state) = project_state.worktrees.get_mut(&worktree_id) { + worktree_state + } else { + return; + }; + worktree_state.paths_changed(changes, worktree); + if let WorktreeState::Registered(_) = worktree_state { + cx.spawn(|this, mut cx| async move { + cx.background_executor() + .timer(BACKGROUND_INDEXING_DELAY) + .await; + if let Some((this, project)) = this.upgrade().zip(project.upgrade()) { + this.update(&mut cx, |this, cx| { + this.index_project(project, cx).detach_and_log_err(cx) + })?; + } + anyhow::Ok(()) + }) + .detach_and_log_err(cx); + } + } + + fn register_worktree( + &mut self, + project: Model, + worktree: Model, + cx: &mut ModelContext, + ) { + let project = project.downgrade(); + let project_state = if let Some(project_state) = self.projects.get_mut(&project) { + project_state + } else { + return; + }; + let worktree = if let Some(worktree) = worktree.read(cx).as_local() { + worktree + } else { + return; + }; + let worktree_abs_path = worktree.abs_path().clone(); + let scan_complete = worktree.scan_complete(); + let worktree_id = worktree.id(); + let db = self.db.clone(); + let language_registry = self.language_registry.clone(); + let (mut done_tx, done_rx) = watch::channel(); + let registration = cx.spawn(|this, mut cx| { + async move { + let register = async { + scan_complete.await; + let db_id = db.find_or_create_worktree(worktree_abs_path).await?; + let mut file_mtimes = db.get_file_mtimes(db_id).await?; + let worktree = if let Some(project) = project.upgrade() { + project + .read_with(&cx, |project, cx| project.worktree_for_id(worktree_id, cx)) + .ok() + .flatten() + .context("worktree not found")? + } else { + return anyhow::Ok(()); + }; + let worktree = worktree.read_with(&cx, |worktree, _| worktree.snapshot())?; + let mut changed_paths = cx + .background_executor() + .spawn(async move { + let mut changed_paths = BTreeMap::new(); + for file in worktree.files(false, 0) { + let absolute_path = worktree.absolutize(&file.path); + + if file.is_external || file.is_ignored || file.is_symlink { + continue; + } + + if let Ok(language) = language_registry + .language_for_file(&absolute_path, None) + .await + { + // Test if file is valid parseable file + if !PARSEABLE_ENTIRE_FILE_TYPES + .contains(&language.name().as_ref()) + && &language.name().as_ref() != &"Markdown" + && language + .grammar() + .and_then(|grammar| grammar.embedding_config.as_ref()) + .is_none() + { + continue; + } + + let stored_mtime = file_mtimes.remove(&file.path.to_path_buf()); + let already_stored = stored_mtime + .map_or(false, |existing_mtime| { + existing_mtime == file.mtime + }); + + if !already_stored { + changed_paths.insert( + file.path.clone(), + ChangedPathInfo { + mtime: file.mtime, + is_deleted: false, + }, + ); + } + } + } + + // Clean up entries from database that are no longer in the worktree. + for (path, mtime) in file_mtimes { + changed_paths.insert( + path.into(), + ChangedPathInfo { + mtime, + is_deleted: true, + }, + ); + } + + anyhow::Ok(changed_paths) + }) + .await?; + this.update(&mut cx, |this, cx| { + let project_state = this + .projects + .get_mut(&project) + .context("project not registered")?; + let project = project.upgrade().context("project was dropped")?; + + if let Some(WorktreeState::Registering(state)) = + project_state.worktrees.remove(&worktree_id) + { + changed_paths.extend(state.changed_paths); + } + project_state.worktrees.insert( + worktree_id, + WorktreeState::Registered(RegisteredWorktreeState { + db_id, + changed_paths, + }), + ); + this.index_project(project, cx).detach_and_log_err(cx); + + anyhow::Ok(()) + })??; + + anyhow::Ok(()) + }; + + if register.await.log_err().is_none() { + // Stop tracking this worktree if the registration failed. + this.update(&mut cx, |this, _| { + this.projects.get_mut(&project).map(|project_state| { + project_state.worktrees.remove(&worktree_id); + }); + }) + .ok(); + } + + *done_tx.borrow_mut() = Some(()); + } + }); + project_state.worktrees.insert( + worktree_id, + WorktreeState::Registering(RegisteringWorktreeState { + changed_paths: Default::default(), + done_rx, + _registration: registration, + }), + ); + } + + fn project_worktrees_changed(&mut self, project: Model, cx: &mut ModelContext) { + let project_state = if let Some(project_state) = self.projects.get_mut(&project.downgrade()) + { + project_state + } else { + return; + }; + + let mut worktrees = project + .read(cx) + .worktrees() + .filter(|worktree| worktree.read(cx).is_local()) + .collect::>(); + let worktree_ids = worktrees + .iter() + .map(|worktree| worktree.read(cx).id()) + .collect::>(); + + // Remove worktrees that are no longer present + project_state + .worktrees + .retain(|worktree_id, _| worktree_ids.contains(worktree_id)); + + // Register new worktrees + worktrees.retain(|worktree| { + let worktree_id = worktree.read(cx).id(); + !project_state.worktrees.contains_key(&worktree_id) + }); + for worktree in worktrees { + self.register_worktree(project.clone(), worktree, cx); + } + } + + pub fn pending_file_count(&self, project: &Model) -> Option> { + Some( + self.projects + .get(&project.downgrade())? + .pending_file_count_rx + .clone(), + ) + } + + pub fn search_project( + &mut self, + project: Model, + query: String, + limit: usize, + includes: Vec, + excludes: Vec, + cx: &mut ModelContext, + ) -> Task>> { + if query.is_empty() { + return Task::ready(Ok(Vec::new())); + } + + let index = self.index_project(project.clone(), cx); + let embedding_provider = self.embedding_provider.clone(); + + cx.spawn(|this, mut cx| async move { + index.await?; + let t0 = Instant::now(); + + let query = embedding_provider + .embed_batch(vec![query]) + .await? + .pop() + .context("could not embed query")?; + log::trace!("Embedding Search Query: {:?}ms", t0.elapsed().as_millis()); + + let search_start = Instant::now(); + let modified_buffer_results = this.update(&mut cx, |this, cx| { + this.search_modified_buffers( + &project, + query.clone(), + limit, + &includes, + &excludes, + cx, + ) + })?; + let file_results = this.update(&mut cx, |this, cx| { + this.search_files(project, query, limit, includes, excludes, cx) + })?; + let (modified_buffer_results, file_results) = + futures::join!(modified_buffer_results, file_results); + + // Weave together the results from modified buffers and files. + let mut results = Vec::new(); + let mut modified_buffers = HashSet::default(); + for result in modified_buffer_results.log_err().unwrap_or_default() { + modified_buffers.insert(result.buffer.clone()); + results.push(result); + } + for result in file_results.log_err().unwrap_or_default() { + if !modified_buffers.contains(&result.buffer) { + results.push(result); + } + } + results.sort_by_key(|result| Reverse(result.similarity)); + results.truncate(limit); + log::trace!("Semantic search took {:?}", search_start.elapsed()); + Ok(results) + }) + } + + pub fn search_files( + &mut self, + project: Model, + query: Embedding, + limit: usize, + includes: Vec, + excludes: Vec, + cx: &mut ModelContext, + ) -> Task>> { + let db_path = self.db.path().clone(); + let fs = self.fs.clone(); + cx.spawn(|this, mut cx| async move { + let database = VectorDatabase::new( + fs.clone(), + db_path.clone(), + cx.background_executor().clone(), + ) + .await?; + + let worktree_db_ids = this.read_with(&cx, |this, _| { + let project_state = this + .projects + .get(&project.downgrade()) + .context("project was not indexed")?; + let worktree_db_ids = project_state + .worktrees + .values() + .filter_map(|worktree| { + if let WorktreeState::Registered(worktree) = worktree { + Some(worktree.db_id) + } else { + None + } + }) + .collect::>(); + anyhow::Ok(worktree_db_ids) + })??; + + let file_ids = database + .retrieve_included_file_ids(&worktree_db_ids, &includes, &excludes) + .await?; + + let batch_n = cx.background_executor().num_cpus(); + let ids_len = file_ids.clone().len(); + let minimum_batch_size = 50; + + let batch_size = { + let size = ids_len / batch_n; + if size < minimum_batch_size { + minimum_batch_size + } else { + size + } + }; + + let mut batch_results = Vec::new(); + for batch in file_ids.chunks(batch_size) { + let batch = batch.into_iter().map(|v| *v).collect::>(); + let limit = limit.clone(); + let fs = fs.clone(); + let db_path = db_path.clone(); + let query = query.clone(); + if let Some(db) = + VectorDatabase::new(fs, db_path.clone(), cx.background_executor().clone()) + .await + .log_err() + { + batch_results.push(async move { + db.top_k_search(&query, limit, batch.as_slice()).await + }); + } + } + + let batch_results = futures::future::join_all(batch_results).await; + + let mut results = Vec::new(); + for batch_result in batch_results { + if batch_result.is_ok() { + for (id, similarity) in batch_result.unwrap() { + let ix = match results + .binary_search_by_key(&Reverse(similarity), |(_, s)| Reverse(*s)) + { + Ok(ix) => ix, + Err(ix) => ix, + }; + + results.insert(ix, (id, similarity)); + results.truncate(limit); + } + } + } + + let ids = results.iter().map(|(id, _)| *id).collect::>(); + let scores = results + .into_iter() + .map(|(_, score)| score) + .collect::>(); + let spans = database.spans_for_ids(ids.as_slice()).await?; + + let mut tasks = Vec::new(); + let mut ranges = Vec::new(); + let weak_project = project.downgrade(); + project.update(&mut cx, |project, cx| { + let this = this.upgrade().context("index was dropped")?; + for (worktree_db_id, file_path, byte_range) in spans { + let project_state = + if let Some(state) = this.read(cx).projects.get(&weak_project) { + state + } else { + return Err(anyhow!("project not added")); + }; + if let Some(worktree_id) = project_state.worktree_id_for_db_id(worktree_db_id) { + tasks.push(project.open_buffer((worktree_id, file_path), cx)); + ranges.push(byte_range); + } + } + + Ok(()) + })??; + + let buffers = futures::future::join_all(tasks).await; + Ok(buffers + .into_iter() + .zip(ranges) + .zip(scores) + .filter_map(|((buffer, range), similarity)| { + let buffer = buffer.log_err()?; + let range = buffer + .read_with(&cx, |buffer, _| { + let start = buffer.clip_offset(range.start, Bias::Left); + let end = buffer.clip_offset(range.end, Bias::Right); + buffer.anchor_before(start)..buffer.anchor_after(end) + }) + .log_err()?; + Some(SearchResult { + buffer, + range, + similarity, + }) + }) + .collect()) + }) + } + + fn search_modified_buffers( + &self, + project: &Model, + query: Embedding, + limit: usize, + includes: &[PathMatcher], + excludes: &[PathMatcher], + cx: &mut ModelContext, + ) -> Task>> { + let modified_buffers = project + .read(cx) + .opened_buffers() + .into_iter() + .filter_map(|buffer_handle| { + let buffer = buffer_handle.read(cx); + let snapshot = buffer.snapshot(); + let excluded = snapshot.resolve_file_path(cx, false).map_or(false, |path| { + excludes.iter().any(|matcher| matcher.is_match(&path)) + }); + + let included = if includes.len() == 0 { + true + } else { + snapshot.resolve_file_path(cx, false).map_or(false, |path| { + includes.iter().any(|matcher| matcher.is_match(&path)) + }) + }; + + if buffer.is_dirty() && !excluded && included { + Some((buffer_handle, snapshot)) + } else { + None + } + }) + .collect::>(); + + let embedding_provider = self.embedding_provider.clone(); + let fs = self.fs.clone(); + let db_path = self.db.path().clone(); + let background = cx.background_executor().clone(); + cx.background_executor().spawn(async move { + let db = VectorDatabase::new(fs, db_path.clone(), background).await?; + let mut results = Vec::::new(); + + let mut retriever = CodeContextRetriever::new(embedding_provider.clone()); + for (buffer, snapshot) in modified_buffers { + let language = snapshot + .language_at(0) + .cloned() + .unwrap_or_else(|| language::PLAIN_TEXT.clone()); + let mut spans = retriever + .parse_file_with_template(None, &snapshot.text(), language) + .log_err() + .unwrap_or_default(); + if Self::embed_spans(&mut spans, embedding_provider.as_ref(), &db) + .await + .log_err() + .is_some() + { + for span in spans { + let similarity = span.embedding.unwrap().similarity(&query); + let ix = match results + .binary_search_by_key(&Reverse(similarity), |result| { + Reverse(result.similarity) + }) { + Ok(ix) => ix, + Err(ix) => ix, + }; + + let range = { + let start = snapshot.clip_offset(span.range.start, Bias::Left); + let end = snapshot.clip_offset(span.range.end, Bias::Right); + snapshot.anchor_before(start)..snapshot.anchor_after(end) + }; + + results.insert( + ix, + SearchResult { + buffer: buffer.clone(), + range, + similarity, + }, + ); + results.truncate(limit); + } + } + } + + Ok(results) + }) + } + + pub fn index_project( + &mut self, + project: Model, + cx: &mut ModelContext, + ) -> Task> { + if !self.is_authenticated() { + if !self.authenticate(cx) { + return Task::ready(Err(anyhow!("user is not authenticated"))); + } + } + + if !self.projects.contains_key(&project.downgrade()) { + let subscription = cx.subscribe(&project, |this, project, event, cx| match event { + project::Event::WorktreeAdded | project::Event::WorktreeRemoved(_) => { + this.project_worktrees_changed(project.clone(), cx); + } + project::Event::WorktreeUpdatedEntries(worktree_id, changes) => { + this.project_entries_changed(project, *worktree_id, changes.clone(), cx); + } + _ => {} + }); + let project_state = ProjectState::new(subscription, cx); + self.projects.insert(project.downgrade(), project_state); + self.project_worktrees_changed(project.clone(), cx); + } + let project_state = self.projects.get_mut(&project.downgrade()).unwrap(); + project_state.pending_index += 1; + cx.notify(); + + let mut pending_file_count_rx = project_state.pending_file_count_rx.clone(); + let db = self.db.clone(); + let language_registry = self.language_registry.clone(); + let parsing_files_tx = self.parsing_files_tx.clone(); + let worktree_registration = self.wait_for_worktree_registration(&project, cx); + + cx.spawn(|this, mut cx| async move { + worktree_registration.await?; + + let mut pending_files = Vec::new(); + let mut files_to_delete = Vec::new(); + this.update(&mut cx, |this, cx| { + let project_state = this + .projects + .get_mut(&project.downgrade()) + .context("project was dropped")?; + let pending_file_count_tx = &project_state.pending_file_count_tx; + + project_state + .worktrees + .retain(|worktree_id, worktree_state| { + let worktree = if let Some(worktree) = + project.read(cx).worktree_for_id(*worktree_id, cx) + { + worktree + } else { + return false; + }; + let worktree_state = + if let WorktreeState::Registered(worktree_state) = worktree_state { + worktree_state + } else { + return true; + }; + + worktree_state.changed_paths.retain(|path, info| { + if info.is_deleted { + files_to_delete.push((worktree_state.db_id, path.clone())); + } else { + let absolute_path = worktree.read(cx).absolutize(path); + let job_handle = JobHandle::new(pending_file_count_tx); + pending_files.push(PendingFile { + absolute_path, + relative_path: path.clone(), + language: None, + job_handle, + modified_time: info.mtime, + worktree_db_id: worktree_state.db_id, + }); + } + + false + }); + true + }); + + anyhow::Ok(()) + })??; + + cx.background_executor() + .spawn(async move { + for (worktree_db_id, path) in files_to_delete { + db.delete_file(worktree_db_id, path).await.log_err(); + } + + let embeddings_for_digest = { + let mut files = HashMap::default(); + for pending_file in &pending_files { + files + .entry(pending_file.worktree_db_id) + .or_insert(Vec::new()) + .push(pending_file.relative_path.clone()); + } + Arc::new( + db.embeddings_for_files(files) + .await + .log_err() + .unwrap_or_default(), + ) + }; + + for mut pending_file in pending_files { + if let Ok(language) = language_registry + .language_for_file(&pending_file.relative_path, None) + .await + { + if !PARSEABLE_ENTIRE_FILE_TYPES.contains(&language.name().as_ref()) + && &language.name().as_ref() != &"Markdown" + && language + .grammar() + .and_then(|grammar| grammar.embedding_config.as_ref()) + .is_none() + { + continue; + } + pending_file.language = Some(language); + } + parsing_files_tx + .try_send((embeddings_for_digest.clone(), pending_file)) + .ok(); + } + + // Wait until we're done indexing. + while let Some(count) = pending_file_count_rx.next().await { + if count == 0 { + break; + } + } + }) + .await; + + this.update(&mut cx, |this, cx| { + let project_state = this + .projects + .get_mut(&project.downgrade()) + .context("project was dropped")?; + project_state.pending_index -= 1; + cx.notify(); + anyhow::Ok(()) + })??; + + Ok(()) + }) + } + + fn wait_for_worktree_registration( + &self, + project: &Model, + cx: &mut ModelContext, + ) -> Task> { + let project = project.downgrade(); + cx.spawn(|this, cx| async move { + loop { + let mut pending_worktrees = Vec::new(); + this.upgrade() + .context("semantic index dropped")? + .read_with(&cx, |this, _| { + if let Some(project) = this.projects.get(&project) { + for worktree in project.worktrees.values() { + if let WorktreeState::Registering(worktree) = worktree { + pending_worktrees.push(worktree.done()); + } + } + } + })?; + + if pending_worktrees.is_empty() { + break; + } else { + future::join_all(pending_worktrees).await; + } + } + Ok(()) + }) + } + + async fn embed_spans( + spans: &mut [Span], + embedding_provider: &dyn EmbeddingProvider, + db: &VectorDatabase, + ) -> Result<()> { + let mut batch = Vec::new(); + let mut batch_tokens = 0; + let mut embeddings = Vec::new(); + + let digests = spans + .iter() + .map(|span| span.digest.clone()) + .collect::>(); + let embeddings_for_digests = db + .embeddings_for_digests(digests) + .await + .log_err() + .unwrap_or_default(); + + for span in &*spans { + if embeddings_for_digests.contains_key(&span.digest) { + continue; + }; + + if batch_tokens + span.token_count > embedding_provider.max_tokens_per_batch() { + let batch_embeddings = embedding_provider + .embed_batch(mem::take(&mut batch)) + .await?; + embeddings.extend(batch_embeddings); + batch_tokens = 0; + } + + batch_tokens += span.token_count; + batch.push(span.content.clone()); + } + + if !batch.is_empty() { + let batch_embeddings = embedding_provider + .embed_batch(mem::take(&mut batch)) + .await?; + + embeddings.extend(batch_embeddings); + } + + let mut embeddings = embeddings.into_iter(); + for span in spans { + let embedding = if let Some(embedding) = embeddings_for_digests.get(&span.digest) { + Some(embedding.clone()) + } else { + embeddings.next() + }; + let embedding = embedding.context("failed to embed spans")?; + span.embedding = Some(embedding); + } + Ok(()) + } +} + +impl Drop for JobHandle { + fn drop(&mut self) { + if let Some(inner) = Arc::get_mut(&mut self.tx) { + // This is the last instance of the JobHandle (regardless of it's origin - whether it was cloned or not) + if let Some(tx) = inner.upgrade() { + let mut tx = tx.lock(); + *tx.borrow_mut() -= 1; + } + } + } +} + +#[cfg(test)] +mod tests { + + use super::*; + #[test] + fn test_job_handle() { + let (job_count_tx, job_count_rx) = watch::channel_with(0); + let tx = Arc::new(Mutex::new(job_count_tx)); + let job_handle = JobHandle::new(&tx); + + assert_eq!(1, *job_count_rx.borrow()); + let new_job_handle = job_handle.clone(); + assert_eq!(1, *job_count_rx.borrow()); + drop(job_handle); + assert_eq!(1, *job_count_rx.borrow()); + drop(new_job_handle); + assert_eq!(0, *job_count_rx.borrow()); + } +} diff --git a/crates/semantic_index2/src/semantic_index_settings.rs b/crates/semantic_index2/src/semantic_index_settings.rs new file mode 100644 index 0000000000..306a38fa9c --- /dev/null +++ b/crates/semantic_index2/src/semantic_index_settings.rs @@ -0,0 +1,28 @@ +use anyhow; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use settings::Settings; + +#[derive(Deserialize, Debug)] +pub struct SemanticIndexSettings { + pub enabled: bool, +} + +#[derive(Clone, Default, Serialize, Deserialize, JsonSchema, Debug)] +pub struct SemanticIndexSettingsContent { + pub enabled: Option, +} + +impl Settings for SemanticIndexSettings { + const KEY: Option<&'static str> = Some("semantic_index"); + + type FileContent = SemanticIndexSettingsContent; + + fn load( + default_value: &Self::FileContent, + user_values: &[&Self::FileContent], + _: &mut gpui::AppContext, + ) -> anyhow::Result { + Self::load_via_json_merge(default_value, user_values) + } +} diff --git a/crates/semantic_index2/src/semantic_index_tests.rs b/crates/semantic_index2/src/semantic_index_tests.rs new file mode 100644 index 0000000000..ced08f4cbc --- /dev/null +++ b/crates/semantic_index2/src/semantic_index_tests.rs @@ -0,0 +1,1697 @@ +use crate::{ + embedding_queue::EmbeddingQueue, + parsing::{subtract_ranges, CodeContextRetriever, Span, SpanDigest}, + semantic_index_settings::SemanticIndexSettings, + FileToEmbed, JobHandle, SearchResult, SemanticIndex, EMBEDDING_QUEUE_FLUSH_TIMEOUT, +}; +use ai::test::FakeEmbeddingProvider; + +use gpui::{Task, TestAppContext}; +use language::{Language, LanguageConfig, LanguageRegistry, ToOffset}; +use parking_lot::Mutex; +use pretty_assertions::assert_eq; +use project::{project_settings::ProjectSettings, FakeFs, Fs, Project}; +use rand::{rngs::StdRng, Rng}; +use serde_json::json; +use settings::{Settings, SettingsStore}; +use std::{path::Path, sync::Arc, time::SystemTime}; +use unindent::Unindent; +use util::{paths::PathMatcher, RandomCharIter}; + +#[ctor::ctor] +fn init_logger() { + if std::env::var("RUST_LOG").is_ok() { + env_logger::init(); + } +} + +#[gpui::test] +async fn test_semantic_index(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.background_executor.clone()); + fs.insert_tree( + "/the-root", + json!({ + "src": { + "file1.rs": " + fn aaa() { + println!(\"aaaaaaaaaaaa!\"); + } + + fn zzzzz() { + println!(\"SLEEPING\"); + } + ".unindent(), + "file2.rs": " + fn bbb() { + println!(\"bbbbbbbbbbbbb!\"); + } + struct pqpqpqp {} + ".unindent(), + "file3.toml": " + ZZZZZZZZZZZZZZZZZZ = 5 + ".unindent(), + } + }), + ) + .await; + + let languages = Arc::new(LanguageRegistry::new(Task::ready(()))); + let rust_language = rust_lang(); + let toml_language = toml_lang(); + languages.add(rust_language); + languages.add(toml_language); + + let db_dir = tempdir::TempDir::new("vector-store").unwrap(); + let db_path = db_dir.path().join("db.sqlite"); + + let embedding_provider = Arc::new(FakeEmbeddingProvider::default()); + let semantic_index = SemanticIndex::new( + fs.clone(), + db_path, + embedding_provider.clone(), + languages, + cx.to_async(), + ) + .await + .unwrap(); + + let project = Project::test(fs.clone(), ["/the-root".as_ref()], cx).await; + + let search_results = semantic_index.update(cx, |store, cx| { + store.search_project( + project.clone(), + "aaaaaabbbbzz".to_string(), + 5, + vec![], + vec![], + cx, + ) + }); + let pending_file_count = + semantic_index.read_with(cx, |index, _| index.pending_file_count(&project).unwrap()); + cx.background_executor.run_until_parked(); + assert_eq!(*pending_file_count.borrow(), 3); + cx.background_executor + .advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT); + assert_eq!(*pending_file_count.borrow(), 0); + + let search_results = search_results.await.unwrap(); + assert_search_results( + &search_results, + &[ + (Path::new("src/file1.rs").into(), 0), + (Path::new("src/file2.rs").into(), 0), + (Path::new("src/file3.toml").into(), 0), + (Path::new("src/file1.rs").into(), 45), + (Path::new("src/file2.rs").into(), 45), + ], + cx, + ); + + // Test Include Files Functonality + let include_files = vec![PathMatcher::new("*.rs").unwrap()]; + let exclude_files = vec![PathMatcher::new("*.rs").unwrap()]; + let rust_only_search_results = semantic_index + .update(cx, |store, cx| { + store.search_project( + project.clone(), + "aaaaaabbbbzz".to_string(), + 5, + include_files, + vec![], + cx, + ) + }) + .await + .unwrap(); + + assert_search_results( + &rust_only_search_results, + &[ + (Path::new("src/file1.rs").into(), 0), + (Path::new("src/file2.rs").into(), 0), + (Path::new("src/file1.rs").into(), 45), + (Path::new("src/file2.rs").into(), 45), + ], + cx, + ); + + let no_rust_search_results = semantic_index + .update(cx, |store, cx| { + store.search_project( + project.clone(), + "aaaaaabbbbzz".to_string(), + 5, + vec![], + exclude_files, + cx, + ) + }) + .await + .unwrap(); + + assert_search_results( + &no_rust_search_results, + &[(Path::new("src/file3.toml").into(), 0)], + cx, + ); + + fs.save( + "/the-root/src/file2.rs".as_ref(), + &" + fn dddd() { println!(\"ddddd!\"); } + struct pqpqpqp {} + " + .unindent() + .into(), + Default::default(), + ) + .await + .unwrap(); + + cx.background_executor + .advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT); + + let prev_embedding_count = embedding_provider.embedding_count(); + let index = semantic_index.update(cx, |store, cx| store.index_project(project.clone(), cx)); + cx.background_executor.run_until_parked(); + assert_eq!(*pending_file_count.borrow(), 1); + cx.background_executor + .advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT); + assert_eq!(*pending_file_count.borrow(), 0); + index.await.unwrap(); + + assert_eq!( + embedding_provider.embedding_count() - prev_embedding_count, + 1 + ); +} + +#[gpui::test(iterations = 10)] +async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) { + let (outstanding_job_count, _) = postage::watch::channel_with(0); + let outstanding_job_count = Arc::new(Mutex::new(outstanding_job_count)); + + let files = (1..=3) + .map(|file_ix| FileToEmbed { + worktree_id: 5, + path: Path::new(&format!("path-{file_ix}")).into(), + mtime: SystemTime::now(), + spans: (0..rng.gen_range(4..22)) + .map(|document_ix| { + let content_len = rng.gen_range(10..100); + let content = RandomCharIter::new(&mut rng) + .with_simple_text() + .take(content_len) + .collect::(); + let digest = SpanDigest::from(content.as_str()); + Span { + range: 0..10, + embedding: None, + name: format!("document {document_ix}"), + content, + digest, + token_count: rng.gen_range(10..30), + } + }) + .collect(), + job_handle: JobHandle::new(&outstanding_job_count), + }) + .collect::>(); + + let embedding_provider = Arc::new(FakeEmbeddingProvider::default()); + + let mut queue = EmbeddingQueue::new(embedding_provider.clone(), cx.background_executor.clone()); + for file in &files { + queue.push(file.clone()); + } + queue.flush(); + + cx.background_executor.run_until_parked(); + let finished_files = queue.finished_files(); + let mut embedded_files: Vec<_> = files + .iter() + .map(|_| finished_files.try_recv().expect("no finished file")) + .collect(); + + let expected_files: Vec<_> = files + .iter() + .map(|file| { + let mut file = file.clone(); + for doc in &mut file.spans { + doc.embedding = Some(embedding_provider.embed_sync(doc.content.as_ref())); + } + file + }) + .collect(); + + embedded_files.sort_by_key(|f| f.path.clone()); + + assert_eq!(embedded_files, expected_files); +} + +#[track_caller] +fn assert_search_results( + actual: &[SearchResult], + expected: &[(Arc, usize)], + cx: &TestAppContext, +) { + let actual = actual + .iter() + .map(|search_result| { + search_result.buffer.read_with(cx, |buffer, _cx| { + ( + buffer.file().unwrap().path().clone(), + search_result.range.start.to_offset(buffer), + ) + }) + }) + .collect::>(); + assert_eq!(actual, expected); +} + +#[gpui::test] +async fn test_code_context_retrieval_rust() { + let language = rust_lang(); + let embedding_provider = Arc::new(FakeEmbeddingProvider::default()); + let mut retriever = CodeContextRetriever::new(embedding_provider); + + let text = " + /// A doc comment + /// that spans multiple lines + #[gpui::test] + fn a() { + b + } + + impl C for D { + } + + impl E { + // This is also a preceding comment + pub fn function_1() -> Option<()> { + unimplemented!(); + } + + // This is a preceding comment + fn function_2() -> Result<()> { + unimplemented!(); + } + } + + #[derive(Clone)] + struct D { + name: String + } + " + .unindent(); + + let documents = retriever.parse_file(&text, language).unwrap(); + + assert_documents_eq( + &documents, + &[ + ( + " + /// A doc comment + /// that spans multiple lines + #[gpui::test] + fn a() { + b + }" + .unindent(), + text.find("fn a").unwrap(), + ), + ( + " + impl C for D { + }" + .unindent(), + text.find("impl C").unwrap(), + ), + ( + " + impl E { + // This is also a preceding comment + pub fn function_1() -> Option<()> { /* ... */ } + + // This is a preceding comment + fn function_2() -> Result<()> { /* ... */ } + }" + .unindent(), + text.find("impl E").unwrap(), + ), + ( + " + // This is also a preceding comment + pub fn function_1() -> Option<()> { + unimplemented!(); + }" + .unindent(), + text.find("pub fn function_1").unwrap(), + ), + ( + " + // This is a preceding comment + fn function_2() -> Result<()> { + unimplemented!(); + }" + .unindent(), + text.find("fn function_2").unwrap(), + ), + ( + " + #[derive(Clone)] + struct D { + name: String + }" + .unindent(), + text.find("struct D").unwrap(), + ), + ], + ); +} + +#[gpui::test] +async fn test_code_context_retrieval_json() { + let language = json_lang(); + let embedding_provider = Arc::new(FakeEmbeddingProvider::default()); + let mut retriever = CodeContextRetriever::new(embedding_provider); + + let text = r#" + { + "array": [1, 2, 3, 4], + "string": "abcdefg", + "nested_object": { + "array_2": [5, 6, 7, 8], + "string_2": "hijklmnop", + "boolean": true, + "none": null + } + } + "# + .unindent(); + + let documents = retriever.parse_file(&text, language.clone()).unwrap(); + + assert_documents_eq( + &documents, + &[( + r#" + { + "array": [], + "string": "", + "nested_object": { + "array_2": [], + "string_2": "", + "boolean": true, + "none": null + } + }"# + .unindent(), + text.find("{").unwrap(), + )], + ); + + let text = r#" + [ + { + "name": "somebody", + "age": 42 + }, + { + "name": "somebody else", + "age": 43 + } + ] + "# + .unindent(); + + let documents = retriever.parse_file(&text, language.clone()).unwrap(); + + assert_documents_eq( + &documents, + &[( + r#" + [{ + "name": "", + "age": 42 + }]"# + .unindent(), + text.find("[").unwrap(), + )], + ); +} + +fn assert_documents_eq( + documents: &[Span], + expected_contents_and_start_offsets: &[(String, usize)], +) { + assert_eq!( + documents + .iter() + .map(|document| (document.content.clone(), document.range.start)) + .collect::>(), + expected_contents_and_start_offsets + ); +} + +#[gpui::test] +async fn test_code_context_retrieval_javascript() { + let language = js_lang(); + let embedding_provider = Arc::new(FakeEmbeddingProvider::default()); + let mut retriever = CodeContextRetriever::new(embedding_provider); + + let text = " + /* globals importScripts, backend */ + function _authorize() {} + + /** + * Sometimes the frontend build is way faster than backend. + */ + export async function authorizeBank() { + _authorize(pushModal, upgradingAccountId, {}); + } + + export class SettingsPage { + /* This is a test setting */ + constructor(page) { + this.page = page; + } + } + + /* This is a test comment */ + class TestClass {} + + /* Schema for editor_events in Clickhouse. */ + export interface ClickhouseEditorEvent { + installation_id: string + operation: string + } + " + .unindent(); + + let documents = retriever.parse_file(&text, language.clone()).unwrap(); + + assert_documents_eq( + &documents, + &[ + ( + " + /* globals importScripts, backend */ + function _authorize() {}" + .unindent(), + 37, + ), + ( + " + /** + * Sometimes the frontend build is way faster than backend. + */ + export async function authorizeBank() { + _authorize(pushModal, upgradingAccountId, {}); + }" + .unindent(), + 131, + ), + ( + " + export class SettingsPage { + /* This is a test setting */ + constructor(page) { + this.page = page; + } + }" + .unindent(), + 225, + ), + ( + " + /* This is a test setting */ + constructor(page) { + this.page = page; + }" + .unindent(), + 290, + ), + ( + " + /* This is a test comment */ + class TestClass {}" + .unindent(), + 374, + ), + ( + " + /* Schema for editor_events in Clickhouse. */ + export interface ClickhouseEditorEvent { + installation_id: string + operation: string + }" + .unindent(), + 440, + ), + ], + ) +} + +#[gpui::test] +async fn test_code_context_retrieval_lua() { + let language = lua_lang(); + let embedding_provider = Arc::new(FakeEmbeddingProvider::default()); + let mut retriever = CodeContextRetriever::new(embedding_provider); + + let text = r#" + -- Creates a new class + -- @param baseclass The Baseclass of this class, or nil. + -- @return A new class reference. + function classes.class(baseclass) + -- Create the class definition and metatable. + local classdef = {} + -- Find the super class, either Object or user-defined. + baseclass = baseclass or classes.Object + -- If this class definition does not know of a function, it will 'look up' to the Baseclass via the __index of the metatable. + setmetatable(classdef, { __index = baseclass }) + -- All class instances have a reference to the class object. + classdef.class = classdef + --- Recursivly allocates the inheritance tree of the instance. + -- @param mastertable The 'root' of the inheritance tree. + -- @return Returns the instance with the allocated inheritance tree. + function classdef.alloc(mastertable) + -- All class instances have a reference to a superclass object. + local instance = { super = baseclass.alloc(mastertable) } + -- Any functions this instance does not know of will 'look up' to the superclass definition. + setmetatable(instance, { __index = classdef, __newindex = mastertable }) + return instance + end + end + "#.unindent(); + + let documents = retriever.parse_file(&text, language.clone()).unwrap(); + + assert_documents_eq( + &documents, + &[ + (r#" + -- Creates a new class + -- @param baseclass The Baseclass of this class, or nil. + -- @return A new class reference. + function classes.class(baseclass) + -- Create the class definition and metatable. + local classdef = {} + -- Find the super class, either Object or user-defined. + baseclass = baseclass or classes.Object + -- If this class definition does not know of a function, it will 'look up' to the Baseclass via the __index of the metatable. + setmetatable(classdef, { __index = baseclass }) + -- All class instances have a reference to the class object. + classdef.class = classdef + --- Recursivly allocates the inheritance tree of the instance. + -- @param mastertable The 'root' of the inheritance tree. + -- @return Returns the instance with the allocated inheritance tree. + function classdef.alloc(mastertable) + --[ ... ]-- + --[ ... ]-- + end + end"#.unindent(), + 114), + (r#" + --- Recursivly allocates the inheritance tree of the instance. + -- @param mastertable The 'root' of the inheritance tree. + -- @return Returns the instance with the allocated inheritance tree. + function classdef.alloc(mastertable) + -- All class instances have a reference to a superclass object. + local instance = { super = baseclass.alloc(mastertable) } + -- Any functions this instance does not know of will 'look up' to the superclass definition. + setmetatable(instance, { __index = classdef, __newindex = mastertable }) + return instance + end"#.unindent(), 809), + ] + ); +} + +#[gpui::test] +async fn test_code_context_retrieval_elixir() { + let language = elixir_lang(); + let embedding_provider = Arc::new(FakeEmbeddingProvider::default()); + let mut retriever = CodeContextRetriever::new(embedding_provider); + + let text = r#" + defmodule File.Stream do + @moduledoc """ + Defines a `File.Stream` struct returned by `File.stream!/3`. + + The following fields are public: + + * `path` - the file path + * `modes` - the file modes + * `raw` - a boolean indicating if bin functions should be used + * `line_or_bytes` - if reading should read lines or a given number of bytes + * `node` - the node the file belongs to + + """ + + defstruct path: nil, modes: [], line_or_bytes: :line, raw: true, node: nil + + @type t :: %__MODULE__{} + + @doc false + def __build__(path, modes, line_or_bytes) do + raw = :lists.keyfind(:encoding, 1, modes) == false + + modes = + case raw do + true -> + case :lists.keyfind(:read_ahead, 1, modes) do + {:read_ahead, false} -> [:raw | :lists.keydelete(:read_ahead, 1, modes)] + {:read_ahead, _} -> [:raw | modes] + false -> [:raw, :read_ahead | modes] + end + + false -> + modes + end + + %File.Stream{path: path, modes: modes, raw: raw, line_or_bytes: line_or_bytes, node: node()} + + end"# + .unindent(); + + let documents = retriever.parse_file(&text, language.clone()).unwrap(); + + assert_documents_eq( + &documents, + &[( + r#" + defmodule File.Stream do + @moduledoc """ + Defines a `File.Stream` struct returned by `File.stream!/3`. + + The following fields are public: + + * `path` - the file path + * `modes` - the file modes + * `raw` - a boolean indicating if bin functions should be used + * `line_or_bytes` - if reading should read lines or a given number of bytes + * `node` - the node the file belongs to + + """ + + defstruct path: nil, modes: [], line_or_bytes: :line, raw: true, node: nil + + @type t :: %__MODULE__{} + + @doc false + def __build__(path, modes, line_or_bytes) do + raw = :lists.keyfind(:encoding, 1, modes) == false + + modes = + case raw do + true -> + case :lists.keyfind(:read_ahead, 1, modes) do + {:read_ahead, false} -> [:raw | :lists.keydelete(:read_ahead, 1, modes)] + {:read_ahead, _} -> [:raw | modes] + false -> [:raw, :read_ahead | modes] + end + + false -> + modes + end + + %File.Stream{path: path, modes: modes, raw: raw, line_or_bytes: line_or_bytes, node: node()} + + end"# + .unindent(), + 0, + ),(r#" + @doc false + def __build__(path, modes, line_or_bytes) do + raw = :lists.keyfind(:encoding, 1, modes) == false + + modes = + case raw do + true -> + case :lists.keyfind(:read_ahead, 1, modes) do + {:read_ahead, false} -> [:raw | :lists.keydelete(:read_ahead, 1, modes)] + {:read_ahead, _} -> [:raw | modes] + false -> [:raw, :read_ahead | modes] + end + + false -> + modes + end + + %File.Stream{path: path, modes: modes, raw: raw, line_or_bytes: line_or_bytes, node: node()} + + end"#.unindent(), 574)], + ); +} + +#[gpui::test] +async fn test_code_context_retrieval_cpp() { + let language = cpp_lang(); + let embedding_provider = Arc::new(FakeEmbeddingProvider::default()); + let mut retriever = CodeContextRetriever::new(embedding_provider); + + let text = " + /** + * @brief Main function + * @returns 0 on exit + */ + int main() { return 0; } + + /** + * This is a test comment + */ + class MyClass { // The class + public: // Access specifier + int myNum; // Attribute (int variable) + string myString; // Attribute (string variable) + }; + + // This is a test comment + enum Color { red, green, blue }; + + /** This is a preceding block comment + * This is the second line + */ + struct { // Structure declaration + int myNum; // Member (int variable) + string myString; // Member (string variable) + } myStructure; + + /** + * @brief Matrix class. + */ + template ::value || std::is_floating_point::value, + bool>::type> + class Matrix2 { + std::vector> _mat; + + public: + /** + * @brief Constructor + * @tparam Integer ensuring integers are being evaluated and not other + * data types. + * @param size denoting the size of Matrix as size x size + */ + template ::value, + Integer>::type> + explicit Matrix(const Integer size) { + for (size_t i = 0; i < size; ++i) { + _mat.emplace_back(std::vector(size, 0)); + } + } + }" + .unindent(); + + let documents = retriever.parse_file(&text, language.clone()).unwrap(); + + assert_documents_eq( + &documents, + &[ + ( + " + /** + * @brief Main function + * @returns 0 on exit + */ + int main() { return 0; }" + .unindent(), + 54, + ), + ( + " + /** + * This is a test comment + */ + class MyClass { // The class + public: // Access specifier + int myNum; // Attribute (int variable) + string myString; // Attribute (string variable) + }" + .unindent(), + 112, + ), + ( + " + // This is a test comment + enum Color { red, green, blue }" + .unindent(), + 322, + ), + ( + " + /** This is a preceding block comment + * This is the second line + */ + struct { // Structure declaration + int myNum; // Member (int variable) + string myString; // Member (string variable) + } myStructure;" + .unindent(), + 425, + ), + ( + " + /** + * @brief Matrix class. + */ + template ::value || std::is_floating_point::value, + bool>::type> + class Matrix2 { + std::vector> _mat; + + public: + /** + * @brief Constructor + * @tparam Integer ensuring integers are being evaluated and not other + * data types. + * @param size denoting the size of Matrix as size x size + */ + template ::value, + Integer>::type> + explicit Matrix(const Integer size) { + for (size_t i = 0; i < size; ++i) { + _mat.emplace_back(std::vector(size, 0)); + } + } + }" + .unindent(), + 612, + ), + ( + " + explicit Matrix(const Integer size) { + for (size_t i = 0; i < size; ++i) { + _mat.emplace_back(std::vector(size, 0)); + } + }" + .unindent(), + 1226, + ), + ], + ); +} + +#[gpui::test] +async fn test_code_context_retrieval_ruby() { + let language = ruby_lang(); + let embedding_provider = Arc::new(FakeEmbeddingProvider::default()); + let mut retriever = CodeContextRetriever::new(embedding_provider); + + let text = r#" + # This concern is inspired by "sudo mode" on GitHub. It + # is a way to re-authenticate a user before allowing them + # to see or perform an action. + # + # Add `before_action :require_challenge!` to actions you + # want to protect. + # + # The user will be shown a page to enter the challenge (which + # is either the password, or just the username when no + # password exists). Upon passing, there is a grace period + # during which no challenge will be asked from the user. + # + # Accessing challenge-protected resources during the grace + # period will refresh the grace period. + module ChallengableConcern + extend ActiveSupport::Concern + + CHALLENGE_TIMEOUT = 1.hour.freeze + + def require_challenge! + return if skip_challenge? + + if challenge_passed_recently? + session[:challenge_passed_at] = Time.now.utc + return + end + + @challenge = Form::Challenge.new(return_to: request.url) + + if params.key?(:form_challenge) + if challenge_passed? + session[:challenge_passed_at] = Time.now.utc + else + flash.now[:alert] = I18n.t('challenge.invalid_password') + render_challenge + end + else + render_challenge + end + end + + def challenge_passed? + current_user.valid_password?(challenge_params[:current_password]) + end + end + + class Animal + include Comparable + + attr_reader :legs + + def initialize(name, legs) + @name, @legs = name, legs + end + + def <=>(other) + legs <=> other.legs + end + end + + # Singleton method for car object + def car.wheels + puts "There are four wheels" + end"# + .unindent(); + + let documents = retriever.parse_file(&text, language.clone()).unwrap(); + + assert_documents_eq( + &documents, + &[ + ( + r#" + # This concern is inspired by "sudo mode" on GitHub. It + # is a way to re-authenticate a user before allowing them + # to see or perform an action. + # + # Add `before_action :require_challenge!` to actions you + # want to protect. + # + # The user will be shown a page to enter the challenge (which + # is either the password, or just the username when no + # password exists). Upon passing, there is a grace period + # during which no challenge will be asked from the user. + # + # Accessing challenge-protected resources during the grace + # period will refresh the grace period. + module ChallengableConcern + extend ActiveSupport::Concern + + CHALLENGE_TIMEOUT = 1.hour.freeze + + def require_challenge! + # ... + end + + def challenge_passed? + # ... + end + end"# + .unindent(), + 558, + ), + ( + r#" + def require_challenge! + return if skip_challenge? + + if challenge_passed_recently? + session[:challenge_passed_at] = Time.now.utc + return + end + + @challenge = Form::Challenge.new(return_to: request.url) + + if params.key?(:form_challenge) + if challenge_passed? + session[:challenge_passed_at] = Time.now.utc + else + flash.now[:alert] = I18n.t('challenge.invalid_password') + render_challenge + end + else + render_challenge + end + end"# + .unindent(), + 663, + ), + ( + r#" + def challenge_passed? + current_user.valid_password?(challenge_params[:current_password]) + end"# + .unindent(), + 1254, + ), + ( + r#" + class Animal + include Comparable + + attr_reader :legs + + def initialize(name, legs) + # ... + end + + def <=>(other) + # ... + end + end"# + .unindent(), + 1363, + ), + ( + r#" + def initialize(name, legs) + @name, @legs = name, legs + end"# + .unindent(), + 1427, + ), + ( + r#" + def <=>(other) + legs <=> other.legs + end"# + .unindent(), + 1501, + ), + ( + r#" + # Singleton method for car object + def car.wheels + puts "There are four wheels" + end"# + .unindent(), + 1591, + ), + ], + ); +} + +#[gpui::test] +async fn test_code_context_retrieval_php() { + let language = php_lang(); + let embedding_provider = Arc::new(FakeEmbeddingProvider::default()); + let mut retriever = CodeContextRetriever::new(embedding_provider); + + let text = r#" + 100) { + throw new Exception(message: 'Progress cannot be greater than 100'); + } + + if ($this->achievements()->find($achievement->id)) { + throw new Exception(message: 'User already has this Achievement'); + } + + $this->achievements()->attach($achievement, [ + 'progress' => $progress ?? null, + ]); + + $this->when(value: ($progress === null) || ($progress === 100), callback: fn (): ?array => event(new AchievementAwarded(achievement: $achievement, user: $this))); + } + + public function achievements(): BelongsToMany + { + return $this->belongsToMany(related: Achievement::class) + ->withPivot(columns: 'progress') + ->where('is_secret', false) + ->using(AchievementUser::class); + } + } + + interface Multiplier + { + public function qualifies(array $data): bool; + + public function setMultiplier(): int; + } + + enum AuditType: string + { + case Add = 'add'; + case Remove = 'remove'; + case Reset = 'reset'; + case LevelUp = 'level_up'; + } + + ?>"# + .unindent(); + + let documents = retriever.parse_file(&text, language.clone()).unwrap(); + + assert_documents_eq( + &documents, + &[ + ( + r#" + /* + This is a multiple-lines comment block + that spans over multiple + lines + */ + function functionName() { + echo "Hello world!"; + }"# + .unindent(), + 123, + ), + ( + r#" + trait HasAchievements + { + /** + * @throws \Exception + */ + public function grantAchievement(Achievement $achievement, $progress = null): void + {/* ... */} + + public function achievements(): BelongsToMany + {/* ... */} + }"# + .unindent(), + 177, + ), + (r#" + /** + * @throws \Exception + */ + public function grantAchievement(Achievement $achievement, $progress = null): void + { + if ($progress > 100) { + throw new Exception(message: 'Progress cannot be greater than 100'); + } + + if ($this->achievements()->find($achievement->id)) { + throw new Exception(message: 'User already has this Achievement'); + } + + $this->achievements()->attach($achievement, [ + 'progress' => $progress ?? null, + ]); + + $this->when(value: ($progress === null) || ($progress === 100), callback: fn (): ?array => event(new AchievementAwarded(achievement: $achievement, user: $this))); + }"#.unindent(), 245), + (r#" + public function achievements(): BelongsToMany + { + return $this->belongsToMany(related: Achievement::class) + ->withPivot(columns: 'progress') + ->where('is_secret', false) + ->using(AchievementUser::class); + }"#.unindent(), 902), + (r#" + interface Multiplier + { + public function qualifies(array $data): bool; + + public function setMultiplier(): int; + }"#.unindent(), + 1146), + (r#" + enum AuditType: string + { + case Add = 'add'; + case Remove = 'remove'; + case Reset = 'reset'; + case LevelUp = 'level_up'; + }"#.unindent(), 1265) + ], + ); +} + +fn js_lang() -> Arc { + Arc::new( + Language::new( + LanguageConfig { + name: "Javascript".into(), + path_suffixes: vec!["js".into()], + ..Default::default() + }, + Some(tree_sitter_typescript::language_tsx()), + ) + .with_embedding_query( + &r#" + + ( + (comment)* @context + . + [ + (export_statement + (function_declaration + "async"? @name + "function" @name + name: (_) @name)) + (function_declaration + "async"? @name + "function" @name + name: (_) @name) + ] @item + ) + + ( + (comment)* @context + . + [ + (export_statement + (class_declaration + "class" @name + name: (_) @name)) + (class_declaration + "class" @name + name: (_) @name) + ] @item + ) + + ( + (comment)* @context + . + [ + (export_statement + (interface_declaration + "interface" @name + name: (_) @name)) + (interface_declaration + "interface" @name + name: (_) @name) + ] @item + ) + + ( + (comment)* @context + . + [ + (export_statement + (enum_declaration + "enum" @name + name: (_) @name)) + (enum_declaration + "enum" @name + name: (_) @name) + ] @item + ) + + ( + (comment)* @context + . + (method_definition + [ + "get" + "set" + "async" + "*" + "static" + ]* @name + name: (_) @name) @item + ) + + "# + .unindent(), + ) + .unwrap(), + ) +} + +fn rust_lang() -> Arc { + Arc::new( + Language::new( + LanguageConfig { + name: "Rust".into(), + path_suffixes: vec!["rs".into()], + collapsed_placeholder: " /* ... */ ".to_string(), + ..Default::default() + }, + Some(tree_sitter_rust::language()), + ) + .with_embedding_query( + r#" + ( + [(line_comment) (attribute_item)]* @context + . + [ + (struct_item + name: (_) @name) + + (enum_item + name: (_) @name) + + (impl_item + trait: (_)? @name + "for"? @name + type: (_) @name) + + (trait_item + name: (_) @name) + + (function_item + name: (_) @name + body: (block + "{" @keep + "}" @keep) @collapse) + + (macro_definition + name: (_) @name) + ] @item + ) + + (attribute_item) @collapse + (use_declaration) @collapse + "#, + ) + .unwrap(), + ) +} + +fn json_lang() -> Arc { + Arc::new( + Language::new( + LanguageConfig { + name: "JSON".into(), + path_suffixes: vec!["json".into()], + ..Default::default() + }, + Some(tree_sitter_json::language()), + ) + .with_embedding_query( + r#" + (document) @item + + (array + "[" @keep + . + (object)? @keep + "]" @keep) @collapse + + (pair value: (string + "\"" @keep + "\"" @keep) @collapse) + "#, + ) + .unwrap(), + ) +} + +fn toml_lang() -> Arc { + Arc::new(Language::new( + LanguageConfig { + name: "TOML".into(), + path_suffixes: vec!["toml".into()], + ..Default::default() + }, + Some(tree_sitter_toml::language()), + )) +} + +fn cpp_lang() -> Arc { + Arc::new( + Language::new( + LanguageConfig { + name: "CPP".into(), + path_suffixes: vec!["cpp".into()], + ..Default::default() + }, + Some(tree_sitter_cpp::language()), + ) + .with_embedding_query( + r#" + ( + (comment)* @context + . + (function_definition + (type_qualifier)? @name + type: (_)? @name + declarator: [ + (function_declarator + declarator: (_) @name) + (pointer_declarator + "*" @name + declarator: (function_declarator + declarator: (_) @name)) + (pointer_declarator + "*" @name + declarator: (pointer_declarator + "*" @name + declarator: (function_declarator + declarator: (_) @name))) + (reference_declarator + ["&" "&&"] @name + (function_declarator + declarator: (_) @name)) + ] + (type_qualifier)? @name) @item + ) + + ( + (comment)* @context + . + (template_declaration + (class_specifier + "class" @name + name: (_) @name) + ) @item + ) + + ( + (comment)* @context + . + (class_specifier + "class" @name + name: (_) @name) @item + ) + + ( + (comment)* @context + . + (enum_specifier + "enum" @name + name: (_) @name) @item + ) + + ( + (comment)* @context + . + (declaration + type: (struct_specifier + "struct" @name) + declarator: (_) @name) @item + ) + + "#, + ) + .unwrap(), + ) +} + +fn lua_lang() -> Arc { + Arc::new( + Language::new( + LanguageConfig { + name: "Lua".into(), + path_suffixes: vec!["lua".into()], + collapsed_placeholder: "--[ ... ]--".to_string(), + ..Default::default() + }, + Some(tree_sitter_lua::language()), + ) + .with_embedding_query( + r#" + ( + (comment)* @context + . + (function_declaration + "function" @name + name: (_) @name + (comment)* @collapse + body: (block) @collapse + ) @item + ) + "#, + ) + .unwrap(), + ) +} + +fn php_lang() -> Arc { + Arc::new( + Language::new( + LanguageConfig { + name: "PHP".into(), + path_suffixes: vec!["php".into()], + collapsed_placeholder: "/* ... */".into(), + ..Default::default() + }, + Some(tree_sitter_php::language()), + ) + .with_embedding_query( + r#" + ( + (comment)* @context + . + [ + (function_definition + "function" @name + name: (_) @name + body: (_ + "{" @keep + "}" @keep) @collapse + ) + + (trait_declaration + "trait" @name + name: (_) @name) + + (method_declaration + "function" @name + name: (_) @name + body: (_ + "{" @keep + "}" @keep) @collapse + ) + + (interface_declaration + "interface" @name + name: (_) @name + ) + + (enum_declaration + "enum" @name + name: (_) @name + ) + + ] @item + ) + "#, + ) + .unwrap(), + ) +} + +fn ruby_lang() -> Arc { + Arc::new( + Language::new( + LanguageConfig { + name: "Ruby".into(), + path_suffixes: vec!["rb".into()], + collapsed_placeholder: "# ...".to_string(), + ..Default::default() + }, + Some(tree_sitter_ruby::language()), + ) + .with_embedding_query( + r#" + ( + (comment)* @context + . + [ + (module + "module" @name + name: (_) @name) + (method + "def" @name + name: (_) @name + body: (body_statement) @collapse) + (class + "class" @name + name: (_) @name) + (singleton_method + "def" @name + object: (_) @name + "." @name + name: (_) @name + body: (body_statement) @collapse) + ] @item + ) + "#, + ) + .unwrap(), + ) +} + +fn elixir_lang() -> Arc { + Arc::new( + Language::new( + LanguageConfig { + name: "Elixir".into(), + path_suffixes: vec!["rs".into()], + ..Default::default() + }, + Some(tree_sitter_elixir::language()), + ) + .with_embedding_query( + r#" + ( + (unary_operator + operator: "@" + operand: (call + target: (identifier) @unary + (#match? @unary "^(doc)$")) + ) @context + . + (call + target: (identifier) @name + (arguments + [ + (identifier) @name + (call + target: (identifier) @name) + (binary_operator + left: (call + target: (identifier) @name) + operator: "when") + ]) + (#any-match? @name "^(def|defp|defdelegate|defguard|defguardp|defmacro|defmacrop|defn|defnp)$")) @item + ) + + (call + target: (identifier) @name + (arguments (alias) @name) + (#any-match? @name "^(defmodule|defprotocol)$")) @item + "#, + ) + .unwrap(), + ) +} + +#[gpui::test] +fn test_subtract_ranges() { + // collapsed_ranges: Vec>, keep_ranges: Vec> + + assert_eq!( + subtract_ranges(&[0..5, 10..21], &[0..1, 4..5]), + vec![1..4, 10..21] + ); + + assert_eq!(subtract_ranges(&[0..5], &[1..2]), &[0..1, 2..5]); +} + +fn init_test(cx: &mut TestAppContext) { + cx.update(|cx| { + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + SemanticIndexSettings::register(cx); + ProjectSettings::register(cx); + }); +} diff --git a/crates/workspace2/src/workspace2.rs b/crates/workspace2/src/workspace2.rs index 77d744b9fc..5741fa4a94 100644 --- a/crates/workspace2/src/workspace2.rs +++ b/crates/workspace2/src/workspace2.rs @@ -3942,8 +3942,6 @@ impl std::fmt::Debug for OpenPaths { } } -pub struct WorkspaceCreated(pub WeakView); - pub fn activate_workspace_for_project( cx: &mut AppContext, predicate: impl Fn(&Project, &AppContext) -> bool + Send + 'static,