Getting closer

This commit is contained in:
Silas Marvin 2024-06-19 12:14:56 -07:00
parent 3e8c99b237
commit 9166aaf4b6
7 changed files with 186 additions and 35 deletions

View File

@ -156,6 +156,8 @@ pub struct PostgresML {
pub crawl: Option<Crawl>,
#[serde(default)]
pub splitter: ValidSplitter,
pub embedding_model: Option<String>,
pub embedding_model_parameters: Option<Value>,
}
#[derive(Clone, Debug, Deserialize, Default)]

View File

@ -1,6 +1,6 @@
use anyhow::Context;
use indexmap::IndexSet;
use lsp_types::TextDocumentPositionParams;
use lsp_types::{Position, TextDocumentPositionParams};
use parking_lot::Mutex;
use ropey::Rope;
use serde_json::Value;
@ -154,6 +154,7 @@ impl FileStore {
&self,
position: &TextDocumentPositionParams,
characters: usize,
pull_from_multiple_files: bool,
) -> anyhow::Result<(Rope, usize)> {
// Get the rope and set our initial cursor index
let current_document_uri = position.text_document.uri.to_string();
@ -174,7 +175,7 @@ impl FileStore {
.filter(|f| **f != current_document_uri)
{
let needed = characters.saturating_sub(rope.len_chars() + 1);
if needed == 0 {
if needed == 0 || !pull_from_multiple_files {
break;
}
let file_map = self.file_map.lock();
@ -220,9 +221,13 @@ impl FileStore {
position: &TextDocumentPositionParams,
prompt_type: PromptType,
params: MemoryRunParams,
pull_from_multiple_files: bool,
) -> anyhow::Result<Prompt> {
let (mut rope, cursor_index) =
self.get_rope_for_position(position, params.max_context_length)?;
let (mut rope, cursor_index) = self.get_rope_for_position(
position,
params.max_context_length,
pull_from_multiple_files,
)?;
Ok(match prompt_type {
PromptType::ContextAndCode => {
@ -277,6 +282,20 @@ impl FileStore {
pub fn contains_file(&self, uri: &str) -> bool {
self.file_map.lock().contains_key(uri)
}
pub fn position_to_byte(&self, position: &TextDocumentPositionParams) -> anyhow::Result<usize> {
let file_map = self.file_map.lock();
let uri = position.text_document.uri.to_string();
let file = file_map
.get(&uri)
.with_context(|| format!("trying to get file that does not exist {uri}"))?;
let line_char_index = file
.rope
.try_line_to_char(position.position.line as usize)?;
Ok(file
.rope
.try_char_to_byte(line_char_index + position.position.character as usize)?)
}
}
#[async_trait::async_trait]
@ -307,7 +326,7 @@ impl MemoryBackend for FileStore {
params: &Value,
) -> anyhow::Result<Prompt> {
let params: MemoryRunParams = params.try_into()?;
self.build_code(position, prompt_type, params)
self.build_code(position, prompt_type, params, true)
}
#[instrument(skip(self))]

View File

@ -29,11 +29,30 @@ use super::{
const RESYNC_MAX_FILE_SIZE: u64 = 10_000_000;
fn chunk_to_document(uri: &str, chunk: Chunk) -> Value {
fn format_chunk_chunk(uri: &str, chunk: &Chunk, root_uri: Option<&str>) -> String {
let path = match root_uri {
Some(root_uri) => {
if uri.starts_with(root_uri) {
&uri[root_uri.chars().count()..]
} else {
uri
}
}
None => uri,
};
format!(
r#"--{path}--
{}
"#,
chunk.text
)
}
fn chunk_to_document(uri: &str, chunk: Chunk, root_uri: Option<&str>) -> Value {
json!({
"id": chunk_to_id(uri, &chunk),
"uri": uri,
"text": chunk.text,
"text": format_chunk_chunk(uri, &chunk, root_uri),
"range": chunk.range
})
}
@ -43,6 +62,7 @@ async fn split_and_upsert_file(
collection: &mut Collection,
file_store: Arc<FileStore>,
splitter: Arc<Box<dyn Splitter + Send + Sync>>,
root_uri: Option<&str>,
) -> anyhow::Result<()> {
// We need to make sure we don't hold the file_store lock while performing a network call
let chunks = {
@ -55,7 +75,7 @@ async fn split_and_upsert_file(
let chunks = chunks.with_context(|| format!("file not found for splitting: {uri}"))?;
let documents = chunks
.into_iter()
.map(|chunk| chunk_to_document(uri, chunk).into())
.map(|chunk| chunk_to_document(uri, chunk, root_uri).into())
.collect();
collection
.upsert_documents(documents, None)
@ -65,7 +85,7 @@ async fn split_and_upsert_file(
#[derive(Clone)]
pub struct PostgresML {
_config: Config,
config: Config,
file_store: Arc<FileStore>,
collection: Collection,
pipeline: Pipeline,
@ -100,21 +120,19 @@ impl PostgresML {
std::env::var("PGML_DATABASE_URL").context("please provide either the `database_url` in the `postgresml` config, or set the `PGML_DATABASE_URL` environment variable")?
};
let collection_name = match configuration.client_params.root_uri.clone() {
Some(root_uri) => format!("{:x}", md5::compute(root_uri.as_bytes())),
None => {
warn!("no root_uri provided in server configuration - generating random string for collection name");
rand::thread_rng()
.sample_iter(&Alphanumeric)
.take(21)
.map(char::from)
.collect()
// Build our pipeline schema
let pipeline = match postgresml_config.embedding_model {
Some(embedding_model) => {
json!({
"text": {
"semantic_search": {
"model": embedding_model,
"parameters": postgresml_config.embedding_model_parameters
}
}
})
}
};
let mut collection = Collection::new(&collection_name, Some(database_url))?;
let mut pipeline = Pipeline::new(
"v1",
Some(
None => {
json!({
"text": {
"semantic_search": {
@ -125,16 +143,36 @@ impl PostgresML {
}
}
})
.into(),
}
};
// When building the collection name we include the Pipeline schema
// If the user changes the Pipeline schema, it will take affect without them having to delete the old files
let collection_name = match configuration.client_params.root_uri.clone() {
Some(root_uri) => format!(
"{:x}",
md5::compute(
format!("{root_uri}_{}", serde_json::to_string(&pipeline)?).as_bytes()
)
),
)?;
None => {
warn!("no root_uri provided in server configuration - generating random string for collection name");
rand::thread_rng()
.sample_iter(&Alphanumeric)
.take(21)
.map(char::from)
.collect()
}
};
let mut collection = Collection::new(&collection_name, Some(database_url))?;
let mut pipeline = Pipeline::new("v1", Some(pipeline.into()))?;
// Add the Pipeline to the Collection
TOKIO_RUNTIME.block_on(async {
collection
.add_pipeline(&mut pipeline)
.await
.context("PGML - Error adding pipeline to collection")
.context("PGML - error adding pipeline to collection")
})?;
// Setup up a debouncer for changed text documents
@ -142,6 +180,7 @@ impl PostgresML {
let mut task_collection = collection.clone();
let task_file_store = file_store.clone();
let task_splitter = splitter.clone();
let task_root_uri = configuration.client_params.root_uri.clone();
TOKIO_RUNTIME.spawn(async move {
let duration = Duration::from_millis(500);
let mut file_uris = Vec::new();
@ -218,7 +257,9 @@ impl PostgresML {
.map(|(chunks, uri)| {
chunks
.into_iter()
.map(|chunk| chunk_to_document(&uri, chunk))
.map(|chunk| {
chunk_to_document(&uri, chunk, task_root_uri.as_deref())
})
.collect::<Vec<Value>>()
})
.flatten()
@ -227,7 +268,7 @@ impl PostgresML {
if let Err(e) = task_collection
.upsert_documents(documents, None)
.await
.context("PGML - Error upserting changed files")
.context("PGML - error upserting changed files")
{
error!("{e:?}");
continue;
@ -239,7 +280,7 @@ impl PostgresML {
});
let s = Self {
_config: configuration,
config: configuration,
file_store,
collection,
pipeline,
@ -317,7 +358,14 @@ impl PostgresML {
.splitter
.split_file_contents(&uri, &contents)
.into_iter()
.map(|chunk| chunk_to_document(&uri, chunk).into())
.map(|chunk| {
chunk_to_document(
&uri,
chunk,
self.config.client_params.root_uri.as_deref(),
)
.into()
})
.collect();
chunks_to_upsert.extend(chunks);
// If we have over 10 mega bytes of chunks do the upsert
@ -326,10 +374,18 @@ impl PostgresML {
.upsert_documents(chunks_to_upsert, None)
.await
.context("PGML - error upserting documents during resync")?;
chunks_to_upsert = vec![];
current_chunks_bytes = 0;
}
chunks_to_upsert = vec![];
}
}
// Upsert any remaining chunks
if chunks_to_upsert.len() > 0 {
collection
.upsert_documents(chunks_to_upsert, None)
.await
.context("PGML - error upserting documents during resync")?;
}
// Delete documents
if !documents_to_delete.is_empty() {
collection
@ -382,7 +438,14 @@ impl PostgresML {
.splitter
.split_file_contents(&uri, &contents)
.into_iter()
.map(|chunk| chunk_to_document(&uri, chunk).into())
.map(|chunk| {
chunk_to_document(
&uri,
chunk,
self.config.client_params.root_uri.as_deref(),
)
.into()
})
.collect();
documents.extend(chunks);
// If we have over 10 mega bytes of data do the upsert
@ -440,17 +503,28 @@ impl MemoryBackend for PostgresML {
) -> anyhow::Result<Prompt> {
let params: MemoryRunParams = params.try_into()?;
// TOOD: FIGURE THIS OUT
// let prompt_size = params.max_context_length
// Build the query
let query = self
.file_store
.get_characters_around_position(position, 512)?;
// Get the code around the Cursor
// Build the prompt
let mut file_store_params = params.clone();
file_store_params.max_context_length = 512;
let code = self
.file_store
.build_code(position, prompt_type, file_store_params)?;
.build_code(position, prompt_type, file_store_params, false)?;
// Get the byte of the cursor
let cursor_byte = self.file_store.position_to_byte(position)?;
eprintln!(
"CURSOR BYTE: {} IN DOCUMENT: {}",
cursor_byte,
position.text_document.uri.to_string()
);
// Get the context
let limit = params.max_context_length / 512;
@ -467,6 +541,29 @@ impl MemoryBackend for PostgresML {
}
}
},
"filter": {
"$or": [
{
"uri": {
"$ne": position.text_document.uri.to_string()
}
},
{
"range": {
"start": {
"$gt": cursor_byte
},
},
},
{
"range": {
"end": {
"$lt": cursor_byte
},
}
}
]
}
},
"limit": limit
})
@ -485,6 +582,8 @@ impl MemoryBackend for PostgresML {
.collect::<anyhow::Result<Vec<String>>>()?
.join("\n\n");
eprintln!("THE CONTEXT:\n\n{context}\n\n");
let chars = tokens_to_estimated_characters(params.max_context_length.saturating_sub(512));
let context = &context[..chars.min(context.len())];
@ -512,9 +611,17 @@ impl MemoryBackend for PostgresML {
let mut collection = self.collection.clone();
let file_store = self.file_store.clone();
let splitter = self.splitter.clone();
let root_uri = self.config.client_params.root_uri.clone();
TOKIO_RUNTIME.spawn(async move {
let uri = params.text_document.uri.to_string();
if let Err(e) = split_and_upsert_file(&uri, &mut collection, file_store, splitter).await
if let Err(e) = split_and_upsert_file(
&uri,
&mut collection,
file_store,
splitter,
root_uri.as_deref(),
)
.await
{
error!("{e:?}")
}
@ -544,6 +651,7 @@ impl MemoryBackend for PostgresML {
let mut collection = self.collection.clone();
let file_store = self.file_store.clone();
let splitter = self.splitter.clone();
let root_uri = self.config.client_params.root_uri.clone();
TOKIO_RUNTIME.spawn(async move {
for file in params.files {
if let Err(e) = collection
@ -564,6 +672,7 @@ impl MemoryBackend for PostgresML {
&mut collection,
file_store.clone(),
splitter.clone(),
root_uri.as_deref(),
)
.await
{

View File

@ -39,6 +39,8 @@ pub trait Splitter {
fn does_use_tree_sitter(&self) -> bool {
false
}
fn chunk_size(&self) -> usize;
}
impl TryFrom<ValidSplitter> for Box<dyn Splitter + Send + Sync> {

View File

@ -3,18 +3,21 @@ use crate::{config, memory_backends::file_store::File};
use super::{ByteRange, Chunk, Splitter};
pub struct TextSplitter {
chunk_size: usize,
splitter: text_splitter::TextSplitter<text_splitter::Characters>,
}
impl TextSplitter {
pub fn new(config: config::TextSplitter) -> Self {
Self {
chunk_size: config.chunk_size,
splitter: text_splitter::TextSplitter::new(config.chunk_size),
}
}
pub fn new_with_chunk_size(chunk_size: usize) -> Self {
Self {
chunk_size,
splitter: text_splitter::TextSplitter::new(chunk_size),
}
}
@ -37,4 +40,8 @@ impl Splitter for TextSplitter {
acc
})
}
fn chunk_size(&self) -> usize {
self.chunk_size
}
}

View File

@ -7,6 +7,7 @@ use crate::{config, memory_backends::file_store::File, utils::parse_tree};
use super::{text_splitter::TextSplitter, ByteRange, Chunk, Splitter};
pub struct TreeSitter {
chunk_size: usize,
splitter: TreeSitterCodeSplitter,
text_splitter: TextSplitter,
}
@ -15,6 +16,7 @@ impl TreeSitter {
pub fn new(config: config::TreeSitter) -> anyhow::Result<Self> {
let text_splitter = TextSplitter::new_with_chunk_size(config.chunk_size);
Ok(Self {
chunk_size: config.chunk_size,
splitter: TreeSitterCodeSplitter::new(config.chunk_size, config.chunk_overlap)?,
text_splitter,
})
@ -75,4 +77,8 @@ impl Splitter for TreeSitter {
fn does_use_tree_sitter(&self) -> bool {
true
}
fn chunk_size(&self) -> usize {
self.chunk_size
}
}

View File

@ -156,6 +156,12 @@ impl OpenAI {
messages: Vec<ChatMessage>,
params: OpenAIRunParams,
) -> anyhow::Result<String> {
eprintln!("\n\n\n\n");
for message in &messages {
eprintln!("{}:\n{}\n", message.role.to_string(), message.content);
}
eprintln!("\n\n\n\n");
let client = reqwest::Client::new();
let token = self.get_token()?;
let res: OpenAIChatResponse = client