mirror of
https://github.com/zed-industries/zed.git
synced 2024-09-20 02:47:34 +03:00
update the assistant panel to use new prompt templates
This commit is contained in:
parent
b9bb27512c
commit
aa1825681c
@ -90,10 +90,6 @@ impl PromptChain {
|
|||||||
if let Some((template_prompt, prompt_token_count)) =
|
if let Some((template_prompt, prompt_token_count)) =
|
||||||
template.generate(&self.args, tokens_outstanding).log_err()
|
template.generate(&self.args, tokens_outstanding).log_err()
|
||||||
{
|
{
|
||||||
println!(
|
|
||||||
"GENERATED PROMPT ({:?}): {:?}",
|
|
||||||
&prompt_token_count, &template_prompt
|
|
||||||
);
|
|
||||||
if template_prompt != "" {
|
if template_prompt != "" {
|
||||||
prompts[idx] = template_prompt;
|
prompts[idx] = template_prompt;
|
||||||
|
|
||||||
|
@ -44,22 +44,22 @@ impl PromptTemplate for FileContext {
|
|||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
if start == end {
|
if start == end {
|
||||||
writeln!(prompt, "<|START|>").unwrap();
|
write!(prompt, "<|START|>").unwrap();
|
||||||
} else {
|
} else {
|
||||||
writeln!(prompt, "<|START|").unwrap();
|
write!(prompt, "<|START|").unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
writeln!(
|
write!(
|
||||||
prompt,
|
prompt,
|
||||||
"{}",
|
"{}",
|
||||||
buffer.text_for_range(start..end).collect::<String>()
|
buffer.text_for_range(start..end).collect::<String>()
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
if start != end {
|
if start != end {
|
||||||
writeln!(prompt, "|END|>").unwrap();
|
write!(prompt, "|END|>").unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
writeln!(
|
write!(
|
||||||
prompt,
|
prompt,
|
||||||
"{}",
|
"{}",
|
||||||
buffer.text_for_range(end..buffer.len()).collect::<String>()
|
buffer.text_for_range(end..buffer.len()).collect::<String>()
|
||||||
|
@ -25,7 +25,7 @@ impl PromptTemplate for EngineerPreamble {
|
|||||||
|
|
||||||
if let Some(project_name) = args.project_name.clone() {
|
if let Some(project_name) = args.project_name.clone() {
|
||||||
prompts.push(format!(
|
prompts.push(format!(
|
||||||
"You are currently working inside the '{project_name}' in Zed the code editor."
|
"You are currently working inside the '{project_name}' project in code editor Zed."
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -612,6 +612,18 @@ impl AssistantPanel {
|
|||||||
|
|
||||||
let project = pending_assist.project.clone();
|
let project = pending_assist.project.clone();
|
||||||
|
|
||||||
|
let project_name = if let Some(project) = project.upgrade(cx) {
|
||||||
|
Some(
|
||||||
|
project
|
||||||
|
.read(cx)
|
||||||
|
.worktree_root_names(cx)
|
||||||
|
.collect::<Vec<&str>>()
|
||||||
|
.join("/"),
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
self.inline_prompt_history
|
self.inline_prompt_history
|
||||||
.retain(|prompt| prompt != user_prompt);
|
.retain(|prompt| prompt != user_prompt);
|
||||||
self.inline_prompt_history.push_back(user_prompt.into());
|
self.inline_prompt_history.push_back(user_prompt.into());
|
||||||
@ -649,7 +661,6 @@ impl AssistantPanel {
|
|||||||
None
|
None
|
||||||
};
|
};
|
||||||
|
|
||||||
let codegen_kind = codegen.read(cx).kind().clone();
|
|
||||||
let user_prompt = user_prompt.to_string();
|
let user_prompt = user_prompt.to_string();
|
||||||
|
|
||||||
let snippets = if retrieve_context {
|
let snippets = if retrieve_context {
|
||||||
@ -692,11 +703,11 @@ impl AssistantPanel {
|
|||||||
generate_content_prompt(
|
generate_content_prompt(
|
||||||
user_prompt,
|
user_prompt,
|
||||||
language_name,
|
language_name,
|
||||||
&buffer,
|
buffer,
|
||||||
range,
|
range,
|
||||||
codegen_kind,
|
|
||||||
snippets,
|
snippets,
|
||||||
model_name,
|
model_name,
|
||||||
|
project_name,
|
||||||
)
|
)
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
use crate::codegen::CodegenKind;
|
use crate::codegen::CodegenKind;
|
||||||
use ai::models::{LanguageModel, OpenAILanguageModel};
|
use ai::models::{LanguageModel, OpenAILanguageModel};
|
||||||
use ai::templates::base::{PromptArguments, PromptChain, PromptPriority, PromptTemplate};
|
use ai::templates::base::{PromptArguments, PromptChain, PromptPriority, PromptTemplate};
|
||||||
|
use ai::templates::file_context::FileContext;
|
||||||
|
use ai::templates::generate::GenerateInlineContent;
|
||||||
use ai::templates::preamble::EngineerPreamble;
|
use ai::templates::preamble::EngineerPreamble;
|
||||||
use ai::templates::repository_context::{PromptCodeSnippet, RepositoryContext};
|
use ai::templates::repository_context::{PromptCodeSnippet, RepositoryContext};
|
||||||
use language::{BufferSnapshot, OffsetRangeExt, ToOffset};
|
use language::{BufferSnapshot, OffsetRangeExt, ToOffset};
|
||||||
@ -124,11 +126,11 @@ fn summarize(buffer: &BufferSnapshot, selected_range: Range<impl ToOffset>) -> S
|
|||||||
pub fn generate_content_prompt(
|
pub fn generate_content_prompt(
|
||||||
user_prompt: String,
|
user_prompt: String,
|
||||||
language_name: Option<&str>,
|
language_name: Option<&str>,
|
||||||
buffer: &BufferSnapshot,
|
buffer: BufferSnapshot,
|
||||||
range: Range<impl ToOffset>,
|
range: Range<usize>,
|
||||||
kind: CodegenKind,
|
|
||||||
search_results: Vec<PromptCodeSnippet>,
|
search_results: Vec<PromptCodeSnippet>,
|
||||||
model: &str,
|
model: &str,
|
||||||
|
project_name: Option<String>,
|
||||||
) -> anyhow::Result<String> {
|
) -> anyhow::Result<String> {
|
||||||
// Using new Prompt Templates
|
// Using new Prompt Templates
|
||||||
let openai_model: Arc<dyn LanguageModel> = Arc::new(OpenAILanguageModel::load(model));
|
let openai_model: Arc<dyn LanguageModel> = Arc::new(OpenAILanguageModel::load(model));
|
||||||
@ -141,146 +143,24 @@ pub fn generate_content_prompt(
|
|||||||
let args = PromptArguments {
|
let args = PromptArguments {
|
||||||
model: openai_model,
|
model: openai_model,
|
||||||
language_name: lang_name.clone(),
|
language_name: lang_name.clone(),
|
||||||
project_name: None,
|
project_name,
|
||||||
snippets: search_results.clone(),
|
snippets: search_results.clone(),
|
||||||
reserved_tokens: 1000,
|
reserved_tokens: 1000,
|
||||||
|
buffer: Some(buffer),
|
||||||
|
selected_range: Some(range),
|
||||||
|
user_prompt: Some(user_prompt.clone()),
|
||||||
};
|
};
|
||||||
|
|
||||||
let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
|
let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
|
||||||
(PromptPriority::High, Box::new(EngineerPreamble {})),
|
(PromptPriority::High, Box::new(EngineerPreamble {})),
|
||||||
(PromptPriority::Low, Box::new(RepositoryContext {})),
|
(PromptPriority::Low, Box::new(RepositoryContext {})),
|
||||||
|
(PromptPriority::Medium, Box::new(FileContext {})),
|
||||||
|
(PromptPriority::High, Box::new(GenerateInlineContent {})),
|
||||||
];
|
];
|
||||||
let chain = PromptChain::new(args, templates);
|
let chain = PromptChain::new(args, templates);
|
||||||
|
let (prompt, _) = chain.generate(true)?;
|
||||||
|
|
||||||
let prompt = chain.generate(true)?;
|
anyhow::Ok(prompt)
|
||||||
println!("{:?}", prompt);
|
|
||||||
|
|
||||||
const MAXIMUM_SNIPPET_TOKEN_COUNT: usize = 500;
|
|
||||||
const RESERVED_TOKENS_FOR_GENERATION: usize = 1000;
|
|
||||||
|
|
||||||
let mut prompts = Vec::new();
|
|
||||||
let range = range.to_offset(buffer);
|
|
||||||
|
|
||||||
// General Preamble
|
|
||||||
if let Some(language_name) = language_name.clone() {
|
|
||||||
prompts.push(format!("You're an expert {language_name} engineer.\n"));
|
|
||||||
} else {
|
|
||||||
prompts.push("You're an expert engineer.\n".to_string());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Snippets
|
|
||||||
let mut snippet_position = prompts.len() - 1;
|
|
||||||
|
|
||||||
let mut content = String::new();
|
|
||||||
content.extend(buffer.text_for_range(0..range.start));
|
|
||||||
if range.start == range.end {
|
|
||||||
content.push_str("<|START|>");
|
|
||||||
} else {
|
|
||||||
content.push_str("<|START|");
|
|
||||||
}
|
|
||||||
content.extend(buffer.text_for_range(range.clone()));
|
|
||||||
if range.start != range.end {
|
|
||||||
content.push_str("|END|>");
|
|
||||||
}
|
|
||||||
content.extend(buffer.text_for_range(range.end..buffer.len()));
|
|
||||||
|
|
||||||
prompts.push("The file you are currently working on has the following content:\n".to_string());
|
|
||||||
|
|
||||||
if let Some(language_name) = language_name {
|
|
||||||
let language_name = language_name.to_lowercase();
|
|
||||||
prompts.push(format!("```{language_name}\n{content}\n```"));
|
|
||||||
} else {
|
|
||||||
prompts.push(format!("```\n{content}\n```"));
|
|
||||||
}
|
|
||||||
|
|
||||||
match kind {
|
|
||||||
CodegenKind::Generate { position: _ } => {
|
|
||||||
prompts.push("In particular, the user's cursor is currently on the '<|START|>' span in the above outline, with no text selected.".to_string());
|
|
||||||
prompts
|
|
||||||
.push("Assume the cursor is located where the `<|START|` marker is.".to_string());
|
|
||||||
prompts.push(
|
|
||||||
"Text can't be replaced, so assume your answer will be inserted at the cursor."
|
|
||||||
.to_string(),
|
|
||||||
);
|
|
||||||
prompts.push(format!(
|
|
||||||
"Generate text based on the users prompt: {user_prompt}"
|
|
||||||
));
|
|
||||||
}
|
|
||||||
CodegenKind::Transform { range: _ } => {
|
|
||||||
prompts.push("In particular, the user has selected a section of the text between the '<|START|' and '|END|>' spans.".to_string());
|
|
||||||
prompts.push(format!(
|
|
||||||
"Modify the users code selected text based upon the users prompt: '{user_prompt}'"
|
|
||||||
));
|
|
||||||
prompts.push("You MUST reply with only the adjusted code (within the '<|START|' and '|END|>' spans), not the entire file.".to_string());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(language_name) = language_name {
|
|
||||||
prompts.push(format!(
|
|
||||||
"Your answer MUST always and only be valid {language_name}"
|
|
||||||
));
|
|
||||||
}
|
|
||||||
prompts.push("Never make remarks about the output.".to_string());
|
|
||||||
prompts.push("Do not return any text, except the generated code.".to_string());
|
|
||||||
prompts.push("Always wrap your code in a Markdown block".to_string());
|
|
||||||
|
|
||||||
let current_messages = [ChatCompletionRequestMessage {
|
|
||||||
role: "user".to_string(),
|
|
||||||
content: Some(prompts.join("\n")),
|
|
||||||
function_call: None,
|
|
||||||
name: None,
|
|
||||||
}];
|
|
||||||
|
|
||||||
let mut remaining_token_count = if let Ok(current_token_count) =
|
|
||||||
tiktoken_rs::num_tokens_from_messages(model, ¤t_messages)
|
|
||||||
{
|
|
||||||
let max_token_count = tiktoken_rs::model::get_context_size(model);
|
|
||||||
let intermediate_token_count = if max_token_count > current_token_count {
|
|
||||||
max_token_count - current_token_count
|
|
||||||
} else {
|
|
||||||
0
|
|
||||||
};
|
|
||||||
|
|
||||||
if intermediate_token_count < RESERVED_TOKENS_FOR_GENERATION {
|
|
||||||
0
|
|
||||||
} else {
|
|
||||||
intermediate_token_count - RESERVED_TOKENS_FOR_GENERATION
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// If tiktoken fails to count token count, assume we have no space remaining.
|
|
||||||
0
|
|
||||||
};
|
|
||||||
|
|
||||||
// TODO:
|
|
||||||
// - add repository name to snippet
|
|
||||||
// - add file path
|
|
||||||
// - add language
|
|
||||||
if let Ok(encoding) = tiktoken_rs::get_bpe_from_model(model) {
|
|
||||||
let mut template = "You are working inside a large repository, here are a few code snippets that may be useful";
|
|
||||||
|
|
||||||
for search_result in search_results {
|
|
||||||
let mut snippet_prompt = template.to_string();
|
|
||||||
let snippet = search_result.to_string();
|
|
||||||
writeln!(snippet_prompt, "```\n{snippet}\n```").unwrap();
|
|
||||||
|
|
||||||
let token_count = encoding
|
|
||||||
.encode_with_special_tokens(snippet_prompt.as_str())
|
|
||||||
.len();
|
|
||||||
if token_count <= remaining_token_count {
|
|
||||||
if token_count < MAXIMUM_SNIPPET_TOKEN_COUNT {
|
|
||||||
prompts.insert(snippet_position, snippet_prompt);
|
|
||||||
snippet_position += 1;
|
|
||||||
remaining_token_count -= token_count;
|
|
||||||
// If you have already added the template to the prompt, remove the template.
|
|
||||||
template = "";
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
anyhow::Ok(prompts.join("\n"))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
Loading…
Reference in New Issue
Block a user