mirror of
https://github.com/SilasMarvin/lsp-ai.git
synced 2024-08-16 07:40:48 +03:00
commit
2c53880a77
2
Cargo.lock
generated
2
Cargo.lock
generated
@ -1518,7 +1518,7 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "lsp-ai"
|
name = "lsp-ai"
|
||||||
version = "0.2.0"
|
version = "0.3.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"anyhow",
|
"anyhow",
|
||||||
"assert_cmd",
|
"assert_cmd",
|
||||||
|
@ -58,7 +58,7 @@ LSP-AI aims to fill this gap by providing a language server that integrates AI-p
|
|||||||
- LSP-AI supports any editor that adheres to the Language Server Protocol (LSP), ensuring that a wide range of editors can leverage the AI capabilities provided by LSP-AI.
|
- LSP-AI supports any editor that adheres to the Language Server Protocol (LSP), ensuring that a wide range of editors can leverage the AI capabilities provided by LSP-AI.
|
||||||
|
|
||||||
5. **Flexible LLM Backend Support**:
|
5. **Flexible LLM Backend Support**:
|
||||||
- Currently, LSP-AI supports llama.cpp, Ollama, OpenAI-compatible APIs, Anthropic-compatible APIs and Mistral AI FIM-compatible APIs, giving developers the flexibility to choose their preferred backend. This list will soon grow.
|
- Currently, LSP-AI supports llama.cpp, Ollama, OpenAI-compatible APIs, Anthropic-compatible APIs, Gemini-compatible APIs and Mistral AI FIM-compatible APIs, giving developers the flexibility to choose their preferred backend. This list will soon grow.
|
||||||
|
|
||||||
6. **Future-Ready**:
|
6. **Future-Ready**:
|
||||||
- LSP-AI is committed to staying updated with the latest advancements in LLM-driven software development.
|
- LSP-AI is committed to staying updated with the latest advancements in LLM-driven software development.
|
||||||
|
@ -46,6 +46,8 @@ pub enum ValidModel {
|
|||||||
MistralFIM(MistralFIM),
|
MistralFIM(MistralFIM),
|
||||||
#[serde(rename = "ollama")]
|
#[serde(rename = "ollama")]
|
||||||
Ollama(Ollama),
|
Ollama(Ollama),
|
||||||
|
#[serde(rename = "gemini")]
|
||||||
|
Gemini(Gemini),
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
@ -171,6 +173,24 @@ pub struct OpenAI {
|
|||||||
pub model: String,
|
pub model: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize)]
|
||||||
|
#[serde(deny_unknown_fields)]
|
||||||
|
pub struct Gemini {
|
||||||
|
// The auth token env var name
|
||||||
|
pub auth_token_env_var_name: Option<String>,
|
||||||
|
// The auth token
|
||||||
|
pub auth_token: Option<String>,
|
||||||
|
// The completions endpoint
|
||||||
|
pub completions_endpoint: Option<String>,
|
||||||
|
// The chat endpoint
|
||||||
|
pub chat_endpoint: Option<String>,
|
||||||
|
// The maximum requests per second
|
||||||
|
#[serde(default = "max_requests_per_second_default")]
|
||||||
|
pub max_requests_per_second: f32,
|
||||||
|
// The model name
|
||||||
|
pub model: String,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize)]
|
#[derive(Clone, Debug, Deserialize)]
|
||||||
#[serde(deny_unknown_fields)]
|
#[serde(deny_unknown_fields)]
|
||||||
pub struct Anthropic {
|
pub struct Anthropic {
|
||||||
@ -272,6 +292,7 @@ impl Config {
|
|||||||
#[cfg(feature = "llama_cpp")]
|
#[cfg(feature = "llama_cpp")]
|
||||||
ValidModel::LLaMACPP(llama_cpp) => Ok(llama_cpp.max_requests_per_second),
|
ValidModel::LLaMACPP(llama_cpp) => Ok(llama_cpp.max_requests_per_second),
|
||||||
ValidModel::OpenAI(open_ai) => Ok(open_ai.max_requests_per_second),
|
ValidModel::OpenAI(open_ai) => Ok(open_ai.max_requests_per_second),
|
||||||
|
ValidModel::Gemini(gemini) => Ok(gemini.max_requests_per_second),
|
||||||
ValidModel::Anthropic(anthropic) => Ok(anthropic.max_requests_per_second),
|
ValidModel::Anthropic(anthropic) => Ok(anthropic.max_requests_per_second),
|
||||||
ValidModel::MistralFIM(mistral_fim) => Ok(mistral_fim.max_requests_per_second),
|
ValidModel::MistralFIM(mistral_fim) => Ok(mistral_fim.max_requests_per_second),
|
||||||
ValidModel::Ollama(ollama) => Ok(ollama.max_requests_per_second),
|
ValidModel::Ollama(ollama) => Ok(ollama.max_requests_per_second),
|
||||||
@ -403,6 +424,47 @@ mod test {
|
|||||||
Config::new(args).unwrap();
|
Config::new(args).unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn gemini_config() {
|
||||||
|
let args = json!({
|
||||||
|
"initializationOptions": {
|
||||||
|
"memory": {
|
||||||
|
"file_store": {}
|
||||||
|
},
|
||||||
|
"models": {
|
||||||
|
"model1": {
|
||||||
|
"type": "gemini",
|
||||||
|
"completions_endpoint": "https://generativelanguage.googleapis.com/v1beta/models/",
|
||||||
|
"model": "gemini-1.5-flash-latest",
|
||||||
|
"auth_token_env_var_name": "GEMINI_API_KEY",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"completion": {
|
||||||
|
"model": "model1",
|
||||||
|
"parameters": {
|
||||||
|
"systemInstruction": {
|
||||||
|
"role": "system",
|
||||||
|
"parts": [{
|
||||||
|
"text": "TEST system instruction"
|
||||||
|
}]
|
||||||
|
},
|
||||||
|
"generationConfig": {
|
||||||
|
"maxOutputTokens": 10
|
||||||
|
},
|
||||||
|
"contents": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"parts":[{
|
||||||
|
"text": "TEST - {CONTEXT} and {CODE}"}]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
Config::new(args).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn anthropic_config() {
|
fn anthropic_config() {
|
||||||
let args = json!({
|
let args = json!({
|
||||||
|
@ -114,7 +114,7 @@ impl FileStore {
|
|||||||
|
|
||||||
Ok(match prompt_type {
|
Ok(match prompt_type {
|
||||||
PromptType::ContextAndCode => {
|
PromptType::ContextAndCode => {
|
||||||
if params.messages.is_some() {
|
if params.is_for_chat {
|
||||||
let max_length = tokens_to_estimated_characters(params.max_context_length);
|
let max_length = tokens_to_estimated_characters(params.max_context_length);
|
||||||
let start = cursor_index.saturating_sub(max_length / 2);
|
let start = cursor_index.saturating_sub(max_length / 2);
|
||||||
let end = rope
|
let end = rope
|
||||||
@ -185,9 +185,9 @@ impl MemoryBackend for FileStore {
|
|||||||
&self,
|
&self,
|
||||||
position: &TextDocumentPositionParams,
|
position: &TextDocumentPositionParams,
|
||||||
prompt_type: PromptType,
|
prompt_type: PromptType,
|
||||||
params: Value,
|
params: &Value,
|
||||||
) -> anyhow::Result<Prompt> {
|
) -> anyhow::Result<Prompt> {
|
||||||
let params: MemoryRunParams = serde_json::from_value(params)?;
|
let params: MemoryRunParams = params.try_into()?;
|
||||||
self.build_code(position, prompt_type, params)
|
self.build_code(position, prompt_type, params)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -414,7 +414,7 @@ The end with a trailing new line
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
PromptType::ContextAndCode,
|
PromptType::ContextAndCode,
|
||||||
json!({}),
|
&json!({}),
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
let prompt: ContextAndCodePrompt = prompt.try_into()?;
|
let prompt: ContextAndCodePrompt = prompt.try_into()?;
|
||||||
@ -434,7 +434,7 @@ The end with a trailing new line
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
PromptType::FIM,
|
PromptType::FIM,
|
||||||
json!({}),
|
&json!({}),
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
let prompt: FIMPrompt = prompt.try_into()?;
|
let prompt: FIMPrompt = prompt.try_into()?;
|
||||||
@ -463,7 +463,7 @@ The end with a trailing new line
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
PromptType::ContextAndCode,
|
PromptType::ContextAndCode,
|
||||||
json!({
|
&json!({
|
||||||
"messages": []
|
"messages": []
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
@ -510,7 +510,7 @@ The end with a trailing new line
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
PromptType::ContextAndCode,
|
PromptType::ContextAndCode,
|
||||||
json!({}),
|
&json!({}),
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
let prompt: ContextAndCodePrompt = prompt.try_into()?;
|
let prompt: ContextAndCodePrompt = prompt.try_into()?;
|
||||||
@ -542,7 +542,7 @@ The end with a trailing new line
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
PromptType::ContextAndCode,
|
PromptType::ContextAndCode,
|
||||||
json!({"messages": []}),
|
&json!({"messages": []}),
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
let prompt: ContextAndCodePrompt = prompt.try_into()?;
|
let prompt: ContextAndCodePrompt = prompt.try_into()?;
|
||||||
|
@ -2,31 +2,35 @@ use lsp_types::{
|
|||||||
DidChangeTextDocumentParams, DidOpenTextDocumentParams, RenameFilesParams,
|
DidChangeTextDocumentParams, DidOpenTextDocumentParams, RenameFilesParams,
|
||||||
TextDocumentPositionParams,
|
TextDocumentPositionParams,
|
||||||
};
|
};
|
||||||
use serde::Deserialize;
|
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
|
|
||||||
use crate::config::{ChatMessage, Config, ValidMemoryBackend};
|
use crate::config::{Config, ValidMemoryBackend};
|
||||||
|
|
||||||
pub mod file_store;
|
pub mod file_store;
|
||||||
mod postgresml;
|
mod postgresml;
|
||||||
|
|
||||||
const fn max_context_length_default() -> usize {
|
|
||||||
1024
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
pub enum PromptType {
|
pub enum PromptType {
|
||||||
ContextAndCode,
|
ContextAndCode,
|
||||||
FIM,
|
FIM,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Deserialize)]
|
#[derive(Clone)]
|
||||||
pub struct MemoryRunParams {
|
pub struct MemoryRunParams {
|
||||||
pub messages: Option<Vec<ChatMessage>>,
|
pub is_for_chat: bool,
|
||||||
#[serde(default = "max_context_length_default")]
|
|
||||||
pub max_context_length: usize,
|
pub max_context_length: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl From<&Value> for MemoryRunParams {
|
||||||
|
fn from(value: &Value) -> Self {
|
||||||
|
Self {
|
||||||
|
max_context_length: value["max_context_length"].as_u64().unwrap_or(1024) as usize,
|
||||||
|
// messages are for most backends, contents are for Gemini
|
||||||
|
is_for_chat: value["messages"].is_array() || value["contents"].is_array(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct ContextAndCodePrompt {
|
pub struct ContextAndCodePrompt {
|
||||||
pub context: String,
|
pub context: String,
|
||||||
@ -119,7 +123,7 @@ pub trait MemoryBackend {
|
|||||||
&self,
|
&self,
|
||||||
position: &TextDocumentPositionParams,
|
position: &TextDocumentPositionParams,
|
||||||
prompt_type: PromptType,
|
prompt_type: PromptType,
|
||||||
params: Value,
|
params: &Value,
|
||||||
) -> anyhow::Result<Prompt>;
|
) -> anyhow::Result<Prompt>;
|
||||||
async fn get_filter_text(
|
async fn get_filter_text(
|
||||||
&self,
|
&self,
|
||||||
|
@ -132,9 +132,9 @@ impl MemoryBackend for PostgresML {
|
|||||||
&self,
|
&self,
|
||||||
position: &TextDocumentPositionParams,
|
position: &TextDocumentPositionParams,
|
||||||
prompt_type: PromptType,
|
prompt_type: PromptType,
|
||||||
params: Value,
|
params: &Value,
|
||||||
) -> anyhow::Result<Prompt> {
|
) -> anyhow::Result<Prompt> {
|
||||||
let params: MemoryRunParams = serde_json::from_value(params)?;
|
let params: MemoryRunParams = params.try_into()?;
|
||||||
let query = self
|
let query = self
|
||||||
.file_store
|
.file_store
|
||||||
.get_characters_around_position(position, 512)?;
|
.get_characters_around_position(position, 512)?;
|
||||||
|
@ -70,7 +70,7 @@ async fn do_task(
|
|||||||
}
|
}
|
||||||
WorkerRequest::Prompt(params) => {
|
WorkerRequest::Prompt(params) => {
|
||||||
let prompt = memory_backend
|
let prompt = memory_backend
|
||||||
.build_prompt(¶ms.position, params.prompt_type, params.params)
|
.build_prompt(¶ms.position, params.prompt_type, ¶ms.params)
|
||||||
.await?;
|
.await?;
|
||||||
params
|
params
|
||||||
.tx
|
.tx
|
||||||
|
237
src/transformer_backends/gemini.rs
Normal file
237
src/transformer_backends/gemini.rs
Normal file
@ -0,0 +1,237 @@
|
|||||||
|
use anyhow::Context;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use serde_json::{json, Value};
|
||||||
|
use tracing::instrument;
|
||||||
|
|
||||||
|
use super::TransformerBackend;
|
||||||
|
use crate::{
|
||||||
|
config,
|
||||||
|
memory_backends::{ContextAndCodePrompt, Prompt},
|
||||||
|
transformer_worker::{
|
||||||
|
DoGenerationResponse, DoGenerationStreamResponse, GenerationStreamRequest,
|
||||||
|
},
|
||||||
|
utils::format_context_code_in_str,
|
||||||
|
};
|
||||||
|
|
||||||
|
fn format_gemini_contents(
|
||||||
|
messages: &[GeminiContent],
|
||||||
|
prompt: &ContextAndCodePrompt,
|
||||||
|
) -> Vec<GeminiContent> {
|
||||||
|
messages
|
||||||
|
.iter()
|
||||||
|
.map(|m| {
|
||||||
|
GeminiContent::new(
|
||||||
|
m.role.to_owned(),
|
||||||
|
m.parts
|
||||||
|
.iter()
|
||||||
|
.map(|p| Part {
|
||||||
|
text: format_context_code_in_str(&p.text, &prompt.context, &prompt.code),
|
||||||
|
})
|
||||||
|
.collect(),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
const fn max_tokens_default() -> usize {
|
||||||
|
64
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||||
|
struct Part {
|
||||||
|
pub text: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||||
|
struct GeminiContent {
|
||||||
|
role: String,
|
||||||
|
parts: Vec<Part>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl GeminiContent {
|
||||||
|
fn new(role: String, parts: Vec<Part>) -> Self {
|
||||||
|
Self { role, parts }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize, Serialize, Clone)]
|
||||||
|
#[serde(deny_unknown_fields)]
|
||||||
|
pub struct GeminiGenerationConfig {
|
||||||
|
#[serde(rename = "stopSequences")]
|
||||||
|
#[serde(default)]
|
||||||
|
pub stop_sequences: Vec<String>,
|
||||||
|
#[serde(rename = "maxOutputTokens")]
|
||||||
|
#[serde(default = "max_tokens_default")]
|
||||||
|
pub max_output_tokens: usize,
|
||||||
|
pub temperature: Option<f32>,
|
||||||
|
#[serde(rename = "topP")]
|
||||||
|
pub top_p: Option<f32>,
|
||||||
|
#[serde(rename = "topK")]
|
||||||
|
pub top_k: Option<f32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
// NOTE: We cannot deny unknown fields as the provided parameters may contain other fields relevant to other processes
|
||||||
|
#[derive(Debug, Deserialize, Serialize, Clone)]
|
||||||
|
pub struct GeminiRunParams {
|
||||||
|
contents: Vec<GeminiContent>,
|
||||||
|
#[serde(rename = "systemInstruction")]
|
||||||
|
system_instruction: GeminiContent,
|
||||||
|
#[serde(rename = "generationConfig")]
|
||||||
|
generation_config: Option<GeminiGenerationConfig>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct Gemini {
|
||||||
|
configuration: config::Gemini,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Gemini {
|
||||||
|
pub fn new(configuration: config::Gemini) -> Self {
|
||||||
|
Self { configuration }
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_token(&self) -> anyhow::Result<String> {
|
||||||
|
if let Some(env_var_name) = &self.configuration.auth_token_env_var_name {
|
||||||
|
Ok(std::env::var(env_var_name)?)
|
||||||
|
} else if let Some(token) = &self.configuration.auth_token {
|
||||||
|
Ok(token.to_string())
|
||||||
|
} else {
|
||||||
|
anyhow::bail!(
|
||||||
|
"set `auth_token_env_var_name` or `auth_token` to use an Gemini compatible API"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get_chat(
|
||||||
|
&self,
|
||||||
|
messages: Vec<GeminiContent>,
|
||||||
|
params: GeminiRunParams,
|
||||||
|
) -> anyhow::Result<String> {
|
||||||
|
let client = reqwest::Client::new();
|
||||||
|
let token = self.get_token()?;
|
||||||
|
let res: serde_json::Value = client
|
||||||
|
.post(
|
||||||
|
self.configuration
|
||||||
|
.chat_endpoint
|
||||||
|
.as_ref()
|
||||||
|
.context("must specify `chat_endpoint` to use gemini")?
|
||||||
|
.to_owned()
|
||||||
|
+ self.configuration.model.as_ref()
|
||||||
|
+ ":generateContent?key="
|
||||||
|
+ token.as_ref(),
|
||||||
|
)
|
||||||
|
.header("Content-Type", "application/json")
|
||||||
|
.json(&json!({
|
||||||
|
"contents": messages,
|
||||||
|
"systemInstruction": params.system_instruction,
|
||||||
|
"generationConfig": params.generation_config,
|
||||||
|
}))
|
||||||
|
.send()
|
||||||
|
.await?
|
||||||
|
.json()
|
||||||
|
.await?;
|
||||||
|
if let Some(error) = res.get("error") {
|
||||||
|
anyhow::bail!("{:?}", error.to_string())
|
||||||
|
} else if let Some(candidates) = res.get("candidates") {
|
||||||
|
Ok(candidates
|
||||||
|
.get(0)
|
||||||
|
.unwrap()
|
||||||
|
.get("content")
|
||||||
|
.unwrap()
|
||||||
|
.get("parts")
|
||||||
|
.unwrap()
|
||||||
|
.get(0)
|
||||||
|
.unwrap()
|
||||||
|
.get("text")
|
||||||
|
.unwrap()
|
||||||
|
.clone()
|
||||||
|
.to_string())
|
||||||
|
} else {
|
||||||
|
anyhow::bail!("Unknown error while making request to Gemini: {:?}", res);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
async fn do_chat_completion(
|
||||||
|
&self,
|
||||||
|
prompt: &Prompt,
|
||||||
|
params: GeminiRunParams,
|
||||||
|
) -> anyhow::Result<String> {
|
||||||
|
match prompt {
|
||||||
|
Prompt::ContextAndCode(code_and_context) => {
|
||||||
|
let messages = format_gemini_contents(¶ms.contents, code_and_context);
|
||||||
|
self.get_chat(messages, params).await
|
||||||
|
}
|
||||||
|
_ => anyhow::bail!("Google Gemini backend does not yet support FIM"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait::async_trait]
|
||||||
|
impl TransformerBackend for Gemini {
|
||||||
|
#[instrument(skip(self))]
|
||||||
|
async fn do_generate(
|
||||||
|
&self,
|
||||||
|
prompt: &Prompt,
|
||||||
|
params: Value,
|
||||||
|
) -> anyhow::Result<DoGenerationResponse> {
|
||||||
|
let params: GeminiRunParams = serde_json::from_value(params)?;
|
||||||
|
let generated_text = self.do_chat_completion(prompt, params).await?;
|
||||||
|
Ok(DoGenerationResponse { generated_text })
|
||||||
|
}
|
||||||
|
|
||||||
|
#[instrument(skip(self))]
|
||||||
|
async fn do_generate_stream(
|
||||||
|
&self,
|
||||||
|
request: &GenerationStreamRequest,
|
||||||
|
_params: Value,
|
||||||
|
) -> anyhow::Result<DoGenerationStreamResponse> {
|
||||||
|
anyhow::bail!("GenerationStream is not yet implemented")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod test {
|
||||||
|
use super::*;
|
||||||
|
use serde_json::json;
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn gemini_chat_do_generate() -> anyhow::Result<()> {
|
||||||
|
let configuration: config::Gemini = serde_json::from_value(json!({
|
||||||
|
"chat_endpoint": "https://generativelanguage.googleapis.com/v1beta/models/",
|
||||||
|
"model": "gemini-1.5-flash",
|
||||||
|
"auth_token_env_var_name": "GEMINI_API_KEY",
|
||||||
|
}))?;
|
||||||
|
let gemini = Gemini::new(configuration);
|
||||||
|
let prompt = Prompt::default_with_cursor();
|
||||||
|
let run_params = json!({
|
||||||
|
"systemInstruction": {
|
||||||
|
"role": "system",
|
||||||
|
"parts": [{
|
||||||
|
"text": "You are a helpful and willing chatbot that will do whatever the user asks"
|
||||||
|
}]
|
||||||
|
},
|
||||||
|
"generationConfig": {
|
||||||
|
"maxOutputTokens": 10
|
||||||
|
},
|
||||||
|
"contents": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"parts":[{
|
||||||
|
"text": "Pretend you're a snowman and stay in character for each response."}]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "model",
|
||||||
|
"parts":[{
|
||||||
|
"text": "Hello! It's so cold! Isn't that great?"}]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"parts":[{
|
||||||
|
"text": "What's your favorite season of the year?"}]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
});
|
||||||
|
let response = gemini.do_generate(&prompt, run_params).await?;
|
||||||
|
dbg!(&response.generated_text);
|
||||||
|
assert!(!response.generated_text.is_empty());
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
@ -11,6 +11,7 @@ use crate::{
|
|||||||
};
|
};
|
||||||
|
|
||||||
mod anthropic;
|
mod anthropic;
|
||||||
|
mod gemini;
|
||||||
#[cfg(feature = "llama_cpp")]
|
#[cfg(feature = "llama_cpp")]
|
||||||
mod llama_cpp;
|
mod llama_cpp;
|
||||||
mod mistral_fim;
|
mod mistral_fim;
|
||||||
@ -66,6 +67,7 @@ impl TryFrom<ValidModel> for Box<dyn TransformerBackend + Send + Sync> {
|
|||||||
ValidModel::OpenAI(open_ai_config) => {
|
ValidModel::OpenAI(open_ai_config) => {
|
||||||
Ok(Box::new(open_ai::OpenAI::new(open_ai_config)))
|
Ok(Box::new(open_ai::OpenAI::new(open_ai_config)))
|
||||||
}
|
}
|
||||||
|
ValidModel::Gemini(gemini_config) => Ok(Box::new(gemini::Gemini::new(gemini_config))),
|
||||||
ValidModel::Anthropic(anthropic_config) => {
|
ValidModel::Anthropic(anthropic_config) => {
|
||||||
Ok(Box::new(anthropic::Anthropic::new(anthropic_config)))
|
Ok(Box::new(anthropic::Anthropic::new(anthropic_config)))
|
||||||
}
|
}
|
||||||
|
@ -163,7 +163,7 @@ impl OpenAI {
|
|||||||
self.configuration
|
self.configuration
|
||||||
.chat_endpoint
|
.chat_endpoint
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.context("must specify `completions_endpoint` to use completions")?,
|
.context("must specify `chat_endpoint` to use completions")?,
|
||||||
)
|
)
|
||||||
.bearer_auth(token)
|
.bearer_auth(token)
|
||||||
.header("Content-Type", "application/json")
|
.header("Content-Type", "application/json")
|
||||||
|
@ -29,14 +29,16 @@ pub fn format_chat_messages(
|
|||||||
.map(|m| {
|
.map(|m| {
|
||||||
ChatMessage::new(
|
ChatMessage::new(
|
||||||
m.role.to_owned(),
|
m.role.to_owned(),
|
||||||
m.content
|
format_context_code_in_str(&m.content, &prompt.context, &prompt.code),
|
||||||
.replace("{CONTEXT}", &prompt.context)
|
|
||||||
.replace("{CODE}", &prompt.code),
|
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
.collect()
|
.collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn format_context_code_in_str(s: &str, context: &str, code: &str) -> String {
|
||||||
|
s.replace("{CONTEXT}", context).replace("{CODE}", code)
|
||||||
|
}
|
||||||
|
|
||||||
pub fn format_context_code(context: &str, code: &str) -> String {
|
pub fn format_context_code(context: &str, code: &str) -> String {
|
||||||
format!("{context}\n\n{code}")
|
format!("{context}\n\n{code}")
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user