use system format

This commit is contained in:
Asuka Minato 2024-06-18 11:56:32 +09:00
parent 5d4a04ac0e
commit e878089b64
2 changed files with 72 additions and 29 deletions

View File

@ -1,15 +1,15 @@
use anyhow::Context;
use serde::Deserialize;
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use tracing::instrument;
use super::TransformerBackend;
use crate::{
config,
config::{self, ChatMessage, FIM},
memory_backends::{FIMPrompt, Prompt, PromptType},
transformer_worker::{
DoGenerationResponse, DoGenerationStreamResponse, GenerationStreamRequest,
},
}, utils::{format_chat_messages, format_context_code},
};
const fn max_tokens_default() -> usize {
@ -27,6 +27,8 @@ const fn temperature_default() -> f32 {
// NOTE: We cannot deny unknown fields as the provided parameters may contain other fields relevant to other processes
#[derive(Debug, Deserialize)]
pub struct GeminiRunParams {
pub fim: Option<FIM>,
messages: Option<Vec<ChatMessage>>,
#[serde(default = "max_tokens_default")]
pub max_tokens: usize,
#[serde(default = "top_p_default")]
@ -40,18 +42,18 @@ pub struct GeminiRunParams {
}
pub struct Gemini {
config: config::Gemini,
configuration: config::Gemini,
}
impl Gemini {
pub fn new(config: config::Gemini) -> Self {
Self { config }
pub fn new(configuration: config::Gemini) -> Self {
Self { configuration }
}
fn get_token(&self) -> anyhow::Result<String> {
if let Some(env_var_name) = &self.config.auth_token_env_var_name {
if let Some(env_var_name) = &self.configuration.auth_token_env_var_name {
Ok(std::env::var(env_var_name)?)
} else if let Some(token) = &self.config.auth_token {
} else if let Some(token) = &self.configuration.auth_token {
Ok(token.to_string())
} else {
anyhow::bail!(
@ -60,17 +62,21 @@ impl Gemini {
}
}
async fn do_fim(&self, prompt: &FIMPrompt, params: GeminiRunParams) -> anyhow::Result<String> {
async fn get_completion(
&self,
prompt: &str,
_params: GeminiRunParams,
) -> anyhow::Result<String> {
let client = reqwest::Client::new();
let token = self.get_token()?;
let res: serde_json::Value = client
.post(
self.config
self.configuration
.completions_endpoint
.as_ref()
.context("must specify `gemini_endpoint` to use gemini")?
.context("must specify `completions_endpoint` to use gemini")?
.to_owned()
+ self.config.model.as_ref()
+ self.configuration.model.as_ref()
+ ":generateContent?key="
+ token.as_ref(),
)
@ -81,7 +87,7 @@ impl Gemini {
{
"parts":[
{
"text": prompt.prompt
"text": prompt
}
]
}
@ -112,22 +118,24 @@ impl Gemini {
anyhow::bail!("Unknown error while making request to Gemini: {:?}", res);
}
}
async fn do_chat_completion(&self, prompt: &Prompt, params: Value) -> anyhow::Result<String> {
async fn get_chat(&self, messages: Vec<ChatMessage>, params: GeminiRunParams) -> anyhow::Result<String> {
let client = reqwest::Client::new();
let token = self.get_token()?;
let res: serde_json::Value = client
.post(
self.config
self.configuration
.chat_endpoint
.as_ref()
.context("must specify `gemini_endpoint` to use gemini")?
.context("must specify `chat_endpoint` to use gemini")?
.to_owned()
+ self.config.model.as_ref()
+ self.configuration.model.as_ref()
+ ":generateContent?key="
+ token.as_ref(),
)
.header("Content-Type", "application/json")
.json(&params)
.json(&messages)
// .json(params)
.send()
.await?
.json()
@ -152,6 +160,44 @@ impl Gemini {
anyhow::bail!("Unknown error while making request to Gemini: {:?}", res);
}
}
async fn do_chat_completion(
&self,
prompt: &Prompt,
params: GeminiRunParams,
) -> anyhow::Result<String> {
match prompt {
Prompt::ContextAndCode(code_and_context) => match &params.messages {
Some(completion_messages) => {
let messages = format_chat_messages(completion_messages, code_and_context);
self.get_chat(messages, params).await
}
None => {
self.get_completion(
&format_context_code(&code_and_context.context, &code_and_context.code),
params,
)
.await
}
},
Prompt::FIM(fim) => match &params.fim {
Some(fim_params) => {
self.get_completion(
&format!(
"{}{}{}{}{}",
fim_params.start,
fim.prompt,
fim_params.middle,
fim.suffix,
fim_params.end
),
params,
)
.await
}
None => anyhow::bail!("Prompt type is FIM but no FIM parameters provided"),
},
}
}
}
#[async_trait::async_trait]
@ -163,7 +209,7 @@ impl TransformerBackend for Gemini {
params: Value,
) -> anyhow::Result<DoGenerationResponse> {
let params: GeminiRunParams = serde_json::from_value(params)?;
let generated_text = self.do_fim(prompt.try_into()?, params).await?;
let generated_text = self.do_chat_completion(prompt, params).await?;
Ok(DoGenerationResponse { generated_text })
}
@ -175,10 +221,6 @@ impl TransformerBackend for Gemini {
) -> anyhow::Result<DoGenerationStreamResponse> {
anyhow::bail!("GenerationStream is not yet implemented")
}
fn get_prompt_type(&self, _params: &Value) -> anyhow::Result<PromptType> {
Ok(PromptType::FIM)
}
}
#[cfg(test)]
@ -194,9 +236,9 @@ mod test {
"auth_token_env_var_name": "GEMINI_API_KEY",
}))?;
let gemini = Gemini::new(configuration);
let prompt = Prompt::default_fim();
let prompt = Prompt::default_without_cursor();
let run_params = json!({
"max_tokens": 2
"max_tokens": 64
});
let response = gemini.do_generate(&prompt, run_params).await?;
assert!(!response.generated_text.is_empty());
@ -207,6 +249,7 @@ mod test {
async fn gemini_chat_do_generate() -> anyhow::Result<()> {
let configuration: config::Gemini = serde_json::from_value(json!({
"chat_endpoint": "https://generativelanguage.googleapis.com/v1beta/models/",
"completions_endpoint": "https://generativelanguage.googleapis.com/v1beta/models/",
"model": "gemini-1.5-flash",
"auth_token_env_var_name": "GEMINI_API_KEY",
}))?;
@ -231,9 +274,9 @@ mod test {
}
]
});
let response = gemini.do_chat_completion(&prompt, run_params).await?;
dbg!(&response);
assert!(!response.is_empty());
let response = gemini.do_generate(&prompt, run_params).await?;
dbg!(&response.generated_text);
assert!(!response.generated_text.is_empty());
Ok(())
}
}

View File

@ -163,7 +163,7 @@ impl OpenAI {
self.configuration
.chat_endpoint
.as_ref()
.context("must specify `completions_endpoint` to use completions")?,
.context("must specify `chat_endpoint` to use completions")?,
)
.bearer_auth(token)
.header("Content-Type", "application/json")