Semantic Index (#10329)

This introduces semantic indexing in Zed based on chunking text from
files in the developer's workspace and creating vector embeddings using
an embedding model. As part of this, we've created an embeddings
provider trait that allows us to work with OpenAI, a local Ollama model,
or a Zed hosted embedding.

The semantic index is built by breaking down text for known
(programming) languages into manageable chunks that are smaller than the
max token size. Each chunk is then fed to a language model to create a
high dimensional vector which is then normalized to a unit vector to
allow fast comparison with other vectors with a simple dot product.
Alongside the vector, we store the path of the file and the range within
the document where the vector was sourced from.

Zed will soon grok contextual similarity across different text snippets,
allowing for natural language search beyond keyword matching. This is
being put together both for human-based search as well as providing
results to Large Language Models to allow them to refine how they help
developers.

Remaining todo:

* [x] Change `provider` to `model` within the zed hosted embeddings
database (as its currently a combo of the provider and the model in one
name)


Release Notes:

- N/A

---------

Co-authored-by: Nathan Sobo <nathan@zed.dev>
Co-authored-by: Antonio Scandurra <me@as-cii.com>
Co-authored-by: Conrad Irwin <conrad@zed.dev>
Co-authored-by: Marshall Bowers <elliott.codes@gmail.com>
Co-authored-by: Antonio <antonio@zed.dev>
This commit is contained in:
Kyle Kelley 2024-04-12 10:40:59 -07:00 committed by GitHub
parent 4b40e83b8b
commit 49371b44cb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
33 changed files with 2649 additions and 41 deletions

143
Cargo.lock generated
View File

@ -3265,6 +3265,15 @@ version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9ea835d29036a4087793836fa931b08837ad5e957da9e23886b29586fb9b6650" checksum = "9ea835d29036a4087793836fa931b08837ad5e957da9e23886b29586fb9b6650"
[[package]]
name = "doxygen-rs"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "415b6ec780d34dcf624666747194393603d0373b7141eef01d12ee58881507d9"
dependencies = [
"phf",
]
[[package]] [[package]]
name = "dwrote" name = "dwrote"
version = "0.11.0" version = "0.11.0"
@ -4085,6 +4094,17 @@ dependencies = [
"futures-util", "futures-util",
] ]
[[package]]
name = "futures-batch"
version = "0.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6f444c45a1cb86f2a7e301469fd50a82084a60dadc25d94529a8312276ecb71a"
dependencies = [
"futures 0.3.28",
"futures-timer",
"pin-utils",
]
[[package]] [[package]]
name = "futures-channel" name = "futures-channel"
version = "0.3.30" version = "0.3.30"
@ -4180,6 +4200,12 @@ version = "0.3.30"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004" checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004"
[[package]]
name = "futures-timer"
version = "3.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24"
[[package]] [[package]]
name = "futures-util" name = "futures-util"
version = "0.3.30" version = "0.3.30"
@ -4659,6 +4685,41 @@ dependencies = [
"unicode-segmentation", "unicode-segmentation",
] ]
[[package]]
name = "heed"
version = "0.20.0-alpha.9"
source = "git+https://github.com/meilisearch/heed?rev=036ac23f73a021894974b9adc815bc95b3e0482a#036ac23f73a021894974b9adc815bc95b3e0482a"
dependencies = [
"bitflags 2.4.2",
"byteorder",
"heed-traits",
"heed-types",
"libc",
"lmdb-master-sys",
"once_cell",
"page_size",
"serde",
"synchronoise",
"url",
]
[[package]]
name = "heed-traits"
version = "0.20.0-alpha.9"
source = "git+https://github.com/meilisearch/heed?rev=036ac23f73a021894974b9adc815bc95b3e0482a#036ac23f73a021894974b9adc815bc95b3e0482a"
[[package]]
name = "heed-types"
version = "0.20.0-alpha.9"
source = "git+https://github.com/meilisearch/heed?rev=036ac23f73a021894974b9adc815bc95b3e0482a#036ac23f73a021894974b9adc815bc95b3e0482a"
dependencies = [
"bincode",
"byteorder",
"heed-traits",
"serde",
"serde_json",
]
[[package]] [[package]]
name = "hermit-abi" name = "hermit-abi"
version = "0.1.19" version = "0.1.19"
@ -5664,6 +5725,16 @@ dependencies = [
"sha2 0.10.7", "sha2 0.10.7",
] ]
[[package]]
name = "lmdb-master-sys"
version = "0.1.0"
source = "git+https://github.com/meilisearch/heed?rev=036ac23f73a021894974b9adc815bc95b3e0482a#036ac23f73a021894974b9adc815bc95b3e0482a"
dependencies = [
"cc",
"doxygen-rs",
"libc",
]
[[package]] [[package]]
name = "lock_api" name = "lock_api"
version = "0.4.10" version = "0.4.10"
@ -6683,6 +6754,16 @@ dependencies = [
"sha2 0.10.7", "sha2 0.10.7",
] ]
[[package]]
name = "page_size"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "30d5b2194ed13191c1999ae0704b7839fb18384fa22e49b57eeaa97d79ce40da"
dependencies = [
"libc",
"winapi",
]
[[package]] [[package]]
name = "palette" name = "palette"
version = "0.7.5" version = "0.7.5"
@ -6856,9 +6937,33 @@ version = "0.11.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ade2d8b8f33c7333b51bcf0428d37e217e9f32192ae4772156f65063b8ce03dc" checksum = "ade2d8b8f33c7333b51bcf0428d37e217e9f32192ae4772156f65063b8ce03dc"
dependencies = [ dependencies = [
"phf_macros",
"phf_shared", "phf_shared",
] ]
[[package]]
name = "phf_generator"
version = "0.11.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "48e4cc64c2ad9ebe670cb8fd69dd50ae301650392e81c05f9bfcb2d5bdbc24b0"
dependencies = [
"phf_shared",
"rand 0.8.5",
]
[[package]]
name = "phf_macros"
version = "0.11.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3444646e286606587e49f3bcf1679b8cef1dc2c5ecc29ddacaffc305180d464b"
dependencies = [
"phf_generator",
"phf_shared",
"proc-macro2",
"quote",
"syn 2.0.48",
]
[[package]] [[package]]
name = "phf_shared" name = "phf_shared"
version = "0.11.2" version = "0.11.2"
@ -8473,6 +8578,35 @@ version = "1.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "58bf37232d3bb9a2c4e641ca2a11d83b5062066f88df7fed36c28772046d65ba" checksum = "58bf37232d3bb9a2c4e641ca2a11d83b5062066f88df7fed36c28772046d65ba"
[[package]]
name = "semantic_index"
version = "0.1.0"
dependencies = [
"anyhow",
"client",
"clock",
"collections",
"env_logger",
"fs",
"futures 0.3.28",
"futures-batch",
"gpui",
"heed",
"language",
"languages",
"log",
"open_ai",
"project",
"serde",
"serde_json",
"settings",
"sha2 0.10.7",
"smol",
"tempfile",
"util",
"worktree",
]
[[package]] [[package]]
name = "semantic_version" name = "semantic_version"
version = "0.1.0" version = "0.1.0"
@ -9478,6 +9612,15 @@ version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160" checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160"
[[package]]
name = "synchronoise"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3dbc01390fc626ce8d1cffe3376ded2b72a11bb70e1c75f404a210e4daa4def2"
dependencies = [
"crossbeam-queue",
]
[[package]] [[package]]
name = "sys-locale" name = "sys-locale"
version = "0.3.1" version = "0.3.1"

View File

