diff --git a/Cargo.lock b/Cargo.lock index 92ede849ff..814630d8e4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3265,6 +3265,15 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9ea835d29036a4087793836fa931b08837ad5e957da9e23886b29586fb9b6650" +[[package]] +name = "doxygen-rs" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "415b6ec780d34dcf624666747194393603d0373b7141eef01d12ee58881507d9" +dependencies = [ + "phf", +] + [[package]] name = "dwrote" version = "0.11.0" @@ -4085,6 +4094,17 @@ dependencies = [ "futures-util", ] +[[package]] +name = "futures-batch" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6f444c45a1cb86f2a7e301469fd50a82084a60dadc25d94529a8312276ecb71a" +dependencies = [ + "futures 0.3.28", + "futures-timer", + "pin-utils", +] + [[package]] name = "futures-channel" version = "0.3.30" @@ -4180,6 +4200,12 @@ version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004" +[[package]] +name = "futures-timer" +version = "3.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24" + [[package]] name = "futures-util" version = "0.3.30" @@ -4659,6 +4685,41 @@ dependencies = [ "unicode-segmentation", ] +[[package]] +name = "heed" +version = "0.20.0-alpha.9" +source = "git+https://github.com/meilisearch/heed?rev=036ac23f73a021894974b9adc815bc95b3e0482a#036ac23f73a021894974b9adc815bc95b3e0482a" +dependencies = [ + "bitflags 2.4.2", + "byteorder", + "heed-traits", + "heed-types", + "libc", + "lmdb-master-sys", + "once_cell", + "page_size", + "serde", + "synchronoise", + "url", +] + +[[package]] +name = "heed-traits" +version = "0.20.0-alpha.9" +source = "git+https://github.com/meilisearch/heed?rev=036ac23f73a021894974b9adc815bc95b3e0482a#036ac23f73a021894974b9adc815bc95b3e0482a" + +[[package]] +name = "heed-types" +version = "0.20.0-alpha.9" +source = "git+https://github.com/meilisearch/heed?rev=036ac23f73a021894974b9adc815bc95b3e0482a#036ac23f73a021894974b9adc815bc95b3e0482a" +dependencies = [ + "bincode", + "byteorder", + "heed-traits", + "serde", + "serde_json", +] + [[package]] name = "hermit-abi" version = "0.1.19" @@ -5664,6 +5725,16 @@ dependencies = [ "sha2 0.10.7", ] +[[package]] +name = "lmdb-master-sys" +version = "0.1.0" +source = "git+https://github.com/meilisearch/heed?rev=036ac23f73a021894974b9adc815bc95b3e0482a#036ac23f73a021894974b9adc815bc95b3e0482a" +dependencies = [ + "cc", + "doxygen-rs", + "libc", +] + [[package]] name = "lock_api" version = "0.4.10" @@ -6683,6 +6754,16 @@ dependencies = [ "sha2 0.10.7", ] +[[package]] +name = "page_size" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "30d5b2194ed13191c1999ae0704b7839fb18384fa22e49b57eeaa97d79ce40da" +dependencies = [ + "libc", + "winapi", +] + [[package]] name = "palette" version = "0.7.5" @@ -6856,9 +6937,33 @@ version = "0.11.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ade2d8b8f33c7333b51bcf0428d37e217e9f32192ae4772156f65063b8ce03dc" dependencies = [ + "phf_macros", "phf_shared", ] +[[package]] +name = "phf_generator" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48e4cc64c2ad9ebe670cb8fd69dd50ae301650392e81c05f9bfcb2d5bdbc24b0" +dependencies = [ + "phf_shared", + "rand 0.8.5", +] + +[[package]] +name = "phf_macros" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3444646e286606587e49f3bcf1679b8cef1dc2c5ecc29ddacaffc305180d464b" +dependencies = [ + "phf_generator", + "phf_shared", + "proc-macro2", + "quote", + "syn 2.0.48", +] + [[package]] name = "phf_shared" version = "0.11.2" @@ -8473,6 +8578,35 @@ version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "58bf37232d3bb9a2c4e641ca2a11d83b5062066f88df7fed36c28772046d65ba" +[[package]] +name = "semantic_index" +version = "0.1.0" +dependencies = [ + "anyhow", + "client", + "clock", + "collections", + "env_logger", + "fs", + "futures 0.3.28", + "futures-batch", + "gpui", + "heed", + "language", + "languages", + "log", + "open_ai", + "project", + "serde", + "serde_json", + "settings", + "sha2 0.10.7", + "smol", + "tempfile", + "util", + "worktree", +] + [[package]] name = "semantic_version" version = "0.1.0" @@ -9478,6 +9612,15 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160" +[[package]] +name = "synchronoise" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3dbc01390fc626ce8d1cffe3376ded2b72a11bb70e1c75f404a210e4daa4def2" +dependencies = [ + "crossbeam-queue", +] + [[package]] name = "sys-locale" version = "0.3.1" diff --git a/Cargo.toml b/Cargo.toml index f58d998a8b..22999d86b6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -73,6 +73,7 @@ members = [ "crates/task", "crates/tasks_ui", "crates/search", + "crates/semantic_index", "crates/semantic_version", "crates/settings", "crates/snippet", @@ -253,9 +254,11 @@ derive_more = "0.99.17" emojis = "0.6.1" env_logger = "0.9" futures = "0.3" +futures-batch = "0.6.1" futures-lite = "1.13" git2 = { version = "0.15", default-features = false } globset = "0.4" +heed = { git = "https://github.com/meilisearch/heed", rev = "036ac23f73a021894974b9adc815bc95b3e0482a", features = ["read-txn-no-tls"] } hex = "0.4.3" ignore = "0.4.22" indoc = "1" diff --git a/crates/channel/src/channel_store_tests.rs b/crates/channel/src/channel_store_tests.rs index cee747a4e9..c9f7bf2485 100644 --- a/crates/channel/src/channel_store_tests.rs +++ b/crates/channel/src/channel_store_tests.rs @@ -264,7 +264,7 @@ async fn test_channel_messages(cx: &mut TestAppContext) { ); assert_eq!( - channel.next_event(cx), + channel.next_event(cx).await, ChannelChatEvent::MessagesUpdated { old_range: 2..2, new_count: 1, @@ -317,7 +317,7 @@ async fn test_channel_messages(cx: &mut TestAppContext) { ); assert_eq!( - channel.next_event(cx), + channel.next_event(cx).await, ChannelChatEvent::MessagesUpdated { old_range: 0..0, new_count: 2, diff --git a/crates/collab/migrations/20240409082755_create_embeddings.sql b/crates/collab/migrations/20240409082755_create_embeddings.sql new file mode 100644 index 0000000000..ae4b4bcb61 --- /dev/null +++ b/crates/collab/migrations/20240409082755_create_embeddings.sql @@ -0,0 +1,9 @@ +CREATE TABLE IF NOT EXISTS "embeddings" ( + "model" TEXT, + "digest" BYTEA, + "dimensions" FLOAT4[1536], + "retrieved_at" TIMESTAMP NOT NULL DEFAULT now(), + PRIMARY KEY ("model", "digest") +); + +CREATE INDEX IF NOT EXISTS "idx_retrieved_at_on_embeddings" ON "embeddings" ("retrieved_at"); diff --git a/crates/collab/src/db/queries.rs b/crates/collab/src/db/queries.rs index 2cbbc67969..b7670aa60c 100644 --- a/crates/collab/src/db/queries.rs +++ b/crates/collab/src/db/queries.rs @@ -6,6 +6,7 @@ pub mod channels; pub mod contacts; pub mod contributors; pub mod dev_servers; +pub mod embeddings; pub mod extensions; pub mod hosted_projects; pub mod messages; diff --git a/crates/collab/src/db/queries/embeddings.rs b/crates/collab/src/db/queries/embeddings.rs new file mode 100644 index 0000000000..d901b59659 --- /dev/null +++ b/crates/collab/src/db/queries/embeddings.rs @@ -0,0 +1,94 @@ +use super::*; +use time::Duration; +use time::OffsetDateTime; + +impl Database { + pub async fn get_embeddings( + &self, + model: &str, + digests: &[Vec], + ) -> Result, Vec>> { + self.weak_transaction(|tx| async move { + let embeddings = { + let mut db_embeddings = embedding::Entity::find() + .filter( + embedding::Column::Model.eq(model).and( + embedding::Column::Digest + .is_in(digests.iter().map(|digest| digest.as_slice())), + ), + ) + .stream(&*tx) + .await?; + + let mut embeddings = HashMap::default(); + while let Some(db_embedding) = db_embeddings.next().await { + let db_embedding = db_embedding?; + embeddings.insert(db_embedding.digest, db_embedding.dimensions); + } + embeddings + }; + + if !embeddings.is_empty() { + let now = OffsetDateTime::now_utc(); + let retrieved_at = PrimitiveDateTime::new(now.date(), now.time()); + + embedding::Entity::update_many() + .filter( + embedding::Column::Digest + .is_in(embeddings.keys().map(|digest| digest.as_slice())), + ) + .col_expr(embedding::Column::RetrievedAt, Expr::value(retrieved_at)) + .exec(&*tx) + .await?; + } + + Ok(embeddings) + }) + .await + } + + pub async fn save_embeddings( + &self, + model: &str, + embeddings: &HashMap, Vec>, + ) -> Result<()> { + self.weak_transaction(|tx| async move { + embedding::Entity::insert_many(embeddings.iter().map(|(digest, dimensions)| { + let now_offset_datetime = OffsetDateTime::now_utc(); + let retrieved_at = + PrimitiveDateTime::new(now_offset_datetime.date(), now_offset_datetime.time()); + + embedding::ActiveModel { + model: ActiveValue::set(model.to_string()), + digest: ActiveValue::set(digest.clone()), + dimensions: ActiveValue::set(dimensions.clone()), + retrieved_at: ActiveValue::set(retrieved_at), + } + })) + .on_conflict( + OnConflict::columns([embedding::Column::Model, embedding::Column::Digest]) + .do_nothing() + .to_owned(), + ) + .exec_without_returning(&*tx) + .await?; + Ok(()) + }) + .await + } + + pub async fn purge_old_embeddings(&self) -> Result<()> { + self.weak_transaction(|tx| async move { + embedding::Entity::delete_many() + .filter( + embedding::Column::RetrievedAt + .lte(OffsetDateTime::now_utc() - Duration::days(60)), + ) + .exec(&*tx) + .await?; + + Ok(()) + }) + .await + } +} diff --git a/crates/collab/src/db/tables.rs b/crates/collab/src/db/tables.rs index 4a284682b2..2af78f776e 100644 --- a/crates/collab/src/db/tables.rs +++ b/crates/collab/src/db/tables.rs @@ -11,6 +11,7 @@ pub mod channel_message_mention; pub mod contact; pub mod contributor; pub mod dev_server; +pub mod embedding; pub mod extension; pub mod extension_version; pub mod feature_flag; diff --git a/crates/collab/src/db/tables/embedding.rs b/crates/collab/src/db/tables/embedding.rs new file mode 100644 index 0000000000..8743b4b9e6 --- /dev/null +++ b/crates/collab/src/db/tables/embedding.rs @@ -0,0 +1,18 @@ +use sea_orm::entity::prelude::*; +use time::PrimitiveDateTime; + +#[derive(Clone, Debug, PartialEq, DeriveEntityModel)] +#[sea_orm(table_name = "embeddings")] +pub struct Model { + #[sea_orm(primary_key)] + pub model: String, + #[sea_orm(primary_key)] + pub digest: Vec, + pub dimensions: Vec, + pub retrieved_at: PrimitiveDateTime, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation {} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab/src/db/tests.rs b/crates/collab/src/db/tests.rs index 35da659e54..e3ce834295 100644 --- a/crates/collab/src/db/tests.rs +++ b/crates/collab/src/db/tests.rs @@ -2,6 +2,7 @@ mod buffer_tests; mod channel_tests; mod contributor_tests; mod db_tests; +mod embedding_tests; mod extension_tests; mod feature_flag_tests; mod message_tests; diff --git a/crates/collab/src/db/tests/embedding_tests.rs b/crates/collab/src/db/tests/embedding_tests.rs new file mode 100644 index 0000000000..fcafac625d --- /dev/null +++ b/crates/collab/src/db/tests/embedding_tests.rs @@ -0,0 +1,84 @@ +use super::TestDb; +use crate::db::embedding; +use collections::HashMap; +use sea_orm::{sea_query::Expr, ColumnTrait, EntityTrait, QueryFilter}; +use std::ops::Sub; +use time::{Duration, OffsetDateTime, PrimitiveDateTime}; + +// SQLite does not support array arguments, so we only test this against a real postgres instance +#[gpui::test] +async fn test_get_embeddings_postgres(cx: &mut gpui::TestAppContext) { + let test_db = TestDb::postgres(cx.executor().clone()); + let db = test_db.db(); + + let provider = "test_model"; + let digest1 = vec![1, 2, 3]; + let digest2 = vec![4, 5, 6]; + let embeddings = HashMap::from_iter([ + (digest1.clone(), vec![0.1, 0.2, 0.3]), + (digest2.clone(), vec![0.4, 0.5, 0.6]), + ]); + + // Save embeddings + db.save_embeddings(provider, &embeddings).await.unwrap(); + + // Retrieve embeddings + let retrieved_embeddings = db + .get_embeddings(provider, &[digest1.clone(), digest2.clone()]) + .await + .unwrap(); + assert_eq!(retrieved_embeddings.len(), 2); + assert!(retrieved_embeddings.contains_key(&digest1)); + assert!(retrieved_embeddings.contains_key(&digest2)); + + // Check if the retrieved embeddings are correct + assert_eq!(retrieved_embeddings[&digest1], vec![0.1, 0.2, 0.3]); + assert_eq!(retrieved_embeddings[&digest2], vec![0.4, 0.5, 0.6]); +} + +#[gpui::test] +async fn test_purge_old_embeddings(cx: &mut gpui::TestAppContext) { + let test_db = TestDb::postgres(cx.executor().clone()); + let db = test_db.db(); + + let model = "test_model"; + let digest = vec![7, 8, 9]; + let embeddings = HashMap::from_iter([(digest.clone(), vec![0.7, 0.8, 0.9])]); + + // Save old embeddings + db.save_embeddings(model, &embeddings).await.unwrap(); + + // Reach into the DB and change the retrieved at to be > 60 days + db.weak_transaction(|tx| { + let digest = digest.clone(); + async move { + let sixty_days_ago = OffsetDateTime::now_utc().sub(Duration::days(61)); + let retrieved_at = PrimitiveDateTime::new(sixty_days_ago.date(), sixty_days_ago.time()); + + embedding::Entity::update_many() + .filter( + embedding::Column::Model + .eq(model) + .and(embedding::Column::Digest.eq(digest)), + ) + .col_expr(embedding::Column::RetrievedAt, Expr::value(retrieved_at)) + .exec(&*tx) + .await + .unwrap(); + + Ok(()) + } + }) + .await + .unwrap(); + + // Purge old embeddings + db.purge_old_embeddings().await.unwrap(); + + // Try to retrieve the purged embeddings + let retrieved_embeddings = db.get_embeddings(model, &[digest.clone()]).await.unwrap(); + assert!( + retrieved_embeddings.is_empty(), + "Old embeddings should have been purged" + ); +} diff --git a/crates/collab/src/main.rs b/crates/collab/src/main.rs index 728544c533..b85d378ef2 100644 --- a/crates/collab/src/main.rs +++ b/crates/collab/src/main.rs @@ -6,8 +6,8 @@ use axum::{ Extension, Router, }; use collab::{ - api::fetch_extensions_from_blob_store_periodically, db, env, executor::Executor, AppState, - Config, RateLimiter, Result, + api::fetch_extensions_from_blob_store_periodically, db, env, executor::Executor, + rpc::ResultExt, AppState, Config, RateLimiter, Result, }; use db::Database; use std::{ @@ -23,7 +23,7 @@ use tower_http::trace::TraceLayer; use tracing_subscriber::{ filter::EnvFilter, fmt::format::JsonFields, util::SubscriberInitExt, Layer, }; -use util::ResultExt; +use util::ResultExt as _; const VERSION: &str = env!("CARGO_PKG_VERSION"); const REVISION: Option<&'static str> = option_env!("GITHUB_SHA"); @@ -90,6 +90,7 @@ async fn main() -> Result<()> { }; if is_collab { + state.db.purge_old_embeddings().await.trace_err(); RateLimiter::save_periodically(state.rate_limiter.clone(), state.executor.clone()); } diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index bdcfd487f1..da8328c411 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -32,6 +32,8 @@ use axum::{ use collections::{HashMap, HashSet}; pub use connection_pool::{ConnectionPool, ZedVersion}; use core::fmt::{self, Debug, Formatter}; +use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL}; +use sha2::Digest; use futures::{ channel::oneshot, @@ -568,6 +570,22 @@ impl Server { app_state.config.google_ai_api_key.clone(), ) }) + }) + .add_request_handler({ + user_handler(move |request, response, session| { + get_cached_embeddings(request, response, session) + }) + }) + .add_request_handler({ + let app_state = app_state.clone(); + user_handler(move |request, response, session| { + compute_embeddings( + request, + response, + session, + app_state.config.openai_api_key.clone(), + ) + }) }); Arc::new(server) @@ -4021,8 +4039,6 @@ async fn complete_with_open_ai( session: UserSession, api_key: Arc, ) -> Result<()> { - const OPEN_AI_API_URL: &str = "https://api.openai.com/v1"; - let mut completion_stream = open_ai::stream_completion( &session.http_client, OPEN_AI_API_URL, @@ -4276,6 +4292,128 @@ async fn count_tokens_with_language_model( Ok(()) } +struct ComputeEmbeddingsRateLimit; + +impl RateLimit for ComputeEmbeddingsRateLimit { + fn capacity() -> usize { + std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR") + .ok() + .and_then(|v| v.parse().ok()) + .unwrap_or(120) // Picked arbitrarily + } + + fn refill_duration() -> chrono::Duration { + chrono::Duration::hours(1) + } + + fn db_name() -> &'static str { + "compute-embeddings" + } +} + +async fn compute_embeddings( + request: proto::ComputeEmbeddings, + response: Response, + session: UserSession, + api_key: Option>, +) -> Result<()> { + let api_key = api_key.context("no OpenAI API key configured on the server")?; + authorize_access_to_language_models(&session).await?; + + session + .rate_limiter + .check::(session.user_id()) + .await?; + + let embeddings = match request.model.as_str() { + "openai/text-embedding-3-small" => { + open_ai::embed( + &session.http_client, + OPEN_AI_API_URL, + &api_key, + OpenAiEmbeddingModel::TextEmbedding3Small, + request.texts.iter().map(|text| text.as_str()), + ) + .await? + } + provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?, + }; + + let embeddings = request + .texts + .iter() + .map(|text| { + let mut hasher = sha2::Sha256::new(); + hasher.update(text.as_bytes()); + let result = hasher.finalize(); + result.to_vec() + }) + .zip( + embeddings + .data + .into_iter() + .map(|embedding| embedding.embedding), + ) + .collect::>(); + + let db = session.db().await; + db.save_embeddings(&request.model, &embeddings) + .await + .context("failed to save embeddings") + .trace_err(); + + response.send(proto::ComputeEmbeddingsResponse { + embeddings: embeddings + .into_iter() + .map(|(digest, dimensions)| proto::Embedding { digest, dimensions }) + .collect(), + })?; + Ok(()) +} + +struct GetCachedEmbeddingsRateLimit; + +impl RateLimit for GetCachedEmbeddingsRateLimit { + fn capacity() -> usize { + std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR") + .ok() + .and_then(|v| v.parse().ok()) + .unwrap_or(120) // Picked arbitrarily + } + + fn refill_duration() -> chrono::Duration { + chrono::Duration::hours(1) + } + + fn db_name() -> &'static str { + "get-cached-embeddings" + } +} + +async fn get_cached_embeddings( + request: proto::GetCachedEmbeddings, + response: Response, + session: UserSession, +) -> Result<()> { + authorize_access_to_language_models(&session).await?; + + session + .rate_limiter + .check::(session.user_id()) + .await?; + + let db = session.db().await; + let embeddings = db.get_embeddings(&request.model, &request.digests).await?; + + response.send(proto::GetCachedEmbeddingsResponse { + embeddings: embeddings + .into_iter() + .map(|(digest, dimensions)| proto::Embedding { digest, dimensions }) + .collect(), + })?; + Ok(()) +} + async fn authorize_access_to_language_models(session: &UserSession) -> Result<(), Error> { let db = session.db().await; let flags = db.get_user_flags(session.user_id()).await?; diff --git a/crates/editor/src/git/blame.rs b/crates/editor/src/git/blame.rs index 8f753f7186..6b5420f15c 100644 --- a/crates/editor/src/git/blame.rs +++ b/crates/editor/src/git/blame.rs @@ -396,7 +396,7 @@ mod tests { let blame = cx.new_model(|cx| GitBlame::new(buffer.clone(), project.clone(), cx)); - let event = project.next_event(cx); + let event = project.next_event(cx).await; assert_eq!( event, project::Event::Notification( diff --git a/crates/gpui/src/app/test_context.rs b/crates/gpui/src/app/test_context.rs index b6bfd33f66..d049ceef2f 100644 --- a/crates/gpui/src/app/test_context.rs +++ b/crates/gpui/src/app/test_context.rs @@ -7,7 +7,7 @@ use crate::{ TextSystem, View, ViewContext, VisualContext, WindowContext, WindowHandle, WindowOptions, }; use anyhow::{anyhow, bail}; -use futures::{Stream, StreamExt}; +use futures::{channel::oneshot, Stream, StreamExt}; use std::{cell::RefCell, future::Future, ops::Deref, rc::Rc, sync::Arc, time::Duration}; /// A TestAppContext is provided to tests created with `#[gpui::test]`, it provides @@ -479,31 +479,26 @@ impl TestAppContext { impl Model { /// Block until the next event is emitted by the model, then return it. - pub fn next_event(&self, cx: &mut TestAppContext) -> Evt + pub fn next_event(&self, cx: &mut TestAppContext) -> impl Future where - Evt: Send + Clone + 'static, - T: EventEmitter, + Event: Send + Clone + 'static, + T: EventEmitter, { - let (tx, mut rx) = futures::channel::mpsc::unbounded(); - let _subscription = self.update(cx, |_, cx| { + let (tx, mut rx) = oneshot::channel(); + let mut tx = Some(tx); + let subscription = self.update(cx, |_, cx| { cx.subscribe(self, move |_, _, event, _| { - tx.unbounded_send(event.clone()).ok(); + if let Some(tx) = tx.take() { + _ = tx.send(event.clone()); + } }) }); - // Run other tasks until the event is emitted. - loop { - match rx.try_next() { - Ok(Some(event)) => return event, - Ok(None) => panic!("model was dropped"), - Err(_) => { - if !cx.executor().tick() { - break; - } - } - } + async move { + let event = rx.await.expect("no event emitted"); + drop(subscription); + event } - panic!("no event received") } /// Returns a future that resolves when the model notifies. diff --git a/crates/gpui/src/executor.rs b/crates/gpui/src/executor.rs index 841fc5b19e..115359231c 100644 --- a/crates/gpui/src/executor.rs +++ b/crates/gpui/src/executor.rs @@ -372,7 +372,7 @@ impl BackgroundExecutor { self.dispatcher.as_test().unwrap().rng() } - /// How many CPUs are available to the dispatcher + /// How many CPUs are available to the dispatcher. pub fn num_cpus(&self) -> usize { num_cpus::get() } @@ -440,6 +440,11 @@ impl<'a> Scope<'a> { } } + /// How many CPUs are available to the dispatcher. + pub fn num_cpus(&self) -> usize { + self.executor.num_cpus() + } + /// Spawn a future into this scope. pub fn spawn(&mut self, f: F) where diff --git a/crates/language/src/language.rs b/crates/language/src/language.rs index 65c838cda0..b3046bb562 100644 --- a/crates/language/src/language.rs +++ b/crates/language/src/language.rs @@ -72,7 +72,7 @@ pub use lsp::LanguageServerId; pub use outline::{Outline, OutlineItem}; pub use syntax_map::{OwnedSyntaxLayer, SyntaxLayer}; pub use text::LineEnding; -pub use tree_sitter::{Parser, Tree}; +pub use tree_sitter::{Node, Parser, Tree, TreeCursor}; use crate::language_settings::SoftWrap; @@ -91,6 +91,16 @@ thread_local! { }; } +pub fn with_parser(func: F) -> R +where + F: FnOnce(&mut Parser) -> R, +{ + PARSER.with(|parser| { + let mut parser = parser.borrow_mut(); + func(&mut parser) + }) +} + lazy_static! { static ref NEXT_LANGUAGE_ID: AtomicUsize = Default::default(); static ref NEXT_GRAMMAR_ID: AtomicUsize = Default::default(); diff --git a/crates/open_ai/src/open_ai.rs b/crates/open_ai/src/open_ai.rs index fcf4aa04bf..97abb45dfc 100644 --- a/crates/open_ai/src/open_ai.rs +++ b/crates/open_ai/src/open_ai.rs @@ -1,9 +1,11 @@ -use anyhow::{anyhow, Result}; +use anyhow::{anyhow, Context, Result}; use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt}; use serde::{Deserialize, Serialize}; -use std::convert::TryFrom; +use std::{convert::TryFrom, future::Future}; use util::http::{AsyncBody, HttpClient, Method, Request as HttpRequest}; +pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1"; + #[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)] #[serde(rename_all = "lowercase")] pub enum Role { @@ -188,3 +190,68 @@ pub async fn stream_completion( } } } + +#[derive(Copy, Clone, Serialize, Deserialize)] +pub enum OpenAiEmbeddingModel { + #[serde(rename = "text-embedding-3-small")] + TextEmbedding3Small, + #[serde(rename = "text-embedding-3-large")] + TextEmbedding3Large, +} + +#[derive(Serialize)] +struct OpenAiEmbeddingRequest<'a> { + model: OpenAiEmbeddingModel, + input: Vec<&'a str>, +} + +#[derive(Deserialize)] +pub struct OpenAiEmbeddingResponse { + pub data: Vec, +} + +#[derive(Deserialize)] +pub struct OpenAiEmbedding { + pub embedding: Vec, +} + +pub fn embed<'a>( + client: &dyn HttpClient, + api_url: &str, + api_key: &str, + model: OpenAiEmbeddingModel, + texts: impl IntoIterator, +) -> impl 'static + Future> { + let uri = format!("{api_url}/embeddings"); + + let request = OpenAiEmbeddingRequest { + model, + input: texts.into_iter().collect(), + }; + let body = AsyncBody::from(serde_json::to_string(&request).unwrap()); + let request = HttpRequest::builder() + .method(Method::POST) + .uri(uri) + .header("Content-Type", "application/json") + .header("Authorization", format!("Bearer {}", api_key)) + .body(body) + .map(|request| client.send(request)); + + async move { + let mut response = request?.await?; + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + + if response.status().is_success() { + let response: OpenAiEmbeddingResponse = + serde_json::from_str(&body).context("failed to parse OpenAI embedding response")?; + Ok(response) + } else { + Err(anyhow!( + "error during embedding, status: {:?}, body: {:?}", + response.status(), + body + )) + } + } +} diff --git a/crates/project/src/project.rs b/crates/project/src/project.rs index 6c27e4d612..92b522b3e9 100644 --- a/crates/project/src/project.rs +++ b/crates/project/src/project.rs @@ -978,6 +978,50 @@ impl Project { } } + #[cfg(any(test, feature = "test-support"))] + pub async fn example( + root_paths: impl IntoIterator, + cx: &mut AsyncAppContext, + ) -> Model { + use clock::FakeSystemClock; + + let fs = Arc::new(RealFs::default()); + let languages = LanguageRegistry::test(cx.background_executor().clone()); + let clock = Arc::new(FakeSystemClock::default()); + let http_client = util::http::FakeHttpClient::with_404_response(); + let client = cx + .update(|cx| client::Client::new(clock, http_client.clone(), cx)) + .unwrap(); + let user_store = cx + .new_model(|cx| UserStore::new(client.clone(), cx)) + .unwrap(); + let project = cx + .update(|cx| { + Project::local( + client, + node_runtime::FakeNodeRuntime::new(), + user_store, + Arc::new(languages), + fs, + cx, + ) + }) + .unwrap(); + for path in root_paths { + let (tree, _) = project + .update(cx, |project, cx| { + project.find_or_create_local_worktree(path, true, cx) + }) + .unwrap() + .await + .unwrap(); + tree.update(cx, |tree, _| tree.as_local().unwrap().scan_complete()) + .unwrap() + .await; + } + project + } + #[cfg(any(test, feature = "test-support"))] pub async fn test( fs: Arc, @@ -1146,6 +1190,10 @@ impl Project { self.user_store.clone() } + pub fn node_runtime(&self) -> Option<&Arc> { + self.node.as_ref() + } + pub fn opened_buffers(&self) -> Vec> { self.opened_buffers .values() diff --git a/crates/project/src/project_tests.rs b/crates/project/src/project_tests.rs index fa00ea6736..fed32ccf4e 100644 --- a/crates/project/src/project_tests.rs +++ b/crates/project/src/project_tests.rs @@ -2661,7 +2661,7 @@ async fn test_file_changes_multiple_times_on_disk(cx: &mut gpui::TestAppContext) ) .await .unwrap(); - worktree.next_event(cx); + worktree.next_event(cx).await; // Change the buffer's file again. Depending on the random seed, the // previous file change may still be in progress. @@ -2672,7 +2672,7 @@ async fn test_file_changes_multiple_times_on_disk(cx: &mut gpui::TestAppContext) ) .await .unwrap(); - worktree.next_event(cx); + worktree.next_event(cx).await; cx.executor().run_until_parked(); let on_disk_text = fs.load(Path::new("/dir/file1")).await.unwrap(); @@ -2716,7 +2716,7 @@ async fn test_edit_buffer_while_it_reloads(cx: &mut gpui::TestAppContext) { ) .await .unwrap(); - worktree.next_event(cx); + worktree.next_event(cx).await; cx.executor() .spawn(cx.executor().simulate_random_delay()) diff --git a/crates/rpc/proto/zed.proto b/crates/rpc/proto/zed.proto index 8cbeed1ae8..606c0bb101 100644 --- a/crates/rpc/proto/zed.proto +++ b/crates/rpc/proto/zed.proto @@ -204,6 +204,11 @@ message Envelope { LanguageModelResponse language_model_response = 167; CountTokensWithLanguageModel count_tokens_with_language_model = 168; CountTokensResponse count_tokens_response = 169; + GetCachedEmbeddings get_cached_embeddings = 189; + GetCachedEmbeddingsResponse get_cached_embeddings_response = 190; + ComputeEmbeddings compute_embeddings = 191; + ComputeEmbeddingsResponse compute_embeddings_response = 192; // current max + UpdateChannelMessage update_channel_message = 170; ChannelMessageUpdate channel_message_update = 171; @@ -216,7 +221,7 @@ message Envelope { MultiLspQueryResponse multi_lsp_query_response = 176; CreateRemoteProject create_remote_project = 177; - CreateRemoteProjectResponse create_remote_project_response = 188; // current max + CreateRemoteProjectResponse create_remote_project_response = 188; CreateDevServer create_dev_server = 178; CreateDevServerResponse create_dev_server_response = 179; ShutdownDevServer shutdown_dev_server = 180; @@ -1892,6 +1897,29 @@ message CountTokensResponse { uint32 token_count = 1; } +message GetCachedEmbeddings { + string model = 1; + repeated bytes digests = 2; +} + +message GetCachedEmbeddingsResponse { + repeated Embedding embeddings = 1; +} + +message ComputeEmbeddings { + string model = 1; + repeated string texts = 2; +} + +message ComputeEmbeddingsResponse { + repeated Embedding embeddings = 1; +} + +message Embedding { + bytes digest = 1; + repeated float dimensions = 2; +} + message BlameBuffer { uint64 project_id = 1; uint64 buffer_id = 2; diff --git a/crates/rpc/src/proto.rs b/crates/rpc/src/proto.rs index a117648cec..48160b2fe4 100644 --- a/crates/rpc/src/proto.rs +++ b/crates/rpc/src/proto.rs @@ -151,6 +151,8 @@ messages!( (ChannelMessageSent, Foreground), (ChannelMessageUpdate, Foreground), (CompleteWithLanguageModel, Background), + (ComputeEmbeddings, Background), + (ComputeEmbeddingsResponse, Background), (CopyProjectEntry, Foreground), (CountTokensWithLanguageModel, Background), (CountTokensResponse, Background), @@ -174,6 +176,8 @@ messages!( (FormatBuffers, Foreground), (FormatBuffersResponse, Foreground), (FuzzySearchUsers, Foreground), + (GetCachedEmbeddings, Background), + (GetCachedEmbeddingsResponse, Background), (GetChannelMembers, Foreground), (GetChannelMembersResponse, Foreground), (GetChannelMessages, Background), @@ -325,6 +329,7 @@ request_messages!( (CancelCall, Ack), (CopyProjectEntry, ProjectEntryResponse), (CompleteWithLanguageModel, LanguageModelResponse), + (ComputeEmbeddings, ComputeEmbeddingsResponse), (CountTokensWithLanguageModel, CountTokensResponse), (CreateChannel, CreateChannelResponse), (CreateProjectEntry, ProjectEntryResponse), @@ -336,6 +341,7 @@ request_messages!( (Follow, FollowResponse), (FormatBuffers, FormatBuffersResponse), (FuzzySearchUsers, UsersResponse), + (GetCachedEmbeddings, GetCachedEmbeddingsResponse), (GetChannelMembers, GetChannelMembersResponse), (GetChannelMessages, GetChannelMessagesResponse), (GetChannelMessagesById, GetChannelMessagesResponse), diff --git a/crates/semantic_index/Cargo.toml b/crates/semantic_index/Cargo.toml new file mode 100644 index 0000000000..5010a2fa49 --- /dev/null +++ b/crates/semantic_index/Cargo.toml @@ -0,0 +1,48 @@ +[package] +name = "semantic_index" +description = "Process, chunk, and embed text as vectors for semantic search." +version = "0.1.0" +edition = "2021" +publish = false +license = "GPL-3.0-or-later" + +[lib] +path = "src/semantic_index.rs" + +[dependencies] +anyhow.workspace = true +client.workspace = true +clock.workspace = true +collections.workspace = true +fs.workspace = true +futures.workspace = true +futures-batch.workspace = true +gpui.workspace = true +language.workspace = true +log.workspace = true +heed.workspace = true +open_ai.workspace = true +project.workspace = true +settings.workspace = true +serde.workspace = true +serde_json.workspace = true +sha2.workspace = true +smol.workspace = true +util. workspace = true +worktree.workspace = true + +[dev-dependencies] +env_logger.workspace = true +client = { workspace = true, features = ["test-support"] } +fs = { workspace = true, features = ["test-support"] } +futures.workspace = true +gpui = { workspace = true, features = ["test-support"] } +language = { workspace = true, features = ["test-support"] } +languages.workspace = true +project = { workspace = true, features = ["test-support"] } +tempfile.workspace = true +util = { workspace = true, features = ["test-support"] } +worktree = { workspace = true, features = ["test-support"] } + +[lints] +workspace = true diff --git a/crates/semantic_index/LICENSE-GPL b/crates/semantic_index/LICENSE-GPL new file mode 120000 index 0000000000..89e542f750 --- /dev/null +++ b/crates/semantic_index/LICENSE-GPL @@ -0,0 +1 @@ +../../LICENSE-GPL \ No newline at end of file diff --git a/crates/semantic_index/examples/index.rs b/crates/semantic_index/examples/index.rs new file mode 100644 index 0000000000..494d8a0f81 --- /dev/null +++ b/crates/semantic_index/examples/index.rs @@ -0,0 +1,140 @@ +use client::Client; +use futures::channel::oneshot; +use gpui::{App, Global, TestAppContext}; +use language::language_settings::AllLanguageSettings; +use project::Project; +use semantic_index::{OpenAiEmbeddingModel, OpenAiEmbeddingProvider, SemanticIndex}; +use settings::SettingsStore; +use std::{path::Path, sync::Arc}; +use util::http::HttpClientWithUrl; + +pub fn init_test(cx: &mut TestAppContext) { + _ = cx.update(|cx| { + let store = SettingsStore::test(cx); + cx.set_global(store); + language::init(cx); + Project::init_settings(cx); + SettingsStore::update(cx, |store, cx| { + store.update_user_settings::(cx, |_| {}); + }); + }); +} + +fn main() { + env_logger::init(); + + use clock::FakeSystemClock; + + App::new().run(|cx| { + let store = SettingsStore::test(cx); + cx.set_global(store); + language::init(cx); + Project::init_settings(cx); + SettingsStore::update(cx, |store, cx| { + store.update_user_settings::(cx, |_| {}); + }); + + let clock = Arc::new(FakeSystemClock::default()); + let http = Arc::new(HttpClientWithUrl::new("http://localhost:11434")); + + let client = client::Client::new(clock, http.clone(), cx); + Client::set_global(client.clone(), cx); + + let args: Vec = std::env::args().collect(); + if args.len() < 2 { + eprintln!("Usage: cargo run --example index -p semantic_index -- "); + cx.quit(); + return; + } + + // let embedding_provider = semantic_index::FakeEmbeddingProvider; + + let api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"); + let embedding_provider = OpenAiEmbeddingProvider::new( + http.clone(), + OpenAiEmbeddingModel::TextEmbedding3Small, + open_ai::OPEN_AI_API_URL.to_string(), + api_key, + ); + + let semantic_index = SemanticIndex::new( + Path::new("/tmp/semantic-index-db.mdb"), + Arc::new(embedding_provider), + cx, + ); + + cx.spawn(|mut cx| async move { + let mut semantic_index = semantic_index.await.unwrap(); + + let project_path = Path::new(&args[1]); + + let project = Project::example([project_path], &mut cx).await; + + cx.update(|cx| { + let language_registry = project.read(cx).languages().clone(); + let node_runtime = project.read(cx).node_runtime().unwrap().clone(); + languages::init(language_registry, node_runtime, cx); + }) + .unwrap(); + + let project_index = cx + .update(|cx| semantic_index.project_index(project.clone(), cx)) + .unwrap(); + + let (tx, rx) = oneshot::channel(); + let mut tx = Some(tx); + let subscription = cx.update(|cx| { + cx.subscribe(&project_index, move |_, event, _| { + if let Some(tx) = tx.take() { + _ = tx.send(*event); + } + }) + }); + + let index_start = std::time::Instant::now(); + rx.await.expect("no event emitted"); + drop(subscription); + println!("Index time: {:?}", index_start.elapsed()); + + let results = cx + .update(|cx| { + let project_index = project_index.read(cx); + let query = "converting an anchor to a point"; + project_index.search(query, 4, cx) + }) + .unwrap() + .await; + + for search_result in results { + let path = search_result.path.clone(); + + let content = cx + .update(|cx| { + let worktree = search_result.worktree.read(cx); + let entry_abs_path = worktree.abs_path().join(search_result.path.clone()); + let fs = project.read(cx).fs().clone(); + cx.spawn(|_| async move { fs.load(&entry_abs_path).await.unwrap() }) + }) + .unwrap() + .await; + + let range = search_result.range.clone(); + let content = content[search_result.range].to_owned(); + + println!( + "✄✄✄✄✄✄✄✄✄✄✄✄✄✄ {:?} @ {} ✄✄✄✄✄✄✄✄✄✄✄✄✄✄", + path, search_result.score + ); + println!("{:?}:{:?}:{:?}", path, range.start, range.end); + println!("{}", content); + } + + cx.background_executor() + .timer(std::time::Duration::from_secs(100000)) + .await; + + cx.update(|cx| cx.quit()).unwrap(); + }) + .detach(); + }); +} diff --git a/crates/semantic_index/fixture/main.rs b/crates/semantic_index/fixture/main.rs new file mode 100644 index 0000000000..f8796c8f45 --- /dev/null +++ b/crates/semantic_index/fixture/main.rs @@ -0,0 +1,3 @@ +fn main() { + println!("Hello Indexer!"); +} diff --git a/crates/semantic_index/fixture/needle.md b/crates/semantic_index/fixture/needle.md new file mode 100644 index 0000000000..80487c9983 --- /dev/null +++ b/crates/semantic_index/fixture/needle.md @@ -0,0 +1,43 @@ +# Searching for a needle in a haystack + +When you have a large amount of text, it can be useful to search for a specific word or phrase. This is often referred to as "finding a needle in a haystack." In this markdown document, we're "hiding" a key phrase for our text search to find. Can you find it? + +## Instructions + +1. Use the search functionality in your text editor or markdown viewer to find the hidden phrase in this document. + +2. Once you've found the **phrase**, write it down and proceed to the next step. + +Honestly, I just want to fill up plenty of characters so that we chunk this markdown into several chunks. + +## Tips + +- Relax +- Take a deep breath +- Focus on the task at hand +- Don't get distracted by other text +- Use the search functionality to your advantage + +## Example code + +```python +def search_for_needle(haystack, needle): + if needle in haystack: + return True + else: + return False +``` + +```javascript +function searchForNeedle(haystack, needle) { + return haystack.includes(needle); +} +``` + +## Background + +When creating an index for a book or searching for a specific term in a large document, the ability to quickly find a specific word or phrase is essential. This is where search functionality comes in handy. However, one should _remember_ that the search is only as good as the index that was built. As they say, garbage in, garbage out! + +## Conclusion + +Searching for a needle in a haystack can be a challenging task, but with the right tools and techniques, it becomes much easier. Whether you're looking for a specific word in a document or trying to find a key piece of information in a large dataset, the ability to search efficiently is a valuable skill to have. diff --git a/crates/semantic_index/src/chunking.rs b/crates/semantic_index/src/chunking.rs new file mode 100644 index 0000000000..da37afd78c --- /dev/null +++ b/crates/semantic_index/src/chunking.rs @@ -0,0 +1,409 @@ +use language::{with_parser, Grammar, Tree}; +use serde::{Deserialize, Serialize}; +use sha2::{Digest, Sha256}; +use std::{cmp, ops::Range, sync::Arc}; + +const CHUNK_THRESHOLD: usize = 1500; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Chunk { + pub range: Range, + pub digest: [u8; 32], +} + +pub fn chunk_text(text: &str, grammar: Option<&Arc>) -> Vec { + if let Some(grammar) = grammar { + let tree = with_parser(|parser| { + parser + .set_language(&grammar.ts_language) + .expect("incompatible grammar"); + parser.parse(&text, None).expect("invalid language") + }); + + chunk_parse_tree(tree, &text, CHUNK_THRESHOLD) + } else { + chunk_lines(&text) + } +} + +fn chunk_parse_tree(tree: Tree, text: &str, chunk_threshold: usize) -> Vec { + let mut chunk_ranges = Vec::new(); + let mut cursor = tree.walk(); + + let mut range = 0..0; + loop { + let node = cursor.node(); + + // If adding the node to the current chunk exceeds the threshold + if node.end_byte() - range.start > chunk_threshold { + // Try to descend into its first child. If we can't, flush the current + // range and try again. + if cursor.goto_first_child() { + continue; + } else if !range.is_empty() { + chunk_ranges.push(range.clone()); + range.start = range.end; + continue; + } + + // If we get here, the node itself has no children but is larger than the threshold. + // Break its text into arbitrary chunks. + split_text(text, range.clone(), node.end_byte(), &mut chunk_ranges); + } + range.end = node.end_byte(); + + // If we get here, we consumed the node. Advance to the next child, ascending if there isn't one. + while !cursor.goto_next_sibling() { + if !cursor.goto_parent() { + if !range.is_empty() { + chunk_ranges.push(range); + } + + return chunk_ranges + .into_iter() + .map(|range| { + let digest = Sha256::digest(&text[range.clone()]).into(); + Chunk { range, digest } + }) + .collect(); + } + } + } +} + +fn chunk_lines(text: &str) -> Vec { + let mut chunk_ranges = Vec::new(); + let mut range = 0..0; + + let mut newlines = text.match_indices('\n').peekable(); + while let Some((newline_ix, _)) = newlines.peek() { + let newline_ix = newline_ix + 1; + if newline_ix - range.start <= CHUNK_THRESHOLD { + range.end = newline_ix; + newlines.next(); + } else { + if range.is_empty() { + split_text(text, range, newline_ix, &mut chunk_ranges); + range = newline_ix..newline_ix; + } else { + chunk_ranges.push(range.clone()); + range.start = range.end; + } + } + } + + if !range.is_empty() { + chunk_ranges.push(range); + } + + chunk_ranges + .into_iter() + .map(|range| { + let mut hasher = Sha256::new(); + hasher.update(&text[range.clone()]); + let mut digest = [0u8; 32]; + digest.copy_from_slice(hasher.finalize().as_slice()); + Chunk { range, digest } + }) + .collect() +} + +fn split_text( + text: &str, + mut range: Range, + max_end: usize, + chunk_ranges: &mut Vec>, +) { + while range.start < max_end { + range.end = cmp::min(range.start + CHUNK_THRESHOLD, max_end); + while !text.is_char_boundary(range.end) { + range.end -= 1; + } + chunk_ranges.push(range.clone()); + range.start = range.end; + } +} + +#[cfg(test)] +mod tests { + use super::*; + use language::{tree_sitter_rust, Language, LanguageConfig, LanguageMatcher}; + + // This example comes from crates/gpui/examples/window_positioning.rs which + // has the property of being CHUNK_THRESHOLD < TEXT.len() < 2*CHUNK_THRESHOLD + static TEXT: &str = r#" + use gpui::*; + + struct WindowContent { + text: SharedString, + } + + impl Render for WindowContent { + fn render(&mut self, _cx: &mut ViewContext) -> impl IntoElement { + div() + .flex() + .bg(rgb(0x1e2025)) + .size_full() + .justify_center() + .items_center() + .text_xl() + .text_color(rgb(0xffffff)) + .child(self.text.clone()) + } + } + + fn main() { + App::new().run(|cx: &mut AppContext| { + // Create several new windows, positioned in the top right corner of each screen + + for screen in cx.displays() { + let options = { + let popup_margin_width = DevicePixels::from(16); + let popup_margin_height = DevicePixels::from(-0) - DevicePixels::from(48); + + let window_size = Size { + width: px(400.), + height: px(72.), + }; + + let screen_bounds = screen.bounds(); + let size: Size = window_size.into(); + + let bounds = gpui::Bounds:: { + origin: screen_bounds.upper_right() + - point(size.width + popup_margin_width, popup_margin_height), + size: window_size.into(), + }; + + WindowOptions { + // Set the bounds of the window in screen coordinates + bounds: Some(bounds), + // Specify the display_id to ensure the window is created on the correct screen + display_id: Some(screen.id()), + + titlebar: None, + window_background: WindowBackgroundAppearance::default(), + focus: false, + show: true, + kind: WindowKind::PopUp, + is_movable: false, + fullscreen: false, + } + }; + + cx.open_window(options, |cx| { + cx.new_view(|_| WindowContent { + text: format!("{:?}", screen.id()).into(), + }) + }); + } + }); + }"#; + + fn setup_rust_language() -> Language { + Language::new( + LanguageConfig { + name: "Rust".into(), + matcher: LanguageMatcher { + path_suffixes: vec!["rs".to_string()], + ..Default::default() + }, + ..Default::default() + }, + Some(tree_sitter_rust::language()), + ) + } + + #[test] + fn test_chunk_text() { + let text = "a\n".repeat(1000); + let chunks = chunk_text(&text, None); + assert_eq!( + chunks.len(), + ((2000_f64) / (CHUNK_THRESHOLD as f64)).ceil() as usize + ); + } + + #[test] + fn test_chunk_text_grammar() { + // Let's set up a big text with some known segments + // We'll then chunk it and verify that the chunks are correct + + let language = setup_rust_language(); + + let chunks = chunk_text(TEXT, language.grammar()); + assert_eq!(chunks.len(), 2); + + assert_eq!(chunks[0].range.start, 0); + assert_eq!(chunks[0].range.end, 1498); + // The break between chunks is right before the "Specify the display_id" comment + + assert_eq!(chunks[1].range.start, 1498); + assert_eq!(chunks[1].range.end, 2396); + } + + #[test] + fn test_chunk_parse_tree() { + let language = setup_rust_language(); + let grammar = language.grammar().unwrap(); + + let tree = with_parser(|parser| { + parser + .set_language(&grammar.ts_language) + .expect("incompatible grammar"); + parser.parse(TEXT, None).expect("invalid language") + }); + + let chunks = chunk_parse_tree(tree, TEXT, 250); + assert_eq!(chunks.len(), 11); + } + + #[test] + fn test_chunk_unparsable() { + // Even if a chunk is unparsable, we should still be able to chunk it + let language = setup_rust_language(); + let grammar = language.grammar().unwrap(); + + let text = r#"fn main() {"#; + let tree = with_parser(|parser| { + parser + .set_language(&grammar.ts_language) + .expect("incompatible grammar"); + parser.parse(text, None).expect("invalid language") + }); + + let chunks = chunk_parse_tree(tree, text, 250); + assert_eq!(chunks.len(), 1); + + assert_eq!(chunks[0].range.start, 0); + assert_eq!(chunks[0].range.end, 11); + } + + #[test] + fn test_empty_text() { + let language = setup_rust_language(); + let grammar = language.grammar().unwrap(); + + let tree = with_parser(|parser| { + parser + .set_language(&grammar.ts_language) + .expect("incompatible grammar"); + parser.parse("", None).expect("invalid language") + }); + + let chunks = chunk_parse_tree(tree, "", CHUNK_THRESHOLD); + assert!(chunks.is_empty(), "Chunks should be empty for empty text"); + } + + #[test] + fn test_single_large_node() { + let large_text = "static ".to_owned() + "a".repeat(CHUNK_THRESHOLD - 1).as_str() + " = 2"; + + let language = setup_rust_language(); + let grammar = language.grammar().unwrap(); + + let tree = with_parser(|parser| { + parser + .set_language(&grammar.ts_language) + .expect("incompatible grammar"); + parser.parse(&large_text, None).expect("invalid language") + }); + + let chunks = chunk_parse_tree(tree, &large_text, CHUNK_THRESHOLD); + + assert_eq!( + chunks.len(), + 3, + "Large chunks are broken up according to grammar as best as possible" + ); + + // Expect chunks to be static, aaaaaa..., and = 2 + assert_eq!(chunks[0].range.start, 0); + assert_eq!(chunks[0].range.end, "static".len()); + + assert_eq!(chunks[1].range.start, "static".len()); + assert_eq!(chunks[1].range.end, "static".len() + CHUNK_THRESHOLD); + + assert_eq!(chunks[2].range.start, "static".len() + CHUNK_THRESHOLD); + assert_eq!(chunks[2].range.end, large_text.len()); + } + + #[test] + fn test_multiple_small_nodes() { + let small_text = "a b c d e f g h i j k l m n o p q r s t u v w x y z"; + let language = setup_rust_language(); + let grammar = language.grammar().unwrap(); + + let tree = with_parser(|parser| { + parser + .set_language(&grammar.ts_language) + .expect("incompatible grammar"); + parser.parse(small_text, None).expect("invalid language") + }); + + let chunks = chunk_parse_tree(tree, small_text, 5); + assert!( + chunks.len() > 1, + "Should have multiple chunks for multiple small nodes" + ); + } + + #[test] + fn test_node_with_children() { + let nested_text = "fn main() { let a = 1; let b = 2; }"; + let language = setup_rust_language(); + let grammar = language.grammar().unwrap(); + + let tree = with_parser(|parser| { + parser + .set_language(&grammar.ts_language) + .expect("incompatible grammar"); + parser.parse(nested_text, None).expect("invalid language") + }); + + let chunks = chunk_parse_tree(tree, nested_text, 10); + assert!( + chunks.len() > 1, + "Should have multiple chunks for a node with children" + ); + } + + #[test] + fn test_text_with_unparsable_sections() { + // This test uses purposefully hit-or-miss sizing of 11 characters per likely chunk + let mixed_text = "fn main() { let a = 1; let b = 2; } unparsable bits here"; + let language = setup_rust_language(); + let grammar = language.grammar().unwrap(); + + let tree = with_parser(|parser| { + parser + .set_language(&grammar.ts_language) + .expect("incompatible grammar"); + parser.parse(mixed_text, None).expect("invalid language") + }); + + let chunks = chunk_parse_tree(tree, mixed_text, 11); + assert!( + chunks.len() > 1, + "Should handle both parsable and unparsable sections correctly" + ); + + let expected_chunks = [ + "fn main() {", + " let a = 1;", + " let b = 2;", + " }", + " unparsable", + " bits here", + ]; + + for (i, chunk) in chunks.iter().enumerate() { + assert_eq!( + &mixed_text[chunk.range.clone()], + expected_chunks[i], + "Chunk {} should match", + i + ); + } + } +} diff --git a/crates/semantic_index/src/embedding.rs b/crates/semantic_index/src/embedding.rs new file mode 100644 index 0000000000..b5195c8911 --- /dev/null +++ b/crates/semantic_index/src/embedding.rs @@ -0,0 +1,125 @@ +mod cloud; +mod ollama; +mod open_ai; + +pub use cloud::*; +pub use ollama::*; +pub use open_ai::*; +use sha2::{Digest, Sha256}; + +use anyhow::Result; +use futures::{future::BoxFuture, FutureExt}; +use serde::{Deserialize, Serialize}; +use std::{fmt, future}; + +#[derive(Debug, Default, Clone, PartialEq, Serialize, Deserialize)] +pub struct Embedding(Vec); + +impl Embedding { + pub fn new(mut embedding: Vec) -> Self { + let len = embedding.len(); + let mut norm = 0f32; + + for i in 0..len { + norm += embedding[i] * embedding[i]; + } + + norm = norm.sqrt(); + for dimension in &mut embedding { + *dimension /= norm; + } + + Self(embedding) + } + + fn len(&self) -> usize { + self.0.len() + } + + pub fn similarity(self, other: &Embedding) -> f32 { + debug_assert_eq!(self.0.len(), other.0.len()); + self.0 + .iter() + .copied() + .zip(other.0.iter().copied()) + .map(|(a, b)| a * b) + .sum() + } +} + +impl fmt::Display for Embedding { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let digits_to_display = 3; + + // Start the Embedding display format + write!(f, "Embedding(sized: {}; values: [", self.len())?; + + for (index, value) in self.0.iter().enumerate().take(digits_to_display) { + // Lead with comma if not the first element + if index != 0 { + write!(f, ", ")?; + } + write!(f, "{:.3}", value)?; + } + if self.len() > digits_to_display { + write!(f, "...")?; + } + write!(f, "])") + } +} + +/// Trait for embedding providers. Texts in, vectors out. +pub trait EmbeddingProvider: Sync + Send { + fn embed<'a>(&'a self, texts: &'a [TextToEmbed<'a>]) -> BoxFuture<'a, Result>>; + fn batch_size(&self) -> usize; +} + +#[derive(Debug)] +pub struct TextToEmbed<'a> { + pub text: &'a str, + pub digest: [u8; 32], +} + +impl<'a> TextToEmbed<'a> { + pub fn new(text: &'a str) -> Self { + let digest = Sha256::digest(text.as_bytes()); + Self { + text, + digest: digest.into(), + } + } +} + +pub struct FakeEmbeddingProvider; + +impl EmbeddingProvider for FakeEmbeddingProvider { + fn embed<'a>(&'a self, texts: &'a [TextToEmbed<'a>]) -> BoxFuture<'a, Result>> { + let embeddings = texts + .iter() + .map(|_text| { + let mut embedding = vec![0f32; 1536]; + for i in 0..embedding.len() { + embedding[i] = i as f32; + } + Embedding::new(embedding) + }) + .collect(); + future::ready(Ok(embeddings)).boxed() + } + + fn batch_size(&self) -> usize { + 16 + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[gpui::test] + fn test_normalize_embedding() { + let normalized = Embedding::new(vec![1.0, 1.0, 1.0]); + let value: f32 = 1.0 / 3.0_f32.sqrt(); + assert_eq!(normalized, Embedding(vec![value; 3])); + } +} diff --git a/crates/semantic_index/src/embedding/cloud.rs b/crates/semantic_index/src/embedding/cloud.rs new file mode 100644 index 0000000000..2a1df705c8 --- /dev/null +++ b/crates/semantic_index/src/embedding/cloud.rs @@ -0,0 +1,88 @@ +use crate::{Embedding, EmbeddingProvider, TextToEmbed}; +use anyhow::{anyhow, Context, Result}; +use client::{proto, Client}; +use collections::HashMap; +use futures::{future::BoxFuture, FutureExt}; +use std::sync::Arc; + +pub struct CloudEmbeddingProvider { + model: String, + client: Arc, +} + +impl CloudEmbeddingProvider { + pub fn new(client: Arc) -> Self { + Self { + model: "openai/text-embedding-3-small".into(), + client, + } + } +} + +impl EmbeddingProvider for CloudEmbeddingProvider { + fn embed<'a>(&'a self, texts: &'a [TextToEmbed<'a>]) -> BoxFuture<'a, Result>> { + // First, fetch any embeddings that are cached based on the requested texts' digests + // Then compute any embeddings that are missing. + async move { + let cached_embeddings = self.client.request(proto::GetCachedEmbeddings { + model: self.model.clone(), + digests: texts + .iter() + .map(|to_embed| to_embed.digest.to_vec()) + .collect(), + }); + let mut embeddings = cached_embeddings + .await + .context("failed to fetch cached embeddings via cloud model")? + .embeddings + .into_iter() + .map(|embedding| { + let digest: [u8; 32] = embedding + .digest + .try_into() + .map_err(|_| anyhow!("invalid digest for cached embedding"))?; + Ok((digest, embedding.dimensions)) + }) + .collect::>>()?; + + let compute_embeddings_request = proto::ComputeEmbeddings { + model: self.model.clone(), + texts: texts + .iter() + .filter_map(|to_embed| { + if embeddings.contains_key(&to_embed.digest) { + None + } else { + Some(to_embed.text.to_string()) + } + }) + .collect(), + }; + if !compute_embeddings_request.texts.is_empty() { + let missing_embeddings = self.client.request(compute_embeddings_request).await?; + for embedding in missing_embeddings.embeddings { + let digest: [u8; 32] = embedding + .digest + .try_into() + .map_err(|_| anyhow!("invalid digest for cached embedding"))?; + embeddings.insert(digest, embedding.dimensions); + } + } + + texts + .iter() + .map(|to_embed| { + let dimensions = embeddings.remove(&to_embed.digest).with_context(|| { + format!("server did not return an embedding for {:?}", to_embed) + })?; + Ok(Embedding::new(dimensions)) + }) + .collect() + } + .boxed() + } + + fn batch_size(&self) -> usize { + 2048 + } +} diff --git a/crates/semantic_index/src/embedding/ollama.rs b/crates/semantic_index/src/embedding/ollama.rs new file mode 100644 index 0000000000..8b8a36481b --- /dev/null +++ b/crates/semantic_index/src/embedding/ollama.rs @@ -0,0 +1,74 @@ +use anyhow::{Context as _, Result}; +use futures::{future::BoxFuture, AsyncReadExt, FutureExt}; +use serde::{Deserialize, Serialize}; +use std::sync::Arc; +use util::http::HttpClient; + +use crate::{Embedding, EmbeddingProvider, TextToEmbed}; + +pub enum OllamaEmbeddingModel { + NomicEmbedText, + MxbaiEmbedLarge, +} + +pub struct OllamaEmbeddingProvider { + client: Arc, + model: OllamaEmbeddingModel, +} + +#[derive(Serialize)] +struct OllamaEmbeddingRequest { + model: String, + prompt: String, +} + +#[derive(Deserialize)] +struct OllamaEmbeddingResponse { + embedding: Vec, +} + +impl OllamaEmbeddingProvider { + pub fn new(client: Arc, model: OllamaEmbeddingModel) -> Self { + Self { client, model } + } +} + +impl EmbeddingProvider for OllamaEmbeddingProvider { + fn embed<'a>(&'a self, texts: &'a [TextToEmbed<'a>]) -> BoxFuture<'a, Result>> { + // + let model = match self.model { + OllamaEmbeddingModel::NomicEmbedText => "nomic-embed-text", + OllamaEmbeddingModel::MxbaiEmbedLarge => "mxbai-embed-large", + }; + + futures::future::try_join_all(texts.into_iter().map(|to_embed| { + let request = OllamaEmbeddingRequest { + model: model.to_string(), + prompt: to_embed.text.to_string(), + }; + + let request = serde_json::to_string(&request).unwrap(); + + async { + let response = self + .client + .post_json("http://localhost:11434/api/embeddings", request.into()) + .await?; + + let mut body = String::new(); + response.into_body().read_to_string(&mut body).await?; + + let response: OllamaEmbeddingResponse = + serde_json::from_str(&body).context("Unable to pull response")?; + + Ok(Embedding::new(response.embedding)) + } + })) + .boxed() + } + + fn batch_size(&self) -> usize { + // TODO: Figure out decent value + 10 + } +} diff --git a/crates/semantic_index/src/embedding/open_ai.rs b/crates/semantic_index/src/embedding/open_ai.rs new file mode 100644 index 0000000000..8eccb5272f --- /dev/null +++ b/crates/semantic_index/src/embedding/open_ai.rs @@ -0,0 +1,55 @@ +use crate::{Embedding, EmbeddingProvider, TextToEmbed}; +use anyhow::Result; +use futures::{future::BoxFuture, FutureExt}; +pub use open_ai::OpenAiEmbeddingModel; +use std::sync::Arc; +use util::http::HttpClient; + +pub struct OpenAiEmbeddingProvider { + client: Arc, + model: OpenAiEmbeddingModel, + api_url: String, + api_key: String, +} + +impl OpenAiEmbeddingProvider { + pub fn new( + client: Arc, + model: OpenAiEmbeddingModel, + api_url: String, + api_key: String, + ) -> Self { + Self { + client, + model, + api_url, + api_key, + } + } +} + +impl EmbeddingProvider for OpenAiEmbeddingProvider { + fn embed<'a>(&'a self, texts: &'a [TextToEmbed<'a>]) -> BoxFuture<'a, Result>> { + let embed = open_ai::embed( + self.client.as_ref(), + &self.api_url, + &self.api_key, + self.model, + texts.iter().map(|to_embed| to_embed.text), + ); + async move { + let response = embed.await?; + Ok(response + .data + .into_iter() + .map(|data| Embedding::new(data.embedding)) + .collect()) + } + .boxed() + } + + fn batch_size(&self) -> usize { + // From https://platform.openai.com/docs/api-reference/embeddings/create + 2048 + } +} diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs new file mode 100644 index 0000000000..86c18c6fe6 --- /dev/null +++ b/crates/semantic_index/src/semantic_index.rs @@ -0,0 +1,954 @@ +mod chunking; +mod embedding; + +use anyhow::{anyhow, Context as _, Result}; +use chunking::{chunk_text, Chunk}; +use collections::{Bound, HashMap}; +pub use embedding::*; +use fs::Fs; +use futures::stream::StreamExt; +use futures_batch::ChunksTimeoutStreamExt; +use gpui::{ + AppContext, AsyncAppContext, Context, EntityId, EventEmitter, Global, Model, ModelContext, + Subscription, Task, WeakModel, +}; +use heed::types::{SerdeBincode, Str}; +use language::LanguageRegistry; +use project::{Entry, Project, UpdatedEntriesSet, Worktree}; +use serde::{Deserialize, Serialize}; +use smol::channel; +use std::{ + cmp::Ordering, + future::Future, + ops::Range, + path::Path, + sync::Arc, + time::{Duration, SystemTime}, +}; +use util::ResultExt; +use worktree::LocalSnapshot; + +pub struct SemanticIndex { + embedding_provider: Arc, + db_connection: heed::Env, + project_indices: HashMap, Model>, +} + +impl Global for SemanticIndex {} + +impl SemanticIndex { + pub fn new( + db_path: &Path, + embedding_provider: Arc, + cx: &mut AppContext, + ) -> Task> { + let db_path = db_path.to_path_buf(); + cx.spawn(|cx| async move { + let db_connection = cx + .background_executor() + .spawn(async move { + unsafe { + heed::EnvOpenOptions::new() + .map_size(1024 * 1024 * 1024) + .max_dbs(3000) + .open(db_path) + } + }) + .await?; + + Ok(SemanticIndex { + db_connection, + embedding_provider, + project_indices: HashMap::default(), + }) + }) + } + + pub fn project_index( + &mut self, + project: Model, + cx: &mut AppContext, + ) -> Model { + self.project_indices + .entry(project.downgrade()) + .or_insert_with(|| { + cx.new_model(|cx| { + ProjectIndex::new( + project, + self.db_connection.clone(), + self.embedding_provider.clone(), + cx, + ) + }) + }) + .clone() + } +} + +pub struct ProjectIndex { + db_connection: heed::Env, + project: Model, + worktree_indices: HashMap, + language_registry: Arc, + fs: Arc, + last_status: Status, + embedding_provider: Arc, + _subscription: Subscription, +} + +enum WorktreeIndexHandle { + Loading { + _task: Task>, + }, + Loaded { + index: Model, + _subscription: Subscription, + }, +} + +impl ProjectIndex { + fn new( + project: Model, + db_connection: heed::Env, + embedding_provider: Arc, + cx: &mut ModelContext, + ) -> Self { + let language_registry = project.read(cx).languages().clone(); + let fs = project.read(cx).fs().clone(); + let mut this = ProjectIndex { + db_connection, + project: project.clone(), + worktree_indices: HashMap::default(), + language_registry, + fs, + last_status: Status::Idle, + embedding_provider, + _subscription: cx.subscribe(&project, Self::handle_project_event), + }; + this.update_worktree_indices(cx); + this + } + + fn handle_project_event( + &mut self, + _: Model, + event: &project::Event, + cx: &mut ModelContext, + ) { + match event { + project::Event::WorktreeAdded | project::Event::WorktreeRemoved(_) => { + self.update_worktree_indices(cx); + } + _ => {} + } + } + + fn update_worktree_indices(&mut self, cx: &mut ModelContext) { + let worktrees = self + .project + .read(cx) + .visible_worktrees(cx) + .filter_map(|worktree| { + if worktree.read(cx).is_local() { + Some((worktree.entity_id(), worktree)) + } else { + None + } + }) + .collect::>(); + + self.worktree_indices + .retain(|worktree_id, _| worktrees.contains_key(worktree_id)); + for (worktree_id, worktree) in worktrees { + self.worktree_indices.entry(worktree_id).or_insert_with(|| { + let worktree_index = WorktreeIndex::load( + worktree.clone(), + self.db_connection.clone(), + self.language_registry.clone(), + self.fs.clone(), + self.embedding_provider.clone(), + cx, + ); + + let load_worktree = cx.spawn(|this, mut cx| async move { + if let Some(index) = worktree_index.await.log_err() { + this.update(&mut cx, |this, cx| { + this.worktree_indices.insert( + worktree_id, + WorktreeIndexHandle::Loaded { + _subscription: cx + .observe(&index, |this, _, cx| this.update_status(cx)), + index, + }, + ); + })?; + } else { + this.update(&mut cx, |this, _cx| { + this.worktree_indices.remove(&worktree_id) + })?; + } + + this.update(&mut cx, |this, cx| this.update_status(cx)) + }); + + WorktreeIndexHandle::Loading { + _task: load_worktree, + } + }); + } + + self.update_status(cx); + } + + fn update_status(&mut self, cx: &mut ModelContext) { + let mut status = Status::Idle; + for index in self.worktree_indices.values() { + match index { + WorktreeIndexHandle::Loading { .. } => { + status = Status::Scanning; + break; + } + WorktreeIndexHandle::Loaded { index, .. } => { + if index.read(cx).status == Status::Scanning { + status = Status::Scanning; + break; + } + } + } + } + + if status != self.last_status { + self.last_status = status; + cx.emit(status); + } + } + + pub fn search(&self, query: &str, limit: usize, cx: &AppContext) -> Task> { + let mut worktree_searches = Vec::new(); + for worktree_index in self.worktree_indices.values() { + if let WorktreeIndexHandle::Loaded { index, .. } = worktree_index { + worktree_searches + .push(index.read_with(cx, |index, cx| index.search(query, limit, cx))); + } + } + + cx.spawn(|_| async move { + let mut results = Vec::new(); + let worktree_searches = futures::future::join_all(worktree_searches).await; + + for worktree_search_results in worktree_searches { + if let Some(worktree_search_results) = worktree_search_results.log_err() { + results.extend(worktree_search_results); + } + } + + results + .sort_unstable_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal)); + results.truncate(limit); + + results + }) + } +} + +pub struct SearchResult { + pub worktree: Model, + pub path: Arc, + pub range: Range, + pub score: f32, +} + +#[derive(Copy, Clone, Debug, Eq, PartialEq)] +pub enum Status { + Idle, + Scanning, +} + +impl EventEmitter for ProjectIndex {} + +struct WorktreeIndex { + worktree: Model, + db_connection: heed::Env, + db: heed::Database>, + language_registry: Arc, + fs: Arc, + embedding_provider: Arc, + status: Status, + _index_entries: Task>, + _subscription: Subscription, +} + +impl WorktreeIndex { + pub fn load( + worktree: Model, + db_connection: heed::Env, + language_registry: Arc, + fs: Arc, + embedding_provider: Arc, + cx: &mut AppContext, + ) -> Task>> { + let worktree_abs_path = worktree.read(cx).abs_path(); + cx.spawn(|mut cx| async move { + let db = cx + .background_executor() + .spawn({ + let db_connection = db_connection.clone(); + async move { + let mut txn = db_connection.write_txn()?; + let db_name = worktree_abs_path.to_string_lossy(); + let db = db_connection.create_database(&mut txn, Some(&db_name))?; + txn.commit()?; + anyhow::Ok(db) + } + }) + .await?; + cx.new_model(|cx| { + Self::new( + worktree, + db_connection, + db, + language_registry, + fs, + embedding_provider, + cx, + ) + }) + }) + } + + fn new( + worktree: Model, + db_connection: heed::Env, + db: heed::Database>, + language_registry: Arc, + fs: Arc, + embedding_provider: Arc, + cx: &mut ModelContext, + ) -> Self { + let (updated_entries_tx, updated_entries_rx) = channel::unbounded(); + let _subscription = cx.subscribe(&worktree, move |_this, _worktree, event, _cx| { + if let worktree::Event::UpdatedEntries(update) = event { + _ = updated_entries_tx.try_send(update.clone()); + } + }); + + Self { + db_connection, + db, + worktree, + language_registry, + fs, + embedding_provider, + status: Status::Idle, + _index_entries: cx.spawn(|this, cx| Self::index_entries(this, updated_entries_rx, cx)), + _subscription, + } + } + + async fn index_entries( + this: WeakModel, + updated_entries: channel::Receiver, + mut cx: AsyncAppContext, + ) -> Result<()> { + let index = this.update(&mut cx, |this, cx| { + cx.notify(); + this.status = Status::Scanning; + this.index_entries_changed_on_disk(cx) + })?; + index.await.log_err(); + this.update(&mut cx, |this, cx| { + this.status = Status::Idle; + cx.notify(); + })?; + + while let Ok(updated_entries) = updated_entries.recv().await { + let index = this.update(&mut cx, |this, cx| { + cx.notify(); + this.status = Status::Scanning; + this.index_updated_entries(updated_entries, cx) + })?; + index.await.log_err(); + this.update(&mut cx, |this, cx| { + this.status = Status::Idle; + cx.notify(); + })?; + } + + Ok(()) + } + + fn index_entries_changed_on_disk(&self, cx: &AppContext) -> impl Future> { + let worktree = self.worktree.read(cx).as_local().unwrap().snapshot(); + let worktree_abs_path = worktree.abs_path().clone(); + let scan = self.scan_entries(worktree.clone(), cx); + let chunk = self.chunk_files(worktree_abs_path, scan.updated_entries, cx); + let embed = self.embed_files(chunk.files, cx); + let persist = self.persist_embeddings(scan.deleted_entry_ranges, embed.files, cx); + async move { + futures::try_join!(scan.task, chunk.task, embed.task, persist)?; + Ok(()) + } + } + + fn index_updated_entries( + &self, + updated_entries: UpdatedEntriesSet, + cx: &AppContext, + ) -> impl Future> { + let worktree = self.worktree.read(cx).as_local().unwrap().snapshot(); + let worktree_abs_path = worktree.abs_path().clone(); + let scan = self.scan_updated_entries(worktree, updated_entries, cx); + let chunk = self.chunk_files(worktree_abs_path, scan.updated_entries, cx); + let embed = self.embed_files(chunk.files, cx); + let persist = self.persist_embeddings(scan.deleted_entry_ranges, embed.files, cx); + async move { + futures::try_join!(scan.task, chunk.task, embed.task, persist)?; + Ok(()) + } + } + + fn scan_entries(&self, worktree: LocalSnapshot, cx: &AppContext) -> ScanEntries { + let (updated_entries_tx, updated_entries_rx) = channel::bounded(512); + let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) = channel::bounded(128); + let db_connection = self.db_connection.clone(); + let db = self.db; + let task = cx.background_executor().spawn(async move { + let txn = db_connection + .read_txn() + .context("failed to create read transaction")?; + let mut db_entries = db + .iter(&txn) + .context("failed to create iterator")? + .move_between_keys() + .peekable(); + + let mut deletion_range: Option<(Bound<&str>, Bound<&str>)> = None; + for entry in worktree.files(false, 0) { + let entry_db_key = db_key_for_path(&entry.path); + + let mut saved_mtime = None; + while let Some(db_entry) = db_entries.peek() { + match db_entry { + Ok((db_path, db_embedded_file)) => match (*db_path).cmp(&entry_db_key) { + Ordering::Less => { + if let Some(deletion_range) = deletion_range.as_mut() { + deletion_range.1 = Bound::Included(db_path); + } else { + deletion_range = + Some((Bound::Included(db_path), Bound::Included(db_path))); + } + + db_entries.next(); + } + Ordering::Equal => { + if let Some(deletion_range) = deletion_range.take() { + deleted_entry_ranges_tx + .send(( + deletion_range.0.map(ToString::to_string), + deletion_range.1.map(ToString::to_string), + )) + .await?; + } + saved_mtime = db_embedded_file.mtime; + db_entries.next(); + break; + } + Ordering::Greater => { + break; + } + }, + Err(_) => return Err(db_entries.next().unwrap().unwrap_err())?, + } + } + + if entry.mtime != saved_mtime { + updated_entries_tx.send(entry.clone()).await?; + } + } + + if let Some(db_entry) = db_entries.next() { + let (db_path, _) = db_entry?; + deleted_entry_ranges_tx + .send((Bound::Included(db_path.to_string()), Bound::Unbounded)) + .await?; + } + + Ok(()) + }); + + ScanEntries { + updated_entries: updated_entries_rx, + deleted_entry_ranges: deleted_entry_ranges_rx, + task, + } + } + + fn scan_updated_entries( + &self, + worktree: LocalSnapshot, + updated_entries: UpdatedEntriesSet, + cx: &AppContext, + ) -> ScanEntries { + let (updated_entries_tx, updated_entries_rx) = channel::bounded(512); + let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) = channel::bounded(128); + let task = cx.background_executor().spawn(async move { + for (path, entry_id, status) in updated_entries.iter() { + match status { + project::PathChange::Added + | project::PathChange::Updated + | project::PathChange::AddedOrUpdated => { + if let Some(entry) = worktree.entry_for_id(*entry_id) { + updated_entries_tx.send(entry.clone()).await?; + } + } + project::PathChange::Removed => { + let db_path = db_key_for_path(path); + deleted_entry_ranges_tx + .send((Bound::Included(db_path.clone()), Bound::Included(db_path))) + .await?; + } + project::PathChange::Loaded => { + // Do nothing. + } + } + } + + Ok(()) + }); + + ScanEntries { + updated_entries: updated_entries_rx, + deleted_entry_ranges: deleted_entry_ranges_rx, + task, + } + } + + fn chunk_files( + &self, + worktree_abs_path: Arc, + entries: channel::Receiver, + cx: &AppContext, + ) -> ChunkFiles { + let language_registry = self.language_registry.clone(); + let fs = self.fs.clone(); + let (chunked_files_tx, chunked_files_rx) = channel::bounded(2048); + let task = cx.spawn(|cx| async move { + cx.background_executor() + .scoped(|cx| { + for _ in 0..cx.num_cpus() { + cx.spawn(async { + while let Ok(entry) = entries.recv().await { + let entry_abs_path = worktree_abs_path.join(&entry.path); + let Some(text) = fs.load(&entry_abs_path).await.log_err() else { + continue; + }; + let language = language_registry + .language_for_file_path(&entry.path) + .await + .ok(); + let grammar = + language.as_ref().and_then(|language| language.grammar()); + let chunked_file = ChunkedFile { + worktree_root: worktree_abs_path.clone(), + chunks: chunk_text(&text, grammar), + entry, + text, + }; + + if chunked_files_tx.send(chunked_file).await.is_err() { + return; + } + } + }); + } + }) + .await; + Ok(()) + }); + + ChunkFiles { + files: chunked_files_rx, + task, + } + } + + fn embed_files( + &self, + chunked_files: channel::Receiver, + cx: &AppContext, + ) -> EmbedFiles { + let embedding_provider = self.embedding_provider.clone(); + let (embedded_files_tx, embedded_files_rx) = channel::bounded(512); + let task = cx.background_executor().spawn(async move { + let mut chunked_file_batches = + chunked_files.chunks_timeout(512, Duration::from_secs(2)); + while let Some(chunked_files) = chunked_file_batches.next().await { + // View the batch of files as a vec of chunks + // Flatten out to a vec of chunks that we can subdivide into batch sized pieces + // Once those are done, reassemble it back into which files they belong to + + let chunks = chunked_files + .iter() + .flat_map(|file| { + file.chunks.iter().map(|chunk| TextToEmbed { + text: &file.text[chunk.range.clone()], + digest: chunk.digest, + }) + }) + .collect::>(); + + let mut embeddings = Vec::new(); + for embedding_batch in chunks.chunks(embedding_provider.batch_size()) { + // todo!("add a retry facility") + embeddings.extend(embedding_provider.embed(embedding_batch).await?); + } + + let mut embeddings = embeddings.into_iter(); + for chunked_file in chunked_files { + let chunk_embeddings = embeddings + .by_ref() + .take(chunked_file.chunks.len()) + .collect::>(); + let embedded_chunks = chunked_file + .chunks + .into_iter() + .zip(chunk_embeddings) + .map(|(chunk, embedding)| EmbeddedChunk { chunk, embedding }) + .collect(); + let embedded_file = EmbeddedFile { + path: chunked_file.entry.path.clone(), + mtime: chunked_file.entry.mtime, + chunks: embedded_chunks, + }; + + embedded_files_tx.send(embedded_file).await?; + } + } + Ok(()) + }); + + EmbedFiles { + files: embedded_files_rx, + task, + } + } + + fn persist_embeddings( + &self, + mut deleted_entry_ranges: channel::Receiver<(Bound, Bound)>, + embedded_files: channel::Receiver, + cx: &AppContext, + ) -> Task> { + let db_connection = self.db_connection.clone(); + let db = self.db; + cx.background_executor().spawn(async move { + while let Some(deletion_range) = deleted_entry_ranges.next().await { + let mut txn = db_connection.write_txn()?; + let start = deletion_range.0.as_ref().map(|start| start.as_str()); + let end = deletion_range.1.as_ref().map(|end| end.as_str()); + log::debug!("deleting embeddings in range {:?}", &(start, end)); + db.delete_range(&mut txn, &(start, end))?; + txn.commit()?; + } + + let mut embedded_files = embedded_files.chunks_timeout(4096, Duration::from_secs(2)); + while let Some(embedded_files) = embedded_files.next().await { + let mut txn = db_connection.write_txn()?; + for file in embedded_files { + log::debug!("saving embedding for file {:?}", file.path); + let key = db_key_for_path(&file.path); + db.put(&mut txn, &key, &file)?; + } + txn.commit()?; + log::debug!("committed"); + } + + Ok(()) + }) + } + + fn search( + &self, + query: &str, + limit: usize, + cx: &AppContext, + ) -> Task>> { + let (chunks_tx, chunks_rx) = channel::bounded(1024); + + let db_connection = self.db_connection.clone(); + let db = self.db; + let scan_chunks = cx.background_executor().spawn({ + async move { + let txn = db_connection + .read_txn() + .context("failed to create read transaction")?; + let db_entries = db.iter(&txn).context("failed to iterate database")?; + for db_entry in db_entries { + let (_, db_embedded_file) = db_entry?; + for chunk in db_embedded_file.chunks { + chunks_tx + .send((db_embedded_file.path.clone(), chunk)) + .await?; + } + } + anyhow::Ok(()) + } + }); + + let query = query.to_string(); + let embedding_provider = self.embedding_provider.clone(); + let worktree = self.worktree.clone(); + cx.spawn(|cx| async move { + #[cfg(debug_assertions)] + let embedding_query_start = std::time::Instant::now(); + + let mut query_embeddings = embedding_provider + .embed(&[TextToEmbed::new(&query)]) + .await?; + let query_embedding = query_embeddings + .pop() + .ok_or_else(|| anyhow!("no embedding for query"))?; + let mut workers = Vec::new(); + for _ in 0..cx.background_executor().num_cpus() { + workers.push(Vec::::new()); + } + + #[cfg(debug_assertions)] + let search_start = std::time::Instant::now(); + + cx.background_executor() + .scoped(|cx| { + for worker_results in workers.iter_mut() { + cx.spawn(async { + while let Ok((path, embedded_chunk)) = chunks_rx.recv().await { + let score = embedded_chunk.embedding.similarity(&query_embedding); + let ix = match worker_results.binary_search_by(|probe| { + score.partial_cmp(&probe.score).unwrap_or(Ordering::Equal) + }) { + Ok(ix) | Err(ix) => ix, + }; + worker_results.insert( + ix, + SearchResult { + worktree: worktree.clone(), + path: path.clone(), + range: embedded_chunk.chunk.range.clone(), + score, + }, + ); + worker_results.truncate(limit); + } + }); + } + }) + .await; + scan_chunks.await?; + + let mut search_results = Vec::with_capacity(workers.len() * limit); + for worker_results in workers { + search_results.extend(worker_results); + } + search_results + .sort_unstable_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal)); + search_results.truncate(limit); + #[cfg(debug_assertions)] + { + let search_elapsed = search_start.elapsed(); + log::debug!( + "searched {} entries in {:?}", + search_results.len(), + search_elapsed + ); + let embedding_query_elapsed = embedding_query_start.elapsed(); + log::debug!("embedding query took {:?}", embedding_query_elapsed); + } + + Ok(search_results) + }) + } +} + +struct ScanEntries { + updated_entries: channel::Receiver, + deleted_entry_ranges: channel::Receiver<(Bound, Bound)>, + task: Task>, +} + +struct ChunkFiles { + files: channel::Receiver, + task: Task>, +} + +struct ChunkedFile { + #[allow(dead_code)] + pub worktree_root: Arc, + pub entry: Entry, + pub text: String, + pub chunks: Vec, +} + +struct EmbedFiles { + files: channel::Receiver, + task: Task>, +} + +#[derive(Debug, Serialize, Deserialize)] +struct EmbeddedFile { + path: Arc, + mtime: Option, + chunks: Vec, +} + +#[derive(Debug, Serialize, Deserialize)] +struct EmbeddedChunk { + chunk: Chunk, + embedding: Embedding, +} + +fn db_key_for_path(path: &Arc) -> String { + path.to_string_lossy().replace('/', "\0") +} + +#[cfg(test)] +mod tests { + use super::*; + + use futures::channel::oneshot; + use futures::{future::BoxFuture, FutureExt}; + + use gpui::{Global, TestAppContext}; + use language::language_settings::AllLanguageSettings; + use project::Project; + use settings::SettingsStore; + use std::{future, path::Path, sync::Arc}; + + fn init_test(cx: &mut TestAppContext) { + _ = cx.update(|cx| { + let store = SettingsStore::test(cx); + cx.set_global(store); + language::init(cx); + Project::init_settings(cx); + SettingsStore::update(cx, |store, cx| { + store.update_user_settings::(cx, |_| {}); + }); + }); + } + + pub struct TestEmbeddingProvider; + + impl EmbeddingProvider for TestEmbeddingProvider { + fn embed<'a>( + &'a self, + texts: &'a [TextToEmbed<'a>], + ) -> BoxFuture<'a, Result>> { + let embeddings = texts + .iter() + .map(|text| { + let mut embedding = vec![0f32; 2]; + // if the text contains garbage, give it a 1 in the first dimension + if text.text.contains("garbage in") { + embedding[0] = 0.9; + } else { + embedding[0] = -0.9; + } + + if text.text.contains("garbage out") { + embedding[1] = 0.9; + } else { + embedding[1] = -0.9; + } + + Embedding::new(embedding) + }) + .collect(); + future::ready(Ok(embeddings)).boxed() + } + + fn batch_size(&self) -> usize { + 16 + } + } + + #[gpui::test] + async fn test_search(cx: &mut TestAppContext) { + cx.executor().allow_parking(); + + init_test(cx); + + let temp_dir = tempfile::tempdir().unwrap(); + + let mut semantic_index = cx + .update(|cx| { + let semantic_index = SemanticIndex::new( + Path::new(temp_dir.path()), + Arc::new(TestEmbeddingProvider), + cx, + ); + semantic_index + }) + .await + .unwrap(); + + // todo!(): use a fixture + let project_path = Path::new("./fixture"); + + let project = cx + .spawn(|mut cx| async move { Project::example([project_path], &mut cx).await }) + .await; + + cx.update(|cx| { + let language_registry = project.read(cx).languages().clone(); + let node_runtime = project.read(cx).node_runtime().unwrap().clone(); + languages::init(language_registry, node_runtime, cx); + }); + + let project_index = cx.update(|cx| semantic_index.project_index(project.clone(), cx)); + + let (tx, rx) = oneshot::channel(); + let mut tx = Some(tx); + let subscription = cx.update(|cx| { + cx.subscribe(&project_index, move |_, event, _| { + if let Some(tx) = tx.take() { + _ = tx.send(*event); + } + }) + }); + + rx.await.expect("no event emitted"); + drop(subscription); + + let results = cx + .update(|cx| { + let project_index = project_index.read(cx); + let query = "garbage in, garbage out"; + project_index.search(query, 4, cx) + }) + .await; + + assert!(results.len() > 1, "should have found some results"); + + for result in &results { + println!("result: {:?}", result.path); + println!("score: {:?}", result.score); + } + + // Find result that is greater than 0.5 + let search_result = results.iter().find(|result| result.score > 0.9).unwrap(); + + assert_eq!(search_result.path.to_string_lossy(), "needle.md"); + + let content = cx + .update(|cx| { + let worktree = search_result.worktree.read(cx); + let entry_abs_path = worktree.abs_path().join(search_result.path.clone()); + let fs = project.read(cx).fs().clone(); + cx.spawn(|_| async move { fs.load(&entry_abs_path).await.unwrap() }) + }) + .await; + + let range = search_result.range.clone(); + let content = content[range.clone()].to_owned(); + + assert!(content.contains("garbage in, garbage out")); + } +} diff --git a/crates/util/src/http.rs b/crates/util/src/http.rs index 3ee7b96eac..01a061cd1a 100644 --- a/crates/util/src/http.rs +++ b/crates/util/src/http.rs @@ -71,19 +71,28 @@ impl HttpClientWithUrl { } impl HttpClient for Arc { - fn send(&self, req: Request) -> BoxFuture, Error>> { + fn send( + &self, + req: Request, + ) -> BoxFuture<'static, Result, Error>> { self.client.send(req) } } impl HttpClient for HttpClientWithUrl { - fn send(&self, req: Request) -> BoxFuture, Error>> { + fn send( + &self, + req: Request, + ) -> BoxFuture<'static, Result, Error>> { self.client.send(req) } } pub trait HttpClient: Send + Sync { - fn send(&self, req: Request) -> BoxFuture, Error>>; + fn send( + &self, + req: Request, + ) -> BoxFuture<'static, Result, Error>>; fn get<'a>( &'a self, @@ -135,8 +144,12 @@ pub fn client() -> Arc { } impl HttpClient for isahc::HttpClient { - fn send(&self, req: Request) -> BoxFuture, Error>> { - Box::pin(async move { self.send_async(req).await }) + fn send( + &self, + req: Request, + ) -> BoxFuture<'static, Result, Error>> { + let client = self.clone(); + Box::pin(async move { client.send_async(req).await }) } } @@ -196,7 +209,10 @@ impl fmt::Debug for FakeHttpClient { #[cfg(feature = "test-support")] impl HttpClient for FakeHttpClient { - fn send(&self, req: Request) -> BoxFuture, Error>> { + fn send( + &self, + req: Request, + ) -> BoxFuture<'static, Result, Error>> { let future = (self.handler)(req); Box::pin(async move { future.await.map(Into::into) }) }