merge pull/1

This commit is contained in:
Asuka Minato 2024-06-19 23:05:49 +09:00
parent be577c19e0
commit f85c964a30
3 changed files with 95 additions and 152 deletions

View File

@ -423,6 +423,7 @@ mod test {
}); });
Config::new(args).unwrap(); Config::new(args).unwrap();
} }
#[test] #[test]
fn gemini_config() { fn gemini_config() {
let args = json!({ let args = json!({
@ -441,24 +442,22 @@ mod test {
"completion": { "completion": {
"model": "model1", "model": "model1",
"parameters": { "parameters": {
"systemInstruction": {
"role": "system",
"parts": [{
"text": "TEST system instruction"
}]
},
"generationConfig": {
"maxOutputTokens": 10
},
"contents": [ "contents": [
{ {
"role": "user", "role": "user",
"parts":[{ "parts":[{
"text": "Pretend you're a snowman and stay in character for each response."}] "text": "TEST - {CONTEXT} and {CODE}"}]
}, }
{ ]
"role": "model",
"parts":[{
"text": "Hello! It's so cold! Isn't that great?"}]
},
{
"role": "user",
"parts":[{
"text": "What's your favorite season of the year?"}]
}
],
"max_new_tokens": 32,
} }
} }
} }

View File

@ -5,51 +5,79 @@ use tracing::instrument;
use super::TransformerBackend; use super::TransformerBackend;
use crate::{ use crate::{
config::{self, ChatMessage, FIM}, config,
memory_backends::Prompt, memory_backends::{ContextAndCodePrompt, Prompt},
transformer_worker::{ transformer_worker::{
DoGenerationResponse, DoGenerationStreamResponse, GenerationStreamRequest, DoGenerationResponse, DoGenerationStreamResponse, GenerationStreamRequest,
}, },
utils::{format_chat_messages, format_context_code}, utils::format_context_code_in_str,
}; };
fn format_gemini_contents(
messages: &[GeminiContent],
prompt: &ContextAndCodePrompt,
) -> Vec<GeminiContent> {
messages
.iter()
.map(|m| {
GeminiContent::new(
m.role.to_owned(),
m.parts
.iter()
.map(|p| Part {
text: format_context_code_in_str(&p.text, &prompt.context, &prompt.code),
})
.collect(),
)
})
.collect()
}
const fn max_tokens_default() -> usize { const fn max_tokens_default() -> usize {
64 64
} }
const fn top_p_default() -> f32 {
0.95
}
const fn temperature_default() -> f32 {
0.1
}
#[derive(Debug, Serialize, Deserialize, Clone)] #[derive(Debug, Serialize, Deserialize, Clone)]
struct Part { struct Part {
pub text: String, pub text: String,
} }
#[derive(Debug, Serialize, Deserialize, Clone)] #[derive(Debug, Serialize, Deserialize, Clone)]
struct GeminiChatMessage { struct GeminiContent {
role: String, role: String,
parts: Vec<Part>, parts: Vec<Part>,
} }
// NOTE: We cannot deny unknown fields as the provided parameters may contain other fields relevant to other processes impl GeminiContent {
#[derive(Debug, Deserialize, Clone)] fn new(role: String, parts: Vec<Part>) -> Self {
pub struct GeminiRunParams { Self { role, parts }
pub fim: Option<FIM>, }
contents: Option<Vec<GeminiChatMessage>>, }
#[serde(default = "max_tokens_default")]
pub max_tokens: usize, #[derive(Debug, Deserialize, Serialize, Clone)]
#[serde(default = "top_p_default")] #[serde(deny_unknown_fields)]
pub top_p: f32, pub struct GeminiGenerationConfig {
#[serde(default = "temperature_default")] #[serde(rename = "stopSequences")]
pub temperature: f32,
pub min_tokens: Option<u64>,
pub random_seed: Option<u64>,
#[serde(default)] #[serde(default)]
pub stop: Vec<String>, pub stop_sequences: Vec<String>,
#[serde(rename = "maxOutputTokens")]
#[serde(default = "max_tokens_default")]
pub max_output_tokens: usize,
pub temperature: Option<f32>,
#[serde(rename = "topP")]
pub top_p: Option<f32>,
#[serde(rename = "topK")]
pub top_k: Option<f32>,
}
// NOTE: We cannot deny unknown fields as the provided parameters may contain other fields relevant to other processes
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct GeminiRunParams {
contents: Vec<GeminiContent>,
#[serde(rename = "systemInstruction")]
system_instruction: GeminiContent,
#[serde(rename = "generationConfig")]
generation_config: Option<GeminiGenerationConfig>,
} }
pub struct Gemini { pub struct Gemini {
@ -73,66 +101,10 @@ impl Gemini {
} }
} }
async fn get_completion(
&self,
prompt: &str,
_params: GeminiRunParams,
) -> anyhow::Result<String> {
let client = reqwest::Client::new();
let token = self.get_token()?;
let res: serde_json::Value = client
.post(
self.configuration
.completions_endpoint
.as_ref()
.context("must specify `completions_endpoint` to use gemini")?
.to_owned()
+ self.configuration.model.as_ref()
+ ":generateContent?key="
+ token.as_ref(),
)
.header("Content-Type", "application/json")
.json(&json!(
{
"contents":[
{
"parts":[
{
"text": prompt
}
]
}
]
}
))
.send()
.await?
.json()
.await?;
if let Some(error) = res.get("error") {
anyhow::bail!("{:?}", error.to_string())
} else if let Some(candidates) = res.get("candidates") {
Ok(candidates
.get(0)
.unwrap()
.get("content")
.unwrap()
.get("parts")
.unwrap()
.get(0)
.unwrap()
.get("text")
.unwrap()
.clone()
.to_string())
} else {
anyhow::bail!("Unknown error while making request to Gemini: {:?}", res);
}
}
async fn get_chat( async fn get_chat(
&self, &self,
messages: &[GeminiChatMessage], messages: Vec<GeminiContent>,
_params: GeminiRunParams, params: GeminiRunParams,
) -> anyhow::Result<String> { ) -> anyhow::Result<String> {
let client = reqwest::Client::new(); let client = reqwest::Client::new();
let token = self.get_token()?; let token = self.get_token()?;
@ -149,7 +121,9 @@ impl Gemini {
) )
.header("Content-Type", "application/json") .header("Content-Type", "application/json")
.json(&json!({ .json(&json!({
"contents": messages "contents": messages,
"systemInstruction": params.system_instruction,
"generationConfig": params.generation_config,
})) }))
.send() .send()
.await? .await?
@ -181,35 +155,11 @@ impl Gemini {
params: GeminiRunParams, params: GeminiRunParams,
) -> anyhow::Result<String> { ) -> anyhow::Result<String> {
match prompt { match prompt {
Prompt::ContextAndCode(code_and_context) => match &params.contents { Prompt::ContextAndCode(code_and_context) => {
Some(completion_messages) => { let messages = format_gemini_contents(&params.contents, code_and_context);
self.get_chat(completion_messages, params.clone()).await self.get_chat(messages, params).await
} }
None => { _ => anyhow::bail!("Google Gemini backend does not yet support FIM"),
self.get_completion(
&format_context_code(&code_and_context.context, &code_and_context.code),
params,
)
.await
}
},
Prompt::FIM(fim) => match &params.fim {
Some(fim_params) => {
self.get_completion(
&format!(
"{}{}{}{}{}",
fim_params.start,
fim.prompt,
fim_params.middle,
fim.suffix,
fim_params.end
),
params,
)
.await
}
None => anyhow::bail!("Prompt type is FIM but no FIM parameters provided"),
},
} }
} }
} }
@ -240,25 +190,8 @@ impl TransformerBackend for Gemini {
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use super::*; use super::*;
use serde_json::{from_value, json}; use serde_json::json;
#[tokio::test]
async fn gemini_completion_do_generate() -> anyhow::Result<()> {
let configuration: config::Gemini = from_value(json!({
"completions_endpoint": "https://generativelanguage.googleapis.com/v1beta/models/",
"model": "gemini-1.5-flash-latest",
"auth_token_env_var_name": "GEMINI_API_KEY",
}))?;
let gemini = Gemini::new(configuration);
let prompt = Prompt::default_without_cursor();
let run_params = json!({
"max_tokens": 64
});
let response = gemini.do_generate(&prompt, run_params).await?;
assert!(!response.generated_text.is_empty());
dbg!(response.generated_text);
Ok(())
}
#[tokio::test] #[tokio::test]
async fn gemini_chat_do_generate() -> anyhow::Result<()> { async fn gemini_chat_do_generate() -> anyhow::Result<()> {
let configuration: config::Gemini = serde_json::from_value(json!({ let configuration: config::Gemini = serde_json::from_value(json!({
@ -269,9 +202,18 @@ mod test {
let gemini = Gemini::new(configuration); let gemini = Gemini::new(configuration);
let prompt = Prompt::default_with_cursor(); let prompt = Prompt::default_with_cursor();
let run_params = json!({ let run_params = json!({
"systemInstruction": {
"role": "system",
"parts": [{
"text": "You are a helpful and willing chatbot that will do whatever the user asks"
}]
},
"generationConfig": {
"maxOutputTokens": 10
},
"contents": [ "contents": [
{ {
"role":"user", "role": "user",
"parts":[{ "parts":[{
"text": "Pretend you're a snowman and stay in character for each response."}] "text": "Pretend you're a snowman and stay in character for each response."}]
}, },

View File

@ -29,14 +29,16 @@ pub fn format_chat_messages(
.map(|m| { .map(|m| {
ChatMessage::new( ChatMessage::new(
m.role.to_owned(), m.role.to_owned(),
m.content format_context_code_in_str(&m.content, &prompt.context, &prompt.code),
.replace("{CONTEXT}", &prompt.context)
.replace("{CODE}", &prompt.code),
) )
}) })
.collect() .collect()
} }
pub fn format_context_code_in_str(s: &str, context: &str, code: &str) -> String {
s.replace("{CONTEXT}", context).replace("{CODE}", code)
}
pub fn format_context_code(context: &str, code: &str) -> String { pub fn format_context_code(context: &str, code: &str) -> String {
format!("{context}\n\n{code}") format!("{context}\n\n{code}")
} }