progress on smarter truncation strategy for file context

This commit is contained in:
KCaverly 2023-10-18 17:56:59 -04:00
parent 587fd707ba
commit 178a84bcf6
4 changed files with 124 additions and 37 deletions

View File

@ -6,6 +6,7 @@ pub trait LanguageModel {
fn name(&self) -> String;
fn count_tokens(&self, content: &str) -> anyhow::Result<usize>;
fn truncate(&self, content: &str, length: usize) -> anyhow::Result<String>;
fn truncate_start(&self, content: &str, length: usize) -> anyhow::Result<String>;
fn capacity(&self) -> anyhow::Result<usize>;
}
@ -47,6 +48,18 @@ impl LanguageModel for OpenAILanguageModel {
Err(anyhow!("bpe for open ai model was not retrieved"))
}
}
fn truncate_start(&self, content: &str, length: usize) -> anyhow::Result<String> {
if let Some(bpe) = &self.bpe {
let tokens = bpe.encode_with_special_tokens(content);
if tokens.len() > length {
bpe.decode(tokens[length..].to_vec())
} else {
bpe.decode(tokens)
}
} else {
Err(anyhow!("bpe for open ai model was not retrieved"))
}
}
fn capacity(&self) -> anyhow::Result<usize> {
anyhow::Ok(tiktoken_rs::model::get_context_size(&self.name))
}

View File

@ -190,6 +190,13 @@ pub(crate) mod tests {
.collect::<String>(),
)
}
fn truncate_start(&self, content: &str, length: usize) -> anyhow::Result<String> {
anyhow::Ok(
content.chars().collect::<Vec<char>>()[length..]
.into_iter()
.collect::<String>(),
)
}
fn capacity(&self) -> anyhow::Result<usize> {
anyhow::Ok(self.capacity)
}

View File

@ -1,9 +1,103 @@
use anyhow::anyhow;
use language::BufferSnapshot;
use language::ToOffset;
use crate::models::LanguageModel;
use crate::templates::base::PromptArguments;
use crate::templates::base::PromptTemplate;
use std::fmt::Write;
use std::ops::Range;
use std::sync::Arc;
fn retrieve_context(
buffer: &BufferSnapshot,
selected_range: &Option<Range<usize>>,
model: Arc<dyn LanguageModel>,
max_token_count: Option<usize>,
) -> anyhow::Result<(String, usize, bool)> {
let mut prompt = String::new();
let mut truncated = false;
if let Some(selected_range) = selected_range {
let start = selected_range.start.to_offset(buffer);
let end = selected_range.end.to_offset(buffer);
let start_window = buffer.text_for_range(0..start).collect::<String>();
let mut selected_window = String::new();
if start == end {
write!(selected_window, "<|START|>").unwrap();
} else {
write!(selected_window, "<|START|").unwrap();
}
write!(
selected_window,
"{}",
buffer.text_for_range(start..end).collect::<String>()
)
.unwrap();
if start != end {
write!(selected_window, "|END|>").unwrap();
}
let end_window = buffer.text_for_range(end..buffer.len()).collect::<String>();
if let Some(max_token_count) = max_token_count {
let selected_tokens = model.count_tokens(&selected_window)?;
if selected_tokens > max_token_count {
return Err(anyhow!(
"selected range is greater than model context window, truncation not possible"
));
};
let mut remaining_tokens = max_token_count - selected_tokens;
let start_window_tokens = model.count_tokens(&start_window)?;
let end_window_tokens = model.count_tokens(&end_window)?;
let outside_tokens = start_window_tokens + end_window_tokens;
if outside_tokens > remaining_tokens {
let (start_goal_tokens, end_goal_tokens) =
if start_window_tokens < end_window_tokens {
let start_goal_tokens = (remaining_tokens / 2).min(start_window_tokens);
remaining_tokens -= start_goal_tokens;
let end_goal_tokens = remaining_tokens.min(end_window_tokens);
(start_goal_tokens, end_goal_tokens)
} else {
let end_goal_tokens = (remaining_tokens / 2).min(end_window_tokens);
remaining_tokens -= end_goal_tokens;
let start_goal_tokens = remaining_tokens.min(start_window_tokens);
(start_goal_tokens, end_goal_tokens)
};
let truncated_start_window =
model.truncate_start(&start_window, start_goal_tokens)?;
let truncated_end_window = model.truncate(&end_window, end_goal_tokens)?;
writeln!(
prompt,
"{truncated_start_window}{selected_window}{truncated_end_window}"
)
.unwrap();
truncated = true;
} else {
writeln!(prompt, "{start_window}{selected_window}{end_window}").unwrap();
}
} else {
// If we dont have a selected range, include entire file.
writeln!(prompt, "{}", &buffer.text()).unwrap();
// Dumb truncation strategy
if let Some(max_token_count) = max_token_count {
if model.count_tokens(&prompt)? > max_token_count {
truncated = true;
prompt = model.truncate(&prompt, max_token_count)?;
}
}
}
}
let token_count = model.count_tokens(&prompt)?;
anyhow::Ok((prompt, token_count, truncated))
}
pub struct FileContext {}
@ -28,53 +122,24 @@ impl PromptTemplate for FileContext {
.clone()
.unwrap_or("".to_string())
.to_lowercase();
writeln!(prompt, "```{language_name}").unwrap();
let (context, _, truncated) = retrieve_context(
buffer,
&args.selected_range,
args.model.clone(),
max_token_length,
)?;
writeln!(prompt, "```{language_name}\n{context}\n```").unwrap();
if let Some(selected_range) = &args.selected_range {
let start = selected_range.start.to_offset(buffer);
let end = selected_range.end.to_offset(buffer);
writeln!(
prompt,
"{}",
buffer.text_for_range(0..start).collect::<String>()
)
.unwrap();
if start == end {
write!(prompt, "<|START|>").unwrap();
} else {
write!(prompt, "<|START|").unwrap();
}
write!(
prompt,
"{}",
buffer.text_for_range(start..end).collect::<String>()
)
.unwrap();
if start != end {
write!(prompt, "|END|>").unwrap();
}
write!(
prompt,
"{}",
buffer.text_for_range(end..buffer.len()).collect::<String>()
)
.unwrap();
writeln!(prompt, "```").unwrap();
if start == end {
writeln!(prompt, "In particular, the user's cursor is currently on the '<|START|>' span in the above content, with no text selected.").unwrap();
} else {
writeln!(prompt, "In particular, the user has selected a section of the text between the '<|START|' and '|END|>' spans.").unwrap();
}
} else {
// If we dont have a selected range, include entire file.
writeln!(prompt, "{}", &buffer.text()).unwrap();
writeln!(prompt, "```").unwrap();
}
// Really dumb truncation strategy

View File

@ -166,6 +166,8 @@ pub fn generate_content_prompt(
let chain = PromptChain::new(args, templates);
let (prompt, _) = chain.generate(true)?;
println!("PROMPT: {:?}", &prompt);
anyhow::Ok(prompt)
}