mirror of
https://github.com/SilasMarvin/lsp-ai.git
synced 2024-08-15 23:30:34 +03:00
fin gemini_chat_do_generate
This commit is contained in:
parent
4548de4e3e
commit
b5302f9748
@ -1,13 +1,11 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use anyhow::Context;
|
||||
use serde::Deserialize;
|
||||
use serde_json::{json, Value};
|
||||
use tracing::instrument;
|
||||
|
||||
use super::{open_ai::OpenAIChatChoices, TransformerBackend};
|
||||
use super::TransformerBackend;
|
||||
use crate::{
|
||||
config::{self},
|
||||
config,
|
||||
memory_backends::{FIMPrompt, Prompt, PromptType},
|
||||
transformer_worker::{
|
||||
DoGenerationResponse, DoGenerationStreamResponse, GenerationStreamRequest,
|
||||
@ -70,7 +68,9 @@ impl Gemini {
|
||||
self.config
|
||||
.completions_endpoint
|
||||
.as_ref()
|
||||
.context("must specify `gemini_endpoint` to use gemini")?.to_owned() + token.as_ref(),
|
||||
.context("must specify `gemini_endpoint` to use gemini")?
|
||||
.to_owned()
|
||||
+ token.as_ref(),
|
||||
)
|
||||
.header("Content-Type", "application/json")
|
||||
.json(&json!(
|
||||
@ -110,6 +110,48 @@ impl Gemini {
|
||||
anyhow::bail!("Unknown error while making request to Gemini: {:?}", res);
|
||||
}
|
||||
}
|
||||
async fn do_chat_completion(
|
||||
&self,
|
||||
prompt: &Prompt,
|
||||
params: Value,
|
||||
) -> anyhow::Result<String> {
|
||||
let client = reqwest::Client::new();
|
||||
let token = self.get_token()?;
|
||||
let res: serde_json::Value = client
|
||||
.post(
|
||||
self.config
|
||||
.chat_endpoint
|
||||
.as_ref()
|
||||
.context("must specify `gemini_endpoint` to use gemini")?
|
||||
.to_owned()
|
||||
+ token.as_ref(),
|
||||
)
|
||||
.header("Content-Type", "application/json")
|
||||
.json(¶ms)
|
||||
.send()
|
||||
.await?
|
||||
.json()
|
||||
.await?;
|
||||
if let Some(error) = res.get("error") {
|
||||
anyhow::bail!("{:?}", error.to_string())
|
||||
} else if let Some(candidates) = res.get("candidates") {
|
||||
Ok(candidates
|
||||
.get(0)
|
||||
.unwrap()
|
||||
.get("content")
|
||||
.unwrap()
|
||||
.get("parts")
|
||||
.unwrap()
|
||||
.get(0)
|
||||
.unwrap()
|
||||
.get("text")
|
||||
.unwrap()
|
||||
.clone()
|
||||
.to_string())
|
||||
} else {
|
||||
anyhow::bail!("Unknown error while making request to Gemini: {:?}", res);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
@ -158,32 +200,40 @@ mod test {
|
||||
});
|
||||
let response = anthropic.do_generate(&prompt, run_params).await?;
|
||||
assert!(!response.generated_text.is_empty());
|
||||
dbg!(response.generated_text);
|
||||
Ok(())
|
||||
}
|
||||
#[tokio::test]
|
||||
async fn gemini_chat_do_generate() -> anyhow::Result<()> {
|
||||
let configuration: config::Gemini = serde_json::from_value(json!({
|
||||
"chat_endpoint": "https://generativelanguage.googleapis.com/v1beta/",
|
||||
"chat_endpoint": "https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent?key=",
|
||||
"model": "gemini-1.5-flash",
|
||||
"auth_token_env_var_name": "GEMINI_API_KEY",
|
||||
}))?;
|
||||
let gemini = Gemini::new(configuration);
|
||||
let prompt = Prompt::default_with_cursor();
|
||||
let run_params = json!({
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "Test"
|
||||
"contents": [
|
||||
{
|
||||
"role":"user",
|
||||
"parts":[{
|
||||
"text": "Pretend you're a snowman and stay in character for each response."}]
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Test {CONTEXT} - {CODE}"
|
||||
{
|
||||
"role": "model",
|
||||
"parts":[{
|
||||
"text": "Hello! It's so cold! Isn't that great?"}]
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"parts":[{
|
||||
"text": "What's your favorite season of the year?"}]
|
||||
}
|
||||
],
|
||||
"max_tokens": 64
|
||||
]
|
||||
});
|
||||
let response = gemini.do_generate(&prompt, run_params).await?;
|
||||
assert!(!response.generated_text.is_empty());
|
||||
let response = gemini.do_chat_completion(&prompt, run_params).await?;
|
||||
dbg!(&response);
|
||||
assert!(!response.is_empty());
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user