Refactored config

This commit is contained in:
SilasMarvin 2024-04-27 15:03:37 -07:00
parent a2457f77e6
commit 16966cba46
5 changed files with 140 additions and 124 deletions

2
Cargo.lock generated
View File

@ -1954,7 +1954,7 @@ checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e"
[[package]]
name = "pgml"
version = "1.0.1"
version = "1.0.0"
dependencies = [
"anyhow",
"async-trait",

View File

@ -3,6 +3,8 @@ use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
use crate::memory_backends::PromptForType;
const DEFAULT_LLAMA_CPP_N_CTX: usize = 1024;
const DEFAULT_OPENAI_MAX_CONTEXT: usize = 2048;
@ -26,7 +28,8 @@ impl Default for ValidMemoryBackend {
}
#[derive(Debug, Clone, Deserialize)]
pub enum ValidTransformerBackend {
#[serde(tag = "type")]
pub enum ValidModel {
#[serde(rename = "llamacpp")]
LLaMACPP(LLaMACPP),
#[serde(rename = "openai")]
@ -35,9 +38,9 @@ pub enum ValidTransformerBackend {
Anthropic(Anthropic),
}
impl Default for ValidTransformerBackend {
impl Default for ValidModel {
fn default() -> Self {
ValidTransformerBackend::LLaMACPP(LLaMACPP::default())
ValidModel::LLaMACPP(LLaMACPP::default())
}
}
@ -102,13 +105,13 @@ pub struct LLaMACPP {
// The model to use
#[serde(flatten)]
pub model: Model,
// Fill in the middle support
pub fim: Option<FIM>,
// The maximum number of new tokens to generate
#[serde(default)]
pub max_tokens: MaxTokens,
// Chat args
pub chat: Option<Chat>,
// // Fill in the middle support
// pub fim: Option<FIM>,
// // The maximum number of new tokens to generate
// #[serde(default)]
// pub max_tokens: MaxTokens,
// // Chat args
// pub chat: Option<Chat>,
// Kwargs passed to LlamaCPP
#[serde(flatten)]
pub kwargs: Kwargs,
@ -121,13 +124,13 @@ impl Default for LLaMACPP {
repository: "stabilityai/stable-code-3b".to_string(),
name: Some("stable-code-3b-Q5_K_M.gguf".to_string()),
},
fim: Some(FIM {
start: "<fim_prefix>".to_string(),
middle: "<fim_suffix>".to_string(),
end: "<fim_middle>".to_string(),
}),
max_tokens: MaxTokens::default(),
chat: None,
// fim: Some(FIM {
// start: "<fim_prefix>".to_string(),
// middle: "<fim_suffix>".to_string(),
// end: "<fim_middle>".to_string(),
// }),
// max_tokens: MaxTokens::default(),
// chat: None,
kwargs: Kwargs::default(),
}
}
@ -137,23 +140,23 @@ const fn api_max_requests_per_second_default() -> f32 {
0.5
}
const fn openai_top_p_default() -> f32 {
const fn top_p_default() -> f32 {
0.95
}
const fn openai_presence_penalty_default() -> f32 {
const fn presence_penalty_default() -> f32 {
0.
}
const fn openai_frequency_penalty_default() -> f32 {
const fn frequency_penalty_default() -> f32 {
0.
}
const fn openai_temperature_default() -> f32 {
const fn temperature_default() -> f32 {
0.1
}
const fn openai_max_context_default() -> usize {
const fn max_context_default() -> usize {
DEFAULT_OPENAI_MAX_CONTEXT
}
@ -171,24 +174,24 @@ pub struct OpenAI {
pub max_requests_per_second: f32,
// The model name
pub model: String,
// Fill in the middle support
pub fim: Option<FIM>,
// The maximum number of new tokens to generate
#[serde(default)]
pub max_tokens: MaxTokens,
// Chat args
pub chat: Option<Chat>,
// Other available args
#[serde(default = "openai_top_p_default")]
pub top_p: f32,
#[serde(default = "openai_presence_penalty_default")]
pub presence_penalty: f32,
#[serde(default = "openai_frequency_penalty_default")]
pub frequency_penalty: f32,
#[serde(default = "openai_temperature_default")]
pub temperature: f32,
#[serde(default = "openai_max_context_default")]
max_context: usize,
// // Fill in the middle support
// pub fim: Option<FIM>,
// // The maximum number of new tokens to generate
// #[serde(default)]
// pub max_tokens: MaxTokens,
// // Chat args
// pub chat: Option<Chat>,
// // Other available args
// #[serde(default = "top_p_default")]
// pub top_p: f32,
// #[serde(default = "presence_penalty_default")]
// pub presence_penalty: f32,
// #[serde(default = "frequency_penalty_default")]
// pub frequency_penalty: f32,
// #[serde(default = "temperature_default")]
// pub temperature: f32,
// #[serde(default = "max_context_default")]
// max_context: usize,
}
#[derive(Clone, Debug, Deserialize)]
@ -205,18 +208,45 @@ pub struct Anthropic {
pub max_requests_per_second: f32,
// The model name
pub model: String,
// The maximum number of new tokens to generate
// // The maximum number of new tokens to generate
// #[serde(default)]
// pub max_tokens: MaxTokens,
// // Chat args
// pub chat: Chat,
// #[serde(default = "top_p_default")]
// pub top_p: f32,
// #[serde(default = "temperature_default")]
// pub temperature: f32,
// #[serde(default = "max_context_default")]
// max_context: usize,
}
#[derive(Clone, Debug, Deserialize)]
pub struct Completion {
// The model key to use
pub model: String,
// Model args
#[serde(default)]
pub max_tokens: MaxTokens,
// Chat args
pub chat: Chat,
// System prompt
#[serde(default = "openai_top_p_default")]
#[serde(default = "presence_penalty_default")]
pub presence_penalty: f32,
#[serde(default = "frequency_penalty_default")]
pub frequency_penalty: f32,
#[serde(default = "top_p_default")]
pub top_p: f32,
#[serde(default = "openai_temperature_default")]
#[serde(default = "temperature_default")]
pub temperature: f32,
#[serde(default = "openai_max_context_default")]
#[serde(default = "max_context_default")]
max_context: usize,
// FIM args
pub fim: Option<FIM>,
// Chat args
pub chat: Option<Vec<ChatMessage>>,
pub chat_template: Option<String>,
pub chat_format: Option<String>,
}
#[derive(Clone, Debug, Deserialize, Default)]
@ -224,7 +254,9 @@ pub struct ValidConfig {
#[serde(default)]
pub memory: ValidMemoryBackend,
#[serde(default)]
pub transformer: ValidTransformerBackend,
pub transformer: ValidModel,
#[serde(default)]
pub models: HashMap<String, ValidModel>,
}
#[derive(Clone, Debug, Deserialize, Default)]
@ -263,15 +295,15 @@ impl Config {
pub fn get_transformer_max_requests_per_second(&self) -> f32 {
match &self.config.transformer {
ValidTransformerBackend::LLaMACPP(_) => 1.,
ValidTransformerBackend::OpenAI(openai) => openai.max_requests_per_second,
ValidTransformerBackend::Anthropic(anthropic) => anthropic.max_requests_per_second,
ValidModel::LLaMACPP(_) => 1.,
ValidModel::OpenAI(openai) => openai.max_requests_per_second,
ValidModel::Anthropic(anthropic) => anthropic.max_requests_per_second,
}
}
pub fn get_max_context_length(&self) -> usize {
match &self.config.transformer {
ValidTransformerBackend::LLaMACPP(llama_cpp) => llama_cpp
ValidModel::LLaMACPP(llama_cpp) => llama_cpp
.kwargs
.get("n_ctx")
.map(|v| {
@ -280,24 +312,24 @@ impl Config {
.unwrap_or(DEFAULT_LLAMA_CPP_N_CTX)
})
.unwrap_or(DEFAULT_LLAMA_CPP_N_CTX),
ValidTransformerBackend::OpenAI(openai) => openai.max_context,
ValidTransformerBackend::Anthropic(anthropic) => anthropic.max_context,
ValidModel::OpenAI(openai) => openai.max_context,
ValidModel::Anthropic(anthropic) => anthropic.max_context,
}
}
pub fn get_fim(&self) -> Option<&FIM> {
match &self.config.transformer {
ValidTransformerBackend::LLaMACPP(llama_cpp) => llama_cpp.fim.as_ref(),
ValidTransformerBackend::OpenAI(openai) => openai.fim.as_ref(),
ValidTransformerBackend::Anthropic(_) => None,
ValidModel::LLaMACPP(llama_cpp) => llama_cpp.fim.as_ref(),
ValidModel::OpenAI(openai) => openai.fim.as_ref(),
ValidModel::Anthropic(_) => None,
}
}
pub fn get_chat(&self) -> Option<&Chat> {
match &self.config.transformer {
ValidTransformerBackend::LLaMACPP(llama_cpp) => llama_cpp.chat.as_ref(),
ValidTransformerBackend::OpenAI(openai) => openai.chat.as_ref(),
ValidTransformerBackend::Anthropic(anthropic) => Some(&anthropic.chat),
ValidModel::LLaMACPP(llama_cpp) => llama_cpp.chat.as_ref(),
ValidModel::OpenAI(openai) => openai.chat.as_ref(),
ValidModel::Anthropic(anthropic) => Some(&anthropic.chat),
}
}
}
@ -314,45 +346,23 @@ mod test {
"memory": {
"file_store": {}
},
"transformer": {
"llamacpp": {
"models": {
"model1": {
"type": "llamacpp",
"repository": "TheBloke/deepseek-coder-6.7B-instruct-GGUF",
"name": "deepseek-coder-6.7b-instruct.Q5_K_S.gguf",
"max_new_tokens": {
"completion": 32,
"generation": 256,
},
"fim": {
"start": "<fim_prefix>",
"middle": "<fim_suffix>",
"end": "<fim_middle>"
},
"chat": {
// "completion": [
// {
// "role": "system",
// "content": "You are a code completion chatbot. Use the following context to complete the next segement of code. Keep your response brief. Do not produce any text besides code. \n\n{context}",
// },
// {
// "role": "user",
// "content": "Complete the following code: \n\n{code}"
// }
// ],
// "generation": [
// {
// "role": "system",
// "content": "You are a code completion chatbot. Use the following context to complete the next segement of code. \n\n{context}",
// },
// {
// "role": "user",
// "content": "Complete the following code: \n\n{code}"
// }
// ],
"chat_template": "{% if not add_generation_prompt is defined %}\n{% set add_generation_prompt = false %}\n{% endif %}\n{%- set ns = namespace(found=false) -%}\n{%- for message in messages -%}\n {%- if message['role'] == 'system' -%}\n {%- set ns.found = true -%}\n {%- endif -%}\n{%- endfor -%}\n{{bos_token}}{%- if not ns.found -%}\n{{'You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer\\n'}}\n{%- endif %}\n{%- for message in messages %}\n {%- if message['role'] == 'system' %}\n{{ message['content'] }}\n {%- else %}\n {%- if message['role'] == 'user' %}\n{{'### Instruction:\\n' + message['content'] + '\\n'}}\n {%- else %}\n{{'### Response:\\n' + message['content'] + '\\n<|EOT|>\\n'}}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{% if add_generation_prompt %}\n{{'### Response:'}}\n{% endif %}"
},
"n_ctx": 2048,
"n_gpu_layers": 35,
"n_gpu_layers": 35
}
},
"completion": {
"model": "model1",
"fim": {
"start": "<fim_prefix>",
"middle": "<fim_suffix>",
"end": "<fim_middle>"
},
"max_new_tokens": 32,
}
}
});
@ -366,20 +376,27 @@ mod test {
"memory": {
"file_store": {}
},
"transformer": {
"openai": {
"models": {
"model1": {
"type": "openai",
"completions_endpoint": "https://api.fireworks.ai/inference/v1/completions",
"model": "accounts/fireworks/models/llama-v2-34b-code",
"auth_token_env_var_name": "FIREWORKS_API_KEY",
"chat": {
// Not sure what to do here yet
},
"max_tokens": {
"completion": 16,
"generation": 64
},
"max_context": 4096
},
},
"completion": {
"model": "model1",
"chat": [
{
"role": "system",
"content": "You are a code completion chatbot. Use the following context to complete the next segement of code. \n\n{CONTEXT}",
},
{
"role": "user",
"content": "Complete the following code: \n\n{CODE}"
}
],
"max_new_tokens": 32,
}
}
});

View File

@ -15,25 +15,25 @@ use super::{MemoryBackend, Prompt, PromptForType};
pub struct FileStore {
_crawl: bool,
configuration: Config,
config: Config,
file_map: Mutex<HashMap<String, Rope>>,
accessed_files: Mutex<IndexSet<String>>,
}
impl FileStore {
pub fn new(file_store_config: config::FileStore, configuration: Config) -> Self {
pub fn new(file_store_config: config::FileStore, config: Config) -> Self {
Self {
_crawl: file_store_config.crawl,
configuration,
config,
file_map: Mutex::new(HashMap::new()),
accessed_files: Mutex::new(IndexSet::new()),
}
}
pub fn new_without_crawl(configuration: Config) -> Self {
pub fn new_without_crawl(config: Config) -> Self {
Self {
_crawl: false,
configuration,
config,
file_map: Mutex::new(HashMap::new()),
accessed_files: Mutex::new(IndexSet::new()),
}
@ -111,18 +111,18 @@ impl FileStore {
let is_chat_enabled = match prompt_for_type {
PromptForType::Completion => self
.configuration
.config
.get_chat()
.map(|c| c.completion.is_some())
.unwrap_or(false),
PromptForType::Generate => self
.configuration
.config
.get_chat()
.map(|c| c.generation.is_some())
.unwrap_or(false),
};
Ok(match (is_chat_enabled, self.configuration.get_fim()) {
Ok(match (is_chat_enabled, self.config.get_fim()) {
r @ (true, _) | r @ (false, Some(_)) if rope.len_chars() != cursor_index => {
let max_length = tokens_to_estimated_characters(max_context_length);
let start = cursor_index.saturating_sub(max_length / 2);
@ -192,10 +192,13 @@ impl MemoryBackend for FileStore {
position: &TextDocumentPositionParams,
prompt_for_type: PromptForType,
) -> anyhow::Result<Prompt> {
// TODO: Fix this
// we need to be subtracting the completion / generation tokens from max_context_length
// not sure if we should be doing that for the chat maybe leave a note here for that?
let code = self.build_code(
position,
prompt_for_type,
self.configuration.get_max_context_length(),
self.config.get_max_context_length(),
)?;
Ok(Prompt::new("".to_string(), code))
}

View File

@ -1,5 +1,5 @@
use crate::{
config::{Config, ValidTransformerBackend},
config::{Config, ValidModel},
memory_backends::Prompt,
transformer_worker::{
DoCompletionResponse, DoGenerationResponse, DoGenerationStreamResponse,
@ -26,13 +26,9 @@ impl TryFrom<Config> for Box<dyn TransformerBackend + Send + Sync> {
fn try_from(configuration: Config) -> Result<Self, Self::Error> {
match configuration.config.transformer {
ValidTransformerBackend::LLaMACPP(model_gguf) => {
Ok(Box::new(llama_cpp::LLaMACPP::new(model_gguf)?))
}
ValidTransformerBackend::OpenAI(openai_config) => {
Ok(Box::new(openai::OpenAI::new(openai_config)))
}
ValidTransformerBackend::Anthropic(anthropic_config) => {
ValidModel::LLaMACPP(model_gguf) => Ok(Box::new(llama_cpp::LLaMACPP::new(model_gguf)?)),
ValidModel::OpenAI(openai_config) => Ok(Box::new(openai::OpenAI::new(openai_config))),
ValidModel::Anthropic(anthropic_config) => {
Ok(Box::new(anthropic::Anthropic::new(anthropic_config)))
}
}

@ -1 +1 @@
Subproject commit 47610cb174df8120fbc5255b221aa49a2484de6a
Subproject commit a16ff700c1d54582711d14c2f57341bb9bd82be2