Working fim

This commit is contained in:
Silas Marvin 2024-02-21 14:23:29 -10:00
parent 81d730c56d
commit 2f2ff81043
4 changed files with 54 additions and 40 deletions

View File

@ -24,9 +24,9 @@ pub enum ValidTransformerBackend {
// TODO: Review this for real lol
#[derive(Clone, Deserialize)]
pub struct FIM {
start: String,
middle: String,
end: String,
pub start: String,
pub middle: String,
pub end: String,
}
#[derive(Clone, Deserialize)]
@ -180,8 +180,12 @@ impl Configuration {
}
}
pub fn supports_fim(&self) -> bool {
false
pub fn get_fim(&self) -> Option<&FIM> {
if let Some(model_gguf) = &self.valid_config.transformer.model_gguf {
model_gguf.fim.as_ref()
} else {
panic!("We currently only support gguf models using llama cpp")
}
}
}
@ -207,11 +211,11 @@ mod tests {
"completion": 32,
"generation": 256,
},
// "fim": {
// "start": "",
// "middle": "",
// "end": ""
// },
"fim": {
"start": "<fim_prefix>",
"middle": "<fim_suffix>",
"end": "<fim_middle>"
},
"chat": {
"completion": [
{

View File

@ -35,33 +35,46 @@ impl MemoryBackend for FileStore {
}
fn build_prompt(&self, position: &TextDocumentPositionParams) -> anyhow::Result<String> {
let rope = self
let mut rope = self
.file_map
.get(position.text_document.uri.as_str())
.context("Error file not found")?
.clone();
if self.configuration.supports_fim() {
// We will want to have some kind of infill support we add
// rope.insert(cursor_index, "<fim_hole>");
// rope.insert(0, "<fim_start>");
// rope.insert(rope.len_chars(), "<fim_end>");
// let prompt = rope.to_string();
unimplemented!()
} else {
// Convert rope to correct prompt for llm
let cursor_index = rope.line_to_char(position.position.line as usize)
+ position.position.character as usize;
let cursor_index = rope.line_to_char(position.position.line as usize)
+ position.position.character as usize;
let start = cursor_index
.checked_sub(self.configuration.get_maximum_context_length())
.unwrap_or(0);
eprintln!("############ {start} - {cursor_index} #############");
Ok(rope
.get_slice(start..cursor_index)
.context("Error getting rope slice")?
.to_string())
// We only want to do FIM if the user has enabled it, and the cursor is not at the end of the file
match self.configuration.get_fim() {
Some(fim) if rope.len_chars() != cursor_index => {
let max_length = self.configuration.get_maximum_context_length();
let start = cursor_index.checked_sub(max_length / 2).unwrap_or(0);
let end = rope
.len_chars()
.min(cursor_index + (max_length - (start - cursor_index)));
rope.insert(end, &fim.end);
rope.insert(cursor_index, &fim.middle);
rope.insert(start, &fim.start);
let rope_slice = rope
.get_slice(
start
..end
+ fim.start.chars().count()
+ fim.middle.chars().count()
+ fim.end.chars().count(),
)
.context("Error getting rope slice")?;
Ok(rope_slice.to_string())
}
_ => {
let start = cursor_index
.checked_sub(self.configuration.get_maximum_context_length())
.unwrap_or(0);
let rope_slice = rope
.get_slice(start..cursor_index)
.context("Error getting rope slice")?;
Ok(rope_slice.to_string())
}
}
}

View File

@ -94,7 +94,7 @@ impl Model {
)
}
let mut batch = LlamaBatch::new(512, 1);
let mut batch = LlamaBatch::new(n_cxt, 1);
let last_index: i32 = (tokens_list.len() - 1) as i32;
for (i, token) in (0_i32..).zip(tokens_list.into_iter()) {
@ -107,11 +107,12 @@ impl Model {
.with_context(|| "llama_decode() failed")?;
// main loop
let n_start = batch.n_tokens();
let mut output: Vec<String> = vec![];
let mut n_cur = batch.n_tokens();
let mut n_cur = n_start;
let mut n_decode = 0;
let t_main_start = ggml_time_us();
while (n_cur as usize) <= max_new_tokens {
while (n_cur as usize) <= (n_start as usize + max_new_tokens) {
// sample the next token
{
let candidates = ctx.candidates_ith(batch.n_tokens() - 1);

View File

@ -98,16 +98,12 @@ impl Worker {
.memory_backend
.lock()
.get_filter_text(&request.params.text_document_position)?;
eprintln!("\nPROMPT\n****************{}***************\n\n", prompt);
eprintln!("\nPROMPT**************\n{}\n******************\n", prompt);
let response = self.transformer_backend.do_completion(&prompt)?;
eprintln!(
"\nINSERT TEXT\n****************{}***************\n\n",
"\nINSERT TEXT&&&&&&&&&&&&&&&&&&&\n{}\n&&&&&&&&&&&&&&&&&&\n",
response.insert_text
);
eprintln!(
"\nFILTER TEXT\n&&&*************{}***********&&&\n\n",
filter_text
);
let completion_text_edit = TextEdit::new(
Range::new(
Position::new(