now the test gemini_completion_do_generate works

This commit is contained in:
Asuka Minato 2024-06-17 17:25:33 +09:00
parent cd46ecf61a
commit 91f03457fc
4 changed files with 248 additions and 1 deletions

2
Cargo.lock generated
View File

@ -1518,7 +1518,7 @@ dependencies = [
[[package]]
name = "lsp-ai"
version = "0.2.0"
version = "0.3.0"
dependencies = [
"anyhow",
"assert_cmd",

View File

@ -46,6 +46,8 @@ pub enum ValidModel {
MistralFIM(MistralFIM),
#[serde(rename = "ollama")]
Ollama(Ollama),
#[serde(rename = "gemini")]
Gemini(Gemini),
}
#[derive(Debug, Clone, Deserialize, Serialize)]
@ -171,6 +173,24 @@ pub struct OpenAI {
pub model: String,
}
#[derive(Clone, Debug, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct Gemini {
// The auth token env var name
pub auth_token_env_var_name: Option<String>,
// The auth token
pub auth_token: Option<String>,
// The completions endpoint
pub completions_endpoint: Option<String>,
// The chat endpoint
pub chat_endpoint: Option<String>,
// The maximum requests per second
#[serde(default = "max_requests_per_second_default")]
pub max_requests_per_second: f32,
// The model name
pub model: String,
}
#[derive(Clone, Debug, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct Anthropic {
@ -272,6 +292,7 @@ impl Config {
#[cfg(feature = "llama_cpp")]
ValidModel::LLaMACPP(llama_cpp) => Ok(llama_cpp.max_requests_per_second),
ValidModel::OpenAI(open_ai) => Ok(open_ai.max_requests_per_second),
ValidModel::Gemini(gemini) => Ok(gemini.max_requests_per_second),
ValidModel::Anthropic(anthropic) => Ok(anthropic.max_requests_per_second),
ValidModel::MistralFIM(mistral_fim) => Ok(mistral_fim.max_requests_per_second),
ValidModel::Ollama(ollama) => Ok(ollama.max_requests_per_second),
@ -402,6 +423,41 @@ mod test {
});
Config::new(args).unwrap();
}
#[test]
fn gemini_config() {
let args = json!({
"initializationOptions": {
"memory": {
"file_store": {}
},
"models": {
"model1": {
"type": "gemini",
"completions_endpoint": "https://api.fireworks.ai/inference/v1/completions",
"model": "accounts/fireworks/models/llama-v2-34b-code",
"auth_token_env_var_name": "FIREWORKS_API_KEY",
},
},
"completion": {
"model": "model1",
"parameters": {
"messages": [
{
"role": "system",
"content": "Test",
},
{
"role": "user",
"content": "Test {CONTEXT} - {CODE}"
}
],
"max_new_tokens": 32,
}
}
}
});
Config::new(args).unwrap();
}
#[test]
fn anthropic_config() {

View File

@ -0,0 +1,189 @@
use std::collections::HashMap;
use anyhow::Context;
use serde::Deserialize;
use serde_json::{json, Value};
use tracing::instrument;
use super::{open_ai::OpenAIChatChoices, TransformerBackend};
use crate::{
config::{self},
memory_backends::{FIMPrompt, Prompt, PromptType},
transformer_worker::{
DoGenerationResponse, DoGenerationStreamResponse, GenerationStreamRequest,
},
};
const fn max_tokens_default() -> usize {
64
}
const fn top_p_default() -> f32 {
0.95
}
const fn temperature_default() -> f32 {
0.1
}
// NOTE: We cannot deny unknown fields as the provided parameters may contain other fields relevant to other processes
#[derive(Debug, Deserialize)]
pub struct GeminiRunParams {
#[serde(default = "max_tokens_default")]
pub max_tokens: usize,
#[serde(default = "top_p_default")]
pub top_p: f32,
#[serde(default = "temperature_default")]
pub temperature: f32,
pub min_tokens: Option<u64>,
pub random_seed: Option<u64>,
#[serde(default)]
pub stop: Vec<String>,
}
pub struct Gemini {
config: config::Gemini,
}
impl Gemini {
pub fn new(config: config::Gemini) -> Self {
Self { config }
}
fn get_token(&self) -> anyhow::Result<String> {
if let Some(env_var_name) = &self.config.auth_token_env_var_name {
Ok(std::env::var(env_var_name)?)
} else if let Some(token) = &self.config.auth_token {
Ok(token.to_string())
} else {
anyhow::bail!(
"set `auth_token_env_var_name` or `auth_token` to use an Gemini compatible API"
)
}
}
async fn do_fim(&self, prompt: &FIMPrompt, params: GeminiRunParams) -> anyhow::Result<String> {
let client = reqwest::Client::new();
let token = self.get_token()?;
let res: serde_json::Value = client
.post(
self.config
.completions_endpoint
.as_ref()
.context("must specify `gemini_endpoint` to use gemini")?,
)
.header("Content-Type", "application/json")
.json(&json!(
{
"contents":[
{
"parts":[
{
"text": prompt.prompt
}
]
}
]
}
))
.send()
.await?
.json()
.await?;
if let Some(error) = res.get("error") {
anyhow::bail!("{:?}", error.to_string())
} else if let Some(candidates) = res.get("candidates") {
Ok(candidates
.get(0)
.unwrap()
.get("content")
.unwrap()
.get("parts")
.unwrap()
.get(0)
.unwrap()
.get("text")
.unwrap()
.clone()
.to_string())
} else {
anyhow::bail!("Unknown error while making request to Gemini: {:?}", res);
}
}
}
#[async_trait::async_trait]
impl TransformerBackend for Gemini {
#[instrument(skip(self))]
async fn do_generate(
&self,
prompt: &Prompt,
params: Value,
) -> anyhow::Result<DoGenerationResponse> {
let params: GeminiRunParams = serde_json::from_value(params)?;
let generated_text = self.do_fim(prompt.try_into()?, params).await?;
Ok(DoGenerationResponse { generated_text })
}
#[instrument(skip(self))]
async fn do_generate_stream(
&self,
request: &GenerationStreamRequest,
_params: Value,
) -> anyhow::Result<DoGenerationStreamResponse> {
anyhow::bail!("GenerationStream is not yet implemented")
}
fn get_prompt_type(&self, _params: &Value) -> anyhow::Result<PromptType> {
Ok(PromptType::FIM)
}
}
#[cfg(test)]
mod test {
use super::*;
use serde_json::{from_value, json};
#[tokio::test]
async fn gemini_completion_do_generate() -> anyhow::Result<()> {
let configuration: config::Gemini = from_value(json!({
"completions_endpoint": "https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash-latest:generateContent?key=", // here fill the key
"model": "gemini-1.5-flash-latest",
"auth_token_env_var_name": "GEMINI_API_KEY",
}))?;
let anthropic = Gemini::new(configuration);
let prompt = Prompt::default_fim();
let run_params = json!({
"max_tokens": 2
});
let response = anthropic.do_generate(&prompt, run_params).await?;
assert!(!response.generated_text.is_empty());
Ok(())
}
#[tokio::test]
async fn gemini_chat_do_generate() -> anyhow::Result<()> {
let configuration: config::Gemini = serde_json::from_value(json!({
"chat_endpoint": "https://generativelanguage.googleapis.com/v1beta/",
"model": "gemini-1.5-flash",
"auth_token_env_var_name": "GEMINI_API_KEY",
}))?;
let gemini = Gemini::new(configuration);
let prompt = Prompt::default_with_cursor();
let run_params = json!({
"messages": [
{
"role": "system",
"content": "Test"
},
{
"role": "user",
"content": "Test {CONTEXT} - {CODE}"
}
],
"max_tokens": 64
});
let response = gemini.do_generate(&prompt, run_params).await?;
assert!(!response.generated_text.is_empty());
Ok(())
}
}

View File

@ -11,6 +11,7 @@ use crate::{
};
mod anthropic;
mod gemini;
#[cfg(feature = "llama_cpp")]
mod llama_cpp;
mod mistral_fim;
@ -66,6 +67,7 @@ impl TryFrom<ValidModel> for Box<dyn TransformerBackend + Send + Sync> {
ValidModel::OpenAI(open_ai_config) => {
Ok(Box::new(open_ai::OpenAI::new(open_ai_config)))
}
ValidModel::Gemini(gemini_config) => Ok(Box::new(gemini::Gemini::new(gemini_config))),
ValidModel::Anthropic(anthropic_config) => {
Ok(Box::new(anthropic::Anthropic::new(anthropic_config)))
}