diff --git a/crates/ai/src/ai.rs b/crates/ai/src/ai.rs index 2c2d7e774e..2e8eca80e3 100644 --- a/crates/ai/src/ai.rs +++ b/crates/ai/src/ai.rs @@ -1,5 +1,6 @@ pub mod assistant; mod assistant_settings; +mod codegen; mod streaming_diff; use anyhow::{anyhow, Result}; diff --git a/crates/ai/src/assistant.rs b/crates/ai/src/assistant.rs index 9b384252fc..1d56a6308c 100644 --- a/crates/ai/src/assistant.rs +++ b/crates/ai/src/assistant.rs @@ -1,9 +1,8 @@ use crate::{ assistant_settings::{AssistantDockPosition, AssistantSettings, OpenAIModel}, - stream_completion, - streaming_diff::{Hunk, StreamingDiff}, - MessageId, MessageMetadata, MessageStatus, OpenAIRequest, RequestMessage, Role, - SavedConversation, SavedConversationMetadata, SavedMessage, OPENAI_API_URL, + codegen::{self, Codegen, OpenAICompletionProvider}, + stream_completion, MessageId, MessageMetadata, MessageStatus, OpenAIRequest, RequestMessage, + Role, SavedConversation, SavedConversationMetadata, SavedMessage, OPENAI_API_URL, }; use anyhow::{anyhow, Result}; use chrono::{DateTime, Local}; @@ -13,10 +12,10 @@ use editor::{ BlockContext, BlockDisposition, BlockId, BlockProperties, BlockStyle, ToDisplayPoint, }, scroll::autoscroll::{Autoscroll, AutoscrollStrategy}, - Anchor, Editor, MoveDown, MoveUp, MultiBufferSnapshot, ToOffset, ToPoint, + Anchor, Editor, MoveDown, MoveUp, MultiBufferSnapshot, ToOffset, }; use fs::Fs; -use futures::{channel::mpsc, SinkExt, Stream, StreamExt}; +use futures::StreamExt; use gpui::{ actions, elements::{ @@ -30,17 +29,14 @@ use gpui::{ ModelHandle, SizeConstraint, Subscription, Task, View, ViewContext, ViewHandle, WeakViewHandle, WindowContext, }; -use language::{ - language_settings::SoftWrap, Buffer, LanguageRegistry, Point, Rope, ToOffset as _, - TransactionId, -}; +use language::{language_settings::SoftWrap, Buffer, LanguageRegistry, ToOffset as _}; use search::BufferSearchBar; use settings::SettingsStore; use std::{ cell::{Cell, RefCell}, cmp, env, fmt::Write, - future, iter, + iter, ops::Range, path::{Path, PathBuf}, rc::Rc, @@ -266,10 +262,22 @@ impl AssistantPanel { } fn new_inline_assist(&mut self, editor: &ViewHandle, cx: &mut ViewContext) { + let api_key = if let Some(api_key) = self.api_key.borrow().clone() { + api_key + } else { + return; + }; + let inline_assist_id = post_inc(&mut self.next_inline_assist_id); let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx); let selection = editor.read(cx).selections.newest_anchor().clone(); let range = selection.start.bias_left(&snapshot)..selection.end.bias_right(&snapshot); + let provider = Arc::new(OpenAICompletionProvider::new( + api_key, + cx.background().clone(), + )); + let codegen = + cx.add_model(|cx| Codegen::new(editor.read(cx).buffer().clone(), range, provider, cx)); let assist_kind = if editor.read(cx).selections.newest::(cx).is_empty() { InlineAssistKind::Generate } else { @@ -283,6 +291,7 @@ impl AssistantPanel { measurements.clone(), self.include_conversation_in_next_inline_assist, self.inline_prompt_history.clone(), + codegen.clone(), cx, ); cx.focus_self(); @@ -323,46 +332,55 @@ impl AssistantPanel { PendingInlineAssist { kind: assist_kind, editor: editor.downgrade(), - range, - highlighted_ranges: Default::default(), inline_assistant: Some((block_id, inline_assistant.clone())), - code_generation: Task::ready(None), - transaction_id: None, + codegen: codegen.clone(), _subscriptions: vec![ cx.subscribe(&inline_assistant, Self::handle_inline_assistant_event), cx.subscribe(editor, { let inline_assistant = inline_assistant.downgrade(); - move |this, editor, event, cx| { + move |_, editor, event, cx| { if let Some(inline_assistant) = inline_assistant.upgrade(cx) { - match event { - editor::Event::SelectionsChanged { local } => { - if *local && inline_assistant.read(cx).has_focus { - cx.focus(&editor); - } + if let editor::Event::SelectionsChanged { local } = event { + if *local && inline_assistant.read(cx).has_focus { + cx.focus(&editor); } - editor::Event::TransactionUndone { - transaction_id: tx_id, - } => { - if let Some(pending_assist) = - this.pending_inline_assists.get(&inline_assist_id) - { - if pending_assist.transaction_id == Some(*tx_id) { - // Notice we are supplying `undo: false` here. This - // is because there's no need to undo the transaction - // because the user just did so. - this.close_inline_assist( - inline_assist_id, - false, - cx, - ); - } - } - } - _ => {} } } } }), + cx.subscribe(&codegen, move |this, codegen, event, cx| match event { + codegen::Event::Undone => { + this.finish_inline_assist(inline_assist_id, false, cx) + } + codegen::Event::Finished => { + let pending_assist = if let Some(pending_assist) = + this.pending_inline_assists.get(&inline_assist_id) + { + pending_assist + } else { + return; + }; + + let error = codegen + .read(cx) + .error() + .map(|error| format!("Inline assistant error: {}", error)); + if let Some(error) = error { + if pending_assist.inline_assistant.is_none() { + if let Some(workspace) = this.workspace.upgrade(cx) { + workspace.update(cx, |workspace, cx| { + workspace.show_toast( + Toast::new(inline_assist_id, error), + cx, + ); + }) + } + } + } + + this.finish_inline_assist(inline_assist_id, false, cx); + } + }), ], }, ); @@ -388,7 +406,7 @@ impl AssistantPanel { self.confirm_inline_assist(assist_id, prompt, *include_conversation, cx); } InlineAssistantEvent::Canceled => { - self.close_inline_assist(assist_id, true, cx); + self.finish_inline_assist(assist_id, true, cx); } InlineAssistantEvent::Dismissed => { self.hide_inline_assist(assist_id, cx); @@ -417,7 +435,7 @@ impl AssistantPanel { .get(&editor.downgrade()) .and_then(|assist_ids| assist_ids.last().copied()) { - panel.close_inline_assist(assist_id, true, cx); + panel.finish_inline_assist(assist_id, true, cx); true } else { false @@ -432,7 +450,7 @@ impl AssistantPanel { cx.propagate_action(); } - fn close_inline_assist(&mut self, assist_id: usize, undo: bool, cx: &mut ViewContext) { + fn finish_inline_assist(&mut self, assist_id: usize, undo: bool, cx: &mut ViewContext) { self.hide_inline_assist(assist_id, cx); if let Some(pending_assist) = self.pending_inline_assists.remove(&assist_id) { @@ -450,13 +468,9 @@ impl AssistantPanel { self.update_highlights_for_editor(&editor, cx); if undo { - if let Some(transaction_id) = pending_assist.transaction_id { - editor.update(cx, |editor, cx| { - editor.buffer().update(cx, |buffer, cx| { - buffer.undo_transaction(transaction_id, cx) - }); - }); - } + pending_assist + .codegen + .update(cx, |codegen, cx| codegen.undo(cx)); } } } @@ -481,12 +495,6 @@ impl AssistantPanel { include_conversation: bool, cx: &mut ViewContext, ) { - let api_key = if let Some(api_key) = self.api_key.borrow().clone() { - api_key - } else { - return; - }; - let conversation = if include_conversation { self.active_editor() .map(|editor| editor.read(cx).conversation.clone()) @@ -514,56 +522,9 @@ impl AssistantPanel { self.inline_prompt_history.pop_front(); } - let range = pending_assist.range.clone(); let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx); - let selected_text = snapshot - .text_for_range(range.start..range.end) - .collect::(); - - let selection_start = range.start.to_point(&snapshot); - let selection_end = range.end.to_point(&snapshot); - - let mut base_indent: Option = None; - let mut start_row = selection_start.row; - if snapshot.is_line_blank(start_row) { - if let Some(prev_non_blank_row) = snapshot.prev_non_blank_row(start_row) { - start_row = prev_non_blank_row; - } - } - for row in start_row..=selection_end.row { - if snapshot.is_line_blank(row) { - continue; - } - - let line_indent = snapshot.indent_size_for_line(row); - if let Some(base_indent) = base_indent.as_mut() { - if line_indent.len < base_indent.len { - *base_indent = line_indent; - } - } else { - base_indent = Some(line_indent); - } - } - - let mut normalized_selected_text = selected_text.clone(); - if let Some(base_indent) = base_indent { - for row in selection_start.row..=selection_end.row { - let selection_row = row - selection_start.row; - let line_start = - normalized_selected_text.point_to_offset(Point::new(selection_row, 0)); - let indent_len = if row == selection_start.row { - base_indent.len.saturating_sub(selection_start.column) - } else { - let line_len = normalized_selected_text.line_len(selection_row); - cmp::min(line_len, base_indent.len) - }; - let indent_end = cmp::min( - line_start + indent_len as usize, - normalized_selected_text.len(), - ); - normalized_selected_text.replace(line_start..indent_end, ""); - } - } + let range = pending_assist.codegen.read(cx).range(); + let selected_text = snapshot.text_for_range(range.clone()).collect::(); let language = snapshot.language_at(range.start); let language_name = if let Some(language) = language.as_ref() { @@ -608,7 +569,7 @@ impl AssistantPanel { } else { writeln!(prompt, "```").unwrap(); } - writeln!(prompt, "{normalized_selected_text}").unwrap(); + writeln!(prompt, "{selected_text}").unwrap(); writeln!(prompt, "```").unwrap(); writeln!(prompt).unwrap(); writeln!( @@ -689,209 +650,9 @@ impl AssistantPanel { messages, stream: true, }; - let response = stream_completion(api_key, cx.background().clone(), request); - let editor = editor.downgrade(); - - pending_assist.code_generation = cx.spawn(|this, mut cx| { - async move { - let mut edit_start = range.start.to_offset(&snapshot); - - let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1); - let diff = cx.background().spawn(async move { - let chunks = strip_markdown_codeblock(response.await?.filter_map( - |message| async move { - match message { - Ok(mut message) => Some(Ok(message.choices.pop()?.delta.content?)), - Err(error) => Some(Err(error)), - } - }, - )); - futures::pin_mut!(chunks); - let mut diff = StreamingDiff::new(selected_text.to_string()); - - let mut indent_len; - let indent_text; - if let Some(base_indent) = base_indent { - indent_len = base_indent.len; - indent_text = match base_indent.kind { - language::IndentKind::Space => " ", - language::IndentKind::Tab => "\t", - }; - } else { - indent_len = 0; - indent_text = ""; - }; - - let mut first_line_len = 0; - let mut first_line_non_whitespace_char_ix = None; - let mut first_line = true; - let mut new_text = String::new(); - - while let Some(chunk) = chunks.next().await { - let chunk = chunk?; - - let mut lines = chunk.split('\n'); - if let Some(mut line) = lines.next() { - if first_line { - if first_line_non_whitespace_char_ix.is_none() { - if let Some(mut char_ix) = - line.find(|ch: char| !ch.is_whitespace()) - { - line = &line[char_ix..]; - char_ix += first_line_len; - first_line_non_whitespace_char_ix = Some(char_ix); - let first_line_indent = char_ix - .saturating_sub(selection_start.column as usize) - as usize; - new_text.push_str(&indent_text.repeat(first_line_indent)); - indent_len = indent_len.saturating_sub(char_ix as u32); - } - } - first_line_len += line.len(); - } - - if first_line_non_whitespace_char_ix.is_some() { - new_text.push_str(line); - } - } - - for line in lines { - first_line = false; - new_text.push('\n'); - if !line.is_empty() { - new_text.push_str(&indent_text.repeat(indent_len as usize)); - } - new_text.push_str(line); - } - - let hunks = diff.push_new(&new_text); - hunks_tx.send(hunks).await?; - new_text.clear(); - } - hunks_tx.send(diff.finish()).await?; - - anyhow::Ok(()) - }); - - while let Some(hunks) = hunks_rx.next().await { - let editor = if let Some(editor) = editor.upgrade(&cx) { - editor - } else { - break; - }; - - let this = if let Some(this) = this.upgrade(&cx) { - this - } else { - break; - }; - - this.update(&mut cx, |this, cx| { - let pending_assist = if let Some(pending_assist) = - this.pending_inline_assists.get_mut(&inline_assist_id) - { - pending_assist - } else { - return; - }; - - pending_assist.highlighted_ranges.clear(); - editor.update(cx, |editor, cx| { - let transaction = editor.buffer().update(cx, |buffer, cx| { - // Avoid grouping assistant edits with user edits. - buffer.finalize_last_transaction(cx); - - buffer.start_transaction(cx); - buffer.edit( - hunks.into_iter().filter_map(|hunk| match hunk { - Hunk::Insert { text } => { - let edit_start = snapshot.anchor_after(edit_start); - Some((edit_start..edit_start, text)) - } - Hunk::Remove { len } => { - let edit_end = edit_start + len; - let edit_range = snapshot.anchor_after(edit_start) - ..snapshot.anchor_before(edit_end); - edit_start = edit_end; - Some((edit_range, String::new())) - } - Hunk::Keep { len } => { - let edit_end = edit_start + len; - let edit_range = snapshot.anchor_after(edit_start) - ..snapshot.anchor_before(edit_end); - edit_start += len; - pending_assist.highlighted_ranges.push(edit_range); - None - } - }), - None, - cx, - ); - - buffer.end_transaction(cx) - }); - - if let Some(transaction) = transaction { - if let Some(first_transaction) = pending_assist.transaction_id { - // Group all assistant edits into the first transaction. - editor.buffer().update(cx, |buffer, cx| { - buffer.merge_transactions( - transaction, - first_transaction, - cx, - ) - }); - } else { - pending_assist.transaction_id = Some(transaction); - editor.buffer().update(cx, |buffer, cx| { - buffer.finalize_last_transaction(cx) - }); - } - } - }); - - this.update_highlights_for_editor(&editor, cx); - }); - } - - if let Err(error) = diff.await { - this.update(&mut cx, |this, cx| { - let pending_assist = if let Some(pending_assist) = - this.pending_inline_assists.get_mut(&inline_assist_id) - { - pending_assist - } else { - return; - }; - - if let Some((_, inline_assistant)) = - pending_assist.inline_assistant.as_ref() - { - inline_assistant.update(cx, |inline_assistant, cx| { - inline_assistant.set_error(error, cx); - }); - } else if let Some(workspace) = this.workspace.upgrade(cx) { - workspace.update(cx, |workspace, cx| { - workspace.show_toast( - Toast::new( - inline_assist_id, - format!("Inline assistant error: {}", error), - ), - cx, - ); - }) - } - })?; - } else { - let _ = this.update(&mut cx, |this, cx| { - this.close_inline_assist(inline_assist_id, false, cx) - }); - } - - anyhow::Ok(()) - } - .log_err() - }); + pending_assist + .codegen + .update(cx, |codegen, cx| codegen.start(request, cx)); } fn update_highlights_for_editor( @@ -909,8 +670,9 @@ impl AssistantPanel { for inline_assist_id in inline_assist_ids { if let Some(pending_assist) = self.pending_inline_assists.get(inline_assist_id) { - background_ranges.push(pending_assist.range.clone()); - foreground_ranges.extend(pending_assist.highlighted_ranges.iter().cloned()); + let codegen = pending_assist.codegen.read(cx); + background_ranges.push(codegen.range()); + foreground_ranges.extend(codegen.last_equal_ranges().iter().cloned()); } } @@ -2900,11 +2662,11 @@ struct InlineAssistant { has_focus: bool, include_conversation: bool, measurements: Rc>, - error: Option, prompt_history: VecDeque, prompt_history_ix: Option, pending_prompt: String, - _subscription: Subscription, + codegen: ModelHandle, + _subscriptions: Vec, } impl Entity for InlineAssistant { @@ -2933,7 +2695,7 @@ impl View for InlineAssistant { .element() .aligned(), ) - .with_children(if let Some(error) = self.error.as_ref() { + .with_children(if let Some(error) = self.codegen.read(cx).error() { Some( Svg::new("icons/circle_x_mark_12.svg") .with_color(theme.assistant.error_icon.color) @@ -3011,6 +2773,7 @@ impl InlineAssistant { measurements: Rc>, include_conversation: bool, prompt_history: VecDeque, + codegen: ModelHandle, cx: &mut ViewContext, ) -> Self { let prompt_editor = cx.add_view(|cx| { @@ -3025,7 +2788,10 @@ impl InlineAssistant { editor.set_placeholder_text(placeholder, cx); editor }); - let subscription = cx.subscribe(&prompt_editor, Self::handle_prompt_editor_events); + let subscriptions = vec![ + cx.observe(&codegen, Self::handle_codegen_changed), + cx.subscribe(&prompt_editor, Self::handle_prompt_editor_events), + ]; Self { id, prompt_editor, @@ -3033,11 +2799,11 @@ impl InlineAssistant { has_focus: false, include_conversation, measurements, - error: None, prompt_history, prompt_history_ix: None, pending_prompt: String::new(), - _subscription: subscription, + codegen, + _subscriptions: subscriptions, } } @@ -3053,6 +2819,31 @@ impl InlineAssistant { } } + fn handle_codegen_changed(&mut self, _: ModelHandle, cx: &mut ViewContext) { + let is_read_only = !self.codegen.read(cx).idle(); + self.prompt_editor.update(cx, |editor, cx| { + let was_read_only = editor.read_only(); + if was_read_only != is_read_only { + if is_read_only { + editor.set_read_only(true); + editor.set_field_editor_style( + Some(Arc::new(|theme| { + theme.assistant.inline.disabled_editor.clone() + })), + cx, + ); + } else { + editor.set_read_only(false); + editor.set_field_editor_style( + Some(Arc::new(|theme| theme.assistant.inline.editor.clone())), + cx, + ); + } + } + }); + cx.notify(); + } + fn cancel(&mut self, _: &editor::Cancel, cx: &mut ViewContext) { cx.emit(InlineAssistantEvent::Canceled); } @@ -3076,7 +2867,6 @@ impl InlineAssistant { include_conversation: self.include_conversation, }); self.confirmed = true; - self.error = None; cx.notify(); } } @@ -3093,19 +2883,6 @@ impl InlineAssistant { cx.notify(); } - fn set_error(&mut self, error: anyhow::Error, cx: &mut ViewContext) { - self.error = Some(error); - self.confirmed = false; - self.prompt_editor.update(cx, |editor, cx| { - editor.set_read_only(false); - editor.set_field_editor_style( - Some(Arc::new(|theme| theme.assistant.inline.editor.clone())), - cx, - ); - }); - cx.notify(); - } - fn move_up(&mut self, _: &MoveUp, cx: &mut ViewContext) { if let Some(ix) = self.prompt_history_ix { if ix > 0 { @@ -3154,11 +2931,8 @@ struct BlockMeasurements { struct PendingInlineAssist { kind: InlineAssistKind, editor: WeakViewHandle, - range: Range, - highlighted_ranges: Vec>, inline_assistant: Option<(BlockId, ViewHandle)>, - code_generation: Task>, - transaction_id: Option, + codegen: ModelHandle, _subscriptions: Vec, } @@ -3184,65 +2958,10 @@ fn merge_ranges(ranges: &mut Vec>, buffer: &MultiBufferSnapshot) { } } -fn strip_markdown_codeblock( - stream: impl Stream>, -) -> impl Stream> { - let mut first_line = true; - let mut buffer = String::new(); - let mut starts_with_fenced_code_block = false; - stream.filter_map(move |chunk| { - let chunk = match chunk { - Ok(chunk) => chunk, - Err(err) => return future::ready(Some(Err(err))), - }; - buffer.push_str(&chunk); - - if first_line { - if buffer == "" || buffer == "`" || buffer == "``" { - return future::ready(None); - } else if buffer.starts_with("```") { - starts_with_fenced_code_block = true; - if let Some(newline_ix) = buffer.find('\n') { - buffer.replace_range(..newline_ix + 1, ""); - first_line = false; - } else { - return future::ready(None); - } - } - } - - let text = if starts_with_fenced_code_block { - buffer - .strip_suffix("\n```\n") - .or_else(|| buffer.strip_suffix("\n```")) - .or_else(|| buffer.strip_suffix("\n``")) - .or_else(|| buffer.strip_suffix("\n`")) - .or_else(|| buffer.strip_suffix('\n')) - .unwrap_or(&buffer) - } else { - &buffer - }; - - if text.contains('\n') { - first_line = false; - } - - let remainder = buffer.split_off(text.len()); - let result = if buffer.is_empty() { - None - } else { - Some(Ok(buffer.clone())) - }; - buffer = remainder; - future::ready(result) - }) -} - #[cfg(test)] mod tests { use super::*; use crate::MessageId; - use futures::stream; use gpui::AppContext; #[gpui::test] @@ -3611,62 +3330,6 @@ mod tests { ); } - #[gpui::test] - async fn test_strip_markdown_codeblock() { - assert_eq!( - strip_markdown_codeblock(chunks("Lorem ipsum dolor", 2)) - .map(|chunk| chunk.unwrap()) - .collect::() - .await, - "Lorem ipsum dolor" - ); - assert_eq!( - strip_markdown_codeblock(chunks("```\nLorem ipsum dolor", 2)) - .map(|chunk| chunk.unwrap()) - .collect::() - .await, - "Lorem ipsum dolor" - ); - assert_eq!( - strip_markdown_codeblock(chunks("```\nLorem ipsum dolor\n```", 2)) - .map(|chunk| chunk.unwrap()) - .collect::() - .await, - "Lorem ipsum dolor" - ); - assert_eq!( - strip_markdown_codeblock(chunks("```\nLorem ipsum dolor\n```\n", 2)) - .map(|chunk| chunk.unwrap()) - .collect::() - .await, - "Lorem ipsum dolor" - ); - assert_eq!( - strip_markdown_codeblock(chunks("```html\n```js\nLorem ipsum dolor\n```\n```", 2)) - .map(|chunk| chunk.unwrap()) - .collect::() - .await, - "```js\nLorem ipsum dolor\n```" - ); - assert_eq!( - strip_markdown_codeblock(chunks("``\nLorem ipsum dolor\n```", 2)) - .map(|chunk| chunk.unwrap()) - .collect::() - .await, - "``\nLorem ipsum dolor\n```" - ); - - fn chunks(text: &str, size: usize) -> impl Stream> { - stream::iter( - text.chars() - .collect::>() - .chunks(size) - .map(|chunk| Ok(chunk.iter().collect::())) - .collect::>(), - ) - } - } - fn messages( conversation: &ModelHandle, cx: &AppContext, diff --git a/crates/ai/src/codegen.rs b/crates/ai/src/codegen.rs new file mode 100644 index 0000000000..b24c0f9435 --- /dev/null +++ b/crates/ai/src/codegen.rs @@ -0,0 +1,468 @@ +use crate::{ + stream_completion, + streaming_diff::{Hunk, StreamingDiff}, + OpenAIRequest, +}; +use anyhow::Result; +use editor::{multi_buffer, Anchor, MultiBuffer, ToOffset, ToPoint}; +use futures::{ + channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, SinkExt, Stream, StreamExt, +}; +use gpui::{executor::Background, Entity, ModelContext, ModelHandle, Task}; +use language::{IndentSize, Point, Rope, TransactionId}; +use std::{cmp, future, ops::Range, sync::Arc}; + +pub trait CompletionProvider { + fn complete( + &self, + prompt: OpenAIRequest, + ) -> BoxFuture<'static, Result>>>; +} + +pub struct OpenAICompletionProvider { + api_key: String, + executor: Arc, +} + +impl OpenAICompletionProvider { + pub fn new(api_key: String, executor: Arc) -> Self { + Self { api_key, executor } + } +} + +impl CompletionProvider for OpenAICompletionProvider { + fn complete( + &self, + prompt: OpenAIRequest, + ) -> BoxFuture<'static, Result>>> { + let request = stream_completion(self.api_key.clone(), self.executor.clone(), prompt); + async move { + let response = request.await?; + let stream = response + .filter_map(|response| async move { + match response { + Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)), + Err(error) => Some(Err(error)), + } + }) + .boxed(); + Ok(stream) + } + .boxed() + } +} + +pub enum Event { + Finished, + Undone, +} + +pub struct Codegen { + provider: Arc, + buffer: ModelHandle, + range: Range, + last_equal_ranges: Vec>, + transaction_id: Option, + error: Option, + generation: Task<()>, + idle: bool, + _subscription: gpui::Subscription, +} + +impl Entity for Codegen { + type Event = Event; +} + +impl Codegen { + pub fn new( + buffer: ModelHandle, + range: Range, + provider: Arc, + cx: &mut ModelContext, + ) -> Self { + Self { + provider, + buffer: buffer.clone(), + range, + last_equal_ranges: Default::default(), + transaction_id: Default::default(), + error: Default::default(), + idle: true, + generation: Task::ready(()), + _subscription: cx.subscribe(&buffer, Self::handle_buffer_event), + } + } + + fn handle_buffer_event( + &mut self, + _buffer: ModelHandle, + event: &multi_buffer::Event, + cx: &mut ModelContext, + ) { + if let multi_buffer::Event::TransactionUndone { transaction_id } = event { + if self.transaction_id == Some(*transaction_id) { + self.transaction_id = None; + self.generation = Task::ready(()); + cx.emit(Event::Undone); + } + } + } + + pub fn range(&self) -> Range { + self.range.clone() + } + + pub fn last_equal_ranges(&self) -> &[Range] { + &self.last_equal_ranges + } + + pub fn idle(&self) -> bool { + self.idle + } + + pub fn error(&self) -> Option<&anyhow::Error> { + self.error.as_ref() + } + + pub fn start(&mut self, prompt: OpenAIRequest, cx: &mut ModelContext) { + let range = self.range.clone(); + let snapshot = self.buffer.read(cx).snapshot(cx); + let selected_text = snapshot + .text_for_range(range.start..range.end) + .collect::(); + + let selection_start = range.start.to_point(&snapshot); + let selection_end = range.end.to_point(&snapshot); + + let mut base_indent: Option = None; + let mut start_row = selection_start.row; + if snapshot.is_line_blank(start_row) { + if let Some(prev_non_blank_row) = snapshot.prev_non_blank_row(start_row) { + start_row = prev_non_blank_row; + } + } + for row in start_row..=selection_end.row { + if snapshot.is_line_blank(row) { + continue; + } + + let line_indent = snapshot.indent_size_for_line(row); + if let Some(base_indent) = base_indent.as_mut() { + if line_indent.len < base_indent.len { + *base_indent = line_indent; + } + } else { + base_indent = Some(line_indent); + } + } + + let mut normalized_selected_text = selected_text.clone(); + if let Some(base_indent) = base_indent { + for row in selection_start.row..=selection_end.row { + let selection_row = row - selection_start.row; + let line_start = + normalized_selected_text.point_to_offset(Point::new(selection_row, 0)); + let indent_len = if row == selection_start.row { + base_indent.len.saturating_sub(selection_start.column) + } else { + let line_len = normalized_selected_text.line_len(selection_row); + cmp::min(line_len, base_indent.len) + }; + let indent_end = cmp::min( + line_start + indent_len as usize, + normalized_selected_text.len(), + ); + normalized_selected_text.replace(line_start..indent_end, ""); + } + } + + let response = self.provider.complete(prompt); + self.generation = cx.spawn_weak(|this, mut cx| { + async move { + let generate = async { + let mut edit_start = range.start.to_offset(&snapshot); + + let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1); + let diff = cx.background().spawn(async move { + let chunks = strip_markdown_codeblock(response.await?); + futures::pin_mut!(chunks); + let mut diff = StreamingDiff::new(selected_text.to_string()); + + let mut indent_len; + let indent_text; + if let Some(base_indent) = base_indent { + indent_len = base_indent.len; + indent_text = match base_indent.kind { + language::IndentKind::Space => " ", + language::IndentKind::Tab => "\t", + }; + } else { + indent_len = 0; + indent_text = ""; + }; + + let mut first_line_len = 0; + let mut first_line_non_whitespace_char_ix = None; + let mut first_line = true; + let mut new_text = String::new(); + + while let Some(chunk) = chunks.next().await { + let chunk = chunk?; + + let mut lines = chunk.split('\n'); + if let Some(mut line) = lines.next() { + if first_line { + if first_line_non_whitespace_char_ix.is_none() { + if let Some(mut char_ix) = + line.find(|ch: char| !ch.is_whitespace()) + { + line = &line[char_ix..]; + char_ix += first_line_len; + first_line_non_whitespace_char_ix = Some(char_ix); + let first_line_indent = char_ix + .saturating_sub(selection_start.column as usize) + as usize; + new_text + .push_str(&indent_text.repeat(first_line_indent)); + indent_len = indent_len.saturating_sub(char_ix as u32); + } + } + first_line_len += line.len(); + } + + if first_line_non_whitespace_char_ix.is_some() { + new_text.push_str(line); + } + } + + for line in lines { + first_line = false; + new_text.push('\n'); + if !line.is_empty() { + new_text.push_str(&indent_text.repeat(indent_len as usize)); + } + new_text.push_str(line); + } + + let hunks = diff.push_new(&new_text); + hunks_tx.send(hunks).await?; + new_text.clear(); + } + hunks_tx.send(diff.finish()).await?; + + anyhow::Ok(()) + }); + + while let Some(hunks) = hunks_rx.next().await { + let this = if let Some(this) = this.upgrade(&cx) { + this + } else { + break; + }; + + this.update(&mut cx, |this, cx| { + this.last_equal_ranges.clear(); + + let transaction = this.buffer.update(cx, |buffer, cx| { + // Avoid grouping assistant edits with user edits. + buffer.finalize_last_transaction(cx); + + buffer.start_transaction(cx); + buffer.edit( + hunks.into_iter().filter_map(|hunk| match hunk { + Hunk::Insert { text } => { + let edit_start = snapshot.anchor_after(edit_start); + Some((edit_start..edit_start, text)) + } + Hunk::Remove { len } => { + let edit_end = edit_start + len; + let edit_range = snapshot.anchor_after(edit_start) + ..snapshot.anchor_before(edit_end); + edit_start = edit_end; + Some((edit_range, String::new())) + } + Hunk::Keep { len } => { + let edit_end = edit_start + len; + let edit_range = snapshot.anchor_after(edit_start) + ..snapshot.anchor_before(edit_end); + edit_start += len; + this.last_equal_ranges.push(edit_range); + None + } + }), + None, + cx, + ); + + buffer.end_transaction(cx) + }); + + if let Some(transaction) = transaction { + if let Some(first_transaction) = this.transaction_id { + // Group all assistant edits into the first transaction. + this.buffer.update(cx, |buffer, cx| { + buffer.merge_transactions( + transaction, + first_transaction, + cx, + ) + }); + } else { + this.transaction_id = Some(transaction); + this.buffer.update(cx, |buffer, cx| { + buffer.finalize_last_transaction(cx) + }); + } + } + + cx.notify(); + }); + } + + diff.await?; + anyhow::Ok(()) + }; + + let result = generate.await; + if let Some(this) = this.upgrade(&cx) { + this.update(&mut cx, |this, cx| { + this.last_equal_ranges.clear(); + this.idle = true; + if let Err(error) = result { + this.error = Some(error); + } + cx.emit(Event::Finished); + cx.notify(); + }); + } + } + }); + self.error.take(); + self.idle = false; + cx.notify(); + } + + pub fn undo(&mut self, cx: &mut ModelContext) { + if let Some(transaction_id) = self.transaction_id { + self.buffer + .update(cx, |buffer, cx| buffer.undo_transaction(transaction_id, cx)); + } + } +} + +fn strip_markdown_codeblock( + stream: impl Stream>, +) -> impl Stream> { + let mut first_line = true; + let mut buffer = String::new(); + let mut starts_with_fenced_code_block = false; + stream.filter_map(move |chunk| { + let chunk = match chunk { + Ok(chunk) => chunk, + Err(err) => return future::ready(Some(Err(err))), + }; + buffer.push_str(&chunk); + + if first_line { + if buffer == "" || buffer == "`" || buffer == "``" { + return future::ready(None); + } else if buffer.starts_with("```") { + starts_with_fenced_code_block = true; + if let Some(newline_ix) = buffer.find('\n') { + buffer.replace_range(..newline_ix + 1, ""); + first_line = false; + } else { + return future::ready(None); + } + } + } + + let text = if starts_with_fenced_code_block { + buffer + .strip_suffix("\n```\n") + .or_else(|| buffer.strip_suffix("\n```")) + .or_else(|| buffer.strip_suffix("\n``")) + .or_else(|| buffer.strip_suffix("\n`")) + .or_else(|| buffer.strip_suffix('\n')) + .unwrap_or(&buffer) + } else { + &buffer + }; + + if text.contains('\n') { + first_line = false; + } + + let remainder = buffer.split_off(text.len()); + let result = if buffer.is_empty() { + None + } else { + Some(Ok(buffer.clone())) + }; + buffer = remainder; + future::ready(result) + }) +} + +#[cfg(test)] +mod tests { + use futures::stream; + + use super::*; + + #[gpui::test] + async fn test_strip_markdown_codeblock() { + assert_eq!( + strip_markdown_codeblock(chunks("Lorem ipsum dolor", 2)) + .map(|chunk| chunk.unwrap()) + .collect::() + .await, + "Lorem ipsum dolor" + ); + assert_eq!( + strip_markdown_codeblock(chunks("```\nLorem ipsum dolor", 2)) + .map(|chunk| chunk.unwrap()) + .collect::() + .await, + "Lorem ipsum dolor" + ); + assert_eq!( + strip_markdown_codeblock(chunks("```\nLorem ipsum dolor\n```", 2)) + .map(|chunk| chunk.unwrap()) + .collect::() + .await, + "Lorem ipsum dolor" + ); + assert_eq!( + strip_markdown_codeblock(chunks("```\nLorem ipsum dolor\n```\n", 2)) + .map(|chunk| chunk.unwrap()) + .collect::() + .await, + "Lorem ipsum dolor" + ); + assert_eq!( + strip_markdown_codeblock(chunks("```html\n```js\nLorem ipsum dolor\n```\n```", 2)) + .map(|chunk| chunk.unwrap()) + .collect::() + .await, + "```js\nLorem ipsum dolor\n```" + ); + assert_eq!( + strip_markdown_codeblock(chunks("``\nLorem ipsum dolor\n```", 2)) + .map(|chunk| chunk.unwrap()) + .collect::() + .await, + "``\nLorem ipsum dolor\n```" + ); + + fn chunks(text: &str, size: usize) -> impl Stream> { + stream::iter( + text.chars() + .collect::>() + .chunks(size) + .map(|chunk| Ok(chunk.iter().collect::())) + .collect::>(), + ) + } + } +} diff --git a/crates/editor/src/editor.rs b/crates/editor/src/editor.rs index bdd29b04fa..12df29df1d 100644 --- a/crates/editor/src/editor.rs +++ b/crates/editor/src/editor.rs @@ -1734,6 +1734,10 @@ impl Editor { } } + pub fn read_only(&self) -> bool { + self.read_only + } + pub fn set_read_only(&mut self, read_only: bool) { self.read_only = read_only; } @@ -5103,9 +5107,6 @@ impl Editor { self.unmark_text(cx); self.refresh_copilot_suggestions(true, cx); cx.emit(Event::Edited); - cx.emit(Event::TransactionUndone { - transaction_id: tx_id, - }); } } @@ -8548,9 +8549,6 @@ pub enum Event { local: bool, autoscroll: bool, }, - TransactionUndone { - transaction_id: TransactionId, - }, Closed, } diff --git a/crates/editor/src/multi_buffer.rs b/crates/editor/src/multi_buffer.rs index 74283fd778..c5d17dfd2e 100644 --- a/crates/editor/src/multi_buffer.rs +++ b/crates/editor/src/multi_buffer.rs @@ -70,6 +70,9 @@ pub enum Event { Edited { sigleton_buffer_edited: bool, }, + TransactionUndone { + transaction_id: TransactionId, + }, Reloaded, DiffBaseChanged, LanguageChanged, @@ -771,30 +774,36 @@ impl MultiBuffer { } pub fn undo(&mut self, cx: &mut ModelContext) -> Option { + let mut transaction_id = None; if let Some(buffer) = self.as_singleton() { - return buffer.update(cx, |buffer, cx| buffer.undo(cx)); - } + transaction_id = buffer.update(cx, |buffer, cx| buffer.undo(cx)); + } else { + while let Some(transaction) = self.history.pop_undo() { + let mut undone = false; + for (buffer_id, buffer_transaction_id) in &mut transaction.buffer_transactions { + if let Some(BufferState { buffer, .. }) = self.buffers.borrow().get(buffer_id) { + undone |= buffer.update(cx, |buffer, cx| { + let undo_to = *buffer_transaction_id; + if let Some(entry) = buffer.peek_undo_stack() { + *buffer_transaction_id = entry.transaction_id(); + } + buffer.undo_to_transaction(undo_to, cx) + }); + } + } - while let Some(transaction) = self.history.pop_undo() { - let mut undone = false; - for (buffer_id, buffer_transaction_id) in &mut transaction.buffer_transactions { - if let Some(BufferState { buffer, .. }) = self.buffers.borrow().get(buffer_id) { - undone |= buffer.update(cx, |buffer, cx| { - let undo_to = *buffer_transaction_id; - if let Some(entry) = buffer.peek_undo_stack() { - *buffer_transaction_id = entry.transaction_id(); - } - buffer.undo_to_transaction(undo_to, cx) - }); + if undone { + transaction_id = Some(transaction.id); + break; } } - - if undone { - return Some(transaction.id); - } } - None + if let Some(transaction_id) = transaction_id { + cx.emit(Event::TransactionUndone { transaction_id }); + } + + transaction_id } pub fn redo(&mut self, cx: &mut ModelContext) -> Option {