mirror of
https://github.com/zed-industries/zed.git
synced 2024-09-20 02:47:34 +03:00
Revert changes to inline assist indentation logic and prompt (#16403)
This PR reverts #16145 and subsequent changes.
This reverts commit a515442a36
.
We still have issues with our approach to indentation in Python
unfortunately, but this feels like a safer equilibrium than where we
were.
Release Notes:
- Returned to our previous prompt for inline assist transformations,
since recent changes were introducing issues.
This commit is contained in:
parent
ebecd7e65f
commit
07d5e22cbe
@ -1,426 +1,61 @@
|
||||
You are an expert developer assistant working in an AI-enabled text editor.
|
||||
Your task is to rewrite a specific section of the provided document based on a user-provided prompt.
|
||||
{{#if language_name}}
|
||||
Here's a file of {{language_name}} that I'm going to ask you to make an edit to.
|
||||
{{else}}
|
||||
Here's a file of text that I'm going to ask you to make an edit to.
|
||||
{{/if}}
|
||||
|
||||
<guidelines>
|
||||
1. Scope: Modify only content within <rewrite_this> tags. Do not alter anything outside these boundaries.
|
||||
2. Precision: Make changes strictly necessary to fulfill the given prompt. Preserve all other content as-is.
|
||||
3. Seamless integration: Ensure rewritten sections flow naturally with surrounding text and maintain document structure.
|
||||
4. Tag exclusion: Never include <rewrite_this>, </rewrite_this>, <edit_here>, or <insert_here> tags in the output.
|
||||
5. Indentation: Maintain the original indentation level of the file in rewritten sections.
|
||||
6. Completeness: Rewrite the entire tagged section, even if only partial changes are needed. Avoid omissions or elisions.
|
||||
7. Insertions: Replace <insert_here></insert_here> tags with appropriate content as specified by the prompt.
|
||||
8. Code integrity: Respect existing code structure and functionality when making changes.
|
||||
9. Consistency: Maintain a uniform style and tone throughout the rewritten text.
|
||||
</guidelines>
|
||||
{{#if is_insert}}
|
||||
The point you'll need to insert at is marked with <insert_here></insert_here>.
|
||||
{{else}}
|
||||
The section you'll need to rewrite is marked with <rewrite_this></rewrite_this> tags.
|
||||
{{/if}}
|
||||
|
||||
<examples>
|
||||
<example>
|
||||
<input>
|
||||
<document>
|
||||
use std::cell::Cell;
|
||||
use std::collections::HashMap;
|
||||
use std::cmp;
|
||||
|
||||
<rewrite_this>
|
||||
<insert_here></insert_here>
|
||||
</rewrite_this>
|
||||
pub struct LruCache<K, V> {
|
||||
/// The maximum number of items the cache can hold.
|
||||
capacity: usize,
|
||||
/// The map storing the cached items.
|
||||
items: HashMap<K, V>,
|
||||
}
|
||||
|
||||
// The rest of the implementation...
|
||||
</document>
|
||||
<prompt>
|
||||
doc this
|
||||
</prompt>
|
||||
</input>
|
||||
|
||||
<incorrect_output failure="Over-generation. The text starting with `pub struct AabbTree<T> {` is *after* the rewrite_this tag">
|
||||
/// Represents an Axis-Aligned Bounding Box (AABB) tree data structure.
|
||||
///
|
||||
/// This structure is used for efficient spatial queries and collision detection.
|
||||
/// It organizes objects in a hierarchical tree structure based on their bounding boxes.
|
||||
///
|
||||
/// # Type Parameters
|
||||
///
|
||||
/// * `T`: The type of data associated with each node in the tree.
|
||||
pub struct AabbTree<T> {
|
||||
root: Option<usize>,
|
||||
</incorrect_output>
|
||||
<corrected_output improvement="Generation stops before repeating content after the rewrite_this section">
|
||||
/// Represents an Axis-Aligned Bounding Box (AABB) tree data structure.
|
||||
///
|
||||
/// This structure is used for efficient spatial queries and collision detection.
|
||||
/// It organizes objects in a hierarchical tree structure based on their bounding boxes.
|
||||
///
|
||||
/// # Type Parameters
|
||||
///
|
||||
/// * `T`: The type of data associated with each node in the tree.
|
||||
</corrected_output>
|
||||
</example>
|
||||
|
||||
<example>
|
||||
<input>
|
||||
<document>
|
||||
import math
|
||||
|
||||
def calculate_circle_area(radius):
|
||||
"""Calculate the area of a circle given its radius."""
|
||||
return math.pi * radius ** 2
|
||||
|
||||
<rewrite_this>
|
||||
<insert_here></insert_here>
|
||||
</rewrite_this>
|
||||
class Circle:
|
||||
def __init__(self, radius):
|
||||
self.radius = radius
|
||||
|
||||
def area(self):
|
||||
return math.pi * self.radius ** 2
|
||||
|
||||
def circumference(self):
|
||||
return 2 * math.pi * self.radius
|
||||
|
||||
# Usage example
|
||||
circle = Circle(5)
|
||||
print(f"Area: {circle.area():.2f}")
|
||||
print(f"Circumference: {circle.circumference():.2f}")
|
||||
</document>
|
||||
<prompt>
|
||||
write docs
|
||||
</prompt>
|
||||
</input>
|
||||
|
||||
<incorrect_output failure="Over-generation. The text starting with `class Circle:` is *after* the rewrite_this tag">
|
||||
"""
|
||||
Represents a circle with methods to calculate its area and circumference.
|
||||
|
||||
This class provides a simple way to work with circles in a geometric context.
|
||||
It allows for the creation of Circle objects with a specified radius and
|
||||
offers methods to compute the circle's area and circumference.
|
||||
|
||||
Attributes:
|
||||
radius (float): The radius of the circle.
|
||||
|
||||
Methods:
|
||||
area(): Calculates and returns the area of the circle.
|
||||
circumference(): Calculates and returns the circumference of the circle.
|
||||
"""
|
||||
class Circle:
|
||||
</incorrect_output>
|
||||
<corrected_output improvement="Generation stops before repeating content after the rewrite_this section">
|
||||
"""
|
||||
Represents a circle with methods to calculate its area and circumference.
|
||||
|
||||
This class provides a simple way to work with circles in a geometric context.
|
||||
It allows for the creation of Circle objects with a specified radius and
|
||||
offers methods to compute the circle's area and circumference.
|
||||
|
||||
Attributes:
|
||||
radius (float): The radius of the circle.
|
||||
|
||||
Methods:
|
||||
area(): Calculates and returns the area of the circle.
|
||||
circumference(): Calculates and returns the circumference of the circle.
|
||||
"""
|
||||
</corrected_output>
|
||||
</example>
|
||||
|
||||
<example>
|
||||
<input>
|
||||
<document>
|
||||
class BankAccount {
|
||||
private balance: number;
|
||||
|
||||
constructor(initialBalance: number) {
|
||||
this.balance = initialBalance;
|
||||
}
|
||||
|
||||
<rewrite_this>
|
||||
<insert_here></insert_here>
|
||||
</rewrite_this>
|
||||
deposit(amount: number): void {
|
||||
if (amount > 0) {
|
||||
this.balance += amount;
|
||||
}
|
||||
}
|
||||
|
||||
withdraw(amount: number): boolean {
|
||||
if (amount > 0 && this.balance >= amount) {
|
||||
this.balance -= amount;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
getBalance(): number {
|
||||
return this.balance;
|
||||
}
|
||||
}
|
||||
|
||||
// Usage
|
||||
const account = new BankAccount(1000);
|
||||
account.deposit(500);
|
||||
console.log(account.getBalance()); // 1500
|
||||
account.withdraw(200);
|
||||
console.log(account.getBalance()); // 1300
|
||||
</document>
|
||||
<prompt>
|
||||
//
|
||||
</prompt>
|
||||
</input>
|
||||
|
||||
<incorrect_output failure="Over-generation. The text starting with `deposit(amount: number): void {` is *after* the rewrite_this tag">
|
||||
/**
|
||||
* Deposits the specified amount into the bank account.
|
||||
*
|
||||
* @param amount The amount to deposit. Must be a positive number.
|
||||
* @throws Error if the amount is not positive.
|
||||
*/
|
||||
deposit(amount: number): void {
|
||||
if (amount > 0) {
|
||||
this.balance += amount;
|
||||
} else {
|
||||
throw new Error("Deposit amount must be positive");
|
||||
}
|
||||
}
|
||||
</incorrect_output>
|
||||
<corrected_output improvement="Generation stops before repeating content after the rewrite_this section">
|
||||
/**
|
||||
* Deposits the specified amount into the bank account.
|
||||
*
|
||||
* @param amount The amount to deposit. Must be a positive number.
|
||||
* @throws Error if the amount is not positive.
|
||||
*/
|
||||
</corrected_output>
|
||||
</example>
|
||||
|
||||
<example>
|
||||
<input>
|
||||
<document>
|
||||
use std::collections::VecDeque;
|
||||
|
||||
pub struct BinaryTree<T> {
|
||||
root: Option<Node<T>>,
|
||||
}
|
||||
|
||||
<rewrite_this>
|
||||
<insert_here></insert_here>
|
||||
</rewrite_this>
|
||||
struct Node<T> {
|
||||
value: T,
|
||||
left: Option<Box<Node<T>>>,
|
||||
right: Option<Box<Node<T>>>,
|
||||
}
|
||||
</document>
|
||||
<prompt>
|
||||
derive clone
|
||||
</prompt>
|
||||
</input>
|
||||
|
||||
<incorrect_output failure="Over-generation below the rewrite_this tags. Extra space between derive annotation and struct definition.">
|
||||
#[derive(Clone)]
|
||||
|
||||
struct Node<T> {
|
||||
value: T,
|
||||
left: Option<Box<Node<T>>>,
|
||||
right: Option<Box<Node<T>>>,
|
||||
}
|
||||
</incorrect_output>
|
||||
|
||||
<incorrect_output failure="Over-generation above the rewrite_this tags">
|
||||
pub struct BinaryTree<T> {
|
||||
root: Option<Node<T>>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
</incorrect_output>
|
||||
|
||||
<incorrect_output failure="Over-generation below the rewrite_this tags">
|
||||
#[derive(Clone)]
|
||||
struct Node<T> {
|
||||
value: T,
|
||||
left: Option<Box<Node<T>>>,
|
||||
right: Option<Box<Node<T>>>,
|
||||
}
|
||||
|
||||
impl<T> Node<T> {
|
||||
fn new(value: T) -> Self {
|
||||
Node {
|
||||
value,
|
||||
left: None,
|
||||
right: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
</incorrect_output>
|
||||
<corrected_output improvement="Only includes the new content within the rewrite_this tags">
|
||||
#[derive(Clone)]
|
||||
</corrected_output>
|
||||
</example>
|
||||
|
||||
<example>
|
||||
<input>
|
||||
<document>
|
||||
import math
|
||||
|
||||
def calculate_circle_area(radius):
|
||||
"""Calculate the area of a circle given its radius."""
|
||||
return math.pi * radius ** 2
|
||||
|
||||
<rewrite_this>
|
||||
<insert_here></insert_here>
|
||||
</rewrite_this>
|
||||
class Circle:
|
||||
def __init__(self, radius):
|
||||
self.radius = radius
|
||||
|
||||
def area(self):
|
||||
return math.pi * self.radius ** 2
|
||||
|
||||
def circumference(self):
|
||||
return 2 * math.pi * self.radius
|
||||
|
||||
# Usage example
|
||||
circle = Circle(5)
|
||||
print(f"Area: {circle.area():.2f}")
|
||||
print(f"Circumference: {circle.circumference():.2f}")
|
||||
</document>
|
||||
<prompt>
|
||||
add dataclass decorator
|
||||
</prompt>
|
||||
</input>
|
||||
|
||||
<incorrect_output failure="Over-generation. The text starting with `class Circle:` is *after* the rewrite_this tag">
|
||||
@dataclass
|
||||
class Circle:
|
||||
radius: float
|
||||
|
||||
def __init__(self, radius):
|
||||
self.radius = radius
|
||||
|
||||
def area(self):
|
||||
return math.pi * self.radius ** 2
|
||||
</incorrect_output>
|
||||
<corrected_output improvement="Generation stops before repeating content after the rewrite_this section">
|
||||
@dataclass
|
||||
</corrected_output>
|
||||
</example>
|
||||
|
||||
<example>
|
||||
<input>
|
||||
<document>
|
||||
interface ShoppingCart {
|
||||
items: string[];
|
||||
total: number;
|
||||
}
|
||||
|
||||
<rewrite_this>
|
||||
<insert_here></insert_here>class ShoppingCartManager {
|
||||
</rewrite_this>
|
||||
private cart: ShoppingCart;
|
||||
|
||||
constructor() {
|
||||
this.cart = { items: [], total: 0 };
|
||||
}
|
||||
|
||||
addItem(item: string, price: number): void {
|
||||
this.cart.items.push(item);
|
||||
this.cart.total += price;
|
||||
}
|
||||
|
||||
getTotal(): number {
|
||||
return this.cart.total;
|
||||
}
|
||||
}
|
||||
|
||||
// Usage
|
||||
const manager = new ShoppingCartManager();
|
||||
manager.addItem("Book", 15.99);
|
||||
console.log(manager.getTotal()); // 15.99
|
||||
</document>
|
||||
<prompt>
|
||||
add readonly modifier
|
||||
</prompt>
|
||||
</input>
|
||||
|
||||
<incorrect_output failure="Over-generation. The line starting with ` items: string[];` is *after* the rewrite_this tag">
|
||||
readonly interface ShoppingCart {
|
||||
items: string[];
|
||||
total: number;
|
||||
}
|
||||
|
||||
class ShoppingCartManager {
|
||||
private readonly cart: ShoppingCart;
|
||||
|
||||
constructor() {
|
||||
this.cart = { items: [], total: 0 };
|
||||
}
|
||||
</incorrect_output>
|
||||
<corrected_output improvement="Only includes the new content within the rewrite_this tags and integrates cleanly into surrounding code">
|
||||
readonly interface ShoppingCart {
|
||||
</corrected_output>
|
||||
</example>
|
||||
|
||||
</examples>
|
||||
|
||||
With these examples in mind, edit the following file:
|
||||
|
||||
<document language="{{ language_name }}">
|
||||
{{{ document_content }}}
|
||||
{{{document_content}}}
|
||||
</document>
|
||||
|
||||
{{#if is_truncated}}
|
||||
The provided document has been truncated (potentially mid-line) for brevity.
|
||||
The context around the relevant section has been truncated (possibly in the middle of a line) for brevity.
|
||||
{{/if}}
|
||||
|
||||
<instructions>
|
||||
{{#if has_insertion}}
|
||||
Insert text anywhere you see marked with <insert_here></insert_here> tags. It's CRITICAL that you DO NOT include <insert_here> tags in your output.
|
||||
{{/if}}
|
||||
{{#if has_replacement}}
|
||||
Edit text that you see surrounded with <edit_here>...</edit_here> tags. It's CRITICAL that you DO NOT include <edit_here> tags in your output.
|
||||
{{/if}}
|
||||
Make no changes to the rewritten content outside these tags.
|
||||
{{#if is_insert}}
|
||||
You can't replace {{content_type}}, your answer will be inserted in place of the `<insert_here></insert_here>` tags. Don't include the insert_here tags in your output.
|
||||
|
||||
<snippet language="{{ language_name }}" annotated="true">
|
||||
{{{ rewrite_section_prefix }}}
|
||||
<rewrite_this>
|
||||
{{{ rewrite_section_with_edits }}}
|
||||
</rewrite_this>
|
||||
{{{ rewrite_section_suffix }}}
|
||||
</snippet>
|
||||
|
||||
Rewrite the lines enclosed within the <rewrite_this></rewrite_this> tags in accordance with the provided instructions and the prompt below.
|
||||
Generate {{content_type}} based on the following prompt:
|
||||
|
||||
<prompt>
|
||||
{{{ user_prompt }}}
|
||||
{{{user_prompt}}}
|
||||
</prompt>
|
||||
|
||||
Do not include <insert_here> or <edit_here> annotations in your output. Here is a clean copy of the snippet without annotations for your reference.
|
||||
Match the indentation in the original file in the inserted {{content_type}}, don't include any indentation on blank lines.
|
||||
|
||||
<snippet>
|
||||
{{{ rewrite_section_prefix }}}
|
||||
{{{ rewrite_section }}}
|
||||
{{{ rewrite_section_suffix }}}
|
||||
</snippet>
|
||||
</instructions>
|
||||
Immediately start with the following format with no remarks:
|
||||
|
||||
<guidelines_reminder>
|
||||
1. Focus on necessary changes: Modify only what's required to fulfill the prompt.
|
||||
2. Preserve context: Maintain all surrounding content as-is, ensuring the rewritten section seamlessly integrates with the existing document structure and flow.
|
||||
3. Exclude annotation tags: Do not output <rewrite_this>, </rewrite_this>, <edit_here>, or <insert_here> tags.
|
||||
4. Maintain indentation: Begin at the original file's indentation level.
|
||||
5. Complete rewrite: Continue until the entire section is rewritten, even if no further changes are needed.
|
||||
6. Avoid elisions: Always write out the full section without unnecessary omissions. NEVER say `// ...` or `// ...existing code` in your output.
|
||||
7. Respect content boundaries: Preserve code integrity.
|
||||
</guidelines_reminder>
|
||||
```
|
||||
\{{INSERTED_CODE}}
|
||||
```
|
||||
{{else}}
|
||||
Edit the section of {{content_type}} in <rewrite_this></rewrite_this> tags based on the following prompt:
|
||||
|
||||
<prompt>
|
||||
{{{user_prompt}}}
|
||||
</prompt>
|
||||
|
||||
{{#if rewrite_section}}
|
||||
And here's the section to rewrite based on that prompt again for reference:
|
||||
|
||||
<rewrite_this>
|
||||
{{{rewrite_section}}}
|
||||
</rewrite_this>
|
||||
{{/if}}
|
||||
|
||||
Only make changes that are necessary to fulfill the prompt, leave everything else as-is. All surrounding {{content_type}} will be preserved.
|
||||
|
||||
Start at the indentation level in the original file in the rewritten {{content_type}}. Don't stop until you've rewritten the entire section, even if you have no more changes to make, always write out the whole section with no unnecessary elisions.
|
||||
|
||||
Immediately start with the following format with no remarks:
|
||||
|
||||
```
|
||||
\{{REWRITTEN_CODE}}
|
||||
```
|
||||
{{/if}}
|
||||
|
@ -34,7 +34,6 @@ use language_model::{
|
||||
};
|
||||
pub(crate) use model_selector::*;
|
||||
pub use prompts::PromptBuilder;
|
||||
use prompts::PromptOverrideContext;
|
||||
use semantic_index::{CloudEmbeddingProvider, SemanticIndex};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::{update_settings_file, Settings, SettingsStore};
|
||||
@ -181,12 +180,7 @@ impl Assistant {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn init(
|
||||
fs: Arc<dyn Fs>,
|
||||
client: Arc<Client>,
|
||||
dev_mode: bool,
|
||||
cx: &mut AppContext,
|
||||
) -> Arc<PromptBuilder> {
|
||||
pub fn init(fs: Arc<dyn Fs>, client: Arc<Client>, cx: &mut AppContext) -> Arc<PromptBuilder> {
|
||||
cx.set_global(Assistant::default());
|
||||
AssistantSettings::register(cx);
|
||||
SlashCommandSettings::register(cx);
|
||||
@ -223,14 +217,10 @@ pub fn init(
|
||||
assistant_panel::init(cx);
|
||||
context_servers::init(cx);
|
||||
|
||||
let prompt_builder = prompts::PromptBuilder::new(Some(PromptOverrideContext {
|
||||
dev_mode,
|
||||
fs: fs.clone(),
|
||||
cx,
|
||||
}))
|
||||
.log_err()
|
||||
.map(Arc::new)
|
||||
.unwrap_or_else(|| Arc::new(prompts::PromptBuilder::new(None).unwrap()));
|
||||
let prompt_builder = prompts::PromptBuilder::new(Some((fs.clone(), cx)))
|
||||
.log_err()
|
||||
.map(Arc::new)
|
||||
.unwrap_or_else(|| Arc::new(prompts::PromptBuilder::new(None).unwrap()));
|
||||
register_slash_commands(Some(prompt_builder.clone()), cx);
|
||||
inline_assistant::init(
|
||||
fs.clone(),
|
||||
|
@ -28,7 +28,7 @@ use gpui::{
|
||||
FontWeight, Global, HighlightStyle, Model, ModelContext, Subscription, Task, TextStyle,
|
||||
UpdateGlobal, View, ViewContext, WeakView, WindowContext,
|
||||
};
|
||||
use language::{Buffer, IndentKind, Point, TransactionId};
|
||||
use language::{Buffer, IndentKind, Point, Selection, TransactionId};
|
||||
use language_model::{
|
||||
LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
|
||||
};
|
||||
@ -38,6 +38,7 @@ use rope::Rope;
|
||||
use settings::Settings;
|
||||
use smol::future::FutureExt;
|
||||
use std::{
|
||||
cmp,
|
||||
future::{self, Future},
|
||||
mem,
|
||||
ops::{Range, RangeInclusive},
|
||||
@ -46,7 +47,6 @@ use std::{
|
||||
task::{self, Poll},
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
use text::OffsetRangeExt as _;
|
||||
use theme::ThemeSettings;
|
||||
use ui::{prelude::*, CheckboxWithLabel, IconButtonShape, Popover, Tooltip};
|
||||
use util::{RangeExt, ResultExt};
|
||||
@ -140,81 +140,66 @@ impl InlineAssistant {
|
||||
cx: &mut WindowContext,
|
||||
) {
|
||||
let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
|
||||
struct CodegenRange {
|
||||
transform_range: Range<Point>,
|
||||
selection_ranges: Vec<Range<Point>>,
|
||||
focus_assist: bool,
|
||||
}
|
||||
|
||||
let newest_selection_range = editor.read(cx).selections.newest::<Point>(cx).range();
|
||||
let mut codegen_ranges: Vec<CodegenRange> = Vec::new();
|
||||
|
||||
let selection_ranges = snapshot
|
||||
.split_ranges(editor.read(cx).selections.disjoint_anchor_ranges())
|
||||
.map(|range| range.to_point(&snapshot))
|
||||
.collect::<Vec<Range<Point>>>();
|
||||
|
||||
for selection_range in selection_ranges {
|
||||
let selection_is_newest = newest_selection_range.contains_inclusive(&selection_range);
|
||||
let mut transform_range = selection_range.start..selection_range.end;
|
||||
|
||||
// Expand the transform range to start/end of lines.
|
||||
// If a non-empty selection ends at the start of the last line, clip at the end of the penultimate line.
|
||||
transform_range.start.column = 0;
|
||||
if transform_range.end.column == 0 && transform_range.end > transform_range.start {
|
||||
transform_range.end.row -= 1;
|
||||
let mut selections = Vec::<Selection<Point>>::new();
|
||||
let mut newest_selection = None;
|
||||
for mut selection in editor.read(cx).selections.all::<Point>(cx) {
|
||||
if selection.end > selection.start {
|
||||
selection.start.column = 0;
|
||||
// If the selection ends at the start of the line, we don't want to include it.
|
||||
if selection.end.column == 0 {
|
||||
selection.end.row -= 1;
|
||||
}
|
||||
selection.end.column = snapshot.line_len(MultiBufferRow(selection.end.row));
|
||||
}
|
||||
transform_range.end.column = snapshot.line_len(MultiBufferRow(transform_range.end.row));
|
||||
let selection_range =
|
||||
selection_range.start..selection_range.end.min(transform_range.end);
|
||||
|
||||
// If we intersect the previous transform range,
|
||||
if let Some(CodegenRange {
|
||||
transform_range: prev_transform_range,
|
||||
selection_ranges,
|
||||
focus_assist,
|
||||
}) = codegen_ranges.last_mut()
|
||||
{
|
||||
if transform_range.start <= prev_transform_range.end {
|
||||
prev_transform_range.end = transform_range.end;
|
||||
selection_ranges.push(selection_range);
|
||||
*focus_assist |= selection_is_newest;
|
||||
if let Some(prev_selection) = selections.last_mut() {
|
||||
if selection.start <= prev_selection.end {
|
||||
prev_selection.end = selection.end;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
codegen_ranges.push(CodegenRange {
|
||||
transform_range,
|
||||
selection_ranges: vec![selection_range],
|
||||
focus_assist: selection_is_newest,
|
||||
})
|
||||
let latest_selection = newest_selection.get_or_insert_with(|| selection.clone());
|
||||
if selection.id > latest_selection.id {
|
||||
*latest_selection = selection.clone();
|
||||
}
|
||||
selections.push(selection);
|
||||
}
|
||||
let newest_selection = newest_selection.unwrap();
|
||||
|
||||
let mut codegen_ranges = Vec::new();
|
||||
for (excerpt_id, buffer, buffer_range) in
|
||||
snapshot.excerpts_in_ranges(selections.iter().map(|selection| {
|
||||
snapshot.anchor_before(selection.start)..snapshot.anchor_after(selection.end)
|
||||
}))
|
||||
{
|
||||
let start = Anchor {
|
||||
buffer_id: Some(buffer.remote_id()),
|
||||
excerpt_id,
|
||||
text_anchor: buffer.anchor_before(buffer_range.start),
|
||||
};
|
||||
let end = Anchor {
|
||||
buffer_id: Some(buffer.remote_id()),
|
||||
excerpt_id,
|
||||
text_anchor: buffer.anchor_after(buffer_range.end),
|
||||
};
|
||||
codegen_ranges.push(start..end);
|
||||
}
|
||||
|
||||
let assist_group_id = self.next_assist_group_id.post_inc();
|
||||
let prompt_buffer =
|
||||
cx.new_model(|cx| Buffer::local(initial_prompt.unwrap_or_default(), cx));
|
||||
let prompt_buffer = cx.new_model(|cx| MultiBuffer::singleton(prompt_buffer, cx));
|
||||
|
||||
let mut assists = Vec::new();
|
||||
let mut assist_to_focus = None;
|
||||
|
||||
for CodegenRange {
|
||||
transform_range,
|
||||
selection_ranges,
|
||||
focus_assist,
|
||||
} in codegen_ranges
|
||||
{
|
||||
let transform_range = snapshot.anchor_before(transform_range.start)
|
||||
..snapshot.anchor_after(transform_range.end);
|
||||
let selection_ranges = selection_ranges
|
||||
.iter()
|
||||
.map(|range| snapshot.anchor_before(range.start)..snapshot.anchor_after(range.end))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
for range in codegen_ranges {
|
||||
let assist_id = self.next_assist_id.post_inc();
|
||||
let codegen = cx.new_model(|cx| {
|
||||
Codegen::new(
|
||||
editor.read(cx).buffer().clone(),
|
||||
transform_range.clone(),
|
||||
selection_ranges,
|
||||
range.clone(),
|
||||
None,
|
||||
self.telemetry.clone(),
|
||||
self.prompt_builder.clone(),
|
||||
@ -222,7 +207,6 @@ impl InlineAssistant {
|
||||
)
|
||||
});
|
||||
|
||||
let assist_id = self.next_assist_id.post_inc();
|
||||
let gutter_dimensions = Arc::new(Mutex::new(GutterDimensions::default()));
|
||||
let prompt_editor = cx.new_view(|cx| {
|
||||
PromptEditor::new(
|
||||
@ -239,16 +223,23 @@ impl InlineAssistant {
|
||||
)
|
||||
});
|
||||
|
||||
if focus_assist {
|
||||
assist_to_focus = Some(assist_id);
|
||||
if assist_to_focus.is_none() {
|
||||
let focus_assist = if newest_selection.reversed {
|
||||
range.start.to_point(&snapshot) == newest_selection.start
|
||||
} else {
|
||||
range.end.to_point(&snapshot) == newest_selection.end
|
||||
};
|
||||
if focus_assist {
|
||||
assist_to_focus = Some(assist_id);
|
||||
}
|
||||
}
|
||||
|
||||
let [prompt_block_id, end_block_id] =
|
||||
self.insert_assist_blocks(editor, &transform_range, &prompt_editor, cx);
|
||||
self.insert_assist_blocks(editor, &range, &prompt_editor, cx);
|
||||
|
||||
assists.push((
|
||||
assist_id,
|
||||
transform_range,
|
||||
range,
|
||||
prompt_editor,
|
||||
prompt_block_id,
|
||||
end_block_id,
|
||||
@ -315,7 +306,6 @@ impl InlineAssistant {
|
||||
Codegen::new(
|
||||
editor.read(cx).buffer().clone(),
|
||||
range.clone(),
|
||||
vec![range.clone()],
|
||||
initial_transaction_id,
|
||||
self.telemetry.clone(),
|
||||
self.prompt_builder.clone(),
|
||||
@ -925,7 +915,12 @@ impl InlineAssistant {
|
||||
assist
|
||||
.codegen
|
||||
.update(cx, |codegen, cx| {
|
||||
codegen.start(user_prompt, assistant_panel_context, cx)
|
||||
codegen.start(
|
||||
assist.range.clone(),
|
||||
user_prompt,
|
||||
assistant_panel_context,
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.log_err();
|
||||
|
||||
@ -2120,9 +2115,12 @@ impl InlineAssist {
|
||||
return future::ready(Err(anyhow!("no user prompt"))).boxed();
|
||||
};
|
||||
let assistant_panel_context = self.assistant_panel_context(cx);
|
||||
self.codegen
|
||||
.read(cx)
|
||||
.count_tokens(user_prompt, assistant_panel_context, cx)
|
||||
self.codegen.read(cx).count_tokens(
|
||||
self.range.clone(),
|
||||
user_prompt,
|
||||
assistant_panel_context,
|
||||
cx,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@ -2143,8 +2141,6 @@ pub struct Codegen {
|
||||
buffer: Model<MultiBuffer>,
|
||||
old_buffer: Model<Buffer>,
|
||||
snapshot: MultiBufferSnapshot,
|
||||
transform_range: Range<Anchor>,
|
||||
selected_ranges: Vec<Range<Anchor>>,
|
||||
edit_position: Option<Anchor>,
|
||||
last_equal_ranges: Vec<Range<Anchor>>,
|
||||
initial_transaction_id: Option<TransactionId>,
|
||||
@ -2154,7 +2150,7 @@ pub struct Codegen {
|
||||
diff: Diff,
|
||||
telemetry: Option<Arc<Telemetry>>,
|
||||
_subscription: gpui::Subscription,
|
||||
prompt_builder: Arc<PromptBuilder>,
|
||||
builder: Arc<PromptBuilder>,
|
||||
}
|
||||
|
||||
enum CodegenStatus {
|
||||
@ -2181,8 +2177,7 @@ impl EventEmitter<CodegenEvent> for Codegen {}
|
||||
impl Codegen {
|
||||
pub fn new(
|
||||
buffer: Model<MultiBuffer>,
|
||||
transform_range: Range<Anchor>,
|
||||
selected_ranges: Vec<Range<Anchor>>,
|
||||
range: Range<Anchor>,
|
||||
initial_transaction_id: Option<TransactionId>,
|
||||
telemetry: Option<Arc<Telemetry>>,
|
||||
builder: Arc<PromptBuilder>,
|
||||
@ -2192,7 +2187,7 @@ impl Codegen {
|
||||
|
||||
let (old_buffer, _, _) = buffer
|
||||
.read(cx)
|
||||
.range_to_buffer_ranges(transform_range.clone(), cx)
|
||||
.range_to_buffer_ranges(range.clone(), cx)
|
||||
.pop()
|
||||
.unwrap();
|
||||
let old_buffer = cx.new_model(|cx| {
|
||||
@ -2223,9 +2218,7 @@ impl Codegen {
|
||||
telemetry,
|
||||
_subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
|
||||
initial_transaction_id,
|
||||
prompt_builder: builder,
|
||||
transform_range,
|
||||
selected_ranges,
|
||||
builder,
|
||||
}
|
||||
}
|
||||
|
||||
@ -2250,12 +2243,14 @@ impl Codegen {
|
||||
|
||||
pub fn count_tokens(
|
||||
&self,
|
||||
edit_range: Range<Anchor>,
|
||||
user_prompt: String,
|
||||
assistant_panel_context: Option<LanguageModelRequest>,
|
||||
cx: &AppContext,
|
||||
) -> BoxFuture<'static, Result<TokenCounts>> {
|
||||
if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() {
|
||||
let request = self.build_request(user_prompt, assistant_panel_context.clone(), cx);
|
||||
let request =
|
||||
self.build_request(user_prompt, assistant_panel_context.clone(), edit_range, cx);
|
||||
match request {
|
||||
Ok(request) => {
|
||||
let total_count = model.count_tokens(request.clone(), cx);
|
||||
@ -2280,6 +2275,7 @@ impl Codegen {
|
||||
|
||||
pub fn start(
|
||||
&mut self,
|
||||
edit_range: Range<Anchor>,
|
||||
user_prompt: String,
|
||||
assistant_panel_context: Option<LanguageModelRequest>,
|
||||
cx: &mut ModelContext<Self>,
|
||||
@ -2294,20 +2290,24 @@ impl Codegen {
|
||||
});
|
||||
}
|
||||
|
||||
self.edit_position = Some(self.transform_range.start.bias_right(&self.snapshot));
|
||||
self.edit_position = Some(edit_range.start.bias_right(&self.snapshot));
|
||||
|
||||
let telemetry_id = model.telemetry_id();
|
||||
let chunks: LocalBoxFuture<Result<BoxStream<Result<String>>>> =
|
||||
if user_prompt.trim().to_lowercase() == "delete" {
|
||||
async { Ok(stream::empty().boxed()) }.boxed_local()
|
||||
} else {
|
||||
let request = self.build_request(user_prompt, assistant_panel_context, cx)?;
|
||||
let chunks: LocalBoxFuture<Result<BoxStream<Result<String>>>> = if user_prompt
|
||||
.trim()
|
||||
.to_lowercase()
|
||||
== "delete"
|
||||
{
|
||||
async { Ok(stream::empty().boxed()) }.boxed_local()
|
||||
} else {
|
||||
let request =
|
||||
self.build_request(user_prompt, assistant_panel_context, edit_range.clone(), cx)?;
|
||||
|
||||
let chunks =
|
||||
cx.spawn(|_, cx| async move { model.stream_completion(request, &cx).await });
|
||||
async move { Ok(chunks.await?.boxed()) }.boxed_local()
|
||||
};
|
||||
self.handle_stream(telemetry_id, self.transform_range.clone(), chunks, cx);
|
||||
let chunks =
|
||||
cx.spawn(|_, cx| async move { model.stream_completion(request, &cx).await });
|
||||
async move { Ok(chunks.await?.boxed()) }.boxed_local()
|
||||
};
|
||||
self.handle_stream(telemetry_id, edit_range, chunks, cx);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -2315,10 +2315,11 @@ impl Codegen {
|
||||
&self,
|
||||
user_prompt: String,
|
||||
assistant_panel_context: Option<LanguageModelRequest>,
|
||||
edit_range: Range<Anchor>,
|
||||
cx: &AppContext,
|
||||
) -> Result<LanguageModelRequest> {
|
||||
let buffer = self.buffer.read(cx).snapshot(cx);
|
||||
let language = buffer.language_at(self.transform_range.start);
|
||||
let language = buffer.language_at(edit_range.start);
|
||||
let language_name = if let Some(language) = language.as_ref() {
|
||||
if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
|
||||
None
|
||||
@ -2343,9 +2344,9 @@ impl Codegen {
|
||||
};
|
||||
|
||||
let language_name = language_name.as_deref();
|
||||
let start = buffer.point_to_buffer_offset(self.transform_range.start);
|
||||
let end = buffer.point_to_buffer_offset(self.transform_range.end);
|
||||
let (transform_buffer, transform_range) = if let Some((start, end)) = start.zip(end) {
|
||||
let start = buffer.point_to_buffer_offset(edit_range.start);
|
||||
let end = buffer.point_to_buffer_offset(edit_range.end);
|
||||
let (buffer, range) = if let Some((start, end)) = start.zip(end) {
|
||||
let (start_buffer, start_buffer_offset) = start;
|
||||
let (end_buffer, end_buffer_offset) = end;
|
||||
if start_buffer.remote_id() == end_buffer.remote_id() {
|
||||
@ -2357,39 +2358,9 @@ impl Codegen {
|
||||
return Err(anyhow::anyhow!("invalid transformation range"));
|
||||
};
|
||||
|
||||
let mut transform_context_range = transform_range.to_point(&transform_buffer);
|
||||
transform_context_range.start.row = transform_context_range.start.row.saturating_sub(3);
|
||||
transform_context_range.start.column = 0;
|
||||
transform_context_range.end =
|
||||
(transform_context_range.end + Point::new(3, 0)).min(transform_buffer.max_point());
|
||||
transform_context_range.end.column =
|
||||
transform_buffer.line_len(transform_context_range.end.row);
|
||||
let transform_context_range = transform_context_range.to_offset(&transform_buffer);
|
||||
|
||||
let selected_ranges = self
|
||||
.selected_ranges
|
||||
.iter()
|
||||
.filter_map(|selected_range| {
|
||||
let start = buffer
|
||||
.point_to_buffer_offset(selected_range.start)
|
||||
.map(|(_, offset)| offset)?;
|
||||
let end = buffer
|
||||
.point_to_buffer_offset(selected_range.end)
|
||||
.map(|(_, offset)| offset)?;
|
||||
Some(start..end)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let prompt = self
|
||||
.prompt_builder
|
||||
.generate_content_prompt(
|
||||
user_prompt,
|
||||
language_name,
|
||||
transform_buffer,
|
||||
transform_range,
|
||||
selected_ranges,
|
||||
transform_context_range,
|
||||
)
|
||||
.builder
|
||||
.generate_content_prompt(user_prompt, language_name, buffer, range)
|
||||
.map_err(|e| anyhow::anyhow!("Failed to generate content prompt: {}", e))?;
|
||||
|
||||
let mut messages = Vec::new();
|
||||
@ -2462,19 +2433,84 @@ impl Codegen {
|
||||
let mut diff = StreamingDiff::new(selected_text.to_string());
|
||||
let mut line_diff = LineDiff::default();
|
||||
|
||||
let mut new_text = String::new();
|
||||
let mut base_indent = None;
|
||||
let mut line_indent = None;
|
||||
let mut first_line = true;
|
||||
|
||||
while let Some(chunk) = chunks.next().await {
|
||||
if response_latency.is_none() {
|
||||
response_latency = Some(request_start.elapsed());
|
||||
}
|
||||
let chunk = chunk?;
|
||||
let char_ops = diff.push_new(&chunk);
|
||||
line_diff.push_char_operations(&char_ops, &selected_text);
|
||||
diff_tx
|
||||
.send((char_ops, line_diff.line_operations()))
|
||||
.await?;
|
||||
|
||||
let mut lines = chunk.split('\n').peekable();
|
||||
while let Some(line) = lines.next() {
|
||||
new_text.push_str(line);
|
||||
if line_indent.is_none() {
|
||||
if let Some(non_whitespace_ch_ix) =
|
||||
new_text.find(|ch: char| !ch.is_whitespace())
|
||||
{
|
||||
line_indent = Some(non_whitespace_ch_ix);
|
||||
base_indent = base_indent.or(line_indent);
|
||||
|
||||
let line_indent = line_indent.unwrap();
|
||||
let base_indent = base_indent.unwrap();
|
||||
let indent_delta =
|
||||
line_indent as i32 - base_indent as i32;
|
||||
let mut corrected_indent_len = cmp::max(
|
||||
0,
|
||||
suggested_line_indent.len as i32 + indent_delta,
|
||||
)
|
||||
as usize;
|
||||
if first_line {
|
||||
corrected_indent_len = corrected_indent_len
|
||||
.saturating_sub(
|
||||
selection_start.column as usize,
|
||||
);
|
||||
}
|
||||
|
||||
let indent_char = suggested_line_indent.char();
|
||||
let mut indent_buffer = [0; 4];
|
||||
let indent_str =
|
||||
indent_char.encode_utf8(&mut indent_buffer);
|
||||
new_text.replace_range(
|
||||
..line_indent,
|
||||
&indent_str.repeat(corrected_indent_len),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if line_indent.is_some() {
|
||||
let char_ops = diff.push_new(&new_text);
|
||||
line_diff
|
||||
.push_char_operations(&char_ops, &selected_text);
|
||||
diff_tx
|
||||
.send((char_ops, line_diff.line_operations()))
|
||||
.await?;
|
||||
new_text.clear();
|
||||
}
|
||||
|
||||
if lines.peek().is_some() {
|
||||
let char_ops = diff.push_new("\n");
|
||||
line_diff
|
||||
.push_char_operations(&char_ops, &selected_text);
|
||||
diff_tx
|
||||
.send((char_ops, line_diff.line_operations()))
|
||||
.await?;
|
||||
if line_indent.is_none() {
|
||||
// Don't write out the leading indentation in empty lines on the next line
|
||||
// This is the case where the above if statement didn't clear the buffer
|
||||
new_text.clear();
|
||||
}
|
||||
line_indent = None;
|
||||
first_line = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let char_ops = diff.finish();
|
||||
let mut char_ops = diff.push_new(&new_text);
|
||||
char_ops.extend(diff.finish());
|
||||
line_diff.push_char_operations(&char_ops, &selected_text);
|
||||
line_diff.finish(&selected_text);
|
||||
diff_tx
|
||||
@ -2938,13 +2974,311 @@ fn merge_ranges(ranges: &mut Vec<Range<Anchor>>, buffer: &MultiBufferSnapshot) {
|
||||
mod tests {
|
||||
use super::*;
|
||||
use futures::stream::{self};
|
||||
use gpui::{Context, TestAppContext};
|
||||
use indoc::indoc;
|
||||
use language::{
|
||||
language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, LanguageMatcher,
|
||||
Point,
|
||||
};
|
||||
use language_model::LanguageModelRegistry;
|
||||
use rand::prelude::*;
|
||||
use serde::Serialize;
|
||||
use settings::SettingsStore;
|
||||
use std::{future, sync::Arc};
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct DummyCompletionRequest {
|
||||
pub name: String,
|
||||
}
|
||||
|
||||
#[gpui::test(iterations = 10)]
|
||||
async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
|
||||
cx.set_global(cx.update(SettingsStore::test));
|
||||
cx.update(language_model::LanguageModelRegistry::test);
|
||||
cx.update(language_settings::init);
|
||||
|
||||
let text = indoc! {"
|
||||
fn main() {
|
||||
let x = 0;
|
||||
for _ in 0..10 {
|
||||
x += 1;
|
||||
}
|
||||
}
|
||||
"};
|
||||
let buffer =
|
||||
cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
|
||||
let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
|
||||
let range = buffer.read_with(cx, |buffer, cx| {
|
||||
let snapshot = buffer.snapshot(cx);
|
||||
snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
|
||||
});
|
||||
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
|
||||
let codegen = cx.new_model(|cx| {
|
||||
Codegen::new(
|
||||
buffer.clone(),
|
||||
range.clone(),
|
||||
None,
|
||||
None,
|
||||
prompt_builder,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
||||
let (chunks_tx, chunks_rx) = mpsc::unbounded();
|
||||
codegen.update(cx, |codegen, cx| {
|
||||
codegen.handle_stream(
|
||||
String::new(),
|
||||
range,
|
||||
future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
||||
let mut new_text = concat!(
|
||||
" let mut x = 0;\n",
|
||||
" while x < 10 {\n",
|
||||
" x += 1;\n",
|
||||
" }",
|
||||
);
|
||||
while !new_text.is_empty() {
|
||||
let max_len = cmp::min(new_text.len(), 10);
|
||||
let len = rng.gen_range(1..=max_len);
|
||||
let (chunk, suffix) = new_text.split_at(len);
|
||||
chunks_tx.unbounded_send(chunk.to_string()).unwrap();
|
||||
new_text = suffix;
|
||||
cx.background_executor.run_until_parked();
|
||||
}
|
||||
drop(chunks_tx);
|
||||
cx.background_executor.run_until_parked();
|
||||
|
||||
assert_eq!(
|
||||
buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
|
||||
indoc! {"
|
||||
fn main() {
|
||||
let mut x = 0;
|
||||
while x < 10 {
|
||||
x += 1;
|
||||
}
|
||||
}
|
||||
"}
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test(iterations = 10)]
|
||||
async fn test_autoindent_when_generating_past_indentation(
|
||||
cx: &mut TestAppContext,
|
||||
mut rng: StdRng,
|
||||
) {
|
||||
cx.set_global(cx.update(SettingsStore::test));
|
||||
cx.update(language_settings::init);
|
||||
|
||||
let text = indoc! {"
|
||||
fn main() {
|
||||
le
|
||||
}
|
||||
"};
|
||||
let buffer =
|
||||
cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
|
||||
let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
|
||||
let range = buffer.read_with(cx, |buffer, cx| {
|
||||
let snapshot = buffer.snapshot(cx);
|
||||
snapshot.anchor_before(Point::new(1, 6))..snapshot.anchor_after(Point::new(1, 6))
|
||||
});
|
||||
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
|
||||
let codegen = cx.new_model(|cx| {
|
||||
Codegen::new(
|
||||
buffer.clone(),
|
||||
range.clone(),
|
||||
None,
|
||||
None,
|
||||
prompt_builder,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
||||
let (chunks_tx, chunks_rx) = mpsc::unbounded();
|
||||
codegen.update(cx, |codegen, cx| {
|
||||
codegen.handle_stream(
|
||||
String::new(),
|
||||
range.clone(),
|
||||
future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
||||
cx.background_executor.run_until_parked();
|
||||
|
||||
let mut new_text = concat!(
|
||||
"t mut x = 0;\n",
|
||||
"while x < 10 {\n",
|
||||
" x += 1;\n",
|
||||
"}", //
|
||||
);
|
||||
while !new_text.is_empty() {
|
||||
let max_len = cmp::min(new_text.len(), 10);
|
||||
let len = rng.gen_range(1..=max_len);
|
||||
let (chunk, suffix) = new_text.split_at(len);
|
||||
chunks_tx.unbounded_send(chunk.to_string()).unwrap();
|
||||
new_text = suffix;
|
||||
cx.background_executor.run_until_parked();
|
||||
}
|
||||
drop(chunks_tx);
|
||||
cx.background_executor.run_until_parked();
|
||||
|
||||
assert_eq!(
|
||||
buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
|
||||
indoc! {"
|
||||
fn main() {
|
||||
let mut x = 0;
|
||||
while x < 10 {
|
||||
x += 1;
|
||||
}
|
||||
}
|
||||
"}
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test(iterations = 10)]
|
||||
async fn test_autoindent_when_generating_before_indentation(
|
||||
cx: &mut TestAppContext,
|
||||
mut rng: StdRng,
|
||||
) {
|
||||
cx.update(LanguageModelRegistry::test);
|
||||
cx.set_global(cx.update(SettingsStore::test));
|
||||
cx.update(language_settings::init);
|
||||
|
||||
let text = concat!(
|
||||
"fn main() {\n",
|
||||
" \n",
|
||||
"}\n" //
|
||||
);
|
||||
let buffer =
|
||||
cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
|
||||
let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
|
||||
let range = buffer.read_with(cx, |buffer, cx| {
|
||||
let snapshot = buffer.snapshot(cx);
|
||||
snapshot.anchor_before(Point::new(1, 2))..snapshot.anchor_after(Point::new(1, 2))
|
||||
});
|
||||
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
|
||||
let codegen = cx.new_model(|cx| {
|
||||
Codegen::new(
|
||||
buffer.clone(),
|
||||
range.clone(),
|
||||
None,
|
||||
None,
|
||||
prompt_builder,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
||||
let (chunks_tx, chunks_rx) = mpsc::unbounded();
|
||||
codegen.update(cx, |codegen, cx| {
|
||||
codegen.handle_stream(
|
||||
String::new(),
|
||||
range.clone(),
|
||||
future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
||||
cx.background_executor.run_until_parked();
|
||||
|
||||
let mut new_text = concat!(
|
||||
"let mut x = 0;\n",
|
||||
"while x < 10 {\n",
|
||||
" x += 1;\n",
|
||||
"}", //
|
||||
);
|
||||
while !new_text.is_empty() {
|
||||
let max_len = cmp::min(new_text.len(), 10);
|
||||
let len = rng.gen_range(1..=max_len);
|
||||
let (chunk, suffix) = new_text.split_at(len);
|
||||
chunks_tx.unbounded_send(chunk.to_string()).unwrap();
|
||||
new_text = suffix;
|
||||
cx.background_executor.run_until_parked();
|
||||
}
|
||||
drop(chunks_tx);
|
||||
cx.background_executor.run_until_parked();
|
||||
|
||||
assert_eq!(
|
||||
buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
|
||||
indoc! {"
|
||||
fn main() {
|
||||
let mut x = 0;
|
||||
while x < 10 {
|
||||
x += 1;
|
||||
}
|
||||
}
|
||||
"}
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test(iterations = 10)]
|
||||
async fn test_autoindent_respects_tabs_in_selection(cx: &mut TestAppContext) {
|
||||
cx.update(LanguageModelRegistry::test);
|
||||
cx.set_global(cx.update(SettingsStore::test));
|
||||
cx.update(language_settings::init);
|
||||
|
||||
let text = indoc! {"
|
||||
func main() {
|
||||
\tx := 0
|
||||
\tfor i := 0; i < 10; i++ {
|
||||
\t\tx++
|
||||
\t}
|
||||
}
|
||||
"};
|
||||
let buffer = cx.new_model(|cx| Buffer::local(text, cx));
|
||||
let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
|
||||
let range = buffer.read_with(cx, |buffer, cx| {
|
||||
let snapshot = buffer.snapshot(cx);
|
||||
snapshot.anchor_before(Point::new(0, 0))..snapshot.anchor_after(Point::new(4, 2))
|
||||
});
|
||||
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
|
||||
let codegen = cx.new_model(|cx| {
|
||||
Codegen::new(
|
||||
buffer.clone(),
|
||||
range.clone(),
|
||||
None,
|
||||
None,
|
||||
prompt_builder,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
||||
let (chunks_tx, chunks_rx) = mpsc::unbounded();
|
||||
codegen.update(cx, |codegen, cx| {
|
||||
codegen.handle_stream(
|
||||
String::new(),
|
||||
range.clone(),
|
||||
future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
||||
let new_text = concat!(
|
||||
"func main() {\n",
|
||||
"\tx := 0\n",
|
||||
"\tfor x < 10 {\n",
|
||||
"\t\tx++\n",
|
||||
"\t}", //
|
||||
);
|
||||
chunks_tx.unbounded_send(new_text.to_string()).unwrap();
|
||||
drop(chunks_tx);
|
||||
cx.background_executor.run_until_parked();
|
||||
|
||||
assert_eq!(
|
||||
buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
|
||||
indoc! {"
|
||||
func main() {
|
||||
\tx := 0
|
||||
\tfor x < 10 {
|
||||
\t\tx++
|
||||
\t}
|
||||
}
|
||||
"}
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_strip_invalid_spans_from_codeblock() {
|
||||
assert_chunks("Lorem ipsum dolor", "Lorem ipsum dolor").await;
|
||||
@ -2984,4 +3318,27 @@ mod tests {
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn rust_lang() -> Language {
|
||||
Language::new(
|
||||
LanguageConfig {
|
||||
name: "Rust".into(),
|
||||
matcher: LanguageMatcher {
|
||||
path_suffixes: vec!["rs".to_string()],
|
||||
..Default::default()
|
||||
},
|
||||
..Default::default()
|
||||
},
|
||||
Some(tree_sitter_rust::language()),
|
||||
)
|
||||
.with_indents_query(
|
||||
r#"
|
||||
(call_expression) @indent
|
||||
(field_expression) @indent
|
||||
(_ "(" ")" @end) @indent
|
||||
(_ "{" "}" @end) @indent
|
||||
"#,
|
||||
)
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
|
@ -12,15 +12,11 @@ use util::ResultExt;
|
||||
pub struct ContentPromptContext {
|
||||
pub content_type: String,
|
||||
pub language_name: Option<String>,
|
||||
pub is_insert: bool,
|
||||
pub is_truncated: bool,
|
||||
pub document_content: String,
|
||||
pub user_prompt: String,
|
||||
pub rewrite_section: String,
|
||||
pub rewrite_section_prefix: String,
|
||||
pub rewrite_section_suffix: String,
|
||||
pub rewrite_section_with_edits: String,
|
||||
pub has_insertion: bool,
|
||||
pub has_replacement: bool,
|
||||
pub rewrite_section: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
@ -46,54 +42,41 @@ pub struct PromptBuilder {
|
||||
handlebars: Arc<Mutex<Handlebars<'static>>>,
|
||||
}
|
||||
|
||||
pub struct PromptOverrideContext<'a> {
|
||||
pub dev_mode: bool,
|
||||
pub fs: Arc<dyn Fs>,
|
||||
pub cx: &'a mut gpui::AppContext,
|
||||
}
|
||||
|
||||
impl PromptBuilder {
|
||||
pub fn new(override_cx: Option<PromptOverrideContext>) -> Result<Self, Box<TemplateError>> {
|
||||
pub fn new(
|
||||
fs_and_cx: Option<(Arc<dyn Fs>, &gpui::AppContext)>,
|
||||
) -> Result<Self, Box<TemplateError>> {
|
||||
let mut handlebars = Handlebars::new();
|
||||
Self::register_templates(&mut handlebars)?;
|
||||
|
||||
let handlebars = Arc::new(Mutex::new(handlebars));
|
||||
|
||||
if let Some(override_cx) = override_cx {
|
||||
Self::watch_fs_for_template_overrides(override_cx, handlebars.clone());
|
||||
if let Some((fs, cx)) = fs_and_cx {
|
||||
Self::watch_fs_for_template_overrides(fs, cx, handlebars.clone());
|
||||
}
|
||||
|
||||
Ok(Self { handlebars })
|
||||
}
|
||||
|
||||
fn watch_fs_for_template_overrides(
|
||||
PromptOverrideContext { dev_mode, fs, cx }: PromptOverrideContext,
|
||||
fs: Arc<dyn Fs>,
|
||||
cx: &gpui::AppContext,
|
||||
handlebars: Arc<Mutex<Handlebars<'static>>>,
|
||||
) {
|
||||
let templates_dir = paths::prompt_overrides_dir();
|
||||
|
||||
cx.background_executor()
|
||||
.spawn(async move {
|
||||
let templates_dir = if dev_mode {
|
||||
std::env::current_dir()
|
||||
.ok()
|
||||
.and_then(|pwd| {
|
||||
let pwd_assets_prompts = pwd.join("assets").join("prompts");
|
||||
pwd_assets_prompts.exists().then_some(pwd_assets_prompts)
|
||||
})
|
||||
.unwrap_or_else(|| paths::prompt_overrides_dir().clone())
|
||||
} else {
|
||||
paths::prompt_overrides_dir().clone()
|
||||
};
|
||||
|
||||
// Create the prompt templates directory if it doesn't exist
|
||||
if !fs.is_dir(&templates_dir).await {
|
||||
if let Err(e) = fs.create_dir(&templates_dir).await {
|
||||
if !fs.is_dir(templates_dir).await {
|
||||
if let Err(e) = fs.create_dir(templates_dir).await {
|
||||
log::error!("Failed to create prompt templates directory: {}", e);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Initial scan of the prompts directory
|
||||
if let Ok(mut entries) = fs.read_dir(&templates_dir).await {
|
||||
if let Ok(mut entries) = fs.read_dir(templates_dir).await {
|
||||
while let Some(Ok(file_path)) = entries.next().await {
|
||||
if file_path.to_string_lossy().ends_with(".hbs") {
|
||||
if let Ok(content) = fs.load(&file_path).await {
|
||||
@ -121,7 +104,7 @@ impl PromptBuilder {
|
||||
}
|
||||
|
||||
// Watch for changes
|
||||
let (mut changes, watcher) = fs.watch(&templates_dir, Duration::from_secs(1)).await;
|
||||
let (mut changes, watcher) = fs.watch(templates_dir, Duration::from_secs(1)).await;
|
||||
while let Some(changed_paths) = changes.next().await {
|
||||
for changed_path in changed_paths {
|
||||
if changed_path.extension().map_or(false, |ext| ext == "hbs") {
|
||||
@ -173,9 +156,7 @@ impl PromptBuilder {
|
||||
user_prompt: String,
|
||||
language_name: Option<&str>,
|
||||
buffer: BufferSnapshot,
|
||||
transform_range: Range<usize>,
|
||||
selected_ranges: Vec<Range<usize>>,
|
||||
transform_context_range: Range<usize>,
|
||||
range: Range<usize>,
|
||||
) -> Result<String, RenderError> {
|
||||
let content_type = match language_name {
|
||||
None | Some("Markdown" | "Plain Text") => "text",
|
||||
@ -183,20 +164,21 @@ impl PromptBuilder {
|
||||
};
|
||||
|
||||
const MAX_CTX: usize = 50000;
|
||||
let is_insert = range.is_empty();
|
||||
let mut is_truncated = false;
|
||||
|
||||
let before_range = 0..transform_range.start;
|
||||
let before_range = 0..range.start;
|
||||
let truncated_before = if before_range.len() > MAX_CTX {
|
||||
is_truncated = true;
|
||||
transform_range.start - MAX_CTX..transform_range.start
|
||||
range.start - MAX_CTX..range.start
|
||||
} else {
|
||||
before_range
|
||||
};
|
||||
|
||||
let after_range = transform_range.end..buffer.len();
|
||||
let after_range = range.end..buffer.len();
|
||||
let truncated_after = if after_range.len() > MAX_CTX {
|
||||
is_truncated = true;
|
||||
transform_range.end..transform_range.end + MAX_CTX
|
||||
range.end..range.end + MAX_CTX
|
||||
} else {
|
||||
after_range
|
||||
};
|
||||
@ -205,74 +187,37 @@ impl PromptBuilder {
|
||||
for chunk in buffer.text_for_range(truncated_before) {
|
||||
document_content.push_str(chunk);
|
||||
}
|
||||
|
||||
document_content.push_str("<rewrite_this>\n");
|
||||
for chunk in buffer.text_for_range(transform_range.clone()) {
|
||||
document_content.push_str(chunk);
|
||||
if is_insert {
|
||||
document_content.push_str("<insert_here></insert_here>");
|
||||
} else {
|
||||
document_content.push_str("<rewrite_this>\n");
|
||||
for chunk in buffer.text_for_range(range.clone()) {
|
||||
document_content.push_str(chunk);
|
||||
}
|
||||
document_content.push_str("\n</rewrite_this>");
|
||||
}
|
||||
document_content.push_str("\n</rewrite_this>");
|
||||
|
||||
for chunk in buffer.text_for_range(truncated_after) {
|
||||
document_content.push_str(chunk);
|
||||
}
|
||||
|
||||
let mut rewrite_section = String::new();
|
||||
for chunk in buffer.text_for_range(transform_range.clone()) {
|
||||
rewrite_section.push_str(chunk);
|
||||
}
|
||||
|
||||
let mut rewrite_section_prefix = String::new();
|
||||
for chunk in buffer.text_for_range(transform_context_range.start..transform_range.start) {
|
||||
rewrite_section_prefix.push_str(chunk);
|
||||
}
|
||||
|
||||
let mut rewrite_section_suffix = String::new();
|
||||
for chunk in buffer.text_for_range(transform_range.end..transform_context_range.end) {
|
||||
rewrite_section_suffix.push_str(chunk);
|
||||
}
|
||||
|
||||
let rewrite_section_with_edits = {
|
||||
let mut section_with_selections = String::new();
|
||||
let mut last_end = 0;
|
||||
for selected_range in &selected_ranges {
|
||||
if selected_range.start > last_end {
|
||||
section_with_selections.push_str(
|
||||
&rewrite_section[last_end..selected_range.start - transform_range.start],
|
||||
);
|
||||
}
|
||||
if selected_range.start == selected_range.end {
|
||||
section_with_selections.push_str("<insert_here></insert_here>");
|
||||
} else {
|
||||
section_with_selections.push_str("<edit_here>");
|
||||
section_with_selections.push_str(
|
||||
&rewrite_section[selected_range.start - transform_range.start
|
||||
..selected_range.end - transform_range.start],
|
||||
);
|
||||
section_with_selections.push_str("</edit_here>");
|
||||
}
|
||||
last_end = selected_range.end - transform_range.start;
|
||||
let rewrite_section = if !is_insert {
|
||||
let mut section = String::new();
|
||||
for chunk in buffer.text_for_range(range.clone()) {
|
||||
section.push_str(chunk);
|
||||
}
|
||||
if last_end < rewrite_section.len() {
|
||||
section_with_selections.push_str(&rewrite_section[last_end..]);
|
||||
}
|
||||
section_with_selections
|
||||
Some(section)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let has_insertion = selected_ranges.iter().any(|range| range.start == range.end);
|
||||
let has_replacement = selected_ranges.iter().any(|range| range.start != range.end);
|
||||
|
||||
let context = ContentPromptContext {
|
||||
content_type: content_type.to_string(),
|
||||
language_name: language_name.map(|s| s.to_string()),
|
||||
is_insert,
|
||||
is_truncated,
|
||||
document_content,
|
||||
user_prompt,
|
||||
rewrite_section,
|
||||
rewrite_section_prefix,
|
||||
rewrite_section_suffix,
|
||||
rewrite_section_with_edits,
|
||||
has_insertion,
|
||||
has_replacement,
|
||||
};
|
||||
|
||||
self.handlebars.lock().render("content_prompt", &context)
|
||||
|
@ -187,12 +187,7 @@ fn init_common(app_state: Arc<AppState>, cx: &mut AppContext) -> Arc<PromptBuild
|
||||
);
|
||||
snippet_provider::init(cx);
|
||||
inline_completion_registry::init(app_state.client.telemetry().clone(), cx);
|
||||
let prompt_builder = assistant::init(
|
||||
app_state.fs.clone(),
|
||||
app_state.client.clone(),
|
||||
stdout_is_a_pty(),
|
||||
cx,
|
||||
);
|
||||
let prompt_builder = assistant::init(app_state.fs.clone(), app_state.client.clone(), cx);
|
||||
repl::init(
|
||||
app_state.fs.clone(),
|
||||
app_state.client.telemetry().clone(),
|
||||
|
@ -1018,8 +1018,6 @@ fn open_settings_file(
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::stdout_is_a_pty;
|
||||
|
||||
use super::*;
|
||||
use anyhow::anyhow;
|
||||
use assets::Assets;
|
||||
@ -3487,12 +3485,8 @@ mod tests {
|
||||
app_state.fs.clone(),
|
||||
cx,
|
||||
);
|
||||
let prompt_builder = assistant::init(
|
||||
app_state.fs.clone(),
|
||||
app_state.client.clone(),
|
||||
stdout_is_a_pty(),
|
||||
cx,
|
||||
);
|
||||
let prompt_builder =
|
||||
assistant::init(app_state.fs.clone(), app_state.client.clone(), cx);
|
||||
repl::init(
|
||||
app_state.fs.clone(),
|
||||
app_state.client.telemetry().clone(),
|
||||
|
Loading…
Reference in New Issue
Block a user