@ -73,6 +73,7 @@ members = [
"crates/task", "crates/task",
"crates/tasks_ui", "crates/tasks_ui",
"crates/search", "crates/search",
"crates/semantic_index",
"crates/semantic_version", "crates/semantic_version",
"crates/settings", "crates/settings",
"crates/snippet", "crates/snippet",
@ -253,9 +254,11 @@ derive_more = "0.99.17"
emojis = "0.6.1" emojis = "0.6.1"
env_logger = "0.9" env_logger = "0.9"
futures = "0.3" futures = "0.3"
futures-batch = "0.6.1"
futures-lite = "1.13" futures-lite = "1.13"
git2 = { version = "0.15", default-features = false } git2 = { version = "0.15", default-features = false }
globset = "0.4" globset = "0.4"
heed = { git = "https://github.com/meilisearch/heed", rev = "036ac23f73a021894974b9adc815bc95b3e0482a", features = ["read-txn-no-tls"] }
hex = "0.4.3" hex = "0.4.3"
ignore = "0.4.22" ignore = "0.4.22"
indoc = "1" indoc = "1"

View File

@ -264,7 +264,7 @@ async fn test_channel_messages(cx: &mut TestAppContext) {
); );
assert_eq!( assert_eq!(
channel.next_event(cx), channel.next_event(cx).await,
ChannelChatEvent::MessagesUpdated { ChannelChatEvent::MessagesUpdated {
old_range: 2..2, old_range: 2..2,
new_count: 1, new_count: 1,
@ -317,7 +317,7 @@ async fn test_channel_messages(cx: &mut TestAppContext) {
); );
assert_eq!( assert_eq!(
channel.next_event(cx), channel.next_event(cx).await,
ChannelChatEvent::MessagesUpdated { ChannelChatEvent::MessagesUpdated {
old_range: 0..0, old_range: 0..0,
new_count: 2, new_count: 2,

View File

@ -0,0 +1,9 @@
CREATE TABLE IF NOT EXISTS "embeddings" (
"model" TEXT,
"digest" BYTEA,
"dimensions" FLOAT4[1536],
"retrieved_at" TIMESTAMP NOT NULL DEFAULT now(),
PRIMARY KEY ("model", "digest")
);
CREATE INDEX IF NOT EXISTS "idx_retrieved_at_on_embeddings" ON "embeddings" ("retrieved_at");

View File

@ -6,6 +6,7 @@ pub mod channels;
pub mod contacts; pub mod contacts;
pub mod contributors; pub mod contributors;
pub mod dev_servers; pub mod dev_servers;
pub mod embeddings;
pub mod extensions; pub mod extensions;
pub mod hosted_projects; pub mod hosted_projects;
pub mod messages; pub mod messages;

View File

@ -0,0 +1,94 @@
use super::*;
use time::Duration;
use time::OffsetDateTime;
impl Database {
pub async fn get_embeddings(
&self,
model: &str,
digests: &[Vec<u8>],
) -> Result<HashMap<Vec<u8>, Vec<f32>>> {
self.weak_transaction(|tx| async move {
let embeddings = {
let mut db_embeddings = embedding::Entity::find()
.filter(
embedding::Column::Model.eq(model).and(
embedding::Column::Digest
.is_in(digests.iter().map(|digest| digest.as_slice())),
),
)
.stream(&*tx)
.await?;
let mut embeddings = HashMap::default();
while let Some(db_embedding) = db_embeddings.next().await {
let db_embedding = db_embedding?;
embeddings.insert(db_embedding.digest, db_embedding.dimensions);
}
embeddings
};
if !embeddings.is_empty() {
let now = OffsetDateTime::now_utc();
let retrieved_at = PrimitiveDateTime::new(now.date(), now.time());
embedding::Entity::update_many()
.filter(
embedding::Column::Digest
.is_in(embeddings.keys().map(|digest| digest.as_slice())),
)
.col_expr(embedding::Column::RetrievedAt, Expr::value(retrieved_at))
.exec(&*tx)
.await?;
}
Ok(embeddings)
})
.await
}
pub async fn save_embeddings(
&self,
model: &str,
embeddings: &HashMap<Vec<u8>, Vec<f32>>,
) -> Result<()> {
self.weak_transaction(|tx| async move {
embedding::Entity::insert_many(embeddings.iter().map(|(digest, dimensions)| {
let now_offset_datetime = OffsetDateTime::now_utc();
let retrieved_at =
PrimitiveDateTime::new(now_offset_datetime.date(), now_offset_datetime.time());
embedding::ActiveModel {
model: ActiveValue::set(model.to_string()),
digest: ActiveValue::set(digest.clone()),
dimensions: ActiveValue::set(dimensions.clone()),
retrieved_at: ActiveValue::set(retrieved_at),
}
}))
.on_conflict(
OnConflict::columns([embedding::Column::Model, embedding::Column::Digest])
.do_nothing()
.to_owned(),
)
.exec_without_returning(&*tx)
.await?;
Ok(())
})
.await
}
pub async fn purge_old_embeddings(&self) -> Result<()> {
self.weak_transaction(|tx| async move {
embedding::Entity::delete_many()
.filter(
embedding::Column::RetrievedAt
.lte(OffsetDateTime::now_utc() - Duration::days(60)),
)
.exec(&*tx)
.await?;
Ok(())
})
.await
}
}

View File

@ -11,6 +11,7 @@ pub mod channel_message_mention;
pub mod contact; pub mod contact;
pub mod contributor; pub mod contributor;
pub mod dev_server; pub mod dev_server;
pub mod embedding;
pub mod extension; pub mod extension;
pub mod extension_version; pub mod extension_version;
pub mod feature_flag; pub mod feature_flag;

View File

@ -0,0 +1,18 @@
use sea_orm::entity::prelude::*;
use time::PrimitiveDateTime;
#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
#[sea_orm(table_name = "embeddings")]
pub struct Model {
#[sea_orm(primary_key)]
pub model: String,
#[sea_orm(primary_key)]
pub digest: Vec<u8>,
pub dimensions: Vec<f32>,
pub retrieved_at: PrimitiveDateTime,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {}
impl ActiveModelBehavior for ActiveModel {}

View File

@ -2,6 +2,7 @@ mod buffer_tests;
mod channel_tests; mod channel_tests;
mod contributor_tests; mod contributor_tests;
mod db_tests; mod db_tests;
mod embedding_tests;
mod extension_tests; mod extension_tests;
mod feature_flag_tests; mod feature_flag_tests;
mod message_tests; mod message_tests;

View File

@ -0,0 +1,84 @@
use super::TestDb;
use crate::db::embedding;
use collections::HashMap;
use sea_orm::{sea_query::Expr, ColumnTrait, EntityTrait, QueryFilter};
use std::ops::Sub;
use time::{Duration, OffsetDateTime, PrimitiveDateTime};
// SQLite does not support array arguments, so we only test this against a real postgres instance
#[gpui::test]
async fn test_get_embeddings_postgres(cx: &mut gpui::TestAppContext) {
let test_db = TestDb::postgres(cx.executor().clone());
let db = test_db.db();
let provider = "test_model";
let digest1 = vec![1, 2, 3];
let digest2 = vec![4, 5, 6];
let embeddings = HashMap::from_iter([
(digest1.clone(), vec![0.1, 0.2, 0.3]),
(digest2.clone(), vec![0.4, 0.5, 0.6]),
]);
// Save embeddings
db.save_embeddings(provider, &embeddings).await.unwrap();
// Retrieve embeddings
let retrieved_embeddings = db
.get_embeddings(provider, &[digest1.clone(), digest2.clone()])
.await
.unwrap();
assert_eq!(retrieved_embeddings.len(), 2);
assert!(retrieved_embeddings.contains_key(&digest1));
assert!(retrieved_embeddings.contains_key(&digest2));
// Check if the retrieved embeddings are correct
assert_eq!(retrieved_embeddings[&digest1], vec![0.1, 0.2, 0.3]);
assert_eq!(retrieved_embeddings[&digest2], vec![0.4, 0.5, 0.6]);
}
#[gpui::test]
async fn test_purge_old_embeddings(cx: &mut gpui::TestAppContext) {
let test_db = TestDb::postgres(cx.executor().clone());
let db = test_db.db();
let model = "test_model";
let digest = vec![7, 8, 9];
let embeddings = HashMap::from_iter([(digest.clone(), vec![0.7, 0.8, 0.9])]);
// Save old embeddings
db.save_embeddings(model, &embeddings).await.unwrap();
// Reach into the DB and change the retrieved at to be > 60 days
db.weak_transaction(|tx| {
let digest = digest.clone();
async move {
let sixty_days_ago = OffsetDateTime::now_utc().sub(Duration::days(61));
let retrieved_at = PrimitiveDateTime::new(sixty_days_ago.date(), sixty_days_ago.time());
embedding::Entity::update_many()
.filter(
embedding::Column::Model
.eq(model)
.and(embedding::Column::Digest.eq(digest)),
)
.col_expr(embedding::Column::RetrievedAt, Expr::value(retrieved_at))
.exec(&*tx)
.await
.unwrap();
Ok(())
}
})
.await
.unwrap();
// Purge old embeddings
db.purge_old_embeddings().await.unwrap();
// Try to retrieve the purged embeddings
let retrieved_embeddings = db.get_embeddings(model, &[digest.clone()]).await.unwrap();
assert!(
retrieved_embeddings.is_empty(),
"Old embeddings should have been purged"
);
}

View File

@ -6,8 +6,8 @@ use axum::{
Extension, Router, Extension, Router,
}; };
use collab::{ use collab::{
api::fetch_extensions_from_blob_store_periodically, db, env, executor::Executor, AppState, api::fetch_extensions_from_blob_store_periodically, db, env, executor::Executor,
Config, RateLimiter, Result, rpc::ResultExt, AppState, Config, RateLimiter, Result,
}; };
use db::Database; use db::Database;
use std::{ use std::{
@ -23,7 +23,7 @@ use tower_http::trace::TraceLayer;
use tracing_subscriber::{ use tracing_subscriber::{
filter::EnvFilter, fmt::format::JsonFields, util::SubscriberInitExt, Layer, filter::EnvFilter, fmt::format::JsonFields, util::SubscriberInitExt, Layer,
}; };
use util::ResultExt; use util::ResultExt as _;
const VERSION: &str = env!("CARGO_PKG_VERSION"); const VERSION: &str = env!("CARGO_PKG_VERSION");
const REVISION: Option<&'static str> = option_env!("GITHUB_SHA"); const REVISION: Option<&'static str> = option_env!("GITHUB_SHA");
@ -90,6 +90,7 @@ async fn main() -> Result<()> {
}; };
if is_collab { if is_collab {
state.db.purge_old_embeddings().await.trace_err();
RateLimiter::save_periodically(state.rate_limiter.clone(), state.executor.clone()); RateLimiter::save_periodically(state.rate_limiter.clone(), state.executor.clone());
} }

View File

@ -32,6 +32,8 @@ use axum::{
use collections::{HashMap, HashSet}; use collections::{HashMap, HashSet};
pub use connection_pool::{ConnectionPool, ZedVersion}; pub use connection_pool::{ConnectionPool, ZedVersion};
use core::fmt::{self, Debug, Formatter}; use core::fmt::{self, Debug, Formatter};
use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
use sha2::Digest;
use futures::{ use futures::{
channel::oneshot, channel::oneshot,
@ -568,6 +570,22 @@ impl Server {
app_state.config.google_ai_api_key.clone(), app_state.config.google_ai_api_key.clone(),
) )
}) })
})
.add_request_handler({
user_handler(move |request, response, session| {
get_cached_embeddings(request, response, session)
})
})
.add_request_handler({
let app_state = app_state.clone();
user_handler(move |request, response, session| {
compute_embeddings(
request,
response,
session,
app_state.config.openai_api_key.clone(),
)
})
}); });
Arc::new(server) Arc::new(server)
@ -4021,8 +4039,6 @@ async fn complete_with_open_ai(
session: UserSession, session: UserSession,
api_key: Arc<str>, api_key: Arc<str>,
) -> Result<()> { ) -> Result<()> {
const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
let mut completion_stream = open_ai::stream_completion( let mut completion_stream = open_ai::stream_completion(
&session.http_client, &session.http_client,
OPEN_AI_API_URL, OPEN_AI_API_URL,
@ -4276,6 +4292,128 @@ async fn count_tokens_with_language_model(
Ok(()) Ok(())
} }
struct ComputeEmbeddingsRateLimit;
impl RateLimit for ComputeEmbeddingsRateLimit {
fn capacity() -> usize {
std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(120) // Picked arbitrarily
}
fn refill_duration() -> chrono::Duration {
chrono::Duration::hours(1)
}
fn db_name() -> &'static str {
"compute-embeddings"
}
}
async fn compute_embeddings(
request: proto::ComputeEmbeddings,
response: Response<proto::ComputeEmbeddings>,
session: UserSession,
api_key: Option<Arc<str>>,
) -> Result<()> {
let api_key = api_key.context("no OpenAI API key configured on the server")?;
authorize_access_to_language_models(&session).await?;
session
.rate_limiter
.check::<ComputeEmbeddingsRateLimit>(session.user_id())
.await?;
let embeddings = match request.model.as_str() {
"openai/text-embedding-3-small" => {
open_ai::embed(
&session.http_client,
OPEN_AI_API_URL,
&api_key,
OpenAiEmbeddingModel::TextEmbedding3Small,
request.texts.iter().map(|text| text.as_str()),
)
.await?
}
provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
};
let embeddings = request
.texts
.iter()
.map(|text| {
let mut hasher = sha2::Sha256::new();
hasher.update(text.as_bytes());
let result = hasher.finalize();
result.to_vec()
})
.zip(
embeddings
.data
.into_iter()
.map(|embedding| embedding.embedding),
)
.collect::<HashMap<_, _>>();
let db = session.db().await;
db.save_embeddings(&request.model, &embeddings)
.await
.context("failed to save embeddings")
.trace_err();
response.send(proto::ComputeEmbeddingsResponse {
embeddings: embeddings
.into_iter()
.map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
.collect(),
})?;
Ok(())
}
struct GetCachedEmbeddingsRateLimit;
impl RateLimit for GetCachedEmbeddingsRateLimit {
fn capacity() -> usize {
std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(120) // Picked arbitrarily
}
fn refill_duration() -> chrono::Duration {
chrono::Duration::hours(1)
}
fn db_name() -> &'static str {
"get-cached-embeddings"
}
}
async fn get_cached_embeddings(
request: proto::GetCachedEmbeddings,
response: Response<proto::GetCachedEmbeddings>,
session: UserSession,
) -> Result<()> {
authorize_access_to_language_models(&session).await?;
session
.rate_limiter
.check::<GetCachedEmbeddingsRateLimit>(session.user_id())
.await?;
let db = session.db().await;
let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
response.send(proto::GetCachedEmbeddingsResponse {
embeddings: embeddings
.into_iter()
.map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
.collect(),
})?;
Ok(())
}
async fn authorize_access_to_language_models(session: &UserSession) -> Result<(), Error> { async fn authorize_access_to_language_models(session: &UserSession) -> Result<(), Error> {
let db = session.db().await; let db = session.db().await;
let flags = db.get_user_flags(session.user_id()).await?; let flags = db.get_user_flags(session.user_id()).await?;

View File

@ -396,7 +396,7 @@ mod tests {
let blame = cx.new_model(|cx| GitBlame::new(buffer.clone(), project.clone(), cx)); let blame = cx.new_model(|cx| GitBlame::new(buffer.clone(), project.clone(), cx));
let event = project.next_event(cx); let event = project.next_event(cx).await;
assert_eq!( assert_eq!(
event, event,
project::Event::Notification( project::Event::Notification(

View File

@ -7,7 +7,7 @@ use crate::{
TextSystem, View, ViewContext, VisualContext, WindowContext, WindowHandle, WindowOptions, TextSystem, View, ViewContext, VisualContext, WindowContext, WindowHandle, WindowOptions,
}; };
use anyhow::{anyhow, bail}; use anyhow::{anyhow, bail};
use futures::{Stream, StreamExt}; use futures::{channel::oneshot, Stream, StreamExt};
use std::{cell::RefCell, future::Future, ops::Deref, rc::Rc, sync::Arc, time::Duration}; use std::{cell::RefCell, future::Future, ops::Deref, rc::Rc, sync::Arc, time::Duration};
/// A TestAppContext is provided to tests created with `#[gpui::test]`, it provides /// A TestAppContext is provided to tests created with `#[gpui::test]`, it provides
@ -479,32 +479,27 @@ impl TestAppContext {
impl<T: 'static> Model<T> { impl<T: 'static> Model<T> {
/// Block until the next event is emitted by the model, then return it. /// Block until the next event is emitted by the model, then return it.
pub fn next_event<Evt>(&self, cx: &mut TestAppContext) -> Evt pub fn next_event<Event>(&self, cx: &mut TestAppContext) -> impl Future<Output = Event>
where where
Evt: Send + Clone + 'static, Event: Send + Clone + 'static,
T: EventEmitter<Evt>, T: EventEmitter<Event>,
{ {
let (tx, mut rx) = futures::channel::mpsc::unbounded(); let (tx, mut rx) = oneshot::channel();
let _subscription = self.update(cx, |_, cx| { let mut tx = Some(tx);
let subscription = self.update(cx, |_, cx| {
cx.subscribe(self, move |_, _, event, _| { cx.subscribe(self, move |_, _, event, _| {
tx.unbounded_send(event.clone()).ok(); if let Some(tx) = tx.take() {
_ = tx.send(event.clone());
}
}) })
}); });
// Run other tasks until the event is emitted. async move {
loop { let event = rx.await.expect("no event emitted");
match rx.try_next() { drop(subscription);
Ok(Some(event)) => return event, event
Ok(None) => panic!("model was dropped"),
Err(_) => {
if !cx.executor().tick() {
break;
} }
} }
}
}
panic!("no event received")
}
/// Returns a future that resolves when the model notifies. /// Returns a future that resolves when the model notifies.
pub fn next_notification(&self, cx: &TestAppContext) -> impl Future<Output = ()> { pub fn next_notification(&self, cx: &TestAppContext) -> impl Future<Output = ()> {

View File

@ -372,7 +372,7 @@ impl BackgroundExecutor {
self.dispatcher.as_test().unwrap().rng() self.dispatcher.as_test().unwrap().rng()
} }
/// How many CPUs are available to the dispatcher /// How many CPUs are available to the dispatcher.
pub fn num_cpus(&self) -> usize { pub fn num_cpus(&self) -> usize {
num_cpus::get() num_cpus::get()
} }
@ -440,6 +440,11 @@ impl<'a> Scope<'a> {
} }
} }
/// How many CPUs are available to the dispatcher.
pub fn num_cpus(&self) -> usize {
self.executor.num_cpus()
}
/// Spawn a future into this scope. /// Spawn a future into this scope.
pub fn spawn<F>(&mut self, f: F) pub fn spawn<F>(&mut self, f: F)
where where

View File

@ -72,7 +72,7 @@ pub use lsp::LanguageServerId;
pub use outline::{Outline, OutlineItem}; pub use outline::{Outline, OutlineItem};
pub use syntax_map::{OwnedSyntaxLayer, SyntaxLayer}; pub use syntax_map::{OwnedSyntaxLayer, SyntaxLayer};
pub use text::LineEnding; pub use text::LineEnding;
pub use tree_sitter::{Parser, Tree}; pub use tree_sitter::{Node, Parser, Tree, TreeCursor};
use crate::language_settings::SoftWrap; use crate::language_settings::SoftWrap;
@ -91,6 +91,16 @@ thread_local! {
}; };
} }
pub fn with_parser<F, R>(func: F) -> R
where
F: FnOnce(&mut Parser) -> R,
{
PARSER.with(|parser| {
let mut parser = parser.borrow_mut();
func(&mut parser)
})
}
lazy_static! { lazy_static! {
static ref NEXT_LANGUAGE_ID: AtomicUsize = Default::default(); static ref NEXT_LANGUAGE_ID: AtomicUsize = Default::default();
static ref NEXT_GRAMMAR_ID: AtomicUsize = Default::default(); static ref NEXT_GRAMMAR_ID: AtomicUsize = Default::default();

View File

@ -1,9 +1,11 @@
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Context, Result};
use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt}; use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::convert::TryFrom; use std::{convert::TryFrom, future::Future};
use util::http::{AsyncBody, HttpClient, Method, Request as HttpRequest}; use util::http::{AsyncBody, HttpClient, Method, Request as HttpRequest};
pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)] #[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
#[serde(rename_all = "lowercase")] #[serde(rename_all = "lowercase")]
pub enum Role { pub enum Role {
@ -188,3 +190,68 @@ pub async fn stream_completion(
} }
} }
} }
#[derive(Copy, Clone, Serialize, Deserialize)]
pub enum OpenAiEmbeddingModel {
#[serde(rename = "text-embedding-3-small")]
TextEmbedding3Small,
#[serde(rename = "text-embedding-3-large")]
TextEmbedding3Large,
}
#[derive(Serialize)]
struct OpenAiEmbeddingRequest<'a> {
model: OpenAiEmbeddingModel,
input: Vec<&'a str>,
}
#[derive(Deserialize)]
pub struct OpenAiEmbeddingResponse {
pub data: Vec<OpenAiEmbedding>,
}
#[derive(Deserialize)]
pub struct OpenAiEmbedding {
pub embedding: Vec<f32>,
}
pub fn embed<'a>(
client: &dyn HttpClient,
api_url: &str,
api_key: &str,
model: OpenAiEmbeddingModel,
texts: impl IntoIterator<Item = &'a str>,
) -> impl 'static + Future<Output = Result<OpenAiEmbeddingResponse>> {
let uri = format!("{api_url}/embeddings");
let request = OpenAiEmbeddingRequest {
model,
input: texts.into_iter().collect(),
};
let body = AsyncBody::from(serde_json::to_string(&request).unwrap());
let request = HttpRequest::builder()
.method(Method::POST)
.uri(uri)
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", api_key))
.body(body)
.map(|request| client.send(request));
async move {
let mut response = request?.await?;
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
if response.status().is_success() {
let response: OpenAiEmbeddingResponse =
serde_json::from_str(&body).context("failed to parse OpenAI embedding response")?;
Ok(response)
} else {
Err(anyhow!(
"error during embedding, status: {:?}, body: {:?}",
response.status(),
body
))
}
}
}

