Working RAG

This commit is contained in:
Silas Marvin 2024-06-21 08:30:22 -07:00
parent 9166aaf4b6
commit 09f602ee12
4 changed files with 66 additions and 47 deletions

View File

@ -149,6 +149,13 @@ pub struct Crawl {
pub all_files: bool, pub all_files: bool,
} }
#[derive(Clone, Debug, Deserialize)]
pub struct PostgresMLEmbeddingModel {
pub model: String,
pub embed_parameters: Option<Value>,
pub query_parameters: Option<Value>,
}
#[derive(Clone, Debug, Deserialize)] #[derive(Clone, Debug, Deserialize)]
#[serde(deny_unknown_fields)] #[serde(deny_unknown_fields)]
pub struct PostgresML { pub struct PostgresML {
@ -156,8 +163,7 @@ pub struct PostgresML {
pub crawl: Option<Crawl>, pub crawl: Option<Crawl>,
#[serde(default)] #[serde(default)]
pub splitter: ValidSplitter, pub splitter: ValidSplitter,
pub embedding_model: Option<String>, pub embedding_model: Option<PostgresMLEmbeddingModel>,
pub embedding_model_parameters: Option<Value>,
} }
#[derive(Clone, Debug, Deserialize, Default)] #[derive(Clone, Debug, Deserialize, Default)]

View File

@ -223,16 +223,13 @@ impl FileStore {
params: MemoryRunParams, params: MemoryRunParams,
pull_from_multiple_files: bool, pull_from_multiple_files: bool,
) -> anyhow::Result<Prompt> { ) -> anyhow::Result<Prompt> {
let (mut rope, cursor_index) = self.get_rope_for_position( let (mut rope, cursor_index) =
position, self.get_rope_for_position(position, params.max_context, pull_from_multiple_files)?;
params.max_context_length,
pull_from_multiple_files,
)?;
Ok(match prompt_type { Ok(match prompt_type {
PromptType::ContextAndCode => { PromptType::ContextAndCode => {
if params.is_for_chat { if params.is_for_chat {
let max_length = tokens_to_estimated_characters(params.max_context_length); let max_length = tokens_to_estimated_characters(params.max_context);
let start = cursor_index.saturating_sub(max_length / 2); let start = cursor_index.saturating_sub(max_length / 2);
let end = rope let end = rope
.len_chars() .len_chars()
@ -248,7 +245,7 @@ impl FileStore {
)) ))
} else { } else {
let start = cursor_index let start = cursor_index
.saturating_sub(tokens_to_estimated_characters(params.max_context_length)); .saturating_sub(tokens_to_estimated_characters(params.max_context));
let rope_slice = rope let rope_slice = rope
.get_slice(start..cursor_index) .get_slice(start..cursor_index)
.context("Error getting rope slice")?; .context("Error getting rope slice")?;
@ -259,7 +256,7 @@ impl FileStore {
} }
} }
PromptType::FIM => { PromptType::FIM => {
let max_length = tokens_to_estimated_characters(params.max_context_length); let max_length = tokens_to_estimated_characters(params.max_context);
let start = cursor_index.saturating_sub(max_length / 2); let start = cursor_index.saturating_sub(max_length / 2);
let end = rope let end = rope
.len_chars() .len_chars()

View File

@ -18,13 +18,13 @@ pub enum PromptType {
#[derive(Clone)] #[derive(Clone)]
pub struct MemoryRunParams { pub struct MemoryRunParams {
pub is_for_chat: bool, pub is_for_chat: bool,
pub max_context_length: usize, pub max_context: usize,
} }
impl From<&Value> for MemoryRunParams { impl From<&Value> for MemoryRunParams {
fn from(value: &Value) -> Self { fn from(value: &Value) -> Self {
Self { Self {
max_context_length: value["max_context_length"].as_u64().unwrap_or(1024) as usize, max_context: value["max_context"].as_u64().unwrap_or(1024) as usize,
// messages are for most backends, contents are for Gemini // messages are for most backends, contents are for Gemini
is_for_chat: value["messages"].is_array() || value["contents"].is_array(), is_for_chat: value["messages"].is_array() || value["contents"].is_array(),
} }

View File

@ -5,6 +5,7 @@ use pgml::{Collection, Pipeline};
use rand::{distributions::Alphanumeric, Rng}; use rand::{distributions::Alphanumeric, Rng};
use serde_json::{json, Value}; use serde_json::{json, Value};
use std::{ use std::{
collections::HashSet,
io::Read, io::Read,
sync::{ sync::{
mpsc::{self, Sender}, mpsc::{self, Sender},
@ -29,7 +30,7 @@ use super::{
const RESYNC_MAX_FILE_SIZE: u64 = 10_000_000; const RESYNC_MAX_FILE_SIZE: u64 = 10_000_000;
fn format_chunk_chunk(uri: &str, chunk: &Chunk, root_uri: Option<&str>) -> String { fn format_file_excerpt(uri: &str, excerpt: &str, root_uri: Option<&str>) -> String {
let path = match root_uri { let path = match root_uri {
Some(root_uri) => { Some(root_uri) => {
if uri.starts_with(root_uri) { if uri.starts_with(root_uri) {
@ -42,9 +43,8 @@ fn format_chunk_chunk(uri: &str, chunk: &Chunk, root_uri: Option<&str>) -> Strin
}; };
format!( format!(
r#"--{path}-- r#"--{path}--
{} {excerpt}
"#, "#,
chunk.text
) )
} }
@ -52,7 +52,7 @@ fn chunk_to_document(uri: &str, chunk: Chunk, root_uri: Option<&str>) -> Value {
json!({ json!({
"id": chunk_to_id(uri, &chunk), "id": chunk_to_id(uri, &chunk),
"uri": uri, "uri": uri,
"text": format_chunk_chunk(uri, &chunk, root_uri), "text": format_file_excerpt(uri, &chunk.text, root_uri),
"range": chunk.range "range": chunk.range
}) })
} }
@ -86,6 +86,7 @@ async fn split_and_upsert_file(
#[derive(Clone)] #[derive(Clone)]
pub struct PostgresML { pub struct PostgresML {
config: Config, config: Config,
postgresml_config: config::PostgresML,
file_store: Arc<FileStore>, file_store: Arc<FileStore>,
collection: Collection, collection: Collection,
pipeline: Pipeline, pipeline: Pipeline,
@ -106,7 +107,7 @@ impl PostgresML {
.map(|x| Arc::new(Mutex::new(Crawl::new(x, configuration.clone())))); .map(|x| Arc::new(Mutex::new(Crawl::new(x, configuration.clone()))));
let splitter: Arc<Box<dyn Splitter + Send + Sync>> = let splitter: Arc<Box<dyn Splitter + Send + Sync>> =
Arc::new(postgresml_config.splitter.try_into()?); Arc::new(postgresml_config.splitter.clone().try_into()?);
let file_store = Arc::new(FileStore::new_with_params( let file_store = Arc::new(FileStore::new_with_params(
config::FileStore::new_without_crawl(), config::FileStore::new_without_crawl(),
@ -114,20 +115,20 @@ impl PostgresML {
AdditionalFileStoreParams::new(splitter.does_use_tree_sitter()), AdditionalFileStoreParams::new(splitter.does_use_tree_sitter()),
)?); )?);
let database_url = if let Some(database_url) = postgresml_config.database_url { let database_url = if let Some(database_url) = postgresml_config.database_url.clone() {
database_url database_url
} else { } else {
std::env::var("PGML_DATABASE_URL").context("please provide either the `database_url` in the `postgresml` config, or set the `PGML_DATABASE_URL` environment variable")? std::env::var("PGML_DATABASE_URL").context("please provide either the `database_url` in the `postgresml` config, or set the `PGML_DATABASE_URL` environment variable")?
}; };
// Build our pipeline schema // Build our pipeline schema
let pipeline = match postgresml_config.embedding_model { let pipeline = match &postgresml_config.embedding_model {
Some(embedding_model) => { Some(embedding_model) => {
json!({ json!({
"text": { "text": {
"semantic_search": { "semantic_search": {
"model": embedding_model, "model": embedding_model.model,
"parameters": postgresml_config.embedding_model_parameters "parameters": embedding_model.embed_parameters
} }
} }
}) })
@ -281,6 +282,7 @@ impl PostgresML {
let s = Self { let s = Self {
config: configuration, config: configuration,
postgresml_config,
file_store, file_store,
collection, collection,
pipeline, pipeline,
@ -332,12 +334,19 @@ impl PostgresML {
let mut documents_to_delete = vec![]; let mut documents_to_delete = vec![];
let mut chunks_to_upsert = vec![]; let mut chunks_to_upsert = vec![];
let mut current_chunks_bytes = 0; let mut current_chunks_bytes = 0;
let mut checked_uris = HashSet::new();
for document in documents.into_iter() { for document in documents.into_iter() {
let uri = match document["document"]["uri"].as_str() { let uri = match document["document"]["uri"].as_str() {
Some(uri) => uri, Some(uri) => uri,
None => continue, // This should never happen, but is really bad as we now have a document with essentially no way to delete it None => continue, // This should never happen, but is really bad as we now have a document with essentially no way to delete it
}; };
// Check if we have already loaded in this file
if checked_uris.contains(uri) {
continue;
}
checked_uris.insert(uri.to_string());
let path = uri.replace("file://", ""); let path = uri.replace("file://", "");
let path = std::path::Path::new(&path); let path = std::path::Path::new(&path);
if !path.exists() { if !path.exists() {
@ -458,9 +467,9 @@ impl PostgresML {
if let Err(e) = collection if let Err(e) = collection
.upsert_documents(to_upsert_documents, None) .upsert_documents(to_upsert_documents, None)
.await .await
.context("PGML - Error upserting changed files") .context("PGML - error upserting changed files")
{ {
error!("{e}"); error!("{e:?}");
} }
}); });
// Reset everything // Reset everything
@ -476,9 +485,9 @@ impl PostgresML {
if let Err(e) = collection if let Err(e) = collection
.upsert_documents(documents, None) .upsert_documents(documents, None)
.await .await
.context("PGML - Error upserting changed files") .context("PGML - error upserting changed files")
{ {
error!("{e}"); error!("{e:?}");
} }
}); });
} }
@ -502,32 +511,38 @@ impl MemoryBackend for PostgresML {
params: &Value, params: &Value,
) -> anyhow::Result<Prompt> { ) -> anyhow::Result<Prompt> {
let params: MemoryRunParams = params.try_into()?; let params: MemoryRunParams = params.try_into()?;
let chunk_size = self.splitter.chunk_size();
// TOOD: FIGURE THIS OUT let total_allowed_characters = tokens_to_estimated_characters(params.max_context);
// let prompt_size = params.max_context_length
// Build the query // Build the query
let query = self let query = self
.file_store .file_store
.get_characters_around_position(position, 512)?; .get_characters_around_position(position, chunk_size)?;
// Build the prompt // Build the prompt
let mut file_store_params = params.clone(); let mut file_store_params = params.clone();
file_store_params.max_context_length = 512; file_store_params.max_context = chunk_size;
let code = self let code = self
.file_store .file_store
.build_code(position, prompt_type, file_store_params, false)?; .build_code(position, prompt_type, file_store_params, false)?;
// Get the byte of the cursor // Get the byte of the cursor
let cursor_byte = self.file_store.position_to_byte(position)?; let cursor_byte = self.file_store.position_to_byte(position)?;
eprintln!(
"CURSOR BYTE: {} IN DOCUMENT: {}",
cursor_byte,
position.text_document.uri.to_string()
);
// Get the context // Get the context
let limit = params.max_context_length / 512; let limit = (total_allowed_characters / chunk_size).saturating_sub(1);
let parameters = match self
.postgresml_config
.embedding_model
.as_ref()
.map(|m| m.query_parameters.clone())
.flatten()
{
Some(query_parameters) => query_parameters,
None => json!({
"prompt": "query: "
}),
};
let res = self let res = self
.collection .collection
.vector_search_local( .vector_search_local(
@ -536,9 +551,7 @@ impl MemoryBackend for PostgresML {
"fields": { "fields": {
"text": { "text": {
"query": query, "query": query,
"parameters": { "parameters": parameters
"prompt": "query: "
}
} }
}, },
"filter": { "filter": {
@ -581,17 +594,20 @@ impl MemoryBackend for PostgresML {
}) })
.collect::<anyhow::Result<Vec<String>>>()? .collect::<anyhow::Result<Vec<String>>>()?
.join("\n\n"); .join("\n\n");
let context = &context[..(total_allowed_characters - chunk_size).min(context.len())];
eprintln!("THE CONTEXT:\n\n{context}\n\n");
let chars = tokens_to_estimated_characters(params.max_context_length.saturating_sub(512));
let context = &context[..chars.min(context.len())];
// Reconstruct the Prompts // Reconstruct the Prompts
Ok(match code { Ok(match code {
Prompt::ContextAndCode(context_and_code) => Prompt::ContextAndCode( Prompt::ContextAndCode(context_and_code) => {
ContextAndCodePrompt::new(context.to_owned(), context_and_code.code), Prompt::ContextAndCode(ContextAndCodePrompt::new(
), context.to_owned(),
format_file_excerpt(
&position.text_document.uri.to_string(),
&context_and_code.code,
self.config.client_params.root_uri.as_deref(),
),
))
}
Prompt::FIM(fim) => Prompt::FIM(FIMPrompt::new( Prompt::FIM(fim) => Prompt::FIM(FIMPrompt::new(
format!("{context}\n\n{}", fim.prompt), format!("{context}\n\n{}", fim.prompt),
fim.suffix, fim.suffix,