diff --git a/src/transformer_backends/gemini.rs b/src/transformer_backends/gemini.rs index ac38b02..d2ac562 100644 --- a/src/transformer_backends/gemini.rs +++ b/src/transformer_backends/gemini.rs @@ -9,7 +9,8 @@ use crate::{ memory_backends::Prompt, transformer_worker::{ DoGenerationResponse, DoGenerationStreamResponse, GenerationStreamRequest, - }, utils::{format_chat_messages, format_context_code}, + }, + utils::{format_chat_messages, format_context_code}, }; const fn max_tokens_default() -> usize { @@ -23,12 +24,22 @@ const fn top_p_default() -> f32 { const fn temperature_default() -> f32 { 0.1 } +#[derive(Debug, Serialize, Deserialize, Clone)] +struct Part { + pub text: String, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +struct GeminiChatMessage { + role: String, + parts: Vec, +} // NOTE: We cannot deny unknown fields as the provided parameters may contain other fields relevant to other processes -#[derive(Debug, Deserialize)] +#[derive(Debug, Deserialize, Clone)] pub struct GeminiRunParams { pub fim: Option, - messages: Option>, + contents: Option>, #[serde(default = "max_tokens_default")] pub max_tokens: usize, #[serde(default = "top_p_default")] @@ -118,16 +129,61 @@ impl Gemini { anyhow::bail!("Unknown error while making request to Gemini: {:?}", res); } } - + async fn get_chat( + &self, + messages: &[GeminiChatMessage], + _params: GeminiRunParams, + ) -> anyhow::Result { + let client = reqwest::Client::new(); + let token = self.get_token()?; + let res: serde_json::Value = client + .post( + self.configuration + .chat_endpoint + .as_ref() + .context("must specify `chat_endpoint` to use gemini")? + .to_owned() + + self.configuration.model.as_ref() + + ":generateContent?key=" + + token.as_ref(), + ) + .header("Content-Type", "application/json") + .json(&json!({ + "contents": messages + })) + .send() + .await? + .json() + .await?; + if let Some(error) = res.get("error") { + anyhow::bail!("{:?}", error.to_string()) + } else if let Some(candidates) = res.get("candidates") { + Ok(candidates + .get(0) + .unwrap() + .get("content") + .unwrap() + .get("parts") + .unwrap() + .get(0) + .unwrap() + .get("text") + .unwrap() + .clone() + .to_string()) + } else { + anyhow::bail!("Unknown error while making request to Gemini: {:?}", res); + } + } async fn do_chat_completion( &self, prompt: &Prompt, params: GeminiRunParams, ) -> anyhow::Result { match prompt { - Prompt::ContextAndCode(code_and_context) => match ¶ms.messages { + Prompt::ContextAndCode(code_and_context) => match ¶ms.contents { Some(completion_messages) => { - todo!(); + self.get_chat(completion_messages, params.clone()).await } None => { self.get_completion( @@ -204,4 +260,37 @@ mod test { Ok(()) } // gemini_chat_do_generate TODO + #[tokio::test] + async fn gemini_chat_do_generate() -> anyhow::Result<()> { + let configuration: config::Gemini = serde_json::from_value(json!({ + "chat_endpoint": "https://generativelanguage.googleapis.com/v1beta/models/", + "model": "gemini-1.5-flash", + "auth_token_env_var_name": "GEMINI_API_KEY", + }))?; + let gemini = Gemini::new(configuration); + let prompt = Prompt::default_with_cursor(); + let run_params = json!({ + "contents": [ + { + "role":"user", + "parts":[{ + "text": "Pretend you're a snowman and stay in character for each response."}] + }, + { + "role": "model", + "parts":[{ + "text": "Hello! It's so cold! Isn't that great?"}] + }, + { + "role": "user", + "parts":[{ + "text": "What's your favorite season of the year?"}] + } + ] + }); + let response = gemini.do_generate(&prompt, run_params).await?; + dbg!(&response.generated_text); + assert!(!response.generated_text.is_empty()); + Ok(()) + } }