From 49371b44cb6c7de5997deda18346238cb932905d Mon Sep 17 00:00:00 2001 From: Kyle Kelley Date: Fri, 12 Apr 2024 10:40:59 -0700 Subject: [PATCH] Semantic Index (#10329) This introduces semantic indexing in Zed based on chunking text from files in the developer's workspace and creating vector embeddings using an embedding model. As part of this, we've created an embeddings provider trait that allows us to work with OpenAI, a local Ollama model, or a Zed hosted embedding. The semantic index is built by breaking down text for known (programming) languages into manageable chunks that are smaller than the max token size. Each chunk is then fed to a language model to create a high dimensional vector which is then normalized to a unit vector to allow fast comparison with other vectors with a simple dot product. Alongside the vector, we store the path of the file and the range within the document where the vector was sourced from. Zed will soon grok contextual similarity across different text snippets, allowing for natural language search beyond keyword matching. This is being put together both for human-based search as well as providing results to Large Language Models to allow them to refine how they help developers. Remaining todo: * [x] Change `provider` to `model` within the zed hosted embeddings database (as its currently a combo of the provider and the model in one name) Release Notes: - N/A --------- Co-authored-by: Nathan Sobo Co-authored-by: Antonio Scandurra Co-authored-by: Conrad Irwin Co-authored-by: Marshall Bowers Co-authored-by: Antonio --- Cargo.lock | 143 +++ Cargo.toml | 3 + crates/channel/src/channel_store_tests.rs | 4 +- .../20240409082755_create_embeddings.sql | 9 + crates/collab/src/db/queries.rs | 1 + crates/collab/src/db/queries/embeddings.rs | 94 ++ crates/collab/src/db/tables.rs | 1 + crates/collab/src/db/tables/embedding.rs | 18 + crates/collab/src/db/tests.rs | 1 + crates/collab/src/db/tests/embedding_tests.rs | 84 ++ crates/collab/src/main.rs | 7 +- crates/collab/src/rpc.rs | 142 ++- crates/editor/src/git/blame.rs | 2 +- crates/gpui/src/app/test_context.rs | 33 +- crates/gpui/src/executor.rs | 7 +- crates/language/src/language.rs | 12 +- crates/open_ai/src/open_ai.rs | 71 +- crates/project/src/project.rs | 48 + crates/project/src/project_tests.rs | 6 +- crates/rpc/proto/zed.proto | 30 +- crates/rpc/src/proto.rs | 6 + crates/semantic_index/Cargo.toml | 48 + crates/semantic_index/LICENSE-GPL | 1 + crates/semantic_index/examples/index.rs | 140 +++ crates/semantic_index/fixture/main.rs | 3 + crates/semantic_index/fixture/needle.md | 43 + crates/semantic_index/src/chunking.rs | 409 ++++++++ crates/semantic_index/src/embedding.rs | 125 +++ crates/semantic_index/src/embedding/cloud.rs | 88 ++ crates/semantic_index/src/embedding/ollama.rs | 74 ++ .../semantic_index/src/embedding/open_ai.rs | 55 + crates/semantic_index/src/semantic_index.rs | 954 ++++++++++++++++++ crates/util/src/http.rs | 28 +- 33 files changed, 2649 insertions(+), 41 deletions(-) create mode 100644 crates/collab/migrations/20240409082755_create_embeddings.sql create mode 100644 crates/collab/src/db/queries/embeddings.rs create mode 100644 crates/collab/src/db/tables/embedding.rs create mode 100644 crates/collab/src/db/tests/embedding_tests.rs create mode 100644 crates/semantic_index/Cargo.toml create mode 120000 crates/semantic_index/LICENSE-GPL create mode 100644 crates/semantic_index/examples/index.rs create mode 100644 crates/semantic_index/fixture/main.rs create mode 100644 crates/semantic_index/fixture/needle.md create mode 100644 crates/semantic_index/src/chunking.rs create mode 100644 crates/semantic_index/src/embedding.rs create mode 100644 crates/semantic_index/src/embedding/cloud.rs create mode 100644 crates/semantic_index/src/embedding/ollama.rs create mode 100644 crates/semantic_index/src/embedding/open_ai.rs create mode 100644 crates/semantic_index/src/semantic_index.rs diff --git a/Cargo.lock b/Cargo.lock index 92ede849ff..814630d8e4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3265,6 +3265,15 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9ea835d29036a4087793836fa931b08837ad5e957da9e23886b29586fb9b6650" +[[package]] +name = "doxygen-rs" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "415b6ec780d34dcf624666747194393603d0373b7141eef01d12ee58881507d9" +dependencies = [ + "phf", +] + [[package]] name = "dwrote" version = "0.11.0" @@ -4085,6 +4094,17 @@ dependencies = [ "futures-util", ] +[[package]] +name = "futures-batch" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6f444c45a1cb86f2a7e301469fd50a82084a60dadc25d94529a8312276ecb71a" +dependencies = [ + "futures 0.3.28", + "futures-timer", + "pin-utils", +] + [[package]] name = "futures-channel" version = "0.3.30" @@ -4180,6 +4200,12 @@ version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004" +[[package]] +name = "futures-timer" +version = "3.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24" + [[package]] name = "futures-util" version = "0.3.30" @@ -4659,6 +4685,41 @@ dependencies = [ "unicode-segmentation", ] +[[package]] +name = "heed" +version = "0.20.0-alpha.9" +source = "git+https://github.com/meilisearch/heed?rev=036ac23f73a021894974b9adc815bc95b3e0482a#036ac23f73a021894974b9adc815bc95b3e0482a" +dependencies = [ + "bitflags 2.4.2", + "byteorder", + "heed-traits", + "heed-types", + "libc", + "lmdb-master-sys", + "once_cell", + "page_size", + "serde", + "synchronoise", + "url", +] + +[[package]] +name = "heed-traits" +version = "0.20.0-alpha.9" +source = "git+https://github.com/meilisearch/heed?rev=036ac23f73a021894974b9adc815bc95b3e0482a#036ac23f73a021894974b9adc815bc95b3e0482a" + +[[package]] +name = "heed-types" +version = "0.20.0-alpha.9" +source = "git+https://github.com/meilisearch/heed?rev=036ac23f73a021894974b9adc815bc95b3e0482a#036ac23f73a021894974b9adc815bc95b3e0482a" +dependencies = [ + "bincode", + "byteorder", + "heed-traits", + "serde", + "serde_json", +] + [[package]] name = "hermit-abi" version = "0.1.19" @@ -5664,6 +5725,16 @@ dependencies = [ "sha2 0.10.7", ] +[[package]] +name = "lmdb-master-sys" +version = "0.1.0" +source = "git+https://github.com/meilisearch/heed?rev=036ac23f73a021894974b9adc815bc95b3e0482a#036ac23f73a021894974b9adc815bc95b3e0482a" +dependencies = [ + "cc", + "doxygen-rs", + "libc", +] + [[package]] name = "lock_api" version = "0.4.10" @@ -6683,6 +6754,16 @@ dependencies = [ "sha2 0.10.7", ] +[[package]] +name = "page_size" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "30d5b2194ed13191c1999ae0704b7839fb18384fa22e49b57eeaa97d79ce40da" +dependencies = [ + "libc", + "winapi", +] + [[package]] name = "palette" version = "0.7.5" @@ -6856,9 +6937,33 @@ version = "0.11.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ade2d8b8f33c7333b51bcf0428d37e217e9f32192ae4772156f65063b8ce03dc" dependencies = [ + "phf_macros", "phf_shared", ] +[[package]] +name = "phf_generator" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48e4cc64c2ad9ebe670cb8fd69dd50ae301650392e81c05f9bfcb2d5bdbc24b0" +dependencies = [ + "phf_shared", + "rand 0.8.5", +] + +[[package]] +name = "phf_macros" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3444646e286606587e49f3bcf1679b8cef1dc2c5ecc29ddacaffc305180d464b" +dependencies = [ + "phf_generator", + "phf_shared", + "proc-macro2", + "quote", + "syn 2.0.48", +] + [[package]] name = "phf_shared" version = "0.11.2" @@ -8473,6 +8578,35 @@ version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "58bf37232d3bb9a2c4e641ca2a11d83b5062066f88df7fed36c28772046d65ba" +[[package]] +name = "semantic_index" +version = "0.1.0" +dependencies = [ + "anyhow", + "client", + "clock", + "collections", + "env_logger", + "fs", + "futures 0.3.28", + "futures-batch", + "gpui", + "heed", + "language", + "languages", + "log", + "open_ai", + "project", + "serde", + "serde_json", + "settings", + "sha2 0.10.7", + "smol", + "tempfile", + "util", + "worktree", +] + [[package]] name = "semantic_version" version = "0.1.0" @@ -9478,6 +9612,15 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160" +[[package]] +name = "synchronoise" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3dbc01390fc626ce8d1cffe3376ded2b72a11bb70e1c75f404a210e4daa4def2" +dependencies = [ + "crossbeam-queue", +] + [[package]] name = "sys-locale" version = "0.3.1" diff --git a/Cargo.toml b/Cargo.toml index f58d998a8b..22999d86b6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -73,6 +73,7 @@ members = [ "crates/task", "crates/tasks_ui", "crates/search", + "crates/semantic_index", "crates/semantic_version", "crates/settings", "crates/snippet", @@ -253,9 +254,11 @@ derive_more = "0.99.17" emojis = "0.6.1" env_logger = "0.9" futures = "0.3" +futures-batch = "0.6.1" futures-lite = "1.13" git2 = { version = "0.15", default-features = false } globset = "0.4" +heed = { git = "https://github.com/meilisearch/heed", rev = "036ac23f73a021894974b9adc815bc95b3e0482a", features = ["read-txn-no-tls"] } hex = "0.4.3" ignore = "0.4.22" indoc = "1" diff --git a/crates/channel/src/channel_store_tests.rs b/crates/channel/src/channel_store_tests.rs index cee747a4e9..c9f7bf2485 100644 --- a/crates/channel/src/channel_store_tests.rs +++ b/crates/channel/src/channel_store_tests.rs @@ -264,7 +264,7 @@ async fn test_channel_messages(cx: &mut TestAppContext) { ); assert_eq!( - channel.next_event(cx), + channel.next_event(cx).await, ChannelChatEvent::MessagesUpdated { old_range: 2..2, new_count: 1, @@ -317,7 +317,7 @@ async fn test_channel_messages(cx: &mut TestAppContext) { ); assert_eq!( - channel.next_event(cx), + channel.next_event(cx).await, ChannelChatEvent::MessagesUpdated { old_range: 0..0, new_count: 2, diff --git a/crates/collab/migrations/20240409082755_create_embeddings.sql b/crates/collab/migrations/20240409082755_create_embeddings.sql new file mode 100644 index 0000000000..ae4b4bcb61 --- /dev/null +++ b/crates/collab/migrations/20240409082755_create_embeddings.sql @@ -0,0 +1,9 @@ +CREATE TABLE IF NOT EXISTS "embeddings" ( + "model" TEXT, + "digest" BYTEA, + "dimensions" FLOAT4[1536], + "retrieved_at" TIMESTAMP NOT NULL DEFAULT now(), + PRIMARY KEY ("model", "digest") +); + +CREATE INDEX IF NOT EXISTS "idx_retrieved_at_on_embeddings" ON "embeddings" ("retrieved_at"); diff --git a/crates/collab/src/db/queries.rs b/crates/collab/src/db/queries.rs index 2cbbc67969..b7670aa60c 100644 --- a/crates/collab/src/db/queries.rs +++ b/crates/collab/src/db/queries.rs @@ -6,6 +6,7 @@ pub mod channels; pub mod contacts; pub mod contributors; pub mod dev_servers; +pub mod embeddings; pub mod extensions; pub mod hosted_projects; pub mod messages; diff --git a/crates/collab/src/db/queries/embeddings.rs b/crates/collab/src/db/queries/embeddings.rs new file mode 100644 index 0000000000..d901b59659 --- /dev/null +++ b/crates/collab/src/db/queries/embeddings.rs @@ -0,0 +1,94 @@ +use super::*; +use time::Duration; +use time::OffsetDateTime; + +impl Database { + pub async fn get_embeddings( + &self, + model: &str, + digests: &[Vec], + ) -> Result, Vec>> { + self.weak_transaction(|tx| async move { + let embeddings = { + let mut db_embeddings = embedding::Entity::find() + .filter( + embedding::Column::Model.eq(model).and( + embedding::Column::Digest + .is_in(digests.iter().map(|digest| digest.as_slice())), + ), + ) + .stream(&*tx) + .await?; + + let mut embeddings = HashMap::default(); + while let Some(db_embedding) = db_embeddings.next().await { + let db_embedding = db_embedding?; + embeddings.insert(db_embedding.digest, db_embedding.dimensions); + } + embeddings + }; + + if !embeddings.is_empty() { + let now = OffsetDateTime::now_utc(); + let retrieved_at = PrimitiveDateTime::new(now.date(), now.time()); + + embedding::Entity::update_many() + .filter( + embedding::Column::Digest + .is_in(embeddings.keys().map(|digest| digest.as_slice())), + ) + .col_expr(embedding::Column::RetrievedAt, Expr::value(retrieved_at)) + .exec(&*tx) + .await?; + } + + Ok(embeddings) + }) + .await + } + + pub async fn save_embeddings( + &self, + model: &str, + embeddings: &HashMap, Vec>, + ) -> Result<()> { + self.weak_transaction(|tx| async move { + embedding::Entity::insert_many(embeddings.iter().map(|(digest, dimensions)| { + let now_offset_datetime = OffsetDateTime::now_utc(); + let retrieved_at = + PrimitiveDateTime::new(now_offset_datetime.date(), now_offset_datetime.time()); + + embedding::ActiveModel { + model: ActiveValue::set(model.to_string()), + digest: ActiveValue::set(digest.clone()), + dimensions: ActiveValue::set(dimensions.clone()), + retrieved_at: ActiveValue::set(retrieved_at), + } + })) + .on_conflict( + OnConflict::columns([embedding::Column::Model, embedding::Column::Digest]) + .do_nothing() + .to_owned(), + ) + .exec_without_returning(&*tx) + .await?; + Ok(()) + }) + .await + } + + pub async fn purge_old_embeddings(&self) -> Result<()> { + self.weak_transaction(|tx| async move { + embedding::Entity::delete_many() + .filter( + embedding::Column::RetrievedAt + .lte(OffsetDateTime::now_utc() - Duration::days(60)), + ) + .exec(&*tx) + .await?; + + Ok(()) + }) + .await + } +} diff --git a/crates/collab/src/db/tables.rs b/crates/collab/src/db/tables.rs index 4a284682b2..2af78f776e 100644 --- a/crates/collab/src/db/tables.rs +++ b/crates/collab/src/db/tables.rs @@ -11,6 +11,7 @@ pub mod channel_message_mention; pub mod contact; pub mod contributor; pub mod dev_server; +pub mod embedding; pub mod extension; pub mod extension_version; pub mod feature_flag; diff --git a/crates/collab/src/db/tables/embedding.rs b/crates/collab/src/db/tables/embedding.rs new file mode 100644 index 0000000000..8743b4b9e6 --- /dev/null +++ b/crates/collab/src/db/tables/embedding.rs @@ -0,0 +1,18 @@ +use sea_orm::entity::prelude::*; +use time::PrimitiveDateTime; + +#[derive(Clone, Debug, PartialEq, DeriveEntityModel)] +#[sea_orm(table_name = "embeddings")] +pub struct Model { + #[sea_orm(primary_key)] + pub model: String, + #[sea_orm(primary_key)] + pub digest: Vec, + pub dimensions: Vec, + pub retrieved_at: PrimitiveDateTime, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation {} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab/src/db/tests.rs b/crates/collab/src/db/tests.rs index 35da659e54..e3ce834295 100644 --- a/crates/collab/src/db/tests.rs +++ b/crates/collab/src/db/tests.rs @@ -2,6 +2,7 @@ mod buffer_tests; mod channel_tests; mod contributor_tests; mod db_tests; +mod embedding_tests; mod extension_tests; mod feature_flag_tests; mod message_tests; diff --git a/crates/collab/src/db/tests/embedding_tests.rs b/crates/collab/src/db/tests/embedding_tests.rs new file mode 100644 index 0000000000..fcafac625d --- /dev/null +++ b/crates/collab/src/db/tests/embedding_tests.rs @@ -0,0 +1,84 @@ +use super::TestDb; +use crate::db::embedding; +use collections::HashMap; +use sea_orm::{sea_query::Expr, ColumnTrait, EntityTrait, QueryFilter}; +use std::ops::Sub; +use time::{Duration, OffsetDateTime, PrimitiveDateTime}; + +// SQLite does not support array arguments, so we only test this against a real postgres instance +#[gpui::test] +async fn test_get_embeddings_postgres(cx: &mut gpui::TestAppContext) { + let test_db = TestDb::postgres(cx.executor().clone()); + let db = test_db.db(); + + let provider = "test_model"; + let digest1 = vec![1, 2, 3]; + let digest2 = vec![4, 5, 6]; + let embeddings = HashMap::from_iter([ + (digest1.clone(), vec![0.1, 0.2, 0.3]), + (digest2.clone(), vec![0.4, 0.5, 0.6]), + ]); + + // Save embeddings + db.save_embeddings(provider, &embeddings).await.unwrap(); + + // Retrieve embeddings + let retrieved_embeddings = db + .get_embeddings(provider, &[digest1.clone(), digest2.clone()]) + .await + .unwrap(); + assert_eq!(retrieved_embeddings.len(), 2); + assert!(retrieved_embeddings.contains_key(&digest1)); + assert!(retrieved_embeddings.contains_key(&digest2)); + + // Check if the retrieved embeddings are correct + assert_eq!(retrieved_embeddings[&digest1], vec![0.1, 0.2, 0.3]); + assert_eq!(retrieved_embeddings[&digest2], vec![0.4, 0.5, 0.6]); +} + +#[gpui::test] +async fn test_purge_old_embeddings(cx: &mut gpui::TestAppContext) { + let test_db = TestDb::postgres(cx.executor().clone()); + let db = test_db.db(); + + let model = "test_model"; + let digest = vec![7, 8, 9]; + let embeddings = HashMap::from_iter([(digest.clone(), vec![0.7, 0.8, 0.9])]); + + // Save old embeddings + db.save_embeddings(model, &embeddings).await.unwrap(); + + // Reach into the DB and change the retrieved at to be > 60 days + db.weak_transaction(|tx| { + let digest = digest.clone(); + async move { + let sixty_days_ago = OffsetDateTime::now_utc().sub(Duration::days(61)); + let retrieved_at = PrimitiveDateTime::new(sixty_days_ago.date(), sixty_days_ago.time()); + + embedding::Entity::update_many() + .filter( + embedding::Column::Model + .eq(model) + .and(embedding::Column::Digest.eq(digest)), + ) + .col_expr(embedding::Column::RetrievedAt, Expr::value(retrieved_at)) + .exec(&*tx) + .await + .unwrap(); + + Ok(()) + } + }) + .await + .unwrap(); + + // Purge old embeddings + db.purge_old_embeddings().await.unwrap(); + + // Try to retrieve the purged embeddings + let retrieved_embeddings = db.get_embeddings(model, &[digest.clone()]).await.unwrap(); + assert!( + retrieved_embeddings.is_empty(), + "Old embeddings should have been purged" + ); +} diff --git a/crates/collab/src/main.rs b/crates/collab/src/main.rs index 728544c533..b85d378ef2 100644 --- a/crates/collab/src/main.rs +++ b/crates/collab/src/main.rs @@ -6,8 +6,8 @@ use axum::{ Extension, Router, }; use collab::{ - api::fetch_extensions_from_blob_store_periodically, db, env, executor::Executor, AppState, - Config, RateLimiter, Result, + api::fetch_extensions_from_blob_store_periodically, db, env, executor::Executor, + rpc::ResultExt, AppState, Config, RateLimiter, Result, }; use db::Database; use std::{ @@ -23,7 +23,7 @@ use tower_http::trace::TraceLayer; use tracing_subscriber::{ filter::EnvFilter, fmt::format::JsonFields, util::SubscriberInitExt, Layer, }; -use util::ResultExt; +use util::ResultExt as _; const VERSION: &str = env!("CARGO_PKG_VERSION"); const REVISION: Option<&'static str> = option_env!("GITHUB_SHA"); @@ -90,6 +90,7 @@ async fn main() -> Result<()> { }; if is_collab { + state.db.purge_old_embeddings().await.trace_err(); RateLimiter::save_periodically(state.rate_limiter.clone(), state.executor.clone()); } diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index bdcfd487f1..da8328c411 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -32,6 +32,8 @@ use axum::{ use collections::{HashMap, HashSet}; pub use connection_pool::{ConnectionPool, ZedVersion}; use core::fmt::{self, Debug, Formatter}; +use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL}; +use sha2::Digest; use futures::{ channel::oneshot, @@ -568,6 +570,22 @@ impl Server { app_state.config.google_ai_api_key.clone(), ) }) + }) + .add_request_handler({ + user_handler(move |request, response, session| { + get_cached_embeddings(request, response, session) + }) + }) + .add_request_handler({ + let app_state = app_state.clone(); + user_handler(move |request, response, session| { + compute_embeddings( + request, + response, + session, + app_state.config.openai_api_key.clone(), + ) + }) }); Arc::new(server) @@ -4021,8 +4039,6 @@ async fn complete_with_open_ai( session: UserSession, api_key: Arc, ) -> Result<()> { - const OPEN_AI_API_URL: &str = "https://api.openai.com/v1"; - let mut completion_stream = open_ai::stream_completion( &session.http_client, OPEN_AI_API_URL, @@ -4276,6 +4292,128 @@ async fn count_tokens_with_language_model( Ok(()) } +struct ComputeEmbeddingsRateLimit; + +impl RateLimit for ComputeEmbeddingsRateLimit { + fn capacity() -> usize { + std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR") + .ok() + .and_then(|v| v.parse().ok()) + .unwrap_or(120) // Picked arbitrarily + } + + fn refill_duration() -> chrono::Duration { + chrono::Duration::hours(1) + } + + fn db_name() -> &'static str { + "compute-embeddings" + } +} + +async fn compute_embeddings( + request: proto::ComputeEmbeddings, + response: Response, + session: UserSession, + api_key: Option>, +) -> Result<()> { + let api_key = api_key.context("no OpenAI API key configured on the server")?; + authorize_access_to_language_models(&session).await?; + + session + .rate_limiter + .check::(session.user_id()) + .await?; + + let embeddings = match request.model.as_str() { + "openai/text-embedding-3-small" => { + open_ai::embed( + &session.http_client, + OPEN_AI_API_URL, + &api_key, + OpenAiEmbeddingModel::TextEmbedding3Small, + request.texts.iter().map(|text| text.as_str()), + ) + .await? + } + provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?, + }; + + let embeddings = request + .texts + .iter() + .map(|text| { + let mut hasher = sha2::Sha256::new(); + hasher.update(text.as_bytes()); + let result = hasher.finalize(); + result.to_vec() + }) + .zip( + embeddings + .data + .into_iter() + .map(|embedding| embedding.embedding), + ) + .collect::>(); + + let db = session.db().await; + db.save_embeddings(&request.model, &embeddings) + .await + .context("failed to save embeddings") + .trace_err(); + + response.send(proto::ComputeEmbeddingsResponse { + embeddings: embeddings + .into_iter() + .map(|(digest, dimensions)| proto::Embedding { digest, dimensions }) + .collect(), + })?; + Ok(()) +} + +struct GetCachedEmbeddingsRateLimit; + +impl RateLimit for GetCachedEmbeddingsRateLimit { + fn capacity() -> usize { + std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR") + .ok() + .and_then(|v| v.parse().ok()) + .unwrap_or(120) // Picked arbitrarily + } + + fn refill_duration() -> chrono::Duration { + chrono::Duration::hours(1) + } + + fn db_name() -> &'static str { + "get-cached-embeddings" + } +} + +async fn get_cached_embeddings( + request: proto::GetCachedEmbeddings, + response: Response, + session: UserSession, +) -> Result<()> { + authorize_access_to_language_models(&session).await?; + + session + .rate_limiter + .check::(session.user_id()) + .await?; + + let db = session.db().await; + let embeddings = db.get_embeddings(&request.model, &request.digests).await?; + + response.send(proto::GetCachedEmbeddingsResponse { + embeddings: embeddings + .into_iter() + .map(|(digest, dimensions)| proto::Embedding { digest, dimensions }) + .collect(), + })?; + Ok(()) +} + async fn authorize_access_to_language_models(session: &UserSession) -> Result<(), Error> { let db = session.db().await; let flags = db.get_user_flags(session.user_id()).await?; diff --git a/crates/editor/src/git/blame.rs b/crates/editor/src/git/blame.rs index 8f753f7186..6b5420f15c 100644 --- a/crates/editor/src/git/blame.rs +++ b/crates/editor/src/git/blame.rs @@ -396,7 +396,7 @@ mod tests { let blame = cx.new_model(|cx| GitBlame::new(buffer.clone(), project.clone(), cx)); - let event = project.next_event(cx); + let event = project.next_event(cx).await; assert_eq!( event, project::Event::Notification( diff --git a/crates/gpui/src/app/test_context.rs b/crates/gpui/src/app/test_context.rs index b6bfd33f66..d049ceef2f 100644 --- a/crates/gpui/src/app/test_context.rs +++ b/crates/gpui/src/app/test_context.rs @@ -7,7 +7,7 @@ use crate::{ TextSystem, View, ViewContext, VisualContext, WindowContext, WindowHandle, WindowOptions, }; use anyhow::{anyhow, bail}; -use futures::{Stream, StreamExt}; +use futures::{channel::oneshot, Stream, StreamExt}; use std::{cell::RefCell, future::Future, ops::Deref, rc::Rc, sync::Arc, time::Duration}; /// A TestAppContext is provided to tests created with `#[gpui::test]`, it provides @@ -479,31 +479,26 @@ impl TestAppContext { impl Model { /// Block until the next event is emitted by the model, then return it. - pub fn next_event(&self, cx: &mut TestAppContext) -> Evt + pub fn next_event(&self, cx: &mut TestAppContext) -> impl Future where - Evt: Send + Clone + 'static, - T: EventEmitter, + Event: Send + Clone + 'static, + T: EventEmitter, { - let (tx, mut rx) = futures::channel::mpsc::unbounded(); - let _subscription = self.update(cx, |_, cx| { + let (tx, mut rx) = oneshot::channel(); + let mut tx = Some(tx); + let subscription = self.update(cx, |_, cx| { cx.subscribe(self, move |_, _, event, _| { - tx.unbounded_send(event.clone()).ok(); + if let Some(tx) = tx.take() { + _ = tx.send(event.clone()); + } }) }); - // Run other tasks until the event is emitted. - loop { - match rx.try_next() { - Ok(Some(event)) => return event, - Ok(None) => panic!("model was dropped"), - Err(_) => { - if !cx.executor().tick() { - break; - } - } - } + async move { + let event = rx.await.expect("no event emitted"); + drop(subscription); + event } - panic!("no event received") } /// Returns a future that resolves when the model notifies. diff --git a/crates/gpui/src/executor.rs b/crates/gpui/src/executor.rs index 841fc5b19e..115359231c 100644 --- a/crates/gpui/src/executor.rs +++ b/crates/gpui/src/executor.rs @@ -372,7 +372,7 @@ impl BackgroundExecutor { self.dispatcher.as_test().unwrap().rng() } - /// How many CPUs are available to the dispatcher + /// How many CPUs are available to the dispatcher. pub fn num_cpus(&self) -> usize { num_cpus::get() } @@ -440,6 +440,11 @@ impl<'a> Scope<'a> { } } + /// How many CPUs are available to the dispatcher. + pub fn num_cpus(&self) -> usize { + self.executor.num_cpus() + } + /// Spawn a future into this scope. pub fn spawn(&mut self, f: F) where diff --git a/crates/language/src/language.rs b/crates/language/src/language.rs index 65c838cda0..b3046bb562 100644 --- a/crates/language/src/language.rs +++ b/crates/language/src/language.rs @@ -72,7 +72,7 @@ pub use lsp::LanguageServerId; pub use outline::{Outline, OutlineItem}; pub use syntax_map::{OwnedSyntaxLayer, SyntaxLayer}; pub use text::LineEnding; -pub use tree_sitter::{Parser, Tree}; +pub use tree_sitter::{Node, Parser, Tree, TreeCursor}; use crate::language_settings::SoftWrap; @@ -91,6 +91,16 @@ thread_local! { }; } +pub fn with_parser(func: F) -> R +where + F: FnOnce(&mut Parser) -> R, +{ + PARSER.with(|parser| { + let mut parser = parser.borrow_mut(); + func(&mut parser) + }) +} + lazy_static! { static ref NEXT_LANGUAGE_ID: AtomicUsize = Default::default(); static ref NEXT_GRAMMAR_ID: AtomicUsize = Default::default(); diff --git a/crates/open_ai/src/open_ai.rs b/crates/open_ai/src/open_ai.rs index fcf4aa04bf..97abb45dfc 100644 --- a/crates/open_ai/src/open_ai.rs +++ b/crates/open_ai/src/open_ai.rs @@ -1,9 +1,11 @@ -use anyhow::{anyhow, Result}; +use anyhow::{anyhow, Context, Result}; use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt}; use serde::{Deserialize, Serialize}; -use std::convert::TryFrom; +use std::{convert::TryFrom, future::Future}; use util::http::{AsyncBody, HttpClient, Method, Request as HttpRequest}; +pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1"; + #[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)] #[serde(rename_all = "lowercase")] pub enum Role { @@ -188,3 +190,68 @@ pub async fn stream_completion( } } } + +#[derive(Copy, Clone, Serialize, Deserialize)] +pub enum OpenAiEmbeddingModel { + #[serde(rename = "text-embedding-3-small")] + TextEmbedding3Small, + #[serde(rename = "text-embedding-3-large")] + TextEmbedding3Large, +} + +#[derive(Serialize)] +struct OpenAiEmbeddingRequest<'a> { + model: OpenAiEmbeddingModel, + input: Vec<&'a str>, +} + +#[derive(Deserialize)] +pub struct OpenAiEmbeddingResponse { + pub data: Vec, +} + +#[derive(Deserialize)] +pub struct OpenAiEmbedding { + pub embedding: Vec, +} + +pub fn embed<'a>( + client: &dyn HttpClient, + api_url: &str, + api_key: &str, + model: OpenAiEmbeddingModel, + texts: impl IntoIterator, +) -> impl 'static + Future> { + let uri = format!("{api_url}/embeddings"); + + let request = OpenAiEmbeddingRequest { + model, + input: texts.into_iter().collect(), + }; + let body = AsyncBody::from(serde_json::to_string(&request).unwrap()); + let request = HttpRequest::builder() + .method(Method::POST) + .uri(uri) + .header("Content-Type", "application/json") + .header("Authorization", format!("Bearer {}", api_key)) + .body(body) + .map(|request| client.send(request)); + + async move { + let mut response = request?.await?; + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + + if response.status().is_success() { + let response: OpenAiEmbeddingResponse = + serde_json::from_str(&body).context("failed to parse OpenAI embedding response")?; + Ok(response) + } else { + Err(anyhow!( + "error during embedding, status: {:?}, body: {:?}", + response.status(), + body + )) + } + } +} diff --git a/crates/project/src/project.rs b/crates/project/src/project.rs index 6c27e4d612..92b522b3e9 100644 --- a/crates/project/src/project.rs +++ b/crates/project/src/project.rs @@ -978,6 +978,50 @@ impl Project { } } + #[cfg(any(test, feature = "test-support"))] + pub async fn example( + root_paths: impl IntoIterator, + cx: &mut AsyncAppContext, + ) -> Model { + use clock::FakeSystemClock; + + let fs = Arc::new(RealFs::default()); + let languages = LanguageRegistry::test(cx.background_executor().clone()); + let clock = Arc::new(FakeSystemClock::default()); + let http_client = util::http::FakeHttpClient::with_404_response(); + let client = cx + .update(|cx| client::Client::new(clock, http_client.clone(), cx)) + .unwrap(); + let user_store = cx + .new_model(|cx| UserStore::new(client.clone(), cx)) + .unwrap(); + let project = cx + .update(|cx| { + Project::local( + client, + node_runtime::FakeNodeRuntime::new(), + user_store, + Arc::new(languages), + fs, + cx, + ) + }) + .unwrap(); + for path in root_paths { + let (tree, _) = project + .update(cx, |project, cx| { + project.find_or_create_local_worktree(path, true, cx) + }) + .unwrap() + .await + .unwrap(); + tree.update(cx, |tree, _| tree.as_local().unwrap().scan_complete()) + .unwrap() + .await; + } + project + } + #[cfg(any(test, feature = "test-support"))] pub async fn test( fs: Arc, @@ -1146,6 +1190,10 @@ impl Project { self.user_store.clone() } + pub fn node_runtime(&self) -> Option<&Arc> { + self.node.as_ref() + } + pub fn opened_buffers(&self) -> Vec> { self.opened_buffers .values() diff --git a/crates/project/src/project_tests.rs b/crates/project/src/project_tests.rs index fa00ea6736..fed32ccf4e 100644 --- a/crates/project/src/project_tests.rs +++ b/crates/project/src/project_tests.rs @@ -2661,7 +2661,7 @@ async fn test_file_changes_multiple_times_on_disk(cx: &mut gpui::TestAppContext) ) .await .unwrap(); - worktree.next_event(cx); + worktree.next_event(cx).await; // Change the buffer's file again. Depending on the random seed, the // previous file change may still be in progress. @@ -2672,7 +2672,7 @@ async fn test_file_changes_multiple_times_on_disk(cx: &mut gpui::TestAppContext) ) .await .unwrap(); - worktree.next_event(cx); + worktree.next_event(cx).await; cx.executor().run_until_parked(); let on_disk_text = fs.load(Path::new("/dir/file1")).await.unwrap(); @@ -2716,7 +2716,7 @@ async fn test_edit_buffer_while_it_reloads(cx: &mut gpui::TestAppContext) { ) .await .unwrap(); - worktree.next_event(cx); + worktree.next_event(cx).await; cx.executor() .spawn(cx.executor().simulate_random_delay()) diff --git a/crates/rpc/proto/zed.proto b/crates/rpc/proto/zed.proto index 8cbeed1ae8..606c0bb101 100644 --- a/crates/rpc/proto/zed.proto +++ b/crates/rpc/proto/zed.proto @@ -204,6 +204,11 @@ message Envelope { LanguageModelResponse language_model_response = 167; CountTokensWithLanguageModel count_tokens_with_language_model = 168; CountTokensResponse count_tokens_response = 169; + GetCachedEmbeddings get_cached_embeddings = 189; + GetCachedEmbeddingsResponse get_cached_embeddings_response = 190; + ComputeEmbeddings compute_embeddings = 191; + ComputeEmbeddingsResponse compute_embeddings_response = 192; // current max + UpdateChannelMessage update_channel_message = 170; ChannelMessageUpdate channel_message_update = 171; @@ -216,7 +221,7 @@ message Envelope { MultiLspQueryResponse multi_lsp_query_response = 176; CreateRemoteProject create_remote_project = 177; - CreateRemoteProjectResponse create_remote_project_response = 188; // current max + CreateRemoteProjectResponse create_remote_project_response = 188; CreateDevServer create_dev_server = 178; CreateDevServerResponse create_dev_server_response = 179; ShutdownDevServer shutdown_dev_server = 180; @@ -1892,6 +1897,29 @@ message CountTokensResponse { uint32 token_count = 1; } +message GetCachedEmbeddings { + string model = 1; + repeated bytes digests = 2; +} + +message GetCachedEmbeddingsResponse { + repeated Embedding embeddings = 1; +} + +message ComputeEmbeddings { + string model = 1; + repeated string texts = 2; +} + +message ComputeEmbeddingsResponse { + repeated Embedding embeddings = 1; +} + +message Embedding { + bytes digest = 1; + repeated float dimensions = 2; +} + message BlameBuffer { uint64 project_id = 1; uint64 buffer_id = 2; diff --git a/crates/rpc/src/proto.rs b/crates/rpc/src/proto.rs index a117648cec..48160b2fe4 100644 --- a/crates/rpc/src/proto.rs +++ b/crates/rpc/src/proto.rs @@ -151,6 +151,8 @@ messages!( (ChannelMessageSent, Foreground), (ChannelMessageUpdate, Foreground), (CompleteWithLanguageModel, Background), + (ComputeEmbeddings, Background), + (ComputeEmbeddingsResponse, Background), (CopyProjectEntry, Foreground), (CountTokensWithLanguageModel, Background), (CountTokensResponse, Background), @@ -174,6 +176,8 @@ messages!( (FormatBuffers, Foreground), (FormatBuffersResponse, Foreground), (FuzzySearchUsers, Foreground), + (GetCachedEmbeddings, Background), + (GetCachedEmbeddingsResponse, Background), (GetChannelMembers, Foreground), (GetChannelMembersResponse, Foreground), (GetChannelMessages, Background), @@ -325,6 +329,7 @@ request_messages!( (CancelCall, Ack), (CopyProjectEntry, ProjectEntryResponse), (CompleteWithLanguageModel, LanguageModelResponse), + (ComputeEmbeddings, ComputeEmbeddingsResponse), (CountTokensWithLanguageModel, CountTokensResponse), (CreateChannel, CreateChannelResponse), (CreateProjectEntry, ProjectEntryResponse), @@ -336,6 +341,7 @@ request_messages!( (Follow, FollowResponse), (FormatBuffers, FormatBuffersResponse), (FuzzySearchUsers, UsersResponse), + (GetCachedEmbeddings, GetCachedEmbeddingsResponse), (GetChannelMembers, GetChannelMembersResponse), (GetChannelMessages, GetChannelMessagesResponse), (GetChannelMessagesById, GetChannelMessagesResponse), diff --git a/crates/semantic_index/Cargo.toml b/crates/semantic_index/Cargo.toml new file mode 100644 index 0000000000..5010a2fa49 --- /dev/null +++ b/crates/semantic_index/Cargo.toml @@ -0,0 +1,48 @@ +[package] +name = "semantic_index" +description = "Process, chunk, and embed text as vectors for semantic search." +version = "0.1.0" +edition = "2021" +publish = false +license = "GPL-3.0-or-later" + +[lib] +path = "src/semantic_index.rs" + +[dependencies] +anyhow.workspace = true +client.workspace = true +clock.workspace = true +collections.workspace = true +fs.workspace = true +futures.workspace = true +futures-batch.workspace = true +gpui.workspace = true +language.workspace = true +log.workspace = true +heed.workspace = true +open_ai.workspace = true +project.workspace = true +settings.workspace = true +serde.workspace = true +serde_json.workspace = true +sha2.workspace = true +smol.workspace = true +util. workspace = true +worktree.workspace = true + +[dev-dependencies] +env_logger.workspace = true +client = { workspace = true, features = ["test-support"] } +fs = { workspace = true, features = ["test-support"] } +futures.workspace = true +gpui = { workspace = true, features = ["test-support"] } +language = { workspace = true, features = ["test-support"] } +languages.workspace = true +project = { workspace = true, features = ["test-support"] } +tempfile.workspace = true +util = { workspace = true, features = ["test-support"] } +worktree = { workspace = true, features = ["test-support"] } + +[lints] +workspace = true diff --git a/crates/semantic_index/LICENSE-GPL b/crates/semantic_index/LICENSE-GPL new file mode 120000 index 0000000000..89e542f750 --- /dev/null +++ b/crates/semantic_index/LICENSE-GPL @@ -0,0 +1 @@ +../../LICENSE-GPL \ No newline at end of file diff --git a/crates/semantic_index/examples/index.rs b/crates/semantic_index/examples/index.rs new file mode 100644 index 0000000000..494d8a0f81 --- /dev/null +++ b/crates/semantic_index/examples/index.rs @@ -0,0 +1,140 @@ +use client::Client; +use futures::channel::oneshot; +use gpui::{App, Global, TestAppContext}; +use language::language_settings::AllLanguageSettings; +use project::Project; +use semantic_index::{OpenAiEmbeddingModel, OpenAiEmbeddingProvider, SemanticIndex}; +use settings::SettingsStore; +use std::{path::Path, sync::Arc}; +use util::http::HttpClientWithUrl; + +pub fn init_test(cx: &mut TestAppContext) { + _ = cx.update(|cx| { + let store = SettingsStore::test(cx); + cx.set_global(store); + language::init(cx); + Project::init_settings(cx); + SettingsStore::update(cx, |store, cx| { + store.update_user_settings::(cx, |_| {}); + }); + }); +} + +fn main() { + env_logger::init(); + + use clock::FakeSystemClock; + + App::new().run(|cx| { + let store = SettingsStore::test(cx); + cx.set_global(store); + language::init(cx); + Project::init_settings(cx); + SettingsStore::update(cx, |store, cx| { + store.update_user_settings::(cx, |_| {}); + }); + + let clock = Arc::new(FakeSystemClock::default()); + let http = Arc::new(HttpClientWithUrl::new("http://localhost:11434")); + + let client = client::Client::new(clock, http.clone(), cx); + Client::set_global(client.clone(), cx); + + let args: Vec = std::env::args().collect(); + if args.len() < 2 { + eprintln!("Usage: cargo run --example index -p semantic_index -- "); + cx.quit(); + return; + } + + // let embedding_provider = semantic_index::FakeEmbeddingProvider; + + let api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"); + let embedding_provider = OpenAiEmbeddingProvider::new( + http.clone(), + OpenAiEmbeddingModel::TextEmbedding3Small, + open_ai::OPEN_AI_API_URL.to_string(), + api_key, + ); + + let semantic_index = SemanticIndex::new( + Path::new("/tmp/semantic-index-db.mdb"), + Arc::new(embedding_provider), + cx, + ); + + cx.spawn(|mut cx| async move { + let mut semantic_index = semantic_index.await.unwrap(); + + let project_path = Path::new(&args[1]); + + let project = Project::example([project_path], &mut cx).await; + + cx.update(|cx| { + let language_registry = project.read(cx).languages().clone(); + let node_runtime = project.read(cx).node_runtime().unwrap().clone(); + languages::init(language_registry, node_runtime, cx); + }) + .unwrap(); + + let project_index = cx + .update(|cx| semantic_index.project_index(project.clone(), cx)) + .unwrap(); + + let (tx, rx) = oneshot::channel(); + let mut tx = Some(tx); + let subscription = cx.update(|cx| { + cx.subscribe(&project_index, move |_, event, _| { + if let Some(tx) = tx.take() { + _ = tx.send(*event); + } + }) + }); + + let index_start = std::time::Instant::now(); + rx.await.expect("no event emitted"); + drop(subscription); + println!("Index time: {:?}", index_start.elapsed()); + + let results = cx + .update(|cx| { + let project_index = project_index.read(cx); + let query = "converting an anchor to a point"; + project_index.search(query, 4, cx) + }) + .unwrap() + .await; + + for search_result in results { + let path = search_result.path.clone(); + + let content = cx + .update(|cx| { + let worktree = search_result.worktree.read(cx); + let entry_abs_path = worktree.abs_path().join(search_result.path.clone()); + let fs = project.read(cx).fs().clone(); + cx.spawn(|_| async move { fs.load(&entry_abs_path).await.unwrap() }) + }) + .unwrap() + .await; + + let range = search_result.range.clone(); + let content = content[search_result.range].to_owned(); + + println!( + "✄✄✄✄✄✄✄✄✄✄✄✄✄✄ {:?} @ {} ✄✄✄✄✄✄✄✄✄✄✄✄✄✄", + path, search_result.score + ); + println!("{:?}:{:?}:{:?}", path, range.start, range.end); + println!("{}", content); + } + + cx.background_executor() + .timer(std::time::Duration::from_secs(100000)) + .await; + + cx.update(|cx| cx.quit()).unwrap(); + }) + .detach(); + }); +} diff --git a/crates/semantic_index/fixture/main.rs b/crates/semantic_index/fixture/main.rs new file mode 100644 index 0000000000..f8796c8f45 --- /dev/null +++ b/crates/semantic_index/fixture/main.rs @@ -0,0 +1,3 @@ +fn main() { + println!("Hello Indexer!"); +} diff --git a/crates/semantic_index/fixture/needle.md b/crates/semantic_index/fixture/needle.md new file mode 100644 index 0000000000..80487c9983 --- /dev/null +++ b/crates/semantic_index/fixture/needle.md @@ -0,0 +1,43 @@ +# Searching for a needle in a haystack + +When you have a large amount of text, it can be useful to search for a specific word or phrase. This is often referred to as "finding a needle in a haystack." In this markdown document, we're "hiding" a key phrase for our text search to find. Can you find it? + +## Instructions + +1. Use the search functionality in your text editor or markdown viewer to find the hidden phrase in this document. + +2. Once you've found the **phrase**, write it down and proceed to the next step. + +Honestly, I just want to fill up plenty of characters so that we chunk this markdown into several chunks. + +## Tips + +- Relax +- Take a deep breath +- Focus on the task at hand +- Don't get distracted by other text +- Use the search functionality to your advantage + +## Example code + +```python +def search_for_needle(haystack, needle): + if needle in haystack: + return True + else: + return False +``` + +```javascript +function searchForNeedle(haystack, needle) { + return haystack.includes(needle); +} +``` + +## Background + +When creating an index for a book or searching for a specific term in a large document, the ability to quickly find a specific word or phrase is essential. This is where search functionality comes in handy. However, one should _remember_ that the search is only as good as the index that was built. As they say, garbage in, garbage out! + +## Conclusion + +Searching for a needle in a haystack can be a challenging task, but with the right tools and techniques, it becomes much easier. Whether you're looking for a specific word in a document or trying to find a key piece of information in a large dataset, the ability to search efficiently is a valuable skill to have. diff --git a/crates/semantic_index/src/chunking.rs b/crates/semantic_index/src/chunking.rs new file mode 100644 index 0000000000..da37afd78c --- /dev/null +++ b/crates/semantic_index/src/chunking.rs @@ -0,0 +1,409 @@ +use language::{with_parser, Grammar, Tree}; +use serde::{Deserialize, Serialize}; +use sha2::{Digest, Sha256}; +use std::{cmp, ops::Range, sync::Arc}; + +const CHUNK_THRESHOLD: usize = 1500; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Chunk { + pub range: Range, + pub digest: [u8; 32], +} + +pub fn chunk_text(text: &str, grammar: Option<&Arc>) -> Vec { + if let Some(grammar) = grammar { + let tree = with_parser(|parser| { + parser + .set_language(&grammar.ts_language) + .expect("incompatible grammar"); + parser.parse(&text, None).expect("invalid language") + }); + + chunk_parse_tree(tree, &text, CHUNK_THRESHOLD) + } else { + chunk_lines(&text) + } +} + +fn chunk_parse_tree(tree: Tree, text: &str, chunk_threshold: usize) -> Vec { + let mut chunk_ranges = Vec::new(); + let mut cursor = tree.walk(); + + let mut range = 0..0; + loop { + let node = cursor.node(); + + // If adding the node to the current chunk exceeds the threshold + if node.end_byte() - range.start > chunk_threshold { + // Try to descend into its first child. If we can't, flush the current + // range and try again. + if cursor.goto_first_child() { + continue; + } else if !range.is_empty() { + chunk_ranges.push(range.clone()); + range.start = range.end; + continue; + } + + // If we get here, the node itself has no children but is larger than the threshold. + // Break its text into arbitrary chunks. + split_text(text, range.clone(), node.end_byte(), &mut chunk_ranges); + } + range.end = node.end_byte(); + + // If we get here, we consumed the node. Advance to the next child, ascending if there isn't one. + while !cursor.goto_next_sibling() { + if !cursor.goto_parent() { + if !range.is_empty() { + chunk_ranges.push(range); + } + + return chunk_ranges + .into_iter() + .map(|range| { + let digest = Sha256::digest(&text[range.clone()]).into(); + Chunk { range, digest } + }) + .collect(); + } + } + } +} + +fn chunk_lines(text: &str) -> Vec { + let mut chunk_ranges = Vec::new(); + let mut range = 0..0; + + let mut newlines = text.match_indices('\n').peekable(); + while let Some((newline_ix, _)) = newlines.peek() { + let newline_ix = newline_ix + 1; + if newline_ix - range.start <= CHUNK_THRESHOLD { + range.end = newline_ix; + newlines.next(); + } else { + if range.is_empty() { + split_text(text, range, newline_ix, &mut chunk_ranges); + range = newline_ix..newline_ix; + } else { + chunk_ranges.push(range.clone()); + range.start = range.end; + } + } + } + + if !range.is_empty() { + chunk_ranges.push(range); + } + + chunk_ranges + .into_iter() + .map(|range| { + let mut hasher = Sha256::new(); + hasher.update(&text[range.clone()]); + let mut digest = [0u8; 32]; + digest.copy_from_slice(hasher.finalize().as_slice()); + Chunk { range, digest } + }) + .collect() +} + +fn split_text( + text: &str, + mut range: Range, + max_end: usize, + chunk_ranges: &mut Vec>, +) { + while range.start < max_end { + range.end = cmp::min(range.start + CHUNK_THRESHOLD, max_end); + while !text.is_char_boundary(range.end) { + range.end -= 1; + } + chunk_ranges.push(range.clone()); + range.start = range.end; + } +} + +#[cfg(test)] +mod tests { + use super::*; + use language::{tree_sitter_rust, Language, LanguageConfig, LanguageMatcher}; + + // This example comes from crates/gpui/examples/window_positioning.rs which + // has the property of being CHUNK_THRESHOLD < TEXT.len() < 2*CHUNK_THRESHOLD + static TEXT: &str = r#" + use gpui::*; + + struct WindowContent { + text: SharedString, + } + + impl Render for WindowContent { + fn render(&mut self, _cx: &mut ViewContext) -> impl IntoElement { + div() + .flex() + .bg(rgb(0x1e2025)) + .size_full() + .justify_center() + .items_center() + .text_xl() + .text_color(rgb(0xffffff)) + .child(self.text.clone()) + } + } + + fn main() { + App::new().run(|cx: &mut AppContext| { + // Create several new windows, positioned in the top right corner of each screen + + for screen in cx.displays() { + let options = { + let popup_margin_width = DevicePixels::from(16); + let popup_margin_height = DevicePixels::from(-0) - DevicePixels::from(48); + + let window_size = Size { + width: px(400.), + height: px(72.), + }; + + let screen_bounds = screen.bounds(); + let size: Size = window_size.into(); + + let bounds = gpui::Bounds:: { + origin: screen_bounds.upper_right() + - point(size.width + popup_margin_width, popup_margin_height), + size: window_size.into(), + }; + + WindowOptions { + // Set the bounds of the window in screen coordinates + bounds: Some(bounds), + // Specify the display_id to ensure the window is created on the correct screen + display_id: Some(screen.id()), + + titlebar: None, + window_background: WindowBackgroundAppearance::default(), + focus: false, + show: true, + kind: WindowKind::PopUp, + is_movable: false, + fullscreen: false, + } + }; + + cx.open_window(options, |cx| { + cx.new_view(|_| WindowContent { + text: format!("{:?}", screen.id()).into(), + }) + }); + } + }); + }"#; + + fn setup_rust_language() -> Language { + Language::new( + LanguageConfig { + name: "Rust".into(), + matcher: LanguageMatcher { + path_suffixes: vec!["rs".to_string()], + ..Default::default() + }, + ..Default::default() + }, + Some(tree_sitter_rust::language()), + ) + } + + #[test] + fn test_chunk_text() { + let text = "a\n".repeat(1000); + let chunks = chunk_text(&text, None); + assert_eq!( + chunks.len(), + ((2000_f64) / (CHUNK_THRESHOLD as f64)).ceil() as usize + ); + } + + #[test] + fn test_chunk_text_grammar() { + // Let's set up a big text with some known segments + // We'll then chunk it and verify that the chunks are correct + + let language = setup_rust_language(); + + let chunks = chunk_text(TEXT, language.grammar()); + assert_eq!(chunks.len(), 2); + + assert_eq!(chunks[0].range.start, 0); + assert_eq!(chunks[0].range.end, 1498); + // The break between chunks is right before the "Specify the display_id" comment + + assert_eq!(chunks[1].range.start, 1498); + assert_eq!(chunks[1].range.end, 2396); + } + + #[test] + fn test_chunk_parse_tree() { + let language = setup_rust_language(); + let grammar = language.grammar().unwrap(); + + let tree = with_parser(|parser| { + parser + .set_language(&grammar.ts_language) + .expect("incompatible grammar"); + parser.parse(TEXT, None).expect("invalid language") + }); + + let chunks = chunk_parse_tree(tree, TEXT, 250); + assert_eq!(chunks.len(), 11); + } + + #[test] + fn test_chunk_unparsable() { + // Even if a chunk is unparsable, we should still be able to chunk it + let language = setup_rust_language(); + let grammar = language.grammar().unwrap(); + + let text = r#"fn main() {"#; + let tree = with_parser(|parser| { + parser + .set_language(&grammar.ts_language) + .expect("incompatible grammar"); + parser.parse(text, None).expect("invalid language") + }); + + let chunks = chunk_parse_tree(tree, text, 250); + assert_eq!(chunks.len(), 1); + + assert_eq!(chunks[0].range.start, 0); + assert_eq!(chunks[0].range.end, 11); + } + + #[test] + fn test_empty_text() { + let language = setup_rust_language(); + let grammar = language.grammar().unwrap(); + + let tree = with_parser(|parser| { + parser + .set_language(&grammar.ts_language) + .expect("incompatible grammar"); + parser.parse("", None).expect("invalid language") + }); + + let chunks = chunk_parse_tree(tree, "", CHUNK_THRESHOLD); + assert!(chunks.is_empty(), "Chunks should be empty for empty text"); + } + + #[test] + fn test_single_large_node() { + let large_text = "static ".to_owned() + "a".repeat(CHUNK_THRESHOLD - 1).as_str() + " = 2"; + + let language = setup_rust_language(); + let grammar = language.grammar().unwrap(); + + let tree = with_parser(|parser| { + parser + .set_language(&grammar.ts_language) + .expect("incompatible grammar"); + parser.parse(&large_text, None).expect("invalid language") + }); + + let chunks = chunk_parse_tree(tree, &large_text, CHUNK_THRESHOLD); + + assert_eq!( + chunks.len(), + 3, + "Large chunks are broken up according to grammar as best as possible" + ); + + // Expect chunks to be static, aaaaaa..., and = 2 + assert_eq!(chunks[0].range.start, 0); + assert_eq!(chunks[0].range.end, "static".len()); + + assert_eq!(chunks[1].range.start, "static".len()); + assert_eq!(chunks[1].range.end, "static".len() + CHUNK_THRESHOLD); + + assert_eq!(chunks[2].range.start, "static".len() + CHUNK_THRESHOLD); + assert_eq!(chunks[2].range.end, large_text.len()); + } + + #[test] + fn test_multiple_small_nodes() { + let small_text = "a b c d e f g h i j k l m n o p q r s t u v w x y z"; + let language = setup_rust_language(); + let grammar = language.grammar().unwrap(); + + let tree = with_parser(|parser| { + parser + .set_language(&grammar.ts_language) + .expect("incompatible grammar"); + parser.parse(small_text, None).expect("invalid language") + }); + + let chunks = chunk_parse_tree(tree, small_text, 5); + assert!( + chunks.len() > 1, + "Should have multiple chunks for multiple small nodes" + ); + } + + #[test] + fn test_node_with_children() { + let nested_text = "fn main() { let a = 1; let b = 2; }"; + let language = setup_rust_language(); + let grammar = language.grammar().unwrap(); + + let tree = with_parser(|parser| { + parser + .set_language(&grammar.ts_language) + .expect("incompatible grammar"); + parser.parse(nested_text, None).expect("invalid language") + }); + + let chunks = chunk_parse_tree(tree, nested_text, 10); + assert!( + chunks.len() > 1, + "Should have multiple chunks for a node with children" + ); + } + + #[test] + fn test_text_with_unparsable_sections() { + // This test uses purposefully hit-or-miss sizing of 11 characters per likely chunk + let mixed_text = "fn main() { let a = 1; let b = 2; } unparsable bits here"; + let language = setup_rust_language(); + let grammar = language.grammar().unwrap(); + + let tree = with_parser(|parser| { + parser + .set_language(&grammar.ts_language) + .expect("incompatible grammar"); + parser.parse(mixed_text, None).expect("invalid language") + }); + + let chunks = chunk_parse_tree(tree, mixed_text, 11); + assert!( + chunks.len() > 1, + "Should handle both parsable and unparsable sections correctly" + ); + + let expected_chunks = [ + "fn main() {", + " let a = 1;", + " let b = 2;", + " }", + " unparsable", + " bits here", + ]; + + for (i, chunk) in chunks.iter().enumerate() { + assert_eq!( + &mixed_text[chunk.range.clone()], + expected_chunks[i], + "Chunk {} should match", + i + ); + } + } +} diff --git a/crates/semantic_index/src/embedding.rs b/crates/semantic_index/src/embedding.rs new file mode 100644 index 0000000000..b5195c8911 --- /dev/null +++ b/crates/semantic_index/src/embedding.rs @@ -0,0 +1,125 @@ +mod cloud; +mod ollama; +mod open_ai; + +pub use cloud::*; +pub use ollama::*; +pub use open_ai::*; +use sha2::{Digest, Sha256}; + +use anyhow::Result; +use futures::{future::BoxFuture, FutureExt}; +use serde::{Deserialize, Serialize}; +use std::{fmt, future}; + +#[derive(Debug, Default, Clone, PartialEq, Serialize, Deserialize)] +pub struct Embedding(Vec); + +impl Embedding { + pub fn new(mut embedding: Vec) -> Self { + let len = embedding.len(); + let mut norm = 0f32; + + for i in 0..len { + norm += embedding[i] * embedding[i]; + } + + norm = norm.sqrt(); + for dimension in &mut embedding { + *dimension /= norm; + } + + Self(embedding) + } + + fn len(&self) -> usize { + self.0.len() + } + + pub fn similarity(self, other: &Embedding) -> f32 { + debug_assert_eq!(self.0.len(), other.0.len()); + self.0 + .iter() + .copied() + .zip(other.0.iter().copied()) + .map(|(a, b)| a * b) + .sum() + } +} + +impl fmt::Display for Embedding { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let digits_to_display = 3; + + // Start the Embedding display format + write!(f, "Embedding(sized: {}; values: [", self.len())?; + + for (index, value) in self.0.iter().enumerate().take(digits_to_display) { + // Lead with comma if not the first element + if index != 0 { + write!(f, ", ")?; + } + write!(f, "{:.3}", value)?; + } + if self.len() > digits_to_display { + write!(f, "...")?; + } + write!(f, "])") + } +} + +/// Trait for embedding providers. Texts in, vectors out. +pub trait EmbeddingProvider: Sync + Send { + fn embed<'a>(&'a self, texts: &'a [TextToEmbed<'a>]) -> BoxFuture<'a, Result>>; + fn batch_size(&self) -> usize; +} + +#[derive(Debug)] +pub struct TextToEmbed<'a> { + pub text: &'a str, + pub digest: [u8; 32], +} + +impl<'a> TextToEmbed<'a> { + pub fn new(text: &'a str) -> Self { + let digest = Sha256::digest(text.as_bytes()); + Self { + text, + digest: digest.into(), + } + } +} + +pub struct FakeEmbeddingProvider; + +impl EmbeddingProvider for FakeEmbeddingProvider { + fn embed<'a>(&'a self, texts: &'a [TextToEmbed<'a>]) -> BoxFuture<'a, Result>> { + let embeddings = texts + .iter() + .map(|_text| { + let mut embedding = vec![0f32; 1536]; + for i in 0..embedding.len() { + embedding[i] = i as f32; + } + Embedding::new(embedding) + }) + .collect(); + future::ready(Ok(embeddings)).boxed() + } + + fn batch_size(&self) -> usize { + 16 + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[gpui::test] + fn test_normalize_embedding() { + let normalized = Embedding::new(vec![1.0, 1.0, 1.0]); + let value: f32 = 1.0 / 3.0_f32.sqrt(); + assert_eq!(normalized, Embedding(vec![value; 3])); + } +} diff --git a/crates/semantic_index/src/embedding/cloud.rs b/crates/semantic_index/src/embedding/cloud.rs new file mode 100644 index 0000000000..2a1df705c8 --- /dev/null +++ b/crates/semantic_index/src/embedding/cloud.rs @@ -0,0 +1,88 @@ +use crate::{Embedding, EmbeddingProvider, TextToEmbed}; +use anyhow::{anyhow, Context, Result}; +use client::{proto, Client}; +use collections::HashMap; +use futures::{future::BoxFuture, FutureExt}; +use std::sync::Arc; + +pub struct CloudEmbeddingProvider { + model: String, + client: Arc, +} + +impl CloudEmbeddingProvider { + pub fn new(client: Arc) -> Self { + Self { + model: "openai/text-embedding-3-small".into(), + client, + } + } +} + +impl EmbeddingProvider for CloudEmbeddingProvider { + fn embed<'a>(&'a self, texts: &'a [TextToEmbed<'a>]) -> BoxFuture<'a, Result>> { + // First, fetch any embeddings that are cached based on the requested texts' digests + // Then compute any embeddings that are missing. + async move { + let cached_embeddings = self.client.request(proto::GetCachedEmbeddings { + model: self.model.clone(), + digests: texts + .iter() + .map(|to_embed| to_embed.digest.to_vec()) + .collect(), + }); + let mut embeddings = cached_embeddings + .await + .context("failed to fetch cached embeddings via cloud model")? + .embeddings + .into_iter() + .map(|embedding| { + let digest: [u8; 32] = embedding + .digest + .try_into() + .map_err(|_| anyhow!("invalid digest for cached embedding"))?; + Ok((digest, embedding.dimensions)) + }) + .collect::>>()?; + + let compute_embeddings_request = proto::ComputeEmbeddings { + model: self.model.clone(), + texts: texts + .iter() + .filter_map(|to_embed| { + if embeddings.contains_key(&to_embed.digest) { + None + } else { + Some(to_embed.text.to_string()) + } + }) + .collect(), + }; + if !compute_embeddings_request.texts.is_empty() { + let missing_embeddings = self.client.request(compute_embeddings_request).await?; + for embedding in missing_embeddings.embeddings { + let digest: [u8; 32] = embedding + .digest + .try_into() + .map_err(|_| anyhow!("invalid digest for cached embedding"))?; + embeddings.insert(digest, embedding.dimensions); + } + } + + texts + .iter() + .map(|to_embed| { + let dimensions = embeddings.remove(&to_embed.digest).with_context(|| { + format!("server did not return an embedding for {:?}", to_embed) + })?; + Ok(Embedding::new(dimensions)) + }) + .collect() + } + .boxed() + } + + fn batch_size(&self) -> usize { + 2048 + } +} diff --git a/crates/semantic_index/src/embedding/ollama.rs b/crates/semantic_index/src/embedding/ollama.rs new file mode 100644 index 0000000000..8b8a36481b --- /dev/null +++ b/crates/semantic_index/src/embedding/ollama.rs @@ -0,0 +1,74 @@ +use anyhow::{Context as _, Result}; +use futures::{future::BoxFuture, AsyncReadExt, FutureExt}; +use serde::{Deserialize, Serialize}; +use std::sync::Arc; +use util::http::HttpClient; + +use crate::{Embedding, EmbeddingProvider, TextToEmbed}; + +pub enum OllamaEmbeddingModel { + NomicEmbedText, + MxbaiEmbedLarge, +} + +pub struct OllamaEmbeddingProvider { + client: Arc, + model: OllamaEmbeddingModel, +} + +#[derive(Serialize)] +struct OllamaEmbeddingRequest { + model: String, + prompt: String, +} + +#[derive(Deserialize)] +struct OllamaEmbeddingResponse { + embedding: Vec, +} + +impl OllamaEmbeddingProvider { + pub fn new(client: Arc, model: OllamaEmbeddingModel) -> Self { + Self { client, model } + } +} + +impl EmbeddingProvider for OllamaEmbeddingProvider { + fn embed<'a>(&'a self, texts: &'a [TextToEmbed<'a>]) -> BoxFuture<'a, Result>> { + // + let model = match self.model { + OllamaEmbeddingModel::NomicEmbedText => "nomic-embed-text", + OllamaEmbeddingModel::MxbaiEmbedLarge => "mxbai-embed-large", + }; + + futures::future::try_join_all(texts.into_iter().map(|to_embed| { + let request = OllamaEmbeddingRequest { + model: model.to_string(), + prompt: to_embed.text.to_string(), + }; + + let request = serde_json::to_string(&request).unwrap(); + + async { + let response = self + .client + .post_json("http://localhost:11434/api/embeddings", request.into()) + .await?; + + let mut body = String::new(); + response.into_body().read_to_string(&mut body).await?; + + let response: OllamaEmbeddingResponse = + serde_json::from_str(&body).context("Unable to pull response")?; + + Ok(Embedding::new(response.embedding)) + } + })) + .boxed() + } + + fn batch_size(&self) -> usize { + // TODO: Figure out decent value + 10 + } +} diff --git a/crates/semantic_index/src/embedding/open_ai.rs b/crates/semantic_index/src/embedding/open_ai.rs new file mode 100644 index 0000000000..8eccb5272f --- /dev/null +++ b/crates/semantic_index/src/embedding/open_ai.rs @@ -0,0 +1,55 @@ +use crate::{Embedding, EmbeddingProvider, TextToEmbed}; +use anyhow::Result; +use futures::{future::BoxFuture, FutureExt}; +pub use open_ai::OpenAiEmbeddingModel; +use std::sync::Arc; +use util::http::HttpClient; + +pub struct OpenAiEmbeddingProvider { + client: Arc, + model: OpenAiEmbeddingModel, + api_url: String, + api_key: String, +} + +impl OpenAiEmbeddingProvider { + pub fn new( + client: Arc, + model: OpenAiEmbeddingModel, + api_url: String, + api_key: String, + ) -> Self { + Self { + client, + model, + api_url, + api_key, + } + } +} + +impl EmbeddingProvider for OpenAiEmbeddingProvider { + fn embed<'a>(&'a self, texts: &'a [TextToEmbed<'a>]) -> BoxFuture<'a, Result>> { + let embed = open_ai::embed( + self.client.as_ref(), + &self.api_url, + &self.api_key, + self.model, + texts.iter().map(|to_embed| to_embed.text), + ); + async move { + let response = embed.await?; + Ok(response + .data + .into_iter() + .map(|data| Embedding::new(data.embedding)) + .collect()) + } + .boxed() + } + + fn batch_size(&self) -> usize { + // From https://platform.openai.com/docs/api-reference/embeddings/create + 2048 + } +} diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs new file mode 100644 index 0000000000..86c18c6fe6 --- /dev/null +++ b/crates/semantic_index/src/semantic_index.rs @@ -0,0 +1,954 @@ +mod chunking; +mod embedding; + +use anyhow::{anyhow, Context as _, Result}; +use chunking::{chunk_text, Chunk}; +use collections::{Bound, HashMap}; +pub use embedding::*; +use fs::Fs; +use futures::stream::StreamExt; +use futures_batch::ChunksTimeoutStreamExt; +use gpui::{ + AppContext, AsyncAppContext, Context, EntityId, EventEmitter, Global, Model, ModelContext, + Subscription, Task, WeakModel, +}; +use heed::types::{SerdeBincode, Str}; +use language::LanguageRegistry; +use project::{Entry, Project, UpdatedEntriesSet, Worktree}; +use serde::{Deserialize, Serialize}; +use smol::channel; +use std::{ + cmp::Ordering, + future::Future, + ops::Range, + path::Path, + sync::Arc, + time::{Duration, SystemTime}, +}; +use util::ResultExt; +use worktree::LocalSnapshot; + +pub struct SemanticIndex { + embedding_provider: Arc, + db_connection: heed::Env, + project_indices: HashMap, Model>, +} + +impl Global for SemanticIndex {} + +impl SemanticIndex { + pub fn new( + db_path: &Path, + embedding_provider: Arc, + cx: &mut AppContext, + ) -> Task> { + let db_path = db_path.to_path_buf(); + cx.spawn(|cx| async move { + let db_connection = cx + .background_executor() + .spawn(async move { + unsafe { + heed::EnvOpenOptions::new() + .map_size(1024 * 1024 * 1024) + .max_dbs(3000) + .open(db_path) + } + }) + .await?; + + Ok(SemanticIndex { + db_connection, + embedding_provider, + project_indices: HashMap::default(), + }) + }) + } + + pub fn project_index( + &mut self, + project: Model, + cx: &mut AppContext, + ) -> Model { + self.project_indices + .entry(project.downgrade()) + .or_insert_with(|| { + cx.new_model(|cx| { + ProjectIndex::new( + project, + self.db_connection.clone(), + self.embedding_provider.clone(), + cx, + ) + }) + }) + .clone() + } +} + +pub struct ProjectIndex { + db_connection: heed::Env, + project: Model, + worktree_indices: HashMap, + language_registry: Arc, + fs: Arc, + last_status: Status, + embedding_provider: Arc, + _subscription: Subscription, +} + +enum WorktreeIndexHandle { + Loading { + _task: Task>, + }, + Loaded { + index: Model, + _subscription: Subscription, + }, +} + +impl ProjectIndex { + fn new( + project: Model, + db_connection: heed::Env, + embedding_provider: Arc, + cx: &mut ModelContext, + ) -> Self { + let language_registry = project.read(cx).languages().clone(); + let fs = project.read(cx).fs().clone(); + let mut this = ProjectIndex { + db_connection, + project: project.clone(), + worktree_indices: HashMap::default(), + language_registry, + fs, + last_status: Status::Idle, + embedding_provider, + _subscription: cx.subscribe(&project, Self::handle_project_event), + }; + this.update_worktree_indices(cx); + this + } + + fn handle_project_event( + &mut self, + _: Model, + event: &project::Event, + cx: &mut ModelContext, + ) { + match event { + project::Event::WorktreeAdded | project::Event::WorktreeRemoved(_) => { + self.update_worktree_indices(cx); + } + _ => {} + } + } + + fn update_worktree_indices(&mut self, cx: &mut ModelContext) { + let worktrees = self + .project + .read(cx) + .visible_worktrees(cx) + .filter_map(|worktree| { + if worktree.read(cx).is_local() { + Some((worktree.entity_id(), worktree)) + } else { + None + } + }) + .collect::>(); + + self.worktree_indices + .retain(|worktree_id, _| worktrees.contains_key(worktree_id)); + for (worktree_id, worktree) in worktrees { + self.worktree_indices.entry(worktree_id).or_insert_with(|| { + let worktree_index = WorktreeIndex::load( + worktree.clone(), + self.db_connection.clone(), + self.language_registry.clone(), + self.fs.clone(), + self.embedding_provider.clone(), + cx, + ); + + let load_worktree = cx.spawn(|this, mut cx| async move { + if let Some(index) = worktree_index.await.log_err() { + this.update(&mut cx, |this, cx| { + this.worktree_indices.insert( + worktree_id, + WorktreeIndexHandle::Loaded { + _subscription: cx + .observe(&index, |this, _, cx| this.update_status(cx)), + index, + }, + ); + })?; + } else { + this.update(&mut cx, |this, _cx| { + this.worktree_indices.remove(&worktree_id) + })?; + } + + this.update(&mut cx, |this, cx| this.update_status(cx)) + }); + + WorktreeIndexHandle::Loading { + _task: load_worktree, + } + }); + } + + self.update_status(cx); + } + + fn update_status(&mut self, cx: &mut ModelContext) { + let mut status = Status::Idle; + for index in self.worktree_indices.values() { + match index { + WorktreeIndexHandle::Loading { .. } => { + status = Status::Scanning; + break; + } + WorktreeIndexHandle::Loaded { index, .. } => { + if index.read(cx).status == Status::Scanning { + status = Status::Scanning; + break; + } + } + } + } + + if status != self.last_status { + self.last_status = status; + cx.emit(status); + } + } + + pub fn search(&self, query: &str, limit: usize, cx: &AppContext) -> Task> { + let mut worktree_searches = Vec::new(); + for worktree_index in self.worktree_indices.values() { + if let WorktreeIndexHandle::Loaded { index, .. } = worktree_index { + worktree_searches + .push(index.read_with(cx, |index, cx| index.search(query, limit, cx))); + } + } + + cx.spawn(|_| async move { + let mut results = Vec::new(); + let worktree_searches = futures::future::join_all(worktree_searches).await; + + for worktree_search_results in worktree_searches { + if let Some(worktree_search_results) = worktree_search_results.log_err() { + results.extend(worktree_search_results); + } + } + + results + .sort_unstable_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal)); + results.truncate(limit); + + results + }) + } +} + +pub struct SearchResult { + pub worktree: Model, + pub path: Arc, + pub range: Range, + pub score: f32, +} + +#[derive(Copy, Clone, Debug, Eq, PartialEq)] +pub enum Status { + Idle, + Scanning, +} + +impl EventEmitter for ProjectIndex {} + +struct WorktreeIndex { + worktree: Model, + db_connection: heed::Env, + db: heed::Database>, + language_registry: Arc, + fs: Arc, + embedding_provider: Arc, + status: Status, + _index_entries: Task>, + _subscription: Subscription, +} + +impl WorktreeIndex { + pub fn load( + worktree: Model, + db_connection: heed::Env, + language_registry: Arc, + fs: Arc, + embedding_provider: Arc, + cx: &mut AppContext, + ) -> Task>> { + let worktree_abs_path = worktree.read(cx).abs_path(); + cx.spawn(|mut cx| async move { + let db = cx + .background_executor() + .spawn({ + let db_connection = db_connection.clone(); + async move { + let mut txn = db_connection.write_txn()?; + let db_name = worktree_abs_path.to_string_lossy(); + let db = db_connection.create_database(&mut txn, Some(&db_name))?; + txn.commit()?; + anyhow::Ok(db) + } + }) + .await?; + cx.new_model(|cx| { + Self::new( + worktree, + db_connection, + db, + language_registry, + fs, + embedding_provider, + cx, + ) + }) + }) + } + + fn new( + worktree: Model, + db_connection: heed::Env, + db: heed::Database>, + language_registry: Arc, + fs: Arc, + embedding_provider: Arc, + cx: &mut ModelContext, + ) -> Self { + let (updated_entries_tx, updated_entries_rx) = channel::unbounded(); + let _subscription = cx.subscribe(&worktree, move |_this, _worktree, event, _cx| { + if let worktree::Event::UpdatedEntries(update) = event { + _ = updated_entries_tx.try_send(update.clone()); + } + }); + + Self { + db_connection, + db, + worktree, + language_registry, + fs, + embedding_provider, + status: Status::Idle, + _index_entries: cx.spawn(|this, cx| Self::index_entries(this, updated_entries_rx, cx)), + _subscription, + } + } + + async fn index_entries( + this: WeakModel, + updated_entries: channel::Receiver, + mut cx: AsyncAppContext, + ) -> Result<()> { + let index = this.update(&mut cx, |this, cx| { + cx.notify(); + this.status = Status::Scanning; + this.index_entries_changed_on_disk(cx) + })?; + index.await.log_err(); + this.update(&mut cx, |this, cx| { + this.status = Status::Idle; + cx.notify(); + })?; + + while let Ok(updated_entries) = updated_entries.recv().await { + let index = this.update(&mut cx, |this, cx| { + cx.notify(); + this.status = Status::Scanning; + this.index_updated_entries(updated_entries, cx) + })?; + index.await.log_err(); + this.update(&mut cx, |this, cx| { + this.status = Status::Idle; + cx.notify(); + })?; + } + + Ok(()) + } + + fn index_entries_changed_on_disk(&self, cx: &AppContext) -> impl Future> { + let worktree = self.worktree.read(cx).as_local().unwrap().snapshot(); + let worktree_abs_path = worktree.abs_path().clone(); + let scan = self.scan_entries(worktree.clone(), cx); + let chunk = self.chunk_files(worktree_abs_path, scan.updated_entries, cx); + let embed = self.embed_files(chunk.files, cx); + let persist = self.persist_embeddings(scan.deleted_entry_ranges, embed.files, cx); + async move { + futures::try_join!(scan.task, chunk.task, embed.task, persist)?; + Ok(()) + } + } + + fn index_updated_entries( + &self, + updated_entries: UpdatedEntriesSet, + cx: &AppContext, + ) -> impl Future> { + let worktree = self.worktree.read(cx).as_local().unwrap().snapshot(); + let worktree_abs_path = worktree.abs_path().clone(); + let scan = self.scan_updated_entries(worktree, updated_entries, cx); + let chunk = self.chunk_files(worktree_abs_path, scan.updated_entries, cx); + let embed = self.embed_files(chunk.files, cx); + let persist = self.persist_embeddings(scan.deleted_entry_ranges, embed.files, cx); + async move { + futures::try_join!(scan.task, chunk.task, embed.task, persist)?; + Ok(()) + } + } + + fn scan_entries(&self, worktree: LocalSnapshot, cx: &AppContext) -> ScanEntries { + let (updated_entries_tx, updated_entries_rx) = channel::bounded(512); + let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) = channel::bounded(128); + let db_connection = self.db_connection.clone(); + let db = self.db; + let task = cx.background_executor().spawn(async move { + let txn = db_connection + .read_txn() + .context("failed to create read transaction")?; + let mut db_entries = db + .iter(&txn) + .context("failed to create iterator")? + .move_between_keys() + .peekable(); + + let mut deletion_range: Option<(Bound<&str>, Bound<&str>)> = None; + for entry in worktree.files(false, 0) { + let entry_db_key = db_key_for_path(&entry.path); + + let mut saved_mtime = None; + while let Some(db_entry) = db_entries.peek() { + match db_entry { + Ok((db_path, db_embedded_file)) => match (*db_path).cmp(&entry_db_key) { + Ordering::Less => { + if let Some(deletion_range) = deletion_range.as_mut() { + deletion_range.1 = Bound::Included(db_path); + } else { + deletion_range = + Some((Bound::Included(db_path), Bound::Included(db_path))); + } + + db_entries.next(); + } + Ordering::Equal => { + if let Some(deletion_range) = deletion_range.take() { + deleted_entry_ranges_tx + .send(( + deletion_range.0.map(ToString::to_string), + deletion_range.1.map(ToString::to_string), + )) + .await?; + } + saved_mtime = db_embedded_file.mtime; + db_entries.next(); + break; + } + Ordering::Greater => { + break; + } + }, + Err(_) => return Err(db_entries.next().unwrap().unwrap_err())?, + } + } + + if entry.mtime != saved_mtime { + updated_entries_tx.send(entry.clone()).await?; + } + } + + if let Some(db_entry) = db_entries.next() { + let (db_path, _) = db_entry?; + deleted_entry_ranges_tx + .send((Bound::Included(db_path.to_string()), Bound::Unbounded)) + .await?; + } + + Ok(()) + }); + + ScanEntries { + updated_entries: updated_entries_rx, + deleted_entry_ranges: deleted_entry_ranges_rx, + task, + } + } + + fn scan_updated_entries( + &self, + worktree: LocalSnapshot, + updated_entries: UpdatedEntriesSet, + cx: &AppContext, + ) -> ScanEntries { + let (updated_entries_tx, updated_entries_rx) = channel::bounded(512); + let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) = channel::bounded(128); + let task = cx.background_executor().spawn(async move { + for (path, entry_id, status) in updated_entries.iter() { + match status { + project::PathChange::Added + | project::PathChange::Updated + | project::PathChange::AddedOrUpdated => { + if let Some(entry) = worktree.entry_for_id(*entry_id) { + updated_entries_tx.send(entry.clone()).await?; + } + } + project::PathChange::Removed => { + let db_path = db_key_for_path(path); + deleted_entry_ranges_tx + .send((Bound::Included(db_path.clone()), Bound::Included(db_path))) + .await?; + } + project::PathChange::Loaded => { + // Do nothing. + } + } + } + + Ok(()) + }); + + ScanEntries { + updated_entries: updated_entries_rx, + deleted_entry_ranges: deleted_entry_ranges_rx, + task, + } + } + + fn chunk_files( + &self, + worktree_abs_path: Arc, + entries: channel::Receiver, + cx: &AppContext, + ) -> ChunkFiles { + let language_registry = self.language_registry.clone(); + let fs = self.fs.clone(); + let (chunked_files_tx, chunked_files_rx) = channel::bounded(2048); + let task = cx.spawn(|cx| async move { + cx.background_executor() + .scoped(|cx| { + for _ in 0..cx.num_cpus() { + cx.spawn(async { + while let Ok(entry) = entries.recv().await { + let entry_abs_path = worktree_abs_path.join(&entry.path); + let Some(text) = fs.load(&entry_abs_path).await.log_err() else { + continue; + }; + let language = language_registry + .language_for_file_path(&entry.path) + .await + .ok(); + let grammar = + language.as_ref().and_then(|language| language.grammar()); + let chunked_file = ChunkedFile { + worktree_root: worktree_abs_path.clone(), + chunks: chunk_text(&text, grammar), + entry, + text, + }; + + if chunked_files_tx.send(chunked_file).await.is_err() { + return; + } + } + }); + } + }) + .await; + Ok(()) + }); + + ChunkFiles { + files: chunked_files_rx, + task, + } + } + + fn embed_files( + &self, + chunked_files: channel::Receiver, + cx: &AppContext, + ) -> EmbedFiles { + let embedding_provider = self.embedding_provider.clone(); + let (embedded_files_tx, embedded_files_rx) = channel::bounded(512); + let task = cx.background_executor().spawn(async move { + let mut chunked_file_batches = + chunked_files.chunks_timeout(512, Duration::from_secs(2)); + while let Some(chunked_files) = chunked_file_batches.next().await { + // View the batch of files as a vec of chunks + // Flatten out to a vec of chunks that we can subdivide into batch sized pieces + // Once those are done, reassemble it back into which files they belong to + + let chunks = chunked_files + .iter() + .flat_map(|file| { + file.chunks.iter().map(|chunk| TextToEmbed { + text: &file.text[chunk.range.clone()], + digest: chunk.digest, + }) + }) + .collect::>(); + + let mut embeddings = Vec::new(); + for embedding_batch in chunks.chunks(embedding_provider.batch_size()) { + // todo!("add a retry facility") + embeddings.extend(embedding_provider.embed(embedding_batch).await?); + } + + let mut embeddings = embeddings.into_iter(); + for chunked_file in chunked_files { + let chunk_embeddings = embeddings + .by_ref() + .take(chunked_file.chunks.len()) + .collect::>(); + let embedded_chunks = chunked_file + .chunks + .into_iter() + .zip(chunk_embeddings) + .map(|(chunk, embedding)| EmbeddedChunk { chunk, embedding }) + .collect(); + let embedded_file = EmbeddedFile { + path: chunked_file.entry.path.clone(), + mtime: chunked_file.entry.mtime, + chunks: embedded_chunks, + }; + + embedded_files_tx.send(embedded_file).await?; + } + } + Ok(()) + }); + + EmbedFiles { + files: embedded_files_rx, + task, + } + } + + fn persist_embeddings( + &self, + mut deleted_entry_ranges: channel::Receiver<(Bound, Bound)>, + embedded_files: channel::Receiver, + cx: &AppContext, + ) -> Task> { + let db_connection = self.db_connection.clone(); + let db = self.db; + cx.background_executor().spawn(async move { + while let Some(deletion_range) = deleted_entry_ranges.next().await { + let mut txn = db_connection.write_txn()?; + let start = deletion_range.0.as_ref().map(|start| start.as_str()); + let end = deletion_range.1.as_ref().map(|end| end.as_str()); + log::debug!("deleting embeddings in range {:?}", &(start, end)); + db.delete_range(&mut txn, &(start, end))?; + txn.commit()?; + } + + let mut embedded_files = embedded_files.chunks_timeout(4096, Duration::from_secs(2)); + while let Some(embedded_files) = embedded_files.next().await { + let mut txn = db_connection.write_txn()?; + for file in embedded_files { + log::debug!("saving embedding for file {:?}", file.path); + let key = db_key_for_path(&file.path); + db.put(&mut txn, &key, &file)?; + } + txn.commit()?; + log::debug!("committed"); + } + + Ok(()) + }) + } + + fn search( + &self, + query: &str, + limit: usize, + cx: &AppContext, + ) -> Task>> { + let (chunks_tx, chunks_rx) = channel::bounded(1024); + + let db_connection = self.db_connection.clone(); + let db = self.db; + let scan_chunks = cx.background_executor().spawn({ + async move { + let txn = db_connection + .read_txn() + .context("failed to create read transaction")?; + let db_entries = db.iter(&txn).context("failed to iterate database")?; + for db_entry in db_entries { + let (_, db_embedded_file) = db_entry?; + for chunk in db_embedded_file.chunks { + chunks_tx + .send((db_embedded_file.path.clone(), chunk)) + .await?; + } + } + anyhow::Ok(()) + } + }); + + let query = query.to_string(); + let embedding_provider = self.embedding_provider.clone(); + let worktree = self.worktree.clone(); + cx.spawn(|cx| async move { + #[cfg(debug_assertions)] + let embedding_query_start = std::time::Instant::now(); + + let mut query_embeddings = embedding_provider + .embed(&[TextToEmbed::new(&query)]) + .await?; + let query_embedding = query_embeddings + .pop() + .ok_or_else(|| anyhow!("no embedding for query"))?; + let mut workers = Vec::new(); + for _ in 0..cx.background_executor().num_cpus() { + workers.push(Vec::::new()); + } + + #[cfg(debug_assertions)] + let search_start = std::time::Instant::now(); + + cx.background_executor() + .scoped(|cx| { + for worker_results in workers.iter_mut() { + cx.spawn(async { + while let Ok((path, embedded_chunk)) = chunks_rx.recv().await { + let score = embedded_chunk.embedding.similarity(&query_embedding); + let ix = match worker_results.binary_search_by(|probe| { + score.partial_cmp(&probe.score).unwrap_or(Ordering::Equal) + }) { + Ok(ix) | Err(ix) => ix, + }; + worker_results.insert( + ix, + SearchResult { + worktree: worktree.clone(), + path: path.clone(), + range: embedded_chunk.chunk.range.clone(), + score, + }, + ); + worker_results.truncate(limit); + } + }); + } + }) + .await; + scan_chunks.await?; + + let mut search_results = Vec::with_capacity(workers.len() * limit); + for worker_results in workers { + search_results.extend(worker_results); + } + search_results + .sort_unstable_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal)); + search_results.truncate(limit); + #[cfg(debug_assertions)] + { + let search_elapsed = search_start.elapsed(); + log::debug!( + "searched {} entries in {:?}", + search_results.len(), + search_elapsed + ); + let embedding_query_elapsed = embedding_query_start.elapsed(); + log::debug!("embedding query took {:?}", embedding_query_elapsed); + } + + Ok(search_results) + }) + } +} + +struct ScanEntries { + updated_entries: channel::Receiver, + deleted_entry_ranges: channel::Receiver<(Bound, Bound)>, + task: Task>, +} + +struct ChunkFiles { + files: channel::Receiver, + task: Task>, +} + +struct ChunkedFile { + #[allow(dead_code)] + pub worktree_root: Arc, + pub entry: Entry, + pub text: String, + pub chunks: Vec, +} + +struct EmbedFiles { + files: channel::Receiver, + task: Task>, +} + +#[derive(Debug, Serialize, Deserialize)] +struct EmbeddedFile { + path: Arc, + mtime: Option, + chunks: Vec, +} + +#[derive(Debug, Serialize, Deserialize)] +struct EmbeddedChunk { + chunk: Chunk, + embedding: Embedding, +} + +fn db_key_for_path(path: &Arc) -> String { + path.to_string_lossy().replace('/', "\0") +} + +#[cfg(test)] +mod tests { + use super::*; + + use futures::channel::oneshot; + use futures::{future::BoxFuture, FutureExt}; + + use gpui::{Global, TestAppContext}; + use language::language_settings::AllLanguageSettings; + use project::Project; + use settings::SettingsStore; + use std::{future, path::Path, sync::Arc}; + + fn init_test(cx: &mut TestAppContext) { + _ = cx.update(|cx| { + let store = SettingsStore::test(cx); + cx.set_global(store); + language::init(cx); + Project::init_settings(cx); + SettingsStore::update(cx, |store, cx| { + store.update_user_settings::(cx, |_| {}); + }); + }); + } + + pub struct TestEmbeddingProvider; + + impl EmbeddingProvider for TestEmbeddingProvider { + fn embed<'a>( + &'a self, + texts: &'a [TextToEmbed<'a>], + ) -> BoxFuture<'a, Result>> { + let embeddings = texts + .iter() + .map(|text| { + let mut embedding = vec![0f32; 2]; + // if the text contains garbage, give it a 1 in the first dimension + if text.text.contains("garbage in") { + embedding[0] = 0.9; + } else { + embedding[0] = -0.9; + } + + if text.text.contains("garbage out") { + embedding[1] = 0.9; + } else { + embedding[1] = -0.9; + } + + Embedding::new(embedding) + }) + .collect(); + future::ready(Ok(embeddings)).boxed() + } + + fn batch_size(&self) -> usize { + 16 + } + } + + #[gpui::test] + async fn test_search(cx: &mut TestAppContext) { + cx.executor().allow_parking(); + + init_test(cx); + + let temp_dir = tempfile::tempdir().unwrap(); + + let mut semantic_index = cx + .update(|cx| { + let semantic_index = SemanticIndex::new( + Path::new(temp_dir.path()), + Arc::new(TestEmbeddingProvider), + cx, + ); + semantic_index + }) + .await + .unwrap(); + + // todo!(): use a fixture + let project_path = Path::new("./fixture"); + + let project = cx + .spawn(|mut cx| async move { Project::example([project_path], &mut cx).await }) + .await; + + cx.update(|cx| { + let language_registry = project.read(cx).languages().clone(); + let node_runtime = project.read(cx).node_runtime().unwrap().clone(); + languages::init(language_registry, node_runtime, cx); + }); + + let project_index = cx.update(|cx| semantic_index.project_index(project.clone(), cx)); + + let (tx, rx) = oneshot::channel(); + let mut tx = Some(tx); + let subscription = cx.update(|cx| { + cx.subscribe(&project_index, move |_, event, _| { + if let Some(tx) = tx.take() { + _ = tx.send(*event); + } + }) + }); + + rx.await.expect("no event emitted"); + drop(subscription); + + let results = cx + .update(|cx| { + let project_index = project_index.read(cx); + let query = "garbage in, garbage out"; + project_index.search(query, 4, cx) + }) + .await; + + assert!(results.len() > 1, "should have found some results"); + + for result in &results { + println!("result: {:?}", result.path); + println!("score: {:?}", result.score); + } + + // Find result that is greater than 0.5 + let search_result = results.iter().find(|result| result.score > 0.9).unwrap(); + + assert_eq!(search_result.path.to_string_lossy(), "needle.md"); + + let content = cx + .update(|cx| { + let worktree = search_result.worktree.read(cx); + let entry_abs_path = worktree.abs_path().join(search_result.path.clone()); + let fs = project.read(cx).fs().clone(); + cx.spawn(|_| async move { fs.load(&entry_abs_path).await.unwrap() }) + }) + .await; + + let range = search_result.range.clone(); + let content = content[range.clone()].to_owned(); + + assert!(content.contains("garbage in, garbage out")); + } +} diff --git a/crates/util/src/http.rs b/crates/util/src/http.rs index 3ee7b96eac..01a061cd1a 100644 --- a/crates/util/src/http.rs +++ b/crates/util/src/http.rs @@ -71,19 +71,28 @@ impl HttpClientWithUrl { } impl HttpClient for Arc { - fn send(&self, req: Request) -> BoxFuture, Error>> { + fn send( + &self, + req: Request, + ) -> BoxFuture<'static, Result, Error>> { self.client.send(req) } } impl HttpClient for HttpClientWithUrl { - fn send(&self, req: Request) -> BoxFuture, Error>> { + fn send( + &self, + req: Request, + ) -> BoxFuture<'static, Result, Error>> { self.client.send(req) } } pub trait HttpClient: Send + Sync { - fn send(&self, req: Request) -> BoxFuture, Error>>; + fn send( + &self, + req: Request, + ) -> BoxFuture<'static, Result, Error>>; fn get<'a>( &'a self, @@ -135,8 +144,12 @@ pub fn client() -> Arc { } impl HttpClient for isahc::HttpClient { - fn send(&self, req: Request) -> BoxFuture, Error>> { - Box::pin(async move { self.send_async(req).await }) + fn send( + &self, + req: Request, + ) -> BoxFuture<'static, Result, Error>> { + let client = self.clone(); + Box::pin(async move { client.send_async(req).await }) } } @@ -196,7 +209,10 @@ impl fmt::Debug for FakeHttpClient { #[cfg(feature = "test-support")] impl HttpClient for FakeHttpClient { - fn send(&self, req: Request) -> BoxFuture, Error>> { + fn send( + &self, + req: Request, + ) -> BoxFuture<'static, Result, Error>> { let future = (self.handler)(req); Box::pin(async move { future.await.map(Into::into) }) }