From 91f03457fc85f8bfb86fcc440ce4068e6b3c96b9 Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Mon, 17 Jun 2024 17:25:33 +0900 Subject: [PATCH 01/12] now the test gemini_completion_do_generate works --- Cargo.lock | 2 +- src/config.rs | 56 +++++++++ src/transformer_backends/gemini.rs | 189 +++++++++++++++++++++++++++++ src/transformer_backends/mod.rs | 2 + 4 files changed, 248 insertions(+), 1 deletion(-) create mode 100644 src/transformer_backends/gemini.rs diff --git a/Cargo.lock b/Cargo.lock index 34f5542..524e12f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1518,7 +1518,7 @@ dependencies = [ [[package]] name = "lsp-ai" -version = "0.2.0" +version = "0.3.0" dependencies = [ "anyhow", "assert_cmd", diff --git a/src/config.rs b/src/config.rs index 8cbeadd..c6445b2 100644 --- a/src/config.rs +++ b/src/config.rs @@ -46,6 +46,8 @@ pub enum ValidModel { MistralFIM(MistralFIM), #[serde(rename = "ollama")] Ollama(Ollama), + #[serde(rename = "gemini")] + Gemini(Gemini), } #[derive(Debug, Clone, Deserialize, Serialize)] @@ -171,6 +173,24 @@ pub struct OpenAI { pub model: String, } +#[derive(Clone, Debug, Deserialize)] +#[serde(deny_unknown_fields)] +pub struct Gemini { + // The auth token env var name + pub auth_token_env_var_name: Option, + // The auth token + pub auth_token: Option, + // The completions endpoint + pub completions_endpoint: Option, + // The chat endpoint + pub chat_endpoint: Option, + // The maximum requests per second + #[serde(default = "max_requests_per_second_default")] + pub max_requests_per_second: f32, + // The model name + pub model: String, +} + #[derive(Clone, Debug, Deserialize)] #[serde(deny_unknown_fields)] pub struct Anthropic { @@ -272,6 +292,7 @@ impl Config { #[cfg(feature = "llama_cpp")] ValidModel::LLaMACPP(llama_cpp) => Ok(llama_cpp.max_requests_per_second), ValidModel::OpenAI(open_ai) => Ok(open_ai.max_requests_per_second), + ValidModel::Gemini(gemini) => Ok(gemini.max_requests_per_second), ValidModel::Anthropic(anthropic) => Ok(anthropic.max_requests_per_second), ValidModel::MistralFIM(mistral_fim) => Ok(mistral_fim.max_requests_per_second), ValidModel::Ollama(ollama) => Ok(ollama.max_requests_per_second), @@ -402,6 +423,41 @@ mod test { }); Config::new(args).unwrap(); } + #[test] + fn gemini_config() { + let args = json!({ + "initializationOptions": { + "memory": { + "file_store": {} + }, + "models": { + "model1": { + "type": "gemini", + "completions_endpoint": "https://api.fireworks.ai/inference/v1/completions", + "model": "accounts/fireworks/models/llama-v2-34b-code", + "auth_token_env_var_name": "FIREWORKS_API_KEY", + }, + }, + "completion": { + "model": "model1", + "parameters": { + "messages": [ + { + "role": "system", + "content": "Test", + }, + { + "role": "user", + "content": "Test {CONTEXT} - {CODE}" + } + ], + "max_new_tokens": 32, + } + } + } + }); + Config::new(args).unwrap(); + } #[test] fn anthropic_config() { diff --git a/src/transformer_backends/gemini.rs b/src/transformer_backends/gemini.rs new file mode 100644 index 0000000..9dacd63 --- /dev/null +++ b/src/transformer_backends/gemini.rs @@ -0,0 +1,189 @@ +use std::collections::HashMap; + +use anyhow::Context; +use serde::Deserialize; +use serde_json::{json, Value}; +use tracing::instrument; + +use super::{open_ai::OpenAIChatChoices, TransformerBackend}; +use crate::{ + config::{self}, + memory_backends::{FIMPrompt, Prompt, PromptType}, + transformer_worker::{ + DoGenerationResponse, DoGenerationStreamResponse, GenerationStreamRequest, + }, +}; + +const fn max_tokens_default() -> usize { + 64 +} + +const fn top_p_default() -> f32 { + 0.95 +} + +const fn temperature_default() -> f32 { + 0.1 +} + +// NOTE: We cannot deny unknown fields as the provided parameters may contain other fields relevant to other processes +#[derive(Debug, Deserialize)] +pub struct GeminiRunParams { + #[serde(default = "max_tokens_default")] + pub max_tokens: usize, + #[serde(default = "top_p_default")] + pub top_p: f32, + #[serde(default = "temperature_default")] + pub temperature: f32, + pub min_tokens: Option, + pub random_seed: Option, + #[serde(default)] + pub stop: Vec, +} + +pub struct Gemini { + config: config::Gemini, +} + +impl Gemini { + pub fn new(config: config::Gemini) -> Self { + Self { config } + } + + fn get_token(&self) -> anyhow::Result { + if let Some(env_var_name) = &self.config.auth_token_env_var_name { + Ok(std::env::var(env_var_name)?) + } else if let Some(token) = &self.config.auth_token { + Ok(token.to_string()) + } else { + anyhow::bail!( + "set `auth_token_env_var_name` or `auth_token` to use an Gemini compatible API" + ) + } + } + + async fn do_fim(&self, prompt: &FIMPrompt, params: GeminiRunParams) -> anyhow::Result { + let client = reqwest::Client::new(); + let token = self.get_token()?; + let res: serde_json::Value = client + .post( + self.config + .completions_endpoint + .as_ref() + .context("must specify `gemini_endpoint` to use gemini")?, + ) + .header("Content-Type", "application/json") + .json(&json!( + { + "contents":[ + { + "parts":[ + { + "text": prompt.prompt + } + ] + } + ] + } + )) + .send() + .await? + .json() + .await?; + if let Some(error) = res.get("error") { + anyhow::bail!("{:?}", error.to_string()) + } else if let Some(candidates) = res.get("candidates") { + Ok(candidates + .get(0) + .unwrap() + .get("content") + .unwrap() + .get("parts") + .unwrap() + .get(0) + .unwrap() + .get("text") + .unwrap() + .clone() + .to_string()) + } else { + anyhow::bail!("Unknown error while making request to Gemini: {:?}", res); + } + } +} + +#[async_trait::async_trait] +impl TransformerBackend for Gemini { + #[instrument(skip(self))] + async fn do_generate( + &self, + prompt: &Prompt, + params: Value, + ) -> anyhow::Result { + let params: GeminiRunParams = serde_json::from_value(params)?; + let generated_text = self.do_fim(prompt.try_into()?, params).await?; + Ok(DoGenerationResponse { generated_text }) + } + + #[instrument(skip(self))] + async fn do_generate_stream( + &self, + request: &GenerationStreamRequest, + _params: Value, + ) -> anyhow::Result { + anyhow::bail!("GenerationStream is not yet implemented") + } + + fn get_prompt_type(&self, _params: &Value) -> anyhow::Result { + Ok(PromptType::FIM) + } +} + +#[cfg(test)] +mod test { + use super::*; + use serde_json::{from_value, json}; + + #[tokio::test] + async fn gemini_completion_do_generate() -> anyhow::Result<()> { + let configuration: config::Gemini = from_value(json!({ + "completions_endpoint": "https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash-latest:generateContent?key=", // here fill the key + "model": "gemini-1.5-flash-latest", + "auth_token_env_var_name": "GEMINI_API_KEY", + }))?; + let anthropic = Gemini::new(configuration); + let prompt = Prompt::default_fim(); + let run_params = json!({ + "max_tokens": 2 + }); + let response = anthropic.do_generate(&prompt, run_params).await?; + assert!(!response.generated_text.is_empty()); + Ok(()) + } + #[tokio::test] + async fn gemini_chat_do_generate() -> anyhow::Result<()> { + let configuration: config::Gemini = serde_json::from_value(json!({ + "chat_endpoint": "https://generativelanguage.googleapis.com/v1beta/", + "model": "gemini-1.5-flash", + "auth_token_env_var_name": "GEMINI_API_KEY", + }))?; + let gemini = Gemini::new(configuration); + let prompt = Prompt::default_with_cursor(); + let run_params = json!({ + "messages": [ + { + "role": "system", + "content": "Test" + }, + { + "role": "user", + "content": "Test {CONTEXT} - {CODE}" + } + ], + "max_tokens": 64 + }); + let response = gemini.do_generate(&prompt, run_params).await?; + assert!(!response.generated_text.is_empty()); + Ok(()) + } +} diff --git a/src/transformer_backends/mod.rs b/src/transformer_backends/mod.rs index c56c109..f962564 100644 --- a/src/transformer_backends/mod.rs +++ b/src/transformer_backends/mod.rs @@ -11,6 +11,7 @@ use crate::{ }; mod anthropic; +mod gemini; #[cfg(feature = "llama_cpp")] mod llama_cpp; mod mistral_fim; @@ -66,6 +67,7 @@ impl TryFrom for Box { ValidModel::OpenAI(open_ai_config) => { Ok(Box::new(open_ai::OpenAI::new(open_ai_config))) } + ValidModel::Gemini(gemini_config) => Ok(Box::new(gemini::Gemini::new(gemini_config))), ValidModel::Anthropic(anthropic_config) => { Ok(Box::new(anthropic::Anthropic::new(anthropic_config))) } From 4548de4e3e0ef8d8c67217dee7a2a2d9f7d7162b Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Mon, 17 Jun 2024 19:34:59 +0900 Subject: [PATCH 02/12] pass the gemini_completion_do_generate with cli env --- src/transformer_backends/gemini.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformer_backends/gemini.rs b/src/transformer_backends/gemini.rs index 9dacd63..641c918 100644 --- a/src/transformer_backends/gemini.rs +++ b/src/transformer_backends/gemini.rs @@ -70,7 +70,7 @@ impl Gemini { self.config .completions_endpoint .as_ref() - .context("must specify `gemini_endpoint` to use gemini")?, + .context("must specify `gemini_endpoint` to use gemini")?.to_owned() + token.as_ref(), ) .header("Content-Type", "application/json") .json(&json!( @@ -147,7 +147,7 @@ mod test { #[tokio::test] async fn gemini_completion_do_generate() -> anyhow::Result<()> { let configuration: config::Gemini = from_value(json!({ - "completions_endpoint": "https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash-latest:generateContent?key=", // here fill the key + "completions_endpoint": "https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash-latest:generateContent?key=", "model": "gemini-1.5-flash-latest", "auth_token_env_var_name": "GEMINI_API_KEY", }))?; From b5302f974809df0f451aac02b9acc3f9a76b97f3 Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Tue, 18 Jun 2024 02:45:07 +0900 Subject: [PATCH 03/12] fin gemini_chat_do_generate --- src/transformer_backends/gemini.rs | 84 ++++++++++++++++++++++++------ 1 file changed, 67 insertions(+), 17 deletions(-) diff --git a/src/transformer_backends/gemini.rs b/src/transformer_backends/gemini.rs index 641c918..667987a 100644 --- a/src/transformer_backends/gemini.rs +++ b/src/transformer_backends/gemini.rs @@ -1,13 +1,11 @@ -use std::collections::HashMap; - use anyhow::Context; use serde::Deserialize; use serde_json::{json, Value}; use tracing::instrument; -use super::{open_ai::OpenAIChatChoices, TransformerBackend}; +use super::TransformerBackend; use crate::{ - config::{self}, + config, memory_backends::{FIMPrompt, Prompt, PromptType}, transformer_worker::{ DoGenerationResponse, DoGenerationStreamResponse, GenerationStreamRequest, @@ -70,7 +68,9 @@ impl Gemini { self.config .completions_endpoint .as_ref() - .context("must specify `gemini_endpoint` to use gemini")?.to_owned() + token.as_ref(), + .context("must specify `gemini_endpoint` to use gemini")? + .to_owned() + + token.as_ref(), ) .header("Content-Type", "application/json") .json(&json!( @@ -110,6 +110,48 @@ impl Gemini { anyhow::bail!("Unknown error while making request to Gemini: {:?}", res); } } + async fn do_chat_completion( + &self, + prompt: &Prompt, + params: Value, + ) -> anyhow::Result { + let client = reqwest::Client::new(); + let token = self.get_token()?; + let res: serde_json::Value = client + .post( + self.config + .chat_endpoint + .as_ref() + .context("must specify `gemini_endpoint` to use gemini")? + .to_owned() + + token.as_ref(), + ) + .header("Content-Type", "application/json") + .json(¶ms) + .send() + .await? + .json() + .await?; + if let Some(error) = res.get("error") { + anyhow::bail!("{:?}", error.to_string()) + } else if let Some(candidates) = res.get("candidates") { + Ok(candidates + .get(0) + .unwrap() + .get("content") + .unwrap() + .get("parts") + .unwrap() + .get(0) + .unwrap() + .get("text") + .unwrap() + .clone() + .to_string()) + } else { + anyhow::bail!("Unknown error while making request to Gemini: {:?}", res); + } + } } #[async_trait::async_trait] @@ -158,32 +200,40 @@ mod test { }); let response = anthropic.do_generate(&prompt, run_params).await?; assert!(!response.generated_text.is_empty()); + dbg!(response.generated_text); Ok(()) } #[tokio::test] async fn gemini_chat_do_generate() -> anyhow::Result<()> { let configuration: config::Gemini = serde_json::from_value(json!({ - "chat_endpoint": "https://generativelanguage.googleapis.com/v1beta/", + "chat_endpoint": "https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent?key=", "model": "gemini-1.5-flash", "auth_token_env_var_name": "GEMINI_API_KEY", }))?; let gemini = Gemini::new(configuration); let prompt = Prompt::default_with_cursor(); let run_params = json!({ - "messages": [ - { - "role": "system", - "content": "Test" + "contents": [ + { + "role":"user", + "parts":[{ + "text": "Pretend you're a snowman and stay in character for each response."}] }, - { - "role": "user", - "content": "Test {CONTEXT} - {CODE}" + { + "role": "model", + "parts":[{ + "text": "Hello! It's so cold! Isn't that great?"}] + }, + { + "role": "user", + "parts":[{ + "text": "What's your favorite season of the year?"}] } - ], - "max_tokens": 64 + ] }); - let response = gemini.do_generate(&prompt, run_params).await?; - assert!(!response.generated_text.is_empty()); + let response = gemini.do_chat_completion(&prompt, run_params).await?; + dbg!(&response); + assert!(!response.is_empty()); Ok(()) } } From 0c732c604d13383b8f8f9be354e7a998cb6fbae3 Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Tue, 18 Jun 2024 02:56:33 +0900 Subject: [PATCH 04/12] fin config.rs --- src/config.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/config.rs b/src/config.rs index c6445b2..7c6e068 100644 --- a/src/config.rs +++ b/src/config.rs @@ -433,9 +433,9 @@ mod test { "models": { "model1": { "type": "gemini", - "completions_endpoint": "https://api.fireworks.ai/inference/v1/completions", - "model": "accounts/fireworks/models/llama-v2-34b-code", - "auth_token_env_var_name": "FIREWORKS_API_KEY", + "completions_endpoint": "https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash-latest:generateContent?key=", + "model": "gemini-1.5-flash-latest", + "auth_token_env_var_name": "GEMINI_API_KEY", }, }, "completion": { From c6c9dc316fb9421334828f151e07d707762a6929 Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Tue, 18 Jun 2024 02:58:19 +0900 Subject: [PATCH 05/12] update readme --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 569cd0a..b91e428 100644 --- a/README.md +++ b/README.md @@ -56,7 +56,7 @@ LSP-AI aims to fill this gap by providing a language server that integrates AI-p - LSP-AI supports any editor that adheres to the Language Server Protocol (LSP), ensuring that a wide range of editors can leverage the AI capabilities provided by LSP-AI. 5. **Flexible LLM Backend Support**: - - Currently, LSP-AI supports llama.cpp, Ollama, OpenAI-compatible APIs, Anthropic-compatible APIs and Mistral AI FIM-compatible APIs, giving developers the flexibility to choose their preferred backend. This list will soon grow. + - Currently, LSP-AI supports llama.cpp, Ollama, OpenAI-compatible APIs, Anthropic-compatible APIs, Gemini-compatible APIs and Mistral AI FIM-compatible APIs, giving developers the flexibility to choose their preferred backend. This list will soon grow. 6. **Future-Ready**: - LSP-AI is committed to staying updated with the latest advancements in LLM-driven software development. From 5d4a04ac0ee0946475c0a258f486cc87d29c15c7 Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Tue, 18 Jun 2024 03:07:37 +0900 Subject: [PATCH 06/12] dont hard code the url --- src/config.rs | 2 +- src/transformer_backends/gemini.rs | 18 +++++++++--------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/config.rs b/src/config.rs index 7c6e068..adbf279 100644 --- a/src/config.rs +++ b/src/config.rs @@ -433,7 +433,7 @@ mod test { "models": { "model1": { "type": "gemini", - "completions_endpoint": "https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash-latest:generateContent?key=", + "completions_endpoint": "https://generativelanguage.googleapis.com/v1beta/models/", "model": "gemini-1.5-flash-latest", "auth_token_env_var_name": "GEMINI_API_KEY", }, diff --git a/src/transformer_backends/gemini.rs b/src/transformer_backends/gemini.rs index 667987a..218c61e 100644 --- a/src/transformer_backends/gemini.rs +++ b/src/transformer_backends/gemini.rs @@ -70,6 +70,8 @@ impl Gemini { .as_ref() .context("must specify `gemini_endpoint` to use gemini")? .to_owned() + + self.config.model.as_ref() + + ":generateContent?key=" + token.as_ref(), ) .header("Content-Type", "application/json") @@ -110,11 +112,7 @@ impl Gemini { anyhow::bail!("Unknown error while making request to Gemini: {:?}", res); } } - async fn do_chat_completion( - &self, - prompt: &Prompt, - params: Value, - ) -> anyhow::Result { + async fn do_chat_completion(&self, prompt: &Prompt, params: Value) -> anyhow::Result { let client = reqwest::Client::new(); let token = self.get_token()?; let res: serde_json::Value = client @@ -124,6 +122,8 @@ impl Gemini { .as_ref() .context("must specify `gemini_endpoint` to use gemini")? .to_owned() + + self.config.model.as_ref() + + ":generateContent?key=" + token.as_ref(), ) .header("Content-Type", "application/json") @@ -189,16 +189,16 @@ mod test { #[tokio::test] async fn gemini_completion_do_generate() -> anyhow::Result<()> { let configuration: config::Gemini = from_value(json!({ - "completions_endpoint": "https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash-latest:generateContent?key=", + "completions_endpoint": "https://generativelanguage.googleapis.com/v1beta/models/", "model": "gemini-1.5-flash-latest", "auth_token_env_var_name": "GEMINI_API_KEY", }))?; - let anthropic = Gemini::new(configuration); + let gemini = Gemini::new(configuration); let prompt = Prompt::default_fim(); let run_params = json!({ "max_tokens": 2 }); - let response = anthropic.do_generate(&prompt, run_params).await?; + let response = gemini.do_generate(&prompt, run_params).await?; assert!(!response.generated_text.is_empty()); dbg!(response.generated_text); Ok(()) @@ -206,7 +206,7 @@ mod test { #[tokio::test] async fn gemini_chat_do_generate() -> anyhow::Result<()> { let configuration: config::Gemini = serde_json::from_value(json!({ - "chat_endpoint": "https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent?key=", + "chat_endpoint": "https://generativelanguage.googleapis.com/v1beta/models/", "model": "gemini-1.5-flash", "auth_token_env_var_name": "GEMINI_API_KEY", }))?; From e878089b64f1c470e6426c7bad434da5576eb4a3 Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Tue, 18 Jun 2024 11:56:32 +0900 Subject: [PATCH 07/12] use system format --- src/transformer_backends/gemini.rs | 99 ++++++++++++++++++------- src/transformer_backends/open_ai/mod.rs | 2 +- 2 files changed, 72 insertions(+), 29 deletions(-) diff --git a/src/transformer_backends/gemini.rs b/src/transformer_backends/gemini.rs index 218c61e..ec17d1f 100644 --- a/src/transformer_backends/gemini.rs +++ b/src/transformer_backends/gemini.rs @@ -1,15 +1,15 @@ use anyhow::Context; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; use tracing::instrument; use super::TransformerBackend; use crate::{ - config, + config::{self, ChatMessage, FIM}, memory_backends::{FIMPrompt, Prompt, PromptType}, transformer_worker::{ DoGenerationResponse, DoGenerationStreamResponse, GenerationStreamRequest, - }, + }, utils::{format_chat_messages, format_context_code}, }; const fn max_tokens_default() -> usize { @@ -27,6 +27,8 @@ const fn temperature_default() -> f32 { // NOTE: We cannot deny unknown fields as the provided parameters may contain other fields relevant to other processes #[derive(Debug, Deserialize)] pub struct GeminiRunParams { + pub fim: Option, + messages: Option>, #[serde(default = "max_tokens_default")] pub max_tokens: usize, #[serde(default = "top_p_default")] @@ -40,18 +42,18 @@ pub struct GeminiRunParams { } pub struct Gemini { - config: config::Gemini, + configuration: config::Gemini, } impl Gemini { - pub fn new(config: config::Gemini) -> Self { - Self { config } + pub fn new(configuration: config::Gemini) -> Self { + Self { configuration } } fn get_token(&self) -> anyhow::Result { - if let Some(env_var_name) = &self.config.auth_token_env_var_name { + if let Some(env_var_name) = &self.configuration.auth_token_env_var_name { Ok(std::env::var(env_var_name)?) - } else if let Some(token) = &self.config.auth_token { + } else if let Some(token) = &self.configuration.auth_token { Ok(token.to_string()) } else { anyhow::bail!( @@ -60,17 +62,21 @@ impl Gemini { } } - async fn do_fim(&self, prompt: &FIMPrompt, params: GeminiRunParams) -> anyhow::Result { + async fn get_completion( + &self, + prompt: &str, + _params: GeminiRunParams, + ) -> anyhow::Result { let client = reqwest::Client::new(); let token = self.get_token()?; let res: serde_json::Value = client .post( - self.config + self.configuration .completions_endpoint .as_ref() - .context("must specify `gemini_endpoint` to use gemini")? + .context("must specify `completions_endpoint` to use gemini")? .to_owned() - + self.config.model.as_ref() + + self.configuration.model.as_ref() + ":generateContent?key=" + token.as_ref(), ) @@ -81,7 +87,7 @@ impl Gemini { { "parts":[ { - "text": prompt.prompt + "text": prompt } ] } @@ -112,22 +118,24 @@ impl Gemini { anyhow::bail!("Unknown error while making request to Gemini: {:?}", res); } } - async fn do_chat_completion(&self, prompt: &Prompt, params: Value) -> anyhow::Result { + + async fn get_chat(&self, messages: Vec, params: GeminiRunParams) -> anyhow::Result { let client = reqwest::Client::new(); let token = self.get_token()?; let res: serde_json::Value = client .post( - self.config + self.configuration .chat_endpoint .as_ref() - .context("must specify `gemini_endpoint` to use gemini")? + .context("must specify `chat_endpoint` to use gemini")? .to_owned() - + self.config.model.as_ref() + + self.configuration.model.as_ref() + ":generateContent?key=" + token.as_ref(), ) .header("Content-Type", "application/json") - .json(¶ms) + .json(&messages) + // .json(params) .send() .await? .json() @@ -152,6 +160,44 @@ impl Gemini { anyhow::bail!("Unknown error while making request to Gemini: {:?}", res); } } + async fn do_chat_completion( + &self, + prompt: &Prompt, + params: GeminiRunParams, + ) -> anyhow::Result { + match prompt { + Prompt::ContextAndCode(code_and_context) => match ¶ms.messages { + Some(completion_messages) => { + let messages = format_chat_messages(completion_messages, code_and_context); + self.get_chat(messages, params).await + } + None => { + self.get_completion( + &format_context_code(&code_and_context.context, &code_and_context.code), + params, + ) + .await + } + }, + Prompt::FIM(fim) => match ¶ms.fim { + Some(fim_params) => { + self.get_completion( + &format!( + "{}{}{}{}{}", + fim_params.start, + fim.prompt, + fim_params.middle, + fim.suffix, + fim_params.end + ), + params, + ) + .await + } + None => anyhow::bail!("Prompt type is FIM but no FIM parameters provided"), + }, + } + } } #[async_trait::async_trait] @@ -163,7 +209,7 @@ impl TransformerBackend for Gemini { params: Value, ) -> anyhow::Result { let params: GeminiRunParams = serde_json::from_value(params)?; - let generated_text = self.do_fim(prompt.try_into()?, params).await?; + let generated_text = self.do_chat_completion(prompt, params).await?; Ok(DoGenerationResponse { generated_text }) } @@ -175,10 +221,6 @@ impl TransformerBackend for Gemini { ) -> anyhow::Result { anyhow::bail!("GenerationStream is not yet implemented") } - - fn get_prompt_type(&self, _params: &Value) -> anyhow::Result { - Ok(PromptType::FIM) - } } #[cfg(test)] @@ -194,9 +236,9 @@ mod test { "auth_token_env_var_name": "GEMINI_API_KEY", }))?; let gemini = Gemini::new(configuration); - let prompt = Prompt::default_fim(); + let prompt = Prompt::default_without_cursor(); let run_params = json!({ - "max_tokens": 2 + "max_tokens": 64 }); let response = gemini.do_generate(&prompt, run_params).await?; assert!(!response.generated_text.is_empty()); @@ -207,6 +249,7 @@ mod test { async fn gemini_chat_do_generate() -> anyhow::Result<()> { let configuration: config::Gemini = serde_json::from_value(json!({ "chat_endpoint": "https://generativelanguage.googleapis.com/v1beta/models/", + "completions_endpoint": "https://generativelanguage.googleapis.com/v1beta/models/", "model": "gemini-1.5-flash", "auth_token_env_var_name": "GEMINI_API_KEY", }))?; @@ -231,9 +274,9 @@ mod test { } ] }); - let response = gemini.do_chat_completion(&prompt, run_params).await?; - dbg!(&response); - assert!(!response.is_empty()); + let response = gemini.do_generate(&prompt, run_params).await?; + dbg!(&response.generated_text); + assert!(!response.generated_text.is_empty()); Ok(()) } } diff --git a/src/transformer_backends/open_ai/mod.rs b/src/transformer_backends/open_ai/mod.rs index a004614..d516adf 100644 --- a/src/transformer_backends/open_ai/mod.rs +++ b/src/transformer_backends/open_ai/mod.rs @@ -163,7 +163,7 @@ impl OpenAI { self.configuration .chat_endpoint .as_ref() - .context("must specify `completions_endpoint` to use completions")?, + .context("must specify `chat_endpoint` to use completions")?, ) .bearer_auth(token) .header("Content-Type", "application/json") From ad9f7381ea46b7535e8a0f99eb28c3477bc5aacd Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Tue, 18 Jun 2024 12:11:09 +0900 Subject: [PATCH 08/12] support completion first --- src/transformer_backends/gemini.rs | 81 ++---------------------------- 1 file changed, 3 insertions(+), 78 deletions(-) diff --git a/src/transformer_backends/gemini.rs b/src/transformer_backends/gemini.rs index ec17d1f..ac38b02 100644 --- a/src/transformer_backends/gemini.rs +++ b/src/transformer_backends/gemini.rs @@ -6,7 +6,7 @@ use tracing::instrument; use super::TransformerBackend; use crate::{ config::{self, ChatMessage, FIM}, - memory_backends::{FIMPrompt, Prompt, PromptType}, + memory_backends::Prompt, transformer_worker::{ DoGenerationResponse, DoGenerationStreamResponse, GenerationStreamRequest, }, utils::{format_chat_messages, format_context_code}, @@ -119,47 +119,6 @@ impl Gemini { } } - async fn get_chat(&self, messages: Vec, params: GeminiRunParams) -> anyhow::Result { - let client = reqwest::Client::new(); - let token = self.get_token()?; - let res: serde_json::Value = client - .post( - self.configuration - .chat_endpoint - .as_ref() - .context("must specify `chat_endpoint` to use gemini")? - .to_owned() - + self.configuration.model.as_ref() - + ":generateContent?key=" - + token.as_ref(), - ) - .header("Content-Type", "application/json") - .json(&messages) - // .json(params) - .send() - .await? - .json() - .await?; - if let Some(error) = res.get("error") { - anyhow::bail!("{:?}", error.to_string()) - } else if let Some(candidates) = res.get("candidates") { - Ok(candidates - .get(0) - .unwrap() - .get("content") - .unwrap() - .get("parts") - .unwrap() - .get(0) - .unwrap() - .get("text") - .unwrap() - .clone() - .to_string()) - } else { - anyhow::bail!("Unknown error while making request to Gemini: {:?}", res); - } - } async fn do_chat_completion( &self, prompt: &Prompt, @@ -168,8 +127,7 @@ impl Gemini { match prompt { Prompt::ContextAndCode(code_and_context) => match ¶ms.messages { Some(completion_messages) => { - let messages = format_chat_messages(completion_messages, code_and_context); - self.get_chat(messages, params).await + todo!(); } None => { self.get_completion( @@ -245,38 +203,5 @@ mod test { dbg!(response.generated_text); Ok(()) } - #[tokio::test] - async fn gemini_chat_do_generate() -> anyhow::Result<()> { - let configuration: config::Gemini = serde_json::from_value(json!({ - "chat_endpoint": "https://generativelanguage.googleapis.com/v1beta/models/", - "completions_endpoint": "https://generativelanguage.googleapis.com/v1beta/models/", - "model": "gemini-1.5-flash", - "auth_token_env_var_name": "GEMINI_API_KEY", - }))?; - let gemini = Gemini::new(configuration); - let prompt = Prompt::default_with_cursor(); - let run_params = json!({ - "contents": [ - { - "role":"user", - "parts":[{ - "text": "Pretend you're a snowman and stay in character for each response."}] - }, - { - "role": "model", - "parts":[{ - "text": "Hello! It's so cold! Isn't that great?"}] - }, - { - "role": "user", - "parts":[{ - "text": "What's your favorite season of the year?"}] - } - ] - }); - let response = gemini.do_generate(&prompt, run_params).await?; - dbg!(&response.generated_text); - assert!(!response.generated_text.is_empty()); - Ok(()) - } + // gemini_chat_do_generate TODO } From c8993bf740fb189102957cd1dad9640b513d0029 Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Tue, 18 Jun 2024 19:47:28 +0900 Subject: [PATCH 09/12] fix chat --- src/transformer_backends/gemini.rs | 101 +++++++++++++++++++++++++++-- 1 file changed, 95 insertions(+), 6 deletions(-) diff --git a/src/transformer_backends/gemini.rs b/src/transformer_backends/gemini.rs index ac38b02..d2ac562 100644 --- a/src/transformer_backends/gemini.rs +++ b/src/transformer_backends/gemini.rs @@ -9,7 +9,8 @@ use crate::{ memory_backends::Prompt, transformer_worker::{ DoGenerationResponse, DoGenerationStreamResponse, GenerationStreamRequest, - }, utils::{format_chat_messages, format_context_code}, + }, + utils::{format_chat_messages, format_context_code}, }; const fn max_tokens_default() -> usize { @@ -23,12 +24,22 @@ const fn top_p_default() -> f32 { const fn temperature_default() -> f32 { 0.1 } +#[derive(Debug, Serialize, Deserialize, Clone)] +struct Part { + pub text: String, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +struct GeminiChatMessage { + role: String, + parts: Vec, +} // NOTE: We cannot deny unknown fields as the provided parameters may contain other fields relevant to other processes -#[derive(Debug, Deserialize)] +#[derive(Debug, Deserialize, Clone)] pub struct GeminiRunParams { pub fim: Option, - messages: Option>, + contents: Option>, #[serde(default = "max_tokens_default")] pub max_tokens: usize, #[serde(default = "top_p_default")] @@ -118,16 +129,61 @@ impl Gemini { anyhow::bail!("Unknown error while making request to Gemini: {:?}", res); } } - + async fn get_chat( + &self, + messages: &[GeminiChatMessage], + _params: GeminiRunParams, + ) -> anyhow::Result { + let client = reqwest::Client::new(); + let token = self.get_token()?; + let res: serde_json::Value = client + .post( + self.configuration + .chat_endpoint + .as_ref() + .context("must specify `chat_endpoint` to use gemini")? + .to_owned() + + self.configuration.model.as_ref() + + ":generateContent?key=" + + token.as_ref(), + ) + .header("Content-Type", "application/json") + .json(&json!({ + "contents": messages + })) + .send() + .await? + .json() + .await?; + if let Some(error) = res.get("error") { + anyhow::bail!("{:?}", error.to_string()) + } else if let Some(candidates) = res.get("candidates") { + Ok(candidates + .get(0) + .unwrap() + .get("content") + .unwrap() + .get("parts") + .unwrap() + .get(0) + .unwrap() + .get("text") + .unwrap() + .clone() + .to_string()) + } else { + anyhow::bail!("Unknown error while making request to Gemini: {:?}", res); + } + } async fn do_chat_completion( &self, prompt: &Prompt, params: GeminiRunParams, ) -> anyhow::Result { match prompt { - Prompt::ContextAndCode(code_and_context) => match ¶ms.messages { + Prompt::ContextAndCode(code_and_context) => match ¶ms.contents { Some(completion_messages) => { - todo!(); + self.get_chat(completion_messages, params.clone()).await } None => { self.get_completion( @@ -204,4 +260,37 @@ mod test { Ok(()) } // gemini_chat_do_generate TODO + #[tokio::test] + async fn gemini_chat_do_generate() -> anyhow::Result<()> { + let configuration: config::Gemini = serde_json::from_value(json!({ + "chat_endpoint": "https://generativelanguage.googleapis.com/v1beta/models/", + "model": "gemini-1.5-flash", + "auth_token_env_var_name": "GEMINI_API_KEY", + }))?; + let gemini = Gemini::new(configuration); + let prompt = Prompt::default_with_cursor(); + let run_params = json!({ + "contents": [ + { + "role":"user", + "parts":[{ + "text": "Pretend you're a snowman and stay in character for each response."}] + }, + { + "role": "model", + "parts":[{ + "text": "Hello! It's so cold! Isn't that great?"}] + }, + { + "role": "user", + "parts":[{ + "text": "What's your favorite season of the year?"}] + } + ] + }); + let response = gemini.do_generate(&prompt, run_params).await?; + dbg!(&response.generated_text); + assert!(!response.generated_text.is_empty()); + Ok(()) + } } From be577c19e04e726393fd7385edc4ff13123143f9 Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Wed, 19 Jun 2024 09:36:16 +0900 Subject: [PATCH 10/12] fix gemini_config --- src/config.rs | 21 ++++++++++++++------- src/transformer_backends/gemini.rs | 1 - 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/src/config.rs b/src/config.rs index adbf279..0dbfa32 100644 --- a/src/config.rs +++ b/src/config.rs @@ -441,15 +441,22 @@ mod test { "completion": { "model": "model1", "parameters": { - "messages": [ + "contents": [ { - "role": "system", - "content": "Test", - }, + "role": "user", + "parts":[{ + "text": "Pretend you're a snowman and stay in character for each response."}] + }, { - "role": "user", - "content": "Test {CONTEXT} - {CODE}" - } + "role": "model", + "parts":[{ + "text": "Hello! It's so cold! Isn't that great?"}] + }, + { + "role": "user", + "parts":[{ + "text": "What's your favorite season of the year?"}] + } ], "max_new_tokens": 32, } diff --git a/src/transformer_backends/gemini.rs b/src/transformer_backends/gemini.rs index d2ac562..be85064 100644 --- a/src/transformer_backends/gemini.rs +++ b/src/transformer_backends/gemini.rs @@ -259,7 +259,6 @@ mod test { dbg!(response.generated_text); Ok(()) } - // gemini_chat_do_generate TODO #[tokio::test] async fn gemini_chat_do_generate() -> anyhow::Result<()> { let configuration: config::Gemini = serde_json::from_value(json!({ From f85c964a3022e8ab883232cf96b3d50b052bcb0f Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Wed, 19 Jun 2024 23:05:49 +0900 Subject: [PATCH 11/12] merge pull/1 --- src/config.rs | 33 +++-- src/transformer_backends/gemini.rs | 206 +++++++++++------------------ src/utils.rs | 8 +- 3 files changed, 95 insertions(+), 152 deletions(-) diff --git a/src/config.rs b/src/config.rs index 0dbfa32..8b7b394 100644 --- a/src/config.rs +++ b/src/config.rs @@ -423,6 +423,7 @@ mod test { }); Config::new(args).unwrap(); } + #[test] fn gemini_config() { let args = json!({ @@ -441,24 +442,22 @@ mod test { "completion": { "model": "model1", "parameters": { + "systemInstruction": { + "role": "system", + "parts": [{ + "text": "TEST system instruction" + }] + }, + "generationConfig": { + "maxOutputTokens": 10 + }, "contents": [ - { - "role": "user", - "parts":[{ - "text": "Pretend you're a snowman and stay in character for each response."}] - }, - { - "role": "model", - "parts":[{ - "text": "Hello! It's so cold! Isn't that great?"}] - }, - { - "role": "user", - "parts":[{ - "text": "What's your favorite season of the year?"}] - } - ], - "max_new_tokens": 32, + { + "role": "user", + "parts":[{ + "text": "TEST - {CONTEXT} and {CODE}"}] + } + ] } } } diff --git a/src/transformer_backends/gemini.rs b/src/transformer_backends/gemini.rs index be85064..3203c48 100644 --- a/src/transformer_backends/gemini.rs +++ b/src/transformer_backends/gemini.rs @@ -5,51 +5,79 @@ use tracing::instrument; use super::TransformerBackend; use crate::{ - config::{self, ChatMessage, FIM}, - memory_backends::Prompt, + config, + memory_backends::{ContextAndCodePrompt, Prompt}, transformer_worker::{ DoGenerationResponse, DoGenerationStreamResponse, GenerationStreamRequest, }, - utils::{format_chat_messages, format_context_code}, + utils::format_context_code_in_str, }; +fn format_gemini_contents( + messages: &[GeminiContent], + prompt: &ContextAndCodePrompt, +) -> Vec { + messages + .iter() + .map(|m| { + GeminiContent::new( + m.role.to_owned(), + m.parts + .iter() + .map(|p| Part { + text: format_context_code_in_str(&p.text, &prompt.context, &prompt.code), + }) + .collect(), + ) + }) + .collect() +} + const fn max_tokens_default() -> usize { 64 } -const fn top_p_default() -> f32 { - 0.95 -} - -const fn temperature_default() -> f32 { - 0.1 -} #[derive(Debug, Serialize, Deserialize, Clone)] struct Part { pub text: String, } #[derive(Debug, Serialize, Deserialize, Clone)] -struct GeminiChatMessage { +struct GeminiContent { role: String, parts: Vec, } -// NOTE: We cannot deny unknown fields as the provided parameters may contain other fields relevant to other processes -#[derive(Debug, Deserialize, Clone)] -pub struct GeminiRunParams { - pub fim: Option, - contents: Option>, - #[serde(default = "max_tokens_default")] - pub max_tokens: usize, - #[serde(default = "top_p_default")] - pub top_p: f32, - #[serde(default = "temperature_default")] - pub temperature: f32, - pub min_tokens: Option, - pub random_seed: Option, +impl GeminiContent { + fn new(role: String, parts: Vec) -> Self { + Self { role, parts } + } +} + +#[derive(Debug, Deserialize, Serialize, Clone)] +#[serde(deny_unknown_fields)] +pub struct GeminiGenerationConfig { + #[serde(rename = "stopSequences")] #[serde(default)] - pub stop: Vec, + pub stop_sequences: Vec, + #[serde(rename = "maxOutputTokens")] + #[serde(default = "max_tokens_default")] + pub max_output_tokens: usize, + pub temperature: Option, + #[serde(rename = "topP")] + pub top_p: Option, + #[serde(rename = "topK")] + pub top_k: Option, +} + +// NOTE: We cannot deny unknown fields as the provided parameters may contain other fields relevant to other processes +#[derive(Debug, Deserialize, Serialize, Clone)] +pub struct GeminiRunParams { + contents: Vec, + #[serde(rename = "systemInstruction")] + system_instruction: GeminiContent, + #[serde(rename = "generationConfig")] + generation_config: Option, } pub struct Gemini { @@ -73,66 +101,10 @@ impl Gemini { } } - async fn get_completion( - &self, - prompt: &str, - _params: GeminiRunParams, - ) -> anyhow::Result { - let client = reqwest::Client::new(); - let token = self.get_token()?; - let res: serde_json::Value = client - .post( - self.configuration - .completions_endpoint - .as_ref() - .context("must specify `completions_endpoint` to use gemini")? - .to_owned() - + self.configuration.model.as_ref() - + ":generateContent?key=" - + token.as_ref(), - ) - .header("Content-Type", "application/json") - .json(&json!( - { - "contents":[ - { - "parts":[ - { - "text": prompt - } - ] - } - ] - } - )) - .send() - .await? - .json() - .await?; - if let Some(error) = res.get("error") { - anyhow::bail!("{:?}", error.to_string()) - } else if let Some(candidates) = res.get("candidates") { - Ok(candidates - .get(0) - .unwrap() - .get("content") - .unwrap() - .get("parts") - .unwrap() - .get(0) - .unwrap() - .get("text") - .unwrap() - .clone() - .to_string()) - } else { - anyhow::bail!("Unknown error while making request to Gemini: {:?}", res); - } - } async fn get_chat( &self, - messages: &[GeminiChatMessage], - _params: GeminiRunParams, + messages: Vec, + params: GeminiRunParams, ) -> anyhow::Result { let client = reqwest::Client::new(); let token = self.get_token()?; @@ -149,7 +121,9 @@ impl Gemini { ) .header("Content-Type", "application/json") .json(&json!({ - "contents": messages + "contents": messages, + "systemInstruction": params.system_instruction, + "generationConfig": params.generation_config, })) .send() .await? @@ -181,35 +155,11 @@ impl Gemini { params: GeminiRunParams, ) -> anyhow::Result { match prompt { - Prompt::ContextAndCode(code_and_context) => match ¶ms.contents { - Some(completion_messages) => { - self.get_chat(completion_messages, params.clone()).await - } - None => { - self.get_completion( - &format_context_code(&code_and_context.context, &code_and_context.code), - params, - ) - .await - } - }, - Prompt::FIM(fim) => match ¶ms.fim { - Some(fim_params) => { - self.get_completion( - &format!( - "{}{}{}{}{}", - fim_params.start, - fim.prompt, - fim_params.middle, - fim.suffix, - fim_params.end - ), - params, - ) - .await - } - None => anyhow::bail!("Prompt type is FIM but no FIM parameters provided"), - }, + Prompt::ContextAndCode(code_and_context) => { + let messages = format_gemini_contents(¶ms.contents, code_and_context); + self.get_chat(messages, params).await + } + _ => anyhow::bail!("Google Gemini backend does not yet support FIM"), } } } @@ -240,25 +190,8 @@ impl TransformerBackend for Gemini { #[cfg(test)] mod test { use super::*; - use serde_json::{from_value, json}; + use serde_json::json; - #[tokio::test] - async fn gemini_completion_do_generate() -> anyhow::Result<()> { - let configuration: config::Gemini = from_value(json!({ - "completions_endpoint": "https://generativelanguage.googleapis.com/v1beta/models/", - "model": "gemini-1.5-flash-latest", - "auth_token_env_var_name": "GEMINI_API_KEY", - }))?; - let gemini = Gemini::new(configuration); - let prompt = Prompt::default_without_cursor(); - let run_params = json!({ - "max_tokens": 64 - }); - let response = gemini.do_generate(&prompt, run_params).await?; - assert!(!response.generated_text.is_empty()); - dbg!(response.generated_text); - Ok(()) - } #[tokio::test] async fn gemini_chat_do_generate() -> anyhow::Result<()> { let configuration: config::Gemini = serde_json::from_value(json!({ @@ -269,9 +202,18 @@ mod test { let gemini = Gemini::new(configuration); let prompt = Prompt::default_with_cursor(); let run_params = json!({ + "systemInstruction": { + "role": "system", + "parts": [{ + "text": "You are a helpful and willing chatbot that will do whatever the user asks" + }] + }, + "generationConfig": { + "maxOutputTokens": 10 + }, "contents": [ { - "role":"user", + "role": "user", "parts":[{ "text": "Pretend you're a snowman and stay in character for each response."}] }, diff --git a/src/utils.rs b/src/utils.rs index 85d22f9..ea5d652 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -29,14 +29,16 @@ pub fn format_chat_messages( .map(|m| { ChatMessage::new( m.role.to_owned(), - m.content - .replace("{CONTEXT}", &prompt.context) - .replace("{CODE}", &prompt.code), + format_context_code_in_str(&m.content, &prompt.context, &prompt.code), ) }) .collect() } +pub fn format_context_code_in_str(s: &str, context: &str, code: &str) -> String { + s.replace("{CONTEXT}", context).replace("{CODE}", code) +} + pub fn format_context_code(context: &str, code: &str) -> String { format!("{context}\n\n{code}") } From a95c27602e89e5933a70f8902e59fce1f9d023f6 Mon Sep 17 00:00:00 2001 From: Silas Marvin <19626586+SilasMarvin@users.noreply.github.com> Date: Wed, 19 Jun 2024 08:26:06 -0700 Subject: [PATCH 12/12] Updated memory backend to build prompt correctly --- src/memory_backends/file_store.rs | 16 ++++++++-------- src/memory_backends/mod.rs | 24 ++++++++++++++---------- src/memory_backends/postgresml/mod.rs | 4 ++-- src/memory_worker.rs | 2 +- 4 files changed, 25 insertions(+), 21 deletions(-) diff --git a/src/memory_backends/file_store.rs b/src/memory_backends/file_store.rs index 219998b..4d70509 100644 --- a/src/memory_backends/file_store.rs +++ b/src/memory_backends/file_store.rs @@ -114,7 +114,7 @@ impl FileStore { Ok(match prompt_type { PromptType::ContextAndCode => { - if params.messages.is_some() { + if params.is_for_chat { let max_length = tokens_to_estimated_characters(params.max_context_length); let start = cursor_index.saturating_sub(max_length / 2); let end = rope @@ -185,9 +185,9 @@ impl MemoryBackend for FileStore { &self, position: &TextDocumentPositionParams, prompt_type: PromptType, - params: Value, + params: &Value, ) -> anyhow::Result { - let params: MemoryRunParams = serde_json::from_value(params)?; + let params: MemoryRunParams = params.try_into()?; self.build_code(position, prompt_type, params) } @@ -414,7 +414,7 @@ The end with a trailing new line }, }, PromptType::ContextAndCode, - json!({}), + &json!({}), ) .await?; let prompt: ContextAndCodePrompt = prompt.try_into()?; @@ -434,7 +434,7 @@ The end with a trailing new line }, }, PromptType::FIM, - json!({}), + &json!({}), ) .await?; let prompt: FIMPrompt = prompt.try_into()?; @@ -463,7 +463,7 @@ The end with a trailing new line }, }, PromptType::ContextAndCode, - json!({ + &json!({ "messages": [] }), ) @@ -510,7 +510,7 @@ The end with a trailing new line }, }, PromptType::ContextAndCode, - json!({}), + &json!({}), ) .await?; let prompt: ContextAndCodePrompt = prompt.try_into()?; @@ -542,7 +542,7 @@ The end with a trailing new line }, }, PromptType::ContextAndCode, - json!({"messages": []}), + &json!({"messages": []}), ) .await?; let prompt: ContextAndCodePrompt = prompt.try_into()?; diff --git a/src/memory_backends/mod.rs b/src/memory_backends/mod.rs index e824d3f..52a8974 100644 --- a/src/memory_backends/mod.rs +++ b/src/memory_backends/mod.rs @@ -2,31 +2,35 @@ use lsp_types::{ DidChangeTextDocumentParams, DidOpenTextDocumentParams, RenameFilesParams, TextDocumentPositionParams, }; -use serde::Deserialize; use serde_json::Value; -use crate::config::{ChatMessage, Config, ValidMemoryBackend}; +use crate::config::{Config, ValidMemoryBackend}; pub mod file_store; mod postgresml; -const fn max_context_length_default() -> usize { - 1024 -} - #[derive(Clone, Debug)] pub enum PromptType { ContextAndCode, FIM, } -#[derive(Clone, Deserialize)] +#[derive(Clone)] pub struct MemoryRunParams { - pub messages: Option>, - #[serde(default = "max_context_length_default")] + pub is_for_chat: bool, pub max_context_length: usize, } +impl From<&Value> for MemoryRunParams { + fn from(value: &Value) -> Self { + Self { + max_context_length: value["max_context_length"].as_u64().unwrap_or(1024) as usize, + // messages are for most backends, contents are for Gemini + is_for_chat: value["messages"].is_array() || value["contents"].is_array(), + } + } +} + #[derive(Debug)] pub struct ContextAndCodePrompt { pub context: String, @@ -119,7 +123,7 @@ pub trait MemoryBackend { &self, position: &TextDocumentPositionParams, prompt_type: PromptType, - params: Value, + params: &Value, ) -> anyhow::Result; async fn get_filter_text( &self, diff --git a/src/memory_backends/postgresml/mod.rs b/src/memory_backends/postgresml/mod.rs index 8e0f748..8b007ab 100644 --- a/src/memory_backends/postgresml/mod.rs +++ b/src/memory_backends/postgresml/mod.rs @@ -132,9 +132,9 @@ impl MemoryBackend for PostgresML { &self, position: &TextDocumentPositionParams, prompt_type: PromptType, - params: Value, + params: &Value, ) -> anyhow::Result { - let params: MemoryRunParams = serde_json::from_value(params)?; + let params: MemoryRunParams = params.try_into()?; let query = self .file_store .get_characters_around_position(position, 512)?; diff --git a/src/memory_worker.rs b/src/memory_worker.rs index 86082f8..39cad6c 100644 --- a/src/memory_worker.rs +++ b/src/memory_worker.rs @@ -70,7 +70,7 @@ async fn do_task( } WorkerRequest::Prompt(params) => { let prompt = memory_backend - .build_prompt(¶ms.position, params.prompt_type, params.params) + .build_prompt(¶ms.position, params.prompt_type, ¶ms.params) .await?; params .tx