diff --git a/src/config.rs b/src/config.rs index 8b7b394..f4e049a 100644 --- a/src/config.rs +++ b/src/config.rs @@ -85,19 +85,36 @@ pub struct FIM { pub end: String, } +const fn max_crawl_memory_default() -> u32 { + 42 +} + +#[derive(Clone, Debug, Deserialize)] +#[serde(deny_unknown_fields)] +pub struct Crawl { + #[serde(default = "max_crawl_memory_default")] + pub max_crawl_memory: u32, + #[serde(default)] + pub all_files: bool, +} + #[derive(Clone, Debug, Deserialize)] #[serde(deny_unknown_fields)] pub struct PostgresML { pub database_url: Option, - #[serde(default)] - pub crawl: bool, + pub crawl: Option, } #[derive(Clone, Debug, Deserialize, Default)] #[serde(deny_unknown_fields)] pub struct FileStore { - #[serde(default)] - pub crawl: bool, + pub crawl: Option, +} + +impl FileStore { + pub fn new_without_crawl() -> Self { + Self { crawl: None } + } } const fn n_gpu_layers_default() -> u32 { @@ -230,15 +247,14 @@ pub struct ValidConfig { #[derive(Clone, Debug, Deserialize, Default)] pub struct ValidClientParams { - #[serde(alias = "rootURI")] - _root_uri: Option, - _workspace_folders: Option>, + #[serde(alias = "rootUri")] + pub root_uri: Option, } #[derive(Clone, Debug)] pub struct Config { pub config: ValidConfig, - _client_params: ValidClientParams, + pub client_params: ValidClientParams, } impl Config { @@ -255,7 +271,7 @@ impl Config { let client_params: ValidClientParams = serde_json::from_value(args)?; Ok(Self { config: valid_args, - _client_params: client_params, + client_params, }) } @@ -306,13 +322,13 @@ impl Config { pub fn default_with_file_store_without_models() -> Self { Self { config: ValidConfig { - memory: ValidMemoryBackend::FileStore(FileStore { crawl: false }), + memory: ValidMemoryBackend::FileStore(FileStore { crawl: None }), models: HashMap::new(), completion: None, }, - _client_params: ValidClientParams { - _root_uri: None, - _workspace_folders: None, + client_params: ValidClientParams { + root_uri: None, + workspace_folders: None, }, } } diff --git a/src/memory_backends/file_store.rs b/src/memory_backends/file_store.rs index 4d70509..9f2123d 100644 --- a/src/memory_backends/file_store.rs +++ b/src/memory_backends/file_store.rs @@ -1,11 +1,12 @@ use anyhow::Context; +use ignore::WalkBuilder; use indexmap::IndexSet; use lsp_types::TextDocumentPositionParams; use parking_lot::Mutex; use ropey::Rope; use serde_json::Value; -use std::collections::HashMap; -use tracing::instrument; +use std::collections::{HashMap, HashSet}; +use tracing::{error, instrument}; use crate::{ config::{self, Config}, @@ -15,28 +16,106 @@ use crate::{ use super::{ContextAndCodePrompt, FIMPrompt, MemoryBackend, MemoryRunParams, Prompt, PromptType}; pub struct FileStore { - _crawl: bool, - _config: Config, + config: Config, + file_store_config: config::FileStore, + crawled_file_types: Mutex>, file_map: Mutex>, accessed_files: Mutex>, } impl FileStore { - pub fn new(file_store_config: config::FileStore, config: Config) -> Self { + pub fn new(file_store_config: config::FileStore, config: Config) -> anyhow::Result { + let s = Self { + config, + file_store_config, + crawled_file_types: Mutex::new(HashSet::new()), + file_map: Mutex::new(HashMap::new()), + accessed_files: Mutex::new(IndexSet::new()), + }; + if let Err(e) = s.maybe_do_crawl(None) { + error!("{e}") + } + Ok(s) + } + + pub fn new_without_crawl(config: Config) -> Self { Self { - _crawl: file_store_config.crawl, - _config: config, + config, + file_store_config: config::FileStore::new_without_crawl(), + crawled_file_types: Mutex::new(HashSet::new()), file_map: Mutex::new(HashMap::new()), accessed_files: Mutex::new(IndexSet::new()), } } - pub fn new_without_crawl(config: Config) -> Self { - Self { - _crawl: false, - _config: config, - file_map: Mutex::new(HashMap::new()), - accessed_files: Mutex::new(IndexSet::new()), + pub fn maybe_do_crawl(&self, triggered_file: Option) -> anyhow::Result<()> { + match ( + &self.config.client_params.root_uri, + &self.file_store_config.crawl, + ) { + (Some(root_uri), Some(crawl)) => { + let extension_to_match = triggered_file + .map(|tf| { + let path = std::path::Path::new(&tf); + path.extension().map(|f| f.to_str().map(|f| f.to_owned())) + }) + .flatten() + .flatten(); + + if let Some(extension_to_match) = &extension_to_match { + if self.crawled_file_types.lock().contains(extension_to_match) { + return Ok(()); + } + } + + if !crawl.all_files && extension_to_match.is_none() { + return Ok(()); + } + + if !root_uri.starts_with("file://") { + anyhow::bail!("Skipping crawling as root_uri does not begin with file://") + } + + for result in WalkBuilder::new(&root_uri[7..]).build() { + let result = result?; + let path = result.path(); + if !path.is_dir() { + if let Some(path_str) = path.to_str() { + let insert_uri = format!("file://{path_str}"); + if self.file_map.lock().contains_key(&insert_uri) { + continue; + } + if crawl.all_files { + let contents = std::fs::read_to_string(path)?; + self.file_map + .lock() + .insert(insert_uri, Rope::from_str(&contents)); + } else { + match ( + path.extension().map(|pe| pe.to_str()).flatten(), + &extension_to_match, + ) { + (Some(path_extension), Some(extension_to_match)) => { + if path_extension == extension_to_match { + let contents = std::fs::read_to_string(path)?; + self.file_map + .lock() + .insert(insert_uri, Rope::from_str(&contents)); + } + } + _ => continue, + } + } + } + } + } + + if let Some(extension_to_match) = extension_to_match { + self.crawled_file_types.lock().insert(extension_to_match); + } + Ok(()) + } + _ => Ok(()), } } @@ -199,7 +278,10 @@ impl MemoryBackend for FileStore { let rope = Rope::from_str(¶ms.text_document.text); let uri = params.text_document.uri.to_string(); self.file_map.lock().insert(uri.clone(), rope); - self.accessed_files.lock().shift_insert(0, uri); + self.accessed_files.lock().shift_insert(0, uri.clone()); + if let Err(e) = self.maybe_do_crawl(Some(uri)) { + error!("{e}") + } Ok(()) } @@ -261,7 +343,7 @@ mod tests { } else { anyhow::bail!("requires a file_store_config") }; - Ok(FileStore::new(file_store_config, config)) + FileStore::new(file_store_config, config) } fn generate_filler_text_document(uri: Option<&str>, text: Option<&str>) -> TextDocumentItem { diff --git a/src/memory_backends/mod.rs b/src/memory_backends/mod.rs index 52a8974..9db0dbf 100644 --- a/src/memory_backends/mod.rs +++ b/src/memory_backends/mod.rs @@ -137,7 +137,7 @@ impl TryFrom for Box { fn try_from(configuration: Config) -> Result { match configuration.config.memory.clone() { ValidMemoryBackend::FileStore(file_store_config) => Ok(Box::new( - file_store::FileStore::new(file_store_config, configuration), + file_store::FileStore::new(file_store_config, configuration)?, )), ValidMemoryBackend::PostgresML(postgresml_config) => Ok(Box::new( postgresml::PostgresML::new(postgresml_config, configuration)?,