View File

@ -978,6 +978,50 @@ impl Project {
} }
} }
#[cfg(any(test, feature = "test-support"))]
pub async fn example(
root_paths: impl IntoIterator<Item = &Path>,
cx: &mut AsyncAppContext,
) -> Model<Project> {
use clock::FakeSystemClock;
let fs = Arc::new(RealFs::default());
let languages = LanguageRegistry::test(cx.background_executor().clone());
let clock = Arc::new(FakeSystemClock::default());
let http_client = util::http::FakeHttpClient::with_404_response();
let client = cx
.update(|cx| client::Client::new(clock, http_client.clone(), cx))
.unwrap();
let user_store = cx
.new_model(|cx| UserStore::new(client.clone(), cx))
.unwrap();
let project = cx
.update(|cx| {
Project::local(
client,
node_runtime::FakeNodeRuntime::new(),
user_store,
Arc::new(languages),
fs,
cx,
)
})
.unwrap();
for path in root_paths {
let (tree, _) = project
.update(cx, |project, cx| {
project.find_or_create_local_worktree(path, true, cx)
})
.unwrap()
.await
.unwrap();
tree.update(cx, |tree, _| tree.as_local().unwrap().scan_complete())
.unwrap()
.await;
}
project
}
#[cfg(any(test, feature = "test-support"))] #[cfg(any(test, feature = "test-support"))]
pub async fn test( pub async fn test(
fs: Arc<dyn Fs>, fs: Arc<dyn Fs>,
@ -1146,6 +1190,10 @@ impl Project {
self.user_store.clone() self.user_store.clone()
} }
pub fn node_runtime(&self) -> Option<&Arc<dyn NodeRuntime>> {
self.node.as_ref()
}
pub fn opened_buffers(&self) -> Vec<Model<Buffer>> { pub fn opened_buffers(&self) -> Vec<Model<Buffer>> {
self.opened_buffers self.opened_buffers
.values() .values()

View File

@ -2661,7 +2661,7 @@ async fn test_file_changes_multiple_times_on_disk(cx: &mut gpui::TestAppContext)
) )
.await .await
.unwrap(); .unwrap();
worktree.next_event(cx); worktree.next_event(cx).await;
// Change the buffer's file again. Depending on the random seed, the // Change the buffer's file again. Depending on the random seed, the
// previous file change may still be in progress. // previous file change may still be in progress.
@ -2672,7 +2672,7 @@ async fn test_file_changes_multiple_times_on_disk(cx: &mut gpui::TestAppContext)
) )
.await .await
.unwrap(); .unwrap();
worktree.next_event(cx); worktree.next_event(cx).await;
cx.executor().run_until_parked(); cx.executor().run_until_parked();
let on_disk_text = fs.load(Path::new("/dir/file1")).await.unwrap(); let on_disk_text = fs.load(Path::new("/dir/file1")).await.unwrap();
@ -2716,7 +2716,7 @@ async fn test_edit_buffer_while_it_reloads(cx: &mut gpui::TestAppContext) {
) )
.await .await
.unwrap(); .unwrap();
worktree.next_event(cx); worktree.next_event(cx).await;
cx.executor() cx.executor()
.spawn(cx.executor().simulate_random_delay()) .spawn(cx.executor().simulate_random_delay())

View File

