From 09f602ee1218b3046726eae06f443ca7b03f6af2 Mon Sep 17 00:00:00 2001 From: Silas Marvin <19626586+SilasMarvin@users.noreply.github.com> Date: Fri, 21 Jun 2024 08:30:22 -0700 Subject: [PATCH] Working RAG --- crates/lsp-ai/src/config.rs | 10 ++- .../lsp-ai/src/memory_backends/file_store.rs | 13 ++- crates/lsp-ai/src/memory_backends/mod.rs | 4 +- .../src/memory_backends/postgresml/mod.rs | 86 +++++++++++-------- 4 files changed, 66 insertions(+), 47 deletions(-) diff --git a/crates/lsp-ai/src/config.rs b/crates/lsp-ai/src/config.rs index 92ef755..5d29b66 100644 --- a/crates/lsp-ai/src/config.rs +++ b/crates/lsp-ai/src/config.rs @@ -149,6 +149,13 @@ pub struct Crawl { pub all_files: bool, } +#[derive(Clone, Debug, Deserialize)] +pub struct PostgresMLEmbeddingModel { + pub model: String, + pub embed_parameters: Option, + pub query_parameters: Option, +} + #[derive(Clone, Debug, Deserialize)] #[serde(deny_unknown_fields)] pub struct PostgresML { @@ -156,8 +163,7 @@ pub struct PostgresML { pub crawl: Option, #[serde(default)] pub splitter: ValidSplitter, - pub embedding_model: Option, - pub embedding_model_parameters: Option, + pub embedding_model: Option, } #[derive(Clone, Debug, Deserialize, Default)] diff --git a/crates/lsp-ai/src/memory_backends/file_store.rs b/crates/lsp-ai/src/memory_backends/file_store.rs index ec50151..9f8b805 100644 --- a/crates/lsp-ai/src/memory_backends/file_store.rs +++ b/crates/lsp-ai/src/memory_backends/file_store.rs @@ -223,16 +223,13 @@ impl FileStore { params: MemoryRunParams, pull_from_multiple_files: bool, ) -> anyhow::Result { - let (mut rope, cursor_index) = self.get_rope_for_position( - position, - params.max_context_length, - pull_from_multiple_files, - )?; + let (mut rope, cursor_index) = + self.get_rope_for_position(position, params.max_context, pull_from_multiple_files)?; Ok(match prompt_type { PromptType::ContextAndCode => { if params.is_for_chat { - let max_length = tokens_to_estimated_characters(params.max_context_length); + let max_length = tokens_to_estimated_characters(params.max_context); let start = cursor_index.saturating_sub(max_length / 2); let end = rope .len_chars() @@ -248,7 +245,7 @@ impl FileStore { )) } else { let start = cursor_index - .saturating_sub(tokens_to_estimated_characters(params.max_context_length)); + .saturating_sub(tokens_to_estimated_characters(params.max_context)); let rope_slice = rope .get_slice(start..cursor_index) .context("Error getting rope slice")?; @@ -259,7 +256,7 @@ impl FileStore { } } PromptType::FIM => { - let max_length = tokens_to_estimated_characters(params.max_context_length); + let max_length = tokens_to_estimated_characters(params.max_context); let start = cursor_index.saturating_sub(max_length / 2); let end = rope .len_chars() diff --git a/crates/lsp-ai/src/memory_backends/mod.rs b/crates/lsp-ai/src/memory_backends/mod.rs index 6b54cff..9d6fcc5 100644 --- a/crates/lsp-ai/src/memory_backends/mod.rs +++ b/crates/lsp-ai/src/memory_backends/mod.rs @@ -18,13 +18,13 @@ pub enum PromptType { #[derive(Clone)] pub struct MemoryRunParams { pub is_for_chat: bool, - pub max_context_length: usize, + pub max_context: usize, } impl From<&Value> for MemoryRunParams { fn from(value: &Value) -> Self { Self { - max_context_length: value["max_context_length"].as_u64().unwrap_or(1024) as usize, + max_context: value["max_context"].as_u64().unwrap_or(1024) as usize, // messages are for most backends, contents are for Gemini is_for_chat: value["messages"].is_array() || value["contents"].is_array(), } diff --git a/crates/lsp-ai/src/memory_backends/postgresml/mod.rs b/crates/lsp-ai/src/memory_backends/postgresml/mod.rs index 858e29e..2c08065 100644 --- a/crates/lsp-ai/src/memory_backends/postgresml/mod.rs +++ b/crates/lsp-ai/src/memory_backends/postgresml/mod.rs @@ -5,6 +5,7 @@ use pgml::{Collection, Pipeline}; use rand::{distributions::Alphanumeric, Rng}; use serde_json::{json, Value}; use std::{ + collections::HashSet, io::Read, sync::{ mpsc::{self, Sender}, @@ -29,7 +30,7 @@ use super::{ const RESYNC_MAX_FILE_SIZE: u64 = 10_000_000; -fn format_chunk_chunk(uri: &str, chunk: &Chunk, root_uri: Option<&str>) -> String { +fn format_file_excerpt(uri: &str, excerpt: &str, root_uri: Option<&str>) -> String { let path = match root_uri { Some(root_uri) => { if uri.starts_with(root_uri) { @@ -42,9 +43,8 @@ fn format_chunk_chunk(uri: &str, chunk: &Chunk, root_uri: Option<&str>) -> Strin }; format!( r#"--{path}-- -{} +{excerpt} "#, - chunk.text ) } @@ -52,7 +52,7 @@ fn chunk_to_document(uri: &str, chunk: Chunk, root_uri: Option<&str>) -> Value { json!({ "id": chunk_to_id(uri, &chunk), "uri": uri, - "text": format_chunk_chunk(uri, &chunk, root_uri), + "text": format_file_excerpt(uri, &chunk.text, root_uri), "range": chunk.range }) } @@ -86,6 +86,7 @@ async fn split_and_upsert_file( #[derive(Clone)] pub struct PostgresML { config: Config, + postgresml_config: config::PostgresML, file_store: Arc, collection: Collection, pipeline: Pipeline, @@ -106,7 +107,7 @@ impl PostgresML { .map(|x| Arc::new(Mutex::new(Crawl::new(x, configuration.clone())))); let splitter: Arc> = - Arc::new(postgresml_config.splitter.try_into()?); + Arc::new(postgresml_config.splitter.clone().try_into()?); let file_store = Arc::new(FileStore::new_with_params( config::FileStore::new_without_crawl(), @@ -114,20 +115,20 @@ impl PostgresML { AdditionalFileStoreParams::new(splitter.does_use_tree_sitter()), )?); - let database_url = if let Some(database_url) = postgresml_config.database_url { + let database_url = if let Some(database_url) = postgresml_config.database_url.clone() { database_url } else { std::env::var("PGML_DATABASE_URL").context("please provide either the `database_url` in the `postgresml` config, or set the `PGML_DATABASE_URL` environment variable")? }; // Build our pipeline schema - let pipeline = match postgresml_config.embedding_model { + let pipeline = match &postgresml_config.embedding_model { Some(embedding_model) => { json!({ "text": { "semantic_search": { - "model": embedding_model, - "parameters": postgresml_config.embedding_model_parameters + "model": embedding_model.model, + "parameters": embedding_model.embed_parameters } } }) @@ -281,6 +282,7 @@ impl PostgresML { let s = Self { config: configuration, + postgresml_config, file_store, collection, pipeline, @@ -332,12 +334,19 @@ impl PostgresML { let mut documents_to_delete = vec![]; let mut chunks_to_upsert = vec![]; let mut current_chunks_bytes = 0; + let mut checked_uris = HashSet::new(); for document in documents.into_iter() { let uri = match document["document"]["uri"].as_str() { Some(uri) => uri, None => continue, // This should never happen, but is really bad as we now have a document with essentially no way to delete it }; + // Check if we have already loaded in this file + if checked_uris.contains(uri) { + continue; + } + checked_uris.insert(uri.to_string()); + let path = uri.replace("file://", ""); let path = std::path::Path::new(&path); if !path.exists() { @@ -458,9 +467,9 @@ impl PostgresML { if let Err(e) = collection .upsert_documents(to_upsert_documents, None) .await - .context("PGML - Error upserting changed files") + .context("PGML - error upserting changed files") { - error!("{e}"); + error!("{e:?}"); } }); // Reset everything @@ -476,9 +485,9 @@ impl PostgresML { if let Err(e) = collection .upsert_documents(documents, None) .await - .context("PGML - Error upserting changed files") + .context("PGML - error upserting changed files") { - error!("{e}"); + error!("{e:?}"); } }); } @@ -502,32 +511,38 @@ impl MemoryBackend for PostgresML { params: &Value, ) -> anyhow::Result { let params: MemoryRunParams = params.try_into()?; - - // TOOD: FIGURE THIS OUT - // let prompt_size = params.max_context_length + let chunk_size = self.splitter.chunk_size(); + let total_allowed_characters = tokens_to_estimated_characters(params.max_context); // Build the query let query = self .file_store - .get_characters_around_position(position, 512)?; + .get_characters_around_position(position, chunk_size)?; // Build the prompt let mut file_store_params = params.clone(); - file_store_params.max_context_length = 512; + file_store_params.max_context = chunk_size; let code = self .file_store .build_code(position, prompt_type, file_store_params, false)?; // Get the byte of the cursor let cursor_byte = self.file_store.position_to_byte(position)?; - eprintln!( - "CURSOR BYTE: {} IN DOCUMENT: {}", - cursor_byte, - position.text_document.uri.to_string() - ); // Get the context - let limit = params.max_context_length / 512; + let limit = (total_allowed_characters / chunk_size).saturating_sub(1); + let parameters = match self + .postgresml_config + .embedding_model + .as_ref() + .map(|m| m.query_parameters.clone()) + .flatten() + { + Some(query_parameters) => query_parameters, + None => json!({ + "prompt": "query: " + }), + }; let res = self .collection .vector_search_local( @@ -536,9 +551,7 @@ impl MemoryBackend for PostgresML { "fields": { "text": { "query": query, - "parameters": { - "prompt": "query: " - } + "parameters": parameters } }, "filter": { @@ -581,17 +594,20 @@ impl MemoryBackend for PostgresML { }) .collect::>>()? .join("\n\n"); - - eprintln!("THE CONTEXT:\n\n{context}\n\n"); - - let chars = tokens_to_estimated_characters(params.max_context_length.saturating_sub(512)); - let context = &context[..chars.min(context.len())]; + let context = &context[..(total_allowed_characters - chunk_size).min(context.len())]; // Reconstruct the Prompts Ok(match code { - Prompt::ContextAndCode(context_and_code) => Prompt::ContextAndCode( - ContextAndCodePrompt::new(context.to_owned(), context_and_code.code), - ), + Prompt::ContextAndCode(context_and_code) => { + Prompt::ContextAndCode(ContextAndCodePrompt::new( + context.to_owned(), + format_file_excerpt( + &position.text_document.uri.to_string(), + &context_and_code.code, + self.config.client_params.root_uri.as_deref(), + ), + )) + } Prompt::FIM(fim) => Prompt::FIM(FIMPrompt::new( format!("{context}\n\n{}", fim.prompt), fim.suffix,