Reformat to work with MistralFIM

This commit is contained in:
Silas Marvin 2024-05-30 20:05:45 -07:00
parent de95cd0871
commit 7152331b1e
15 changed files with 543 additions and 302 deletions

8
Cargo.lock generated
View File

@ -1410,7 +1410,9 @@ checksum = "01cda141df6706de531b6c46c3a33ecca755538219bd484262fa09410c13539c"
[[package]]
name = "llama-cpp-2"
version = "0.1.52"
version = "0.1.53"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "50ee2be825e1e2f979393851319559d6f483303329a501c37799aacd7c9bdc12"
dependencies = [
"llama-cpp-sys-2",
"thiserror",
@ -1419,7 +1421,9 @@ dependencies = [
[[package]]
name = "llama-cpp-sys-2"
version = "0.1.52"
version = "0.1.53"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "11ec0f7b379c8935a01e97f734d15281f1cde9c256e2a6c8a482a11c5ed17ea7"
dependencies = [
"bindgen",
"cc",

View File

@ -18,8 +18,7 @@ tokenizers = "0.14.1"
parking_lot = "0.12.1"
once_cell = "1.19.0"
directories = "5.0.1"
# llama-cpp-2 = { version = "0.1.47", optional = true }
llama-cpp-2 = { path = "../llama-cpp-rs/llama-cpp-2", optional = true }
llama-cpp-2 = { version = "0.1.53", optional = true }
minijinja = { version = "1.0.12", features = ["loader"] }
tracing-subscriber = { version = "0.3.18", features = ["env-filter"] }
tracing = "0.1.40"
@ -32,8 +31,8 @@ indexmap = "2.2.5"
async-trait = "0.1.78"
[features]
default = ["llamacpp"]
llamacpp = ["dep:llama-cpp-2"]
default = ["llama_cpp"]
llama_cpp = ["dep:llama-cpp-2"]
cublas = ["llama-cpp-2/cublas"]
[dev-dependencies]

View File

@ -16,13 +16,15 @@ pub enum ValidMemoryBackend {
#[derive(Debug, Clone, Deserialize)]
#[serde(tag = "type")]
pub enum ValidModel {
#[cfg(feature = "llamacpp")]
#[serde(rename = "llamacpp")]
#[cfg(feature = "llama_cpp")]
#[serde(rename = "llama_cpp")]
LLaMACPP(LLaMACPP),
#[serde(rename = "openai")]
#[serde(rename = "open_ai")]
OpenAI(OpenAI),
#[serde(rename = "anthropic")]
Anthropic(Anthropic),
#[serde(rename = "mistral_fim")]
MistralFIM(MistralFIM),
}
#[derive(Debug, Clone, Deserialize, Serialize)]
@ -30,6 +32,17 @@ pub enum ValidModel {
pub struct ChatMessage {
pub role: String,
pub content: String,
pub tool_calls: Option<Value>, // This is to be compatible with Mistral
}
impl ChatMessage {
pub fn new(role: String, content: String) -> Self {
Self {
role,
content,
tool_calls: None,
}
}
}
#[derive(Debug, Clone, Deserialize)]
@ -80,6 +93,20 @@ const fn n_ctx_default() -> u32 {
1000
}
#[derive(Clone, Debug, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct MistralFIM {
// The auth token env var name
pub auth_token_env_var_name: Option<String>,
pub auth_token: Option<String>,
// The fim endpoint
pub fim_endpoint: Option<String>,
// The model name
pub model: String,
#[serde(default = "api_max_requests_per_second_default")]
pub max_requests_per_second: f32,
}
#[derive(Clone, Debug, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct LLaMACPP {
@ -205,10 +232,11 @@ impl Config {
&self.config.completion.as_ref().unwrap().model
)
})? {
#[cfg(feature = "llamacpp")]
#[cfg(feature = "llama_cpp")]
ValidModel::LLaMACPP(_) => Ok(1.),
ValidModel::OpenAI(openai) => Ok(openai.max_requests_per_second),
ValidModel::OpenAI(open_ai) => Ok(open_ai.max_requests_per_second),
ValidModel::Anthropic(anthropic) => Ok(anthropic.max_requests_per_second),
ValidModel::MistralFIM(mistral_fim) => Ok(mistral_fim.max_requests_per_second),
}
}
}
@ -237,7 +265,7 @@ mod test {
use serde_json::json;
#[test]
#[cfg(feature = "llamacpp")]
#[cfg(feature = "llama_cpp")]
fn llama_cpp_config() {
let args = json!({
"initializationOptions": {
@ -246,7 +274,7 @@ mod test {
},
"models": {
"model1": {
"type": "llamacpp",
"type": "llama_cpp",
"repository": "TheBloke/deepseek-coder-6.7B-instruct-GGUF",
"name": "deepseek-coder-6.7b-instruct.Q5_K_S.gguf",
"n_ctx": 2048,
@ -271,7 +299,7 @@ mod test {
}
#[test]
fn openai_config() {
fn open_ai_config() {
let args = json!({
"initializationOptions": {
"memory": {
@ -279,7 +307,7 @@ mod test {
},
"models": {
"model1": {
"type": "openai",
"type": "open_ai",
"completions_endpoint": "https://api.fireworks.ai/inference/v1/completions",
"model": "accounts/fireworks/models/llama-v2-34b-code",
"auth_token_env_var_name": "FIREWORKS_API_KEY",

View File

@ -17,7 +17,7 @@ mod config;
mod custom_requests;
mod memory_backends;
mod memory_worker;
#[cfg(feature = "llamacpp")]
#[cfg(feature = "llama_cpp")]
mod template;
mod transformer_backends;
mod transformer_worker;

View File

@ -12,7 +12,7 @@ use crate::{
utils::tokens_to_estimated_characters,
};
use super::{MemoryBackend, MemoryRunParams, Prompt};
use super::{ContextAndCodePrompt, FIMPrompt, MemoryBackend, MemoryRunParams, Prompt, PromptType};
pub struct FileStore {
_crawl: bool,
@ -106,56 +106,54 @@ impl FileStore {
pub fn build_code(
&self,
position: &TextDocumentPositionParams,
prompt_type: PromptType,
params: MemoryRunParams,
) -> anyhow::Result<String> {
) -> anyhow::Result<Prompt> {
let (mut rope, cursor_index) =
self.get_rope_for_position(position, params.max_context_length)?;
// Prioritize doing chat
// If FIM is enabled, make sure the cursor is not at the end of the file as that is just completion
// If not chat and not FIM do completion
Ok(match (params.messages.is_some(), params.fim) {
(true, _) => {
Ok(match prompt_type {
PromptType::ContextAndCode => {
if params.messages.is_some() {
let max_length = tokens_to_estimated_characters(params.max_context_length);
let start = cursor_index.saturating_sub(max_length / 2);
let end = rope
.len_chars()
.min(cursor_index + (max_length - (cursor_index - start)));
rope.insert(cursor_index, "<CURSOR>");
let rope_slice = rope
.get_slice(start..end + "<CURSOR>".chars().count())
.context("Error getting rope slice")?;
Prompt::ContextAndCode(ContextAndCodePrompt::new(
"".to_string(),
rope_slice.to_string(),
))
} else {
let start = cursor_index
.saturating_sub(tokens_to_estimated_characters(params.max_context_length));
let rope_slice = rope
.get_slice(start..cursor_index)
.context("Error getting rope slice")?;
Prompt::ContextAndCode(ContextAndCodePrompt::new(
"".to_string(),
rope_slice.to_string(),
))
}
}
PromptType::FIM => {
let max_length = tokens_to_estimated_characters(params.max_context_length);
let start = cursor_index.saturating_sub(max_length / 2);
let end = rope
.len_chars()
.min(cursor_index + (max_length - (cursor_index - start)));
rope.insert(cursor_index, "<CURSOR>");
let rope_slice = rope
.get_slice(start..end + "<CURSOR>".chars().count())
.context("Error getting rope slice")?;
rope_slice.to_string()
}
(false, Some(fim)) if rope.len_chars() != cursor_index => {
let max_length = tokens_to_estimated_characters(params.max_context_length);
let start = cursor_index.saturating_sub(max_length / 2);
let end = rope
.len_chars()
.min(cursor_index + (max_length - (cursor_index - start)));
rope.insert(end, &fim.end);
rope.insert(cursor_index, &fim.middle);
rope.insert(start, &fim.start);
let rope_slice = rope
.get_slice(
start
..end
+ fim.start.chars().count()
+ fim.middle.chars().count()
+ fim.end.chars().count(),
)
.context("Error getting rope slice")?;
rope_slice.to_string()
}
_ => {
let start = cursor_index
.saturating_sub(tokens_to_estimated_characters(params.max_context_length));
let rope_slice = rope
let prefix = rope
.get_slice(start..cursor_index)
.context("Error getting rope slice")?;
rope_slice.to_string()
let suffix = rope
.get_slice(cursor_index..end)
.context("Error getting rope slice")?;
Prompt::FIM(FIMPrompt::new(prefix.to_string(), suffix.to_string()))
}
})
}
@ -186,11 +184,11 @@ impl MemoryBackend for FileStore {
async fn build_prompt(
&self,
position: &TextDocumentPositionParams,
prompt_type: PromptType,
params: Value,
) -> anyhow::Result<Prompt> {
let params: MemoryRunParams = serde_json::from_value(params)?;
let code = self.build_code(position, params)?;
Ok(Prompt::new("".to_string(), code))
self.build_code(position, prompt_type, params)
}
#[instrument(skip(self))]
@ -404,7 +402,6 @@ The end with a trailing new line
let file_store = generate_base_file_store()?;
file_store.opened_text_document(params).await?;
let params = json!({});
let prompt = file_store
.build_prompt(
&TextDocumentPositionParams {
@ -416,20 +413,15 @@ The end with a trailing new line
character: 10,
},
},
params,
PromptType::ContextAndCode,
json!({}),
)
.await?;
let prompt: ContextAndCodePrompt = prompt.try_into()?;
assert_eq!(prompt.context, "");
assert_eq!("Document T", prompt.code);
// Test FIM
let params = json!({
"fim": {
"start": "SS",
"middle": "MM",
"end": "EE"
}
});
let prompt = file_store
.build_prompt(
&TextDocumentPositionParams {
@ -441,24 +433,24 @@ The end with a trailing new line
character: 10,
},
},
params,
PromptType::FIM,
json!({}),
)
.await?;
assert_eq!(prompt.context, "");
let text = r#"SSDocument TMMop
let prompt: FIMPrompt = prompt.try_into()?;
assert_eq!(prompt.prompt, r#"Document T"#);
assert_eq!(
prompt.suffix,
r#"op
Here is a more complicated document
Some text
The end with a trailing new line
EE"#
.to_string();
assert_eq!(text, prompt.code);
"#
);
// Test chat
let params = json!({
"messages": []
});
let prompt = file_store
.build_prompt(
&TextDocumentPositionParams {
@ -470,9 +462,13 @@ EE"#
character: 10,
},
},
params,
PromptType::ContextAndCode,
json!({
"messages": []
}),
)
.await?;
let prompt: ContextAndCodePrompt = prompt.try_into()?;
assert_eq!(prompt.context, "");
let text = r#"Document T<CURSOR>op
Here is a more complicated document
@ -502,7 +498,6 @@ The end with a trailing new line
};
file_store.opened_text_document(params).await?;
let params = json!({});
let prompt = file_store
.build_prompt(
&TextDocumentPositionParams {
@ -514,9 +509,11 @@ The end with a trailing new line
character: 10,
},
},
params,
PromptType::ContextAndCode,
json!({}),
)
.await?;
let prompt: ContextAndCodePrompt = prompt.try_into()?;
assert_eq!(prompt.context, "");
assert_eq!(format!("{}\nDocument T", text_document2.text), prompt.code);
@ -533,9 +530,6 @@ The end with a trailing new line
file_store.opened_text_document(params).await?;
// Test chat
let params = json!({
"messages": []
});
let prompt = file_store
.build_prompt(
&TextDocumentPositionParams {
@ -547,9 +541,11 @@ The end with a trailing new line
character: 0,
},
},
params,
PromptType::ContextAndCode,
json!({"messages": []}),
)
.await?;
let prompt: ContextAndCodePrompt = prompt.try_into()?;
assert_eq!(prompt.context, "");
let text = r#"test
<CURSOR>"#
@ -559,43 +555,43 @@ The end with a trailing new line
Ok(())
}
#[tokio::test]
async fn test_fim_placement_corner_cases() -> anyhow::Result<()> {
let text_document = generate_filler_text_document(None, Some("test\n"));
let params = lsp_types::DidOpenTextDocumentParams {
text_document: text_document.clone(),
};
let file_store = generate_base_file_store()?;
file_store.opened_text_document(params).await?;
// #[tokio::test]
// async fn test_fim_placement_corner_cases() -> anyhow::Result<()> {
// let text_document = generate_filler_text_document(None, Some("test\n"));
// let params = lsp_types::DidOpenTextDocumentParams {
// text_document: text_document.clone(),
// };
// let file_store = generate_base_file_store()?;
// file_store.opened_text_document(params).await?;
// Test FIM
let params = json!({
"fim": {
"start": "SS",
"middle": "MM",
"end": "EE"
}
});
let prompt = file_store
.build_prompt(
&TextDocumentPositionParams {
text_document: TextDocumentIdentifier {
uri: text_document.uri.clone(),
},
position: Position {
line: 1,
character: 0,
},
},
params,
)
.await?;
assert_eq!(prompt.context, "");
let text = r#"test
"#
.to_string();
assert_eq!(text, prompt.code);
// // Test FIM
// let params = json!({
// "fim": {
// "start": "SS",
// "middle": "MM",
// "end": "EE"
// }
// });
// let prompt = file_store
// .build_prompt(
// &TextDocumentPositionParams {
// text_document: TextDocumentIdentifier {
// uri: text_document.uri.clone(),
// },
// position: Position {
// line: 1,
// character: 0,
// },
// },
// params,
// )
// .await?;
// assert_eq!(prompt.context, "");
// let text = r#"test
// "#
// .to_string();
// assert_eq!(text, prompt.code);
Ok(())
}
// Ok(())
// }
}

View File

@ -5,7 +5,7 @@ use lsp_types::{
use serde::Deserialize;
use serde_json::Value;
use crate::config::{ChatMessage, Config, ValidMemoryBackend, FIM};
use crate::config::{ChatMessage, Config, ValidMemoryBackend};
pub mod file_store;
mod postgresml;
@ -14,26 +14,96 @@ const fn max_context_length_default() -> usize {
1024
}
#[derive(Clone, Debug)]
pub enum PromptType {
ContextAndCode,
FIM,
}
#[derive(Clone, Deserialize)]
pub struct MemoryRunParams {
pub fim: Option<FIM>,
pub messages: Option<Vec<ChatMessage>>,
#[serde(default = "max_context_length_default")]
pub max_context_length: usize,
}
#[derive(Debug)]
pub struct Prompt {
pub struct ContextAndCodePrompt {
pub context: String,
pub code: String,
}
impl Prompt {
impl ContextAndCodePrompt {
pub fn new(context: String, code: String) -> Self {
Self { context, code }
}
}
#[derive(Debug)]
pub struct FIMPrompt {
pub prompt: String,
pub suffix: String,
}
impl FIMPrompt {
pub fn new(prefix: String, suffix: String) -> Self {
Self {
prompt: prefix,
suffix,
}
}
}
#[derive(Debug)]
pub enum Prompt {
FIM(FIMPrompt),
ContextAndCode(ContextAndCodePrompt),
}
impl<'a> TryFrom<&'a Prompt> for &'a ContextAndCodePrompt {
type Error = anyhow::Error;
fn try_from(value: &'a Prompt) -> Result<Self, Self::Error> {
match value {
Prompt::ContextAndCode(code_and_context) => Ok(code_and_context),
_ => anyhow::bail!("cannot convert Prompt into CodeAndContextPrompt"),
}
}
}
impl TryFrom<Prompt> for ContextAndCodePrompt {
type Error = anyhow::Error;
fn try_from(value: Prompt) -> Result<Self, Self::Error> {
match value {
Prompt::ContextAndCode(code_and_context) => Ok(code_and_context),
_ => anyhow::bail!("cannot convert Prompt into CodeAndContextPrompt"),
}
}
}
impl TryFrom<Prompt> for FIMPrompt {
type Error = anyhow::Error;
fn try_from(value: Prompt) -> Result<Self, Self::Error> {
match value {
Prompt::FIM(fim) => Ok(fim),
_ => anyhow::bail!("cannot convert Prompt into FIMPrompt"),
}
}
}
impl<'a> TryFrom<&'a Prompt> for &'a FIMPrompt {
type Error = anyhow::Error;
fn try_from(value: &'a Prompt) -> Result<Self, Self::Error> {
match value {
Prompt::FIM(fim) => Ok(fim),
_ => anyhow::bail!("cannot convert Prompt into FIMPrompt"),
}
}
}
#[async_trait::async_trait]
pub trait MemoryBackend {
async fn init(&self) -> anyhow::Result<()> {
@ -48,6 +118,7 @@ pub trait MemoryBackend {
async fn build_prompt(
&self,
position: &TextDocumentPositionParams,
prompt_type: PromptType,
params: Value,
) -> anyhow::Result<Prompt>;
async fn get_filter_text(
@ -76,16 +147,23 @@ impl TryFrom<Config> for Box<dyn MemoryBackend + Send + Sync> {
#[cfg(test)]
impl Prompt {
pub fn default_with_cursor() -> Self {
Self {
context: r#"def test_context():\n pass"#.to_string(),
code: r#"def test_code():\n <CURSOR>"#.to_string(),
}
Self::ContextAndCode(ContextAndCodePrompt::new(
r#"def test_context():\n pass"#.to_string(),
r#"def test_code():\n <CURSOR>"#.to_string(),
))
}
pub fn default_fim() -> Self {
Self::FIM(FIMPrompt::new(
r#"def test_context():\n pass"#.to_string(),
r#"def test_code():\n "#.to_string(),
))
}
pub fn default_without_cursor() -> Self {
Self {
context: r#"def test_context():\n pass"#.to_string(),
code: r#"def test_code():\n "#.to_string(),
}
Self::ContextAndCode(ContextAndCodePrompt::new(
r#"def test_context():\n pass"#.to_string(),
r#"def test_code():\n "#.to_string(),
))
}
}

View File

@ -15,7 +15,9 @@ use crate::{
utils::tokens_to_estimated_characters,
};
use super::{file_store::FileStore, MemoryBackend, MemoryRunParams, Prompt};
use super::{
file_store::FileStore, ContextAndCodePrompt, MemoryBackend, MemoryRunParams, Prompt, PromptType,
};
pub struct PostgresML {
_config: Config,
@ -129,6 +131,7 @@ impl MemoryBackend for PostgresML {
async fn build_prompt(
&self,
position: &TextDocumentPositionParams,
prompt_type: PromptType,
params: Value,
) -> anyhow::Result<Prompt> {
let params: MemoryRunParams = serde_json::from_value(params)?;
@ -164,13 +167,19 @@ impl MemoryBackend for PostgresML {
.join("\n\n");
let mut file_store_params = params.clone();
file_store_params.max_context_length = 512;
let code = self.file_store.build_code(position, file_store_params)?;
let code = self
.file_store
.build_code(position, prompt_type, file_store_params)?;
let code: ContextAndCodePrompt = code.try_into()?;
let code = code.code;
let max_characters = tokens_to_estimated_characters(params.max_context_length);
let context: String = context
let _context: String = context
.chars()
.take(max_characters - code.chars().count())
.collect();
Ok(Prompt::new(context, code))
// We need to redo this section to work with the new memory backend system
todo!()
// Ok(Prompt::new(context, code))
}
#[instrument(skip(self))]

View File

@ -7,11 +7,12 @@ use lsp_types::{
use serde_json::Value;
use tracing::error;
use crate::memory_backends::{MemoryBackend, Prompt};
use crate::memory_backends::{MemoryBackend, Prompt, PromptType};
#[derive(Debug)]
pub struct PromptRequest {
position: TextDocumentPositionParams,
prompt_type: PromptType,
params: Value,
tx: tokio::sync::oneshot::Sender<Prompt>,
}
@ -19,11 +20,13 @@ pub struct PromptRequest {
impl PromptRequest {
pub fn new(
position: TextDocumentPositionParams,
prompt_type: PromptType,
params: Value,
tx: tokio::sync::oneshot::Sender<Prompt>,
) -> Self {
Self {
position,
prompt_type,
params,
tx,
}
@ -67,7 +70,7 @@ async fn do_task(
}
WorkerRequest::Prompt(params) => {
let prompt = memory_backend
.build_prompt(&params.position, params.params)
.build_prompt(&params.position, params.prompt_type, params.params)
.await?;
params
.tx

View File

@ -7,8 +7,7 @@ use crate::{
config::{self, ChatMessage},
memory_backends::Prompt,
transformer_worker::{
DoCompletionResponse, DoGenerationResponse, DoGenerationStreamResponse,
GenerationStreamRequest,
DoGenerationResponse, DoGenerationStreamResponse, GenerationStreamRequest,
},
utils::format_chat_messages,
};
@ -41,7 +40,7 @@ pub struct AnthropicRunParams {
}
pub struct Anthropic {
configuration: config::Anthropic,
config: config::Anthropic,
}
#[derive(Deserialize)]
@ -56,9 +55,8 @@ struct AnthropicChatResponse {
}
impl Anthropic {
#[instrument]
pub fn new(configuration: config::Anthropic) -> Self {
Self { configuration }
pub fn new(config: config::Anthropic) -> Self {
Self { config }
}
async fn get_chat(
@ -68,9 +66,9 @@ impl Anthropic {
params: AnthropicRunParams,
) -> anyhow::Result<String> {
let client = reqwest::Client::new();
let token = if let Some(env_var_name) = &self.configuration.auth_token_env_var_name {
let token = if let Some(env_var_name) = &self.config.auth_token_env_var_name {
std::env::var(env_var_name)?
} else if let Some(token) = &self.configuration.auth_token {
} else if let Some(token) = &self.config.auth_token {
token.to_string()
} else {
anyhow::bail!(
@ -79,7 +77,7 @@ impl Anthropic {
};
let res: AnthropicChatResponse = client
.post(
self.configuration
self.config
.chat_endpoint
.as_ref()
.context("must specify `completions_endpoint` to use completions")?,
@ -89,7 +87,7 @@ impl Anthropic {
.header("Content-Type", "application/json")
.header("Accept", "application/json")
.json(&json!({
"model": self.configuration.model,
"model": self.config.model,
"system": system_prompt,
"max_tokens": params.max_tokens,
"top_p": params.top_p,
@ -114,12 +112,12 @@ impl Anthropic {
prompt: &Prompt,
params: AnthropicRunParams,
) -> anyhow::Result<String> {
let mut messages = vec![ChatMessage {
role: "system".to_string(),
content: params.system.clone(),
}];
let mut messages = vec![ChatMessage::new(
"system".to_string(),
params.system.clone(),
)];
messages.extend_from_slice(&params.messages);
let mut messages = format_chat_messages(&messages, prompt);
let mut messages = format_chat_messages(&messages, prompt.try_into()?);
let system_prompt = messages.remove(0).content;
self.get_chat(system_prompt, messages, params).await
}
@ -127,17 +125,6 @@ impl Anthropic {
#[async_trait::async_trait]
impl TransformerBackend for Anthropic {
#[instrument(skip(self))]
async fn do_completion(
&self,
prompt: &Prompt,
params: Value,
) -> anyhow::Result<DoCompletionResponse> {
let params: AnthropicRunParams = serde_json::from_value(params)?;
let insert_text = self.do_get_chat(prompt, params).await?;
Ok(DoCompletionResponse { insert_text })
}
#[instrument(skip(self))]
async fn do_generate(
&self,
@ -164,30 +151,6 @@ mod test {
use super::*;
use serde_json::{from_value, json};
#[tokio::test]
async fn anthropic_chat_do_completion() -> anyhow::Result<()> {
let configuration: config::Anthropic = from_value(json!({
"chat_endpoint": "https://api.anthropic.com/v1/messages",
"model": "claude-3-haiku-20240307",
"auth_token_env_var_name": "ANTHROPIC_API_KEY",
}))?;
let anthropic = Anthropic::new(configuration);
let prompt = Prompt::default_with_cursor();
let run_params = json!({
"system": "Test",
"messages": [
{
"role": "user",
"content": "Test {CONTEXT} - {CODE}"
}
],
"max_tokens": 2
});
let response = anthropic.do_completion(&prompt, run_params).await?;
assert!(!response.insert_text.is_empty());
Ok(())
}
#[tokio::test]
async fn anthropic_chat_do_generate() -> anyhow::Result<()> {
let configuration: config::Anthropic = from_value(json!({

View File

@ -1,9 +1,4 @@
use anyhow::Context;
use hf_hub::api::sync::ApiBuilder;
use serde::Deserialize;
use serde_json::Value;
use tracing::instrument;
use super::TransformerBackend;
use crate::{
config::{self, ChatMessage, FIM},
memory_backends::Prompt,
@ -14,12 +9,15 @@ use crate::{
},
utils::format_chat_messages,
};
use anyhow::Context;
use hf_hub::api::sync::ApiBuilder;
use serde::Deserialize;
use serde_json::Value;
use tracing::instrument;
mod model;
use model::Model;
use super::TransformerBackend;
const fn max_new_tokens_default() -> usize {
32
}
@ -61,20 +59,31 @@ impl LLaMACPP {
prompt: &Prompt,
params: &LLaMACPPRunParams,
) -> anyhow::Result<String> {
Ok(match &params.messages {
Some(completion_messages) => {
let chat_messages = format_chat_messages(completion_messages, prompt);
if let Some(chat_template) = &params.chat_template {
let bos_token = self.model.get_bos_token()?;
let eos_token = self.model.get_eos_token()?;
apply_chat_template(chat_template, chat_messages, &bos_token, &eos_token)?
} else {
self.model
.apply_chat_template(chat_messages, params.chat_format.clone())?
match prompt {
Prompt::ContextAndCode(context_and_code) => Ok(match &params.messages {
Some(completion_messages) => {
let chat_messages = format_chat_messages(completion_messages, context_and_code);
if let Some(chat_template) = &params.chat_template {
let bos_token = self.model.get_bos_token()?;
let eos_token = self.model.get_eos_token()?;
apply_chat_template(chat_template, chat_messages, &bos_token, &eos_token)?
} else {
self.model
.apply_chat_template(chat_messages, params.chat_format.clone())?
}
}
}
None => prompt.code.to_owned(),
})
None => context_and_code.code.clone(),
}),
Prompt::FIM(fim) => Ok(match &params.fim {
Some(fim_params) => {
format!(
"{}{}{}{}{}",
fim_params.start, fim.prompt, fim_params.middle, fim.suffix, fim_params.end
)
}
None => anyhow::bail!("Prompt type is FIM but no FIM parameters provided"),
}),
}
}
}
@ -143,7 +152,7 @@ mod test {
}
],
"chat_format": "llama2",
"max_tokens": 64
"max_tokens": 4
});
let response = llama_cpp.do_completion(&prompt, run_params).await?;
assert!(!response.insert_text.is_empty());
@ -166,7 +175,7 @@ mod test {
"middle": "<fim_suffix>",
"end": "<fim_middle>"
},
"max_tokens": 64
"max_tokens": 4
});
let response = llama_cpp.do_completion(&prompt, run_params).await?;
assert!(!response.insert_text.is_empty());
@ -189,7 +198,7 @@ mod test {
"middle": "<fim_suffix>",
"end": "<fim_middle>"
},
"max_tokens": 64
"max_tokens": 4
});
let response = llama_cpp.do_generate(&prompt, run_params).await?;
assert!(!response.generated_text.is_empty());

View File

@ -0,0 +1,153 @@
use anyhow::Context;
use serde::Deserialize;
use serde_json::{json, Value};
use tracing::instrument;
use super::{open_ai::OpenAIChatResponse, TransformerBackend};
use crate::{
config::{self},
memory_backends::{FIMPrompt, Prompt, PromptType},
transformer_worker::{
DoGenerationResponse, DoGenerationStreamResponse, GenerationStreamRequest,
},
};
const fn max_tokens_default() -> usize {
64
}
const fn top_p_default() -> f32 {
0.95
}
const fn temperature_default() -> f32 {
0.1
}
// NOTE: We cannot deny unknown fields as the provided parameters may contain other fields relevant to other processes
#[derive(Debug, Deserialize)]
pub struct MistralFIMRunParams {
#[serde(default = "max_tokens_default")]
pub max_tokens: usize,
#[serde(default = "top_p_default")]
pub top_p: f32,
#[serde(default = "temperature_default")]
pub temperature: f32,
pub min_tokens: Option<u64>,
pub random_seed: Option<u64>,
#[serde(default)]
pub stop: Vec<String>,
}
pub struct MistralFIM {
config: config::MistralFIM,
}
impl MistralFIM {
pub fn new(config: config::MistralFIM) -> Self {
Self { config }
}
fn get_token(&self) -> anyhow::Result<String> {
if let Some(env_var_name) = &self.config.auth_token_env_var_name {
Ok(std::env::var(env_var_name)?)
} else if let Some(token) = &self.config.auth_token {
Ok(token.to_string())
} else {
anyhow::bail!(
"set `auth_token_env_var_name` or `auth_token` to use an MistralFIM compatible API"
)
}
}
async fn do_fim(
&self,
prompt: &FIMPrompt,
params: MistralFIMRunParams,
) -> anyhow::Result<String> {
let client = reqwest::Client::new();
let token = self.get_token()?;
let res: OpenAIChatResponse = client
.post(
self.config
.fim_endpoint
.as_ref()
.context("must specify `fim_endpoint` to use fim")?,
)
.bearer_auth(token)
.header("Content-Type", "application/json")
.header("Accept", "application/json")
.json(&json!({
"prompt": prompt.prompt,
"suffix": prompt.suffix,
"model": self.config.model,
"max_tokens": params.max_tokens,
"top_p": params.top_p,
"temperature": params.temperature,
"min_tokens": params.min_tokens,
"random_seed": params.random_seed,
"stop": params.stop
}))
.send()
.await?
.json()
.await?;
if let Some(error) = res.error {
anyhow::bail!("{:?}", error.to_string())
} else if let Some(choices) = res.choices {
Ok(choices[0].message.content.clone())
} else {
anyhow::bail!("Unknown error while making request to OpenAI")
}
}
}
#[async_trait::async_trait]
impl TransformerBackend for MistralFIM {
#[instrument(skip(self))]
async fn do_generate(
&self,
prompt: &Prompt,
params: Value,
) -> anyhow::Result<DoGenerationResponse> {
let params: MistralFIMRunParams = serde_json::from_value(params)?;
let generated_text = self.do_fim(prompt.try_into()?, params).await?;
Ok(DoGenerationResponse { generated_text })
}
#[instrument(skip(self))]
async fn do_generate_stream(
&self,
request: &GenerationStreamRequest,
_params: Value,
) -> anyhow::Result<DoGenerationStreamResponse> {
anyhow::bail!("GenerationStream is not yet implemented")
}
fn get_prompt_type(&self, _params: &Value) -> anyhow::Result<PromptType> {
Ok(PromptType::FIM)
}
}
#[cfg(test)]
mod test {
use super::*;
use serde_json::{from_value, json};
#[tokio::test]
async fn mistral_fim_do_generate() -> anyhow::Result<()> {
let configuration: config::MistralFIM = from_value(json!({
"fim_endpoint": "https://api.mistral.ai/v1/fim/completions",
"model": "codestral-latest",
"auth_token_env_var_name": "MISTRAL_API_KEY",
}))?;
let anthropic = MistralFIM::new(configuration);
let prompt = Prompt::default_fim();
let run_params = json!({
"max_tokens": 2
});
let response = anthropic.do_generate(&prompt, run_params).await?;
assert!(!response.generated_text.is_empty());
Ok(())
}
}

View File

@ -1,8 +1,9 @@
use anyhow::Context;
use serde_json::Value;
use crate::{
config::ValidModel,
memory_backends::Prompt,
memory_backends::{Prompt, PromptType},
transformer_worker::{
DoCompletionResponse, DoGenerationResponse, DoGenerationStreamResponse,
GenerationStreamRequest,
@ -10,9 +11,10 @@ use crate::{
};
mod anthropic;
#[cfg(feature = "llamacpp")]
#[cfg(feature = "llama_cpp")]
mod llama_cpp;
mod openai;
mod mistral_fim;
mod open_ai;
#[async_trait::async_trait]
pub trait TransformerBackend {
@ -20,17 +22,37 @@ pub trait TransformerBackend {
&self,
prompt: &Prompt,
params: Value,
) -> anyhow::Result<DoCompletionResponse>;
) -> anyhow::Result<DoCompletionResponse> {
self.do_generate(prompt, params)
.await
.map(|x| DoCompletionResponse {
insert_text: x.generated_text,
})
}
async fn do_generate(
&self,
prompt: &Prompt,
params: Value,
) -> anyhow::Result<DoGenerationResponse>;
async fn do_generate_stream(
&self,
request: &GenerationStreamRequest,
params: Value,
) -> anyhow::Result<DoGenerationStreamResponse>;
fn get_prompt_type(&self, params: &Value) -> anyhow::Result<PromptType> {
if params
.as_object()
.context("params must be a JSON object")?
.contains_key("fim")
{
Ok(PromptType::FIM)
} else {
Ok(PromptType::ContextAndCode)
}
}
}
impl TryFrom<ValidModel> for Box<dyn TransformerBackend + Send + Sync> {
@ -38,12 +60,17 @@ impl TryFrom<ValidModel> for Box<dyn TransformerBackend + Send + Sync> {
fn try_from(valid_model: ValidModel) -> Result<Self, Self::Error> {
match valid_model {
#[cfg(feature = "llamacpp")]
#[cfg(feature = "llama_cpp")]
ValidModel::LLaMACPP(model_gguf) => Ok(Box::new(llama_cpp::LLaMACPP::new(model_gguf)?)),
ValidModel::OpenAI(openai_config) => Ok(Box::new(openai::OpenAI::new(openai_config))),
ValidModel::OpenAI(open_ai_config) => {
Ok(Box::new(open_ai::OpenAI::new(open_ai_config)))
}
ValidModel::Anthropic(anthropic_config) => {
Ok(Box::new(anthropic::Anthropic::new(anthropic_config)))
}
ValidModel::MistralFIM(mistral_fim) => {
Ok(Box::new(mistral_fim::MistralFIM::new(mistral_fim)))
}
}
}
}

View File

@ -7,8 +7,7 @@ use crate::{
config::{self, ChatMessage, FIM},
memory_backends::Prompt,
transformer_worker::{
DoCompletionResponse, DoGenerationResponse, DoGenerationStreamResponse,
GenerationStreamRequest,
DoGenerationResponse, DoGenerationStreamResponse, GenerationStreamRequest,
},
utils::{format_chat_messages, format_context_code},
};
@ -68,14 +67,14 @@ struct OpenAICompletionsResponse {
}
#[derive(Deserialize)]
struct OpenAIChatChoices {
message: ChatMessage,
pub struct OpenAIChatChoices {
pub message: ChatMessage,
}
#[derive(Deserialize)]
struct OpenAIChatResponse {
choices: Option<Vec<OpenAIChatChoices>>,
error: Option<Value>,
pub struct OpenAIChatResponse {
pub choices: Option<Vec<OpenAIChatChoices>>,
pub error: Option<Value>,
}
impl OpenAI {
@ -180,32 +179,43 @@ impl OpenAI {
prompt: &Prompt,
params: OpenAIRunParams,
) -> anyhow::Result<String> {
match &params.messages {
Some(completion_messages) => {
let messages = format_chat_messages(completion_messages, prompt);
self.get_chat(messages, params).await
}
None => {
self.get_completion(&format_context_code(&prompt.context, &prompt.code), params)
match prompt {
Prompt::ContextAndCode(code_and_context) => match &params.messages {
Some(completion_messages) => {
let messages = format_chat_messages(completion_messages, code_and_context);
self.get_chat(messages, params).await
}
None => {
self.get_completion(
&format_context_code(&code_and_context.context, &code_and_context.code),
params,
)
.await
}
}
},
Prompt::FIM(fim) => match &params.fim {
Some(fim_params) => {
self.get_completion(
&format!(
"{}{}{}{}{}",
fim_params.start,
fim.prompt,
fim_params.middle,
fim.suffix,
fim_params.end
),
params,
)
.await
}
None => anyhow::bail!("Prompt type is FIM but no FIM parameters provided"),
},
}
}
}
#[async_trait::async_trait]
impl TransformerBackend for OpenAI {
#[instrument(skip(self))]
async fn do_completion(
&self,
prompt: &Prompt,
params: Value,
) -> anyhow::Result<DoCompletionResponse> {
let params: OpenAIRunParams = serde_json::from_value(params)?;
let insert_text = self.do_chat_completion(prompt, params).await?;
Ok(DoCompletionResponse { insert_text })
}
#[instrument(skip(self))]
async fn do_generate(
&self,
@ -234,74 +244,30 @@ mod test {
use serde_json::{from_value, json};
#[tokio::test]
async fn openai_completion_do_completion() -> anyhow::Result<()> {
async fn open_ai_completion_do_generate() -> anyhow::Result<()> {
let configuration: config::OpenAI = from_value(json!({
"completions_endpoint": "https://api.openai.com/v1/completions",
"model": "gpt-3.5-turbo-instruct",
"auth_token_env_var_name": "OPENAI_API_KEY",
}))?;
let openai = OpenAI::new(configuration);
let open_ai = OpenAI::new(configuration);
let prompt = Prompt::default_without_cursor();
let run_params = json!({
"max_tokens": 64
});
let response = openai.do_completion(&prompt, run_params).await?;
assert!(!response.insert_text.is_empty());
Ok(())
}
#[tokio::test]
async fn openai_chat_do_completion() -> anyhow::Result<()> {
let configuration: config::OpenAI = serde_json::from_value(json!({
"chat_endpoint": "https://api.openai.com/v1/chat/completions",
"model": "gpt-3.5-turbo",
"auth_token_env_var_name": "OPENAI_API_KEY",
}))?;
let openai = OpenAI::new(configuration);
let prompt = Prompt::default_with_cursor();
let run_params = json!({
"messages": [
{
"role": "system",
"content": "Test"
},
{
"role": "user",
"content": "Test {CONTEXT} - {CODE}"
}
],
"max_tokens": 64
});
let response = openai.do_completion(&prompt, run_params).await?;
assert!(!response.insert_text.is_empty());
Ok(())
}
#[tokio::test]
async fn openai_completion_do_generate() -> anyhow::Result<()> {
let configuration: config::OpenAI = from_value(json!({
"completions_endpoint": "https://api.openai.com/v1/completions",
"model": "gpt-3.5-turbo-instruct",
"auth_token_env_var_name": "OPENAI_API_KEY",
}))?;
let openai = OpenAI::new(configuration);
let prompt = Prompt::default_without_cursor();
let run_params = json!({
"max_tokens": 64
});
let response = openai.do_generate(&prompt, run_params).await?;
let response = open_ai.do_generate(&prompt, run_params).await?;
assert!(!response.generated_text.is_empty());
Ok(())
}
#[tokio::test]
async fn openai_chat_do_generate() -> anyhow::Result<()> {
async fn open_ai_chat_do_generate() -> anyhow::Result<()> {
let configuration: config::OpenAI = serde_json::from_value(json!({
"chat_endpoint": "https://api.openai.com/v1/chat/completions",
"model": "gpt-3.5-turbo",
"auth_token_env_var_name": "OPENAI_API_KEY",
}))?;
let openai = OpenAI::new(configuration);
let open_ai = OpenAI::new(configuration);
let prompt = Prompt::default_with_cursor();
let run_params = json!({
"messages": [
@ -316,7 +282,7 @@ mod test {
],
"max_tokens": 64
});
let response = openai.do_generate(&prompt, run_params).await?;
let response = open_ai.do_generate(&prompt, run_params).await?;
assert!(!response.generated_text.is_empty());
Ok(())
}

View File

@ -252,6 +252,7 @@ async fn do_completion(
let (tx, rx) = oneshot::channel();
memory_backend_tx.send(memory_worker::WorkerRequest::Prompt(PromptRequest::new(
request.params.text_document_position.clone(),
transformer_backend.get_prompt_type(&params)?,
params.clone(),
tx,
)))?;
@ -307,6 +308,7 @@ async fn do_generate(
let (tx, rx) = oneshot::channel();
memory_backend_tx.send(memory_worker::WorkerRequest::Prompt(PromptRequest::new(
request.params.text_document_position.clone(),
transformer_backend.get_prompt_type(&params)?,
params.clone(),
tx,
)))?;

View File

@ -1,6 +1,6 @@
use lsp_server::ResponseError;
use crate::{config::ChatMessage, memory_backends::Prompt};
use crate::{config::ChatMessage, memory_backends::ContextAndCodePrompt};
pub trait ToResponseError {
fn to_response_error(&self, code: i32) -> ResponseError;
@ -20,15 +20,19 @@ pub fn tokens_to_estimated_characters(tokens: usize) -> usize {
tokens * 4
}
pub fn format_chat_messages(messages: &[ChatMessage], prompt: &Prompt) -> Vec<ChatMessage> {
pub fn format_chat_messages(
messages: &[ChatMessage],
prompt: &ContextAndCodePrompt,
) -> Vec<ChatMessage> {
messages
.iter()
.map(|m| ChatMessage {
role: m.role.to_owned(),
content: m
.content
.replace("{CONTEXT}", &prompt.context)
.replace("{CODE}", &prompt.code),
.map(|m| {
ChatMessage::new(
m.role.to_owned(),
m.content
.replace("{CONTEXT}", &prompt.context)
.replace("{CODE}", &prompt.code),
)
})
.collect()
}