mirror of
https://github.com/SilasMarvin/lsp-ai.git
synced 2024-09-17 15:17:23 +03:00
The beginning of something awesome
This commit is contained in:
commit
1a1c328ae7
2
.gitignore
vendored
Normal file
2
.gitignore
vendored
Normal file
@ -0,0 +1,2 @@
|
||||
/target
|
||||
/models
|
1997
Cargo.lock
generated
Normal file
1997
Cargo.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
26
Cargo.toml
Normal file
26
Cargo.toml
Normal file
@ -0,0 +1,26 @@
|
||||
[package]
|
||||
name = "lsp-ai"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
anyhow = "1.0.75"
|
||||
lsp-server = "0.7.4"
|
||||
lsp-types = "0.94.1"
|
||||
once_cell = "1.18.0"
|
||||
parking_lot = "0.12.1"
|
||||
ropey = "1.6.1"
|
||||
serde = "1.0.190"
|
||||
serde_json = "1.0.108"
|
||||
# candle-core = { git = "https://github.com/huggingface/candle/", version = "0.3.1", features = ["accelerate"] }
|
||||
# candle-nn = { git = "https://github.com/huggingface/candle/", version = "0.3.1", features = ["accelerate"] }
|
||||
# candle-transformers = { git = "https://github.com/huggingface/candle/", version = "0.3.1", features = ["accelerate"] }
|
||||
candle-core = { path = "../candle/candle-core", version = "0.3.1", features = ["accelerate"] }
|
||||
candle-nn = { path = "../candle/candle-nn", version = "0.3.1", features = ["accelerate"] }
|
||||
candle-transformers = { path = "../candle/candle-transformers", version = "0.3.1", features = ["accelerate"] }
|
||||
hf-hub = { git = "https://github.com/huggingface/hf-hub", version = "0.3.2" }
|
||||
rand = "0.8.5"
|
||||
tokenizers = "0.14.1"
|
||||
|
3
run.sh
Executable file
3
run.sh
Executable file
@ -0,0 +1,3 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
/Users/silas/Projects/lsp-ai/target/release/lsp-ai
|
194
src/main.rs
Normal file
194
src/main.rs
Normal file
@ -0,0 +1,194 @@
|
||||
use anyhow::Context;
|
||||
use anyhow::Result;
|
||||
use core::panic;
|
||||
use lsp_server::{Connection, ExtractError, Message, Notification, Request, RequestId, Response};
|
||||
use lsp_types::{
|
||||
request::Completion, CompletionItem, CompletionItemKind, CompletionList, CompletionOptions,
|
||||
CompletionResponse, DidChangeTextDocumentParams, DidOpenTextDocumentParams, Position, Range,
|
||||
RenameFilesParams, ServerCapabilities, TextDocumentSyncKind, TextEdit,
|
||||
};
|
||||
use parking_lot::Mutex;
|
||||
use serde::Deserialize;
|
||||
// use pyo3::prelude::*;
|
||||
// use pyo3::types::PyTuple;
|
||||
use ropey::Rope;
|
||||
use std::collections::HashMap;
|
||||
|
||||
mod transformer;
|
||||
|
||||
static FILE_MAP: once_cell::sync::Lazy<Mutex<HashMap<String, Rope>>> =
|
||||
once_cell::sync::Lazy::new(|| Mutex::new(HashMap::new()));
|
||||
|
||||
// Taken directly from: https://github.com/rust-lang/rust-analyzer
|
||||
fn notification_is<N: lsp_types::notification::Notification>(notification: &Notification) -> bool {
|
||||
notification.method == N::METHOD
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let (connection, io_threads) = Connection::stdio();
|
||||
let server_capabilities = serde_json::to_value(&ServerCapabilities {
|
||||
completion_provider: Some(CompletionOptions::default()),
|
||||
text_document_sync: Some(lsp_types::TextDocumentSyncCapability::Kind(
|
||||
TextDocumentSyncKind::INCREMENTAL,
|
||||
)),
|
||||
..Default::default()
|
||||
})?;
|
||||
let initialization_params = connection.initialize(server_capabilities)?;
|
||||
main_loop(connection, initialization_params)?;
|
||||
io_threads.join()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct Params {
|
||||
model: Option<String>,
|
||||
model_file: Option<String>,
|
||||
model_type: Option<String>,
|
||||
device: Option<String>,
|
||||
}
|
||||
|
||||
fn main_loop(connection: Connection, params: serde_json::Value) -> Result<()> {
|
||||
let params: Params = serde_json::from_value(params)?;
|
||||
let mut text_generation = transformer::build()?;
|
||||
for msg in &connection.receiver {
|
||||
match msg {
|
||||
Message::Request(req) => {
|
||||
if connection.handle_shutdown(&req)? {
|
||||
return Ok(());
|
||||
}
|
||||
match cast::<Completion>(req) {
|
||||
Ok((id, params)) => {
|
||||
// Get rope
|
||||
let file_map = FILE_MAP.lock();
|
||||
let mut rope = file_map
|
||||
.get(params.text_document_position.text_document.uri.as_str())
|
||||
.context("Error file not found")?
|
||||
.clone();
|
||||
let filter_text = rope
|
||||
.get_line(params.text_document_position.position.line as usize)
|
||||
.context("Error getting line with ropey")?
|
||||
.to_string();
|
||||
|
||||
// Convert rope to correct prompt for llm
|
||||
let start_index = rope
|
||||
.line_to_char(params.text_document_position.position.line as usize)
|
||||
+ params.text_document_position.position.character as usize;
|
||||
rope.insert(start_index, "<fim_suffix>");
|
||||
let prompt = format!("<fim_prefix>{}<fim_middle>", rope);
|
||||
let insert_text = text_generation.run(&prompt, 64)?;
|
||||
|
||||
// Create and return the completion
|
||||
let completion_text_edit = TextEdit::new(
|
||||
Range::new(
|
||||
Position::new(
|
||||
params.text_document_position.position.line,
|
||||
params.text_document_position.position.character,
|
||||
),
|
||||
Position::new(
|
||||
params.text_document_position.position.line,
|
||||
params.text_document_position.position.character,
|
||||
),
|
||||
),
|
||||
insert_text.clone(),
|
||||
);
|
||||
let item = CompletionItem {
|
||||
label: format!("ai - {insert_text}"),
|
||||
filter_text: Some(filter_text),
|
||||
text_edit: Some(lsp_types::CompletionTextEdit::Edit(
|
||||
completion_text_edit,
|
||||
)),
|
||||
kind: Some(CompletionItemKind::TEXT),
|
||||
..Default::default()
|
||||
};
|
||||
let completion_list = CompletionList {
|
||||
is_incomplete: false,
|
||||
items: vec![item],
|
||||
};
|
||||
let result = Some(CompletionResponse::List(completion_list));
|
||||
let result = serde_json::to_value(&result).unwrap();
|
||||
let resp = Response {
|
||||
id,
|
||||
result: Some(result),
|
||||
error: None,
|
||||
};
|
||||
connection.sender.send(Message::Response(resp))?;
|
||||
continue;
|
||||
}
|
||||
Err(err @ ExtractError::JsonError { .. }) => panic!("{err:?}"),
|
||||
Err(ExtractError::MethodMismatch(req)) => req,
|
||||
};
|
||||
}
|
||||
Message::Notification(not) => {
|
||||
eprintln!("got notification: {not:?}");
|
||||
if notification_is::<lsp_types::notification::DidOpenTextDocument>(¬) {
|
||||
let params: DidOpenTextDocumentParams = serde_json::from_value(not.params)?;
|
||||
let rope = Rope::from_str(¶ms.text_document.text);
|
||||
let mut file_map = FILE_MAP.lock();
|
||||
file_map.insert(params.text_document.uri.to_string(), rope);
|
||||
} else if notification_is::<lsp_types::notification::DidChangeTextDocument>(¬) {
|
||||
let params: DidChangeTextDocumentParams = serde_json::from_value(not.params)?;
|
||||
let mut file_map = FILE_MAP.lock();
|
||||
let rope = file_map
|
||||
.get_mut(params.text_document.uri.as_str())
|
||||
.context("Error trying to get file that does not exist")?;
|
||||
for change in params.content_changes {
|
||||
// If range is ommitted, text is the new text of the document
|
||||
if let Some(range) = change.range {
|
||||
let start_index = rope.line_to_char(range.start.line as usize)
|
||||
+ range.start.character as usize;
|
||||
let end_index = rope.line_to_char(range.end.line as usize)
|
||||
+ range.end.character as usize;
|
||||
rope.remove(start_index..end_index);
|
||||
rope.insert(start_index, &change.text);
|
||||
} else {
|
||||
*rope = Rope::from_str(&change.text);
|
||||
}
|
||||
}
|
||||
} else if notification_is::<lsp_types::notification::DidRenameFiles>(¬) {
|
||||
let params: RenameFilesParams = serde_json::from_value(not.params)?;
|
||||
let mut file_map = FILE_MAP.lock();
|
||||
for file_rename in params.files {
|
||||
if let Some(rope) = file_map.remove(&file_rename.old_uri) {
|
||||
file_map.insert(file_rename.new_uri, rope);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => (),
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn cast<R>(req: Request) -> Result<(RequestId, R::Params), ExtractError<Request>>
|
||||
where
|
||||
R: lsp_types::request::Request,
|
||||
R::Params: serde::de::DeserializeOwned,
|
||||
{
|
||||
req.extract(R::METHOD)
|
||||
}
|
||||
|
||||
// #[cfg(test)]
|
||||
// mod tests {
|
||||
// use super::*;
|
||||
|
||||
// #[test]
|
||||
// fn test_lsp() -> Result<()> {
|
||||
// let prompt = "def sum_two_numers(x: int, y:";
|
||||
// let result = Python::with_gil(|py| -> Result<String> {
|
||||
// let transform: Py<PyAny> = PY_MODULE
|
||||
// .as_ref()
|
||||
// .expect("Error getting python module")
|
||||
// .getattr(py, "transform")
|
||||
// .expect("Error getting transform");
|
||||
|
||||
// let output = transform
|
||||
// .call1(py, PyTuple::new(py, &[prompt]))
|
||||
// .expect("Error calling transform");
|
||||
|
||||
// Ok(output.extract(py).expect("Error extracting result"))
|
||||
// })?;
|
||||
// println!("\n\nTHE RESULT\n{:?}\n\n", result);
|
||||
// Ok(())
|
||||
// }
|
||||
// }
|
103
src/transformer.rs
Normal file
103
src/transformer.rs
Normal file
@ -0,0 +1,103 @@
|
||||
use anyhow::{Error as E, Result};
|
||||
use candle_core::{DType, Device, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use candle_transformers::models::bigcode::{Config, GPTBigCode};
|
||||
use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
pub struct TextGeneration {
|
||||
model: GPTBigCode,
|
||||
device: Device,
|
||||
tokenizer: Tokenizer,
|
||||
logits_processor: LogitsProcessor,
|
||||
}
|
||||
|
||||
impl TextGeneration {
|
||||
fn new(
|
||||
model: GPTBigCode,
|
||||
tokenizer: Tokenizer,
|
||||
seed: u64,
|
||||
temp: Option<f64>,
|
||||
top_p: Option<f64>,
|
||||
device: &Device,
|
||||
) -> Self {
|
||||
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||
Self {
|
||||
model,
|
||||
tokenizer,
|
||||
logits_processor,
|
||||
device: device.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn run(&mut self, prompt: &str, sample_len: usize) -> Result<String> {
|
||||
eprintln!("Starting to generate tokens");
|
||||
let mut tokens = self
|
||||
.tokenizer
|
||||
.encode(prompt, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
let mut new_tokens = vec![];
|
||||
let mut outputs = vec![];
|
||||
let start_gen = std::time::Instant::now();
|
||||
for index in 0..sample_len {
|
||||
let (context_size, past_len) = if self.model.config().use_cache && index > 0 {
|
||||
(1, tokens.len().saturating_sub(1))
|
||||
} else {
|
||||
(tokens.len(), 0)
|
||||
};
|
||||
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||
let logits = self.model.forward(&input, past_len)?;
|
||||
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
|
||||
|
||||
let next_token = self.logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
new_tokens.push(next_token);
|
||||
let token = self.tokenizer.decode(&[next_token], true).map_err(E::msg)?;
|
||||
outputs.push(token);
|
||||
}
|
||||
let dt = start_gen.elapsed();
|
||||
self.model.clear_cache();
|
||||
eprintln!(
|
||||
"GENERATED {} tokens in {} seconds",
|
||||
outputs.len(),
|
||||
dt.as_secs()
|
||||
);
|
||||
Ok(outputs.join(""))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn build() -> Result<TextGeneration> {
|
||||
let start = std::time::Instant::now();
|
||||
eprintln!("Loading in model");
|
||||
let api = ApiBuilder::new()
|
||||
.with_token(Some(std::env::var("HF_TOKEN")?.to_string()))
|
||||
.build()?;
|
||||
let repo = api.repo(Repo::with_revision(
|
||||
"bigcode/starcoderbase-1b".to_string(),
|
||||
RepoType::Model,
|
||||
"main".to_string(),
|
||||
));
|
||||
let tokenizer_filename = repo.get("tokenizer.json")?;
|
||||
let filenames = ["model.safetensors"]
|
||||
.iter()
|
||||
.map(|f| repo.get(f))
|
||||
.collect::<std::result::Result<Vec<_>, _>>()?;
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
let device = Device::Cpu;
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
|
||||
let config = Config::starcoder_1b();
|
||||
let model = GPTBigCode::load(vb, config)?;
|
||||
eprintln!("loaded the model in {:?}", start.elapsed());
|
||||
Ok(TextGeneration::new(
|
||||
model,
|
||||
tokenizer,
|
||||
0,
|
||||
Some(0.85),
|
||||
None,
|
||||
&device,
|
||||
))
|
||||
}
|
Loading…
Reference in New Issue
Block a user