Maintain scroll bottom when streaming assistant responses

This commit is contained in:
Antonio Scandurra 2023-06-07 15:01:50 +02:00
parent 43500dbf60
commit d26cc2c897
6 changed files with 176 additions and 75 deletions

View File

@ -5,13 +5,21 @@ use crate::{
use anyhow::{anyhow, Result};
use chrono::{DateTime, Local};
use collections::{HashMap, HashSet};
use editor::{Anchor, Editor, ExcerptId, ExcerptRange, MultiBuffer};
use editor::{
display_map::ToDisplayPoint,
scroll::{
autoscroll::{Autoscroll, AutoscrollStrategy},
ScrollAnchor,
},
Anchor, DisplayPoint, Editor, ExcerptId, ExcerptRange, MultiBuffer,
};
use fs::Fs;
use futures::{io::BufReader, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
use gpui::{
actions,
elements::*,
executor::Background,
geometry::vector::vec2f,
platform::{CursorStyle, MouseButton},
Action, AppContext, AsyncAppContext, ClipboardItem, Entity, ModelContext, ModelHandle,
Subscription, Task, View, ViewContext, ViewHandle, WeakViewHandle, WindowContext,
@ -414,6 +422,7 @@ impl Panel for AssistantPanel {
enum AssistantEvent {
MessagesEdited { ids: Vec<ExcerptId> },
SummaryChanged,
StreamedCompletion,
}
struct Assistant {
@ -531,7 +540,7 @@ impl Assistant {
cx.notify();
}
fn assist(&mut self, cx: &mut ModelContext<Self>) {
fn assist(&mut self, cx: &mut ModelContext<Self>) -> Option<(Message, Message)> {
let messages = self
.messages
.iter()
@ -548,24 +557,30 @@ impl Assistant {
stream: true,
};
let api_key = self.api_key.borrow().clone();
if let Some(api_key) = api_key {
let stream = stream_completion(api_key, cx.background().clone(), request);
let (excerpt_id, content) =
self.insert_message_after(ExcerptId::max(), Role::Assistant, cx);
self.insert_message_after(ExcerptId::max(), Role::User, cx);
let task = cx.spawn_weak(|this, mut cx| async move {
let api_key = self.api_key.borrow().clone()?;
let stream = stream_completion(api_key, cx.background().clone(), request);
let assistant_message = self.insert_message_after(ExcerptId::max(), Role::Assistant, cx);
let user_message = self.insert_message_after(ExcerptId::max(), Role::User, cx);
let task = cx.spawn_weak({
let assistant_message = assistant_message.clone();
|this, mut cx| async move {
let assistant_message = assistant_message;
let stream_completion = async {
let mut messages = stream.await?;
while let Some(message) = messages.next().await {
let mut message = message?;
if let Some(choice) = message.choices.pop() {
content.update(&mut cx, |content, cx| {
assistant_message.content.update(&mut cx, |content, cx| {
let text: Arc<str> = choice.delta.content?.into();
content.edit([(content.len()..content.len(), text)], None, cx);
Some(())
});
this.upgrade(&cx)
.ok_or_else(|| anyhow!("assistant was dropped"))?
.update(&mut cx, |_, cx| {
cx.emit(AssistantEvent::StreamedCompletion);
});
}
}
@ -580,23 +595,28 @@ impl Assistant {
anyhow::Ok(())
};
if let Err(error) = stream_completion.await {
if let Some(this) = this.upgrade(&cx) {
this.update(&mut cx, |this, cx| {
if let Some(metadata) = this.messages_metadata.get_mut(&excerpt_id) {
let result = stream_completion.await;
if let Some(this) = this.upgrade(&cx) {
this.update(&mut cx, |this, cx| {
if let Err(error) = result {
if let Some(metadata) = this
.messages_metadata
.get_mut(&assistant_message.excerpt_id)
{
metadata.error = Some(error.to_string().trim().into());
cx.notify();
}
});
}
}
});
}
});
}
});
self.pending_completions.push(PendingCompletion {
id: post_inc(&mut self.completion_count),
_task: task,
});
}
self.pending_completions.push(PendingCompletion {
id: post_inc(&mut self.completion_count),
_task: task,
});
Some((assistant_message, user_message))
}
fn cancel_last_assist(&mut self) -> bool {
@ -646,7 +666,7 @@ impl Assistant {
excerpt_id: ExcerptId,
role: Role,
cx: &mut ModelContext<Self>,
) -> (ExcerptId, ModelHandle<Buffer>) {
) -> Message {
let content = cx.add_model(|cx| {
let mut buffer = Buffer::new(0, "", cx);
let markdown = self.languages.language_for_name("Markdown");
@ -684,13 +704,11 @@ impl Assistant {
.iter()
.position(|message| message.excerpt_id == excerpt_id)
.map_or(self.messages.len(), |ix| ix + 1);
self.messages.insert(
ix,
Message {
excerpt_id: new_excerpt_id,
content: content.clone(),
},
);
let message = Message {
excerpt_id: new_excerpt_id,
content: content.clone(),
};
self.messages.insert(ix, message.clone());
self.messages_metadata.insert(
new_excerpt_id,
MessageMetadata {
@ -699,7 +717,7 @@ impl Assistant {
error: None,
},
);
(new_excerpt_id, content)
message
}
fn summarize(&mut self, cx: &mut ModelContext<Self>) {
@ -766,6 +784,7 @@ enum AssistantEditorEvent {
struct AssistantEditor {
assistant: ModelHandle<Assistant>,
editor: ViewHandle<Editor>,
scroll_bottom: ScrollAnchor,
_subscriptions: Vec<Subscription>,
}
@ -875,37 +894,64 @@ impl AssistantEditor {
let _subscriptions = vec![
cx.observe(&assistant, |_, _, cx| cx.notify()),
cx.subscribe(&assistant, Self::handle_assistant_event),
cx.subscribe(&editor, Self::handle_editor_event),
];
Self {
assistant,
editor,
scroll_bottom: ScrollAnchor {
offset: Default::default(),
anchor: Anchor::max(),
},
_subscriptions,
}
}
fn assist(&mut self, _: &Assist, cx: &mut ViewContext<Self>) {
self.assistant.update(cx, |assistant, cx| {
let user_message = self.assistant.update(cx, |assistant, cx| {
let editor = self.editor.read(cx);
let newest_selection = editor.selections.newest_anchor();
let excerpt_id = if newest_selection.head() == Anchor::min() {
assistant.messages.first().map(|message| message.excerpt_id)
assistant
.messages
.first()
.map(|message| message.excerpt_id)?
} else if newest_selection.head() == Anchor::max() {
assistant.messages.last().map(|message| message.excerpt_id)
assistant
.messages
.last()
.map(|message| message.excerpt_id)?
} else {
Some(newest_selection.head().excerpt_id())
newest_selection.head().excerpt_id()
};
if let Some(excerpt_id) = excerpt_id {
if let Some(metadata) = assistant.messages_metadata.get(&excerpt_id) {
if metadata.role == Role::User {
assistant.assist(cx);
} else {
assistant.insert_message_after(excerpt_id, Role::User, cx);
}
}
}
let metadata = assistant.messages_metadata.get(&excerpt_id)?;
let user_message = if metadata.role == Role::User {
let (_, user_message) = assistant.assist(cx)?;
user_message
} else {
let user_message = assistant.insert_message_after(excerpt_id, Role::User, cx);
user_message
};
Some(user_message)
});
if let Some(user_message) = user_message {
self.editor.update(cx, |editor, cx| {
let cursor = editor
.buffer()
.read(cx)
.snapshot(cx)
.anchor_in_excerpt(user_message.excerpt_id, language::Anchor::MIN);
editor.change_selections(
Some(Autoscroll::Strategy(AutoscrollStrategy::Fit)),
cx,
|selections| selections.select_anchor_ranges([cursor..cursor]),
);
});
self.update_scroll_bottom(cx);
}
}
fn cancel_last_assist(&mut self, _: &editor::Cancel, cx: &mut ViewContext<Self>) {
@ -919,7 +965,7 @@ impl AssistantEditor {
fn handle_assistant_event(
&mut self,
assistant: ModelHandle<Assistant>,
_: ModelHandle<Assistant>,
event: &AssistantEvent,
cx: &mut ViewContext<Self>,
) {
@ -931,16 +977,70 @@ impl AssistantEditor {
.map(|selection| selection.head())
.collect::<HashSet<usize>>();
let ids = ids.iter().copied().collect::<HashSet<_>>();
assistant.update(cx, |assistant, cx| {
self.assistant.update(cx, |assistant, cx| {
assistant.remove_empty_messages(ids, selection_heads, cx)
});
}
AssistantEvent::SummaryChanged => {
cx.emit(AssistantEditorEvent::TabContentChanged);
}
AssistantEvent::StreamedCompletion => {
self.editor.update(cx, |editor, cx| {
let snapshot = editor.snapshot(cx);
let scroll_bottom_row = self
.scroll_bottom
.anchor
.to_display_point(&snapshot.display_snapshot)
.row();
let scroll_bottom = scroll_bottom_row as f32 + self.scroll_bottom.offset.y();
let visible_line_count = editor.visible_line_count().unwrap_or(0.);
let scroll_top = scroll_bottom - visible_line_count;
editor
.set_scroll_position(vec2f(self.scroll_bottom.offset.x(), scroll_top), cx);
});
}
}
}
fn handle_editor_event(
&mut self,
_: ViewHandle<Editor>,
event: &editor::Event,
cx: &mut ViewContext<Self>,
) {
match event {
editor::Event::ScrollPositionChanged { .. } => self.update_scroll_bottom(cx),
_ => {}
}
}
fn update_scroll_bottom(&mut self, cx: &mut ViewContext<Self>) {
self.editor.update(cx, |editor, cx| {
let snapshot = editor.snapshot(cx);
let scroll_position = editor
.scroll_manager
.anchor()
.scroll_position(&snapshot.display_snapshot);
let scroll_bottom = scroll_position.y() + editor.visible_line_count().unwrap_or(0.);
let scroll_bottom_point = cmp::min(
DisplayPoint::new(scroll_bottom.floor() as u32, 0),
snapshot.display_snapshot.max_point(),
);
let scroll_bottom_anchor = snapshot
.buffer_snapshot
.anchor_after(scroll_bottom_point.to_point(&snapshot.display_snapshot));
let scroll_bottom_offset = vec2f(
scroll_position.x(),
scroll_bottom - scroll_bottom_point.row() as f32,
);
self.scroll_bottom = ScrollAnchor {
anchor: scroll_bottom_anchor,
offset: scroll_bottom_offset,
};
});
}
fn quote_selection(
workspace: &mut Workspace,
_: &QuoteSelection,
@ -1155,7 +1255,7 @@ impl Item for AssistantEditor {
}
}
#[derive(Debug)]
#[derive(Clone, Debug)]
struct Message {
excerpt_id: ExcerptId,
content: ModelHandle<Buffer>,
@ -1265,15 +1365,16 @@ mod tests {
cx.add_model(|cx| {
let mut assistant = Assistant::new(Default::default(), registry, cx);
let (excerpt_1, _) =
assistant.insert_message_after(ExcerptId::max(), Role::Assistant, cx);
let (excerpt_2, _) = assistant.insert_message_after(excerpt_1, Role::User, cx);
let (excerpt_3, _) = assistant.insert_message_after(excerpt_1, Role::User, cx);
let message_1 = assistant.insert_message_after(ExcerptId::max(), Role::Assistant, cx);
let message_2 = assistant.insert_message_after(message_1.excerpt_id, Role::User, cx);
let message_3 = assistant.insert_message_after(message_1.excerpt_id, Role::User, cx);
assistant.remove_empty_messages(
HashSet::from_iter([excerpt_2, excerpt_3]),
HashSet::from_iter([message_2.excerpt_id, message_3.excerpt_id]),
Default::default(),
cx,
);
assert_eq!(assistant.messages.len(), 1);
assert_eq!(assistant.messages[0].excerpt_id, message_1.excerpt_id);
assistant
});
}

View File

@ -579,7 +579,7 @@ async fn test_navigation_history(cx: &mut TestAppContext) {
assert_eq!(editor.scroll_manager.anchor(), original_scroll_position);
// Ensure we don't panic when navigation data contains invalid anchors *and* points.
let mut invalid_anchor = editor.scroll_manager.anchor().top_anchor;
let mut invalid_anchor = editor.scroll_manager.anchor().anchor;
invalid_anchor.text_anchor.buffer_id = Some(999);
let invalid_point = Point::new(9999, 0);
editor.navigate(
@ -587,7 +587,7 @@ async fn test_navigation_history(cx: &mut TestAppContext) {
cursor_anchor: invalid_anchor,
cursor_position: invalid_point,
scroll_anchor: ScrollAnchor {
top_anchor: invalid_anchor,
anchor: invalid_anchor,
offset: Default::default(),
},
scroll_top_row: invalid_point.row,
@ -5815,7 +5815,7 @@ async fn test_following(cx: &mut gpui::TestAppContext) {
let top_anchor = follower.buffer().read(cx).read(cx).anchor_after(0);
follower.set_scroll_anchor(
ScrollAnchor {
top_anchor,
anchor: top_anchor,
offset: vec2f(0.0, 0.5),
},
cx,

View File

@ -196,7 +196,7 @@ impl FollowableItem for Editor {
singleton: buffer.is_singleton(),
title: (!buffer.is_singleton()).then(|| buffer.title(cx).into()),
excerpts,
scroll_top_anchor: Some(serialize_anchor(&scroll_anchor.top_anchor)),
scroll_top_anchor: Some(serialize_anchor(&scroll_anchor.anchor)),
scroll_x: scroll_anchor.offset.x(),
scroll_y: scroll_anchor.offset.y(),
selections: self
@ -253,7 +253,7 @@ impl FollowableItem for Editor {
}
Event::ScrollPositionChanged { .. } => {
let scroll_anchor = self.scroll_manager.anchor();
update.scroll_top_anchor = Some(serialize_anchor(&scroll_anchor.top_anchor));
update.scroll_top_anchor = Some(serialize_anchor(&scroll_anchor.anchor));
update.scroll_x = scroll_anchor.offset.x();
update.scroll_y = scroll_anchor.offset.y();
true
@ -412,7 +412,7 @@ async fn update_editor_from_message(
} else if let Some(scroll_top_anchor) = scroll_top_anchor {
editor.set_scroll_anchor_remote(
ScrollAnchor {
top_anchor: scroll_top_anchor,
anchor: scroll_top_anchor,
offset: vec2f(message.scroll_x, message.scroll_y),
},
cx,
@ -510,8 +510,8 @@ impl Item for Editor {
};
let mut scroll_anchor = data.scroll_anchor;
if !buffer.can_resolve(&scroll_anchor.top_anchor) {
scroll_anchor.top_anchor = buffer.anchor_before(
if !buffer.can_resolve(&scroll_anchor.anchor) {
scroll_anchor.anchor = buffer.anchor_before(
buffer.clip_point(Point::new(data.scroll_top_row, 0), Bias::Left),
);
}

View File

@ -36,21 +36,21 @@ pub struct ScrollbarAutoHide(pub bool);
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct ScrollAnchor {
pub offset: Vector2F,
pub top_anchor: Anchor,
pub anchor: Anchor,
}
impl ScrollAnchor {
fn new() -> Self {
Self {
offset: Vector2F::zero(),
top_anchor: Anchor::min(),
anchor: Anchor::min(),
}
}
pub fn scroll_position(&self, snapshot: &DisplaySnapshot) -> Vector2F {
let mut scroll_position = self.offset;
if self.top_anchor != Anchor::min() {
let scroll_top = self.top_anchor.to_display_point(snapshot).row() as f32;
if self.anchor != Anchor::min() {
let scroll_top = self.anchor.to_display_point(snapshot).row() as f32;
scroll_position.set_y(scroll_top + scroll_position.y());
} else {
scroll_position.set_y(0.);
@ -59,7 +59,7 @@ impl ScrollAnchor {
}
pub fn top_row(&self, buffer: &MultiBufferSnapshot) -> u32 {
self.top_anchor.to_point(buffer).row
self.anchor.to_point(buffer).row
}
}
@ -179,7 +179,7 @@ impl ScrollManager {
let (new_anchor, top_row) = if scroll_position.y() <= 0. {
(
ScrollAnchor {
top_anchor: Anchor::min(),
anchor: Anchor::min(),
offset: scroll_position.max(vec2f(0., 0.)),
},
0,
@ -193,7 +193,7 @@ impl ScrollManager {
(
ScrollAnchor {
top_anchor,
anchor: top_anchor,
offset: vec2f(
scroll_position.x(),
scroll_position.y() - top_anchor.to_display_point(&map).row() as f32,
@ -322,7 +322,7 @@ impl Editor {
hide_hover(self, cx);
let workspace_id = self.workspace.as_ref().map(|workspace| workspace.1);
let top_row = scroll_anchor
.top_anchor
.anchor
.to_point(&self.buffer().read(cx).snapshot(cx))
.row;
self.scroll_manager
@ -337,7 +337,7 @@ impl Editor {
hide_hover(self, cx);
let workspace_id = self.workspace.as_ref().map(|workspace| workspace.1);
let top_row = scroll_anchor
.top_anchor
.anchor
.to_point(&self.buffer().read(cx).snapshot(cx))
.row;
self.scroll_manager
@ -377,7 +377,7 @@ impl Editor {
let screen_top = self
.scroll_manager
.anchor
.top_anchor
.anchor
.to_display_point(&snapshot);
if screen_top > newest_head {
@ -408,7 +408,7 @@ impl Editor {
.anchor_at(Point::new(top_row as u32, 0), Bias::Left);
let scroll_anchor = ScrollAnchor {
offset: Vector2F::new(x, y),
top_anchor,
anchor: top_anchor,
};
self.set_scroll_anchor(scroll_anchor, cx);
}

View File

@ -86,7 +86,7 @@ impl Editor {
editor.set_scroll_anchor(
ScrollAnchor {
top_anchor: new_anchor,
anchor: new_anchor,
offset: Default::default(),
},
cx,
@ -113,7 +113,7 @@ impl Editor {
editor.set_scroll_anchor(
ScrollAnchor {
top_anchor: new_anchor,
anchor: new_anchor,
offset: Default::default(),
},
cx,
@ -143,7 +143,7 @@ impl Editor {
editor.set_scroll_anchor(
ScrollAnchor {
top_anchor: new_anchor,
anchor: new_anchor,
offset: Default::default(),
},
cx,

View File

@ -400,7 +400,7 @@ fn scroll(editor: &mut Editor, amount: &ScrollAmount, cx: &mut ViewContext<Edito
};
let scroll_margin_rows = editor.vertical_scroll_margin() as u32;
let top_anchor = editor.scroll_manager.anchor().top_anchor;
let top_anchor = editor.scroll_manager.anchor().anchor;
editor.change_selections(None, cx, |s| {
s.replace_cursors_with(|snapshot| {