@ -204,6 +204,11 @@ message Envelope {
LanguageModelResponse language_model_response = 167; LanguageModelResponse language_model_response = 167;
CountTokensWithLanguageModel count_tokens_with_language_model = 168; CountTokensWithLanguageModel count_tokens_with_language_model = 168;
CountTokensResponse count_tokens_response = 169; CountTokensResponse count_tokens_response = 169;
GetCachedEmbeddings get_cached_embeddings = 189;
GetCachedEmbeddingsResponse get_cached_embeddings_response = 190;
ComputeEmbeddings compute_embeddings = 191;
ComputeEmbeddingsResponse compute_embeddings_response = 192; // current max
UpdateChannelMessage update_channel_message = 170; UpdateChannelMessage update_channel_message = 170;
ChannelMessageUpdate channel_message_update = 171; ChannelMessageUpdate channel_message_update = 171;
@ -216,7 +221,7 @@ message Envelope {
MultiLspQueryResponse multi_lsp_query_response = 176; MultiLspQueryResponse multi_lsp_query_response = 176;
CreateRemoteProject create_remote_project = 177; CreateRemoteProject create_remote_project = 177;
CreateRemoteProjectResponse create_remote_project_response = 188; // current max CreateRemoteProjectResponse create_remote_project_response = 188;
CreateDevServer create_dev_server = 178; CreateDevServer create_dev_server = 178;
CreateDevServerResponse create_dev_server_response = 179; CreateDevServerResponse create_dev_server_response = 179;
ShutdownDevServer shutdown_dev_server = 180; ShutdownDevServer shutdown_dev_server = 180;
@ -1892,6 +1897,29 @@ message CountTokensResponse {
uint32 token_count = 1; uint32 token_count = 1;
} }
message GetCachedEmbeddings {
string model = 1;
repeated bytes digests = 2;
}
message GetCachedEmbeddingsResponse {
repeated Embedding embeddings = 1;
}
message ComputeEmbeddings {
string model = 1;
repeated string texts = 2;
}
message ComputeEmbeddingsResponse {
repeated Embedding embeddings = 1;
}
message Embedding {
bytes digest = 1;
repeated float dimensions = 2;
}
message BlameBuffer { message BlameBuffer {
uint64 project_id = 1; uint64 project_id = 1;
uint64 buffer_id = 2; uint64 buffer_id = 2;

View File

@ -151,6 +151,8 @@ messages!(
(ChannelMessageSent, Foreground), (ChannelMessageSent, Foreground),
(ChannelMessageUpdate, Foreground), (ChannelMessageUpdate, Foreground),
(CompleteWithLanguageModel, Background), (CompleteWithLanguageModel, Background),
(ComputeEmbeddings, Background),
(ComputeEmbeddingsResponse, Background),
(CopyProjectEntry, Foreground), (CopyProjectEntry, Foreground),
(CountTokensWithLanguageModel, Background), (CountTokensWithLanguageModel, Background),
(CountTokensResponse, Background), (CountTokensResponse, Background),
@ -174,6 +176,8 @@ messages!(
(FormatBuffers, Foreground), (FormatBuffers, Foreground),
(FormatBuffersResponse, Foreground), (FormatBuffersResponse, Foreground),
(FuzzySearchUsers, Foreground), (FuzzySearchUsers, Foreground),
(GetCachedEmbeddings, Background),
(GetCachedEmbeddingsResponse, Background),
(GetChannelMembers, Foreground), (GetChannelMembers, Foreground),
(GetChannelMembersResponse, Foreground), (GetChannelMembersResponse, Foreground),
(GetChannelMessages, Background), (GetChannelMessages, Background),
@ -325,6 +329,7 @@ request_messages!(
(CancelCall, Ack), (CancelCall, Ack),
(CopyProjectEntry, ProjectEntryResponse), (CopyProjectEntry, ProjectEntryResponse),
(CompleteWithLanguageModel, LanguageModelResponse), (CompleteWithLanguageModel, LanguageModelResponse),
(ComputeEmbeddings, ComputeEmbeddingsResponse),
(CountTokensWithLanguageModel, CountTokensResponse), (CountTokensWithLanguageModel, CountTokensResponse),
(CreateChannel, CreateChannelResponse), (CreateChannel, CreateChannelResponse),
(CreateProjectEntry, ProjectEntryResponse), (CreateProjectEntry, ProjectEntryResponse),
@ -336,6 +341,7 @@ request_messages!(
(Follow, FollowResponse), (Follow, FollowResponse),
(FormatBuffers, FormatBuffersResponse), (FormatBuffers, FormatBuffersResponse),
(FuzzySearchUsers, UsersResponse), (FuzzySearchUsers, UsersResponse),
(GetCachedEmbeddings, GetCachedEmbeddingsResponse),
(GetChannelMembers, GetChannelMembersResponse), (GetChannelMembers, GetChannelMembersResponse),
(GetChannelMessages, GetChannelMessagesResponse), (GetChannelMessages, GetChannelMessagesResponse),
(GetChannelMessagesById, GetChannelMessagesResponse), (GetChannelMessagesById, GetChannelMessagesResponse),

View File

@ -0,0 +1,48 @@
[package]
name = "semantic_index"
description = "Process, chunk, and embed text as vectors for semantic search."
version = "0.1.0"
edition = "2021"
publish = false
license = "GPL-3.0-or-later"
[lib]
path = "src/semantic_index.rs"
[dependencies]
anyhow.workspace = true
client.workspace = true
clock.workspace = true
collections.workspace = true
fs.workspace = true
futures.workspace = true
futures-batch.workspace = true
gpui.workspace = true
language.workspace = true
log.workspace = true
heed.workspace = true
open_ai.workspace = true
project.workspace = true
settings.workspace = true
serde.workspace = true
serde_json.workspace = true
sha2.workspace = true
smol.workspace = true
util. workspace = true
worktree.workspace = true
[dev-dependencies]
env_logger.workspace = true
client = { workspace = true, features = ["test-support"] }
fs = { workspace = true, features = ["test-support"] }
futures.workspace = true
gpui = { workspace = true, features = ["test-support"] }
language = { workspace = true, features = ["test-support"] }
languages.workspace = true
project = { workspace = true, features = ["test-support"] }
tempfile.workspace = true
util = { workspace = true, features = ["test-support"] }
worktree = { workspace = true, features = ["test-support"] }
[lints]
workspace = true

View File

@ -0,0 +1 @@
../../LICENSE-GPL

View File

@ -0,0 +1,140 @@
use client::Client;
use futures::channel::oneshot;
use gpui::{App, Global, TestAppContext};
use language::language_settings::AllLanguageSettings;
use project::Project;
use semantic_index::{OpenAiEmbeddingModel, OpenAiEmbeddingProvider, SemanticIndex};
use settings::SettingsStore;
use std::{path::Path, sync::Arc};
use util::http::HttpClientWithUrl;
pub fn init_test(cx: &mut TestAppContext) {
_ = cx.update(|cx| {
let store = SettingsStore::test(cx);
cx.set_global(store);
language::init(cx);
Project::init_settings(cx);
SettingsStore::update(cx, |store, cx| {
store.update_user_settings::<AllLanguageSettings>(cx, |_| {});
});
});
}
fn main() {
env_logger::init();
use clock::FakeSystemClock;
App::new().run(|cx| {
let store = SettingsStore::test(cx);
cx.set_global(store);
language::init(cx);
Project::init_settings(cx);
SettingsStore::update(cx, |store, cx| {
store.update_user_settings::<AllLanguageSettings>(cx, |_| {});
});
let clock = Arc::new(FakeSystemClock::default());
let http = Arc::new(HttpClientWithUrl::new("http://localhost:11434"));
let client = client::Client::new(clock, http.clone(), cx);
Client::set_global(client.clone(), cx);
let args: Vec<String> = std::env::args().collect();
if args.len() < 2 {
eprintln!("Usage: cargo run --example index -p semantic_index -- <project_path>");
cx.quit();
return;
}
// let embedding_provider = semantic_index::FakeEmbeddingProvider;
let api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
let embedding_provider = OpenAiEmbeddingProvider::new(
http.clone(),
OpenAiEmbeddingModel::TextEmbedding3Small,
open_ai::OPEN_AI_API_URL.to_string(),
api_key,
);
let semantic_index = SemanticIndex::new(
Path::new("/tmp/semantic-index-db.mdb"),
Arc::new(embedding_provider),
cx,
);
cx.spawn(|mut cx| async move {
let mut semantic_index = semantic_index.await.unwrap();
let project_path = Path::new(&args[1]);
let project = Project::example([project_path], &mut cx).await;
cx.update(|cx| {
let language_registry = project.read(cx).languages().clone();
let node_runtime = project.read(cx).node_runtime().unwrap().clone();
languages::init(language_registry, node_runtime, cx);
})
.unwrap();
let project_index = cx
.update(|cx| semantic_index.project_index(project.clone(), cx))
.unwrap();
let (tx, rx) = oneshot::channel();
let mut tx = Some(tx);
let subscription = cx.update(|cx| {
cx.subscribe(&project_index, move |_, event, _| {
if let Some(tx) = tx.take() {
_ = tx.send(*event);
}
})
});
let index_start = std::time::Instant::now();
rx.await.expect("no event emitted");
drop(subscription);
println!("Index time: {:?}", index_start.elapsed());
let results = cx
.update(|cx| {
let project_index = project_index.read(cx);
let query = "converting an anchor to a point";
project_index.search(query, 4, cx)
})
.unwrap()
.await;
for search_result in results {
let path = search_result.path.clone();
let content = cx
.update(|cx| {
let worktree = search_result.worktree.read(cx);
let entry_abs_path = worktree.abs_path().join(search_result.path.clone());
let fs = project.read(cx).fs().clone();
cx.spawn(|_| async move { fs.load(&entry_abs_path).await.unwrap() })
})
.unwrap()
.await;
let range = search_result.range.clone();
let content = content[search_result.range].to_owned();
println!(
"✄✄✄✄✄✄✄✄✄✄✄✄✄✄ {:?} @ {} ✄✄✄✄✄✄✄✄✄✄✄✄✄✄",
path, search_result.score
);
println!("{:?}:{:?}:{:?}", path, range.start, range.end);
println!("{}", content);
}
cx.background_executor()
.timer(std::time::Duration::from_secs(100000))
.await;
cx.update(|cx| cx.quit()).unwrap();
})
.detach();
});
}

View File

@ -0,0 +1,3 @@
fn main() {
println!("Hello Indexer!");
}

View File

@ -0,0 +1,43 @@
# Searching for a needle in a haystack
When you have a large amount of text, it can be useful to search for a specific word or phrase. This is often referred to as "finding a needle in a haystack." In this markdown document, we're "hiding" a key phrase for our text search to find. Can you find it?
## Instructions
1. Use the search functionality in your text editor or markdown viewer to find the hidden phrase in this document.
2. Once you've found the **phrase**, write it down and proceed to the next step.
Honestly, I just want to fill up plenty of characters so that we chunk this markdown into several chunks.
## Tips
- Relax
- Take a deep breath
- Focus on the task at hand
- Don't get distracted by other text
- Use the search functionality to your advantage
## Example code
```python
def search_for_needle(haystack, needle):
if needle in haystack:
return True
else:
return False
```
```javascript
function searchForNeedle(haystack, needle) {
return haystack.includes(needle);
}
```
## Background
When creating an index for a book or searching for a specific term in a large document, the ability to quickly find a specific word or phrase is essential. This is where search functionality comes in handy. However, one should _remember_ that the search is only as good as the index that was built. As they say, garbage in, garbage out!
## Conclusion
Searching for a needle in a haystack can be a challenging task, but with the right tools and techniques, it becomes much easier. Whether you're looking for a specific word in a document or trying to find a key piece of information in a large dataset, the ability to search efficiently is a valuable skill to have.

View File

@ -0,0 +1,409 @@
use language::{with_parser, Grammar, Tree};
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use std::{cmp, ops::Range, sync::Arc};
const CHUNK_THRESHOLD: usize = 1500;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Chunk {
pub range: Range<usize>,
pub digest: [u8; 32],
}
pub fn chunk_text(text: &str, grammar: Option<&Arc<Grammar>>) -> Vec<Chunk> {
if let Some(grammar) = grammar {
let tree = with_parser(|parser| {
parser
.set_language(&grammar.ts_language)
.expect("incompatible grammar");
parser.parse(&text, None).expect("invalid language")
});
chunk_parse_tree(tree, &text, CHUNK_THRESHOLD)
} else {
chunk_lines(&text)
}
}
fn chunk_parse_tree(tree: Tree, text: &str, chunk_threshold: usize) -> Vec<Chunk> {
let mut chunk_ranges = Vec::new();
let mut cursor = tree.walk();
let mut range = 0..0;
loop {
let node = cursor.node();
// If adding the node to the current chunk exceeds the threshold
if node.end_byte() - range.start > chunk_threshold {
// Try to descend into its first child. If we can't, flush the current
// range and try again.
if cursor.goto_first_child() {
continue;
} else if !range.is_empty() {
chunk_ranges.push(range.clone());
range.start = range.end;
continue;
}
// If we get here, the node itself has no children but is larger than the threshold.
// Break its text into arbitrary chunks.
split_text(text, range.clone(), node.end_byte(), &mut chunk_ranges);
}
range.end = node.end_byte();
// If we get here, we consumed the node. Advance to the next child, ascending if there isn't one.
while !cursor.goto_next_sibling() {
if !cursor.goto_parent() {
if !range.is_empty() {
chunk_ranges.push(range);
}
return chunk_ranges
.into_iter()
.map(|range| {
let digest = Sha256::digest(&text[range.clone()]).into();
Chunk { range, digest }
})
.collect();
}
}
}
}
fn chunk_lines(text: &str) -> Vec<Chunk> {
let mut chunk_ranges = Vec::new();
let mut range = 0..0;
let mut newlines = text.match_indices('\n').peekable();
while let Some((newline_ix, _)) = newlines.peek() {
let newline_ix = newline_ix + 1;
if newline_ix - range.start <= CHUNK_THRESHOLD {
range.end = newline_ix;
newlines.next();
} else {
if range.is_empty() {
split_text(text, range, newline_ix, &mut chunk_ranges);
range = newline_ix..newline_ix;
} else {
chunk_ranges.push(range.clone());
range.start = range.end;
}
}
}
if !range.is_empty() {
chunk_ranges.push(range);
}
chunk_ranges
.into_iter()
.map(|range| {
let mut hasher = Sha256::new();
hasher.update(&text[range.clone()]);
let mut digest = [0u8; 32];
digest.copy_from_slice(hasher.finalize().as_slice());
Chunk { range, digest }
})
.collect()
}
fn split_text(
text: &str,
mut range: Range<usize>,
max_end: usize,
chunk_ranges: &mut Vec<Range<usize>>,
) {
while range.start < max_end {
range.end = cmp::min(range.start + CHUNK_THRESHOLD, max_end);
while !text.is_char_boundary(range.end) {
range.end -= 1;
}
chunk_ranges.push(range.clone());
range.start = range.end;
}
}
#[cfg(test)]
mod tests {
use super::*;
use language::{tree_sitter_rust, Language, LanguageConfig, LanguageMatcher};
// This example comes from crates/gpui/examples/window_positioning.rs which
// has the property of being CHUNK_THRESHOLD < TEXT.len() < 2*CHUNK_THRESHOLD
static TEXT: &str = r#"
use gpui::*;
struct WindowContent {
text: SharedString,
}
impl Render for WindowContent {
fn render(&mut self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
div()
.flex()
.bg(rgb(0x1e2025))
.size_full()
.justify_center()
.items_center()
.text_xl()
.text_color(rgb(0xffffff))
.child(self.text.clone())
}
}
fn main() {
App::new().run(|cx: &mut AppContext| {
// Create several new windows, positioned in the top right corner of each screen
for screen in cx.displays() {
let options = {
let popup_margin_width = DevicePixels::from(16);
let popup_margin_height = DevicePixels::from(-0) - DevicePixels::from(48);
let window_size = Size {
width: px(400.),
height: px(72.),
};
let screen_bounds = screen.bounds();
let size: Size<DevicePixels> = window_size.into();
let bounds = gpui::Bounds::<DevicePixels> {
origin: screen_bounds.upper_right()
- point(size.width + popup_margin_width, popup_margin_height),
size: window_size.into(),
};
WindowOptions {
// Set the bounds of the window in screen coordinates
bounds: Some(bounds),
// Specify the display_id to ensure the window is created on the correct screen
display_id: Some(screen.id()),
titlebar: None,
window_background: WindowBackgroundAppearance::default(),
focus: false,
show: true,
kind: WindowKind::PopUp,
is_movable: false,
fullscreen: false,
}
};
cx.open_window(options, |cx| {
cx.new_view(|_| WindowContent {
text: format!("{:?}", screen.id()).into(),
})
});
}
});
}"#;
fn setup_rust_language() -> Language {
Language::new(
LanguageConfig {
name: "Rust".into(),
matcher: LanguageMatcher {
path_suffixes: vec!["rs".to_string()],
..Default::default()
},
..Default::default()
},
Some(tree_sitter_rust::language()),
)
}
#[test]
fn test_chunk_text() {
let text = "a\n".repeat(1000);
let chunks = chunk_text(&text, None);
assert_eq!(
chunks.len(),
((2000_f64) / (CHUNK_THRESHOLD as f64)).ceil() as usize
);
}
#[test]
fn test_chunk_text_grammar() {
// Let's set up a big text with some known segments
// We'll then chunk it and verify that the chunks are correct
let language = setup_rust_language();
let chunks = chunk_text(TEXT, language.grammar());
assert_eq!(chunks.len(), 2);
assert_eq!(chunks[0].range.start, 0);
assert_eq!(chunks[0].range.end, 1498);
// The break between chunks is right before the "Specify the display_id" comment
assert_eq!(chunks[1].range.start, 1498);
assert_eq!(chunks[1].range.end, 2396);
}
#[test]
fn test_chunk_parse_tree() {
let language = setup_rust_language();
let grammar = language.grammar().unwrap();
let tree = with_parser(|parser| {
parser
.set_language(&grammar.ts_language)
.expect("incompatible grammar");
parser.parse(TEXT, None).expect("invalid language")
});
let chunks = chunk_parse_tree(tree, TEXT, 250);
assert_eq!(chunks.len(), 11);
}
#[test]
fn test_chunk_unparsable() {
// Even if a chunk is unparsable, we should still be able to chunk it
let language = setup_rust_language();
let grammar = language.grammar().unwrap();
let text = r#"fn main() {"#;
let tree = with_parser(|parser| {
parser
.set_language(&grammar.ts_language)
.expect("incompatible grammar");
parser.parse(text, None).expect("invalid language")
});
let chunks = chunk_parse_tree(tree, text, 250);
assert_eq!(chunks.len(), 1);
assert_eq!(chunks[0].range.start, 0);
assert_eq!(chunks[0].range.end, 11);
}
#[test]
fn test_empty_text() {
let language = setup_rust_language();
let grammar = language.grammar().unwrap();
let tree = with_parser(|parser| {
parser
.set_language(&grammar.ts_language)
.expect("incompatible grammar");
parser.parse("", None).expect("invalid language")
});
let chunks = chunk_parse_tree(tree, "", CHUNK_THRESHOLD);
assert!(chunks.is_empty(), "Chunks should be empty for empty text");
}
#[test]
fn test_single_large_node() {
let large_text = "static ".to_owned() + "a".repeat(CHUNK_THRESHOLD - 1).as_str() + " = 2";
let language = setup_rust_language();
let grammar = language.grammar().unwrap();
let tree = with_parser(|parser| {
parser
.set_language(&grammar.ts_language)
.expect("incompatible grammar");
parser.parse(&large_text, None).expect("invalid language")
});
let chunks = chunk_parse_tree(tree, &large_text, CHUNK_THRESHOLD);
assert_eq!(
chunks.len(),
3,
"Large chunks are broken up according to grammar as best as possible"
);
// Expect chunks to be static, aaaaaa..., and = 2
assert_eq!(chunks[0].range.start, 0);
assert_eq!(chunks[0].range.end, "static".len());
assert_eq!(chunks[1].range.start, "static".len());
assert_eq!(chunks[1].range.end, "static".len() + CHUNK_THRESHOLD);
assert_eq!(chunks[2].range.start, "static".len() + CHUNK_THRESHOLD);
assert_eq!(chunks[2].range.end, large_text.len());
}
#[test]
fn test_multiple_small_nodes() {
let small_text = "a b c d e f g h i j k l m n o p q r s t u v w x y z";
let language = setup_rust_language();
let grammar = language.grammar().unwrap();
let tree = with_parser(|parser| {
parser
.set_language(&grammar.ts_language)
.expect("incompatible grammar");
parser.parse(small_text, None).expect("invalid language")
});
let chunks = chunk_parse_tree(tree, small_text, 5);
assert!(
chunks.len() > 1,
"Should have multiple chunks for multiple small nodes"
);
}
#[test]
fn test_node_with_children() {
let nested_text = "fn main() { let a = 1; let b = 2; }";
let language = setup_rust_language();
let grammar = language.grammar().unwrap();
let tree = with_parser(|parser| {
parser
.set_language(&grammar.ts_language)
.expect("incompatible grammar");
parser.parse(nested_text, None).expect("invalid language")
});
let chunks = chunk_parse_tree(tree, nested_text, 10);
assert!(
chunks.len() > 1,
"Should have multiple chunks for a node with children"
);
}
#[test]
fn test_text_with_unparsable_sections() {
// This test uses purposefully hit-or-miss sizing of 11 characters per likely chunk
let mixed_text = "fn main() { let a = 1; let b = 2; } unparsable bits here";
let language = setup_rust_language();
let grammar = language.grammar().unwrap();
let tree = with_parser(|parser| {
parser
.set_language(&grammar.ts_language)
.expect("incompatible grammar");
parser.parse(mixed_text, None).expect("invalid language")
});
let chunks = chunk_parse_tree(tree, mixed_text, 11);
assert!(
chunks.len() > 1,
"Should handle both parsable and unparsable sections correctly"
);
let expected_chunks = [
"fn main() {",
" let a = 1;",
" let b = 2;",
" }",
" unparsable",
" bits here",
];
for (i, chunk) in chunks.iter().enumerate() {
assert_eq!(
&mixed_text[chunk.range.clone()],
expected_chunks[i],
"Chunk {} should match",
i
);
}
}
}

View File

@ -0,0 +1,125 @@
mod cloud;
mod ollama;
mod open_ai;
pub use cloud::*;
pub use ollama::*;
pub use open_ai::*;
use sha2::{Digest, Sha256};
use anyhow::Result;
use futures::{future::BoxFuture, FutureExt};
use serde::{Deserialize, Serialize};
use std::{fmt, future};
#[derive(Debug, Default, Clone, PartialEq, Serialize, Deserialize)]
pub struct Embedding(Vec<f32>);
impl Embedding {
pub fn new(mut embedding: Vec<f32>) -> Self {
let len = embedding.len();
let mut norm = 0f32;
for i in 0..len {
norm += embedding[i] * embedding[i];
}
norm = norm.sqrt();
for dimension in &mut embedding {
*dimension /= norm;
}
Self(embedding)
}
fn len(&self) -> usize {
self.0.len()
}
pub fn similarity(self, other: &Embedding) -> f32 {
debug_assert_eq!(self.0.len(), other.0.len());
self.0
.iter()
.copied()
.zip(other.0.iter().copied())
.map(|(a, b)| a * b)
.sum()
}
}
impl fmt::Display for Embedding {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let digits_to_display = 3;
// Start the Embedding display format
write!(f, "Embedding(sized: {}; values: [", self.len())?;
for (index, value) in self.0.iter().enumerate().take(digits_to_display) {
// Lead with comma if not the first element
if index != 0 {
write!(f, ", ")?;
}
write!(f, "{:.3}", value)?;
}
if self.len() > digits_to_display {
write!(f, "...")?;
}
write!(f, "])")
}
}
/// Trait for embedding providers. Texts in, vectors out.
pub trait EmbeddingProvider: Sync + Send {
fn embed<'a>(&'a self, texts: &'a [TextToEmbed<'a>]) -> BoxFuture<'a, Result<Vec<Embedding>>>;
fn batch_size(&self) -> usize;
}
#[derive(Debug)]
pub struct TextToEmbed<'a> {
pub text: &'a str,
pub digest: [u8; 32],
}
impl<'a> TextToEmbed<'a> {
pub fn new(text: &'a str) -> Self {
let digest = Sha256::digest(text.as_bytes());
Self {
text,
digest: digest.into(),
}
}
}
pub struct FakeEmbeddingProvider;
impl EmbeddingProvider for FakeEmbeddingProvider {
fn embed<'a>(&'a self, texts: &'a [TextToEmbed<'a>]) -> BoxFuture<'a, Result<Vec<Embedding>>> {
let embeddings = texts
.iter()
.map(|_text| {
let mut embedding = vec![0f32; 1536];
for i in 0..embedding.len() {
embedding[i] = i as f32;
}
Embedding::new(embedding)
})
.collect();
future::ready(Ok(embeddings)).boxed()
}
fn batch_size(&self) -> usize {
16
}
}
#[cfg(test)]
mod test {
use super::*;
#[gpui::test]
fn test_normalize_embedding() {
let normalized = Embedding::new(vec![1.0, 1.0, 1.0]);
let value: f32 = 1.0 / 3.0_f32.sqrt();
assert_eq!(normalized, Embedding(vec![value; 3]));
}
}

