diff --git a/Cargo.lock b/Cargo.lock index 524e12f..d9ef5db 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -149,6 +149,18 @@ dependencies = [ "num-traits", ] +[[package]] +name = "auto_enums" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1899bfcfd9340ceea3533ea157360ba8fa864354eccbceab58e1006ecab35393" +dependencies = [ + "derive_utils", + "proc-macro2", + "quote", + "syn 2.0.52", +] + [[package]] name = "autocfg" version = "1.1.0" @@ -356,7 +368,7 @@ version = "4.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "307bc0538d5f0f83b8248db3087aa92fe504e4691294d0c96c0eabc33f47ba47" dependencies = [ - "heck", + "heck 0.4.1", "proc-macro2", "quote", "syn 2.0.52", @@ -662,6 +674,17 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "derive_utils" +version = "0.14.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "61bb5a1014ce6dfc2a378578509abe775a5aa06bff584a547555d9efdb81b926" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.52", +] + [[package]] name = "difflib" version = "0.4.0" @@ -730,9 +753,9 @@ checksum = "0d6ef0072f8a535281e4876be788938b528e9a1d43900b82c2569af7da799125" [[package]] name = "either" -version = "1.10.0" +version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "11157ac094ffbdde99aa67b23417ebdd801842852b500e395a45a9c0aac03e4a" +checksum = "3dca9240753cf90908d7e4aac30f630662b02aebaa1b58a3cadabdb23385b58b" dependencies = [ "serde", ] @@ -1056,6 +1079,12 @@ dependencies = [ "unicode-segmentation", ] +[[package]] +name = "heck" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" + [[package]] name = "hermit-abi" version = "0.3.9" @@ -1364,6 +1393,15 @@ dependencies = [ "either", ] +[[package]] +name = "itertools" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "1.0.10" @@ -1523,6 +1561,7 @@ dependencies = [ "anyhow", "assert_cmd", "async-trait", + "cc", "directories", "hf-hub", "ignore", @@ -1530,6 +1569,7 @@ dependencies = [ "llama-cpp-2", "lsp-server", "lsp-types", + "md5", "minijinja", "once_cell", "parking_lot", @@ -1539,10 +1579,14 @@ dependencies = [ "ropey", "serde", "serde_json", + "splitter-tree-sitter", + "text-splitter", "tokenizers", "tokio", "tracing", "tracing-subscriber", + "tree-sitter", + "utils-tree-sitter", "xxhash-rust", ] @@ -2196,9 +2240,9 @@ dependencies = [ [[package]] name = "regex" -version = "1.10.3" +version = "1.10.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b62dbe01f0b06f9d8dc7d49e05a0785f153b00b2c227856282f671e0318c9b15" +checksum = "b91213439dad192326a0d7c6ee3955910425f441d7038e0d6933b0aec5c4517f" dependencies = [ "aho-corasick", "memchr", @@ -2415,6 +2459,12 @@ dependencies = [ "untrusted", ] +[[package]] +name = "rustversion" +version = "1.0.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "955d28af4278de8121b7ebeb796b6a45735dc01436d898801014aced2773a3d6" + [[package]] name = "ryu" version = "1.0.17" @@ -2475,7 +2525,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "878cf3d57f0e5bfacd425cdaccc58b4c06d68a7b71c63fc28710a20c88676808" dependencies = [ "darling 0.14.4", - "heck", + "heck 0.4.1", "quote", "syn 1.0.109", ] @@ -2498,7 +2548,7 @@ version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "25a82fcb49253abcb45cdcb2adf92956060ec0928635eb21b4f7a6d8f25ab0bc" dependencies = [ - "heck", + "heck 0.4.1", "proc-macro2", "quote", "syn 2.0.52", @@ -2756,6 +2806,17 @@ dependencies = [ "der", ] +[[package]] +name = "splitter-tree-sitter" +version = "0.1.0" +dependencies = [ + "cc", + "thiserror", + "tree-sitter", + "tree-sitter-rust", + "tree-sitter-zig", +] + [[package]] name = "spm_precompiled" version = "0.1.4" @@ -2857,7 +2918,7 @@ checksum = "5833ef53aaa16d860e92123292f1f6a3d53c34ba8b1969f152ef1a7bb803f3c8" dependencies = [ "dotenvy", "either", - "heck", + "heck 0.4.1", "hex", "once_cell", "proc-macro2", @@ -3013,6 +3074,28 @@ version = "0.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" +[[package]] +name = "strum" +version = "0.26.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d8cec3501a5194c432b2b7976db6b7d10ec95c253208b45f83f7136aa985e29" +dependencies = [ + "strum_macros", +] + +[[package]] +name = "strum_macros" +version = "0.26.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c6bee85a5a24955dc440386795aa378cd9cf82acd5f764469152d2270e581be" +dependencies = [ + "heck 0.5.0", + "proc-macro2", + "quote", + "rustversion", + "syn 2.0.52", +] + [[package]] name = "subtle" version = "2.5.0" @@ -3087,19 +3170,36 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3369f5ac52d5eb6ab48c6b4ffdc8efbcad6b89c765749064ba298f2c68a16a76" [[package]] -name = "thiserror" -version = "1.0.58" +name = "text-splitter" +version = "0.13.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "03468839009160513471e86a034bb2c5c0e4baae3b43f79ffc55c4a5427b3297" +checksum = "2ab9dc04b7cf08eb01c07c272bf699fa55679a326ddf7dd075e14094efc80fb9" +dependencies = [ + "ahash", + "auto_enums", + "either", + "itertools 0.13.0", + "once_cell", + "regex", + "strum", + "thiserror", + "unicode-segmentation", +] + +[[package]] +name = "thiserror" +version = "1.0.61" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c546c80d6be4bc6a00c0f01730c08df82eaa7a7a61f11d656526506112cc1709" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.58" +version = "1.0.61" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c61f3ba182994efc43764a46c018c347bc492c79f024e705f46567b418f6d4f7" +checksum = "46c3384250002a6d5af4d114f2845d37b57521033f30d5c3f46c4d70e1197533" dependencies = [ "proc-macro2", "quote", @@ -3339,6 +3439,195 @@ dependencies = [ "tracing-serde", ] +[[package]] +name = "tree-sitter" +version = "0.22.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df7cc499ceadd4dcdf7ec6d4cbc34ece92c3fa07821e287aedecd4416c516dca" +dependencies = [ + "cc", + "regex", +] + +[[package]] +name = "tree-sitter-bash" +version = "0.21.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b5244703ad2e08a616d859a0557d7aa290adcd5e0990188a692e628ffe9dce40" +dependencies = [ + "cc", + "tree-sitter", +] + +[[package]] +name = "tree-sitter-c" +version = "0.21.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f956d5351d62652864a4ff3ae861747e7a1940dc96c9998ae400ac0d3ce30427" +dependencies = [ + "cc", + "tree-sitter", +] + +[[package]] +name = "tree-sitter-c-sharp" +version = "0.21.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff899037068a1ffbb891891b7e94db1400ddf12c3d934b85b8c9e30be5cd18da" +dependencies = [ + "cc", + "tree-sitter", +] + +[[package]] +name = "tree-sitter-cpp" +version = "0.22.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "537b7e0f0d8c89b8dd6f4d195814da94832f20720c09016c2a3ac3dc3c437993" +dependencies = [ + "cc", + "tree-sitter", +] + +[[package]] +name = "tree-sitter-css" +version = "0.21.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2f806f96136762b0121f5fdd7172a3dcd8f42d37a2f23ed7f11b35895e20eb4" +dependencies = [ + "cc", + "tree-sitter", +] + +[[package]] +name = "tree-sitter-elixir" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df94bf7f057768b1cab2ee1f14812ed4ae33f9e04d09254043eeaa797db4ef70" +dependencies = [ + "cc", + "tree-sitter", +] + +[[package]] +name = "tree-sitter-erlang" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8db61152e6d8a5b3b5895ecbb85848f85d028f84b4633a2368075c35e5817b34" +dependencies = [ + "cc", + "tree-sitter", +] + +[[package]] +name = "tree-sitter-go" +version = "0.21.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55cb318be5ccf75f44e054acf6898a5c95d59b53443eed578e16be0cd7ec037f" +dependencies = [ + "cc", + "tree-sitter", +] + +[[package]] +name = "tree-sitter-haskell" +version = "0.21.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ef25a7e6c73cc1cbe0c0b7dbd5406e7b3485b370bd61c5d8d852ae0781f9bf9a" +dependencies = [ + "cc", + "tree-sitter", +] + +[[package]] +name = "tree-sitter-html" +version = "0.20.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95b3492b08a786bf5cc79feb0ef2ff3b115d5174364e0ddfd7860e0b9b088b53" +dependencies = [ + "cc", + "tree-sitter", +] + +[[package]] +name = "tree-sitter-java" +version = "0.21.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "33bc21adf831a773c075d9d00107ab43965e6a6ea7607b47fd9ec6f3db4b481b" +dependencies = [ + "cc", + "tree-sitter", +] + +[[package]] +name = "tree-sitter-javascript" +version = "0.21.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4fced510d43e6627cd8e19adfd994ac9cfa3b1d71b0d522b41f74145de37feef" +dependencies = [ + "cc", + "tree-sitter", +] + +[[package]] +name = "tree-sitter-json" +version = "0.21.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b737dcb73c35d74b7d64a5f3dde158113c86a012bf3cee2bfdf2150d23b05db" +dependencies = [ + "cc", + "tree-sitter", +] + +[[package]] +name = "tree-sitter-lua" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b9fe6fc87bd480e1943fc1fcb02453fb2da050e4e8ce0daa67d801544046856" +dependencies = [ + "cc", + "tree-sitter", +] + +[[package]] +name = "tree-sitter-ocaml" +version = "0.22.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2f2e8e848902d12ca6778d31d0e66b5709fc1ad0c84fd8b0c078472fff20dd2" +dependencies = [ + "cc", + "tree-sitter", +] + +[[package]] +name = "tree-sitter-python" +version = "0.21.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4066c6cf678f962f8c2c4561f205945c84834cce73d981e71392624fdc390a9" +dependencies = [ + "cc", + "tree-sitter", +] + +[[package]] +name = "tree-sitter-rust" +version = "0.21.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "277690f420bf90741dea984f3da038ace46c4fe6047cba57a66822226cde1c93" +dependencies = [ + "cc", + "tree-sitter", +] + +[[package]] +name = "tree-sitter-zig" +version = "0.0.1" +source = "git+https://github.com/maxxnino/tree-sitter-zig#7c5a29b721d409be8842017351bf007d7e384401" +dependencies = [ + "cc", + "tree-sitter", +] + [[package]] name = "try-lock" version = "0.2.5" @@ -3450,6 +3739,32 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a" +[[package]] +name = "utils-tree-sitter" +version = "0.1.0" +dependencies = [ + "cc", + "thiserror", + "tree-sitter", + "tree-sitter-bash", + "tree-sitter-c", + "tree-sitter-c-sharp", + "tree-sitter-cpp", + "tree-sitter-css", + "tree-sitter-elixir", + "tree-sitter-erlang", + "tree-sitter-go", + "tree-sitter-haskell", + "tree-sitter-html", + "tree-sitter-java", + "tree-sitter-javascript", + "tree-sitter-json", + "tree-sitter-lua", + "tree-sitter-ocaml", + "tree-sitter-python", + "tree-sitter-rust", +] + [[package]] name = "uuid" version = "1.7.0" diff --git a/Cargo.toml b/Cargo.toml index 18dfb33..0121f0b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,44 +1,12 @@ -[package] -name = "lsp-ai" -version = "0.3.0" +[workspace] +members = [ + "crates/*", +] +resolver = "2" + +[workspace.package] edition = "2021" license = "MIT" description = "LSP-AI is an open-source language server that serves as a backend for AI-powered functionality, designed to assist and empower software engineers, not replace them." repository = "https://github.com/SilasMarvin/lsp-ai" readme = "README.md" - -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - -[dependencies] -anyhow = "1.0.75" -lsp-server = "0.7.6" -lsp-types = "0.95.0" -ropey = "1.6.1" -serde = "1.0.190" -serde_json = "1.0.108" -hf-hub = { git = "https://github.com/huggingface/hf-hub", version = "0.3.2" } -rand = "0.8.5" -tokenizers = "0.14.1" -parking_lot = "0.12.1" -once_cell = "1.19.0" -directories = "5.0.1" -llama-cpp-2 = { version = "0.1.55", optional = true } -minijinja = { version = "1.0.12", features = ["loader"] } -tracing-subscriber = { version = "0.3.18", features = ["env-filter"] } -tracing = "0.1.40" -xxhash-rust = { version = "0.8.5", features = ["xxh3"] } -reqwest = { version = "0.11.25", features = ["blocking", "json"] } -ignore = "0.4.22" -pgml = "1.0.4" -tokio = { version = "1.36.0", features = ["rt-multi-thread", "time"] } -indexmap = "2.2.5" -async-trait = "0.1.78" - -[features] -default = [] -llama_cpp = ["dep:llama-cpp-2"] -metal = ["llama-cpp-2/metal"] -cuda = ["llama-cpp-2/cuda"] - -[dev-dependencies] -assert_cmd = "2.0.14" diff --git a/crates/lsp-ai/Cargo.toml b/crates/lsp-ai/Cargo.toml new file mode 100644 index 0000000..05ec30b --- /dev/null +++ b/crates/lsp-ai/Cargo.toml @@ -0,0 +1,51 @@ +[package] +name = "lsp-ai" +version = "0.3.0" + +description.workspace = true +repository.workspace = true +readme.workspace = true +edition.workspace = true +license.workspace = true + +[dependencies] +anyhow = "1.0.75" +lsp-server = "0.7.6" +lsp-types = "0.95.0" +ropey = "1.6.1" +serde = "1.0.190" +serde_json = "1.0.108" +hf-hub = { git = "https://github.com/huggingface/hf-hub", version = "0.3.2" } +rand = "0.8.5" +tokenizers = "0.14.1" +parking_lot = "0.12.1" +once_cell = "1.19.0" +directories = "5.0.1" +llama-cpp-2 = { version = "0.1.55", optional = true } +minijinja = { version = "1.0.12", features = ["loader"] } +tracing-subscriber = { version = "0.3.18", features = ["env-filter"] } +tracing = "0.1.40" +xxhash-rust = { version = "0.8.5", features = ["xxh3"] } +reqwest = { version = "0.11.25", features = ["blocking", "json"] } +ignore = "0.4.22" +pgml = "1.0.4" +tokio = { version = "1.36.0", features = ["rt-multi-thread", "time"] } +indexmap = "2.2.5" +async-trait = "0.1.78" +tree-sitter = "0.22" +utils-tree-sitter = { path = "../utils-tree-sitter", features = ["all"], version = "0.1.0" } +splitter-tree-sitter = { path = "../splitter-tree-sitter", version = "0.1.0" } +text-splitter = { version = "0.13.3" } +md5 = "0.7.0" + +[build-dependencies] +cc="*" + +[features] +default = [] +llama_cpp = ["dep:llama-cpp-2"] +metal = ["llama-cpp-2/metal"] +cuda = ["llama-cpp-2/cuda"] + +[dev-dependencies] +assert_cmd = "2.0.14" diff --git a/src/config.rs b/crates/lsp-ai/src/config.rs similarity index 87% rename from src/config.rs rename to crates/lsp-ai/src/config.rs index 8b7b394..ea9631f 100644 --- a/src/config.rs +++ b/crates/lsp-ai/src/config.rs @@ -24,6 +24,51 @@ impl Default for PostProcess { } } +#[derive(Debug, Clone, Deserialize)] +pub enum ValidSplitter { + #[serde(rename = "tree_sitter")] + TreeSitter(TreeSitter), + #[serde(rename = "text_sitter")] + TextSplitter(TextSplitter), +} + +impl Default for ValidSplitter { + fn default() -> Self { + ValidSplitter::TreeSitter(TreeSitter::default()) + } +} + +const fn chunk_size_default() -> usize { + 1500 +} + +const fn chunk_overlap_default() -> usize { + 0 +} + +#[derive(Debug, Clone, Deserialize)] +pub struct TreeSitter { + #[serde(default = "chunk_size_default")] + pub chunk_size: usize, + #[serde(default = "chunk_overlap_default")] + pub chunk_overlap: usize, +} + +impl Default for TreeSitter { + fn default() -> Self { + Self { + chunk_size: 1500, + chunk_overlap: 0, + } + } +} + +#[derive(Debug, Clone, Deserialize)] +pub struct TextSplitter { + #[serde(default = "chunk_size_default")] + pub chunk_size: usize, +} + #[derive(Debug, Clone, Deserialize)] pub enum ValidMemoryBackend { #[serde(rename = "file_store")] @@ -67,15 +112,6 @@ impl ChatMessage { } } -#[derive(Debug, Clone, Deserialize)] -#[serde(deny_unknown_fields)] -pub struct Chat { - pub completion: Option>, - pub generation: Option>, - pub chat_template: Option, - pub chat_format: Option, -} - #[derive(Clone, Debug, Deserialize)] #[allow(clippy::upper_case_acronyms)] #[serde(deny_unknown_fields)] @@ -85,27 +121,52 @@ pub struct FIM { pub end: String, } +const fn max_crawl_memory_default() -> u64 { + 100_000_000 +} + +const fn max_crawl_file_size_default() -> u64 { + 10_000_000 +} + +#[derive(Clone, Debug, Deserialize)] +#[serde(deny_unknown_fields)] +pub struct Crawl { + #[serde(default = "max_crawl_file_size_default")] + pub max_file_size: u64, + #[serde(default = "max_crawl_memory_default")] + pub max_crawl_memory: u64, + #[serde(default)] + pub all_files: bool, +} + +#[derive(Clone, Debug, Deserialize)] +pub struct PostgresMLEmbeddingModel { + pub model: String, + pub embed_parameters: Option, + pub query_parameters: Option, +} + #[derive(Clone, Debug, Deserialize)] #[serde(deny_unknown_fields)] pub struct PostgresML { pub database_url: Option, + pub crawl: Option, #[serde(default)] - pub crawl: bool, + pub splitter: ValidSplitter, + pub embedding_model: Option, } #[derive(Clone, Debug, Deserialize, Default)] #[serde(deny_unknown_fields)] pub struct FileStore { - #[serde(default)] - pub crawl: bool, + pub crawl: Option, } -const fn n_gpu_layers_default() -> u32 { - 1000 -} - -const fn n_ctx_default() -> u32 { - 1000 +impl FileStore { + pub fn new_without_crawl() -> Self { + Self { crawl: None } + } } #[derive(Clone, Debug, Deserialize)] @@ -137,6 +198,17 @@ pub struct MistralFIM { pub max_requests_per_second: f32, } +#[cfg(feature = "llama_cpp")] +const fn n_gpu_layers_default() -> u32 { + 1000 +} + +#[cfg(feature = "llama_cpp")] +const fn n_ctx_default() -> u32 { + 1000 +} + +#[cfg(feature = "llama_cpp")] #[derive(Clone, Debug, Deserialize)] #[serde(deny_unknown_fields)] pub struct LLaMACPP { @@ -230,15 +302,14 @@ pub struct ValidConfig { #[derive(Clone, Debug, Deserialize, Default)] pub struct ValidClientParams { - #[serde(alias = "rootURI")] - _root_uri: Option, - _workspace_folders: Option>, + #[serde(alias = "rootUri")] + pub root_uri: Option, } #[derive(Clone, Debug)] pub struct Config { pub config: ValidConfig, - _client_params: ValidClientParams, + pub client_params: ValidClientParams, } impl Config { @@ -255,7 +326,7 @@ impl Config { let client_params: ValidClientParams = serde_json::from_value(args)?; Ok(Self { config: valid_args, - _client_params: client_params, + client_params, }) } @@ -300,20 +371,17 @@ impl Config { } } -// This makes testing much easier. +// For teesting use only #[cfg(test)] impl Config { pub fn default_with_file_store_without_models() -> Self { Self { config: ValidConfig { - memory: ValidMemoryBackend::FileStore(FileStore { crawl: false }), + memory: ValidMemoryBackend::FileStore(FileStore { crawl: None }), models: HashMap::new(), completion: None, }, - _client_params: ValidClientParams { - _root_uri: None, - _workspace_folders: None, - }, + client_params: ValidClientParams { root_uri: None }, } } } diff --git a/crates/lsp-ai/src/crawl.rs b/crates/lsp-ai/src/crawl.rs new file mode 100644 index 0000000..2dc1721 --- /dev/null +++ b/crates/lsp-ai/src/crawl.rs @@ -0,0 +1,103 @@ +use ignore::WalkBuilder; +use std::collections::HashSet; +use tracing::{error, instrument}; + +use crate::config::{self, Config}; + +pub struct Crawl { + crawl_config: config::Crawl, + config: Config, + crawled_file_types: HashSet, + crawled_all: bool, +} + +impl Crawl { + pub fn new(crawl_config: config::Crawl, config: Config) -> Self { + Self { + crawl_config, + config, + crawled_file_types: HashSet::new(), + crawled_all: false, + } + } + + #[instrument(skip(self, f))] + pub fn maybe_do_crawl( + &mut self, + triggered_file: Option, + mut f: impl FnMut(&config::Crawl, &str) -> anyhow::Result, + ) -> anyhow::Result<()> { + if self.crawled_all { + return Ok(()); + } + + if let Some(root_uri) = &self.config.client_params.root_uri { + if !root_uri.starts_with("file://") { + anyhow::bail!("Skipping crawling as root_uri does not begin with file://") + } + + let extension_to_match = triggered_file + .map(|tf| { + let path = std::path::Path::new(&tf); + path.extension().map(|f| f.to_str().map(|f| f.to_owned())) + }) + .flatten() + .flatten(); + + if let Some(extension_to_match) = &extension_to_match { + if self.crawled_file_types.contains(extension_to_match) { + return Ok(()); + } + } + + if !self.crawl_config.all_files && extension_to_match.is_none() { + return Ok(()); + } + + for result in WalkBuilder::new(&root_uri[7..]).build() { + let result = result?; + let path = result.path(); + if !path.is_dir() { + if let Some(path_str) = path.to_str() { + if self.crawl_config.all_files { + match f(&self.crawl_config, path_str) { + Ok(c) => { + if !c { + break; + } + } + Err(e) => error!("{e:?}"), + } + } else { + match ( + path.extension().map(|pe| pe.to_str()).flatten(), + &extension_to_match, + ) { + (Some(path_extension), Some(extension_to_match)) => { + if path_extension == extension_to_match { + match f(&self.crawl_config, path_str) { + Ok(c) => { + if !c { + break; + } + } + Err(e) => error!("{e:?}"), + } + } + } + _ => continue, + } + } + } + } + } + + if let Some(extension_to_match) = extension_to_match { + self.crawled_file_types.insert(extension_to_match); + } else { + self.crawled_all = true + } + } + Ok(()) + } +} diff --git a/src/custom_requests/generation.rs b/crates/lsp-ai/src/custom_requests/generation.rs similarity index 100% rename from src/custom_requests/generation.rs rename to crates/lsp-ai/src/custom_requests/generation.rs diff --git a/src/custom_requests/generation_stream.rs b/crates/lsp-ai/src/custom_requests/generation_stream.rs similarity index 100% rename from src/custom_requests/generation_stream.rs rename to crates/lsp-ai/src/custom_requests/generation_stream.rs diff --git a/src/custom_requests/mod.rs b/crates/lsp-ai/src/custom_requests/mod.rs similarity index 100% rename from src/custom_requests/mod.rs rename to crates/lsp-ai/src/custom_requests/mod.rs diff --git a/src/main.rs b/crates/lsp-ai/src/main.rs similarity index 94% rename from src/main.rs rename to crates/lsp-ai/src/main.rs index 37b1f3f..106be0a 100644 --- a/src/main.rs +++ b/crates/lsp-ai/src/main.rs @@ -14,9 +14,11 @@ use tracing::error; use tracing_subscriber::{EnvFilter, FmtSubscriber}; mod config; +mod crawl; mod custom_requests; mod memory_backends; mod memory_worker; +mod splitters; #[cfg(feature = "llama_cpp")] mod template; mod transformer_backends; @@ -50,15 +52,19 @@ where req.extract(R::METHOD) } -fn main() -> Result<()> { - // Builds a tracing subscriber from the `LSP_AI_LOG` environment variable - // If the variables value is malformed or missing, sets the default log level to ERROR +// Builds a tracing subscriber from the `LSP_AI_LOG` environment variable +// If the variables value is malformed or missing, sets the default log level to ERROR +fn init_logger() { FmtSubscriber::builder() .with_writer(std::io::stderr) .with_ansi(false) .without_time() .with_env_filter(EnvFilter::from_env("LSP_AI_LOG")) .init(); +} + +fn main() -> Result<()> { + init_logger(); let (connection, io_threads) = Connection::stdio(); let server_capabilities = serde_json::to_value(ServerCapabilities { @@ -83,7 +89,6 @@ fn main_loop(connection: Connection, args: serde_json::Value) -> Result<()> { let connection = Arc::new(connection); // Our channel we use to communicate with our transformer worker - // let last_worker_request = Arc::new(Mutex::new(None)); let (transformer_tx, transformer_rx) = mpsc::channel(); // The channel we use to communicate with our memory worker @@ -94,8 +99,6 @@ fn main_loop(connection: Connection, args: serde_json::Value) -> Result<()> { thread::spawn(move || memory_worker::run(memory_backend, memory_rx)); // Setup our transformer worker - // let transformer_backend: Box = - // config.clone().try_into()?; let transformer_backends: HashMap> = config .config .models diff --git a/crates/lsp-ai/src/memory_backends/file_store.rs b/crates/lsp-ai/src/memory_backends/file_store.rs new file mode 100644 index 0000000..45abb7e --- /dev/null +++ b/crates/lsp-ai/src/memory_backends/file_store.rs @@ -0,0 +1,924 @@ +use anyhow::Context; +use indexmap::IndexSet; +use lsp_types::TextDocumentPositionParams; +use parking_lot::Mutex; +use ropey::Rope; +use serde_json::Value; +use std::{collections::HashMap, io::Read}; +use tracing::{error, instrument, warn}; +use tree_sitter::{InputEdit, Point, Tree}; + +use crate::{ + config::{self, Config}, + crawl::Crawl, + utils::{parse_tree, tokens_to_estimated_characters}, +}; + +use super::{ContextAndCodePrompt, FIMPrompt, MemoryBackend, MemoryRunParams, Prompt, PromptType}; + +#[derive(Default)] +pub struct AdditionalFileStoreParams { + build_tree: bool, +} + +impl AdditionalFileStoreParams { + pub fn new(build_tree: bool) -> Self { + Self { build_tree } + } +} + +#[derive(Clone)] +pub struct File { + rope: Rope, + tree: Option, +} + +impl File { + fn new(rope: Rope, tree: Option) -> Self { + Self { rope, tree } + } + + pub fn rope(&self) -> &Rope { + &self.rope + } + + pub fn tree(&self) -> Option<&Tree> { + self.tree.as_ref() + } +} + +pub struct FileStore { + params: AdditionalFileStoreParams, + file_map: Mutex>, + accessed_files: Mutex>, + crawl: Option>, +} + +impl FileStore { + pub fn new(mut file_store_config: config::FileStore, config: Config) -> anyhow::Result { + let crawl = file_store_config + .crawl + .take() + .map(|x| Mutex::new(Crawl::new(x, config.clone()))); + let s = Self { + params: AdditionalFileStoreParams::default(), + file_map: Mutex::new(HashMap::new()), + accessed_files: Mutex::new(IndexSet::new()), + crawl, + }; + if let Err(e) = s.maybe_do_crawl(None) { + error!("{e:?}") + } + Ok(s) + } + + pub fn new_with_params( + mut file_store_config: config::FileStore, + config: Config, + params: AdditionalFileStoreParams, + ) -> anyhow::Result { + let crawl = file_store_config + .crawl + .take() + .map(|x| Mutex::new(Crawl::new(x, config.clone()))); + let s = Self { + params, + file_map: Mutex::new(HashMap::new()), + accessed_files: Mutex::new(IndexSet::new()), + crawl, + }; + if let Err(e) = s.maybe_do_crawl(None) { + error!("{e:?}") + } + Ok(s) + } + + fn add_new_file(&self, uri: &str, contents: String) { + let tree = if self.params.build_tree { + match parse_tree(uri, &contents, None) { + Ok(tree) => Some(tree), + Err(e) => { + error!( + "Failed to parse tree for {uri} with error {e}, falling back to no tree" + ); + None + } + } + } else { + None + }; + self.file_map + .lock() + .insert(uri.to_string(), File::new(Rope::from_str(&contents), tree)); + self.accessed_files.lock().insert(uri.to_string()); + } + + fn maybe_do_crawl(&self, triggered_file: Option) -> anyhow::Result<()> { + let mut total_bytes = 0; + let mut current_bytes = 0; + if let Some(crawl) = &self.crawl { + crawl + .lock() + .maybe_do_crawl(triggered_file, |config, path| { + // Break if total bytes is over the max crawl memory + if total_bytes as u64 >= config.max_crawl_memory { + warn!("Ending crawl early due to `max_crawl_memory` resetraint"); + return Ok(false); + } + // This means it has been opened before + let insert_uri = format!("file:///{path}"); + if self.file_map.lock().contains_key(&insert_uri) { + return Ok(true); + } + // Open the file and see if it is small enough to read + let mut f = std::fs::File::open(path)?; + let metadata = f.metadata()?; + if metadata.len() > config.max_file_size { + warn!("Skipping file: {path} because it is too large"); + return Ok(true); + } + // Read the file contents + let mut contents = vec![]; + f.read_to_end(&mut contents)?; + let contents = String::from_utf8(contents)?; + current_bytes += contents.len(); + total_bytes += contents.len(); + self.add_new_file(&insert_uri, contents); + Ok(true) + })?; + } + Ok(()) + } + + fn get_rope_for_position( + &self, + position: &TextDocumentPositionParams, + characters: usize, + pull_from_multiple_files: bool, + ) -> anyhow::Result<(Rope, usize)> { + // Get the rope and set our initial cursor index + let current_document_uri = position.text_document.uri.to_string(); + let mut rope = self + .file_map + .lock() + .get(¤t_document_uri) + .context("Error file not found")? + .rope + .clone(); + let mut cursor_index = rope.line_to_char(position.position.line as usize) + + position.position.character as usize; + // Add to our rope if we need to + for file in self + .accessed_files + .lock() + .iter() + .filter(|f| **f != current_document_uri) + { + let needed = characters.saturating_sub(rope.len_chars() + 1); + if needed == 0 || !pull_from_multiple_files { + break; + } + let file_map = self.file_map.lock(); + let r = &file_map.get(file).context("Error file not found")?.rope; + let slice_max = needed.min(r.len_chars() + 1); + let rope_str_slice = r + .get_slice(0..slice_max - 1) + .context("Error getting slice")? + .to_string(); + rope.insert(0, "\n"); + rope.insert(0, &rope_str_slice); + cursor_index += slice_max; + } + Ok((rope, cursor_index)) + } + + pub fn get_characters_around_position( + &self, + position: &TextDocumentPositionParams, + characters: usize, + ) -> anyhow::Result { + let rope = self + .file_map + .lock() + .get(position.text_document.uri.as_str()) + .context("Error file not found")? + .rope + .clone(); + let cursor_index = rope.line_to_char(position.position.line as usize) + + position.position.character as usize; + let start = cursor_index.saturating_sub(characters / 2); + let end = rope + .len_chars() + .min(cursor_index + (characters - (cursor_index - start))); + let rope_slice = rope + .get_slice(start..end) + .context("Error getting rope slice")?; + Ok(rope_slice.to_string()) + } + + pub fn build_code( + &self, + position: &TextDocumentPositionParams, + prompt_type: PromptType, + params: MemoryRunParams, + pull_from_multiple_files: bool, + ) -> anyhow::Result { + let (mut rope, cursor_index) = + self.get_rope_for_position(position, params.max_context, pull_from_multiple_files)?; + + Ok(match prompt_type { + PromptType::ContextAndCode => { + if params.is_for_chat { + let max_length = tokens_to_estimated_characters(params.max_context); + let start = cursor_index.saturating_sub(max_length / 2); + let end = rope + .len_chars() + .min(cursor_index + (max_length - (cursor_index - start))); + + rope.insert(cursor_index, ""); + let rope_slice = rope + .get_slice(start..end + "".chars().count()) + .context("Error getting rope slice")?; + Prompt::ContextAndCode(ContextAndCodePrompt::new( + "".to_string(), + rope_slice.to_string(), + )) + } else { + let start = cursor_index + .saturating_sub(tokens_to_estimated_characters(params.max_context)); + let rope_slice = rope + .get_slice(start..cursor_index) + .context("Error getting rope slice")?; + Prompt::ContextAndCode(ContextAndCodePrompt::new( + "".to_string(), + rope_slice.to_string(), + )) + } + } + PromptType::FIM => { + let max_length = tokens_to_estimated_characters(params.max_context); + let start = cursor_index.saturating_sub(max_length / 2); + let end = rope + .len_chars() + .min(cursor_index + (max_length - (cursor_index - start))); + let prefix = rope + .get_slice(start..cursor_index) + .context("Error getting rope slice")?; + let suffix = rope + .get_slice(cursor_index..end) + .context("Error getting rope slice")?; + Prompt::FIM(FIMPrompt::new(prefix.to_string(), suffix.to_string())) + } + }) + } + + pub fn file_map(&self) -> &Mutex> { + &self.file_map + } + + pub fn contains_file(&self, uri: &str) -> bool { + self.file_map.lock().contains_key(uri) + } + + pub fn position_to_byte(&self, position: &TextDocumentPositionParams) -> anyhow::Result { + let file_map = self.file_map.lock(); + let uri = position.text_document.uri.to_string(); + let file = file_map + .get(&uri) + .with_context(|| format!("trying to get file that does not exist {uri}"))?; + let line_char_index = file + .rope + .try_line_to_char(position.position.line as usize)?; + Ok(file + .rope + .try_char_to_byte(line_char_index + position.position.character as usize)?) + } +} + +#[async_trait::async_trait] +impl MemoryBackend for FileStore { + #[instrument(skip(self))] + fn get_filter_text(&self, position: &TextDocumentPositionParams) -> anyhow::Result { + let rope = self + .file_map + .lock() + .get(position.text_document.uri.as_str()) + .context("Error file not found")? + .rope + .clone(); + let line = rope + .get_line(position.position.line as usize) + .context("Error getting filter text")? + .get_slice(0..position.position.character as usize) + .context("Error getting filter text")? + .to_string(); + Ok(line) + } + + #[instrument(skip(self))] + async fn build_prompt( + &self, + position: &TextDocumentPositionParams, + prompt_type: PromptType, + params: &Value, + ) -> anyhow::Result { + let params: MemoryRunParams = params.try_into()?; + self.build_code(position, prompt_type, params, true) + } + + #[instrument(skip(self))] + fn opened_text_document( + &self, + params: lsp_types::DidOpenTextDocumentParams, + ) -> anyhow::Result<()> { + let uri = params.text_document.uri.to_string(); + self.add_new_file(&uri, params.text_document.text); + if let Err(e) = self.maybe_do_crawl(Some(uri)) { + error!("{e:?}") + } + Ok(()) + } + + #[instrument(skip(self))] + fn changed_text_document( + &self, + params: lsp_types::DidChangeTextDocumentParams, + ) -> anyhow::Result<()> { + let uri = params.text_document.uri.to_string(); + let mut file_map = self.file_map.lock(); + let file = file_map + .get_mut(&uri) + .with_context(|| format!("Trying to get file that does not exist {uri}"))?; + for change in params.content_changes { + // If range is ommitted, text is the new text of the document + if let Some(range) = change.range { + // Record old positions + let (old_end_position, old_end_byte) = { + let last_line_index = file.rope.len_lines() - 1; + ( + file.rope + .get_line(last_line_index) + .context("getting last line for edit") + .map(|last_line| Point::new(last_line_index, last_line.len_chars())), + file.rope.bytes().count(), + ) + }; + // Update the document + let start_index = file.rope.line_to_char(range.start.line as usize) + + range.start.character as usize; + let end_index = + file.rope.line_to_char(range.end.line as usize) + range.end.character as usize; + file.rope.remove(start_index..end_index); + file.rope.insert(start_index, &change.text); + // Set new end positions + let (new_end_position, new_end_byte) = { + let last_line_index = file.rope.len_lines() - 1; + ( + file.rope + .get_line(last_line_index) + .context("getting last line for edit") + .map(|last_line| Point::new(last_line_index, last_line.len_chars())), + file.rope.bytes().count(), + ) + }; + // Update the tree + if self.params.build_tree { + let mut old_tree = file.tree.take(); + let start_byte = file + .rope + .try_line_to_char(range.start.line as usize) + .and_then(|start_char| { + file.rope + .try_char_to_byte(start_char + range.start.character as usize) + }) + .map_err(anyhow::Error::msg); + if let Some(old_tree) = &mut old_tree { + match (start_byte, old_end_position, new_end_position) { + (Ok(start_byte), Ok(old_end_position), Ok(new_end_position)) => { + old_tree.edit(&InputEdit { + start_byte, + old_end_byte, + new_end_byte, + start_position: Point::new( + range.start.line as usize, + range.start.character as usize, + ), + old_end_position, + new_end_position, + }); + file.tree = match parse_tree( + &uri, + &file.rope.to_string(), + Some(old_tree), + ) { + Ok(tree) => Some(tree), + Err(e) => { + error!("failed to edit tree: {e:?}"); + None + } + }; + } + (Err(e), _, _) | (_, Err(e), _) | (_, _, Err(e)) => { + error!("failed to build tree edit: {e:?}"); + } + } + } + } + } else { + file.rope = Rope::from_str(&change.text); + if self.params.build_tree { + file.tree = match parse_tree(&uri, &change.text, None) { + Ok(tree) => Some(tree), + Err(e) => { + error!("failed to parse new tree: {e:?}"); + None + } + }; + } + } + } + self.accessed_files.lock().shift_insert(0, uri); + Ok(()) + } + + #[instrument(skip(self))] + fn renamed_files(&self, params: lsp_types::RenameFilesParams) -> anyhow::Result<()> { + for file_rename in params.files { + let mut file_map = self.file_map.lock(); + if let Some(rope) = file_map.remove(&file_rename.old_uri) { + file_map.insert(file_rename.new_uri, rope); + } + } + Ok(()) + } +} + +// For teesting use only +#[cfg(test)] +impl FileStore { + pub fn default_with_filler_file() -> anyhow::Result { + let config = Config::default_with_file_store_without_models(); + let file_store_config = if let config::ValidMemoryBackend::FileStore(file_store_config) = + config.config.memory.clone() + { + file_store_config + } else { + anyhow::bail!("requires a file_store_config") + }; + let f = FileStore::new(file_store_config, config)?; + + let uri = "file:///filler.py"; + let text = r#"# Multiplies two numbers +def multiply_two_numbers(x, y): + return + +# A singular test +assert multiply_two_numbers(2, 3) == 6 +"#; + let params = lsp_types::DidOpenTextDocumentParams { + text_document: lsp_types::TextDocumentItem { + uri: reqwest::Url::parse(uri).unwrap(), + language_id: "filler".to_string(), + version: 0, + text: text.to_string(), + }, + }; + f.opened_text_document(params)?; + + Ok(f) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use lsp_types::{ + DidOpenTextDocumentParams, FileRename, Position, Range, RenameFilesParams, + TextDocumentContentChangeEvent, TextDocumentIdentifier, TextDocumentItem, + VersionedTextDocumentIdentifier, + }; + use serde_json::json; + + fn generate_base_file_store() -> anyhow::Result { + let config = Config::default_with_file_store_without_models(); + let file_store_config = if let config::ValidMemoryBackend::FileStore(file_store_config) = + config.config.memory.clone() + { + file_store_config + } else { + anyhow::bail!("requires a file_store_config") + }; + FileStore::new(file_store_config, config) + } + + fn generate_filler_text_document(uri: Option<&str>, text: Option<&str>) -> TextDocumentItem { + let uri = uri.unwrap_or("file:///filler/"); + let text = text.unwrap_or("Here is the document body"); + TextDocumentItem { + uri: reqwest::Url::parse(uri).unwrap(), + language_id: "filler".to_string(), + version: 0, + text: text.to_string(), + } + } + + #[test] + fn can_open_document() -> anyhow::Result<()> { + let params = lsp_types::DidOpenTextDocumentParams { + text_document: generate_filler_text_document(None, None), + }; + let file_store = generate_base_file_store()?; + file_store.opened_text_document(params)?; + let file = file_store + .file_map + .lock() + .get("file:///filler/") + .unwrap() + .clone(); + assert_eq!(file.rope.to_string(), "Here is the document body"); + Ok(()) + } + + #[test] + fn can_rename_document() -> anyhow::Result<()> { + let params = lsp_types::DidOpenTextDocumentParams { + text_document: generate_filler_text_document(None, None), + }; + let file_store = generate_base_file_store()?; + file_store.opened_text_document(params)?; + + let params = RenameFilesParams { + files: vec![FileRename { + old_uri: "file:///filler/".to_string(), + new_uri: "file:///filler2/".to_string(), + }], + }; + file_store.renamed_files(params)?; + + let file = file_store + .file_map + .lock() + .get("file:///filler2/") + .unwrap() + .clone(); + assert_eq!(file.rope.to_string(), "Here is the document body"); + Ok(()) + } + + #[test] + fn can_change_document() -> anyhow::Result<()> { + let text_document = generate_filler_text_document(None, None); + + let params = DidOpenTextDocumentParams { + text_document: text_document.clone(), + }; + let file_store = generate_base_file_store()?; + file_store.opened_text_document(params)?; + + let params = lsp_types::DidChangeTextDocumentParams { + text_document: VersionedTextDocumentIdentifier { + uri: text_document.uri.clone(), + version: 1, + }, + content_changes: vec![TextDocumentContentChangeEvent { + range: Some(Range { + start: Position { + line: 0, + character: 1, + }, + end: Position { + line: 0, + character: 3, + }, + }), + range_length: None, + text: "a".to_string(), + }], + }; + file_store.changed_text_document(params)?; + let file = file_store + .file_map + .lock() + .get("file:///filler/") + .unwrap() + .clone(); + assert_eq!(file.rope.to_string(), "Hae is the document body"); + + let params = lsp_types::DidChangeTextDocumentParams { + text_document: VersionedTextDocumentIdentifier { + uri: text_document.uri, + version: 1, + }, + content_changes: vec![TextDocumentContentChangeEvent { + range: None, + range_length: None, + text: "abc".to_string(), + }], + }; + file_store.changed_text_document(params)?; + let file = file_store + .file_map + .lock() + .get("file:///filler/") + .unwrap() + .clone(); + assert_eq!(file.rope.to_string(), "abc"); + + Ok(()) + } + + #[tokio::test] + async fn can_build_prompt() -> anyhow::Result<()> { + let text_document = generate_filler_text_document( + None, + Some( + r#"Document Top +Here is a more complicated document + +Some text + +The end with a trailing new line +"#, + ), + ); + + // Test basic completion + let params = lsp_types::DidOpenTextDocumentParams { + text_document: text_document.clone(), + }; + let file_store = generate_base_file_store()?; + file_store.opened_text_document(params)?; + + let prompt = file_store + .build_prompt( + &TextDocumentPositionParams { + text_document: TextDocumentIdentifier { + uri: text_document.uri.clone(), + }, + position: Position { + line: 0, + character: 10, + }, + }, + PromptType::ContextAndCode, + &json!({}), + ) + .await?; + let prompt: ContextAndCodePrompt = prompt.try_into()?; + assert_eq!(prompt.context, ""); + assert_eq!("Document T", prompt.code); + + // Test FIM + let prompt = file_store + .build_prompt( + &TextDocumentPositionParams { + text_document: TextDocumentIdentifier { + uri: text_document.uri.clone(), + }, + position: Position { + line: 0, + character: 10, + }, + }, + PromptType::FIM, + &json!({}), + ) + .await?; + let prompt: FIMPrompt = prompt.try_into()?; + assert_eq!(prompt.prompt, r#"Document T"#); + assert_eq!( + prompt.suffix, + r#"op +Here is a more complicated document + +Some text + +The end with a trailing new line +"# + ); + + // Test chat + let prompt = file_store + .build_prompt( + &TextDocumentPositionParams { + text_document: TextDocumentIdentifier { + uri: text_document.uri.clone(), + }, + position: Position { + line: 0, + character: 10, + }, + }, + PromptType::ContextAndCode, + &json!({ + "messages": [] + }), + ) + .await?; + let prompt: ContextAndCodePrompt = prompt.try_into()?; + assert_eq!(prompt.context, ""); + let text = r#"Document Top +Here is a more complicated document + +Some text + +The end with a trailing new line +"# + .to_string(); + assert_eq!(text, prompt.code); + + // Test multi-file + let text_document2 = generate_filler_text_document( + Some("file:///filler2"), + Some( + r#"Document Top2 +Here is a more complicated document + +Some text + +The end with a trailing new line +"#, + ), + ); + let params = lsp_types::DidOpenTextDocumentParams { + text_document: text_document2.clone(), + }; + file_store.opened_text_document(params)?; + + let prompt = file_store + .build_prompt( + &TextDocumentPositionParams { + text_document: TextDocumentIdentifier { + uri: text_document.uri.clone(), + }, + position: Position { + line: 0, + character: 10, + }, + }, + PromptType::ContextAndCode, + &json!({}), + ) + .await?; + let prompt: ContextAndCodePrompt = prompt.try_into()?; + assert_eq!(prompt.context, ""); + assert_eq!(format!("{}\nDocument T", text_document2.text), prompt.code); + + Ok(()) + } + + #[tokio::test] + async fn test_document_cursor_placement_corner_cases() -> anyhow::Result<()> { + let text_document = generate_filler_text_document(None, Some("test\n")); + let params = lsp_types::DidOpenTextDocumentParams { + text_document: text_document.clone(), + }; + let file_store = generate_base_file_store()?; + file_store.opened_text_document(params)?; + + // Test chat + let prompt = file_store + .build_prompt( + &TextDocumentPositionParams { + text_document: TextDocumentIdentifier { + uri: text_document.uri.clone(), + }, + position: Position { + line: 1, + character: 0, + }, + }, + PromptType::ContextAndCode, + &json!({"messages": []}), + ) + .await?; + let prompt: ContextAndCodePrompt = prompt.try_into()?; + assert_eq!(prompt.context, ""); + let text = r#"test +"# + .to_string(); + assert_eq!(text, prompt.code); + + Ok(()) + } + + #[test] + fn test_file_store_tree_sitter() -> anyhow::Result<()> { + crate::init_logger(); + + let config = Config::default_with_file_store_without_models(); + let file_store_config = if let config::ValidMemoryBackend::FileStore(file_store_config) = + config.config.memory.clone() + { + file_store_config + } else { + anyhow::bail!("requires a file_store_config") + }; + let params = AdditionalFileStoreParams { build_tree: true }; + let file_store = FileStore::new_with_params(file_store_config, config, params)?; + + let uri = "file:///filler/test.rs"; + let text = r#"#[derive(Debug)] +struct Rectangle { + width: u32, + height: u32, +} + +impl Rectangle { + fn area(&self) -> u32 { + + } +} + +fn main() { + let rect1 = Rectangle { + width: 30, + height: 50, + }; + + println!( + "The area of the rectangle is {} square pixels.", + rect1.area() + ); +}"#; + let text_document = TextDocumentItem { + uri: reqwest::Url::parse(uri).unwrap(), + language_id: "".to_string(), + version: 0, + text: text.to_string(), + }; + let params = DidOpenTextDocumentParams { + text_document: text_document.clone(), + }; + + file_store.opened_text_document(params)?; + + // Test insert + let params = lsp_types::DidChangeTextDocumentParams { + text_document: VersionedTextDocumentIdentifier { + uri: text_document.uri.clone(), + version: 1, + }, + content_changes: vec![TextDocumentContentChangeEvent { + range: Some(Range { + start: Position { + line: 8, + character: 0, + }, + end: Position { + line: 8, + character: 0, + }, + }), + range_length: None, + text: " self.width * self.height".to_string(), + }], + }; + file_store.changed_text_document(params)?; + let file = file_store.file_map.lock().get(uri).unwrap().clone(); + assert_eq!(file.tree.unwrap().root_node().to_sexp(), "(source_file (attribute_item (attribute (identifier) arguments: (token_tree (identifier)))) (struct_item name: (type_identifier) body: (field_declaration_list (field_declaration name: (field_identifier) type: (primitive_type)) (field_declaration name: (field_identifier) type: (primitive_type)))) (impl_item type: (type_identifier) body: (declaration_list (function_item name: (identifier) parameters: (parameters (self_parameter (self))) return_type: (primitive_type) body: (block (binary_expression left: (field_expression value: (self) field: (field_identifier)) right: (field_expression value: (self) field: (field_identifier))))))) (function_item name: (identifier) parameters: (parameters) body: (block (let_declaration pattern: (identifier) value: (struct_expression name: (type_identifier) body: (field_initializer_list (field_initializer field: (field_identifier) value: (integer_literal)) (field_initializer field: (field_identifier) value: (integer_literal))))) (expression_statement (macro_invocation macro: (identifier) (token_tree (string_literal (string_content)) (identifier) (identifier) (token_tree)))))))"); + + // Test delete + let params = lsp_types::DidChangeTextDocumentParams { + text_document: VersionedTextDocumentIdentifier { + uri: text_document.uri.clone(), + version: 1, + }, + content_changes: vec![TextDocumentContentChangeEvent { + range: Some(Range { + start: Position { + line: 0, + character: 0, + }, + end: Position { + line: 12, + character: 0, + }, + }), + range_length: None, + text: "".to_string(), + }], + }; + file_store.changed_text_document(params)?; + let file = file_store.file_map.lock().get(uri).unwrap().clone(); + assert_eq!(file.tree.unwrap().root_node().to_sexp(), "(source_file (function_item name: (identifier) parameters: (parameters) body: (block (let_declaration pattern: (identifier) value: (struct_expression name: (type_identifier) body: (field_initializer_list (field_initializer field: (field_identifier) value: (integer_literal)) (field_initializer field: (field_identifier) value: (integer_literal))))) (expression_statement (macro_invocation macro: (identifier) (token_tree (string_literal (string_content)) (identifier) (identifier) (token_tree)))))))"); + + // Test replace + let params = lsp_types::DidChangeTextDocumentParams { + text_document: VersionedTextDocumentIdentifier { + uri: text_document.uri, + version: 1, + }, + content_changes: vec![TextDocumentContentChangeEvent { + range: None, + range_length: None, + text: "fn main() {}".to_string(), + }], + }; + file_store.changed_text_document(params)?; + let file = file_store.file_map.lock().get(uri).unwrap().clone(); + assert_eq!(file.tree.unwrap().root_node().to_sexp(), "(source_file (function_item name: (identifier) parameters: (parameters) body: (block)))"); + + Ok(()) + } +} diff --git a/src/memory_backends/mod.rs b/crates/lsp-ai/src/memory_backends/mod.rs similarity index 88% rename from src/memory_backends/mod.rs rename to crates/lsp-ai/src/memory_backends/mod.rs index 52a8974..9d6fcc5 100644 --- a/src/memory_backends/mod.rs +++ b/crates/lsp-ai/src/memory_backends/mod.rs @@ -18,13 +18,13 @@ pub enum PromptType { #[derive(Clone)] pub struct MemoryRunParams { pub is_for_chat: bool, - pub max_context_length: usize, + pub max_context: usize, } impl From<&Value> for MemoryRunParams { fn from(value: &Value) -> Self { Self { - max_context_length: value["max_context_length"].as_u64().unwrap_or(1024) as usize, + max_context: value["max_context"].as_u64().unwrap_or(1024) as usize, // messages are for most backends, contents are for Gemini is_for_chat: value["messages"].is_array() || value["contents"].is_array(), } @@ -113,22 +113,16 @@ pub trait MemoryBackend { async fn init(&self) -> anyhow::Result<()> { Ok(()) } - async fn opened_text_document(&self, params: DidOpenTextDocumentParams) -> anyhow::Result<()>; - async fn changed_text_document( - &self, - params: DidChangeTextDocumentParams, - ) -> anyhow::Result<()>; - async fn renamed_files(&self, params: RenameFilesParams) -> anyhow::Result<()>; + fn opened_text_document(&self, params: DidOpenTextDocumentParams) -> anyhow::Result<()>; + fn changed_text_document(&self, params: DidChangeTextDocumentParams) -> anyhow::Result<()>; + fn renamed_files(&self, params: RenameFilesParams) -> anyhow::Result<()>; + fn get_filter_text(&self, position: &TextDocumentPositionParams) -> anyhow::Result; async fn build_prompt( &self, position: &TextDocumentPositionParams, prompt_type: PromptType, params: &Value, ) -> anyhow::Result; - async fn get_filter_text( - &self, - position: &TextDocumentPositionParams, - ) -> anyhow::Result; } impl TryFrom for Box { @@ -137,7 +131,7 @@ impl TryFrom for Box { fn try_from(configuration: Config) -> Result { match configuration.config.memory.clone() { ValidMemoryBackend::FileStore(file_store_config) => Ok(Box::new( - file_store::FileStore::new(file_store_config, configuration), + file_store::FileStore::new(file_store_config, configuration)?, )), ValidMemoryBackend::PostgresML(postgresml_config) => Ok(Box::new( postgresml::PostgresML::new(postgresml_config, configuration)?, diff --git a/crates/lsp-ai/src/memory_backends/postgresml/mod.rs b/crates/lsp-ai/src/memory_backends/postgresml/mod.rs new file mode 100644 index 0000000..2c08065 --- /dev/null +++ b/crates/lsp-ai/src/memory_backends/postgresml/mod.rs @@ -0,0 +1,701 @@ +use anyhow::Context; +use lsp_types::TextDocumentPositionParams; +use parking_lot::Mutex; +use pgml::{Collection, Pipeline}; +use rand::{distributions::Alphanumeric, Rng}; +use serde_json::{json, Value}; +use std::{ + collections::HashSet, + io::Read, + sync::{ + mpsc::{self, Sender}, + Arc, + }, + time::Duration, +}; +use tokio::time; +use tracing::{error, instrument, warn}; + +use crate::{ + config::{self, Config}, + crawl::Crawl, + splitters::{Chunk, Splitter}, + utils::{chunk_to_id, tokens_to_estimated_characters, TOKIO_RUNTIME}, +}; + +use super::{ + file_store::{AdditionalFileStoreParams, FileStore}, + ContextAndCodePrompt, FIMPrompt, MemoryBackend, MemoryRunParams, Prompt, PromptType, +}; + +const RESYNC_MAX_FILE_SIZE: u64 = 10_000_000; + +fn format_file_excerpt(uri: &str, excerpt: &str, root_uri: Option<&str>) -> String { + let path = match root_uri { + Some(root_uri) => { + if uri.starts_with(root_uri) { + &uri[root_uri.chars().count()..] + } else { + uri + } + } + None => uri, + }; + format!( + r#"--{path}-- +{excerpt} +"#, + ) +} + +fn chunk_to_document(uri: &str, chunk: Chunk, root_uri: Option<&str>) -> Value { + json!({ + "id": chunk_to_id(uri, &chunk), + "uri": uri, + "text": format_file_excerpt(uri, &chunk.text, root_uri), + "range": chunk.range + }) +} + +async fn split_and_upsert_file( + uri: &str, + collection: &mut Collection, + file_store: Arc, + splitter: Arc>, + root_uri: Option<&str>, +) -> anyhow::Result<()> { + // We need to make sure we don't hold the file_store lock while performing a network call + let chunks = { + file_store + .file_map() + .lock() + .get(uri) + .map(|f| splitter.split(f)) + }; + let chunks = chunks.with_context(|| format!("file not found for splitting: {uri}"))?; + let documents = chunks + .into_iter() + .map(|chunk| chunk_to_document(uri, chunk, root_uri).into()) + .collect(); + collection + .upsert_documents(documents, None) + .await + .context("PGML - Error upserting documents") +} + +#[derive(Clone)] +pub struct PostgresML { + config: Config, + postgresml_config: config::PostgresML, + file_store: Arc, + collection: Collection, + pipeline: Pipeline, + debounce_tx: Sender, + crawl: Option>>, + splitter: Arc>, +} + +impl PostgresML { + #[instrument] + pub fn new( + mut postgresml_config: config::PostgresML, + configuration: Config, + ) -> anyhow::Result { + let crawl = postgresml_config + .crawl + .take() + .map(|x| Arc::new(Mutex::new(Crawl::new(x, configuration.clone())))); + + let splitter: Arc> = + Arc::new(postgresml_config.splitter.clone().try_into()?); + + let file_store = Arc::new(FileStore::new_with_params( + config::FileStore::new_without_crawl(), + configuration.clone(), + AdditionalFileStoreParams::new(splitter.does_use_tree_sitter()), + )?); + + let database_url = if let Some(database_url) = postgresml_config.database_url.clone() { + database_url + } else { + std::env::var("PGML_DATABASE_URL").context("please provide either the `database_url` in the `postgresml` config, or set the `PGML_DATABASE_URL` environment variable")? + }; + + // Build our pipeline schema + let pipeline = match &postgresml_config.embedding_model { + Some(embedding_model) => { + json!({ + "text": { + "semantic_search": { + "model": embedding_model.model, + "parameters": embedding_model.embed_parameters + } + } + }) + } + None => { + json!({ + "text": { + "semantic_search": { + "model": "intfloat/e5-small-v2", + "parameters": { + "prompt": "passage: " + } + } + } + }) + } + }; + + // When building the collection name we include the Pipeline schema + // If the user changes the Pipeline schema, it will take affect without them having to delete the old files + let collection_name = match configuration.client_params.root_uri.clone() { + Some(root_uri) => format!( + "{:x}", + md5::compute( + format!("{root_uri}_{}", serde_json::to_string(&pipeline)?).as_bytes() + ) + ), + None => { + warn!("no root_uri provided in server configuration - generating random string for collection name"); + rand::thread_rng() + .sample_iter(&Alphanumeric) + .take(21) + .map(char::from) + .collect() + } + }; + let mut collection = Collection::new(&collection_name, Some(database_url))?; + let mut pipeline = Pipeline::new("v1", Some(pipeline.into()))?; + + // Add the Pipeline to the Collection + TOKIO_RUNTIME.block_on(async { + collection + .add_pipeline(&mut pipeline) + .await + .context("PGML - error adding pipeline to collection") + })?; + + // Setup up a debouncer for changed text documents + let (debounce_tx, debounce_rx) = mpsc::channel::(); + let mut task_collection = collection.clone(); + let task_file_store = file_store.clone(); + let task_splitter = splitter.clone(); + let task_root_uri = configuration.client_params.root_uri.clone(); + TOKIO_RUNTIME.spawn(async move { + let duration = Duration::from_millis(500); + let mut file_uris = Vec::new(); + loop { + time::sleep(duration).await; + let new_uris: Vec = debounce_rx.try_iter().collect(); + if !new_uris.is_empty() { + for uri in new_uris { + if !file_uris.iter().any(|p| *p == uri) { + file_uris.push(uri); + } + } + } else { + if file_uris.is_empty() { + continue; + } + // Build the chunks for our changed files + let chunks: Vec> = match file_uris + .iter() + .map(|uri| { + let file_store = task_file_store.file_map().lock(); + let file = file_store + .get(uri) + .with_context(|| format!("getting file for splitting: {uri}"))?; + anyhow::Ok(task_splitter.split(file)) + }) + .collect() + { + Ok(chunks) => chunks, + Err(e) => { + error!("{e:?}"); + continue; + } + }; + // Delete old chunks that no longer exist after the latest file changes + let delete_or_statements: Vec = file_uris + .iter() + .zip(&chunks) + .map(|(uri, chunks)| { + let ids: Vec = + chunks.iter().map(|c| chunk_to_id(uri, c)).collect(); + json!({ + "$and": [ + { + "uri": { + "$eq": uri + } + }, + { + "id": { + "$nin": ids + } + } + ] + }) + }) + .collect(); + if let Err(e) = task_collection + .delete_documents( + json!({ + "$or": delete_or_statements + }) + .into(), + ) + .await + .context("PGML - error deleting documents") + { + error!("{e:?}"); + } + // Prepare and upsert our new chunks + let documents: Vec = chunks + .into_iter() + .zip(&file_uris) + .map(|(chunks, uri)| { + chunks + .into_iter() + .map(|chunk| { + chunk_to_document(&uri, chunk, task_root_uri.as_deref()) + }) + .collect::>() + }) + .flatten() + .map(|f: Value| f.into()) + .collect(); + if let Err(e) = task_collection + .upsert_documents(documents, None) + .await + .context("PGML - error upserting changed files") + { + error!("{e:?}"); + continue; + } + + file_uris = Vec::new(); + } + } + }); + + let s = Self { + config: configuration, + postgresml_config, + file_store, + collection, + pipeline, + debounce_tx, + crawl, + splitter, + }; + + // Resync our Collection + let task_s = s.clone(); + TOKIO_RUNTIME.spawn(async move { + if let Err(e) = task_s.resync().await { + error!("{e:?}") + } + }); + + if let Err(e) = s.maybe_do_crawl(None) { + error!("{e:?}") + } + Ok(s) + } + + async fn resync(&self) -> anyhow::Result<()> { + let mut collection = self.collection.clone(); + + let documents = collection + .get_documents(Some( + json!({ + "limit": 100_000_000, + "keys": ["uri"] + }) + .into(), + )) + .await?; + + let try_get_file_contents = |path: &std::path::Path| { + // Open the file and see if it is small enough to read + let mut f = std::fs::File::open(path)?; + let metadata = f.metadata()?; + if metadata.len() > RESYNC_MAX_FILE_SIZE { + anyhow::bail!("file size is greater than: {RESYNC_MAX_FILE_SIZE}") + } + // Read the file contents + let mut contents = vec![]; + f.read_to_end(&mut contents)?; + anyhow::Ok(String::from_utf8(contents)?) + }; + + let mut documents_to_delete = vec![]; + let mut chunks_to_upsert = vec![]; + let mut current_chunks_bytes = 0; + let mut checked_uris = HashSet::new(); + for document in documents.into_iter() { + let uri = match document["document"]["uri"].as_str() { + Some(uri) => uri, + None => continue, // This should never happen, but is really bad as we now have a document with essentially no way to delete it + }; + + // Check if we have already loaded in this file + if checked_uris.contains(uri) { + continue; + } + checked_uris.insert(uri.to_string()); + + let path = uri.replace("file://", ""); + let path = std::path::Path::new(&path); + if !path.exists() { + documents_to_delete.push(uri.to_string()); + } else { + // Try to read the file. If we fail delete it + let contents = match try_get_file_contents(path) { + Ok(contents) => contents, + Err(e) => { + error!("{e:?}"); + documents_to_delete.push(uri.to_string()); + continue; + } + }; + // Split the file into chunks + current_chunks_bytes += contents.len(); + let chunks: Vec = self + .splitter + .split_file_contents(&uri, &contents) + .into_iter() + .map(|chunk| { + chunk_to_document( + &uri, + chunk, + self.config.client_params.root_uri.as_deref(), + ) + .into() + }) + .collect(); + chunks_to_upsert.extend(chunks); + // If we have over 10 mega bytes of chunks do the upsert + if current_chunks_bytes > 10_000_000 { + collection + .upsert_documents(chunks_to_upsert, None) + .await + .context("PGML - error upserting documents during resync")?; + chunks_to_upsert = vec![]; + current_chunks_bytes = 0; + } + } + } + // Upsert any remaining chunks + if chunks_to_upsert.len() > 0 { + collection + .upsert_documents(chunks_to_upsert, None) + .await + .context("PGML - error upserting documents during resync")?; + } + // Delete documents + if !documents_to_delete.is_empty() { + collection + .delete_documents( + json!({ + "uri": { + "$in": documents_to_delete + } + }) + .into(), + ) + .await + .context("PGML - error deleting documents during resync")?; + } + Ok(()) + } + + fn maybe_do_crawl(&self, triggered_file: Option) -> anyhow::Result<()> { + if let Some(crawl) = &self.crawl { + let mut documents = vec![]; + let mut total_bytes = 0; + let mut current_bytes = 0; + crawl + .lock() + .maybe_do_crawl(triggered_file, |config, path| { + // Break if total bytes is over the max crawl memory + if total_bytes as u64 >= config.max_crawl_memory { + warn!("Ending crawl early due to `max_crawl_memory` restraint"); + return Ok(false); + } + // This means it has been opened before + let uri = format!("file://{path}"); + if self.file_store.contains_file(&uri) { + return Ok(true); + } + // Open the file and see if it is small enough to read + let mut f = std::fs::File::open(path)?; + let metadata = f.metadata()?; + if metadata.len() > config.max_file_size { + warn!("Skipping file: {path} because it is too large"); + return Ok(true); + } + // Read the file contents + let mut contents = vec![]; + f.read_to_end(&mut contents)?; + let contents = String::from_utf8(contents)?; + current_bytes += contents.len(); + total_bytes += contents.len(); + let chunks: Vec = self + .splitter + .split_file_contents(&uri, &contents) + .into_iter() + .map(|chunk| { + chunk_to_document( + &uri, + chunk, + self.config.client_params.root_uri.as_deref(), + ) + .into() + }) + .collect(); + documents.extend(chunks); + // If we have over 10 mega bytes of data do the upsert + if current_bytes >= 10_000_000 || total_bytes as u64 >= config.max_crawl_memory + { + // Upsert the documents + let mut collection = self.collection.clone(); + let to_upsert_documents = std::mem::take(&mut documents); + TOKIO_RUNTIME.spawn(async move { + if let Err(e) = collection + .upsert_documents(to_upsert_documents, None) + .await + .context("PGML - error upserting changed files") + { + error!("{e:?}"); + } + }); + // Reset everything + current_bytes = 0; + documents = vec![]; + } + Ok(true) + })?; + // Upsert any remaining documents + if documents.len() > 0 { + let mut collection = self.collection.clone(); + TOKIO_RUNTIME.spawn(async move { + if let Err(e) = collection + .upsert_documents(documents, None) + .await + .context("PGML - error upserting changed files") + { + error!("{e:?}"); + } + }); + } + } + Ok(()) + } +} + +#[async_trait::async_trait] +impl MemoryBackend for PostgresML { + #[instrument(skip(self))] + fn get_filter_text(&self, position: &TextDocumentPositionParams) -> anyhow::Result { + self.file_store.get_filter_text(position) + } + + #[instrument(skip(self))] + async fn build_prompt( + &self, + position: &TextDocumentPositionParams, + prompt_type: PromptType, + params: &Value, + ) -> anyhow::Result { + let params: MemoryRunParams = params.try_into()?; + let chunk_size = self.splitter.chunk_size(); + let total_allowed_characters = tokens_to_estimated_characters(params.max_context); + + // Build the query + let query = self + .file_store + .get_characters_around_position(position, chunk_size)?; + + // Build the prompt + let mut file_store_params = params.clone(); + file_store_params.max_context = chunk_size; + let code = self + .file_store + .build_code(position, prompt_type, file_store_params, false)?; + + // Get the byte of the cursor + let cursor_byte = self.file_store.position_to_byte(position)?; + + // Get the context + let limit = (total_allowed_characters / chunk_size).saturating_sub(1); + let parameters = match self + .postgresml_config + .embedding_model + .as_ref() + .map(|m| m.query_parameters.clone()) + .flatten() + { + Some(query_parameters) => query_parameters, + None => json!({ + "prompt": "query: " + }), + }; + let res = self + .collection + .vector_search_local( + json!({ + "query": { + "fields": { + "text": { + "query": query, + "parameters": parameters + } + }, + "filter": { + "$or": [ + { + "uri": { + "$ne": position.text_document.uri.to_string() + } + }, + { + "range": { + "start": { + "$gt": cursor_byte + }, + }, + }, + { + "range": { + "end": { + "$lt": cursor_byte + }, + } + } + ] + } + }, + "limit": limit + }) + .into(), + &self.pipeline, + ) + .await?; + let context = res + .into_iter() + .map(|c| { + c["chunk"] + .as_str() + .map(|t| t.to_owned()) + .context("PGML - Error getting chunk from vector search") + }) + .collect::>>()? + .join("\n\n"); + let context = &context[..(total_allowed_characters - chunk_size).min(context.len())]; + + // Reconstruct the Prompts + Ok(match code { + Prompt::ContextAndCode(context_and_code) => { + Prompt::ContextAndCode(ContextAndCodePrompt::new( + context.to_owned(), + format_file_excerpt( + &position.text_document.uri.to_string(), + &context_and_code.code, + self.config.client_params.root_uri.as_deref(), + ), + )) + } + Prompt::FIM(fim) => Prompt::FIM(FIMPrompt::new( + format!("{context}\n\n{}", fim.prompt), + fim.suffix, + )), + }) + } + + #[instrument(skip(self))] + fn opened_text_document( + &self, + params: lsp_types::DidOpenTextDocumentParams, + ) -> anyhow::Result<()> { + self.file_store.opened_text_document(params.clone())?; + + let saved_uri = params.text_document.uri.to_string(); + + let mut collection = self.collection.clone(); + let file_store = self.file_store.clone(); + let splitter = self.splitter.clone(); + let root_uri = self.config.client_params.root_uri.clone(); + TOKIO_RUNTIME.spawn(async move { + let uri = params.text_document.uri.to_string(); + if let Err(e) = split_and_upsert_file( + &uri, + &mut collection, + file_store, + splitter, + root_uri.as_deref(), + ) + .await + { + error!("{e:?}") + } + }); + + if let Err(e) = self.maybe_do_crawl(Some(saved_uri)) { + error!("{e:?}") + } + Ok(()) + } + + #[instrument(skip(self))] + fn changed_text_document( + &self, + params: lsp_types::DidChangeTextDocumentParams, + ) -> anyhow::Result<()> { + self.file_store.changed_text_document(params.clone())?; + let uri = params.text_document.uri.to_string(); + self.debounce_tx.send(uri)?; + Ok(()) + } + + #[instrument(skip(self))] + fn renamed_files(&self, params: lsp_types::RenameFilesParams) -> anyhow::Result<()> { + self.file_store.renamed_files(params.clone())?; + + let mut collection = self.collection.clone(); + let file_store = self.file_store.clone(); + let splitter = self.splitter.clone(); + let root_uri = self.config.client_params.root_uri.clone(); + TOKIO_RUNTIME.spawn(async move { + for file in params.files { + if let Err(e) = collection + .delete_documents( + json!({ + "uri": { + "$eq": file.old_uri + } + }) + .into(), + ) + .await + { + error!("PGML - Error deleting file: {e:?}"); + } + if let Err(e) = split_and_upsert_file( + &file.new_uri, + &mut collection, + file_store.clone(), + splitter.clone(), + root_uri.as_deref(), + ) + .await + { + error!("{e:?}") + } + } + }); + Ok(()) + } +} diff --git a/src/memory_worker.rs b/crates/lsp-ai/src/memory_worker.rs similarity index 70% rename from src/memory_worker.rs rename to crates/lsp-ai/src/memory_worker.rs index 39cad6c..1b7a481 100644 --- a/src/memory_worker.rs +++ b/crates/lsp-ai/src/memory_worker.rs @@ -7,7 +7,10 @@ use lsp_types::{ use serde_json::Value; use tracing::error; -use crate::memory_backends::{MemoryBackend, Prompt, PromptType}; +use crate::{ + memory_backends::{MemoryBackend, Prompt, PromptType}, + utils::TOKIO_RUNTIME, +}; #[derive(Debug)] pub struct PromptRequest { @@ -56,34 +59,45 @@ pub enum WorkerRequest { DidRenameFiles(RenameFilesParams), } -async fn do_task( +async fn do_build_prompt( + params: PromptRequest, + memory_backend: Arc>, +) -> anyhow::Result<()> { + let prompt = memory_backend + .build_prompt(¶ms.position, params.prompt_type, ¶ms.params) + .await?; + params + .tx + .send(prompt) + .map_err(|_| anyhow::anyhow!("sending on channel failed")) +} + +fn do_task( request: WorkerRequest, memory_backend: Arc>, ) -> anyhow::Result<()> { match request { WorkerRequest::FilterText(params) => { - let filter_text = memory_backend.get_filter_text(¶ms.position).await?; + let filter_text = memory_backend.get_filter_text(¶ms.position)?; params .tx .send(filter_text) .map_err(|_| anyhow::anyhow!("sending on channel failed"))?; } WorkerRequest::Prompt(params) => { - let prompt = memory_backend - .build_prompt(¶ms.position, params.prompt_type, ¶ms.params) - .await?; - params - .tx - .send(prompt) - .map_err(|_| anyhow::anyhow!("sending on channel failed"))?; + TOKIO_RUNTIME.spawn(async move { + if let Err(e) = do_build_prompt(params, memory_backend).await { + error!("error in memory worker building prompt: {e}") + } + }); } WorkerRequest::DidOpenTextDocument(params) => { - memory_backend.opened_text_document(params).await?; + memory_backend.opened_text_document(params)?; } WorkerRequest::DidChangeTextDocument(params) => { - memory_backend.changed_text_document(params).await?; + memory_backend.changed_text_document(params)?; } - WorkerRequest::DidRenameFiles(params) => memory_backend.renamed_files(params).await?, + WorkerRequest::DidRenameFiles(params) => memory_backend.renamed_files(params)?, } anyhow::Ok(()) } @@ -93,18 +107,11 @@ fn do_run( rx: std::sync::mpsc::Receiver, ) -> anyhow::Result<()> { let memory_backend = Arc::new(memory_backend); - let runtime = tokio::runtime::Builder::new_multi_thread() - .worker_threads(4) - .enable_all() - .build()?; loop { let request = rx.recv()?; - let thread_memory_backend = memory_backend.clone(); - runtime.spawn(async move { - if let Err(e) = do_task(request, thread_memory_backend).await { - error!("error in memory worker task: {e}") - } - }); + if let Err(e) = do_task(request, memory_backend.clone()) { + error!("error in memory worker task: {e}") + } } } diff --git a/crates/lsp-ai/src/splitters/mod.rs b/crates/lsp-ai/src/splitters/mod.rs new file mode 100644 index 0000000..72db6b7 --- /dev/null +++ b/crates/lsp-ai/src/splitters/mod.rs @@ -0,0 +1,59 @@ +use serde::Serialize; + +use crate::{config::ValidSplitter, memory_backends::file_store::File}; + +mod text_splitter; +mod tree_sitter; + +#[derive(Serialize)] +pub struct ByteRange { + pub start_byte: usize, + pub end_byte: usize, +} + +impl ByteRange { + pub fn new(start_byte: usize, end_byte: usize) -> Self { + Self { + start_byte, + end_byte, + } + } +} + +#[derive(Serialize)] +pub struct Chunk { + pub text: String, + pub range: ByteRange, +} + +impl Chunk { + fn new(text: String, range: ByteRange) -> Self { + Self { text, range } + } +} + +pub trait Splitter { + fn split(&self, file: &File) -> Vec; + fn split_file_contents(&self, uri: &str, contents: &str) -> Vec; + + fn does_use_tree_sitter(&self) -> bool { + false + } + + fn chunk_size(&self) -> usize; +} + +impl TryFrom for Box { + type Error = anyhow::Error; + + fn try_from(value: ValidSplitter) -> Result { + match value { + ValidSplitter::TreeSitter(config) => { + Ok(Box::new(tree_sitter::TreeSitter::new(config)?)) + } + ValidSplitter::TextSplitter(config) => { + Ok(Box::new(text_splitter::TextSplitter::new(config))) + } + } + } +} diff --git a/crates/lsp-ai/src/splitters/text_splitter.rs b/crates/lsp-ai/src/splitters/text_splitter.rs new file mode 100644 index 0000000..9b280a1 --- /dev/null +++ b/crates/lsp-ai/src/splitters/text_splitter.rs @@ -0,0 +1,47 @@ +use crate::{config, memory_backends::file_store::File}; + +use super::{ByteRange, Chunk, Splitter}; + +pub struct TextSplitter { + chunk_size: usize, + splitter: text_splitter::TextSplitter, +} + +impl TextSplitter { + pub fn new(config: config::TextSplitter) -> Self { + Self { + chunk_size: config.chunk_size, + splitter: text_splitter::TextSplitter::new(config.chunk_size), + } + } + + pub fn new_with_chunk_size(chunk_size: usize) -> Self { + Self { + chunk_size, + splitter: text_splitter::TextSplitter::new(chunk_size), + } + } +} + +impl Splitter for TextSplitter { + fn split(&self, file: &File) -> Vec { + self.split_file_contents("", &file.rope().to_string()) + } + + fn split_file_contents(&self, _uri: &str, contents: &str) -> Vec { + self.splitter + .chunk_indices(contents) + .fold(vec![], |mut acc, (start_byte, text)| { + let end_byte = start_byte + text.len(); + acc.push(Chunk::new( + text.to_string(), + ByteRange::new(start_byte, end_byte), + )); + acc + }) + } + + fn chunk_size(&self) -> usize { + self.chunk_size + } +} diff --git a/crates/lsp-ai/src/splitters/tree_sitter.rs b/crates/lsp-ai/src/splitters/tree_sitter.rs new file mode 100644 index 0000000..dbbb9ce --- /dev/null +++ b/crates/lsp-ai/src/splitters/tree_sitter.rs @@ -0,0 +1,84 @@ +use splitter_tree_sitter::TreeSitterCodeSplitter; +use tracing::error; +use tree_sitter::Tree; + +use crate::{config, memory_backends::file_store::File, utils::parse_tree}; + +use super::{text_splitter::TextSplitter, ByteRange, Chunk, Splitter}; + +pub struct TreeSitter { + chunk_size: usize, + splitter: TreeSitterCodeSplitter, + text_splitter: TextSplitter, +} + +impl TreeSitter { + pub fn new(config: config::TreeSitter) -> anyhow::Result { + let text_splitter = TextSplitter::new_with_chunk_size(config.chunk_size); + Ok(Self { + chunk_size: config.chunk_size, + splitter: TreeSitterCodeSplitter::new(config.chunk_size, config.chunk_overlap)?, + text_splitter, + }) + } + + fn split_tree(&self, tree: &Tree, contents: &[u8]) -> anyhow::Result> { + Ok(self + .splitter + .split(tree, contents)? + .into_iter() + .map(|c| { + Chunk::new( + c.text.to_owned(), + ByteRange::new(c.range.start_byte, c.range.end_byte), + ) + }) + .collect()) + } +} + +impl Splitter for TreeSitter { + fn split(&self, file: &File) -> Vec { + if let Some(tree) = file.tree() { + match self.split_tree(tree, file.rope().to_string().as_bytes()) { + Ok(chunks) => chunks, + Err(e) => { + error!( + "Failed to parse tree for file with error: {e:?}. Falling back to default splitter.", + ); + self.text_splitter.split(file) + } + } + } else { + self.text_splitter.split(file) + } + } + + fn split_file_contents(&self, uri: &str, contents: &str) -> Vec { + match parse_tree(uri, contents, None) { + Ok(tree) => match self.split_tree(&tree, contents.as_bytes()) { + Ok(chunks) => chunks, + Err(e) => { + error!( + "Failed to parse tree for file: {uri} with error: {e:?}. Falling back to default splitter.", + ); + self.text_splitter.split_file_contents(uri, contents) + } + }, + Err(e) => { + error!( + "Failed to parse tree for file {uri} with error: {e:?}. Falling back to default splitter.", + ); + self.text_splitter.split_file_contents(uri, contents) + } + } + } + + fn does_use_tree_sitter(&self) -> bool { + true + } + + fn chunk_size(&self) -> usize { + self.chunk_size + } +} diff --git a/src/template.rs b/crates/lsp-ai/src/template.rs similarity index 100% rename from src/template.rs rename to crates/lsp-ai/src/template.rs diff --git a/src/transformer_backends/anthropic.rs b/crates/lsp-ai/src/transformer_backends/anthropic.rs similarity index 100% rename from src/transformer_backends/anthropic.rs rename to crates/lsp-ai/src/transformer_backends/anthropic.rs diff --git a/src/transformer_backends/gemini.rs b/crates/lsp-ai/src/transformer_backends/gemini.rs similarity index 100% rename from src/transformer_backends/gemini.rs rename to crates/lsp-ai/src/transformer_backends/gemini.rs diff --git a/src/transformer_backends/llama_cpp/mod.rs b/crates/lsp-ai/src/transformer_backends/llama_cpp/mod.rs similarity index 100% rename from src/transformer_backends/llama_cpp/mod.rs rename to crates/lsp-ai/src/transformer_backends/llama_cpp/mod.rs diff --git a/src/transformer_backends/llama_cpp/model.rs b/crates/lsp-ai/src/transformer_backends/llama_cpp/model.rs similarity index 100% rename from src/transformer_backends/llama_cpp/model.rs rename to crates/lsp-ai/src/transformer_backends/llama_cpp/model.rs diff --git a/src/transformer_backends/mistral_fim.rs b/crates/lsp-ai/src/transformer_backends/mistral_fim.rs similarity index 100% rename from src/transformer_backends/mistral_fim.rs rename to crates/lsp-ai/src/transformer_backends/mistral_fim.rs diff --git a/src/transformer_backends/mod.rs b/crates/lsp-ai/src/transformer_backends/mod.rs similarity index 100% rename from src/transformer_backends/mod.rs rename to crates/lsp-ai/src/transformer_backends/mod.rs diff --git a/src/transformer_backends/ollama.rs b/crates/lsp-ai/src/transformer_backends/ollama.rs similarity index 94% rename from src/transformer_backends/ollama.rs rename to crates/lsp-ai/src/transformer_backends/ollama.rs index 6f1a6b1..16486bf 100644 --- a/src/transformer_backends/ollama.rs +++ b/crates/lsp-ai/src/transformer_backends/ollama.rs @@ -67,11 +67,11 @@ impl Ollama { ) -> anyhow::Result { let client = reqwest::Client::new(); let res: OllamaCompletionsResponse = client - .post(self - .configuration - .generate_endpoint - .as_deref() - .unwrap_or("http://localhost:11434/api/generate") + .post( + self.configuration + .generate_endpoint + .as_deref() + .unwrap_or("http://localhost:11434/api/generate"), ) .header("Content-Type", "application/json") .header("Accept", "application/json") @@ -106,11 +106,11 @@ impl Ollama { ) -> anyhow::Result { let client = reqwest::Client::new(); let res: OllamaChatResponse = client - .post(self - .configuration - .chat_endpoint - .as_deref() - .unwrap_or("http://localhost:11434/api/chat") + .post( + self.configuration + .chat_endpoint + .as_deref() + .unwrap_or("http://localhost:11434/api/chat"), ) .header("Content-Type", "application/json") .header("Accept", "application/json") diff --git a/src/transformer_backends/open_ai/mod.rs b/crates/lsp-ai/src/transformer_backends/open_ai/mod.rs similarity index 100% rename from src/transformer_backends/open_ai/mod.rs rename to crates/lsp-ai/src/transformer_backends/open_ai/mod.rs diff --git a/src/transformer_worker.rs b/crates/lsp-ai/src/transformer_worker.rs similarity index 78% rename from src/transformer_worker.rs rename to crates/lsp-ai/src/transformer_worker.rs index 196447b..7766a11 100644 --- a/src/transformer_worker.rs +++ b/crates/lsp-ai/src/transformer_worker.rs @@ -17,7 +17,7 @@ use crate::custom_requests::generation_stream::GenerationStreamParams; use crate::memory_backends::Prompt; use crate::memory_worker::{self, FilterRequest, PromptRequest}; use crate::transformer_backends::TransformerBackend; -use crate::utils::ToResponseError; +use crate::utils::{ToResponseError, TOKIO_RUNTIME}; #[derive(Clone, Debug)] pub struct CompletionRequest { @@ -89,14 +89,14 @@ pub struct DoGenerationStreamResponse { fn post_process_start(response: String, front: &str) -> String { let mut front_match = response.len(); loop { - if response.len() == 0 || front.ends_with(&response[..front_match]) { + if response.is_empty() || front.ends_with(&response[..front_match]) { break; } else { front_match -= 1; } } if front_match > 0 { - (&response[front_match..]).to_owned() + response[front_match..].to_owned() } else { response } @@ -105,16 +105,14 @@ fn post_process_start(response: String, front: &str) -> String { fn post_process_end(response: String, back: &str) -> String { let mut back_match = 0; loop { - if back_match == response.len() { - break; - } else if back.starts_with(&response[back_match..]) { + if back_match == response.len() || back.starts_with(&response[back_match..]) { break; } else { back_match += 1; } } if back_match > 0 { - (&response[..back_match]).to_owned() + response[..back_match].to_owned() } else { response } @@ -140,12 +138,10 @@ fn post_process_response( } else { response } + } else if config.remove_duplicate_start { + post_process_start(response, &context_and_code.code) } else { - if config.remove_duplicate_start { - post_process_start(response, &context_and_code.code) - } else { - response - } + response } } Prompt::FIM(fim) => { @@ -177,7 +173,7 @@ pub fn run( connection, config, ) { - error!("error in transformer worker: {e}") + error!("error in transformer worker: {e:?}") } } @@ -189,10 +185,6 @@ fn do_run( config: Config, ) -> anyhow::Result<()> { let transformer_backends = Arc::new(transformer_backends); - let runtime = tokio::runtime::Builder::new_multi_thread() - .worker_threads(4) - .enable_all() - .build()?; // If they have disabled completions, this function will fail. We set it to MIN_POSITIVE to never process a completions request let max_requests_per_second = config @@ -206,7 +198,7 @@ fn do_run( let task_transformer_backends = transformer_backends.clone(); let task_memory_backend_tx = memory_backend_tx.clone(); let task_config = config.clone(); - runtime.spawn(async move { + TOKIO_RUNTIME.spawn(async move { dispatch_request( request, task_connection, @@ -264,7 +256,7 @@ async fn dispatch_request( { Ok(response) => response, Err(e) => { - error!("generating response: {e}"); + error!("generating response: {e:?}"); Response { id: request.get_id(), result: None, @@ -274,7 +266,7 @@ async fn dispatch_request( }; if let Err(e) = connection.sender.send(Message::Response(response)) { - error!("sending response: {e}"); + error!("sending response: {e:?}"); } } @@ -293,14 +285,12 @@ async fn generate_response( .context("Completions is none")?; let transformer_backend = transformer_backends .get(&completion_config.model) - .clone() .with_context(|| format!("can't find model: {}", &completion_config.model))?; do_completion(transformer_backend, memory_backend_tx, &request, &config).await } WorkerRequest::Generation(request) => { let transformer_backend = transformer_backends .get(&request.params.model) - .clone() .with_context(|| format!("can't find model: {}", &request.params.model))?; do_generate(transformer_backend, memory_backend_tx, &request).await } @@ -346,7 +336,6 @@ async fn do_completion( // Get the response let mut response = transformer_backend.do_completion(&prompt, params).await?; - eprintln!("\n\n\n\nGOT RESPONSE: {}\n\n\n\n", response.insert_text); if let Some(post_process) = config.get_completions_post_process() { response.insert_text = post_process_response(response.insert_text, &prompt, &post_process); @@ -423,7 +412,103 @@ async fn do_generate( #[cfg(test)] mod tests { use super::*; - use crate::memory_backends::{ContextAndCodePrompt, FIMPrompt}; + use crate::memory_backends::{ + file_store::FileStore, ContextAndCodePrompt, FIMPrompt, MemoryBackend, + }; + use serde_json::json; + use std::{sync::mpsc, thread}; + + #[tokio::test] + async fn test_do_completion() -> anyhow::Result<()> { + let (memory_tx, memory_rx) = mpsc::channel(); + let memory_backend: Box = + Box::new(FileStore::default_with_filler_file()?); + thread::spawn(move || memory_worker::run(memory_backend, memory_rx)); + + let transformer_backend: Box = + config::ValidModel::Ollama(serde_json::from_value( + json!({"model": "deepseek-coder:1.3b-base"}), + )?) + .try_into()?; + let completion_request = CompletionRequest::new( + serde_json::from_value(json!(0))?, + serde_json::from_value(json!({ + "position": {"character":10, "line":2}, + "textDocument": { + "uri": "file:///filler.py" + } + }))?, + ); + let mut config = config::Config::default_with_file_store_without_models(); + config.config.completion = Some(serde_json::from_value(json!({ + "model": "model1", + "parameters": { + "options": { + "temperature": 0 + } + } + }))?); + + let result = do_completion( + &transformer_backend, + memory_tx, + &completion_request, + &config, + ) + .await?; + + assert_eq!( + " x * y", + result.result.clone().unwrap()["items"][0]["textEdit"]["newText"] + .as_str() + .unwrap() + ); + assert_eq!( + " return", + result.result.unwrap()["items"][0]["filterText"] + .as_str() + .unwrap() + ); + + Ok(()) + } + + #[tokio::test] + async fn test_do_generate() -> anyhow::Result<()> { + let (memory_tx, memory_rx) = mpsc::channel(); + let memory_backend: Box = + Box::new(FileStore::default_with_filler_file()?); + thread::spawn(move || memory_worker::run(memory_backend, memory_rx)); + + let transformer_backend: Box = + config::ValidModel::Ollama(serde_json::from_value( + json!({"model": "deepseek-coder:1.3b-base"}), + )?) + .try_into()?; + let generation_request = GenerationRequest::new( + serde_json::from_value(json!(0))?, + serde_json::from_value(json!({ + "position": {"character":10, "line":2}, + "textDocument": { + "uri": "file:///filler.py" + }, + "model": "model1", + "parameters": { + "options": { + "temperature": 0 + } + } + }))?, + ); + let result = do_generate(&transformer_backend, memory_tx, &generation_request).await?; + + assert_eq!( + " x * y", + result.result.unwrap()["generatedText"].as_str().unwrap() + ); + + Ok(()) + } #[test] fn test_post_process_fim() { diff --git a/src/utils.rs b/crates/lsp-ai/src/utils.rs similarity index 54% rename from src/utils.rs rename to crates/lsp-ai/src/utils.rs index ea5d652..8b5b8b4 100644 --- a/src/utils.rs +++ b/crates/lsp-ai/src/utils.rs @@ -1,6 +1,18 @@ +use anyhow::Context; use lsp_server::ResponseError; +use once_cell::sync::Lazy; +use tokio::runtime; +use tree_sitter::Tree; -use crate::{config::ChatMessage, memory_backends::ContextAndCodePrompt}; +use crate::{config::ChatMessage, memory_backends::ContextAndCodePrompt, splitters::Chunk}; + +pub static TOKIO_RUNTIME: Lazy = Lazy::new(|| { + runtime::Builder::new_multi_thread() + .worker_threads(4) + .enable_all() + .build() + .expect("Error building tokio runtime") +}); pub trait ToResponseError { fn to_response_error(&self, code: i32) -> ResponseError; @@ -42,3 +54,17 @@ pub fn format_context_code_in_str(s: &str, context: &str, code: &str) -> String pub fn format_context_code(context: &str, code: &str) -> String { format!("{context}\n\n{code}") } + +pub fn chunk_to_id(uri: &str, chunk: &Chunk) -> String { + format!("{uri}#{}-{}", chunk.range.start_byte, chunk.range.end_byte) +} + +pub fn parse_tree(uri: &str, contents: &str, old_tree: Option<&Tree>) -> anyhow::Result { + let path = std::path::Path::new(uri); + let extension = path.extension().map(|x| x.to_string_lossy()); + let extension = extension.as_deref().unwrap_or(""); + let mut parser = utils_tree_sitter::get_parser_for_extension(extension)?; + parser + .parse(&contents, old_tree) + .with_context(|| format!("parsing tree failed for {uri}")) +} diff --git a/crates/lsp-ai/tests/integration_tests.rs b/crates/lsp-ai/tests/integration_tests.rs new file mode 100644 index 0000000..b523e12 --- /dev/null +++ b/crates/lsp-ai/tests/integration_tests.rs @@ -0,0 +1,289 @@ +use anyhow::Result; +use std::{ + io::{Read, Write}, + process::{ChildStdin, ChildStdout, Command, Stdio}, +}; + +// Note if you get an empty response with no error, that typically means +// the language server died +fn read_response(stdout: &mut ChildStdout) -> Result { + let mut content_length = None; + let mut buf = vec![]; + loop { + let mut buf2 = vec![0]; + stdout.read_exact(&mut buf2)?; + buf.push(buf2[0]); + if let Some(content_length) = content_length { + if buf.len() == content_length { + break; + } + } else { + let len = buf.len(); + if len > 4 + && buf[len - 4] == 13 + && buf[len - 3] == 10 + && buf[len - 2] == 13 + && buf[len - 1] == 10 + { + content_length = + Some(String::from_utf8(buf[16..len - 4].to_vec())?.parse::()?); + buf = vec![]; + } + } + } + Ok(String::from_utf8(buf)?) +} + +fn send_message(stdin: &mut ChildStdin, message: &str) -> Result<()> { + stdin.write_all(format!("Content-Length: {}\r\n", message.as_bytes().len(),).as_bytes())?; + stdin.write_all("\r\n".as_bytes())?; + stdin.write_all(message.as_bytes())?; + Ok(()) +} + +// This chat completion sequence was created using helix with lsp-ai and reading the logs +// It utilizes Ollama with llama3:8b-instruct-q4_0 and a temperature of 0 +// It starts with a Python file: +// ``` +// # Multiplies two numbers +// def multiply_two_numbers(x, y): +// +// # A singular test +// assert multiply_two_numbers(2, 3) == 6 +// +// ``` +// And has the following sequence of key strokes: +// o on line 2 (this creates an indented new line and enters insert mode) +// r +// e +// t +// u +// r +// n +// The sequence has: +// - 1 textDocument/DidOpen notification +// - 7 textDocument/didChange notifications +// - 1 textDocument/completion requests +#[test] +fn test_chat_completion_sequence() -> Result<()> { + let mut child = Command::new("cargo") + .arg("run") + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .spawn()?; + + let mut stdin = child.stdin.take().unwrap(); + let mut stdout = child.stdout.take().unwrap(); + + let initialization_message = r##"{"jsonrpc":"2.0","method":"initialize","params":{"capabilities":{"general":{"positionEncodings":["utf-8","utf-32","utf-16"]},"textDocument":{"codeAction":{"codeActionLiteralSupport":{"codeActionKind":{"valueSet":["","quickfix","refactor","refactor.extract","refactor.inline","refactor.rewrite","source","source.organizeImports"]}},"dataSupport":true,"disabledSupport":true,"isPreferredSupport":true,"resolveSupport":{"properties":["edit","command"]}},"completion":{"completionItem":{"deprecatedSupport":true,"insertReplaceSupport":true,"resolveSupport":{"properties":["documentation","detail","additionalTextEdits"]},"snippetSupport":true,"tagSupport":{"valueSet":[1]}},"completionItemKind":{}},"hover":{"contentFormat":["markdown"]},"inlayHint":{"dynamicRegistration":false},"publishDiagnostics":{"tagSupport":{"valueSet":[1,2]},"versionSupport":true},"rename":{"dynamicRegistration":false,"honorsChangeAnnotations":false,"prepareSupport":true},"signatureHelp":{"signatureInformation":{"activeParameterSupport":true,"documentationFormat":["markdown"],"parameterInformation":{"labelOffsetSupport":true}}}},"window":{"workDoneProgress":true},"workspace":{"applyEdit":true,"configuration":true,"didChangeConfiguration":{"dynamicRegistration":false},"didChangeWatchedFiles":{"dynamicRegistration":true,"relativePatternSupport":false},"executeCommand":{"dynamicRegistration":false},"fileOperations":{"didRename":true,"willRename":true},"inlayHint":{"refreshSupport":false},"symbol":{"dynamicRegistration":false},"workspaceEdit":{"documentChanges":true,"failureHandling":"abort","normalizesLineEndings":false,"resourceOperations":["create","rename","delete"]},"workspaceFolders":true}},"clientInfo":{"name":"helix","version":"24.3 (beb5afcb)"},"initializationOptions":{"completion":{"model":"model1","parameters":{"max_context":1024,"messages":[{"content":"Instructions:\n- You are an AI programming assistant.\n- Given a piece of code with the cursor location marked by \"\", replace \"\" with the correct code or comment.\n- First, think step-by-step.\n- Describe your plan for what to build in pseudocode, written out in great detail.\n- Then output the code replacing the \"\"\n- Ensure that your completion fits within the language context of the provided code snippet (e.g., Python, JavaScript, Rust).\n\nRules:\n- Only respond with code or comments.\n- Only replace \"\"; do not include any previously written code.\n- Never include \"\" in your response\n- If the cursor is within a comment, complete the comment meaningfully.\n- Handle ambiguous cases by providing the most contextually appropriate completion.\n- Be consistent with your responses.","role":"system"},{"content":"def greet(name):\n print(f\"Hello, {}\")","role":"user"},{"content":"name","role":"assistant"},{"content":"function sum(a, b) {\n return a + ;\n}","role":"user"},{"content":"b","role":"assistant"},{"content":"fn multiply(a: i32, b: i32) -> i32 {\n a * \n}","role":"user"},{"content":"b","role":"assistant"},{"content":"# \ndef add(a, b):\n return a + b","role":"user"},{"content":"Adds two numbers","role":"assistant"},{"content":"# This function checks if a number is even\n","role":"user"},{"content":"def is_even(n):\n return n % 2 == 0","role":"assistant"},{"content":"{CODE}","role":"user"}],"options":{"num_predict":32,"temperature":0}}},"memory":{"file_store":{}},"models":{"model1":{"model":"llama3:8b-instruct-q4_0","type":"ollama"}}},"processId":66009,"rootPath":"/home/silas/Projects/test","rootUri":null,"workspaceFolders":[]},"id":0}"##; + send_message(&mut stdin, initialization_message)?; + let _ = read_response(&mut stdout)?; + + send_message( + &mut stdin, + r#"{"jsonrpc":"2.0","method":"initialized","params":{}}"#, + )?; + send_message( + &mut stdin, + r##"{"jsonrpc":"2.0","method":"textDocument/didOpen","params":{"textDocument":{"languageId":"python","text":"# Multiplies two numbers\ndef multiply_two_numbers(x, y):\n\n# A singular test\nassert multiply_two_numbers(2, 3) == 6\n","uri":"file:///fake.py","version":0}}}"##, + )?; + send_message( + &mut stdin, + r##"{"jsonrpc":"2.0","method":"textDocument/didChange","params":{"contentChanges":[{"range":{"end":{"character":31,"line":1},"start":{"character":31,"line":1}},"text":"\n "}],"textDocument":{"uri":"file:///fake.py","version":1}}}"##, + )?; + send_message( + &mut stdin, + r##"{"jsonrpc":"2.0","method":"textDocument/didChange","params":{"contentChanges":[{"range":{"end":{"character":4,"line":2},"start":{"character":4,"line":2}},"text":"r"}],"textDocument":{"uri":"file:///fake.py","version":2}}}"##, + )?; + send_message( + &mut stdin, + r##"{"jsonrpc":"2.0","method":"textDocument/didChange","params":{"contentChanges":[{"range":{"end":{"character":5,"line":2},"start":{"character":5,"line":2}},"text":"e"}],"textDocument":{"uri":"file:///fake.py","version":3}}}"##, + )?; + send_message( + &mut stdin, + r##"{"jsonrpc":"2.0","method":"textDocument/didChange","params":{"contentChanges":[{"range":{"end":{"character":6,"line":2},"start":{"character":6,"line":2}},"text":"t"}],"textDocument":{"uri":"file:///fake.py","version":4}}}"##, + )?; + send_message( + &mut stdin, + r##"{"jsonrpc":"2.0","method":"textDocument/didChange","params":{"contentChanges":[{"range":{"end":{"character":7,"line":2},"start":{"character":7,"line":2}},"text":"u"}],"textDocument":{"uri":"file:///fake.py","version":5}}}"##, + )?; + send_message( + &mut stdin, + r##"{"jsonrpc":"2.0","method":"textDocument/didChange","params":{"contentChanges":[{"range":{"end":{"character":8,"line":2},"start":{"character":8,"line":2}},"text":"r"}],"textDocument":{"uri":"file:///fake.py","version":6}}}"##, + )?; + send_message( + &mut stdin, + r##"{"jsonrpc":"2.0","method":"textDocument/didChange","params":{"contentChanges":[{"range":{"end":{"character":9,"line":2},"start":{"character":9,"line":2}},"text":"n"}],"textDocument":{"uri":"file:///fake.py","version":7}}}"##, + )?; + send_message( + &mut stdin, + r##"{"jsonrpc":"2.0","method":"textDocument/completion","params":{"position":{"character":10,"line":2},"textDocument":{"uri":"file:///fake.py"}},"id":1}"##, + )?; + + let output = read_response(&mut stdout)?; + assert_eq!( + output, + r##"{"jsonrpc":"2.0","id":1,"result":{"isIncomplete":false,"items":[{"filterText":" return","kind":1,"label":"ai - x * y","textEdit":{"newText":"x * y","range":{"end":{"character":10,"line":2},"start":{"character":10,"line":2}}}}]}}"## + ); + + child.kill()?; + Ok(()) +} + +// This FIM completion sequence was created using helix with lsp-ai and reading the logs +// It utilizes Ollama with deepseek-coder:1.3b-base and a temperature of 0 +// It starts with a Python file: +// ``` +// # Multiplies two numbers +// def multiply_two_numbers(x, y): +// +// # A singular test +// assert multiply_two_numbers(2, 3) == 6 +// +// ``` +// And has the following sequence of key strokes: +// o on line 2 (this creates an indented new line and enters insert mode) +// r +// e +// The sequence has: +// - 1 textDocument/DidOpen notification +// - 3 textDocument/didChange notifications +// - 1 textDocument/completion requests +#[test] +fn test_fim_completion_sequence() -> Result<()> { + let mut child = Command::new("cargo") + .arg("run") + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .spawn()?; + + let mut stdin = child.stdin.take().unwrap(); + let mut stdout = child.stdout.take().unwrap(); + + let initialization_message = r##"{"jsonrpc":"2.0","method":"initialize","params":{"capabilities":{"general":{"positionEncodings":["utf-8","utf-32","utf-16"]},"textDocument":{"codeAction":{"codeActionLiteralSupport":{"codeActionKind":{"valueSet":["","quickfix","refactor","refactor.extract","refactor.inline","refactor.rewrite","source","source.organizeImports"]}},"dataSupport":true,"disabledSupport":true,"isPreferredSupport":true,"resolveSupport":{"properties":["edit","command"]}},"completion":{"completionItem":{"deprecatedSupport":true,"insertReplaceSupport":true,"resolveSupport":{"properties":["documentation","detail","additionalTextEdits"]},"snippetSupport":true,"tagSupport":{"valueSet":[1]}},"completionItemKind":{}},"hover":{"contentFormat":["markdown"]},"inlayHint":{"dynamicRegistration":false},"publishDiagnostics":{"tagSupport":{"valueSet":[1,2]},"versionSupport":true},"rename":{"dynamicRegistration":false,"honorsChangeAnnotations":false,"prepareSupport":true},"signatureHelp":{"signatureInformation":{"activeParameterSupport":true,"documentationFormat":["markdown"],"parameterInformation":{"labelOffsetSupport":true}}}},"window":{"workDoneProgress":true},"workspace":{"applyEdit":true,"configuration":true,"didChangeConfiguration":{"dynamicRegistration":false},"didChangeWatchedFiles":{"dynamicRegistration":true,"relativePatternSupport":false},"executeCommand":{"dynamicRegistration":false},"fileOperations":{"didRename":true,"willRename":true},"inlayHint":{"refreshSupport":false},"symbol":{"dynamicRegistration":false},"workspaceEdit":{"documentChanges":true,"failureHandling":"abort","normalizesLineEndings":false,"resourceOperations":["create","rename","delete"]},"workspaceFolders":true}},"clientInfo":{"name":"helix","version":"24.3 (beb5afcb)"},"initializationOptions":{"completion":{"model":"model1","parameters":{"fim":{"end":"<|fim▁end|>","middle":"<|fim▁hole|>","start":"<|fim▁begin|>"},"max_context":1024,"options":{"num_predict":32,"temperature":0}}},"memory":{"file_store":{}},"models":{"model1":{"model":"deepseek-coder:1.3b-base","type":"ollama"}}},"processId":50347,"rootPath":"/home/silas/Projects/test","rootUri":null,"workspaceFolders":[]},"id":0}"##; + send_message(&mut stdin, initialization_message)?; + let _ = read_response(&mut stdout)?; + + send_message( + &mut stdin, + r#"{"jsonrpc":"2.0","method":"initialized","params":{}}"#, + )?; + send_message( + &mut stdin, + r##"{"jsonrpc":"2.0","method":"textDocument/didOpen","params":{"textDocument":{"languageId":"python","text":"# Multiplies two numbers\ndef multiply_two_numbers(x, y):\n\n# A singular test\nassert multiply_two_numbers(2, 3) == 6\n","uri":"file:///fake.py","version":0}}}"##, + )?; + send_message( + &mut stdin, + r##"{"jsonrpc":"2.0","method":"textDocument/didChange","params":{"contentChanges":[{"range":{"end":{"character":31,"line":1},"start":{"character":31,"line":1}},"text":"\n "}],"textDocument":{"uri":"file:///fake.py","version":1}}}"##, + )?; + send_message( + &mut stdin, + r##"{"jsonrpc":"2.0","method":"textDocument/didChange","params":{"contentChanges":[{"range":{"end":{"character":4,"line":2},"start":{"character":4,"line":2}},"text":"r"}],"textDocument":{"uri":"file:///fake.py","version":2}}}"##, + )?; + send_message( + &mut stdin, + r##"{"jsonrpc":"2.0","method":"textDocument/didChange","params":{"contentChanges":[{"range":{"end":{"character":5,"line":2},"start":{"character":5,"line":2}},"text":"e"}],"textDocument":{"uri":"file:///fake.py","version":3}}}"##, + )?; + send_message( + &mut stdin, + r##"{"jsonrpc":"2.0","method":"textDocument/completion","params":{"position":{"character":6,"line":2},"textDocument":{"uri":"file:///fake.py"}},"id":1}"##, + )?; + + let output = read_response(&mut stdout)?; + assert_eq!( + output, + r##"{"jsonrpc":"2.0","id":1,"result":{"isIncomplete":false,"items":[{"filterText":" re","kind":1,"label":"ai - turn x * y","textEdit":{"newText":"turn x * y","range":{"end":{"character":6,"line":2},"start":{"character":6,"line":2}}}}]}}"## + ); + + child.kill()?; + Ok(()) +} + +// This completion sequence was created using helix with lsp-ai and reading the logs +// It utilizes Ollama with deepseek-coder:1.3b-base and a temperature of 0 +// It starts with a Python file: +// ``` +// # Multiplies two numbers +// def multiply_two_numbers(x, y): +// +// ``` +// And has the following sequence of key strokes: +// o on line 2 (this creates an indented new line and enters insert mode) +// r +// e +// t +// u +// r +// n +// The sequence has: +// - 1 textDocument/DidOpen notification +// - 7 textDocument/didChange notifications +// - 1 textDocument/completion requests +#[test] +fn test_completion_sequence() -> Result<()> { + let mut child = Command::new("cargo") + .arg("run") + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .spawn()?; + + let mut stdin = child.stdin.take().unwrap(); + let mut stdout = child.stdout.take().unwrap(); + + let initialization_message = r##"{"jsonrpc":"2.0","method":"initialize","params":{"capabilities":{"general":{"positionEncodings":["utf-8","utf-32","utf-16"]},"textDocument":{"codeAction":{"codeActionLiteralSupport":{"codeActionKind":{"valueSet":["","quickfix","refactor","refactor.extract","refactor.inline","refactor.rewrite","source","source.organizeImports"]}},"dataSupport":true,"disabledSupport":true,"isPreferredSupport":true,"resolveSupport":{"properties":["edit","command"]}},"completion":{"completionItem":{"deprecatedSupport":true,"insertReplaceSupport":true,"resolveSupport":{"properties":["documentation","detail","additionalTextEdits"]},"snippetSupport":true,"tagSupport":{"valueSet":[1]}},"completionItemKind":{}},"hover":{"contentFormat":["markdown"]},"inlayHint":{"dynamicRegistration":false},"publishDiagnostics":{"tagSupport":{"valueSet":[1,2]},"versionSupport":true},"rename":{"dynamicRegistration":false,"honorsChangeAnnotations":false,"prepareSupport":true},"signatureHelp":{"signatureInformation":{"activeParameterSupport":true,"documentationFormat":["markdown"],"parameterInformation":{"labelOffsetSupport":true}}}},"window":{"workDoneProgress":true},"workspace":{"applyEdit":true,"configuration":true,"didChangeConfiguration":{"dynamicRegistration":false},"didChangeWatchedFiles":{"dynamicRegistration":true,"relativePatternSupport":false},"executeCommand":{"dynamicRegistration":false},"fileOperations":{"didRename":true,"willRename":true},"inlayHint":{"refreshSupport":false},"symbol":{"dynamicRegistration":false},"workspaceEdit":{"documentChanges":true,"failureHandling":"abort","normalizesLineEndings":false,"resourceOperations":["create","rename","delete"]},"workspaceFolders":true}},"clientInfo":{"name":"helix","version":"24.3 (beb5afcb)"},"initializationOptions":{"completion":{"model":"model1","parameters":{"max_context":1024,"options":{"num_predict":32,"temperature":0}}},"memory":{"file_store":{}},"models":{"model1":{"model":"deepseek-coder:1.3b-base","type":"ollama"}}},"processId":62322,"rootPath":"/home/silas/Projects/test","rootUri":null,"workspaceFolders":[]},"id":0}"##; + send_message(&mut stdin, initialization_message)?; + let _ = read_response(&mut stdout)?; + + send_message( + &mut stdin, + r#"{"jsonrpc":"2.0","method":"initialized","params":{}}"#, + )?; + send_message( + &mut stdin, + r##"{"jsonrpc":"2.0","method":"textDocument/didOpen","params":{"textDocument":{"languageId":"python","text":"# Multiplies two numbers\ndef multiply_two_numbers(x, y):\n","uri":"file:///fake.py","version":0}}}"##, + )?; + send_message( + &mut stdin, + r##"{"jsonrpc":"2.0","method":"textDocument/didChange","params":{"contentChanges":[{"range":{"end":{"character":31,"line":1},"start":{"character":31,"line":1}},"text":"\n "}],"textDocument":{"uri":"file:///fake.py","version":1}}}"##, + )?; + send_message( + &mut stdin, + r##"{"jsonrpc":"2.0","method":"textDocument/didChange","params":{"contentChanges":[{"range":{"end":{"character":4,"line":2},"start":{"character":4,"line":2}},"text":"r"}],"textDocument":{"uri":"file:///fake.py","version":2}}}"##, + )?; + send_message( + &mut stdin, + r##"{"jsonrpc":"2.0","method":"textDocument/didChange","params":{"contentChanges":[{"range":{"end":{"character":5,"line":2},"start":{"character":5,"line":2}},"text":"e"}],"textDocument":{"uri":"file:///fake.py","version":3}}}"##, + )?; + send_message( + &mut stdin, + r##"{"jsonrpc":"2.0","method":"textDocument/didChange","params":{"contentChanges":[{"range":{"end":{"character":6,"line":2},"start":{"character":6,"line":2}},"text":"t"}],"textDocument":{"uri":"file:///fake.py","version":4}}}"##, + )?; + send_message( + &mut stdin, + r##"{"jsonrpc":"2.0","method":"textDocument/didChange","params":{"contentChanges":[{"range":{"end":{"character":7,"line":2},"start":{"character":7,"line":2}},"text":"u"}],"textDocument":{"uri":"file:///fake.py","version":5}}}"##, + )?; + send_message( + &mut stdin, + r##"{"jsonrpc":"2.0","method":"textDocument/didChange","params":{"contentChanges":[{"range":{"end":{"character":8,"line":2},"start":{"character":8,"line":2}},"text":"r"}],"textDocument":{"uri":"file:///fake.py","version":6}}}"##, + )?; + send_message( + &mut stdin, + r##"{"jsonrpc":"2.0","method":"textDocument/didChange","params":{"contentChanges":[{"range":{"end":{"character":9,"line":2},"start":{"character":9,"line":2}},"text":"n"}],"textDocument":{"uri":"file:///fake.py","version":7}}}"##, + )?; + send_message( + &mut stdin, + r##"{"jsonrpc":"2.0","method":"textDocument/completion","params":{"position":{"character":10,"line":2},"textDocument":{"uri":"file:///fake.py"}},"id":1}"##, + )?; + + let output = read_response(&mut stdout)?; + assert_eq!( + output, + r##"{"jsonrpc":"2.0","id":1,"result":{"isIncomplete":false,"items":[{"filterText":" return","kind":1,"label":"ai - x * y","textEdit":{"newText":" x * y","range":{"end":{"character":10,"line":2},"start":{"character":10,"line":2}}}}]}}"## + ); + + child.kill()?; + Ok(()) +} diff --git a/crates/splitter-tree-sitter/Cargo.toml b/crates/splitter-tree-sitter/Cargo.toml new file mode 100644 index 0000000..2502006 --- /dev/null +++ b/crates/splitter-tree-sitter/Cargo.toml @@ -0,0 +1,19 @@ +[package] +name = "splitter-tree-sitter" +version = "0.1.0" +description = "A code splitter utilizing Tree-sitter" + +edition.workspace = true +repository.workspace = true +license.workspace = true + +[dependencies] +thiserror = "1.0.61" +tree-sitter = "0.22" + +[dev-dependencies] +tree-sitter-rust = "0.21" +tree-sitter-zig = { git = "https://github.com/maxxnino/tree-sitter-zig" } + +[build-dependencies] +cc="*" diff --git a/crates/splitter-tree-sitter/LICENSE b/crates/splitter-tree-sitter/LICENSE new file mode 100644 index 0000000..19e1809 --- /dev/null +++ b/crates/splitter-tree-sitter/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024 Silas Marvin + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/crates/splitter-tree-sitter/README.md b/crates/splitter-tree-sitter/README.md new file mode 100644 index 0000000..5432bd6 --- /dev/null +++ b/crates/splitter-tree-sitter/README.md @@ -0,0 +1,3 @@ +# tree-sitter-splitter + +This is a code splitter that utilizes Tree-sitter to split code. diff --git a/crates/splitter-tree-sitter/src/lib.rs b/crates/splitter-tree-sitter/src/lib.rs new file mode 100644 index 0000000..49bf1a0 --- /dev/null +++ b/crates/splitter-tree-sitter/src/lib.rs @@ -0,0 +1,326 @@ +use thiserror::Error; +use tree_sitter::{Tree, TreeCursor}; + +#[derive(Error, Debug)] +pub enum NewError { + #[error("chunk_size must be greater than chunk_overlap")] + SizeOverlapError, +} + +#[derive(Error, Debug)] +pub enum SplitError { + #[error("converting utf8 to str")] + Utf8Error(#[from] core::str::Utf8Error), +} + +pub struct TreeSitterCodeSplitter { + chunk_size: usize, + chunk_overlap: usize, +} + +pub struct ByteRange { + pub start_byte: usize, + pub end_byte: usize, +} + +impl ByteRange { + fn new(start_byte: usize, end_byte: usize) -> Self { + Self { + start_byte, + end_byte, + } + } +} + +pub struct Chunk<'a> { + pub text: &'a str, + pub range: ByteRange, +} + +impl<'a> Chunk<'a> { + fn new(text: &'a str, range: ByteRange) -> Self { + Self { text, range } + } +} + +impl TreeSitterCodeSplitter { + pub fn new(chunk_size: usize, chunk_overlap: usize) -> Result { + if chunk_overlap > chunk_size { + Err(NewError::SizeOverlapError) + } else { + Ok(Self { + chunk_size, + chunk_overlap, + }) + } + } + + pub fn split<'a, 'b, 'c>( + &'a self, + tree: &'b Tree, + utf8: &'c [u8], + ) -> Result>, SplitError> { + let cursor = tree.walk(); + Ok(self + .split_recursive(cursor, utf8)? + .into_iter() + .rev() + // Let's combine some of our smaller chunks together + // We also want to do this in reverse as it (seems) to make more sense to combine code slices from bottom to top + .try_fold(vec![], |mut acc, current| { + if acc.len() == 0 { + acc.push(current); + Ok::<_, SplitError>(acc) + } else { + if acc.last().as_ref().unwrap().text.len() + current.text.len() + < self.chunk_size + { + let last = acc.pop().unwrap(); + let text = std::str::from_utf8( + &utf8[current.range.start_byte..last.range.end_byte], + )?; + acc.push(Chunk::new( + text, + ByteRange::new(current.range.start_byte, last.range.end_byte), + )); + } else { + acc.push(current); + } + Ok(acc) + } + })? + .into_iter() + .rev() + .collect()) + } + + fn split_recursive<'a, 'b, 'c>( + &'a self, + mut cursor: TreeCursor<'b>, + utf8: &'c [u8], + ) -> Result>, SplitError> { + let node = cursor.node(); + let text = node.utf8_text(utf8)?; + + // There are three cases: + // 1. Is the current range of code smaller than the chunk_size? If so, return it + // 2. If not, does the current node have children? If so, recursively walk down + // 3. If not, we must split our current node + let mut out = if text.chars().count() <= self.chunk_size { + vec![Chunk::new( + text, + ByteRange::new(node.range().start_byte, node.range().end_byte), + )] + } else { + let mut cursor_copy = cursor.clone(); + if cursor_copy.goto_first_child() { + self.split_recursive(cursor_copy, utf8)? + } else { + let mut current_range = + ByteRange::new(node.range().start_byte, node.range().end_byte); + let mut chunks = vec![]; + let mut current_chunk = text; + loop { + if current_chunk.len() < self.chunk_size { + chunks.push(Chunk::new(current_chunk, current_range)); + break; + } else { + let new_chunk = ¤t_chunk[0..self.chunk_size.min(current_chunk.len())]; + let new_range = ByteRange::new( + current_range.start_byte, + current_range.start_byte + new_chunk.as_bytes().len(), + ); + chunks.push(Chunk::new(new_chunk, new_range)); + let new_current_chunk = + ¤t_chunk[self.chunk_size - self.chunk_overlap..]; + let byte_diff = + current_chunk.as_bytes().len() - new_current_chunk.as_bytes().len(); + current_range = ByteRange::new( + current_range.start_byte + byte_diff, + current_range.end_byte, + ); + current_chunk = new_current_chunk + } + } + chunks + } + }; + if cursor.goto_next_sibling() { + out.append(&mut self.split_recursive(cursor, utf8)?); + } + Ok(out) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use tree_sitter::Parser; + + #[test] + fn test_split_rust() { + let splitter = TreeSitterCodeSplitter::new(128, 0).unwrap(); + + let mut parser = Parser::new(); + parser + .set_language(&tree_sitter_rust::language()) + .expect("Error loading Rust grammar"); + + let source_code = r#" +#[derive(Debug)] +struct Rectangle { + width: u32, + height: u32, +} + +impl Rectangle { + fn area(&self) -> u32 { + self.width * self.height + } +} + +fn main() { + let rect1 = Rectangle { + width: 30, + height: 50, + }; + + println!( + "The area of the rectangle is {} square pixels.", + rect1.area() + ); +} +"#; + let tree = parser.parse(source_code, None).unwrap(); + let chunks = splitter.split(&tree, source_code.as_bytes()).unwrap(); + assert_eq!( + chunks[0].text, + r#"#[derive(Debug)] +struct Rectangle { + width: u32, + height: u32, +}"# + ); + assert_eq!( + chunks[1].text, + r#"impl Rectangle { + fn area(&self) -> u32 { + self.width * self.height + } +}"# + ); + assert_eq!( + chunks[2].text, + r#"fn main() { + let rect1 = Rectangle { + width: 30, + height: 50, + };"# + ); + assert_eq!( + chunks[3].text, + r#"println!( + "The area of the rectangle is {} square pixels.", + rect1.area() + ); +}"# + ); + } + + #[test] + fn test_split_zig() { + let splitter = TreeSitterCodeSplitter::new(128, 10).unwrap(); + + let mut parser = Parser::new(); + parser + .set_language(&tree_sitter_rust::language()) + .expect("Error loading Rust grammar"); + + let source_code = r#" +const std = @import("std"); +const parseInt = std.fmt.parseInt; + +std.debug.print("Here is a long string 1 ... Here is a long string 2 ... Here is a long string 3 ... Here is a long string 4 ... Here is a long string 5 ... Here is a long string 6 ... Here is a long string 7 ... Here is a long string 8 ... Here is a long string 9 ...", .{}); + +test "parse integers" { + const input = "123 67 89,99"; + const ally = std.testing.allocator; + + var list = std.ArrayList(u32).init(ally); + // Ensure the list is freed at scope exit. + // Try commenting out this line! + defer list.deinit(); + + var it = std.mem.tokenizeAny(u8, input, " ,"); + while (it.next()) |num| { + const n = try parseInt(u32, num, 10); + try list.append(n); + } + + const expected = [_]u32{ 123, 67, 89, 99 }; + + for (expected, list.items) |exp, actual| { + try std.testing.expectEqual(exp, actual); + } +} +"#; + let tree = parser.parse(source_code, None).unwrap(); + let chunks = splitter.split(&tree, source_code.as_bytes()).unwrap(); + + assert_eq!( + chunks[0].text, + r#"const std = @import("std"); +const parseInt = std.fmt.parseInt; + +std.debug.print(""# + ); + assert_eq!( + chunks[1].text, + r#"Here is a long string 1 ... Here is a long string 2 ... Here is a long string 3 ... Here is a long string 4 ... Here is a long s"# + ); + assert_eq!( + chunks[2].text, + r#"s a long string 5 ... Here is a long string 6 ... Here is a long string 7 ... Here is a long string 8 ... Here is a long string "# + ); + assert_eq!(chunks[3].text, r#"ng string 9 ...", .{});"#); + assert_eq!( + chunks[4].text, + r#"test "parse integers" { + const input = "123 67 89,99"; + const ally = std.testing.allocator; + + var list = std.ArrayList"# + ); + assert_eq!( + chunks[5].text, + r#"(u32).init(ally); + // Ensure the list is freed at scope exit. + // Try commenting out this line!"# + ); + assert_eq!( + chunks[6].text, + r#"defer list.deinit(); + + var it = std.mem.tokenizeAny(u8, input, " ,"); + while (it.next()) |num"# + ); + assert_eq!( + chunks[7].text, + r#"| { + const n = try parseInt(u32, num, 10); + try list.append(n); + } + + const expected = [_]u32{ 123, 67, 89,"# + ); + assert_eq!( + chunks[8].text, + r#"99 }; + + for (expected, list.items) |exp, actual| { + try std.testing.expectEqual(exp, actual); + } +}"# + ); + } +} diff --git a/crates/utils-tree-sitter/Cargo.toml b/crates/utils-tree-sitter/Cargo.toml new file mode 100644 index 0000000..22ebc26 --- /dev/null +++ b/crates/utils-tree-sitter/Cargo.toml @@ -0,0 +1,37 @@ +[package] +name = "utils-tree-sitter" +version = "0.1.0" +description = "Utils for working with splitter-tree-sitter" + +edition.workspace = true +repository.workspace = true +license.workspace = true + +[dependencies] +thiserror = "1.0.61" +tree-sitter = "0.22" +tree-sitter-bash = { version = "0.21", optional = true } +tree-sitter-c = { version = "0.21", optional = true } +tree-sitter-cpp = { version = "0.22", optional = true } +tree-sitter-c-sharp = { version = "0.21", optional = true } +tree-sitter-css = { version = "0.21", optional = true } +tree-sitter-elixir = { version = "0.2", optional = true } +tree-sitter-erlang = { version = "0.6", optional = true } +tree-sitter-go = { version = "0.21", optional = true } +tree-sitter-html = { version = "0.20", optional = true } +tree-sitter-java = { version = "0.21", optional = true } +tree-sitter-javascript = { version = "0.21", optional = true } +tree-sitter-json = { version = "0.21", optional = true } +tree-sitter-haskell = { version = "0.21", optional = true } +tree-sitter-lua = { version = "0.1.0", optional = true } +tree-sitter-ocaml = { version = "0.22.0", optional = true } +tree-sitter-python = { version = "0.21", optional = true } +tree-sitter-rust = { version = "0.21", optional = true } +# tree-sitter-zig = { git = "https://github.com/maxxnino/tree-sitter-zig", optional = true } + +[build-dependencies] +cc="*" + +[features] +default = [] +all = ["dep:tree-sitter-python", "dep:tree-sitter-bash", "dep:tree-sitter-c", "dep:tree-sitter-cpp", "dep:tree-sitter-c-sharp", "dep:tree-sitter-css", "dep:tree-sitter-elixir", "dep:tree-sitter-erlang", "dep:tree-sitter-go", "dep:tree-sitter-html", "dep:tree-sitter-java", "dep:tree-sitter-javascript", "dep:tree-sitter-json", "dep:tree-sitter-rust", "dep:tree-sitter-haskell", "dep:tree-sitter-lua", "dep:tree-sitter-ocaml"] diff --git a/crates/utils-tree-sitter/LICENSE b/crates/utils-tree-sitter/LICENSE new file mode 100644 index 0000000..19e1809 --- /dev/null +++ b/crates/utils-tree-sitter/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024 Silas Marvin + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/crates/utils-tree-sitter/README.md b/crates/utils-tree-sitter/README.md new file mode 100644 index 0000000..97f75c8 --- /dev/null +++ b/crates/utils-tree-sitter/README.md @@ -0,0 +1,3 @@ +# utils-tree-sitter + +Utils for working with Tree-sitter diff --git a/crates/utils-tree-sitter/src/lib.rs b/crates/utils-tree-sitter/src/lib.rs new file mode 100644 index 0000000..7facd50 --- /dev/null +++ b/crates/utils-tree-sitter/src/lib.rs @@ -0,0 +1,90 @@ +use thiserror::Error; +use tree_sitter::{LanguageError, Parser}; + +#[derive(Error, Debug)] +pub enum GetParserError { + #[error("no parser found for extension")] + NoParserFoundForExtension(String), + #[error("no parser found for extension")] + NoLanguageFoundForExtension(String), + #[error("loading grammer")] + LoadingGrammer(#[from] LanguageError), +} + +fn get_extension_for_language(extension: &str) -> Result { + Ok(match extension { + "py" => "Python", + "rs" => "Rust", + // "zig" => "Zig", + "sh" => "Bash", + "c" => "C", + "cpp" => "C++", + "cs" => "C#", + "css" => "CSS", + "ex" => "Elixir", + "erl" => "Erlang", + "go" => "Go", + "html" => "HTML", + "java" => "Java", + "js" => "JavaScript", + "json" => "JSON", + "hs" => "Haskell", + "lua" => "Lua", + "ml" => "OCaml", + _ => { + return Err(GetParserError::NoLanguageFoundForExtension( + extension.to_string(), + )) + } + } + .to_string()) +} + +pub fn get_parser_for_extension(extension: &str) -> Result { + let language = get_extension_for_language(extension)?; + let mut parser = Parser::new(); + match language.as_str() { + #[cfg(any(feature = "all", feature = "python"))] + "Python" => parser.set_language(&tree_sitter_python::language())?, + #[cfg(any(feature = "all", feature = "rust"))] + "Rust" => parser.set_language(&tree_sitter_rust::language())?, + // #[cfg(any(feature = "all", feature = "zig"))] + // "Zig" => parser.set_language(&tree_sitter_zig::language())?, + #[cfg(any(feature = "all", feature = "bash"))] + "Bash" => parser.set_language(&tree_sitter_bash::language())?, + #[cfg(any(feature = "all", feature = "c"))] + "C" => parser.set_language(&tree_sitter_c::language())?, + #[cfg(any(feature = "all", feature = "cpp"))] + "C++" => parser.set_language(&tree_sitter_cpp::language())?, + #[cfg(any(feature = "all", feature = "c-sharp"))] + "C#" => parser.set_language(&tree_sitter_c_sharp::language())?, + #[cfg(any(feature = "all", feature = "css"))] + "CSS" => parser.set_language(&tree_sitter_css::language())?, + #[cfg(any(feature = "all", feature = "elixir"))] + "Elixir" => parser.set_language(&tree_sitter_elixir::language())?, + #[cfg(any(feature = "all", feature = "erlang"))] + "Erlang" => parser.set_language(&tree_sitter_erlang::language())?, + #[cfg(any(feature = "all", feature = "go"))] + "Go" => parser.set_language(&tree_sitter_go::language())?, + #[cfg(any(feature = "all", feature = "html"))] + "HTML" => parser.set_language(&tree_sitter_html::language())?, + #[cfg(any(feature = "all", feature = "java"))] + "Java" => parser.set_language(&tree_sitter_java::language())?, + #[cfg(any(feature = "all", feature = "javascript"))] + "JavaScript" => parser.set_language(&tree_sitter_javascript::language())?, + #[cfg(any(feature = "all", feature = "json"))] + "JSON" => parser.set_language(&tree_sitter_json::language())?, + #[cfg(any(feature = "all", feature = "haskell"))] + "Haskell" => parser.set_language(&tree_sitter_haskell::language())?, + #[cfg(any(feature = "all", feature = "lua"))] + "Lua" => parser.set_language(&tree_sitter_lua::language())?, + #[cfg(any(feature = "all", feature = "ocaml"))] + "OCaml" => parser.set_language(&tree_sitter_ocaml::language_ocaml())?, + _ => { + return Err(GetParserError::NoParserFoundForExtension( + language.to_string(), + )) + } + } + Ok(parser) +} diff --git a/src/memory_backends/file_store.rs b/src/memory_backends/file_store.rs deleted file mode 100644 index 4d70509..0000000 --- a/src/memory_backends/file_store.rs +++ /dev/null @@ -1,597 +0,0 @@ -use anyhow::Context; -use indexmap::IndexSet; -use lsp_types::TextDocumentPositionParams; -use parking_lot::Mutex; -use ropey::Rope; -use serde_json::Value; -use std::collections::HashMap; -use tracing::instrument; - -use crate::{ - config::{self, Config}, - utils::tokens_to_estimated_characters, -}; - -use super::{ContextAndCodePrompt, FIMPrompt, MemoryBackend, MemoryRunParams, Prompt, PromptType}; - -pub struct FileStore { - _crawl: bool, - _config: Config, - file_map: Mutex>, - accessed_files: Mutex>, -} - -impl FileStore { - pub fn new(file_store_config: config::FileStore, config: Config) -> Self { - Self { - _crawl: file_store_config.crawl, - _config: config, - file_map: Mutex::new(HashMap::new()), - accessed_files: Mutex::new(IndexSet::new()), - } - } - - pub fn new_without_crawl(config: Config) -> Self { - Self { - _crawl: false, - _config: config, - file_map: Mutex::new(HashMap::new()), - accessed_files: Mutex::new(IndexSet::new()), - } - } - - fn get_rope_for_position( - &self, - position: &TextDocumentPositionParams, - characters: usize, - ) -> anyhow::Result<(Rope, usize)> { - // Get the rope and set our initial cursor index - let current_document_uri = position.text_document.uri.to_string(); - let mut rope = self - .file_map - .lock() - .get(¤t_document_uri) - .context("Error file not found")? - .clone(); - let mut cursor_index = rope.line_to_char(position.position.line as usize) - + position.position.character as usize; - // Add to our rope if we need to - for file in self - .accessed_files - .lock() - .iter() - .filter(|f| **f != current_document_uri) - { - let needed = characters.saturating_sub(rope.len_chars() + 1); - if needed == 0 { - break; - } - let file_map = self.file_map.lock(); - let r = file_map.get(file).context("Error file not found")?; - let slice_max = needed.min(r.len_chars() + 1); - let rope_str_slice = r - .get_slice(0..slice_max - 1) - .context("Error getting slice")? - .to_string(); - rope.insert(0, "\n"); - rope.insert(0, &rope_str_slice); - cursor_index += slice_max; - } - Ok((rope, cursor_index)) - } - - pub fn get_characters_around_position( - &self, - position: &TextDocumentPositionParams, - characters: usize, - ) -> anyhow::Result { - let rope = self - .file_map - .lock() - .get(position.text_document.uri.as_str()) - .context("Error file not found")? - .clone(); - let cursor_index = rope.line_to_char(position.position.line as usize) - + position.position.character as usize; - let start = cursor_index.saturating_sub(characters / 2); - let end = rope - .len_chars() - .min(cursor_index + (characters - (cursor_index - start))); - let rope_slice = rope - .get_slice(start..end) - .context("Error getting rope slice")?; - Ok(rope_slice.to_string()) - } - - pub fn build_code( - &self, - position: &TextDocumentPositionParams, - prompt_type: PromptType, - params: MemoryRunParams, - ) -> anyhow::Result { - let (mut rope, cursor_index) = - self.get_rope_for_position(position, params.max_context_length)?; - - Ok(match prompt_type { - PromptType::ContextAndCode => { - if params.is_for_chat { - let max_length = tokens_to_estimated_characters(params.max_context_length); - let start = cursor_index.saturating_sub(max_length / 2); - let end = rope - .len_chars() - .min(cursor_index + (max_length - (cursor_index - start))); - - rope.insert(cursor_index, ""); - let rope_slice = rope - .get_slice(start..end + "".chars().count()) - .context("Error getting rope slice")?; - Prompt::ContextAndCode(ContextAndCodePrompt::new( - "".to_string(), - rope_slice.to_string(), - )) - } else { - let start = cursor_index - .saturating_sub(tokens_to_estimated_characters(params.max_context_length)); - let rope_slice = rope - .get_slice(start..cursor_index) - .context("Error getting rope slice")?; - Prompt::ContextAndCode(ContextAndCodePrompt::new( - "".to_string(), - rope_slice.to_string(), - )) - } - } - PromptType::FIM => { - let max_length = tokens_to_estimated_characters(params.max_context_length); - let start = cursor_index.saturating_sub(max_length / 2); - let end = rope - .len_chars() - .min(cursor_index + (max_length - (cursor_index - start))); - let prefix = rope - .get_slice(start..cursor_index) - .context("Error getting rope slice")?; - let suffix = rope - .get_slice(cursor_index..end) - .context("Error getting rope slice")?; - Prompt::FIM(FIMPrompt::new(prefix.to_string(), suffix.to_string())) - } - }) - } -} - -#[async_trait::async_trait] -impl MemoryBackend for FileStore { - #[instrument(skip(self))] - async fn get_filter_text( - &self, - position: &TextDocumentPositionParams, - ) -> anyhow::Result { - let rope = self - .file_map - .lock() - .get(position.text_document.uri.as_str()) - .context("Error file not found")? - .clone(); - let line = rope - .get_line(position.position.line as usize) - .context("Error getting filter_text")? - .slice(0..position.position.character as usize) - .to_string(); - Ok(line) - } - - #[instrument(skip(self))] - async fn build_prompt( - &self, - position: &TextDocumentPositionParams, - prompt_type: PromptType, - params: &Value, - ) -> anyhow::Result { - let params: MemoryRunParams = params.try_into()?; - self.build_code(position, prompt_type, params) - } - - #[instrument(skip(self))] - async fn opened_text_document( - &self, - params: lsp_types::DidOpenTextDocumentParams, - ) -> anyhow::Result<()> { - let rope = Rope::from_str(¶ms.text_document.text); - let uri = params.text_document.uri.to_string(); - self.file_map.lock().insert(uri.clone(), rope); - self.accessed_files.lock().shift_insert(0, uri); - Ok(()) - } - - #[instrument(skip(self))] - async fn changed_text_document( - &self, - params: lsp_types::DidChangeTextDocumentParams, - ) -> anyhow::Result<()> { - let uri = params.text_document.uri.to_string(); - let mut file_map = self.file_map.lock(); - let rope = file_map - .get_mut(&uri) - .context("Error trying to get file that does not exist")?; - for change in params.content_changes { - // If range is ommitted, text is the new text of the document - if let Some(range) = change.range { - let start_index = - rope.line_to_char(range.start.line as usize) + range.start.character as usize; - let end_index = - rope.line_to_char(range.end.line as usize) + range.end.character as usize; - rope.remove(start_index..end_index); - rope.insert(start_index, &change.text); - } else { - *rope = Rope::from_str(&change.text); - } - } - self.accessed_files.lock().shift_insert(0, uri); - Ok(()) - } - - #[instrument(skip(self))] - async fn renamed_files(&self, params: lsp_types::RenameFilesParams) -> anyhow::Result<()> { - for file_rename in params.files { - let mut file_map = self.file_map.lock(); - if let Some(rope) = file_map.remove(&file_rename.old_uri) { - file_map.insert(file_rename.new_uri, rope); - } - } - Ok(()) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use lsp_types::{ - DidOpenTextDocumentParams, FileRename, Position, Range, RenameFilesParams, - TextDocumentContentChangeEvent, TextDocumentIdentifier, TextDocumentItem, - VersionedTextDocumentIdentifier, - }; - use serde_json::json; - - fn generate_base_file_store() -> anyhow::Result { - let config = Config::default_with_file_store_without_models(); - let file_store_config = if let config::ValidMemoryBackend::FileStore(file_store_config) = - config.config.memory.clone() - { - file_store_config - } else { - anyhow::bail!("requires a file_store_config") - }; - Ok(FileStore::new(file_store_config, config)) - } - - fn generate_filler_text_document(uri: Option<&str>, text: Option<&str>) -> TextDocumentItem { - let uri = uri.unwrap_or("file://filler/"); - let text = text.unwrap_or("Here is the document body"); - TextDocumentItem { - uri: reqwest::Url::parse(uri).unwrap(), - language_id: "filler".to_string(), - version: 0, - text: text.to_string(), - } - } - - #[tokio::test] - async fn can_open_document() -> anyhow::Result<()> { - let params = lsp_types::DidOpenTextDocumentParams { - text_document: generate_filler_text_document(None, None), - }; - let file_store = generate_base_file_store()?; - file_store.opened_text_document(params).await?; - let file = file_store - .file_map - .lock() - .get("file://filler/") - .unwrap() - .clone(); - assert_eq!(file.to_string(), "Here is the document body"); - Ok(()) - } - - #[tokio::test] - async fn can_rename_document() -> anyhow::Result<()> { - let params = lsp_types::DidOpenTextDocumentParams { - text_document: generate_filler_text_document(None, None), - }; - let file_store = generate_base_file_store()?; - file_store.opened_text_document(params).await?; - - let params = RenameFilesParams { - files: vec![FileRename { - old_uri: "file://filler/".to_string(), - new_uri: "file://filler2/".to_string(), - }], - }; - file_store.renamed_files(params).await?; - - let file = file_store - .file_map - .lock() - .get("file://filler2/") - .unwrap() - .clone(); - assert_eq!(file.to_string(), "Here is the document body"); - Ok(()) - } - - #[tokio::test] - async fn can_change_document() -> anyhow::Result<()> { - let text_document = generate_filler_text_document(None, None); - - let params = DidOpenTextDocumentParams { - text_document: text_document.clone(), - }; - let file_store = generate_base_file_store()?; - file_store.opened_text_document(params).await?; - - let params = lsp_types::DidChangeTextDocumentParams { - text_document: VersionedTextDocumentIdentifier { - uri: text_document.uri.clone(), - version: 1, - }, - content_changes: vec![TextDocumentContentChangeEvent { - range: Some(Range { - start: Position { - line: 0, - character: 1, - }, - end: Position { - line: 0, - character: 3, - }, - }), - range_length: None, - text: "a".to_string(), - }], - }; - file_store.changed_text_document(params).await?; - let file = file_store - .file_map - .lock() - .get("file://filler/") - .unwrap() - .clone(); - assert_eq!(file.to_string(), "Hae is the document body"); - - let params = lsp_types::DidChangeTextDocumentParams { - text_document: VersionedTextDocumentIdentifier { - uri: text_document.uri, - version: 1, - }, - content_changes: vec![TextDocumentContentChangeEvent { - range: None, - range_length: None, - text: "abc".to_string(), - }], - }; - file_store.changed_text_document(params).await?; - let file = file_store - .file_map - .lock() - .get("file://filler/") - .unwrap() - .clone(); - assert_eq!(file.to_string(), "abc"); - - Ok(()) - } - - #[tokio::test] - async fn can_build_prompt() -> anyhow::Result<()> { - let text_document = generate_filler_text_document( - None, - Some( - r#"Document Top -Here is a more complicated document - -Some text - -The end with a trailing new line -"#, - ), - ); - - // Test basic completion - let params = lsp_types::DidOpenTextDocumentParams { - text_document: text_document.clone(), - }; - let file_store = generate_base_file_store()?; - file_store.opened_text_document(params).await?; - - let prompt = file_store - .build_prompt( - &TextDocumentPositionParams { - text_document: TextDocumentIdentifier { - uri: text_document.uri.clone(), - }, - position: Position { - line: 0, - character: 10, - }, - }, - PromptType::ContextAndCode, - &json!({}), - ) - .await?; - let prompt: ContextAndCodePrompt = prompt.try_into()?; - assert_eq!(prompt.context, ""); - assert_eq!("Document T", prompt.code); - - // Test FIM - let prompt = file_store - .build_prompt( - &TextDocumentPositionParams { - text_document: TextDocumentIdentifier { - uri: text_document.uri.clone(), - }, - position: Position { - line: 0, - character: 10, - }, - }, - PromptType::FIM, - &json!({}), - ) - .await?; - let prompt: FIMPrompt = prompt.try_into()?; - assert_eq!(prompt.prompt, r#"Document T"#); - assert_eq!( - prompt.suffix, - r#"op -Here is a more complicated document - -Some text - -The end with a trailing new line -"# - ); - - // Test chat - let prompt = file_store - .build_prompt( - &TextDocumentPositionParams { - text_document: TextDocumentIdentifier { - uri: text_document.uri.clone(), - }, - position: Position { - line: 0, - character: 10, - }, - }, - PromptType::ContextAndCode, - &json!({ - "messages": [] - }), - ) - .await?; - let prompt: ContextAndCodePrompt = prompt.try_into()?; - assert_eq!(prompt.context, ""); - let text = r#"Document Top -Here is a more complicated document - -Some text - -The end with a trailing new line -"# - .to_string(); - assert_eq!(text, prompt.code); - - // Test multi-file - let text_document2 = generate_filler_text_document( - Some("file://filler2"), - Some( - r#"Document Top2 -Here is a more complicated document - -Some text - -The end with a trailing new line -"#, - ), - ); - let params = lsp_types::DidOpenTextDocumentParams { - text_document: text_document2.clone(), - }; - file_store.opened_text_document(params).await?; - - let prompt = file_store - .build_prompt( - &TextDocumentPositionParams { - text_document: TextDocumentIdentifier { - uri: text_document.uri.clone(), - }, - position: Position { - line: 0, - character: 10, - }, - }, - PromptType::ContextAndCode, - &json!({}), - ) - .await?; - let prompt: ContextAndCodePrompt = prompt.try_into()?; - assert_eq!(prompt.context, ""); - assert_eq!(format!("{}\nDocument T", text_document2.text), prompt.code); - - Ok(()) - } - - #[tokio::test] - async fn test_document_cursor_placement_corner_cases() -> anyhow::Result<()> { - let text_document = generate_filler_text_document(None, Some("test\n")); - let params = lsp_types::DidOpenTextDocumentParams { - text_document: text_document.clone(), - }; - let file_store = generate_base_file_store()?; - file_store.opened_text_document(params).await?; - - // Test chat - let prompt = file_store - .build_prompt( - &TextDocumentPositionParams { - text_document: TextDocumentIdentifier { - uri: text_document.uri.clone(), - }, - position: Position { - line: 1, - character: 0, - }, - }, - PromptType::ContextAndCode, - &json!({"messages": []}), - ) - .await?; - let prompt: ContextAndCodePrompt = prompt.try_into()?; - assert_eq!(prompt.context, ""); - let text = r#"test -"# - .to_string(); - assert_eq!(text, prompt.code); - - Ok(()) - } - - // #[tokio::test] - // async fn test_fim_placement_corner_cases() -> anyhow::Result<()> { - // let text_document = generate_filler_text_document(None, Some("test\n")); - // let params = lsp_types::DidOpenTextDocumentParams { - // text_document: text_document.clone(), - // }; - // let file_store = generate_base_file_store()?; - // file_store.opened_text_document(params).await?; - - // // Test FIM - // let params = json!({ - // "fim": { - // "start": "SS", - // "middle": "MM", - // "end": "EE" - // } - // }); - // let prompt = file_store - // .build_prompt( - // &TextDocumentPositionParams { - // text_document: TextDocumentIdentifier { - // uri: text_document.uri.clone(), - // }, - // position: Position { - // line: 1, - // character: 0, - // }, - // }, - // params, - // ) - // .await?; - // assert_eq!(prompt.context, ""); - // let text = r#"test - // "# - // .to_string(); - // assert_eq!(text, prompt.code); - - // Ok(()) - // } -} diff --git a/src/memory_backends/postgresml/mod.rs b/src/memory_backends/postgresml/mod.rs deleted file mode 100644 index 8b007ab..0000000 --- a/src/memory_backends/postgresml/mod.rs +++ /dev/null @@ -1,254 +0,0 @@ -use std::{ - sync::mpsc::{self, Sender}, - time::Duration, -}; - -use anyhow::Context; -use lsp_types::TextDocumentPositionParams; -use pgml::{Collection, Pipeline}; -use serde_json::{json, Value}; -use tokio::time; -use tracing::instrument; - -use crate::{ - config::{self, Config}, - utils::tokens_to_estimated_characters, -}; - -use super::{ - file_store::FileStore, ContextAndCodePrompt, MemoryBackend, MemoryRunParams, Prompt, PromptType, -}; - -pub struct PostgresML { - _config: Config, - file_store: FileStore, - collection: Collection, - pipeline: Pipeline, - debounce_tx: Sender, - added_pipeline: bool, -} - -impl PostgresML { - pub fn new( - postgresml_config: config::PostgresML, - configuration: Config, - ) -> anyhow::Result { - let file_store = FileStore::new_without_crawl(configuration.clone()); - let database_url = if let Some(database_url) = postgresml_config.database_url { - database_url - } else { - std::env::var("PGML_DATABASE_URL")? - }; - // TODO: Think on the naming of the collection - // Maybe filter on metadata or I'm not sure - let collection = Collection::new("test-lsp-ai-3", Some(database_url))?; - // TODO: Review the pipeline - let pipeline = Pipeline::new( - "v1", - Some( - json!({ - "text": { - "splitter": { - "model": "recursive_character", - "parameters": { - "chunk_size": 1500, - "chunk_overlap": 40 - } - }, - "semantic_search": { - "model": "intfloat/e5-small", - } - } - }) - .into(), - ), - )?; - // Setup up a debouncer for changed text documents - let runtime = tokio::runtime::Builder::new_multi_thread() - .worker_threads(2) - .enable_all() - .build()?; - let mut task_collection = collection.clone(); - let (debounce_tx, debounce_rx) = mpsc::channel::(); - runtime.spawn(async move { - let duration = Duration::from_millis(500); - let mut file_paths = Vec::new(); - loop { - time::sleep(duration).await; - let new_paths: Vec = debounce_rx.try_iter().collect(); - if !new_paths.is_empty() { - for path in new_paths { - if !file_paths.iter().any(|p| *p == path) { - file_paths.push(path); - } - } - } else { - if file_paths.is_empty() { - continue; - } - let documents = file_paths - .into_iter() - .map(|path| { - let text = std::fs::read_to_string(&path) - .unwrap_or_else(|_| panic!("Error reading path: {}", path)); - json!({ - "id": path, - "text": text - }) - .into() - }) - .collect(); - task_collection - .upsert_documents(documents, None) - .await - .expect("PGML - Error adding pipeline to collection"); - file_paths = Vec::new(); - } - } - }); - Ok(Self { - _config: configuration, - file_store, - collection, - pipeline, - debounce_tx, - added_pipeline: false, - }) - } -} - -#[async_trait::async_trait] -impl MemoryBackend for PostgresML { - #[instrument(skip(self))] - async fn get_filter_text( - &self, - position: &TextDocumentPositionParams, - ) -> anyhow::Result { - self.file_store.get_filter_text(position).await - } - - #[instrument(skip(self))] - async fn build_prompt( - &self, - position: &TextDocumentPositionParams, - prompt_type: PromptType, - params: &Value, - ) -> anyhow::Result { - let params: MemoryRunParams = params.try_into()?; - let query = self - .file_store - .get_characters_around_position(position, 512)?; - let res = self - .collection - .vector_search_local( - json!({ - "query": { - "fields": { - "text": { - "query": query - } - }, - }, - "limit": 5 - }) - .into(), - &self.pipeline, - ) - .await?; - let context = res - .into_iter() - .map(|c| { - c["chunk"] - .as_str() - .map(|t| t.to_owned()) - .context("PGML - Error getting chunk from vector search") - }) - .collect::>>()? - .join("\n\n"); - let mut file_store_params = params.clone(); - file_store_params.max_context_length = 512; - let code = self - .file_store - .build_code(position, prompt_type, file_store_params)?; - let code: ContextAndCodePrompt = code.try_into()?; - let code = code.code; - let max_characters = tokens_to_estimated_characters(params.max_context_length); - let _context: String = context - .chars() - .take(max_characters - code.chars().count()) - .collect(); - // We need to redo this section to work with the new memory backend system - todo!() - // Ok(Prompt::new(context, code)) - } - - #[instrument(skip(self))] - async fn opened_text_document( - &self, - params: lsp_types::DidOpenTextDocumentParams, - ) -> anyhow::Result<()> { - let text = params.text_document.text.clone(); - let path = params.text_document.uri.path().to_owned(); - let task_added_pipeline = self.added_pipeline; - let mut task_collection = self.collection.clone(); - let mut task_pipeline = self.pipeline.clone(); - if !task_added_pipeline { - task_collection - .add_pipeline(&mut task_pipeline) - .await - .expect("PGML - Error adding pipeline to collection"); - } - task_collection - .upsert_documents( - vec![json!({ - "id": path, - "text": text - }) - .into()], - None, - ) - .await - .expect("PGML - Error upserting documents"); - self.file_store.opened_text_document(params).await - } - - #[instrument(skip(self))] - async fn changed_text_document( - &self, - params: lsp_types::DidChangeTextDocumentParams, - ) -> anyhow::Result<()> { - let path = params.text_document.uri.path().to_owned(); - self.debounce_tx.send(path)?; - self.file_store.changed_text_document(params).await - } - - #[instrument(skip(self))] - async fn renamed_files(&self, params: lsp_types::RenameFilesParams) -> anyhow::Result<()> { - let mut task_collection = self.collection.clone(); - let task_params = params.clone(); - for file in task_params.files { - task_collection - .delete_documents( - json!({ - "id": file.old_uri - }) - .into(), - ) - .await - .expect("PGML - Error deleting file"); - let text = std::fs::read_to_string(&file.new_uri).expect("PGML - Error reading file"); - task_collection - .upsert_documents( - vec![json!({ - "id": file.new_uri, - "text": text - }) - .into()], - None, - ) - .await - .expect("PGML - Error adding pipeline to collection"); - } - self.file_store.renamed_files(params).await - } -} diff --git a/tests/integration_tests.rs b/tests/integration_tests.rs deleted file mode 100644 index fcfa410..0000000 --- a/tests/integration_tests.rs +++ /dev/null @@ -1,112 +0,0 @@ -use anyhow::Result; -use std::{ - io::{Read, Write}, - process::{ChildStdin, ChildStdout, Command, Stdio}, -}; - -// Note if you get an empty response with no error, that typically means -// the language server died -fn read_response(stdout: &mut ChildStdout) -> Result { - let mut content_length = None; - let mut buf = vec![]; - loop { - let mut buf2 = vec![0]; - stdout.read_exact(&mut buf2)?; - buf.push(buf2[0]); - if let Some(content_length) = content_length { - if buf.len() == content_length { - break; - } - } else { - let len = buf.len(); - if len > 4 - && buf[len - 4] == 13 - && buf[len - 3] == 10 - && buf[len - 2] == 13 - && buf[len - 1] == 10 - { - content_length = - Some(String::from_utf8(buf[16..len - 4].to_vec())?.parse::()?); - buf = vec![]; - } - } - } - Ok(String::from_utf8(buf)?) -} - -fn send_message(stdin: &mut ChildStdin, message: &str) -> Result<()> { - stdin.write_all(format!("Content-Length: {}\r\n", message.as_bytes().len(),).as_bytes())?; - stdin.write_all("\r\n".as_bytes())?; - stdin.write_all(message.as_bytes())?; - Ok(()) -} - -// This completion sequence was created using helix with the lsp-ai analyzer and reading the logs -// It starts with a Python file: -// ``` -// # Multiplies two numbers -// def multiply_two_numbers(x, y): -// -// # A singular test -// assert multiply_two_numbers(2, 3) == 6 -// ``` -// And has the following sequence of key strokes: -// o on line 2 (this creates an indented new line and enters insert mode) -// r -// e -// The sequence has: -// - 1 textDocument/DidOpen notification -// - 3 textDocument/didChange notifications -// - 1 textDocument/completion requests -// This test can fail if the model gives a different response than normal, but that seems reasonably unlikely -// I guess we should hardcode the seed or something if we want to do more of these -#[test] -fn test_completion_sequence() -> Result<()> { - let mut child = Command::new("cargo") - .arg("run") - .stdin(Stdio::piped()) - .stdout(Stdio::piped()) - .stderr(Stdio::piped()) - .spawn()?; - - let mut stdin = child.stdin.take().unwrap(); - let mut stdout = child.stdout.take().unwrap(); - - let initialization_message = r##"{"jsonrpc":"2.0","method":"initialize","params":{"capabilities":{"general":{"positionEncodings":["utf-8","utf-32","utf-16"]},"textDocument":{"codeAction":{"codeActionLiteralSupport":{"codeActionKind":{"valueSet":["","quickfix","refactor","refactor.extract","refactor.inline","refactor.rewrite","source","source.organizeImports"]}},"dataSupport":true,"disabledSupport":true,"isPreferredSupport":true,"resolveSupport":{"properties":["edit","command"]}},"completion":{"completionItem":{"deprecatedSupport":true,"insertReplaceSupport":true,"resolveSupport":{"properties":["documentation","detail","additionalTextEdits"]},"snippetSupport":true,"tagSupport":{"valueSet":[1]}},"completionItemKind":{}},"hover":{"contentFormat":["markdown"]},"inlayHint":{"dynamicRegistration":false},"publishDiagnostics":{"versionSupport":true},"rename":{"dynamicRegistration":false,"honorsChangeAnnotations":false,"prepareSupport":true},"signatureHelp":{"signatureInformation":{"activeParameterSupport":true,"documentationFormat":["markdown"],"parameterInformation":{"labelOffsetSupport":true}}}},"window":{"workDoneProgress":true},"workspace":{"applyEdit":true,"configuration":true,"didChangeConfiguration":{"dynamicRegistration":false},"didChangeWatchedFiles":{"dynamicRegistration":true,"relativePatternSupport":false},"executeCommand":{"dynamicRegistration":false},"inlayHint":{"refreshSupport":false},"symbol":{"dynamicRegistration":false},"workspaceEdit":{"documentChanges":true,"failureHandling":"abort","normalizesLineEndings":false,"resourceOperations":["create","rename","delete"]},"workspaceFolders":true}},"clientInfo":{"name":"helix","version":"23.10 (f6021dd0)"},"processId":70007,"rootPath":"/Users/silas/Projects/Tests/lsp-ai-tests","rootUri":null,"workspaceFolders":[]},"id":0}"##; - send_message(&mut stdin, initialization_message)?; - let _ = read_response(&mut stdout)?; - - send_message( - &mut stdin, - r#"{"jsonrpc":"2.0","method":"initialized","params":{}}"#, - )?; - send_message( - &mut stdin, - r##"{"jsonrpc":"2.0","method":"textDocument/didOpen","params":{"textDocument":{"languageId":"python","text":"# Multiplies two numbers\ndef multiply_two_numbers(x, y):\n\n# A singular test\nassert multiply_two_numbers(2, 3) == 6\n","uri":"file:///fake.py","version":0}}}"##, - )?; - send_message( - &mut stdin, - r##"{"jsonrpc":"2.0","method":"textDocument/didChange","params":{"contentChanges":[{"range":{"end":{"character":31,"line":1},"start":{"character":31,"line":1}},"text":"\n "}],"textDocument":{"uri":"file:///fake.py","version":1}}}"##, - )?; - send_message( - &mut stdin, - r##"{"jsonrpc":"2.0","method":"textDocument/didChange","params":{"contentChanges":[{"range":{"end":{"character":4,"line":2},"start":{"character":4,"line":2}},"text":"r"}],"textDocument":{"uri":"file:///fake.py","version":2}}}"##, - )?; - send_message( - &mut stdin, - r##"{"jsonrpc":"2.0","method":"textDocument/didChange","params":{"contentChanges":[{"range":{"end":{"character":5,"line":2},"start":{"character":5,"line":2}},"text":"e"}],"textDocument":{"uri":"file:///fake.py","version":3}}}"##, - )?; - send_message( - &mut stdin, - r##"{"jsonrpc":"2.0","method":"textDocument/completion","params":{"position":{"character":6,"line":2},"textDocument":{"uri":"file:///fake.py"}},"id":1}"##, - )?; - - let output = read_response(&mut stdout)?; - assert_eq!( - output, - r##"{"jsonrpc":"2.0","id":1,"result":{"isIncomplete":false,"items":[{"filterText":" re\n","kind":1,"label":"ai - turn x * y","textEdit":{"newText":"turn x * y","range":{"end":{"character":6,"line":2},"start":{"character":6,"line":2}}}}]}}"## - ); - - child.kill()?; - Ok(()) -}