mirror of
https://github.com/SilasMarvin/lsp-ai.git
synced 2024-08-15 23:30:34 +03:00
Merge pull request #31 from SilasMarvin/silas-rag-force
Introduce RAG and PostgresML support
This commit is contained in:
commit
17ea67a6a7
341
Cargo.lock
generated
341
Cargo.lock
generated
@ -149,6 +149,18 @@ dependencies = [
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "auto_enums"
|
||||
version = "0.8.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1899bfcfd9340ceea3533ea157360ba8fa864354eccbceab58e1006ecab35393"
|
||||
dependencies = [
|
||||
"derive_utils",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.52",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "autocfg"
|
||||
version = "1.1.0"
|
||||
@ -356,7 +368,7 @@ version = "4.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "307bc0538d5f0f83b8248db3087aa92fe504e4691294d0c96c0eabc33f47ba47"
|
||||
dependencies = [
|
||||
"heck",
|
||||
"heck 0.4.1",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.52",
|
||||
@ -662,6 +674,17 @@ dependencies = [
|
||||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "derive_utils"
|
||||
version = "0.14.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "61bb5a1014ce6dfc2a378578509abe775a5aa06bff584a547555d9efdb81b926"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.52",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "difflib"
|
||||
version = "0.4.0"
|
||||
@ -730,9 +753,9 @@ checksum = "0d6ef0072f8a535281e4876be788938b528e9a1d43900b82c2569af7da799125"
|
||||
|
||||
[[package]]
|
||||
name = "either"
|
||||
version = "1.10.0"
|
||||
version = "1.12.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "11157ac094ffbdde99aa67b23417ebdd801842852b500e395a45a9c0aac03e4a"
|
||||
checksum = "3dca9240753cf90908d7e4aac30f630662b02aebaa1b58a3cadabdb23385b58b"
|
||||
dependencies = [
|
||||
"serde",
|
||||
]
|
||||
@ -1056,6 +1079,12 @@ dependencies = [
|
||||
"unicode-segmentation",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "heck"
|
||||
version = "0.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
|
||||
|
||||
[[package]]
|
||||
name = "hermit-abi"
|
||||
version = "0.3.9"
|
||||
@ -1364,6 +1393,15 @@ dependencies = [
|
||||
"either",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "itertools"
|
||||
version = "0.13.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186"
|
||||
dependencies = [
|
||||
"either",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "itoa"
|
||||
version = "1.0.10"
|
||||
@ -1523,6 +1561,7 @@ dependencies = [
|
||||
"anyhow",
|
||||
"assert_cmd",
|
||||
"async-trait",
|
||||
"cc",
|
||||
"directories",
|
||||
"hf-hub",
|
||||
"ignore",
|
||||
@ -1530,6 +1569,7 @@ dependencies = [
|
||||
"llama-cpp-2",
|
||||
"lsp-server",
|
||||
"lsp-types",
|
||||
"md5",
|
||||
"minijinja",
|
||||
"once_cell",
|
||||
"parking_lot",
|
||||
@ -1539,10 +1579,14 @@ dependencies = [
|
||||
"ropey",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"splitter-tree-sitter",
|
||||
"text-splitter",
|
||||
"tokenizers",
|
||||
"tokio",
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
"tree-sitter",
|
||||
"utils-tree-sitter",
|
||||
"xxhash-rust",
|
||||
]
|
||||
|
||||
@ -2196,9 +2240,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "regex"
|
||||
version = "1.10.3"
|
||||
version = "1.10.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b62dbe01f0b06f9d8dc7d49e05a0785f153b00b2c227856282f671e0318c9b15"
|
||||
checksum = "b91213439dad192326a0d7c6ee3955910425f441d7038e0d6933b0aec5c4517f"
|
||||
dependencies = [
|
||||
"aho-corasick",
|
||||
"memchr",
|
||||
@ -2415,6 +2459,12 @@ dependencies = [
|
||||
"untrusted",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustversion"
|
||||
version = "1.0.17"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "955d28af4278de8121b7ebeb796b6a45735dc01436d898801014aced2773a3d6"
|
||||
|
||||
[[package]]
|
||||
name = "ryu"
|
||||
version = "1.0.17"
|
||||
@ -2475,7 +2525,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "878cf3d57f0e5bfacd425cdaccc58b4c06d68a7b71c63fc28710a20c88676808"
|
||||
dependencies = [
|
||||
"darling 0.14.4",
|
||||
"heck",
|
||||
"heck 0.4.1",
|
||||
"quote",
|
||||
"syn 1.0.109",
|
||||
]
|
||||
@ -2498,7 +2548,7 @@ version = "0.4.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "25a82fcb49253abcb45cdcb2adf92956060ec0928635eb21b4f7a6d8f25ab0bc"
|
||||
dependencies = [
|
||||
"heck",
|
||||
"heck 0.4.1",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.52",
|
||||
@ -2756,6 +2806,17 @@ dependencies = [
|
||||
"der",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "splitter-tree-sitter"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"thiserror",
|
||||
"tree-sitter",
|
||||
"tree-sitter-rust",
|
||||
"tree-sitter-zig",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "spm_precompiled"
|
||||
version = "0.1.4"
|
||||
@ -2857,7 +2918,7 @@ checksum = "5833ef53aaa16d860e92123292f1f6a3d53c34ba8b1969f152ef1a7bb803f3c8"
|
||||
dependencies = [
|
||||
"dotenvy",
|
||||
"either",
|
||||
"heck",
|
||||
"heck 0.4.1",
|
||||
"hex",
|
||||
"once_cell",
|
||||
"proc-macro2",
|
||||
@ -3013,6 +3074,28 @@ version = "0.11.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f"
|
||||
|
||||
[[package]]
|
||||
name = "strum"
|
||||
version = "0.26.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5d8cec3501a5194c432b2b7976db6b7d10ec95c253208b45f83f7136aa985e29"
|
||||
dependencies = [
|
||||
"strum_macros",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "strum_macros"
|
||||
version = "0.26.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4c6bee85a5a24955dc440386795aa378cd9cf82acd5f764469152d2270e581be"
|
||||
dependencies = [
|
||||
"heck 0.5.0",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"rustversion",
|
||||
"syn 2.0.52",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "subtle"
|
||||
version = "2.5.0"
|
||||
@ -3087,19 +3170,36 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3369f5ac52d5eb6ab48c6b4ffdc8efbcad6b89c765749064ba298f2c68a16a76"
|
||||
|
||||
[[package]]
|
||||
name = "thiserror"
|
||||
version = "1.0.58"
|
||||
name = "text-splitter"
|
||||
version = "0.13.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "03468839009160513471e86a034bb2c5c0e4baae3b43f79ffc55c4a5427b3297"
|
||||
checksum = "2ab9dc04b7cf08eb01c07c272bf699fa55679a326ddf7dd075e14094efc80fb9"
|
||||
dependencies = [
|
||||
"ahash",
|
||||
"auto_enums",
|
||||
"either",
|
||||
"itertools 0.13.0",
|
||||
"once_cell",
|
||||
"regex",
|
||||
"strum",
|
||||
"thiserror",
|
||||
"unicode-segmentation",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "thiserror"
|
||||
version = "1.0.61"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c546c80d6be4bc6a00c0f01730c08df82eaa7a7a61f11d656526506112cc1709"
|
||||
dependencies = [
|
||||
"thiserror-impl",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "thiserror-impl"
|
||||
version = "1.0.58"
|
||||
version = "1.0.61"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c61f3ba182994efc43764a46c018c347bc492c79f024e705f46567b418f6d4f7"
|
||||
checksum = "46c3384250002a6d5af4d114f2845d37b57521033f30d5c3f46c4d70e1197533"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
@ -3339,6 +3439,195 @@ dependencies = [
|
||||
"tracing-serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tree-sitter"
|
||||
version = "0.22.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "df7cc499ceadd4dcdf7ec6d4cbc34ece92c3fa07821e287aedecd4416c516dca"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"regex",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tree-sitter-bash"
|
||||
version = "0.21.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b5244703ad2e08a616d859a0557d7aa290adcd5e0990188a692e628ffe9dce40"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"tree-sitter",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tree-sitter-c"
|
||||
version = "0.21.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f956d5351d62652864a4ff3ae861747e7a1940dc96c9998ae400ac0d3ce30427"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"tree-sitter",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tree-sitter-c-sharp"
|
||||
version = "0.21.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ff899037068a1ffbb891891b7e94db1400ddf12c3d934b85b8c9e30be5cd18da"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"tree-sitter",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tree-sitter-cpp"
|
||||
version = "0.22.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "537b7e0f0d8c89b8dd6f4d195814da94832f20720c09016c2a3ac3dc3c437993"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"tree-sitter",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tree-sitter-css"
|
||||
version = "0.21.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e2f806f96136762b0121f5fdd7172a3dcd8f42d37a2f23ed7f11b35895e20eb4"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"tree-sitter",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tree-sitter-elixir"
|
||||
version = "0.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "df94bf7f057768b1cab2ee1f14812ed4ae33f9e04d09254043eeaa797db4ef70"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"tree-sitter",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tree-sitter-erlang"
|
||||
version = "0.6.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8db61152e6d8a5b3b5895ecbb85848f85d028f84b4633a2368075c35e5817b34"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"tree-sitter",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tree-sitter-go"
|
||||
version = "0.21.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "55cb318be5ccf75f44e054acf6898a5c95d59b53443eed578e16be0cd7ec037f"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"tree-sitter",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tree-sitter-haskell"
|
||||
version = "0.21.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ef25a7e6c73cc1cbe0c0b7dbd5406e7b3485b370bd61c5d8d852ae0781f9bf9a"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"tree-sitter",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tree-sitter-html"
|
||||
version = "0.20.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "95b3492b08a786bf5cc79feb0ef2ff3b115d5174364e0ddfd7860e0b9b088b53"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"tree-sitter",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tree-sitter-java"
|
||||
version = "0.21.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "33bc21adf831a773c075d9d00107ab43965e6a6ea7607b47fd9ec6f3db4b481b"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"tree-sitter",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tree-sitter-javascript"
|
||||
version = "0.21.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4fced510d43e6627cd8e19adfd994ac9cfa3b1d71b0d522b41f74145de37feef"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"tree-sitter",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tree-sitter-json"
|
||||
version = "0.21.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5b737dcb73c35d74b7d64a5f3dde158113c86a012bf3cee2bfdf2150d23b05db"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"tree-sitter",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tree-sitter-lua"
|
||||
version = "0.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3b9fe6fc87bd480e1943fc1fcb02453fb2da050e4e8ce0daa67d801544046856"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"tree-sitter",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tree-sitter-ocaml"
|
||||
version = "0.22.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e2f2e8e848902d12ca6778d31d0e66b5709fc1ad0c84fd8b0c078472fff20dd2"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"tree-sitter",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tree-sitter-python"
|
||||
version = "0.21.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b4066c6cf678f962f8c2c4561f205945c84834cce73d981e71392624fdc390a9"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"tree-sitter",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tree-sitter-rust"
|
||||
version = "0.21.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "277690f420bf90741dea984f3da038ace46c4fe6047cba57a66822226cde1c93"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"tree-sitter",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tree-sitter-zig"
|
||||
version = "0.0.1"
|
||||
source = "git+https://github.com/maxxnino/tree-sitter-zig#7c5a29b721d409be8842017351bf007d7e384401"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"tree-sitter",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "try-lock"
|
||||
version = "0.2.5"
|
||||
@ -3450,6 +3739,32 @@ version = "0.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a"
|
||||
|
||||
[[package]]
|
||||
name = "utils-tree-sitter"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"thiserror",
|
||||
"tree-sitter",
|
||||
"tree-sitter-bash",
|
||||
"tree-sitter-c",
|
||||
"tree-sitter-c-sharp",
|
||||
"tree-sitter-cpp",
|
||||
"tree-sitter-css",
|
||||
"tree-sitter-elixir",
|
||||
"tree-sitter-erlang",
|
||||
"tree-sitter-go",
|
||||
"tree-sitter-haskell",
|
||||
"tree-sitter-html",
|
||||
"tree-sitter-java",
|
||||
"tree-sitter-javascript",
|
||||
"tree-sitter-json",
|
||||
"tree-sitter-lua",
|
||||
"tree-sitter-ocaml",
|
||||
"tree-sitter-python",
|
||||
"tree-sitter-rust",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "uuid"
|
||||
version = "1.7.0"
|
||||
|
46
Cargo.toml
46
Cargo.toml
@ -1,44 +1,12 @@
|
||||
[package]
|
||||
name = "lsp-ai"
|
||||
version = "0.3.0"
|
||||
[workspace]
|
||||
members = [
|
||||
"crates/*",
|
||||
]
|
||||
resolver = "2"
|
||||
|
||||
[workspace.package]
|
||||
edition = "2021"
|
||||
license = "MIT"
|
||||
description = "LSP-AI is an open-source language server that serves as a backend for AI-powered functionality, designed to assist and empower software engineers, not replace them."
|
||||
repository = "https://github.com/SilasMarvin/lsp-ai"
|
||||
readme = "README.md"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
anyhow = "1.0.75"
|
||||
lsp-server = "0.7.6"
|
||||
lsp-types = "0.95.0"
|
||||
ropey = "1.6.1"
|
||||
serde = "1.0.190"
|
||||
serde_json = "1.0.108"
|
||||
hf-hub = { git = "https://github.com/huggingface/hf-hub", version = "0.3.2" }
|
||||
rand = "0.8.5"
|
||||
tokenizers = "0.14.1"
|
||||
parking_lot = "0.12.1"
|
||||
once_cell = "1.19.0"
|
||||
directories = "5.0.1"
|
||||
llama-cpp-2 = { version = "0.1.55", optional = true }
|
||||
minijinja = { version = "1.0.12", features = ["loader"] }
|
||||
tracing-subscriber = { version = "0.3.18", features = ["env-filter"] }
|
||||
tracing = "0.1.40"
|
||||
xxhash-rust = { version = "0.8.5", features = ["xxh3"] }
|
||||
reqwest = { version = "0.11.25", features = ["blocking", "json"] }
|
||||
ignore = "0.4.22"
|
||||
pgml = "1.0.4"
|
||||
tokio = { version = "1.36.0", features = ["rt-multi-thread", "time"] }
|
||||
indexmap = "2.2.5"
|
||||
async-trait = "0.1.78"
|
||||
|
||||
[features]
|
||||
default = []
|
||||
llama_cpp = ["dep:llama-cpp-2"]
|
||||
metal = ["llama-cpp-2/metal"]
|
||||
cuda = ["llama-cpp-2/cuda"]
|
||||
|
||||
[dev-dependencies]
|
||||
assert_cmd = "2.0.14"
|
||||
|
51
crates/lsp-ai/Cargo.toml
Normal file
51
crates/lsp-ai/Cargo.toml
Normal file
@ -0,0 +1,51 @@
|
||||
[package]
|
||||
name = "lsp-ai"
|
||||
version = "0.3.0"
|
||||
|
||||
description.workspace = true
|
||||
repository.workspace = true
|
||||
readme.workspace = true
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
anyhow = "1.0.75"
|
||||
lsp-server = "0.7.6"
|
||||
lsp-types = "0.95.0"
|
||||
ropey = "1.6.1"
|
||||
serde = "1.0.190"
|
||||
serde_json = "1.0.108"
|
||||
hf-hub = { git = "https://github.com/huggingface/hf-hub", version = "0.3.2" }
|
||||
rand = "0.8.5"
|
||||
tokenizers = "0.14.1"
|
||||
parking_lot = "0.12.1"
|
||||
once_cell = "1.19.0"
|
||||
directories = "5.0.1"
|
||||
llama-cpp-2 = { version = "0.1.55", optional = true }
|
||||
minijinja = { version = "1.0.12", features = ["loader"] }
|
||||
tracing-subscriber = { version = "0.3.18", features = ["env-filter"] }
|
||||
tracing = "0.1.40"
|
||||
xxhash-rust = { version = "0.8.5", features = ["xxh3"] }
|
||||
reqwest = { version = "0.11.25", features = ["blocking", "json"] }
|
||||
ignore = "0.4.22"
|
||||
pgml = "1.0.4"
|
||||
tokio = { version = "1.36.0", features = ["rt-multi-thread", "time"] }
|
||||
indexmap = "2.2.5"
|
||||
async-trait = "0.1.78"
|
||||
tree-sitter = "0.22"
|
||||
utils-tree-sitter = { path = "../utils-tree-sitter", features = ["all"], version = "0.1.0" }
|
||||
splitter-tree-sitter = { path = "../splitter-tree-sitter", version = "0.1.0" }
|
||||
text-splitter = { version = "0.13.3" }
|
||||
md5 = "0.7.0"
|
||||
|
||||
[build-dependencies]
|
||||
cc="*"
|
||||
|
||||
[features]
|
||||
default = []
|
||||
llama_cpp = ["dep:llama-cpp-2"]
|
||||
metal = ["llama-cpp-2/metal"]
|
||||
cuda = ["llama-cpp-2/cuda"]
|
||||
|
||||
[dev-dependencies]
|
||||
assert_cmd = "2.0.14"
|
@ -24,6 +24,51 @@ impl Default for PostProcess {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub enum ValidSplitter {
|
||||
#[serde(rename = "tree_sitter")]
|
||||
TreeSitter(TreeSitter),
|
||||
#[serde(rename = "text_sitter")]
|
||||
TextSplitter(TextSplitter),
|
||||
}
|
||||
|
||||
impl Default for ValidSplitter {
|
||||
fn default() -> Self {
|
||||
ValidSplitter::TreeSitter(TreeSitter::default())
|
||||
}
|
||||
}
|
||||
|
||||
const fn chunk_size_default() -> usize {
|
||||
1500
|
||||
}
|
||||
|
||||
const fn chunk_overlap_default() -> usize {
|
||||
0
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct TreeSitter {
|
||||
#[serde(default = "chunk_size_default")]
|
||||
pub chunk_size: usize,
|
||||
#[serde(default = "chunk_overlap_default")]
|
||||
pub chunk_overlap: usize,
|
||||
}
|
||||
|
||||
impl Default for TreeSitter {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
chunk_size: 1500,
|
||||
chunk_overlap: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct TextSplitter {
|
||||
#[serde(default = "chunk_size_default")]
|
||||
pub chunk_size: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub enum ValidMemoryBackend {
|
||||
#[serde(rename = "file_store")]
|
||||
@ -67,15 +112,6 @@ impl ChatMessage {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
#[serde(deny_unknown_fields)]
|
||||
pub struct Chat {
|
||||
pub completion: Option<Vec<ChatMessage>>,
|
||||
pub generation: Option<Vec<ChatMessage>>,
|
||||
pub chat_template: Option<String>,
|
||||
pub chat_format: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
#[allow(clippy::upper_case_acronyms)]
|
||||
#[serde(deny_unknown_fields)]
|
||||
@ -85,27 +121,52 @@ pub struct FIM {
|
||||
pub end: String,
|
||||
}
|
||||
|
||||
const fn max_crawl_memory_default() -> u64 {
|
||||
100_000_000
|
||||
}
|
||||
|
||||
const fn max_crawl_file_size_default() -> u64 {
|
||||
10_000_000
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
#[serde(deny_unknown_fields)]
|
||||
pub struct Crawl {
|
||||
#[serde(default = "max_crawl_file_size_default")]
|
||||
pub max_file_size: u64,
|
||||
#[serde(default = "max_crawl_memory_default")]
|
||||
pub max_crawl_memory: u64,
|
||||
#[serde(default)]
|
||||
pub all_files: bool,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
pub struct PostgresMLEmbeddingModel {
|
||||
pub model: String,
|
||||
pub embed_parameters: Option<Value>,
|
||||
pub query_parameters: Option<Value>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
#[serde(deny_unknown_fields)]
|
||||
pub struct PostgresML {
|
||||
pub database_url: Option<String>,
|
||||
pub crawl: Option<Crawl>,
|
||||
#[serde(default)]
|
||||
pub crawl: bool,
|
||||
pub splitter: ValidSplitter,
|
||||
pub embedding_model: Option<PostgresMLEmbeddingModel>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Default)]
|
||||
#[serde(deny_unknown_fields)]
|
||||
pub struct FileStore {
|
||||
#[serde(default)]
|
||||
pub crawl: bool,
|
||||
pub crawl: Option<Crawl>,
|
||||
}
|
||||
|
||||
const fn n_gpu_layers_default() -> u32 {
|
||||
1000
|
||||
}
|
||||
|
||||
const fn n_ctx_default() -> u32 {
|
||||
1000
|
||||
impl FileStore {
|
||||
pub fn new_without_crawl() -> Self {
|
||||
Self { crawl: None }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
@ -137,6 +198,17 @@ pub struct MistralFIM {
|
||||
pub max_requests_per_second: f32,
|
||||
}
|
||||
|
||||
#[cfg(feature = "llama_cpp")]
|
||||
const fn n_gpu_layers_default() -> u32 {
|
||||
1000
|
||||
}
|
||||
|
||||
#[cfg(feature = "llama_cpp")]
|
||||
const fn n_ctx_default() -> u32 {
|
||||
1000
|
||||
}
|
||||
|
||||
#[cfg(feature = "llama_cpp")]
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
#[serde(deny_unknown_fields)]
|
||||
pub struct LLaMACPP {
|
||||
@ -230,15 +302,14 @@ pub struct ValidConfig {
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Default)]
|
||||
pub struct ValidClientParams {
|
||||
#[serde(alias = "rootURI")]
|
||||
_root_uri: Option<String>,
|
||||
_workspace_folders: Option<Vec<String>>,
|
||||
#[serde(alias = "rootUri")]
|
||||
pub root_uri: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Config {
|
||||
pub config: ValidConfig,
|
||||
_client_params: ValidClientParams,
|
||||
pub client_params: ValidClientParams,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
@ -255,7 +326,7 @@ impl Config {
|
||||
let client_params: ValidClientParams = serde_json::from_value(args)?;
|
||||
Ok(Self {
|
||||
config: valid_args,
|
||||
_client_params: client_params,
|
||||
client_params,
|
||||
})
|
||||
}
|
||||
|
||||
@ -300,20 +371,17 @@ impl Config {
|
||||
}
|
||||
}
|
||||
|
||||
// This makes testing much easier.
|
||||
// For teesting use only
|
||||
#[cfg(test)]
|
||||
impl Config {
|
||||
pub fn default_with_file_store_without_models() -> Self {
|
||||
Self {
|
||||
config: ValidConfig {
|
||||
memory: ValidMemoryBackend::FileStore(FileStore { crawl: false }),
|
||||
memory: ValidMemoryBackend::FileStore(FileStore { crawl: None }),
|
||||
models: HashMap::new(),
|
||||
completion: None,
|
||||
},
|
||||
_client_params: ValidClientParams {
|
||||
_root_uri: None,
|
||||
_workspace_folders: None,
|
||||
},
|
||||
client_params: ValidClientParams { root_uri: None },
|
||||
}
|
||||
}
|
||||
}
|
103
crates/lsp-ai/src/crawl.rs
Normal file
103
crates/lsp-ai/src/crawl.rs
Normal file
@ -0,0 +1,103 @@
|
||||
use ignore::WalkBuilder;
|
||||
use std::collections::HashSet;
|
||||
use tracing::{error, instrument};
|
||||
|
||||
use crate::config::{self, Config};
|
||||
|
||||
pub struct Crawl {
|
||||
crawl_config: config::Crawl,
|
||||
config: Config,
|
||||
crawled_file_types: HashSet<String>,
|
||||
crawled_all: bool,
|
||||
}
|
||||
|
||||
impl Crawl {
|
||||
pub fn new(crawl_config: config::Crawl, config: Config) -> Self {
|
||||
Self {
|
||||
crawl_config,
|
||||
config,
|
||||
crawled_file_types: HashSet::new(),
|
||||
crawled_all: false,
|
||||
}
|
||||
}
|
||||
|
||||
#[instrument(skip(self, f))]
|
||||
pub fn maybe_do_crawl(
|
||||
&mut self,
|
||||
triggered_file: Option<String>,
|
||||
mut f: impl FnMut(&config::Crawl, &str) -> anyhow::Result<bool>,
|
||||
) -> anyhow::Result<()> {
|
||||
if self.crawled_all {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if let Some(root_uri) = &self.config.client_params.root_uri {
|
||||
if !root_uri.starts_with("file://") {
|
||||
anyhow::bail!("Skipping crawling as root_uri does not begin with file://")
|
||||
}
|
||||
|
||||
let extension_to_match = triggered_file
|
||||
.map(|tf| {
|
||||
let path = std::path::Path::new(&tf);
|
||||
path.extension().map(|f| f.to_str().map(|f| f.to_owned()))
|
||||
})
|
||||
.flatten()
|
||||
.flatten();
|
||||
|
||||
if let Some(extension_to_match) = &extension_to_match {
|
||||
if self.crawled_file_types.contains(extension_to_match) {
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
if !self.crawl_config.all_files && extension_to_match.is_none() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
for result in WalkBuilder::new(&root_uri[7..]).build() {
|
||||
let result = result?;
|
||||
let path = result.path();
|
||||
if !path.is_dir() {
|
||||
if let Some(path_str) = path.to_str() {
|
||||
if self.crawl_config.all_files {
|
||||
match f(&self.crawl_config, path_str) {
|
||||
Ok(c) => {
|
||||
if !c {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(e) => error!("{e:?}"),
|
||||
}
|
||||
} else {
|
||||
match (
|
||||
path.extension().map(|pe| pe.to_str()).flatten(),
|
||||
&extension_to_match,
|
||||
) {
|
||||
(Some(path_extension), Some(extension_to_match)) => {
|
||||
if path_extension == extension_to_match {
|
||||
match f(&self.crawl_config, path_str) {
|
||||
Ok(c) => {
|
||||
if !c {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(e) => error!("{e:?}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => continue,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(extension_to_match) = extension_to_match {
|
||||
self.crawled_file_types.insert(extension_to_match);
|
||||
} else {
|
||||
self.crawled_all = true
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
@ -14,9 +14,11 @@ use tracing::error;
|
||||
use tracing_subscriber::{EnvFilter, FmtSubscriber};
|
||||
|
||||
mod config;
|
||||
mod crawl;
|
||||
mod custom_requests;
|
||||
mod memory_backends;
|
||||
mod memory_worker;
|
||||
mod splitters;
|
||||
#[cfg(feature = "llama_cpp")]
|
||||
mod template;
|
||||
mod transformer_backends;
|
||||
@ -50,15 +52,19 @@ where
|
||||
req.extract(R::METHOD)
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
// Builds a tracing subscriber from the `LSP_AI_LOG` environment variable
|
||||
// If the variables value is malformed or missing, sets the default log level to ERROR
|
||||
// Builds a tracing subscriber from the `LSP_AI_LOG` environment variable
|
||||
// If the variables value is malformed or missing, sets the default log level to ERROR
|
||||
fn init_logger() {
|
||||
FmtSubscriber::builder()
|
||||
.with_writer(std::io::stderr)
|
||||
.with_ansi(false)
|
||||
.without_time()
|
||||
.with_env_filter(EnvFilter::from_env("LSP_AI_LOG"))
|
||||
.init();
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
init_logger();
|
||||
|
||||
let (connection, io_threads) = Connection::stdio();
|
||||
let server_capabilities = serde_json::to_value(ServerCapabilities {
|
||||
@ -83,7 +89,6 @@ fn main_loop(connection: Connection, args: serde_json::Value) -> Result<()> {
|
||||
let connection = Arc::new(connection);
|
||||
|
||||
// Our channel we use to communicate with our transformer worker
|
||||
// let last_worker_request = Arc::new(Mutex::new(None));
|
||||
let (transformer_tx, transformer_rx) = mpsc::channel();
|
||||
|
||||
// The channel we use to communicate with our memory worker
|
||||
@ -94,8 +99,6 @@ fn main_loop(connection: Connection, args: serde_json::Value) -> Result<()> {
|
||||
thread::spawn(move || memory_worker::run(memory_backend, memory_rx));
|
||||
|
||||
// Setup our transformer worker
|
||||
// let transformer_backend: Box<dyn TransformerBackend + Send + Sync> =
|
||||
// config.clone().try_into()?;
|
||||
let transformer_backends: HashMap<String, Box<dyn TransformerBackend + Send + Sync>> = config
|
||||
.config
|
||||
.models
|
924
crates/lsp-ai/src/memory_backends/file_store.rs
Normal file
924
crates/lsp-ai/src/memory_backends/file_store.rs
Normal file
@ -0,0 +1,924 @@
|
||||
use anyhow::Context;
|
||||
use indexmap::IndexSet;
|
||||
use lsp_types::TextDocumentPositionParams;
|
||||
use parking_lot::Mutex;
|
||||
use ropey::Rope;
|
||||
use serde_json::Value;
|
||||
use std::{collections::HashMap, io::Read};
|
||||
use tracing::{error, instrument, warn};
|
||||
use tree_sitter::{InputEdit, Point, Tree};
|
||||
|
||||
use crate::{
|
||||
config::{self, Config},
|
||||
crawl::Crawl,
|
||||
utils::{parse_tree, tokens_to_estimated_characters},
|
||||
};
|
||||
|
||||
use super::{ContextAndCodePrompt, FIMPrompt, MemoryBackend, MemoryRunParams, Prompt, PromptType};
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct AdditionalFileStoreParams {
|
||||
build_tree: bool,
|
||||
}
|
||||
|
||||
impl AdditionalFileStoreParams {
|
||||
pub fn new(build_tree: bool) -> Self {
|
||||
Self { build_tree }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct File {
|
||||
rope: Rope,
|
||||
tree: Option<Tree>,
|
||||
}
|
||||
|
||||
impl File {
|
||||
fn new(rope: Rope, tree: Option<Tree>) -> Self {
|
||||
Self { rope, tree }
|
||||
}
|
||||
|
||||
pub fn rope(&self) -> &Rope {
|
||||
&self.rope
|
||||
}
|
||||
|
||||
pub fn tree(&self) -> Option<&Tree> {
|
||||
self.tree.as_ref()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct FileStore {
|
||||
params: AdditionalFileStoreParams,
|
||||
file_map: Mutex<HashMap<String, File>>,
|
||||
accessed_files: Mutex<IndexSet<String>>,
|
||||
crawl: Option<Mutex<Crawl>>,
|
||||
}
|
||||
|
||||
impl FileStore {
|
||||
pub fn new(mut file_store_config: config::FileStore, config: Config) -> anyhow::Result<Self> {
|
||||
let crawl = file_store_config
|
||||
.crawl
|
||||
.take()
|
||||
.map(|x| Mutex::new(Crawl::new(x, config.clone())));
|
||||
let s = Self {
|
||||
params: AdditionalFileStoreParams::default(),
|
||||
file_map: Mutex::new(HashMap::new()),
|
||||
accessed_files: Mutex::new(IndexSet::new()),
|
||||
crawl,
|
||||
};
|
||||
if let Err(e) = s.maybe_do_crawl(None) {
|
||||
error!("{e:?}")
|
||||
}
|
||||
Ok(s)
|
||||
}
|
||||
|
||||
pub fn new_with_params(
|
||||
mut file_store_config: config::FileStore,
|
||||
config: Config,
|
||||
params: AdditionalFileStoreParams,
|
||||
) -> anyhow::Result<Self> {
|
||||
let crawl = file_store_config
|
||||
.crawl
|
||||
.take()
|
||||
.map(|x| Mutex::new(Crawl::new(x, config.clone())));
|
||||
let s = Self {
|
||||
params,
|
||||
file_map: Mutex::new(HashMap::new()),
|
||||
accessed_files: Mutex::new(IndexSet::new()),
|
||||
crawl,
|
||||
};
|
||||
if let Err(e) = s.maybe_do_crawl(None) {
|
||||
error!("{e:?}")
|
||||
}
|
||||
Ok(s)
|
||||
}
|
||||
|
||||
fn add_new_file(&self, uri: &str, contents: String) {
|
||||
let tree = if self.params.build_tree {
|
||||
match parse_tree(uri, &contents, None) {
|
||||
Ok(tree) => Some(tree),
|
||||
Err(e) => {
|
||||
error!(
|
||||
"Failed to parse tree for {uri} with error {e}, falling back to no tree"
|
||||
);
|
||||
None
|
||||
}
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
self.file_map
|
||||
.lock()
|
||||
.insert(uri.to_string(), File::new(Rope::from_str(&contents), tree));
|
||||
self.accessed_files.lock().insert(uri.to_string());
|
||||
}
|
||||
|
||||
fn maybe_do_crawl(&self, triggered_file: Option<String>) -> anyhow::Result<()> {
|
||||
let mut total_bytes = 0;
|
||||
let mut current_bytes = 0;
|
||||
if let Some(crawl) = &self.crawl {
|
||||
crawl
|
||||
.lock()
|
||||
.maybe_do_crawl(triggered_file, |config, path| {
|
||||
// Break if total bytes is over the max crawl memory
|
||||
if total_bytes as u64 >= config.max_crawl_memory {
|
||||
warn!("Ending crawl early due to `max_crawl_memory` resetraint");
|
||||
return Ok(false);
|
||||
}
|
||||
// This means it has been opened before
|
||||
let insert_uri = format!("file:///{path}");
|
||||
if self.file_map.lock().contains_key(&insert_uri) {
|
||||
return Ok(true);
|
||||
}
|
||||
// Open the file and see if it is small enough to read
|
||||
let mut f = std::fs::File::open(path)?;
|
||||
let metadata = f.metadata()?;
|
||||
if metadata.len() > config.max_file_size {
|
||||
warn!("Skipping file: {path} because it is too large");
|
||||
return Ok(true);
|
||||
}
|
||||
// Read the file contents
|
||||
let mut contents = vec![];
|
||||
f.read_to_end(&mut contents)?;
|
||||
let contents = String::from_utf8(contents)?;
|
||||
current_bytes += contents.len();
|
||||
total_bytes += contents.len();
|
||||
self.add_new_file(&insert_uri, contents);
|
||||
Ok(true)
|
||||
})?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_rope_for_position(
|
||||
&self,
|
||||
position: &TextDocumentPositionParams,
|
||||
characters: usize,
|
||||
pull_from_multiple_files: bool,
|
||||
) -> anyhow::Result<(Rope, usize)> {
|
||||
// Get the rope and set our initial cursor index
|
||||
let current_document_uri = position.text_document.uri.to_string();
|
||||
let mut rope = self
|
||||
.file_map
|
||||
.lock()
|
||||
.get(¤t_document_uri)
|
||||
.context("Error file not found")?
|
||||
.rope
|
||||
.clone();
|
||||
let mut cursor_index = rope.line_to_char(position.position.line as usize)
|
||||
+ position.position.character as usize;
|
||||
// Add to our rope if we need to
|
||||
for file in self
|
||||
.accessed_files
|
||||
.lock()
|
||||
.iter()
|
||||
.filter(|f| **f != current_document_uri)
|
||||
{
|
||||
let needed = characters.saturating_sub(rope.len_chars() + 1);
|
||||
if needed == 0 || !pull_from_multiple_files {
|
||||
break;
|
||||
}
|
||||
let file_map = self.file_map.lock();
|
||||
let r = &file_map.get(file).context("Error file not found")?.rope;
|
||||
let slice_max = needed.min(r.len_chars() + 1);
|
||||
let rope_str_slice = r
|
||||
.get_slice(0..slice_max - 1)
|
||||
.context("Error getting slice")?
|
||||
.to_string();
|
||||
rope.insert(0, "\n");
|
||||
rope.insert(0, &rope_str_slice);
|
||||
cursor_index += slice_max;
|
||||
}
|
||||
Ok((rope, cursor_index))
|
||||
}
|
||||
|
||||
pub fn get_characters_around_position(
|
||||
&self,
|
||||
position: &TextDocumentPositionParams,
|
||||
characters: usize,
|
||||
) -> anyhow::Result<String> {
|
||||
let rope = self
|
||||
.file_map
|
||||
.lock()
|
||||
.get(position.text_document.uri.as_str())
|
||||
.context("Error file not found")?
|
||||
.rope
|
||||
.clone();
|
||||
let cursor_index = rope.line_to_char(position.position.line as usize)
|
||||
+ position.position.character as usize;
|
||||
let start = cursor_index.saturating_sub(characters / 2);
|
||||
let end = rope
|
||||
.len_chars()
|
||||
.min(cursor_index + (characters - (cursor_index - start)));
|
||||
let rope_slice = rope
|
||||
.get_slice(start..end)
|
||||
.context("Error getting rope slice")?;
|
||||
Ok(rope_slice.to_string())
|
||||
}
|
||||
|
||||
pub fn build_code(
|
||||
&self,
|
||||
position: &TextDocumentPositionParams,
|
||||
prompt_type: PromptType,
|
||||
params: MemoryRunParams,
|
||||
pull_from_multiple_files: bool,
|
||||
) -> anyhow::Result<Prompt> {
|
||||
let (mut rope, cursor_index) =
|
||||
self.get_rope_for_position(position, params.max_context, pull_from_multiple_files)?;
|
||||
|
||||
Ok(match prompt_type {
|
||||
PromptType::ContextAndCode => {
|
||||
if params.is_for_chat {
|
||||
let max_length = tokens_to_estimated_characters(params.max_context);
|
||||
let start = cursor_index.saturating_sub(max_length / 2);
|
||||
let end = rope
|
||||
.len_chars()
|
||||
.min(cursor_index + (max_length - (cursor_index - start)));
|
||||
|
||||
rope.insert(cursor_index, "<CURSOR>");
|
||||
let rope_slice = rope
|
||||
.get_slice(start..end + "<CURSOR>".chars().count())
|
||||
.context("Error getting rope slice")?;
|
||||
Prompt::ContextAndCode(ContextAndCodePrompt::new(
|
||||
"".to_string(),
|
||||
rope_slice.to_string(),
|
||||
))
|
||||
} else {
|
||||
let start = cursor_index
|
||||
.saturating_sub(tokens_to_estimated_characters(params.max_context));
|
||||
let rope_slice = rope
|
||||
.get_slice(start..cursor_index)
|
||||
.context("Error getting rope slice")?;
|
||||
Prompt::ContextAndCode(ContextAndCodePrompt::new(
|
||||
"".to_string(),
|
||||
rope_slice.to_string(),
|
||||
))
|
||||
}
|
||||
}
|
||||
PromptType::FIM => {
|
||||
let max_length = tokens_to_estimated_characters(params.max_context);
|
||||
let start = cursor_index.saturating_sub(max_length / 2);
|
||||
let end = rope
|
||||
.len_chars()
|
||||
.min(cursor_index + (max_length - (cursor_index - start)));
|
||||
let prefix = rope
|
||||
.get_slice(start..cursor_index)
|
||||
.context("Error getting rope slice")?;
|
||||
let suffix = rope
|
||||
.get_slice(cursor_index..end)
|
||||
.context("Error getting rope slice")?;
|
||||
Prompt::FIM(FIMPrompt::new(prefix.to_string(), suffix.to_string()))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub fn file_map(&self) -> &Mutex<HashMap<String, File>> {
|
||||
&self.file_map
|
||||
}
|
||||
|
||||
pub fn contains_file(&self, uri: &str) -> bool {
|
||||
self.file_map.lock().contains_key(uri)
|
||||
}
|
||||
|
||||
pub fn position_to_byte(&self, position: &TextDocumentPositionParams) -> anyhow::Result<usize> {
|
||||
let file_map = self.file_map.lock();
|
||||
let uri = position.text_document.uri.to_string();
|
||||
let file = file_map
|
||||
.get(&uri)
|
||||
.with_context(|| format!("trying to get file that does not exist {uri}"))?;
|
||||
let line_char_index = file
|
||||
.rope
|
||||
.try_line_to_char(position.position.line as usize)?;
|
||||
Ok(file
|
||||
.rope
|
||||
.try_char_to_byte(line_char_index + position.position.character as usize)?)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl MemoryBackend for FileStore {
|
||||
#[instrument(skip(self))]
|
||||
fn get_filter_text(&self, position: &TextDocumentPositionParams) -> anyhow::Result<String> {
|
||||
let rope = self
|
||||
.file_map
|
||||
.lock()
|
||||
.get(position.text_document.uri.as_str())
|
||||
.context("Error file not found")?
|
||||
.rope
|
||||
.clone();
|
||||
let line = rope
|
||||
.get_line(position.position.line as usize)
|
||||
.context("Error getting filter text")?
|
||||
.get_slice(0..position.position.character as usize)
|
||||
.context("Error getting filter text")?
|
||||
.to_string();
|
||||
Ok(line)
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
async fn build_prompt(
|
||||
&self,
|
||||
position: &TextDocumentPositionParams,
|
||||
prompt_type: PromptType,
|
||||
params: &Value,
|
||||
) -> anyhow::Result<Prompt> {
|
||||
let params: MemoryRunParams = params.try_into()?;
|
||||
self.build_code(position, prompt_type, params, true)
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
fn opened_text_document(
|
||||
&self,
|
||||
params: lsp_types::DidOpenTextDocumentParams,
|
||||
) -> anyhow::Result<()> {
|
||||
let uri = params.text_document.uri.to_string();
|
||||
self.add_new_file(&uri, params.text_document.text);
|
||||
if let Err(e) = self.maybe_do_crawl(Some(uri)) {
|
||||
error!("{e:?}")
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
fn changed_text_document(
|
||||
&self,
|
||||
params: lsp_types::DidChangeTextDocumentParams,
|
||||
) -> anyhow::Result<()> {
|
||||
let uri = params.text_document.uri.to_string();
|
||||
let mut file_map = self.file_map.lock();
|
||||
let file = file_map
|
||||
.get_mut(&uri)
|
||||
.with_context(|| format!("Trying to get file that does not exist {uri}"))?;
|
||||
for change in params.content_changes {
|
||||
// If range is ommitted, text is the new text of the document
|
||||
if let Some(range) = change.range {
|
||||
// Record old positions
|
||||
let (old_end_position, old_end_byte) = {
|
||||
let last_line_index = file.rope.len_lines() - 1;
|
||||
(
|
||||
file.rope
|
||||
.get_line(last_line_index)
|
||||
.context("getting last line for edit")
|
||||
.map(|last_line| Point::new(last_line_index, last_line.len_chars())),
|
||||
file.rope.bytes().count(),
|
||||
)
|
||||
};
|
||||
// Update the document
|
||||
let start_index = file.rope.line_to_char(range.start.line as usize)
|
||||
+ range.start.character as usize;
|
||||
let end_index =
|
||||
file.rope.line_to_char(range.end.line as usize) + range.end.character as usize;
|
||||
file.rope.remove(start_index..end_index);
|
||||
file.rope.insert(start_index, &change.text);
|
||||
// Set new end positions
|
||||
let (new_end_position, new_end_byte) = {
|
||||
let last_line_index = file.rope.len_lines() - 1;
|
||||
(
|
||||
file.rope
|
||||
.get_line(last_line_index)
|
||||
.context("getting last line for edit")
|
||||
.map(|last_line| Point::new(last_line_index, last_line.len_chars())),
|
||||
file.rope.bytes().count(),
|
||||
)
|
||||
};
|
||||
// Update the tree
|
||||
if self.params.build_tree {
|
||||
let mut old_tree = file.tree.take();
|
||||
let start_byte = file
|
||||
.rope
|
||||
.try_line_to_char(range.start.line as usize)
|
||||
.and_then(|start_char| {
|
||||
file.rope
|
||||
.try_char_to_byte(start_char + range.start.character as usize)
|
||||
})
|
||||
.map_err(anyhow::Error::msg);
|
||||
if let Some(old_tree) = &mut old_tree {
|
||||
match (start_byte, old_end_position, new_end_position) {
|
||||
(Ok(start_byte), Ok(old_end_position), Ok(new_end_position)) => {
|
||||
old_tree.edit(&InputEdit {
|
||||
start_byte,
|
||||
old_end_byte,
|
||||
new_end_byte,
|
||||
start_position: Point::new(
|
||||
range.start.line as usize,
|
||||
range.start.character as usize,
|
||||
),
|
||||
old_end_position,
|
||||
new_end_position,
|
||||
});
|
||||
file.tree = match parse_tree(
|
||||
&uri,
|
||||
&file.rope.to_string(),
|
||||
Some(old_tree),
|
||||
) {
|
||||
Ok(tree) => Some(tree),
|
||||
Err(e) => {
|
||||
error!("failed to edit tree: {e:?}");
|
||||
None
|
||||
}
|
||||
};
|
||||
}
|
||||
(Err(e), _, _) | (_, Err(e), _) | (_, _, Err(e)) => {
|
||||
error!("failed to build tree edit: {e:?}");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
file.rope = Rope::from_str(&change.text);
|
||||
if self.params.build_tree {
|
||||
file.tree = match parse_tree(&uri, &change.text, None) {
|
||||
Ok(tree) => Some(tree),
|
||||
Err(e) => {
|
||||
error!("failed to parse new tree: {e:?}");
|
||||
None
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
self.accessed_files.lock().shift_insert(0, uri);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
fn renamed_files(&self, params: lsp_types::RenameFilesParams) -> anyhow::Result<()> {
|
||||
for file_rename in params.files {
|
||||
let mut file_map = self.file_map.lock();
|
||||
if let Some(rope) = file_map.remove(&file_rename.old_uri) {
|
||||
file_map.insert(file_rename.new_uri, rope);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
// For teesting use only
|
||||
#[cfg(test)]
|
||||
impl FileStore {
|
||||
pub fn default_with_filler_file() -> anyhow::Result<Self> {
|
||||
let config = Config::default_with_file_store_without_models();
|
||||
let file_store_config = if let config::ValidMemoryBackend::FileStore(file_store_config) =
|
||||
config.config.memory.clone()
|
||||
{
|
||||
file_store_config
|
||||
} else {
|
||||
anyhow::bail!("requires a file_store_config")
|
||||
};
|
||||
let f = FileStore::new(file_store_config, config)?;
|
||||
|
||||
let uri = "file:///filler.py";
|
||||
let text = r#"# Multiplies two numbers
|
||||
def multiply_two_numbers(x, y):
|
||||
return
|
||||
|
||||
# A singular test
|
||||
assert multiply_two_numbers(2, 3) == 6
|
||||
"#;
|
||||
let params = lsp_types::DidOpenTextDocumentParams {
|
||||
text_document: lsp_types::TextDocumentItem {
|
||||
uri: reqwest::Url::parse(uri).unwrap(),
|
||||
language_id: "filler".to_string(),
|
||||
version: 0,
|
||||
text: text.to_string(),
|
||||
},
|
||||
};
|
||||
f.opened_text_document(params)?;
|
||||
|
||||
Ok(f)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use lsp_types::{
|
||||
DidOpenTextDocumentParams, FileRename, Position, Range, RenameFilesParams,
|
||||
TextDocumentContentChangeEvent, TextDocumentIdentifier, TextDocumentItem,
|
||||
VersionedTextDocumentIdentifier,
|
||||
};
|
||||
use serde_json::json;
|
||||
|
||||
fn generate_base_file_store() -> anyhow::Result<FileStore> {
|
||||
let config = Config::default_with_file_store_without_models();
|
||||
let file_store_config = if let config::ValidMemoryBackend::FileStore(file_store_config) =
|
||||
config.config.memory.clone()
|
||||
{
|
||||
file_store_config
|
||||
} else {
|
||||
anyhow::bail!("requires a file_store_config")
|
||||
};
|
||||
FileStore::new(file_store_config, config)
|
||||
}
|
||||
|
||||
fn generate_filler_text_document(uri: Option<&str>, text: Option<&str>) -> TextDocumentItem {
|
||||
let uri = uri.unwrap_or("file:///filler/");
|
||||
let text = text.unwrap_or("Here is the document body");
|
||||
TextDocumentItem {
|
||||
uri: reqwest::Url::parse(uri).unwrap(),
|
||||
language_id: "filler".to_string(),
|
||||
version: 0,
|
||||
text: text.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_open_document() -> anyhow::Result<()> {
|
||||
let params = lsp_types::DidOpenTextDocumentParams {
|
||||
text_document: generate_filler_text_document(None, None),
|
||||
};
|
||||
let file_store = generate_base_file_store()?;
|
||||
file_store.opened_text_document(params)?;
|
||||
let file = file_store
|
||||
.file_map
|
||||
.lock()
|
||||
.get("file:///filler/")
|
||||
.unwrap()
|
||||
.clone();
|
||||
assert_eq!(file.rope.to_string(), "Here is the document body");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_rename_document() -> anyhow::Result<()> {
|
||||
let params = lsp_types::DidOpenTextDocumentParams {
|
||||
text_document: generate_filler_text_document(None, None),
|
||||
};
|
||||
let file_store = generate_base_file_store()?;
|
||||
file_store.opened_text_document(params)?;
|
||||
|
||||
let params = RenameFilesParams {
|
||||
files: vec![FileRename {
|
||||
old_uri: "file:///filler/".to_string(),
|
||||
new_uri: "file:///filler2/".to_string(),
|
||||
}],
|
||||
};
|
||||
file_store.renamed_files(params)?;
|
||||
|
||||
let file = file_store
|
||||
.file_map
|
||||
.lock()
|
||||
.get("file:///filler2/")
|
||||
.unwrap()
|
||||
.clone();
|
||||
assert_eq!(file.rope.to_string(), "Here is the document body");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_change_document() -> anyhow::Result<()> {
|
||||
let text_document = generate_filler_text_document(None, None);
|
||||
|
||||
let params = DidOpenTextDocumentParams {
|
||||
text_document: text_document.clone(),
|
||||
};
|
||||
let file_store = generate_base_file_store()?;
|
||||
file_store.opened_text_document(params)?;
|
||||
|
||||
let params = lsp_types::DidChangeTextDocumentParams {
|
||||
text_document: VersionedTextDocumentIdentifier {
|
||||
uri: text_document.uri.clone(),
|
||||
version: 1,
|
||||
},
|
||||
content_changes: vec![TextDocumentContentChangeEvent {
|
||||
range: Some(Range {
|
||||
start: Position {
|
||||
line: 0,
|
||||
character: 1,
|
||||
},
|
||||
end: Position {
|
||||
line: 0,
|
||||
character: 3,
|
||||
},
|
||||
}),
|
||||
range_length: None,
|
||||
text: "a".to_string(),
|
||||
}],
|
||||
};
|
||||
file_store.changed_text_document(params)?;
|
||||
let file = file_store
|
||||
.file_map
|
||||
.lock()
|
||||
.get("file:///filler/")
|
||||
.unwrap()
|
||||
.clone();
|
||||
assert_eq!(file.rope.to_string(), "Hae is the document body");
|
||||
|
||||
let params = lsp_types::DidChangeTextDocumentParams {
|
||||
text_document: VersionedTextDocumentIdentifier {
|
||||
uri: text_document.uri,
|
||||
version: 1,
|
||||
},
|
||||
content_changes: vec![TextDocumentContentChangeEvent {
|
||||
range: None,
|
||||
range_length: None,
|
||||
text: "abc".to_string(),
|
||||
}],
|
||||
};
|
||||
file_store.changed_text_document(params)?;
|
||||
let file = file_store
|
||||
.file_map
|
||||
.lock()
|
||||
.get("file:///filler/")
|
||||
.unwrap()
|
||||
.clone();
|
||||
assert_eq!(file.rope.to_string(), "abc");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn can_build_prompt() -> anyhow::Result<()> {
|
||||
let text_document = generate_filler_text_document(
|
||||
None,
|
||||
Some(
|
||||
r#"Document Top
|
||||
Here is a more complicated document
|
||||
|
||||
Some text
|
||||
|
||||
The end with a trailing new line
|
||||
"#,
|
||||
),
|
||||
);
|
||||
|
||||
// Test basic completion
|
||||
let params = lsp_types::DidOpenTextDocumentParams {
|
||||
text_document: text_document.clone(),
|
||||
};
|
||||
let file_store = generate_base_file_store()?;
|
||||
file_store.opened_text_document(params)?;
|
||||
|
||||
let prompt = file_store
|
||||
.build_prompt(
|
||||
&TextDocumentPositionParams {
|
||||
text_document: TextDocumentIdentifier {
|
||||
uri: text_document.uri.clone(),
|
||||
},
|
||||
position: Position {
|
||||
line: 0,
|
||||
character: 10,
|
||||
},
|
||||
},
|
||||
PromptType::ContextAndCode,
|
||||
&json!({}),
|
||||
)
|
||||
.await?;
|
||||
let prompt: ContextAndCodePrompt = prompt.try_into()?;
|
||||
assert_eq!(prompt.context, "");
|
||||
assert_eq!("Document T", prompt.code);
|
||||
|
||||
// Test FIM
|
||||
let prompt = file_store
|
||||
.build_prompt(
|
||||
&TextDocumentPositionParams {
|
||||
text_document: TextDocumentIdentifier {
|
||||
uri: text_document.uri.clone(),
|
||||
},
|
||||
position: Position {
|
||||
line: 0,
|
||||
character: 10,
|
||||
},
|
||||
},
|
||||
PromptType::FIM,
|
||||
&json!({}),
|
||||
)
|
||||
.await?;
|
||||
let prompt: FIMPrompt = prompt.try_into()?;
|
||||
assert_eq!(prompt.prompt, r#"Document T"#);
|
||||
assert_eq!(
|
||||
prompt.suffix,
|
||||
r#"op
|
||||
Here is a more complicated document
|
||||
|
||||
Some text
|
||||
|
||||
The end with a trailing new line
|
||||
"#
|
||||
);
|
||||
|
||||
// Test chat
|
||||
let prompt = file_store
|
||||
.build_prompt(
|
||||
&TextDocumentPositionParams {
|
||||
text_document: TextDocumentIdentifier {
|
||||
uri: text_document.uri.clone(),
|
||||
},
|
||||
position: Position {
|
||||
line: 0,
|
||||
character: 10,
|
||||
},
|
||||
},
|
||||
PromptType::ContextAndCode,
|
||||
&json!({
|
||||
"messages": []
|
||||
}),
|
||||
)
|
||||
.await?;
|
||||
let prompt: ContextAndCodePrompt = prompt.try_into()?;
|
||||
assert_eq!(prompt.context, "");
|
||||
let text = r#"Document T<CURSOR>op
|
||||
Here is a more complicated document
|
||||
|
||||
Some text
|
||||
|
||||
The end with a trailing new line
|
||||
"#
|
||||
.to_string();
|
||||
assert_eq!(text, prompt.code);
|
||||
|
||||
// Test multi-file
|
||||
let text_document2 = generate_filler_text_document(
|
||||
Some("file:///filler2"),
|
||||
Some(
|
||||
r#"Document Top2
|
||||
Here is a more complicated document
|
||||
|
||||
Some text
|
||||
|
||||
The end with a trailing new line
|
||||
"#,
|
||||
),
|
||||
);
|
||||
let params = lsp_types::DidOpenTextDocumentParams {
|
||||
text_document: text_document2.clone(),
|
||||
};
|
||||
file_store.opened_text_document(params)?;
|
||||
|
||||
let prompt = file_store
|
||||
.build_prompt(
|
||||
&TextDocumentPositionParams {
|
||||
text_document: TextDocumentIdentifier {
|
||||
uri: text_document.uri.clone(),
|
||||
},
|
||||
position: Position {
|
||||
line: 0,
|
||||
character: 10,
|
||||
},
|
||||
},
|
||||
PromptType::ContextAndCode,
|
||||
&json!({}),
|
||||
)
|
||||
.await?;
|
||||
let prompt: ContextAndCodePrompt = prompt.try_into()?;
|
||||
assert_eq!(prompt.context, "");
|
||||
assert_eq!(format!("{}\nDocument T", text_document2.text), prompt.code);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_document_cursor_placement_corner_cases() -> anyhow::Result<()> {
|
||||
let text_document = generate_filler_text_document(None, Some("test\n"));
|
||||
let params = lsp_types::DidOpenTextDocumentParams {
|
||||
text_document: text_document.clone(),
|
||||
};
|
||||
let file_store = generate_base_file_store()?;
|
||||
file_store.opened_text_document(params)?;
|
||||
|
||||
// Test chat
|
||||
let prompt = file_store
|
||||
.build_prompt(
|
||||
&TextDocumentPositionParams {
|
||||
text_document: TextDocumentIdentifier {
|
||||
uri: text_document.uri.clone(),
|
||||
},
|
||||
position: Position {
|
||||
line: 1,
|
||||
character: 0,
|
||||
},
|
||||
},
|
||||
PromptType::ContextAndCode,
|
||||
&json!({"messages": []}),
|
||||
)
|
||||
.await?;
|
||||
let prompt: ContextAndCodePrompt = prompt.try_into()?;
|
||||
assert_eq!(prompt.context, "");
|
||||
let text = r#"test
|
||||
<CURSOR>"#
|
||||
.to_string();
|
||||
assert_eq!(text, prompt.code);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_file_store_tree_sitter() -> anyhow::Result<()> {
|
||||
crate::init_logger();
|
||||
|
||||
let config = Config::default_with_file_store_without_models();
|
||||
let file_store_config = if let config::ValidMemoryBackend::FileStore(file_store_config) =
|
||||
config.config.memory.clone()
|
||||
{
|
||||
file_store_config
|
||||
} else {
|
||||
anyhow::bail!("requires a file_store_config")
|
||||
};
|
||||
let params = AdditionalFileStoreParams { build_tree: true };
|
||||
let file_store = FileStore::new_with_params(file_store_config, config, params)?;
|
||||
|
||||
let uri = "file:///filler/test.rs";
|
||||
let text = r#"#[derive(Debug)]
|
||||
struct Rectangle {
|
||||
width: u32,
|
||||
height: u32,
|
||||
}
|
||||
|
||||
impl Rectangle {
|
||||
fn area(&self) -> u32 {
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let rect1 = Rectangle {
|
||||
width: 30,
|
||||
height: 50,
|
||||
};
|
||||
|
||||
println!(
|
||||
"The area of the rectangle is {} square pixels.",
|
||||
rect1.area()
|
||||
);
|
||||
}"#;
|
||||
let text_document = TextDocumentItem {
|
||||
uri: reqwest::Url::parse(uri).unwrap(),
|
||||
language_id: "".to_string(),
|
||||
version: 0,
|
||||
text: text.to_string(),
|
||||
};
|
||||
let params = DidOpenTextDocumentParams {
|
||||
text_document: text_document.clone(),
|
||||
};
|
||||
|
||||
file_store.opened_text_document(params)?;
|
||||
|
||||
// Test insert
|
||||
let params = lsp_types::DidChangeTextDocumentParams {
|
||||
text_document: VersionedTextDocumentIdentifier {
|
||||
uri: text_document.uri.clone(),
|
||||
version: 1,
|
||||
},
|
||||
content_changes: vec![TextDocumentContentChangeEvent {
|
||||
range: Some(Range {
|
||||
start: Position {
|
||||
line: 8,
|
||||
character: 0,
|
||||
},
|
||||
end: Position {
|
||||
line: 8,
|
||||
character: 0,
|
||||
},
|
||||
}),
|
||||
range_length: None,
|
||||
text: " self.width * self.height".to_string(),
|
||||
}],
|
||||
};
|
||||
file_store.changed_text_document(params)?;
|
||||
let file = file_store.file_map.lock().get(uri).unwrap().clone();
|
||||
assert_eq!(file.tree.unwrap().root_node().to_sexp(), "(source_file (attribute_item (attribute (identifier) arguments: (token_tree (identifier)))) (struct_item name: (type_identifier) body: (field_declaration_list (field_declaration name: (field_identifier) type: (primitive_type)) (field_declaration name: (field_identifier) type: (primitive_type)))) (impl_item type: (type_identifier) body: (declaration_list (function_item name: (identifier) parameters: (parameters (self_parameter (self))) return_type: (primitive_type) body: (block (binary_expression left: (field_expression value: (self) field: (field_identifier)) right: (field_expression value: (self) field: (field_identifier))))))) (function_item name: (identifier) parameters: (parameters) body: (block (let_declaration pattern: (identifier) value: (struct_expression name: (type_identifier) body: (field_initializer_list (field_initializer field: (field_identifier) value: (integer_literal)) (field_initializer field: (field_identifier) value: (integer_literal))))) (expression_statement (macro_invocation macro: (identifier) (token_tree (string_literal (string_content)) (identifier) (identifier) (token_tree)))))))");
|
||||
|
||||
// Test delete
|
||||
let params = lsp_types::DidChangeTextDocumentParams {
|
||||
text_document: VersionedTextDocumentIdentifier {
|
||||
uri: text_document.uri.clone(),
|
||||
version: 1,
|
||||
},
|
||||
content_changes: vec![TextDocumentContentChangeEvent {
|
||||
range: Some(Range {
|
||||
start: Position {
|
||||
line: 0,
|
||||
character: 0,
|
||||
},
|
||||
end: Position {
|
||||
line: 12,
|
||||
character: 0,
|
||||
},
|
||||
}),
|
||||
range_length: None,
|
||||
text: "".to_string(),
|
||||
}],
|
||||
};
|
||||
file_store.changed_text_document(params)?;
|
||||
let file = file_store.file_map.lock().get(uri).unwrap().clone();
|
||||
assert_eq!(file.tree.unwrap().root_node().to_sexp(), "(source_file (function_item name: (identifier) parameters: (parameters) body: (block (let_declaration pattern: (identifier) value: (struct_expression name: (type_identifier) body: (field_initializer_list (field_initializer field: (field_identifier) value: (integer_literal)) (field_initializer field: (field_identifier) value: (integer_literal))))) (expression_statement (macro_invocation macro: (identifier) (token_tree (string_literal (string_content)) (identifier) (identifier) (token_tree)))))))");
|
||||
|
||||
// Test replace
|
||||
let params = lsp_types::DidChangeTextDocumentParams {
|
||||
text_document: VersionedTextDocumentIdentifier {
|
||||
uri: text_document.uri,
|
||||
version: 1,
|
||||
},
|
||||
content_changes: vec![TextDocumentContentChangeEvent {
|
||||
range: None,
|
||||
range_length: None,
|
||||
text: "fn main() {}".to_string(),
|
||||
}],
|
||||
};
|
||||
file_store.changed_text_document(params)?;
|
||||
let file = file_store.file_map.lock().get(uri).unwrap().clone();
|
||||
assert_eq!(file.tree.unwrap().root_node().to_sexp(), "(source_file (function_item name: (identifier) parameters: (parameters) body: (block)))");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
@ -18,13 +18,13 @@ pub enum PromptType {
|
||||
#[derive(Clone)]
|
||||
pub struct MemoryRunParams {
|
||||
pub is_for_chat: bool,
|
||||
pub max_context_length: usize,
|
||||
pub max_context: usize,
|
||||
}
|
||||
|
||||
impl From<&Value> for MemoryRunParams {
|
||||
fn from(value: &Value) -> Self {
|
||||
Self {
|
||||
max_context_length: value["max_context_length"].as_u64().unwrap_or(1024) as usize,
|
||||
max_context: value["max_context"].as_u64().unwrap_or(1024) as usize,
|
||||
// messages are for most backends, contents are for Gemini
|
||||
is_for_chat: value["messages"].is_array() || value["contents"].is_array(),
|
||||
}
|
||||
@ -113,22 +113,16 @@ pub trait MemoryBackend {
|
||||
async fn init(&self) -> anyhow::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
async fn opened_text_document(&self, params: DidOpenTextDocumentParams) -> anyhow::Result<()>;
|
||||
async fn changed_text_document(
|
||||
&self,
|
||||
params: DidChangeTextDocumentParams,
|
||||
) -> anyhow::Result<()>;
|
||||
async fn renamed_files(&self, params: RenameFilesParams) -> anyhow::Result<()>;
|
||||
fn opened_text_document(&self, params: DidOpenTextDocumentParams) -> anyhow::Result<()>;
|
||||
fn changed_text_document(&self, params: DidChangeTextDocumentParams) -> anyhow::Result<()>;
|
||||
fn renamed_files(&self, params: RenameFilesParams) -> anyhow::Result<()>;
|
||||
fn get_filter_text(&self, position: &TextDocumentPositionParams) -> anyhow::Result<String>;
|
||||
async fn build_prompt(
|
||||
&self,
|
||||
position: &TextDocumentPositionParams,
|
||||
prompt_type: PromptType,
|
||||
params: &Value,
|
||||
) -> anyhow::Result<Prompt>;
|
||||
async fn get_filter_text(
|
||||
&self,
|
||||
position: &TextDocumentPositionParams,
|
||||
) -> anyhow::Result<String>;
|
||||
}
|
||||
|
||||
impl TryFrom<Config> for Box<dyn MemoryBackend + Send + Sync> {
|
||||
@ -137,7 +131,7 @@ impl TryFrom<Config> for Box<dyn MemoryBackend + Send + Sync> {
|
||||
fn try_from(configuration: Config) -> Result<Self, Self::Error> {
|
||||
match configuration.config.memory.clone() {
|
||||
ValidMemoryBackend::FileStore(file_store_config) => Ok(Box::new(
|
||||
file_store::FileStore::new(file_store_config, configuration),
|
||||
file_store::FileStore::new(file_store_config, configuration)?,
|
||||
)),
|
||||
ValidMemoryBackend::PostgresML(postgresml_config) => Ok(Box::new(
|
||||
postgresml::PostgresML::new(postgresml_config, configuration)?,
|
701
crates/lsp-ai/src/memory_backends/postgresml/mod.rs
Normal file
701
crates/lsp-ai/src/memory_backends/postgresml/mod.rs
Normal file
@ -0,0 +1,701 @@
|
||||
use anyhow::Context;
|
||||
use lsp_types::TextDocumentPositionParams;
|
||||
use parking_lot::Mutex;
|
||||
use pgml::{Collection, Pipeline};
|
||||
use rand::{distributions::Alphanumeric, Rng};
|
||||
use serde_json::{json, Value};
|
||||
use std::{
|
||||
collections::HashSet,
|
||||
io::Read,
|
||||
sync::{
|
||||
mpsc::{self, Sender},
|
||||
Arc,
|
||||
},
|
||||
time::Duration,
|
||||
};
|
||||
use tokio::time;
|
||||
use tracing::{error, instrument, warn};
|
||||
|
||||
use crate::{
|
||||
config::{self, Config},
|
||||
crawl::Crawl,
|
||||
splitters::{Chunk, Splitter},
|
||||
utils::{chunk_to_id, tokens_to_estimated_characters, TOKIO_RUNTIME},
|
||||
};
|
||||
|
||||
use super::{
|
||||
file_store::{AdditionalFileStoreParams, FileStore},
|
||||
ContextAndCodePrompt, FIMPrompt, MemoryBackend, MemoryRunParams, Prompt, PromptType,
|
||||
};
|
||||
|
||||
const RESYNC_MAX_FILE_SIZE: u64 = 10_000_000;
|
||||
|
||||
fn format_file_excerpt(uri: &str, excerpt: &str, root_uri: Option<&str>) -> String {
|
||||
let path = match root_uri {
|
||||
Some(root_uri) => {
|
||||
if uri.starts_with(root_uri) {
|
||||
&uri[root_uri.chars().count()..]
|
||||
} else {
|
||||
uri
|
||||
}
|
||||
}
|
||||
None => uri,
|
||||
};
|
||||
format!(
|
||||
r#"--{path}--
|
||||
{excerpt}
|
||||
"#,
|
||||
)
|
||||
}
|
||||
|
||||
fn chunk_to_document(uri: &str, chunk: Chunk, root_uri: Option<&str>) -> Value {
|
||||
json!({
|
||||
"id": chunk_to_id(uri, &chunk),
|
||||
"uri": uri,
|
||||
"text": format_file_excerpt(uri, &chunk.text, root_uri),
|
||||
"range": chunk.range
|
||||
})
|
||||
}
|
||||
|
||||
async fn split_and_upsert_file(
|
||||
uri: &str,
|
||||
collection: &mut Collection,
|
||||
file_store: Arc<FileStore>,
|
||||
splitter: Arc<Box<dyn Splitter + Send + Sync>>,
|
||||
root_uri: Option<&str>,
|
||||
) -> anyhow::Result<()> {
|
||||
// We need to make sure we don't hold the file_store lock while performing a network call
|
||||
let chunks = {
|
||||
file_store
|
||||
.file_map()
|
||||
.lock()
|
||||
.get(uri)
|
||||
.map(|f| splitter.split(f))
|
||||
};
|
||||
let chunks = chunks.with_context(|| format!("file not found for splitting: {uri}"))?;
|
||||
let documents = chunks
|
||||
.into_iter()
|
||||
.map(|chunk| chunk_to_document(uri, chunk, root_uri).into())
|
||||
.collect();
|
||||
collection
|
||||
.upsert_documents(documents, None)
|
||||
.await
|
||||
.context("PGML - Error upserting documents")
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct PostgresML {
|
||||
config: Config,
|
||||
postgresml_config: config::PostgresML,
|
||||
file_store: Arc<FileStore>,
|
||||
collection: Collection,
|
||||
pipeline: Pipeline,
|
||||
debounce_tx: Sender<String>,
|
||||
crawl: Option<Arc<Mutex<Crawl>>>,
|
||||
splitter: Arc<Box<dyn Splitter + Send + Sync>>,
|
||||
}
|
||||
|
||||
impl PostgresML {
|
||||
#[instrument]
|
||||
pub fn new(
|
||||
mut postgresml_config: config::PostgresML,
|
||||
configuration: Config,
|
||||
) -> anyhow::Result<Self> {
|
||||
let crawl = postgresml_config
|
||||
.crawl
|
||||
.take()
|
||||
.map(|x| Arc::new(Mutex::new(Crawl::new(x, configuration.clone()))));
|
||||
|
||||
let splitter: Arc<Box<dyn Splitter + Send + Sync>> =
|
||||
Arc::new(postgresml_config.splitter.clone().try_into()?);
|
||||
|
||||
let file_store = Arc::new(FileStore::new_with_params(
|
||||
config::FileStore::new_without_crawl(),
|
||||
configuration.clone(),
|
||||
AdditionalFileStoreParams::new(splitter.does_use_tree_sitter()),
|
||||
)?);
|
||||
|
||||
let database_url = if let Some(database_url) = postgresml_config.database_url.clone() {
|
||||
database_url
|
||||
} else {
|
||||
std::env::var("PGML_DATABASE_URL").context("please provide either the `database_url` in the `postgresml` config, or set the `PGML_DATABASE_URL` environment variable")?
|
||||
};
|
||||
|
||||
// Build our pipeline schema
|
||||
let pipeline = match &postgresml_config.embedding_model {
|
||||
Some(embedding_model) => {
|
||||
json!({
|
||||
"text": {
|
||||
"semantic_search": {
|
||||
"model": embedding_model.model,
|
||||
"parameters": embedding_model.embed_parameters
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
None => {
|
||||
json!({
|
||||
"text": {
|
||||
"semantic_search": {
|
||||
"model": "intfloat/e5-small-v2",
|
||||
"parameters": {
|
||||
"prompt": "passage: "
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
};
|
||||
|
||||
// When building the collection name we include the Pipeline schema
|
||||
// If the user changes the Pipeline schema, it will take affect without them having to delete the old files
|
||||
let collection_name = match configuration.client_params.root_uri.clone() {
|
||||
Some(root_uri) => format!(
|
||||
"{:x}",
|
||||
md5::compute(
|
||||
format!("{root_uri}_{}", serde_json::to_string(&pipeline)?).as_bytes()
|
||||
)
|
||||
),
|
||||
None => {
|
||||
warn!("no root_uri provided in server configuration - generating random string for collection name");
|
||||
rand::thread_rng()
|
||||
.sample_iter(&Alphanumeric)
|
||||
.take(21)
|
||||
.map(char::from)
|
||||
.collect()
|
||||
}
|
||||
};
|
||||
let mut collection = Collection::new(&collection_name, Some(database_url))?;
|
||||
let mut pipeline = Pipeline::new("v1", Some(pipeline.into()))?;
|
||||
|
||||
// Add the Pipeline to the Collection
|
||||
TOKIO_RUNTIME.block_on(async {
|
||||
collection
|
||||
.add_pipeline(&mut pipeline)
|
||||
.await
|
||||
.context("PGML - error adding pipeline to collection")
|
||||
})?;
|
||||
|
||||
// Setup up a debouncer for changed text documents
|
||||
let (debounce_tx, debounce_rx) = mpsc::channel::<String>();
|
||||
let mut task_collection = collection.clone();
|
||||
let task_file_store = file_store.clone();
|
||||
let task_splitter = splitter.clone();
|
||||
let task_root_uri = configuration.client_params.root_uri.clone();
|
||||
TOKIO_RUNTIME.spawn(async move {
|
||||
let duration = Duration::from_millis(500);
|
||||
let mut file_uris = Vec::new();
|
||||
loop {
|
||||
time::sleep(duration).await;
|
||||
let new_uris: Vec<String> = debounce_rx.try_iter().collect();
|
||||
if !new_uris.is_empty() {
|
||||
for uri in new_uris {
|
||||
if !file_uris.iter().any(|p| *p == uri) {
|
||||
file_uris.push(uri);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if file_uris.is_empty() {
|
||||
continue;
|
||||
}
|
||||
// Build the chunks for our changed files
|
||||
let chunks: Vec<Vec<Chunk>> = match file_uris
|
||||
.iter()
|
||||
.map(|uri| {
|
||||
let file_store = task_file_store.file_map().lock();
|
||||
let file = file_store
|
||||
.get(uri)
|
||||
.with_context(|| format!("getting file for splitting: {uri}"))?;
|
||||
anyhow::Ok(task_splitter.split(file))
|
||||
})
|
||||
.collect()
|
||||
{
|
||||
Ok(chunks) => chunks,
|
||||
Err(e) => {
|
||||
error!("{e:?}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
// Delete old chunks that no longer exist after the latest file changes
|
||||
let delete_or_statements: Vec<Value> = file_uris
|
||||
.iter()
|
||||
.zip(&chunks)
|
||||
.map(|(uri, chunks)| {
|
||||
let ids: Vec<String> =
|
||||
chunks.iter().map(|c| chunk_to_id(uri, c)).collect();
|
||||
json!({
|
||||
"$and": [
|
||||
{
|
||||
"uri": {
|
||||
"$eq": uri
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": {
|
||||
"$nin": ids
|
||||
}
|
||||
}
|
||||
]
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
if let Err(e) = task_collection
|
||||
.delete_documents(
|
||||
json!({
|
||||
"$or": delete_or_statements
|
||||
})
|
||||
.into(),
|
||||
)
|
||||
.await
|
||||
.context("PGML - error deleting documents")
|
||||
{
|
||||
error!("{e:?}");
|
||||
}
|
||||
// Prepare and upsert our new chunks
|
||||
let documents: Vec<pgml::types::Json> = chunks
|
||||
.into_iter()
|
||||
.zip(&file_uris)
|
||||
.map(|(chunks, uri)| {
|
||||
chunks
|
||||
.into_iter()
|
||||
.map(|chunk| {
|
||||
chunk_to_document(&uri, chunk, task_root_uri.as_deref())
|
||||
})
|
||||
.collect::<Vec<Value>>()
|
||||
})
|
||||
.flatten()
|
||||
.map(|f: Value| f.into())
|
||||
.collect();
|
||||
if let Err(e) = task_collection
|
||||
.upsert_documents(documents, None)
|
||||
.await
|
||||
.context("PGML - error upserting changed files")
|
||||
{
|
||||
error!("{e:?}");
|
||||
continue;
|
||||
}
|
||||
|
||||
file_uris = Vec::new();
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let s = Self {
|
||||
config: configuration,
|
||||
postgresml_config,
|
||||
file_store,
|
||||
collection,
|
||||
pipeline,
|
||||
debounce_tx,
|
||||
crawl,
|
||||
splitter,
|
||||
};
|
||||
|
||||
// Resync our Collection
|
||||
let task_s = s.clone();
|
||||
TOKIO_RUNTIME.spawn(async move {
|
||||
if let Err(e) = task_s.resync().await {
|
||||
error!("{e:?}")
|
||||
}
|
||||
});
|
||||
|
||||
if let Err(e) = s.maybe_do_crawl(None) {
|
||||
error!("{e:?}")
|
||||
}
|
||||
Ok(s)
|
||||
}
|
||||
|
||||
async fn resync(&self) -> anyhow::Result<()> {
|
||||
let mut collection = self.collection.clone();
|
||||
|
||||
let documents = collection
|
||||
.get_documents(Some(
|
||||
json!({
|
||||
"limit": 100_000_000,
|
||||
"keys": ["uri"]
|
||||
})
|
||||
.into(),
|
||||
))
|
||||
.await?;
|
||||
|
||||
let try_get_file_contents = |path: &std::path::Path| {
|
||||
// Open the file and see if it is small enough to read
|
||||
let mut f = std::fs::File::open(path)?;
|
||||
let metadata = f.metadata()?;
|
||||
if metadata.len() > RESYNC_MAX_FILE_SIZE {
|
||||
anyhow::bail!("file size is greater than: {RESYNC_MAX_FILE_SIZE}")
|
||||
}
|
||||
// Read the file contents
|
||||
let mut contents = vec![];
|
||||
f.read_to_end(&mut contents)?;
|
||||
anyhow::Ok(String::from_utf8(contents)?)
|
||||
};
|
||||
|
||||
let mut documents_to_delete = vec![];
|
||||
let mut chunks_to_upsert = vec![];
|
||||
let mut current_chunks_bytes = 0;
|
||||
let mut checked_uris = HashSet::new();
|
||||
for document in documents.into_iter() {
|
||||
let uri = match document["document"]["uri"].as_str() {
|
||||
Some(uri) => uri,
|
||||
None => continue, // This should never happen, but is really bad as we now have a document with essentially no way to delete it
|
||||
};
|
||||
|
||||
// Check if we have already loaded in this file
|
||||
if checked_uris.contains(uri) {
|
||||
continue;
|
||||
}
|
||||
checked_uris.insert(uri.to_string());
|
||||
|
||||
let path = uri.replace("file://", "");
|
||||
let path = std::path::Path::new(&path);
|
||||
if !path.exists() {
|
||||
documents_to_delete.push(uri.to_string());
|
||||
} else {
|
||||
// Try to read the file. If we fail delete it
|
||||
let contents = match try_get_file_contents(path) {
|
||||
Ok(contents) => contents,
|
||||
Err(e) => {
|
||||
error!("{e:?}");
|
||||
documents_to_delete.push(uri.to_string());
|
||||
continue;
|
||||
}
|
||||
};
|
||||
// Split the file into chunks
|
||||
current_chunks_bytes += contents.len();
|
||||
let chunks: Vec<pgml::types::Json> = self
|
||||
.splitter
|
||||
.split_file_contents(&uri, &contents)
|
||||
.into_iter()
|
||||
.map(|chunk| {
|
||||
chunk_to_document(
|
||||
&uri,
|
||||
chunk,
|
||||
self.config.client_params.root_uri.as_deref(),
|
||||
)
|
||||
.into()
|
||||
})
|
||||
.collect();
|
||||
chunks_to_upsert.extend(chunks);
|
||||
// If we have over 10 mega bytes of chunks do the upsert
|
||||
if current_chunks_bytes > 10_000_000 {
|
||||
collection
|
||||
.upsert_documents(chunks_to_upsert, None)
|
||||
.await
|
||||
.context("PGML - error upserting documents during resync")?;
|
||||
chunks_to_upsert = vec![];
|
||||
current_chunks_bytes = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Upsert any remaining chunks
|
||||
if chunks_to_upsert.len() > 0 {
|
||||
collection
|
||||
.upsert_documents(chunks_to_upsert, None)
|
||||
.await
|
||||
.context("PGML - error upserting documents during resync")?;
|
||||
}
|
||||
// Delete documents
|
||||
if !documents_to_delete.is_empty() {
|
||||
collection
|
||||
.delete_documents(
|
||||
json!({
|
||||
"uri": {
|
||||
"$in": documents_to_delete
|
||||
}
|
||||
})
|
||||
.into(),
|
||||
)
|
||||
.await
|
||||
.context("PGML - error deleting documents during resync")?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn maybe_do_crawl(&self, triggered_file: Option<String>) -> anyhow::Result<()> {
|
||||
if let Some(crawl) = &self.crawl {
|
||||
let mut documents = vec![];
|
||||
let mut total_bytes = 0;
|
||||
let mut current_bytes = 0;
|
||||
crawl
|
||||
.lock()
|
||||
.maybe_do_crawl(triggered_file, |config, path| {
|
||||
// Break if total bytes is over the max crawl memory
|
||||
if total_bytes as u64 >= config.max_crawl_memory {
|
||||
warn!("Ending crawl early due to `max_crawl_memory` restraint");
|
||||
return Ok(false);
|
||||
}
|
||||
// This means it has been opened before
|
||||
let uri = format!("file://{path}");
|
||||
if self.file_store.contains_file(&uri) {
|
||||
return Ok(true);
|
||||
}
|
||||
// Open the file and see if it is small enough to read
|
||||
let mut f = std::fs::File::open(path)?;
|
||||
let metadata = f.metadata()?;
|
||||
if metadata.len() > config.max_file_size {
|
||||
warn!("Skipping file: {path} because it is too large");
|
||||
return Ok(true);
|
||||
}
|
||||
// Read the file contents
|
||||
let mut contents = vec![];
|
||||
f.read_to_end(&mut contents)?;
|
||||
let contents = String::from_utf8(contents)?;
|
||||
current_bytes += contents.len();
|
||||
total_bytes += contents.len();
|
||||
let chunks: Vec<pgml::types::Json> = self
|
||||
.splitter
|
||||
.split_file_contents(&uri, &contents)
|
||||
.into_iter()
|
||||
.map(|chunk| {
|
||||
chunk_to_document(
|
||||
&uri,
|
||||
chunk,
|
||||
self.config.client_params.root_uri.as_deref(),
|
||||
)
|
||||
.into()
|
||||
})
|
||||
.collect();
|
||||
documents.extend(chunks);
|
||||
// If we have over 10 mega bytes of data do the upsert
|
||||
if current_bytes >= 10_000_000 || total_bytes as u64 >= config.max_crawl_memory
|
||||
{
|
||||
// Upsert the documents
|
||||
let mut collection = self.collection.clone();
|
||||
let to_upsert_documents = std::mem::take(&mut documents);
|
||||
TOKIO_RUNTIME.spawn(async move {
|
||||
if let Err(e) = collection
|
||||
.upsert_documents(to_upsert_documents, None)
|
||||
.await
|
||||
.context("PGML - error upserting changed files")
|
||||
{
|
||||
error!("{e:?}");
|
||||
}
|
||||
});
|
||||
// Reset everything
|
||||
current_bytes = 0;
|
||||
documents = vec![];
|
||||
}
|
||||
Ok(true)
|
||||
})?;
|
||||
// Upsert any remaining documents
|
||||
if documents.len() > 0 {
|
||||
let mut collection = self.collection.clone();
|
||||
TOKIO_RUNTIME.spawn(async move {
|
||||
if let Err(e) = collection
|
||||
.upsert_documents(documents, None)
|
||||
.await
|
||||
.context("PGML - error upserting changed files")
|
||||
{
|
||||
error!("{e:?}");
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl MemoryBackend for PostgresML {
|
||||
#[instrument(skip(self))]
|
||||
fn get_filter_text(&self, position: &TextDocumentPositionParams) -> anyhow::Result<String> {
|
||||
self.file_store.get_filter_text(position)
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
async fn build_prompt(
|
||||
&self,
|
||||
position: &TextDocumentPositionParams,
|
||||
prompt_type: PromptType,
|
||||
params: &Value,
|
||||
) -> anyhow::Result<Prompt> {
|
||||
let params: MemoryRunParams = params.try_into()?;
|
||||
let chunk_size = self.splitter.chunk_size();
|
||||
let total_allowed_characters = tokens_to_estimated_characters(params.max_context);
|
||||
|
||||
// Build the query
|
||||
let query = self
|
||||
.file_store
|
||||
.get_characters_around_position(position, chunk_size)?;
|
||||
|
||||
// Build the prompt
|
||||
let mut file_store_params = params.clone();
|
||||
file_store_params.max_context = chunk_size;
|
||||
let code = self
|
||||
.file_store
|
||||
.build_code(position, prompt_type, file_store_params, false)?;
|
||||
|
||||
// Get the byte of the cursor
|
||||
let cursor_byte = self.file_store.position_to_byte(position)?;
|
||||
|
||||
// Get the context
|
||||
let limit = (total_allowed_characters / chunk_size).saturating_sub(1);
|
||||
let parameters = match self
|
||||
.postgresml_config
|
||||
.embedding_model
|
||||
.as_ref()
|
||||
.map(|m| m.query_parameters.clone())
|
||||
.flatten()
|
||||
{
|
||||
Some(query_parameters) => query_parameters,
|
||||
None => json!({
|
||||
"prompt": "query: "
|
||||
}),
|
||||
};
|
||||
let res = self
|
||||
.collection
|
||||
.vector_search_local(
|
||||
json!({
|
||||
"query": {
|
||||
"fields": {
|
||||
"text": {
|
||||
"query": query,
|
||||
"parameters": parameters
|
||||
}
|
||||
},
|
||||
"filter": {
|
||||
"$or": [
|
||||
{
|
||||
"uri": {
|
||||
"$ne": position.text_document.uri.to_string()
|
||||
}
|
||||
},
|
||||
{
|
||||
"range": {
|
||||
"start": {
|
||||
"$gt": cursor_byte
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"range": {
|
||||
"end": {
|
||||
"$lt": cursor_byte
|
||||
},
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"limit": limit
|
||||
})
|
||||
.into(),
|
||||
&self.pipeline,
|
||||
)
|
||||
.await?;
|
||||
let context = res
|
||||
.into_iter()
|
||||
.map(|c| {
|
||||
c["chunk"]
|
||||
.as_str()
|
||||
.map(|t| t.to_owned())
|
||||
.context("PGML - Error getting chunk from vector search")
|
||||
})
|
||||
.collect::<anyhow::Result<Vec<String>>>()?
|
||||
.join("\n\n");
|
||||
let context = &context[..(total_allowed_characters - chunk_size).min(context.len())];
|
||||
|
||||
// Reconstruct the Prompts
|
||||
Ok(match code {
|
||||
Prompt::ContextAndCode(context_and_code) => {
|
||||
Prompt::ContextAndCode(ContextAndCodePrompt::new(
|
||||
context.to_owned(),
|
||||
format_file_excerpt(
|
||||
&position.text_document.uri.to_string(),
|
||||
&context_and_code.code,
|
||||
self.config.client_params.root_uri.as_deref(),
|
||||
),
|
||||
))
|
||||
}
|
||||
Prompt::FIM(fim) => Prompt::FIM(FIMPrompt::new(
|
||||
format!("{context}\n\n{}", fim.prompt),
|
||||
fim.suffix,
|
||||
)),
|
||||
})
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
fn opened_text_document(
|
||||
&self,
|
||||
params: lsp_types::DidOpenTextDocumentParams,
|
||||
) -> anyhow::Result<()> {
|
||||
self.file_store.opened_text_document(params.clone())?;
|
||||
|
||||
let saved_uri = params.text_document.uri.to_string();
|
||||
|
||||
let mut collection = self.collection.clone();
|
||||
let file_store = self.file_store.clone();
|
||||
let splitter = self.splitter.clone();
|
||||
let root_uri = self.config.client_params.root_uri.clone();
|
||||
TOKIO_RUNTIME.spawn(async move {
|
||||
let uri = params.text_document.uri.to_string();
|
||||
if let Err(e) = split_and_upsert_file(
|
||||
&uri,
|
||||
&mut collection,
|
||||
file_store,
|
||||
splitter,
|
||||
root_uri.as_deref(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
error!("{e:?}")
|
||||
}
|
||||
});
|
||||
|
||||
if let Err(e) = self.maybe_do_crawl(Some(saved_uri)) {
|
||||
error!("{e:?}")
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
fn changed_text_document(
|
||||
&self,
|
||||
params: lsp_types::DidChangeTextDocumentParams,
|
||||
) -> anyhow::Result<()> {
|
||||
self.file_store.changed_text_document(params.clone())?;
|
||||
let uri = params.text_document.uri.to_string();
|
||||
self.debounce_tx.send(uri)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
fn renamed_files(&self, params: lsp_types::RenameFilesParams) -> anyhow::Result<()> {
|
||||
self.file_store.renamed_files(params.clone())?;
|
||||
|
||||
let mut collection = self.collection.clone();
|
||||
let file_store = self.file_store.clone();
|
||||
let splitter = self.splitter.clone();
|
||||
let root_uri = self.config.client_params.root_uri.clone();
|
||||
TOKIO_RUNTIME.spawn(async move {
|
||||
for file in params.files {
|
||||
if let Err(e) = collection
|
||||
.delete_documents(
|
||||
json!({
|
||||
"uri": {
|
||||
"$eq": file.old_uri
|
||||
}
|
||||
})
|
||||
.into(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
error!("PGML - Error deleting file: {e:?}");
|
||||
}
|
||||
if let Err(e) = split_and_upsert_file(
|
||||
&file.new_uri,
|
||||
&mut collection,
|
||||
file_store.clone(),
|
||||
splitter.clone(),
|
||||
root_uri.as_deref(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
error!("{e:?}")
|
||||
}
|
||||
}
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
}
|
@ -7,7 +7,10 @@ use lsp_types::{
|
||||
use serde_json::Value;
|
||||
use tracing::error;
|
||||
|
||||
use crate::memory_backends::{MemoryBackend, Prompt, PromptType};
|
||||
use crate::{
|
||||
memory_backends::{MemoryBackend, Prompt, PromptType},
|
||||
utils::TOKIO_RUNTIME,
|
||||
};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct PromptRequest {
|
||||
@ -56,34 +59,45 @@ pub enum WorkerRequest {
|
||||
DidRenameFiles(RenameFilesParams),
|
||||
}
|
||||
|
||||
async fn do_task(
|
||||
async fn do_build_prompt(
|
||||
params: PromptRequest,
|
||||
memory_backend: Arc<Box<dyn MemoryBackend + Send + Sync>>,
|
||||
) -> anyhow::Result<()> {
|
||||
let prompt = memory_backend
|
||||
.build_prompt(¶ms.position, params.prompt_type, ¶ms.params)
|
||||
.await?;
|
||||
params
|
||||
.tx
|
||||
.send(prompt)
|
||||
.map_err(|_| anyhow::anyhow!("sending on channel failed"))
|
||||
}
|
||||
|
||||
fn do_task(
|
||||
request: WorkerRequest,
|
||||
memory_backend: Arc<Box<dyn MemoryBackend + Send + Sync>>,
|
||||
) -> anyhow::Result<()> {
|
||||
match request {
|
||||
WorkerRequest::FilterText(params) => {
|
||||
let filter_text = memory_backend.get_filter_text(¶ms.position).await?;
|
||||
let filter_text = memory_backend.get_filter_text(¶ms.position)?;
|
||||
params
|
||||
.tx
|
||||
.send(filter_text)
|
||||
.map_err(|_| anyhow::anyhow!("sending on channel failed"))?;
|
||||
}
|
||||
WorkerRequest::Prompt(params) => {
|
||||
let prompt = memory_backend
|
||||
.build_prompt(¶ms.position, params.prompt_type, ¶ms.params)
|
||||
.await?;
|
||||
params
|
||||
.tx
|
||||
.send(prompt)
|
||||
.map_err(|_| anyhow::anyhow!("sending on channel failed"))?;
|
||||
TOKIO_RUNTIME.spawn(async move {
|
||||
if let Err(e) = do_build_prompt(params, memory_backend).await {
|
||||
error!("error in memory worker building prompt: {e}")
|
||||
}
|
||||
});
|
||||
}
|
||||
WorkerRequest::DidOpenTextDocument(params) => {
|
||||
memory_backend.opened_text_document(params).await?;
|
||||
memory_backend.opened_text_document(params)?;
|
||||
}
|
||||
WorkerRequest::DidChangeTextDocument(params) => {
|
||||
memory_backend.changed_text_document(params).await?;
|
||||
memory_backend.changed_text_document(params)?;
|
||||
}
|
||||
WorkerRequest::DidRenameFiles(params) => memory_backend.renamed_files(params).await?,
|
||||
WorkerRequest::DidRenameFiles(params) => memory_backend.renamed_files(params)?,
|
||||
}
|
||||
anyhow::Ok(())
|
||||
}
|
||||
@ -93,18 +107,11 @@ fn do_run(
|
||||
rx: std::sync::mpsc::Receiver<WorkerRequest>,
|
||||
) -> anyhow::Result<()> {
|
||||
let memory_backend = Arc::new(memory_backend);
|
||||
let runtime = tokio::runtime::Builder::new_multi_thread()
|
||||
.worker_threads(4)
|
||||
.enable_all()
|
||||
.build()?;
|
||||
loop {
|
||||
let request = rx.recv()?;
|
||||
let thread_memory_backend = memory_backend.clone();
|
||||
runtime.spawn(async move {
|
||||
if let Err(e) = do_task(request, thread_memory_backend).await {
|
||||
error!("error in memory worker task: {e}")
|
||||
}
|
||||
});
|
||||
if let Err(e) = do_task(request, memory_backend.clone()) {
|
||||
error!("error in memory worker task: {e}")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
59
crates/lsp-ai/src/splitters/mod.rs
Normal file
59
crates/lsp-ai/src/splitters/mod.rs
Normal file
@ -0,0 +1,59 @@
|
||||
use serde::Serialize;
|
||||
|
||||
use crate::{config::ValidSplitter, memory_backends::file_store::File};
|
||||
|
||||
mod text_splitter;
|
||||
mod tree_sitter;
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct ByteRange {
|
||||
pub start_byte: usize,
|
||||
pub end_byte: usize,
|
||||
}
|
||||
|
||||
impl ByteRange {
|
||||
pub fn new(start_byte: usize, end_byte: usize) -> Self {
|
||||
Self {
|
||||
start_byte,
|
||||
end_byte,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct Chunk {
|
||||
pub text: String,
|
||||
pub range: ByteRange,
|
||||
}
|
||||
|
||||
impl Chunk {
|
||||
fn new(text: String, range: ByteRange) -> Self {
|
||||
Self { text, range }
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Splitter {
|
||||
fn split(&self, file: &File) -> Vec<Chunk>;
|
||||
fn split_file_contents(&self, uri: &str, contents: &str) -> Vec<Chunk>;
|
||||
|
||||
fn does_use_tree_sitter(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn chunk_size(&self) -> usize;
|
||||
}
|
||||
|
||||
impl TryFrom<ValidSplitter> for Box<dyn Splitter + Send + Sync> {
|
||||
type Error = anyhow::Error;
|
||||
|
||||
fn try_from(value: ValidSplitter) -> Result<Self, Self::Error> {
|
||||
match value {
|
||||
ValidSplitter::TreeSitter(config) => {
|
||||
Ok(Box::new(tree_sitter::TreeSitter::new(config)?))
|
||||
}
|
||||
ValidSplitter::TextSplitter(config) => {
|
||||
Ok(Box::new(text_splitter::TextSplitter::new(config)))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
47
crates/lsp-ai/src/splitters/text_splitter.rs
Normal file
47
crates/lsp-ai/src/splitters/text_splitter.rs
Normal file
@ -0,0 +1,47 @@
|
||||
use crate::{config, memory_backends::file_store::File};
|
||||
|
||||
use super::{ByteRange, Chunk, Splitter};
|
||||
|
||||
pub struct TextSplitter {
|
||||
chunk_size: usize,
|
||||
splitter: text_splitter::TextSplitter<text_splitter::Characters>,
|
||||
}
|
||||
|
||||
impl TextSplitter {
|
||||
pub fn new(config: config::TextSplitter) -> Self {
|
||||
Self {
|
||||
chunk_size: config.chunk_size,
|
||||
splitter: text_splitter::TextSplitter::new(config.chunk_size),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_with_chunk_size(chunk_size: usize) -> Self {
|
||||
Self {
|
||||
chunk_size,
|
||||
splitter: text_splitter::TextSplitter::new(chunk_size),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Splitter for TextSplitter {
|
||||
fn split(&self, file: &File) -> Vec<Chunk> {
|
||||
self.split_file_contents("", &file.rope().to_string())
|
||||
}
|
||||
|
||||
fn split_file_contents(&self, _uri: &str, contents: &str) -> Vec<Chunk> {
|
||||
self.splitter
|
||||
.chunk_indices(contents)
|
||||
.fold(vec![], |mut acc, (start_byte, text)| {
|
||||
let end_byte = start_byte + text.len();
|
||||
acc.push(Chunk::new(
|
||||
text.to_string(),
|
||||
ByteRange::new(start_byte, end_byte),
|
||||
));
|
||||
acc
|
||||
})
|
||||
}
|
||||
|
||||
fn chunk_size(&self) -> usize {
|
||||
self.chunk_size
|
||||
}
|
||||
}
|
84
crates/lsp-ai/src/splitters/tree_sitter.rs
Normal file
84
crates/lsp-ai/src/splitters/tree_sitter.rs
Normal file
@ -0,0 +1,84 @@
|
||||
use splitter_tree_sitter::TreeSitterCodeSplitter;
|
||||
use tracing::error;
|
||||
use tree_sitter::Tree;
|
||||
|
||||
use crate::{config, memory_backends::file_store::File, utils::parse_tree};
|
||||
|
||||
use super::{text_splitter::TextSplitter, ByteRange, Chunk, Splitter};
|
||||
|
||||
pub struct TreeSitter {
|
||||
chunk_size: usize,
|
||||
splitter: TreeSitterCodeSplitter,
|
||||
text_splitter: TextSplitter,
|
||||
}
|
||||
|
||||
impl TreeSitter {
|
||||
pub fn new(config: config::TreeSitter) -> anyhow::Result<Self> {
|
||||
let text_splitter = TextSplitter::new_with_chunk_size(config.chunk_size);
|
||||
Ok(Self {
|
||||
chunk_size: config.chunk_size,
|
||||
splitter: TreeSitterCodeSplitter::new(config.chunk_size, config.chunk_overlap)?,
|
||||
text_splitter,
|
||||
})
|
||||
}
|
||||
|
||||
fn split_tree(&self, tree: &Tree, contents: &[u8]) -> anyhow::Result<Vec<Chunk>> {
|
||||
Ok(self
|
||||
.splitter
|
||||
.split(tree, contents)?
|
||||
.into_iter()
|
||||
.map(|c| {
|
||||
Chunk::new(
|
||||
c.text.to_owned(),
|
||||
ByteRange::new(c.range.start_byte, c.range.end_byte),
|
||||
)
|
||||
})
|
||||
.collect())
|
||||
}
|
||||
}
|
||||
|
||||
impl Splitter for TreeSitter {
|
||||
fn split(&self, file: &File) -> Vec<Chunk> {
|
||||
if let Some(tree) = file.tree() {
|
||||
match self.split_tree(tree, file.rope().to_string().as_bytes()) {
|
||||
Ok(chunks) => chunks,
|
||||
Err(e) => {
|
||||
error!(
|
||||
"Failed to parse tree for file with error: {e:?}. Falling back to default splitter.",
|
||||
);
|
||||
self.text_splitter.split(file)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
self.text_splitter.split(file)
|
||||
}
|
||||
}
|
||||
|
||||
fn split_file_contents(&self, uri: &str, contents: &str) -> Vec<Chunk> {
|
||||
match parse_tree(uri, contents, None) {
|
||||
Ok(tree) => match self.split_tree(&tree, contents.as_bytes()) {
|
||||
Ok(chunks) => chunks,
|
||||
Err(e) => {
|
||||
error!(
|
||||
"Failed to parse tree for file: {uri} with error: {e:?}. Falling back to default splitter.",
|
||||
);
|
||||
self.text_splitter.split_file_contents(uri, contents)
|
||||
}
|
||||
},
|
||||
Err(e) => {
|
||||
error!(
|
||||
"Failed to parse tree for file {uri} with error: {e:?}. Falling back to default splitter.",
|
||||
);
|
||||
self.text_splitter.split_file_contents(uri, contents)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn does_use_tree_sitter(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn chunk_size(&self) -> usize {
|
||||
self.chunk_size
|
||||
}
|
||||
}
|
@ -67,11 +67,11 @@ impl Ollama {
|
||||
) -> anyhow::Result<String> {
|
||||
let client = reqwest::Client::new();
|
||||
let res: OllamaCompletionsResponse = client
|
||||
.post(self
|
||||
.configuration
|
||||
.generate_endpoint
|
||||
.as_deref()
|
||||
.unwrap_or("http://localhost:11434/api/generate")
|
||||
.post(
|
||||
self.configuration
|
||||
.generate_endpoint
|
||||
.as_deref()
|
||||
.unwrap_or("http://localhost:11434/api/generate"),
|
||||
)
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Accept", "application/json")
|
||||
@ -106,11 +106,11 @@ impl Ollama {
|
||||
) -> anyhow::Result<String> {
|
||||
let client = reqwest::Client::new();
|
||||
let res: OllamaChatResponse = client
|
||||
.post(self
|
||||
.configuration
|
||||
.chat_endpoint
|
||||
.as_deref()
|
||||
.unwrap_or("http://localhost:11434/api/chat")
|
||||
.post(
|
||||
self.configuration
|
||||
.chat_endpoint
|
||||
.as_deref()
|
||||
.unwrap_or("http://localhost:11434/api/chat"),
|
||||
)
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Accept", "application/json")
|
@ -17,7 +17,7 @@ use crate::custom_requests::generation_stream::GenerationStreamParams;
|
||||
use crate::memory_backends::Prompt;
|
||||
use crate::memory_worker::{self, FilterRequest, PromptRequest};
|
||||
use crate::transformer_backends::TransformerBackend;
|
||||
use crate::utils::ToResponseError;
|
||||
use crate::utils::{ToResponseError, TOKIO_RUNTIME};
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct CompletionRequest {
|
||||
@ -89,14 +89,14 @@ pub struct DoGenerationStreamResponse {
|
||||
fn post_process_start(response: String, front: &str) -> String {
|
||||
let mut front_match = response.len();
|
||||
loop {
|
||||
if response.len() == 0 || front.ends_with(&response[..front_match]) {
|
||||
if response.is_empty() || front.ends_with(&response[..front_match]) {
|
||||
break;
|
||||
} else {
|
||||
front_match -= 1;
|
||||
}
|
||||
}
|
||||
if front_match > 0 {
|
||||
(&response[front_match..]).to_owned()
|
||||
response[front_match..].to_owned()
|
||||
} else {
|
||||
response
|
||||
}
|
||||
@ -105,16 +105,14 @@ fn post_process_start(response: String, front: &str) -> String {
|
||||
fn post_process_end(response: String, back: &str) -> String {
|
||||
let mut back_match = 0;
|
||||
loop {
|
||||
if back_match == response.len() {
|
||||
break;
|
||||
} else if back.starts_with(&response[back_match..]) {
|
||||
if back_match == response.len() || back.starts_with(&response[back_match..]) {
|
||||
break;
|
||||
} else {
|
||||
back_match += 1;
|
||||
}
|
||||
}
|
||||
if back_match > 0 {
|
||||
(&response[..back_match]).to_owned()
|
||||
response[..back_match].to_owned()
|
||||
} else {
|
||||
response
|
||||
}
|
||||
@ -140,12 +138,10 @@ fn post_process_response(
|
||||
} else {
|
||||
response
|
||||
}
|
||||
} else if config.remove_duplicate_start {
|
||||
post_process_start(response, &context_and_code.code)
|
||||
} else {
|
||||
if config.remove_duplicate_start {
|
||||
post_process_start(response, &context_and_code.code)
|
||||
} else {
|
||||
response
|
||||
}
|
||||
response
|
||||
}
|
||||
}
|
||||
Prompt::FIM(fim) => {
|
||||
@ -177,7 +173,7 @@ pub fn run(
|
||||
connection,
|
||||
config,
|
||||
) {
|
||||
error!("error in transformer worker: {e}")
|
||||
error!("error in transformer worker: {e:?}")
|
||||
}
|
||||
}
|
||||
|
||||
@ -189,10 +185,6 @@ fn do_run(
|
||||
config: Config,
|
||||
) -> anyhow::Result<()> {
|
||||
let transformer_backends = Arc::new(transformer_backends);
|
||||
let runtime = tokio::runtime::Builder::new_multi_thread()
|
||||
.worker_threads(4)
|
||||
.enable_all()
|
||||
.build()?;
|
||||
|
||||
// If they have disabled completions, this function will fail. We set it to MIN_POSITIVE to never process a completions request
|
||||
let max_requests_per_second = config
|
||||
@ -206,7 +198,7 @@ fn do_run(
|
||||
let task_transformer_backends = transformer_backends.clone();
|
||||
let task_memory_backend_tx = memory_backend_tx.clone();
|
||||
let task_config = config.clone();
|
||||
runtime.spawn(async move {
|
||||
TOKIO_RUNTIME.spawn(async move {
|
||||
dispatch_request(
|
||||
request,
|
||||
task_connection,
|
||||
@ -264,7 +256,7 @@ async fn dispatch_request(
|
||||
{
|
||||
Ok(response) => response,
|
||||
Err(e) => {
|
||||
error!("generating response: {e}");
|
||||
error!("generating response: {e:?}");
|
||||
Response {
|
||||
id: request.get_id(),
|
||||
result: None,
|
||||
@ -274,7 +266,7 @@ async fn dispatch_request(
|
||||
};
|
||||
|
||||
if let Err(e) = connection.sender.send(Message::Response(response)) {
|
||||
error!("sending response: {e}");
|
||||
error!("sending response: {e:?}");
|
||||
}
|
||||
}
|
||||
|
||||
@ -293,14 +285,12 @@ async fn generate_response(
|
||||
.context("Completions is none")?;
|
||||
let transformer_backend = transformer_backends
|
||||
.get(&completion_config.model)
|
||||
.clone()
|
||||
.with_context(|| format!("can't find model: {}", &completion_config.model))?;
|
||||
do_completion(transformer_backend, memory_backend_tx, &request, &config).await
|
||||
}
|
||||
WorkerRequest::Generation(request) => {
|
||||
let transformer_backend = transformer_backends
|
||||
.get(&request.params.model)
|
||||
.clone()
|
||||
.with_context(|| format!("can't find model: {}", &request.params.model))?;
|
||||
do_generate(transformer_backend, memory_backend_tx, &request).await
|
||||
}
|
||||
@ -346,7 +336,6 @@ async fn do_completion(
|
||||
|
||||
// Get the response
|
||||
let mut response = transformer_backend.do_completion(&prompt, params).await?;
|
||||
eprintln!("\n\n\n\nGOT RESPONSE: {}\n\n\n\n", response.insert_text);
|
||||
|
||||
if let Some(post_process) = config.get_completions_post_process() {
|
||||
response.insert_text = post_process_response(response.insert_text, &prompt, &post_process);
|
||||
@ -423,7 +412,103 @@ async fn do_generate(
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::memory_backends::{ContextAndCodePrompt, FIMPrompt};
|
||||
use crate::memory_backends::{
|
||||
file_store::FileStore, ContextAndCodePrompt, FIMPrompt, MemoryBackend,
|
||||
};
|
||||
use serde_json::json;
|
||||
use std::{sync::mpsc, thread};
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_do_completion() -> anyhow::Result<()> {
|
||||
let (memory_tx, memory_rx) = mpsc::channel();
|
||||
let memory_backend: Box<dyn MemoryBackend + Send + Sync> =
|
||||
Box::new(FileStore::default_with_filler_file()?);
|
||||
thread::spawn(move || memory_worker::run(memory_backend, memory_rx));
|
||||
|
||||
let transformer_backend: Box<dyn TransformerBackend + Send + Sync> =
|
||||
config::ValidModel::Ollama(serde_json::from_value(
|
||||
json!({"model": "deepseek-coder:1.3b-base"}),
|
||||
)?)
|
||||
.try_into()?;
|
||||
let completion_request = CompletionRequest::new(
|
||||
serde_json::from_value(json!(0))?,
|
||||
serde_json::from_value(json!({
|
||||
"position": {"character":10, "line":2},
|
||||
"textDocument": {
|
||||
"uri": "file:///filler.py"
|
||||
}
|
||||
}))?,
|
||||
);
|
||||
let mut config = config::Config::default_with_file_store_without_models();
|
||||
config.config.completion = Some(serde_json::from_value(json!({
|
||||
"model": "model1",
|
||||
"parameters": {
|
||||
"options": {
|
||||
"temperature": 0
|
||||
}
|
||||
}
|
||||
}))?);
|
||||
|
||||
let result = do_completion(
|
||||
&transformer_backend,
|
||||
memory_tx,
|
||||
&completion_request,
|
||||
&config,
|
||||
)
|
||||
.await?;
|
||||
|
||||
assert_eq!(
|
||||
" x * y",
|
||||
result.result.clone().unwrap()["items"][0]["textEdit"]["newText"]
|
||||
.as_str()
|
||||
.unwrap()
|
||||
);
|
||||
assert_eq!(
|
||||
" return",
|
||||
result.result.unwrap()["items"][0]["filterText"]
|
||||
.as_str()
|
||||
.unwrap()
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_do_generate() -> anyhow::Result<()> {
|
||||
let (memory_tx, memory_rx) = mpsc::channel();
|
||||
let memory_backend: Box<dyn MemoryBackend + Send + Sync> =
|
||||
Box::new(FileStore::default_with_filler_file()?);
|
||||
thread::spawn(move || memory_worker::run(memory_backend, memory_rx));
|
||||
|
||||
let transformer_backend: Box<dyn TransformerBackend + Send + Sync> =
|
||||
config::ValidModel::Ollama(serde_json::from_value(
|
||||
json!({"model": "deepseek-coder:1.3b-base"}),
|
||||
)?)
|
||||
.try_into()?;
|
||||
let generation_request = GenerationRequest::new(
|
||||
serde_json::from_value(json!(0))?,
|
||||
serde_json::from_value(json!({
|
||||
"position": {"character":10, "line":2},
|
||||
"textDocument": {
|
||||
"uri": "file:///filler.py"
|
||||
},
|
||||
"model": "model1",
|
||||
"parameters": {
|
||||
"options": {
|
||||
"temperature": 0
|
||||
}
|
||||
}
|
||||
}))?,
|
||||
);
|
||||
let result = do_generate(&transformer_backend, memory_tx, &generation_request).await?;
|
||||
|
||||
assert_eq!(
|
||||
" x * y",
|
||||
result.result.unwrap()["generatedText"].as_str().unwrap()
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_post_process_fim() {
|
@ -1,6 +1,18 @@
|
||||
use anyhow::Context;
|
||||
use lsp_server::ResponseError;
|
||||
use once_cell::sync::Lazy;
|
||||
use tokio::runtime;
|
||||
use tree_sitter::Tree;
|
||||
|
||||
use crate::{config::ChatMessage, memory_backends::ContextAndCodePrompt};
|
||||
use crate::{config::ChatMessage, memory_backends::ContextAndCodePrompt, splitters::Chunk};
|
||||
|
||||
pub static TOKIO_RUNTIME: Lazy<runtime::Runtime> = Lazy::new(|| {
|
||||
runtime::Builder::new_multi_thread()
|
||||
.worker_threads(4)
|
||||
.enable_all()
|
||||
.build()
|
||||
.expect("Error building tokio runtime")
|
||||
});
|
||||
|
||||
pub trait ToResponseError {
|
||||
fn to_response_error(&self, code: i32) -> ResponseError;
|
||||
@ -42,3 +54,17 @@ pub fn format_context_code_in_str(s: &str, context: &str, code: &str) -> String
|
||||
pub fn format_context_code(context: &str, code: &str) -> String {
|
||||
format!("{context}\n\n{code}")
|
||||
}
|
||||
|
||||
pub fn chunk_to_id(uri: &str, chunk: &Chunk) -> String {
|
||||
format!("{uri}#{}-{}", chunk.range.start_byte, chunk.range.end_byte)
|
||||
}
|
||||
|
||||
pub fn parse_tree(uri: &str, contents: &str, old_tree: Option<&Tree>) -> anyhow::Result<Tree> {
|
||||
let path = std::path::Path::new(uri);
|
||||
let extension = path.extension().map(|x| x.to_string_lossy());
|
||||
let extension = extension.as_deref().unwrap_or("");
|
||||
let mut parser = utils_tree_sitter::get_parser_for_extension(extension)?;
|
||||
parser
|
||||
.parse(&contents, old_tree)
|
||||
.with_context(|| format!("parsing tree failed for {uri}"))
|
||||
}
|
289
crates/lsp-ai/tests/integration_tests.rs
Normal file
289
crates/lsp-ai/tests/integration_tests.rs
Normal file
@ -0,0 +1,289 @@
|
||||
use anyhow::Result;
|
||||
use std::{
|
||||
io::{Read, Write},
|
||||
process::{ChildStdin, ChildStdout, Command, Stdio},
|
||||
};
|
||||
|
||||
// Note if you get an empty response with no error, that typically means
|
||||
// the language server died
|
||||
fn read_response(stdout: &mut ChildStdout) -> Result<String> {
|
||||
let mut content_length = None;
|
||||
let mut buf = vec![];
|
||||
loop {
|
||||
let mut buf2 = vec![0];
|
||||
stdout.read_exact(&mut buf2)?;
|
||||
buf.push(buf2[0]);
|
||||
if let Some(content_length) = content_length {
|
||||
if buf.len() == content_length {
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
let len = buf.len();
|
||||
if len > 4
|
||||
&& buf[len - 4] == 13
|
||||
&& buf[len - 3] == 10
|
||||
&& buf[len - 2] == 13
|
||||
&& buf[len - 1] == 10
|
||||
{
|
||||
content_length =
|
||||
Some(String::from_utf8(buf[16..len - 4].to_vec())?.parse::<usize>()?);
|
||||
buf = vec![];
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(String::from_utf8(buf)?)
|
||||
}
|
||||
|
||||
fn send_message(stdin: &mut ChildStdin, message: &str) -> Result<()> {
|
||||
stdin.write_all(format!("Content-Length: {}\r\n", message.as_bytes().len(),).as_bytes())?;
|
||||
stdin.write_all("\r\n".as_bytes())?;
|
||||
stdin.write_all(message.as_bytes())?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// This chat completion sequence was created using helix with lsp-ai and reading the logs
|
||||
// It utilizes Ollama with llama3:8b-instruct-q4_0 and a temperature of 0
|
||||
// It starts with a Python file:
|
||||
// ```
|
||||
// # Multiplies two numbers
|
||||
// def multiply_two_numbers(x, y):
|
||||
//
|
||||
// # A singular test
|
||||
// assert multiply_two_numbers(2, 3) == 6
|
||||
//
|
||||
// ```
|
||||
// And has the following sequence of key strokes:
|
||||
// o on line 2 (this creates an indented new line and enters insert mode)
|
||||
// r
|
||||
// e
|
||||
// t
|
||||
// u
|
||||
// r
|
||||
// n
|
||||
// The sequence has:
|
||||
// - 1 textDocument/DidOpen notification
|
||||
// - 7 textDocument/didChange notifications
|
||||
// - 1 textDocument/completion requests
|
||||
#[test]
|
||||
fn test_chat_completion_sequence() -> Result<()> {
|
||||
let mut child = Command::new("cargo")
|
||||
.arg("run")
|
||||
.stdin(Stdio::piped())
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped())
|
||||
.spawn()?;
|
||||
|
||||
let mut stdin = child.stdin.take().unwrap();
|
||||
let mut stdout = child.stdout.take().unwrap();
|
||||
|
||||
let initialization_message = r##"{"jsonrpc":"2.0","method":"initialize","params":{"capabilities":{"general":{"positionEncodings":["utf-8","utf-32","utf-16"]},"textDocument":{"codeAction":{"codeActionLiteralSupport":{"codeActionKind":{"valueSet":["","quickfix","refactor","refactor.extract","refactor.inline","refactor.rewrite","source","source.organizeImports"]}},"dataSupport":true,"disabledSupport":true,"isPreferredSupport":true,"resolveSupport":{"properties":["edit","command"]}},"completion":{"completionItem":{"deprecatedSupport":true,"insertReplaceSupport":true,"resolveSupport":{"properties":["documentation","detail","additionalTextEdits"]},"snippetSupport":true,"tagSupport":{"valueSet":[1]}},"completionItemKind":{}},"hover":{"contentFormat":["markdown"]},"inlayHint":{"dynamicRegistration":false},"publishDiagnostics":{"tagSupport":{"valueSet":[1,2]},"versionSupport":true},"rename":{"dynamicRegistration":false,"honorsChangeAnnotations":false,"prepareSupport":true},"signatureHelp":{"signatureInformation":{"activeParameterSupport":true,"documentationFormat":["markdown"],"parameterInformation":{"labelOffsetSupport":true}}}},"window":{"workDoneProgress":true},"workspace":{"applyEdit":true,"configuration":true,"didChangeConfiguration":{"dynamicRegistration":false},"didChangeWatchedFiles":{"dynamicRegistration":true,"relativePatternSupport":false},"executeCommand":{"dynamicRegistration":false},"fileOperations":{"didRename":true,"willRename":true},"inlayHint":{"refreshSupport":false},"symbol":{"dynamicRegistration":false},"workspaceEdit":{"documentChanges":true,"failureHandling":"abort","normalizesLineEndings":false,"resourceOperations":["create","rename","delete"]},"workspaceFolders":true}},"clientInfo":{"name":"helix","version":"24.3 (beb5afcb)"},"initializationOptions":{"completion":{"model":"model1","parameters":{"max_context":1024,"messages":[{"content":"Instructions:\n- You are an AI programming assistant.\n- Given a piece of code with the cursor location marked by \"<CURSOR>\", replace \"<CURSOR>\" with the correct code or comment.\n- First, think step-by-step.\n- Describe your plan for what to build in pseudocode, written out in great detail.\n- Then output the code replacing the \"<CURSOR>\"\n- Ensure that your completion fits within the language context of the provided code snippet (e.g., Python, JavaScript, Rust).\n\nRules:\n- Only respond with code or comments.\n- Only replace \"<CURSOR>\"; do not include any previously written code.\n- Never include \"<CURSOR>\" in your response\n- If the cursor is within a comment, complete the comment meaningfully.\n- Handle ambiguous cases by providing the most contextually appropriate completion.\n- Be consistent with your responses.","role":"system"},{"content":"def greet(name):\n print(f\"Hello, {<CURSOR>}\")","role":"user"},{"content":"name","role":"assistant"},{"content":"function sum(a, b) {\n return a + <CURSOR>;\n}","role":"user"},{"content":"b","role":"assistant"},{"content":"fn multiply(a: i32, b: i32) -> i32 {\n a * <CURSOR>\n}","role":"user"},{"content":"b","role":"assistant"},{"content":"# <CURSOR>\ndef add(a, b):\n return a + b","role":"user"},{"content":"Adds two numbers","role":"assistant"},{"content":"# This function checks if a number is even\n<CURSOR>","role":"user"},{"content":"def is_even(n):\n return n % 2 == 0","role":"assistant"},{"content":"{CODE}","role":"user"}],"options":{"num_predict":32,"temperature":0}}},"memory":{"file_store":{}},"models":{"model1":{"model":"llama3:8b-instruct-q4_0","type":"ollama"}}},"processId":66009,"rootPath":"/home/silas/Projects/test","rootUri":null,"workspaceFolders":[]},"id":0}"##;
|
||||
send_message(&mut stdin, initialization_message)?;
|
||||
let _ = read_response(&mut stdout)?;
|
||||
|
||||
send_message(
|
||||
&mut stdin,
|
||||
r#"{"jsonrpc":"2.0","method":"initialized","params":{}}"#,
|
||||
)?;
|
||||
send_message(
|
||||
&mut stdin,
|
||||
r##"{"jsonrpc":"2.0","method":"textDocument/didOpen","params":{"textDocument":{"languageId":"python","text":"# Multiplies two numbers\ndef multiply_two_numbers(x, y):\n\n# A singular test\nassert multiply_two_numbers(2, 3) == 6\n","uri":"file:///fake.py","version":0}}}"##,
|
||||
)?;
|
||||
send_message(
|
||||
&mut stdin,
|
||||
r##"{"jsonrpc":"2.0","method":"textDocument/didChange","params":{"contentChanges":[{"range":{"end":{"character":31,"line":1},"start":{"character":31,"line":1}},"text":"\n "}],"textDocument":{"uri":"file:///fake.py","version":1}}}"##,
|
||||
)?;
|
||||
send_message(
|
||||
&mut stdin,
|
||||
r##"{"jsonrpc":"2.0","method":"textDocument/didChange","params":{"contentChanges":[{"range":{"end":{"character":4,"line":2},"start":{"character":4,"line":2}},"text":"r"}],"textDocument":{"uri":"file:///fake.py","version":2}}}"##,
|
||||
)?;
|
||||
send_message(
|
||||
&mut stdin,
|
||||
r##"{"jsonrpc":"2.0","method":"textDocument/didChange","params":{"contentChanges":[{"range":{"end":{"character":5,"line":2},"start":{"character":5,"line":2}},"text":"e"}],"textDocument":{"uri":"file:///fake.py","version":3}}}"##,
|
||||
)?;
|
||||
send_message(
|
||||
&mut stdin,
|
||||
r##"{"jsonrpc":"2.0","method":"textDocument/didChange","params":{"contentChanges":[{"range":{"end":{"character":6,"line":2},"start":{"character":6,"line":2}},"text":"t"}],"textDocument":{"uri":"file:///fake.py","version":4}}}"##,
|
||||
)?;
|
||||
send_message(
|
||||
&mut stdin,
|
||||
r##"{"jsonrpc":"2.0","method":"textDocument/didChange","params":{"contentChanges":[{"range":{"end":{"character":7,"line":2},"start":{"character":7,"line":2}},"text":"u"}],"textDocument":{"uri":"file:///fake.py","version":5}}}"##,
|
||||
)?;
|
||||
send_message(
|
||||
&mut stdin,
|
||||
r##"{"jsonrpc":"2.0","method":"textDocument/didChange","params":{"contentChanges":[{"range":{"end":{"character":8,"line":2},"start":{"character":8,"line":2}},"text":"r"}],"textDocument":{"uri":"file:///fake.py","version":6}}}"##,
|
||||
)?;
|
||||
send_message(
|
||||
&mut stdin,
|
||||
r##"{"jsonrpc":"2.0","method":"textDocument/didChange","params":{"contentChanges":[{"range":{"end":{"character":9,"line":2},"start":{"character":9,"line":2}},"text":"n"}],"textDocument":{"uri":"file:///fake.py","version":7}}}"##,
|
||||
)?;
|
||||
send_message(
|
||||
&mut stdin,
|
||||
r##"{"jsonrpc":"2.0","method":"textDocument/completion","params":{"position":{"character":10,"line":2},"textDocument":{"uri":"file:///fake.py"}},"id":1}"##,
|
||||
)?;
|
||||
|
||||
let output = read_response(&mut stdout)?;
|
||||
assert_eq!(
|
||||
output,
|
||||
r##"{"jsonrpc":"2.0","id":1,"result":{"isIncomplete":false,"items":[{"filterText":" return","kind":1,"label":"ai - x * y","textEdit":{"newText":"x * y","range":{"end":{"character":10,"line":2},"start":{"character":10,"line":2}}}}]}}"##
|
||||
);
|
||||
|
||||
child.kill()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// This FIM completion sequence was created using helix with lsp-ai and reading the logs
|
||||
// It utilizes Ollama with deepseek-coder:1.3b-base and a temperature of 0
|
||||
// It starts with a Python file:
|
||||
// ```
|
||||
// # Multiplies two numbers
|
||||
// def multiply_two_numbers(x, y):
|
||||
//
|
||||
// # A singular test
|
||||
// assert multiply_two_numbers(2, 3) == 6
|
||||
//
|
||||
// ```
|
||||
// And has the following sequence of key strokes:
|
||||
// o on line 2 (this creates an indented new line and enters insert mode)
|
||||
// r
|
||||
// e
|
||||
// The sequence has:
|
||||
// - 1 textDocument/DidOpen notification
|
||||
// - 3 textDocument/didChange notifications
|
||||
// - 1 textDocument/completion requests
|
||||
#[test]
|
||||
fn test_fim_completion_sequence() -> Result<()> {
|
||||
let mut child = Command::new("cargo")
|
||||
.arg("run")
|
||||
.stdin(Stdio::piped())
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped())
|
||||
.spawn()?;
|
||||
|
||||
let mut stdin = child.stdin.take().unwrap();
|
||||
let mut stdout = child.stdout.take().unwrap();
|
||||
|
||||
let initialization_message = r##"{"jsonrpc":"2.0","method":"initialize","params":{"capabilities":{"general":{"positionEncodings":["utf-8","utf-32","utf-16"]},"textDocument":{"codeAction":{"codeActionLiteralSupport":{"codeActionKind":{"valueSet":["","quickfix","refactor","refactor.extract","refactor.inline","refactor.rewrite","source","source.organizeImports"]}},"dataSupport":true,"disabledSupport":true,"isPreferredSupport":true,"resolveSupport":{"properties":["edit","command"]}},"completion":{"completionItem":{"deprecatedSupport":true,"insertReplaceSupport":true,"resolveSupport":{"properties":["documentation","detail","additionalTextEdits"]},"snippetSupport":true,"tagSupport":{"valueSet":[1]}},"completionItemKind":{}},"hover":{"contentFormat":["markdown"]},"inlayHint":{"dynamicRegistration":false},"publishDiagnostics":{"tagSupport":{"valueSet":[1,2]},"versionSupport":true},"rename":{"dynamicRegistration":false,"honorsChangeAnnotations":false,"prepareSupport":true},"signatureHelp":{"signatureInformation":{"activeParameterSupport":true,"documentationFormat":["markdown"],"parameterInformation":{"labelOffsetSupport":true}}}},"window":{"workDoneProgress":true},"workspace":{"applyEdit":true,"configuration":true,"didChangeConfiguration":{"dynamicRegistration":false},"didChangeWatchedFiles":{"dynamicRegistration":true,"relativePatternSupport":false},"executeCommand":{"dynamicRegistration":false},"fileOperations":{"didRename":true,"willRename":true},"inlayHint":{"refreshSupport":false},"symbol":{"dynamicRegistration":false},"workspaceEdit":{"documentChanges":true,"failureHandling":"abort","normalizesLineEndings":false,"resourceOperations":["create","rename","delete"]},"workspaceFolders":true}},"clientInfo":{"name":"helix","version":"24.3 (beb5afcb)"},"initializationOptions":{"completion":{"model":"model1","parameters":{"fim":{"end":"<|fim▁end|>","middle":"<|fim▁hole|>","start":"<|fim▁begin|>"},"max_context":1024,"options":{"num_predict":32,"temperature":0}}},"memory":{"file_store":{}},"models":{"model1":{"model":"deepseek-coder:1.3b-base","type":"ollama"}}},"processId":50347,"rootPath":"/home/silas/Projects/test","rootUri":null,"workspaceFolders":[]},"id":0}"##;
|
||||
send_message(&mut stdin, initialization_message)?;
|
||||
let _ = read_response(&mut stdout)?;
|
||||
|
||||
send_message(
|
||||
&mut stdin,
|
||||
r#"{"jsonrpc":"2.0","method":"initialized","params":{}}"#,
|
||||
)?;
|
||||
send_message(
|
||||
&mut stdin,
|
||||
r##"{"jsonrpc":"2.0","method":"textDocument/didOpen","params":{"textDocument":{"languageId":"python","text":"# Multiplies two numbers\ndef multiply_two_numbers(x, y):\n\n# A singular test\nassert multiply_two_numbers(2, 3) == 6\n","uri":"file:///fake.py","version":0}}}"##,
|
||||
)?;
|
||||
send_message(
|
||||
&mut stdin,
|
||||
r##"{"jsonrpc":"2.0","method":"textDocument/didChange","params":{"contentChanges":[{"range":{"end":{"character":31,"line":1},"start":{"character":31,"line":1}},"text":"\n "}],"textDocument":{"uri":"file:///fake.py","version":1}}}"##,
|
||||
)?;
|
||||
send_message(
|
||||
&mut stdin,
|
||||
r##"{"jsonrpc":"2.0","method":"textDocument/didChange","params":{"contentChanges":[{"range":{"end":{"character":4,"line":2},"start":{"character":4,"line":2}},"text":"r"}],"textDocument":{"uri":"file:///fake.py","version":2}}}"##,
|
||||
)?;
|
||||
send_message(
|
||||
&mut stdin,
|
||||
r##"{"jsonrpc":"2.0","method":"textDocument/didChange","params":{"contentChanges":[{"range":{"end":{"character":5,"line":2},"start":{"character":5,"line":2}},"text":"e"}],"textDocument":{"uri":"file:///fake.py","version":3}}}"##,
|
||||
)?;
|
||||
send_message(
|
||||
&mut stdin,
|
||||
r##"{"jsonrpc":"2.0","method":"textDocument/completion","params":{"position":{"character":6,"line":2},"textDocument":{"uri":"file:///fake.py"}},"id":1}"##,
|
||||
)?;
|
||||
|
||||
let output = read_response(&mut stdout)?;
|
||||
assert_eq!(
|
||||
output,
|
||||
r##"{"jsonrpc":"2.0","id":1,"result":{"isIncomplete":false,"items":[{"filterText":" re","kind":1,"label":"ai - turn x * y","textEdit":{"newText":"turn x * y","range":{"end":{"character":6,"line":2},"start":{"character":6,"line":2}}}}]}}"##
|
||||
);
|
||||
|
||||
child.kill()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// This completion sequence was created using helix with lsp-ai and reading the logs
|
||||
// It utilizes Ollama with deepseek-coder:1.3b-base and a temperature of 0
|
||||
// It starts with a Python file:
|
||||
// ```
|
||||
// # Multiplies two numbers
|
||||
// def multiply_two_numbers(x, y):
|
||||
//
|
||||
// ```
|
||||
// And has the following sequence of key strokes:
|
||||
// o on line 2 (this creates an indented new line and enters insert mode)
|
||||
// r
|
||||
// e
|
||||
// t
|
||||
// u
|
||||
// r
|
||||
// n
|
||||
// The sequence has:
|
||||
// - 1 textDocument/DidOpen notification
|
||||
// - 7 textDocument/didChange notifications
|
||||
// - 1 textDocument/completion requests
|
||||
#[test]
|
||||
fn test_completion_sequence() -> Result<()> {
|
||||
let mut child = Command::new("cargo")
|
||||
.arg("run")
|
||||
.stdin(Stdio::piped())
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped())
|
||||
.spawn()?;
|
||||
|
||||
let mut stdin = child.stdin.take().unwrap();
|
||||
let mut stdout = child.stdout.take().unwrap();
|
||||
|
||||
let initialization_message = r##"{"jsonrpc":"2.0","method":"initialize","params":{"capabilities":{"general":{"positionEncodings":["utf-8","utf-32","utf-16"]},"textDocument":{"codeAction":{"codeActionLiteralSupport":{"codeActionKind":{"valueSet":["","quickfix","refactor","refactor.extract","refactor.inline","refactor.rewrite","source","source.organizeImports"]}},"dataSupport":true,"disabledSupport":true,"isPreferredSupport":true,"resolveSupport":{"properties":["edit","command"]}},"completion":{"completionItem":{"deprecatedSupport":true,"insertReplaceSupport":true,"resolveSupport":{"properties":["documentation","detail","additionalTextEdits"]},"snippetSupport":true,"tagSupport":{"valueSet":[1]}},"completionItemKind":{}},"hover":{"contentFormat":["markdown"]},"inlayHint":{"dynamicRegistration":false},"publishDiagnostics":{"tagSupport":{"valueSet":[1,2]},"versionSupport":true},"rename":{"dynamicRegistration":false,"honorsChangeAnnotations":false,"prepareSupport":true},"signatureHelp":{"signatureInformation":{"activeParameterSupport":true,"documentationFormat":["markdown"],"parameterInformation":{"labelOffsetSupport":true}}}},"window":{"workDoneProgress":true},"workspace":{"applyEdit":true,"configuration":true,"didChangeConfiguration":{"dynamicRegistration":false},"didChangeWatchedFiles":{"dynamicRegistration":true,"relativePatternSupport":false},"executeCommand":{"dynamicRegistration":false},"fileOperations":{"didRename":true,"willRename":true},"inlayHint":{"refreshSupport":false},"symbol":{"dynamicRegistration":false},"workspaceEdit":{"documentChanges":true,"failureHandling":"abort","normalizesLineEndings":false,"resourceOperations":["create","rename","delete"]},"workspaceFolders":true}},"clientInfo":{"name":"helix","version":"24.3 (beb5afcb)"},"initializationOptions":{"completion":{"model":"model1","parameters":{"max_context":1024,"options":{"num_predict":32,"temperature":0}}},"memory":{"file_store":{}},"models":{"model1":{"model":"deepseek-coder:1.3b-base","type":"ollama"}}},"processId":62322,"rootPath":"/home/silas/Projects/test","rootUri":null,"workspaceFolders":[]},"id":0}"##;
|
||||
send_message(&mut stdin, initialization_message)?;
|
||||
let _ = read_response(&mut stdout)?;
|
||||
|
||||
send_message(
|
||||
&mut stdin,
|
||||
r#"{"jsonrpc":"2.0","method":"initialized","params":{}}"#,
|
||||
)?;
|
||||
send_message(
|
||||
&mut stdin,
|
||||
r##"{"jsonrpc":"2.0","method":"textDocument/didOpen","params":{"textDocument":{"languageId":"python","text":"# Multiplies two numbers\ndef multiply_two_numbers(x, y):\n","uri":"file:///fake.py","version":0}}}"##,
|
||||
)?;
|
||||
send_message(
|
||||
&mut stdin,
|
||||
r##"{"jsonrpc":"2.0","method":"textDocument/didChange","params":{"contentChanges":[{"range":{"end":{"character":31,"line":1},"start":{"character":31,"line":1}},"text":"\n "}],"textDocument":{"uri":"file:///fake.py","version":1}}}"##,
|
||||
)?;
|
||||
send_message(
|
||||
&mut stdin,
|
||||
r##"{"jsonrpc":"2.0","method":"textDocument/didChange","params":{"contentChanges":[{"range":{"end":{"character":4,"line":2},"start":{"character":4,"line":2}},"text":"r"}],"textDocument":{"uri":"file:///fake.py","version":2}}}"##,
|
||||
)?;
|
||||
send_message(
|
||||
&mut stdin,
|
||||
r##"{"jsonrpc":"2.0","method":"textDocument/didChange","params":{"contentChanges":[{"range":{"end":{"character":5,"line":2},"start":{"character":5,"line":2}},"text":"e"}],"textDocument":{"uri":"file:///fake.py","version":3}}}"##,
|
||||
)?;
|
||||
send_message(
|
||||
&mut stdin,
|
||||
r##"{"jsonrpc":"2.0","method":"textDocument/didChange","params":{"contentChanges":[{"range":{"end":{"character":6,"line":2},"start":{"character":6,"line":2}},"text":"t"}],"textDocument":{"uri":"file:///fake.py","version":4}}}"##,
|
||||
)?;
|
||||
send_message(
|
||||
&mut stdin,
|
||||
r##"{"jsonrpc":"2.0","method":"textDocument/didChange","params":{"contentChanges":[{"range":{"end":{"character":7,"line":2},"start":{"character":7,"line":2}},"text":"u"}],"textDocument":{"uri":"file:///fake.py","version":5}}}"##,
|
||||
)?;
|
||||
send_message(
|
||||
&mut stdin,
|
||||
r##"{"jsonrpc":"2.0","method":"textDocument/didChange","params":{"contentChanges":[{"range":{"end":{"character":8,"line":2},"start":{"character":8,"line":2}},"text":"r"}],"textDocument":{"uri":"file:///fake.py","version":6}}}"##,
|
||||
)?;
|
||||
send_message(
|
||||
&mut stdin,
|
||||
r##"{"jsonrpc":"2.0","method":"textDocument/didChange","params":{"contentChanges":[{"range":{"end":{"character":9,"line":2},"start":{"character":9,"line":2}},"text":"n"}],"textDocument":{"uri":"file:///fake.py","version":7}}}"##,
|
||||
)?;
|
||||
send_message(
|
||||
&mut stdin,
|
||||
r##"{"jsonrpc":"2.0","method":"textDocument/completion","params":{"position":{"character":10,"line":2},"textDocument":{"uri":"file:///fake.py"}},"id":1}"##,
|
||||
)?;
|
||||
|
||||
let output = read_response(&mut stdout)?;
|
||||
assert_eq!(
|
||||
output,
|
||||
r##"{"jsonrpc":"2.0","id":1,"result":{"isIncomplete":false,"items":[{"filterText":" return","kind":1,"label":"ai - x * y","textEdit":{"newText":" x * y","range":{"end":{"character":10,"line":2},"start":{"character":10,"line":2}}}}]}}"##
|
||||
);
|
||||
|
||||
child.kill()?;
|
||||
Ok(())
|
||||
}
|
19
crates/splitter-tree-sitter/Cargo.toml
Normal file
19
crates/splitter-tree-sitter/Cargo.toml
Normal file
@ -0,0 +1,19 @@
|
||||
[package]
|
||||
name = "splitter-tree-sitter"
|
||||
version = "0.1.0"
|
||||
description = "A code splitter utilizing Tree-sitter"
|
||||
|
||||
edition.workspace = true
|
||||
repository.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
thiserror = "1.0.61"
|
||||
tree-sitter = "0.22"
|
||||
|
||||
[dev-dependencies]
|
||||
tree-sitter-rust = "0.21"
|
||||
tree-sitter-zig = { git = "https://github.com/maxxnino/tree-sitter-zig" }
|
||||
|
||||
[build-dependencies]
|
||||
cc="*"
|
21
crates/splitter-tree-sitter/LICENSE
Normal file
21
crates/splitter-tree-sitter/LICENSE
Normal file
@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2024 Silas Marvin
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
3
crates/splitter-tree-sitter/README.md
Normal file
3
crates/splitter-tree-sitter/README.md
Normal file
@ -0,0 +1,3 @@
|
||||
# tree-sitter-splitter
|
||||
|
||||
This is a code splitter that utilizes Tree-sitter to split code.
|
326
crates/splitter-tree-sitter/src/lib.rs
Normal file
326
crates/splitter-tree-sitter/src/lib.rs
Normal file
@ -0,0 +1,326 @@
|
||||
use thiserror::Error;
|
||||
use tree_sitter::{Tree, TreeCursor};
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum NewError {
|
||||
#[error("chunk_size must be greater than chunk_overlap")]
|
||||
SizeOverlapError,
|
||||
}
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum SplitError {
|
||||
#[error("converting utf8 to str")]
|
||||
Utf8Error(#[from] core::str::Utf8Error),
|
||||
}
|
||||
|
||||
pub struct TreeSitterCodeSplitter {
|
||||
chunk_size: usize,
|
||||
chunk_overlap: usize,
|
||||
}
|
||||
|
||||
pub struct ByteRange {
|
||||
pub start_byte: usize,
|
||||
pub end_byte: usize,
|
||||
}
|
||||
|
||||
impl ByteRange {
|
||||
fn new(start_byte: usize, end_byte: usize) -> Self {
|
||||
Self {
|
||||
start_byte,
|
||||
end_byte,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Chunk<'a> {
|
||||
pub text: &'a str,
|
||||
pub range: ByteRange,
|
||||
}
|
||||
|
||||
impl<'a> Chunk<'a> {
|
||||
fn new(text: &'a str, range: ByteRange) -> Self {
|
||||
Self { text, range }
|
||||
}
|
||||
}
|
||||
|
||||
impl TreeSitterCodeSplitter {
|
||||
pub fn new(chunk_size: usize, chunk_overlap: usize) -> Result<Self, NewError> {
|
||||
if chunk_overlap > chunk_size {
|
||||
Err(NewError::SizeOverlapError)
|
||||
} else {
|
||||
Ok(Self {
|
||||
chunk_size,
|
||||
chunk_overlap,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub fn split<'a, 'b, 'c>(
|
||||
&'a self,
|
||||
tree: &'b Tree,
|
||||
utf8: &'c [u8],
|
||||
) -> Result<Vec<Chunk<'c>>, SplitError> {
|
||||
let cursor = tree.walk();
|
||||
Ok(self
|
||||
.split_recursive(cursor, utf8)?
|
||||
.into_iter()
|
||||
.rev()
|
||||
// Let's combine some of our smaller chunks together
|
||||
// We also want to do this in reverse as it (seems) to make more sense to combine code slices from bottom to top
|
||||
.try_fold(vec![], |mut acc, current| {
|
||||
if acc.len() == 0 {
|
||||
acc.push(current);
|
||||
Ok::<_, SplitError>(acc)
|
||||
} else {
|
||||
if acc.last().as_ref().unwrap().text.len() + current.text.len()
|
||||
< self.chunk_size
|
||||
{
|
||||
let last = acc.pop().unwrap();
|
||||
let text = std::str::from_utf8(
|
||||
&utf8[current.range.start_byte..last.range.end_byte],
|
||||
)?;
|
||||
acc.push(Chunk::new(
|
||||
text,
|
||||
ByteRange::new(current.range.start_byte, last.range.end_byte),
|
||||
));
|
||||
} else {
|
||||
acc.push(current);
|
||||
}
|
||||
Ok(acc)
|
||||
}
|
||||
})?
|
||||
.into_iter()
|
||||
.rev()
|
||||
.collect())
|
||||
}
|
||||
|
||||
fn split_recursive<'a, 'b, 'c>(
|
||||
&'a self,
|
||||
mut cursor: TreeCursor<'b>,
|
||||
utf8: &'c [u8],
|
||||
) -> Result<Vec<Chunk<'c>>, SplitError> {
|
||||
let node = cursor.node();
|
||||
let text = node.utf8_text(utf8)?;
|
||||
|
||||
// There are three cases:
|
||||
// 1. Is the current range of code smaller than the chunk_size? If so, return it
|
||||
// 2. If not, does the current node have children? If so, recursively walk down
|
||||
// 3. If not, we must split our current node
|
||||
let mut out = if text.chars().count() <= self.chunk_size {
|
||||
vec![Chunk::new(
|
||||
text,
|
||||
ByteRange::new(node.range().start_byte, node.range().end_byte),
|
||||
)]
|
||||
} else {
|
||||
let mut cursor_copy = cursor.clone();
|
||||
if cursor_copy.goto_first_child() {
|
||||
self.split_recursive(cursor_copy, utf8)?
|
||||
} else {
|
||||
let mut current_range =
|
||||
ByteRange::new(node.range().start_byte, node.range().end_byte);
|
||||
let mut chunks = vec![];
|
||||
let mut current_chunk = text;
|
||||
loop {
|
||||
if current_chunk.len() < self.chunk_size {
|
||||
chunks.push(Chunk::new(current_chunk, current_range));
|
||||
break;
|
||||
} else {
|
||||
let new_chunk = ¤t_chunk[0..self.chunk_size.min(current_chunk.len())];
|
||||
let new_range = ByteRange::new(
|
||||
current_range.start_byte,
|
||||
current_range.start_byte + new_chunk.as_bytes().len(),
|
||||
);
|
||||
chunks.push(Chunk::new(new_chunk, new_range));
|
||||
let new_current_chunk =
|
||||
¤t_chunk[self.chunk_size - self.chunk_overlap..];
|
||||
let byte_diff =
|
||||
current_chunk.as_bytes().len() - new_current_chunk.as_bytes().len();
|
||||
current_range = ByteRange::new(
|
||||
current_range.start_byte + byte_diff,
|
||||
current_range.end_byte,
|
||||
);
|
||||
current_chunk = new_current_chunk
|
||||
}
|
||||
}
|
||||
chunks
|
||||
}
|
||||
};
|
||||
if cursor.goto_next_sibling() {
|
||||
out.append(&mut self.split_recursive(cursor, utf8)?);
|
||||
}
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tree_sitter::Parser;
|
||||
|
||||
#[test]
|
||||
fn test_split_rust() {
|
||||
let splitter = TreeSitterCodeSplitter::new(128, 0).unwrap();
|
||||
|
||||
let mut parser = Parser::new();
|
||||
parser
|
||||
.set_language(&tree_sitter_rust::language())
|
||||
.expect("Error loading Rust grammar");
|
||||
|
||||
let source_code = r#"
|
||||
#[derive(Debug)]
|
||||
struct Rectangle {
|
||||
width: u32,
|
||||
height: u32,
|
||||
}
|
||||
|
||||
impl Rectangle {
|
||||
fn area(&self) -> u32 {
|
||||
self.width * self.height
|
||||
}
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let rect1 = Rectangle {
|
||||
width: 30,
|
||||
height: 50,
|
||||
};
|
||||
|
||||
println!(
|
||||
"The area of the rectangle is {} square pixels.",
|
||||
rect1.area()
|
||||
);
|
||||
}
|
||||
"#;
|
||||
let tree = parser.parse(source_code, None).unwrap();
|
||||
let chunks = splitter.split(&tree, source_code.as_bytes()).unwrap();
|
||||
assert_eq!(
|
||||
chunks[0].text,
|
||||
r#"#[derive(Debug)]
|
||||
struct Rectangle {
|
||||
width: u32,
|
||||
height: u32,
|
||||
}"#
|
||||
);
|
||||
assert_eq!(
|
||||
chunks[1].text,
|
||||
r#"impl Rectangle {
|
||||
fn area(&self) -> u32 {
|
||||
self.width * self.height
|
||||
}
|
||||
}"#
|
||||
);
|
||||
assert_eq!(
|
||||
chunks[2].text,
|
||||
r#"fn main() {
|
||||
let rect1 = Rectangle {
|
||||
width: 30,
|
||||
height: 50,
|
||||
};"#
|
||||
);
|
||||
assert_eq!(
|
||||
chunks[3].text,
|
||||
r#"println!(
|
||||
"The area of the rectangle is {} square pixels.",
|
||||
rect1.area()
|
||||
);
|
||||
}"#
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_split_zig() {
|
||||
let splitter = TreeSitterCodeSplitter::new(128, 10).unwrap();
|
||||
|
||||
let mut parser = Parser::new();
|
||||
parser
|
||||
.set_language(&tree_sitter_rust::language())
|
||||
.expect("Error loading Rust grammar");
|
||||
|
||||
let source_code = r#"
|
||||
const std = @import("std");
|
||||
const parseInt = std.fmt.parseInt;
|
||||
|
||||
std.debug.print("Here is a long string 1 ... Here is a long string 2 ... Here is a long string 3 ... Here is a long string 4 ... Here is a long string 5 ... Here is a long string 6 ... Here is a long string 7 ... Here is a long string 8 ... Here is a long string 9 ...", .{});
|
||||
|
||||
test "parse integers" {
|
||||
const input = "123 67 89,99";
|
||||
const ally = std.testing.allocator;
|
||||
|
||||
var list = std.ArrayList(u32).init(ally);
|
||||
// Ensure the list is freed at scope exit.
|
||||
// Try commenting out this line!
|
||||
defer list.deinit();
|
||||
|
||||
var it = std.mem.tokenizeAny(u8, input, " ,");
|
||||
while (it.next()) |num| {
|
||||
const n = try parseInt(u32, num, 10);
|
||||
try list.append(n);
|
||||
}
|
||||
|
||||
const expected = [_]u32{ 123, 67, 89, 99 };
|
||||
|
||||
for (expected, list.items) |exp, actual| {
|
||||
try std.testing.expectEqual(exp, actual);
|
||||
}
|
||||
}
|
||||
"#;
|
||||
let tree = parser.parse(source_code, None).unwrap();
|
||||
let chunks = splitter.split(&tree, source_code.as_bytes()).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
chunks[0].text,
|
||||
r#"const std = @import("std");
|
||||
const parseInt = std.fmt.parseInt;
|
||||
|
||||
std.debug.print(""#
|
||||
);
|
||||
assert_eq!(
|
||||
chunks[1].text,
|
||||
r#"Here is a long string 1 ... Here is a long string 2 ... Here is a long string 3 ... Here is a long string 4 ... Here is a long s"#
|
||||
);
|
||||
assert_eq!(
|
||||
chunks[2].text,
|
||||
r#"s a long string 5 ... Here is a long string 6 ... Here is a long string 7 ... Here is a long string 8 ... Here is a long string "#
|
||||
);
|
||||
assert_eq!(chunks[3].text, r#"ng string 9 ...", .{});"#);
|
||||
assert_eq!(
|
||||
chunks[4].text,
|
||||
r#"test "parse integers" {
|
||||
const input = "123 67 89,99";
|
||||
const ally = std.testing.allocator;
|
||||
|
||||
var list = std.ArrayList"#
|
||||
);
|
||||
assert_eq!(
|
||||
chunks[5].text,
|
||||
r#"(u32).init(ally);
|
||||
// Ensure the list is freed at scope exit.
|
||||
// Try commenting out this line!"#
|
||||
);
|
||||
assert_eq!(
|
||||
chunks[6].text,
|
||||
r#"defer list.deinit();
|
||||
|
||||
var it = std.mem.tokenizeAny(u8, input, " ,");
|
||||
while (it.next()) |num"#
|
||||
);
|
||||
assert_eq!(
|
||||
chunks[7].text,
|
||||
r#"| {
|
||||
const n = try parseInt(u32, num, 10);
|
||||
try list.append(n);
|
||||
}
|
||||
|
||||
const expected = [_]u32{ 123, 67, 89,"#
|
||||
);
|
||||
assert_eq!(
|
||||
chunks[8].text,
|
||||
r#"99 };
|
||||
|
||||
for (expected, list.items) |exp, actual| {
|
||||
try std.testing.expectEqual(exp, actual);
|
||||
}
|
||||
}"#
|
||||
);
|
||||
}
|
||||
}
|
37
crates/utils-tree-sitter/Cargo.toml
Normal file
37
crates/utils-tree-sitter/Cargo.toml
Normal file
@ -0,0 +1,37 @@
|
||||
[package]
|
||||
name = "utils-tree-sitter"
|
||||
version = "0.1.0"
|
||||
description = "Utils for working with splitter-tree-sitter"
|
||||
|
||||
edition.workspace = true
|
||||
repository.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
thiserror = "1.0.61"
|
||||
tree-sitter = "0.22"
|
||||
tree-sitter-bash = { version = "0.21", optional = true }
|
||||
tree-sitter-c = { version = "0.21", optional = true }
|
||||
tree-sitter-cpp = { version = "0.22", optional = true }
|
||||
tree-sitter-c-sharp = { version = "0.21", optional = true }
|
||||
tree-sitter-css = { version = "0.21", optional = true }
|
||||
tree-sitter-elixir = { version = "0.2", optional = true }
|
||||
tree-sitter-erlang = { version = "0.6", optional = true }
|
||||
tree-sitter-go = { version = "0.21", optional = true }
|
||||
tree-sitter-html = { version = "0.20", optional = true }
|
||||
tree-sitter-java = { version = "0.21", optional = true }
|
||||
tree-sitter-javascript = { version = "0.21", optional = true }
|
||||
tree-sitter-json = { version = "0.21", optional = true }
|
||||
tree-sitter-haskell = { version = "0.21", optional = true }
|
||||
tree-sitter-lua = { version = "0.1.0", optional = true }
|
||||
tree-sitter-ocaml = { version = "0.22.0", optional = true }
|
||||
tree-sitter-python = { version = "0.21", optional = true }
|
||||
tree-sitter-rust = { version = "0.21", optional = true }
|
||||
# tree-sitter-zig = { git = "https://github.com/maxxnino/tree-sitter-zig", optional = true }
|
||||
|
||||
[build-dependencies]
|
||||
cc="*"
|
||||
|
||||
[features]
|
||||
default = []
|
||||
all = ["dep:tree-sitter-python", "dep:tree-sitter-bash", "dep:tree-sitter-c", "dep:tree-sitter-cpp", "dep:tree-sitter-c-sharp", "dep:tree-sitter-css", "dep:tree-sitter-elixir", "dep:tree-sitter-erlang", "dep:tree-sitter-go", "dep:tree-sitter-html", "dep:tree-sitter-java", "dep:tree-sitter-javascript", "dep:tree-sitter-json", "dep:tree-sitter-rust", "dep:tree-sitter-haskell", "dep:tree-sitter-lua", "dep:tree-sitter-ocaml"]
|
21
crates/utils-tree-sitter/LICENSE
Normal file
21
crates/utils-tree-sitter/LICENSE
Normal file
@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2024 Silas Marvin
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
3
crates/utils-tree-sitter/README.md
Normal file
3
crates/utils-tree-sitter/README.md
Normal file
@ -0,0 +1,3 @@
|
||||
# utils-tree-sitter
|
||||
|
||||
Utils for working with Tree-sitter
|
90
crates/utils-tree-sitter/src/lib.rs
Normal file
90
crates/utils-tree-sitter/src/lib.rs
Normal file
@ -0,0 +1,90 @@
|
||||
use thiserror::Error;
|
||||
use tree_sitter::{LanguageError, Parser};
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum GetParserError {
|
||||
#[error("no parser found for extension")]
|
||||
NoParserFoundForExtension(String),
|
||||
#[error("no parser found for extension")]
|
||||
NoLanguageFoundForExtension(String),
|
||||
#[error("loading grammer")]
|
||||
LoadingGrammer(#[from] LanguageError),
|
||||
}
|
||||
|
||||
fn get_extension_for_language(extension: &str) -> Result<String, GetParserError> {
|
||||
Ok(match extension {
|
||||
"py" => "Python",
|
||||
"rs" => "Rust",
|
||||
// "zig" => "Zig",
|
||||
"sh" => "Bash",
|
||||
"c" => "C",
|
||||
"cpp" => "C++",
|
||||
"cs" => "C#",
|
||||
"css" => "CSS",
|
||||
"ex" => "Elixir",
|
||||
"erl" => "Erlang",
|
||||
"go" => "Go",
|
||||
"html" => "HTML",
|
||||
"java" => "Java",
|
||||
"js" => "JavaScript",
|
||||
"json" => "JSON",
|
||||
"hs" => "Haskell",
|
||||
"lua" => "Lua",
|
||||
"ml" => "OCaml",
|
||||
_ => {
|
||||
return Err(GetParserError::NoLanguageFoundForExtension(
|
||||
extension.to_string(),
|
||||
))
|
||||
}
|
||||
}
|
||||
.to_string())
|
||||
}
|
||||
|
||||
pub fn get_parser_for_extension(extension: &str) -> Result<Parser, GetParserError> {
|
||||
let language = get_extension_for_language(extension)?;
|
||||
let mut parser = Parser::new();
|
||||
match language.as_str() {
|
||||
#[cfg(any(feature = "all", feature = "python"))]
|
||||
"Python" => parser.set_language(&tree_sitter_python::language())?,
|
||||
#[cfg(any(feature = "all", feature = "rust"))]
|
||||
"Rust" => parser.set_language(&tree_sitter_rust::language())?,
|
||||
// #[cfg(any(feature = "all", feature = "zig"))]
|
||||
// "Zig" => parser.set_language(&tree_sitter_zig::language())?,
|
||||
#[cfg(any(feature = "all", feature = "bash"))]
|
||||
"Bash" => parser.set_language(&tree_sitter_bash::language())?,
|
||||
#[cfg(any(feature = "all", feature = "c"))]
|
||||
"C" => parser.set_language(&tree_sitter_c::language())?,
|
||||
#[cfg(any(feature = "all", feature = "cpp"))]
|
||||
"C++" => parser.set_language(&tree_sitter_cpp::language())?,
|
||||
#[cfg(any(feature = "all", feature = "c-sharp"))]
|
||||
"C#" => parser.set_language(&tree_sitter_c_sharp::language())?,
|
||||
#[cfg(any(feature = "all", feature = "css"))]
|
||||
"CSS" => parser.set_language(&tree_sitter_css::language())?,
|
||||
#[cfg(any(feature = "all", feature = "elixir"))]
|
||||
"Elixir" => parser.set_language(&tree_sitter_elixir::language())?,
|
||||
#[cfg(any(feature = "all", feature = "erlang"))]
|
||||
"Erlang" => parser.set_language(&tree_sitter_erlang::language())?,
|
||||
#[cfg(any(feature = "all", feature = "go"))]
|
||||
"Go" => parser.set_language(&tree_sitter_go::language())?,
|
||||
#[cfg(any(feature = "all", feature = "html"))]
|
||||
"HTML" => parser.set_language(&tree_sitter_html::language())?,
|
||||
#[cfg(any(feature = "all", feature = "java"))]
|
||||
"Java" => parser.set_language(&tree_sitter_java::language())?,
|
||||
#[cfg(any(feature = "all", feature = "javascript"))]
|
||||
"JavaScript" => parser.set_language(&tree_sitter_javascript::language())?,
|
||||
#[cfg(any(feature = "all", feature = "json"))]
|
||||
"JSON" => parser.set_language(&tree_sitter_json::language())?,
|
||||
#[cfg(any(feature = "all", feature = "haskell"))]
|
||||
"Haskell" => parser.set_language(&tree_sitter_haskell::language())?,
|
||||
#[cfg(any(feature = "all", feature = "lua"))]
|
||||
"Lua" => parser.set_language(&tree_sitter_lua::language())?,
|
||||
#[cfg(any(feature = "all", feature = "ocaml"))]
|
||||
"OCaml" => parser.set_language(&tree_sitter_ocaml::language_ocaml())?,
|
||||
_ => {
|
||||
return Err(GetParserError::NoParserFoundForExtension(
|
||||
language.to_string(),
|
||||
))
|
||||
}
|
||||
}
|
||||
Ok(parser)
|
||||
}
|
@ -1,597 +0,0 @@
|
||||
use anyhow::Context;
|
||||
use indexmap::IndexSet;
|
||||
use lsp_types::TextDocumentPositionParams;
|
||||
use parking_lot::Mutex;
|
||||
use ropey::Rope;
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
use tracing::instrument;
|
||||
|
||||
use crate::{
|
||||
config::{self, Config},
|
||||
utils::tokens_to_estimated_characters,
|
||||
};
|
||||
|
||||
use super::{ContextAndCodePrompt, FIMPrompt, MemoryBackend, MemoryRunParams, Prompt, PromptType};
|
||||
|
||||
pub struct FileStore {
|
||||
_crawl: bool,
|
||||
_config: Config,
|
||||
file_map: Mutex<HashMap<String, Rope>>,
|
||||
accessed_files: Mutex<IndexSet<String>>,
|
||||
}
|
||||
|
||||
impl FileStore {
|
||||
pub fn new(file_store_config: config::FileStore, config: Config) -> Self {
|
||||
Self {
|
||||
_crawl: file_store_config.crawl,
|
||||
_config: config,
|
||||
file_map: Mutex::new(HashMap::new()),
|
||||
accessed_files: Mutex::new(IndexSet::new()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_without_crawl(config: Config) -> Self {
|
||||
Self {
|
||||
_crawl: false,
|
||||
_config: config,
|
||||
file_map: Mutex::new(HashMap::new()),
|
||||
accessed_files: Mutex::new(IndexSet::new()),
|
||||
}
|
||||
}
|
||||
|
||||
fn get_rope_for_position(
|
||||
&self,
|
||||
position: &TextDocumentPositionParams,
|
||||
characters: usize,
|
||||
) -> anyhow::Result<(Rope, usize)> {
|
||||
// Get the rope and set our initial cursor index
|
||||
let current_document_uri = position.text_document.uri.to_string();
|
||||
let mut rope = self
|
||||
.file_map
|
||||
.lock()
|
||||
.get(¤t_document_uri)
|
||||
.context("Error file not found")?
|
||||
.clone();
|
||||
let mut cursor_index = rope.line_to_char(position.position.line as usize)
|
||||
+ position.position.character as usize;
|
||||
// Add to our rope if we need to
|
||||
for file in self
|
||||
.accessed_files
|
||||
.lock()
|
||||
.iter()
|
||||
.filter(|f| **f != current_document_uri)
|
||||
{
|
||||
let needed = characters.saturating_sub(rope.len_chars() + 1);
|
||||
if needed == 0 {
|
||||
break;
|
||||
}
|
||||
let file_map = self.file_map.lock();
|
||||
let r = file_map.get(file).context("Error file not found")?;
|
||||
let slice_max = needed.min(r.len_chars() + 1);
|
||||
let rope_str_slice = r
|
||||
.get_slice(0..slice_max - 1)
|
||||
.context("Error getting slice")?
|
||||
.to_string();
|
||||
rope.insert(0, "\n");
|
||||
rope.insert(0, &rope_str_slice);
|
||||
cursor_index += slice_max;
|
||||
}
|
||||
Ok((rope, cursor_index))
|
||||
}
|
||||
|
||||
pub fn get_characters_around_position(
|
||||
&self,
|
||||
position: &TextDocumentPositionParams,
|
||||
characters: usize,
|
||||
) -> anyhow::Result<String> {
|
||||
let rope = self
|
||||
.file_map
|
||||
.lock()
|
||||
.get(position.text_document.uri.as_str())
|
||||
.context("Error file not found")?
|
||||
.clone();
|
||||
let cursor_index = rope.line_to_char(position.position.line as usize)
|
||||
+ position.position.character as usize;
|
||||
let start = cursor_index.saturating_sub(characters / 2);
|
||||
let end = rope
|
||||
.len_chars()
|
||||
.min(cursor_index + (characters - (cursor_index - start)));
|
||||
let rope_slice = rope
|
||||
.get_slice(start..end)
|
||||
.context("Error getting rope slice")?;
|
||||
Ok(rope_slice.to_string())
|
||||
}
|
||||
|
||||
pub fn build_code(
|
||||
&self,
|
||||
position: &TextDocumentPositionParams,
|
||||
prompt_type: PromptType,
|
||||
params: MemoryRunParams,
|
||||
) -> anyhow::Result<Prompt> {
|
||||
let (mut rope, cursor_index) =
|
||||
self.get_rope_for_position(position, params.max_context_length)?;
|
||||
|
||||
Ok(match prompt_type {
|
||||
PromptType::ContextAndCode => {
|
||||
if params.is_for_chat {
|
||||
let max_length = tokens_to_estimated_characters(params.max_context_length);
|
||||
let start = cursor_index.saturating_sub(max_length / 2);
|
||||
let end = rope
|
||||
.len_chars()
|
||||
.min(cursor_index + (max_length - (cursor_index - start)));
|
||||
|
||||
rope.insert(cursor_index, "<CURSOR>");
|
||||
let rope_slice = rope
|
||||
.get_slice(start..end + "<CURSOR>".chars().count())
|
||||
.context("Error getting rope slice")?;
|
||||
Prompt::ContextAndCode(ContextAndCodePrompt::new(
|
||||
"".to_string(),
|
||||
rope_slice.to_string(),
|
||||
))
|
||||
} else {
|
||||
let start = cursor_index
|
||||
.saturating_sub(tokens_to_estimated_characters(params.max_context_length));
|
||||
let rope_slice = rope
|
||||
.get_slice(start..cursor_index)
|
||||
.context("Error getting rope slice")?;
|
||||
Prompt::ContextAndCode(ContextAndCodePrompt::new(
|
||||
"".to_string(),
|
||||
rope_slice.to_string(),
|
||||
))
|
||||
}
|
||||
}
|
||||
PromptType::FIM => {
|
||||
let max_length = tokens_to_estimated_characters(params.max_context_length);
|
||||
let start = cursor_index.saturating_sub(max_length / 2);
|
||||
let end = rope
|
||||
.len_chars()
|
||||
.min(cursor_index + (max_length - (cursor_index - start)));
|
||||
let prefix = rope
|
||||
.get_slice(start..cursor_index)
|
||||
.context("Error getting rope slice")?;
|
||||
let suffix = rope
|
||||
.get_slice(cursor_index..end)
|
||||
.context("Error getting rope slice")?;
|
||||
Prompt::FIM(FIMPrompt::new(prefix.to_string(), suffix.to_string()))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl MemoryBackend for FileStore {
|
||||
#[instrument(skip(self))]
|
||||
async fn get_filter_text(
|
||||
&self,
|
||||
position: &TextDocumentPositionParams,
|
||||
) -> anyhow::Result<String> {
|
||||
let rope = self
|
||||
.file_map
|
||||
.lock()
|
||||
.get(position.text_document.uri.as_str())
|
||||
.context("Error file not found")?
|
||||
.clone();
|
||||
let line = rope
|
||||
.get_line(position.position.line as usize)
|
||||
.context("Error getting filter_text")?
|
||||
.slice(0..position.position.character as usize)
|
||||
.to_string();
|
||||
Ok(line)
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
async fn build_prompt(
|
||||
&self,
|
||||
position: &TextDocumentPositionParams,
|
||||
prompt_type: PromptType,
|
||||
params: &Value,
|
||||
) -> anyhow::Result<Prompt> {
|
||||
let params: MemoryRunParams = params.try_into()?;
|
||||
self.build_code(position, prompt_type, params)
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
async fn opened_text_document(
|
||||
&self,
|
||||
params: lsp_types::DidOpenTextDocumentParams,
|
||||
) -> anyhow::Result<()> {
|
||||
let rope = Rope::from_str(¶ms.text_document.text);
|
||||
let uri = params.text_document.uri.to_string();
|
||||
self.file_map.lock().insert(uri.clone(), rope);
|
||||
self.accessed_files.lock().shift_insert(0, uri);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
async fn changed_text_document(
|
||||
&self,
|
||||
params: lsp_types::DidChangeTextDocumentParams,
|
||||
) -> anyhow::Result<()> {
|
||||
let uri = params.text_document.uri.to_string();
|
||||
let mut file_map = self.file_map.lock();
|
||||
let rope = file_map
|
||||
.get_mut(&uri)
|
||||
.context("Error trying to get file that does not exist")?;
|
||||
for change in params.content_changes {
|
||||
// If range is ommitted, text is the new text of the document
|
||||
if let Some(range) = change.range {
|
||||
let start_index =
|
||||
rope.line_to_char(range.start.line as usize) + range.start.character as usize;
|
||||
let end_index =
|
||||
rope.line_to_char(range.end.line as usize) + range.end.character as usize;
|
||||
rope.remove(start_index..end_index);
|
||||
rope.insert(start_index, &change.text);
|
||||
} else {
|
||||
*rope = Rope::from_str(&change.text);
|
||||
}
|
||||
}
|
||||
self.accessed_files.lock().shift_insert(0, uri);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
async fn renamed_files(&self, params: lsp_types::RenameFilesParams) -> anyhow::Result<()> {
|
||||
for file_rename in params.files {
|
||||
let mut file_map = self.file_map.lock();
|
||||
if let Some(rope) = file_map.remove(&file_rename.old_uri) {
|
||||
file_map.insert(file_rename.new_uri, rope);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use lsp_types::{
|
||||
DidOpenTextDocumentParams, FileRename, Position, Range, RenameFilesParams,
|
||||
TextDocumentContentChangeEvent, TextDocumentIdentifier, TextDocumentItem,
|
||||
VersionedTextDocumentIdentifier,
|
||||
};
|
||||
use serde_json::json;
|
||||
|
||||
fn generate_base_file_store() -> anyhow::Result<FileStore> {
|
||||
let config = Config::default_with_file_store_without_models();
|
||||
let file_store_config = if let config::ValidMemoryBackend::FileStore(file_store_config) =
|
||||
config.config.memory.clone()
|
||||
{
|
||||
file_store_config
|
||||
} else {
|
||||
anyhow::bail!("requires a file_store_config")
|
||||
};
|
||||
Ok(FileStore::new(file_store_config, config))
|
||||
}
|
||||
|
||||
fn generate_filler_text_document(uri: Option<&str>, text: Option<&str>) -> TextDocumentItem {
|
||||
let uri = uri.unwrap_or("file://filler/");
|
||||
let text = text.unwrap_or("Here is the document body");
|
||||
TextDocumentItem {
|
||||
uri: reqwest::Url::parse(uri).unwrap(),
|
||||
language_id: "filler".to_string(),
|
||||
version: 0,
|
||||
text: text.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn can_open_document() -> anyhow::Result<()> {
|
||||
let params = lsp_types::DidOpenTextDocumentParams {
|
||||
text_document: generate_filler_text_document(None, None),
|
||||
};
|
||||
let file_store = generate_base_file_store()?;
|
||||
file_store.opened_text_document(params).await?;
|
||||
let file = file_store
|
||||
.file_map
|
||||
.lock()
|
||||
.get("file://filler/")
|
||||
.unwrap()
|
||||
.clone();
|
||||
assert_eq!(file.to_string(), "Here is the document body");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn can_rename_document() -> anyhow::Result<()> {
|
||||
let params = lsp_types::DidOpenTextDocumentParams {
|
||||
text_document: generate_filler_text_document(None, None),
|
||||
};
|
||||
let file_store = generate_base_file_store()?;
|
||||
file_store.opened_text_document(params).await?;
|
||||
|
||||
let params = RenameFilesParams {
|
||||
files: vec![FileRename {
|
||||
old_uri: "file://filler/".to_string(),
|
||||
new_uri: "file://filler2/".to_string(),
|
||||
}],
|
||||
};
|
||||
file_store.renamed_files(params).await?;
|
||||
|
||||
let file = file_store
|
||||
.file_map
|
||||
.lock()
|
||||
.get("file://filler2/")
|
||||
.unwrap()
|
||||
.clone();
|
||||
assert_eq!(file.to_string(), "Here is the document body");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn can_change_document() -> anyhow::Result<()> {
|
||||
let text_document = generate_filler_text_document(None, None);
|
||||
|
||||
let params = DidOpenTextDocumentParams {
|
||||
text_document: text_document.clone(),
|
||||
};
|
||||
let file_store = generate_base_file_store()?;
|
||||
file_store.opened_text_document(params).await?;
|
||||
|
||||
let params = lsp_types::DidChangeTextDocumentParams {
|
||||
text_document: VersionedTextDocumentIdentifier {
|
||||
uri: text_document.uri.clone(),
|
||||
version: 1,
|
||||
},
|
||||
content_changes: vec![TextDocumentContentChangeEvent {
|
||||
range: Some(Range {
|
||||
start: Position {
|
||||
line: 0,
|
||||
character: 1,
|
||||
},
|
||||
end: Position {
|
||||
line: 0,
|
||||
character: 3,
|
||||
},
|
||||
}),
|
||||
range_length: None,
|
||||
text: "a".to_string(),
|
||||
}],
|
||||
};
|
||||
file_store.changed_text_document(params).await?;
|
||||
let file = file_store
|
||||
.file_map
|
||||
.lock()
|
||||
.get("file://filler/")
|
||||
.unwrap()
|
||||
.clone();
|
||||
assert_eq!(file.to_string(), "Hae is the document body");
|
||||
|
||||
let params = lsp_types::DidChangeTextDocumentParams {
|
||||
text_document: VersionedTextDocumentIdentifier {
|
||||
uri: text_document.uri,
|
||||
version: 1,
|
||||
},
|
||||
content_changes: vec![TextDocumentContentChangeEvent {
|
||||
range: None,
|
||||
range_length: None,
|
||||
text: "abc".to_string(),
|
||||
}],
|
||||
};
|
||||
file_store.changed_text_document(params).await?;
|
||||
let file = file_store
|
||||
.file_map
|
||||
.lock()
|
||||
.get("file://filler/")
|
||||
.unwrap()
|
||||
.clone();
|
||||
assert_eq!(file.to_string(), "abc");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn can_build_prompt() -> anyhow::Result<()> {
|
||||
let text_document = generate_filler_text_document(
|
||||
None,
|
||||
Some(
|
||||
r#"Document Top
|
||||
Here is a more complicated document
|
||||
|
||||
Some text
|
||||
|
||||
The end with a trailing new line
|
||||
"#,
|
||||
),
|
||||
);
|
||||
|
||||
// Test basic completion
|
||||
let params = lsp_types::DidOpenTextDocumentParams {
|
||||
text_document: text_document.clone(),
|
||||
};
|
||||
let file_store = generate_base_file_store()?;
|
||||
file_store.opened_text_document(params).await?;
|
||||
|
||||
let prompt = file_store
|
||||
.build_prompt(
|
||||
&TextDocumentPositionParams {
|
||||
text_document: TextDocumentIdentifier {
|
||||
uri: text_document.uri.clone(),
|
||||
},
|
||||
position: Position {
|
||||
line: 0,
|
||||
character: 10,
|
||||
},
|
||||
},
|
||||
PromptType::ContextAndCode,
|
||||
&json!({}),
|
||||
)
|
||||
.await?;
|
||||
let prompt: ContextAndCodePrompt = prompt.try_into()?;
|
||||
assert_eq!(prompt.context, "");
|
||||
assert_eq!("Document T", prompt.code);
|
||||
|
||||
// Test FIM
|
||||
let prompt = file_store
|
||||
.build_prompt(
|
||||
&TextDocumentPositionParams {
|
||||
text_document: TextDocumentIdentifier {
|
||||
uri: text_document.uri.clone(),
|
||||
},
|
||||
position: Position {
|
||||
line: 0,
|
||||
character: 10,
|
||||
},
|
||||
},
|
||||
PromptType::FIM,
|
||||
&json!({}),
|
||||
)
|
||||
.await?;
|
||||
let prompt: FIMPrompt = prompt.try_into()?;
|
||||
assert_eq!(prompt.prompt, r#"Document T"#);
|
||||
assert_eq!(
|
||||
prompt.suffix,
|
||||
r#"op
|
||||
Here is a more complicated document
|
||||
|
||||
Some text
|
||||
|
||||
The end with a trailing new line
|
||||
"#
|
||||
);
|
||||
|
||||
// Test chat
|
||||
let prompt = file_store
|
||||
.build_prompt(
|
||||
&TextDocumentPositionParams {
|
||||
text_document: TextDocumentIdentifier {
|
||||
uri: text_document.uri.clone(),
|
||||
},
|
||||
position: Position {
|
||||
line: 0,
|
||||
character: 10,
|
||||
},
|
||||
},
|
||||
PromptType::ContextAndCode,
|
||||
&json!({
|
||||
"messages": []
|
||||
}),
|
||||
)
|
||||
.await?;
|
||||
let prompt: ContextAndCodePrompt = prompt.try_into()?;
|
||||
assert_eq!(prompt.context, "");
|
||||
let text = r#"Document T<CURSOR>op
|
||||
Here is a more complicated document
|
||||
|
||||
Some text
|
||||
|
||||
The end with a trailing new line
|
||||
"#
|
||||
.to_string();
|
||||
assert_eq!(text, prompt.code);
|
||||
|
||||
// Test multi-file
|
||||
let text_document2 = generate_filler_text_document(
|
||||
Some("file://filler2"),
|
||||
Some(
|
||||
r#"Document Top2
|
||||
Here is a more complicated document
|
||||
|
||||
Some text
|
||||
|
||||
The end with a trailing new line
|
||||
"#,
|
||||
),
|
||||
);
|
||||
let params = lsp_types::DidOpenTextDocumentParams {
|
||||
text_document: text_document2.clone(),
|
||||
};
|
||||
file_store.opened_text_document(params).await?;
|
||||
|
||||
let prompt = file_store
|
||||
.build_prompt(
|
||||
&TextDocumentPositionParams {
|
||||
text_document: TextDocumentIdentifier {
|
||||
uri: text_document.uri.clone(),
|
||||
},
|
||||
position: Position {
|
||||
line: 0,
|
||||
character: 10,
|
||||
},
|
||||
},
|
||||
PromptType::ContextAndCode,
|
||||
&json!({}),
|
||||
)
|
||||
.await?;
|
||||
let prompt: ContextAndCodePrompt = prompt.try_into()?;
|
||||
assert_eq!(prompt.context, "");
|
||||
assert_eq!(format!("{}\nDocument T", text_document2.text), prompt.code);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_document_cursor_placement_corner_cases() -> anyhow::Result<()> {
|
||||
let text_document = generate_filler_text_document(None, Some("test\n"));
|
||||
let params = lsp_types::DidOpenTextDocumentParams {
|
||||
text_document: text_document.clone(),
|
||||
};
|
||||
let file_store = generate_base_file_store()?;
|
||||
file_store.opened_text_document(params).await?;
|
||||
|
||||
// Test chat
|
||||
let prompt = file_store
|
||||
.build_prompt(
|
||||
&TextDocumentPositionParams {
|
||||
text_document: TextDocumentIdentifier {
|
||||
uri: text_document.uri.clone(),
|
||||
},
|
||||
position: Position {
|
||||
line: 1,
|
||||
character: 0,
|
||||
},
|
||||
},
|
||||
PromptType::ContextAndCode,
|
||||
&json!({"messages": []}),
|
||||
)
|
||||
.await?;
|
||||
let prompt: ContextAndCodePrompt = prompt.try_into()?;
|
||||
assert_eq!(prompt.context, "");
|
||||
let text = r#"test
|
||||
<CURSOR>"#
|
||||
.to_string();
|
||||
assert_eq!(text, prompt.code);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// #[tokio::test]
|
||||
// async fn test_fim_placement_corner_cases() -> anyhow::Result<()> {
|
||||
// let text_document = generate_filler_text_document(None, Some("test\n"));
|
||||
// let params = lsp_types::DidOpenTextDocumentParams {
|
||||
// text_document: text_document.clone(),
|
||||
// };
|
||||
// let file_store = generate_base_file_store()?;
|
||||
// file_store.opened_text_document(params).await?;
|
||||
|
||||
// // Test FIM
|
||||
// let params = json!({
|
||||
// "fim": {
|
||||
// "start": "SS",
|
||||
// "middle": "MM",
|
||||
// "end": "EE"
|
||||
// }
|
||||
// });
|
||||
// let prompt = file_store
|
||||
// .build_prompt(
|
||||
// &TextDocumentPositionParams {
|
||||
// text_document: TextDocumentIdentifier {
|
||||
// uri: text_document.uri.clone(),
|
||||
// },
|
||||
// position: Position {
|
||||
// line: 1,
|
||||
// character: 0,
|
||||
// },
|
||||
// },
|
||||
// params,
|
||||
// )
|
||||
// .await?;
|
||||
// assert_eq!(prompt.context, "");
|
||||
// let text = r#"test
|
||||
// "#
|
||||
// .to_string();
|
||||
// assert_eq!(text, prompt.code);
|
||||
|
||||
// Ok(())
|
||||
// }
|
||||
}
|
@ -1,254 +0,0 @@
|
||||
use std::{
|
||||
sync::mpsc::{self, Sender},
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use anyhow::Context;
|
||||
use lsp_types::TextDocumentPositionParams;
|
||||
use pgml::{Collection, Pipeline};
|
||||
use serde_json::{json, Value};
|
||||
use tokio::time;
|
||||
use tracing::instrument;
|
||||
|
||||
use crate::{
|
||||
config::{self, Config},
|
||||
utils::tokens_to_estimated_characters,
|
||||
};
|
||||
|
||||
use super::{
|
||||
file_store::FileStore, ContextAndCodePrompt, MemoryBackend, MemoryRunParams, Prompt, PromptType,
|
||||
};
|
||||
|
||||
pub struct PostgresML {
|
||||
_config: Config,
|
||||
file_store: FileStore,
|
||||
collection: Collection,
|
||||
pipeline: Pipeline,
|
||||
debounce_tx: Sender<String>,
|
||||
added_pipeline: bool,
|
||||
}
|
||||
|
||||
impl PostgresML {
|
||||
pub fn new(
|
||||
postgresml_config: config::PostgresML,
|
||||
configuration: Config,
|
||||
) -> anyhow::Result<Self> {
|
||||
let file_store = FileStore::new_without_crawl(configuration.clone());
|
||||
let database_url = if let Some(database_url) = postgresml_config.database_url {
|
||||
database_url
|
||||
} else {
|
||||
std::env::var("PGML_DATABASE_URL")?
|
||||
};
|
||||
// TODO: Think on the naming of the collection
|
||||
// Maybe filter on metadata or I'm not sure
|
||||
let collection = Collection::new("test-lsp-ai-3", Some(database_url))?;
|
||||
// TODO: Review the pipeline
|
||||
let pipeline = Pipeline::new(
|
||||
"v1",
|
||||
Some(
|
||||
json!({
|
||||
"text": {
|
||||
"splitter": {
|
||||
"model": "recursive_character",
|
||||
"parameters": {
|
||||
"chunk_size": 1500,
|
||||
"chunk_overlap": 40
|
||||
}
|
||||
},
|
||||
"semantic_search": {
|
||||
"model": "intfloat/e5-small",
|
||||
}
|
||||
}
|
||||
})
|
||||
.into(),
|
||||
),
|
||||
)?;
|
||||
// Setup up a debouncer for changed text documents
|
||||
let runtime = tokio::runtime::Builder::new_multi_thread()
|
||||
.worker_threads(2)
|
||||
.enable_all()
|
||||
.build()?;
|
||||
let mut task_collection = collection.clone();
|
||||
let (debounce_tx, debounce_rx) = mpsc::channel::<String>();
|
||||
runtime.spawn(async move {
|
||||
let duration = Duration::from_millis(500);
|
||||
let mut file_paths = Vec::new();
|
||||
loop {
|
||||
time::sleep(duration).await;
|
||||
let new_paths: Vec<String> = debounce_rx.try_iter().collect();
|
||||
if !new_paths.is_empty() {
|
||||
for path in new_paths {
|
||||
if !file_paths.iter().any(|p| *p == path) {
|
||||
file_paths.push(path);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if file_paths.is_empty() {
|
||||
continue;
|
||||
}
|
||||
let documents = file_paths
|
||||
.into_iter()
|
||||
.map(|path| {
|
||||
let text = std::fs::read_to_string(&path)
|
||||
.unwrap_or_else(|_| panic!("Error reading path: {}", path));
|
||||
json!({
|
||||
"id": path,
|
||||
"text": text
|
||||
})
|
||||
.into()
|
||||
})
|
||||
.collect();
|
||||
task_collection
|
||||
.upsert_documents(documents, None)
|
||||
.await
|
||||
.expect("PGML - Error adding pipeline to collection");
|
||||
file_paths = Vec::new();
|
||||
}
|
||||
}
|
||||
});
|
||||
Ok(Self {
|
||||
_config: configuration,
|
||||
file_store,
|
||||
collection,
|
||||
pipeline,
|
||||
debounce_tx,
|
||||
added_pipeline: false,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl MemoryBackend for PostgresML {
|
||||
#[instrument(skip(self))]
|
||||
async fn get_filter_text(
|
||||
&self,
|
||||
position: &TextDocumentPositionParams,
|
||||
) -> anyhow::Result<String> {
|
||||
self.file_store.get_filter_text(position).await
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
async fn build_prompt(
|
||||
&self,
|
||||
position: &TextDocumentPositionParams,
|
||||
prompt_type: PromptType,
|
||||
params: &Value,
|
||||
) -> anyhow::Result<Prompt> {
|
||||
let params: MemoryRunParams = params.try_into()?;
|
||||
let query = self
|
||||
.file_store
|
||||
.get_characters_around_position(position, 512)?;
|
||||
let res = self
|
||||
.collection
|
||||
.vector_search_local(
|
||||
json!({
|
||||
"query": {
|
||||
"fields": {
|
||||
"text": {
|
||||
"query": query
|
||||
}
|
||||
},
|
||||
},
|
||||
"limit": 5
|
||||
})
|
||||
.into(),
|
||||
&self.pipeline,
|
||||
)
|
||||
.await?;
|
||||
let context = res
|
||||
.into_iter()
|
||||
.map(|c| {
|
||||
c["chunk"]
|
||||
.as_str()
|
||||
.map(|t| t.to_owned())
|
||||
.context("PGML - Error getting chunk from vector search")
|
||||
})
|
||||
.collect::<anyhow::Result<Vec<String>>>()?
|
||||
.join("\n\n");
|
||||
let mut file_store_params = params.clone();
|
||||
file_store_params.max_context_length = 512;
|
||||
let code = self
|
||||
.file_store
|
||||
.build_code(position, prompt_type, file_store_params)?;
|
||||
let code: ContextAndCodePrompt = code.try_into()?;
|
||||
let code = code.code;
|
||||
let max_characters = tokens_to_estimated_characters(params.max_context_length);
|
||||
let _context: String = context
|
||||
.chars()
|
||||
.take(max_characters - code.chars().count())
|
||||
.collect();
|
||||
// We need to redo this section to work with the new memory backend system
|
||||
todo!()
|
||||
// Ok(Prompt::new(context, code))
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
async fn opened_text_document(
|
||||
&self,
|
||||
params: lsp_types::DidOpenTextDocumentParams,
|
||||
) -> anyhow::Result<()> {
|
||||
let text = params.text_document.text.clone();
|
||||
let path = params.text_document.uri.path().to_owned();
|
||||
let task_added_pipeline = self.added_pipeline;
|
||||
let mut task_collection = self.collection.clone();
|
||||
let mut task_pipeline = self.pipeline.clone();
|
||||
if !task_added_pipeline {
|
||||
task_collection
|
||||
.add_pipeline(&mut task_pipeline)
|
||||
.await
|
||||
.expect("PGML - Error adding pipeline to collection");
|
||||
}
|
||||
task_collection
|
||||
.upsert_documents(
|
||||
vec![json!({
|
||||
"id": path,
|
||||
"text": text
|
||||
})
|
||||
.into()],
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("PGML - Error upserting documents");
|
||||
self.file_store.opened_text_document(params).await
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
async fn changed_text_document(
|
||||
&self,
|
||||
params: lsp_types::DidChangeTextDocumentParams,
|
||||
) -> anyhow::Result<()> {
|
||||
let path = params.text_document.uri.path().to_owned();
|
||||
self.debounce_tx.send(path)?;
|
||||
self.file_store.changed_text_document(params).await
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
async fn renamed_files(&self, params: lsp_types::RenameFilesParams) -> anyhow::Result<()> {
|
||||
let mut task_collection = self.collection.clone();
|
||||
let task_params = params.clone();
|
||||
for file in task_params.files {
|
||||
task_collection
|
||||
.delete_documents(
|
||||
json!({
|
||||
"id": file.old_uri
|
||||
})
|
||||
.into(),
|
||||
)
|
||||
.await
|
||||
.expect("PGML - Error deleting file");
|
||||
let text = std::fs::read_to_string(&file.new_uri).expect("PGML - Error reading file");
|
||||
task_collection
|
||||
.upsert_documents(
|
||||
vec![json!({
|
||||
"id": file.new_uri,
|
||||
"text": text
|
||||
})
|
||||
.into()],
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("PGML - Error adding pipeline to collection");
|
||||
}
|
||||
self.file_store.renamed_files(params).await
|
||||
}
|
||||
}
|
@ -1,112 +0,0 @@
|
||||
use anyhow::Result;
|
||||
use std::{
|
||||
io::{Read, Write},
|
||||
process::{ChildStdin, ChildStdout, Command, Stdio},
|
||||
};
|
||||
|
||||
// Note if you get an empty response with no error, that typically means
|
||||
// the language server died
|
||||
fn read_response(stdout: &mut ChildStdout) -> Result<String> {
|
||||
let mut content_length = None;
|
||||
let mut buf = vec![];
|
||||
loop {
|
||||
let mut buf2 = vec![0];
|
||||
stdout.read_exact(&mut buf2)?;
|
||||
buf.push(buf2[0]);
|
||||
if let Some(content_length) = content_length {
|
||||
if buf.len() == content_length {
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
let len = buf.len();
|
||||
if len > 4
|
||||
&& buf[len - 4] == 13
|
||||
&& buf[len - 3] == 10
|
||||
&& buf[len - 2] == 13
|
||||
&& buf[len - 1] == 10
|
||||
{
|
||||
content_length =
|
||||
Some(String::from_utf8(buf[16..len - 4].to_vec())?.parse::<usize>()?);
|
||||
buf = vec![];
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(String::from_utf8(buf)?)
|
||||
}
|
||||
|
||||
fn send_message(stdin: &mut ChildStdin, message: &str) -> Result<()> {
|
||||
stdin.write_all(format!("Content-Length: {}\r\n", message.as_bytes().len(),).as_bytes())?;
|
||||
stdin.write_all("\r\n".as_bytes())?;
|
||||
stdin.write_all(message.as_bytes())?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// This completion sequence was created using helix with the lsp-ai analyzer and reading the logs
|
||||
// It starts with a Python file:
|
||||
// ```
|
||||
// # Multiplies two numbers
|
||||
// def multiply_two_numbers(x, y):
|
||||
//
|
||||
// # A singular test
|
||||
// assert multiply_two_numbers(2, 3) == 6
|
||||
// ```
|
||||
// And has the following sequence of key strokes:
|
||||
// o on line 2 (this creates an indented new line and enters insert mode)
|
||||
// r
|
||||
// e
|
||||
// The sequence has:
|
||||
// - 1 textDocument/DidOpen notification
|
||||
// - 3 textDocument/didChange notifications
|
||||
// - 1 textDocument/completion requests
|
||||
// This test can fail if the model gives a different response than normal, but that seems reasonably unlikely
|
||||
// I guess we should hardcode the seed or something if we want to do more of these
|
||||
#[test]
|
||||
fn test_completion_sequence() -> Result<()> {
|
||||
let mut child = Command::new("cargo")
|
||||
.arg("run")
|
||||
.stdin(Stdio::piped())
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped())
|
||||
.spawn()?;
|
||||
|
||||
let mut stdin = child.stdin.take().unwrap();
|
||||
let mut stdout = child.stdout.take().unwrap();
|
||||
|
||||
let initialization_message = r##"{"jsonrpc":"2.0","method":"initialize","params":{"capabilities":{"general":{"positionEncodings":["utf-8","utf-32","utf-16"]},"textDocument":{"codeAction":{"codeActionLiteralSupport":{"codeActionKind":{"valueSet":["","quickfix","refactor","refactor.extract","refactor.inline","refactor.rewrite","source","source.organizeImports"]}},"dataSupport":true,"disabledSupport":true,"isPreferredSupport":true,"resolveSupport":{"properties":["edit","command"]}},"completion":{"completionItem":{"deprecatedSupport":true,"insertReplaceSupport":true,"resolveSupport":{"properties":["documentation","detail","additionalTextEdits"]},"snippetSupport":true,"tagSupport":{"valueSet":[1]}},"completionItemKind":{}},"hover":{"contentFormat":["markdown"]},"inlayHint":{"dynamicRegistration":false},"publishDiagnostics":{"versionSupport":true},"rename":{"dynamicRegistration":false,"honorsChangeAnnotations":false,"prepareSupport":true},"signatureHelp":{"signatureInformation":{"activeParameterSupport":true,"documentationFormat":["markdown"],"parameterInformation":{"labelOffsetSupport":true}}}},"window":{"workDoneProgress":true},"workspace":{"applyEdit":true,"configuration":true,"didChangeConfiguration":{"dynamicRegistration":false},"didChangeWatchedFiles":{"dynamicRegistration":true,"relativePatternSupport":false},"executeCommand":{"dynamicRegistration":false},"inlayHint":{"refreshSupport":false},"symbol":{"dynamicRegistration":false},"workspaceEdit":{"documentChanges":true,"failureHandling":"abort","normalizesLineEndings":false,"resourceOperations":["create","rename","delete"]},"workspaceFolders":true}},"clientInfo":{"name":"helix","version":"23.10 (f6021dd0)"},"processId":70007,"rootPath":"/Users/silas/Projects/Tests/lsp-ai-tests","rootUri":null,"workspaceFolders":[]},"id":0}"##;
|
||||
send_message(&mut stdin, initialization_message)?;
|
||||
let _ = read_response(&mut stdout)?;
|
||||
|
||||
send_message(
|
||||
&mut stdin,
|
||||
r#"{"jsonrpc":"2.0","method":"initialized","params":{}}"#,
|
||||
)?;
|
||||
send_message(
|
||||
&mut stdin,
|
||||
r##"{"jsonrpc":"2.0","method":"textDocument/didOpen","params":{"textDocument":{"languageId":"python","text":"# Multiplies two numbers\ndef multiply_two_numbers(x, y):\n\n# A singular test\nassert multiply_two_numbers(2, 3) == 6\n","uri":"file:///fake.py","version":0}}}"##,
|
||||
)?;
|
||||
send_message(
|
||||
&mut stdin,
|
||||
r##"{"jsonrpc":"2.0","method":"textDocument/didChange","params":{"contentChanges":[{"range":{"end":{"character":31,"line":1},"start":{"character":31,"line":1}},"text":"\n "}],"textDocument":{"uri":"file:///fake.py","version":1}}}"##,
|
||||
)?;
|
||||
send_message(
|
||||
&mut stdin,
|
||||
r##"{"jsonrpc":"2.0","method":"textDocument/didChange","params":{"contentChanges":[{"range":{"end":{"character":4,"line":2},"start":{"character":4,"line":2}},"text":"r"}],"textDocument":{"uri":"file:///fake.py","version":2}}}"##,
|
||||
)?;
|
||||
send_message(
|
||||
&mut stdin,
|
||||
r##"{"jsonrpc":"2.0","method":"textDocument/didChange","params":{"contentChanges":[{"range":{"end":{"character":5,"line":2},"start":{"character":5,"line":2}},"text":"e"}],"textDocument":{"uri":"file:///fake.py","version":3}}}"##,
|
||||
)?;
|
||||
send_message(
|
||||
&mut stdin,
|
||||
r##"{"jsonrpc":"2.0","method":"textDocument/completion","params":{"position":{"character":6,"line":2},"textDocument":{"uri":"file:///fake.py"}},"id":1}"##,
|
||||
)?;
|
||||
|
||||
let output = read_response(&mut stdout)?;
|
||||
assert_eq!(
|
||||
output,
|
||||
r##"{"jsonrpc":"2.0","id":1,"result":{"isIncomplete":false,"items":[{"filterText":" re\n","kind":1,"label":"ai - turn x * y","textEdit":{"newText":"turn x * y","range":{"end":{"character":6,"line":2},"start":{"character":6,"line":2}}}}]}}"##
|
||||
);
|
||||
|
||||
child.kill()?;
|
||||
Ok(())
|
||||
}
|
Loading…
Reference in New Issue
Block a user