diff --git a/crates/ai/src/refactor.rs b/crates/ai/src/refactor.rs index 2821a1e845..9b36d760b7 100644 --- a/crates/ai/src/refactor.rs +++ b/crates/ai/src/refactor.rs @@ -1,13 +1,14 @@ use crate::{diff::Diff, stream_completion, OpenAIRequest, RequestMessage, Role}; use collections::HashMap; -use editor::{Editor, ToOffset}; +use editor::{Editor, ToOffset, ToPoint}; use futures::{channel::mpsc, SinkExt, StreamExt}; use gpui::{ actions, elements::*, platform::MouseButton, AnyViewHandle, AppContext, Entity, Task, View, ViewContext, ViewHandle, WeakViewHandle, }; +use language::{Point, Rope}; use menu::{Cancel, Confirm}; -use std::{env, sync::Arc}; +use std::{cmp, env, sync::Arc}; use util::TryFutureExt; use workspace::{Modal, Workspace}; @@ -36,7 +37,48 @@ impl RefactoringAssistant { let selection = editor.read(cx).selections.newest_anchor().clone(); let selected_text = snapshot .text_for_range(selection.start..selection.end) - .collect::(); + .collect::(); + + let mut normalized_selected_text = selected_text.clone(); + let mut base_indentation: Option = None; + let selection_start = selection.start.to_point(&snapshot); + let selection_end = selection.end.to_point(&snapshot); + if selection_start.row < selection_end.row { + for row in selection_start.row..=selection_end.row { + if snapshot.is_line_blank(row) { + continue; + } + + let line_indentation = snapshot.indent_size_for_line(row); + if let Some(base_indentation) = base_indentation.as_mut() { + if line_indentation.len < base_indentation.len { + *base_indentation = line_indentation; + } + } else { + base_indentation = Some(line_indentation); + } + } + } + + if let Some(base_indentation) = base_indentation { + for row in selection_start.row..=selection_end.row { + let selection_row = row - selection_start.row; + let line_start = + normalized_selected_text.point_to_offset(Point::new(selection_row, 0)); + let indentation_len = if row == selection_start.row { + base_indentation.len.saturating_sub(selection_start.column) + } else { + let line_len = normalized_selected_text.line_len(selection_row); + cmp::min(line_len, base_indentation.len) + }; + let indentation_end = cmp::min( + line_start + indentation_len as usize, + normalized_selected_text.len(), + ); + normalized_selected_text.replace(line_start..indentation_end, ""); + } + } + let language_name = snapshot .language_at(selection.start) .map(|language| language.name()); @@ -47,7 +89,7 @@ impl RefactoringAssistant { RequestMessage { role: Role::User, content: format!( - "Given the following {language_name} snippet:\n{selected_text}\n{prompt}. Never make remarks and reply only with the new code. Never change the leading whitespace on each line." + "Given the following {language_name} snippet:\n{normalized_selected_text}\n{prompt}. Never make remarks and reply only with the new code." ), }], stream: true, @@ -64,21 +106,49 @@ impl RefactoringAssistant { let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1); let diff = cx.background().spawn(async move { let mut messages = response.await?.ready_chunks(4); - let mut diff = Diff::new(selected_text); + let mut diff = Diff::new(selected_text.to_string()); + let indentation_len; + let indentation_text; + if let Some(base_indentation) = base_indentation { + indentation_len = base_indentation.len; + indentation_text = match base_indentation.kind { + language::IndentKind::Space => " ", + language::IndentKind::Tab => "\t", + }; + } else { + indentation_len = 0; + indentation_text = ""; + }; + + let mut new_text = + indentation_text.repeat( + indentation_len.saturating_sub(selection_start.column) as usize, + ); while let Some(messages) = messages.next().await { - let mut new_text = String::new(); for message in messages { let mut message = message?; if let Some(choice) = message.choices.pop() { if let Some(text) = choice.delta.content { - new_text.push_str(&text); + let mut lines = text.split('\n'); + if let Some(first_line) = lines.next() { + new_text.push_str(&first_line); + } + + for line in lines { + new_text.push('\n'); + new_text.push_str( + &indentation_text.repeat(indentation_len as usize), + ); + new_text.push_str(line); + } } } } let hunks = diff.push_new(&new_text); hunks_tx.send(hunks).await?; + new_text.clear(); } hunks_tx.send(diff.finish()).await?; diff --git a/crates/rope/src/rope.rs b/crates/rope/src/rope.rs index 2bfb090bb2..9c764c468e 100644 --- a/crates/rope/src/rope.rs +++ b/crates/rope/src/rope.rs @@ -384,6 +384,16 @@ impl<'a> From<&'a str> for Rope { } } +impl<'a> FromIterator<&'a str> for Rope { + fn from_iter>(iter: T) -> Self { + let mut rope = Rope::new(); + for chunk in iter { + rope.push(chunk); + } + rope + } +} + impl From for Rope { fn from(text: String) -> Self { Rope::from(text.as_str())