use system format

This commit is contained in:
Asuka Minato 2024-06-18 11:56:32 +09:00
parent 5d4a04ac0e
commit e878089b64
2 changed files with 72 additions and 29 deletions

View File

@ -1,15 +1,15 @@
use anyhow::Context; use anyhow::Context;
use serde::Deserialize; use serde::{Deserialize, Serialize};
use serde_json::{json, Value}; use serde_json::{json, Value};
use tracing::instrument; use tracing::instrument;
use super::TransformerBackend; use super::TransformerBackend;
use crate::{ use crate::{
config, config::{self, ChatMessage, FIM},
memory_backends::{FIMPrompt, Prompt, PromptType}, memory_backends::{FIMPrompt, Prompt, PromptType},
transformer_worker::{ transformer_worker::{
DoGenerationResponse, DoGenerationStreamResponse, GenerationStreamRequest, DoGenerationResponse, DoGenerationStreamResponse, GenerationStreamRequest,
}, }, utils::{format_chat_messages, format_context_code},
}; };
const fn max_tokens_default() -> usize { const fn max_tokens_default() -> usize {
@ -27,6 +27,8 @@ const fn temperature_default() -> f32 {
// NOTE: We cannot deny unknown fields as the provided parameters may contain other fields relevant to other processes // NOTE: We cannot deny unknown fields as the provided parameters may contain other fields relevant to other processes
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
pub struct GeminiRunParams { pub struct GeminiRunParams {
pub fim: Option<FIM>,
messages: Option<Vec<ChatMessage>>,
#[serde(default = "max_tokens_default")] #[serde(default = "max_tokens_default")]
pub max_tokens: usize, pub max_tokens: usize,
#[serde(default = "top_p_default")] #[serde(default = "top_p_default")]
@ -40,18 +42,18 @@ pub struct GeminiRunParams {
} }
pub struct Gemini { pub struct Gemini {
config: config::Gemini, configuration: config::Gemini,
} }
impl Gemini { impl Gemini {
pub fn new(config: config::Gemini) -> Self { pub fn new(configuration: config::Gemini) -> Self {
Self { config } Self { configuration }
} }
fn get_token(&self) -> anyhow::Result<String> { fn get_token(&self) -> anyhow::Result<String> {
if let Some(env_var_name) = &self.config.auth_token_env_var_name { if let Some(env_var_name) = &self.configuration.auth_token_env_var_name {
Ok(std::env::var(env_var_name)?) Ok(std::env::var(env_var_name)?)
} else if let Some(token) = &self.config.auth_token { } else if let Some(token) = &self.configuration.auth_token {
Ok(token.to_string()) Ok(token.to_string())
} else { } else {
anyhow::bail!( anyhow::bail!(
@ -60,17 +62,21 @@ impl Gemini {
} }
} }
async fn do_fim(&self, prompt: &FIMPrompt, params: GeminiRunParams) -> anyhow::Result<String> { async fn get_completion(
&self,
prompt: &str,
_params: GeminiRunParams,
) -> anyhow::Result<String> {
let client = reqwest::Client::new(); let client = reqwest::Client::new();
let token = self.get_token()?; let token = self.get_token()?;
let res: serde_json::Value = client let res: serde_json::Value = client
.post( .post(
self.config self.configuration
.completions_endpoint .completions_endpoint
.as_ref() .as_ref()
.context("must specify `gemini_endpoint` to use gemini")? .context("must specify `completions_endpoint` to use gemini")?
.to_owned() .to_owned()
+ self.config.model.as_ref() + self.configuration.model.as_ref()
+ ":generateContent?key=" + ":generateContent?key="
+ token.as_ref(), + token.as_ref(),
) )
@ -81,7 +87,7 @@ impl Gemini {
{ {
"parts":[ "parts":[
{ {
"text": prompt.prompt "text": prompt
} }
] ]
} }
@ -112,22 +118,24 @@ impl Gemini {
anyhow::bail!("Unknown error while making request to Gemini: {:?}", res); anyhow::bail!("Unknown error while making request to Gemini: {:?}", res);
} }
} }
async fn do_chat_completion(&self, prompt: &Prompt, params: Value) -> anyhow::Result<String> {
async fn get_chat(&self, messages: Vec<ChatMessage>, params: GeminiRunParams) -> anyhow::Result<String> {
let client = reqwest::Client::new(); let client = reqwest::Client::new();
let token = self.get_token()?; let token = self.get_token()?;
let res: serde_json::Value = client let res: serde_json::Value = client
.post( .post(
self.config self.configuration
.chat_endpoint .chat_endpoint
.as_ref() .as_ref()
.context("must specify `gemini_endpoint` to use gemini")? .context("must specify `chat_endpoint` to use gemini")?
.to_owned() .to_owned()
+ self.config.model.as_ref() + self.configuration.model.as_ref()
+ ":generateContent?key=" + ":generateContent?key="
+ token.as_ref(), + token.as_ref(),
) )
.header("Content-Type", "application/json") .header("Content-Type", "application/json")
.json(&params) .json(&messages)
// .json(params)
.send() .send()
.await? .await?
.json() .json()
@ -152,6 +160,44 @@ impl Gemini {
anyhow::bail!("Unknown error while making request to Gemini: {:?}", res); anyhow::bail!("Unknown error while making request to Gemini: {:?}", res);
} }
} }
async fn do_chat_completion(
&self,
prompt: &Prompt,
params: GeminiRunParams,
) -> anyhow::Result<String> {
match prompt {
Prompt::ContextAndCode(code_and_context) => match &params.messages {
Some(completion_messages) => {
let messages = format_chat_messages(completion_messages, code_and_context);
self.get_chat(messages, params).await
}
None => {
self.get_completion(
&format_context_code(&code_and_context.context, &code_and_context.code),
params,
)
.await
}
},
Prompt::FIM(fim) => match &params.fim {
Some(fim_params) => {
self.get_completion(
&format!(
"{}{}{}{}{}",
fim_params.start,
fim.prompt,
fim_params.middle,
fim.suffix,
fim_params.end
),
params,
)
.await
}
None => anyhow::bail!("Prompt type is FIM but no FIM parameters provided"),
},
}
}
} }
#[async_trait::async_trait] #[async_trait::async_trait]
@ -163,7 +209,7 @@ impl TransformerBackend for Gemini {
params: Value, params: Value,
) -> anyhow::Result<DoGenerationResponse> { ) -> anyhow::Result<DoGenerationResponse> {
let params: GeminiRunParams = serde_json::from_value(params)?; let params: GeminiRunParams = serde_json::from_value(params)?;
let generated_text = self.do_fim(prompt.try_into()?, params).await?; let generated_text = self.do_chat_completion(prompt, params).await?;
Ok(DoGenerationResponse { generated_text }) Ok(DoGenerationResponse { generated_text })
} }
@ -175,10 +221,6 @@ impl TransformerBackend for Gemini {
) -> anyhow::Result<DoGenerationStreamResponse> { ) -> anyhow::Result<DoGenerationStreamResponse> {
anyhow::bail!("GenerationStream is not yet implemented") anyhow::bail!("GenerationStream is not yet implemented")
} }
fn get_prompt_type(&self, _params: &Value) -> anyhow::Result<PromptType> {
Ok(PromptType::FIM)
}
} }
#[cfg(test)] #[cfg(test)]
@ -194,9 +236,9 @@ mod test {
"auth_token_env_var_name": "GEMINI_API_KEY", "auth_token_env_var_name": "GEMINI_API_KEY",
}))?; }))?;
let gemini = Gemini::new(configuration); let gemini = Gemini::new(configuration);
let prompt = Prompt::default_fim(); let prompt = Prompt::default_without_cursor();
let run_params = json!({ let run_params = json!({
"max_tokens": 2 "max_tokens": 64
}); });
let response = gemini.do_generate(&prompt, run_params).await?; let response = gemini.do_generate(&prompt, run_params).await?;
assert!(!response.generated_text.is_empty()); assert!(!response.generated_text.is_empty());
@ -207,6 +249,7 @@ mod test {
async fn gemini_chat_do_generate() -> anyhow::Result<()> { async fn gemini_chat_do_generate() -> anyhow::Result<()> {
let configuration: config::Gemini = serde_json::from_value(json!({ let configuration: config::Gemini = serde_json::from_value(json!({
"chat_endpoint": "https://generativelanguage.googleapis.com/v1beta/models/", "chat_endpoint": "https://generativelanguage.googleapis.com/v1beta/models/",
"completions_endpoint": "https://generativelanguage.googleapis.com/v1beta/models/",
"model": "gemini-1.5-flash", "model": "gemini-1.5-flash",
"auth_token_env_var_name": "GEMINI_API_KEY", "auth_token_env_var_name": "GEMINI_API_KEY",
}))?; }))?;
@ -231,9 +274,9 @@ mod test {
} }
] ]
}); });
let response = gemini.do_chat_completion(&prompt, run_params).await?; let response = gemini.do_generate(&prompt, run_params).await?;
dbg!(&response); dbg!(&response.generated_text);
assert!(!response.is_empty()); assert!(!response.generated_text.is_empty());
Ok(()) Ok(())
} }
} }

View File

@ -163,7 +163,7 @@ impl OpenAI {
self.configuration self.configuration
.chat_endpoint .chat_endpoint
.as_ref() .as_ref()
.context("must specify `completions_endpoint` to use completions")?, .context("must specify `chat_endpoint` to use completions")?,
) )
.bearer_auth(token) .bearer_auth(token)
.header("Content-Type", "application/json") .header("Content-Type", "application/json")