View File

@ -0,0 +1,88 @@
use crate::{Embedding, EmbeddingProvider, TextToEmbed};
use anyhow::{anyhow, Context, Result};
use client::{proto, Client};
use collections::HashMap;
use futures::{future::BoxFuture, FutureExt};
use std::sync::Arc;
pub struct CloudEmbeddingProvider {
model: String,
client: Arc<Client>,
}
impl CloudEmbeddingProvider {
pub fn new(client: Arc<Client>) -> Self {
Self {
model: "openai/text-embedding-3-small".into(),
client,
}
}
}
impl EmbeddingProvider for CloudEmbeddingProvider {
fn embed<'a>(&'a self, texts: &'a [TextToEmbed<'a>]) -> BoxFuture<'a, Result<Vec<Embedding>>> {
// First, fetch any embeddings that are cached based on the requested texts' digests
// Then compute any embeddings that are missing.
async move {
let cached_embeddings = self.client.request(proto::GetCachedEmbeddings {
model: self.model.clone(),
digests: texts
.iter()
.map(|to_embed| to_embed.digest.to_vec())
.collect(),
});
let mut embeddings = cached_embeddings
.await
.context("failed to fetch cached embeddings via cloud model")?
.embeddings
.into_iter()
.map(|embedding| {
let digest: [u8; 32] = embedding
.digest
.try_into()
.map_err(|_| anyhow!("invalid digest for cached embedding"))?;
Ok((digest, embedding.dimensions))
})
.collect::<Result<HashMap<_, _>>>()?;
let compute_embeddings_request = proto::ComputeEmbeddings {
model: self.model.clone(),
texts: texts
.iter()
.filter_map(|to_embed| {
if embeddings.contains_key(&to_embed.digest) {
None
} else {
Some(to_embed.text.to_string())
}
})
.collect(),
};
if !compute_embeddings_request.texts.is_empty() {
let missing_embeddings = self.client.request(compute_embeddings_request).await?;
for embedding in missing_embeddings.embeddings {
let digest: [u8; 32] = embedding
.digest
.try_into()
.map_err(|_| anyhow!("invalid digest for cached embedding"))?;
embeddings.insert(digest, embedding.dimensions);
}
}
texts
.iter()
.map(|to_embed| {
let dimensions = embeddings.remove(&to_embed.digest).with_context(|| {
format!("server did not return an embedding for {:?}", to_embed)
})?;
Ok(Embedding::new(dimensions))
})
.collect()
}
.boxed()
}
fn batch_size(&self) -> usize {
2048
}
}

