Almost working RAG

This commit is contained in:
Silas Marvin 2024-06-16 19:33:09 -07:00
parent f2b8c1eda3
commit cbe487ca3a
8 changed files with 97 additions and 30 deletions

1
Cargo.lock generated
View File

@ -3182,7 +3182,6 @@ dependencies = [
"regex",
"strum",
"thiserror",
"tree-sitter",
"unicode-segmentation",
]

View File

@ -35,7 +35,7 @@ async-trait = "0.1.78"
tree-sitter = "0.22"
utils-tree-sitter = { workspace = true, features = ["all"] }
splitter-tree-sitter = { workspace = true }
text-splitter = { version = "0.13.3", features = ["code"] }
text-splitter = { version = "0.13.3" }
[build-dependencies]
cc="*"

View File

@ -28,6 +28,8 @@ impl Default for PostProcess {
pub enum ValidSplitter {
#[serde(rename = "tree_sitter")]
TreeSitter(TreeSitter),
#[serde(rename = "text_sitter")]
TextSplitter(TextSplitter),
}
impl Default for ValidSplitter {
@ -61,6 +63,12 @@ impl Default for TreeSitter {
}
}
#[derive(Debug, Clone, Deserialize)]
pub struct TextSplitter {
#[serde(default = "chunk_size_default")]
pub chunk_size: usize,
}
#[derive(Debug, Clone, Deserialize)]
pub enum ValidMemoryBackend {
#[serde(rename = "file_store")]
@ -123,7 +131,7 @@ pub struct FIM {
}
const fn max_crawl_memory_default() -> u64 {
42
100_000_000
}
const fn max_crawl_file_size_default() -> u64 {

View File

@ -8,6 +8,7 @@ pub struct Crawl {
crawl_config: config::Crawl,
config: Config,
crawled_file_types: HashSet<String>,
crawled_all: bool,
}
impl Crawl {
@ -16,6 +17,7 @@ impl Crawl {
crawl_config,
config,
crawled_file_types: HashSet::new(),
crawled_all: false,
}
}
@ -25,6 +27,10 @@ impl Crawl {
triggered_file: Option<String>,
mut f: impl FnMut(&config::Crawl, &str) -> anyhow::Result<bool>,
) -> anyhow::Result<()> {
if self.crawled_all {
return Ok(());
}
if let Some(root_uri) = &self.config.client_params.root_uri {
if !root_uri.starts_with("file://") {
anyhow::bail!("Skipping crawling as root_uri does not begin with file://")
@ -51,13 +57,14 @@ impl Crawl {
for result in WalkBuilder::new(&root_uri[7..]).build() {
let result = result?;
let path = result.path();
eprintln!("CRAWLING: {}", path.display());
if !path.is_dir() {
if let Some(path_str) = path.to_str() {
if self.crawl_config.all_files {
match f(&self.crawl_config, path_str) {
Ok(c) => {
if !c {
return Ok(());
break;
}
}
Err(e) => error!("{e:?}"),
@ -72,7 +79,7 @@ impl Crawl {
match f(&self.crawl_config, path_str) {
Ok(c) => {
if !c {
return Ok(());
break;
}
}
Err(e) => error!("{e:?}"),
@ -88,6 +95,8 @@ impl Crawl {
if let Some(extension_to_match) = extension_to_match {
self.crawled_file_types.insert(extension_to_match);
} else {
self.crawled_all = true
}
}
Ok(())

View File

@ -245,7 +245,7 @@ impl PostgresML {
fn maybe_do_crawl(&self, triggered_file: Option<String>) -> anyhow::Result<()> {
if let Some(crawl) = &self.crawl {
let mut documents: Vec<(String, Vec<Chunk>)> = vec![];
let mut documents = vec![];
let mut total_bytes = 0;
let mut current_bytes = 0;
crawl
@ -253,7 +253,7 @@ impl PostgresML {
.maybe_do_crawl(triggered_file, |config, path| {
// Break if total bytes is over the max crawl memory
if total_bytes as u64 >= config.max_crawl_memory {
warn!("Ending crawl early due to `max_crawl_memory` resetraint");
warn!("Ending crawl early due to `max_crawl_memory` restraint");
return Ok(false);
}
// This means it has been opened before
@ -274,26 +274,19 @@ impl PostgresML {
let contents = String::from_utf8(contents)?;
current_bytes += contents.len();
total_bytes += contents.len();
let chunks = self.splitter.split_file_contents(&uri, &contents);
documents.push((uri, chunks));
let chunks: Vec<pgml::types::Json> = self
.splitter
.split_file_contents(&uri, &contents)
.into_iter()
.map(|chunk| chunk_to_document(&uri, chunk).into())
.collect();
documents.extend(chunks);
// If we have over 100 mega bytes of data do the upsert
if current_bytes >= 100_000_000 || total_bytes as u64 >= config.max_crawl_memory
{
// Prepare our chunks
let to_upsert_documents: Vec<pgml::types::Json> =
std::mem::take(&mut documents)
.into_iter()
.map(|(uri, chunks)| {
chunks
.into_iter()
.map(|chunk| chunk_to_document(&uri, chunk))
.collect::<Vec<Value>>()
})
.flatten()
.map(|f: Value| f.into())
.collect();
// Do the upsert
// Upsert the documents
let mut collection = self.collection.clone();
let to_upsert_documents = std::mem::take(&mut documents);
TOKIO_RUNTIME.spawn(async move {
if let Err(e) = collection
.upsert_documents(to_upsert_documents, None)
@ -309,6 +302,19 @@ impl PostgresML {
}
Ok(true)
})?;
// Upsert any remaining documents
if documents.len() > 0 {
let mut collection = self.collection.clone();
TOKIO_RUNTIME.spawn(async move {
if let Err(e) = collection
.upsert_documents(documents, None)
.await
.context("PGML - Error upserting changed files")
{
error!("{e}");
}
});
}
}
Ok(())
}

View File

@ -2,6 +2,7 @@ use serde::Serialize;
use crate::{config::ValidSplitter, memory_backends::file_store::File};
mod text_splitter;
mod tree_sitter;
#[derive(Serialize)]
@ -48,6 +49,9 @@ impl TryFrom<ValidSplitter> for Box<dyn Splitter + Send + Sync> {
ValidSplitter::TreeSitter(config) => {
Ok(Box::new(tree_sitter::TreeSitter::new(config)?))
}
ValidSplitter::TextSplitter(config) => {
Ok(Box::new(text_splitter::TextSplitter::new(config)))
}
}
}
}

View File

@ -0,0 +1,40 @@
use crate::{config, memory_backends::file_store::File};
use super::{ByteRange, Chunk, Splitter};
pub struct TextSplitter {
splitter: text_splitter::TextSplitter<text_splitter::Characters>,
}
impl TextSplitter {
pub fn new(config: config::TextSplitter) -> Self {
Self {
splitter: text_splitter::TextSplitter::new(config.chunk_size),
}
}
pub fn new_with_chunk_size(chunk_size: usize) -> Self {
Self {
splitter: text_splitter::TextSplitter::new(chunk_size),
}
}
}
impl Splitter for TextSplitter {
fn split(&self, file: &File) -> Vec<Chunk> {
self.split_file_contents("", &file.rope().to_string())
}
fn split_file_contents(&self, _uri: &str, contents: &str) -> Vec<Chunk> {
self.splitter
.chunk_indices(contents)
.fold(vec![], |mut acc, (start_byte, text)| {
let end_byte = start_byte + text.len();
acc.push(Chunk::new(
text.to_string(),
ByteRange::new(start_byte, end_byte),
));
acc
})
}
}

View File

@ -4,18 +4,19 @@ use tree_sitter::Tree;
use crate::{config, memory_backends::file_store::File, utils::parse_tree};
use super::{ByteRange, Chunk, Splitter};
use super::{text_splitter::TextSplitter, ByteRange, Chunk, Splitter};
pub struct TreeSitter {
_config: config::TreeSitter,
splitter: TreeSitterCodeSplitter,
text_splitter: TextSplitter,
}
impl TreeSitter {
pub fn new(config: config::TreeSitter) -> anyhow::Result<Self> {
let text_splitter = TextSplitter::new_with_chunk_size(config.chunk_size);
Ok(Self {
splitter: TreeSitterCodeSplitter::new(config.chunk_size, config.chunk_overlap)?,
_config: config,
text_splitter,
})
}
@ -43,11 +44,11 @@ impl Splitter for TreeSitter {
error!(
"Failed to parse tree for file with error: {e:?}. Falling back to default splitter.",
);
todo!()
self.text_splitter.split(file)
}
}
} else {
panic!("TreeSitter splitter requires a tree to split")
self.text_splitter.split(file)
}
}
@ -59,14 +60,14 @@ impl Splitter for TreeSitter {
error!(
"Failed to parse tree for file: {uri} with error: {e:?}. Falling back to default splitter.",
);
todo!()
self.text_splitter.split_file_contents(uri, contents)
}
},
Err(e) => {
error!(
"Failed to parse tree for file {uri} with error: {e:?}. Falling back to default splitter.",
);
todo!()
self.text_splitter.split_file_contents(uri, contents)
}
}
}