From a95c27602e89e5933a70f8902e59fce1f9d023f6 Mon Sep 17 00:00:00 2001 From: Silas Marvin <19626586+SilasMarvin@users.noreply.github.com> Date: Wed, 19 Jun 2024 08:26:06 -0700 Subject: [PATCH] Updated memory backend to build prompt correctly --- src/memory_backends/file_store.rs | 16 ++++++++-------- src/memory_backends/mod.rs | 24 ++++++++++++++---------- src/memory_backends/postgresml/mod.rs | 4 ++-- src/memory_worker.rs | 2 +- 4 files changed, 25 insertions(+), 21 deletions(-) diff --git a/src/memory_backends/file_store.rs b/src/memory_backends/file_store.rs index 219998b..4d70509 100644 --- a/src/memory_backends/file_store.rs +++ b/src/memory_backends/file_store.rs @@ -114,7 +114,7 @@ impl FileStore { Ok(match prompt_type { PromptType::ContextAndCode => { - if params.messages.is_some() { + if params.is_for_chat { let max_length = tokens_to_estimated_characters(params.max_context_length); let start = cursor_index.saturating_sub(max_length / 2); let end = rope @@ -185,9 +185,9 @@ impl MemoryBackend for FileStore { &self, position: &TextDocumentPositionParams, prompt_type: PromptType, - params: Value, + params: &Value, ) -> anyhow::Result { - let params: MemoryRunParams = serde_json::from_value(params)?; + let params: MemoryRunParams = params.try_into()?; self.build_code(position, prompt_type, params) } @@ -414,7 +414,7 @@ The end with a trailing new line }, }, PromptType::ContextAndCode, - json!({}), + &json!({}), ) .await?; let prompt: ContextAndCodePrompt = prompt.try_into()?; @@ -434,7 +434,7 @@ The end with a trailing new line }, }, PromptType::FIM, - json!({}), + &json!({}), ) .await?; let prompt: FIMPrompt = prompt.try_into()?; @@ -463,7 +463,7 @@ The end with a trailing new line }, }, PromptType::ContextAndCode, - json!({ + &json!({ "messages": [] }), ) @@ -510,7 +510,7 @@ The end with a trailing new line }, }, PromptType::ContextAndCode, - json!({}), + &json!({}), ) .await?; let prompt: ContextAndCodePrompt = prompt.try_into()?; @@ -542,7 +542,7 @@ The end with a trailing new line }, }, PromptType::ContextAndCode, - json!({"messages": []}), + &json!({"messages": []}), ) .await?; let prompt: ContextAndCodePrompt = prompt.try_into()?; diff --git a/src/memory_backends/mod.rs b/src/memory_backends/mod.rs index e824d3f..52a8974 100644 --- a/src/memory_backends/mod.rs +++ b/src/memory_backends/mod.rs @@ -2,31 +2,35 @@ use lsp_types::{ DidChangeTextDocumentParams, DidOpenTextDocumentParams, RenameFilesParams, TextDocumentPositionParams, }; -use serde::Deserialize; use serde_json::Value; -use crate::config::{ChatMessage, Config, ValidMemoryBackend}; +use crate::config::{Config, ValidMemoryBackend}; pub mod file_store; mod postgresml; -const fn max_context_length_default() -> usize { - 1024 -} - #[derive(Clone, Debug)] pub enum PromptType { ContextAndCode, FIM, } -#[derive(Clone, Deserialize)] +#[derive(Clone)] pub struct MemoryRunParams { - pub messages: Option>, - #[serde(default = "max_context_length_default")] + pub is_for_chat: bool, pub max_context_length: usize, } +impl From<&Value> for MemoryRunParams { + fn from(value: &Value) -> Self { + Self { + max_context_length: value["max_context_length"].as_u64().unwrap_or(1024) as usize, + // messages are for most backends, contents are for Gemini + is_for_chat: value["messages"].is_array() || value["contents"].is_array(), + } + } +} + #[derive(Debug)] pub struct ContextAndCodePrompt { pub context: String, @@ -119,7 +123,7 @@ pub trait MemoryBackend { &self, position: &TextDocumentPositionParams, prompt_type: PromptType, - params: Value, + params: &Value, ) -> anyhow::Result; async fn get_filter_text( &self, diff --git a/src/memory_backends/postgresml/mod.rs b/src/memory_backends/postgresml/mod.rs index 8e0f748..8b007ab 100644 --- a/src/memory_backends/postgresml/mod.rs +++ b/src/memory_backends/postgresml/mod.rs @@ -132,9 +132,9 @@ impl MemoryBackend for PostgresML { &self, position: &TextDocumentPositionParams, prompt_type: PromptType, - params: Value, + params: &Value, ) -> anyhow::Result { - let params: MemoryRunParams = serde_json::from_value(params)?; + let params: MemoryRunParams = params.try_into()?; let query = self .file_store .get_characters_around_position(position, 512)?; diff --git a/src/memory_worker.rs b/src/memory_worker.rs index 86082f8..39cad6c 100644 --- a/src/memory_worker.rs +++ b/src/memory_worker.rs @@ -70,7 +70,7 @@ async fn do_task( } WorkerRequest::Prompt(params) => { let prompt = memory_backend - .build_prompt(¶ms.position, params.prompt_type, params.params) + .build_prompt(¶ms.position, params.prompt_type, ¶ms.params) .await?; params .tx