Added backends

This commit is contained in:
Silas Marvin 2024-02-19 18:51:05 -10:00
parent 60df273a4f
commit d0bb139bc9
18 changed files with 773 additions and 465 deletions

23
design.txt Normal file
View File

@ -0,0 +1,23 @@
# Overview
LSP AI should support multiple transform_backends:
- Python - LLAMA CPP
- Python - SOME OTHER LIBRARY
- PostgresML
pub trait TransformBackend {
// These all take memory backends as an argument
do_completion()
do_generate()
do_generate_stream()
}
LSP AI should support multiple memory_backends:
- SIMPLE FILE STORE
- IN MEMORY VECTOR STORE
- PostgresML
pub trait MemoryBackend {
// Some file change ones
get_context() // Depending on the memory backend this will do very different things
}

View File

@ -24,7 +24,17 @@
"command": "lsp-ai.generateStream",
"title": "LSP AI Generate Stream"
}
]
],
"configuration": {
"title": "Configuration",
"properties": {
"configuration.json": {
"type": "json",
"default": "{}",
"description": "JSON configuration for LSP AI"
}
}
}
},
"devDependencies": {
"@types/node": "^20.11.0",

181
src/configuration.rs Normal file
View File

@ -0,0 +1,181 @@
use anyhow::{Context, Result};
use serde::Deserialize;
use serde_json::Value;
use std::collections::HashMap;
#[cfg(target_os = "macos")]
const DEFAULT_LLAMA_CPP_N_CTX: usize = 1024;
const DEFAULT_MAX_COMPLETION_TOKENS: usize = 32;
const DEFAULT_MAX_GENERATION_TOKENS: usize = 256;
pub enum ValidMemoryBackend {
FileStore,
PostgresML,
}
pub enum ValidTransformerBackend {
LlamaCPP,
PostgresML,
}
// TODO: Review this for real lol
#[derive(Clone, Deserialize)]
pub struct FIM {
prefix: String,
middle: String,
suffix: String,
}
// TODO: Add some default things
#[derive(Clone, Deserialize)]
pub struct MaxNewTokens {
pub completion: usize,
pub generation: usize,
}
impl Default for MaxNewTokens {
fn default() -> Self {
Self {
completion: DEFAULT_MAX_COMPLETION_TOKENS,
generation: DEFAULT_MAX_GENERATION_TOKENS,
}
}
}
#[derive(Clone, Deserialize)]
struct ValidMemoryConfiguration {
file_store: Option<Value>,
}
#[derive(Clone, Deserialize)]
struct ModelGGUF {
repository: String,
name: String,
// Fill in the middle support
fim: Option<FIM>,
// The maximum number of new tokens to generate
#[serde(default)]
max_new_tokens: MaxNewTokens,
// Kwargs passed to LlamaCPP
#[serde(flatten)]
kwargs: HashMap<String, Value>,
}
#[derive(Clone, Deserialize)]
struct ValidMacTransformerConfiguration {
model_gguf: Option<ModelGGUF>,
}
#[derive(Clone, Deserialize)]
struct ValidLinuxTransformerConfiguration {
model_gguf: Option<ModelGGUF>,
}
#[derive(Clone, Deserialize)]
struct ValidConfiguration {
memory: ValidMemoryConfiguration,
// TODO: Add renam here
#[cfg(target_os = "macos")]
#[serde(alias = "macos")]
transformer: ValidMacTransformerConfiguration,
#[cfg(target_os = "linux")]
#[serde(alias = "linux")]
transformer: ValidLinuxTransformerConfiguration,
}
#[derive(Clone)]
pub struct Configuration {
valid_config: ValidConfiguration,
}
impl Configuration {
pub fn new(mut args: Value) -> Result<Self> {
let configuration_args = args
.as_object_mut()
.context("Server configuration must be a JSON object")?
.remove("initializationOptions")
.unwrap_or_default();
let valid_args: ValidConfiguration = serde_json::from_value(configuration_args)?;
// TODO: Make sure they only specified one model or something ya know
Ok(Self {
valid_config: valid_args,
})
}
pub fn get_memory_backend(&self) -> Result<ValidMemoryBackend> {
if self.valid_config.memory.file_store.is_some() {
Ok(ValidMemoryBackend::FileStore)
} else {
anyhow::bail!("Invalid memory configuration")
}
}
pub fn get_transformer_backend(&self) -> Result<ValidTransformerBackend> {
if self.valid_config.transformer.model_gguf.is_some() {
Ok(ValidTransformerBackend::LlamaCPP)
} else {
anyhow::bail!("Invalid model configuration")
}
}
pub fn get_maximum_context_length(&self) -> usize {
if let Some(model_gguf) = &self.valid_config.transformer.model_gguf {
model_gguf
.kwargs
.get("n_ctx")
.map(|v| {
v.as_u64()
.map(|u| u as usize)
.unwrap_or(DEFAULT_LLAMA_CPP_N_CTX)
})
.unwrap_or(DEFAULT_LLAMA_CPP_N_CTX)
} else {
panic!("We currently only support gguf models using llama cpp")
}
}
pub fn get_max_new_tokens(&self) -> &MaxNewTokens {
if let Some(model_gguf) = &self.valid_config.transformer.model_gguf {
&model_gguf.max_new_tokens
} else {
panic!("We currently only support gguf models using llama cpp")
}
}
pub fn supports_fim(&self) -> bool {
false
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn custom_mac_gguf_model() {
let args = json!({
"initializationOptions": {
"memory": {
"file_store": {}
},
"macos": {
"model_gguf": {
"repository": "deepseek-coder-6.7b-base",
"name": "Q4_K_M.gguf",
"max_new_tokens": {
"completion": 32,
"generation": 256,
},
"n_ctx": 2048,
"n_threads": 8,
"n_gpu_layers": 35,
"chat_template": "",
}
},
}
});
let _ = Configuration::new(args).unwrap();
}
}

View File

@ -1,4 +1,4 @@
use lsp_types::{PartialResultParams, ProgressToken, TextDocumentPositionParams};
use lsp_types::{ProgressToken, TextDocumentPositionParams};
use serde::{Deserialize, Serialize};
pub enum GenerateStream {}

View File

@ -1,31 +1,28 @@
use anyhow::{Context, Result};
use anyhow::Result;
use lsp_server::{Connection, ExtractError, Message, Notification, Request, RequestId};
use lsp_types::{
request::Completion, CompletionOptions, DidChangeTextDocumentParams, DidOpenTextDocumentParams,
RenameFilesParams, ServerCapabilities, TextDocumentSyncKind,
};
use once_cell::sync::Lazy;
use parking_lot::Mutex;
use pyo3::prelude::*;
use ropey::Rope;
use serde::Deserialize;
use std::{collections::HashMap, sync::Arc, thread};
use std::{sync::Arc, thread};
mod configuration;
mod custom_requests;
mod memory_backends;
mod transformer_backends;
mod utils;
mod worker;
use configuration::Configuration;
use custom_requests::generate::Generate;
use worker::{CompletionRequest, GenerateRequest, WorkerRequest};
use memory_backends::MemoryBackend;
use transformer_backends::TransformerBackend;
use worker::{CompletionRequest, GenerateRequest, Worker, WorkerRequest};
use crate::{custom_requests::generate_stream::GenerateStream, worker::GenerateStreamRequest};
pub static PY_MODULE: Lazy<Result<Py<PyAny>>> = Lazy::new(|| {
pyo3::Python::with_gil(|py| -> Result<Py<PyAny>> {
let src = include_str!("python/transformers.py");
Ok(pyo3::types::PyModule::from_code(py, src, "transformers.py", "transformers")?.into())
})
});
// Taken directly from: https://github.com/rust-lang/rust-analyzer
fn notification_is<N: lsp_types::notification::Notification>(notification: &Notification) -> bool {
notification.method == N::METHOD
@ -52,55 +49,47 @@ fn main() -> Result<()> {
)),
..Default::default()
})?;
let initialization_params = connection.initialize(server_capabilities)?;
let initialization_args = connection.initialize(server_capabilities)?;
// Activate the python venv
Python::with_gil(|py| -> Result<()> {
let activate: Py<PyAny> = PY_MODULE
.as_ref()
.map_err(anyhow::Error::msg)?
.getattr(py, "activate_venv")?;
activate.call1(py, ("/Users/silas/Projects/lsp-ai/venv",))?;
Ok(())
})?;
main_loop(connection, initialization_params)?;
main_loop(connection, initialization_args)?;
io_threads.join()?;
Ok(())
}
#[derive(Deserialize)]
struct Params {}
// This main loop is tricky
// We create a worker thread that actually does the heavy lifting because we do not want to process every completion request we get
// Completion requests may take a few seconds given the model configuration and hardware allowed, and we only want to process the latest completion request
fn main_loop(connection: Connection, params: serde_json::Value) -> Result<()> {
let _params: Params = serde_json::from_value(params)?;
// Note that we also want to have the memory backend in the worker thread as that may also involve heavy computations
fn main_loop(connection: Connection, args: serde_json::Value) -> Result<()> {
let args = Configuration::new(args)?;
// Set the model
Python::with_gil(|py| -> Result<()> {
let activate: Py<PyAny> = PY_MODULE
.as_ref()
.map_err(anyhow::Error::msg)?
.getattr(py, "set_model")?;
activate.call1(py, ("",))?;
Ok(())
})?;
// Set the transformer_backend
let transformer_backend: Box<dyn TransformerBackend + Send> = args.clone().try_into()?;
transformer_backend.init()?;
// Prep variables
// Set the memory_backend
let memory_backend: Arc<Mutex<Box<dyn MemoryBackend + Send>>> =
Arc::new(Mutex::new(args.clone().try_into()?));
// Wrap the connection for sharing between threads
let connection = Arc::new(connection);
let mut file_map: HashMap<String, Rope> = HashMap::new();
// How we communicate between the worker and receiver threads
let last_worker_request = Arc::new(Mutex::new(None));
// Thread local variables
let thread_memory_backend = memory_backend.clone();
let thread_last_worker_request = last_worker_request.clone();
let thread_connection = connection.clone();
// TODO: Pass some backend into here
thread::spawn(move || {
worker::run(thread_last_worker_request, thread_connection);
Worker::new(
transformer_backend,
thread_memory_backend,
thread_last_worker_request,
thread_connection,
)
.run();
});
for msg in &connection.receiver {
@ -115,41 +104,30 @@ fn main_loop(connection: Connection, params: serde_json::Value) -> Result<()> {
if request_is::<Completion>(&req) {
match cast::<Completion>(req) {
Ok((id, params)) => {
let rope = file_map
.get(params.text_document_position.text_document.uri.as_str())
.context("Error file not found")?
.clone();
eprintln!("******{:?}********", id);
let mut lcr = last_worker_request.lock();
let completion_request = CompletionRequest::new(id, params, rope);
let completion_request = CompletionRequest::new(id, params);
*lcr = Some(WorkerRequest::Completion(completion_request));
}
Err(err) => panic!("{err:?}"),
Err(err) => eprintln!("{err:?}"),
}
} else if request_is::<Generate>(&req) {
match cast::<Generate>(req) {
Ok((id, params)) => {
let rope = file_map
.get(params.text_document_position.text_document.uri.as_str())
.context("Error file not found")?
.clone();
let mut lcr = last_worker_request.lock();
let completion_request = GenerateRequest::new(id, params, rope);
let completion_request = GenerateRequest::new(id, params);
*lcr = Some(WorkerRequest::Generate(completion_request));
}
Err(err) => panic!("{err:?}"),
Err(err) => eprintln!("{err:?}"),
}
} else if request_is::<GenerateStream>(&req) {
match cast::<GenerateStream>(req) {
Ok((id, params)) => {
let rope = file_map
.get(params.text_document_position.text_document.uri.as_str())
.context("Error file not found")?
.clone();
let mut lcr = last_worker_request.lock();
let completion_request = GenerateStreamRequest::new(id, params, rope);
let completion_request = GenerateStreamRequest::new(id, params);
*lcr = Some(WorkerRequest::GenerateStream(completion_request));
}
Err(err) => panic!("{err:?}"),
Err(err) => eprintln!("{err:?}"),
}
} else {
eprintln!("lsp-ai currently only supports textDocument/completion, textDocument/generate and textDocument/generateStream")
@ -158,33 +136,13 @@ fn main_loop(connection: Connection, params: serde_json::Value) -> Result<()> {
Message::Notification(not) => {
if notification_is::<lsp_types::notification::DidOpenTextDocument>(&not) {
let params: DidOpenTextDocumentParams = serde_json::from_value(not.params)?;
let rope = Rope::from_str(&params.text_document.text);
file_map.insert(params.text_document.uri.to_string(), rope);
memory_backend.lock().opened_text_document(params)?;
} else if notification_is::<lsp_types::notification::DidChangeTextDocument>(&not) {
let params: DidChangeTextDocumentParams = serde_json::from_value(not.params)?;
let rope = file_map
.get_mut(params.text_document.uri.as_str())
.context("Error trying to get file that does not exist")?;
for change in params.content_changes {
// If range is ommitted, text is the new text of the document
if let Some(range) = change.range {
let start_index = rope.line_to_char(range.start.line as usize)
+ range.start.character as usize;
let end_index = rope.line_to_char(range.end.line as usize)
+ range.end.character as usize;
rope.remove(start_index..end_index);
rope.insert(start_index, &change.text);
} else {
*rope = Rope::from_str(&change.text);
}
}
memory_backend.lock().changed_text_document(params)?;
} else if notification_is::<lsp_types::notification::DidRenameFiles>(&not) {
let params: RenameFilesParams = serde_json::from_value(not.params)?;
for file_rename in params.files {
if let Some(rope) = file_map.remove(&file_rename.old_uri) {
file_map.insert(file_rename.new_uri, rope);
}
}
memory_backend.lock().renamed_file(params)?;
}
}
_ => (),

View File

@ -0,0 +1,98 @@
use anyhow::Context;
use lsp_types::TextDocumentPositionParams;
use ropey::Rope;
use std::collections::HashMap;
use crate::configuration::Configuration;
use super::MemoryBackend;
pub struct FileStore {
configuration: Configuration,
file_map: HashMap<String, Rope>,
}
impl FileStore {
pub fn new(configuration: Configuration) -> Self {
Self {
configuration,
file_map: HashMap::new(),
}
}
}
impl MemoryBackend for FileStore {
fn build_prompt(&self, position: &TextDocumentPositionParams) -> anyhow::Result<String> {
let rope = self
.file_map
.get(position.text_document.uri.as_str())
.context("Error file not found")?
.clone();
if self.configuration.supports_fim() {
// We will want to have some kind of infill support we add
// rope.insert(cursor_index, "<fim_hole>");
// rope.insert(0, "<fim_start>");
// rope.insert(rope.len_chars(), "<fim_end>");
// let prompt = rope.to_string();
unimplemented!()
} else {
// Convert rope to correct prompt for llm
let cursor_index = rope.line_to_char(position.position.line as usize)
+ position.position.character as usize;
let start = cursor_index
.checked_sub(self.configuration.get_maximum_context_length())
.unwrap_or(0);
eprintln!("############ {start} - {cursor_index} #############");
Ok(rope
.get_slice(start..cursor_index)
.context("Error getting rope slice")?
.to_string())
}
}
fn opened_text_document(
&mut self,
params: lsp_types::DidOpenTextDocumentParams,
) -> anyhow::Result<()> {
let rope = Rope::from_str(&params.text_document.text);
self.file_map
.insert(params.text_document.uri.to_string(), rope);
Ok(())
}
fn changed_text_document(
&mut self,
params: lsp_types::DidChangeTextDocumentParams,
) -> anyhow::Result<()> {
let rope = self
.file_map
.get_mut(params.text_document.uri.as_str())
.context("Error trying to get file that does not exist")?;
for change in params.content_changes {
// If range is ommitted, text is the new text of the document
if let Some(range) = change.range {
let start_index =
rope.line_to_char(range.start.line as usize) + range.start.character as usize;
let end_index =
rope.line_to_char(range.end.line as usize) + range.end.character as usize;
rope.remove(start_index..end_index);
rope.insert(start_index, &change.text);
} else {
*rope = Rope::from_str(&change.text);
}
}
Ok(())
}
fn renamed_file(&mut self, params: lsp_types::RenameFilesParams) -> anyhow::Result<()> {
for file_rename in params.files {
if let Some(rope) = self.file_map.remove(&file_rename.old_uri) {
self.file_map.insert(file_rename.new_uri, rope);
}
}
Ok(())
}
}

View File

@ -0,0 +1,31 @@
use lsp_types::{
DidChangeTextDocumentParams, DidOpenTextDocumentParams, RenameFilesParams,
TextDocumentPositionParams,
};
use crate::configuration::{Configuration, ValidMemoryBackend};
pub mod file_store;
pub trait MemoryBackend {
fn init(&self) -> anyhow::Result<()> {
Ok(())
}
fn opened_text_document(&mut self, params: DidOpenTextDocumentParams) -> anyhow::Result<()>;
fn changed_text_document(&mut self, params: DidChangeTextDocumentParams) -> anyhow::Result<()>;
fn renamed_file(&mut self, params: RenameFilesParams) -> anyhow::Result<()>;
fn build_prompt(&self, position: &TextDocumentPositionParams) -> anyhow::Result<String>;
}
impl TryFrom<Configuration> for Box<dyn MemoryBackend + Send> {
type Error = anyhow::Error;
fn try_from(configuration: Configuration) -> Result<Self, Self::Error> {
match configuration.get_memory_backend()? {
ValidMemoryBackend::FileStore => {
Ok(Box::new(file_store::FileStore::new(configuration)))
}
_ => unimplemented!(),
}
}
}

View File

@ -1,44 +0,0 @@
import sys
import os
from llama_cpp import Llama
model = None
def activate_venv(venv):
if sys.platform in ('win32', 'win64', 'cygwin'):
activate_this = os.path.join(venv, 'Scripts', 'activate_this.py')
else:
activate_this = os.path.join(venv, 'bin', 'activate_this.py')
if os.path.exists(activate_this):
exec(open(activate_this).read(), dict(__file__=activate_this))
return True
else:
print(f"Virtualenv not found: {venv}", file=sys.stderr)
return False
def set_model(filler):
global model
model = Llama(
# model_path="./tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf", # Download the model file first
model_path="/Users/silas/Projects/Tests/lsp-ai-tests/deepseek-coder-6.7b-base.Q4_K_M.gguf", # Download the model file first
n_ctx=2048, # The max sequence length to use - note that longer sequence lengths require much more resources
n_threads=8, # The number of CPU threads to use, tailor to your system and the resulting performance
n_gpu_layers=35 # The number of layers to offload to GPU, if you have GPU acceleration available
)
def transform(input):
# Simple inference example
output = model(
input, # Prompt
max_tokens=32, # Generate up to 512 tokens
stop=["<|EOT|>"], # Example stop token - not necessarily correct for this specific model! Please check before using.
echo=False # Whether to echo the prompt
)
return output["choices"][0]["text"]

View File

@ -0,0 +1,88 @@
use crate::{
configuration::Configuration,
worker::{
DoCompletionResponse, DoGenerateResponse, DoGenerateStreamResponse, GenerateStreamRequest,
},
};
use super::TransformerBackend;
use once_cell::sync::Lazy;
use pyo3::prelude::*;
pub static PY_MODULE: Lazy<anyhow::Result<Py<PyAny>>> = Lazy::new(|| {
pyo3::Python::with_gil(|py| -> anyhow::Result<Py<PyAny>> {
let src = include_str!("python/transformers.py");
Ok(pyo3::types::PyModule::from_code(py, src, "transformers.py", "transformers")?.into())
})
});
pub struct LlamaCPP {
configuration: Configuration,
}
impl LlamaCPP {
pub fn new(configuration: Configuration) -> Self {
Self { configuration }
}
}
impl TransformerBackend for LlamaCPP {
fn init(&self) -> anyhow::Result<()> {
// Activate the python venv
Python::with_gil(|py| -> anyhow::Result<()> {
let activate: Py<PyAny> = PY_MODULE
.as_ref()
.map_err(anyhow::Error::msg)?
.getattr(py, "activate_venv")?;
activate.call1(py, ("/Users/silas/Projects/lsp-ai/venv",))?;
Ok(())
})?;
// Set the model
Python::with_gil(|py| -> anyhow::Result<()> {
let activate: Py<PyAny> = PY_MODULE
.as_ref()
.map_err(anyhow::Error::msg)?
.getattr(py, "set_model")?;
activate.call1(py, ("",))?;
Ok(())
})?;
Ok(())
}
fn do_completion(&self, prompt: &str) -> anyhow::Result<DoCompletionResponse> {
let max_new_tokens = self.configuration.get_max_new_tokens().completion;
Python::with_gil(|py| -> anyhow::Result<String> {
let transform: Py<PyAny> = PY_MODULE
.as_ref()
.map_err(anyhow::Error::msg)?
.getattr(py, "transform")?;
let out: String = transform.call1(py, (prompt, max_new_tokens))?.extract(py)?;
Ok(out)
})
.map(|insert_text| DoCompletionResponse { insert_text })
}
fn do_generate(&self, prompt: &str) -> anyhow::Result<DoGenerateResponse> {
let max_new_tokens = self.configuration.get_max_new_tokens().generation;
Python::with_gil(|py| -> anyhow::Result<String> {
let transform: Py<PyAny> = PY_MODULE
.as_ref()
.map_err(anyhow::Error::msg)?
.getattr(py, "transform")?;
let out: String = transform.call1(py, (prompt, max_new_tokens))?.extract(py)?;
Ok(out)
})
.map(|generated_text| DoGenerateResponse { generated_text })
}
fn do_generate_stream(
&self,
request: &GenerateStreamRequest,
) -> anyhow::Result<DoGenerateStreamResponse> {
Ok(DoGenerateStreamResponse {
generated_text: "".to_string(),
})
}
}

View File

@ -0,0 +1,45 @@
import sys
import os
from llama_cpp import Llama
model = None
def activate_venv(venv):
if sys.platform in ("win32", "win64", "cygwin"):
activate_this = os.path.join(venv, "Scripts", "activate_this.py")
else:
activate_this = os.path.join(venv, "bin", "activate_this.py")
if os.path.exists(activate_this):
exec(open(activate_this).read(), dict(__file__=activate_this))
return True
else:
print(f"Virtualenv not found: {venv}", file=sys.stderr)
return False
def set_model(filler):
global model
model = Llama(
# model_path="./tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf", # Download the model file first
model_path="/Users/silas/Projects/Tests/lsp-ai-tests/deepseek-coder-6.7b-base.Q4_K_M.gguf", # Download the model file first
n_ctx=2048, # The max sequence length to use - note that longer sequence lengths require much more resources
n_threads=8, # The number of CPU threads to use, tailor to your system and the resulting performance
n_gpu_layers=35, # The number of layers to offload to GPU, if you have GPU acceleration available
)
def transform(input, max_tokens):
# Simple inference example
output = model(
input, # Prompt
max_tokens=max_tokens, # Generate up to max tokens
# stop=[
# "<|EOT|>"
# ], # Example stop token - not necessarily correct for this specific model! Please check before using.
echo=False, # Whether to echo the prompt
)
return output["choices"][0]["text"]

View File

@ -0,0 +1,32 @@
use crate::{
configuration::{Configuration, ValidTransformerBackend},
worker::{
DoCompletionResponse, DoGenerateResponse, DoGenerateStreamResponse, GenerateRequest,
GenerateStreamRequest,
},
};
pub mod llama_cpp;
pub trait TransformerBackend {
fn init(&self) -> anyhow::Result<()>;
fn do_completion(&self, prompt: &str) -> anyhow::Result<DoCompletionResponse>;
fn do_generate(&self, prompt: &str) -> anyhow::Result<DoGenerateResponse>;
fn do_generate_stream(
&self,
request: &GenerateStreamRequest,
) -> anyhow::Result<DoGenerateStreamResponse>;
}
impl TryFrom<Configuration> for Box<dyn TransformerBackend + Send> {
type Error = anyhow::Error;
fn try_from(configuration: Configuration) -> Result<Self, Self::Error> {
match configuration.get_transformer_backend()? {
ValidTransformerBackend::LlamaCPP => {
Ok(Box::new(llama_cpp::LlamaCPP::new(configuration)))
}
_ => unimplemented!(),
}
}
}

15
src/utils.rs Normal file
View File

@ -0,0 +1,15 @@
use lsp_server::ResponseError;
pub trait ToResponseError {
fn to_response_error(&self, code: i32) -> ResponseError;
}
impl ToResponseError for anyhow::Error {
fn to_response_error(&self, code: i32) -> ResponseError {
ResponseError {
code: -32603,
message: self.to_string(),
data: None,
}
}
}

187
src/worker.rs Normal file
View File

@ -0,0 +1,187 @@
use lsp_server::{Connection, Message, RequestId, Response};
use lsp_types::{
CompletionItem, CompletionItemKind, CompletionList, CompletionParams, CompletionResponse,
Position, Range, TextEdit,
};
use parking_lot::Mutex;
use std::{sync::Arc, thread};
use crate::custom_requests::generate::{GenerateParams, GenerateResult};
use crate::custom_requests::generate_stream::{GenerateStreamParams, GenerateStreamResult};
use crate::memory_backends::MemoryBackend;
use crate::transformer_backends::TransformerBackend;
use crate::utils::ToResponseError;
#[derive(Clone)]
pub struct CompletionRequest {
id: RequestId,
params: CompletionParams,
}
impl CompletionRequest {
pub fn new(id: RequestId, params: CompletionParams) -> Self {
Self { id, params }
}
}
#[derive(Clone)]
pub struct GenerateRequest {
id: RequestId,
params: GenerateParams,
}
impl GenerateRequest {
pub fn new(id: RequestId, params: GenerateParams) -> Self {
Self { id, params }
}
}
#[derive(Clone)]
pub struct GenerateStreamRequest {
id: RequestId,
params: GenerateStreamParams,
}
impl GenerateStreamRequest {
pub fn new(id: RequestId, params: GenerateStreamParams) -> Self {
Self { id, params }
}
}
#[derive(Clone)]
pub enum WorkerRequest {
Completion(CompletionRequest),
Generate(GenerateRequest),
GenerateStream(GenerateStreamRequest),
}
pub struct DoCompletionResponse {
pub insert_text: String,
}
pub struct DoGenerateResponse {
pub generated_text: String,
}
pub struct DoGenerateStreamResponse {
pub generated_text: String,
}
pub struct Worker {
transformer_backend: Box<dyn TransformerBackend>,
memory_backend: Arc<Mutex<Box<dyn MemoryBackend + Send>>>,
last_worker_request: Arc<Mutex<Option<WorkerRequest>>>,
connection: Arc<Connection>,
}
impl Worker {
pub fn new(
transformer_backend: Box<dyn TransformerBackend>,
memory_backend: Arc<Mutex<Box<dyn MemoryBackend + Send>>>,
last_worker_request: Arc<Mutex<Option<WorkerRequest>>>,
connection: Arc<Connection>,
) -> Self {
Self {
transformer_backend,
memory_backend,
last_worker_request,
connection,
}
}
fn do_completion(&self, request: &CompletionRequest) -> anyhow::Result<Response> {
let prompt = self
.memory_backend
.lock()
.build_prompt(&request.params.text_document_position)?;
eprintln!("\n\n****************{}***************\n\n", prompt);
let response = self.transformer_backend.do_completion(&prompt)?;
eprintln!(
"\n\n****************{}***************\n\n",
response.insert_text
);
let completion_text_edit = TextEdit::new(
Range::new(
Position::new(
request.params.text_document_position.position.line,
request.params.text_document_position.position.character,
),
Position::new(
request.params.text_document_position.position.line,
request.params.text_document_position.position.character,
),
),
response.insert_text.clone(),
);
let item = CompletionItem {
label: format!("ai - {}", response.insert_text),
text_edit: Some(lsp_types::CompletionTextEdit::Edit(completion_text_edit)),
kind: Some(CompletionItemKind::TEXT),
..Default::default()
};
let completion_list = CompletionList {
is_incomplete: false,
items: vec![item],
};
let result = Some(CompletionResponse::List(completion_list));
let result = serde_json::to_value(&result).unwrap();
Ok(Response {
id: request.id.clone(),
result: Some(result),
error: None,
})
}
fn do_generate(&self, request: &GenerateRequest) -> anyhow::Result<Response> {
let prompt = self
.memory_backend
.lock()
.build_prompt(&request.params.text_document_position)?;
eprintln!("\n\n****************{}***************\n\n", prompt);
let response = self.transformer_backend.do_generate(&prompt)?;
let result = GenerateResult {
generated_text: response.generated_text,
};
let result = serde_json::to_value(&result).unwrap();
Ok(Response {
id: request.id.clone(),
result: Some(result),
error: None,
})
}
pub fn run(self) {
loop {
let option_worker_request: Option<WorkerRequest> = {
let mut completion_request = self.last_worker_request.lock();
std::mem::take(&mut *completion_request)
};
if let Some(request) = option_worker_request {
let response = match request {
WorkerRequest::Completion(request) => match self.do_completion(&request) {
Ok(r) => r,
Err(e) => Response {
id: request.id,
result: None,
error: Some(e.to_response_error(-32603)),
},
},
WorkerRequest::Generate(request) => match self.do_generate(&request) {
Ok(r) => r,
Err(e) => Response {
id: request.id,
result: None,
error: Some(e.to_response_error(-32603)),
},
},
WorkerRequest::GenerateStream(_) => panic!("Streaming is not supported yet"),
};
self.connection
.sender
.send(Message::Response(response))
.expect("Error sending message");
}
thread::sleep(std::time::Duration::from_millis(5));
}
}
}

View File

@ -1,61 +0,0 @@
use lsp_server::ResponseError;
use pyo3::prelude::*;
use super::CompletionRequest;
use crate::PY_MODULE;
pub struct DoCompletionResponse {
pub insert_text: String,
pub filter_text: String,
}
pub fn do_completion(request: &CompletionRequest) -> Result<DoCompletionResponse, ResponseError> {
let filter_text = request
.rope
.get_line(request.params.text_document_position.position.line as usize)
.ok_or(ResponseError {
code: -32603, // Maybe we want a different error code here?
message: "Error getting line in requested document".to_string(),
data: None,
})?
.to_string();
// Convert rope to correct prompt for llm
let cursor_index = request
.rope
.line_to_char(request.params.text_document_position.position.line as usize)
+ request.params.text_document_position.position.character as usize;
// We will want to have some kind of infill support we add
// rope.insert(cursor_index, "<fim_hole>");
// rope.insert(0, "<fim_start>");
// rope.insert(rope.len_chars(), "<fim_end>");
// let prompt = rope.to_string();
let prompt = request
.rope
.get_slice(0..cursor_index)
.expect("Error getting rope slice")
.to_string();
eprintln!("\n\n****{prompt}****\n\n");
Python::with_gil(|py| -> anyhow::Result<String> {
let transform: Py<PyAny> = PY_MODULE
.as_ref()
.map_err(anyhow::Error::msg)?
.getattr(py, "transform")?;
let out: String = transform.call1(py, (prompt,))?.extract(py)?;
Ok(out)
})
.map(|insert_text| DoCompletionResponse {
insert_text,
filter_text,
})
.map_err(|e| ResponseError {
code: -32603,
message: e.to_string(),
data: None,
})
}

View File

@ -1,47 +0,0 @@
use lsp_server::ResponseError;
use pyo3::prelude::*;
use super::{GenerateRequest, GenerateStreamRequest};
use crate::PY_MODULE;
pub struct DoGenerateResponse {
pub generated_text: String,
}
pub fn do_generate(request: &GenerateRequest) -> Result<DoGenerateResponse, ResponseError> {
// Convert rope to correct prompt for llm
let cursor_index = request
.rope
.line_to_char(request.params.text_document_position.position.line as usize)
+ request.params.text_document_position.position.character as usize;
// We will want to have some kind of infill support we add
// rope.insert(cursor_index, "<fim_hole>");
// rope.insert(0, "<fim_start>");
// rope.insert(rope.len_chars(), "<fim_end>");
// let prompt = rope.to_string();
let prompt = request
.rope
.get_slice(0..cursor_index)
.expect("Error getting rope slice")
.to_string();
eprintln!("\n\n****{prompt}****\n\n");
Python::with_gil(|py| -> anyhow::Result<String> {
let transform: Py<PyAny> = PY_MODULE
.as_ref()
.map_err(anyhow::Error::msg)?
.getattr(py, "transform")?;
let out: String = transform.call1(py, (prompt,))?.extract(py)?;
Ok(out)
})
.map(|generated_text| DoGenerateResponse { generated_text })
.map_err(|e| ResponseError {
code: -32603,
message: e.to_string(),
data: None,
})
}

View File

@ -1,45 +0,0 @@
use lsp_server::ResponseError;
use pyo3::prelude::*;
use super::{GenerateRequest, GenerateStreamRequest};
use crate::PY_MODULE;
pub fn do_generate_stream(request: &GenerateStreamRequest) -> Result<(), ResponseError> {
// Convert rope to correct prompt for llm
// let cursor_index = request
// .rope
// .line_to_char(request.params.text_document_position.position.line as usize)
// + request.params.text_document_position.position.character as usize;
// // We will want to have some kind of infill support we add
// // rope.insert(cursor_index, "<fim_hole>");
// // rope.insert(0, "<fim_start>");
// // rope.insert(rope.len_chars(), "<fim_end>");
// // let prompt = rope.to_string();
// let prompt = request
// .rope
// .get_slice(0..cursor_index)
// .expect("Error getting rope slice")
// .to_string();
// eprintln!("\n\n****{prompt}****\n\n");
// Python::with_gil(|py| -> anyhow::Result<String> {
// let transform: Py<PyAny> = PY_MODULE
// .as_ref()
// .map_err(anyhow::Error::msg)?
// .getattr(py, "transform")?;
// let out: String = transform.call1(py, (prompt,))?.extract(py)?;
// Ok(out)
// })
// .map(|generated_text| DoGenerateResponse { generated_text })
// .map_err(|e| ResponseError {
// code: -32603,
// message: e.to_string(),
// data: None,
// })
Ok(())
}

View File

@ -1,181 +0,0 @@
use lsp_server::{Connection, Message, RequestId, Response};
use lsp_types::{
CompletionItem, CompletionItemKind, CompletionList, CompletionParams, CompletionResponse,
Position, Range, TextEdit,
};
use parking_lot::Mutex;
use ropey::Rope;
use std::{sync::Arc, thread};
mod completion;
mod generate;
mod generate_stream;
use crate::custom_requests::generate::{GenerateParams, GenerateResult};
use crate::custom_requests::generate_stream::{GenerateStreamParams, GenerateStreamResult};
use completion::do_completion;
use generate::do_generate;
use generate_stream::do_generate_stream;
#[derive(Clone)]
pub struct CompletionRequest {
id: RequestId,
params: CompletionParams,
rope: Rope,
}
impl CompletionRequest {
pub fn new(id: RequestId, params: CompletionParams, rope: Rope) -> Self {
Self { id, params, rope }
}
}
#[derive(Clone)]
pub struct GenerateRequest {
id: RequestId,
params: GenerateParams,
rope: Rope,
}
impl GenerateRequest {
pub fn new(id: RequestId, params: GenerateParams, rope: Rope) -> Self {
Self { id, params, rope }
}
}
#[derive(Clone)]
pub struct GenerateStreamRequest {
id: RequestId,
params: GenerateStreamParams,
rope: Rope,
}
impl GenerateStreamRequest {
pub fn new(id: RequestId, params: GenerateStreamParams, rope: Rope) -> Self {
Self { id, params, rope }
}
}
#[derive(Clone)]
pub enum WorkerRequest {
Completion(CompletionRequest),
Generate(GenerateRequest),
GenerateStream(GenerateStreamRequest),
}
pub fn run(last_worker_request: Arc<Mutex<Option<WorkerRequest>>>, connection: Arc<Connection>) {
loop {
let option_worker_request: Option<WorkerRequest> = {
let mut completion_request = last_worker_request.lock();
std::mem::take(&mut *completion_request)
};
if let Some(request) = option_worker_request {
let response = match request {
WorkerRequest::Completion(request) => match do_completion(&request) {
Ok(response) => {
let completion_text_edit = TextEdit::new(
Range::new(
Position::new(
request.params.text_document_position.position.line,
request.params.text_document_position.position.character,
),
Position::new(
request.params.text_document_position.position.line,
request.params.text_document_position.position.character,
),
),
response.insert_text.clone(),
);
let item = CompletionItem {
label: format!("ai - {}", response.insert_text),
filter_text: Some(response.filter_text),
text_edit: Some(lsp_types::CompletionTextEdit::Edit(
completion_text_edit,
)),
kind: Some(CompletionItemKind::TEXT),
..Default::default()
};
let completion_list = CompletionList {
is_incomplete: false,
items: vec![item],
};
let result = Some(CompletionResponse::List(completion_list));
let result = serde_json::to_value(&result).unwrap();
Response {
id: request.id,
result: Some(result),
error: None,
}
}
Err(e) => Response {
id: request.id,
result: None,
error: Some(e),
},
},
WorkerRequest::Generate(request) => match do_generate(&request) {
Ok(result) => {
let result = GenerateResult {
generated_text: result.generated_text,
};
let result = serde_json::to_value(&result).unwrap();
Response {
id: request.id,
result: Some(result),
error: None,
}
}
Err(e) => Response {
id: request.id,
result: None,
error: Some(e),
},
},
WorkerRequest::GenerateStream(request) => match do_generate_stream(&request) {
Ok(result) => {
// let result = GenerateResult {
// generated_text: result.generated_text,
// };
// let result = serde_json::to_value(&result).unwrap();
let result = GenerateStreamResult {
generated_text: "test".to_string(),
partial_result_token: request.params.partial_result_token,
};
let result = serde_json::to_value(&result).unwrap();
Response {
id: request.id,
result: Some(result),
error: None,
}
}
Err(e) => Response {
id: request.id,
result: None,
error: Some(e),
},
},
};
connection
.sender
.send(Message::Response(response.clone()))
.expect("Error sending response");
connection
.sender
.send(Message::Response(response.clone()))
.expect("Error sending response");
connection
.sender
.send(Message::Response(response.clone()))
.expect("Error sending response");
// connection
// .sender
// .send(Message::Response(Response {
// id: response.id,
// result: None,
// error: None,
// }))
// .expect("Error sending message");
}
thread::sleep(std::time::Duration::from_millis(5));
}
}

18
test.json Normal file
View File

@ -0,0 +1,18 @@
{
"macos": {
"model_gguf": {
"repository": "deepseek-coder-6.7b-base",
"name": "Q4_K_M.gguf",
"fim": false,
"n_ctx": 2048,
"n_threads": 8,
"n_gpu_layers": 35
}
},
"linux": {
"model_gptq": {
"repository": "theblokesomething",
"name": "some q5 or something"
}
}
}