Extract a strip_markdown_codeblock function

This commit is contained in:
Antonio Scandurra 2023-08-28 11:24:55 +02:00
parent 55bf45d265
commit 937aabfdfd

View File

@ -16,7 +16,7 @@ use editor::{
Anchor, Editor, MultiBufferSnapshot, ToOffset, ToPoint,
};
use fs::Fs;
use futures::{channel::mpsc, SinkExt, StreamExt};
use futures::{channel::mpsc, SinkExt, Stream, StreamExt};
use gpui::{
actions,
elements::*,
@ -620,7 +620,10 @@ impl AssistantPanel {
let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
let diff = cx.background().spawn(async move {
let mut messages = response.await?;
let chunks = strip_markdown_codeblock(response.await?.filter_map(
|message| async move { message.ok()?.choices.pop()?.delta.content },
));
futures::pin_mut!(chunks);
let mut diff = StreamingDiff::new(selected_text.to_string());
let indentation_len;
@ -636,93 +639,21 @@ impl AssistantPanel {
indentation_text = "";
};
let mut inside_first_line = true;
let mut starts_with_fenced_code_block = None;
let mut has_pending_newline = false;
let mut new_text = String::new();
let mut new_text = indentation_text
.repeat(indentation_len.saturating_sub(selection_start.column) as usize);
while let Some(message) = messages.next().await {
let mut message = message?;
if let Some(mut choice) = message.choices.pop() {
if has_pending_newline {
has_pending_newline = false;
choice
.delta
.content
.get_or_insert(String::new())
.insert(0, '\n');
}
while let Some(message) = chunks.next().await {
let mut lines = message.split('\n');
if let Some(first_line) = lines.next() {
new_text.push_str(first_line);
}
// Buffer a trailing codeblock fence. Note that we don't stop
// right away because this may be an inner fence that we need
// to insert into the editor.
if starts_with_fenced_code_block.is_some()
&& choice.delta.content.as_deref() == Some("\n```")
{
new_text.push_str("\n```");
continue;
}
// If this was the last completion and we started with a codeblock
// fence and we ended with another codeblock fence, then we can
// stop right away. Otherwise, whatever text we buffered will be
// processed normally.
if choice.finish_reason.is_some()
&& starts_with_fenced_code_block.unwrap_or(false)
&& new_text == "\n```"
{
break;
}
if let Some(text) = choice.delta.content {
// Never push a newline if there's nothing after it. This is
// useful to detect if the newline was pushed because of a
// trailing codeblock fence.
let text = if let Some(prefix) = text.strip_suffix('\n') {
has_pending_newline = true;
prefix
} else {
text.as_str()
};
if text.is_empty() {
continue;
}
let mut lines = text.split('\n');
if let Some(line) = lines.next() {
if starts_with_fenced_code_block.is_none() {
starts_with_fenced_code_block =
Some(line.starts_with("```"));
}
// Avoid pushing the first line if it's the start of a fenced code block.
if !inside_first_line || !starts_with_fenced_code_block.unwrap()
{
new_text.push_str(&line);
}
}
for line in lines {
if inside_first_line && starts_with_fenced_code_block.unwrap() {
// If we were inside the first line and that line was the
// start of a fenced code block, we just need to push the
// leading indentation of the original selection.
new_text.push_str(&indentation_text.repeat(
indentation_len.saturating_sub(selection_start.column)
as usize,
));
} else {
// Otherwise, we need to push a newline and the base indentation.
new_text.push('\n');
new_text.push_str(
&indentation_text.repeat(indentation_len as usize),
);
}
new_text.push_str(line);
inside_first_line = false;
}
for line in lines {
new_text.push('\n');
if !line.is_empty() {
new_text
.push_str(&indentation_text.repeat(indentation_len as usize));
new_text.push_str(line);
}
}
@ -2919,10 +2850,58 @@ fn merge_ranges(ranges: &mut Vec<Range<Anchor>>, buffer: &MultiBufferSnapshot) {
}
}
fn strip_markdown_codeblock(stream: impl Stream<Item = String>) -> impl Stream<Item = String> {
let mut first_line = true;
let mut buffer = String::new();
let mut starts_with_fenced_code_block = false;
stream.filter_map(move |chunk| {
buffer.push_str(&chunk);
if first_line {
if buffer == "" || buffer == "`" || buffer == "``" {
return futures::future::ready(None);
} else if buffer.starts_with("```") {
starts_with_fenced_code_block = true;
if let Some(newline_ix) = buffer.find('\n') {
buffer.replace_range(..newline_ix + 1, "");
first_line = false;
} else {
return futures::future::ready(None);
}
}
}
let text = if starts_with_fenced_code_block {
buffer
.strip_suffix("\n```")
.or_else(|| buffer.strip_suffix("\n``"))
.or_else(|| buffer.strip_suffix("\n`"))
.or_else(|| buffer.strip_suffix('\n'))
.unwrap_or(&buffer)
} else {
&buffer
};
if text.contains('\n') {
first_line = false;
}
let remainder = buffer.split_off(text.len());
let result = if buffer.is_empty() {
None
} else {
Some(buffer.clone())
};
buffer = remainder;
futures::future::ready(result)
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::MessageId;
use futures::stream;
use gpui::AppContext;
#[gpui::test]
@ -3291,6 +3270,50 @@ mod tests {
);
}
#[gpui::test]
async fn test_strip_markdown_codeblock() {
assert_eq!(
strip_markdown_codeblock(chunks("Lorem ipsum dolor", 2))
.collect::<String>()
.await,
"Lorem ipsum dolor"
);
assert_eq!(
strip_markdown_codeblock(chunks("```\nLorem ipsum dolor", 2))
.collect::<String>()
.await,
"Lorem ipsum dolor"
);
assert_eq!(
strip_markdown_codeblock(chunks("```\nLorem ipsum dolor\n```", 2))
.collect::<String>()
.await,
"Lorem ipsum dolor"
);
assert_eq!(
strip_markdown_codeblock(chunks("```html\n```js\nLorem ipsum dolor\n```\n```", 2))
.collect::<String>()
.await,
"```js\nLorem ipsum dolor\n```"
);
assert_eq!(
strip_markdown_codeblock(chunks("``\nLorem ipsum dolor\n```", 2))
.collect::<String>()
.await,
"``\nLorem ipsum dolor\n```"
);
fn chunks(text: &str, size: usize) -> impl Stream<Item = String> {
stream::iter(
text.chars()
.collect::<Vec<_>>()
.chunks(size)
.map(|chunk| chunk.iter().collect::<String>())
.collect::<Vec<_>>(),
)
}
}
fn messages(
conversation: &ModelHandle<Conversation>,
cx: &AppContext,