From 21e8e8763e5a28264a15a4f5ef3d19447247169e Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Thu, 15 Jun 2023 13:59:01 +0200 Subject: [PATCH] Allow splitting of messages using `shift-enter` --- assets/keymaps/default.json | 3 +- crates/ai/src/assistant.rs | 158 ++++++++++++++++++++++++++++-------- 2 files changed, 126 insertions(+), 35 deletions(-) diff --git a/assets/keymaps/default.json b/assets/keymaps/default.json index 45e85fd04f..f6682a9f0b 100644 --- a/assets/keymaps/default.json +++ b/assets/keymaps/default.json @@ -200,7 +200,8 @@ "context": "AssistantEditor > Editor", "bindings": { "cmd-enter": "assistant::Assist", - "cmd->": "assistant::QuoteSelection" + "cmd->": "assistant::QuoteSelection", + "shift-enter": "assistant::Split" } }, { diff --git a/crates/ai/src/assistant.rs b/crates/ai/src/assistant.rs index e5702cb677..cd334d77b1 100644 --- a/crates/ai/src/assistant.rs +++ b/crates/ai/src/assistant.rs @@ -8,7 +8,7 @@ use collections::{HashMap, HashSet}; use editor::{ display_map::{BlockDisposition, BlockId, BlockProperties, BlockStyle, ToDisplayPoint}, scroll::autoscroll::{Autoscroll, AutoscrollStrategy}, - Anchor, Editor, ToOffset as _, + Anchor, Editor, }; use fs::Fs; use futures::{io::BufReader, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt}; @@ -40,7 +40,14 @@ const OPENAI_API_URL: &'static str = "https://api.openai.com/v1"; actions!( assistant, - [NewContext, Assist, QuoteSelection, ToggleFocus, ResetKey] + [ + NewContext, + Assist, + Split, + QuoteSelection, + ToggleFocus, + ResetKey + ] ); pub fn init(cx: &mut AppContext) { @@ -64,6 +71,7 @@ pub fn init(cx: &mut AppContext) { cx.capture_action(AssistantEditor::cancel_last_assist); cx.add_action(AssistantEditor::quote_selection); cx.capture_action(AssistantEditor::copy); + cx.capture_action(AssistantEditor::split); cx.add_action(AssistantPanel::save_api_key); cx.add_action(AssistantPanel::reset_api_key); cx.add_action( @@ -711,6 +719,67 @@ impl Assistant { } } + fn split_message( + &mut self, + range: Range, + cx: &mut ModelContext, + ) -> (Option, Option) { + let start_message = self.message_for_offset(range.start, cx); + let end_message = self.message_for_offset(range.end, cx); + if let Some((start_message, end_message)) = start_message.zip(end_message) { + let (start_message_ix, _, start_message_metadata) = start_message; + let (end_message_ix, _, _) = end_message; + + // Prevent splitting when range spans multiple messages. + if start_message_ix != end_message_ix { + return (None, None); + } + + let role = start_message_metadata.role; + self.buffer.update(cx, |buffer, cx| { + buffer.edit([(range.end..range.end, "\n")], None, cx) + }); + let suffix = Message { + id: MessageId(post_inc(&mut self.next_message_id.0)), + start: self.buffer.read(cx).anchor_before(range.end + 1), + }; + self.messages.insert(start_message_ix + 1, suffix.clone()); + self.messages_metadata.insert( + suffix.id, + MessageMetadata { + role, + sent_at: Local::now(), + error: None, + }, + ); + + if range.start == range.end { + (None, Some(suffix)) + } else { + self.buffer.update(cx, |buffer, cx| { + buffer.edit([(range.start..range.start, "\n")], None, cx) + }); + let selection = Message { + id: MessageId(post_inc(&mut self.next_message_id.0)), + start: self.buffer.read(cx).anchor_before(range.start + 1), + }; + self.messages + .insert(start_message_ix + 1, selection.clone()); + self.messages_metadata.insert( + selection.id, + MessageMetadata { + role, + sent_at: Local::now(), + error: None, + }, + ); + (Some(selection), Some(suffix)) + } + } else { + (None, None) + } + } + fn summarize(&mut self, cx: &mut ModelContext) { if self.messages.len() >= 2 && self.summary.is_none() { let api_key = self.api_key.borrow().clone(); @@ -755,35 +824,39 @@ impl Assistant { fn open_ai_request_messages(&self, cx: &AppContext) -> Vec { let buffer = self.buffer.read(cx); self.messages(cx) - .map(|(_message, metadata, range)| RequestMessage { + .map(|(_ix, _message, metadata, range)| RequestMessage { role: metadata.role, content: buffer.text_for_range(range).collect(), }) .collect() } - fn message_id_for_offset(&self, offset: usize, cx: &AppContext) -> Option { - Some( - self.messages(cx) - .find(|(_, _, range)| range.contains(&offset)) - .map(|(message, _, _)| message) - .or(self.messages.last())? - .id, - ) + fn message_for_offset<'a>( + &'a self, + offset: usize, + cx: &'a AppContext, + ) -> Option<(usize, &Message, &MessageMetadata)> { + let mut messages = self.messages(cx).peekable(); + while let Some((ix, message, metadata, range)) = messages.next() { + if range.contains(&offset) || messages.peek().is_none() { + return Some((ix, message, metadata)); + } + } + None } fn messages<'a>( &'a self, cx: &'a AppContext, - ) -> impl 'a + Iterator)> { + ) -> impl 'a + Iterator)> { let buffer = self.buffer.read(cx); - let mut messages = self.messages.iter().peekable(); + let mut messages = self.messages.iter().enumerate().peekable(); iter::from_fn(move || { - while let Some(message) = messages.next() { + while let Some((ix, message)) = messages.next() { let metadata = self.messages_metadata.get(&message.id)?; let message_start = message.start.to_offset(buffer); let mut message_end = None; - while let Some(next_message) = messages.peek() { + while let Some((_, next_message)) = messages.peek() { if next_message.start.is_valid(buffer) { message_end = Some(next_message.start); break; @@ -794,7 +867,7 @@ impl Assistant { let message_end = message_end .unwrap_or(language::Anchor::MAX) .to_offset(buffer); - return Some((message, metadata, message_start..message_end)); + return Some((ix, message, metadata, message_start..message_end)); } None }) @@ -857,21 +930,7 @@ impl AssistantEditor { fn assist(&mut self, _: &Assist, cx: &mut ViewContext) { let user_message = self.assistant.update(cx, |assistant, cx| { - let editor = self.editor.read(cx); - let newest_selection = editor - .selections - .newest_anchor() - .head() - .to_offset(&editor.buffer().read(cx).snapshot(cx)); - let message_id = assistant.message_id_for_offset(newest_selection, cx)?; - let metadata = assistant.messages_metadata.get(&message_id)?; - let user_message = if metadata.role == Role::User { - let (_, user_message) = assistant.assist(cx)?; - user_message - } else { - let user_message = assistant.insert_message_after(message_id, Role::User, cx)?; - user_message - }; + let (_, user_message) = assistant.assist(cx)?; Some(user_message) }); @@ -982,7 +1041,7 @@ impl AssistantEditor { .assistant .read(cx) .messages(cx) - .map(|(message, metadata, _)| BlockProperties { + .map(|(_, message, metadata, _)| BlockProperties { position: buffer.anchor_in_excerpt(excerpt_id, message.start), height: 2, style: BlockStyle::Sticky, @@ -1147,7 +1206,7 @@ impl AssistantEditor { let selection = editor.selections.newest::(cx); let mut copied_text = String::new(); let mut spanned_messages = 0; - for (_message, metadata, message_range) in assistant.messages(cx) { + for (_ix, _message, metadata, message_range) in assistant.messages(cx) { if message_range.start >= selection.range().end { break; } else if message_range.end >= selection.range().start { @@ -1174,6 +1233,13 @@ impl AssistantEditor { cx.propagate_action(); } + fn split(&mut self, _: &Split, cx: &mut ViewContext) { + self.assistant.update(cx, |assistant, cx| { + let range = self.editor.read(cx).selections.newest::(cx).range(); + assistant.split_message(range, cx); + }); + } + fn cycle_model(&mut self, cx: &mut ViewContext) { self.assistant.update(cx, |assistant, cx| { let new_model = match assistant.model.as_str() { @@ -1510,6 +1576,30 @@ mod tests { (message_3.id, Role::User, 4..5) ] ); + + // Split a message into prefix, selection and suffix. + buffer.update(cx, |buffer, cx| buffer.edit([(2..2, "3")], None, cx)); + assert_eq!( + messages(&assistant, cx), + vec![ + (message_1.id, Role::User, 0..4), + (message_5.id, Role::System, 4..5), + (message_3.id, Role::User, 5..6) + ] + ); + let (message_6, message_7) = + assistant.update(cx, |assistant, cx| assistant.split_message(2..3, cx)); + let (message_6, message_7) = (message_6.unwrap(), message_7.unwrap()); + assert_eq!( + messages(&assistant, cx), + vec![ + (message_1.id, Role::User, 0..3), + (message_6.id, Role::User, 3..5), + (message_7.id, Role::User, 5..6), + (message_5.id, Role::System, 6..7), + (message_3.id, Role::User, 7..8) + ] + ); } fn messages( @@ -1519,7 +1609,7 @@ mod tests { assistant .read(cx) .messages(cx) - .map(|(message, metadata, range)| (message.id, metadata.role, range)) + .map(|(_, message, metadata, range)| (message.id, metadata.role, range)) .collect() } }