The beginning of something awesome

This commit is contained in:
Silas Marvin 2023-11-24 09:08:25 -08:00
commit 1a1c328ae7
6 changed files with 2325 additions and 0 deletions

2
.gitignore vendored Normal file
View File

@ -0,0 +1,2 @@
/target
/models

1997
Cargo.lock generated Normal file

File diff suppressed because it is too large Load Diff

26
Cargo.toml Normal file
View File

@ -0,0 +1,26 @@
[package]
name = "lsp-ai"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
anyhow = "1.0.75"
lsp-server = "0.7.4"
lsp-types = "0.94.1"
once_cell = "1.18.0"
parking_lot = "0.12.1"
ropey = "1.6.1"
serde = "1.0.190"
serde_json = "1.0.108"
# candle-core = { git = "https://github.com/huggingface/candle/", version = "0.3.1", features = ["accelerate"] }
# candle-nn = { git = "https://github.com/huggingface/candle/", version = "0.3.1", features = ["accelerate"] }
# candle-transformers = { git = "https://github.com/huggingface/candle/", version = "0.3.1", features = ["accelerate"] }
candle-core = { path = "../candle/candle-core", version = "0.3.1", features = ["accelerate"] }
candle-nn = { path = "../candle/candle-nn", version = "0.3.1", features = ["accelerate"] }
candle-transformers = { path = "../candle/candle-transformers", version = "0.3.1", features = ["accelerate"] }
hf-hub = { git = "https://github.com/huggingface/hf-hub", version = "0.3.2" }
rand = "0.8.5"
tokenizers = "0.14.1"

3
run.sh Executable file
View File

@ -0,0 +1,3 @@
#!/usr/bin/env bash
/Users/silas/Projects/lsp-ai/target/release/lsp-ai

194
src/main.rs Normal file
View File

@ -0,0 +1,194 @@
use anyhow::Context;
use anyhow::Result;
use core::panic;
use lsp_server::{Connection, ExtractError, Message, Notification, Request, RequestId, Response};
use lsp_types::{
request::Completion, CompletionItem, CompletionItemKind, CompletionList, CompletionOptions,
CompletionResponse, DidChangeTextDocumentParams, DidOpenTextDocumentParams, Position, Range,
RenameFilesParams, ServerCapabilities, TextDocumentSyncKind, TextEdit,
};
use parking_lot::Mutex;
use serde::Deserialize;
// use pyo3::prelude::*;
// use pyo3::types::PyTuple;
use ropey::Rope;
use std::collections::HashMap;
mod transformer;
static FILE_MAP: once_cell::sync::Lazy<Mutex<HashMap<String, Rope>>> =
once_cell::sync::Lazy::new(|| Mutex::new(HashMap::new()));
// Taken directly from: https://github.com/rust-lang/rust-analyzer
fn notification_is<N: lsp_types::notification::Notification>(notification: &Notification) -> bool {
notification.method == N::METHOD
}
fn main() -> Result<()> {
let (connection, io_threads) = Connection::stdio();
let server_capabilities = serde_json::to_value(&ServerCapabilities {
completion_provider: Some(CompletionOptions::default()),
text_document_sync: Some(lsp_types::TextDocumentSyncCapability::Kind(
TextDocumentSyncKind::INCREMENTAL,
)),
..Default::default()
})?;
let initialization_params = connection.initialize(server_capabilities)?;
main_loop(connection, initialization_params)?;
io_threads.join()?;
Ok(())
}
#[derive(Deserialize)]
struct Params {
model: Option<String>,
model_file: Option<String>,
model_type: Option<String>,
device: Option<String>,
}
fn main_loop(connection: Connection, params: serde_json::Value) -> Result<()> {
let params: Params = serde_json::from_value(params)?;
let mut text_generation = transformer::build()?;
for msg in &connection.receiver {
match msg {
Message::Request(req) => {
if connection.handle_shutdown(&req)? {
return Ok(());
}
match cast::<Completion>(req) {
Ok((id, params)) => {
// Get rope
let file_map = FILE_MAP.lock();
let mut rope = file_map
.get(params.text_document_position.text_document.uri.as_str())
.context("Error file not found")?
.clone();
let filter_text = rope
.get_line(params.text_document_position.position.line as usize)
.context("Error getting line with ropey")?
.to_string();
// Convert rope to correct prompt for llm
let start_index = rope
.line_to_char(params.text_document_position.position.line as usize)
+ params.text_document_position.position.character as usize;
rope.insert(start_index, "<fim_suffix>");
let prompt = format!("<fim_prefix>{}<fim_middle>", rope);
let insert_text = text_generation.run(&prompt, 64)?;
// Create and return the completion
let completion_text_edit = TextEdit::new(
Range::new(
Position::new(
params.text_document_position.position.line,
params.text_document_position.position.character,
),
Position::new(
params.text_document_position.position.line,
params.text_document_position.position.character,
),
),
insert_text.clone(),
);
let item = CompletionItem {
label: format!("ai - {insert_text}"),
filter_text: Some(filter_text),
text_edit: Some(lsp_types::CompletionTextEdit::Edit(
completion_text_edit,
)),
kind: Some(CompletionItemKind::TEXT),
..Default::default()
};
let completion_list = CompletionList {
is_incomplete: false,
items: vec![item],
};
let result = Some(CompletionResponse::List(completion_list));
let result = serde_json::to_value(&result).unwrap();
let resp = Response {
id,
result: Some(result),
error: None,
};
connection.sender.send(Message::Response(resp))?;
continue;
}
Err(err @ ExtractError::JsonError { .. }) => panic!("{err:?}"),
Err(ExtractError::MethodMismatch(req)) => req,
};
}
Message::Notification(not) => {
eprintln!("got notification: {not:?}");
if notification_is::<lsp_types::notification::DidOpenTextDocument>(&not) {
let params: DidOpenTextDocumentParams = serde_json::from_value(not.params)?;
let rope = Rope::from_str(&params.text_document.text);
let mut file_map = FILE_MAP.lock();
file_map.insert(params.text_document.uri.to_string(), rope);
} else if notification_is::<lsp_types::notification::DidChangeTextDocument>(&not) {
let params: DidChangeTextDocumentParams = serde_json::from_value(not.params)?;
let mut file_map = FILE_MAP.lock();
let rope = file_map
.get_mut(params.text_document.uri.as_str())
.context("Error trying to get file that does not exist")?;
for change in params.content_changes {
// If range is ommitted, text is the new text of the document
if let Some(range) = change.range {
let start_index = rope.line_to_char(range.start.line as usize)
+ range.start.character as usize;
let end_index = rope.line_to_char(range.end.line as usize)
+ range.end.character as usize;
rope.remove(start_index..end_index);
rope.insert(start_index, &change.text);
} else {
*rope = Rope::from_str(&change.text);
}
}
} else if notification_is::<lsp_types::notification::DidRenameFiles>(&not) {
let params: RenameFilesParams = serde_json::from_value(not.params)?;
let mut file_map = FILE_MAP.lock();
for file_rename in params.files {
if let Some(rope) = file_map.remove(&file_rename.old_uri) {
file_map.insert(file_rename.new_uri, rope);
}
}
}
}
_ => (),
}
}
Ok(())
}
fn cast<R>(req: Request) -> Result<(RequestId, R::Params), ExtractError<Request>>
where
R: lsp_types::request::Request,
R::Params: serde::de::DeserializeOwned,
{
req.extract(R::METHOD)
}
// #[cfg(test)]
// mod tests {
// use super::*;
// #[test]
// fn test_lsp() -> Result<()> {
// let prompt = "def sum_two_numers(x: int, y:";
// let result = Python::with_gil(|py| -> Result<String> {
// let transform: Py<PyAny> = PY_MODULE
// .as_ref()
// .expect("Error getting python module")
// .getattr(py, "transform")
// .expect("Error getting transform");
// let output = transform
// .call1(py, PyTuple::new(py, &[prompt]))
// .expect("Error calling transform");
// Ok(output.extract(py).expect("Error extracting result"))
// })?;
// println!("\n\nTHE RESULT\n{:?}\n\n", result);
// Ok(())
// }
// }

