diff --git a/Cargo.lock b/Cargo.lock index 34f5542..524e12f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1518,7 +1518,7 @@ dependencies = [ [[package]] name = "lsp-ai" -version = "0.2.0" +version = "0.3.0" dependencies = [ "anyhow", "assert_cmd", diff --git a/README.md b/README.md index e5ee08a..3115638 100644 --- a/README.md +++ b/README.md @@ -58,7 +58,7 @@ LSP-AI aims to fill this gap by providing a language server that integrates AI-p - LSP-AI supports any editor that adheres to the Language Server Protocol (LSP), ensuring that a wide range of editors can leverage the AI capabilities provided by LSP-AI. 5. **Flexible LLM Backend Support**: - - Currently, LSP-AI supports llama.cpp, Ollama, OpenAI-compatible APIs, Anthropic-compatible APIs and Mistral AI FIM-compatible APIs, giving developers the flexibility to choose their preferred backend. This list will soon grow. + - Currently, LSP-AI supports llama.cpp, Ollama, OpenAI-compatible APIs, Anthropic-compatible APIs, Gemini-compatible APIs and Mistral AI FIM-compatible APIs, giving developers the flexibility to choose their preferred backend. This list will soon grow. 6. **Future-Ready**: - LSP-AI is committed to staying updated with the latest advancements in LLM-driven software development. diff --git a/src/config.rs b/src/config.rs index 8cbeadd..8b7b394 100644 --- a/src/config.rs +++ b/src/config.rs @@ -46,6 +46,8 @@ pub enum ValidModel { MistralFIM(MistralFIM), #[serde(rename = "ollama")] Ollama(Ollama), + #[serde(rename = "gemini")] + Gemini(Gemini), } #[derive(Debug, Clone, Deserialize, Serialize)] @@ -171,6 +173,24 @@ pub struct OpenAI { pub model: String, } +#[derive(Clone, Debug, Deserialize)] +#[serde(deny_unknown_fields)] +pub struct Gemini { + // The auth token env var name + pub auth_token_env_var_name: Option, + // The auth token + pub auth_token: Option, + // The completions endpoint + pub completions_endpoint: Option, + // The chat endpoint + pub chat_endpoint: Option, + // The maximum requests per second + #[serde(default = "max_requests_per_second_default")] + pub max_requests_per_second: f32, + // The model name + pub model: String, +} + #[derive(Clone, Debug, Deserialize)] #[serde(deny_unknown_fields)] pub struct Anthropic { @@ -272,6 +292,7 @@ impl Config { #[cfg(feature = "llama_cpp")] ValidModel::LLaMACPP(llama_cpp) => Ok(llama_cpp.max_requests_per_second), ValidModel::OpenAI(open_ai) => Ok(open_ai.max_requests_per_second), + ValidModel::Gemini(gemini) => Ok(gemini.max_requests_per_second), ValidModel::Anthropic(anthropic) => Ok(anthropic.max_requests_per_second), ValidModel::MistralFIM(mistral_fim) => Ok(mistral_fim.max_requests_per_second), ValidModel::Ollama(ollama) => Ok(ollama.max_requests_per_second), @@ -403,6 +424,47 @@ mod test { Config::new(args).unwrap(); } + #[test] + fn gemini_config() { + let args = json!({ + "initializationOptions": { + "memory": { + "file_store": {} + }, + "models": { + "model1": { + "type": "gemini", + "completions_endpoint": "https://generativelanguage.googleapis.com/v1beta/models/", + "model": "gemini-1.5-flash-latest", + "auth_token_env_var_name": "GEMINI_API_KEY", + }, + }, + "completion": { + "model": "model1", + "parameters": { + "systemInstruction": { + "role": "system", + "parts": [{ + "text": "TEST system instruction" + }] + }, + "generationConfig": { + "maxOutputTokens": 10 + }, + "contents": [ + { + "role": "user", + "parts":[{ + "text": "TEST - {CONTEXT} and {CODE}"}] + } + ] + } + } + } + }); + Config::new(args).unwrap(); + } + #[test] fn anthropic_config() { let args = json!({ diff --git a/src/memory_backends/file_store.rs b/src/memory_backends/file_store.rs index 219998b..4d70509 100644 --- a/src/memory_backends/file_store.rs +++ b/src/memory_backends/file_store.rs @@ -114,7 +114,7 @@ impl FileStore { Ok(match prompt_type { PromptType::ContextAndCode => { - if params.messages.is_some() { + if params.is_for_chat { let max_length = tokens_to_estimated_characters(params.max_context_length); let start = cursor_index.saturating_sub(max_length / 2); let end = rope @@ -185,9 +185,9 @@ impl MemoryBackend for FileStore { &self, position: &TextDocumentPositionParams, prompt_type: PromptType, - params: Value, + params: &Value, ) -> anyhow::Result { - let params: MemoryRunParams = serde_json::from_value(params)?; + let params: MemoryRunParams = params.try_into()?; self.build_code(position, prompt_type, params) } @@ -414,7 +414,7 @@ The end with a trailing new line }, }, PromptType::ContextAndCode, - json!({}), + &json!({}), ) .await?; let prompt: ContextAndCodePrompt = prompt.try_into()?; @@ -434,7 +434,7 @@ The end with a trailing new line }, }, PromptType::FIM, - json!({}), + &json!({}), ) .await?; let prompt: FIMPrompt = prompt.try_into()?; @@ -463,7 +463,7 @@ The end with a trailing new line }, }, PromptType::ContextAndCode, - json!({ + &json!({ "messages": [] }), ) @@ -510,7 +510,7 @@ The end with a trailing new line }, }, PromptType::ContextAndCode, - json!({}), + &json!({}), ) .await?; let prompt: ContextAndCodePrompt = prompt.try_into()?; @@ -542,7 +542,7 @@ The end with a trailing new line }, }, PromptType::ContextAndCode, - json!({"messages": []}), + &json!({"messages": []}), ) .await?; let prompt: ContextAndCodePrompt = prompt.try_into()?; diff --git a/src/memory_backends/mod.rs b/src/memory_backends/mod.rs index e824d3f..52a8974 100644 --- a/src/memory_backends/mod.rs +++ b/src/memory_backends/mod.rs @@ -2,31 +2,35 @@ use lsp_types::{ DidChangeTextDocumentParams, DidOpenTextDocumentParams, RenameFilesParams, TextDocumentPositionParams, }; -use serde::Deserialize; use serde_json::Value; -use crate::config::{ChatMessage, Config, ValidMemoryBackend}; +use crate::config::{Config, ValidMemoryBackend}; pub mod file_store; mod postgresml; -const fn max_context_length_default() -> usize { - 1024 -} - #[derive(Clone, Debug)] pub enum PromptType { ContextAndCode, FIM, } -#[derive(Clone, Deserialize)] +#[derive(Clone)] pub struct MemoryRunParams { - pub messages: Option>, - #[serde(default = "max_context_length_default")] + pub is_for_chat: bool, pub max_context_length: usize, } +impl From<&Value> for MemoryRunParams { + fn from(value: &Value) -> Self { + Self { + max_context_length: value["max_context_length"].as_u64().unwrap_or(1024) as usize, + // messages are for most backends, contents are for Gemini + is_for_chat: value["messages"].is_array() || value["contents"].is_array(), + } + } +} + #[derive(Debug)] pub struct ContextAndCodePrompt { pub context: String, @@ -119,7 +123,7 @@ pub trait MemoryBackend { &self, position: &TextDocumentPositionParams, prompt_type: PromptType, - params: Value, + params: &Value, ) -> anyhow::Result; async fn get_filter_text( &self, diff --git a/src/memory_backends/postgresml/mod.rs b/src/memory_backends/postgresml/mod.rs index 8e0f748..8b007ab 100644 --- a/src/memory_backends/postgresml/mod.rs +++ b/src/memory_backends/postgresml/mod.rs @@ -132,9 +132,9 @@ impl MemoryBackend for PostgresML { &self, position: &TextDocumentPositionParams, prompt_type: PromptType, - params: Value, + params: &Value, ) -> anyhow::Result { - let params: MemoryRunParams = serde_json::from_value(params)?; + let params: MemoryRunParams = params.try_into()?; let query = self .file_store .get_characters_around_position(position, 512)?; diff --git a/src/memory_worker.rs b/src/memory_worker.rs index 86082f8..39cad6c 100644 --- a/src/memory_worker.rs +++ b/src/memory_worker.rs @@ -70,7 +70,7 @@ async fn do_task( } WorkerRequest::Prompt(params) => { let prompt = memory_backend - .build_prompt(¶ms.position, params.prompt_type, params.params) + .build_prompt(¶ms.position, params.prompt_type, ¶ms.params) .await?; params .tx diff --git a/src/transformer_backends/gemini.rs b/src/transformer_backends/gemini.rs new file mode 100644 index 0000000..3203c48 --- /dev/null +++ b/src/transformer_backends/gemini.rs @@ -0,0 +1,237 @@ +use anyhow::Context; +use serde::{Deserialize, Serialize}; +use serde_json::{json, Value}; +use tracing::instrument; + +use super::TransformerBackend; +use crate::{ + config, + memory_backends::{ContextAndCodePrompt, Prompt}, + transformer_worker::{ + DoGenerationResponse, DoGenerationStreamResponse, GenerationStreamRequest, + }, + utils::format_context_code_in_str, +}; + +fn format_gemini_contents( + messages: &[GeminiContent], + prompt: &ContextAndCodePrompt, +) -> Vec { + messages + .iter() + .map(|m| { + GeminiContent::new( + m.role.to_owned(), + m.parts + .iter() + .map(|p| Part { + text: format_context_code_in_str(&p.text, &prompt.context, &prompt.code), + }) + .collect(), + ) + }) + .collect() +} + +const fn max_tokens_default() -> usize { + 64 +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +struct Part { + pub text: String, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +struct GeminiContent { + role: String, + parts: Vec, +} + +impl GeminiContent { + fn new(role: String, parts: Vec) -> Self { + Self { role, parts } + } +} + +#[derive(Debug, Deserialize, Serialize, Clone)] +#[serde(deny_unknown_fields)] +pub struct GeminiGenerationConfig { + #[serde(rename = "stopSequences")] + #[serde(default)] + pub stop_sequences: Vec, + #[serde(rename = "maxOutputTokens")] + #[serde(default = "max_tokens_default")] + pub max_output_tokens: usize, + pub temperature: Option, + #[serde(rename = "topP")] + pub top_p: Option, + #[serde(rename = "topK")] + pub top_k: Option, +} + +// NOTE: We cannot deny unknown fields as the provided parameters may contain other fields relevant to other processes +#[derive(Debug, Deserialize, Serialize, Clone)] +pub struct GeminiRunParams { + contents: Vec, + #[serde(rename = "systemInstruction")] + system_instruction: GeminiContent, + #[serde(rename = "generationConfig")] + generation_config: Option, +} + +pub struct Gemini { + configuration: config::Gemini, +} + +impl Gemini { + pub fn new(configuration: config::Gemini) -> Self { + Self { configuration } + } + + fn get_token(&self) -> anyhow::Result { + if let Some(env_var_name) = &self.configuration.auth_token_env_var_name { + Ok(std::env::var(env_var_name)?) + } else if let Some(token) = &self.configuration.auth_token { + Ok(token.to_string()) + } else { + anyhow::bail!( + "set `auth_token_env_var_name` or `auth_token` to use an Gemini compatible API" + ) + } + } + + async fn get_chat( + &self, + messages: Vec, + params: GeminiRunParams, + ) -> anyhow::Result { + let client = reqwest::Client::new(); + let token = self.get_token()?; + let res: serde_json::Value = client + .post( + self.configuration + .chat_endpoint + .as_ref() + .context("must specify `chat_endpoint` to use gemini")? + .to_owned() + + self.configuration.model.as_ref() + + ":generateContent?key=" + + token.as_ref(), + ) + .header("Content-Type", "application/json") + .json(&json!({ + "contents": messages, + "systemInstruction": params.system_instruction, + "generationConfig": params.generation_config, + })) + .send() + .await? + .json() + .await?; + if let Some(error) = res.get("error") { + anyhow::bail!("{:?}", error.to_string()) + } else if let Some(candidates) = res.get("candidates") { + Ok(candidates + .get(0) + .unwrap() + .get("content") + .unwrap() + .get("parts") + .unwrap() + .get(0) + .unwrap() + .get("text") + .unwrap() + .clone() + .to_string()) + } else { + anyhow::bail!("Unknown error while making request to Gemini: {:?}", res); + } + } + async fn do_chat_completion( + &self, + prompt: &Prompt, + params: GeminiRunParams, + ) -> anyhow::Result { + match prompt { + Prompt::ContextAndCode(code_and_context) => { + let messages = format_gemini_contents(¶ms.contents, code_and_context); + self.get_chat(messages, params).await + } + _ => anyhow::bail!("Google Gemini backend does not yet support FIM"), + } + } +} + +#[async_trait::async_trait] +impl TransformerBackend for Gemini { + #[instrument(skip(self))] + async fn do_generate( + &self, + prompt: &Prompt, + params: Value, + ) -> anyhow::Result { + let params: GeminiRunParams = serde_json::from_value(params)?; + let generated_text = self.do_chat_completion(prompt, params).await?; + Ok(DoGenerationResponse { generated_text }) + } + + #[instrument(skip(self))] + async fn do_generate_stream( + &self, + request: &GenerationStreamRequest, + _params: Value, + ) -> anyhow::Result { + anyhow::bail!("GenerationStream is not yet implemented") + } +} + +#[cfg(test)] +mod test { + use super::*; + use serde_json::json; + + #[tokio::test] + async fn gemini_chat_do_generate() -> anyhow::Result<()> { + let configuration: config::Gemini = serde_json::from_value(json!({ + "chat_endpoint": "https://generativelanguage.googleapis.com/v1beta/models/", + "model": "gemini-1.5-flash", + "auth_token_env_var_name": "GEMINI_API_KEY", + }))?; + let gemini = Gemini::new(configuration); + let prompt = Prompt::default_with_cursor(); + let run_params = json!({ + "systemInstruction": { + "role": "system", + "parts": [{ + "text": "You are a helpful and willing chatbot that will do whatever the user asks" + }] + }, + "generationConfig": { + "maxOutputTokens": 10 + }, + "contents": [ + { + "role": "user", + "parts":[{ + "text": "Pretend you're a snowman and stay in character for each response."}] + }, + { + "role": "model", + "parts":[{ + "text": "Hello! It's so cold! Isn't that great?"}] + }, + { + "role": "user", + "parts":[{ + "text": "What's your favorite season of the year?"}] + } + ] + }); + let response = gemini.do_generate(&prompt, run_params).await?; + dbg!(&response.generated_text); + assert!(!response.generated_text.is_empty()); + Ok(()) + } +} diff --git a/src/transformer_backends/mod.rs b/src/transformer_backends/mod.rs index c56c109..f962564 100644 --- a/src/transformer_backends/mod.rs +++ b/src/transformer_backends/mod.rs @@ -11,6 +11,7 @@ use crate::{ }; mod anthropic; +mod gemini; #[cfg(feature = "llama_cpp")] mod llama_cpp; mod mistral_fim; @@ -66,6 +67,7 @@ impl TryFrom for Box { ValidModel::OpenAI(open_ai_config) => { Ok(Box::new(open_ai::OpenAI::new(open_ai_config))) } + ValidModel::Gemini(gemini_config) => Ok(Box::new(gemini::Gemini::new(gemini_config))), ValidModel::Anthropic(anthropic_config) => { Ok(Box::new(anthropic::Anthropic::new(anthropic_config))) } diff --git a/src/transformer_backends/open_ai/mod.rs b/src/transformer_backends/open_ai/mod.rs index a004614..d516adf 100644 --- a/src/transformer_backends/open_ai/mod.rs +++ b/src/transformer_backends/open_ai/mod.rs @@ -163,7 +163,7 @@ impl OpenAI { self.configuration .chat_endpoint .as_ref() - .context("must specify `completions_endpoint` to use completions")?, + .context("must specify `chat_endpoint` to use completions")?, ) .bearer_auth(token) .header("Content-Type", "application/json") diff --git a/src/utils.rs b/src/utils.rs index 85d22f9..ea5d652 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -29,14 +29,16 @@ pub fn format_chat_messages( .map(|m| { ChatMessage::new( m.role.to_owned(), - m.content - .replace("{CONTEXT}", &prompt.context) - .replace("{CODE}", &prompt.code), + format_context_code_in_str(&m.content, &prompt.context, &prompt.code), ) }) .collect() } +pub fn format_context_code_in_str(s: &str, context: &str, code: &str) -> String { + s.replace("{CONTEXT}", context).replace("{CODE}", code) +} + pub fn format_context_code(context: &str, code: &str) -> String { format!("{context}\n\n{code}") }