Working RAG

This commit is contained in:
Silas Marvin 2024-06-21 08:30:22 -07:00
parent 9166aaf4b6
commit 09f602ee12
4 changed files with 66 additions and 47 deletions

View File

@ -149,6 +149,13 @@ pub struct Crawl {
pub all_files: bool,
}
#[derive(Clone, Debug, Deserialize)]
pub struct PostgresMLEmbeddingModel {
pub model: String,
pub embed_parameters: Option<Value>,
pub query_parameters: Option<Value>,
}
#[derive(Clone, Debug, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct PostgresML {
@ -156,8 +163,7 @@ pub struct PostgresML {
pub crawl: Option<Crawl>,
#[serde(default)]
pub splitter: ValidSplitter,
pub embedding_model: Option<String>,
pub embedding_model_parameters: Option<Value>,
pub embedding_model: Option<PostgresMLEmbeddingModel>,
}
#[derive(Clone, Debug, Deserialize, Default)]

View File

@ -223,16 +223,13 @@ impl FileStore {
params: MemoryRunParams,
pull_from_multiple_files: bool,
) -> anyhow::Result<Prompt> {
let (mut rope, cursor_index) = self.get_rope_for_position(
position,
params.max_context_length,
pull_from_multiple_files,
)?;
let (mut rope, cursor_index) =
self.get_rope_for_position(position, params.max_context, pull_from_multiple_files)?;
Ok(match prompt_type {
PromptType::ContextAndCode => {
if params.is_for_chat {
let max_length = tokens_to_estimated_characters(params.max_context_length);
let max_length = tokens_to_estimated_characters(params.max_context);
let start = cursor_index.saturating_sub(max_length / 2);
let end = rope
.len_chars()
@ -248,7 +245,7 @@ impl FileStore {
))
} else {
let start = cursor_index
.saturating_sub(tokens_to_estimated_characters(params.max_context_length));
.saturating_sub(tokens_to_estimated_characters(params.max_context));
let rope_slice = rope
.get_slice(start..cursor_index)
.context("Error getting rope slice")?;
@ -259,7 +256,7 @@ impl FileStore {
}
}
PromptType::FIM => {
let max_length = tokens_to_estimated_characters(params.max_context_length);
let max_length = tokens_to_estimated_characters(params.max_context);
let start = cursor_index.saturating_sub(max_length / 2);
let end = rope
.len_chars()

View File

@ -18,13 +18,13 @@ pub enum PromptType {
#[derive(Clone)]
pub struct MemoryRunParams {
pub is_for_chat: bool,
pub max_context_length: usize,
pub max_context: usize,
}
impl From<&Value> for MemoryRunParams {
fn from(value: &Value) -> Self {
Self {
max_context_length: value["max_context_length"].as_u64().unwrap_or(1024) as usize,
max_context: value["max_context"].as_u64().unwrap_or(1024) as usize,
// messages are for most backends, contents are for Gemini
is_for_chat: value["messages"].is_array() || value["contents"].is_array(),
}

View File

@ -5,6 +5,7 @@ use pgml::{Collection, Pipeline};
use rand::{distributions::Alphanumeric, Rng};
use serde_json::{json, Value};
use std::{
collections::HashSet,
io::Read,
sync::{
mpsc::{self, Sender},
@ -29,7 +30,7 @@ use super::{
const RESYNC_MAX_FILE_SIZE: u64 = 10_000_000;
fn format_chunk_chunk(uri: &str, chunk: &Chunk, root_uri: Option<&str>) -> String {
fn format_file_excerpt(uri: &str, excerpt: &str, root_uri: Option<&str>) -> String {
let path = match root_uri {
Some(root_uri) => {
if uri.starts_with(root_uri) {
@ -42,9 +43,8 @@ fn format_chunk_chunk(uri: &str, chunk: &Chunk, root_uri: Option<&str>) -> Strin
};
format!(
r#"--{path}--
{}
{excerpt}
"#,
chunk.text
)
}
@ -52,7 +52,7 @@ fn chunk_to_document(uri: &str, chunk: Chunk, root_uri: Option<&str>) -> Value {
json!({
"id": chunk_to_id(uri, &chunk),
"uri": uri,
"text": format_chunk_chunk(uri, &chunk, root_uri),
"text": format_file_excerpt(uri, &chunk.text, root_uri),
"range": chunk.range
})
}
@ -86,6 +86,7 @@ async fn split_and_upsert_file(
#[derive(Clone)]
pub struct PostgresML {
config: Config,
postgresml_config: config::PostgresML,
file_store: Arc<FileStore>,
collection: Collection,
pipeline: Pipeline,
@ -106,7 +107,7 @@ impl PostgresML {
.map(|x| Arc::new(Mutex::new(Crawl::new(x, configuration.clone()))));
let splitter: Arc<Box<dyn Splitter + Send + Sync>> =
Arc::new(postgresml_config.splitter.try_into()?);
Arc::new(postgresml_config.splitter.clone().try_into()?);
let file_store = Arc::new(FileStore::new_with_params(
config::FileStore::new_without_crawl(),
@ -114,20 +115,20 @@ impl PostgresML {
AdditionalFileStoreParams::new(splitter.does_use_tree_sitter()),
)?);
let database_url = if let Some(database_url) = postgresml_config.database_url {
let database_url = if let Some(database_url) = postgresml_config.database_url.clone() {
database_url
} else {
std::env::var("PGML_DATABASE_URL").context("please provide either the `database_url` in the `postgresml` config, or set the `PGML_DATABASE_URL` environment variable")?
};
// Build our pipeline schema
let pipeline = match postgresml_config.embedding_model {
let pipeline = match &postgresml_config.embedding_model {
Some(embedding_model) => {
json!({
"text": {
"semantic_search": {
"model": embedding_model,
"parameters": postgresml_config.embedding_model_parameters
"model": embedding_model.model,
"parameters": embedding_model.embed_parameters
}
}
})
@ -281,6 +282,7 @@ impl PostgresML {
let s = Self {
config: configuration,
postgresml_config,
file_store,
collection,
pipeline,
@ -332,12 +334,19 @@ impl PostgresML {
let mut documents_to_delete = vec![];
let mut chunks_to_upsert = vec![];
let mut current_chunks_bytes = 0;
let mut checked_uris = HashSet::new();
for document in documents.into_iter() {
let uri = match document["document"]["uri"].as_str() {
Some(uri) => uri,
None => continue, // This should never happen, but is really bad as we now have a document with essentially no way to delete it
};
// Check if we have already loaded in this file
if checked_uris.contains(uri) {
continue;
}
checked_uris.insert(uri.to_string());
let path = uri.replace("file://", "");
let path = std::path::Path::new(&path);
if !path.exists() {
@ -458,9 +467,9 @@ impl PostgresML {
if let Err(e) = collection
.upsert_documents(to_upsert_documents, None)
.await
.context("PGML - Error upserting changed files")
.context("PGML - error upserting changed files")
{
error!("{e}");
error!("{e:?}");
}
});
// Reset everything
@ -476,9 +485,9 @@ impl PostgresML {
if let Err(e) = collection
.upsert_documents(documents, None)
.await
.context("PGML - Error upserting changed files")
.context("PGML - error upserting changed files")
{
error!("{e}");
error!("{e:?}");
}
});
}
@ -502,32 +511,38 @@ impl MemoryBackend for PostgresML {
params: &Value,
) -> anyhow::Result<Prompt> {
let params: MemoryRunParams = params.try_into()?;
// TOOD: FIGURE THIS OUT
// let prompt_size = params.max_context_length
let chunk_size = self.splitter.chunk_size();
let total_allowed_characters = tokens_to_estimated_characters(params.max_context);
// Build the query
let query = self
.file_store
.get_characters_around_position(position, 512)?;
.get_characters_around_position(position, chunk_size)?;
// Build the prompt
let mut file_store_params = params.clone();
file_store_params.max_context_length = 512;
file_store_params.max_context = chunk_size;
let code = self
.file_store
.build_code(position, prompt_type, file_store_params, false)?;
// Get the byte of the cursor
let cursor_byte = self.file_store.position_to_byte(position)?;
eprintln!(
"CURSOR BYTE: {} IN DOCUMENT: {}",
cursor_byte,
position.text_document.uri.to_string()
);
// Get the context
let limit = params.max_context_length / 512;
let limit = (total_allowed_characters / chunk_size).saturating_sub(1);
let parameters = match self
.postgresml_config
.embedding_model
.as_ref()
.map(|m| m.query_parameters.clone())
.flatten()
{
Some(query_parameters) => query_parameters,
None => json!({
"prompt": "query: "
}),
};
let res = self
.collection
.vector_search_local(
@ -536,9 +551,7 @@ impl MemoryBackend for PostgresML {
"fields": {
"text": {
"query": query,
"parameters": {
"prompt": "query: "
}
"parameters": parameters
}
},
"filter": {
@ -581,17 +594,20 @@ impl MemoryBackend for PostgresML {
})
.collect::<anyhow::Result<Vec<String>>>()?
.join("\n\n");
eprintln!("THE CONTEXT:\n\n{context}\n\n");
let chars = tokens_to_estimated_characters(params.max_context_length.saturating_sub(512));
let context = &context[..chars.min(context.len())];
let context = &context[..(total_allowed_characters - chunk_size).min(context.len())];
// Reconstruct the Prompts
Ok(match code {
Prompt::ContextAndCode(context_and_code) => Prompt::ContextAndCode(
ContextAndCodePrompt::new(context.to_owned(), context_and_code.code),
),
Prompt::ContextAndCode(context_and_code) => {
Prompt::ContextAndCode(ContextAndCodePrompt::new(
context.to_owned(),
format_file_excerpt(
&position.text_document.uri.to_string(),
&context_and_code.code,
self.config.client_params.root_uri.as_deref(),
),
))
}
Prompt::FIM(fim) => Prompt::FIM(FIMPrompt::new(
format!("{context}\n\n{}", fim.prompt),
fim.suffix,