103
src/transformer.rs Normal file
View File

@ -0,0 +1,103 @@
use anyhow::{Error as E, Result};
use candle_core::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use candle_transformers::models::bigcode::{Config, GPTBigCode};
use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
use tokenizers::Tokenizer;
pub struct TextGeneration {
model: GPTBigCode,
device: Device,
tokenizer: Tokenizer,
logits_processor: LogitsProcessor,
}
impl TextGeneration {
fn new(
model: GPTBigCode,
tokenizer: Tokenizer,
seed: u64,
temp: Option<f64>,
top_p: Option<f64>,
device: &Device,
) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
Self {
model,
tokenizer,
logits_processor,
device: device.clone(),
}
}
pub fn run(&mut self, prompt: &str, sample_len: usize) -> Result<String> {
eprintln!("Starting to generate tokens");
let mut tokens = self
.tokenizer
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let mut new_tokens = vec![];
let mut outputs = vec![];
let start_gen = std::time::Instant::now();
for index in 0..sample_len {
let (context_size, past_len) = if self.model.config().use_cache && index > 0 {
(1, tokens.len().saturating_sub(1))
} else {
(tokens.len(), 0)
};
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
let logits = self.model.forward(&input, past_len)?;
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
new_tokens.push(next_token);
let token = self.tokenizer.decode(&[next_token], true).map_err(E::msg)?;
outputs.push(token);
}
let dt = start_gen.elapsed();
self.model.clear_cache();
eprintln!(
"GENERATED {} tokens in {} seconds",
outputs.len(),
dt.as_secs()
);
Ok(outputs.join(""))
}
}
pub fn build() -> Result<TextGeneration> {
let start = std::time::Instant::now();
eprintln!("Loading in model");
let api = ApiBuilder::new()
.with_token(Some(std::env::var("HF_TOKEN")?.to_string()))
.build()?;
let repo = api.repo(Repo::with_revision(
"bigcode/starcoderbase-1b".to_string(),
RepoType::Model,
"main".to_string(),
));
let tokenizer_filename = repo.get("tokenizer.json")?;
let filenames = ["model.safetensors"]
.iter()
.map(|f| repo.get(f))
.collect::<std::result::Result<Vec<_>, _>>()?;
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let device = Device::Cpu;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
let config = Config::starcoder_1b();
let model = GPTBigCode::load(vb, config)?;
eprintln!("loaded the model in {:?}", start.elapsed());
Ok(TextGeneration::new(
model,
tokenizer,
0,
Some(0.85),
None,
&device,
))
}