Periodic commit

This commit is contained in:
Silas Marvin 2024-06-16 09:49:16 -07:00
parent 3eb5102399
commit 58192c4182
10 changed files with 810 additions and 170 deletions

76
Cargo.lock generated
View File

@ -1523,6 +1523,7 @@ dependencies = [
"anyhow", "anyhow",
"assert_cmd", "assert_cmd",
"async-trait", "async-trait",
"cc",
"directories", "directories",
"hf-hub", "hf-hub",
"ignore", "ignore",
@ -1539,10 +1540,13 @@ dependencies = [
"ropey", "ropey",
"serde", "serde",
"serde_json", "serde_json",
"splitter-tree-sitter",
"tokenizers", "tokenizers",
"tokio", "tokio",
"tracing", "tracing",
"tracing-subscriber", "tracing-subscriber",
"tree-sitter",
"utils-tree-sitter",
"xxhash-rust", "xxhash-rust",
] ]
@ -2196,9 +2200,9 @@ dependencies = [
[[package]] [[package]]
name = "regex" name = "regex"
version = "1.10.3" version = "1.10.5"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b62dbe01f0b06f9d8dc7d49e05a0785f153b00b2c227856282f671e0318c9b15" checksum = "b91213439dad192326a0d7c6ee3955910425f441d7038e0d6933b0aec5c4517f"
dependencies = [ dependencies = [
"aho-corasick", "aho-corasick",
"memchr", "memchr",
@ -2756,6 +2760,15 @@ dependencies = [
"der", "der",
] ]
[[package]]
name = "splitter-tree-sitter"
version = "0.1.0"
dependencies = [
"cc",
"thiserror",
"tree-sitter",
]
[[package]] [[package]]
name = "spm_precompiled" name = "spm_precompiled"
version = "0.1.4" version = "0.1.4"
@ -3088,18 +3101,18 @@ checksum = "3369f5ac52d5eb6ab48c6b4ffdc8efbcad6b89c765749064ba298f2c68a16a76"
[[package]] [[package]]
name = "thiserror" name = "thiserror"
version = "1.0.58" version = "1.0.61"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "03468839009160513471e86a034bb2c5c0e4baae3b43f79ffc55c4a5427b3297" checksum = "c546c80d6be4bc6a00c0f01730c08df82eaa7a7a61f11d656526506112cc1709"
dependencies = [ dependencies = [
"thiserror-impl", "thiserror-impl",
] ]
[[package]] [[package]]
name = "thiserror-impl" name = "thiserror-impl"
version = "1.0.58" version = "1.0.61"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c61f3ba182994efc43764a46c018c347bc492c79f024e705f46567b418f6d4f7" checksum = "46c3384250002a6d5af4d114f2845d37b57521033f30d5c3f46c4d70e1197533"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
@ -3339,6 +3352,45 @@ dependencies = [
"tracing-serde", "tracing-serde",
] ]
[[package]]
name = "tree-sitter"
version = "0.22.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df7cc499ceadd4dcdf7ec6d4cbc34ece92c3fa07821e287aedecd4416c516dca"
dependencies = [
"cc",
"regex",
]
[[package]]
name = "tree-sitter-python"
version = "0.21.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b4066c6cf678f962f8c2c4561f205945c84834cce73d981e71392624fdc390a9"
dependencies = [
"cc",
"tree-sitter",
]
[[package]]
name = "tree-sitter-rust"
version = "0.21.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "277690f420bf90741dea984f3da038ace46c4fe6047cba57a66822226cde1c93"
dependencies = [
"cc",
"tree-sitter",
]
[[package]]
name = "tree-sitter-zig"
version = "0.0.1"
source = "git+https://github.com/SilasMarvin/tree-sitter-zig?branch=silas-update-tree-sitter-version#2eedab3ff6dda88aedddf0bb32a14f81bb709a73"
dependencies = [
"cc",
"tree-sitter",
]
[[package]] [[package]]
name = "try-lock" name = "try-lock"
version = "0.2.5" version = "0.2.5"
@ -3450,6 +3502,18 @@ version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a" checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a"
[[package]]
name = "utils-tree-sitter"
version = "0.1.0"
dependencies = [
"cc",
"thiserror",
"tree-sitter",
"tree-sitter-python",
"tree-sitter-rust",
"tree-sitter-zig",
]
[[package]] [[package]]
name = "uuid" name = "uuid"
version = "1.7.0" version = "1.7.0"

View File

@ -33,6 +33,14 @@ pgml = "1.0.4"
tokio = { version = "1.36.0", features = ["rt-multi-thread", "time"] } tokio = { version = "1.36.0", features = ["rt-multi-thread", "time"] }
indexmap = "2.2.5" indexmap = "2.2.5"
async-trait = "0.1.78" async-trait = "0.1.78"
tree-sitter = "0.22"
# splitter-tree-sitter = { git = "https://github.com/SilasMarvin/splitter-tree-sitter" }
splitter-tree-sitter = { path = "../../splitter-tree-sitter" }
# utils-tree-sitter = { git = "https://github.com/SilasMarvin/utils-tree-sitter" }
utils-tree-sitter = { path = "../../utils-tree-sitter", features = ["all"] }
[build-dependencies]
cc="*"
[features] [features]
default = [] default = []

View File

@ -24,6 +24,43 @@ impl Default for PostProcess {
} }
} }
#[derive(Debug, Clone, Deserialize)]
pub enum ValidSplitter {
#[serde(rename = "tree_sitter")]
TreeSitter(TreeSitter),
}
impl Default for ValidSplitter {
fn default() -> Self {
ValidSplitter::TreeSitter(TreeSitter::default())
}
}
const fn chunk_size_default() -> usize {
1500
}
const fn chunk_overlap_default() -> usize {
0
}
#[derive(Debug, Clone, Deserialize)]
pub struct TreeSitter {
#[serde(default = "chunk_size_default")]
pub chunk_size: usize,
#[serde(default = "chunk_overlap_default")]
pub chunk_overlap: usize,
}
impl Default for TreeSitter {
fn default() -> Self {
Self {
chunk_size: 1500,
chunk_overlap: 0,
}
}
}
#[derive(Debug, Clone, Deserialize)] #[derive(Debug, Clone, Deserialize)]
pub enum ValidMemoryBackend { pub enum ValidMemoryBackend {
#[serde(rename = "file_store")] #[serde(rename = "file_store")]
@ -85,15 +122,21 @@ pub struct FIM {
pub end: String, pub end: String,
} }
const fn max_crawl_memory_default() -> u32 { const fn max_crawl_memory_default() -> u64 {
42 42
} }
const fn max_crawl_file_size_default() -> u64 {
10_000_000
}
#[derive(Clone, Debug, Deserialize)] #[derive(Clone, Debug, Deserialize)]
#[serde(deny_unknown_fields)] #[serde(deny_unknown_fields)]
pub struct Crawl { pub struct Crawl {
#[serde(default = "max_crawl_file_size_default")]
pub max_file_size: u64,
#[serde(default = "max_crawl_memory_default")] #[serde(default = "max_crawl_memory_default")]
pub max_crawl_memory: u32, pub max_crawl_memory: u64,
#[serde(default)] #[serde(default)]
pub all_files: bool, pub all_files: bool,
} }
@ -103,6 +146,8 @@ pub struct Crawl {
pub struct PostgresML { pub struct PostgresML {
pub database_url: Option<String>, pub database_url: Option<String>,
pub crawl: Option<Crawl>, pub crawl: Option<Crawl>,
#[serde(default)]
pub splitter: ValidSplitter,
} }
#[derive(Clone, Debug, Deserialize, Default)] #[derive(Clone, Debug, Deserialize, Default)]

View File

@ -18,10 +18,14 @@ impl Crawl {
} }
} }
pub fn crawl_config(&self) -> &config::Crawl {
&self.crawl_config
}
pub fn maybe_do_crawl( pub fn maybe_do_crawl(
&mut self, &mut self,
triggered_file: Option<String>, triggered_file: Option<String>,
mut f: impl FnMut(&str) -> anyhow::Result<()>, mut f: impl FnMut(&config::Crawl, &str) -> anyhow::Result<()>,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
if let Some(root_uri) = &self.config.client_params.root_uri { if let Some(root_uri) = &self.config.client_params.root_uri {
if !root_uri.starts_with("file://") { if !root_uri.starts_with("file://") {
@ -52,7 +56,7 @@ impl Crawl {
if !path.is_dir() { if !path.is_dir() {
if let Some(path_str) = path.to_str() { if let Some(path_str) = path.to_str() {
if self.crawl_config.all_files { if self.crawl_config.all_files {
f(path_str)?; f(&self.crawl_config, path_str)?;
} else { } else {
match ( match (
path.extension().map(|pe| pe.to_str()).flatten(), path.extension().map(|pe| pe.to_str()).flatten(),
@ -60,7 +64,7 @@ impl Crawl {
) { ) {
(Some(path_extension), Some(extension_to_match)) => { (Some(path_extension), Some(extension_to_match)) => {
if path_extension == extension_to_match { if path_extension == extension_to_match {
f(path_str)?; f(&self.crawl_config, path_str)?;
} }
} }
_ => continue, _ => continue,

View File

@ -18,6 +18,7 @@ mod crawl;
mod custom_requests; mod custom_requests;
mod memory_backends; mod memory_backends;
mod memory_worker; mod memory_worker;
mod splitters;
#[cfg(feature = "llama_cpp")] #[cfg(feature = "llama_cpp")]
mod template; mod template;
mod transformer_backends; mod transformer_backends;
@ -51,15 +52,19 @@ where
req.extract(R::METHOD) req.extract(R::METHOD)
} }
fn main() -> Result<()> { // Builds a tracing subscriber from the `LSP_AI_LOG` environment variable
// Builds a tracing subscriber from the `LSP_AI_LOG` environment variable // If the variables value is malformed or missing, sets the default log level to ERROR
// If the variables value is malformed or missing, sets the default log level to ERROR fn init_logger() {
FmtSubscriber::builder() FmtSubscriber::builder()
.with_writer(std::io::stderr) .with_writer(std::io::stderr)
.with_ansi(false) .with_ansi(false)
.without_time() .without_time()
.with_env_filter(EnvFilter::from_env("LSP_AI_LOG")) .with_env_filter(EnvFilter::from_env("LSP_AI_LOG"))
.init(); .init();
}
fn main() -> Result<()> {
init_logger();
let (connection, io_threads) = Connection::stdio(); let (connection, io_threads) = Connection::stdio();
let server_capabilities = serde_json::to_value(ServerCapabilities { let server_capabilities = serde_json::to_value(ServerCapabilities {

View File

@ -6,17 +6,50 @@ use ropey::Rope;
use serde_json::Value; use serde_json::Value;
use std::collections::HashMap; use std::collections::HashMap;
use tracing::{error, instrument}; use tracing::{error, instrument};
use tree_sitter::{InputEdit, Point, Tree};
use crate::{ use crate::{
config::{self, Config}, config::{self, Config},
crawl::Crawl, crawl::Crawl,
utils::tokens_to_estimated_characters, utils::{parse_tree, tokens_to_estimated_characters},
}; };
use super::{ContextAndCodePrompt, FIMPrompt, MemoryBackend, MemoryRunParams, Prompt, PromptType}; use super::{ContextAndCodePrompt, FIMPrompt, MemoryBackend, MemoryRunParams, Prompt, PromptType};
#[derive(Default)]
pub struct AdditionalFileStoreParams {
build_tree: bool,
}
impl AdditionalFileStoreParams {
pub fn new(build_tree: bool) -> Self {
Self { build_tree }
}
}
#[derive(Clone)]
pub struct File {
rope: Rope,
tree: Option<Tree>,
}
impl File {
fn new(rope: Rope, tree: Option<Tree>) -> Self {
Self { rope, tree }
}
pub fn rope(&self) -> &Rope {
&self.rope
}
pub fn tree(&self) -> Option<&Tree> {
self.tree.as_ref()
}
}
pub struct FileStore { pub struct FileStore {
file_map: Mutex<HashMap<String, Rope>>, params: AdditionalFileStoreParams,
file_map: Mutex<HashMap<String, File>>,
accessed_files: Mutex<IndexSet<String>>, accessed_files: Mutex<IndexSet<String>>,
crawl: Option<Mutex<Crawl>>, crawl: Option<Mutex<Crawl>>,
} }
@ -28,29 +61,72 @@ impl FileStore {
.take() .take()
.map(|x| Mutex::new(Crawl::new(x, config.clone()))); .map(|x| Mutex::new(Crawl::new(x, config.clone())));
let s = Self { let s = Self {
params: AdditionalFileStoreParams::default(),
file_map: Mutex::new(HashMap::new()), file_map: Mutex::new(HashMap::new()),
accessed_files: Mutex::new(IndexSet::new()), accessed_files: Mutex::new(IndexSet::new()),
crawl, crawl,
}; };
if let Err(e) = s.maybe_do_crawl(None) { if let Err(e) = s.maybe_do_crawl(None) {
error!("{e}") error!("{e:?}")
} }
Ok(s) Ok(s)
} }
pub fn new_with_params(
mut file_store_config: config::FileStore,
config: Config,
params: AdditionalFileStoreParams,
) -> anyhow::Result<Self> {
let crawl = file_store_config
.crawl
.take()
.map(|x| Mutex::new(Crawl::new(x, config.clone())));
let s = Self {
params,
file_map: Mutex::new(HashMap::new()),
accessed_files: Mutex::new(IndexSet::new()),
crawl,
};
if let Err(e) = s.maybe_do_crawl(None) {
error!("{e:?}")
}
Ok(s)
}
fn add_new_file(&self, uri: &str, contents: String) {
let tree = if self.params.build_tree {
match parse_tree(uri, &contents, None) {
Ok(tree) => Some(tree),
Err(e) => {
error!(
"Failed to parse tree for {uri} with error {e}, falling back to no tree"
);
None
}
}
} else {
None
};
self.file_map
.lock()
.insert(uri.to_string(), File::new(Rope::from_str(&contents), tree));
self.accessed_files.lock().insert(uri.to_string());
}
fn maybe_do_crawl(&self, triggered_file: Option<String>) -> anyhow::Result<()> { fn maybe_do_crawl(&self, triggered_file: Option<String>) -> anyhow::Result<()> {
if let Some(crawl) = &self.crawl { if let Some(crawl) = &self.crawl {
crawl.lock().maybe_do_crawl(triggered_file, |path| { crawl
let insert_uri = format!("file://{path}"); .lock()
if self.file_map.lock().contains_key(&insert_uri) { .maybe_do_crawl(triggered_file, |config, path| {
return Ok(()); let insert_uri = format!("file://{path}");
} if self.file_map.lock().contains_key(&insert_uri) {
let contents = std::fs::read_to_string(path)?; return Ok(());
self.file_map }
.lock() // TODO: actually limit files based on config
.insert(insert_uri, Rope::from_str(&contents)); let contents = std::fs::read_to_string(path)?;
Ok(()) self.add_new_file(&insert_uri, contents);
})?; Ok(())
})?;
} }
Ok(()) Ok(())
} }
@ -67,6 +143,7 @@ impl FileStore {
.lock() .lock()
.get(&current_document_uri) .get(&current_document_uri)
.context("Error file not found")? .context("Error file not found")?
.rope
.clone(); .clone();
let mut cursor_index = rope.line_to_char(position.position.line as usize) let mut cursor_index = rope.line_to_char(position.position.line as usize)
+ position.position.character as usize; + position.position.character as usize;
@ -82,7 +159,7 @@ impl FileStore {
break; break;
} }
let file_map = self.file_map.lock(); let file_map = self.file_map.lock();
let r = file_map.get(file).context("Error file not found")?; let r = &file_map.get(file).context("Error file not found")?.rope;
let slice_max = needed.min(r.len_chars() + 1); let slice_max = needed.min(r.len_chars() + 1);
let rope_str_slice = r let rope_str_slice = r
.get_slice(0..slice_max - 1) .get_slice(0..slice_max - 1)
@ -105,6 +182,7 @@ impl FileStore {
.lock() .lock()
.get(position.text_document.uri.as_str()) .get(position.text_document.uri.as_str())
.context("Error file not found")? .context("Error file not found")?
.rope
.clone(); .clone();
let cursor_index = rope.line_to_char(position.position.line as usize) let cursor_index = rope.line_to_char(position.position.line as usize)
+ position.position.character as usize; + position.position.character as usize;
@ -173,8 +251,8 @@ impl FileStore {
}) })
} }
pub fn get_file_contents(&self, uri: &str) -> Option<String> { pub fn file_map(&self) -> &Mutex<HashMap<String, File>> {
self.file_map.lock().get(uri).clone().map(|x| x.to_string()) &self.file_map
} }
pub fn contains_file(&self, uri: &str) -> bool { pub fn contains_file(&self, uri: &str) -> bool {
@ -191,6 +269,7 @@ impl MemoryBackend for FileStore {
.lock() .lock()
.get(position.text_document.uri.as_str()) .get(position.text_document.uri.as_str())
.context("Error file not found")? .context("Error file not found")?
.rope
.clone(); .clone();
let line = rope let line = rope
.get_line(position.position.line as usize) .get_line(position.position.line as usize)
@ -217,12 +296,10 @@ impl MemoryBackend for FileStore {
&self, &self,
params: lsp_types::DidOpenTextDocumentParams, params: lsp_types::DidOpenTextDocumentParams,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let rope = Rope::from_str(&params.text_document.text);
let uri = params.text_document.uri.to_string(); let uri = params.text_document.uri.to_string();
self.file_map.lock().insert(uri.clone(), rope); self.add_new_file(&uri, params.text_document.text);
self.accessed_files.lock().shift_insert(0, uri.clone());
if let Err(e) = self.maybe_do_crawl(Some(uri)) { if let Err(e) = self.maybe_do_crawl(Some(uri)) {
error!("{e}") error!("{e:?}")
} }
Ok(()) Ok(())
} }
@ -234,20 +311,95 @@ impl MemoryBackend for FileStore {
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let uri = params.text_document.uri.to_string(); let uri = params.text_document.uri.to_string();
let mut file_map = self.file_map.lock(); let mut file_map = self.file_map.lock();
let rope = file_map let file = file_map
.get_mut(&uri) .get_mut(&uri)
.context("Error trying to get file that does not exist")?; .with_context(|| format!("Trying to get file that does not exist {uri}"))?;
for change in params.content_changes { for change in params.content_changes {
// If range is ommitted, text is the new text of the document // If range is ommitted, text is the new text of the document
if let Some(range) = change.range { if let Some(range) = change.range {
let start_index = // Record old positions
rope.line_to_char(range.start.line as usize) + range.start.character as usize; let (old_end_position, old_end_byte) = {
let last_line_index = file.rope.len_lines() - 1;
(
file.rope
.get_line(last_line_index)
.context("getting last line for edit")
.map(|last_line| Point::new(last_line_index, last_line.len_chars())),
file.rope.bytes().count(),
)
};
// Update the document
let start_index = file.rope.line_to_char(range.start.line as usize)
+ range.start.character as usize;
let end_index = let end_index =
rope.line_to_char(range.end.line as usize) + range.end.character as usize; file.rope.line_to_char(range.end.line as usize) + range.end.character as usize;
rope.remove(start_index..end_index); file.rope.remove(start_index..end_index);
rope.insert(start_index, &change.text); file.rope.insert(start_index, &change.text);
// Set new end positions
let (new_end_position, new_end_byte) = {
let last_line_index = file.rope.len_lines() - 1;
(
file.rope
.get_line(last_line_index)
.context("getting last line for edit")
.map(|last_line| Point::new(last_line_index, last_line.len_chars())),
file.rope.bytes().count(),
)
};
// Update the tree
if self.params.build_tree {
let mut old_tree = file.tree.take();
let start_byte = file
.rope
.try_line_to_char(range.start.line as usize)
.and_then(|start_char| {
file.rope
.try_char_to_byte(start_char + range.start.character as usize)
})
.map_err(anyhow::Error::msg);
if let Some(old_tree) = &mut old_tree {
match (start_byte, old_end_position, new_end_position) {
(Ok(start_byte), Ok(old_end_position), Ok(new_end_position)) => {
old_tree.edit(&InputEdit {
start_byte,
old_end_byte,
new_end_byte,
start_position: Point::new(
range.start.line as usize,
range.start.character as usize,
),
old_end_position,
new_end_position,
});
file.tree = match parse_tree(
&uri,
&file.rope.to_string(),
Some(old_tree),
) {
Ok(tree) => Some(tree),
Err(e) => {
error!("failed to edit tree: {e:?}");
None
}
};
}
(Err(e), _, _) | (_, Err(e), _) | (_, _, Err(e)) => {
error!("failed to build tree edit: {e:?}");
}
}
}
}
} else { } else {
*rope = Rope::from_str(&change.text); file.rope = Rope::from_str(&change.text);
if self.params.build_tree {
file.tree = match parse_tree(&uri, &change.text, None) {
Ok(tree) => Some(tree),
Err(e) => {
error!("failed to parse new tree: {e:?}");
None
}
};
}
} }
} }
self.accessed_files.lock().shift_insert(0, uri); self.accessed_files.lock().shift_insert(0, uri);
@ -299,8 +451,8 @@ mod tests {
} }
} }
#[tokio::test] #[test]
async fn can_open_document() -> anyhow::Result<()> { fn can_open_document() -> anyhow::Result<()> {
let params = lsp_types::DidOpenTextDocumentParams { let params = lsp_types::DidOpenTextDocumentParams {
text_document: generate_filler_text_document(None, None), text_document: generate_filler_text_document(None, None),
}; };
@ -312,12 +464,12 @@ mod tests {
.get("file://filler/") .get("file://filler/")
.unwrap() .unwrap()
.clone(); .clone();
assert_eq!(file.to_string(), "Here is the document body"); assert_eq!(file.rope.to_string(), "Here is the document body");
Ok(()) Ok(())
} }
#[tokio::test] #[test]
async fn can_rename_document() -> anyhow::Result<()> { fn can_rename_document() -> anyhow::Result<()> {
let params = lsp_types::DidOpenTextDocumentParams { let params = lsp_types::DidOpenTextDocumentParams {
text_document: generate_filler_text_document(None, None), text_document: generate_filler_text_document(None, None),
}; };
@ -338,12 +490,12 @@ mod tests {
.get("file://filler2/") .get("file://filler2/")
.unwrap() .unwrap()
.clone(); .clone();
assert_eq!(file.to_string(), "Here is the document body"); assert_eq!(file.rope.to_string(), "Here is the document body");
Ok(()) Ok(())
} }
#[tokio::test] #[test]
async fn can_change_document() -> anyhow::Result<()> { fn can_change_document() -> anyhow::Result<()> {
let text_document = generate_filler_text_document(None, None); let text_document = generate_filler_text_document(None, None);
let params = DidOpenTextDocumentParams { let params = DidOpenTextDocumentParams {
@ -379,7 +531,7 @@ mod tests {
.get("file://filler/") .get("file://filler/")
.unwrap() .unwrap()
.clone(); .clone();
assert_eq!(file.to_string(), "Hae is the document body"); assert_eq!(file.rope.to_string(), "Hae is the document body");
let params = lsp_types::DidChangeTextDocumentParams { let params = lsp_types::DidChangeTextDocumentParams {
text_document: VersionedTextDocumentIdentifier { text_document: VersionedTextDocumentIdentifier {
@ -399,7 +551,7 @@ mod tests {
.get("file://filler/") .get("file://filler/")
.unwrap() .unwrap()
.clone(); .clone();
assert_eq!(file.to_string(), "abc"); assert_eq!(file.rope.to_string(), "abc");
Ok(()) Ok(())
} }
@ -579,43 +731,123 @@ The end with a trailing new line
Ok(()) Ok(())
} }
// #[tokio::test] #[test]
// async fn test_fim_placement_corner_cases() -> anyhow::Result<()> { fn test_file_store_tree_sitter() -> anyhow::Result<()> {
// let text_document = generate_filler_text_document(None, Some("test\n")); crate::init_logger();
// let params = lsp_types::DidOpenTextDocumentParams {
// text_document: text_document.clone(),
// };
// let file_store = generate_base_file_store()?;
// file_store.opened_text_document(params).await?;
// // Test FIM let config = Config::default_with_file_store_without_models();
// let params = json!({ let file_store_config = if let config::ValidMemoryBackend::FileStore(file_store_config) =
// "fim": { config.config.memory.clone()
// "start": "SS", {
// "middle": "MM", file_store_config
// "end": "EE" } else {
// } anyhow::bail!("requires a file_store_config")
// }); };
// let prompt = file_store let params = AdditionalFileStoreParams { build_tree: true };
// .build_prompt( let file_store = FileStore::new_with_params(file_store_config, config, params)?;
// &TextDocumentPositionParams {
// text_document: TextDocumentIdentifier {
// uri: text_document.uri.clone(),
// },
// position: Position {
// line: 1,
// character: 0,
// },
// },
// params,
// )
// .await?;
// assert_eq!(prompt.context, "");
// let text = r#"test
// "#
// .to_string();
// assert_eq!(text, prompt.code);
// Ok(()) let uri = "file://filler/test.rs";
// } let text = r#"#[derive(Debug)]
struct Rectangle {
width: u32,
height: u32,
}
impl Rectangle {
fn area(&self) -> u32 {
}
}
fn main() {
let rect1 = Rectangle {
width: 30,
height: 50,
};
println!(
"The area of the rectangle is {} square pixels.",
rect1.area()
);
}"#;
let text_document = TextDocumentItem {
uri: reqwest::Url::parse(uri).unwrap(),
language_id: "".to_string(),
version: 0,
text: text.to_string(),
};
let params = DidOpenTextDocumentParams {
text_document: text_document.clone(),
};
file_store.opened_text_document(params)?;
// Test insert
let params = lsp_types::DidChangeTextDocumentParams {
text_document: VersionedTextDocumentIdentifier {
uri: text_document.uri.clone(),
version: 1,
},
content_changes: vec![TextDocumentContentChangeEvent {
range: Some(Range {
start: Position {
line: 8,
character: 0,
},
end: Position {
line: 8,
character: 0,
},
}),
range_length: None,
text: " self.width * self.height".to_string(),
}],
};
file_store.changed_text_document(params)?;
let file = file_store.file_map.lock().get(uri).unwrap().clone();
assert_eq!(file.tree.unwrap().root_node().to_sexp(), "(source_file (attribute_item (attribute (identifier) arguments: (token_tree (identifier)))) (struct_item name: (type_identifier) body: (field_declaration_list (field_declaration name: (field_identifier) type: (primitive_type)) (field_declaration name: (field_identifier) type: (primitive_type)))) (impl_item type: (type_identifier) body: (declaration_list (function_item name: (identifier) parameters: (parameters (self_parameter (self))) return_type: (primitive_type) body: (block (binary_expression left: (field_expression value: (self) field: (field_identifier)) right: (field_expression value: (self) field: (field_identifier))))))) (function_item name: (identifier) parameters: (parameters) body: (block (let_declaration pattern: (identifier) value: (struct_expression name: (type_identifier) body: (field_initializer_list (field_initializer field: (field_identifier) value: (integer_literal)) (field_initializer field: (field_identifier) value: (integer_literal))))) (expression_statement (macro_invocation macro: (identifier) (token_tree (string_literal (string_content)) (identifier) (identifier) (token_tree)))))))");
// Test delete
let params = lsp_types::DidChangeTextDocumentParams {
text_document: VersionedTextDocumentIdentifier {
uri: text_document.uri.clone(),
version: 1,
},
content_changes: vec![TextDocumentContentChangeEvent {
range: Some(Range {
start: Position {
line: 0,
character: 0,
},
end: Position {
line: 12,
character: 0,
},
}),
range_length: None,
text: "".to_string(),
}],
};
file_store.changed_text_document(params)?;
let file = file_store.file_map.lock().get(uri).unwrap().clone();
assert_eq!(file.tree.unwrap().root_node().to_sexp(), "(source_file (function_item name: (identifier) parameters: (parameters) body: (block (let_declaration pattern: (identifier) value: (struct_expression name: (type_identifier) body: (field_initializer_list (field_initializer field: (field_identifier) value: (integer_literal)) (field_initializer field: (field_identifier) value: (integer_literal))))) (expression_statement (macro_invocation macro: (identifier) (token_tree (string_literal (string_content)) (identifier) (identifier) (token_tree)))))))");
// Test replace
let params = lsp_types::DidChangeTextDocumentParams {
text_document: VersionedTextDocumentIdentifier {
uri: text_document.uri,
version: 1,
},
content_changes: vec![TextDocumentContentChangeEvent {
range: None,
range_length: None,
text: "fn main() {}".to_string(),
}],
};
file_store.changed_text_document(params)?;
let file = file_store.file_map.lock().get(uri).unwrap().clone();
assert_eq!(file.tree.unwrap().root_node().to_sexp(), "(source_file (function_item name: (identifier) parameters: (parameters) body: (block)))");
Ok(())
}
} }

View File

@ -1,30 +1,65 @@
use anyhow::Context;
use lsp_types::TextDocumentPositionParams;
use parking_lot::Mutex;
use pgml::{Collection, Pipeline};
use serde_json::{json, Value};
use std::{ use std::{
io::Read,
sync::{ sync::{
mpsc::{self, Sender}, mpsc::{self, Sender},
Arc, Arc,
}, },
time::Duration, time::Duration,
}; };
use anyhow::Context;
use lsp_types::TextDocumentPositionParams;
use parking_lot::Mutex;
use pgml::{Collection, Pipeline};
use serde_json::{json, Value};
use tokio::time; use tokio::time;
use tracing::{error, instrument}; use tracing::{error, instrument, warn};
use crate::{ use crate::{
config::{self, Config}, config::{self, Config},
crawl::Crawl, crawl::Crawl,
utils::{tokens_to_estimated_characters, TOKIO_RUNTIME}, splitters::{Chunk, Splitter},
utils::{chunk_to_id, tokens_to_estimated_characters, TOKIO_RUNTIME},
}; };
use super::{ use super::{
file_store::FileStore, ContextAndCodePrompt, FIMPrompt, MemoryBackend, MemoryRunParams, Prompt, file_store::{AdditionalFileStoreParams, FileStore},
PromptType, ContextAndCodePrompt, FIMPrompt, MemoryBackend, MemoryRunParams, Prompt, PromptType,
}; };
fn chunk_to_document(uri: &str, chunk: Chunk) -> Value {
json!({
"id": chunk_to_id(uri, &chunk),
"uri": uri,
"text": chunk.text,
"range": chunk.range
})
}
async fn split_and_upsert_file(
uri: &str,
collection: &mut Collection,
file_store: Arc<FileStore>,
splitter: Arc<Box<dyn Splitter + Send + Sync>>,
) -> anyhow::Result<()> {
// We need to make sure we don't hold the file_store lock while performing a network call
let chunks = {
file_store
.file_map()
.lock()
.get(uri)
.map(|f| splitter.split(f))
};
let chunks = chunks.with_context(|| format!("file not found for splitting: {uri}"))?;
let documents = chunks
.into_iter()
.map(|chunk| chunk_to_document(uri, chunk).into())
.collect();
collection
.upsert_documents(documents, None)
.await
.context("PGML - Error upserting documents")
}
#[derive(Clone)] #[derive(Clone)]
pub struct PostgresML { pub struct PostgresML {
_config: Config, _config: Config,
@ -33,6 +68,7 @@ pub struct PostgresML {
pipeline: Pipeline, pipeline: Pipeline,
debounce_tx: Sender<String>, debounce_tx: Sender<String>,
crawl: Option<Arc<Mutex<Crawl>>>, crawl: Option<Arc<Mutex<Crawl>>>,
splitter: Arc<Box<dyn Splitter + Send + Sync>>,
} }
impl PostgresML { impl PostgresML {
@ -45,10 +81,16 @@ impl PostgresML {
.crawl .crawl
.take() .take()
.map(|x| Arc::new(Mutex::new(Crawl::new(x, configuration.clone())))); .map(|x| Arc::new(Mutex::new(Crawl::new(x, configuration.clone()))));
let file_store = Arc::new(FileStore::new(
let splitter: Arc<Box<dyn Splitter + Send + Sync>> =
Arc::new(postgresml_config.splitter.try_into()?);
let file_store = Arc::new(FileStore::new_with_params(
config::FileStore::new_without_crawl(), config::FileStore::new_without_crawl(),
configuration.clone(), configuration.clone(),
AdditionalFileStoreParams::new(splitter.does_use_tree_sitter()),
)?); )?);
let database_url = if let Some(database_url) = postgresml_config.database_url { let database_url = if let Some(database_url) = postgresml_config.database_url {
database_url database_url
} else { } else {
@ -86,6 +128,7 @@ impl PostgresML {
let (debounce_tx, debounce_rx) = mpsc::channel::<String>(); let (debounce_tx, debounce_rx) = mpsc::channel::<String>();
let mut task_collection = collection.clone(); let mut task_collection = collection.clone();
let task_file_store = file_store.clone(); let task_file_store = file_store.clone();
let task_splitter = splitter.clone();
TOKIO_RUNTIME.spawn(async move { TOKIO_RUNTIME.spawn(async move {
let duration = Duration::from_millis(500); let duration = Duration::from_millis(500);
let mut file_uris = Vec::new(); let mut file_uris = Vec::new();
@ -102,36 +145,83 @@ impl PostgresML {
if file_uris.is_empty() { if file_uris.is_empty() {
continue; continue;
} }
let documents = match file_uris
// Build the chunks for our changed files
let chunks: Vec<Vec<Chunk>> = match file_uris
.iter() .iter()
.map(|uri| { .map(|uri| {
let text = task_file_store let file_store = task_file_store.file_map().lock();
.get_file_contents(&uri) let file = file_store
.context("Error reading file contents from file_store")?; .get(uri)
anyhow::Ok( .with_context(|| format!("getting file for splitting: {uri}"))?;
json!({ anyhow::Ok(task_splitter.split(file))
"id": uri,
"text": text
})
.into(),
)
}) })
.collect() .collect()
{ {
Ok(documents) => documents, Ok(chunks) => chunks,
Err(e) => { Err(e) => {
error!("{e}"); error!("{e}");
continue; continue;
} }
}; };
// Delete old chunks that no longer exist after the latest file changes
let delete_or_statements: Vec<Value> = file_uris
.iter()
.zip(&chunks)
.map(|(uri, chunks)| {
let ids: Vec<String> =
chunks.iter().map(|c| chunk_to_id(uri, c)).collect();
json!({
"$and": [
{
"uri": {
"$eq": uri
}
},
{
"id": {
"$nin": ids
}
}
]
})
})
.collect();
if let Err(e) = task_collection
.delete_documents(
json!({
"$or": delete_or_statements
})
.into(),
)
.await
{
error!("PGML - Error deleting file: {e:?}");
}
// Prepare and upsert our new chunks
let documents: Vec<pgml::types::Json> = chunks
.into_iter()
.zip(&file_uris)
.map(|(chunks, uri)| {
chunks
.into_iter()
.map(|chunk| chunk_to_document(&uri, chunk))
.collect::<Vec<Value>>()
})
.flatten()
.map(|f: Value| f.into())
.collect();
if let Err(e) = task_collection if let Err(e) = task_collection
.upsert_documents(documents, None) .upsert_documents(documents, None)
.await .await
.context("PGML - Error adding pipeline to collection") .context("PGML - Error upserting changed files")
{ {
error!("{e}"); error!("{e}");
continue; continue;
} }
file_uris = Vec::new(); file_uris = Vec::new();
} }
} }
@ -144,6 +234,7 @@ impl PostgresML {
pipeline, pipeline,
debounce_tx, debounce_tx,
crawl, crawl,
splitter,
}; };
if let Err(e) = s.maybe_do_crawl(None) { if let Err(e) = s.maybe_do_crawl(None) {
@ -154,28 +245,73 @@ impl PostgresML {
fn maybe_do_crawl(&self, triggered_file: Option<String>) -> anyhow::Result<()> { fn maybe_do_crawl(&self, triggered_file: Option<String>) -> anyhow::Result<()> {
if let Some(crawl) = &self.crawl { if let Some(crawl) = &self.crawl {
let mut _collection = self.collection.clone(); let mut documents: Vec<(String, Vec<Chunk>)> = vec![];
let mut _pipeline = self.pipeline.clone(); let mut total_bytes = 0;
let mut documents: Vec<pgml::types::Json> = vec![]; let mut current_bytes = 0;
crawl.lock().maybe_do_crawl(triggered_file, |path| { crawl
let uri = format!("file://{path}"); .lock()
// This means it has been opened before .maybe_do_crawl(triggered_file, |config, path| {
if self.file_store.contains_file(&uri) { let uri = format!("file://{path}");
return Ok(()); // This means it has been opened before
} if self.file_store.contains_file(&uri) {
// Get the contents, split, and upsert it return Ok(());
let contents = std::fs::read_to_string(path)?; }
documents.push( // Open the file and see if it is small enough to read
json!({ let mut f = std::fs::File::open(path)?;
"id": uri, if f.metadata()
"text": contents .map(|m| m.len() > config.max_file_size)
}) .unwrap_or(true)
.into(), {
); warn!("Skipping file because it is too large: {path}");
// Track the size of the documents we have return Ok(());
// If it is over some amount in bytes, upsert it }
Ok(()) // Read the file contents
})?; let mut contents = vec![];
f.read_to_end(&mut contents);
if let Ok(contents) = String::from_utf8(contents) {
current_bytes += contents.len();
total_bytes += contents.len();
let chunks = self.splitter.split_file_contents(&uri, &contents);
documents.push((uri, chunks));
}
// If we have over 100 mega bytes of data do the upsert
if current_bytes >= 100_000_000 || total_bytes as u64 >= config.max_crawl_memory
{
// Prepare our chunks
let to_upsert_documents: Vec<pgml::types::Json> =
std::mem::take(&mut documents)
.into_iter()
.map(|(uri, chunks)| {
chunks
.into_iter()
.map(|chunk| chunk_to_document(&uri, chunk))
.collect::<Vec<Value>>()
})
.flatten()
.map(|f: Value| f.into())
.collect();
// Do the upsert
let mut collection = self.collection.clone();
TOKIO_RUNTIME.spawn(async move {
if let Err(e) = collection
.upsert_documents(to_upsert_documents, None)
.await
.context("PGML - Error upserting changed files")
{
error!("{e}");
}
});
// Reset everything
current_bytes = 0;
documents = vec![];
}
// Break if total bytes is over the max crawl memory
if total_bytes as u64 >= config.max_crawl_memory {
warn!("Ending crawl eraly do to max_crawl_memory");
return Ok(());
}
Ok(())
})?;
} }
Ok(()) Ok(())
} }
@ -263,25 +399,22 @@ impl MemoryBackend for PostgresML {
params: lsp_types::DidOpenTextDocumentParams, params: lsp_types::DidOpenTextDocumentParams,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
self.file_store.opened_text_document(params.clone())?; self.file_store.opened_text_document(params.clone())?;
let mut task_collection = self.collection.clone();
let saved_uri = params.text_document.uri.to_string(); let saved_uri = params.text_document.uri.to_string();
let mut collection = self.collection.clone();
let file_store = self.file_store.clone();
let splitter = self.splitter.clone();
TOKIO_RUNTIME.spawn(async move { TOKIO_RUNTIME.spawn(async move {
let text = params.text_document.text.clone();
let uri = params.text_document.uri.to_string(); let uri = params.text_document.uri.to_string();
task_collection if let Err(e) = split_and_upsert_file(&uri, &mut collection, file_store, splitter).await
.upsert_documents( {
vec![json!({ error!("{e:?}")
"id": uri, }
"text": text
})
.into()],
None,
)
.await
.expect("PGML - Error upserting documents");
}); });
if let Err(e) = self.maybe_do_crawl(Some(saved_uri)) { if let Err(e) = self.maybe_do_crawl(Some(saved_uri)) {
error!("{e}") error!("{e:?}")
} }
Ok(()) Ok(())
} }
@ -300,32 +433,35 @@ impl MemoryBackend for PostgresML {
#[instrument(skip(self))] #[instrument(skip(self))]
fn renamed_files(&self, params: lsp_types::RenameFilesParams) -> anyhow::Result<()> { fn renamed_files(&self, params: lsp_types::RenameFilesParams) -> anyhow::Result<()> {
self.file_store.renamed_files(params.clone())?; self.file_store.renamed_files(params.clone())?;
let mut task_collection = self.collection.clone();
let task_params = params.clone(); let mut collection = self.collection.clone();
let file_store = self.file_store.clone();
let splitter = self.splitter.clone();
TOKIO_RUNTIME.spawn(async move { TOKIO_RUNTIME.spawn(async move {
for file in task_params.files { for file in params.files {
task_collection if let Err(e) = collection
.delete_documents( .delete_documents(
json!({ json!({
"id": file.old_uri "uri": {
"$eq": file.old_uri
}
}) })
.into(), .into(),
) )
.await .await
.expect("PGML - Error deleting file"); {
let text = error!("PGML - Error deleting file: {e:?}");
std::fs::read_to_string(&file.new_uri).expect("PGML - Error reading file"); }
task_collection if let Err(e) = split_and_upsert_file(
.upsert_documents( &file.new_uri,
vec![json!({ &mut collection,
"id": file.new_uri, file_store.clone(),
"text": text splitter.clone(),
}) )
.into()], .await
None, {
) error!("{e:?}")
.await }
.expect("PGML - Error adding pipeline to collection");
} }
}); });
Ok(()) Ok(())

53
src/splitters/mod.rs Normal file
View File

@ -0,0 +1,53 @@
use serde::Serialize;
use crate::{config::ValidSplitter, memory_backends::file_store::File};
mod tree_sitter;
#[derive(Serialize)]
pub struct ByteRange {
pub start_byte: usize,
pub end_byte: usize,
}
impl ByteRange {
pub fn new(start_byte: usize, end_byte: usize) -> Self {
Self {
start_byte,
end_byte,
}
}
}
#[derive(Serialize)]
pub struct Chunk {
pub text: String,
pub range: ByteRange,
}
impl Chunk {
fn new(text: String, range: ByteRange) -> Self {
Self { text, range }
}
}
pub trait Splitter {
fn split(&self, file: &File) -> Vec<Chunk>;
fn split_file_contents(&self, uri: &str, contents: &str) -> Vec<Chunk>;
fn does_use_tree_sitter(&self) -> bool {
false
}
}
impl TryFrom<ValidSplitter> for Box<dyn Splitter + Send + Sync> {
type Error = anyhow::Error;
fn try_from(value: ValidSplitter) -> Result<Self, Self::Error> {
match value {
ValidSplitter::TreeSitter(config) => {
Ok(Box::new(tree_sitter::TreeSitter::new(config)?))
}
}
}
}

View File

@ -0,0 +1,77 @@
use splitter_tree_sitter::TreeSitterCodeSplitter;
use tracing::error;
use tree_sitter::Tree;
use crate::{config, memory_backends::file_store::File, utils::parse_tree};
use super::{ByteRange, Chunk, Splitter};
pub struct TreeSitter {
_config: config::TreeSitter,
splitter: TreeSitterCodeSplitter,
}
impl TreeSitter {
pub fn new(config: config::TreeSitter) -> anyhow::Result<Self> {
Ok(Self {
splitter: TreeSitterCodeSplitter::new(config.chunk_size, config.chunk_overlap)?,
_config: config,
})
}
fn split_tree(&self, tree: &Tree, contents: &[u8]) -> anyhow::Result<Vec<Chunk>> {
Ok(self
.splitter
.split(tree, contents)?
.into_iter()
.map(|c| {
Chunk::new(
c.text.to_owned(),
ByteRange::new(c.range.start_byte, c.range.end_byte),
)
})
.collect())
}
}
impl Splitter for TreeSitter {
fn split(&self, file: &File) -> Vec<Chunk> {
if let Some(tree) = file.tree() {
match self.split_tree(tree, file.rope().to_string().as_bytes()) {
Ok(chunks) => chunks,
Err(e) => {
error!(
"Failed to parse tree for file with error {e:?}. Falling back to default splitter.",
);
todo!()
}
}
} else {
panic!("TreeSitter splitter requires a tree to split")
}
}
fn split_file_contents(&self, uri: &str, contents: &str) -> Vec<Chunk> {
match parse_tree(uri, contents, None) {
Ok(tree) => match self.split_tree(&tree, contents.as_bytes()) {
Ok(chunks) => chunks,
Err(e) => {
error!(
"Failed to parse tree for file: {uri} with error {e:?}. Falling back to default splitter.",
);
todo!()
}
},
Err(e) => {
error!(
"Failed to parse tree for file {uri} with error {e:?}. Falling back to default splitter.",
);
todo!()
}
}
}
fn does_use_tree_sitter(&self) -> bool {
true
}
}

View File

@ -1,8 +1,10 @@
use anyhow::Context;
use lsp_server::ResponseError; use lsp_server::ResponseError;
use once_cell::sync::Lazy; use once_cell::sync::Lazy;
use tokio::runtime; use tokio::runtime;
use tree_sitter::Tree;
use crate::{config::ChatMessage, memory_backends::ContextAndCodePrompt}; use crate::{config::ChatMessage, memory_backends::ContextAndCodePrompt, splitters::Chunk};
pub static TOKIO_RUNTIME: Lazy<runtime::Runtime> = Lazy::new(|| { pub static TOKIO_RUNTIME: Lazy<runtime::Runtime> = Lazy::new(|| {
runtime::Builder::new_multi_thread() runtime::Builder::new_multi_thread()
@ -52,3 +54,17 @@ pub fn format_context_code_in_str(s: &str, context: &str, code: &str) -> String
pub fn format_context_code(context: &str, code: &str) -> String { pub fn format_context_code(context: &str, code: &str) -> String {
format!("{context}\n\n{code}") format!("{context}\n\n{code}")
} }
pub fn chunk_to_id(uri: &str, chunk: &Chunk) -> String {
format!("{uri}#{}-{}", chunk.range.start_byte, chunk.range.end_byte)
}
pub fn parse_tree(uri: &str, contents: &str, old_tree: Option<&Tree>) -> anyhow::Result<Tree> {
let path = std::path::Path::new(uri);
let extension = path.extension().map(|x| x.to_string_lossy());
let extension = extension.as_deref().unwrap_or("");
let mut parser = utils_tree_sitter::get_parser_for_extension(extension)?;
parser
.parse(&contents, old_tree)
.with_context(|| format!("parsing tree failed for {uri}"))
}