From cb0b8b4c4bf4dc7594ccddd0ffff4ba46033939b Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Fri, 21 Jun 2024 17:41:43 +0200 Subject: [PATCH] Introduce multi-cursor inline transformations (#13368) https://github.com/zed-industries/zed/assets/482957/591def34-e5c8-4402-9c6b-372cbca720c3 Release Notes: - N/A --------- Co-authored-by: Richard Feldman --- Cargo.lock | 10 + crates/assistant/Cargo.toml | 1 + crates/assistant/src/inline_assistant.rs | 1621 ++++++++++++-------- crates/assistant/src/prompts.rs | 82 +- crates/collab/src/tests/editor_tests.rs | 4 +- crates/diagnostics/src/diagnostics.rs | 2 +- crates/editor/src/display_map.rs | 21 +- crates/editor/src/display_map/block_map.rs | 81 +- crates/editor/src/editor.rs | 24 +- crates/editor/src/element.rs | 66 +- crates/editor/src/items.rs | 30 +- crates/editor/src/scroll/autoscroll.rs | 6 +- crates/gpui/src/window.rs | 65 +- crates/language/src/buffer.rs | 10 +- crates/language/src/buffer_tests.rs | 2 +- crates/multi_buffer/src/multi_buffer.rs | 8 +- 16 files changed, 1335 insertions(+), 698 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 848752db02..eaf1cf3bb1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -362,6 +362,7 @@ dependencies = [ "anthropic", "anyhow", "assistant_slash_command", + "async-watch", "cargo_toml", "chrono", "client", @@ -873,6 +874,15 @@ dependencies = [ "tungstenite 0.16.0", ] +[[package]] +name = "async-watch" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a078faf4e27c0c6cc0efb20e5da59dcccc04968ebf2801d8e0b2195124cdcdb2" +dependencies = [ + "event-listener 2.5.3", +] + [[package]] name = "async_zip" version = "0.0.17" diff --git a/crates/assistant/Cargo.toml b/crates/assistant/Cargo.toml index d3c52e7ab5..c8f84e3b9e 100644 --- a/crates/assistant/Cargo.toml +++ b/crates/assistant/Cargo.toml @@ -16,6 +16,7 @@ doctest = false anyhow.workspace = true anthropic = { workspace = true, features = ["schemars"] } assistant_slash_command.workspace = true +async-watch.workspace = true cargo_toml.workspace = true chrono.workspace = true client.workspace = true diff --git a/crates/assistant/src/inline_assistant.rs b/crates/assistant/src/inline_assistant.rs index b934be4e9c..e655b66e6c 100644 --- a/crates/assistant/src/inline_assistant.rs +++ b/crates/assistant/src/inline_assistant.rs @@ -2,34 +2,36 @@ use crate::{ prompts::generate_content_prompt, AssistantPanel, CompletionProvider, Hunk, LanguageModelRequest, LanguageModelRequestMessage, Role, StreamingDiff, }; -use anyhow::Result; +use anyhow::{Context as _, Result}; use client::telemetry::Telemetry; use collections::{hash_map, HashMap, HashSet, VecDeque}; use editor::{ actions::{MoveDown, MoveUp, SelectAll}, display_map::{ BlockContext, BlockDisposition, BlockId, BlockProperties, BlockStyle, RenderBlock, + ToDisplayPoint, }, - scroll::{Autoscroll, AutoscrollStrategy}, - Anchor, AnchorRangeExt, Editor, EditorElement, EditorEvent, EditorStyle, ExcerptRange, - GutterDimensions, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint, + Anchor, AnchorRangeExt, Editor, EditorElement, EditorEvent, EditorMode, EditorStyle, + ExcerptRange, GutterDimensions, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint, }; use futures::{channel::mpsc, SinkExt, Stream, StreamExt}; use gpui::{ - AppContext, EventEmitter, FocusHandle, FocusableView, FontStyle, FontWeight, Global, + point, AppContext, EventEmitter, FocusHandle, FocusableView, FontStyle, FontWeight, Global, HighlightStyle, Model, ModelContext, Subscription, Task, TextStyle, UpdateGlobal, View, ViewContext, WeakView, WhiteSpace, WindowContext, }; -use language::{Buffer, Point, TransactionId}; +use language::{Buffer, Point, Selection, TransactionId}; use multi_buffer::MultiBufferRow; use parking_lot::Mutex; use rope::Rope; use settings::Settings; use similar::TextDiff; use std::{ - cmp, future, mem, + cmp, mem, ops::{Range, RangeInclusive}, + pin::Pin, sync::Arc, + task::{self, Poll}, time::Instant, }; use theme::ThemeSettings; @@ -45,8 +47,10 @@ const PROMPT_HISTORY_MAX_LEN: usize = 20; pub struct InlineAssistant { next_assist_id: InlineAssistId, - pending_assists: HashMap, - pending_assist_ids_by_editor: HashMap, Vec>, + next_assist_group_id: InlineAssistGroupId, + assists: HashMap, + assists_by_editor: HashMap, EditorInlineAssists>, + assist_groups: HashMap, prompt_history: VecDeque, telemetry: Option>, } @@ -57,8 +61,10 @@ impl InlineAssistant { pub fn new(telemetry: Arc) -> Self { Self { next_assist_id: InlineAssistId::default(), - pending_assists: HashMap::default(), - pending_assist_ids_by_editor: HashMap::default(), + next_assist_group_id: InlineAssistGroupId::default(), + assists: HashMap::default(), + assists_by_editor: HashMap::default(), + assist_groups: HashMap::default(), prompt_history: VecDeque::default(), telemetry: Some(telemetry), } @@ -71,380 +77,452 @@ impl InlineAssistant { include_context: bool, cx: &mut WindowContext, ) { - let selection = editor.read(cx).selections.newest_anchor().clone(); - if selection.start.excerpt_id != selection.end.excerpt_id { - return; - } let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx); - // Extend the selection to the start and the end of the line. - let mut point_selection = selection.map(|selection| selection.to_point(&snapshot)); - if point_selection.end > point_selection.start { - point_selection.start.column = 0; - // If the selection ends at the start of the line, we don't want to include it. - if point_selection.end.column == 0 { - point_selection.end.row -= 1; + let mut selections = Vec::>::new(); + let mut newest_selection = None; + for mut selection in editor.read(cx).selections.all::(cx) { + if selection.end > selection.start { + selection.start.column = 0; + // If the selection ends at the start of the line, we don't want to include it. + if selection.end.column == 0 { + selection.end.row -= 1; + } + selection.end.column = snapshot.line_len(MultiBufferRow(selection.end.row)); } - point_selection.end.column = snapshot.line_len(MultiBufferRow(point_selection.end.row)); + + if let Some(prev_selection) = selections.last_mut() { + if selection.start <= prev_selection.end { + prev_selection.end = selection.end; + continue; + } + } + + let latest_selection = newest_selection.get_or_insert_with(|| selection.clone()); + if selection.id > latest_selection.id { + *latest_selection = selection.clone(); + } + selections.push(selection); + } + let newest_selection = newest_selection.unwrap(); + + let mut codegen_ranges = Vec::new(); + for (excerpt_id, buffer, buffer_range) in + snapshot.excerpts_in_ranges(selections.iter().map(|selection| { + snapshot.anchor_before(selection.start)..snapshot.anchor_after(selection.end) + })) + { + let start = Anchor { + buffer_id: Some(buffer.remote_id()), + excerpt_id, + text_anchor: buffer.anchor_before(buffer_range.start), + }; + let end = Anchor { + buffer_id: Some(buffer.remote_id()), + excerpt_id, + text_anchor: buffer.anchor_after(buffer_range.end), + }; + codegen_ranges.push(start..end); } - let codegen_kind = if point_selection.start == point_selection.end { - CodegenKind::Generate { - position: snapshot.anchor_after(point_selection.start), - } - } else { - CodegenKind::Transform { - range: snapshot.anchor_before(point_selection.start) - ..snapshot.anchor_after(point_selection.end), - } - }; + let assist_group_id = self.next_assist_group_id.post_inc(); + let prompt_buffer = cx.new_model(|cx| Buffer::local("", cx)); + let prompt_buffer = cx.new_model(|cx| MultiBuffer::singleton(prompt_buffer, cx)); - let assist_id = self.next_assist_id.post_inc(); - let codegen = cx.new_model(|cx| { - Codegen::new( - editor.read(cx).buffer().clone(), - codegen_kind, - self.telemetry.clone(), - cx, - ) - }); - - let gutter_dimensions = Arc::new(Mutex::new(GutterDimensions::default())); - let prompt_editor = cx.new_view(|cx| { - InlineAssistEditor::new( - assist_id, - gutter_dimensions.clone(), - self.prompt_history.clone(), - codegen.clone(), - workspace.clone(), - cx, - ) - }); - let (prompt_block_id, end_block_id) = editor.update(cx, |editor, cx| { - let start_anchor = snapshot.anchor_before(point_selection.start); - let end_anchor = snapshot.anchor_after(point_selection.end); - editor.change_selections(Some(Autoscroll::newest()), cx, |selections| { - selections.select_anchor_ranges([start_anchor..start_anchor]) + let mut assists = Vec::new(); + let mut assist_blocks = Vec::new(); + let mut assist_to_focus = None; + for range in codegen_ranges { + let assist_id = self.next_assist_id.post_inc(); + let codegen = cx.new_model(|cx| { + Codegen::new( + editor.read(cx).buffer().clone(), + range.clone(), + self.telemetry.clone(), + cx, + ) }); - let block_ids = editor.insert_blocks( - [ - BlockProperties { - style: BlockStyle::Sticky, - position: start_anchor, - height: prompt_editor.read(cx).height_in_lines, - render: build_inline_assist_editor_renderer( - &prompt_editor, - gutter_dimensions, - ), - disposition: BlockDisposition::Above, - }, - BlockProperties { - style: BlockStyle::Sticky, - position: end_anchor, - height: 1, - render: Box::new(|cx| { - v_flex() - .h_full() - .w_full() - .border_t_1() - .border_color(cx.theme().status().info_border) - .into_any_element() - }), - disposition: BlockDisposition::Below, - }, - ], - Some(Autoscroll::Strategy(AutoscrollStrategy::Newest)), - cx, - ); - (block_ids[0], block_ids[1]) + + let gutter_dimensions = Arc::new(Mutex::new(GutterDimensions::default())); + let prompt_editor = cx.new_view(|cx| { + PromptEditor::new( + assist_id, + gutter_dimensions.clone(), + self.prompt_history.clone(), + prompt_buffer.clone(), + codegen.clone(), + workspace.clone(), + cx, + ) + }); + + if assist_to_focus.is_none() { + let focus_assist = if newest_selection.reversed { + range.start.to_point(&snapshot) == newest_selection.start + } else { + range.end.to_point(&snapshot) == newest_selection.end + }; + if focus_assist { + assist_to_focus = Some(assist_id); + } + } + + assist_blocks.push(BlockProperties { + style: BlockStyle::Sticky, + position: range.start, + height: prompt_editor.read(cx).height_in_lines, + render: build_assist_editor_renderer(&prompt_editor), + disposition: BlockDisposition::Above, + }); + assist_blocks.push(BlockProperties { + style: BlockStyle::Sticky, + position: range.end, + height: 1, + render: Box::new(|cx| { + v_flex() + .h_full() + .w_full() + .border_t_1() + .border_color(cx.theme().status().info_border) + .into_any_element() + }), + disposition: BlockDisposition::Below, + }); + assists.push((assist_id, prompt_editor)); + } + + let assist_block_ids = editor.update(cx, |editor, cx| { + editor.insert_blocks(assist_blocks, None, cx) }); - self.pending_assists.insert( - assist_id, - PendingInlineAssist { - include_context, - editor: editor.downgrade(), - editor_decorations: Some(PendingInlineAssistDecorations { - prompt_block_id, - prompt_editor: prompt_editor.clone(), - removed_line_block_ids: HashSet::default(), - end_block_id, - }), - codegen: codegen.clone(), - workspace, - _subscriptions: vec![ - cx.subscribe(&prompt_editor, |inline_assist_editor, event, cx| { - InlineAssistant::update_global(cx, |this, cx| { - this.handle_inline_assistant_editor_event( - inline_assist_editor, - event, - cx, - ) - }) - }), - editor.update(cx, |editor, _cx| { - editor.register_action( - move |_: &editor::actions::Newline, cx: &mut WindowContext| { - InlineAssistant::update_global(cx, |this, cx| { - this.handle_editor_newline(assist_id, cx) - }) - }, - ) - }), - editor.update(cx, |editor, _cx| { - editor.register_action( - move |_: &editor::actions::Cancel, cx: &mut WindowContext| { - InlineAssistant::update_global(cx, |this, cx| { - this.handle_editor_cancel(assist_id, cx) - }) - }, - ) - }), - cx.subscribe(editor, move |editor, event, cx| { - InlineAssistant::update_global(cx, |this, cx| { - this.handle_editor_event(assist_id, editor, event, cx) - }) - }), - cx.observe(&codegen, { - let editor = editor.downgrade(); - move |_, cx| { - if let Some(editor) = editor.upgrade() { - InlineAssistant::update_global(cx, |this, cx| { - this.update_editor_highlights(&editor, cx); - this.update_editor_blocks(&editor, assist_id, cx); - }) - } - } - }), - cx.subscribe(&codegen, move |codegen, event, cx| { - InlineAssistant::update_global(cx, |this, cx| match event { - CodegenEvent::Undone => this.finish_inline_assist(assist_id, false, cx), - CodegenEvent::Finished => { - let pending_assist = if let Some(pending_assist) = - this.pending_assists.get(&assist_id) - { - pending_assist - } else { - return; - }; - - if let CodegenStatus::Error(error) = &codegen.read(cx).status { - if pending_assist.editor_decorations.is_none() { - if let Some(workspace) = pending_assist - .workspace - .as_ref() - .and_then(|workspace| workspace.upgrade()) - { - let error = - format!("Inline assistant error: {}", error); - workspace.update(cx, |workspace, cx| { - struct InlineAssistantError; - - let id = NotificationId::identified::< - InlineAssistantError, - >( - assist_id.0 - ); - - workspace.show_toast(Toast::new(id, error), cx); - }) - } - } - } - - if pending_assist.editor_decorations.is_none() { - this.finish_inline_assist(assist_id, false, cx); - } - } - }) - }), - ], - }, - ); - - self.pending_assist_ids_by_editor + let editor_assists = self + .assists_by_editor .entry(editor.downgrade()) - .or_default() - .push(assist_id); - self.update_editor_highlights(editor, cx); + .or_insert_with(|| EditorInlineAssists::new(&editor, cx)); + let mut assist_group = InlineAssistGroup::new(); + for ((assist_id, prompt_editor), block_ids) in + assists.into_iter().zip(assist_block_ids.chunks_exact(2)) + { + self.assists.insert( + assist_id, + InlineAssist::new( + assist_id, + assist_group_id, + include_context, + editor, + &prompt_editor, + block_ids[0], + block_ids[1], + prompt_editor.read(cx).codegen.clone(), + workspace.clone(), + cx, + ), + ); + assist_group.assist_ids.push(assist_id); + editor_assists.assist_ids.push(assist_id); + } + self.assist_groups.insert(assist_group_id, assist_group); + + if let Some(assist_id) = assist_to_focus { + self.focus_assist(assist_id, cx); + } } - fn handle_inline_assistant_editor_event( + fn handle_prompt_editor_focus_in(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) { + let assist = &self.assists[&assist_id]; + let Some(decorations) = assist.decorations.as_ref() else { + return; + }; + let assist_group = self.assist_groups.get_mut(&assist.group_id).unwrap(); + let editor_assists = self.assists_by_editor.get_mut(&assist.editor).unwrap(); + + assist_group.active_assist_id = Some(assist_id); + if assist_group.linked { + for assist_id in &assist_group.assist_ids { + if let Some(decorations) = self.assists[assist_id].decorations.as_ref() { + decorations.prompt_editor.update(cx, |prompt_editor, cx| { + prompt_editor.set_show_cursor_when_unfocused(true, cx) + }); + } + } + } + + assist + .editor + .update(cx, |editor, cx| { + let scroll_top = editor.scroll_position(cx).y; + let scroll_bottom = scroll_top + editor.visible_line_count().unwrap_or(0.); + let prompt_row = editor + .row_for_block(decorations.prompt_block_id, cx) + .unwrap() + .0 as f32; + + if (scroll_top..scroll_bottom).contains(&prompt_row) { + editor_assists.scroll_lock = Some(InlineAssistScrollLock { + assist_id, + distance_from_top: prompt_row - scroll_top, + }); + } else { + editor_assists.scroll_lock = None; + } + }) + .ok(); + } + + fn handle_prompt_editor_focus_out( &mut self, - inline_assist_editor: View, - event: &InlineAssistEditorEvent, + assist_id: InlineAssistId, cx: &mut WindowContext, ) { - let assist_id = inline_assist_editor.read(cx).id; - match event { - InlineAssistEditorEvent::StartRequested => { - self.start_inline_assist(assist_id, cx); - } - InlineAssistEditorEvent::StopRequested => { - self.stop_inline_assist(assist_id, cx); - } - InlineAssistEditorEvent::ConfirmRequested => { - self.finish_inline_assist(assist_id, false, cx); - } - InlineAssistEditorEvent::CancelRequested => { - self.finish_inline_assist(assist_id, true, cx); - } - InlineAssistEditorEvent::DismissRequested => { - self.dismiss_inline_assist(assist_id, cx); - } - InlineAssistEditorEvent::Resized { height_in_lines } => { - self.resize_inline_assist(assist_id, *height_in_lines, cx); + let assist = &self.assists[&assist_id]; + let assist_group = self.assist_groups.get_mut(&assist.group_id).unwrap(); + if assist_group.active_assist_id == Some(assist_id) { + assist_group.active_assist_id = None; + if assist_group.linked { + for assist_id in &assist_group.assist_ids { + if let Some(decorations) = self.assists[assist_id].decorations.as_ref() { + decorations.prompt_editor.update(cx, |prompt_editor, cx| { + prompt_editor.set_show_cursor_when_unfocused(false, cx) + }); + } + } } } } - fn handle_editor_newline(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) { - let Some(assist) = self.pending_assists.get(&assist_id) else { - return; - }; - let Some(editor) = assist.editor.upgrade() else { + fn handle_prompt_editor_event( + &mut self, + prompt_editor: View, + event: &PromptEditorEvent, + cx: &mut WindowContext, + ) { + let assist_id = prompt_editor.read(cx).id; + match event { + PromptEditorEvent::StartRequested => { + self.start_assist(assist_id, cx); + } + PromptEditorEvent::StopRequested => { + self.stop_assist(assist_id, cx); + } + PromptEditorEvent::ConfirmRequested => { + self.finish_assist(assist_id, false, cx); + } + PromptEditorEvent::CancelRequested => { + self.finish_assist(assist_id, true, cx); + } + PromptEditorEvent::DismissRequested => { + self.dismiss_assist(assist_id, cx); + } + PromptEditorEvent::Resized { height_in_lines } => { + self.resize_assist(assist_id, *height_in_lines, cx); + } + } + } + + fn handle_editor_newline(&mut self, editor: View, cx: &mut WindowContext) { + let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) else { return; }; - let buffer = editor.read(cx).buffer().read(cx).snapshot(cx); - let assist_range = assist.codegen.read(cx).range().to_offset(&buffer); let editor = editor.read(cx); if editor.selections.count() == 1 { let selection = editor.selections.newest::(cx); - if assist_range.contains(&selection.start) && assist_range.contains(&selection.end) { - if matches!(assist.codegen.read(cx).status, CodegenStatus::Pending) { - self.dismiss_inline_assist(assist_id, cx); - } else { - self.finish_inline_assist(assist_id, false, cx); - } + let buffer = editor.buffer().read(cx).snapshot(cx); + for assist_id in &editor_assists.assist_ids { + let assist = &self.assists[assist_id]; + let assist_range = assist.codegen.read(cx).range.to_offset(&buffer); + if assist_range.contains(&selection.start) && assist_range.contains(&selection.end) + { + if matches!(assist.codegen.read(cx).status, CodegenStatus::Pending) { + self.dismiss_assist(*assist_id, cx); + } else { + self.finish_assist(*assist_id, false, cx); + } - return; + return; + } } } cx.propagate(); } - fn handle_editor_cancel(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) { - let Some(assist) = self.pending_assists.get(&assist_id) else { - return; - }; - let Some(editor) = assist.editor.upgrade() else { + fn handle_editor_cancel(&mut self, editor: View, cx: &mut WindowContext) { + let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) else { return; }; - let buffer = editor.read(cx).buffer().read(cx).snapshot(cx); - let assist_range = assist.codegen.read(cx).range().to_offset(&buffer); - let propagate = editor.update(cx, |editor, cx| { - if let Some(decorations) = assist.editor_decorations.as_ref() { - if editor.selections.count() == 1 { - let selection = editor.selections.newest::(cx); - if assist_range.contains(&selection.start) - && assist_range.contains(&selection.end) - { - editor.change_selections(Some(Autoscroll::newest()), cx, |selections| { - selections.select_ranges([assist_range.start..assist_range.start]); - }); - decorations.prompt_editor.update(cx, |prompt_editor, cx| { - prompt_editor.editor.update(cx, |prompt_editor, cx| { - prompt_editor.select_all(&SelectAll, cx); - prompt_editor.focus(cx); - }); - }); - return false; - } + let editor = editor.read(cx); + if editor.selections.count() == 1 { + let selection = editor.selections.newest::(cx); + let buffer = editor.buffer().read(cx).snapshot(cx); + for assist_id in &editor_assists.assist_ids { + let assist = &self.assists[assist_id]; + let assist_range = assist.codegen.read(cx).range.to_offset(&buffer); + if assist.decorations.is_some() + && assist_range.contains(&selection.start) + && assist_range.contains(&selection.end) + { + self.focus_assist(*assist_id, cx); + return; } } - true - }); - - if propagate { - cx.propagate(); } + + cx.propagate(); + } + + fn handle_editor_change(&mut self, editor: View, cx: &mut WindowContext) { + let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) else { + return; + }; + let Some(scroll_lock) = editor_assists.scroll_lock.as_ref() else { + return; + }; + let assist = &self.assists[&scroll_lock.assist_id]; + let Some(decorations) = assist.decorations.as_ref() else { + return; + }; + + editor.update(cx, |editor, cx| { + let scroll_position = editor.scroll_position(cx); + let target_scroll_top = editor + .row_for_block(decorations.prompt_block_id, cx) + .unwrap() + .0 as f32 + - scroll_lock.distance_from_top; + if target_scroll_top != scroll_position.y { + editor.set_scroll_position(point(scroll_position.x, target_scroll_top), cx); + } + }); } fn handle_editor_event( &mut self, - assist_id: InlineAssistId, editor: View, event: &EditorEvent, cx: &mut WindowContext, ) { - let Some(assist) = self.pending_assists.get(&assist_id) else { + let Some(editor_assists) = self.assists_by_editor.get_mut(&editor.downgrade()) else { return; }; match event { - EditorEvent::SelectionsChanged { local } if *local => { - if let CodegenStatus::Idle = &assist.codegen.read(cx).status { - self.finish_inline_assist(assist_id, true, cx); - } - } EditorEvent::Saved => { - if let CodegenStatus::Done = &assist.codegen.read(cx).status { - self.finish_inline_assist(assist_id, false, cx) + for assist_id in editor_assists.assist_ids.clone() { + let assist = &self.assists[&assist_id]; + if let CodegenStatus::Done = &assist.codegen.read(cx).status { + self.finish_assist(assist_id, false, cx) + } } } - EditorEvent::Edited { transaction_id } - if matches!( - assist.codegen.read(cx).status, - CodegenStatus::Error(_) | CodegenStatus::Done - ) => - { + EditorEvent::Edited { transaction_id } => { let buffer = editor.read(cx).buffer().read(cx); let edited_ranges = buffer.edited_ranges_for_transaction::(*transaction_id, cx); - let assist_range = assist.codegen.read(cx).range().to_offset(&buffer.read(cx)); - if edited_ranges - .iter() - .any(|range| range.overlaps(&assist_range)) - { - self.finish_inline_assist(assist_id, false, cx); + let snapshot = buffer.snapshot(cx); + + for assist_id in editor_assists.assist_ids.clone() { + let assist = &self.assists[&assist_id]; + if matches!( + assist.codegen.read(cx).status, + CodegenStatus::Error(_) | CodegenStatus::Done + ) { + let assist_range = assist.codegen.read(cx).range.to_offset(&snapshot); + if edited_ranges + .iter() + .any(|range| range.overlaps(&assist_range)) + { + self.finish_assist(assist_id, false, cx); + } + } } } + EditorEvent::ScrollPositionChanged { .. } => { + if let Some(scroll_lock) = editor_assists.scroll_lock.as_ref() { + let assist = &self.assists[&scroll_lock.assist_id]; + if let Some(decorations) = assist.decorations.as_ref() { + let distance_from_top = editor.update(cx, |editor, cx| { + let scroll_top = editor.scroll_position(cx).y; + let prompt_row = editor + .row_for_block(decorations.prompt_block_id, cx) + .unwrap() + .0 as f32; + prompt_row - scroll_top + }); + + if distance_from_top != scroll_lock.distance_from_top { + editor_assists.scroll_lock = None; + } + } + } + } + EditorEvent::SelectionsChanged { .. } => { + for assist_id in editor_assists.assist_ids.clone() { + let assist = &self.assists[&assist_id]; + if let Some(decorations) = assist.decorations.as_ref() { + if decorations.prompt_editor.focus_handle(cx).is_focused(cx) { + return; + } + } + } + + editor_assists.scroll_lock = None; + } _ => {} } } - fn finish_inline_assist( - &mut self, - assist_id: InlineAssistId, - undo: bool, - cx: &mut WindowContext, - ) { - self.dismiss_inline_assist(assist_id, cx); + fn finish_assist(&mut self, assist_id: InlineAssistId, undo: bool, cx: &mut WindowContext) { + if let Some(assist) = self.assists.get(&assist_id) { + let assist_group_id = assist.group_id; + if self.assist_groups[&assist_group_id].linked { + for assist_id in self.unlink_assist_group(assist_group_id, cx) { + self.finish_assist(assist_id, undo, cx); + } + return; + } + } - if let Some(pending_assist) = self.pending_assists.remove(&assist_id) { - if let hash_map::Entry::Occupied(mut entry) = self - .pending_assist_ids_by_editor - .entry(pending_assist.editor.clone()) + self.dismiss_assist(assist_id, cx); + + if let Some(assist) = self.assists.remove(&assist_id) { + if let hash_map::Entry::Occupied(mut entry) = self.assist_groups.entry(assist.group_id) { - entry.get_mut().retain(|id| *id != assist_id); - if entry.get().is_empty() { + entry.get_mut().assist_ids.retain(|id| *id != assist_id); + if entry.get().assist_ids.is_empty() { entry.remove(); } } - if let Some(editor) = pending_assist.editor.upgrade() { - self.update_editor_highlights(&editor, cx); - - if undo { - pending_assist - .codegen - .update(cx, |codegen, cx| codegen.undo(cx)); + if let hash_map::Entry::Occupied(mut entry) = + self.assists_by_editor.entry(assist.editor.clone()) + { + entry.get_mut().assist_ids.retain(|id| *id != assist_id); + if entry.get().assist_ids.is_empty() { + entry.remove(); + if let Some(editor) = assist.editor.upgrade() { + self.update_editor_highlights(&editor, cx); + } + } else { + entry.get().highlight_updates.send(()).ok(); } } + + if undo { + assist.codegen.update(cx, |codegen, cx| codegen.undo(cx)); + } } } - fn dismiss_inline_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) -> bool { - let Some(pending_assist) = self.pending_assists.get_mut(&assist_id) else { + fn dismiss_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) -> bool { + let Some(assist) = self.assists.get_mut(&assist_id) else { return false; }; - let Some(editor) = pending_assist.editor.upgrade() else { + let Some(editor) = assist.editor.upgrade() else { return false; }; - let Some(decorations) = pending_assist.editor_decorations.take() else { + let Some(decorations) = assist.decorations.take() else { return false; }; @@ -453,39 +531,136 @@ impl InlineAssistant { to_remove.insert(decorations.prompt_block_id); to_remove.insert(decorations.end_block_id); editor.remove_blocks(to_remove, None, cx); - if decorations - .prompt_editor - .focus_handle(cx) - .contains_focused(cx) - { - editor.focus(cx); - } }); - self.update_editor_highlights(&editor, cx); + if decorations + .prompt_editor + .focus_handle(cx) + .contains_focused(cx) + { + self.focus_next_assist(assist_id, cx); + } + + if let Some(editor_assists) = self.assists_by_editor.get_mut(&editor.downgrade()) { + if editor_assists + .scroll_lock + .as_ref() + .map_or(false, |lock| lock.assist_id == assist_id) + { + editor_assists.scroll_lock = None; + } + editor_assists.highlight_updates.send(()).ok(); + } + true } - fn resize_inline_assist( + fn focus_next_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) { + let Some(assist) = self.assists.get(&assist_id) else { + return; + }; + + let assist_group = &self.assist_groups[&assist.group_id]; + let assist_ix = assist_group + .assist_ids + .iter() + .position(|id| *id == assist_id) + .unwrap(); + let assist_ids = assist_group + .assist_ids + .iter() + .skip(assist_ix + 1) + .chain(assist_group.assist_ids.iter().take(assist_ix)); + + for assist_id in assist_ids { + let assist = &self.assists[assist_id]; + if assist.decorations.is_some() { + self.focus_assist(*assist_id, cx); + return; + } + } + + assist.editor.update(cx, |editor, cx| editor.focus(cx)).ok(); + } + + fn focus_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) { + let assist = &self.assists[&assist_id]; + let Some(editor) = assist.editor.upgrade() else { + return; + }; + + if let Some(decorations) = assist.decorations.as_ref() { + decorations.prompt_editor.update(cx, |prompt_editor, cx| { + prompt_editor.editor.update(cx, |editor, cx| { + editor.focus(cx); + editor.select_all(&SelectAll, cx); + }) + }); + } + + let position = assist.codegen.read(cx).range.start; + editor.update(cx, |editor, cx| { + editor.change_selections(None, cx, |selections| { + selections.select_anchor_ranges([position..position]) + }); + + let mut scroll_target_top; + let mut scroll_target_bottom; + if let Some(decorations) = assist.decorations.as_ref() { + scroll_target_top = editor + .row_for_block(decorations.prompt_block_id, cx) + .unwrap() + .0 as f32; + scroll_target_bottom = editor + .row_for_block(decorations.end_block_id, cx) + .unwrap() + .0 as f32; + } else { + let snapshot = editor.snapshot(cx); + let codegen = assist.codegen.read(cx); + let start_row = codegen + .range + .start + .to_display_point(&snapshot.display_snapshot) + .row(); + scroll_target_top = start_row.0 as f32; + scroll_target_bottom = scroll_target_top + 1.; + } + scroll_target_top -= editor.vertical_scroll_margin() as f32; + scroll_target_bottom += editor.vertical_scroll_margin() as f32; + + let height_in_lines = editor.visible_line_count().unwrap_or(0.); + let scroll_top = editor.scroll_position(cx).y; + let scroll_bottom = scroll_top + height_in_lines; + + if scroll_target_top < scroll_top { + editor.set_scroll_position(point(0., scroll_target_top), cx); + } else if scroll_target_bottom > scroll_bottom { + if (scroll_target_bottom - scroll_target_top) <= height_in_lines { + editor + .set_scroll_position(point(0., scroll_target_bottom - height_in_lines), cx); + } else { + editor.set_scroll_position(point(0., scroll_target_top), cx); + } + } + }); + } + + fn resize_assist( &mut self, assist_id: InlineAssistId, height_in_lines: u8, cx: &mut WindowContext, ) { - if let Some(pending_assist) = self.pending_assists.get_mut(&assist_id) { - if let Some(editor) = pending_assist.editor.upgrade() { - if let Some(decorations) = pending_assist.editor_decorations.as_ref() { - let gutter_dimensions = - decorations.prompt_editor.read(cx).gutter_dimensions.clone(); + if let Some(assist) = self.assists.get_mut(&assist_id) { + if let Some(editor) = assist.editor.upgrade() { + if let Some(decorations) = assist.decorations.as_ref() { let mut new_blocks = HashMap::default(); new_blocks.insert( decorations.prompt_block_id, ( Some(height_in_lines), - build_inline_assist_editor_renderer( - &decorations.prompt_editor, - gutter_dimensions, - ), + build_assist_editor_renderer(&decorations.prompt_editor), ), ); editor.update(cx, |editor, cx| { @@ -498,28 +673,51 @@ impl InlineAssistant { } } - fn start_inline_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) { - let pending_assist = if let Some(pending_assist) = self.pending_assists.get_mut(&assist_id) - { - pending_assist + fn unlink_assist_group( + &mut self, + assist_group_id: InlineAssistGroupId, + cx: &mut WindowContext, + ) -> Vec { + let assist_group = self.assist_groups.get_mut(&assist_group_id).unwrap(); + assist_group.linked = false; + for assist_id in &assist_group.assist_ids { + let assist = self.assists.get_mut(assist_id).unwrap(); + if let Some(editor_decorations) = assist.decorations.as_ref() { + editor_decorations + .prompt_editor + .update(cx, |prompt_editor, cx| prompt_editor.unlink(cx)); + } + } + assist_group.assist_ids.clone() + } + + fn start_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) { + let assist = if let Some(assist) = self.assists.get_mut(&assist_id) { + assist } else { return; }; - pending_assist - .codegen - .update(cx, |codegen, cx| codegen.undo(cx)); + let assist_group_id = assist.group_id; + if self.assist_groups[&assist_group_id].linked { + for assist_id in self.unlink_assist_group(assist_group_id, cx) { + self.start_assist(assist_id, cx); + } + return; + } - let Some(user_prompt) = pending_assist - .editor_decorations + assist.codegen.update(cx, |codegen, cx| codegen.undo(cx)); + + let Some(user_prompt) = assist + .decorations .as_ref() .map(|decorations| decorations.prompt_editor.read(cx).prompt(cx)) else { return; }; - let context = if pending_assist.include_context { - pending_assist.workspace.as_ref().and_then(|workspace| { + let context = if assist.include_context { + assist.workspace.as_ref().and_then(|workspace| { let workspace = workspace.upgrade()?.read(cx); let assistant_panel = workspace.panel::(cx)?; assistant_panel.read(cx).active_context(cx) @@ -528,13 +726,13 @@ impl InlineAssistant { None }; - let editor = if let Some(editor) = pending_assist.editor.upgrade() { + let editor = if let Some(editor) = assist.editor.upgrade() { editor } else { return; }; - let project_name = pending_assist.workspace.as_ref().and_then(|workspace| { + let project_name = assist.workspace.as_ref().and_then(|workspace| { let workspace = workspace.upgrade()?; Some( workspace @@ -553,9 +751,9 @@ impl InlineAssistant { self.prompt_history.pop_front(); } - let codegen = pending_assist.codegen.clone(); + let codegen = assist.codegen.clone(); let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx); - let range = codegen.read(cx).range(); + let range = codegen.read(cx).range.clone(); let start = snapshot.point_to_buffer_offset(range.start); let end = snapshot.point_to_buffer_offset(range.end); let (buffer, range) = if let Some((start, end)) = start.zip(end) { @@ -564,11 +762,11 @@ impl InlineAssistant { if start_buffer.remote_id() == end_buffer.remote_id() { (start_buffer.clone(), start_buffer_offset..end_buffer_offset) } else { - self.finish_inline_assist(assist_id, false, cx); + self.finish_assist(assist_id, false, cx); return; } } else { - self.finish_inline_assist(assist_id, false, cx); + self.finish_assist(assist_id, false, cx); return; }; @@ -629,17 +827,14 @@ impl InlineAssistant { .detach_and_log_err(cx); } - fn stop_inline_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) { - let pending_assist = if let Some(pending_assist) = self.pending_assists.get_mut(&assist_id) - { - pending_assist + fn stop_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) { + let assist = if let Some(assist) = self.assists.get_mut(&assist_id) { + assist } else { return; }; - pending_assist - .codegen - .update(cx, |codegen, cx| codegen.stop(cx)); + assist.codegen.update(cx, |codegen, cx| codegen.stop(cx)); } fn update_editor_highlights(&self, editor: &View, cx: &mut WindowContext) { @@ -649,24 +844,26 @@ impl InlineAssistant { let mut inserted_row_ranges = Vec::new(); let empty_assist_ids = Vec::new(); let assist_ids = self - .pending_assist_ids_by_editor + .assists_by_editor .get(&editor.downgrade()) - .unwrap_or(&empty_assist_ids); + .map_or(&empty_assist_ids, |editor_assists| { + &editor_assists.assist_ids + }); for assist_id in assist_ids { - if let Some(pending_assist) = self.pending_assists.get(assist_id) { - let codegen = pending_assist.codegen.read(cx); + if let Some(assist) = self.assists.get(assist_id) { + let codegen = assist.codegen.read(cx); foreground_ranges.extend(codegen.last_equal_ranges().iter().cloned()); - if codegen.edit_position != codegen.range().end { - gutter_pending_ranges.push(codegen.edit_position..codegen.range().end); + if codegen.edit_position != codegen.range.end { + gutter_pending_ranges.push(codegen.edit_position..codegen.range.end); } - if codegen.range().start != codegen.edit_position { - gutter_transformed_ranges.push(codegen.range().start..codegen.edit_position); + if codegen.range.start != codegen.edit_position { + gutter_transformed_ranges.push(codegen.range.start..codegen.edit_position); } - if pending_assist.editor_decorations.is_some() { + if assist.decorations.is_some() { inserted_row_ranges.extend(codegen.diff.inserted_row_ranges.iter().cloned()); } } @@ -700,9 +897,9 @@ impl InlineAssistant { } if foreground_ranges.is_empty() { - editor.clear_highlights::(cx); + editor.clear_highlights::(cx); } else { - editor.highlight_text::( + editor.highlight_text::( foreground_ranges, HighlightStyle { fade_out: Some(0.6), @@ -712,9 +909,9 @@ impl InlineAssistant { ); } - editor.clear_row_highlights::(); + editor.clear_row_highlights::(); for row_range in inserted_row_ranges { - editor.highlight_rows::( + editor.highlight_rows::( row_range, Some(cx.theme().status().info_background), false, @@ -730,14 +927,14 @@ impl InlineAssistant { assist_id: InlineAssistId, cx: &mut WindowContext, ) { - let Some(pending_assist) = self.pending_assists.get_mut(&assist_id) else { + let Some(assist) = self.assists.get_mut(&assist_id) else { return; }; - let Some(decorations) = pending_assist.editor_decorations.as_mut() else { + let Some(decorations) = assist.decorations.as_mut() else { return; }; - let codegen = pending_assist.codegen.read(cx); + let codegen = assist.codegen.read(cx); let old_snapshot = codegen.snapshot.clone(); let old_buffer = codegen.old_buffer.clone(); let deleted_row_ranges = codegen.diff.deleted_row_ranges.clone(); @@ -815,13 +1012,99 @@ impl InlineAssistant { } } -fn build_inline_assist_editor_renderer( - editor: &View, - gutter_dimensions: Arc>, -) -> RenderBlock { +struct EditorInlineAssists { + assist_ids: Vec, + scroll_lock: Option, + highlight_updates: async_watch::Sender<()>, + _update_highlights: Task>, + _subscriptions: Vec, +} + +struct InlineAssistScrollLock { + assist_id: InlineAssistId, + distance_from_top: f32, +} + +impl EditorInlineAssists { + #[allow(clippy::too_many_arguments)] + fn new(editor: &View, cx: &mut WindowContext) -> Self { + let (highlight_updates_tx, mut highlight_updates_rx) = async_watch::channel(()); + Self { + assist_ids: Vec::new(), + scroll_lock: None, + highlight_updates: highlight_updates_tx, + _update_highlights: cx.spawn(|mut cx| { + let editor = editor.downgrade(); + async move { + while let Ok(()) = highlight_updates_rx.changed().await { + let editor = editor.upgrade().context("editor was dropped")?; + cx.update_global(|assistant: &mut InlineAssistant, cx| { + assistant.update_editor_highlights(&editor, cx); + })?; + } + Ok(()) + } + }), + _subscriptions: vec![ + cx.observe(editor, move |editor, cx| { + InlineAssistant::update_global(cx, |this, cx| { + this.handle_editor_change(editor, cx) + }) + }), + cx.subscribe(editor, move |editor, event, cx| { + InlineAssistant::update_global(cx, |this, cx| { + this.handle_editor_event(editor, event, cx) + }) + }), + editor.update(cx, |editor, cx| { + let editor_handle = cx.view().downgrade(); + editor.register_action( + move |_: &editor::actions::Newline, cx: &mut WindowContext| { + InlineAssistant::update_global(cx, |this, cx| { + if let Some(editor) = editor_handle.upgrade() { + this.handle_editor_newline(editor, cx) + } + }) + }, + ) + }), + editor.update(cx, |editor, cx| { + let editor_handle = cx.view().downgrade(); + editor.register_action( + move |_: &editor::actions::Cancel, cx: &mut WindowContext| { + InlineAssistant::update_global(cx, |this, cx| { + if let Some(editor) = editor_handle.upgrade() { + this.handle_editor_cancel(editor, cx) + } + }) + }, + ) + }), + ], + } + } +} + +struct InlineAssistGroup { + assist_ids: Vec, + linked: bool, + active_assist_id: Option, +} + +impl InlineAssistGroup { + fn new() -> Self { + Self { + assist_ids: Vec::new(), + linked: true, + active_assist_id: None, + } + } +} + +fn build_assist_editor_renderer(editor: &View) -> RenderBlock { let editor = editor.clone(); Box::new(move |cx: &mut BlockContext| { - *gutter_dimensions.lock() = *cx.gutter_dimensions; + *editor.read(cx).gutter_dimensions.lock() = *cx.gutter_dimensions; editor.clone().into_any_element() }) } @@ -837,7 +1120,18 @@ impl InlineAssistId { } } -enum InlineAssistEditorEvent { +#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)] +struct InlineAssistGroupId(usize); + +impl InlineAssistGroupId { + fn post_inc(&mut self) -> InlineAssistGroupId { + let id = *self; + self.0 += 1; + id + } +} + +enum PromptEditorEvent { StartRequested, StopRequested, ConfirmRequested, @@ -846,7 +1140,7 @@ enum InlineAssistEditorEvent { Resized { height_in_lines: u8 }, } -struct InlineAssistEditor { +struct PromptEditor { id: InlineAssistId, height_in_lines: u8, editor: View, @@ -857,12 +1151,13 @@ struct InlineAssistEditor { pending_prompt: String, codegen: Model, workspace: Option>, - _subscriptions: Vec, + _codegen_subscription: Subscription, + editor_subscriptions: Vec, } -impl EventEmitter for InlineAssistEditor {} +impl EventEmitter for PromptEditor {} -impl Render for InlineAssistEditor { +impl Render for PromptEditor { fn render(&mut self, cx: &mut ViewContext) -> impl IntoElement { let gutter_dimensions = *self.gutter_dimensions.lock(); @@ -873,18 +1168,16 @@ impl Render for InlineAssistEditor { .icon_color(Color::Muted) .size(ButtonSize::None) .tooltip(|cx| Tooltip::for_action("Cancel Assist", &menu::Cancel, cx)) - .on_click(cx.listener(|_, _, cx| { - cx.emit(InlineAssistEditorEvent::CancelRequested) - })), + .on_click( + cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested)), + ), IconButton::new("start", IconName::Sparkle) .icon_color(Color::Muted) .size(ButtonSize::None) .icon_size(IconSize::XSmall) .tooltip(|cx| Tooltip::for_action("Transform", &menu::Confirm, cx)) .on_click( - cx.listener(|_, _, cx| { - cx.emit(InlineAssistEditorEvent::StartRequested) - }), + cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::StartRequested)), ), ] } @@ -894,9 +1187,9 @@ impl Render for InlineAssistEditor { .icon_color(Color::Muted) .size(ButtonSize::None) .tooltip(|cx| Tooltip::text("Cancel Assist", cx)) - .on_click(cx.listener(|_, _, cx| { - cx.emit(InlineAssistEditorEvent::CancelRequested) - })), + .on_click( + cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested)), + ), IconButton::new("stop", IconName::Stop) .icon_color(Color::Error) .size(ButtonSize::None) @@ -910,7 +1203,7 @@ impl Render for InlineAssistEditor { ) }) .on_click( - cx.listener(|_, _, cx| cx.emit(InlineAssistEditorEvent::StopRequested)), + cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::StopRequested)), ), ] } @@ -920,9 +1213,9 @@ impl Render for InlineAssistEditor { .icon_color(Color::Muted) .size(ButtonSize::None) .tooltip(|cx| Tooltip::for_action("Cancel Assist", &menu::Cancel, cx)) - .on_click(cx.listener(|_, _, cx| { - cx.emit(InlineAssistEditorEvent::CancelRequested) - })), + .on_click( + cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested)), + ), if self.edited_since_done { IconButton::new("restart", IconName::RotateCw) .icon_color(Color::Info) @@ -937,7 +1230,7 @@ impl Render for InlineAssistEditor { ) }) .on_click(cx.listener(|_, _, cx| { - cx.emit(InlineAssistEditorEvent::StartRequested); + cx.emit(PromptEditorEvent::StartRequested); })) } else { IconButton::new("confirm", IconName::Check) @@ -945,7 +1238,7 @@ impl Render for InlineAssistEditor { .size(ButtonSize::None) .tooltip(|cx| Tooltip::for_action("Confirm Assist", &menu::Confirm, cx)) .on_click(cx.listener(|_, _, cx| { - cx.emit(InlineAssistEditorEvent::ConfirmRequested); + cx.emit(PromptEditorEvent::ConfirmRequested); })) }, ] @@ -1034,38 +1327,42 @@ impl Render for InlineAssistEditor { } } -impl FocusableView for InlineAssistEditor { +impl FocusableView for PromptEditor { fn focus_handle(&self, cx: &AppContext) -> FocusHandle { self.editor.focus_handle(cx) } } -impl InlineAssistEditor { +impl PromptEditor { const MAX_LINES: u8 = 8; - #[allow(clippy::too_many_arguments)] fn new( id: InlineAssistId, gutter_dimensions: Arc>, prompt_history: VecDeque, + prompt_buffer: Model, codegen: Model, workspace: Option>, cx: &mut ViewContext, ) -> Self { let prompt_editor = cx.new_view(|cx| { - let mut editor = Editor::auto_height(Self::MAX_LINES as usize, cx); + let mut editor = Editor::new( + EditorMode::AutoHeight { + max_lines: Self::MAX_LINES as usize, + }, + prompt_buffer, + None, + false, + cx, + ); editor.set_soft_wrap_mode(language::language_settings::SoftWrap::EditorWidth, cx); + // Since the prompt editors for all inline assistants are linked, + // always show the cursor (even when it isn't focused) because + // typing in one will make what you typed appear in all of them. + editor.set_show_cursor_when_unfocused(true, cx); editor.set_placeholder_text("Add a prompt…", cx); editor }); - cx.focus_view(&prompt_editor); - - let subscriptions = vec![ - cx.observe(&codegen, Self::handle_codegen_changed), - cx.observe(&prompt_editor, Self::handle_prompt_editor_changed), - cx.subscribe(&prompt_editor, Self::handle_prompt_editor_events), - ]; - let mut this = Self { id, height_in_lines: 1, @@ -1075,14 +1372,50 @@ impl InlineAssistEditor { prompt_history, prompt_history_ix: None, pending_prompt: String::new(), + _codegen_subscription: cx.observe(&codegen, Self::handle_codegen_changed), + editor_subscriptions: Vec::new(), codegen, workspace, - _subscriptions: subscriptions, }; this.count_lines(cx); + this.subscribe_to_editor(cx); this } + fn subscribe_to_editor(&mut self, cx: &mut ViewContext) { + self.editor_subscriptions.clear(); + self.editor_subscriptions + .push(cx.observe(&self.editor, Self::handle_prompt_editor_changed)); + self.editor_subscriptions + .push(cx.subscribe(&self.editor, Self::handle_prompt_editor_events)); + } + + fn set_show_cursor_when_unfocused( + &mut self, + show_cursor_when_unfocused: bool, + cx: &mut ViewContext, + ) { + self.editor.update(cx, |editor, cx| { + editor.set_show_cursor_when_unfocused(show_cursor_when_unfocused, cx) + }); + } + + fn unlink(&mut self, cx: &mut ViewContext) { + let prompt = self.prompt(cx); + let focus = self.editor.focus_handle(cx).contains_focused(cx); + self.editor = cx.new_view(|cx| { + let mut editor = Editor::auto_height(Self::MAX_LINES as usize, cx); + editor.set_soft_wrap_mode(language::language_settings::SoftWrap::EditorWidth, cx); + editor.set_placeholder_text("Add a prompt…", cx); + editor.set_text(prompt, cx); + if focus { + editor.focus(cx); + } + editor + }); + self.subscribe_to_editor(cx); + } + fn prompt(&self, cx: &AppContext) -> String { self.editor.read(cx).text(cx) } @@ -1099,7 +1432,7 @@ impl InlineAssistEditor { if height_in_lines != self.height_in_lines { self.height_in_lines = height_in_lines; - cx.emit(InlineAssistEditorEvent::Resized { height_in_lines }); + cx.emit(PromptEditorEvent::Resized { height_in_lines }); } } @@ -1152,10 +1485,10 @@ impl InlineAssistEditor { fn cancel(&mut self, _: &editor::actions::Cancel, cx: &mut ViewContext) { match &self.codegen.read(cx).status { CodegenStatus::Idle | CodegenStatus::Done | CodegenStatus::Error(_) => { - cx.emit(InlineAssistEditorEvent::CancelRequested); + cx.emit(PromptEditorEvent::CancelRequested); } CodegenStatus::Pending => { - cx.emit(InlineAssistEditorEvent::StopRequested); + cx.emit(PromptEditorEvent::StopRequested); } } } @@ -1163,16 +1496,16 @@ impl InlineAssistEditor { fn confirm(&mut self, _: &menu::Confirm, cx: &mut ViewContext) { match &self.codegen.read(cx).status { CodegenStatus::Idle => { - cx.emit(InlineAssistEditorEvent::StartRequested); + cx.emit(PromptEditorEvent::StartRequested); } CodegenStatus::Pending => { - cx.emit(InlineAssistEditorEvent::DismissRequested); + cx.emit(PromptEditorEvent::DismissRequested); } CodegenStatus::Done | CodegenStatus::Error(_) => { if self.edited_since_done { - cx.emit(InlineAssistEditorEvent::StartRequested); + cx.emit(PromptEditorEvent::StartRequested); } else { - cx.emit(InlineAssistEditorEvent::ConfirmRequested); + cx.emit(PromptEditorEvent::ConfirmRequested); } } } @@ -1249,18 +1582,121 @@ impl InlineAssistEditor { } } -struct PendingInlineAssist { +struct InlineAssist { + group_id: InlineAssistGroupId, editor: WeakView, - editor_decorations: Option, + decorations: Option, codegen: Model, _subscriptions: Vec, workspace: Option>, include_context: bool, } -struct PendingInlineAssistDecorations { +impl InlineAssist { + #[allow(clippy::too_many_arguments)] + fn new( + assist_id: InlineAssistId, + group_id: InlineAssistGroupId, + include_context: bool, + editor: &View, + prompt_editor: &View, + prompt_block_id: BlockId, + end_block_id: BlockId, + codegen: Model, + workspace: Option>, + cx: &mut WindowContext, + ) -> Self { + let prompt_editor_focus_handle = prompt_editor.focus_handle(cx); + InlineAssist { + group_id, + include_context, + editor: editor.downgrade(), + decorations: Some(InlineAssistDecorations { + prompt_block_id, + prompt_editor: prompt_editor.clone(), + removed_line_block_ids: HashSet::default(), + end_block_id, + }), + codegen: codegen.clone(), + workspace: workspace.clone(), + _subscriptions: vec![ + cx.on_focus_in(&prompt_editor_focus_handle, move |cx| { + InlineAssistant::update_global(cx, |this, cx| { + this.handle_prompt_editor_focus_in(assist_id, cx) + }) + }), + cx.on_focus_out(&prompt_editor_focus_handle, move |_, cx| { + InlineAssistant::update_global(cx, |this, cx| { + this.handle_prompt_editor_focus_out(assist_id, cx) + }) + }), + cx.subscribe(prompt_editor, |prompt_editor, event, cx| { + InlineAssistant::update_global(cx, |this, cx| { + this.handle_prompt_editor_event(prompt_editor, event, cx) + }) + }), + cx.observe(&codegen, { + let editor = editor.downgrade(); + move |_, cx| { + if let Some(editor) = editor.upgrade() { + InlineAssistant::update_global(cx, |this, cx| { + if let Some(editor_assists) = + this.assists_by_editor.get(&editor.downgrade()) + { + editor_assists.highlight_updates.send(()).ok(); + } + + this.update_editor_blocks(&editor, assist_id, cx); + }) + } + } + }), + cx.subscribe(&codegen, move |codegen, event, cx| { + InlineAssistant::update_global(cx, |this, cx| match event { + CodegenEvent::Undone => this.finish_assist(assist_id, false, cx), + CodegenEvent::Finished => { + let assist = if let Some(assist) = this.assists.get(&assist_id) { + assist + } else { + return; + }; + + if let CodegenStatus::Error(error) = &codegen.read(cx).status { + if assist.decorations.is_none() { + if let Some(workspace) = assist + .workspace + .as_ref() + .and_then(|workspace| workspace.upgrade()) + { + let error = format!("Inline assistant error: {}", error); + workspace.update(cx, |workspace, cx| { + struct InlineAssistantError; + + let id = + NotificationId::identified::( + assist_id.0, + ); + + workspace.show_toast(Toast::new(id, error), cx); + }) + } + } + } + + if assist.decorations.is_none() { + this.finish_assist(assist_id, false, cx); + } + } + }) + }), + ], + } + } +} + +struct InlineAssistDecorations { prompt_block_id: BlockId, - prompt_editor: View, + prompt_editor: View, removed_line_block_ids: HashSet, end_block_id: BlockId, } @@ -1271,26 +1707,11 @@ pub enum CodegenEvent { Undone, } -#[derive(Clone)] -pub enum CodegenKind { - Transform { range: Range }, - Generate { position: Anchor }, -} - -impl CodegenKind { - fn range(&self, snapshot: &MultiBufferSnapshot) -> Range { - match self { - CodegenKind::Transform { range } => range.clone(), - CodegenKind::Generate { position } => position.bias_left(snapshot)..*position, - } - } -} - pub struct Codegen { buffer: Model, old_buffer: Model, snapshot: MultiBufferSnapshot, - kind: CodegenKind, + range: Range, edit_position: Anchor, last_equal_ranges: Vec>, transaction_id: Option, @@ -1321,7 +1742,7 @@ impl EventEmitter for Codegen {} impl Codegen { pub fn new( buffer: Model, - kind: CodegenKind, + range: Range, telemetry: Option>, cx: &mut ModelContext, ) -> Self { @@ -1329,7 +1750,7 @@ impl Codegen { let (old_buffer, _, _) = buffer .read(cx) - .range_to_buffer_ranges(kind.range(&snapshot), cx) + .range_to_buffer_ranges(range.clone(), cx) .pop() .unwrap(); let old_buffer = cx.new_model(|cx| { @@ -1350,9 +1771,9 @@ impl Codegen { Self { buffer: buffer.clone(), old_buffer, - edit_position: kind.range(&snapshot).start, + edit_position: range.start, + range, snapshot, - kind, last_equal_ranges: Default::default(), transaction_id: Default::default(), status: CodegenStatus::Idle, @@ -1378,16 +1799,12 @@ impl Codegen { } } - pub fn range(&self) -> Range { - self.kind.range(&self.snapshot) - } - pub fn last_equal_ranges(&self) -> &[Range] { &self.last_equal_ranges } pub fn start(&mut self, prompt: LanguageModelRequest, cx: &mut ModelContext) { - let range = self.range(); + let range = self.range.clone(); let snapshot = self.snapshot.clone(); let selected_text = snapshot .text_for_range(range.start..range.end) @@ -1417,7 +1834,7 @@ impl Codegen { let mut response_latency = None; let request_start = Instant::now(); let diff = async { - let chunks = strip_invalid_spans_from_codeblock(response.await?); + let chunks = StripInvalidSpans::new(response.await?); futures::pin_mut!(chunks); let mut diff = StreamingDiff::new(selected_text.to_string()); @@ -1611,9 +2028,9 @@ impl Codegen { self.diff.should_update = false; let old_snapshot = self.snapshot.clone(); - let old_range = self.range().to_point(&old_snapshot); + let old_range = self.range.to_point(&old_snapshot); let new_snapshot = self.buffer.read(cx).snapshot(cx); - let new_range = self.range().to_point(&new_snapshot); + let new_range = self.range.to_point(&new_snapshot); self.diff.task = Some(cx.spawn(|this, mut cx| async move { let (deleted_row_ranges, inserted_row_ranges) = cx @@ -1704,90 +2121,136 @@ impl Codegen { } } -fn strip_invalid_spans_from_codeblock( - stream: impl Stream>, -) -> impl Stream> { - let mut first_line = true; - let mut buffer = String::new(); - let mut starts_with_markdown_codeblock = false; - let mut includes_start_or_end_span = false; - stream.filter_map(move |chunk| { - let chunk = match chunk { - Ok(chunk) => chunk, - Err(err) => return future::ready(Some(Err(err))), - }; - buffer.push_str(&chunk); +struct StripInvalidSpans { + stream: T, + stream_done: bool, + buffer: String, + first_line: bool, + line_end: bool, + starts_with_code_block: bool, +} - if buffer.len() > "<|S|".len() && buffer.starts_with("<|S|") { - includes_start_or_end_span = true; - - buffer = buffer - .strip_prefix("<|S|>") - .or_else(|| buffer.strip_prefix("<|S|")) - .unwrap_or(&buffer) - .to_string(); - } else if buffer.ends_with("|E|>") { - includes_start_or_end_span = true; - } else if buffer.starts_with("<|") - || buffer.starts_with("<|S") - || buffer.starts_with("<|S|") - || buffer.ends_with('|') - || buffer.ends_with("|E") - || buffer.ends_with("|E|") - { - return future::ready(None); +impl StripInvalidSpans +where + T: Stream>, +{ + fn new(stream: T) -> Self { + Self { + stream, + stream_done: false, + buffer: String::new(), + first_line: true, + line_end: false, + starts_with_code_block: false, } + } +} - if first_line { - if buffer.is_empty() || buffer == "`" || buffer == "``" { - return future::ready(None); - } else if buffer.starts_with("```") { - starts_with_markdown_codeblock = true; - if let Some(newline_ix) = buffer.find('\n') { - buffer.replace_range(..newline_ix + 1, ""); - first_line = false; - } else { - return future::ready(None); +impl Stream for StripInvalidSpans +where + T: Stream>, +{ + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context) -> Poll> { + const CODE_BLOCK_DELIMITER: &str = "```"; + const CURSOR_SPAN: &str = "<|CURSOR|>"; + + let this = unsafe { self.get_unchecked_mut() }; + loop { + if !this.stream_done { + let mut stream = unsafe { Pin::new_unchecked(&mut this.stream) }; + match stream.as_mut().poll_next(cx) { + Poll::Ready(Some(Ok(chunk))) => { + this.buffer.push_str(&chunk); + } + Poll::Ready(Some(Err(error))) => return Poll::Ready(Some(Err(error))), + Poll::Ready(None) => { + this.stream_done = true; + } + Poll::Pending => return Poll::Pending, } } + + let mut chunk = String::new(); + let mut consumed = 0; + if !this.buffer.is_empty() { + let mut lines = this.buffer.split('\n').enumerate().peekable(); + while let Some((line_ix, line)) = lines.next() { + if line_ix > 0 { + this.first_line = false; + } + + if this.first_line { + let trimmed_line = line.trim(); + if lines.peek().is_some() { + if trimmed_line.starts_with(CODE_BLOCK_DELIMITER) { + consumed += line.len() + 1; + this.starts_with_code_block = true; + continue; + } + } else if trimmed_line.is_empty() + || prefixes(CODE_BLOCK_DELIMITER) + .any(|prefix| trimmed_line.starts_with(prefix)) + { + break; + } + } + + let line_without_cursor = line.replace(CURSOR_SPAN, ""); + if lines.peek().is_some() { + if this.line_end { + chunk.push('\n'); + } + + chunk.push_str(&line_without_cursor); + this.line_end = true; + consumed += line.len() + 1; + } else if this.stream_done { + if !this.starts_with_code_block + || !line_without_cursor.trim().ends_with(CODE_BLOCK_DELIMITER) + { + if this.line_end { + chunk.push('\n'); + } + + chunk.push_str(&line); + } + + consumed += line.len(); + } else { + let trimmed_line = line.trim(); + if trimmed_line.is_empty() + || prefixes(CURSOR_SPAN).any(|prefix| trimmed_line.ends_with(prefix)) + || prefixes(CODE_BLOCK_DELIMITER) + .any(|prefix| trimmed_line.ends_with(prefix)) + { + break; + } else { + if this.line_end { + chunk.push('\n'); + this.line_end = false; + } + + chunk.push_str(&line_without_cursor); + consumed += line.len(); + } + } + } + } + + this.buffer = this.buffer.split_off(consumed); + if !chunk.is_empty() { + return Poll::Ready(Some(Ok(chunk))); + } else if this.stream_done { + return Poll::Ready(None); + } } + } +} - let mut text = buffer.to_string(); - if starts_with_markdown_codeblock { - text = text - .strip_suffix("\n```\n") - .or_else(|| text.strip_suffix("\n```")) - .or_else(|| text.strip_suffix("\n``")) - .or_else(|| text.strip_suffix("\n`")) - .or_else(|| text.strip_suffix('\n')) - .unwrap_or(&text) - .to_string(); - } - - if includes_start_or_end_span { - text = text - .strip_suffix("|E|>") - .or_else(|| text.strip_suffix("E|>")) - .or_else(|| text.strip_prefix("|>")) - .or_else(|| text.strip_prefix('>')) - .unwrap_or(&text) - .to_string(); - }; - - if text.contains('\n') { - first_line = false; - } - - let remainder = buffer.split_off(text.len()); - let result = if buffer.is_empty() { - None - } else { - Some(Ok(buffer.clone())) - }; - - buffer = remainder; - future::ready(result) - }) +fn prefixes(text: &str) -> impl Iterator { + (0..text.len() - 1).map(|ix| &text[..ix + 1]) } fn merge_ranges(ranges: &mut Vec>, buffer: &MultiBufferSnapshot) { @@ -1857,9 +2320,7 @@ mod tests { let snapshot = buffer.snapshot(cx); snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5)) }); - let codegen = cx.new_model(|cx| { - Codegen::new(buffer.clone(), CodegenKind::Transform { range }, None, cx) - }); + let codegen = cx.new_model(|cx| Codegen::new(buffer.clone(), range, None, cx)); let request = LanguageModelRequest::default(); codegen.update(cx, |codegen, cx| codegen.start(request, cx)); @@ -1912,13 +2373,11 @@ mod tests { let buffer = cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx)); let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx)); - let position = buffer.read_with(cx, |buffer, cx| { + let range = buffer.read_with(cx, |buffer, cx| { let snapshot = buffer.snapshot(cx); - snapshot.anchor_before(Point::new(1, 6)) - }); - let codegen = cx.new_model(|cx| { - Codegen::new(buffer.clone(), CodegenKind::Generate { position }, None, cx) + snapshot.anchor_before(Point::new(1, 6))..snapshot.anchor_after(Point::new(1, 6)) }); + let codegen = cx.new_model(|cx| Codegen::new(buffer.clone(), range, None, cx)); let request = LanguageModelRequest::default(); codegen.update(cx, |codegen, cx| codegen.start(request, cx)); @@ -1971,13 +2430,11 @@ mod tests { let buffer = cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx)); let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx)); - let position = buffer.read_with(cx, |buffer, cx| { + let range = buffer.read_with(cx, |buffer, cx| { let snapshot = buffer.snapshot(cx); - snapshot.anchor_before(Point::new(1, 2)) - }); - let codegen = cx.new_model(|cx| { - Codegen::new(buffer.clone(), CodegenKind::Generate { position }, None, cx) + snapshot.anchor_before(Point::new(1, 2))..snapshot.anchor_after(Point::new(1, 2)) }); + let codegen = cx.new_model(|cx| Codegen::new(buffer.clone(), range, None, cx)); let request = LanguageModelRequest::default(); codegen.update(cx, |codegen, cx| codegen.start(request, cx)); @@ -2014,81 +2471,33 @@ mod tests { #[gpui::test] async fn test_strip_invalid_spans_from_codeblock() { - assert_eq!( - strip_invalid_spans_from_codeblock(chunks("Lorem ipsum dolor", 2)) - .map(|chunk| chunk.unwrap()) - .collect::() - .await, - "Lorem ipsum dolor" - ); - assert_eq!( - strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor", 2)) - .map(|chunk| chunk.unwrap()) - .collect::() - .await, - "Lorem ipsum dolor" - ); - assert_eq!( - strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor\n```", 2)) - .map(|chunk| chunk.unwrap()) - .collect::() - .await, - "Lorem ipsum dolor" - ); - assert_eq!( - strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor\n```\n", 2)) - .map(|chunk| chunk.unwrap()) - .collect::() - .await, - "Lorem ipsum dolor" - ); - assert_eq!( - strip_invalid_spans_from_codeblock(chunks( - "```html\n```js\nLorem ipsum dolor\n```\n```", - 2 - )) - .map(|chunk| chunk.unwrap()) - .collect::() - .await, - "```js\nLorem ipsum dolor\n```" - ); - assert_eq!( - strip_invalid_spans_from_codeblock(chunks("``\nLorem ipsum dolor\n```", 2)) - .map(|chunk| chunk.unwrap()) - .collect::() - .await, - "``\nLorem ipsum dolor\n```" - ); - assert_eq!( - strip_invalid_spans_from_codeblock(chunks("<|S|Lorem ipsum|E|>", 2)) - .map(|chunk| chunk.unwrap()) - .collect::() - .await, - "Lorem ipsum" - ); + assert_chunks("Lorem ipsum dolor", "Lorem ipsum dolor").await; + assert_chunks("```\nLorem ipsum dolor", "Lorem ipsum dolor").await; + assert_chunks("```\nLorem ipsum dolor\n```", "Lorem ipsum dolor").await; + assert_chunks( + "```html\n```js\nLorem ipsum dolor\n```\n```", + "```js\nLorem ipsum dolor\n```", + ) + .await; + assert_chunks("``\nLorem ipsum dolor\n```", "``\nLorem ipsum dolor\n```").await; + assert_chunks("Lorem<|CURSOR|> ipsum", "Lorem ipsum").await; + assert_chunks("Lorem ipsum", "Lorem ipsum").await; + assert_chunks("```\n<|CURSOR|>Lorem ipsum\n```", "Lorem ipsum").await; - assert_eq!( - strip_invalid_spans_from_codeblock(chunks("<|S|>Lorem ipsum", 2)) - .map(|chunk| chunk.unwrap()) - .collect::() - .await, - "Lorem ipsum" - ); + async fn assert_chunks(text: &str, expected_text: &str) { + for chunk_size in 1..=text.len() { + let actual_text = StripInvalidSpans::new(chunks(text, chunk_size)) + .map(|chunk| chunk.unwrap()) + .collect::() + .await; + assert_eq!( + actual_text, expected_text, + "failed to strip invalid spans, chunk size: {}", + chunk_size + ); + } + } - assert_eq!( - strip_invalid_spans_from_codeblock(chunks("```\n<|S|>Lorem ipsum\n```", 2)) - .map(|chunk| chunk.unwrap()) - .collect::() - .await, - "Lorem ipsum" - ); - assert_eq!( - strip_invalid_spans_from_codeblock(chunks("```\n<|S|Lorem ipsum|E|>\n```", 2)) - .map(|chunk| chunk.unwrap()) - .collect::() - .await, - "Lorem ipsum" - ); fn chunks(text: &str, size: usize) -> impl Stream> { stream::iter( text.chars() diff --git a/crates/assistant/src/prompts.rs b/crates/assistant/src/prompts.rs index 80dfc45c4f..f8847733f7 100644 --- a/crates/assistant/src/prompts.rs +++ b/crates/assistant/src/prompts.rs @@ -33,35 +33,32 @@ pub fn generate_content_prompt( )?; } - // Include file content. - for chunk in buffer.text_for_range(0..range.start) { - prompt.push_str(chunk); - } - + writeln!( + prompt, + "The user has the following file open in the editor:" + )?; if range.is_empty() { - prompt.push_str("<|START|>"); - } else { - prompt.push_str("<|START|"); - } + write!(prompt, "```")?; + if let Some(language_name) = language_name { + write!(prompt, "{language_name}")?; + } - for chunk in buffer.text_for_range(range.clone()) { - prompt.push_str(chunk); - } + for chunk in buffer.as_rope().chunks_in_range(0..range.start) { + prompt.push_str(chunk); + } + prompt.push_str("<|CURSOR|>"); + for chunk in buffer.as_rope().chunks_in_range(range.start..buffer.len()) { + prompt.push_str(chunk); + } + if !prompt.ends_with('\n') { + prompt.push('\n'); + } + writeln!(prompt, "```")?; + prompt.push('\n'); - if !range.is_empty() { - prompt.push_str("|END|>"); - } - - for chunk in buffer.text_for_range(range.end..buffer.len()) { - prompt.push_str(chunk); - } - - prompt.push('\n'); - - if range.is_empty() { writeln!( prompt, - "Assume the cursor is located where the `<|START|>` span is." + "Assume the cursor is located where the `<|CURSOR|>` span is." ) .unwrap(); writeln!( @@ -75,11 +72,42 @@ pub fn generate_content_prompt( ) .unwrap(); } else { - writeln!(prompt, "Modify the user's selected {content_type} based upon the users prompt: '{user_prompt}'").unwrap(); - writeln!(prompt, "You must reply with only the adjusted {content_type} (within the '<|START|' and '|END|>' spans) not the entire file.").unwrap(); + write!(prompt, "```")?; + for chunk in buffer.as_rope().chunks() { + prompt.push_str(chunk); + } + if !prompt.ends_with('\n') { + prompt.push('\n'); + } + writeln!(prompt, "```")?; + prompt.push('\n'); + writeln!( prompt, - "Double check that you only return code and not the '<|START|' and '|END|'> spans" + "In particular, the following piece of text is selected:" + )?; + write!(prompt, "```")?; + if let Some(language_name) = language_name { + write!(prompt, "{language_name}")?; + } + prompt.push('\n'); + for chunk in buffer.text_for_range(range.clone()) { + prompt.push_str(chunk); + } + if !prompt.ends_with('\n') { + prompt.push('\n'); + } + writeln!(prompt, "```")?; + prompt.push('\n'); + + writeln!( + prompt, + "Modify the user's selected {content_type} based upon the users prompt: {user_prompt}" + ) + .unwrap(); + writeln!( + prompt, + "You must reply with only the adjusted {content_type}, not the entire file." ) .unwrap(); } diff --git a/crates/collab/src/tests/editor_tests.rs b/crates/collab/src/tests/editor_tests.rs index ee02862d10..74cb699e08 100644 --- a/crates/collab/src/tests/editor_tests.rs +++ b/crates/collab/src/tests/editor_tests.rs @@ -1204,7 +1204,7 @@ async fn test_share_project( buffer_a.read_with(cx_a, |buffer, _| { buffer .snapshot() - .remote_selections_in_range(text::Anchor::MIN..text::Anchor::MAX) + .selections_in_range(text::Anchor::MIN..text::Anchor::MAX, false) .count() == 1 }); @@ -1245,7 +1245,7 @@ async fn test_share_project( buffer_a.read_with(cx_a, |buffer, _| { buffer .snapshot() - .remote_selections_in_range(text::Anchor::MIN..text::Anchor::MAX) + .selections_in_range(text::Anchor::MIN..text::Anchor::MAX, false) .count() == 0 }); diff --git a/crates/diagnostics/src/diagnostics.rs b/crates/diagnostics/src/diagnostics.rs index b2eea68bb6..fa6d64937b 100644 --- a/crates/diagnostics/src/diagnostics.rs +++ b/crates/diagnostics/src/diagnostics.rs @@ -137,7 +137,7 @@ impl ProjectDiagnosticsEditor { this.summary = project.read(cx).diagnostic_summary(false, cx); cx.emit(EditorEvent::TitleChanged); - if this.editor.read(cx).is_focused(cx) || this.focus_handle.is_focused(cx) { + if this.editor.focus_handle(cx).contains_focused(cx) || this.focus_handle.contains_focused(cx) { log::debug!("diagnostics updated for server {language_server_id}, path {path:?}. recording change"); } else { log::debug!("diagnostics updated for server {language_server_id}, path {path:?}. updating excerpts"); diff --git a/crates/editor/src/display_map.rs b/crates/editor/src/display_map.rs index 42f3c34b15..3eb9b28e6a 100644 --- a/crates/editor/src/display_map.rs +++ b/crates/editor/src/display_map.rs @@ -169,7 +169,7 @@ impl DisplayMap { let (wrap_snapshot, edits) = self .wrap_map .update(cx, |map, cx| map.sync(tab_snapshot.clone(), edits, cx)); - let block_snapshot = self.block_map.read(wrap_snapshot.clone(), edits); + let block_snapshot = self.block_map.read(wrap_snapshot.clone(), edits).snapshot; DisplaySnapshot { buffer_snapshot: self.buffer.read(cx).snapshot(cx), @@ -348,6 +348,25 @@ impl DisplayMap { block_map.remove(ids); } + pub fn row_for_block( + &mut self, + block_id: BlockId, + cx: &mut ModelContext, + ) -> Option { + let snapshot = self.buffer.read(cx).snapshot(cx); + let edits = self.buffer_subscription.consume().into_inner(); + let tab_size = Self::tab_size(&self.buffer, cx); + let (snapshot, edits) = self.inlay_map.sync(snapshot, edits); + let (snapshot, edits) = self.fold_map.read(snapshot, edits); + let (snapshot, edits) = self.tab_map.sync(snapshot, edits, tab_size); + let (snapshot, edits) = self + .wrap_map + .update(cx, |map, cx| map.sync(snapshot, edits, cx)); + let block_map = self.block_map.read(snapshot, edits); + let block_row = block_map.row_for_block(block_id)?; + Some(DisplayRow(block_row.0)) + } + pub fn highlight_text( &mut self, type_id: TypeId, diff --git a/crates/editor/src/display_map/block_map.rs b/crates/editor/src/display_map/block_map.rs index e17333a195..40cd4a48d7 100644 --- a/crates/editor/src/display_map/block_map.rs +++ b/crates/editor/src/display_map/block_map.rs @@ -37,6 +37,11 @@ pub struct BlockMap { excerpt_footer_height: u8, } +pub struct BlockMapReader<'a> { + blocks: &'a Vec>, + pub snapshot: BlockSnapshot, +} + pub struct BlockMapWriter<'a>(&'a mut BlockMap); #[derive(Clone)] @@ -246,12 +251,15 @@ impl BlockMap { map } - pub fn read(&self, wrap_snapshot: WrapSnapshot, edits: Patch) -> BlockSnapshot { + pub fn read(&self, wrap_snapshot: WrapSnapshot, edits: Patch) -> BlockMapReader { self.sync(&wrap_snapshot, edits); *self.wrap_snapshot.borrow_mut() = wrap_snapshot.clone(); - BlockSnapshot { - wrap_snapshot, - transforms: self.transforms.borrow().clone(), + BlockMapReader { + blocks: &self.blocks, + snapshot: BlockSnapshot { + wrap_snapshot, + transforms: self.transforms.borrow().clone(), + }, } } @@ -606,6 +614,62 @@ impl std::ops::DerefMut for BlockPoint { } } +impl<'a> Deref for BlockMapReader<'a> { + type Target = BlockSnapshot; + + fn deref(&self) -> &Self::Target { + &self.snapshot + } +} + +impl<'a> DerefMut for BlockMapReader<'a> { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.snapshot + } +} + +impl<'a> BlockMapReader<'a> { + pub fn row_for_block(&self, block_id: BlockId) -> Option { + let block = self.blocks.iter().find(|block| block.id == block_id)?; + let buffer_row = block + .position + .to_point(self.wrap_snapshot.buffer_snapshot()) + .row; + let wrap_row = self + .wrap_snapshot + .make_wrap_point(Point::new(buffer_row, 0), Bias::Left) + .row(); + let start_wrap_row = WrapRow( + self.wrap_snapshot + .prev_row_boundary(WrapPoint::new(wrap_row, 0)), + ); + let end_wrap_row = WrapRow( + self.wrap_snapshot + .next_row_boundary(WrapPoint::new(wrap_row, 0)) + .unwrap_or(self.wrap_snapshot.max_point().row() + 1), + ); + + let mut cursor = self.transforms.cursor::<(WrapRow, BlockRow)>(); + cursor.seek(&start_wrap_row, Bias::Left, &()); + while let Some(transform) = cursor.item() { + if cursor.start().0 > end_wrap_row { + break; + } + + if let Some(BlockType::Custom(id)) = + transform.block.as_ref().map(|block| block.block_type()) + { + if id == block_id { + return Some(cursor.start().1); + } + } + cursor.next(&()); + } + + None + } +} + impl<'a> BlockMapWriter<'a> { pub fn insert( &mut self, @@ -1784,6 +1848,15 @@ mod tests { expected_block_positions ); + for (block_row, block) in expected_block_positions { + if let BlockType::Custom(block_id) = block.block_type() { + assert_eq!( + blocks_snapshot.row_for_block(block_id), + Some(BlockRow(block_row)) + ); + } + } + let mut expected_longest_rows = Vec::new(); let mut longest_line_len = -1_isize; for (row, line) in expected_lines.iter().enumerate() { diff --git a/crates/editor/src/editor.rs b/crates/editor/src/editor.rs index 298b51d7bd..cde7562b6c 100644 --- a/crates/editor/src/editor.rs +++ b/crates/editor/src/editor.rs @@ -457,6 +457,9 @@ pub struct Editor { pub display_map: Model, pub selections: SelectionsCollection, pub scroll_manager: ScrollManager, + /// When inline assist editors are linked, they all render cursors because + /// typing enters text into each of them, even the ones that aren't focused. + pub(crate) show_cursor_when_unfocused: bool, columnar_selection_tail: Option, add_selections_state: Option, select_next_state: Option, @@ -1635,7 +1638,7 @@ impl Editor { clone } - fn new( + pub fn new( mode: EditorMode, buffer: Model, project: Option>, @@ -1752,6 +1755,7 @@ impl Editor { let mut this = Self { focus_handle, + show_cursor_when_unfocused: false, last_focused_descendant: None, buffer: buffer.clone(), display_map: display_map.clone(), @@ -2220,7 +2224,7 @@ impl Editor { // Copy selections to primary selection buffer #[cfg(target_os = "linux")] if local { - let selections = &self.selections.disjoint; + let selections = self.selections.all::(cx); let buffer_handle = self.buffer.read(cx).read(cx); let mut text = String::new(); @@ -9964,6 +9968,15 @@ impl Editor { } } + pub fn row_for_block( + &self, + block_id: BlockId, + cx: &mut ViewContext, + ) -> Option { + self.display_map + .update(cx, |map, cx| map.row_for_block(block_id, cx)) + } + pub fn insert_creases( &mut self, creases: impl IntoIterator, @@ -10902,6 +10915,11 @@ impl Editor { && self.focus_handle.is_focused(cx) } + pub fn set_show_cursor_when_unfocused(&mut self, is_enabled: bool, cx: &mut ViewContext) { + self.show_cursor_when_unfocused = is_enabled; + cx.notify(); + } + fn on_buffer_changed(&mut self, _: Model, cx: &mut ViewContext) { cx.notify(); } @@ -11722,7 +11740,7 @@ impl EditorSnapshot { .map(|(_, collaborator)| (collaborator.replica_id, collaborator)) .collect::>(); self.buffer_snapshot - .remote_selections_in_range(range) + .selections_in_range(range, false) .filter_map(move |(replica_id, line_mode, cursor_shape, selection)| { let collaborator = collaborators_by_replica_id.get(&replica_id)?; let participant_index = participant_indices.get(&collaborator.user_id).copied(); diff --git a/crates/editor/src/element.rs b/crates/editor/src/element.rs index 0faec8f398..d7ea58c782 100644 --- a/crates/editor/src/element.rs +++ b/crates/editor/src/element.rs @@ -859,6 +859,28 @@ impl EditorElement { } selections.extend(remote_selections.into_values()); + } else if !editor.is_focused(cx) && editor.show_cursor_when_unfocused { + let player = if editor.read_only(cx) { + cx.theme().players().read_only() + } else { + self.style.local_player + }; + let layouts = snapshot + .buffer_snapshot + .selections_in_range(&(start_anchor..end_anchor), true) + .map(move |(_, line_mode, cursor_shape, selection)| { + SelectionLayout::new( + selection, + line_mode, + cursor_shape, + &snapshot.display_snapshot, + false, + false, + None, + ) + }) + .collect::>(); + selections.push((player, layouts)); } (selections, active_rows, newest_selection_head) } @@ -3631,12 +3653,12 @@ impl EditorElement { let forbid_vertical_scroll = editor.scroll_manager.forbid_vertical_scroll(); if forbid_vertical_scroll { scroll_position.y = current_scroll_position.y; - if scroll_position == current_scroll_position { - return; - } } - editor.scroll(scroll_position, axis, cx); - cx.stop_propagation(); + + if scroll_position != current_scroll_position { + editor.scroll(scroll_position, axis, cx); + cx.stop_propagation(); + } }); } } @@ -4621,13 +4643,29 @@ impl Element for EditorElement { let content_origin = text_hitbox.origin + point(gutter_dimensions.margin, Pixels::ZERO); + let height_in_lines = bounds.size.height / line_height; + let max_scroll_top = if matches!(snapshot.mode, EditorMode::AutoHeight { .. }) { + (snapshot.max_point().row().as_f32() - height_in_lines + 1.).max(0.) + } else { + let settings = EditorSettings::get_global(cx); + let max_row = snapshot.max_point().row().as_f32(); + match settings.scroll_beyond_last_line { + ScrollBeyondLastLine::OnePage => max_row, + ScrollBeyondLastLine::Off => (max_row - height_in_lines + 1.0).max(0.0), + ScrollBeyondLastLine::VerticalScrollMargin => { + (max_row - height_in_lines + 1.0 + settings.vertical_scroll_margin) + .max(0.0) + } + } + }; + let mut autoscroll_containing_element = false; let mut autoscroll_horizontally = false; self.editor.update(cx, |editor, cx| { autoscroll_containing_element = editor.autoscroll_requested() || editor.has_pending_selection(); autoscroll_horizontally = - editor.autoscroll_vertically(bounds, line_height, cx); + editor.autoscroll_vertically(bounds, line_height, max_scroll_top, cx); snapshot = editor.snapshot(cx); }); @@ -4635,7 +4673,6 @@ impl Element for EditorElement { // The scroll position is a fractional point, the whole number of which represents // the top of the window in terms of display rows. let start_row = DisplayRow(scroll_position.y as u32); - let height_in_lines = bounds.size.height / line_height; let max_row = snapshot.max_point().row(); let end_row = cmp::min( (scroll_position.y + height_in_lines).ceil() as u32, @@ -4817,22 +4854,9 @@ impl Element for EditorElement { cx, ); - let settings = EditorSettings::get_global(cx); - let scroll_max_row = max_row.as_f32(); - let scroll_max_row = match settings.scroll_beyond_last_line { - ScrollBeyondLastLine::OnePage => scroll_max_row, - ScrollBeyondLastLine::Off => { - (scroll_max_row - height_in_lines + 1.0).max(0.0) - } - ScrollBeyondLastLine::VerticalScrollMargin => (scroll_max_row - - height_in_lines - + 1.0 - + settings.vertical_scroll_margin) - .max(0.0), - }; let scroll_max = point( ((scroll_width - text_hitbox.size.width) / em_width).max(0.0), - scroll_max_row, + max_scroll_top, ); self.editor.update(cx, |editor, cx| { diff --git a/crates/editor/src/items.rs b/crates/editor/src/items.rs index 2dc0b6c616..5874789ddf 100644 --- a/crates/editor/src/items.rs +++ b/crates/editor/src/items.rs @@ -1201,20 +1201,22 @@ impl SearchableItem for Editor { for (excerpt_id, search_buffer, search_range) in buffer.excerpts_in_ranges(search_within_ranges) { - ranges.extend( - query - .search(&search_buffer, Some(search_range.clone())) - .await - .into_iter() - .map(|match_range| { - let start = search_buffer - .anchor_after(search_range.start + match_range.start); - let end = search_buffer - .anchor_before(search_range.start + match_range.end); - buffer.anchor_in_excerpt(excerpt_id, start).unwrap() - ..buffer.anchor_in_excerpt(excerpt_id, end).unwrap() - }), - ); + if !search_range.is_empty() { + ranges.extend( + query + .search(&search_buffer, Some(search_range.clone())) + .await + .into_iter() + .map(|match_range| { + let start = search_buffer + .anchor_after(search_range.start + match_range.start); + let end = search_buffer + .anchor_before(search_range.start + match_range.end); + buffer.anchor_in_excerpt(excerpt_id, start).unwrap() + ..buffer.anchor_in_excerpt(excerpt_id, end).unwrap() + }), + ); + } } }; diff --git a/crates/editor/src/scroll/autoscroll.rs b/crates/editor/src/scroll/autoscroll.rs index cde450f11b..deed5333f8 100644 --- a/crates/editor/src/scroll/autoscroll.rs +++ b/crates/editor/src/scroll/autoscroll.rs @@ -69,6 +69,7 @@ impl Editor { &mut self, bounds: Bounds, line_height: Pixels, + max_scroll_top: f32, cx: &mut ViewContext, ) -> bool { let viewport_height = bounds.size.height; @@ -84,11 +85,6 @@ impl Editor { } } } - let max_scroll_top = if matches!(self.mode, EditorMode::AutoHeight { .. }) { - (display_map.max_point().row().as_f32() - visible_lines + 1.).max(0.) - } else { - display_map.max_point().row().as_f32() - }; if scroll_position.y > max_scroll_top { scroll_position.y = max_scroll_top; } diff --git a/crates/gpui/src/window.rs b/crates/gpui/src/window.rs index 5384367a29..0c7a980669 100644 --- a/crates/gpui/src/window.rs +++ b/crates/gpui/src/window.rs @@ -93,6 +93,16 @@ struct WindowFocusEvent { current_focus_path: SmallVec<[FocusId; 8]>, } +impl WindowFocusEvent { + pub fn is_focus_in(&self, focus_id: FocusId) -> bool { + !self.previous_focus_path.contains(&focus_id) && self.current_focus_path.contains(&focus_id) + } + + pub fn is_focus_out(&self, focus_id: FocusId) -> bool { + self.previous_focus_path.contains(&focus_id) && !self.current_focus_path.contains(&focus_id) + } +} + /// This is provided when subscribing for `ViewContext::on_focus_out` events. pub struct FocusOutEvent { /// A weak focus handle representing what was blurred. @@ -2883,6 +2893,53 @@ impl<'a> WindowContext<'a> { )); } + /// Register a listener to be called when the given focus handle or one of its descendants receives focus. + /// This does not fire if the given focus handle - or one of its descendants - was previously focused. + /// Returns a subscription and persists until the subscription is dropped. + pub fn on_focus_in( + &mut self, + handle: &FocusHandle, + mut listener: impl FnMut(&mut WindowContext) + 'static, + ) -> Subscription { + let focus_id = handle.id; + let (subscription, activate) = + self.window.new_focus_listener(Box::new(move |event, cx| { + if event.is_focus_in(focus_id) { + listener(cx); + } + true + })); + self.app.defer(move |_| activate()); + subscription + } + + /// Register a listener to be called when the given focus handle or one of its descendants loses focus. + /// Returns a subscription and persists until the subscription is dropped. + pub fn on_focus_out( + &mut self, + handle: &FocusHandle, + mut listener: impl FnMut(FocusOutEvent, &mut WindowContext) + 'static, + ) -> Subscription { + let focus_id = handle.id; + let (subscription, activate) = + self.window.new_focus_listener(Box::new(move |event, cx| { + if let Some(blurred_id) = event.previous_focus_path.last().copied() { + if event.is_focus_out(focus_id) { + let event = FocusOutEvent { + blurred: WeakFocusHandle { + id: blurred_id, + handles: Arc::downgrade(&cx.window.focus_handles), + }, + }; + listener(event, cx) + } + } + true + })); + self.app.defer(move |_| activate()); + subscription + } + fn reset_cursor_style(&self) { // Set the cursor only if we're the active window. if self.is_window_active() { @@ -4109,9 +4166,7 @@ impl<'a, V: 'static> ViewContext<'a, V> { let (subscription, activate) = self.window.new_focus_listener(Box::new(move |event, cx| { view.update(cx, |view, cx| { - if !event.previous_focus_path.contains(&focus_id) - && event.current_focus_path.contains(&focus_id) - { + if event.is_focus_in(focus_id) { listener(view, cx) } }) @@ -4175,9 +4230,7 @@ impl<'a, V: 'static> ViewContext<'a, V> { self.window.new_focus_listener(Box::new(move |event, cx| { view.update(cx, |view, cx| { if let Some(blurred_id) = event.previous_focus_path.last().copied() { - if event.previous_focus_path.contains(&focus_id) - && !event.current_focus_path.contains(&focus_id) - { + if event.is_focus_out(focus_id) { let event = FocusOutEvent { blurred: WeakFocusHandle { id: blurred_id, diff --git a/crates/language/src/buffer.rs b/crates/language/src/buffer.rs index ba3a961b1e..eda38f8a76 100644 --- a/crates/language/src/buffer.rs +++ b/crates/language/src/buffer.rs @@ -1701,6 +1701,8 @@ impl Buffer { }, cx, ); + self.selections_update_count += 1; + cx.notify(); } /// Clears the selections, so that other replicas of the buffer do not see any selections for @@ -3355,9 +3357,10 @@ impl BufferSnapshot { /// Returns selections for remote peers intersecting the given range. #[allow(clippy::type_complexity)] - pub fn remote_selections_in_range( + pub fn selections_in_range( &self, range: Range, + include_local: bool, ) -> impl Iterator< Item = ( ReplicaId, @@ -3368,8 +3371,9 @@ impl BufferSnapshot { > + '_ { self.remote_selections .iter() - .filter(|(replica_id, set)| { - **replica_id != self.text.replica_id() && !set.selections.is_empty() + .filter(move |(replica_id, set)| { + (include_local || **replica_id != self.text.replica_id()) + && !set.selections.is_empty() }) .map(move |(replica_id, set)| { let start_ix = match set.selections.binary_search_by(|probe| { diff --git a/crates/language/src/buffer_tests.rs b/crates/language/src/buffer_tests.rs index c47ccbddd5..2c50e6dc9e 100644 --- a/crates/language/src/buffer_tests.rs +++ b/crates/language/src/buffer_tests.rs @@ -2416,7 +2416,7 @@ fn test_random_collaboration(cx: &mut AppContext, mut rng: StdRng) { for buffer in &buffers { let buffer = buffer.read(cx).snapshot(); let actual_remote_selections = buffer - .remote_selections_in_range(Anchor::MIN..Anchor::MAX) + .selections_in_range(Anchor::MIN..Anchor::MAX, false) .map(|(replica_id, _, _, selections)| (replica_id, selections.collect::>())) .collect::>(); let expected_remote_selections = active_selections diff --git a/crates/multi_buffer/src/multi_buffer.rs b/crates/multi_buffer/src/multi_buffer.rs index ff05aac6e8..9564a3c169 100644 --- a/crates/multi_buffer/src/multi_buffer.rs +++ b/crates/multi_buffer/src/multi_buffer.rs @@ -3834,8 +3834,7 @@ impl MultiBufferSnapshot { return None; } - if range.as_ref().unwrap().is_empty() || *cursor.start() >= range.as_ref().unwrap().end - { + if *cursor.start() >= range.as_ref().unwrap().end { range = next_range(&mut cursor); if range.is_none() { return None; @@ -3867,9 +3866,10 @@ impl MultiBufferSnapshot { }) } - pub fn remote_selections_in_range<'a>( + pub fn selections_in_range<'a>( &'a self, range: &'a Range, + include_local: bool, ) -> impl 'a + Iterator)> { let mut cursor = self.excerpts.cursor::(); let start_locator = self.excerpt_locator_for_id(range.start.excerpt_id); @@ -3888,7 +3888,7 @@ impl MultiBufferSnapshot { excerpt .buffer - .remote_selections_in_range(query_range) + .selections_in_range(query_range, include_local) .flat_map(move |(replica_id, line_mode, cursor_shape, selections)| { selections.map(move |selection| { let mut start = Anchor {