diff --git a/src/transformer_backends/gemini.rs b/src/transformer_backends/gemini.rs index ec17d1f..ac38b02 100644 --- a/src/transformer_backends/gemini.rs +++ b/src/transformer_backends/gemini.rs @@ -6,7 +6,7 @@ use tracing::instrument; use super::TransformerBackend; use crate::{ config::{self, ChatMessage, FIM}, - memory_backends::{FIMPrompt, Prompt, PromptType}, + memory_backends::Prompt, transformer_worker::{ DoGenerationResponse, DoGenerationStreamResponse, GenerationStreamRequest, }, utils::{format_chat_messages, format_context_code}, @@ -119,47 +119,6 @@ impl Gemini { } } - async fn get_chat(&self, messages: Vec, params: GeminiRunParams) -> anyhow::Result { - let client = reqwest::Client::new(); - let token = self.get_token()?; - let res: serde_json::Value = client - .post( - self.configuration - .chat_endpoint - .as_ref() - .context("must specify `chat_endpoint` to use gemini")? - .to_owned() - + self.configuration.model.as_ref() - + ":generateContent?key=" - + token.as_ref(), - ) - .header("Content-Type", "application/json") - .json(&messages) - // .json(params) - .send() - .await? - .json() - .await?; - if let Some(error) = res.get("error") { - anyhow::bail!("{:?}", error.to_string()) - } else if let Some(candidates) = res.get("candidates") { - Ok(candidates - .get(0) - .unwrap() - .get("content") - .unwrap() - .get("parts") - .unwrap() - .get(0) - .unwrap() - .get("text") - .unwrap() - .clone() - .to_string()) - } else { - anyhow::bail!("Unknown error while making request to Gemini: {:?}", res); - } - } async fn do_chat_completion( &self, prompt: &Prompt, @@ -168,8 +127,7 @@ impl Gemini { match prompt { Prompt::ContextAndCode(code_and_context) => match ¶ms.messages { Some(completion_messages) => { - let messages = format_chat_messages(completion_messages, code_and_context); - self.get_chat(messages, params).await + todo!(); } None => { self.get_completion( @@ -245,38 +203,5 @@ mod test { dbg!(response.generated_text); Ok(()) } - #[tokio::test] - async fn gemini_chat_do_generate() -> anyhow::Result<()> { - let configuration: config::Gemini = serde_json::from_value(json!({ - "chat_endpoint": "https://generativelanguage.googleapis.com/v1beta/models/", - "completions_endpoint": "https://generativelanguage.googleapis.com/v1beta/models/", - "model": "gemini-1.5-flash", - "auth_token_env_var_name": "GEMINI_API_KEY", - }))?; - let gemini = Gemini::new(configuration); - let prompt = Prompt::default_with_cursor(); - let run_params = json!({ - "contents": [ - { - "role":"user", - "parts":[{ - "text": "Pretend you're a snowman and stay in character for each response."}] - }, - { - "role": "model", - "parts":[{ - "text": "Hello! It's so cold! Isn't that great?"}] - }, - { - "role": "user", - "parts":[{ - "text": "What's your favorite season of the year?"}] - } - ] - }); - let response = gemini.do_generate(&prompt, run_params).await?; - dbg!(&response.generated_text); - assert!(!response.generated_text.is_empty()); - Ok(()) - } + // gemini_chat_do_generate TODO }