mirror of
https://github.com/zed-industries/zed.git
synced 2024-11-08 07:35:01 +03:00
Allow for multi-cursor assist
and cycle_role
actions
Co-Authored-By: Nathan Sobo <nathan@zed.dev> Co-Authored-By: Kyle Caverly <kyle@zed.dev>
This commit is contained in:
parent
9191a82447
commit
75e2329028
@ -22,7 +22,7 @@ use gpui::{
|
||||
Subscription, Task, View, ViewContext, ViewHandle, WeakViewHandle, WindowContext,
|
||||
};
|
||||
use isahc::{http::StatusCode, Request, RequestExt};
|
||||
use language::{language_settings::SoftWrap, Buffer, LanguageRegistry, Selection, ToOffset as _};
|
||||
use language::{language_settings::SoftWrap, Buffer, LanguageRegistry, ToOffset as _};
|
||||
use serde::Deserialize;
|
||||
use settings::SettingsStore;
|
||||
use std::{
|
||||
@ -591,106 +591,129 @@ impl Assistant {
|
||||
|
||||
fn assist(
|
||||
&mut self,
|
||||
selection: Selection<usize>,
|
||||
selected_messages: HashSet<MessageId>,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) -> Option<(MessageAnchor, MessageAnchor)> {
|
||||
let request = OpenAIRequest {
|
||||
model: self.model.clone(),
|
||||
messages: self
|
||||
.messages(cx)
|
||||
.map(|message| message.to_open_ai_message(self.buffer.read(cx)))
|
||||
.collect(),
|
||||
stream: true,
|
||||
};
|
||||
) -> Vec<MessageAnchor> {
|
||||
let mut user_messages = Vec::new();
|
||||
for selected_message_id in selected_messages {
|
||||
let selected_message_role =
|
||||
if let Some(metadata) = self.messages_metadata.get(&selected_message_id) {
|
||||
metadata.role
|
||||
} else {
|
||||
continue;
|
||||
};
|
||||
let Some(user_message) = self.insert_message_after(selected_message_id, Role::User, cx) else {
|
||||
continue;
|
||||
};
|
||||
user_messages.push(user_message);
|
||||
if selected_message_role == Role::User {
|
||||
let request = OpenAIRequest {
|
||||
model: self.model.clone(),
|
||||
messages: self
|
||||
.messages(cx)
|
||||
.map(|message| message.to_open_ai_message(self.buffer.read(cx)))
|
||||
.chain(Some(RequestMessage {
|
||||
role: Role::System,
|
||||
content: format!(
|
||||
"Direct your reply to message with id {}. Do not include a [Message X] header.",
|
||||
selected_message_id.0
|
||||
),
|
||||
}))
|
||||
.collect(),
|
||||
stream: true,
|
||||
};
|
||||
|
||||
let api_key = self.api_key.borrow().clone()?;
|
||||
let stream = stream_completion(api_key, cx.background().clone(), request);
|
||||
let assistant_message = self.insert_message_after(
|
||||
self.message_for_offset(selection.head(), cx)?.id,
|
||||
Role::Assistant,
|
||||
cx,
|
||||
)?;
|
||||
let user_message = self.insert_message_after(assistant_message.id, Role::User, cx)?;
|
||||
let Some(api_key) = self.api_key.borrow().clone() else { continue };
|
||||
let stream = stream_completion(api_key, cx.background().clone(), request);
|
||||
let assistant_message = self
|
||||
.insert_message_after(selected_message_id, Role::Assistant, cx)
|
||||
.unwrap();
|
||||
|
||||
let task = cx.spawn_weak({
|
||||
|this, mut cx| async move {
|
||||
let assistant_message_id = assistant_message.id;
|
||||
let stream_completion = async {
|
||||
let mut messages = stream.await?;
|
||||
let task = cx.spawn_weak({
|
||||
|this, mut cx| async move {
|
||||
let assistant_message_id = assistant_message.id;
|
||||
let stream_completion = async {
|
||||
let mut messages = stream.await?;
|
||||
|
||||
while let Some(message) = messages.next().await {
|
||||
let mut message = message?;
|
||||
if let Some(choice) = message.choices.pop() {
|
||||
this.upgrade(&cx)
|
||||
.ok_or_else(|| anyhow!("assistant was dropped"))?
|
||||
.update(&mut cx, |this, cx| {
|
||||
let text: Arc<str> = choice.delta.content?.into();
|
||||
let message_ix = this.message_anchors.iter().position(
|
||||
|message| message.id == assistant_message_id,
|
||||
)?;
|
||||
this.buffer.update(cx, |buffer, cx| {
|
||||
let offset = if message_ix + 1
|
||||
== this.message_anchors.len()
|
||||
{
|
||||
buffer.len()
|
||||
} else {
|
||||
this.message_anchors[message_ix + 1]
|
||||
.start
|
||||
.to_offset(buffer)
|
||||
.saturating_sub(1)
|
||||
};
|
||||
buffer.edit([(offset..offset, text)], None, cx);
|
||||
});
|
||||
cx.emit(AssistantEvent::StreamedCompletion);
|
||||
|
||||
Some(())
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
while let Some(message) = messages.next().await {
|
||||
let mut message = message?;
|
||||
if let Some(choice) = message.choices.pop() {
|
||||
this.upgrade(&cx)
|
||||
.ok_or_else(|| anyhow!("assistant was dropped"))?
|
||||
.update(&mut cx, |this, cx| {
|
||||
let text: Arc<str> = choice.delta.content?.into();
|
||||
let message_ix = this
|
||||
.message_anchors
|
||||
.iter()
|
||||
.position(|message| message.id == assistant_message_id)?;
|
||||
this.buffer.update(cx, |buffer, cx| {
|
||||
let offset = if message_ix + 1 == this.message_anchors.len()
|
||||
{
|
||||
buffer.len()
|
||||
} else {
|
||||
this.message_anchors[message_ix + 1]
|
||||
.start
|
||||
.to_offset(buffer)
|
||||
.saturating_sub(1)
|
||||
};
|
||||
buffer.edit([(offset..offset, text)], None, cx);
|
||||
this.pending_completions.retain(|completion| {
|
||||
completion.id != this.completion_count
|
||||
});
|
||||
cx.emit(AssistantEvent::StreamedCompletion);
|
||||
|
||||
Some(())
|
||||
this.summarize(cx);
|
||||
});
|
||||
|
||||
anyhow::Ok(())
|
||||
};
|
||||
|
||||
let result = stream_completion.await;
|
||||
if let Some(this) = this.upgrade(&cx) {
|
||||
this.update(&mut cx, |this, cx| {
|
||||
if let Err(error) = result {
|
||||
if let Some(metadata) =
|
||||
this.messages_metadata.get_mut(&assistant_message.id)
|
||||
{
|
||||
metadata.error = Some(error.to_string().trim().into());
|
||||
cx.notify();
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
this.upgrade(&cx)
|
||||
.ok_or_else(|| anyhow!("assistant was dropped"))?
|
||||
.update(&mut cx, |this, cx| {
|
||||
this.pending_completions
|
||||
.retain(|completion| completion.id != this.completion_count);
|
||||
this.summarize(cx);
|
||||
});
|
||||
|
||||
anyhow::Ok(())
|
||||
};
|
||||
|
||||
let result = stream_completion.await;
|
||||
if let Some(this) = this.upgrade(&cx) {
|
||||
this.update(&mut cx, |this, cx| {
|
||||
if let Err(error) = result {
|
||||
if let Some(metadata) =
|
||||
this.messages_metadata.get_mut(&assistant_message.id)
|
||||
{
|
||||
metadata.error = Some(error.to_string().trim().into());
|
||||
cx.notify();
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
self.pending_completions.push(PendingCompletion {
|
||||
id: post_inc(&mut self.completion_count),
|
||||
_task: task,
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
self.pending_completions.push(PendingCompletion {
|
||||
id: post_inc(&mut self.completion_count),
|
||||
_task: task,
|
||||
});
|
||||
Some((assistant_message, user_message))
|
||||
user_messages
|
||||
}
|
||||
|
||||
fn cancel_last_assist(&mut self) -> bool {
|
||||
self.pending_completions.pop().is_some()
|
||||
}
|
||||
|
||||
fn cycle_message_role(&mut self, id: MessageId, cx: &mut ModelContext<Self>) {
|
||||
if let Some(metadata) = self.messages_metadata.get_mut(&id) {
|
||||
metadata.role.cycle();
|
||||
cx.emit(AssistantEvent::MessagesEdited);
|
||||
cx.notify();
|
||||
fn cycle_message_roles(&mut self, ids: HashSet<MessageId>, cx: &mut ModelContext<Self>) {
|
||||
for id in ids {
|
||||
if let Some(metadata) = self.messages_metadata.get_mut(&id) {
|
||||
metadata.role.cycle();
|
||||
cx.emit(AssistantEvent::MessagesEdited);
|
||||
cx.notify();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -884,14 +907,39 @@ impl Assistant {
|
||||
}
|
||||
}
|
||||
|
||||
fn message_for_offset<'a>(&'a self, offset: usize, cx: &'a AppContext) -> Option<Message> {
|
||||
fn message_for_offset(&self, offset: usize, cx: &AppContext) -> Option<Message> {
|
||||
self.messages_for_offsets([offset], cx).pop()
|
||||
}
|
||||
|
||||
fn messages_for_offsets(
|
||||
&self,
|
||||
offsets: impl IntoIterator<Item = usize>,
|
||||
cx: &AppContext,
|
||||
) -> Vec<Message> {
|
||||
let mut result = Vec::new();
|
||||
|
||||
let buffer_len = self.buffer.read(cx).len();
|
||||
let mut messages = self.messages(cx).peekable();
|
||||
while let Some(message) = messages.next() {
|
||||
if message.range.contains(&offset) || messages.peek().is_none() {
|
||||
return Some(message);
|
||||
let mut offsets = offsets.into_iter().peekable();
|
||||
while let Some(offset) = offsets.next() {
|
||||
// Skip messages that start after the offset.
|
||||
while messages.peek().map_or(false, |message| {
|
||||
message.range.end < offset || (message.range.end == offset && offset < buffer_len)
|
||||
}) {
|
||||
messages.next();
|
||||
}
|
||||
let Some(message) = messages.peek() else { continue };
|
||||
|
||||
// Skip offsets that are in the same message.
|
||||
while offsets.peek().map_or(false, |offset| {
|
||||
message.range.contains(offset) || message.range.end == buffer_len
|
||||
}) {
|
||||
offsets.next();
|
||||
}
|
||||
|
||||
result.push(message.clone());
|
||||
}
|
||||
None
|
||||
result
|
||||
}
|
||||
|
||||
fn messages<'a>(&'a self, cx: &'a AppContext) -> impl 'a + Iterator<Item = Message> {
|
||||
@ -983,24 +1031,32 @@ impl AssistantEditor {
|
||||
}
|
||||
|
||||
fn assist(&mut self, _: &Assist, cx: &mut ViewContext<Self>) {
|
||||
let selection = self.editor.read(cx).selections.newest(cx);
|
||||
let user_message = self.assistant.update(cx, |assistant, cx| {
|
||||
let (_, user_message) = assistant.assist(selection, cx)?;
|
||||
Some(user_message)
|
||||
});
|
||||
let cursors = self.cursors(cx);
|
||||
|
||||
if let Some(user_message) = user_message {
|
||||
let cursor = user_message
|
||||
.start
|
||||
.to_offset(&self.assistant.read(cx).buffer.read(cx));
|
||||
self.editor.update(cx, |editor, cx| {
|
||||
editor.change_selections(
|
||||
Some(Autoscroll::Strategy(AutoscrollStrategy::Fit)),
|
||||
cx,
|
||||
|selections| selections.select_ranges([cursor..cursor]),
|
||||
);
|
||||
});
|
||||
}
|
||||
let user_messages = self.assistant.update(cx, |assistant, cx| {
|
||||
let selected_messages = assistant
|
||||
.messages_for_offsets(cursors, cx)
|
||||
.into_iter()
|
||||
.map(|message| message.id)
|
||||
.collect();
|
||||
assistant.assist(selected_messages, cx)
|
||||
});
|
||||
let new_selections = user_messages
|
||||
.iter()
|
||||
.map(|message| {
|
||||
let cursor = message
|
||||
.start
|
||||
.to_offset(self.assistant.read(cx).buffer.read(cx));
|
||||
cursor..cursor
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
self.editor.update(cx, |editor, cx| {
|
||||
editor.change_selections(
|
||||
Some(Autoscroll::Strategy(AutoscrollStrategy::Fit)),
|
||||
cx,
|
||||
|selections| selections.select_ranges(new_selections),
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
fn cancel_last_assist(&mut self, _: &editor::Cancel, cx: &mut ViewContext<Self>) {
|
||||
@ -1013,14 +1069,25 @@ impl AssistantEditor {
|
||||
}
|
||||
|
||||
fn cycle_message_role(&mut self, _: &CycleMessageRole, cx: &mut ViewContext<Self>) {
|
||||
let cursor_offset = self.editor.read(cx).selections.newest(cx).head();
|
||||
let cursors = self.cursors(cx);
|
||||
self.assistant.update(cx, |assistant, cx| {
|
||||
if let Some(message) = assistant.message_for_offset(cursor_offset, cx) {
|
||||
assistant.cycle_message_role(message.id, cx);
|
||||
}
|
||||
let messages = assistant
|
||||
.messages_for_offsets(cursors, cx)
|
||||
.into_iter()
|
||||
.map(|message| message.id)
|
||||
.collect();
|
||||
assistant.cycle_message_roles(messages, cx)
|
||||
});
|
||||
}
|
||||
|
||||
fn cursors(&self, cx: &AppContext) -> Vec<usize> {
|
||||
let selections = self.editor.read(cx).selections.all::<usize>(cx);
|
||||
selections
|
||||
.into_iter()
|
||||
.map(|selection| selection.head())
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn handle_assistant_event(
|
||||
&mut self,
|
||||
_: ModelHandle<Assistant>,
|
||||
@ -1149,7 +1216,10 @@ impl AssistantEditor {
|
||||
let assistant = assistant.clone();
|
||||
move |_, _, cx| {
|
||||
assistant.update(cx, |assistant, cx| {
|
||||
assistant.cycle_message_role(message_id, cx)
|
||||
assistant.cycle_message_roles(
|
||||
HashSet::from_iter(Some(message_id)),
|
||||
cx,
|
||||
)
|
||||
})
|
||||
}
|
||||
});
|
||||
@ -1444,9 +1514,11 @@ pub struct Message {
|
||||
|
||||
impl Message {
|
||||
fn to_open_ai_message(&self, buffer: &Buffer) -> RequestMessage {
|
||||
let mut content = format!("[Message {}]\n", self.id.0).to_string();
|
||||
content.extend(buffer.text_for_range(self.range.clone()));
|
||||
RequestMessage {
|
||||
role: self.role,
|
||||
content: buffer.text_for_range(self.range.clone()).collect(),
|
||||
content,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1761,6 +1833,66 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
fn test_messages_for_offsets(cx: &mut AppContext) {
|
||||
let registry = Arc::new(LanguageRegistry::test());
|
||||
let assistant = cx.add_model(|cx| Assistant::new(Default::default(), registry, cx));
|
||||
let buffer = assistant.read(cx).buffer.clone();
|
||||
|
||||
let message_1 = assistant.read(cx).message_anchors[0].clone();
|
||||
assert_eq!(
|
||||
messages(&assistant, cx),
|
||||
vec![(message_1.id, Role::User, 0..0)]
|
||||
);
|
||||
|
||||
buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
|
||||
let message_2 = assistant
|
||||
.update(cx, |assistant, cx| {
|
||||
assistant.insert_message_after(message_1.id, Role::User, cx)
|
||||
})
|
||||
.unwrap();
|
||||
buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbb")], None, cx));
|
||||
|
||||
let message_3 = assistant
|
||||
.update(cx, |assistant, cx| {
|
||||
assistant.insert_message_after(message_2.id, Role::User, cx)
|
||||
})
|
||||
.unwrap();
|
||||
buffer.update(cx, |buffer, cx| buffer.edit([(8..8, "ccc")], None, cx));
|
||||
|
||||
assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc");
|
||||
assert_eq!(
|
||||
messages(&assistant, cx),
|
||||
vec![
|
||||
(message_1.id, Role::User, 0..4),
|
||||
(message_2.id, Role::User, 4..8),
|
||||
(message_3.id, Role::User, 8..11)
|
||||
]
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
message_ids_for_offsets(&assistant, &[0, 4, 9], cx),
|
||||
[message_1.id, message_2.id, message_3.id]
|
||||
);
|
||||
assert_eq!(
|
||||
message_ids_for_offsets(&assistant, &[0, 1, 11], cx),
|
||||
[message_1.id, message_3.id]
|
||||
);
|
||||
|
||||
fn message_ids_for_offsets(
|
||||
assistant: &ModelHandle<Assistant>,
|
||||
offsets: &[usize],
|
||||
cx: &AppContext,
|
||||
) -> Vec<MessageId> {
|
||||
assistant
|
||||
.read(cx)
|
||||
.messages_for_offsets(offsets.iter().copied(), cx)
|
||||
.into_iter()
|
||||
.map(|message| message.id)
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
fn messages(
|
||||
assistant: &ModelHandle<Assistant>,
|
||||
cx: &AppContext,
|
||||
|
Loading…
Reference in New Issue
Block a user