mirror of
https://github.com/SilasMarvin/lsp-ai.git
synced 2024-09-11 12:25:48 +03:00
fix chat
This commit is contained in:
parent
ad9f7381ea
commit
c8993bf740
@ -9,7 +9,8 @@ use crate::{
|
|||||||
memory_backends::Prompt,
|
memory_backends::Prompt,
|
||||||
transformer_worker::{
|
transformer_worker::{
|
||||||
DoGenerationResponse, DoGenerationStreamResponse, GenerationStreamRequest,
|
DoGenerationResponse, DoGenerationStreamResponse, GenerationStreamRequest,
|
||||||
}, utils::{format_chat_messages, format_context_code},
|
},
|
||||||
|
utils::{format_chat_messages, format_context_code},
|
||||||
};
|
};
|
||||||
|
|
||||||
const fn max_tokens_default() -> usize {
|
const fn max_tokens_default() -> usize {
|
||||||
@ -23,12 +24,22 @@ const fn top_p_default() -> f32 {
|
|||||||
const fn temperature_default() -> f32 {
|
const fn temperature_default() -> f32 {
|
||||||
0.1
|
0.1
|
||||||
}
|
}
|
||||||
|
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||||
|
struct Part {
|
||||||
|
pub text: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||||
|
struct GeminiChatMessage {
|
||||||
|
role: String,
|
||||||
|
parts: Vec<Part>,
|
||||||
|
}
|
||||||
|
|
||||||
// NOTE: We cannot deny unknown fields as the provided parameters may contain other fields relevant to other processes
|
// NOTE: We cannot deny unknown fields as the provided parameters may contain other fields relevant to other processes
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize, Clone)]
|
||||||
pub struct GeminiRunParams {
|
pub struct GeminiRunParams {
|
||||||
pub fim: Option<FIM>,
|
pub fim: Option<FIM>,
|
||||||
messages: Option<Vec<ChatMessage>>,
|
contents: Option<Vec<GeminiChatMessage>>,
|
||||||
#[serde(default = "max_tokens_default")]
|
#[serde(default = "max_tokens_default")]
|
||||||
pub max_tokens: usize,
|
pub max_tokens: usize,
|
||||||
#[serde(default = "top_p_default")]
|
#[serde(default = "top_p_default")]
|
||||||
@ -118,16 +129,61 @@ impl Gemini {
|
|||||||
anyhow::bail!("Unknown error while making request to Gemini: {:?}", res);
|
anyhow::bail!("Unknown error while making request to Gemini: {:?}", res);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
async fn get_chat(
|
||||||
|
&self,
|
||||||
|
messages: &[GeminiChatMessage],
|
||||||
|
_params: GeminiRunParams,
|
||||||
|
) -> anyhow::Result<String> {
|
||||||
|
let client = reqwest::Client::new();
|
||||||
|
let token = self.get_token()?;
|
||||||
|
let res: serde_json::Value = client
|
||||||
|
.post(
|
||||||
|
self.configuration
|
||||||
|
.chat_endpoint
|
||||||
|
.as_ref()
|
||||||
|
.context("must specify `chat_endpoint` to use gemini")?
|
||||||
|
.to_owned()
|
||||||
|
+ self.configuration.model.as_ref()
|
||||||
|
+ ":generateContent?key="
|
||||||
|
+ token.as_ref(),
|
||||||
|
)
|
||||||
|
.header("Content-Type", "application/json")
|
||||||
|
.json(&json!({
|
||||||
|
"contents": messages
|
||||||
|
}))
|
||||||
|
.send()
|
||||||
|
.await?
|
||||||
|
.json()
|
||||||
|
.await?;
|
||||||
|
if let Some(error) = res.get("error") {
|
||||||
|
anyhow::bail!("{:?}", error.to_string())
|
||||||
|
} else if let Some(candidates) = res.get("candidates") {
|
||||||
|
Ok(candidates
|
||||||
|
.get(0)
|
||||||
|
.unwrap()
|
||||||
|
.get("content")
|
||||||
|
.unwrap()
|
||||||
|
.get("parts")
|
||||||
|
.unwrap()
|
||||||
|
.get(0)
|
||||||
|
.unwrap()
|
||||||
|
.get("text")
|
||||||
|
.unwrap()
|
||||||
|
.clone()
|
||||||
|
.to_string())
|
||||||
|
} else {
|
||||||
|
anyhow::bail!("Unknown error while making request to Gemini: {:?}", res);
|
||||||
|
}
|
||||||
|
}
|
||||||
async fn do_chat_completion(
|
async fn do_chat_completion(
|
||||||
&self,
|
&self,
|
||||||
prompt: &Prompt,
|
prompt: &Prompt,
|
||||||
params: GeminiRunParams,
|
params: GeminiRunParams,
|
||||||
) -> anyhow::Result<String> {
|
) -> anyhow::Result<String> {
|
||||||
match prompt {
|
match prompt {
|
||||||
Prompt::ContextAndCode(code_and_context) => match ¶ms.messages {
|
Prompt::ContextAndCode(code_and_context) => match ¶ms.contents {
|
||||||
Some(completion_messages) => {
|
Some(completion_messages) => {
|
||||||
todo!();
|
self.get_chat(completion_messages, params.clone()).await
|
||||||
}
|
}
|
||||||
None => {
|
None => {
|
||||||
self.get_completion(
|
self.get_completion(
|
||||||
@ -204,4 +260,37 @@ mod test {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
// gemini_chat_do_generate TODO
|
// gemini_chat_do_generate TODO
|
||||||
|
#[tokio::test]
|
||||||
|
async fn gemini_chat_do_generate() -> anyhow::Result<()> {
|
||||||
|
let configuration: config::Gemini = serde_json::from_value(json!({
|
||||||
|
"chat_endpoint": "https://generativelanguage.googleapis.com/v1beta/models/",
|
||||||
|
"model": "gemini-1.5-flash",
|
||||||
|
"auth_token_env_var_name": "GEMINI_API_KEY",
|
||||||
|
}))?;
|
||||||
|
let gemini = Gemini::new(configuration);
|
||||||
|
let prompt = Prompt::default_with_cursor();
|
||||||
|
let run_params = json!({
|
||||||
|
"contents": [
|
||||||
|
{
|
||||||
|
"role":"user",
|
||||||
|
"parts":[{
|
||||||
|
"text": "Pretend you're a snowman and stay in character for each response."}]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "model",
|
||||||
|
"parts":[{
|
||||||
|
"text": "Hello! It's so cold! Isn't that great?"}]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"parts":[{
|
||||||
|
"text": "What's your favorite season of the year?"}]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
});
|
||||||
|
let response = gemini.do_generate(&prompt, run_params).await?;
|
||||||
|
dbg!(&response.generated_text);
|
||||||
|
assert!(!response.generated_text.is_empty());
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user