View File

@ -0,0 +1,74 @@
use anyhow::{Context as _, Result};
use futures::{future::BoxFuture, AsyncReadExt, FutureExt};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use util::http::HttpClient;
use crate::{Embedding, EmbeddingProvider, TextToEmbed};
pub enum OllamaEmbeddingModel {
NomicEmbedText,
MxbaiEmbedLarge,
}
pub struct OllamaEmbeddingProvider {
client: Arc<dyn HttpClient>,
model: OllamaEmbeddingModel,
}
#[derive(Serialize)]
struct OllamaEmbeddingRequest {
model: String,
prompt: String,
}
#[derive(Deserialize)]
struct OllamaEmbeddingResponse {
embedding: Vec<f32>,
}
impl OllamaEmbeddingProvider {
pub fn new(client: Arc<dyn HttpClient>, model: OllamaEmbeddingModel) -> Self {
Self { client, model }
}
}
impl EmbeddingProvider for OllamaEmbeddingProvider {
fn embed<'a>(&'a self, texts: &'a [TextToEmbed<'a>]) -> BoxFuture<'a, Result<Vec<Embedding>>> {
//
let model = match self.model {
OllamaEmbeddingModel::NomicEmbedText => "nomic-embed-text",
OllamaEmbeddingModel::MxbaiEmbedLarge => "mxbai-embed-large",
};
futures::future::try_join_all(texts.into_iter().map(|to_embed| {
let request = OllamaEmbeddingRequest {
model: model.to_string(),
prompt: to_embed.text.to_string(),
};
let request = serde_json::to_string(&request).unwrap();
async {
let response = self
.client
.post_json("http://localhost:11434/api/embeddings", request.into())
.await?;
let mut body = String::new();
response.into_body().read_to_string(&mut body).await?;
let response: OllamaEmbeddingResponse =
serde_json::from_str(&body).context("Unable to pull response")?;
Ok(Embedding::new(response.embedding))
}
}))
.boxed()
}
fn batch_size(&self) -> usize {
// TODO: Figure out decent value
10
}
}

View File

@ -0,0 +1,55 @@
use crate::{Embedding, EmbeddingProvider, TextToEmbed};
use anyhow::Result;
use futures::{future::BoxFuture, FutureExt};
pub use open_ai::OpenAiEmbeddingModel;
use std::sync::Arc;
use util::http::HttpClient;
pub struct OpenAiEmbeddingProvider {
client: Arc<dyn HttpClient>,
model: OpenAiEmbeddingModel,
api_url: String,
api_key: String,
}
impl OpenAiEmbeddingProvider {
pub fn new(
client: Arc<dyn HttpClient>,
model: OpenAiEmbeddingModel,
api_url: String,
api_key: String,
) -> Self {
Self {
client,
model,
api_url,
api_key,
}
}
}
impl EmbeddingProvider for OpenAiEmbeddingProvider {
fn embed<'a>(&'a self, texts: &'a [TextToEmbed<'a>]) -> BoxFuture<'a, Result<Vec<Embedding>>> {
let embed = open_ai::embed(
self.client.as_ref(),
&self.api_url,
&self.api_key,
self.model,
texts.iter().map(|to_embed| to_embed.text),
);
async move {
let response = embed.await?;
Ok(response
.data
.into_iter()
.map(|data| Embedding::new(data.embedding))
.collect())
}
.boxed()
}
fn batch_size(&self) -> usize {
// From https://platform.openai.com/docs/api-reference/embeddings/create
2048
}
}

View File

