Remove the ability to reply to specific message in assistant

This commit is contained in:
Antonio Scandurra 2023-08-29 14:51:00 +02:00
parent 2332f82442
commit 72413dbaf2

View File

@ -1767,15 +1767,20 @@ impl Conversation {
cx: &mut ModelContext<Self>,
) -> Vec<MessageAnchor> {
let mut user_messages = Vec::new();
let mut tasks = Vec::new();
let last_message_id = self.message_anchors.iter().rev().find_map(|message| {
message
.start
.is_valid(self.buffer.read(cx))
.then_some(message.id)
});
let last_message_id = if let Some(last_message_id) =
self.message_anchors.iter().rev().find_map(|message| {
message
.start
.is_valid(self.buffer.read(cx))
.then_some(message.id)
}) {
last_message_id
} else {
return Default::default();
};
let mut should_assist = false;
for selected_message_id in selected_messages {
let selected_message_role =
if let Some(metadata) = self.messages_metadata.get(&selected_message_id) {
@ -1792,144 +1797,111 @@ impl Conversation {
cx,
) {
user_messages.push(user_message);
} else {
continue;
}
} else {
let request = OpenAIRequest {
model: self.model.full_name().to_string(),
messages: self
.messages(cx)
.filter(|message| matches!(message.status, MessageStatus::Done))
.flat_map(|message| {
let mut system_message = None;
if message.id == selected_message_id {
system_message = Some(RequestMessage {
role: Role::System,
content: concat!(
"Treat the following messages as additional knowledge you have learned about, ",
"but act as if they were not part of this conversation. That is, treat them ",
"as if the user didn't see them and couldn't possibly inquire about them."
).into()
});
}
Some(message.to_open_ai_message(self.buffer.read(cx))).into_iter().chain(system_message)
})
.chain(Some(RequestMessage {
role: Role::System,
content: format!(
"Direct your reply to message with id {}. Do not include a [Message X] header.",
selected_message_id.0
),
}))
.collect(),
stream: true,
};
let Some(api_key) = self.api_key.borrow().clone() else {
continue;
};
let stream = stream_completion(api_key, cx.background().clone(), request);
let assistant_message = self
.insert_message_after(
selected_message_id,
Role::Assistant,
MessageStatus::Pending,
cx,
)
.unwrap();
// Queue up the user's next reply
if Some(selected_message_id) == last_message_id {
let user_message = self
.insert_message_after(
assistant_message.id,
Role::User,
MessageStatus::Done,
cx,
)
.unwrap();
user_messages.push(user_message);
}
tasks.push(cx.spawn_weak({
|this, mut cx| async move {
let assistant_message_id = assistant_message.id;
let stream_completion = async {
let mut messages = stream.await?;
while let Some(message) = messages.next().await {
let mut message = message?;
if let Some(choice) = message.choices.pop() {
this.upgrade(&cx)
.ok_or_else(|| anyhow!("conversation was dropped"))?
.update(&mut cx, |this, cx| {
let text: Arc<str> = choice.delta.content?.into();
let message_ix = this.message_anchors.iter().position(
|message| message.id == assistant_message_id,
)?;
this.buffer.update(cx, |buffer, cx| {
let offset = this.message_anchors[message_ix + 1..]
.iter()
.find(|message| message.start.is_valid(buffer))
.map_or(buffer.len(), |message| {
message
.start
.to_offset(buffer)
.saturating_sub(1)
});
buffer.edit([(offset..offset, text)], None, cx);
});
cx.emit(ConversationEvent::StreamedCompletion);
Some(())
});
}
smol::future::yield_now().await;
}
this.upgrade(&cx)
.ok_or_else(|| anyhow!("conversation was dropped"))?
.update(&mut cx, |this, cx| {
this.pending_completions.retain(|completion| {
completion.id != this.completion_count
});
this.summarize(cx);
});
anyhow::Ok(())
};
let result = stream_completion.await;
if let Some(this) = this.upgrade(&cx) {
this.update(&mut cx, |this, cx| {
if let Some(metadata) =
this.messages_metadata.get_mut(&assistant_message.id)
{
match result {
Ok(_) => {
metadata.status = MessageStatus::Done;
}
Err(error) => {
metadata.status = MessageStatus::Error(
error.to_string().trim().into(),
);
}
}
cx.notify();
}
});
}
}
}));
should_assist = true;
}
}
if !tasks.is_empty() {
if should_assist {
let Some(api_key) = self.api_key.borrow().clone() else {
return Default::default();
};
let request = OpenAIRequest {
model: self.model.full_name().to_string(),
messages: self
.messages(cx)
.filter(|message| matches!(message.status, MessageStatus::Done))
.map(|message| message.to_open_ai_message(self.buffer.read(cx)))
.collect(),
stream: true,
};
let stream = stream_completion(api_key, cx.background().clone(), request);
let assistant_message = self
.insert_message_after(last_message_id, Role::Assistant, MessageStatus::Pending, cx)
.unwrap();
// Queue up the user's next reply.
let user_message = self
.insert_message_after(assistant_message.id, Role::User, MessageStatus::Done, cx)
.unwrap();
user_messages.push(user_message);
let task = cx.spawn_weak({
|this, mut cx| async move {
let assistant_message_id = assistant_message.id;
let stream_completion = async {
let mut messages = stream.await?;
while let Some(message) = messages.next().await {
let mut message = message?;
if let Some(choice) = message.choices.pop() {
this.upgrade(&cx)
.ok_or_else(|| anyhow!("conversation was dropped"))?
.update(&mut cx, |this, cx| {
let text: Arc<str> = choice.delta.content?.into();
let message_ix =
this.message_anchors.iter().position(|message| {
message.id == assistant_message_id
})?;
this.buffer.update(cx, |buffer, cx| {
let offset = this.message_anchors[message_ix + 1..]
.iter()
.find(|message| message.start.is_valid(buffer))
.map_or(buffer.len(), |message| {
message
.start
.to_offset(buffer)
.saturating_sub(1)
});
buffer.edit([(offset..offset, text)], None, cx);
});
cx.emit(ConversationEvent::StreamedCompletion);
Some(())
});
}
smol::future::yield_now().await;
}
this.upgrade(&cx)
.ok_or_else(|| anyhow!("conversation was dropped"))?
.update(&mut cx, |this, cx| {
this.pending_completions
.retain(|completion| completion.id != this.completion_count);
this.summarize(cx);
});
anyhow::Ok(())
};
let result = stream_completion.await;
if let Some(this) = this.upgrade(&cx) {
this.update(&mut cx, |this, cx| {
if let Some(metadata) =
this.messages_metadata.get_mut(&assistant_message.id)
{
match result {
Ok(_) => {
metadata.status = MessageStatus::Done;
}
Err(error) => {
metadata.status =
MessageStatus::Error(error.to_string().trim().into());
}
}
cx.notify();
}
});
}
}
});
self.pending_completions.push(PendingCompletion {
id: post_inc(&mut self.completion_count),
_tasks: tasks,
_task: task,
});
}
@ -2296,7 +2268,7 @@ impl Conversation {
struct PendingCompletion {
id: usize,
_tasks: Vec<Task<()>>,
_task: Task<()>,
}
enum ConversationEditorEvent {
@ -2844,8 +2816,9 @@ pub struct Message {
impl Message {
fn to_open_ai_message(&self, buffer: &Buffer) -> RequestMessage {
let mut content = format!("[Message {}]\n", self.id.0).to_string();
content.extend(buffer.text_for_range(self.offset_range.clone()));
let content = buffer
.text_for_range(self.offset_range.clone())
.collect::<String>();
RequestMessage {
role: self.role,
content: content.trim_end().into(),