Remove 2 suffix for fs, db, semantic_index, prettier

Co-authored-by: Mikayla <mikayla@zed.dev>
This commit is contained in:
Max Brunsfeld 2024-01-03 12:09:08 -08:00
parent 324ac96977
commit 5ddd298b4d
75 changed files with 455 additions and 12973 deletions

262
Cargo.lock generated
View File

@ -310,7 +310,7 @@ dependencies = [
"ctor",
"editor",
"env_logger",
"fs2",
"fs",
"futures 0.3.28",
"gpui2",
"indoc",
@ -326,7 +326,7 @@ dependencies = [
"regex",
"schemars",
"search",
"semantic_index2",
"semantic_index",
"serde",
"serde_json",
"settings2",
@ -688,7 +688,7 @@ version = "0.1.0"
dependencies = [
"anyhow",
"client",
"db2",
"db",
"gpui2",
"isahc",
"lazy_static",
@ -1122,7 +1122,7 @@ dependencies = [
"audio2",
"client",
"collections",
"fs2",
"fs",
"futures 0.3.28",
"gpui2",
"image",
@ -1210,7 +1210,7 @@ dependencies = [
"client",
"clock",
"collections",
"db2",
"db",
"feature_flags",
"futures 0.3.28",
"gpui2",
@ -1221,7 +1221,7 @@ dependencies = [
"parking_lot 0.11.2",
"postage",
"rand 0.8.5",
"rpc2",
"rpc",
"schemars",
"serde",
"serde_derive",
@ -1384,7 +1384,7 @@ dependencies = [
"async-tungstenite",
"chrono",
"collections",
"db2",
"db",
"feature_flags",
"futures 0.3.28",
"gpui2",
@ -1394,7 +1394,7 @@ dependencies = [
"parking_lot 0.11.2",
"postage",
"rand 0.8.5",
"rpc2",
"rpc",
"schemars",
"serde",
"serde_derive",
@ -1553,7 +1553,7 @@ dependencies = [
"editor",
"env_logger",
"envy",
"fs2",
"fs",
"futures 0.3.28",
"git3",
"gpui2",
@ -1568,7 +1568,7 @@ dependencies = [
"lsp",
"nanoid",
"node_runtime",
"notifications2",
"notifications",
"parking_lot 0.11.2",
"pretty_assertions",
"project",
@ -1576,7 +1576,7 @@ dependencies = [
"prost 0.8.0",
"rand 0.8.5",
"reqwest",
"rpc2",
"rpc",
"scrypt",
"sea-orm",
"serde",
@ -1614,7 +1614,7 @@ dependencies = [
"client",
"clock",
"collections",
"db2",
"db",
"editor",
"feature_flags",
"feedback",
@ -1625,14 +1625,14 @@ dependencies = [
"lazy_static",
"log",
"menu2",
"notifications2",
"notifications",
"picker",
"postage",
"pretty_assertions",
"project",
"recent_projects",
"rich_text",
"rpc2",
"rpc",
"schemars",
"serde",
"serde_derive",
@ -1794,7 +1794,7 @@ dependencies = [
"lsp",
"node_runtime",
"parking_lot 0.11.2",
"rpc2",
"rpc",
"serde",
"serde_derive",
"settings2",
@ -1811,7 +1811,7 @@ dependencies = [
"anyhow",
"copilot",
"editor",
"fs2",
"fs",
"futures 0.3.28",
"gpui2",
"language",
@ -2208,28 +2208,6 @@ dependencies = [
[[package]]
name = "db"
version = "0.1.0"
dependencies = [
"anyhow",
"async-trait",
"collections",
"env_logger",
"gpui",
"indoc",
"lazy_static",
"log",
"parking_lot 0.11.2",
"serde",
"serde_derive",
"smol",
"sqlez",
"sqlez_macros",
"tempdir",
"util",
]
[[package]]
name = "db2"
version = "0.1.0"
dependencies = [
"anyhow",
"async-trait",
@ -2500,7 +2478,7 @@ dependencies = [
"convert_case 0.6.0",
"copilot",
"ctor",
"db2",
"db",
"env_logger",
"futures 0.3.28",
"fuzzy",
@ -2519,7 +2497,7 @@ dependencies = [
"project",
"rand 0.8.5",
"rich_text",
"rpc2",
"rpc",
"schemars",
"serde",
"serde_derive",
@ -2724,7 +2702,7 @@ dependencies = [
"anyhow",
"bitflags 2.4.1",
"client",
"db2",
"db",
"editor",
"futures 0.3.28",
"gpui2",
@ -2923,34 +2901,6 @@ dependencies = [
[[package]]
name = "fs"
version = "0.1.0"
dependencies = [
"anyhow",
"async-trait",
"collections",
"fsevent",
"futures 0.3.28",
"git2",
"gpui",
"lazy_static",
"libc",
"log",
"parking_lot 0.11.2",
"regex",
"rope",
"serde",
"serde_derive",
"serde_json",
"smol",
"sum_tree",
"tempfile",
"text",
"time",
"util",
]
[[package]]
name = "fs2"
version = "0.1.0"
dependencies = [
"anyhow",
"async-trait",
@ -4106,7 +4056,7 @@ dependencies = [
"pulldown-cmark",
"rand 0.8.5",
"regex",
"rpc2",
"rpc",
"schemars",
"serde",
"serde_derive",
@ -4957,28 +4907,8 @@ dependencies = [
"collections",
"db",
"feature_flags",
"gpui",
"rpc",
"settings",
"sum_tree",
"text",
"time",
"util",
]
[[package]]
name = "notifications2"
version = "0.1.0"
dependencies = [
"anyhow",
"channel",
"client",
"clock",
"collections",
"db2",
"feature_flags",
"gpui2",
"rpc2",
"rpc",
"settings2",
"sum_tree",
"text2",
@ -5786,27 +5716,6 @@ dependencies = [
"collections",
"fs",
"futures 0.3.28",
"gpui",
"language",
"log",
"lsp",
"node_runtime",
"parking_lot 0.11.2",
"serde",
"serde_derive",
"serde_json",
"util",
]
[[package]]
name = "prettier2"
version = "0.1.0"
dependencies = [
"anyhow",
"client",
"collections",
"fs2",
"futures 0.3.28",
"gpui2",
"language",
"log",
@ -5915,9 +5824,9 @@ dependencies = [
"collections",
"copilot",
"ctor",
"db2",
"db",
"env_logger",
"fs2",
"fs",
"fsevent",
"futures 0.3.28",
"fuzzy",
@ -5934,11 +5843,11 @@ dependencies = [
"node_runtime",
"parking_lot 0.11.2",
"postage",
"prettier2",
"prettier",
"pretty_assertions",
"rand 0.8.5",
"regex",
"rpc2",
"rpc",
"schemars",
"serde",
"serde_derive",
@ -5964,7 +5873,7 @@ dependencies = [
"anyhow",
"client",
"collections",
"db2",
"db",
"editor",
"futures 0.3.28",
"gpui2",
@ -6668,37 +6577,6 @@ dependencies = [
[[package]]
name = "rpc"
version = "0.1.0"
dependencies = [
"anyhow",
"async-lock",
"async-tungstenite",
"base64 0.13.1",
"clock",
"collections",
"ctor",
"env_logger",
"futures 0.3.28",
"gpui",
"parking_lot 0.11.2",
"prost 0.8.0",
"prost-build",
"rand 0.8.5",
"rsa 0.4.0",
"serde",
"serde_derive",
"serde_json",
"smol",
"smol-timeout",
"strum",
"tempdir",
"tracing",
"util",
"zstd",
]
[[package]]
name = "rpc2"
version = "0.1.0"
dependencies = [
"anyhow",
"async-lock",
@ -7175,7 +7053,7 @@ dependencies = [
"menu2",
"postage",
"project",
"semantic_index2",
"semantic_index",
"serde",
"serde_derive",
"serde_json",
@ -7215,60 +7093,6 @@ dependencies = [
[[package]]
name = "semantic_index"
version = "0.1.0"
dependencies = [
"ai",
"anyhow",
"async-trait",
"client",
"collections",
"ctor",
"editor",
"env_logger",
"futures 0.3.28",
"globset",
"gpui",
"language",
"lazy_static",
"log",
"ndarray",
"node_runtime",
"ordered-float 2.10.0",
"parking_lot 0.11.2",
"picker",
"postage",
"pretty_assertions",
"project",
"rand 0.8.5",
"rpc",
"rusqlite",
"rust-embed",
"schemars",
"serde",
"serde_json",
"settings",
"sha1",
"smol",
"tempdir",
"theme",
"tiktoken-rs",
"tree-sitter",
"tree-sitter-cpp",
"tree-sitter-elixir",
"tree-sitter-json 0.20.0",
"tree-sitter-lua",
"tree-sitter-php",
"tree-sitter-ruby",
"tree-sitter-rust",
"tree-sitter-toml",
"tree-sitter-typescript",
"unindent",
"util",
"workspace",
]
[[package]]
name = "semantic_index2"
version = "0.1.0"
dependencies = [
"ai",
"anyhow",
@ -7291,7 +7115,7 @@ dependencies = [
"pretty_assertions",
"project",
"rand 0.8.5",
"rpc2",
"rpc",
"rusqlite",
"rust-embed",
"schemars",
@ -7473,7 +7297,7 @@ dependencies = [
"anyhow",
"collections",
"feature_flags",
"fs2",
"fs",
"futures 0.3.28",
"gpui2",
"indoc",
@ -8424,7 +8248,7 @@ version = "0.1.0"
dependencies = [
"alacritty_terminal",
"anyhow",
"db2",
"db",
"dirs 4.0.0",
"futures 0.3.28",
"gpui2",
@ -8453,7 +8277,7 @@ version = "0.1.0"
dependencies = [
"anyhow",
"client",
"db2",
"db",
"dirs 4.0.0",
"editor",
"futures 0.3.28",
@ -8556,7 +8380,7 @@ name = "theme2"
version = "0.1.0"
dependencies = [
"anyhow",
"fs2",
"fs",
"gpui2",
"indexmap 1.9.3",
"itertools 0.11.0",
@ -8602,7 +8426,7 @@ dependencies = [
"client",
"editor",
"feature_flags",
"fs2",
"fs",
"fuzzy",
"gpui2",
"log",
@ -9672,7 +9496,7 @@ name = "vcs_menu"
version = "0.1.0"
dependencies = [
"anyhow",
"fs2",
"fs",
"fuzzy",
"gpui2",
"picker",
@ -10114,9 +9938,9 @@ version = "0.1.0"
dependencies = [
"anyhow",
"client",
"db2",
"db",
"editor",
"fs2",
"fs",
"fuzzy",
"gpui2",
"install_cli",
@ -10383,9 +10207,9 @@ dependencies = [
"call",
"client",
"collections",
"db2",
"db",
"env_logger",
"fs2",
"fs",
"futures 0.3.28",
"gpui2",
"indoc",
@ -10516,14 +10340,14 @@ dependencies = [
"copilot",
"copilot_button",
"ctor",
"db2",
"db",
"diagnostics",
"editor",
"env_logger",
"feature_flags",
"feedback",
"file_finder",
"fs2",
"fs",
"fsevent",
"futures 0.3.28",
"go_to_line",
@ -10543,7 +10367,7 @@ dependencies = [
"lsp",
"menu2",
"node_runtime",
"notifications2",
"notifications",
"num_cpus",
"outline",
"parking_lot 0.11.2",
@ -10556,12 +10380,12 @@ dependencies = [
"recent_projects",
"regex",
"rope2",
"rpc2",
"rpc",
"rsa 0.4.0",
"rust-embed",
"schemars",
"search",
"semantic_index2",
"semantic_index",
"serde",
"serde_derive",
"serde_json",

View File

@ -22,7 +22,6 @@ members = [
"crates/copilot",
"crates/copilot_button",
"crates/db",
"crates/db2",
"crates/refineable",
"crates/refineable/derive_refineable",
"crates/diagnostics",
@ -32,7 +31,6 @@ members = [
"crates/feedback",
"crates/file_finder",
"crates/fs",
"crates/fs2",
"crates/fsevent",
"crates/fuzzy",
"crates/git",
@ -56,13 +54,11 @@ members = [
"crates/multi_buffer",
"crates/node_runtime",
"crates/notifications",
"crates/notifications2",
"crates/outline",
"crates/picker",
"crates/plugin",
"crates/plugin_macros",
"crates/prettier",
"crates/prettier2",
"crates/project",
"crates/project_panel",
"crates/project_symbols",
@ -70,10 +66,8 @@ members = [
"crates/recent_projects",
"crates/rope",
"crates/rpc",
"crates/rpc2",
"crates/search",
"crates/semantic_index",
"crates/semantic_index2",
"crates/settings",
"crates/settings2",
"crates/snippet",

View File

@ -13,14 +13,14 @@ ai = { path = "../ai" }
client = { path = "../client" }
collections = { path = "../collections"}
editor = { path = "../editor" }
fs = { package = "fs2", path = "../fs2" }
fs = { path = "../fs" }
gpui = { package = "gpui2", path = "../gpui2" }
language = { path = "../language" }
menu = { package = "menu2", path = "../menu2" }
multi_buffer = { path = "../multi_buffer" }
project = { path = "../project" }
search = { path = "../search" }
semantic_index = { package = "semantic_index2", path = "../semantic_index2" }
semantic_index = { path = "../semantic_index" }
settings = { package = "settings2", path = "../settings2" }
theme = { package = "theme2", path = "../theme2" }
ui = { package = "ui2", path = "../ui2" }

View File

@ -9,7 +9,7 @@ path = "src/auto_update.rs"
doctest = false
[dependencies]
db = { package = "db2", path = "../db2" }
db = { path = "../db" }
client = { path = "../client" }
gpui = { package = "gpui2", path = "../gpui2" }
menu = { package = "menu2", path = "../menu2" }

View File

@ -25,7 +25,7 @@ collections = { path = "../collections" }
gpui = { package = "gpui2", path = "../gpui2" }
log.workspace = true
live_kit_client = { package = "live_kit_client2", path = "../live_kit_client2" }
fs = { package = "fs2", path = "../fs2" }
fs = { path = "../fs" }
language = { path = "../language" }
media = { path = "../media" }
project = { path = "../project" }
@ -45,7 +45,7 @@ smallvec.workspace = true
[dev-dependencies]
client = { path = "../client", features = ["test-support"] }
fs = { package = "fs2", path = "../fs2", features = ["test-support"] }
fs = { path = "../fs", features = ["test-support"] }
language = { path = "../language", features = ["test-support"] }
collections = { path = "../collections", features = ["test-support"] }
gpui = { package = "gpui2", path = "../gpui2", features = ["test-support"] }

View File

@ -14,10 +14,10 @@ test-support = ["collections/test-support", "gpui/test-support", "rpc/test-suppo
[dependencies]
client = { path = "../client" }
collections = { path = "../collections" }
db = { package = "db2", path = "../db2" }
db = { path = "../db" }
gpui = { package = "gpui2", path = "../gpui2" }
util = { path = "../util" }
rpc = { package = "rpc2", path = "../rpc2" }
rpc = { path = "../rpc" }
text = { package = "text2", path = "../text2" }
language = { path = "../language" }
settings = { package = "settings2", path = "../settings2" }
@ -48,7 +48,7 @@ tempfile = "3"
[dev-dependencies]
collections = { path = "../collections", features = ["test-support"] }
gpui = { package = "gpui2", path = "../gpui2", features = ["test-support"] }
rpc = { package = "rpc2", path = "../rpc2", features = ["test-support"] }
rpc = { path = "../rpc", features = ["test-support"] }
client = { path = "../client", features = ["test-support"] }
settings = { package = "settings2", path = "../settings2", features = ["test-support"] }
util = { path = "../util", features = ["test-support"] }

View File

@ -14,10 +14,10 @@ test-support = ["collections/test-support", "gpui/test-support", "rpc/test-suppo
[dependencies]
chrono = { version = "0.4", features = ["serde"] }
collections = { path = "../collections" }
db = { package = "db2", path = "../db2" }
db = { path = "../db" }
gpui = { package = "gpui2", path = "../gpui2" }
util = { path = "../util" }
rpc = { package = "rpc2", path = "../rpc2" }
rpc = { path = "../rpc" }
text = { package = "text2", path = "../text2" }
settings = { package = "settings2", path = "../settings2" }
feature_flags = { path = "../feature_flags" }
@ -48,6 +48,6 @@ url = "2.2"
[dev-dependencies]
collections = { path = "../collections", features = ["test-support"] }
gpui = { package = "gpui2", path = "../gpui2", features = ["test-support"] }
rpc = { package = "rpc2", path = "../rpc2", features = ["test-support"] }
rpc = { path = "../rpc", features = ["test-support"] }
settings = { package = "settings2", path = "../settings2", features = ["test-support"] }
util = { path = "../util", features = ["test-support"] }

View File

@ -18,7 +18,7 @@ clock = { path = "../clock" }
collections = { path = "../collections" }
live_kit_server = { path = "../live_kit_server" }
text = { package = "text2", path = "../text2" }
rpc = { package = "rpc2", path = "../rpc2" }
rpc = { path = "../rpc" }
util = { path = "../util" }
anyhow.workspace = true
@ -68,15 +68,15 @@ client = { path = "../client", features = ["test-support"] }
channel = { path = "../channel" }
editor = { path = "../editor", features = ["test-support"] }
language = { path = "../language", features = ["test-support"] }
fs = { package = "fs2", path = "../fs2", features = ["test-support"] }
fs = { path = "../fs", features = ["test-support"] }
git = { package = "git3", path = "../git3", features = ["test-support"] }
live_kit_client = { package = "live_kit_client2", path = "../live_kit_client2", features = ["test-support"] }
lsp = { path = "../lsp", features = ["test-support"] }
node_runtime = { path = "../node_runtime" }
notifications = { package = "notifications2", path = "../notifications2", features = ["test-support"] }
notifications = { path = "../notifications", features = ["test-support"] }
project = { path = "../project", features = ["test-support"] }
rpc = { package = "rpc2", path = "../rpc2", features = ["test-support"] }
rpc = { path = "../rpc", features = ["test-support"] }
settings = { package = "settings2", path = "../settings2", features = ["test-support"] }
theme = { package = "theme2", path = "../theme2" }
workspace = { path = "../workspace", features = ["test-support"] }

View File

@ -23,7 +23,7 @@ test-support = [
[dependencies]
auto_update = { path = "../auto_update" }
db = { package = "db2", path = "../db2" }
db = { path = "../db" }
call = { path = "../call" }
client = { path = "../client" }
channel = { path = "../channel" }
@ -37,12 +37,12 @@ fuzzy = { path = "../fuzzy" }
gpui = { package = "gpui2", path = "../gpui2" }
language = { path = "../language" }
menu = { package = "menu2", path = "../menu2" }
notifications = { package = "notifications2", path = "../notifications2" }
notifications = { path = "../notifications" }
rich_text = { path = "../rich_text" }
picker = { path = "../picker" }
project = { path = "../project" }
recent_projects = { path = "../recent_projects" }
rpc = { package ="rpc2", path = "../rpc2" }
rpc = { path = "../rpc" }
settings = { package = "settings2", path = "../settings2" }
feature_flags = { path = "../feature_flags"}
theme = { package = "theme2", path = "../theme2" }
@ -70,9 +70,9 @@ client = { path = "../client", features = ["test-support"] }
collections = { path = "../collections", features = ["test-support"] }
editor = { path = "../editor", features = ["test-support"] }
gpui = { package = "gpui2", path = "../gpui2", features = ["test-support"] }
notifications = { package = "notifications2", path = "../notifications2", features = ["test-support"] }
notifications = { path = "../notifications", features = ["test-support"] }
project = { path = "../project", features = ["test-support"] }
rpc = { package = "rpc2", path = "../rpc2", features = ["test-support"] }
rpc = { path = "../rpc", features = ["test-support"] }
settings = { package = "settings2", path = "../settings2", features = ["test-support"] }
util = { path = "../util", features = ["test-support"] }
workspace = { path = "../workspace", features = ["test-support"] }

View File

@ -46,6 +46,6 @@ fs = { path = "../fs", features = ["test-support"] }
gpui = { package = "gpui2", path = "../gpui2", features = ["test-support"] }
language = { path = "../language", features = ["test-support"] }
lsp = { path = "../lsp", features = ["test-support"] }
rpc = { package = "rpc2", path = "../rpc2", features = ["test-support"] }
rpc = { path = "../rpc", features = ["test-support"] }
settings = { package = "settings2", path = "../settings2", features = ["test-support"] }
util = { path = "../util", features = ["test-support"] }

View File

@ -11,7 +11,7 @@ doctest = false
[dependencies]
copilot = { path = "../copilot" }
editor = { path = "../editor" }
fs = { package = "fs2", path = "../fs2" }
fs = { path = "../fs" }
zed_actions = { path = "../zed_actions"}
gpui = { package = "gpui2", path = "../gpui2" }
language = { path = "../language" }

View File

@ -13,7 +13,7 @@ test-support = []
[dependencies]
collections = { path = "../collections" }
gpui = { path = "../gpui" }
gpui = { package = "gpui2", path = "../gpui2" }
sqlez = { path = "../sqlez" }
sqlez_macros = { path = "../sqlez_macros" }
util = { path = "../util" }
@ -28,6 +28,6 @@ serde_derive.workspace = true
smol.workspace = true
[dev-dependencies]
gpui = { path = "../gpui", features = ["test-support"] }
gpui = { package = "gpui2", path = "../gpui2", features = ["test-support"] }
env_logger.workspace = true
tempdir.workspace = true

View File

@ -185,7 +185,7 @@ pub fn write_and_log<F>(cx: &mut AppContext, db_write: impl FnOnce() -> F + Send
where
F: Future<Output = anyhow::Result<()>> + Send,
{
cx.background()
cx.background_executor()
.spawn(async move { db_write().await.log_err() })
.detach()
}
@ -226,7 +226,9 @@ mod tests {
/// Test that DB exists but corrupted (causing recreate)
#[gpui::test]
async fn test_db_corruption() {
async fn test_db_corruption(cx: &mut gpui::TestAppContext) {
cx.executor().allow_parking();
enum CorruptedDB {}
impl Domain for CorruptedDB {
@ -268,7 +270,9 @@ mod tests {
/// Test that DB exists but corrupted (causing recreate)
#[gpui::test(iterations = 30)]
async fn test_simultaneous_db_corruption() {
async fn test_simultaneous_db_corruption(cx: &mut gpui::TestAppContext) {
cx.executor().allow_parking();
enum CorruptedDB {}
impl Domain for CorruptedDB {

View File

@ -1,33 +0,0 @@
[package]
name = "db2"
version = "0.1.0"
edition = "2021"
publish = false
[lib]
path = "src/db2.rs"
doctest = false
[features]
test-support = []
[dependencies]
collections = { path = "../collections" }
gpui = { package = "gpui2", path = "../gpui2" }
sqlez = { path = "../sqlez" }
sqlez_macros = { path = "../sqlez_macros" }
util = { path = "../util" }
anyhow.workspace = true
indoc.workspace = true
async-trait.workspace = true
lazy_static.workspace = true
log.workspace = true
parking_lot.workspace = true
serde.workspace = true
serde_derive.workspace = true
smol.workspace = true
[dev-dependencies]
gpui = { package = "gpui2", path = "../gpui2", features = ["test-support"] }
env_logger.workspace = true
tempdir.workspace = true

View File

@ -1,5 +0,0 @@
# Building Queries
First, craft your test data. The examples folder shows a template for building a test-db, and can be ran with `cargo run --example [your-example]`.
To actually use and test your queries, import the generated DB file into https://sqliteonline.com/

View File

@ -1,331 +0,0 @@
pub mod kvp;
pub mod query;
// Re-export
pub use anyhow;
use anyhow::Context;
use gpui::AppContext;
pub use indoc::indoc;
pub use lazy_static;
pub use smol;
pub use sqlez;
pub use sqlez_macros;
pub use util::channel::{RELEASE_CHANNEL, RELEASE_CHANNEL_NAME};
pub use util::paths::DB_DIR;
use sqlez::domain::Migrator;
use sqlez::thread_safe_connection::ThreadSafeConnection;
use sqlez_macros::sql;
use std::future::Future;
use std::path::{Path, PathBuf};
use std::sync::atomic::{AtomicBool, Ordering};
use util::channel::ReleaseChannel;
use util::{async_maybe, ResultExt};
const CONNECTION_INITIALIZE_QUERY: &'static str = sql!(
PRAGMA foreign_keys=TRUE;
);
const DB_INITIALIZE_QUERY: &'static str = sql!(
PRAGMA journal_mode=WAL;
PRAGMA busy_timeout=1;
PRAGMA case_sensitive_like=TRUE;
PRAGMA synchronous=NORMAL;
);
const FALLBACK_DB_NAME: &'static str = "FALLBACK_MEMORY_DB";
const DB_FILE_NAME: &'static str = "db.sqlite";
lazy_static::lazy_static! {
pub static ref ZED_STATELESS: bool = std::env::var("ZED_STATELESS").map_or(false, |v| !v.is_empty());
pub static ref ALL_FILE_DB_FAILED: AtomicBool = AtomicBool::new(false);
}
/// Open or create a database at the given directory path.
/// This will retry a couple times if there are failures. If opening fails once, the db directory
/// is moved to a backup folder and a new one is created. If that fails, a shared in memory db is created.
/// In either case, static variables are set so that the user can be notified.
pub async fn open_db<M: Migrator + 'static>(
db_dir: &Path,
release_channel: &ReleaseChannel,
) -> ThreadSafeConnection<M> {
if *ZED_STATELESS {
return open_fallback_db().await;
}
let release_channel_name = release_channel.dev_name();
let main_db_dir = db_dir.join(Path::new(&format!("0-{}", release_channel_name)));
let connection = async_maybe!({
smol::fs::create_dir_all(&main_db_dir)
.await
.context("Could not create db directory")
.log_err()?;
let db_path = main_db_dir.join(Path::new(DB_FILE_NAME));
open_main_db(&db_path).await
})
.await;
if let Some(connection) = connection {
return connection;
}
// Set another static ref so that we can escalate the notification
ALL_FILE_DB_FAILED.store(true, Ordering::Release);
// If still failed, create an in memory db with a known name
open_fallback_db().await
}
async fn open_main_db<M: Migrator>(db_path: &PathBuf) -> Option<ThreadSafeConnection<M>> {
log::info!("Opening main db");
ThreadSafeConnection::<M>::builder(db_path.to_string_lossy().as_ref(), true)
.with_db_initialization_query(DB_INITIALIZE_QUERY)
.with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY)
.build()
.await
.log_err()
}
async fn open_fallback_db<M: Migrator>() -> ThreadSafeConnection<M> {
log::info!("Opening fallback db");
ThreadSafeConnection::<M>::builder(FALLBACK_DB_NAME, false)
.with_db_initialization_query(DB_INITIALIZE_QUERY)
.with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY)
.build()
.await
.expect(
"Fallback in memory database failed. Likely initialization queries or migrations have fundamental errors",
)
}
#[cfg(any(test, feature = "test-support"))]
pub async fn open_test_db<M: Migrator>(db_name: &str) -> ThreadSafeConnection<M> {
use sqlez::thread_safe_connection::locking_queue;
ThreadSafeConnection::<M>::builder(db_name, false)
.with_db_initialization_query(DB_INITIALIZE_QUERY)
.with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY)
// Serialize queued writes via a mutex and run them synchronously
.with_write_queue_constructor(locking_queue())
.build()
.await
.unwrap()
}
/// Implements a basic DB wrapper for a given domain
#[macro_export]
macro_rules! define_connection {
(pub static ref $id:ident: $t:ident<()> = $migrations:expr;) => {
pub struct $t($crate::sqlez::thread_safe_connection::ThreadSafeConnection<$t>);
impl ::std::ops::Deref for $t {
type Target = $crate::sqlez::thread_safe_connection::ThreadSafeConnection<$t>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl $crate::sqlez::domain::Domain for $t {
fn name() -> &'static str {
stringify!($t)
}
fn migrations() -> &'static [&'static str] {
$migrations
}
}
#[cfg(any(test, feature = "test-support"))]
$crate::lazy_static::lazy_static! {
pub static ref $id: $t = $t($crate::smol::block_on($crate::open_test_db(stringify!($id))));
}
#[cfg(not(any(test, feature = "test-support")))]
$crate::lazy_static::lazy_static! {
pub static ref $id: $t = $t($crate::smol::block_on($crate::open_db(&$crate::DB_DIR, &$crate::RELEASE_CHANNEL)));
}
};
(pub static ref $id:ident: $t:ident<$($d:ty),+> = $migrations:expr;) => {
pub struct $t($crate::sqlez::thread_safe_connection::ThreadSafeConnection<( $($d),+, $t )>);
impl ::std::ops::Deref for $t {
type Target = $crate::sqlez::thread_safe_connection::ThreadSafeConnection<($($d),+, $t)>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl $crate::sqlez::domain::Domain for $t {
fn name() -> &'static str {
stringify!($t)
}
fn migrations() -> &'static [&'static str] {
$migrations
}
}
#[cfg(any(test, feature = "test-support"))]
$crate::lazy_static::lazy_static! {
pub static ref $id: $t = $t($crate::smol::block_on($crate::open_test_db(stringify!($id))));
}
#[cfg(not(any(test, feature = "test-support")))]
$crate::lazy_static::lazy_static! {
pub static ref $id: $t = $t($crate::smol::block_on($crate::open_db(&$crate::DB_DIR, &$crate::RELEASE_CHANNEL)));
}
};
}
pub fn write_and_log<F>(cx: &mut AppContext, db_write: impl FnOnce() -> F + Send + 'static)
where
F: Future<Output = anyhow::Result<()>> + Send,
{
cx.background_executor()
.spawn(async move { db_write().await.log_err() })
.detach()
}
#[cfg(test)]
mod tests {
use std::thread;
use sqlez::domain::Domain;
use sqlez_macros::sql;
use tempdir::TempDir;
use crate::open_db;
// Test bad migration panics
#[gpui::test]
#[should_panic]
async fn test_bad_migration_panics() {
enum BadDB {}
impl Domain for BadDB {
fn name() -> &'static str {
"db_tests"
}
fn migrations() -> &'static [&'static str] {
&[
sql!(CREATE TABLE test(value);),
// failure because test already exists
sql!(CREATE TABLE test(value);),
]
}
}
let tempdir = TempDir::new("DbTests").unwrap();
let _bad_db = open_db::<BadDB>(tempdir.path(), &util::channel::ReleaseChannel::Dev).await;
}
/// Test that DB exists but corrupted (causing recreate)
#[gpui::test]
async fn test_db_corruption(cx: &mut gpui::TestAppContext) {
cx.executor().allow_parking();
enum CorruptedDB {}
impl Domain for CorruptedDB {
fn name() -> &'static str {
"db_tests"
}
fn migrations() -> &'static [&'static str] {
&[sql!(CREATE TABLE test(value);)]
}
}
enum GoodDB {}
impl Domain for GoodDB {
fn name() -> &'static str {
"db_tests" //Notice same name
}
fn migrations() -> &'static [&'static str] {
&[sql!(CREATE TABLE test2(value);)] //But different migration
}
}
let tempdir = TempDir::new("DbTests").unwrap();
{
let corrupt_db =
open_db::<CorruptedDB>(tempdir.path(), &util::channel::ReleaseChannel::Dev).await;
assert!(corrupt_db.persistent());
}
let good_db = open_db::<GoodDB>(tempdir.path(), &util::channel::ReleaseChannel::Dev).await;
assert!(
good_db.select_row::<usize>("SELECT * FROM test2").unwrap()()
.unwrap()
.is_none()
);
}
/// Test that DB exists but corrupted (causing recreate)
#[gpui::test(iterations = 30)]
async fn test_simultaneous_db_corruption(cx: &mut gpui::TestAppContext) {
cx.executor().allow_parking();
enum CorruptedDB {}
impl Domain for CorruptedDB {
fn name() -> &'static str {
"db_tests"
}
fn migrations() -> &'static [&'static str] {
&[sql!(CREATE TABLE test(value);)]
}
}
enum GoodDB {}
impl Domain for GoodDB {
fn name() -> &'static str {
"db_tests" //Notice same name
}
fn migrations() -> &'static [&'static str] {
&[sql!(CREATE TABLE test2(value);)] //But different migration
}
}
let tempdir = TempDir::new("DbTests").unwrap();
{
// Setup the bad database
let corrupt_db =
open_db::<CorruptedDB>(tempdir.path(), &util::channel::ReleaseChannel::Dev).await;
assert!(corrupt_db.persistent());
}
// Try to connect to it a bunch of times at once
let mut guards = vec![];
for _ in 0..10 {
let tmp_path = tempdir.path().to_path_buf();
let guard = thread::spawn(move || {
let good_db = smol::block_on(open_db::<GoodDB>(
tmp_path.as_path(),
&util::channel::ReleaseChannel::Dev,
));
assert!(
good_db.select_row::<usize>("SELECT * FROM test2").unwrap()()
.unwrap()
.is_none()
);
});
guards.push(guard);
}
for guard in guards.into_iter() {
assert!(guard.join().is_ok());
}
}
}

View File

@ -1,62 +0,0 @@
use sqlez_macros::sql;
use crate::{define_connection, query};
define_connection!(pub static ref KEY_VALUE_STORE: KeyValueStore<()> =
&[sql!(
CREATE TABLE IF NOT EXISTS kv_store(
key TEXT PRIMARY KEY,
value TEXT NOT NULL
) STRICT;
)];
);
impl KeyValueStore {
query! {
pub fn read_kvp(key: &str) -> Result<Option<String>> {
SELECT value FROM kv_store WHERE key = (?)
}
}
query! {
pub async fn write_kvp(key: String, value: String) -> Result<()> {
INSERT OR REPLACE INTO kv_store(key, value) VALUES ((?), (?))
}
}
query! {
pub async fn delete_kvp(key: String) -> Result<()> {
DELETE FROM kv_store WHERE key = (?)
}
}
}
#[cfg(test)]
mod tests {
use crate::kvp::KeyValueStore;
#[gpui::test]
async fn test_kvp() {
let db = KeyValueStore(crate::open_test_db("test_kvp").await);
assert_eq!(db.read_kvp("key-1").unwrap(), None);
db.write_kvp("key-1".to_string(), "one".to_string())
.await
.unwrap();
assert_eq!(db.read_kvp("key-1").unwrap(), Some("one".to_string()));
db.write_kvp("key-1".to_string(), "one-2".to_string())
.await
.unwrap();
assert_eq!(db.read_kvp("key-1").unwrap(), Some("one-2".to_string()));
db.write_kvp("key-2".to_string(), "two".to_string())
.await
.unwrap();
assert_eq!(db.read_kvp("key-2").unwrap(), Some("two".to_string()));
db.delete_kvp("key-1".to_string()).await.unwrap();
assert_eq!(db.read_kvp("key-1").unwrap(), None);
}
}

View File

@ -1,314 +0,0 @@
#[macro_export]
macro_rules! query {
($vis:vis fn $id:ident() -> Result<()> { $($sql:tt)+ }) => {
$vis fn $id(&self) -> $crate::anyhow::Result<()> {
use $crate::anyhow::Context;
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
self.exec(sql_stmt)?().context(::std::format!(
"Error in {}, exec failed to execute or parse for: {}",
::std::stringify!($id),
sql_stmt,
))
}
};
($vis:vis async fn $id:ident() -> Result<()> { $($sql:tt)+ }) => {
$vis async fn $id(&self) -> $crate::anyhow::Result<()> {
use $crate::anyhow::Context;
self.write(|connection| {
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
connection.exec(sql_stmt)?().context(::std::format!(
"Error in {}, exec failed to execute or parse for: {}",
::std::stringify!($id),
sql_stmt
))
}).await
}
};
($vis:vis fn $id:ident($($arg:ident: $arg_type:ty),+) -> Result<()> { $($sql:tt)+ }) => {
$vis fn $id(&self, $($arg: $arg_type),+) -> $crate::anyhow::Result<()> {
use $crate::anyhow::Context;
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
self.exec_bound::<($($arg_type),+)>(sql_stmt)?(($($arg),+))
.context(::std::format!(
"Error in {}, exec_bound failed to execute or parse for: {}",
::std::stringify!($id),
sql_stmt
))
}
};
($vis:vis async fn $id:ident($arg:ident: $arg_type:ty) -> Result<()> { $($sql:tt)+ }) => {
$vis async fn $id(&self, $arg: $arg_type) -> $crate::anyhow::Result<()> {
use $crate::anyhow::Context;
self.write(move |connection| {
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
connection.exec_bound::<$arg_type>(sql_stmt)?($arg)
.context(::std::format!(
"Error in {}, exec_bound failed to execute or parse for: {}",
::std::stringify!($id),
sql_stmt
))
}).await
}
};
($vis:vis async fn $id:ident($($arg:ident: $arg_type:ty),+) -> Result<()> { $($sql:tt)+ }) => {
$vis async fn $id(&self, $($arg: $arg_type),+) -> $crate::anyhow::Result<()> {
use $crate::anyhow::Context;
self.write(move |connection| {
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
connection.exec_bound::<($($arg_type),+)>(sql_stmt)?(($($arg),+))
.context(::std::format!(
"Error in {}, exec_bound failed to execute or parse for: {}",
::std::stringify!($id),
sql_stmt
))
}).await
}
};
($vis:vis fn $id:ident() -> Result<Vec<$return_type:ty>> { $($sql:tt)+ }) => {
$vis fn $id(&self) -> $crate::anyhow::Result<Vec<$return_type>> {
use $crate::anyhow::Context;
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
self.select::<$return_type>(sql_stmt)?()
.context(::std::format!(
"Error in {}, select_row failed to execute or parse for: {}",
::std::stringify!($id),
sql_stmt
))
}
};
($vis:vis async fn $id:ident() -> Result<Vec<$return_type:ty>> { $($sql:tt)+ }) => {
pub async fn $id(&self) -> $crate::anyhow::Result<Vec<$return_type>> {
use $crate::anyhow::Context;
self.write(|connection| {
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
connection.select::<$return_type>(sql_stmt)?()
.context(::std::format!(
"Error in {}, select_row failed to execute or parse for: {}",
::std::stringify!($id),
sql_stmt
))
}).await
}
};
($vis:vis fn $id:ident($($arg:ident: $arg_type:ty),+) -> Result<Vec<$return_type:ty>> { $($sql:tt)+ }) => {
$vis fn $id(&self, $($arg: $arg_type),+) -> $crate::anyhow::Result<Vec<$return_type>> {
use $crate::anyhow::Context;
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
self.select_bound::<($($arg_type),+), $return_type>(sql_stmt)?(($($arg),+))
.context(::std::format!(
"Error in {}, exec_bound failed to execute or parse for: {}",
::std::stringify!($id),
sql_stmt
))
}
};
($vis:vis async fn $id:ident($($arg:ident: $arg_type:ty),+) -> Result<Vec<$return_type:ty>> { $($sql:tt)+ }) => {
$vis async fn $id(&self, $($arg: $arg_type),+) -> $crate::anyhow::Result<Vec<$return_type>> {
use $crate::anyhow::Context;
self.write(|connection| {
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
connection.select_bound::<($($arg_type),+), $return_type>(sql_stmt)?(($($arg),+))
.context(::std::format!(
"Error in {}, exec_bound failed to execute or parse for: {}",
::std::stringify!($id),
sql_stmt
))
}).await
}
};
($vis:vis fn $id:ident() -> Result<Option<$return_type:ty>> { $($sql:tt)+ }) => {
$vis fn $id(&self) -> $crate::anyhow::Result<Option<$return_type>> {
use $crate::anyhow::Context;
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
self.select_row::<$return_type>(sql_stmt)?()
.context(::std::format!(
"Error in {}, select_row failed to execute or parse for: {}",
::std::stringify!($id),
sql_stmt
))
}
};
($vis:vis async fn $id:ident() -> Result<Option<$return_type:ty>> { $($sql:tt)+ }) => {
$vis async fn $id(&self) -> $crate::anyhow::Result<Option<$return_type>> {
use $crate::anyhow::Context;
self.write(|connection| {
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
connection.select_row::<$return_type>(sql_stmt)?()
.context(::std::format!(
"Error in {}, select_row failed to execute or parse for: {}",
::std::stringify!($id),
sql_stmt
))
}).await
}
};
($vis:vis fn $id:ident($arg:ident: $arg_type:ty) -> Result<Option<$return_type:ty>> { $($sql:tt)+ }) => {
$vis fn $id(&self, $arg: $arg_type) -> $crate::anyhow::Result<Option<$return_type>> {
use $crate::anyhow::Context;
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
self.select_row_bound::<$arg_type, $return_type>(sql_stmt)?($arg)
.context(::std::format!(
"Error in {}, select_row_bound failed to execute or parse for: {}",
::std::stringify!($id),
sql_stmt
))
}
};
($vis:vis fn $id:ident($($arg:ident: $arg_type:ty),+) -> Result<Option<$return_type:ty>> { $($sql:tt)+ }) => {
$vis fn $id(&self, $($arg: $arg_type),+) -> $crate::anyhow::Result<Option<$return_type>> {
use $crate::anyhow::Context;
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
self.select_row_bound::<($($arg_type),+), $return_type>(sql_stmt)?(($($arg),+))
.context(::std::format!(
"Error in {}, select_row_bound failed to execute or parse for: {}",
::std::stringify!($id),
sql_stmt
))
}
};
($vis:vis async fn $id:ident($($arg:ident: $arg_type:ty),+) -> Result<Option<$return_type:ty>> { $($sql:tt)+ }) => {
$vis async fn $id(&self, $($arg: $arg_type),+) -> $crate::anyhow::Result<Option<$return_type>> {
use $crate::anyhow::Context;
self.write(move |connection| {
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
connection.select_row_bound::<($($arg_type),+), $return_type>(sql_stmt)?(($($arg),+))
.context(::std::format!(
"Error in {}, select_row_bound failed to execute or parse for: {}",
::std::stringify!($id),
sql_stmt
))
}).await
}
};
($vis:vis fn $id:ident() -> Result<$return_type:ty> { $($sql:tt)+ }) => {
$vis fn $id(&self) -> $crate::anyhow::Result<$return_type> {
use $crate::anyhow::Context;
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
self.select_row::<$return_type>(indoc! { $sql })?()
.context(::std::format!(
"Error in {}, select_row_bound failed to execute or parse for: {}",
::std::stringify!($id),
sql_stmt
))?
.context(::std::format!(
"Error in {}, select_row_bound expected single row result but found none for: {}",
::std::stringify!($id),
sql_stmt
))
}
};
($vis:vis async fn $id:ident() -> Result<$return_type:ty> { $($sql:tt)+ }) => {
$vis async fn $id(&self) -> $crate::anyhow::Result<$return_type> {
use $crate::anyhow::Context;
self.write(|connection| {
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
connection.select_row::<$return_type>(sql_stmt)?()
.context(::std::format!(
"Error in {}, select_row_bound failed to execute or parse for: {}",
::std::stringify!($id),
sql_stmt
))?
.context(::std::format!(
"Error in {}, select_row_bound expected single row result but found none for: {}",
::std::stringify!($id),
sql_stmt
))
}).await
}
};
($vis:vis fn $id:ident($arg:ident: $arg_type:ty) -> Result<$return_type:ty> { $($sql:tt)+ }) => {
pub fn $id(&self, $arg: $arg_type) -> $crate::anyhow::Result<$return_type> {
use $crate::anyhow::Context;
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
self.select_row_bound::<$arg_type, $return_type>(sql_stmt)?($arg)
.context(::std::format!(
"Error in {}, select_row_bound failed to execute or parse for: {}",
::std::stringify!($id),
sql_stmt
))?
.context(::std::format!(
"Error in {}, select_row_bound expected single row result but found none for: {}",
::std::stringify!($id),
sql_stmt
))
}
};
($vis:vis fn $id:ident($($arg:ident: $arg_type:ty),+) -> Result<$return_type:ty> { $($sql:tt)+ }) => {
$vis fn $id(&self, $($arg: $arg_type),+) -> $crate::anyhow::Result<$return_type> {
use $crate::anyhow::Context;
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
self.select_row_bound::<($($arg_type),+), $return_type>(sql_stmt)?(($($arg),+))
.context(::std::format!(
"Error in {}, select_row_bound failed to execute or parse for: {}",
::std::stringify!($id),
sql_stmt
))?
.context(::std::format!(
"Error in {}, select_row_bound expected single row result but found none for: {}",
::std::stringify!($id),
sql_stmt
))
}
};
($vis:vis fn async $id:ident($($arg:ident: $arg_type:ty),+) -> Result<$return_type:ty> { $($sql:tt)+ }) => {
$vis async fn $id(&self, $($arg: $arg_type),+) -> $crate::anyhow::Result<$return_type> {
use $crate::anyhow::Context;
self.write(|connection| {
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
connection.select_row_bound::<($($arg_type),+), $return_type>(sql_stmt)?(($($arg),+))
.context(::std::format!(
"Error in {}, select_row_bound failed to execute or parse for: {}",
::std::stringify!($id),
sql_stmt
))?
.context(::std::format!(
"Error in {}, select_row_bound expected single row result but found none for: {}",
::std::stringify!($id),
sql_stmt
))
}).await
}
};
}

View File

@ -26,7 +26,7 @@ test-support = [
client = { path = "../client" }
clock = { path = "../clock" }
copilot = { path = "../copilot" }
db = { package="db2", path = "../db2" }
db = { path = "../db" }
collections = { path = "../collections" }
# context_menu = { path = "../context_menu" }
fuzzy = { path = "../fuzzy" }
@ -36,7 +36,7 @@ language = { path = "../language" }
lsp = { path = "../lsp" }
multi_buffer = { path = "../multi_buffer" }
project = { path = "../project" }
rpc = { package = "rpc2", path = "../rpc2" }
rpc = { path = "../rpc" }
rich_text = { path = "../rich_text" }
settings = { package="settings2", path = "../settings2" }
snippet = { path = "../snippet" }

View File

@ -12,7 +12,7 @@ test-support = []
[dependencies]
client = { path = "../client" }
db = { package = "db2", path = "../db2" }
db = { path = "../db" }
editor = { path = "../editor" }
gpui = { package = "gpui2", path = "../gpui2" }
language = { path = "../language" }

View File

@ -9,8 +9,8 @@ path = "src/fs.rs"
[dependencies]
collections = { path = "../collections" }
rope = { path = "../rope" }
text = { path = "../text" }
rope = { package = "rope2", path = "../rope2" }
text = { package = "text2", path = "../text2" }
util = { path = "../util" }
sum_tree = { path = "../sum_tree" }
@ -31,10 +31,10 @@ log.workspace = true
libc = "0.2"
time.workspace = true
gpui = { path = "../gpui", optional = true}
gpui = { package = "gpui2", path = "../gpui2", optional = true}
[dev-dependencies]
gpui = { path = "../gpui", features = ["test-support"] }
gpui = { package = "gpui2", path = "../gpui2", features = ["test-support"] }
[features]
test-support = ["gpui/test-support"]

View File

@ -27,8 +27,6 @@ use collections::{btree_map, BTreeMap};
use repository::{FakeGitRepositoryState, GitFileStatus};
#[cfg(any(test, feature = "test-support"))]
use std::ffi::OsStr;
#[cfg(any(test, feature = "test-support"))]
use std::sync::Weak;
#[async_trait::async_trait]
pub trait Fs: Send + Sync {
@ -290,7 +288,7 @@ impl Fs for RealFs {
pub struct FakeFs {
// Use an unfair lock to ensure tests are deterministic.
state: Mutex<FakeFsState>,
executor: Weak<gpui::executor::Background>,
executor: gpui::BackgroundExecutor,
}
#[cfg(any(test, feature = "test-support"))]
@ -436,9 +434,9 @@ lazy_static::lazy_static! {
#[cfg(any(test, feature = "test-support"))]
impl FakeFs {
pub fn new(executor: Arc<gpui::executor::Background>) -> Arc<Self> {
pub fn new(executor: gpui::BackgroundExecutor) -> Arc<Self> {
Arc::new(Self {
executor: Arc::downgrade(&executor),
executor,
state: Mutex::new(FakeFsState {
root: Arc::new(Mutex::new(FakeFsEntry::Dir {
inode: 0,
@ -699,12 +697,8 @@ impl FakeFs {
self.state.lock().metadata_call_count
}
async fn simulate_random_delay(&self) {
self.executor
.upgrade()
.expect("executor has been dropped")
.simulate_random_delay()
.await;
fn simulate_random_delay(&self) -> impl futures::Future<Output = ()> {
self.executor.simulate_random_delay()
}
}
@ -1103,9 +1097,7 @@ impl Fs for FakeFs {
let result = events.iter().any(|event| event.path.starts_with(&path));
let executor = executor.clone();
async move {
if let Some(executor) = executor.clone().upgrade() {
executor.simulate_random_delay().await;
}
executor.simulate_random_delay().await;
result
}
}))
@ -1230,13 +1222,12 @@ pub fn copy_recursive<'a>(
#[cfg(test)]
mod tests {
use super::*;
use gpui::TestAppContext;
use gpui::BackgroundExecutor;
use serde_json::json;
#[gpui::test]
async fn test_fake_fs(cx: &mut TestAppContext) {
let fs = FakeFs::new(cx.background());
async fn test_fake_fs(executor: BackgroundExecutor) {
let fs = FakeFs::new(executor.clone());
fs.insert_tree(
"/root",
json!({

View File

@ -1,40 +0,0 @@
[package]
name = "fs2"
version = "0.1.0"
edition = "2021"
publish = false
[lib]
path = "src/fs2.rs"
[dependencies]
collections = { path = "../collections" }
rope = { package = "rope2", path = "../rope2" }
text = { package = "text2", path = "../text2" }
util = { path = "../util" }
sum_tree = { path = "../sum_tree" }
anyhow.workspace = true
async-trait.workspace = true
futures.workspace = true
tempfile = "3"
fsevent = { path = "../fsevent" }
lazy_static.workspace = true
parking_lot.workspace = true
smol.workspace = true
regex.workspace = true
git2.workspace = true
serde.workspace = true
serde_derive.workspace = true
serde_json.workspace = true
log.workspace = true
libc = "0.2"
time.workspace = true
gpui = { package = "gpui2", path = "../gpui2", optional = true}
[dev-dependencies]
gpui = { package = "gpui2", path = "../gpui2", features = ["test-support"] }
[features]
test-support = ["gpui/test-support"]

File diff suppressed because it is too large Load Diff

View File

@ -1,415 +0,0 @@
use anyhow::Result;
use collections::HashMap;
use git2::{BranchType, StatusShow};
use parking_lot::Mutex;
use serde_derive::{Deserialize, Serialize};
use std::{
cmp::Ordering,
ffi::OsStr,
os::unix::prelude::OsStrExt,
path::{Component, Path, PathBuf},
sync::Arc,
time::SystemTime,
};
use sum_tree::{MapSeekTarget, TreeMap};
use util::ResultExt;
pub use git2::Repository as LibGitRepository;
#[derive(Clone, Debug, Hash, PartialEq)]
pub struct Branch {
pub name: Box<str>,
/// Timestamp of most recent commit, normalized to Unix Epoch format.
pub unix_timestamp: Option<i64>,
}
pub trait GitRepository: Send {
fn reload_index(&self);
fn load_index_text(&self, relative_file_path: &Path) -> Option<String>;
fn branch_name(&self) -> Option<String>;
/// Get the statuses of all of the files in the index that start with the given
/// path and have changes with resepect to the HEAD commit. This is fast because
/// the index stores hashes of trees, so that unchanged directories can be skipped.
fn staged_statuses(&self, path_prefix: &Path) -> TreeMap<RepoPath, GitFileStatus>;
/// Get the status of a given file in the working directory with respect to
/// the index. In the common case, when there are no changes, this only requires
/// an index lookup. The index stores the mtime of each file when it was added,
/// so there's no work to do if the mtime matches.
fn unstaged_status(&self, path: &RepoPath, mtime: SystemTime) -> Option<GitFileStatus>;
/// Get the status of a given file in the working directory with respect to
/// the HEAD commit. In the common case, when there are no changes, this only
/// requires an index lookup and blob comparison between the index and the HEAD
/// commit. The index stores the mtime of each file when it was added, so there's
/// no need to consider the working directory file if the mtime matches.
fn status(&self, path: &RepoPath, mtime: SystemTime) -> Option<GitFileStatus>;
fn branches(&self) -> Result<Vec<Branch>>;
fn change_branch(&self, _: &str) -> Result<()>;
fn create_branch(&self, _: &str) -> Result<()>;
}
impl std::fmt::Debug for dyn GitRepository {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("dyn GitRepository<...>").finish()
}
}
impl GitRepository for LibGitRepository {
fn reload_index(&self) {
if let Ok(mut index) = self.index() {
_ = index.read(false);
}
}
fn load_index_text(&self, relative_file_path: &Path) -> Option<String> {
fn logic(repo: &LibGitRepository, relative_file_path: &Path) -> Result<Option<String>> {
const STAGE_NORMAL: i32 = 0;
let index = repo.index()?;
// This check is required because index.get_path() unwraps internally :(
check_path_to_repo_path_errors(relative_file_path)?;
let oid = match index.get_path(&relative_file_path, STAGE_NORMAL) {
Some(entry) => entry.id,
None => return Ok(None),
};
let content = repo.find_blob(oid)?.content().to_owned();
Ok(Some(String::from_utf8(content)?))
}
match logic(&self, relative_file_path) {
Ok(value) => return value,
Err(err) => log::error!("Error loading head text: {:?}", err),
}
None
}
fn branch_name(&self) -> Option<String> {
let head = self.head().log_err()?;
let branch = String::from_utf8_lossy(head.shorthand_bytes());
Some(branch.to_string())
}
fn staged_statuses(&self, path_prefix: &Path) -> TreeMap<RepoPath, GitFileStatus> {
let mut map = TreeMap::default();
let mut options = git2::StatusOptions::new();
options.pathspec(path_prefix);
options.show(StatusShow::Index);
if let Some(statuses) = self.statuses(Some(&mut options)).log_err() {
for status in statuses.iter() {
let path = RepoPath(PathBuf::from(OsStr::from_bytes(status.path_bytes())));
let status = status.status();
if !status.contains(git2::Status::IGNORED) {
if let Some(status) = read_status(status) {
map.insert(path, status)
}
}
}
}
map
}
fn unstaged_status(&self, path: &RepoPath, mtime: SystemTime) -> Option<GitFileStatus> {
// If the file has not changed since it was added to the index, then
// there can't be any changes.
if matches_index(self, path, mtime) {
return None;
}
let mut options = git2::StatusOptions::new();
options.pathspec(&path.0);
options.disable_pathspec_match(true);
options.include_untracked(true);
options.recurse_untracked_dirs(true);
options.include_unmodified(true);
options.show(StatusShow::Workdir);
let statuses = self.statuses(Some(&mut options)).log_err()?;
let status = statuses.get(0).and_then(|s| read_status(s.status()));
status
}
fn status(&self, path: &RepoPath, mtime: SystemTime) -> Option<GitFileStatus> {
let mut options = git2::StatusOptions::new();
options.pathspec(&path.0);
options.disable_pathspec_match(true);
options.include_untracked(true);
options.recurse_untracked_dirs(true);
options.include_unmodified(true);
// If the file has not changed since it was added to the index, then
// there's no need to examine the working directory file: just compare
// the blob in the index to the one in the HEAD commit.
if matches_index(self, path, mtime) {
options.show(StatusShow::Index);
}
let statuses = self.statuses(Some(&mut options)).log_err()?;
let status = statuses.get(0).and_then(|s| read_status(s.status()));
status
}
fn branches(&self) -> Result<Vec<Branch>> {
let local_branches = self.branches(Some(BranchType::Local))?;
let valid_branches = local_branches
.filter_map(|branch| {
branch.ok().and_then(|(branch, _)| {
let name = branch.name().ok().flatten().map(Box::from)?;
let timestamp = branch.get().peel_to_commit().ok()?.time();
let unix_timestamp = timestamp.seconds();
let timezone_offset = timestamp.offset_minutes();
let utc_offset =
time::UtcOffset::from_whole_seconds(timezone_offset * 60).ok()?;
let unix_timestamp =
time::OffsetDateTime::from_unix_timestamp(unix_timestamp).ok()?;
Some(Branch {
name,
unix_timestamp: Some(unix_timestamp.to_offset(utc_offset).unix_timestamp()),
})
})
})
.collect();
Ok(valid_branches)
}
fn change_branch(&self, name: &str) -> Result<()> {
let revision = self.find_branch(name, BranchType::Local)?;
let revision = revision.get();
let as_tree = revision.peel_to_tree()?;
self.checkout_tree(as_tree.as_object(), None)?;
self.set_head(
revision
.name()
.ok_or_else(|| anyhow::anyhow!("Branch name could not be retrieved"))?,
)?;
Ok(())
}
fn create_branch(&self, name: &str) -> Result<()> {
let current_commit = self.head()?.peel_to_commit()?;
self.branch(name, &current_commit, false)?;
Ok(())
}
}
fn matches_index(repo: &LibGitRepository, path: &RepoPath, mtime: SystemTime) -> bool {
if let Some(index) = repo.index().log_err() {
if let Some(entry) = index.get_path(&path, 0) {
if let Some(mtime) = mtime.duration_since(SystemTime::UNIX_EPOCH).log_err() {
if entry.mtime.seconds() == mtime.as_secs() as i32
&& entry.mtime.nanoseconds() == mtime.subsec_nanos()
{
return true;
}
}
}
}
false
}
fn read_status(status: git2::Status) -> Option<GitFileStatus> {
if status.contains(git2::Status::CONFLICTED) {
Some(GitFileStatus::Conflict)
} else if status.intersects(
git2::Status::WT_MODIFIED
| git2::Status::WT_RENAMED
| git2::Status::INDEX_MODIFIED
| git2::Status::INDEX_RENAMED,
) {
Some(GitFileStatus::Modified)
} else if status.intersects(git2::Status::WT_NEW | git2::Status::INDEX_NEW) {
Some(GitFileStatus::Added)
} else {
None
}
}
#[derive(Debug, Clone, Default)]
pub struct FakeGitRepository {
state: Arc<Mutex<FakeGitRepositoryState>>,
}
#[derive(Debug, Clone, Default)]
pub struct FakeGitRepositoryState {
pub index_contents: HashMap<PathBuf, String>,
pub worktree_statuses: HashMap<RepoPath, GitFileStatus>,
pub branch_name: Option<String>,
}
impl FakeGitRepository {
pub fn open(state: Arc<Mutex<FakeGitRepositoryState>>) -> Arc<Mutex<dyn GitRepository>> {
Arc::new(Mutex::new(FakeGitRepository { state }))
}
}
impl GitRepository for FakeGitRepository {
fn reload_index(&self) {}
fn load_index_text(&self, path: &Path) -> Option<String> {
let state = self.state.lock();
state.index_contents.get(path).cloned()
}
fn branch_name(&self) -> Option<String> {
let state = self.state.lock();
state.branch_name.clone()
}
fn staged_statuses(&self, path_prefix: &Path) -> TreeMap<RepoPath, GitFileStatus> {
let mut map = TreeMap::default();
let state = self.state.lock();
for (repo_path, status) in state.worktree_statuses.iter() {
if repo_path.0.starts_with(path_prefix) {
map.insert(repo_path.to_owned(), status.to_owned());
}
}
map
}
fn unstaged_status(&self, _path: &RepoPath, _mtime: SystemTime) -> Option<GitFileStatus> {
None
}
fn status(&self, path: &RepoPath, _mtime: SystemTime) -> Option<GitFileStatus> {
let state = self.state.lock();
state.worktree_statuses.get(path).cloned()
}
fn branches(&self) -> Result<Vec<Branch>> {
Ok(vec![])
}
fn change_branch(&self, name: &str) -> Result<()> {
let mut state = self.state.lock();
state.branch_name = Some(name.to_owned());
Ok(())
}
fn create_branch(&self, name: &str) -> Result<()> {
let mut state = self.state.lock();
state.branch_name = Some(name.to_owned());
Ok(())
}
}
fn check_path_to_repo_path_errors(relative_file_path: &Path) -> Result<()> {
match relative_file_path.components().next() {
None => anyhow::bail!("repo path should not be empty"),
Some(Component::Prefix(_)) => anyhow::bail!(
"repo path `{}` should be relative, not a windows prefix",
relative_file_path.to_string_lossy()
),
Some(Component::RootDir) => {
anyhow::bail!(
"repo path `{}` should be relative",
relative_file_path.to_string_lossy()
)
}
Some(Component::CurDir) => {
anyhow::bail!(
"repo path `{}` should not start with `.`",
relative_file_path.to_string_lossy()
)
}
Some(Component::ParentDir) => {
anyhow::bail!(
"repo path `{}` should not start with `..`",
relative_file_path.to_string_lossy()
)
}
_ => Ok(()),
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum GitFileStatus {
Added,
Modified,
Conflict,
}
impl GitFileStatus {
pub fn merge(
this: Option<GitFileStatus>,
other: Option<GitFileStatus>,
prefer_other: bool,
) -> Option<GitFileStatus> {
if prefer_other {
return other;
} else {
match (this, other) {
(Some(GitFileStatus::Conflict), _) | (_, Some(GitFileStatus::Conflict)) => {
Some(GitFileStatus::Conflict)
}
(Some(GitFileStatus::Modified), _) | (_, Some(GitFileStatus::Modified)) => {
Some(GitFileStatus::Modified)
}
(Some(GitFileStatus::Added), _) | (_, Some(GitFileStatus::Added)) => {
Some(GitFileStatus::Added)
}
_ => None,
}
}
}
}
#[derive(Clone, Debug, Ord, Hash, PartialOrd, Eq, PartialEq)]
pub struct RepoPath(pub PathBuf);
impl RepoPath {
pub fn new(path: PathBuf) -> Self {
debug_assert!(path.is_relative(), "Repo paths must be relative");
RepoPath(path)
}
}
impl From<&Path> for RepoPath {
fn from(value: &Path) -> Self {
RepoPath::new(value.to_path_buf())
}
}
impl From<PathBuf> for RepoPath {
fn from(value: PathBuf) -> Self {
RepoPath::new(value)
}
}
impl Default for RepoPath {
fn default() -> Self {
RepoPath(PathBuf::new())
}
}
impl AsRef<Path> for RepoPath {
fn as_ref(&self) -> &Path {
self.0.as_ref()
}
}
impl std::ops::Deref for RepoPath {
type Target = PathBuf;
fn deref(&self) -> &Self::Target {
&self.0
}
}
#[derive(Debug)]
pub struct RepoPathDescendants<'a>(pub &'a Path);
impl<'a> MapSeekTarget<RepoPath> for RepoPathDescendants<'a> {
fn cmp_cursor(&self, key: &RepoPath) -> Ordering {
if key.starts_with(&self.0) {
Ordering::Greater
} else {
self.0.cmp(key)
}
}
}

View File

@ -28,7 +28,7 @@ fuzzy = { path = "../fuzzy" }
git = { package = "git3", path = "../git3" }
gpui = { package = "gpui2", path = "../gpui2" }
lsp = { path = "../lsp" }
rpc = { package = "rpc2", path = "../rpc2" }
rpc = { path = "../rpc" }
settings = { package = "settings2", path = "../settings2" }
sum_tree = { path = "../sum_tree" }
text = { package = "text2", path = "../text2" }

View File

@ -5,7 +5,7 @@ edition = "2021"
publish = false
[lib]
path = "src/notification_store.rs"
path = "src/notification_store2.rs"
doctest = false
[features]
@ -23,11 +23,11 @@ clock = { path = "../clock" }
collections = { path = "../collections" }
db = { path = "../db" }
feature_flags = { path = "../feature_flags" }
gpui = { path = "../gpui" }
gpui = { package = "gpui2", path = "../gpui2" }
rpc = { path = "../rpc" }
settings = { path = "../settings" }
settings = { package = "settings2", path = "../settings2" }
sum_tree = { path = "../sum_tree" }
text = { path = "../text" }
text = { package = "text2", path = "../text2" }
util = { path = "../util" }
anyhow.workspace = true
@ -36,7 +36,7 @@ time.workspace = true
[dev-dependencies]
client = { path = "../client", features = ["test-support"] }
collections = { path = "../collections", features = ["test-support"] }
gpui = { path = "../gpui", features = ["test-support"] }
gpui = { package = "gpui2", path = "../gpui2", features = ["test-support"] }
rpc = { path = "../rpc", features = ["test-support"] }
settings = { path = "../settings", features = ["test-support"] }
settings = { package = "settings2", path = "../settings2", features = ["test-support"] }
util = { path = "../util", features = ["test-support"] }

View File

@ -1,459 +0,0 @@
use anyhow::Result;
use channel::{ChannelMessage, ChannelMessageId, ChannelStore};
use client::{Client, UserStore};
use collections::HashMap;
use db::smol::stream::StreamExt;
use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task};
use rpc::{proto, Notification, TypedEnvelope};
use std::{ops::Range, sync::Arc};
use sum_tree::{Bias, SumTree};
use time::OffsetDateTime;
use util::ResultExt;
pub fn init(client: Arc<Client>, user_store: ModelHandle<UserStore>, cx: &mut AppContext) {
let notification_store = cx.add_model(|cx| NotificationStore::new(client, user_store, cx));
cx.set_global(notification_store);
}
pub struct NotificationStore {
client: Arc<Client>,
user_store: ModelHandle<UserStore>,
channel_messages: HashMap<u64, ChannelMessage>,
channel_store: ModelHandle<ChannelStore>,
notifications: SumTree<NotificationEntry>,
loaded_all_notifications: bool,
_watch_connection_status: Task<Option<()>>,
_subscriptions: Vec<client::Subscription>,
}
#[derive(Clone, PartialEq, Eq, Debug)]
pub enum NotificationEvent {
NotificationsUpdated {
old_range: Range<usize>,
new_count: usize,
},
NewNotification {
entry: NotificationEntry,
},
NotificationRemoved {
entry: NotificationEntry,
},
NotificationRead {
entry: NotificationEntry,
},
}
#[derive(Debug, PartialEq, Eq, Clone)]
pub struct NotificationEntry {
pub id: u64,
pub notification: Notification,
pub timestamp: OffsetDateTime,
pub is_read: bool,
pub response: Option<bool>,
}
#[derive(Clone, Debug, Default)]
pub struct NotificationSummary {
max_id: u64,
count: usize,
unread_count: usize,
}
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
struct Count(usize);
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
struct UnreadCount(usize);
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
struct NotificationId(u64);
impl NotificationStore {
pub fn global(cx: &AppContext) -> ModelHandle<Self> {
cx.global::<ModelHandle<Self>>().clone()
}
pub fn new(
client: Arc<Client>,
user_store: ModelHandle<UserStore>,
cx: &mut ModelContext<Self>,
) -> Self {
let mut connection_status = client.status();
let watch_connection_status = cx.spawn_weak(|this, mut cx| async move {
while let Some(status) = connection_status.next().await {
let this = this.upgrade(&cx)?;
match status {
client::Status::Connected { .. } => {
if let Some(task) = this.update(&mut cx, |this, cx| this.handle_connect(cx))
{
task.await.log_err()?;
}
}
_ => this.update(&mut cx, |this, cx| this.handle_disconnect(cx)),
}
}
Some(())
});
Self {
channel_store: ChannelStore::global(cx),
notifications: Default::default(),
loaded_all_notifications: false,
channel_messages: Default::default(),
_watch_connection_status: watch_connection_status,
_subscriptions: vec![
client.add_message_handler(cx.handle(), Self::handle_new_notification),
client.add_message_handler(cx.handle(), Self::handle_delete_notification),
],
user_store,
client,
}
}
pub fn notification_count(&self) -> usize {
self.notifications.summary().count
}
pub fn unread_notification_count(&self) -> usize {
self.notifications.summary().unread_count
}
pub fn channel_message_for_id(&self, id: u64) -> Option<&ChannelMessage> {
self.channel_messages.get(&id)
}
// Get the nth newest notification.
pub fn notification_at(&self, ix: usize) -> Option<&NotificationEntry> {
let count = self.notifications.summary().count;
if ix >= count {
return None;
}
let ix = count - 1 - ix;
let mut cursor = self.notifications.cursor::<Count>();
cursor.seek(&Count(ix), Bias::Right, &());
cursor.item()
}
pub fn notification_for_id(&self, id: u64) -> Option<&NotificationEntry> {
let mut cursor = self.notifications.cursor::<NotificationId>();
cursor.seek(&NotificationId(id), Bias::Left, &());
if let Some(item) = cursor.item() {
if item.id == id {
return Some(item);
}
}
None
}
pub fn load_more_notifications(
&self,
clear_old: bool,
cx: &mut ModelContext<Self>,
) -> Option<Task<Result<()>>> {
if self.loaded_all_notifications && !clear_old {
return None;
}
let before_id = if clear_old {
None
} else {
self.notifications.first().map(|entry| entry.id)
};
let request = self.client.request(proto::GetNotifications { before_id });
Some(cx.spawn(|this, mut cx| async move {
let response = request.await?;
this.update(&mut cx, |this, _| {
this.loaded_all_notifications = response.done
});
Self::add_notifications(
this,
response.notifications,
AddNotificationsOptions {
is_new: false,
clear_old,
includes_first: response.done,
},
cx,
)
.await?;
Ok(())
}))
}
fn handle_connect(&mut self, cx: &mut ModelContext<Self>) -> Option<Task<Result<()>>> {
self.notifications = Default::default();
self.channel_messages = Default::default();
cx.notify();
self.load_more_notifications(true, cx)
}
fn handle_disconnect(&mut self, cx: &mut ModelContext<Self>) {
cx.notify()
}
async fn handle_new_notification(
this: ModelHandle<Self>,
envelope: TypedEnvelope<proto::AddNotification>,
_: Arc<Client>,
cx: AsyncAppContext,
) -> Result<()> {
Self::add_notifications(
this,
envelope.payload.notification.into_iter().collect(),
AddNotificationsOptions {
is_new: true,
clear_old: false,
includes_first: false,
},
cx,
)
.await
}
async fn handle_delete_notification(
this: ModelHandle<Self>,
envelope: TypedEnvelope<proto::DeleteNotification>,
_: Arc<Client>,
mut cx: AsyncAppContext,
) -> Result<()> {
this.update(&mut cx, |this, cx| {
this.splice_notifications([(envelope.payload.notification_id, None)], false, cx);
Ok(())
})
}
async fn add_notifications(
this: ModelHandle<Self>,
notifications: Vec<proto::Notification>,
options: AddNotificationsOptions,
mut cx: AsyncAppContext,
) -> Result<()> {
let mut user_ids = Vec::new();
let mut message_ids = Vec::new();
let notifications = notifications
.into_iter()
.filter_map(|message| {
Some(NotificationEntry {
id: message.id,
is_read: message.is_read,
timestamp: OffsetDateTime::from_unix_timestamp(message.timestamp as i64)
.ok()?,
notification: Notification::from_proto(&message)?,
response: message.response,
})
})
.collect::<Vec<_>>();
if notifications.is_empty() {
return Ok(());
}
for entry in &notifications {
match entry.notification {
Notification::ChannelInvitation { inviter_id, .. } => {
user_ids.push(inviter_id);
}
Notification::ContactRequest {
sender_id: requester_id,
} => {
user_ids.push(requester_id);
}
Notification::ContactRequestAccepted {
responder_id: contact_id,
} => {
user_ids.push(contact_id);
}
Notification::ChannelMessageMention {
sender_id,
message_id,
..
} => {
user_ids.push(sender_id);
message_ids.push(message_id);
}
}
}
let (user_store, channel_store) = this.read_with(&cx, |this, _| {
(this.user_store.clone(), this.channel_store.clone())
});
user_store
.update(&mut cx, |store, cx| store.get_users(user_ids, cx))
.await?;
let messages = channel_store
.update(&mut cx, |store, cx| {
store.fetch_channel_messages(message_ids, cx)
})
.await?;
this.update(&mut cx, |this, cx| {
if options.clear_old {
cx.emit(NotificationEvent::NotificationsUpdated {
old_range: 0..this.notifications.summary().count,
new_count: 0,
});
this.notifications = SumTree::default();
this.channel_messages.clear();
this.loaded_all_notifications = false;
}
if options.includes_first {
this.loaded_all_notifications = true;
}
this.channel_messages
.extend(messages.into_iter().filter_map(|message| {
if let ChannelMessageId::Saved(id) = message.id {
Some((id, message))
} else {
None
}
}));
this.splice_notifications(
notifications
.into_iter()
.map(|notification| (notification.id, Some(notification))),
options.is_new,
cx,
);
});
Ok(())
}
fn splice_notifications(
&mut self,
notifications: impl IntoIterator<Item = (u64, Option<NotificationEntry>)>,
is_new: bool,
cx: &mut ModelContext<'_, NotificationStore>,
) {
let mut cursor = self.notifications.cursor::<(NotificationId, Count)>();
let mut new_notifications = SumTree::new();
let mut old_range = 0..0;
for (i, (id, new_notification)) in notifications.into_iter().enumerate() {
new_notifications.append(cursor.slice(&NotificationId(id), Bias::Left, &()), &());
if i == 0 {
old_range.start = cursor.start().1 .0;
}
let old_notification = cursor.item();
if let Some(old_notification) = old_notification {
if old_notification.id == id {
cursor.next(&());
if let Some(new_notification) = &new_notification {
if new_notification.is_read {
cx.emit(NotificationEvent::NotificationRead {
entry: new_notification.clone(),
});
}
} else {
cx.emit(NotificationEvent::NotificationRemoved {
entry: old_notification.clone(),
});
}
}
} else if let Some(new_notification) = &new_notification {
if is_new {
cx.emit(NotificationEvent::NewNotification {
entry: new_notification.clone(),
});
}
}
if let Some(notification) = new_notification {
new_notifications.push(notification, &());
}
}
old_range.end = cursor.start().1 .0;
let new_count = new_notifications.summary().count - old_range.start;
new_notifications.append(cursor.suffix(&()), &());
drop(cursor);
self.notifications = new_notifications;
cx.emit(NotificationEvent::NotificationsUpdated {
old_range,
new_count,
});
}
pub fn respond_to_notification(
&mut self,
notification: Notification,
response: bool,
cx: &mut ModelContext<Self>,
) {
match notification {
Notification::ContactRequest { sender_id } => {
self.user_store
.update(cx, |store, cx| {
store.respond_to_contact_request(sender_id, response, cx)
})
.detach();
}
Notification::ChannelInvitation { channel_id, .. } => {
self.channel_store
.update(cx, |store, cx| {
store.respond_to_channel_invite(channel_id, response, cx)
})
.detach();
}
_ => {}
}
}
}
impl Entity for NotificationStore {
type Event = NotificationEvent;
}
impl sum_tree::Item for NotificationEntry {
type Summary = NotificationSummary;
fn summary(&self) -> Self::Summary {
NotificationSummary {
max_id: self.id,
count: 1,
unread_count: if self.is_read { 0 } else { 1 },
}
}
}
impl sum_tree::Summary for NotificationSummary {
type Context = ();
fn add_summary(&mut self, summary: &Self, _: &()) {
self.max_id = self.max_id.max(summary.max_id);
self.count += summary.count;
self.unread_count += summary.unread_count;
}
}
impl<'a> sum_tree::Dimension<'a, NotificationSummary> for NotificationId {
fn add_summary(&mut self, summary: &NotificationSummary, _: &()) {
debug_assert!(summary.max_id > self.0);
self.0 = summary.max_id;
}
}
impl<'a> sum_tree::Dimension<'a, NotificationSummary> for Count {
fn add_summary(&mut self, summary: &NotificationSummary, _: &()) {
self.0 += summary.count;
}
}
impl<'a> sum_tree::Dimension<'a, NotificationSummary> for UnreadCount {
fn add_summary(&mut self, summary: &NotificationSummary, _: &()) {
self.0 += summary.unread_count;
}
}
struct AddNotificationsOptions {
is_new: bool,
clear_old: bool,
includes_first: bool,
}

View File

@ -1,42 +0,0 @@
[package]
name = "notifications2"
version = "0.1.0"
edition = "2021"
publish = false
[lib]
path = "src/notification_store2.rs"
doctest = false
[features]
test-support = [
"channel/test-support",
"collections/test-support",
"gpui/test-support",
"rpc/test-support",
]
[dependencies]
channel = { path = "../channel" }
client = { path = "../client" }
clock = { path = "../clock" }
collections = { path = "../collections" }
db = { package = "db2", path = "../db2" }
feature_flags = { path = "../feature_flags" }
gpui = { package = "gpui2", path = "../gpui2" }
rpc = { package = "rpc2", path = "../rpc2" }
settings = { package = "settings2", path = "../settings2" }
sum_tree = { path = "../sum_tree" }
text = { package = "text2", path = "../text2" }
util = { path = "../util" }
anyhow.workspace = true
time.workspace = true
[dev-dependencies]
client = { path = "../client", features = ["test-support"] }
collections = { path = "../collections", features = ["test-support"] }
gpui = { package = "gpui2", path = "../gpui2", features = ["test-support"] }
rpc = { package = "rpc2", path = "../rpc2", features = ["test-support"] }
settings = { package = "settings2", path = "../settings2", features = ["test-support"] }
util = { path = "../util", features = ["test-support"] }

View File

@ -15,7 +15,7 @@ test-support = []
client = { path = "../client" }
collections = { path = "../collections"}
language = { path = "../language" }
gpui = { path = "../gpui" }
gpui = { package = "gpui2", path = "../gpui2" }
fs = { path = "../fs" }
lsp = { path = "../lsp" }
node_runtime = { path = "../node_runtime"}
@ -31,5 +31,5 @@ parking_lot.workspace = true
[dev-dependencies]
language = { path = "../language", features = ["test-support"] }
gpui = { path = "../gpui", features = ["test-support"] }
gpui = { package = "gpui2", path = "../gpui2", features = ["test-support"] }
fs = { path = "../fs", features = ["test-support"] }

View File

@ -1,16 +1,16 @@
use std::ops::ControlFlow;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use anyhow::Context;
use collections::{HashMap, HashSet};
use fs::Fs;
use gpui::{AsyncAppContext, ModelHandle};
use language::language_settings::language_settings;
use language::{Buffer, Diff};
use gpui::{AsyncAppContext, Model};
use language::{language_settings::language_settings, Buffer, Diff};
use lsp::{LanguageServer, LanguageServerId};
use node_runtime::NodeRuntime;
use serde::{Deserialize, Serialize};
use std::{
ops::ControlFlow,
path::{Path, PathBuf},
sync::Arc,
};
use util::paths::{PathMatcher, DEFAULT_PRETTIER_DIR};
#[derive(Clone)]
@ -100,39 +100,39 @@ impl Prettier {
}
} else {
match package_json_contents.get("workspaces") {
Some(serde_json::Value::Array(workspaces)) => {
match &project_path_with_prettier_dependency {
Some(project_path_with_prettier_dependency) => {
let subproject_path = project_path_with_prettier_dependency.strip_prefix(&path_to_check).expect("traversing path parents, should be able to strip prefix");
if workspaces.iter().filter_map(|value| {
if let serde_json::Value::String(s) = value {
Some(s.clone())
Some(serde_json::Value::Array(workspaces)) => {
match &project_path_with_prettier_dependency {
Some(project_path_with_prettier_dependency) => {
let subproject_path = project_path_with_prettier_dependency.strip_prefix(&path_to_check).expect("traversing path parents, should be able to strip prefix");
if workspaces.iter().filter_map(|value| {
if let serde_json::Value::String(s) = value {
Some(s.clone())
} else {
log::warn!("Skipping non-string 'workspaces' value: {value:?}");
None
}
}).any(|workspace_definition| {
if let Some(path_matcher) = PathMatcher::new(&workspace_definition).ok() {
path_matcher.is_match(subproject_path)
} else {
workspace_definition == subproject_path.to_string_lossy()
}
}) {
anyhow::ensure!(has_prettier_in_node_modules(fs, &path_to_check).await?, "Found prettier path {path_to_check:?} in the workspace root for project in {project_path_with_prettier_dependency:?}, but it's not installed into workspace root's node_modules");
log::info!("Found prettier path {path_to_check:?} in the workspace root for project in {project_path_with_prettier_dependency:?}");
return Ok(ControlFlow::Continue(Some(path_to_check)));
} else {
log::warn!("Skipping non-string 'workspaces' value: {value:?}");
None
log::warn!("Skipping path {path_to_check:?} that has prettier in its 'node_modules' subdirectory, but is not included in its package.json workspaces {workspaces:?}");
}
}).any(|workspace_definition| {
if let Some(path_matcher) = PathMatcher::new(&workspace_definition).ok() {
path_matcher.is_match(subproject_path)
} else {
workspace_definition == subproject_path.to_string_lossy()
}
}) {
anyhow::ensure!(has_prettier_in_node_modules(fs, &path_to_check).await?, "Found prettier path {path_to_check:?} in the workspace root for project in {project_path_with_prettier_dependency:?}, but it's not installed into workspace root's node_modules");
log::info!("Found prettier path {path_to_check:?} in the workspace root for project in {project_path_with_prettier_dependency:?}");
return Ok(ControlFlow::Continue(Some(path_to_check)));
} else {
log::warn!("Skipping path {path_to_check:?} that has prettier in its 'node_modules' subdirectory, but is not included in its package.json workspaces {workspaces:?}");
}
None => {
log::warn!("Skipping path {path_to_check:?} that has prettier in its 'node_modules' subdirectory, but has no prettier in its package.json");
}
}
None => {
log::warn!("Skipping path {path_to_check:?} that has prettier in its 'node_modules' subdirectory, but has no prettier in its package.json");
}
}
},
Some(unknown) => log::error!("Failed to parse workspaces for {path_to_check:?} from package.json, got {unknown:?}. Skipping."),
None => log::warn!("Skipping path {path_to_check:?} that has no prettier dependency and no workspaces section in its package.json"),
}
},
Some(unknown) => log::error!("Failed to parse workspaces for {path_to_check:?} from package.json, got {unknown:?}. Skipping."),
None => log::warn!("Skipping path {path_to_check:?} that has no prettier dependency and no workspaces section in its package.json"),
}
}
}
@ -172,7 +172,7 @@ impl Prettier {
) -> anyhow::Result<Self> {
use lsp::LanguageServerBinary;
let background = cx.background();
let executor = cx.background_executor().clone();
anyhow::ensure!(
prettier_dir.is_dir(),
"Prettier dir {prettier_dir:?} is not a directory"
@ -183,7 +183,7 @@ impl Prettier {
"no prettier server package found at {prettier_server:?}"
);
let node_path = background
let node_path = executor
.spawn(async move { node.binary_path().await })
.await?;
let server = LanguageServer::new(
@ -198,7 +198,7 @@ impl Prettier {
cx,
)
.context("prettier server creation")?;
let server = background
let server = executor
.spawn(server.initialize(None))
.await
.context("prettier server initialization")?;
@ -211,124 +211,154 @@ impl Prettier {
pub async fn format(
&self,
buffer: &ModelHandle<Buffer>,
buffer: &Model<Buffer>,
buffer_path: Option<PathBuf>,
cx: &AsyncAppContext,
cx: &mut AsyncAppContext,
) -> anyhow::Result<Diff> {
match self {
Self::Real(local) => {
let params = buffer.read_with(cx, |buffer, cx| {
let buffer_language = buffer.language();
let parser_with_plugins = buffer_language.and_then(|l| {
let prettier_parser = l.prettier_parser_name()?;
let mut prettier_plugins = l
.lsp_adapters()
.iter()
.flat_map(|adapter| adapter.prettier_plugins())
.collect::<Vec<_>>();
prettier_plugins.dedup();
Some((prettier_parser, prettier_plugins))
});
let params = buffer
.update(cx, |buffer, cx| {
let buffer_language = buffer.language();
let parser_with_plugins = buffer_language.and_then(|l| {
let prettier_parser = l.prettier_parser_name()?;
let mut prettier_plugins = l
.lsp_adapters()
.iter()
.flat_map(|adapter| adapter.prettier_plugins())
.collect::<Vec<_>>();
prettier_plugins.dedup();
Some((prettier_parser, prettier_plugins))
});
let prettier_node_modules = self.prettier_dir().join("node_modules");
anyhow::ensure!(prettier_node_modules.is_dir(), "Prettier node_modules dir does not exist: {prettier_node_modules:?}");
let plugin_name_into_path = |plugin_name: &str| {
let prettier_plugin_dir = prettier_node_modules.join(plugin_name);
for possible_plugin_path in [
prettier_plugin_dir.join("dist").join("index.mjs"),
prettier_plugin_dir.join("dist").join("index.js"),
prettier_plugin_dir.join("dist").join("plugin.js"),
prettier_plugin_dir.join("index.mjs"),
prettier_plugin_dir.join("index.js"),
prettier_plugin_dir.join("plugin.js"),
prettier_plugin_dir,
] {
if possible_plugin_path.is_file() {
return Some(possible_plugin_path);
}
}
None
};
let (parser, located_plugins) = match parser_with_plugins {
Some((parser, plugins)) => {
// Tailwind plugin requires being added last
// https://github.com/tailwindlabs/prettier-plugin-tailwindcss#compatibility-with-other-prettier-plugins
let mut add_tailwind_back = false;
let mut plugins = plugins.into_iter().filter(|&&plugin_name| {
if plugin_name == TAILWIND_PRETTIER_PLUGIN_PACKAGE_NAME {
add_tailwind_back = true;
false
} else {
true
let prettier_node_modules = self.prettier_dir().join("node_modules");
anyhow::ensure!(
prettier_node_modules.is_dir(),
"Prettier node_modules dir does not exist: {prettier_node_modules:?}"
);
let plugin_name_into_path = |plugin_name: &str| {
let prettier_plugin_dir = prettier_node_modules.join(plugin_name);
for possible_plugin_path in [
prettier_plugin_dir.join("dist").join("index.mjs"),
prettier_plugin_dir.join("dist").join("index.js"),
prettier_plugin_dir.join("dist").join("plugin.js"),
prettier_plugin_dir.join("index.mjs"),
prettier_plugin_dir.join("index.js"),
prettier_plugin_dir.join("plugin.js"),
prettier_plugin_dir,
] {
if possible_plugin_path.is_file() {
return Some(possible_plugin_path);
}
}).map(|plugin_name| (plugin_name, plugin_name_into_path(plugin_name))).collect::<Vec<_>>();
if add_tailwind_back {
plugins.push((&TAILWIND_PRETTIER_PLUGIN_PACKAGE_NAME, plugin_name_into_path(TAILWIND_PRETTIER_PLUGIN_PACKAGE_NAME)));
}
(Some(parser.to_string()), plugins)
},
None => (None, Vec::new()),
};
None
};
let (parser, located_plugins) = match parser_with_plugins {
Some((parser, plugins)) => {
// Tailwind plugin requires being added last
// https://github.com/tailwindlabs/prettier-plugin-tailwindcss#compatibility-with-other-prettier-plugins
let mut add_tailwind_back = false;
let prettier_options = if self.is_default() {
let language_settings = language_settings(buffer_language, buffer.file(), cx);
let mut options = language_settings.prettier.clone();
if !options.contains_key("tabWidth") {
options.insert(
"tabWidth".to_string(),
serde_json::Value::Number(serde_json::Number::from(
language_settings.tab_size.get(),
)),
);
}
if !options.contains_key("printWidth") {
options.insert(
"printWidth".to_string(),
serde_json::Value::Number(serde_json::Number::from(
language_settings.preferred_line_length,
)),
);
}
Some(options)
} else {
None
};
let mut plugins = plugins
.into_iter()
.filter(|&&plugin_name| {
if plugin_name == TAILWIND_PRETTIER_PLUGIN_PACKAGE_NAME {
add_tailwind_back = true;
false
} else {
true
}
})
.map(|plugin_name| {
(plugin_name, plugin_name_into_path(plugin_name))
})
.collect::<Vec<_>>();
if add_tailwind_back {
plugins.push((
&TAILWIND_PRETTIER_PLUGIN_PACKAGE_NAME,
plugin_name_into_path(
TAILWIND_PRETTIER_PLUGIN_PACKAGE_NAME,
),
));
}
(Some(parser.to_string()), plugins)
}
None => (None, Vec::new()),
};
let plugins = located_plugins.into_iter().filter_map(|(plugin_name, located_plugin_path)| {
match located_plugin_path {
Some(path) => Some(path),
None => {
log::error!("Have not found plugin path for {plugin_name:?} inside {prettier_node_modules:?}");
None},
}
}).collect();
log::debug!("Formatting file {:?} with prettier, plugins :{plugins:?}, options: {prettier_options:?}", buffer.file().map(|f| f.full_path(cx)));
let prettier_options = if self.is_default() {
let language_settings =
language_settings(buffer_language, buffer.file(), cx);
let mut options = language_settings.prettier.clone();
if !options.contains_key("tabWidth") {
options.insert(
"tabWidth".to_string(),
serde_json::Value::Number(serde_json::Number::from(
language_settings.tab_size.get(),
)),
);
}
if !options.contains_key("printWidth") {
options.insert(
"printWidth".to_string(),
serde_json::Value::Number(serde_json::Number::from(
language_settings.preferred_line_length,
)),
);
}
Some(options)
} else {
None
};
anyhow::Ok(FormatParams {
text: buffer.text(),
options: FormatOptions {
parser,
let plugins = located_plugins
.into_iter()
.filter_map(|(plugin_name, located_plugin_path)| {
match located_plugin_path {
Some(path) => Some(path),
None => {
log::error!(
"Have not found plugin path for {:?} inside {:?}",
plugin_name,
prettier_node_modules
);
None
}
}
})
.collect();
log::debug!(
"Formatting file {:?} with prettier, plugins :{:?}, options: {:?}",
plugins,
path: buffer_path,
prettier_options,
},
})
}).context("prettier params calculation")?;
buffer.file().map(|f| f.full_path(cx))
);
anyhow::Ok(FormatParams {
text: buffer.text(),
options: FormatOptions {
parser,
plugins,
path: buffer_path,
prettier_options,
},
})
})?
.context("prettier params calculation")?;
let response = local
.server
.request::<Format>(params)
.await
.context("prettier format request")?;
let diff_task = buffer.read_with(cx, |buffer, cx| buffer.diff(response.text, cx));
let diff_task = buffer.update(cx, |buffer, cx| buffer.diff(response.text, cx))?;
Ok(diff_task.await)
}
#[cfg(any(test, feature = "test-support"))]
Self::Test(_) => Ok(buffer
.read_with(cx, |buffer, cx| {
.update(cx, |buffer, cx| {
let formatted_text = buffer.text() + FORMAT_SUFFIX;
buffer.diff(formatted_text, cx)
})
})?
.await),
}
}
@ -471,7 +501,7 @@ mod tests {
#[gpui::test]
async fn test_prettier_lookup_finds_nothing(cx: &mut gpui::TestAppContext) {
let fs = FakeFs::new(cx.background());
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
"/root",
json!({
@ -547,7 +577,7 @@ mod tests {
#[gpui::test]
async fn test_prettier_lookup_in_simple_npm_projects(cx: &mut gpui::TestAppContext) {
let fs = FakeFs::new(cx.background());
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
"/root",
json!({
@ -612,7 +642,7 @@ mod tests {
#[gpui::test]
async fn test_prettier_lookup_for_not_installed(cx: &mut gpui::TestAppContext) {
let fs = FakeFs::new(cx.background());
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
"/root",
json!({
@ -662,6 +692,7 @@ mod tests {
assert!(message.contains("/root/work/web_blog"), "Error message should mention which project had prettier defined");
},
};
assert_eq!(
Prettier::locate_prettier_installation(
fs.as_ref(),
@ -704,7 +735,7 @@ mod tests {
#[gpui::test]
async fn test_prettier_lookup_in_npm_workspaces(cx: &mut gpui::TestAppContext) {
let fs = FakeFs::new(cx.background());
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
"/root",
json!({
@ -785,7 +816,7 @@ mod tests {
async fn test_prettier_lookup_in_npm_workspaces_for_not_installed(
cx: &mut gpui::TestAppContext,
) {
let fs = FakeFs::new(cx.background());
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
"/root",
json!({

View File

@ -1,35 +0,0 @@
[package]
name = "prettier2"
version = "0.1.0"
edition = "2021"
publish = false
[lib]
path = "src/prettier2.rs"
doctest = false
[features]
test-support = []
[dependencies]
client = { path = "../client" }
collections = { path = "../collections"}
language = { path = "../language" }
gpui = { package = "gpui2", path = "../gpui2" }
fs = { package = "fs2", path = "../fs2" }
lsp = { path = "../lsp" }
node_runtime = { path = "../node_runtime"}
util = { path = "../util" }
log.workspace = true
serde.workspace = true
serde_derive.workspace = true
serde_json.workspace = true
anyhow.workspace = true
futures.workspace = true
parking_lot.workspace = true
[dev-dependencies]
language = { path = "../language", features = ["test-support"] }
gpui = { package = "gpui2", path = "../gpui2", features = ["test-support"] }
fs = { package = "fs2", path = "../fs2", features = ["test-support"] }

View File

@ -1,869 +0,0 @@
use anyhow::Context;
use collections::{HashMap, HashSet};
use fs::Fs;
use gpui::{AsyncAppContext, Model};
use language::{language_settings::language_settings, Buffer, Diff};
use lsp::{LanguageServer, LanguageServerId};
use node_runtime::NodeRuntime;
use serde::{Deserialize, Serialize};
use std::{
ops::ControlFlow,
path::{Path, PathBuf},
sync::Arc,
};
use util::paths::{PathMatcher, DEFAULT_PRETTIER_DIR};
#[derive(Clone)]
pub enum Prettier {
Real(RealPrettier),
#[cfg(any(test, feature = "test-support"))]
Test(TestPrettier),
}
#[derive(Clone)]
pub struct RealPrettier {
default: bool,
prettier_dir: PathBuf,
server: Arc<LanguageServer>,
}
#[cfg(any(test, feature = "test-support"))]
#[derive(Clone)]
pub struct TestPrettier {
prettier_dir: PathBuf,
default: bool,
}
pub const FAIL_THRESHOLD: usize = 4;
pub const PRETTIER_SERVER_FILE: &str = "prettier_server.js";
pub const PRETTIER_SERVER_JS: &str = include_str!("./prettier_server.js");
const PRETTIER_PACKAGE_NAME: &str = "prettier";
const TAILWIND_PRETTIER_PLUGIN_PACKAGE_NAME: &str = "prettier-plugin-tailwindcss";
#[cfg(any(test, feature = "test-support"))]
pub const FORMAT_SUFFIX: &str = "\nformatted by test prettier";
impl Prettier {
pub const CONFIG_FILE_NAMES: &'static [&'static str] = &[
".prettierrc",
".prettierrc.json",
".prettierrc.json5",
".prettierrc.yaml",
".prettierrc.yml",
".prettierrc.toml",
".prettierrc.js",
".prettierrc.cjs",
"package.json",
"prettier.config.js",
"prettier.config.cjs",
".editorconfig",
];
pub async fn locate_prettier_installation(
fs: &dyn Fs,
installed_prettiers: &HashSet<PathBuf>,
locate_from: &Path,
) -> anyhow::Result<ControlFlow<(), Option<PathBuf>>> {
let mut path_to_check = locate_from
.components()
.take_while(|component| component.as_os_str().to_string_lossy() != "node_modules")
.collect::<PathBuf>();
if path_to_check != locate_from {
log::debug!(
"Skipping prettier location for path {path_to_check:?} that is inside node_modules"
);
return Ok(ControlFlow::Break(()));
}
let path_to_check_metadata = fs
.metadata(&path_to_check)
.await
.with_context(|| format!("failed to get metadata for initial path {path_to_check:?}"))?
.with_context(|| format!("empty metadata for initial path {path_to_check:?}"))?;
if !path_to_check_metadata.is_dir {
path_to_check.pop();
}
let mut project_path_with_prettier_dependency = None;
loop {
if installed_prettiers.contains(&path_to_check) {
log::debug!("Found prettier path {path_to_check:?} in installed prettiers");
return Ok(ControlFlow::Continue(Some(path_to_check)));
} else if let Some(package_json_contents) =
read_package_json(fs, &path_to_check).await?
{
if has_prettier_in_package_json(&package_json_contents) {
if has_prettier_in_node_modules(fs, &path_to_check).await? {
log::debug!("Found prettier path {path_to_check:?} in both package.json and node_modules");
return Ok(ControlFlow::Continue(Some(path_to_check)));
} else if project_path_with_prettier_dependency.is_none() {
project_path_with_prettier_dependency = Some(path_to_check.clone());
}
} else {
match package_json_contents.get("workspaces") {
Some(serde_json::Value::Array(workspaces)) => {
match &project_path_with_prettier_dependency {
Some(project_path_with_prettier_dependency) => {
let subproject_path = project_path_with_prettier_dependency.strip_prefix(&path_to_check).expect("traversing path parents, should be able to strip prefix");
if workspaces.iter().filter_map(|value| {
if let serde_json::Value::String(s) = value {
Some(s.clone())
} else {
log::warn!("Skipping non-string 'workspaces' value: {value:?}");
None
}
}).any(|workspace_definition| {
if let Some(path_matcher) = PathMatcher::new(&workspace_definition).ok() {
path_matcher.is_match(subproject_path)
} else {
workspace_definition == subproject_path.to_string_lossy()
}
}) {
anyhow::ensure!(has_prettier_in_node_modules(fs, &path_to_check).await?, "Found prettier path {path_to_check:?} in the workspace root for project in {project_path_with_prettier_dependency:?}, but it's not installed into workspace root's node_modules");
log::info!("Found prettier path {path_to_check:?} in the workspace root for project in {project_path_with_prettier_dependency:?}");
return Ok(ControlFlow::Continue(Some(path_to_check)));
} else {
log::warn!("Skipping path {path_to_check:?} that has prettier in its 'node_modules' subdirectory, but is not included in its package.json workspaces {workspaces:?}");
}
}
None => {
log::warn!("Skipping path {path_to_check:?} that has prettier in its 'node_modules' subdirectory, but has no prettier in its package.json");
}
}
},
Some(unknown) => log::error!("Failed to parse workspaces for {path_to_check:?} from package.json, got {unknown:?}. Skipping."),
None => log::warn!("Skipping path {path_to_check:?} that has no prettier dependency and no workspaces section in its package.json"),
}
}
}
if !path_to_check.pop() {
match project_path_with_prettier_dependency {
Some(closest_prettier_discovered) => {
anyhow::bail!("No prettier found in node_modules for ancestors of {locate_from:?}, but discovered prettier package.json dependency in {closest_prettier_discovered:?}")
}
None => {
log::debug!("Found no prettier in ancestors of {locate_from:?}");
return Ok(ControlFlow::Continue(None));
}
}
}
}
}
#[cfg(any(test, feature = "test-support"))]
pub async fn start(
_: LanguageServerId,
prettier_dir: PathBuf,
_: Arc<dyn NodeRuntime>,
_: AsyncAppContext,
) -> anyhow::Result<Self> {
Ok(Self::Test(TestPrettier {
default: prettier_dir == DEFAULT_PRETTIER_DIR.as_path(),
prettier_dir,
}))
}
#[cfg(not(any(test, feature = "test-support")))]
pub async fn start(
server_id: LanguageServerId,
prettier_dir: PathBuf,
node: Arc<dyn NodeRuntime>,
cx: AsyncAppContext,
) -> anyhow::Result<Self> {
use lsp::LanguageServerBinary;
let executor = cx.background_executor().clone();
anyhow::ensure!(
prettier_dir.is_dir(),
"Prettier dir {prettier_dir:?} is not a directory"
);
let prettier_server = DEFAULT_PRETTIER_DIR.join(PRETTIER_SERVER_FILE);
anyhow::ensure!(
prettier_server.is_file(),
"no prettier server package found at {prettier_server:?}"
);
let node_path = executor
.spawn(async move { node.binary_path().await })
.await?;
let server = LanguageServer::new(
Arc::new(parking_lot::Mutex::new(None)),
server_id,
LanguageServerBinary {
path: node_path,
arguments: vec![prettier_server.into(), prettier_dir.as_path().into()],
},
Path::new("/"),
None,
cx,
)
.context("prettier server creation")?;
let server = executor
.spawn(server.initialize(None))
.await
.context("prettier server initialization")?;
Ok(Self::Real(RealPrettier {
server,
default: prettier_dir == DEFAULT_PRETTIER_DIR.as_path(),
prettier_dir,
}))
}
pub async fn format(
&self,
buffer: &Model<Buffer>,
buffer_path: Option<PathBuf>,
cx: &mut AsyncAppContext,
) -> anyhow::Result<Diff> {
match self {
Self::Real(local) => {
let params = buffer
.update(cx, |buffer, cx| {
let buffer_language = buffer.language();
let parser_with_plugins = buffer_language.and_then(|l| {
let prettier_parser = l.prettier_parser_name()?;
let mut prettier_plugins = l
.lsp_adapters()
.iter()
.flat_map(|adapter| adapter.prettier_plugins())
.collect::<Vec<_>>();
prettier_plugins.dedup();
Some((prettier_parser, prettier_plugins))
});
let prettier_node_modules = self.prettier_dir().join("node_modules");
anyhow::ensure!(
prettier_node_modules.is_dir(),
"Prettier node_modules dir does not exist: {prettier_node_modules:?}"
);
let plugin_name_into_path = |plugin_name: &str| {
let prettier_plugin_dir = prettier_node_modules.join(plugin_name);
for possible_plugin_path in [
prettier_plugin_dir.join("dist").join("index.mjs"),
prettier_plugin_dir.join("dist").join("index.js"),
prettier_plugin_dir.join("dist").join("plugin.js"),
prettier_plugin_dir.join("index.mjs"),
prettier_plugin_dir.join("index.js"),
prettier_plugin_dir.join("plugin.js"),
prettier_plugin_dir,
] {
if possible_plugin_path.is_file() {
return Some(possible_plugin_path);
}
}
None
};
let (parser, located_plugins) = match parser_with_plugins {
Some((parser, plugins)) => {
// Tailwind plugin requires being added last
// https://github.com/tailwindlabs/prettier-plugin-tailwindcss#compatibility-with-other-prettier-plugins
let mut add_tailwind_back = false;
let mut plugins = plugins
.into_iter()
.filter(|&&plugin_name| {
if plugin_name == TAILWIND_PRETTIER_PLUGIN_PACKAGE_NAME {
add_tailwind_back = true;
false
} else {
true
}
})
.map(|plugin_name| {
(plugin_name, plugin_name_into_path(plugin_name))
})
.collect::<Vec<_>>();
if add_tailwind_back {
plugins.push((
&TAILWIND_PRETTIER_PLUGIN_PACKAGE_NAME,
plugin_name_into_path(
TAILWIND_PRETTIER_PLUGIN_PACKAGE_NAME,
),
));
}
(Some(parser.to_string()), plugins)
}
None => (None, Vec::new()),
};
let prettier_options = if self.is_default() {
let language_settings =
language_settings(buffer_language, buffer.file(), cx);
let mut options = language_settings.prettier.clone();
if !options.contains_key("tabWidth") {
options.insert(
"tabWidth".to_string(),
serde_json::Value::Number(serde_json::Number::from(
language_settings.tab_size.get(),
)),
);
}
if !options.contains_key("printWidth") {
options.insert(
"printWidth".to_string(),
serde_json::Value::Number(serde_json::Number::from(
language_settings.preferred_line_length,
)),
);
}
Some(options)
} else {
None
};
let plugins = located_plugins
.into_iter()
.filter_map(|(plugin_name, located_plugin_path)| {
match located_plugin_path {
Some(path) => Some(path),
None => {
log::error!(
"Have not found plugin path for {:?} inside {:?}",
plugin_name,
prettier_node_modules
);
None
}
}
})
.collect();
log::debug!(
"Formatting file {:?} with prettier, plugins :{:?}, options: {:?}",
plugins,
prettier_options,
buffer.file().map(|f| f.full_path(cx))
);
anyhow::Ok(FormatParams {
text: buffer.text(),
options: FormatOptions {
parser,
plugins,
path: buffer_path,
prettier_options,
},
})
})?
.context("prettier params calculation")?;
let response = local
.server
.request::<Format>(params)
.await
.context("prettier format request")?;
let diff_task = buffer.update(cx, |buffer, cx| buffer.diff(response.text, cx))?;
Ok(diff_task.await)
}
#[cfg(any(test, feature = "test-support"))]
Self::Test(_) => Ok(buffer
.update(cx, |buffer, cx| {
let formatted_text = buffer.text() + FORMAT_SUFFIX;
buffer.diff(formatted_text, cx)
})?
.await),
}
}
pub async fn clear_cache(&self) -> anyhow::Result<()> {
match self {
Self::Real(local) => local
.server
.request::<ClearCache>(())
.await
.context("prettier clear cache"),
#[cfg(any(test, feature = "test-support"))]
Self::Test(_) => Ok(()),
}
}
pub fn server(&self) -> Option<&Arc<LanguageServer>> {
match self {
Self::Real(local) => Some(&local.server),
#[cfg(any(test, feature = "test-support"))]
Self::Test(_) => None,
}
}
pub fn is_default(&self) -> bool {
match self {
Self::Real(local) => local.default,
#[cfg(any(test, feature = "test-support"))]
Self::Test(test_prettier) => test_prettier.default,
}
}
pub fn prettier_dir(&self) -> &Path {
match self {
Self::Real(local) => &local.prettier_dir,
#[cfg(any(test, feature = "test-support"))]
Self::Test(test_prettier) => &test_prettier.prettier_dir,
}
}
}
async fn has_prettier_in_node_modules(fs: &dyn Fs, path: &Path) -> anyhow::Result<bool> {
let possible_node_modules_location = path.join("node_modules").join(PRETTIER_PACKAGE_NAME);
if let Some(node_modules_location_metadata) = fs
.metadata(&possible_node_modules_location)
.await
.with_context(|| format!("fetching metadata for {possible_node_modules_location:?}"))?
{
return Ok(node_modules_location_metadata.is_dir);
}
Ok(false)
}
async fn read_package_json(
fs: &dyn Fs,
path: &Path,
) -> anyhow::Result<Option<HashMap<String, serde_json::Value>>> {
let possible_package_json = path.join("package.json");
if let Some(package_json_metadata) = fs
.metadata(&possible_package_json)
.await
.with_context(|| format!("fetching metadata for package json {possible_package_json:?}"))?
{
if !package_json_metadata.is_dir && !package_json_metadata.is_symlink {
let package_json_contents = fs
.load(&possible_package_json)
.await
.with_context(|| format!("reading {possible_package_json:?} file contents"))?;
return serde_json::from_str::<HashMap<String, serde_json::Value>>(
&package_json_contents,
)
.map(Some)
.with_context(|| format!("parsing {possible_package_json:?} file contents"));
}
}
Ok(None)
}
fn has_prettier_in_package_json(
package_json_contents: &HashMap<String, serde_json::Value>,
) -> bool {
if let Some(serde_json::Value::Object(o)) = package_json_contents.get("dependencies") {
if o.contains_key(PRETTIER_PACKAGE_NAME) {
return true;
}
}
if let Some(serde_json::Value::Object(o)) = package_json_contents.get("devDependencies") {
if o.contains_key(PRETTIER_PACKAGE_NAME) {
return true;
}
}
false
}
enum Format {}
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
struct FormatParams {
text: String,
options: FormatOptions,
}
#[derive(Clone, Debug, Default, Eq, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
struct FormatOptions {
plugins: Vec<PathBuf>,
parser: Option<String>,
#[serde(rename = "filepath")]
path: Option<PathBuf>,
prettier_options: Option<HashMap<String, serde_json::Value>>,
}
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
struct FormatResult {
text: String,
}
impl lsp::request::Request for Format {
type Params = FormatParams;
type Result = FormatResult;
const METHOD: &'static str = "prettier/format";
}
enum ClearCache {}
impl lsp::request::Request for ClearCache {
type Params = ();
type Result = ();
const METHOD: &'static str = "prettier/clear_cache";
}
#[cfg(test)]
mod tests {
use fs::FakeFs;
use serde_json::json;
use super::*;
#[gpui::test]
async fn test_prettier_lookup_finds_nothing(cx: &mut gpui::TestAppContext) {
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
"/root",
json!({
".config": {
"zed": {
"settings.json": r#"{ "formatter": "auto" }"#,
},
},
"work": {
"project": {
"src": {
"index.js": "// index.js file contents",
},
"node_modules": {
"expect": {
"build": {
"print.js": "// print.js file contents",
},
"package.json": r#"{
"devDependencies": {
"prettier": "2.5.1"
}
}"#,
},
"prettier": {
"index.js": "// Dummy prettier package file",
},
},
"package.json": r#"{}"#
},
}
}),
)
.await;
assert!(
matches!(
Prettier::locate_prettier_installation(
fs.as_ref(),
&HashSet::default(),
Path::new("/root/.config/zed/settings.json"),
)
.await,
Ok(ControlFlow::Continue(None))
),
"Should successfully find no prettier for path hierarchy without it"
);
assert!(
matches!(
Prettier::locate_prettier_installation(
fs.as_ref(),
&HashSet::default(),
Path::new("/root/work/project/src/index.js")
)
.await,
Ok(ControlFlow::Continue(None))
),
"Should successfully find no prettier for path hierarchy that has node_modules with prettier, but no package.json mentions of it"
);
assert!(
matches!(
Prettier::locate_prettier_installation(
fs.as_ref(),
&HashSet::default(),
Path::new("/root/work/project/node_modules/expect/build/print.js")
)
.await,
Ok(ControlFlow::Break(()))
),
"Should not format files inside node_modules/"
);
}
#[gpui::test]
async fn test_prettier_lookup_in_simple_npm_projects(cx: &mut gpui::TestAppContext) {
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
"/root",
json!({
"web_blog": {
"node_modules": {
"prettier": {
"index.js": "// Dummy prettier package file",
},
"expect": {
"build": {
"print.js": "// print.js file contents",
},
"package.json": r#"{
"devDependencies": {
"prettier": "2.5.1"
}
}"#,
},
},
"pages": {
"[slug].tsx": "// [slug].tsx file contents",
},
"package.json": r#"{
"devDependencies": {
"prettier": "2.3.0"
},
"prettier": {
"semi": false,
"printWidth": 80,
"htmlWhitespaceSensitivity": "strict",
"tabWidth": 4
}
}"#
}
}),
)
.await;
assert_eq!(
Prettier::locate_prettier_installation(
fs.as_ref(),
&HashSet::default(),
Path::new("/root/web_blog/pages/[slug].tsx")
)
.await
.unwrap(),
ControlFlow::Continue(Some(PathBuf::from("/root/web_blog"))),
"Should find a preinstalled prettier in the project root"
);
assert_eq!(
Prettier::locate_prettier_installation(
fs.as_ref(),
&HashSet::default(),
Path::new("/root/web_blog/node_modules/expect/build/print.js")
)
.await
.unwrap(),
ControlFlow::Break(()),
"Should not allow formatting node_modules/ contents"
);
}
#[gpui::test]
async fn test_prettier_lookup_for_not_installed(cx: &mut gpui::TestAppContext) {
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
"/root",
json!({
"work": {
"web_blog": {
"node_modules": {
"expect": {
"build": {
"print.js": "// print.js file contents",
},
"package.json": r#"{
"devDependencies": {
"prettier": "2.5.1"
}
}"#,
},
},
"pages": {
"[slug].tsx": "// [slug].tsx file contents",
},
"package.json": r#"{
"devDependencies": {
"prettier": "2.3.0"
},
"prettier": {
"semi": false,
"printWidth": 80,
"htmlWhitespaceSensitivity": "strict",
"tabWidth": 4
}
}"#
}
}
}),
)
.await;
match Prettier::locate_prettier_installation(
fs.as_ref(),
&HashSet::default(),
Path::new("/root/work/web_blog/pages/[slug].tsx")
)
.await {
Ok(path) => panic!("Expected to fail for prettier in package.json but not in node_modules found, but got path {path:?}"),
Err(e) => {
let message = e.to_string();
assert!(message.contains("/root/work/web_blog"), "Error message should mention which project had prettier defined");
},
};
assert_eq!(
Prettier::locate_prettier_installation(
fs.as_ref(),
&HashSet::from_iter(
[PathBuf::from("/root"), PathBuf::from("/root/work")].into_iter()
),
Path::new("/root/work/web_blog/pages/[slug].tsx")
)
.await
.unwrap(),
ControlFlow::Continue(Some(PathBuf::from("/root/work"))),
"Should return closest cached value found without path checks"
);
assert_eq!(
Prettier::locate_prettier_installation(
fs.as_ref(),
&HashSet::default(),
Path::new("/root/work/web_blog/node_modules/expect/build/print.js")
)
.await
.unwrap(),
ControlFlow::Break(()),
"Should not allow formatting files inside node_modules/"
);
assert_eq!(
Prettier::locate_prettier_installation(
fs.as_ref(),
&HashSet::from_iter(
[PathBuf::from("/root"), PathBuf::from("/root/work")].into_iter()
),
Path::new("/root/work/web_blog/node_modules/expect/build/print.js")
)
.await
.unwrap(),
ControlFlow::Break(()),
"Should ignore cache lookup for files inside node_modules/"
);
}
#[gpui::test]
async fn test_prettier_lookup_in_npm_workspaces(cx: &mut gpui::TestAppContext) {
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
"/root",
json!({
"work": {
"full-stack-foundations": {
"exercises": {
"03.loading": {
"01.problem.loader": {
"app": {
"routes": {
"users+": {
"$username_+": {
"notes.tsx": "// notes.tsx file contents",
},
},
},
},
"node_modules": {
"test.js": "// test.js contents",
},
"package.json": r#"{
"devDependencies": {
"prettier": "^3.0.3"
}
}"#
},
},
},
"package.json": r#"{
"workspaces": ["exercises/*/*", "examples/*"]
}"#,
"node_modules": {
"prettier": {
"index.js": "// Dummy prettier package file",
},
},
},
}
}),
)
.await;
assert_eq!(
Prettier::locate_prettier_installation(
fs.as_ref(),
&HashSet::default(),
Path::new("/root/work/full-stack-foundations/exercises/03.loading/01.problem.loader/app/routes/users+/$username_+/notes.tsx"),
).await.unwrap(),
ControlFlow::Continue(Some(PathBuf::from("/root/work/full-stack-foundations"))),
"Should ascend to the multi-workspace root and find the prettier there",
);
assert_eq!(
Prettier::locate_prettier_installation(
fs.as_ref(),
&HashSet::default(),
Path::new("/root/work/full-stack-foundations/node_modules/prettier/index.js")
)
.await
.unwrap(),
ControlFlow::Break(()),
"Should not allow formatting files inside root node_modules/"
);
assert_eq!(
Prettier::locate_prettier_installation(
fs.as_ref(),
&HashSet::default(),
Path::new("/root/work/full-stack-foundations/exercises/03.loading/01.problem.loader/node_modules/test.js")
)
.await
.unwrap(),
ControlFlow::Break(()),
"Should not allow formatting files inside submodule's node_modules/"
);
}
#[gpui::test]
async fn test_prettier_lookup_in_npm_workspaces_for_not_installed(
cx: &mut gpui::TestAppContext,
) {
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
"/root",
json!({
"work": {
"full-stack-foundations": {
"exercises": {
"03.loading": {
"01.problem.loader": {
"app": {
"routes": {
"users+": {
"$username_+": {
"notes.tsx": "// notes.tsx file contents",
},
},
},
},
"node_modules": {},
"package.json": r#"{
"devDependencies": {
"prettier": "^3.0.3"
}
}"#
},
},
},
"package.json": r#"{
"workspaces": ["exercises/*/*", "examples/*"]
}"#,
},
}
}),
)
.await;
match Prettier::locate_prettier_installation(
fs.as_ref(),
&HashSet::default(),
Path::new("/root/work/full-stack-foundations/exercises/03.loading/01.problem.loader/app/routes/users+/$username_+/notes.tsx")
)
.await {
Ok(path) => panic!("Expected to fail for prettier in package.json but not in node_modules found, but got path {path:?}"),
Err(e) => {
let message = e.to_string();
assert!(message.contains("/root/work/full-stack-foundations/exercises/03.loading/01.problem.loader"), "Error message should mention which project had prettier defined");
assert!(message.contains("/root/work/full-stack-foundations"), "Error message should mention potential candidates without prettier node_modules contents");
},
};
}
}

View File

@ -1,241 +0,0 @@
const { Buffer } = require("buffer");
const fs = require("fs");
const path = require("path");
const { once } = require("events");
const prettierContainerPath = process.argv[2];
if (prettierContainerPath == null || prettierContainerPath.length == 0) {
process.stderr.write(
`Prettier path argument was not specified or empty.\nUsage: ${process.argv[0]} ${process.argv[1]} prettier/path\n`,
);
process.exit(1);
}
fs.stat(prettierContainerPath, (err, stats) => {
if (err) {
process.stderr.write(`Path '${prettierContainerPath}' does not exist\n`);
process.exit(1);
}
if (!stats.isDirectory()) {
process.stderr.write(`Path '${prettierContainerPath}' exists but is not a directory\n`);
process.exit(1);
}
});
const prettierPath = path.join(prettierContainerPath, "node_modules/prettier");
class Prettier {
constructor(path, prettier, config) {
this.path = path;
this.prettier = prettier;
this.config = config;
}
}
(async () => {
let prettier;
let config;
try {
prettier = await loadPrettier(prettierPath);
config = (await prettier.resolveConfig(prettierPath)) || {};
} catch (e) {
process.stderr.write(`Failed to load prettier: ${e}\n`);
process.exit(1);
}
process.stderr.write(`Prettier at path '${prettierPath}' loaded successfully, config: ${JSON.stringify(config)}\n`);
process.stdin.resume();
handleBuffer(new Prettier(prettierPath, prettier, config));
})();
async function handleBuffer(prettier) {
for await (const messageText of readStdin()) {
let message;
try {
message = JSON.parse(messageText);
} catch (e) {
sendResponse(makeError(`Failed to parse message '${messageText}': ${e}`));
continue;
}
// allow concurrent request handling by not `await`ing the message handling promise (async function)
handleMessage(message, prettier).catch((e) => {
const errorMessage = message;
if ((errorMessage.params || {}).text !== undefined) {
errorMessage.params.text = "..snip..";
}
sendResponse({
id: message.id,
...makeError(`error during message '${JSON.stringify(errorMessage)}' handling: ${e}`),
});
});
}
}
const headerSeparator = "\r\n";
const contentLengthHeaderName = "Content-Length";
async function* readStdin() {
let buffer = Buffer.alloc(0);
let streamEnded = false;
process.stdin.on("end", () => {
streamEnded = true;
});
process.stdin.on("data", (data) => {
buffer = Buffer.concat([buffer, data]);
});
async function handleStreamEnded(errorMessage) {
sendResponse(makeError(errorMessage));
buffer = Buffer.alloc(0);
messageLength = null;
await once(process.stdin, "readable");
streamEnded = false;
}
try {
let headersLength = null;
let messageLength = null;
main_loop: while (true) {
if (messageLength === null) {
while (buffer.indexOf(`${headerSeparator}${headerSeparator}`) === -1) {
if (streamEnded) {
await handleStreamEnded("Unexpected end of stream: headers not found");
continue main_loop;
} else if (buffer.length > contentLengthHeaderName.length * 10) {
await handleStreamEnded(
`Unexpected stream of bytes: no headers end found after ${buffer.length} bytes of input`,
);
continue main_loop;
}
await once(process.stdin, "readable");
}
const headers = buffer
.subarray(0, buffer.indexOf(`${headerSeparator}${headerSeparator}`))
.toString("ascii");
const contentLengthHeader = headers
.split(headerSeparator)
.map((header) => header.split(":"))
.filter((header) => header[2] === undefined)
.filter((header) => (header[1] || "").length > 0)
.find((header) => (header[0] || "").trim() === contentLengthHeaderName);
const contentLength = (contentLengthHeader || [])[1];
if (contentLength === undefined) {
await handleStreamEnded(`Missing or incorrect ${contentLengthHeaderName} header: ${headers}`);
continue main_loop;
}
headersLength = headers.length + headerSeparator.length * 2;
messageLength = parseInt(contentLength, 10);
}
while (buffer.length < headersLength + messageLength) {
if (streamEnded) {
await handleStreamEnded(
`Unexpected end of stream: buffer length ${buffer.length} does not match expected header length ${headersLength} + body length ${messageLength}`,
);
continue main_loop;
}
await once(process.stdin, "readable");
}
const messageEnd = headersLength + messageLength;
const message = buffer.subarray(headersLength, messageEnd);
buffer = buffer.subarray(messageEnd);
headersLength = null;
messageLength = null;
yield message.toString("utf8");
}
} catch (e) {
sendResponse(makeError(`Error reading stdin: ${e}`));
} finally {
process.stdin.off("data", () => {});
}
}
async function handleMessage(message, prettier) {
const { method, id, params } = message;
if (method === undefined) {
throw new Error(`Message method is undefined: ${JSON.stringify(message)}`);
} else if (method == "initialized") {
return;
}
if (id === undefined) {
throw new Error(`Message id is undefined: ${JSON.stringify(message)}`);
}
if (method === "prettier/format") {
if (params === undefined || params.text === undefined) {
throw new Error(`Message params.text is undefined: ${JSON.stringify(message)}`);
}
if (params.options === undefined) {
throw new Error(`Message params.options is undefined: ${JSON.stringify(message)}`);
}
let resolvedConfig = {};
if (params.options.filepath !== undefined) {
resolvedConfig = (await prettier.prettier.resolveConfig(params.options.filepath)) || {};
}
const options = {
...(params.options.prettierOptions || prettier.config),
...resolvedConfig,
parser: params.options.parser,
plugins: params.options.plugins,
path: params.options.filepath,
};
process.stderr.write(
`Resolved config: ${JSON.stringify(resolvedConfig)}, will format file '${
params.options.filepath || ""
}' with options: ${JSON.stringify(options)}\n`,
);
const formattedText = await prettier.prettier.format(params.text, options);
sendResponse({ id, result: { text: formattedText } });
} else if (method === "prettier/clear_cache") {
prettier.prettier.clearConfigCache();
prettier.config = (await prettier.prettier.resolveConfig(prettier.path)) || {};
sendResponse({ id, result: null });
} else if (method === "initialize") {
sendResponse({
id,
result: {
capabilities: {},
},
});
} else {
throw new Error(`Unknown method: ${method}`);
}
}
function makeError(message) {
return {
error: {
code: -32600, // invalid request code
message,
},
};
}
function sendResponse(response) {
const responsePayloadString = JSON.stringify({
jsonrpc: "2.0",
...response,
});
const headers = `${contentLengthHeaderName}: ${Buffer.byteLength(
responsePayloadString,
)}${headerSeparator}${headerSeparator}`;
process.stdout.write(headers + responsePayloadString);
}
function loadPrettier(prettierPath) {
return new Promise((resolve, reject) => {
fs.access(prettierPath, fs.constants.F_OK, (err) => {
if (err) {
reject(`Path '${prettierPath}' does not exist.Error: ${err}`);
} else {
try {
resolve(require(prettierPath));
} catch (err) {
reject(`Error requiring prettier module from path '${prettierPath}'.Error: ${err}`);
}
}
});
});
}

View File

@ -25,8 +25,8 @@ copilot = { path = "../copilot" }
client = { path = "../client" }
clock = { path = "../clock" }
collections = { path = "../collections" }
db = { package = "db2", path = "../db2" }
fs = { package = "fs2", path = "../fs2" }
db = { path = "../db" }
fs = { path = "../fs" }
fsevent = { path = "../fsevent" }
fuzzy = { path = "../fuzzy" }
git = { package = "git3", path = "../git3" }
@ -34,8 +34,8 @@ gpui = { package = "gpui2", path = "../gpui2" }
language = { path = "../language" }
lsp = { path = "../lsp" }
node_runtime = { path = "../node_runtime" }
prettier = { package = "prettier2", path = "../prettier2" }
rpc = { package = "rpc2", path = "../rpc2" }
prettier = { path = "../prettier" }
rpc = { path = "../rpc" }
settings = { package = "settings2", path = "../settings2" }
sum_tree = { path = "../sum_tree" }
terminal = { package = "terminal2", path = "../terminal2" }
@ -71,15 +71,15 @@ env_logger.workspace = true
pretty_assertions.workspace = true
client = { path = "../client", features = ["test-support"] }
collections = { path = "../collections", features = ["test-support"] }
db = { package = "db2", path = "../db2", features = ["test-support"] }
fs = { package = "fs2", path = "../fs2", features = ["test-support"] }
db = { path = "../db", features = ["test-support"] }
fs = { path = "../fs", features = ["test-support"] }
gpui = { package = "gpui2", path = "../gpui2", features = ["test-support"] }
language = { path = "../language", features = ["test-support"] }
lsp = { path = "../lsp", features = ["test-support"] }
settings = { package = "settings2", path = "../settings2", features = ["test-support"] }
prettier = { package = "prettier2", path = "../prettier2", features = ["test-support"] }
prettier = { path = "../prettier", features = ["test-support"] }
util = { path = "../util", features = ["test-support"] }
rpc = { package = "rpc2", path = "../rpc2", features = ["test-support"] }
rpc = { path = "../rpc", features = ["test-support"] }
git2.workspace = true
tempdir.workspace = true
unindent.workspace = true

View File

@ -10,7 +10,7 @@ doctest = false
[dependencies]
collections = { path = "../collections" }
db = { path = "../db2", package = "db2" }
db = { path = "../db" }
editor = { path = "../editor" }
gpui = { path = "../gpui2", package = "gpui2" }
menu = { path = "../menu2", package = "menu2" }

View File

@ -15,9 +15,8 @@ test-support = ["collections/test-support", "gpui/test-support"]
[dependencies]
clock = { path = "../clock" }
collections = { path = "../collections" }
gpui = { path = "../gpui", optional = true }
gpui = { package = "gpui2", path = "../gpui2", optional = true }
util = { path = "../util" }
anyhow.workspace = true
async-lock = "2.4"
async-tungstenite = "0.16"
@ -27,8 +26,8 @@ parking_lot.workspace = true
prost.workspace = true
rand.workspace = true
rsa = "0.4"
serde.workspace = true
serde_json.workspace = true
serde.workspace = true
serde_derive.workspace = true
smol-timeout = "0.6"
strum.workspace = true
@ -40,7 +39,7 @@ prost-build = "0.9"
[dev-dependencies]
collections = { path = "../collections", features = ["test-support"] }
gpui = { path = "../gpui", features = ["test-support"] }
gpui = { package = "gpui2", path = "../gpui2", features = ["test-support"] }
smol.workspace = true
tempdir.workspace = true
ctor.workspace = true

View File

@ -34,7 +34,7 @@ impl Connection {
#[cfg(any(test, feature = "test-support"))]
pub fn in_memory(
executor: std::sync::Arc<gpui::executor::Background>,
executor: gpui::BackgroundExecutor,
) -> (Self, Self, std::sync::Arc<std::sync::atomic::AtomicBool>) {
use std::sync::{
atomic::{AtomicBool, Ordering::SeqCst},
@ -53,7 +53,7 @@ impl Connection {
#[allow(clippy::type_complexity)]
fn channel(
killed: Arc<AtomicBool>,
executor: Arc<gpui::executor::Background>,
executor: gpui::BackgroundExecutor,
) -> (
Box<dyn Send + Unpin + futures::Sink<WebSocketMessage, Error = anyhow::Error>>,
Box<dyn Send + Unpin + futures::Stream<Item = Result<WebSocketMessage, anyhow::Error>>>,
@ -66,14 +66,12 @@ impl Connection {
let tx = tx.sink_map_err(|error| anyhow!(error)).with({
let killed = killed.clone();
let executor = Arc::downgrade(&executor);
let executor = executor.clone();
move |msg| {
let killed = killed.clone();
let executor = executor.clone();
Box::pin(async move {
if let Some(executor) = executor.upgrade() {
executor.simulate_random_delay().await;
}
executor.simulate_random_delay().await;
// Writes to a half-open TCP connection will error.
if killed.load(SeqCst) {
@ -87,14 +85,12 @@ impl Connection {
let rx = rx.then({
let killed = killed;
let executor = Arc::downgrade(&executor);
let executor = executor.clone();
move |msg| {
let killed = killed.clone();
let executor = executor.clone();
Box::pin(async move {
if let Some(executor) = executor.upgrade() {
executor.simulate_random_delay().await;
}
executor.simulate_random_delay().await;
// Reads from a half-open TCP connection will hang.
if killed.load(SeqCst) {

View File

@ -1,46 +0,0 @@
[package]
description = "Shared logic for communication between the Zed app and the zed.dev server"
edition = "2021"
name = "rpc2"
version = "0.1.0"
publish = false
[lib]
path = "src/rpc.rs"
doctest = false
[features]
test-support = ["collections/test-support", "gpui/test-support"]
[dependencies]
clock = { path = "../clock" }
collections = { path = "../collections" }
gpui = { package = "gpui2", path = "../gpui2", optional = true }
util = { path = "../util" }
anyhow.workspace = true
async-lock = "2.4"
async-tungstenite = "0.16"
base64 = "0.13"
futures.workspace = true
parking_lot.workspace = true
prost.workspace = true
rand.workspace = true
rsa = "0.4"
serde_json.workspace = true
serde.workspace = true
serde_derive.workspace = true
smol-timeout = "0.6"
strum.workspace = true
tracing = { version = "0.1.34", features = ["log"] }
zstd = "0.11"
[build-dependencies]
prost-build = "0.9"
[dev-dependencies]
collections = { path = "../collections", features = ["test-support"] }
gpui = { package = "gpui2", path = "../gpui2", features = ["test-support"] }
smol.workspace = true
tempdir.workspace = true
ctor.workspace = true
env_logger.workspace = true

View File

@ -1,8 +0,0 @@
fn main() {
let mut build = prost_build::Config::new();
// build.protoc_arg("--experimental_allow_proto3_optional");
build
.type_attribute(".", "#[derive(serde::Serialize)]")
.compile_protos(&["proto/zed.proto"], &["proto"])
.unwrap();
}

File diff suppressed because it is too large Load Diff

View File

@ -1,136 +0,0 @@
use anyhow::{Context, Result};
use rand::{thread_rng, Rng as _};
use rsa::{PublicKey as _, PublicKeyEncoding, RSAPrivateKey, RSAPublicKey};
use std::convert::TryFrom;
pub struct PublicKey(RSAPublicKey);
pub struct PrivateKey(RSAPrivateKey);
/// Generate a public and private key for asymmetric encryption.
pub fn keypair() -> Result<(PublicKey, PrivateKey)> {
let mut rng = thread_rng();
let bits = 1024;
let private_key = RSAPrivateKey::new(&mut rng, bits)?;
let public_key = RSAPublicKey::from(&private_key);
Ok((PublicKey(public_key), PrivateKey(private_key)))
}
/// Generate a random 64-character base64 string.
pub fn random_token() -> String {
let mut rng = thread_rng();
let mut token_bytes = [0; 48];
for byte in token_bytes.iter_mut() {
*byte = rng.gen();
}
base64::encode_config(token_bytes, base64::URL_SAFE)
}
impl PublicKey {
/// Convert a string to a base64-encoded string that can only be decoded with the corresponding
/// private key.
pub fn encrypt_string(&self, string: &str) -> Result<String> {
let mut rng = thread_rng();
let bytes = string.as_bytes();
let encrypted_bytes = self
.0
.encrypt(&mut rng, PADDING_SCHEME, bytes)
.context("failed to encrypt string with public key")?;
let encrypted_string = base64::encode_config(&encrypted_bytes, base64::URL_SAFE);
Ok(encrypted_string)
}
}
impl PrivateKey {
/// Decrypt a base64-encoded string that was encrypted by the corresponding public key.
pub fn decrypt_string(&self, encrypted_string: &str) -> Result<String> {
let encrypted_bytes = base64::decode_config(encrypted_string, base64::URL_SAFE)
.context("failed to base64-decode encrypted string")?;
let bytes = self
.0
.decrypt(PADDING_SCHEME, &encrypted_bytes)
.context("failed to decrypt string with private key")?;
let string = String::from_utf8(bytes).context("decrypted content was not valid utf8")?;
Ok(string)
}
}
impl TryFrom<PublicKey> for String {
type Error = anyhow::Error;
fn try_from(key: PublicKey) -> Result<Self> {
let bytes = key.0.to_pkcs1().context("failed to serialize public key")?;
let string = base64::encode_config(&bytes, base64::URL_SAFE);
Ok(string)
}
}
impl TryFrom<String> for PublicKey {
type Error = anyhow::Error;
fn try_from(value: String) -> Result<Self> {
let bytes = base64::decode_config(&value, base64::URL_SAFE)
.context("failed to base64-decode public key string")?;
let key = Self(RSAPublicKey::from_pkcs1(&bytes).context("failed to parse public key")?);
Ok(key)
}
}
const PADDING_SCHEME: rsa::PaddingScheme = rsa::PaddingScheme::PKCS1v15Encrypt;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_generate_encrypt_and_decrypt_token() {
// CLIENT:
// * generate a keypair for asymmetric encryption
// * serialize the public key to send it to the server.
let (public, private) = keypair().unwrap();
let public_string = String::try_from(public).unwrap();
assert_printable(&public_string);
// SERVER:
// * parse the public key
// * generate a random token.
// * encrypt the token using the public key.
let public = PublicKey::try_from(public_string).unwrap();
let token = random_token();
let encrypted_token = public.encrypt_string(&token).unwrap();
assert_eq!(token.len(), 64);
assert_ne!(encrypted_token, token);
assert_printable(&token);
assert_printable(&encrypted_token);
// CLIENT:
// * decrypt the token using the private key.
let decrypted_token = private.decrypt_string(&encrypted_token).unwrap();
assert_eq!(decrypted_token, token);
}
#[test]
fn test_tokens_are_always_url_safe() {
for _ in 0..5 {
let token = random_token();
let (public_key, _) = keypair().unwrap();
let encrypted_token = public_key.encrypt_string(&token).unwrap();
let public_key_str = String::try_from(public_key).unwrap();
assert_printable(&token);
assert_printable(&public_key_str);
assert_printable(&encrypted_token);
}
}
fn assert_printable(token: &str) {
for c in token.chars() {
assert!(
c.is_ascii_graphic(),
"token {:?} has non-printable char {}",
token,
c
);
assert_ne!(c, '/', "token {:?} is not URL-safe", token);
assert_ne!(c, '&', "token {:?} is not URL-safe", token);
}
}
}

View File

@ -1,108 +0,0 @@
use async_tungstenite::tungstenite::Message as WebSocketMessage;
use futures::{SinkExt as _, StreamExt as _};
pub struct Connection {
pub(crate) tx:
Box<dyn 'static + Send + Unpin + futures::Sink<WebSocketMessage, Error = anyhow::Error>>,
pub(crate) rx: Box<
dyn 'static
+ Send
+ Unpin
+ futures::Stream<Item = Result<WebSocketMessage, anyhow::Error>>,
>,
}
impl Connection {
pub fn new<S>(stream: S) -> Self
where
S: 'static
+ Send
+ Unpin
+ futures::Sink<WebSocketMessage, Error = anyhow::Error>
+ futures::Stream<Item = Result<WebSocketMessage, anyhow::Error>>,
{
let (tx, rx) = stream.split();
Self {
tx: Box::new(tx),
rx: Box::new(rx),
}
}
pub async fn send(&mut self, message: WebSocketMessage) -> Result<(), anyhow::Error> {
self.tx.send(message).await
}
#[cfg(any(test, feature = "test-support"))]
pub fn in_memory(
executor: gpui::BackgroundExecutor,
) -> (Self, Self, std::sync::Arc<std::sync::atomic::AtomicBool>) {
use std::sync::{
atomic::{AtomicBool, Ordering::SeqCst},
Arc,
};
let killed = Arc::new(AtomicBool::new(false));
let (a_tx, a_rx) = channel(killed.clone(), executor.clone());
let (b_tx, b_rx) = channel(killed.clone(), executor);
return (
Self { tx: a_tx, rx: b_rx },
Self { tx: b_tx, rx: a_rx },
killed,
);
#[allow(clippy::type_complexity)]
fn channel(
killed: Arc<AtomicBool>,
executor: gpui::BackgroundExecutor,
) -> (
Box<dyn Send + Unpin + futures::Sink<WebSocketMessage, Error = anyhow::Error>>,
Box<dyn Send + Unpin + futures::Stream<Item = Result<WebSocketMessage, anyhow::Error>>>,
) {
use anyhow::anyhow;
use futures::channel::mpsc;
use std::io::{Error, ErrorKind};
let (tx, rx) = mpsc::unbounded::<WebSocketMessage>();
let tx = tx.sink_map_err(|error| anyhow!(error)).with({
let killed = killed.clone();
let executor = executor.clone();
move |msg| {
let killed = killed.clone();
let executor = executor.clone();
Box::pin(async move {
executor.simulate_random_delay().await;
// Writes to a half-open TCP connection will error.
if killed.load(SeqCst) {
std::io::Result::Err(Error::new(ErrorKind::Other, "connection lost"))?;
}
Ok(msg)
})
}
});
let rx = rx.then({
let killed = killed;
let executor = executor.clone();
move |msg| {
let killed = killed.clone();
let executor = executor.clone();
Box::pin(async move {
executor.simulate_random_delay().await;
// Reads from a half-open TCP connection will hang.
if killed.load(SeqCst) {
futures::future::pending::<()>().await;
}
Ok(msg)
})
}
});
(Box::new(tx), Box::new(rx))
}
}
}

View File

@ -1,70 +0,0 @@
#[macro_export]
macro_rules! messages {
($(($name:ident, $priority:ident)),* $(,)?) => {
pub fn build_typed_envelope(sender_id: ConnectionId, envelope: Envelope) -> Option<Box<dyn AnyTypedEnvelope>> {
match envelope.payload {
$(Some(envelope::Payload::$name(payload)) => {
Some(Box::new(TypedEnvelope {
sender_id,
original_sender_id: envelope.original_sender_id.map(|original_sender| PeerId {
owner_id: original_sender.owner_id,
id: original_sender.id
}),
message_id: envelope.id,
payload,
}))
}, )*
_ => None
}
}
$(
impl EnvelopedMessage for $name {
const NAME: &'static str = std::stringify!($name);
const PRIORITY: MessagePriority = MessagePriority::$priority;
fn into_envelope(
self,
id: u32,
responding_to: Option<u32>,
original_sender_id: Option<PeerId>,
) -> Envelope {
Envelope {
id,
responding_to,
original_sender_id,
payload: Some(envelope::Payload::$name(self)),
}
}
fn from_envelope(envelope: Envelope) -> Option<Self> {
if let Some(envelope::Payload::$name(msg)) = envelope.payload {
Some(msg)
} else {
None
}
}
}
)*
};
}
#[macro_export]
macro_rules! request_messages {
($(($request_name:ident, $response_name:ident)),* $(,)?) => {
$(impl RequestMessage for $request_name {
type Response = $response_name;
})*
};
}
#[macro_export]
macro_rules! entity_messages {
($id_field:ident, $($name:ident),* $(,)?) => {
$(impl EntityMessage for $name {
fn remote_entity_id(&self) -> u64 {
self.$id_field
}
})*
};
}

View File

@ -1,105 +0,0 @@
use crate::proto;
use serde::{Deserialize, Serialize};
use serde_json::{map, Value};
use strum::{EnumVariantNames, VariantNames as _};
const KIND: &'static str = "kind";
const ENTITY_ID: &'static str = "entity_id";
/// A notification that can be stored, associated with a given recipient.
///
/// This struct is stored in the collab database as JSON, so it shouldn't be
/// changed in a backward-incompatible way. For example, when renaming a
/// variant, add a serde alias for the old name.
///
/// Most notification types have a special field which is aliased to
/// `entity_id`. This field is stored in its own database column, and can
/// be used to query the notification.
#[derive(Debug, Clone, PartialEq, Eq, EnumVariantNames, Serialize, Deserialize)]
#[serde(tag = "kind")]
pub enum Notification {
ContactRequest {
#[serde(rename = "entity_id")]
sender_id: u64,
},
ContactRequestAccepted {
#[serde(rename = "entity_id")]
responder_id: u64,
},
ChannelInvitation {
#[serde(rename = "entity_id")]
channel_id: u64,
channel_name: String,
inviter_id: u64,
},
ChannelMessageMention {
#[serde(rename = "entity_id")]
message_id: u64,
sender_id: u64,
channel_id: u64,
},
}
impl Notification {
pub fn to_proto(&self) -> proto::Notification {
let mut value = serde_json::to_value(self).unwrap();
let mut entity_id = None;
let value = value.as_object_mut().unwrap();
let Some(Value::String(kind)) = value.remove(KIND) else {
unreachable!("kind is the enum tag")
};
if let map::Entry::Occupied(e) = value.entry(ENTITY_ID) {
if e.get().is_u64() {
entity_id = e.remove().as_u64();
}
}
proto::Notification {
kind,
entity_id,
content: serde_json::to_string(&value).unwrap(),
..Default::default()
}
}
pub fn from_proto(notification: &proto::Notification) -> Option<Self> {
let mut value = serde_json::from_str::<Value>(&notification.content).ok()?;
let object = value.as_object_mut()?;
object.insert(KIND.into(), notification.kind.to_string().into());
if let Some(entity_id) = notification.entity_id {
object.insert(ENTITY_ID.into(), entity_id.into());
}
serde_json::from_value(value).ok()
}
pub fn all_variant_names() -> &'static [&'static str] {
Self::VARIANTS
}
}
#[test]
fn test_notification() {
// Notifications can be serialized and deserialized.
for notification in [
Notification::ContactRequest { sender_id: 1 },
Notification::ContactRequestAccepted { responder_id: 2 },
Notification::ChannelInvitation {
channel_id: 100,
channel_name: "the-channel".into(),
inviter_id: 50,
},
Notification::ChannelMessageMention {
sender_id: 200,
channel_id: 30,
message_id: 1,
},
] {
let message = notification.to_proto();
let deserialized = Notification::from_proto(&message).unwrap();
assert_eq!(deserialized, notification);
}
// When notifications are serialized, the `kind` and `actor_id` fields are
// stored separately, and do not appear redundantly in the JSON.
let notification = Notification::ContactRequest { sender_id: 1 };
assert_eq!(notification.to_proto().content, "{}");
}

View File

@ -1,934 +0,0 @@
use super::{
proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, PeerId, RequestMessage},
Connection,
};
use anyhow::{anyhow, Context, Result};
use collections::HashMap;
use futures::{
channel::{mpsc, oneshot},
stream::BoxStream,
FutureExt, SinkExt, StreamExt, TryFutureExt,
};
use parking_lot::{Mutex, RwLock};
use serde::{ser::SerializeStruct, Serialize};
use std::{fmt, sync::atomic::Ordering::SeqCst};
use std::{
future::Future,
marker::PhantomData,
sync::{
atomic::{self, AtomicU32},
Arc,
},
time::Duration,
};
use tracing::instrument;
#[derive(Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, Hash, Debug, Serialize)]
pub struct ConnectionId {
pub owner_id: u32,
pub id: u32,
}
impl Into<PeerId> for ConnectionId {
fn into(self) -> PeerId {
PeerId {
owner_id: self.owner_id,
id: self.id,
}
}
}
impl From<PeerId> for ConnectionId {
fn from(peer_id: PeerId) -> Self {
Self {
owner_id: peer_id.owner_id,
id: peer_id.id,
}
}
}
impl fmt::Display for ConnectionId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}/{}", self.owner_id, self.id)
}
}
pub struct Receipt<T> {
pub sender_id: ConnectionId,
pub message_id: u32,
payload_type: PhantomData<T>,
}
impl<T> Clone for Receipt<T> {
fn clone(&self) -> Self {
Self {
sender_id: self.sender_id,
message_id: self.message_id,
payload_type: PhantomData,
}
}
}
impl<T> Copy for Receipt<T> {}
#[derive(Clone, Debug)]
pub struct TypedEnvelope<T> {
pub sender_id: ConnectionId,
pub original_sender_id: Option<PeerId>,
pub message_id: u32,
pub payload: T,
}
impl<T> TypedEnvelope<T> {
pub fn original_sender_id(&self) -> Result<PeerId> {
self.original_sender_id
.ok_or_else(|| anyhow!("missing original_sender_id"))
}
}
impl<T: RequestMessage> TypedEnvelope<T> {
pub fn receipt(&self) -> Receipt<T> {
Receipt {
sender_id: self.sender_id,
message_id: self.message_id,
payload_type: PhantomData,
}
}
}
pub struct Peer {
epoch: AtomicU32,
pub connections: RwLock<HashMap<ConnectionId, ConnectionState>>,
next_connection_id: AtomicU32,
}
#[derive(Clone, Serialize)]
pub struct ConnectionState {
#[serde(skip)]
outgoing_tx: mpsc::UnboundedSender<proto::Message>,
next_message_id: Arc<AtomicU32>,
#[allow(clippy::type_complexity)]
#[serde(skip)]
response_channels:
Arc<Mutex<Option<HashMap<u32, oneshot::Sender<(proto::Envelope, oneshot::Sender<()>)>>>>>,
}
const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(1);
const WRITE_TIMEOUT: Duration = Duration::from_secs(2);
pub const RECEIVE_TIMEOUT: Duration = Duration::from_secs(10);
impl Peer {
pub fn new(epoch: u32) -> Arc<Self> {
Arc::new(Self {
epoch: AtomicU32::new(epoch),
connections: Default::default(),
next_connection_id: Default::default(),
})
}
pub fn epoch(&self) -> u32 {
self.epoch.load(SeqCst)
}
#[instrument(skip_all)]
pub fn add_connection<F, Fut, Out>(
self: &Arc<Self>,
connection: Connection,
create_timer: F,
) -> (
ConnectionId,
impl Future<Output = anyhow::Result<()>> + Send,
BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
)
where
F: Send + Fn(Duration) -> Fut,
Fut: Send + Future<Output = Out>,
Out: Send,
{
// For outgoing messages, use an unbounded channel so that application code
// can always send messages without yielding. For incoming messages, use a
// bounded channel so that other peers will receive backpressure if they send
// messages faster than this peer can process them.
#[cfg(any(test, feature = "test-support"))]
const INCOMING_BUFFER_SIZE: usize = 1;
#[cfg(not(any(test, feature = "test-support")))]
const INCOMING_BUFFER_SIZE: usize = 64;
let (mut incoming_tx, incoming_rx) = mpsc::channel(INCOMING_BUFFER_SIZE);
let (outgoing_tx, mut outgoing_rx) = mpsc::unbounded();
let connection_id = ConnectionId {
owner_id: self.epoch.load(SeqCst),
id: self.next_connection_id.fetch_add(1, SeqCst),
};
let connection_state = ConnectionState {
outgoing_tx,
next_message_id: Default::default(),
response_channels: Arc::new(Mutex::new(Some(Default::default()))),
};
let mut writer = MessageStream::new(connection.tx);
let mut reader = MessageStream::new(connection.rx);
let this = self.clone();
let response_channels = connection_state.response_channels.clone();
let handle_io = async move {
tracing::trace!(%connection_id, "handle io future: start");
let _end_connection = util::defer(|| {
response_channels.lock().take();
this.connections.write().remove(&connection_id);
tracing::trace!(%connection_id, "handle io future: end");
});
// Send messages on this frequency so the connection isn't closed.
let keepalive_timer = create_timer(KEEPALIVE_INTERVAL).fuse();
futures::pin_mut!(keepalive_timer);
// Disconnect if we don't receive messages at least this frequently.
let receive_timeout = create_timer(RECEIVE_TIMEOUT).fuse();
futures::pin_mut!(receive_timeout);
loop {
tracing::trace!(%connection_id, "outer loop iteration start");
let read_message = reader.read().fuse();
futures::pin_mut!(read_message);
loop {
tracing::trace!(%connection_id, "inner loop iteration start");
futures::select_biased! {
outgoing = outgoing_rx.next().fuse() => match outgoing {
Some(outgoing) => {
tracing::trace!(%connection_id, "outgoing rpc message: writing");
futures::select_biased! {
result = writer.write(outgoing).fuse() => {
tracing::trace!(%connection_id, "outgoing rpc message: done writing");
result.context("failed to write RPC message")?;
tracing::trace!(%connection_id, "keepalive interval: resetting after sending message");
keepalive_timer.set(create_timer(KEEPALIVE_INTERVAL).fuse());
}
_ = create_timer(WRITE_TIMEOUT).fuse() => {
tracing::trace!(%connection_id, "outgoing rpc message: writing timed out");
Err(anyhow!("timed out writing message"))?;
}
}
}
None => {
tracing::trace!(%connection_id, "outgoing rpc message: channel closed");
return Ok(())
},
},
_ = keepalive_timer => {
tracing::trace!(%connection_id, "keepalive interval: pinging");
futures::select_biased! {
result = writer.write(proto::Message::Ping).fuse() => {
tracing::trace!(%connection_id, "keepalive interval: done pinging");
result.context("failed to send keepalive")?;
tracing::trace!(%connection_id, "keepalive interval: resetting after pinging");
keepalive_timer.set(create_timer(KEEPALIVE_INTERVAL).fuse());
}
_ = create_timer(WRITE_TIMEOUT).fuse() => {
tracing::trace!(%connection_id, "keepalive interval: pinging timed out");
Err(anyhow!("timed out sending keepalive"))?;
}
}
}
incoming = read_message => {
let incoming = incoming.context("error reading rpc message from socket")?;
tracing::trace!(%connection_id, "incoming rpc message: received");
tracing::trace!(%connection_id, "receive timeout: resetting");
receive_timeout.set(create_timer(RECEIVE_TIMEOUT).fuse());
if let proto::Message::Envelope(incoming) = incoming {
tracing::trace!(%connection_id, "incoming rpc message: processing");
futures::select_biased! {
result = incoming_tx.send(incoming).fuse() => match result {
Ok(_) => {
tracing::trace!(%connection_id, "incoming rpc message: processed");
}
Err(_) => {
tracing::trace!(%connection_id, "incoming rpc message: channel closed");
return Ok(())
}
},
_ = create_timer(WRITE_TIMEOUT).fuse() => {
tracing::trace!(%connection_id, "incoming rpc message: processing timed out");
Err(anyhow!("timed out processing incoming message"))?
}
}
}
break;
},
_ = receive_timeout => {
tracing::trace!(%connection_id, "receive timeout: delay between messages too long");
Err(anyhow!("delay between messages too long"))?
}
}
}
}
};
let response_channels = connection_state.response_channels.clone();
self.connections
.write()
.insert(connection_id, connection_state);
let incoming_rx = incoming_rx.filter_map(move |incoming| {
let response_channels = response_channels.clone();
async move {
let message_id = incoming.id;
tracing::trace!(?incoming, "incoming message future: start");
let _end = util::defer(move || {
tracing::trace!(%connection_id, message_id, "incoming message future: end");
});
if let Some(responding_to) = incoming.responding_to {
tracing::trace!(
%connection_id,
message_id,
responding_to,
"incoming response: received"
);
let channel = response_channels.lock().as_mut()?.remove(&responding_to);
if let Some(tx) = channel {
let requester_resumed = oneshot::channel();
if let Err(error) = tx.send((incoming, requester_resumed.0)) {
tracing::trace!(
%connection_id,
message_id,
responding_to = responding_to,
?error,
"incoming response: request future dropped",
);
}
tracing::trace!(
%connection_id,
message_id,
responding_to,
"incoming response: waiting to resume requester"
);
let _ = requester_resumed.1.await;
tracing::trace!(
%connection_id,
message_id,
responding_to,
"incoming response: requester resumed"
);
} else {
tracing::warn!(
%connection_id,
message_id,
responding_to,
"incoming response: unknown request"
);
}
None
} else {
tracing::trace!(%connection_id, message_id, "incoming message: received");
proto::build_typed_envelope(connection_id, incoming).or_else(|| {
tracing::error!(
%connection_id,
message_id,
"unable to construct a typed envelope"
);
None
})
}
}
});
(connection_id, handle_io, incoming_rx.boxed())
}
#[cfg(any(test, feature = "test-support"))]
pub fn add_test_connection(
self: &Arc<Self>,
connection: Connection,
executor: gpui::BackgroundExecutor,
) -> (
ConnectionId,
impl Future<Output = anyhow::Result<()>> + Send,
BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
) {
let executor = executor.clone();
self.add_connection(connection, move |duration| executor.timer(duration))
}
pub fn disconnect(&self, connection_id: ConnectionId) {
self.connections.write().remove(&connection_id);
}
pub fn reset(&self, epoch: u32) {
self.teardown();
self.next_connection_id.store(0, SeqCst);
self.epoch.store(epoch, SeqCst);
}
pub fn teardown(&self) {
self.connections.write().clear();
}
pub fn request<T: RequestMessage>(
&self,
receiver_id: ConnectionId,
request: T,
) -> impl Future<Output = Result<T::Response>> {
self.request_internal(None, receiver_id, request)
.map_ok(|envelope| envelope.payload)
}
pub fn request_envelope<T: RequestMessage>(
&self,
receiver_id: ConnectionId,
request: T,
) -> impl Future<Output = Result<TypedEnvelope<T::Response>>> {
self.request_internal(None, receiver_id, request)
}
pub fn forward_request<T: RequestMessage>(
&self,
sender_id: ConnectionId,
receiver_id: ConnectionId,
request: T,
) -> impl Future<Output = Result<T::Response>> {
self.request_internal(Some(sender_id), receiver_id, request)
.map_ok(|envelope| envelope.payload)
}
pub fn request_internal<T: RequestMessage>(
&self,
original_sender_id: Option<ConnectionId>,
receiver_id: ConnectionId,
request: T,
) -> impl Future<Output = Result<TypedEnvelope<T::Response>>> {
let (tx, rx) = oneshot::channel();
let send = self.connection_state(receiver_id).and_then(|connection| {
let message_id = connection.next_message_id.fetch_add(1, SeqCst);
connection
.response_channels
.lock()
.as_mut()
.ok_or_else(|| anyhow!("connection was closed"))?
.insert(message_id, tx);
connection
.outgoing_tx
.unbounded_send(proto::Message::Envelope(request.into_envelope(
message_id,
None,
original_sender_id.map(Into::into),
)))
.map_err(|_| anyhow!("connection was closed"))?;
Ok(())
});
async move {
send?;
let (response, _barrier) = rx.await.map_err(|_| anyhow!("connection was closed"))?;
if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
Err(anyhow!(
"RPC request {} failed - {}",
T::NAME,
error.message
))
} else {
Ok(TypedEnvelope {
message_id: response.id,
sender_id: receiver_id,
original_sender_id: response.original_sender_id,
payload: T::Response::from_envelope(response)
.ok_or_else(|| anyhow!("received response of the wrong type"))?,
})
}
}
}
pub fn send<T: EnvelopedMessage>(&self, receiver_id: ConnectionId, message: T) -> Result<()> {
let connection = self.connection_state(receiver_id)?;
let message_id = connection
.next_message_id
.fetch_add(1, atomic::Ordering::SeqCst);
connection
.outgoing_tx
.unbounded_send(proto::Message::Envelope(
message.into_envelope(message_id, None, None),
))?;
Ok(())
}
pub fn forward_send<T: EnvelopedMessage>(
&self,
sender_id: ConnectionId,
receiver_id: ConnectionId,
message: T,
) -> Result<()> {
let connection = self.connection_state(receiver_id)?;
let message_id = connection
.next_message_id
.fetch_add(1, atomic::Ordering::SeqCst);
connection
.outgoing_tx
.unbounded_send(proto::Message::Envelope(message.into_envelope(
message_id,
None,
Some(sender_id.into()),
)))?;
Ok(())
}
pub fn respond<T: RequestMessage>(
&self,
receipt: Receipt<T>,
response: T::Response,
) -> Result<()> {
let connection = self.connection_state(receipt.sender_id)?;
let message_id = connection
.next_message_id
.fetch_add(1, atomic::Ordering::SeqCst);
connection
.outgoing_tx
.unbounded_send(proto::Message::Envelope(response.into_envelope(
message_id,
Some(receipt.message_id),
None,
)))?;
Ok(())
}
pub fn respond_with_error<T: RequestMessage>(
&self,
receipt: Receipt<T>,
response: proto::Error,
) -> Result<()> {
let connection = self.connection_state(receipt.sender_id)?;
let message_id = connection
.next_message_id
.fetch_add(1, atomic::Ordering::SeqCst);
connection
.outgoing_tx
.unbounded_send(proto::Message::Envelope(response.into_envelope(
message_id,
Some(receipt.message_id),
None,
)))?;
Ok(())
}
pub fn respond_with_unhandled_message(
&self,
envelope: Box<dyn AnyTypedEnvelope>,
) -> Result<()> {
let connection = self.connection_state(envelope.sender_id())?;
let response = proto::Error {
message: format!("message {} was not handled", envelope.payload_type_name()),
};
let message_id = connection
.next_message_id
.fetch_add(1, atomic::Ordering::SeqCst);
connection
.outgoing_tx
.unbounded_send(proto::Message::Envelope(response.into_envelope(
message_id,
Some(envelope.message_id()),
None,
)))?;
Ok(())
}
fn connection_state(&self, connection_id: ConnectionId) -> Result<ConnectionState> {
let connections = self.connections.read();
let connection = connections
.get(&connection_id)
.ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
Ok(connection.clone())
}
}
impl Serialize for Peer {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let mut state = serializer.serialize_struct("Peer", 2)?;
state.serialize_field("connections", &*self.connections.read())?;
state.end()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TypedEnvelope;
use async_tungstenite::tungstenite::Message as WebSocketMessage;
use gpui::TestAppContext;
fn init_logger() {
if std::env::var("RUST_LOG").is_ok() {
env_logger::init();
}
}
#[gpui::test(iterations = 50)]
async fn test_request_response(cx: &mut TestAppContext) {
init_logger();
let executor = cx.executor();
// create 2 clients connected to 1 server
let server = Peer::new(0);
let client1 = Peer::new(0);
let client2 = Peer::new(0);
let (client1_to_server_conn, server_to_client_1_conn, _kill) =
Connection::in_memory(cx.executor());
let (client1_conn_id, io_task1, client1_incoming) =
client1.add_test_connection(client1_to_server_conn, cx.executor());
let (_, io_task2, server_incoming1) =
server.add_test_connection(server_to_client_1_conn, cx.executor());
let (client2_to_server_conn, server_to_client_2_conn, _kill) =
Connection::in_memory(cx.executor());
let (client2_conn_id, io_task3, client2_incoming) =
client2.add_test_connection(client2_to_server_conn, cx.executor());
let (_, io_task4, server_incoming2) =
server.add_test_connection(server_to_client_2_conn, cx.executor());
executor.spawn(io_task1).detach();
executor.spawn(io_task2).detach();
executor.spawn(io_task3).detach();
executor.spawn(io_task4).detach();
executor
.spawn(handle_messages(server_incoming1, server.clone()))
.detach();
executor
.spawn(handle_messages(client1_incoming, client1.clone()))
.detach();
executor
.spawn(handle_messages(server_incoming2, server.clone()))
.detach();
executor
.spawn(handle_messages(client2_incoming, client2.clone()))
.detach();
assert_eq!(
client1
.request(client1_conn_id, proto::Ping {},)
.await
.unwrap(),
proto::Ack {}
);
assert_eq!(
client2
.request(client2_conn_id, proto::Ping {},)
.await
.unwrap(),
proto::Ack {}
);
assert_eq!(
client1
.request(client1_conn_id, proto::Test { id: 1 },)
.await
.unwrap(),
proto::Test { id: 1 }
);
assert_eq!(
client2
.request(client2_conn_id, proto::Test { id: 2 })
.await
.unwrap(),
proto::Test { id: 2 }
);
client1.disconnect(client1_conn_id);
client2.disconnect(client1_conn_id);
async fn handle_messages(
mut messages: BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
peer: Arc<Peer>,
) -> Result<()> {
while let Some(envelope) = messages.next().await {
let envelope = envelope.into_any();
if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() {
let receipt = envelope.receipt();
peer.respond(receipt, proto::Ack {})?
} else if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Test>>()
{
peer.respond(envelope.receipt(), envelope.payload.clone())?
} else {
panic!("unknown message type");
}
}
Ok(())
}
}
#[gpui::test(iterations = 50)]
async fn test_order_of_response_and_incoming(cx: &mut TestAppContext) {
let executor = cx.executor();
let server = Peer::new(0);
let client = Peer::new(0);
let (client_to_server_conn, server_to_client_conn, _kill) =
Connection::in_memory(executor.clone());
let (client_to_server_conn_id, io_task1, mut client_incoming) =
client.add_test_connection(client_to_server_conn, executor.clone());
let (server_to_client_conn_id, io_task2, mut server_incoming) =
server.add_test_connection(server_to_client_conn, executor.clone());
executor.spawn(io_task1).detach();
executor.spawn(io_task2).detach();
executor
.spawn(async move {
let future = server_incoming.next().await;
let request = future
.unwrap()
.into_any()
.downcast::<TypedEnvelope<proto::Ping>>()
.unwrap();
server
.send(
server_to_client_conn_id,
proto::Error {
message: "message 1".to_string(),
},
)
.unwrap();
server
.send(
server_to_client_conn_id,
proto::Error {
message: "message 2".to_string(),
},
)
.unwrap();
server.respond(request.receipt(), proto::Ack {}).unwrap();
// Prevent the connection from being dropped
server_incoming.next().await;
})
.detach();
let events = Arc::new(Mutex::new(Vec::new()));
let response = client.request(client_to_server_conn_id, proto::Ping {});
let response_task = executor.spawn({
let events = events.clone();
async move {
response.await.unwrap();
events.lock().push("response".to_string());
}
});
executor
.spawn({
let events = events.clone();
async move {
let incoming1 = client_incoming
.next()
.await
.unwrap()
.into_any()
.downcast::<TypedEnvelope<proto::Error>>()
.unwrap();
events.lock().push(incoming1.payload.message);
let incoming2 = client_incoming
.next()
.await
.unwrap()
.into_any()
.downcast::<TypedEnvelope<proto::Error>>()
.unwrap();
events.lock().push(incoming2.payload.message);
// Prevent the connection from being dropped
client_incoming.next().await;
}
})
.detach();
response_task.await;
assert_eq!(
&*events.lock(),
&[
"message 1".to_string(),
"message 2".to_string(),
"response".to_string()
]
);
}
#[gpui::test(iterations = 50)]
async fn test_dropping_request_before_completion(cx: &mut TestAppContext) {
let executor = cx.executor();
let server = Peer::new(0);
let client = Peer::new(0);
let (client_to_server_conn, server_to_client_conn, _kill) =
Connection::in_memory(cx.executor());
let (client_to_server_conn_id, io_task1, mut client_incoming) =
client.add_test_connection(client_to_server_conn, cx.executor());
let (server_to_client_conn_id, io_task2, mut server_incoming) =
server.add_test_connection(server_to_client_conn, cx.executor());
executor.spawn(io_task1).detach();
executor.spawn(io_task2).detach();
executor
.spawn(async move {
let request1 = server_incoming
.next()
.await
.unwrap()
.into_any()
.downcast::<TypedEnvelope<proto::Ping>>()
.unwrap();
let request2 = server_incoming
.next()
.await
.unwrap()
.into_any()
.downcast::<TypedEnvelope<proto::Ping>>()
.unwrap();
server
.send(
server_to_client_conn_id,
proto::Error {
message: "message 1".to_string(),
},
)
.unwrap();
server
.send(
server_to_client_conn_id,
proto::Error {
message: "message 2".to_string(),
},
)
.unwrap();
server.respond(request1.receipt(), proto::Ack {}).unwrap();
server.respond(request2.receipt(), proto::Ack {}).unwrap();
// Prevent the connection from being dropped
server_incoming.next().await;
})
.detach();
let events = Arc::new(Mutex::new(Vec::new()));
let request1 = client.request(client_to_server_conn_id, proto::Ping {});
let request1_task = executor.spawn(request1);
let request2 = client.request(client_to_server_conn_id, proto::Ping {});
let request2_task = executor.spawn({
let events = events.clone();
async move {
request2.await.unwrap();
events.lock().push("response 2".to_string());
}
});
executor
.spawn({
let events = events.clone();
async move {
let incoming1 = client_incoming
.next()
.await
.unwrap()
.into_any()
.downcast::<TypedEnvelope<proto::Error>>()
.unwrap();
events.lock().push(incoming1.payload.message);
let incoming2 = client_incoming
.next()
.await
.unwrap()
.into_any()
.downcast::<TypedEnvelope<proto::Error>>()
.unwrap();
events.lock().push(incoming2.payload.message);
// Prevent the connection from being dropped
client_incoming.next().await;
}
})
.detach();
// Allow the request to make some progress before dropping it.
cx.executor().simulate_random_delay().await;
drop(request1_task);
request2_task.await;
assert_eq!(
&*events.lock(),
&[
"message 1".to_string(),
"message 2".to_string(),
"response 2".to_string()
]
);
}
#[gpui::test(iterations = 50)]
async fn test_disconnect(cx: &mut TestAppContext) {
let executor = cx.executor();
let (client_conn, mut server_conn, _kill) = Connection::in_memory(executor.clone());
let client = Peer::new(0);
let (connection_id, io_handler, mut incoming) =
client.add_test_connection(client_conn, executor.clone());
let (io_ended_tx, io_ended_rx) = oneshot::channel();
executor
.spawn(async move {
io_handler.await.ok();
io_ended_tx.send(()).unwrap();
})
.detach();
let (messages_ended_tx, messages_ended_rx) = oneshot::channel();
executor
.spawn(async move {
incoming.next().await;
messages_ended_tx.send(()).unwrap();
})
.detach();
client.disconnect(connection_id);
let _ = io_ended_rx.await;
let _ = messages_ended_rx.await;
assert!(server_conn
.send(WebSocketMessage::Binary(vec![]))
.await
.is_err());
}
#[gpui::test(iterations = 50)]
async fn test_io_error(cx: &mut TestAppContext) {
let executor = cx.executor();
let (client_conn, mut server_conn, _kill) = Connection::in_memory(executor.clone());
let client = Peer::new(0);
let (connection_id, io_handler, mut incoming) =
client.add_test_connection(client_conn, executor.clone());
executor.spawn(io_handler).detach();
executor
.spawn(async move { incoming.next().await })
.detach();
let response = executor.spawn(client.request(connection_id, proto::Ping {}));
let _request = server_conn.rx.next().await.unwrap().unwrap();
drop(server_conn);
assert_eq!(
response.await.unwrap_err().to_string(),
"connection was closed"
);
}
}

View File

@ -1,692 +0,0 @@
#![allow(non_snake_case)]
use super::{entity_messages, messages, request_messages, ConnectionId, TypedEnvelope};
use anyhow::{anyhow, Result};
use async_tungstenite::tungstenite::Message as WebSocketMessage;
use collections::HashMap;
use futures::{SinkExt as _, StreamExt as _};
use prost::Message as _;
use serde::Serialize;
use std::any::{Any, TypeId};
use std::{
cmp,
fmt::Debug,
io, iter,
time::{Duration, SystemTime, UNIX_EPOCH},
};
use std::{fmt, mem};
include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
pub trait EnvelopedMessage: Clone + Debug + Serialize + Sized + Send + Sync + 'static {
const NAME: &'static str;
const PRIORITY: MessagePriority;
fn into_envelope(
self,
id: u32,
responding_to: Option<u32>,
original_sender_id: Option<PeerId>,
) -> Envelope;
fn from_envelope(envelope: Envelope) -> Option<Self>;
}
pub trait EntityMessage: EnvelopedMessage {
fn remote_entity_id(&self) -> u64;
}
pub trait RequestMessage: EnvelopedMessage {
type Response: EnvelopedMessage;
}
pub trait AnyTypedEnvelope: 'static + Send + Sync {
fn payload_type_id(&self) -> TypeId;
fn payload_type_name(&self) -> &'static str;
fn as_any(&self) -> &dyn Any;
fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
fn is_background(&self) -> bool;
fn original_sender_id(&self) -> Option<PeerId>;
fn sender_id(&self) -> ConnectionId;
fn message_id(&self) -> u32;
}
pub enum MessagePriority {
Foreground,
Background,
}
impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
fn payload_type_id(&self) -> TypeId {
TypeId::of::<T>()
}
fn payload_type_name(&self) -> &'static str {
T::NAME
}
fn as_any(&self) -> &dyn Any {
self
}
fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
self
}
fn is_background(&self) -> bool {
matches!(T::PRIORITY, MessagePriority::Background)
}
fn original_sender_id(&self) -> Option<PeerId> {
self.original_sender_id
}
fn sender_id(&self) -> ConnectionId {
self.sender_id
}
fn message_id(&self) -> u32 {
self.message_id
}
}
impl PeerId {
pub fn from_u64(peer_id: u64) -> Self {
let owner_id = (peer_id >> 32) as u32;
let id = peer_id as u32;
Self { owner_id, id }
}
pub fn as_u64(self) -> u64 {
((self.owner_id as u64) << 32) | (self.id as u64)
}
}
impl Copy for PeerId {}
impl Eq for PeerId {}
impl Ord for PeerId {
fn cmp(&self, other: &Self) -> cmp::Ordering {
self.owner_id
.cmp(&other.owner_id)
.then_with(|| self.id.cmp(&other.id))
}
}
impl PartialOrd for PeerId {
fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
Some(self.cmp(other))
}
}
impl std::hash::Hash for PeerId {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.owner_id.hash(state);
self.id.hash(state);
}
}
impl fmt::Display for PeerId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}/{}", self.owner_id, self.id)
}
}
messages!(
(Ack, Foreground),
(AckBufferOperation, Background),
(AckChannelMessage, Background),
(AddNotification, Foreground),
(AddProjectCollaborator, Foreground),
(ApplyCodeAction, Background),
(ApplyCodeActionResponse, Background),
(ApplyCompletionAdditionalEdits, Background),
(ApplyCompletionAdditionalEditsResponse, Background),
(BufferReloaded, Foreground),
(BufferSaved, Foreground),
(Call, Foreground),
(CallCanceled, Foreground),
(CancelCall, Foreground),
(ChannelMessageSent, Foreground),
(CopyProjectEntry, Foreground),
(CreateBufferForPeer, Foreground),
(CreateChannel, Foreground),
(CreateChannelResponse, Foreground),
(CreateProjectEntry, Foreground),
(CreateRoom, Foreground),
(CreateRoomResponse, Foreground),
(DeclineCall, Foreground),
(DeleteChannel, Foreground),
(DeleteNotification, Foreground),
(DeleteProjectEntry, Foreground),
(Error, Foreground),
(ExpandProjectEntry, Foreground),
(ExpandProjectEntryResponse, Foreground),
(Follow, Foreground),
(FollowResponse, Foreground),
(FormatBuffers, Foreground),
(FormatBuffersResponse, Foreground),
(FuzzySearchUsers, Foreground),
(GetChannelMembers, Foreground),
(GetChannelMembersResponse, Foreground),
(GetChannelMessages, Background),
(GetChannelMessagesById, Background),
(GetChannelMessagesResponse, Background),
(GetCodeActions, Background),
(GetCodeActionsResponse, Background),
(GetCompletions, Background),
(GetCompletionsResponse, Background),
(GetDefinition, Background),
(GetDefinitionResponse, Background),
(GetDocumentHighlights, Background),
(GetDocumentHighlightsResponse, Background),
(GetHover, Background),
(GetHoverResponse, Background),
(GetNotifications, Foreground),
(GetNotificationsResponse, Foreground),
(GetPrivateUserInfo, Foreground),
(GetPrivateUserInfoResponse, Foreground),
(GetProjectSymbols, Background),
(GetProjectSymbolsResponse, Background),
(GetReferences, Background),
(GetReferencesResponse, Background),
(GetTypeDefinition, Background),
(GetTypeDefinitionResponse, Background),
(GetUsers, Foreground),
(Hello, Foreground),
(IncomingCall, Foreground),
(InlayHints, Background),
(InlayHintsResponse, Background),
(InviteChannelMember, Foreground),
(JoinChannel, Foreground),
(JoinChannelBuffer, Foreground),
(JoinChannelBufferResponse, Foreground),
(JoinChannelChat, Foreground),
(JoinChannelChatResponse, Foreground),
(JoinProject, Foreground),
(JoinProjectResponse, Foreground),
(JoinRoom, Foreground),
(JoinRoomResponse, Foreground),
(LeaveChannelBuffer, Background),
(LeaveChannelChat, Foreground),
(LeaveProject, Foreground),
(LeaveRoom, Foreground),
(MarkNotificationRead, Foreground),
(MoveChannel, Foreground),
(OnTypeFormatting, Background),
(OnTypeFormattingResponse, Background),
(OpenBufferById, Background),
(OpenBufferByPath, Background),
(OpenBufferForSymbol, Background),
(OpenBufferForSymbolResponse, Background),
(OpenBufferResponse, Background),
(PerformRename, Background),
(PerformRenameResponse, Background),
(Ping, Foreground),
(PrepareRename, Background),
(PrepareRenameResponse, Background),
(ProjectEntryResponse, Foreground),
(RefreshInlayHints, Foreground),
(RejoinChannelBuffers, Foreground),
(RejoinChannelBuffersResponse, Foreground),
(RejoinRoom, Foreground),
(RejoinRoomResponse, Foreground),
(ReloadBuffers, Foreground),
(ReloadBuffersResponse, Foreground),
(RemoveChannelMember, Foreground),
(RemoveChannelMessage, Foreground),
(RemoveContact, Foreground),
(RemoveProjectCollaborator, Foreground),
(RenameChannel, Foreground),
(RenameChannelResponse, Foreground),
(RenameProjectEntry, Foreground),
(RequestContact, Foreground),
(ResolveCompletionDocumentation, Background),
(ResolveCompletionDocumentationResponse, Background),
(ResolveInlayHint, Background),
(ResolveInlayHintResponse, Background),
(RespondToChannelInvite, Foreground),
(RespondToContactRequest, Foreground),
(RoomUpdated, Foreground),
(SaveBuffer, Foreground),
(SetChannelMemberRole, Foreground),
(SetChannelVisibility, Foreground),
(SearchProject, Background),
(SearchProjectResponse, Background),
(SendChannelMessage, Background),
(SendChannelMessageResponse, Background),
(ShareProject, Foreground),
(ShareProjectResponse, Foreground),
(ShowContacts, Foreground),
(StartLanguageServer, Foreground),
(SynchronizeBuffers, Foreground),
(SynchronizeBuffersResponse, Foreground),
(Test, Foreground),
(Unfollow, Foreground),
(UnshareProject, Foreground),
(UpdateBuffer, Foreground),
(UpdateBufferFile, Foreground),
(UpdateChannelBuffer, Foreground),
(UpdateChannelBufferCollaborators, Foreground),
(UpdateChannels, Foreground),
(UpdateContacts, Foreground),
(UpdateDiagnosticSummary, Foreground),
(UpdateDiffBase, Foreground),
(UpdateFollowers, Foreground),
(UpdateInviteInfo, Foreground),
(UpdateLanguageServer, Foreground),
(UpdateParticipantLocation, Foreground),
(UpdateProject, Foreground),
(UpdateProjectCollaborator, Foreground),
(UpdateWorktree, Foreground),
(UpdateWorktreeSettings, Foreground),
(UsersResponse, Foreground),
(LspExtExpandMacro, Background),
(LspExtExpandMacroResponse, Background),
);
request_messages!(
(ApplyCodeAction, ApplyCodeActionResponse),
(
ApplyCompletionAdditionalEdits,
ApplyCompletionAdditionalEditsResponse
),
(Call, Ack),
(CancelCall, Ack),
(CopyProjectEntry, ProjectEntryResponse),
(CreateChannel, CreateChannelResponse),
(CreateProjectEntry, ProjectEntryResponse),
(CreateRoom, CreateRoomResponse),
(DeclineCall, Ack),
(DeleteChannel, Ack),
(DeleteProjectEntry, ProjectEntryResponse),
(ExpandProjectEntry, ExpandProjectEntryResponse),
(Follow, FollowResponse),
(FormatBuffers, FormatBuffersResponse),
(FuzzySearchUsers, UsersResponse),
(GetChannelMembers, GetChannelMembersResponse),
(GetChannelMessages, GetChannelMessagesResponse),
(GetChannelMessagesById, GetChannelMessagesResponse),
(GetCodeActions, GetCodeActionsResponse),
(GetCompletions, GetCompletionsResponse),
(GetDefinition, GetDefinitionResponse),
(GetDocumentHighlights, GetDocumentHighlightsResponse),
(GetHover, GetHoverResponse),
(GetNotifications, GetNotificationsResponse),
(GetPrivateUserInfo, GetPrivateUserInfoResponse),
(GetProjectSymbols, GetProjectSymbolsResponse),
(GetReferences, GetReferencesResponse),
(GetTypeDefinition, GetTypeDefinitionResponse),
(GetUsers, UsersResponse),
(IncomingCall, Ack),
(InlayHints, InlayHintsResponse),
(InviteChannelMember, Ack),
(JoinChannel, JoinRoomResponse),
(JoinChannelBuffer, JoinChannelBufferResponse),
(JoinChannelChat, JoinChannelChatResponse),
(JoinProject, JoinProjectResponse),
(JoinRoom, JoinRoomResponse),
(LeaveChannelBuffer, Ack),
(LeaveRoom, Ack),
(MarkNotificationRead, Ack),
(MoveChannel, Ack),
(OnTypeFormatting, OnTypeFormattingResponse),
(OpenBufferById, OpenBufferResponse),
(OpenBufferByPath, OpenBufferResponse),
(OpenBufferForSymbol, OpenBufferForSymbolResponse),
(PerformRename, PerformRenameResponse),
(Ping, Ack),
(PrepareRename, PrepareRenameResponse),
(RefreshInlayHints, Ack),
(RejoinChannelBuffers, RejoinChannelBuffersResponse),
(RejoinRoom, RejoinRoomResponse),
(ReloadBuffers, ReloadBuffersResponse),
(RemoveChannelMember, Ack),
(RemoveChannelMessage, Ack),
(RemoveContact, Ack),
(RenameChannel, RenameChannelResponse),
(RenameProjectEntry, ProjectEntryResponse),
(RequestContact, Ack),
(
ResolveCompletionDocumentation,
ResolveCompletionDocumentationResponse
),
(ResolveInlayHint, ResolveInlayHintResponse),
(RespondToChannelInvite, Ack),
(RespondToContactRequest, Ack),
(SaveBuffer, BufferSaved),
(SearchProject, SearchProjectResponse),
(SendChannelMessage, SendChannelMessageResponse),
(SetChannelMemberRole, Ack),
(SetChannelVisibility, Ack),
(ShareProject, ShareProjectResponse),
(SynchronizeBuffers, SynchronizeBuffersResponse),
(Test, Test),
(UpdateBuffer, Ack),
(UpdateParticipantLocation, Ack),
(UpdateProject, Ack),
(UpdateWorktree, Ack),
(LspExtExpandMacro, LspExtExpandMacroResponse),
);
entity_messages!(
project_id,
AddProjectCollaborator,
ApplyCodeAction,
ApplyCompletionAdditionalEdits,
BufferReloaded,
BufferSaved,
CopyProjectEntry,
CreateBufferForPeer,
CreateProjectEntry,
DeleteProjectEntry,
ExpandProjectEntry,
FormatBuffers,
GetCodeActions,
GetCompletions,
GetDefinition,
GetDocumentHighlights,
GetHover,
GetProjectSymbols,
GetReferences,
GetTypeDefinition,
InlayHints,
JoinProject,
LeaveProject,
OnTypeFormatting,
OpenBufferById,
OpenBufferByPath,
OpenBufferForSymbol,
PerformRename,
PrepareRename,
RefreshInlayHints,
ReloadBuffers,
RemoveProjectCollaborator,
RenameProjectEntry,
ResolveCompletionDocumentation,
ResolveInlayHint,
SaveBuffer,
SearchProject,
StartLanguageServer,
SynchronizeBuffers,
UnshareProject,
UpdateBuffer,
UpdateBufferFile,
UpdateDiagnosticSummary,
UpdateDiffBase,
UpdateLanguageServer,
UpdateProject,
UpdateProjectCollaborator,
UpdateWorktree,
UpdateWorktreeSettings,
LspExtExpandMacro,
);
entity_messages!(
channel_id,
ChannelMessageSent,
RemoveChannelMessage,
UpdateChannelBuffer,
UpdateChannelBufferCollaborators,
);
const KIB: usize = 1024;
const MIB: usize = KIB * 1024;
const MAX_BUFFER_LEN: usize = MIB;
/// A stream of protobuf messages.
pub struct MessageStream<S> {
stream: S,
encoding_buffer: Vec<u8>,
}
#[allow(clippy::large_enum_variant)]
#[derive(Debug)]
pub enum Message {
Envelope(Envelope),
Ping,
Pong,
}
impl<S> MessageStream<S> {
pub fn new(stream: S) -> Self {
Self {
stream,
encoding_buffer: Vec::new(),
}
}
pub fn inner_mut(&mut self) -> &mut S {
&mut self.stream
}
}
impl<S> MessageStream<S>
where
S: futures::Sink<WebSocketMessage, Error = anyhow::Error> + Unpin,
{
pub async fn write(&mut self, message: Message) -> Result<(), anyhow::Error> {
#[cfg(any(test, feature = "test-support"))]
const COMPRESSION_LEVEL: i32 = -7;
#[cfg(not(any(test, feature = "test-support")))]
const COMPRESSION_LEVEL: i32 = 4;
match message {
Message::Envelope(message) => {
self.encoding_buffer.reserve(message.encoded_len());
message
.encode(&mut self.encoding_buffer)
.map_err(io::Error::from)?;
let buffer =
zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL)
.unwrap();
self.encoding_buffer.clear();
self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
self.stream.send(WebSocketMessage::Binary(buffer)).await?;
}
Message::Ping => {
self.stream
.send(WebSocketMessage::Ping(Default::default()))
.await?;
}
Message::Pong => {
self.stream
.send(WebSocketMessage::Pong(Default::default()))
.await?;
}
}
Ok(())
}
}
impl<S> MessageStream<S>
where
S: futures::Stream<Item = Result<WebSocketMessage, anyhow::Error>> + Unpin,
{
pub async fn read(&mut self) -> Result<Message, anyhow::Error> {
while let Some(bytes) = self.stream.next().await {
match bytes? {
WebSocketMessage::Binary(bytes) => {
zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap();
let envelope = Envelope::decode(self.encoding_buffer.as_slice())
.map_err(io::Error::from)?;
self.encoding_buffer.clear();
self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
return Ok(Message::Envelope(envelope));
}
WebSocketMessage::Ping(_) => return Ok(Message::Ping),
WebSocketMessage::Pong(_) => return Ok(Message::Pong),
WebSocketMessage::Close(_) => break,
_ => {}
}
}
Err(anyhow!("connection closed"))
}
}
impl From<Timestamp> for SystemTime {
fn from(val: Timestamp) -> Self {
UNIX_EPOCH
.checked_add(Duration::new(val.seconds, val.nanos))
.unwrap()
}
}
impl From<SystemTime> for Timestamp {
fn from(time: SystemTime) -> Self {
let duration = time.duration_since(UNIX_EPOCH).unwrap();
Self {
seconds: duration.as_secs(),
nanos: duration.subsec_nanos(),
}
}
}
impl From<u128> for Nonce {
fn from(nonce: u128) -> Self {
let upper_half = (nonce >> 64) as u64;
let lower_half = nonce as u64;
Self {
upper_half,
lower_half,
}
}
}
impl From<Nonce> for u128 {
fn from(nonce: Nonce) -> Self {
let upper_half = (nonce.upper_half as u128) << 64;
let lower_half = nonce.lower_half as u128;
upper_half | lower_half
}
}
pub fn split_worktree_update(
mut message: UpdateWorktree,
max_chunk_size: usize,
) -> impl Iterator<Item = UpdateWorktree> {
let mut done_files = false;
let mut repository_map = message
.updated_repositories
.into_iter()
.map(|repo| (repo.work_directory_id, repo))
.collect::<HashMap<_, _>>();
iter::from_fn(move || {
if done_files {
return None;
}
let updated_entries_chunk_size = cmp::min(message.updated_entries.len(), max_chunk_size);
let updated_entries: Vec<_> = message
.updated_entries
.drain(..updated_entries_chunk_size)
.collect();
let removed_entries_chunk_size = cmp::min(message.removed_entries.len(), max_chunk_size);
let removed_entries = message
.removed_entries
.drain(..removed_entries_chunk_size)
.collect();
done_files = message.updated_entries.is_empty() && message.removed_entries.is_empty();
let mut updated_repositories = Vec::new();
if !repository_map.is_empty() {
for entry in &updated_entries {
if let Some(repo) = repository_map.remove(&entry.id) {
updated_repositories.push(repo)
}
}
}
let removed_repositories = if done_files {
mem::take(&mut message.removed_repositories)
} else {
Default::default()
};
if done_files {
updated_repositories.extend(mem::take(&mut repository_map).into_values());
}
Some(UpdateWorktree {
project_id: message.project_id,
worktree_id: message.worktree_id,
root_name: message.root_name.clone(),
abs_path: message.abs_path.clone(),
updated_entries,
removed_entries,
scan_id: message.scan_id,
is_last_update: done_files && message.is_last_update,
updated_repositories,
removed_repositories,
})
})
}
#[cfg(test)]
mod tests {
use super::*;
#[gpui::test]
async fn test_buffer_size() {
let (tx, rx) = futures::channel::mpsc::unbounded();
let mut sink = MessageStream::new(tx.sink_map_err(|_| anyhow!("")));
sink.write(Message::Envelope(Envelope {
payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
root_name: "abcdefg".repeat(10),
..Default::default()
})),
..Default::default()
}))
.await
.unwrap();
assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
sink.write(Message::Envelope(Envelope {
payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
root_name: "abcdefg".repeat(1000000),
..Default::default()
})),
..Default::default()
}))
.await
.unwrap();
assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
let mut stream = MessageStream::new(rx.map(anyhow::Ok));
stream.read().await.unwrap();
assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
stream.read().await.unwrap();
assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
}
#[gpui::test]
fn test_converting_peer_id_from_and_to_u64() {
let peer_id = PeerId {
owner_id: 10,
id: 3,
};
assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
let peer_id = PeerId {
owner_id: u32::MAX,
id: 3,
};
assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
let peer_id = PeerId {
owner_id: 10,
id: u32::MAX,
};
assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
let peer_id = PeerId {
owner_id: u32::MAX,
id: u32::MAX,
};
assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
}
}

View File

@ -1,12 +0,0 @@
pub mod auth;
mod conn;
mod notification;
mod peer;
pub mod proto;
pub use conn::Connection;
pub use notification::*;
pub use peer::*;
mod macros;
pub const PROTOCOL_VERSION: u32 = 67;

View File

@ -21,7 +21,7 @@ theme = { package = "theme2", path = "../theme2" }
util = { path = "../util" }
ui = {package = "ui2", path = "../ui2"}
workspace = { path = "../workspace" }
semantic_index = { package = "semantic_index2", path = "../semantic_index2" }
semantic_index = { path = "../semantic_index" }
anyhow.workspace = true
futures.workspace = true
log.workspace = true

View File

@ -11,16 +11,13 @@ doctest = false
[dependencies]
ai = { path = "../ai" }
collections = { path = "../collections" }
gpui = { path = "../gpui" }
gpui = { package = "gpui2", path = "../gpui2" }
language = { path = "../language" }
project = { path = "../project" }
workspace = { path = "../workspace" }
util = { path = "../util" }
picker = { path = "../picker" }
theme = { path = "../theme" }
editor = { path = "../editor" }
rpc = { path = "../rpc" }
settings = { path = "../settings" }
settings = { package = "settings2", path = "../settings2" }
anyhow.workspace = true
postage.workspace = true
futures.workspace = true
@ -44,12 +41,12 @@ ndarray = { version = "0.15.0" }
[dev-dependencies]
ai = { path = "../ai", features = ["test-support"] }
collections = { path = "../collections", features = ["test-support"] }
gpui = { path = "../gpui", features = ["test-support"] }
gpui = { package = "gpui2", path = "../gpui2", features = ["test-support"] }
language = { path = "../language", features = ["test-support"] }
project = { path = "../project", features = ["test-support"] }
rpc = { path = "../rpc", features = ["test-support"] }
workspace = { path = "../workspace", features = ["test-support"] }
settings = { path = "../settings", features = ["test-support"]}
settings = { package = "settings2", path = "../settings2", features = ["test-support"]}
rust-embed = { version = "8.0", features = ["include-exclude"] }
client = { path = "../client" }
node_runtime = { path = "../node_runtime"}

View File

@ -6,7 +6,7 @@ use ai::embedding::Embedding;
use anyhow::{anyhow, Context, Result};
use collections::HashMap;
use futures::channel::oneshot;
use gpui::executor;
use gpui::BackgroundExecutor;
use ndarray::{Array1, Array2};
use ordered_float::OrderedFloat;
use project::Fs;
@ -48,7 +48,7 @@ impl VectorDatabase {
pub async fn new(
fs: Arc<dyn Fs>,
path: Arc<Path>,
executor: Arc<executor::Background>,
executor: BackgroundExecutor,
) -> Result<Self> {
if let Some(db_directory) = path.parent() {
fs.create_dir(db_directory).await?;

View File

@ -1,6 +1,6 @@
use crate::{parsing::Span, JobHandle};
use ai::embedding::EmbeddingProvider;
use gpui::executor::Background;
use gpui::BackgroundExecutor;
use parking_lot::Mutex;
use smol::channel;
use std::{mem, ops::Range, path::Path, sync::Arc, time::SystemTime};
@ -37,7 +37,7 @@ impl PartialEq for FileToEmbed {
pub struct EmbeddingQueue {
embedding_provider: Arc<dyn EmbeddingProvider>,
pending_batch: Vec<FileFragmentToEmbed>,
executor: Arc<Background>,
executor: BackgroundExecutor,
pending_batch_token_count: usize,
finished_files_tx: channel::Sender<FileToEmbed>,
finished_files_rx: channel::Receiver<FileToEmbed>,
@ -50,7 +50,10 @@ pub struct FileFragmentToEmbed {
}
impl EmbeddingQueue {
pub fn new(embedding_provider: Arc<dyn EmbeddingProvider>, executor: Arc<Background>) -> Self {
pub fn new(
embedding_provider: Arc<dyn EmbeddingProvider>,
executor: BackgroundExecutor,
) -> Self {
let (finished_files_tx, finished_files_rx) = channel::unbounded();
Self {
embedding_provider,

View File

@ -9,12 +9,15 @@ mod semantic_index_tests;
use crate::semantic_index_settings::SemanticIndexSettings;
use ai::embedding::{Embedding, EmbeddingProvider};
use ai::providers::open_ai::OpenAIEmbeddingProvider;
use anyhow::{anyhow, Result};
use anyhow::{anyhow, Context as _, Result};
use collections::{BTreeMap, HashMap, HashSet};
use db::VectorDatabase;
use embedding_queue::{EmbeddingQueue, FileToEmbed};
use futures::{future, FutureExt, StreamExt};
use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle};
use gpui::{
AppContext, AsyncAppContext, BorrowWindow, Context, Model, ModelContext, Task, ViewContext,
WeakModel,
};
use language::{Anchor, Bias, Buffer, Language, LanguageRegistry};
use lazy_static::lazy_static;
use ordered_float::OrderedFloat;
@ -22,6 +25,7 @@ use parking_lot::Mutex;
use parsing::{CodeContextRetriever, Span, SpanDigest, PARSEABLE_ENTIRE_FILE_TYPES};
use postage::watch;
use project::{Fs, PathChange, Project, ProjectEntryId, Worktree, WorktreeId};
use settings::Settings;
use smol::channel;
use std::{
cmp::Reverse,
@ -35,7 +39,7 @@ use std::{
};
use util::paths::PathMatcher;
use util::{channel::RELEASE_CHANNEL_NAME, http::HttpClient, paths::EMBEDDINGS_DIR, ResultExt};
use workspace::WorkspaceCreated;
use workspace::Workspace;
const SEMANTIC_INDEX_VERSION: usize = 11;
const BACKGROUND_INDEXING_DELAY: Duration = Duration::from_secs(5 * 60);
@ -51,54 +55,54 @@ pub fn init(
language_registry: Arc<LanguageRegistry>,
cx: &mut AppContext,
) {
settings::register::<SemanticIndexSettings>(cx);
SemanticIndexSettings::register(cx);
let db_file_path = EMBEDDINGS_DIR
.join(Path::new(RELEASE_CHANNEL_NAME.as_str()))
.join("embeddings_db");
cx.subscribe_global::<WorkspaceCreated, _>({
move |event, cx| {
cx.observe_new_views(
|workspace: &mut Workspace, cx: &mut ViewContext<Workspace>| {
let Some(semantic_index) = SemanticIndex::global(cx) else {
return;
};
let workspace = &event.0;
if let Some(workspace) = workspace.upgrade(cx) {
let project = workspace.read(cx).project().clone();
if project.read(cx).is_local() {
cx.spawn(|mut cx| async move {
let project = workspace.project().clone();
if project.read(cx).is_local() {
cx.app_mut()
.spawn(|mut cx| async move {
let previously_indexed = semantic_index
.update(&mut cx, |index, cx| {
index.project_previously_indexed(&project, cx)
})
})?
.await?;
if previously_indexed {
semantic_index
.update(&mut cx, |index, cx| index.index_project(project, cx))
.update(&mut cx, |index, cx| index.index_project(project, cx))?
.await?;
}
anyhow::Ok(())
})
.detach_and_log_err(cx);
}
}
}
})
},
)
.detach();
cx.spawn(move |mut cx| async move {
cx.spawn(move |cx| async move {
let semantic_index = SemanticIndex::new(
fs,
db_file_path,
Arc::new(OpenAIEmbeddingProvider::new(http_client, cx.background())),
Arc::new(OpenAIEmbeddingProvider::new(
http_client,
cx.background_executor().clone(),
)),
language_registry,
cx.clone(),
)
.await?;
cx.update(|cx| {
cx.set_global(semantic_index.clone());
});
cx.update(|cx| cx.set_global(semantic_index.clone()))?;
anyhow::Ok(())
})
@ -124,7 +128,7 @@ pub struct SemanticIndex {
parsing_files_tx: channel::Sender<(Arc<HashMap<SpanDigest, Embedding>>, PendingFile)>,
_embedding_task: Task<()>,
_parsing_files_tasks: Vec<Task<()>>,
projects: HashMap<WeakModelHandle<Project>, ProjectState>,
projects: HashMap<WeakModel<Project>, ProjectState>,
}
struct ProjectState {
@ -229,12 +233,12 @@ impl ProjectState {
pending_file_count_tx,
pending_index: 0,
_subscription: subscription,
_observe_pending_file_count: cx.spawn_weak({
_observe_pending_file_count: cx.spawn({
let mut pending_file_count_rx = pending_file_count_rx.clone();
|this, mut cx| async move {
while let Some(_) = pending_file_count_rx.next().await {
if let Some(this) = this.upgrade(&cx) {
this.update(&mut cx, |_, cx| cx.notify());
if this.update(&mut cx, |_, cx| cx.notify()).is_err() {
break;
}
}
}
@ -264,21 +268,21 @@ pub struct PendingFile {
#[derive(Clone)]
pub struct SearchResult {
pub buffer: ModelHandle<Buffer>,
pub buffer: Model<Buffer>,
pub range: Range<Anchor>,
pub similarity: OrderedFloat<f32>,
}
impl SemanticIndex {
pub fn global(cx: &mut AppContext) -> Option<ModelHandle<SemanticIndex>> {
if cx.has_global::<ModelHandle<Self>>() {
Some(cx.global::<ModelHandle<SemanticIndex>>().clone())
pub fn global(cx: &mut AppContext) -> Option<Model<SemanticIndex>> {
if cx.has_global::<Model<Self>>() {
Some(cx.global::<Model<SemanticIndex>>().clone())
} else {
None
}
}
pub fn authenticate(&mut self, cx: &AppContext) -> bool {
pub fn authenticate(&mut self, cx: &mut AppContext) -> bool {
if !self.embedding_provider.has_credentials() {
self.embedding_provider.retrieve_credentials(cx);
} else {
@ -293,10 +297,10 @@ impl SemanticIndex {
}
pub fn enabled(cx: &AppContext) -> bool {
settings::get::<SemanticIndexSettings>(cx).enabled
SemanticIndexSettings::get_global(cx).enabled
}
pub fn status(&self, project: &ModelHandle<Project>) -> SemanticIndexStatus {
pub fn status(&self, project: &Model<Project>) -> SemanticIndexStatus {
if !self.is_authenticated() {
return SemanticIndexStatus::NotAuthenticated;
}
@ -326,21 +330,22 @@ impl SemanticIndex {
embedding_provider: Arc<dyn EmbeddingProvider>,
language_registry: Arc<LanguageRegistry>,
mut cx: AsyncAppContext,
) -> Result<ModelHandle<Self>> {
) -> Result<Model<Self>> {
let t0 = Instant::now();
let database_path = Arc::from(database_path);
let db = VectorDatabase::new(fs.clone(), database_path, cx.background()).await?;
let db = VectorDatabase::new(fs.clone(), database_path, cx.background_executor().clone())
.await?;
log::trace!(
"db initialization took {:?} milliseconds",
t0.elapsed().as_millis()
);
Ok(cx.add_model(|cx| {
cx.new_model(|cx| {
let t0 = Instant::now();
let embedding_queue =
EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone());
let _embedding_task = cx.background().spawn({
EmbeddingQueue::new(embedding_provider.clone(), cx.background_executor().clone());
let _embedding_task = cx.background_executor().spawn({
let embedded_files = embedding_queue.finished_files();
let db = db.clone();
async move {
@ -357,13 +362,13 @@ impl SemanticIndex {
channel::unbounded::<(Arc<HashMap<SpanDigest, Embedding>>, PendingFile)>();
let embedding_queue = Arc::new(Mutex::new(embedding_queue));
let mut _parsing_files_tasks = Vec::new();
for _ in 0..cx.background().num_cpus() {
for _ in 0..cx.background_executor().num_cpus() {
let fs = fs.clone();
let mut parsing_files_rx = parsing_files_rx.clone();
let embedding_provider = embedding_provider.clone();
let embedding_queue = embedding_queue.clone();
let background = cx.background().clone();
_parsing_files_tasks.push(cx.background().spawn(async move {
let background = cx.background_executor().clone();
_parsing_files_tasks.push(cx.background_executor().spawn(async move {
let mut retriever = CodeContextRetriever::new(embedding_provider.clone());
loop {
let mut timer = background.timer(EMBEDDING_QUEUE_FLUSH_TIMEOUT).fuse();
@ -405,7 +410,7 @@ impl SemanticIndex {
_parsing_files_tasks,
projects: Default::default(),
}
}))
})
}
async fn parse_file(
@ -449,12 +454,12 @@ impl SemanticIndex {
pub fn project_previously_indexed(
&mut self,
project: &ModelHandle<Project>,
project: &Model<Project>,
cx: &mut ModelContext<Self>,
) -> Task<Result<bool>> {
let worktrees_indexed_previously = project
.read(cx)
.worktrees(cx)
.worktrees()
.map(|worktree| {
self.db
.worktree_previously_indexed(&worktree.read(cx).abs_path())
@ -473,7 +478,7 @@ impl SemanticIndex {
fn project_entries_changed(
&mut self,
project: ModelHandle<Project>,
project: Model<Project>,
worktree_id: WorktreeId,
changes: Arc<[(Arc<Path>, ProjectEntryId, PathChange)]>,
cx: &mut ModelContext<Self>,
@ -495,22 +500,25 @@ impl SemanticIndex {
};
worktree_state.paths_changed(changes, worktree);
if let WorktreeState::Registered(_) = worktree_state {
cx.spawn_weak(|this, mut cx| async move {
cx.background().timer(BACKGROUND_INDEXING_DELAY).await;
if let Some((this, project)) = this.upgrade(&cx).zip(project.upgrade(&cx)) {
cx.spawn(|this, mut cx| async move {
cx.background_executor()
.timer(BACKGROUND_INDEXING_DELAY)
.await;
if let Some((this, project)) = this.upgrade().zip(project.upgrade()) {
this.update(&mut cx, |this, cx| {
this.index_project(project, cx).detach_and_log_err(cx)
});
})?;
}
anyhow::Ok(())
})
.detach();
.detach_and_log_err(cx);
}
}
fn register_worktree(
&mut self,
project: ModelHandle<Project>,
worktree: ModelHandle<Worktree>,
project: Model<Project>,
worktree: Model<Worktree>,
cx: &mut ModelContext<Self>,
) {
let project = project.downgrade();
@ -536,16 +544,18 @@ impl SemanticIndex {
scan_complete.await;
let db_id = db.find_or_create_worktree(worktree_abs_path).await?;
let mut file_mtimes = db.get_file_mtimes(db_id).await?;
let worktree = if let Some(project) = project.upgrade(&cx) {
let worktree = if let Some(project) = project.upgrade() {
project
.read_with(&cx, |project, cx| project.worktree_for_id(worktree_id, cx))
.ok_or_else(|| anyhow!("worktree not found"))?
.ok()
.flatten()
.context("worktree not found")?
} else {
return anyhow::Ok(());
};
let worktree = worktree.read_with(&cx, |worktree, _| worktree.snapshot());
let worktree = worktree.read_with(&cx, |worktree, _| worktree.snapshot())?;
let mut changed_paths = cx
.background()
.background_executor()
.spawn(async move {
let mut changed_paths = BTreeMap::new();
for file in worktree.files(false, 0) {
@ -607,10 +617,8 @@ impl SemanticIndex {
let project_state = this
.projects
.get_mut(&project)
.ok_or_else(|| anyhow!("project not registered"))?;
let project = project
.upgrade(cx)
.ok_or_else(|| anyhow!("project was dropped"))?;
.context("project not registered")?;
let project = project.upgrade().context("project was dropped")?;
if let Some(WorktreeState::Registering(state)) =
project_state.worktrees.remove(&worktree_id)
@ -627,7 +635,7 @@ impl SemanticIndex {
this.index_project(project, cx).detach_and_log_err(cx);
anyhow::Ok(())
})?;
})??;
anyhow::Ok(())
};
@ -639,6 +647,7 @@ impl SemanticIndex {
project_state.worktrees.remove(&worktree_id);
});
})
.ok();
}
*done_tx.borrow_mut() = Some(());
@ -654,11 +663,7 @@ impl SemanticIndex {
);
}
fn project_worktrees_changed(
&mut self,
project: ModelHandle<Project>,
cx: &mut ModelContext<Self>,
) {
fn project_worktrees_changed(&mut self, project: Model<Project>, cx: &mut ModelContext<Self>) {
let project_state = if let Some(project_state) = self.projects.get_mut(&project.downgrade())
{
project_state
@ -668,7 +673,7 @@ impl SemanticIndex {
let mut worktrees = project
.read(cx)
.worktrees(cx)
.worktrees()
.filter(|worktree| worktree.read(cx).is_local())
.collect::<Vec<_>>();
let worktree_ids = worktrees
@ -691,10 +696,7 @@ impl SemanticIndex {
}
}
pub fn pending_file_count(
&self,
project: &ModelHandle<Project>,
) -> Option<watch::Receiver<usize>> {
pub fn pending_file_count(&self, project: &Model<Project>) -> Option<watch::Receiver<usize>> {
Some(
self.projects
.get(&project.downgrade())?
@ -705,7 +707,7 @@ impl SemanticIndex {
pub fn search_project(
&mut self,
project: ModelHandle<Project>,
project: Model<Project>,
query: String,
limit: usize,
includes: Vec<PathMatcher>,
@ -727,7 +729,7 @@ impl SemanticIndex {
.embed_batch(vec![query])
.await?
.pop()
.ok_or_else(|| anyhow!("could not embed query"))?;
.context("could not embed query")?;
log::trace!("Embedding Search Query: {:?}ms", t0.elapsed().as_millis());
let search_start = Instant::now();
@ -740,10 +742,10 @@ impl SemanticIndex {
&excludes,
cx,
)
});
})?;
let file_results = this.update(&mut cx, |this, cx| {
this.search_files(project, query, limit, includes, excludes, cx)
});
})?;
let (modified_buffer_results, file_results) =
futures::join!(modified_buffer_results, file_results);
@ -768,7 +770,7 @@ impl SemanticIndex {
pub fn search_files(
&mut self,
project: ModelHandle<Project>,
project: Model<Project>,
query: Embedding,
limit: usize,
includes: Vec<PathMatcher>,
@ -778,14 +780,18 @@ impl SemanticIndex {
let db_path = self.db.path().clone();
let fs = self.fs.clone();
cx.spawn(|this, mut cx| async move {
let database =
VectorDatabase::new(fs.clone(), db_path.clone(), cx.background()).await?;
let database = VectorDatabase::new(
fs.clone(),
db_path.clone(),
cx.background_executor().clone(),
)
.await?;
let worktree_db_ids = this.read_with(&cx, |this, _| {
let project_state = this
.projects
.get(&project.downgrade())
.ok_or_else(|| anyhow!("project was not indexed"))?;
.context("project was not indexed")?;
let worktree_db_ids = project_state
.worktrees
.values()
@ -798,13 +804,13 @@ impl SemanticIndex {
})
.collect::<Vec<i64>>();
anyhow::Ok(worktree_db_ids)
})?;
})??;
let file_ids = database
.retrieve_included_file_ids(&worktree_db_ids, &includes, &excludes)
.await?;
let batch_n = cx.background().num_cpus();
let batch_n = cx.background_executor().num_cpus();
let ids_len = file_ids.clone().len();
let minimum_batch_size = 50;
@ -824,9 +830,10 @@ impl SemanticIndex {
let fs = fs.clone();
let db_path = db_path.clone();
let query = query.clone();
if let Some(db) = VectorDatabase::new(fs, db_path.clone(), cx.background())
.await
.log_err()
if let Some(db) =
VectorDatabase::new(fs, db_path.clone(), cx.background_executor().clone())
.await
.log_err()
{
batch_results.push(async move {
db.top_k_search(&query, limit, batch.as_slice()).await
@ -864,6 +871,7 @@ impl SemanticIndex {
let mut ranges = Vec::new();
let weak_project = project.downgrade();
project.update(&mut cx, |project, cx| {
let this = this.upgrade().context("index was dropped")?;
for (worktree_db_id, file_path, byte_range) in spans {
let project_state =
if let Some(state) = this.read(cx).projects.get(&weak_project) {
@ -878,7 +886,7 @@ impl SemanticIndex {
}
Ok(())
})?;
})??;
let buffers = futures::future::join_all(tasks).await;
Ok(buffers
@ -887,11 +895,13 @@ impl SemanticIndex {
.zip(scores)
.filter_map(|((buffer, range), similarity)| {
let buffer = buffer.log_err()?;
let range = buffer.read_with(&cx, |buffer, _| {
let start = buffer.clip_offset(range.start, Bias::Left);
let end = buffer.clip_offset(range.end, Bias::Right);
buffer.anchor_before(start)..buffer.anchor_after(end)
});
let range = buffer
.read_with(&cx, |buffer, _| {
let start = buffer.clip_offset(range.start, Bias::Left);
let end = buffer.clip_offset(range.end, Bias::Right);
buffer.anchor_before(start)..buffer.anchor_after(end)
})
.log_err()?;
Some(SearchResult {
buffer,
range,
@ -904,7 +914,7 @@ impl SemanticIndex {
fn search_modified_buffers(
&self,
project: &ModelHandle<Project>,
project: &Model<Project>,
query: Embedding,
limit: usize,
includes: &[PathMatcher],
@ -913,7 +923,7 @@ impl SemanticIndex {
) -> Task<Result<Vec<SearchResult>>> {
let modified_buffers = project
.read(cx)
.opened_buffers(cx)
.opened_buffers()
.into_iter()
.filter_map(|buffer_handle| {
let buffer = buffer_handle.read(cx);
@ -941,8 +951,8 @@ impl SemanticIndex {
let embedding_provider = self.embedding_provider.clone();
let fs = self.fs.clone();
let db_path = self.db.path().clone();
let background = cx.background().clone();
cx.background().spawn(async move {
let background = cx.background_executor().clone();
cx.background_executor().spawn(async move {
let db = VectorDatabase::new(fs, db_path.clone(), background).await?;
let mut results = Vec::<SearchResult>::new();
@ -996,7 +1006,7 @@ impl SemanticIndex {
pub fn index_project(
&mut self,
project: ModelHandle<Project>,
project: Model<Project>,
cx: &mut ModelContext<Self>,
) -> Task<Result<()>> {
if !self.is_authenticated() {
@ -1038,7 +1048,7 @@ impl SemanticIndex {
let project_state = this
.projects
.get_mut(&project.downgrade())
.ok_or_else(|| anyhow!("project was dropped"))?;
.context("project was dropped")?;
let pending_file_count_tx = &project_state.pending_file_count_tx;
project_state
@ -1080,9 +1090,9 @@ impl SemanticIndex {
});
anyhow::Ok(())
})?;
})??;
cx.background()
cx.background_executor()
.spawn(async move {
for (worktree_db_id, path) in files_to_delete {
db.delete_file(worktree_db_id, path).await.log_err();
@ -1138,11 +1148,11 @@ impl SemanticIndex {
let project_state = this
.projects
.get_mut(&project.downgrade())
.ok_or_else(|| anyhow!("project was dropped"))?;
.context("project was dropped")?;
project_state.pending_index -= 1;
cx.notify();
anyhow::Ok(())
})?;
})??;
Ok(())
})
@ -1150,15 +1160,15 @@ impl SemanticIndex {
fn wait_for_worktree_registration(
&self,
project: &ModelHandle<Project>,
project: &Model<Project>,
cx: &mut ModelContext<Self>,
) -> Task<Result<()>> {
let project = project.downgrade();
cx.spawn_weak(|this, cx| async move {
cx.spawn(|this, cx| async move {
loop {
let mut pending_worktrees = Vec::new();
this.upgrade(&cx)
.ok_or_else(|| anyhow!("semantic index dropped"))?
this.upgrade()
.context("semantic index dropped")?
.read_with(&cx, |this, _| {
if let Some(project) = this.projects.get(&project) {
for worktree in project.worktrees.values() {
@ -1167,7 +1177,7 @@ impl SemanticIndex {
}
}
}
});
})?;
if pending_worktrees.is_empty() {
break;
@ -1230,17 +1240,13 @@ impl SemanticIndex {
} else {
embeddings.next()
};
let embedding = embedding.ok_or_else(|| anyhow!("failed to embed spans"))?;
let embedding = embedding.context("failed to embed spans")?;
span.embedding = Some(embedding);
}
Ok(())
}
}
impl Entity for SemanticIndex {
type Event = ();
}
impl Drop for JobHandle {
fn drop(&mut self) {
if let Some(inner) = Arc::get_mut(&mut self.tx) {

View File

@ -1,7 +1,7 @@
use anyhow;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::Setting;
use settings::Settings;
#[derive(Deserialize, Debug)]
pub struct SemanticIndexSettings {
@ -13,7 +13,7 @@ pub struct SemanticIndexSettingsContent {
pub enabled: Option<bool>,
}
impl Setting for SemanticIndexSettings {
impl Settings for SemanticIndexSettings {
const KEY: Option<&'static str> = Some("semantic_index");
type FileContent = SemanticIndexSettingsContent;
@ -21,7 +21,7 @@ impl Setting for SemanticIndexSettings {
fn load(
default_value: &Self::FileContent,
user_values: &[&Self::FileContent],
_: &gpui::AppContext,
_: &mut gpui::AppContext,
) -> anyhow::Result<Self> {
Self::load_via_json_merge(default_value, user_values)
}

View File

@ -6,14 +6,14 @@ use crate::{
};
use ai::test::FakeEmbeddingProvider;
use gpui::{executor::Deterministic, Task, TestAppContext};
use gpui::{Task, TestAppContext};
use language::{Language, LanguageConfig, LanguageRegistry, ToOffset};
use parking_lot::Mutex;
use pretty_assertions::assert_eq;
use project::{project_settings::ProjectSettings, FakeFs, Fs, Project};
use rand::{rngs::StdRng, Rng};
use serde_json::json;
use settings::SettingsStore;
use settings::{Settings, SettingsStore};
use std::{path::Path, sync::Arc, time::SystemTime};
use unindent::Unindent;
use util::{paths::PathMatcher, RandomCharIter};
@ -26,10 +26,10 @@ fn init_logger() {
}
#[gpui::test]
async fn test_semantic_index(deterministic: Arc<Deterministic>, cx: &mut TestAppContext) {
async fn test_semantic_index(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.background());
let fs = FakeFs::new(cx.background_executor.clone());
fs.insert_tree(
"/the-root",
json!({
@ -91,9 +91,10 @@ async fn test_semantic_index(deterministic: Arc<Deterministic>, cx: &mut TestApp
});
let pending_file_count =
semantic_index.read_with(cx, |index, _| index.pending_file_count(&project).unwrap());
deterministic.run_until_parked();
cx.background_executor.run_until_parked();
assert_eq!(*pending_file_count.borrow(), 3);
deterministic.advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
cx.background_executor
.advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
assert_eq!(*pending_file_count.borrow(), 0);
let search_results = search_results.await.unwrap();
@ -170,13 +171,15 @@ async fn test_semantic_index(deterministic: Arc<Deterministic>, cx: &mut TestApp
.await
.unwrap();
deterministic.advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
cx.background_executor
.advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
let prev_embedding_count = embedding_provider.embedding_count();
let index = semantic_index.update(cx, |store, cx| store.index_project(project.clone(), cx));
deterministic.run_until_parked();
cx.background_executor.run_until_parked();
assert_eq!(*pending_file_count.borrow(), 1);
deterministic.advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
cx.background_executor
.advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
assert_eq!(*pending_file_count.borrow(), 0);
index.await.unwrap();
@ -220,13 +223,13 @@ async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) {
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
let mut queue = EmbeddingQueue::new(embedding_provider.clone(), cx.background());
let mut queue = EmbeddingQueue::new(embedding_provider.clone(), cx.background_executor.clone());
for file in &files {
queue.push(file.clone());
}
queue.flush();
cx.foreground().run_until_parked();
cx.background_executor.run_until_parked();
let finished_files = queue.finished_files();
let mut embedded_files: Vec<_> = files
.iter()
@ -1686,8 +1689,9 @@ fn test_subtract_ranges() {
fn init_test(cx: &mut TestAppContext) {
cx.update(|cx| {
cx.set_global(SettingsStore::test(cx));
settings::register::<SemanticIndexSettings>(cx);
settings::register::<ProjectSettings>(cx);
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
SemanticIndexSettings::register(cx);
ProjectSettings::register(cx);
});
}

View File

@ -1,69 +0,0 @@
[package]
name = "semantic_index2"
version = "0.1.0"
edition = "2021"
publish = false
[lib]
path = "src/semantic_index.rs"
doctest = false
[dependencies]
ai = { path = "../ai" }
collections = { path = "../collections" }
gpui = { package = "gpui2", path = "../gpui2" }
language = { path = "../language" }
project = { path = "../project" }
workspace = { path = "../workspace" }
util = { path = "../util" }
rpc = { package = "rpc2", path = "../rpc2" }
settings = { package = "settings2", path = "../settings2" }
anyhow.workspace = true
postage.workspace = true
futures.workspace = true
ordered-float.workspace = true
smol.workspace = true
rusqlite.workspace = true
log.workspace = true
tree-sitter.workspace = true
lazy_static.workspace = true
serde.workspace = true
serde_json.workspace = true
async-trait.workspace = true
tiktoken-rs.workspace = true
parking_lot.workspace = true
rand.workspace = true
schemars.workspace = true
globset.workspace = true
sha1 = "0.10.5"
ndarray = { version = "0.15.0" }
[dev-dependencies]
ai = { path = "../ai", features = ["test-support"] }
collections = { path = "../collections", features = ["test-support"] }
gpui = { package = "gpui2", path = "../gpui2", features = ["test-support"] }
language = { path = "../language", features = ["test-support"] }
project = { path = "../project", features = ["test-support"] }
rpc = { package = "rpc2", path = "../rpc2", features = ["test-support"] }
workspace = { path = "../workspace", features = ["test-support"] }
settings = { package = "settings2", path = "../settings2", features = ["test-support"]}
rust-embed = { version = "8.0", features = ["include-exclude"] }
client = { path = "../client" }
node_runtime = { path = "../node_runtime"}
pretty_assertions.workspace = true
rand.workspace = true
unindent.workspace = true
tempdir.workspace = true
ctor.workspace = true
env_logger.workspace = true
tree-sitter-typescript.workspace = true
tree-sitter-json.workspace = true
tree-sitter-rust.workspace = true
tree-sitter-toml.workspace = true
tree-sitter-cpp.workspace = true
tree-sitter-elixir.workspace = true
tree-sitter-lua.workspace = true
tree-sitter-ruby.workspace = true
tree-sitter-php.workspace = true

View File

@ -1,20 +0,0 @@
# Semantic Index
## Evaluation
### Metrics
nDCG@k:
- "The value of NDCG is determined by comparing the relevance of the items returned by the search engine to the relevance of the item that a hypothetical "ideal" search engine would return.
- "The relevance of result is represented by a score (also known as a 'grade') that is assigned to the search query. The scores of these results are then discounted based on their position in the search results -- did they get recommended first or last?"
MRR@k:
- "Mean reciprocal rank quantifies the rank of the first relevant item found in teh recommendation list."
MAP@k:
- "Mean average precision averages the precision@k metric at each relevant item position in the recommendation list.
Resources:
- [Evaluating recommendation metrics](https://www.shaped.ai/blog/evaluating-recommendation-systems-map-mmr-ndcg)
- [Math Walkthrough](https://towardsdatascience.com/demystifying-ndcg-bee3be58cfe0)

View File

@ -1,114 +0,0 @@
{
"repo": "https://github.com/AntonOsika/gpt-engineer.git",
"commit": "7735a6445bae3611c62f521e6464c67c957f87c2",
"assertions": [
{
"query": "How do I contribute to this project?",
"matches": [
".github/CONTRIBUTING.md:1",
"ROADMAP.md:48"
]
},
{
"query": "What version of the openai package is active?",
"matches": [
"pyproject.toml:14"
]
},
{
"query": "Ask user for clarification",
"matches": [
"gpt_engineer/steps.py:69"
]
},
{
"query": "generate tests for python code",
"matches": [
"gpt_engineer/steps.py:153"
]
},
{
"query": "get item from database based on key",
"matches": [
"gpt_engineer/db.py:42",
"gpt_engineer/db.py:68"
]
},
{
"query": "prompt user to select files",
"matches": [
"gpt_engineer/file_selector.py:171",
"gpt_engineer/file_selector.py:306",
"gpt_engineer/file_selector.py:289",
"gpt_engineer/file_selector.py:234"
]
},
{
"query": "send to rudderstack",
"matches": [
"gpt_engineer/collect.py:11",
"gpt_engineer/collect.py:38"
]
},
{
"query": "parse code blocks from chat messages",
"matches": [
"gpt_engineer/chat_to_files.py:10",
"docs/intro/chat_parsing.md:1"
]
},
{
"query": "how do I use the docker cli?",
"matches": [
"docker/README.md:1"
]
},
{
"query": "ask the user if the code ran successfully?",
"matches": [
"gpt_engineer/learning.py:54"
]
},
{
"query": "how is consent granted by the user?",
"matches": [
"gpt_engineer/learning.py:107",
"gpt_engineer/learning.py:130",
"gpt_engineer/learning.py:152"
]
},
{
"query": "what are all the different steps the agent can take?",
"matches": [
"docs/intro/steps_module.md:1",
"gpt_engineer/steps.py:391"
]
},
{
"query": "ask the user for clarification?",
"matches": [
"gpt_engineer/steps.py:69"
]
},
{
"query": "what models are available?",
"matches": [
"gpt_engineer/ai.py:315",
"gpt_engineer/ai.py:341",
"docs/open-models.md:1"
]
},
{
"query": "what is the current focus of the project?",
"matches": [
"ROADMAP.md:11"
]
},
{
"query": "does the agent know how to fix code?",
"matches": [
"gpt_engineer/steps.py:367"
]
}
]
}

View File

@ -1,104 +0,0 @@
{
"repo": "https://github.com/tree-sitter/tree-sitter.git",
"commit": "46af27796a76c72d8466627d499f2bca4af958ee",
"assertions": [
{
"query": "What attributes are available for the tags configuration struct?",
"matches": [
"tags/src/lib.rs:24"
]
},
{
"query": "create a new tag configuration",
"matches": [
"tags/src/lib.rs:119"
]
},
{
"query": "generate tags based on config",
"matches": [
"tags/src/lib.rs:261"
]
},
{
"query": "match on ts quantifier in rust",
"matches": [
"lib/binding_rust/lib.rs:139"
]
},
{
"query": "cli command to generate tags",
"matches": [
"cli/src/tags.rs:10"
]
},
{
"query": "what version of the tree-sitter-tags package is active?",
"matches": [
"tags/Cargo.toml:4"
]
},
{
"query": "Insert a new parse state",
"matches": [
"cli/src/generate/build_tables/build_parse_table.rs:153"
]
},
{
"query": "Handle conflict when numerous actions occur on the same symbol",
"matches": [
"cli/src/generate/build_tables/build_parse_table.rs:363",
"cli/src/generate/build_tables/build_parse_table.rs:442"
]
},
{
"query": "Match based on associativity of actions",
"matches": [
"cri/src/generate/build_tables/build_parse_table.rs:542"
]
},
{
"query": "Format token set display",
"matches": [
"cli/src/generate/build_tables/item.rs:246"
]
},
{
"query": "extract choices from rule",
"matches": [
"cli/src/generate/prepare_grammar/flatten_grammar.rs:124"
]
},
{
"query": "How do we identify if a symbol is being used?",
"matches": [
"cli/src/generate/prepare_grammar/flatten_grammar.rs:175"
]
},
{
"query": "How do we launch the playground?",
"matches": [
"cli/src/playground.rs:46"
]
},
{
"query": "How do we test treesitter query matches in rust?",
"matches": [
"cli/src/query_testing.rs:152",
"cli/src/tests/query_test.rs:781",
"cli/src/tests/query_test.rs:2163",
"cli/src/tests/query_test.rs:3781",
"cli/src/tests/query_test.rs:887"
]
},
{
"query": "What does the CLI do?",
"matches": [
"cli/README.md:10",
"cli/loader/README.md:3",
"docs/section-5-implementation.md:14",
"docs/section-5-implementation.md:18"
]
}
]
}

View File

@ -1,603 +0,0 @@
use crate::{
parsing::{Span, SpanDigest},
SEMANTIC_INDEX_VERSION,
};
use ai::embedding::Embedding;
use anyhow::{anyhow, Context, Result};
use collections::HashMap;
use futures::channel::oneshot;
use gpui::BackgroundExecutor;
use ndarray::{Array1, Array2};
use ordered_float::OrderedFloat;
use project::Fs;
use rpc::proto::Timestamp;
use rusqlite::params;
use rusqlite::types::Value;
use std::{
future::Future,
ops::Range,
path::{Path, PathBuf},
rc::Rc,
sync::Arc,
time::SystemTime,
};
use util::{paths::PathMatcher, TryFutureExt};
pub fn argsort<T: Ord>(data: &[T]) -> Vec<usize> {
let mut indices = (0..data.len()).collect::<Vec<_>>();
indices.sort_by_key(|&i| &data[i]);
indices.reverse();
indices
}
#[derive(Debug)]
pub struct FileRecord {
pub id: usize,
pub relative_path: String,
pub mtime: Timestamp,
}
#[derive(Clone)]
pub struct VectorDatabase {
path: Arc<Path>,
transactions:
smol::channel::Sender<Box<dyn 'static + Send + FnOnce(&mut rusqlite::Connection)>>,
}
impl VectorDatabase {
pub async fn new(
fs: Arc<dyn Fs>,
path: Arc<Path>,
executor: BackgroundExecutor,
) -> Result<Self> {
if let Some(db_directory) = path.parent() {
fs.create_dir(db_directory).await?;
}
let (transactions_tx, transactions_rx) = smol::channel::unbounded::<
Box<dyn 'static + Send + FnOnce(&mut rusqlite::Connection)>,
>();
executor
.spawn({
let path = path.clone();
async move {
let mut connection = rusqlite::Connection::open(&path)?;
connection.pragma_update(None, "journal_mode", "wal")?;
connection.pragma_update(None, "synchronous", "normal")?;
connection.pragma_update(None, "cache_size", 1000000)?;
connection.pragma_update(None, "temp_store", "MEMORY")?;
while let Ok(transaction) = transactions_rx.recv().await {
transaction(&mut connection);
}
anyhow::Ok(())
}
.log_err()
})
.detach();
let this = Self {
transactions: transactions_tx,
path,
};
this.initialize_database().await?;
Ok(this)
}
pub fn path(&self) -> &Arc<Path> {
&self.path
}
fn transact<F, T>(&self, f: F) -> impl Future<Output = Result<T>>
where
F: 'static + Send + FnOnce(&rusqlite::Transaction) -> Result<T>,
T: 'static + Send,
{
let (tx, rx) = oneshot::channel();
let transactions = self.transactions.clone();
async move {
if transactions
.send(Box::new(|connection| {
let result = connection
.transaction()
.map_err(|err| anyhow!(err))
.and_then(|transaction| {
let result = f(&transaction)?;
transaction.commit()?;
Ok(result)
});
let _ = tx.send(result);
}))
.await
.is_err()
{
return Err(anyhow!("connection was dropped"))?;
}
rx.await?
}
}
fn initialize_database(&self) -> impl Future<Output = Result<()>> {
self.transact(|db| {
rusqlite::vtab::array::load_module(&db)?;
// Delete existing tables, if SEMANTIC_INDEX_VERSION is bumped
let version_query = db.prepare("SELECT version from semantic_index_config");
let version = version_query
.and_then(|mut query| query.query_row([], |row| Ok(row.get::<_, i64>(0)?)));
if version.map_or(false, |version| version == SEMANTIC_INDEX_VERSION as i64) {
log::trace!("vector database schema up to date");
return Ok(());
}
log::trace!("vector database schema out of date. updating...");
// We renamed the `documents` table to `spans`, so we want to drop
// `documents` without recreating it if it exists.
db.execute("DROP TABLE IF EXISTS documents", [])
.context("failed to drop 'documents' table")?;
db.execute("DROP TABLE IF EXISTS spans", [])
.context("failed to drop 'spans' table")?;
db.execute("DROP TABLE IF EXISTS files", [])
.context("failed to drop 'files' table")?;
db.execute("DROP TABLE IF EXISTS worktrees", [])
.context("failed to drop 'worktrees' table")?;
db.execute("DROP TABLE IF EXISTS semantic_index_config", [])
.context("failed to drop 'semantic_index_config' table")?;
// Initialize Vector Databasing Tables
db.execute(
"CREATE TABLE semantic_index_config (
version INTEGER NOT NULL
)",
[],
)?;
db.execute(
"INSERT INTO semantic_index_config (version) VALUES (?1)",
params![SEMANTIC_INDEX_VERSION],
)?;
db.execute(
"CREATE TABLE worktrees (
id INTEGER PRIMARY KEY AUTOINCREMENT,
absolute_path VARCHAR NOT NULL
);
CREATE UNIQUE INDEX worktrees_absolute_path ON worktrees (absolute_path);
",
[],
)?;
db.execute(
"CREATE TABLE files (
id INTEGER PRIMARY KEY AUTOINCREMENT,
worktree_id INTEGER NOT NULL,
relative_path VARCHAR NOT NULL,
mtime_seconds INTEGER NOT NULL,
mtime_nanos INTEGER NOT NULL,
FOREIGN KEY(worktree_id) REFERENCES worktrees(id) ON DELETE CASCADE
)",
[],
)?;
db.execute(
"CREATE UNIQUE INDEX files_worktree_id_and_relative_path ON files (worktree_id, relative_path)",
[],
)?;
db.execute(
"CREATE TABLE spans (
id INTEGER PRIMARY KEY AUTOINCREMENT,
file_id INTEGER NOT NULL,
start_byte INTEGER NOT NULL,
end_byte INTEGER NOT NULL,
name VARCHAR NOT NULL,
embedding BLOB NOT NULL,
digest BLOB NOT NULL,
FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE
)",
[],
)?;
db.execute(
"CREATE INDEX spans_digest ON spans (digest)",
[],
)?;
log::trace!("vector database initialized with updated schema.");
Ok(())
})
}
pub fn delete_file(
&self,
worktree_id: i64,
delete_path: Arc<Path>,
) -> impl Future<Output = Result<()>> {
self.transact(move |db| {
db.execute(
"DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2",
params![worktree_id, delete_path.to_str()],
)?;
Ok(())
})
}
pub fn insert_file(
&self,
worktree_id: i64,
path: Arc<Path>,
mtime: SystemTime,
spans: Vec<Span>,
) -> impl Future<Output = Result<()>> {
self.transact(move |db| {
// Return the existing ID, if both the file and mtime match
let mtime = Timestamp::from(mtime);
db.execute(
"
REPLACE INTO files
(worktree_id, relative_path, mtime_seconds, mtime_nanos)
VALUES (?1, ?2, ?3, ?4)
",
params![worktree_id, path.to_str(), mtime.seconds, mtime.nanos],
)?;
let file_id = db.last_insert_rowid();
let mut query = db.prepare(
"
INSERT INTO spans
(file_id, start_byte, end_byte, name, embedding, digest)
VALUES (?1, ?2, ?3, ?4, ?5, ?6)
",
)?;
for span in spans {
query.execute(params![
file_id,
span.range.start.to_string(),
span.range.end.to_string(),
span.name,
span.embedding,
span.digest
])?;
}
Ok(())
})
}
pub fn worktree_previously_indexed(
&self,
worktree_root_path: &Path,
) -> impl Future<Output = Result<bool>> {
let worktree_root_path = worktree_root_path.to_string_lossy().into_owned();
self.transact(move |db| {
let mut worktree_query =
db.prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?;
let worktree_id = worktree_query
.query_row(params![worktree_root_path], |row| Ok(row.get::<_, i64>(0)?));
if worktree_id.is_ok() {
return Ok(true);
} else {
return Ok(false);
}
})
}
pub fn embeddings_for_digests(
&self,
digests: Vec<SpanDigest>,
) -> impl Future<Output = Result<HashMap<SpanDigest, Embedding>>> {
self.transact(move |db| {
let mut query = db.prepare(
"
SELECT digest, embedding
FROM spans
WHERE digest IN rarray(?)
",
)?;
let mut embeddings_by_digest = HashMap::default();
let digests = Rc::new(
digests
.into_iter()
.map(|p| Value::Blob(p.0.to_vec()))
.collect::<Vec<_>>(),
);
let rows = query.query_map(params![digests], |row| {
Ok((row.get::<_, SpanDigest>(0)?, row.get::<_, Embedding>(1)?))
})?;
for row in rows {
if let Ok(row) = row {
embeddings_by_digest.insert(row.0, row.1);
}
}
Ok(embeddings_by_digest)
})
}
pub fn embeddings_for_files(
&self,
worktree_id_file_paths: HashMap<i64, Vec<Arc<Path>>>,
) -> impl Future<Output = Result<HashMap<SpanDigest, Embedding>>> {
self.transact(move |db| {
let mut query = db.prepare(
"
SELECT digest, embedding
FROM spans
LEFT JOIN files ON files.id = spans.file_id
WHERE files.worktree_id = ? AND files.relative_path IN rarray(?)
",
)?;
let mut embeddings_by_digest = HashMap::default();
for (worktree_id, file_paths) in worktree_id_file_paths {
let file_paths = Rc::new(
file_paths
.into_iter()
.map(|p| Value::Text(p.to_string_lossy().into_owned()))
.collect::<Vec<_>>(),
);
let rows = query.query_map(params![worktree_id, file_paths], |row| {
Ok((row.get::<_, SpanDigest>(0)?, row.get::<_, Embedding>(1)?))
})?;
for row in rows {
if let Ok(row) = row {
embeddings_by_digest.insert(row.0, row.1);
}
}
}
Ok(embeddings_by_digest)
})
}
pub fn find_or_create_worktree(
&self,
worktree_root_path: Arc<Path>,
) -> impl Future<Output = Result<i64>> {
self.transact(move |db| {
let mut worktree_query =
db.prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?;
let worktree_id = worktree_query
.query_row(params![worktree_root_path.to_string_lossy()], |row| {
Ok(row.get::<_, i64>(0)?)
});
if worktree_id.is_ok() {
return Ok(worktree_id?);
}
// If worktree_id is Err, insert new worktree
db.execute(
"INSERT into worktrees (absolute_path) VALUES (?1)",
params![worktree_root_path.to_string_lossy()],
)?;
Ok(db.last_insert_rowid())
})
}
pub fn get_file_mtimes(
&self,
worktree_id: i64,
) -> impl Future<Output = Result<HashMap<PathBuf, SystemTime>>> {
self.transact(move |db| {
let mut statement = db.prepare(
"
SELECT relative_path, mtime_seconds, mtime_nanos
FROM files
WHERE worktree_id = ?1
ORDER BY relative_path",
)?;
let mut result: HashMap<PathBuf, SystemTime> = HashMap::default();
for row in statement.query_map(params![worktree_id], |row| {
Ok((
row.get::<_, String>(0)?.into(),
Timestamp {
seconds: row.get(1)?,
nanos: row.get(2)?,
}
.into(),
))
})? {
let row = row?;
result.insert(row.0, row.1);
}
Ok(result)
})
}
pub fn top_k_search(
&self,
query_embedding: &Embedding,
limit: usize,
file_ids: &[i64],
) -> impl Future<Output = Result<Vec<(i64, OrderedFloat<f32>)>>> {
let file_ids = file_ids.to_vec();
let query = query_embedding.clone().0;
let query = Array1::from_vec(query);
self.transact(move |db| {
let mut query_statement = db.prepare(
"
SELECT
id, embedding
FROM
spans
WHERE
file_id IN rarray(?)
",
)?;
let deserialized_rows = query_statement
.query_map(params![ids_to_sql(&file_ids)], |row| {
Ok((row.get::<_, usize>(0)?, row.get::<_, Embedding>(1)?))
})?
.filter_map(|row| row.ok())
.collect::<Vec<(usize, Embedding)>>();
if deserialized_rows.len() == 0 {
return Ok(Vec::new());
}
// Get Length of Embeddings Returned
let embedding_len = deserialized_rows[0].1 .0.len();
let batch_n = 1000;
let mut batches = Vec::new();
let mut batch_ids = Vec::new();
let mut batch_embeddings: Vec<f32> = Vec::new();
deserialized_rows.iter().for_each(|(id, embedding)| {
batch_ids.push(id);
batch_embeddings.extend(&embedding.0);
if batch_ids.len() == batch_n {
let embeddings = std::mem::take(&mut batch_embeddings);
let ids = std::mem::take(&mut batch_ids);
let array =
Array2::from_shape_vec((ids.len(), embedding_len.clone()), embeddings);
match array {
Ok(array) => {
batches.push((ids, array));
}
Err(err) => log::error!("Failed to deserialize to ndarray: {:?}", err),
}
}
});
if batch_ids.len() > 0 {
let array = Array2::from_shape_vec(
(batch_ids.len(), embedding_len),
batch_embeddings.clone(),
);
match array {
Ok(array) => {
batches.push((batch_ids.clone(), array));
}
Err(err) => log::error!("Failed to deserialize to ndarray: {:?}", err),
}
}
let mut ids: Vec<usize> = Vec::new();
let mut results = Vec::new();
for (batch_ids, array) in batches {
let scores = array
.dot(&query.t())
.to_vec()
.iter()
.map(|score| OrderedFloat(*score))
.collect::<Vec<OrderedFloat<f32>>>();
results.extend(scores);
ids.extend(batch_ids);
}
let sorted_idx = argsort(&results);
let mut sorted_results = Vec::new();
let last_idx = limit.min(sorted_idx.len());
for idx in &sorted_idx[0..last_idx] {
sorted_results.push((ids[*idx] as i64, results[*idx]))
}
Ok(sorted_results)
})
}
pub fn retrieve_included_file_ids(
&self,
worktree_ids: &[i64],
includes: &[PathMatcher],
excludes: &[PathMatcher],
) -> impl Future<Output = Result<Vec<i64>>> {
let worktree_ids = worktree_ids.to_vec();
let includes = includes.to_vec();
let excludes = excludes.to_vec();
self.transact(move |db| {
let mut file_query = db.prepare(
"
SELECT
id, relative_path
FROM
files
WHERE
worktree_id IN rarray(?)
",
)?;
let mut file_ids = Vec::<i64>::new();
let mut rows = file_query.query([ids_to_sql(&worktree_ids)])?;
while let Some(row) = rows.next()? {
let file_id = row.get(0)?;
let relative_path = row.get_ref(1)?.as_str()?;
let included =
includes.is_empty() || includes.iter().any(|glob| glob.is_match(relative_path));
let excluded = excludes.iter().any(|glob| glob.is_match(relative_path));
if included && !excluded {
file_ids.push(file_id);
}
}
anyhow::Ok(file_ids)
})
}
pub fn spans_for_ids(
&self,
ids: &[i64],
) -> impl Future<Output = Result<Vec<(i64, PathBuf, Range<usize>)>>> {
let ids = ids.to_vec();
self.transact(move |db| {
let mut statement = db.prepare(
"
SELECT
spans.id,
files.worktree_id,
files.relative_path,
spans.start_byte,
spans.end_byte
FROM
spans, files
WHERE
spans.file_id = files.id AND
spans.id in rarray(?)
",
)?;
let result_iter = statement.query_map(params![ids_to_sql(&ids)], |row| {
Ok((
row.get::<_, i64>(0)?,
row.get::<_, i64>(1)?,
row.get::<_, String>(2)?.into(),
row.get(3)?..row.get(4)?,
))
})?;
let mut values_by_id = HashMap::<i64, (i64, PathBuf, Range<usize>)>::default();
for row in result_iter {
let (id, worktree_id, path, range) = row?;
values_by_id.insert(id, (worktree_id, path, range));
}
let mut results = Vec::with_capacity(ids.len());
for id in &ids {
let value = values_by_id
.remove(id)
.ok_or(anyhow!("missing span id {}", id))?;
results.push(value);
}
Ok(results)
})
}
}
fn ids_to_sql(ids: &[i64]) -> Rc<Vec<rusqlite::types::Value>> {
Rc::new(
ids.iter()
.copied()
.map(|v| rusqlite::types::Value::from(v))
.collect::<Vec<_>>(),
)
}

View File

@ -1,169 +0,0 @@
use crate::{parsing::Span, JobHandle};
use ai::embedding::EmbeddingProvider;
use gpui::BackgroundExecutor;
use parking_lot::Mutex;
use smol::channel;
use std::{mem, ops::Range, path::Path, sync::Arc, time::SystemTime};
#[derive(Clone)]
pub struct FileToEmbed {
pub worktree_id: i64,
pub path: Arc<Path>,
pub mtime: SystemTime,
pub spans: Vec<Span>,
pub job_handle: JobHandle,
}
impl std::fmt::Debug for FileToEmbed {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("FileToEmbed")
.field("worktree_id", &self.worktree_id)
.field("path", &self.path)
.field("mtime", &self.mtime)
.field("spans", &self.spans)
.finish_non_exhaustive()
}
}
impl PartialEq for FileToEmbed {
fn eq(&self, other: &Self) -> bool {
self.worktree_id == other.worktree_id
&& self.path == other.path
&& self.mtime == other.mtime
&& self.spans == other.spans
}
}
pub struct EmbeddingQueue {
embedding_provider: Arc<dyn EmbeddingProvider>,
pending_batch: Vec<FileFragmentToEmbed>,
executor: BackgroundExecutor,
pending_batch_token_count: usize,
finished_files_tx: channel::Sender<FileToEmbed>,
finished_files_rx: channel::Receiver<FileToEmbed>,
}
#[derive(Clone)]
pub struct FileFragmentToEmbed {
file: Arc<Mutex<FileToEmbed>>,
span_range: Range<usize>,
}
impl EmbeddingQueue {
pub fn new(
embedding_provider: Arc<dyn EmbeddingProvider>,
executor: BackgroundExecutor,
) -> Self {
let (finished_files_tx, finished_files_rx) = channel::unbounded();
Self {
embedding_provider,
executor,
pending_batch: Vec::new(),
pending_batch_token_count: 0,
finished_files_tx,
finished_files_rx,
}
}
pub fn push(&mut self, file: FileToEmbed) {
if file.spans.is_empty() {
self.finished_files_tx.try_send(file).unwrap();
return;
}
let file = Arc::new(Mutex::new(file));
self.pending_batch.push(FileFragmentToEmbed {
file: file.clone(),
span_range: 0..0,
});
let mut fragment_range = &mut self.pending_batch.last_mut().unwrap().span_range;
for (ix, span) in file.lock().spans.iter().enumerate() {
let span_token_count = if span.embedding.is_none() {
span.token_count
} else {
0
};
let next_token_count = self.pending_batch_token_count + span_token_count;
if next_token_count > self.embedding_provider.max_tokens_per_batch() {
let range_end = fragment_range.end;
self.flush();
self.pending_batch.push(FileFragmentToEmbed {
file: file.clone(),
span_range: range_end..range_end,
});
fragment_range = &mut self.pending_batch.last_mut().unwrap().span_range;
}
fragment_range.end = ix + 1;
self.pending_batch_token_count += span_token_count;
}
}
pub fn flush(&mut self) {
let batch = mem::take(&mut self.pending_batch);
self.pending_batch_token_count = 0;
if batch.is_empty() {
return;
}
let finished_files_tx = self.finished_files_tx.clone();
let embedding_provider = self.embedding_provider.clone();
self.executor
.spawn(async move {
let mut spans = Vec::new();
for fragment in &batch {
let file = fragment.file.lock();
spans.extend(
file.spans[fragment.span_range.clone()]
.iter()
.filter(|d| d.embedding.is_none())
.map(|d| d.content.clone()),
);
}
// If spans is 0, just send the fragment to the finished files if its the last one.
if spans.is_empty() {
for fragment in batch.clone() {
if let Some(file) = Arc::into_inner(fragment.file) {
finished_files_tx.try_send(file.into_inner()).unwrap();
}
}
return;
};
match embedding_provider.embed_batch(spans).await {
Ok(embeddings) => {
let mut embeddings = embeddings.into_iter();
for fragment in batch {
for span in &mut fragment.file.lock().spans[fragment.span_range.clone()]
.iter_mut()
.filter(|d| d.embedding.is_none())
{
if let Some(embedding) = embeddings.next() {
span.embedding = Some(embedding);
} else {
log::error!("number of embeddings != number of documents");
}
}
if let Some(file) = Arc::into_inner(fragment.file) {
finished_files_tx.try_send(file.into_inner()).unwrap();
}
}
}
Err(error) => {
log::error!("{:?}", error);
}
}
})
.detach();
}
pub fn finished_files(&self) -> channel::Receiver<FileToEmbed> {
self.finished_files_rx.clone()
}
}

View File

@ -1,414 +0,0 @@
use ai::{
embedding::{Embedding, EmbeddingProvider},
models::TruncationDirection,
};
use anyhow::{anyhow, Result};
use language::{Grammar, Language};
use rusqlite::{
types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef},
ToSql,
};
use sha1::{Digest, Sha1};
use std::{
borrow::Cow,
cmp::{self, Reverse},
collections::HashSet,
ops::Range,
path::Path,
sync::Arc,
};
use tree_sitter::{Parser, QueryCursor};
#[derive(Debug, PartialEq, Eq, Clone, Hash)]
pub struct SpanDigest(pub [u8; 20]);
impl FromSql for SpanDigest {
fn column_result(value: ValueRef) -> FromSqlResult<Self> {
let blob = value.as_blob()?;
let bytes =
blob.try_into()
.map_err(|_| rusqlite::types::FromSqlError::InvalidBlobSize {
expected_size: 20,
blob_size: blob.len(),
})?;
return Ok(SpanDigest(bytes));
}
}
impl ToSql for SpanDigest {
fn to_sql(&self) -> rusqlite::Result<ToSqlOutput> {
self.0.to_sql()
}
}
impl From<&'_ str> for SpanDigest {
fn from(value: &'_ str) -> Self {
let mut sha1 = Sha1::new();
sha1.update(value);
Self(sha1.finalize().into())
}
}
#[derive(Debug, PartialEq, Clone)]
pub struct Span {
pub name: String,
pub range: Range<usize>,
pub content: String,
pub embedding: Option<Embedding>,
pub digest: SpanDigest,
pub token_count: usize,
}
const CODE_CONTEXT_TEMPLATE: &str =
"The below code snippet is from file '<path>'\n\n```<language>\n<item>\n```";
const ENTIRE_FILE_TEMPLATE: &str =
"The below snippet is from file '<path>'\n\n```<language>\n<item>\n```";
const MARKDOWN_CONTEXT_TEMPLATE: &str = "The below file contents is from file '<path>'\n\n<item>";
pub const PARSEABLE_ENTIRE_FILE_TYPES: &[&str] = &[
"TOML", "YAML", "CSS", "HEEX", "ERB", "SVELTE", "HTML", "Scheme",
];
pub struct CodeContextRetriever {
pub parser: Parser,
pub cursor: QueryCursor,
pub embedding_provider: Arc<dyn EmbeddingProvider>,
}
// Every match has an item, this represents the fundamental treesitter symbol and anchors the search
// Every match has one or more 'name' captures. These indicate the display range of the item for deduplication.
// If there are preceeding comments, we track this with a context capture
// If there is a piece that should be collapsed in hierarchical queries, we capture it with a collapse capture
// If there is a piece that should be kept inside a collapsed node, we capture it with a keep capture
#[derive(Debug, Clone)]
pub struct CodeContextMatch {
pub start_col: usize,
pub item_range: Option<Range<usize>>,
pub name_range: Option<Range<usize>>,
pub context_ranges: Vec<Range<usize>>,
pub collapse_ranges: Vec<Range<usize>>,
}
impl CodeContextRetriever {
pub fn new(embedding_provider: Arc<dyn EmbeddingProvider>) -> Self {
Self {
parser: Parser::new(),
cursor: QueryCursor::new(),
embedding_provider,
}
}
fn parse_entire_file(
&self,
relative_path: Option<&Path>,
language_name: Arc<str>,
content: &str,
) -> Result<Vec<Span>> {
let document_span = ENTIRE_FILE_TEMPLATE
.replace(
"<path>",
&relative_path.map_or(Cow::Borrowed("untitled"), |path| path.to_string_lossy()),
)
.replace("<language>", language_name.as_ref())
.replace("<item>", &content);
let digest = SpanDigest::from(document_span.as_str());
let model = self.embedding_provider.base_model();
let document_span = model.truncate(
&document_span,
model.capacity()?,
ai::models::TruncationDirection::End,
)?;
let token_count = model.count_tokens(&document_span)?;
Ok(vec![Span {
range: 0..content.len(),
content: document_span,
embedding: Default::default(),
name: language_name.to_string(),
digest,
token_count,
}])
}
fn parse_markdown_file(
&self,
relative_path: Option<&Path>,
content: &str,
) -> Result<Vec<Span>> {
let document_span = MARKDOWN_CONTEXT_TEMPLATE
.replace(
"<path>",
&relative_path.map_or(Cow::Borrowed("untitled"), |path| path.to_string_lossy()),
)
.replace("<item>", &content);
let digest = SpanDigest::from(document_span.as_str());
let model = self.embedding_provider.base_model();
let document_span = model.truncate(
&document_span,
model.capacity()?,
ai::models::TruncationDirection::End,
)?;
let token_count = model.count_tokens(&document_span)?;
Ok(vec![Span {
range: 0..content.len(),
content: document_span,
embedding: None,
name: "Markdown".to_string(),
digest,
token_count,
}])
}
fn get_matches_in_file(
&mut self,
content: &str,
grammar: &Arc<Grammar>,
) -> Result<Vec<CodeContextMatch>> {
let embedding_config = grammar
.embedding_config
.as_ref()
.ok_or_else(|| anyhow!("no embedding queries"))?;
self.parser.set_language(&grammar.ts_language).unwrap();
let tree = self
.parser
.parse(&content, None)
.ok_or_else(|| anyhow!("parsing failed"))?;
let mut captures: Vec<CodeContextMatch> = Vec::new();
let mut collapse_ranges: Vec<Range<usize>> = Vec::new();
let mut keep_ranges: Vec<Range<usize>> = Vec::new();
for mat in self.cursor.matches(
&embedding_config.query,
tree.root_node(),
content.as_bytes(),
) {
let mut start_col = 0;
let mut item_range: Option<Range<usize>> = None;
let mut name_range: Option<Range<usize>> = None;
let mut context_ranges: Vec<Range<usize>> = Vec::new();
collapse_ranges.clear();
keep_ranges.clear();
for capture in mat.captures {
if capture.index == embedding_config.item_capture_ix {
item_range = Some(capture.node.byte_range());
start_col = capture.node.start_position().column;
} else if Some(capture.index) == embedding_config.name_capture_ix {
name_range = Some(capture.node.byte_range());
} else if Some(capture.index) == embedding_config.context_capture_ix {
context_ranges.push(capture.node.byte_range());
} else if Some(capture.index) == embedding_config.collapse_capture_ix {
collapse_ranges.push(capture.node.byte_range());
} else if Some(capture.index) == embedding_config.keep_capture_ix {
keep_ranges.push(capture.node.byte_range());
}
}
captures.push(CodeContextMatch {
start_col,
item_range,
name_range,
context_ranges,
collapse_ranges: subtract_ranges(&collapse_ranges, &keep_ranges),
});
}
Ok(captures)
}
pub fn parse_file_with_template(
&mut self,
relative_path: Option<&Path>,
content: &str,
language: Arc<Language>,
) -> Result<Vec<Span>> {
let language_name = language.name();
if PARSEABLE_ENTIRE_FILE_TYPES.contains(&language_name.as_ref()) {
return self.parse_entire_file(relative_path, language_name, &content);
} else if ["Markdown", "Plain Text"].contains(&language_name.as_ref()) {
return self.parse_markdown_file(relative_path, &content);
}
let mut spans = self.parse_file(content, language)?;
for span in &mut spans {
let document_content = CODE_CONTEXT_TEMPLATE
.replace(
"<path>",
&relative_path.map_or(Cow::Borrowed("untitled"), |path| path.to_string_lossy()),
)
.replace("<language>", language_name.as_ref())
.replace("item", &span.content);
let model = self.embedding_provider.base_model();
let document_content = model.truncate(
&document_content,
model.capacity()?,
TruncationDirection::End,
)?;
let token_count = model.count_tokens(&document_content)?;
span.content = document_content;
span.token_count = token_count;
}
Ok(spans)
}
pub fn parse_file(&mut self, content: &str, language: Arc<Language>) -> Result<Vec<Span>> {
let grammar = language
.grammar()
.ok_or_else(|| anyhow!("no grammar for language"))?;
// Iterate through query matches
let matches = self.get_matches_in_file(content, grammar)?;
let language_scope = language.default_scope();
let placeholder = language_scope.collapsed_placeholder();
let mut spans = Vec::new();
let mut collapsed_ranges_within = Vec::new();
let mut parsed_name_ranges = HashSet::new();
for (i, context_match) in matches.iter().enumerate() {
// Items which are collapsible but not embeddable have no item range
let item_range = if let Some(item_range) = context_match.item_range.clone() {
item_range
} else {
continue;
};
// Checks for deduplication
let name;
if let Some(name_range) = context_match.name_range.clone() {
name = content
.get(name_range.clone())
.map_or(String::new(), |s| s.to_string());
if parsed_name_ranges.contains(&name_range) {
continue;
}
parsed_name_ranges.insert(name_range);
} else {
name = String::new();
}
collapsed_ranges_within.clear();
'outer: for remaining_match in &matches[(i + 1)..] {
for collapsed_range in &remaining_match.collapse_ranges {
if item_range.start <= collapsed_range.start
&& item_range.end >= collapsed_range.end
{
collapsed_ranges_within.push(collapsed_range.clone());
} else {
break 'outer;
}
}
}
collapsed_ranges_within.sort_by_key(|r| (r.start, Reverse(r.end)));
let mut span_content = String::new();
for context_range in &context_match.context_ranges {
add_content_from_range(
&mut span_content,
content,
context_range.clone(),
context_match.start_col,
);
span_content.push_str("\n");
}
let mut offset = item_range.start;
for collapsed_range in &collapsed_ranges_within {
if collapsed_range.start > offset {
add_content_from_range(
&mut span_content,
content,
offset..collapsed_range.start,
context_match.start_col,
);
offset = collapsed_range.start;
}
if collapsed_range.end > offset {
span_content.push_str(placeholder);
offset = collapsed_range.end;
}
}
if offset < item_range.end {
add_content_from_range(
&mut span_content,
content,
offset..item_range.end,
context_match.start_col,
);
}
let sha1 = SpanDigest::from(span_content.as_str());
spans.push(Span {
name,
content: span_content,
range: item_range.clone(),
embedding: None,
digest: sha1,
token_count: 0,
})
}
return Ok(spans);
}
}
pub(crate) fn subtract_ranges(
ranges: &[Range<usize>],
ranges_to_subtract: &[Range<usize>],
) -> Vec<Range<usize>> {
let mut result = Vec::new();
let mut ranges_to_subtract = ranges_to_subtract.iter().peekable();
for range in ranges {
let mut offset = range.start;
while offset < range.end {
if let Some(range_to_subtract) = ranges_to_subtract.peek() {
if offset < range_to_subtract.start {
let next_offset = cmp::min(range_to_subtract.start, range.end);
result.push(offset..next_offset);
offset = next_offset;
} else {
let next_offset = cmp::min(range_to_subtract.end, range.end);
offset = next_offset;
}
if offset >= range_to_subtract.end {
ranges_to_subtract.next();
}
} else {
result.push(offset..range.end);
offset = range.end;
}
}
}
result
}
fn add_content_from_range(
output: &mut String,
content: &str,
range: Range<usize>,
start_col: usize,
) {
for mut line in content.get(range.clone()).unwrap_or("").lines() {
for _ in 0..start_col {
if line.starts_with(' ') {
line = &line[1..];
} else {
break;
}
}
output.push_str(line);
output.push('\n');
}
output.pop();
}

File diff suppressed because it is too large Load Diff

View File

@ -1,28 +0,0 @@
use anyhow;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::Settings;
#[derive(Deserialize, Debug)]
pub struct SemanticIndexSettings {
pub enabled: bool,
}
#[derive(Clone, Default, Serialize, Deserialize, JsonSchema, Debug)]
pub struct SemanticIndexSettingsContent {
pub enabled: Option<bool>,
}
impl Settings for SemanticIndexSettings {
const KEY: Option<&'static str> = Some("semantic_index");
type FileContent = SemanticIndexSettingsContent;
fn load(
default_value: &Self::FileContent,
user_values: &[&Self::FileContent],
_: &mut gpui::AppContext,
) -> anyhow::Result<Self> {
Self::load_via_json_merge(default_value, user_values)
}
}

File diff suppressed because it is too large Load Diff

View File

@ -15,7 +15,7 @@ test-support = ["gpui/test-support", "fs/test-support"]
collections = { path = "../collections" }
gpui = {package = "gpui2", path = "../gpui2" }
sqlez = { path = "../sqlez" }
fs = {package = "fs2", path = "../fs2" }
fs = { path = "../fs" }
feature_flags = { path = "../feature_flags" }
util = { path = "../util" }
@ -36,7 +36,7 @@ tree-sitter-json = "*"
[dev-dependencies]
gpui = {package = "gpui2", path = "../gpui2", features = ["test-support"] }
fs = { package = "fs2", path = "../fs2", features = ["test-support"] }
fs = { path = "../fs", features = ["test-support"] }
indoc.workspace = true
pretty_assertions.workspace = true
unindent.workspace = true

View File

@ -12,7 +12,7 @@ doctest = false
[dependencies]
gpui = { package = "gpui2", path = "../gpui2" }
settings = { package = "settings2", path = "../settings2" }
db = { package = "db2", path = "../db2" }
db = { path = "../db" }
theme = { package = "theme2", path = "../theme2" }
util = { path = "../util" }

View File

@ -18,7 +18,7 @@ settings = { package = "settings2", path = "../settings2" }
theme = { package = "theme2", path = "../theme2" }
util = { path = "../util" }
workspace = { path = "../workspace" }
db = { package = "db2", path = "../db2" }
db = { path = "../db" }
procinfo = { git = "https://github.com/zed-industries/wezterm", rev = "5cd757e5f2eb039ed0c6bb6512223e69d5efc64d", default-features = false }
terminal = { package = "terminal2", path = "../terminal2" }
ui = { package = "ui2", path = "../ui2" }

View File

@ -20,7 +20,7 @@ doctest = false
[dependencies]
anyhow.workspace = true
fs = { package = "fs2", path = "../fs2" }
fs = { path = "../fs" }
gpui = { package = "gpui2", path = "../gpui2" }
indexmap = "1.6.2"
parking_lot.workspace = true
@ -38,5 +38,5 @@ itertools = { version = "0.11.0", optional = true }
[dev-dependencies]
gpui = { package = "gpui2", path = "../gpui2", features = ["test-support"] }
fs = { package = "fs2", path = "../fs2", features = ["test-support"] }
fs = { path = "../fs", features = ["test-support"] }
settings = { package = "settings2", path = "../settings2", features = ["test-support"] }

View File

@ -12,7 +12,7 @@ doctest = false
client = { path = "../client" }
editor = { path = "../editor" }
feature_flags = { path = "../feature_flags" }
fs = { package = "fs2", path = "../fs2" }
fs = { path = "../fs" }
fuzzy = { path = "../fuzzy" }
gpui = { package = "gpui2", path = "../gpui2" }
picker = { path = "../picker" }

View File

@ -7,7 +7,7 @@ publish = false
[dependencies]
fuzzy = { path = "../fuzzy"}
fs = {package = "fs2", path = "../fs2"}
fs = {path = "../fs"}
gpui = {package = "gpui2", path = "../gpui2"}
picker = {path = "../picker"}
util = {path = "../util"}

View File

@ -13,11 +13,11 @@ test-support = []
[dependencies]
client = { path = "../client" }
editor = { path = "../editor" }
fs = { package = "fs2", path = "../fs2" }
fs = { path = "../fs" }
fuzzy = { path = "../fuzzy" }
gpui = { package = "gpui2", path = "../gpui2" }
ui = { package = "ui2", path = "../ui2" }
db = { package = "db2", path = "../db2" }
db = { path = "../db" }
install_cli = { path = "../install_cli" }
project = { path = "../project" }
settings = { package = "settings2", path = "../settings2" }

View File

@ -19,12 +19,12 @@ test-support = [
]
[dependencies]
db = { path = "../db2", package = "db2" }
db = { path = "../db" }
call = { path = "../call" }
client = { path = "../client" }
collections = { path = "../collections" }
# context_menu = { path = "../context_menu" }
fs = { path = "../fs2", package = "fs2" }
fs = { path = "../fs" }
gpui = { package = "gpui2", path = "../gpui2" }
install_cli = { path = "../install_cli" }
language = { path = "../language" }
@ -59,8 +59,8 @@ client = { path = "../client", features = ["test-support"] }
gpui = { path = "../gpui2", package = "gpui2", features = ["test-support"] }
project = { path = "../project", features = ["test-support"] }
settings = { path = "../settings2", package = "settings2", features = ["test-support"] }
fs = { path = "../fs2", package = "fs2", features = ["test-support"] }
db = { path = "../db2", package = "db2", features = ["test-support"] }
fs = { path = "../fs", features = ["test-support"] }
db = { path = "../db", features = ["test-support"] }
indoc.workspace = true
env_logger.workspace = true

View File

@ -32,12 +32,12 @@ client = { path = "../client" }
copilot = { path = "../copilot" }
copilot_button = { path = "../copilot_button" }
diagnostics = { path = "../diagnostics" }
db = { package = "db2", path = "../db2" }
db = { path = "../db" }
editor = { path = "../editor" }
feedback = { path = "../feedback" }
file_finder = { path = "../file_finder" }
search = { path = "../search" }
fs = { package = "fs2", path = "../fs2" }
fs = { path = "../fs" }
fsevent = { path = "../fsevent" }
go_to_line = { path = "../go_to_line" }
gpui = { package = "gpui2", path = "../gpui2" }
@ -49,7 +49,7 @@ lsp = { path = "../lsp" }
menu = { package = "menu2", path = "../menu2" }
language_tools = { path = "../language_tools" }
node_runtime = { path = "../node_runtime" }
notifications = { package = "notifications2", path = "../notifications2" }
notifications = { path = "../notifications" }
assistant = { path = "../assistant" }
outline = { path = "../outline" }
# plugin_runtime = { path = "../plugin_runtime",optional = true }
@ -59,7 +59,7 @@ project_symbols = { path = "../project_symbols" }
quick_action_bar = { path = "../quick_action_bar" }
recent_projects = { path = "../recent_projects" }
rope = { package = "rope2", path = "../rope2"}
rpc = { package = "rpc2", path = "../rpc2" }
rpc = { path = "../rpc" }
settings = { package = "settings2", path = "../settings2" }
feature_flags = { path = "../feature_flags" }
sum_tree = { path = "../sum_tree" }
@ -69,7 +69,7 @@ terminal_view = { path = "../terminal_view" }
theme = { package = "theme2", path = "../theme2" }
theme_selector = { path = "../theme_selector" }
util = { path = "../util" }
semantic_index = { package = "semantic_index2", path = "../semantic_index2" }
semantic_index = { path = "../semantic_index" }
vim = { path = "../vim" }
workspace = { path = "../workspace" }
welcome = { path = "../welcome" }