This commit is contained in:
Nathan Sobo 2023-05-22 23:11:22 -06:00
parent 7e6cccfa3d
commit 30de64845f
13 changed files with 86 additions and 120 deletions

1
Cargo.lock generated
View File

@ -104,6 +104,7 @@ dependencies = [
"editor",
"futures 0.3.28",
"gpui",
"indoc",
"isahc",
"pulldown-cmark",
"serde",

View File

@ -79,6 +79,7 @@ ctor = { version = "0.1" }
env_logger = { version = "0.9" }
futures = { version = "0.3" }
glob = { version = "0.3.1" }
indoc = "1"
isahc = "1.7.2"
lazy_static = { version = "1.4.0" }
log = { version = "0.4.16", features = ["kv_unstable_serde"] }

0
Untitled Normal file
View File

View File

@ -16,6 +16,7 @@ util = { path = "../util" }
serde.workspace = true
serde_json.workspace = true
anyhow.workspace = true
indoc.workspace = true
pulldown-cmark = "0.9.2"
futures.workspace = true
isahc.workspace = true

5
crates/ai/README.zmd Normal file
View File

@ -0,0 +1,5 @@
This is Zed Markdown.
Mention a language model with / at the start of any line, like this:
/ What do you think of this idea?

View File

@ -1,16 +1,14 @@
use std::io;
use std::rc::Rc;
use anyhow::{anyhow, Result};
use editor::Editor;
use futures::AsyncBufReadExt;
use futures::{io::BufReader, AsyncReadExt, Stream, StreamExt};
use gpui::executor::Foreground;
use gpui::executor::Background;
use gpui::{actions, AppContext, Task, ViewContext};
use indoc::indoc;
use isahc::prelude::*;
use isahc::{http::StatusCode, Request};
use pulldown_cmark::{Event, HeadingLevel, Parser, Tag};
use serde::{Deserialize, Serialize};
use std::{io, sync::Arc};
use util::ResultExt;
actions!(ai, [Assist]);
@ -93,99 +91,87 @@ fn assist(
) -> Option<Task<Result<()>>> {
let api_key = std::env::var("OPENAI_API_KEY").log_err()?;
let markdown = editor.text(cx);
let prompt = parse_dialog(&markdown);
let response = stream_completion(api_key, prompt, cx.foreground().clone());
const SYSTEM_MESSAGE: &'static str = indoc! {r#"
You an AI language model embedded in a code editor named Zed, authored by Zed Industries.
The input you are currently processing was produced by a special \"model mention\" in a document that is open in the editor.
A model mention is indicated via a leading / on a line.
The user's currently selected text is indicated via ->->selected text<-<- surrounding selected text.
In this sentence, the word ->->example<-<- is selected.
Respond to any selected model mention.
Summarize each mention in a single short sentence like:
> The user selected the word \"example\".
Then provide your response to that mention below its summary.
"#};
let range = editor.buffer().update(cx, |buffer, cx| {
let (user_message, insertion_site) = editor.buffer().update(cx, |buffer, cx| {
// Insert ->-> <-<- around selected text as described in the system prompt above.
let snapshot = buffer.snapshot(cx);
let chars = snapshot.reversed_chars_at(snapshot.len());
let trailing_newlines = chars.take(2).take_while(|c| *c == '\n').count();
let suffix = "\n".repeat(2 - trailing_newlines);
let end = snapshot.len();
buffer.edit([(end..end, suffix.clone())], None, cx);
let snapshot = buffer.snapshot(cx);
let start = snapshot.anchor_before(snapshot.len());
let end = snapshot.anchor_after(snapshot.len());
start..end
let mut user_message = String::new();
let mut buffer_offset = 0;
for selection in editor.selections.all(cx) {
user_message.extend(snapshot.text_for_range(buffer_offset..selection.start));
user_message.push_str("->->");
user_message.extend(snapshot.text_for_range(selection.start..selection.end));
buffer_offset = selection.end;
user_message.push_str("<-<-");
}
if buffer_offset < snapshot.len() {
user_message.extend(snapshot.text_for_range(buffer_offset..snapshot.len()));
}
// Ensure the document ends with 4 trailing newlines.
let trailing_newline_count = snapshot
.reversed_chars_at(snapshot.len())
.take_while(|c| *c == '\n')
.take(4);
let suffix = "\n".repeat(4 - trailing_newline_count.count());
buffer.edit([(snapshot.len()..snapshot.len(), suffix)], None, cx);
let snapshot = buffer.snapshot(cx); // Take a new snapshot after editing.
let insertion_site = snapshot.len() - 2; // Insert text at end of buffer, with an empty line both above and below.
(user_message, insertion_site)
});
let stream = stream_completion(
api_key,
cx.background_executor().clone(),
OpenAIRequest {
model: "gpt-4".to_string(),
messages: vec![
RequestMessage {
role: Role::System,
content: SYSTEM_MESSAGE.to_string(),
},
RequestMessage {
role: Role::User,
content: user_message,
},
],
stream: false,
},
);
let buffer = editor.buffer().clone();
Some(cx.spawn(|_, mut cx| async move {
let mut stream = response.await?;
let mut message = String::new();
while let Some(stream_event) = stream.next().await {
if let Some(choice) = stream_event?.choices.first() {
if let Some(content) = &choice.delta.content {
message.push_str(content);
}
let mut messages = stream.await?;
while let Some(message) = messages.next().await {
let mut message = message?;
if let Some(choice) = message.choices.pop() {
buffer.update(&mut cx, |buffer, cx| {
let text: Arc<str> = choice.delta.content?.into();
buffer.edit([(insertion_site.clone()..insertion_site, text)], None, cx);
Some(())
});
}
buffer.update(&mut cx, |buffer, cx| {
buffer.edit([(range.clone(), message.clone())], None, cx);
});
}
Ok(())
}))
}
fn parse_dialog(markdown: &str) -> OpenAIRequest {
let parser = Parser::new(markdown);
let mut messages = Vec::new();
let mut current_role: Option<Role> = None;
let mut buffer = String::new();
for event in parser {
match event {
Event::Start(Tag::Heading(HeadingLevel::H2, _, _)) => {
if let Some(role) = current_role.take() {
if !buffer.is_empty() {
messages.push(RequestMessage {
role,
content: buffer.trim().to_string(),
});
buffer.clear();
}
}
}
Event::Text(text) => {
if current_role.is_some() {
buffer.push_str(&text);
} else {
// Determine the current role based on the H2 header text
let text = text.to_lowercase();
current_role = if text.contains("user") {
Some(Role::User)
} else if text.contains("assistant") {
Some(Role::Assistant)
} else if text.contains("system") {
Some(Role::System)
} else {
None
};
}
}
_ => (),
}
}
if let Some(role) = current_role {
messages.push(RequestMessage {
role,
content: buffer,
});
}
OpenAIRequest {
model: "gpt-4".into(),
messages,
stream: true,
}
}
async fn stream_completion(
api_key: String,
executor: Arc<Background>,
mut request: OpenAIRequest,
executor: Rc<Foreground>,
) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
request.stream = true;
@ -240,32 +226,4 @@ async fn stream_completion(
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_dialog() {
use unindent::Unindent;
let test_input = r#"
## System
Hey there, welcome to Zed!
## Assintant
Thanks! I'm excited to be here. I have much to learn, but also much to teach, and I'm growing fast.
"#.unindent();
let expected_output = vec![
RequestMessage {
role: Role::User,
content: "Hey there, welcome to Zed!".to_string(),
},
RequestMessage {
role: Role::Assistant,
content: "Thanks! I'm excited to be here. I have much to learn, but also much to teach, and I'm growing fast.".to_string(),
},
];
assert_eq!(parse_dialog(&test_input).messages, expected_output);
}
}
mod tests {}

View File

@ -76,7 +76,7 @@ workspace = { path = "../workspace", features = ["test-support"] }
ctor.workspace = true
env_logger.workspace = true
indoc = "1.0.4"
indoc.workspace = true
util = { path = "../util" }
lazy_static.workspace = true
sea-orm = { git = "https://github.com/zed-industries/sea-orm", rev = "18f4c691085712ad014a51792af75a9044bacee6", features = ["sqlx-sqlite"] }

View File

@ -18,7 +18,7 @@ sqlez = { path = "../sqlez" }
sqlez_macros = { path = "../sqlez_macros" }
util = { path = "../util" }
anyhow.workspace = true
indoc = "1.0.4"
indoc.workspace = true
async-trait.workspace = true
lazy_static.workspace = true
log.workspace = true

View File

@ -50,7 +50,7 @@ aho-corasick = "0.7"
anyhow.workspace = true
futures.workspace = true
glob.workspace = true
indoc = "1.0.4"
indoc.workspace = true
itertools = "0.10"
lazy_static.workspace = true
log.workspace = true

View File

@ -70,7 +70,7 @@ settings = { path = "../settings", features = ["test-support"] }
util = { path = "../util", features = ["test-support"] }
ctor.workspace = true
env_logger.workspace = true
indoc = "1.0.4"
indoc.workspace = true
rand.workspace = true
tree-sitter-embedded-template = "*"
tree-sitter-html = "*"

View File

@ -6,7 +6,7 @@ publish = false
[dependencies]
anyhow.workspace = true
indoc = "1.0.7"
indoc.workspace = true
libsqlite3-sys = { version = "0.24", features = ["bundled"] }
smol.workspace = true
thread_local = "1.1.4"

View File

@ -35,7 +35,7 @@ settings = { path = "../settings" }
workspace = { path = "../workspace" }
[dev-dependencies]
indoc = "1.0.4"
indoc.workspace = true
parking_lot.workspace = true
lazy_static.workspace = true

View File

@ -62,5 +62,5 @@ settings = { path = "../settings", features = ["test-support"] }
fs = { path = "../fs", features = ["test-support"] }
db = { path = "../db", features = ["test-support"] }
indoc = "1.0.4"
indoc.workspace = true
env_logger.workspace = true