mirror of
https://github.com/zed-industries/zed.git
synced 2024-12-28 23:41:42 +03:00
Get back to a compiling state with Buffer
backing the assistant
This commit is contained in:
parent
7db690b713
commit
2ae8b558b9
@ -11,7 +11,7 @@ use editor::{
|
||||
autoscroll::{Autoscroll, AutoscrollStrategy},
|
||||
ScrollAnchor,
|
||||
},
|
||||
Anchor, DisplayPoint, Editor, ExcerptId,
|
||||
Anchor, DisplayPoint, Editor, ToOffset as _,
|
||||
};
|
||||
use fs::Fs;
|
||||
use futures::{io::BufReader, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
|
||||
@ -25,10 +25,13 @@ use gpui::{
|
||||
Subscription, Task, View, ViewContext, ViewHandle, WeakViewHandle, WindowContext,
|
||||
};
|
||||
use isahc::{http::StatusCode, Request, RequestExt};
|
||||
use language::{language_settings::SoftWrap, Buffer, LanguageRegistry};
|
||||
use language::{language_settings::SoftWrap, Buffer, LanguageRegistry, ToOffset as _};
|
||||
use serde::Deserialize;
|
||||
use settings::SettingsStore;
|
||||
use std::{borrow::Cow, cell::RefCell, cmp, fmt::Write, io, rc::Rc, sync::Arc, time::Duration};
|
||||
use std::{
|
||||
borrow::Cow, cell::RefCell, cmp, fmt::Write, io, iter, ops::Range, rc::Rc, sync::Arc,
|
||||
time::Duration,
|
||||
};
|
||||
use util::{post_inc, truncate_and_trailoff, ResultExt, TryFutureExt};
|
||||
use workspace::{
|
||||
dock::{DockPosition, Panel},
|
||||
@ -507,16 +510,16 @@ impl Assistant {
|
||||
|
||||
fn count_remaining_tokens(&mut self, cx: &mut ModelContext<Self>) {
|
||||
let messages = self
|
||||
.messages
|
||||
.iter()
|
||||
.open_ai_request_messages(cx)
|
||||
.into_iter()
|
||||
.filter_map(|message| {
|
||||
Some(tiktoken_rs::ChatCompletionRequestMessage {
|
||||
role: match self.messages_metadata.get(&message.excerpt_id)?.role {
|
||||
role: match message.role {
|
||||
Role::User => "user".into(),
|
||||
Role::Assistant => "assistant".into(),
|
||||
Role::System => "system".into(),
|
||||
},
|
||||
content: message.content.read(cx).text(),
|
||||
content: message.content,
|
||||
name: None,
|
||||
})
|
||||
})
|
||||
@ -554,45 +557,47 @@ impl Assistant {
|
||||
}
|
||||
|
||||
fn assist(&mut self, cx: &mut ModelContext<Self>) -> Option<(Message, Message)> {
|
||||
let messages = self
|
||||
.messages
|
||||
.iter()
|
||||
.filter_map(|message| {
|
||||
Some(RequestMessage {
|
||||
role: self.messages_metadata.get(&message.excerpt_id)?.role,
|
||||
content: message.content.read(cx).text(),
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
let request = OpenAIRequest {
|
||||
model: self.model.clone(),
|
||||
messages,
|
||||
messages: self.open_ai_request_messages(cx),
|
||||
stream: true,
|
||||
};
|
||||
|
||||
let api_key = self.api_key.borrow().clone()?;
|
||||
let stream = stream_completion(api_key, cx.background().clone(), request);
|
||||
let assistant_message = self.insert_message_after(ExcerptId::max(), Role::Assistant, cx);
|
||||
let user_message = self.insert_message_after(ExcerptId::max(), Role::User, cx);
|
||||
let assistant_message =
|
||||
self.insert_message_after(self.messages.last()?.id, Role::Assistant, cx)?;
|
||||
let user_message = self.insert_message_after(assistant_message.id, Role::User, cx)?;
|
||||
let task = cx.spawn_weak({
|
||||
let assistant_message = assistant_message.clone();
|
||||
|this, mut cx| async move {
|
||||
let assistant_message = assistant_message;
|
||||
let assistant_message_id = assistant_message.id;
|
||||
let stream_completion = async {
|
||||
let mut messages = stream.await?;
|
||||
|
||||
while let Some(message) = messages.next().await {
|
||||
let mut message = message?;
|
||||
if let Some(choice) = message.choices.pop() {
|
||||
assistant_message.content.update(&mut cx, |content, cx| {
|
||||
let text: Arc<str> = choice.delta.content?.into();
|
||||
content.edit([(content.len()..content.len(), text)], None, cx);
|
||||
Some(())
|
||||
});
|
||||
this.upgrade(&cx)
|
||||
.ok_or_else(|| anyhow!("assistant was dropped"))?
|
||||
.update(&mut cx, |_, cx| {
|
||||
cx.emit(AssistantEvent::StreamedCompletion);
|
||||
.update(&mut cx, |this, cx| {
|
||||
let text: Arc<str> = choice.delta.content?.into();
|
||||
let message_ix = this
|
||||
.messages
|
||||
.iter()
|
||||
.position(|message| message.id == assistant_message_id)?;
|
||||
this.buffer.update(cx, |buffer, cx| {
|
||||
let offset = if message_ix + 1 == this.messages.len() {
|
||||
buffer.len()
|
||||
} else {
|
||||
this.messages[message_ix + 1]
|
||||
.start
|
||||
.to_offset(buffer)
|
||||
.saturating_sub(1)
|
||||
};
|
||||
buffer.edit([(offset..offset, text)], None, cx);
|
||||
});
|
||||
|
||||
Some(())
|
||||
});
|
||||
}
|
||||
}
|
||||
@ -612,9 +617,8 @@ impl Assistant {
|
||||
if let Some(this) = this.upgrade(&cx) {
|
||||
this.update(&mut cx, |this, cx| {
|
||||
if let Err(error) = result {
|
||||
if let Some(metadata) = this
|
||||
.messages_metadata
|
||||
.get_mut(&assistant_message.excerpt_id)
|
||||
if let Some(metadata) =
|
||||
this.messages_metadata.get_mut(&assistant_message.id)
|
||||
{
|
||||
metadata.error = Some(error.to_string().trim().into());
|
||||
cx.notify();
|
||||
@ -642,33 +646,33 @@ impl Assistant {
|
||||
protected_offsets: HashSet<usize>,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) {
|
||||
let mut offset = 0;
|
||||
let mut excerpts_to_remove = Vec::new();
|
||||
self.messages.retain(|message| {
|
||||
let range = offset..offset + message.content.read(cx).len();
|
||||
offset = range.end + 1;
|
||||
if range.is_empty()
|
||||
&& !protected_offsets.contains(&range.start)
|
||||
&& messages.contains(&message.id)
|
||||
{
|
||||
excerpts_to_remove.push(message.excerpt_id);
|
||||
self.messages_metadata.remove(&message.excerpt_id);
|
||||
false
|
||||
} else {
|
||||
true
|
||||
}
|
||||
});
|
||||
// let mut offset = 0;
|
||||
// let mut excerpts_to_remove = Vec::new();
|
||||
// self.messages.retain(|message| {
|
||||
// let range = offset..offset + message.content.read(cx).len();
|
||||
// offset = range.end + 1;
|
||||
// if range.is_empty()
|
||||
// && !protected_offsets.contains(&range.start)
|
||||
// && messages.contains(&message.id)
|
||||
// {
|
||||
// excerpts_to_remove.push(message.excerpt_id);
|
||||
// self.messages_metadata.remove(&message.excerpt_id);
|
||||
// false
|
||||
// } else {
|
||||
// true
|
||||
// }
|
||||
// });
|
||||
|
||||
if !excerpts_to_remove.is_empty() {
|
||||
self.buffer.update(cx, |buffer, cx| {
|
||||
buffer.remove_excerpts(excerpts_to_remove, cx)
|
||||
});
|
||||
cx.notify();
|
||||
}
|
||||
// if !excerpts_to_remove.is_empty() {
|
||||
// self.buffer.update(cx, |buffer, cx| {
|
||||
// buffer.remove_excerpts(excerpts_to_remove, cx)
|
||||
// });
|
||||
// cx.notify();
|
||||
// }
|
||||
}
|
||||
|
||||
fn cycle_message_role(&mut self, excerpt_id: ExcerptId, cx: &mut ModelContext<Self>) {
|
||||
if let Some(metadata) = self.messages_metadata.get_mut(&excerpt_id) {
|
||||
fn cycle_message_role(&mut self, id: MessageId, cx: &mut ModelContext<Self>) {
|
||||
if let Some(metadata) = self.messages_metadata.get_mut(&id) {
|
||||
metadata.role.cycle();
|
||||
cx.notify();
|
||||
}
|
||||
@ -686,15 +690,18 @@ impl Assistant {
|
||||
.position(|message| message.id == message_id)
|
||||
{
|
||||
let start = self.buffer.update(cx, |buffer, cx| {
|
||||
let len = buffer.len();
|
||||
buffer.edit([(len..len, "\n")], None, cx);
|
||||
buffer.anchor_before(len + 1)
|
||||
let offset = self
|
||||
.messages
|
||||
.get(prev_message_ix + 1)
|
||||
.map_or(buffer.len(), |message| message.start.to_offset(buffer) - 1);
|
||||
buffer.edit([(offset..offset, "\n")], None, cx);
|
||||
buffer.anchor_before(offset + 1)
|
||||
});
|
||||
let message = Message {
|
||||
id: MessageId(post_inc(&mut self.next_message_id.0)),
|
||||
start,
|
||||
};
|
||||
self.messages.insert(prev_message_ix, message.clone());
|
||||
self.messages.insert(prev_message_ix + 1, message.clone());
|
||||
self.messages_metadata.insert(
|
||||
message.id,
|
||||
MessageMetadata {
|
||||
@ -713,24 +720,13 @@ impl Assistant {
|
||||
if self.messages.len() >= 2 && self.summary.is_none() {
|
||||
let api_key = self.api_key.borrow().clone();
|
||||
if let Some(api_key) = api_key {
|
||||
// let messages = self
|
||||
// .messages
|
||||
// .iter()
|
||||
// .take(2)
|
||||
// .filter_map(|message| {
|
||||
// Some(RequestMessage {
|
||||
// role: self.messages_metadata.get(&message.id)?.role,
|
||||
// content: message.content.read(cx).text(),
|
||||
// })
|
||||
// })
|
||||
// .chain(Some(RequestMessage {
|
||||
// role: Role::User,
|
||||
// content:
|
||||
// "Summarize the conversation into a short title without punctuation"
|
||||
// .into(),
|
||||
// }))
|
||||
// .collect();
|
||||
let messages = todo!();
|
||||
let mut messages = self.open_ai_request_messages(cx);
|
||||
messages.truncate(2);
|
||||
messages.push(RequestMessage {
|
||||
role: Role::User,
|
||||
content: "Summarize the conversation into a short title without punctuation"
|
||||
.into(),
|
||||
});
|
||||
let request = OpenAIRequest {
|
||||
model: self.model.clone(),
|
||||
messages,
|
||||
@ -760,6 +756,44 @@ impl Assistant {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn open_ai_request_messages(&self, cx: &AppContext) -> Vec<RequestMessage> {
|
||||
let buffer = self.buffer.read(cx);
|
||||
self.messages(cx)
|
||||
.map(|(message, metadata, range)| RequestMessage {
|
||||
role: metadata.role,
|
||||
content: buffer.text_for_range(range).collect(),
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn message_id_for_offset(&self, offset: usize, cx: &AppContext) -> Option<MessageId> {
|
||||
Some(
|
||||
self.messages(cx)
|
||||
.find(|(_, _, range)| range.contains(&offset))
|
||||
.map(|(message, _, _)| message)
|
||||
.or(self.messages.last())?
|
||||
.id,
|
||||
)
|
||||
}
|
||||
|
||||
fn messages<'a>(
|
||||
&'a self,
|
||||
cx: &'a AppContext,
|
||||
) -> impl 'a + Iterator<Item = (&Message, &MessageMetadata, Range<usize>)> {
|
||||
let buffer = self.buffer.read(cx);
|
||||
let mut messages = self.messages.iter().peekable();
|
||||
iter::from_fn(move || {
|
||||
let message = messages.next()?;
|
||||
let metadata = self.messages_metadata.get(&message.id)?;
|
||||
let message_start = message.start.to_offset(buffer);
|
||||
let message_end = messages
|
||||
.peek()
|
||||
.map_or(language::Anchor::MAX, |message| message.start)
|
||||
.to_offset(buffer);
|
||||
Some((message, metadata, message_start..message_end))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
struct PendingCompletion {
|
||||
@ -812,16 +846,12 @@ impl AssistantEditor {
|
||||
fn assist(&mut self, _: &Assist, cx: &mut ViewContext<Self>) {
|
||||
let user_message = self.assistant.update(cx, |assistant, cx| {
|
||||
let editor = self.editor.read(cx);
|
||||
let newest_selection = editor.selections.newest_anchor();
|
||||
let message_id = if newest_selection.head() == Anchor::min() {
|
||||
assistant.messages.first().map(|message| message.id)?
|
||||
} else if newest_selection.head() == Anchor::max() {
|
||||
assistant.messages.last().map(|message| message.id)?
|
||||
} else {
|
||||
todo!()
|
||||
// newest_selection.head().excerpt_id()
|
||||
};
|
||||
|
||||
let newest_selection = editor
|
||||
.selections
|
||||
.newest_anchor()
|
||||
.head()
|
||||
.to_offset(&editor.buffer().read(cx).snapshot(cx));
|
||||
let message_id = assistant.message_id_for_offset(newest_selection, cx)?;
|
||||
let metadata = assistant.messages_metadata.get(&message_id)?;
|
||||
let user_message = if metadata.role == Role::User {
|
||||
let (_, user_message) = assistant.assist(cx)?;
|
||||
@ -834,16 +864,14 @@ impl AssistantEditor {
|
||||
});
|
||||
|
||||
if let Some(user_message) = user_message {
|
||||
let cursor = user_message
|
||||
.start
|
||||
.to_offset(&self.assistant.read(cx).buffer.read(cx));
|
||||
self.editor.update(cx, |editor, cx| {
|
||||
let cursor = editor
|
||||
.buffer()
|
||||
.read(cx)
|
||||
.snapshot(cx)
|
||||
.anchor_in_excerpt(Default::default(), user_message.start);
|
||||
editor.change_selections(
|
||||
Some(Autoscroll::Strategy(AutoscrollStrategy::Fit)),
|
||||
cx,
|
||||
|selections| selections.select_anchor_ranges([cursor..cursor]),
|
||||
|selections| selections.select_ranges([cursor..cursor]),
|
||||
);
|
||||
});
|
||||
self.update_scroll_bottom(cx);
|
||||
@ -1011,7 +1039,7 @@ impl AssistantEditor {
|
||||
let mut copied_text = String::new();
|
||||
let mut spanned_messages = 0;
|
||||
for message in &assistant.messages {
|
||||
// TODO
|
||||
todo!();
|
||||
// let message_range = offset..offset + message.content.read(cx).len() + 1;
|
||||
let message_range = offset..offset + 1;
|
||||
|
||||
@ -1260,28 +1288,100 @@ mod tests {
|
||||
#[gpui::test]
|
||||
fn test_inserting_and_removing_messages(cx: &mut AppContext) {
|
||||
let registry = Arc::new(LanguageRegistry::test());
|
||||
let assistant = cx.add_model(|cx| Assistant::new(Default::default(), registry, cx));
|
||||
let buffer = assistant.read(cx).buffer.clone();
|
||||
|
||||
cx.add_model(|cx| {
|
||||
let mut assistant = Assistant::new(Default::default(), registry, cx);
|
||||
let message_1 = assistant.messages[0].clone();
|
||||
let message_2 = assistant
|
||||
.insert_message_after(message_1.id, Role::Assistant, cx)
|
||||
.unwrap();
|
||||
let message_3 = assistant
|
||||
.insert_message_after(message_2.id, Role::User, cx)
|
||||
.unwrap();
|
||||
let message_4 = assistant
|
||||
.insert_message_after(message_2.id, Role::User, cx)
|
||||
.unwrap();
|
||||
assistant.remove_empty_messages(
|
||||
HashSet::from_iter([message_3.id, message_4.id]),
|
||||
Default::default(),
|
||||
cx,
|
||||
);
|
||||
assert_eq!(assistant.messages.len(), 2);
|
||||
assert_eq!(assistant.messages[0].id, message_1.id);
|
||||
assert_eq!(assistant.messages[1].id, message_2.id);
|
||||
let message_1 = assistant.read(cx).messages[0].clone();
|
||||
assert_eq!(
|
||||
messages(&assistant, cx),
|
||||
vec![(message_1.id, Role::User, 0..0)]
|
||||
);
|
||||
|
||||
let message_2 = assistant.update(cx, |assistant, cx| {
|
||||
assistant
|
||||
.insert_message_after(message_1.id, Role::Assistant, cx)
|
||||
.unwrap()
|
||||
});
|
||||
assert_eq!(
|
||||
messages(&assistant, cx),
|
||||
vec![
|
||||
(message_1.id, Role::User, 0..1),
|
||||
(message_2.id, Role::Assistant, 1..1)
|
||||
]
|
||||
);
|
||||
|
||||
buffer.update(cx, |buffer, cx| {
|
||||
buffer.edit([(0..0, "1"), (1..1, "2")], None, cx)
|
||||
});
|
||||
assert_eq!(
|
||||
messages(&assistant, cx),
|
||||
vec![
|
||||
(message_1.id, Role::User, 0..2),
|
||||
(message_2.id, Role::Assistant, 2..3)
|
||||
]
|
||||
);
|
||||
|
||||
let message_3 = assistant.update(cx, |assistant, cx| {
|
||||
assistant
|
||||
.insert_message_after(message_2.id, Role::User, cx)
|
||||
.unwrap()
|
||||
});
|
||||
assert_eq!(
|
||||
messages(&assistant, cx),
|
||||
vec![
|
||||
(message_1.id, Role::User, 0..2),
|
||||
(message_2.id, Role::Assistant, 2..4),
|
||||
(message_3.id, Role::User, 4..4)
|
||||
]
|
||||
);
|
||||
|
||||
let message_4 = assistant.update(cx, |assistant, cx| {
|
||||
assistant
|
||||
.insert_message_after(message_2.id, Role::User, cx)
|
||||
.unwrap()
|
||||
});
|
||||
assert_eq!(
|
||||
messages(&assistant, cx),
|
||||
vec![
|
||||
(message_1.id, Role::User, 0..2),
|
||||
(message_2.id, Role::Assistant, 2..4),
|
||||
(message_4.id, Role::User, 4..5),
|
||||
(message_3.id, Role::User, 5..5),
|
||||
]
|
||||
);
|
||||
|
||||
buffer.update(cx, |buffer, cx| {
|
||||
buffer.edit([(4..4, "C"), (5..5, "D")], None, cx)
|
||||
});
|
||||
assert_eq!(
|
||||
messages(&assistant, cx),
|
||||
vec![
|
||||
(message_1.id, Role::User, 0..2),
|
||||
(message_2.id, Role::Assistant, 2..4),
|
||||
(message_4.id, Role::User, 4..6),
|
||||
(message_3.id, Role::User, 6..7),
|
||||
]
|
||||
);
|
||||
|
||||
// Deleting across message boundaries merges the messages.
|
||||
buffer.update(cx, |buffer, cx| buffer.edit([(1..4, "")], None, cx));
|
||||
assert_eq!(
|
||||
messages(&assistant, cx),
|
||||
vec![
|
||||
(message_1.id, Role::User, 0..6),
|
||||
(message_3.id, Role::User, 6..7),
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
fn messages(
|
||||
assistant: &ModelHandle<Assistant>,
|
||||
cx: &AppContext,
|
||||
) -> Vec<(MessageId, Role, Range<usize>)> {
|
||||
assistant
|
||||
.read(cx)
|
||||
.messages(cx)
|
||||
.map(|(message, metadata, range)| (message.id, metadata.role, range))
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user