diff --git a/src/config.rs b/src/config.rs index 0dbfa32..8b7b394 100644 --- a/src/config.rs +++ b/src/config.rs @@ -423,6 +423,7 @@ mod test { }); Config::new(args).unwrap(); } + #[test] fn gemini_config() { let args = json!({ @@ -441,24 +442,22 @@ mod test { "completion": { "model": "model1", "parameters": { + "systemInstruction": { + "role": "system", + "parts": [{ + "text": "TEST system instruction" + }] + }, + "generationConfig": { + "maxOutputTokens": 10 + }, "contents": [ - { - "role": "user", - "parts":[{ - "text": "Pretend you're a snowman and stay in character for each response."}] - }, - { - "role": "model", - "parts":[{ - "text": "Hello! It's so cold! Isn't that great?"}] - }, - { - "role": "user", - "parts":[{ - "text": "What's your favorite season of the year?"}] - } - ], - "max_new_tokens": 32, + { + "role": "user", + "parts":[{ + "text": "TEST - {CONTEXT} and {CODE}"}] + } + ] } } } diff --git a/src/transformer_backends/gemini.rs b/src/transformer_backends/gemini.rs index be85064..3203c48 100644 --- a/src/transformer_backends/gemini.rs +++ b/src/transformer_backends/gemini.rs @@ -5,51 +5,79 @@ use tracing::instrument; use super::TransformerBackend; use crate::{ - config::{self, ChatMessage, FIM}, - memory_backends::Prompt, + config, + memory_backends::{ContextAndCodePrompt, Prompt}, transformer_worker::{ DoGenerationResponse, DoGenerationStreamResponse, GenerationStreamRequest, }, - utils::{format_chat_messages, format_context_code}, + utils::format_context_code_in_str, }; +fn format_gemini_contents( + messages: &[GeminiContent], + prompt: &ContextAndCodePrompt, +) -> Vec { + messages + .iter() + .map(|m| { + GeminiContent::new( + m.role.to_owned(), + m.parts + .iter() + .map(|p| Part { + text: format_context_code_in_str(&p.text, &prompt.context, &prompt.code), + }) + .collect(), + ) + }) + .collect() +} + const fn max_tokens_default() -> usize { 64 } -const fn top_p_default() -> f32 { - 0.95 -} - -const fn temperature_default() -> f32 { - 0.1 -} #[derive(Debug, Serialize, Deserialize, Clone)] struct Part { pub text: String, } #[derive(Debug, Serialize, Deserialize, Clone)] -struct GeminiChatMessage { +struct GeminiContent { role: String, parts: Vec, } -// NOTE: We cannot deny unknown fields as the provided parameters may contain other fields relevant to other processes -#[derive(Debug, Deserialize, Clone)] -pub struct GeminiRunParams { - pub fim: Option, - contents: Option>, - #[serde(default = "max_tokens_default")] - pub max_tokens: usize, - #[serde(default = "top_p_default")] - pub top_p: f32, - #[serde(default = "temperature_default")] - pub temperature: f32, - pub min_tokens: Option, - pub random_seed: Option, +impl GeminiContent { + fn new(role: String, parts: Vec) -> Self { + Self { role, parts } + } +} + +#[derive(Debug, Deserialize, Serialize, Clone)] +#[serde(deny_unknown_fields)] +pub struct GeminiGenerationConfig { + #[serde(rename = "stopSequences")] #[serde(default)] - pub stop: Vec, + pub stop_sequences: Vec, + #[serde(rename = "maxOutputTokens")] + #[serde(default = "max_tokens_default")] + pub max_output_tokens: usize, + pub temperature: Option, + #[serde(rename = "topP")] + pub top_p: Option, + #[serde(rename = "topK")] + pub top_k: Option, +} + +// NOTE: We cannot deny unknown fields as the provided parameters may contain other fields relevant to other processes +#[derive(Debug, Deserialize, Serialize, Clone)] +pub struct GeminiRunParams { + contents: Vec, + #[serde(rename = "systemInstruction")] + system_instruction: GeminiContent, + #[serde(rename = "generationConfig")] + generation_config: Option, } pub struct Gemini { @@ -73,66 +101,10 @@ impl Gemini { } } - async fn get_completion( - &self, - prompt: &str, - _params: GeminiRunParams, - ) -> anyhow::Result { - let client = reqwest::Client::new(); - let token = self.get_token()?; - let res: serde_json::Value = client - .post( - self.configuration - .completions_endpoint - .as_ref() - .context("must specify `completions_endpoint` to use gemini")? - .to_owned() - + self.configuration.model.as_ref() - + ":generateContent?key=" - + token.as_ref(), - ) - .header("Content-Type", "application/json") - .json(&json!( - { - "contents":[ - { - "parts":[ - { - "text": prompt - } - ] - } - ] - } - )) - .send() - .await? - .json() - .await?; - if let Some(error) = res.get("error") { - anyhow::bail!("{:?}", error.to_string()) - } else if let Some(candidates) = res.get("candidates") { - Ok(candidates - .get(0) - .unwrap() - .get("content") - .unwrap() - .get("parts") - .unwrap() - .get(0) - .unwrap() - .get("text") - .unwrap() - .clone() - .to_string()) - } else { - anyhow::bail!("Unknown error while making request to Gemini: {:?}", res); - } - } async fn get_chat( &self, - messages: &[GeminiChatMessage], - _params: GeminiRunParams, + messages: Vec, + params: GeminiRunParams, ) -> anyhow::Result { let client = reqwest::Client::new(); let token = self.get_token()?; @@ -149,7 +121,9 @@ impl Gemini { ) .header("Content-Type", "application/json") .json(&json!({ - "contents": messages + "contents": messages, + "systemInstruction": params.system_instruction, + "generationConfig": params.generation_config, })) .send() .await? @@ -181,35 +155,11 @@ impl Gemini { params: GeminiRunParams, ) -> anyhow::Result { match prompt { - Prompt::ContextAndCode(code_and_context) => match ¶ms.contents { - Some(completion_messages) => { - self.get_chat(completion_messages, params.clone()).await - } - None => { - self.get_completion( - &format_context_code(&code_and_context.context, &code_and_context.code), - params, - ) - .await - } - }, - Prompt::FIM(fim) => match ¶ms.fim { - Some(fim_params) => { - self.get_completion( - &format!( - "{}{}{}{}{}", - fim_params.start, - fim.prompt, - fim_params.middle, - fim.suffix, - fim_params.end - ), - params, - ) - .await - } - None => anyhow::bail!("Prompt type is FIM but no FIM parameters provided"), - }, + Prompt::ContextAndCode(code_and_context) => { + let messages = format_gemini_contents(¶ms.contents, code_and_context); + self.get_chat(messages, params).await + } + _ => anyhow::bail!("Google Gemini backend does not yet support FIM"), } } } @@ -240,25 +190,8 @@ impl TransformerBackend for Gemini { #[cfg(test)] mod test { use super::*; - use serde_json::{from_value, json}; + use serde_json::json; - #[tokio::test] - async fn gemini_completion_do_generate() -> anyhow::Result<()> { - let configuration: config::Gemini = from_value(json!({ - "completions_endpoint": "https://generativelanguage.googleapis.com/v1beta/models/", - "model": "gemini-1.5-flash-latest", - "auth_token_env_var_name": "GEMINI_API_KEY", - }))?; - let gemini = Gemini::new(configuration); - let prompt = Prompt::default_without_cursor(); - let run_params = json!({ - "max_tokens": 64 - }); - let response = gemini.do_generate(&prompt, run_params).await?; - assert!(!response.generated_text.is_empty()); - dbg!(response.generated_text); - Ok(()) - } #[tokio::test] async fn gemini_chat_do_generate() -> anyhow::Result<()> { let configuration: config::Gemini = serde_json::from_value(json!({ @@ -269,9 +202,18 @@ mod test { let gemini = Gemini::new(configuration); let prompt = Prompt::default_with_cursor(); let run_params = json!({ + "systemInstruction": { + "role": "system", + "parts": [{ + "text": "You are a helpful and willing chatbot that will do whatever the user asks" + }] + }, + "generationConfig": { + "maxOutputTokens": 10 + }, "contents": [ { - "role":"user", + "role": "user", "parts":[{ "text": "Pretend you're a snowman and stay in character for each response."}] }, diff --git a/src/utils.rs b/src/utils.rs index 85d22f9..ea5d652 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -29,14 +29,16 @@ pub fn format_chat_messages( .map(|m| { ChatMessage::new( m.role.to_owned(), - m.content - .replace("{CONTEXT}", &prompt.context) - .replace("{CODE}", &prompt.code), + format_context_code_in_str(&m.content, &prompt.context, &prompt.code), ) }) .collect() } +pub fn format_context_code_in_str(s: &str, context: &str, code: &str) -> String { + s.replace("{CONTEXT}", context).replace("{CODE}", code) +} + pub fn format_context_code(context: &str, code: &str) -> String { format!("{context}\n\n{code}") }