From 64a55681e615d71ae5146d16b1d6f28d2237819c Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Mon, 2 Oct 2023 14:32:13 +0200 Subject: [PATCH] Summarize the contents of a file using the embedding query --- crates/assistant/src/assistant_panel.rs | 1 - crates/assistant/src/prompts.rs | 480 ++++++++++++---------- crates/language/src/buffer.rs | 12 +- crates/zed/src/languages/rust/summary.scm | 6 - 4 files changed, 264 insertions(+), 235 deletions(-) delete mode 100644 crates/zed/src/languages/rust/summary.scm diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index 37d0d729fe..816047e325 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -578,7 +578,6 @@ impl AssistantPanel { language_name, &snapshot, language_range, - cx, codegen_kind, ); diff --git a/crates/assistant/src/prompts.rs b/crates/assistant/src/prompts.rs index ee58090d04..8699c77cd1 100644 --- a/crates/assistant/src/prompts.rs +++ b/crates/assistant/src/prompts.rs @@ -1,86 +1,118 @@ -use gpui::AppContext; +use crate::codegen::CodegenKind; use language::{BufferSnapshot, OffsetRangeExt, ToOffset}; use std::cmp; use std::ops::Range; use std::{fmt::Write, iter}; -use crate::codegen::CodegenKind; - -fn outline_for_prompt( - buffer: &BufferSnapshot, - range: Range, - cx: &AppContext, -) -> Option { - let indent = buffer - .language_indent_size_at(0, cx) - .chars() - .collect::(); - let outline = buffer.outline(None)?; - let range = range.to_offset(buffer); - - let mut text = String::new(); - let mut items = outline.items.into_iter().peekable(); - - let mut intersected = false; - let mut intersection_indent = 0; - let mut extended_range = range.clone(); - - while let Some(item) = items.next() { - let item_range = item.range.to_offset(buffer); - if item_range.end < range.start || item_range.start > range.end { - text.extend(iter::repeat(indent.as_str()).take(item.depth)); - text.push_str(&item.text); - text.push('\n'); - } else { - intersected = true; - let is_terminal = items - .peek() - .map_or(true, |next_item| next_item.depth <= item.depth); - if is_terminal { - if item_range.start <= extended_range.start { - extended_range.start = item_range.start; - intersection_indent = item.depth; - } - extended_range.end = cmp::max(extended_range.end, item_range.end); - } else { - let name_start = item_range.start + item.name_ranges.first().unwrap().start; - let name_end = item_range.start + item.name_ranges.last().unwrap().end; - - if range.start > name_end { - text.extend(iter::repeat(indent.as_str()).take(item.depth)); - text.push_str(&item.text); - text.push('\n'); - } else { - if name_start <= extended_range.start { - extended_range.start = item_range.start; - intersection_indent = item.depth; - } - extended_range.end = cmp::max(extended_range.end, name_end); - } - } - } - - if intersected - && items.peek().map_or(true, |next_item| { - next_item.range.start.to_offset(buffer) > range.end - }) - { - intersected = false; - text.extend(iter::repeat(indent.as_str()).take(intersection_indent)); - text.extend(buffer.text_for_range(extended_range.start..range.start)); - text.push_str("<|START|"); - text.extend(buffer.text_for_range(range.clone())); - if range.start != range.end { - text.push_str("|END|>"); - } else { - text.push_str(">"); - } - text.extend(buffer.text_for_range(range.end..extended_range.end)); - text.push('\n'); - } +fn summarize(buffer: &BufferSnapshot, selected_range: Range) -> String { + #[derive(Debug)] + struct Match { + collapse: Range, + keep: Vec>, } - Some(text) + let selected_range = selected_range.to_offset(buffer); + let mut matches = buffer.matches(0..buffer.len(), |grammar| { + Some(&grammar.embedding_config.as_ref()?.query) + }); + let configs = matches + .grammars() + .iter() + .map(|g| g.embedding_config.as_ref().unwrap()) + .collect::>(); + let mut matches = iter::from_fn(move || { + while let Some(mat) = matches.peek() { + let config = &configs[mat.grammar_index]; + if let Some(collapse) = mat.captures.iter().find_map(|cap| { + if Some(cap.index) == config.collapse_capture_ix { + Some(cap.node.byte_range()) + } else { + None + } + }) { + let mut keep = Vec::new(); + for capture in mat.captures.iter() { + if Some(capture.index) == config.keep_capture_ix { + keep.push(capture.node.byte_range()); + } else { + continue; + } + } + matches.advance(); + return Some(Match { collapse, keep }); + } else { + matches.advance(); + } + } + None + }) + .peekable(); + + let mut summary = String::new(); + let mut offset = 0; + let mut flushed_selection = false; + while let Some(mut mat) = matches.next() { + // Keep extending the collapsed range if the next match surrounds + // the current one. + while let Some(next_mat) = matches.peek() { + if next_mat.collapse.start <= mat.collapse.start + && next_mat.collapse.end >= mat.collapse.end + { + mat = matches.next().unwrap(); + } else { + break; + } + } + + if offset >= mat.collapse.start { + // Skip collapsed nodes that have already been summarized. + offset = cmp::max(offset, mat.collapse.end); + continue; + } + + if offset <= selected_range.start && selected_range.start <= mat.collapse.end { + if !flushed_selection { + // The collapsed node ends after the selection starts, so we'll flush the selection first. + summary.extend(buffer.text_for_range(offset..selected_range.start)); + summary.push_str("<|START|"); + if selected_range.end == selected_range.start { + summary.push_str(">"); + } else { + summary.extend(buffer.text_for_range(selected_range.clone())); + summary.push_str("|END|>"); + } + offset = selected_range.end; + flushed_selection = true; + } + + // If the selection intersects the collapsed node, we won't collapse it. + if selected_range.end >= mat.collapse.start { + continue; + } + } + + summary.extend(buffer.text_for_range(offset..mat.collapse.start)); + for keep in mat.keep { + summary.extend(buffer.text_for_range(keep)); + } + offset = mat.collapse.end; + } + + // Flush selection if we haven't already done so. + if !flushed_selection && offset <= selected_range.start { + summary.extend(buffer.text_for_range(offset..selected_range.start)); + summary.push_str("<|START|"); + if selected_range.end == selected_range.start { + summary.push_str(">"); + } else { + summary.extend(buffer.text_for_range(selected_range.clone())); + summary.push_str("|END|>"); + } + offset = selected_range.end; + } + + summary.extend(buffer.text_for_range(offset..buffer.len())); + summary } pub fn generate_content_prompt( @@ -88,7 +120,6 @@ pub fn generate_content_prompt( language_name: Option<&str>, buffer: &BufferSnapshot, range: Range, - cx: &AppContext, kind: CodegenKind, ) -> String { let mut prompt = String::new(); @@ -100,19 +131,17 @@ pub fn generate_content_prompt( writeln!(prompt, "You're an expert engineer.\n").unwrap(); } - let outline = outline_for_prompt(buffer, range.clone(), cx); - if let Some(outline) = outline { - writeln!( - prompt, - "The file you are currently working on has the following outline:" - ) - .unwrap(); - if let Some(language_name) = language_name { - let language_name = language_name.to_lowercase(); - writeln!(prompt, "```{language_name}\n{outline}\n```").unwrap(); - } else { - writeln!(prompt, "```\n{outline}\n```").unwrap(); - } + let outline = summarize(buffer, range.clone()); + writeln!( + prompt, + "The file you are currently working on has the following outline:" + ) + .unwrap(); + if let Some(language_name) = language_name { + let language_name = language_name.to_lowercase(); + writeln!(prompt, "```{language_name}\n{outline}\n```").unwrap(); + } else { + writeln!(prompt, "```\n{outline}\n```").unwrap(); } // Assume for now that we are just generating @@ -183,39 +212,37 @@ pub(crate) mod tests { }, Some(tree_sitter_rust::language()), ) - .with_indents_query( + .with_embedding_query( r#" - (call_expression) @indent - (field_expression) @indent - (_ "(" ")" @end) @indent - (_ "{" "}" @end) @indent - "#, - ) - .unwrap() - .with_outline_query( - r#" - (struct_item - "struct" @context - name: (_) @name) @item - (enum_item - "enum" @context - name: (_) @name) @item - (enum_variant - name: (_) @name) @item - (field_declaration - name: (_) @name) @item - (impl_item - "impl" @context - trait: (_)? @name - "for"? @context - type: (_) @name) @item - (function_item - "fn" @context - name: (_) @name) @item - (mod_item - "mod" @context - name: (_) @name) @item - "#, + ( + [(line_comment) (attribute_item)]* @context + . + [ + (struct_item + name: (_) @name) + + (enum_item + name: (_) @name) + + (impl_item + trait: (_)? @name + "for"? @name + type: (_) @name) + + (trait_item + name: (_) @name) + + (function_item + name: (_) @name + body: (block + "{" @keep + "}" @keep) @collapse) + + (macro_definition + name: (_) @name) + ] @item + ) + "#, ) .unwrap() } @@ -251,132 +278,133 @@ pub(crate) mod tests { cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx)); let snapshot = buffer.read(cx).snapshot(); - let outline = outline_for_prompt( - &snapshot, - snapshot.anchor_before(Point::new(1, 4))..snapshot.anchor_before(Point::new(1, 4)), - cx, - ); assert_eq!( - outline.as_deref(), - Some(indoc! {" - struct X - <|START|>a: usize - b - impl X - fn new - fn a - fn b - "}) + summarize(&snapshot, Point::new(1, 4)..Point::new(1, 4)), + indoc! {" + struct X { + <|START|>a: usize, + b: usize, + } + + impl X { + + fn new() -> Self {} + + pub fn a(&self, param: bool) -> usize {} + + pub fn b(&self) -> usize {} + } + "} ); - let outline = outline_for_prompt( - &snapshot, - snapshot.anchor_before(Point::new(8, 12))..snapshot.anchor_before(Point::new(8, 14)), - cx, - ); assert_eq!( - outline.as_deref(), - Some(indoc! {" - struct X - a - b - impl X + summarize(&snapshot, Point::new(8, 12)..Point::new(8, 14)), + indoc! {" + struct X { + a: usize, + b: usize, + } + + impl X { + fn new() -> Self { let <|START|a |END|>= 1; let b = 2; Self { a, b } } - fn a - fn b - "}) + + pub fn a(&self, param: bool) -> usize {} + + pub fn b(&self) -> usize {} + } + "} ); - let outline = outline_for_prompt( - &snapshot, - snapshot.anchor_before(Point::new(6, 0))..snapshot.anchor_before(Point::new(6, 0)), - cx, - ); assert_eq!( - outline.as_deref(), - Some(indoc! {" - struct X - a - b - impl X + summarize(&snapshot, Point::new(6, 0)..Point::new(6, 0)), + indoc! {" + struct X { + a: usize, + b: usize, + } + + impl X { <|START|> - fn new - fn a - fn b - "}) + fn new() -> Self {} + + pub fn a(&self, param: bool) -> usize {} + + pub fn b(&self) -> usize {} + } + "} ); - let outline = outline_for_prompt( - &snapshot, - snapshot.anchor_before(Point::new(8, 12))..snapshot.anchor_before(Point::new(13, 9)), - cx, - ); assert_eq!( - outline.as_deref(), - Some(indoc! {" - struct X - a - b - impl X - fn new() -> Self { - let <|START|a = 1; - let b = 2; - Self { a, b } - } + summarize(&snapshot, Point::new(21, 0)..Point::new(21, 0)), + indoc! {" + struct X { + a: usize, + b: usize, + } - pub f|END|>n a(&self, param: bool) -> usize { - self.a - } - fn b - "}) + impl X { + + fn new() -> Self {} + + pub fn a(&self, param: bool) -> usize {} + + pub fn b(&self) -> usize {} + } + <|START|>"} ); - let outline = outline_for_prompt( - &snapshot, - snapshot.anchor_before(Point::new(5, 6))..snapshot.anchor_before(Point::new(12, 0)), - cx, - ); + // Ensure nested functions get collapsed properly. + let text = indoc! {" + struct X { + a: usize, + b: usize, + } + + impl X { + + fn new() -> Self { + let a = 1; + let b = 2; + Self { a, b } + } + + pub fn a(&self, param: bool) -> usize { + let a = 30; + fn nested() -> usize { + 3 + } + self.a + nested() + } + + pub fn b(&self) -> usize { + self.b + } + } + "}; + buffer.update(cx, |buffer, cx| buffer.set_text(text, cx)); + let snapshot = buffer.read(cx).snapshot(); assert_eq!( - outline.as_deref(), - Some(indoc! {" - struct X - a - b - impl X<|START| { + summarize(&snapshot, Point::new(0, 0)..Point::new(0, 0)), + indoc! {" + <|START|>struct X { + a: usize, + b: usize, + } - fn new() -> Self { - let a = 1; - let b = 2; - Self { a, b } - } - |END|> - fn a - fn b - "}) - ); + impl X { - let outline = outline_for_prompt( - &snapshot, - snapshot.anchor_before(Point::new(18, 8))..snapshot.anchor_before(Point::new(18, 8)), - cx, - ); - assert_eq!( - outline.as_deref(), - Some(indoc! {" - struct X - a - b - impl X - fn new - fn a - pub fn b(&self) -> usize { - <|START|>self.b - } - "}) + fn new() -> Self {} + + pub fn a(&self, param: bool) -> usize {} + + pub fn b(&self) -> usize {} + } + "} ); } } diff --git a/crates/language/src/buffer.rs b/crates/language/src/buffer.rs index 38b2842c12..27b01543e1 100644 --- a/crates/language/src/buffer.rs +++ b/crates/language/src/buffer.rs @@ -8,8 +8,8 @@ use crate::{ language_settings::{language_settings, LanguageSettings}, outline::OutlineItem, syntax_map::{ - SyntaxLayerInfo, SyntaxMap, SyntaxMapCapture, SyntaxMapCaptures, SyntaxSnapshot, - ToTreeSitterPoint, + SyntaxLayerInfo, SyntaxMap, SyntaxMapCapture, SyntaxMapCaptures, SyntaxMapMatches, + SyntaxSnapshot, ToTreeSitterPoint, }, CodeLabel, LanguageScope, Outline, }; @@ -2467,6 +2467,14 @@ impl BufferSnapshot { Some(items) } + pub fn matches( + &self, + range: Range, + query: fn(&Grammar) -> Option<&tree_sitter::Query>, + ) -> SyntaxMapMatches { + self.syntax.matches(range, self, query) + } + /// Returns bracket range pairs overlapping or adjacent to `range` pub fn bracket_ranges<'a, T: ToOffset>( &'a self, diff --git a/crates/zed/src/languages/rust/summary.scm b/crates/zed/src/languages/rust/summary.scm deleted file mode 100644 index 7174eec3c3..0000000000 --- a/crates/zed/src/languages/rust/summary.scm +++ /dev/null @@ -1,6 +0,0 @@ -(function_item - body: (block - "{" @keep - "}" @keep) @collapse) - -(use_declaration) @collapse