@ -0,0 +1,954 @@
mod chunking;
mod embedding;
use anyhow::{anyhow, Context as _, Result};
use chunking::{chunk_text, Chunk};
use collections::{Bound, HashMap};
pub use embedding::*;
use fs::Fs;
use futures::stream::StreamExt;
use futures_batch::ChunksTimeoutStreamExt;
use gpui::{
AppContext, AsyncAppContext, Context, EntityId, EventEmitter, Global, Model, ModelContext,
Subscription, Task, WeakModel,
};
use heed::types::{SerdeBincode, Str};
use language::LanguageRegistry;
use project::{Entry, Project, UpdatedEntriesSet, Worktree};
use serde::{Deserialize, Serialize};
use smol::channel;
use std::{
cmp::Ordering,
future::Future,
ops::Range,
path::Path,
sync::Arc,
time::{Duration, SystemTime},
};
use util::ResultExt;
use worktree::LocalSnapshot;
pub struct SemanticIndex {
embedding_provider: Arc<dyn EmbeddingProvider>,
db_connection: heed::Env,
project_indices: HashMap<WeakModel<Project>, Model<ProjectIndex>>,
}
impl Global for SemanticIndex {}
impl SemanticIndex {
pub fn new(
db_path: &Path,
embedding_provider: Arc<dyn EmbeddingProvider>,
cx: &mut AppContext,
) -> Task<Result<Self>> {
let db_path = db_path.to_path_buf();
cx.spawn(|cx| async move {
let db_connection = cx
.background_executor()
.spawn(async move {
unsafe {
heed::EnvOpenOptions::new()
.map_size(1024 * 1024 * 1024)
.max_dbs(3000)
.open(db_path)
}
})
.await?;
Ok(SemanticIndex {
db_connection,
embedding_provider,
project_indices: HashMap::default(),
})
})
}
pub fn project_index(
&mut self,
project: Model<Project>,
cx: &mut AppContext,
) -> Model<ProjectIndex> {
self.project_indices
.entry(project.downgrade())
.or_insert_with(|| {
cx.new_model(|cx| {
ProjectIndex::new(
project,
self.db_connection.clone(),
self.embedding_provider.clone(),
cx,
)
})
})
.clone()
}
}
pub struct ProjectIndex {
db_connection: heed::Env,
project: Model<Project>,
worktree_indices: HashMap<EntityId, WorktreeIndexHandle>,
language_registry: Arc<LanguageRegistry>,
fs: Arc<dyn Fs>,
last_status: Status,
embedding_provider: Arc<dyn EmbeddingProvider>,
_subscription: Subscription,
}
enum WorktreeIndexHandle {
Loading {
_task: Task<Result<()>>,
},
Loaded {
index: Model<WorktreeIndex>,
_subscription: Subscription,
},
}
impl ProjectIndex {
fn new(
project: Model<Project>,
db_connection: heed::Env,
embedding_provider: Arc<dyn EmbeddingProvider>,
cx: &mut ModelContext<Self>,
) -> Self {
let language_registry = project.read(cx).languages().clone();
let fs = project.read(cx).fs().clone();
let mut this = ProjectIndex {
db_connection,
project: project.clone(),
worktree_indices: HashMap::default(),
language_registry,
fs,
last_status: Status::Idle,
embedding_provider,
_subscription: cx.subscribe(&project, Self::handle_project_event),
};
this.update_worktree_indices(cx);
this
}
fn handle_project_event(
&mut self,
_: Model<Project>,
event: &project::Event,
cx: &mut ModelContext<Self>,
) {
match event {
project::Event::WorktreeAdded | project::Event::WorktreeRemoved(_) => {
self.update_worktree_indices(cx);
}
_ => {}
}
}
fn update_worktree_indices(&mut self, cx: &mut ModelContext<Self>) {
let worktrees = self
.project
.read(cx)
.visible_worktrees(cx)
.filter_map(|worktree| {
if worktree.read(cx).is_local() {
Some((worktree.entity_id(), worktree))
} else {
None
}
})
.collect::<HashMap<_, _>>();
self.worktree_indices
.retain(|worktree_id, _| worktrees.contains_key(worktree_id));
for (worktree_id, worktree) in worktrees {
self.worktree_indices.entry(worktree_id).or_insert_with(|| {
let worktree_index = WorktreeIndex::load(
worktree.clone(),
self.db_connection.clone(),
self.language_registry.clone(),
self.fs.clone(),
self.embedding_provider.clone(),
cx,
);
let load_worktree = cx.spawn(|this, mut cx| async move {
if let Some(index) = worktree_index.await.log_err() {
this.update(&mut cx, |this, cx| {
this.worktree_indices.insert(
worktree_id,
WorktreeIndexHandle::Loaded {
_subscription: cx
.observe(&index, |this, _, cx| this.update_status(cx)),
index,
},
);
})?;
} else {
this.update(&mut cx, |this, _cx| {
this.worktree_indices.remove(&worktree_id)
})?;
}
this.update(&mut cx, |this, cx| this.update_status(cx))
});
WorktreeIndexHandle::Loading {
_task: load_worktree,
}
});
}
self.update_status(cx);
}
fn update_status(&mut self, cx: &mut ModelContext<Self>) {
let mut status = Status::Idle;
for index in self.worktree_indices.values() {
match index {
WorktreeIndexHandle::Loading { .. } => {
status = Status::Scanning;
break;
}
WorktreeIndexHandle::Loaded { index, .. } => {
if index.read(cx).status == Status::Scanning {
status = Status::Scanning;
break;
}
}
}
}
if status != self.last_status {
self.last_status = status;
cx.emit(status);
}
}
pub fn search(&self, query: &str, limit: usize, cx: &AppContext) -> Task<Vec<SearchResult>> {
let mut worktree_searches = Vec::new();
for worktree_index in self.worktree_indices.values() {
if let WorktreeIndexHandle::Loaded { index, .. } = worktree_index {
worktree_searches
.push(index.read_with(cx, |index, cx| index.search(query, limit, cx)));
}
}
cx.spawn(|_| async move {
let mut results = Vec::new();
let worktree_searches = futures::future::join_all(worktree_searches).await;
for worktree_search_results in worktree_searches {
if let Some(worktree_search_results) = worktree_search_results.log_err() {
results.extend(worktree_search_results);
}
}
results
.sort_unstable_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal));
results.truncate(limit);
results
})
}
}
pub struct SearchResult {
pub worktree: Model<Worktree>,
pub path: Arc<Path>,
pub range: Range<usize>,
pub score: f32,
}
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub enum Status {
Idle,
Scanning,
}
impl EventEmitter<Status> for ProjectIndex {}
struct WorktreeIndex {
worktree: Model<Worktree>,
db_connection: heed::Env,
db: heed::Database<Str, SerdeBincode<EmbeddedFile>>,
language_registry: Arc<LanguageRegistry>,
fs: Arc<dyn Fs>,
embedding_provider: Arc<dyn EmbeddingProvider>,
status: Status,
_index_entries: Task<Result<()>>,
_subscription: Subscription,
}
impl WorktreeIndex {
pub fn load(
worktree: Model<Worktree>,
db_connection: heed::Env,
language_registry: Arc<LanguageRegistry>,
fs: Arc<dyn Fs>,
embedding_provider: Arc<dyn EmbeddingProvider>,
cx: &mut AppContext,
) -> Task<Result<Model<Self>>> {
let worktree_abs_path = worktree.read(cx).abs_path();
cx.spawn(|mut cx| async move {
let db = cx
.background_executor()
.spawn({
let db_connection = db_connection.clone();
async move {
let mut txn = db_connection.write_txn()?;
let db_name = worktree_abs_path.to_string_lossy();
let db = db_connection.create_database(&mut txn, Some(&db_name))?;
txn.commit()?;
anyhow::Ok(db)
}
})
.await?;
cx.new_model(|cx| {
Self::new(
worktree,
db_connection,
db,
language_registry,
fs,
embedding_provider,
cx,
)
})
})
}
fn new(
worktree: Model<Worktree>,
db_connection: heed::Env,
db: heed::Database<Str, SerdeBincode<EmbeddedFile>>,
language_registry: Arc<LanguageRegistry>,
fs: Arc<dyn Fs>,
embedding_provider: Arc<dyn EmbeddingProvider>,
cx: &mut ModelContext<Self>,
) -> Self {
let (updated_entries_tx, updated_entries_rx) = channel::unbounded();
let _subscription = cx.subscribe(&worktree, move |_this, _worktree, event, _cx| {
if let worktree::Event::UpdatedEntries(update) = event {
_ = updated_entries_tx.try_send(update.clone());
}
});
Self {
db_connection,
db,
worktree,
language_registry,
fs,
embedding_provider,
status: Status::Idle,
_index_entries: cx.spawn(|this, cx| Self::index_entries(this, updated_entries_rx, cx)),
_subscription,
}
}
async fn index_entries(
this: WeakModel<Self>,
updated_entries: channel::Receiver<UpdatedEntriesSet>,
mut cx: AsyncAppContext,
) -> Result<()> {
let index = this.update(&mut cx, |this, cx| {
cx.notify();
this.status = Status::Scanning;
this.index_entries_changed_on_disk(cx)
})?;
index.await.log_err();
this.update(&mut cx, |this, cx| {
this.status = Status::Idle;
cx.notify();
})?;
while let Ok(updated_entries) = updated_entries.recv().await {
let index = this.update(&mut cx, |this, cx| {
cx.notify();
this.status = Status::Scanning;
this.index_updated_entries(updated_entries, cx)
})?;
index.await.log_err();
this.update(&mut cx, |this, cx| {
this.status = Status::Idle;
cx.notify();
})?;
}
Ok(())
}
fn index_entries_changed_on_disk(&self, cx: &AppContext) -> impl Future<Output = Result<()>> {
let worktree = self.worktree.read(cx).as_local().unwrap().snapshot();
let worktree_abs_path = worktree.abs_path().clone();
let scan = self.scan_entries(worktree.clone(), cx);
let chunk = self.chunk_files(worktree_abs_path, scan.updated_entries, cx);
let embed = self.embed_files(chunk.files, cx);
let persist = self.persist_embeddings(scan.deleted_entry_ranges, embed.files, cx);
async move {
futures::try_join!(scan.task, chunk.task, embed.task, persist)?;
Ok(())
}
}
fn index_updated_entries(
&self,
updated_entries: UpdatedEntriesSet,
cx: &AppContext,
) -> impl Future<Output = Result<()>> {
let worktree = self.worktree.read(cx).as_local().unwrap().snapshot();
let worktree_abs_path = worktree.abs_path().clone();
let scan = self.scan_updated_entries(worktree, updated_entries, cx);
let chunk = self.chunk_files(worktree_abs_path, scan.updated_entries, cx);
let embed = self.embed_files(chunk.files, cx);
let persist = self.persist_embeddings(scan.deleted_entry_ranges, embed.files, cx);
async move {
futures::try_join!(scan.task, chunk.task, embed.task, persist)?;
Ok(())
}
}
fn scan_entries(&self, worktree: LocalSnapshot, cx: &AppContext) -> ScanEntries {
let (updated_entries_tx, updated_entries_rx) = channel::bounded(512);
let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) = channel::bounded(128);
let db_connection = self.db_connection.clone();
let db = self.db;
let task = cx.background_executor().spawn(async move {
let txn = db_connection
.read_txn()
.context("failed to create read transaction")?;
let mut db_entries = db
.iter(&txn)
.context("failed to create iterator")?
.move_between_keys()
.peekable();
let mut deletion_range: Option<(Bound<&str>, Bound<&str>)> = None;
for entry in worktree.files(false, 0) {
let entry_db_key = db_key_for_path(&entry.path);
let mut saved_mtime = None;
while let Some(db_entry) = db_entries.peek() {
match db_entry {
Ok((db_path, db_embedded_file)) => match (*db_path).cmp(&entry_db_key) {
Ordering::Less => {
if let Some(deletion_range) = deletion_range.as_mut() {
deletion_range.1 = Bound::Included(db_path);
} else {
deletion_range =
Some((Bound::Included(db_path), Bound::Included(db_path)));
}
db_entries.next();
}
Ordering::Equal => {
if let Some(deletion_range) = deletion_range.take() {
deleted_entry_ranges_tx
.send((
deletion_range.0.map(ToString::to_string),
deletion_range.1.map(ToString::to_string),
))
.await?;
}
saved_mtime = db_embedded_file.mtime;
db_entries.next();
break;
}
Ordering::Greater => {
break;
}
},
Err(_) => return Err(db_entries.next().unwrap().unwrap_err())?,
}
}
if entry.mtime != saved_mtime {
updated_entries_tx.send(entry.clone()).await?;
}
}
if let Some(db_entry) = db_entries.next() {
let (db_path, _) = db_entry?;
deleted_entry_ranges_tx
.send((Bound::Included(db_path.to_string()), Bound::Unbounded))
.await?;
}
Ok(())
});
ScanEntries {
updated_entries: updated_entries_rx,
deleted_entry_ranges: deleted_entry_ranges_rx,
task,
}
}
fn scan_updated_entries(
&self,
worktree: LocalSnapshot,
updated_entries: UpdatedEntriesSet,
cx: &AppContext,
) -> ScanEntries {
let (updated_entries_tx, updated_entries_rx) = channel::bounded(512);
let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) = channel::bounded(128);
let task = cx.background_executor().spawn(async move {
for (path, entry_id, status) in updated_entries.iter() {
match status {
project::PathChange::Added
| project::PathChange::Updated
| project::PathChange::AddedOrUpdated => {
if let Some(entry) = worktree.entry_for_id(*entry_id) {
updated_entries_tx.send(entry.clone()).await?;
}
}
project::PathChange::Removed => {
let db_path = db_key_for_path(path);
deleted_entry_ranges_tx
.send((Bound::Included(db_path.clone()), Bound::Included(db_path)))
.await?;
}
project::PathChange::Loaded => {
// Do nothing.
}
}
}
Ok(())
});
ScanEntries {
updated_entries: updated_entries_rx,
deleted_entry_ranges: deleted_entry_ranges_rx,
task,
}
}
fn chunk_files(
&self,
worktree_abs_path: Arc<Path>,
entries: channel::Receiver<Entry>,
cx: &AppContext,
) -> ChunkFiles {
let language_registry = self.language_registry.clone();
let fs = self.fs.clone();
let (chunked_files_tx, chunked_files_rx) = channel::bounded(2048);
let task = cx.spawn(|cx| async move {
cx.background_executor()
.scoped(|cx| {
for _ in 0..cx.num_cpus() {
cx.spawn(async {
while let Ok(entry) = entries.recv().await {
let entry_abs_path = worktree_abs_path.join(&entry.path);
let Some(text) = fs.load(&entry_abs_path).await.log_err() else {
continue;
};
let language = language_registry
.language_for_file_path(&entry.path)
.await
.ok();
let grammar =
language.as_ref().and_then(|language| language.grammar());
let chunked_file = ChunkedFile {
worktree_root: worktree_abs_path.clone(),
chunks: chunk_text(&text, grammar),
entry,
text,
};
if chunked_files_tx.send(chunked_file).await.is_err() {
return;
}
}
});
}
})
.await;
Ok(())
});
ChunkFiles {
files: chunked_files_rx,
task,
}
}
fn embed_files(
&self,
chunked_files: channel::Receiver<ChunkedFile>,
cx: &AppContext,
) -> EmbedFiles {
let embedding_provider = self.embedding_provider.clone();
let (embedded_files_tx, embedded_files_rx) = channel::bounded(512);
let task = cx.background_executor().spawn(async move {
let mut chunked_file_batches =
chunked_files.chunks_timeout(512, Duration::from_secs(2));
while let Some(chunked_files) = chunked_file_batches.next().await {
// View the batch of files as a vec of chunks
// Flatten out to a vec of chunks that we can subdivide into batch sized pieces
// Once those are done, reassemble it back into which files they belong to
let chunks = chunked_files
.iter()
.flat_map(|file| {
file.chunks.iter().map(|chunk| TextToEmbed {
text: &file.text[chunk.range.clone()],
digest: chunk.digest,
})
})
.collect::<Vec<_>>();
let mut embeddings = Vec::new();
for embedding_batch in chunks.chunks(embedding_provider.batch_size()) {
// todo!("add a retry facility")
embeddings.extend(embedding_provider.embed(embedding_batch).await?);
}
let mut embeddings = embeddings.into_iter();
for chunked_file in chunked_files {
let chunk_embeddings = embeddings
.by_ref()
.take(chunked_file.chunks.len())
.collect::<Vec<_>>();
let embedded_chunks = chunked_file
.chunks
.into_iter()
.zip(chunk_embeddings)
.map(|(chunk, embedding)| EmbeddedChunk { chunk, embedding })
.collect();
let embedded_file = EmbeddedFile {
path: chunked_file.entry.path.clone(),
mtime: chunked_file.entry.mtime,
chunks: embedded_chunks,
};
embedded_files_tx.send(embedded_file).await?;
}
}
Ok(())
});
EmbedFiles {
files: embedded_files_rx,
task,
}
}
fn persist_embeddings(
&self,
mut deleted_entry_ranges: channel::Receiver<(Bound<String>, Bound<String>)>,
embedded_files: channel::Receiver<EmbeddedFile>,
cx: &AppContext,
) -> Task<Result<()>> {
let db_connection = self.db_connection.clone();
let db = self.db;
cx.background_executor().spawn(async move {
while let Some(deletion_range) = deleted_entry_ranges.next().await {
let mut txn = db_connection.write_txn()?;
let start = deletion_range.0.as_ref().map(|start| start.as_str());
let end = deletion_range.1.as_ref().map(|end| end.as_str());
log::debug!("deleting embeddings in range {:?}", &(start, end));
db.delete_range(&mut txn, &(start, end))?;
txn.commit()?;
}
let mut embedded_files = embedded_files.chunks_timeout(4096, Duration::from_secs(2));
while let Some(embedded_files) = embedded_files.next().await {
let mut txn = db_connection.write_txn()?;
for file in embedded_files {
log::debug!("saving embedding for file {:?}", file.path);
let key = db_key_for_path(&file.path);
db.put(&mut txn, &key, &file)?;
}
txn.commit()?;
log::debug!("committed");
}
Ok(())
})
}
fn search(
&self,
query: &str,
limit: usize,
cx: &AppContext,
) -> Task<Result<Vec<SearchResult>>> {
let (chunks_tx, chunks_rx) = channel::bounded(1024);
let db_connection = self.db_connection.clone();
let db = self.db;
let scan_chunks = cx.background_executor().spawn({
async move {
let txn = db_connection
.read_txn()
.context("failed to create read transaction")?;
let db_entries = db.iter(&txn).context("failed to iterate database")?;
for db_entry in db_entries {
let (_, db_embedded_file) = db_entry?;
for chunk in db_embedded_file.chunks {
chunks_tx
.send((db_embedded_file.path.clone(), chunk))
.await?;
}
}
anyhow::Ok(())
}
});
let query = query.to_string();
let embedding_provider = self.embedding_provider.clone();
let worktree = self.worktree.clone();
cx.spawn(|cx| async move {
#[cfg(debug_assertions)]
let embedding_query_start = std::time::Instant::now();
let mut query_embeddings = embedding_provider
.embed(&[TextToEmbed::new(&query)])
.await?;
let query_embedding = query_embeddings
.pop()
.ok_or_else(|| anyhow!("no embedding for query"))?;
let mut workers = Vec::new();
for _ in 0..cx.background_executor().num_cpus() {
workers.push(Vec::<SearchResult>::new());
}
#[cfg(debug_assertions)]
let search_start = std::time::Instant::now();
cx.background_executor()
.scoped(|cx| {
for worker_results in workers.iter_mut() {
cx.spawn(async {
while let Ok((path, embedded_chunk)) = chunks_rx.recv().await {
let score = embedded_chunk.embedding.similarity(&query_embedding);
let ix = match worker_results.binary_search_by(|probe| {
score.partial_cmp(&probe.score).unwrap_or(Ordering::Equal)
}) {
Ok(ix) | Err(ix) => ix,
};
worker_results.insert(
ix,
SearchResult {
worktree: worktree.clone(),
path: path.clone(),
range: embedded_chunk.chunk.range.clone(),
score,
},
);
worker_results.truncate(limit);
}
});
}
})
.await;
scan_chunks.await?;
let mut search_results = Vec::with_capacity(workers.len() * limit);
for worker_results in workers {
search_results.extend(worker_results);
}
search_results
.sort_unstable_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal));
search_results.truncate(limit);
#[cfg(debug_assertions)]
{
let search_elapsed = search_start.elapsed();
log::debug!(
"searched {} entries in {:?}",
search_results.len(),
search_elapsed
);
let embedding_query_elapsed = embedding_query_start.elapsed();
log::debug!("embedding query took {:?}", embedding_query_elapsed);
}
Ok(search_results)
})
}
}
struct ScanEntries {
updated_entries: channel::Receiver<Entry>,
deleted_entry_ranges: channel::Receiver<(Bound<String>, Bound<String>)>,
task: Task<Result<()>>,
}
struct ChunkFiles {
files: channel::Receiver<ChunkedFile>,
task: Task<Result<()>>,
}
struct ChunkedFile {
#[allow(dead_code)]
pub worktree_root: Arc<Path>,
pub entry: Entry,
pub text: String,
pub chunks: Vec<Chunk>,
}
struct EmbedFiles {
files: channel::Receiver<EmbeddedFile>,
task: Task<Result<()>>,
}
#[derive(Debug, Serialize, Deserialize)]
struct EmbeddedFile {
path: Arc<Path>,
mtime: Option<SystemTime>,
chunks: Vec<EmbeddedChunk>,
}
#[derive(Debug, Serialize, Deserialize)]
struct EmbeddedChunk {
chunk: Chunk,
embedding: Embedding,
}
fn db_key_for_path(path: &Arc<Path>) -> String {
path.to_string_lossy().replace('/', "\0")
}
#[cfg(test)]
mod tests {
use super::*;
use futures::channel::oneshot;
use futures::{future::BoxFuture, FutureExt};
use gpui::{Global, TestAppContext};
use language::language_settings::AllLanguageSettings;
use project::Project;
use settings::SettingsStore;
use std::{future, path::Path, sync::Arc};
fn init_test(cx: &mut TestAppContext) {
_ = cx.update(|cx| {
let store = SettingsStore::test(cx);
cx.set_global(store);
language::init(cx);
Project::init_settings(cx);
SettingsStore::update(cx, |store, cx| {
store.update_user_settings::<AllLanguageSettings>(cx, |_| {});
});
});
}
pub struct TestEmbeddingProvider;
impl EmbeddingProvider for TestEmbeddingProvider {
fn embed<'a>(
&'a self,
texts: &'a [TextToEmbed<'a>],
) -> BoxFuture<'a, Result<Vec<Embedding>>> {
let embeddings = texts
.iter()
.map(|text| {
let mut embedding = vec![0f32; 2];
// if the text contains garbage, give it a 1 in the first dimension
if text.text.contains("garbage in") {
embedding[0] = 0.9;
} else {
embedding[0] = -0.9;
}
if text.text.contains("garbage out") {
embedding[1] = 0.9;
} else {
embedding[1] = -0.9;
}
Embedding::new(embedding)
})
.collect();
future::ready(Ok(embeddings)).boxed()
}
fn batch_size(&self) -> usize {
16
}
}
#[gpui::test]
async fn test_search(cx: &mut TestAppContext) {
cx.executor().allow_parking();
init_test(cx);
let temp_dir = tempfile::tempdir().unwrap();
let mut semantic_index = cx
.update(|cx| {
let semantic_index = SemanticIndex::new(
Path::new(temp_dir.path()),
Arc::new(TestEmbeddingProvider),
cx,
);
semantic_index
})
.await
.unwrap();
// todo!(): use a fixture
let project_path = Path::new("./fixture");
let project = cx
.spawn(|mut cx| async move { Project::example([project_path], &mut cx).await })
.await;
cx.update(|cx| {
let language_registry = project.read(cx).languages().clone();
let node_runtime = project.read(cx).node_runtime().unwrap().clone();
languages::init(language_registry, node_runtime, cx);
});
let project_index = cx.update(|cx| semantic_index.project_index(project.clone(), cx));
let (tx, rx) = oneshot::channel();
let mut tx = Some(tx);
let subscription = cx.update(|cx| {
cx.subscribe(&project_index, move |_, event, _| {
if let Some(tx) = tx.take() {
_ = tx.send(*event);
}
})
});
rx.await.expect("no event emitted");
drop(subscription);
let results = cx
.update(|cx| {
let project_index = project_index.read(cx);
let query = "garbage in, garbage out";
project_index.search(query, 4, cx)
})
.await;
assert!(results.len() > 1, "should have found some results");
for result in &results {
println!("result: {:?}", result.path);
println!("score: {:?}", result.score);
}
// Find result that is greater than 0.5
let search_result = results.iter().find(|result| result.score > 0.9).unwrap();
assert_eq!(search_result.path.to_string_lossy(), "needle.md");
let content = cx
.update(|cx| {
let worktree = search_result.worktree.read(cx);
let entry_abs_path = worktree.abs_path().join(search_result.path.clone());
let fs = project.read(cx).fs().clone();
cx.spawn(|_| async move { fs.load(&entry_abs_path).await.unwrap() })
})
.await;
let range = search_result.range.clone();
let content = content[range.clone()].to_owned();
assert!(content.contains("garbage in, garbage out"));
}
}

