update the assistant panel to use new prompt templates

This commit is contained in:
KCaverly 2023-10-18 14:20:12 -04:00
parent b9bb27512c
commit aa1825681c
5 changed files with 33 additions and 146 deletions

View File

@ -90,10 +90,6 @@ impl PromptChain {
if let Some((template_prompt, prompt_token_count)) = if let Some((template_prompt, prompt_token_count)) =
template.generate(&self.args, tokens_outstanding).log_err() template.generate(&self.args, tokens_outstanding).log_err()
{ {
println!(
"GENERATED PROMPT ({:?}): {:?}",
&prompt_token_count, &template_prompt
);
if template_prompt != "" { if template_prompt != "" {
prompts[idx] = template_prompt; prompts[idx] = template_prompt;

View File

@ -44,22 +44,22 @@ impl PromptTemplate for FileContext {
.unwrap(); .unwrap();
if start == end { if start == end {
writeln!(prompt, "<|START|>").unwrap(); write!(prompt, "<|START|>").unwrap();
} else { } else {
writeln!(prompt, "<|START|").unwrap(); write!(prompt, "<|START|").unwrap();
} }
writeln!( write!(
prompt, prompt,
"{}", "{}",
buffer.text_for_range(start..end).collect::<String>() buffer.text_for_range(start..end).collect::<String>()
) )
.unwrap(); .unwrap();
if start != end { if start != end {
writeln!(prompt, "|END|>").unwrap(); write!(prompt, "|END|>").unwrap();
} }
writeln!( write!(
prompt, prompt,
"{}", "{}",
buffer.text_for_range(end..buffer.len()).collect::<String>() buffer.text_for_range(end..buffer.len()).collect::<String>()

View File

@ -25,7 +25,7 @@ impl PromptTemplate for EngineerPreamble {
if let Some(project_name) = args.project_name.clone() { if let Some(project_name) = args.project_name.clone() {
prompts.push(format!( prompts.push(format!(
"You are currently working inside the '{project_name}' in Zed the code editor." "You are currently working inside the '{project_name}' project in code editor Zed."
)); ));
} }

View File

@ -612,6 +612,18 @@ impl AssistantPanel {
let project = pending_assist.project.clone(); let project = pending_assist.project.clone();
let project_name = if let Some(project) = project.upgrade(cx) {
Some(
project
.read(cx)
.worktree_root_names(cx)
.collect::<Vec<&str>>()
.join("/"),
)
} else {
None
};
self.inline_prompt_history self.inline_prompt_history
.retain(|prompt| prompt != user_prompt); .retain(|prompt| prompt != user_prompt);
self.inline_prompt_history.push_back(user_prompt.into()); self.inline_prompt_history.push_back(user_prompt.into());
@ -649,7 +661,6 @@ impl AssistantPanel {
None None
}; };
let codegen_kind = codegen.read(cx).kind().clone();
let user_prompt = user_prompt.to_string(); let user_prompt = user_prompt.to_string();
let snippets = if retrieve_context { let snippets = if retrieve_context {
@ -692,11 +703,11 @@ impl AssistantPanel {
generate_content_prompt( generate_content_prompt(
user_prompt, user_prompt,
language_name, language_name,
&buffer, buffer,
range, range,
codegen_kind,
snippets, snippets,
model_name, model_name,
project_name,
) )
}); });

View File

@ -1,6 +1,8 @@
use crate::codegen::CodegenKind; use crate::codegen::CodegenKind;
use ai::models::{LanguageModel, OpenAILanguageModel}; use ai::models::{LanguageModel, OpenAILanguageModel};
use ai::templates::base::{PromptArguments, PromptChain, PromptPriority, PromptTemplate}; use ai::templates::base::{PromptArguments, PromptChain, PromptPriority, PromptTemplate};
use ai::templates::file_context::FileContext;
use ai::templates::generate::GenerateInlineContent;
use ai::templates::preamble::EngineerPreamble; use ai::templates::preamble::EngineerPreamble;
use ai::templates::repository_context::{PromptCodeSnippet, RepositoryContext}; use ai::templates::repository_context::{PromptCodeSnippet, RepositoryContext};
use language::{BufferSnapshot, OffsetRangeExt, ToOffset}; use language::{BufferSnapshot, OffsetRangeExt, ToOffset};
@ -124,11 +126,11 @@ fn summarize(buffer: &BufferSnapshot, selected_range: Range<impl ToOffset>) -> S
pub fn generate_content_prompt( pub fn generate_content_prompt(
user_prompt: String, user_prompt: String,
language_name: Option<&str>, language_name: Option<&str>,
buffer: &BufferSnapshot, buffer: BufferSnapshot,
range: Range<impl ToOffset>, range: Range<usize>,
kind: CodegenKind,
search_results: Vec<PromptCodeSnippet>, search_results: Vec<PromptCodeSnippet>,
model: &str, model: &str,
project_name: Option<String>,
) -> anyhow::Result<String> { ) -> anyhow::Result<String> {
// Using new Prompt Templates // Using new Prompt Templates
let openai_model: Arc<dyn LanguageModel> = Arc::new(OpenAILanguageModel::load(model)); let openai_model: Arc<dyn LanguageModel> = Arc::new(OpenAILanguageModel::load(model));
@ -141,146 +143,24 @@ pub fn generate_content_prompt(
let args = PromptArguments { let args = PromptArguments {
model: openai_model, model: openai_model,
language_name: lang_name.clone(), language_name: lang_name.clone(),
project_name: None, project_name,
snippets: search_results.clone(), snippets: search_results.clone(),
reserved_tokens: 1000, reserved_tokens: 1000,
buffer: Some(buffer),
selected_range: Some(range),
user_prompt: Some(user_prompt.clone()),
}; };
let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![ let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
(PromptPriority::High, Box::new(EngineerPreamble {})), (PromptPriority::High, Box::new(EngineerPreamble {})),
(PromptPriority::Low, Box::new(RepositoryContext {})), (PromptPriority::Low, Box::new(RepositoryContext {})),
(PromptPriority::Medium, Box::new(FileContext {})),
(PromptPriority::High, Box::new(GenerateInlineContent {})),
]; ];
let chain = PromptChain::new(args, templates); let chain = PromptChain::new(args, templates);
let (prompt, _) = chain.generate(true)?;
let prompt = chain.generate(true)?; anyhow::Ok(prompt)
println!("{:?}", prompt);
const MAXIMUM_SNIPPET_TOKEN_COUNT: usize = 500;
const RESERVED_TOKENS_FOR_GENERATION: usize = 1000;
let mut prompts = Vec::new();
let range = range.to_offset(buffer);
// General Preamble
if let Some(language_name) = language_name.clone() {
prompts.push(format!("You're an expert {language_name} engineer.\n"));
} else {
prompts.push("You're an expert engineer.\n".to_string());
}
// Snippets
let mut snippet_position = prompts.len() - 1;
let mut content = String::new();
content.extend(buffer.text_for_range(0..range.start));
if range.start == range.end {
content.push_str("<|START|>");
} else {
content.push_str("<|START|");
}
content.extend(buffer.text_for_range(range.clone()));
if range.start != range.end {
content.push_str("|END|>");
}
content.extend(buffer.text_for_range(range.end..buffer.len()));
prompts.push("The file you are currently working on has the following content:\n".to_string());
if let Some(language_name) = language_name {
let language_name = language_name.to_lowercase();
prompts.push(format!("```{language_name}\n{content}\n```"));
} else {
prompts.push(format!("```\n{content}\n```"));
}
match kind {
CodegenKind::Generate { position: _ } => {
prompts.push("In particular, the user's cursor is currently on the '<|START|>' span in the above outline, with no text selected.".to_string());
prompts
.push("Assume the cursor is located where the `<|START|` marker is.".to_string());
prompts.push(
"Text can't be replaced, so assume your answer will be inserted at the cursor."
.to_string(),
);
prompts.push(format!(
"Generate text based on the users prompt: {user_prompt}"
));
}
CodegenKind::Transform { range: _ } => {
prompts.push("In particular, the user has selected a section of the text between the '<|START|' and '|END|>' spans.".to_string());
prompts.push(format!(
"Modify the users code selected text based upon the users prompt: '{user_prompt}'"
));
prompts.push("You MUST reply with only the adjusted code (within the '<|START|' and '|END|>' spans), not the entire file.".to_string());
}
}
if let Some(language_name) = language_name {
prompts.push(format!(
"Your answer MUST always and only be valid {language_name}"
));
}
prompts.push("Never make remarks about the output.".to_string());
prompts.push("Do not return any text, except the generated code.".to_string());
prompts.push("Always wrap your code in a Markdown block".to_string());
let current_messages = [ChatCompletionRequestMessage {
role: "user".to_string(),
content: Some(prompts.join("\n")),
function_call: None,
name: None,
}];
let mut remaining_token_count = if let Ok(current_token_count) =
tiktoken_rs::num_tokens_from_messages(model, &current_messages)
{
let max_token_count = tiktoken_rs::model::get_context_size(model);
let intermediate_token_count = if max_token_count > current_token_count {
max_token_count - current_token_count
} else {
0
};
if intermediate_token_count < RESERVED_TOKENS_FOR_GENERATION {
0
} else {
intermediate_token_count - RESERVED_TOKENS_FOR_GENERATION
}
} else {
// If tiktoken fails to count token count, assume we have no space remaining.
0
};
// TODO:
// - add repository name to snippet
// - add file path
// - add language
if let Ok(encoding) = tiktoken_rs::get_bpe_from_model(model) {
let mut template = "You are working inside a large repository, here are a few code snippets that may be useful";
for search_result in search_results {
let mut snippet_prompt = template.to_string();
let snippet = search_result.to_string();
writeln!(snippet_prompt, "```\n{snippet}\n```").unwrap();
let token_count = encoding
.encode_with_special_tokens(snippet_prompt.as_str())
.len();
if token_count <= remaining_token_count {
if token_count < MAXIMUM_SNIPPET_TOKEN_COUNT {
prompts.insert(snippet_position, snippet_prompt);
snippet_position += 1;
remaining_token_count -= token_count;
// If you have already added the template to the prompt, remove the template.
template = "";
}
} else {
break;
}
}
}
anyhow::Ok(prompts.join("\n"))
} }
#[cfg(test)] #[cfg(test)]