fin gemini_chat_do_generate

This commit is contained in:
Asuka Minato 2024-06-18 02:45:07 +09:00
parent 4548de4e3e
commit b5302f9748

View File

@ -1,13 +1,11 @@
use std::collections::HashMap;
use anyhow::Context;
use serde::Deserialize;
use serde_json::{json, Value};
use tracing::instrument;
use super::{open_ai::OpenAIChatChoices, TransformerBackend};
use super::TransformerBackend;
use crate::{
config::{self},
config,
memory_backends::{FIMPrompt, Prompt, PromptType},
transformer_worker::{
DoGenerationResponse, DoGenerationStreamResponse, GenerationStreamRequest,
@ -70,7 +68,9 @@ impl Gemini {
self.config
.completions_endpoint
.as_ref()
.context("must specify `gemini_endpoint` to use gemini")?.to_owned() + token.as_ref(),
.context("must specify `gemini_endpoint` to use gemini")?
.to_owned()
+ token.as_ref(),
)
.header("Content-Type", "application/json")
.json(&json!(
@ -110,6 +110,48 @@ impl Gemini {
anyhow::bail!("Unknown error while making request to Gemini: {:?}", res);
}
}
async fn do_chat_completion(
&self,
prompt: &Prompt,
params: Value,
) -> anyhow::Result<String> {
let client = reqwest::Client::new();
let token = self.get_token()?;
let res: serde_json::Value = client
.post(
self.config
.chat_endpoint
.as_ref()
.context("must specify `gemini_endpoint` to use gemini")?
.to_owned()
+ token.as_ref(),
)
.header("Content-Type", "application/json")
.json(&params)
.send()
.await?
.json()
.await?;
if let Some(error) = res.get("error") {
anyhow::bail!("{:?}", error.to_string())
} else if let Some(candidates) = res.get("candidates") {
Ok(candidates
.get(0)
.unwrap()
.get("content")
.unwrap()
.get("parts")
.unwrap()
.get(0)
.unwrap()
.get("text")
.unwrap()
.clone()
.to_string())
} else {
anyhow::bail!("Unknown error while making request to Gemini: {:?}", res);
}
}
}
#[async_trait::async_trait]
@ -158,32 +200,40 @@ mod test {
});
let response = anthropic.do_generate(&prompt, run_params).await?;
assert!(!response.generated_text.is_empty());
dbg!(response.generated_text);
Ok(())
}
#[tokio::test]
async fn gemini_chat_do_generate() -> anyhow::Result<()> {
let configuration: config::Gemini = serde_json::from_value(json!({
"chat_endpoint": "https://generativelanguage.googleapis.com/v1beta/",
"chat_endpoint": "https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent?key=",
"model": "gemini-1.5-flash",
"auth_token_env_var_name": "GEMINI_API_KEY",
}))?;
let gemini = Gemini::new(configuration);
let prompt = Prompt::default_with_cursor();
let run_params = json!({
"messages": [
{
"role": "system",
"content": "Test"
"contents": [
{
"role":"user",
"parts":[{
"text": "Pretend you're a snowman and stay in character for each response."}]
},
{
"role": "user",
"content": "Test {CONTEXT} - {CODE}"
{
"role": "model",
"parts":[{
"text": "Hello! It's so cold! Isn't that great?"}]
},
{
"role": "user",
"parts":[{
"text": "What's your favorite season of the year?"}]
}
],
"max_tokens": 64
]
});
let response = gemini.do_generate(&prompt, run_params).await?;
assert!(!response.generated_text.is_empty());
let response = gemini.do_chat_completion(&prompt, run_params).await?;
dbg!(&response);
assert!(!response.is_empty());
Ok(())
}
}