Merge branch 'main' into additional-keybinding-icons

This commit is contained in:
Nate Butler 2024-01-03 16:49:34 -05:00
commit f39bc6e132
1388 changed files with 44138 additions and 329825 deletions

View File

@ -1,6 +1,3 @@
[alias]
xtask = "run --package xtask --"
[build]
# v0 mangling scheme provides more detailed backtraces around closures
rustflags = ["-C", "symbol-mangling-version=v0"]

View File

@ -2,8 +2,8 @@ name: Release Nightly
on:
schedule:
# Fire every day at 1:00pm and 1:00am
- cron: "0 1,13 * * *"
# Fire every day at 7:00am UTC (Roughly before EU workday and after US workday)
- cron: "0 7 * * *"
push:
tags:
- "nightly"
@ -92,7 +92,7 @@ jobs:
run: script/generate-licenses
- name: Create app bundle
run: script/bundle -2
run: script/bundle
- name: Upload Zed Nightly
run: script/upload-nightly

2855
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -1,138 +1,89 @@
[workspace]
members = [
"crates/activity_indicator",
"crates/activity_indicator2",
"crates/ai",
"crates/assistant",
"crates/assistant2",
"crates/audio",
"crates/audio2",
"crates/auto_update",
"crates/auto_update2",
"crates/breadcrumbs",
"crates/breadcrumbs2",
"crates/call",
"crates/call2",
"crates/channel",
"crates/channel2",
"crates/cli",
"crates/client",
"crates/client2",
"crates/clock",
"crates/collab",
"crates/collab2",
"crates/collab_ui",
"crates/collab_ui2",
"crates/collections",
"crates/command_palette",
"crates/command_palette2",
"crates/component_test",
"crates/context_menu",
"crates/copilot",
"crates/copilot2",
"crates/copilot_button",
"crates/db",
"crates/db2",
"crates/refineable",
"crates/refineable/derive_refineable",
"crates/diagnostics",
"crates/diagnostics2",
"crates/drag_and_drop",
"crates/editor",
"crates/feature_flags",
"crates/feature_flags2",
"crates/feedback",
"crates/file_finder",
"crates/fs",
"crates/fs2",
"crates/fsevent",
"crates/fuzzy",
"crates/fuzzy2",
"crates/git",
"crates/go_to_line",
"crates/go_to_line2",
"crates/gpui",
"crates/gpui_macros",
"crates/gpui2",
"crates/gpui2_macros",
"crates/gpui",
"crates/gpui_macros",
"crates/install_cli",
"crates/install_cli2",
"crates/journal",
"crates/journal2",
"crates/journal",
"crates/language",
"crates/language2",
"crates/language_selector",
"crates/language_selector2",
"crates/language_tools",
"crates/language_tools2",
"crates/live_kit_client",
"crates/live_kit_server",
"crates/lsp",
"crates/lsp2",
"crates/media",
"crates/menu",
"crates/menu2",
"crates/multi_buffer",
"crates/multi_buffer2",
"crates/node_runtime",
"crates/notifications",
"crates/notifications2",
"crates/outline",
"crates/outline2",
"crates/picker",
"crates/picker2",
"crates/plugin",
"crates/plugin_macros",
"crates/plugin_runtime",
"crates/prettier",
"crates/prettier2",
"crates/project",
"crates/project2",
"crates/project_panel",
"crates/project_panel2",
"crates/project_symbols",
"crates/project_symbols2",
"crates/quick_action_bar2",
"crates/quick_action_bar",
"crates/recent_projects",
"crates/recent_projects2",
"crates/rope",
"crates/rpc",
"crates/rpc2",
"crates/search",
"crates/search2",
"crates/semantic_index",
"crates/semantic_index2",
"crates/settings",
"crates/settings2",
"crates/snippet",
"crates/sqlez",
"crates/sqlez_macros",
"crates/rich_text",
"crates/storybook2",
"crates/storybook",
"crates/sum_tree",
"crates/terminal",
"crates/terminal2",
"crates/terminal_view2",
"crates/terminal_view",
"crates/text",
"crates/theme",
"crates/theme2",
"crates/theme_importer",
"crates/theme_selector",
"crates/theme_selector2",
"crates/ui2",
"crates/ui",
"crates/util",
"crates/story",
"crates/vim",
"crates/vcs_menu",
"crates/vcs_menu2",
"crates/workspace2",
"crates/workspace",
"crates/welcome",
"crates/welcome2",
"crates/xtask",
"crates/zed",
"crates/zed2",
"crates/zed-actions",
"crates/zed_actions2"
"crates/zed_actions",
]
default-members = ["crates/zed"]
resolver = "2"
@ -176,11 +127,12 @@ thiserror = { version = "1.0.29" }
time = { version = "0.3", features = ["serde", "serde-well-known"] }
toml = { version = "0.5" }
tiktoken-rs = "0.5.7"
tree-sitter = "0.20"
tree-sitter = { version = "0.20", features = ["wasm"] }
unindent = { version = "0.1.7" }
pretty_assertions = "1.3.0"
git2 = { version = "0.15", default-features = false}
uuid = { version = "1.1.2", features = ["v4"] }
wasmtime = "16"
tree-sitter-bash = { git = "https://github.com/tree-sitter/tree-sitter-bash", rev = "7331995b19b8f8aba2d5e26deb51d2195c18bc94" }
tree-sitter-c = "0.20.1"
@ -212,8 +164,9 @@ tree-sitter-vue = {git = "https://github.com/zed-industries/tree-sitter-vue", re
tree-sitter-uiua = {git = "https://github.com/shnarazk/tree-sitter-uiua", rev = "9260f11be5900beda4ee6d1a24ab8ddfaf5a19b2"}
[patch.crates-io]
tree-sitter = { git = "https://github.com/tree-sitter/tree-sitter", rev = "b5f461a69bf3df7298b1903574d506179e6390b0" }
tree-sitter = { git = "https://github.com/tree-sitter/tree-sitter", rev = "31c40449749c4263a91a43593831b82229049a4c" }
async-task = { git = "https://github.com/zed-industries/async-task", rev = "341b57d6de98cdfd7b418567b8de2022ca993a6e" }
wasmtime = { git = "https://github.com/bytecodealliance/wasmtime", rev = "v16.0.0" }
# TODO - Remove when a version is released with this PR: https://github.com/servo/core-foundation-rs/pull/457
cocoa = { git = "https://github.com/servo/core-foundation-rs", rev = "079665882507dd5e2ff77db3de5070c1f6c0fb85" }

View File

@ -1,4 +0,0 @@
web: cd ../zed.dev && PORT=3000 npm run dev
collab: cd crates/collab2 && RUST_LOG=${RUST_LOG:-warn,collab=info} cargo run serve
livekit: livekit-server --dev
postgrest: postgrest crates/collab2/admin_api.conf

View File

@ -15,10 +15,12 @@ language = { path = "../language" }
gpui = { path = "../gpui" }
project = { path = "../project" }
settings = { path = "../settings" }
ui = { path = "../ui" }
util = { path = "../util" }
theme = { path = "../theme" }
workspace = { path = "../workspace" }
workspace = { path = "../workspace", package = "workspace" }
anyhow.workspace = true
futures.workspace = true
smallvec.workspace = true

View File

@ -2,19 +2,19 @@ use auto_update::{AutoUpdateStatus, AutoUpdater, DismissErrorMessage};
use editor::Editor;
use futures::StreamExt;
use gpui::{
actions, anyhow,
elements::*,
platform::{CursorStyle, MouseButton},
AppContext, Entity, ModelHandle, View, ViewContext, ViewHandle,
actions, svg, AppContext, CursorStyle, EventEmitter, InteractiveElement as _, Model,
ParentElement as _, Render, SharedString, StatefulInteractiveElement, Styled, View,
ViewContext, VisualContext as _,
};
use language::{LanguageRegistry, LanguageServerBinaryStatus};
use project::{LanguageServerProgress, Project};
use smallvec::SmallVec;
use std::{cmp::Reverse, fmt::Write, sync::Arc};
use ui::prelude::*;
use util::ResultExt;
use workspace::{item::ItemHandle, StatusItemView, Workspace};
actions!(lsp_status, [ShowErrorMessage]);
actions!(activity_indicator, [ShowErrorMessage]);
const DOWNLOAD_ICON: &str = "icons/download.svg";
const WARNING_ICON: &str = "icons/warning.svg";
@ -25,8 +25,8 @@ pub enum Event {
pub struct ActivityIndicator {
statuses: Vec<LspStatus>,
project: ModelHandle<Project>,
auto_updater: Option<ModelHandle<AutoUpdater>>,
project: Model<Project>,
auto_updater: Option<Model<AutoUpdater>>,
}
struct LspStatus {
@ -47,20 +47,15 @@ struct Content {
on_click: Option<Arc<dyn Fn(&mut ActivityIndicator, &mut ViewContext<ActivityIndicator>)>>,
}
pub fn init(cx: &mut AppContext) {
cx.add_action(ActivityIndicator::show_error_message);
cx.add_action(ActivityIndicator::dismiss_error_message);
}
impl ActivityIndicator {
pub fn new(
workspace: &mut Workspace,
languages: Arc<LanguageRegistry>,
cx: &mut ViewContext<Workspace>,
) -> ViewHandle<ActivityIndicator> {
) -> View<ActivityIndicator> {
let project = workspace.project().clone();
let auto_updater = AutoUpdater::get(cx);
let this = cx.add_view(|cx: &mut ViewContext<Self>| {
let this = cx.new_view(|cx: &mut ViewContext<Self>| {
let mut status_events = languages.language_server_binary_statuses();
cx.spawn(|this, mut cx| async move {
while let Some((language, event)) = status_events.next().await {
@ -77,11 +72,13 @@ impl ActivityIndicator {
})
.detach();
cx.observe(&project, |_, _, cx| cx.notify()).detach();
if let Some(auto_updater) = auto_updater.as_ref() {
cx.observe(auto_updater, |_, _, cx| cx.notify()).detach();
}
cx.observe_active_labeled_tasks(|_, cx| cx.notify())
.detach();
// cx.observe_active_labeled_tasks(|_, cx| cx.notify())
// .detach();
Self {
statuses: Default::default(),
@ -89,6 +86,7 @@ impl ActivityIndicator {
auto_updater,
}
});
cx.subscribe(&this, move |workspace, _, event, cx| match event {
Event::ShowError { lsp_name, error } => {
if let Some(buffer) = project
@ -104,7 +102,7 @@ impl ActivityIndicator {
});
workspace.add_item(
Box::new(
cx.add_view(|cx| Editor::for_buffer(buffer, Some(project.clone()), cx)),
cx.new_view(|cx| Editor::for_buffer(buffer, Some(project.clone()), cx)),
),
cx,
);
@ -290,71 +288,41 @@ impl ActivityIndicator {
};
}
if let Some(most_recent_active_task) = cx.active_labeled_tasks().last() {
return Content {
icon: None,
message: most_recent_active_task.to_string(),
on_click: None,
};
}
// todo!(show active tasks)
// if let Some(most_recent_active_task) = cx.active_labeled_tasks().last() {
// return Content {
// icon: None,
// message: most_recent_active_task.to_string(),
// on_click: None,
// };
// }
Default::default()
}
}
impl Entity for ActivityIndicator {
type Event = Event;
}
impl EventEmitter<Event> for ActivityIndicator {}
impl View for ActivityIndicator {
fn ui_name() -> &'static str {
"ActivityIndicator"
}
impl Render for ActivityIndicator {
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
let content = self.content_to_render(cx);
fn render(&mut self, cx: &mut ViewContext<Self>) -> AnyElement<Self> {
let Content {
icon,
message,
on_click,
} = self.content_to_render(cx);
let mut result = h_stack()
.id("activity-indicator")
.on_action(cx.listener(Self::show_error_message))
.on_action(cx.listener(Self::dismiss_error_message));
let mut element = MouseEventHandler::new::<Self, _>(0, cx, |state, cx| {
let theme = &theme::current(cx).workspace.status_bar.lsp_status;
let style = if state.hovered() && on_click.is_some() {
theme.hovered.as_ref().unwrap_or(&theme.default)
} else {
&theme.default
};
Flex::row()
.with_children(icon.map(|path| {
Svg::new(path)
.with_color(style.icon_color)
.constrained()
.with_width(style.icon_width)
.contained()
.with_margin_right(style.icon_spacing)
.aligned()
.into_any_named("activity-icon")
if let Some(on_click) = content.on_click {
result = result
.cursor(CursorStyle::PointingHand)
.on_click(cx.listener(move |this, _, cx| {
on_click(this, cx);
}))
.with_child(
Text::new(message, style.message.clone())
.with_soft_wrap(false)
.aligned(),
)
.constrained()
.with_height(style.height)
.contained()
.with_style(style.container)
.aligned()
});
if let Some(on_click) = on_click.clone() {
element = element
.with_cursor_style(CursorStyle::PointingHand)
.on_click(MouseButton::Left, move |_, this, cx| on_click(this, cx));
}
element.into_any()
result
.children(content.icon.map(|icon| svg().path(icon)))
.child(Label::new(SharedString::from(content.message)).size(LabelSize::Small))
}
}

View File

@ -1,28 +0,0 @@
[package]
name = "activity_indicator2"
version = "0.1.0"
edition = "2021"
publish = false
[lib]
path = "src/activity_indicator.rs"
doctest = false
[dependencies]
auto_update = { path = "../auto_update2", package = "auto_update2" }
editor = { path = "../editor2", package = "editor2" }
language = { path = "../language2", package = "language2" }
gpui = { path = "../gpui2", package = "gpui2" }
project = { path = "../project2", package = "project2" }
settings = { path = "../settings2", package = "settings2" }
ui = { path = "../ui2", package = "ui2" }
util = { path = "../util" }
theme = { path = "../theme2", package = "theme2" }
workspace = { path = "../workspace2", package = "workspace2" }
anyhow.workspace = true
futures.workspace = true
smallvec.workspace = true
[dev-dependencies]
editor = { path = "../editor2", package = "editor2", features = ["test-support"] }

View File

@ -1,331 +0,0 @@
use auto_update::{AutoUpdateStatus, AutoUpdater, DismissErrorMessage};
use editor::Editor;
use futures::StreamExt;
use gpui::{
actions, svg, AppContext, CursorStyle, EventEmitter, InteractiveElement as _, Model,
ParentElement as _, Render, SharedString, StatefulInteractiveElement, Styled, View,
ViewContext, VisualContext as _,
};
use language::{LanguageRegistry, LanguageServerBinaryStatus};
use project::{LanguageServerProgress, Project};
use smallvec::SmallVec;
use std::{cmp::Reverse, fmt::Write, sync::Arc};
use ui::prelude::*;
use util::ResultExt;
use workspace::{item::ItemHandle, StatusItemView, Workspace};
actions!(activity_indicator, [ShowErrorMessage]);
const DOWNLOAD_ICON: &str = "icons/download.svg";
const WARNING_ICON: &str = "icons/warning.svg";
pub enum Event {
ShowError { lsp_name: Arc<str>, error: String },
}
pub struct ActivityIndicator {
statuses: Vec<LspStatus>,
project: Model<Project>,
auto_updater: Option<Model<AutoUpdater>>,
}
struct LspStatus {
name: Arc<str>,
status: LanguageServerBinaryStatus,
}
struct PendingWork<'a> {
language_server_name: &'a str,
progress_token: &'a str,
progress: &'a LanguageServerProgress,
}
#[derive(Default)]
struct Content {
icon: Option<&'static str>,
message: String,
on_click: Option<Arc<dyn Fn(&mut ActivityIndicator, &mut ViewContext<ActivityIndicator>)>>,
}
impl ActivityIndicator {
pub fn new(
workspace: &mut Workspace,
languages: Arc<LanguageRegistry>,
cx: &mut ViewContext<Workspace>,
) -> View<ActivityIndicator> {
let project = workspace.project().clone();
let auto_updater = AutoUpdater::get(cx);
let this = cx.new_view(|cx: &mut ViewContext<Self>| {
let mut status_events = languages.language_server_binary_statuses();
cx.spawn(|this, mut cx| async move {
while let Some((language, event)) = status_events.next().await {
this.update(&mut cx, |this, cx| {
this.statuses.retain(|s| s.name != language.name());
this.statuses.push(LspStatus {
name: language.name(),
status: event,
});
cx.notify();
})?;
}
anyhow::Ok(())
})
.detach();
cx.observe(&project, |_, _, cx| cx.notify()).detach();
if let Some(auto_updater) = auto_updater.as_ref() {
cx.observe(auto_updater, |_, _, cx| cx.notify()).detach();
}
// cx.observe_active_labeled_tasks(|_, cx| cx.notify())
// .detach();
Self {
statuses: Default::default(),
project: project.clone(),
auto_updater,
}
});
cx.subscribe(&this, move |workspace, _, event, cx| match event {
Event::ShowError { lsp_name, error } => {
if let Some(buffer) = project
.update(cx, |project, cx| project.create_buffer(error, None, cx))
.log_err()
{
buffer.update(cx, |buffer, cx| {
buffer.edit(
[(0..0, format!("Language server error: {}\n\n", lsp_name))],
None,
cx,
);
});
workspace.add_item(
Box::new(
cx.new_view(|cx| Editor::for_buffer(buffer, Some(project.clone()), cx)),
),
cx,
);
}
}
})
.detach();
this
}
fn show_error_message(&mut self, _: &ShowErrorMessage, cx: &mut ViewContext<Self>) {
self.statuses.retain(|status| {
if let LanguageServerBinaryStatus::Failed { error } = &status.status {
cx.emit(Event::ShowError {
lsp_name: status.name.clone(),
error: error.clone(),
});
false
} else {
true
}
});
cx.notify();
}
fn dismiss_error_message(&mut self, _: &DismissErrorMessage, cx: &mut ViewContext<Self>) {
if let Some(updater) = &self.auto_updater {
updater.update(cx, |updater, cx| {
updater.dismiss_error(cx);
});
}
cx.notify();
}
fn pending_language_server_work<'a>(
&self,
cx: &'a AppContext,
) -> impl Iterator<Item = PendingWork<'a>> {
self.project
.read(cx)
.language_server_statuses()
.rev()
.filter_map(|status| {
if status.pending_work.is_empty() {
None
} else {
let mut pending_work = status
.pending_work
.iter()
.map(|(token, progress)| PendingWork {
language_server_name: status.name.as_str(),
progress_token: token.as_str(),
progress,
})
.collect::<SmallVec<[_; 4]>>();
pending_work.sort_by_key(|work| Reverse(work.progress.last_update_at));
Some(pending_work)
}
})
.flatten()
}
fn content_to_render(&mut self, cx: &mut ViewContext<Self>) -> Content {
// Show any language server has pending activity.
let mut pending_work = self.pending_language_server_work(cx);
if let Some(PendingWork {
language_server_name,
progress_token,
progress,
}) = pending_work.next()
{
let mut message = language_server_name.to_string();
message.push_str(": ");
if let Some(progress_message) = progress.message.as_ref() {
message.push_str(progress_message);
} else {
message.push_str(progress_token);
}
if let Some(percentage) = progress.percentage {
write!(&mut message, " ({}%)", percentage).unwrap();
}
let additional_work_count = pending_work.count();
if additional_work_count > 0 {
write!(&mut message, " + {} more", additional_work_count).unwrap();
}
return Content {
icon: None,
message,
on_click: None,
};
}
// Show any language server installation info.
let mut downloading = SmallVec::<[_; 3]>::new();
let mut checking_for_update = SmallVec::<[_; 3]>::new();
let mut failed = SmallVec::<[_; 3]>::new();
for status in &self.statuses {
let name = status.name.clone();
match status.status {
LanguageServerBinaryStatus::CheckingForUpdate => checking_for_update.push(name),
LanguageServerBinaryStatus::Downloading => downloading.push(name),
LanguageServerBinaryStatus::Failed { .. } => failed.push(name),
LanguageServerBinaryStatus::Downloaded | LanguageServerBinaryStatus::Cached => {}
}
}
if !downloading.is_empty() {
return Content {
icon: Some(DOWNLOAD_ICON),
message: format!(
"Downloading {} language server{}...",
downloading.join(", "),
if downloading.len() > 1 { "s" } else { "" }
),
on_click: None,
};
} else if !checking_for_update.is_empty() {
return Content {
icon: Some(DOWNLOAD_ICON),
message: format!(
"Checking for updates to {} language server{}...",
checking_for_update.join(", "),
if checking_for_update.len() > 1 {
"s"
} else {
""
}
),
on_click: None,
};
} else if !failed.is_empty() {
return Content {
icon: Some(WARNING_ICON),
message: format!(
"Failed to download {} language server{}. Click to show error.",
failed.join(", "),
if failed.len() > 1 { "s" } else { "" }
),
on_click: Some(Arc::new(|this, cx| {
this.show_error_message(&Default::default(), cx)
})),
};
}
// Show any application auto-update info.
if let Some(updater) = &self.auto_updater {
return match &updater.read(cx).status() {
AutoUpdateStatus::Checking => Content {
icon: Some(DOWNLOAD_ICON),
message: "Checking for Zed updates…".to_string(),
on_click: None,
},
AutoUpdateStatus::Downloading => Content {
icon: Some(DOWNLOAD_ICON),
message: "Downloading Zed update…".to_string(),
on_click: None,
},
AutoUpdateStatus::Installing => Content {
icon: Some(DOWNLOAD_ICON),
message: "Installing Zed update…".to_string(),
on_click: None,
},
AutoUpdateStatus::Updated => Content {
icon: None,
message: "Click to restart and update Zed".to_string(),
on_click: Some(Arc::new(|_, cx| {
workspace::restart(&Default::default(), cx)
})),
},
AutoUpdateStatus::Errored => Content {
icon: Some(WARNING_ICON),
message: "Auto update failed".to_string(),
on_click: Some(Arc::new(|this, cx| {
this.dismiss_error_message(&Default::default(), cx)
})),
},
AutoUpdateStatus::Idle => Default::default(),
};
}
// todo!(show active tasks)
// if let Some(most_recent_active_task) = cx.active_labeled_tasks().last() {
// return Content {
// icon: None,
// message: most_recent_active_task.to_string(),
// on_click: None,
// };
// }
Default::default()
}
}
impl EventEmitter<Event> for ActivityIndicator {}
impl Render for ActivityIndicator {
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl Element {
let content = self.content_to_render(cx);
let mut result = h_stack()
.id("activity-indicator")
.on_action(cx.listener(Self::show_error_message))
.on_action(cx.listener(Self::dismiss_error_message));
if let Some(on_click) = content.on_click {
result = result
.cursor(CursorStyle::PointingHand)
.on_click(cx.listener(move |this, _, cx| {
on_click(this, cx);
}))
}
result
.children(content.icon.map(|icon| svg().path(icon)))
.child(Label::new(SharedString::from(content.message)).size(LabelSize::Small))
}
}
impl StatusItemView for ActivityIndicator {
fn set_active_pane_item(&mut self, _: Option<&dyn ItemHandle>, _: &mut ViewContext<Self>) {}
}

View File

@ -9,7 +9,7 @@ pub enum ProviderCredential {
pub trait CredentialProvider: Send + Sync {
fn has_credentials(&self) -> bool;
fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential;
fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential);
fn delete_credentials(&self, cx: &AppContext);
fn retrieve_credentials(&self, cx: &mut AppContext) -> ProviderCredential;
fn save_credentials(&self, cx: &mut AppContext, credential: ProviderCredential);
fn delete_credentials(&self, cx: &mut AppContext);
}

View File

@ -2,7 +2,7 @@ use crate::prompts::base::{PromptArguments, PromptTemplate};
use std::fmt::Write;
use std::{ops::Range, path::PathBuf};
use gpui::{AsyncAppContext, ModelHandle};
use gpui::{AsyncAppContext, Model};
use language::{Anchor, Buffer};
#[derive(Clone)]
@ -13,8 +13,12 @@ pub struct PromptCodeSnippet {
}
impl PromptCodeSnippet {
pub fn new(buffer: ModelHandle<Buffer>, range: Range<Anchor>, cx: &AsyncAppContext) -> Self {
let (content, language_name, file_path) = buffer.read_with(cx, |buffer, _| {
pub fn new(
buffer: Model<Buffer>,
range: Range<Anchor>,
cx: &mut AsyncAppContext,
) -> anyhow::Result<Self> {
let (content, language_name, file_path) = buffer.update(cx, |buffer, _| {
let snapshot = buffer.snapshot();
let content = snapshot.text_for_range(range.clone()).collect::<String>();
@ -27,13 +31,13 @@ impl PromptCodeSnippet {
.and_then(|file| Some(file.path().to_path_buf()));
(content, language_name, file_path)
});
})?;
PromptCodeSnippet {
anyhow::Ok(PromptCodeSnippet {
path: file_path,
language_name,
content,
}
})
}
}

View File

@ -3,7 +3,7 @@ use futures::{
future::BoxFuture, io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt,
Stream, StreamExt,
};
use gpui::{executor::Background, AppContext};
use gpui::{AppContext, BackgroundExecutor};
use isahc::{http::StatusCode, Request, RequestExt};
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
@ -104,7 +104,7 @@ pub struct OpenAIResponseStreamEvent {
pub async fn stream_completion(
credential: ProviderCredential,
executor: Arc<Background>,
executor: BackgroundExecutor,
request: Box<dyn CompletionRequest>,
) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
let api_key = match credential {
@ -197,11 +197,11 @@ pub async fn stream_completion(
pub struct OpenAICompletionProvider {
model: OpenAILanguageModel,
credential: Arc<RwLock<ProviderCredential>>,
executor: Arc<Background>,
executor: BackgroundExecutor,
}
impl OpenAICompletionProvider {
pub fn new(model_name: &str, executor: Arc<Background>) -> Self {
pub fn new(model_name: &str, executor: BackgroundExecutor) -> Self {
let model = OpenAILanguageModel::load(model_name);
let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
Self {
@ -219,46 +219,45 @@ impl CredentialProvider for OpenAICompletionProvider {
_ => false,
}
}
fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential {
let mut credential = self.credential.write();
match *credential {
ProviderCredential::Credentials { .. } => {
return credential.clone();
}
fn retrieve_credentials(&self, cx: &mut AppContext) -> ProviderCredential {
let existing_credential = self.credential.read().clone();
let retrieved_credential = match existing_credential {
ProviderCredential::Credentials { .. } => existing_credential.clone(),
_ => {
if let Ok(api_key) = env::var("OPENAI_API_KEY") {
*credential = ProviderCredential::Credentials { api_key };
} else if let Some((_, api_key)) = cx
.platform()
.read_credentials(OPENAI_API_URL)
.log_err()
.flatten()
if let Some(api_key) = env::var("OPENAI_API_KEY").log_err() {
ProviderCredential::Credentials { api_key }
} else if let Some(Some((_, api_key))) =
cx.read_credentials(OPENAI_API_URL).log_err()
{
if let Some(api_key) = String::from_utf8(api_key).log_err() {
*credential = ProviderCredential::Credentials { api_key };
ProviderCredential::Credentials { api_key }
} else {
ProviderCredential::NoCredentials
}
} else {
};
ProviderCredential::NoCredentials
}
}
}
credential.clone()
};
*self.credential.write() = retrieved_credential.clone();
retrieved_credential
}
fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential) {
match credential.clone() {
fn save_credentials(&self, cx: &mut AppContext, credential: ProviderCredential) {
*self.credential.write() = credential.clone();
let credential = credential.clone();
match credential {
ProviderCredential::Credentials { api_key } => {
cx.platform()
.write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
cx.write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
.log_err();
}
_ => {}
}
*self.credential.write() = credential;
}
fn delete_credentials(&self, cx: &AppContext) {
cx.platform().delete_credentials(OPENAI_API_URL).log_err();
fn delete_credentials(&self, cx: &mut AppContext) {
cx.delete_credentials(OPENAI_API_URL).log_err();
*self.credential.write() = ProviderCredential::NoCredentials;
}
}

View File

@ -1,7 +1,7 @@
use anyhow::{anyhow, Result};
use async_trait::async_trait;
use futures::AsyncReadExt;
use gpui::executor::Background;
use gpui::BackgroundExecutor;
use gpui::{serde_json, AppContext};
use isahc::http::StatusCode;
use isahc::prelude::Configurable;
@ -35,7 +35,7 @@ pub struct OpenAIEmbeddingProvider {
model: OpenAILanguageModel,
credential: Arc<RwLock<ProviderCredential>>,
pub client: Arc<dyn HttpClient>,
pub executor: Arc<Background>,
pub executor: BackgroundExecutor,
rate_limit_count_rx: watch::Receiver<Option<Instant>>,
rate_limit_count_tx: Arc<Mutex<watch::Sender<Option<Instant>>>>,
}
@ -66,7 +66,7 @@ struct OpenAIEmbeddingUsage {
}
impl OpenAIEmbeddingProvider {
pub fn new(client: Arc<dyn HttpClient>, executor: Arc<Background>) -> Self {
pub fn new(client: Arc<dyn HttpClient>, executor: BackgroundExecutor) -> Self {
let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
@ -153,46 +153,45 @@ impl CredentialProvider for OpenAIEmbeddingProvider {
_ => false,
}
}
fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential {
let mut credential = self.credential.write();
match *credential {
ProviderCredential::Credentials { .. } => {
return credential.clone();
}
fn retrieve_credentials(&self, cx: &mut AppContext) -> ProviderCredential {
let existing_credential = self.credential.read().clone();
let retrieved_credential = match existing_credential {
ProviderCredential::Credentials { .. } => existing_credential.clone(),
_ => {
if let Ok(api_key) = env::var("OPENAI_API_KEY") {
*credential = ProviderCredential::Credentials { api_key };
} else if let Some((_, api_key)) = cx
.platform()
.read_credentials(OPENAI_API_URL)
.log_err()
.flatten()
if let Some(api_key) = env::var("OPENAI_API_KEY").log_err() {
ProviderCredential::Credentials { api_key }
} else if let Some(Some((_, api_key))) =
cx.read_credentials(OPENAI_API_URL).log_err()
{
if let Some(api_key) = String::from_utf8(api_key).log_err() {
*credential = ProviderCredential::Credentials { api_key };
ProviderCredential::Credentials { api_key }
} else {
ProviderCredential::NoCredentials
}
} else {
};
ProviderCredential::NoCredentials
}
}
}
};
credential.clone()
*self.credential.write() = retrieved_credential.clone();
retrieved_credential
}
fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential) {
match credential.clone() {
fn save_credentials(&self, cx: &mut AppContext, credential: ProviderCredential) {
*self.credential.write() = credential.clone();
match credential {
ProviderCredential::Credentials { api_key } => {
cx.platform()
.write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
cx.write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
.log_err();
}
_ => {}
}
*self.credential.write() = credential;
}
fn delete_credentials(&self, cx: &AppContext) {
cx.platform().delete_credentials(OPENAI_API_URL).log_err();
fn delete_credentials(&self, cx: &mut AppContext) {
cx.delete_credentials(OPENAI_API_URL).log_err();
*self.credential.write() = ProviderCredential::NoCredentials;
}
}

View File

@ -104,11 +104,11 @@ impl CredentialProvider for FakeEmbeddingProvider {
fn has_credentials(&self) -> bool {
true
}
fn retrieve_credentials(&self, _cx: &AppContext) -> ProviderCredential {
fn retrieve_credentials(&self, _cx: &mut AppContext) -> ProviderCredential {
ProviderCredential::NotNeeded
}
fn save_credentials(&self, _cx: &AppContext, _credential: ProviderCredential) {}
fn delete_credentials(&self, _cx: &AppContext) {}
fn save_credentials(&self, _cx: &mut AppContext, _credential: ProviderCredential) {}
fn delete_credentials(&self, _cx: &mut AppContext) {}
}
#[async_trait]
@ -153,17 +153,10 @@ impl FakeCompletionProvider {
pub fn send_completion(&self, completion: impl Into<String>) {
let mut tx = self.last_completion_tx.lock();
println!("COMPLETION TX: {:?}", &tx);
let a = tx.as_mut().unwrap();
a.try_send(completion.into()).unwrap();
// tx.as_mut().unwrap().try_send(completion.into()).unwrap();
tx.as_mut().unwrap().try_send(completion.into()).unwrap();
}
pub fn finish_completion(&self) {
println!("FINISHING COMPLETION");
self.last_completion_tx.lock().take().unwrap();
}
}
@ -172,11 +165,11 @@ impl CredentialProvider for FakeCompletionProvider {
fn has_credentials(&self) -> bool {
true
}
fn retrieve_credentials(&self, _cx: &AppContext) -> ProviderCredential {
fn retrieve_credentials(&self, _cx: &mut AppContext) -> ProviderCredential {
ProviderCredential::NotNeeded
}
fn save_credentials(&self, _cx: &AppContext, _credential: ProviderCredential) {}
fn delete_credentials(&self, _cx: &AppContext) {}
fn save_credentials(&self, _cx: &mut AppContext, _credential: ProviderCredential) {}
fn delete_credentials(&self, _cx: &mut AppContext) {}
}
impl CompletionProvider for FakeCompletionProvider {
@ -188,10 +181,8 @@ impl CompletionProvider for FakeCompletionProvider {
&self,
_prompt: Box<dyn CompletionRequest>,
) -> BoxFuture<'static, anyhow::Result<BoxStream<'static, anyhow::Result<String>>>> {
println!("COMPLETING");
let (tx, rx) = mpsc::channel(1);
*self.last_completion_tx.lock() = Some(tx);
println!("TX: {:?}", *self.last_completion_tx.lock());
async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed()
}
fn box_clone(&self) -> Box<dyn CompletionProvider> {

View File

@ -1,38 +0,0 @@
[package]
name = "ai2"
version = "0.1.0"
edition = "2021"
publish = false
[lib]
path = "src/ai2.rs"
doctest = false
[features]
test-support = []
[dependencies]
gpui = { package = "gpui2", path = "../gpui2" }
util = { path = "../util" }
language = { package = "language2", path = "../language2" }
async-trait.workspace = true
anyhow.workspace = true
futures.workspace = true
lazy_static.workspace = true
ordered-float.workspace = true
parking_lot.workspace = true
isahc.workspace = true
regex.workspace = true
serde.workspace = true
serde_json.workspace = true
postage.workspace = true
rand.workspace = true
log.workspace = true
parse_duration = "2.1.1"
tiktoken-rs.workspace = true
matrixmultiply = "0.3.7"
rusqlite = { version = "0.29.0", features = ["blob", "array", "modern_sqlite"] }
bincode = "1.3.3"
[dev-dependencies]
gpui = { package = "gpui2", path = "../gpui2", features = ["test-support"] }

View File

@ -1,8 +0,0 @@
pub mod auth;
pub mod completion;
pub mod embedding;
pub mod models;
pub mod prompts;
pub mod providers;
#[cfg(any(test, feature = "test-support"))]
pub mod test;

View File

@ -1,15 +0,0 @@
use gpui::AppContext;
#[derive(Clone, Debug)]
pub enum ProviderCredential {
Credentials { api_key: String },
NoCredentials,
NotNeeded,
}
pub trait CredentialProvider: Send + Sync {
fn has_credentials(&self) -> bool;
fn retrieve_credentials(&self, cx: &mut AppContext) -> ProviderCredential;
fn save_credentials(&self, cx: &mut AppContext, credential: ProviderCredential);
fn delete_credentials(&self, cx: &mut AppContext);
}

View File

@ -1,23 +0,0 @@
use anyhow::Result;
use futures::{future::BoxFuture, stream::BoxStream};
use crate::{auth::CredentialProvider, models::LanguageModel};
pub trait CompletionRequest: Send + Sync {
fn data(&self) -> serde_json::Result<String>;
}
pub trait CompletionProvider: CredentialProvider {
fn base_model(&self) -> Box<dyn LanguageModel>;
fn complete(
&self,
prompt: Box<dyn CompletionRequest>,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
fn box_clone(&self) -> Box<dyn CompletionProvider>;
}
impl Clone for Box<dyn CompletionProvider> {
fn clone(&self) -> Box<dyn CompletionProvider> {
self.box_clone()
}
}

View File

@ -1,123 +0,0 @@
use std::time::Instant;
use anyhow::Result;
use async_trait::async_trait;
use ordered_float::OrderedFloat;
use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef};
use rusqlite::ToSql;
use crate::auth::CredentialProvider;
use crate::models::LanguageModel;
#[derive(Debug, PartialEq, Clone)]
pub struct Embedding(pub Vec<f32>);
// This is needed for semantic index functionality
// Unfortunately it has to live wherever the "Embedding" struct is created.
// Keeping this in here though, introduces a 'rusqlite' dependency into AI
// which is less than ideal
impl FromSql for Embedding {
fn column_result(value: ValueRef) -> FromSqlResult<Self> {
let bytes = value.as_blob()?;
let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
if embedding.is_err() {
return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
}
Ok(Embedding(embedding.unwrap()))
}
}
impl ToSql for Embedding {
fn to_sql(&self) -> rusqlite::Result<ToSqlOutput> {
let bytes = bincode::serialize(&self.0)
.map_err(|err| rusqlite::Error::ToSqlConversionFailure(Box::new(err)))?;
Ok(ToSqlOutput::Owned(rusqlite::types::Value::Blob(bytes)))
}
}
impl From<Vec<f32>> for Embedding {
fn from(value: Vec<f32>) -> Self {
Embedding(value)
}
}
impl Embedding {
pub fn similarity(&self, other: &Self) -> OrderedFloat<f32> {
let len = self.0.len();
assert_eq!(len, other.0.len());
let mut result = 0.0;
unsafe {
matrixmultiply::sgemm(
1,
len,
1,
1.0,
self.0.as_ptr(),
len as isize,
1,
other.0.as_ptr(),
1,
len as isize,
0.0,
&mut result as *mut f32,
1,
1,
);
}
OrderedFloat(result)
}
}
#[async_trait]
pub trait EmbeddingProvider: CredentialProvider {
fn base_model(&self) -> Box<dyn LanguageModel>;
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>>;
fn max_tokens_per_batch(&self) -> usize;
fn rate_limit_expiration(&self) -> Option<Instant>;
}
#[cfg(test)]
mod tests {
use super::*;
use rand::prelude::*;
#[gpui::test]
fn test_similarity(mut rng: StdRng) {
assert_eq!(
Embedding::from(vec![1., 0., 0., 0., 0.])
.similarity(&Embedding::from(vec![0., 1., 0., 0., 0.])),
0.
);
assert_eq!(
Embedding::from(vec![2., 0., 0., 0., 0.])
.similarity(&Embedding::from(vec![3., 1., 0., 0., 0.])),
6.
);
for _ in 0..100 {
let size = 1536;
let mut a = vec![0.; size];
let mut b = vec![0.; size];
for (a, b) in a.iter_mut().zip(b.iter_mut()) {
*a = rng.gen();
*b = rng.gen();
}
let a = Embedding::from(a);
let b = Embedding::from(b);
assert_eq!(
round_to_decimals(a.similarity(&b), 1),
round_to_decimals(reference_dot(&a.0, &b.0), 1)
);
}
fn round_to_decimals(n: OrderedFloat<f32>, decimal_places: i32) -> f32 {
let factor = (10.0 as f32).powi(decimal_places);
(n * factor).round() / factor
}
fn reference_dot(a: &[f32], b: &[f32]) -> OrderedFloat<f32> {
OrderedFloat(a.iter().zip(b.iter()).map(|(a, b)| a * b).sum())
}
}
}

View File

@ -1,16 +0,0 @@
pub enum TruncationDirection {
Start,
End,
}
pub trait LanguageModel {
fn name(&self) -> String;
fn count_tokens(&self, content: &str) -> anyhow::Result<usize>;
fn truncate(
&self,
content: &str,
length: usize,
direction: TruncationDirection,
) -> anyhow::Result<String>;
fn capacity(&self) -> anyhow::Result<usize>;
}

View File

@ -1,330 +0,0 @@
use std::cmp::Reverse;
use std::ops::Range;
use std::sync::Arc;
use language::BufferSnapshot;
use util::ResultExt;
use crate::models::LanguageModel;
use crate::prompts::repository_context::PromptCodeSnippet;
pub(crate) enum PromptFileType {
Text,
Code,
}
// TODO: Set this up to manage for defaults well
pub struct PromptArguments {
pub model: Arc<dyn LanguageModel>,
pub user_prompt: Option<String>,
pub language_name: Option<String>,
pub project_name: Option<String>,
pub snippets: Vec<PromptCodeSnippet>,
pub reserved_tokens: usize,
pub buffer: Option<BufferSnapshot>,
pub selected_range: Option<Range<usize>>,
}
impl PromptArguments {
pub(crate) fn get_file_type(&self) -> PromptFileType {
if self
.language_name
.as_ref()
.and_then(|name| Some(!["Markdown", "Plain Text"].contains(&name.as_str())))
.unwrap_or(true)
{
PromptFileType::Code
} else {
PromptFileType::Text
}
}
}
pub trait PromptTemplate {
fn generate(
&self,
args: &PromptArguments,
max_token_length: Option<usize>,
) -> anyhow::Result<(String, usize)>;
}
#[repr(i8)]
#[derive(PartialEq, Eq, Ord)]
pub enum PromptPriority {
Mandatory, // Ignores truncation
Ordered { order: usize }, // Truncates based on priority
}
impl PartialOrd for PromptPriority {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
match (self, other) {
(Self::Mandatory, Self::Mandatory) => Some(std::cmp::Ordering::Equal),
(Self::Mandatory, Self::Ordered { .. }) => Some(std::cmp::Ordering::Greater),
(Self::Ordered { .. }, Self::Mandatory) => Some(std::cmp::Ordering::Less),
(Self::Ordered { order: a }, Self::Ordered { order: b }) => b.partial_cmp(a),
}
}
}
pub struct PromptChain {
args: PromptArguments,
templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)>,
}
impl PromptChain {
pub fn new(
args: PromptArguments,
templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)>,
) -> Self {
PromptChain { args, templates }
}
pub fn generate(&self, truncate: bool) -> anyhow::Result<(String, usize)> {
// Argsort based on Prompt Priority
let seperator = "\n";
let seperator_tokens = self.args.model.count_tokens(seperator)?;
let mut sorted_indices = (0..self.templates.len()).collect::<Vec<_>>();
sorted_indices.sort_by_key(|&i| Reverse(&self.templates[i].0));
// If Truncate
let mut tokens_outstanding = if truncate {
Some(self.args.model.capacity()? - self.args.reserved_tokens)
} else {
None
};
let mut prompts = vec!["".to_string(); sorted_indices.len()];
for idx in sorted_indices {
let (_, template) = &self.templates[idx];
if let Some((template_prompt, prompt_token_count)) =
template.generate(&self.args, tokens_outstanding).log_err()
{
if template_prompt != "" {
prompts[idx] = template_prompt;
if let Some(remaining_tokens) = tokens_outstanding {
let new_tokens = prompt_token_count + seperator_tokens;
tokens_outstanding = if remaining_tokens > new_tokens {
Some(remaining_tokens - new_tokens)
} else {
Some(0)
};
}
}
}
}
prompts.retain(|x| x != "");
let full_prompt = prompts.join(seperator);
let total_token_count = self.args.model.count_tokens(&full_prompt)?;
anyhow::Ok((prompts.join(seperator), total_token_count))
}
}
#[cfg(test)]
pub(crate) mod tests {
use crate::models::TruncationDirection;
use crate::test::FakeLanguageModel;
use super::*;
#[test]
pub fn test_prompt_chain() {
struct TestPromptTemplate {}
impl PromptTemplate for TestPromptTemplate {
fn generate(
&self,
args: &PromptArguments,
max_token_length: Option<usize>,
) -> anyhow::Result<(String, usize)> {
let mut content = "This is a test prompt template".to_string();
let mut token_count = args.model.count_tokens(&content)?;
if let Some(max_token_length) = max_token_length {
if token_count > max_token_length {
content = args.model.truncate(
&content,
max_token_length,
TruncationDirection::End,
)?;
token_count = max_token_length;
}
}
anyhow::Ok((content, token_count))
}
}
struct TestLowPriorityTemplate {}
impl PromptTemplate for TestLowPriorityTemplate {
fn generate(
&self,
args: &PromptArguments,
max_token_length: Option<usize>,
) -> anyhow::Result<(String, usize)> {
let mut content = "This is a low priority test prompt template".to_string();
let mut token_count = args.model.count_tokens(&content)?;
if let Some(max_token_length) = max_token_length {
if token_count > max_token_length {
content = args.model.truncate(
&content,
max_token_length,
TruncationDirection::End,
)?;
token_count = max_token_length;
}
}
anyhow::Ok((content, token_count))
}
}
let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity: 100 });
let args = PromptArguments {
model: model.clone(),
language_name: None,
project_name: None,
snippets: Vec::new(),
reserved_tokens: 0,
buffer: None,
selected_range: None,
user_prompt: None,
};
let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
(
PromptPriority::Ordered { order: 0 },
Box::new(TestPromptTemplate {}),
),
(
PromptPriority::Ordered { order: 1 },
Box::new(TestLowPriorityTemplate {}),
),
];
let chain = PromptChain::new(args, templates);
let (prompt, token_count) = chain.generate(false).unwrap();
assert_eq!(
prompt,
"This is a test prompt template\nThis is a low priority test prompt template"
.to_string()
);
assert_eq!(model.count_tokens(&prompt).unwrap(), token_count);
// Testing with Truncation Off
// Should ignore capacity and return all prompts
let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity: 20 });
let args = PromptArguments {
model: model.clone(),
language_name: None,
project_name: None,
snippets: Vec::new(),
reserved_tokens: 0,
buffer: None,
selected_range: None,
user_prompt: None,
};
let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
(
PromptPriority::Ordered { order: 0 },
Box::new(TestPromptTemplate {}),
),
(
PromptPriority::Ordered { order: 1 },
Box::new(TestLowPriorityTemplate {}),
),
];
let chain = PromptChain::new(args, templates);
let (prompt, token_count) = chain.generate(false).unwrap();
assert_eq!(
prompt,
"This is a test prompt template\nThis is a low priority test prompt template"
.to_string()
);
assert_eq!(model.count_tokens(&prompt).unwrap(), token_count);
// Testing with Truncation Off
// Should ignore capacity and return all prompts
let capacity = 20;
let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity });
let args = PromptArguments {
model: model.clone(),
language_name: None,
project_name: None,
snippets: Vec::new(),
reserved_tokens: 0,
buffer: None,
selected_range: None,
user_prompt: None,
};
let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
(
PromptPriority::Ordered { order: 0 },
Box::new(TestPromptTemplate {}),
),
(
PromptPriority::Ordered { order: 1 },
Box::new(TestLowPriorityTemplate {}),
),
(
PromptPriority::Ordered { order: 2 },
Box::new(TestLowPriorityTemplate {}),
),
];
let chain = PromptChain::new(args, templates);
let (prompt, token_count) = chain.generate(true).unwrap();
assert_eq!(prompt, "This is a test promp".to_string());
assert_eq!(token_count, capacity);
// Change Ordering of Prompts Based on Priority
let capacity = 120;
let reserved_tokens = 10;
let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity });
let args = PromptArguments {
model: model.clone(),
language_name: None,
project_name: None,
snippets: Vec::new(),
reserved_tokens,
buffer: None,
selected_range: None,
user_prompt: None,
};
let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
(
PromptPriority::Mandatory,
Box::new(TestLowPriorityTemplate {}),
),
(
PromptPriority::Ordered { order: 0 },
Box::new(TestPromptTemplate {}),
),
(
PromptPriority::Ordered { order: 1 },
Box::new(TestLowPriorityTemplate {}),
),
];
let chain = PromptChain::new(args, templates);
let (prompt, token_count) = chain.generate(true).unwrap();
assert_eq!(
prompt,
"This is a low priority test prompt template\nThis is a test prompt template\nThis is a low priority test prompt "
.to_string()
);
assert_eq!(token_count, capacity - reserved_tokens);
}
}

View File

@ -1,164 +0,0 @@
use anyhow::anyhow;
use language::BufferSnapshot;
use language::ToOffset;
use crate::models::LanguageModel;
use crate::models::TruncationDirection;
use crate::prompts::base::PromptArguments;
use crate::prompts::base::PromptTemplate;
use std::fmt::Write;
use std::ops::Range;
use std::sync::Arc;
fn retrieve_context(
buffer: &BufferSnapshot,
selected_range: &Option<Range<usize>>,
model: Arc<dyn LanguageModel>,
max_token_count: Option<usize>,
) -> anyhow::Result<(String, usize, bool)> {
let mut prompt = String::new();
let mut truncated = false;
if let Some(selected_range) = selected_range {
let start = selected_range.start.to_offset(buffer);
let end = selected_range.end.to_offset(buffer);
let start_window = buffer.text_for_range(0..start).collect::<String>();
let mut selected_window = String::new();
if start == end {
write!(selected_window, "<|START|>").unwrap();
} else {
write!(selected_window, "<|START|").unwrap();
}
write!(
selected_window,
"{}",
buffer.text_for_range(start..end).collect::<String>()
)
.unwrap();
if start != end {
write!(selected_window, "|END|>").unwrap();
}
let end_window = buffer.text_for_range(end..buffer.len()).collect::<String>();
if let Some(max_token_count) = max_token_count {
let selected_tokens = model.count_tokens(&selected_window)?;
if selected_tokens > max_token_count {
return Err(anyhow!(
"selected range is greater than model context window, truncation not possible"
));
};
let mut remaining_tokens = max_token_count - selected_tokens;
let start_window_tokens = model.count_tokens(&start_window)?;
let end_window_tokens = model.count_tokens(&end_window)?;
let outside_tokens = start_window_tokens + end_window_tokens;
if outside_tokens > remaining_tokens {
let (start_goal_tokens, end_goal_tokens) =
if start_window_tokens < end_window_tokens {
let start_goal_tokens = (remaining_tokens / 2).min(start_window_tokens);
remaining_tokens -= start_goal_tokens;
let end_goal_tokens = remaining_tokens.min(end_window_tokens);
(start_goal_tokens, end_goal_tokens)
} else {
let end_goal_tokens = (remaining_tokens / 2).min(end_window_tokens);
remaining_tokens -= end_goal_tokens;
let start_goal_tokens = remaining_tokens.min(start_window_tokens);
(start_goal_tokens, end_goal_tokens)
};
let truncated_start_window =
model.truncate(&start_window, start_goal_tokens, TruncationDirection::Start)?;
let truncated_end_window =
model.truncate(&end_window, end_goal_tokens, TruncationDirection::End)?;
writeln!(
prompt,
"{truncated_start_window}{selected_window}{truncated_end_window}"
)
.unwrap();
truncated = true;
} else {
writeln!(prompt, "{start_window}{selected_window}{end_window}").unwrap();
}
} else {
// If we dont have a selected range, include entire file.
writeln!(prompt, "{}", &buffer.text()).unwrap();
// Dumb truncation strategy
if let Some(max_token_count) = max_token_count {
if model.count_tokens(&prompt)? > max_token_count {
truncated = true;
prompt = model.truncate(&prompt, max_token_count, TruncationDirection::End)?;
}
}
}
}
let token_count = model.count_tokens(&prompt)?;
anyhow::Ok((prompt, token_count, truncated))
}
pub struct FileContext {}
impl PromptTemplate for FileContext {
fn generate(
&self,
args: &PromptArguments,
max_token_length: Option<usize>,
) -> anyhow::Result<(String, usize)> {
if let Some(buffer) = &args.buffer {
let mut prompt = String::new();
// Add Initial Preamble
// TODO: Do we want to add the path in here?
writeln!(
prompt,
"The file you are currently working on has the following content:"
)
.unwrap();
let language_name = args
.language_name
.clone()
.unwrap_or("".to_string())
.to_lowercase();
let (context, _, truncated) = retrieve_context(
buffer,
&args.selected_range,
args.model.clone(),
max_token_length,
)?;
writeln!(prompt, "```{language_name}\n{context}\n```").unwrap();
if truncated {
writeln!(prompt, "Note the content has been truncated and only represents a portion of the file.").unwrap();
}
if let Some(selected_range) = &args.selected_range {
let start = selected_range.start.to_offset(buffer);
let end = selected_range.end.to_offset(buffer);
if start == end {
writeln!(prompt, "In particular, the user's cursor is currently on the '<|START|>' span in the above content, with no text selected.").unwrap();
} else {
writeln!(prompt, "In particular, the user has selected a section of the text between the '<|START|' and '|END|>' spans.").unwrap();
}
}
// Really dumb truncation strategy
if let Some(max_tokens) = max_token_length {
prompt = args
.model
.truncate(&prompt, max_tokens, TruncationDirection::End)?;
}
let token_count = args.model.count_tokens(&prompt)?;
anyhow::Ok((prompt, token_count))
} else {
Err(anyhow!("no buffer provided to retrieve file context from"))
}
}
}

View File

@ -1,99 +0,0 @@
use crate::prompts::base::{PromptArguments, PromptFileType, PromptTemplate};
use anyhow::anyhow;
use std::fmt::Write;
pub fn capitalize(s: &str) -> String {
let mut c = s.chars();
match c.next() {
None => String::new(),
Some(f) => f.to_uppercase().collect::<String>() + c.as_str(),
}
}
pub struct GenerateInlineContent {}
impl PromptTemplate for GenerateInlineContent {
fn generate(
&self,
args: &PromptArguments,
max_token_length: Option<usize>,
) -> anyhow::Result<(String, usize)> {
let Some(user_prompt) = &args.user_prompt else {
return Err(anyhow!("user prompt not provided"));
};
let file_type = args.get_file_type();
let content_type = match &file_type {
PromptFileType::Code => "code",
PromptFileType::Text => "text",
};
let mut prompt = String::new();
if let Some(selected_range) = &args.selected_range {
if selected_range.start == selected_range.end {
writeln!(
prompt,
"Assume the cursor is located where the `<|START|>` span is."
)
.unwrap();
writeln!(
prompt,
"{} can't be replaced, so assume your answer will be inserted at the cursor.",
capitalize(content_type)
)
.unwrap();
writeln!(
prompt,
"Generate {content_type} based on the users prompt: {user_prompt}",
)
.unwrap();
} else {
writeln!(prompt, "Modify the user's selected {content_type} based upon the users prompt: '{user_prompt}'").unwrap();
writeln!(prompt, "You must reply with only the adjusted {content_type} (within the '<|START|' and '|END|>' spans) not the entire file.").unwrap();
writeln!(prompt, "Double check that you only return code and not the '<|START|' and '|END|'> spans").unwrap();
}
} else {
writeln!(
prompt,
"Generate {content_type} based on the users prompt: {user_prompt}"
)
.unwrap();
}
if let Some(language_name) = &args.language_name {
writeln!(
prompt,
"Your answer MUST always and only be valid {}.",
language_name
)
.unwrap();
}
writeln!(prompt, "Never make remarks about the output.").unwrap();
writeln!(
prompt,
"Do not return anything else, except the generated {content_type}."
)
.unwrap();
match file_type {
PromptFileType::Code => {
// writeln!(prompt, "Always wrap your code in a Markdown block.").unwrap();
}
_ => {}
}
// Really dumb truncation strategy
if let Some(max_tokens) = max_token_length {
prompt = args.model.truncate(
&prompt,
max_tokens,
crate::models::TruncationDirection::End,
)?;
}
let token_count = args.model.count_tokens(&prompt)?;
anyhow::Ok((prompt, token_count))
}
}

View File

@ -1,5 +0,0 @@
pub mod base;
pub mod file_context;
pub mod generate;
pub mod preamble;
pub mod repository_context;

View File

@ -1,52 +0,0 @@
use crate::prompts::base::{PromptArguments, PromptFileType, PromptTemplate};
use std::fmt::Write;
pub struct EngineerPreamble {}
impl PromptTemplate for EngineerPreamble {
fn generate(
&self,
args: &PromptArguments,
max_token_length: Option<usize>,
) -> anyhow::Result<(String, usize)> {
let mut prompts = Vec::new();
match args.get_file_type() {
PromptFileType::Code => {
prompts.push(format!(
"You are an expert {}engineer.",
args.language_name.clone().unwrap_or("".to_string()) + " "
));
}
PromptFileType::Text => {
prompts.push("You are an expert engineer.".to_string());
}
}
if let Some(project_name) = args.project_name.clone() {
prompts.push(format!(
"You are currently working inside the '{project_name}' project in code editor Zed."
));
}
if let Some(mut remaining_tokens) = max_token_length {
let mut prompt = String::new();
let mut total_count = 0;
for prompt_piece in prompts {
let prompt_token_count =
args.model.count_tokens(&prompt_piece)? + args.model.count_tokens("\n")?;
if remaining_tokens > prompt_token_count {
writeln!(prompt, "{prompt_piece}").unwrap();
remaining_tokens -= prompt_token_count;
total_count += prompt_token_count;
}
}
anyhow::Ok((prompt, total_count))
} else {
let prompt = prompts.join("\n");
let token_count = args.model.count_tokens(&prompt)?;
anyhow::Ok((prompt, token_count))
}
}
}

View File

@ -1,98 +0,0 @@
use crate::prompts::base::{PromptArguments, PromptTemplate};
use std::fmt::Write;
use std::{ops::Range, path::PathBuf};
use gpui::{AsyncAppContext, Model};
use language::{Anchor, Buffer};
#[derive(Clone)]
pub struct PromptCodeSnippet {
path: Option<PathBuf>,
language_name: Option<String>,
content: String,
}
impl PromptCodeSnippet {
pub fn new(
buffer: Model<Buffer>,
range: Range<Anchor>,
cx: &mut AsyncAppContext,
) -> anyhow::Result<Self> {
let (content, language_name, file_path) = buffer.update(cx, |buffer, _| {
let snapshot = buffer.snapshot();
let content = snapshot.text_for_range(range.clone()).collect::<String>();
let language_name = buffer
.language()
.and_then(|language| Some(language.name().to_string().to_lowercase()));
let file_path = buffer
.file()
.and_then(|file| Some(file.path().to_path_buf()));
(content, language_name, file_path)
})?;
anyhow::Ok(PromptCodeSnippet {
path: file_path,
language_name,
content,
})
}
}
impl ToString for PromptCodeSnippet {
fn to_string(&self) -> String {
let path = self
.path
.as_ref()
.and_then(|path| Some(path.to_string_lossy().to_string()))
.unwrap_or("".to_string());
let language_name = self.language_name.clone().unwrap_or("".to_string());
let content = self.content.clone();
format!("The below code snippet may be relevant from file: {path}\n```{language_name}\n{content}\n```")
}
}
pub struct RepositoryContext {}
impl PromptTemplate for RepositoryContext {
fn generate(
&self,
args: &PromptArguments,
max_token_length: Option<usize>,
) -> anyhow::Result<(String, usize)> {
const MAXIMUM_SNIPPET_TOKEN_COUNT: usize = 500;
let template = "You are working inside a large repository, here are a few code snippets that may be useful.";
let mut prompt = String::new();
let mut remaining_tokens = max_token_length.clone();
let seperator_token_length = args.model.count_tokens("\n")?;
for snippet in &args.snippets {
let mut snippet_prompt = template.to_string();
let content = snippet.to_string();
writeln!(snippet_prompt, "{content}").unwrap();
let token_count = args.model.count_tokens(&snippet_prompt)?;
if token_count <= MAXIMUM_SNIPPET_TOKEN_COUNT {
if let Some(tokens_left) = remaining_tokens {
if tokens_left >= token_count {
writeln!(prompt, "{snippet_prompt}").unwrap();
remaining_tokens = if tokens_left >= (token_count + seperator_token_length)
{
Some(tokens_left - token_count - seperator_token_length)
} else {
Some(0)
};
}
} else {
writeln!(prompt, "{snippet_prompt}").unwrap();
}
}
}
let total_token_count = args.model.count_tokens(&prompt)?;
anyhow::Ok((prompt, total_token_count))
}
}

View File

@ -1 +0,0 @@
pub mod open_ai;

View File

@ -1,297 +0,0 @@
use anyhow::{anyhow, Result};
use futures::{
future::BoxFuture, io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt,
Stream, StreamExt,
};
use gpui::{AppContext, BackgroundExecutor};
use isahc::{http::StatusCode, Request, RequestExt};
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use std::{
env,
fmt::{self, Display},
io,
sync::Arc,
};
use util::ResultExt;
use crate::{
auth::{CredentialProvider, ProviderCredential},
completion::{CompletionProvider, CompletionRequest},
models::LanguageModel,
};
use crate::providers::open_ai::{OpenAILanguageModel, OPENAI_API_URL};
#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum Role {
User,
Assistant,
System,
}
impl Role {
pub fn cycle(&mut self) {
*self = match self {
Role::User => Role::Assistant,
Role::Assistant => Role::System,
Role::System => Role::User,
}
}
}
impl Display for Role {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Role::User => write!(f, "User"),
Role::Assistant => write!(f, "Assistant"),
Role::System => write!(f, "System"),
}
}
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct RequestMessage {
pub role: Role,
pub content: String,
}
#[derive(Debug, Default, Serialize)]
pub struct OpenAIRequest {
pub model: String,
pub messages: Vec<RequestMessage>,
pub stream: bool,
pub stop: Vec<String>,
pub temperature: f32,
}
impl CompletionRequest for OpenAIRequest {
fn data(&self) -> serde_json::Result<String> {
serde_json::to_string(self)
}
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct ResponseMessage {
pub role: Option<Role>,
pub content: Option<String>,
}
#[derive(Deserialize, Debug)]
pub struct OpenAIUsage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}
#[derive(Deserialize, Debug)]
pub struct ChatChoiceDelta {
pub index: u32,
pub delta: ResponseMessage,
pub finish_reason: Option<String>,
}
#[derive(Deserialize, Debug)]
pub struct OpenAIResponseStreamEvent {
pub id: Option<String>,
pub object: String,
pub created: u32,
pub model: String,
pub choices: Vec<ChatChoiceDelta>,
pub usage: Option<OpenAIUsage>,
}
pub async fn stream_completion(
credential: ProviderCredential,
executor: BackgroundExecutor,
request: Box<dyn CompletionRequest>,
) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
let api_key = match credential {
ProviderCredential::Credentials { api_key } => api_key,
_ => {
return Err(anyhow!("no credentials provider for completion"));
}
};
let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
let json_data = request.data()?;
let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", api_key))
.body(json_data)?
.send_async()
.await?;
let status = response.status();
if status == StatusCode::OK {
executor
.spawn(async move {
let mut lines = BufReader::new(response.body_mut()).lines();
fn parse_line(
line: Result<String, io::Error>,
) -> Result<Option<OpenAIResponseStreamEvent>> {
if let Some(data) = line?.strip_prefix("data: ") {
let event = serde_json::from_str(&data)?;
Ok(Some(event))
} else {
Ok(None)
}
}
while let Some(line) = lines.next().await {
if let Some(event) = parse_line(line).transpose() {
let done = event.as_ref().map_or(false, |event| {
event
.choices
.last()
.map_or(false, |choice| choice.finish_reason.is_some())
});
if tx.unbounded_send(event).is_err() {
break;
}
if done {
break;
}
}
}
anyhow::Ok(())
})
.detach();
Ok(rx)
} else {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
#[derive(Deserialize)]
struct OpenAIResponse {
error: OpenAIError,
}
#[derive(Deserialize)]
struct OpenAIError {
message: String,
}
match serde_json::from_str::<OpenAIResponse>(&body) {
Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
"Failed to connect to OpenAI API: {}",
response.error.message,
)),
_ => Err(anyhow!(
"Failed to connect to OpenAI API: {} {}",
response.status(),
body,
)),
}
}
}
#[derive(Clone)]
pub struct OpenAICompletionProvider {
model: OpenAILanguageModel,
credential: Arc<RwLock<ProviderCredential>>,
executor: BackgroundExecutor,
}
impl OpenAICompletionProvider {
pub fn new(model_name: &str, executor: BackgroundExecutor) -> Self {
let model = OpenAILanguageModel::load(model_name);
let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
Self {
model,
credential,
executor,
}
}
}
impl CredentialProvider for OpenAICompletionProvider {
fn has_credentials(&self) -> bool {
match *self.credential.read() {
ProviderCredential::Credentials { .. } => true,
_ => false,
}
}
fn retrieve_credentials(&self, cx: &mut AppContext) -> ProviderCredential {
let existing_credential = self.credential.read().clone();
let retrieved_credential = match existing_credential {
ProviderCredential::Credentials { .. } => existing_credential.clone(),
_ => {
if let Some(api_key) = env::var("OPENAI_API_KEY").log_err() {
ProviderCredential::Credentials { api_key }
} else if let Some(Some((_, api_key))) =
cx.read_credentials(OPENAI_API_URL).log_err()
{
if let Some(api_key) = String::from_utf8(api_key).log_err() {
ProviderCredential::Credentials { api_key }
} else {
ProviderCredential::NoCredentials
}
} else {
ProviderCredential::NoCredentials
}
}
};
*self.credential.write() = retrieved_credential.clone();
retrieved_credential
}
fn save_credentials(&self, cx: &mut AppContext, credential: ProviderCredential) {
*self.credential.write() = credential.clone();
let credential = credential.clone();
match credential {
ProviderCredential::Credentials { api_key } => {
cx.write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
.log_err();
}
_ => {}
}
}
fn delete_credentials(&self, cx: &mut AppContext) {
cx.delete_credentials(OPENAI_API_URL).log_err();
*self.credential.write() = ProviderCredential::NoCredentials;
}
}
impl CompletionProvider for OpenAICompletionProvider {
fn base_model(&self) -> Box<dyn LanguageModel> {
let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
model
}
fn complete(
&self,
prompt: Box<dyn CompletionRequest>,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
// Currently the CompletionRequest for OpenAI, includes a 'model' parameter
// This means that the model is determined by the CompletionRequest and not the CompletionProvider,
// which is currently model based, due to the langauge model.
// At some point in the future we should rectify this.
let credential = self.credential.read().clone();
let request = stream_completion(credential, self.executor.clone(), prompt);
async move {
let response = request.await?;
let stream = response
.filter_map(|response| async move {
match response {
Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
Err(error) => Some(Err(error)),
}
})
.boxed();
Ok(stream)
}
.boxed()
}
fn box_clone(&self) -> Box<dyn CompletionProvider> {
Box::new((*self).clone())
}
}

View File

@ -1,305 +0,0 @@
use anyhow::{anyhow, Result};
use async_trait::async_trait;
use futures::AsyncReadExt;
use gpui::BackgroundExecutor;
use gpui::{serde_json, AppContext};
use isahc::http::StatusCode;
use isahc::prelude::Configurable;
use isahc::{AsyncBody, Response};
use lazy_static::lazy_static;
use parking_lot::{Mutex, RwLock};
use parse_duration::parse;
use postage::watch;
use serde::{Deserialize, Serialize};
use std::env;
use std::ops::Add;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tiktoken_rs::{cl100k_base, CoreBPE};
use util::http::{HttpClient, Request};
use util::ResultExt;
use crate::auth::{CredentialProvider, ProviderCredential};
use crate::embedding::{Embedding, EmbeddingProvider};
use crate::models::LanguageModel;
use crate::providers::open_ai::OpenAILanguageModel;
use crate::providers::open_ai::OPENAI_API_URL;
lazy_static! {
static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
}
#[derive(Clone)]
pub struct OpenAIEmbeddingProvider {
model: OpenAILanguageModel,
credential: Arc<RwLock<ProviderCredential>>,
pub client: Arc<dyn HttpClient>,
pub executor: BackgroundExecutor,
rate_limit_count_rx: watch::Receiver<Option<Instant>>,
rate_limit_count_tx: Arc<Mutex<watch::Sender<Option<Instant>>>>,
}
#[derive(Serialize)]
struct OpenAIEmbeddingRequest<'a> {
model: &'static str,
input: Vec<&'a str>,
}
#[derive(Deserialize)]
struct OpenAIEmbeddingResponse {
data: Vec<OpenAIEmbedding>,
usage: OpenAIEmbeddingUsage,
}
#[derive(Debug, Deserialize)]
struct OpenAIEmbedding {
embedding: Vec<f32>,
index: usize,
object: String,
}
#[derive(Deserialize)]
struct OpenAIEmbeddingUsage {
prompt_tokens: usize,
total_tokens: usize,
}
impl OpenAIEmbeddingProvider {
pub fn new(client: Arc<dyn HttpClient>, executor: BackgroundExecutor) -> Self {
let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
let model = OpenAILanguageModel::load("text-embedding-ada-002");
let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
OpenAIEmbeddingProvider {
model,
credential,
client,
executor,
rate_limit_count_rx,
rate_limit_count_tx,
}
}
fn get_api_key(&self) -> Result<String> {
match self.credential.read().clone() {
ProviderCredential::Credentials { api_key } => Ok(api_key),
_ => Err(anyhow!("api credentials not provided")),
}
}
fn resolve_rate_limit(&self) {
let reset_time = *self.rate_limit_count_tx.lock().borrow();
if let Some(reset_time) = reset_time {
if Instant::now() >= reset_time {
*self.rate_limit_count_tx.lock().borrow_mut() = None
}
}
log::trace!(
"resolving reset time: {:?}",
*self.rate_limit_count_tx.lock().borrow()
);
}
fn update_reset_time(&self, reset_time: Instant) {
let original_time = *self.rate_limit_count_tx.lock().borrow();
let updated_time = if let Some(original_time) = original_time {
if reset_time < original_time {
Some(reset_time)
} else {
Some(original_time)
}
} else {
Some(reset_time)
};
log::trace!("updating rate limit time: {:?}", updated_time);
*self.rate_limit_count_tx.lock().borrow_mut() = updated_time;
}
async fn send_request(
&self,
api_key: &str,
spans: Vec<&str>,
request_timeout: u64,
) -> Result<Response<AsyncBody>> {
let request = Request::post("https://api.openai.com/v1/embeddings")
.redirect_policy(isahc::config::RedirectPolicy::Follow)
.timeout(Duration::from_secs(request_timeout))
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", api_key))
.body(
serde_json::to_string(&OpenAIEmbeddingRequest {
input: spans.clone(),
model: "text-embedding-ada-002",
})
.unwrap()
.into(),
)?;
Ok(self.client.send(request).await?)
}
}
impl CredentialProvider for OpenAIEmbeddingProvider {
fn has_credentials(&self) -> bool {
match *self.credential.read() {
ProviderCredential::Credentials { .. } => true,
_ => false,
}
}
fn retrieve_credentials(&self, cx: &mut AppContext) -> ProviderCredential {
let existing_credential = self.credential.read().clone();
let retrieved_credential = match existing_credential {
ProviderCredential::Credentials { .. } => existing_credential.clone(),
_ => {
if let Some(api_key) = env::var("OPENAI_API_KEY").log_err() {
ProviderCredential::Credentials { api_key }
} else if let Some(Some((_, api_key))) =
cx.read_credentials(OPENAI_API_URL).log_err()
{
if let Some(api_key) = String::from_utf8(api_key).log_err() {
ProviderCredential::Credentials { api_key }
} else {
ProviderCredential::NoCredentials
}
} else {
ProviderCredential::NoCredentials
}
}
};
*self.credential.write() = retrieved_credential.clone();
retrieved_credential
}
fn save_credentials(&self, cx: &mut AppContext, credential: ProviderCredential) {
*self.credential.write() = credential.clone();
match credential {
ProviderCredential::Credentials { api_key } => {
cx.write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
.log_err();
}
_ => {}
}
}
fn delete_credentials(&self, cx: &mut AppContext) {
cx.delete_credentials(OPENAI_API_URL).log_err();
*self.credential.write() = ProviderCredential::NoCredentials;
}
}
#[async_trait]
impl EmbeddingProvider for OpenAIEmbeddingProvider {
fn base_model(&self) -> Box<dyn LanguageModel> {
let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
model
}
fn max_tokens_per_batch(&self) -> usize {
50000
}
fn rate_limit_expiration(&self) -> Option<Instant> {
*self.rate_limit_count_rx.borrow()
}
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
const MAX_RETRIES: usize = 4;
let api_key = self.get_api_key()?;
let mut request_number = 0;
let mut rate_limiting = false;
let mut request_timeout: u64 = 15;
let mut response: Response<AsyncBody>;
while request_number < MAX_RETRIES {
response = self
.send_request(
&api_key,
spans.iter().map(|x| &**x).collect(),
request_timeout,
)
.await?;
request_number += 1;
match response.status() {
StatusCode::REQUEST_TIMEOUT => {
request_timeout += 5;
}
StatusCode::OK => {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?;
log::trace!(
"openai embedding completed. tokens: {:?}",
response.usage.total_tokens
);
// If we complete a request successfully that was previously rate_limited
// resolve the rate limit
if rate_limiting {
self.resolve_rate_limit()
}
return Ok(response
.data
.into_iter()
.map(|embedding| Embedding::from(embedding.embedding))
.collect());
}
StatusCode::TOO_MANY_REQUESTS => {
rate_limiting = true;
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
let delay_duration = {
let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
if let Some(time_to_reset) =
response.headers().get("x-ratelimit-reset-tokens")
{
if let Ok(time_str) = time_to_reset.to_str() {
parse(time_str).unwrap_or(delay)
} else {
delay
}
} else {
delay
}
};
// If we've previously rate limited, increment the duration but not the count
let reset_time = Instant::now().add(delay_duration);
self.update_reset_time(reset_time);
log::trace!(
"openai rate limiting: waiting {:?} until lifted",
&delay_duration
);
self.executor.timer(delay_duration).await;
}
_ => {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
return Err(anyhow!(
"open ai bad request: {:?} {:?}",
&response.status(),
body
));
}
}
}
Err(anyhow!("openai max retries"))
}
}

View File

@ -1,9 +0,0 @@
pub mod completion;
pub mod embedding;
pub mod model;
pub use completion::*;
pub use embedding::*;
pub use model::OpenAILanguageModel;
pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";

View File

@ -1,57 +0,0 @@
use anyhow::anyhow;
use tiktoken_rs::CoreBPE;
use util::ResultExt;
use crate::models::{LanguageModel, TruncationDirection};
#[derive(Clone)]
pub struct OpenAILanguageModel {
name: String,
bpe: Option<CoreBPE>,
}
impl OpenAILanguageModel {
pub fn load(model_name: &str) -> Self {
let bpe = tiktoken_rs::get_bpe_from_model(model_name).log_err();
OpenAILanguageModel {
name: model_name.to_string(),
bpe,
}
}
}
impl LanguageModel for OpenAILanguageModel {
fn name(&self) -> String {
self.name.clone()
}
fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
if let Some(bpe) = &self.bpe {
anyhow::Ok(bpe.encode_with_special_tokens(content).len())
} else {
Err(anyhow!("bpe for open ai model was not retrieved"))
}
}
fn truncate(
&self,
content: &str,
length: usize,
direction: TruncationDirection,
) -> anyhow::Result<String> {
if let Some(bpe) = &self.bpe {
let tokens = bpe.encode_with_special_tokens(content);
if tokens.len() > length {
match direction {
TruncationDirection::End => bpe.decode(tokens[..length].to_vec()),
TruncationDirection::Start => bpe.decode(tokens[length..].to_vec()),
}
} else {
bpe.decode(tokens)
}
} else {
Err(anyhow!("bpe for open ai model was not retrieved"))
}
}
fn capacity(&self) -> anyhow::Result<usize> {
anyhow::Ok(tiktoken_rs::model::get_context_size(&self.name))
}
}

View File

@ -1,11 +0,0 @@
pub trait LanguageModel {
fn name(&self) -> String;
fn count_tokens(&self, content: &str) -> anyhow::Result<usize>;
fn truncate(
&self,
content: &str,
length: usize,
direction: TruncationDirection,
) -> anyhow::Result<String>;
fn capacity(&self) -> anyhow::Result<usize>;
}

View File

@ -1,191 +0,0 @@
use std::{
sync::atomic::{self, AtomicUsize, Ordering},
time::Instant,
};
use async_trait::async_trait;
use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
use gpui::AppContext;
use parking_lot::Mutex;
use crate::{
auth::{CredentialProvider, ProviderCredential},
completion::{CompletionProvider, CompletionRequest},
embedding::{Embedding, EmbeddingProvider},
models::{LanguageModel, TruncationDirection},
};
#[derive(Clone)]
pub struct FakeLanguageModel {
pub capacity: usize,
}
impl LanguageModel for FakeLanguageModel {
fn name(&self) -> String {
"dummy".to_string()
}
fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
anyhow::Ok(content.chars().collect::<Vec<char>>().len())
}
fn truncate(
&self,
content: &str,
length: usize,
direction: TruncationDirection,
) -> anyhow::Result<String> {
println!("TRYING TO TRUNCATE: {:?}", length.clone());
if length > self.count_tokens(content)? {
println!("NOT TRUNCATING");
return anyhow::Ok(content.to_string());
}
anyhow::Ok(match direction {
TruncationDirection::End => content.chars().collect::<Vec<char>>()[..length]
.into_iter()
.collect::<String>(),
TruncationDirection::Start => content.chars().collect::<Vec<char>>()[length..]
.into_iter()
.collect::<String>(),
})
}
fn capacity(&self) -> anyhow::Result<usize> {
anyhow::Ok(self.capacity)
}
}
pub struct FakeEmbeddingProvider {
pub embedding_count: AtomicUsize,
}
impl Clone for FakeEmbeddingProvider {
fn clone(&self) -> Self {
FakeEmbeddingProvider {
embedding_count: AtomicUsize::new(self.embedding_count.load(Ordering::SeqCst)),
}
}
}
impl Default for FakeEmbeddingProvider {
fn default() -> Self {
FakeEmbeddingProvider {
embedding_count: AtomicUsize::default(),
}
}
}
impl FakeEmbeddingProvider {
pub fn embedding_count(&self) -> usize {
self.embedding_count.load(atomic::Ordering::SeqCst)
}
pub fn embed_sync(&self, span: &str) -> Embedding {
let mut result = vec![1.0; 26];
for letter in span.chars() {
let letter = letter.to_ascii_lowercase();
if letter as u32 >= 'a' as u32 {
let ix = (letter as u32) - ('a' as u32);
if ix < 26 {
result[ix as usize] += 1.0;
}
}
}
let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
for x in &mut result {
*x /= norm;
}
result.into()
}
}
impl CredentialProvider for FakeEmbeddingProvider {
fn has_credentials(&self) -> bool {
true
}
fn retrieve_credentials(&self, _cx: &mut AppContext) -> ProviderCredential {
ProviderCredential::NotNeeded
}
fn save_credentials(&self, _cx: &mut AppContext, _credential: ProviderCredential) {}
fn delete_credentials(&self, _cx: &mut AppContext) {}
}
#[async_trait]
impl EmbeddingProvider for FakeEmbeddingProvider {
fn base_model(&self) -> Box<dyn LanguageModel> {
Box::new(FakeLanguageModel { capacity: 1000 })
}
fn max_tokens_per_batch(&self) -> usize {
1000
}
fn rate_limit_expiration(&self) -> Option<Instant> {
None
}
async fn embed_batch(&self, spans: Vec<String>) -> anyhow::Result<Vec<Embedding>> {
self.embedding_count
.fetch_add(spans.len(), atomic::Ordering::SeqCst);
anyhow::Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
}
}
pub struct FakeCompletionProvider {
last_completion_tx: Mutex<Option<mpsc::Sender<String>>>,
}
impl Clone for FakeCompletionProvider {
fn clone(&self) -> Self {
Self {
last_completion_tx: Mutex::new(None),
}
}
}
impl FakeCompletionProvider {
pub fn new() -> Self {
Self {
last_completion_tx: Mutex::new(None),
}
}
pub fn send_completion(&self, completion: impl Into<String>) {
let mut tx = self.last_completion_tx.lock();
tx.as_mut().unwrap().try_send(completion.into()).unwrap();
}
pub fn finish_completion(&self) {
self.last_completion_tx.lock().take().unwrap();
}
}
impl CredentialProvider for FakeCompletionProvider {
fn has_credentials(&self) -> bool {
true
}
fn retrieve_credentials(&self, _cx: &mut AppContext) -> ProviderCredential {
ProviderCredential::NotNeeded
}
fn save_credentials(&self, _cx: &mut AppContext, _credential: ProviderCredential) {}
fn delete_credentials(&self, _cx: &mut AppContext) {}
}
impl CompletionProvider for FakeCompletionProvider {
fn base_model(&self) -> Box<dyn LanguageModel> {
let model: Box<dyn LanguageModel> = Box::new(FakeLanguageModel { capacity: 8190 });
model
}
fn complete(
&self,
_prompt: Box<dyn CompletionRequest>,
) -> BoxFuture<'static, anyhow::Result<BoxStream<'static, anyhow::Result<String>>>> {
let (tx, rx) = mpsc::channel(1);
*self.last_completion_tx.lock() = Some(tx);
async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed()
}
fn box_clone(&self) -> Box<dyn CompletionProvider> {
Box::new((*self).clone())
}
}

View File

@ -18,13 +18,14 @@ gpui = { path = "../gpui" }
language = { path = "../language" }
menu = { path = "../menu" }
multi_buffer = { path = "../multi_buffer" }
project = { path = "../project" }
search = { path = "../search" }
semantic_index = { path = "../semantic_index" }
settings = { path = "../settings" }
theme = { path = "../theme" }
ui = { path = "../ui" }
util = { path = "../util" }
workspace = { path = "../workspace" }
semantic_index = { path = "../semantic_index" }
project = { path = "../project" }
uuid.workspace = true
log.workspace = true
@ -43,9 +44,9 @@ smol.workspace = true
tiktoken-rs.workspace = true
[dev-dependencies]
ai = { path = "../ai", features = ["test-support"]}
editor = { path = "../editor", features = ["test-support"] }
project = { path = "../project", features = ["test-support"] }
ai = { path = "../ai", features = ["test-support"]}
ctor.workspace = true
env_logger.workspace = true

View File

@ -12,12 +12,28 @@ use chrono::{DateTime, Local};
use collections::HashMap;
use fs::Fs;
use futures::StreamExt;
use gpui::AppContext;
use gpui::{actions, AppContext, SharedString};
use regex::Regex;
use serde::{Deserialize, Serialize};
use std::{cmp::Reverse, ffi::OsStr, path::PathBuf, sync::Arc};
use util::paths::CONVERSATIONS_DIR;
actions!(
assistant,
[
NewConversation,
Assist,
Split,
CycleMessageRole,
QuoteSelection,
ToggleFocus,
ResetKey,
InlineAssist,
ToggleIncludeConversation,
ToggleRetrieveContext,
]
);
#[derive(
Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize,
)]
@ -34,7 +50,7 @@ struct MessageMetadata {
enum MessageStatus {
Pending,
Done,
Error(Arc<str>),
Error(SharedString),
}
#[derive(Serialize, Deserialize)]

File diff suppressed because it is too large Load Diff

View File

@ -1,7 +1,8 @@
use anyhow;
use gpui::Pixels;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::Setting;
use settings::Settings;
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
pub enum OpenAIModel {
@ -51,8 +52,8 @@ pub enum AssistantDockPosition {
pub struct AssistantSettings {
pub button: bool,
pub dock: AssistantDockPosition,
pub default_width: f32,
pub default_height: f32,
pub default_width: Pixels,
pub default_height: Pixels,
pub default_open_ai_model: OpenAIModel,
}
@ -65,7 +66,7 @@ pub struct AssistantSettingsContent {
pub default_open_ai_model: Option<OpenAIModel>,
}
impl Setting for AssistantSettings {
impl Settings for AssistantSettings {
const KEY: Option<&'static str> = Some("assistant");
type FileContent = AssistantSettingsContent;
@ -73,7 +74,7 @@ impl Setting for AssistantSettings {
fn load(
default_value: &Self::FileContent,
user_values: &[&Self::FileContent],
_: &gpui::AppContext,
_: &mut gpui::AppContext,
) -> anyhow::Result<Self> {
Self::load_via_json_merge(default_value, user_values)
}

View File

@ -3,7 +3,7 @@ use ai::completion::{CompletionProvider, CompletionRequest};
use anyhow::Result;
use editor::{Anchor, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint};
use futures::{channel::mpsc, SinkExt, Stream, StreamExt};
use gpui::{Entity, ModelContext, ModelHandle, Task};
use gpui::{EventEmitter, Model, ModelContext, Task};
use language::{Rope, TransactionId};
use multi_buffer;
use std::{cmp, future, ops::Range, sync::Arc};
@ -21,7 +21,7 @@ pub enum CodegenKind {
pub struct Codegen {
provider: Arc<dyn CompletionProvider>,
buffer: ModelHandle<MultiBuffer>,
buffer: Model<MultiBuffer>,
snapshot: MultiBufferSnapshot,
kind: CodegenKind,
last_equal_ranges: Vec<Range<Anchor>>,
@ -32,13 +32,11 @@ pub struct Codegen {
_subscription: gpui::Subscription,
}
impl Entity for Codegen {
type Event = Event;
}
impl EventEmitter<Event> for Codegen {}
impl Codegen {
pub fn new(
buffer: ModelHandle<MultiBuffer>,
buffer: Model<MultiBuffer>,
kind: CodegenKind,
provider: Arc<dyn CompletionProvider>,
cx: &mut ModelContext<Self>,
@ -60,7 +58,7 @@ impl Codegen {
fn handle_buffer_event(
&mut self,
_buffer: ModelHandle<MultiBuffer>,
_buffer: Model<MultiBuffer>,
event: &multi_buffer::Event,
cx: &mut ModelContext<Self>,
) {
@ -111,13 +109,13 @@ impl Codegen {
.unwrap_or_else(|| snapshot.indent_size_for_line(selection_start.row));
let response = self.provider.complete(prompt);
self.generation = cx.spawn_weak(|this, mut cx| {
self.generation = cx.spawn(|this, mut cx| {
async move {
let generate = async {
let mut edit_start = range.start.to_offset(&snapshot);
let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
let diff = cx.background().spawn(async move {
let diff = cx.background_executor().spawn(async move {
let chunks = strip_invalid_spans_from_codeblock(response.await?);
futures::pin_mut!(chunks);
let mut diff = StreamingDiff::new(selected_text.to_string());
@ -183,12 +181,6 @@ impl Codegen {
});
while let Some(hunks) = hunks_rx.next().await {
let this = if let Some(this) = this.upgrade(&cx) {
this
} else {
break;
};
this.update(&mut cx, |this, cx| {
this.last_equal_ranges.clear();
@ -245,7 +237,7 @@ impl Codegen {
}
cx.notify();
});
})?;
}
diff.await?;
@ -253,17 +245,16 @@ impl Codegen {
};
let result = generate.await;
if let Some(this) = this.upgrade(&cx) {
this.update(&mut cx, |this, cx| {
this.last_equal_ranges.clear();
this.idle = true;
if let Err(error) = result {
this.error = Some(error);
}
cx.emit(Event::Finished);
cx.notify();
});
}
this.update(&mut cx, |this, cx| {
this.last_equal_ranges.clear();
this.idle = true;
if let Err(error) = result {
this.error = Some(error);
}
cx.emit(Event::Finished);
cx.notify();
})
.ok();
}
});
self.error.take();
@ -372,7 +363,7 @@ mod tests {
use super::*;
use ai::test::FakeCompletionProvider;
use futures::stream::{self};
use gpui::{executor::Deterministic, TestAppContext};
use gpui::{Context, TestAppContext};
use indoc::indoc;
use language::{language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, Point};
use rand::prelude::*;
@ -391,12 +382,8 @@ mod tests {
}
#[gpui::test(iterations = 10)]
async fn test_transform_autoindent(
cx: &mut TestAppContext,
mut rng: StdRng,
deterministic: Arc<Deterministic>,
) {
cx.set_global(cx.read(SettingsStore::test));
async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
cx.set_global(cx.update(SettingsStore::test));
cx.update(language_settings::init);
let text = indoc! {"
@ -408,14 +395,14 @@ mod tests {
}
"};
let buffer =
cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
cx.new_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
let range = buffer.read_with(cx, |buffer, cx| {
let snapshot = buffer.snapshot(cx);
snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
});
let provider = Arc::new(FakeCompletionProvider::new());
let codegen = cx.add_model(|cx| {
let codegen = cx.new_model(|cx| {
Codegen::new(
buffer.clone(),
CodegenKind::Transform { range },
@ -442,10 +429,10 @@ mod tests {
println!("CHUNK: {:?}", &chunk);
provider.send_completion(chunk);
new_text = suffix;
deterministic.run_until_parked();
cx.background_executor.run_until_parked();
}
provider.finish_completion();
deterministic.run_until_parked();
cx.background_executor.run_until_parked();
assert_eq!(
buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
@ -464,9 +451,8 @@ mod tests {
async fn test_autoindent_when_generating_past_indentation(
cx: &mut TestAppContext,
mut rng: StdRng,
deterministic: Arc<Deterministic>,
) {
cx.set_global(cx.read(SettingsStore::test));
cx.set_global(cx.update(SettingsStore::test));
cx.update(language_settings::init);
let text = indoc! {"
@ -475,14 +461,14 @@ mod tests {
}
"};
let buffer =
cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
cx.new_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
let position = buffer.read_with(cx, |buffer, cx| {
let snapshot = buffer.snapshot(cx);
snapshot.anchor_before(Point::new(1, 6))
});
let provider = Arc::new(FakeCompletionProvider::new());
let codegen = cx.add_model(|cx| {
let codegen = cx.new_model(|cx| {
Codegen::new(
buffer.clone(),
CodegenKind::Generate { position },
@ -508,10 +494,10 @@ mod tests {
let (chunk, suffix) = new_text.split_at(len);
provider.send_completion(chunk);
new_text = suffix;
deterministic.run_until_parked();
cx.background_executor.run_until_parked();
}
provider.finish_completion();
deterministic.run_until_parked();
cx.background_executor.run_until_parked();
assert_eq!(
buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
@ -530,9 +516,8 @@ mod tests {
async fn test_autoindent_when_generating_before_indentation(
cx: &mut TestAppContext,
mut rng: StdRng,
deterministic: Arc<Deterministic>,
) {
cx.set_global(cx.read(SettingsStore::test));
cx.set_global(cx.update(SettingsStore::test));
cx.update(language_settings::init);
let text = concat!(
@ -541,14 +526,14 @@ mod tests {
"}\n" //
);
let buffer =
cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
cx.new_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
let position = buffer.read_with(cx, |buffer, cx| {
let snapshot = buffer.snapshot(cx);
snapshot.anchor_before(Point::new(1, 2))
});
let provider = Arc::new(FakeCompletionProvider::new());
let codegen = cx.add_model(|cx| {
let codegen = cx.new_model(|cx| {
Codegen::new(
buffer.clone(),
CodegenKind::Generate { position },
@ -575,10 +560,10 @@ mod tests {
println!("{:?}", &chunk);
provider.send_completion(chunk);
new_text = suffix;
deterministic.run_until_parked();
cx.background_executor.run_until_parked();
}
provider.finish_completion();
deterministic.run_until_parked();
cx.background_executor.run_until_parked();
assert_eq!(
buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),

View File

@ -176,7 +176,7 @@ pub(crate) mod tests {
use super::*;
use std::sync::Arc;
use gpui::AppContext;
use gpui::{AppContext, Context};
use indoc::indoc;
use language::{language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, Point};
use settings::SettingsStore;
@ -227,7 +227,8 @@ pub(crate) mod tests {
#[gpui::test]
fn test_outline_for_prompt(cx: &mut AppContext) {
cx.set_global(SettingsStore::test(cx));
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
language_settings::init(cx);
let text = indoc! {"
struct X {
@ -253,7 +254,7 @@ pub(crate) mod tests {
}
"};
let buffer =
cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
cx.new_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
let snapshot = buffer.read(cx).snapshot();
assert_eq!(

View File

@ -1,54 +0,0 @@
[package]
name = "assistant2"
version = "0.1.0"
edition = "2021"
publish = false
[lib]
path = "src/assistant.rs"
doctest = false
[dependencies]
ai = { package = "ai2", path = "../ai2" }
client = { package = "client2", path = "../client2" }
collections = { path = "../collections"}
editor = { package = "editor2", path = "../editor2" }
fs = { package = "fs2", path = "../fs2" }
gpui = { package = "gpui2", path = "../gpui2" }
language = { package = "language2", path = "../language2" }
menu = { package = "menu2", path = "../menu2" }
multi_buffer = { package = "multi_buffer2", path = "../multi_buffer2" }
project = { package = "project2", path = "../project2" }
search = { package = "search2", path = "../search2" }
semantic_index = { package = "semantic_index2", path = "../semantic_index2" }
settings = { package = "settings2", path = "../settings2" }
theme = { package = "theme2", path = "../theme2" }
ui = { package = "ui2", path = "../ui2" }
util = { path = "../util" }
workspace = { package = "workspace2", path = "../workspace2" }
uuid.workspace = true
log.workspace = true
anyhow.workspace = true
chrono = { version = "0.4", features = ["serde"] }
futures.workspace = true
indoc.workspace = true
isahc.workspace = true
ordered-float.workspace = true
parking_lot.workspace = true
regex.workspace = true
schemars.workspace = true
serde.workspace = true
serde_json.workspace = true
smol.workspace = true
tiktoken-rs.workspace = true
[dev-dependencies]
ai = { package = "ai2", path = "../ai2", features = ["test-support"]}
editor = { package = "editor2", path = "../editor2", features = ["test-support"] }
project = { package = "project2", path = "../project2", features = ["test-support"] }
ctor.workspace = true
env_logger.workspace = true
log.workspace = true
rand.workspace = true

View File

@ -1,63 +0,0 @@
> Hello, I'm Nathan Sobo.
I'm currently your sole creator and the leader of Zed Industries.
Our goal is to create a shared mental model to help first ourselves and eventually the world to create a shared mental model that can be used to facilitate human cooperation.
< Hello, Nathan! Welcome back to Zed. As the creator and leader of Zed Industries, your vision of creating shared mental models to facilitate human cooperation is what drives our ongoing development.
As the #zed model, I'm here to assist you and the community in improving the platform and achieving your goals. Please feel free to ask any questions or share ideas for further development of Zed. Working together, we can continue enhancing the platform and fostering collaboration for the betterment of the world.
> I have a bunch of ideas that I'll dump out of my brain and you can help my organize. Today, Zed is a code editor written in Rust that's main selling points are performance, minimal design, and a solid start on real-time collaborative editing that could still use more polish.
We've decided that our next milestone will be open-sourcing Zed on the Zed platform.
< Nathan, it's exciting to hear that you have more ideas to enhance Zed! Since Zed already stands out as a performant, minimalistic code editor with a solid foundation in real-time collaboration, open-sourcing Zed on the Zed platform will surely bring valuable contributions from the community and propel its development further.
Let's organize your ideas and create a plan for your next milestone. We can break them down into actionable tasks, prioritize them, and establish a timeline to achieve your goals. Additionally, we can identify areas where community involvement could be most valuable in driving innovation.
Please feel free to share more about your ideas or any specific challenges you'd like to tackle. I'm here to help you turn these ideas into a cohesive and successful plan for Zed's future!
@nathan> #zed, can you describe the conversation UI as you currently understand it.
< Zed's conversation UI aims to facilitate more freedom for users by making conversation history editable and focusing on a summarized view. Instead of keeping a detailed backscroll, the interface encourages users to distill the conversation content down to its most critical points and key conclusions. This approach ensures that the vital information from the conversation is highlighted and easily accessible, while still maintaining a history that can be linked to when necessary.
Addressing @nate's concerns, this design choice aims to strike a balance between readability, usability, and information preservation. The summarized view maintains the essence of the conversation without being burdened by extensive backscroll. Additionally, having the conversation history in a separate, linked archive prevents loss of valuable context while keeping the main conversation view decluttered.
@nathan> What details around the conversation view are worth exploring?
< @nathan, there are several aspects of the conversation view that are worth exploring and refining:
1. Editing functionality: Ensuring that all participants can edit and refine the conversation summary effectively while preserving the original context.
2. Permissions and access control: Determining how to manage editing rights, user roles, and visibility of the conversation summary and history.
3. Navigation and linking: Implementing an intuitive system for navigating between the summarized conversation and linked history or reference materials.
4. Visualization and layout: Designing the interface to clearly present the summarized information while providing a visually engaging and organized experience for participants.
5. Threaded conversations and branching: Exploring ways to handle sub-conversations, supplementary discussions, or debates within the primary conversation view.
Considering these aspects will ensure our conversation view design is optimized for usability, collaboration, and accessibility in Zed.
@nathan> @nate, which of these would you like to explore?
@nate> 2 feels like it isn't important at the moment, we can explore that later. Let's start with 4, which I think will lead us to discussion 3 and 5.
#zed share your thoughts on the points we need to consider to design a layout and visualization for a conversation view between you (#zed) and multuple peoople, or between multiple people and multiple bots (you and other bots).
@nathan> Agreed. I'm interested in threading I think more than anything. Or 4 yeah. I think we need to scope the threading conversation. Also, asking #zed to propose the solution... not sure it will be that effective but it's worth a try...
#zed describe a hypothetical conversation UI that meets the description.
< @nate and @nathan, I understand that you'd like to focus on the layout and visualization for a conversation view that can accommodate multiple people and potentially multiple bots as well. Here's a proposal for a conversation UI that meets the requirements:
1. Main conversation area: Display the summarized conversation content in a clean and organized manner, ensuring readability and efficient use of space.
2. Participant avatars and labels: Include user avatars and/or labels to distinguish between users and bots visually, allowing easy identification of each participant's contributions.
3. Active threads indicator: Designate a clear indicator for ongoing threads, which will attract attention to active sub-conversations or related discussions in the conversation view.
4. Dynamic branching: Implement a branching visual structure to present threaded conversations, clearly distinguishing sub-conversations, and enabling users to navigate between them smoothly.
5. Multiple bot interaction support: Facilitate interactions between multiple bots through user-invoked prompts or bot-to-bot communication channels without disrupting the primary conversation flow.
6. Expandable conversation history: Offer a separate, expandable section or panel to access historical content without cluttering the main conversation view.
7. Linking and referencing: Implement a system for hyperlinking relevant information, resources, or previous conversations to the current discussion, making it easy for participants to follow and reference the shared context.
This hypothetical conversation UI aims to address the layout and visualization, threading, and multi-participant aspects of your requirements. Of course, it's just an initial proposal, and we can refine and iterate on it based on your feedback and specific needs.

View File

@ -1,3 +0,0 @@
Push content to a deeper layer.
A context can have multiple sublayers.
You can enable or disable arbitrary sublayers at arbitrary nesting depths when viewing the document.

View File

@ -1,129 +0,0 @@
pub mod assistant_panel;
mod assistant_settings;
mod codegen;
mod prompts;
mod streaming_diff;
use ai::providers::open_ai::Role;
use anyhow::Result;
pub use assistant_panel::AssistantPanel;
use assistant_settings::OpenAIModel;
use chrono::{DateTime, Local};
use collections::HashMap;
use fs::Fs;
use futures::StreamExt;
use gpui::{actions, AppContext, SharedString};
use regex::Regex;
use serde::{Deserialize, Serialize};
use std::{cmp::Reverse, ffi::OsStr, path::PathBuf, sync::Arc};
use util::paths::CONVERSATIONS_DIR;
actions!(
assistant,
[
NewConversation,
Assist,
Split,
CycleMessageRole,
QuoteSelection,
ToggleFocus,
ResetKey,
InlineAssist,
ToggleIncludeConversation,
ToggleRetrieveContext,
]
);
#[derive(
Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize,
)]
struct MessageId(usize);
#[derive(Clone, Debug, Serialize, Deserialize)]
struct MessageMetadata {
role: Role,
sent_at: DateTime<Local>,
status: MessageStatus,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
enum MessageStatus {
Pending,
Done,
Error(SharedString),
}
#[derive(Serialize, Deserialize)]
struct SavedMessage {
id: MessageId,
start: usize,
}
#[derive(Serialize, Deserialize)]
struct SavedConversation {
id: Option<String>,
zed: String,
version: String,
text: String,
messages: Vec<SavedMessage>,
message_metadata: HashMap<MessageId, MessageMetadata>,
summary: String,
model: OpenAIModel,
}
impl SavedConversation {
const VERSION: &'static str = "0.1.0";
}
struct SavedConversationMetadata {
title: String,
path: PathBuf,
mtime: chrono::DateTime<chrono::Local>,
}
impl SavedConversationMetadata {
pub async fn list(fs: Arc<dyn Fs>) -> Result<Vec<Self>> {
fs.create_dir(&CONVERSATIONS_DIR).await?;
let mut paths = fs.read_dir(&CONVERSATIONS_DIR).await?;
let mut conversations = Vec::<SavedConversationMetadata>::new();
while let Some(path) = paths.next().await {
let path = path?;
if path.extension() != Some(OsStr::new("json")) {
continue;
}
let pattern = r" - \d+.zed.json$";
let re = Regex::new(pattern).unwrap();
let metadata = fs.metadata(&path).await?;
if let Some((file_name, metadata)) = path
.file_name()
.and_then(|name| name.to_str())
.zip(metadata)
{
let title = re.replace(file_name, "");
conversations.push(Self {
title: title.into_owned(),
path,
mtime: metadata.mtime.into(),
});
}
}
conversations.sort_unstable_by_key(|conversation| Reverse(conversation.mtime));
Ok(conversations)
}
}
pub fn init(cx: &mut AppContext) {
assistant_panel::init(cx);
}
#[cfg(test)]
#[ctor::ctor]
fn init_logger() {
if std::env::var("RUST_LOG").is_ok() {
env_logger::init();
}
}

File diff suppressed because it is too large Load Diff

View File

@ -1,81 +0,0 @@
use anyhow;
use gpui::Pixels;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::Settings;
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
pub enum OpenAIModel {
#[serde(rename = "gpt-3.5-turbo-0613")]
ThreePointFiveTurbo,
#[serde(rename = "gpt-4-0613")]
Four,
#[serde(rename = "gpt-4-1106-preview")]
FourTurbo,
}
impl OpenAIModel {
pub fn full_name(&self) -> &'static str {
match self {
OpenAIModel::ThreePointFiveTurbo => "gpt-3.5-turbo-0613",
OpenAIModel::Four => "gpt-4-0613",
OpenAIModel::FourTurbo => "gpt-4-1106-preview",
}
}
pub fn short_name(&self) -> &'static str {
match self {
OpenAIModel::ThreePointFiveTurbo => "gpt-3.5-turbo",
OpenAIModel::Four => "gpt-4",
OpenAIModel::FourTurbo => "gpt-4-turbo",
}
}
pub fn cycle(&self) -> Self {
match self {
OpenAIModel::ThreePointFiveTurbo => OpenAIModel::Four,
OpenAIModel::Four => OpenAIModel::FourTurbo,
OpenAIModel::FourTurbo => OpenAIModel::ThreePointFiveTurbo,
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "snake_case")]
pub enum AssistantDockPosition {
Left,
Right,
Bottom,
}
#[derive(Deserialize, Debug)]
pub struct AssistantSettings {
pub button: bool,
pub dock: AssistantDockPosition,
pub default_width: Pixels,
pub default_height: Pixels,
pub default_open_ai_model: OpenAIModel,
}
#[derive(Clone, Default, Serialize, Deserialize, JsonSchema, Debug)]
pub struct AssistantSettingsContent {
pub button: Option<bool>,
pub dock: Option<AssistantDockPosition>,
pub default_width: Option<f32>,
pub default_height: Option<f32>,
pub default_open_ai_model: Option<OpenAIModel>,
}
impl Settings for AssistantSettings {
const KEY: Option<&'static str> = Some("assistant");
type FileContent = AssistantSettingsContent;
fn load(
default_value: &Self::FileContent,
user_values: &[&Self::FileContent],
_: &mut gpui::AppContext,
) -> anyhow::Result<Self> {
Self::load_via_json_merge(default_value, user_values)
}
}

View File

@ -1,688 +0,0 @@
use crate::streaming_diff::{Hunk, StreamingDiff};
use ai::completion::{CompletionProvider, CompletionRequest};
use anyhow::Result;
use editor::{Anchor, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint};
use futures::{channel::mpsc, SinkExt, Stream, StreamExt};
use gpui::{EventEmitter, Model, ModelContext, Task};
use language::{Rope, TransactionId};
use multi_buffer;
use std::{cmp, future, ops::Range, sync::Arc};
pub enum Event {
Finished,
Undone,
}
#[derive(Clone)]
pub enum CodegenKind {
Transform { range: Range<Anchor> },
Generate { position: Anchor },
}
pub struct Codegen {
provider: Arc<dyn CompletionProvider>,
buffer: Model<MultiBuffer>,
snapshot: MultiBufferSnapshot,
kind: CodegenKind,
last_equal_ranges: Vec<Range<Anchor>>,
transaction_id: Option<TransactionId>,
error: Option<anyhow::Error>,
generation: Task<()>,
idle: bool,
_subscription: gpui::Subscription,
}
impl EventEmitter<Event> for Codegen {}
impl Codegen {
pub fn new(
buffer: Model<MultiBuffer>,
kind: CodegenKind,
provider: Arc<dyn CompletionProvider>,
cx: &mut ModelContext<Self>,
) -> Self {
let snapshot = buffer.read(cx).snapshot(cx);
Self {
provider,
buffer: buffer.clone(),
snapshot,
kind,
last_equal_ranges: Default::default(),
transaction_id: Default::default(),
error: Default::default(),
idle: true,
generation: Task::ready(()),
_subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
}
}
fn handle_buffer_event(
&mut self,
_buffer: Model<MultiBuffer>,
event: &multi_buffer::Event,
cx: &mut ModelContext<Self>,
) {
if let multi_buffer::Event::TransactionUndone { transaction_id } = event {
if self.transaction_id == Some(*transaction_id) {
self.transaction_id = None;
self.generation = Task::ready(());
cx.emit(Event::Undone);
}
}
}
pub fn range(&self) -> Range<Anchor> {
match &self.kind {
CodegenKind::Transform { range } => range.clone(),
CodegenKind::Generate { position } => position.bias_left(&self.snapshot)..*position,
}
}
pub fn kind(&self) -> &CodegenKind {
&self.kind
}
pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
&self.last_equal_ranges
}
pub fn idle(&self) -> bool {
self.idle
}
pub fn error(&self) -> Option<&anyhow::Error> {
self.error.as_ref()
}
pub fn start(&mut self, prompt: Box<dyn CompletionRequest>, cx: &mut ModelContext<Self>) {
let range = self.range();
let snapshot = self.snapshot.clone();
let selected_text = snapshot
.text_for_range(range.start..range.end)
.collect::<Rope>();
let selection_start = range.start.to_point(&snapshot);
let suggested_line_indent = snapshot
.suggested_indents(selection_start.row..selection_start.row + 1, cx)
.into_values()
.next()
.unwrap_or_else(|| snapshot.indent_size_for_line(selection_start.row));
let response = self.provider.complete(prompt);
self.generation = cx.spawn(|this, mut cx| {
async move {
let generate = async {
let mut edit_start = range.start.to_offset(&snapshot);
let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
let diff = cx.background_executor().spawn(async move {
let chunks = strip_invalid_spans_from_codeblock(response.await?);
futures::pin_mut!(chunks);
let mut diff = StreamingDiff::new(selected_text.to_string());
let mut new_text = String::new();
let mut base_indent = None;
let mut line_indent = None;
let mut first_line = true;
while let Some(chunk) = chunks.next().await {
let chunk = chunk?;
let mut lines = chunk.split('\n').peekable();
while let Some(line) = lines.next() {
new_text.push_str(line);
if line_indent.is_none() {
if let Some(non_whitespace_ch_ix) =
new_text.find(|ch: char| !ch.is_whitespace())
{
line_indent = Some(non_whitespace_ch_ix);
base_indent = base_indent.or(line_indent);
let line_indent = line_indent.unwrap();
let base_indent = base_indent.unwrap();
let indent_delta = line_indent as i32 - base_indent as i32;
let mut corrected_indent_len = cmp::max(
0,
suggested_line_indent.len as i32 + indent_delta,
)
as usize;
if first_line {
corrected_indent_len = corrected_indent_len
.saturating_sub(selection_start.column as usize);
}
let indent_char = suggested_line_indent.char();
let mut indent_buffer = [0; 4];
let indent_str =
indent_char.encode_utf8(&mut indent_buffer);
new_text.replace_range(
..line_indent,
&indent_str.repeat(corrected_indent_len),
);
}
}
if line_indent.is_some() {
hunks_tx.send(diff.push_new(&new_text)).await?;
new_text.clear();
}
if lines.peek().is_some() {
hunks_tx.send(diff.push_new("\n")).await?;
line_indent = None;
first_line = false;
}
}
}
hunks_tx.send(diff.push_new(&new_text)).await?;
hunks_tx.send(diff.finish()).await?;
anyhow::Ok(())
});
while let Some(hunks) = hunks_rx.next().await {
this.update(&mut cx, |this, cx| {
this.last_equal_ranges.clear();
let transaction = this.buffer.update(cx, |buffer, cx| {
// Avoid grouping assistant edits with user edits.
buffer.finalize_last_transaction(cx);
buffer.start_transaction(cx);
buffer.edit(
hunks.into_iter().filter_map(|hunk| match hunk {
Hunk::Insert { text } => {
let edit_start = snapshot.anchor_after(edit_start);
Some((edit_start..edit_start, text))
}
Hunk::Remove { len } => {
let edit_end = edit_start + len;
let edit_range = snapshot.anchor_after(edit_start)
..snapshot.anchor_before(edit_end);
edit_start = edit_end;
Some((edit_range, String::new()))
}
Hunk::Keep { len } => {
let edit_end = edit_start + len;
let edit_range = snapshot.anchor_after(edit_start)
..snapshot.anchor_before(edit_end);
edit_start = edit_end;
this.last_equal_ranges.push(edit_range);
None
}
}),
None,
cx,
);
buffer.end_transaction(cx)
});
if let Some(transaction) = transaction {
if let Some(first_transaction) = this.transaction_id {
// Group all assistant edits into the first transaction.
this.buffer.update(cx, |buffer, cx| {
buffer.merge_transactions(
transaction,
first_transaction,
cx,
)
});
} else {
this.transaction_id = Some(transaction);
this.buffer.update(cx, |buffer, cx| {
buffer.finalize_last_transaction(cx)
});
}
}
cx.notify();
})?;
}
diff.await?;
anyhow::Ok(())
};
let result = generate.await;
this.update(&mut cx, |this, cx| {
this.last_equal_ranges.clear();
this.idle = true;
if let Err(error) = result {
this.error = Some(error);
}
cx.emit(Event::Finished);
cx.notify();
})
.ok();
}
});
self.error.take();
self.idle = false;
cx.notify();
}
pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
if let Some(transaction_id) = self.transaction_id {
self.buffer
.update(cx, |buffer, cx| buffer.undo_transaction(transaction_id, cx));
}
}
}
fn strip_invalid_spans_from_codeblock(
stream: impl Stream<Item = Result<String>>,
) -> impl Stream<Item = Result<String>> {
let mut first_line = true;
let mut buffer = String::new();
let mut starts_with_markdown_codeblock = false;
let mut includes_start_or_end_span = false;
stream.filter_map(move |chunk| {
let chunk = match chunk {
Ok(chunk) => chunk,
Err(err) => return future::ready(Some(Err(err))),
};
buffer.push_str(&chunk);
if buffer.len() > "<|S|".len() && buffer.starts_with("<|S|") {
includes_start_or_end_span = true;
buffer = buffer
.strip_prefix("<|S|>")
.or_else(|| buffer.strip_prefix("<|S|"))
.unwrap_or(&buffer)
.to_string();
} else if buffer.ends_with("|E|>") {
includes_start_or_end_span = true;
} else if buffer.starts_with("<|")
|| buffer.starts_with("<|S")
|| buffer.starts_with("<|S|")
|| buffer.ends_with("|")
|| buffer.ends_with("|E")
|| buffer.ends_with("|E|")
{
return future::ready(None);
}
if first_line {
if buffer == "" || buffer == "`" || buffer == "``" {
return future::ready(None);
} else if buffer.starts_with("```") {
starts_with_markdown_codeblock = true;
if let Some(newline_ix) = buffer.find('\n') {
buffer.replace_range(..newline_ix + 1, "");
first_line = false;
} else {
return future::ready(None);
}
}
}
let mut text = buffer.to_string();
if starts_with_markdown_codeblock {
text = text
.strip_suffix("\n```\n")
.or_else(|| text.strip_suffix("\n```"))
.or_else(|| text.strip_suffix("\n``"))
.or_else(|| text.strip_suffix("\n`"))
.or_else(|| text.strip_suffix('\n'))
.unwrap_or(&text)
.to_string();
}
if includes_start_or_end_span {
text = text
.strip_suffix("|E|>")
.or_else(|| text.strip_suffix("E|>"))
.or_else(|| text.strip_prefix("|>"))
.or_else(|| text.strip_prefix(">"))
.unwrap_or(&text)
.to_string();
};
if text.contains('\n') {
first_line = false;
}
let remainder = buffer.split_off(text.len());
let result = if buffer.is_empty() {
None
} else {
Some(Ok(buffer.clone()))
};
buffer = remainder;
future::ready(result)
})
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use super::*;
use ai::test::FakeCompletionProvider;
use futures::stream::{self};
use gpui::{Context, TestAppContext};
use indoc::indoc;
use language::{language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, Point};
use rand::prelude::*;
use serde::Serialize;
use settings::SettingsStore;
#[derive(Serialize)]
pub struct DummyCompletionRequest {
pub name: String,
}
impl CompletionRequest for DummyCompletionRequest {
fn data(&self) -> serde_json::Result<String> {
serde_json::to_string(self)
}
}
#[gpui::test(iterations = 10)]
async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
cx.set_global(cx.update(SettingsStore::test));
cx.update(language_settings::init);
let text = indoc! {"
fn main() {
let x = 0;
for _ in 0..10 {
x += 1;
}
}
"};
let buffer =
cx.new_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
let range = buffer.read_with(cx, |buffer, cx| {
let snapshot = buffer.snapshot(cx);
snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
});
let provider = Arc::new(FakeCompletionProvider::new());
let codegen = cx.new_model(|cx| {
Codegen::new(
buffer.clone(),
CodegenKind::Transform { range },
provider.clone(),
cx,
)
});
let request = Box::new(DummyCompletionRequest {
name: "test".to_string(),
});
codegen.update(cx, |codegen, cx| codegen.start(request, cx));
let mut new_text = concat!(
" let mut x = 0;\n",
" while x < 10 {\n",
" x += 1;\n",
" }",
);
while !new_text.is_empty() {
let max_len = cmp::min(new_text.len(), 10);
let len = rng.gen_range(1..=max_len);
let (chunk, suffix) = new_text.split_at(len);
println!("CHUNK: {:?}", &chunk);
provider.send_completion(chunk);
new_text = suffix;
cx.background_executor.run_until_parked();
}
provider.finish_completion();
cx.background_executor.run_until_parked();
assert_eq!(
buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
indoc! {"
fn main() {
let mut x = 0;
while x < 10 {
x += 1;
}
}
"}
);
}
#[gpui::test(iterations = 10)]
async fn test_autoindent_when_generating_past_indentation(
cx: &mut TestAppContext,
mut rng: StdRng,
) {
cx.set_global(cx.update(SettingsStore::test));
cx.update(language_settings::init);
let text = indoc! {"
fn main() {
le
}
"};
let buffer =
cx.new_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
let position = buffer.read_with(cx, |buffer, cx| {
let snapshot = buffer.snapshot(cx);
snapshot.anchor_before(Point::new(1, 6))
});
let provider = Arc::new(FakeCompletionProvider::new());
let codegen = cx.new_model(|cx| {
Codegen::new(
buffer.clone(),
CodegenKind::Generate { position },
provider.clone(),
cx,
)
});
let request = Box::new(DummyCompletionRequest {
name: "test".to_string(),
});
codegen.update(cx, |codegen, cx| codegen.start(request, cx));
let mut new_text = concat!(
"t mut x = 0;\n",
"while x < 10 {\n",
" x += 1;\n",
"}", //
);
while !new_text.is_empty() {
let max_len = cmp::min(new_text.len(), 10);
let len = rng.gen_range(1..=max_len);
let (chunk, suffix) = new_text.split_at(len);
provider.send_completion(chunk);
new_text = suffix;
cx.background_executor.run_until_parked();
}
provider.finish_completion();
cx.background_executor.run_until_parked();
assert_eq!(
buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
indoc! {"
fn main() {
let mut x = 0;
while x < 10 {
x += 1;
}
}
"}
);
}
#[gpui::test(iterations = 10)]
async fn test_autoindent_when_generating_before_indentation(
cx: &mut TestAppContext,
mut rng: StdRng,
) {
cx.set_global(cx.update(SettingsStore::test));
cx.update(language_settings::init);
let text = concat!(
"fn main() {\n",
" \n",
"}\n" //
);
let buffer =
cx.new_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
let position = buffer.read_with(cx, |buffer, cx| {
let snapshot = buffer.snapshot(cx);
snapshot.anchor_before(Point::new(1, 2))
});
let provider = Arc::new(FakeCompletionProvider::new());
let codegen = cx.new_model(|cx| {
Codegen::new(
buffer.clone(),
CodegenKind::Generate { position },
provider.clone(),
cx,
)
});
let request = Box::new(DummyCompletionRequest {
name: "test".to_string(),
});
codegen.update(cx, |codegen, cx| codegen.start(request, cx));
let mut new_text = concat!(
"let mut x = 0;\n",
"while x < 10 {\n",
" x += 1;\n",
"}", //
);
while !new_text.is_empty() {
let max_len = cmp::min(new_text.len(), 10);
let len = rng.gen_range(1..=max_len);
let (chunk, suffix) = new_text.split_at(len);
println!("{:?}", &chunk);
provider.send_completion(chunk);
new_text = suffix;
cx.background_executor.run_until_parked();
}
provider.finish_completion();
cx.background_executor.run_until_parked();
assert_eq!(
buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
indoc! {"
fn main() {
let mut x = 0;
while x < 10 {
x += 1;
}
}
"}
);
}
#[gpui::test]
async fn test_strip_invalid_spans_from_codeblock() {
assert_eq!(
strip_invalid_spans_from_codeblock(chunks("Lorem ipsum dolor", 2))
.map(|chunk| chunk.unwrap())
.collect::<String>()
.await,
"Lorem ipsum dolor"
);
assert_eq!(
strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor", 2))
.map(|chunk| chunk.unwrap())
.collect::<String>()
.await,
"Lorem ipsum dolor"
);
assert_eq!(
strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor\n```", 2))
.map(|chunk| chunk.unwrap())
.collect::<String>()
.await,
"Lorem ipsum dolor"
);
assert_eq!(
strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor\n```\n", 2))
.map(|chunk| chunk.unwrap())
.collect::<String>()
.await,
"Lorem ipsum dolor"
);
assert_eq!(
strip_invalid_spans_from_codeblock(chunks(
"```html\n```js\nLorem ipsum dolor\n```\n```",
2
))
.map(|chunk| chunk.unwrap())
.collect::<String>()
.await,
"```js\nLorem ipsum dolor\n```"
);
assert_eq!(
strip_invalid_spans_from_codeblock(chunks("``\nLorem ipsum dolor\n```", 2))
.map(|chunk| chunk.unwrap())
.collect::<String>()
.await,
"``\nLorem ipsum dolor\n```"
);
assert_eq!(
strip_invalid_spans_from_codeblock(chunks("<|S|Lorem ipsum|E|>", 2))
.map(|chunk| chunk.unwrap())
.collect::<String>()
.await,
"Lorem ipsum"
);
assert_eq!(
strip_invalid_spans_from_codeblock(chunks("<|S|>Lorem ipsum", 2))
.map(|chunk| chunk.unwrap())
.collect::<String>()
.await,
"Lorem ipsum"
);
assert_eq!(
strip_invalid_spans_from_codeblock(chunks("```\n<|S|>Lorem ipsum\n```", 2))
.map(|chunk| chunk.unwrap())
.collect::<String>()
.await,
"Lorem ipsum"
);
assert_eq!(
strip_invalid_spans_from_codeblock(chunks("```\n<|S|Lorem ipsum|E|>\n```", 2))
.map(|chunk| chunk.unwrap())
.collect::<String>()
.await,
"Lorem ipsum"
);
fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
stream::iter(
text.chars()
.collect::<Vec<_>>()
.chunks(size)
.map(|chunk| Ok(chunk.iter().collect::<String>()))
.collect::<Vec<_>>(),
)
}
}
fn rust_lang() -> Language {
Language::new(
LanguageConfig {
name: "Rust".into(),
path_suffixes: vec!["rs".to_string()],
..Default::default()
},
Some(tree_sitter_rust::language()),
)
.with_indents_query(
r#"
(call_expression) @indent
(field_expression) @indent
(_ "(" ")" @end) @indent
(_ "{" "}" @end) @indent
"#,
)
.unwrap()
}
}

View File

@ -1,389 +0,0 @@
use ai::models::LanguageModel;
use ai::prompts::base::{PromptArguments, PromptChain, PromptPriority, PromptTemplate};
use ai::prompts::file_context::FileContext;
use ai::prompts::generate::GenerateInlineContent;
use ai::prompts::preamble::EngineerPreamble;
use ai::prompts::repository_context::{PromptCodeSnippet, RepositoryContext};
use ai::providers::open_ai::OpenAILanguageModel;
use language::{BufferSnapshot, OffsetRangeExt, ToOffset};
use std::cmp::{self, Reverse};
use std::ops::Range;
use std::sync::Arc;
#[allow(dead_code)]
fn summarize(buffer: &BufferSnapshot, selected_range: Range<impl ToOffset>) -> String {
#[derive(Debug)]
struct Match {
collapse: Range<usize>,
keep: Vec<Range<usize>>,
}
let selected_range = selected_range.to_offset(buffer);
let mut ts_matches = buffer.matches(0..buffer.len(), |grammar| {
Some(&grammar.embedding_config.as_ref()?.query)
});
let configs = ts_matches
.grammars()
.iter()
.map(|g| g.embedding_config.as_ref().unwrap())
.collect::<Vec<_>>();
let mut matches = Vec::new();
while let Some(mat) = ts_matches.peek() {
let config = &configs[mat.grammar_index];
if let Some(collapse) = mat.captures.iter().find_map(|cap| {
if Some(cap.index) == config.collapse_capture_ix {
Some(cap.node.byte_range())
} else {
None
}
}) {
let mut keep = Vec::new();
for capture in mat.captures.iter() {
if Some(capture.index) == config.keep_capture_ix {
keep.push(capture.node.byte_range());
} else {
continue;
}
}
ts_matches.advance();
matches.push(Match { collapse, keep });
} else {
ts_matches.advance();
}
}
matches.sort_unstable_by_key(|mat| (mat.collapse.start, Reverse(mat.collapse.end)));
let mut matches = matches.into_iter().peekable();
let mut summary = String::new();
let mut offset = 0;
let mut flushed_selection = false;
while let Some(mat) = matches.next() {
// Keep extending the collapsed range if the next match surrounds
// the current one.
while let Some(next_mat) = matches.peek() {
if mat.collapse.start <= next_mat.collapse.start
&& mat.collapse.end >= next_mat.collapse.end
{
matches.next().unwrap();
} else {
break;
}
}
if offset > mat.collapse.start {
// Skip collapsed nodes that have already been summarized.
offset = cmp::max(offset, mat.collapse.end);
continue;
}
if offset <= selected_range.start && selected_range.start <= mat.collapse.end {
if !flushed_selection {
// The collapsed node ends after the selection starts, so we'll flush the selection first.
summary.extend(buffer.text_for_range(offset..selected_range.start));
summary.push_str("<|S|");
if selected_range.end == selected_range.start {
summary.push_str(">");
} else {
summary.extend(buffer.text_for_range(selected_range.clone()));
summary.push_str("|E|>");
}
offset = selected_range.end;
flushed_selection = true;
}
// If the selection intersects the collapsed node, we won't collapse it.
if selected_range.end >= mat.collapse.start {
continue;
}
}
summary.extend(buffer.text_for_range(offset..mat.collapse.start));
for keep in mat.keep {
summary.extend(buffer.text_for_range(keep));
}
offset = mat.collapse.end;
}
// Flush selection if we haven't already done so.
if !flushed_selection && offset <= selected_range.start {
summary.extend(buffer.text_for_range(offset..selected_range.start));
summary.push_str("<|S|");
if selected_range.end == selected_range.start {
summary.push_str(">");
} else {
summary.extend(buffer.text_for_range(selected_range.clone()));
summary.push_str("|E|>");
}
offset = selected_range.end;
}
summary.extend(buffer.text_for_range(offset..buffer.len()));
summary
}
pub fn generate_content_prompt(
user_prompt: String,
language_name: Option<&str>,
buffer: BufferSnapshot,
range: Range<usize>,
search_results: Vec<PromptCodeSnippet>,
model: &str,
project_name: Option<String>,
) -> anyhow::Result<String> {
// Using new Prompt Templates
let openai_model: Arc<dyn LanguageModel> = Arc::new(OpenAILanguageModel::load(model));
let lang_name = if let Some(language_name) = language_name {
Some(language_name.to_string())
} else {
None
};
let args = PromptArguments {
model: openai_model,
language_name: lang_name.clone(),
project_name,
snippets: search_results.clone(),
reserved_tokens: 1000,
buffer: Some(buffer),
selected_range: Some(range),
user_prompt: Some(user_prompt.clone()),
};
let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
(PromptPriority::Mandatory, Box::new(EngineerPreamble {})),
(
PromptPriority::Ordered { order: 1 },
Box::new(RepositoryContext {}),
),
(
PromptPriority::Ordered { order: 0 },
Box::new(FileContext {}),
),
(
PromptPriority::Mandatory,
Box::new(GenerateInlineContent {}),
),
];
let chain = PromptChain::new(args, templates);
let (prompt, _) = chain.generate(true)?;
anyhow::Ok(prompt)
}
#[cfg(test)]
pub(crate) mod tests {
use super::*;
use std::sync::Arc;
use gpui::{AppContext, Context};
use indoc::indoc;
use language::{language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, Point};
use settings::SettingsStore;
pub(crate) fn rust_lang() -> Language {
Language::new(
LanguageConfig {
name: "Rust".into(),
path_suffixes: vec!["rs".to_string()],
..Default::default()
},
Some(tree_sitter_rust::language()),
)
.with_embedding_query(
r#"
(
[(line_comment) (attribute_item)]* @context
.
[
(struct_item
name: (_) @name)
(enum_item
name: (_) @name)
(impl_item
trait: (_)? @name
"for"? @name
type: (_) @name)
(trait_item
name: (_) @name)
(function_item
name: (_) @name
body: (block
"{" @keep
"}" @keep) @collapse)
(macro_definition
name: (_) @name)
] @item
)
"#,
)
.unwrap()
}
#[gpui::test]
fn test_outline_for_prompt(cx: &mut AppContext) {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
language_settings::init(cx);
let text = indoc! {"
struct X {
a: usize,
b: usize,
}
impl X {
fn new() -> Self {
let a = 1;
let b = 2;
Self { a, b }
}
pub fn a(&self, param: bool) -> usize {
self.a
}
pub fn b(&self) -> usize {
self.b
}
}
"};
let buffer =
cx.new_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
let snapshot = buffer.read(cx).snapshot();
assert_eq!(
summarize(&snapshot, Point::new(1, 4)..Point::new(1, 4)),
indoc! {"
struct X {
<|S|>a: usize,
b: usize,
}
impl X {
fn new() -> Self {}
pub fn a(&self, param: bool) -> usize {}
pub fn b(&self) -> usize {}
}
"}
);
assert_eq!(
summarize(&snapshot, Point::new(8, 12)..Point::new(8, 14)),
indoc! {"
struct X {
a: usize,
b: usize,
}
impl X {
fn new() -> Self {
let <|S|a |E|>= 1;
let b = 2;
Self { a, b }
}
pub fn a(&self, param: bool) -> usize {}
pub fn b(&self) -> usize {}
}
"}
);
assert_eq!(
summarize(&snapshot, Point::new(6, 0)..Point::new(6, 0)),
indoc! {"
struct X {
a: usize,
b: usize,
}
impl X {
<|S|>
fn new() -> Self {}
pub fn a(&self, param: bool) -> usize {}
pub fn b(&self) -> usize {}
}
"}
);
assert_eq!(
summarize(&snapshot, Point::new(21, 0)..Point::new(21, 0)),
indoc! {"
struct X {
a: usize,
b: usize,
}
impl X {
fn new() -> Self {}
pub fn a(&self, param: bool) -> usize {}
pub fn b(&self) -> usize {}
}
<|S|>"}
);
// Ensure nested functions get collapsed properly.
let text = indoc! {"
struct X {
a: usize,
b: usize,
}
impl X {
fn new() -> Self {
let a = 1;
let b = 2;
Self { a, b }
}
pub fn a(&self, param: bool) -> usize {
let a = 30;
fn nested() -> usize {
3
}
self.a + nested()
}
pub fn b(&self) -> usize {
self.b
}
}
"};
buffer.update(cx, |buffer, cx| buffer.set_text(text, cx));
let snapshot = buffer.read(cx).snapshot();
assert_eq!(
summarize(&snapshot, Point::new(0, 0)..Point::new(0, 0)),
indoc! {"
<|S|>struct X {
a: usize,
b: usize,
}
impl X {
fn new() -> Self {}
pub fn a(&self, param: bool) -> usize {}
pub fn b(&self) -> usize {}
}
"}
);
}
}

View File

@ -1,293 +0,0 @@
use collections::HashMap;
use ordered_float::OrderedFloat;
use std::{
cmp,
fmt::{self, Debug},
ops::Range,
};
struct Matrix {
cells: Vec<f64>,
rows: usize,
cols: usize,
}
impl Matrix {
fn new() -> Self {
Self {
cells: Vec::new(),
rows: 0,
cols: 0,
}
}
fn resize(&mut self, rows: usize, cols: usize) {
self.cells.resize(rows * cols, 0.);
self.rows = rows;
self.cols = cols;
}
fn get(&self, row: usize, col: usize) -> f64 {
if row >= self.rows {
panic!("row out of bounds")
}
if col >= self.cols {
panic!("col out of bounds")
}
self.cells[col * self.rows + row]
}
fn set(&mut self, row: usize, col: usize, value: f64) {
if row >= self.rows {
panic!("row out of bounds")
}
if col >= self.cols {
panic!("col out of bounds")
}
self.cells[col * self.rows + row] = value;
}
}
impl Debug for Matrix {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
writeln!(f)?;
for i in 0..self.rows {
for j in 0..self.cols {
write!(f, "{:5}", self.get(i, j))?;
}
writeln!(f)?;
}
Ok(())
}
}
#[derive(Debug)]
pub enum Hunk {
Insert { text: String },
Remove { len: usize },
Keep { len: usize },
}
pub struct StreamingDiff {
old: Vec<char>,
new: Vec<char>,
scores: Matrix,
old_text_ix: usize,
new_text_ix: usize,
equal_runs: HashMap<(usize, usize), u32>,
}
impl StreamingDiff {
const INSERTION_SCORE: f64 = -1.;
const DELETION_SCORE: f64 = -20.;
const EQUALITY_BASE: f64 = 1.8;
const MAX_EQUALITY_EXPONENT: i32 = 16;
pub fn new(old: String) -> Self {
let old = old.chars().collect::<Vec<_>>();
let mut scores = Matrix::new();
scores.resize(old.len() + 1, 1);
for i in 0..=old.len() {
scores.set(i, 0, i as f64 * Self::DELETION_SCORE);
}
Self {
old,
new: Vec::new(),
scores,
old_text_ix: 0,
new_text_ix: 0,
equal_runs: Default::default(),
}
}
pub fn push_new(&mut self, text: &str) -> Vec<Hunk> {
self.new.extend(text.chars());
self.scores.resize(self.old.len() + 1, self.new.len() + 1);
for j in self.new_text_ix + 1..=self.new.len() {
self.scores.set(0, j, j as f64 * Self::INSERTION_SCORE);
for i in 1..=self.old.len() {
let insertion_score = self.scores.get(i, j - 1) + Self::INSERTION_SCORE;
let deletion_score = self.scores.get(i - 1, j) + Self::DELETION_SCORE;
let equality_score = if self.old[i - 1] == self.new[j - 1] {
let mut equal_run = self.equal_runs.get(&(i - 1, j - 1)).copied().unwrap_or(0);
equal_run += 1;
self.equal_runs.insert((i, j), equal_run);
let exponent = cmp::min(equal_run as i32 / 4, Self::MAX_EQUALITY_EXPONENT);
self.scores.get(i - 1, j - 1) + Self::EQUALITY_BASE.powi(exponent)
} else {
f64::NEG_INFINITY
};
let score = insertion_score.max(deletion_score).max(equality_score);
self.scores.set(i, j, score);
}
}
let mut max_score = f64::NEG_INFINITY;
let mut next_old_text_ix = self.old_text_ix;
let next_new_text_ix = self.new.len();
for i in self.old_text_ix..=self.old.len() {
let score = self.scores.get(i, next_new_text_ix);
if score > max_score {
max_score = score;
next_old_text_ix = i;
}
}
let hunks = self.backtrack(next_old_text_ix, next_new_text_ix);
self.old_text_ix = next_old_text_ix;
self.new_text_ix = next_new_text_ix;
hunks
}
fn backtrack(&self, old_text_ix: usize, new_text_ix: usize) -> Vec<Hunk> {
let mut pending_insert: Option<Range<usize>> = None;
let mut hunks = Vec::new();
let mut i = old_text_ix;
let mut j = new_text_ix;
while (i, j) != (self.old_text_ix, self.new_text_ix) {
let insertion_score = if j > self.new_text_ix {
Some((i, j - 1))
} else {
None
};
let deletion_score = if i > self.old_text_ix {
Some((i - 1, j))
} else {
None
};
let equality_score = if i > self.old_text_ix && j > self.new_text_ix {
if self.old[i - 1] == self.new[j - 1] {
Some((i - 1, j - 1))
} else {
None
}
} else {
None
};
let (prev_i, prev_j) = [insertion_score, deletion_score, equality_score]
.iter()
.max_by_key(|cell| cell.map(|(i, j)| OrderedFloat(self.scores.get(i, j))))
.unwrap()
.unwrap();
if prev_i == i && prev_j == j - 1 {
if let Some(pending_insert) = pending_insert.as_mut() {
pending_insert.start = prev_j;
} else {
pending_insert = Some(prev_j..j);
}
} else {
if let Some(range) = pending_insert.take() {
hunks.push(Hunk::Insert {
text: self.new[range].iter().collect(),
});
}
let char_len = self.old[i - 1].len_utf8();
if prev_i == i - 1 && prev_j == j {
if let Some(Hunk::Remove { len }) = hunks.last_mut() {
*len += char_len;
} else {
hunks.push(Hunk::Remove { len: char_len })
}
} else {
if let Some(Hunk::Keep { len }) = hunks.last_mut() {
*len += char_len;
} else {
hunks.push(Hunk::Keep { len: char_len })
}
}
}
i = prev_i;
j = prev_j;
}
if let Some(range) = pending_insert.take() {
hunks.push(Hunk::Insert {
text: self.new[range].iter().collect(),
});
}
hunks.reverse();
hunks
}
pub fn finish(self) -> Vec<Hunk> {
self.backtrack(self.old.len(), self.new.len())
}
}
#[cfg(test)]
mod tests {
use std::env;
use super::*;
use rand::prelude::*;
#[gpui::test(iterations = 100)]
fn test_random_diffs(mut rng: StdRng) {
let old_text_len = env::var("OLD_TEXT_LEN")
.map(|i| i.parse().expect("invalid `OLD_TEXT_LEN` variable"))
.unwrap_or(10);
let new_text_len = env::var("NEW_TEXT_LEN")
.map(|i| i.parse().expect("invalid `NEW_TEXT_LEN` variable"))
.unwrap_or(10);
let old = util::RandomCharIter::new(&mut rng)
.take(old_text_len)
.collect::<String>();
log::info!("old text: {:?}", old);
let mut diff = StreamingDiff::new(old.clone());
let mut hunks = Vec::new();
let mut new_len = 0;
let mut new = String::new();
while new_len < new_text_len {
let new_chunk_len = rng.gen_range(1..=new_text_len - new_len);
let new_chunk = util::RandomCharIter::new(&mut rng)
.take(new_len)
.collect::<String>();
log::info!("new chunk: {:?}", new_chunk);
new_len += new_chunk_len;
new.push_str(&new_chunk);
let new_hunks = diff.push_new(&new_chunk);
log::info!("hunks: {:?}", new_hunks);
hunks.extend(new_hunks);
}
let final_hunks = diff.finish();
log::info!("final hunks: {:?}", final_hunks);
hunks.extend(final_hunks);
log::info!("new text: {:?}", new);
let mut old_ix = 0;
let mut new_ix = 0;
let mut patched = String::new();
for hunk in hunks {
match hunk {
Hunk::Keep { len } => {
assert_eq!(&old[old_ix..old_ix + len], &new[new_ix..new_ix + len]);
patched.push_str(&old[old_ix..old_ix + len]);
old_ix += len;
new_ix += len;
}
Hunk::Remove { len } => {
old_ix += len;
}
Hunk::Insert { text } => {
assert_eq!(text, &new[new_ix..new_ix + text.len()]);
patched.push_str(&text);
new_ix += text.len();
}
}
}
assert_eq!(patched, new);
}
}

View File

@ -13,11 +13,10 @@ gpui = { path = "../gpui" }
collections = { path = "../collections" }
util = { path = "../util" }
rodio ={version = "0.17.1", default-features=false, features = ["wav"]}
log.workspace = true
futures.workspace = true
anyhow.workspace = true
parking_lot.workspace = true
[dev-dependencies]

View File

@ -60,7 +60,7 @@ impl Audio {
return;
}
cx.update_global::<Self, _, _>(|this, cx| {
cx.update_global::<Self, _>(|this, cx| {
let output_handle = this.ensure_output_exists()?;
let source = SoundRegistry::global(cx).get(sound.file()).log_err()?;
output_handle.play_raw(source).log_err()?;
@ -73,7 +73,7 @@ impl Audio {
return;
}
cx.update_global::<Self, _, _>(|this, _| {
cx.update_global::<Self, _>(|this, _| {
this._output_stream.take();
this.output_handle.take();
});

View File

@ -1,24 +0,0 @@
[package]
name = "audio2"
version = "0.1.0"
edition = "2021"
publish = false
[lib]
path = "src/audio2.rs"
doctest = false
[dependencies]
gpui = { package = "gpui2", path = "../gpui2" }
collections = { path = "../collections" }
util = { path = "../util" }
rodio ={version = "0.17.1", default-features=false, features = ["wav"]}
log.workspace = true
futures.workspace = true
anyhow.workspace = true
parking_lot.workspace = true
[dev-dependencies]

View File

@ -1,23 +0,0 @@
[package]
name = "audio"
version = "0.1.0"
edition = "2021"
publish = false
[lib]
path = "src/audio.rs"
doctest = false
[dependencies]
gpui = { path = "../gpui" }
collections = { path = "../collections" }
util = { path = "../util" }
rodio ={version = "0.17.1", default-features=false, features = ["wav"]}
log.workspace = true
anyhow.workspace = true
parking_lot.workspace = true
[dev-dependencies]

View File

@ -1,44 +0,0 @@
use std::{io::Cursor, sync::Arc};
use anyhow::Result;
use collections::HashMap;
use gpui::{AppContext, AssetSource};
use rodio::{
source::{Buffered, SamplesConverter},
Decoder, Source,
};
type Sound = Buffered<SamplesConverter<Decoder<Cursor<Vec<u8>>>, f32>>;
pub struct SoundRegistry {
cache: Arc<parking_lot::Mutex<HashMap<String, Sound>>>,
assets: Box<dyn AssetSource>,
}
impl SoundRegistry {
pub fn new(source: impl AssetSource) -> Arc<Self> {
Arc::new(Self {
cache: Default::default(),
assets: Box::new(source),
})
}
pub fn global(cx: &AppContext) -> Arc<Self> {
cx.global::<Arc<Self>>().clone()
}
pub fn get(&self, name: &str) -> Result<impl Source<Item = f32>> {
if let Some(wav) = self.cache.lock().get(name) {
return Ok(wav.clone());
}
let path = format!("sounds/{}.wav", name);
let bytes = self.assets.load(&path)?.into_owned();
let cursor = Cursor::new(bytes);
let source = Decoder::new(cursor)?.convert_samples::<f32>().buffered();
self.cache.lock().insert(name.to_string(), source.clone());
Ok(source)
}
}

View File

@ -1,81 +0,0 @@
use assets::SoundRegistry;
use gpui::{AppContext, AssetSource};
use rodio::{OutputStream, OutputStreamHandle};
use util::ResultExt;
mod assets;
pub fn init(source: impl AssetSource, cx: &mut AppContext) {
cx.set_global(SoundRegistry::new(source));
cx.set_global(Audio::new());
}
pub enum Sound {
Joined,
Leave,
Mute,
Unmute,
StartScreenshare,
StopScreenshare,
}
impl Sound {
fn file(&self) -> &'static str {
match self {
Self::Joined => "joined_call",
Self::Leave => "leave_call",
Self::Mute => "mute",
Self::Unmute => "unmute",
Self::StartScreenshare => "start_screenshare",
Self::StopScreenshare => "stop_screenshare",
}
}
}
pub struct Audio {
_output_stream: Option<OutputStream>,
output_handle: Option<OutputStreamHandle>,
}
impl Audio {
pub fn new() -> Self {
Self {
_output_stream: None,
output_handle: None,
}
}
fn ensure_output_exists(&mut self) -> Option<&OutputStreamHandle> {
if self.output_handle.is_none() {
let (_output_stream, output_handle) = OutputStream::try_default().log_err().unzip();
self.output_handle = output_handle;
self._output_stream = _output_stream;
}
self.output_handle.as_ref()
}
pub fn play_sound(sound: Sound, cx: &mut AppContext) {
if !cx.has_global::<Self>() {
return;
}
cx.update_global::<Self, _, _>(|this, cx| {
let output_handle = this.ensure_output_exists()?;
let source = SoundRegistry::global(cx).get(sound.file()).log_err()?;
output_handle.play_raw(source).log_err()?;
Some(())
});
}
pub fn end_call(cx: &mut AppContext) {
if !cx.has_global::<Self>() {
return;
}
cx.update_global::<Self, _, _>(|this, _| {
this._output_stream.take();
this.output_handle.take();
});
}
}

View File

@ -1,44 +0,0 @@
use std::{io::Cursor, sync::Arc};
use anyhow::Result;
use collections::HashMap;
use gpui::{AppContext, AssetSource};
use rodio::{
source::{Buffered, SamplesConverter},
Decoder, Source,
};
type Sound = Buffered<SamplesConverter<Decoder<Cursor<Vec<u8>>>, f32>>;
pub struct SoundRegistry {
cache: Arc<parking_lot::Mutex<HashMap<String, Sound>>>,
assets: Box<dyn AssetSource>,
}
impl SoundRegistry {
pub fn new(source: impl AssetSource) -> Arc<Self> {
Arc::new(Self {
cache: Default::default(),
assets: Box::new(source),
})
}
pub fn global(cx: &AppContext) -> Arc<Self> {
cx.global::<Arc<Self>>().clone()
}
pub fn get(&self, name: &str) -> Result<impl Source<Item = f32>> {
if let Some(wav) = self.cache.lock().get(name) {
return Ok(wav.clone());
}
let path = format!("sounds/{}.wav", name);
let bytes = self.assets.load(&path)?.into_owned();
let cursor = Cursor::new(bytes);
let source = Decoder::new(cursor)?.convert_samples::<f32>().buffered();
self.cache.lock().insert(name.to_string(), source.clone());
Ok(source)
}
}

View File

@ -1,81 +0,0 @@
use assets::SoundRegistry;
use gpui::{AppContext, AssetSource};
use rodio::{OutputStream, OutputStreamHandle};
use util::ResultExt;
mod assets;
pub fn init(source: impl AssetSource, cx: &mut AppContext) {
cx.set_global(SoundRegistry::new(source));
cx.set_global(Audio::new());
}
pub enum Sound {
Joined,
Leave,
Mute,
Unmute,
StartScreenshare,
StopScreenshare,
}
impl Sound {
fn file(&self) -> &'static str {
match self {
Self::Joined => "joined_call",
Self::Leave => "leave_call",
Self::Mute => "mute",
Self::Unmute => "unmute",
Self::StartScreenshare => "start_screenshare",
Self::StopScreenshare => "stop_screenshare",
}
}
}
pub struct Audio {
_output_stream: Option<OutputStream>,
output_handle: Option<OutputStreamHandle>,
}
impl Audio {
pub fn new() -> Self {
Self {
_output_stream: None,
output_handle: None,
}
}
fn ensure_output_exists(&mut self) -> Option<&OutputStreamHandle> {
if self.output_handle.is_none() {
let (_output_stream, output_handle) = OutputStream::try_default().log_err().unzip();
self.output_handle = output_handle;
self._output_stream = _output_stream;
}
self.output_handle.as_ref()
}
pub fn play_sound(sound: Sound, cx: &mut AppContext) {
if !cx.has_global::<Self>() {
return;
}
cx.update_global::<Self, _>(|this, cx| {
let output_handle = this.ensure_output_exists()?;
let source = SoundRegistry::global(cx).get(sound.file()).log_err()?;
output_handle.play_raw(source).log_err()?;
Some(())
});
}
pub fn end_call(cx: &mut AppContext) {
if !cx.has_global::<Self>() {
return;
}
cx.update_global::<Self, _>(|this, _| {
this._output_stream.take();
this.output_handle.take();
});
}
}

View File

@ -3,18 +3,23 @@ mod update_notification;
use anyhow::{anyhow, Context, Result};
use client::{Client, TelemetrySettings, ZED_APP_PATH, ZED_APP_VERSION, ZED_SECRET_CLIENT_TOKEN};
use db::kvp::KEY_VALUE_STORE;
use db::RELEASE_CHANNEL;
use gpui::{
actions, platform::AppVersion, AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle,
Task, WeakViewHandle,
actions, AppContext, AsyncAppContext, Context as _, Model, ModelContext, SemanticVersion, Task,
ViewContext, VisualContext, WindowContext,
};
use isahc::AsyncBody;
use serde::Deserialize;
use serde_derive::Serialize;
use settings::{Setting, SettingsStore};
use smol::{fs::File, io::AsyncReadExt, process::Command};
use smol::io::AsyncReadExt;
use settings::{Settings, SettingsStore};
use smol::{fs::File, process::Command};
use std::{ffi::OsString, sync::Arc, time::Duration};
use update_notification::UpdateNotification;
use util::channel::ReleaseChannel;
use util::channel::{AppCommitSha, ReleaseChannel};
use util::http::HttpClient;
use workspace::Workspace;
@ -42,9 +47,9 @@ pub enum AutoUpdateStatus {
pub struct AutoUpdater {
status: AutoUpdateStatus,
current_version: AppVersion,
current_version: SemanticVersion,
http_client: Arc<dyn HttpClient>,
pending_poll: Option<Task<()>>,
pending_poll: Option<Task<Option<()>>>,
server_url: String,
}
@ -54,13 +59,9 @@ struct JsonRelease {
url: String,
}
impl Entity for AutoUpdater {
type Event = ();
}
struct AutoUpdateSetting(bool);
impl Setting for AutoUpdateSetting {
impl Settings for AutoUpdateSetting {
const KEY: Option<&'static str> = Some("auto_update");
type FileContent = Option<bool>;
@ -68,7 +69,7 @@ impl Setting for AutoUpdateSetting {
fn load(
default_value: &Option<bool>,
user_values: &[&Option<bool>],
_: &AppContext,
_: &mut AppContext,
) -> Result<Self> {
Ok(Self(
Self::json_merge(default_value, user_values)?.ok_or_else(Self::missing_default)?,
@ -77,18 +78,31 @@ impl Setting for AutoUpdateSetting {
}
pub fn init(http_client: Arc<dyn HttpClient>, server_url: String, cx: &mut AppContext) {
settings::register::<AutoUpdateSetting>(cx);
AutoUpdateSetting::register(cx);
if let Some(version) = (*ZED_APP_VERSION).or_else(|| cx.platform().app_version().ok()) {
let auto_updater = cx.add_model(|cx| {
cx.observe_new_views(|workspace: &mut Workspace, _cx| {
workspace.register_action(|_, action: &Check, cx| check(action, cx));
workspace.register_action(|_, action, cx| view_release_notes(action, cx));
// @nate - code to trigger update notification on launch
// todo!("remove this when Nate is done")
// workspace.show_notification(0, _cx, |cx| {
// cx.build_view(|_| UpdateNotification::new(SemanticVersion::from_str("1.1.1").unwrap()))
// });
})
.detach();
if let Some(version) = ZED_APP_VERSION.or_else(|| cx.app_metadata().app_version) {
let auto_updater = cx.new_model(|cx| {
let updater = AutoUpdater::new(version, http_client, server_url);
let mut update_subscription = settings::get::<AutoUpdateSetting>(cx)
let mut update_subscription = AutoUpdateSetting::get_global(cx)
.0
.then(|| updater.start_polling(cx));
cx.observe_global::<SettingsStore, _>(move |updater, cx| {
if settings::get::<AutoUpdateSetting>(cx).0 {
cx.observe_global::<SettingsStore>(move |updater, cx| {
if AutoUpdateSetting::get_global(cx).0 {
if update_subscription.is_none() {
update_subscription = Some(updater.start_polling(cx))
}
@ -101,19 +115,22 @@ pub fn init(http_client: Arc<dyn HttpClient>, server_url: String, cx: &mut AppCo
updater
});
cx.set_global(Some(auto_updater));
cx.add_global_action(check);
cx.add_global_action(view_release_notes);
cx.add_action(UpdateNotification::dismiss);
}
}
pub fn check(_: &Check, cx: &mut AppContext) {
pub fn check(_: &Check, cx: &mut WindowContext) {
if let Some(updater) = AutoUpdater::get(cx) {
updater.update(cx, |updater, cx| updater.poll(cx));
} else {
drop(cx.prompt(
gpui::PromptLevel::Info,
"Auto-updates disabled for non-bundled app.",
&["Ok"],
));
}
}
fn view_release_notes(_: &ViewReleaseNotes, cx: &mut AppContext) {
pub fn view_release_notes(_: &ViewReleaseNotes, cx: &mut AppContext) {
if let Some(auto_updater) = AutoUpdater::get(cx) {
let auto_updater = auto_updater.read(cx);
let server_url = &auto_updater.server_url;
@ -122,31 +139,28 @@ fn view_release_notes(_: &ViewReleaseNotes, cx: &mut AppContext) {
match cx.global::<ReleaseChannel>() {
ReleaseChannel::Dev => {}
ReleaseChannel::Nightly => {}
ReleaseChannel::Preview => cx
.platform()
.open_url(&format!("{server_url}/releases/preview/{current_version}")),
ReleaseChannel::Stable => cx
.platform()
.open_url(&format!("{server_url}/releases/stable/{current_version}")),
ReleaseChannel::Preview => {
cx.open_url(&format!("{server_url}/releases/preview/{current_version}"))
}
ReleaseChannel::Stable => {
cx.open_url(&format!("{server_url}/releases/stable/{current_version}"))
}
}
}
}
}
pub fn notify_of_any_new_update(
workspace: WeakViewHandle<Workspace>,
cx: &mut AppContext,
) -> Option<()> {
pub fn notify_of_any_new_update(cx: &mut ViewContext<Workspace>) -> Option<()> {
let updater = AutoUpdater::get(cx)?;
let version = updater.read(cx).current_version;
let should_show_notification = updater.read(cx).should_show_update_notification(cx);
cx.spawn(|mut cx| async move {
cx.spawn(|workspace, mut cx| async move {
let should_show_notification = should_show_notification.await?;
if should_show_notification {
workspace.update(&mut cx, |workspace, cx| {
workspace.show_notification(0, cx, |cx| {
cx.add_view(|_| UpdateNotification::new(version))
cx.new_view(|_| UpdateNotification::new(version))
});
updater
.read(cx)
@ -162,12 +176,12 @@ pub fn notify_of_any_new_update(
}
impl AutoUpdater {
pub fn get(cx: &mut AppContext) -> Option<ModelHandle<Self>> {
cx.default_global::<Option<ModelHandle<Self>>>().clone()
pub fn get(cx: &mut AppContext) -> Option<Model<Self>> {
cx.default_global::<Option<Model<Self>>>().clone()
}
fn new(
current_version: AppVersion,
current_version: SemanticVersion,
http_client: Arc<dyn HttpClient>,
server_url: String,
) -> Self {
@ -180,11 +194,11 @@ impl AutoUpdater {
}
}
pub fn start_polling(&self, cx: &mut ModelContext<Self>) -> Task<()> {
pub fn start_polling(&self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
cx.spawn(|this, mut cx| async move {
loop {
this.update(&mut cx, |this, cx| this.poll(cx));
cx.background().timer(POLL_INTERVAL).await;
this.update(&mut cx, |this, cx| this.poll(cx))?;
cx.background_executor().timer(POLL_INTERVAL).await;
}
})
}
@ -198,7 +212,7 @@ impl AutoUpdater {
cx.notify();
self.pending_poll = Some(cx.spawn(|this, mut cx| async move {
let result = Self::update(this.clone(), cx.clone()).await;
let result = Self::update(this.upgrade()?, cx.clone()).await;
this.update(&mut cx, |this, cx| {
this.pending_poll = None;
if let Err(error) = result {
@ -206,7 +220,8 @@ impl AutoUpdater {
this.status = AutoUpdateStatus::Errored;
cx.notify();
}
});
})
.ok()
}));
}
@ -219,26 +234,26 @@ impl AutoUpdater {
cx.notify();
}
async fn update(this: ModelHandle<Self>, mut cx: AsyncAppContext) -> Result<()> {
async fn update(this: Model<Self>, mut cx: AsyncAppContext) -> Result<()> {
let (client, server_url, current_version) = this.read_with(&cx, |this, _| {
(
this.http_client.clone(),
this.server_url.clone(),
this.current_version,
)
});
})?;
let mut url_string = format!(
"{server_url}/api/releases/latest?token={ZED_SECRET_CLIENT_TOKEN}&asset=Zed.dmg"
);
cx.read(|cx| {
cx.update(|cx| {
if cx.has_global::<ReleaseChannel>() {
if let Some(param) = cx.global::<ReleaseChannel>().release_query_param() {
url_string += "&";
url_string += param;
}
}
});
})?;
let mut response = client.get(&url_string, Default::default(), true).await?;
@ -251,26 +266,32 @@ impl AutoUpdater {
let release: JsonRelease =
serde_json::from_slice(body.as_slice()).context("error deserializing release")?;
let latest_version = release.version.parse::<AppVersion>()?;
if latest_version <= current_version {
let should_download = match *RELEASE_CHANNEL {
ReleaseChannel::Nightly => cx
.try_read_global::<AppCommitSha, _>(|sha, _| release.version != sha.0)
.unwrap_or(true),
_ => release.version.parse::<SemanticVersion>()? <= current_version,
};
if !should_download {
this.update(&mut cx, |this, cx| {
this.status = AutoUpdateStatus::Idle;
cx.notify();
});
})?;
return Ok(());
}
this.update(&mut cx, |this, cx| {
this.status = AutoUpdateStatus::Downloading;
cx.notify();
});
})?;
let temp_dir = tempdir::TempDir::new("zed-auto-update")?;
let dmg_path = temp_dir.path().join("Zed.dmg");
let mount_path = temp_dir.path().join("Zed");
let running_app_path = ZED_APP_PATH
.clone()
.map_or_else(|| cx.platform().app_path(), Ok)?;
.map_or_else(|| cx.update(|cx| cx.app_path())?, Ok)?;
let running_app_filename = running_app_path
.file_name()
.ok_or_else(|| anyhow!("invalid running app path"))?;
@ -279,15 +300,15 @@ impl AutoUpdater {
let mut dmg_file = File::create(&dmg_path).await?;
let (installation_id, release_channel, telemetry) = cx.read(|cx| {
let (installation_id, release_channel, telemetry) = cx.update(|cx| {
let installation_id = cx.global::<Arc<Client>>().telemetry().installation_id();
let release_channel = cx
.has_global::<ReleaseChannel>()
.then(|| cx.global::<ReleaseChannel>().display_name());
let telemetry = settings::get::<TelemetrySettings>(cx).metrics;
let telemetry = TelemetrySettings::get_global(cx).metrics;
(installation_id, release_channel, telemetry)
});
})?;
let request_body = AsyncBody::from(serde_json::to_string(&UpdateRequestBody {
installation_id,
@ -302,7 +323,7 @@ impl AutoUpdater {
this.update(&mut cx, |this, cx| {
this.status = AutoUpdateStatus::Installing;
cx.notify();
});
})?;
let output = Command::new("hdiutil")
.args(&["attach", "-nobrowse"])
@ -348,7 +369,7 @@ impl AutoUpdater {
.detach_and_log_err(cx);
this.status = AutoUpdateStatus::Updated;
cx.notify();
});
})?;
Ok(())
}
@ -357,7 +378,7 @@ impl AutoUpdater {
should_show: bool,
cx: &AppContext,
) -> Task<Result<()>> {
cx.background().spawn(async move {
cx.background_executor().spawn(async move {
if should_show {
KEY_VALUE_STORE
.write_kvp(
@ -375,7 +396,7 @@ impl AutoUpdater {
}
fn should_show_update_notification(&self, cx: &AppContext) -> Task<Result<bool>> {
cx.background().spawn(async move {
cx.background_executor().spawn(async move {
Ok(KEY_VALUE_STORE
.read_kvp(SHOULD_SHOW_UPDATE_NOTIFICATION_KEY)?
.is_some())

View File

@ -1,106 +1,56 @@
use crate::ViewReleaseNotes;
use gpui::{
elements::{Flex, MouseEventHandler, Padding, ParentElement, Svg, Text},
platform::{AppVersion, CursorStyle, MouseButton},
Element, Entity, View, ViewContext,
div, DismissEvent, EventEmitter, InteractiveElement, IntoElement, ParentElement, Render,
SemanticVersion, StatefulInteractiveElement, Styled, ViewContext,
};
use menu::Cancel;
use util::channel::ReleaseChannel;
use workspace::notifications::Notification;
use workspace::ui::{h_stack, v_stack, Icon, IconElement, Label, StyledExt};
pub struct UpdateNotification {
version: AppVersion,
version: SemanticVersion,
}
pub enum Event {
Dismiss,
}
impl Entity for UpdateNotification {
type Event = Event;
}
impl View for UpdateNotification {
fn ui_name() -> &'static str {
"UpdateNotification"
}
fn render(&mut self, cx: &mut gpui::ViewContext<Self>) -> gpui::AnyElement<Self> {
let theme = theme::current(cx).clone();
let theme = &theme.update_notification;
impl EventEmitter<DismissEvent> for UpdateNotification {}
impl Render for UpdateNotification {
fn render(&mut self, cx: &mut gpui::ViewContext<Self>) -> impl IntoElement {
let app_name = cx.global::<ReleaseChannel>().display_name();
MouseEventHandler::new::<ViewReleaseNotes, _>(0, cx, |state, cx| {
Flex::column()
.with_child(
Flex::row()
.with_child(
Text::new(
format!("Updated to {app_name} {}", self.version),
theme.message.text.clone(),
)
.contained()
.with_style(theme.message.container)
.aligned()
.top()
.left()
.flex(1., true),
)
.with_child(
MouseEventHandler::new::<Cancel, _>(0, cx, |state, _| {
let style = theme.dismiss_button.style_for(state);
Svg::new("icons/x.svg")
.with_color(style.color)
.constrained()
.with_width(style.icon_width)
.aligned()
.contained()
.with_style(style.container)
.constrained()
.with_width(style.button_width)
.with_height(style.button_width)
})
.with_padding(Padding::uniform(5.))
.on_click(MouseButton::Left, move |_, this, cx| {
this.dismiss(&Default::default(), cx)
})
.aligned()
.constrained()
.with_height(cx.font_cache().line_height(theme.message.text.font_size))
.aligned()
.top()
.flex_float(),
),
)
.with_child({
let style = theme.action_message.style_for(state);
Text::new("View the release notes", style.text.clone())
.contained()
.with_style(style.container)
})
.contained()
})
.with_cursor_style(CursorStyle::PointingHand)
.on_click(MouseButton::Left, |_, _, cx| {
crate::view_release_notes(&Default::default(), cx)
})
.into_any_named("update notification")
}
}
impl Notification for UpdateNotification {
fn should_dismiss_notification_on_event(&self, event: &<Self as Entity>::Event) -> bool {
matches!(event, Event::Dismiss)
v_stack()
.on_action(cx.listener(UpdateNotification::dismiss))
.elevation_3(cx)
.p_4()
.child(
h_stack()
.justify_between()
.child(Label::new(format!(
"Updated to {app_name} {}",
self.version
)))
.child(
div()
.id("cancel")
.child(IconElement::new(Icon::Close))
.cursor_pointer()
.on_click(cx.listener(|this, _, cx| this.dismiss(&menu::Cancel, cx))),
),
)
.child(
div()
.id("notes")
.child(Label::new("View the release notes"))
.cursor_pointer()
.on_click(|_, cx| crate::view_release_notes(&Default::default(), cx)),
)
}
}
impl UpdateNotification {
pub fn new(version: AppVersion) -> Self {
pub fn new(version: SemanticVersion) -> Self {
Self { version }
}
pub fn dismiss(&mut self, _: &Cancel, cx: &mut ViewContext<Self>) {
cx.emit(Event::Dismiss);
cx.emit(DismissEvent);
}
}

View File

@ -1,29 +0,0 @@
[package]
name = "auto_update2"
version = "0.1.0"
edition = "2021"
publish = false
[lib]
path = "src/auto_update.rs"
doctest = false
[dependencies]
db = { package = "db2", path = "../db2" }
client = { package = "client2", path = "../client2" }
gpui = { package = "gpui2", path = "../gpui2" }
menu = { package = "menu2", path = "../menu2" }
project = { package = "project2", path = "../project2" }
settings = { package = "settings2", path = "../settings2" }
theme = { package = "theme2", path = "../theme2" }
workspace = { package = "workspace2", path = "../workspace2" }
util = { path = "../util" }
anyhow.workspace = true
isahc.workspace = true
lazy_static.workspace = true
log.workspace = true
serde.workspace = true
serde_derive.workspace = true
serde_json.workspace = true
smol.workspace = true
tempdir.workspace = true

View File

@ -1,405 +0,0 @@
mod update_notification;
use anyhow::{anyhow, Context, Result};
use client::{Client, TelemetrySettings, ZED_APP_PATH, ZED_APP_VERSION, ZED_SECRET_CLIENT_TOKEN};
use db::kvp::KEY_VALUE_STORE;
use db::RELEASE_CHANNEL;
use gpui::{
actions, AppContext, AsyncAppContext, Context as _, Model, ModelContext, SemanticVersion, Task,
ViewContext, VisualContext, WindowContext,
};
use isahc::AsyncBody;
use serde::Deserialize;
use serde_derive::Serialize;
use smol::io::AsyncReadExt;
use settings::{Settings, SettingsStore};
use smol::{fs::File, process::Command};
use std::{ffi::OsString, sync::Arc, time::Duration};
use update_notification::UpdateNotification;
use util::channel::{AppCommitSha, ReleaseChannel};
use util::http::HttpClient;
use workspace::Workspace;
const SHOULD_SHOW_UPDATE_NOTIFICATION_KEY: &str = "auto-updater-should-show-updated-notification";
const POLL_INTERVAL: Duration = Duration::from_secs(60 * 60);
actions!(auto_update, [Check, DismissErrorMessage, ViewReleaseNotes]);
#[derive(Serialize)]
struct UpdateRequestBody {
installation_id: Option<Arc<str>>,
release_channel: Option<&'static str>,
telemetry: bool,
}
#[derive(Clone, Copy, PartialEq, Eq)]
pub enum AutoUpdateStatus {
Idle,
Checking,
Downloading,
Installing,
Updated,
Errored,
}
pub struct AutoUpdater {
status: AutoUpdateStatus,
current_version: SemanticVersion,
http_client: Arc<dyn HttpClient>,
pending_poll: Option<Task<Option<()>>>,
server_url: String,
}
#[derive(Deserialize)]
struct JsonRelease {
version: String,
url: String,
}
struct AutoUpdateSetting(bool);
impl Settings for AutoUpdateSetting {
const KEY: Option<&'static str> = Some("auto_update");
type FileContent = Option<bool>;
fn load(
default_value: &Option<bool>,
user_values: &[&Option<bool>],
_: &mut AppContext,
) -> Result<Self> {
Ok(Self(
Self::json_merge(default_value, user_values)?.ok_or_else(Self::missing_default)?,
))
}
}
pub fn init(http_client: Arc<dyn HttpClient>, server_url: String, cx: &mut AppContext) {
AutoUpdateSetting::register(cx);
cx.observe_new_views(|workspace: &mut Workspace, _cx| {
workspace.register_action(|_, action: &Check, cx| check(action, cx));
workspace.register_action(|_, action, cx| view_release_notes(action, cx));
// @nate - code to trigger update notification on launch
// todo!("remove this when Nate is done")
// workspace.show_notification(0, _cx, |cx| {
// cx.build_view(|_| UpdateNotification::new(SemanticVersion::from_str("1.1.1").unwrap()))
// });
})
.detach();
if let Some(version) = ZED_APP_VERSION.or_else(|| cx.app_metadata().app_version) {
let auto_updater = cx.new_model(|cx| {
let updater = AutoUpdater::new(version, http_client, server_url);
let mut update_subscription = AutoUpdateSetting::get_global(cx)
.0
.then(|| updater.start_polling(cx));
cx.observe_global::<SettingsStore>(move |updater, cx| {
if AutoUpdateSetting::get_global(cx).0 {
if update_subscription.is_none() {
update_subscription = Some(updater.start_polling(cx))
}
} else {
update_subscription.take();
}
})
.detach();
updater
});
cx.set_global(Some(auto_updater));
}
}
pub fn check(_: &Check, cx: &mut WindowContext) {
if let Some(updater) = AutoUpdater::get(cx) {
updater.update(cx, |updater, cx| updater.poll(cx));
} else {
drop(cx.prompt(
gpui::PromptLevel::Info,
"Auto-updates disabled for non-bundled app.",
&["Ok"],
));
}
}
pub fn view_release_notes(_: &ViewReleaseNotes, cx: &mut AppContext) {
if let Some(auto_updater) = AutoUpdater::get(cx) {
let auto_updater = auto_updater.read(cx);
let server_url = &auto_updater.server_url;
let current_version = auto_updater.current_version;
if cx.has_global::<ReleaseChannel>() {
match cx.global::<ReleaseChannel>() {
ReleaseChannel::Dev => {}
ReleaseChannel::Nightly => {}
ReleaseChannel::Preview => {
cx.open_url(&format!("{server_url}/releases/preview/{current_version}"))
}
ReleaseChannel::Stable => {
cx.open_url(&format!("{server_url}/releases/stable/{current_version}"))
}
}
}
}
}
pub fn notify_of_any_new_update(cx: &mut ViewContext<Workspace>) -> Option<()> {
let updater = AutoUpdater::get(cx)?;
let version = updater.read(cx).current_version;
let should_show_notification = updater.read(cx).should_show_update_notification(cx);
cx.spawn(|workspace, mut cx| async move {
let should_show_notification = should_show_notification.await?;
if should_show_notification {
workspace.update(&mut cx, |workspace, cx| {
workspace.show_notification(0, cx, |cx| {
cx.new_view(|_| UpdateNotification::new(version))
});
updater
.read(cx)
.set_should_show_update_notification(false, cx)
.detach_and_log_err(cx);
})?;
}
anyhow::Ok(())
})
.detach();
None
}
impl AutoUpdater {
pub fn get(cx: &mut AppContext) -> Option<Model<Self>> {
cx.default_global::<Option<Model<Self>>>().clone()
}
fn new(
current_version: SemanticVersion,
http_client: Arc<dyn HttpClient>,
server_url: String,
) -> Self {
Self {
status: AutoUpdateStatus::Idle,
current_version,
http_client,
server_url,
pending_poll: None,
}
}
pub fn start_polling(&self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
cx.spawn(|this, mut cx| async move {
loop {
this.update(&mut cx, |this, cx| this.poll(cx))?;
cx.background_executor().timer(POLL_INTERVAL).await;
}
})
}
pub fn poll(&mut self, cx: &mut ModelContext<Self>) {
if self.pending_poll.is_some() || self.status == AutoUpdateStatus::Updated {
return;
}
self.status = AutoUpdateStatus::Checking;
cx.notify();
self.pending_poll = Some(cx.spawn(|this, mut cx| async move {
let result = Self::update(this.upgrade()?, cx.clone()).await;
this.update(&mut cx, |this, cx| {
this.pending_poll = None;
if let Err(error) = result {
log::error!("auto-update failed: error:{:?}", error);
this.status = AutoUpdateStatus::Errored;
cx.notify();
}
})
.ok()
}));
}
pub fn status(&self) -> AutoUpdateStatus {
self.status
}
pub fn dismiss_error(&mut self, cx: &mut ModelContext<Self>) {
self.status = AutoUpdateStatus::Idle;
cx.notify();
}
async fn update(this: Model<Self>, mut cx: AsyncAppContext) -> Result<()> {
let (client, server_url, current_version) = this.read_with(&cx, |this, _| {
(
this.http_client.clone(),
this.server_url.clone(),
this.current_version,
)
})?;
let mut url_string = format!(
"{server_url}/api/releases/latest?token={ZED_SECRET_CLIENT_TOKEN}&asset=Zed.dmg"
);
cx.update(|cx| {
if cx.has_global::<ReleaseChannel>() {
if let Some(param) = cx.global::<ReleaseChannel>().release_query_param() {
url_string += "&";
url_string += param;
}
}
})?;
let mut response = client.get(&url_string, Default::default(), true).await?;
let mut body = Vec::new();
response
.body_mut()
.read_to_end(&mut body)
.await
.context("error reading release")?;
let release: JsonRelease =
serde_json::from_slice(body.as_slice()).context("error deserializing release")?;
let should_download = match *RELEASE_CHANNEL {
ReleaseChannel::Nightly => cx
.try_read_global::<AppCommitSha, _>(|sha, _| release.version != sha.0)
.unwrap_or(true),
_ => release.version.parse::<SemanticVersion>()? <= current_version,
};
if !should_download {
this.update(&mut cx, |this, cx| {
this.status = AutoUpdateStatus::Idle;
cx.notify();
})?;
return Ok(());
}
this.update(&mut cx, |this, cx| {
this.status = AutoUpdateStatus::Downloading;
cx.notify();
})?;
let temp_dir = tempdir::TempDir::new("zed-auto-update")?;
let dmg_path = temp_dir.path().join("Zed.dmg");
let mount_path = temp_dir.path().join("Zed");
let running_app_path = ZED_APP_PATH
.clone()
.map_or_else(|| cx.update(|cx| cx.app_path())?, Ok)?;
let running_app_filename = running_app_path
.file_name()
.ok_or_else(|| anyhow!("invalid running app path"))?;
let mut mounted_app_path: OsString = mount_path.join(running_app_filename).into();
mounted_app_path.push("/");
let mut dmg_file = File::create(&dmg_path).await?;
let (installation_id, release_channel, telemetry) = cx.update(|cx| {
let installation_id = cx.global::<Arc<Client>>().telemetry().installation_id();
let release_channel = cx
.has_global::<ReleaseChannel>()
.then(|| cx.global::<ReleaseChannel>().display_name());
let telemetry = TelemetrySettings::get_global(cx).metrics;
(installation_id, release_channel, telemetry)
})?;
let request_body = AsyncBody::from(serde_json::to_string(&UpdateRequestBody {
installation_id,
release_channel,
telemetry,
})?);
let mut response = client.get(&release.url, request_body, true).await?;
smol::io::copy(response.body_mut(), &mut dmg_file).await?;
log::info!("downloaded update. path:{:?}", dmg_path);
this.update(&mut cx, |this, cx| {
this.status = AutoUpdateStatus::Installing;
cx.notify();
})?;
let output = Command::new("hdiutil")
.args(&["attach", "-nobrowse"])
.arg(&dmg_path)
.arg("-mountroot")
.arg(&temp_dir.path())
.output()
.await?;
if !output.status.success() {
Err(anyhow!(
"failed to mount: {:?}",
String::from_utf8_lossy(&output.stderr)
))?;
}
let output = Command::new("rsync")
.args(&["-av", "--delete"])
.arg(&mounted_app_path)
.arg(&running_app_path)
.output()
.await?;
if !output.status.success() {
Err(anyhow!(
"failed to copy app: {:?}",
String::from_utf8_lossy(&output.stderr)
))?;
}
let output = Command::new("hdiutil")
.args(&["detach"])
.arg(&mount_path)
.output()
.await?;
if !output.status.success() {
Err(anyhow!(
"failed to unmount: {:?}",
String::from_utf8_lossy(&output.stderr)
))?;
}
this.update(&mut cx, |this, cx| {
this.set_should_show_update_notification(true, cx)
.detach_and_log_err(cx);
this.status = AutoUpdateStatus::Updated;
cx.notify();
})?;
Ok(())
}
fn set_should_show_update_notification(
&self,
should_show: bool,
cx: &AppContext,
) -> Task<Result<()>> {
cx.background_executor().spawn(async move {
if should_show {
KEY_VALUE_STORE
.write_kvp(
SHOULD_SHOW_UPDATE_NOTIFICATION_KEY.to_string(),
"".to_string(),
)
.await?;
} else {
KEY_VALUE_STORE
.delete_kvp(SHOULD_SHOW_UPDATE_NOTIFICATION_KEY.to_string())
.await?;
}
Ok(())
})
}
fn should_show_update_notification(&self, cx: &AppContext) -> Task<Result<bool>> {
cx.background_executor().spawn(async move {
Ok(KEY_VALUE_STORE
.read_kvp(SHOULD_SHOW_UPDATE_NOTIFICATION_KEY)?
.is_some())
})
}
}

View File

@ -1,56 +0,0 @@
use gpui::{
div, DismissEvent, Element, EventEmitter, InteractiveElement, ParentElement, Render,
SemanticVersion, StatefulInteractiveElement, Styled, ViewContext,
};
use menu::Cancel;
use util::channel::ReleaseChannel;
use workspace::ui::{h_stack, v_stack, Icon, IconElement, Label, StyledExt};
pub struct UpdateNotification {
version: SemanticVersion,
}
impl EventEmitter<DismissEvent> for UpdateNotification {}
impl Render for UpdateNotification {
fn render(&mut self, cx: &mut gpui::ViewContext<Self>) -> impl Element {
let app_name = cx.global::<ReleaseChannel>().display_name();
v_stack()
.on_action(cx.listener(UpdateNotification::dismiss))
.elevation_3(cx)
.p_4()
.child(
h_stack()
.justify_between()
.child(Label::new(format!(
"Updated to {app_name} {}",
self.version
)))
.child(
div()
.id("cancel")
.child(IconElement::new(Icon::Close))
.cursor_pointer()
.on_click(cx.listener(|this, _, cx| this.dismiss(&menu::Cancel, cx))),
),
)
.child(
div()
.id("notes")
.child(Label::new("View the release notes"))
.cursor_pointer()
.on_click(|_, cx| crate::view_release_notes(&Default::default(), cx)),
)
}
}
impl UpdateNotification {
pub fn new(version: SemanticVersion) -> Self {
Self { version }
}
pub fn dismiss(&mut self, _: &Cancel, cx: &mut ViewContext<Self>) {
cx.emit(DismissEvent);
}
}

View File

@ -12,6 +12,7 @@ doctest = false
collections = { path = "../collections" }
editor = { path = "../editor" }
gpui = { path = "../gpui" }
ui = { path = "../ui" }
language = { path = "../language" }
project = { path = "../project" }
search = { path = "../search" }

View File

@ -1,108 +1,74 @@
use editor::Editor;
use gpui::{
elements::*, platform::MouseButton, AppContext, Entity, Subscription, View, ViewContext,
ViewHandle, WeakViewHandle,
Element, EventEmitter, IntoElement, ParentElement, Render, StyledText, Subscription,
ViewContext,
};
use itertools::Itertools;
use search::ProjectSearchView;
use theme::ActiveTheme;
use ui::{prelude::*, ButtonLike, ButtonStyle, Label, Tooltip};
use workspace::{
item::{ItemEvent, ItemHandle},
ToolbarItemLocation, ToolbarItemView, Workspace,
ToolbarItemEvent, ToolbarItemLocation, ToolbarItemView,
};
pub enum Event {
UpdateLocation,
}
pub struct Breadcrumbs {
pane_focused: bool,
active_item: Option<Box<dyn ItemHandle>>,
project_search: Option<ViewHandle<ProjectSearchView>>,
subscription: Option<Subscription>,
workspace: WeakViewHandle<Workspace>,
}
impl Breadcrumbs {
pub fn new(workspace: &Workspace) -> Self {
pub fn new() -> Self {
Self {
pane_focused: false,
active_item: Default::default(),
subscription: Default::default(),
project_search: Default::default(),
workspace: workspace.weak_handle(),
}
}
}
impl Entity for Breadcrumbs {
type Event = Event;
}
impl EventEmitter<ToolbarItemEvent> for Breadcrumbs {}
impl View for Breadcrumbs {
fn ui_name() -> &'static str {
"Breadcrumbs"
}
fn render(&mut self, cx: &mut ViewContext<Self>) -> AnyElement<Self> {
let active_item = match &self.active_item {
Some(active_item) => active_item,
None => return Empty::new().into_any(),
impl Render for Breadcrumbs {
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
let element = h_stack().text_ui();
let Some(active_item) = self.active_item.as_ref() else {
return element;
};
let Some(segments) = active_item.breadcrumbs(cx.theme(), cx) else {
return element;
};
let not_editor = active_item.downcast::<editor::Editor>().is_none();
let theme = theme::current(cx).clone();
let style = &theme.workspace.toolbar.breadcrumbs;
let highlighted_segments = segments.into_iter().map(|segment| {
let mut text_style = cx.text_style();
text_style.color = Color::Muted.color(cx);
let breadcrumbs = match active_item.breadcrumbs(&theme, cx) {
Some(breadcrumbs) => breadcrumbs,
None => return Empty::new().into_any(),
}
.into_iter()
.map(|breadcrumb| {
Text::new(
breadcrumb.text,
theme.workspace.toolbar.breadcrumbs.default.text.clone(),
)
.with_highlights(breadcrumb.highlights.unwrap_or_default())
.into_any()
StyledText::new(segment.text)
.with_highlights(&text_style, segment.highlights.unwrap_or_default())
.into_any()
});
let breadcrumbs = Itertools::intersperse_with(highlighted_segments, || {
Label::new("").color(Color::Muted).into_any_element()
});
let crumbs = Flex::row()
.with_children(Itertools::intersperse_with(breadcrumbs, || {
Label::new(" ", style.default.text.clone()).into_any()
}))
.constrained()
.with_height(theme.workspace.toolbar.breadcrumb_height)
.contained();
if not_editor || !self.pane_focused {
return crumbs
.with_style(style.default.container)
.aligned()
.left()
.into_any();
let breadcrumbs_stack = h_stack().gap_1().children(breadcrumbs);
match active_item
.downcast::<Editor>()
.map(|editor| editor.downgrade())
{
Some(editor) => element.child(
ButtonLike::new("toggle outline view")
.child(breadcrumbs_stack)
.style(ButtonStyle::Subtle)
.on_click(move |_, cx| {
if let Some(editor) = editor.upgrade() {
outline::toggle(editor, &outline::Toggle, cx)
}
})
.tooltip(|cx| Tooltip::for_action("Show symbol outline", &outline::Toggle, cx)),
),
None => element.child(breadcrumbs_stack),
}
MouseEventHandler::new::<Breadcrumbs, _>(0, cx, |state, _| {
let style = style.style_for(state);
crumbs.with_style(style.container)
})
.on_click(MouseButton::Left, |_, this, cx| {
if let Some(workspace) = this.workspace.upgrade(cx) {
workspace.update(cx, |workspace, cx| {
outline::toggle(workspace, &Default::default(), cx)
})
}
})
.with_tooltip::<Breadcrumbs>(
0,
"Show symbol outline".to_owned(),
Some(Box::new(outline::Toggle)),
theme.tooltip.clone(),
cx,
)
.aligned()
.left()
.into_any()
}
}
@ -114,19 +80,21 @@ impl ToolbarItemView for Breadcrumbs {
) -> ToolbarItemLocation {
cx.notify();
self.active_item = None;
self.project_search = None;
if let Some(item) = active_pane_item {
let this = cx.weak_handle();
let this = cx.view().downgrade();
self.subscription = Some(item.subscribe_to_item_events(
cx,
Box::new(move |event, cx| {
if let Some(this) = this.upgrade(cx) {
if let ItemEvent::UpdateBreadcrumbs = event {
this.update(cx, |_, cx| {
cx.emit(Event::UpdateLocation);
cx.notify();
});
}
if let ItemEvent::UpdateBreadcrumbs = event {
this.update(cx, |this, cx| {
cx.notify();
if let Some(active_item) = this.active_item.as_ref() {
cx.emit(ToolbarItemEvent::ChangeLocation(
active_item.breadcrumb_location(cx),
))
}
})
.ok();
}
}),
));
@ -137,19 +105,6 @@ impl ToolbarItemView for Breadcrumbs {
}
}
fn location_for_event(
&self,
_: &Event,
current_location: ToolbarItemLocation,
cx: &AppContext,
) -> ToolbarItemLocation {
if let Some(active_item) = self.active_item.as_ref() {
active_item.breadcrumb_location(cx)
} else {
current_location
}
}
fn pane_focus_update(&mut self, pane_focused: bool, _: &mut ViewContext<Self>) {
self.pane_focused = pane_focused;
}

View File

@ -1,28 +0,0 @@
[package]
name = "breadcrumbs2"
version = "0.1.0"
edition = "2021"
publish = false
[lib]
path = "src/breadcrumbs.rs"
doctest = false
[dependencies]
collections = { path = "../collections" }
editor = { package = "editor2", path = "../editor2" }
gpui = { package = "gpui2", path = "../gpui2" }
ui = { package = "ui2", path = "../ui2" }
language = { package = "language2", path = "../language2" }
project = { package = "project2", path = "../project2" }
search = { package = "search2", path = "../search2" }
settings = { package = "settings2", path = "../settings2" }
theme = { package = "theme2", path = "../theme2" }
workspace = { package = "workspace2", path = "../workspace2" }
outline = { package = "outline2", path = "../outline2" }
itertools = "0.10"
[dev-dependencies]
editor = { package = "editor2", path = "../editor2", features = ["test-support"] }
gpui = { package = "gpui2", path = "../gpui2", features = ["test-support"] }
workspace = { package = "workspace2", path = "../workspace2", features = ["test-support"] }

View File

@ -1,111 +0,0 @@
use editor::Editor;
use gpui::{
Element, EventEmitter, IntoElement, ParentElement, Render, StyledText, Subscription,
ViewContext,
};
use itertools::Itertools;
use theme::ActiveTheme;
use ui::{prelude::*, ButtonLike, ButtonStyle, Label, Tooltip};
use workspace::{
item::{ItemEvent, ItemHandle},
ToolbarItemEvent, ToolbarItemLocation, ToolbarItemView,
};
pub struct Breadcrumbs {
pane_focused: bool,
active_item: Option<Box<dyn ItemHandle>>,
subscription: Option<Subscription>,
}
impl Breadcrumbs {
pub fn new() -> Self {
Self {
pane_focused: false,
active_item: Default::default(),
subscription: Default::default(),
}
}
}
impl EventEmitter<ToolbarItemEvent> for Breadcrumbs {}
impl Render for Breadcrumbs {
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl Element {
let element = h_stack().text_ui();
let Some(active_item) = self.active_item.as_ref() else {
return element;
};
let Some(segments) = active_item.breadcrumbs(cx.theme(), cx) else {
return element;
};
let highlighted_segments = segments.into_iter().map(|segment| {
let mut text_style = cx.text_style();
text_style.color = Color::Muted.color(cx);
StyledText::new(segment.text)
.with_highlights(&text_style, segment.highlights.unwrap_or_default())
.into_any()
});
let breadcrumbs = Itertools::intersperse_with(highlighted_segments, || {
Label::new("").color(Color::Muted).into_any_element()
});
let breadcrumbs_stack = h_stack().gap_1().children(breadcrumbs);
match active_item
.downcast::<Editor>()
.map(|editor| editor.downgrade())
{
Some(editor) => element.child(
ButtonLike::new("toggle outline view")
.child(breadcrumbs_stack)
.style(ButtonStyle::Subtle)
.on_click(move |_, cx| {
if let Some(editor) = editor.upgrade() {
outline::toggle(editor, &outline::Toggle, cx)
}
})
.tooltip(|cx| Tooltip::for_action("Show symbol outline", &outline::Toggle, cx)),
),
None => element.child(breadcrumbs_stack),
}
}
}
impl ToolbarItemView for Breadcrumbs {
fn set_active_pane_item(
&mut self,
active_pane_item: Option<&dyn ItemHandle>,
cx: &mut ViewContext<Self>,
) -> ToolbarItemLocation {
cx.notify();
self.active_item = None;
if let Some(item) = active_pane_item {
let this = cx.view().downgrade();
self.subscription = Some(item.subscribe_to_item_events(
cx,
Box::new(move |event, cx| {
if let ItemEvent::UpdateBreadcrumbs = event {
this.update(cx, |this, cx| {
cx.notify();
if let Some(active_item) = this.active_item.as_ref() {
cx.emit(ToolbarItemEvent::ChangeLocation(
active_item.breadcrumb_location(cx),
))
}
})
.ok();
}
}),
));
self.active_item = Some(item.boxed_clone());
item.breadcrumb_location(cx)
} else {
ToolbarItemLocation::Hidden
}
}
fn pane_focus_update(&mut self, pane_focused: bool, _: &mut ViewContext<Self>) {
self.pane_focused = pane_focused;
}
}

View File

@ -35,11 +35,13 @@ util = { path = "../util" }
anyhow.workspace = true
async-broadcast = "0.4"
futures.workspace = true
image = "0.23"
postage.workspace = true
schemars.workspace = true
serde.workspace = true
serde_json.workspace = true
serde_derive.workspace = true
smallvec.workspace = true
[dev-dependencies]
client = { path = "../client", features = ["test-support"] }

View File

@ -9,31 +9,25 @@ use client::{proto, Client, TelemetrySettings, TypedEnvelope, User, UserStore, Z
use collections::HashSet;
use futures::{channel::oneshot, future::Shared, Future, FutureExt};
use gpui::{
AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Subscription, Task,
WeakModelHandle,
AppContext, AsyncAppContext, Context, EventEmitter, Model, ModelContext, Subscription, Task,
WeakModel,
};
use postage::watch;
use project::Project;
use room::Event;
use settings::Settings;
use std::sync::Arc;
pub use participant::ParticipantLocation;
pub use room::Room;
pub fn init(client: Arc<Client>, user_store: ModelHandle<UserStore>, cx: &mut AppContext) {
settings::register::<CallSettings>(cx);
pub fn init(client: Arc<Client>, user_store: Model<UserStore>, cx: &mut AppContext) {
CallSettings::register(cx);
let active_call = cx.add_model(|cx| ActiveCall::new(client, user_store, cx));
let active_call = cx.new_model(|cx| ActiveCall::new(client, user_store, cx));
cx.set_global(active_call);
}
#[derive(Clone)]
pub struct IncomingCall {
pub room_id: u64,
pub calling_user: Arc<User>,
pub participants: Vec<Arc<User>>,
pub initial_project: Option<proto::ParticipantProject>,
}
pub struct OneAtATime {
cancel: Option<oneshot::Sender<()>>,
}
@ -65,43 +59,44 @@ impl OneAtATime {
}
}
#[derive(Clone)]
pub struct IncomingCall {
pub room_id: u64,
pub calling_user: Arc<User>,
pub participants: Vec<Arc<User>>,
pub initial_project: Option<proto::ParticipantProject>,
}
/// Singleton global maintaining the user's participation in a room across workspaces.
pub struct ActiveCall {
room: Option<(ModelHandle<Room>, Vec<Subscription>)>,
pending_room_creation: Option<Shared<Task<Result<ModelHandle<Room>, Arc<anyhow::Error>>>>>,
room: Option<(Model<Room>, Vec<Subscription>)>,
pending_room_creation: Option<Shared<Task<Result<Model<Room>, Arc<anyhow::Error>>>>>,
location: Option<WeakModel<Project>>,
_join_debouncer: OneAtATime,
location: Option<WeakModelHandle<Project>>,
pending_invites: HashSet<u64>,
incoming_call: (
watch::Sender<Option<IncomingCall>>,
watch::Receiver<Option<IncomingCall>>,
),
client: Arc<Client>,
user_store: ModelHandle<UserStore>,
user_store: Model<UserStore>,
_subscriptions: Vec<client::Subscription>,
}
impl Entity for ActiveCall {
type Event = room::Event;
}
impl EventEmitter<Event> for ActiveCall {}
impl ActiveCall {
fn new(
client: Arc<Client>,
user_store: ModelHandle<UserStore>,
cx: &mut ModelContext<Self>,
) -> Self {
fn new(client: Arc<Client>, user_store: Model<UserStore>, cx: &mut ModelContext<Self>) -> Self {
Self {
room: None,
pending_room_creation: None,
location: None,
pending_invites: Default::default(),
incoming_call: watch::channel(),
_join_debouncer: OneAtATime { cancel: None },
_subscriptions: vec![
client.add_request_handler(cx.handle(), Self::handle_incoming_call),
client.add_message_handler(cx.handle(), Self::handle_call_canceled),
client.add_request_handler(cx.weak_model(), Self::handle_incoming_call),
client.add_message_handler(cx.weak_model(), Self::handle_call_canceled),
],
client,
user_store,
@ -113,35 +108,35 @@ impl ActiveCall {
}
async fn handle_incoming_call(
this: ModelHandle<Self>,
this: Model<Self>,
envelope: TypedEnvelope<proto::IncomingCall>,
_: Arc<Client>,
mut cx: AsyncAppContext,
) -> Result<proto::Ack> {
let user_store = this.read_with(&cx, |this, _| this.user_store.clone());
let user_store = this.update(&mut cx, |this, _| this.user_store.clone())?;
let call = IncomingCall {
room_id: envelope.payload.room_id,
participants: user_store
.update(&mut cx, |user_store, cx| {
user_store.get_users(envelope.payload.participant_user_ids, cx)
})
})?
.await?,
calling_user: user_store
.update(&mut cx, |user_store, cx| {
user_store.get_user(envelope.payload.calling_user_id, cx)
})
})?
.await?,
initial_project: envelope.payload.initial_project,
};
this.update(&mut cx, |this, _| {
*this.incoming_call.0.borrow_mut() = Some(call);
});
})?;
Ok(proto::Ack {})
}
async fn handle_call_canceled(
this: ModelHandle<Self>,
this: Model<Self>,
envelope: TypedEnvelope<proto::CallCanceled>,
_: Arc<Client>,
mut cx: AsyncAppContext,
@ -154,18 +149,18 @@ impl ActiveCall {
{
incoming_call.take();
}
});
})?;
Ok(())
}
pub fn global(cx: &AppContext) -> ModelHandle<Self> {
cx.global::<ModelHandle<Self>>().clone()
pub fn global(cx: &AppContext) -> Model<Self> {
cx.global::<Model<Self>>().clone()
}
pub fn invite(
&mut self,
called_user_id: u64,
initial_project: Option<ModelHandle<Project>>,
initial_project: Option<Model<Project>>,
cx: &mut ModelContext<Self>,
) -> Task<Result<()>> {
if !self.pending_invites.insert(called_user_id) {
@ -184,21 +179,21 @@ impl ActiveCall {
};
let invite = if let Some(room) = room {
cx.spawn_weak(|_, mut cx| async move {
cx.spawn(move |_, mut cx| async move {
let room = room.await.map_err(|err| anyhow!("{:?}", err))?;
let initial_project_id = if let Some(initial_project) = initial_project {
Some(
room.update(&mut cx, |room, cx| room.share_project(initial_project, cx))
room.update(&mut cx, |room, cx| room.share_project(initial_project, cx))?
.await?,
)
} else {
None
};
room.update(&mut cx, |room, cx| {
room.update(&mut cx, move |room, cx| {
room.call(called_user_id, initial_project_id, cx)
})
})?
.await?;
anyhow::Ok(())
@ -207,7 +202,7 @@ impl ActiveCall {
let client = self.client.clone();
let user_store = self.user_store.clone();
let room = cx
.spawn(|this, mut cx| async move {
.spawn(move |this, mut cx| async move {
let create_room = async {
let room = cx
.update(|cx| {
@ -218,31 +213,31 @@ impl ActiveCall {
user_store,
cx,
)
})
})?
.await?;
this.update(&mut cx, |this, cx| this.set_room(Some(room.clone()), cx))
this.update(&mut cx, |this, cx| this.set_room(Some(room.clone()), cx))?
.await?;
anyhow::Ok(room)
};
let room = create_room.await;
this.update(&mut cx, |this, _| this.pending_room_creation = None);
this.update(&mut cx, |this, _| this.pending_room_creation = None)?;
room.map_err(Arc::new)
})
.shared();
self.pending_room_creation = Some(room.clone());
cx.foreground().spawn(async move {
cx.background_executor().spawn(async move {
room.await.map_err(|err| anyhow!("{:?}", err))?;
anyhow::Ok(())
})
};
cx.spawn(|this, mut cx| async move {
cx.spawn(move |this, mut cx| async move {
let result = invite.await;
if result.is_ok() {
this.update(&mut cx, |this, cx| this.report_call_event("invite", cx));
this.update(&mut cx, |this, cx| this.report_call_event("invite", cx))?;
} else {
// TODO: Resport collaboration error
}
@ -250,7 +245,7 @@ impl ActiveCall {
this.update(&mut cx, |this, cx| {
this.pending_invites.remove(&called_user_id);
cx.notify();
});
})?;
result
})
}
@ -267,7 +262,7 @@ impl ActiveCall {
};
let client = self.client.clone();
cx.foreground().spawn(async move {
cx.background_executor().spawn(async move {
client
.request(proto::CancelCall {
room_id,
@ -306,11 +301,11 @@ impl ActiveCall {
cx.spawn(|this, mut cx| async move {
let room = join.await?;
this.update(&mut cx, |this, cx| this.set_room(room.clone(), cx))
this.update(&mut cx, |this, cx| this.set_room(room.clone(), cx))?
.await?;
this.update(&mut cx, |this, cx| {
this.report_call_event("accept incoming", cx)
});
})?;
Ok(())
})
}
@ -333,7 +328,7 @@ impl ActiveCall {
&mut self,
channel_id: u64,
cx: &mut ModelContext<Self>,
) -> Task<Result<Option<ModelHandle<Room>>>> {
) -> Task<Result<Option<Model<Room>>>> {
if let Some(room) = self.room().cloned() {
if room.read(cx).channel_id() == Some(channel_id) {
return Task::ready(Ok(Some(room)));
@ -352,13 +347,13 @@ impl ActiveCall {
Room::join_channel(channel_id, client, user_store, cx).await
});
cx.spawn(move |this, mut cx| async move {
cx.spawn(|this, mut cx| async move {
let room = join.await?;
this.update(&mut cx, |this, cx| this.set_room(room.clone(), cx))
this.update(&mut cx, |this, cx| this.set_room(room.clone(), cx))?
.await?;
this.update(&mut cx, |this, cx| {
this.report_call_event("join channel", cx)
});
})?;
Ok(room)
})
}
@ -366,6 +361,7 @@ impl ActiveCall {
pub fn hang_up(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
cx.notify();
self.report_call_event("hang up", cx);
Audio::end_call(cx);
if let Some((room, _)) = self.room.take() {
room.update(cx, |room, cx| room.leave(cx))
@ -376,7 +372,7 @@ impl ActiveCall {
pub fn share_project(
&mut self,
project: ModelHandle<Project>,
project: Model<Project>,
cx: &mut ModelContext<Self>,
) -> Task<Result<u64>> {
if let Some((room, _)) = self.room.as_ref() {
@ -389,7 +385,7 @@ impl ActiveCall {
pub fn unshare_project(
&mut self,
project: ModelHandle<Project>,
project: Model<Project>,
cx: &mut ModelContext<Self>,
) -> Result<()> {
if let Some((room, _)) = self.room.as_ref() {
@ -400,13 +396,13 @@ impl ActiveCall {
}
}
pub fn location(&self) -> Option<&WeakModelHandle<Project>> {
pub fn location(&self) -> Option<&WeakModel<Project>> {
self.location.as_ref()
}
pub fn set_location(
&mut self,
project: Option<&ModelHandle<Project>>,
project: Option<&Model<Project>>,
cx: &mut ModelContext<Self>,
) -> Task<Result<()>> {
if project.is_some() || !*ZED_ALWAYS_ACTIVE {
@ -420,7 +416,7 @@ impl ActiveCall {
fn set_room(
&mut self,
room: Option<ModelHandle<Room>>,
room: Option<Model<Room>>,
cx: &mut ModelContext<Self>,
) -> Task<Result<()>> {
if room.as_ref() != self.room.as_ref().map(|room| &room.0) {
@ -441,7 +437,10 @@ impl ActiveCall {
cx.subscribe(&room, |_, _, event, cx| cx.emit(event.clone())),
];
self.room = Some((room.clone(), subscriptions));
let location = self.location.and_then(|location| location.upgrade(cx));
let location = self
.location
.as_ref()
.and_then(|location| location.upgrade());
room.update(cx, |room, cx| room.set_location(location.as_ref(), cx))
}
} else {
@ -453,7 +452,7 @@ impl ActiveCall {
}
}
pub fn room(&self) -> Option<&ModelHandle<Room>> {
pub fn room(&self) -> Option<&Model<Room>> {
self.room.as_ref().map(|(room, _)| room)
}
@ -465,7 +464,7 @@ impl ActiveCall {
&self.pending_invites
}
pub fn report_call_event(&self, operation: &'static str, cx: &AppContext) {
pub fn report_call_event(&self, operation: &'static str, cx: &mut AppContext) {
if let Some(room) = self.room() {
let room = room.read(cx);
report_call_event_for_room(operation, room.id(), room.channel_id(), &self.client, cx);
@ -478,10 +477,10 @@ pub fn report_call_event_for_room(
room_id: u64,
channel_id: Option<u64>,
client: &Arc<Client>,
cx: &AppContext,
cx: &mut AppContext,
) {
let telemetry = client.telemetry();
let telemetry_settings = *settings::get::<TelemetrySettings>(cx);
let telemetry_settings = *TelemetrySettings::get_global(cx);
telemetry.report_call_event(telemetry_settings, operation, Some(room_id), channel_id)
}
@ -495,7 +494,8 @@ pub fn report_call_event_for_channel(
let room = ActiveCall::global(cx).read(cx).room();
let telemetry = client.telemetry();
let telemetry_settings = *settings::get::<TelemetrySettings>(cx);
let telemetry_settings = *TelemetrySettings::get_global(cx);
telemetry.report_call_event(
telemetry_settings,

View File

@ -1,6 +1,8 @@
use anyhow::Result;
use gpui::AppContext;
use schemars::JsonSchema;
use serde_derive::{Deserialize, Serialize};
use settings::Setting;
use settings::Settings;
#[derive(Deserialize, Debug)]
pub struct CallSettings {
@ -12,7 +14,7 @@ pub struct CallSettingsContent {
pub mute_on_join: Option<bool>,
}
impl Setting for CallSettings {
impl Settings for CallSettings {
const KEY: Option<&'static str> = Some("calls");
type FileContent = CallSettingsContent;
@ -20,8 +22,11 @@ impl Setting for CallSettings {
fn load(
default_value: &Self::FileContent,
user_values: &[&Self::FileContent],
_: &gpui::AppContext,
) -> anyhow::Result<Self> {
_cx: &mut AppContext,
) -> Result<Self>
where
Self: Sized,
{
Self::load_via_json_merge(default_value, user_values)
}
}

View File

@ -2,11 +2,11 @@ use anyhow::{anyhow, Result};
use client::ParticipantIndex;
use client::{proto, User};
use collections::HashMap;
use gpui::WeakModelHandle;
use gpui::WeakModel;
pub use live_kit_client::Frame;
use live_kit_client::RemoteAudioTrack;
pub use live_kit_client::{RemoteAudioTrack, RemoteVideoTrack};
use project::Project;
use std::{fmt, sync::Arc};
use std::sync::Arc;
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub enum ParticipantLocation {
@ -35,7 +35,7 @@ impl ParticipantLocation {
#[derive(Clone, Default)]
pub struct LocalParticipant {
pub projects: Vec<proto::ParticipantProject>,
pub active_project: Option<WeakModelHandle<Project>>,
pub active_project: Option<WeakModel<Project>>,
}
#[derive(Clone, Debug)]
@ -50,20 +50,3 @@ pub struct RemoteParticipant {
pub video_tracks: HashMap<live_kit_client::Sid, Arc<RemoteVideoTrack>>,
pub audio_tracks: HashMap<live_kit_client::Sid, Arc<RemoteAudioTrack>>,
}
#[derive(Clone)]
pub struct RemoteVideoTrack {
pub(crate) live_kit_track: Arc<live_kit_client::RemoteVideoTrack>,
}
impl fmt::Debug for RemoteVideoTrack {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("RemoteVideoTrack").finish()
}
}
impl RemoteVideoTrack {
pub fn frames(&self) -> async_broadcast::Receiver<Frame> {
self.live_kit_track.frames()
}
}

View File

@ -1,6 +1,6 @@
use crate::{
call_settings::CallSettings,
participant::{LocalParticipant, ParticipantLocation, RemoteParticipant, RemoteVideoTrack},
participant::{LocalParticipant, ParticipantLocation, RemoteParticipant},
};
use anyhow::{anyhow, Result};
use audio::{Audio, Sound};
@ -11,7 +11,9 @@ use client::{
use collections::{BTreeMap, HashMap, HashSet};
use fs::Fs;
use futures::{FutureExt, StreamExt};
use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle};
use gpui::{
AppContext, AsyncAppContext, Context, EventEmitter, Model, ModelContext, Task, WeakModel,
};
use language::LanguageRegistry;
use live_kit_client::{
LocalAudioTrack, LocalTrackPublication, LocalVideoTrack, RemoteAudioTrackUpdate,
@ -19,7 +21,8 @@ use live_kit_client::{
};
use postage::{sink::Sink, stream::Stream, watch};
use project::Project;
use std::{future::Future, mem, pin::Pin, sync::Arc, time::Duration};
use settings::Settings as _;
use std::{future::Future, mem, sync::Arc, time::Duration};
use util::{post_inc, ResultExt, TryFutureExt};
pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
@ -54,11 +57,11 @@ pub enum Event {
pub struct Room {
id: u64,
pub channel_id: Option<u64>,
channel_id: Option<u64>,
live_kit: Option<LiveKitRoom>,
status: RoomStatus,
shared_projects: HashSet<WeakModelHandle<Project>>,
joined_projects: HashSet<WeakModelHandle<Project>>,
shared_projects: HashSet<WeakModel<Project>>,
joined_projects: HashSet<WeakModel<Project>>,
local_participant: LocalParticipant,
remote_participants: BTreeMap<u64, RemoteParticipant>,
pending_participants: Vec<Arc<User>>,
@ -66,39 +69,17 @@ pub struct Room {
pending_call_count: usize,
leave_when_empty: bool,
client: Arc<Client>,
user_store: ModelHandle<UserStore>,
user_store: Model<UserStore>,
follows_by_leader_id_project_id: HashMap<(PeerId, u64), Vec<PeerId>>,
subscriptions: Vec<client::Subscription>,
client_subscriptions: Vec<client::Subscription>,
_subscriptions: Vec<gpui::Subscription>,
room_update_completed_tx: watch::Sender<Option<()>>,
room_update_completed_rx: watch::Receiver<Option<()>>,
pending_room_update: Option<Task<()>>,
maintain_connection: Option<Task<Option<()>>>,
}
impl Entity for Room {
type Event = Event;
fn release(&mut self, cx: &mut AppContext) {
if self.status.is_online() {
self.leave_internal(cx).detach_and_log_err(cx);
}
}
fn app_will_quit(&mut self, cx: &mut AppContext) -> Option<Pin<Box<dyn Future<Output = ()>>>> {
if self.status.is_online() {
let leave = self.leave_internal(cx);
Some(
cx.background()
.spawn(async move {
leave.await.log_err();
})
.boxed(),
)
} else {
None
}
}
}
impl EventEmitter<Event> for Room {}
impl Room {
pub fn channel_id(&self) -> Option<u64> {
@ -121,16 +102,12 @@ impl Room {
}
}
pub fn can_publish(&self) -> bool {
self.live_kit.as_ref().is_some_and(|room| room.can_publish)
}
fn new(
id: u64,
channel_id: Option<u64>,
live_kit_connection_info: Option<proto::LiveKitConnectionInfo>,
client: Arc<Client>,
user_store: ModelHandle<UserStore>,
user_store: Model<UserStore>,
cx: &mut ModelContext<Self>,
) -> Self {
let live_kit_room = if let Some(connection_info) = live_kit_connection_info {
@ -138,69 +115,75 @@ impl Room {
let mut status = room.status();
// Consume the initial status of the room.
let _ = status.try_recv();
let _maintain_room = cx.spawn_weak(|this, mut cx| async move {
let _maintain_room = cx.spawn(|this, mut cx| async move {
while let Some(status) = status.next().await {
let this = if let Some(this) = this.upgrade(&cx) {
let this = if let Some(this) = this.upgrade() {
this
} else {
break;
};
if status == live_kit_client::ConnectionState::Disconnected {
this.update(&mut cx, |this, cx| this.leave(cx).log_err());
this.update(&mut cx, |this, cx| this.leave(cx).log_err())
.ok();
break;
}
}
});
let mut track_video_changes = room.remote_video_track_updates();
let _maintain_video_tracks = cx.spawn_weak(|this, mut cx| async move {
while let Some(track_change) = track_video_changes.next().await {
let this = if let Some(this) = this.upgrade(&cx) {
this
} else {
break;
};
let _maintain_video_tracks = cx.spawn({
let room = room.clone();
move |this, mut cx| async move {
let mut track_video_changes = room.remote_video_track_updates();
while let Some(track_change) = track_video_changes.next().await {
let this = if let Some(this) = this.upgrade() {
this
} else {
break;
};
this.update(&mut cx, |this, cx| {
this.remote_video_track_updated(track_change, cx).log_err()
});
this.update(&mut cx, |this, cx| {
this.remote_video_track_updated(track_change, cx).log_err()
})
.ok();
}
}
});
let mut track_audio_changes = room.remote_audio_track_updates();
let _maintain_audio_tracks = cx.spawn_weak(|this, mut cx| async move {
while let Some(track_change) = track_audio_changes.next().await {
let this = if let Some(this) = this.upgrade(&cx) {
this
} else {
break;
};
let _maintain_audio_tracks = cx.spawn({
let room = room.clone();
|this, mut cx| async move {
let mut track_audio_changes = room.remote_audio_track_updates();
while let Some(track_change) = track_audio_changes.next().await {
let this = if let Some(this) = this.upgrade() {
this
} else {
break;
};
this.update(&mut cx, |this, cx| {
this.remote_audio_track_updated(track_change, cx).log_err()
});
this.update(&mut cx, |this, cx| {
this.remote_audio_track_updated(track_change, cx).log_err()
})
.ok();
}
}
});
let connect = room.connect(&connection_info.server_url, &connection_info.token);
if connection_info.can_publish {
cx.spawn(|this, mut cx| async move {
connect.await?;
cx.spawn(|this, mut cx| async move {
connect.await?;
if !cx.read(Self::mute_on_join) {
this.update(&mut cx, |this, cx| this.share_microphone(cx))
.await?;
}
if !cx.update(|cx| Self::mute_on_join(cx))? {
this.update(&mut cx, |this, cx| this.share_microphone(cx))?
.await?;
}
anyhow::Ok(())
})
.detach_and_log_err(cx);
}
anyhow::Ok(())
})
.detach_and_log_err(cx);
Some(LiveKitRoom {
room,
can_publish: connection_info.can_publish,
screen_track: LocalTrack::None,
microphone_track: LocalTrack::None,
next_publish_id: 0,
@ -214,8 +197,10 @@ impl Room {
None
};
let maintain_connection =
cx.spawn_weak(|this, cx| Self::maintain_connection(this, client.clone(), cx).log_err());
let maintain_connection = cx.spawn({
let client = client.clone();
move |this, cx| Self::maintain_connection(this, client.clone(), cx).log_err()
});
Audio::play_sound(Sound::Joined, cx);
@ -233,7 +218,13 @@ impl Room {
remote_participants: Default::default(),
pending_participants: Default::default(),
pending_call_count: 0,
subscriptions: vec![client.add_message_handler(cx.handle(), Self::handle_room_updated)],
client_subscriptions: vec![
client.add_message_handler(cx.weak_model(), Self::handle_room_updated)
],
_subscriptions: vec![
cx.on_release(Self::released),
cx.on_app_quit(Self::app_will_quit),
],
leave_when_empty: false,
pending_room_update: None,
client,
@ -247,15 +238,15 @@ impl Room {
pub(crate) fn create(
called_user_id: u64,
initial_project: Option<ModelHandle<Project>>,
initial_project: Option<Model<Project>>,
client: Arc<Client>,
user_store: ModelHandle<UserStore>,
user_store: Model<UserStore>,
cx: &mut AppContext,
) -> Task<Result<ModelHandle<Self>>> {
cx.spawn(|mut cx| async move {
) -> Task<Result<Model<Self>>> {
cx.spawn(move |mut cx| async move {
let response = client.request(proto::CreateRoom {}).await?;
let room_proto = response.room.ok_or_else(|| anyhow!("invalid room"))?;
let room = cx.add_model(|cx| {
let room = cx.new_model(|cx| {
Self::new(
room_proto.id,
None,
@ -264,13 +255,13 @@ impl Room {
user_store,
cx,
)
});
})?;
let initial_project_id = if let Some(initial_project) = initial_project {
let initial_project_id = room
.update(&mut cx, |room, cx| {
room.share_project(initial_project.clone(), cx)
})
})?
.await?;
Some(initial_project_id)
} else {
@ -281,7 +272,7 @@ impl Room {
.update(&mut cx, |room, cx| {
room.leave_when_empty = true;
room.call(called_user_id, initial_project_id, cx)
})
})?
.await
{
Ok(()) => Ok(room),
@ -293,9 +284,9 @@ impl Room {
pub(crate) async fn join_channel(
channel_id: u64,
client: Arc<Client>,
user_store: ModelHandle<UserStore>,
user_store: Model<UserStore>,
cx: AsyncAppContext,
) -> Result<ModelHandle<Self>> {
) -> Result<Model<Self>> {
Self::from_join_response(
client.request(proto::JoinChannel { channel_id }).await?,
client,
@ -307,9 +298,9 @@ impl Room {
pub(crate) async fn join(
room_id: u64,
client: Arc<Client>,
user_store: ModelHandle<UserStore>,
user_store: Model<UserStore>,
cx: AsyncAppContext,
) -> Result<ModelHandle<Self>> {
) -> Result<Model<Self>> {
Self::from_join_response(
client.request(proto::JoinRoom { id: room_id }).await?,
client,
@ -318,18 +309,41 @@ impl Room {
)
}
fn released(&mut self, cx: &mut AppContext) {
if self.status.is_online() {
self.leave_internal(cx).detach_and_log_err(cx);
}
}
fn app_will_quit(&mut self, cx: &mut ModelContext<Self>) -> impl Future<Output = ()> {
let task = if self.status.is_online() {
let leave = self.leave_internal(cx);
Some(cx.background_executor().spawn(async move {
leave.await.log_err();
}))
} else {
None
};
async move {
if let Some(task) = task {
task.await;
}
}
}
pub fn mute_on_join(cx: &AppContext) -> bool {
settings::get::<CallSettings>(cx).mute_on_join || client::IMPERSONATE_LOGIN.is_some()
CallSettings::get_global(cx).mute_on_join || client::IMPERSONATE_LOGIN.is_some()
}
fn from_join_response(
response: proto::JoinRoomResponse,
client: Arc<Client>,
user_store: ModelHandle<UserStore>,
user_store: Model<UserStore>,
mut cx: AsyncAppContext,
) -> Result<ModelHandle<Self>> {
) -> Result<Model<Self>> {
let room_proto = response.room.ok_or_else(|| anyhow!("invalid room"))?;
let room = cx.add_model(|cx| {
let room = cx.new_model(|cx| {
Self::new(
room_proto.id,
response.channel_id,
@ -338,12 +352,12 @@ impl Room {
user_store,
cx,
)
});
})?;
room.update(&mut cx, |room, cx| {
room.leave_when_empty = room.channel_id.is_none();
room.apply_room_update(room_proto, cx)?;
anyhow::Ok(())
})?;
})??;
Ok(room)
}
@ -372,7 +386,7 @@ impl Room {
self.clear_state(cx);
let leave_room = self.client.request(proto::LeaveRoom {});
cx.background().spawn(async move {
cx.background_executor().spawn(async move {
leave_room.await?;
anyhow::Ok(())
})
@ -380,14 +394,14 @@ impl Room {
pub(crate) fn clear_state(&mut self, cx: &mut AppContext) {
for project in self.shared_projects.drain() {
if let Some(project) = project.upgrade(cx) {
if let Some(project) = project.upgrade() {
project.update(cx, |project, cx| {
project.unshare(cx).log_err();
});
}
}
for project in self.joined_projects.drain() {
if let Some(project) = project.upgrade(cx) {
if let Some(project) = project.upgrade() {
project.update(cx, |project, cx| {
project.disconnected_from_host(cx);
project.close(cx);
@ -399,14 +413,14 @@ impl Room {
self.remote_participants.clear();
self.pending_participants.clear();
self.participant_user_ids.clear();
self.subscriptions.clear();
self.client_subscriptions.clear();
self.live_kit.take();
self.pending_room_update.take();
self.maintain_connection.take();
}
async fn maintain_connection(
this: WeakModelHandle<Self>,
this: WeakModel<Self>,
client: Arc<Client>,
mut cx: AsyncAppContext,
) -> Result<()> {
@ -418,32 +432,33 @@ impl Room {
if !is_connected || client_status.next().await.is_some() {
log::info!("detected client disconnection");
this.upgrade(&cx)
this.upgrade()
.ok_or_else(|| anyhow!("room was dropped"))?
.update(&mut cx, |this, cx| {
this.status = RoomStatus::Rejoining;
cx.notify();
});
})?;
// Wait for client to re-establish a connection to the server.
{
let mut reconnection_timeout = cx.background().timer(RECONNECT_TIMEOUT).fuse();
let mut reconnection_timeout =
cx.background_executor().timer(RECONNECT_TIMEOUT).fuse();
let client_reconnection = async {
let mut remaining_attempts = 3;
while remaining_attempts > 0 {
if client_status.borrow().is_connected() {
log::info!("client reconnected, attempting to rejoin room");
let Some(this) = this.upgrade(&cx) else { break };
if this
.update(&mut cx, |this, cx| this.rejoin(cx))
.await
.log_err()
.is_some()
{
return true;
} else {
remaining_attempts -= 1;
let Some(this) = this.upgrade() else { break };
match this.update(&mut cx, |this, cx| this.rejoin(cx)) {
Ok(task) => {
if task.await.log_err().is_some() {
return true;
} else {
remaining_attempts -= 1;
}
}
Err(_app_dropped) => return false,
}
} else if client_status.borrow().is_signed_out() {
return false;
@ -482,9 +497,9 @@ impl Room {
// The client failed to re-establish a connection to the server
// or an error occurred while trying to re-join the room. Either way
// we leave the room and return an error.
if let Some(this) = this.upgrade(&cx) {
if let Some(this) = this.upgrade() {
log::info!("reconnection failed, leaving room");
let _ = this.update(&mut cx, |this, cx| this.leave(cx));
let _ = this.update(&mut cx, |this, cx| this.leave(cx))?;
}
Err(anyhow!(
"can't reconnect to room: client failed to re-establish connection"
@ -496,7 +511,7 @@ impl Room {
let mut reshared_projects = Vec::new();
let mut rejoined_projects = Vec::new();
self.shared_projects.retain(|project| {
if let Some(handle) = project.upgrade(cx) {
if let Some(handle) = project.upgrade() {
let project = handle.read(cx);
if let Some(project_id) = project.remote_id() {
projects.insert(project_id, handle.clone());
@ -510,14 +525,14 @@ impl Room {
false
});
self.joined_projects.retain(|project| {
if let Some(handle) = project.upgrade(cx) {
if let Some(handle) = project.upgrade() {
let project = handle.read(cx);
if let Some(project_id) = project.remote_id() {
projects.insert(project_id, handle.clone());
rejoined_projects.push(proto::RejoinProject {
id: project_id,
worktrees: project
.worktrees(cx)
.worktrees()
.map(|worktree| {
let worktree = worktree.read(cx);
proto::RejoinWorktree {
@ -565,7 +580,7 @@ impl Room {
}
anyhow::Ok(())
})
})?
})
}
@ -643,7 +658,7 @@ impl Room {
}
async fn handle_room_updated(
this: ModelHandle<Self>,
this: Model<Self>,
envelope: TypedEnvelope<proto::RoomUpdated>,
_: Arc<Client>,
mut cx: AsyncAppContext,
@ -652,7 +667,7 @@ impl Room {
.payload
.room
.ok_or_else(|| anyhow!("invalid room"))?;
this.update(&mut cx, |this, cx| this.apply_room_update(room, cx))
this.update(&mut cx, |this, cx| this.apply_room_update(room, cx))?
}
fn apply_room_update(
@ -733,7 +748,7 @@ impl Room {
for unshared_project_id in old_projects.difference(&new_projects) {
this.joined_projects.retain(|project| {
if let Some(project) = project.upgrade(cx) {
if let Some(project) = project.upgrade() {
project.update(cx, |project, cx| {
if project.remote_id() == Some(*unshared_project_id) {
project.disconnected_from_host(cx);
@ -876,7 +891,8 @@ impl Room {
this.check_invariants();
this.room_update_completed_tx.try_send(Some(())).ok();
cx.notify();
});
})
.ok();
}));
cx.notify();
@ -907,12 +923,7 @@ impl Room {
.remote_participants
.get_mut(&user_id)
.ok_or_else(|| anyhow!("subscribed to track by unknown participant"))?;
participant.video_tracks.insert(
track_id.clone(),
Arc::new(RemoteVideoTrack {
live_kit_track: track,
}),
);
participant.video_tracks.insert(track_id.clone(), track);
cx.emit(Event::RemoteVideoTracksChanged {
participant_id: participant.peer_id,
});
@ -991,7 +1002,6 @@ impl Room {
.remote_participants
.get_mut(&user_id)
.ok_or_else(|| anyhow!("subscribed to track by unknown participant"))?;
participant.audio_tracks.insert(track_id.clone(), track);
participant.muted = publication.is_muted();
@ -1053,7 +1063,7 @@ impl Room {
let client = self.client.clone();
let room_id = self.id;
self.pending_call_count += 1;
cx.spawn(|this, mut cx| async move {
cx.spawn(move |this, mut cx| async move {
let result = client
.request(proto::Call {
room_id,
@ -1066,7 +1076,7 @@ impl Room {
if this.should_leave() {
this.leave(cx).detach_and_log_err(cx);
}
});
})?;
result?;
Ok(())
})
@ -1078,31 +1088,31 @@ impl Room {
language_registry: Arc<LanguageRegistry>,
fs: Arc<dyn Fs>,
cx: &mut ModelContext<Self>,
) -> Task<Result<ModelHandle<Project>>> {
) -> Task<Result<Model<Project>>> {
let client = self.client.clone();
let user_store = self.user_store.clone();
cx.emit(Event::RemoteProjectJoined { project_id: id });
cx.spawn(|this, mut cx| async move {
cx.spawn(move |this, mut cx| async move {
let project =
Project::remote(id, client, user_store, language_registry, fs, cx.clone()).await?;
this.update(&mut cx, |this, cx| {
this.joined_projects.retain(|project| {
if let Some(project) = project.upgrade(cx) {
if let Some(project) = project.upgrade() {
!project.read(cx).is_read_only()
} else {
false
}
});
this.joined_projects.insert(project.downgrade());
});
})?;
Ok(project)
})
}
pub(crate) fn share_project(
&mut self,
project: ModelHandle<Project>,
project: Model<Project>,
cx: &mut ModelContext<Self>,
) -> Task<Result<u64>> {
if let Some(project_id) = project.read(cx).remote_id() {
@ -1118,7 +1128,7 @@ impl Room {
project.update(&mut cx, |project, cx| {
project.shared(response.project_id, cx)
})?;
})??;
// If the user's location is in this project, it changes from UnsharedProject to SharedProject.
this.update(&mut cx, |this, cx| {
@ -1129,7 +1139,7 @@ impl Room {
} else {
Task::ready(Ok(()))
}
})
})?
.await?;
Ok(response.project_id)
@ -1138,7 +1148,7 @@ impl Room {
pub(crate) fn unshare_project(
&mut self,
project: ModelHandle<Project>,
project: Model<Project>,
cx: &mut ModelContext<Self>,
) -> Result<()> {
let project_id = match project.read(cx).remote_id() {
@ -1152,7 +1162,7 @@ impl Room {
pub(crate) fn set_location(
&mut self,
project: Option<&ModelHandle<Project>>,
project: Option<&Model<Project>>,
cx: &mut ModelContext<Self>,
) -> Task<Result<()>> {
if self.status.is_offline() {
@ -1178,7 +1188,7 @@ impl Room {
};
cx.notify();
cx.foreground().spawn(async move {
cx.background_executor().spawn(async move {
client
.request(proto::UpdateParticipantLocation {
room_id,
@ -1244,22 +1254,21 @@ impl Room {
return Task::ready(Err(anyhow!("live-kit was not initialized")));
};
cx.spawn_weak(|this, mut cx| async move {
cx.spawn(move |this, mut cx| async move {
let publish_track = async {
let track = LocalAudioTrack::create();
this.upgrade(&cx)
this.upgrade()
.ok_or_else(|| anyhow!("room was dropped"))?
.read_with(&cx, |this, _| {
.update(&mut cx, |this, _| {
this.live_kit
.as_ref()
.map(|live_kit| live_kit.room.publish_audio_track(track))
})
})?
.ok_or_else(|| anyhow!("live-kit was not initialized"))?
.await
};
let publication = publish_track.await;
this.upgrade(&cx)
this.upgrade()
.ok_or_else(|| anyhow!("room was dropped"))?
.update(&mut cx, |this, cx| {
let live_kit = this
@ -1283,7 +1292,9 @@ impl Room {
live_kit.room.unpublish_track(publication);
} else {
if muted {
cx.background().spawn(publication.set_mute(muted)).detach();
cx.background_executor()
.spawn(publication.set_mute(muted))
.detach();
}
live_kit.microphone_track = LocalTrack::Published {
track_publication: publication,
@ -1303,7 +1314,7 @@ impl Room {
}
}
}
})
})?
})
}
@ -1326,26 +1337,26 @@ impl Room {
return Task::ready(Err(anyhow!("live-kit was not initialized")));
};
cx.spawn_weak(|this, mut cx| async move {
cx.spawn(move |this, mut cx| async move {
let publish_track = async {
let displays = displays.await?;
let display = displays
.first()
.ok_or_else(|| anyhow!("no display found"))?;
let track = LocalVideoTrack::screen_share_for_display(&display);
this.upgrade(&cx)
this.upgrade()
.ok_or_else(|| anyhow!("room was dropped"))?
.read_with(&cx, |this, _| {
.update(&mut cx, |this, _| {
this.live_kit
.as_ref()
.map(|live_kit| live_kit.room.publish_video_track(track))
})
})?
.ok_or_else(|| anyhow!("live-kit was not initialized"))?
.await
};
let publication = publish_track.await;
this.upgrade(&cx)
this.upgrade()
.ok_or_else(|| anyhow!("room was dropped"))?
.update(&mut cx, |this, cx| {
let live_kit = this
@ -1369,7 +1380,9 @@ impl Room {
live_kit.room.unpublish_track(publication);
} else {
if muted {
cx.background().spawn(publication.set_mute(muted)).detach();
cx.background_executor()
.spawn(publication.set_mute(muted))
.detach();
}
live_kit.screen_track = LocalTrack::Published {
track_publication: publication,
@ -1392,7 +1405,7 @@ impl Room {
}
}
}
})
})?
})
}
@ -1435,11 +1448,12 @@ impl Room {
.room
.remote_audio_track_publications(&participant.user.id.to_string())
{
tasks.push(cx.foreground().spawn(track.set_enabled(!live_kit.deafened)));
let deafened = live_kit.deafened;
tasks.push(cx.foreground_executor().spawn(track.set_enabled(!deafened)));
}
}
Ok(cx.foreground().spawn(async move {
Ok(cx.foreground_executor().spawn(async move {
if let Some(mute_task) = mute_task {
mute_task.await?;
}
@ -1499,7 +1513,6 @@ struct LiveKitRoom {
deafened: bool,
speaking: bool,
next_publish_id: usize,
can_publish: bool,
_maintain_room: Task<()>,
_maintain_tracks: [Task<()>; 2],
}
@ -1531,7 +1544,8 @@ impl LiveKitRoom {
*muted = should_mute;
cx.notify();
Ok((
cx.background().spawn(track_publication.set_mute(*muted)),
cx.background_executor()
.spawn(track_publication.set_mute(*muted)),
old_muted,
))
}

View File

@ -1,54 +0,0 @@
[package]
name = "call2"
version = "0.1.0"
edition = "2021"
publish = false
[lib]
path = "src/call2.rs"
doctest = false
[features]
test-support = [
"client/test-support",
"collections/test-support",
"gpui/test-support",
"live_kit_client/test-support",
"project/test-support",
"util/test-support"
]
[dependencies]
audio = { package = "audio2", path = "../audio2" }
client = { package = "client2", path = "../client2" }
collections = { path = "../collections" }
gpui = { package = "gpui2", path = "../gpui2" }
log.workspace = true
live_kit_client = { package = "live_kit_client2", path = "../live_kit_client2" }
fs = { package = "fs2", path = "../fs2" }
language = { package = "language2", path = "../language2" }
media = { path = "../media" }
project = { package = "project2", path = "../project2" }
settings = { package = "settings2", path = "../settings2" }
util = { path = "../util" }
anyhow.workspace = true
async-broadcast = "0.4"
futures.workspace = true
image = "0.23"
postage.workspace = true
schemars.workspace = true
serde.workspace = true
serde_json.workspace = true
serde_derive.workspace = true
smallvec.workspace = true
[dev-dependencies]
client = { package = "client2", path = "../client2", features = ["test-support"] }
fs = { package = "fs2", path = "../fs2", features = ["test-support"] }
language = { package = "language2", path = "../language2", features = ["test-support"] }
collections = { path = "../collections", features = ["test-support"] }
gpui = { package = "gpui2", path = "../gpui2", features = ["test-support"] }
live_kit_client = { package = "live_kit_client2", path = "../live_kit_client2", features = ["test-support"] }
project = { package = "project2", path = "../project2", features = ["test-support"] }
util = { path = "../util", features = ["test-support"] }

View File

@ -1,543 +0,0 @@
pub mod call_settings;
pub mod participant;
pub mod room;
use anyhow::{anyhow, Result};
use audio::Audio;
use call_settings::CallSettings;
use client::{proto, Client, TelemetrySettings, TypedEnvelope, User, UserStore, ZED_ALWAYS_ACTIVE};
use collections::HashSet;
use futures::{channel::oneshot, future::Shared, Future, FutureExt};
use gpui::{
AppContext, AsyncAppContext, Context, EventEmitter, Model, ModelContext, Subscription, Task,
WeakModel,
};
use postage::watch;
use project::Project;
use room::Event;
use settings::Settings;
use std::sync::Arc;
pub use participant::ParticipantLocation;
pub use room::Room;
pub fn init(client: Arc<Client>, user_store: Model<UserStore>, cx: &mut AppContext) {
CallSettings::register(cx);
let active_call = cx.new_model(|cx| ActiveCall::new(client, user_store, cx));
cx.set_global(active_call);
}
pub struct OneAtATime {
cancel: Option<oneshot::Sender<()>>,
}
impl OneAtATime {
/// spawn a task in the given context.
/// if another task is spawned before that resolves, or if the OneAtATime itself is dropped, the first task will be cancelled and return Ok(None)
/// otherwise you'll see the result of the task.
fn spawn<F, Fut, R>(&mut self, cx: &mut AppContext, f: F) -> Task<Result<Option<R>>>
where
F: 'static + FnOnce(AsyncAppContext) -> Fut,
Fut: Future<Output = Result<R>>,
R: 'static,
{
let (tx, rx) = oneshot::channel();
self.cancel.replace(tx);
cx.spawn(|cx| async move {
futures::select_biased! {
_ = rx.fuse() => Ok(None),
result = f(cx).fuse() => result.map(Some),
}
})
}
fn running(&self) -> bool {
self.cancel
.as_ref()
.is_some_and(|cancel| !cancel.is_canceled())
}
}
#[derive(Clone)]
pub struct IncomingCall {
pub room_id: u64,
pub calling_user: Arc<User>,
pub participants: Vec<Arc<User>>,
pub initial_project: Option<proto::ParticipantProject>,
}
/// Singleton global maintaining the user's participation in a room across workspaces.
pub struct ActiveCall {
room: Option<(Model<Room>, Vec<Subscription>)>,
pending_room_creation: Option<Shared<Task<Result<Model<Room>, Arc<anyhow::Error>>>>>,
location: Option<WeakModel<Project>>,
_join_debouncer: OneAtATime,
pending_invites: HashSet<u64>,
incoming_call: (
watch::Sender<Option<IncomingCall>>,
watch::Receiver<Option<IncomingCall>>,
),
client: Arc<Client>,
user_store: Model<UserStore>,
_subscriptions: Vec<client::Subscription>,
}
impl EventEmitter<Event> for ActiveCall {}
impl ActiveCall {
fn new(client: Arc<Client>, user_store: Model<UserStore>, cx: &mut ModelContext<Self>) -> Self {
Self {
room: None,
pending_room_creation: None,
location: None,
pending_invites: Default::default(),
incoming_call: watch::channel(),
_join_debouncer: OneAtATime { cancel: None },
_subscriptions: vec![
client.add_request_handler(cx.weak_model(), Self::handle_incoming_call),
client.add_message_handler(cx.weak_model(), Self::handle_call_canceled),
],
client,
user_store,
}
}
pub fn channel_id(&self, cx: &AppContext) -> Option<u64> {
self.room()?.read(cx).channel_id()
}
async fn handle_incoming_call(
this: Model<Self>,
envelope: TypedEnvelope<proto::IncomingCall>,
_: Arc<Client>,
mut cx: AsyncAppContext,
) -> Result<proto::Ack> {
let user_store = this.update(&mut cx, |this, _| this.user_store.clone())?;
let call = IncomingCall {
room_id: envelope.payload.room_id,
participants: user_store
.update(&mut cx, |user_store, cx| {
user_store.get_users(envelope.payload.participant_user_ids, cx)
})?
.await?,
calling_user: user_store
.update(&mut cx, |user_store, cx| {
user_store.get_user(envelope.payload.calling_user_id, cx)
})?
.await?,
initial_project: envelope.payload.initial_project,
};
this.update(&mut cx, |this, _| {
*this.incoming_call.0.borrow_mut() = Some(call);
})?;
Ok(proto::Ack {})
}
async fn handle_call_canceled(
this: Model<Self>,
envelope: TypedEnvelope<proto::CallCanceled>,
_: Arc<Client>,
mut cx: AsyncAppContext,
) -> Result<()> {
this.update(&mut cx, |this, _| {
let mut incoming_call = this.incoming_call.0.borrow_mut();
if incoming_call
.as_ref()
.map_or(false, |call| call.room_id == envelope.payload.room_id)
{
incoming_call.take();
}
})?;
Ok(())
}
pub fn global(cx: &AppContext) -> Model<Self> {
cx.global::<Model<Self>>().clone()
}
pub fn invite(
&mut self,
called_user_id: u64,
initial_project: Option<Model<Project>>,
cx: &mut ModelContext<Self>,
) -> Task<Result<()>> {
if !self.pending_invites.insert(called_user_id) {
return Task::ready(Err(anyhow!("user was already invited")));
}
cx.notify();
if self._join_debouncer.running() {
return Task::ready(Ok(()));
}
let room = if let Some(room) = self.room().cloned() {
Some(Task::ready(Ok(room)).shared())
} else {
self.pending_room_creation.clone()
};
let invite = if let Some(room) = room {
cx.spawn(move |_, mut cx| async move {
let room = room.await.map_err(|err| anyhow!("{:?}", err))?;
let initial_project_id = if let Some(initial_project) = initial_project {
Some(
room.update(&mut cx, |room, cx| room.share_project(initial_project, cx))?
.await?,
)
} else {
None
};
room.update(&mut cx, move |room, cx| {
room.call(called_user_id, initial_project_id, cx)
})?
.await?;
anyhow::Ok(())
})
} else {
let client = self.client.clone();
let user_store = self.user_store.clone();
let room = cx
.spawn(move |this, mut cx| async move {
let create_room = async {
let room = cx
.update(|cx| {
Room::create(
called_user_id,
initial_project,
client,
user_store,
cx,
)
})?
.await?;
this.update(&mut cx, |this, cx| this.set_room(Some(room.clone()), cx))?
.await?;
anyhow::Ok(room)
};
let room = create_room.await;
this.update(&mut cx, |this, _| this.pending_room_creation = None)?;
room.map_err(Arc::new)
})
.shared();
self.pending_room_creation = Some(room.clone());
cx.background_executor().spawn(async move {
room.await.map_err(|err| anyhow!("{:?}", err))?;
anyhow::Ok(())
})
};
cx.spawn(move |this, mut cx| async move {
let result = invite.await;
if result.is_ok() {
this.update(&mut cx, |this, cx| this.report_call_event("invite", cx))?;
} else {
// TODO: Resport collaboration error
}
this.update(&mut cx, |this, cx| {
this.pending_invites.remove(&called_user_id);
cx.notify();
})?;
result
})
}
pub fn cancel_invite(
&mut self,
called_user_id: u64,
cx: &mut ModelContext<Self>,
) -> Task<Result<()>> {
let room_id = if let Some(room) = self.room() {
room.read(cx).id()
} else {
return Task::ready(Err(anyhow!("no active call")));
};
let client = self.client.clone();
cx.background_executor().spawn(async move {
client
.request(proto::CancelCall {
room_id,
called_user_id,
})
.await?;
anyhow::Ok(())
})
}
pub fn incoming(&self) -> watch::Receiver<Option<IncomingCall>> {
self.incoming_call.1.clone()
}
pub fn accept_incoming(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
if self.room.is_some() {
return Task::ready(Err(anyhow!("cannot join while on another call")));
}
let call = if let Some(call) = self.incoming_call.1.borrow().clone() {
call
} else {
return Task::ready(Err(anyhow!("no incoming call")));
};
if self.pending_room_creation.is_some() {
return Task::ready(Ok(()));
}
let room_id = call.room_id.clone();
let client = self.client.clone();
let user_store = self.user_store.clone();
let join = self
._join_debouncer
.spawn(cx, move |cx| Room::join(room_id, client, user_store, cx));
cx.spawn(|this, mut cx| async move {
let room = join.await?;
this.update(&mut cx, |this, cx| this.set_room(room.clone(), cx))?
.await?;
this.update(&mut cx, |this, cx| {
this.report_call_event("accept incoming", cx)
})?;
Ok(())
})
}
pub fn decline_incoming(&mut self, cx: &mut ModelContext<Self>) -> Result<()> {
let call = self
.incoming_call
.0
.borrow_mut()
.take()
.ok_or_else(|| anyhow!("no incoming call"))?;
report_call_event_for_room("decline incoming", call.room_id, None, &self.client, cx);
self.client.send(proto::DeclineCall {
room_id: call.room_id,
})?;
Ok(())
}
pub fn join_channel(
&mut self,
channel_id: u64,
cx: &mut ModelContext<Self>,
) -> Task<Result<Option<Model<Room>>>> {
if let Some(room) = self.room().cloned() {
if room.read(cx).channel_id() == Some(channel_id) {
return Task::ready(Ok(Some(room)));
} else {
room.update(cx, |room, cx| room.clear_state(cx));
}
}
if self.pending_room_creation.is_some() {
return Task::ready(Ok(None));
}
let client = self.client.clone();
let user_store = self.user_store.clone();
let join = self._join_debouncer.spawn(cx, move |cx| async move {
Room::join_channel(channel_id, client, user_store, cx).await
});
cx.spawn(|this, mut cx| async move {
let room = join.await?;
this.update(&mut cx, |this, cx| this.set_room(room.clone(), cx))?
.await?;
this.update(&mut cx, |this, cx| {
this.report_call_event("join channel", cx)
})?;
Ok(room)
})
}
pub fn hang_up(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
cx.notify();
self.report_call_event("hang up", cx);
Audio::end_call(cx);
if let Some((room, _)) = self.room.take() {
room.update(cx, |room, cx| room.leave(cx))
} else {
Task::ready(Ok(()))
}
}
pub fn share_project(
&mut self,
project: Model<Project>,
cx: &mut ModelContext<Self>,
) -> Task<Result<u64>> {
if let Some((room, _)) = self.room.as_ref() {
self.report_call_event("share project", cx);
room.update(cx, |room, cx| room.share_project(project, cx))
} else {
Task::ready(Err(anyhow!("no active call")))
}
}
pub fn unshare_project(
&mut self,
project: Model<Project>,
cx: &mut ModelContext<Self>,
) -> Result<()> {
if let Some((room, _)) = self.room.as_ref() {
self.report_call_event("unshare project", cx);
room.update(cx, |room, cx| room.unshare_project(project, cx))
} else {
Err(anyhow!("no active call"))
}
}
pub fn location(&self) -> Option<&WeakModel<Project>> {
self.location.as_ref()
}
pub fn set_location(
&mut self,
project: Option<&Model<Project>>,
cx: &mut ModelContext<Self>,
) -> Task<Result<()>> {
if project.is_some() || !*ZED_ALWAYS_ACTIVE {
self.location = project.map(|project| project.downgrade());
if let Some((room, _)) = self.room.as_ref() {
return room.update(cx, |room, cx| room.set_location(project, cx));
}
}
Task::ready(Ok(()))
}
fn set_room(
&mut self,
room: Option<Model<Room>>,
cx: &mut ModelContext<Self>,
) -> Task<Result<()>> {
if room.as_ref() != self.room.as_ref().map(|room| &room.0) {
cx.notify();
if let Some(room) = room {
if room.read(cx).status().is_offline() {
self.room = None;
Task::ready(Ok(()))
} else {
let subscriptions = vec![
cx.observe(&room, |this, room, cx| {
if room.read(cx).status().is_offline() {
this.set_room(None, cx).detach_and_log_err(cx);
}
cx.notify();
}),
cx.subscribe(&room, |_, _, event, cx| cx.emit(event.clone())),
];
self.room = Some((room.clone(), subscriptions));
let location = self
.location
.as_ref()
.and_then(|location| location.upgrade());
room.update(cx, |room, cx| room.set_location(location.as_ref(), cx))
}
} else {
self.room = None;
Task::ready(Ok(()))
}
} else {
Task::ready(Ok(()))
}
}
pub fn room(&self) -> Option<&Model<Room>> {
self.room.as_ref().map(|(room, _)| room)
}
pub fn client(&self) -> Arc<Client> {
self.client.clone()
}
pub fn pending_invites(&self) -> &HashSet<u64> {
&self.pending_invites
}
pub fn report_call_event(&self, operation: &'static str, cx: &mut AppContext) {
if let Some(room) = self.room() {
let room = room.read(cx);
report_call_event_for_room(operation, room.id(), room.channel_id(), &self.client, cx);
}
}
}
pub fn report_call_event_for_room(
operation: &'static str,
room_id: u64,
channel_id: Option<u64>,
client: &Arc<Client>,
cx: &mut AppContext,
) {
let telemetry = client.telemetry();
let telemetry_settings = *TelemetrySettings::get_global(cx);
telemetry.report_call_event(telemetry_settings, operation, Some(room_id), channel_id)
}
pub fn report_call_event_for_channel(
operation: &'static str,
channel_id: u64,
client: &Arc<Client>,
cx: &AppContext,
) {
let room = ActiveCall::global(cx).read(cx).room();
let telemetry = client.telemetry();
let telemetry_settings = *TelemetrySettings::get_global(cx);
telemetry.report_call_event(
telemetry_settings,
operation,
room.map(|r| r.read(cx).id()),
Some(channel_id),
)
}
#[cfg(test)]
mod test {
use gpui::TestAppContext;
use crate::OneAtATime;
#[gpui::test]
async fn test_one_at_a_time(cx: &mut TestAppContext) {
let mut one_at_a_time = OneAtATime { cancel: None };
assert_eq!(
cx.update(|cx| one_at_a_time.spawn(cx, |_| async { Ok(1) }))
.await
.unwrap(),
Some(1)
);
let (a, b) = cx.update(|cx| {
(
one_at_a_time.spawn(cx, |_| async {
assert!(false);
Ok(2)
}),
one_at_a_time.spawn(cx, |_| async { Ok(3) }),
)
});
assert_eq!(a.await.unwrap(), None);
assert_eq!(b.await.unwrap(), Some(3));
let promise = cx.update(|cx| one_at_a_time.spawn(cx, |_| async { Ok(4) }));
drop(one_at_a_time);
assert_eq!(promise.await.unwrap(), None);
}
}

View File

@ -1,32 +0,0 @@
use anyhow::Result;
use gpui::AppContext;
use schemars::JsonSchema;
use serde_derive::{Deserialize, Serialize};
use settings::Settings;
#[derive(Deserialize, Debug)]
pub struct CallSettings {
pub mute_on_join: bool,
}
#[derive(Clone, Default, Serialize, Deserialize, JsonSchema, Debug)]
pub struct CallSettingsContent {
pub mute_on_join: Option<bool>,
}
impl Settings for CallSettings {
const KEY: Option<&'static str> = Some("calls");
type FileContent = CallSettingsContent;
fn load(
default_value: &Self::FileContent,
user_values: &[&Self::FileContent],
_cx: &mut AppContext,
) -> Result<Self>
where
Self: Sized,
{
Self::load_via_json_merge(default_value, user_values)
}
}

View File

@ -1,52 +0,0 @@
use anyhow::{anyhow, Result};
use client::ParticipantIndex;
use client::{proto, User};
use collections::HashMap;
use gpui::WeakModel;
pub use live_kit_client::Frame;
pub use live_kit_client::{RemoteAudioTrack, RemoteVideoTrack};
use project::Project;
use std::sync::Arc;
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub enum ParticipantLocation {
SharedProject { project_id: u64 },
UnsharedProject,
External,
}
impl ParticipantLocation {
pub fn from_proto(location: Option<proto::ParticipantLocation>) -> Result<Self> {
match location.and_then(|l| l.variant) {
Some(proto::participant_location::Variant::SharedProject(project)) => {
Ok(Self::SharedProject {
project_id: project.id,
})
}
Some(proto::participant_location::Variant::UnsharedProject(_)) => {
Ok(Self::UnsharedProject)
}
Some(proto::participant_location::Variant::External(_)) => Ok(Self::External),
None => Err(anyhow!("participant location was not provided")),
}
}
}
#[derive(Clone, Default)]
pub struct LocalParticipant {
pub projects: Vec<proto::ParticipantProject>,
pub active_project: Option<WeakModel<Project>>,
}
#[derive(Clone, Debug)]
pub struct RemoteParticipant {
pub user: Arc<User>,
pub peer_id: proto::PeerId,
pub projects: Vec<proto::ParticipantProject>,
pub location: ParticipantLocation,
pub participant_index: ParticipantIndex,
pub muted: bool,
pub speaking: bool,
pub video_tracks: HashMap<live_kit_client::Sid, Arc<RemoteVideoTrack>>,
pub audio_tracks: HashMap<live_kit_client::Sid, Arc<RemoteAudioTrack>>,
}

File diff suppressed because it is too large Load Diff

View File

@ -3,7 +3,7 @@ mod channel_chat;
mod channel_store;
use client::{Client, UserStore};
use gpui::{AppContext, ModelHandle};
use gpui::{AppContext, Model};
use std::sync::Arc;
pub use channel_buffer::{ChannelBuffer, ChannelBufferEvent, ACKNOWLEDGE_DEBOUNCE_INTERVAL};
@ -16,7 +16,7 @@ pub use channel_store::{Channel, ChannelEvent, ChannelId, ChannelMembership, Cha
#[cfg(test)]
mod channel_store_tests;
pub fn init(client: &Arc<Client>, user_store: ModelHandle<UserStore>, cx: &mut AppContext) {
pub fn init(client: &Arc<Client>, user_store: Model<UserStore>, cx: &mut AppContext) {
channel_store::init(client, user_store, cx);
channel_buffer::init(client);
channel_chat::init(client);

View File

@ -2,7 +2,7 @@ use crate::{Channel, ChannelId, ChannelStore};
use anyhow::Result;
use client::{Client, Collaborator, UserStore};
use collections::HashMap;
use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task};
use gpui::{AppContext, AsyncAppContext, Context, EventEmitter, Model, ModelContext, Task};
use language::proto::serialize_version;
use rpc::{
proto::{self, PeerId},
@ -22,9 +22,9 @@ pub struct ChannelBuffer {
pub channel_id: ChannelId,
connected: bool,
collaborators: HashMap<PeerId, Collaborator>,
user_store: ModelHandle<UserStore>,
channel_store: ModelHandle<ChannelStore>,
buffer: ModelHandle<language::Buffer>,
user_store: Model<UserStore>,
channel_store: Model<ChannelStore>,
buffer: Model<language::Buffer>,
buffer_epoch: u64,
client: Arc<Client>,
subscription: Option<client::Subscription>,
@ -38,31 +38,16 @@ pub enum ChannelBufferEvent {
ChannelChanged,
}
impl Entity for ChannelBuffer {
type Event = ChannelBufferEvent;
fn release(&mut self, _: &mut AppContext) {
if self.connected {
if let Some(task) = self.acknowledge_task.take() {
task.detach();
}
self.client
.send(proto::LeaveChannelBuffer {
channel_id: self.channel_id,
})
.log_err();
}
}
}
impl EventEmitter<ChannelBufferEvent> for ChannelBuffer {}
impl ChannelBuffer {
pub(crate) async fn new(
channel: Arc<Channel>,
client: Arc<Client>,
user_store: ModelHandle<UserStore>,
channel_store: ModelHandle<ChannelStore>,
user_store: Model<UserStore>,
channel_store: Model<ChannelStore>,
mut cx: AsyncAppContext,
) -> Result<ModelHandle<Self>> {
) -> Result<Model<Self>> {
let response = client
.request(proto::JoinChannelBuffer {
channel_id: channel.id,
@ -76,16 +61,16 @@ impl ChannelBuffer {
.map(language::proto::deserialize_operation)
.collect::<Result<Vec<_>, _>>()?;
let buffer = cx.add_model(|_| {
let buffer = cx.new_model(|_| {
language::Buffer::remote(response.buffer_id, response.replica_id as u16, base_text)
});
buffer.update(&mut cx, |buffer, cx| buffer.apply_ops(operations, cx))?;
})?;
buffer.update(&mut cx, |buffer, cx| buffer.apply_ops(operations, cx))??;
let subscription = client.subscribe_to_entity(channel.id)?;
anyhow::Ok(cx.add_model(|cx| {
anyhow::Ok(cx.new_model(|cx| {
cx.subscribe(&buffer, Self::on_buffer_update).detach();
cx.on_release(Self::release).detach();
let mut this = Self {
buffer,
buffer_epoch: response.epoch,
@ -100,14 +85,27 @@ impl ChannelBuffer {
};
this.replace_collaborators(response.collaborators, cx);
this
}))
})?)
}
fn release(&mut self, _: &mut AppContext) {
if self.connected {
if let Some(task) = self.acknowledge_task.take() {
task.detach();
}
self.client
.send(proto::LeaveChannelBuffer {
channel_id: self.channel_id,
})
.log_err();
}
}
pub fn remote_id(&self, cx: &AppContext) -> u64 {
self.buffer.read(cx).remote_id()
}
pub fn user_store(&self) -> &ModelHandle<UserStore> {
pub fn user_store(&self) -> &Model<UserStore> {
&self.user_store
}
@ -136,7 +134,7 @@ impl ChannelBuffer {
}
async fn handle_update_channel_buffer(
this: ModelHandle<Self>,
this: Model<Self>,
update_channel_buffer: TypedEnvelope<proto::UpdateChannelBuffer>,
_: Arc<Client>,
mut cx: AsyncAppContext,
@ -152,13 +150,13 @@ impl ChannelBuffer {
cx.notify();
this.buffer
.update(cx, |buffer, cx| buffer.apply_ops(ops, cx))
})?;
})??;
Ok(())
}
async fn handle_update_channel_buffer_collaborators(
this: ModelHandle<Self>,
this: Model<Self>,
message: TypedEnvelope<proto::UpdateChannelBufferCollaborators>,
_: Arc<Client>,
mut cx: AsyncAppContext,
@ -167,14 +165,12 @@ impl ChannelBuffer {
this.replace_collaborators(message.payload.collaborators, cx);
cx.emit(ChannelBufferEvent::CollaboratorsChanged);
cx.notify();
});
Ok(())
})
}
fn on_buffer_update(
&mut self,
_: ModelHandle<language::Buffer>,
_: Model<language::Buffer>,
event: &language::Event,
cx: &mut ModelContext<Self>,
) {
@ -202,8 +198,10 @@ impl ChannelBuffer {
let client = self.client.clone();
let epoch = self.epoch();
self.acknowledge_task = Some(cx.spawn_weak(|_, cx| async move {
cx.background().timer(ACKNOWLEDGE_DEBOUNCE_INTERVAL).await;
self.acknowledge_task = Some(cx.spawn(move |_, cx| async move {
cx.background_executor()
.timer(ACKNOWLEDGE_DEBOUNCE_INTERVAL)
.await;
client
.send(proto::AckBufferOperation {
buffer_id,
@ -219,7 +217,7 @@ impl ChannelBuffer {
self.buffer_epoch
}
pub fn buffer(&self) -> ModelHandle<language::Buffer> {
pub fn buffer(&self) -> Model<language::Buffer> {
self.buffer.clone()
}

View File

@ -6,7 +6,7 @@ use client::{
Client, Subscription, TypedEnvelope, UserId,
};
use futures::lock::Mutex;
use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task};
use gpui::{AppContext, AsyncAppContext, Context, EventEmitter, Model, ModelContext, Task};
use rand::prelude::*;
use std::{
collections::HashSet,
@ -22,11 +22,11 @@ pub struct ChannelChat {
pub channel_id: ChannelId,
messages: SumTree<ChannelMessage>,
acknowledged_message_ids: HashSet<u64>,
channel_store: ModelHandle<ChannelStore>,
channel_store: Model<ChannelStore>,
loaded_all_messages: bool,
last_acknowledged_id: Option<u64>,
next_pending_message_id: usize,
user_store: ModelHandle<UserStore>,
user_store: Model<UserStore>,
rpc: Arc<Client>,
outgoing_messages_lock: Arc<Mutex<()>>,
rng: StdRng,
@ -76,31 +76,20 @@ pub enum ChannelChatEvent {
},
}
impl EventEmitter<ChannelChatEvent> for ChannelChat {}
pub fn init(client: &Arc<Client>) {
client.add_model_message_handler(ChannelChat::handle_message_sent);
client.add_model_message_handler(ChannelChat::handle_message_removed);
}
impl Entity for ChannelChat {
type Event = ChannelChatEvent;
fn release(&mut self, _: &mut AppContext) {
self.rpc
.send(proto::LeaveChannelChat {
channel_id: self.channel_id,
})
.log_err();
}
}
impl ChannelChat {
pub async fn new(
channel: Arc<Channel>,
channel_store: ModelHandle<ChannelStore>,
user_store: ModelHandle<UserStore>,
channel_store: Model<ChannelStore>,
user_store: Model<UserStore>,
client: Arc<Client>,
mut cx: AsyncAppContext,
) -> Result<ModelHandle<Self>> {
) -> Result<Model<Self>> {
let channel_id = channel.id;
let subscription = client.subscribe_to_entity(channel_id).unwrap();
@ -110,7 +99,8 @@ impl ChannelChat {
let messages = messages_from_proto(response.messages, &user_store, &mut cx).await?;
let loaded_all_messages = response.done;
Ok(cx.add_model(|cx| {
Ok(cx.new_model(|cx| {
cx.on_release(Self::release).detach();
let mut this = Self {
channel_id: channel.id,
user_store,
@ -127,7 +117,15 @@ impl ChannelChat {
};
this.insert_messages(messages, cx);
this
}))
})?)
}
fn release(&mut self, _: &mut AppContext) {
self.rpc
.send(proto::LeaveChannelChat {
channel_id: self.channel_id,
})
.log_err();
}
pub fn channel(&self, cx: &AppContext) -> Option<Arc<Channel>> {
@ -176,7 +174,7 @@ impl ChannelChat {
let user_store = self.user_store.clone();
let rpc = self.rpc.clone();
let outgoing_messages_lock = self.outgoing_messages_lock.clone();
Ok(cx.spawn(|this, mut cx| async move {
Ok(cx.spawn(move |this, mut cx| async move {
let outgoing_message_guard = outgoing_messages_lock.lock().await;
let request = rpc.request(proto::SendChannelMessage {
channel_id,
@ -191,8 +189,8 @@ impl ChannelChat {
let message = ChannelMessage::from_proto(response, &user_store, &mut cx).await?;
this.update(&mut cx, |this, cx| {
this.insert_messages(SumTree::from_item(message, &()), cx);
Ok(id)
})
})?;
Ok(id)
}))
}
@ -201,13 +199,12 @@ impl ChannelChat {
channel_id: self.channel_id,
message_id: id,
});
cx.spawn(|this, mut cx| async move {
cx.spawn(move |this, mut cx| async move {
response.await?;
this.update(&mut cx, |this, cx| {
this.message_removed(id, cx);
Ok(())
})
})?;
Ok(())
})
}
@ -220,7 +217,7 @@ impl ChannelChat {
let user_store = self.user_store.clone();
let channel_id = self.channel_id;
let before_message_id = self.first_loaded_message_id()?;
Some(cx.spawn(|this, mut cx| {
Some(cx.spawn(move |this, mut cx| {
async move {
let response = rpc
.request(proto::GetChannelMessages {
@ -233,7 +230,7 @@ impl ChannelChat {
this.update(&mut cx, |this, cx| {
this.loaded_all_messages = loaded_all_messages;
this.insert_messages(messages, cx);
});
})?;
anyhow::Ok(())
}
.log_err()
@ -251,31 +248,33 @@ impl ChannelChat {
///
/// For now, we always maintain a suffix of the channel's messages.
pub async fn load_history_since_message(
chat: ModelHandle<Self>,
chat: Model<Self>,
message_id: u64,
mut cx: AsyncAppContext,
) -> Option<usize> {
loop {
let step = chat.update(&mut cx, |chat, cx| {
if let Some(first_id) = chat.first_loaded_message_id() {
if first_id <= message_id {
let mut cursor = chat.messages.cursor::<(ChannelMessageId, Count)>();
let message_id = ChannelMessageId::Saved(message_id);
cursor.seek(&message_id, Bias::Left, &());
return ControlFlow::Break(
if cursor
.item()
.map_or(false, |message| message.id == message_id)
{
Some(cursor.start().1 .0)
} else {
None
},
);
let step = chat
.update(&mut cx, |chat, cx| {
if let Some(first_id) = chat.first_loaded_message_id() {
if first_id <= message_id {
let mut cursor = chat.messages.cursor::<(ChannelMessageId, Count)>();
let message_id = ChannelMessageId::Saved(message_id);
cursor.seek(&message_id, Bias::Left, &());
return ControlFlow::Break(
if cursor
.item()
.map_or(false, |message| message.id == message_id)
{
Some(cursor.start().1 .0)
} else {
None
},
);
}
}
}
ControlFlow::Continue(chat.load_more_messages(cx))
});
ControlFlow::Continue(chat.load_more_messages(cx))
})
.log_err()?;
match step {
ControlFlow::Break(ix) => return ix,
ControlFlow::Continue(task) => task?.await?,
@ -307,7 +306,7 @@ impl ChannelChat {
let user_store = self.user_store.clone();
let rpc = self.rpc.clone();
let channel_id = self.channel_id;
cx.spawn(|this, mut cx| {
cx.spawn(move |this, mut cx| {
async move {
let response = rpc.request(proto::JoinChannelChat { channel_id }).await?;
let messages = messages_from_proto(response.messages, &user_store, &mut cx).await?;
@ -333,7 +332,7 @@ impl ChannelChat {
}
this.pending_messages().cloned().collect::<Vec<_>>()
});
})?;
for pending_message in pending_messages {
let request = rpc.request(proto::SendChannelMessage {
@ -351,7 +350,7 @@ impl ChannelChat {
.await?;
this.update(&mut cx, |this, cx| {
this.insert_messages(SumTree::from_item(message, &()), cx);
});
})?;
}
anyhow::Ok(())
@ -399,12 +398,12 @@ impl ChannelChat {
}
async fn handle_message_sent(
this: ModelHandle<Self>,
this: Model<Self>,
message: TypedEnvelope<proto::ChannelMessageSent>,
_: Arc<Client>,
mut cx: AsyncAppContext,
) -> Result<()> {
let user_store = this.read_with(&cx, |this, _| this.user_store.clone());
let user_store = this.update(&mut cx, |this, _| this.user_store.clone())?;
let message = message
.payload
.message
@ -418,20 +417,20 @@ impl ChannelChat {
channel_id: this.channel_id,
message_id,
})
});
})?;
Ok(())
}
async fn handle_message_removed(
this: ModelHandle<Self>,
this: Model<Self>,
message: TypedEnvelope<proto::RemoveChannelMessage>,
_: Arc<Client>,
mut cx: AsyncAppContext,
) -> Result<()> {
this.update(&mut cx, |this, cx| {
this.message_removed(message.payload.message_id, cx)
});
})?;
Ok(())
}
@ -515,7 +514,7 @@ impl ChannelChat {
async fn messages_from_proto(
proto_messages: Vec<proto::ChannelMessage>,
user_store: &ModelHandle<UserStore>,
user_store: &Model<UserStore>,
cx: &mut AsyncAppContext,
) -> Result<SumTree<ChannelMessage>> {
let messages = ChannelMessage::from_proto_vec(proto_messages, user_store, cx).await?;
@ -527,13 +526,13 @@ async fn messages_from_proto(
impl ChannelMessage {
pub async fn from_proto(
message: proto::ChannelMessage,
user_store: &ModelHandle<UserStore>,
user_store: &Model<UserStore>,
cx: &mut AsyncAppContext,
) -> Result<Self> {
let sender = user_store
.update(cx, |user_store, cx| {
user_store.get_user(message.sender_id, cx)
})
})?
.await?;
Ok(ChannelMessage {
id: ChannelMessageId::Saved(message.id),
@ -561,7 +560,7 @@ impl ChannelMessage {
pub async fn from_proto_vec(
proto_messages: Vec<proto::ChannelMessage>,
user_store: &ModelHandle<UserStore>,
user_store: &Model<UserStore>,
cx: &mut AsyncAppContext,
) -> Result<Vec<Self>> {
let unique_user_ids = proto_messages
@ -573,7 +572,7 @@ impl ChannelMessage {
user_store
.update(cx, |user_store, cx| {
user_store.get_users(unique_user_ids, cx)
})
})?
.await?;
let mut messages = Vec::with_capacity(proto_messages.len());

View File

@ -7,17 +7,20 @@ use client::{Client, Subscription, User, UserId, UserStore};
use collections::{hash_map, HashMap, HashSet};
use db::RELEASE_CHANNEL;
use futures::{channel::mpsc, future::Shared, Future, FutureExt, StreamExt};
use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle};
use gpui::{
AppContext, AsyncAppContext, Context, EventEmitter, Model, ModelContext, SharedString, Task,
WeakModel,
};
use rpc::{
proto::{self, ChannelVisibility},
TypedEnvelope,
};
use std::{mem, sync::Arc, time::Duration};
use util::ResultExt;
use util::{async_maybe, ResultExt};
pub fn init(client: &Arc<Client>, user_store: ModelHandle<UserStore>, cx: &mut AppContext) {
pub fn init(client: &Arc<Client>, user_store: Model<UserStore>, cx: &mut AppContext) {
let channel_store =
cx.add_model(|cx| ChannelStore::new(client.clone(), user_store.clone(), cx));
cx.new_model(|cx| ChannelStore::new(client.clone(), user_store.clone(), cx));
cx.set_global(channel_store);
}
@ -34,7 +37,7 @@ pub struct ChannelStore {
opened_buffers: HashMap<ChannelId, OpenedModelHandle<ChannelBuffer>>,
opened_chats: HashMap<ChannelId, OpenedModelHandle<ChannelChat>>,
client: Arc<Client>,
user_store: ModelHandle<UserStore>,
user_store: Model<UserStore>,
_rpc_subscription: Subscription,
_watch_connection_status: Task<Option<()>>,
disconnect_channel_buffers_task: Option<Task<()>>,
@ -44,7 +47,7 @@ pub struct ChannelStore {
#[derive(Clone, Debug, PartialEq)]
pub struct Channel {
pub id: ChannelId,
pub name: String,
pub name: SharedString,
pub visibility: proto::ChannelVisibility,
pub role: proto::ChannelRole,
pub unseen_note_version: Option<(u64, clock::Global)>,
@ -112,44 +115,45 @@ pub enum ChannelEvent {
ChannelRenamed(ChannelId),
}
impl Entity for ChannelStore {
type Event = ChannelEvent;
}
impl EventEmitter<ChannelEvent> for ChannelStore {}
enum OpenedModelHandle<E: Entity> {
Open(WeakModelHandle<E>),
Loading(Shared<Task<Result<ModelHandle<E>, Arc<anyhow::Error>>>>),
enum OpenedModelHandle<E> {
Open(WeakModel<E>),
Loading(Shared<Task<Result<Model<E>, Arc<anyhow::Error>>>>),
}
impl ChannelStore {
pub fn global(cx: &AppContext) -> ModelHandle<Self> {
cx.global::<ModelHandle<Self>>().clone()
pub fn global(cx: &AppContext) -> Model<Self> {
cx.global::<Model<Self>>().clone()
}
pub fn new(
client: Arc<Client>,
user_store: ModelHandle<UserStore>,
user_store: Model<UserStore>,
cx: &mut ModelContext<Self>,
) -> Self {
let rpc_subscription =
client.add_message_handler(cx.handle(), Self::handle_update_channels);
client.add_message_handler(cx.weak_model(), Self::handle_update_channels);
let mut connection_status = client.status();
let (update_channels_tx, mut update_channels_rx) = mpsc::unbounded();
let watch_connection_status = cx.spawn_weak(|this, mut cx| async move {
let watch_connection_status = cx.spawn(|this, mut cx| async move {
while let Some(status) = connection_status.next().await {
let this = this.upgrade(&cx)?;
let this = this.upgrade()?;
match status {
client::Status::Connected { .. } => {
this.update(&mut cx, |this, cx| this.handle_connect(cx))
.ok()?
.await
.log_err()?;
}
client::Status::SignedOut | client::Status::UpgradeRequired => {
this.update(&mut cx, |this, cx| this.handle_disconnect(false, cx));
this.update(&mut cx, |this, cx| this.handle_disconnect(false, cx))
.ok();
}
_ => {
this.update(&mut cx, |this, cx| this.handle_disconnect(true, cx));
this.update(&mut cx, |this, cx| this.handle_disconnect(true, cx))
.ok();
}
}
}
@ -169,17 +173,22 @@ impl ChannelStore {
_rpc_subscription: rpc_subscription,
_watch_connection_status: watch_connection_status,
disconnect_channel_buffers_task: None,
_update_channels: cx.spawn_weak(|this, mut cx| async move {
while let Some(update_channels) = update_channels_rx.next().await {
if let Some(this) = this.upgrade(&cx) {
let update_task = this.update(&mut cx, |this, cx| {
this.update_channels(update_channels, cx)
});
if let Some(update_task) = update_task {
update_task.await.log_err();
_update_channels: cx.spawn(|this, mut cx| async move {
async_maybe!({
while let Some(update_channels) = update_channels_rx.next().await {
if let Some(this) = this.upgrade() {
let update_task = this.update(&mut cx, |this, cx| {
this.update_channels(update_channels, cx)
})?;
if let Some(update_task) = update_task {
update_task.await.log_err();
}
}
}
}
anyhow::Ok(())
})
.await
.log_err();
}),
}
}
@ -240,10 +249,10 @@ impl ChannelStore {
self.channel_index.by_id().get(&channel_id)
}
pub fn has_open_channel_buffer(&self, channel_id: ChannelId, cx: &AppContext) -> bool {
pub fn has_open_channel_buffer(&self, channel_id: ChannelId, _cx: &AppContext) -> bool {
if let Some(buffer) = self.opened_buffers.get(&channel_id) {
if let OpenedModelHandle::Open(buffer) = buffer {
return buffer.upgrade(cx).is_some();
return buffer.upgrade().is_some();
}
}
false
@ -253,7 +262,7 @@ impl ChannelStore {
&mut self,
channel_id: ChannelId,
cx: &mut ModelContext<Self>,
) -> Task<Result<ModelHandle<ChannelBuffer>>> {
) -> Task<Result<Model<ChannelBuffer>>> {
let client = self.client.clone();
let user_store = self.user_store.clone();
let channel_store = cx.handle();
@ -278,13 +287,13 @@ impl ChannelStore {
.request(proto::GetChannelMessagesById { message_ids }),
)
};
cx.spawn_weak(|this, mut cx| async move {
cx.spawn(|this, mut cx| async move {
if let Some(request) = request {
let response = request.await?;
let this = this
.upgrade(&cx)
.upgrade()
.ok_or_else(|| anyhow!("channel store dropped"))?;
let user_store = this.read_with(&cx, |this, _| this.user_store.clone());
let user_store = this.update(&mut cx, |this, _| this.user_store.clone())?;
ChannelMessage::from_proto_vec(response.messages, &user_store, &mut cx).await
} else {
Ok(Vec::new())
@ -354,7 +363,7 @@ impl ChannelStore {
&mut self,
channel_id: ChannelId,
cx: &mut ModelContext<Self>,
) -> Task<Result<ModelHandle<ChannelChat>>> {
) -> Task<Result<Model<ChannelChat>>> {
let client = self.client.clone();
let user_store = self.user_store.clone();
let this = cx.handle();
@ -371,22 +380,23 @@ impl ChannelStore {
/// Make sure that the resource is only opened once, even if this method
/// is called multiple times with the same channel id while the first task
/// is still running.
fn open_channel_resource<T: Entity, F, Fut>(
fn open_channel_resource<T, F, Fut>(
&mut self,
channel_id: ChannelId,
get_map: fn(&mut Self) -> &mut HashMap<ChannelId, OpenedModelHandle<T>>,
load: F,
cx: &mut ModelContext<Self>,
) -> Task<Result<ModelHandle<T>>>
) -> Task<Result<Model<T>>>
where
F: 'static + FnOnce(Arc<Channel>, AsyncAppContext) -> Fut,
Fut: Future<Output = Result<ModelHandle<T>>>,
Fut: Future<Output = Result<Model<T>>>,
T: 'static,
{
let task = loop {
match get_map(self).entry(channel_id) {
hash_map::Entry::Occupied(e) => match e.get() {
OpenedModelHandle::Open(model) => {
if let Some(model) = model.upgrade(cx) {
if let Some(model) = model.upgrade() {
break Task::ready(Ok(model)).shared();
} else {
get_map(self).remove(&channel_id);
@ -399,12 +409,12 @@ impl ChannelStore {
},
hash_map::Entry::Vacant(e) => {
let task = cx
.spawn(|this, cx| async move {
let channel = this.read_with(&cx, |this, _| {
.spawn(move |this, mut cx| async move {
let channel = this.update(&mut cx, |this, _| {
this.channel_for_id(channel_id).cloned().ok_or_else(|| {
Arc::new(anyhow!("no channel for id: {}", channel_id))
})
})?;
})??;
load(channel, cx).await.map_err(Arc::new)
})
@ -413,7 +423,7 @@ impl ChannelStore {
e.insert(OpenedModelHandle::Loading(task.clone()));
cx.spawn({
let task = task.clone();
|this, mut cx| async move {
move |this, mut cx| async move {
let result = task.await;
this.update(&mut cx, |this, _| match result {
Ok(model) => {
@ -425,7 +435,8 @@ impl ChannelStore {
Err(_) => {
get_map(this).remove(&channel_id);
}
});
})
.ok();
}
})
.detach();
@ -433,7 +444,7 @@ impl ChannelStore {
}
}
};
cx.foreground()
cx.background_executor()
.spawn(async move { task.await.map_err(|error| anyhow!("{}", error)) })
}
@ -458,7 +469,7 @@ impl ChannelStore {
) -> Task<Result<ChannelId>> {
let client = self.client.clone();
let name = name.trim_start_matches("#").to_owned();
cx.spawn(|this, mut cx| async move {
cx.spawn(move |this, mut cx| async move {
let response = client
.request(proto::CreateChannel { name, parent_id })
.await?;
@ -468,15 +479,6 @@ impl ChannelStore {
.ok_or_else(|| anyhow!("missing channel in response"))?;
let channel_id = channel.id;
// let parent_edge = if let Some(parent_id) = parent_id {
// vec![ChannelEdge {
// channel_id: channel.id,
// parent_id,
// }]
// } else {
// vec![]
// };
this.update(&mut cx, |this, cx| {
let task = this.update_channels(
proto::UpdateChannels {
@ -492,7 +494,7 @@ impl ChannelStore {
// will resolve before this flush_effects finishes. Synchronously emitting this event
// ensures that the collab panel will observe this creation before the frame completes
cx.emit(ChannelEvent::ChannelCreated(channel_id));
});
})?;
Ok(channel_id)
})
@ -505,7 +507,7 @@ impl ChannelStore {
cx: &mut ModelContext<Self>,
) -> Task<Result<()>> {
let client = self.client.clone();
cx.spawn(|_, _| async move {
cx.spawn(move |_, _| async move {
let _ = client
.request(proto::MoveChannel { channel_id, to })
.await?;
@ -521,7 +523,7 @@ impl ChannelStore {
cx: &mut ModelContext<Self>,
) -> Task<Result<()>> {
let client = self.client.clone();
cx.spawn(|_, _| async move {
cx.spawn(move |_, _| async move {
let _ = client
.request(proto::SetChannelVisibility {
channel_id,
@ -546,7 +548,7 @@ impl ChannelStore {
cx.notify();
let client = self.client.clone();
cx.spawn(|this, mut cx| async move {
cx.spawn(move |this, mut cx| async move {
let result = client
.request(proto::InviteChannelMember {
channel_id,
@ -558,7 +560,7 @@ impl ChannelStore {
this.update(&mut cx, |this, cx| {
this.outgoing_invites.remove(&(channel_id, user_id));
cx.notify();
});
})?;
result?;
@ -578,7 +580,7 @@ impl ChannelStore {
cx.notify();
let client = self.client.clone();
cx.spawn(|this, mut cx| async move {
cx.spawn(move |this, mut cx| async move {
let result = client
.request(proto::RemoveChannelMember {
channel_id,
@ -589,7 +591,7 @@ impl ChannelStore {
this.update(&mut cx, |this, cx| {
this.outgoing_invites.remove(&(channel_id, user_id));
cx.notify();
});
})?;
result?;
Ok(())
})
@ -608,7 +610,7 @@ impl ChannelStore {
cx.notify();
let client = self.client.clone();
cx.spawn(|this, mut cx| async move {
cx.spawn(move |this, mut cx| async move {
let result = client
.request(proto::SetChannelMemberRole {
channel_id,
@ -620,7 +622,7 @@ impl ChannelStore {
this.update(&mut cx, |this, cx| {
this.outgoing_invites.remove(&(channel_id, user_id));
cx.notify();
});
})?;
result?;
Ok(())
@ -635,7 +637,7 @@ impl ChannelStore {
) -> Task<Result<()>> {
let client = self.client.clone();
let name = new_name.to_string();
cx.spawn(|this, mut cx| async move {
cx.spawn(move |this, mut cx| async move {
let channel = client
.request(proto::RenameChannel { channel_id, name })
.await?
@ -656,7 +658,7 @@ impl ChannelStore {
// will resolve before this flush_effects finishes. Synchronously emitting this event
// ensures that the collab panel will observe this creation before the frame complete
cx.emit(ChannelEvent::ChannelRenamed(channel_id))
});
})?;
Ok(())
})
}
@ -668,7 +670,7 @@ impl ChannelStore {
cx: &mut ModelContext<Self>,
) -> Task<Result<()>> {
let client = self.client.clone();
cx.background().spawn(async move {
cx.background_executor().spawn(async move {
client
.request(proto::RespondToChannelInvite { channel_id, accept })
.await?;
@ -683,17 +685,17 @@ impl ChannelStore {
) -> Task<Result<Vec<ChannelMembership>>> {
let client = self.client.clone();
let user_store = self.user_store.downgrade();
cx.spawn(|_, mut cx| async move {
cx.spawn(move |_, mut cx| async move {
let response = client
.request(proto::GetChannelMembers { channel_id })
.await?;
let user_ids = response.members.iter().map(|m| m.user_id).collect();
let user_store = user_store
.upgrade(&cx)
.upgrade()
.ok_or_else(|| anyhow!("user store dropped"))?;
let users = user_store
.update(&mut cx, |user_store, cx| user_store.get_users(user_ids, cx))
.update(&mut cx, |user_store, cx| user_store.get_users(user_ids, cx))?
.await?;
Ok(users
@ -727,7 +729,7 @@ impl ChannelStore {
}
async fn handle_update_channels(
this: ModelHandle<Self>,
this: Model<Self>,
message: TypedEnvelope<proto::UpdateChannels>,
_: Arc<Client>,
mut cx: AsyncAppContext,
@ -736,7 +738,7 @@ impl ChannelStore {
this.update_channels_tx
.unbounded_send(message.payload)
.unwrap();
});
})?;
Ok(())
}
@ -750,7 +752,7 @@ impl ChannelStore {
for chat in self.opened_chats.values() {
if let OpenedModelHandle::Open(chat) = chat {
if let Some(chat) = chat.upgrade(cx) {
if let Some(chat) = chat.upgrade() {
chat.update(cx, |chat, cx| {
chat.rejoin(cx);
});
@ -761,7 +763,7 @@ impl ChannelStore {
let mut buffer_versions = Vec::new();
for buffer in self.opened_buffers.values() {
if let OpenedModelHandle::Open(buffer) = buffer {
if let Some(buffer) = buffer.upgrade(cx) {
if let Some(buffer) = buffer.upgrade() {
let channel_buffer = buffer.read(cx);
let buffer = channel_buffer.buffer().read(cx);
buffer_versions.push(proto::ChannelBufferVersion {
@ -787,7 +789,7 @@ impl ChannelStore {
this.update(&mut cx, |this, cx| {
this.opened_buffers.retain(|_, buffer| match buffer {
OpenedModelHandle::Open(channel_buffer) => {
let Some(channel_buffer) = channel_buffer.upgrade(cx) else {
let Some(channel_buffer) = channel_buffer.upgrade() else {
return false;
};
@ -824,7 +826,7 @@ impl ChannelStore {
if let Some(operations) = operations {
let client = this.client.clone();
cx.background()
cx.background_executor()
.spawn(async move {
let operations = operations.await;
for chunk in
@ -849,7 +851,8 @@ impl ChannelStore {
}
OpenedModelHandle::Loading(_) => true,
});
});
})
.ok();
anyhow::Ok(())
})
}
@ -858,21 +861,22 @@ impl ChannelStore {
cx.notify();
self.disconnect_channel_buffers_task.get_or_insert_with(|| {
cx.spawn_weak(|this, mut cx| async move {
cx.spawn(move |this, mut cx| async move {
if wait_for_reconnect {
cx.background().timer(RECONNECT_TIMEOUT).await;
cx.background_executor().timer(RECONNECT_TIMEOUT).await;
}
if let Some(this) = this.upgrade(&cx) {
if let Some(this) = this.upgrade() {
this.update(&mut cx, |this, cx| {
for (_, buffer) in this.opened_buffers.drain() {
if let OpenedModelHandle::Open(buffer) = buffer {
if let Some(buffer) = buffer.upgrade(cx) {
if let Some(buffer) = buffer.upgrade() {
buffer.update(cx, |buffer, cx| buffer.disconnect(cx));
}
}
}
});
})
.ok();
}
})
});
@ -892,14 +896,16 @@ impl ChannelStore {
.channel_invitations
.binary_search_by_key(&channel.id, |c| c.id)
{
Ok(ix) => Arc::make_mut(&mut self.channel_invitations[ix]).name = channel.name,
Ok(ix) => {
Arc::make_mut(&mut self.channel_invitations[ix]).name = channel.name.into()
}
Err(ix) => self.channel_invitations.insert(
ix,
Arc::new(Channel {
id: channel.id,
visibility: channel.visibility(),
role: channel.role(),
name: channel.name,
name: channel.name.into(),
unseen_note_version: None,
unseen_message_id: None,
parent_path: channel.parent_path,
@ -931,7 +937,7 @@ impl ChannelStore {
if let Some(OpenedModelHandle::Open(buffer)) =
self.opened_buffers.remove(&channel_id)
{
if let Some(buffer) = buffer.upgrade(cx) {
if let Some(buffer) = buffer.upgrade() {
buffer.update(cx, ChannelBuffer::disconnect);
}
}
@ -945,7 +951,7 @@ impl ChannelStore {
if channel_changed {
if let Some(OpenedModelHandle::Open(buffer)) = self.opened_buffers.get(&id) {
if let Some(buffer) = buffer.upgrade(cx) {
if let Some(buffer) = buffer.upgrade() {
buffer.update(cx, ChannelBuffer::channel_changed);
}
}
@ -1010,8 +1016,7 @@ impl ChannelStore {
}
cx.notify();
});
anyhow::Ok(())
})
}))
}
}

View File

@ -104,7 +104,7 @@ impl<'a> ChannelPathsInsertGuard<'a> {
existing_channel.visibility = channel_proto.visibility();
existing_channel.role = channel_proto.role();
existing_channel.name = channel_proto.name;
existing_channel.name = channel_proto.name.into();
} else {
self.channels_by_id.insert(
channel_proto.id,
@ -112,7 +112,7 @@ impl<'a> ChannelPathsInsertGuard<'a> {
id: channel_proto.id,
visibility: channel_proto.visibility(),
role: channel_proto.role(),
name: channel_proto.name,
name: channel_proto.name.into(),
unseen_note_version: None,
unseen_message_id: None,
parent_path: channel_proto.parent_path,
@ -146,11 +146,11 @@ fn channel_path_sorting_key<'a>(
let (parent_path, name) = channels_by_id
.get(&id)
.map_or((&[] as &[_], None), |channel| {
(channel.parent_path.as_slice(), Some(channel.name.as_str()))
(channel.parent_path.as_slice(), Some(channel.name.as_ref()))
});
parent_path
.iter()
.filter_map(|id| Some(channels_by_id.get(id)?.name.as_str()))
.filter_map(|id| Some(channels_by_id.get(id)?.name.as_ref()))
.chain(name)
}

View File

@ -2,7 +2,7 @@ use crate::channel_chat::ChannelChatEvent;
use super::*;
use client::{test::FakeServer, Client, UserStore};
use gpui::{AppContext, ModelHandle, TestAppContext};
use gpui::{AppContext, Context, Model, TestAppContext};
use rpc::proto::{self};
use settings::SettingsStore;
use util::http::FakeHttpClient;
@ -147,7 +147,7 @@ async fn test_channel_messages(cx: &mut TestAppContext) {
let user_id = 5;
let channel_id = 5;
let channel_store = cx.update(init_test);
let client = channel_store.read_with(cx, |s, _| s.client());
let client = channel_store.update(cx, |s, _| s.client());
let server = FakeServer::for_client(user_id, &client, cx).await;
// Get the available channels.
@ -161,8 +161,8 @@ async fn test_channel_messages(cx: &mut TestAppContext) {
}],
..Default::default()
});
cx.foreground().run_until_parked();
cx.read(|cx| {
cx.executor().run_until_parked();
cx.update(|cx| {
assert_channels(
&channel_store,
&[(0, "the-channel".to_string(), proto::ChannelRole::Member)],
@ -214,7 +214,7 @@ async fn test_channel_messages(cx: &mut TestAppContext) {
},
);
cx.foreground().start_waiting();
cx.executor().start_waiting();
// Client requests all users for the received messages
let mut get_users = server.receive::<proto::GetUsers>().await.unwrap();
@ -232,7 +232,7 @@ async fn test_channel_messages(cx: &mut TestAppContext) {
);
let channel = channel.await.unwrap();
channel.read_with(cx, |channel, _| {
channel.update(cx, |channel, _| {
assert_eq!(
channel
.messages_in_range(0..2)
@ -273,13 +273,13 @@ async fn test_channel_messages(cx: &mut TestAppContext) {
);
assert_eq!(
channel.next_event(cx).await,
channel.next_event(cx),
ChannelChatEvent::MessagesUpdated {
old_range: 2..2,
new_count: 1,
}
);
channel.read_with(cx, |channel, _| {
channel.update(cx, |channel, _| {
assert_eq!(
channel
.messages_in_range(2..3)
@ -322,13 +322,13 @@ async fn test_channel_messages(cx: &mut TestAppContext) {
);
assert_eq!(
channel.next_event(cx).await,
channel.next_event(cx),
ChannelChatEvent::MessagesUpdated {
old_range: 0..0,
new_count: 2,
}
);
channel.read_with(cx, |channel, _| {
channel.update(cx, |channel, _| {
assert_eq!(
channel
.messages_in_range(0..2)
@ -342,13 +342,13 @@ async fn test_channel_messages(cx: &mut TestAppContext) {
});
}
fn init_test(cx: &mut AppContext) -> ModelHandle<ChannelStore> {
fn init_test(cx: &mut AppContext) -> Model<ChannelStore> {
let http = FakeHttpClient::with_404_response();
let client = Client::new(http.clone(), cx);
let user_store = cx.add_model(|cx| UserStore::new(client.clone(), http, cx));
let user_store = cx.new_model(|cx| UserStore::new(client.clone(), cx));
cx.foreground().forbid_parking();
cx.set_global(SettingsStore::test(cx));
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
client::init(&client, cx);
crate::init(&client, user_store, cx);
@ -356,7 +356,7 @@ fn init_test(cx: &mut AppContext) -> ModelHandle<ChannelStore> {
}
fn update_channels(
channel_store: &ModelHandle<ChannelStore>,
channel_store: &Model<ChannelStore>,
message: proto::UpdateChannels,
cx: &mut AppContext,
) {
@ -366,11 +366,11 @@ fn update_channels(
#[track_caller]
fn assert_channels(
channel_store: &ModelHandle<ChannelStore>,
channel_store: &Model<ChannelStore>,
expected_channels: &[(usize, String, proto::ChannelRole)],
cx: &AppContext,
cx: &mut AppContext,
) {
let actual = channel_store.read_with(cx, |store, _| {
let actual = channel_store.update(cx, |store, _| {
store
.ordered_channels()
.map(|(depth, channel)| (depth, channel.name.to_string(), channel.role))

View File

@ -1,54 +0,0 @@
[package]
name = "channel2"
version = "0.1.0"
edition = "2021"
publish = false
[lib]
path = "src/channel2.rs"
doctest = false
[features]
test-support = ["collections/test-support", "gpui/test-support", "rpc/test-support"]
[dependencies]
client = { package = "client2", path = "../client2" }
collections = { path = "../collections" }
db = { package = "db2", path = "../db2" }
gpui = { package = "gpui2", path = "../gpui2" }
util = { path = "../util" }
rpc = { package = "rpc2", path = "../rpc2" }
text = { package = "text2", path = "../text2" }
language = { package = "language2", path = "../language2" }
settings = { package = "settings2", path = "../settings2" }
feature_flags = { package = "feature_flags2", path = "../feature_flags2" }
sum_tree = { path = "../sum_tree" }
clock = { path = "../clock" }
anyhow.workspace = true
futures.workspace = true
image = "0.23"
lazy_static.workspace = true
smallvec.workspace = true
log.workspace = true
parking_lot.workspace = true
postage.workspace = true
rand.workspace = true
schemars.workspace = true
smol.workspace = true
thiserror.workspace = true
time.workspace = true
tiny_http = "0.8"
uuid.workspace = true
url = "2.2"
serde.workspace = true
serde_derive.workspace = true
tempfile = "3"
[dev-dependencies]
collections = { path = "../collections", features = ["test-support"] }
gpui = { package = "gpui2", path = "../gpui2", features = ["test-support"] }
rpc = { package = "rpc2", path = "../rpc2", features = ["test-support"] }
client = { package = "client2", path = "../client2", features = ["test-support"] }
settings = { package = "settings2", path = "../settings2", features = ["test-support"] }
util = { path = "../util", features = ["test-support"] }

View File

@ -1,23 +0,0 @@
mod channel_buffer;
mod channel_chat;
mod channel_store;
use client::{Client, UserStore};
use gpui::{AppContext, Model};
use std::sync::Arc;
pub use channel_buffer::{ChannelBuffer, ChannelBufferEvent, ACKNOWLEDGE_DEBOUNCE_INTERVAL};
pub use channel_chat::{
mentions_to_proto, ChannelChat, ChannelChatEvent, ChannelMessage, ChannelMessageId,
MessageParams,
};
pub use channel_store::{Channel, ChannelEvent, ChannelId, ChannelMembership, ChannelStore};
#[cfg(test)]
mod channel_store_tests;
pub fn init(client: &Arc<Client>, user_store: Model<UserStore>, cx: &mut AppContext) {
channel_store::init(client, user_store, cx);
channel_buffer::init(client);
channel_chat::init(client);
}

View File

@ -1,257 +0,0 @@
use crate::{Channel, ChannelId, ChannelStore};
use anyhow::Result;
use client::{Client, Collaborator, UserStore};
use collections::HashMap;
use gpui::{AppContext, AsyncAppContext, Context, EventEmitter, Model, ModelContext, Task};
use language::proto::serialize_version;
use rpc::{
proto::{self, PeerId},
TypedEnvelope,
};
use std::{sync::Arc, time::Duration};
use util::ResultExt;
pub const ACKNOWLEDGE_DEBOUNCE_INTERVAL: Duration = Duration::from_millis(250);
pub(crate) fn init(client: &Arc<Client>) {
client.add_model_message_handler(ChannelBuffer::handle_update_channel_buffer);
client.add_model_message_handler(ChannelBuffer::handle_update_channel_buffer_collaborators);
}
pub struct ChannelBuffer {
pub channel_id: ChannelId,
connected: bool,
collaborators: HashMap<PeerId, Collaborator>,
user_store: Model<UserStore>,
channel_store: Model<ChannelStore>,
buffer: Model<language::Buffer>,
buffer_epoch: u64,
client: Arc<Client>,
subscription: Option<client::Subscription>,
acknowledge_task: Option<Task<Result<()>>>,
}
pub enum ChannelBufferEvent {
CollaboratorsChanged,
Disconnected,
BufferEdited,
ChannelChanged,
}
impl EventEmitter<ChannelBufferEvent> for ChannelBuffer {}
impl ChannelBuffer {
pub(crate) async fn new(
channel: Arc<Channel>,
client: Arc<Client>,
user_store: Model<UserStore>,
channel_store: Model<ChannelStore>,
mut cx: AsyncAppContext,
) -> Result<Model<Self>> {
let response = client
.request(proto::JoinChannelBuffer {
channel_id: channel.id,
})
.await?;
let base_text = response.base_text;
let operations = response
.operations
.into_iter()
.map(language::proto::deserialize_operation)
.collect::<Result<Vec<_>, _>>()?;
let buffer = cx.new_model(|_| {
language::Buffer::remote(response.buffer_id, response.replica_id as u16, base_text)
})?;
buffer.update(&mut cx, |buffer, cx| buffer.apply_ops(operations, cx))??;
let subscription = client.subscribe_to_entity(channel.id)?;
anyhow::Ok(cx.new_model(|cx| {
cx.subscribe(&buffer, Self::on_buffer_update).detach();
cx.on_release(Self::release).detach();
let mut this = Self {
buffer,
buffer_epoch: response.epoch,
client,
connected: true,
collaborators: Default::default(),
acknowledge_task: None,
channel_id: channel.id,
subscription: Some(subscription.set_model(&cx.handle(), &mut cx.to_async())),
user_store,
channel_store,
};
this.replace_collaborators(response.collaborators, cx);
this
})?)
}
fn release(&mut self, _: &mut AppContext) {
if self.connected {
if let Some(task) = self.acknowledge_task.take() {
task.detach();
}
self.client
.send(proto::LeaveChannelBuffer {
channel_id: self.channel_id,
})
.log_err();
}
}
pub fn remote_id(&self, cx: &AppContext) -> u64 {
self.buffer.read(cx).remote_id()
}
pub fn user_store(&self) -> &Model<UserStore> {
&self.user_store
}
pub(crate) fn replace_collaborators(
&mut self,
collaborators: Vec<proto::Collaborator>,
cx: &mut ModelContext<Self>,
) {
let mut new_collaborators = HashMap::default();
for collaborator in collaborators {
if let Ok(collaborator) = Collaborator::from_proto(collaborator) {
new_collaborators.insert(collaborator.peer_id, collaborator);
}
}
for (_, old_collaborator) in &self.collaborators {
if !new_collaborators.contains_key(&old_collaborator.peer_id) {
self.buffer.update(cx, |buffer, cx| {
buffer.remove_peer(old_collaborator.replica_id as u16, cx)
});
}
}
self.collaborators = new_collaborators;
cx.emit(ChannelBufferEvent::CollaboratorsChanged);
cx.notify();
}
async fn handle_update_channel_buffer(
this: Model<Self>,
update_channel_buffer: TypedEnvelope<proto::UpdateChannelBuffer>,
_: Arc<Client>,
mut cx: AsyncAppContext,
) -> Result<()> {
let ops = update_channel_buffer
.payload
.operations
.into_iter()
.map(language::proto::deserialize_operation)
.collect::<Result<Vec<_>, _>>()?;
this.update(&mut cx, |this, cx| {
cx.notify();
this.buffer
.update(cx, |buffer, cx| buffer.apply_ops(ops, cx))
})??;
Ok(())
}
async fn handle_update_channel_buffer_collaborators(
this: Model<Self>,
message: TypedEnvelope<proto::UpdateChannelBufferCollaborators>,
_: Arc<Client>,
mut cx: AsyncAppContext,
) -> Result<()> {
this.update(&mut cx, |this, cx| {
this.replace_collaborators(message.payload.collaborators, cx);
cx.emit(ChannelBufferEvent::CollaboratorsChanged);
cx.notify();
})
}
fn on_buffer_update(
&mut self,
_: Model<language::Buffer>,
event: &language::Event,
cx: &mut ModelContext<Self>,
) {
match event {
language::Event::Operation(operation) => {
let operation = language::proto::serialize_operation(operation);
self.client
.send(proto::UpdateChannelBuffer {
channel_id: self.channel_id,
operations: vec![operation],
})
.log_err();
}
language::Event::Edited => {
cx.emit(ChannelBufferEvent::BufferEdited);
}
_ => {}
}
}
pub fn acknowledge_buffer_version(&mut self, cx: &mut ModelContext<'_, ChannelBuffer>) {
let buffer = self.buffer.read(cx);
let version = buffer.version();
let buffer_id = buffer.remote_id();
let client = self.client.clone();
let epoch = self.epoch();
self.acknowledge_task = Some(cx.spawn(move |_, cx| async move {
cx.background_executor()
.timer(ACKNOWLEDGE_DEBOUNCE_INTERVAL)
.await;
client
.send(proto::AckBufferOperation {
buffer_id,
epoch,
version: serialize_version(&version),
})
.ok();
Ok(())
}));
}
pub fn epoch(&self) -> u64 {
self.buffer_epoch
}
pub fn buffer(&self) -> Model<language::Buffer> {
self.buffer.clone()
}
pub fn collaborators(&self) -> &HashMap<PeerId, Collaborator> {
&self.collaborators
}
pub fn channel(&self, cx: &AppContext) -> Option<Arc<Channel>> {
self.channel_store
.read(cx)
.channel_for_id(self.channel_id)
.cloned()
}
pub(crate) fn disconnect(&mut self, cx: &mut ModelContext<Self>) {
log::info!("channel buffer {} disconnected", self.channel_id);
if self.connected {
self.connected = false;
self.subscription.take();
cx.emit(ChannelBufferEvent::Disconnected);
cx.notify()
}
}
pub(crate) fn channel_changed(&mut self, cx: &mut ModelContext<Self>) {
cx.emit(ChannelBufferEvent::ChannelChanged);
cx.notify()
}
pub fn is_connected(&self) -> bool {
self.connected
}
pub fn replica_id(&self, cx: &AppContext) -> u16 {
self.buffer.read(cx).replica_id()
}
}

View File

@ -1,645 +0,0 @@
use crate::{Channel, ChannelId, ChannelStore};
use anyhow::{anyhow, Result};
use client::{
proto,
user::{User, UserStore},
Client, Subscription, TypedEnvelope, UserId,
};
use futures::lock::Mutex;
use gpui::{AppContext, AsyncAppContext, Context, EventEmitter, Model, ModelContext, Task};
use rand::prelude::*;
use std::{
collections::HashSet,
mem,
ops::{ControlFlow, Range},
sync::Arc,
};
use sum_tree::{Bias, SumTree};
use time::OffsetDateTime;
use util::{post_inc, ResultExt as _, TryFutureExt};
pub struct ChannelChat {
pub channel_id: ChannelId,
messages: SumTree<ChannelMessage>,
acknowledged_message_ids: HashSet<u64>,
channel_store: Model<ChannelStore>,
loaded_all_messages: bool,
last_acknowledged_id: Option<u64>,
next_pending_message_id: usize,
user_store: Model<UserStore>,
rpc: Arc<Client>,
outgoing_messages_lock: Arc<Mutex<()>>,
rng: StdRng,
_subscription: Subscription,
}
#[derive(Debug, PartialEq, Eq)]
pub struct MessageParams {
pub text: String,
pub mentions: Vec<(Range<usize>, UserId)>,
}
#[derive(Clone, Debug)]
pub struct ChannelMessage {
pub id: ChannelMessageId,
pub body: String,
pub timestamp: OffsetDateTime,
pub sender: Arc<User>,
pub nonce: u128,
pub mentions: Vec<(Range<usize>, UserId)>,
}
#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum ChannelMessageId {
Saved(u64),
Pending(usize),
}
#[derive(Clone, Debug, Default)]
pub struct ChannelMessageSummary {
max_id: ChannelMessageId,
count: usize,
}
#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
struct Count(usize);
#[derive(Clone, Debug, PartialEq)]
pub enum ChannelChatEvent {
MessagesUpdated {
old_range: Range<usize>,
new_count: usize,
},
NewMessage {
channel_id: ChannelId,
message_id: u64,
},
}
impl EventEmitter<ChannelChatEvent> for ChannelChat {}
pub fn init(client: &Arc<Client>) {
client.add_model_message_handler(ChannelChat::handle_message_sent);
client.add_model_message_handler(ChannelChat::handle_message_removed);
}
impl ChannelChat {
pub async fn new(
channel: Arc<Channel>,
channel_store: Model<ChannelStore>,
user_store: Model<UserStore>,
client: Arc<Client>,
mut cx: AsyncAppContext,
) -> Result<Model<Self>> {
let channel_id = channel.id;
let subscription = client.subscribe_to_entity(channel_id).unwrap();
let response = client
.request(proto::JoinChannelChat { channel_id })
.await?;
let messages = messages_from_proto(response.messages, &user_store, &mut cx).await?;
let loaded_all_messages = response.done;
Ok(cx.new_model(|cx| {
cx.on_release(Self::release).detach();
let mut this = Self {
channel_id: channel.id,
user_store,
channel_store,
rpc: client,
outgoing_messages_lock: Default::default(),
messages: Default::default(),
acknowledged_message_ids: Default::default(),
loaded_all_messages,
next_pending_message_id: 0,
last_acknowledged_id: None,
rng: StdRng::from_entropy(),
_subscription: subscription.set_model(&cx.handle(), &mut cx.to_async()),
};
this.insert_messages(messages, cx);
this
})?)
}
fn release(&mut self, _: &mut AppContext) {
self.rpc
.send(proto::LeaveChannelChat {
channel_id: self.channel_id,
})
.log_err();
}
pub fn channel(&self, cx: &AppContext) -> Option<Arc<Channel>> {
self.channel_store
.read(cx)
.channel_for_id(self.channel_id)
.cloned()
}
pub fn client(&self) -> &Arc<Client> {
&self.rpc
}
pub fn send_message(
&mut self,
message: MessageParams,
cx: &mut ModelContext<Self>,
) -> Result<Task<Result<u64>>> {
if message.text.is_empty() {
Err(anyhow!("message body can't be empty"))?;
}
let current_user = self
.user_store
.read(cx)
.current_user()
.ok_or_else(|| anyhow!("current_user is not present"))?;
let channel_id = self.channel_id;
let pending_id = ChannelMessageId::Pending(post_inc(&mut self.next_pending_message_id));
let nonce = self.rng.gen();
self.insert_messages(
SumTree::from_item(
ChannelMessage {
id: pending_id,
body: message.text.clone(),
sender: current_user,
timestamp: OffsetDateTime::now_utc(),
mentions: message.mentions.clone(),
nonce,
},
&(),
),
cx,
);
let user_store = self.user_store.clone();
let rpc = self.rpc.clone();
let outgoing_messages_lock = self.outgoing_messages_lock.clone();
Ok(cx.spawn(move |this, mut cx| async move {
let outgoing_message_guard = outgoing_messages_lock.lock().await;
let request = rpc.request(proto::SendChannelMessage {
channel_id,
body: message.text,
nonce: Some(nonce.into()),
mentions: mentions_to_proto(&message.mentions),
});
let response = request.await?;
drop(outgoing_message_guard);
let response = response.message.ok_or_else(|| anyhow!("invalid message"))?;
let id = response.id;
let message = ChannelMessage::from_proto(response, &user_store, &mut cx).await?;
this.update(&mut cx, |this, cx| {
this.insert_messages(SumTree::from_item(message, &()), cx);
})?;
Ok(id)
}))
}
pub fn remove_message(&mut self, id: u64, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
let response = self.rpc.request(proto::RemoveChannelMessage {
channel_id: self.channel_id,
message_id: id,
});
cx.spawn(move |this, mut cx| async move {
response.await?;
this.update(&mut cx, |this, cx| {
this.message_removed(id, cx);
})?;
Ok(())
})
}
pub fn load_more_messages(&mut self, cx: &mut ModelContext<Self>) -> Option<Task<Option<()>>> {
if self.loaded_all_messages {
return None;
}
let rpc = self.rpc.clone();
let user_store = self.user_store.clone();
let channel_id = self.channel_id;
let before_message_id = self.first_loaded_message_id()?;
Some(cx.spawn(move |this, mut cx| {
async move {
let response = rpc
.request(proto::GetChannelMessages {
channel_id,
before_message_id,
})
.await?;
let loaded_all_messages = response.done;
let messages = messages_from_proto(response.messages, &user_store, &mut cx).await?;
this.update(&mut cx, |this, cx| {
this.loaded_all_messages = loaded_all_messages;
this.insert_messages(messages, cx);
})?;
anyhow::Ok(())
}
.log_err()
}))
}
pub fn first_loaded_message_id(&mut self) -> Option<u64> {
self.messages.first().and_then(|message| match message.id {
ChannelMessageId::Saved(id) => Some(id),
ChannelMessageId::Pending(_) => None,
})
}
/// Load all of the chat messages since a certain message id.
///
/// For now, we always maintain a suffix of the channel's messages.
pub async fn load_history_since_message(
chat: Model<Self>,
message_id: u64,
mut cx: AsyncAppContext,
) -> Option<usize> {
loop {
let step = chat
.update(&mut cx, |chat, cx| {
if let Some(first_id) = chat.first_loaded_message_id() {
if first_id <= message_id {
let mut cursor = chat.messages.cursor::<(ChannelMessageId, Count)>();
let message_id = ChannelMessageId::Saved(message_id);
cursor.seek(&message_id, Bias::Left, &());
return ControlFlow::Break(
if cursor
.item()
.map_or(false, |message| message.id == message_id)
{
Some(cursor.start().1 .0)
} else {
None
},
);
}
}
ControlFlow::Continue(chat.load_more_messages(cx))
})
.log_err()?;
match step {
ControlFlow::Break(ix) => return ix,
ControlFlow::Continue(task) => task?.await?,
}
}
}
pub fn acknowledge_last_message(&mut self, cx: &mut ModelContext<Self>) {
if let ChannelMessageId::Saved(latest_message_id) = self.messages.summary().max_id {
if self
.last_acknowledged_id
.map_or(true, |acknowledged_id| acknowledged_id < latest_message_id)
{
self.rpc
.send(proto::AckChannelMessage {
channel_id: self.channel_id,
message_id: latest_message_id,
})
.ok();
self.last_acknowledged_id = Some(latest_message_id);
self.channel_store.update(cx, |store, cx| {
store.acknowledge_message_id(self.channel_id, latest_message_id, cx);
});
}
}
}
pub fn rejoin(&mut self, cx: &mut ModelContext<Self>) {
let user_store = self.user_store.clone();
let rpc = self.rpc.clone();
let channel_id = self.channel_id;
cx.spawn(move |this, mut cx| {
async move {
let response = rpc.request(proto::JoinChannelChat { channel_id }).await?;
let messages = messages_from_proto(response.messages, &user_store, &mut cx).await?;
let loaded_all_messages = response.done;
let pending_messages = this.update(&mut cx, |this, cx| {
if let Some((first_new_message, last_old_message)) =
messages.first().zip(this.messages.last())
{
if first_new_message.id > last_old_message.id {
let old_messages = mem::take(&mut this.messages);
cx.emit(ChannelChatEvent::MessagesUpdated {
old_range: 0..old_messages.summary().count,
new_count: 0,
});
this.loaded_all_messages = loaded_all_messages;
}
}
this.insert_messages(messages, cx);
if loaded_all_messages {
this.loaded_all_messages = loaded_all_messages;
}
this.pending_messages().cloned().collect::<Vec<_>>()
})?;
for pending_message in pending_messages {
let request = rpc.request(proto::SendChannelMessage {
channel_id,
body: pending_message.body,
mentions: mentions_to_proto(&pending_message.mentions),
nonce: Some(pending_message.nonce.into()),
});
let response = request.await?;
let message = ChannelMessage::from_proto(
response.message.ok_or_else(|| anyhow!("invalid message"))?,
&user_store,
&mut cx,
)
.await?;
this.update(&mut cx, |this, cx| {
this.insert_messages(SumTree::from_item(message, &()), cx);
})?;
}
anyhow::Ok(())
}
.log_err()
})
.detach();
}
pub fn message_count(&self) -> usize {
self.messages.summary().count
}
pub fn messages(&self) -> &SumTree<ChannelMessage> {
&self.messages
}
pub fn message(&self, ix: usize) -> &ChannelMessage {
let mut cursor = self.messages.cursor::<Count>();
cursor.seek(&Count(ix), Bias::Right, &());
cursor.item().unwrap()
}
pub fn acknowledge_message(&mut self, id: u64) {
if self.acknowledged_message_ids.insert(id) {
self.rpc
.send(proto::AckChannelMessage {
channel_id: self.channel_id,
message_id: id,
})
.ok();
}
}
pub fn messages_in_range(&self, range: Range<usize>) -> impl Iterator<Item = &ChannelMessage> {
let mut cursor = self.messages.cursor::<Count>();
cursor.seek(&Count(range.start), Bias::Right, &());
cursor.take(range.len())
}
pub fn pending_messages(&self) -> impl Iterator<Item = &ChannelMessage> {
let mut cursor = self.messages.cursor::<ChannelMessageId>();
cursor.seek(&ChannelMessageId::Pending(0), Bias::Left, &());
cursor
}
async fn handle_message_sent(
this: Model<Self>,
message: TypedEnvelope<proto::ChannelMessageSent>,
_: Arc<Client>,
mut cx: AsyncAppContext,
) -> Result<()> {
let user_store = this.update(&mut cx, |this, _| this.user_store.clone())?;
let message = message
.payload
.message
.ok_or_else(|| anyhow!("empty message"))?;
let message_id = message.id;
let message = ChannelMessage::from_proto(message, &user_store, &mut cx).await?;
this.update(&mut cx, |this, cx| {
this.insert_messages(SumTree::from_item(message, &()), cx);
cx.emit(ChannelChatEvent::NewMessage {
channel_id: this.channel_id,
message_id,
})
})?;
Ok(())
}
async fn handle_message_removed(
this: Model<Self>,
message: TypedEnvelope<proto::RemoveChannelMessage>,
_: Arc<Client>,
mut cx: AsyncAppContext,
) -> Result<()> {
this.update(&mut cx, |this, cx| {
this.message_removed(message.payload.message_id, cx)
})?;
Ok(())
}
fn insert_messages(&mut self, messages: SumTree<ChannelMessage>, cx: &mut ModelContext<Self>) {
if let Some((first_message, last_message)) = messages.first().zip(messages.last()) {
let nonces = messages
.cursor::<()>()
.map(|m| m.nonce)
.collect::<HashSet<_>>();
let mut old_cursor = self.messages.cursor::<(ChannelMessageId, Count)>();
let mut new_messages = old_cursor.slice(&first_message.id, Bias::Left, &());
let start_ix = old_cursor.start().1 .0;
let removed_messages = old_cursor.slice(&last_message.id, Bias::Right, &());
let removed_count = removed_messages.summary().count;
let new_count = messages.summary().count;
let end_ix = start_ix + removed_count;
new_messages.append(messages, &());
let mut ranges = Vec::<Range<usize>>::new();
if new_messages.last().unwrap().is_pending() {
new_messages.append(old_cursor.suffix(&()), &());
} else {
new_messages.append(
old_cursor.slice(&ChannelMessageId::Pending(0), Bias::Left, &()),
&(),
);
while let Some(message) = old_cursor.item() {
let message_ix = old_cursor.start().1 .0;
if nonces.contains(&message.nonce) {
if ranges.last().map_or(false, |r| r.end == message_ix) {
ranges.last_mut().unwrap().end += 1;
} else {
ranges.push(message_ix..message_ix + 1);
}
} else {
new_messages.push(message.clone(), &());
}
old_cursor.next(&());
}
}
drop(old_cursor);
self.messages = new_messages;
for range in ranges.into_iter().rev() {
cx.emit(ChannelChatEvent::MessagesUpdated {
old_range: range,
new_count: 0,
});
}
cx.emit(ChannelChatEvent::MessagesUpdated {
old_range: start_ix..end_ix,
new_count,
});
cx.notify();
}
}
fn message_removed(&mut self, id: u64, cx: &mut ModelContext<Self>) {
let mut cursor = self.messages.cursor::<ChannelMessageId>();
let mut messages = cursor.slice(&ChannelMessageId::Saved(id), Bias::Left, &());
if let Some(item) = cursor.item() {
if item.id == ChannelMessageId::Saved(id) {
let ix = messages.summary().count;
cursor.next(&());
messages.append(cursor.suffix(&()), &());
drop(cursor);
self.messages = messages;
cx.emit(ChannelChatEvent::MessagesUpdated {
old_range: ix..ix + 1,
new_count: 0,
});
}
}
}
}
async fn messages_from_proto(
proto_messages: Vec<proto::ChannelMessage>,
user_store: &Model<UserStore>,
cx: &mut AsyncAppContext,
) -> Result<SumTree<ChannelMessage>> {
let messages = ChannelMessage::from_proto_vec(proto_messages, user_store, cx).await?;
let mut result = SumTree::new();
result.extend(messages, &());
Ok(result)
}
impl ChannelMessage {
pub async fn from_proto(
message: proto::ChannelMessage,
user_store: &Model<UserStore>,
cx: &mut AsyncAppContext,
) -> Result<Self> {
let sender = user_store
.update(cx, |user_store, cx| {
user_store.get_user(message.sender_id, cx)
})?
.await?;
Ok(ChannelMessage {
id: ChannelMessageId::Saved(message.id),
body: message.body,
mentions: message
.mentions
.into_iter()
.filter_map(|mention| {
let range = mention.range?;
Some((range.start as usize..range.end as usize, mention.user_id))
})
.collect(),
timestamp: OffsetDateTime::from_unix_timestamp(message.timestamp as i64)?,
sender,
nonce: message
.nonce
.ok_or_else(|| anyhow!("nonce is required"))?
.into(),
})
}
pub fn is_pending(&self) -> bool {
matches!(self.id, ChannelMessageId::Pending(_))
}
pub async fn from_proto_vec(
proto_messages: Vec<proto::ChannelMessage>,
user_store: &Model<UserStore>,
cx: &mut AsyncAppContext,
) -> Result<Vec<Self>> {
let unique_user_ids = proto_messages
.iter()
.map(|m| m.sender_id)
.collect::<HashSet<_>>()
.into_iter()
.collect();
user_store
.update(cx, |user_store, cx| {
user_store.get_users(unique_user_ids, cx)
})?
.await?;
let mut messages = Vec::with_capacity(proto_messages.len());
for message in proto_messages {
messages.push(ChannelMessage::from_proto(message, user_store, cx).await?);
}
Ok(messages)
}
}
pub fn mentions_to_proto(mentions: &[(Range<usize>, UserId)]) -> Vec<proto::ChatMention> {
mentions
.iter()
.map(|(range, user_id)| proto::ChatMention {
range: Some(proto::Range {
start: range.start as u64,
end: range.end as u64,
}),
user_id: *user_id as u64,
})
.collect()
}
impl sum_tree::Item for ChannelMessage {
type Summary = ChannelMessageSummary;
fn summary(&self) -> Self::Summary {
ChannelMessageSummary {
max_id: self.id,
count: 1,
}
}
}
impl Default for ChannelMessageId {
fn default() -> Self {
Self::Saved(0)
}
}
impl sum_tree::Summary for ChannelMessageSummary {
type Context = ();
fn add_summary(&mut self, summary: &Self, _: &()) {
self.max_id = summary.max_id;
self.count += summary.count;
}
}
impl<'a> sum_tree::Dimension<'a, ChannelMessageSummary> for ChannelMessageId {
fn add_summary(&mut self, summary: &'a ChannelMessageSummary, _: &()) {
debug_assert!(summary.max_id > *self);
*self = summary.max_id;
}
}
impl<'a> sum_tree::Dimension<'a, ChannelMessageSummary> for Count {
fn add_summary(&mut self, summary: &'a ChannelMessageSummary, _: &()) {
self.0 += summary.count;
}
}
impl<'a> From<&'a str> for MessageParams {
fn from(value: &'a str) -> Self {
Self {
text: value.into(),
mentions: Vec::new(),
}
}
}

File diff suppressed because it is too large Load Diff

View File

@ -1,184 +0,0 @@
use crate::{Channel, ChannelId};
use collections::BTreeMap;
use rpc::proto;
use std::sync::Arc;
#[derive(Default, Debug)]
pub struct ChannelIndex {
channels_ordered: Vec<ChannelId>,
channels_by_id: BTreeMap<ChannelId, Arc<Channel>>,
}
impl ChannelIndex {
pub fn by_id(&self) -> &BTreeMap<ChannelId, Arc<Channel>> {
&self.channels_by_id
}
pub fn ordered_channels(&self) -> &[ChannelId] {
&self.channels_ordered
}
pub fn clear(&mut self) {
self.channels_ordered.clear();
self.channels_by_id.clear();
}
/// Delete the given channels from this index.
pub fn delete_channels(&mut self, channels: &[ChannelId]) {
self.channels_by_id
.retain(|channel_id, _| !channels.contains(channel_id));
self.channels_ordered
.retain(|channel_id| !channels.contains(channel_id));
}
pub fn bulk_insert(&mut self) -> ChannelPathsInsertGuard {
ChannelPathsInsertGuard {
channels_ordered: &mut self.channels_ordered,
channels_by_id: &mut self.channels_by_id,
}
}
pub fn acknowledge_note_version(
&mut self,
channel_id: ChannelId,
epoch: u64,
version: &clock::Global,
) {
if let Some(channel) = self.channels_by_id.get_mut(&channel_id) {
let channel = Arc::make_mut(channel);
if let Some((unseen_epoch, unseen_version)) = &channel.unseen_note_version {
if epoch > *unseen_epoch
|| epoch == *unseen_epoch && version.observed_all(unseen_version)
{
channel.unseen_note_version = None;
}
}
}
}
pub fn acknowledge_message_id(&mut self, channel_id: ChannelId, message_id: u64) {
if let Some(channel) = self.channels_by_id.get_mut(&channel_id) {
let channel = Arc::make_mut(channel);
if let Some(unseen_message_id) = channel.unseen_message_id {
if message_id >= unseen_message_id {
channel.unseen_message_id = None;
}
}
}
}
pub fn note_changed(&mut self, channel_id: ChannelId, epoch: u64, version: &clock::Global) {
insert_note_changed(&mut self.channels_by_id, channel_id, epoch, version);
}
pub fn new_message(&mut self, channel_id: ChannelId, message_id: u64) {
insert_new_message(&mut self.channels_by_id, channel_id, message_id)
}
}
/// A guard for ensuring that the paths index maintains its sort and uniqueness
/// invariants after a series of insertions
#[derive(Debug)]
pub struct ChannelPathsInsertGuard<'a> {
channels_ordered: &'a mut Vec<ChannelId>,
channels_by_id: &'a mut BTreeMap<ChannelId, Arc<Channel>>,
}
impl<'a> ChannelPathsInsertGuard<'a> {
pub fn note_changed(&mut self, channel_id: ChannelId, epoch: u64, version: &clock::Global) {
insert_note_changed(&mut self.channels_by_id, channel_id, epoch, &version);
}
pub fn new_messages(&mut self, channel_id: ChannelId, message_id: u64) {
insert_new_message(&mut self.channels_by_id, channel_id, message_id)
}
pub fn insert(&mut self, channel_proto: proto::Channel) -> bool {
let mut ret = false;
if let Some(existing_channel) = self.channels_by_id.get_mut(&channel_proto.id) {
let existing_channel = Arc::make_mut(existing_channel);
ret = existing_channel.visibility != channel_proto.visibility()
|| existing_channel.role != channel_proto.role()
|| existing_channel.name != channel_proto.name;
existing_channel.visibility = channel_proto.visibility();
existing_channel.role = channel_proto.role();
existing_channel.name = channel_proto.name.into();
} else {
self.channels_by_id.insert(
channel_proto.id,
Arc::new(Channel {
id: channel_proto.id,
visibility: channel_proto.visibility(),
role: channel_proto.role(),
name: channel_proto.name.into(),
unseen_note_version: None,
unseen_message_id: None,
parent_path: channel_proto.parent_path,
}),
);
self.insert_root(channel_proto.id);
}
ret
}
fn insert_root(&mut self, channel_id: ChannelId) {
self.channels_ordered.push(channel_id);
}
}
impl<'a> Drop for ChannelPathsInsertGuard<'a> {
fn drop(&mut self) {
self.channels_ordered.sort_by(|a, b| {
let a = channel_path_sorting_key(*a, &self.channels_by_id);
let b = channel_path_sorting_key(*b, &self.channels_by_id);
a.cmp(b)
});
self.channels_ordered.dedup();
}
}
fn channel_path_sorting_key<'a>(
id: ChannelId,
channels_by_id: &'a BTreeMap<ChannelId, Arc<Channel>>,
) -> impl Iterator<Item = &str> {
let (parent_path, name) = channels_by_id
.get(&id)
.map_or((&[] as &[_], None), |channel| {
(channel.parent_path.as_slice(), Some(channel.name.as_ref()))
});
parent_path
.iter()
.filter_map(|id| Some(channels_by_id.get(id)?.name.as_ref()))
.chain(name)
}
fn insert_note_changed(
channels_by_id: &mut BTreeMap<ChannelId, Arc<Channel>>,
channel_id: u64,
epoch: u64,
version: &clock::Global,
) {
if let Some(channel) = channels_by_id.get_mut(&channel_id) {
let unseen_version = Arc::make_mut(channel)
.unseen_note_version
.get_or_insert((0, clock::Global::new()));
if epoch > unseen_version.0 {
*unseen_version = (epoch, version.clone());
} else {
unseen_version.1.join(&version);
}
}
}
fn insert_new_message(
channels_by_id: &mut BTreeMap<ChannelId, Arc<Channel>>,
channel_id: u64,
message_id: u64,
) {
if let Some(channel) = channels_by_id.get_mut(&channel_id) {
let unseen_message_id = Arc::make_mut(channel).unseen_message_id.get_or_insert(0);
*unseen_message_id = message_id.max(*unseen_message_id);
}
}

View File

@ -1,380 +0,0 @@
use crate::channel_chat::ChannelChatEvent;
use super::*;
use client::{test::FakeServer, Client, UserStore};
use gpui::{AppContext, Context, Model, TestAppContext};
use rpc::proto::{self};
use settings::SettingsStore;
use util::http::FakeHttpClient;
#[gpui::test]
fn test_update_channels(cx: &mut AppContext) {
let channel_store = init_test(cx);
update_channels(
&channel_store,
proto::UpdateChannels {
channels: vec![
proto::Channel {
id: 1,
name: "b".to_string(),
visibility: proto::ChannelVisibility::Members as i32,
role: proto::ChannelRole::Admin.into(),
parent_path: Vec::new(),
},
proto::Channel {
id: 2,
name: "a".to_string(),
visibility: proto::ChannelVisibility::Members as i32,
role: proto::ChannelRole::Member.into(),
parent_path: Vec::new(),
},
],
..Default::default()
},
cx,
);
assert_channels(
&channel_store,
&[
//
(0, "a".to_string(), proto::ChannelRole::Member),
(0, "b".to_string(), proto::ChannelRole::Admin),
],
cx,
);
update_channels(
&channel_store,
proto::UpdateChannels {
channels: vec![
proto::Channel {
id: 3,
name: "x".to_string(),
visibility: proto::ChannelVisibility::Members as i32,
role: proto::ChannelRole::Admin.into(),
parent_path: vec![1],
},
proto::Channel {
id: 4,
name: "y".to_string(),
visibility: proto::ChannelVisibility::Members as i32,
role: proto::ChannelRole::Member.into(),
parent_path: vec![2],
},
],
..Default::default()
},
cx,
);
assert_channels(
&channel_store,
&[
(0, "a".to_string(), proto::ChannelRole::Member),
(1, "y".to_string(), proto::ChannelRole::Member),
(0, "b".to_string(), proto::ChannelRole::Admin),
(1, "x".to_string(), proto::ChannelRole::Admin),
],
cx,
);
}
#[gpui::test]
fn test_dangling_channel_paths(cx: &mut AppContext) {
let channel_store = init_test(cx);
update_channels(
&channel_store,
proto::UpdateChannels {
channels: vec![
proto::Channel {
id: 0,
name: "a".to_string(),
visibility: proto::ChannelVisibility::Members as i32,
role: proto::ChannelRole::Admin.into(),
parent_path: vec![],
},
proto::Channel {
id: 1,
name: "b".to_string(),
visibility: proto::ChannelVisibility::Members as i32,
role: proto::ChannelRole::Admin.into(),
parent_path: vec![0],
},
proto::Channel {
id: 2,
name: "c".to_string(),
visibility: proto::ChannelVisibility::Members as i32,
role: proto::ChannelRole::Admin.into(),
parent_path: vec![0, 1],
},
],
..Default::default()
},
cx,
);
// Sanity check
assert_channels(
&channel_store,
&[
//
(0, "a".to_string(), proto::ChannelRole::Admin),
(1, "b".to_string(), proto::ChannelRole::Admin),
(2, "c".to_string(), proto::ChannelRole::Admin),
],
cx,
);
update_channels(
&channel_store,
proto::UpdateChannels {
delete_channels: vec![1, 2],
..Default::default()
},
cx,
);
// Make sure that the 1/2/3 path is gone
assert_channels(
&channel_store,
&[(0, "a".to_string(), proto::ChannelRole::Admin)],
cx,
);
}
#[gpui::test]
async fn test_channel_messages(cx: &mut TestAppContext) {
let user_id = 5;
let channel_id = 5;
let channel_store = cx.update(init_test);
let client = channel_store.update(cx, |s, _| s.client());
let server = FakeServer::for_client(user_id, &client, cx).await;
// Get the available channels.
server.send(proto::UpdateChannels {
channels: vec![proto::Channel {
id: channel_id,
name: "the-channel".to_string(),
visibility: proto::ChannelVisibility::Members as i32,
role: proto::ChannelRole::Member.into(),
parent_path: vec![],
}],
..Default::default()
});
cx.executor().run_until_parked();
cx.update(|cx| {
assert_channels(
&channel_store,
&[(0, "the-channel".to_string(), proto::ChannelRole::Member)],
cx,
);
});
let get_users = server.receive::<proto::GetUsers>().await.unwrap();
assert_eq!(get_users.payload.user_ids, vec![5]);
server.respond(
get_users.receipt(),
proto::UsersResponse {
users: vec![proto::User {
id: 5,
github_login: "nathansobo".into(),
avatar_url: "http://avatar.com/nathansobo".into(),
}],
},
);
// Join a channel and populate its existing messages.
let channel = channel_store.update(cx, |store, cx| {
let channel_id = store.ordered_channels().next().unwrap().1.id;
store.open_channel_chat(channel_id, cx)
});
let join_channel = server.receive::<proto::JoinChannelChat>().await.unwrap();
server.respond(
join_channel.receipt(),
proto::JoinChannelChatResponse {
messages: vec![
proto::ChannelMessage {
id: 10,
body: "a".into(),
timestamp: 1000,
sender_id: 5,
mentions: vec![],
nonce: Some(1.into()),
},
proto::ChannelMessage {
id: 11,
body: "b".into(),
timestamp: 1001,
sender_id: 6,
mentions: vec![],
nonce: Some(2.into()),
},
],
done: false,
},
);
cx.executor().start_waiting();
// Client requests all users for the received messages
let mut get_users = server.receive::<proto::GetUsers>().await.unwrap();
get_users.payload.user_ids.sort();
assert_eq!(get_users.payload.user_ids, vec![6]);
server.respond(
get_users.receipt(),
proto::UsersResponse {
users: vec![proto::User {
id: 6,
github_login: "maxbrunsfeld".into(),
avatar_url: "http://avatar.com/maxbrunsfeld".into(),
}],
},
);
let channel = channel.await.unwrap();
channel.update(cx, |channel, _| {
assert_eq!(
channel
.messages_in_range(0..2)
.map(|message| (message.sender.github_login.clone(), message.body.clone()))
.collect::<Vec<_>>(),
&[
("nathansobo".into(), "a".into()),
("maxbrunsfeld".into(), "b".into())
]
);
});
// Receive a new message.
server.send(proto::ChannelMessageSent {
channel_id,
message: Some(proto::ChannelMessage {
id: 12,
body: "c".into(),
timestamp: 1002,
sender_id: 7,
mentions: vec![],
nonce: Some(3.into()),
}),
});
// Client requests user for message since they haven't seen them yet
let get_users = server.receive::<proto::GetUsers>().await.unwrap();
assert_eq!(get_users.payload.user_ids, vec![7]);
server.respond(
get_users.receipt(),
proto::UsersResponse {
users: vec![proto::User {
id: 7,
github_login: "as-cii".into(),
avatar_url: "http://avatar.com/as-cii".into(),
}],
},
);
assert_eq!(
channel.next_event(cx),
ChannelChatEvent::MessagesUpdated {
old_range: 2..2,
new_count: 1,
}
);
channel.update(cx, |channel, _| {
assert_eq!(
channel
.messages_in_range(2..3)
.map(|message| (message.sender.github_login.clone(), message.body.clone()))
.collect::<Vec<_>>(),
&[("as-cii".into(), "c".into())]
)
});
// Scroll up to view older messages.
channel.update(cx, |channel, cx| {
channel.load_more_messages(cx).unwrap().detach();
});
let get_messages = server.receive::<proto::GetChannelMessages>().await.unwrap();
assert_eq!(get_messages.payload.channel_id, 5);
assert_eq!(get_messages.payload.before_message_id, 10);
server.respond(
get_messages.receipt(),
proto::GetChannelMessagesResponse {
done: true,
messages: vec![
proto::ChannelMessage {
id: 8,
body: "y".into(),
timestamp: 998,
sender_id: 5,
nonce: Some(4.into()),
mentions: vec![],
},
proto::ChannelMessage {
id: 9,
body: "z".into(),
timestamp: 999,
sender_id: 6,
nonce: Some(5.into()),
mentions: vec![],
},
],
},
);
assert_eq!(
channel.next_event(cx),
ChannelChatEvent::MessagesUpdated {
old_range: 0..0,
new_count: 2,
}
);
channel.update(cx, |channel, _| {
assert_eq!(
channel
.messages_in_range(0..2)
.map(|message| (message.sender.github_login.clone(), message.body.clone()))
.collect::<Vec<_>>(),
&[
("nathansobo".into(), "y".into()),
("maxbrunsfeld".into(), "z".into())
]
);
});
}
fn init_test(cx: &mut AppContext) -> Model<ChannelStore> {
let http = FakeHttpClient::with_404_response();
let client = Client::new(http.clone(), cx);
let user_store = cx.new_model(|cx| UserStore::new(client.clone(), cx));
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
client::init(&client, cx);
crate::init(&client, user_store, cx);
ChannelStore::global(cx)
}
fn update_channels(
channel_store: &Model<ChannelStore>,
message: proto::UpdateChannels,
cx: &mut AppContext,
) {
let task = channel_store.update(cx, |store, cx| store.update_channels(message, cx));
assert!(task.is_none());
}
#[track_caller]
fn assert_channels(
channel_store: &Model<ChannelStore>,
expected_channels: &[(usize, String, proto::ChannelRole)],
cx: &mut AppContext,
) {
let actual = channel_store.update(cx, |store, _| {
store
.ordered_channels()
.map(|(depth, channel)| (depth, channel.name.to_string(), channel.role))
.collect::<Vec<_>>()
});
assert_eq!(actual, expected_channels);
}

File diff suppressed because it is too large Load Diff

View File

@ -1,9 +1,11 @@
use crate::{TelemetrySettings, ZED_SECRET_CLIENT_TOKEN, ZED_SERVER_URL};
use chrono::{DateTime, Utc};
use gpui::{executor::Background, serde_json, AppContext, Task};
use futures::Future;
use gpui::{serde_json, AppContext, AppMetadata, BackgroundExecutor, Task};
use lazy_static::lazy_static;
use parking_lot::Mutex;
use serde::Serialize;
use settings::Settings;
use std::{env, io::Write, mem, path::PathBuf, sync::Arc, time::Duration};
use sysinfo::{
CpuRefreshKind, Pid, PidExt, ProcessExt, ProcessRefreshKind, RefreshKind, System, SystemExt,
@ -14,19 +16,16 @@ use util::{channel::ReleaseChannel, TryFutureExt};
pub struct Telemetry {
http_client: Arc<dyn HttpClient>,
executor: Arc<Background>,
executor: BackgroundExecutor,
state: Mutex<TelemetryState>,
}
#[derive(Default)]
struct TelemetryState {
metrics_id: Option<Arc<str>>, // Per logged-in user
installation_id: Option<Arc<str>>, // Per app installation (different for dev, nightly, preview, and stable)
session_id: Option<Arc<str>>, // Per app launch
app_version: Option<Arc<str>>,
release_channel: Option<&'static str>,
os_name: &'static str,
os_version: Option<Arc<str>>,
app_metadata: AppMetadata,
architecture: &'static str,
clickhouse_events_queue: Vec<ClickhouseEventWrapper>,
flush_clickhouse_events_task: Option<Task<()>>,
@ -48,9 +47,9 @@ struct ClickhouseEventRequestBody {
installation_id: Option<Arc<str>>,
session_id: Option<Arc<str>>,
is_staff: Option<bool>,
app_version: Option<Arc<str>>,
app_version: Option<String>,
os_name: &'static str,
os_version: Option<Arc<str>>,
os_version: Option<String>,
architecture: &'static str,
release_channel: Option<&'static str>,
events: Vec<ClickhouseEventWrapper>,
@ -130,25 +129,23 @@ const MAX_QUEUE_LEN: usize = 50;
const DEBOUNCE_INTERVAL: Duration = Duration::from_secs(1);
#[cfg(not(debug_assertions))]
const DEBOUNCE_INTERVAL: Duration = Duration::from_secs(120);
const DEBOUNCE_INTERVAL: Duration = Duration::from_secs(60 * 5);
impl Telemetry {
pub fn new(client: Arc<dyn HttpClient>, cx: &AppContext) -> Arc<Self> {
let platform = cx.platform();
pub fn new(client: Arc<dyn HttpClient>, cx: &mut AppContext) -> Arc<Self> {
let release_channel = if cx.has_global::<ReleaseChannel>() {
Some(cx.global::<ReleaseChannel>().display_name())
} else {
None
};
// TODO: Replace all hardware stuff with nested SystemSpecs json
let this = Arc::new(Self {
http_client: client,
executor: cx.background().clone(),
executor: cx.background_executor().clone(),
state: Mutex::new(TelemetryState {
os_name: platform.os_name().into(),
os_version: platform.os_version().ok().map(|v| v.to_string().into()),
app_metadata: cx.app_metadata(),
architecture: env::consts::ARCH,
app_version: platform.app_version().ok().map(|v| v.to_string().into()),
release_channel,
installation_id: None,
metrics_id: None,
@ -161,9 +158,30 @@ impl Telemetry {
}),
});
// We should only ever have one instance of Telemetry, leak the subscription to keep it alive
// rather than store in TelemetryState, complicating spawn as subscriptions are not Send
std::mem::forget(cx.on_app_quit({
let this = this.clone();
move |cx| this.shutdown_telemetry(cx)
}));
this
}
#[cfg(any(test, feature = "test-support"))]
fn shutdown_telemetry(self: &Arc<Self>, _: &mut AppContext) -> impl Future<Output = ()> {
Task::ready(())
}
// Skip calling this function in tests.
// TestAppContext ends up calling this function on shutdown and it panics when trying to find the TelemetrySettings
#[cfg(not(any(test, feature = "test-support")))]
fn shutdown_telemetry(self: &Arc<Self>, cx: &mut AppContext) -> impl Future<Output = ()> {
let telemetry_settings = TelemetrySettings::get_global(cx).clone();
self.report_app_event(telemetry_settings, "close", true);
Task::ready(())
}
pub fn log_file_path(&self) -> Option<PathBuf> {
Some(self.state.lock().log_file.as_ref()?.path().to_path_buf())
}
@ -180,7 +198,7 @@ impl Telemetry {
drop(state);
let this = self.clone();
cx.spawn(|mut cx| async move {
cx.spawn(|cx| async move {
// Avoiding calling `System::new_all()`, as there have been crashes related to it
let refresh_kind = RefreshKind::new()
.with_memory() // For memory usage
@ -209,7 +227,13 @@ impl Telemetry {
return;
};
let telemetry_settings = cx.update(|cx| *settings::get::<TelemetrySettings>(cx));
let telemetry_settings = if let Ok(telemetry_settings) =
cx.update(|cx| *TelemetrySettings::get_global(cx))
{
telemetry_settings
} else {
break;
};
this.report_memory_event(
telemetry_settings,
@ -232,7 +256,7 @@ impl Telemetry {
is_staff: bool,
cx: &AppContext,
) {
if !settings::get::<TelemetrySettings>(cx).metrics {
if !TelemetrySettings::get_global(cx).metrics {
return;
}
@ -461,9 +485,15 @@ impl Telemetry {
installation_id: state.installation_id.clone(),
session_id: state.session_id.clone(),
is_staff: state.is_staff.clone(),
app_version: state.app_version.clone(),
os_name: state.os_name,
os_version: state.os_version.clone(),
app_version: state
.app_metadata
.app_version
.map(|version| version.to_string()),
os_name: state.app_metadata.os_name,
os_version: state
.app_metadata
.os_version
.map(|version| version.to_string()),
architecture: state.architecture,
release_channel: state.release_channel,

View File

@ -1,20 +1,19 @@
use crate::{Client, Connection, Credentials, EstablishConnectionError, UserStore};
use anyhow::{anyhow, Result};
use futures::{stream::BoxStream, StreamExt};
use gpui::{executor, ModelHandle, TestAppContext};
use gpui::{BackgroundExecutor, Context, Model, TestAppContext};
use parking_lot::Mutex;
use rpc::{
proto::{self, GetPrivateUserInfo, GetPrivateUserInfoResponse},
ConnectionId, Peer, Receipt, TypedEnvelope,
};
use std::{rc::Rc, sync::Arc};
use util::http::FakeHttpClient;
use std::sync::Arc;
pub struct FakeServer {
peer: Arc<Peer>,
state: Arc<Mutex<FakeServerState>>,
user_id: u64,
executor: Rc<executor::Foreground>,
executor: BackgroundExecutor,
}
#[derive(Default)]
@ -36,7 +35,7 @@ impl FakeServer {
peer: Peer::new(0),
state: Default::default(),
user_id: client_user_id,
executor: cx.foreground(),
executor: cx.executor(),
};
client
@ -78,10 +77,11 @@ impl FakeServer {
Err(EstablishConnectionError::Unauthorized)?
}
let (client_conn, server_conn, _) = Connection::in_memory(cx.background());
let (client_conn, server_conn, _) =
Connection::in_memory(cx.background_executor().clone());
let (connection_id, io, incoming) =
peer.add_test_connection(server_conn, cx.background());
cx.background().spawn(io).detach();
peer.add_test_connection(server_conn, cx.background_executor().clone());
cx.background_executor().spawn(io).detach();
{
let mut state = state.lock();
state.connection_id = Some(connection_id);
@ -193,9 +193,8 @@ impl FakeServer {
&self,
client: Arc<Client>,
cx: &mut TestAppContext,
) -> ModelHandle<UserStore> {
let http_client = FakeHttpClient::with_404_response();
let user_store = cx.add_model(|cx| UserStore::new(client, http_client, cx));
) -> Model<UserStore> {
let user_store = cx.new_model(|cx| UserStore::new(client, cx));
assert_eq!(
self.receive::<proto::GetUsers>()
.await

View File

@ -2,13 +2,12 @@ use super::{proto, Client, Status, TypedEnvelope};
use anyhow::{anyhow, Context, Result};
use collections::{hash_map::Entry, HashMap, HashSet};
use feature_flags::FeatureFlagAppExt;
use futures::{channel::mpsc, future, AsyncReadExt, Future, StreamExt};
use gpui::{AsyncAppContext, Entity, ImageData, ModelContext, ModelHandle, Task};
use futures::{channel::mpsc, Future, StreamExt};
use gpui::{AsyncAppContext, EventEmitter, Model, ModelContext, SharedString, Task};
use postage::{sink::Sink, watch};
use rpc::proto::{RequestMessage, UsersResponse};
use std::sync::{Arc, Weak};
use text::ReplicaId;
use util::http::HttpClient;
use util::TryFutureExt as _;
pub type UserId = u64;
@ -20,7 +19,7 @@ pub struct ParticipantIndex(pub u32);
pub struct User {
pub id: UserId,
pub github_login: String,
pub avatar: Option<Arc<ImageData>>,
pub avatar_uri: SharedString,
}
#[derive(Clone, Debug, PartialEq, Eq)]
@ -76,9 +75,8 @@ pub struct UserStore {
pending_contact_requests: HashMap<u64, usize>,
invite_info: Option<InviteInfo>,
client: Weak<Client>,
http: Arc<dyn HttpClient>,
_maintain_contacts: Task<()>,
_maintain_current_user: Task<()>,
_maintain_current_user: Task<Result<()>>,
}
#[derive(Clone)]
@ -103,9 +101,7 @@ pub enum ContactEventKind {
Cancelled,
}
impl Entity for UserStore {
type Event = Event;
}
impl EventEmitter<Event> for UserStore {}
enum UpdateContacts {
Update(proto::UpdateContacts),
@ -114,17 +110,13 @@ enum UpdateContacts {
}
impl UserStore {
pub fn new(
client: Arc<Client>,
http: Arc<dyn HttpClient>,
cx: &mut ModelContext<Self>,
) -> Self {
pub fn new(client: Arc<Client>, cx: &mut ModelContext<Self>) -> Self {
let (mut current_user_tx, current_user_rx) = watch::channel();
let (update_contacts_tx, mut update_contacts_rx) = mpsc::unbounded();
let rpc_subscriptions = vec![
client.add_message_handler(cx.handle(), Self::handle_update_contacts),
client.add_message_handler(cx.handle(), Self::handle_update_invite_info),
client.add_message_handler(cx.handle(), Self::handle_show_contacts),
client.add_message_handler(cx.weak_model(), Self::handle_update_contacts),
client.add_message_handler(cx.weak_model(), Self::handle_update_invite_info),
client.add_message_handler(cx.weak_model(), Self::handle_show_contacts),
];
Self {
users: Default::default(),
@ -136,76 +128,71 @@ impl UserStore {
invite_info: None,
client: Arc::downgrade(&client),
update_contacts_tx,
http,
_maintain_contacts: cx.spawn_weak(|this, mut cx| async move {
_maintain_contacts: cx.spawn(|this, mut cx| async move {
let _subscriptions = rpc_subscriptions;
while let Some(message) = update_contacts_rx.next().await {
if let Some(this) = this.upgrade(&cx) {
if let Ok(task) =
this.update(&mut cx, |this, cx| this.update_contacts(message, cx))
.log_err()
.await;
{
task.log_err().await;
} else {
break;
}
}
}),
_maintain_current_user: cx.spawn_weak(|this, mut cx| async move {
_maintain_current_user: cx.spawn(|this, mut cx| async move {
let mut status = client.status();
while let Some(status) = status.next().await {
match status {
Status::Connected { .. } => {
if let Some((this, user_id)) = this.upgrade(&cx).zip(client.user_id()) {
let fetch_user = this
.update(&mut cx, |this, cx| this.get_user(user_id, cx))
.log_err();
if let Some(user_id) = client.user_id() {
let fetch_user = if let Ok(fetch_user) = this
.update(&mut cx, |this, cx| {
this.get_user(user_id, cx).log_err()
}) {
fetch_user
} else {
break;
};
let fetch_metrics_id =
client.request(proto::GetPrivateUserInfo {}).log_err();
let (user, info) = futures::join!(fetch_user, fetch_metrics_id);
if let Some(info) = info {
cx.update(|cx| {
cx.update(|cx| {
if let Some(info) = info {
cx.update_flags(info.staff, info.flags);
client.telemetry.set_authenticated_user_info(
Some(info.metrics_id.clone()),
info.staff,
cx,
)
});
} else {
cx.read(|cx| {
client
.telemetry
.set_authenticated_user_info(None, false, cx)
});
}
}
})?;
current_user_tx.send(user).await.ok();
this.update(&mut cx, |_, cx| {
cx.notify();
});
this.update(&mut cx, |_, cx| cx.notify())?;
}
}
Status::SignedOut => {
current_user_tx.send(None).await.ok();
if let Some(this) = this.upgrade(&cx) {
this.update(&mut cx, |this, cx| {
cx.notify();
this.clear_contacts()
})
.await;
}
this.update(&mut cx, |this, cx| {
cx.notify();
this.clear_contacts()
})?
.await;
}
Status::ConnectionLost => {
if let Some(this) = this.upgrade(&cx) {
this.update(&mut cx, |this, cx| {
cx.notify();
this.clear_contacts()
})
.await;
}
this.update(&mut cx, |this, cx| {
cx.notify();
this.clear_contacts()
})?
.await;
}
_ => {}
}
}
Ok(())
}),
pending_contact_requests: Default::default(),
}
@ -217,7 +204,7 @@ impl UserStore {
}
async fn handle_update_invite_info(
this: ModelHandle<Self>,
this: Model<Self>,
message: TypedEnvelope<proto::UpdateInviteInfo>,
_: Arc<Client>,
mut cx: AsyncAppContext,
@ -228,17 +215,17 @@ impl UserStore {
count: message.payload.count,
});
cx.notify();
});
})?;
Ok(())
}
async fn handle_show_contacts(
this: ModelHandle<Self>,
this: Model<Self>,
_: TypedEnvelope<proto::ShowContacts>,
_: Arc<Client>,
mut cx: AsyncAppContext,
) -> Result<()> {
this.update(&mut cx, |_, cx| cx.emit(Event::ShowContacts));
this.update(&mut cx, |_, cx| cx.emit(Event::ShowContacts))?;
Ok(())
}
@ -247,7 +234,7 @@ impl UserStore {
}
async fn handle_update_contacts(
this: ModelHandle<Self>,
this: Model<Self>,
message: TypedEnvelope<proto::UpdateContacts>,
_: Arc<Client>,
mut cx: AsyncAppContext,
@ -256,7 +243,7 @@ impl UserStore {
this.update_contacts_tx
.unbounded_send(UpdateContacts::Update(message.payload))
.unwrap();
});
})?;
Ok(())
}
@ -292,6 +279,9 @@ impl UserStore {
// Users are fetched in parallel above and cached in call to get_users
// No need to paralellize here
let mut updated_contacts = Vec::new();
let this = this
.upgrade()
.ok_or_else(|| anyhow!("can't upgrade user store handle"))?;
for contact in message.contacts {
updated_contacts.push(Arc::new(
Contact::from_proto(contact, &this, &mut cx).await?,
@ -300,18 +290,18 @@ impl UserStore {
let mut incoming_requests = Vec::new();
for request in message.incoming_requests {
incoming_requests.push(
incoming_requests.push({
this.update(&mut cx, |this, cx| {
this.get_user(request.requester_id, cx)
})
.await?,
);
})?
.await?
});
}
let mut outgoing_requests = Vec::new();
for requested_user_id in message.outgoing_requests {
outgoing_requests.push(
this.update(&mut cx, |this, cx| this.get_user(requested_user_id, cx))
this.update(&mut cx, |this, cx| this.get_user(requested_user_id, cx))?
.await?,
);
}
@ -378,7 +368,7 @@ impl UserStore {
}
cx.notify();
});
})?;
Ok(())
})
@ -400,12 +390,6 @@ impl UserStore {
&self.incoming_contact_requests
}
pub fn has_incoming_contact_request(&self, user_id: u64) -> bool {
self.incoming_contact_requests
.iter()
.any(|user| user.id == user_id)
}
pub fn outgoing_contact_requests(&self) -> &[Arc<User>] {
&self.outgoing_contact_requests
}
@ -454,6 +438,12 @@ impl UserStore {
self.perform_contact_request(user_id, proto::RemoveContact { user_id }, cx)
}
pub fn has_incoming_contact_request(&self, user_id: u64) -> bool {
self.incoming_contact_requests
.iter()
.any(|user| user.id == user_id)
}
pub fn respond_to_contact_request(
&mut self,
requester_id: u64,
@ -480,7 +470,7 @@ impl UserStore {
cx: &mut ModelContext<Self>,
) -> Task<Result<()>> {
let client = self.client.upgrade();
cx.spawn_weak(|_, _| async move {
cx.spawn(move |_, _| async move {
client
.ok_or_else(|| anyhow!("can't upgrade client reference"))?
.request(proto::RespondToContactRequest {
@ -502,7 +492,7 @@ impl UserStore {
*self.pending_contact_requests.entry(user_id).or_insert(0) += 1;
cx.notify();
cx.spawn(|this, mut cx| async move {
cx.spawn(move |this, mut cx| async move {
let response = client
.ok_or_else(|| anyhow!("can't upgrade client reference"))?
.request(request)
@ -517,7 +507,7 @@ impl UserStore {
}
}
cx.notify();
});
})?;
response?;
Ok(())
})
@ -560,11 +550,11 @@ impl UserStore {
},
cx,
)
})
})?
.await?;
}
this.read_with(&cx, |this, _| {
this.update(&mut cx, |this, _| {
user_ids
.iter()
.map(|user_id| {
@ -574,7 +564,7 @@ impl UserStore {
.ok_or_else(|| anyhow!("user {} not found", user_id))
})
.collect()
})
})?
})
}
@ -596,18 +586,18 @@ impl UserStore {
cx: &mut ModelContext<Self>,
) -> Task<Result<Arc<User>>> {
if let Some(user) = self.users.get(&user_id).cloned() {
return cx.foreground().spawn(async move { Ok(user) });
return Task::ready(Ok(user));
}
let load_users = self.get_users(vec![user_id], cx);
cx.spawn(|this, mut cx| async move {
cx.spawn(move |this, mut cx| async move {
load_users.await?;
this.update(&mut cx, |this, _| {
this.users
.get(&user_id)
.cloned()
.ok_or_else(|| anyhow!("server responded with no users"))
})
})?
})
}
@ -625,25 +615,22 @@ impl UserStore {
cx: &mut ModelContext<Self>,
) -> Task<Result<Vec<Arc<User>>>> {
let client = self.client.clone();
let http = self.http.clone();
cx.spawn_weak(|this, mut cx| async move {
cx.spawn(|this, mut cx| async move {
if let Some(rpc) = client.upgrade() {
let response = rpc.request(request).await.context("error loading users")?;
let users = future::join_all(
response
.users
.into_iter()
.map(|user| User::new(user, http.as_ref())),
)
.await;
let users = response
.users
.into_iter()
.map(|user| User::new(user))
.collect::<Vec<_>>();
this.update(&mut cx, |this, _| {
for user in &users {
this.users.insert(user.id, user.clone());
}
})
.ok();
if let Some(this) = this.upgrade(&cx) {
this.update(&mut cx, |this, _| {
for user in &users {
this.users.insert(user.id, user.clone());
}
});
}
Ok(users)
} else {
Ok(Vec::new())
@ -668,11 +655,11 @@ impl UserStore {
}
impl User {
async fn new(message: proto::User, http: &dyn HttpClient) -> Arc<Self> {
fn new(message: proto::User) -> Arc<Self> {
Arc::new(User {
id: message.id,
github_login: message.github_login,
avatar: fetch_avatar(http, &message.avatar_url).warn_on_err().await,
avatar_uri: message.avatar_url.into(),
})
}
}
@ -680,13 +667,13 @@ impl User {
impl Contact {
async fn from_proto(
contact: proto::Contact,
user_store: &ModelHandle<UserStore>,
user_store: &Model<UserStore>,
cx: &mut AsyncAppContext,
) -> Result<Self> {
let user = user_store
.update(cx, |user_store, cx| {
user_store.get_user(contact.user_id, cx)
})
})?
.await?;
Ok(Self {
user,
@ -705,24 +692,3 @@ impl Collaborator {
})
}
}
async fn fetch_avatar(http: &dyn HttpClient, url: &str) -> Result<Arc<ImageData>> {
let mut response = http
.get(url, Default::default(), true)
.await
.map_err(|e| anyhow!("failed to send user avatar request: {}", e))?;
if !response.status().is_success() {
return Err(anyhow!("avatar request failed {:?}", response.status()));
}
let mut body = Vec::new();
response
.body_mut()
.read_to_end(&mut body)
.await
.map_err(|e| anyhow!("failed to read user avatar response body: {}", e))?;
let format = image::guess_format(&body)?;
let image = image::load_from_memory_with_format(&body, format)?.into_bgra8();
Ok(ImageData::new(image))
}

View File

@ -1,53 +0,0 @@
[package]
name = "client2"
version = "0.1.0"
edition = "2021"
publish = false
[lib]
path = "src/client2.rs"
doctest = false
[features]
test-support = ["collections/test-support", "gpui/test-support", "rpc/test-support"]
[dependencies]
chrono = { version = "0.4", features = ["serde"] }
collections = { path = "../collections" }
db = { package = "db2", path = "../db2" }
gpui = { package = "gpui2", path = "../gpui2" }
util = { path = "../util" }
rpc = { package = "rpc2", path = "../rpc2" }
text = { package = "text2", path = "../text2" }
settings = { package = "settings2", path = "../settings2" }
feature_flags = { package = "feature_flags2", path = "../feature_flags2" }
sum_tree = { path = "../sum_tree" }
anyhow.workspace = true
async-recursion = "0.3"
async-tungstenite = { version = "0.16", features = ["async-tls"] }
futures.workspace = true
image = "0.23"
lazy_static.workspace = true
log.workspace = true
parking_lot.workspace = true
postage.workspace = true
rand.workspace = true
schemars.workspace = true
serde.workspace = true
serde_derive.workspace = true
smol.workspace = true
sysinfo.workspace = true
tempfile = "3"
thiserror.workspace = true
time.workspace = true
tiny_http = "0.8"
uuid.workspace = true
url = "2.2"
[dev-dependencies]
collections = { path = "../collections", features = ["test-support"] }
gpui = { package = "gpui2", path = "../gpui2", features = ["test-support"] }
rpc = { package = "rpc2", path = "../rpc2", features = ["test-support"] }
settings = { package = "settings2", path = "../settings2", features = ["test-support"] }
util = { path = "../util", features = ["test-support"] }

File diff suppressed because it is too large Load Diff

View File

@ -1,515 +0,0 @@
use crate::{TelemetrySettings, ZED_SECRET_CLIENT_TOKEN, ZED_SERVER_URL};
use chrono::{DateTime, Utc};
use futures::Future;
use gpui::{serde_json, AppContext, AppMetadata, BackgroundExecutor, Task};
use lazy_static::lazy_static;
use parking_lot::Mutex;
use serde::Serialize;
use settings::Settings;
use std::{env, io::Write, mem, path::PathBuf, sync::Arc, time::Duration};
use sysinfo::{
CpuRefreshKind, Pid, PidExt, ProcessExt, ProcessRefreshKind, RefreshKind, System, SystemExt,
};
use tempfile::NamedTempFile;
use util::http::HttpClient;
use util::{channel::ReleaseChannel, TryFutureExt};
pub struct Telemetry {
http_client: Arc<dyn HttpClient>,
executor: BackgroundExecutor,
state: Mutex<TelemetryState>,
}
struct TelemetryState {
metrics_id: Option<Arc<str>>, // Per logged-in user
installation_id: Option<Arc<str>>, // Per app installation (different for dev, nightly, preview, and stable)
session_id: Option<Arc<str>>, // Per app launch
release_channel: Option<&'static str>,
app_metadata: AppMetadata,
architecture: &'static str,
clickhouse_events_queue: Vec<ClickhouseEventWrapper>,
flush_clickhouse_events_task: Option<Task<()>>,
log_file: Option<NamedTempFile>,
is_staff: Option<bool>,
first_event_datetime: Option<DateTime<Utc>>,
}
const CLICKHOUSE_EVENTS_URL_PATH: &'static str = "/api/events";
lazy_static! {
static ref CLICKHOUSE_EVENTS_URL: String =
format!("{}{}", *ZED_SERVER_URL, CLICKHOUSE_EVENTS_URL_PATH);
}
#[derive(Serialize, Debug)]
struct ClickhouseEventRequestBody {
token: &'static str,
installation_id: Option<Arc<str>>,
session_id: Option<Arc<str>>,
is_staff: Option<bool>,
app_version: Option<String>,
os_name: &'static str,
os_version: Option<String>,
architecture: &'static str,
release_channel: Option<&'static str>,
events: Vec<ClickhouseEventWrapper>,
}
#[derive(Serialize, Debug)]
struct ClickhouseEventWrapper {
signed_in: bool,
#[serde(flatten)]
event: ClickhouseEvent,
}
#[derive(Serialize, Debug)]
#[serde(rename_all = "snake_case")]
pub enum AssistantKind {
Panel,
Inline,
}
#[derive(Serialize, Debug)]
#[serde(tag = "type")]
pub enum ClickhouseEvent {
Editor {
operation: &'static str,
file_extension: Option<String>,
vim_mode: bool,
copilot_enabled: bool,
copilot_enabled_for_language: bool,
milliseconds_since_first_event: i64,
},
Copilot {
suggestion_id: Option<String>,
suggestion_accepted: bool,
file_extension: Option<String>,
milliseconds_since_first_event: i64,
},
Call {
operation: &'static str,
room_id: Option<u64>,
channel_id: Option<u64>,
milliseconds_since_first_event: i64,
},
Assistant {
conversation_id: Option<String>,
kind: AssistantKind,
model: &'static str,
milliseconds_since_first_event: i64,
},
Cpu {
usage_as_percentage: f32,
core_count: u32,
milliseconds_since_first_event: i64,
},
Memory {
memory_in_bytes: u64,
virtual_memory_in_bytes: u64,
milliseconds_since_first_event: i64,
},
App {
operation: &'static str,
milliseconds_since_first_event: i64,
},
Setting {
setting: &'static str,
value: String,
milliseconds_since_first_event: i64,
},
}
#[cfg(debug_assertions)]
const MAX_QUEUE_LEN: usize = 1;
#[cfg(not(debug_assertions))]
const MAX_QUEUE_LEN: usize = 50;
#[cfg(debug_assertions)]
const DEBOUNCE_INTERVAL: Duration = Duration::from_secs(1);
#[cfg(not(debug_assertions))]
const DEBOUNCE_INTERVAL: Duration = Duration::from_secs(60 * 5);
impl Telemetry {
pub fn new(client: Arc<dyn HttpClient>, cx: &mut AppContext) -> Arc<Self> {
let release_channel = if cx.has_global::<ReleaseChannel>() {
Some(cx.global::<ReleaseChannel>().display_name())
} else {
None
};
// TODO: Replace all hardware stuff with nested SystemSpecs json
let this = Arc::new(Self {
http_client: client,
executor: cx.background_executor().clone(),
state: Mutex::new(TelemetryState {
app_metadata: cx.app_metadata(),
architecture: env::consts::ARCH,
release_channel,
installation_id: None,
metrics_id: None,
session_id: None,
clickhouse_events_queue: Default::default(),
flush_clickhouse_events_task: Default::default(),
log_file: None,
is_staff: None,
first_event_datetime: None,
}),
});
// We should only ever have one instance of Telemetry, leak the subscription to keep it alive
// rather than store in TelemetryState, complicating spawn as subscriptions are not Send
std::mem::forget(cx.on_app_quit({
let this = this.clone();
move |cx| this.shutdown_telemetry(cx)
}));
this
}
#[cfg(any(test, feature = "test-support"))]
fn shutdown_telemetry(self: &Arc<Self>, _: &mut AppContext) -> impl Future<Output = ()> {
Task::ready(())
}
// Skip calling this function in tests.
// TestAppContext ends up calling this function on shutdown and it panics when trying to find the TelemetrySettings
#[cfg(not(any(test, feature = "test-support")))]
fn shutdown_telemetry(self: &Arc<Self>, cx: &mut AppContext) -> impl Future<Output = ()> {
let telemetry_settings = TelemetrySettings::get_global(cx).clone();
self.report_app_event(telemetry_settings, "close", true);
Task::ready(())
}
pub fn log_file_path(&self) -> Option<PathBuf> {
Some(self.state.lock().log_file.as_ref()?.path().to_path_buf())
}
pub fn start(
self: &Arc<Self>,
installation_id: Option<String>,
session_id: String,
cx: &mut AppContext,
) {
let mut state = self.state.lock();
state.installation_id = installation_id.map(|id| id.into());
state.session_id = Some(session_id.into());
drop(state);
let this = self.clone();
cx.spawn(|cx| async move {
// Avoiding calling `System::new_all()`, as there have been crashes related to it
let refresh_kind = RefreshKind::new()
.with_memory() // For memory usage
.with_processes(ProcessRefreshKind::everything()) // For process usage
.with_cpu(CpuRefreshKind::everything()); // For core count
let mut system = System::new_with_specifics(refresh_kind);
// Avoiding calling `refresh_all()`, just update what we need
system.refresh_specifics(refresh_kind);
// Waiting some amount of time before the first query is important to get a reasonable value
// https://docs.rs/sysinfo/0.29.10/sysinfo/trait.ProcessExt.html#tymethod.cpu_usage
const DURATION_BETWEEN_SYSTEM_EVENTS: Duration = Duration::from_secs(4 * 60);
loop {
smol::Timer::after(DURATION_BETWEEN_SYSTEM_EVENTS).await;
system.refresh_specifics(refresh_kind);
let current_process = Pid::from_u32(std::process::id());
let Some(process) = system.processes().get(&current_process) else {
let process = current_process;
log::error!("Failed to find own process {process:?} in system process table");
// TODO: Fire an error telemetry event
return;
};
let telemetry_settings = if let Ok(telemetry_settings) =
cx.update(|cx| *TelemetrySettings::get_global(cx))
{
telemetry_settings
} else {
break;
};
this.report_memory_event(
telemetry_settings,
process.memory(),
process.virtual_memory(),
);
this.report_cpu_event(
telemetry_settings,
process.cpu_usage(),
system.cpus().len() as u32,
);
}
})
.detach();
}
pub fn set_authenticated_user_info(
self: &Arc<Self>,
metrics_id: Option<String>,
is_staff: bool,
cx: &AppContext,
) {
if !TelemetrySettings::get_global(cx).metrics {
return;
}
let mut state = self.state.lock();
let metrics_id: Option<Arc<str>> = metrics_id.map(|id| id.into());
state.metrics_id = metrics_id.clone();
state.is_staff = Some(is_staff);
drop(state);
}
pub fn report_editor_event(
self: &Arc<Self>,
telemetry_settings: TelemetrySettings,
file_extension: Option<String>,
vim_mode: bool,
operation: &'static str,
copilot_enabled: bool,
copilot_enabled_for_language: bool,
) {
let event = ClickhouseEvent::Editor {
file_extension,
vim_mode,
operation,
copilot_enabled,
copilot_enabled_for_language,
milliseconds_since_first_event: self.milliseconds_since_first_event(),
};
self.report_clickhouse_event(event, telemetry_settings, false)
}
pub fn report_copilot_event(
self: &Arc<Self>,
telemetry_settings: TelemetrySettings,
suggestion_id: Option<String>,
suggestion_accepted: bool,
file_extension: Option<String>,
) {
let event = ClickhouseEvent::Copilot {
suggestion_id,
suggestion_accepted,
file_extension,
milliseconds_since_first_event: self.milliseconds_since_first_event(),
};
self.report_clickhouse_event(event, telemetry_settings, false)
}
pub fn report_assistant_event(
self: &Arc<Self>,
telemetry_settings: TelemetrySettings,
conversation_id: Option<String>,
kind: AssistantKind,
model: &'static str,
) {
let event = ClickhouseEvent::Assistant {
conversation_id,
kind,
model,
milliseconds_since_first_event: self.milliseconds_since_first_event(),
};
self.report_clickhouse_event(event, telemetry_settings, false)
}
pub fn report_call_event(
self: &Arc<Self>,
telemetry_settings: TelemetrySettings,
operation: &'static str,
room_id: Option<u64>,
channel_id: Option<u64>,
) {
let event = ClickhouseEvent::Call {
operation,
room_id,
channel_id,
milliseconds_since_first_event: self.milliseconds_since_first_event(),
};
self.report_clickhouse_event(event, telemetry_settings, false)
}
pub fn report_cpu_event(
self: &Arc<Self>,
telemetry_settings: TelemetrySettings,
usage_as_percentage: f32,
core_count: u32,
) {
let event = ClickhouseEvent::Cpu {
usage_as_percentage,
core_count,
milliseconds_since_first_event: self.milliseconds_since_first_event(),
};
self.report_clickhouse_event(event, telemetry_settings, false)
}
pub fn report_memory_event(
self: &Arc<Self>,
telemetry_settings: TelemetrySettings,
memory_in_bytes: u64,
virtual_memory_in_bytes: u64,
) {
let event = ClickhouseEvent::Memory {
memory_in_bytes,
virtual_memory_in_bytes,
milliseconds_since_first_event: self.milliseconds_since_first_event(),
};
self.report_clickhouse_event(event, telemetry_settings, false)
}
pub fn report_app_event(
self: &Arc<Self>,
telemetry_settings: TelemetrySettings,
operation: &'static str,
immediate_flush: bool,
) {
let event = ClickhouseEvent::App {
operation,
milliseconds_since_first_event: self.milliseconds_since_first_event(),
};
self.report_clickhouse_event(event, telemetry_settings, immediate_flush)
}
pub fn report_setting_event(
self: &Arc<Self>,
telemetry_settings: TelemetrySettings,
setting: &'static str,
value: String,
) {
let event = ClickhouseEvent::Setting {
setting,
value,
milliseconds_since_first_event: self.milliseconds_since_first_event(),
};
self.report_clickhouse_event(event, telemetry_settings, false)
}
fn milliseconds_since_first_event(&self) -> i64 {
let mut state = self.state.lock();
match state.first_event_datetime {
Some(first_event_datetime) => {
let now: DateTime<Utc> = Utc::now();
now.timestamp_millis() - first_event_datetime.timestamp_millis()
}
None => {
state.first_event_datetime = Some(Utc::now());
0
}
}
}
fn report_clickhouse_event(
self: &Arc<Self>,
event: ClickhouseEvent,
telemetry_settings: TelemetrySettings,
immediate_flush: bool,
) {
if !telemetry_settings.metrics {
return;
}
let mut state = self.state.lock();
let signed_in = state.metrics_id.is_some();
state
.clickhouse_events_queue
.push(ClickhouseEventWrapper { signed_in, event });
if state.installation_id.is_some() {
if immediate_flush || state.clickhouse_events_queue.len() >= MAX_QUEUE_LEN {
drop(state);
self.flush_clickhouse_events();
} else {
let this = self.clone();
let executor = self.executor.clone();
state.flush_clickhouse_events_task = Some(self.executor.spawn(async move {
executor.timer(DEBOUNCE_INTERVAL).await;
this.flush_clickhouse_events();
}));
}
}
}
pub fn metrics_id(self: &Arc<Self>) -> Option<Arc<str>> {
self.state.lock().metrics_id.clone()
}
pub fn installation_id(self: &Arc<Self>) -> Option<Arc<str>> {
self.state.lock().installation_id.clone()
}
pub fn is_staff(self: &Arc<Self>) -> Option<bool> {
self.state.lock().is_staff
}
fn flush_clickhouse_events(self: &Arc<Self>) {
let mut state = self.state.lock();
state.first_event_datetime = None;
let mut events = mem::take(&mut state.clickhouse_events_queue);
state.flush_clickhouse_events_task.take();
drop(state);
let this = self.clone();
self.executor
.spawn(
async move {
let mut json_bytes = Vec::new();
if let Some(file) = &mut this.state.lock().log_file {
let file = file.as_file_mut();
for event in &mut events {
json_bytes.clear();
serde_json::to_writer(&mut json_bytes, event)?;
file.write_all(&json_bytes)?;
file.write(b"\n")?;
}
}
{
let state = this.state.lock();
let request_body = ClickhouseEventRequestBody {
token: ZED_SECRET_CLIENT_TOKEN,
installation_id: state.installation_id.clone(),
session_id: state.session_id.clone(),
is_staff: state.is_staff.clone(),
app_version: state
.app_metadata
.app_version
.map(|version| version.to_string()),
os_name: state.app_metadata.os_name,
os_version: state
.app_metadata
.os_version
.map(|version| version.to_string()),
architecture: state.architecture,
release_channel: state.release_channel,
events,
};
json_bytes.clear();
serde_json::to_writer(&mut json_bytes, &request_body)?;
}
this.http_client
.post_json(CLICKHOUSE_EVENTS_URL.as_str(), json_bytes.into())
.await?;
anyhow::Ok(())
}
.log_err(),
)
.detach();
}
}

View File

@ -1,214 +0,0 @@
use crate::{Client, Connection, Credentials, EstablishConnectionError, UserStore};
use anyhow::{anyhow, Result};
use futures::{stream::BoxStream, StreamExt};
use gpui::{BackgroundExecutor, Context, Model, TestAppContext};
use parking_lot::Mutex;
use rpc::{
proto::{self, GetPrivateUserInfo, GetPrivateUserInfoResponse},
ConnectionId, Peer, Receipt, TypedEnvelope,
};
use std::sync::Arc;
pub struct FakeServer {
peer: Arc<Peer>,
state: Arc<Mutex<FakeServerState>>,
user_id: u64,
executor: BackgroundExecutor,
}
#[derive(Default)]
struct FakeServerState {
incoming: Option<BoxStream<'static, Box<dyn proto::AnyTypedEnvelope>>>,
connection_id: Option<ConnectionId>,
forbid_connections: bool,
auth_count: usize,
access_token: usize,
}
impl FakeServer {
pub async fn for_client(
client_user_id: u64,
client: &Arc<Client>,
cx: &TestAppContext,
) -> Self {
let server = Self {
peer: Peer::new(0),
state: Default::default(),
user_id: client_user_id,
executor: cx.executor(),
};
client
.override_authenticate({
let state = Arc::downgrade(&server.state);
move |cx| {
let state = state.clone();
cx.spawn(move |_| async move {
let state = state.upgrade().ok_or_else(|| anyhow!("server dropped"))?;
let mut state = state.lock();
state.auth_count += 1;
let access_token = state.access_token.to_string();
Ok(Credentials {
user_id: client_user_id,
access_token,
})
})
}
})
.override_establish_connection({
let peer = Arc::downgrade(&server.peer);
let state = Arc::downgrade(&server.state);
move |credentials, cx| {
let peer = peer.clone();
let state = state.clone();
let credentials = credentials.clone();
cx.spawn(move |cx| async move {
let state = state.upgrade().ok_or_else(|| anyhow!("server dropped"))?;
let peer = peer.upgrade().ok_or_else(|| anyhow!("server dropped"))?;
if state.lock().forbid_connections {
Err(EstablishConnectionError::Other(anyhow!(
"server is forbidding connections"
)))?
}
assert_eq!(credentials.user_id, client_user_id);
if credentials.access_token != state.lock().access_token.to_string() {
Err(EstablishConnectionError::Unauthorized)?
}
let (client_conn, server_conn, _) =
Connection::in_memory(cx.background_executor().clone());
let (connection_id, io, incoming) =
peer.add_test_connection(server_conn, cx.background_executor().clone());
cx.background_executor().spawn(io).detach();
{
let mut state = state.lock();
state.connection_id = Some(connection_id);
state.incoming = Some(incoming);
}
peer.send(
connection_id,
proto::Hello {
peer_id: Some(connection_id.into()),
},
)
.unwrap();
Ok(client_conn)
})
}
});
client
.authenticate_and_connect(false, &cx.to_async())
.await
.unwrap();
server
}
pub fn disconnect(&self) {
if self.state.lock().connection_id.is_some() {
self.peer.disconnect(self.connection_id());
let mut state = self.state.lock();
state.connection_id.take();
state.incoming.take();
}
}
pub fn auth_count(&self) -> usize {
self.state.lock().auth_count
}
pub fn roll_access_token(&self) {
self.state.lock().access_token += 1;
}
pub fn forbid_connections(&self) {
self.state.lock().forbid_connections = true;
}
pub fn allow_connections(&self) {
self.state.lock().forbid_connections = false;
}
pub fn send<T: proto::EnvelopedMessage>(&self, message: T) {
self.peer.send(self.connection_id(), message).unwrap();
}
#[allow(clippy::await_holding_lock)]
pub async fn receive<M: proto::EnvelopedMessage>(&self) -> Result<TypedEnvelope<M>> {
self.executor.start_waiting();
loop {
let message = self
.state
.lock()
.incoming
.as_mut()
.expect("not connected")
.next()
.await
.ok_or_else(|| anyhow!("other half hung up"))?;
self.executor.finish_waiting();
let type_name = message.payload_type_name();
let message = message.into_any();
if message.is::<TypedEnvelope<M>>() {
return Ok(*message.downcast().unwrap());
}
if message.is::<TypedEnvelope<GetPrivateUserInfo>>() {
self.respond(
message
.downcast::<TypedEnvelope<GetPrivateUserInfo>>()
.unwrap()
.receipt(),
GetPrivateUserInfoResponse {
metrics_id: "the-metrics-id".into(),
staff: false,
flags: Default::default(),
},
);
continue;
}
panic!(
"fake server received unexpected message type: {:?}",
type_name
);
}
}
pub fn respond<T: proto::RequestMessage>(&self, receipt: Receipt<T>, response: T::Response) {
self.peer.respond(receipt, response).unwrap()
}
fn connection_id(&self) -> ConnectionId {
self.state.lock().connection_id.expect("not connected")
}
pub async fn build_user_store(
&self,
client: Arc<Client>,
cx: &mut TestAppContext,
) -> Model<UserStore> {
let user_store = cx.new_model(|cx| UserStore::new(client, cx));
assert_eq!(
self.receive::<proto::GetUsers>()
.await
.unwrap()
.payload
.user_ids,
&[self.user_id]
);
user_store
}
}
impl Drop for FakeServer {
fn drop(&mut self) {
self.disconnect();
}
}

View File

@ -1,694 +0,0 @@
use super::{proto, Client, Status, TypedEnvelope};
use anyhow::{anyhow, Context, Result};
use collections::{hash_map::Entry, HashMap, HashSet};
use feature_flags::FeatureFlagAppExt;
use futures::{channel::mpsc, Future, StreamExt};
use gpui::{AsyncAppContext, EventEmitter, Model, ModelContext, SharedString, Task};
use postage::{sink::Sink, watch};
use rpc::proto::{RequestMessage, UsersResponse};
use std::sync::{Arc, Weak};
use text::ReplicaId;
use util::TryFutureExt as _;
pub type UserId = u64;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ParticipantIndex(pub u32);
#[derive(Default, Debug)]
pub struct User {
pub id: UserId,
pub github_login: String,
pub avatar_uri: SharedString,
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct Collaborator {
pub peer_id: proto::PeerId,
pub replica_id: ReplicaId,
pub user_id: UserId,
}
impl PartialOrd for User {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for User {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.github_login.cmp(&other.github_login)
}
}
impl PartialEq for User {
fn eq(&self, other: &Self) -> bool {
self.id == other.id && self.github_login == other.github_login
}
}
impl Eq for User {}
#[derive(Debug, PartialEq)]
pub struct Contact {
pub user: Arc<User>,
pub online: bool,
pub busy: bool,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ContactRequestStatus {
None,
RequestSent,
RequestReceived,
RequestAccepted,
}
pub struct UserStore {
users: HashMap<u64, Arc<User>>,
participant_indices: HashMap<u64, ParticipantIndex>,
update_contacts_tx: mpsc::UnboundedSender<UpdateContacts>,
current_user: watch::Receiver<Option<Arc<User>>>,
contacts: Vec<Arc<Contact>>,
incoming_contact_requests: Vec<Arc<User>>,
outgoing_contact_requests: Vec<Arc<User>>,
pending_contact_requests: HashMap<u64, usize>,
invite_info: Option<InviteInfo>,
client: Weak<Client>,
_maintain_contacts: Task<()>,
_maintain_current_user: Task<Result<()>>,
}
#[derive(Clone)]
pub struct InviteInfo {
pub count: u32,
pub url: Arc<str>,
}
pub enum Event {
Contact {
user: Arc<User>,
kind: ContactEventKind,
},
ShowContacts,
ParticipantIndicesChanged,
}
#[derive(Clone, Copy)]
pub enum ContactEventKind {
Requested,
Accepted,
Cancelled,
}
impl EventEmitter<Event> for UserStore {}
enum UpdateContacts {
Update(proto::UpdateContacts),
Wait(postage::barrier::Sender),
Clear(postage::barrier::Sender),
}
impl UserStore {
pub fn new(client: Arc<Client>, cx: &mut ModelContext<Self>) -> Self {
let (mut current_user_tx, current_user_rx) = watch::channel();
let (update_contacts_tx, mut update_contacts_rx) = mpsc::unbounded();
let rpc_subscriptions = vec![
client.add_message_handler(cx.weak_model(), Self::handle_update_contacts),
client.add_message_handler(cx.weak_model(), Self::handle_update_invite_info),
client.add_message_handler(cx.weak_model(), Self::handle_show_contacts),
];
Self {
users: Default::default(),
current_user: current_user_rx,
contacts: Default::default(),
incoming_contact_requests: Default::default(),
participant_indices: Default::default(),
outgoing_contact_requests: Default::default(),
invite_info: None,
client: Arc::downgrade(&client),
update_contacts_tx,
_maintain_contacts: cx.spawn(|this, mut cx| async move {
let _subscriptions = rpc_subscriptions;
while let Some(message) = update_contacts_rx.next().await {
if let Ok(task) =
this.update(&mut cx, |this, cx| this.update_contacts(message, cx))
{
task.log_err().await;
} else {
break;
}
}
}),
_maintain_current_user: cx.spawn(|this, mut cx| async move {
let mut status = client.status();
while let Some(status) = status.next().await {
match status {
Status::Connected { .. } => {
if let Some(user_id) = client.user_id() {
let fetch_user = if let Ok(fetch_user) = this
.update(&mut cx, |this, cx| {
this.get_user(user_id, cx).log_err()
}) {
fetch_user
} else {
break;
};
let fetch_metrics_id =
client.request(proto::GetPrivateUserInfo {}).log_err();
let (user, info) = futures::join!(fetch_user, fetch_metrics_id);
cx.update(|cx| {
if let Some(info) = info {
cx.update_flags(info.staff, info.flags);
client.telemetry.set_authenticated_user_info(
Some(info.metrics_id.clone()),
info.staff,
cx,
)
}
})?;
current_user_tx.send(user).await.ok();
this.update(&mut cx, |_, cx| cx.notify())?;
}
}
Status::SignedOut => {
current_user_tx.send(None).await.ok();
this.update(&mut cx, |this, cx| {
cx.notify();
this.clear_contacts()
})?
.await;
}
Status::ConnectionLost => {
this.update(&mut cx, |this, cx| {
cx.notify();
this.clear_contacts()
})?
.await;
}
_ => {}
}
}
Ok(())
}),
pending_contact_requests: Default::default(),
}
}
#[cfg(feature = "test-support")]
pub fn clear_cache(&mut self) {
self.users.clear();
}
async fn handle_update_invite_info(
this: Model<Self>,
message: TypedEnvelope<proto::UpdateInviteInfo>,
_: Arc<Client>,
mut cx: AsyncAppContext,
) -> Result<()> {
this.update(&mut cx, |this, cx| {
this.invite_info = Some(InviteInfo {
url: Arc::from(message.payload.url),
count: message.payload.count,
});
cx.notify();
})?;
Ok(())
}
async fn handle_show_contacts(
this: Model<Self>,
_: TypedEnvelope<proto::ShowContacts>,
_: Arc<Client>,
mut cx: AsyncAppContext,
) -> Result<()> {
this.update(&mut cx, |_, cx| cx.emit(Event::ShowContacts))?;
Ok(())
}
pub fn invite_info(&self) -> Option<&InviteInfo> {
self.invite_info.as_ref()
}
async fn handle_update_contacts(
this: Model<Self>,
message: TypedEnvelope<proto::UpdateContacts>,
_: Arc<Client>,
mut cx: AsyncAppContext,
) -> Result<()> {
this.update(&mut cx, |this, _| {
this.update_contacts_tx
.unbounded_send(UpdateContacts::Update(message.payload))
.unwrap();
})?;
Ok(())
}
fn update_contacts(
&mut self,
message: UpdateContacts,
cx: &mut ModelContext<Self>,
) -> Task<Result<()>> {
match message {
UpdateContacts::Wait(barrier) => {
drop(barrier);
Task::ready(Ok(()))
}
UpdateContacts::Clear(barrier) => {
self.contacts.clear();
self.incoming_contact_requests.clear();
self.outgoing_contact_requests.clear();
drop(barrier);
Task::ready(Ok(()))
}
UpdateContacts::Update(message) => {
let mut user_ids = HashSet::default();
for contact in &message.contacts {
user_ids.insert(contact.user_id);
}
user_ids.extend(message.incoming_requests.iter().map(|req| req.requester_id));
user_ids.extend(message.outgoing_requests.iter());
let load_users = self.get_users(user_ids.into_iter().collect(), cx);
cx.spawn(|this, mut cx| async move {
load_users.await?;
// Users are fetched in parallel above and cached in call to get_users
// No need to paralellize here
let mut updated_contacts = Vec::new();
let this = this
.upgrade()
.ok_or_else(|| anyhow!("can't upgrade user store handle"))?;
for contact in message.contacts {
updated_contacts.push(Arc::new(
Contact::from_proto(contact, &this, &mut cx).await?,
));
}
let mut incoming_requests = Vec::new();
for request in message.incoming_requests {
incoming_requests.push({
this.update(&mut cx, |this, cx| {
this.get_user(request.requester_id, cx)
})?
.await?
});
}
let mut outgoing_requests = Vec::new();
for requested_user_id in message.outgoing_requests {
outgoing_requests.push(
this.update(&mut cx, |this, cx| this.get_user(requested_user_id, cx))?
.await?,
);
}
let removed_contacts =
HashSet::<u64>::from_iter(message.remove_contacts.iter().copied());
let removed_incoming_requests =
HashSet::<u64>::from_iter(message.remove_incoming_requests.iter().copied());
let removed_outgoing_requests =
HashSet::<u64>::from_iter(message.remove_outgoing_requests.iter().copied());
this.update(&mut cx, |this, cx| {
// Remove contacts
this.contacts
.retain(|contact| !removed_contacts.contains(&contact.user.id));
// Update existing contacts and insert new ones
for updated_contact in updated_contacts {
match this.contacts.binary_search_by_key(
&&updated_contact.user.github_login,
|contact| &contact.user.github_login,
) {
Ok(ix) => this.contacts[ix] = updated_contact,
Err(ix) => this.contacts.insert(ix, updated_contact),
}
}
// Remove incoming contact requests
this.incoming_contact_requests.retain(|user| {
if removed_incoming_requests.contains(&user.id) {
cx.emit(Event::Contact {
user: user.clone(),
kind: ContactEventKind::Cancelled,
});
false
} else {
true
}
});
// Update existing incoming requests and insert new ones
for user in incoming_requests {
match this
.incoming_contact_requests
.binary_search_by_key(&&user.github_login, |contact| {
&contact.github_login
}) {
Ok(ix) => this.incoming_contact_requests[ix] = user,
Err(ix) => this.incoming_contact_requests.insert(ix, user),
}
}
// Remove outgoing contact requests
this.outgoing_contact_requests
.retain(|user| !removed_outgoing_requests.contains(&user.id));
// Update existing incoming requests and insert new ones
for request in outgoing_requests {
match this
.outgoing_contact_requests
.binary_search_by_key(&&request.github_login, |contact| {
&contact.github_login
}) {
Ok(ix) => this.outgoing_contact_requests[ix] = request,
Err(ix) => this.outgoing_contact_requests.insert(ix, request),
}
}
cx.notify();
})?;
Ok(())
})
}
}
}
pub fn contacts(&self) -> &[Arc<Contact>] {
&self.contacts
}
pub fn has_contact(&self, user: &Arc<User>) -> bool {
self.contacts
.binary_search_by_key(&&user.github_login, |contact| &contact.user.github_login)
.is_ok()
}
pub fn incoming_contact_requests(&self) -> &[Arc<User>] {
&self.incoming_contact_requests
}
pub fn outgoing_contact_requests(&self) -> &[Arc<User>] {
&self.outgoing_contact_requests
}
pub fn is_contact_request_pending(&self, user: &User) -> bool {
self.pending_contact_requests.contains_key(&user.id)
}
pub fn contact_request_status(&self, user: &User) -> ContactRequestStatus {
if self
.contacts
.binary_search_by_key(&&user.github_login, |contact| &contact.user.github_login)
.is_ok()
{
ContactRequestStatus::RequestAccepted
} else if self
.outgoing_contact_requests
.binary_search_by_key(&&user.github_login, |user| &user.github_login)
.is_ok()
{
ContactRequestStatus::RequestSent
} else if self
.incoming_contact_requests
.binary_search_by_key(&&user.github_login, |user| &user.github_login)
.is_ok()
{
ContactRequestStatus::RequestReceived
} else {
ContactRequestStatus::None
}
}
pub fn request_contact(
&mut self,
responder_id: u64,
cx: &mut ModelContext<Self>,
) -> Task<Result<()>> {
self.perform_contact_request(responder_id, proto::RequestContact { responder_id }, cx)
}
pub fn remove_contact(
&mut self,
user_id: u64,
cx: &mut ModelContext<Self>,
) -> Task<Result<()>> {
self.perform_contact_request(user_id, proto::RemoveContact { user_id }, cx)
}
pub fn has_incoming_contact_request(&self, user_id: u64) -> bool {
self.incoming_contact_requests
.iter()
.any(|user| user.id == user_id)
}
pub fn respond_to_contact_request(
&mut self,
requester_id: u64,
accept: bool,
cx: &mut ModelContext<Self>,
) -> Task<Result<()>> {
self.perform_contact_request(
requester_id,
proto::RespondToContactRequest {
requester_id,
response: if accept {
proto::ContactRequestResponse::Accept
} else {
proto::ContactRequestResponse::Decline
} as i32,
},
cx,
)
}
pub fn dismiss_contact_request(
&mut self,
requester_id: u64,
cx: &mut ModelContext<Self>,
) -> Task<Result<()>> {
let client = self.client.upgrade();
cx.spawn(move |_, _| async move {
client
.ok_or_else(|| anyhow!("can't upgrade client reference"))?
.request(proto::RespondToContactRequest {
requester_id,
response: proto::ContactRequestResponse::Dismiss as i32,
})
.await?;
Ok(())
})
}
fn perform_contact_request<T: RequestMessage>(
&mut self,
user_id: u64,
request: T,
cx: &mut ModelContext<Self>,
) -> Task<Result<()>> {
let client = self.client.upgrade();
*self.pending_contact_requests.entry(user_id).or_insert(0) += 1;
cx.notify();
cx.spawn(move |this, mut cx| async move {
let response = client
.ok_or_else(|| anyhow!("can't upgrade client reference"))?
.request(request)
.await;
this.update(&mut cx, |this, cx| {
if let Entry::Occupied(mut request_count) =
this.pending_contact_requests.entry(user_id)
{
*request_count.get_mut() -= 1;
if *request_count.get() == 0 {
request_count.remove();
}
}
cx.notify();
})?;
response?;
Ok(())
})
}
pub fn clear_contacts(&mut self) -> impl Future<Output = ()> {
let (tx, mut rx) = postage::barrier::channel();
self.update_contacts_tx
.unbounded_send(UpdateContacts::Clear(tx))
.unwrap();
async move {
rx.next().await;
}
}
pub fn contact_updates_done(&mut self) -> impl Future<Output = ()> {
let (tx, mut rx) = postage::barrier::channel();
self.update_contacts_tx
.unbounded_send(UpdateContacts::Wait(tx))
.unwrap();
async move {
rx.next().await;
}
}
pub fn get_users(
&mut self,
user_ids: Vec<u64>,
cx: &mut ModelContext<Self>,
) -> Task<Result<Vec<Arc<User>>>> {
let mut user_ids_to_fetch = user_ids.clone();
user_ids_to_fetch.retain(|id| !self.users.contains_key(id));
cx.spawn(|this, mut cx| async move {
if !user_ids_to_fetch.is_empty() {
this.update(&mut cx, |this, cx| {
this.load_users(
proto::GetUsers {
user_ids: user_ids_to_fetch,
},
cx,
)
})?
.await?;
}
this.update(&mut cx, |this, _| {
user_ids
.iter()
.map(|user_id| {
this.users
.get(user_id)
.cloned()
.ok_or_else(|| anyhow!("user {} not found", user_id))
})
.collect()
})?
})
}
pub fn fuzzy_search_users(
&mut self,
query: String,
cx: &mut ModelContext<Self>,
) -> Task<Result<Vec<Arc<User>>>> {
self.load_users(proto::FuzzySearchUsers { query }, cx)
}
pub fn get_cached_user(&self, user_id: u64) -> Option<Arc<User>> {
self.users.get(&user_id).cloned()
}
pub fn get_user(
&mut self,
user_id: u64,
cx: &mut ModelContext<Self>,
) -> Task<Result<Arc<User>>> {
if let Some(user) = self.users.get(&user_id).cloned() {
return Task::ready(Ok(user));
}
let load_users = self.get_users(vec![user_id], cx);
cx.spawn(move |this, mut cx| async move {
load_users.await?;
this.update(&mut cx, |this, _| {
this.users
.get(&user_id)
.cloned()
.ok_or_else(|| anyhow!("server responded with no users"))
})?
})
}
pub fn current_user(&self) -> Option<Arc<User>> {
self.current_user.borrow().clone()
}
pub fn watch_current_user(&self) -> watch::Receiver<Option<Arc<User>>> {
self.current_user.clone()
}
fn load_users(
&mut self,
request: impl RequestMessage<Response = UsersResponse>,
cx: &mut ModelContext<Self>,
) -> Task<Result<Vec<Arc<User>>>> {
let client = self.client.clone();
cx.spawn(|this, mut cx| async move {
if let Some(rpc) = client.upgrade() {
let response = rpc.request(request).await.context("error loading users")?;
let users = response
.users
.into_iter()
.map(|user| User::new(user))
.collect::<Vec<_>>();
this.update(&mut cx, |this, _| {
for user in &users {
this.users.insert(user.id, user.clone());
}
})
.ok();
Ok(users)
} else {
Ok(Vec::new())
}
})
}
pub fn set_participant_indices(
&mut self,
participant_indices: HashMap<u64, ParticipantIndex>,
cx: &mut ModelContext<Self>,
) {
if participant_indices != self.participant_indices {
self.participant_indices = participant_indices;
cx.emit(Event::ParticipantIndicesChanged);
}
}
pub fn participant_indices(&self) -> &HashMap<u64, ParticipantIndex> {
&self.participant_indices
}
}
impl User {
fn new(message: proto::User) -> Arc<Self> {
Arc::new(User {
id: message.id,
github_login: message.github_login,
avatar_uri: message.avatar_url.into(),
})
}
}
impl Contact {
async fn from_proto(
contact: proto::Contact,
user_store: &Model<UserStore>,
cx: &mut AsyncAppContext,
) -> Result<Self> {
let user = user_store
.update(cx, |user_store, cx| {
user_store.get_user(contact.user_id, cx)
})?
.await?;
Ok(Self {
user,
online: contact.online,
busy: contact.busy,
})
}
}
impl Collaborator {
pub fn from_proto(message: proto::Collaborator) -> Result<Self> {
Ok(Self {
peer_id: message.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?,
replica_id: message.replica_id as ReplicaId,
user_id: message.user_id as UserId,
})
}
}

View File

@ -3,7 +3,7 @@ authors = ["Nathan Sobo <nathan@zed.dev>"]
default-run = "collab"
edition = "2021"
name = "collab"
version = "0.32.0"
version = "0.28.0"
publish = false
[[bin]]
@ -74,11 +74,13 @@ live_kit_client = { path = "../live_kit_client", features = ["test-support"] }
lsp = { path = "../lsp", features = ["test-support"] }
node_runtime = { path = "../node_runtime" }
notifications = { path = "../notifications", features = ["test-support"] }
project = { path = "../project", features = ["test-support"] }
rpc = { path = "../rpc", features = ["test-support"] }
settings = { path = "../settings", features = ["test-support"] }
theme = { path = "../theme" }
workspace = { path = "../workspace", features = ["test-support"] }
collab_ui = { path = "../collab_ui", features = ["test-support"] }
async-trait.workspace = true

View File

@ -1 +0,0 @@
../collab2/k8s

Some files were not shown because too many files have changed in this diff Show More