mirror of
https://github.com/zed-industries/zed.git
synced 2024-11-07 20:39:04 +03:00
Remove 2 suffix for fs, db, semantic_index, prettier
Co-authored-by: Mikayla <mikayla@zed.dev>
This commit is contained in:
parent
324ac96977
commit
5ddd298b4d
262
Cargo.lock
generated
262
Cargo.lock
generated
@ -310,7 +310,7 @@ dependencies = [
|
||||
"ctor",
|
||||
"editor",
|
||||
"env_logger",
|
||||
"fs2",
|
||||
"fs",
|
||||
"futures 0.3.28",
|
||||
"gpui2",
|
||||
"indoc",
|
||||
@ -326,7 +326,7 @@ dependencies = [
|
||||
"regex",
|
||||
"schemars",
|
||||
"search",
|
||||
"semantic_index2",
|
||||
"semantic_index",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"settings2",
|
||||
@ -688,7 +688,7 @@ version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"client",
|
||||
"db2",
|
||||
"db",
|
||||
"gpui2",
|
||||
"isahc",
|
||||
"lazy_static",
|
||||
@ -1122,7 +1122,7 @@ dependencies = [
|
||||
"audio2",
|
||||
"client",
|
||||
"collections",
|
||||
"fs2",
|
||||
"fs",
|
||||
"futures 0.3.28",
|
||||
"gpui2",
|
||||
"image",
|
||||
@ -1210,7 +1210,7 @@ dependencies = [
|
||||
"client",
|
||||
"clock",
|
||||
"collections",
|
||||
"db2",
|
||||
"db",
|
||||
"feature_flags",
|
||||
"futures 0.3.28",
|
||||
"gpui2",
|
||||
@ -1221,7 +1221,7 @@ dependencies = [
|
||||
"parking_lot 0.11.2",
|
||||
"postage",
|
||||
"rand 0.8.5",
|
||||
"rpc2",
|
||||
"rpc",
|
||||
"schemars",
|
||||
"serde",
|
||||
"serde_derive",
|
||||
@ -1384,7 +1384,7 @@ dependencies = [
|
||||
"async-tungstenite",
|
||||
"chrono",
|
||||
"collections",
|
||||
"db2",
|
||||
"db",
|
||||
"feature_flags",
|
||||
"futures 0.3.28",
|
||||
"gpui2",
|
||||
@ -1394,7 +1394,7 @@ dependencies = [
|
||||
"parking_lot 0.11.2",
|
||||
"postage",
|
||||
"rand 0.8.5",
|
||||
"rpc2",
|
||||
"rpc",
|
||||
"schemars",
|
||||
"serde",
|
||||
"serde_derive",
|
||||
@ -1553,7 +1553,7 @@ dependencies = [
|
||||
"editor",
|
||||
"env_logger",
|
||||
"envy",
|
||||
"fs2",
|
||||
"fs",
|
||||
"futures 0.3.28",
|
||||
"git3",
|
||||
"gpui2",
|
||||
@ -1568,7 +1568,7 @@ dependencies = [
|
||||
"lsp",
|
||||
"nanoid",
|
||||
"node_runtime",
|
||||
"notifications2",
|
||||
"notifications",
|
||||
"parking_lot 0.11.2",
|
||||
"pretty_assertions",
|
||||
"project",
|
||||
@ -1576,7 +1576,7 @@ dependencies = [
|
||||
"prost 0.8.0",
|
||||
"rand 0.8.5",
|
||||
"reqwest",
|
||||
"rpc2",
|
||||
"rpc",
|
||||
"scrypt",
|
||||
"sea-orm",
|
||||
"serde",
|
||||
@ -1614,7 +1614,7 @@ dependencies = [
|
||||
"client",
|
||||
"clock",
|
||||
"collections",
|
||||
"db2",
|
||||
"db",
|
||||
"editor",
|
||||
"feature_flags",
|
||||
"feedback",
|
||||
@ -1625,14 +1625,14 @@ dependencies = [
|
||||
"lazy_static",
|
||||
"log",
|
||||
"menu2",
|
||||
"notifications2",
|
||||
"notifications",
|
||||
"picker",
|
||||
"postage",
|
||||
"pretty_assertions",
|
||||
"project",
|
||||
"recent_projects",
|
||||
"rich_text",
|
||||
"rpc2",
|
||||
"rpc",
|
||||
"schemars",
|
||||
"serde",
|
||||
"serde_derive",
|
||||
@ -1794,7 +1794,7 @@ dependencies = [
|
||||
"lsp",
|
||||
"node_runtime",
|
||||
"parking_lot 0.11.2",
|
||||
"rpc2",
|
||||
"rpc",
|
||||
"serde",
|
||||
"serde_derive",
|
||||
"settings2",
|
||||
@ -1811,7 +1811,7 @@ dependencies = [
|
||||
"anyhow",
|
||||
"copilot",
|
||||
"editor",
|
||||
"fs2",
|
||||
"fs",
|
||||
"futures 0.3.28",
|
||||
"gpui2",
|
||||
"language",
|
||||
@ -2208,28 +2208,6 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "db"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"async-trait",
|
||||
"collections",
|
||||
"env_logger",
|
||||
"gpui",
|
||||
"indoc",
|
||||
"lazy_static",
|
||||
"log",
|
||||
"parking_lot 0.11.2",
|
||||
"serde",
|
||||
"serde_derive",
|
||||
"smol",
|
||||
"sqlez",
|
||||
"sqlez_macros",
|
||||
"tempdir",
|
||||
"util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "db2"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"async-trait",
|
||||
@ -2500,7 +2478,7 @@ dependencies = [
|
||||
"convert_case 0.6.0",
|
||||
"copilot",
|
||||
"ctor",
|
||||
"db2",
|
||||
"db",
|
||||
"env_logger",
|
||||
"futures 0.3.28",
|
||||
"fuzzy",
|
||||
@ -2519,7 +2497,7 @@ dependencies = [
|
||||
"project",
|
||||
"rand 0.8.5",
|
||||
"rich_text",
|
||||
"rpc2",
|
||||
"rpc",
|
||||
"schemars",
|
||||
"serde",
|
||||
"serde_derive",
|
||||
@ -2724,7 +2702,7 @@ dependencies = [
|
||||
"anyhow",
|
||||
"bitflags 2.4.1",
|
||||
"client",
|
||||
"db2",
|
||||
"db",
|
||||
"editor",
|
||||
"futures 0.3.28",
|
||||
"gpui2",
|
||||
@ -2923,34 +2901,6 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "fs"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"async-trait",
|
||||
"collections",
|
||||
"fsevent",
|
||||
"futures 0.3.28",
|
||||
"git2",
|
||||
"gpui",
|
||||
"lazy_static",
|
||||
"libc",
|
||||
"log",
|
||||
"parking_lot 0.11.2",
|
||||
"regex",
|
||||
"rope",
|
||||
"serde",
|
||||
"serde_derive",
|
||||
"serde_json",
|
||||
"smol",
|
||||
"sum_tree",
|
||||
"tempfile",
|
||||
"text",
|
||||
"time",
|
||||
"util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fs2"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"async-trait",
|
||||
@ -4106,7 +4056,7 @@ dependencies = [
|
||||
"pulldown-cmark",
|
||||
"rand 0.8.5",
|
||||
"regex",
|
||||
"rpc2",
|
||||
"rpc",
|
||||
"schemars",
|
||||
"serde",
|
||||
"serde_derive",
|
||||
@ -4957,28 +4907,8 @@ dependencies = [
|
||||
"collections",
|
||||
"db",
|
||||
"feature_flags",
|
||||
"gpui",
|
||||
"rpc",
|
||||
"settings",
|
||||
"sum_tree",
|
||||
"text",
|
||||
"time",
|
||||
"util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "notifications2"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"channel",
|
||||
"client",
|
||||
"clock",
|
||||
"collections",
|
||||
"db2",
|
||||
"feature_flags",
|
||||
"gpui2",
|
||||
"rpc2",
|
||||
"rpc",
|
||||
"settings2",
|
||||
"sum_tree",
|
||||
"text2",
|
||||
@ -5786,27 +5716,6 @@ dependencies = [
|
||||
"collections",
|
||||
"fs",
|
||||
"futures 0.3.28",
|
||||
"gpui",
|
||||
"language",
|
||||
"log",
|
||||
"lsp",
|
||||
"node_runtime",
|
||||
"parking_lot 0.11.2",
|
||||
"serde",
|
||||
"serde_derive",
|
||||
"serde_json",
|
||||
"util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "prettier2"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"client",
|
||||
"collections",
|
||||
"fs2",
|
||||
"futures 0.3.28",
|
||||
"gpui2",
|
||||
"language",
|
||||
"log",
|
||||
@ -5915,9 +5824,9 @@ dependencies = [
|
||||
"collections",
|
||||
"copilot",
|
||||
"ctor",
|
||||
"db2",
|
||||
"db",
|
||||
"env_logger",
|
||||
"fs2",
|
||||
"fs",
|
||||
"fsevent",
|
||||
"futures 0.3.28",
|
||||
"fuzzy",
|
||||
@ -5934,11 +5843,11 @@ dependencies = [
|
||||
"node_runtime",
|
||||
"parking_lot 0.11.2",
|
||||
"postage",
|
||||
"prettier2",
|
||||
"prettier",
|
||||
"pretty_assertions",
|
||||
"rand 0.8.5",
|
||||
"regex",
|
||||
"rpc2",
|
||||
"rpc",
|
||||
"schemars",
|
||||
"serde",
|
||||
"serde_derive",
|
||||
@ -5964,7 +5873,7 @@ dependencies = [
|
||||
"anyhow",
|
||||
"client",
|
||||
"collections",
|
||||
"db2",
|
||||
"db",
|
||||
"editor",
|
||||
"futures 0.3.28",
|
||||
"gpui2",
|
||||
@ -6668,37 +6577,6 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "rpc"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"async-lock",
|
||||
"async-tungstenite",
|
||||
"base64 0.13.1",
|
||||
"clock",
|
||||
"collections",
|
||||
"ctor",
|
||||
"env_logger",
|
||||
"futures 0.3.28",
|
||||
"gpui",
|
||||
"parking_lot 0.11.2",
|
||||
"prost 0.8.0",
|
||||
"prost-build",
|
||||
"rand 0.8.5",
|
||||
"rsa 0.4.0",
|
||||
"serde",
|
||||
"serde_derive",
|
||||
"serde_json",
|
||||
"smol",
|
||||
"smol-timeout",
|
||||
"strum",
|
||||
"tempdir",
|
||||
"tracing",
|
||||
"util",
|
||||
"zstd",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rpc2"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"async-lock",
|
||||
@ -7175,7 +7053,7 @@ dependencies = [
|
||||
"menu2",
|
||||
"postage",
|
||||
"project",
|
||||
"semantic_index2",
|
||||
"semantic_index",
|
||||
"serde",
|
||||
"serde_derive",
|
||||
"serde_json",
|
||||
@ -7215,60 +7093,6 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "semantic_index"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"ai",
|
||||
"anyhow",
|
||||
"async-trait",
|
||||
"client",
|
||||
"collections",
|
||||
"ctor",
|
||||
"editor",
|
||||
"env_logger",
|
||||
"futures 0.3.28",
|
||||
"globset",
|
||||
"gpui",
|
||||
"language",
|
||||
"lazy_static",
|
||||
"log",
|
||||
"ndarray",
|
||||
"node_runtime",
|
||||
"ordered-float 2.10.0",
|
||||
"parking_lot 0.11.2",
|
||||
"picker",
|
||||
"postage",
|
||||
"pretty_assertions",
|
||||
"project",
|
||||
"rand 0.8.5",
|
||||
"rpc",
|
||||
"rusqlite",
|
||||
"rust-embed",
|
||||
"schemars",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"settings",
|
||||
"sha1",
|
||||
"smol",
|
||||
"tempdir",
|
||||
"theme",
|
||||
"tiktoken-rs",
|
||||
"tree-sitter",
|
||||
"tree-sitter-cpp",
|
||||
"tree-sitter-elixir",
|
||||
"tree-sitter-json 0.20.0",
|
||||
"tree-sitter-lua",
|
||||
"tree-sitter-php",
|
||||
"tree-sitter-ruby",
|
||||
"tree-sitter-rust",
|
||||
"tree-sitter-toml",
|
||||
"tree-sitter-typescript",
|
||||
"unindent",
|
||||
"util",
|
||||
"workspace",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "semantic_index2"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"ai",
|
||||
"anyhow",
|
||||
@ -7291,7 +7115,7 @@ dependencies = [
|
||||
"pretty_assertions",
|
||||
"project",
|
||||
"rand 0.8.5",
|
||||
"rpc2",
|
||||
"rpc",
|
||||
"rusqlite",
|
||||
"rust-embed",
|
||||
"schemars",
|
||||
@ -7473,7 +7297,7 @@ dependencies = [
|
||||
"anyhow",
|
||||
"collections",
|
||||
"feature_flags",
|
||||
"fs2",
|
||||
"fs",
|
||||
"futures 0.3.28",
|
||||
"gpui2",
|
||||
"indoc",
|
||||
@ -8424,7 +8248,7 @@ version = "0.1.0"
|
||||
dependencies = [
|
||||
"alacritty_terminal",
|
||||
"anyhow",
|
||||
"db2",
|
||||
"db",
|
||||
"dirs 4.0.0",
|
||||
"futures 0.3.28",
|
||||
"gpui2",
|
||||
@ -8453,7 +8277,7 @@ version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"client",
|
||||
"db2",
|
||||
"db",
|
||||
"dirs 4.0.0",
|
||||
"editor",
|
||||
"futures 0.3.28",
|
||||
@ -8556,7 +8380,7 @@ name = "theme2"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"fs2",
|
||||
"fs",
|
||||
"gpui2",
|
||||
"indexmap 1.9.3",
|
||||
"itertools 0.11.0",
|
||||
@ -8602,7 +8426,7 @@ dependencies = [
|
||||
"client",
|
||||
"editor",
|
||||
"feature_flags",
|
||||
"fs2",
|
||||
"fs",
|
||||
"fuzzy",
|
||||
"gpui2",
|
||||
"log",
|
||||
@ -9672,7 +9496,7 @@ name = "vcs_menu"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"fs2",
|
||||
"fs",
|
||||
"fuzzy",
|
||||
"gpui2",
|
||||
"picker",
|
||||
@ -10114,9 +9938,9 @@ version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"client",
|
||||
"db2",
|
||||
"db",
|
||||
"editor",
|
||||
"fs2",
|
||||
"fs",
|
||||
"fuzzy",
|
||||
"gpui2",
|
||||
"install_cli",
|
||||
@ -10383,9 +10207,9 @@ dependencies = [
|
||||
"call",
|
||||
"client",
|
||||
"collections",
|
||||
"db2",
|
||||
"db",
|
||||
"env_logger",
|
||||
"fs2",
|
||||
"fs",
|
||||
"futures 0.3.28",
|
||||
"gpui2",
|
||||
"indoc",
|
||||
@ -10516,14 +10340,14 @@ dependencies = [
|
||||
"copilot",
|
||||
"copilot_button",
|
||||
"ctor",
|
||||
"db2",
|
||||
"db",
|
||||
"diagnostics",
|
||||
"editor",
|
||||
"env_logger",
|
||||
"feature_flags",
|
||||
"feedback",
|
||||
"file_finder",
|
||||
"fs2",
|
||||
"fs",
|
||||
"fsevent",
|
||||
"futures 0.3.28",
|
||||
"go_to_line",
|
||||
@ -10543,7 +10367,7 @@ dependencies = [
|
||||
"lsp",
|
||||
"menu2",
|
||||
"node_runtime",
|
||||
"notifications2",
|
||||
"notifications",
|
||||
"num_cpus",
|
||||
"outline",
|
||||
"parking_lot 0.11.2",
|
||||
@ -10556,12 +10380,12 @@ dependencies = [
|
||||
"recent_projects",
|
||||
"regex",
|
||||
"rope2",
|
||||
"rpc2",
|
||||
"rpc",
|
||||
"rsa 0.4.0",
|
||||
"rust-embed",
|
||||
"schemars",
|
||||
"search",
|
||||
"semantic_index2",
|
||||
"semantic_index",
|
||||
"serde",
|
||||
"serde_derive",
|
||||
"serde_json",
|
||||
|
@ -22,7 +22,6 @@ members = [
|
||||
"crates/copilot",
|
||||
"crates/copilot_button",
|
||||
"crates/db",
|
||||
"crates/db2",
|
||||
"crates/refineable",
|
||||
"crates/refineable/derive_refineable",
|
||||
"crates/diagnostics",
|
||||
@ -32,7 +31,6 @@ members = [
|
||||
"crates/feedback",
|
||||
"crates/file_finder",
|
||||
"crates/fs",
|
||||
"crates/fs2",
|
||||
"crates/fsevent",
|
||||
"crates/fuzzy",
|
||||
"crates/git",
|
||||
@ -56,13 +54,11 @@ members = [
|
||||
"crates/multi_buffer",
|
||||
"crates/node_runtime",
|
||||
"crates/notifications",
|
||||
"crates/notifications2",
|
||||
"crates/outline",
|
||||
"crates/picker",
|
||||
"crates/plugin",
|
||||
"crates/plugin_macros",
|
||||
"crates/prettier",
|
||||
"crates/prettier2",
|
||||
"crates/project",
|
||||
"crates/project_panel",
|
||||
"crates/project_symbols",
|
||||
@ -70,10 +66,8 @@ members = [
|
||||
"crates/recent_projects",
|
||||
"crates/rope",
|
||||
"crates/rpc",
|
||||
"crates/rpc2",
|
||||
"crates/search",
|
||||
"crates/semantic_index",
|
||||
"crates/semantic_index2",
|
||||
"crates/settings",
|
||||
"crates/settings2",
|
||||
"crates/snippet",
|
||||
|
@ -13,14 +13,14 @@ ai = { path = "../ai" }
|
||||
client = { path = "../client" }
|
||||
collections = { path = "../collections"}
|
||||
editor = { path = "../editor" }
|
||||
fs = { package = "fs2", path = "../fs2" }
|
||||
fs = { path = "../fs" }
|
||||
gpui = { package = "gpui2", path = "../gpui2" }
|
||||
language = { path = "../language" }
|
||||
menu = { package = "menu2", path = "../menu2" }
|
||||
multi_buffer = { path = "../multi_buffer" }
|
||||
project = { path = "../project" }
|
||||
search = { path = "../search" }
|
||||
semantic_index = { package = "semantic_index2", path = "../semantic_index2" }
|
||||
semantic_index = { path = "../semantic_index" }
|
||||
settings = { package = "settings2", path = "../settings2" }
|
||||
theme = { package = "theme2", path = "../theme2" }
|
||||
ui = { package = "ui2", path = "../ui2" }
|
||||
|
@ -9,7 +9,7 @@ path = "src/auto_update.rs"
|
||||
doctest = false
|
||||
|
||||
[dependencies]
|
||||
db = { package = "db2", path = "../db2" }
|
||||
db = { path = "../db" }
|
||||
client = { path = "../client" }
|
||||
gpui = { package = "gpui2", path = "../gpui2" }
|
||||
menu = { package = "menu2", path = "../menu2" }
|
||||
|
@ -25,7 +25,7 @@ collections = { path = "../collections" }
|
||||
gpui = { package = "gpui2", path = "../gpui2" }
|
||||
log.workspace = true
|
||||
live_kit_client = { package = "live_kit_client2", path = "../live_kit_client2" }
|
||||
fs = { package = "fs2", path = "../fs2" }
|
||||
fs = { path = "../fs" }
|
||||
language = { path = "../language" }
|
||||
media = { path = "../media" }
|
||||
project = { path = "../project" }
|
||||
@ -45,7 +45,7 @@ smallvec.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
client = { path = "../client", features = ["test-support"] }
|
||||
fs = { package = "fs2", path = "../fs2", features = ["test-support"] }
|
||||
fs = { path = "../fs", features = ["test-support"] }
|
||||
language = { path = "../language", features = ["test-support"] }
|
||||
collections = { path = "../collections", features = ["test-support"] }
|
||||
gpui = { package = "gpui2", path = "../gpui2", features = ["test-support"] }
|
||||
|
@ -14,10 +14,10 @@ test-support = ["collections/test-support", "gpui/test-support", "rpc/test-suppo
|
||||
[dependencies]
|
||||
client = { path = "../client" }
|
||||
collections = { path = "../collections" }
|
||||
db = { package = "db2", path = "../db2" }
|
||||
db = { path = "../db" }
|
||||
gpui = { package = "gpui2", path = "../gpui2" }
|
||||
util = { path = "../util" }
|
||||
rpc = { package = "rpc2", path = "../rpc2" }
|
||||
rpc = { path = "../rpc" }
|
||||
text = { package = "text2", path = "../text2" }
|
||||
language = { path = "../language" }
|
||||
settings = { package = "settings2", path = "../settings2" }
|
||||
@ -48,7 +48,7 @@ tempfile = "3"
|
||||
[dev-dependencies]
|
||||
collections = { path = "../collections", features = ["test-support"] }
|
||||
gpui = { package = "gpui2", path = "../gpui2", features = ["test-support"] }
|
||||
rpc = { package = "rpc2", path = "../rpc2", features = ["test-support"] }
|
||||
rpc = { path = "../rpc", features = ["test-support"] }
|
||||
client = { path = "../client", features = ["test-support"] }
|
||||
settings = { package = "settings2", path = "../settings2", features = ["test-support"] }
|
||||
util = { path = "../util", features = ["test-support"] }
|
||||
|
@ -14,10 +14,10 @@ test-support = ["collections/test-support", "gpui/test-support", "rpc/test-suppo
|
||||
[dependencies]
|
||||
chrono = { version = "0.4", features = ["serde"] }
|
||||
collections = { path = "../collections" }
|
||||
db = { package = "db2", path = "../db2" }
|
||||
db = { path = "../db" }
|
||||
gpui = { package = "gpui2", path = "../gpui2" }
|
||||
util = { path = "../util" }
|
||||
rpc = { package = "rpc2", path = "../rpc2" }
|
||||
rpc = { path = "../rpc" }
|
||||
text = { package = "text2", path = "../text2" }
|
||||
settings = { package = "settings2", path = "../settings2" }
|
||||
feature_flags = { path = "../feature_flags" }
|
||||
@ -48,6 +48,6 @@ url = "2.2"
|
||||
[dev-dependencies]
|
||||
collections = { path = "../collections", features = ["test-support"] }
|
||||
gpui = { package = "gpui2", path = "../gpui2", features = ["test-support"] }
|
||||
rpc = { package = "rpc2", path = "../rpc2", features = ["test-support"] }
|
||||
rpc = { path = "../rpc", features = ["test-support"] }
|
||||
settings = { package = "settings2", path = "../settings2", features = ["test-support"] }
|
||||
util = { path = "../util", features = ["test-support"] }
|
||||
|
@ -18,7 +18,7 @@ clock = { path = "../clock" }
|
||||
collections = { path = "../collections" }
|
||||
live_kit_server = { path = "../live_kit_server" }
|
||||
text = { package = "text2", path = "../text2" }
|
||||
rpc = { package = "rpc2", path = "../rpc2" }
|
||||
rpc = { path = "../rpc" }
|
||||
util = { path = "../util" }
|
||||
|
||||
anyhow.workspace = true
|
||||
@ -68,15 +68,15 @@ client = { path = "../client", features = ["test-support"] }
|
||||
channel = { path = "../channel" }
|
||||
editor = { path = "../editor", features = ["test-support"] }
|
||||
language = { path = "../language", features = ["test-support"] }
|
||||
fs = { package = "fs2", path = "../fs2", features = ["test-support"] }
|
||||
fs = { path = "../fs", features = ["test-support"] }
|
||||
git = { package = "git3", path = "../git3", features = ["test-support"] }
|
||||
live_kit_client = { package = "live_kit_client2", path = "../live_kit_client2", features = ["test-support"] }
|
||||
lsp = { path = "../lsp", features = ["test-support"] }
|
||||
node_runtime = { path = "../node_runtime" }
|
||||
notifications = { package = "notifications2", path = "../notifications2", features = ["test-support"] }
|
||||
notifications = { path = "../notifications", features = ["test-support"] }
|
||||
|
||||
project = { path = "../project", features = ["test-support"] }
|
||||
rpc = { package = "rpc2", path = "../rpc2", features = ["test-support"] }
|
||||
rpc = { path = "../rpc", features = ["test-support"] }
|
||||
settings = { package = "settings2", path = "../settings2", features = ["test-support"] }
|
||||
theme = { package = "theme2", path = "../theme2" }
|
||||
workspace = { path = "../workspace", features = ["test-support"] }
|
||||
|
@ -23,7 +23,7 @@ test-support = [
|
||||
|
||||
[dependencies]
|
||||
auto_update = { path = "../auto_update" }
|
||||
db = { package = "db2", path = "../db2" }
|
||||
db = { path = "../db" }
|
||||
call = { path = "../call" }
|
||||
client = { path = "../client" }
|
||||
channel = { path = "../channel" }
|
||||
@ -37,12 +37,12 @@ fuzzy = { path = "../fuzzy" }
|
||||
gpui = { package = "gpui2", path = "../gpui2" }
|
||||
language = { path = "../language" }
|
||||
menu = { package = "menu2", path = "../menu2" }
|
||||
notifications = { package = "notifications2", path = "../notifications2" }
|
||||
notifications = { path = "../notifications" }
|
||||
rich_text = { path = "../rich_text" }
|
||||
picker = { path = "../picker" }
|
||||
project = { path = "../project" }
|
||||
recent_projects = { path = "../recent_projects" }
|
||||
rpc = { package ="rpc2", path = "../rpc2" }
|
||||
rpc = { path = "../rpc" }
|
||||
settings = { package = "settings2", path = "../settings2" }
|
||||
feature_flags = { path = "../feature_flags"}
|
||||
theme = { package = "theme2", path = "../theme2" }
|
||||
@ -70,9 +70,9 @@ client = { path = "../client", features = ["test-support"] }
|
||||
collections = { path = "../collections", features = ["test-support"] }
|
||||
editor = { path = "../editor", features = ["test-support"] }
|
||||
gpui = { package = "gpui2", path = "../gpui2", features = ["test-support"] }
|
||||
notifications = { package = "notifications2", path = "../notifications2", features = ["test-support"] }
|
||||
notifications = { path = "../notifications", features = ["test-support"] }
|
||||
project = { path = "../project", features = ["test-support"] }
|
||||
rpc = { package = "rpc2", path = "../rpc2", features = ["test-support"] }
|
||||
rpc = { path = "../rpc", features = ["test-support"] }
|
||||
settings = { package = "settings2", path = "../settings2", features = ["test-support"] }
|
||||
util = { path = "../util", features = ["test-support"] }
|
||||
workspace = { path = "../workspace", features = ["test-support"] }
|
||||
|
@ -46,6 +46,6 @@ fs = { path = "../fs", features = ["test-support"] }
|
||||
gpui = { package = "gpui2", path = "../gpui2", features = ["test-support"] }
|
||||
language = { path = "../language", features = ["test-support"] }
|
||||
lsp = { path = "../lsp", features = ["test-support"] }
|
||||
rpc = { package = "rpc2", path = "../rpc2", features = ["test-support"] }
|
||||
rpc = { path = "../rpc", features = ["test-support"] }
|
||||
settings = { package = "settings2", path = "../settings2", features = ["test-support"] }
|
||||
util = { path = "../util", features = ["test-support"] }
|
||||
|
@ -11,7 +11,7 @@ doctest = false
|
||||
[dependencies]
|
||||
copilot = { path = "../copilot" }
|
||||
editor = { path = "../editor" }
|
||||
fs = { package = "fs2", path = "../fs2" }
|
||||
fs = { path = "../fs" }
|
||||
zed_actions = { path = "../zed_actions"}
|
||||
gpui = { package = "gpui2", path = "../gpui2" }
|
||||
language = { path = "../language" }
|
||||
|
@ -13,7 +13,7 @@ test-support = []
|
||||
|
||||
[dependencies]
|
||||
collections = { path = "../collections" }
|
||||
gpui = { path = "../gpui" }
|
||||
gpui = { package = "gpui2", path = "../gpui2" }
|
||||
sqlez = { path = "../sqlez" }
|
||||
sqlez_macros = { path = "../sqlez_macros" }
|
||||
util = { path = "../util" }
|
||||
@ -28,6 +28,6 @@ serde_derive.workspace = true
|
||||
smol.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
gpui = { path = "../gpui", features = ["test-support"] }
|
||||
gpui = { package = "gpui2", path = "../gpui2", features = ["test-support"] }
|
||||
env_logger.workspace = true
|
||||
tempdir.workspace = true
|
||||
|
@ -185,7 +185,7 @@ pub fn write_and_log<F>(cx: &mut AppContext, db_write: impl FnOnce() -> F + Send
|
||||
where
|
||||
F: Future<Output = anyhow::Result<()>> + Send,
|
||||
{
|
||||
cx.background()
|
||||
cx.background_executor()
|
||||
.spawn(async move { db_write().await.log_err() })
|
||||
.detach()
|
||||
}
|
||||
@ -226,7 +226,9 @@ mod tests {
|
||||
|
||||
/// Test that DB exists but corrupted (causing recreate)
|
||||
#[gpui::test]
|
||||
async fn test_db_corruption() {
|
||||
async fn test_db_corruption(cx: &mut gpui::TestAppContext) {
|
||||
cx.executor().allow_parking();
|
||||
|
||||
enum CorruptedDB {}
|
||||
|
||||
impl Domain for CorruptedDB {
|
||||
@ -268,7 +270,9 @@ mod tests {
|
||||
|
||||
/// Test that DB exists but corrupted (causing recreate)
|
||||
#[gpui::test(iterations = 30)]
|
||||
async fn test_simultaneous_db_corruption() {
|
||||
async fn test_simultaneous_db_corruption(cx: &mut gpui::TestAppContext) {
|
||||
cx.executor().allow_parking();
|
||||
|
||||
enum CorruptedDB {}
|
||||
|
||||
impl Domain for CorruptedDB {
|
||||
|
@ -1,33 +0,0 @@
|
||||
[package]
|
||||
name = "db2"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
publish = false
|
||||
|
||||
[lib]
|
||||
path = "src/db2.rs"
|
||||
doctest = false
|
||||
|
||||
[features]
|
||||
test-support = []
|
||||
|
||||
[dependencies]
|
||||
collections = { path = "../collections" }
|
||||
gpui = { package = "gpui2", path = "../gpui2" }
|
||||
sqlez = { path = "../sqlez" }
|
||||
sqlez_macros = { path = "../sqlez_macros" }
|
||||
util = { path = "../util" }
|
||||
anyhow.workspace = true
|
||||
indoc.workspace = true
|
||||
async-trait.workspace = true
|
||||
lazy_static.workspace = true
|
||||
log.workspace = true
|
||||
parking_lot.workspace = true
|
||||
serde.workspace = true
|
||||
serde_derive.workspace = true
|
||||
smol.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
gpui = { package = "gpui2", path = "../gpui2", features = ["test-support"] }
|
||||
env_logger.workspace = true
|
||||
tempdir.workspace = true
|
@ -1,5 +0,0 @@
|
||||
# Building Queries
|
||||
|
||||
First, craft your test data. The examples folder shows a template for building a test-db, and can be ran with `cargo run --example [your-example]`.
|
||||
|
||||
To actually use and test your queries, import the generated DB file into https://sqliteonline.com/
|
@ -1,331 +0,0 @@
|
||||
pub mod kvp;
|
||||
pub mod query;
|
||||
|
||||
// Re-export
|
||||
pub use anyhow;
|
||||
use anyhow::Context;
|
||||
use gpui::AppContext;
|
||||
pub use indoc::indoc;
|
||||
pub use lazy_static;
|
||||
pub use smol;
|
||||
pub use sqlez;
|
||||
pub use sqlez_macros;
|
||||
pub use util::channel::{RELEASE_CHANNEL, RELEASE_CHANNEL_NAME};
|
||||
pub use util::paths::DB_DIR;
|
||||
|
||||
use sqlez::domain::Migrator;
|
||||
use sqlez::thread_safe_connection::ThreadSafeConnection;
|
||||
use sqlez_macros::sql;
|
||||
use std::future::Future;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use util::channel::ReleaseChannel;
|
||||
use util::{async_maybe, ResultExt};
|
||||
|
||||
const CONNECTION_INITIALIZE_QUERY: &'static str = sql!(
|
||||
PRAGMA foreign_keys=TRUE;
|
||||
);
|
||||
|
||||
const DB_INITIALIZE_QUERY: &'static str = sql!(
|
||||
PRAGMA journal_mode=WAL;
|
||||
PRAGMA busy_timeout=1;
|
||||
PRAGMA case_sensitive_like=TRUE;
|
||||
PRAGMA synchronous=NORMAL;
|
||||
);
|
||||
|
||||
const FALLBACK_DB_NAME: &'static str = "FALLBACK_MEMORY_DB";
|
||||
|
||||
const DB_FILE_NAME: &'static str = "db.sqlite";
|
||||
|
||||
lazy_static::lazy_static! {
|
||||
pub static ref ZED_STATELESS: bool = std::env::var("ZED_STATELESS").map_or(false, |v| !v.is_empty());
|
||||
pub static ref ALL_FILE_DB_FAILED: AtomicBool = AtomicBool::new(false);
|
||||
}
|
||||
|
||||
/// Open or create a database at the given directory path.
|
||||
/// This will retry a couple times if there are failures. If opening fails once, the db directory
|
||||
/// is moved to a backup folder and a new one is created. If that fails, a shared in memory db is created.
|
||||
/// In either case, static variables are set so that the user can be notified.
|
||||
pub async fn open_db<M: Migrator + 'static>(
|
||||
db_dir: &Path,
|
||||
release_channel: &ReleaseChannel,
|
||||
) -> ThreadSafeConnection<M> {
|
||||
if *ZED_STATELESS {
|
||||
return open_fallback_db().await;
|
||||
}
|
||||
|
||||
let release_channel_name = release_channel.dev_name();
|
||||
let main_db_dir = db_dir.join(Path::new(&format!("0-{}", release_channel_name)));
|
||||
|
||||
let connection = async_maybe!({
|
||||
smol::fs::create_dir_all(&main_db_dir)
|
||||
.await
|
||||
.context("Could not create db directory")
|
||||
.log_err()?;
|
||||
let db_path = main_db_dir.join(Path::new(DB_FILE_NAME));
|
||||
open_main_db(&db_path).await
|
||||
})
|
||||
.await;
|
||||
|
||||
if let Some(connection) = connection {
|
||||
return connection;
|
||||
}
|
||||
|
||||
// Set another static ref so that we can escalate the notification
|
||||
ALL_FILE_DB_FAILED.store(true, Ordering::Release);
|
||||
|
||||
// If still failed, create an in memory db with a known name
|
||||
open_fallback_db().await
|
||||
}
|
||||
|
||||
async fn open_main_db<M: Migrator>(db_path: &PathBuf) -> Option<ThreadSafeConnection<M>> {
|
||||
log::info!("Opening main db");
|
||||
ThreadSafeConnection::<M>::builder(db_path.to_string_lossy().as_ref(), true)
|
||||
.with_db_initialization_query(DB_INITIALIZE_QUERY)
|
||||
.with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY)
|
||||
.build()
|
||||
.await
|
||||
.log_err()
|
||||
}
|
||||
|
||||
async fn open_fallback_db<M: Migrator>() -> ThreadSafeConnection<M> {
|
||||
log::info!("Opening fallback db");
|
||||
ThreadSafeConnection::<M>::builder(FALLBACK_DB_NAME, false)
|
||||
.with_db_initialization_query(DB_INITIALIZE_QUERY)
|
||||
.with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY)
|
||||
.build()
|
||||
.await
|
||||
.expect(
|
||||
"Fallback in memory database failed. Likely initialization queries or migrations have fundamental errors",
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
pub async fn open_test_db<M: Migrator>(db_name: &str) -> ThreadSafeConnection<M> {
|
||||
use sqlez::thread_safe_connection::locking_queue;
|
||||
|
||||
ThreadSafeConnection::<M>::builder(db_name, false)
|
||||
.with_db_initialization_query(DB_INITIALIZE_QUERY)
|
||||
.with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY)
|
||||
// Serialize queued writes via a mutex and run them synchronously
|
||||
.with_write_queue_constructor(locking_queue())
|
||||
.build()
|
||||
.await
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Implements a basic DB wrapper for a given domain
|
||||
#[macro_export]
|
||||
macro_rules! define_connection {
|
||||
(pub static ref $id:ident: $t:ident<()> = $migrations:expr;) => {
|
||||
pub struct $t($crate::sqlez::thread_safe_connection::ThreadSafeConnection<$t>);
|
||||
|
||||
impl ::std::ops::Deref for $t {
|
||||
type Target = $crate::sqlez::thread_safe_connection::ThreadSafeConnection<$t>;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl $crate::sqlez::domain::Domain for $t {
|
||||
fn name() -> &'static str {
|
||||
stringify!($t)
|
||||
}
|
||||
|
||||
fn migrations() -> &'static [&'static str] {
|
||||
$migrations
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
$crate::lazy_static::lazy_static! {
|
||||
pub static ref $id: $t = $t($crate::smol::block_on($crate::open_test_db(stringify!($id))));
|
||||
}
|
||||
|
||||
#[cfg(not(any(test, feature = "test-support")))]
|
||||
$crate::lazy_static::lazy_static! {
|
||||
pub static ref $id: $t = $t($crate::smol::block_on($crate::open_db(&$crate::DB_DIR, &$crate::RELEASE_CHANNEL)));
|
||||
}
|
||||
};
|
||||
(pub static ref $id:ident: $t:ident<$($d:ty),+> = $migrations:expr;) => {
|
||||
pub struct $t($crate::sqlez::thread_safe_connection::ThreadSafeConnection<( $($d),+, $t )>);
|
||||
|
||||
impl ::std::ops::Deref for $t {
|
||||
type Target = $crate::sqlez::thread_safe_connection::ThreadSafeConnection<($($d),+, $t)>;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl $crate::sqlez::domain::Domain for $t {
|
||||
fn name() -> &'static str {
|
||||
stringify!($t)
|
||||
}
|
||||
|
||||
fn migrations() -> &'static [&'static str] {
|
||||
$migrations
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
$crate::lazy_static::lazy_static! {
|
||||
pub static ref $id: $t = $t($crate::smol::block_on($crate::open_test_db(stringify!($id))));
|
||||
}
|
||||
|
||||
#[cfg(not(any(test, feature = "test-support")))]
|
||||
$crate::lazy_static::lazy_static! {
|
||||
pub static ref $id: $t = $t($crate::smol::block_on($crate::open_db(&$crate::DB_DIR, &$crate::RELEASE_CHANNEL)));
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
pub fn write_and_log<F>(cx: &mut AppContext, db_write: impl FnOnce() -> F + Send + 'static)
|
||||
where
|
||||
F: Future<Output = anyhow::Result<()>> + Send,
|
||||
{
|
||||
cx.background_executor()
|
||||
.spawn(async move { db_write().await.log_err() })
|
||||
.detach()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::thread;
|
||||
|
||||
use sqlez::domain::Domain;
|
||||
use sqlez_macros::sql;
|
||||
use tempdir::TempDir;
|
||||
|
||||
use crate::open_db;
|
||||
|
||||
// Test bad migration panics
|
||||
#[gpui::test]
|
||||
#[should_panic]
|
||||
async fn test_bad_migration_panics() {
|
||||
enum BadDB {}
|
||||
|
||||
impl Domain for BadDB {
|
||||
fn name() -> &'static str {
|
||||
"db_tests"
|
||||
}
|
||||
|
||||
fn migrations() -> &'static [&'static str] {
|
||||
&[
|
||||
sql!(CREATE TABLE test(value);),
|
||||
// failure because test already exists
|
||||
sql!(CREATE TABLE test(value);),
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
let tempdir = TempDir::new("DbTests").unwrap();
|
||||
let _bad_db = open_db::<BadDB>(tempdir.path(), &util::channel::ReleaseChannel::Dev).await;
|
||||
}
|
||||
|
||||
/// Test that DB exists but corrupted (causing recreate)
|
||||
#[gpui::test]
|
||||
async fn test_db_corruption(cx: &mut gpui::TestAppContext) {
|
||||
cx.executor().allow_parking();
|
||||
|
||||
enum CorruptedDB {}
|
||||
|
||||
impl Domain for CorruptedDB {
|
||||
fn name() -> &'static str {
|
||||
"db_tests"
|
||||
}
|
||||
|
||||
fn migrations() -> &'static [&'static str] {
|
||||
&[sql!(CREATE TABLE test(value);)]
|
||||
}
|
||||
}
|
||||
|
||||
enum GoodDB {}
|
||||
|
||||
impl Domain for GoodDB {
|
||||
fn name() -> &'static str {
|
||||
"db_tests" //Notice same name
|
||||
}
|
||||
|
||||
fn migrations() -> &'static [&'static str] {
|
||||
&[sql!(CREATE TABLE test2(value);)] //But different migration
|
||||
}
|
||||
}
|
||||
|
||||
let tempdir = TempDir::new("DbTests").unwrap();
|
||||
{
|
||||
let corrupt_db =
|
||||
open_db::<CorruptedDB>(tempdir.path(), &util::channel::ReleaseChannel::Dev).await;
|
||||
assert!(corrupt_db.persistent());
|
||||
}
|
||||
|
||||
let good_db = open_db::<GoodDB>(tempdir.path(), &util::channel::ReleaseChannel::Dev).await;
|
||||
assert!(
|
||||
good_db.select_row::<usize>("SELECT * FROM test2").unwrap()()
|
||||
.unwrap()
|
||||
.is_none()
|
||||
);
|
||||
}
|
||||
|
||||
/// Test that DB exists but corrupted (causing recreate)
|
||||
#[gpui::test(iterations = 30)]
|
||||
async fn test_simultaneous_db_corruption(cx: &mut gpui::TestAppContext) {
|
||||
cx.executor().allow_parking();
|
||||
|
||||
enum CorruptedDB {}
|
||||
|
||||
impl Domain for CorruptedDB {
|
||||
fn name() -> &'static str {
|
||||
"db_tests"
|
||||
}
|
||||
|
||||
fn migrations() -> &'static [&'static str] {
|
||||
&[sql!(CREATE TABLE test(value);)]
|
||||
}
|
||||
}
|
||||
|
||||
enum GoodDB {}
|
||||
|
||||
impl Domain for GoodDB {
|
||||
fn name() -> &'static str {
|
||||
"db_tests" //Notice same name
|
||||
}
|
||||
|
||||
fn migrations() -> &'static [&'static str] {
|
||||
&[sql!(CREATE TABLE test2(value);)] //But different migration
|
||||
}
|
||||
}
|
||||
|
||||
let tempdir = TempDir::new("DbTests").unwrap();
|
||||
{
|
||||
// Setup the bad database
|
||||
let corrupt_db =
|
||||
open_db::<CorruptedDB>(tempdir.path(), &util::channel::ReleaseChannel::Dev).await;
|
||||
assert!(corrupt_db.persistent());
|
||||
}
|
||||
|
||||
// Try to connect to it a bunch of times at once
|
||||
let mut guards = vec![];
|
||||
for _ in 0..10 {
|
||||
let tmp_path = tempdir.path().to_path_buf();
|
||||
let guard = thread::spawn(move || {
|
||||
let good_db = smol::block_on(open_db::<GoodDB>(
|
||||
tmp_path.as_path(),
|
||||
&util::channel::ReleaseChannel::Dev,
|
||||
));
|
||||
assert!(
|
||||
good_db.select_row::<usize>("SELECT * FROM test2").unwrap()()
|
||||
.unwrap()
|
||||
.is_none()
|
||||
);
|
||||
});
|
||||
|
||||
guards.push(guard);
|
||||
}
|
||||
|
||||
for guard in guards.into_iter() {
|
||||
assert!(guard.join().is_ok());
|
||||
}
|
||||
}
|
||||
}
|
@ -1,62 +0,0 @@
|
||||
use sqlez_macros::sql;
|
||||
|
||||
use crate::{define_connection, query};
|
||||
|
||||
define_connection!(pub static ref KEY_VALUE_STORE: KeyValueStore<()> =
|
||||
&[sql!(
|
||||
CREATE TABLE IF NOT EXISTS kv_store(
|
||||
key TEXT PRIMARY KEY,
|
||||
value TEXT NOT NULL
|
||||
) STRICT;
|
||||
)];
|
||||
);
|
||||
|
||||
impl KeyValueStore {
|
||||
query! {
|
||||
pub fn read_kvp(key: &str) -> Result<Option<String>> {
|
||||
SELECT value FROM kv_store WHERE key = (?)
|
||||
}
|
||||
}
|
||||
|
||||
query! {
|
||||
pub async fn write_kvp(key: String, value: String) -> Result<()> {
|
||||
INSERT OR REPLACE INTO kv_store(key, value) VALUES ((?), (?))
|
||||
}
|
||||
}
|
||||
|
||||
query! {
|
||||
pub async fn delete_kvp(key: String) -> Result<()> {
|
||||
DELETE FROM kv_store WHERE key = (?)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::kvp::KeyValueStore;
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_kvp() {
|
||||
let db = KeyValueStore(crate::open_test_db("test_kvp").await);
|
||||
|
||||
assert_eq!(db.read_kvp("key-1").unwrap(), None);
|
||||
|
||||
db.write_kvp("key-1".to_string(), "one".to_string())
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(db.read_kvp("key-1").unwrap(), Some("one".to_string()));
|
||||
|
||||
db.write_kvp("key-1".to_string(), "one-2".to_string())
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(db.read_kvp("key-1").unwrap(), Some("one-2".to_string()));
|
||||
|
||||
db.write_kvp("key-2".to_string(), "two".to_string())
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(db.read_kvp("key-2").unwrap(), Some("two".to_string()));
|
||||
|
||||
db.delete_kvp("key-1".to_string()).await.unwrap();
|
||||
assert_eq!(db.read_kvp("key-1").unwrap(), None);
|
||||
}
|
||||
}
|
@ -1,314 +0,0 @@
|
||||
#[macro_export]
|
||||
macro_rules! query {
|
||||
($vis:vis fn $id:ident() -> Result<()> { $($sql:tt)+ }) => {
|
||||
$vis fn $id(&self) -> $crate::anyhow::Result<()> {
|
||||
use $crate::anyhow::Context;
|
||||
|
||||
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
|
||||
|
||||
self.exec(sql_stmt)?().context(::std::format!(
|
||||
"Error in {}, exec failed to execute or parse for: {}",
|
||||
::std::stringify!($id),
|
||||
sql_stmt,
|
||||
))
|
||||
}
|
||||
};
|
||||
($vis:vis async fn $id:ident() -> Result<()> { $($sql:tt)+ }) => {
|
||||
$vis async fn $id(&self) -> $crate::anyhow::Result<()> {
|
||||
use $crate::anyhow::Context;
|
||||
|
||||
self.write(|connection| {
|
||||
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
|
||||
|
||||
connection.exec(sql_stmt)?().context(::std::format!(
|
||||
"Error in {}, exec failed to execute or parse for: {}",
|
||||
::std::stringify!($id),
|
||||
sql_stmt
|
||||
))
|
||||
}).await
|
||||
}
|
||||
};
|
||||
($vis:vis fn $id:ident($($arg:ident: $arg_type:ty),+) -> Result<()> { $($sql:tt)+ }) => {
|
||||
$vis fn $id(&self, $($arg: $arg_type),+) -> $crate::anyhow::Result<()> {
|
||||
use $crate::anyhow::Context;
|
||||
|
||||
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
|
||||
|
||||
self.exec_bound::<($($arg_type),+)>(sql_stmt)?(($($arg),+))
|
||||
.context(::std::format!(
|
||||
"Error in {}, exec_bound failed to execute or parse for: {}",
|
||||
::std::stringify!($id),
|
||||
sql_stmt
|
||||
))
|
||||
}
|
||||
};
|
||||
($vis:vis async fn $id:ident($arg:ident: $arg_type:ty) -> Result<()> { $($sql:tt)+ }) => {
|
||||
$vis async fn $id(&self, $arg: $arg_type) -> $crate::anyhow::Result<()> {
|
||||
use $crate::anyhow::Context;
|
||||
|
||||
self.write(move |connection| {
|
||||
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
|
||||
|
||||
connection.exec_bound::<$arg_type>(sql_stmt)?($arg)
|
||||
.context(::std::format!(
|
||||
"Error in {}, exec_bound failed to execute or parse for: {}",
|
||||
::std::stringify!($id),
|
||||
sql_stmt
|
||||
))
|
||||
}).await
|
||||
}
|
||||
};
|
||||
($vis:vis async fn $id:ident($($arg:ident: $arg_type:ty),+) -> Result<()> { $($sql:tt)+ }) => {
|
||||
$vis async fn $id(&self, $($arg: $arg_type),+) -> $crate::anyhow::Result<()> {
|
||||
use $crate::anyhow::Context;
|
||||
|
||||
self.write(move |connection| {
|
||||
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
|
||||
|
||||
connection.exec_bound::<($($arg_type),+)>(sql_stmt)?(($($arg),+))
|
||||
.context(::std::format!(
|
||||
"Error in {}, exec_bound failed to execute or parse for: {}",
|
||||
::std::stringify!($id),
|
||||
sql_stmt
|
||||
))
|
||||
}).await
|
||||
}
|
||||
};
|
||||
($vis:vis fn $id:ident() -> Result<Vec<$return_type:ty>> { $($sql:tt)+ }) => {
|
||||
$vis fn $id(&self) -> $crate::anyhow::Result<Vec<$return_type>> {
|
||||
use $crate::anyhow::Context;
|
||||
|
||||
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
|
||||
|
||||
self.select::<$return_type>(sql_stmt)?()
|
||||
.context(::std::format!(
|
||||
"Error in {}, select_row failed to execute or parse for: {}",
|
||||
::std::stringify!($id),
|
||||
sql_stmt
|
||||
))
|
||||
}
|
||||
};
|
||||
($vis:vis async fn $id:ident() -> Result<Vec<$return_type:ty>> { $($sql:tt)+ }) => {
|
||||
pub async fn $id(&self) -> $crate::anyhow::Result<Vec<$return_type>> {
|
||||
use $crate::anyhow::Context;
|
||||
|
||||
self.write(|connection| {
|
||||
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
|
||||
|
||||
connection.select::<$return_type>(sql_stmt)?()
|
||||
.context(::std::format!(
|
||||
"Error in {}, select_row failed to execute or parse for: {}",
|
||||
::std::stringify!($id),
|
||||
sql_stmt
|
||||
))
|
||||
}).await
|
||||
}
|
||||
};
|
||||
($vis:vis fn $id:ident($($arg:ident: $arg_type:ty),+) -> Result<Vec<$return_type:ty>> { $($sql:tt)+ }) => {
|
||||
$vis fn $id(&self, $($arg: $arg_type),+) -> $crate::anyhow::Result<Vec<$return_type>> {
|
||||
use $crate::anyhow::Context;
|
||||
|
||||
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
|
||||
|
||||
self.select_bound::<($($arg_type),+), $return_type>(sql_stmt)?(($($arg),+))
|
||||
.context(::std::format!(
|
||||
"Error in {}, exec_bound failed to execute or parse for: {}",
|
||||
::std::stringify!($id),
|
||||
sql_stmt
|
||||
))
|
||||
}
|
||||
};
|
||||
($vis:vis async fn $id:ident($($arg:ident: $arg_type:ty),+) -> Result<Vec<$return_type:ty>> { $($sql:tt)+ }) => {
|
||||
$vis async fn $id(&self, $($arg: $arg_type),+) -> $crate::anyhow::Result<Vec<$return_type>> {
|
||||
use $crate::anyhow::Context;
|
||||
|
||||
self.write(|connection| {
|
||||
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
|
||||
|
||||
connection.select_bound::<($($arg_type),+), $return_type>(sql_stmt)?(($($arg),+))
|
||||
.context(::std::format!(
|
||||
"Error in {}, exec_bound failed to execute or parse for: {}",
|
||||
::std::stringify!($id),
|
||||
sql_stmt
|
||||
))
|
||||
}).await
|
||||
}
|
||||
};
|
||||
($vis:vis fn $id:ident() -> Result<Option<$return_type:ty>> { $($sql:tt)+ }) => {
|
||||
$vis fn $id(&self) -> $crate::anyhow::Result<Option<$return_type>> {
|
||||
use $crate::anyhow::Context;
|
||||
|
||||
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
|
||||
|
||||
self.select_row::<$return_type>(sql_stmt)?()
|
||||
.context(::std::format!(
|
||||
"Error in {}, select_row failed to execute or parse for: {}",
|
||||
::std::stringify!($id),
|
||||
sql_stmt
|
||||
))
|
||||
}
|
||||
};
|
||||
($vis:vis async fn $id:ident() -> Result<Option<$return_type:ty>> { $($sql:tt)+ }) => {
|
||||
$vis async fn $id(&self) -> $crate::anyhow::Result<Option<$return_type>> {
|
||||
use $crate::anyhow::Context;
|
||||
|
||||
self.write(|connection| {
|
||||
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
|
||||
|
||||
connection.select_row::<$return_type>(sql_stmt)?()
|
||||
.context(::std::format!(
|
||||
"Error in {}, select_row failed to execute or parse for: {}",
|
||||
::std::stringify!($id),
|
||||
sql_stmt
|
||||
))
|
||||
}).await
|
||||
}
|
||||
};
|
||||
($vis:vis fn $id:ident($arg:ident: $arg_type:ty) -> Result<Option<$return_type:ty>> { $($sql:tt)+ }) => {
|
||||
$vis fn $id(&self, $arg: $arg_type) -> $crate::anyhow::Result<Option<$return_type>> {
|
||||
use $crate::anyhow::Context;
|
||||
|
||||
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
|
||||
|
||||
self.select_row_bound::<$arg_type, $return_type>(sql_stmt)?($arg)
|
||||
.context(::std::format!(
|
||||
"Error in {}, select_row_bound failed to execute or parse for: {}",
|
||||
::std::stringify!($id),
|
||||
sql_stmt
|
||||
))
|
||||
|
||||
}
|
||||
};
|
||||
($vis:vis fn $id:ident($($arg:ident: $arg_type:ty),+) -> Result<Option<$return_type:ty>> { $($sql:tt)+ }) => {
|
||||
$vis fn $id(&self, $($arg: $arg_type),+) -> $crate::anyhow::Result<Option<$return_type>> {
|
||||
use $crate::anyhow::Context;
|
||||
|
||||
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
|
||||
|
||||
self.select_row_bound::<($($arg_type),+), $return_type>(sql_stmt)?(($($arg),+))
|
||||
.context(::std::format!(
|
||||
"Error in {}, select_row_bound failed to execute or parse for: {}",
|
||||
::std::stringify!($id),
|
||||
sql_stmt
|
||||
))
|
||||
|
||||
}
|
||||
};
|
||||
($vis:vis async fn $id:ident($($arg:ident: $arg_type:ty),+) -> Result<Option<$return_type:ty>> { $($sql:tt)+ }) => {
|
||||
$vis async fn $id(&self, $($arg: $arg_type),+) -> $crate::anyhow::Result<Option<$return_type>> {
|
||||
use $crate::anyhow::Context;
|
||||
|
||||
|
||||
self.write(move |connection| {
|
||||
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
|
||||
|
||||
connection.select_row_bound::<($($arg_type),+), $return_type>(sql_stmt)?(($($arg),+))
|
||||
.context(::std::format!(
|
||||
"Error in {}, select_row_bound failed to execute or parse for: {}",
|
||||
::std::stringify!($id),
|
||||
sql_stmt
|
||||
))
|
||||
}).await
|
||||
}
|
||||
};
|
||||
($vis:vis fn $id:ident() -> Result<$return_type:ty> { $($sql:tt)+ }) => {
|
||||
$vis fn $id(&self) -> $crate::anyhow::Result<$return_type> {
|
||||
use $crate::anyhow::Context;
|
||||
|
||||
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
|
||||
|
||||
self.select_row::<$return_type>(indoc! { $sql })?()
|
||||
.context(::std::format!(
|
||||
"Error in {}, select_row_bound failed to execute or parse for: {}",
|
||||
::std::stringify!($id),
|
||||
sql_stmt
|
||||
))?
|
||||
.context(::std::format!(
|
||||
"Error in {}, select_row_bound expected single row result but found none for: {}",
|
||||
::std::stringify!($id),
|
||||
sql_stmt
|
||||
))
|
||||
}
|
||||
};
|
||||
($vis:vis async fn $id:ident() -> Result<$return_type:ty> { $($sql:tt)+ }) => {
|
||||
$vis async fn $id(&self) -> $crate::anyhow::Result<$return_type> {
|
||||
use $crate::anyhow::Context;
|
||||
|
||||
self.write(|connection| {
|
||||
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
|
||||
|
||||
connection.select_row::<$return_type>(sql_stmt)?()
|
||||
.context(::std::format!(
|
||||
"Error in {}, select_row_bound failed to execute or parse for: {}",
|
||||
::std::stringify!($id),
|
||||
sql_stmt
|
||||
))?
|
||||
.context(::std::format!(
|
||||
"Error in {}, select_row_bound expected single row result but found none for: {}",
|
||||
::std::stringify!($id),
|
||||
sql_stmt
|
||||
))
|
||||
}).await
|
||||
}
|
||||
};
|
||||
($vis:vis fn $id:ident($arg:ident: $arg_type:ty) -> Result<$return_type:ty> { $($sql:tt)+ }) => {
|
||||
pub fn $id(&self, $arg: $arg_type) -> $crate::anyhow::Result<$return_type> {
|
||||
use $crate::anyhow::Context;
|
||||
|
||||
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
|
||||
|
||||
self.select_row_bound::<$arg_type, $return_type>(sql_stmt)?($arg)
|
||||
.context(::std::format!(
|
||||
"Error in {}, select_row_bound failed to execute or parse for: {}",
|
||||
::std::stringify!($id),
|
||||
sql_stmt
|
||||
))?
|
||||
.context(::std::format!(
|
||||
"Error in {}, select_row_bound expected single row result but found none for: {}",
|
||||
::std::stringify!($id),
|
||||
sql_stmt
|
||||
))
|
||||
}
|
||||
};
|
||||
($vis:vis fn $id:ident($($arg:ident: $arg_type:ty),+) -> Result<$return_type:ty> { $($sql:tt)+ }) => {
|
||||
$vis fn $id(&self, $($arg: $arg_type),+) -> $crate::anyhow::Result<$return_type> {
|
||||
use $crate::anyhow::Context;
|
||||
|
||||
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
|
||||
|
||||
self.select_row_bound::<($($arg_type),+), $return_type>(sql_stmt)?(($($arg),+))
|
||||
.context(::std::format!(
|
||||
"Error in {}, select_row_bound failed to execute or parse for: {}",
|
||||
::std::stringify!($id),
|
||||
sql_stmt
|
||||
))?
|
||||
.context(::std::format!(
|
||||
"Error in {}, select_row_bound expected single row result but found none for: {}",
|
||||
::std::stringify!($id),
|
||||
sql_stmt
|
||||
))
|
||||
}
|
||||
};
|
||||
($vis:vis fn async $id:ident($($arg:ident: $arg_type:ty),+) -> Result<$return_type:ty> { $($sql:tt)+ }) => {
|
||||
$vis async fn $id(&self, $($arg: $arg_type),+) -> $crate::anyhow::Result<$return_type> {
|
||||
use $crate::anyhow::Context;
|
||||
|
||||
|
||||
self.write(|connection| {
|
||||
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
|
||||
|
||||
connection.select_row_bound::<($($arg_type),+), $return_type>(sql_stmt)?(($($arg),+))
|
||||
.context(::std::format!(
|
||||
"Error in {}, select_row_bound failed to execute or parse for: {}",
|
||||
::std::stringify!($id),
|
||||
sql_stmt
|
||||
))?
|
||||
.context(::std::format!(
|
||||
"Error in {}, select_row_bound expected single row result but found none for: {}",
|
||||
::std::stringify!($id),
|
||||
sql_stmt
|
||||
))
|
||||
}).await
|
||||
}
|
||||
};
|
||||
}
|
@ -26,7 +26,7 @@ test-support = [
|
||||
client = { path = "../client" }
|
||||
clock = { path = "../clock" }
|
||||
copilot = { path = "../copilot" }
|
||||
db = { package="db2", path = "../db2" }
|
||||
db = { path = "../db" }
|
||||
collections = { path = "../collections" }
|
||||
# context_menu = { path = "../context_menu" }
|
||||
fuzzy = { path = "../fuzzy" }
|
||||
@ -36,7 +36,7 @@ language = { path = "../language" }
|
||||
lsp = { path = "../lsp" }
|
||||
multi_buffer = { path = "../multi_buffer" }
|
||||
project = { path = "../project" }
|
||||
rpc = { package = "rpc2", path = "../rpc2" }
|
||||
rpc = { path = "../rpc" }
|
||||
rich_text = { path = "../rich_text" }
|
||||
settings = { package="settings2", path = "../settings2" }
|
||||
snippet = { path = "../snippet" }
|
||||
|
@ -12,7 +12,7 @@ test-support = []
|
||||
|
||||
[dependencies]
|
||||
client = { path = "../client" }
|
||||
db = { package = "db2", path = "../db2" }
|
||||
db = { path = "../db" }
|
||||
editor = { path = "../editor" }
|
||||
gpui = { package = "gpui2", path = "../gpui2" }
|
||||
language = { path = "../language" }
|
||||
|
@ -9,8 +9,8 @@ path = "src/fs.rs"
|
||||
|
||||
[dependencies]
|
||||
collections = { path = "../collections" }
|
||||
rope = { path = "../rope" }
|
||||
text = { path = "../text" }
|
||||
rope = { package = "rope2", path = "../rope2" }
|
||||
text = { package = "text2", path = "../text2" }
|
||||
util = { path = "../util" }
|
||||
sum_tree = { path = "../sum_tree" }
|
||||
|
||||
@ -31,10 +31,10 @@ log.workspace = true
|
||||
libc = "0.2"
|
||||
time.workspace = true
|
||||
|
||||
gpui = { path = "../gpui", optional = true}
|
||||
gpui = { package = "gpui2", path = "../gpui2", optional = true}
|
||||
|
||||
[dev-dependencies]
|
||||
gpui = { path = "../gpui", features = ["test-support"] }
|
||||
gpui = { package = "gpui2", path = "../gpui2", features = ["test-support"] }
|
||||
|
||||
[features]
|
||||
test-support = ["gpui/test-support"]
|
||||
|
@ -27,8 +27,6 @@ use collections::{btree_map, BTreeMap};
|
||||
use repository::{FakeGitRepositoryState, GitFileStatus};
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
use std::ffi::OsStr;
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
use std::sync::Weak;
|
||||
|
||||
#[async_trait::async_trait]
|
||||
pub trait Fs: Send + Sync {
|
||||
@ -290,7 +288,7 @@ impl Fs for RealFs {
|
||||
pub struct FakeFs {
|
||||
// Use an unfair lock to ensure tests are deterministic.
|
||||
state: Mutex<FakeFsState>,
|
||||
executor: Weak<gpui::executor::Background>,
|
||||
executor: gpui::BackgroundExecutor,
|
||||
}
|
||||
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
@ -436,9 +434,9 @@ lazy_static::lazy_static! {
|
||||
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
impl FakeFs {
|
||||
pub fn new(executor: Arc<gpui::executor::Background>) -> Arc<Self> {
|
||||
pub fn new(executor: gpui::BackgroundExecutor) -> Arc<Self> {
|
||||
Arc::new(Self {
|
||||
executor: Arc::downgrade(&executor),
|
||||
executor,
|
||||
state: Mutex::new(FakeFsState {
|
||||
root: Arc::new(Mutex::new(FakeFsEntry::Dir {
|
||||
inode: 0,
|
||||
@ -699,12 +697,8 @@ impl FakeFs {
|
||||
self.state.lock().metadata_call_count
|
||||
}
|
||||
|
||||
async fn simulate_random_delay(&self) {
|
||||
self.executor
|
||||
.upgrade()
|
||||
.expect("executor has been dropped")
|
||||
.simulate_random_delay()
|
||||
.await;
|
||||
fn simulate_random_delay(&self) -> impl futures::Future<Output = ()> {
|
||||
self.executor.simulate_random_delay()
|
||||
}
|
||||
}
|
||||
|
||||
@ -1103,9 +1097,7 @@ impl Fs for FakeFs {
|
||||
let result = events.iter().any(|event| event.path.starts_with(&path));
|
||||
let executor = executor.clone();
|
||||
async move {
|
||||
if let Some(executor) = executor.clone().upgrade() {
|
||||
executor.simulate_random_delay().await;
|
||||
}
|
||||
result
|
||||
}
|
||||
}))
|
||||
@ -1230,13 +1222,12 @@ pub fn copy_recursive<'a>(
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use gpui::TestAppContext;
|
||||
use gpui::BackgroundExecutor;
|
||||
use serde_json::json;
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_fake_fs(cx: &mut TestAppContext) {
|
||||
let fs = FakeFs::new(cx.background());
|
||||
|
||||
async fn test_fake_fs(executor: BackgroundExecutor) {
|
||||
let fs = FakeFs::new(executor.clone());
|
||||
fs.insert_tree(
|
||||
"/root",
|
||||
json!({
|
||||
|
@ -1,40 +0,0 @@
|
||||
[package]
|
||||
name = "fs2"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
publish = false
|
||||
|
||||
[lib]
|
||||
path = "src/fs2.rs"
|
||||
|
||||
[dependencies]
|
||||
collections = { path = "../collections" }
|
||||
rope = { package = "rope2", path = "../rope2" }
|
||||
text = { package = "text2", path = "../text2" }
|
||||
util = { path = "../util" }
|
||||
sum_tree = { path = "../sum_tree" }
|
||||
|
||||
anyhow.workspace = true
|
||||
async-trait.workspace = true
|
||||
futures.workspace = true
|
||||
tempfile = "3"
|
||||
fsevent = { path = "../fsevent" }
|
||||
lazy_static.workspace = true
|
||||
parking_lot.workspace = true
|
||||
smol.workspace = true
|
||||
regex.workspace = true
|
||||
git2.workspace = true
|
||||
serde.workspace = true
|
||||
serde_derive.workspace = true
|
||||
serde_json.workspace = true
|
||||
log.workspace = true
|
||||
libc = "0.2"
|
||||
time.workspace = true
|
||||
|
||||
gpui = { package = "gpui2", path = "../gpui2", optional = true}
|
||||
|
||||
[dev-dependencies]
|
||||
gpui = { package = "gpui2", path = "../gpui2", features = ["test-support"] }
|
||||
|
||||
[features]
|
||||
test-support = ["gpui/test-support"]
|
File diff suppressed because it is too large
Load Diff
@ -1,415 +0,0 @@
|
||||
use anyhow::Result;
|
||||
use collections::HashMap;
|
||||
use git2::{BranchType, StatusShow};
|
||||
use parking_lot::Mutex;
|
||||
use serde_derive::{Deserialize, Serialize};
|
||||
use std::{
|
||||
cmp::Ordering,
|
||||
ffi::OsStr,
|
||||
os::unix::prelude::OsStrExt,
|
||||
path::{Component, Path, PathBuf},
|
||||
sync::Arc,
|
||||
time::SystemTime,
|
||||
};
|
||||
use sum_tree::{MapSeekTarget, TreeMap};
|
||||
use util::ResultExt;
|
||||
|
||||
pub use git2::Repository as LibGitRepository;
|
||||
|
||||
#[derive(Clone, Debug, Hash, PartialEq)]
|
||||
pub struct Branch {
|
||||
pub name: Box<str>,
|
||||
/// Timestamp of most recent commit, normalized to Unix Epoch format.
|
||||
pub unix_timestamp: Option<i64>,
|
||||
}
|
||||
|
||||
pub trait GitRepository: Send {
|
||||
fn reload_index(&self);
|
||||
fn load_index_text(&self, relative_file_path: &Path) -> Option<String>;
|
||||
fn branch_name(&self) -> Option<String>;
|
||||
|
||||
/// Get the statuses of all of the files in the index that start with the given
|
||||
/// path and have changes with resepect to the HEAD commit. This is fast because
|
||||
/// the index stores hashes of trees, so that unchanged directories can be skipped.
|
||||
fn staged_statuses(&self, path_prefix: &Path) -> TreeMap<RepoPath, GitFileStatus>;
|
||||
|
||||
/// Get the status of a given file in the working directory with respect to
|
||||
/// the index. In the common case, when there are no changes, this only requires
|
||||
/// an index lookup. The index stores the mtime of each file when it was added,
|
||||
/// so there's no work to do if the mtime matches.
|
||||
fn unstaged_status(&self, path: &RepoPath, mtime: SystemTime) -> Option<GitFileStatus>;
|
||||
|
||||
/// Get the status of a given file in the working directory with respect to
|
||||
/// the HEAD commit. In the common case, when there are no changes, this only
|
||||
/// requires an index lookup and blob comparison between the index and the HEAD
|
||||
/// commit. The index stores the mtime of each file when it was added, so there's
|
||||
/// no need to consider the working directory file if the mtime matches.
|
||||
fn status(&self, path: &RepoPath, mtime: SystemTime) -> Option<GitFileStatus>;
|
||||
|
||||
fn branches(&self) -> Result<Vec<Branch>>;
|
||||
fn change_branch(&self, _: &str) -> Result<()>;
|
||||
fn create_branch(&self, _: &str) -> Result<()>;
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for dyn GitRepository {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("dyn GitRepository<...>").finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl GitRepository for LibGitRepository {
|
||||
fn reload_index(&self) {
|
||||
if let Ok(mut index) = self.index() {
|
||||
_ = index.read(false);
|
||||
}
|
||||
}
|
||||
|
||||
fn load_index_text(&self, relative_file_path: &Path) -> Option<String> {
|
||||
fn logic(repo: &LibGitRepository, relative_file_path: &Path) -> Result<Option<String>> {
|
||||
const STAGE_NORMAL: i32 = 0;
|
||||
let index = repo.index()?;
|
||||
|
||||
// This check is required because index.get_path() unwraps internally :(
|
||||
check_path_to_repo_path_errors(relative_file_path)?;
|
||||
|
||||
let oid = match index.get_path(&relative_file_path, STAGE_NORMAL) {
|
||||
Some(entry) => entry.id,
|
||||
None => return Ok(None),
|
||||
};
|
||||
|
||||
let content = repo.find_blob(oid)?.content().to_owned();
|
||||
Ok(Some(String::from_utf8(content)?))
|
||||
}
|
||||
|
||||
match logic(&self, relative_file_path) {
|
||||
Ok(value) => return value,
|
||||
Err(err) => log::error!("Error loading head text: {:?}", err),
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
fn branch_name(&self) -> Option<String> {
|
||||
let head = self.head().log_err()?;
|
||||
let branch = String::from_utf8_lossy(head.shorthand_bytes());
|
||||
Some(branch.to_string())
|
||||
}
|
||||
|
||||
fn staged_statuses(&self, path_prefix: &Path) -> TreeMap<RepoPath, GitFileStatus> {
|
||||
let mut map = TreeMap::default();
|
||||
|
||||
let mut options = git2::StatusOptions::new();
|
||||
options.pathspec(path_prefix);
|
||||
options.show(StatusShow::Index);
|
||||
|
||||
if let Some(statuses) = self.statuses(Some(&mut options)).log_err() {
|
||||
for status in statuses.iter() {
|
||||
let path = RepoPath(PathBuf::from(OsStr::from_bytes(status.path_bytes())));
|
||||
let status = status.status();
|
||||
if !status.contains(git2::Status::IGNORED) {
|
||||
if let Some(status) = read_status(status) {
|
||||
map.insert(path, status)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
map
|
||||
}
|
||||
|
||||
fn unstaged_status(&self, path: &RepoPath, mtime: SystemTime) -> Option<GitFileStatus> {
|
||||
// If the file has not changed since it was added to the index, then
|
||||
// there can't be any changes.
|
||||
if matches_index(self, path, mtime) {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut options = git2::StatusOptions::new();
|
||||
options.pathspec(&path.0);
|
||||
options.disable_pathspec_match(true);
|
||||
options.include_untracked(true);
|
||||
options.recurse_untracked_dirs(true);
|
||||
options.include_unmodified(true);
|
||||
options.show(StatusShow::Workdir);
|
||||
|
||||
let statuses = self.statuses(Some(&mut options)).log_err()?;
|
||||
let status = statuses.get(0).and_then(|s| read_status(s.status()));
|
||||
status
|
||||
}
|
||||
|
||||
fn status(&self, path: &RepoPath, mtime: SystemTime) -> Option<GitFileStatus> {
|
||||
let mut options = git2::StatusOptions::new();
|
||||
options.pathspec(&path.0);
|
||||
options.disable_pathspec_match(true);
|
||||
options.include_untracked(true);
|
||||
options.recurse_untracked_dirs(true);
|
||||
options.include_unmodified(true);
|
||||
|
||||
// If the file has not changed since it was added to the index, then
|
||||
// there's no need to examine the working directory file: just compare
|
||||
// the blob in the index to the one in the HEAD commit.
|
||||
if matches_index(self, path, mtime) {
|
||||
options.show(StatusShow::Index);
|
||||
}
|
||||
|
||||
let statuses = self.statuses(Some(&mut options)).log_err()?;
|
||||
let status = statuses.get(0).and_then(|s| read_status(s.status()));
|
||||
status
|
||||
}
|
||||
|
||||
fn branches(&self) -> Result<Vec<Branch>> {
|
||||
let local_branches = self.branches(Some(BranchType::Local))?;
|
||||
let valid_branches = local_branches
|
||||
.filter_map(|branch| {
|
||||
branch.ok().and_then(|(branch, _)| {
|
||||
let name = branch.name().ok().flatten().map(Box::from)?;
|
||||
let timestamp = branch.get().peel_to_commit().ok()?.time();
|
||||
let unix_timestamp = timestamp.seconds();
|
||||
let timezone_offset = timestamp.offset_minutes();
|
||||
let utc_offset =
|
||||
time::UtcOffset::from_whole_seconds(timezone_offset * 60).ok()?;
|
||||
let unix_timestamp =
|
||||
time::OffsetDateTime::from_unix_timestamp(unix_timestamp).ok()?;
|
||||
Some(Branch {
|
||||
name,
|
||||
unix_timestamp: Some(unix_timestamp.to_offset(utc_offset).unix_timestamp()),
|
||||
})
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
Ok(valid_branches)
|
||||
}
|
||||
fn change_branch(&self, name: &str) -> Result<()> {
|
||||
let revision = self.find_branch(name, BranchType::Local)?;
|
||||
let revision = revision.get();
|
||||
let as_tree = revision.peel_to_tree()?;
|
||||
self.checkout_tree(as_tree.as_object(), None)?;
|
||||
self.set_head(
|
||||
revision
|
||||
.name()
|
||||
.ok_or_else(|| anyhow::anyhow!("Branch name could not be retrieved"))?,
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
fn create_branch(&self, name: &str) -> Result<()> {
|
||||
let current_commit = self.head()?.peel_to_commit()?;
|
||||
self.branch(name, ¤t_commit, false)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn matches_index(repo: &LibGitRepository, path: &RepoPath, mtime: SystemTime) -> bool {
|
||||
if let Some(index) = repo.index().log_err() {
|
||||
if let Some(entry) = index.get_path(&path, 0) {
|
||||
if let Some(mtime) = mtime.duration_since(SystemTime::UNIX_EPOCH).log_err() {
|
||||
if entry.mtime.seconds() == mtime.as_secs() as i32
|
||||
&& entry.mtime.nanoseconds() == mtime.subsec_nanos()
|
||||
{
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
fn read_status(status: git2::Status) -> Option<GitFileStatus> {
|
||||
if status.contains(git2::Status::CONFLICTED) {
|
||||
Some(GitFileStatus::Conflict)
|
||||
} else if status.intersects(
|
||||
git2::Status::WT_MODIFIED
|
||||
| git2::Status::WT_RENAMED
|
||||
| git2::Status::INDEX_MODIFIED
|
||||
| git2::Status::INDEX_RENAMED,
|
||||
) {
|
||||
Some(GitFileStatus::Modified)
|
||||
} else if status.intersects(git2::Status::WT_NEW | git2::Status::INDEX_NEW) {
|
||||
Some(GitFileStatus::Added)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct FakeGitRepository {
|
||||
state: Arc<Mutex<FakeGitRepositoryState>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct FakeGitRepositoryState {
|
||||
pub index_contents: HashMap<PathBuf, String>,
|
||||
pub worktree_statuses: HashMap<RepoPath, GitFileStatus>,
|
||||
pub branch_name: Option<String>,
|
||||
}
|
||||
|
||||
impl FakeGitRepository {
|
||||
pub fn open(state: Arc<Mutex<FakeGitRepositoryState>>) -> Arc<Mutex<dyn GitRepository>> {
|
||||
Arc::new(Mutex::new(FakeGitRepository { state }))
|
||||
}
|
||||
}
|
||||
|
||||
impl GitRepository for FakeGitRepository {
|
||||
fn reload_index(&self) {}
|
||||
|
||||
fn load_index_text(&self, path: &Path) -> Option<String> {
|
||||
let state = self.state.lock();
|
||||
state.index_contents.get(path).cloned()
|
||||
}
|
||||
|
||||
fn branch_name(&self) -> Option<String> {
|
||||
let state = self.state.lock();
|
||||
state.branch_name.clone()
|
||||
}
|
||||
|
||||
fn staged_statuses(&self, path_prefix: &Path) -> TreeMap<RepoPath, GitFileStatus> {
|
||||
let mut map = TreeMap::default();
|
||||
let state = self.state.lock();
|
||||
for (repo_path, status) in state.worktree_statuses.iter() {
|
||||
if repo_path.0.starts_with(path_prefix) {
|
||||
map.insert(repo_path.to_owned(), status.to_owned());
|
||||
}
|
||||
}
|
||||
map
|
||||
}
|
||||
|
||||
fn unstaged_status(&self, _path: &RepoPath, _mtime: SystemTime) -> Option<GitFileStatus> {
|
||||
None
|
||||
}
|
||||
|
||||
fn status(&self, path: &RepoPath, _mtime: SystemTime) -> Option<GitFileStatus> {
|
||||
let state = self.state.lock();
|
||||
state.worktree_statuses.get(path).cloned()
|
||||
}
|
||||
|
||||
fn branches(&self) -> Result<Vec<Branch>> {
|
||||
Ok(vec![])
|
||||
}
|
||||
|
||||
fn change_branch(&self, name: &str) -> Result<()> {
|
||||
let mut state = self.state.lock();
|
||||
state.branch_name = Some(name.to_owned());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn create_branch(&self, name: &str) -> Result<()> {
|
||||
let mut state = self.state.lock();
|
||||
state.branch_name = Some(name.to_owned());
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn check_path_to_repo_path_errors(relative_file_path: &Path) -> Result<()> {
|
||||
match relative_file_path.components().next() {
|
||||
None => anyhow::bail!("repo path should not be empty"),
|
||||
Some(Component::Prefix(_)) => anyhow::bail!(
|
||||
"repo path `{}` should be relative, not a windows prefix",
|
||||
relative_file_path.to_string_lossy()
|
||||
),
|
||||
Some(Component::RootDir) => {
|
||||
anyhow::bail!(
|
||||
"repo path `{}` should be relative",
|
||||
relative_file_path.to_string_lossy()
|
||||
)
|
||||
}
|
||||
Some(Component::CurDir) => {
|
||||
anyhow::bail!(
|
||||
"repo path `{}` should not start with `.`",
|
||||
relative_file_path.to_string_lossy()
|
||||
)
|
||||
}
|
||||
Some(Component::ParentDir) => {
|
||||
anyhow::bail!(
|
||||
"repo path `{}` should not start with `..`",
|
||||
relative_file_path.to_string_lossy()
|
||||
)
|
||||
}
|
||||
_ => Ok(()),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum GitFileStatus {
|
||||
Added,
|
||||
Modified,
|
||||
Conflict,
|
||||
}
|
||||
|
||||
impl GitFileStatus {
|
||||
pub fn merge(
|
||||
this: Option<GitFileStatus>,
|
||||
other: Option<GitFileStatus>,
|
||||
prefer_other: bool,
|
||||
) -> Option<GitFileStatus> {
|
||||
if prefer_other {
|
||||
return other;
|
||||
} else {
|
||||
match (this, other) {
|
||||
(Some(GitFileStatus::Conflict), _) | (_, Some(GitFileStatus::Conflict)) => {
|
||||
Some(GitFileStatus::Conflict)
|
||||
}
|
||||
(Some(GitFileStatus::Modified), _) | (_, Some(GitFileStatus::Modified)) => {
|
||||
Some(GitFileStatus::Modified)
|
||||
}
|
||||
(Some(GitFileStatus::Added), _) | (_, Some(GitFileStatus::Added)) => {
|
||||
Some(GitFileStatus::Added)
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Ord, Hash, PartialOrd, Eq, PartialEq)]
|
||||
pub struct RepoPath(pub PathBuf);
|
||||
|
||||
impl RepoPath {
|
||||
pub fn new(path: PathBuf) -> Self {
|
||||
debug_assert!(path.is_relative(), "Repo paths must be relative");
|
||||
|
||||
RepoPath(path)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&Path> for RepoPath {
|
||||
fn from(value: &Path) -> Self {
|
||||
RepoPath::new(value.to_path_buf())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<PathBuf> for RepoPath {
|
||||
fn from(value: PathBuf) -> Self {
|
||||
RepoPath::new(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for RepoPath {
|
||||
fn default() -> Self {
|
||||
RepoPath(PathBuf::new())
|
||||
}
|
||||
}
|
||||
|
||||
impl AsRef<Path> for RepoPath {
|
||||
fn as_ref(&self) -> &Path {
|
||||
self.0.as_ref()
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Deref for RepoPath {
|
||||
type Target = PathBuf;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct RepoPathDescendants<'a>(pub &'a Path);
|
||||
|
||||
impl<'a> MapSeekTarget<RepoPath> for RepoPathDescendants<'a> {
|
||||
fn cmp_cursor(&self, key: &RepoPath) -> Ordering {
|
||||
if key.starts_with(&self.0) {
|
||||
Ordering::Greater
|
||||
} else {
|
||||
self.0.cmp(key)
|
||||
}
|
||||
}
|
||||
}
|
@ -28,7 +28,7 @@ fuzzy = { path = "../fuzzy" }
|
||||
git = { package = "git3", path = "../git3" }
|
||||
gpui = { package = "gpui2", path = "../gpui2" }
|
||||
lsp = { path = "../lsp" }
|
||||
rpc = { package = "rpc2", path = "../rpc2" }
|
||||
rpc = { path = "../rpc" }
|
||||
settings = { package = "settings2", path = "../settings2" }
|
||||
sum_tree = { path = "../sum_tree" }
|
||||
text = { package = "text2", path = "../text2" }
|
||||
|
@ -5,7 +5,7 @@ edition = "2021"
|
||||
publish = false
|
||||
|
||||
[lib]
|
||||
path = "src/notification_store.rs"
|
||||
path = "src/notification_store2.rs"
|
||||
doctest = false
|
||||
|
||||
[features]
|
||||
@ -23,11 +23,11 @@ clock = { path = "../clock" }
|
||||
collections = { path = "../collections" }
|
||||
db = { path = "../db" }
|
||||
feature_flags = { path = "../feature_flags" }
|
||||
gpui = { path = "../gpui" }
|
||||
gpui = { package = "gpui2", path = "../gpui2" }
|
||||
rpc = { path = "../rpc" }
|
||||
settings = { path = "../settings" }
|
||||
settings = { package = "settings2", path = "../settings2" }
|
||||
sum_tree = { path = "../sum_tree" }
|
||||
text = { path = "../text" }
|
||||
text = { package = "text2", path = "../text2" }
|
||||
util = { path = "../util" }
|
||||
|
||||
anyhow.workspace = true
|
||||
@ -36,7 +36,7 @@ time.workspace = true
|
||||
[dev-dependencies]
|
||||
client = { path = "../client", features = ["test-support"] }
|
||||
collections = { path = "../collections", features = ["test-support"] }
|
||||
gpui = { path = "../gpui", features = ["test-support"] }
|
||||
gpui = { package = "gpui2", path = "../gpui2", features = ["test-support"] }
|
||||
rpc = { path = "../rpc", features = ["test-support"] }
|
||||
settings = { path = "../settings", features = ["test-support"] }
|
||||
settings = { package = "settings2", path = "../settings2", features = ["test-support"] }
|
||||
util = { path = "../util", features = ["test-support"] }
|
||||
|
@ -1,459 +0,0 @@
|
||||
use anyhow::Result;
|
||||
use channel::{ChannelMessage, ChannelMessageId, ChannelStore};
|
||||
use client::{Client, UserStore};
|
||||
use collections::HashMap;
|
||||
use db::smol::stream::StreamExt;
|
||||
use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task};
|
||||
use rpc::{proto, Notification, TypedEnvelope};
|
||||
use std::{ops::Range, sync::Arc};
|
||||
use sum_tree::{Bias, SumTree};
|
||||
use time::OffsetDateTime;
|
||||
use util::ResultExt;
|
||||
|
||||
pub fn init(client: Arc<Client>, user_store: ModelHandle<UserStore>, cx: &mut AppContext) {
|
||||
let notification_store = cx.add_model(|cx| NotificationStore::new(client, user_store, cx));
|
||||
cx.set_global(notification_store);
|
||||
}
|
||||
|
||||
pub struct NotificationStore {
|
||||
client: Arc<Client>,
|
||||
user_store: ModelHandle<UserStore>,
|
||||
channel_messages: HashMap<u64, ChannelMessage>,
|
||||
channel_store: ModelHandle<ChannelStore>,
|
||||
notifications: SumTree<NotificationEntry>,
|
||||
loaded_all_notifications: bool,
|
||||
_watch_connection_status: Task<Option<()>>,
|
||||
_subscriptions: Vec<client::Subscription>,
|
||||
}
|
||||
|
||||
#[derive(Clone, PartialEq, Eq, Debug)]
|
||||
pub enum NotificationEvent {
|
||||
NotificationsUpdated {
|
||||
old_range: Range<usize>,
|
||||
new_count: usize,
|
||||
},
|
||||
NewNotification {
|
||||
entry: NotificationEntry,
|
||||
},
|
||||
NotificationRemoved {
|
||||
entry: NotificationEntry,
|
||||
},
|
||||
NotificationRead {
|
||||
entry: NotificationEntry,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone)]
|
||||
pub struct NotificationEntry {
|
||||
pub id: u64,
|
||||
pub notification: Notification,
|
||||
pub timestamp: OffsetDateTime,
|
||||
pub is_read: bool,
|
||||
pub response: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct NotificationSummary {
|
||||
max_id: u64,
|
||||
count: usize,
|
||||
unread_count: usize,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
|
||||
struct Count(usize);
|
||||
|
||||
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
|
||||
struct UnreadCount(usize);
|
||||
|
||||
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
|
||||
struct NotificationId(u64);
|
||||
|
||||
impl NotificationStore {
|
||||
pub fn global(cx: &AppContext) -> ModelHandle<Self> {
|
||||
cx.global::<ModelHandle<Self>>().clone()
|
||||
}
|
||||
|
||||
pub fn new(
|
||||
client: Arc<Client>,
|
||||
user_store: ModelHandle<UserStore>,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) -> Self {
|
||||
let mut connection_status = client.status();
|
||||
let watch_connection_status = cx.spawn_weak(|this, mut cx| async move {
|
||||
while let Some(status) = connection_status.next().await {
|
||||
let this = this.upgrade(&cx)?;
|
||||
match status {
|
||||
client::Status::Connected { .. } => {
|
||||
if let Some(task) = this.update(&mut cx, |this, cx| this.handle_connect(cx))
|
||||
{
|
||||
task.await.log_err()?;
|
||||
}
|
||||
}
|
||||
_ => this.update(&mut cx, |this, cx| this.handle_disconnect(cx)),
|
||||
}
|
||||
}
|
||||
Some(())
|
||||
});
|
||||
|
||||
Self {
|
||||
channel_store: ChannelStore::global(cx),
|
||||
notifications: Default::default(),
|
||||
loaded_all_notifications: false,
|
||||
channel_messages: Default::default(),
|
||||
_watch_connection_status: watch_connection_status,
|
||||
_subscriptions: vec![
|
||||
client.add_message_handler(cx.handle(), Self::handle_new_notification),
|
||||
client.add_message_handler(cx.handle(), Self::handle_delete_notification),
|
||||
],
|
||||
user_store,
|
||||
client,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn notification_count(&self) -> usize {
|
||||
self.notifications.summary().count
|
||||
}
|
||||
|
||||
pub fn unread_notification_count(&self) -> usize {
|
||||
self.notifications.summary().unread_count
|
||||
}
|
||||
|
||||
pub fn channel_message_for_id(&self, id: u64) -> Option<&ChannelMessage> {
|
||||
self.channel_messages.get(&id)
|
||||
}
|
||||
|
||||
// Get the nth newest notification.
|
||||
pub fn notification_at(&self, ix: usize) -> Option<&NotificationEntry> {
|
||||
let count = self.notifications.summary().count;
|
||||
if ix >= count {
|
||||
return None;
|
||||
}
|
||||
let ix = count - 1 - ix;
|
||||
let mut cursor = self.notifications.cursor::<Count>();
|
||||
cursor.seek(&Count(ix), Bias::Right, &());
|
||||
cursor.item()
|
||||
}
|
||||
|
||||
pub fn notification_for_id(&self, id: u64) -> Option<&NotificationEntry> {
|
||||
let mut cursor = self.notifications.cursor::<NotificationId>();
|
||||
cursor.seek(&NotificationId(id), Bias::Left, &());
|
||||
if let Some(item) = cursor.item() {
|
||||
if item.id == id {
|
||||
return Some(item);
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
pub fn load_more_notifications(
|
||||
&self,
|
||||
clear_old: bool,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) -> Option<Task<Result<()>>> {
|
||||
if self.loaded_all_notifications && !clear_old {
|
||||
return None;
|
||||
}
|
||||
|
||||
let before_id = if clear_old {
|
||||
None
|
||||
} else {
|
||||
self.notifications.first().map(|entry| entry.id)
|
||||
};
|
||||
let request = self.client.request(proto::GetNotifications { before_id });
|
||||
Some(cx.spawn(|this, mut cx| async move {
|
||||
let response = request.await?;
|
||||
this.update(&mut cx, |this, _| {
|
||||
this.loaded_all_notifications = response.done
|
||||
});
|
||||
Self::add_notifications(
|
||||
this,
|
||||
response.notifications,
|
||||
AddNotificationsOptions {
|
||||
is_new: false,
|
||||
clear_old,
|
||||
includes_first: response.done,
|
||||
},
|
||||
cx,
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
}))
|
||||
}
|
||||
|
||||
fn handle_connect(&mut self, cx: &mut ModelContext<Self>) -> Option<Task<Result<()>>> {
|
||||
self.notifications = Default::default();
|
||||
self.channel_messages = Default::default();
|
||||
cx.notify();
|
||||
self.load_more_notifications(true, cx)
|
||||
}
|
||||
|
||||
fn handle_disconnect(&mut self, cx: &mut ModelContext<Self>) {
|
||||
cx.notify()
|
||||
}
|
||||
|
||||
async fn handle_new_notification(
|
||||
this: ModelHandle<Self>,
|
||||
envelope: TypedEnvelope<proto::AddNotification>,
|
||||
_: Arc<Client>,
|
||||
cx: AsyncAppContext,
|
||||
) -> Result<()> {
|
||||
Self::add_notifications(
|
||||
this,
|
||||
envelope.payload.notification.into_iter().collect(),
|
||||
AddNotificationsOptions {
|
||||
is_new: true,
|
||||
clear_old: false,
|
||||
includes_first: false,
|
||||
},
|
||||
cx,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn handle_delete_notification(
|
||||
this: ModelHandle<Self>,
|
||||
envelope: TypedEnvelope<proto::DeleteNotification>,
|
||||
_: Arc<Client>,
|
||||
mut cx: AsyncAppContext,
|
||||
) -> Result<()> {
|
||||
this.update(&mut cx, |this, cx| {
|
||||
this.splice_notifications([(envelope.payload.notification_id, None)], false, cx);
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
async fn add_notifications(
|
||||
this: ModelHandle<Self>,
|
||||
notifications: Vec<proto::Notification>,
|
||||
options: AddNotificationsOptions,
|
||||
mut cx: AsyncAppContext,
|
||||
) -> Result<()> {
|
||||
let mut user_ids = Vec::new();
|
||||
let mut message_ids = Vec::new();
|
||||
|
||||
let notifications = notifications
|
||||
.into_iter()
|
||||
.filter_map(|message| {
|
||||
Some(NotificationEntry {
|
||||
id: message.id,
|
||||
is_read: message.is_read,
|
||||
timestamp: OffsetDateTime::from_unix_timestamp(message.timestamp as i64)
|
||||
.ok()?,
|
||||
notification: Notification::from_proto(&message)?,
|
||||
response: message.response,
|
||||
})
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
if notifications.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
for entry in ¬ifications {
|
||||
match entry.notification {
|
||||
Notification::ChannelInvitation { inviter_id, .. } => {
|
||||
user_ids.push(inviter_id);
|
||||
}
|
||||
Notification::ContactRequest {
|
||||
sender_id: requester_id,
|
||||
} => {
|
||||
user_ids.push(requester_id);
|
||||
}
|
||||
Notification::ContactRequestAccepted {
|
||||
responder_id: contact_id,
|
||||
} => {
|
||||
user_ids.push(contact_id);
|
||||
}
|
||||
Notification::ChannelMessageMention {
|
||||
sender_id,
|
||||
message_id,
|
||||
..
|
||||
} => {
|
||||
user_ids.push(sender_id);
|
||||
message_ids.push(message_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let (user_store, channel_store) = this.read_with(&cx, |this, _| {
|
||||
(this.user_store.clone(), this.channel_store.clone())
|
||||
});
|
||||
|
||||
user_store
|
||||
.update(&mut cx, |store, cx| store.get_users(user_ids, cx))
|
||||
.await?;
|
||||
let messages = channel_store
|
||||
.update(&mut cx, |store, cx| {
|
||||
store.fetch_channel_messages(message_ids, cx)
|
||||
})
|
||||
.await?;
|
||||
this.update(&mut cx, |this, cx| {
|
||||
if options.clear_old {
|
||||
cx.emit(NotificationEvent::NotificationsUpdated {
|
||||
old_range: 0..this.notifications.summary().count,
|
||||
new_count: 0,
|
||||
});
|
||||
this.notifications = SumTree::default();
|
||||
this.channel_messages.clear();
|
||||
this.loaded_all_notifications = false;
|
||||
}
|
||||
|
||||
if options.includes_first {
|
||||
this.loaded_all_notifications = true;
|
||||
}
|
||||
|
||||
this.channel_messages
|
||||
.extend(messages.into_iter().filter_map(|message| {
|
||||
if let ChannelMessageId::Saved(id) = message.id {
|
||||
Some((id, message))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}));
|
||||
|
||||
this.splice_notifications(
|
||||
notifications
|
||||
.into_iter()
|
||||
.map(|notification| (notification.id, Some(notification))),
|
||||
options.is_new,
|
||||
cx,
|
||||
);
|
||||
});
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn splice_notifications(
|
||||
&mut self,
|
||||
notifications: impl IntoIterator<Item = (u64, Option<NotificationEntry>)>,
|
||||
is_new: bool,
|
||||
cx: &mut ModelContext<'_, NotificationStore>,
|
||||
) {
|
||||
let mut cursor = self.notifications.cursor::<(NotificationId, Count)>();
|
||||
let mut new_notifications = SumTree::new();
|
||||
let mut old_range = 0..0;
|
||||
|
||||
for (i, (id, new_notification)) in notifications.into_iter().enumerate() {
|
||||
new_notifications.append(cursor.slice(&NotificationId(id), Bias::Left, &()), &());
|
||||
|
||||
if i == 0 {
|
||||
old_range.start = cursor.start().1 .0;
|
||||
}
|
||||
|
||||
let old_notification = cursor.item();
|
||||
if let Some(old_notification) = old_notification {
|
||||
if old_notification.id == id {
|
||||
cursor.next(&());
|
||||
|
||||
if let Some(new_notification) = &new_notification {
|
||||
if new_notification.is_read {
|
||||
cx.emit(NotificationEvent::NotificationRead {
|
||||
entry: new_notification.clone(),
|
||||
});
|
||||
}
|
||||
} else {
|
||||
cx.emit(NotificationEvent::NotificationRemoved {
|
||||
entry: old_notification.clone(),
|
||||
});
|
||||
}
|
||||
}
|
||||
} else if let Some(new_notification) = &new_notification {
|
||||
if is_new {
|
||||
cx.emit(NotificationEvent::NewNotification {
|
||||
entry: new_notification.clone(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(notification) = new_notification {
|
||||
new_notifications.push(notification, &());
|
||||
}
|
||||
}
|
||||
|
||||
old_range.end = cursor.start().1 .0;
|
||||
let new_count = new_notifications.summary().count - old_range.start;
|
||||
new_notifications.append(cursor.suffix(&()), &());
|
||||
drop(cursor);
|
||||
|
||||
self.notifications = new_notifications;
|
||||
cx.emit(NotificationEvent::NotificationsUpdated {
|
||||
old_range,
|
||||
new_count,
|
||||
});
|
||||
}
|
||||
|
||||
pub fn respond_to_notification(
|
||||
&mut self,
|
||||
notification: Notification,
|
||||
response: bool,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) {
|
||||
match notification {
|
||||
Notification::ContactRequest { sender_id } => {
|
||||
self.user_store
|
||||
.update(cx, |store, cx| {
|
||||
store.respond_to_contact_request(sender_id, response, cx)
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
Notification::ChannelInvitation { channel_id, .. } => {
|
||||
self.channel_store
|
||||
.update(cx, |store, cx| {
|
||||
store.respond_to_channel_invite(channel_id, response, cx)
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Entity for NotificationStore {
|
||||
type Event = NotificationEvent;
|
||||
}
|
||||
|
||||
impl sum_tree::Item for NotificationEntry {
|
||||
type Summary = NotificationSummary;
|
||||
|
||||
fn summary(&self) -> Self::Summary {
|
||||
NotificationSummary {
|
||||
max_id: self.id,
|
||||
count: 1,
|
||||
unread_count: if self.is_read { 0 } else { 1 },
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl sum_tree::Summary for NotificationSummary {
|
||||
type Context = ();
|
||||
|
||||
fn add_summary(&mut self, summary: &Self, _: &()) {
|
||||
self.max_id = self.max_id.max(summary.max_id);
|
||||
self.count += summary.count;
|
||||
self.unread_count += summary.unread_count;
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> sum_tree::Dimension<'a, NotificationSummary> for NotificationId {
|
||||
fn add_summary(&mut self, summary: &NotificationSummary, _: &()) {
|
||||
debug_assert!(summary.max_id > self.0);
|
||||
self.0 = summary.max_id;
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> sum_tree::Dimension<'a, NotificationSummary> for Count {
|
||||
fn add_summary(&mut self, summary: &NotificationSummary, _: &()) {
|
||||
self.0 += summary.count;
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> sum_tree::Dimension<'a, NotificationSummary> for UnreadCount {
|
||||
fn add_summary(&mut self, summary: &NotificationSummary, _: &()) {
|
||||
self.0 += summary.unread_count;
|
||||
}
|
||||
}
|
||||
|
||||
struct AddNotificationsOptions {
|
||||
is_new: bool,
|
||||
clear_old: bool,
|
||||
includes_first: bool,
|
||||
}
|
@ -1,42 +0,0 @@
|
||||
[package]
|
||||
name = "notifications2"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
publish = false
|
||||
|
||||
[lib]
|
||||
path = "src/notification_store2.rs"
|
||||
doctest = false
|
||||
|
||||
[features]
|
||||
test-support = [
|
||||
"channel/test-support",
|
||||
"collections/test-support",
|
||||
"gpui/test-support",
|
||||
"rpc/test-support",
|
||||
]
|
||||
|
||||
[dependencies]
|
||||
channel = { path = "../channel" }
|
||||
client = { path = "../client" }
|
||||
clock = { path = "../clock" }
|
||||
collections = { path = "../collections" }
|
||||
db = { package = "db2", path = "../db2" }
|
||||
feature_flags = { path = "../feature_flags" }
|
||||
gpui = { package = "gpui2", path = "../gpui2" }
|
||||
rpc = { package = "rpc2", path = "../rpc2" }
|
||||
settings = { package = "settings2", path = "../settings2" }
|
||||
sum_tree = { path = "../sum_tree" }
|
||||
text = { package = "text2", path = "../text2" }
|
||||
util = { path = "../util" }
|
||||
|
||||
anyhow.workspace = true
|
||||
time.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
client = { path = "../client", features = ["test-support"] }
|
||||
collections = { path = "../collections", features = ["test-support"] }
|
||||
gpui = { package = "gpui2", path = "../gpui2", features = ["test-support"] }
|
||||
rpc = { package = "rpc2", path = "../rpc2", features = ["test-support"] }
|
||||
settings = { package = "settings2", path = "../settings2", features = ["test-support"] }
|
||||
util = { path = "../util", features = ["test-support"] }
|
@ -15,7 +15,7 @@ test-support = []
|
||||
client = { path = "../client" }
|
||||
collections = { path = "../collections"}
|
||||
language = { path = "../language" }
|
||||
gpui = { path = "../gpui" }
|
||||
gpui = { package = "gpui2", path = "../gpui2" }
|
||||
fs = { path = "../fs" }
|
||||
lsp = { path = "../lsp" }
|
||||
node_runtime = { path = "../node_runtime"}
|
||||
@ -31,5 +31,5 @@ parking_lot.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
language = { path = "../language", features = ["test-support"] }
|
||||
gpui = { path = "../gpui", features = ["test-support"] }
|
||||
gpui = { package = "gpui2", path = "../gpui2", features = ["test-support"] }
|
||||
fs = { path = "../fs", features = ["test-support"] }
|
||||
|
@ -1,16 +1,16 @@
|
||||
use std::ops::ControlFlow;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::Context;
|
||||
use collections::{HashMap, HashSet};
|
||||
use fs::Fs;
|
||||
use gpui::{AsyncAppContext, ModelHandle};
|
||||
use language::language_settings::language_settings;
|
||||
use language::{Buffer, Diff};
|
||||
use gpui::{AsyncAppContext, Model};
|
||||
use language::{language_settings::language_settings, Buffer, Diff};
|
||||
use lsp::{LanguageServer, LanguageServerId};
|
||||
use node_runtime::NodeRuntime;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{
|
||||
ops::ControlFlow,
|
||||
path::{Path, PathBuf},
|
||||
sync::Arc,
|
||||
};
|
||||
use util::paths::{PathMatcher, DEFAULT_PRETTIER_DIR};
|
||||
|
||||
#[derive(Clone)]
|
||||
@ -172,7 +172,7 @@ impl Prettier {
|
||||
) -> anyhow::Result<Self> {
|
||||
use lsp::LanguageServerBinary;
|
||||
|
||||
let background = cx.background();
|
||||
let executor = cx.background_executor().clone();
|
||||
anyhow::ensure!(
|
||||
prettier_dir.is_dir(),
|
||||
"Prettier dir {prettier_dir:?} is not a directory"
|
||||
@ -183,7 +183,7 @@ impl Prettier {
|
||||
"no prettier server package found at {prettier_server:?}"
|
||||
);
|
||||
|
||||
let node_path = background
|
||||
let node_path = executor
|
||||
.spawn(async move { node.binary_path().await })
|
||||
.await?;
|
||||
let server = LanguageServer::new(
|
||||
@ -198,7 +198,7 @@ impl Prettier {
|
||||
cx,
|
||||
)
|
||||
.context("prettier server creation")?;
|
||||
let server = background
|
||||
let server = executor
|
||||
.spawn(server.initialize(None))
|
||||
.await
|
||||
.context("prettier server initialization")?;
|
||||
@ -211,13 +211,14 @@ impl Prettier {
|
||||
|
||||
pub async fn format(
|
||||
&self,
|
||||
buffer: &ModelHandle<Buffer>,
|
||||
buffer: &Model<Buffer>,
|
||||
buffer_path: Option<PathBuf>,
|
||||
cx: &AsyncAppContext,
|
||||
cx: &mut AsyncAppContext,
|
||||
) -> anyhow::Result<Diff> {
|
||||
match self {
|
||||
Self::Real(local) => {
|
||||
let params = buffer.read_with(cx, |buffer, cx| {
|
||||
let params = buffer
|
||||
.update(cx, |buffer, cx| {
|
||||
let buffer_language = buffer.language();
|
||||
let parser_with_plugins = buffer_language.and_then(|l| {
|
||||
let prettier_parser = l.prettier_parser_name()?;
|
||||
@ -231,7 +232,10 @@ impl Prettier {
|
||||
});
|
||||
|
||||
let prettier_node_modules = self.prettier_dir().join("node_modules");
|
||||
anyhow::ensure!(prettier_node_modules.is_dir(), "Prettier node_modules dir does not exist: {prettier_node_modules:?}");
|
||||
anyhow::ensure!(
|
||||
prettier_node_modules.is_dir(),
|
||||
"Prettier node_modules dir does not exist: {prettier_node_modules:?}"
|
||||
);
|
||||
let plugin_name_into_path = |plugin_name: &str| {
|
||||
let prettier_plugin_dir = prettier_node_modules.join(plugin_name);
|
||||
for possible_plugin_path in [
|
||||
@ -255,24 +259,36 @@ impl Prettier {
|
||||
// https://github.com/tailwindlabs/prettier-plugin-tailwindcss#compatibility-with-other-prettier-plugins
|
||||
let mut add_tailwind_back = false;
|
||||
|
||||
let mut plugins = plugins.into_iter().filter(|&&plugin_name| {
|
||||
let mut plugins = plugins
|
||||
.into_iter()
|
||||
.filter(|&&plugin_name| {
|
||||
if plugin_name == TAILWIND_PRETTIER_PLUGIN_PACKAGE_NAME {
|
||||
add_tailwind_back = true;
|
||||
false
|
||||
} else {
|
||||
true
|
||||
}
|
||||
}).map(|plugin_name| (plugin_name, plugin_name_into_path(plugin_name))).collect::<Vec<_>>();
|
||||
})
|
||||
.map(|plugin_name| {
|
||||
(plugin_name, plugin_name_into_path(plugin_name))
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
if add_tailwind_back {
|
||||
plugins.push((&TAILWIND_PRETTIER_PLUGIN_PACKAGE_NAME, plugin_name_into_path(TAILWIND_PRETTIER_PLUGIN_PACKAGE_NAME)));
|
||||
plugins.push((
|
||||
&TAILWIND_PRETTIER_PLUGIN_PACKAGE_NAME,
|
||||
plugin_name_into_path(
|
||||
TAILWIND_PRETTIER_PLUGIN_PACKAGE_NAME,
|
||||
),
|
||||
));
|
||||
}
|
||||
(Some(parser.to_string()), plugins)
|
||||
},
|
||||
}
|
||||
None => (None, Vec::new()),
|
||||
};
|
||||
|
||||
let prettier_options = if self.is_default() {
|
||||
let language_settings = language_settings(buffer_language, buffer.file(), cx);
|
||||
let language_settings =
|
||||
language_settings(buffer_language, buffer.file(), cx);
|
||||
let mut options = language_settings.prettier.clone();
|
||||
if !options.contains_key("tabWidth") {
|
||||
options.insert(
|
||||
@ -295,15 +311,28 @@ impl Prettier {
|
||||
None
|
||||
};
|
||||
|
||||
let plugins = located_plugins.into_iter().filter_map(|(plugin_name, located_plugin_path)| {
|
||||
let plugins = located_plugins
|
||||
.into_iter()
|
||||
.filter_map(|(plugin_name, located_plugin_path)| {
|
||||
match located_plugin_path {
|
||||
Some(path) => Some(path),
|
||||
None => {
|
||||
log::error!("Have not found plugin path for {plugin_name:?} inside {prettier_node_modules:?}");
|
||||
None},
|
||||
log::error!(
|
||||
"Have not found plugin path for {:?} inside {:?}",
|
||||
plugin_name,
|
||||
prettier_node_modules
|
||||
);
|
||||
None
|
||||
}
|
||||
}).collect();
|
||||
log::debug!("Formatting file {:?} with prettier, plugins :{plugins:?}, options: {prettier_options:?}", buffer.file().map(|f| f.full_path(cx)));
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
log::debug!(
|
||||
"Formatting file {:?} with prettier, plugins :{:?}, options: {:?}",
|
||||
plugins,
|
||||
prettier_options,
|
||||
buffer.file().map(|f| f.full_path(cx))
|
||||
);
|
||||
|
||||
anyhow::Ok(FormatParams {
|
||||
text: buffer.text(),
|
||||
@ -314,21 +343,22 @@ impl Prettier {
|
||||
prettier_options,
|
||||
},
|
||||
})
|
||||
}).context("prettier params calculation")?;
|
||||
})?
|
||||
.context("prettier params calculation")?;
|
||||
let response = local
|
||||
.server
|
||||
.request::<Format>(params)
|
||||
.await
|
||||
.context("prettier format request")?;
|
||||
let diff_task = buffer.read_with(cx, |buffer, cx| buffer.diff(response.text, cx));
|
||||
let diff_task = buffer.update(cx, |buffer, cx| buffer.diff(response.text, cx))?;
|
||||
Ok(diff_task.await)
|
||||
}
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
Self::Test(_) => Ok(buffer
|
||||
.read_with(cx, |buffer, cx| {
|
||||
.update(cx, |buffer, cx| {
|
||||
let formatted_text = buffer.text() + FORMAT_SUFFIX;
|
||||
buffer.diff(formatted_text, cx)
|
||||
})
|
||||
})?
|
||||
.await),
|
||||
}
|
||||
}
|
||||
@ -471,7 +501,7 @@ mod tests {
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_prettier_lookup_finds_nothing(cx: &mut gpui::TestAppContext) {
|
||||
let fs = FakeFs::new(cx.background());
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
fs.insert_tree(
|
||||
"/root",
|
||||
json!({
|
||||
@ -547,7 +577,7 @@ mod tests {
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_prettier_lookup_in_simple_npm_projects(cx: &mut gpui::TestAppContext) {
|
||||
let fs = FakeFs::new(cx.background());
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
fs.insert_tree(
|
||||
"/root",
|
||||
json!({
|
||||
@ -612,7 +642,7 @@ mod tests {
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_prettier_lookup_for_not_installed(cx: &mut gpui::TestAppContext) {
|
||||
let fs = FakeFs::new(cx.background());
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
fs.insert_tree(
|
||||
"/root",
|
||||
json!({
|
||||
@ -662,6 +692,7 @@ mod tests {
|
||||
assert!(message.contains("/root/work/web_blog"), "Error message should mention which project had prettier defined");
|
||||
},
|
||||
};
|
||||
|
||||
assert_eq!(
|
||||
Prettier::locate_prettier_installation(
|
||||
fs.as_ref(),
|
||||
@ -704,7 +735,7 @@ mod tests {
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_prettier_lookup_in_npm_workspaces(cx: &mut gpui::TestAppContext) {
|
||||
let fs = FakeFs::new(cx.background());
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
fs.insert_tree(
|
||||
"/root",
|
||||
json!({
|
||||
@ -785,7 +816,7 @@ mod tests {
|
||||
async fn test_prettier_lookup_in_npm_workspaces_for_not_installed(
|
||||
cx: &mut gpui::TestAppContext,
|
||||
) {
|
||||
let fs = FakeFs::new(cx.background());
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
fs.insert_tree(
|
||||
"/root",
|
||||
json!({
|
||||
|
@ -1,35 +0,0 @@
|
||||
[package]
|
||||
name = "prettier2"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
publish = false
|
||||
|
||||
[lib]
|
||||
path = "src/prettier2.rs"
|
||||
doctest = false
|
||||
|
||||
[features]
|
||||
test-support = []
|
||||
|
||||
[dependencies]
|
||||
client = { path = "../client" }
|
||||
collections = { path = "../collections"}
|
||||
language = { path = "../language" }
|
||||
gpui = { package = "gpui2", path = "../gpui2" }
|
||||
fs = { package = "fs2", path = "../fs2" }
|
||||
lsp = { path = "../lsp" }
|
||||
node_runtime = { path = "../node_runtime"}
|
||||
util = { path = "../util" }
|
||||
|
||||
log.workspace = true
|
||||
serde.workspace = true
|
||||
serde_derive.workspace = true
|
||||
serde_json.workspace = true
|
||||
anyhow.workspace = true
|
||||
futures.workspace = true
|
||||
parking_lot.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
language = { path = "../language", features = ["test-support"] }
|
||||
gpui = { package = "gpui2", path = "../gpui2", features = ["test-support"] }
|
||||
fs = { package = "fs2", path = "../fs2", features = ["test-support"] }
|
@ -1,869 +0,0 @@
|
||||
use anyhow::Context;
|
||||
use collections::{HashMap, HashSet};
|
||||
use fs::Fs;
|
||||
use gpui::{AsyncAppContext, Model};
|
||||
use language::{language_settings::language_settings, Buffer, Diff};
|
||||
use lsp::{LanguageServer, LanguageServerId};
|
||||
use node_runtime::NodeRuntime;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{
|
||||
ops::ControlFlow,
|
||||
path::{Path, PathBuf},
|
||||
sync::Arc,
|
||||
};
|
||||
use util::paths::{PathMatcher, DEFAULT_PRETTIER_DIR};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub enum Prettier {
|
||||
Real(RealPrettier),
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
Test(TestPrettier),
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct RealPrettier {
|
||||
default: bool,
|
||||
prettier_dir: PathBuf,
|
||||
server: Arc<LanguageServer>,
|
||||
}
|
||||
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
#[derive(Clone)]
|
||||
pub struct TestPrettier {
|
||||
prettier_dir: PathBuf,
|
||||
default: bool,
|
||||
}
|
||||
|
||||
pub const FAIL_THRESHOLD: usize = 4;
|
||||
pub const PRETTIER_SERVER_FILE: &str = "prettier_server.js";
|
||||
pub const PRETTIER_SERVER_JS: &str = include_str!("./prettier_server.js");
|
||||
const PRETTIER_PACKAGE_NAME: &str = "prettier";
|
||||
const TAILWIND_PRETTIER_PLUGIN_PACKAGE_NAME: &str = "prettier-plugin-tailwindcss";
|
||||
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
pub const FORMAT_SUFFIX: &str = "\nformatted by test prettier";
|
||||
|
||||
impl Prettier {
|
||||
pub const CONFIG_FILE_NAMES: &'static [&'static str] = &[
|
||||
".prettierrc",
|
||||
".prettierrc.json",
|
||||
".prettierrc.json5",
|
||||
".prettierrc.yaml",
|
||||
".prettierrc.yml",
|
||||
".prettierrc.toml",
|
||||
".prettierrc.js",
|
||||
".prettierrc.cjs",
|
||||
"package.json",
|
||||
"prettier.config.js",
|
||||
"prettier.config.cjs",
|
||||
".editorconfig",
|
||||
];
|
||||
|
||||
pub async fn locate_prettier_installation(
|
||||
fs: &dyn Fs,
|
||||
installed_prettiers: &HashSet<PathBuf>,
|
||||
locate_from: &Path,
|
||||
) -> anyhow::Result<ControlFlow<(), Option<PathBuf>>> {
|
||||
let mut path_to_check = locate_from
|
||||
.components()
|
||||
.take_while(|component| component.as_os_str().to_string_lossy() != "node_modules")
|
||||
.collect::<PathBuf>();
|
||||
if path_to_check != locate_from {
|
||||
log::debug!(
|
||||
"Skipping prettier location for path {path_to_check:?} that is inside node_modules"
|
||||
);
|
||||
return Ok(ControlFlow::Break(()));
|
||||
}
|
||||
let path_to_check_metadata = fs
|
||||
.metadata(&path_to_check)
|
||||
.await
|
||||
.with_context(|| format!("failed to get metadata for initial path {path_to_check:?}"))?
|
||||
.with_context(|| format!("empty metadata for initial path {path_to_check:?}"))?;
|
||||
if !path_to_check_metadata.is_dir {
|
||||
path_to_check.pop();
|
||||
}
|
||||
|
||||
let mut project_path_with_prettier_dependency = None;
|
||||
loop {
|
||||
if installed_prettiers.contains(&path_to_check) {
|
||||
log::debug!("Found prettier path {path_to_check:?} in installed prettiers");
|
||||
return Ok(ControlFlow::Continue(Some(path_to_check)));
|
||||
} else if let Some(package_json_contents) =
|
||||
read_package_json(fs, &path_to_check).await?
|
||||
{
|
||||
if has_prettier_in_package_json(&package_json_contents) {
|
||||
if has_prettier_in_node_modules(fs, &path_to_check).await? {
|
||||
log::debug!("Found prettier path {path_to_check:?} in both package.json and node_modules");
|
||||
return Ok(ControlFlow::Continue(Some(path_to_check)));
|
||||
} else if project_path_with_prettier_dependency.is_none() {
|
||||
project_path_with_prettier_dependency = Some(path_to_check.clone());
|
||||
}
|
||||
} else {
|
||||
match package_json_contents.get("workspaces") {
|
||||
Some(serde_json::Value::Array(workspaces)) => {
|
||||
match &project_path_with_prettier_dependency {
|
||||
Some(project_path_with_prettier_dependency) => {
|
||||
let subproject_path = project_path_with_prettier_dependency.strip_prefix(&path_to_check).expect("traversing path parents, should be able to strip prefix");
|
||||
if workspaces.iter().filter_map(|value| {
|
||||
if let serde_json::Value::String(s) = value {
|
||||
Some(s.clone())
|
||||
} else {
|
||||
log::warn!("Skipping non-string 'workspaces' value: {value:?}");
|
||||
None
|
||||
}
|
||||
}).any(|workspace_definition| {
|
||||
if let Some(path_matcher) = PathMatcher::new(&workspace_definition).ok() {
|
||||
path_matcher.is_match(subproject_path)
|
||||
} else {
|
||||
workspace_definition == subproject_path.to_string_lossy()
|
||||
}
|
||||
}) {
|
||||
anyhow::ensure!(has_prettier_in_node_modules(fs, &path_to_check).await?, "Found prettier path {path_to_check:?} in the workspace root for project in {project_path_with_prettier_dependency:?}, but it's not installed into workspace root's node_modules");
|
||||
log::info!("Found prettier path {path_to_check:?} in the workspace root for project in {project_path_with_prettier_dependency:?}");
|
||||
return Ok(ControlFlow::Continue(Some(path_to_check)));
|
||||
} else {
|
||||
log::warn!("Skipping path {path_to_check:?} that has prettier in its 'node_modules' subdirectory, but is not included in its package.json workspaces {workspaces:?}");
|
||||
}
|
||||
}
|
||||
None => {
|
||||
log::warn!("Skipping path {path_to_check:?} that has prettier in its 'node_modules' subdirectory, but has no prettier in its package.json");
|
||||
}
|
||||
}
|
||||
},
|
||||
Some(unknown) => log::error!("Failed to parse workspaces for {path_to_check:?} from package.json, got {unknown:?}. Skipping."),
|
||||
None => log::warn!("Skipping path {path_to_check:?} that has no prettier dependency and no workspaces section in its package.json"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !path_to_check.pop() {
|
||||
match project_path_with_prettier_dependency {
|
||||
Some(closest_prettier_discovered) => {
|
||||
anyhow::bail!("No prettier found in node_modules for ancestors of {locate_from:?}, but discovered prettier package.json dependency in {closest_prettier_discovered:?}")
|
||||
}
|
||||
None => {
|
||||
log::debug!("Found no prettier in ancestors of {locate_from:?}");
|
||||
return Ok(ControlFlow::Continue(None));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
pub async fn start(
|
||||
_: LanguageServerId,
|
||||
prettier_dir: PathBuf,
|
||||
_: Arc<dyn NodeRuntime>,
|
||||
_: AsyncAppContext,
|
||||
) -> anyhow::Result<Self> {
|
||||
Ok(Self::Test(TestPrettier {
|
||||
default: prettier_dir == DEFAULT_PRETTIER_DIR.as_path(),
|
||||
prettier_dir,
|
||||
}))
|
||||
}
|
||||
|
||||
#[cfg(not(any(test, feature = "test-support")))]
|
||||
pub async fn start(
|
||||
server_id: LanguageServerId,
|
||||
prettier_dir: PathBuf,
|
||||
node: Arc<dyn NodeRuntime>,
|
||||
cx: AsyncAppContext,
|
||||
) -> anyhow::Result<Self> {
|
||||
use lsp::LanguageServerBinary;
|
||||
|
||||
let executor = cx.background_executor().clone();
|
||||
anyhow::ensure!(
|
||||
prettier_dir.is_dir(),
|
||||
"Prettier dir {prettier_dir:?} is not a directory"
|
||||
);
|
||||
let prettier_server = DEFAULT_PRETTIER_DIR.join(PRETTIER_SERVER_FILE);
|
||||
anyhow::ensure!(
|
||||
prettier_server.is_file(),
|
||||
"no prettier server package found at {prettier_server:?}"
|
||||
);
|
||||
|
||||
let node_path = executor
|
||||
.spawn(async move { node.binary_path().await })
|
||||
.await?;
|
||||
let server = LanguageServer::new(
|
||||
Arc::new(parking_lot::Mutex::new(None)),
|
||||
server_id,
|
||||
LanguageServerBinary {
|
||||
path: node_path,
|
||||
arguments: vec![prettier_server.into(), prettier_dir.as_path().into()],
|
||||
},
|
||||
Path::new("/"),
|
||||
None,
|
||||
cx,
|
||||
)
|
||||
.context("prettier server creation")?;
|
||||
let server = executor
|
||||
.spawn(server.initialize(None))
|
||||
.await
|
||||
.context("prettier server initialization")?;
|
||||
Ok(Self::Real(RealPrettier {
|
||||
server,
|
||||
default: prettier_dir == DEFAULT_PRETTIER_DIR.as_path(),
|
||||
prettier_dir,
|
||||
}))
|
||||
}
|
||||
|
||||
pub async fn format(
|
||||
&self,
|
||||
buffer: &Model<Buffer>,
|
||||
buffer_path: Option<PathBuf>,
|
||||
cx: &mut AsyncAppContext,
|
||||
) -> anyhow::Result<Diff> {
|
||||
match self {
|
||||
Self::Real(local) => {
|
||||
let params = buffer
|
||||
.update(cx, |buffer, cx| {
|
||||
let buffer_language = buffer.language();
|
||||
let parser_with_plugins = buffer_language.and_then(|l| {
|
||||
let prettier_parser = l.prettier_parser_name()?;
|
||||
let mut prettier_plugins = l
|
||||
.lsp_adapters()
|
||||
.iter()
|
||||
.flat_map(|adapter| adapter.prettier_plugins())
|
||||
.collect::<Vec<_>>();
|
||||
prettier_plugins.dedup();
|
||||
Some((prettier_parser, prettier_plugins))
|
||||
});
|
||||
|
||||
let prettier_node_modules = self.prettier_dir().join("node_modules");
|
||||
anyhow::ensure!(
|
||||
prettier_node_modules.is_dir(),
|
||||
"Prettier node_modules dir does not exist: {prettier_node_modules:?}"
|
||||
);
|
||||
let plugin_name_into_path = |plugin_name: &str| {
|
||||
let prettier_plugin_dir = prettier_node_modules.join(plugin_name);
|
||||
for possible_plugin_path in [
|
||||
prettier_plugin_dir.join("dist").join("index.mjs"),
|
||||
prettier_plugin_dir.join("dist").join("index.js"),
|
||||
prettier_plugin_dir.join("dist").join("plugin.js"),
|
||||
prettier_plugin_dir.join("index.mjs"),
|
||||
prettier_plugin_dir.join("index.js"),
|
||||
prettier_plugin_dir.join("plugin.js"),
|
||||
prettier_plugin_dir,
|
||||
] {
|
||||
if possible_plugin_path.is_file() {
|
||||
return Some(possible_plugin_path);
|
||||
}
|
||||
}
|
||||
None
|
||||
};
|
||||
let (parser, located_plugins) = match parser_with_plugins {
|
||||
Some((parser, plugins)) => {
|
||||
// Tailwind plugin requires being added last
|
||||
// https://github.com/tailwindlabs/prettier-plugin-tailwindcss#compatibility-with-other-prettier-plugins
|
||||
let mut add_tailwind_back = false;
|
||||
|
||||
let mut plugins = plugins
|
||||
.into_iter()
|
||||
.filter(|&&plugin_name| {
|
||||
if plugin_name == TAILWIND_PRETTIER_PLUGIN_PACKAGE_NAME {
|
||||
add_tailwind_back = true;
|
||||
false
|
||||
} else {
|
||||
true
|
||||
}
|
||||
})
|
||||
.map(|plugin_name| {
|
||||
(plugin_name, plugin_name_into_path(plugin_name))
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
if add_tailwind_back {
|
||||
plugins.push((
|
||||
&TAILWIND_PRETTIER_PLUGIN_PACKAGE_NAME,
|
||||
plugin_name_into_path(
|
||||
TAILWIND_PRETTIER_PLUGIN_PACKAGE_NAME,
|
||||
),
|
||||
));
|
||||
}
|
||||
(Some(parser.to_string()), plugins)
|
||||
}
|
||||
None => (None, Vec::new()),
|
||||
};
|
||||
|
||||
let prettier_options = if self.is_default() {
|
||||
let language_settings =
|
||||
language_settings(buffer_language, buffer.file(), cx);
|
||||
let mut options = language_settings.prettier.clone();
|
||||
if !options.contains_key("tabWidth") {
|
||||
options.insert(
|
||||
"tabWidth".to_string(),
|
||||
serde_json::Value::Number(serde_json::Number::from(
|
||||
language_settings.tab_size.get(),
|
||||
)),
|
||||
);
|
||||
}
|
||||
if !options.contains_key("printWidth") {
|
||||
options.insert(
|
||||
"printWidth".to_string(),
|
||||
serde_json::Value::Number(serde_json::Number::from(
|
||||
language_settings.preferred_line_length,
|
||||
)),
|
||||
);
|
||||
}
|
||||
Some(options)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let plugins = located_plugins
|
||||
.into_iter()
|
||||
.filter_map(|(plugin_name, located_plugin_path)| {
|
||||
match located_plugin_path {
|
||||
Some(path) => Some(path),
|
||||
None => {
|
||||
log::error!(
|
||||
"Have not found plugin path for {:?} inside {:?}",
|
||||
plugin_name,
|
||||
prettier_node_modules
|
||||
);
|
||||
None
|
||||
}
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
log::debug!(
|
||||
"Formatting file {:?} with prettier, plugins :{:?}, options: {:?}",
|
||||
plugins,
|
||||
prettier_options,
|
||||
buffer.file().map(|f| f.full_path(cx))
|
||||
);
|
||||
|
||||
anyhow::Ok(FormatParams {
|
||||
text: buffer.text(),
|
||||
options: FormatOptions {
|
||||
parser,
|
||||
plugins,
|
||||
path: buffer_path,
|
||||
prettier_options,
|
||||
},
|
||||
})
|
||||
})?
|
||||
.context("prettier params calculation")?;
|
||||
let response = local
|
||||
.server
|
||||
.request::<Format>(params)
|
||||
.await
|
||||
.context("prettier format request")?;
|
||||
let diff_task = buffer.update(cx, |buffer, cx| buffer.diff(response.text, cx))?;
|
||||
Ok(diff_task.await)
|
||||
}
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
Self::Test(_) => Ok(buffer
|
||||
.update(cx, |buffer, cx| {
|
||||
let formatted_text = buffer.text() + FORMAT_SUFFIX;
|
||||
buffer.diff(formatted_text, cx)
|
||||
})?
|
||||
.await),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn clear_cache(&self) -> anyhow::Result<()> {
|
||||
match self {
|
||||
Self::Real(local) => local
|
||||
.server
|
||||
.request::<ClearCache>(())
|
||||
.await
|
||||
.context("prettier clear cache"),
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
Self::Test(_) => Ok(()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn server(&self) -> Option<&Arc<LanguageServer>> {
|
||||
match self {
|
||||
Self::Real(local) => Some(&local.server),
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
Self::Test(_) => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_default(&self) -> bool {
|
||||
match self {
|
||||
Self::Real(local) => local.default,
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
Self::Test(test_prettier) => test_prettier.default,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn prettier_dir(&self) -> &Path {
|
||||
match self {
|
||||
Self::Real(local) => &local.prettier_dir,
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
Self::Test(test_prettier) => &test_prettier.prettier_dir,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn has_prettier_in_node_modules(fs: &dyn Fs, path: &Path) -> anyhow::Result<bool> {
|
||||
let possible_node_modules_location = path.join("node_modules").join(PRETTIER_PACKAGE_NAME);
|
||||
if let Some(node_modules_location_metadata) = fs
|
||||
.metadata(&possible_node_modules_location)
|
||||
.await
|
||||
.with_context(|| format!("fetching metadata for {possible_node_modules_location:?}"))?
|
||||
{
|
||||
return Ok(node_modules_location_metadata.is_dir);
|
||||
}
|
||||
Ok(false)
|
||||
}
|
||||
|
||||
async fn read_package_json(
|
||||
fs: &dyn Fs,
|
||||
path: &Path,
|
||||
) -> anyhow::Result<Option<HashMap<String, serde_json::Value>>> {
|
||||
let possible_package_json = path.join("package.json");
|
||||
if let Some(package_json_metadata) = fs
|
||||
.metadata(&possible_package_json)
|
||||
.await
|
||||
.with_context(|| format!("fetching metadata for package json {possible_package_json:?}"))?
|
||||
{
|
||||
if !package_json_metadata.is_dir && !package_json_metadata.is_symlink {
|
||||
let package_json_contents = fs
|
||||
.load(&possible_package_json)
|
||||
.await
|
||||
.with_context(|| format!("reading {possible_package_json:?} file contents"))?;
|
||||
return serde_json::from_str::<HashMap<String, serde_json::Value>>(
|
||||
&package_json_contents,
|
||||
)
|
||||
.map(Some)
|
||||
.with_context(|| format!("parsing {possible_package_json:?} file contents"));
|
||||
}
|
||||
}
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
fn has_prettier_in_package_json(
|
||||
package_json_contents: &HashMap<String, serde_json::Value>,
|
||||
) -> bool {
|
||||
if let Some(serde_json::Value::Object(o)) = package_json_contents.get("dependencies") {
|
||||
if o.contains_key(PRETTIER_PACKAGE_NAME) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
if let Some(serde_json::Value::Object(o)) = package_json_contents.get("devDependencies") {
|
||||
if o.contains_key(PRETTIER_PACKAGE_NAME) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
enum Format {}
|
||||
|
||||
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
struct FormatParams {
|
||||
text: String,
|
||||
options: FormatOptions,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Default, Eq, PartialEq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
struct FormatOptions {
|
||||
plugins: Vec<PathBuf>,
|
||||
parser: Option<String>,
|
||||
#[serde(rename = "filepath")]
|
||||
path: Option<PathBuf>,
|
||||
prettier_options: Option<HashMap<String, serde_json::Value>>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
struct FormatResult {
|
||||
text: String,
|
||||
}
|
||||
|
||||
impl lsp::request::Request for Format {
|
||||
type Params = FormatParams;
|
||||
type Result = FormatResult;
|
||||
const METHOD: &'static str = "prettier/format";
|
||||
}
|
||||
|
||||
enum ClearCache {}
|
||||
|
||||
impl lsp::request::Request for ClearCache {
|
||||
type Params = ();
|
||||
type Result = ();
|
||||
const METHOD: &'static str = "prettier/clear_cache";
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use fs::FakeFs;
|
||||
use serde_json::json;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_prettier_lookup_finds_nothing(cx: &mut gpui::TestAppContext) {
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
fs.insert_tree(
|
||||
"/root",
|
||||
json!({
|
||||
".config": {
|
||||
"zed": {
|
||||
"settings.json": r#"{ "formatter": "auto" }"#,
|
||||
},
|
||||
},
|
||||
"work": {
|
||||
"project": {
|
||||
"src": {
|
||||
"index.js": "// index.js file contents",
|
||||
},
|
||||
"node_modules": {
|
||||
"expect": {
|
||||
"build": {
|
||||
"print.js": "// print.js file contents",
|
||||
},
|
||||
"package.json": r#"{
|
||||
"devDependencies": {
|
||||
"prettier": "2.5.1"
|
||||
}
|
||||
}"#,
|
||||
},
|
||||
"prettier": {
|
||||
"index.js": "// Dummy prettier package file",
|
||||
},
|
||||
},
|
||||
"package.json": r#"{}"#
|
||||
},
|
||||
}
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(
|
||||
matches!(
|
||||
Prettier::locate_prettier_installation(
|
||||
fs.as_ref(),
|
||||
&HashSet::default(),
|
||||
Path::new("/root/.config/zed/settings.json"),
|
||||
)
|
||||
.await,
|
||||
Ok(ControlFlow::Continue(None))
|
||||
),
|
||||
"Should successfully find no prettier for path hierarchy without it"
|
||||
);
|
||||
assert!(
|
||||
matches!(
|
||||
Prettier::locate_prettier_installation(
|
||||
fs.as_ref(),
|
||||
&HashSet::default(),
|
||||
Path::new("/root/work/project/src/index.js")
|
||||
)
|
||||
.await,
|
||||
Ok(ControlFlow::Continue(None))
|
||||
),
|
||||
"Should successfully find no prettier for path hierarchy that has node_modules with prettier, but no package.json mentions of it"
|
||||
);
|
||||
assert!(
|
||||
matches!(
|
||||
Prettier::locate_prettier_installation(
|
||||
fs.as_ref(),
|
||||
&HashSet::default(),
|
||||
Path::new("/root/work/project/node_modules/expect/build/print.js")
|
||||
)
|
||||
.await,
|
||||
Ok(ControlFlow::Break(()))
|
||||
),
|
||||
"Should not format files inside node_modules/"
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_prettier_lookup_in_simple_npm_projects(cx: &mut gpui::TestAppContext) {
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
fs.insert_tree(
|
||||
"/root",
|
||||
json!({
|
||||
"web_blog": {
|
||||
"node_modules": {
|
||||
"prettier": {
|
||||
"index.js": "// Dummy prettier package file",
|
||||
},
|
||||
"expect": {
|
||||
"build": {
|
||||
"print.js": "// print.js file contents",
|
||||
},
|
||||
"package.json": r#"{
|
||||
"devDependencies": {
|
||||
"prettier": "2.5.1"
|
||||
}
|
||||
}"#,
|
||||
},
|
||||
},
|
||||
"pages": {
|
||||
"[slug].tsx": "// [slug].tsx file contents",
|
||||
},
|
||||
"package.json": r#"{
|
||||
"devDependencies": {
|
||||
"prettier": "2.3.0"
|
||||
},
|
||||
"prettier": {
|
||||
"semi": false,
|
||||
"printWidth": 80,
|
||||
"htmlWhitespaceSensitivity": "strict",
|
||||
"tabWidth": 4
|
||||
}
|
||||
}"#
|
||||
}
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
|
||||
assert_eq!(
|
||||
Prettier::locate_prettier_installation(
|
||||
fs.as_ref(),
|
||||
&HashSet::default(),
|
||||
Path::new("/root/web_blog/pages/[slug].tsx")
|
||||
)
|
||||
.await
|
||||
.unwrap(),
|
||||
ControlFlow::Continue(Some(PathBuf::from("/root/web_blog"))),
|
||||
"Should find a preinstalled prettier in the project root"
|
||||
);
|
||||
assert_eq!(
|
||||
Prettier::locate_prettier_installation(
|
||||
fs.as_ref(),
|
||||
&HashSet::default(),
|
||||
Path::new("/root/web_blog/node_modules/expect/build/print.js")
|
||||
)
|
||||
.await
|
||||
.unwrap(),
|
||||
ControlFlow::Break(()),
|
||||
"Should not allow formatting node_modules/ contents"
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_prettier_lookup_for_not_installed(cx: &mut gpui::TestAppContext) {
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
fs.insert_tree(
|
||||
"/root",
|
||||
json!({
|
||||
"work": {
|
||||
"web_blog": {
|
||||
"node_modules": {
|
||||
"expect": {
|
||||
"build": {
|
||||
"print.js": "// print.js file contents",
|
||||
},
|
||||
"package.json": r#"{
|
||||
"devDependencies": {
|
||||
"prettier": "2.5.1"
|
||||
}
|
||||
}"#,
|
||||
},
|
||||
},
|
||||
"pages": {
|
||||
"[slug].tsx": "// [slug].tsx file contents",
|
||||
},
|
||||
"package.json": r#"{
|
||||
"devDependencies": {
|
||||
"prettier": "2.3.0"
|
||||
},
|
||||
"prettier": {
|
||||
"semi": false,
|
||||
"printWidth": 80,
|
||||
"htmlWhitespaceSensitivity": "strict",
|
||||
"tabWidth": 4
|
||||
}
|
||||
}"#
|
||||
}
|
||||
}
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
|
||||
match Prettier::locate_prettier_installation(
|
||||
fs.as_ref(),
|
||||
&HashSet::default(),
|
||||
Path::new("/root/work/web_blog/pages/[slug].tsx")
|
||||
)
|
||||
.await {
|
||||
Ok(path) => panic!("Expected to fail for prettier in package.json but not in node_modules found, but got path {path:?}"),
|
||||
Err(e) => {
|
||||
let message = e.to_string();
|
||||
assert!(message.contains("/root/work/web_blog"), "Error message should mention which project had prettier defined");
|
||||
},
|
||||
};
|
||||
|
||||
assert_eq!(
|
||||
Prettier::locate_prettier_installation(
|
||||
fs.as_ref(),
|
||||
&HashSet::from_iter(
|
||||
[PathBuf::from("/root"), PathBuf::from("/root/work")].into_iter()
|
||||
),
|
||||
Path::new("/root/work/web_blog/pages/[slug].tsx")
|
||||
)
|
||||
.await
|
||||
.unwrap(),
|
||||
ControlFlow::Continue(Some(PathBuf::from("/root/work"))),
|
||||
"Should return closest cached value found without path checks"
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
Prettier::locate_prettier_installation(
|
||||
fs.as_ref(),
|
||||
&HashSet::default(),
|
||||
Path::new("/root/work/web_blog/node_modules/expect/build/print.js")
|
||||
)
|
||||
.await
|
||||
.unwrap(),
|
||||
ControlFlow::Break(()),
|
||||
"Should not allow formatting files inside node_modules/"
|
||||
);
|
||||
assert_eq!(
|
||||
Prettier::locate_prettier_installation(
|
||||
fs.as_ref(),
|
||||
&HashSet::from_iter(
|
||||
[PathBuf::from("/root"), PathBuf::from("/root/work")].into_iter()
|
||||
),
|
||||
Path::new("/root/work/web_blog/node_modules/expect/build/print.js")
|
||||
)
|
||||
.await
|
||||
.unwrap(),
|
||||
ControlFlow::Break(()),
|
||||
"Should ignore cache lookup for files inside node_modules/"
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_prettier_lookup_in_npm_workspaces(cx: &mut gpui::TestAppContext) {
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
fs.insert_tree(
|
||||
"/root",
|
||||
json!({
|
||||
"work": {
|
||||
"full-stack-foundations": {
|
||||
"exercises": {
|
||||
"03.loading": {
|
||||
"01.problem.loader": {
|
||||
"app": {
|
||||
"routes": {
|
||||
"users+": {
|
||||
"$username_+": {
|
||||
"notes.tsx": "// notes.tsx file contents",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"node_modules": {
|
||||
"test.js": "// test.js contents",
|
||||
},
|
||||
"package.json": r#"{
|
||||
"devDependencies": {
|
||||
"prettier": "^3.0.3"
|
||||
}
|
||||
}"#
|
||||
},
|
||||
},
|
||||
},
|
||||
"package.json": r#"{
|
||||
"workspaces": ["exercises/*/*", "examples/*"]
|
||||
}"#,
|
||||
"node_modules": {
|
||||
"prettier": {
|
||||
"index.js": "// Dummy prettier package file",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
|
||||
assert_eq!(
|
||||
Prettier::locate_prettier_installation(
|
||||
fs.as_ref(),
|
||||
&HashSet::default(),
|
||||
Path::new("/root/work/full-stack-foundations/exercises/03.loading/01.problem.loader/app/routes/users+/$username_+/notes.tsx"),
|
||||
).await.unwrap(),
|
||||
ControlFlow::Continue(Some(PathBuf::from("/root/work/full-stack-foundations"))),
|
||||
"Should ascend to the multi-workspace root and find the prettier there",
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
Prettier::locate_prettier_installation(
|
||||
fs.as_ref(),
|
||||
&HashSet::default(),
|
||||
Path::new("/root/work/full-stack-foundations/node_modules/prettier/index.js")
|
||||
)
|
||||
.await
|
||||
.unwrap(),
|
||||
ControlFlow::Break(()),
|
||||
"Should not allow formatting files inside root node_modules/"
|
||||
);
|
||||
assert_eq!(
|
||||
Prettier::locate_prettier_installation(
|
||||
fs.as_ref(),
|
||||
&HashSet::default(),
|
||||
Path::new("/root/work/full-stack-foundations/exercises/03.loading/01.problem.loader/node_modules/test.js")
|
||||
)
|
||||
.await
|
||||
.unwrap(),
|
||||
ControlFlow::Break(()),
|
||||
"Should not allow formatting files inside submodule's node_modules/"
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_prettier_lookup_in_npm_workspaces_for_not_installed(
|
||||
cx: &mut gpui::TestAppContext,
|
||||
) {
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
fs.insert_tree(
|
||||
"/root",
|
||||
json!({
|
||||
"work": {
|
||||
"full-stack-foundations": {
|
||||
"exercises": {
|
||||
"03.loading": {
|
||||
"01.problem.loader": {
|
||||
"app": {
|
||||
"routes": {
|
||||
"users+": {
|
||||
"$username_+": {
|
||||
"notes.tsx": "// notes.tsx file contents",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"node_modules": {},
|
||||
"package.json": r#"{
|
||||
"devDependencies": {
|
||||
"prettier": "^3.0.3"
|
||||
}
|
||||
}"#
|
||||
},
|
||||
},
|
||||
},
|
||||
"package.json": r#"{
|
||||
"workspaces": ["exercises/*/*", "examples/*"]
|
||||
}"#,
|
||||
},
|
||||
}
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
|
||||
match Prettier::locate_prettier_installation(
|
||||
fs.as_ref(),
|
||||
&HashSet::default(),
|
||||
Path::new("/root/work/full-stack-foundations/exercises/03.loading/01.problem.loader/app/routes/users+/$username_+/notes.tsx")
|
||||
)
|
||||
.await {
|
||||
Ok(path) => panic!("Expected to fail for prettier in package.json but not in node_modules found, but got path {path:?}"),
|
||||
Err(e) => {
|
||||
let message = e.to_string();
|
||||
assert!(message.contains("/root/work/full-stack-foundations/exercises/03.loading/01.problem.loader"), "Error message should mention which project had prettier defined");
|
||||
assert!(message.contains("/root/work/full-stack-foundations"), "Error message should mention potential candidates without prettier node_modules contents");
|
||||
},
|
||||
};
|
||||
}
|
||||
}
|
@ -1,241 +0,0 @@
|
||||
const { Buffer } = require("buffer");
|
||||
const fs = require("fs");
|
||||
const path = require("path");
|
||||
const { once } = require("events");
|
||||
|
||||
const prettierContainerPath = process.argv[2];
|
||||
if (prettierContainerPath == null || prettierContainerPath.length == 0) {
|
||||
process.stderr.write(
|
||||
`Prettier path argument was not specified or empty.\nUsage: ${process.argv[0]} ${process.argv[1]} prettier/path\n`,
|
||||
);
|
||||
process.exit(1);
|
||||
}
|
||||
fs.stat(prettierContainerPath, (err, stats) => {
|
||||
if (err) {
|
||||
process.stderr.write(`Path '${prettierContainerPath}' does not exist\n`);
|
||||
process.exit(1);
|
||||
}
|
||||
|
||||
if (!stats.isDirectory()) {
|
||||
process.stderr.write(`Path '${prettierContainerPath}' exists but is not a directory\n`);
|
||||
process.exit(1);
|
||||
}
|
||||
});
|
||||
const prettierPath = path.join(prettierContainerPath, "node_modules/prettier");
|
||||
|
||||
class Prettier {
|
||||
constructor(path, prettier, config) {
|
||||
this.path = path;
|
||||
this.prettier = prettier;
|
||||
this.config = config;
|
||||
}
|
||||
}
|
||||
|
||||
(async () => {
|
||||
let prettier;
|
||||
let config;
|
||||
try {
|
||||
prettier = await loadPrettier(prettierPath);
|
||||
config = (await prettier.resolveConfig(prettierPath)) || {};
|
||||
} catch (e) {
|
||||
process.stderr.write(`Failed to load prettier: ${e}\n`);
|
||||
process.exit(1);
|
||||
}
|
||||
process.stderr.write(`Prettier at path '${prettierPath}' loaded successfully, config: ${JSON.stringify(config)}\n`);
|
||||
process.stdin.resume();
|
||||
handleBuffer(new Prettier(prettierPath, prettier, config));
|
||||
})();
|
||||
|
||||
async function handleBuffer(prettier) {
|
||||
for await (const messageText of readStdin()) {
|
||||
let message;
|
||||
try {
|
||||
message = JSON.parse(messageText);
|
||||
} catch (e) {
|
||||
sendResponse(makeError(`Failed to parse message '${messageText}': ${e}`));
|
||||
continue;
|
||||
}
|
||||
// allow concurrent request handling by not `await`ing the message handling promise (async function)
|
||||
handleMessage(message, prettier).catch((e) => {
|
||||
const errorMessage = message;
|
||||
if ((errorMessage.params || {}).text !== undefined) {
|
||||
errorMessage.params.text = "..snip..";
|
||||
}
|
||||
sendResponse({
|
||||
id: message.id,
|
||||
...makeError(`error during message '${JSON.stringify(errorMessage)}' handling: ${e}`),
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
const headerSeparator = "\r\n";
|
||||
const contentLengthHeaderName = "Content-Length";
|
||||
|
||||
async function* readStdin() {
|
||||
let buffer = Buffer.alloc(0);
|
||||
let streamEnded = false;
|
||||
process.stdin.on("end", () => {
|
||||
streamEnded = true;
|
||||
});
|
||||
process.stdin.on("data", (data) => {
|
||||
buffer = Buffer.concat([buffer, data]);
|
||||
});
|
||||
|
||||
async function handleStreamEnded(errorMessage) {
|
||||
sendResponse(makeError(errorMessage));
|
||||
buffer = Buffer.alloc(0);
|
||||
messageLength = null;
|
||||
await once(process.stdin, "readable");
|
||||
streamEnded = false;
|
||||
}
|
||||
|
||||
try {
|
||||
let headersLength = null;
|
||||
let messageLength = null;
|
||||
main_loop: while (true) {
|
||||
if (messageLength === null) {
|
||||
while (buffer.indexOf(`${headerSeparator}${headerSeparator}`) === -1) {
|
||||
if (streamEnded) {
|
||||
await handleStreamEnded("Unexpected end of stream: headers not found");
|
||||
continue main_loop;
|
||||
} else if (buffer.length > contentLengthHeaderName.length * 10) {
|
||||
await handleStreamEnded(
|
||||
`Unexpected stream of bytes: no headers end found after ${buffer.length} bytes of input`,
|
||||
);
|
||||
continue main_loop;
|
||||
}
|
||||
await once(process.stdin, "readable");
|
||||
}
|
||||
const headers = buffer
|
||||
.subarray(0, buffer.indexOf(`${headerSeparator}${headerSeparator}`))
|
||||
.toString("ascii");
|
||||
const contentLengthHeader = headers
|
||||
.split(headerSeparator)
|
||||
.map((header) => header.split(":"))
|
||||
.filter((header) => header[2] === undefined)
|
||||
.filter((header) => (header[1] || "").length > 0)
|
||||
.find((header) => (header[0] || "").trim() === contentLengthHeaderName);
|
||||
const contentLength = (contentLengthHeader || [])[1];
|
||||
if (contentLength === undefined) {
|
||||
await handleStreamEnded(`Missing or incorrect ${contentLengthHeaderName} header: ${headers}`);
|
||||
continue main_loop;
|
||||
}
|
||||
headersLength = headers.length + headerSeparator.length * 2;
|
||||
messageLength = parseInt(contentLength, 10);
|
||||
}
|
||||
|
||||
while (buffer.length < headersLength + messageLength) {
|
||||
if (streamEnded) {
|
||||
await handleStreamEnded(
|
||||
`Unexpected end of stream: buffer length ${buffer.length} does not match expected header length ${headersLength} + body length ${messageLength}`,
|
||||
);
|
||||
continue main_loop;
|
||||
}
|
||||
await once(process.stdin, "readable");
|
||||
}
|
||||
|
||||
const messageEnd = headersLength + messageLength;
|
||||
const message = buffer.subarray(headersLength, messageEnd);
|
||||
buffer = buffer.subarray(messageEnd);
|
||||
headersLength = null;
|
||||
messageLength = null;
|
||||
yield message.toString("utf8");
|
||||
}
|
||||
} catch (e) {
|
||||
sendResponse(makeError(`Error reading stdin: ${e}`));
|
||||
} finally {
|
||||
process.stdin.off("data", () => {});
|
||||
}
|
||||
}
|
||||
|
||||
async function handleMessage(message, prettier) {
|
||||
const { method, id, params } = message;
|
||||
if (method === undefined) {
|
||||
throw new Error(`Message method is undefined: ${JSON.stringify(message)}`);
|
||||
} else if (method == "initialized") {
|
||||
return;
|
||||
}
|
||||
|
||||
if (id === undefined) {
|
||||
throw new Error(`Message id is undefined: ${JSON.stringify(message)}`);
|
||||
}
|
||||
|
||||
if (method === "prettier/format") {
|
||||
if (params === undefined || params.text === undefined) {
|
||||
throw new Error(`Message params.text is undefined: ${JSON.stringify(message)}`);
|
||||
}
|
||||
if (params.options === undefined) {
|
||||
throw new Error(`Message params.options is undefined: ${JSON.stringify(message)}`);
|
||||
}
|
||||
|
||||
let resolvedConfig = {};
|
||||
if (params.options.filepath !== undefined) {
|
||||
resolvedConfig = (await prettier.prettier.resolveConfig(params.options.filepath)) || {};
|
||||
}
|
||||
|
||||
const options = {
|
||||
...(params.options.prettierOptions || prettier.config),
|
||||
...resolvedConfig,
|
||||
parser: params.options.parser,
|
||||
plugins: params.options.plugins,
|
||||
path: params.options.filepath,
|
||||
};
|
||||
process.stderr.write(
|
||||
`Resolved config: ${JSON.stringify(resolvedConfig)}, will format file '${
|
||||
params.options.filepath || ""
|
||||
}' with options: ${JSON.stringify(options)}\n`,
|
||||
);
|
||||
const formattedText = await prettier.prettier.format(params.text, options);
|
||||
sendResponse({ id, result: { text: formattedText } });
|
||||
} else if (method === "prettier/clear_cache") {
|
||||
prettier.prettier.clearConfigCache();
|
||||
prettier.config = (await prettier.prettier.resolveConfig(prettier.path)) || {};
|
||||
sendResponse({ id, result: null });
|
||||
} else if (method === "initialize") {
|
||||
sendResponse({
|
||||
id,
|
||||
result: {
|
||||
capabilities: {},
|
||||
},
|
||||
});
|
||||
} else {
|
||||
throw new Error(`Unknown method: ${method}`);
|
||||
}
|
||||
}
|
||||
|
||||
function makeError(message) {
|
||||
return {
|
||||
error: {
|
||||
code: -32600, // invalid request code
|
||||
message,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
function sendResponse(response) {
|
||||
const responsePayloadString = JSON.stringify({
|
||||
jsonrpc: "2.0",
|
||||
...response,
|
||||
});
|
||||
const headers = `${contentLengthHeaderName}: ${Buffer.byteLength(
|
||||
responsePayloadString,
|
||||
)}${headerSeparator}${headerSeparator}`;
|
||||
process.stdout.write(headers + responsePayloadString);
|
||||
}
|
||||
|
||||
function loadPrettier(prettierPath) {
|
||||
return new Promise((resolve, reject) => {
|
||||
fs.access(prettierPath, fs.constants.F_OK, (err) => {
|
||||
if (err) {
|
||||
reject(`Path '${prettierPath}' does not exist.Error: ${err}`);
|
||||
} else {
|
||||
try {
|
||||
resolve(require(prettierPath));
|
||||
} catch (err) {
|
||||
reject(`Error requiring prettier module from path '${prettierPath}'.Error: ${err}`);
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
@ -25,8 +25,8 @@ copilot = { path = "../copilot" }
|
||||
client = { path = "../client" }
|
||||
clock = { path = "../clock" }
|
||||
collections = { path = "../collections" }
|
||||
db = { package = "db2", path = "../db2" }
|
||||
fs = { package = "fs2", path = "../fs2" }
|
||||
db = { path = "../db" }
|
||||
fs = { path = "../fs" }
|
||||
fsevent = { path = "../fsevent" }
|
||||
fuzzy = { path = "../fuzzy" }
|
||||
git = { package = "git3", path = "../git3" }
|
||||
@ -34,8 +34,8 @@ gpui = { package = "gpui2", path = "../gpui2" }
|
||||
language = { path = "../language" }
|
||||
lsp = { path = "../lsp" }
|
||||
node_runtime = { path = "../node_runtime" }
|
||||
prettier = { package = "prettier2", path = "../prettier2" }
|
||||
rpc = { package = "rpc2", path = "../rpc2" }
|
||||
prettier = { path = "../prettier" }
|
||||
rpc = { path = "../rpc" }
|
||||
settings = { package = "settings2", path = "../settings2" }
|
||||
sum_tree = { path = "../sum_tree" }
|
||||
terminal = { package = "terminal2", path = "../terminal2" }
|
||||
@ -71,15 +71,15 @@ env_logger.workspace = true
|
||||
pretty_assertions.workspace = true
|
||||
client = { path = "../client", features = ["test-support"] }
|
||||
collections = { path = "../collections", features = ["test-support"] }
|
||||
db = { package = "db2", path = "../db2", features = ["test-support"] }
|
||||
fs = { package = "fs2", path = "../fs2", features = ["test-support"] }
|
||||
db = { path = "../db", features = ["test-support"] }
|
||||
fs = { path = "../fs", features = ["test-support"] }
|
||||
gpui = { package = "gpui2", path = "../gpui2", features = ["test-support"] }
|
||||
language = { path = "../language", features = ["test-support"] }
|
||||
lsp = { path = "../lsp", features = ["test-support"] }
|
||||
settings = { package = "settings2", path = "../settings2", features = ["test-support"] }
|
||||
prettier = { package = "prettier2", path = "../prettier2", features = ["test-support"] }
|
||||
prettier = { path = "../prettier", features = ["test-support"] }
|
||||
util = { path = "../util", features = ["test-support"] }
|
||||
rpc = { package = "rpc2", path = "../rpc2", features = ["test-support"] }
|
||||
rpc = { path = "../rpc", features = ["test-support"] }
|
||||
git2.workspace = true
|
||||
tempdir.workspace = true
|
||||
unindent.workspace = true
|
||||
|
@ -10,7 +10,7 @@ doctest = false
|
||||
|
||||
[dependencies]
|
||||
collections = { path = "../collections" }
|
||||
db = { path = "../db2", package = "db2" }
|
||||
db = { path = "../db" }
|
||||
editor = { path = "../editor" }
|
||||
gpui = { path = "../gpui2", package = "gpui2" }
|
||||
menu = { path = "../menu2", package = "menu2" }
|
||||
|
@ -15,9 +15,8 @@ test-support = ["collections/test-support", "gpui/test-support"]
|
||||
[dependencies]
|
||||
clock = { path = "../clock" }
|
||||
collections = { path = "../collections" }
|
||||
gpui = { path = "../gpui", optional = true }
|
||||
gpui = { package = "gpui2", path = "../gpui2", optional = true }
|
||||
util = { path = "../util" }
|
||||
|
||||
anyhow.workspace = true
|
||||
async-lock = "2.4"
|
||||
async-tungstenite = "0.16"
|
||||
@ -27,8 +26,8 @@ parking_lot.workspace = true
|
||||
prost.workspace = true
|
||||
rand.workspace = true
|
||||
rsa = "0.4"
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
serde.workspace = true
|
||||
serde_derive.workspace = true
|
||||
smol-timeout = "0.6"
|
||||
strum.workspace = true
|
||||
@ -40,7 +39,7 @@ prost-build = "0.9"
|
||||
|
||||
[dev-dependencies]
|
||||
collections = { path = "../collections", features = ["test-support"] }
|
||||
gpui = { path = "../gpui", features = ["test-support"] }
|
||||
gpui = { package = "gpui2", path = "../gpui2", features = ["test-support"] }
|
||||
smol.workspace = true
|
||||
tempdir.workspace = true
|
||||
ctor.workspace = true
|
||||
|
@ -34,7 +34,7 @@ impl Connection {
|
||||
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
pub fn in_memory(
|
||||
executor: std::sync::Arc<gpui::executor::Background>,
|
||||
executor: gpui::BackgroundExecutor,
|
||||
) -> (Self, Self, std::sync::Arc<std::sync::atomic::AtomicBool>) {
|
||||
use std::sync::{
|
||||
atomic::{AtomicBool, Ordering::SeqCst},
|
||||
@ -53,7 +53,7 @@ impl Connection {
|
||||
#[allow(clippy::type_complexity)]
|
||||
fn channel(
|
||||
killed: Arc<AtomicBool>,
|
||||
executor: Arc<gpui::executor::Background>,
|
||||
executor: gpui::BackgroundExecutor,
|
||||
) -> (
|
||||
Box<dyn Send + Unpin + futures::Sink<WebSocketMessage, Error = anyhow::Error>>,
|
||||
Box<dyn Send + Unpin + futures::Stream<Item = Result<WebSocketMessage, anyhow::Error>>>,
|
||||
@ -66,14 +66,12 @@ impl Connection {
|
||||
|
||||
let tx = tx.sink_map_err(|error| anyhow!(error)).with({
|
||||
let killed = killed.clone();
|
||||
let executor = Arc::downgrade(&executor);
|
||||
let executor = executor.clone();
|
||||
move |msg| {
|
||||
let killed = killed.clone();
|
||||
let executor = executor.clone();
|
||||
Box::pin(async move {
|
||||
if let Some(executor) = executor.upgrade() {
|
||||
executor.simulate_random_delay().await;
|
||||
}
|
||||
|
||||
// Writes to a half-open TCP connection will error.
|
||||
if killed.load(SeqCst) {
|
||||
@ -87,14 +85,12 @@ impl Connection {
|
||||
|
||||
let rx = rx.then({
|
||||
let killed = killed;
|
||||
let executor = Arc::downgrade(&executor);
|
||||
let executor = executor.clone();
|
||||
move |msg| {
|
||||
let killed = killed.clone();
|
||||
let executor = executor.clone();
|
||||
Box::pin(async move {
|
||||
if let Some(executor) = executor.upgrade() {
|
||||
executor.simulate_random_delay().await;
|
||||
}
|
||||
|
||||
// Reads from a half-open TCP connection will hang.
|
||||
if killed.load(SeqCst) {
|
||||
|
@ -1,46 +0,0 @@
|
||||
[package]
|
||||
description = "Shared logic for communication between the Zed app and the zed.dev server"
|
||||
edition = "2021"
|
||||
name = "rpc2"
|
||||
version = "0.1.0"
|
||||
publish = false
|
||||
|
||||
[lib]
|
||||
path = "src/rpc.rs"
|
||||
doctest = false
|
||||
|
||||
[features]
|
||||
test-support = ["collections/test-support", "gpui/test-support"]
|
||||
|
||||
[dependencies]
|
||||
clock = { path = "../clock" }
|
||||
collections = { path = "../collections" }
|
||||
gpui = { package = "gpui2", path = "../gpui2", optional = true }
|
||||
util = { path = "../util" }
|
||||
anyhow.workspace = true
|
||||
async-lock = "2.4"
|
||||
async-tungstenite = "0.16"
|
||||
base64 = "0.13"
|
||||
futures.workspace = true
|
||||
parking_lot.workspace = true
|
||||
prost.workspace = true
|
||||
rand.workspace = true
|
||||
rsa = "0.4"
|
||||
serde_json.workspace = true
|
||||
serde.workspace = true
|
||||
serde_derive.workspace = true
|
||||
smol-timeout = "0.6"
|
||||
strum.workspace = true
|
||||
tracing = { version = "0.1.34", features = ["log"] }
|
||||
zstd = "0.11"
|
||||
|
||||
[build-dependencies]
|
||||
prost-build = "0.9"
|
||||
|
||||
[dev-dependencies]
|
||||
collections = { path = "../collections", features = ["test-support"] }
|
||||
gpui = { package = "gpui2", path = "../gpui2", features = ["test-support"] }
|
||||
smol.workspace = true
|
||||
tempdir.workspace = true
|
||||
ctor.workspace = true
|
||||
env_logger.workspace = true
|
@ -1,8 +0,0 @@
|
||||
fn main() {
|
||||
let mut build = prost_build::Config::new();
|
||||
// build.protoc_arg("--experimental_allow_proto3_optional");
|
||||
build
|
||||
.type_attribute(".", "#[derive(serde::Serialize)]")
|
||||
.compile_protos(&["proto/zed.proto"], &["proto"])
|
||||
.unwrap();
|
||||
}
|
File diff suppressed because it is too large
Load Diff
@ -1,136 +0,0 @@
|
||||
use anyhow::{Context, Result};
|
||||
use rand::{thread_rng, Rng as _};
|
||||
use rsa::{PublicKey as _, PublicKeyEncoding, RSAPrivateKey, RSAPublicKey};
|
||||
use std::convert::TryFrom;
|
||||
|
||||
pub struct PublicKey(RSAPublicKey);
|
||||
|
||||
pub struct PrivateKey(RSAPrivateKey);
|
||||
|
||||
/// Generate a public and private key for asymmetric encryption.
|
||||
pub fn keypair() -> Result<(PublicKey, PrivateKey)> {
|
||||
let mut rng = thread_rng();
|
||||
let bits = 1024;
|
||||
let private_key = RSAPrivateKey::new(&mut rng, bits)?;
|
||||
let public_key = RSAPublicKey::from(&private_key);
|
||||
Ok((PublicKey(public_key), PrivateKey(private_key)))
|
||||
}
|
||||
|
||||
/// Generate a random 64-character base64 string.
|
||||
pub fn random_token() -> String {
|
||||
let mut rng = thread_rng();
|
||||
let mut token_bytes = [0; 48];
|
||||
for byte in token_bytes.iter_mut() {
|
||||
*byte = rng.gen();
|
||||
}
|
||||
base64::encode_config(token_bytes, base64::URL_SAFE)
|
||||
}
|
||||
|
||||
impl PublicKey {
|
||||
/// Convert a string to a base64-encoded string that can only be decoded with the corresponding
|
||||
/// private key.
|
||||
pub fn encrypt_string(&self, string: &str) -> Result<String> {
|
||||
let mut rng = thread_rng();
|
||||
let bytes = string.as_bytes();
|
||||
let encrypted_bytes = self
|
||||
.0
|
||||
.encrypt(&mut rng, PADDING_SCHEME, bytes)
|
||||
.context("failed to encrypt string with public key")?;
|
||||
let encrypted_string = base64::encode_config(&encrypted_bytes, base64::URL_SAFE);
|
||||
Ok(encrypted_string)
|
||||
}
|
||||
}
|
||||
|
||||
impl PrivateKey {
|
||||
/// Decrypt a base64-encoded string that was encrypted by the corresponding public key.
|
||||
pub fn decrypt_string(&self, encrypted_string: &str) -> Result<String> {
|
||||
let encrypted_bytes = base64::decode_config(encrypted_string, base64::URL_SAFE)
|
||||
.context("failed to base64-decode encrypted string")?;
|
||||
let bytes = self
|
||||
.0
|
||||
.decrypt(PADDING_SCHEME, &encrypted_bytes)
|
||||
.context("failed to decrypt string with private key")?;
|
||||
let string = String::from_utf8(bytes).context("decrypted content was not valid utf8")?;
|
||||
Ok(string)
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<PublicKey> for String {
|
||||
type Error = anyhow::Error;
|
||||
fn try_from(key: PublicKey) -> Result<Self> {
|
||||
let bytes = key.0.to_pkcs1().context("failed to serialize public key")?;
|
||||
let string = base64::encode_config(&bytes, base64::URL_SAFE);
|
||||
Ok(string)
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<String> for PublicKey {
|
||||
type Error = anyhow::Error;
|
||||
fn try_from(value: String) -> Result<Self> {
|
||||
let bytes = base64::decode_config(&value, base64::URL_SAFE)
|
||||
.context("failed to base64-decode public key string")?;
|
||||
let key = Self(RSAPublicKey::from_pkcs1(&bytes).context("failed to parse public key")?);
|
||||
Ok(key)
|
||||
}
|
||||
}
|
||||
|
||||
const PADDING_SCHEME: rsa::PaddingScheme = rsa::PaddingScheme::PKCS1v15Encrypt;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_generate_encrypt_and_decrypt_token() {
|
||||
// CLIENT:
|
||||
// * generate a keypair for asymmetric encryption
|
||||
// * serialize the public key to send it to the server.
|
||||
let (public, private) = keypair().unwrap();
|
||||
let public_string = String::try_from(public).unwrap();
|
||||
assert_printable(&public_string);
|
||||
|
||||
// SERVER:
|
||||
// * parse the public key
|
||||
// * generate a random token.
|
||||
// * encrypt the token using the public key.
|
||||
let public = PublicKey::try_from(public_string).unwrap();
|
||||
let token = random_token();
|
||||
let encrypted_token = public.encrypt_string(&token).unwrap();
|
||||
assert_eq!(token.len(), 64);
|
||||
assert_ne!(encrypted_token, token);
|
||||
assert_printable(&token);
|
||||
assert_printable(&encrypted_token);
|
||||
|
||||
// CLIENT:
|
||||
// * decrypt the token using the private key.
|
||||
let decrypted_token = private.decrypt_string(&encrypted_token).unwrap();
|
||||
assert_eq!(decrypted_token, token);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tokens_are_always_url_safe() {
|
||||
for _ in 0..5 {
|
||||
let token = random_token();
|
||||
let (public_key, _) = keypair().unwrap();
|
||||
let encrypted_token = public_key.encrypt_string(&token).unwrap();
|
||||
let public_key_str = String::try_from(public_key).unwrap();
|
||||
|
||||
assert_printable(&token);
|
||||
assert_printable(&public_key_str);
|
||||
assert_printable(&encrypted_token);
|
||||
}
|
||||
}
|
||||
|
||||
fn assert_printable(token: &str) {
|
||||
for c in token.chars() {
|
||||
assert!(
|
||||
c.is_ascii_graphic(),
|
||||
"token {:?} has non-printable char {}",
|
||||
token,
|
||||
c
|
||||
);
|
||||
assert_ne!(c, '/', "token {:?} is not URL-safe", token);
|
||||
assert_ne!(c, '&', "token {:?} is not URL-safe", token);
|
||||
}
|
||||
}
|
||||
}
|
@ -1,108 +0,0 @@
|
||||
use async_tungstenite::tungstenite::Message as WebSocketMessage;
|
||||
use futures::{SinkExt as _, StreamExt as _};
|
||||
|
||||
pub struct Connection {
|
||||
pub(crate) tx:
|
||||
Box<dyn 'static + Send + Unpin + futures::Sink<WebSocketMessage, Error = anyhow::Error>>,
|
||||
pub(crate) rx: Box<
|
||||
dyn 'static
|
||||
+ Send
|
||||
+ Unpin
|
||||
+ futures::Stream<Item = Result<WebSocketMessage, anyhow::Error>>,
|
||||
>,
|
||||
}
|
||||
|
||||
impl Connection {
|
||||
pub fn new<S>(stream: S) -> Self
|
||||
where
|
||||
S: 'static
|
||||
+ Send
|
||||
+ Unpin
|
||||
+ futures::Sink<WebSocketMessage, Error = anyhow::Error>
|
||||
+ futures::Stream<Item = Result<WebSocketMessage, anyhow::Error>>,
|
||||
{
|
||||
let (tx, rx) = stream.split();
|
||||
Self {
|
||||
tx: Box::new(tx),
|
||||
rx: Box::new(rx),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn send(&mut self, message: WebSocketMessage) -> Result<(), anyhow::Error> {
|
||||
self.tx.send(message).await
|
||||
}
|
||||
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
pub fn in_memory(
|
||||
executor: gpui::BackgroundExecutor,
|
||||
) -> (Self, Self, std::sync::Arc<std::sync::atomic::AtomicBool>) {
|
||||
use std::sync::{
|
||||
atomic::{AtomicBool, Ordering::SeqCst},
|
||||
Arc,
|
||||
};
|
||||
|
||||
let killed = Arc::new(AtomicBool::new(false));
|
||||
let (a_tx, a_rx) = channel(killed.clone(), executor.clone());
|
||||
let (b_tx, b_rx) = channel(killed.clone(), executor);
|
||||
return (
|
||||
Self { tx: a_tx, rx: b_rx },
|
||||
Self { tx: b_tx, rx: a_rx },
|
||||
killed,
|
||||
);
|
||||
|
||||
#[allow(clippy::type_complexity)]
|
||||
fn channel(
|
||||
killed: Arc<AtomicBool>,
|
||||
executor: gpui::BackgroundExecutor,
|
||||
) -> (
|
||||
Box<dyn Send + Unpin + futures::Sink<WebSocketMessage, Error = anyhow::Error>>,
|
||||
Box<dyn Send + Unpin + futures::Stream<Item = Result<WebSocketMessage, anyhow::Error>>>,
|
||||
) {
|
||||
use anyhow::anyhow;
|
||||
use futures::channel::mpsc;
|
||||
use std::io::{Error, ErrorKind};
|
||||
|
||||
let (tx, rx) = mpsc::unbounded::<WebSocketMessage>();
|
||||
|
||||
let tx = tx.sink_map_err(|error| anyhow!(error)).with({
|
||||
let killed = killed.clone();
|
||||
let executor = executor.clone();
|
||||
move |msg| {
|
||||
let killed = killed.clone();
|
||||
let executor = executor.clone();
|
||||
Box::pin(async move {
|
||||
executor.simulate_random_delay().await;
|
||||
|
||||
// Writes to a half-open TCP connection will error.
|
||||
if killed.load(SeqCst) {
|
||||
std::io::Result::Err(Error::new(ErrorKind::Other, "connection lost"))?;
|
||||
}
|
||||
|
||||
Ok(msg)
|
||||
})
|
||||
}
|
||||
});
|
||||
|
||||
let rx = rx.then({
|
||||
let killed = killed;
|
||||
let executor = executor.clone();
|
||||
move |msg| {
|
||||
let killed = killed.clone();
|
||||
let executor = executor.clone();
|
||||
Box::pin(async move {
|
||||
executor.simulate_random_delay().await;
|
||||
|
||||
// Reads from a half-open TCP connection will hang.
|
||||
if killed.load(SeqCst) {
|
||||
futures::future::pending::<()>().await;
|
||||
}
|
||||
|
||||
Ok(msg)
|
||||
})
|
||||
}
|
||||
});
|
||||
|
||||
(Box::new(tx), Box::new(rx))
|
||||
}
|
||||
}
|
||||
}
|
@ -1,70 +0,0 @@
|
||||
#[macro_export]
|
||||
macro_rules! messages {
|
||||
($(($name:ident, $priority:ident)),* $(,)?) => {
|
||||
pub fn build_typed_envelope(sender_id: ConnectionId, envelope: Envelope) -> Option<Box<dyn AnyTypedEnvelope>> {
|
||||
match envelope.payload {
|
||||
$(Some(envelope::Payload::$name(payload)) => {
|
||||
Some(Box::new(TypedEnvelope {
|
||||
sender_id,
|
||||
original_sender_id: envelope.original_sender_id.map(|original_sender| PeerId {
|
||||
owner_id: original_sender.owner_id,
|
||||
id: original_sender.id
|
||||
}),
|
||||
message_id: envelope.id,
|
||||
payload,
|
||||
}))
|
||||
}, )*
|
||||
_ => None
|
||||
}
|
||||
}
|
||||
|
||||
$(
|
||||
impl EnvelopedMessage for $name {
|
||||
const NAME: &'static str = std::stringify!($name);
|
||||
const PRIORITY: MessagePriority = MessagePriority::$priority;
|
||||
|
||||
fn into_envelope(
|
||||
self,
|
||||
id: u32,
|
||||
responding_to: Option<u32>,
|
||||
original_sender_id: Option<PeerId>,
|
||||
) -> Envelope {
|
||||
Envelope {
|
||||
id,
|
||||
responding_to,
|
||||
original_sender_id,
|
||||
payload: Some(envelope::Payload::$name(self)),
|
||||
}
|
||||
}
|
||||
|
||||
fn from_envelope(envelope: Envelope) -> Option<Self> {
|
||||
if let Some(envelope::Payload::$name(msg)) = envelope.payload {
|
||||
Some(msg)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
)*
|
||||
};
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! request_messages {
|
||||
($(($request_name:ident, $response_name:ident)),* $(,)?) => {
|
||||
$(impl RequestMessage for $request_name {
|
||||
type Response = $response_name;
|
||||
})*
|
||||
};
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! entity_messages {
|
||||
($id_field:ident, $($name:ident),* $(,)?) => {
|
||||
$(impl EntityMessage for $name {
|
||||
fn remote_entity_id(&self) -> u64 {
|
||||
self.$id_field
|
||||
}
|
||||
})*
|
||||
};
|
||||
}
|
@ -1,105 +0,0 @@
|
||||
use crate::proto;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{map, Value};
|
||||
use strum::{EnumVariantNames, VariantNames as _};
|
||||
|
||||
const KIND: &'static str = "kind";
|
||||
const ENTITY_ID: &'static str = "entity_id";
|
||||
|
||||
/// A notification that can be stored, associated with a given recipient.
|
||||
///
|
||||
/// This struct is stored in the collab database as JSON, so it shouldn't be
|
||||
/// changed in a backward-incompatible way. For example, when renaming a
|
||||
/// variant, add a serde alias for the old name.
|
||||
///
|
||||
/// Most notification types have a special field which is aliased to
|
||||
/// `entity_id`. This field is stored in its own database column, and can
|
||||
/// be used to query the notification.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, EnumVariantNames, Serialize, Deserialize)]
|
||||
#[serde(tag = "kind")]
|
||||
pub enum Notification {
|
||||
ContactRequest {
|
||||
#[serde(rename = "entity_id")]
|
||||
sender_id: u64,
|
||||
},
|
||||
ContactRequestAccepted {
|
||||
#[serde(rename = "entity_id")]
|
||||
responder_id: u64,
|
||||
},
|
||||
ChannelInvitation {
|
||||
#[serde(rename = "entity_id")]
|
||||
channel_id: u64,
|
||||
channel_name: String,
|
||||
inviter_id: u64,
|
||||
},
|
||||
ChannelMessageMention {
|
||||
#[serde(rename = "entity_id")]
|
||||
message_id: u64,
|
||||
sender_id: u64,
|
||||
channel_id: u64,
|
||||
},
|
||||
}
|
||||
|
||||
impl Notification {
|
||||
pub fn to_proto(&self) -> proto::Notification {
|
||||
let mut value = serde_json::to_value(self).unwrap();
|
||||
let mut entity_id = None;
|
||||
let value = value.as_object_mut().unwrap();
|
||||
let Some(Value::String(kind)) = value.remove(KIND) else {
|
||||
unreachable!("kind is the enum tag")
|
||||
};
|
||||
if let map::Entry::Occupied(e) = value.entry(ENTITY_ID) {
|
||||
if e.get().is_u64() {
|
||||
entity_id = e.remove().as_u64();
|
||||
}
|
||||
}
|
||||
proto::Notification {
|
||||
kind,
|
||||
entity_id,
|
||||
content: serde_json::to_string(&value).unwrap(),
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_proto(notification: &proto::Notification) -> Option<Self> {
|
||||
let mut value = serde_json::from_str::<Value>(¬ification.content).ok()?;
|
||||
let object = value.as_object_mut()?;
|
||||
object.insert(KIND.into(), notification.kind.to_string().into());
|
||||
if let Some(entity_id) = notification.entity_id {
|
||||
object.insert(ENTITY_ID.into(), entity_id.into());
|
||||
}
|
||||
serde_json::from_value(value).ok()
|
||||
}
|
||||
|
||||
pub fn all_variant_names() -> &'static [&'static str] {
|
||||
Self::VARIANTS
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_notification() {
|
||||
// Notifications can be serialized and deserialized.
|
||||
for notification in [
|
||||
Notification::ContactRequest { sender_id: 1 },
|
||||
Notification::ContactRequestAccepted { responder_id: 2 },
|
||||
Notification::ChannelInvitation {
|
||||
channel_id: 100,
|
||||
channel_name: "the-channel".into(),
|
||||
inviter_id: 50,
|
||||
},
|
||||
Notification::ChannelMessageMention {
|
||||
sender_id: 200,
|
||||
channel_id: 30,
|
||||
message_id: 1,
|
||||
},
|
||||
] {
|
||||
let message = notification.to_proto();
|
||||
let deserialized = Notification::from_proto(&message).unwrap();
|
||||
assert_eq!(deserialized, notification);
|
||||
}
|
||||
|
||||
// When notifications are serialized, the `kind` and `actor_id` fields are
|
||||
// stored separately, and do not appear redundantly in the JSON.
|
||||
let notification = Notification::ContactRequest { sender_id: 1 };
|
||||
assert_eq!(notification.to_proto().content, "{}");
|
||||
}
|
@ -1,934 +0,0 @@
|
||||
use super::{
|
||||
proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, PeerId, RequestMessage},
|
||||
Connection,
|
||||
};
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use collections::HashMap;
|
||||
use futures::{
|
||||
channel::{mpsc, oneshot},
|
||||
stream::BoxStream,
|
||||
FutureExt, SinkExt, StreamExt, TryFutureExt,
|
||||
};
|
||||
use parking_lot::{Mutex, RwLock};
|
||||
use serde::{ser::SerializeStruct, Serialize};
|
||||
use std::{fmt, sync::atomic::Ordering::SeqCst};
|
||||
use std::{
|
||||
future::Future,
|
||||
marker::PhantomData,
|
||||
sync::{
|
||||
atomic::{self, AtomicU32},
|
||||
Arc,
|
||||
},
|
||||
time::Duration,
|
||||
};
|
||||
use tracing::instrument;
|
||||
|
||||
#[derive(Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, Hash, Debug, Serialize)]
|
||||
pub struct ConnectionId {
|
||||
pub owner_id: u32,
|
||||
pub id: u32,
|
||||
}
|
||||
|
||||
impl Into<PeerId> for ConnectionId {
|
||||
fn into(self) -> PeerId {
|
||||
PeerId {
|
||||
owner_id: self.owner_id,
|
||||
id: self.id,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<PeerId> for ConnectionId {
|
||||
fn from(peer_id: PeerId) -> Self {
|
||||
Self {
|
||||
owner_id: peer_id.owner_id,
|
||||
id: peer_id.id,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for ConnectionId {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "{}/{}", self.owner_id, self.id)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Receipt<T> {
|
||||
pub sender_id: ConnectionId,
|
||||
pub message_id: u32,
|
||||
payload_type: PhantomData<T>,
|
||||
}
|
||||
|
||||
impl<T> Clone for Receipt<T> {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
sender_id: self.sender_id,
|
||||
message_id: self.message_id,
|
||||
payload_type: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Copy for Receipt<T> {}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct TypedEnvelope<T> {
|
||||
pub sender_id: ConnectionId,
|
||||
pub original_sender_id: Option<PeerId>,
|
||||
pub message_id: u32,
|
||||
pub payload: T,
|
||||
}
|
||||
|
||||
impl<T> TypedEnvelope<T> {
|
||||
pub fn original_sender_id(&self) -> Result<PeerId> {
|
||||
self.original_sender_id
|
||||
.ok_or_else(|| anyhow!("missing original_sender_id"))
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RequestMessage> TypedEnvelope<T> {
|
||||
pub fn receipt(&self) -> Receipt<T> {
|
||||
Receipt {
|
||||
sender_id: self.sender_id,
|
||||
message_id: self.message_id,
|
||||
payload_type: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Peer {
|
||||
epoch: AtomicU32,
|
||||
pub connections: RwLock<HashMap<ConnectionId, ConnectionState>>,
|
||||
next_connection_id: AtomicU32,
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize)]
|
||||
pub struct ConnectionState {
|
||||
#[serde(skip)]
|
||||
outgoing_tx: mpsc::UnboundedSender<proto::Message>,
|
||||
next_message_id: Arc<AtomicU32>,
|
||||
#[allow(clippy::type_complexity)]
|
||||
#[serde(skip)]
|
||||
response_channels:
|
||||
Arc<Mutex<Option<HashMap<u32, oneshot::Sender<(proto::Envelope, oneshot::Sender<()>)>>>>>,
|
||||
}
|
||||
|
||||
const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(1);
|
||||
const WRITE_TIMEOUT: Duration = Duration::from_secs(2);
|
||||
pub const RECEIVE_TIMEOUT: Duration = Duration::from_secs(10);
|
||||
|
||||
impl Peer {
|
||||
pub fn new(epoch: u32) -> Arc<Self> {
|
||||
Arc::new(Self {
|
||||
epoch: AtomicU32::new(epoch),
|
||||
connections: Default::default(),
|
||||
next_connection_id: Default::default(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn epoch(&self) -> u32 {
|
||||
self.epoch.load(SeqCst)
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
pub fn add_connection<F, Fut, Out>(
|
||||
self: &Arc<Self>,
|
||||
connection: Connection,
|
||||
create_timer: F,
|
||||
) -> (
|
||||
ConnectionId,
|
||||
impl Future<Output = anyhow::Result<()>> + Send,
|
||||
BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
|
||||
)
|
||||
where
|
||||
F: Send + Fn(Duration) -> Fut,
|
||||
Fut: Send + Future<Output = Out>,
|
||||
Out: Send,
|
||||
{
|
||||
// For outgoing messages, use an unbounded channel so that application code
|
||||
// can always send messages without yielding. For incoming messages, use a
|
||||
// bounded channel so that other peers will receive backpressure if they send
|
||||
// messages faster than this peer can process them.
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
const INCOMING_BUFFER_SIZE: usize = 1;
|
||||
#[cfg(not(any(test, feature = "test-support")))]
|
||||
const INCOMING_BUFFER_SIZE: usize = 64;
|
||||
let (mut incoming_tx, incoming_rx) = mpsc::channel(INCOMING_BUFFER_SIZE);
|
||||
let (outgoing_tx, mut outgoing_rx) = mpsc::unbounded();
|
||||
|
||||
let connection_id = ConnectionId {
|
||||
owner_id: self.epoch.load(SeqCst),
|
||||
id: self.next_connection_id.fetch_add(1, SeqCst),
|
||||
};
|
||||
let connection_state = ConnectionState {
|
||||
outgoing_tx,
|
||||
next_message_id: Default::default(),
|
||||
response_channels: Arc::new(Mutex::new(Some(Default::default()))),
|
||||
};
|
||||
let mut writer = MessageStream::new(connection.tx);
|
||||
let mut reader = MessageStream::new(connection.rx);
|
||||
|
||||
let this = self.clone();
|
||||
let response_channels = connection_state.response_channels.clone();
|
||||
let handle_io = async move {
|
||||
tracing::trace!(%connection_id, "handle io future: start");
|
||||
|
||||
let _end_connection = util::defer(|| {
|
||||
response_channels.lock().take();
|
||||
this.connections.write().remove(&connection_id);
|
||||
tracing::trace!(%connection_id, "handle io future: end");
|
||||
});
|
||||
|
||||
// Send messages on this frequency so the connection isn't closed.
|
||||
let keepalive_timer = create_timer(KEEPALIVE_INTERVAL).fuse();
|
||||
futures::pin_mut!(keepalive_timer);
|
||||
|
||||
// Disconnect if we don't receive messages at least this frequently.
|
||||
let receive_timeout = create_timer(RECEIVE_TIMEOUT).fuse();
|
||||
futures::pin_mut!(receive_timeout);
|
||||
|
||||
loop {
|
||||
tracing::trace!(%connection_id, "outer loop iteration start");
|
||||
let read_message = reader.read().fuse();
|
||||
futures::pin_mut!(read_message);
|
||||
|
||||
loop {
|
||||
tracing::trace!(%connection_id, "inner loop iteration start");
|
||||
futures::select_biased! {
|
||||
outgoing = outgoing_rx.next().fuse() => match outgoing {
|
||||
Some(outgoing) => {
|
||||
tracing::trace!(%connection_id, "outgoing rpc message: writing");
|
||||
futures::select_biased! {
|
||||
result = writer.write(outgoing).fuse() => {
|
||||
tracing::trace!(%connection_id, "outgoing rpc message: done writing");
|
||||
result.context("failed to write RPC message")?;
|
||||
tracing::trace!(%connection_id, "keepalive interval: resetting after sending message");
|
||||
keepalive_timer.set(create_timer(KEEPALIVE_INTERVAL).fuse());
|
||||
}
|
||||
_ = create_timer(WRITE_TIMEOUT).fuse() => {
|
||||
tracing::trace!(%connection_id, "outgoing rpc message: writing timed out");
|
||||
Err(anyhow!("timed out writing message"))?;
|
||||
}
|
||||
}
|
||||
}
|
||||
None => {
|
||||
tracing::trace!(%connection_id, "outgoing rpc message: channel closed");
|
||||
return Ok(())
|
||||
},
|
||||
},
|
||||
_ = keepalive_timer => {
|
||||
tracing::trace!(%connection_id, "keepalive interval: pinging");
|
||||
futures::select_biased! {
|
||||
result = writer.write(proto::Message::Ping).fuse() => {
|
||||
tracing::trace!(%connection_id, "keepalive interval: done pinging");
|
||||
result.context("failed to send keepalive")?;
|
||||
tracing::trace!(%connection_id, "keepalive interval: resetting after pinging");
|
||||
keepalive_timer.set(create_timer(KEEPALIVE_INTERVAL).fuse());
|
||||
}
|
||||
_ = create_timer(WRITE_TIMEOUT).fuse() => {
|
||||
tracing::trace!(%connection_id, "keepalive interval: pinging timed out");
|
||||
Err(anyhow!("timed out sending keepalive"))?;
|
||||
}
|
||||
}
|
||||
}
|
||||
incoming = read_message => {
|
||||
let incoming = incoming.context("error reading rpc message from socket")?;
|
||||
tracing::trace!(%connection_id, "incoming rpc message: received");
|
||||
tracing::trace!(%connection_id, "receive timeout: resetting");
|
||||
receive_timeout.set(create_timer(RECEIVE_TIMEOUT).fuse());
|
||||
if let proto::Message::Envelope(incoming) = incoming {
|
||||
tracing::trace!(%connection_id, "incoming rpc message: processing");
|
||||
futures::select_biased! {
|
||||
result = incoming_tx.send(incoming).fuse() => match result {
|
||||
Ok(_) => {
|
||||
tracing::trace!(%connection_id, "incoming rpc message: processed");
|
||||
}
|
||||
Err(_) => {
|
||||
tracing::trace!(%connection_id, "incoming rpc message: channel closed");
|
||||
return Ok(())
|
||||
}
|
||||
},
|
||||
_ = create_timer(WRITE_TIMEOUT).fuse() => {
|
||||
tracing::trace!(%connection_id, "incoming rpc message: processing timed out");
|
||||
Err(anyhow!("timed out processing incoming message"))?
|
||||
}
|
||||
}
|
||||
}
|
||||
break;
|
||||
},
|
||||
_ = receive_timeout => {
|
||||
tracing::trace!(%connection_id, "receive timeout: delay between messages too long");
|
||||
Err(anyhow!("delay between messages too long"))?
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let response_channels = connection_state.response_channels.clone();
|
||||
self.connections
|
||||
.write()
|
||||
.insert(connection_id, connection_state);
|
||||
|
||||
let incoming_rx = incoming_rx.filter_map(move |incoming| {
|
||||
let response_channels = response_channels.clone();
|
||||
async move {
|
||||
let message_id = incoming.id;
|
||||
tracing::trace!(?incoming, "incoming message future: start");
|
||||
let _end = util::defer(move || {
|
||||
tracing::trace!(%connection_id, message_id, "incoming message future: end");
|
||||
});
|
||||
|
||||
if let Some(responding_to) = incoming.responding_to {
|
||||
tracing::trace!(
|
||||
%connection_id,
|
||||
message_id,
|
||||
responding_to,
|
||||
"incoming response: received"
|
||||
);
|
||||
let channel = response_channels.lock().as_mut()?.remove(&responding_to);
|
||||
if let Some(tx) = channel {
|
||||
let requester_resumed = oneshot::channel();
|
||||
if let Err(error) = tx.send((incoming, requester_resumed.0)) {
|
||||
tracing::trace!(
|
||||
%connection_id,
|
||||
message_id,
|
||||
responding_to = responding_to,
|
||||
?error,
|
||||
"incoming response: request future dropped",
|
||||
);
|
||||
}
|
||||
|
||||
tracing::trace!(
|
||||
%connection_id,
|
||||
message_id,
|
||||
responding_to,
|
||||
"incoming response: waiting to resume requester"
|
||||
);
|
||||
let _ = requester_resumed.1.await;
|
||||
tracing::trace!(
|
||||
%connection_id,
|
||||
message_id,
|
||||
responding_to,
|
||||
"incoming response: requester resumed"
|
||||
);
|
||||
} else {
|
||||
tracing::warn!(
|
||||
%connection_id,
|
||||
message_id,
|
||||
responding_to,
|
||||
"incoming response: unknown request"
|
||||
);
|
||||
}
|
||||
|
||||
None
|
||||
} else {
|
||||
tracing::trace!(%connection_id, message_id, "incoming message: received");
|
||||
proto::build_typed_envelope(connection_id, incoming).or_else(|| {
|
||||
tracing::error!(
|
||||
%connection_id,
|
||||
message_id,
|
||||
"unable to construct a typed envelope"
|
||||
);
|
||||
None
|
||||
})
|
||||
}
|
||||
}
|
||||
});
|
||||
(connection_id, handle_io, incoming_rx.boxed())
|
||||
}
|
||||
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
pub fn add_test_connection(
|
||||
self: &Arc<Self>,
|
||||
connection: Connection,
|
||||
executor: gpui::BackgroundExecutor,
|
||||
) -> (
|
||||
ConnectionId,
|
||||
impl Future<Output = anyhow::Result<()>> + Send,
|
||||
BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
|
||||
) {
|
||||
let executor = executor.clone();
|
||||
self.add_connection(connection, move |duration| executor.timer(duration))
|
||||
}
|
||||
|
||||
pub fn disconnect(&self, connection_id: ConnectionId) {
|
||||
self.connections.write().remove(&connection_id);
|
||||
}
|
||||
|
||||
pub fn reset(&self, epoch: u32) {
|
||||
self.teardown();
|
||||
self.next_connection_id.store(0, SeqCst);
|
||||
self.epoch.store(epoch, SeqCst);
|
||||
}
|
||||
|
||||
pub fn teardown(&self) {
|
||||
self.connections.write().clear();
|
||||
}
|
||||
|
||||
pub fn request<T: RequestMessage>(
|
||||
&self,
|
||||
receiver_id: ConnectionId,
|
||||
request: T,
|
||||
) -> impl Future<Output = Result<T::Response>> {
|
||||
self.request_internal(None, receiver_id, request)
|
||||
.map_ok(|envelope| envelope.payload)
|
||||
}
|
||||
|
||||
pub fn request_envelope<T: RequestMessage>(
|
||||
&self,
|
||||
receiver_id: ConnectionId,
|
||||
request: T,
|
||||
) -> impl Future<Output = Result<TypedEnvelope<T::Response>>> {
|
||||
self.request_internal(None, receiver_id, request)
|
||||
}
|
||||
|
||||
pub fn forward_request<T: RequestMessage>(
|
||||
&self,
|
||||
sender_id: ConnectionId,
|
||||
receiver_id: ConnectionId,
|
||||
request: T,
|
||||
) -> impl Future<Output = Result<T::Response>> {
|
||||
self.request_internal(Some(sender_id), receiver_id, request)
|
||||
.map_ok(|envelope| envelope.payload)
|
||||
}
|
||||
|
||||
pub fn request_internal<T: RequestMessage>(
|
||||
&self,
|
||||
original_sender_id: Option<ConnectionId>,
|
||||
receiver_id: ConnectionId,
|
||||
request: T,
|
||||
) -> impl Future<Output = Result<TypedEnvelope<T::Response>>> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
let send = self.connection_state(receiver_id).and_then(|connection| {
|
||||
let message_id = connection.next_message_id.fetch_add(1, SeqCst);
|
||||
connection
|
||||
.response_channels
|
||||
.lock()
|
||||
.as_mut()
|
||||
.ok_or_else(|| anyhow!("connection was closed"))?
|
||||
.insert(message_id, tx);
|
||||
connection
|
||||
.outgoing_tx
|
||||
.unbounded_send(proto::Message::Envelope(request.into_envelope(
|
||||
message_id,
|
||||
None,
|
||||
original_sender_id.map(Into::into),
|
||||
)))
|
||||
.map_err(|_| anyhow!("connection was closed"))?;
|
||||
Ok(())
|
||||
});
|
||||
async move {
|
||||
send?;
|
||||
let (response, _barrier) = rx.await.map_err(|_| anyhow!("connection was closed"))?;
|
||||
|
||||
if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
|
||||
Err(anyhow!(
|
||||
"RPC request {} failed - {}",
|
||||
T::NAME,
|
||||
error.message
|
||||
))
|
||||
} else {
|
||||
Ok(TypedEnvelope {
|
||||
message_id: response.id,
|
||||
sender_id: receiver_id,
|
||||
original_sender_id: response.original_sender_id,
|
||||
payload: T::Response::from_envelope(response)
|
||||
.ok_or_else(|| anyhow!("received response of the wrong type"))?,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn send<T: EnvelopedMessage>(&self, receiver_id: ConnectionId, message: T) -> Result<()> {
|
||||
let connection = self.connection_state(receiver_id)?;
|
||||
let message_id = connection
|
||||
.next_message_id
|
||||
.fetch_add(1, atomic::Ordering::SeqCst);
|
||||
connection
|
||||
.outgoing_tx
|
||||
.unbounded_send(proto::Message::Envelope(
|
||||
message.into_envelope(message_id, None, None),
|
||||
))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn forward_send<T: EnvelopedMessage>(
|
||||
&self,
|
||||
sender_id: ConnectionId,
|
||||
receiver_id: ConnectionId,
|
||||
message: T,
|
||||
) -> Result<()> {
|
||||
let connection = self.connection_state(receiver_id)?;
|
||||
let message_id = connection
|
||||
.next_message_id
|
||||
.fetch_add(1, atomic::Ordering::SeqCst);
|
||||
connection
|
||||
.outgoing_tx
|
||||
.unbounded_send(proto::Message::Envelope(message.into_envelope(
|
||||
message_id,
|
||||
None,
|
||||
Some(sender_id.into()),
|
||||
)))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn respond<T: RequestMessage>(
|
||||
&self,
|
||||
receipt: Receipt<T>,
|
||||
response: T::Response,
|
||||
) -> Result<()> {
|
||||
let connection = self.connection_state(receipt.sender_id)?;
|
||||
let message_id = connection
|
||||
.next_message_id
|
||||
.fetch_add(1, atomic::Ordering::SeqCst);
|
||||
connection
|
||||
.outgoing_tx
|
||||
.unbounded_send(proto::Message::Envelope(response.into_envelope(
|
||||
message_id,
|
||||
Some(receipt.message_id),
|
||||
None,
|
||||
)))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn respond_with_error<T: RequestMessage>(
|
||||
&self,
|
||||
receipt: Receipt<T>,
|
||||
response: proto::Error,
|
||||
) -> Result<()> {
|
||||
let connection = self.connection_state(receipt.sender_id)?;
|
||||
let message_id = connection
|
||||
.next_message_id
|
||||
.fetch_add(1, atomic::Ordering::SeqCst);
|
||||
connection
|
||||
.outgoing_tx
|
||||
.unbounded_send(proto::Message::Envelope(response.into_envelope(
|
||||
message_id,
|
||||
Some(receipt.message_id),
|
||||
None,
|
||||
)))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn respond_with_unhandled_message(
|
||||
&self,
|
||||
envelope: Box<dyn AnyTypedEnvelope>,
|
||||
) -> Result<()> {
|
||||
let connection = self.connection_state(envelope.sender_id())?;
|
||||
let response = proto::Error {
|
||||
message: format!("message {} was not handled", envelope.payload_type_name()),
|
||||
};
|
||||
let message_id = connection
|
||||
.next_message_id
|
||||
.fetch_add(1, atomic::Ordering::SeqCst);
|
||||
connection
|
||||
.outgoing_tx
|
||||
.unbounded_send(proto::Message::Envelope(response.into_envelope(
|
||||
message_id,
|
||||
Some(envelope.message_id()),
|
||||
None,
|
||||
)))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn connection_state(&self, connection_id: ConnectionId) -> Result<ConnectionState> {
|
||||
let connections = self.connections.read();
|
||||
let connection = connections
|
||||
.get(&connection_id)
|
||||
.ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
|
||||
Ok(connection.clone())
|
||||
}
|
||||
}
|
||||
|
||||
impl Serialize for Peer {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
let mut state = serializer.serialize_struct("Peer", 2)?;
|
||||
state.serialize_field("connections", &*self.connections.read())?;
|
||||
state.end()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::TypedEnvelope;
|
||||
use async_tungstenite::tungstenite::Message as WebSocketMessage;
|
||||
use gpui::TestAppContext;
|
||||
|
||||
fn init_logger() {
|
||||
if std::env::var("RUST_LOG").is_ok() {
|
||||
env_logger::init();
|
||||
}
|
||||
}
|
||||
|
||||
#[gpui::test(iterations = 50)]
|
||||
async fn test_request_response(cx: &mut TestAppContext) {
|
||||
init_logger();
|
||||
|
||||
let executor = cx.executor();
|
||||
|
||||
// create 2 clients connected to 1 server
|
||||
let server = Peer::new(0);
|
||||
let client1 = Peer::new(0);
|
||||
let client2 = Peer::new(0);
|
||||
|
||||
let (client1_to_server_conn, server_to_client_1_conn, _kill) =
|
||||
Connection::in_memory(cx.executor());
|
||||
let (client1_conn_id, io_task1, client1_incoming) =
|
||||
client1.add_test_connection(client1_to_server_conn, cx.executor());
|
||||
let (_, io_task2, server_incoming1) =
|
||||
server.add_test_connection(server_to_client_1_conn, cx.executor());
|
||||
|
||||
let (client2_to_server_conn, server_to_client_2_conn, _kill) =
|
||||
Connection::in_memory(cx.executor());
|
||||
let (client2_conn_id, io_task3, client2_incoming) =
|
||||
client2.add_test_connection(client2_to_server_conn, cx.executor());
|
||||
let (_, io_task4, server_incoming2) =
|
||||
server.add_test_connection(server_to_client_2_conn, cx.executor());
|
||||
|
||||
executor.spawn(io_task1).detach();
|
||||
executor.spawn(io_task2).detach();
|
||||
executor.spawn(io_task3).detach();
|
||||
executor.spawn(io_task4).detach();
|
||||
executor
|
||||
.spawn(handle_messages(server_incoming1, server.clone()))
|
||||
.detach();
|
||||
executor
|
||||
.spawn(handle_messages(client1_incoming, client1.clone()))
|
||||
.detach();
|
||||
executor
|
||||
.spawn(handle_messages(server_incoming2, server.clone()))
|
||||
.detach();
|
||||
executor
|
||||
.spawn(handle_messages(client2_incoming, client2.clone()))
|
||||
.detach();
|
||||
|
||||
assert_eq!(
|
||||
client1
|
||||
.request(client1_conn_id, proto::Ping {},)
|
||||
.await
|
||||
.unwrap(),
|
||||
proto::Ack {}
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
client2
|
||||
.request(client2_conn_id, proto::Ping {},)
|
||||
.await
|
||||
.unwrap(),
|
||||
proto::Ack {}
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
client1
|
||||
.request(client1_conn_id, proto::Test { id: 1 },)
|
||||
.await
|
||||
.unwrap(),
|
||||
proto::Test { id: 1 }
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
client2
|
||||
.request(client2_conn_id, proto::Test { id: 2 })
|
||||
.await
|
||||
.unwrap(),
|
||||
proto::Test { id: 2 }
|
||||
);
|
||||
|
||||
client1.disconnect(client1_conn_id);
|
||||
client2.disconnect(client1_conn_id);
|
||||
|
||||
async fn handle_messages(
|
||||
mut messages: BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
|
||||
peer: Arc<Peer>,
|
||||
) -> Result<()> {
|
||||
while let Some(envelope) = messages.next().await {
|
||||
let envelope = envelope.into_any();
|
||||
if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() {
|
||||
let receipt = envelope.receipt();
|
||||
peer.respond(receipt, proto::Ack {})?
|
||||
} else if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Test>>()
|
||||
{
|
||||
peer.respond(envelope.receipt(), envelope.payload.clone())?
|
||||
} else {
|
||||
panic!("unknown message type");
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[gpui::test(iterations = 50)]
|
||||
async fn test_order_of_response_and_incoming(cx: &mut TestAppContext) {
|
||||
let executor = cx.executor();
|
||||
let server = Peer::new(0);
|
||||
let client = Peer::new(0);
|
||||
|
||||
let (client_to_server_conn, server_to_client_conn, _kill) =
|
||||
Connection::in_memory(executor.clone());
|
||||
let (client_to_server_conn_id, io_task1, mut client_incoming) =
|
||||
client.add_test_connection(client_to_server_conn, executor.clone());
|
||||
|
||||
let (server_to_client_conn_id, io_task2, mut server_incoming) =
|
||||
server.add_test_connection(server_to_client_conn, executor.clone());
|
||||
|
||||
executor.spawn(io_task1).detach();
|
||||
executor.spawn(io_task2).detach();
|
||||
|
||||
executor
|
||||
.spawn(async move {
|
||||
let future = server_incoming.next().await;
|
||||
let request = future
|
||||
.unwrap()
|
||||
.into_any()
|
||||
.downcast::<TypedEnvelope<proto::Ping>>()
|
||||
.unwrap();
|
||||
|
||||
server
|
||||
.send(
|
||||
server_to_client_conn_id,
|
||||
proto::Error {
|
||||
message: "message 1".to_string(),
|
||||
},
|
||||
)
|
||||
.unwrap();
|
||||
server
|
||||
.send(
|
||||
server_to_client_conn_id,
|
||||
proto::Error {
|
||||
message: "message 2".to_string(),
|
||||
},
|
||||
)
|
||||
.unwrap();
|
||||
server.respond(request.receipt(), proto::Ack {}).unwrap();
|
||||
|
||||
// Prevent the connection from being dropped
|
||||
server_incoming.next().await;
|
||||
})
|
||||
.detach();
|
||||
|
||||
let events = Arc::new(Mutex::new(Vec::new()));
|
||||
|
||||
let response = client.request(client_to_server_conn_id, proto::Ping {});
|
||||
let response_task = executor.spawn({
|
||||
let events = events.clone();
|
||||
async move {
|
||||
response.await.unwrap();
|
||||
events.lock().push("response".to_string());
|
||||
}
|
||||
});
|
||||
|
||||
executor
|
||||
.spawn({
|
||||
let events = events.clone();
|
||||
async move {
|
||||
let incoming1 = client_incoming
|
||||
.next()
|
||||
.await
|
||||
.unwrap()
|
||||
.into_any()
|
||||
.downcast::<TypedEnvelope<proto::Error>>()
|
||||
.unwrap();
|
||||
events.lock().push(incoming1.payload.message);
|
||||
let incoming2 = client_incoming
|
||||
.next()
|
||||
.await
|
||||
.unwrap()
|
||||
.into_any()
|
||||
.downcast::<TypedEnvelope<proto::Error>>()
|
||||
.unwrap();
|
||||
events.lock().push(incoming2.payload.message);
|
||||
|
||||
// Prevent the connection from being dropped
|
||||
client_incoming.next().await;
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
|
||||
response_task.await;
|
||||
assert_eq!(
|
||||
&*events.lock(),
|
||||
&[
|
||||
"message 1".to_string(),
|
||||
"message 2".to_string(),
|
||||
"response".to_string()
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test(iterations = 50)]
|
||||
async fn test_dropping_request_before_completion(cx: &mut TestAppContext) {
|
||||
let executor = cx.executor();
|
||||
let server = Peer::new(0);
|
||||
let client = Peer::new(0);
|
||||
|
||||
let (client_to_server_conn, server_to_client_conn, _kill) =
|
||||
Connection::in_memory(cx.executor());
|
||||
let (client_to_server_conn_id, io_task1, mut client_incoming) =
|
||||
client.add_test_connection(client_to_server_conn, cx.executor());
|
||||
let (server_to_client_conn_id, io_task2, mut server_incoming) =
|
||||
server.add_test_connection(server_to_client_conn, cx.executor());
|
||||
|
||||
executor.spawn(io_task1).detach();
|
||||
executor.spawn(io_task2).detach();
|
||||
|
||||
executor
|
||||
.spawn(async move {
|
||||
let request1 = server_incoming
|
||||
.next()
|
||||
.await
|
||||
.unwrap()
|
||||
.into_any()
|
||||
.downcast::<TypedEnvelope<proto::Ping>>()
|
||||
.unwrap();
|
||||
let request2 = server_incoming
|
||||
.next()
|
||||
.await
|
||||
.unwrap()
|
||||
.into_any()
|
||||
.downcast::<TypedEnvelope<proto::Ping>>()
|
||||
.unwrap();
|
||||
|
||||
server
|
||||
.send(
|
||||
server_to_client_conn_id,
|
||||
proto::Error {
|
||||
message: "message 1".to_string(),
|
||||
},
|
||||
)
|
||||
.unwrap();
|
||||
server
|
||||
.send(
|
||||
server_to_client_conn_id,
|
||||
proto::Error {
|
||||
message: "message 2".to_string(),
|
||||
},
|
||||
)
|
||||
.unwrap();
|
||||
server.respond(request1.receipt(), proto::Ack {}).unwrap();
|
||||
server.respond(request2.receipt(), proto::Ack {}).unwrap();
|
||||
|
||||
// Prevent the connection from being dropped
|
||||
server_incoming.next().await;
|
||||
})
|
||||
.detach();
|
||||
|
||||
let events = Arc::new(Mutex::new(Vec::new()));
|
||||
|
||||
let request1 = client.request(client_to_server_conn_id, proto::Ping {});
|
||||
let request1_task = executor.spawn(request1);
|
||||
let request2 = client.request(client_to_server_conn_id, proto::Ping {});
|
||||
let request2_task = executor.spawn({
|
||||
let events = events.clone();
|
||||
async move {
|
||||
request2.await.unwrap();
|
||||
events.lock().push("response 2".to_string());
|
||||
}
|
||||
});
|
||||
|
||||
executor
|
||||
.spawn({
|
||||
let events = events.clone();
|
||||
async move {
|
||||
let incoming1 = client_incoming
|
||||
.next()
|
||||
.await
|
||||
.unwrap()
|
||||
.into_any()
|
||||
.downcast::<TypedEnvelope<proto::Error>>()
|
||||
.unwrap();
|
||||
events.lock().push(incoming1.payload.message);
|
||||
let incoming2 = client_incoming
|
||||
.next()
|
||||
.await
|
||||
.unwrap()
|
||||
.into_any()
|
||||
.downcast::<TypedEnvelope<proto::Error>>()
|
||||
.unwrap();
|
||||
events.lock().push(incoming2.payload.message);
|
||||
|
||||
// Prevent the connection from being dropped
|
||||
client_incoming.next().await;
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
|
||||
// Allow the request to make some progress before dropping it.
|
||||
cx.executor().simulate_random_delay().await;
|
||||
drop(request1_task);
|
||||
|
||||
request2_task.await;
|
||||
assert_eq!(
|
||||
&*events.lock(),
|
||||
&[
|
||||
"message 1".to_string(),
|
||||
"message 2".to_string(),
|
||||
"response 2".to_string()
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test(iterations = 50)]
|
||||
async fn test_disconnect(cx: &mut TestAppContext) {
|
||||
let executor = cx.executor();
|
||||
|
||||
let (client_conn, mut server_conn, _kill) = Connection::in_memory(executor.clone());
|
||||
|
||||
let client = Peer::new(0);
|
||||
let (connection_id, io_handler, mut incoming) =
|
||||
client.add_test_connection(client_conn, executor.clone());
|
||||
|
||||
let (io_ended_tx, io_ended_rx) = oneshot::channel();
|
||||
executor
|
||||
.spawn(async move {
|
||||
io_handler.await.ok();
|
||||
io_ended_tx.send(()).unwrap();
|
||||
})
|
||||
.detach();
|
||||
|
||||
let (messages_ended_tx, messages_ended_rx) = oneshot::channel();
|
||||
executor
|
||||
.spawn(async move {
|
||||
incoming.next().await;
|
||||
messages_ended_tx.send(()).unwrap();
|
||||
})
|
||||
.detach();
|
||||
|
||||
client.disconnect(connection_id);
|
||||
|
||||
let _ = io_ended_rx.await;
|
||||
let _ = messages_ended_rx.await;
|
||||
assert!(server_conn
|
||||
.send(WebSocketMessage::Binary(vec![]))
|
||||
.await
|
||||
.is_err());
|
||||
}
|
||||
|
||||
#[gpui::test(iterations = 50)]
|
||||
async fn test_io_error(cx: &mut TestAppContext) {
|
||||
let executor = cx.executor();
|
||||
let (client_conn, mut server_conn, _kill) = Connection::in_memory(executor.clone());
|
||||
|
||||
let client = Peer::new(0);
|
||||
let (connection_id, io_handler, mut incoming) =
|
||||
client.add_test_connection(client_conn, executor.clone());
|
||||
executor.spawn(io_handler).detach();
|
||||
executor
|
||||
.spawn(async move { incoming.next().await })
|
||||
.detach();
|
||||
|
||||
let response = executor.spawn(client.request(connection_id, proto::Ping {}));
|
||||
let _request = server_conn.rx.next().await.unwrap().unwrap();
|
||||
|
||||
drop(server_conn);
|
||||
assert_eq!(
|
||||
response.await.unwrap_err().to_string(),
|
||||
"connection was closed"
|
||||
);
|
||||
}
|
||||
}
|
@ -1,692 +0,0 @@
|
||||
#![allow(non_snake_case)]
|
||||
|
||||
use super::{entity_messages, messages, request_messages, ConnectionId, TypedEnvelope};
|
||||
use anyhow::{anyhow, Result};
|
||||
use async_tungstenite::tungstenite::Message as WebSocketMessage;
|
||||
use collections::HashMap;
|
||||
use futures::{SinkExt as _, StreamExt as _};
|
||||
use prost::Message as _;
|
||||
use serde::Serialize;
|
||||
use std::any::{Any, TypeId};
|
||||
use std::{
|
||||
cmp,
|
||||
fmt::Debug,
|
||||
io, iter,
|
||||
time::{Duration, SystemTime, UNIX_EPOCH},
|
||||
};
|
||||
use std::{fmt, mem};
|
||||
|
||||
include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
|
||||
|
||||
pub trait EnvelopedMessage: Clone + Debug + Serialize + Sized + Send + Sync + 'static {
|
||||
const NAME: &'static str;
|
||||
const PRIORITY: MessagePriority;
|
||||
fn into_envelope(
|
||||
self,
|
||||
id: u32,
|
||||
responding_to: Option<u32>,
|
||||
original_sender_id: Option<PeerId>,
|
||||
) -> Envelope;
|
||||
fn from_envelope(envelope: Envelope) -> Option<Self>;
|
||||
}
|
||||
|
||||
pub trait EntityMessage: EnvelopedMessage {
|
||||
fn remote_entity_id(&self) -> u64;
|
||||
}
|
||||
|
||||
pub trait RequestMessage: EnvelopedMessage {
|
||||
type Response: EnvelopedMessage;
|
||||
}
|
||||
|
||||
pub trait AnyTypedEnvelope: 'static + Send + Sync {
|
||||
fn payload_type_id(&self) -> TypeId;
|
||||
fn payload_type_name(&self) -> &'static str;
|
||||
fn as_any(&self) -> &dyn Any;
|
||||
fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
|
||||
fn is_background(&self) -> bool;
|
||||
fn original_sender_id(&self) -> Option<PeerId>;
|
||||
fn sender_id(&self) -> ConnectionId;
|
||||
fn message_id(&self) -> u32;
|
||||
}
|
||||
|
||||
pub enum MessagePriority {
|
||||
Foreground,
|
||||
Background,
|
||||
}
|
||||
|
||||
impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
|
||||
fn payload_type_id(&self) -> TypeId {
|
||||
TypeId::of::<T>()
|
||||
}
|
||||
|
||||
fn payload_type_name(&self) -> &'static str {
|
||||
T::NAME
|
||||
}
|
||||
|
||||
fn as_any(&self) -> &dyn Any {
|
||||
self
|
||||
}
|
||||
|
||||
fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
|
||||
self
|
||||
}
|
||||
|
||||
fn is_background(&self) -> bool {
|
||||
matches!(T::PRIORITY, MessagePriority::Background)
|
||||
}
|
||||
|
||||
fn original_sender_id(&self) -> Option<PeerId> {
|
||||
self.original_sender_id
|
||||
}
|
||||
|
||||
fn sender_id(&self) -> ConnectionId {
|
||||
self.sender_id
|
||||
}
|
||||
|
||||
fn message_id(&self) -> u32 {
|
||||
self.message_id
|
||||
}
|
||||
}
|
||||
|
||||
impl PeerId {
|
||||
pub fn from_u64(peer_id: u64) -> Self {
|
||||
let owner_id = (peer_id >> 32) as u32;
|
||||
let id = peer_id as u32;
|
||||
Self { owner_id, id }
|
||||
}
|
||||
|
||||
pub fn as_u64(self) -> u64 {
|
||||
((self.owner_id as u64) << 32) | (self.id as u64)
|
||||
}
|
||||
}
|
||||
|
||||
impl Copy for PeerId {}
|
||||
|
||||
impl Eq for PeerId {}
|
||||
|
||||
impl Ord for PeerId {
|
||||
fn cmp(&self, other: &Self) -> cmp::Ordering {
|
||||
self.owner_id
|
||||
.cmp(&other.owner_id)
|
||||
.then_with(|| self.id.cmp(&other.id))
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialOrd for PeerId {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
|
||||
Some(self.cmp(other))
|
||||
}
|
||||
}
|
||||
|
||||
impl std::hash::Hash for PeerId {
|
||||
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
|
||||
self.owner_id.hash(state);
|
||||
self.id.hash(state);
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for PeerId {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "{}/{}", self.owner_id, self.id)
|
||||
}
|
||||
}
|
||||
|
||||
messages!(
|
||||
(Ack, Foreground),
|
||||
(AckBufferOperation, Background),
|
||||
(AckChannelMessage, Background),
|
||||
(AddNotification, Foreground),
|
||||
(AddProjectCollaborator, Foreground),
|
||||
(ApplyCodeAction, Background),
|
||||
(ApplyCodeActionResponse, Background),
|
||||
(ApplyCompletionAdditionalEdits, Background),
|
||||
(ApplyCompletionAdditionalEditsResponse, Background),
|
||||
(BufferReloaded, Foreground),
|
||||
(BufferSaved, Foreground),
|
||||
(Call, Foreground),
|
||||
(CallCanceled, Foreground),
|
||||
(CancelCall, Foreground),
|
||||
(ChannelMessageSent, Foreground),
|
||||
(CopyProjectEntry, Foreground),
|
||||
(CreateBufferForPeer, Foreground),
|
||||
(CreateChannel, Foreground),
|
||||
(CreateChannelResponse, Foreground),
|
||||
(CreateProjectEntry, Foreground),
|
||||
(CreateRoom, Foreground),
|
||||
(CreateRoomResponse, Foreground),
|
||||
(DeclineCall, Foreground),
|
||||
(DeleteChannel, Foreground),
|
||||
(DeleteNotification, Foreground),
|
||||
(DeleteProjectEntry, Foreground),
|
||||
(Error, Foreground),
|
||||
(ExpandProjectEntry, Foreground),
|
||||
(ExpandProjectEntryResponse, Foreground),
|
||||
(Follow, Foreground),
|
||||
(FollowResponse, Foreground),
|
||||
(FormatBuffers, Foreground),
|
||||
(FormatBuffersResponse, Foreground),
|
||||
(FuzzySearchUsers, Foreground),
|
||||
(GetChannelMembers, Foreground),
|
||||
(GetChannelMembersResponse, Foreground),
|
||||
(GetChannelMessages, Background),
|
||||
(GetChannelMessagesById, Background),
|
||||
(GetChannelMessagesResponse, Background),
|
||||
(GetCodeActions, Background),
|
||||
(GetCodeActionsResponse, Background),
|
||||
(GetCompletions, Background),
|
||||
(GetCompletionsResponse, Background),
|
||||
(GetDefinition, Background),
|
||||
(GetDefinitionResponse, Background),
|
||||
(GetDocumentHighlights, Background),
|
||||
(GetDocumentHighlightsResponse, Background),
|
||||
(GetHover, Background),
|
||||
(GetHoverResponse, Background),
|
||||
(GetNotifications, Foreground),
|
||||
(GetNotificationsResponse, Foreground),
|
||||
(GetPrivateUserInfo, Foreground),
|
||||
(GetPrivateUserInfoResponse, Foreground),
|
||||
(GetProjectSymbols, Background),
|
||||
(GetProjectSymbolsResponse, Background),
|
||||
(GetReferences, Background),
|
||||
(GetReferencesResponse, Background),
|
||||
(GetTypeDefinition, Background),
|
||||
(GetTypeDefinitionResponse, Background),
|
||||
(GetUsers, Foreground),
|
||||
(Hello, Foreground),
|
||||
(IncomingCall, Foreground),
|
||||
(InlayHints, Background),
|
||||
(InlayHintsResponse, Background),
|
||||
(InviteChannelMember, Foreground),
|
||||
(JoinChannel, Foreground),
|
||||
(JoinChannelBuffer, Foreground),
|
||||
(JoinChannelBufferResponse, Foreground),
|
||||
(JoinChannelChat, Foreground),
|
||||
(JoinChannelChatResponse, Foreground),
|
||||
(JoinProject, Foreground),
|
||||
(JoinProjectResponse, Foreground),
|
||||
(JoinRoom, Foreground),
|
||||
(JoinRoomResponse, Foreground),
|
||||
(LeaveChannelBuffer, Background),
|
||||
(LeaveChannelChat, Foreground),
|
||||
(LeaveProject, Foreground),
|
||||
(LeaveRoom, Foreground),
|
||||
(MarkNotificationRead, Foreground),
|
||||
(MoveChannel, Foreground),
|
||||
(OnTypeFormatting, Background),
|
||||
(OnTypeFormattingResponse, Background),
|
||||
(OpenBufferById, Background),
|
||||
(OpenBufferByPath, Background),
|
||||
(OpenBufferForSymbol, Background),
|
||||
(OpenBufferForSymbolResponse, Background),
|
||||
(OpenBufferResponse, Background),
|
||||
(PerformRename, Background),
|
||||
(PerformRenameResponse, Background),
|
||||
(Ping, Foreground),
|
||||
(PrepareRename, Background),
|
||||
(PrepareRenameResponse, Background),
|
||||
(ProjectEntryResponse, Foreground),
|
||||
(RefreshInlayHints, Foreground),
|
||||
(RejoinChannelBuffers, Foreground),
|
||||
(RejoinChannelBuffersResponse, Foreground),
|
||||
(RejoinRoom, Foreground),
|
||||
(RejoinRoomResponse, Foreground),
|
||||
(ReloadBuffers, Foreground),
|
||||
(ReloadBuffersResponse, Foreground),
|
||||
(RemoveChannelMember, Foreground),
|
||||
(RemoveChannelMessage, Foreground),
|
||||
(RemoveContact, Foreground),
|
||||
(RemoveProjectCollaborator, Foreground),
|
||||
(RenameChannel, Foreground),
|
||||
(RenameChannelResponse, Foreground),
|
||||
(RenameProjectEntry, Foreground),
|
||||
(RequestContact, Foreground),
|
||||
(ResolveCompletionDocumentation, Background),
|
||||
(ResolveCompletionDocumentationResponse, Background),
|
||||
(ResolveInlayHint, Background),
|
||||
(ResolveInlayHintResponse, Background),
|
||||
(RespondToChannelInvite, Foreground),
|
||||
(RespondToContactRequest, Foreground),
|
||||
(RoomUpdated, Foreground),
|
||||
(SaveBuffer, Foreground),
|
||||
(SetChannelMemberRole, Foreground),
|
||||
(SetChannelVisibility, Foreground),
|
||||
(SearchProject, Background),
|
||||
(SearchProjectResponse, Background),
|
||||
(SendChannelMessage, Background),
|
||||
(SendChannelMessageResponse, Background),
|
||||
(ShareProject, Foreground),
|
||||
(ShareProjectResponse, Foreground),
|
||||
(ShowContacts, Foreground),
|
||||
(StartLanguageServer, Foreground),
|
||||
(SynchronizeBuffers, Foreground),
|
||||
(SynchronizeBuffersResponse, Foreground),
|
||||
(Test, Foreground),
|
||||
(Unfollow, Foreground),
|
||||
(UnshareProject, Foreground),
|
||||
(UpdateBuffer, Foreground),
|
||||
(UpdateBufferFile, Foreground),
|
||||
(UpdateChannelBuffer, Foreground),
|
||||
(UpdateChannelBufferCollaborators, Foreground),
|
||||
(UpdateChannels, Foreground),
|
||||
(UpdateContacts, Foreground),
|
||||
(UpdateDiagnosticSummary, Foreground),
|
||||
(UpdateDiffBase, Foreground),
|
||||
(UpdateFollowers, Foreground),
|
||||
(UpdateInviteInfo, Foreground),
|
||||
(UpdateLanguageServer, Foreground),
|
||||
(UpdateParticipantLocation, Foreground),
|
||||
(UpdateProject, Foreground),
|
||||
(UpdateProjectCollaborator, Foreground),
|
||||
(UpdateWorktree, Foreground),
|
||||
(UpdateWorktreeSettings, Foreground),
|
||||
(UsersResponse, Foreground),
|
||||
(LspExtExpandMacro, Background),
|
||||
(LspExtExpandMacroResponse, Background),
|
||||
);
|
||||
|
||||
request_messages!(
|
||||
(ApplyCodeAction, ApplyCodeActionResponse),
|
||||
(
|
||||
ApplyCompletionAdditionalEdits,
|
||||
ApplyCompletionAdditionalEditsResponse
|
||||
),
|
||||
(Call, Ack),
|
||||
(CancelCall, Ack),
|
||||
(CopyProjectEntry, ProjectEntryResponse),
|
||||
(CreateChannel, CreateChannelResponse),
|
||||
(CreateProjectEntry, ProjectEntryResponse),
|
||||
(CreateRoom, CreateRoomResponse),
|
||||
(DeclineCall, Ack),
|
||||
(DeleteChannel, Ack),
|
||||
(DeleteProjectEntry, ProjectEntryResponse),
|
||||
(ExpandProjectEntry, ExpandProjectEntryResponse),
|
||||
(Follow, FollowResponse),
|
||||
(FormatBuffers, FormatBuffersResponse),
|
||||
(FuzzySearchUsers, UsersResponse),
|
||||
(GetChannelMembers, GetChannelMembersResponse),
|
||||
(GetChannelMessages, GetChannelMessagesResponse),
|
||||
(GetChannelMessagesById, GetChannelMessagesResponse),
|
||||
(GetCodeActions, GetCodeActionsResponse),
|
||||
(GetCompletions, GetCompletionsResponse),
|
||||
(GetDefinition, GetDefinitionResponse),
|
||||
(GetDocumentHighlights, GetDocumentHighlightsResponse),
|
||||
(GetHover, GetHoverResponse),
|
||||
(GetNotifications, GetNotificationsResponse),
|
||||
(GetPrivateUserInfo, GetPrivateUserInfoResponse),
|
||||
(GetProjectSymbols, GetProjectSymbolsResponse),
|
||||
(GetReferences, GetReferencesResponse),
|
||||
(GetTypeDefinition, GetTypeDefinitionResponse),
|
||||
(GetUsers, UsersResponse),
|
||||
(IncomingCall, Ack),
|
||||
(InlayHints, InlayHintsResponse),
|
||||
(InviteChannelMember, Ack),
|
||||
(JoinChannel, JoinRoomResponse),
|
||||
(JoinChannelBuffer, JoinChannelBufferResponse),
|
||||
(JoinChannelChat, JoinChannelChatResponse),
|
||||
(JoinProject, JoinProjectResponse),
|
||||
(JoinRoom, JoinRoomResponse),
|
||||
(LeaveChannelBuffer, Ack),
|
||||
(LeaveRoom, Ack),
|
||||
(MarkNotificationRead, Ack),
|
||||
(MoveChannel, Ack),
|
||||
(OnTypeFormatting, OnTypeFormattingResponse),
|
||||
(OpenBufferById, OpenBufferResponse),
|
||||
(OpenBufferByPath, OpenBufferResponse),
|
||||
(OpenBufferForSymbol, OpenBufferForSymbolResponse),
|
||||
(PerformRename, PerformRenameResponse),
|
||||
(Ping, Ack),
|
||||
(PrepareRename, PrepareRenameResponse),
|
||||
(RefreshInlayHints, Ack),
|
||||
(RejoinChannelBuffers, RejoinChannelBuffersResponse),
|
||||
(RejoinRoom, RejoinRoomResponse),
|
||||
(ReloadBuffers, ReloadBuffersResponse),
|
||||
(RemoveChannelMember, Ack),
|
||||
(RemoveChannelMessage, Ack),
|
||||
(RemoveContact, Ack),
|
||||
(RenameChannel, RenameChannelResponse),
|
||||
(RenameProjectEntry, ProjectEntryResponse),
|
||||
(RequestContact, Ack),
|
||||
(
|
||||
ResolveCompletionDocumentation,
|
||||
ResolveCompletionDocumentationResponse
|
||||
),
|
||||
(ResolveInlayHint, ResolveInlayHintResponse),
|
||||
(RespondToChannelInvite, Ack),
|
||||
(RespondToContactRequest, Ack),
|
||||
(SaveBuffer, BufferSaved),
|
||||
(SearchProject, SearchProjectResponse),
|
||||
(SendChannelMessage, SendChannelMessageResponse),
|
||||
(SetChannelMemberRole, Ack),
|
||||
(SetChannelVisibility, Ack),
|
||||
(ShareProject, ShareProjectResponse),
|
||||
(SynchronizeBuffers, SynchronizeBuffersResponse),
|
||||
(Test, Test),
|
||||
(UpdateBuffer, Ack),
|
||||
(UpdateParticipantLocation, Ack),
|
||||
(UpdateProject, Ack),
|
||||
(UpdateWorktree, Ack),
|
||||
(LspExtExpandMacro, LspExtExpandMacroResponse),
|
||||
);
|
||||
|
||||
entity_messages!(
|
||||
project_id,
|
||||
AddProjectCollaborator,
|
||||
ApplyCodeAction,
|
||||
ApplyCompletionAdditionalEdits,
|
||||
BufferReloaded,
|
||||
BufferSaved,
|
||||
CopyProjectEntry,
|
||||
CreateBufferForPeer,
|
||||
CreateProjectEntry,
|
||||
DeleteProjectEntry,
|
||||
ExpandProjectEntry,
|
||||
FormatBuffers,
|
||||
GetCodeActions,
|
||||
GetCompletions,
|
||||
GetDefinition,
|
||||
GetDocumentHighlights,
|
||||
GetHover,
|
||||
GetProjectSymbols,
|
||||
GetReferences,
|
||||
GetTypeDefinition,
|
||||
InlayHints,
|
||||
JoinProject,
|
||||
LeaveProject,
|
||||
OnTypeFormatting,
|
||||
OpenBufferById,
|
||||
OpenBufferByPath,
|
||||
OpenBufferForSymbol,
|
||||
PerformRename,
|
||||
PrepareRename,
|
||||
RefreshInlayHints,
|
||||
ReloadBuffers,
|
||||
RemoveProjectCollaborator,
|
||||
RenameProjectEntry,
|
||||
ResolveCompletionDocumentation,
|
||||
ResolveInlayHint,
|
||||
SaveBuffer,
|
||||
SearchProject,
|
||||
StartLanguageServer,
|
||||
SynchronizeBuffers,
|
||||
UnshareProject,
|
||||
UpdateBuffer,
|
||||
UpdateBufferFile,
|
||||
UpdateDiagnosticSummary,
|
||||
UpdateDiffBase,
|
||||
UpdateLanguageServer,
|
||||
UpdateProject,
|
||||
UpdateProjectCollaborator,
|
||||
UpdateWorktree,
|
||||
UpdateWorktreeSettings,
|
||||
LspExtExpandMacro,
|
||||
);
|
||||
|
||||
entity_messages!(
|
||||
channel_id,
|
||||
ChannelMessageSent,
|
||||
RemoveChannelMessage,
|
||||
UpdateChannelBuffer,
|
||||
UpdateChannelBufferCollaborators,
|
||||
);
|
||||
|
||||
const KIB: usize = 1024;
|
||||
const MIB: usize = KIB * 1024;
|
||||
const MAX_BUFFER_LEN: usize = MIB;
|
||||
|
||||
/// A stream of protobuf messages.
|
||||
pub struct MessageStream<S> {
|
||||
stream: S,
|
||||
encoding_buffer: Vec<u8>,
|
||||
}
|
||||
|
||||
#[allow(clippy::large_enum_variant)]
|
||||
#[derive(Debug)]
|
||||
pub enum Message {
|
||||
Envelope(Envelope),
|
||||
Ping,
|
||||
Pong,
|
||||
}
|
||||
|
||||
impl<S> MessageStream<S> {
|
||||
pub fn new(stream: S) -> Self {
|
||||
Self {
|
||||
stream,
|
||||
encoding_buffer: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn inner_mut(&mut self) -> &mut S {
|
||||
&mut self.stream
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> MessageStream<S>
|
||||
where
|
||||
S: futures::Sink<WebSocketMessage, Error = anyhow::Error> + Unpin,
|
||||
{
|
||||
pub async fn write(&mut self, message: Message) -> Result<(), anyhow::Error> {
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
const COMPRESSION_LEVEL: i32 = -7;
|
||||
|
||||
#[cfg(not(any(test, feature = "test-support")))]
|
||||
const COMPRESSION_LEVEL: i32 = 4;
|
||||
|
||||
match message {
|
||||
Message::Envelope(message) => {
|
||||
self.encoding_buffer.reserve(message.encoded_len());
|
||||
message
|
||||
.encode(&mut self.encoding_buffer)
|
||||
.map_err(io::Error::from)?;
|
||||
let buffer =
|
||||
zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL)
|
||||
.unwrap();
|
||||
|
||||
self.encoding_buffer.clear();
|
||||
self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
|
||||
self.stream.send(WebSocketMessage::Binary(buffer)).await?;
|
||||
}
|
||||
Message::Ping => {
|
||||
self.stream
|
||||
.send(WebSocketMessage::Ping(Default::default()))
|
||||
.await?;
|
||||
}
|
||||
Message::Pong => {
|
||||
self.stream
|
||||
.send(WebSocketMessage::Pong(Default::default()))
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> MessageStream<S>
|
||||
where
|
||||
S: futures::Stream<Item = Result<WebSocketMessage, anyhow::Error>> + Unpin,
|
||||
{
|
||||
pub async fn read(&mut self) -> Result<Message, anyhow::Error> {
|
||||
while let Some(bytes) = self.stream.next().await {
|
||||
match bytes? {
|
||||
WebSocketMessage::Binary(bytes) => {
|
||||
zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap();
|
||||
let envelope = Envelope::decode(self.encoding_buffer.as_slice())
|
||||
.map_err(io::Error::from)?;
|
||||
|
||||
self.encoding_buffer.clear();
|
||||
self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
|
||||
return Ok(Message::Envelope(envelope));
|
||||
}
|
||||
WebSocketMessage::Ping(_) => return Ok(Message::Ping),
|
||||
WebSocketMessage::Pong(_) => return Ok(Message::Pong),
|
||||
WebSocketMessage::Close(_) => break,
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
Err(anyhow!("connection closed"))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Timestamp> for SystemTime {
|
||||
fn from(val: Timestamp) -> Self {
|
||||
UNIX_EPOCH
|
||||
.checked_add(Duration::new(val.seconds, val.nanos))
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<SystemTime> for Timestamp {
|
||||
fn from(time: SystemTime) -> Self {
|
||||
let duration = time.duration_since(UNIX_EPOCH).unwrap();
|
||||
Self {
|
||||
seconds: duration.as_secs(),
|
||||
nanos: duration.subsec_nanos(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<u128> for Nonce {
|
||||
fn from(nonce: u128) -> Self {
|
||||
let upper_half = (nonce >> 64) as u64;
|
||||
let lower_half = nonce as u64;
|
||||
Self {
|
||||
upper_half,
|
||||
lower_half,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Nonce> for u128 {
|
||||
fn from(nonce: Nonce) -> Self {
|
||||
let upper_half = (nonce.upper_half as u128) << 64;
|
||||
let lower_half = nonce.lower_half as u128;
|
||||
upper_half | lower_half
|
||||
}
|
||||
}
|
||||
|
||||
pub fn split_worktree_update(
|
||||
mut message: UpdateWorktree,
|
||||
max_chunk_size: usize,
|
||||
) -> impl Iterator<Item = UpdateWorktree> {
|
||||
let mut done_files = false;
|
||||
|
||||
let mut repository_map = message
|
||||
.updated_repositories
|
||||
.into_iter()
|
||||
.map(|repo| (repo.work_directory_id, repo))
|
||||
.collect::<HashMap<_, _>>();
|
||||
|
||||
iter::from_fn(move || {
|
||||
if done_files {
|
||||
return None;
|
||||
}
|
||||
|
||||
let updated_entries_chunk_size = cmp::min(message.updated_entries.len(), max_chunk_size);
|
||||
let updated_entries: Vec<_> = message
|
||||
.updated_entries
|
||||
.drain(..updated_entries_chunk_size)
|
||||
.collect();
|
||||
|
||||
let removed_entries_chunk_size = cmp::min(message.removed_entries.len(), max_chunk_size);
|
||||
let removed_entries = message
|
||||
.removed_entries
|
||||
.drain(..removed_entries_chunk_size)
|
||||
.collect();
|
||||
|
||||
done_files = message.updated_entries.is_empty() && message.removed_entries.is_empty();
|
||||
|
||||
let mut updated_repositories = Vec::new();
|
||||
|
||||
if !repository_map.is_empty() {
|
||||
for entry in &updated_entries {
|
||||
if let Some(repo) = repository_map.remove(&entry.id) {
|
||||
updated_repositories.push(repo)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let removed_repositories = if done_files {
|
||||
mem::take(&mut message.removed_repositories)
|
||||
} else {
|
||||
Default::default()
|
||||
};
|
||||
|
||||
if done_files {
|
||||
updated_repositories.extend(mem::take(&mut repository_map).into_values());
|
||||
}
|
||||
|
||||
Some(UpdateWorktree {
|
||||
project_id: message.project_id,
|
||||
worktree_id: message.worktree_id,
|
||||
root_name: message.root_name.clone(),
|
||||
abs_path: message.abs_path.clone(),
|
||||
updated_entries,
|
||||
removed_entries,
|
||||
scan_id: message.scan_id,
|
||||
is_last_update: done_files && message.is_last_update,
|
||||
updated_repositories,
|
||||
removed_repositories,
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_buffer_size() {
|
||||
let (tx, rx) = futures::channel::mpsc::unbounded();
|
||||
let mut sink = MessageStream::new(tx.sink_map_err(|_| anyhow!("")));
|
||||
sink.write(Message::Envelope(Envelope {
|
||||
payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
|
||||
root_name: "abcdefg".repeat(10),
|
||||
..Default::default()
|
||||
})),
|
||||
..Default::default()
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
|
||||
sink.write(Message::Envelope(Envelope {
|
||||
payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
|
||||
root_name: "abcdefg".repeat(1000000),
|
||||
..Default::default()
|
||||
})),
|
||||
..Default::default()
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
|
||||
|
||||
let mut stream = MessageStream::new(rx.map(anyhow::Ok));
|
||||
stream.read().await.unwrap();
|
||||
assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
|
||||
stream.read().await.unwrap();
|
||||
assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
fn test_converting_peer_id_from_and_to_u64() {
|
||||
let peer_id = PeerId {
|
||||
owner_id: 10,
|
||||
id: 3,
|
||||
};
|
||||
assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
|
||||
let peer_id = PeerId {
|
||||
owner_id: u32::MAX,
|
||||
id: 3,
|
||||
};
|
||||
assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
|
||||
let peer_id = PeerId {
|
||||
owner_id: 10,
|
||||
id: u32::MAX,
|
||||
};
|
||||
assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
|
||||
let peer_id = PeerId {
|
||||
owner_id: u32::MAX,
|
||||
id: u32::MAX,
|
||||
};
|
||||
assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
|
||||
}
|
||||
}
|
@ -1,12 +0,0 @@
|
||||
pub mod auth;
|
||||
mod conn;
|
||||
mod notification;
|
||||
mod peer;
|
||||
pub mod proto;
|
||||
|
||||
pub use conn::Connection;
|
||||
pub use notification::*;
|
||||
pub use peer::*;
|
||||
mod macros;
|
||||
|
||||
pub const PROTOCOL_VERSION: u32 = 67;
|
@ -21,7 +21,7 @@ theme = { package = "theme2", path = "../theme2" }
|
||||
util = { path = "../util" }
|
||||
ui = {package = "ui2", path = "../ui2"}
|
||||
workspace = { path = "../workspace" }
|
||||
semantic_index = { package = "semantic_index2", path = "../semantic_index2" }
|
||||
semantic_index = { path = "../semantic_index" }
|
||||
anyhow.workspace = true
|
||||
futures.workspace = true
|
||||
log.workspace = true
|
||||
|
@ -11,16 +11,13 @@ doctest = false
|
||||
[dependencies]
|
||||
ai = { path = "../ai" }
|
||||
collections = { path = "../collections" }
|
||||
gpui = { path = "../gpui" }
|
||||
gpui = { package = "gpui2", path = "../gpui2" }
|
||||
language = { path = "../language" }
|
||||
project = { path = "../project" }
|
||||
workspace = { path = "../workspace" }
|
||||
util = { path = "../util" }
|
||||
picker = { path = "../picker" }
|
||||
theme = { path = "../theme" }
|
||||
editor = { path = "../editor" }
|
||||
rpc = { path = "../rpc" }
|
||||
settings = { path = "../settings" }
|
||||
settings = { package = "settings2", path = "../settings2" }
|
||||
anyhow.workspace = true
|
||||
postage.workspace = true
|
||||
futures.workspace = true
|
||||
@ -44,12 +41,12 @@ ndarray = { version = "0.15.0" }
|
||||
[dev-dependencies]
|
||||
ai = { path = "../ai", features = ["test-support"] }
|
||||
collections = { path = "../collections", features = ["test-support"] }
|
||||
gpui = { path = "../gpui", features = ["test-support"] }
|
||||
gpui = { package = "gpui2", path = "../gpui2", features = ["test-support"] }
|
||||
language = { path = "../language", features = ["test-support"] }
|
||||
project = { path = "../project", features = ["test-support"] }
|
||||
rpc = { path = "../rpc", features = ["test-support"] }
|
||||
workspace = { path = "../workspace", features = ["test-support"] }
|
||||
settings = { path = "../settings", features = ["test-support"]}
|
||||
settings = { package = "settings2", path = "../settings2", features = ["test-support"]}
|
||||
rust-embed = { version = "8.0", features = ["include-exclude"] }
|
||||
client = { path = "../client" }
|
||||
node_runtime = { path = "../node_runtime"}
|
||||
|
@ -6,7 +6,7 @@ use ai::embedding::Embedding;
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use collections::HashMap;
|
||||
use futures::channel::oneshot;
|
||||
use gpui::executor;
|
||||
use gpui::BackgroundExecutor;
|
||||
use ndarray::{Array1, Array2};
|
||||
use ordered_float::OrderedFloat;
|
||||
use project::Fs;
|
||||
@ -48,7 +48,7 @@ impl VectorDatabase {
|
||||
pub async fn new(
|
||||
fs: Arc<dyn Fs>,
|
||||
path: Arc<Path>,
|
||||
executor: Arc<executor::Background>,
|
||||
executor: BackgroundExecutor,
|
||||
) -> Result<Self> {
|
||||
if let Some(db_directory) = path.parent() {
|
||||
fs.create_dir(db_directory).await?;
|
||||
|
@ -1,6 +1,6 @@
|
||||
use crate::{parsing::Span, JobHandle};
|
||||
use ai::embedding::EmbeddingProvider;
|
||||
use gpui::executor::Background;
|
||||
use gpui::BackgroundExecutor;
|
||||
use parking_lot::Mutex;
|
||||
use smol::channel;
|
||||
use std::{mem, ops::Range, path::Path, sync::Arc, time::SystemTime};
|
||||
@ -37,7 +37,7 @@ impl PartialEq for FileToEmbed {
|
||||
pub struct EmbeddingQueue {
|
||||
embedding_provider: Arc<dyn EmbeddingProvider>,
|
||||
pending_batch: Vec<FileFragmentToEmbed>,
|
||||
executor: Arc<Background>,
|
||||
executor: BackgroundExecutor,
|
||||
pending_batch_token_count: usize,
|
||||
finished_files_tx: channel::Sender<FileToEmbed>,
|
||||
finished_files_rx: channel::Receiver<FileToEmbed>,
|
||||
@ -50,7 +50,10 @@ pub struct FileFragmentToEmbed {
|
||||
}
|
||||
|
||||
impl EmbeddingQueue {
|
||||
pub fn new(embedding_provider: Arc<dyn EmbeddingProvider>, executor: Arc<Background>) -> Self {
|
||||
pub fn new(
|
||||
embedding_provider: Arc<dyn EmbeddingProvider>,
|
||||
executor: BackgroundExecutor,
|
||||
) -> Self {
|
||||
let (finished_files_tx, finished_files_rx) = channel::unbounded();
|
||||
Self {
|
||||
embedding_provider,
|
||||
|
@ -9,12 +9,15 @@ mod semantic_index_tests;
|
||||
use crate::semantic_index_settings::SemanticIndexSettings;
|
||||
use ai::embedding::{Embedding, EmbeddingProvider};
|
||||
use ai::providers::open_ai::OpenAIEmbeddingProvider;
|
||||
use anyhow::{anyhow, Result};
|
||||
use anyhow::{anyhow, Context as _, Result};
|
||||
use collections::{BTreeMap, HashMap, HashSet};
|
||||
use db::VectorDatabase;
|
||||
use embedding_queue::{EmbeddingQueue, FileToEmbed};
|
||||
use futures::{future, FutureExt, StreamExt};
|
||||
use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle};
|
||||
use gpui::{
|
||||
AppContext, AsyncAppContext, BorrowWindow, Context, Model, ModelContext, Task, ViewContext,
|
||||
WeakModel,
|
||||
};
|
||||
use language::{Anchor, Bias, Buffer, Language, LanguageRegistry};
|
||||
use lazy_static::lazy_static;
|
||||
use ordered_float::OrderedFloat;
|
||||
@ -22,6 +25,7 @@ use parking_lot::Mutex;
|
||||
use parsing::{CodeContextRetriever, Span, SpanDigest, PARSEABLE_ENTIRE_FILE_TYPES};
|
||||
use postage::watch;
|
||||
use project::{Fs, PathChange, Project, ProjectEntryId, Worktree, WorktreeId};
|
||||
use settings::Settings;
|
||||
use smol::channel;
|
||||
use std::{
|
||||
cmp::Reverse,
|
||||
@ -35,7 +39,7 @@ use std::{
|
||||
};
|
||||
use util::paths::PathMatcher;
|
||||
use util::{channel::RELEASE_CHANNEL_NAME, http::HttpClient, paths::EMBEDDINGS_DIR, ResultExt};
|
||||
use workspace::WorkspaceCreated;
|
||||
use workspace::Workspace;
|
||||
|
||||
const SEMANTIC_INDEX_VERSION: usize = 11;
|
||||
const BACKGROUND_INDEXING_DELAY: Duration = Duration::from_secs(5 * 60);
|
||||
@ -51,54 +55,54 @@ pub fn init(
|
||||
language_registry: Arc<LanguageRegistry>,
|
||||
cx: &mut AppContext,
|
||||
) {
|
||||
settings::register::<SemanticIndexSettings>(cx);
|
||||
SemanticIndexSettings::register(cx);
|
||||
|
||||
let db_file_path = EMBEDDINGS_DIR
|
||||
.join(Path::new(RELEASE_CHANNEL_NAME.as_str()))
|
||||
.join("embeddings_db");
|
||||
|
||||
cx.subscribe_global::<WorkspaceCreated, _>({
|
||||
move |event, cx| {
|
||||
cx.observe_new_views(
|
||||
|workspace: &mut Workspace, cx: &mut ViewContext<Workspace>| {
|
||||
let Some(semantic_index) = SemanticIndex::global(cx) else {
|
||||
return;
|
||||
};
|
||||
let workspace = &event.0;
|
||||
if let Some(workspace) = workspace.upgrade(cx) {
|
||||
let project = workspace.read(cx).project().clone();
|
||||
let project = workspace.project().clone();
|
||||
|
||||
if project.read(cx).is_local() {
|
||||
cx.spawn(|mut cx| async move {
|
||||
cx.app_mut()
|
||||
.spawn(|mut cx| async move {
|
||||
let previously_indexed = semantic_index
|
||||
.update(&mut cx, |index, cx| {
|
||||
index.project_previously_indexed(&project, cx)
|
||||
})
|
||||
})?
|
||||
.await?;
|
||||
if previously_indexed {
|
||||
semantic_index
|
||||
.update(&mut cx, |index, cx| index.index_project(project, cx))
|
||||
.update(&mut cx, |index, cx| index.index_project(project, cx))?
|
||||
.await?;
|
||||
}
|
||||
anyhow::Ok(())
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
},
|
||||
)
|
||||
.detach();
|
||||
|
||||
cx.spawn(move |mut cx| async move {
|
||||
cx.spawn(move |cx| async move {
|
||||
let semantic_index = SemanticIndex::new(
|
||||
fs,
|
||||
db_file_path,
|
||||
Arc::new(OpenAIEmbeddingProvider::new(http_client, cx.background())),
|
||||
Arc::new(OpenAIEmbeddingProvider::new(
|
||||
http_client,
|
||||
cx.background_executor().clone(),
|
||||
)),
|
||||
language_registry,
|
||||
cx.clone(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
cx.update(|cx| {
|
||||
cx.set_global(semantic_index.clone());
|
||||
});
|
||||
cx.update(|cx| cx.set_global(semantic_index.clone()))?;
|
||||
|
||||
anyhow::Ok(())
|
||||
})
|
||||
@ -124,7 +128,7 @@ pub struct SemanticIndex {
|
||||
parsing_files_tx: channel::Sender<(Arc<HashMap<SpanDigest, Embedding>>, PendingFile)>,
|
||||
_embedding_task: Task<()>,
|
||||
_parsing_files_tasks: Vec<Task<()>>,
|
||||
projects: HashMap<WeakModelHandle<Project>, ProjectState>,
|
||||
projects: HashMap<WeakModel<Project>, ProjectState>,
|
||||
}
|
||||
|
||||
struct ProjectState {
|
||||
@ -229,12 +233,12 @@ impl ProjectState {
|
||||
pending_file_count_tx,
|
||||
pending_index: 0,
|
||||
_subscription: subscription,
|
||||
_observe_pending_file_count: cx.spawn_weak({
|
||||
_observe_pending_file_count: cx.spawn({
|
||||
let mut pending_file_count_rx = pending_file_count_rx.clone();
|
||||
|this, mut cx| async move {
|
||||
while let Some(_) = pending_file_count_rx.next().await {
|
||||
if let Some(this) = this.upgrade(&cx) {
|
||||
this.update(&mut cx, |_, cx| cx.notify());
|
||||
if this.update(&mut cx, |_, cx| cx.notify()).is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -264,21 +268,21 @@ pub struct PendingFile {
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct SearchResult {
|
||||
pub buffer: ModelHandle<Buffer>,
|
||||
pub buffer: Model<Buffer>,
|
||||
pub range: Range<Anchor>,
|
||||
pub similarity: OrderedFloat<f32>,
|
||||
}
|
||||
|
||||
impl SemanticIndex {
|
||||
pub fn global(cx: &mut AppContext) -> Option<ModelHandle<SemanticIndex>> {
|
||||
if cx.has_global::<ModelHandle<Self>>() {
|
||||
Some(cx.global::<ModelHandle<SemanticIndex>>().clone())
|
||||
pub fn global(cx: &mut AppContext) -> Option<Model<SemanticIndex>> {
|
||||
if cx.has_global::<Model<Self>>() {
|
||||
Some(cx.global::<Model<SemanticIndex>>().clone())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub fn authenticate(&mut self, cx: &AppContext) -> bool {
|
||||
pub fn authenticate(&mut self, cx: &mut AppContext) -> bool {
|
||||
if !self.embedding_provider.has_credentials() {
|
||||
self.embedding_provider.retrieve_credentials(cx);
|
||||
} else {
|
||||
@ -293,10 +297,10 @@ impl SemanticIndex {
|
||||
}
|
||||
|
||||
pub fn enabled(cx: &AppContext) -> bool {
|
||||
settings::get::<SemanticIndexSettings>(cx).enabled
|
||||
SemanticIndexSettings::get_global(cx).enabled
|
||||
}
|
||||
|
||||
pub fn status(&self, project: &ModelHandle<Project>) -> SemanticIndexStatus {
|
||||
pub fn status(&self, project: &Model<Project>) -> SemanticIndexStatus {
|
||||
if !self.is_authenticated() {
|
||||
return SemanticIndexStatus::NotAuthenticated;
|
||||
}
|
||||
@ -326,21 +330,22 @@ impl SemanticIndex {
|
||||
embedding_provider: Arc<dyn EmbeddingProvider>,
|
||||
language_registry: Arc<LanguageRegistry>,
|
||||
mut cx: AsyncAppContext,
|
||||
) -> Result<ModelHandle<Self>> {
|
||||
) -> Result<Model<Self>> {
|
||||
let t0 = Instant::now();
|
||||
let database_path = Arc::from(database_path);
|
||||
let db = VectorDatabase::new(fs.clone(), database_path, cx.background()).await?;
|
||||
let db = VectorDatabase::new(fs.clone(), database_path, cx.background_executor().clone())
|
||||
.await?;
|
||||
|
||||
log::trace!(
|
||||
"db initialization took {:?} milliseconds",
|
||||
t0.elapsed().as_millis()
|
||||
);
|
||||
|
||||
Ok(cx.add_model(|cx| {
|
||||
cx.new_model(|cx| {
|
||||
let t0 = Instant::now();
|
||||
let embedding_queue =
|
||||
EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone());
|
||||
let _embedding_task = cx.background().spawn({
|
||||
EmbeddingQueue::new(embedding_provider.clone(), cx.background_executor().clone());
|
||||
let _embedding_task = cx.background_executor().spawn({
|
||||
let embedded_files = embedding_queue.finished_files();
|
||||
let db = db.clone();
|
||||
async move {
|
||||
@ -357,13 +362,13 @@ impl SemanticIndex {
|
||||
channel::unbounded::<(Arc<HashMap<SpanDigest, Embedding>>, PendingFile)>();
|
||||
let embedding_queue = Arc::new(Mutex::new(embedding_queue));
|
||||
let mut _parsing_files_tasks = Vec::new();
|
||||
for _ in 0..cx.background().num_cpus() {
|
||||
for _ in 0..cx.background_executor().num_cpus() {
|
||||
let fs = fs.clone();
|
||||
let mut parsing_files_rx = parsing_files_rx.clone();
|
||||
let embedding_provider = embedding_provider.clone();
|
||||
let embedding_queue = embedding_queue.clone();
|
||||
let background = cx.background().clone();
|
||||
_parsing_files_tasks.push(cx.background().spawn(async move {
|
||||
let background = cx.background_executor().clone();
|
||||
_parsing_files_tasks.push(cx.background_executor().spawn(async move {
|
||||
let mut retriever = CodeContextRetriever::new(embedding_provider.clone());
|
||||
loop {
|
||||
let mut timer = background.timer(EMBEDDING_QUEUE_FLUSH_TIMEOUT).fuse();
|
||||
@ -405,7 +410,7 @@ impl SemanticIndex {
|
||||
_parsing_files_tasks,
|
||||
projects: Default::default(),
|
||||
}
|
||||
}))
|
||||
})
|
||||
}
|
||||
|
||||
async fn parse_file(
|
||||
@ -449,12 +454,12 @@ impl SemanticIndex {
|
||||
|
||||
pub fn project_previously_indexed(
|
||||
&mut self,
|
||||
project: &ModelHandle<Project>,
|
||||
project: &Model<Project>,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) -> Task<Result<bool>> {
|
||||
let worktrees_indexed_previously = project
|
||||
.read(cx)
|
||||
.worktrees(cx)
|
||||
.worktrees()
|
||||
.map(|worktree| {
|
||||
self.db
|
||||
.worktree_previously_indexed(&worktree.read(cx).abs_path())
|
||||
@ -473,7 +478,7 @@ impl SemanticIndex {
|
||||
|
||||
fn project_entries_changed(
|
||||
&mut self,
|
||||
project: ModelHandle<Project>,
|
||||
project: Model<Project>,
|
||||
worktree_id: WorktreeId,
|
||||
changes: Arc<[(Arc<Path>, ProjectEntryId, PathChange)]>,
|
||||
cx: &mut ModelContext<Self>,
|
||||
@ -495,22 +500,25 @@ impl SemanticIndex {
|
||||
};
|
||||
worktree_state.paths_changed(changes, worktree);
|
||||
if let WorktreeState::Registered(_) = worktree_state {
|
||||
cx.spawn_weak(|this, mut cx| async move {
|
||||
cx.background().timer(BACKGROUND_INDEXING_DELAY).await;
|
||||
if let Some((this, project)) = this.upgrade(&cx).zip(project.upgrade(&cx)) {
|
||||
cx.spawn(|this, mut cx| async move {
|
||||
cx.background_executor()
|
||||
.timer(BACKGROUND_INDEXING_DELAY)
|
||||
.await;
|
||||
if let Some((this, project)) = this.upgrade().zip(project.upgrade()) {
|
||||
this.update(&mut cx, |this, cx| {
|
||||
this.index_project(project, cx).detach_and_log_err(cx)
|
||||
});
|
||||
})?;
|
||||
}
|
||||
anyhow::Ok(())
|
||||
})
|
||||
.detach();
|
||||
.detach_and_log_err(cx);
|
||||
}
|
||||
}
|
||||
|
||||
fn register_worktree(
|
||||
&mut self,
|
||||
project: ModelHandle<Project>,
|
||||
worktree: ModelHandle<Worktree>,
|
||||
project: Model<Project>,
|
||||
worktree: Model<Worktree>,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) {
|
||||
let project = project.downgrade();
|
||||
@ -536,16 +544,18 @@ impl SemanticIndex {
|
||||
scan_complete.await;
|
||||
let db_id = db.find_or_create_worktree(worktree_abs_path).await?;
|
||||
let mut file_mtimes = db.get_file_mtimes(db_id).await?;
|
||||
let worktree = if let Some(project) = project.upgrade(&cx) {
|
||||
let worktree = if let Some(project) = project.upgrade() {
|
||||
project
|
||||
.read_with(&cx, |project, cx| project.worktree_for_id(worktree_id, cx))
|
||||
.ok_or_else(|| anyhow!("worktree not found"))?
|
||||
.ok()
|
||||
.flatten()
|
||||
.context("worktree not found")?
|
||||
} else {
|
||||
return anyhow::Ok(());
|
||||
};
|
||||
let worktree = worktree.read_with(&cx, |worktree, _| worktree.snapshot());
|
||||
let worktree = worktree.read_with(&cx, |worktree, _| worktree.snapshot())?;
|
||||
let mut changed_paths = cx
|
||||
.background()
|
||||
.background_executor()
|
||||
.spawn(async move {
|
||||
let mut changed_paths = BTreeMap::new();
|
||||
for file in worktree.files(false, 0) {
|
||||
@ -607,10 +617,8 @@ impl SemanticIndex {
|
||||
let project_state = this
|
||||
.projects
|
||||
.get_mut(&project)
|
||||
.ok_or_else(|| anyhow!("project not registered"))?;
|
||||
let project = project
|
||||
.upgrade(cx)
|
||||
.ok_or_else(|| anyhow!("project was dropped"))?;
|
||||
.context("project not registered")?;
|
||||
let project = project.upgrade().context("project was dropped")?;
|
||||
|
||||
if let Some(WorktreeState::Registering(state)) =
|
||||
project_state.worktrees.remove(&worktree_id)
|
||||
@ -627,7 +635,7 @@ impl SemanticIndex {
|
||||
this.index_project(project, cx).detach_and_log_err(cx);
|
||||
|
||||
anyhow::Ok(())
|
||||
})?;
|
||||
})??;
|
||||
|
||||
anyhow::Ok(())
|
||||
};
|
||||
@ -639,6 +647,7 @@ impl SemanticIndex {
|
||||
project_state.worktrees.remove(&worktree_id);
|
||||
});
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
|
||||
*done_tx.borrow_mut() = Some(());
|
||||
@ -654,11 +663,7 @@ impl SemanticIndex {
|
||||
);
|
||||
}
|
||||
|
||||
fn project_worktrees_changed(
|
||||
&mut self,
|
||||
project: ModelHandle<Project>,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) {
|
||||
fn project_worktrees_changed(&mut self, project: Model<Project>, cx: &mut ModelContext<Self>) {
|
||||
let project_state = if let Some(project_state) = self.projects.get_mut(&project.downgrade())
|
||||
{
|
||||
project_state
|
||||
@ -668,7 +673,7 @@ impl SemanticIndex {
|
||||
|
||||
let mut worktrees = project
|
||||
.read(cx)
|
||||
.worktrees(cx)
|
||||
.worktrees()
|
||||
.filter(|worktree| worktree.read(cx).is_local())
|
||||
.collect::<Vec<_>>();
|
||||
let worktree_ids = worktrees
|
||||
@ -691,10 +696,7 @@ impl SemanticIndex {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn pending_file_count(
|
||||
&self,
|
||||
project: &ModelHandle<Project>,
|
||||
) -> Option<watch::Receiver<usize>> {
|
||||
pub fn pending_file_count(&self, project: &Model<Project>) -> Option<watch::Receiver<usize>> {
|
||||
Some(
|
||||
self.projects
|
||||
.get(&project.downgrade())?
|
||||
@ -705,7 +707,7 @@ impl SemanticIndex {
|
||||
|
||||
pub fn search_project(
|
||||
&mut self,
|
||||
project: ModelHandle<Project>,
|
||||
project: Model<Project>,
|
||||
query: String,
|
||||
limit: usize,
|
||||
includes: Vec<PathMatcher>,
|
||||
@ -727,7 +729,7 @@ impl SemanticIndex {
|
||||
.embed_batch(vec![query])
|
||||
.await?
|
||||
.pop()
|
||||
.ok_or_else(|| anyhow!("could not embed query"))?;
|
||||
.context("could not embed query")?;
|
||||
log::trace!("Embedding Search Query: {:?}ms", t0.elapsed().as_millis());
|
||||
|
||||
let search_start = Instant::now();
|
||||
@ -740,10 +742,10 @@ impl SemanticIndex {
|
||||
&excludes,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
})?;
|
||||
let file_results = this.update(&mut cx, |this, cx| {
|
||||
this.search_files(project, query, limit, includes, excludes, cx)
|
||||
});
|
||||
})?;
|
||||
let (modified_buffer_results, file_results) =
|
||||
futures::join!(modified_buffer_results, file_results);
|
||||
|
||||
@ -768,7 +770,7 @@ impl SemanticIndex {
|
||||
|
||||
pub fn search_files(
|
||||
&mut self,
|
||||
project: ModelHandle<Project>,
|
||||
project: Model<Project>,
|
||||
query: Embedding,
|
||||
limit: usize,
|
||||
includes: Vec<PathMatcher>,
|
||||
@ -778,14 +780,18 @@ impl SemanticIndex {
|
||||
let db_path = self.db.path().clone();
|
||||
let fs = self.fs.clone();
|
||||
cx.spawn(|this, mut cx| async move {
|
||||
let database =
|
||||
VectorDatabase::new(fs.clone(), db_path.clone(), cx.background()).await?;
|
||||
let database = VectorDatabase::new(
|
||||
fs.clone(),
|
||||
db_path.clone(),
|
||||
cx.background_executor().clone(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
let worktree_db_ids = this.read_with(&cx, |this, _| {
|
||||
let project_state = this
|
||||
.projects
|
||||
.get(&project.downgrade())
|
||||
.ok_or_else(|| anyhow!("project was not indexed"))?;
|
||||
.context("project was not indexed")?;
|
||||
let worktree_db_ids = project_state
|
||||
.worktrees
|
||||
.values()
|
||||
@ -798,13 +804,13 @@ impl SemanticIndex {
|
||||
})
|
||||
.collect::<Vec<i64>>();
|
||||
anyhow::Ok(worktree_db_ids)
|
||||
})?;
|
||||
})??;
|
||||
|
||||
let file_ids = database
|
||||
.retrieve_included_file_ids(&worktree_db_ids, &includes, &excludes)
|
||||
.await?;
|
||||
|
||||
let batch_n = cx.background().num_cpus();
|
||||
let batch_n = cx.background_executor().num_cpus();
|
||||
let ids_len = file_ids.clone().len();
|
||||
let minimum_batch_size = 50;
|
||||
|
||||
@ -824,7 +830,8 @@ impl SemanticIndex {
|
||||
let fs = fs.clone();
|
||||
let db_path = db_path.clone();
|
||||
let query = query.clone();
|
||||
if let Some(db) = VectorDatabase::new(fs, db_path.clone(), cx.background())
|
||||
if let Some(db) =
|
||||
VectorDatabase::new(fs, db_path.clone(), cx.background_executor().clone())
|
||||
.await
|
||||
.log_err()
|
||||
{
|
||||
@ -864,6 +871,7 @@ impl SemanticIndex {
|
||||
let mut ranges = Vec::new();
|
||||
let weak_project = project.downgrade();
|
||||
project.update(&mut cx, |project, cx| {
|
||||
let this = this.upgrade().context("index was dropped")?;
|
||||
for (worktree_db_id, file_path, byte_range) in spans {
|
||||
let project_state =
|
||||
if let Some(state) = this.read(cx).projects.get(&weak_project) {
|
||||
@ -878,7 +886,7 @@ impl SemanticIndex {
|
||||
}
|
||||
|
||||
Ok(())
|
||||
})?;
|
||||
})??;
|
||||
|
||||
let buffers = futures::future::join_all(tasks).await;
|
||||
Ok(buffers
|
||||
@ -887,11 +895,13 @@ impl SemanticIndex {
|
||||
.zip(scores)
|
||||
.filter_map(|((buffer, range), similarity)| {
|
||||
let buffer = buffer.log_err()?;
|
||||
let range = buffer.read_with(&cx, |buffer, _| {
|
||||
let range = buffer
|
||||
.read_with(&cx, |buffer, _| {
|
||||
let start = buffer.clip_offset(range.start, Bias::Left);
|
||||
let end = buffer.clip_offset(range.end, Bias::Right);
|
||||
buffer.anchor_before(start)..buffer.anchor_after(end)
|
||||
});
|
||||
})
|
||||
.log_err()?;
|
||||
Some(SearchResult {
|
||||
buffer,
|
||||
range,
|
||||
@ -904,7 +914,7 @@ impl SemanticIndex {
|
||||
|
||||
fn search_modified_buffers(
|
||||
&self,
|
||||
project: &ModelHandle<Project>,
|
||||
project: &Model<Project>,
|
||||
query: Embedding,
|
||||
limit: usize,
|
||||
includes: &[PathMatcher],
|
||||
@ -913,7 +923,7 @@ impl SemanticIndex {
|
||||
) -> Task<Result<Vec<SearchResult>>> {
|
||||
let modified_buffers = project
|
||||
.read(cx)
|
||||
.opened_buffers(cx)
|
||||
.opened_buffers()
|
||||
.into_iter()
|
||||
.filter_map(|buffer_handle| {
|
||||
let buffer = buffer_handle.read(cx);
|
||||
@ -941,8 +951,8 @@ impl SemanticIndex {
|
||||
let embedding_provider = self.embedding_provider.clone();
|
||||
let fs = self.fs.clone();
|
||||
let db_path = self.db.path().clone();
|
||||
let background = cx.background().clone();
|
||||
cx.background().spawn(async move {
|
||||
let background = cx.background_executor().clone();
|
||||
cx.background_executor().spawn(async move {
|
||||
let db = VectorDatabase::new(fs, db_path.clone(), background).await?;
|
||||
let mut results = Vec::<SearchResult>::new();
|
||||
|
||||
@ -996,7 +1006,7 @@ impl SemanticIndex {
|
||||
|
||||
pub fn index_project(
|
||||
&mut self,
|
||||
project: ModelHandle<Project>,
|
||||
project: Model<Project>,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) -> Task<Result<()>> {
|
||||
if !self.is_authenticated() {
|
||||
@ -1038,7 +1048,7 @@ impl SemanticIndex {
|
||||
let project_state = this
|
||||
.projects
|
||||
.get_mut(&project.downgrade())
|
||||
.ok_or_else(|| anyhow!("project was dropped"))?;
|
||||
.context("project was dropped")?;
|
||||
let pending_file_count_tx = &project_state.pending_file_count_tx;
|
||||
|
||||
project_state
|
||||
@ -1080,9 +1090,9 @@ impl SemanticIndex {
|
||||
});
|
||||
|
||||
anyhow::Ok(())
|
||||
})?;
|
||||
})??;
|
||||
|
||||
cx.background()
|
||||
cx.background_executor()
|
||||
.spawn(async move {
|
||||
for (worktree_db_id, path) in files_to_delete {
|
||||
db.delete_file(worktree_db_id, path).await.log_err();
|
||||
@ -1138,11 +1148,11 @@ impl SemanticIndex {
|
||||
let project_state = this
|
||||
.projects
|
||||
.get_mut(&project.downgrade())
|
||||
.ok_or_else(|| anyhow!("project was dropped"))?;
|
||||
.context("project was dropped")?;
|
||||
project_state.pending_index -= 1;
|
||||
cx.notify();
|
||||
anyhow::Ok(())
|
||||
})?;
|
||||
})??;
|
||||
|
||||
Ok(())
|
||||
})
|
||||
@ -1150,15 +1160,15 @@ impl SemanticIndex {
|
||||
|
||||
fn wait_for_worktree_registration(
|
||||
&self,
|
||||
project: &ModelHandle<Project>,
|
||||
project: &Model<Project>,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) -> Task<Result<()>> {
|
||||
let project = project.downgrade();
|
||||
cx.spawn_weak(|this, cx| async move {
|
||||
cx.spawn(|this, cx| async move {
|
||||
loop {
|
||||
let mut pending_worktrees = Vec::new();
|
||||
this.upgrade(&cx)
|
||||
.ok_or_else(|| anyhow!("semantic index dropped"))?
|
||||
this.upgrade()
|
||||
.context("semantic index dropped")?
|
||||
.read_with(&cx, |this, _| {
|
||||
if let Some(project) = this.projects.get(&project) {
|
||||
for worktree in project.worktrees.values() {
|
||||
@ -1167,7 +1177,7 @@ impl SemanticIndex {
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
})?;
|
||||
|
||||
if pending_worktrees.is_empty() {
|
||||
break;
|
||||
@ -1230,17 +1240,13 @@ impl SemanticIndex {
|
||||
} else {
|
||||
embeddings.next()
|
||||
};
|
||||
let embedding = embedding.ok_or_else(|| anyhow!("failed to embed spans"))?;
|
||||
let embedding = embedding.context("failed to embed spans")?;
|
||||
span.embedding = Some(embedding);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Entity for SemanticIndex {
|
||||
type Event = ();
|
||||
}
|
||||
|
||||
impl Drop for JobHandle {
|
||||
fn drop(&mut self) {
|
||||
if let Some(inner) = Arc::get_mut(&mut self.tx) {
|
||||
|
@ -1,7 +1,7 @@
|
||||
use anyhow;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::Setting;
|
||||
use settings::Settings;
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct SemanticIndexSettings {
|
||||
@ -13,7 +13,7 @@ pub struct SemanticIndexSettingsContent {
|
||||
pub enabled: Option<bool>,
|
||||
}
|
||||
|
||||
impl Setting for SemanticIndexSettings {
|
||||
impl Settings for SemanticIndexSettings {
|
||||
const KEY: Option<&'static str> = Some("semantic_index");
|
||||
|
||||
type FileContent = SemanticIndexSettingsContent;
|
||||
@ -21,7 +21,7 @@ impl Setting for SemanticIndexSettings {
|
||||
fn load(
|
||||
default_value: &Self::FileContent,
|
||||
user_values: &[&Self::FileContent],
|
||||
_: &gpui::AppContext,
|
||||
_: &mut gpui::AppContext,
|
||||
) -> anyhow::Result<Self> {
|
||||
Self::load_via_json_merge(default_value, user_values)
|
||||
}
|
||||
|
@ -6,14 +6,14 @@ use crate::{
|
||||
};
|
||||
use ai::test::FakeEmbeddingProvider;
|
||||
|
||||
use gpui::{executor::Deterministic, Task, TestAppContext};
|
||||
use gpui::{Task, TestAppContext};
|
||||
use language::{Language, LanguageConfig, LanguageRegistry, ToOffset};
|
||||
use parking_lot::Mutex;
|
||||
use pretty_assertions::assert_eq;
|
||||
use project::{project_settings::ProjectSettings, FakeFs, Fs, Project};
|
||||
use rand::{rngs::StdRng, Rng};
|
||||
use serde_json::json;
|
||||
use settings::SettingsStore;
|
||||
use settings::{Settings, SettingsStore};
|
||||
use std::{path::Path, sync::Arc, time::SystemTime};
|
||||
use unindent::Unindent;
|
||||
use util::{paths::PathMatcher, RandomCharIter};
|
||||
@ -26,10 +26,10 @@ fn init_logger() {
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_semantic_index(deterministic: Arc<Deterministic>, cx: &mut TestAppContext) {
|
||||
async fn test_semantic_index(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
|
||||
let fs = FakeFs::new(cx.background());
|
||||
let fs = FakeFs::new(cx.background_executor.clone());
|
||||
fs.insert_tree(
|
||||
"/the-root",
|
||||
json!({
|
||||
@ -91,9 +91,10 @@ async fn test_semantic_index(deterministic: Arc<Deterministic>, cx: &mut TestApp
|
||||
});
|
||||
let pending_file_count =
|
||||
semantic_index.read_with(cx, |index, _| index.pending_file_count(&project).unwrap());
|
||||
deterministic.run_until_parked();
|
||||
cx.background_executor.run_until_parked();
|
||||
assert_eq!(*pending_file_count.borrow(), 3);
|
||||
deterministic.advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
|
||||
cx.background_executor
|
||||
.advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
|
||||
assert_eq!(*pending_file_count.borrow(), 0);
|
||||
|
||||
let search_results = search_results.await.unwrap();
|
||||
@ -170,13 +171,15 @@ async fn test_semantic_index(deterministic: Arc<Deterministic>, cx: &mut TestApp
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
deterministic.advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
|
||||
cx.background_executor
|
||||
.advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
|
||||
|
||||
let prev_embedding_count = embedding_provider.embedding_count();
|
||||
let index = semantic_index.update(cx, |store, cx| store.index_project(project.clone(), cx));
|
||||
deterministic.run_until_parked();
|
||||
cx.background_executor.run_until_parked();
|
||||
assert_eq!(*pending_file_count.borrow(), 1);
|
||||
deterministic.advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
|
||||
cx.background_executor
|
||||
.advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
|
||||
assert_eq!(*pending_file_count.borrow(), 0);
|
||||
index.await.unwrap();
|
||||
|
||||
@ -220,13 +223,13 @@ async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) {
|
||||
|
||||
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
|
||||
|
||||
let mut queue = EmbeddingQueue::new(embedding_provider.clone(), cx.background());
|
||||
let mut queue = EmbeddingQueue::new(embedding_provider.clone(), cx.background_executor.clone());
|
||||
for file in &files {
|
||||
queue.push(file.clone());
|
||||
}
|
||||
queue.flush();
|
||||
|
||||
cx.foreground().run_until_parked();
|
||||
cx.background_executor.run_until_parked();
|
||||
let finished_files = queue.finished_files();
|
||||
let mut embedded_files: Vec<_> = files
|
||||
.iter()
|
||||
@ -1686,8 +1689,9 @@ fn test_subtract_ranges() {
|
||||
|
||||
fn init_test(cx: &mut TestAppContext) {
|
||||
cx.update(|cx| {
|
||||
cx.set_global(SettingsStore::test(cx));
|
||||
settings::register::<SemanticIndexSettings>(cx);
|
||||
settings::register::<ProjectSettings>(cx);
|
||||
let settings_store = SettingsStore::test(cx);
|
||||
cx.set_global(settings_store);
|
||||
SemanticIndexSettings::register(cx);
|
||||
ProjectSettings::register(cx);
|
||||
});
|
||||
}
|
||||
|
@ -1,69 +0,0 @@
|
||||
[package]
|
||||
name = "semantic_index2"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
publish = false
|
||||
|
||||
[lib]
|
||||
path = "src/semantic_index.rs"
|
||||
doctest = false
|
||||
|
||||
[dependencies]
|
||||
ai = { path = "../ai" }
|
||||
collections = { path = "../collections" }
|
||||
gpui = { package = "gpui2", path = "../gpui2" }
|
||||
language = { path = "../language" }
|
||||
project = { path = "../project" }
|
||||
workspace = { path = "../workspace" }
|
||||
util = { path = "../util" }
|
||||
rpc = { package = "rpc2", path = "../rpc2" }
|
||||
settings = { package = "settings2", path = "../settings2" }
|
||||
anyhow.workspace = true
|
||||
postage.workspace = true
|
||||
futures.workspace = true
|
||||
ordered-float.workspace = true
|
||||
smol.workspace = true
|
||||
rusqlite.workspace = true
|
||||
log.workspace = true
|
||||
tree-sitter.workspace = true
|
||||
lazy_static.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
async-trait.workspace = true
|
||||
tiktoken-rs.workspace = true
|
||||
parking_lot.workspace = true
|
||||
rand.workspace = true
|
||||
schemars.workspace = true
|
||||
globset.workspace = true
|
||||
sha1 = "0.10.5"
|
||||
ndarray = { version = "0.15.0" }
|
||||
|
||||
[dev-dependencies]
|
||||
ai = { path = "../ai", features = ["test-support"] }
|
||||
collections = { path = "../collections", features = ["test-support"] }
|
||||
gpui = { package = "gpui2", path = "../gpui2", features = ["test-support"] }
|
||||
language = { path = "../language", features = ["test-support"] }
|
||||
project = { path = "../project", features = ["test-support"] }
|
||||
rpc = { package = "rpc2", path = "../rpc2", features = ["test-support"] }
|
||||
workspace = { path = "../workspace", features = ["test-support"] }
|
||||
settings = { package = "settings2", path = "../settings2", features = ["test-support"]}
|
||||
rust-embed = { version = "8.0", features = ["include-exclude"] }
|
||||
client = { path = "../client" }
|
||||
node_runtime = { path = "../node_runtime"}
|
||||
|
||||
pretty_assertions.workspace = true
|
||||
rand.workspace = true
|
||||
unindent.workspace = true
|
||||
tempdir.workspace = true
|
||||
ctor.workspace = true
|
||||
env_logger.workspace = true
|
||||
|
||||
tree-sitter-typescript.workspace = true
|
||||
tree-sitter-json.workspace = true
|
||||
tree-sitter-rust.workspace = true
|
||||
tree-sitter-toml.workspace = true
|
||||
tree-sitter-cpp.workspace = true
|
||||
tree-sitter-elixir.workspace = true
|
||||
tree-sitter-lua.workspace = true
|
||||
tree-sitter-ruby.workspace = true
|
||||
tree-sitter-php.workspace = true
|
@ -1,20 +0,0 @@
|
||||
|
||||
# Semantic Index
|
||||
|
||||
## Evaluation
|
||||
|
||||
### Metrics
|
||||
|
||||
nDCG@k:
|
||||
- "The value of NDCG is determined by comparing the relevance of the items returned by the search engine to the relevance of the item that a hypothetical "ideal" search engine would return.
|
||||
- "The relevance of result is represented by a score (also known as a 'grade') that is assigned to the search query. The scores of these results are then discounted based on their position in the search results -- did they get recommended first or last?"
|
||||
|
||||
MRR@k:
|
||||
- "Mean reciprocal rank quantifies the rank of the first relevant item found in teh recommendation list."
|
||||
|
||||
MAP@k:
|
||||
- "Mean average precision averages the precision@k metric at each relevant item position in the recommendation list.
|
||||
|
||||
Resources:
|
||||
- [Evaluating recommendation metrics](https://www.shaped.ai/blog/evaluating-recommendation-systems-map-mmr-ndcg)
|
||||
- [Math Walkthrough](https://towardsdatascience.com/demystifying-ndcg-bee3be58cfe0)
|
@ -1,114 +0,0 @@
|
||||
{
|
||||
"repo": "https://github.com/AntonOsika/gpt-engineer.git",
|
||||
"commit": "7735a6445bae3611c62f521e6464c67c957f87c2",
|
||||
"assertions": [
|
||||
{
|
||||
"query": "How do I contribute to this project?",
|
||||
"matches": [
|
||||
".github/CONTRIBUTING.md:1",
|
||||
"ROADMAP.md:48"
|
||||
]
|
||||
},
|
||||
{
|
||||
"query": "What version of the openai package is active?",
|
||||
"matches": [
|
||||
"pyproject.toml:14"
|
||||
]
|
||||
},
|
||||
{
|
||||
"query": "Ask user for clarification",
|
||||
"matches": [
|
||||
"gpt_engineer/steps.py:69"
|
||||
]
|
||||
},
|
||||
{
|
||||
"query": "generate tests for python code",
|
||||
"matches": [
|
||||
"gpt_engineer/steps.py:153"
|
||||
]
|
||||
},
|
||||
{
|
||||
"query": "get item from database based on key",
|
||||
"matches": [
|
||||
"gpt_engineer/db.py:42",
|
||||
"gpt_engineer/db.py:68"
|
||||
]
|
||||
},
|
||||
{
|
||||
"query": "prompt user to select files",
|
||||
"matches": [
|
||||
"gpt_engineer/file_selector.py:171",
|
||||
"gpt_engineer/file_selector.py:306",
|
||||
"gpt_engineer/file_selector.py:289",
|
||||
"gpt_engineer/file_selector.py:234"
|
||||
]
|
||||
},
|
||||
{
|
||||
"query": "send to rudderstack",
|
||||
"matches": [
|
||||
"gpt_engineer/collect.py:11",
|
||||
"gpt_engineer/collect.py:38"
|
||||
]
|
||||
},
|
||||
{
|
||||
"query": "parse code blocks from chat messages",
|
||||
"matches": [
|
||||
"gpt_engineer/chat_to_files.py:10",
|
||||
"docs/intro/chat_parsing.md:1"
|
||||
]
|
||||
},
|
||||
{
|
||||
"query": "how do I use the docker cli?",
|
||||
"matches": [
|
||||
"docker/README.md:1"
|
||||
]
|
||||
},
|
||||
{
|
||||
"query": "ask the user if the code ran successfully?",
|
||||
"matches": [
|
||||
"gpt_engineer/learning.py:54"
|
||||
]
|
||||
},
|
||||
{
|
||||
"query": "how is consent granted by the user?",
|
||||
"matches": [
|
||||
"gpt_engineer/learning.py:107",
|
||||
"gpt_engineer/learning.py:130",
|
||||
"gpt_engineer/learning.py:152"
|
||||
]
|
||||
},
|
||||
{
|
||||
"query": "what are all the different steps the agent can take?",
|
||||
"matches": [
|
||||
"docs/intro/steps_module.md:1",
|
||||
"gpt_engineer/steps.py:391"
|
||||
]
|
||||
},
|
||||
{
|
||||
"query": "ask the user for clarification?",
|
||||
"matches": [
|
||||
"gpt_engineer/steps.py:69"
|
||||
]
|
||||
},
|
||||
{
|
||||
"query": "what models are available?",
|
||||
"matches": [
|
||||
"gpt_engineer/ai.py:315",
|
||||
"gpt_engineer/ai.py:341",
|
||||
"docs/open-models.md:1"
|
||||
]
|
||||
},
|
||||
{
|
||||
"query": "what is the current focus of the project?",
|
||||
"matches": [
|
||||
"ROADMAP.md:11"
|
||||
]
|
||||
},
|
||||
{
|
||||
"query": "does the agent know how to fix code?",
|
||||
"matches": [
|
||||
"gpt_engineer/steps.py:367"
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
@ -1,104 +0,0 @@
|
||||
{
|
||||
"repo": "https://github.com/tree-sitter/tree-sitter.git",
|
||||
"commit": "46af27796a76c72d8466627d499f2bca4af958ee",
|
||||
"assertions": [
|
||||
{
|
||||
"query": "What attributes are available for the tags configuration struct?",
|
||||
"matches": [
|
||||
"tags/src/lib.rs:24"
|
||||
]
|
||||
},
|
||||
{
|
||||
"query": "create a new tag configuration",
|
||||
"matches": [
|
||||
"tags/src/lib.rs:119"
|
||||
]
|
||||
},
|
||||
{
|
||||
"query": "generate tags based on config",
|
||||
"matches": [
|
||||
"tags/src/lib.rs:261"
|
||||
]
|
||||
},
|
||||
{
|
||||
"query": "match on ts quantifier in rust",
|
||||
"matches": [
|
||||
"lib/binding_rust/lib.rs:139"
|
||||
]
|
||||
},
|
||||
{
|
||||
"query": "cli command to generate tags",
|
||||
"matches": [
|
||||
"cli/src/tags.rs:10"
|
||||
]
|
||||
},
|
||||
{
|
||||
"query": "what version of the tree-sitter-tags package is active?",
|
||||
"matches": [
|
||||
"tags/Cargo.toml:4"
|
||||
]
|
||||
},
|
||||
{
|
||||
"query": "Insert a new parse state",
|
||||
"matches": [
|
||||
"cli/src/generate/build_tables/build_parse_table.rs:153"
|
||||
]
|
||||
},
|
||||
{
|
||||
"query": "Handle conflict when numerous actions occur on the same symbol",
|
||||
"matches": [
|
||||
"cli/src/generate/build_tables/build_parse_table.rs:363",
|
||||
"cli/src/generate/build_tables/build_parse_table.rs:442"
|
||||
]
|
||||
},
|
||||
{
|
||||
"query": "Match based on associativity of actions",
|
||||
"matches": [
|
||||
"cri/src/generate/build_tables/build_parse_table.rs:542"
|
||||
]
|
||||
},
|
||||
{
|
||||
"query": "Format token set display",
|
||||
"matches": [
|
||||
"cli/src/generate/build_tables/item.rs:246"
|
||||
]
|
||||
},
|
||||
{
|
||||
"query": "extract choices from rule",
|
||||
"matches": [
|
||||
"cli/src/generate/prepare_grammar/flatten_grammar.rs:124"
|
||||
]
|
||||
},
|
||||
{
|
||||
"query": "How do we identify if a symbol is being used?",
|
||||
"matches": [
|
||||
"cli/src/generate/prepare_grammar/flatten_grammar.rs:175"
|
||||
]
|
||||
},
|
||||
{
|
||||
"query": "How do we launch the playground?",
|
||||
"matches": [
|
||||
"cli/src/playground.rs:46"
|
||||
]
|
||||
},
|
||||
{
|
||||
"query": "How do we test treesitter query matches in rust?",
|
||||
"matches": [
|
||||
"cli/src/query_testing.rs:152",
|
||||
"cli/src/tests/query_test.rs:781",
|
||||
"cli/src/tests/query_test.rs:2163",
|
||||
"cli/src/tests/query_test.rs:3781",
|
||||
"cli/src/tests/query_test.rs:887"
|
||||
]
|
||||
},
|
||||
{
|
||||
"query": "What does the CLI do?",
|
||||
"matches": [
|
||||
"cli/README.md:10",
|
||||
"cli/loader/README.md:3",
|
||||
"docs/section-5-implementation.md:14",
|
||||
"docs/section-5-implementation.md:18"
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
@ -1,603 +0,0 @@
|
||||
use crate::{
|
||||
parsing::{Span, SpanDigest},
|
||||
SEMANTIC_INDEX_VERSION,
|
||||
};
|
||||
use ai::embedding::Embedding;
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use collections::HashMap;
|
||||
use futures::channel::oneshot;
|
||||
use gpui::BackgroundExecutor;
|
||||
use ndarray::{Array1, Array2};
|
||||
use ordered_float::OrderedFloat;
|
||||
use project::Fs;
|
||||
use rpc::proto::Timestamp;
|
||||
use rusqlite::params;
|
||||
use rusqlite::types::Value;
|
||||
use std::{
|
||||
future::Future,
|
||||
ops::Range,
|
||||
path::{Path, PathBuf},
|
||||
rc::Rc,
|
||||
sync::Arc,
|
||||
time::SystemTime,
|
||||
};
|
||||
use util::{paths::PathMatcher, TryFutureExt};
|
||||
|
||||
pub fn argsort<T: Ord>(data: &[T]) -> Vec<usize> {
|
||||
let mut indices = (0..data.len()).collect::<Vec<_>>();
|
||||
indices.sort_by_key(|&i| &data[i]);
|
||||
indices.reverse();
|
||||
indices
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct FileRecord {
|
||||
pub id: usize,
|
||||
pub relative_path: String,
|
||||
pub mtime: Timestamp,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct VectorDatabase {
|
||||
path: Arc<Path>,
|
||||
transactions:
|
||||
smol::channel::Sender<Box<dyn 'static + Send + FnOnce(&mut rusqlite::Connection)>>,
|
||||
}
|
||||
|
||||
impl VectorDatabase {
|
||||
pub async fn new(
|
||||
fs: Arc<dyn Fs>,
|
||||
path: Arc<Path>,
|
||||
executor: BackgroundExecutor,
|
||||
) -> Result<Self> {
|
||||
if let Some(db_directory) = path.parent() {
|
||||
fs.create_dir(db_directory).await?;
|
||||
}
|
||||
|
||||
let (transactions_tx, transactions_rx) = smol::channel::unbounded::<
|
||||
Box<dyn 'static + Send + FnOnce(&mut rusqlite::Connection)>,
|
||||
>();
|
||||
executor
|
||||
.spawn({
|
||||
let path = path.clone();
|
||||
async move {
|
||||
let mut connection = rusqlite::Connection::open(&path)?;
|
||||
|
||||
connection.pragma_update(None, "journal_mode", "wal")?;
|
||||
connection.pragma_update(None, "synchronous", "normal")?;
|
||||
connection.pragma_update(None, "cache_size", 1000000)?;
|
||||
connection.pragma_update(None, "temp_store", "MEMORY")?;
|
||||
|
||||
while let Ok(transaction) = transactions_rx.recv().await {
|
||||
transaction(&mut connection);
|
||||
}
|
||||
|
||||
anyhow::Ok(())
|
||||
}
|
||||
.log_err()
|
||||
})
|
||||
.detach();
|
||||
let this = Self {
|
||||
transactions: transactions_tx,
|
||||
path,
|
||||
};
|
||||
this.initialize_database().await?;
|
||||
Ok(this)
|
||||
}
|
||||
|
||||
pub fn path(&self) -> &Arc<Path> {
|
||||
&self.path
|
||||
}
|
||||
|
||||
fn transact<F, T>(&self, f: F) -> impl Future<Output = Result<T>>
|
||||
where
|
||||
F: 'static + Send + FnOnce(&rusqlite::Transaction) -> Result<T>,
|
||||
T: 'static + Send,
|
||||
{
|
||||
let (tx, rx) = oneshot::channel();
|
||||
let transactions = self.transactions.clone();
|
||||
async move {
|
||||
if transactions
|
||||
.send(Box::new(|connection| {
|
||||
let result = connection
|
||||
.transaction()
|
||||
.map_err(|err| anyhow!(err))
|
||||
.and_then(|transaction| {
|
||||
let result = f(&transaction)?;
|
||||
transaction.commit()?;
|
||||
Ok(result)
|
||||
});
|
||||
let _ = tx.send(result);
|
||||
}))
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
return Err(anyhow!("connection was dropped"))?;
|
||||
}
|
||||
rx.await?
|
||||
}
|
||||
}
|
||||
|
||||
fn initialize_database(&self) -> impl Future<Output = Result<()>> {
|
||||
self.transact(|db| {
|
||||
rusqlite::vtab::array::load_module(&db)?;
|
||||
|
||||
// Delete existing tables, if SEMANTIC_INDEX_VERSION is bumped
|
||||
let version_query = db.prepare("SELECT version from semantic_index_config");
|
||||
let version = version_query
|
||||
.and_then(|mut query| query.query_row([], |row| Ok(row.get::<_, i64>(0)?)));
|
||||
if version.map_or(false, |version| version == SEMANTIC_INDEX_VERSION as i64) {
|
||||
log::trace!("vector database schema up to date");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
log::trace!("vector database schema out of date. updating...");
|
||||
// We renamed the `documents` table to `spans`, so we want to drop
|
||||
// `documents` without recreating it if it exists.
|
||||
db.execute("DROP TABLE IF EXISTS documents", [])
|
||||
.context("failed to drop 'documents' table")?;
|
||||
db.execute("DROP TABLE IF EXISTS spans", [])
|
||||
.context("failed to drop 'spans' table")?;
|
||||
db.execute("DROP TABLE IF EXISTS files", [])
|
||||
.context("failed to drop 'files' table")?;
|
||||
db.execute("DROP TABLE IF EXISTS worktrees", [])
|
||||
.context("failed to drop 'worktrees' table")?;
|
||||
db.execute("DROP TABLE IF EXISTS semantic_index_config", [])
|
||||
.context("failed to drop 'semantic_index_config' table")?;
|
||||
|
||||
// Initialize Vector Databasing Tables
|
||||
db.execute(
|
||||
"CREATE TABLE semantic_index_config (
|
||||
version INTEGER NOT NULL
|
||||
)",
|
||||
[],
|
||||
)?;
|
||||
|
||||
db.execute(
|
||||
"INSERT INTO semantic_index_config (version) VALUES (?1)",
|
||||
params![SEMANTIC_INDEX_VERSION],
|
||||
)?;
|
||||
|
||||
db.execute(
|
||||
"CREATE TABLE worktrees (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
absolute_path VARCHAR NOT NULL
|
||||
);
|
||||
CREATE UNIQUE INDEX worktrees_absolute_path ON worktrees (absolute_path);
|
||||
",
|
||||
[],
|
||||
)?;
|
||||
|
||||
db.execute(
|
||||
"CREATE TABLE files (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
worktree_id INTEGER NOT NULL,
|
||||
relative_path VARCHAR NOT NULL,
|
||||
mtime_seconds INTEGER NOT NULL,
|
||||
mtime_nanos INTEGER NOT NULL,
|
||||
FOREIGN KEY(worktree_id) REFERENCES worktrees(id) ON DELETE CASCADE
|
||||
)",
|
||||
[],
|
||||
)?;
|
||||
|
||||
db.execute(
|
||||
"CREATE UNIQUE INDEX files_worktree_id_and_relative_path ON files (worktree_id, relative_path)",
|
||||
[],
|
||||
)?;
|
||||
|
||||
db.execute(
|
||||
"CREATE TABLE spans (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
file_id INTEGER NOT NULL,
|
||||
start_byte INTEGER NOT NULL,
|
||||
end_byte INTEGER NOT NULL,
|
||||
name VARCHAR NOT NULL,
|
||||
embedding BLOB NOT NULL,
|
||||
digest BLOB NOT NULL,
|
||||
FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE
|
||||
)",
|
||||
[],
|
||||
)?;
|
||||
db.execute(
|
||||
"CREATE INDEX spans_digest ON spans (digest)",
|
||||
[],
|
||||
)?;
|
||||
|
||||
log::trace!("vector database initialized with updated schema.");
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
pub fn delete_file(
|
||||
&self,
|
||||
worktree_id: i64,
|
||||
delete_path: Arc<Path>,
|
||||
) -> impl Future<Output = Result<()>> {
|
||||
self.transact(move |db| {
|
||||
db.execute(
|
||||
"DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2",
|
||||
params![worktree_id, delete_path.to_str()],
|
||||
)?;
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
pub fn insert_file(
|
||||
&self,
|
||||
worktree_id: i64,
|
||||
path: Arc<Path>,
|
||||
mtime: SystemTime,
|
||||
spans: Vec<Span>,
|
||||
) -> impl Future<Output = Result<()>> {
|
||||
self.transact(move |db| {
|
||||
// Return the existing ID, if both the file and mtime match
|
||||
let mtime = Timestamp::from(mtime);
|
||||
|
||||
db.execute(
|
||||
"
|
||||
REPLACE INTO files
|
||||
(worktree_id, relative_path, mtime_seconds, mtime_nanos)
|
||||
VALUES (?1, ?2, ?3, ?4)
|
||||
",
|
||||
params![worktree_id, path.to_str(), mtime.seconds, mtime.nanos],
|
||||
)?;
|
||||
|
||||
let file_id = db.last_insert_rowid();
|
||||
|
||||
let mut query = db.prepare(
|
||||
"
|
||||
INSERT INTO spans
|
||||
(file_id, start_byte, end_byte, name, embedding, digest)
|
||||
VALUES (?1, ?2, ?3, ?4, ?5, ?6)
|
||||
",
|
||||
)?;
|
||||
|
||||
for span in spans {
|
||||
query.execute(params![
|
||||
file_id,
|
||||
span.range.start.to_string(),
|
||||
span.range.end.to_string(),
|
||||
span.name,
|
||||
span.embedding,
|
||||
span.digest
|
||||
])?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
pub fn worktree_previously_indexed(
|
||||
&self,
|
||||
worktree_root_path: &Path,
|
||||
) -> impl Future<Output = Result<bool>> {
|
||||
let worktree_root_path = worktree_root_path.to_string_lossy().into_owned();
|
||||
self.transact(move |db| {
|
||||
let mut worktree_query =
|
||||
db.prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?;
|
||||
let worktree_id = worktree_query
|
||||
.query_row(params![worktree_root_path], |row| Ok(row.get::<_, i64>(0)?));
|
||||
|
||||
if worktree_id.is_ok() {
|
||||
return Ok(true);
|
||||
} else {
|
||||
return Ok(false);
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub fn embeddings_for_digests(
|
||||
&self,
|
||||
digests: Vec<SpanDigest>,
|
||||
) -> impl Future<Output = Result<HashMap<SpanDigest, Embedding>>> {
|
||||
self.transact(move |db| {
|
||||
let mut query = db.prepare(
|
||||
"
|
||||
SELECT digest, embedding
|
||||
FROM spans
|
||||
WHERE digest IN rarray(?)
|
||||
",
|
||||
)?;
|
||||
let mut embeddings_by_digest = HashMap::default();
|
||||
let digests = Rc::new(
|
||||
digests
|
||||
.into_iter()
|
||||
.map(|p| Value::Blob(p.0.to_vec()))
|
||||
.collect::<Vec<_>>(),
|
||||
);
|
||||
let rows = query.query_map(params![digests], |row| {
|
||||
Ok((row.get::<_, SpanDigest>(0)?, row.get::<_, Embedding>(1)?))
|
||||
})?;
|
||||
|
||||
for row in rows {
|
||||
if let Ok(row) = row {
|
||||
embeddings_by_digest.insert(row.0, row.1);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(embeddings_by_digest)
|
||||
})
|
||||
}
|
||||
|
||||
pub fn embeddings_for_files(
|
||||
&self,
|
||||
worktree_id_file_paths: HashMap<i64, Vec<Arc<Path>>>,
|
||||
) -> impl Future<Output = Result<HashMap<SpanDigest, Embedding>>> {
|
||||
self.transact(move |db| {
|
||||
let mut query = db.prepare(
|
||||
"
|
||||
SELECT digest, embedding
|
||||
FROM spans
|
||||
LEFT JOIN files ON files.id = spans.file_id
|
||||
WHERE files.worktree_id = ? AND files.relative_path IN rarray(?)
|
||||
",
|
||||
)?;
|
||||
let mut embeddings_by_digest = HashMap::default();
|
||||
for (worktree_id, file_paths) in worktree_id_file_paths {
|
||||
let file_paths = Rc::new(
|
||||
file_paths
|
||||
.into_iter()
|
||||
.map(|p| Value::Text(p.to_string_lossy().into_owned()))
|
||||
.collect::<Vec<_>>(),
|
||||
);
|
||||
let rows = query.query_map(params![worktree_id, file_paths], |row| {
|
||||
Ok((row.get::<_, SpanDigest>(0)?, row.get::<_, Embedding>(1)?))
|
||||
})?;
|
||||
|
||||
for row in rows {
|
||||
if let Ok(row) = row {
|
||||
embeddings_by_digest.insert(row.0, row.1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(embeddings_by_digest)
|
||||
})
|
||||
}
|
||||
|
||||
pub fn find_or_create_worktree(
|
||||
&self,
|
||||
worktree_root_path: Arc<Path>,
|
||||
) -> impl Future<Output = Result<i64>> {
|
||||
self.transact(move |db| {
|
||||
let mut worktree_query =
|
||||
db.prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?;
|
||||
let worktree_id = worktree_query
|
||||
.query_row(params![worktree_root_path.to_string_lossy()], |row| {
|
||||
Ok(row.get::<_, i64>(0)?)
|
||||
});
|
||||
|
||||
if worktree_id.is_ok() {
|
||||
return Ok(worktree_id?);
|
||||
}
|
||||
|
||||
// If worktree_id is Err, insert new worktree
|
||||
db.execute(
|
||||
"INSERT into worktrees (absolute_path) VALUES (?1)",
|
||||
params![worktree_root_path.to_string_lossy()],
|
||||
)?;
|
||||
Ok(db.last_insert_rowid())
|
||||
})
|
||||
}
|
||||
|
||||
pub fn get_file_mtimes(
|
||||
&self,
|
||||
worktree_id: i64,
|
||||
) -> impl Future<Output = Result<HashMap<PathBuf, SystemTime>>> {
|
||||
self.transact(move |db| {
|
||||
let mut statement = db.prepare(
|
||||
"
|
||||
SELECT relative_path, mtime_seconds, mtime_nanos
|
||||
FROM files
|
||||
WHERE worktree_id = ?1
|
||||
ORDER BY relative_path",
|
||||
)?;
|
||||
let mut result: HashMap<PathBuf, SystemTime> = HashMap::default();
|
||||
for row in statement.query_map(params![worktree_id], |row| {
|
||||
Ok((
|
||||
row.get::<_, String>(0)?.into(),
|
||||
Timestamp {
|
||||
seconds: row.get(1)?,
|
||||
nanos: row.get(2)?,
|
||||
}
|
||||
.into(),
|
||||
))
|
||||
})? {
|
||||
let row = row?;
|
||||
result.insert(row.0, row.1);
|
||||
}
|
||||
Ok(result)
|
||||
})
|
||||
}
|
||||
|
||||
pub fn top_k_search(
|
||||
&self,
|
||||
query_embedding: &Embedding,
|
||||
limit: usize,
|
||||
file_ids: &[i64],
|
||||
) -> impl Future<Output = Result<Vec<(i64, OrderedFloat<f32>)>>> {
|
||||
let file_ids = file_ids.to_vec();
|
||||
let query = query_embedding.clone().0;
|
||||
let query = Array1::from_vec(query);
|
||||
self.transact(move |db| {
|
||||
let mut query_statement = db.prepare(
|
||||
"
|
||||
SELECT
|
||||
id, embedding
|
||||
FROM
|
||||
spans
|
||||
WHERE
|
||||
file_id IN rarray(?)
|
||||
",
|
||||
)?;
|
||||
|
||||
let deserialized_rows = query_statement
|
||||
.query_map(params![ids_to_sql(&file_ids)], |row| {
|
||||
Ok((row.get::<_, usize>(0)?, row.get::<_, Embedding>(1)?))
|
||||
})?
|
||||
.filter_map(|row| row.ok())
|
||||
.collect::<Vec<(usize, Embedding)>>();
|
||||
|
||||
if deserialized_rows.len() == 0 {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
// Get Length of Embeddings Returned
|
||||
let embedding_len = deserialized_rows[0].1 .0.len();
|
||||
|
||||
let batch_n = 1000;
|
||||
let mut batches = Vec::new();
|
||||
let mut batch_ids = Vec::new();
|
||||
let mut batch_embeddings: Vec<f32> = Vec::new();
|
||||
deserialized_rows.iter().for_each(|(id, embedding)| {
|
||||
batch_ids.push(id);
|
||||
batch_embeddings.extend(&embedding.0);
|
||||
|
||||
if batch_ids.len() == batch_n {
|
||||
let embeddings = std::mem::take(&mut batch_embeddings);
|
||||
let ids = std::mem::take(&mut batch_ids);
|
||||
let array =
|
||||
Array2::from_shape_vec((ids.len(), embedding_len.clone()), embeddings);
|
||||
match array {
|
||||
Ok(array) => {
|
||||
batches.push((ids, array));
|
||||
}
|
||||
Err(err) => log::error!("Failed to deserialize to ndarray: {:?}", err),
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
if batch_ids.len() > 0 {
|
||||
let array = Array2::from_shape_vec(
|
||||
(batch_ids.len(), embedding_len),
|
||||
batch_embeddings.clone(),
|
||||
);
|
||||
match array {
|
||||
Ok(array) => {
|
||||
batches.push((batch_ids.clone(), array));
|
||||
}
|
||||
Err(err) => log::error!("Failed to deserialize to ndarray: {:?}", err),
|
||||
}
|
||||
}
|
||||
|
||||
let mut ids: Vec<usize> = Vec::new();
|
||||
let mut results = Vec::new();
|
||||
for (batch_ids, array) in batches {
|
||||
let scores = array
|
||||
.dot(&query.t())
|
||||
.to_vec()
|
||||
.iter()
|
||||
.map(|score| OrderedFloat(*score))
|
||||
.collect::<Vec<OrderedFloat<f32>>>();
|
||||
results.extend(scores);
|
||||
ids.extend(batch_ids);
|
||||
}
|
||||
|
||||
let sorted_idx = argsort(&results);
|
||||
let mut sorted_results = Vec::new();
|
||||
let last_idx = limit.min(sorted_idx.len());
|
||||
for idx in &sorted_idx[0..last_idx] {
|
||||
sorted_results.push((ids[*idx] as i64, results[*idx]))
|
||||
}
|
||||
|
||||
Ok(sorted_results)
|
||||
})
|
||||
}
|
||||
|
||||
pub fn retrieve_included_file_ids(
|
||||
&self,
|
||||
worktree_ids: &[i64],
|
||||
includes: &[PathMatcher],
|
||||
excludes: &[PathMatcher],
|
||||
) -> impl Future<Output = Result<Vec<i64>>> {
|
||||
let worktree_ids = worktree_ids.to_vec();
|
||||
let includes = includes.to_vec();
|
||||
let excludes = excludes.to_vec();
|
||||
self.transact(move |db| {
|
||||
let mut file_query = db.prepare(
|
||||
"
|
||||
SELECT
|
||||
id, relative_path
|
||||
FROM
|
||||
files
|
||||
WHERE
|
||||
worktree_id IN rarray(?)
|
||||
",
|
||||
)?;
|
||||
|
||||
let mut file_ids = Vec::<i64>::new();
|
||||
let mut rows = file_query.query([ids_to_sql(&worktree_ids)])?;
|
||||
|
||||
while let Some(row) = rows.next()? {
|
||||
let file_id = row.get(0)?;
|
||||
let relative_path = row.get_ref(1)?.as_str()?;
|
||||
let included =
|
||||
includes.is_empty() || includes.iter().any(|glob| glob.is_match(relative_path));
|
||||
let excluded = excludes.iter().any(|glob| glob.is_match(relative_path));
|
||||
if included && !excluded {
|
||||
file_ids.push(file_id);
|
||||
}
|
||||
}
|
||||
|
||||
anyhow::Ok(file_ids)
|
||||
})
|
||||
}
|
||||
|
||||
pub fn spans_for_ids(
|
||||
&self,
|
||||
ids: &[i64],
|
||||
) -> impl Future<Output = Result<Vec<(i64, PathBuf, Range<usize>)>>> {
|
||||
let ids = ids.to_vec();
|
||||
self.transact(move |db| {
|
||||
let mut statement = db.prepare(
|
||||
"
|
||||
SELECT
|
||||
spans.id,
|
||||
files.worktree_id,
|
||||
files.relative_path,
|
||||
spans.start_byte,
|
||||
spans.end_byte
|
||||
FROM
|
||||
spans, files
|
||||
WHERE
|
||||
spans.file_id = files.id AND
|
||||
spans.id in rarray(?)
|
||||
",
|
||||
)?;
|
||||
|
||||
let result_iter = statement.query_map(params![ids_to_sql(&ids)], |row| {
|
||||
Ok((
|
||||
row.get::<_, i64>(0)?,
|
||||
row.get::<_, i64>(1)?,
|
||||
row.get::<_, String>(2)?.into(),
|
||||
row.get(3)?..row.get(4)?,
|
||||
))
|
||||
})?;
|
||||
|
||||
let mut values_by_id = HashMap::<i64, (i64, PathBuf, Range<usize>)>::default();
|
||||
for row in result_iter {
|
||||
let (id, worktree_id, path, range) = row?;
|
||||
values_by_id.insert(id, (worktree_id, path, range));
|
||||
}
|
||||
|
||||
let mut results = Vec::with_capacity(ids.len());
|
||||
for id in &ids {
|
||||
let value = values_by_id
|
||||
.remove(id)
|
||||
.ok_or(anyhow!("missing span id {}", id))?;
|
||||
results.push(value);
|
||||
}
|
||||
|
||||
Ok(results)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn ids_to_sql(ids: &[i64]) -> Rc<Vec<rusqlite::types::Value>> {
|
||||
Rc::new(
|
||||
ids.iter()
|
||||
.copied()
|
||||
.map(|v| rusqlite::types::Value::from(v))
|
||||
.collect::<Vec<_>>(),
|
||||
)
|
||||
}
|
@ -1,169 +0,0 @@
|
||||
use crate::{parsing::Span, JobHandle};
|
||||
use ai::embedding::EmbeddingProvider;
|
||||
use gpui::BackgroundExecutor;
|
||||
use parking_lot::Mutex;
|
||||
use smol::channel;
|
||||
use std::{mem, ops::Range, path::Path, sync::Arc, time::SystemTime};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct FileToEmbed {
|
||||
pub worktree_id: i64,
|
||||
pub path: Arc<Path>,
|
||||
pub mtime: SystemTime,
|
||||
pub spans: Vec<Span>,
|
||||
pub job_handle: JobHandle,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for FileToEmbed {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("FileToEmbed")
|
||||
.field("worktree_id", &self.worktree_id)
|
||||
.field("path", &self.path)
|
||||
.field("mtime", &self.mtime)
|
||||
.field("spans", &self.spans)
|
||||
.finish_non_exhaustive()
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq for FileToEmbed {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.worktree_id == other.worktree_id
|
||||
&& self.path == other.path
|
||||
&& self.mtime == other.mtime
|
||||
&& self.spans == other.spans
|
||||
}
|
||||
}
|
||||
|
||||
pub struct EmbeddingQueue {
|
||||
embedding_provider: Arc<dyn EmbeddingProvider>,
|
||||
pending_batch: Vec<FileFragmentToEmbed>,
|
||||
executor: BackgroundExecutor,
|
||||
pending_batch_token_count: usize,
|
||||
finished_files_tx: channel::Sender<FileToEmbed>,
|
||||
finished_files_rx: channel::Receiver<FileToEmbed>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct FileFragmentToEmbed {
|
||||
file: Arc<Mutex<FileToEmbed>>,
|
||||
span_range: Range<usize>,
|
||||
}
|
||||
|
||||
impl EmbeddingQueue {
|
||||
pub fn new(
|
||||
embedding_provider: Arc<dyn EmbeddingProvider>,
|
||||
executor: BackgroundExecutor,
|
||||
) -> Self {
|
||||
let (finished_files_tx, finished_files_rx) = channel::unbounded();
|
||||
Self {
|
||||
embedding_provider,
|
||||
executor,
|
||||
pending_batch: Vec::new(),
|
||||
pending_batch_token_count: 0,
|
||||
finished_files_tx,
|
||||
finished_files_rx,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn push(&mut self, file: FileToEmbed) {
|
||||
if file.spans.is_empty() {
|
||||
self.finished_files_tx.try_send(file).unwrap();
|
||||
return;
|
||||
}
|
||||
|
||||
let file = Arc::new(Mutex::new(file));
|
||||
|
||||
self.pending_batch.push(FileFragmentToEmbed {
|
||||
file: file.clone(),
|
||||
span_range: 0..0,
|
||||
});
|
||||
|
||||
let mut fragment_range = &mut self.pending_batch.last_mut().unwrap().span_range;
|
||||
for (ix, span) in file.lock().spans.iter().enumerate() {
|
||||
let span_token_count = if span.embedding.is_none() {
|
||||
span.token_count
|
||||
} else {
|
||||
0
|
||||
};
|
||||
|
||||
let next_token_count = self.pending_batch_token_count + span_token_count;
|
||||
if next_token_count > self.embedding_provider.max_tokens_per_batch() {
|
||||
let range_end = fragment_range.end;
|
||||
self.flush();
|
||||
self.pending_batch.push(FileFragmentToEmbed {
|
||||
file: file.clone(),
|
||||
span_range: range_end..range_end,
|
||||
});
|
||||
fragment_range = &mut self.pending_batch.last_mut().unwrap().span_range;
|
||||
}
|
||||
|
||||
fragment_range.end = ix + 1;
|
||||
self.pending_batch_token_count += span_token_count;
|
||||
}
|
||||
}
|
||||
|
||||
pub fn flush(&mut self) {
|
||||
let batch = mem::take(&mut self.pending_batch);
|
||||
self.pending_batch_token_count = 0;
|
||||
if batch.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
let finished_files_tx = self.finished_files_tx.clone();
|
||||
let embedding_provider = self.embedding_provider.clone();
|
||||
|
||||
self.executor
|
||||
.spawn(async move {
|
||||
let mut spans = Vec::new();
|
||||
for fragment in &batch {
|
||||
let file = fragment.file.lock();
|
||||
spans.extend(
|
||||
file.spans[fragment.span_range.clone()]
|
||||
.iter()
|
||||
.filter(|d| d.embedding.is_none())
|
||||
.map(|d| d.content.clone()),
|
||||
);
|
||||
}
|
||||
|
||||
// If spans is 0, just send the fragment to the finished files if its the last one.
|
||||
if spans.is_empty() {
|
||||
for fragment in batch.clone() {
|
||||
if let Some(file) = Arc::into_inner(fragment.file) {
|
||||
finished_files_tx.try_send(file.into_inner()).unwrap();
|
||||
}
|
||||
}
|
||||
return;
|
||||
};
|
||||
|
||||
match embedding_provider.embed_batch(spans).await {
|
||||
Ok(embeddings) => {
|
||||
let mut embeddings = embeddings.into_iter();
|
||||
for fragment in batch {
|
||||
for span in &mut fragment.file.lock().spans[fragment.span_range.clone()]
|
||||
.iter_mut()
|
||||
.filter(|d| d.embedding.is_none())
|
||||
{
|
||||
if let Some(embedding) = embeddings.next() {
|
||||
span.embedding = Some(embedding);
|
||||
} else {
|
||||
log::error!("number of embeddings != number of documents");
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(file) = Arc::into_inner(fragment.file) {
|
||||
finished_files_tx.try_send(file.into_inner()).unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(error) => {
|
||||
log::error!("{:?}", error);
|
||||
}
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
|
||||
pub fn finished_files(&self) -> channel::Receiver<FileToEmbed> {
|
||||
self.finished_files_rx.clone()
|
||||
}
|
||||
}
|
@ -1,414 +0,0 @@
|
||||
use ai::{
|
||||
embedding::{Embedding, EmbeddingProvider},
|
||||
models::TruncationDirection,
|
||||
};
|
||||
use anyhow::{anyhow, Result};
|
||||
use language::{Grammar, Language};
|
||||
use rusqlite::{
|
||||
types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef},
|
||||
ToSql,
|
||||
};
|
||||
use sha1::{Digest, Sha1};
|
||||
use std::{
|
||||
borrow::Cow,
|
||||
cmp::{self, Reverse},
|
||||
collections::HashSet,
|
||||
ops::Range,
|
||||
path::Path,
|
||||
sync::Arc,
|
||||
};
|
||||
use tree_sitter::{Parser, QueryCursor};
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Hash)]
|
||||
pub struct SpanDigest(pub [u8; 20]);
|
||||
|
||||
impl FromSql for SpanDigest {
|
||||
fn column_result(value: ValueRef) -> FromSqlResult<Self> {
|
||||
let blob = value.as_blob()?;
|
||||
let bytes =
|
||||
blob.try_into()
|
||||
.map_err(|_| rusqlite::types::FromSqlError::InvalidBlobSize {
|
||||
expected_size: 20,
|
||||
blob_size: blob.len(),
|
||||
})?;
|
||||
return Ok(SpanDigest(bytes));
|
||||
}
|
||||
}
|
||||
|
||||
impl ToSql for SpanDigest {
|
||||
fn to_sql(&self) -> rusqlite::Result<ToSqlOutput> {
|
||||
self.0.to_sql()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&'_ str> for SpanDigest {
|
||||
fn from(value: &'_ str) -> Self {
|
||||
let mut sha1 = Sha1::new();
|
||||
sha1.update(value);
|
||||
Self(sha1.finalize().into())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Clone)]
|
||||
pub struct Span {
|
||||
pub name: String,
|
||||
pub range: Range<usize>,
|
||||
pub content: String,
|
||||
pub embedding: Option<Embedding>,
|
||||
pub digest: SpanDigest,
|
||||
pub token_count: usize,
|
||||
}
|
||||
|
||||
const CODE_CONTEXT_TEMPLATE: &str =
|
||||
"The below code snippet is from file '<path>'\n\n```<language>\n<item>\n```";
|
||||
const ENTIRE_FILE_TEMPLATE: &str =
|
||||
"The below snippet is from file '<path>'\n\n```<language>\n<item>\n```";
|
||||
const MARKDOWN_CONTEXT_TEMPLATE: &str = "The below file contents is from file '<path>'\n\n<item>";
|
||||
pub const PARSEABLE_ENTIRE_FILE_TYPES: &[&str] = &[
|
||||
"TOML", "YAML", "CSS", "HEEX", "ERB", "SVELTE", "HTML", "Scheme",
|
||||
];
|
||||
|
||||
pub struct CodeContextRetriever {
|
||||
pub parser: Parser,
|
||||
pub cursor: QueryCursor,
|
||||
pub embedding_provider: Arc<dyn EmbeddingProvider>,
|
||||
}
|
||||
|
||||
// Every match has an item, this represents the fundamental treesitter symbol and anchors the search
|
||||
// Every match has one or more 'name' captures. These indicate the display range of the item for deduplication.
|
||||
// If there are preceeding comments, we track this with a context capture
|
||||
// If there is a piece that should be collapsed in hierarchical queries, we capture it with a collapse capture
|
||||
// If there is a piece that should be kept inside a collapsed node, we capture it with a keep capture
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CodeContextMatch {
|
||||
pub start_col: usize,
|
||||
pub item_range: Option<Range<usize>>,
|
||||
pub name_range: Option<Range<usize>>,
|
||||
pub context_ranges: Vec<Range<usize>>,
|
||||
pub collapse_ranges: Vec<Range<usize>>,
|
||||
}
|
||||
|
||||
impl CodeContextRetriever {
|
||||
pub fn new(embedding_provider: Arc<dyn EmbeddingProvider>) -> Self {
|
||||
Self {
|
||||
parser: Parser::new(),
|
||||
cursor: QueryCursor::new(),
|
||||
embedding_provider,
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_entire_file(
|
||||
&self,
|
||||
relative_path: Option<&Path>,
|
||||
language_name: Arc<str>,
|
||||
content: &str,
|
||||
) -> Result<Vec<Span>> {
|
||||
let document_span = ENTIRE_FILE_TEMPLATE
|
||||
.replace(
|
||||
"<path>",
|
||||
&relative_path.map_or(Cow::Borrowed("untitled"), |path| path.to_string_lossy()),
|
||||
)
|
||||
.replace("<language>", language_name.as_ref())
|
||||
.replace("<item>", &content);
|
||||
let digest = SpanDigest::from(document_span.as_str());
|
||||
let model = self.embedding_provider.base_model();
|
||||
let document_span = model.truncate(
|
||||
&document_span,
|
||||
model.capacity()?,
|
||||
ai::models::TruncationDirection::End,
|
||||
)?;
|
||||
let token_count = model.count_tokens(&document_span)?;
|
||||
|
||||
Ok(vec![Span {
|
||||
range: 0..content.len(),
|
||||
content: document_span,
|
||||
embedding: Default::default(),
|
||||
name: language_name.to_string(),
|
||||
digest,
|
||||
token_count,
|
||||
}])
|
||||
}
|
||||
|
||||
fn parse_markdown_file(
|
||||
&self,
|
||||
relative_path: Option<&Path>,
|
||||
content: &str,
|
||||
) -> Result<Vec<Span>> {
|
||||
let document_span = MARKDOWN_CONTEXT_TEMPLATE
|
||||
.replace(
|
||||
"<path>",
|
||||
&relative_path.map_or(Cow::Borrowed("untitled"), |path| path.to_string_lossy()),
|
||||
)
|
||||
.replace("<item>", &content);
|
||||
let digest = SpanDigest::from(document_span.as_str());
|
||||
|
||||
let model = self.embedding_provider.base_model();
|
||||
let document_span = model.truncate(
|
||||
&document_span,
|
||||
model.capacity()?,
|
||||
ai::models::TruncationDirection::End,
|
||||
)?;
|
||||
let token_count = model.count_tokens(&document_span)?;
|
||||
|
||||
Ok(vec![Span {
|
||||
range: 0..content.len(),
|
||||
content: document_span,
|
||||
embedding: None,
|
||||
name: "Markdown".to_string(),
|
||||
digest,
|
||||
token_count,
|
||||
}])
|
||||
}
|
||||
|
||||
fn get_matches_in_file(
|
||||
&mut self,
|
||||
content: &str,
|
||||
grammar: &Arc<Grammar>,
|
||||
) -> Result<Vec<CodeContextMatch>> {
|
||||
let embedding_config = grammar
|
||||
.embedding_config
|
||||
.as_ref()
|
||||
.ok_or_else(|| anyhow!("no embedding queries"))?;
|
||||
self.parser.set_language(&grammar.ts_language).unwrap();
|
||||
|
||||
let tree = self
|
||||
.parser
|
||||
.parse(&content, None)
|
||||
.ok_or_else(|| anyhow!("parsing failed"))?;
|
||||
|
||||
let mut captures: Vec<CodeContextMatch> = Vec::new();
|
||||
let mut collapse_ranges: Vec<Range<usize>> = Vec::new();
|
||||
let mut keep_ranges: Vec<Range<usize>> = Vec::new();
|
||||
for mat in self.cursor.matches(
|
||||
&embedding_config.query,
|
||||
tree.root_node(),
|
||||
content.as_bytes(),
|
||||
) {
|
||||
let mut start_col = 0;
|
||||
let mut item_range: Option<Range<usize>> = None;
|
||||
let mut name_range: Option<Range<usize>> = None;
|
||||
let mut context_ranges: Vec<Range<usize>> = Vec::new();
|
||||
collapse_ranges.clear();
|
||||
keep_ranges.clear();
|
||||
for capture in mat.captures {
|
||||
if capture.index == embedding_config.item_capture_ix {
|
||||
item_range = Some(capture.node.byte_range());
|
||||
start_col = capture.node.start_position().column;
|
||||
} else if Some(capture.index) == embedding_config.name_capture_ix {
|
||||
name_range = Some(capture.node.byte_range());
|
||||
} else if Some(capture.index) == embedding_config.context_capture_ix {
|
||||
context_ranges.push(capture.node.byte_range());
|
||||
} else if Some(capture.index) == embedding_config.collapse_capture_ix {
|
||||
collapse_ranges.push(capture.node.byte_range());
|
||||
} else if Some(capture.index) == embedding_config.keep_capture_ix {
|
||||
keep_ranges.push(capture.node.byte_range());
|
||||
}
|
||||
}
|
||||
|
||||
captures.push(CodeContextMatch {
|
||||
start_col,
|
||||
item_range,
|
||||
name_range,
|
||||
context_ranges,
|
||||
collapse_ranges: subtract_ranges(&collapse_ranges, &keep_ranges),
|
||||
});
|
||||
}
|
||||
Ok(captures)
|
||||
}
|
||||
|
||||
pub fn parse_file_with_template(
|
||||
&mut self,
|
||||
relative_path: Option<&Path>,
|
||||
content: &str,
|
||||
language: Arc<Language>,
|
||||
) -> Result<Vec<Span>> {
|
||||
let language_name = language.name();
|
||||
|
||||
if PARSEABLE_ENTIRE_FILE_TYPES.contains(&language_name.as_ref()) {
|
||||
return self.parse_entire_file(relative_path, language_name, &content);
|
||||
} else if ["Markdown", "Plain Text"].contains(&language_name.as_ref()) {
|
||||
return self.parse_markdown_file(relative_path, &content);
|
||||
}
|
||||
|
||||
let mut spans = self.parse_file(content, language)?;
|
||||
for span in &mut spans {
|
||||
let document_content = CODE_CONTEXT_TEMPLATE
|
||||
.replace(
|
||||
"<path>",
|
||||
&relative_path.map_or(Cow::Borrowed("untitled"), |path| path.to_string_lossy()),
|
||||
)
|
||||
.replace("<language>", language_name.as_ref())
|
||||
.replace("item", &span.content);
|
||||
|
||||
let model = self.embedding_provider.base_model();
|
||||
let document_content = model.truncate(
|
||||
&document_content,
|
||||
model.capacity()?,
|
||||
TruncationDirection::End,
|
||||
)?;
|
||||
let token_count = model.count_tokens(&document_content)?;
|
||||
|
||||
span.content = document_content;
|
||||
span.token_count = token_count;
|
||||
}
|
||||
Ok(spans)
|
||||
}
|
||||
|
||||
pub fn parse_file(&mut self, content: &str, language: Arc<Language>) -> Result<Vec<Span>> {
|
||||
let grammar = language
|
||||
.grammar()
|
||||
.ok_or_else(|| anyhow!("no grammar for language"))?;
|
||||
|
||||
// Iterate through query matches
|
||||
let matches = self.get_matches_in_file(content, grammar)?;
|
||||
|
||||
let language_scope = language.default_scope();
|
||||
let placeholder = language_scope.collapsed_placeholder();
|
||||
|
||||
let mut spans = Vec::new();
|
||||
let mut collapsed_ranges_within = Vec::new();
|
||||
let mut parsed_name_ranges = HashSet::new();
|
||||
for (i, context_match) in matches.iter().enumerate() {
|
||||
// Items which are collapsible but not embeddable have no item range
|
||||
let item_range = if let Some(item_range) = context_match.item_range.clone() {
|
||||
item_range
|
||||
} else {
|
||||
continue;
|
||||
};
|
||||
|
||||
// Checks for deduplication
|
||||
let name;
|
||||
if let Some(name_range) = context_match.name_range.clone() {
|
||||
name = content
|
||||
.get(name_range.clone())
|
||||
.map_or(String::new(), |s| s.to_string());
|
||||
if parsed_name_ranges.contains(&name_range) {
|
||||
continue;
|
||||
}
|
||||
parsed_name_ranges.insert(name_range);
|
||||
} else {
|
||||
name = String::new();
|
||||
}
|
||||
|
||||
collapsed_ranges_within.clear();
|
||||
'outer: for remaining_match in &matches[(i + 1)..] {
|
||||
for collapsed_range in &remaining_match.collapse_ranges {
|
||||
if item_range.start <= collapsed_range.start
|
||||
&& item_range.end >= collapsed_range.end
|
||||
{
|
||||
collapsed_ranges_within.push(collapsed_range.clone());
|
||||
} else {
|
||||
break 'outer;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
collapsed_ranges_within.sort_by_key(|r| (r.start, Reverse(r.end)));
|
||||
|
||||
let mut span_content = String::new();
|
||||
for context_range in &context_match.context_ranges {
|
||||
add_content_from_range(
|
||||
&mut span_content,
|
||||
content,
|
||||
context_range.clone(),
|
||||
context_match.start_col,
|
||||
);
|
||||
span_content.push_str("\n");
|
||||
}
|
||||
|
||||
let mut offset = item_range.start;
|
||||
for collapsed_range in &collapsed_ranges_within {
|
||||
if collapsed_range.start > offset {
|
||||
add_content_from_range(
|
||||
&mut span_content,
|
||||
content,
|
||||
offset..collapsed_range.start,
|
||||
context_match.start_col,
|
||||
);
|
||||
offset = collapsed_range.start;
|
||||
}
|
||||
|
||||
if collapsed_range.end > offset {
|
||||
span_content.push_str(placeholder);
|
||||
offset = collapsed_range.end;
|
||||
}
|
||||
}
|
||||
|
||||
if offset < item_range.end {
|
||||
add_content_from_range(
|
||||
&mut span_content,
|
||||
content,
|
||||
offset..item_range.end,
|
||||
context_match.start_col,
|
||||
);
|
||||
}
|
||||
|
||||
let sha1 = SpanDigest::from(span_content.as_str());
|
||||
spans.push(Span {
|
||||
name,
|
||||
content: span_content,
|
||||
range: item_range.clone(),
|
||||
embedding: None,
|
||||
digest: sha1,
|
||||
token_count: 0,
|
||||
})
|
||||
}
|
||||
|
||||
return Ok(spans);
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn subtract_ranges(
|
||||
ranges: &[Range<usize>],
|
||||
ranges_to_subtract: &[Range<usize>],
|
||||
) -> Vec<Range<usize>> {
|
||||
let mut result = Vec::new();
|
||||
|
||||
let mut ranges_to_subtract = ranges_to_subtract.iter().peekable();
|
||||
|
||||
for range in ranges {
|
||||
let mut offset = range.start;
|
||||
|
||||
while offset < range.end {
|
||||
if let Some(range_to_subtract) = ranges_to_subtract.peek() {
|
||||
if offset < range_to_subtract.start {
|
||||
let next_offset = cmp::min(range_to_subtract.start, range.end);
|
||||
result.push(offset..next_offset);
|
||||
offset = next_offset;
|
||||
} else {
|
||||
let next_offset = cmp::min(range_to_subtract.end, range.end);
|
||||
offset = next_offset;
|
||||
}
|
||||
|
||||
if offset >= range_to_subtract.end {
|
||||
ranges_to_subtract.next();
|
||||
}
|
||||
} else {
|
||||
result.push(offset..range.end);
|
||||
offset = range.end;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
fn add_content_from_range(
|
||||
output: &mut String,
|
||||
content: &str,
|
||||
range: Range<usize>,
|
||||
start_col: usize,
|
||||
) {
|
||||
for mut line in content.get(range.clone()).unwrap_or("").lines() {
|
||||
for _ in 0..start_col {
|
||||
if line.starts_with(' ') {
|
||||
line = &line[1..];
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
output.push_str(line);
|
||||
output.push('\n');
|
||||
}
|
||||
output.pop();
|
||||
}
|
File diff suppressed because it is too large
Load Diff
@ -1,28 +0,0 @@
|
||||
use anyhow;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::Settings;
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct SemanticIndexSettings {
|
||||
pub enabled: bool,
|
||||
}
|
||||
|
||||
#[derive(Clone, Default, Serialize, Deserialize, JsonSchema, Debug)]
|
||||
pub struct SemanticIndexSettingsContent {
|
||||
pub enabled: Option<bool>,
|
||||
}
|
||||
|
||||
impl Settings for SemanticIndexSettings {
|
||||
const KEY: Option<&'static str> = Some("semantic_index");
|
||||
|
||||
type FileContent = SemanticIndexSettingsContent;
|
||||
|
||||
fn load(
|
||||
default_value: &Self::FileContent,
|
||||
user_values: &[&Self::FileContent],
|
||||
_: &mut gpui::AppContext,
|
||||
) -> anyhow::Result<Self> {
|
||||
Self::load_via_json_merge(default_value, user_values)
|
||||
}
|
||||
}
|
File diff suppressed because it is too large
Load Diff
@ -15,7 +15,7 @@ test-support = ["gpui/test-support", "fs/test-support"]
|
||||
collections = { path = "../collections" }
|
||||
gpui = {package = "gpui2", path = "../gpui2" }
|
||||
sqlez = { path = "../sqlez" }
|
||||
fs = {package = "fs2", path = "../fs2" }
|
||||
fs = { path = "../fs" }
|
||||
feature_flags = { path = "../feature_flags" }
|
||||
util = { path = "../util" }
|
||||
|
||||
@ -36,7 +36,7 @@ tree-sitter-json = "*"
|
||||
|
||||
[dev-dependencies]
|
||||
gpui = {package = "gpui2", path = "../gpui2", features = ["test-support"] }
|
||||
fs = { package = "fs2", path = "../fs2", features = ["test-support"] }
|
||||
fs = { path = "../fs", features = ["test-support"] }
|
||||
indoc.workspace = true
|
||||
pretty_assertions.workspace = true
|
||||
unindent.workspace = true
|
||||
|
@ -12,7 +12,7 @@ doctest = false
|
||||
[dependencies]
|
||||
gpui = { package = "gpui2", path = "../gpui2" }
|
||||
settings = { package = "settings2", path = "../settings2" }
|
||||
db = { package = "db2", path = "../db2" }
|
||||
db = { path = "../db" }
|
||||
theme = { package = "theme2", path = "../theme2" }
|
||||
util = { path = "../util" }
|
||||
|
||||
|
@ -18,7 +18,7 @@ settings = { package = "settings2", path = "../settings2" }
|
||||
theme = { package = "theme2", path = "../theme2" }
|
||||
util = { path = "../util" }
|
||||
workspace = { path = "../workspace" }
|
||||
db = { package = "db2", path = "../db2" }
|
||||
db = { path = "../db" }
|
||||
procinfo = { git = "https://github.com/zed-industries/wezterm", rev = "5cd757e5f2eb039ed0c6bb6512223e69d5efc64d", default-features = false }
|
||||
terminal = { package = "terminal2", path = "../terminal2" }
|
||||
ui = { package = "ui2", path = "../ui2" }
|
||||
|
@ -20,7 +20,7 @@ doctest = false
|
||||
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
fs = { package = "fs2", path = "../fs2" }
|
||||
fs = { path = "../fs" }
|
||||
gpui = { package = "gpui2", path = "../gpui2" }
|
||||
indexmap = "1.6.2"
|
||||
parking_lot.workspace = true
|
||||
@ -38,5 +38,5 @@ itertools = { version = "0.11.0", optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
gpui = { package = "gpui2", path = "../gpui2", features = ["test-support"] }
|
||||
fs = { package = "fs2", path = "../fs2", features = ["test-support"] }
|
||||
fs = { path = "../fs", features = ["test-support"] }
|
||||
settings = { package = "settings2", path = "../settings2", features = ["test-support"] }
|
||||
|
@ -12,7 +12,7 @@ doctest = false
|
||||
client = { path = "../client" }
|
||||
editor = { path = "../editor" }
|
||||
feature_flags = { path = "../feature_flags" }
|
||||
fs = { package = "fs2", path = "../fs2" }
|
||||
fs = { path = "../fs" }
|
||||
fuzzy = { path = "../fuzzy" }
|
||||
gpui = { package = "gpui2", path = "../gpui2" }
|
||||
picker = { path = "../picker" }
|
||||
|
@ -7,7 +7,7 @@ publish = false
|
||||
|
||||
[dependencies]
|
||||
fuzzy = { path = "../fuzzy"}
|
||||
fs = {package = "fs2", path = "../fs2"}
|
||||
fs = {path = "../fs"}
|
||||
gpui = {package = "gpui2", path = "../gpui2"}
|
||||
picker = {path = "../picker"}
|
||||
util = {path = "../util"}
|
||||
|
@ -13,11 +13,11 @@ test-support = []
|
||||
[dependencies]
|
||||
client = { path = "../client" }
|
||||
editor = { path = "../editor" }
|
||||
fs = { package = "fs2", path = "../fs2" }
|
||||
fs = { path = "../fs" }
|
||||
fuzzy = { path = "../fuzzy" }
|
||||
gpui = { package = "gpui2", path = "../gpui2" }
|
||||
ui = { package = "ui2", path = "../ui2" }
|
||||
db = { package = "db2", path = "../db2" }
|
||||
db = { path = "../db" }
|
||||
install_cli = { path = "../install_cli" }
|
||||
project = { path = "../project" }
|
||||
settings = { package = "settings2", path = "../settings2" }
|
||||
|
@ -19,12 +19,12 @@ test-support = [
|
||||
]
|
||||
|
||||
[dependencies]
|
||||
db = { path = "../db2", package = "db2" }
|
||||
db = { path = "../db" }
|
||||
call = { path = "../call" }
|
||||
client = { path = "../client" }
|
||||
collections = { path = "../collections" }
|
||||
# context_menu = { path = "../context_menu" }
|
||||
fs = { path = "../fs2", package = "fs2" }
|
||||
fs = { path = "../fs" }
|
||||
gpui = { package = "gpui2", path = "../gpui2" }
|
||||
install_cli = { path = "../install_cli" }
|
||||
language = { path = "../language" }
|
||||
@ -59,8 +59,8 @@ client = { path = "../client", features = ["test-support"] }
|
||||
gpui = { path = "../gpui2", package = "gpui2", features = ["test-support"] }
|
||||
project = { path = "../project", features = ["test-support"] }
|
||||
settings = { path = "../settings2", package = "settings2", features = ["test-support"] }
|
||||
fs = { path = "../fs2", package = "fs2", features = ["test-support"] }
|
||||
db = { path = "../db2", package = "db2", features = ["test-support"] }
|
||||
fs = { path = "../fs", features = ["test-support"] }
|
||||
db = { path = "../db", features = ["test-support"] }
|
||||
|
||||
indoc.workspace = true
|
||||
env_logger.workspace = true
|
||||
|
@ -32,12 +32,12 @@ client = { path = "../client" }
|
||||
copilot = { path = "../copilot" }
|
||||
copilot_button = { path = "../copilot_button" }
|
||||
diagnostics = { path = "../diagnostics" }
|
||||
db = { package = "db2", path = "../db2" }
|
||||
db = { path = "../db" }
|
||||
editor = { path = "../editor" }
|
||||
feedback = { path = "../feedback" }
|
||||
file_finder = { path = "../file_finder" }
|
||||
search = { path = "../search" }
|
||||
fs = { package = "fs2", path = "../fs2" }
|
||||
fs = { path = "../fs" }
|
||||
fsevent = { path = "../fsevent" }
|
||||
go_to_line = { path = "../go_to_line" }
|
||||
gpui = { package = "gpui2", path = "../gpui2" }
|
||||
@ -49,7 +49,7 @@ lsp = { path = "../lsp" }
|
||||
menu = { package = "menu2", path = "../menu2" }
|
||||
language_tools = { path = "../language_tools" }
|
||||
node_runtime = { path = "../node_runtime" }
|
||||
notifications = { package = "notifications2", path = "../notifications2" }
|
||||
notifications = { path = "../notifications" }
|
||||
assistant = { path = "../assistant" }
|
||||
outline = { path = "../outline" }
|
||||
# plugin_runtime = { path = "../plugin_runtime",optional = true }
|
||||
@ -59,7 +59,7 @@ project_symbols = { path = "../project_symbols" }
|
||||
quick_action_bar = { path = "../quick_action_bar" }
|
||||
recent_projects = { path = "../recent_projects" }
|
||||
rope = { package = "rope2", path = "../rope2"}
|
||||
rpc = { package = "rpc2", path = "../rpc2" }
|
||||
rpc = { path = "../rpc" }
|
||||
settings = { package = "settings2", path = "../settings2" }
|
||||
feature_flags = { path = "../feature_flags" }
|
||||
sum_tree = { path = "../sum_tree" }
|
||||
@ -69,7 +69,7 @@ terminal_view = { path = "../terminal_view" }
|
||||
theme = { package = "theme2", path = "../theme2" }
|
||||
theme_selector = { path = "../theme_selector" }
|
||||
util = { path = "../util" }
|
||||
semantic_index = { package = "semantic_index2", path = "../semantic_index2" }
|
||||
semantic_index = { path = "../semantic_index" }
|
||||
vim = { path = "../vim" }
|
||||
workspace = { path = "../workspace" }
|
||||
welcome = { path = "../welcome" }
|
||||
|
Loading…
Reference in New Issue
Block a user