Updated memory backend to build prompt correctly

This commit is contained in:
Silas Marvin 2024-06-19 08:26:06 -07:00 committed by Asuka Minato
parent f85c964a30
commit a95c27602e
4 changed files with 25 additions and 21 deletions

View File

@ -114,7 +114,7 @@ impl FileStore {
Ok(match prompt_type {
PromptType::ContextAndCode => {
if params.messages.is_some() {
if params.is_for_chat {
let max_length = tokens_to_estimated_characters(params.max_context_length);
let start = cursor_index.saturating_sub(max_length / 2);
let end = rope
@ -185,9 +185,9 @@ impl MemoryBackend for FileStore {
&self,
position: &TextDocumentPositionParams,
prompt_type: PromptType,
params: Value,
params: &Value,
) -> anyhow::Result<Prompt> {
let params: MemoryRunParams = serde_json::from_value(params)?;
let params: MemoryRunParams = params.try_into()?;
self.build_code(position, prompt_type, params)
}
@ -414,7 +414,7 @@ The end with a trailing new line
},
},
PromptType::ContextAndCode,
json!({}),
&json!({}),
)
.await?;
let prompt: ContextAndCodePrompt = prompt.try_into()?;
@ -434,7 +434,7 @@ The end with a trailing new line
},
},
PromptType::FIM,
json!({}),
&json!({}),
)
.await?;
let prompt: FIMPrompt = prompt.try_into()?;
@ -463,7 +463,7 @@ The end with a trailing new line
},
},
PromptType::ContextAndCode,
json!({
&json!({
"messages": []
}),
)
@ -510,7 +510,7 @@ The end with a trailing new line
},
},
PromptType::ContextAndCode,
json!({}),
&json!({}),
)
.await?;
let prompt: ContextAndCodePrompt = prompt.try_into()?;
@ -542,7 +542,7 @@ The end with a trailing new line
},
},
PromptType::ContextAndCode,
json!({"messages": []}),
&json!({"messages": []}),
)
.await?;
let prompt: ContextAndCodePrompt = prompt.try_into()?;

View File

@ -2,31 +2,35 @@ use lsp_types::{
DidChangeTextDocumentParams, DidOpenTextDocumentParams, RenameFilesParams,
TextDocumentPositionParams,
};
use serde::Deserialize;
use serde_json::Value;
use crate::config::{ChatMessage, Config, ValidMemoryBackend};
use crate::config::{Config, ValidMemoryBackend};
pub mod file_store;
mod postgresml;
const fn max_context_length_default() -> usize {
1024
}
#[derive(Clone, Debug)]
pub enum PromptType {
ContextAndCode,
FIM,
}
#[derive(Clone, Deserialize)]
#[derive(Clone)]
pub struct MemoryRunParams {
pub messages: Option<Vec<ChatMessage>>,
#[serde(default = "max_context_length_default")]
pub is_for_chat: bool,
pub max_context_length: usize,
}
impl From<&Value> for MemoryRunParams {
fn from(value: &Value) -> Self {
Self {
max_context_length: value["max_context_length"].as_u64().unwrap_or(1024) as usize,
// messages are for most backends, contents are for Gemini
is_for_chat: value["messages"].is_array() || value["contents"].is_array(),
}
}
}
#[derive(Debug)]
pub struct ContextAndCodePrompt {
pub context: String,
@ -119,7 +123,7 @@ pub trait MemoryBackend {
&self,
position: &TextDocumentPositionParams,
prompt_type: PromptType,
params: Value,
params: &Value,
) -> anyhow::Result<Prompt>;
async fn get_filter_text(
&self,

View File

@ -132,9 +132,9 @@ impl MemoryBackend for PostgresML {
&self,
position: &TextDocumentPositionParams,
prompt_type: PromptType,
params: Value,
params: &Value,
) -> anyhow::Result<Prompt> {
let params: MemoryRunParams = serde_json::from_value(params)?;
let params: MemoryRunParams = params.try_into()?;
let query = self
.file_store
.get_characters_around_position(position, 512)?;

View File

@ -70,7 +70,7 @@ async fn do_task(
}
WorkerRequest::Prompt(params) => {
let prompt = memory_backend
.build_prompt(&params.position, params.prompt_type, params.params)
.build_prompt(&params.position, params.prompt_type, &params.params)
.await?;
params
.tx