This commit is contained in:
Asuka Minato 2024-06-18 19:47:28 +09:00
parent ad9f7381ea
commit c8993bf740

View File

@ -9,7 +9,8 @@ use crate::{
memory_backends::Prompt,
transformer_worker::{
DoGenerationResponse, DoGenerationStreamResponse, GenerationStreamRequest,
}, utils::{format_chat_messages, format_context_code},
},
utils::{format_chat_messages, format_context_code},
};
const fn max_tokens_default() -> usize {
@ -23,12 +24,22 @@ const fn top_p_default() -> f32 {
const fn temperature_default() -> f32 {
0.1
}
#[derive(Debug, Serialize, Deserialize, Clone)]
struct Part {
pub text: String,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
struct GeminiChatMessage {
role: String,
parts: Vec<Part>,
}
// NOTE: We cannot deny unknown fields as the provided parameters may contain other fields relevant to other processes
#[derive(Debug, Deserialize)]
#[derive(Debug, Deserialize, Clone)]
pub struct GeminiRunParams {
pub fim: Option<FIM>,
messages: Option<Vec<ChatMessage>>,
contents: Option<Vec<GeminiChatMessage>>,
#[serde(default = "max_tokens_default")]
pub max_tokens: usize,
#[serde(default = "top_p_default")]
@ -118,16 +129,61 @@ impl Gemini {
anyhow::bail!("Unknown error while making request to Gemini: {:?}", res);
}
}
async fn get_chat(
&self,
messages: &[GeminiChatMessage],
_params: GeminiRunParams,
) -> anyhow::Result<String> {
let client = reqwest::Client::new();
let token = self.get_token()?;
let res: serde_json::Value = client
.post(
self.configuration
.chat_endpoint
.as_ref()
.context("must specify `chat_endpoint` to use gemini")?
.to_owned()
+ self.configuration.model.as_ref()
+ ":generateContent?key="
+ token.as_ref(),
)
.header("Content-Type", "application/json")
.json(&json!({
"contents": messages
}))
.send()
.await?
.json()
.await?;
if let Some(error) = res.get("error") {
anyhow::bail!("{:?}", error.to_string())
} else if let Some(candidates) = res.get("candidates") {
Ok(candidates
.get(0)
.unwrap()
.get("content")
.unwrap()
.get("parts")
.unwrap()
.get(0)
.unwrap()
.get("text")
.unwrap()
.clone()
.to_string())
} else {
anyhow::bail!("Unknown error while making request to Gemini: {:?}", res);
}
}
async fn do_chat_completion(
&self,
prompt: &Prompt,
params: GeminiRunParams,
) -> anyhow::Result<String> {
match prompt {
Prompt::ContextAndCode(code_and_context) => match &params.messages {
Prompt::ContextAndCode(code_and_context) => match &params.contents {
Some(completion_messages) => {
todo!();
self.get_chat(completion_messages, params.clone()).await
}
None => {
self.get_completion(
@ -204,4 +260,37 @@ mod test {
Ok(())
}
// gemini_chat_do_generate TODO
#[tokio::test]
async fn gemini_chat_do_generate() -> anyhow::Result<()> {
let configuration: config::Gemini = serde_json::from_value(json!({
"chat_endpoint": "https://generativelanguage.googleapis.com/v1beta/models/",
"model": "gemini-1.5-flash",
"auth_token_env_var_name": "GEMINI_API_KEY",
}))?;
let gemini = Gemini::new(configuration);
let prompt = Prompt::default_with_cursor();
let run_params = json!({
"contents": [
{
"role":"user",
"parts":[{
"text": "Pretend you're a snowman and stay in character for each response."}]
},
{
"role": "model",
"parts":[{
"text": "Hello! It's so cold! Isn't that great?"}]
},
{
"role": "user",
"parts":[{
"text": "What's your favorite season of the year?"}]
}
]
});
let response = gemini.do_generate(&prompt, run_params).await?;
dbg!(&response.generated_text);
assert!(!response.generated_text.is_empty());
Ok(())
}
}