reorganize AI crates to structure future development (#3015)

Reorganized assistant/semantic_index crates and introduced AI crate to
include shared functionality.

Release Notes:

- Moved most of the Assistant functionality from ai crate to assistant
crate
- Moved interaction with embedding providers from semantic_index to ai
crate
This commit is contained in:
Kyle Caverly 2023-09-22 09:54:46 -04:00 committed by GitHub
commit e84339ef4a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
26 changed files with 555 additions and 528 deletions

220
Cargo.lock generated
View File

@ -79,9 +79,9 @@ dependencies = [
[[package]]
name = "aho-corasick"
version = "1.1.0"
version = "1.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0f2135563fb5c609d2b2b87c1e8ce7bc41b0b45430fa9661f457981503dd5bf0"
checksum = "ea5d730647d4fadd988536d06fecce94b7b4f2a7efdae548f1cf4b63205518ab"
dependencies = [
"memchr",
]
@ -91,36 +91,25 @@ name = "ai"
version = "0.1.0"
dependencies = [
"anyhow",
"chrono",
"client",
"collections",
"ctor",
"editor",
"env_logger 0.9.3",
"fs",
"async-trait",
"bincode",
"futures 0.3.28",
"gpui",
"indoc",
"isahc",
"language",
"lazy_static",
"log",
"menu",
"matrixmultiply",
"ordered-float",
"parking_lot 0.11.2",
"project",
"parse_duration",
"postage",
"rand 0.8.5",
"regex",
"schemars",
"search",
"rusqlite",
"serde",
"serde_json",
"settings",
"smol",
"theme",
"tiktoken-rs 0.4.5",
"tiktoken-rs 0.5.4",
"util",
"uuid 1.4.1",
"workspace",
]
[[package]]
@ -305,6 +294,44 @@ version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d92bec98840b8f03a5ff5413de5293bfcd8bf96467cf5452609f939ec6f5de16"
[[package]]
name = "assistant"
version = "0.1.0"
dependencies = [
"ai",
"anyhow",
"chrono",
"client",
"collections",
"ctor",
"editor",
"env_logger 0.9.3",
"fs",
"futures 0.3.28",
"gpui",
"indoc",
"isahc",
"language",
"log",
"menu",
"ordered-float",
"parking_lot 0.11.2",
"project",
"rand 0.8.5",
"regex",
"schemars",
"search",
"serde",
"serde_json",
"settings",
"smol",
"theme",
"tiktoken-rs 0.4.5",
"util",
"uuid 1.4.1",
"workspace",
]
[[package]]
name = "async-broadcast"
version = "0.4.1"
@ -2141,7 +2168,7 @@ dependencies = [
"convert_case 0.4.0",
"proc-macro2",
"quote",
"rustc_version 0.4.0",
"rustc_version",
"syn 1.0.109",
]
@ -3234,7 +3261,7 @@ dependencies = [
"indexmap 1.9.3",
"slab",
"tokio",
"tokio-util 0.7.8",
"tokio-util 0.7.9",
"tracing",
]
@ -3355,9 +3382,9 @@ dependencies = [
[[package]]
name = "hermit-abi"
version = "0.3.2"
version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "443144c8cdadd93ebf52ddb4056d257f5b52c04d3c804e657d19eb73fc33668b"
checksum = "d77f7ec81a6d05a3abb01ab6eb7590f6083d08449fe5a1c8b1e620283546ccb7"
[[package]]
name = "hex"
@ -3651,7 +3678,7 @@ version = "1.0.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eae7b9aee968036d54dce06cebaefd919e4472e753296daccd6d344e3e2df0c2"
dependencies = [
"hermit-abi 0.3.2",
"hermit-abi 0.3.3",
"libc",
"windows-sys",
]
@ -3708,8 +3735,8 @@ version = "0.4.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cb0889898416213fab133e1d33a0e5858a48177452750691bde3666d0fdbaf8b"
dependencies = [
"hermit-abi 0.3.2",
"rustix 0.38.13",
"hermit-abi 0.3.3",
"rustix 0.38.14",
"windows-sys",
]
@ -4277,9 +4304,9 @@ checksum = "73cbba799671b762df5a175adf59ce145165747bb891505c43d09aefbbf38beb"
[[package]]
name = "matrixmultiply"
version = "0.3.7"
version = "0.3.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "090126dc04f95dc0d1c1c91f61bdd474b3930ca064c1edc8a849da2c6cbe1e77"
checksum = "7574c1cf36da4798ab73da5b215bbf444f50718207754cb522201d78d1cd0ff2"
dependencies = [
"autocfg",
"rawpointer",
@ -4806,7 +4833,7 @@ version = "1.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43"
dependencies = [
"hermit-abi 0.3.2",
"hermit-abi 0.3.3",
"libc",
]
@ -4843,7 +4870,7 @@ dependencies = [
"rmp",
"rmpv",
"tokio",
"tokio-util 0.7.8",
"tokio-util 0.7.9",
]
[[package]]
@ -5144,11 +5171,11 @@ dependencies = [
[[package]]
name = "pathfinder_simd"
version = "0.5.1"
version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "39fe46acc5503595e5949c17b818714d26fdf9b4920eacf3b2947f0199f4a6ff"
checksum = "0444332826c70dc47be74a7c6a5fc44e23a7905ad6858d4162b658320455ef93"
dependencies = [
"rustc_version 0.3.3",
"rustc_version",
]
[[package]]
@ -5183,17 +5210,6 @@ version = "2.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9b2a4787296e9989611394c33f193f676704af1686e70b8f8033ab5ba9a35a94"
[[package]]
name = "pest"
version = "2.7.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d7a4d085fd991ac8d5b05a147b437791b4260b76326baf0fc60cf7c9c27ecd33"
dependencies = [
"memchr",
"thiserror",
"ucd-trie",
]
[[package]]
name = "petgraph"
version = "0.6.4"
@ -5728,7 +5744,7 @@ dependencies = [
name = "quick_action_bar"
version = "0.1.0"
dependencies = [
"ai",
"assistant",
"editor",
"gpui",
"search",
@ -5864,9 +5880,9 @@ checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3"
[[package]]
name = "rayon"
version = "1.7.0"
version = "1.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1d2df5196e37bcc87abebc0053e20787d73847bb33134a69841207dd0a47f03b"
checksum = "9c27db03db7734835b3f53954b534c91069375ce6ccaa2e065441e07d9b6cdb1"
dependencies = [
"either",
"rayon-core",
@ -5874,14 +5890,12 @@ dependencies = [
[[package]]
name = "rayon-core"
version = "1.11.0"
version = "1.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4b8f95bd6966f5c87776639160a66bd8ab9895d9d4ab01ddba9fc60661aebe8d"
checksum = "5ce3fb6ad83f861aac485e76e1985cd109d9a3713802152be56c3b1f0e0658ed"
dependencies = [
"crossbeam-channel",
"crossbeam-deque",
"crossbeam-utils",
"num_cpus",
]
[[package]]
@ -6331,22 +6345,13 @@ version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2"
[[package]]
name = "rustc_version"
version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f0dfe2087c51c460008730de8b57e6a320782fbfb312e1f4d520e6c6fae155ee"
dependencies = [
"semver 0.11.0",
]
[[package]]
name = "rustc_version"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bfa0f585226d2e68097d4f95d113b15b83a82e819ab25717ec0590d9584ef366"
dependencies = [
"semver 1.0.18",
"semver",
]
[[package]]
@ -6381,9 +6386,9 @@ dependencies = [
[[package]]
name = "rustix"
version = "0.38.13"
version = "0.38.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d7db8590df6dfcd144d22afd1b83b36c21a18d7cbc1dc4bb5295a8712e9eb662"
checksum = "747c788e9ce8e92b12cd485c49ddf90723550b654b32508f979b71a7b1ecda4f"
dependencies = [
"bitflags 2.4.0",
"errno 0.3.3",
@ -6732,9 +6737,9 @@ dependencies = [
name = "semantic_index"
version = "0.1.0"
dependencies = [
"ai",
"anyhow",
"async-trait",
"bincode",
"client",
"collections",
"ctor",
@ -6743,15 +6748,12 @@ dependencies = [
"futures 0.3.28",
"globset",
"gpui",
"isahc",
"language",
"lazy_static",
"log",
"matrixmultiply",
"node_runtime",
"ordered-float",
"parking_lot 0.11.2",
"parse_duration",
"picker",
"postage",
"pretty_assertions",
@ -6785,30 +6787,12 @@ dependencies = [
"zed",
]
[[package]]
name = "semver"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f301af10236f6df4160f7c3f04eec6dbc70ace82d23326abad5edee88801c6b6"
dependencies = [
"semver-parser",
]
[[package]]
name = "semver"
version = "1.0.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b0293b4b29daaf487284529cc2f5675b8e57c61f70167ba415a463651fd6a918"
[[package]]
name = "semver-parser"
version = "0.10.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "00b0bef5b7f9e0df16536d3961cfb6e84331c065b4066afb39768d0e319411f7"
dependencies = [
"pest",
]
[[package]]
name = "seq-macro"
version = "0.2.2"
@ -6978,9 +6962,9 @@ dependencies = [
[[package]]
name = "sha1"
version = "0.10.5"
version = "0.10.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f04293dc80c3993519f2d7f6f511707ee7094fe0c6d3406feb330cdb3540eba3"
checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba"
dependencies = [
"cfg-if 1.0.0",
"cpufeatures",
@ -7153,9 +7137,9 @@ dependencies = [
[[package]]
name = "smallvec"
version = "1.11.0"
version = "1.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "62bb4feee49fdd9f707ef802e22365a35de4b7b299de4763d44bfea899442ff9"
checksum = "942b4a808e05215192e39f4ab80813e599068285906cc91aa64f923db842bd5a"
[[package]]
name = "smol"
@ -7438,15 +7422,15 @@ dependencies = [
[[package]]
name = "sval"
version = "2.6.1"
version = "2.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8b031320a434d3e9477ccf9b5756d57d4272937b8d22cb88af80b7633a1b78b1"
checksum = "05d11eec9fbe2bc8bc71e7349f0e7534db9a96d961fb9f302574275b7880ad06"
[[package]]
name = "sval_buffer"
version = "2.6.1"
version = "2.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6bf7e9412af26b342f3f2cc5cc4122b0105e9d16eb76046cd14ed10106cf6028"
checksum = "6b7451f69a93c5baf2653d5aa8bb4178934337f16c22830a50b06b386f72d761"
dependencies = [
"sval",
"sval_ref",
@ -7454,18 +7438,18 @@ dependencies = [
[[package]]
name = "sval_dynamic"
version = "2.6.1"
version = "2.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a0ef628e8a77a46ed3338db8d1b08af77495123cc229453084e47cd716d403cf"
checksum = "c34f5a2cc12b4da2adfb59d5eedfd9b174a23cc3fae84cec71dcbcd9302068f5"
dependencies = [
"sval",
]
[[package]]
name = "sval_fmt"
version = "2.6.1"
version = "2.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7dc09e9364c2045ab5fa38f7b04d077b3359d30c4c2b3ec4bae67a358bd64326"
checksum = "2f578b2301341e246d00b35957f2952c4ec554ad9c7cfaee10bc86bc92896578"
dependencies = [
"itoa",
"ryu",
@ -7474,9 +7458,9 @@ dependencies = [
[[package]]
name = "sval_json"
version = "2.6.1"
version = "2.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ada6f627e38cbb8860283649509d87bc4a5771141daa41c78fd31f2b9485888d"
checksum = "8346c00f5dc6efe18bea8d13c1f7ca4f112b20803434bf3657ac17c0f74cbc4b"
dependencies = [
"itoa",
"ryu",
@ -7485,18 +7469,18 @@ dependencies = [
[[package]]
name = "sval_ref"
version = "2.6.1"
version = "2.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "703ca1942a984bd0d9b5a4c0a65ab8b4b794038d080af4eb303c71bc6bf22d7c"
checksum = "6617cc89952f792aebc0f4a1a76bc51e80c70b18c491bd52215c7989c4c3dd06"
dependencies = [
"sval",
]
[[package]]
name = "sval_serde"
version = "2.6.1"
version = "2.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "830926cd0581f7c3e5d51efae4d35c6b6fc4db583842652891ba2f1bed8db046"
checksum = "fe3d1e59f023341d9af75d86f3bc148a6704f3f831eef0dd90bbe9cb445fa024"
dependencies = [
"serde",
"sval",
@ -7647,7 +7631,7 @@ dependencies = [
"cfg-if 1.0.0",
"fastrand 2.0.0",
"redox_syscall 0.3.5",
"rustix 0.38.13",
"rustix 0.38.14",
"windows-sys",
]
@ -8049,9 +8033,9 @@ dependencies = [
[[package]]
name = "tokio-util"
version = "0.7.8"
version = "0.7.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "806fe8c2c87eccc8b3267cbae29ed3ab2d0bd37fca70ab622e46aaa9375ddb7d"
checksum = "1d68074620f57a0b21594d9735eb2e98ab38b17f80d3fcb189fca266771ca60d"
dependencies = [
"bytes 1.5.0",
"futures-core",
@ -8150,7 +8134,7 @@ dependencies = [
"rand 0.8.5",
"slab",
"tokio",
"tokio-util 0.7.8",
"tokio-util 0.7.9",
"tower-layer",
"tower-service",
"tracing",
@ -8595,12 +8579,6 @@ version = "1.17.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825"
[[package]]
name = "ucd-trie"
version = "0.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ed646292ffc8188ef8ea4d1e0e0150fb15a5c2e12ad9b8fc191ae7a8a7f3c4b9"
[[package]]
name = "ui"
version = "0.1.0"
@ -8680,9 +8658,9 @@ checksum = "b1d386ff53b415b7fe27b50bb44679e2cc4660272694b7b6f3326d8480823a94"
[[package]]
name = "unicode-width"
version = "0.1.10"
version = "0.1.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c0edd1e5b14653f783770bce4a4dabb4a5108a5370a5f5d8cfe8710c361f6c8b"
checksum = "e51733f11c9c4f72aa0c160008246859e340b00807569a0da0e7a1079b27ba85"
[[package]]
name = "unicode_categories"
@ -9395,7 +9373,7 @@ dependencies = [
"either",
"home",
"once_cell",
"rustix 0.38.13",
"rustix 0.38.14",
]
[[package]]
@ -9480,9 +9458,9 @@ checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
[[package]]
name = "winapi-util"
version = "0.1.5"
version = "0.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "70ec6ce85bb158151cae5e5c87f95a8e97d2c0c4b001223f33a334e3ce5de178"
checksum = "f29e6f9198ba0d26b4c9f07dbe6f9ed633e1f3d5b8b414090084349e46a52596"
dependencies = [
"winapi 0.3.9",
]
@ -9808,8 +9786,8 @@ name = "zed"
version = "0.106.0"
dependencies = [
"activity_indicator",
"ai",
"anyhow",
"assistant",
"async-compression",
"async-recursion 0.3.2",
"async-tar",

View File

@ -2,6 +2,7 @@
members = [
"crates/activity_indicator",
"crates/ai",
"crates/assistant",
"crates/audio",
"crates/auto_update",
"crates/breadcrumbs",

View File

@ -9,39 +9,26 @@ path = "src/ai.rs"
doctest = false
[dependencies]
client = { path = "../client" }
collections = { path = "../collections"}
editor = { path = "../editor" }
fs = { path = "../fs" }
gpui = { path = "../gpui" }
language = { path = "../language" }
menu = { path = "../menu" }
search = { path = "../search" }
settings = { path = "../settings" }
theme = { path = "../theme" }
util = { path = "../util" }
uuid = { version = "1.1.2", features = ["v4"] }
workspace = { path = "../workspace" }
async-trait.workspace = true
anyhow.workspace = true
chrono = { version = "0.4", features = ["serde"] }
futures.workspace = true
indoc.workspace = true
isahc.workspace = true
lazy_static.workspace = true
ordered-float.workspace = true
parking_lot.workspace = true
isahc.workspace = true
regex.workspace = true
schemars.workspace = true
serde.workspace = true
serde_json.workspace = true
smol.workspace = true
tiktoken-rs = "0.4"
postage.workspace = true
rand.workspace = true
log.workspace = true
parse_duration = "2.1.1"
tiktoken-rs = "0.5.0"
matrixmultiply = "0.3.7"
rusqlite = { version = "0.27.0", features = ["blob", "array", "modern_sqlite"] }
bincode = "1.3.3"
[dev-dependencies]
editor = { path = "../editor", features = ["test-support"] }
project = { path = "../project", features = ["test-support"] }
ctor.workspace = true
env_logger.workspace = true
log.workspace = true
rand.workspace = true
gpui = { path = "../gpui", features = ["test-support"] }

View File

@ -1,294 +1,2 @@
pub mod assistant;
mod assistant_settings;
mod codegen;
mod streaming_diff;
use anyhow::{anyhow, Result};
pub use assistant::AssistantPanel;
use assistant_settings::OpenAIModel;
use chrono::{DateTime, Local};
use collections::HashMap;
use fs::Fs;
use futures::{io::BufReader, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
use gpui::{executor::Background, AppContext};
use isahc::{http::StatusCode, Request, RequestExt};
use regex::Regex;
use serde::{Deserialize, Serialize};
use std::{
cmp::Reverse,
ffi::OsStr,
fmt::{self, Display},
io,
path::PathBuf,
sync::Arc,
};
use util::paths::CONVERSATIONS_DIR;
const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
// Data types for chat completion requests
#[derive(Debug, Default, Serialize)]
pub struct OpenAIRequest {
model: String,
messages: Vec<RequestMessage>,
stream: bool,
}
#[derive(
Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize,
)]
struct MessageId(usize);
#[derive(Clone, Debug, Serialize, Deserialize)]
struct MessageMetadata {
role: Role,
sent_at: DateTime<Local>,
status: MessageStatus,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
enum MessageStatus {
Pending,
Done,
Error(Arc<str>),
}
#[derive(Serialize, Deserialize)]
struct SavedMessage {
id: MessageId,
start: usize,
}
#[derive(Serialize, Deserialize)]
struct SavedConversation {
id: Option<String>,
zed: String,
version: String,
text: String,
messages: Vec<SavedMessage>,
message_metadata: HashMap<MessageId, MessageMetadata>,
summary: String,
model: OpenAIModel,
}
impl SavedConversation {
const VERSION: &'static str = "0.1.0";
}
struct SavedConversationMetadata {
title: String,
path: PathBuf,
mtime: chrono::DateTime<chrono::Local>,
}
impl SavedConversationMetadata {
pub async fn list(fs: Arc<dyn Fs>) -> Result<Vec<Self>> {
fs.create_dir(&CONVERSATIONS_DIR).await?;
let mut paths = fs.read_dir(&CONVERSATIONS_DIR).await?;
let mut conversations = Vec::<SavedConversationMetadata>::new();
while let Some(path) = paths.next().await {
let path = path?;
if path.extension() != Some(OsStr::new("json")) {
continue;
}
let pattern = r" - \d+.zed.json$";
let re = Regex::new(pattern).unwrap();
let metadata = fs.metadata(&path).await?;
if let Some((file_name, metadata)) = path
.file_name()
.and_then(|name| name.to_str())
.zip(metadata)
{
let title = re.replace(file_name, "");
conversations.push(Self {
title: title.into_owned(),
path,
mtime: metadata.mtime.into(),
});
}
}
conversations.sort_unstable_by_key(|conversation| Reverse(conversation.mtime));
Ok(conversations)
}
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
struct RequestMessage {
role: Role,
content: String,
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct ResponseMessage {
role: Option<Role>,
content: Option<String>,
}
#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
#[serde(rename_all = "lowercase")]
enum Role {
User,
Assistant,
System,
}
impl Role {
pub fn cycle(&mut self) {
*self = match self {
Role::User => Role::Assistant,
Role::Assistant => Role::System,
Role::System => Role::User,
}
}
}
impl Display for Role {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Role::User => write!(f, "User"),
Role::Assistant => write!(f, "Assistant"),
Role::System => write!(f, "System"),
}
}
}
#[derive(Deserialize, Debug)]
pub struct OpenAIResponseStreamEvent {
pub id: Option<String>,
pub object: String,
pub created: u32,
pub model: String,
pub choices: Vec<ChatChoiceDelta>,
pub usage: Option<Usage>,
}
#[derive(Deserialize, Debug)]
pub struct Usage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}
#[derive(Deserialize, Debug)]
pub struct ChatChoiceDelta {
pub index: u32,
pub delta: ResponseMessage,
pub finish_reason: Option<String>,
}
#[derive(Deserialize, Debug)]
struct OpenAIUsage {
prompt_tokens: u64,
completion_tokens: u64,
total_tokens: u64,
}
#[derive(Deserialize, Debug)]
struct OpenAIChoice {
text: String,
index: u32,
logprobs: Option<serde_json::Value>,
finish_reason: Option<String>,
}
pub fn init(cx: &mut AppContext) {
assistant::init(cx);
}
pub async fn stream_completion(
api_key: String,
executor: Arc<Background>,
mut request: OpenAIRequest,
) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
request.stream = true;
let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
let json_data = serde_json::to_string(&request)?;
let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", api_key))
.body(json_data)?
.send_async()
.await?;
let status = response.status();
if status == StatusCode::OK {
executor
.spawn(async move {
let mut lines = BufReader::new(response.body_mut()).lines();
fn parse_line(
line: Result<String, io::Error>,
) -> Result<Option<OpenAIResponseStreamEvent>> {
if let Some(data) = line?.strip_prefix("data: ") {
let event = serde_json::from_str(&data)?;
Ok(Some(event))
} else {
Ok(None)
}
}
while let Some(line) = lines.next().await {
if let Some(event) = parse_line(line).transpose() {
let done = event.as_ref().map_or(false, |event| {
event
.choices
.last()
.map_or(false, |choice| choice.finish_reason.is_some())
});
if tx.unbounded_send(event).is_err() {
break;
}
if done {
break;
}
}
}
anyhow::Ok(())
})
.detach();
Ok(rx)
} else {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
#[derive(Deserialize)]
struct OpenAIResponse {
error: OpenAIError,
}
#[derive(Deserialize)]
struct OpenAIError {
message: String,
}
match serde_json::from_str::<OpenAIResponse>(&body) {
Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
"Failed to connect to OpenAI API: {}",
response.error.message,
)),
_ => Err(anyhow!(
"Failed to connect to OpenAI API: {} {}",
response.status(),
body,
)),
}
}
}
#[cfg(test)]
#[ctor::ctor]
fn init_logger() {
if std::env::var("RUST_LOG").is_ok() {
env_logger::init();
}
}
pub mod completion;
pub mod embedding;

212
crates/ai/src/completion.rs Normal file
View File

@ -0,0 +1,212 @@
use anyhow::{anyhow, Result};
use futures::{
future::BoxFuture, io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt,
Stream, StreamExt,
};
use gpui::executor::Background;
use isahc::{http::StatusCode, Request, RequestExt};
use serde::{Deserialize, Serialize};
use std::{
fmt::{self, Display},
io,
sync::Arc,
};
pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum Role {
User,
Assistant,
System,
}
impl Role {
pub fn cycle(&mut self) {
*self = match self {
Role::User => Role::Assistant,
Role::Assistant => Role::System,
Role::System => Role::User,
}
}
}
impl Display for Role {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Role::User => write!(f, "User"),
Role::Assistant => write!(f, "Assistant"),
Role::System => write!(f, "System"),
}
}
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct RequestMessage {
pub role: Role,
pub content: String,
}
#[derive(Debug, Default, Serialize)]
pub struct OpenAIRequest {
pub model: String,
pub messages: Vec<RequestMessage>,
pub stream: bool,
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct ResponseMessage {
pub role: Option<Role>,
pub content: Option<String>,
}
#[derive(Deserialize, Debug)]
pub struct OpenAIUsage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}
#[derive(Deserialize, Debug)]
pub struct ChatChoiceDelta {
pub index: u32,
pub delta: ResponseMessage,
pub finish_reason: Option<String>,
}
#[derive(Deserialize, Debug)]
pub struct OpenAIResponseStreamEvent {
pub id: Option<String>,
pub object: String,
pub created: u32,
pub model: String,
pub choices: Vec<ChatChoiceDelta>,
pub usage: Option<OpenAIUsage>,
}
pub async fn stream_completion(
api_key: String,
executor: Arc<Background>,
mut request: OpenAIRequest,
) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
request.stream = true;
let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
let json_data = serde_json::to_string(&request)?;
let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", api_key))
.body(json_data)?
.send_async()
.await?;
let status = response.status();
if status == StatusCode::OK {
executor
.spawn(async move {
let mut lines = BufReader::new(response.body_mut()).lines();
fn parse_line(
line: Result<String, io::Error>,
) -> Result<Option<OpenAIResponseStreamEvent>> {
if let Some(data) = line?.strip_prefix("data: ") {
let event = serde_json::from_str(&data)?;
Ok(Some(event))
} else {
Ok(None)
}
}
while let Some(line) = lines.next().await {
if let Some(event) = parse_line(line).transpose() {
let done = event.as_ref().map_or(false, |event| {
event
.choices
.last()
.map_or(false, |choice| choice.finish_reason.is_some())
});
if tx.unbounded_send(event).is_err() {
break;
}
if done {
break;
}
}
}
anyhow::Ok(())
})
.detach();
Ok(rx)
} else {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
#[derive(Deserialize)]
struct OpenAIResponse {
error: OpenAIError,
}
#[derive(Deserialize)]
struct OpenAIError {
message: String,
}
match serde_json::from_str::<OpenAIResponse>(&body) {
Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
"Failed to connect to OpenAI API: {}",
response.error.message,
)),
_ => Err(anyhow!(
"Failed to connect to OpenAI API: {} {}",
response.status(),
body,
)),
}
}
}
pub trait CompletionProvider {
fn complete(
&self,
prompt: OpenAIRequest,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
}
pub struct OpenAICompletionProvider {
api_key: String,
executor: Arc<Background>,
}
impl OpenAICompletionProvider {
pub fn new(api_key: String, executor: Arc<Background>) -> Self {
Self { api_key, executor }
}
}
impl CompletionProvider for OpenAICompletionProvider {
fn complete(
&self,
prompt: OpenAIRequest,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
let request = stream_completion(self.api_key.clone(), self.executor.clone(), prompt);
async move {
let response = request.await?;
let stream = response
.filter_map(|response| async move {
match response {
Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
Err(error) => Some(Err(error)),
}
})
.boxed();
Ok(stream)
}
.boxed()
}
}

View File

@ -27,8 +27,30 @@ lazy_static! {
}
#[derive(Debug, PartialEq, Clone)]
pub struct Embedding(Vec<f32>);
pub struct Embedding(pub Vec<f32>);
// This is needed for semantic index functionality
// Unfortunately it has to live wherever the "Embedding" struct is created.
// Keeping this in here though, introduces a 'rusqlite' dependency into AI
// which is less than ideal
impl FromSql for Embedding {
fn column_result(value: ValueRef) -> FromSqlResult<Self> {
let bytes = value.as_blob()?;
let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
if embedding.is_err() {
return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
}
Ok(Embedding(embedding.unwrap()))
}
}
impl ToSql for Embedding {
fn to_sql(&self) -> rusqlite::Result<ToSqlOutput> {
let bytes = bincode::serialize(&self.0)
.map_err(|err| rusqlite::Error::ToSqlConversionFailure(Box::new(err)))?;
Ok(ToSqlOutput::Owned(rusqlite::types::Value::Blob(bytes)))
}
}
impl From<Vec<f32>> for Embedding {
fn from(value: Vec<f32>) -> Self {
Embedding(value)
@ -63,24 +85,24 @@ impl Embedding {
}
}
impl FromSql for Embedding {
fn column_result(value: ValueRef) -> FromSqlResult<Self> {
let bytes = value.as_blob()?;
let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
if embedding.is_err() {
return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
}
Ok(Embedding(embedding.unwrap()))
}
}
// impl FromSql for Embedding {
// fn column_result(value: ValueRef) -> FromSqlResult<Self> {
// let bytes = value.as_blob()?;
// let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
// if embedding.is_err() {
// return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
// }
// Ok(Embedding(embedding.unwrap()))
// }
// }
impl ToSql for Embedding {
fn to_sql(&self) -> rusqlite::Result<ToSqlOutput> {
let bytes = bincode::serialize(&self.0)
.map_err(|err| rusqlite::Error::ToSqlConversionFailure(Box::new(err)))?;
Ok(ToSqlOutput::Owned(rusqlite::types::Value::Blob(bytes)))
}
}
// impl ToSql for Embedding {
// fn to_sql(&self) -> rusqlite::Result<ToSqlOutput> {
// let bytes = bincode::serialize(&self.0)
// .map_err(|err| rusqlite::Error::ToSqlConversionFailure(Box::new(err)))?;
// Ok(ToSqlOutput::Owned(rusqlite::types::Value::Blob(bytes)))
// }
// }
#[derive(Clone)]
pub struct OpenAIEmbeddings {

View File

@ -0,0 +1,48 @@
[package]
name = "assistant"
version = "0.1.0"
edition = "2021"
publish = false
[lib]
path = "src/assistant.rs"
doctest = false
[dependencies]
ai = { path = "../ai" }
client = { path = "../client" }
collections = { path = "../collections"}
editor = { path = "../editor" }
fs = { path = "../fs" }
gpui = { path = "../gpui" }
language = { path = "../language" }
menu = { path = "../menu" }
search = { path = "../search" }
settings = { path = "../settings" }
theme = { path = "../theme" }
util = { path = "../util" }
uuid = { version = "1.1.2", features = ["v4"] }
workspace = { path = "../workspace" }
anyhow.workspace = true
chrono = { version = "0.4", features = ["serde"] }
futures.workspace = true
indoc.workspace = true
isahc.workspace = true
ordered-float.workspace = true
parking_lot.workspace = true
regex.workspace = true
schemars.workspace = true
serde.workspace = true
serde_json.workspace = true
smol.workspace = true
tiktoken-rs = "0.4"
[dev-dependencies]
editor = { path = "../editor", features = ["test-support"] }
project = { path = "../project", features = ["test-support"] }
ctor.workspace = true
env_logger.workspace = true
log.workspace = true
rand.workspace = true

View File

@ -0,0 +1,112 @@
pub mod assistant_panel;
mod assistant_settings;
mod codegen;
mod streaming_diff;
use ai::completion::Role;
use anyhow::Result;
pub use assistant_panel::AssistantPanel;
use assistant_settings::OpenAIModel;
use chrono::{DateTime, Local};
use collections::HashMap;
use fs::Fs;
use futures::StreamExt;
use gpui::AppContext;
use regex::Regex;
use serde::{Deserialize, Serialize};
use std::{cmp::Reverse, ffi::OsStr, path::PathBuf, sync::Arc};
use util::paths::CONVERSATIONS_DIR;
#[derive(
Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize,
)]
struct MessageId(usize);
#[derive(Clone, Debug, Serialize, Deserialize)]
struct MessageMetadata {
role: Role,
sent_at: DateTime<Local>,
status: MessageStatus,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
enum MessageStatus {
Pending,
Done,
Error(Arc<str>),
}
#[derive(Serialize, Deserialize)]
struct SavedMessage {
id: MessageId,
start: usize,
}
#[derive(Serialize, Deserialize)]
struct SavedConversation {
id: Option<String>,
zed: String,
version: String,
text: String,
messages: Vec<SavedMessage>,
message_metadata: HashMap<MessageId, MessageMetadata>,
summary: String,
model: OpenAIModel,
}
impl SavedConversation {
const VERSION: &'static str = "0.1.0";
}
struct SavedConversationMetadata {
title: String,
path: PathBuf,
mtime: chrono::DateTime<chrono::Local>,
}
impl SavedConversationMetadata {
pub async fn list(fs: Arc<dyn Fs>) -> Result<Vec<Self>> {
fs.create_dir(&CONVERSATIONS_DIR).await?;
let mut paths = fs.read_dir(&CONVERSATIONS_DIR).await?;
let mut conversations = Vec::<SavedConversationMetadata>::new();
while let Some(path) = paths.next().await {
let path = path?;
if path.extension() != Some(OsStr::new("json")) {
continue;
}
let pattern = r" - \d+.zed.json$";
let re = Regex::new(pattern).unwrap();
let metadata = fs.metadata(&path).await?;
if let Some((file_name, metadata)) = path
.file_name()
.and_then(|name| name.to_str())
.zip(metadata)
{
let title = re.replace(file_name, "");
conversations.push(Self {
title: title.into_owned(),
path,
mtime: metadata.mtime.into(),
});
}
}
conversations.sort_unstable_by_key(|conversation| Reverse(conversation.mtime));
Ok(conversations)
}
}
pub fn init(cx: &mut AppContext) {
assistant_panel::init(cx);
}
#[cfg(test)]
#[ctor::ctor]
fn init_logger() {
if std::env::var("RUST_LOG").is_ok() {
env_logger::init();
}
}

View File

@ -1,8 +1,11 @@
use crate::{
assistant_settings::{AssistantDockPosition, AssistantSettings, OpenAIModel},
codegen::{self, Codegen, CodegenKind, OpenAICompletionProvider},
stream_completion, MessageId, MessageMetadata, MessageStatus, OpenAIRequest, RequestMessage,
Role, SavedConversation, SavedConversationMetadata, SavedMessage, OPENAI_API_URL,
codegen::{self, Codegen, CodegenKind},
MessageId, MessageMetadata, MessageStatus, Role, SavedConversation, SavedConversationMetadata,
SavedMessage,
};
use ai::completion::{
stream_completion, OpenAICompletionProvider, OpenAIRequest, RequestMessage, OPENAI_API_URL,
};
use anyhow::{anyhow, Result};
use chrono::{DateTime, Local};

View File

@ -1,59 +1,14 @@
use crate::{
stream_completion,
streaming_diff::{Hunk, StreamingDiff},
OpenAIRequest,
};
use crate::streaming_diff::{Hunk, StreamingDiff};
use ai::completion::{CompletionProvider, OpenAIRequest};
use anyhow::Result;
use editor::{
multi_buffer, Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint,
};
use futures::{
channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, SinkExt, Stream, StreamExt,
};
use gpui::{executor::Background, Entity, ModelContext, ModelHandle, Task};
use futures::{channel::mpsc, SinkExt, Stream, StreamExt};
use gpui::{Entity, ModelContext, ModelHandle, Task};
use language::{Rope, TransactionId};
use std::{cmp, future, ops::Range, sync::Arc};
pub trait CompletionProvider {
fn complete(
&self,
prompt: OpenAIRequest,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
}
pub struct OpenAICompletionProvider {
api_key: String,
executor: Arc<Background>,
}
impl OpenAICompletionProvider {
pub fn new(api_key: String, executor: Arc<Background>) -> Self {
Self { api_key, executor }
}
}
impl CompletionProvider for OpenAICompletionProvider {
fn complete(
&self,
prompt: OpenAIRequest,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
let request = stream_completion(self.api_key.clone(), self.executor.clone(), prompt);
async move {
let response = request.await?;
let stream = response
.filter_map(|response| async move {
match response {
Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
Err(error) => Some(Err(error)),
}
})
.boxed();
Ok(stream)
}
.boxed()
}
}
pub enum Event {
Finished,
Undone,
@ -397,13 +352,17 @@ fn strip_markdown_codeblock(
#[cfg(test)]
mod tests {
use super::*;
use futures::stream;
use futures::{
future::BoxFuture,
stream::{self, BoxStream},
};
use gpui::{executor::Deterministic, TestAppContext};
use indoc::indoc;
use language::{language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, Point};
use parking_lot::Mutex;
use rand::prelude::*;
use settings::SettingsStore;
use smol::future::FutureExt;
#[gpui::test(iterations = 10)]
async fn test_transform_autoindent(

View File

@ -9,7 +9,7 @@ path = "src/quick_action_bar.rs"
doctest = false
[dependencies]
ai = { path = "../ai" }
assistant = { path = "../assistant" }
editor = { path = "../editor" }
gpui = { path = "../gpui" }
search = { path = "../search" }

View File

@ -1,4 +1,4 @@
use ai::{assistant::InlineAssist, AssistantPanel};
use assistant::{assistant_panel::InlineAssist, AssistantPanel};
use editor::Editor;
use gpui::{
elements::{Empty, Flex, MouseEventHandler, ParentElement, Svg},

View File

@ -9,6 +9,7 @@ path = "src/semantic_index.rs"
doctest = false
[dependencies]
ai = { path = "../ai" }
collections = { path = "../collections" }
gpui = { path = "../gpui" }
language = { path = "../language" }
@ -26,22 +27,18 @@ futures.workspace = true
ordered-float.workspace = true
smol.workspace = true
rusqlite = { version = "0.27.0", features = ["blob", "array", "modern_sqlite"] }
isahc.workspace = true
log.workspace = true
tree-sitter.workspace = true
lazy_static.workspace = true
serde.workspace = true
serde_json.workspace = true
async-trait.workspace = true
bincode = "1.3.3"
matrixmultiply = "0.3.7"
tiktoken-rs = "0.5.0"
parking_lot.workspace = true
rand.workspace = true
schemars.workspace = true
globset.workspace = true
sha1 = "0.10.5"
parse_duration = "2.1.1"
[dev-dependencies]
collections = { path = "../collections", features = ["test-support"] }

View File

@ -1,10 +1,10 @@
use ai::embedding::OpenAIEmbeddings;
use anyhow::{anyhow, Result};
use client::{self, UserStore};
use gpui::{AsyncAppContext, ModelHandle, Task};
use language::LanguageRegistry;
use node_runtime::RealNodeRuntime;
use project::{Project, RealFs};
use semantic_index::embedding::OpenAIEmbeddings;
use semantic_index::semantic_index_settings::SemanticIndexSettings;
use semantic_index::{SearchResult, SemanticIndex};
use serde::{Deserialize, Serialize};

View File

@ -1,8 +1,8 @@
use crate::{
embedding::Embedding,
parsing::{Span, SpanDigest},
SEMANTIC_INDEX_VERSION,
};
use ai::embedding::Embedding;
use anyhow::{anyhow, Context, Result};
use collections::HashMap;
use futures::channel::oneshot;

View File

@ -1,4 +1,5 @@
use crate::{embedding::EmbeddingProvider, parsing::Span, JobHandle};
use crate::{parsing::Span, JobHandle};
use ai::embedding::EmbeddingProvider;
use gpui::executor::Background;
use parking_lot::Mutex;
use smol::channel;

View File

@ -1,4 +1,4 @@
use crate::embedding::{Embedding, EmbeddingProvider};
use ai::embedding::{Embedding, EmbeddingProvider};
use anyhow::{anyhow, Result};
use language::{Grammar, Language};
use rusqlite::{

View File

@ -1,5 +1,4 @@
mod db;
pub mod embedding;
mod embedding_queue;
mod parsing;
pub mod semantic_index_settings;
@ -8,10 +7,10 @@ pub mod semantic_index_settings;
mod semantic_index_tests;
use crate::semantic_index_settings::SemanticIndexSettings;
use ai::embedding::{Embedding, EmbeddingProvider, OpenAIEmbeddings};
use anyhow::{anyhow, Result};
use collections::{BTreeMap, HashMap, HashSet};
use db::VectorDatabase;
use embedding::{Embedding, EmbeddingProvider, OpenAIEmbeddings};
use embedding_queue::{EmbeddingQueue, FileToEmbed};
use futures::{future, FutureExt, StreamExt};
use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle};

View File

@ -1,10 +1,10 @@
use crate::{
embedding::{DummyEmbeddings, Embedding, EmbeddingProvider},
embedding_queue::EmbeddingQueue,
parsing::{subtract_ranges, CodeContextRetriever, Span, SpanDigest},
semantic_index_settings::SemanticIndexSettings,
FileToEmbed, JobHandle, SearchResult, SemanticIndex, EMBEDDING_QUEUE_FLUSH_TIMEOUT,
};
use ai::embedding::{DummyEmbeddings, Embedding, EmbeddingProvider};
use anyhow::Result;
use async_trait::async_trait;
use gpui::{executor::Deterministic, Task, TestAppContext};

View File

@ -50,7 +50,7 @@ language_selector = { path = "../language_selector" }
lsp = { path = "../lsp" }
language_tools = { path = "../language_tools" }
node_runtime = { path = "../node_runtime" }
ai = { path = "../ai" }
assistant = { path = "../assistant" }
outline = { path = "../outline" }
plugin_runtime = { path = "../plugin_runtime",optional = true }
project = { path = "../project" }

View File

@ -161,7 +161,7 @@ fn main() {
vim::init(cx);
terminal_view::init(cx);
copilot::init(copilot_language_server_id, http.clone(), node_runtime, cx);
ai::init(cx);
assistant::init(cx);
component_test::init(cx);
cx.spawn(|cx| watch_themes(fs.clone(), cx)).detach();

View File

@ -5,9 +5,9 @@ pub mod only_instance;
#[cfg(any(test, feature = "test-support"))]
pub mod test;
use ai::AssistantPanel;
use anyhow::Context;
use assets::Assets;
use assistant::AssistantPanel;
use breadcrumbs::Breadcrumbs;
pub use client;
use collab_ui::CollabTitlebarItem; // TODO: Add back toggle collab ui shortcut
@ -2418,7 +2418,7 @@ mod tests {
pane::init(cx);
project_panel::init((), cx);
terminal_view::init(cx);
ai::init(cx);
assistant::init(cx);
app_state
})
}