mirror of
https://github.com/SilasMarvin/lsp-ai.git
synced 2024-09-11 12:25:48 +03:00
merge pull/1
This commit is contained in:
parent
be577c19e0
commit
f85c964a30
@ -423,6 +423,7 @@ mod test {
|
|||||||
});
|
});
|
||||||
Config::new(args).unwrap();
|
Config::new(args).unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn gemini_config() {
|
fn gemini_config() {
|
||||||
let args = json!({
|
let args = json!({
|
||||||
@ -441,24 +442,22 @@ mod test {
|
|||||||
"completion": {
|
"completion": {
|
||||||
"model": "model1",
|
"model": "model1",
|
||||||
"parameters": {
|
"parameters": {
|
||||||
|
"systemInstruction": {
|
||||||
|
"role": "system",
|
||||||
|
"parts": [{
|
||||||
|
"text": "TEST system instruction"
|
||||||
|
}]
|
||||||
|
},
|
||||||
|
"generationConfig": {
|
||||||
|
"maxOutputTokens": 10
|
||||||
|
},
|
||||||
"contents": [
|
"contents": [
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"parts":[{
|
"parts":[{
|
||||||
"text": "Pretend you're a snowman and stay in character for each response."}]
|
"text": "TEST - {CONTEXT} and {CODE}"}]
|
||||||
},
|
}
|
||||||
{
|
]
|
||||||
"role": "model",
|
|
||||||
"parts":[{
|
|
||||||
"text": "Hello! It's so cold! Isn't that great?"}]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"parts":[{
|
|
||||||
"text": "What's your favorite season of the year?"}]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"max_new_tokens": 32,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -5,51 +5,79 @@ use tracing::instrument;
|
|||||||
|
|
||||||
use super::TransformerBackend;
|
use super::TransformerBackend;
|
||||||
use crate::{
|
use crate::{
|
||||||
config::{self, ChatMessage, FIM},
|
config,
|
||||||
memory_backends::Prompt,
|
memory_backends::{ContextAndCodePrompt, Prompt},
|
||||||
transformer_worker::{
|
transformer_worker::{
|
||||||
DoGenerationResponse, DoGenerationStreamResponse, GenerationStreamRequest,
|
DoGenerationResponse, DoGenerationStreamResponse, GenerationStreamRequest,
|
||||||
},
|
},
|
||||||
utils::{format_chat_messages, format_context_code},
|
utils::format_context_code_in_str,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
fn format_gemini_contents(
|
||||||
|
messages: &[GeminiContent],
|
||||||
|
prompt: &ContextAndCodePrompt,
|
||||||
|
) -> Vec<GeminiContent> {
|
||||||
|
messages
|
||||||
|
.iter()
|
||||||
|
.map(|m| {
|
||||||
|
GeminiContent::new(
|
||||||
|
m.role.to_owned(),
|
||||||
|
m.parts
|
||||||
|
.iter()
|
||||||
|
.map(|p| Part {
|
||||||
|
text: format_context_code_in_str(&p.text, &prompt.context, &prompt.code),
|
||||||
|
})
|
||||||
|
.collect(),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
const fn max_tokens_default() -> usize {
|
const fn max_tokens_default() -> usize {
|
||||||
64
|
64
|
||||||
}
|
}
|
||||||
|
|
||||||
const fn top_p_default() -> f32 {
|
|
||||||
0.95
|
|
||||||
}
|
|
||||||
|
|
||||||
const fn temperature_default() -> f32 {
|
|
||||||
0.1
|
|
||||||
}
|
|
||||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||||
struct Part {
|
struct Part {
|
||||||
pub text: String,
|
pub text: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||||
struct GeminiChatMessage {
|
struct GeminiContent {
|
||||||
role: String,
|
role: String,
|
||||||
parts: Vec<Part>,
|
parts: Vec<Part>,
|
||||||
}
|
}
|
||||||
|
|
||||||
// NOTE: We cannot deny unknown fields as the provided parameters may contain other fields relevant to other processes
|
impl GeminiContent {
|
||||||
#[derive(Debug, Deserialize, Clone)]
|
fn new(role: String, parts: Vec<Part>) -> Self {
|
||||||
pub struct GeminiRunParams {
|
Self { role, parts }
|
||||||
pub fim: Option<FIM>,
|
}
|
||||||
contents: Option<Vec<GeminiChatMessage>>,
|
}
|
||||||
#[serde(default = "max_tokens_default")]
|
|
||||||
pub max_tokens: usize,
|
#[derive(Debug, Deserialize, Serialize, Clone)]
|
||||||
#[serde(default = "top_p_default")]
|
#[serde(deny_unknown_fields)]
|
||||||
pub top_p: f32,
|
pub struct GeminiGenerationConfig {
|
||||||
#[serde(default = "temperature_default")]
|
#[serde(rename = "stopSequences")]
|
||||||
pub temperature: f32,
|
|
||||||
pub min_tokens: Option<u64>,
|
|
||||||
pub random_seed: Option<u64>,
|
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub stop: Vec<String>,
|
pub stop_sequences: Vec<String>,
|
||||||
|
#[serde(rename = "maxOutputTokens")]
|
||||||
|
#[serde(default = "max_tokens_default")]
|
||||||
|
pub max_output_tokens: usize,
|
||||||
|
pub temperature: Option<f32>,
|
||||||
|
#[serde(rename = "topP")]
|
||||||
|
pub top_p: Option<f32>,
|
||||||
|
#[serde(rename = "topK")]
|
||||||
|
pub top_k: Option<f32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
// NOTE: We cannot deny unknown fields as the provided parameters may contain other fields relevant to other processes
|
||||||
|
#[derive(Debug, Deserialize, Serialize, Clone)]
|
||||||
|
pub struct GeminiRunParams {
|
||||||
|
contents: Vec<GeminiContent>,
|
||||||
|
#[serde(rename = "systemInstruction")]
|
||||||
|
system_instruction: GeminiContent,
|
||||||
|
#[serde(rename = "generationConfig")]
|
||||||
|
generation_config: Option<GeminiGenerationConfig>,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct Gemini {
|
pub struct Gemini {
|
||||||
@ -73,66 +101,10 @@ impl Gemini {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get_completion(
|
|
||||||
&self,
|
|
||||||
prompt: &str,
|
|
||||||
_params: GeminiRunParams,
|
|
||||||
) -> anyhow::Result<String> {
|
|
||||||
let client = reqwest::Client::new();
|
|
||||||
let token = self.get_token()?;
|
|
||||||
let res: serde_json::Value = client
|
|
||||||
.post(
|
|
||||||
self.configuration
|
|
||||||
.completions_endpoint
|
|
||||||
.as_ref()
|
|
||||||
.context("must specify `completions_endpoint` to use gemini")?
|
|
||||||
.to_owned()
|
|
||||||
+ self.configuration.model.as_ref()
|
|
||||||
+ ":generateContent?key="
|
|
||||||
+ token.as_ref(),
|
|
||||||
)
|
|
||||||
.header("Content-Type", "application/json")
|
|
||||||
.json(&json!(
|
|
||||||
{
|
|
||||||
"contents":[
|
|
||||||
{
|
|
||||||
"parts":[
|
|
||||||
{
|
|
||||||
"text": prompt
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
))
|
|
||||||
.send()
|
|
||||||
.await?
|
|
||||||
.json()
|
|
||||||
.await?;
|
|
||||||
if let Some(error) = res.get("error") {
|
|
||||||
anyhow::bail!("{:?}", error.to_string())
|
|
||||||
} else if let Some(candidates) = res.get("candidates") {
|
|
||||||
Ok(candidates
|
|
||||||
.get(0)
|
|
||||||
.unwrap()
|
|
||||||
.get("content")
|
|
||||||
.unwrap()
|
|
||||||
.get("parts")
|
|
||||||
.unwrap()
|
|
||||||
.get(0)
|
|
||||||
.unwrap()
|
|
||||||
.get("text")
|
|
||||||
.unwrap()
|
|
||||||
.clone()
|
|
||||||
.to_string())
|
|
||||||
} else {
|
|
||||||
anyhow::bail!("Unknown error while making request to Gemini: {:?}", res);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
async fn get_chat(
|
async fn get_chat(
|
||||||
&self,
|
&self,
|
||||||
messages: &[GeminiChatMessage],
|
messages: Vec<GeminiContent>,
|
||||||
_params: GeminiRunParams,
|
params: GeminiRunParams,
|
||||||
) -> anyhow::Result<String> {
|
) -> anyhow::Result<String> {
|
||||||
let client = reqwest::Client::new();
|
let client = reqwest::Client::new();
|
||||||
let token = self.get_token()?;
|
let token = self.get_token()?;
|
||||||
@ -149,7 +121,9 @@ impl Gemini {
|
|||||||
)
|
)
|
||||||
.header("Content-Type", "application/json")
|
.header("Content-Type", "application/json")
|
||||||
.json(&json!({
|
.json(&json!({
|
||||||
"contents": messages
|
"contents": messages,
|
||||||
|
"systemInstruction": params.system_instruction,
|
||||||
|
"generationConfig": params.generation_config,
|
||||||
}))
|
}))
|
||||||
.send()
|
.send()
|
||||||
.await?
|
.await?
|
||||||
@ -181,35 +155,11 @@ impl Gemini {
|
|||||||
params: GeminiRunParams,
|
params: GeminiRunParams,
|
||||||
) -> anyhow::Result<String> {
|
) -> anyhow::Result<String> {
|
||||||
match prompt {
|
match prompt {
|
||||||
Prompt::ContextAndCode(code_and_context) => match ¶ms.contents {
|
Prompt::ContextAndCode(code_and_context) => {
|
||||||
Some(completion_messages) => {
|
let messages = format_gemini_contents(¶ms.contents, code_and_context);
|
||||||
self.get_chat(completion_messages, params.clone()).await
|
self.get_chat(messages, params).await
|
||||||
}
|
}
|
||||||
None => {
|
_ => anyhow::bail!("Google Gemini backend does not yet support FIM"),
|
||||||
self.get_completion(
|
|
||||||
&format_context_code(&code_and_context.context, &code_and_context.code),
|
|
||||||
params,
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
}
|
|
||||||
},
|
|
||||||
Prompt::FIM(fim) => match ¶ms.fim {
|
|
||||||
Some(fim_params) => {
|
|
||||||
self.get_completion(
|
|
||||||
&format!(
|
|
||||||
"{}{}{}{}{}",
|
|
||||||
fim_params.start,
|
|
||||||
fim.prompt,
|
|
||||||
fim_params.middle,
|
|
||||||
fim.suffix,
|
|
||||||
fim_params.end
|
|
||||||
),
|
|
||||||
params,
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
}
|
|
||||||
None => anyhow::bail!("Prompt type is FIM but no FIM parameters provided"),
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -240,25 +190,8 @@ impl TransformerBackend for Gemini {
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod test {
|
mod test {
|
||||||
use super::*;
|
use super::*;
|
||||||
use serde_json::{from_value, json};
|
use serde_json::json;
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn gemini_completion_do_generate() -> anyhow::Result<()> {
|
|
||||||
let configuration: config::Gemini = from_value(json!({
|
|
||||||
"completions_endpoint": "https://generativelanguage.googleapis.com/v1beta/models/",
|
|
||||||
"model": "gemini-1.5-flash-latest",
|
|
||||||
"auth_token_env_var_name": "GEMINI_API_KEY",
|
|
||||||
}))?;
|
|
||||||
let gemini = Gemini::new(configuration);
|
|
||||||
let prompt = Prompt::default_without_cursor();
|
|
||||||
let run_params = json!({
|
|
||||||
"max_tokens": 64
|
|
||||||
});
|
|
||||||
let response = gemini.do_generate(&prompt, run_params).await?;
|
|
||||||
assert!(!response.generated_text.is_empty());
|
|
||||||
dbg!(response.generated_text);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn gemini_chat_do_generate() -> anyhow::Result<()> {
|
async fn gemini_chat_do_generate() -> anyhow::Result<()> {
|
||||||
let configuration: config::Gemini = serde_json::from_value(json!({
|
let configuration: config::Gemini = serde_json::from_value(json!({
|
||||||
@ -269,9 +202,18 @@ mod test {
|
|||||||
let gemini = Gemini::new(configuration);
|
let gemini = Gemini::new(configuration);
|
||||||
let prompt = Prompt::default_with_cursor();
|
let prompt = Prompt::default_with_cursor();
|
||||||
let run_params = json!({
|
let run_params = json!({
|
||||||
|
"systemInstruction": {
|
||||||
|
"role": "system",
|
||||||
|
"parts": [{
|
||||||
|
"text": "You are a helpful and willing chatbot that will do whatever the user asks"
|
||||||
|
}]
|
||||||
|
},
|
||||||
|
"generationConfig": {
|
||||||
|
"maxOutputTokens": 10
|
||||||
|
},
|
||||||
"contents": [
|
"contents": [
|
||||||
{
|
{
|
||||||
"role":"user",
|
"role": "user",
|
||||||
"parts":[{
|
"parts":[{
|
||||||
"text": "Pretend you're a snowman and stay in character for each response."}]
|
"text": "Pretend you're a snowman and stay in character for each response."}]
|
||||||
},
|
},
|
||||||
|
@ -29,14 +29,16 @@ pub fn format_chat_messages(
|
|||||||
.map(|m| {
|
.map(|m| {
|
||||||
ChatMessage::new(
|
ChatMessage::new(
|
||||||
m.role.to_owned(),
|
m.role.to_owned(),
|
||||||
m.content
|
format_context_code_in_str(&m.content, &prompt.context, &prompt.code),
|
||||||
.replace("{CONTEXT}", &prompt.context)
|
|
||||||
.replace("{CODE}", &prompt.code),
|
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
.collect()
|
.collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn format_context_code_in_str(s: &str, context: &str, code: &str) -> String {
|
||||||
|
s.replace("{CONTEXT}", context).replace("{CODE}", code)
|
||||||
|
}
|
||||||
|
|
||||||
pub fn format_context_code(context: &str, code: &str) -> String {
|
pub fn format_context_code(context: &str, code: &str) -> String {
|
||||||
format!("{context}\n\n{code}")
|
format!("{context}\n\n{code}")
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user