From 58192c4182bddfe33d8e87591055de3bb7036c1b Mon Sep 17 00:00:00 2001 From: Silas Marvin <19626586+SilasMarvin@users.noreply.github.com> Date: Sun, 16 Jun 2024 09:49:16 -0700 Subject: [PATCH] Periodic commit --- Cargo.lock | 76 ++++- Cargo.toml | 8 + src/config.rs | 49 +++- src/crawl.rs | 10 +- src/main.rs | 11 +- src/memory_backends/file_store.rs | 384 +++++++++++++++++++++----- src/memory_backends/postgresml/mod.rs | 294 ++++++++++++++------ src/splitters/mod.rs | 53 ++++ src/splitters/tree_sitter.rs | 77 ++++++ src/utils.rs | 18 +- 10 files changed, 810 insertions(+), 170 deletions(-) create mode 100644 src/splitters/mod.rs create mode 100644 src/splitters/tree_sitter.rs diff --git a/Cargo.lock b/Cargo.lock index 524e12f..cd6001c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1523,6 +1523,7 @@ dependencies = [ "anyhow", "assert_cmd", "async-trait", + "cc", "directories", "hf-hub", "ignore", @@ -1539,10 +1540,13 @@ dependencies = [ "ropey", "serde", "serde_json", + "splitter-tree-sitter", "tokenizers", "tokio", "tracing", "tracing-subscriber", + "tree-sitter", + "utils-tree-sitter", "xxhash-rust", ] @@ -2196,9 +2200,9 @@ dependencies = [ [[package]] name = "regex" -version = "1.10.3" +version = "1.10.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b62dbe01f0b06f9d8dc7d49e05a0785f153b00b2c227856282f671e0318c9b15" +checksum = "b91213439dad192326a0d7c6ee3955910425f441d7038e0d6933b0aec5c4517f" dependencies = [ "aho-corasick", "memchr", @@ -2756,6 +2760,15 @@ dependencies = [ "der", ] +[[package]] +name = "splitter-tree-sitter" +version = "0.1.0" +dependencies = [ + "cc", + "thiserror", + "tree-sitter", +] + [[package]] name = "spm_precompiled" version = "0.1.4" @@ -3088,18 +3101,18 @@ checksum = "3369f5ac52d5eb6ab48c6b4ffdc8efbcad6b89c765749064ba298f2c68a16a76" [[package]] name = "thiserror" -version = "1.0.58" +version = "1.0.61" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "03468839009160513471e86a034bb2c5c0e4baae3b43f79ffc55c4a5427b3297" +checksum = "c546c80d6be4bc6a00c0f01730c08df82eaa7a7a61f11d656526506112cc1709" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.58" +version = "1.0.61" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c61f3ba182994efc43764a46c018c347bc492c79f024e705f46567b418f6d4f7" +checksum = "46c3384250002a6d5af4d114f2845d37b57521033f30d5c3f46c4d70e1197533" dependencies = [ "proc-macro2", "quote", @@ -3339,6 +3352,45 @@ dependencies = [ "tracing-serde", ] +[[package]] +name = "tree-sitter" +version = "0.22.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df7cc499ceadd4dcdf7ec6d4cbc34ece92c3fa07821e287aedecd4416c516dca" +dependencies = [ + "cc", + "regex", +] + +[[package]] +name = "tree-sitter-python" +version = "0.21.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4066c6cf678f962f8c2c4561f205945c84834cce73d981e71392624fdc390a9" +dependencies = [ + "cc", + "tree-sitter", +] + +[[package]] +name = "tree-sitter-rust" +version = "0.21.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "277690f420bf90741dea984f3da038ace46c4fe6047cba57a66822226cde1c93" +dependencies = [ + "cc", + "tree-sitter", +] + +[[package]] +name = "tree-sitter-zig" +version = "0.0.1" +source = "git+https://github.com/SilasMarvin/tree-sitter-zig?branch=silas-update-tree-sitter-version#2eedab3ff6dda88aedddf0bb32a14f81bb709a73" +dependencies = [ + "cc", + "tree-sitter", +] + [[package]] name = "try-lock" version = "0.2.5" @@ -3450,6 +3502,18 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a" +[[package]] +name = "utils-tree-sitter" +version = "0.1.0" +dependencies = [ + "cc", + "thiserror", + "tree-sitter", + "tree-sitter-python", + "tree-sitter-rust", + "tree-sitter-zig", +] + [[package]] name = "uuid" version = "1.7.0" diff --git a/Cargo.toml b/Cargo.toml index 18dfb33..657589a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,6 +33,14 @@ pgml = "1.0.4" tokio = { version = "1.36.0", features = ["rt-multi-thread", "time"] } indexmap = "2.2.5" async-trait = "0.1.78" +tree-sitter = "0.22" +# splitter-tree-sitter = { git = "https://github.com/SilasMarvin/splitter-tree-sitter" } +splitter-tree-sitter = { path = "../../splitter-tree-sitter" } +# utils-tree-sitter = { git = "https://github.com/SilasMarvin/utils-tree-sitter" } +utils-tree-sitter = { path = "../../utils-tree-sitter", features = ["all"] } + +[build-dependencies] +cc="*" [features] default = [] diff --git a/src/config.rs b/src/config.rs index db1cf63..49f8e54 100644 --- a/src/config.rs +++ b/src/config.rs @@ -24,6 +24,43 @@ impl Default for PostProcess { } } +#[derive(Debug, Clone, Deserialize)] +pub enum ValidSplitter { + #[serde(rename = "tree_sitter")] + TreeSitter(TreeSitter), +} + +impl Default for ValidSplitter { + fn default() -> Self { + ValidSplitter::TreeSitter(TreeSitter::default()) + } +} + +const fn chunk_size_default() -> usize { + 1500 +} + +const fn chunk_overlap_default() -> usize { + 0 +} + +#[derive(Debug, Clone, Deserialize)] +pub struct TreeSitter { + #[serde(default = "chunk_size_default")] + pub chunk_size: usize, + #[serde(default = "chunk_overlap_default")] + pub chunk_overlap: usize, +} + +impl Default for TreeSitter { + fn default() -> Self { + Self { + chunk_size: 1500, + chunk_overlap: 0, + } + } +} + #[derive(Debug, Clone, Deserialize)] pub enum ValidMemoryBackend { #[serde(rename = "file_store")] @@ -85,15 +122,21 @@ pub struct FIM { pub end: String, } -const fn max_crawl_memory_default() -> u32 { +const fn max_crawl_memory_default() -> u64 { 42 } +const fn max_crawl_file_size_default() -> u64 { + 10_000_000 +} + #[derive(Clone, Debug, Deserialize)] #[serde(deny_unknown_fields)] pub struct Crawl { + #[serde(default = "max_crawl_file_size_default")] + pub max_file_size: u64, #[serde(default = "max_crawl_memory_default")] - pub max_crawl_memory: u32, + pub max_crawl_memory: u64, #[serde(default)] pub all_files: bool, } @@ -103,6 +146,8 @@ pub struct Crawl { pub struct PostgresML { pub database_url: Option, pub crawl: Option, + #[serde(default)] + pub splitter: ValidSplitter, } #[derive(Clone, Debug, Deserialize, Default)] diff --git a/src/crawl.rs b/src/crawl.rs index 4a860e2..191d869 100644 --- a/src/crawl.rs +++ b/src/crawl.rs @@ -18,10 +18,14 @@ impl Crawl { } } + pub fn crawl_config(&self) -> &config::Crawl { + &self.crawl_config + } + pub fn maybe_do_crawl( &mut self, triggered_file: Option, - mut f: impl FnMut(&str) -> anyhow::Result<()>, + mut f: impl FnMut(&config::Crawl, &str) -> anyhow::Result<()>, ) -> anyhow::Result<()> { if let Some(root_uri) = &self.config.client_params.root_uri { if !root_uri.starts_with("file://") { @@ -52,7 +56,7 @@ impl Crawl { if !path.is_dir() { if let Some(path_str) = path.to_str() { if self.crawl_config.all_files { - f(path_str)?; + f(&self.crawl_config, path_str)?; } else { match ( path.extension().map(|pe| pe.to_str()).flatten(), @@ -60,7 +64,7 @@ impl Crawl { ) { (Some(path_extension), Some(extension_to_match)) => { if path_extension == extension_to_match { - f(path_str)?; + f(&self.crawl_config, path_str)?; } } _ => continue, diff --git a/src/main.rs b/src/main.rs index 82ef732..106be0a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -18,6 +18,7 @@ mod crawl; mod custom_requests; mod memory_backends; mod memory_worker; +mod splitters; #[cfg(feature = "llama_cpp")] mod template; mod transformer_backends; @@ -51,15 +52,19 @@ where req.extract(R::METHOD) } -fn main() -> Result<()> { - // Builds a tracing subscriber from the `LSP_AI_LOG` environment variable - // If the variables value is malformed or missing, sets the default log level to ERROR +// Builds a tracing subscriber from the `LSP_AI_LOG` environment variable +// If the variables value is malformed or missing, sets the default log level to ERROR +fn init_logger() { FmtSubscriber::builder() .with_writer(std::io::stderr) .with_ansi(false) .without_time() .with_env_filter(EnvFilter::from_env("LSP_AI_LOG")) .init(); +} + +fn main() -> Result<()> { + init_logger(); let (connection, io_threads) = Connection::stdio(); let server_capabilities = serde_json::to_value(ServerCapabilities { diff --git a/src/memory_backends/file_store.rs b/src/memory_backends/file_store.rs index e1f4ff2..e93cb06 100644 --- a/src/memory_backends/file_store.rs +++ b/src/memory_backends/file_store.rs @@ -6,17 +6,50 @@ use ropey::Rope; use serde_json::Value; use std::collections::HashMap; use tracing::{error, instrument}; +use tree_sitter::{InputEdit, Point, Tree}; use crate::{ config::{self, Config}, crawl::Crawl, - utils::tokens_to_estimated_characters, + utils::{parse_tree, tokens_to_estimated_characters}, }; use super::{ContextAndCodePrompt, FIMPrompt, MemoryBackend, MemoryRunParams, Prompt, PromptType}; +#[derive(Default)] +pub struct AdditionalFileStoreParams { + build_tree: bool, +} + +impl AdditionalFileStoreParams { + pub fn new(build_tree: bool) -> Self { + Self { build_tree } + } +} + +#[derive(Clone)] +pub struct File { + rope: Rope, + tree: Option, +} + +impl File { + fn new(rope: Rope, tree: Option) -> Self { + Self { rope, tree } + } + + pub fn rope(&self) -> &Rope { + &self.rope + } + + pub fn tree(&self) -> Option<&Tree> { + self.tree.as_ref() + } +} + pub struct FileStore { - file_map: Mutex>, + params: AdditionalFileStoreParams, + file_map: Mutex>, accessed_files: Mutex>, crawl: Option>, } @@ -28,29 +61,72 @@ impl FileStore { .take() .map(|x| Mutex::new(Crawl::new(x, config.clone()))); let s = Self { + params: AdditionalFileStoreParams::default(), file_map: Mutex::new(HashMap::new()), accessed_files: Mutex::new(IndexSet::new()), crawl, }; if let Err(e) = s.maybe_do_crawl(None) { - error!("{e}") + error!("{e:?}") } Ok(s) } + pub fn new_with_params( + mut file_store_config: config::FileStore, + config: Config, + params: AdditionalFileStoreParams, + ) -> anyhow::Result { + let crawl = file_store_config + .crawl + .take() + .map(|x| Mutex::new(Crawl::new(x, config.clone()))); + let s = Self { + params, + file_map: Mutex::new(HashMap::new()), + accessed_files: Mutex::new(IndexSet::new()), + crawl, + }; + if let Err(e) = s.maybe_do_crawl(None) { + error!("{e:?}") + } + Ok(s) + } + + fn add_new_file(&self, uri: &str, contents: String) { + let tree = if self.params.build_tree { + match parse_tree(uri, &contents, None) { + Ok(tree) => Some(tree), + Err(e) => { + error!( + "Failed to parse tree for {uri} with error {e}, falling back to no tree" + ); + None + } + } + } else { + None + }; + self.file_map + .lock() + .insert(uri.to_string(), File::new(Rope::from_str(&contents), tree)); + self.accessed_files.lock().insert(uri.to_string()); + } + fn maybe_do_crawl(&self, triggered_file: Option) -> anyhow::Result<()> { if let Some(crawl) = &self.crawl { - crawl.lock().maybe_do_crawl(triggered_file, |path| { - let insert_uri = format!("file://{path}"); - if self.file_map.lock().contains_key(&insert_uri) { - return Ok(()); - } - let contents = std::fs::read_to_string(path)?; - self.file_map - .lock() - .insert(insert_uri, Rope::from_str(&contents)); - Ok(()) - })?; + crawl + .lock() + .maybe_do_crawl(triggered_file, |config, path| { + let insert_uri = format!("file://{path}"); + if self.file_map.lock().contains_key(&insert_uri) { + return Ok(()); + } + // TODO: actually limit files based on config + let contents = std::fs::read_to_string(path)?; + self.add_new_file(&insert_uri, contents); + Ok(()) + })?; } Ok(()) } @@ -67,6 +143,7 @@ impl FileStore { .lock() .get(¤t_document_uri) .context("Error file not found")? + .rope .clone(); let mut cursor_index = rope.line_to_char(position.position.line as usize) + position.position.character as usize; @@ -82,7 +159,7 @@ impl FileStore { break; } let file_map = self.file_map.lock(); - let r = file_map.get(file).context("Error file not found")?; + let r = &file_map.get(file).context("Error file not found")?.rope; let slice_max = needed.min(r.len_chars() + 1); let rope_str_slice = r .get_slice(0..slice_max - 1) @@ -105,6 +182,7 @@ impl FileStore { .lock() .get(position.text_document.uri.as_str()) .context("Error file not found")? + .rope .clone(); let cursor_index = rope.line_to_char(position.position.line as usize) + position.position.character as usize; @@ -173,8 +251,8 @@ impl FileStore { }) } - pub fn get_file_contents(&self, uri: &str) -> Option { - self.file_map.lock().get(uri).clone().map(|x| x.to_string()) + pub fn file_map(&self) -> &Mutex> { + &self.file_map } pub fn contains_file(&self, uri: &str) -> bool { @@ -191,6 +269,7 @@ impl MemoryBackend for FileStore { .lock() .get(position.text_document.uri.as_str()) .context("Error file not found")? + .rope .clone(); let line = rope .get_line(position.position.line as usize) @@ -217,12 +296,10 @@ impl MemoryBackend for FileStore { &self, params: lsp_types::DidOpenTextDocumentParams, ) -> anyhow::Result<()> { - let rope = Rope::from_str(¶ms.text_document.text); let uri = params.text_document.uri.to_string(); - self.file_map.lock().insert(uri.clone(), rope); - self.accessed_files.lock().shift_insert(0, uri.clone()); + self.add_new_file(&uri, params.text_document.text); if let Err(e) = self.maybe_do_crawl(Some(uri)) { - error!("{e}") + error!("{e:?}") } Ok(()) } @@ -234,20 +311,95 @@ impl MemoryBackend for FileStore { ) -> anyhow::Result<()> { let uri = params.text_document.uri.to_string(); let mut file_map = self.file_map.lock(); - let rope = file_map + let file = file_map .get_mut(&uri) - .context("Error trying to get file that does not exist")?; + .with_context(|| format!("Trying to get file that does not exist {uri}"))?; for change in params.content_changes { // If range is ommitted, text is the new text of the document if let Some(range) = change.range { - let start_index = - rope.line_to_char(range.start.line as usize) + range.start.character as usize; + // Record old positions + let (old_end_position, old_end_byte) = { + let last_line_index = file.rope.len_lines() - 1; + ( + file.rope + .get_line(last_line_index) + .context("getting last line for edit") + .map(|last_line| Point::new(last_line_index, last_line.len_chars())), + file.rope.bytes().count(), + ) + }; + // Update the document + let start_index = file.rope.line_to_char(range.start.line as usize) + + range.start.character as usize; let end_index = - rope.line_to_char(range.end.line as usize) + range.end.character as usize; - rope.remove(start_index..end_index); - rope.insert(start_index, &change.text); + file.rope.line_to_char(range.end.line as usize) + range.end.character as usize; + file.rope.remove(start_index..end_index); + file.rope.insert(start_index, &change.text); + // Set new end positions + let (new_end_position, new_end_byte) = { + let last_line_index = file.rope.len_lines() - 1; + ( + file.rope + .get_line(last_line_index) + .context("getting last line for edit") + .map(|last_line| Point::new(last_line_index, last_line.len_chars())), + file.rope.bytes().count(), + ) + }; + // Update the tree + if self.params.build_tree { + let mut old_tree = file.tree.take(); + let start_byte = file + .rope + .try_line_to_char(range.start.line as usize) + .and_then(|start_char| { + file.rope + .try_char_to_byte(start_char + range.start.character as usize) + }) + .map_err(anyhow::Error::msg); + if let Some(old_tree) = &mut old_tree { + match (start_byte, old_end_position, new_end_position) { + (Ok(start_byte), Ok(old_end_position), Ok(new_end_position)) => { + old_tree.edit(&InputEdit { + start_byte, + old_end_byte, + new_end_byte, + start_position: Point::new( + range.start.line as usize, + range.start.character as usize, + ), + old_end_position, + new_end_position, + }); + file.tree = match parse_tree( + &uri, + &file.rope.to_string(), + Some(old_tree), + ) { + Ok(tree) => Some(tree), + Err(e) => { + error!("failed to edit tree: {e:?}"); + None + } + }; + } + (Err(e), _, _) | (_, Err(e), _) | (_, _, Err(e)) => { + error!("failed to build tree edit: {e:?}"); + } + } + } + } } else { - *rope = Rope::from_str(&change.text); + file.rope = Rope::from_str(&change.text); + if self.params.build_tree { + file.tree = match parse_tree(&uri, &change.text, None) { + Ok(tree) => Some(tree), + Err(e) => { + error!("failed to parse new tree: {e:?}"); + None + } + }; + } } } self.accessed_files.lock().shift_insert(0, uri); @@ -299,8 +451,8 @@ mod tests { } } - #[tokio::test] - async fn can_open_document() -> anyhow::Result<()> { + #[test] + fn can_open_document() -> anyhow::Result<()> { let params = lsp_types::DidOpenTextDocumentParams { text_document: generate_filler_text_document(None, None), }; @@ -312,12 +464,12 @@ mod tests { .get("file://filler/") .unwrap() .clone(); - assert_eq!(file.to_string(), "Here is the document body"); + assert_eq!(file.rope.to_string(), "Here is the document body"); Ok(()) } - #[tokio::test] - async fn can_rename_document() -> anyhow::Result<()> { + #[test] + fn can_rename_document() -> anyhow::Result<()> { let params = lsp_types::DidOpenTextDocumentParams { text_document: generate_filler_text_document(None, None), }; @@ -338,12 +490,12 @@ mod tests { .get("file://filler2/") .unwrap() .clone(); - assert_eq!(file.to_string(), "Here is the document body"); + assert_eq!(file.rope.to_string(), "Here is the document body"); Ok(()) } - #[tokio::test] - async fn can_change_document() -> anyhow::Result<()> { + #[test] + fn can_change_document() -> anyhow::Result<()> { let text_document = generate_filler_text_document(None, None); let params = DidOpenTextDocumentParams { @@ -379,7 +531,7 @@ mod tests { .get("file://filler/") .unwrap() .clone(); - assert_eq!(file.to_string(), "Hae is the document body"); + assert_eq!(file.rope.to_string(), "Hae is the document body"); let params = lsp_types::DidChangeTextDocumentParams { text_document: VersionedTextDocumentIdentifier { @@ -399,7 +551,7 @@ mod tests { .get("file://filler/") .unwrap() .clone(); - assert_eq!(file.to_string(), "abc"); + assert_eq!(file.rope.to_string(), "abc"); Ok(()) } @@ -579,43 +731,123 @@ The end with a trailing new line Ok(()) } - // #[tokio::test] - // async fn test_fim_placement_corner_cases() -> anyhow::Result<()> { - // let text_document = generate_filler_text_document(None, Some("test\n")); - // let params = lsp_types::DidOpenTextDocumentParams { - // text_document: text_document.clone(), - // }; - // let file_store = generate_base_file_store()?; - // file_store.opened_text_document(params).await?; + #[test] + fn test_file_store_tree_sitter() -> anyhow::Result<()> { + crate::init_logger(); - // // Test FIM - // let params = json!({ - // "fim": { - // "start": "SS", - // "middle": "MM", - // "end": "EE" - // } - // }); - // let prompt = file_store - // .build_prompt( - // &TextDocumentPositionParams { - // text_document: TextDocumentIdentifier { - // uri: text_document.uri.clone(), - // }, - // position: Position { - // line: 1, - // character: 0, - // }, - // }, - // params, - // ) - // .await?; - // assert_eq!(prompt.context, ""); - // let text = r#"test - // "# - // .to_string(); - // assert_eq!(text, prompt.code); + let config = Config::default_with_file_store_without_models(); + let file_store_config = if let config::ValidMemoryBackend::FileStore(file_store_config) = + config.config.memory.clone() + { + file_store_config + } else { + anyhow::bail!("requires a file_store_config") + }; + let params = AdditionalFileStoreParams { build_tree: true }; + let file_store = FileStore::new_with_params(file_store_config, config, params)?; - // Ok(()) - // } + let uri = "file://filler/test.rs"; + let text = r#"#[derive(Debug)] +struct Rectangle { + width: u32, + height: u32, +} + +impl Rectangle { + fn area(&self) -> u32 { + + } +} + +fn main() { + let rect1 = Rectangle { + width: 30, + height: 50, + }; + + println!( + "The area of the rectangle is {} square pixels.", + rect1.area() + ); +}"#; + let text_document = TextDocumentItem { + uri: reqwest::Url::parse(uri).unwrap(), + language_id: "".to_string(), + version: 0, + text: text.to_string(), + }; + let params = DidOpenTextDocumentParams { + text_document: text_document.clone(), + }; + + file_store.opened_text_document(params)?; + + // Test insert + let params = lsp_types::DidChangeTextDocumentParams { + text_document: VersionedTextDocumentIdentifier { + uri: text_document.uri.clone(), + version: 1, + }, + content_changes: vec![TextDocumentContentChangeEvent { + range: Some(Range { + start: Position { + line: 8, + character: 0, + }, + end: Position { + line: 8, + character: 0, + }, + }), + range_length: None, + text: " self.width * self.height".to_string(), + }], + }; + file_store.changed_text_document(params)?; + let file = file_store.file_map.lock().get(uri).unwrap().clone(); + assert_eq!(file.tree.unwrap().root_node().to_sexp(), "(source_file (attribute_item (attribute (identifier) arguments: (token_tree (identifier)))) (struct_item name: (type_identifier) body: (field_declaration_list (field_declaration name: (field_identifier) type: (primitive_type)) (field_declaration name: (field_identifier) type: (primitive_type)))) (impl_item type: (type_identifier) body: (declaration_list (function_item name: (identifier) parameters: (parameters (self_parameter (self))) return_type: (primitive_type) body: (block (binary_expression left: (field_expression value: (self) field: (field_identifier)) right: (field_expression value: (self) field: (field_identifier))))))) (function_item name: (identifier) parameters: (parameters) body: (block (let_declaration pattern: (identifier) value: (struct_expression name: (type_identifier) body: (field_initializer_list (field_initializer field: (field_identifier) value: (integer_literal)) (field_initializer field: (field_identifier) value: (integer_literal))))) (expression_statement (macro_invocation macro: (identifier) (token_tree (string_literal (string_content)) (identifier) (identifier) (token_tree)))))))"); + + // Test delete + let params = lsp_types::DidChangeTextDocumentParams { + text_document: VersionedTextDocumentIdentifier { + uri: text_document.uri.clone(), + version: 1, + }, + content_changes: vec![TextDocumentContentChangeEvent { + range: Some(Range { + start: Position { + line: 0, + character: 0, + }, + end: Position { + line: 12, + character: 0, + }, + }), + range_length: None, + text: "".to_string(), + }], + }; + file_store.changed_text_document(params)?; + let file = file_store.file_map.lock().get(uri).unwrap().clone(); + assert_eq!(file.tree.unwrap().root_node().to_sexp(), "(source_file (function_item name: (identifier) parameters: (parameters) body: (block (let_declaration pattern: (identifier) value: (struct_expression name: (type_identifier) body: (field_initializer_list (field_initializer field: (field_identifier) value: (integer_literal)) (field_initializer field: (field_identifier) value: (integer_literal))))) (expression_statement (macro_invocation macro: (identifier) (token_tree (string_literal (string_content)) (identifier) (identifier) (token_tree)))))))"); + + // Test replace + let params = lsp_types::DidChangeTextDocumentParams { + text_document: VersionedTextDocumentIdentifier { + uri: text_document.uri, + version: 1, + }, + content_changes: vec![TextDocumentContentChangeEvent { + range: None, + range_length: None, + text: "fn main() {}".to_string(), + }], + }; + file_store.changed_text_document(params)?; + let file = file_store.file_map.lock().get(uri).unwrap().clone(); + assert_eq!(file.tree.unwrap().root_node().to_sexp(), "(source_file (function_item name: (identifier) parameters: (parameters) body: (block)))"); + + Ok(()) + } } diff --git a/src/memory_backends/postgresml/mod.rs b/src/memory_backends/postgresml/mod.rs index d94f9d2..2c091e1 100644 --- a/src/memory_backends/postgresml/mod.rs +++ b/src/memory_backends/postgresml/mod.rs @@ -1,30 +1,65 @@ +use anyhow::Context; +use lsp_types::TextDocumentPositionParams; +use parking_lot::Mutex; +use pgml::{Collection, Pipeline}; +use serde_json::{json, Value}; use std::{ + io::Read, sync::{ mpsc::{self, Sender}, Arc, }, time::Duration, }; - -use anyhow::Context; -use lsp_types::TextDocumentPositionParams; -use parking_lot::Mutex; -use pgml::{Collection, Pipeline}; -use serde_json::{json, Value}; use tokio::time; -use tracing::{error, instrument}; +use tracing::{error, instrument, warn}; use crate::{ config::{self, Config}, crawl::Crawl, - utils::{tokens_to_estimated_characters, TOKIO_RUNTIME}, + splitters::{Chunk, Splitter}, + utils::{chunk_to_id, tokens_to_estimated_characters, TOKIO_RUNTIME}, }; use super::{ - file_store::FileStore, ContextAndCodePrompt, FIMPrompt, MemoryBackend, MemoryRunParams, Prompt, - PromptType, + file_store::{AdditionalFileStoreParams, FileStore}, + ContextAndCodePrompt, FIMPrompt, MemoryBackend, MemoryRunParams, Prompt, PromptType, }; +fn chunk_to_document(uri: &str, chunk: Chunk) -> Value { + json!({ + "id": chunk_to_id(uri, &chunk), + "uri": uri, + "text": chunk.text, + "range": chunk.range + }) +} + +async fn split_and_upsert_file( + uri: &str, + collection: &mut Collection, + file_store: Arc, + splitter: Arc>, +) -> anyhow::Result<()> { + // We need to make sure we don't hold the file_store lock while performing a network call + let chunks = { + file_store + .file_map() + .lock() + .get(uri) + .map(|f| splitter.split(f)) + }; + let chunks = chunks.with_context(|| format!("file not found for splitting: {uri}"))?; + let documents = chunks + .into_iter() + .map(|chunk| chunk_to_document(uri, chunk).into()) + .collect(); + collection + .upsert_documents(documents, None) + .await + .context("PGML - Error upserting documents") +} + #[derive(Clone)] pub struct PostgresML { _config: Config, @@ -33,6 +68,7 @@ pub struct PostgresML { pipeline: Pipeline, debounce_tx: Sender, crawl: Option>>, + splitter: Arc>, } impl PostgresML { @@ -45,10 +81,16 @@ impl PostgresML { .crawl .take() .map(|x| Arc::new(Mutex::new(Crawl::new(x, configuration.clone())))); - let file_store = Arc::new(FileStore::new( + + let splitter: Arc> = + Arc::new(postgresml_config.splitter.try_into()?); + + let file_store = Arc::new(FileStore::new_with_params( config::FileStore::new_without_crawl(), configuration.clone(), + AdditionalFileStoreParams::new(splitter.does_use_tree_sitter()), )?); + let database_url = if let Some(database_url) = postgresml_config.database_url { database_url } else { @@ -86,6 +128,7 @@ impl PostgresML { let (debounce_tx, debounce_rx) = mpsc::channel::(); let mut task_collection = collection.clone(); let task_file_store = file_store.clone(); + let task_splitter = splitter.clone(); TOKIO_RUNTIME.spawn(async move { let duration = Duration::from_millis(500); let mut file_uris = Vec::new(); @@ -102,36 +145,83 @@ impl PostgresML { if file_uris.is_empty() { continue; } - let documents = match file_uris + + // Build the chunks for our changed files + let chunks: Vec> = match file_uris .iter() .map(|uri| { - let text = task_file_store - .get_file_contents(&uri) - .context("Error reading file contents from file_store")?; - anyhow::Ok( - json!({ - "id": uri, - "text": text - }) - .into(), - ) + let file_store = task_file_store.file_map().lock(); + let file = file_store + .get(uri) + .with_context(|| format!("getting file for splitting: {uri}"))?; + anyhow::Ok(task_splitter.split(file)) }) .collect() { - Ok(documents) => documents, + Ok(chunks) => chunks, Err(e) => { error!("{e}"); continue; } }; + + // Delete old chunks that no longer exist after the latest file changes + let delete_or_statements: Vec = file_uris + .iter() + .zip(&chunks) + .map(|(uri, chunks)| { + let ids: Vec = + chunks.iter().map(|c| chunk_to_id(uri, c)).collect(); + json!({ + "$and": [ + { + "uri": { + "$eq": uri + } + }, + { + "id": { + "$nin": ids + } + } + ] + }) + }) + .collect(); + if let Err(e) = task_collection + .delete_documents( + json!({ + "$or": delete_or_statements + }) + .into(), + ) + .await + { + error!("PGML - Error deleting file: {e:?}"); + } + + // Prepare and upsert our new chunks + let documents: Vec = chunks + .into_iter() + .zip(&file_uris) + .map(|(chunks, uri)| { + chunks + .into_iter() + .map(|chunk| chunk_to_document(&uri, chunk)) + .collect::>() + }) + .flatten() + .map(|f: Value| f.into()) + .collect(); if let Err(e) = task_collection .upsert_documents(documents, None) .await - .context("PGML - Error adding pipeline to collection") + .context("PGML - Error upserting changed files") { error!("{e}"); continue; } + file_uris = Vec::new(); } } @@ -144,6 +234,7 @@ impl PostgresML { pipeline, debounce_tx, crawl, + splitter, }; if let Err(e) = s.maybe_do_crawl(None) { @@ -154,28 +245,73 @@ impl PostgresML { fn maybe_do_crawl(&self, triggered_file: Option) -> anyhow::Result<()> { if let Some(crawl) = &self.crawl { - let mut _collection = self.collection.clone(); - let mut _pipeline = self.pipeline.clone(); - let mut documents: Vec = vec![]; - crawl.lock().maybe_do_crawl(triggered_file, |path| { - let uri = format!("file://{path}"); - // This means it has been opened before - if self.file_store.contains_file(&uri) { - return Ok(()); - } - // Get the contents, split, and upsert it - let contents = std::fs::read_to_string(path)?; - documents.push( - json!({ - "id": uri, - "text": contents - }) - .into(), - ); - // Track the size of the documents we have - // If it is over some amount in bytes, upsert it - Ok(()) - })?; + let mut documents: Vec<(String, Vec)> = vec![]; + let mut total_bytes = 0; + let mut current_bytes = 0; + crawl + .lock() + .maybe_do_crawl(triggered_file, |config, path| { + let uri = format!("file://{path}"); + // This means it has been opened before + if self.file_store.contains_file(&uri) { + return Ok(()); + } + // Open the file and see if it is small enough to read + let mut f = std::fs::File::open(path)?; + if f.metadata() + .map(|m| m.len() > config.max_file_size) + .unwrap_or(true) + { + warn!("Skipping file because it is too large: {path}"); + return Ok(()); + } + // Read the file contents + let mut contents = vec![]; + f.read_to_end(&mut contents); + if let Ok(contents) = String::from_utf8(contents) { + current_bytes += contents.len(); + total_bytes += contents.len(); + let chunks = self.splitter.split_file_contents(&uri, &contents); + documents.push((uri, chunks)); + } + // If we have over 100 mega bytes of data do the upsert + if current_bytes >= 100_000_000 || total_bytes as u64 >= config.max_crawl_memory + { + // Prepare our chunks + let to_upsert_documents: Vec = + std::mem::take(&mut documents) + .into_iter() + .map(|(uri, chunks)| { + chunks + .into_iter() + .map(|chunk| chunk_to_document(&uri, chunk)) + .collect::>() + }) + .flatten() + .map(|f: Value| f.into()) + .collect(); + // Do the upsert + let mut collection = self.collection.clone(); + TOKIO_RUNTIME.spawn(async move { + if let Err(e) = collection + .upsert_documents(to_upsert_documents, None) + .await + .context("PGML - Error upserting changed files") + { + error!("{e}"); + } + }); + // Reset everything + current_bytes = 0; + documents = vec![]; + } + // Break if total bytes is over the max crawl memory + if total_bytes as u64 >= config.max_crawl_memory { + warn!("Ending crawl eraly do to max_crawl_memory"); + return Ok(()); + } + Ok(()) + })?; } Ok(()) } @@ -263,25 +399,22 @@ impl MemoryBackend for PostgresML { params: lsp_types::DidOpenTextDocumentParams, ) -> anyhow::Result<()> { self.file_store.opened_text_document(params.clone())?; - let mut task_collection = self.collection.clone(); + let saved_uri = params.text_document.uri.to_string(); + + let mut collection = self.collection.clone(); + let file_store = self.file_store.clone(); + let splitter = self.splitter.clone(); TOKIO_RUNTIME.spawn(async move { - let text = params.text_document.text.clone(); let uri = params.text_document.uri.to_string(); - task_collection - .upsert_documents( - vec![json!({ - "id": uri, - "text": text - }) - .into()], - None, - ) - .await - .expect("PGML - Error upserting documents"); + if let Err(e) = split_and_upsert_file(&uri, &mut collection, file_store, splitter).await + { + error!("{e:?}") + } }); + if let Err(e) = self.maybe_do_crawl(Some(saved_uri)) { - error!("{e}") + error!("{e:?}") } Ok(()) } @@ -300,32 +433,35 @@ impl MemoryBackend for PostgresML { #[instrument(skip(self))] fn renamed_files(&self, params: lsp_types::RenameFilesParams) -> anyhow::Result<()> { self.file_store.renamed_files(params.clone())?; - let mut task_collection = self.collection.clone(); - let task_params = params.clone(); + + let mut collection = self.collection.clone(); + let file_store = self.file_store.clone(); + let splitter = self.splitter.clone(); TOKIO_RUNTIME.spawn(async move { - for file in task_params.files { - task_collection + for file in params.files { + if let Err(e) = collection .delete_documents( json!({ - "id": file.old_uri + "uri": { + "$eq": file.old_uri + } }) .into(), ) .await - .expect("PGML - Error deleting file"); - let text = - std::fs::read_to_string(&file.new_uri).expect("PGML - Error reading file"); - task_collection - .upsert_documents( - vec![json!({ - "id": file.new_uri, - "text": text - }) - .into()], - None, - ) - .await - .expect("PGML - Error adding pipeline to collection"); + { + error!("PGML - Error deleting file: {e:?}"); + } + if let Err(e) = split_and_upsert_file( + &file.new_uri, + &mut collection, + file_store.clone(), + splitter.clone(), + ) + .await + { + error!("{e:?}") + } } }); Ok(()) diff --git a/src/splitters/mod.rs b/src/splitters/mod.rs new file mode 100644 index 0000000..ed5c15a --- /dev/null +++ b/src/splitters/mod.rs @@ -0,0 +1,53 @@ +use serde::Serialize; + +use crate::{config::ValidSplitter, memory_backends::file_store::File}; + +mod tree_sitter; + +#[derive(Serialize)] +pub struct ByteRange { + pub start_byte: usize, + pub end_byte: usize, +} + +impl ByteRange { + pub fn new(start_byte: usize, end_byte: usize) -> Self { + Self { + start_byte, + end_byte, + } + } +} + +#[derive(Serialize)] +pub struct Chunk { + pub text: String, + pub range: ByteRange, +} + +impl Chunk { + fn new(text: String, range: ByteRange) -> Self { + Self { text, range } + } +} + +pub trait Splitter { + fn split(&self, file: &File) -> Vec; + fn split_file_contents(&self, uri: &str, contents: &str) -> Vec; + + fn does_use_tree_sitter(&self) -> bool { + false + } +} + +impl TryFrom for Box { + type Error = anyhow::Error; + + fn try_from(value: ValidSplitter) -> Result { + match value { + ValidSplitter::TreeSitter(config) => { + Ok(Box::new(tree_sitter::TreeSitter::new(config)?)) + } + } + } +} diff --git a/src/splitters/tree_sitter.rs b/src/splitters/tree_sitter.rs new file mode 100644 index 0000000..e8fb309 --- /dev/null +++ b/src/splitters/tree_sitter.rs @@ -0,0 +1,77 @@ +use splitter_tree_sitter::TreeSitterCodeSplitter; +use tracing::error; +use tree_sitter::Tree; + +use crate::{config, memory_backends::file_store::File, utils::parse_tree}; + +use super::{ByteRange, Chunk, Splitter}; + +pub struct TreeSitter { + _config: config::TreeSitter, + splitter: TreeSitterCodeSplitter, +} + +impl TreeSitter { + pub fn new(config: config::TreeSitter) -> anyhow::Result { + Ok(Self { + splitter: TreeSitterCodeSplitter::new(config.chunk_size, config.chunk_overlap)?, + _config: config, + }) + } + + fn split_tree(&self, tree: &Tree, contents: &[u8]) -> anyhow::Result> { + Ok(self + .splitter + .split(tree, contents)? + .into_iter() + .map(|c| { + Chunk::new( + c.text.to_owned(), + ByteRange::new(c.range.start_byte, c.range.end_byte), + ) + }) + .collect()) + } +} + +impl Splitter for TreeSitter { + fn split(&self, file: &File) -> Vec { + if let Some(tree) = file.tree() { + match self.split_tree(tree, file.rope().to_string().as_bytes()) { + Ok(chunks) => chunks, + Err(e) => { + error!( + "Failed to parse tree for file with error {e:?}. Falling back to default splitter.", + ); + todo!() + } + } + } else { + panic!("TreeSitter splitter requires a tree to split") + } + } + + fn split_file_contents(&self, uri: &str, contents: &str) -> Vec { + match parse_tree(uri, contents, None) { + Ok(tree) => match self.split_tree(&tree, contents.as_bytes()) { + Ok(chunks) => chunks, + Err(e) => { + error!( + "Failed to parse tree for file: {uri} with error {e:?}. Falling back to default splitter.", + ); + todo!() + } + }, + Err(e) => { + error!( + "Failed to parse tree for file {uri} with error {e:?}. Falling back to default splitter.", + ); + todo!() + } + } + } + + fn does_use_tree_sitter(&self) -> bool { + true + } +} diff --git a/src/utils.rs b/src/utils.rs index 29afd71..8b5b8b4 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,8 +1,10 @@ +use anyhow::Context; use lsp_server::ResponseError; use once_cell::sync::Lazy; use tokio::runtime; +use tree_sitter::Tree; -use crate::{config::ChatMessage, memory_backends::ContextAndCodePrompt}; +use crate::{config::ChatMessage, memory_backends::ContextAndCodePrompt, splitters::Chunk}; pub static TOKIO_RUNTIME: Lazy = Lazy::new(|| { runtime::Builder::new_multi_thread() @@ -52,3 +54,17 @@ pub fn format_context_code_in_str(s: &str, context: &str, code: &str) -> String pub fn format_context_code(context: &str, code: &str) -> String { format!("{context}\n\n{code}") } + +pub fn chunk_to_id(uri: &str, chunk: &Chunk) -> String { + format!("{uri}#{}-{}", chunk.range.start_byte, chunk.range.end_byte) +} + +pub fn parse_tree(uri: &str, contents: &str, old_tree: Option<&Tree>) -> anyhow::Result { + let path = std::path::Path::new(uri); + let extension = path.extension().map(|x| x.to_string_lossy()); + let extension = extension.as_deref().unwrap_or(""); + let mut parser = utils_tree_sitter::get_parser_for_extension(extension)?; + parser + .parse(&contents, old_tree) + .with_context(|| format!("parsing tree failed for {uri}")) +}