mirror of
https://github.com/SilasMarvin/lsp-ai.git
synced 2024-08-15 23:30:34 +03:00
use system format
This commit is contained in:
parent
5d4a04ac0e
commit
e878089b64
@ -1,15 +1,15 @@
|
||||
use anyhow::Context;
|
||||
use serde::Deserialize;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{json, Value};
|
||||
use tracing::instrument;
|
||||
|
||||
use super::TransformerBackend;
|
||||
use crate::{
|
||||
config,
|
||||
config::{self, ChatMessage, FIM},
|
||||
memory_backends::{FIMPrompt, Prompt, PromptType},
|
||||
transformer_worker::{
|
||||
DoGenerationResponse, DoGenerationStreamResponse, GenerationStreamRequest,
|
||||
},
|
||||
}, utils::{format_chat_messages, format_context_code},
|
||||
};
|
||||
|
||||
const fn max_tokens_default() -> usize {
|
||||
@ -27,6 +27,8 @@ const fn temperature_default() -> f32 {
|
||||
// NOTE: We cannot deny unknown fields as the provided parameters may contain other fields relevant to other processes
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct GeminiRunParams {
|
||||
pub fim: Option<FIM>,
|
||||
messages: Option<Vec<ChatMessage>>,
|
||||
#[serde(default = "max_tokens_default")]
|
||||
pub max_tokens: usize,
|
||||
#[serde(default = "top_p_default")]
|
||||
@ -40,18 +42,18 @@ pub struct GeminiRunParams {
|
||||
}
|
||||
|
||||
pub struct Gemini {
|
||||
config: config::Gemini,
|
||||
configuration: config::Gemini,
|
||||
}
|
||||
|
||||
impl Gemini {
|
||||
pub fn new(config: config::Gemini) -> Self {
|
||||
Self { config }
|
||||
pub fn new(configuration: config::Gemini) -> Self {
|
||||
Self { configuration }
|
||||
}
|
||||
|
||||
fn get_token(&self) -> anyhow::Result<String> {
|
||||
if let Some(env_var_name) = &self.config.auth_token_env_var_name {
|
||||
if let Some(env_var_name) = &self.configuration.auth_token_env_var_name {
|
||||
Ok(std::env::var(env_var_name)?)
|
||||
} else if let Some(token) = &self.config.auth_token {
|
||||
} else if let Some(token) = &self.configuration.auth_token {
|
||||
Ok(token.to_string())
|
||||
} else {
|
||||
anyhow::bail!(
|
||||
@ -60,17 +62,21 @@ impl Gemini {
|
||||
}
|
||||
}
|
||||
|
||||
async fn do_fim(&self, prompt: &FIMPrompt, params: GeminiRunParams) -> anyhow::Result<String> {
|
||||
async fn get_completion(
|
||||
&self,
|
||||
prompt: &str,
|
||||
_params: GeminiRunParams,
|
||||
) -> anyhow::Result<String> {
|
||||
let client = reqwest::Client::new();
|
||||
let token = self.get_token()?;
|
||||
let res: serde_json::Value = client
|
||||
.post(
|
||||
self.config
|
||||
self.configuration
|
||||
.completions_endpoint
|
||||
.as_ref()
|
||||
.context("must specify `gemini_endpoint` to use gemini")?
|
||||
.context("must specify `completions_endpoint` to use gemini")?
|
||||
.to_owned()
|
||||
+ self.config.model.as_ref()
|
||||
+ self.configuration.model.as_ref()
|
||||
+ ":generateContent?key="
|
||||
+ token.as_ref(),
|
||||
)
|
||||
@ -81,7 +87,7 @@ impl Gemini {
|
||||
{
|
||||
"parts":[
|
||||
{
|
||||
"text": prompt.prompt
|
||||
"text": prompt
|
||||
}
|
||||
]
|
||||
}
|
||||
@ -112,22 +118,24 @@ impl Gemini {
|
||||
anyhow::bail!("Unknown error while making request to Gemini: {:?}", res);
|
||||
}
|
||||
}
|
||||
async fn do_chat_completion(&self, prompt: &Prompt, params: Value) -> anyhow::Result<String> {
|
||||
|
||||
async fn get_chat(&self, messages: Vec<ChatMessage>, params: GeminiRunParams) -> anyhow::Result<String> {
|
||||
let client = reqwest::Client::new();
|
||||
let token = self.get_token()?;
|
||||
let res: serde_json::Value = client
|
||||
.post(
|
||||
self.config
|
||||
self.configuration
|
||||
.chat_endpoint
|
||||
.as_ref()
|
||||
.context("must specify `gemini_endpoint` to use gemini")?
|
||||
.context("must specify `chat_endpoint` to use gemini")?
|
||||
.to_owned()
|
||||
+ self.config.model.as_ref()
|
||||
+ self.configuration.model.as_ref()
|
||||
+ ":generateContent?key="
|
||||
+ token.as_ref(),
|
||||
)
|
||||
.header("Content-Type", "application/json")
|
||||
.json(¶ms)
|
||||
.json(&messages)
|
||||
// .json(params)
|
||||
.send()
|
||||
.await?
|
||||
.json()
|
||||
@ -152,6 +160,44 @@ impl Gemini {
|
||||
anyhow::bail!("Unknown error while making request to Gemini: {:?}", res);
|
||||
}
|
||||
}
|
||||
async fn do_chat_completion(
|
||||
&self,
|
||||
prompt: &Prompt,
|
||||
params: GeminiRunParams,
|
||||
) -> anyhow::Result<String> {
|
||||
match prompt {
|
||||
Prompt::ContextAndCode(code_and_context) => match ¶ms.messages {
|
||||
Some(completion_messages) => {
|
||||
let messages = format_chat_messages(completion_messages, code_and_context);
|
||||
self.get_chat(messages, params).await
|
||||
}
|
||||
None => {
|
||||
self.get_completion(
|
||||
&format_context_code(&code_and_context.context, &code_and_context.code),
|
||||
params,
|
||||
)
|
||||
.await
|
||||
}
|
||||
},
|
||||
Prompt::FIM(fim) => match ¶ms.fim {
|
||||
Some(fim_params) => {
|
||||
self.get_completion(
|
||||
&format!(
|
||||
"{}{}{}{}{}",
|
||||
fim_params.start,
|
||||
fim.prompt,
|
||||
fim_params.middle,
|
||||
fim.suffix,
|
||||
fim_params.end
|
||||
),
|
||||
params,
|
||||
)
|
||||
.await
|
||||
}
|
||||
None => anyhow::bail!("Prompt type is FIM but no FIM parameters provided"),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
@ -163,7 +209,7 @@ impl TransformerBackend for Gemini {
|
||||
params: Value,
|
||||
) -> anyhow::Result<DoGenerationResponse> {
|
||||
let params: GeminiRunParams = serde_json::from_value(params)?;
|
||||
let generated_text = self.do_fim(prompt.try_into()?, params).await?;
|
||||
let generated_text = self.do_chat_completion(prompt, params).await?;
|
||||
Ok(DoGenerationResponse { generated_text })
|
||||
}
|
||||
|
||||
@ -175,10 +221,6 @@ impl TransformerBackend for Gemini {
|
||||
) -> anyhow::Result<DoGenerationStreamResponse> {
|
||||
anyhow::bail!("GenerationStream is not yet implemented")
|
||||
}
|
||||
|
||||
fn get_prompt_type(&self, _params: &Value) -> anyhow::Result<PromptType> {
|
||||
Ok(PromptType::FIM)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@ -194,9 +236,9 @@ mod test {
|
||||
"auth_token_env_var_name": "GEMINI_API_KEY",
|
||||
}))?;
|
||||
let gemini = Gemini::new(configuration);
|
||||
let prompt = Prompt::default_fim();
|
||||
let prompt = Prompt::default_without_cursor();
|
||||
let run_params = json!({
|
||||
"max_tokens": 2
|
||||
"max_tokens": 64
|
||||
});
|
||||
let response = gemini.do_generate(&prompt, run_params).await?;
|
||||
assert!(!response.generated_text.is_empty());
|
||||
@ -207,6 +249,7 @@ mod test {
|
||||
async fn gemini_chat_do_generate() -> anyhow::Result<()> {
|
||||
let configuration: config::Gemini = serde_json::from_value(json!({
|
||||
"chat_endpoint": "https://generativelanguage.googleapis.com/v1beta/models/",
|
||||
"completions_endpoint": "https://generativelanguage.googleapis.com/v1beta/models/",
|
||||
"model": "gemini-1.5-flash",
|
||||
"auth_token_env_var_name": "GEMINI_API_KEY",
|
||||
}))?;
|
||||
@ -231,9 +274,9 @@ mod test {
|
||||
}
|
||||
]
|
||||
});
|
||||
let response = gemini.do_chat_completion(&prompt, run_params).await?;
|
||||
dbg!(&response);
|
||||
assert!(!response.is_empty());
|
||||
let response = gemini.do_generate(&prompt, run_params).await?;
|
||||
dbg!(&response.generated_text);
|
||||
assert!(!response.generated_text.is_empty());
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
@ -163,7 +163,7 @@ impl OpenAI {
|
||||
self.configuration
|
||||
.chat_endpoint
|
||||
.as_ref()
|
||||
.context("must specify `completions_endpoint` to use completions")?,
|
||||
.context("must specify `chat_endpoint` to use completions")?,
|
||||
)
|
||||
.bearer_auth(token)
|
||||
.header("Content-Type", "application/json")
|
||||
|
Loading…
Reference in New Issue
Block a user