Added templating and some other great things

This commit is contained in:
Silas Marvin 2024-03-08 15:12:37 -08:00
parent d818cdca6d
commit aa7c4061cf
10 changed files with 196 additions and 39 deletions

21
Cargo.lock generated
View File

@ -712,6 +712,7 @@ dependencies = [
"tokenizers",
"tracing",
"tracing-subscriber",
"xxhash-rust",
]
[[package]]
@ -770,12 +771,20 @@ version = "2.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "523dc4f511e55ab87b694dc30d0f820d60906ef06413f93d4d7a1385599cc149"
[[package]]
name = "memo-map"
version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "374c335b2df19e62d4cb323103473cbc6510980253119180de862d89184f6a83"
[[package]]
name = "minijinja"
version = "1.0.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6fe0ff215195a22884d867b547c70a0c4815cbbcc70991f281dca604b20d10ce"
dependencies = [
"memo-map",
"self_cell",
"serde",
]
@ -1307,6 +1316,12 @@ dependencies = [
"libc",
]
[[package]]
name = "self_cell"
version = "1.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "58bf37232d3bb9a2c4e641ca2a11d83b5062066f88df7fed36c28772046d65ba"
[[package]]
name = "serde"
version = "1.0.197"
@ -1897,6 +1912,12 @@ version = "0.52.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dff9641d1cd4be8d1a070daf9e3773c5f67e78b4d9d42263020c057706765c04"
[[package]]
name = "xxhash-rust"
version = "0.8.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "927da81e25be1e1a2901d59b81b37dd2efd1fc9c9345a55007f09bf5a2d3ee03"
[[package]]
name = "zeroize"
version = "1.7.0"

View File

@ -21,9 +21,10 @@ once_cell = "1.19.0"
directories = "5.0.1"
# llama-cpp-2 = "0.1.31"
llama-cpp-2 = { path = "../llama-cpp-rs/llama-cpp-2" }
minijinja = "1.0.12"
minijinja = { version = "1.0.12", features = ["loader"] }
tracing-subscriber = { version = "0.3.18", features = ["env-filter"] }
tracing = "0.1.40"
xxhash-rust = { version = "0.8.5", features = ["xxh3"] }
[features]
default = []

View File