View File

@ -71,19 +71,28 @@ impl HttpClientWithUrl {
} }
impl HttpClient for Arc<HttpClientWithUrl> { impl HttpClient for Arc<HttpClientWithUrl> {
fn send(&self, req: Request<AsyncBody>) -> BoxFuture<Result<Response<AsyncBody>, Error>> { fn send(
&self,
req: Request<AsyncBody>,
) -> BoxFuture<'static, Result<Response<AsyncBody>, Error>> {
self.client.send(req) self.client.send(req)
} }
} }
impl HttpClient for HttpClientWithUrl { impl HttpClient for HttpClientWithUrl {
fn send(&self, req: Request<AsyncBody>) -> BoxFuture<Result<Response<AsyncBody>, Error>> { fn send(
&self,
req: Request<AsyncBody>,
) -> BoxFuture<'static, Result<Response<AsyncBody>, Error>> {
self.client.send(req) self.client.send(req)
} }
} }
pub trait HttpClient: Send + Sync { pub trait HttpClient: Send + Sync {
fn send(&self, req: Request<AsyncBody>) -> BoxFuture<Result<Response<AsyncBody>, Error>>; fn send(
&self,
req: Request<AsyncBody>,
) -> BoxFuture<'static, Result<Response<AsyncBody>, Error>>;
fn get<'a>( fn get<'a>(
&'a self, &'a self,
@ -135,8 +144,12 @@ pub fn client() -> Arc<dyn HttpClient> {
} }
impl HttpClient for isahc::HttpClient { impl HttpClient for isahc::HttpClient {
fn send(&self, req: Request<AsyncBody>) -> BoxFuture<Result<Response<AsyncBody>, Error>> { fn send(
Box::pin(async move { self.send_async(req).await }) &self,
req: Request<AsyncBody>,
) -> BoxFuture<'static, Result<Response<AsyncBody>, Error>> {
let client = self.clone();
Box::pin(async move { client.send_async(req).await })
} }
} }
@ -196,7 +209,10 @@ impl fmt::Debug for FakeHttpClient {
#[cfg(feature = "test-support")] #[cfg(feature = "test-support")]
impl HttpClient for FakeHttpClient { impl HttpClient for FakeHttpClient {
fn send(&self, req: Request<AsyncBody>) -> BoxFuture<Result<Response<AsyncBody>, Error>> { fn send(
&self,
req: Request<AsyncBody>,
) -> BoxFuture<'static, Result<Response<AsyncBody>, Error>> {
let future = (self.handler)(req); let future = (self.handler)(req);
Box::pin(async move { future.await.map(Into::into) }) Box::pin(async move { future.await.map(Into::into) })
} }