Periodic commit

This commit is contained in:
Silas Marvin 2024-06-16 09:49:16 -07:00
parent 3eb5102399
commit 58192c4182
10 changed files with 810 additions and 170 deletions

76
Cargo.lock generated
View File

@ -1523,6 +1523,7 @@ dependencies = [
"anyhow",
"assert_cmd",
"async-trait",
"cc",
"directories",
"hf-hub",
"ignore",
@ -1539,10 +1540,13 @@ dependencies = [
"ropey",
"serde",
"serde_json",
"splitter-tree-sitter",
"tokenizers",
"tokio",
"tracing",
"tracing-subscriber",
"tree-sitter",
"utils-tree-sitter",
"xxhash-rust",
]
@ -2196,9 +2200,9 @@ dependencies = [
[[package]]
name = "regex"
version = "1.10.3"
version = "1.10.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b62dbe01f0b06f9d8dc7d49e05a0785f153b00b2c227856282f671e0318c9b15"
checksum = "b91213439dad192326a0d7c6ee3955910425f441d7038e0d6933b0aec5c4517f"
dependencies = [
"aho-corasick",
"memchr",
@ -2756,6 +2760,15 @@ dependencies = [
"der",
]
[[package]]
name = "splitter-tree-sitter"
version = "0.1.0"
dependencies = [
"cc",
"thiserror",
"tree-sitter",
]
[[package]]
name = "spm_precompiled"
version = "0.1.4"
@ -3088,18 +3101,18 @@ checksum = "3369f5ac52d5eb6ab48c6b4ffdc8efbcad6b89c765749064ba298f2c68a16a76"
[[package]]
name = "thiserror"
version = "1.0.58"
version = "1.0.61"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "03468839009160513471e86a034bb2c5c0e4baae3b43f79ffc55c4a5427b3297"
checksum = "c546c80d6be4bc6a00c0f01730c08df82eaa7a7a61f11d656526506112cc1709"
dependencies = [
"thiserror-impl",
]
[[package]]
name = "thiserror-impl"
version = "1.0.58"
version = "1.0.61"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c61f3ba182994efc43764a46c018c347bc492c79f024e705f46567b418f6d4f7"
checksum = "46c3384250002a6d5af4d114f2845d37b57521033f30d5c3f46c4d70e1197533"
dependencies = [
"proc-macro2",
"quote",
@ -3339,6 +3352,45 @@ dependencies = [
"tracing-serde",
]
[[package]]
name = "tree-sitter"
version = "0.22.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df7cc499ceadd4dcdf7ec6d4cbc34ece92c3fa07821e287aedecd4416c516dca"
dependencies = [
"cc",
"regex",
]
[[package]]
name = "tree-sitter-python"
version = "0.21.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b4066c6cf678f962f8c2c4561f205945c84834cce73d981e71392624fdc390a9"
dependencies = [
"cc",
"tree-sitter",
]
[[package]]
name = "tree-sitter-rust"
version = "0.21.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "277690f420bf90741dea984f3da038ace46c4fe6047cba57a66822226cde1c93"
dependencies = [
"cc",
"tree-sitter",
]
[[package]]
name = "tree-sitter-zig"
version = "0.0.1"
source = "git+https://github.com/SilasMarvin/tree-sitter-zig?branch=silas-update-tree-sitter-version#2eedab3ff6dda88aedddf0bb32a14f81bb709a73"
dependencies = [
"cc",
"tree-sitter",
]
[[package]]
name = "try-lock"
version = "0.2.5"
@ -3450,6 +3502,18 @@ version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a"
[[package]]
name = "utils-tree-sitter"
version = "0.1.0"
dependencies = [
"cc",
"thiserror",
"tree-sitter",
"tree-sitter-python",
"tree-sitter-rust",
"tree-sitter-zig",
]
[[package]]
name = "uuid"
version = "1.7.0"

View File

@ -33,6 +33,14 @@ pgml = "1.0.4"
tokio = { version = "1.36.0", features = ["rt-multi-thread", "time"] }
indexmap = "2.2.5"
async-trait = "0.1.78"
tree-sitter = "0.22"
# splitter-tree-sitter = { git = "https://github.com/SilasMarvin/splitter-tree-sitter" }
splitter-tree-sitter = { path = "../../splitter-tree-sitter" }
# utils-tree-sitter = { git = "https://github.com/SilasMarvin/utils-tree-sitter" }
utils-tree-sitter = { path = "../../utils-tree-sitter", features = ["all"] }
[build-dependencies]
cc="*"
[features]
default = []

View File

@ -24,6 +24,43 @@ impl Default for PostProcess {
}
}
#[derive(Debug, Clone, Deserialize)]
pub enum ValidSplitter {
#[serde(rename = "tree_sitter")]
TreeSitter(TreeSitter),
}
impl Default for ValidSplitter {
fn default() -> Self {
ValidSplitter::TreeSitter(TreeSitter::default())
}
}
const fn chunk_size_default() -> usize {
1500
}
const fn chunk_overlap_default() -> usize {
0
}
#[derive(Debug, Clone, Deserialize)]
pub struct TreeSitter {
#[serde(default = "chunk_size_default")]
pub chunk_size: usize,
#[serde(default = "chunk_overlap_default")]
pub chunk_overlap: usize,
}
impl Default for TreeSitter {
fn default() -> Self {
Self {
chunk_size: 1500,
chunk_overlap: 0,
}
}
}
#[derive(Debug, Clone, Deserialize)]
pub enum ValidMemoryBackend {
#[serde(rename = "file_store")]
@ -85,15 +122,21 @@ pub struct FIM {
pub end: String,
}
const fn max_crawl_memory_default() -> u32 {
const fn max_crawl_memory_default() -> u64 {
42
}
const fn max_crawl_file_size_default() -> u64 {
10_000_000
}
#[derive(Clone, Debug, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct Crawl {
#[serde(default = "max_crawl_file_size_default")]
pub max_file_size: u64,
#[serde(default = "max_crawl_memory_default")]
pub max_crawl_memory: u32,
pub max_crawl_memory: u64,
#[serde(default)]
pub all_files: bool,
}
@ -103,6 +146,8 @@ pub struct Crawl {
pub struct PostgresML {
pub database_url: Option<String>,
pub crawl: Option<Crawl>,
#[serde(default)]
pub splitter: ValidSplitter,
}
#[derive(Clone, Debug, Deserialize, Default)]

View File

@ -18,10 +18,14 @@ impl Crawl {
}
}
pub fn crawl_config(&self) -> &config::Crawl {
&self.crawl_config
}
pub fn maybe_do_crawl(
&mut self,
triggered_file: Option<String>,
mut f: impl FnMut(&str) -> anyhow::Result<()>,
mut f: impl FnMut(&config::Crawl, &str) -> anyhow::Result<()>,
) -> anyhow::Result<()> {
if let Some(root_uri) = &self.config.client_params.root_uri {
if !root_uri.starts_with("file://") {
@ -52,7 +56,7 @@ impl Crawl {
if !path.is_dir() {
if let Some(path_str) = path.to_str() {
if self.crawl_config.all_files {
f(path_str)?;
f(&self.crawl_config, path_str)?;
} else {
match (
path.extension().map(|pe| pe.to_str()).flatten(),
@ -60,7 +64,7 @@ impl Crawl {
) {
(Some(path_extension), Some(extension_to_match)) => {
if path_extension == extension_to_match {
f(path_str)?;
f(&self.crawl_config, path_str)?;
}
}
_ => continue,

View File

@ -18,6 +18,7 @@ mod crawl;
mod custom_requests;
mod memory_backends;
mod memory_worker;
mod splitters;
#[cfg(feature = "llama_cpp")]
mod template;
mod transformer_backends;
@ -51,15 +52,19 @@ where
req.extract(R::METHOD)
}
fn main() -> Result<()> {
// Builds a tracing subscriber from the `LSP_AI_LOG` environment variable
// If the variables value is malformed or missing, sets the default log level to ERROR
// Builds a tracing subscriber from the `LSP_AI_LOG` environment variable
// If the variables value is malformed or missing, sets the default log level to ERROR
fn init_logger() {
FmtSubscriber::builder()
.with_writer(std::io::stderr)
.with_ansi(false)
.without_time()
.with_env_filter(EnvFilter::from_env("LSP_AI_LOG"))
.init();
}
fn main() -> Result<()> {
init_logger();
let (connection, io_threads) = Connection::stdio();
let server_capabilities = serde_json::to_value(ServerCapabilities {

View File

@ -6,17 +6,50 @@ use ropey::Rope;
use serde_json::Value;
use std::collections::HashMap;
use tracing::{error, instrument};
use tree_sitter::{InputEdit, Point, Tree};
use crate::{
config::{self, Config},
crawl::Crawl,
utils::tokens_to_estimated_characters,
utils::{parse_tree, tokens_to_estimated_characters},
};
use super::{ContextAndCodePrompt, FIMPrompt, MemoryBackend, MemoryRunParams, Prompt, PromptType};
#[derive(Default)]
pub struct AdditionalFileStoreParams {
build_tree: bool,
}
impl AdditionalFileStoreParams {
pub fn new(build_tree: bool) -> Self {
Self { build_tree }
}
}
#[derive(Clone)]
pub struct File {
rope: Rope,
tree: Option<Tree>,
}
impl File {
fn new(rope: Rope, tree: Option<Tree>) -> Self {
Self { rope, tree }
}
pub fn rope(&self) -> &Rope {
&self.rope
}
pub fn tree(&self) -> Option<&Tree> {
self.tree.as_ref()
}
}
pub struct FileStore {
file_map: Mutex<HashMap<String, Rope>>,
params: AdditionalFileStoreParams,
file_map: Mutex<HashMap<String, File>>,
accessed_files: Mutex<IndexSet<String>>,
crawl: Option<Mutex<Crawl>>,
}
@ -28,29 +61,72 @@ impl FileStore {
.take()
.map(|x| Mutex::new(Crawl::new(x, config.clone())));
let s = Self {
params: AdditionalFileStoreParams::default(),
file_map: Mutex::new(HashMap::new()),
accessed_files: Mutex::new(IndexSet::new()),
crawl,
};
if let Err(e) = s.maybe_do_crawl(None) {
error!("{e}")
error!("{e:?}")
}
Ok(s)
}
pub fn new_with_params(
mut file_store_config: config::FileStore,
config: Config,
params: AdditionalFileStoreParams,
) -> anyhow::Result<Self> {
let crawl = file_store_config
.crawl
.take()
.map(|x| Mutex::new(Crawl::new(x, config.clone())));
let s = Self {
params,
file_map: Mutex::new(HashMap::new()),
accessed_files: Mutex::new(IndexSet::new()),
crawl,
};
if let Err(e) = s.maybe_do_crawl(None) {
error!("{e:?}")
}
Ok(s)
}
fn add_new_file(&self, uri: &str, contents: String) {
let tree = if self.params.build_tree {
match parse_tree(uri, &contents, None) {
Ok(tree) => Some(tree),
Err(e) => {
error!(
"Failed to parse tree for {uri} with error {e}, falling back to no tree"
);
None
}
}
} else {
None
};
self.file_map
.lock()
.insert(uri.to_string(), File::new(Rope::from_str(&contents), tree));
self.accessed_files.lock().insert(uri.to_string());
}
fn maybe_do_crawl(&self, triggered_file: Option<String>) -> anyhow::Result<()> {
if let Some(crawl) = &self.crawl {
crawl.lock().maybe_do_crawl(triggered_file, |path| {
let insert_uri = format!("file://{path}");
if self.file_map.lock().contains_key(&insert_uri) {
return Ok(());
}
let contents = std::fs::read_to_string(path)?;
self.file_map
.lock()
.insert(insert_uri, Rope::from_str(&contents));
Ok(())
})?;
crawl
.lock()
.maybe_do_crawl(triggered_file, |config, path| {
let insert_uri = format!("file://{path}");
if self.file_map.lock().contains_key(&insert_uri) {
return Ok(());
}
// TODO: actually limit files based on config
let contents = std::fs::read_to_string(path)?;
self.add_new_file(&insert_uri, contents);
Ok(())
})?;
}
Ok(())
}
@ -67,6 +143,7 @@ impl FileStore {
.lock()
.get(&current_document_uri)
.context("Error file not found")?
.rope
.clone();
let mut cursor_index = rope.line_to_char(position.position.line as usize)
+ position.position.character as usize;
@ -82,7 +159,7 @@ impl FileStore {
break;
}
let file_map = self.file_map.lock();
let r = file_map.get(file).context("Error file not found")?;
let r = &file_map.get(file).context("Error file not found")?.rope;
let slice_max = needed.min(r.len_chars() + 1);
let rope_str_slice = r
.get_slice(0..slice_max - 1)
@ -105,6 +182,7 @@ impl FileStore {
.lock()
.get(position.text_document.uri.as_str())
.context("Error file not found")?
.rope
.clone();
let cursor_index = rope.line_to_char(position.position.line as usize)
+ position.position.character as usize;
@ -173,8 +251,8 @@ impl FileStore {
})
}
pub fn get_file_contents(&self, uri: &str) -> Option<String> {
self.file_map.lock().get(uri).clone().map(|x| x.to_string())
pub fn file_map(&self) -> &Mutex<HashMap<String, File>> {
&self.file_map
}
pub fn contains_file(&self, uri: &str) -> bool {
@ -191,6 +269,7 @@ impl MemoryBackend for FileStore {
.lock()
.get(position.text_document.uri.as_str())
.context("Error file not found")?
.rope
.clone();
let line = rope
.get_line(position.position.line as usize)
@ -217,12 +296,10 @@ impl MemoryBackend for FileStore {
&self,
params: lsp_types::DidOpenTextDocumentParams,
) -> anyhow::Result<()> {
let rope = Rope::from_str(&params.text_document.text);
let uri = params.text_document.uri.to_string();
self.file_map.lock().insert(uri.clone(), rope);
self.accessed_files.lock().shift_insert(0, uri.clone());
self.add_new_file(&uri, params.text_document.text);
if let Err(e) = self.maybe_do_crawl(Some(uri)) {
error!("{e}")
error!("{e:?}")
}
Ok(())
}
@ -234,20 +311,95 @@ impl MemoryBackend for FileStore {
) -> anyhow::Result<()> {
let uri = params.text_document.uri.to_string();
let mut file_map = self.file_map.lock();
let rope = file_map
let file = file_map
.get_mut(&uri)
.context("Error trying to get file that does not exist")?;
.with_context(|| format!("Trying to get file that does not exist {uri}"))?;
for change in params.content_changes {
// If range is ommitted, text is the new text of the document
if let Some(range) = change.range {
let start_index =
rope.line_to_char(range.start.line as usize) + range.start.character as usize;
// Record old positions
let (old_end_position, old_end_byte) = {
let last_line_index = file.rope.len_lines() - 1;
(
file.rope
.get_line(last_line_index)
.context("getting last line for edit")
.map(|last_line| Point::new(last_line_index, last_line.len_chars())),
file.rope.bytes().count(),
)
};
// Update the document
let start_index = file.rope.line_to_char(range.start.line as usize)
+ range.start.character as usize;
let end_index =
rope.line_to_char(range.end.line as usize) + range.end.character as usize;
rope.remove(start_index..end_index);
rope.insert(start_index, &change.text);
file.rope.line_to_char(range.end.line as usize) + range.end.character as usize;
file.rope.remove(start_index..end_index);
file.rope.insert(start_index, &change.text);
// Set new end positions
let (new_end_position, new_end_byte) = {
let last_line_index = file.rope.len_lines() - 1;
(
file.rope
.get_line(last_line_index)
.context("getting last line for edit")
.map(|last_line| Point::new(last_line_index, last_line.len_chars())),
file.rope.bytes().count(),
)
};
// Update the tree
if self.params.build_tree {
let mut old_tree = file.tree.take();
let start_byte = file
.rope
.try_line_to_char(range.start.line as usize)
.and_then(|start_char| {
file.rope
.try_char_to_byte(start_char + range.start.character as usize)
})
.map_err(anyhow::Error::msg);
if let Some(old_tree) = &mut old_tree {
match (start_byte, old_end_position, new_end_position) {
(Ok(start_byte), Ok(old_end_position), Ok(new_end_position)) => {
old_tree.edit(&InputEdit {
start_byte,
old_end_byte,
new_end_byte,
start_position: Point::new(
range.start.line as usize,
range.start.character as usize,
),
old_end_position,
new_end_position,
});
file.tree = match parse_tree(
&uri,
&file.rope.to_string(),
Some(old_tree),
) {
Ok(tree) => Some(tree),
Err(e) => {
error!("failed to edit tree: {e:?}");
None
}
};
}
(Err(e), _, _) | (_, Err(e), _) | (_, _, Err(e)) => {
error!("failed to build tree edit: {e:?}");
}
}
}
}
} else {
*rope = Rope::from_str(&change.text);
file.rope = Rope::from_str(&change.text);
if self.params.build_tree {
file.tree = match parse_tree(&uri, &change.text, None) {
Ok(tree) => Some(tree),
Err(e) => {
error!("failed to parse new tree: {e:?}");
None
}
};
}
}
}
self.accessed_files.lock().shift_insert(0, uri);
@ -299,8 +451,8 @@ mod tests {
}
}
#[tokio::test]
async fn can_open_document() -> anyhow::Result<()> {
#[test]
fn can_open_document() -> anyhow::Result<()> {
let params = lsp_types::DidOpenTextDocumentParams {
text_document: generate_filler_text_document(None, None),
};
@ -312,12 +464,12 @@ mod tests {
.get("file://filler/")
.unwrap()
.clone();
assert_eq!(file.to_string(), "Here is the document body");
assert_eq!(file.rope.to_string(), "Here is the document body");
Ok(())
}
#[tokio::test]
async fn can_rename_document() -> anyhow::Result<()> {
#[test]
fn can_rename_document() -> anyhow::Result<()> {
let params = lsp_types::DidOpenTextDocumentParams {
text_document: generate_filler_text_document(None, None),
};
@ -338,12 +490,12 @@ mod tests {
.get("file://filler2/")
.unwrap()
.clone();
assert_eq!(file.to_string(), "Here is the document body");
assert_eq!(file.rope.to_string(), "Here is the document body");
Ok(())
}
#[tokio::test]
async fn can_change_document() -> anyhow::Result<()> {
#[test]
fn can_change_document() -> anyhow::Result<()> {
let text_document = generate_filler_text_document(None, None);
let params = DidOpenTextDocumentParams {
@ -379,7 +531,7 @@ mod tests {
.get("file://filler/")
.unwrap()
.clone();
assert_eq!(file.to_string(), "Hae is the document body");
assert_eq!(file.rope.to_string(), "Hae is the document body");
let params = lsp_types::DidChangeTextDocumentParams {
text_document: VersionedTextDocumentIdentifier {
@ -399,7 +551,7 @@ mod tests {
.get("file://filler/")
.unwrap()
.clone();
assert_eq!(file.to_string(), "abc");
assert_eq!(file.rope.to_string(), "abc");
Ok(())
}
@ -579,43 +731,123 @@ The end with a trailing new line
Ok(())
}
// #[tokio::test]
// async fn test_fim_placement_corner_cases() -> anyhow::Result<()> {
// let text_document = generate_filler_text_document(None, Some("test\n"));
// let params = lsp_types::DidOpenTextDocumentParams {
// text_document: text_document.clone(),
// };
// let file_store = generate_base_file_store()?;
// file_store.opened_text_document(params).await?;
#[test]
fn test_file_store_tree_sitter() -> anyhow::Result<()> {
crate::init_logger();
// // Test FIM
// let params = json!({
// "fim": {
// "start": "SS",
// "middle": "MM",
// "end": "EE"
// }
// });
// let prompt = file_store
// .build_prompt(
// &TextDocumentPositionParams {
// text_document: TextDocumentIdentifier {
// uri: text_document.uri.clone(),
// },
// position: Position {
// line: 1,
// character: 0,
// },
// },
// params,
// )
// .await?;
// assert_eq!(prompt.context, "");
// let text = r#"test
// "#
// .to_string();
// assert_eq!(text, prompt.code);
let config = Config::default_with_file_store_without_models();
let file_store_config = if let config::ValidMemoryBackend::FileStore(file_store_config) =
config.config.memory.clone()
{
file_store_config
} else {
anyhow::bail!("requires a file_store_config")
};
let params = AdditionalFileStoreParams { build_tree: true };
let file_store = FileStore::new_with_params(file_store_config, config, params)?;
// Ok(())
// }
let uri = "file://filler/test.rs";
let text = r#"#[derive(Debug)]
struct Rectangle {
width: u32,
height: u32,
}
impl Rectangle {
fn area(&self) -> u32 {
}
}
fn main() {
let rect1 = Rectangle {
width: 30,
height: 50,
};
println!(
"The area of the rectangle is {} square pixels.",
rect1.area()
);
}"#;
let text_document = TextDocumentItem {
uri: reqwest::Url::parse(uri).unwrap(),
language_id: "".to_string(),
version: 0,
text: text.to_string(),
};
let params = DidOpenTextDocumentParams {
text_document: text_document.clone(),
};
file_store.opened_text_document(params)?;
// Test insert
let params = lsp_types::DidChangeTextDocumentParams {
text_document: VersionedTextDocumentIdentifier {
uri: text_document.uri.clone(),
version: 1,
},
content_changes: vec![TextDocumentContentChangeEvent {
range: Some(Range {
start: Position {
line: 8,
character: 0,
},
end: Position {
line: 8,
character: 0,
},
}),
range_length: None,
text: " self.width * self.height".to_string(),
}],
};
file_store.changed_text_document(params)?;
let file = file_store.file_map.lock().get(uri).unwrap().clone();
assert_eq!(file.tree.unwrap().root_node().to_sexp(), "(source_file (attribute_item (attribute (identifier) arguments: (token_tree (identifier)))) (struct_item name: (type_identifier) body: (field_declaration_list (field_declaration name: (field_identifier) type: (primitive_type)) (field_declaration name: (field_identifier) type: (primitive_type)))) (impl_item type: (type_identifier) body: (declaration_list (function_item name: (identifier) parameters: (parameters (self_parameter (self))) return_type: (primitive_type) body: (block (binary_expression left: (field_expression value: (self) field: (field_identifier)) right: (field_expression value: (self) field: (field_identifier))))))) (function_item name: (identifier) parameters: (parameters) body: (block (let_declaration pattern: (identifier) value: (struct_expression name: (type_identifier) body: (field_initializer_list (field_initializer field: (field_identifier) value: (integer_literal)) (field_initializer field: (field_identifier) value: (integer_literal))))) (expression_statement (macro_invocation macro: (identifier) (token_tree (string_literal (string_content)) (identifier) (identifier) (token_tree)))))))");
// Test delete
let params = lsp_types::DidChangeTextDocumentParams {
text_document: VersionedTextDocumentIdentifier {
uri: text_document.uri.clone(),
version: 1,
},
content_changes: vec![TextDocumentContentChangeEvent {
range: Some(Range {
start: Position {
line: 0,
character: 0,
},
end: Position {
line: 12,
character: 0,
},
}),
range_length: None,
text: "".to_string(),
}],
};
file_store.changed_text_document(params)?;
let file = file_store.file_map.lock().get(uri).unwrap().clone();
assert_eq!(file.tree.unwrap().root_node().to_sexp(), "(source_file (function_item name: (identifier) parameters: (parameters) body: (block (let_declaration pattern: (identifier) value: (struct_expression name: (type_identifier) body: (field_initializer_list (field_initializer field: (field_identifier) value: (integer_literal)) (field_initializer field: (field_identifier) value: (integer_literal))))) (expression_statement (macro_invocation macro: (identifier) (token_tree (string_literal (string_content)) (identifier) (identifier) (token_tree)))))))");
// Test replace
let params = lsp_types::DidChangeTextDocumentParams {
text_document: VersionedTextDocumentIdentifier {
uri: text_document.uri,
version: 1,
},
content_changes: vec![TextDocumentContentChangeEvent {
range: None,
range_length: None,
text: "fn main() {}".to_string(),
}],
};
file_store.changed_text_document(params)?;
let file = file_store.file_map.lock().get(uri).unwrap().clone();
assert_eq!(file.tree.unwrap().root_node().to_sexp(), "(source_file (function_item name: (identifier) parameters: (parameters) body: (block)))");
Ok(())
}
}

View File

@ -1,30 +1,65 @@
use anyhow::Context;
use lsp_types::TextDocumentPositionParams;
use parking_lot::Mutex;
use pgml::{Collection, Pipeline};
use serde_json::{json, Value};
use std::{
io::Read,
sync::{
mpsc::{self, Sender},
Arc,
},
time::Duration,
};
use anyhow::Context;
use lsp_types::TextDocumentPositionParams;
use parking_lot::Mutex;
use pgml::{Collection, Pipeline};
use serde_json::{json, Value};
use tokio::time;
use tracing::{error, instrument};
use tracing::{error, instrument, warn};
use crate::{
config::{self, Config},
crawl::Crawl,
utils::{tokens_to_estimated_characters, TOKIO_RUNTIME},
splitters::{Chunk, Splitter},
utils::{chunk_to_id, tokens_to_estimated_characters, TOKIO_RUNTIME},
};
use super::{
file_store::FileStore, ContextAndCodePrompt, FIMPrompt, MemoryBackend, MemoryRunParams, Prompt,
PromptType,
file_store::{AdditionalFileStoreParams, FileStore},
ContextAndCodePrompt, FIMPrompt, MemoryBackend, MemoryRunParams, Prompt, PromptType,
};
fn chunk_to_document(uri: &str, chunk: Chunk) -> Value {
json!({
"id": chunk_to_id(uri, &chunk),
"uri": uri,
"text": chunk.text,
"range": chunk.range
})
}
async fn split_and_upsert_file(
uri: &str,
collection: &mut Collection,
file_store: Arc<FileStore>,
splitter: Arc<Box<dyn Splitter + Send + Sync>>,
) -> anyhow::Result<()> {
// We need to make sure we don't hold the file_store lock while performing a network call
let chunks = {
file_store
.file_map()
.lock()
.get(uri)
.map(|f| splitter.split(f))
};
let chunks = chunks.with_context(|| format!("file not found for splitting: {uri}"))?;
let documents = chunks
.into_iter()
.map(|chunk| chunk_to_document(uri, chunk).into())
.collect();
collection
.upsert_documents(documents, None)
.await
.context("PGML - Error upserting documents")
}
#[derive(Clone)]
pub struct PostgresML {
_config: Config,
@ -33,6 +68,7 @@ pub struct PostgresML {
pipeline: Pipeline,
debounce_tx: Sender<String>,
crawl: Option<Arc<Mutex<Crawl>>>,
splitter: Arc<Box<dyn Splitter + Send + Sync>>,
}
impl PostgresML {
@ -45,10 +81,16 @@ impl PostgresML {
.crawl
.take()
.map(|x| Arc::new(Mutex::new(Crawl::new(x, configuration.clone()))));
let file_store = Arc::new(FileStore::new(
let splitter: Arc<Box<dyn Splitter + Send + Sync>> =
Arc::new(postgresml_config.splitter.try_into()?);
let file_store = Arc::new(FileStore::new_with_params(
config::FileStore::new_without_crawl(),
configuration.clone(),
AdditionalFileStoreParams::new(splitter.does_use_tree_sitter()),
)?);
let database_url = if let Some(database_url) = postgresml_config.database_url {
database_url
} else {
@ -86,6 +128,7 @@ impl PostgresML {
let (debounce_tx, debounce_rx) = mpsc::channel::<String>();
let mut task_collection = collection.clone();
let task_file_store = file_store.clone();
let task_splitter = splitter.clone();
TOKIO_RUNTIME.spawn(async move {
let duration = Duration::from_millis(500);
let mut file_uris = Vec::new();
@ -102,36 +145,83 @@ impl PostgresML {
if file_uris.is_empty() {
continue;
}
let documents = match file_uris
// Build the chunks for our changed files
let chunks: Vec<Vec<Chunk>> = match file_uris
.iter()
.map(|uri| {
let text = task_file_store
.get_file_contents(&uri)
.context("Error reading file contents from file_store")?;
anyhow::Ok(
json!({
"id": uri,
"text": text
})
.into(),
)
let file_store = task_file_store.file_map().lock();
let file = file_store
.get(uri)
.with_context(|| format!("getting file for splitting: {uri}"))?;
anyhow::Ok(task_splitter.split(file))
})
.collect()
{
Ok(documents) => documents,
Ok(chunks) => chunks,
Err(e) => {
error!("{e}");
continue;
}
};
// Delete old chunks that no longer exist after the latest file changes
let delete_or_statements: Vec<Value> = file_uris
.iter()
.zip(&chunks)
.map(|(uri, chunks)| {
let ids: Vec<String> =
chunks.iter().map(|c| chunk_to_id(uri, c)).collect();
json!({
"$and": [
{
"uri": {
"$eq": uri
}
},
{
"id": {
"$nin": ids
}
}
]
})
})
.collect();
if let Err(e) = task_collection
.delete_documents(
json!({
"$or": delete_or_statements
})
.into(),
)
.await
{
error!("PGML - Error deleting file: {e:?}");
}
// Prepare and upsert our new chunks
let documents: Vec<pgml::types::Json> = chunks
.into_iter()
.zip(&file_uris)
.map(|(chunks, uri)| {
chunks
.into_iter()
.map(|chunk| chunk_to_document(&uri, chunk))
.collect::<Vec<Value>>()
})
.flatten()
.map(|f: Value| f.into())
.collect();
if let Err(e) = task_collection
.upsert_documents(documents, None)
.await
.context("PGML - Error adding pipeline to collection")
.context("PGML - Error upserting changed files")
{
error!("{e}");
continue;
}
file_uris = Vec::new();
}
}
@ -144,6 +234,7 @@ impl PostgresML {
pipeline,
debounce_tx,
crawl,
splitter,
};
if let Err(e) = s.maybe_do_crawl(None) {
@ -154,28 +245,73 @@ impl PostgresML {
fn maybe_do_crawl(&self, triggered_file: Option<String>) -> anyhow::Result<()> {
if let Some(crawl) = &self.crawl {
let mut _collection = self.collection.clone();
let mut _pipeline = self.pipeline.clone();
let mut documents: Vec<pgml::types::Json> = vec![];
crawl.lock().maybe_do_crawl(triggered_file, |path| {
let uri = format!("file://{path}");
// This means it has been opened before
if self.file_store.contains_file(&uri) {
return Ok(());
}
// Get the contents, split, and upsert it
let contents = std::fs::read_to_string(path)?;
documents.push(
json!({
"id": uri,
"text": contents
})
.into(),
);
// Track the size of the documents we have
// If it is over some amount in bytes, upsert it
Ok(())
})?;
let mut documents: Vec<(String, Vec<Chunk>)> = vec![];
let mut total_bytes = 0;
let mut current_bytes = 0;
crawl
.lock()
.maybe_do_crawl(triggered_file, |config, path| {
let uri = format!("file://{path}");
// This means it has been opened before
if self.file_store.contains_file(&uri) {
return Ok(());
}
// Open the file and see if it is small enough to read
let mut f = std::fs::File::open(path)?;
if f.metadata()
.map(|m| m.len() > config.max_file_size)
.unwrap_or(true)
{
warn!("Skipping file because it is too large: {path}");
return Ok(());
}
// Read the file contents
let mut contents = vec![];
f.read_to_end(&mut contents);
if let Ok(contents) = String::from_utf8(contents) {
current_bytes += contents.len();
total_bytes += contents.len();
let chunks = self.splitter.split_file_contents(&uri, &contents);
documents.push((uri, chunks));
}
// If we have over 100 mega bytes of data do the upsert
if current_bytes >= 100_000_000 || total_bytes as u64 >= config.max_crawl_memory
{
// Prepare our chunks
let to_upsert_documents: Vec<pgml::types::Json> =
std::mem::take(&mut documents)
.into_iter()
.map(|(uri, chunks)| {
chunks
.into_iter()
.map(|chunk| chunk_to_document(&uri, chunk))
.collect::<Vec<Value>>()
})
.flatten()
.map(|f: Value| f.into())
.collect();
// Do the upsert
let mut collection = self.collection.clone();
TOKIO_RUNTIME.spawn(async move {
if let Err(e) = collection
.upsert_documents(to_upsert_documents, None)
.await
.context("PGML - Error upserting changed files")
{
error!("{e}");
}
});
// Reset everything
current_bytes = 0;
documents = vec![];
}
// Break if total bytes is over the max crawl memory
if total_bytes as u64 >= config.max_crawl_memory {
warn!("Ending crawl eraly do to max_crawl_memory");
return Ok(());
}
Ok(())
})?;
}
Ok(())
}
@ -263,25 +399,22 @@ impl MemoryBackend for PostgresML {
params: lsp_types::DidOpenTextDocumentParams,
) -> anyhow::Result<()> {
self.file_store.opened_text_document(params.clone())?;
let mut task_collection = self.collection.clone();
let saved_uri = params.text_document.uri.to_string();
let mut collection = self.collection.clone();
let file_store = self.file_store.clone();
let splitter = self.splitter.clone();
TOKIO_RUNTIME.spawn(async move {
let text = params.text_document.text.clone();
let uri = params.text_document.uri.to_string();
task_collection
.upsert_documents(
vec![json!({
"id": uri,
"text": text
})
.into()],
None,
)
.await
.expect("PGML - Error upserting documents");
if let Err(e) = split_and_upsert_file(&uri, &mut collection, file_store, splitter).await
{
error!("{e:?}")
}
});
if let Err(e) = self.maybe_do_crawl(Some(saved_uri)) {
error!("{e}")
error!("{e:?}")
}
Ok(())
}
@ -300,32 +433,35 @@ impl MemoryBackend for PostgresML {
#[instrument(skip(self))]
fn renamed_files(&self, params: lsp_types::RenameFilesParams) -> anyhow::Result<()> {
self.file_store.renamed_files(params.clone())?;
let mut task_collection = self.collection.clone();
let task_params = params.clone();
let mut collection = self.collection.clone();
let file_store = self.file_store.clone();
let splitter = self.splitter.clone();
TOKIO_RUNTIME.spawn(async move {
for file in task_params.files {
task_collection
for file in params.files {
if let Err(e) = collection
.delete_documents(
json!({
"id": file.old_uri
"uri": {
"$eq": file.old_uri
}
})
.into(),
)
.await
.expect("PGML - Error deleting file");
let text =
std::fs::read_to_string(&file.new_uri).expect("PGML - Error reading file");
task_collection
.upsert_documents(
vec![json!({
"id": file.new_uri,
"text": text
})
.into()],
None,
)
.await
.expect("PGML - Error adding pipeline to collection");
{
error!("PGML - Error deleting file: {e:?}");
}
if let Err(e) = split_and_upsert_file(
&file.new_uri,
&mut collection,
file_store.clone(),
splitter.clone(),
)
.await
{
error!("{e:?}")
}
}
});
Ok(())

53
src/splitters/mod.rs Normal file
View File

@ -0,0 +1,53 @@
use serde::Serialize;
use crate::{config::ValidSplitter, memory_backends::file_store::File};
mod tree_sitter;
#[derive(Serialize)]
pub struct ByteRange {
pub start_byte: usize,
pub end_byte: usize,
}
impl ByteRange {
pub fn new(start_byte: usize, end_byte: usize) -> Self {
Self {
start_byte,
end_byte,
}
}
}
#[derive(Serialize)]
pub struct Chunk {
pub text: String,
pub range: ByteRange,
}
impl Chunk {
fn new(text: String, range: ByteRange) -> Self {
Self { text, range }
}
}
pub trait Splitter {
fn split(&self, file: &File) -> Vec<Chunk>;
fn split_file_contents(&self, uri: &str, contents: &str) -> Vec<Chunk>;
fn does_use_tree_sitter(&self) -> bool {
false
}
}
impl TryFrom<ValidSplitter> for Box<dyn Splitter + Send + Sync> {
type Error = anyhow::Error;
fn try_from(value: ValidSplitter) -> Result<Self, Self::Error> {
match value {
ValidSplitter::TreeSitter(config) => {
Ok(Box::new(tree_sitter::TreeSitter::new(config)?))
}
}
}
}

View File

@ -0,0 +1,77 @@
use splitter_tree_sitter::TreeSitterCodeSplitter;
use tracing::error;
use tree_sitter::Tree;
use crate::{config, memory_backends::file_store::File, utils::parse_tree};
use super::{ByteRange, Chunk, Splitter};
pub struct TreeSitter {
_config: config::TreeSitter,
splitter: TreeSitterCodeSplitter,
}
impl TreeSitter {
pub fn new(config: config::TreeSitter) -> anyhow::Result<Self> {
Ok(Self {
splitter: TreeSitterCodeSplitter::new(config.chunk_size, config.chunk_overlap)?,
_config: config,
})
}
fn split_tree(&self, tree: &Tree, contents: &[u8]) -> anyhow::Result<Vec<Chunk>> {
Ok(self
.splitter
.split(tree, contents)?
.into_iter()
.map(|c| {
Chunk::new(
c.text.to_owned(),
ByteRange::new(c.range.start_byte, c.range.end_byte),
)
})
.collect())
}
}
impl Splitter for TreeSitter {
fn split(&self, file: &File) -> Vec<Chunk> {
if let Some(tree) = file.tree() {
match self.split_tree(tree, file.rope().to_string().as_bytes()) {
Ok(chunks) => chunks,
Err(e) => {
error!(
"Failed to parse tree for file with error {e:?}. Falling back to default splitter.",
);
todo!()
}
}
} else {
panic!("TreeSitter splitter requires a tree to split")
}
}
fn split_file_contents(&self, uri: &str, contents: &str) -> Vec<Chunk> {
match parse_tree(uri, contents, None) {
Ok(tree) => match self.split_tree(&tree, contents.as_bytes()) {
Ok(chunks) => chunks,
Err(e) => {
error!(
"Failed to parse tree for file: {uri} with error {e:?}. Falling back to default splitter.",
);
todo!()
}
},
Err(e) => {
error!(
"Failed to parse tree for file {uri} with error {e:?}. Falling back to default splitter.",
);
todo!()
}
}
}
fn does_use_tree_sitter(&self) -> bool {
true
}
}

View File

@ -1,8 +1,10 @@
use anyhow::Context;
use lsp_server::ResponseError;
use once_cell::sync::Lazy;
use tokio::runtime;
use tree_sitter::Tree;
use crate::{config::ChatMessage, memory_backends::ContextAndCodePrompt};
use crate::{config::ChatMessage, memory_backends::ContextAndCodePrompt, splitters::Chunk};
pub static TOKIO_RUNTIME: Lazy<runtime::Runtime> = Lazy::new(|| {
runtime::Builder::new_multi_thread()
@ -52,3 +54,17 @@ pub fn format_context_code_in_str(s: &str, context: &str, code: &str) -> String
pub fn format_context_code(context: &str, code: &str) -> String {
format!("{context}\n\n{code}")
}
pub fn chunk_to_id(uri: &str, chunk: &Chunk) -> String {
format!("{uri}#{}-{}", chunk.range.start_byte, chunk.range.end_byte)
}
pub fn parse_tree(uri: &str, contents: &str, old_tree: Option<&Tree>) -> anyhow::Result<Tree> {
let path = std::path::Path::new(uri);
let extension = path.extension().map(|x| x.to_string_lossy());
let extension = extension.as_deref().unwrap_or("");
let mut parser = utils_tree_sitter::get_parser_for_extension(extension)?;
parser
.parse(&contents, old_tree)
.with_context(|| format!("parsing tree failed for {uri}"))
}