@ -15,6 +15,7 @@
},
"devDependencies": {
"@types/node": "^20.11.0",
"@types/uuid": "^9.0.8",
"typescript": "^5.3.3"
},
"engines": {
@ -30,6 +31,12 @@
"undici-types": "~5.26.4"
}
},
"node_modules/@types/uuid": {
"version": "9.0.8",
"resolved": "https://registry.npmjs.org/@types/uuid/-/uuid-9.0.8.tgz",
"integrity": "sha512-jg+97EGIcY9AGHJJRaaPVgetKDsrTgbRjQ5Msgjh/DQKEFl0DtyRr/VCOyD1T2R1MNeWPK/u7JoGhlDZnKBAfA==",
"dev": true
},
"node_modules/@types/vscode": {
"version": "1.85.0",
"resolved": "https://registry.npmjs.org/@types/vscode/-/vscode-1.85.0.tgz",

View File

@ -38,6 +38,7 @@
},
"devDependencies": {
"@types/node": "^20.11.0",
"@types/uuid": "^9.0.8",
"typescript": "^5.3.3"
},
"dependencies": {

View File

@ -1,18 +1,18 @@
import * as vscode from 'vscode';
import * as vscode from 'vscode';
import {
LanguageClient,
LanguageClientOptions,
ServerOptions,
TransportKind
} from 'vscode-languageclient/node';
import { v4 as uuidv4 } from 'uuid';
// import { v4 as uuidv4 } from 'uuid';
let client: LanguageClient;
export function activate(context: vscode.ExtensionContext) {
// Configure the server options
let serverOptions: ServerOptions = {
command: "lsp-ai",
command: "lsp-ai",
transport: TransportKind.stdio,
};
@ -34,7 +34,7 @@ export function activate(context: vscode.ExtensionContext) {
// Register generate function
const generateCommand = 'lsp-ai.generate';
const generateCommandHandler = (editor) => {
const generateCommandHandler = (editor: vscode.TextEditor) => {
let params = {
textDocument: {
uri: editor.document.uri.toString(),
@ -42,7 +42,6 @@ export function activate(context: vscode.ExtensionContext) {
position: editor.selection.active
};
client.sendRequest("textDocument/generate", params).then(result => {
console.log("RECEIVED RESULT", result);
editor.edit((edit) => {
edit.insert(editor.selection.active, result["generatedText"]);
});
@ -52,28 +51,43 @@ export function activate(context: vscode.ExtensionContext) {
};
context.subscriptions.push(vscode.commands.registerTextEditorCommand(generateCommand, generateCommandHandler));
// Register functions
const generateStreamCommand = 'lsp-ai.generateStream';
const generateStreamCommandHandler = (editor) => {
let params = {
textDocument: {
uri: editor.document.uri.toString(),
},
position: editor.selection.active,
partialResultToken: uuidv4()
};
console.log("PARAMS: ", params);
client.sendRequest("textDocument/generateStream", params).then(result => {
console.log("RECEIVED RESULT", result);
editor.edit((edit) => {
edit.insert(editor.selection.active, result["generatedText"]);
});
}).catch(error => {
console.error("Error making generate request", error);
});
};
context.subscriptions.push(vscode.commands.registerTextEditorCommand(generateStreamCommand, generateStreamCommandHandler));
// This function is not ready to go
// const generateStreamCommand = 'lsp-ai.generateStream';
// const generateStreamCommandHandler = (editor: vscode.TextEditor) => {
// let params = {
// textDocument: {
// uri: editor.document.uri.toString(),
// },
// position: editor.selection.active,
// partialResultToken: uuidv4()
// };
// console.log("PARAMS: ", params);
// client.sendRequest("textDocument/generateStream", params).then(result => {
// console.log("RECEIVED RESULT", result);
// editor.edit((edit) => {
// edit.insert(editor.selection.active, result["generatedText"]);
// });
// }).catch(error => {
// console.error("Error making generate request", error);
// });
// };
// context.subscriptions.push(vscode.commands.registerTextEditorCommand(generateStreamCommand, generateStreamCommandHandler));
vscode.languages.registerInlineCompletionItemProvider({ pattern: '**' },
{
provideInlineCompletionItems: async (document: vscode.TextDocument, position: vscode.Position) => {
let params = {
textDocument: {
uri: document.uri.toString(),
},
position: position
};
const result = await client.sendRequest("textDocument/generate", params);
return [new vscode.InlineCompletionItem(result["generatedText"])];
}
}
);
}
export function deactivate(): Thenable<void> | undefined {

View File

@ -1,5 +1,5 @@
use anyhow::{Context, Result};
use serde::Deserialize;
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use std::collections::HashMap;
@ -21,7 +21,7 @@ pub enum ValidTransformerBackend {
PostgresML,
}
#[derive(Debug, Clone, Deserialize)]
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ChatMessage {
pub role: String,
pub content: String,
@ -241,3 +241,59 @@ impl Configuration {
}
}
}
#[cfg(test)]
mod test {
use super::*;
use serde_json::json;
#[test]
fn macos_model_gguf() {
let args = json!({
"memory": {
"file_store": {}
},
"macos": {
"model_gguf": {
"repository": "TheBloke/deepseek-coder-6.7B-instruct-GGUF",
"name": "deepseek-coder-6.7b-instruct.Q5_K_S.gguf",
"max_new_tokens": {
"completion": 32,
"generation": 256,
},
"fim": {
"start": "<fim_prefix>",
"middle": "<fim_suffix>",
"end": "<fim_middle>"
},
"chat": {
"completion": [
{
"role": "system",
"content": "You are a code completion chatbot. Use the following context to complete the next segement of code. Keep your response brief. Do not produce any text besides code. \n\n{context}",
},
{
"role": "user",
"content": "Complete the following code: \n\n{code}"
}
],
"generation": [
{
"role": "system",
"content": "You are a code completion chatbot. Use the following context to complete the next segement of code. \n\n{context}",
},
{
"role": "user",
"content": "Complete the following code: \n\n{code}"
}
],
"chat_template": "{% if not add_generation_prompt is defined %}\n{% set add_generation_prompt = false %}\n{% endif %}\n{%- set ns = namespace(found=false) -%}\n{%- for message in messages -%}\n {%- if message['role'] == 'system' -%}\n {%- set ns.found = true -%}\n {%- endif -%}\n{%- endfor -%}\n{{bos_token}}{%- if not ns.found -%}\n{{'You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer\\n'}}\n{%- endif %}\n{%- for message in messages %}\n {%- if message['role'] == 'system' %}\n{{ message['content'] }}\n {%- else %}\n {%- if message['role'] == 'user' %}\n{{'### Instruction:\\n' + message['content'] + '\\n'}}\n {%- else %}\n{{'### Response:\\n' + message['content'] + '\\n<|EOT|>\\n'}}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{% if add_generation_prompt %}\n{{'### Response:'}}\n{% endif %}"
},
"n_ctx": 2048,
"n_gpu_layers": 35,
}
},
});
Configuration::new(args).unwrap();
}
}

View File

@ -13,6 +13,7 @@ use tracing_subscriber::{EnvFilter, FmtSubscriber};
mod configuration;
mod custom_requests;
mod memory_backends;
mod template;
mod transformer_backends;
mod utils;
mod worker;
@ -25,7 +26,6 @@ use worker::{CompletionRequest, GenerateRequest, Worker, WorkerRequest};
use crate::{custom_requests::generate_stream::GenerateStream, worker::GenerateStreamRequest};
// Taken directly from: https://github.com/rust-lang/rust-analyzer
fn notification_is<N: lsp_types::notification::Notification>(notification: &Notification) -> bool {
notification.method == N::METHOD
}
@ -48,7 +48,7 @@ fn main() -> Result<()> {
FmtSubscriber::builder()
.with_writer(std::io::stderr)
.with_env_filter(EnvFilter::from_env("LSP_AI_LOG"))
.with_max_level(tracing::Level::TRACE)
// .with_max_level(tracing::Level::TRACE)
.init();
let (connection, io_threads) = Connection::stdio();

35
src/template.rs Normal file
View File

@ -0,0 +1,35 @@
use minijinja::{context, Environment, ErrorKind};
use once_cell::sync::Lazy;
use parking_lot::Mutex;
use crate::configuration::ChatMessage;
static MINININJA_ENVIRONMENT: Lazy<Mutex<Environment>> =
Lazy::new(|| Mutex::new(Environment::new()));
fn template_name_from_template_string(template: &str) -> String {
xxhash_rust::xxh3::xxh3_64(template.as_bytes()).to_string()
}
pub fn apply_chat_template(
template: &str,
chat_messages: Vec<ChatMessage>,
bos_token: &str,
eos_token: &str,
) -> anyhow::Result<String> {
let template_name = template_name_from_template_string(template);
let mut env = MINININJA_ENVIRONMENT.lock();
let template = match env.get_template(&template_name) {
Ok(template) => template,
Err(e) => match e.kind() {
ErrorKind::TemplateNotFound => {
env.add_template_owned(template_name.clone(), template.to_owned())?;
env.get_template(&template_name)?
}
_ => anyhow::bail!(e.to_string()),
},
};
Ok(template.render(
context!(messages => chat_messages, bos_token => bos_token, eos_token => eos_token),
)?)
}

View File

@ -1,11 +1,12 @@
use anyhow::Context;
use hf_hub::api::sync::Api;
use hf_hub::api::sync::ApiBuilder;
use tracing::{debug, instrument};
use super::TransformerBackend;
use crate::{
configuration::Configuration,
memory_backends::Prompt,
template::apply_chat_template,
utils::format_chat_messages,
worker::{
DoCompletionResponse, DoGenerateResponse, DoGenerateStreamResponse, GenerateStreamRequest,
@ -23,7 +24,7 @@ pub struct LlamaCPP {
impl LlamaCPP {
#[instrument]
pub fn new(configuration: Configuration) -> anyhow::Result<Self> {
let api = Api::new()?;
let api = ApiBuilder::new().with_progress(true).build()?;
let model = configuration.get_model()?;
let name = model
.name
@ -45,8 +46,13 @@ impl LlamaCPP {
Some(c) => {
if let Some(completion_messages) = &c.completion {
let chat_messages = format_chat_messages(completion_messages, prompt);
self.model
.apply_chat_template(chat_messages, c.chat_template.to_owned())?
if let Some(chat_template) = &c.chat_template {
let bos_token = self.model.get_bos_token()?;
let eos_token = self.model.get_eos_token()?;
apply_chat_template(&chat_template, chat_messages, &bos_token, &eos_token)?
} else {
self.model.apply_chat_template(chat_messages, None)?
}
} else {
prompt.code.to_owned()
}
@ -59,8 +65,9 @@ impl LlamaCPP {
impl TransformerBackend for LlamaCPP {
#[instrument(skip(self))]
fn do_completion(&self, prompt: &Prompt) -> anyhow::Result<DoCompletionResponse> {
let prompt = self.get_prompt_string(prompt)?;
// debug!("Prompt string for LLM: {}", prompt);
// let prompt = self.get_prompt_string(prompt)?;
let prompt = &prompt.code;
debug!("Prompt string for LLM: {}", prompt);
let max_new_tokens = self.configuration.get_max_new_tokens()?.completion;
self.model
.complete(&prompt, max_new_tokens)
@ -69,8 +76,9 @@ impl TransformerBackend for LlamaCPP {
#[instrument(skip(self))]
fn do_generate(&self, prompt: &Prompt) -> anyhow::Result<DoGenerateResponse> {
let prompt = self.get_prompt_string(prompt)?;
// let prompt = self.get_prompt_string(prompt)?;
// debug!("Prompt string for LLM: {}", prompt);
let prompt = &prompt.code;
let max_new_tokens = self.configuration.get_max_new_tokens()?.completion;
self.model
.complete(&prompt, max_new_tokens)

View File

@ -64,7 +64,9 @@ impl Model {
#[instrument(skip(self))]
pub fn complete(&self, prompt: &str, max_new_tokens: usize) -> anyhow::Result<String> {
// initialize the context
let ctx_params = LlamaContextParams::default().with_n_ctx(Some(self.n_ctx.clone()));
let ctx_params = LlamaContextParams::default()
.with_n_ctx(Some(self.n_ctx.clone()))
.with_n_batch(self.n_ctx.get());
let mut ctx = self
.model
@ -157,4 +159,16 @@ impl Model {
.model
.apply_chat_template(template, llama_chat_messages, true)?)
}
#[instrument(skip(self))]
pub fn get_eos_token(&self) -> anyhow::Result<String> {
let token = self.model.token_eos();
Ok(self.model.token_to_str(token)?)
}
#[instrument(skip(self))]
pub fn get_bos_token(&self) -> anyhow::Result<String> {
let token = self.model.token_bos();
Ok(self.model.token_to_str(token)?)
}
}