Merge pull request #31 from SilasMarvin/silas-rag-force

Introduce RAG and PostgresML support
This commit is contained in:
Silas Marvin 2024-06-24 21:26:25 -07:00 committed by GitHub
commit 17ea67a6a7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
39 changed files with 3402 additions and 1121 deletions

341
Cargo.lock generated
View File

@ -149,6 +149,18 @@ dependencies = [
"num-traits",
]
[[package]]
name = "auto_enums"
version = "0.8.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1899bfcfd9340ceea3533ea157360ba8fa864354eccbceab58e1006ecab35393"
dependencies = [
"derive_utils",
"proc-macro2",
"quote",
"syn 2.0.52",
]
[[package]]
name = "autocfg"
version = "1.1.0"
@ -356,7 +368,7 @@ version = "4.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "307bc0538d5f0f83b8248db3087aa92fe504e4691294d0c96c0eabc33f47ba47"
dependencies = [
"heck",
"heck 0.4.1",
"proc-macro2",
"quote",
"syn 2.0.52",
@ -662,6 +674,17 @@ dependencies = [
"syn 1.0.109",
]
[[package]]
name = "derive_utils"
version = "0.14.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "61bb5a1014ce6dfc2a378578509abe775a5aa06bff584a547555d9efdb81b926"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.52",
]
[[package]]
name = "difflib"
version = "0.4.0"
@ -730,9 +753,9 @@ checksum = "0d6ef0072f8a535281e4876be788938b528e9a1d43900b82c2569af7da799125"
[[package]]
name = "either"
version = "1.10.0"
version = "1.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "11157ac094ffbdde99aa67b23417ebdd801842852b500e395a45a9c0aac03e4a"
checksum = "3dca9240753cf90908d7e4aac30f630662b02aebaa1b58a3cadabdb23385b58b"
dependencies = [
"serde",
]
@ -1056,6 +1079,12 @@ dependencies = [
"unicode-segmentation",
]
[[package]]
name = "heck"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
[[package]]
name = "hermit-abi"
version = "0.3.9"
@ -1364,6 +1393,15 @@ dependencies = [
"either",
]
[[package]]
name = "itertools"
version = "0.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186"
dependencies = [
"either",
]
[[package]]
name = "itoa"
version = "1.0.10"
@ -1523,6 +1561,7 @@ dependencies = [
"anyhow",
"assert_cmd",
"async-trait",
"cc",
"directories",
"hf-hub",
"ignore",
@ -1530,6 +1569,7 @@ dependencies = [
"llama-cpp-2",
"lsp-server",
"lsp-types",
"md5",
"minijinja",
"once_cell",
"parking_lot",
@ -1539,10 +1579,14 @@ dependencies = [
"ropey",
"serde",
"serde_json",
"splitter-tree-sitter",
"text-splitter",
"tokenizers",
"tokio",
"tracing",
"tracing-subscriber",
"tree-sitter",
"utils-tree-sitter",
"xxhash-rust",
]
@ -2196,9 +2240,9 @@ dependencies = [
[[package]]
name = "regex"
version = "1.10.3"
version = "1.10.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b62dbe01f0b06f9d8dc7d49e05a0785f153b00b2c227856282f671e0318c9b15"
checksum = "b91213439dad192326a0d7c6ee3955910425f441d7038e0d6933b0aec5c4517f"
dependencies = [
"aho-corasick",
"memchr",
@ -2415,6 +2459,12 @@ dependencies = [
"untrusted",
]
[[package]]
name = "rustversion"
version = "1.0.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "955d28af4278de8121b7ebeb796b6a45735dc01436d898801014aced2773a3d6"
[[package]]
name = "ryu"
version = "1.0.17"
@ -2475,7 +2525,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "878cf3d57f0e5bfacd425cdaccc58b4c06d68a7b71c63fc28710a20c88676808"
dependencies = [
"darling 0.14.4",
"heck",
"heck 0.4.1",
"quote",
"syn 1.0.109",
]
@ -2498,7 +2548,7 @@ version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "25a82fcb49253abcb45cdcb2adf92956060ec0928635eb21b4f7a6d8f25ab0bc"
dependencies = [
"heck",
"heck 0.4.1",
"proc-macro2",
"quote",
"syn 2.0.52",
@ -2756,6 +2806,17 @@ dependencies = [
"der",
]
[[package]]
name = "splitter-tree-sitter"
version = "0.1.0"
dependencies = [
"cc",
"thiserror",
"tree-sitter",
"tree-sitter-rust",
"tree-sitter-zig",
]
[[package]]
name = "spm_precompiled"
version = "0.1.4"
@ -2857,7 +2918,7 @@ checksum = "5833ef53aaa16d860e92123292f1f6a3d53c34ba8b1969f152ef1a7bb803f3c8"
dependencies = [
"dotenvy",
"either",
"heck",
"heck 0.4.1",
"hex",
"once_cell",
"proc-macro2",
@ -3013,6 +3074,28 @@ version = "0.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f"
[[package]]
name = "strum"
version = "0.26.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5d8cec3501a5194c432b2b7976db6b7d10ec95c253208b45f83f7136aa985e29"
dependencies = [
"strum_macros",
]
[[package]]
name = "strum_macros"
version = "0.26.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4c6bee85a5a24955dc440386795aa378cd9cf82acd5f764469152d2270e581be"
dependencies = [
"heck 0.5.0",
"proc-macro2",
"quote",
"rustversion",
"syn 2.0.52",
]
[[package]]
name = "subtle"
version = "2.5.0"
@ -3087,19 +3170,36 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3369f5ac52d5eb6ab48c6b4ffdc8efbcad6b89c765749064ba298f2c68a16a76"
[[package]]
name = "thiserror"
version = "1.0.58"
name = "text-splitter"
version = "0.13.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "03468839009160513471e86a034bb2c5c0e4baae3b43f79ffc55c4a5427b3297"
checksum = "2ab9dc04b7cf08eb01c07c272bf699fa55679a326ddf7dd075e14094efc80fb9"
dependencies = [
"ahash",
"auto_enums",
"either",
"itertools 0.13.0",
"once_cell",
"regex",
"strum",
"thiserror",
"unicode-segmentation",
]
[[package]]
name = "thiserror"
version = "1.0.61"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c546c80d6be4bc6a00c0f01730c08df82eaa7a7a61f11d656526506112cc1709"
dependencies = [
"thiserror-impl",
]
[[package]]
name = "thiserror-impl"
version = "1.0.58"
version = "1.0.61"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c61f3ba182994efc43764a46c018c347bc492c79f024e705f46567b418f6d4f7"
checksum = "46c3384250002a6d5af4d114f2845d37b57521033f30d5c3f46c4d70e1197533"
dependencies = [
"proc-macro2",
"quote",
@ -3339,6 +3439,195 @@ dependencies = [
"tracing-serde",
]
[[package]]
name = "tree-sitter"
version = "0.22.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df7cc499ceadd4dcdf7ec6d4cbc34ece92c3fa07821e287aedecd4416c516dca"
dependencies = [
"cc",
"regex",
]
[[package]]
name = "tree-sitter-bash"
version = "0.21.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b5244703ad2e08a616d859a0557d7aa290adcd5e0990188a692e628ffe9dce40"
dependencies = [
"cc",
"tree-sitter",
]
[[package]]
name = "tree-sitter-c"
version = "0.21.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f956d5351d62652864a4ff3ae861747e7a1940dc96c9998ae400ac0d3ce30427"
dependencies = [
"cc",
"tree-sitter",
]
[[package]]
name = "tree-sitter-c-sharp"
version = "0.21.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ff899037068a1ffbb891891b7e94db1400ddf12c3d934b85b8c9e30be5cd18da"
dependencies = [
"cc",
"tree-sitter",
]
[[package]]
name = "tree-sitter-cpp"
version = "0.22.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "537b7e0f0d8c89b8dd6f4d195814da94832f20720c09016c2a3ac3dc3c437993"
dependencies = [
"cc",
"tree-sitter",
]
[[package]]
name = "tree-sitter-css"
version = "0.21.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e2f806f96136762b0121f5fdd7172a3dcd8f42d37a2f23ed7f11b35895e20eb4"
dependencies = [
"cc",
"tree-sitter",
]
[[package]]
name = "tree-sitter-elixir"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df94bf7f057768b1cab2ee1f14812ed4ae33f9e04d09254043eeaa797db4ef70"
dependencies = [
"cc",
"tree-sitter",
]
[[package]]
name = "tree-sitter-erlang"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8db61152e6d8a5b3b5895ecbb85848f85d028f84b4633a2368075c35e5817b34"
dependencies = [
"cc",
"tree-sitter",
]
[[package]]
name = "tree-sitter-go"
version = "0.21.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "55cb318be5ccf75f44e054acf6898a5c95d59b53443eed578e16be0cd7ec037f"
dependencies = [
"cc",
"tree-sitter",
]
[[package]]
name = "tree-sitter-haskell"
version = "0.21.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ef25a7e6c73cc1cbe0c0b7dbd5406e7b3485b370bd61c5d8d852ae0781f9bf9a"
dependencies = [
"cc",
"tree-sitter",
]
[[package]]
name = "tree-sitter-html"
version = "0.20.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "95b3492b08a786bf5cc79feb0ef2ff3b115d5174364e0ddfd7860e0b9b088b53"
dependencies = [
"cc",
"tree-sitter",
]
[[package]]
name = "tree-sitter-java"
version = "0.21.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "33bc21adf831a773c075d9d00107ab43965e6a6ea7607b47fd9ec6f3db4b481b"
dependencies = [
"cc",
"tree-sitter",
]
[[package]]
name = "tree-sitter-javascript"
version = "0.21.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4fced510d43e6627cd8e19adfd994ac9cfa3b1d71b0d522b41f74145de37feef"
dependencies = [
"cc",
"tree-sitter",
]
[[package]]
name = "tree-sitter-json"
version = "0.21.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5b737dcb73c35d74b7d64a5f3dde158113c86a012bf3cee2bfdf2150d23b05db"
dependencies = [
"cc",
"tree-sitter",
]
[[package]]
name = "tree-sitter-lua"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3b9fe6fc87bd480e1943fc1fcb02453fb2da050e4e8ce0daa67d801544046856"
dependencies = [
"cc",
"tree-sitter",
]
[[package]]
name = "tree-sitter-ocaml"
version = "0.22.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e2f2e8e848902d12ca6778d31d0e66b5709fc1ad0c84fd8b0c078472fff20dd2"
dependencies = [
"cc",
"tree-sitter",
]
[[package]]
name = "tree-sitter-python"
version = "0.21.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b4066c6cf678f962f8c2c4561f205945c84834cce73d981e71392624fdc390a9"
dependencies = [
"cc",
"tree-sitter",
]
[[package]]
name = "tree-sitter-rust"
version = "0.21.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "277690f420bf90741dea984f3da038ace46c4fe6047cba57a66822226cde1c93"
dependencies = [
"cc",
"tree-sitter",
]
[[package]]
name = "tree-sitter-zig"
version = "0.0.1"
source = "git+https://github.com/maxxnino/tree-sitter-zig#7c5a29b721d409be8842017351bf007d7e384401"
dependencies = [
"cc",
"tree-sitter",
]
[[package]]
name = "try-lock"
version = "0.2.5"
@ -3450,6 +3739,32 @@ version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a"
[[package]]
name = "utils-tree-sitter"
version = "0.1.0"
dependencies = [
"cc",
"thiserror",
"tree-sitter",
"tree-sitter-bash",
"tree-sitter-c",
"tree-sitter-c-sharp",
"tree-sitter-cpp",
"tree-sitter-css",
"tree-sitter-elixir",
"tree-sitter-erlang",
"tree-sitter-go",
"tree-sitter-haskell",
"tree-sitter-html",
"tree-sitter-java",
"tree-sitter-javascript",
"tree-sitter-json",
"tree-sitter-lua",
"tree-sitter-ocaml",
"tree-sitter-python",
"tree-sitter-rust",
]
[[package]]
name = "uuid"
version = "1.7.0"

View File

@ -1,44 +1,12 @@
[package]
name = "lsp-ai"
version = "0.3.0"
[workspace]
members = [
"crates/*",
]
resolver = "2"
[workspace.package]
edition = "2021"
license = "MIT"
description = "LSP-AI is an open-source language server that serves as a backend for AI-powered functionality, designed to assist and empower software engineers, not replace them."
repository = "https://github.com/SilasMarvin/lsp-ai"
readme = "README.md"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
anyhow = "1.0.75"
lsp-server = "0.7.6"
lsp-types = "0.95.0"
ropey = "1.6.1"
serde = "1.0.190"
serde_json = "1.0.108"
hf-hub = { git = "https://github.com/huggingface/hf-hub", version = "0.3.2" }
rand = "0.8.5"
tokenizers = "0.14.1"
parking_lot = "0.12.1"
once_cell = "1.19.0"
directories = "5.0.1"
llama-cpp-2 = { version = "0.1.55", optional = true }
minijinja = { version = "1.0.12", features = ["loader"] }
tracing-subscriber = { version = "0.3.18", features = ["env-filter"] }
tracing = "0.1.40"
xxhash-rust = { version = "0.8.5", features = ["xxh3"] }
reqwest = { version = "0.11.25", features = ["blocking", "json"] }
ignore = "0.4.22"
pgml = "1.0.4"
tokio = { version = "1.36.0", features = ["rt-multi-thread", "time"] }
indexmap = "2.2.5"
async-trait = "0.1.78"
[features]
default = []
llama_cpp = ["dep:llama-cpp-2"]
metal = ["llama-cpp-2/metal"]
cuda = ["llama-cpp-2/cuda"]
[dev-dependencies]
assert_cmd = "2.0.14"

51
crates/lsp-ai/Cargo.toml Normal file
View File

@ -0,0 +1,51 @@
[package]
name = "lsp-ai"
version = "0.3.0"
description.workspace = true
repository.workspace = true
readme.workspace = true
edition.workspace = true
license.workspace = true
[dependencies]
anyhow = "1.0.75"
lsp-server = "0.7.6"
lsp-types = "0.95.0"
ropey = "1.6.1"
serde = "1.0.190"
serde_json = "1.0.108"
hf-hub = { git = "https://github.com/huggingface/hf-hub", version = "0.3.2" }
rand = "0.8.5"
tokenizers = "0.14.1"
parking_lot = "0.12.1"
once_cell = "1.19.0"
directories = "5.0.1"
llama-cpp-2 = { version = "0.1.55", optional = true }
minijinja = { version = "1.0.12", features = ["loader"] }
tracing-subscriber = { version = "0.3.18", features = ["env-filter"] }
tracing = "0.1.40"
xxhash-rust = { version = "0.8.5", features = ["xxh3"] }
reqwest = { version = "0.11.25", features = ["blocking", "json"] }
ignore = "0.4.22"
pgml = "1.0.4"
tokio = { version = "1.36.0", features = ["rt-multi-thread", "time"] }
indexmap = "2.2.5"
async-trait = "0.1.78"
tree-sitter = "0.22"
utils-tree-sitter = { path = "../utils-tree-sitter", features = ["all"], version = "0.1.0" }
splitter-tree-sitter = { path = "../splitter-tree-sitter", version = "0.1.0" }
text-splitter = { version = "0.13.3" }
md5 = "0.7.0"
[build-dependencies]
cc="*"
[features]
default = []
llama_cpp = ["dep:llama-cpp-2"]
metal = ["llama-cpp-2/metal"]
cuda = ["llama-cpp-2/cuda"]
[dev-dependencies]
assert_cmd = "2.0.14"

View File

@ -24,6 +24,51 @@ impl Default for PostProcess {
}
}
#[derive(Debug, Clone, Deserialize)]
pub enum ValidSplitter {
#[serde(rename = "tree_sitter")]
TreeSitter(TreeSitter),
#[serde(rename = "text_sitter")]
TextSplitter(TextSplitter),
}
impl Default for ValidSplitter {
fn default() -> Self {
ValidSplitter::TreeSitter(TreeSitter::default())
}
}
const fn chunk_size_default() -> usize {
1500
}
const fn chunk_overlap_default() -> usize {
0
}
#[derive(Debug, Clone, Deserialize)]
pub struct TreeSitter {
#[serde(default = "chunk_size_default")]
pub chunk_size: usize,
#[serde(default = "chunk_overlap_default")]
pub chunk_overlap: usize,
}
impl Default for TreeSitter {
fn default() -> Self {
Self {
chunk_size: 1500,
chunk_overlap: 0,
}
}
}
#[derive(Debug, Clone, Deserialize)]
pub struct TextSplitter {
#[serde(default = "chunk_size_default")]
pub chunk_size: usize,
}
#[derive(Debug, Clone, Deserialize)]
pub enum ValidMemoryBackend {
#[serde(rename = "file_store")]
@ -67,15 +112,6 @@ impl ChatMessage {
}
}
#[derive(Debug, Clone, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct Chat {
pub completion: Option<Vec<ChatMessage>>,
pub generation: Option<Vec<ChatMessage>>,
pub chat_template: Option<String>,
pub chat_format: Option<String>,
}
#[derive(Clone, Debug, Deserialize)]
#[allow(clippy::upper_case_acronyms)]
#[serde(deny_unknown_fields)]
@ -85,27 +121,52 @@ pub struct FIM {
pub end: String,
}
const fn max_crawl_memory_default() -> u64 {
100_000_000
}
const fn max_crawl_file_size_default() -> u64 {
10_000_000
}
#[derive(Clone, Debug, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct Crawl {
#[serde(default = "max_crawl_file_size_default")]
pub max_file_size: u64,
#[serde(default = "max_crawl_memory_default")]
pub max_crawl_memory: u64,
#[serde(default)]
pub all_files: bool,
}
#[derive(Clone, Debug, Deserialize)]
pub struct PostgresMLEmbeddingModel {
pub model: String,
pub embed_parameters: Option<Value>,
pub query_parameters: Option<Value>,
}
#[derive(Clone, Debug, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct PostgresML {
pub database_url: Option<String>,
pub crawl: Option<Crawl>,
#[serde(default)]
pub crawl: bool,
pub splitter: ValidSplitter,
pub embedding_model: Option<PostgresMLEmbeddingModel>,
}
#[derive(Clone, Debug, Deserialize, Default)]
#[serde(deny_unknown_fields)]
pub struct FileStore {
#[serde(default)]
pub crawl: bool,
pub crawl: Option<Crawl>,
}
const fn n_gpu_layers_default() -> u32 {
1000
}
const fn n_ctx_default() -> u32 {
1000
impl FileStore {
pub fn new_without_crawl() -> Self {
Self { crawl: None }
}
}
#[derive(Clone, Debug, Deserialize)]
@ -137,6 +198,17 @@ pub struct MistralFIM {
pub max_requests_per_second: f32,
}
#[cfg(feature = "llama_cpp")]
const fn n_gpu_layers_default() -> u32 {
1000
}
#[cfg(feature = "llama_cpp")]
const fn n_ctx_default() -> u32 {
1000
}
#[cfg(feature = "llama_cpp")]
#[derive(Clone, Debug, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct LLaMACPP {
@ -230,15 +302,14 @@ pub struct ValidConfig {
#[derive(Clone, Debug, Deserialize, Default)]
pub struct ValidClientParams {
#[serde(alias = "rootURI")]
_root_uri: Option<String>,
_workspace_folders: Option<Vec<String>>,
#[serde(alias = "rootUri")]
pub root_uri: Option<String>,
}
#[derive(Clone, Debug)]
pub struct Config {
pub config: ValidConfig,
_client_params: ValidClientParams,
pub client_params: ValidClientParams,
}
impl Config {
@ -255,7 +326,7 @@ impl Config {
let client_params: ValidClientParams = serde_json::from_value(args)?;
Ok(Self {
config: valid_args,
_client_params: client_params,
client_params,
})
}
@ -300,20 +371,17 @@ impl Config {
}
}
// This makes testing much easier.
// For teesting use only
#[cfg(test)]
impl Config {
pub fn default_with_file_store_without_models() -> Self {
Self {
config: ValidConfig {
memory: ValidMemoryBackend::FileStore(FileStore { crawl: false }),
memory: ValidMemoryBackend::FileStore(FileStore { crawl: None }),
models: HashMap::new(),
completion: None,
},
_client_params: ValidClientParams {
_root_uri: None,
_workspace_folders: None,
},
client_params: ValidClientParams { root_uri: None },
}
}
}

103
crates/lsp-ai/src/crawl.rs Normal file
View File

@ -0,0 +1,103 @@
use ignore::WalkBuilder;
use std::collections::HashSet;
use tracing::{error, instrument};
use crate::config::{self, Config};
pub struct Crawl {
crawl_config: config::Crawl,
config: Config,
crawled_file_types: HashSet<String>,
crawled_all: bool,
}
impl Crawl {
pub fn new(crawl_config: config::Crawl, config: Config) -> Self {
Self {
crawl_config,
config,
crawled_file_types: HashSet::new(),
crawled_all: false,
}
}
#[instrument(skip(self, f))]
pub fn maybe_do_crawl(
&mut self,
triggered_file: Option<String>,
mut f: impl FnMut(&config::Crawl, &str) -> anyhow::Result<bool>,
) -> anyhow::Result<()> {
if self.crawled_all {
return Ok(());
}
if let Some(root_uri) = &self.config.client_params.root_uri {
if !root_uri.starts_with("file://") {
anyhow::bail!("Skipping crawling as root_uri does not begin with file://")
}
let extension_to_match = triggered_file
.map(|tf| {
let path = std::path::Path::new(&tf);
path.extension().map(|f| f.to_str().map(|f| f.to_owned()))
})
.flatten()
.flatten();
if let Some(extension_to_match) = &extension_to_match {
if self.crawled_file_types.contains(extension_to_match) {
return Ok(());
}
}
if !self.crawl_config.all_files && extension_to_match.is_none() {
return Ok(());
}
for result in WalkBuilder::new(&root_uri[7..]).build() {
let result = result?;
let path = result.path();
if !path.is_dir() {
if let Some(path_str) = path.to_str() {
if self.crawl_config.all_files {
match f(&self.crawl_config, path_str) {
Ok(c) => {
if !c {
break;
}
}
Err(e) => error!("{e:?}"),
}
} else {
match (
path.extension().map(|pe| pe.to_str()).flatten(),
&extension_to_match,
) {
(Some(path_extension), Some(extension_to_match)) => {
if path_extension == extension_to_match {
match f(&self.crawl_config, path_str) {
Ok(c) => {
if !c {
break;
}
}
Err(e) => error!("{e:?}"),
}
}
}
_ => continue,
}
}
}
}
}
if let Some(extension_to_match) = extension_to_match {
self.crawled_file_types.insert(extension_to_match);
} else {
self.crawled_all = true
}
}
Ok(())
}
}

View File

@ -14,9 +14,11 @@ use tracing::error;
use tracing_subscriber::{EnvFilter, FmtSubscriber};
mod config;
mod crawl;
mod custom_requests;
mod memory_backends;
mod memory_worker;
mod splitters;
#[cfg(feature = "llama_cpp")]
mod template;
mod transformer_backends;
@ -50,15 +52,19 @@ where
req.extract(R::METHOD)
}
fn main() -> Result<()> {
// Builds a tracing subscriber from the `LSP_AI_LOG` environment variable
// If the variables value is malformed or missing, sets the default log level to ERROR
// Builds a tracing subscriber from the `LSP_AI_LOG` environment variable
// If the variables value is malformed or missing, sets the default log level to ERROR
fn init_logger() {
FmtSubscriber::builder()
.with_writer(std::io::stderr)
.with_ansi(false)
.without_time()
.with_env_filter(EnvFilter::from_env("LSP_AI_LOG"))
.init();
}
fn main() -> Result<()> {
init_logger();
let (connection, io_threads) = Connection::stdio();
let server_capabilities = serde_json::to_value(ServerCapabilities {
@ -83,7 +89,6 @@ fn main_loop(connection: Connection, args: serde_json::Value) -> Result<()> {
let connection = Arc::new(connection);
// Our channel we use to communicate with our transformer worker
// let last_worker_request = Arc::new(Mutex::new(None));
let (transformer_tx, transformer_rx) = mpsc::channel();
// The channel we use to communicate with our memory worker
@ -94,8 +99,6 @@ fn main_loop(connection: Connection, args: serde_json::Value) -> Result<()> {
thread::spawn(move || memory_worker::run(memory_backend, memory_rx));
// Setup our transformer worker
// let transformer_backend: Box<dyn TransformerBackend + Send + Sync> =
// config.clone().try_into()?;
let transformer_backends: HashMap<String, Box<dyn TransformerBackend + Send + Sync>> = config
.config
.models

View File

@ -0,0 +1,924 @@
use anyhow::Context;
use indexmap::IndexSet;
use lsp_types::TextDocumentPositionParams;
use parking_lot::Mutex;
use ropey::Rope;
use serde_json::Value;
use std::{collections::HashMap, io::Read};
use tracing::{error, instrument, warn};
use tree_sitter::{InputEdit, Point, Tree};
use crate::{
config::{self, Config},
crawl::Crawl,
utils::{parse_tree, tokens_to_estimated_characters},
};
use super::{ContextAndCodePrompt, FIMPrompt, MemoryBackend, MemoryRunParams, Prompt, PromptType};
#[derive(Default)]
pub struct AdditionalFileStoreParams {
build_tree: bool,
}
impl AdditionalFileStoreParams {
pub fn new(build_tree: bool) -> Self {
Self { build_tree }
}
}
#[derive(Clone)]
pub struct File {
rope: Rope,
tree: Option<Tree>,
}
impl File {
fn new(rope: Rope, tree: Option<Tree>) -> Self {
Self { rope, tree }
}
pub fn rope(&self) -> &Rope {
&self.rope
}
pub fn tree(&self) -> Option<&Tree> {
self.tree.as_ref()
}
}
pub struct FileStore {
params: AdditionalFileStoreParams,
file_map: Mutex<HashMap<String, File>>,
accessed_files: Mutex<IndexSet<String>>,
crawl: Option<Mutex<Crawl>>,
}
impl FileStore {
pub fn new(mut file_store_config: config::FileStore, config: Config) -> anyhow::Result<Self> {
let crawl = file_store_config
.crawl
.take()
.map(|x| Mutex::new(Crawl::new(x, config.clone())));
let s = Self {
params: AdditionalFileStoreParams::default(),
file_map: Mutex::new(HashMap::new()),
accessed_files: Mutex::new(IndexSet::new()),
crawl,
};
if let Err(e) = s.maybe_do_crawl(None) {
error!("{e:?}")
}
Ok(s)
}
pub fn new_with_params(
mut file_store_config: config::FileStore,
config: Config,
params: AdditionalFileStoreParams,
) -> anyhow::Result<Self> {
let crawl = file_store_config
.crawl
.take()
.map(|x| Mutex::new(Crawl::new(x, config.clone())));
let s = Self {
params,
file_map: Mutex::new(HashMap::new()),
accessed_files: Mutex::new(IndexSet::new()),
crawl,
};
if let Err(e) = s.maybe_do_crawl(None) {
error!("{e:?}")
}
Ok(s)
}
fn add_new_file(&self, uri: &str, contents: String) {
let tree = if self.params.build_tree {
match parse_tree(uri, &contents, None) {
Ok(tree) => Some(tree),
Err(e) => {
error!(
"Failed to parse tree for {uri} with error {e}, falling back to no tree"
);
None
}
}
} else {
None
};
self.file_map
.lock()
.insert(uri.to_string(), File::new(Rope::from_str(&contents), tree));
self.accessed_files.lock().insert(uri.to_string());
}
fn maybe_do_crawl(&self, triggered_file: Option<String>) -> anyhow::Result<()> {
let mut total_bytes = 0;
let mut current_bytes = 0;
if let Some(crawl) = &self.crawl {
crawl
.lock()
.maybe_do_crawl(triggered_file, |config, path| {
// Break if total bytes is over the max crawl memory
if total_bytes as u64 >= config.max_crawl_memory {
warn!("Ending crawl early due to `max_crawl_memory` resetraint");
return Ok(false);
}
// This means it has been opened before
let insert_uri = format!("file:///{path}");
if self.file_map.lock().contains_key(&insert_uri) {
return Ok(true);
}
// Open the file and see if it is small enough to read
let mut f = std::fs::File::open(path)?;
let metadata = f.metadata()?;
if metadata.len() > config.max_file_size {
warn!("Skipping file: {path} because it is too large");
return Ok(true);
}
// Read the file contents
let mut contents = vec![];
f.read_to_end(&mut contents)?;
let contents = String::from_utf8(contents)?;
current_bytes += contents.len();
total_bytes += contents.len();
self.add_new_file(&insert_uri, contents);
Ok(true)
})?;
}
Ok(())
}
fn get_rope_for_position(
&self,
position: &TextDocumentPositionParams,
characters: usize,
pull_from_multiple_files: bool,
) -> anyhow::Result<(Rope, usize)> {
// Get the rope and set our initial cursor index
let current_document_uri = position.text_document.uri.to_string();
let mut rope = self
.file_map
.lock()
.get(&current_document_uri)
.context("Error file not found")?
.rope
.clone();
let mut cursor_index = rope.line_to_char(position.position.line as usize)
+ position.position.character as usize;
// Add to our rope if we need to
for file in self
.accessed_files
.lock()
.iter()
.filter(|f| **f != current_document_uri)
{
let needed = characters.saturating_sub(rope.len_chars() + 1);
if needed == 0 || !pull_from_multiple_files {
break;
}
let file_map = self.file_map.lock();
let r = &file_map.get(file).context("Error file not found")?.rope;
let slice_max = needed.min(r.len_chars() + 1);
let rope_str_slice = r
.get_slice(0..slice_max - 1)
.context("Error getting slice")?
.to_string();
rope.insert(0, "\n");
rope.insert(0, &rope_str_slice);
cursor_index += slice_max;
}
Ok((rope, cursor_index))
}
pub fn get_characters_around_position(
&self,
position: &TextDocumentPositionParams,
characters: usize,
) -> anyhow::Result<String> {
let rope = self
.file_map
.lock()
.get(position.text_document.uri.as_str())
.context("Error file not found")?
.rope
.clone();
let cursor_index = rope.line_to_char(position.position.line as usize)
+ position.position.character as usize;
let start = cursor_index.saturating_sub(characters / 2);
let end = rope
.len_chars()
.min(cursor_index + (characters - (cursor_index - start)));
let rope_slice = rope
.get_slice(start..end)
.context("Error getting rope slice")?;
Ok(rope_slice.to_string())
}
pub fn build_code(
&self,
position: &TextDocumentPositionParams,
prompt_type: PromptType,
params: MemoryRunParams,
pull_from_multiple_files: bool,
) -> anyhow::Result<Prompt> {
let (mut rope, cursor_index) =
self.get_rope_for_position(position, params.max_context, pull_from_multiple_files)?;
Ok(match prompt_type {
PromptType::ContextAndCode => {
if params.is_for_chat {
let max_length = tokens_to_estimated_characters(params.max_context);
let start = cursor_index.saturating_sub(max_length / 2);
let end = rope
.len_chars()
.min(cursor_index + (max_length - (cursor_index - start)));
rope.insert(cursor_index, "<CURSOR>");
let rope_slice = rope
.get_slice(start..end + "<CURSOR>".chars().count())
.context("Error getting rope slice")?;
Prompt::ContextAndCode(ContextAndCodePrompt::new(
"".to_string(),
rope_slice.to_string(),
))
} else {
let start = cursor_index
.saturating_sub(tokens_to_estimated_characters(params.max_context));
let rope_slice = rope
.get_slice(start..cursor_index)
.context("Error getting rope slice")?;
Prompt::ContextAndCode(ContextAndCodePrompt::new(
"".to_string(),
rope_slice.to_string(),
))
}
}
PromptType::FIM => {
let max_length = tokens_to_estimated_characters(params.max_context);
let start = cursor_index.saturating_sub(max_length / 2);
let end = rope
.len_chars()
.min(cursor_index + (max_length - (cursor_index - start)));
let prefix = rope
.get_slice(start..cursor_index)
.context("Error getting rope slice")?;
let suffix = rope
.get_slice(cursor_index..end)
.context("Error getting rope slice")?;
Prompt::FIM(FIMPrompt::new(prefix.to_string(), suffix.to_string()))
}
})
}
pub fn file_map(&self) -> &Mutex<HashMap<String, File>> {
&self.file_map
}
pub fn contains_file(&self, uri: &str) -> bool {
self.file_map.lock().contains_key(uri)
}
pub fn position_to_byte(&self, position: &TextDocumentPositionParams) -> anyhow::Result<usize> {
let file_map = self.file_map.lock();
let uri = position.text_document.uri.to_string();
let file = file_map
.get(&uri)
.with_context(|| format!("trying to get file that does not exist {uri}"))?;
let line_char_index = file
.rope
.try_line_to_char(position.position.line as usize)?;
Ok(file
.rope
.try_char_to_byte(line_char_index + position.position.character as usize)?)
}
}
#[async_trait::async_trait]
impl MemoryBackend for FileStore {
#[instrument(skip(self))]
fn get_filter_text(&self, position: &TextDocumentPositionParams) -> anyhow::Result<String> {
let rope = self
.file_map
.lock()
.get(position.text_document.uri.as_str())
.context("Error file not found")?
.rope
.clone();
let line = rope
.get_line(position.position.line as usize)
.context("Error getting filter text")?
.get_slice(0..position.position.character as usize)
.context("Error getting filter text")?
.to_string();
Ok(line)
}
#[instrument(skip(self))]
async fn build_prompt(
&self,
position: &TextDocumentPositionParams,
prompt_type: PromptType,
params: &Value,
) -> anyhow::Result<Prompt> {
let params: MemoryRunParams = params.try_into()?;
self.build_code(position, prompt_type, params, true)
}
#[instrument(skip(self))]
fn opened_text_document(
&self,
params: lsp_types::DidOpenTextDocumentParams,
) -> anyhow::Result<()> {
let uri = params.text_document.uri.to_string();
self.add_new_file(&uri, params.text_document.text);
if let Err(e) = self.maybe_do_crawl(Some(uri)) {
error!("{e:?}")
}
Ok(())
}
#[instrument(skip(self))]
fn changed_text_document(
&self,
params: lsp_types::DidChangeTextDocumentParams,
) -> anyhow::Result<()> {
let uri = params.text_document.uri.to_string();
let mut file_map = self.file_map.lock();
let file = file_map
.get_mut(&uri)
.with_context(|| format!("Trying to get file that does not exist {uri}"))?;
for change in params.content_changes {
// If range is ommitted, text is the new text of the document
if let Some(range) = change.range {
// Record old positions
let (old_end_position, old_end_byte) = {
let last_line_index = file.rope.len_lines() - 1;
(
file.rope
.get_line(last_line_index)
.context("getting last line for edit")
.map(|last_line| Point::new(last_line_index, last_line.len_chars())),
file.rope.bytes().count(),
)
};
// Update the document
let start_index = file.rope.line_to_char(range.start.line as usize)
+ range.start.character as usize;
let end_index =
file.rope.line_to_char(range.end.line as usize) + range.end.character as usize;
file.rope.remove(start_index..end_index);
file.rope.insert(start_index, &change.text);
// Set new end positions
let (new_end_position, new_end_byte) = {
let last_line_index = file.rope.len_lines() - 1;
(
file.rope
.get_line(last_line_index)
.context("getting last line for edit")
.map(|last_line| Point::new(last_line_index, last_line.len_chars())),
file.rope.bytes().count(),
)
};
// Update the tree
if self.params.build_tree {
let mut old_tree = file.tree.take();
let start_byte = file
.rope
.try_line_to_char(range.start.line as usize)
.and_then(|start_char| {
file.rope
.try_char_to_byte(start_char + range.start.character as usize)
})
.map_err(anyhow::Error::msg);
if let Some(old_tree) = &mut old_tree {
match (start_byte, old_end_position, new_end_position) {
(Ok(start_byte), Ok(old_end_position), Ok(new_end_position)) => {
old_tree.edit(&InputEdit {
start_byte,
old_end_byte,
new_end_byte,
start_position: Point::new(
range.start.line as usize,
range.start.character as usize,
),
old_end_position,
new_end_position,
});
file.tree = match parse_tree(
&uri,
&file.rope.to_string(),
Some(old_tree),
) {
Ok(tree) => Some(tree),
Err(e) => {
error!("failed to edit tree: {e:?}");
None
}
};
}
(Err(e), _, _) | (_, Err(e), _) | (_, _, Err(e)) => {
error!("failed to build tree edit: {e:?}");
}
}
}
}
} else {
file.rope = Rope::from_str(&change.text);
if self.params.build_tree {
file.tree = match parse_tree(&uri, &change.text, None) {
Ok(tree) => Some(tree),
Err(e) => {
error!("failed to parse new tree: {e:?}");
None
}
};
}
}
}
self.accessed_files.lock().shift_insert(0, uri);
Ok(())
}
#[instrument(skip(self))]
fn renamed_files(&self, params: lsp_types::RenameFilesParams) -> anyhow::Result<()> {
for file_rename in params.files {
let mut file_map = self.file_map.lock();
if let Some(rope) = file_map.remove(&file_rename.old_uri) {
file_map.insert(file_rename.new_uri, rope);
}
}
Ok(())
}
}
// For teesting use only
#[cfg(test)]
impl FileStore {
pub fn default_with_filler_file() -> anyhow::Result<Self> {
let config = Config::default_with_file_store_without_models();
let file_store_config = if let config::ValidMemoryBackend::FileStore(file_store_config) =
config.config.memory.clone()
{
file_store_config
} else {
anyhow::bail!("requires a file_store_config")
};
let f = FileStore::new(file_store_config, config)?;
let uri = "file:///filler.py";
let text = r#"# Multiplies two numbers
def multiply_two_numbers(x, y):
return
# A singular test
assert multiply_two_numbers(2, 3) == 6
"#;
let params = lsp_types::DidOpenTextDocumentParams {
text_document: lsp_types::TextDocumentItem {
uri: reqwest::Url::parse(uri).unwrap(),
language_id: "filler".to_string(),
version: 0,
text: text.to_string(),
},
};
f.opened_text_document(params)?;
Ok(f)
}
}
#[cfg(test)]
mod tests {
use super::*;
use lsp_types::{
DidOpenTextDocumentParams, FileRename, Position, Range, RenameFilesParams,
TextDocumentContentChangeEvent, TextDocumentIdentifier, TextDocumentItem,
VersionedTextDocumentIdentifier,
};
use serde_json::json;
fn generate_base_file_store() -> anyhow::Result<FileStore> {
let config = Config::default_with_file_store_without_models();
let file_store_config = if let config::ValidMemoryBackend::FileStore(file_store_config) =
config.config.memory.clone()
{
file_store_config
} else {
anyhow::bail!("requires a file_store_config")
};
FileStore::new(file_store_config, config)
}
fn generate_filler_text_document(uri: Option<&str>, text: Option<&str>) -> TextDocumentItem {
let uri = uri.unwrap_or("file:///filler/");
let text = text.unwrap_or("Here is the document body");
TextDocumentItem {
uri: reqwest::Url::parse(uri).unwrap(),
language_id: "filler".to_string(),
version: 0,
text: text.to_string(),
}
}
#[test]
fn can_open_document() -> anyhow::Result<()> {
let params = lsp_types::DidOpenTextDocumentParams {
text_document: generate_filler_text_document(None, None),
};
let file_store = generate_base_file_store()?;
file_store.opened_text_document(params)?;
let file = file_store
.file_map
.lock()
.get("file:///filler/")
.unwrap()
.clone();
assert_eq!(file.rope.to_string(), "Here is the document body");
Ok(())
}
#[test]
fn can_rename_document() -> anyhow::Result<()> {
let params = lsp_types::DidOpenTextDocumentParams {
text_document: generate_filler_text_document(None, None),
};
let file_store = generate_base_file_store()?;
file_store.opened_text_document(params)?;
let params = RenameFilesParams {
files: vec![FileRename {
old_uri: "file:///filler/".to_string(),
new_uri: "file:///filler2/".to_string(),
}],
};
file_store.renamed_files(params)?;
let file = file_store
.file_map
.lock()
.get("file:///filler2/")
.unwrap()
.clone();
assert_eq!(file.rope.to_string(), "Here is the document body");
Ok(())
}
#[test]
fn can_change_document() -> anyhow::Result<()> {
let text_document = generate_filler_text_document(None, None);
let params = DidOpenTextDocumentParams {
text_document: text_document.clone(),
};
let file_store = generate_base_file_store()?;
file_store.opened_text_document(params)?;
let params = lsp_types::DidChangeTextDocumentParams {
text_document: VersionedTextDocumentIdentifier {
uri: text_document.uri.clone(),
version: 1,
},
content_changes: vec![TextDocumentContentChangeEvent {
range: Some(Range {
start: Position {
line: 0,
character: 1,
},
end: Position {
line: 0,
character: 3,
},
}),
range_length: None,
text: "a".to_string(),
}],
};
file_store.changed_text_document(params)?;
let file = file_store
.file_map
.lock()
.get("file:///filler/")
.unwrap()
.clone();
assert_eq!(file.rope.to_string(), "Hae is the document body");
let params = lsp_types::DidChangeTextDocumentParams {
text_document: VersionedTextDocumentIdentifier {
uri: text_document.uri,
version: 1,
},
content_changes: vec![TextDocumentContentChangeEvent {
range: None,
range_length: None,
text: "abc".to_string(),
}],
};
file_store.changed_text_document(params)?;
let file = file_store
.file_map
.lock()
.get("file:///filler/")
.unwrap()
.clone();
assert_eq!(file.rope.to_string(), "abc");
Ok(())
}
#[tokio::test]
async fn can_build_prompt() -> anyhow::Result<()> {
let text_document = generate_filler_text_document(
None,
Some(
r#"Document Top
Here is a more complicated document
Some text
The end with a trailing new line
"#,
),
);
// Test basic completion
let params = lsp_types::DidOpenTextDocumentParams {
text_document: text_document.clone(),
};
let file_store = generate_base_file_store()?;
file_store.opened_text_document(params)?;
let prompt = file_store
.build_prompt(
&TextDocumentPositionParams {
text_document: TextDocumentIdentifier {
uri: text_document.uri.clone(),
},
position: Position {
line: 0,
character: 10,
},
},
PromptType::ContextAndCode,
&json!({}),
)
.await?;
let prompt: ContextAndCodePrompt = prompt.try_into()?;
assert_eq!(prompt.context, "");
assert_eq!("Document T", prompt.code);
// Test FIM
let prompt = file_store
.build_prompt(
&TextDocumentPositionParams {
text_document: TextDocumentIdentifier {
uri: text_document.uri.clone(),
},
position: Position {
line: 0,
character: 10,
},
},
PromptType::FIM,
&json!({}),
)
.await?;
let prompt: FIMPrompt = prompt.try_into()?;
assert_eq!(prompt.prompt, r#"Document T"#);
assert_eq!(
prompt.suffix,
r#"op
Here is a more complicated document
Some text
The end with a trailing new line
"#
);
// Test chat
let prompt = file_store
.build_prompt(
&TextDocumentPositionParams {
text_document: TextDocumentIdentifier {
uri: text_document.uri.clone(),
},
position: Position {
line: 0,
character: 10,
},
},
PromptType::ContextAndCode,
&json!({
"messages": []
}),
)
.await?;
let prompt: ContextAndCodePrompt = prompt.try_into()?;
assert_eq!(prompt.context, "");
let text = r#"Document T<CURSOR>op
Here is a more complicated document
Some text
The end with a trailing new line
"#
.to_string();
assert_eq!(text, prompt.code);
// Test multi-file
let text_document2 = generate_filler_text_document(
Some("file:///filler2"),
Some(
r#"Document Top2
Here is a more complicated document
Some text
The end with a trailing new line
"#,
),
);
let params = lsp_types::DidOpenTextDocumentParams {
text_document: text_document2.clone(),
};
file_store.opened_text_document(params)?;
let prompt = file_store
.build_prompt(
&TextDocumentPositionParams {
text_document: TextDocumentIdentifier {
uri: text_document.uri.clone(),
},
position: Position {
line: 0,
character: 10,
},
},
PromptType::ContextAndCode,
&json!({}),
)
.await?;
let prompt: ContextAndCodePrompt = prompt.try_into()?;
assert_eq!(prompt.context, "");
assert_eq!(format!("{}\nDocument T", text_document2.text), prompt.code);
Ok(())
}
#[tokio::test]
async fn test_document_cursor_placement_corner_cases() -> anyhow::Result<()> {
let text_document = generate_filler_text_document(None, Some("test\n"));
let params = lsp_types::DidOpenTextDocumentParams {
text_document: text_document.clone(),
};
let file_store = generate_base_file_store()?;
file_store.opened_text_document(params)?;
// Test chat
let prompt = file_store
.build_prompt(
&TextDocumentPositionParams {
text_document: TextDocumentIdentifier {
uri: text_document.uri.clone(),
},
position: Position {
line: 1,
character: 0,
},
},
PromptType::ContextAndCode,
&json!({"messages": []}),
)
.await?;
let prompt: ContextAndCodePrompt = prompt.try_into()?;
assert_eq!(prompt.context, "");
let text = r#"test
<CURSOR>"#
.to_string();
assert_eq!(text, prompt.code);
Ok(())
}
#[test]
fn test_file_store_tree_sitter() -> anyhow::Result<()> {
crate::init_logger();
let config = Config::default_with_file_store_without_models();
let file_store_config = if let config::ValidMemoryBackend::FileStore(file_store_config) =
config.config.memory.clone()
{
file_store_config
} else {
anyhow::bail!("requires a file_store_config")
};
let params = AdditionalFileStoreParams { build_tree: true };
let file_store = FileStore::new_with_params(file_store_config, config, params)?;
let uri = "file:///filler/test.rs";
let text = r#"#[derive(Debug)]
struct Rectangle {
width: u32,
height: u32,
}
impl Rectangle {
fn area(&self) -> u32 {
}
}
fn main() {
let rect1 = Rectangle {
width: 30,
height: 50,
};
println!(
"The area of the rectangle is {} square pixels.",
rect1.area()
);
}"#;
let text_document = TextDocumentItem {
uri: reqwest::Url::parse(uri).unwrap(),
language_id: "".to_string(),
version: 0,
text: text.to_string(),
};
let params = DidOpenTextDocumentParams {
text_document: text_document.clone(),
};
file_store.opened_text_document(params)?;
// Test insert
let params = lsp_types::DidChangeTextDocumentParams {
text_document: VersionedTextDocumentIdentifier {
uri: text_document.uri.clone(),
version: 1,
},
content_changes: vec![TextDocumentContentChangeEvent {
range: Some(Range {
start: Position {
line: 8,
character: 0,
},
end: Position {
line: 8,
character: 0,
},
}),
range_length: None,
text: " self.width * self.height".to_string(),
}],
};
file_store.changed_text_document(params)?;
let file = file_store.file_map.lock().get(uri).unwrap().clone();
assert_eq!(file.tree.unwrap().root_node().to_sexp(), "(source_file (attribute_item (attribute (identifier) arguments: (token_tree (identifier)))) (struct_item name: (type_identifier) body: (field_declaration_list (field_declaration name: (field_identifier) type: (primitive_type)) (field_declaration name: (field_identifier) type: (primitive_type)))) (impl_item type: (type_identifier) body: (declaration_list (function_item name: (identifier) parameters: (parameters (self_parameter (self))) return_type: (primitive_type) body: (block (binary_expression left: (field_expression value: (self) field: (field_identifier)) right: (field_expression value: (self) field: (field_identifier))))))) (function_item name: (identifier) parameters: (parameters) body: (block (let_declaration pattern: (identifier) value: (struct_expression name: (type_identifier) body: (field_initializer_list (field_initializer field: (field_identifier) value: (integer_literal)) (field_initializer field: (field_identifier) value: (integer_literal))))) (expression_statement (macro_invocation macro: (identifier) (token_tree (string_literal (string_content)) (identifier) (identifier) (token_tree)))))))");
// Test delete
let params = lsp_types::DidChangeTextDocumentParams {
text_document: VersionedTextDocumentIdentifier {
uri: text_document.uri.clone(),
version: 1,
},
content_changes: vec![TextDocumentContentChangeEvent {
range: Some(Range {
start: Position {
line: 0,
character: 0,
},
end: Position {
line: 12,
character: 0,
},
}),
range_length: None,
text: "".to_string(),
}],
};
file_store.changed_text_document(params)?;
let file = file_store.file_map.lock().get(uri).unwrap().clone();
assert_eq!(file.tree.unwrap().root_node().to_sexp(), "(source_file (function_item name: (identifier) parameters: (parameters) body: (block (let_declaration pattern: (identifier) value: (struct_expression name: (type_identifier) body: (field_initializer_list (field_initializer field: (field_identifier) value: (integer_literal)) (field_initializer field: (field_identifier) value: (integer_literal))))) (expression_statement (macro_invocation macro: (identifier) (token_tree (string_literal (string_content)) (identifier) (identifier) (token_tree)))))))");
// Test replace
let params = lsp_types::DidChangeTextDocumentParams {
text_document: VersionedTextDocumentIdentifier {
uri: text_document.uri,
version: 1,
},
content_changes: vec![TextDocumentContentChangeEvent {
range: None,
range_length: None,
text: "fn main() {}".to_string(),
}],
};
file_store.changed_text_document(params)?;
let file = file_store.file_map.lock().get(uri).unwrap().clone();
assert_eq!(file.tree.unwrap().root_node().to_sexp(), "(source_file (function_item name: (identifier) parameters: (parameters) body: (block)))");
Ok(())
}
}

View File

@ -18,13 +18,13 @@ pub enum PromptType {
#[derive(Clone)]
pub struct MemoryRunParams {
pub is_for_chat: bool,
pub max_context_length: usize,
pub max_context: usize,
}
impl From<&Value> for MemoryRunParams {
fn from(value: &Value) -> Self {
Self {
max_context_length: value["max_context_length"].as_u64().unwrap_or(1024) as usize,
max_context: value["max_context"].as_u64().unwrap_or(1024) as usize,
// messages are for most backends, contents are for Gemini
is_for_chat: value["messages"].is_array() || value["contents"].is_array(),
}
@ -113,22 +113,16 @@ pub trait MemoryBackend {
async fn init(&self) -> anyhow::Result<()> {
Ok(())
}
async fn opened_text_document(&self, params: DidOpenTextDocumentParams) -> anyhow::Result<()>;
async fn changed_text_document(
&self,
params: DidChangeTextDocumentParams,
) -> anyhow::Result<()>;
async fn renamed_files(&self, params: RenameFilesParams) -> anyhow::Result<()>;
fn opened_text_document(&self, params: DidOpenTextDocumentParams) -> anyhow::Result<()>;
fn changed_text_document(&self, params: DidChangeTextDocumentParams) -> anyhow::Result<()>;
fn renamed_files(&self, params: RenameFilesParams) -> anyhow::Result<()>;
fn get_filter_text(&self, position: &TextDocumentPositionParams) -> anyhow::Result<String>;
async fn build_prompt(
&self,
position: &TextDocumentPositionParams,
prompt_type: PromptType,
params: &Value,
) -> anyhow::Result<Prompt>;
async fn get_filter_text(
&self,
position: &TextDocumentPositionParams,
) -> anyhow::Result<String>;
}
impl TryFrom<Config> for Box<dyn MemoryBackend + Send + Sync> {
@ -137,7 +131,7 @@ impl TryFrom<Config> for Box<dyn MemoryBackend + Send + Sync> {
fn try_from(configuration: Config) -> Result<Self, Self::Error> {
match configuration.config.memory.clone() {
ValidMemoryBackend::FileStore(file_store_config) => Ok(Box::new(
file_store::FileStore::new(file_store_config, configuration),
file_store::FileStore::new(file_store_config, configuration)?,
)),
ValidMemoryBackend::PostgresML(postgresml_config) => Ok(Box::new(
postgresml::PostgresML::new(postgresml_config, configuration)?,

View File

@ -0,0 +1,701 @@
use anyhow::Context;
use lsp_types::TextDocumentPositionParams;
use parking_lot::Mutex;
use pgml::{Collection, Pipeline};
use rand::{distributions::Alphanumeric, Rng};
use serde_json::{json, Value};
use std::{
collections::HashSet,
io::Read,
sync::{
mpsc::{self, Sender},
Arc,
},
time::Duration,
};
use tokio::time;
use tracing::{error, instrument, warn};
use crate::{
config::{self, Config},
crawl::Crawl,
splitters::{Chunk, Splitter},
utils::{chunk_to_id, tokens_to_estimated_characters, TOKIO_RUNTIME},
};
use super::{
file_store::{AdditionalFileStoreParams, FileStore},
ContextAndCodePrompt, FIMPrompt, MemoryBackend, MemoryRunParams, Prompt, PromptType,
};
const RESYNC_MAX_FILE_SIZE: u64 = 10_000_000;
fn format_file_excerpt(uri: &str, excerpt: &str, root_uri: Option<&str>) -> String {
let path = match root_uri {
Some(root_uri) => {
if uri.starts_with(root_uri) {
&uri[root_uri.chars().count()..]
} else {
uri
}
}
None => uri,
};
format!(
r#"--{path}--
{excerpt}
"#,
)
}
fn chunk_to_document(uri: &str, chunk: Chunk, root_uri: Option<&str>) -> Value {
json!({
"id": chunk_to_id(uri, &chunk),
"uri": uri,
"text": format_file_excerpt(uri, &chunk.text, root_uri),
"range": chunk.range
})
}
async fn split_and_upsert_file(
uri: &str,
collection: &mut Collection,
file_store: Arc<FileStore>,
splitter: Arc<Box<dyn Splitter + Send + Sync>>,
root_uri: Option<&str>,
) -> anyhow::Result<()> {
// We need to make sure we don't hold the file_store lock while performing a network call
let chunks = {
file_store
.file_map()
.lock()
.get(uri)
.map(|f| splitter.split(f))
};
let chunks = chunks.with_context(|| format!("file not found for splitting: {uri}"))?;
let documents = chunks
.into_iter()
.map(|chunk| chunk_to_document(uri, chunk, root_uri).into())
.collect();
collection
.upsert_documents(documents, None)
.await
.context("PGML - Error upserting documents")
}
#[derive(Clone)]
pub struct PostgresML {
config: Config,
postgresml_config: config::PostgresML,
file_store: Arc<FileStore>,
collection: Collection,
pipeline: Pipeline,
debounce_tx: Sender<String>,
crawl: Option<Arc<Mutex<Crawl>>>,
splitter: Arc<Box<dyn Splitter + Send + Sync>>,
}
impl PostgresML {
#[instrument]
pub fn new(
mut postgresml_config: config::PostgresML,
configuration: Config,
) -> anyhow::Result<Self> {
let crawl = postgresml_config
.crawl
.take()
.map(|x| Arc::new(Mutex::new(Crawl::new(x, configuration.clone()))));
let splitter: Arc<Box<dyn Splitter + Send + Sync>> =
Arc::new(postgresml_config.splitter.clone().try_into()?);
let file_store = Arc::new(FileStore::new_with_params(
config::FileStore::new_without_crawl(),
configuration.clone(),
AdditionalFileStoreParams::new(splitter.does_use_tree_sitter()),
)?);
let database_url = if let Some(database_url) = postgresml_config.database_url.clone() {
database_url
} else {
std::env::var("PGML_DATABASE_URL").context("please provide either the `database_url` in the `postgresml` config, or set the `PGML_DATABASE_URL` environment variable")?
};
// Build our pipeline schema
let pipeline = match &postgresml_config.embedding_model {
Some(embedding_model) => {
json!({
"text": {
"semantic_search": {
"model": embedding_model.model,
"parameters": embedding_model.embed_parameters
}
}
})
}
None => {
json!({
"text": {
"semantic_search": {
"model": "intfloat/e5-small-v2",
"parameters": {
"prompt": "passage: "
}
}
}
})
}
};
// When building the collection name we include the Pipeline schema
// If the user changes the Pipeline schema, it will take affect without them having to delete the old files
let collection_name = match configuration.client_params.root_uri.clone() {
Some(root_uri) => format!(
"{:x}",
md5::compute(
format!("{root_uri}_{}", serde_json::to_string(&pipeline)?).as_bytes()
)
),
None => {
warn!("no root_uri provided in server configuration - generating random string for collection name");
rand::thread_rng()
.sample_iter(&Alphanumeric)
.take(21)
.map(char::from)
.collect()
}
};
let mut collection = Collection::new(&collection_name, Some(database_url))?;
let mut pipeline = Pipeline::new("v1", Some(pipeline.into()))?;
// Add the Pipeline to the Collection
TOKIO_RUNTIME.block_on(async {
collection
.add_pipeline(&mut pipeline)
.await
.context("PGML - error adding pipeline to collection")
})?;
// Setup up a debouncer for changed text documents
let (debounce_tx, debounce_rx) = mpsc::channel::<String>();
let mut task_collection = collection.clone();
let task_file_store = file_store.clone();
let task_splitter = splitter.clone();
let task_root_uri = configuration.client_params.root_uri.clone();
TOKIO_RUNTIME.spawn(async move {
let duration = Duration::from_millis(500);
let mut file_uris = Vec::new();
loop {
time::sleep(duration).await;
let new_uris: Vec<String> = debounce_rx.try_iter().collect();
if !new_uris.is_empty() {
for uri in new_uris {
if !file_uris.iter().any(|p| *p == uri) {
file_uris.push(uri);
}
}
} else {
if file_uris.is_empty() {
continue;
}
// Build the chunks for our changed files
let chunks: Vec<Vec<Chunk>> = match file_uris
.iter()
.map(|uri| {
let file_store = task_file_store.file_map().lock();
let file = file_store
.get(uri)
.with_context(|| format!("getting file for splitting: {uri}"))?;
anyhow::Ok(task_splitter.split(file))
})
.collect()
{
Ok(chunks) => chunks,
Err(e) => {
error!("{e:?}");
continue;
}
};
// Delete old chunks that no longer exist after the latest file changes
let delete_or_statements: Vec<Value> = file_uris
.iter()
.zip(&chunks)
.map(|(uri, chunks)| {
let ids: Vec<String> =
chunks.iter().map(|c| chunk_to_id(uri, c)).collect();
json!({
"$and": [
{
"uri": {
"$eq": uri
}
},
{
"id": {
"$nin": ids
}
}
]
})
})
.collect();
if let Err(e) = task_collection
.delete_documents(
json!({
"$or": delete_or_statements
})
.into(),
)
.await
.context("PGML - error deleting documents")
{
error!("{e:?}");
}
// Prepare and upsert our new chunks
let documents: Vec<pgml::types::Json> = chunks
.into_iter()
.zip(&file_uris)
.map(|(chunks, uri)| {
chunks
.into_iter()
.map(|chunk| {
chunk_to_document(&uri, chunk, task_root_uri.as_deref())
})
.collect::<Vec<Value>>()
})
.flatten()
.map(|f: Value| f.into())
.collect();
if let Err(e) = task_collection
.upsert_documents(documents, None)
.await
.context("PGML - error upserting changed files")
{
error!("{e:?}");
continue;
}
file_uris = Vec::new();
}
}
});
let s = Self {
config: configuration,
postgresml_config,
file_store,
collection,
pipeline,
debounce_tx,
crawl,
splitter,
};
// Resync our Collection
let task_s = s.clone();
TOKIO_RUNTIME.spawn(async move {
if let Err(e) = task_s.resync().await {
error!("{e:?}")
}
});
if let Err(e) = s.maybe_do_crawl(None) {
error!("{e:?}")
}
Ok(s)
}
async fn resync(&self) -> anyhow::Result<()> {
let mut collection = self.collection.clone();
let documents = collection
.get_documents(Some(
json!({
"limit": 100_000_000,
"keys": ["uri"]
})
.into(),
))
.await?;
let try_get_file_contents = |path: &std::path::Path| {
// Open the file and see if it is small enough to read
let mut f = std::fs::File::open(path)?;
let metadata = f.metadata()?;
if metadata.len() > RESYNC_MAX_FILE_SIZE {
anyhow::bail!("file size is greater than: {RESYNC_MAX_FILE_SIZE}")
}
// Read the file contents
let mut contents = vec![];
f.read_to_end(&mut contents)?;
anyhow::Ok(String::from_utf8(contents)?)
};
let mut documents_to_delete = vec![];
let mut chunks_to_upsert = vec![];
let mut current_chunks_bytes = 0;
let mut checked_uris = HashSet::new();
for document in documents.into_iter() {
let uri = match document["document"]["uri"].as_str() {
Some(uri) => uri,
None => continue, // This should never happen, but is really bad as we now have a document with essentially no way to delete it
};
// Check if we have already loaded in this file
if checked_uris.contains(uri) {
continue;
}
checked_uris.insert(uri.to_string());
let path = uri.replace("file://", "");
let path = std::path::Path::new(&path);
if !path.exists() {
documents_to_delete.push(uri.to_string());
} else {
// Try to read the file. If we fail delete it
let contents = match try_get_file_contents(path) {
Ok(contents) => contents,
Err(e) => {
error!("{e:?}");
documents_to_delete.push(uri.to_string());
continue;
}
};
// Split the file into chunks
current_chunks_bytes += contents.len();
let chunks: Vec<pgml::types::Json> = self
.splitter
.split_file_contents(&uri, &contents)
.into_iter()
.map(|chunk| {
chunk_to_document(
&uri,
chunk,
self.config.client_params.root_uri.as_deref(),
)
.into()
})
.collect();
chunks_to_upsert.extend(chunks);
// If we have over 10 mega bytes of chunks do the upsert
if current_chunks_bytes > 10_000_000 {
collection
.upsert_documents(chunks_to_upsert, None)
.await
.context("PGML - error upserting documents during resync")?;
chunks_to_upsert = vec![];
current_chunks_bytes = 0;
}
}
}
// Upsert any remaining chunks
if chunks_to_upsert.len() > 0 {
collection
.upsert_documents(chunks_to_upsert, None)
.await
.context("PGML - error upserting documents during resync")?;
}
// Delete documents
if !documents_to_delete.is_empty() {
collection
.delete_documents(
json!({
"uri": {
"$in": documents_to_delete
}
})
.into(),
)
.await
.context("PGML - error deleting documents during resync")?;
}
Ok(())
}
fn maybe_do_crawl(&self, triggered_file: Option<String>) -> anyhow::Result<()> {
if let Some(crawl) = &self.crawl {
let mut documents = vec![];
let mut total_bytes = 0;
let mut current_bytes = 0;
crawl
.lock()
.maybe_do_crawl(triggered_file, |config, path| {
// Break if total bytes is over the max crawl memory
if total_bytes as u64 >= config.max_crawl_memory {
warn!("Ending crawl early due to `max_crawl_memory` restraint");
return Ok(false);
}
// This means it has been opened before
let uri = format!("file://{path}");
if self.file_store.contains_file(&uri) {
return Ok(true);
}
// Open the file and see if it is small enough to read
let mut f = std::fs::File::open(path)?;
let metadata = f.metadata()?;
if metadata.len() > config.max_file_size {
warn!("Skipping file: {path} because it is too large");
return Ok(true);
}
// Read the file contents
let mut contents = vec![];
f.read_to_end(&mut contents)?;
let contents = String::from_utf8(contents)?;
current_bytes += contents.len();
total_bytes += contents.len();
let chunks: Vec<pgml::types::Json> = self
.splitter
.split_file_contents(&uri, &contents)
.into_iter()
.map(|chunk| {
chunk_to_document(
&uri,
chunk,
self.config.client_params.root_uri.as_deref(),
)
.into()
})
.collect();
documents.extend(chunks);
// If we have over 10 mega bytes of data do the upsert
if current_bytes >= 10_000_000 || total_bytes as u64 >= config.max_crawl_memory
{
// Upsert the documents
let mut collection = self.collection.clone();
let to_upsert_documents = std::mem::take(&mut documents);
TOKIO_RUNTIME.spawn(async move {
if let Err(e) = collection
.upsert_documents(to_upsert_documents, None)
.await
.context("PGML - error upserting changed files")
{
error!("{e:?}");
}
});
// Reset everything
current_bytes = 0;
documents = vec![];
}
Ok(true)
})?;
// Upsert any remaining documents
if documents.len() > 0 {
let mut collection = self.collection.clone();
TOKIO_RUNTIME.spawn(async move {
if let Err(e) = collection
.upsert_documents(documents, None)
.await
.context("PGML - error upserting changed files")
{
error!("{e:?}");
}
});
}
}
Ok(())
}
}
#[async_trait::async_trait]
impl MemoryBackend for PostgresML {
#[instrument(skip(self))]
fn get_filter_text(&self, position: &TextDocumentPositionParams) -> anyhow::Result<String> {
self.file_store.get_filter_text(position)
}
#[instrument(skip(self))]
async fn build_prompt(
&self,
position: &TextDocumentPositionParams,
prompt_type: PromptType,
params: &Value,
) -> anyhow::Result<Prompt> {
let params: MemoryRunParams = params.try_into()?;
let chunk_size = self.splitter.chunk_size();
let total_allowed_characters = tokens_to_estimated_characters(params.max_context);
// Build the query
let query = self
.file_store
.get_characters_around_position(position, chunk_size)?;
// Build the prompt
let mut file_store_params = params.clone();
file_store_params.max_context = chunk_size;
let code = self
.file_store
.build_code(position, prompt_type, file_store_params, false)?;
// Get the byte of the cursor
let cursor_byte = self.file_store.position_to_byte(position)?;
// Get the context
let limit = (total_allowed_characters / chunk_size).saturating_sub(1);
let parameters = match self
.postgresml_config
.embedding_model
.as_ref()
.map(|m| m.query_parameters.clone())
.flatten()
{
Some(query_parameters) => query_parameters,
None => json!({
"prompt": "query: "
}),
};
let res = self
.collection
.vector_search_local(
json!({
"query": {
"fields": {
"text": {
"query": query,
"parameters": parameters
}
},
"filter": {
"$or": [
{
"uri": {
"$ne": position.text_document.uri.to_string()
}
},
{
"range": {
"start": {
"$gt": cursor_byte
},
},
},
{
"range": {
"end": {
"$lt": cursor_byte
},
}
}
]
}
},
"limit": limit
})
.into(),
&self.pipeline,
)
.await?;
let context = res
.into_iter()
.map(|c| {
c["chunk"]
.as_str()
.map(|t| t.to_owned())
.context("PGML - Error getting chunk from vector search")
})
.collect::<anyhow::Result<Vec<String>>>()?
.join("\n\n");
let context = &context[..(total_allowed_characters - chunk_size).min(context.len())];
// Reconstruct the Prompts
Ok(match code {
Prompt::ContextAndCode(context_and_code) => {
Prompt::ContextAndCode(ContextAndCodePrompt::new(
context.to_owned(),
format_file_excerpt(
&position.text_document.uri.to_string(),
&context_and_code.code,
self.config.client_params.root_uri.as_deref(),
),
))
}
Prompt::FIM(fim) => Prompt::FIM(FIMPrompt::new(
format!("{context}\n\n{}", fim.prompt),
fim.suffix,
)),
})
}
#[instrument(skip(self))]
fn opened_text_document(
&self,
params: lsp_types::DidOpenTextDocumentParams,
) -> anyhow::Result<()> {
self.file_store.opened_text_document(params.clone())?;
let saved_uri = params.text_document.uri.to_string();
let mut collection = self.collection.clone();
let file_store = self.file_store.clone();
let splitter = self.splitter.clone();
let root_uri = self.config.client_params.root_uri.clone();
TOKIO_RUNTIME.spawn(async move {
let uri = params.text_document.uri.to_string();
if let Err(e) = split_and_upsert_file(
&uri,
&mut collection,
file_store,
splitter,
root_uri.as_deref(),
)
.await
{
error!("{e:?}")
}
});
if let Err(e) = self.maybe_do_crawl(Some(saved_uri)) {
error!("{e:?}")
}
Ok(())
}
#[instrument(skip(self))]
fn changed_text_document(
&self,
params: lsp_types::DidChangeTextDocumentParams,
) -> anyhow::Result<()> {
self.file_store.changed_text_document(params.clone())?;
let uri = params.text_document.uri.to_string();
self.debounce_tx.send(uri)?;
Ok(())
}
#[instrument(skip(self))]
fn renamed_files(&self, params: lsp_types::RenameFilesParams) -> anyhow::Result<()> {
self.file_store.renamed_files(params.clone())?;
let mut collection = self.collection.clone();
let file_store = self.file_store.clone();
let splitter = self.splitter.clone();
let root_uri = self.config.client_params.root_uri.clone();
TOKIO_RUNTIME.spawn(async move {
for file in params.files {
if let Err(e) = collection
.delete_documents(
json!({
"uri": {
"$eq": file.old_uri
}
})
.into(),
)
.await
{
error!("PGML - Error deleting file: {e:?}");
}
if let Err(e) = split_and_upsert_file(
&file.new_uri,
&mut collection,
file_store.clone(),
splitter.clone(),
root_uri.as_deref(),
)
.await
{
error!("{e:?}")
}
}
});
Ok(())
}
}

View File

@ -7,7 +7,10 @@ use lsp_types::{
use serde_json::Value;
use tracing::error;
use crate::memory_backends::{MemoryBackend, Prompt, PromptType};
use crate::{
memory_backends::{MemoryBackend, Prompt, PromptType},
utils::TOKIO_RUNTIME,
};
#[derive(Debug)]
pub struct PromptRequest {
@ -56,34 +59,45 @@ pub enum WorkerRequest {
DidRenameFiles(RenameFilesParams),
}
async fn do_task(
async fn do_build_prompt(
params: PromptRequest,
memory_backend: Arc<Box<dyn MemoryBackend + Send + Sync>>,
) -> anyhow::Result<()> {
let prompt = memory_backend
.build_prompt(&params.position, params.prompt_type, &params.params)
.await?;
params
.tx
.send(prompt)
.map_err(|_| anyhow::anyhow!("sending on channel failed"))
}
fn do_task(
request: WorkerRequest,
memory_backend: Arc<Box<dyn MemoryBackend + Send + Sync>>,
) -> anyhow::Result<()> {
match request {
WorkerRequest::FilterText(params) => {
let filter_text = memory_backend.get_filter_text(&params.position).await?;
let filter_text = memory_backend.get_filter_text(&params.position)?;
params
.tx
.send(filter_text)
.map_err(|_| anyhow::anyhow!("sending on channel failed"))?;
}
WorkerRequest::Prompt(params) => {
let prompt = memory_backend
.build_prompt(&params.position, params.prompt_type, &params.params)
.await?;
params
.tx
.send(prompt)
.map_err(|_| anyhow::anyhow!("sending on channel failed"))?;
TOKIO_RUNTIME.spawn(async move {
if let Err(e) = do_build_prompt(params, memory_backend).await {
error!("error in memory worker building prompt: {e}")
}
});
}
WorkerRequest::DidOpenTextDocument(params) => {
memory_backend.opened_text_document(params).await?;
memory_backend.opened_text_document(params)?;
}
WorkerRequest::DidChangeTextDocument(params) => {
memory_backend.changed_text_document(params).await?;
memory_backend.changed_text_document(params)?;
}
WorkerRequest::DidRenameFiles(params) => memory_backend.renamed_files(params).await?,
WorkerRequest::DidRenameFiles(params) => memory_backend.renamed_files(params)?,
}
anyhow::Ok(())
}
@ -93,18 +107,11 @@ fn do_run(
rx: std::sync::mpsc::Receiver<WorkerRequest>,
) -> anyhow::Result<()> {
let memory_backend = Arc::new(memory_backend);
let runtime = tokio::runtime::Builder::new_multi_thread()
.worker_threads(4)
.enable_all()
.build()?;
loop {
let request = rx.recv()?;
let thread_memory_backend = memory_backend.clone();
runtime.spawn(async move {
if let Err(e) = do_task(request, thread_memory_backend).await {
error!("error in memory worker task: {e}")
}
});
if let Err(e) = do_task(request, memory_backend.clone()) {
error!("error in memory worker task: {e}")
}
}
}

View File

@ -0,0 +1,59 @@
use serde::Serialize;
use crate::{config::ValidSplitter, memory_backends::file_store::File};
mod text_splitter;
mod tree_sitter;
#[derive(Serialize)]
pub struct ByteRange {
pub start_byte: usize,
pub end_byte: usize,
}
impl ByteRange {
pub fn new(start_byte: usize, end_byte: usize) -> Self {
Self {
start_byte,
end_byte,
}
}
}
#[derive(Serialize)]
pub struct Chunk {
pub text: String,
pub range: ByteRange,
}
impl Chunk {
fn new(text: String, range: ByteRange) -> Self {
Self { text, range }
}
}
pub trait Splitter {
fn split(&self, file: &File) -> Vec<Chunk>;
fn split_file_contents(&self, uri: &str, contents: &str) -> Vec<Chunk>;
fn does_use_tree_sitter(&self) -> bool {
false
}
fn chunk_size(&self) -> usize;
}
impl TryFrom<ValidSplitter> for Box<dyn Splitter + Send + Sync> {
type Error = anyhow::Error;
fn try_from(value: ValidSplitter) -> Result<Self, Self::Error> {
match value {
ValidSplitter::TreeSitter(config) => {
Ok(Box::new(tree_sitter::TreeSitter::new(config)?))
}
ValidSplitter::TextSplitter(config) => {
Ok(Box::new(text_splitter::TextSplitter::new(config)))
}
}
}
}

View File

@ -0,0 +1,47 @@
use crate::{config, memory_backends::file_store::File};
use super::{ByteRange, Chunk, Splitter};
pub struct TextSplitter {
chunk_size: usize,
splitter: text_splitter::TextSplitter<text_splitter::Characters>,
}
impl TextSplitter {
pub fn new(config: config::TextSplitter) -> Self {
Self {
chunk_size: config.chunk_size,
splitter: text_splitter::TextSplitter::new(config.chunk_size),
}
}
pub fn new_with_chunk_size(chunk_size: usize) -> Self {
Self {
chunk_size,
splitter: text_splitter::TextSplitter::new(chunk_size),
}
}
}
impl Splitter for TextSplitter {
fn split(&self, file: &File) -> Vec<Chunk> {
self.split_file_contents("", &file.rope().to_string())
}
fn split_file_contents(&self, _uri: &str, contents: &str) -> Vec<Chunk> {
self.splitter
.chunk_indices(contents)
.fold(vec![], |mut acc, (start_byte, text)| {
let end_byte = start_byte + text.len();
acc.push(Chunk::new(
text.to_string(),
ByteRange::new(start_byte, end_byte),
));
acc
})
}
fn chunk_size(&self) -> usize {
self.chunk_size
}
}

View File

@ -0,0 +1,84 @@
use splitter_tree_sitter::TreeSitterCodeSplitter;
use tracing::error;
use tree_sitter::Tree;
use crate::{config, memory_backends::file_store::File, utils::parse_tree};
use super::{text_splitter::TextSplitter, ByteRange, Chunk, Splitter};
pub struct TreeSitter {
chunk_size: usize,
splitter: TreeSitterCodeSplitter,
text_splitter: TextSplitter,
}
impl TreeSitter {
pub fn new(config: config::TreeSitter) -> anyhow::Result<Self> {
let text_splitter = TextSplitter::new_with_chunk_size(config.chunk_size);
Ok(Self {
chunk_size: config.chunk_size,
splitter: TreeSitterCodeSplitter::new(config.chunk_size, config.chunk_overlap)?,
text_splitter,
})
}
fn split_tree(&self, tree: &Tree, contents: &[u8]) -> anyhow::Result<Vec<Chunk>> {
Ok(self
.splitter
.split(tree, contents)?
.into_iter()
.map(|c| {
Chunk::new(
c.text.to_owned(),
ByteRange::new(c.range.start_byte, c.range.end_byte),
)
})
.collect())
}
}
impl Splitter for TreeSitter {
fn split(&self, file: &File) -> Vec<Chunk> {
if let Some(tree) = file.tree() {
match self.split_tree(tree, file.rope().to_string().as_bytes()) {
Ok(chunks) => chunks,
Err(e) => {
error!(
"Failed to parse tree for file with error: {e:?}. Falling back to default splitter.",
);
self.text_splitter.split(file)
}
}
} else {
self.text_splitter.split(file)
}
}
fn split_file_contents(&self, uri: &str, contents: &str) -> Vec<Chunk> {
match parse_tree(uri, contents, None) {
Ok(tree) => match self.split_tree(&tree, contents.as_bytes()) {
Ok(chunks) => chunks,
Err(e) => {
error!(
"Failed to parse tree for file: {uri} with error: {e:?}. Falling back to default splitter.",
);
self.text_splitter.split_file_contents(uri, contents)
}
},
Err(e) => {
error!(
"Failed to parse tree for file {uri} with error: {e:?}. Falling back to default splitter.",
);
self.text_splitter.split_file_contents(uri, contents)
}
}
}
fn does_use_tree_sitter(&self) -> bool {
true
}
fn chunk_size(&self) -> usize {
self.chunk_size
}
}

View File

@ -67,11 +67,11 @@ impl Ollama {
) -> anyhow::Result<String> {
let client = reqwest::Client::new();
let res: OllamaCompletionsResponse = client
.post(self
.configuration
.generate_endpoint
.as_deref()
.unwrap_or("http://localhost:11434/api/generate")
.post(
self.configuration
.generate_endpoint
.as_deref()
.unwrap_or("http://localhost:11434/api/generate"),
)
.header("Content-Type", "application/json")
.header("Accept", "application/json")
@ -106,11 +106,11 @@ impl Ollama {
) -> anyhow::Result<String> {
let client = reqwest::Client::new();
let res: OllamaChatResponse = client
.post(self
.configuration
.chat_endpoint
.as_deref()
.unwrap_or("http://localhost:11434/api/chat")
.post(
self.configuration
.chat_endpoint
.as_deref()
.unwrap_or("http://localhost:11434/api/chat"),
)
.header("Content-Type", "application/json")
.header("Accept", "application/json")

View File

@ -17,7 +17,7 @@ use crate::custom_requests::generation_stream::GenerationStreamParams;
use crate::memory_backends::Prompt;
use crate::memory_worker::{self, FilterRequest, PromptRequest};
use crate::transformer_backends::TransformerBackend;
use crate::utils::ToResponseError;
use crate::utils::{ToResponseError, TOKIO_RUNTIME};
#[derive(Clone, Debug)]
pub struct CompletionRequest {
@ -89,14 +89,14 @@ pub struct DoGenerationStreamResponse {
fn post_process_start(response: String, front: &str) -> String {
let mut front_match = response.len();
loop {
if response.len() == 0 || front.ends_with(&response[..front_match]) {
if response.is_empty() || front.ends_with(&response[..front_match]) {
break;
} else {
front_match -= 1;
}
}
if front_match > 0 {
(&response[front_match..]).to_owned()
response[front_match..].to_owned()
} else {
response
}
@ -105,16 +105,14 @@ fn post_process_start(response: String, front: &str) -> String {
fn post_process_end(response: String, back: &str) -> String {
let mut back_match = 0;
loop {
if back_match == response.len() {
break;
} else if back.starts_with(&response[back_match..]) {
if back_match == response.len() || back.starts_with(&response[back_match..]) {
break;
} else {
back_match += 1;
}
}
if back_match > 0 {
(&response[..back_match]).to_owned()
response[..back_match].to_owned()
} else {
response
}
@ -140,12 +138,10 @@ fn post_process_response(
} else {
response
}
} else if config.remove_duplicate_start {
post_process_start(response, &context_and_code.code)
} else {
if config.remove_duplicate_start {
post_process_start(response, &context_and_code.code)
} else {
response
}
response
}
}
Prompt::FIM(fim) => {
@ -177,7 +173,7 @@ pub fn run(
connection,
config,
) {
error!("error in transformer worker: {e}")
error!("error in transformer worker: {e:?}")
}
}
@ -189,10 +185,6 @@ fn do_run(
config: Config,
) -> anyhow::Result<()> {
let transformer_backends = Arc::new(transformer_backends);
let runtime = tokio::runtime::Builder::new_multi_thread()
.worker_threads(4)
.enable_all()
.build()?;
// If they have disabled completions, this function will fail. We set it to MIN_POSITIVE to never process a completions request
let max_requests_per_second = config
@ -206,7 +198,7 @@ fn do_run(
let task_transformer_backends = transformer_backends.clone();
let task_memory_backend_tx = memory_backend_tx.clone();
let task_config = config.clone();
runtime.spawn(async move {
TOKIO_RUNTIME.spawn(async move {
dispatch_request(
request,
task_connection,
@ -264,7 +256,7 @@ async fn dispatch_request(
{
Ok(response) => response,
Err(e) => {
error!("generating response: {e}");
error!("generating response: {e:?}");
Response {
id: request.get_id(),
result: None,
@ -274,7 +266,7 @@ async fn dispatch_request(
};
if let Err(e) = connection.sender.send(Message::Response(response)) {
error!("sending response: {e}");
error!("sending response: {e:?}");
}
}
@ -293,14 +285,12 @@ async fn generate_response(
.context("Completions is none")?;
let transformer_backend = transformer_backends
.get(&completion_config.model)
.clone()
.with_context(|| format!("can't find model: {}", &completion_config.model))?;
do_completion(transformer_backend, memory_backend_tx, &request, &config).await
}
WorkerRequest::Generation(request) => {
let transformer_backend = transformer_backends
.get(&request.params.model)
.clone()
.with_context(|| format!("can't find model: {}", &request.params.model))?;
do_generate(transformer_backend, memory_backend_tx, &request).await
}
@ -346,7 +336,6 @@ async fn do_completion(
// Get the response
let mut response = transformer_backend.do_completion(&prompt, params).await?;
eprintln!("\n\n\n\nGOT RESPONSE: {}\n\n\n\n", response.insert_text);
if let Some(post_process) = config.get_completions_post_process() {
response.insert_text = post_process_response(response.insert_text, &prompt, &post_process);
@ -423,7 +412,103 @@ async fn do_generate(
#[cfg(test)]
mod tests {
use super::*;
use crate::memory_backends::{ContextAndCodePrompt, FIMPrompt};
use crate::memory_backends::{
file_store::FileStore, ContextAndCodePrompt, FIMPrompt, MemoryBackend,
};
use serde_json::json;
use std::{sync::mpsc, thread};
#[tokio::test]
async fn test_do_completion() -> anyhow::Result<()> {
let (memory_tx, memory_rx) = mpsc::channel();
let memory_backend: Box<dyn MemoryBackend + Send + Sync> =
Box::new(FileStore::default_with_filler_file()?);
thread::spawn(move || memory_worker::run(memory_backend, memory_rx));
let transformer_backend: Box<dyn TransformerBackend + Send + Sync> =
config::ValidModel::Ollama(serde_json::from_value(
json!({"model": "deepseek-coder:1.3b-base"}),
)?)
.try_into()?;
let completion_request = CompletionRequest::new(
serde_json::from_value(json!(0))?,
serde_json::from_value(json!({
"position": {"character":10, "line":2},
"textDocument": {
"uri": "file:///filler.py"
}
}))?,
);
let mut config = config::Config::default_with_file_store_without_models();
config.config.completion = Some(serde_json::from_value(json!({
"model": "model1",
"parameters": {
"options": {
"temperature": 0
}
}
}))?);
let result = do_completion(
&transformer_backend,
memory_tx,
&completion_request,
&config,
)
.await?;
assert_eq!(
" x * y",
result.result.clone().unwrap()["items"][0]["textEdit"]["newText"]
.as_str()
.unwrap()
);
assert_eq!(
" return",
result.result.unwrap()["items"][0]["filterText"]
.as_str()
.unwrap()
);
Ok(())
}
#[tokio::test]
async fn test_do_generate() -> anyhow::Result<()> {
let (memory_tx, memory_rx) = mpsc::channel();
let memory_backend: Box<dyn MemoryBackend + Send + Sync> =
Box::new(FileStore::default_with_filler_file()?);
thread::spawn(move || memory_worker::run(memory_backend, memory_rx));
let transformer_backend: Box<dyn TransformerBackend + Send + Sync> =
config::ValidModel::Ollama(serde_json::from_value(
json!({"model": "deepseek-coder:1.3b-base"}),
)?)
.try_into()?;
let generation_request = GenerationRequest::new(
serde_json::from_value(json!(0))?,
serde_json::from_value(json!({
"position": {"character":10, "line":2},
"textDocument": {
"uri": "file:///filler.py"
},
"model": "model1",
"parameters": {
"options": {
"temperature": 0
}
}
}))?,
);
let result = do_generate(&transformer_backend, memory_tx, &generation_request).await?;
assert_eq!(
" x * y",
result.result.unwrap()["generatedText"].as_str().unwrap()
);
Ok(())
}
#[test]
fn test_post_process_fim() {

View File

@ -1,6 +1,18 @@
use anyhow::Context;
use lsp_server::ResponseError;
use once_cell::sync::Lazy;
use tokio::runtime;
use tree_sitter::Tree;
use crate::{config::ChatMessage, memory_backends::ContextAndCodePrompt};
use crate::{config::ChatMessage, memory_backends::ContextAndCodePrompt, splitters::Chunk};
pub static TOKIO_RUNTIME: Lazy<runtime::Runtime> = Lazy::new(|| {
runtime::Builder::new_multi_thread()
.worker_threads(4)
.enable_all()
.build()
.expect("Error building tokio runtime")
});
pub trait ToResponseError {
fn to_response_error(&self, code: i32) -> ResponseError;
@ -42,3 +54,17 @@ pub fn format_context_code_in_str(s: &str, context: &str, code: &str) -> String
pub fn format_context_code(context: &str, code: &str) -> String {
format!("{context}\n\n{code}")
}
pub fn chunk_to_id(uri: &str, chunk: &Chunk) -> String {
format!("{uri}#{}-{}", chunk.range.start_byte, chunk.range.end_byte)
}
pub fn parse_tree(uri: &str, contents: &str, old_tree: Option<&Tree>) -> anyhow::Result<Tree> {
let path = std::path::Path::new(uri);
let extension = path.extension().map(|x| x.to_string_lossy());
let extension = extension.as_deref().unwrap_or("");
let mut parser = utils_tree_sitter::get_parser_for_extension(extension)?;
parser
.parse(&contents, old_tree)
.with_context(|| format!("parsing tree failed for {uri}"))
}

View File

@ -0,0 +1,289 @@
use anyhow::Result;
use std::{
io::{Read, Write},
process::{ChildStdin, ChildStdout, Command, Stdio},
};
// Note if you get an empty response with no error, that typically means
// the language server died
fn read_response(stdout: &mut ChildStdout) -> Result<String> {
let mut content_length = None;
let mut buf = vec![];
loop {
let mut buf2 = vec![0];
stdout.read_exact(&mut buf2)?;
buf.push(buf2[0]);
if let Some(content_length) = content_length {
if buf.len() == content_length {
break;
}
} else {
let len = buf.len();
if len > 4
&& buf[len - 4] == 13
&& buf[len - 3] == 10
&& buf[len - 2] == 13
&& buf[len - 1] == 10
{
content_length =
Some(String::from_utf8(buf[16..len - 4].to_vec())?.parse::<usize>()?);
buf = vec![];
}
}
}
Ok(String::from_utf8(buf)?)
}
fn send_message(stdin: &mut ChildStdin, message: &str) -> Result<()> {
stdin.write_all(format!("Content-Length: {}\r\n", message.as_bytes().len(),).as_bytes())?;
stdin.write_all("\r\n".as_bytes())?;
stdin.write_all(message.as_bytes())?;
Ok(())
}
// This chat completion sequence was created using helix with lsp-ai and reading the logs
// It utilizes Ollama with llama3:8b-instruct-q4_0 and a temperature of 0
// It starts with a Python file:
// ```
// # Multiplies two numbers
// def multiply_two_numbers(x, y):
//
// # A singular test
// assert multiply_two_numbers(2, 3) == 6
//
// ```
// And has the following sequence of key strokes:
// o on line 2 (this creates an indented new line and enters insert mode)
// r
// e
// t
// u
// r
// n
// The sequence has:
// - 1 textDocument/DidOpen notification
// - 7 textDocument/didChange notifications
// - 1 textDocument/completion requests
#[test]
fn test_chat_completion_sequence() -> Result<()> {
let mut child = Command::new("cargo")
.arg("run")
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.spawn()?;
let mut stdin = child.stdin.take().unwrap();
let mut stdout = child.stdout.take().unwrap();
let initialization_message = r##"{"jsonrpc":"2.0","method":"initialize","params":{"capabilities":{"general":{"positionEncodings":["utf-8","utf-32","utf-16"]},"textDocument":{"codeAction":{"codeActionLiteralSupport":{"codeActionKind":{"valueSet":["","quickfix","refactor","refactor.extract","refactor.inline","refactor.rewrite","source","source.organizeImports"]}},"dataSupport":true,"disabledSupport":true,"isPreferredSupport":true,"resolveSupport":{"properties":["edit","command"]}},"completion":{"completionItem":{"deprecatedSupport":true,"insertReplaceSupport":true,"resolveSupport":{"properties":["documentation","detail","additionalTextEdits"]},"snippetSupport":true,"tagSupport":{"valueSet":[1]}},"completionItemKind":{}},"hover":{"contentFormat":["markdown"]},"inlayHint":{"dynamicRegistration":false},"publishDiagnostics":{"tagSupport":{"valueSet":[1,2]},"versionSupport":true},"rename":{"dynamicRegistration":false,"honorsChangeAnnotations":false,"prepareSupport":true},"signatureHelp":{"signatureInformation":{"activeParameterSupport":true,"documentationFormat":["markdown"],"parameterInformation":{"labelOffsetSupport":true}}}},"window":{"workDoneProgress":true},"workspace":{"applyEdit":true,"configuration":true,"didChangeConfiguration":{"dynamicRegistration":false},"didChangeWatchedFiles":{"dynamicRegistration":true,"relativePatternSupport":false},"executeCommand":{"dynamicRegistration":false},"fileOperations":{"didRename":true,"willRename":true},"inlayHint":{"refreshSupport":false},"symbol":{"dynamicRegistration":false},"workspaceEdit":{"documentChanges":true,"failureHandling":"abort","normalizesLineEndings":false,"resourceOperations":["create","rename","delete"]},"workspaceFolders":true}},"clientInfo":{"name":"helix","version":"24.3 (beb5afcb)"},"initializationOptions":{"completion":{"model":"model1","parameters":{"max_context":1024,"messages":[{"content":"Instructions:\n- You are an AI programming assistant.\n- Given a piece of code with the cursor location marked by \"<CURSOR>\", replace \"<CURSOR>\" with the correct code or comment.\n- First, think step-by-step.\n- Describe your plan for what to build in pseudocode, written out in great detail.\n- Then output the code replacing the \"<CURSOR>\"\n- Ensure that your completion fits within the language context of the provided code snippet (e.g., Python, JavaScript, Rust).\n\nRules:\n- Only respond with code or comments.\n- Only replace \"<CURSOR>\"; do not include any previously written code.\n- Never include \"<CURSOR>\" in your response\n- If the cursor is within a comment, complete the comment meaningfully.\n- Handle ambiguous cases by providing the most contextually appropriate completion.\n- Be consistent with your responses.","role":"system"},{"content":"def greet(name):\n print(f\"Hello, {<CURSOR>}\")","role":"user"},{"content":"name","role":"assistant"},{"content":"function sum(a, b) {\n return a + <CURSOR>;\n}","role":"user"},{"content":"b","role":"assistant"},{"content":"fn multiply(a: i32, b: i32) -> i32 {\n a * <CURSOR>\n}","role":"user"},{"content":"b","role":"assistant"},{"content":"# <CURSOR>\ndef add(a, b):\n return a + b","role":"user"},{"content":"Adds two numbers","role":"assistant"},{"content":"# This function checks if a number is even\n<CURSOR>","role":"user"},{"content":"def is_even(n):\n return n % 2 == 0","role":"assistant"},{"content":"{CODE}","role":"user"}],"options":{"num_predict":32,"temperature":0}}},"memory":{"file_store":{}},"models":{"model1":{"model":"llama3:8b-instruct-q4_0","type":"ollama"}}},"processId":66009,"rootPath":"/home/silas/Projects/test","rootUri":null,"workspaceFolders":[]},"id":0}"##;
send_message(&mut stdin, initialization_message)?;
let _ = read_response(&mut stdout)?;
send_message(
&mut stdin,
r#"{"jsonrpc":"2.0","method":"initialized","params":{}}"#,
)?;
send_message(
&mut stdin,
r##"{"jsonrpc":"2.0","method":"textDocument/didOpen","params":{"textDocument":{"languageId":"python","text":"# Multiplies two numbers\ndef multiply_two_numbers(x, y):\n\n# A singular test\nassert multiply_two_numbers(2, 3) == 6\n","uri":"file:///fake.py","version":0}}}"##,
)?;
send_message(
&mut stdin,
r##"{"jsonrpc":"2.0","method":"textDocument/didChange","params":{"contentChanges":[{"range":{"end":{"character":31,"line":1},"start":{"character":31,"line":1}},"text":"\n "}],"textDocument":{"uri":"file:///fake.py","version":1}}}"##,
)?;
send_message(
&mut stdin,
r##"{"jsonrpc":"2.0","method":"textDocument/didChange","params":{"contentChanges":[{"range":{"end":{"character":4,"line":2},"start":{"character":4,"line":2}},"text":"r"}],"textDocument":{"uri":"file:///fake.py","version":2}}}"##,
)?;
send_message(
&mut stdin,
r##"{"jsonrpc":"2.0","method":"textDocument/didChange","params":{"contentChanges":[{"range":{"end":{"character":5,"line":2},"start":{"character":5,"line":2}},"text":"e"}],"textDocument":{"uri":"file:///fake.py","version":3}}}"##,
)?;
send_message(
&mut stdin,
r##"{"jsonrpc":"2.0","method":"textDocument/didChange","params":{"contentChanges":[{"range":{"end":{"character":6,"line":2},"start":{"character":6,"line":2}},"text":"t"}],"textDocument":{"uri":"file:///fake.py","version":4}}}"##,
)?;
send_message(
&mut stdin,
r##"{"jsonrpc":"2.0","method":"textDocument/didChange","params":{"contentChanges":[{"range":{"end":{"character":7,"line":2},"start":{"character":7,"line":2}},"text":"u"}],"textDocument":{"uri":"file:///fake.py","version":5}}}"##,
)?;
send_message(
&mut stdin,
r##"{"jsonrpc":"2.0","method":"textDocument/didChange","params":{"contentChanges":[{"range":{"end":{"character":8,"line":2},"start":{"character":8,"line":2}},"text":"r"}],"textDocument":{"uri":"file:///fake.py","version":6}}}"##,
)?;
send_message(
&mut stdin,
r##"{"jsonrpc":"2.0","method":"textDocument/didChange","params":{"contentChanges":[{"range":{"end":{"character":9,"line":2},"start":{"character":9,"line":2}},"text":"n"}],"textDocument":{"uri":"file:///fake.py","version":7}}}"##,
)?;
send_message(
&mut stdin,
r##"{"jsonrpc":"2.0","method":"textDocument/completion","params":{"position":{"character":10,"line":2},"textDocument":{"uri":"file:///fake.py"}},"id":1}"##,
)?;
let output = read_response(&mut stdout)?;
assert_eq!(
output,
r##"{"jsonrpc":"2.0","id":1,"result":{"isIncomplete":false,"items":[{"filterText":" return","kind":1,"label":"ai - x * y","textEdit":{"newText":"x * y","range":{"end":{"character":10,"line":2},"start":{"character":10,"line":2}}}}]}}"##
);
child.kill()?;
Ok(())
}
// This FIM completion sequence was created using helix with lsp-ai and reading the logs
// It utilizes Ollama with deepseek-coder:1.3b-base and a temperature of 0
// It starts with a Python file:
// ```
// # Multiplies two numbers
// def multiply_two_numbers(x, y):
//
// # A singular test
// assert multiply_two_numbers(2, 3) == 6
//
// ```
// And has the following sequence of key strokes:
// o on line 2 (this creates an indented new line and enters insert mode)
// r
// e
// The sequence has:
// - 1 textDocument/DidOpen notification
// - 3 textDocument/didChange notifications
// - 1 textDocument/completion requests
#[test]
fn test_fim_completion_sequence() -> Result<()> {
let mut child = Command::new("cargo")
.arg("run")
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.spawn()?;
let mut stdin = child.stdin.take().unwrap();
let mut stdout = child.stdout.take().unwrap();
let initialization_message = r##"{"jsonrpc":"2.0","method":"initialize","params":{"capabilities":{"general":{"positionEncodings":["utf-8","utf-32","utf-16"]},"textDocument":{"codeAction":{"codeActionLiteralSupport":{"codeActionKind":{"valueSet":["","quickfix","refactor","refactor.extract","refactor.inline","refactor.rewrite","source","source.organizeImports"]}},"dataSupport":true,"disabledSupport":true,"isPreferredSupport":true,"resolveSupport":{"properties":["edit","command"]}},"completion":{"completionItem":{"deprecatedSupport":true,"insertReplaceSupport":true,"resolveSupport":{"properties":["documentation","detail","additionalTextEdits"]},"snippetSupport":true,"tagSupport":{"valueSet":[1]}},"completionItemKind":{}},"hover":{"contentFormat":["markdown"]},"inlayHint":{"dynamicRegistration":false},"publishDiagnostics":{"tagSupport":{"valueSet":[1,2]},"versionSupport":true},"rename":{"dynamicRegistration":false,"honorsChangeAnnotations":false,"prepareSupport":true},"signatureHelp":{"signatureInformation":{"activeParameterSupport":true,"documentationFormat":["markdown"],"parameterInformation":{"labelOffsetSupport":true}}}},"window":{"workDoneProgress":true},"workspace":{"applyEdit":true,"configuration":true,"didChangeConfiguration":{"dynamicRegistration":false},"didChangeWatchedFiles":{"dynamicRegistration":true,"relativePatternSupport":false},"executeCommand":{"dynamicRegistration":false},"fileOperations":{"didRename":true,"willRename":true},"inlayHint":{"refreshSupport":false},"symbol":{"dynamicRegistration":false},"workspaceEdit":{"documentChanges":true,"failureHandling":"abort","normalizesLineEndings":false,"resourceOperations":["create","rename","delete"]},"workspaceFolders":true}},"clientInfo":{"name":"helix","version":"24.3 (beb5afcb)"},"initializationOptions":{"completion":{"model":"model1","parameters":{"fim":{"end":"<fim▁end>","middle":"<fim▁hole>","start":"<fim▁begin>"},"max_context":1024,"options":{"num_predict":32,"temperature":0}}},"memory":{"file_store":{}},"models":{"model1":{"model":"deepseek-coder:1.3b-base","type":"ollama"}}},"processId":50347,"rootPath":"/home/silas/Projects/test","rootUri":null,"workspaceFolders":[]},"id":0}"##;
send_message(&mut stdin, initialization_message)?;
let _ = read_response(&mut stdout)?;
send_message(
&mut stdin,
r#"{"jsonrpc":"2.0","method":"initialized","params":{}}"#,
)?;
send_message(
&mut stdin,
r##"{"jsonrpc":"2.0","method":"textDocument/didOpen","params":{"textDocument":{"languageId":"python","text":"# Multiplies two numbers\ndef multiply_two_numbers(x, y):\n\n# A singular test\nassert multiply_two_numbers(2, 3) == 6\n","uri":"file:///fake.py","version":0}}}"##,
)?;
send_message(
&mut stdin,
r##"{"jsonrpc":"2.0","method":"textDocument/didChange","params":{"contentChanges":[{"range":{"end":{"character":31,"line":1},"start":{"character":31,"line":1}},"text":"\n "}],"textDocument":{"uri":"file:///fake.py","version":1}}}"##,
)?;
send_message(
&mut stdin,
r##"{"jsonrpc":"2.0","method":"textDocument/didChange","params":{"contentChanges":[{"range":{"end":{"character":4,"line":2},"start":{"character":4,"line":2}},"text":"r"}],"textDocument":{"uri":"file:///fake.py","version":2}}}"##,
)?;
send_message(
&mut stdin,
r##"{"jsonrpc":"2.0","method":"textDocument/didChange","params":{"contentChanges":[{"range":{"end":{"character":5,"line":2},"start":{"character":5,"line":2}},"text":"e"}],"textDocument":{"uri":"file:///fake.py","version":3}}}"##,
)?;
send_message(
&mut stdin,
r##"{"jsonrpc":"2.0","method":"textDocument/completion","params":{"position":{"character":6,"line":2},"textDocument":{"uri":"file:///fake.py"}},"id":1}"##,
)?;
let output = read_response(&mut stdout)?;
assert_eq!(
output,
r##"{"jsonrpc":"2.0","id":1,"result":{"isIncomplete":false,"items":[{"filterText":" re","kind":1,"label":"ai - turn x * y","textEdit":{"newText":"turn x * y","range":{"end":{"character":6,"line":2},"start":{"character":6,"line":2}}}}]}}"##
);
child.kill()?;
Ok(())
}
// This completion sequence was created using helix with lsp-ai and reading the logs
// It utilizes Ollama with deepseek-coder:1.3b-base and a temperature of 0
// It starts with a Python file:
// ```
// # Multiplies two numbers
// def multiply_two_numbers(x, y):
//
// ```
// And has the following sequence of key strokes:
// o on line 2 (this creates an indented new line and enters insert mode)
// r
// e
// t
// u
// r
// n
// The sequence has:
// - 1 textDocument/DidOpen notification
// - 7 textDocument/didChange notifications
// - 1 textDocument/completion requests
#[test]
fn test_completion_sequence() -> Result<()> {
let mut child = Command::new("cargo")
.arg("run")
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.spawn()?;
let mut stdin = child.stdin.take().unwrap();
let mut stdout = child.stdout.take().unwrap();
let initialization_message = r##"{"jsonrpc":"2.0","method":"initialize","params":{"capabilities":{"general":{"positionEncodings":["utf-8","utf-32","utf-16"]},"textDocument":{"codeAction":{"codeActionLiteralSupport":{"codeActionKind":{"valueSet":["","quickfix","refactor","refactor.extract","refactor.inline","refactor.rewrite","source","source.organizeImports"]}},"dataSupport":true,"disabledSupport":true,"isPreferredSupport":true,"resolveSupport":{"properties":["edit","command"]}},"completion":{"completionItem":{"deprecatedSupport":true,"insertReplaceSupport":true,"resolveSupport":{"properties":["documentation","detail","additionalTextEdits"]},"snippetSupport":true,"tagSupport":{"valueSet":[1]}},"completionItemKind":{}},"hover":{"contentFormat":["markdown"]},"inlayHint":{"dynamicRegistration":false},"publishDiagnostics":{"tagSupport":{"valueSet":[1,2]},"versionSupport":true},"rename":{"dynamicRegistration":false,"honorsChangeAnnotations":false,"prepareSupport":true},"signatureHelp":{"signatureInformation":{"activeParameterSupport":true,"documentationFormat":["markdown"],"parameterInformation":{"labelOffsetSupport":true}}}},"window":{"workDoneProgress":true},"workspace":{"applyEdit":true,"configuration":true,"didChangeConfiguration":{"dynamicRegistration":false},"didChangeWatchedFiles":{"dynamicRegistration":true,"relativePatternSupport":false},"executeCommand":{"dynamicRegistration":false},"fileOperations":{"didRename":true,"willRename":true},"inlayHint":{"refreshSupport":false},"symbol":{"dynamicRegistration":false},"workspaceEdit":{"documentChanges":true,"failureHandling":"abort","normalizesLineEndings":false,"resourceOperations":["create","rename","delete"]},"workspaceFolders":true}},"clientInfo":{"name":"helix","version":"24.3 (beb5afcb)"},"initializationOptions":{"completion":{"model":"model1","parameters":{"max_context":1024,"options":{"num_predict":32,"temperature":0}}},"memory":{"file_store":{}},"models":{"model1":{"model":"deepseek-coder:1.3b-base","type":"ollama"}}},"processId":62322,"rootPath":"/home/silas/Projects/test","rootUri":null,"workspaceFolders":[]},"id":0}"##;
send_message(&mut stdin, initialization_message)?;
let _ = read_response(&mut stdout)?;
send_message(
&mut stdin,
r#"{"jsonrpc":"2.0","method":"initialized","params":{}}"#,
)?;
send_message(
&mut stdin,
r##"{"jsonrpc":"2.0","method":"textDocument/didOpen","params":{"textDocument":{"languageId":"python","text":"# Multiplies two numbers\ndef multiply_two_numbers(x, y):\n","uri":"file:///fake.py","version":0}}}"##,
)?;
send_message(
&mut stdin,
r##"{"jsonrpc":"2.0","method":"textDocument/didChange","params":{"contentChanges":[{"range":{"end":{"character":31,"line":1},"start":{"character":31,"line":1}},"text":"\n "}],"textDocument":{"uri":"file:///fake.py","version":1}}}"##,
)?;
send_message(
&mut stdin,
r##"{"jsonrpc":"2.0","method":"textDocument/didChange","params":{"contentChanges":[{"range":{"end":{"character":4,"line":2},"start":{"character":4,"line":2}},"text":"r"}],"textDocument":{"uri":"file:///fake.py","version":2}}}"##,
)?;
send_message(
&mut stdin,
r##"{"jsonrpc":"2.0","method":"textDocument/didChange","params":{"contentChanges":[{"range":{"end":{"character":5,"line":2},"start":{"character":5,"line":2}},"text":"e"}],"textDocument":{"uri":"file:///fake.py","version":3}}}"##,
)?;
send_message(
&mut stdin,
r##"{"jsonrpc":"2.0","method":"textDocument/didChange","params":{"contentChanges":[{"range":{"end":{"character":6,"line":2},"start":{"character":6,"line":2}},"text":"t"}],"textDocument":{"uri":"file:///fake.py","version":4}}}"##,
)?;
send_message(
&mut stdin,
r##"{"jsonrpc":"2.0","method":"textDocument/didChange","params":{"contentChanges":[{"range":{"end":{"character":7,"line":2},"start":{"character":7,"line":2}},"text":"u"}],"textDocument":{"uri":"file:///fake.py","version":5}}}"##,
)?;
send_message(
&mut stdin,
r##"{"jsonrpc":"2.0","method":"textDocument/didChange","params":{"contentChanges":[{"range":{"end":{"character":8,"line":2},"start":{"character":8,"line":2}},"text":"r"}],"textDocument":{"uri":"file:///fake.py","version":6}}}"##,
)?;
send_message(
&mut stdin,
r##"{"jsonrpc":"2.0","method":"textDocument/didChange","params":{"contentChanges":[{"range":{"end":{"character":9,"line":2},"start":{"character":9,"line":2}},"text":"n"}],"textDocument":{"uri":"file:///fake.py","version":7}}}"##,
)?;
send_message(
&mut stdin,
r##"{"jsonrpc":"2.0","method":"textDocument/completion","params":{"position":{"character":10,"line":2},"textDocument":{"uri":"file:///fake.py"}},"id":1}"##,
)?;
let output = read_response(&mut stdout)?;
assert_eq!(
output,
r##"{"jsonrpc":"2.0","id":1,"result":{"isIncomplete":false,"items":[{"filterText":" return","kind":1,"label":"ai - x * y","textEdit":{"newText":" x * y","range":{"end":{"character":10,"line":2},"start":{"character":10,"line":2}}}}]}}"##
);
child.kill()?;
Ok(())
}

View File

@ -0,0 +1,19 @@
[package]
name = "splitter-tree-sitter"
version = "0.1.0"
description = "A code splitter utilizing Tree-sitter"
edition.workspace = true
repository.workspace = true
license.workspace = true
[dependencies]
thiserror = "1.0.61"
tree-sitter = "0.22"
[dev-dependencies]
tree-sitter-rust = "0.21"
tree-sitter-zig = { git = "https://github.com/maxxnino/tree-sitter-zig" }
[build-dependencies]
cc="*"

View File

@ -0,0 +1,21 @@
MIT License
Copyright (c) 2024 Silas Marvin
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@ -0,0 +1,3 @@
# tree-sitter-splitter
This is a code splitter that utilizes Tree-sitter to split code.

View File

@ -0,0 +1,326 @@
use thiserror::Error;
use tree_sitter::{Tree, TreeCursor};
#[derive(Error, Debug)]
pub enum NewError {
#[error("chunk_size must be greater than chunk_overlap")]
SizeOverlapError,
}
#[derive(Error, Debug)]
pub enum SplitError {
#[error("converting utf8 to str")]
Utf8Error(#[from] core::str::Utf8Error),
}
pub struct TreeSitterCodeSplitter {
chunk_size: usize,
chunk_overlap: usize,
}
pub struct ByteRange {
pub start_byte: usize,
pub end_byte: usize,
}
impl ByteRange {
fn new(start_byte: usize, end_byte: usize) -> Self {
Self {
start_byte,
end_byte,
}
}
}
pub struct Chunk<'a> {
pub text: &'a str,
pub range: ByteRange,
}
impl<'a> Chunk<'a> {
fn new(text: &'a str, range: ByteRange) -> Self {
Self { text, range }
}
}
impl TreeSitterCodeSplitter {
pub fn new(chunk_size: usize, chunk_overlap: usize) -> Result<Self, NewError> {
if chunk_overlap > chunk_size {
Err(NewError::SizeOverlapError)
} else {
Ok(Self {
chunk_size,
chunk_overlap,
})
}
}
pub fn split<'a, 'b, 'c>(
&'a self,
tree: &'b Tree,
utf8: &'c [u8],
) -> Result<Vec<Chunk<'c>>, SplitError> {
let cursor = tree.walk();
Ok(self
.split_recursive(cursor, utf8)?
.into_iter()
.rev()
// Let's combine some of our smaller chunks together
// We also want to do this in reverse as it (seems) to make more sense to combine code slices from bottom to top
.try_fold(vec![], |mut acc, current| {
if acc.len() == 0 {
acc.push(current);
Ok::<_, SplitError>(acc)
} else {
if acc.last().as_ref().unwrap().text.len() + current.text.len()
< self.chunk_size
{
let last = acc.pop().unwrap();
let text = std::str::from_utf8(
&utf8[current.range.start_byte..last.range.end_byte],
)?;
acc.push(Chunk::new(
text,
ByteRange::new(current.range.start_byte, last.range.end_byte),
));
} else {
acc.push(current);
}
Ok(acc)
}
})?
.into_iter()
.rev()
.collect())
}
fn split_recursive<'a, 'b, 'c>(
&'a self,
mut cursor: TreeCursor<'b>,
utf8: &'c [u8],
) -> Result<Vec<Chunk<'c>>, SplitError> {
let node = cursor.node();
let text = node.utf8_text(utf8)?;
// There are three cases:
// 1. Is the current range of code smaller than the chunk_size? If so, return it
// 2. If not, does the current node have children? If so, recursively walk down
// 3. If not, we must split our current node
let mut out = if text.chars().count() <= self.chunk_size {
vec![Chunk::new(
text,
ByteRange::new(node.range().start_byte, node.range().end_byte),
)]
} else {
let mut cursor_copy = cursor.clone();
if cursor_copy.goto_first_child() {
self.split_recursive(cursor_copy, utf8)?
} else {
let mut current_range =
ByteRange::new(node.range().start_byte, node.range().end_byte);
let mut chunks = vec![];
let mut current_chunk = text;
loop {
if current_chunk.len() < self.chunk_size {
chunks.push(Chunk::new(current_chunk, current_range));
break;
} else {
let new_chunk = &current_chunk[0..self.chunk_size.min(current_chunk.len())];
let new_range = ByteRange::new(
current_range.start_byte,
current_range.start_byte + new_chunk.as_bytes().len(),
);
chunks.push(Chunk::new(new_chunk, new_range));
let new_current_chunk =
&current_chunk[self.chunk_size - self.chunk_overlap..];
let byte_diff =
current_chunk.as_bytes().len() - new_current_chunk.as_bytes().len();
current_range = ByteRange::new(
current_range.start_byte + byte_diff,
current_range.end_byte,
);
current_chunk = new_current_chunk
}
}
chunks
}
};
if cursor.goto_next_sibling() {
out.append(&mut self.split_recursive(cursor, utf8)?);
}
Ok(out)
}
}
#[cfg(test)]
mod tests {
use super::*;
use tree_sitter::Parser;
#[test]
fn test_split_rust() {
let splitter = TreeSitterCodeSplitter::new(128, 0).unwrap();
let mut parser = Parser::new();
parser
.set_language(&tree_sitter_rust::language())
.expect("Error loading Rust grammar");
let source_code = r#"
#[derive(Debug)]
struct Rectangle {
width: u32,
height: u32,
}
impl Rectangle {
fn area(&self) -> u32 {
self.width * self.height
}
}
fn main() {
let rect1 = Rectangle {
width: 30,
height: 50,
};
println!(
"The area of the rectangle is {} square pixels.",
rect1.area()
);
}
"#;
let tree = parser.parse(source_code, None).unwrap();
let chunks = splitter.split(&tree, source_code.as_bytes()).unwrap();
assert_eq!(
chunks[0].text,
r#"#[derive(Debug)]
struct Rectangle {
width: u32,
height: u32,
}"#
);
assert_eq!(
chunks[1].text,
r#"impl Rectangle {
fn area(&self) -> u32 {
self.width * self.height
}
}"#
);
assert_eq!(
chunks[2].text,
r#"fn main() {
let rect1 = Rectangle {
width: 30,
height: 50,
};"#
);
assert_eq!(
chunks[3].text,
r#"println!(
"The area of the rectangle is {} square pixels.",
rect1.area()
);
}"#
);
}
#[test]
fn test_split_zig() {
let splitter = TreeSitterCodeSplitter::new(128, 10).unwrap();
let mut parser = Parser::new();
parser
.set_language(&tree_sitter_rust::language())
.expect("Error loading Rust grammar");
let source_code = r#"
const std = @import("std");
const parseInt = std.fmt.parseInt;
std.debug.print("Here is a long string 1 ... Here is a long string 2 ... Here is a long string 3 ... Here is a long string 4 ... Here is a long string 5 ... Here is a long string 6 ... Here is a long string 7 ... Here is a long string 8 ... Here is a long string 9 ...", .{});
test "parse integers" {
const input = "123 67 89,99";
const ally = std.testing.allocator;
var list = std.ArrayList(u32).init(ally);
// Ensure the list is freed at scope exit.
// Try commenting out this line!
defer list.deinit();
var it = std.mem.tokenizeAny(u8, input, " ,");
while (it.next()) |num| {
const n = try parseInt(u32, num, 10);
try list.append(n);
}
const expected = [_]u32{ 123, 67, 89, 99 };
for (expected, list.items) |exp, actual| {
try std.testing.expectEqual(exp, actual);
}
}
"#;
let tree = parser.parse(source_code, None).unwrap();
let chunks = splitter.split(&tree, source_code.as_bytes()).unwrap();
assert_eq!(
chunks[0].text,
r#"const std = @import("std");
const parseInt = std.fmt.parseInt;
std.debug.print(""#
);
assert_eq!(
chunks[1].text,
r#"Here is a long string 1 ... Here is a long string 2 ... Here is a long string 3 ... Here is a long string 4 ... Here is a long s"#
);
assert_eq!(
chunks[2].text,
r#"s a long string 5 ... Here is a long string 6 ... Here is a long string 7 ... Here is a long string 8 ... Here is a long string "#
);
assert_eq!(chunks[3].text, r#"ng string 9 ...", .{});"#);
assert_eq!(
chunks[4].text,
r#"test "parse integers" {
const input = "123 67 89,99";
const ally = std.testing.allocator;
var list = std.ArrayList"#
);
assert_eq!(
chunks[5].text,
r#"(u32).init(ally);
// Ensure the list is freed at scope exit.
// Try commenting out this line!"#
);
assert_eq!(
chunks[6].text,
r#"defer list.deinit();
var it = std.mem.tokenizeAny(u8, input, " ,");
while (it.next()) |num"#
);
assert_eq!(
chunks[7].text,
r#"| {
const n = try parseInt(u32, num, 10);
try list.append(n);
}
const expected = [_]u32{ 123, 67, 89,"#
);
assert_eq!(
chunks[8].text,
r#"99 };
for (expected, list.items) |exp, actual| {
try std.testing.expectEqual(exp, actual);
}
}"#
);
}
}

View File

@ -0,0 +1,37 @@
[package]
name = "utils-tree-sitter"
version = "0.1.0"
description = "Utils for working with splitter-tree-sitter"
edition.workspace = true
repository.workspace = true
license.workspace = true
[dependencies]
thiserror = "1.0.61"
tree-sitter = "0.22"
tree-sitter-bash = { version = "0.21", optional = true }
tree-sitter-c = { version = "0.21", optional = true }
tree-sitter-cpp = { version = "0.22", optional = true }
tree-sitter-c-sharp = { version = "0.21", optional = true }
tree-sitter-css = { version = "0.21", optional = true }
tree-sitter-elixir = { version = "0.2", optional = true }
tree-sitter-erlang = { version = "0.6", optional = true }
tree-sitter-go = { version = "0.21", optional = true }
tree-sitter-html = { version = "0.20", optional = true }
tree-sitter-java = { version = "0.21", optional = true }
tree-sitter-javascript = { version = "0.21", optional = true }
tree-sitter-json = { version = "0.21", optional = true }
tree-sitter-haskell = { version = "0.21", optional = true }
tree-sitter-lua = { version = "0.1.0", optional = true }
tree-sitter-ocaml = { version = "0.22.0", optional = true }
tree-sitter-python = { version = "0.21", optional = true }
tree-sitter-rust = { version = "0.21", optional = true }
# tree-sitter-zig = { git = "https://github.com/maxxnino/tree-sitter-zig", optional = true }
[build-dependencies]
cc="*"
[features]
default = []
all = ["dep:tree-sitter-python", "dep:tree-sitter-bash", "dep:tree-sitter-c", "dep:tree-sitter-cpp", "dep:tree-sitter-c-sharp", "dep:tree-sitter-css", "dep:tree-sitter-elixir", "dep:tree-sitter-erlang", "dep:tree-sitter-go", "dep:tree-sitter-html", "dep:tree-sitter-java", "dep:tree-sitter-javascript", "dep:tree-sitter-json", "dep:tree-sitter-rust", "dep:tree-sitter-haskell", "dep:tree-sitter-lua", "dep:tree-sitter-ocaml"]

View File

@ -0,0 +1,21 @@
MIT License
Copyright (c) 2024 Silas Marvin
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@ -0,0 +1,3 @@
# utils-tree-sitter
Utils for working with Tree-sitter

View File

@ -0,0 +1,90 @@
use thiserror::Error;
use tree_sitter::{LanguageError, Parser};
#[derive(Error, Debug)]
pub enum GetParserError {
#[error("no parser found for extension")]
NoParserFoundForExtension(String),
#[error("no parser found for extension")]
NoLanguageFoundForExtension(String),
#[error("loading grammer")]
LoadingGrammer(#[from] LanguageError),
}
fn get_extension_for_language(extension: &str) -> Result<String, GetParserError> {
Ok(match extension {
"py" => "Python",
"rs" => "Rust",
// "zig" => "Zig",
"sh" => "Bash",
"c" => "C",
"cpp" => "C++",
"cs" => "C#",
"css" => "CSS",
"ex" => "Elixir",
"erl" => "Erlang",
"go" => "Go",
"html" => "HTML",
"java" => "Java",
"js" => "JavaScript",
"json" => "JSON",
"hs" => "Haskell",
"lua" => "Lua",
"ml" => "OCaml",
_ => {
return Err(GetParserError::NoLanguageFoundForExtension(
extension.to_string(),
))
}
}
.to_string())
}
pub fn get_parser_for_extension(extension: &str) -> Result<Parser, GetParserError> {
let language = get_extension_for_language(extension)?;
let mut parser = Parser::new();
match language.as_str() {
#[cfg(any(feature = "all", feature = "python"))]
"Python" => parser.set_language(&tree_sitter_python::language())?,
#[cfg(any(feature = "all", feature = "rust"))]
"Rust" => parser.set_language(&tree_sitter_rust::language())?,
// #[cfg(any(feature = "all", feature = "zig"))]
// "Zig" => parser.set_language(&tree_sitter_zig::language())?,
#[cfg(any(feature = "all", feature = "bash"))]
"Bash" => parser.set_language(&tree_sitter_bash::language())?,
#[cfg(any(feature = "all", feature = "c"))]
"C" => parser.set_language(&tree_sitter_c::language())?,
#[cfg(any(feature = "all", feature = "cpp"))]
"C++" => parser.set_language(&tree_sitter_cpp::language())?,
#[cfg(any(feature = "all", feature = "c-sharp"))]
"C#" => parser.set_language(&tree_sitter_c_sharp::language())?,
#[cfg(any(feature = "all", feature = "css"))]
"CSS" => parser.set_language(&tree_sitter_css::language())?,
#[cfg(any(feature = "all", feature = "elixir"))]
"Elixir" => parser.set_language(&tree_sitter_elixir::language())?,
#[cfg(any(feature = "all", feature = "erlang"))]
"Erlang" => parser.set_language(&tree_sitter_erlang::language())?,
#[cfg(any(feature = "all", feature = "go"))]
"Go" => parser.set_language(&tree_sitter_go::language())?,
#[cfg(any(feature = "all", feature = "html"))]
"HTML" => parser.set_language(&tree_sitter_html::language())?,
#[cfg(any(feature = "all", feature = "java"))]
"Java" => parser.set_language(&tree_sitter_java::language())?,
#[cfg(any(feature = "all", feature = "javascript"))]
"JavaScript" => parser.set_language(&tree_sitter_javascript::language())?,
#[cfg(any(feature = "all", feature = "json"))]
"JSON" => parser.set_language(&tree_sitter_json::language())?,
#[cfg(any(feature = "all", feature = "haskell"))]
"Haskell" => parser.set_language(&tree_sitter_haskell::language())?,
#[cfg(any(feature = "all", feature = "lua"))]
"Lua" => parser.set_language(&tree_sitter_lua::language())?,
#[cfg(any(feature = "all", feature = "ocaml"))]
"OCaml" => parser.set_language(&tree_sitter_ocaml::language_ocaml())?,
_ => {
return Err(GetParserError::NoParserFoundForExtension(
language.to_string(),
))
}
}
Ok(parser)
}

View File

@ -1,597 +0,0 @@
use anyhow::Context;
use indexmap::IndexSet;
use lsp_types::TextDocumentPositionParams;
use parking_lot::Mutex;
use ropey::Rope;
use serde_json::Value;
use std::collections::HashMap;
use tracing::instrument;
use crate::{
config::{self, Config},
utils::tokens_to_estimated_characters,
};
use super::{ContextAndCodePrompt, FIMPrompt, MemoryBackend, MemoryRunParams, Prompt, PromptType};
pub struct FileStore {
_crawl: bool,
_config: Config,
file_map: Mutex<HashMap<String, Rope>>,
accessed_files: Mutex<IndexSet<String>>,
}
impl FileStore {
pub fn new(file_store_config: config::FileStore, config: Config) -> Self {
Self {
_crawl: file_store_config.crawl,
_config: config,
file_map: Mutex::new(HashMap::new()),
accessed_files: Mutex::new(IndexSet::new()),
}
}
pub fn new_without_crawl(config: Config) -> Self {
Self {
_crawl: false,
_config: config,
file_map: Mutex::new(HashMap::new()),
accessed_files: Mutex::new(IndexSet::new()),
}
}
fn get_rope_for_position(
&self,
position: &TextDocumentPositionParams,
characters: usize,
) -> anyhow::Result<(Rope, usize)> {
// Get the rope and set our initial cursor index
let current_document_uri = position.text_document.uri.to_string();
let mut rope = self
.file_map
.lock()
.get(&current_document_uri)
.context("Error file not found")?
.clone();
let mut cursor_index = rope.line_to_char(position.position.line as usize)
+ position.position.character as usize;
// Add to our rope if we need to
for file in self
.accessed_files
.lock()
.iter()
.filter(|f| **f != current_document_uri)
{
let needed = characters.saturating_sub(rope.len_chars() + 1);
if needed == 0 {
break;
}
let file_map = self.file_map.lock();
let r = file_map.get(file).context("Error file not found")?;
let slice_max = needed.min(r.len_chars() + 1);
let rope_str_slice = r
.get_slice(0..slice_max - 1)
.context("Error getting slice")?
.to_string();
rope.insert(0, "\n");
rope.insert(0, &rope_str_slice);
cursor_index += slice_max;
}
Ok((rope, cursor_index))
}
pub fn get_characters_around_position(
&self,
position: &TextDocumentPositionParams,
characters: usize,
) -> anyhow::Result<String> {
let rope = self
.file_map
.lock()
.get(position.text_document.uri.as_str())
.context("Error file not found")?
.clone();
let cursor_index = rope.line_to_char(position.position.line as usize)
+ position.position.character as usize;
let start = cursor_index.saturating_sub(characters / 2);
let end = rope
.len_chars()
.min(cursor_index + (characters - (cursor_index - start)));
let rope_slice = rope
.get_slice(start..end)
.context("Error getting rope slice")?;
Ok(rope_slice.to_string())
}
pub fn build_code(
&self,
position: &TextDocumentPositionParams,
prompt_type: PromptType,
params: MemoryRunParams,
) -> anyhow::Result<Prompt> {
let (mut rope, cursor_index) =
self.get_rope_for_position(position, params.max_context_length)?;
Ok(match prompt_type {
PromptType::ContextAndCode => {
if params.is_for_chat {
let max_length = tokens_to_estimated_characters(params.max_context_length);
let start = cursor_index.saturating_sub(max_length / 2);
let end = rope
.len_chars()
.min(cursor_index + (max_length - (cursor_index - start)));
rope.insert(cursor_index, "<CURSOR>");
let rope_slice = rope
.get_slice(start..end + "<CURSOR>".chars().count())
.context("Error getting rope slice")?;
Prompt::ContextAndCode(ContextAndCodePrompt::new(
"".to_string(),
rope_slice.to_string(),
))
} else {
let start = cursor_index
.saturating_sub(tokens_to_estimated_characters(params.max_context_length));
let rope_slice = rope
.get_slice(start..cursor_index)
.context("Error getting rope slice")?;
Prompt::ContextAndCode(ContextAndCodePrompt::new(
"".to_string(),
rope_slice.to_string(),
))
}
}
PromptType::FIM => {
let max_length = tokens_to_estimated_characters(params.max_context_length);
let start = cursor_index.saturating_sub(max_length / 2);
let end = rope
.len_chars()
.min(cursor_index + (max_length - (cursor_index - start)));
let prefix = rope
.get_slice(start..cursor_index)
.context("Error getting rope slice")?;
let suffix = rope
.get_slice(cursor_index..end)
.context("Error getting rope slice")?;
Prompt::FIM(FIMPrompt::new(prefix.to_string(), suffix.to_string()))
}
})
}
}
#[async_trait::async_trait]
impl MemoryBackend for FileStore {
#[instrument(skip(self))]
async fn get_filter_text(
&self,
position: &TextDocumentPositionParams,
) -> anyhow::Result<String> {
let rope = self
.file_map
.lock()
.get(position.text_document.uri.as_str())
.context("Error file not found")?
.clone();
let line = rope
.get_line(position.position.line as usize)
.context("Error getting filter_text")?
.slice(0..position.position.character as usize)
.to_string();
Ok(line)
}
#[instrument(skip(self))]
async fn build_prompt(
&self,
position: &TextDocumentPositionParams,
prompt_type: PromptType,
params: &Value,
) -> anyhow::Result<Prompt> {
let params: MemoryRunParams = params.try_into()?;
self.build_code(position, prompt_type, params)
}
#[instrument(skip(self))]
async fn opened_text_document(
&self,
params: lsp_types::DidOpenTextDocumentParams,
) -> anyhow::Result<()> {
let rope = Rope::from_str(&params.text_document.text);
let uri = params.text_document.uri.to_string();
self.file_map.lock().insert(uri.clone(), rope);
self.accessed_files.lock().shift_insert(0, uri);
Ok(())
}
#[instrument(skip(self))]
async fn changed_text_document(
&self,
params: lsp_types::DidChangeTextDocumentParams,
) -> anyhow::Result<()> {
let uri = params.text_document.uri.to_string();
let mut file_map = self.file_map.lock();
let rope = file_map
.get_mut(&uri)
.context("Error trying to get file that does not exist")?;
for change in params.content_changes {
// If range is ommitted, text is the new text of the document
if let Some(range) = change.range {
let start_index =
rope.line_to_char(range.start.line as usize) + range.start.character as usize;
let end_index =
rope.line_to_char(range.end.line as usize) + range.end.character as usize;
rope.remove(start_index..end_index);
rope.insert(start_index, &change.text);
} else {
*rope = Rope::from_str(&change.text);
}
}
self.accessed_files.lock().shift_insert(0, uri);
Ok(())
}
#[instrument(skip(self))]
async fn renamed_files(&self, params: lsp_types::RenameFilesParams) -> anyhow::Result<()> {
for file_rename in params.files {
let mut file_map = self.file_map.lock();
if let Some(rope) = file_map.remove(&file_rename.old_uri) {
file_map.insert(file_rename.new_uri, rope);
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use lsp_types::{
DidOpenTextDocumentParams, FileRename, Position, Range, RenameFilesParams,
TextDocumentContentChangeEvent, TextDocumentIdentifier, TextDocumentItem,
VersionedTextDocumentIdentifier,
};
use serde_json::json;
fn generate_base_file_store() -> anyhow::Result<FileStore> {
let config = Config::default_with_file_store_without_models();
let file_store_config = if let config::ValidMemoryBackend::FileStore(file_store_config) =
config.config.memory.clone()
{
file_store_config
} else {
anyhow::bail!("requires a file_store_config")
};
Ok(FileStore::new(file_store_config, config))
}
fn generate_filler_text_document(uri: Option<&str>, text: Option<&str>) -> TextDocumentItem {
let uri = uri.unwrap_or("file://filler/");
let text = text.unwrap_or("Here is the document body");
TextDocumentItem {
uri: reqwest::Url::parse(uri).unwrap(),
language_id: "filler".to_string(),
version: 0,
text: text.to_string(),
}
}
#[tokio::test]
async fn can_open_document() -> anyhow::Result<()> {
let params = lsp_types::DidOpenTextDocumentParams {
text_document: generate_filler_text_document(None, None),
};
let file_store = generate_base_file_store()?;
file_store.opened_text_document(params).await?;
let file = file_store
.file_map
.lock()
.get("file://filler/")
.unwrap()
.clone();
assert_eq!(file.to_string(), "Here is the document body");
Ok(())
}
#[tokio::test]
async fn can_rename_document() -> anyhow::Result<()> {
let params = lsp_types::DidOpenTextDocumentParams {
text_document: generate_filler_text_document(None, None),
};
let file_store = generate_base_file_store()?;
file_store.opened_text_document(params).await?;
let params = RenameFilesParams {
files: vec![FileRename {
old_uri: "file://filler/".to_string(),
new_uri: "file://filler2/".to_string(),
}],
};
file_store.renamed_files(params).await?;
let file = file_store
.file_map
.lock()
.get("file://filler2/")
.unwrap()
.clone();
assert_eq!(file.to_string(), "Here is the document body");
Ok(())
}
#[tokio::test]
async fn can_change_document() -> anyhow::Result<()> {
let text_document = generate_filler_text_document(None, None);
let params = DidOpenTextDocumentParams {
text_document: text_document.clone(),
};
let file_store = generate_base_file_store()?;
file_store.opened_text_document(params).await?;
let params = lsp_types::DidChangeTextDocumentParams {
text_document: VersionedTextDocumentIdentifier {
uri: text_document.uri.clone(),
version: 1,
},
content_changes: vec![TextDocumentContentChangeEvent {
range: Some(Range {
start: Position {
line: 0,
character: 1,
},
end: Position {
line: 0,
character: 3,
},
}),
range_length: None,
text: "a".to_string(),
}],
};
file_store.changed_text_document(params).await?;
let file = file_store
.file_map
.lock()
.get("file://filler/")
.unwrap()
.clone();
assert_eq!(file.to_string(), "Hae is the document body");
let params = lsp_types::DidChangeTextDocumentParams {
text_document: VersionedTextDocumentIdentifier {
uri: text_document.uri,
version: 1,
},
content_changes: vec![TextDocumentContentChangeEvent {
range: None,
range_length: None,
text: "abc".to_string(),
}],
};
file_store.changed_text_document(params).await?;
let file = file_store
.file_map
.lock()
.get("file://filler/")
.unwrap()
.clone();
assert_eq!(file.to_string(), "abc");
Ok(())
}
#[tokio::test]
async fn can_build_prompt() -> anyhow::Result<()> {
let text_document = generate_filler_text_document(
None,
Some(
r#"Document Top
Here is a more complicated document
Some text
The end with a trailing new line
"#,
),
);
// Test basic completion
let params = lsp_types::DidOpenTextDocumentParams {
text_document: text_document.clone(),
};
let file_store = generate_base_file_store()?;
file_store.opened_text_document(params).await?;
let prompt = file_store
.build_prompt(
&TextDocumentPositionParams {
text_document: TextDocumentIdentifier {
uri: text_document.uri.clone(),
},
position: Position {
line: 0,
character: 10,
},
},
PromptType::ContextAndCode,
&json!({}),
)
.await?;
let prompt: ContextAndCodePrompt = prompt.try_into()?;
assert_eq!(prompt.context, "");
assert_eq!("Document T", prompt.code);
// Test FIM
let prompt = file_store
.build_prompt(
&TextDocumentPositionParams {
text_document: TextDocumentIdentifier {
uri: text_document.uri.clone(),
},
position: Position {
line: 0,
character: 10,
},
},
PromptType::FIM,
&json!({}),
)
.await?;
let prompt: FIMPrompt = prompt.try_into()?;
assert_eq!(prompt.prompt, r#"Document T"#);
assert_eq!(
prompt.suffix,
r#"op
Here is a more complicated document
Some text
The end with a trailing new line
"#
);
// Test chat
let prompt = file_store
.build_prompt(
&TextDocumentPositionParams {
text_document: TextDocumentIdentifier {
uri: text_document.uri.clone(),
},
position: Position {
line: 0,
character: 10,
},
},
PromptType::ContextAndCode,
&json!({
"messages": []
}),
)
.await?;
let prompt: ContextAndCodePrompt = prompt.try_into()?;
assert_eq!(prompt.context, "");
let text = r#"Document T<CURSOR>op
Here is a more complicated document
Some text
The end with a trailing new line
"#
.to_string();
assert_eq!(text, prompt.code);
// Test multi-file
let text_document2 = generate_filler_text_document(
Some("file://filler2"),
Some(
r#"Document Top2
Here is a more complicated document
Some text
The end with a trailing new line
"#,
),
);
let params = lsp_types::DidOpenTextDocumentParams {
text_document: text_document2.clone(),
};
file_store.opened_text_document(params).await?;
let prompt = file_store
.build_prompt(
&TextDocumentPositionParams {
text_document: TextDocumentIdentifier {
uri: text_document.uri.clone(),
},
position: Position {
line: 0,
character: 10,
},
},
PromptType::ContextAndCode,
&json!({}),
)
.await?;
let prompt: ContextAndCodePrompt = prompt.try_into()?;
assert_eq!(prompt.context, "");
assert_eq!(format!("{}\nDocument T", text_document2.text), prompt.code);
Ok(())
}
#[tokio::test]
async fn test_document_cursor_placement_corner_cases() -> anyhow::Result<()> {
let text_document = generate_filler_text_document(None, Some("test\n"));
let params = lsp_types::DidOpenTextDocumentParams {
text_document: text_document.clone(),
};
let file_store = generate_base_file_store()?;
file_store.opened_text_document(params).await?;
// Test chat
let prompt = file_store
.build_prompt(
&TextDocumentPositionParams {
text_document: TextDocumentIdentifier {
uri: text_document.uri.clone(),
},
position: Position {
line: 1,
character: 0,
},
},
PromptType::ContextAndCode,
&json!({"messages": []}),
)
.await?;
let prompt: ContextAndCodePrompt = prompt.try_into()?;
assert_eq!(prompt.context, "");
let text = r#"test
<CURSOR>"#
.to_string();
assert_eq!(text, prompt.code);
Ok(())
}
// #[tokio::test]
// async fn test_fim_placement_corner_cases() -> anyhow::Result<()> {
// let text_document = generate_filler_text_document(None, Some("test\n"));
// let params = lsp_types::DidOpenTextDocumentParams {
// text_document: text_document.clone(),
// };
// let file_store = generate_base_file_store()?;
// file_store.opened_text_document(params).await?;
// // Test FIM
// let params = json!({
// "fim": {
// "start": "SS",
// "middle": "MM",
// "end": "EE"
// }
// });
// let prompt = file_store
// .build_prompt(
// &TextDocumentPositionParams {
// text_document: TextDocumentIdentifier {
// uri: text_document.uri.clone(),
// },
// position: Position {
// line: 1,
// character: 0,
// },
// },
// params,
// )
// .await?;
// assert_eq!(prompt.context, "");
// let text = r#"test
// "#
// .to_string();
// assert_eq!(text, prompt.code);
// Ok(())
// }
}

View File

@ -1,254 +0,0 @@
use std::{
sync::mpsc::{self, Sender},
time::Duration,
};
use anyhow::Context;
use lsp_types::TextDocumentPositionParams;
use pgml::{Collection, Pipeline};
use serde_json::{json, Value};
use tokio::time;
use tracing::instrument;
use crate::{
config::{self, Config},
utils::tokens_to_estimated_characters,
};
use super::{
file_store::FileStore, ContextAndCodePrompt, MemoryBackend, MemoryRunParams, Prompt, PromptType,
};
pub struct PostgresML {
_config: Config,
file_store: FileStore,
collection: Collection,
pipeline: Pipeline,
debounce_tx: Sender<String>,
added_pipeline: bool,
}
impl PostgresML {
pub fn new(
postgresml_config: config::PostgresML,
configuration: Config,
) -> anyhow::Result<Self> {
let file_store = FileStore::new_without_crawl(configuration.clone());
let database_url = if let Some(database_url) = postgresml_config.database_url {
database_url
} else {
std::env::var("PGML_DATABASE_URL")?
};
// TODO: Think on the naming of the collection
// Maybe filter on metadata or I'm not sure
let collection = Collection::new("test-lsp-ai-3", Some(database_url))?;
// TODO: Review the pipeline
let pipeline = Pipeline::new(
"v1",
Some(
json!({
"text": {
"splitter": {
"model": "recursive_character",
"parameters": {
"chunk_size": 1500,
"chunk_overlap": 40
}
},
"semantic_search": {
"model": "intfloat/e5-small",
}
}
})
.into(),
),
)?;
// Setup up a debouncer for changed text documents
let runtime = tokio::runtime::Builder::new_multi_thread()
.worker_threads(2)
.enable_all()
.build()?;
let mut task_collection = collection.clone();
let (debounce_tx, debounce_rx) = mpsc::channel::<String>();
runtime.spawn(async move {
let duration = Duration::from_millis(500);
let mut file_paths = Vec::new();
loop {
time::sleep(duration).await;
let new_paths: Vec<String> = debounce_rx.try_iter().collect();
if !new_paths.is_empty() {
for path in new_paths {
if !file_paths.iter().any(|p| *p == path) {
file_paths.push(path);
}
}
} else {
if file_paths.is_empty() {
continue;
}
let documents = file_paths
.into_iter()
.map(|path| {
let text = std::fs::read_to_string(&path)
.unwrap_or_else(|_| panic!("Error reading path: {}", path));
json!({
"id": path,
"text": text
})
.into()
})
.collect();
task_collection
.upsert_documents(documents, None)
.await
.expect("PGML - Error adding pipeline to collection");
file_paths = Vec::new();
}
}
});
Ok(Self {
_config: configuration,
file_store,
collection,
pipeline,
debounce_tx,
added_pipeline: false,
})
}
}
#[async_trait::async_trait]
impl MemoryBackend for PostgresML {
#[instrument(skip(self))]
async fn get_filter_text(
&self,
position: &TextDocumentPositionParams,
) -> anyhow::Result<String> {
self.file_store.get_filter_text(position).await
}
#[instrument(skip(self))]
async fn build_prompt(
&self,
position: &TextDocumentPositionParams,
prompt_type: PromptType,
params: &Value,
) -> anyhow::Result<Prompt> {
let params: MemoryRunParams = params.try_into()?;
let query = self
.file_store
.get_characters_around_position(position, 512)?;
let res = self
.collection
.vector_search_local(
json!({
"query": {
"fields": {
"text": {
"query": query
}
},
},
"limit": 5
})
.into(),
&self.pipeline,
)
.await?;
let context = res
.into_iter()
.map(|c| {
c["chunk"]
.as_str()
.map(|t| t.to_owned())
.context("PGML - Error getting chunk from vector search")
})
.collect::<anyhow::Result<Vec<String>>>()?
.join("\n\n");
let mut file_store_params = params.clone();
file_store_params.max_context_length = 512;
let code = self
.file_store
.build_code(position, prompt_type, file_store_params)?;
let code: ContextAndCodePrompt = code.try_into()?;
let code = code.code;
let max_characters = tokens_to_estimated_characters(params.max_context_length);
let _context: String = context
.chars()
.take(max_characters - code.chars().count())
.collect();
// We need to redo this section to work with the new memory backend system
todo!()
// Ok(Prompt::new(context, code))
}
#[instrument(skip(self))]
async fn opened_text_document(
&self,
params: lsp_types::DidOpenTextDocumentParams,
) -> anyhow::Result<()> {
let text = params.text_document.text.clone();
let path = params.text_document.uri.path().to_owned();
let task_added_pipeline = self.added_pipeline;
let mut task_collection = self.collection.clone();
let mut task_pipeline = self.pipeline.clone();
if !task_added_pipeline {
task_collection
.add_pipeline(&mut task_pipeline)
.await
.expect("PGML - Error adding pipeline to collection");
}
task_collection
.upsert_documents(
vec![json!({
"id": path,
"text": text
})
.into()],
None,
)
.await
.expect("PGML - Error upserting documents");
self.file_store.opened_text_document(params).await
}
#[instrument(skip(self))]
async fn changed_text_document(
&self,
params: lsp_types::DidChangeTextDocumentParams,
) -> anyhow::Result<()> {
let path = params.text_document.uri.path().to_owned();
self.debounce_tx.send(path)?;
self.file_store.changed_text_document(params).await
}
#[instrument(skip(self))]
async fn renamed_files(&self, params: lsp_types::RenameFilesParams) -> anyhow::Result<()> {
let mut task_collection = self.collection.clone();
let task_params = params.clone();
for file in task_params.files {
task_collection
.delete_documents(
json!({
"id": file.old_uri
})
.into(),
)
.await
.expect("PGML - Error deleting file");
let text = std::fs::read_to_string(&file.new_uri).expect("PGML - Error reading file");
task_collection
.upsert_documents(
vec![json!({
"id": file.new_uri,
"text": text
})
.into()],
None,
)
.await
.expect("PGML - Error adding pipeline to collection");
}
self.file_store.renamed_files(params).await
}
}

View File

@ -1,112 +0,0 @@
use anyhow::Result;
use std::{
io::{Read, Write},
process::{ChildStdin, ChildStdout, Command, Stdio},
};
// Note if you get an empty response with no error, that typically means
// the language server died
fn read_response(stdout: &mut ChildStdout) -> Result<String> {
let mut content_length = None;
let mut buf = vec![];
loop {
let mut buf2 = vec![0];
stdout.read_exact(&mut buf2)?;
buf.push(buf2[0]);
if let Some(content_length) = content_length {
if buf.len() == content_length {
break;
}
} else {
let len = buf.len();
if len > 4
&& buf[len - 4] == 13
&& buf[len - 3] == 10
&& buf[len - 2] == 13
&& buf[len - 1] == 10
{
content_length =
Some(String::from_utf8(buf[16..len - 4].to_vec())?.parse::<usize>()?);
buf = vec![];
}
}
}
Ok(String::from_utf8(buf)?)
}
fn send_message(stdin: &mut ChildStdin, message: &str) -> Result<()> {
stdin.write_all(format!("Content-Length: {}\r\n", message.as_bytes().len(),).as_bytes())?;
stdin.write_all("\r\n".as_bytes())?;
stdin.write_all(message.as_bytes())?;
Ok(())
}
// This completion sequence was created using helix with the lsp-ai analyzer and reading the logs
// It starts with a Python file:
// ```
// # Multiplies two numbers
// def multiply_two_numbers(x, y):
//
// # A singular test
// assert multiply_two_numbers(2, 3) == 6
// ```
// And has the following sequence of key strokes:
// o on line 2 (this creates an indented new line and enters insert mode)
// r
// e
// The sequence has:
// - 1 textDocument/DidOpen notification
// - 3 textDocument/didChange notifications
// - 1 textDocument/completion requests
// This test can fail if the model gives a different response than normal, but that seems reasonably unlikely
// I guess we should hardcode the seed or something if we want to do more of these
#[test]
fn test_completion_sequence() -> Result<()> {
let mut child = Command::new("cargo")
.arg("run")
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.spawn()?;
let mut stdin = child.stdin.take().unwrap();
let mut stdout = child.stdout.take().unwrap();
let initialization_message = r##"{"jsonrpc":"2.0","method":"initialize","params":{"capabilities":{"general":{"positionEncodings":["utf-8","utf-32","utf-16"]},"textDocument":{"codeAction":{"codeActionLiteralSupport":{"codeActionKind":{"valueSet":["","quickfix","refactor","refactor.extract","refactor.inline","refactor.rewrite","source","source.organizeImports"]}},"dataSupport":true,"disabledSupport":true,"isPreferredSupport":true,"resolveSupport":{"properties":["edit","command"]}},"completion":{"completionItem":{"deprecatedSupport":true,"insertReplaceSupport":true,"resolveSupport":{"properties":["documentation","detail","additionalTextEdits"]},"snippetSupport":true,"tagSupport":{"valueSet":[1]}},"completionItemKind":{}},"hover":{"contentFormat":["markdown"]},"inlayHint":{"dynamicRegistration":false},"publishDiagnostics":{"versionSupport":true},"rename":{"dynamicRegistration":false,"honorsChangeAnnotations":false,"prepareSupport":true},"signatureHelp":{"signatureInformation":{"activeParameterSupport":true,"documentationFormat":["markdown"],"parameterInformation":{"labelOffsetSupport":true}}}},"window":{"workDoneProgress":true},"workspace":{"applyEdit":true,"configuration":true,"didChangeConfiguration":{"dynamicRegistration":false},"didChangeWatchedFiles":{"dynamicRegistration":true,"relativePatternSupport":false},"executeCommand":{"dynamicRegistration":false},"inlayHint":{"refreshSupport":false},"symbol":{"dynamicRegistration":false},"workspaceEdit":{"documentChanges":true,"failureHandling":"abort","normalizesLineEndings":false,"resourceOperations":["create","rename","delete"]},"workspaceFolders":true}},"clientInfo":{"name":"helix","version":"23.10 (f6021dd0)"},"processId":70007,"rootPath":"/Users/silas/Projects/Tests/lsp-ai-tests","rootUri":null,"workspaceFolders":[]},"id":0}"##;
send_message(&mut stdin, initialization_message)?;
let _ = read_response(&mut stdout)?;
send_message(
&mut stdin,
r#"{"jsonrpc":"2.0","method":"initialized","params":{}}"#,
)?;
send_message(
&mut stdin,
r##"{"jsonrpc":"2.0","method":"textDocument/didOpen","params":{"textDocument":{"languageId":"python","text":"# Multiplies two numbers\ndef multiply_two_numbers(x, y):\n\n# A singular test\nassert multiply_two_numbers(2, 3) == 6\n","uri":"file:///fake.py","version":0}}}"##,
)?;
send_message(
&mut stdin,
r##"{"jsonrpc":"2.0","method":"textDocument/didChange","params":{"contentChanges":[{"range":{"end":{"character":31,"line":1},"start":{"character":31,"line":1}},"text":"\n "}],"textDocument":{"uri":"file:///fake.py","version":1}}}"##,
)?;
send_message(
&mut stdin,
r##"{"jsonrpc":"2.0","method":"textDocument/didChange","params":{"contentChanges":[{"range":{"end":{"character":4,"line":2},"start":{"character":4,"line":2}},"text":"r"}],"textDocument":{"uri":"file:///fake.py","version":2}}}"##,
)?;
send_message(
&mut stdin,
r##"{"jsonrpc":"2.0","method":"textDocument/didChange","params":{"contentChanges":[{"range":{"end":{"character":5,"line":2},"start":{"character":5,"line":2}},"text":"e"}],"textDocument":{"uri":"file:///fake.py","version":3}}}"##,
)?;
send_message(
&mut stdin,
r##"{"jsonrpc":"2.0","method":"textDocument/completion","params":{"position":{"character":6,"line":2},"textDocument":{"uri":"file:///fake.py"}},"id":1}"##,
)?;
let output = read_response(&mut stdout)?;
assert_eq!(
output,
r##"{"jsonrpc":"2.0","id":1,"result":{"isIncomplete":false,"items":[{"filterText":" re\n","kind":1,"label":"ai - turn x * y","textEdit":{"newText":"turn x * y","range":{"end":{"character":6,"line":2},"start":{"character":6,"line":2}}}}]}}"##
);
child.kill()?;
Ok(())
}