mirror of
https://github.com/zed-industries/zed.git
synced 2024-09-20 02:47:34 +03:00
update the assistant panel to use new prompt templates
This commit is contained in:
parent
b9bb27512c
commit
aa1825681c
@ -90,10 +90,6 @@ impl PromptChain {
|
||||
if let Some((template_prompt, prompt_token_count)) =
|
||||
template.generate(&self.args, tokens_outstanding).log_err()
|
||||
{
|
||||
println!(
|
||||
"GENERATED PROMPT ({:?}): {:?}",
|
||||
&prompt_token_count, &template_prompt
|
||||
);
|
||||
if template_prompt != "" {
|
||||
prompts[idx] = template_prompt;
|
||||
|
||||
|
@ -44,22 +44,22 @@ impl PromptTemplate for FileContext {
|
||||
.unwrap();
|
||||
|
||||
if start == end {
|
||||
writeln!(prompt, "<|START|>").unwrap();
|
||||
write!(prompt, "<|START|>").unwrap();
|
||||
} else {
|
||||
writeln!(prompt, "<|START|").unwrap();
|
||||
write!(prompt, "<|START|").unwrap();
|
||||
}
|
||||
|
||||
writeln!(
|
||||
write!(
|
||||
prompt,
|
||||
"{}",
|
||||
buffer.text_for_range(start..end).collect::<String>()
|
||||
)
|
||||
.unwrap();
|
||||
if start != end {
|
||||
writeln!(prompt, "|END|>").unwrap();
|
||||
write!(prompt, "|END|>").unwrap();
|
||||
}
|
||||
|
||||
writeln!(
|
||||
write!(
|
||||
prompt,
|
||||
"{}",
|
||||
buffer.text_for_range(end..buffer.len()).collect::<String>()
|
||||
|
@ -25,7 +25,7 @@ impl PromptTemplate for EngineerPreamble {
|
||||
|
||||
if let Some(project_name) = args.project_name.clone() {
|
||||
prompts.push(format!(
|
||||
"You are currently working inside the '{project_name}' in Zed the code editor."
|
||||
"You are currently working inside the '{project_name}' project in code editor Zed."
|
||||
));
|
||||
}
|
||||
|
||||
|
@ -612,6 +612,18 @@ impl AssistantPanel {
|
||||
|
||||
let project = pending_assist.project.clone();
|
||||
|
||||
let project_name = if let Some(project) = project.upgrade(cx) {
|
||||
Some(
|
||||
project
|
||||
.read(cx)
|
||||
.worktree_root_names(cx)
|
||||
.collect::<Vec<&str>>()
|
||||
.join("/"),
|
||||
)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
self.inline_prompt_history
|
||||
.retain(|prompt| prompt != user_prompt);
|
||||
self.inline_prompt_history.push_back(user_prompt.into());
|
||||
@ -649,7 +661,6 @@ impl AssistantPanel {
|
||||
None
|
||||
};
|
||||
|
||||
let codegen_kind = codegen.read(cx).kind().clone();
|
||||
let user_prompt = user_prompt.to_string();
|
||||
|
||||
let snippets = if retrieve_context {
|
||||
@ -692,11 +703,11 @@ impl AssistantPanel {
|
||||
generate_content_prompt(
|
||||
user_prompt,
|
||||
language_name,
|
||||
&buffer,
|
||||
buffer,
|
||||
range,
|
||||
codegen_kind,
|
||||
snippets,
|
||||
model_name,
|
||||
project_name,
|
||||
)
|
||||
});
|
||||
|
||||
|
@ -1,6 +1,8 @@
|
||||
use crate::codegen::CodegenKind;
|
||||
use ai::models::{LanguageModel, OpenAILanguageModel};
|
||||
use ai::templates::base::{PromptArguments, PromptChain, PromptPriority, PromptTemplate};
|
||||
use ai::templates::file_context::FileContext;
|
||||
use ai::templates::generate::GenerateInlineContent;
|
||||
use ai::templates::preamble::EngineerPreamble;
|
||||
use ai::templates::repository_context::{PromptCodeSnippet, RepositoryContext};
|
||||
use language::{BufferSnapshot, OffsetRangeExt, ToOffset};
|
||||
@ -124,11 +126,11 @@ fn summarize(buffer: &BufferSnapshot, selected_range: Range<impl ToOffset>) -> S
|
||||
pub fn generate_content_prompt(
|
||||
user_prompt: String,
|
||||
language_name: Option<&str>,
|
||||
buffer: &BufferSnapshot,
|
||||
range: Range<impl ToOffset>,
|
||||
kind: CodegenKind,
|
||||
buffer: BufferSnapshot,
|
||||
range: Range<usize>,
|
||||
search_results: Vec<PromptCodeSnippet>,
|
||||
model: &str,
|
||||
project_name: Option<String>,
|
||||
) -> anyhow::Result<String> {
|
||||
// Using new Prompt Templates
|
||||
let openai_model: Arc<dyn LanguageModel> = Arc::new(OpenAILanguageModel::load(model));
|
||||
@ -141,146 +143,24 @@ pub fn generate_content_prompt(
|
||||
let args = PromptArguments {
|
||||
model: openai_model,
|
||||
language_name: lang_name.clone(),
|
||||
project_name: None,
|
||||
project_name,
|
||||
snippets: search_results.clone(),
|
||||
reserved_tokens: 1000,
|
||||
buffer: Some(buffer),
|
||||
selected_range: Some(range),
|
||||
user_prompt: Some(user_prompt.clone()),
|
||||
};
|
||||
|
||||
let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
|
||||
(PromptPriority::High, Box::new(EngineerPreamble {})),
|
||||
(PromptPriority::Low, Box::new(RepositoryContext {})),
|
||||
(PromptPriority::Medium, Box::new(FileContext {})),
|
||||
(PromptPriority::High, Box::new(GenerateInlineContent {})),
|
||||
];
|
||||
let chain = PromptChain::new(args, templates);
|
||||
let (prompt, _) = chain.generate(true)?;
|
||||
|
||||
let prompt = chain.generate(true)?;
|
||||
println!("{:?}", prompt);
|
||||
|
||||
const MAXIMUM_SNIPPET_TOKEN_COUNT: usize = 500;
|
||||
const RESERVED_TOKENS_FOR_GENERATION: usize = 1000;
|
||||
|
||||
let mut prompts = Vec::new();
|
||||
let range = range.to_offset(buffer);
|
||||
|
||||
// General Preamble
|
||||
if let Some(language_name) = language_name.clone() {
|
||||
prompts.push(format!("You're an expert {language_name} engineer.\n"));
|
||||
} else {
|
||||
prompts.push("You're an expert engineer.\n".to_string());
|
||||
}
|
||||
|
||||
// Snippets
|
||||
let mut snippet_position = prompts.len() - 1;
|
||||
|
||||
let mut content = String::new();
|
||||
content.extend(buffer.text_for_range(0..range.start));
|
||||
if range.start == range.end {
|
||||
content.push_str("<|START|>");
|
||||
} else {
|
||||
content.push_str("<|START|");
|
||||
}
|
||||
content.extend(buffer.text_for_range(range.clone()));
|
||||
if range.start != range.end {
|
||||
content.push_str("|END|>");
|
||||
}
|
||||
content.extend(buffer.text_for_range(range.end..buffer.len()));
|
||||
|
||||
prompts.push("The file you are currently working on has the following content:\n".to_string());
|
||||
|
||||
if let Some(language_name) = language_name {
|
||||
let language_name = language_name.to_lowercase();
|
||||
prompts.push(format!("```{language_name}\n{content}\n```"));
|
||||
} else {
|
||||
prompts.push(format!("```\n{content}\n```"));
|
||||
}
|
||||
|
||||
match kind {
|
||||
CodegenKind::Generate { position: _ } => {
|
||||
prompts.push("In particular, the user's cursor is currently on the '<|START|>' span in the above outline, with no text selected.".to_string());
|
||||
prompts
|
||||
.push("Assume the cursor is located where the `<|START|` marker is.".to_string());
|
||||
prompts.push(
|
||||
"Text can't be replaced, so assume your answer will be inserted at the cursor."
|
||||
.to_string(),
|
||||
);
|
||||
prompts.push(format!(
|
||||
"Generate text based on the users prompt: {user_prompt}"
|
||||
));
|
||||
}
|
||||
CodegenKind::Transform { range: _ } => {
|
||||
prompts.push("In particular, the user has selected a section of the text between the '<|START|' and '|END|>' spans.".to_string());
|
||||
prompts.push(format!(
|
||||
"Modify the users code selected text based upon the users prompt: '{user_prompt}'"
|
||||
));
|
||||
prompts.push("You MUST reply with only the adjusted code (within the '<|START|' and '|END|>' spans), not the entire file.".to_string());
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(language_name) = language_name {
|
||||
prompts.push(format!(
|
||||
"Your answer MUST always and only be valid {language_name}"
|
||||
));
|
||||
}
|
||||
prompts.push("Never make remarks about the output.".to_string());
|
||||
prompts.push("Do not return any text, except the generated code.".to_string());
|
||||
prompts.push("Always wrap your code in a Markdown block".to_string());
|
||||
|
||||
let current_messages = [ChatCompletionRequestMessage {
|
||||
role: "user".to_string(),
|
||||
content: Some(prompts.join("\n")),
|
||||
function_call: None,
|
||||
name: None,
|
||||
}];
|
||||
|
||||
let mut remaining_token_count = if let Ok(current_token_count) =
|
||||
tiktoken_rs::num_tokens_from_messages(model, ¤t_messages)
|
||||
{
|
||||
let max_token_count = tiktoken_rs::model::get_context_size(model);
|
||||
let intermediate_token_count = if max_token_count > current_token_count {
|
||||
max_token_count - current_token_count
|
||||
} else {
|
||||
0
|
||||
};
|
||||
|
||||
if intermediate_token_count < RESERVED_TOKENS_FOR_GENERATION {
|
||||
0
|
||||
} else {
|
||||
intermediate_token_count - RESERVED_TOKENS_FOR_GENERATION
|
||||
}
|
||||
} else {
|
||||
// If tiktoken fails to count token count, assume we have no space remaining.
|
||||
0
|
||||
};
|
||||
|
||||
// TODO:
|
||||
// - add repository name to snippet
|
||||
// - add file path
|
||||
// - add language
|
||||
if let Ok(encoding) = tiktoken_rs::get_bpe_from_model(model) {
|
||||
let mut template = "You are working inside a large repository, here are a few code snippets that may be useful";
|
||||
|
||||
for search_result in search_results {
|
||||
let mut snippet_prompt = template.to_string();
|
||||
let snippet = search_result.to_string();
|
||||
writeln!(snippet_prompt, "```\n{snippet}\n```").unwrap();
|
||||
|
||||
let token_count = encoding
|
||||
.encode_with_special_tokens(snippet_prompt.as_str())
|
||||
.len();
|
||||
if token_count <= remaining_token_count {
|
||||
if token_count < MAXIMUM_SNIPPET_TOKEN_COUNT {
|
||||
prompts.insert(snippet_position, snippet_prompt);
|
||||
snippet_position += 1;
|
||||
remaining_token_count -= token_count;
|
||||
// If you have already added the template to the prompt, remove the template.
|
||||
template = "";
|
||||
}
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
anyhow::Ok(prompts.join("\n"))
|
||||
anyhow::Ok(prompt)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
Loading…
Reference in New Issue
Block a user