From e878089b64f1c470e6426c7bad434da5576eb4a3 Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Tue, 18 Jun 2024 11:56:32 +0900 Subject: [PATCH] use system format --- src/transformer_backends/gemini.rs | 99 ++++++++++++++++++------- src/transformer_backends/open_ai/mod.rs | 2 +- 2 files changed, 72 insertions(+), 29 deletions(-) diff --git a/src/transformer_backends/gemini.rs b/src/transformer_backends/gemini.rs index 218c61e..ec17d1f 100644 --- a/src/transformer_backends/gemini.rs +++ b/src/transformer_backends/gemini.rs @@ -1,15 +1,15 @@ use anyhow::Context; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; use tracing::instrument; use super::TransformerBackend; use crate::{ - config, + config::{self, ChatMessage, FIM}, memory_backends::{FIMPrompt, Prompt, PromptType}, transformer_worker::{ DoGenerationResponse, DoGenerationStreamResponse, GenerationStreamRequest, - }, + }, utils::{format_chat_messages, format_context_code}, }; const fn max_tokens_default() -> usize { @@ -27,6 +27,8 @@ const fn temperature_default() -> f32 { // NOTE: We cannot deny unknown fields as the provided parameters may contain other fields relevant to other processes #[derive(Debug, Deserialize)] pub struct GeminiRunParams { + pub fim: Option, + messages: Option>, #[serde(default = "max_tokens_default")] pub max_tokens: usize, #[serde(default = "top_p_default")] @@ -40,18 +42,18 @@ pub struct GeminiRunParams { } pub struct Gemini { - config: config::Gemini, + configuration: config::Gemini, } impl Gemini { - pub fn new(config: config::Gemini) -> Self { - Self { config } + pub fn new(configuration: config::Gemini) -> Self { + Self { configuration } } fn get_token(&self) -> anyhow::Result { - if let Some(env_var_name) = &self.config.auth_token_env_var_name { + if let Some(env_var_name) = &self.configuration.auth_token_env_var_name { Ok(std::env::var(env_var_name)?) - } else if let Some(token) = &self.config.auth_token { + } else if let Some(token) = &self.configuration.auth_token { Ok(token.to_string()) } else { anyhow::bail!( @@ -60,17 +62,21 @@ impl Gemini { } } - async fn do_fim(&self, prompt: &FIMPrompt, params: GeminiRunParams) -> anyhow::Result { + async fn get_completion( + &self, + prompt: &str, + _params: GeminiRunParams, + ) -> anyhow::Result { let client = reqwest::Client::new(); let token = self.get_token()?; let res: serde_json::Value = client .post( - self.config + self.configuration .completions_endpoint .as_ref() - .context("must specify `gemini_endpoint` to use gemini")? + .context("must specify `completions_endpoint` to use gemini")? .to_owned() - + self.config.model.as_ref() + + self.configuration.model.as_ref() + ":generateContent?key=" + token.as_ref(), ) @@ -81,7 +87,7 @@ impl Gemini { { "parts":[ { - "text": prompt.prompt + "text": prompt } ] } @@ -112,22 +118,24 @@ impl Gemini { anyhow::bail!("Unknown error while making request to Gemini: {:?}", res); } } - async fn do_chat_completion(&self, prompt: &Prompt, params: Value) -> anyhow::Result { + + async fn get_chat(&self, messages: Vec, params: GeminiRunParams) -> anyhow::Result { let client = reqwest::Client::new(); let token = self.get_token()?; let res: serde_json::Value = client .post( - self.config + self.configuration .chat_endpoint .as_ref() - .context("must specify `gemini_endpoint` to use gemini")? + .context("must specify `chat_endpoint` to use gemini")? .to_owned() - + self.config.model.as_ref() + + self.configuration.model.as_ref() + ":generateContent?key=" + token.as_ref(), ) .header("Content-Type", "application/json") - .json(¶ms) + .json(&messages) + // .json(params) .send() .await? .json() @@ -152,6 +160,44 @@ impl Gemini { anyhow::bail!("Unknown error while making request to Gemini: {:?}", res); } } + async fn do_chat_completion( + &self, + prompt: &Prompt, + params: GeminiRunParams, + ) -> anyhow::Result { + match prompt { + Prompt::ContextAndCode(code_and_context) => match ¶ms.messages { + Some(completion_messages) => { + let messages = format_chat_messages(completion_messages, code_and_context); + self.get_chat(messages, params).await + } + None => { + self.get_completion( + &format_context_code(&code_and_context.context, &code_and_context.code), + params, + ) + .await + } + }, + Prompt::FIM(fim) => match ¶ms.fim { + Some(fim_params) => { + self.get_completion( + &format!( + "{}{}{}{}{}", + fim_params.start, + fim.prompt, + fim_params.middle, + fim.suffix, + fim_params.end + ), + params, + ) + .await + } + None => anyhow::bail!("Prompt type is FIM but no FIM parameters provided"), + }, + } + } } #[async_trait::async_trait] @@ -163,7 +209,7 @@ impl TransformerBackend for Gemini { params: Value, ) -> anyhow::Result { let params: GeminiRunParams = serde_json::from_value(params)?; - let generated_text = self.do_fim(prompt.try_into()?, params).await?; + let generated_text = self.do_chat_completion(prompt, params).await?; Ok(DoGenerationResponse { generated_text }) } @@ -175,10 +221,6 @@ impl TransformerBackend for Gemini { ) -> anyhow::Result { anyhow::bail!("GenerationStream is not yet implemented") } - - fn get_prompt_type(&self, _params: &Value) -> anyhow::Result { - Ok(PromptType::FIM) - } } #[cfg(test)] @@ -194,9 +236,9 @@ mod test { "auth_token_env_var_name": "GEMINI_API_KEY", }))?; let gemini = Gemini::new(configuration); - let prompt = Prompt::default_fim(); + let prompt = Prompt::default_without_cursor(); let run_params = json!({ - "max_tokens": 2 + "max_tokens": 64 }); let response = gemini.do_generate(&prompt, run_params).await?; assert!(!response.generated_text.is_empty()); @@ -207,6 +249,7 @@ mod test { async fn gemini_chat_do_generate() -> anyhow::Result<()> { let configuration: config::Gemini = serde_json::from_value(json!({ "chat_endpoint": "https://generativelanguage.googleapis.com/v1beta/models/", + "completions_endpoint": "https://generativelanguage.googleapis.com/v1beta/models/", "model": "gemini-1.5-flash", "auth_token_env_var_name": "GEMINI_API_KEY", }))?; @@ -231,9 +274,9 @@ mod test { } ] }); - let response = gemini.do_chat_completion(&prompt, run_params).await?; - dbg!(&response); - assert!(!response.is_empty()); + let response = gemini.do_generate(&prompt, run_params).await?; + dbg!(&response.generated_text); + assert!(!response.generated_text.is_empty()); Ok(()) } } diff --git a/src/transformer_backends/open_ai/mod.rs b/src/transformer_backends/open_ai/mod.rs index a004614..d516adf 100644 --- a/src/transformer_backends/open_ai/mod.rs +++ b/src/transformer_backends/open_ai/mod.rs @@ -163,7 +163,7 @@ impl OpenAI { self.configuration .chat_endpoint .as_ref() - .context("must specify `completions_endpoint` to use completions")?, + .context("must specify `chat_endpoint` to use completions")?, ) .bearer_auth(token) .header("Content-Type", "application/json")