From f6c85b28d5c69d9231fd10b2808ce18a34a96647 Mon Sep 17 00:00:00 2001 From: Conrad Irwin Date: Thu, 11 Apr 2024 15:36:35 -0600 Subject: [PATCH] WIP: remoting (#10085) Release Notes: - Added private alpha support for remote development. Please reach out to hi@zed.dev if you'd like to be part of shaping this feature. --- Cargo.lock | 168 +-- Cargo.toml | 3 + assets/icons/server.svg | 5 + assets/icons/trash.svg | 1 + crates/call/src/room.rs | 2 +- crates/channel/src/channel.rs | 4 +- crates/channel/src/channel_store.rs | 205 +++- crates/client/src/client.rs | 3 +- crates/client/src/user.rs | 6 + crates/collab/.env.toml | 2 +- crates/collab/Cargo.toml | 5 +- .../20221109000000_test_schema.sql | 19 +- ...20240402155003_add_dev_server_projects.sql | 9 + crates/collab/src/auth.rs | 27 +- crates/collab/src/db.rs | 107 +- crates/collab/src/db/ids.rs | 16 + crates/collab/src/db/queries.rs | 1 + crates/collab/src/db/queries/channels.rs | 5 + crates/collab/src/db/queries/dev_servers.rs | 42 +- crates/collab/src/db/queries/projects.rs | 374 +++---- .../collab/src/db/queries/remote_projects.rs | 261 +++++ crates/collab/src/db/queries/rooms.rs | 378 +++---- crates/collab/src/db/tables.rs | 1 + crates/collab/src/db/tables/dev_server.rs | 12 + crates/collab/src/db/tables/project.rs | 15 +- crates/collab/src/db/tables/remote_project.rs | 42 + crates/collab/src/rpc.rs | 955 +++++++++++++----- crates/collab/src/rpc/connection_pool.rs | 95 +- crates/collab/src/tests.rs | 1 + crates/collab/src/tests/dev_server_tests.rs | 110 ++ crates/collab/src/tests/integration_tests.rs | 2 +- crates/collab/src/tests/test_server.rs | 125 +++ crates/collab_ui/Cargo.toml | 1 + crates/collab_ui/src/collab_panel.rs | 149 ++- .../src/collab_panel/dev_server_modal.rs | 622 ++++++++++++ crates/feature_flags/src/feature_flags.rs | 5 + crates/gpui/src/app/async_context.rs | 21 +- crates/gpui/src/platform.rs | 7 +- crates/gpui/src/platform/linux.rs | 2 + crates/gpui/src/platform/linux/headless.rs | 3 + .../src/platform/linux/headless/client.rs | 98 ++ crates/headless/Cargo.toml | 36 + crates/headless/LICENSE-GPL | 1 + crates/headless/src/headless.rs | 265 +++++ crates/project/src/connection_manager.rs | 212 ++++ crates/project/src/project.rs | 24 +- crates/rpc/proto/zed.proto | 97 +- crates/rpc/src/proto.rs | 19 + crates/ui/src/components.rs | 2 + crates/ui/src/components/icon.rs | 4 + crates/ui/src/components/modal.rs | 133 +++ crates/workspace/src/workspace.rs | 55 + crates/zed/Cargo.toml | 1 + crates/zed/src/main.rs | 118 ++- 54 files changed, 4117 insertions(+), 759 deletions(-) create mode 100644 assets/icons/server.svg create mode 100644 assets/icons/trash.svg create mode 100644 crates/collab/migrations/20240402155003_add_dev_server_projects.sql create mode 100644 crates/collab/src/db/queries/remote_projects.rs create mode 100644 crates/collab/src/db/tables/remote_project.rs create mode 100644 crates/collab/src/tests/dev_server_tests.rs create mode 100644 crates/collab_ui/src/collab_panel/dev_server_modal.rs create mode 100644 crates/gpui/src/platform/linux/headless.rs create mode 100644 crates/gpui/src/platform/linux/headless/client.rs create mode 100644 crates/headless/Cargo.toml create mode 120000 crates/headless/LICENSE-GPL create mode 100644 crates/headless/src/headless.rs create mode 100644 crates/project/src/connection_manager.rs create mode 100644 crates/ui/src/components/modal.rs diff --git a/Cargo.lock b/Cargo.lock index da08ad6363..ab51b79c5a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -520,7 +520,7 @@ dependencies = [ "polling 3.3.2", "rustix 0.38.32", "slab", - "tracing", + "tracing 0.1.40 (registry+https://github.com/rust-lang/crates.io-index)", "windows-sys 0.52.0", ] @@ -861,7 +861,7 @@ dependencies = [ "ring 0.17.7", "time", "tokio", - "tracing", + "tracing 0.1.40 (registry+https://github.com/rust-lang/crates.io-index)", "zeroize", ] @@ -897,7 +897,7 @@ dependencies = [ "http-body", "percent-encoding", "pin-project-lite", - "tracing", + "tracing 0.1.40 (registry+https://github.com/rust-lang/crates.io-index)", "uuid", ] @@ -926,7 +926,7 @@ dependencies = [ "once_cell", "percent-encoding", "regex-lite", - "tracing", + "tracing 0.1.40 (registry+https://github.com/rust-lang/crates.io-index)", "url", ] @@ -949,7 +949,7 @@ dependencies = [ "http 0.2.9", "once_cell", "regex-lite", - "tracing", + "tracing 0.1.40 (registry+https://github.com/rust-lang/crates.io-index)", ] [[package]] @@ -971,7 +971,7 @@ dependencies = [ "http 0.2.9", "once_cell", "regex-lite", - "tracing", + "tracing 0.1.40 (registry+https://github.com/rust-lang/crates.io-index)", ] [[package]] @@ -994,7 +994,7 @@ dependencies = [ "http 0.2.9", "once_cell", "regex-lite", - "tracing", + "tracing 0.1.40 (registry+https://github.com/rust-lang/crates.io-index)", ] [[package]] @@ -1022,7 +1022,7 @@ dependencies = [ "sha2 0.10.7", "subtle", "time", - "tracing", + "tracing 0.1.40 (registry+https://github.com/rust-lang/crates.io-index)", "zeroize", ] @@ -1055,7 +1055,7 @@ dependencies = [ "pin-project-lite", "sha1", "sha2 0.10.7", - "tracing", + "tracing 0.1.40 (registry+https://github.com/rust-lang/crates.io-index)", ] [[package]] @@ -1087,7 +1087,7 @@ dependencies = [ "percent-encoding", "pin-project-lite", "pin-utils", - "tracing", + "tracing 0.1.40 (registry+https://github.com/rust-lang/crates.io-index)", ] [[package]] @@ -1131,7 +1131,7 @@ dependencies = [ "pin-utils", "rustls", "tokio", - "tracing", + "tracing 0.1.40 (registry+https://github.com/rust-lang/crates.io-index)", ] [[package]] @@ -1146,7 +1146,7 @@ dependencies = [ "http 0.2.9", "pin-project-lite", "tokio", - "tracing", + "tracing 0.1.40 (registry+https://github.com/rust-lang/crates.io-index)", "zeroize", ] @@ -1194,7 +1194,7 @@ dependencies = [ "aws-smithy-types", "http 0.2.9", "rustc_version", - "tracing", + "tracing 0.1.40 (registry+https://github.com/rust-lang/crates.io-index)", ] [[package]] @@ -1517,7 +1517,7 @@ dependencies = [ "futures-io", "futures-lite 2.2.0", "piper", - "tracing", + "tracing 0.1.40 (registry+https://github.com/rust-lang/crates.io-index)", ] [[package]] @@ -1877,6 +1877,12 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "cfg_aliases" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fd16c4719339c4530435d38e511904438d07cce7950afa3718a84ac36c10e89e" + [[package]] name = "channel" version = "0.1.0" @@ -2238,6 +2244,7 @@ dependencies = [ "git", "google_ai", "gpui", + "headless", "hex", "indoc", "language", @@ -2279,7 +2286,7 @@ dependencies = [ "toml 0.8.10", "tower", "tower-http 0.4.4", - "tracing", + "tracing 0.1.40 (registry+https://github.com/rust-lang/crates.io-index)", "tracing-subscriber", "unindent", "util", @@ -2302,6 +2309,7 @@ dependencies = [ "editor", "emojis", "extensions_ui", + "feature_flags", "futures 0.3.28", "fuzzy", "gpui", @@ -2966,11 +2974,11 @@ dependencies = [ [[package]] name = "ctrlc" -version = "3.4.2" +version = "3.4.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b467862cc8610ca6fc9a1532d7777cee0804e678ab45410897b9396495994a0b" +checksum = "672465ae37dc1bc6380a6547a8883d5dd397b0f1faaad4f265726cc7042a5345" dependencies = [ - "nix 0.27.1", + "nix 0.28.0", "windows-sys 0.52.0", ] @@ -4550,7 +4558,7 @@ dependencies = [ "slab", "tokio", "tokio-util", - "tracing", + "tracing 0.1.40 (registry+https://github.com/rust-lang/crates.io-index)", ] [[package]] @@ -4624,6 +4632,26 @@ dependencies = [ "http 0.2.9", ] +[[package]] +name = "headless" +version = "0.1.0" +dependencies = [ + "anyhow", + "client", + "ctrlc", + "fs", + "futures 0.3.28", + "gpui", + "language", + "log", + "node_runtime", + "postage", + "project", + "rpc", + "settings", + "util", +] + [[package]] name = "heck" version = "0.3.3" @@ -4806,7 +4834,7 @@ dependencies = [ "socket2 0.4.9", "tokio", "tower-service", - "tracing", + "tracing 0.1.40 (registry+https://github.com/rust-lang/crates.io-index)", "want", ] @@ -5121,7 +5149,7 @@ dependencies = [ "polling 2.8.0", "slab", "sluice", - "tracing", + "tracing 0.1.40 (registry+https://github.com/rust-lang/crates.io-index)", "tracing-futures", "url", "waker-fn", @@ -6100,6 +6128,18 @@ dependencies = [ "memoffset", ] +[[package]] +name = "nix" +version = "0.28.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ab2156c4fce2f8df6c499cc1c763e4394b7482525bf2a9701c9d79d215f519e4" +dependencies = [ + "bitflags 2.4.2", + "cfg-if", + "cfg_aliases", + "libc", +] + [[package]] name = "node_runtime" version = "0.1.0" @@ -7036,7 +7076,7 @@ dependencies = [ "concurrent-queue", "pin-project-lite", "rustix 0.38.32", - "tracing", + "tracing 0.1.40 (registry+https://github.com/rust-lang/crates.io-index)", "windows-sys 0.52.0", ] @@ -7941,7 +7981,7 @@ dependencies = [ "serde", "serde_json", "strum", - "tracing", + "tracing 0.1.40 (registry+https://github.com/rust-lang/crates.io-index)", "util", "zstd", ] @@ -8308,7 +8348,7 @@ dependencies = [ "strum", "thiserror", "time", - "tracing", + "tracing 0.1.40 (registry+https://github.com/rust-lang/crates.io-index)", "url", "uuid", ] @@ -9059,7 +9099,7 @@ dependencies = [ "time", "tokio", "tokio-stream", - "tracing", + "tracing 0.1.40 (registry+https://github.com/rust-lang/crates.io-index)", "url", "uuid", "webpki-roots", @@ -9146,7 +9186,7 @@ dependencies = [ "stringprep", "thiserror", "time", - "tracing", + "tracing 0.1.40 (registry+https://github.com/rust-lang/crates.io-index)", "uuid", "whoami", ] @@ -9191,7 +9231,7 @@ dependencies = [ "stringprep", "thiserror", "time", - "tracing", + "tracing 0.1.40 (registry+https://github.com/rust-lang/crates.io-index)", "uuid", "whoami", ] @@ -9216,7 +9256,7 @@ dependencies = [ "serde", "sqlx-core", "time", - "tracing", + "tracing 0.1.40 (registry+https://github.com/rust-lang/crates.io-index)", "url", "uuid", ] @@ -10039,7 +10079,7 @@ dependencies = [ "futures-sink", "pin-project-lite", "tokio", - "tracing", + "tracing 0.1.40 (registry+https://github.com/rust-lang/crates.io-index)", ] [[package]] @@ -10134,7 +10174,7 @@ dependencies = [ "tokio", "tower-layer", "tower-service", - "tracing", + "tracing 0.1.40 (registry+https://github.com/rust-lang/crates.io-index)", ] [[package]] @@ -10171,7 +10211,7 @@ dependencies = [ "pin-project-lite", "tower-layer", "tower-service", - "tracing", + "tracing 0.1.40 (registry+https://github.com/rust-lang/crates.io-index)", ] [[package]] @@ -10188,22 +10228,30 @@ checksum = "b6bc1c9ce2b5135ac7f93c72918fc37feb872bdc6a5533a8b85eb4b86bfdae52" [[package]] name = "tracing" -version = "0.1.37" +version = "0.1.40" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ce8c33a8d48bd45d624a6e523445fd21ec13d3653cd51f681abf67418f54eb8" +checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef" dependencies = [ - "cfg-if", "log", "pin-project-lite", "tracing-attributes", - "tracing-core", + "tracing-core 0.1.32 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "tracing" +version = "0.1.40" +source = "git+https://github.com/tokio-rs/tracing?rev=tracing-subscriber-0.3.18#8b7a1dde69797b33ecfa20da71e72eb5e61f0b25" +dependencies = [ + "pin-project-lite", + "tracing-core 0.1.32 (git+https://github.com/tokio-rs/tracing?rev=tracing-subscriber-0.3.18)", ] [[package]] name = "tracing-attributes" -version = "0.1.26" +version = "0.1.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f4f31f56159e98206da9efd823404b79b6ef3143b4a7ab76e67b1751b25a4ab" +checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", @@ -10212,9 +10260,17 @@ dependencies = [ [[package]] name = "tracing-core" -version = "0.1.31" +version = "0.1.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0955b8137a1df6f1a2e9a37d8a6656291ff0297c1a97c24e0d8425fe2312f79a" +checksum = "c06d3da6113f116aaee68e4d601191614c9053067f9ab7f6edbcb161237daa54" +dependencies = [ + "once_cell", +] + +[[package]] +name = "tracing-core" +version = "0.1.32" +source = "git+https://github.com/tokio-rs/tracing?rev=tracing-subscriber-0.3.18#8b7a1dde69797b33ecfa20da71e72eb5e61f0b25" dependencies = [ "once_cell", "valuable", @@ -10227,35 +10283,32 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "97d095ae15e245a057c8e8451bab9b3ee1e1f68e9ba2b4fbc18d0ac5237835f2" dependencies = [ "pin-project", - "tracing", + "tracing 0.1.40 (registry+https://github.com/rust-lang/crates.io-index)", ] [[package]] name = "tracing-log" -version = "0.1.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "78ddad33d2d10b1ed7eb9d1f518a5674713876e97e5bb9b7345a7984fbb4f922" +version = "0.2.0" +source = "git+https://github.com/tokio-rs/tracing?rev=tracing-subscriber-0.3.18#8b7a1dde69797b33ecfa20da71e72eb5e61f0b25" dependencies = [ - "lazy_static", "log", - "tracing-core", + "once_cell", + "tracing-core 0.1.32 (git+https://github.com/tokio-rs/tracing?rev=tracing-subscriber-0.3.18)", ] [[package]] name = "tracing-serde" version = "0.1.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bc6b213177105856957181934e4920de57730fc69bf42c37ee5bb664d406d9e1" +source = "git+https://github.com/tokio-rs/tracing?rev=tracing-subscriber-0.3.18#8b7a1dde69797b33ecfa20da71e72eb5e61f0b25" dependencies = [ "serde", - "tracing-core", + "tracing-core 0.1.32 (git+https://github.com/tokio-rs/tracing?rev=tracing-subscriber-0.3.18)", ] [[package]] name = "tracing-subscriber" -version = "0.3.17" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "30a651bc37f915e81f087d86e62a18eec5f79550c7faff886f7090b4ea757c77" +version = "0.3.18" +source = "git+https://github.com/tokio-rs/tracing?rev=tracing-subscriber-0.3.18#8b7a1dde69797b33ecfa20da71e72eb5e61f0b25" dependencies = [ "matchers", "nu-ansi-term", @@ -10266,8 +10319,8 @@ dependencies = [ "sharded-slab", "smallvec", "thread_local", - "tracing", - "tracing-core", + "tracing 0.1.40 (git+https://github.com/tokio-rs/tracing?rev=tracing-subscriber-0.3.18)", + "tracing-core 0.1.32 (git+https://github.com/tokio-rs/tracing?rev=tracing-subscriber-0.3.18)", "tracing-log", "tracing-serde", ] @@ -11157,7 +11210,7 @@ dependencies = [ "anyhow", "log", "once_cell", - "tracing", + "tracing 0.1.40 (registry+https://github.com/rust-lang/crates.io-index)", "wasmtime", "wasmtime-c-api-macros", ] @@ -11369,7 +11422,7 @@ dependencies = [ "system-interface", "thiserror", "tokio", - "tracing", + "tracing 0.1.40 (registry+https://github.com/rust-lang/crates.io-index)", "url", "wasmtime", "wiggle", @@ -11606,7 +11659,7 @@ dependencies = [ "async-trait", "bitflags 2.4.2", "thiserror", - "tracing", + "tracing 0.1.40 (registry+https://github.com/rust-lang/crates.io-index)", "wasmtime", "wiggle-macro", ] @@ -12362,7 +12415,7 @@ dependencies = [ "serde_repr", "sha1", "static_assertions", - "tracing", + "tracing 0.1.40 (registry+https://github.com/rust-lang/crates.io-index)", "uds_windows", "windows-sys 0.52.0", "xdg-home", @@ -12434,6 +12487,7 @@ dependencies = [ "futures 0.3.28", "go_to_line", "gpui", + "headless", "image_viewer", "install_cli", "isahc", diff --git a/Cargo.toml b/Cargo.toml index e8b03c0d96..5802072805 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -38,6 +38,7 @@ members = [ "crates/google_ai", "crates/gpui", "crates/gpui_macros", + "crates/headless", "crates/image_viewer", "crates/install_cli", "crates/journal", @@ -164,6 +165,7 @@ go_to_line = { path = "crates/go_to_line" } google_ai = { path = "crates/google_ai" } gpui = { path = "crates/gpui" } gpui_macros = { path = "crates/gpui_macros" } +headless = { path = "crates/headless" } install_cli = { path = "crates/install_cli" } image_viewer = { path = "crates/image_viewer" } journal = { path = "crates/journal" } @@ -242,6 +244,7 @@ chrono = { version = "0.4", features = ["serde"] } clap = { version = "4.4", features = ["derive"] } clickhouse = { version = "0.11.6" } ctor = "0.2.6" +ctrlc = "3.4.4" core-foundation = { version = "0.9.3" } core-foundation-sys = "0.8.6" derive_more = "0.99.17" diff --git a/assets/icons/server.svg b/assets/icons/server.svg new file mode 100644 index 0000000000..10fbdcbff4 --- /dev/null +++ b/assets/icons/server.svg @@ -0,0 +1,5 @@ + + + + + diff --git a/assets/icons/trash.svg b/assets/icons/trash.svg new file mode 100644 index 0000000000..94d7971f9b --- /dev/null +++ b/assets/icons/trash.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/crates/call/src/room.rs b/crates/call/src/room.rs index 94c055e6b8..7ec80334e4 100644 --- a/crates/call/src/room.rs +++ b/crates/call/src/room.rs @@ -1182,7 +1182,7 @@ impl Room { cx.emit(Event::RemoteProjectJoined { project_id: id }); cx.spawn(move |this, mut cx| async move { let project = - Project::remote(id, client, user_store, language_registry, fs, cx.clone()).await?; + Project::in_room(id, client, user_store, language_registry, fs, cx.clone()).await?; this.update(&mut cx, |this, cx| { this.joined_projects.retain(|project| { diff --git a/crates/channel/src/channel.rs b/crates/channel/src/channel.rs index aee92d0f6c..f592c1f8e7 100644 --- a/crates/channel/src/channel.rs +++ b/crates/channel/src/channel.rs @@ -11,7 +11,9 @@ pub use channel_chat::{ mentions_to_proto, ChannelChat, ChannelChatEvent, ChannelMessage, ChannelMessageId, MessageParams, }; -pub use channel_store::{Channel, ChannelEvent, ChannelMembership, ChannelStore}; +pub use channel_store::{ + Channel, ChannelEvent, ChannelMembership, ChannelStore, DevServer, RemoteProject, +}; #[cfg(test)] mod channel_store_tests; diff --git a/crates/channel/src/channel_store.rs b/crates/channel/src/channel_store.rs index 28f9150143..0d323a2fa0 100644 --- a/crates/channel/src/channel_store.rs +++ b/crates/channel/src/channel_store.rs @@ -3,7 +3,10 @@ mod channel_index; use crate::{channel_buffer::ChannelBuffer, channel_chat::ChannelChat, ChannelMessage}; use anyhow::{anyhow, Result}; use channel_index::ChannelIndex; -use client::{ChannelId, Client, ClientSettings, ProjectId, Subscription, User, UserId, UserStore}; +use client::{ + ChannelId, Client, ClientSettings, DevServerId, ProjectId, RemoteProjectId, Subscription, User, + UserId, UserStore, +}; use collections::{hash_map, HashMap, HashSet}; use futures::{channel::mpsc, future::Shared, Future, FutureExt, StreamExt}; use gpui::{ @@ -12,7 +15,7 @@ use gpui::{ }; use language::Capability; use rpc::{ - proto::{self, ChannelRole, ChannelVisibility}, + proto::{self, ChannelRole, ChannelVisibility, DevServerStatus}, TypedEnvelope, }; use settings::Settings; @@ -40,7 +43,6 @@ pub struct HostedProject { name: SharedString, _visibility: proto::ChannelVisibility, } - impl From for HostedProject { fn from(project: proto::HostedProject) -> Self { Self { @@ -52,12 +54,56 @@ impl From for HostedProject { } } +#[derive(Debug, Clone)] +pub struct RemoteProject { + pub id: RemoteProjectId, + pub project_id: Option, + pub channel_id: ChannelId, + pub name: SharedString, + pub path: SharedString, + pub dev_server_id: DevServerId, +} + +impl From for RemoteProject { + fn from(project: proto::RemoteProject) -> Self { + Self { + id: RemoteProjectId(project.id), + project_id: project.project_id.map(|id| ProjectId(id)), + channel_id: ChannelId(project.channel_id), + name: project.name.into(), + path: project.path.into(), + dev_server_id: DevServerId(project.dev_server_id), + } + } +} + +#[derive(Debug, Clone)] +pub struct DevServer { + pub id: DevServerId, + pub channel_id: ChannelId, + pub name: SharedString, + pub status: DevServerStatus, +} + +impl From for DevServer { + fn from(dev_server: proto::DevServer) -> Self { + Self { + id: DevServerId(dev_server.dev_server_id), + channel_id: ChannelId(dev_server.channel_id), + status: dev_server.status(), + name: dev_server.name.into(), + } + } +} + pub struct ChannelStore { pub channel_index: ChannelIndex, channel_invitations: Vec>, channel_participants: HashMap>>, channel_states: HashMap, hosted_projects: HashMap, + remote_projects: HashMap, + dev_servers: HashMap, outgoing_invites: HashSet<(ChannelId, UserId)>, update_channels_tx: mpsc::UnboundedSender, @@ -87,6 +133,8 @@ pub struct ChannelState { observed_chat_message: Option, role: Option, projects: HashSet, + dev_servers: HashSet, + remote_projects: HashSet, } impl Channel { @@ -217,6 +265,8 @@ impl ChannelStore { channel_index: ChannelIndex::default(), channel_participants: Default::default(), hosted_projects: Default::default(), + remote_projects: Default::default(), + dev_servers: Default::default(), outgoing_invites: Default::default(), opened_buffers: Default::default(), opened_chats: Default::default(), @@ -316,6 +366,40 @@ impl ChannelStore { projects } + pub fn dev_servers_for_id(&self, channel_id: ChannelId) -> Vec { + let mut dev_servers: Vec = self + .channel_states + .get(&channel_id) + .map(|state| state.dev_servers.clone()) + .unwrap_or_default() + .into_iter() + .flat_map(|id| self.dev_servers.get(&id).cloned()) + .collect(); + dev_servers.sort_by_key(|s| (s.name.clone(), s.id)); + dev_servers + } + + pub fn find_dev_server_by_id(&self, id: DevServerId) -> Option<&DevServer> { + self.dev_servers.get(&id) + } + + pub fn find_remote_project_by_id(&self, id: RemoteProjectId) -> Option<&RemoteProject> { + self.remote_projects.get(&id) + } + + pub fn remote_projects_for_id(&self, channel_id: ChannelId) -> Vec { + let mut remote_projects: Vec = self + .channel_states + .get(&channel_id) + .map(|state| state.remote_projects.clone()) + .unwrap_or_default() + .into_iter() + .flat_map(|id| self.remote_projects.get(&id).cloned()) + .collect(); + remote_projects.sort_by_key(|p| (p.name.clone(), p.id)); + remote_projects + } + pub fn has_open_channel_buffer(&self, channel_id: ChannelId, _cx: &AppContext) -> bool { if let Some(buffer) = self.opened_buffers.get(&channel_id) { if let OpenedModelHandle::Open(buffer) = buffer { @@ -818,6 +902,45 @@ impl ChannelStore { }) } + pub fn create_remote_project( + &mut self, + channel_id: ChannelId, + dev_server_id: DevServerId, + name: String, + path: String, + cx: &mut ModelContext, + ) -> Task> { + let client = self.client.clone(); + cx.background_executor().spawn(async move { + client + .request(proto::CreateRemoteProject { + channel_id: channel_id.0, + dev_server_id: dev_server_id.0, + name, + path, + }) + .await + }) + } + + pub fn create_dev_server( + &mut self, + channel_id: ChannelId, + name: String, + cx: &mut ModelContext, + ) -> Task> { + let client = self.client.clone(); + cx.background_executor().spawn(async move { + let result = client + .request(proto::CreateDevServer { + channel_id: channel_id.0, + name, + }) + .await?; + Ok(result) + }) + } + pub fn get_channel_member_details( &self, channel_id: ChannelId, @@ -1098,7 +1221,11 @@ impl ChannelStore { || !payload.latest_channel_message_ids.is_empty() || !payload.latest_channel_buffer_versions.is_empty() || !payload.hosted_projects.is_empty() - || !payload.deleted_hosted_projects.is_empty(); + || !payload.deleted_hosted_projects.is_empty() + || !payload.dev_servers.is_empty() + || !payload.deleted_dev_servers.is_empty() + || !payload.remote_projects.is_empty() + || !payload.deleted_remote_projects.is_empty(); if channels_changed { if !payload.delete_channels.is_empty() { @@ -1186,6 +1313,60 @@ impl ChannelStore { .remove_hosted_project(old_project.project_id); } } + + for remote_project in payload.remote_projects { + let remote_project: RemoteProject = remote_project.into(); + if let Some(old_remote_project) = self + .remote_projects + .insert(remote_project.id, remote_project.clone()) + { + self.channel_states + .entry(old_remote_project.channel_id) + .or_default() + .remove_remote_project(old_remote_project.id); + } + self.channel_states + .entry(remote_project.channel_id) + .or_default() + .add_remote_project(remote_project.id); + } + + for remote_project_id in payload.deleted_remote_projects { + let remote_project_id = RemoteProjectId(remote_project_id); + + if let Some(old_project) = self.remote_projects.remove(&remote_project_id) { + self.channel_states + .entry(old_project.channel_id) + .or_default() + .remove_remote_project(old_project.id); + } + } + + for dev_server in payload.dev_servers { + let dev_server: DevServer = dev_server.into(); + if let Some(old_server) = self.dev_servers.insert(dev_server.id, dev_server.clone()) + { + self.channel_states + .entry(old_server.channel_id) + .or_default() + .remove_dev_server(old_server.id); + } + self.channel_states + .entry(dev_server.channel_id) + .or_default() + .add_dev_server(dev_server.id); + } + + for dev_server_id in payload.deleted_dev_servers { + let dev_server_id = DevServerId(dev_server_id); + + if let Some(old_server) = self.dev_servers.remove(&dev_server_id) { + self.channel_states + .entry(old_server.channel_id) + .or_default() + .remove_dev_server(old_server.id); + } + } } cx.notify(); @@ -1300,4 +1481,20 @@ impl ChannelState { fn remove_hosted_project(&mut self, project_id: ProjectId) { self.projects.remove(&project_id); } + + fn add_remote_project(&mut self, remote_project_id: RemoteProjectId) { + self.remote_projects.insert(remote_project_id); + } + + fn remove_remote_project(&mut self, remote_project_id: RemoteProjectId) { + self.remote_projects.remove(&remote_project_id); + } + + fn add_dev_server(&mut self, dev_server_id: DevServerId) { + self.dev_servers.insert(dev_server_id); + } + + fn remove_dev_server(&mut self, dev_server_id: DevServerId) { + self.dev_servers.remove(&dev_server_id); + } } diff --git a/crates/client/src/client.rs b/crates/client/src/client.rs index 7b696d1eaa..db88afb038 100644 --- a/crates/client/src/client.rs +++ b/crates/client/src/client.rs @@ -759,8 +759,9 @@ impl Client { read_credentials_from_keychain(cx).await.is_some() } - pub fn set_dev_server_token(&self, token: DevServerToken) { + pub fn set_dev_server_token(&self, token: DevServerToken) -> &Self { self.state.write().credentials = Some(Credentials::DevServer { token }); + self } #[async_recursion(?Send)] diff --git a/crates/client/src/user.rs b/crates/client/src/user.rs index e8be09dd64..2c5632593d 100644 --- a/crates/client/src/user.rs +++ b/crates/client/src/user.rs @@ -27,6 +27,12 @@ impl std::fmt::Display for ChannelId { #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy)] pub struct ProjectId(pub u64); +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy)] +pub struct DevServerId(pub u64); + +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy)] +pub struct RemoteProjectId(pub u64); + #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct ParticipantIndex(pub u32); diff --git a/crates/collab/.env.toml b/crates/collab/.env.toml index ee01d75782..9bfdf294e4 100644 --- a/crates/collab/.env.toml +++ b/crates/collab/.env.toml @@ -1,5 +1,5 @@ DATABASE_URL = "postgres://postgres@localhost/zed" -# DATABASE_URL = "sqlite:////home/zed/.config/zed/db.sqlite3?mode=rwc" +# DATABASE_URL = "sqlite:////root/0/zed/db.sqlite3?mode=rwc" DATABASE_MAX_CONNECTIONS = 5 HTTP_PORT = 8080 API_TOKEN = "secret" diff --git a/crates/collab/Cargo.toml b/crates/collab/Cargo.toml index 7fbc4bfd03..f83403c705 100644 --- a/crates/collab/Cargo.toml +++ b/crates/collab/Cargo.toml @@ -63,8 +63,8 @@ tokio.workspace = true toml.workspace = true tower = "0.4" tower-http = { workspace = true, features = ["trace"] } -tracing = "0.1.34" -tracing-subscriber = { version = "0.3.11", features = ["env-filter", "json", "registry", "tracing-log"] } +tracing = "0.1.40" +tracing-subscriber = { git = "https://github.com/tokio-rs/tracing", rev = "tracing-subscriber-0.3.18", features = ["env-filter", "json", "registry", "tracing-log"] } # workaround for https://github.com/tokio-rs/tracing/issues/2927 util.workspace = true uuid.workspace = true @@ -102,3 +102,4 @@ theme.workspace = true unindent.workspace = true util.workspace = true workspace = { workspace = true, features = ["test-support"] } +headless.workspace = true diff --git a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql index 9ad045e56d..bc14721e21 100644 --- a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql +++ b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql @@ -45,12 +45,13 @@ CREATE UNIQUE INDEX "index_rooms_on_channel_id" ON "rooms" ("channel_id"); CREATE TABLE "projects" ( "id" INTEGER PRIMARY KEY AUTOINCREMENT, - "room_id" INTEGER REFERENCES rooms (id) ON DELETE CASCADE NOT NULL, + "room_id" INTEGER REFERENCES rooms (id) ON DELETE CASCADE, "host_user_id" INTEGER REFERENCES users (id), "host_connection_id" INTEGER, "host_connection_server_id" INTEGER REFERENCES servers (id) ON DELETE CASCADE, "unregistered" BOOLEAN NOT NULL DEFAULT FALSE, - "hosted_project_id" INTEGER REFERENCES hosted_projects (id) + "hosted_project_id" INTEGER REFERENCES hosted_projects (id), + "remote_project_id" INTEGER REFERENCES remote_projects(id) ); CREATE INDEX "index_projects_on_host_connection_server_id" ON "projects" ("host_connection_server_id"); CREATE INDEX "index_projects_on_host_connection_id_and_host_connection_server_id" ON "projects" ("host_connection_id", "host_connection_server_id"); @@ -397,7 +398,9 @@ CREATE TABLE hosted_projects ( channel_id INTEGER NOT NULL REFERENCES channels(id), name TEXT NOT NULL, visibility TEXT NOT NULL, - deleted_at TIMESTAMP NULL + deleted_at TIMESTAMP NULL, + dev_server_id INTEGER REFERENCES dev_servers(id), + dev_server_path TEXT ); CREATE INDEX idx_hosted_projects_on_channel_id ON hosted_projects (channel_id); CREATE UNIQUE INDEX uix_hosted_projects_on_channel_id_and_name ON hosted_projects (channel_id, name) WHERE (deleted_at IS NULL); @@ -409,3 +412,13 @@ CREATE TABLE dev_servers ( hashed_token TEXT NOT NULL ); CREATE INDEX idx_dev_servers_on_channel_id ON dev_servers (channel_id); + +CREATE TABLE remote_projects ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + channel_id INTEGER NOT NULL REFERENCES channels(id), + dev_server_id INTEGER NOT NULL REFERENCES dev_servers(id), + name TEXT NOT NULL, + path TEXT NOT NULL +); + +ALTER TABLE hosted_projects ADD COLUMN remote_project_id INTEGER REFERENCES remote_projects(id); diff --git a/crates/collab/migrations/20240402155003_add_dev_server_projects.sql b/crates/collab/migrations/20240402155003_add_dev_server_projects.sql new file mode 100644 index 0000000000..003c43f4e2 --- /dev/null +++ b/crates/collab/migrations/20240402155003_add_dev_server_projects.sql @@ -0,0 +1,9 @@ +CREATE TABLE remote_projects ( + id INT PRIMARY KEY GENERATED ALWAYS AS IDENTITY, + channel_id INT NOT NULL REFERENCES channels(id), + dev_server_id INT NOT NULL REFERENCES dev_servers(id), + name TEXT NOT NULL, + path TEXT NOT NULL +); + +ALTER TABLE projects ADD COLUMN remote_project_id INTEGER REFERENCES remote_projects(id); diff --git a/crates/collab/src/auth.rs b/crates/collab/src/auth.rs index 5daf6e6186..915563d6b4 100644 --- a/crates/collab/src/auth.rs +++ b/crates/collab/src/auth.rs @@ -10,6 +10,7 @@ use axum::{ response::IntoResponse, }; use prometheus::{exponential_buckets, register_histogram, Histogram}; +pub use rpc::auth::random_token; use scrypt::{ password_hash::{PasswordHash, PasswordVerifier}, Scrypt, @@ -152,7 +153,7 @@ pub async fn create_access_token( /// Hashing prevents anyone with access to the database being able to login. /// As the token is randomly generated, we don't need to worry about scrypt-style /// protection. -fn hash_access_token(token: &str) -> String { +pub fn hash_access_token(token: &str) -> String { let digest = sha2::Sha256::digest(token); format!( "$sha256${}", @@ -230,18 +231,15 @@ pub async fn verify_access_token( }) } -// a dev_server_token has the format .. This is to make them -// relatively easy to copy/paste around. +pub fn generate_dev_server_token(id: usize, access_token: String) -> String { + format!("{}.{}", id, access_token) +} + pub async fn verify_dev_server_token( dev_server_token: &str, db: &Arc, ) -> anyhow::Result { - let mut parts = dev_server_token.splitn(2, '.'); - let id = DevServerId(parts.next().unwrap_or_default().parse()?); - let token = parts - .next() - .ok_or_else(|| anyhow!("invalid dev server token format"))?; - + let (id, token) = split_dev_server_token(dev_server_token)?; let token_hash = hash_access_token(&token); let server = db.get_dev_server(id).await?; @@ -257,6 +255,17 @@ pub async fn verify_dev_server_token( } } +// a dev_server_token has the format .. This is to make them +// relatively easy to copy/paste around. +pub fn split_dev_server_token(dev_server_token: &str) -> anyhow::Result<(DevServerId, &str)> { + let mut parts = dev_server_token.splitn(2, '.'); + let id = DevServerId(parts.next().unwrap_or_default().parse()?); + let token = parts + .next() + .ok_or_else(|| anyhow!("invalid dev server token format"))?; + Ok((id, token)) +} + #[cfg(test)] mod test { use rand::thread_rng; diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index e9249edcb1..24bae3fba7 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -56,6 +56,7 @@ pub struct Database { options: ConnectOptions, pool: DatabaseConnection, rooms: DashMap>>, + projects: DashMap>>, rng: Mutex, executor: Executor, notification_kinds_by_id: HashMap, @@ -74,6 +75,7 @@ impl Database { options: options.clone(), pool: sea_orm::Database::connect(options).await?, rooms: DashMap::with_capacity(16384), + projects: DashMap::with_capacity(16384), rng: Mutex::new(StdRng::seed_from_u64(0)), notification_kinds_by_id: HashMap::default(), notification_kinds_by_name: HashMap::default(), @@ -86,6 +88,7 @@ impl Database { #[cfg(test)] pub fn reset(&self) { self.rooms.clear(); + self.projects.clear(); } /// Runs the database migrations. @@ -190,7 +193,10 @@ impl Database { } /// The same as room_transaction, but if you need to only optionally return a Room. - async fn optional_room_transaction(&self, f: F) -> Result>> + async fn optional_room_transaction( + &self, + f: F, + ) -> Result>> where F: Send + Fn(TransactionHandle) -> Fut, Fut: Send + Future>>, @@ -205,7 +211,7 @@ impl Database { let _guard = lock.lock_owned().await; match tx.commit().await.map_err(Into::into) { Ok(()) => { - return Ok(Some(RoomGuard { + return Ok(Some(TransactionGuard { data, _guard, _not_send: PhantomData, @@ -240,10 +246,63 @@ impl Database { self.run(body).await } + async fn project_transaction( + &self, + project_id: ProjectId, + f: F, + ) -> Result> + where + F: Send + Fn(TransactionHandle) -> Fut, + Fut: Send + Future>, + { + let room_id = Database::room_id_for_project(&self, project_id).await?; + let body = async { + let mut i = 0; + loop { + let lock = if let Some(room_id) = room_id { + self.rooms.entry(room_id).or_default().clone() + } else { + self.projects.entry(project_id).or_default().clone() + }; + let _guard = lock.lock_owned().await; + let (tx, result) = self.with_transaction(&f).await?; + match result { + Ok(data) => match tx.commit().await.map_err(Into::into) { + Ok(()) => { + return Ok(TransactionGuard { + data, + _guard, + _not_send: PhantomData, + }); + } + Err(error) => { + if !self.retry_on_serialization_error(&error, i).await { + return Err(error); + } + } + }, + Err(error) => { + tx.rollback().await?; + if !self.retry_on_serialization_error(&error, i).await { + return Err(error); + } + } + } + i += 1; + } + }; + + self.run(body).await + } + /// room_transaction runs the block in a transaction. It returns a RoomGuard, that keeps /// the database locked until it is dropped. This ensures that updates sent to clients are /// properly serialized with respect to database changes. - async fn room_transaction(&self, room_id: RoomId, f: F) -> Result> + async fn room_transaction( + &self, + room_id: RoomId, + f: F, + ) -> Result> where F: Send + Fn(TransactionHandle) -> Fut, Fut: Send + Future>, @@ -257,7 +316,7 @@ impl Database { match result { Ok(data) => match tx.commit().await.map_err(Into::into) { Ok(()) => { - return Ok(RoomGuard { + return Ok(TransactionGuard { data, _guard, _not_send: PhantomData, @@ -399,15 +458,16 @@ impl Deref for TransactionHandle { } } -/// [`RoomGuard`] keeps a database transaction alive until it is dropped. -/// so that updates to rooms are serialized. -pub struct RoomGuard { +/// [`TransactionGuard`] keeps a database transaction alive until it is dropped. +/// It wraps data that depends on the state of the database and prevents an additional +/// transaction from starting that would invalidate that data. +pub struct TransactionGuard { data: T, _guard: OwnedMutexGuard<()>, _not_send: PhantomData>, } -impl Deref for RoomGuard { +impl Deref for TransactionGuard { type Target = T; fn deref(&self) -> &T { @@ -415,13 +475,13 @@ impl Deref for RoomGuard { } } -impl DerefMut for RoomGuard { +impl DerefMut for TransactionGuard { fn deref_mut(&mut self) -> &mut T { &mut self.data } } -impl RoomGuard { +impl TransactionGuard { /// Returns the inner value of the guard. pub fn into_inner(self) -> T { self.data @@ -518,6 +578,7 @@ pub struct MembershipUpdated { /// The result of setting a member's role. #[derive(Debug)] +#[allow(clippy::large_enum_variant)] pub enum SetMemberRoleResult { InviteUpdated(Channel), MembershipUpdated(MembershipUpdated), @@ -594,6 +655,8 @@ pub struct ChannelsForUser { pub channel_memberships: Vec, pub channel_participants: HashMap>, pub hosted_projects: Vec, + pub dev_servers: Vec, + pub remote_projects: Vec, pub observed_buffer_versions: Vec, pub observed_channel_messages: Vec, @@ -635,6 +698,30 @@ pub struct RejoinedProject { pub language_servers: Vec, } +impl RejoinedProject { + pub fn to_proto(&self) -> proto::RejoinedProject { + proto::RejoinedProject { + id: self.id.to_proto(), + worktrees: self + .worktrees + .iter() + .map(|worktree| proto::WorktreeMetadata { + id: worktree.id, + root_name: worktree.root_name.clone(), + visible: worktree.visible, + abs_path: worktree.abs_path.clone(), + }) + .collect(), + collaborators: self + .collaborators + .iter() + .map(|collaborator| collaborator.to_proto()) + .collect(), + language_servers: self.language_servers.clone(), + } + } +} + #[derive(Debug)] pub struct RejoinedWorktree { pub id: u64, diff --git a/crates/collab/src/db/ids.rs b/crates/collab/src/db/ids.rs index 91c0c440a5..0c2ba2bc13 100644 --- a/crates/collab/src/db/ids.rs +++ b/crates/collab/src/db/ids.rs @@ -84,6 +84,7 @@ id_type!(NotificationId); id_type!(NotificationKindId); id_type!(ProjectCollaboratorId); id_type!(ProjectId); +id_type!(RemoteProjectId); id_type!(ReplicaId); id_type!(RoomId); id_type!(RoomParticipantId); @@ -270,3 +271,18 @@ impl Into for ChannelVisibility { proto.into() } } + +#[derive(Copy, Clone, Debug, Serialize, PartialEq)] +pub enum PrincipalId { + UserId(UserId), + DevServerId(DevServerId), +} + +/// Indicate whether a [Buffer] has permissions to edit. +#[derive(PartialEq, Clone, Copy, Debug)] +pub enum Capability { + /// The buffer is a mutable replica. + ReadWrite, + /// The buffer is a read-only replica. + ReadOnly, +} diff --git a/crates/collab/src/db/queries.rs b/crates/collab/src/db/queries.rs index 0582b8f256..2cbbc67969 100644 --- a/crates/collab/src/db/queries.rs +++ b/crates/collab/src/db/queries.rs @@ -12,6 +12,7 @@ pub mod messages; pub mod notifications; pub mod projects; pub mod rate_buckets; +pub mod remote_projects; pub mod rooms; pub mod servers; pub mod users; diff --git a/crates/collab/src/db/queries/channels.rs b/crates/collab/src/db/queries/channels.rs index 3f168e0854..279f767df8 100644 --- a/crates/collab/src/db/queries/channels.rs +++ b/crates/collab/src/db/queries/channels.rs @@ -640,10 +640,15 @@ impl Database { .get_hosted_projects(&channel_ids, &roles_by_channel_id, tx) .await?; + let dev_servers = self.get_dev_servers(&channel_ids, tx).await?; + let remote_projects = self.get_remote_projects(&channel_ids, tx).await?; + Ok(ChannelsForUser { channel_memberships, channels, hosted_projects, + dev_servers, + remote_projects, channel_participants, latest_buffer_versions, latest_channel_messages, diff --git a/crates/collab/src/db/queries/dev_servers.rs b/crates/collab/src/db/queries/dev_servers.rs index d95897b51e..4767f24734 100644 --- a/crates/collab/src/db/queries/dev_servers.rs +++ b/crates/collab/src/db/queries/dev_servers.rs @@ -1,6 +1,6 @@ -use sea_orm::EntityTrait; +use sea_orm::{ActiveValue, ColumnTrait, DatabaseTransaction, EntityTrait, QueryFilter}; -use super::{dev_server, Database, DevServerId}; +use super::{channel, dev_server, ChannelId, Database, DevServerId, UserId}; impl Database { pub async fn get_dev_server( @@ -15,4 +15,42 @@ impl Database { }) .await } + + pub async fn get_dev_servers( + &self, + channel_ids: &Vec, + tx: &DatabaseTransaction, + ) -> crate::Result> { + let servers = dev_server::Entity::find() + .filter(dev_server::Column::ChannelId.is_in(channel_ids.iter().map(|id| id.0))) + .all(tx) + .await?; + Ok(servers) + } + + pub async fn create_dev_server( + &self, + channel_id: ChannelId, + name: &str, + hashed_access_token: &str, + user_id: UserId, + ) -> crate::Result<(channel::Model, dev_server::Model)> { + self.transaction(|tx| async move { + let channel = self.get_channel_internal(channel_id, &tx).await?; + self.check_user_is_channel_admin(&channel, user_id, &tx) + .await?; + + let dev_server = dev_server::Entity::insert(dev_server::ActiveModel { + id: ActiveValue::NotSet, + hashed_token: ActiveValue::Set(hashed_access_token.to_string()), + channel_id: ActiveValue::Set(channel_id), + name: ActiveValue::Set(name.to_string()), + }) + .exec_with_returning(&*tx) + .await?; + + Ok((channel, dev_server)) + }) + .await + } } diff --git a/crates/collab/src/db/queries/projects.rs b/crates/collab/src/db/queries/projects.rs index 6bd7022a79..03b8b5d29e 100644 --- a/crates/collab/src/db/queries/projects.rs +++ b/crates/collab/src/db/queries/projects.rs @@ -1,3 +1,5 @@ +use util::ResultExt; + use super::*; impl Database { @@ -28,7 +30,7 @@ impl Database { room_id: RoomId, connection: ConnectionId, worktrees: &[proto::WorktreeMetadata], - ) -> Result> { + ) -> Result> { self.room_transaction(room_id, |tx| async move { let participant = room_participant::Entity::find() .filter( @@ -65,6 +67,7 @@ impl Database { ))), id: ActiveValue::NotSet, hosted_project_id: ActiveValue::Set(None), + remote_project_id: ActiveValue::Set(None), } .insert(&*tx) .await?; @@ -108,20 +111,22 @@ impl Database { &self, project_id: ProjectId, connection: ConnectionId, - ) -> Result)>> { - let room_id = self.room_id_for_project(project_id).await?; - self.room_transaction(room_id, |tx| async move { + ) -> Result, Vec)>> { + self.project_transaction(project_id, |tx| async move { let guest_connection_ids = self.project_guest_connection_ids(project_id, &tx).await?; - let project = project::Entity::find_by_id(project_id) .one(&*tx) .await? .ok_or_else(|| anyhow!("project not found"))?; if project.host_connection()? == connection { + let room = if let Some(room_id) = project.room_id { + Some(self.get_room(room_id, &tx).await?) + } else { + None + }; project::Entity::delete(project.into_active_model()) .exec(&*tx) .await?; - let room = self.get_room(room_id, &tx).await?; Ok((room, guest_connection_ids)) } else { Err(anyhow!("cannot unshare a project hosted by another user"))? @@ -136,9 +141,8 @@ impl Database { project_id: ProjectId, connection: ConnectionId, worktrees: &[proto::WorktreeMetadata], - ) -> Result)>> { - let room_id = self.room_id_for_project(project_id).await?; - self.room_transaction(room_id, |tx| async move { + ) -> Result, Vec)>> { + self.project_transaction(project_id, |tx| async move { let project = project::Entity::find_by_id(project_id) .filter( Condition::all() @@ -154,12 +158,14 @@ impl Database { self.update_project_worktrees(project.id, worktrees, &tx) .await?; - let room_id = project - .room_id - .ok_or_else(|| anyhow!("project not in a room"))?; - let guest_connection_ids = self.project_guest_connection_ids(project.id, &tx).await?; - let room = self.get_room(room_id, &tx).await?; + + let room = if let Some(room_id) = project.room_id { + Some(self.get_room(room_id, &tx).await?) + } else { + None + }; + Ok((room, guest_connection_ids)) }) .await @@ -204,11 +210,10 @@ impl Database { &self, update: &proto::UpdateWorktree, connection: ConnectionId, - ) -> Result>> { + ) -> Result>> { let project_id = ProjectId::from_proto(update.project_id); let worktree_id = update.worktree_id as i64; - let room_id = self.room_id_for_project(project_id).await?; - self.room_transaction(room_id, |tx| async move { + self.project_transaction(project_id, |tx| async move { // Ensure the update comes from the host. let _project = project::Entity::find_by_id(project_id) .filter( @@ -360,11 +365,10 @@ impl Database { &self, update: &proto::UpdateDiagnosticSummary, connection: ConnectionId, - ) -> Result>> { + ) -> Result>> { let project_id = ProjectId::from_proto(update.project_id); let worktree_id = update.worktree_id as i64; - let room_id = self.room_id_for_project(project_id).await?; - self.room_transaction(room_id, |tx| async move { + self.project_transaction(project_id, |tx| async move { let summary = update .summary .as_ref() @@ -415,10 +419,9 @@ impl Database { &self, update: &proto::StartLanguageServer, connection: ConnectionId, - ) -> Result>> { + ) -> Result>> { let project_id = ProjectId::from_proto(update.project_id); - let room_id = self.room_id_for_project(project_id).await?; - self.room_transaction(room_id, |tx| async move { + self.project_transaction(project_id, |tx| async move { let server = update .server .as_ref() @@ -461,10 +464,9 @@ impl Database { &self, update: &proto::UpdateWorktreeSettings, connection: ConnectionId, - ) -> Result>> { + ) -> Result>> { let project_id = ProjectId::from_proto(update.project_id); - let room_id = self.room_id_for_project(project_id).await?; - self.room_transaction(room_id, |tx| async move { + self.project_transaction(project_id, |tx| async move { // Ensure the update comes from the host. let project = project::Entity::find_by_id(project_id) .one(&*tx) @@ -542,46 +544,36 @@ impl Database { .await } + pub async fn get_project(&self, id: ProjectId) -> Result { + self.transaction(|tx| async move { + Ok(project::Entity::find_by_id(id) + .one(&*tx) + .await? + .ok_or_else(|| anyhow!("no such project"))?) + }) + .await + } + /// Adds the given connection to the specified project /// in the current room. - pub async fn join_project_in_room( + pub async fn join_project( &self, project_id: ProjectId, connection: ConnectionId, - ) -> Result> { - let room_id = self.room_id_for_project(project_id).await?; - self.room_transaction(room_id, |tx| async move { - let participant = room_participant::Entity::find() - .filter( - Condition::all() - .add( - room_participant::Column::AnsweringConnectionId - .eq(connection.id as i32), - ) - .add( - room_participant::Column::AnsweringConnectionServerId - .eq(connection.owner_id as i32), - ), + user_id: UserId, + ) -> Result> { + self.project_transaction(project_id, |tx| async move { + let (project, role) = self + .access_project( + project_id, + connection, + PrincipalId::UserId(user_id), + Capability::ReadOnly, + &tx, ) - .one(&*tx) - .await? - .ok_or_else(|| anyhow!("must join a room first"))?; - - let project = project::Entity::find_by_id(project_id) - .one(&*tx) - .await? - .ok_or_else(|| anyhow!("no such project"))?; - if project.room_id != Some(participant.room_id) { - return Err(anyhow!("no such project"))?; - } - self.join_project_internal( - project, - participant.user_id, - connection, - participant.role.unwrap_or(ChannelRole::Member), - &tx, - ) - .await + .await?; + self.join_project_internal(project, user_id, connection, role, &tx) + .await }) .await } @@ -814,9 +806,8 @@ impl Database { &self, project_id: ProjectId, connection: ConnectionId, - ) -> Result> { - let room_id = self.room_id_for_project(project_id).await?; - self.room_transaction(room_id, |tx| async move { + ) -> Result, LeftProject)>> { + self.project_transaction(project_id, |tx| async move { let result = project_collaborator::Entity::delete_many() .filter( Condition::all() @@ -871,7 +862,12 @@ impl Database { .exec(&*tx) .await?; - let room = self.get_room(room_id, &tx).await?; + let room = if let Some(room_id) = project.room_id { + Some(self.get_room(room_id, &tx).await?) + } else { + None + }; + let left_project = LeftProject { id: project_id, host_user_id: project.host_user_id, @@ -888,17 +884,15 @@ impl Database { project_id: ProjectId, connection_id: ConnectionId, ) -> Result<()> { - let room_id = self.room_id_for_project(project_id).await?; - self.room_transaction(room_id, |tx| async move { - project_collaborator::Entity::find() + self.project_transaction(project_id, |tx| async move { + project::Entity::find() .filter( Condition::all() - .add(project_collaborator::Column::ProjectId.eq(project_id)) - .add(project_collaborator::Column::IsHost.eq(true)) - .add(project_collaborator::Column::ConnectionId.eq(connection_id.id)) + .add(project::Column::Id.eq(project_id)) + .add(project::Column::HostConnectionId.eq(Some(connection_id.id as i32))) .add( - project_collaborator::Column::ConnectionServerId - .eq(connection_id.owner_id), + project::Column::HostConnectionServerId + .eq(Some(connection_id.owner_id as i32)), ), ) .one(&*tx) @@ -911,39 +905,90 @@ impl Database { .map(|guard| guard.into_inner()) } + /// Returns the current project if the given user is authorized to access it with the specified capability. + pub async fn access_project( + &self, + project_id: ProjectId, + connection_id: ConnectionId, + principal_id: PrincipalId, + capability: Capability, + tx: &DatabaseTransaction, + ) -> Result<(project::Model, ChannelRole)> { + let (project, remote_project) = project::Entity::find_by_id(project_id) + .find_also_related(remote_project::Entity) + .one(tx) + .await? + .ok_or_else(|| anyhow!("no such project"))?; + + let user_id = match principal_id { + PrincipalId::DevServerId(_) => { + if project + .host_connection() + .is_ok_and(|connection| connection == connection_id) + { + return Ok((project, ChannelRole::Admin)); + } + return Err(anyhow!("not the project host"))?; + } + PrincipalId::UserId(user_id) => user_id, + }; + + let role = if let Some(remote_project) = remote_project { + let channel = channel::Entity::find_by_id(remote_project.channel_id) + .one(tx) + .await? + .ok_or_else(|| anyhow!("no such channel"))?; + + self.check_user_is_channel_participant(&channel, user_id, &tx) + .await? + } else if let Some(room_id) = project.room_id { + // what's the users role? + let current_participant = room_participant::Entity::find() + .filter(room_participant::Column::RoomId.eq(room_id)) + .filter(room_participant::Column::AnsweringConnectionId.eq(connection_id.id)) + .one(tx) + .await? + .ok_or_else(|| anyhow!("no such room"))?; + + current_participant.role.unwrap_or(ChannelRole::Guest) + } else { + return Err(anyhow!("not authorized to read projects"))?; + }; + + match capability { + Capability::ReadWrite => { + if !role.can_edit_projects() { + return Err(anyhow!("not authorized to edit projects"))?; + } + } + Capability::ReadOnly => { + if !role.can_read_projects() { + return Err(anyhow!("not authorized to read projects"))?; + } + } + } + + Ok((project, role)) + } + /// Returns the host connection for a read-only request to join a shared project. pub async fn host_for_read_only_project_request( &self, project_id: ProjectId, connection_id: ConnectionId, + user_id: UserId, ) -> Result { - let room_id = self.room_id_for_project(project_id).await?; - self.room_transaction(room_id, |tx| async move { - let current_participant = room_participant::Entity::find() - .filter(room_participant::Column::RoomId.eq(room_id)) - .filter(room_participant::Column::AnsweringConnectionId.eq(connection_id.id)) - .one(&*tx) - .await? - .ok_or_else(|| anyhow!("no such room"))?; - - if !current_participant - .role - .map_or(false, |role| role.can_read_projects()) - { - Err(anyhow!("not authorized to read projects"))?; - } - - let host = project_collaborator::Entity::find() - .filter( - project_collaborator::Column::ProjectId - .eq(project_id) - .and(project_collaborator::Column::IsHost.eq(true)), + self.project_transaction(project_id, |tx| async move { + let (project, _) = self + .access_project( + project_id, + connection_id, + PrincipalId::UserId(user_id), + Capability::ReadOnly, + &tx, ) - .one(&*tx) - .await? - .ok_or_else(|| anyhow!("failed to read project host"))?; - - Ok(host.connection()) + .await?; + project.host_connection() }) .await .map(|guard| guard.into_inner()) @@ -954,83 +999,56 @@ impl Database { &self, project_id: ProjectId, connection_id: ConnectionId, + user_id: UserId, ) -> Result { - let room_id = self.room_id_for_project(project_id).await?; - self.room_transaction(room_id, |tx| async move { - let current_participant = room_participant::Entity::find() - .filter(room_participant::Column::RoomId.eq(room_id)) - .filter(room_participant::Column::AnsweringConnectionId.eq(connection_id.id)) - .one(&*tx) - .await? - .ok_or_else(|| anyhow!("no such room"))?; - - if !current_participant - .role - .map_or(false, |role| role.can_edit_projects()) - { - Err(anyhow!("not authorized to edit projects"))?; - } - - let host = project_collaborator::Entity::find() - .filter( - project_collaborator::Column::ProjectId - .eq(project_id) - .and(project_collaborator::Column::IsHost.eq(true)), + self.project_transaction(project_id, |tx| async move { + let (project, _) = self + .access_project( + project_id, + connection_id, + PrincipalId::UserId(user_id), + Capability::ReadWrite, + &tx, ) - .one(&*tx) - .await? - .ok_or_else(|| anyhow!("failed to read project host"))?; - - Ok(host.connection()) + .await?; + project.host_connection() }) .await .map(|guard| guard.into_inner()) } - pub async fn project_collaborators_for_buffer_update( + pub async fn connections_for_buffer_update( &self, project_id: ProjectId, + principal_id: PrincipalId, connection_id: ConnectionId, - requires_write: bool, - ) -> Result>> { - let room_id = self.room_id_for_project(project_id).await?; - self.room_transaction(room_id, |tx| async move { - let current_participant = room_participant::Entity::find() - .filter(room_participant::Column::RoomId.eq(room_id)) - .filter(room_participant::Column::AnsweringConnectionId.eq(connection_id.id)) - .one(&*tx) - .await? - .ok_or_else(|| anyhow!("no such room"))?; + capability: Capability, + ) -> Result)>> { + self.project_transaction(project_id, |tx| async move { + // Authorize + let (project, _) = self + .access_project(project_id, connection_id, principal_id, capability, &tx) + .await?; - if requires_write - && !current_participant - .role - .map_or(false, |role| role.can_edit_projects()) - { - Err(anyhow!("not authorized to edit projects"))?; - } + let host_connection_id = project.host_connection()?; let collaborators = project_collaborator::Entity::find() .filter(project_collaborator::Column::ProjectId.eq(project_id)) .all(&*tx) - .await? - .into_iter() - .map(|collaborator| ProjectCollaborator { - connection_id: collaborator.connection(), - user_id: collaborator.user_id, - replica_id: collaborator.replica_id, - is_host: collaborator.is_host, - }) - .collect::>(); + .await?; - if collaborators - .iter() - .any(|collaborator| collaborator.connection_id == connection_id) - { - Ok(collaborators) - } else { - Err(anyhow!("no such project"))? - } + let guest_connection_ids = collaborators + .into_iter() + .filter_map(|collaborator| { + if collaborator.is_host { + None + } else { + Some(collaborator.connection()) + } + }) + .collect(); + + Ok((host_connection_id, guest_connection_ids)) }) .await } @@ -1043,24 +1061,39 @@ impl Database { &self, project_id: ProjectId, connection_id: ConnectionId, - ) -> Result>> { - let room_id = self.room_id_for_project(project_id).await?; - self.room_transaction(room_id, |tx| async move { + exclude_dev_server: bool, + ) -> Result>> { + self.project_transaction(project_id, |tx| async move { + let project = project::Entity::find_by_id(project_id) + .one(&*tx) + .await? + .ok_or_else(|| anyhow!("no such project"))?; + let mut collaborators = project_collaborator::Entity::find() .filter(project_collaborator::Column::ProjectId.eq(project_id)) .stream(&*tx) .await?; let mut connection_ids = HashSet::default(); + if let Some(host_connection) = project.host_connection().log_err() { + if !exclude_dev_server { + connection_ids.insert(host_connection); + } + } + while let Some(collaborator) = collaborators.next().await { let collaborator = collaborator?; connection_ids.insert(collaborator.connection()); } - if connection_ids.contains(&connection_id) { + if connection_ids.contains(&connection_id) + || Some(connection_id) == project.host_connection().ok() + { Ok(connection_ids) } else { - Err(anyhow!("no such project"))? + Err(anyhow!( + "can only send project updates to a project you're in" + ))? } }) .await @@ -1089,15 +1122,12 @@ impl Database { } /// Returns the [`RoomId`] for the given project. - pub async fn room_id_for_project(&self, project_id: ProjectId) -> Result { + pub async fn room_id_for_project(&self, project_id: ProjectId) -> Result> { self.transaction(|tx| async move { - let project = project::Entity::find_by_id(project_id) + Ok(project::Entity::find_by_id(project_id) .one(&*tx) .await? - .ok_or_else(|| anyhow!("project {} not found", project_id))?; - Ok(project - .room_id - .ok_or_else(|| anyhow!("project not in room"))?) + .and_then(|project| project.room_id)) }) .await } @@ -1142,7 +1172,7 @@ impl Database { project_id: ProjectId, leader_connection: ConnectionId, follower_connection: ConnectionId, - ) -> Result> { + ) -> Result> { self.room_transaction(room_id, |tx| async move { follower::ActiveModel { room_id: ActiveValue::set(room_id), @@ -1173,7 +1203,7 @@ impl Database { project_id: ProjectId, leader_connection: ConnectionId, follower_connection: ConnectionId, - ) -> Result> { + ) -> Result> { self.room_transaction(room_id, |tx| async move { follower::Entity::delete_many() .filter( diff --git a/crates/collab/src/db/queries/remote_projects.rs b/crates/collab/src/db/queries/remote_projects.rs new file mode 100644 index 0000000000..86538d219e --- /dev/null +++ b/crates/collab/src/db/queries/remote_projects.rs @@ -0,0 +1,261 @@ +use anyhow::anyhow; +use rpc::{proto, ConnectionId}; +use sea_orm::{ + ActiveModelTrait, ActiveValue, ColumnTrait, Condition, DatabaseTransaction, EntityTrait, + ModelTrait, QueryFilter, +}; + +use crate::db::ProjectId; + +use super::{ + channel, project, project_collaborator, remote_project, worktree, ChannelId, Database, + DevServerId, RejoinedProject, RemoteProjectId, ResharedProject, ServerId, UserId, +}; + +impl Database { + pub async fn get_remote_project( + &self, + remote_project_id: RemoteProjectId, + ) -> crate::Result { + self.transaction(|tx| async move { + Ok(remote_project::Entity::find_by_id(remote_project_id) + .one(&*tx) + .await? + .ok_or_else(|| anyhow!("no remote project with id {}", remote_project_id))?) + }) + .await + } + + pub async fn get_remote_projects( + &self, + channel_ids: &Vec, + tx: &DatabaseTransaction, + ) -> crate::Result> { + let servers = remote_project::Entity::find() + .filter(remote_project::Column::ChannelId.is_in(channel_ids.iter().map(|id| id.0))) + .find_also_related(project::Entity) + .all(tx) + .await?; + Ok(servers + .into_iter() + .map(|(remote_project, project)| proto::RemoteProject { + id: remote_project.id.to_proto(), + project_id: project.map(|p| p.id.to_proto()), + channel_id: remote_project.channel_id.to_proto(), + name: remote_project.name, + dev_server_id: remote_project.dev_server_id.to_proto(), + path: remote_project.path, + }) + .collect()) + } + + pub async fn get_remote_projects_for_dev_server( + &self, + dev_server_id: DevServerId, + ) -> crate::Result> { + self.transaction(|tx| async move { + let servers = remote_project::Entity::find() + .filter(remote_project::Column::DevServerId.eq(dev_server_id)) + .find_also_related(project::Entity) + .all(&*tx) + .await?; + Ok(servers + .into_iter() + .map(|(remote_project, project)| proto::RemoteProject { + id: remote_project.id.to_proto(), + project_id: project.map(|p| p.id.to_proto()), + channel_id: remote_project.channel_id.to_proto(), + name: remote_project.name, + dev_server_id: remote_project.dev_server_id.to_proto(), + path: remote_project.path, + }) + .collect()) + }) + .await + } + + pub async fn get_stale_dev_server_projects( + &self, + connection: ConnectionId, + ) -> crate::Result> { + self.transaction(|tx| async move { + let projects = project::Entity::find() + .filter( + Condition::all() + .add(project::Column::HostConnectionId.eq(connection.id)) + .add(project::Column::HostConnectionServerId.eq(connection.owner_id)), + ) + .all(&*tx) + .await?; + + Ok(projects.into_iter().map(|p| p.id).collect()) + }) + .await + } + + pub async fn create_remote_project( + &self, + channel_id: ChannelId, + dev_server_id: DevServerId, + name: &str, + path: &str, + user_id: UserId, + ) -> crate::Result<(channel::Model, remote_project::Model)> { + self.transaction(|tx| async move { + let channel = self.get_channel_internal(channel_id, &tx).await?; + self.check_user_is_channel_admin(&channel, user_id, &tx) + .await?; + + let project = remote_project::Entity::insert(remote_project::ActiveModel { + name: ActiveValue::Set(name.to_string()), + id: ActiveValue::NotSet, + channel_id: ActiveValue::Set(channel_id), + dev_server_id: ActiveValue::Set(dev_server_id), + path: ActiveValue::Set(path.to_string()), + }) + .exec_with_returning(&*tx) + .await?; + + Ok((channel, project)) + }) + .await + } + + pub async fn share_remote_project( + &self, + remote_project_id: RemoteProjectId, + dev_server_id: DevServerId, + connection: ConnectionId, + worktrees: &[proto::WorktreeMetadata], + ) -> crate::Result { + self.transaction(|tx| async move { + let remote_project = remote_project::Entity::find_by_id(remote_project_id) + .one(&*tx) + .await? + .ok_or_else(|| anyhow!("no remote project with id {}", remote_project_id))?; + + if remote_project.dev_server_id != dev_server_id { + return Err(anyhow!("remote project shared from wrong server"))?; + } + + let project = project::ActiveModel { + room_id: ActiveValue::Set(None), + host_user_id: ActiveValue::Set(None), + host_connection_id: ActiveValue::set(Some(connection.id as i32)), + host_connection_server_id: ActiveValue::set(Some(ServerId( + connection.owner_id as i32, + ))), + id: ActiveValue::NotSet, + hosted_project_id: ActiveValue::Set(None), + remote_project_id: ActiveValue::Set(Some(remote_project_id)), + } + .insert(&*tx) + .await?; + + if !worktrees.is_empty() { + worktree::Entity::insert_many(worktrees.iter().map(|worktree| { + worktree::ActiveModel { + id: ActiveValue::set(worktree.id as i64), + project_id: ActiveValue::set(project.id), + abs_path: ActiveValue::set(worktree.abs_path.clone()), + root_name: ActiveValue::set(worktree.root_name.clone()), + visible: ActiveValue::set(worktree.visible), + scan_id: ActiveValue::set(0), + completed_scan_id: ActiveValue::set(0), + } + })) + .exec(&*tx) + .await?; + } + + Ok(remote_project.to_proto(Some(project))) + }) + .await + } + + pub async fn reshare_remote_projects( + &self, + reshared_projects: &Vec, + dev_server_id: DevServerId, + connection: ConnectionId, + ) -> crate::Result> { + // todo!() project_transaction? (maybe we can make the lock per-dev-server instead of per-project?) + self.transaction(|tx| async move { + let mut ret = Vec::new(); + for reshared_project in reshared_projects { + let project_id = ProjectId::from_proto(reshared_project.project_id); + let (project, remote_project) = project::Entity::find_by_id(project_id) + .find_also_related(remote_project::Entity) + .one(&*tx) + .await? + .ok_or_else(|| anyhow!("project does not exist"))?; + + if remote_project.map(|rp| rp.dev_server_id) != Some(dev_server_id) { + return Err(anyhow!("remote project reshared from wrong server"))?; + } + + let Ok(old_connection_id) = project.host_connection() else { + return Err(anyhow!("remote project was not shared"))?; + }; + + project::Entity::update(project::ActiveModel { + id: ActiveValue::set(project_id), + host_connection_id: ActiveValue::set(Some(connection.id as i32)), + host_connection_server_id: ActiveValue::set(Some(ServerId( + connection.owner_id as i32, + ))), + ..Default::default() + }) + .exec(&*tx) + .await?; + + let collaborators = project + .find_related(project_collaborator::Entity) + .all(&*tx) + .await?; + + self.update_project_worktrees(project_id, &reshared_project.worktrees, &tx) + .await?; + + ret.push(super::ResharedProject { + id: project_id, + old_connection_id, + collaborators: collaborators + .iter() + .map(|collaborator| super::ProjectCollaborator { + connection_id: collaborator.connection(), + user_id: collaborator.user_id, + replica_id: collaborator.replica_id, + is_host: collaborator.is_host, + }) + .collect(), + worktrees: reshared_project.worktrees.clone(), + }); + } + Ok(ret) + }) + .await + } + + pub async fn rejoin_remote_projects( + &self, + rejoined_projects: &Vec, + user_id: UserId, + connection_id: ConnectionId, + ) -> crate::Result> { + // todo!() project_transaction? (maybe we can make the lock per-dev-server instead of per-project?) + self.transaction(|tx| async move { + let mut ret = Vec::new(); + for rejoined_project in rejoined_projects { + if let Some(project) = self + .rejoin_project_internal(&tx, rejoined_project, user_id, connection_id) + .await? + { + ret.push(project); + } + } + Ok(ret) + }) + .await + } +} diff --git a/crates/collab/src/db/queries/rooms.rs b/crates/collab/src/db/queries/rooms.rs index c53f60872d..46552740f3 100644 --- a/crates/collab/src/db/queries/rooms.rs +++ b/crates/collab/src/db/queries/rooms.rs @@ -6,7 +6,7 @@ impl Database { &self, room_id: RoomId, new_server_id: ServerId, - ) -> Result> { + ) -> Result> { self.room_transaction(room_id, |tx| async move { let stale_participant_filter = Condition::all() .add(room_participant::Column::RoomId.eq(room_id)) @@ -149,7 +149,7 @@ impl Database { calling_connection: ConnectionId, called_user_id: UserId, initial_project_id: Option, - ) -> Result> { + ) -> Result> { self.room_transaction(room_id, |tx| async move { let caller = room_participant::Entity::find() .filter( @@ -201,7 +201,7 @@ impl Database { &self, room_id: RoomId, called_user_id: UserId, - ) -> Result> { + ) -> Result> { self.room_transaction(room_id, |tx| async move { room_participant::Entity::delete_many() .filter( @@ -221,7 +221,7 @@ impl Database { &self, expected_room_id: Option, user_id: UserId, - ) -> Result>> { + ) -> Result>> { self.optional_room_transaction(|tx| async move { let mut filter = Condition::all() .add(room_participant::Column::UserId.eq(user_id)) @@ -258,7 +258,7 @@ impl Database { room_id: RoomId, calling_connection: ConnectionId, called_user_id: UserId, - ) -> Result> { + ) -> Result> { self.room_transaction(room_id, |tx| async move { let participant = room_participant::Entity::find() .filter( @@ -294,7 +294,7 @@ impl Database { room_id: RoomId, user_id: UserId, connection: ConnectionId, - ) -> Result> { + ) -> Result> { self.room_transaction(room_id, |tx| async move { #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] enum QueryChannelId { @@ -472,7 +472,7 @@ impl Database { rejoin_room: proto::RejoinRoom, user_id: UserId, connection: ConnectionId, - ) -> Result> { + ) -> Result> { let room_id = RoomId::from_proto(rejoin_room.id); self.room_transaction(room_id, |tx| async { let tx = tx; @@ -572,180 +572,12 @@ impl Database { let mut rejoined_projects = Vec::new(); for rejoined_project in &rejoin_room.rejoined_projects { - let project_id = ProjectId::from_proto(rejoined_project.id); - let Some(project) = project::Entity::find_by_id(project_id).one(&*tx).await? else { - continue; - }; - - let mut worktrees = Vec::new(); - let db_worktrees = project.find_related(worktree::Entity).all(&*tx).await?; - for db_worktree in db_worktrees { - let mut worktree = RejoinedWorktree { - id: db_worktree.id as u64, - abs_path: db_worktree.abs_path, - root_name: db_worktree.root_name, - visible: db_worktree.visible, - updated_entries: Default::default(), - removed_entries: Default::default(), - updated_repositories: Default::default(), - removed_repositories: Default::default(), - diagnostic_summaries: Default::default(), - settings_files: Default::default(), - scan_id: db_worktree.scan_id as u64, - completed_scan_id: db_worktree.completed_scan_id as u64, - }; - - let rejoined_worktree = rejoined_project - .worktrees - .iter() - .find(|worktree| worktree.id == db_worktree.id as u64); - - // File entries - { - let entry_filter = if let Some(rejoined_worktree) = rejoined_worktree { - worktree_entry::Column::ScanId.gt(rejoined_worktree.scan_id) - } else { - worktree_entry::Column::IsDeleted.eq(false) - }; - - let mut db_entries = worktree_entry::Entity::find() - .filter( - Condition::all() - .add(worktree_entry::Column::ProjectId.eq(project.id)) - .add(worktree_entry::Column::WorktreeId.eq(worktree.id)) - .add(entry_filter), - ) - .stream(&*tx) - .await?; - - while let Some(db_entry) = db_entries.next().await { - let db_entry = db_entry?; - if db_entry.is_deleted { - worktree.removed_entries.push(db_entry.id as u64); - } else { - worktree.updated_entries.push(proto::Entry { - id: db_entry.id as u64, - is_dir: db_entry.is_dir, - path: db_entry.path, - inode: db_entry.inode as u64, - mtime: Some(proto::Timestamp { - seconds: db_entry.mtime_seconds as u64, - nanos: db_entry.mtime_nanos as u32, - }), - is_symlink: db_entry.is_symlink, - is_ignored: db_entry.is_ignored, - is_external: db_entry.is_external, - git_status: db_entry.git_status.map(|status| status as i32), - }); - } - } - } - - // Repository Entries - { - let repository_entry_filter = - if let Some(rejoined_worktree) = rejoined_worktree { - worktree_repository::Column::ScanId.gt(rejoined_worktree.scan_id) - } else { - worktree_repository::Column::IsDeleted.eq(false) - }; - - let mut db_repositories = worktree_repository::Entity::find() - .filter( - Condition::all() - .add(worktree_repository::Column::ProjectId.eq(project.id)) - .add(worktree_repository::Column::WorktreeId.eq(worktree.id)) - .add(repository_entry_filter), - ) - .stream(&*tx) - .await?; - - while let Some(db_repository) = db_repositories.next().await { - let db_repository = db_repository?; - if db_repository.is_deleted { - worktree - .removed_repositories - .push(db_repository.work_directory_id as u64); - } else { - worktree.updated_repositories.push(proto::RepositoryEntry { - work_directory_id: db_repository.work_directory_id as u64, - branch: db_repository.branch, - }); - } - } - } - - worktrees.push(worktree); - } - - let language_servers = project - .find_related(language_server::Entity) - .all(&*tx) + if let Some(rejoined_project) = self + .rejoin_project_internal(&tx, rejoined_project, user_id, connection) .await? - .into_iter() - .map(|language_server| proto::LanguageServer { - id: language_server.id as u64, - name: language_server.name, - }) - .collect::>(); - { - let mut db_settings_files = worktree_settings_file::Entity::find() - .filter(worktree_settings_file::Column::ProjectId.eq(project_id)) - .stream(&*tx) - .await?; - while let Some(db_settings_file) = db_settings_files.next().await { - let db_settings_file = db_settings_file?; - if let Some(worktree) = worktrees - .iter_mut() - .find(|w| w.id == db_settings_file.worktree_id as u64) - { - worktree.settings_files.push(WorktreeSettingsFile { - path: db_settings_file.path, - content: db_settings_file.content, - }); - } - } + rejoined_projects.push(rejoined_project); } - - let mut collaborators = project - .find_related(project_collaborator::Entity) - .all(&*tx) - .await?; - let self_collaborator = if let Some(self_collaborator_ix) = collaborators - .iter() - .position(|collaborator| collaborator.user_id == user_id) - { - collaborators.swap_remove(self_collaborator_ix) - } else { - continue; - }; - let old_connection_id = self_collaborator.connection(); - project_collaborator::Entity::update(project_collaborator::ActiveModel { - connection_id: ActiveValue::set(connection.id as i32), - connection_server_id: ActiveValue::set(ServerId(connection.owner_id as i32)), - ..self_collaborator.into_active_model() - }) - .exec(&*tx) - .await?; - - let collaborators = collaborators - .into_iter() - .map(|collaborator| ProjectCollaborator { - connection_id: collaborator.connection(), - user_id: collaborator.user_id, - replica_id: collaborator.replica_id, - is_host: collaborator.is_host, - }) - .collect::>(); - - rejoined_projects.push(RejoinedProject { - id: project_id, - old_connection_id, - collaborators, - worktrees, - language_servers, - }); } let (channel, room) = self.get_channel_room(room_id, &tx).await?; @@ -760,10 +592,192 @@ impl Database { .await } + pub async fn rejoin_project_internal( + &self, + tx: &DatabaseTransaction, + rejoined_project: &proto::RejoinProject, + user_id: UserId, + connection: ConnectionId, + ) -> Result> { + let project_id = ProjectId::from_proto(rejoined_project.id); + let Some(project) = project::Entity::find_by_id(project_id).one(tx).await? else { + return Ok(None); + }; + + let mut worktrees = Vec::new(); + let db_worktrees = project.find_related(worktree::Entity).all(tx).await?; + for db_worktree in db_worktrees { + let mut worktree = RejoinedWorktree { + id: db_worktree.id as u64, + abs_path: db_worktree.abs_path, + root_name: db_worktree.root_name, + visible: db_worktree.visible, + updated_entries: Default::default(), + removed_entries: Default::default(), + updated_repositories: Default::default(), + removed_repositories: Default::default(), + diagnostic_summaries: Default::default(), + settings_files: Default::default(), + scan_id: db_worktree.scan_id as u64, + completed_scan_id: db_worktree.completed_scan_id as u64, + }; + + let rejoined_worktree = rejoined_project + .worktrees + .iter() + .find(|worktree| worktree.id == db_worktree.id as u64); + + // File entries + { + let entry_filter = if let Some(rejoined_worktree) = rejoined_worktree { + worktree_entry::Column::ScanId.gt(rejoined_worktree.scan_id) + } else { + worktree_entry::Column::IsDeleted.eq(false) + }; + + let mut db_entries = worktree_entry::Entity::find() + .filter( + Condition::all() + .add(worktree_entry::Column::ProjectId.eq(project.id)) + .add(worktree_entry::Column::WorktreeId.eq(worktree.id)) + .add(entry_filter), + ) + .stream(tx) + .await?; + + while let Some(db_entry) = db_entries.next().await { + let db_entry = db_entry?; + if db_entry.is_deleted { + worktree.removed_entries.push(db_entry.id as u64); + } else { + worktree.updated_entries.push(proto::Entry { + id: db_entry.id as u64, + is_dir: db_entry.is_dir, + path: db_entry.path, + inode: db_entry.inode as u64, + mtime: Some(proto::Timestamp { + seconds: db_entry.mtime_seconds as u64, + nanos: db_entry.mtime_nanos as u32, + }), + is_symlink: db_entry.is_symlink, + is_ignored: db_entry.is_ignored, + is_external: db_entry.is_external, + git_status: db_entry.git_status.map(|status| status as i32), + }); + } + } + } + + // Repository Entries + { + let repository_entry_filter = if let Some(rejoined_worktree) = rejoined_worktree { + worktree_repository::Column::ScanId.gt(rejoined_worktree.scan_id) + } else { + worktree_repository::Column::IsDeleted.eq(false) + }; + + let mut db_repositories = worktree_repository::Entity::find() + .filter( + Condition::all() + .add(worktree_repository::Column::ProjectId.eq(project.id)) + .add(worktree_repository::Column::WorktreeId.eq(worktree.id)) + .add(repository_entry_filter), + ) + .stream(tx) + .await?; + + while let Some(db_repository) = db_repositories.next().await { + let db_repository = db_repository?; + if db_repository.is_deleted { + worktree + .removed_repositories + .push(db_repository.work_directory_id as u64); + } else { + worktree.updated_repositories.push(proto::RepositoryEntry { + work_directory_id: db_repository.work_directory_id as u64, + branch: db_repository.branch, + }); + } + } + } + + worktrees.push(worktree); + } + + let language_servers = project + .find_related(language_server::Entity) + .all(tx) + .await? + .into_iter() + .map(|language_server| proto::LanguageServer { + id: language_server.id as u64, + name: language_server.name, + }) + .collect::>(); + + { + let mut db_settings_files = worktree_settings_file::Entity::find() + .filter(worktree_settings_file::Column::ProjectId.eq(project_id)) + .stream(tx) + .await?; + while let Some(db_settings_file) = db_settings_files.next().await { + let db_settings_file = db_settings_file?; + if let Some(worktree) = worktrees + .iter_mut() + .find(|w| w.id == db_settings_file.worktree_id as u64) + { + worktree.settings_files.push(WorktreeSettingsFile { + path: db_settings_file.path, + content: db_settings_file.content, + }); + } + } + } + + let mut collaborators = project + .find_related(project_collaborator::Entity) + .all(tx) + .await?; + let self_collaborator = if let Some(self_collaborator_ix) = collaborators + .iter() + .position(|collaborator| collaborator.user_id == user_id) + { + collaborators.swap_remove(self_collaborator_ix) + } else { + return Ok(None); + }; + let old_connection_id = self_collaborator.connection(); + project_collaborator::Entity::update(project_collaborator::ActiveModel { + connection_id: ActiveValue::set(connection.id as i32), + connection_server_id: ActiveValue::set(ServerId(connection.owner_id as i32)), + ..self_collaborator.into_active_model() + }) + .exec(tx) + .await?; + + let collaborators = collaborators + .into_iter() + .map(|collaborator| ProjectCollaborator { + connection_id: collaborator.connection(), + user_id: collaborator.user_id, + replica_id: collaborator.replica_id, + is_host: collaborator.is_host, + }) + .collect::>(); + + return Ok(Some(RejoinedProject { + id: project_id, + old_connection_id, + collaborators, + worktrees, + language_servers, + })); + } + pub async fn leave_room( &self, connection: ConnectionId, - ) -> Result>> { + ) -> Result>> { self.optional_room_transaction(|tx| async move { let leaving_participant = room_participant::Entity::find() .filter( @@ -935,7 +949,7 @@ impl Database { room_id: RoomId, connection: ConnectionId, location: proto::ParticipantLocation, - ) -> Result> { + ) -> Result> { self.room_transaction(room_id, |tx| async { let tx = tx; let location_kind; @@ -997,7 +1011,7 @@ impl Database { room_id: RoomId, user_id: UserId, role: ChannelRole, - ) -> Result> { + ) -> Result> { self.room_transaction(room_id, |tx| async move { room_participant::Entity::find() .filter( @@ -1150,7 +1164,7 @@ impl Database { &self, room_id: RoomId, connection_id: ConnectionId, - ) -> Result>> { + ) -> Result>> { self.room_transaction(room_id, |tx| async move { let mut participants = room_participant::Entity::find() .filter(room_participant::Column::RoomId.eq(room_id)) diff --git a/crates/collab/src/db/tables.rs b/crates/collab/src/db/tables.rs index b679337943..4a284682b2 100644 --- a/crates/collab/src/db/tables.rs +++ b/crates/collab/src/db/tables.rs @@ -24,6 +24,7 @@ pub mod observed_channel_messages; pub mod project; pub mod project_collaborator; pub mod rate_buckets; +pub mod remote_project; pub mod room; pub mod room_participant; pub mod server; diff --git a/crates/collab/src/db/tables/dev_server.rs b/crates/collab/src/db/tables/dev_server.rs index 94b1d4dc00..cd98ae4892 100644 --- a/crates/collab/src/db/tables/dev_server.rs +++ b/crates/collab/src/db/tables/dev_server.rs @@ -1,4 +1,5 @@ use crate::db::{ChannelId, DevServerId}; +use rpc::proto; use sea_orm::entity::prelude::*; #[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] @@ -15,3 +16,14 @@ impl ActiveModelBehavior for ActiveModel {} #[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] pub enum Relation {} + +impl Model { + pub fn to_proto(&self, status: proto::DevServerStatus) -> proto::DevServer { + proto::DevServer { + dev_server_id: self.id.to_proto(), + channel_id: self.channel_id.to_proto(), + name: self.name.clone(), + status: status as i32, + } + } +} diff --git a/crates/collab/src/db/tables/project.rs b/crates/collab/src/db/tables/project.rs index a357634aff..bfb0b17c9a 100644 --- a/crates/collab/src/db/tables/project.rs +++ b/crates/collab/src/db/tables/project.rs @@ -1,4 +1,4 @@ -use crate::db::{HostedProjectId, ProjectId, Result, RoomId, ServerId, UserId}; +use crate::db::{HostedProjectId, ProjectId, RemoteProjectId, Result, RoomId, ServerId, UserId}; use anyhow::anyhow; use rpc::ConnectionId; use sea_orm::entity::prelude::*; @@ -13,6 +13,7 @@ pub struct Model { pub host_connection_id: Option, pub host_connection_server_id: Option, pub hosted_project_id: Option, + pub remote_project_id: Option, } impl Model { @@ -56,6 +57,12 @@ pub enum Relation { to = "super::hosted_project::Column::Id" )] HostedProject, + #[sea_orm( + belongs_to = "super::remote_project::Entity", + from = "Column::RemoteProjectId", + to = "super::remote_project::Column::Id" + )] + RemoteProject, } impl Related for Entity { @@ -94,4 +101,10 @@ impl Related for Entity { } } +impl Related for Entity { + fn to() -> RelationDef { + Relation::RemoteProject.def() + } +} + impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab/src/db/tables/remote_project.rs b/crates/collab/src/db/tables/remote_project.rs new file mode 100644 index 0000000000..ba486d9733 --- /dev/null +++ b/crates/collab/src/db/tables/remote_project.rs @@ -0,0 +1,42 @@ +use super::project; +use crate::db::{ChannelId, DevServerId, RemoteProjectId}; +use rpc::proto; +use sea_orm::entity::prelude::*; + +#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] +#[sea_orm(table_name = "remote_projects")] +pub struct Model { + #[sea_orm(primary_key)] + pub id: RemoteProjectId, + pub channel_id: ChannelId, + pub dev_server_id: DevServerId, + pub name: String, + pub path: String, +} + +impl ActiveModelBehavior for ActiveModel {} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm(has_one = "super::project::Entity")] + Project, +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::Project.def() + } +} + +impl Model { + pub fn to_proto(&self, project: Option) -> proto::RemoteProject { + proto::RemoteProject { + id: self.id.to_proto(), + project_id: project.map(|p| p.id.to_proto()), + channel_id: self.channel_id.to_proto(), + dev_server_id: self.dev_server_id.to_proto(), + name: self.name.clone(), + path: self.path.clone(), + } + } +} diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index a2a4c16a2a..bdcfd487f1 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -1,12 +1,13 @@ mod connection_pool; use crate::{ - auth::{self}, + auth, db::{ - self, dev_server, BufferId, Channel, ChannelId, ChannelRole, ChannelsForUser, - CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId, - NotificationId, Project, ProjectId, RemoveChannelMemberResult, ReplicaId, - RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId, + self, dev_server, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser, + CreatedChannelMessage, Database, DevServerId, InviteMemberResult, MembershipUpdated, + MessageId, NotificationId, PrincipalId, Project, ProjectId, RejoinedProject, + RemoteProjectId, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId, + ServerId, UpdatedChannelMessage, User, UserId, }, executor::Executor, AppState, Error, RateLimit, RateLimiter, Result, @@ -172,6 +173,10 @@ impl Session { UserSession::new(self) } + fn for_dev_server(self) -> Option { + DevServerSession::new(self) + } + fn user_id(&self) -> Option { match &self.principal { Principal::User(user) => Some(user.id), @@ -179,6 +184,21 @@ impl Session { Principal::DevServer(_) => None, } } + + fn dev_server_id(&self) -> Option { + match &self.principal { + Principal::User(_) | Principal::Impersonated { .. } => None, + Principal::DevServer(dev_server) => Some(dev_server.id), + } + } + + fn principal_id(&self) -> PrincipalId { + match &self.principal { + Principal::User(user) => PrincipalId::UserId(user.id), + Principal::Impersonated { user, .. } => PrincipalId::UserId(user.id), + Principal::DevServer(dev_server) => PrincipalId::DevServerId(dev_server.id), + } + } } impl Debug for Session { @@ -224,6 +244,30 @@ impl DerefMut for UserSession { } } +struct DevServerSession(Session); + +impl DevServerSession { + pub fn new(s: Session) -> Option { + s.dev_server_id().map(|_| DevServerSession(s)) + } + pub fn dev_server_id(&self) -> DevServerId { + self.0.dev_server_id().unwrap() + } +} + +impl Deref for DevServerSession { + type Target = Session; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} +impl DerefMut for DevServerSession { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + fn user_handler( handler: impl 'static + Send + Sync + Fn(M, Response, UserSession) -> Fut, ) -> impl 'static + Send + Sync + Fn(M, Response, Session) -> BoxFuture<'static, Result<()>> @@ -237,7 +281,32 @@ where if let Some(user_session) = session.for_user() { Ok(handler(message, response, user_session).await?) } else { - Err(Error::Internal(anyhow!("must be a user"))) + Err(Error::Internal(anyhow!( + "must be a user to call {}", + M::NAME + ))) + } + }) + } +} + +fn dev_server_handler( + handler: impl 'static + Send + Sync + Fn(M, Response, DevServerSession) -> Fut, +) -> impl 'static + Send + Sync + Fn(M, Response, Session) -> BoxFuture<'static, Result<()>> +where + Fut: Send + Future>, +{ + let handler = Arc::new(handler); + move |message, response, session| { + let handler = handler.clone(); + Box::pin(async move { + if let Some(dev_server_session) = session.for_dev_server() { + Ok(handler(message, response, dev_server_session).await?) + } else { + Err(Error::Internal(anyhow!( + "must be a dev server to call {}", + M::NAME + ))) } }) } @@ -256,7 +325,10 @@ where if let Some(user_session) = session.for_user() { Ok(handler(message, user_session).await?) } else { - Err(Error::Internal(anyhow!("must be a user"))) + Err(Error::Internal(anyhow!( + "must be a user to call {}", + M::NAME + ))) } }) } @@ -324,10 +396,16 @@ impl Server { .add_request_handler(user_handler(cancel_call)) .add_message_handler(user_message_handler(decline_call)) .add_request_handler(user_handler(update_participant_location)) - .add_request_handler(share_project) + .add_request_handler(user_handler(share_project)) .add_message_handler(unshare_project) .add_request_handler(user_handler(join_project)) .add_request_handler(user_handler(join_hosted_project)) + .add_request_handler(user_handler(rejoin_remote_projects)) + .add_request_handler(user_handler(create_remote_project)) + .add_request_handler(user_handler(create_dev_server)) + .add_request_handler(dev_server_handler(share_remote_project)) + .add_request_handler(dev_server_handler(shutdown_dev_server)) + .add_request_handler(dev_server_handler(reconnect_dev_server)) .add_message_handler(user_message_handler(leave_project)) .add_request_handler(update_project) .add_request_handler(update_worktree) @@ -335,40 +413,96 @@ impl Server { .add_message_handler(update_language_server) .add_message_handler(update_diagnostic_summary) .add_message_handler(update_worktree_settings) - .add_request_handler(forward_read_only_project_request::) - .add_request_handler(forward_read_only_project_request::) - .add_request_handler(forward_read_only_project_request::) - .add_request_handler(forward_read_only_project_request::) - .add_request_handler(forward_read_only_project_request::) - .add_request_handler(forward_read_only_project_request::) - .add_request_handler(forward_read_only_project_request::) - .add_request_handler(forward_read_only_project_request::) - .add_request_handler(forward_read_only_project_request::) - .add_request_handler(forward_read_only_project_request::) - .add_request_handler(forward_read_only_project_request::) - .add_request_handler(forward_read_only_project_request::) - .add_request_handler(forward_mutating_project_request::) - .add_request_handler( + .add_request_handler(user_handler( + forward_read_only_project_request::, + )) + .add_request_handler(user_handler( + forward_read_only_project_request::, + )) + .add_request_handler(user_handler( + forward_read_only_project_request::, + )) + .add_request_handler(user_handler( + forward_read_only_project_request::, + )) + .add_request_handler(user_handler( + forward_read_only_project_request::, + )) + .add_request_handler(user_handler( + forward_read_only_project_request::, + )) + .add_request_handler(user_handler( + forward_read_only_project_request::, + )) + .add_request_handler(user_handler( + forward_read_only_project_request::, + )) + .add_request_handler(user_handler( + forward_read_only_project_request::, + )) + .add_request_handler(user_handler( + forward_read_only_project_request::, + )) + .add_request_handler(user_handler( + forward_read_only_project_request::, + )) + .add_request_handler(user_handler( + forward_read_only_project_request::, + )) + .add_request_handler(user_handler( + forward_mutating_project_request::, + )) + .add_request_handler(user_handler( forward_mutating_project_request::, - ) - .add_request_handler( + )) + .add_request_handler(user_handler( forward_mutating_project_request::, - ) - .add_request_handler(forward_mutating_project_request::) - .add_request_handler(forward_mutating_project_request::) - .add_request_handler(forward_mutating_project_request::) - .add_request_handler(forward_mutating_project_request::) - .add_request_handler(forward_mutating_project_request::) - .add_request_handler(forward_mutating_project_request::) - .add_request_handler(forward_mutating_project_request::) - .add_request_handler(forward_mutating_project_request::) - .add_request_handler(forward_mutating_project_request::) - .add_request_handler(forward_mutating_project_request::) - .add_request_handler(forward_mutating_project_request::) - .add_request_handler(forward_mutating_project_request::) - .add_request_handler(forward_mutating_project_request::) - .add_request_handler(forward_mutating_project_request::) - .add_request_handler(forward_mutating_project_request::) + )) + .add_request_handler(user_handler( + forward_mutating_project_request::, + )) + .add_request_handler(user_handler( + forward_mutating_project_request::, + )) + .add_request_handler(user_handler( + forward_mutating_project_request::, + )) + .add_request_handler(user_handler( + forward_mutating_project_request::, + )) + .add_request_handler(user_handler( + forward_mutating_project_request::, + )) + .add_request_handler(user_handler( + forward_mutating_project_request::, + )) + .add_request_handler(user_handler( + forward_mutating_project_request::, + )) + .add_request_handler(user_handler( + forward_mutating_project_request::, + )) + .add_request_handler(user_handler( + forward_mutating_project_request::, + )) + .add_request_handler(user_handler( + forward_mutating_project_request::, + )) + .add_request_handler(user_handler( + forward_mutating_project_request::, + )) + .add_request_handler(user_handler( + forward_mutating_project_request::, + )) + .add_request_handler(user_handler( + forward_mutating_project_request::, + )) + .add_request_handler(user_handler( + forward_mutating_project_request::, + )) + .add_request_handler(user_handler( + forward_mutating_project_request::, + )) .add_message_handler(create_buffer_for_peer) .add_request_handler(update_buffer) .add_message_handler(broadcast_project_message_from_host::) @@ -625,9 +759,11 @@ impl Server { let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0; let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0; let queue_duration_ms = total_duration_ms - processing_duration_ms; + let payload_type = M::NAME; match result { Err(error) => { - tracing::error!(%error, total_duration_ms, processing_duration_ms, queue_duration_ms, "error handling message") + // todo!(), why isn't this logged inside the span? + tracing::error!(%error, total_duration_ms, processing_duration_ms, queue_duration_ms, payload_type, "error handling message") } Ok(()) => tracing::info!(total_duration_ms, processing_duration_ms, queue_duration_ms, "finished handling message"), } @@ -876,55 +1012,77 @@ impl Server { }, )?; tracing::info!("sent hello message"); - - let Principal::User(user) = principal else { - return Ok(()); - }; - if let Some(send_connection_id) = send_connection_id.take() { let _ = send_connection_id.send(connection_id); } - if !user.connected_once { - self.peer.send(connection_id, proto::ShowContacts {})?; - self.app_state - .db - .set_user_connected_once(user.id, true) + match principal { + Principal::User(user) | Principal::Impersonated { user, admin: _ } => { + if !user.connected_once { + self.peer.send(connection_id, proto::ShowContacts {})?; + self.app_state + .db + .set_user_connected_once(user.id, true) + .await?; + } + + let (contacts, channels_for_user, channel_invites) = future::try_join3( + self.app_state.db.get_contacts(user.id), + self.app_state.db.get_channels_for_user(user.id), + self.app_state.db.get_channel_invites_for_user(user.id), + ) .await?; - } - let (contacts, channels_for_user, channel_invites) = future::try_join3( - self.app_state.db.get_contacts(user.id), - self.app_state.db.get_channels_for_user(user.id), - self.app_state.db.get_channel_invites_for_user(user.id), - ) - .await?; + { + let mut pool = self.connection_pool.lock(); + pool.add_connection(connection_id, user.id, user.admin, zed_version); + for membership in &channels_for_user.channel_memberships { + pool.subscribe_to_channel(user.id, membership.channel_id, membership.role) + } + self.peer.send( + connection_id, + build_initial_contacts_update(contacts, &pool), + )?; + self.peer.send( + connection_id, + build_update_user_channels(&channels_for_user), + )?; + self.peer.send( + connection_id, + build_channels_update(channels_for_user, channel_invites, &pool), + )?; + } - { - let mut pool = self.connection_pool.lock(); - pool.add_connection(connection_id, user.id, user.admin, zed_version); - for membership in &channels_for_user.channel_memberships { - pool.subscribe_to_channel(user.id, membership.channel_id, membership.role) + if let Some(incoming_call) = + self.app_state.db.incoming_call_for_user(user.id).await? + { + self.peer.send(connection_id, incoming_call)?; + } + + update_user_contacts(user.id, &session).await?; + } + Principal::DevServer(dev_server) => { + { + let mut pool = self.connection_pool.lock(); + if pool.dev_server_connection_id(dev_server.id).is_some() { + return Err(anyhow!(ErrorCode::DevServerAlreadyOnline))?; + }; + pool.add_dev_server(connection_id, dev_server.id, zed_version); + } + update_dev_server_status(dev_server, proto::DevServerStatus::Online, &session) + .await; + // todo!() allow only one connection. + + let projects = self + .app_state + .db + .get_remote_projects_for_dev_server(dev_server.id) + .await?; + self.peer + .send(connection_id, proto::DevServerInstructions { projects })?; } - self.peer.send( - connection_id, - build_initial_contacts_update(contacts, &pool), - )?; - self.peer.send( - connection_id, - build_update_user_channels(&channels_for_user), - )?; - self.peer.send( - connection_id, - build_channels_update(channels_for_user, channel_invites), - )?; } - if let Some(incoming_call) = self.app_state.db.incoming_call_for_user(user.id).await? { - self.peer.send(connection_id, incoming_call)?; - } - - update_user_contacts(user.id, &session).await?; Ok(()) } @@ -1202,27 +1360,36 @@ async fn connection_lost( futures::select_biased! { _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => { - if let Some(session) = session.for_user() { - log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id); - leave_room_for_session(&session, session.connection_id).await.trace_err(); - leave_channel_buffers_for_session(&session) - .await - .trace_err(); + match &session.principal { + Principal::User(_) | Principal::Impersonated{ user: _, admin:_ } => { + let session = session.for_user().unwrap(); - if !session - .connection_pool() - .await - .is_user_online(session.user_id()) - { - let db = session.db().await; - if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() { - room_updated(&room, &session.peer); + log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id); + leave_room_for_session(&session, session.connection_id).await.trace_err(); + leave_channel_buffers_for_session(&session) + .await + .trace_err(); + + if !session + .connection_pool() + .await + .is_user_online(session.user_id()) + { + let db = session.db().await; + if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() { + room_updated(&room, &session.peer); + } } - } - update_user_contacts(session.user_id(), &session).await?; - } + update_user_contacts(session.user_id(), &session).await?; + }, + Principal::DevServer(dev_server) => { + lost_dev_server_connection(&session).await?; + update_dev_server_status(&dev_server, proto::DevServerStatus::Offline, &session) + .await; + }, } + }, _ = teardown.changed().fuse() => {} } @@ -1377,25 +1544,7 @@ async fn rejoin_room( rejoined_projects: rejoined_room .rejoined_projects .iter() - .map(|rejoined_project| proto::RejoinedProject { - id: rejoined_project.id.to_proto(), - worktrees: rejoined_project - .worktrees - .iter() - .map(|worktree| proto::WorktreeMetadata { - id: worktree.id, - root_name: worktree.root_name.clone(), - visible: worktree.visible, - abs_path: worktree.abs_path.clone(), - }) - .collect(), - collaborators: rejoined_project - .collaborators - .iter() - .map(|collaborator| collaborator.to_proto()) - .collect(), - language_servers: rejoined_project.language_servers.clone(), - }) + .map(|rejoined_project| rejoined_project.to_proto()) .collect(), })?; room_updated(&rejoined_room.room, &session.peer); @@ -1434,86 +1583,7 @@ async fn rejoin_room( ); } - for project in &rejoined_room.rejoined_projects { - for collaborator in &project.collaborators { - session - .peer - .send( - collaborator.connection_id, - proto::UpdateProjectCollaborator { - project_id: project.id.to_proto(), - old_peer_id: Some(project.old_connection_id.into()), - new_peer_id: Some(session.connection_id.into()), - }, - ) - .trace_err(); - } - } - - for project in &mut rejoined_room.rejoined_projects { - for worktree in mem::take(&mut project.worktrees) { - #[cfg(any(test, feature = "test-support"))] - const MAX_CHUNK_SIZE: usize = 2; - #[cfg(not(any(test, feature = "test-support")))] - const MAX_CHUNK_SIZE: usize = 256; - - // Stream this worktree's entries. - let message = proto::UpdateWorktree { - project_id: project.id.to_proto(), - worktree_id: worktree.id, - abs_path: worktree.abs_path.clone(), - root_name: worktree.root_name, - updated_entries: worktree.updated_entries, - removed_entries: worktree.removed_entries, - scan_id: worktree.scan_id, - is_last_update: worktree.completed_scan_id == worktree.scan_id, - updated_repositories: worktree.updated_repositories, - removed_repositories: worktree.removed_repositories, - }; - for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) { - session.peer.send(session.connection_id, update.clone())?; - } - - // Stream this worktree's diagnostics. - for summary in worktree.diagnostic_summaries { - session.peer.send( - session.connection_id, - proto::UpdateDiagnosticSummary { - project_id: project.id.to_proto(), - worktree_id: worktree.id, - summary: Some(summary), - }, - )?; - } - - for settings_file in worktree.settings_files { - session.peer.send( - session.connection_id, - proto::UpdateWorktreeSettings { - project_id: project.id.to_proto(), - worktree_id: worktree.id, - path: settings_file.path, - content: Some(settings_file.content), - }, - )?; - } - } - - for language_server in &project.language_servers { - session.peer.send( - session.connection_id, - proto::UpdateLanguageServer { - project_id: project.id.to_proto(), - language_server_id: language_server.id, - variant: Some( - proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated( - proto::LspDiskBasedDiagnosticsUpdated {}, - ), - ), - }, - )?; - } - } + notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?; let rejoined_room = rejoined_room.into_inner(); @@ -1534,6 +1604,93 @@ async fn rejoin_room( Ok(()) } +fn notify_rejoined_projects( + rejoined_projects: &mut Vec, + session: &UserSession, +) -> Result<()> { + for project in rejoined_projects.iter() { + for collaborator in &project.collaborators { + session + .peer + .send( + collaborator.connection_id, + proto::UpdateProjectCollaborator { + project_id: project.id.to_proto(), + old_peer_id: Some(project.old_connection_id.into()), + new_peer_id: Some(session.connection_id.into()), + }, + ) + .trace_err(); + } + } + + for project in rejoined_projects { + for worktree in mem::take(&mut project.worktrees) { + #[cfg(any(test, feature = "test-support"))] + const MAX_CHUNK_SIZE: usize = 2; + #[cfg(not(any(test, feature = "test-support")))] + const MAX_CHUNK_SIZE: usize = 256; + + // Stream this worktree's entries. + let message = proto::UpdateWorktree { + project_id: project.id.to_proto(), + worktree_id: worktree.id, + abs_path: worktree.abs_path.clone(), + root_name: worktree.root_name, + updated_entries: worktree.updated_entries, + removed_entries: worktree.removed_entries, + scan_id: worktree.scan_id, + is_last_update: worktree.completed_scan_id == worktree.scan_id, + updated_repositories: worktree.updated_repositories, + removed_repositories: worktree.removed_repositories, + }; + for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) { + session.peer.send(session.connection_id, update.clone())?; + } + + // Stream this worktree's diagnostics. + for summary in worktree.diagnostic_summaries { + session.peer.send( + session.connection_id, + proto::UpdateDiagnosticSummary { + project_id: project.id.to_proto(), + worktree_id: worktree.id, + summary: Some(summary), + }, + )?; + } + + for settings_file in worktree.settings_files { + session.peer.send( + session.connection_id, + proto::UpdateWorktreeSettings { + project_id: project.id.to_proto(), + worktree_id: worktree.id, + path: settings_file.path, + content: Some(settings_file.content), + }, + )?; + } + } + + for language_server in &project.language_servers { + session.peer.send( + session.connection_id, + proto::UpdateLanguageServer { + project_id: project.id.to_proto(), + language_server_id: language_server.id, + variant: Some( + proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated( + proto::LspDiskBasedDiagnosticsUpdated {}, + ), + ), + }, + )?; + } + } + Ok(()) +} + /// leave room disconnects from the room. async fn leave_room( _: proto::LeaveRoom, @@ -1757,7 +1914,7 @@ async fn update_participant_location( async fn share_project( request: proto::ShareProject, response: Response, - session: Session, + session: UserSession, ) -> Result<()> { let (project_id, room) = &*session .db() @@ -1779,19 +1936,70 @@ async fn share_project( /// Unshare a project from the room. async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> { let project_id = ProjectId::from_proto(message.project_id); + unshare_project_internal(project_id, &session).await +} +async fn unshare_project_internal(project_id: ProjectId, session: &Session) -> Result<()> { let (room, guest_connection_ids) = &*session .db() .await .unshare_project(project_id, session.connection_id) .await?; + let message = proto::UnshareProject { + project_id: project_id.to_proto(), + }; + broadcast( Some(session.connection_id), guest_connection_ids.iter().copied(), |conn_id| session.peer.send(conn_id, message.clone()), ); - room_updated(&room, &session.peer); + if let Some(room) = room { + room_updated(room, &session.peer); + } + + Ok(()) +} + +/// Share a project into the room. +async fn share_remote_project( + request: proto::ShareRemoteProject, + response: Response, + session: DevServerSession, +) -> Result<()> { + let remote_project = session + .db() + .await + .share_remote_project( + RemoteProjectId::from_proto(request.remote_project_id), + session.dev_server_id(), + session.connection_id, + &request.worktrees, + ) + .await?; + let Some(project_id) = remote_project.project_id else { + return Err(anyhow!("failed to share remote project"))?; + }; + + for (connection_id, _) in session + .connection_pool() + .await + .channel_connection_ids(ChannelId::from_proto(remote_project.channel_id)) + { + session + .peer + .send( + connection_id, + proto::UpdateChannels { + remote_projects: vec![remote_project.clone()], + ..Default::default() + }, + ) + .trace_err(); + } + + response.send(proto::ShareProjectResponse { project_id })?; Ok(()) } @@ -1806,12 +2014,12 @@ async fn join_project( tracing::info!(%project_id, "join project"); - let (project, replica_id) = &mut *session - .db() - .await - .join_project_in_room(project_id, session.connection_id) + let db = session.db().await; + let (project, replica_id) = &mut *db + .join_project(project_id, session.connection_id, session.user_id()) .await?; - + drop(db); + tracing::info!(%project_id, "join remote project"); join_project_internal(response, session, project, replica_id) } @@ -1968,7 +2176,9 @@ async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Re ); project_left(&project, &session); - room_updated(&room, &session.peer); + if let Some(room) = room { + room_updated(&room, &session.peer); + } Ok(()) } @@ -1991,6 +2201,219 @@ async fn join_hosted_project( join_project_internal(response, session, &mut project, &replica_id) } +async fn create_remote_project( + request: proto::CreateRemoteProject, + response: Response, + session: UserSession, +) -> Result<()> { + let (channel, remote_project) = session + .db() + .await + .create_remote_project( + ChannelId(request.channel_id as i32), + DevServerId(request.dev_server_id as i32), + &request.name, + &request.path, + session.user_id(), + ) + .await?; + + let projects = session + .db() + .await + .get_remote_projects_for_dev_server(remote_project.dev_server_id) + .await?; + + let update = proto::UpdateChannels { + remote_projects: vec![remote_project.to_proto(None)], + ..Default::default() + }; + let connection_pool = session.connection_pool().await; + for (connection_id, role) in connection_pool.channel_connection_ids(channel.root_id()) { + if role.can_see_all_descendants() { + session.peer.send(connection_id, update.clone())?; + } + } + + let dev_server_id = remote_project.dev_server_id; + let dev_server_connection_id = connection_pool.dev_server_connection_id(dev_server_id); + if let Some(dev_server_connection_id) = dev_server_connection_id { + session.peer.send( + dev_server_connection_id, + proto::DevServerInstructions { projects }, + )?; + } + + response.send(proto::CreateRemoteProjectResponse { + remote_project: Some(remote_project.to_proto(None)), + })?; + Ok(()) +} + +async fn create_dev_server( + request: proto::CreateDevServer, + response: Response, + session: UserSession, +) -> Result<()> { + let access_token = auth::random_token(); + let hashed_access_token = auth::hash_access_token(&access_token); + + let (channel, dev_server) = session + .db() + .await + .create_dev_server( + ChannelId(request.channel_id as i32), + &request.name, + &hashed_access_token, + session.user_id(), + ) + .await?; + + let update = proto::UpdateChannels { + dev_servers: vec![dev_server.to_proto(proto::DevServerStatus::Offline)], + ..Default::default() + }; + let connection_pool = session.connection_pool().await; + for (connection_id, role) in connection_pool.channel_connection_ids(channel.root_id()) { + if role.can_see_channel(channel.visibility) { + session.peer.send(connection_id, update.clone())?; + } + } + + response.send(proto::CreateDevServerResponse { + dev_server_id: dev_server.id.0 as u64, + channel_id: request.channel_id, + access_token: auth::generate_dev_server_token(dev_server.id.0 as usize, access_token), + name: request.name.clone(), + })?; + Ok(()) +} + +async fn rejoin_remote_projects( + request: proto::RejoinRemoteProjects, + response: Response, + session: UserSession, +) -> Result<()> { + let mut rejoined_projects = { + let db = session.db().await; + db.rejoin_remote_projects( + &request.rejoined_projects, + session.user_id(), + session.0.connection_id, + ) + .await? + }; + notify_rejoined_projects(&mut rejoined_projects, &session)?; + + response.send(proto::RejoinRemoteProjectsResponse { + rejoined_projects: rejoined_projects + .into_iter() + .map(|project| project.to_proto()) + .collect(), + }) +} + +async fn reconnect_dev_server( + request: proto::ReconnectDevServer, + response: Response, + session: DevServerSession, +) -> Result<()> { + let reshared_projects = { + let db = session.db().await; + db.reshare_remote_projects( + &request.reshared_projects, + session.dev_server_id(), + session.0.connection_id, + ) + .await? + }; + + for project in &reshared_projects { + for collaborator in &project.collaborators { + session + .peer + .send( + collaborator.connection_id, + proto::UpdateProjectCollaborator { + project_id: project.id.to_proto(), + old_peer_id: Some(project.old_connection_id.into()), + new_peer_id: Some(session.connection_id.into()), + }, + ) + .trace_err(); + } + + broadcast( + Some(session.connection_id), + project + .collaborators + .iter() + .map(|collaborator| collaborator.connection_id), + |connection_id| { + session.peer.forward_send( + session.connection_id, + connection_id, + proto::UpdateProject { + project_id: project.id.to_proto(), + worktrees: project.worktrees.clone(), + }, + ) + }, + ); + } + + response.send(proto::ReconnectDevServerResponse { + reshared_projects: reshared_projects + .iter() + .map(|project| proto::ResharedProject { + id: project.id.to_proto(), + collaborators: project + .collaborators + .iter() + .map(|collaborator| collaborator.to_proto()) + .collect(), + }) + .collect(), + })?; + + Ok(()) +} + +async fn shutdown_dev_server( + _: proto::ShutdownDevServer, + response: Response, + session: DevServerSession, +) -> Result<()> { + response.send(proto::Ack {})?; + let (remote_projects, dev_server) = { + let dev_server_id = session.dev_server_id(); + let db = session.db().await; + let remote_projects = db.get_remote_projects_for_dev_server(dev_server_id).await?; + let dev_server = db.get_dev_server(dev_server_id).await?; + (remote_projects, dev_server) + }; + + for project_id in remote_projects.iter().filter_map(|p| p.project_id) { + unshare_project_internal(ProjectId::from_proto(project_id), &session.0).await?; + } + + let update = proto::UpdateChannels { + remote_projects, + dev_servers: vec![dev_server.to_proto(proto::DevServerStatus::Offline)], + ..Default::default() + }; + + for (connection_id, _) in session + .connection_pool() + .await + .channel_connection_ids(dev_server.channel_id) + { + session.peer.send(connection_id, update.clone()).trace_err(); + } + + Ok(()) +} + /// Updates other participants with changes to the project async fn update_project( request: proto::UpdateProject, @@ -2012,7 +2435,9 @@ async fn update_project( .forward_send(session.connection_id, connection_id, request.clone()) }, ); - room_updated(&room, &session.peer); + if let Some(room) = room { + room_updated(&room, &session.peer); + } response.send(proto::Ack {})?; Ok(()) @@ -2123,7 +2548,7 @@ async fn update_language_server( let project_connection_ids = session .db() .await - .project_connection_ids(project_id, session.connection_id) + .project_connection_ids(project_id, session.connection_id, true) .await?; broadcast( Some(session.connection_id), @@ -2142,7 +2567,7 @@ async fn update_language_server( async fn forward_read_only_project_request( request: T, response: Response, - session: Session, + session: UserSession, ) -> Result<()> where T: EntityMessage + RequestMessage, @@ -2151,7 +2576,7 @@ where let host_connection_id = session .db() .await - .host_for_read_only_project_request(project_id, session.connection_id) + .host_for_read_only_project_request(project_id, session.connection_id, session.user_id()) .await?; let payload = session .peer @@ -2166,7 +2591,7 @@ where async fn forward_mutating_project_request( request: T, response: Response, - session: Session, + session: UserSession, ) -> Result<()> where T: EntityMessage + RequestMessage, @@ -2175,7 +2600,7 @@ where let host_connection_id = session .db() .await - .host_for_mutating_project_request(project_id, session.connection_id) + .host_for_mutating_project_request(project_id, session.connection_id, session.user_id()) .await?; let payload = session .peer @@ -2213,52 +2638,46 @@ async fn update_buffer( session: Session, ) -> Result<()> { let project_id = ProjectId::from_proto(request.project_id); - let mut guest_connection_ids; - let mut host_connection_id = None; - - let mut requires_write_permission = false; + let mut capability = Capability::ReadOnly; for op in request.operations.iter() { match op.variant { None | Some(proto::operation::Variant::UpdateSelections(_)) => {} - Some(_) => requires_write_permission = true, + Some(_) => capability = Capability::ReadWrite, } } - { - let collaborators = session + let host = { + let guard = session .db() .await - .project_collaborators_for_buffer_update( + .connections_for_buffer_update( project_id, + session.principal_id(), session.connection_id, - requires_write_permission, + capability, ) .await?; - guest_connection_ids = Vec::with_capacity(collaborators.len() - 1); - for collaborator in collaborators.iter() { - if collaborator.is_host { - host_connection_id = Some(collaborator.connection_id); - } else { - guest_connection_ids.push(collaborator.connection_id); - } - } - } - let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?; - broadcast( - Some(session.connection_id), - guest_connection_ids, - |connection_id| { - session - .peer - .forward_send(session.connection_id, connection_id, request.clone()) - }, - ); - if host_connection_id != session.connection_id { + let (host, guests) = &*guard; + + broadcast( + Some(session.connection_id), + guests.clone(), + |connection_id| { + session + .peer + .forward_send(session.connection_id, connection_id, request.clone()) + }, + ); + + *host + }; + + if host != session.connection_id { session .peer - .forward_request(session.connection_id, host_connection_id, request.clone()) + .forward_request(session.connection_id, host, request.clone()) .await?; } @@ -2275,7 +2694,7 @@ async fn broadcast_project_message_from_host proto::UpdateUserCh fn build_channels_update( channels: ChannelsForUser, channel_invites: Vec, + pool: &ConnectionPool, ) -> proto::UpdateChannels { let mut update = proto::UpdateChannels::default(); @@ -4124,9 +4544,14 @@ fn build_channels_update( for channel in channel_invites { update.channel_invitations.push(channel.to_proto()); } - for project in channels.hosted_projects { - update.hosted_projects.push(project); - } + + update.hosted_projects = channels.hosted_projects; + update.dev_servers = channels + .dev_servers + .into_iter() + .map(|dev_server| dev_server.to_proto(pool.dev_server_status(dev_server.id))) + .collect(); + update.remote_projects = channels.remote_projects; update } @@ -4214,6 +4639,27 @@ fn channel_updated( ); } +async fn update_dev_server_status( + dev_server: &dev_server::Model, + status: proto::DevServerStatus, + session: &Session, +) { + let pool = session.connection_pool().await; + let connections = pool.channel_connection_ids(dev_server.channel_id); + for (connection_id, _) in connections { + session + .peer + .send( + connection_id, + proto::UpdateChannels { + dev_servers: vec![dev_server.to_proto(status)], + ..Default::default() + }, + ) + .trace_err(); + } +} + async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> { let db = session.db().await; @@ -4249,6 +4695,22 @@ async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> Ok(()) } +async fn lost_dev_server_connection(session: &Session) -> Result<()> { + log::info!("lost dev server connection, unsharing projects"); + let project_ids = session + .db() + .await + .get_stale_dev_server_projects(session.connection_id) + .await?; + + for project_id in project_ids { + // not unshare re-checks the connection ids match, so we get away with no transaction + unshare_project_internal(project_id, &session).await?; + } + + Ok(()) +} + async fn leave_room_for_session(session: &UserSession, connection_id: ConnectionId) -> Result<()> { let mut contacts_to_update = HashSet::default(); @@ -4384,6 +4846,7 @@ where { type Ok = T; + #[track_caller] fn trace_err(self) -> Option { match self { Ok(value) => Some(value), diff --git a/crates/collab/src/rpc/connection_pool.rs b/crates/collab/src/rpc/connection_pool.rs index 33def0b9bb..ea40d39122 100644 --- a/crates/collab/src/rpc/connection_pool.rs +++ b/crates/collab/src/rpc/connection_pool.rs @@ -1,7 +1,7 @@ -use crate::db::{ChannelId, ChannelRole, UserId}; +use crate::db::{ChannelId, ChannelRole, DevServerId, PrincipalId, UserId}; use anyhow::{anyhow, Result}; use collections::{BTreeMap, HashMap, HashSet}; -use rpc::ConnectionId; +use rpc::{proto, ConnectionId}; use semantic_version::SemanticVersion; use serde::Serialize; use std::fmt; @@ -10,12 +10,13 @@ use tracing::instrument; #[derive(Default, Serialize)] pub struct ConnectionPool { connections: BTreeMap, - connected_users: BTreeMap, + connected_users: BTreeMap, + connected_dev_servers: BTreeMap, channels: ChannelPool, } #[derive(Default, Serialize)] -struct ConnectedUser { +struct ConnectedPrincipal { connection_ids: HashSet, } @@ -36,7 +37,7 @@ impl ZedVersion { #[derive(Serialize)] pub struct Connection { - pub user_id: UserId, + pub principal_id: PrincipalId, pub admin: bool, pub zed_version: ZedVersion, } @@ -59,7 +60,7 @@ impl ConnectionPool { self.connections.insert( connection_id, Connection { - user_id, + principal_id: PrincipalId::UserId(user_id), admin, zed_version, }, @@ -68,6 +69,25 @@ impl ConnectionPool { connected_user.connection_ids.insert(connection_id); } + pub fn add_dev_server( + &mut self, + connection_id: ConnectionId, + dev_server_id: DevServerId, + zed_version: ZedVersion, + ) { + self.connections.insert( + connection_id, + Connection { + principal_id: PrincipalId::DevServerId(dev_server_id), + admin: false, + zed_version, + }, + ); + + self.connected_dev_servers + .insert(dev_server_id, connection_id); + } + #[instrument(skip(self))] pub fn remove_connection(&mut self, connection_id: ConnectionId) -> Result<()> { let connection = self @@ -75,12 +95,18 @@ impl ConnectionPool { .get_mut(&connection_id) .ok_or_else(|| anyhow!("no such connection"))?; - let user_id = connection.user_id; - let connected_user = self.connected_users.get_mut(&user_id).unwrap(); - connected_user.connection_ids.remove(&connection_id); - if connected_user.connection_ids.is_empty() { - self.connected_users.remove(&user_id); - self.channels.remove_user(&user_id); + match connection.principal_id { + PrincipalId::UserId(user_id) => { + let connected_user = self.connected_users.get_mut(&user_id).unwrap(); + connected_user.connection_ids.remove(&connection_id); + if connected_user.connection_ids.is_empty() { + self.connected_users.remove(&user_id); + self.channels.remove_user(&user_id); + } + } + PrincipalId::DevServerId(dev_server_id) => { + self.connected_dev_servers.remove(&dev_server_id); + } } self.connections.remove(&connection_id).unwrap(); Ok(()) @@ -110,6 +136,18 @@ impl ConnectionPool { .copied() } + pub fn dev_server_status(&self, dev_server_id: DevServerId) -> proto::DevServerStatus { + if self.dev_server_connection_id(dev_server_id).is_some() { + proto::DevServerStatus::Online + } else { + proto::DevServerStatus::Offline + } + } + + pub fn dev_server_connection_id(&self, dev_server_id: DevServerId) -> Option { + self.connected_dev_servers.get(&dev_server_id).copied() + } + pub fn channel_user_ids( &self, channel_id: ChannelId, @@ -154,22 +192,39 @@ impl ConnectionPool { #[cfg(test)] pub fn check_invariants(&self) { for (connection_id, connection) in &self.connections { - assert!(self - .connected_users - .get(&connection.user_id) - .unwrap() - .connection_ids - .contains(connection_id)); + match &connection.principal_id { + PrincipalId::UserId(user_id) => { + assert!(self + .connected_users + .get(user_id) + .unwrap() + .connection_ids + .contains(connection_id)); + } + PrincipalId::DevServerId(dev_server_id) => { + assert_eq!( + self.connected_dev_servers.get(&dev_server_id).unwrap(), + connection_id + ); + } + } } for (user_id, state) in &self.connected_users { for connection_id in &state.connection_ids { assert_eq!( - self.connections.get(connection_id).unwrap().user_id, - *user_id + self.connections.get(connection_id).unwrap().principal_id, + PrincipalId::UserId(*user_id) ); } } + + for (dev_server_id, connection_id) in &self.connected_dev_servers { + assert_eq!( + self.connections.get(connection_id).unwrap().principal_id, + PrincipalId::DevServerId(*dev_server_id) + ); + } } } diff --git a/crates/collab/src/tests.rs b/crates/collab/src/tests.rs index 6185b8d582..bb9ea43fda 100644 --- a/crates/collab/src/tests.rs +++ b/crates/collab/src/tests.rs @@ -8,6 +8,7 @@ mod channel_buffer_tests; mod channel_guest_tests; mod channel_message_tests; mod channel_tests; +mod dev_server_tests; mod editor_tests; mod following_tests; mod integration_tests; diff --git a/crates/collab/src/tests/dev_server_tests.rs b/crates/collab/src/tests/dev_server_tests.rs new file mode 100644 index 0000000000..91849b4fb9 --- /dev/null +++ b/crates/collab/src/tests/dev_server_tests.rs @@ -0,0 +1,110 @@ +use std::path::Path; + +use editor::Editor; +use fs::Fs; +use gpui::VisualTestContext; +use rpc::proto::DevServerStatus; +use serde_json::json; + +use crate::tests::TestServer; + +#[gpui::test] +async fn test_dev_server(cx: &mut gpui::TestAppContext, cx2: &mut gpui::TestAppContext) { + let (server, client) = TestServer::start1(cx).await; + + let channel_id = server + .make_channel("test", None, (&client, cx), &mut []) + .await; + + let resp = client + .channel_store() + .update(cx, |store, cx| { + store.create_dev_server(channel_id, "server-1".to_string(), cx) + }) + .await + .unwrap(); + + client.channel_store().update(cx, |store, _| { + assert_eq!(store.dev_servers_for_id(channel_id).len(), 1); + assert_eq!(store.dev_servers_for_id(channel_id)[0].name, "server-1"); + assert_eq!( + store.dev_servers_for_id(channel_id)[0].status, + DevServerStatus::Offline + ); + }); + + let dev_server = server.create_dev_server(resp.access_token, cx2).await; + cx.executor().run_until_parked(); + client.channel_store().update(cx, |store, _| { + assert_eq!( + store.dev_servers_for_id(channel_id)[0].status, + DevServerStatus::Online + ); + }); + + dev_server + .fs() + .insert_tree( + "/remote", + json!({ + "1.txt": "remote\nremote\nremote", + "2.js": "function two() { return 2; }", + "3.rs": "mod test", + }), + ) + .await; + + client + .channel_store() + .update(cx, |store, cx| { + store.create_remote_project( + channel_id, + client::DevServerId(resp.dev_server_id), + "project-1".to_string(), + "/remote".to_string(), + cx, + ) + }) + .await + .unwrap(); + + cx.executor().run_until_parked(); + + let remote_workspace = client + .channel_store() + .update(cx, |store, cx| { + let projects = store.remote_projects_for_id(channel_id); + assert_eq!(projects.len(), 1); + assert_eq!(projects[0].name, "project-1"); + workspace::join_remote_project( + projects[0].project_id.unwrap(), + client.app_state.clone(), + cx, + ) + }) + .await + .unwrap(); + + cx.executor().run_until_parked(); + + let cx2 = VisualTestContext::from_window(remote_workspace.into(), cx).as_mut(); + cx2.simulate_keystrokes("cmd-p 1 enter"); + + let editor = remote_workspace + .update(cx2, |ws, cx| { + ws.active_item_as::(cx).unwrap().clone() + }) + .unwrap(); + editor.update(cx2, |ed, cx| { + assert_eq!(ed.text(cx).to_string(), "remote\nremote\nremote"); + }); + cx2.simulate_input("wow!"); + cx2.simulate_keystrokes("cmd-s"); + + let content = dev_server + .fs() + .load(&Path::new("/remote/1.txt")) + .await + .unwrap(); + assert_eq!(content, "wow!remote\nremote\nremote\n"); +} diff --git a/crates/collab/src/tests/integration_tests.rs b/crates/collab/src/tests/integration_tests.rs index 87eb5d51ba..4370eda090 100644 --- a/crates/collab/src/tests/integration_tests.rs +++ b/crates/collab/src/tests/integration_tests.rs @@ -3760,7 +3760,7 @@ async fn test_leaving_project( // Client B can't join the project, unless they re-join the room. cx_b.spawn(|cx| { - Project::remote( + Project::in_room( project_id, client_b.app_state.client.clone(), client_b.user_store().clone(), diff --git a/crates/collab/src/tests/test_server.rs b/crates/collab/src/tests/test_server.rs index 78323bde76..fc0d0fdaf9 100644 --- a/crates/collab/src/tests/test_server.rs +++ b/crates/collab/src/tests/test_server.rs @@ -1,4 +1,5 @@ use crate::{ + auth::split_dev_server_token, db::{tests::TestDb, NewUserParams, UserId}, executor::Executor, rpc::{Principal, Server, ZedVersion, CLEANUP_TIMEOUT, RECONNECT_TIMEOUT}, @@ -302,6 +303,130 @@ impl TestServer { client } + pub async fn create_dev_server( + &self, + access_token: String, + cx: &mut TestAppContext, + ) -> TestClient { + cx.update(|cx| { + if cx.has_global::() { + panic!("Same cx used to create two test clients") + } + let settings = SettingsStore::test(cx); + cx.set_global(settings); + release_channel::init("0.0.0", cx); + client::init_settings(cx); + }); + let (dev_server_id, _) = split_dev_server_token(&access_token).unwrap(); + + let clock = Arc::new(FakeSystemClock::default()); + let http = FakeHttpClient::with_404_response(); + let mut client = cx.update(|cx| Client::new(clock, http.clone(), cx)); + let server = self.server.clone(); + let db = self.app_state.db.clone(); + let connection_killers = self.connection_killers.clone(); + let forbid_connections = self.forbid_connections.clone(); + Arc::get_mut(&mut client) + .unwrap() + .set_id(1) + .set_dev_server_token(client::DevServerToken(access_token.clone())) + .override_establish_connection(move |credentials, cx| { + assert_eq!( + credentials, + &Credentials::DevServer { + token: client::DevServerToken(access_token.to_string()) + } + ); + + let server = server.clone(); + let db = db.clone(); + let connection_killers = connection_killers.clone(); + let forbid_connections = forbid_connections.clone(); + cx.spawn(move |cx| async move { + if forbid_connections.load(SeqCst) { + Err(EstablishConnectionError::other(anyhow!( + "server is forbidding connections" + ))) + } else { + let (client_conn, server_conn, killed) = + Connection::in_memory(cx.background_executor().clone()); + let (connection_id_tx, connection_id_rx) = oneshot::channel(); + let dev_server = db + .get_dev_server(dev_server_id) + .await + .expect("retrieving dev_server failed"); + cx.background_executor() + .spawn(server.handle_connection( + server_conn, + "dev-server".to_string(), + Principal::DevServer(dev_server), + ZedVersion(SemanticVersion::new(1, 0, 0)), + Some(connection_id_tx), + Executor::Deterministic(cx.background_executor().clone()), + )) + .detach(); + let connection_id = connection_id_rx.await.map_err(|e| { + EstablishConnectionError::Other(anyhow!( + "{} (is server shutting down?)", + e + )) + })?; + connection_killers + .lock() + .insert(connection_id.into(), killed); + Ok(client_conn) + } + }) + }); + + let fs = FakeFs::new(cx.executor()); + let user_store = cx.new_model(|cx| UserStore::new(client.clone(), cx)); + let workspace_store = cx.new_model(|cx| WorkspaceStore::new(client.clone(), cx)); + let language_registry = Arc::new(LanguageRegistry::test(cx.executor())); + let app_state = Arc::new(workspace::AppState { + client: client.clone(), + user_store: user_store.clone(), + workspace_store, + languages: language_registry, + fs: fs.clone(), + build_window_options: |_, _| Default::default(), + node_runtime: FakeNodeRuntime::new(), + }); + + cx.update(|cx| { + theme::init(theme::LoadThemes::JustBase, cx); + Project::init(&client, cx); + client::init(&client, cx); + language::init(cx); + editor::init(cx); + workspace::init(app_state.clone(), cx); + call::init(client.clone(), user_store.clone(), cx); + channel::init(&client, user_store.clone(), cx); + notifications::init(client.clone(), user_store, cx); + collab_ui::init(&app_state, cx); + file_finder::init(cx); + menu::init(); + headless::init( + client.clone(), + headless::AppState { + languages: app_state.languages.clone(), + user_store: app_state.user_store.clone(), + fs: fs.clone(), + node_runtime: app_state.node_runtime.clone(), + }, + cx, + ); + }); + + TestClient { + app_state, + username: "dev-server".to_string(), + channel_store: cx.read(ChannelStore::global).clone(), + notification_store: cx.read(NotificationStore::global).clone(), + state: Default::default(), + } + } + pub fn disconnect_client(&self, peer_id: PeerId) { self.connection_killers .lock() diff --git a/crates/collab_ui/Cargo.toml b/crates/collab_ui/Cargo.toml index ff78b50853..89d669f991 100644 --- a/crates/collab_ui/Cargo.toml +++ b/crates/collab_ui/Cargo.toml @@ -39,6 +39,7 @@ db.workspace = true editor.workspace = true emojis.workspace = true extensions_ui.workspace = true +feature_flags.workspace = true futures.workspace = true fuzzy.workspace = true gpui.workspace = true diff --git a/crates/collab_ui/src/collab_panel.rs b/crates/collab_ui/src/collab_panel.rs index aea84ae9ef..b27b891c38 100644 --- a/crates/collab_ui/src/collab_panel.rs +++ b/crates/collab_ui/src/collab_panel.rs @@ -1,17 +1,20 @@ mod channel_modal; mod contact_finder; +mod dev_server_modal; use self::channel_modal::ChannelModal; +use self::dev_server_modal::DevServerModal; use crate::{ channel_view::ChannelView, chat_panel::ChatPanel, face_pile::FacePile, CollaborationPanelSettings, }; use call::ActiveCall; -use channel::{Channel, ChannelEvent, ChannelStore}; +use channel::{Channel, ChannelEvent, ChannelStore, RemoteProject}; use client::{ChannelId, Client, Contact, ProjectId, User, UserStore}; use contact_finder::ContactFinder; use db::kvp::KEY_VALUE_STORE; use editor::{Editor, EditorElement, EditorStyle}; +use feature_flags::{self, FeatureFlagAppExt}; use fuzzy::{match_strings, StringMatchCandidate}; use gpui::{ actions, anchored, canvas, deferred, div, fill, list, point, prelude::*, px, AnyElement, @@ -24,7 +27,7 @@ use gpui::{ use menu::{Cancel, Confirm, SecondaryConfirm, SelectNext, SelectPrev}; use project::{Fs, Project}; use rpc::{ - proto::{self, ChannelVisibility, PeerId}, + proto::{self, ChannelVisibility, DevServerStatus, PeerId}, ErrorCode, ErrorExt, }; use serde_derive::{Deserialize, Serialize}; @@ -188,6 +191,7 @@ enum ListEntry { id: ProjectId, name: SharedString, }, + RemoteProject(channel::RemoteProject), Contact { contact: Arc, calling: bool, @@ -278,10 +282,23 @@ impl CollabPanel { .push(cx.observe(&this.user_store, |this, _, cx| { this.update_entries(true, cx) })); - this.subscriptions - .push(cx.observe(&this.channel_store, |this, _, cx| { + let mut has_opened = false; + this.subscriptions.push(cx.observe( + &this.channel_store, + move |this, channel_store, cx| { + if !has_opened { + if !channel_store + .read(cx) + .dev_servers_for_id(ChannelId(1)) + .is_empty() + { + this.manage_remote_projects(ChannelId(1), cx); + has_opened = true; + } + } this.update_entries(true, cx) - })); + }, + )); this.subscriptions .push(cx.observe(&active_call, |this, _, cx| this.update_entries(true, cx))); this.subscriptions.push(cx.subscribe( @@ -569,6 +586,7 @@ impl CollabPanel { } let hosted_projects = channel_store.projects_for_id(channel.id); + let remote_projects = channel_store.remote_projects_for_id(channel.id); let has_children = channel_store .channel_at_index(mat.candidate_id + 1) .map_or(false, |next_channel| { @@ -604,7 +622,13 @@ impl CollabPanel { } for (name, id) in hosted_projects { - self.entries.push(ListEntry::HostedProject { id, name }) + self.entries.push(ListEntry::HostedProject { id, name }); + } + + if cx.has_flag::() { + for remote_project in remote_projects { + self.entries.push(ListEntry::RemoteProject(remote_project)); + } } } } @@ -1065,6 +1089,59 @@ impl CollabPanel { .tooltip(move |cx| Tooltip::text("Open Project", cx)) } + fn render_remote_project( + &self, + remote_project: &RemoteProject, + is_selected: bool, + cx: &mut ViewContext, + ) -> impl IntoElement { + let id = remote_project.id; + let name = remote_project.name.clone(); + let maybe_project_id = remote_project.project_id; + + let dev_server = self + .channel_store + .read(cx) + .find_dev_server_by_id(remote_project.dev_server_id); + + let tooltip_text = SharedString::from(match dev_server { + Some(dev_server) => format!("Open Remote Project ({})", dev_server.name), + None => "Open Remote Project".to_string(), + }); + + let dev_server_is_online = dev_server.map(|s| s.status) == Some(DevServerStatus::Online); + + let dev_server_text_color = if dev_server_is_online { + Color::Default + } else { + Color::Disabled + }; + + ListItem::new(ElementId::NamedInteger( + "remote-project".into(), + id.0 as usize, + )) + .indent_level(2) + .indent_step_size(px(20.)) + .selected(is_selected) + .on_click(cx.listener(move |this, _, cx| { + //TODO display error message if dev server is offline + if dev_server_is_online { + if let Some(project_id) = maybe_project_id { + this.join_remote_project(project_id, cx); + } + } + })) + .start_slot( + h_flex() + .relative() + .gap_1() + .child(IconButton::new(0, IconName::FileTree).icon_color(dev_server_text_color)), + ) + .child(Label::new(name.clone()).color(dev_server_text_color)) + .tooltip(move |cx| Tooltip::text(tooltip_text.clone(), cx)) + } + fn has_subchannels(&self, ix: usize) -> bool { self.entries.get(ix).map_or(false, |entry| { if let ListEntry::Channel { has_children, .. } = entry { @@ -1266,11 +1343,24 @@ impl CollabPanel { } if self.channel_store.read(cx).is_root_channel(channel_id) { - context_menu = context_menu.separator().entry( - "Manage Members", - None, - cx.handler_for(&this, move |this, cx| this.manage_members(channel_id, cx)), - ) + context_menu = context_menu + .separator() + .entry( + "Manage Members", + None, + cx.handler_for(&this, move |this, cx| { + this.manage_members(channel_id, cx) + }), + ) + .when(cx.has_flag::(), |context_menu| { + context_menu.entry( + "Manage Remote Projects", + None, + cx.handler_for(&this, move |this, cx| { + this.manage_remote_projects(channel_id, cx) + }), + ) + }) } else { context_menu = context_menu.entry( "Move this channel", @@ -1534,6 +1624,11 @@ impl CollabPanel { } => { // todo() } + ListEntry::RemoteProject(project) => { + if let Some(project_id) = project.project_id { + self.join_remote_project(project_id, cx) + } + } ListEntry::OutgoingRequest(_) => {} ListEntry::ChannelEditor { .. } => {} @@ -1706,6 +1801,18 @@ impl CollabPanel { self.show_channel_modal(channel_id, channel_modal::Mode::ManageMembers, cx); } + fn manage_remote_projects(&mut self, channel_id: ChannelId, cx: &mut ViewContext) { + let channel_store = self.channel_store.clone(); + let Some(workspace) = self.workspace.upgrade() else { + return; + }; + workspace.update(cx, |workspace, cx| { + workspace.toggle_modal(cx, |cx| { + DevServerModal::new(channel_store.clone(), channel_id, cx) + }); + }); + } + fn remove_selected_channel(&mut self, _: &Remove, cx: &mut ViewContext) { if let Some(channel) = self.selected_channel() { self.remove_channel(channel.id, cx) @@ -2006,6 +2113,18 @@ impl CollabPanel { .detach_and_prompt_err("Failed to join channel", cx, |_, _| None) } + fn join_remote_project(&mut self, project_id: ProjectId, cx: &mut ViewContext) { + let Some(workspace) = self.workspace.upgrade() else { + return; + }; + let app_state = workspace.read(cx).app_state().clone(); + workspace::join_remote_project(project_id, app_state, cx).detach_and_prompt_err( + "Failed to join project", + cx, + |_, _| None, + ) + } + fn join_channel_chat(&mut self, channel_id: ChannelId, cx: &mut ViewContext) { let Some(workspace) = self.workspace.upgrade() else { return; @@ -2141,6 +2260,9 @@ impl CollabPanel { ListEntry::HostedProject { id, name } => self .render_channel_project(*id, name, is_selected, cx) .into_any_element(), + ListEntry::RemoteProject(remote_project) => self + .render_remote_project(remote_project, is_selected, cx) + .into_any_element(), } } @@ -2883,6 +3005,11 @@ impl PartialEq for ListEntry { return id == other_id; } } + ListEntry::RemoteProject(project) => { + if let ListEntry::RemoteProject(other) = other { + return project.id == other.id; + } + } ListEntry::ChannelNotes { channel_id } => { if let ListEntry::ChannelNotes { channel_id: other_id, diff --git a/crates/collab_ui/src/collab_panel/dev_server_modal.rs b/crates/collab_ui/src/collab_panel/dev_server_modal.rs new file mode 100644 index 0000000000..4e2057c140 --- /dev/null +++ b/crates/collab_ui/src/collab_panel/dev_server_modal.rs @@ -0,0 +1,622 @@ +use channel::{ChannelStore, DevServer, RemoteProject}; +use client::{ChannelId, DevServerId, RemoteProjectId}; +use editor::Editor; +use gpui::{ + AppContext, ClipboardItem, DismissEvent, EventEmitter, FocusHandle, FocusableView, Model, + ScrollHandle, Task, View, ViewContext, +}; +use rpc::proto::{self, CreateDevServerResponse, DevServerStatus}; +use ui::{prelude::*, Indicator, List, ListHeader, ModalContent, ModalHeader, Tooltip}; +use util::ResultExt; +use workspace::ModalView; + +pub struct DevServerModal { + mode: Mode, + focus_handle: FocusHandle, + scroll_handle: ScrollHandle, + channel_store: Model, + channel_id: ChannelId, + remote_project_name_editor: View, + remote_project_path_editor: View, + dev_server_name_editor: View, + _subscriptions: [gpui::Subscription; 2], +} + +#[derive(Default)] +struct CreateDevServer { + creating: Option>, + dev_server: Option, +} + +struct CreateRemoteProject { + dev_server_id: DevServerId, + creating: Option>, + remote_project: Option, +} + +enum Mode { + Default, + CreateRemoteProject(CreateRemoteProject), + CreateDevServer(CreateDevServer), +} + +impl DevServerModal { + pub fn new( + channel_store: Model, + channel_id: ChannelId, + cx: &mut ViewContext, + ) -> Self { + let name_editor = cx.new_view(|cx| Editor::single_line(cx)); + let path_editor = cx.new_view(|cx| Editor::single_line(cx)); + let dev_server_name_editor = cx.new_view(|cx| { + let mut editor = Editor::single_line(cx); + editor.set_placeholder_text("Dev server name", cx); + editor + }); + + let focus_handle = cx.focus_handle(); + + let subscriptions = [ + cx.observe(&channel_store, |_, _, cx| { + cx.notify(); + }), + cx.on_focus_out(&focus_handle, |_, _cx| { /* cx.emit(DismissEvent) */ }), + ]; + + Self { + mode: Mode::Default, + focus_handle, + scroll_handle: ScrollHandle::new(), + channel_store, + channel_id, + remote_project_name_editor: name_editor, + remote_project_path_editor: path_editor, + dev_server_name_editor, + _subscriptions: subscriptions, + } + } + + pub fn create_remote_project( + &mut self, + dev_server_id: DevServerId, + cx: &mut ViewContext, + ) { + let channel_id = self.channel_id; + let name = self + .remote_project_name_editor + .read(cx) + .text(cx) + .trim() + .to_string(); + let path = self + .remote_project_path_editor + .read(cx) + .text(cx) + .trim() + .to_string(); + + if name == "" { + return; + } + if path == "" { + return; + } + + let create = self.channel_store.update(cx, |store, cx| { + store.create_remote_project(channel_id, dev_server_id, name, path, cx) + }); + + let task = cx.spawn(|this, mut cx| async move { + let result = create.await; + if let Err(e) = &result { + cx.prompt( + gpui::PromptLevel::Critical, + "Failed to create project", + Some(&format!("{:?}. Please try again.", e)), + &["Ok"], + ) + .await + .log_err(); + } + this.update(&mut cx, |this, _| { + this.mode = Mode::CreateRemoteProject(CreateRemoteProject { + dev_server_id, + creating: None, + remote_project: result.ok().and_then(|r| r.remote_project), + }); + }) + .log_err(); + }); + + self.mode = Mode::CreateRemoteProject(CreateRemoteProject { + dev_server_id, + creating: Some(task), + remote_project: None, + }); + } + + pub fn create_dev_server(&mut self, cx: &mut ViewContext) { + let name = self + .dev_server_name_editor + .read(cx) + .text(cx) + .trim() + .to_string(); + + if name == "" { + return; + } + + let dev_server = self.channel_store.update(cx, |store, cx| { + store.create_dev_server(self.channel_id, name.clone(), cx) + }); + + let task = cx.spawn(|this, mut cx| async move { + match dev_server.await { + Ok(dev_server) => { + this.update(&mut cx, |this, _| { + this.mode = Mode::CreateDevServer(CreateDevServer { + creating: None, + dev_server: Some(dev_server), + }); + }) + .log_err(); + } + Err(e) => { + cx.prompt( + gpui::PromptLevel::Critical, + "Failed to create server", + Some(&format!("{:?}. Please try again.", e)), + &["Ok"], + ) + .await + .log_err(); + this.update(&mut cx, |this, _| { + this.mode = Mode::CreateDevServer(Default::default()); + }) + .log_err(); + } + } + }); + + self.mode = Mode::CreateDevServer(CreateDevServer { + creating: Some(task), + dev_server: None, + }); + cx.notify() + } + + fn cancel(&mut self, _: &menu::Cancel, cx: &mut ViewContext) { + match self.mode { + Mode::Default => cx.emit(DismissEvent), + Mode::CreateRemoteProject(_) | Mode::CreateDevServer(_) => { + self.mode = Mode::Default; + cx.notify(); + } + } + } + + fn render_dev_server( + &mut self, + dev_server: &DevServer, + cx: &mut ViewContext, + ) -> impl IntoElement { + let channel_store = self.channel_store.read(cx); + let dev_server_id = dev_server.id; + let status = dev_server.status; + + v_flex() + .w_full() + .child( + h_flex() + .group("dev-server") + .justify_between() + .child( + h_flex() + .gap_2() + .child( + div() + .id(("status", dev_server.id.0)) + .relative() + .child(Icon::new(IconName::Server).size(IconSize::Small)) + .child( + div().absolute().bottom_0().left(rems_from_px(8.0)).child( + Indicator::dot().color(match status { + DevServerStatus::Online => Color::Created, + DevServerStatus::Offline => Color::Deleted, + }), + ), + ) + .tooltip(move |cx| { + Tooltip::text( + match status { + DevServerStatus::Online => "Online", + DevServerStatus::Offline => "Offline", + }, + cx, + ) + }), + ) + .child(dev_server.name.clone()) + .child( + h_flex() + .visible_on_hover("dev-server") + .gap_1() + .child( + IconButton::new("edit-dev-server", IconName::Pencil) + .disabled(true) //TODO implement this on the collab side + .tooltip(|cx| { + Tooltip::text("Coming Soon - Edit dev server", cx) + }), + ) + .child( + IconButton::new("remove-dev-server", IconName::Trash) + .disabled(true) //TODO implement this on the collab side + .tooltip(|cx| { + Tooltip::text("Coming Soon - Remove dev server", cx) + }), + ), + ), + ) + .child( + h_flex().gap_1().child( + IconButton::new("add-remote-project", IconName::Plus) + .tooltip(|cx| Tooltip::text("Add a remote project", cx)) + .on_click(cx.listener(move |this, _, cx| { + this.mode = Mode::CreateRemoteProject(CreateRemoteProject { + dev_server_id, + creating: None, + remote_project: None, + }); + cx.notify(); + })), + ), + ), + ) + .child( + v_flex() + .w_full() + .bg(cx.theme().colors().title_bar_background) + .border() + .border_color(cx.theme().colors().border_variant) + .rounded_md() + .my_1() + .py_0p5() + .px_3() + .child( + List::new().empty_message("No projects.").children( + channel_store + .remote_projects_for_id(dev_server.channel_id) + .iter() + .filter_map(|remote_project| { + if remote_project.dev_server_id == dev_server.id { + Some(self.render_remote_project(remote_project, cx)) + } else { + None + } + }), + ), + ), + ) + // .child(div().ml_8().child( + // Button::new(("add-project", dev_server_id.0), "Add Project").on_click(cx.listener( + // move |this, _, cx| { + // this.mode = Mode::CreateRemoteProject(CreateRemoteProject { + // dev_server_id, + // creating: None, + // remote_project: None, + // }); + // cx.notify(); + // }, + // )), + // )) + } + + fn render_remote_project( + &mut self, + project: &RemoteProject, + _: &mut ViewContext, + ) -> impl IntoElement { + h_flex() + .gap_2() + .child(Icon::new(IconName::FileTree)) + .child(Label::new(project.name.clone())) + .child(Label::new(format!("({})", project.path.clone())).color(Color::Muted)) + } + + fn render_create_dev_server(&mut self, cx: &mut ViewContext) -> impl IntoElement { + let Mode::CreateDevServer(CreateDevServer { + creating, + dev_server, + }) = &self.mode + else { + unreachable!() + }; + + self.dev_server_name_editor.update(cx, |editor, _| { + editor.set_read_only(creating.is_some() || dev_server.is_some()) + }); + v_flex() + .px_1() + .pt_0p5() + .gap_px() + .child( + v_flex().py_0p5().px_1().child( + h_flex() + .px_1() + .py_0p5() + .child( + IconButton::new("back", IconName::ArrowLeft) + .style(ButtonStyle::Transparent) + .on_click(cx.listener(|this, _: &gpui::ClickEvent, cx| { + this.mode = Mode::Default; + cx.notify(); + })), + ) + .child(Headline::new("Register dev server")), + ), + ) + .child( + h_flex() + .ml_5() + .gap_2() + .child("Name") + .child(self.dev_server_name_editor.clone()) + .on_action( + cx.listener(|this, _: &menu::Confirm, cx| this.create_dev_server(cx)), + ) + .when(creating.is_none() && dev_server.is_none(), |div| { + div.child( + Button::new("create-dev-server", "Create").on_click(cx.listener( + move |this, _, cx| { + this.create_dev_server(cx); + }, + )), + ) + }) + .when(creating.is_some() && dev_server.is_none(), |div| { + div.child(Button::new("create-dev-server", "Creating...").disabled(true)) + }), + ) + .when_some(dev_server.clone(), |div, dev_server| { + let channel_store = self.channel_store.read(cx); + let status = channel_store + .find_dev_server_by_id(DevServerId(dev_server.dev_server_id)) + .map(|server| server.status) + .unwrap_or(DevServerStatus::Offline); + let instructions = SharedString::from(format!( + "zed --dev-server-token {}", + dev_server.access_token + )); + div.child( + v_flex() + .ml_8() + .gap_2() + .child(Label::new(format!( + "Please log into `{}` and run:", + dev_server.name + ))) + .child(instructions.clone()) + .child( + IconButton::new("copy-access-token", IconName::Copy) + .on_click(cx.listener(move |_, _, cx| { + cx.write_to_clipboard(ClipboardItem::new( + instructions.to_string(), + )) + })) + .icon_size(IconSize::Small) + .tooltip(|cx| Tooltip::text("Copy access token", cx)), + ) + .when(status == DevServerStatus::Offline, |this| { + this.child(Label::new("Waiting for connection...")) + }) + .when(status == DevServerStatus::Online, |this| { + this.child(Label::new("Connection established! 🎊")).child( + Button::new("done", "Done").on_click(cx.listener(|this, _, cx| { + this.mode = Mode::Default; + cx.notify(); + })), + ) + }), + ) + }) + } + + fn render_default(&mut self, cx: &mut ViewContext) -> impl IntoElement { + let channel_store = self.channel_store.read(cx); + let dev_servers = channel_store.dev_servers_for_id(self.channel_id); + // let dev_servers = Vec::new(); + + v_flex() + .id("scroll-container") + .h_full() + .overflow_y_scroll() + .track_scroll(&self.scroll_handle) + .px_1() + .pt_0p5() + .gap_px() + .child( + ModalHeader::new("Manage Remote Project") + .child(Headline::new("Remote Projects").size(HeadlineSize::Small)), + ) + .child( + ModalContent::new().child( + List::new() + .empty_message("No dev servers registered.") + .header(Some( + ListHeader::new("Dev Servers").end_slot( + Button::new("register-dev-server-button", "New Server") + .icon(IconName::Plus) + .icon_position(IconPosition::Start) + .tooltip(|cx| Tooltip::text("Register a new dev server", cx)) + .on_click(cx.listener(|this, _, cx| { + this.mode = Mode::CreateDevServer(Default::default()); + this.dev_server_name_editor + .read(cx) + .focus_handle(cx) + .focus(cx); + cx.notify(); + })), + ), + )) + .children(dev_servers.iter().map(|dev_server| { + self.render_dev_server(dev_server, cx).into_any_element() + })), + ), + ) + } + + fn render_create_project(&self, cx: &mut ViewContext) -> impl IntoElement { + let Mode::CreateRemoteProject(CreateRemoteProject { + dev_server_id, + creating, + remote_project, + }) = &self.mode + else { + unreachable!() + }; + let channel_store = self.channel_store.read(cx); + let (dev_server_name, dev_server_status) = channel_store + .find_dev_server_by_id(*dev_server_id) + .map(|server| (server.name.clone(), server.status)) + .unwrap_or((SharedString::from(""), DevServerStatus::Offline)); + v_flex() + .px_1() + .pt_0p5() + .gap_px() + .child( + ModalHeader::new("Manage Remote Project") + .child(Headline::new("Manage Remote Projects")), + ) + .child( + h_flex() + .py_0p5() + .px_1() + .child(div().px_1().py_0p5().child( + IconButton::new("back", IconName::ArrowLeft).on_click(cx.listener( + |this, _, cx| { + this.mode = Mode::Default; + cx.notify() + }, + )), + )) + .child("Add Project..."), + ) + .child( + h_flex() + .ml_5() + .gap_2() + .child( + div() + .id(("status", dev_server_id.0)) + .relative() + .child(Icon::new(IconName::Server)) + .child(div().absolute().bottom_0().left(rems_from_px(12.0)).child( + Indicator::dot().color(match dev_server_status { + DevServerStatus::Online => Color::Created, + DevServerStatus::Offline => Color::Deleted, + }), + )) + .tooltip(move |cx| { + Tooltip::text( + match dev_server_status { + DevServerStatus::Online => "Online", + DevServerStatus::Offline => "Offline", + }, + cx, + ) + }), + ) + .child(dev_server_name.clone()), + ) + .child( + h_flex() + .ml_5() + .gap_2() + .child("Name") + .child(self.remote_project_name_editor.clone()) + .on_action(cx.listener(|this, _: &menu::Confirm, cx| { + cx.focus_view(&this.remote_project_path_editor) + })), + ) + .child( + h_flex() + .ml_5() + .gap_2() + .child("Path") + .child(self.remote_project_path_editor.clone()) + .on_action( + cx.listener(|this, _: &menu::Confirm, cx| this.create_dev_server(cx)), + ) + .when(creating.is_none() && remote_project.is_none(), |div| { + div.child(Button::new("create-remote-server", "Create").on_click({ + let dev_server_id = *dev_server_id; + cx.listener(move |this, _, cx| { + this.create_remote_project(dev_server_id, cx) + }) + })) + }) + .when(creating.is_some(), |div| { + div.child(Button::new("create-dev-server", "Creating...").disabled(true)) + }), + ) + .when_some(remote_project.clone(), |div, remote_project| { + let channel_store = self.channel_store.read(cx); + let status = channel_store + .find_remote_project_by_id(RemoteProjectId(remote_project.id)) + .map(|project| { + if project.project_id.is_some() { + DevServerStatus::Online + } else { + DevServerStatus::Offline + } + }) + .unwrap_or(DevServerStatus::Offline); + div.child( + v_flex() + .ml_5() + .ml_8() + .gap_2() + .when(status == DevServerStatus::Offline, |this| { + this.child(Label::new("Waiting for project...")) + }) + .when(status == DevServerStatus::Online, |this| { + this.child(Label::new("Project online! 🎊")).child( + Button::new("done", "Done").on_click(cx.listener(|this, _, cx| { + this.mode = Mode::Default; + cx.notify(); + })), + ) + }), + ) + }) + } +} +impl ModalView for DevServerModal {} + +impl FocusableView for DevServerModal { + fn focus_handle(&self, _cx: &AppContext) -> FocusHandle { + self.focus_handle.clone() + } +} + +impl EventEmitter for DevServerModal {} + +impl Render for DevServerModal { + fn render(&mut self, cx: &mut ViewContext) -> impl IntoElement { + div() + .track_focus(&self.focus_handle) + .elevation_3(cx) + .key_context("DevServerModal") + .on_action(cx.listener(Self::cancel)) + .pb_4() + .w(rems(34.)) + .min_h(rems(20.)) + .max_h(rems(40.)) + .child(match &self.mode { + Mode::Default => self.render_default(cx).into_any_element(), + Mode::CreateRemoteProject(_) => self.render_create_project(cx).into_any_element(), + Mode::CreateDevServer(_) => self.render_create_dev_server(cx).into_any_element(), + }) + } +} diff --git a/crates/feature_flags/src/feature_flags.rs b/crates/feature_flags/src/feature_flags.rs index 700f70be78..0823e2f460 100644 --- a/crates/feature_flags/src/feature_flags.rs +++ b/crates/feature_flags/src/feature_flags.rs @@ -18,6 +18,11 @@ pub trait FeatureFlag { const NAME: &'static str; } +pub struct Remoting {} +impl FeatureFlag for Remoting { + const NAME: &'static str = "remoting"; +} + pub trait FeatureFlagViewExt { fn observe_flag(&mut self, callback: F) -> Subscription where diff --git a/crates/gpui/src/app/async_context.rs b/crates/gpui/src/app/async_context.rs index de87e35fd0..b251c638f5 100644 --- a/crates/gpui/src/app/async_context.rs +++ b/crates/gpui/src/app/async_context.rs @@ -1,10 +1,12 @@ use crate::{ AnyView, AnyWindowHandle, AppCell, AppContext, BackgroundExecutor, BorrowAppContext, Context, - DismissEvent, FocusableView, ForegroundExecutor, Global, Model, ModelContext, Render, - Reservation, Result, Task, View, ViewContext, VisualContext, WindowContext, WindowHandle, + DismissEvent, FocusableView, ForegroundExecutor, Global, Model, ModelContext, PromptLevel, + Render, Reservation, Result, Task, View, ViewContext, VisualContext, WindowContext, + WindowHandle, }; use anyhow::{anyhow, Context as _}; use derive_more::{Deref, DerefMut}; +use futures::channel::oneshot; use std::{future::Future, rc::Weak}; /// An async-friendly version of [AppContext] with a static lifetime so it can be held across `await` points in async code. @@ -285,6 +287,21 @@ impl AsyncWindowContext { { self.foreground_executor.spawn(f(self.clone())) } + + /// Present a platform dialog. + /// The provided message will be presented, along with buttons for each answer. + /// When a button is clicked, the returned Receiver will receive the index of the clicked button. + pub fn prompt( + &mut self, + level: PromptLevel, + message: &str, + detail: Option<&str>, + answers: &[&str], + ) -> oneshot::Receiver { + self.window + .update(self, |_, cx| cx.prompt(level, message, detail, answers)) + .unwrap_or_else(|_| oneshot::channel().1) + } } impl Context for AsyncWindowContext { diff --git a/crates/gpui/src/platform.rs b/crates/gpui/src/platform.rs index 509160471c..62be295632 100644 --- a/crates/gpui/src/platform.rs +++ b/crates/gpui/src/platform.rs @@ -73,12 +73,17 @@ pub(crate) fn current_platform() -> Rc { #[cfg(target_os = "linux")] pub(crate) fn current_platform() -> Rc { let wayland_display = std::env::var_os("WAYLAND_DISPLAY"); + let x11_display = std::env::var_os("DISPLAY"); + let use_wayland = wayland_display.is_some_and(|display| !display.is_empty()); + let use_x11 = x11_display.is_some_and(|display| !display.is_empty()); if use_wayland { Rc::new(WaylandClient::new()) - } else { + } else if use_x11 { Rc::new(X11Client::new()) + } else { + Rc::new(HeadlessClient::new()) } } // todo("windows") diff --git a/crates/gpui/src/platform/linux.rs b/crates/gpui/src/platform/linux.rs index 6bf7cc4840..1628e22f37 100644 --- a/crates/gpui/src/platform/linux.rs +++ b/crates/gpui/src/platform/linux.rs @@ -2,11 +2,13 @@ #![allow(unused)] mod dispatcher; +mod headless; mod platform; mod wayland; mod x11; pub(crate) use dispatcher::*; +pub(crate) use headless::*; pub(crate) use platform::*; pub(crate) use wayland::*; pub(crate) use x11::*; diff --git a/crates/gpui/src/platform/linux/headless.rs b/crates/gpui/src/platform/linux/headless.rs new file mode 100644 index 0000000000..2237aeb194 --- /dev/null +++ b/crates/gpui/src/platform/linux/headless.rs @@ -0,0 +1,3 @@ +mod client; + +pub(crate) use client::*; diff --git a/crates/gpui/src/platform/linux/headless/client.rs b/crates/gpui/src/platform/linux/headless/client.rs new file mode 100644 index 0000000000..fdad401851 --- /dev/null +++ b/crates/gpui/src/platform/linux/headless/client.rs @@ -0,0 +1,98 @@ +use std::cell::RefCell; +use std::ops::Deref; +use std::rc::Rc; +use std::time::{Duration, Instant}; + +use calloop::{EventLoop, LoopHandle}; +use collections::HashMap; + +use util::ResultExt; + +use crate::platform::linux::LinuxClient; +use crate::platform::{LinuxCommon, PlatformWindow}; +use crate::{ + px, AnyWindowHandle, Bounds, CursorStyle, DisplayId, Modifiers, ModifiersChangedEvent, Pixels, + PlatformDisplay, PlatformInput, Point, ScrollDelta, Size, TouchPhase, WindowParams, +}; + +use calloop::{ + generic::{FdWrapper, Generic}, + RegistrationToken, +}; + +pub struct HeadlessClientState { + pub(crate) loop_handle: LoopHandle<'static, HeadlessClient>, + pub(crate) event_loop: Option>, + pub(crate) common: LinuxCommon, +} + +#[derive(Clone)] +pub(crate) struct HeadlessClient(Rc>); + +impl HeadlessClient { + pub(crate) fn new() -> Self { + let event_loop = EventLoop::try_new().unwrap(); + + let (common, main_receiver) = LinuxCommon::new(event_loop.get_signal()); + + let handle = event_loop.handle(); + + handle.insert_source(main_receiver, |event, _, _: &mut HeadlessClient| { + if let calloop::channel::Event::Msg(runnable) = event { + runnable.run(); + } + }); + + HeadlessClient(Rc::new(RefCell::new(HeadlessClientState { + event_loop: Some(event_loop), + loop_handle: handle, + common, + }))) + } +} + +impl LinuxClient for HeadlessClient { + fn with_common(&self, f: impl FnOnce(&mut LinuxCommon) -> R) -> R { + f(&mut self.0.borrow_mut().common) + } + + fn displays(&self) -> Vec> { + vec![] + } + + fn primary_display(&self) -> Option> { + None + } + + fn display(&self, id: DisplayId) -> Option> { + None + } + + fn open_window( + &self, + _handle: AnyWindowHandle, + params: WindowParams, + ) -> Box { + unimplemented!() + } + + //todo(linux) + fn set_cursor_style(&self, _style: CursorStyle) {} + + fn write_to_clipboard(&self, item: crate::ClipboardItem) {} + + fn read_from_clipboard(&self) -> Option { + None + } + + fn run(&self) { + let mut event_loop = self + .0 + .borrow_mut() + .event_loop + .take() + .expect("App is already running"); + + event_loop.run(None, &mut self.clone(), |_| {}).log_err(); + } +} diff --git a/crates/headless/Cargo.toml b/crates/headless/Cargo.toml new file mode 100644 index 0000000000..772f625a6f --- /dev/null +++ b/crates/headless/Cargo.toml @@ -0,0 +1,36 @@ +[package] +name = "headless" +version = "0.1.0" +edition = "2021" +publish = false +license = "GPL-3.0-or-later" + +[lints] +workspace = true + +[lib] +path = "src/headless.rs" +doctest = false + +[dependencies] +anyhow.workspace = true +client.workspace = true +ctrlc.workspace = true +gpui.workspace = true +log.workspace = true +rpc.workspace = true +util.workspace = true +node_runtime.workspace = true +language.workspace = true +project.workspace = true +fs.workspace = true +futures.workspace = true +settings.workspace = true +postage.workspace = true + +[dev-dependencies] +client = { workspace = true, features = ["test-support"] } +fs = { workspace = true, features = ["test-support"] } +gpui = { workspace = true, features = ["test-support"] } +rpc = { workspace = true, features = ["test-support"] } +util = { workspace = true, features = ["test-support"] } diff --git a/crates/headless/LICENSE-GPL b/crates/headless/LICENSE-GPL new file mode 120000 index 0000000000..89e542f750 --- /dev/null +++ b/crates/headless/LICENSE-GPL @@ -0,0 +1 @@ +../../LICENSE-GPL \ No newline at end of file diff --git a/crates/headless/src/headless.rs b/crates/headless/src/headless.rs new file mode 100644 index 0000000000..677389b637 --- /dev/null +++ b/crates/headless/src/headless.rs @@ -0,0 +1,265 @@ +use anyhow::Result; +use client::{user::UserStore, Client, ClientSettings, RemoteProjectId}; +use fs::Fs; +use futures::Future; +use gpui::{AppContext, AsyncAppContext, Context, Global, Model, ModelContext, Task, WeakModel}; +use language::LanguageRegistry; +use node_runtime::NodeRuntime; +use postage::stream::Stream; +use project::Project; +use rpc::{proto, TypedEnvelope}; +use settings::Settings; +use std::{collections::HashMap, sync::Arc}; +use util::{ResultExt, TryFutureExt}; + +pub struct DevServer { + client: Arc, + app_state: AppState, + projects: HashMap>, + _subscriptions: Vec, + _maintain_connection: Task>, +} + +pub struct AppState { + pub node_runtime: Arc, + pub user_store: Model, + pub languages: Arc, + pub fs: Arc, +} + +struct GlobalDevServer(Model); + +impl Global for GlobalDevServer {} + +pub fn init(client: Arc, app_state: AppState, cx: &mut AppContext) { + let dev_server = cx.new_model(|cx| DevServer::new(client.clone(), app_state, cx)); + cx.set_global(GlobalDevServer(dev_server.clone())); + + // Set up a handler when the dev server is shut down by the user pressing Ctrl-C + let (tx, rx) = futures::channel::oneshot::channel(); + set_ctrlc_handler(move || tx.send(()).log_err().unwrap()).log_err(); + + cx.spawn(|cx| async move { + rx.await.log_err(); + log::info!("Received interrupt signal"); + cx.update(|cx| cx.quit()).log_err(); + }) + .detach(); + + let server_url = ClientSettings::get_global(&cx).server_url.clone(); + cx.spawn(|cx| async move { + match client.authenticate_and_connect(false, &cx).await { + Ok(_) => { + log::info!("Connected to {}", server_url); + } + Err(e) => { + log::error!("Error connecting to {}: {}", server_url, e); + cx.update(|cx| cx.quit()).log_err(); + } + } + }) + .detach(); +} + +fn set_ctrlc_handler(f: F) -> Result<(), ctrlc::Error> +where + F: FnOnce() + 'static + Send, +{ + let f = std::sync::Mutex::new(Some(f)); + ctrlc::set_handler(move || { + if let Ok(mut guard) = f.lock() { + let f = guard.take().expect("f can only be taken once"); + f(); + } + }) +} + +impl DevServer { + pub fn global(cx: &AppContext) -> Model { + cx.global::().0.clone() + } + + pub fn new(client: Arc, app_state: AppState, cx: &mut ModelContext) -> Self { + cx.on_app_quit(Self::app_will_quit).detach(); + + let maintain_connection = cx.spawn({ + let client = client.clone(); + move |this, cx| Self::maintain_connection(this, client.clone(), cx).log_err() + }); + + DevServer { + _subscriptions: vec![ + client.add_message_handler(cx.weak_model(), Self::handle_dev_server_instructions) + ], + _maintain_connection: maintain_connection, + projects: Default::default(), + app_state, + client, + } + } + + fn app_will_quit(&mut self, _: &mut ModelContext) -> impl Future { + let request = self.client.request(proto::ShutdownDevServer {}); + async move { + request.await.log_err(); + } + } + + async fn handle_dev_server_instructions( + this: Model, + envelope: TypedEnvelope, + _: Arc, + mut cx: AsyncAppContext, + ) -> Result<()> { + let (added_projects, removed_projects_ids) = this.read_with(&mut cx, |this, _| { + let removed_projects = this + .projects + .keys() + .filter(|remote_project_id| { + !envelope + .payload + .projects + .iter() + .any(|p| p.id == remote_project_id.0) + }) + .cloned() + .collect::>(); + + let added_projects = envelope + .payload + .projects + .into_iter() + .filter(|project| !this.projects.contains_key(&RemoteProjectId(project.id))) + .collect::>(); + + (added_projects, removed_projects) + })?; + + for remote_project in added_projects { + DevServer::share_project(this.clone(), &remote_project, &mut cx).await?; + } + + this.update(&mut cx, |this, cx| { + for old_project_id in &removed_projects_ids { + this.unshare_project(old_project_id, cx)?; + } + Ok::<(), anyhow::Error>(()) + })??; + Ok(()) + } + + fn unshare_project( + &mut self, + remote_project_id: &RemoteProjectId, + cx: &mut ModelContext, + ) -> Result<()> { + if let Some(project) = self.projects.remove(remote_project_id) { + project.update(cx, |project, cx| project.unshare(cx))?; + } + Ok(()) + } + + async fn share_project( + this: Model, + remote_project: &proto::RemoteProject, + cx: &mut AsyncAppContext, + ) -> Result<()> { + let (client, project) = this.update(cx, |this, cx| { + let project = Project::local( + this.client.clone(), + this.app_state.node_runtime.clone(), + this.app_state.user_store.clone(), + this.app_state.languages.clone(), + this.app_state.fs.clone(), + cx, + ); + + (this.client.clone(), project) + })?; + + project + .update(cx, |project, cx| { + project.find_or_create_local_worktree(&remote_project.path, true, cx) + })? + .await?; + + let worktrees = + project.read_with(cx, |project, cx| project.worktree_metadata_protos(cx))?; + + let response = client + .request(proto::ShareRemoteProject { + remote_project_id: remote_project.id, + worktrees, + }) + .await?; + + let project_id = response.project_id; + project.update(cx, |project, cx| project.shared(project_id, cx))??; + this.update(cx, |this, _| { + this.projects + .insert(RemoteProjectId(remote_project.id), project); + })?; + Ok(()) + } + + async fn maintain_connection( + this: WeakModel, + client: Arc, + mut cx: AsyncAppContext, + ) -> Result<()> { + let mut client_status = client.status(); + + let _ = client_status.try_recv(); + let current_status = *client_status.borrow(); + if current_status.is_connected() { + // wait for first disconnect + client_status.recv().await; + } + + loop { + let Some(current_status) = client_status.recv().await else { + return Ok(()); + }; + let Some(this) = this.upgrade() else { + return Ok(()); + }; + + if !current_status.is_connected() { + continue; + } + + this.update(&mut cx, |this, cx| this.rejoin(cx))?.await?; + } + } + + fn rejoin(&mut self, cx: &mut ModelContext) -> Task> { + let mut projects: HashMap> = HashMap::default(); + let request = self.client.request(proto::ReconnectDevServer { + reshared_projects: self + .projects + .iter() + .flat_map(|(_, handle)| { + let project = handle.read(cx); + let project_id = project.remote_id()?; + projects.insert(project_id, handle.clone()); + Some(proto::UpdateProject { + project_id, + worktrees: project.worktree_metadata_protos(cx), + }) + }) + .collect(), + }); + cx.spawn(|_, mut cx| async move { + let response = request.await?; + + for reshared_project in response.reshared_projects { + if let Some(project) = projects.get(&reshared_project.id) { + project.update(&mut cx, |project, cx| { + project.reshared(reshared_project, cx).log_err(); + })?; + } + } + Ok(()) + }) + } +} diff --git a/crates/project/src/connection_manager.rs b/crates/project/src/connection_manager.rs new file mode 100644 index 0000000000..f9b342ea98 --- /dev/null +++ b/crates/project/src/connection_manager.rs @@ -0,0 +1,212 @@ +use super::Project; +use anyhow::Result; +use client::Client; +use collections::{HashMap, HashSet}; +use futures::{FutureExt, StreamExt}; +use gpui::{AppContext, AsyncAppContext, Context, Global, Model, ModelContext, Task, WeakModel}; +use postage::stream::Stream; +use rpc::proto; +use std::{sync::Arc, time::Duration}; +use util::{ResultExt, TryFutureExt}; + +impl Global for GlobalManager {} +struct GlobalManager(Model); + +pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30); + +pub struct Manager { + client: Arc, + maintain_connection: Option>>, + projects: HashSet>, +} + +pub fn init(client: Arc, cx: &mut AppContext) { + let manager = cx.new_model(|_| Manager { + client, + maintain_connection: None, + projects: HashSet::default(), + }); + cx.set_global(GlobalManager(manager)); +} + +impl Manager { + pub fn global(cx: &AppContext) -> Model { + cx.global::().0.clone() + } + + pub fn maintain_project_connection( + &mut self, + project: &Model, + cx: &mut ModelContext, + ) { + let manager = cx.weak_model(); + project.update(cx, |_, cx| { + let manager = manager.clone(); + cx.on_release(move |project, cx| { + manager + .update(cx, |manager, cx| { + manager.projects.retain(|p| { + if let Some(p) = p.upgrade() { + p.read(cx).remote_id() != project.remote_id() + } else { + false + } + }); + if manager.projects.is_empty() { + manager.maintain_connection.take(); + } + }) + .ok(); + }) + .detach(); + }); + + self.projects.insert(project.downgrade()); + if self.maintain_connection.is_none() { + self.maintain_connection = Some(cx.spawn({ + let client = self.client.clone(); + move |_, cx| Self::maintain_connection(manager, client.clone(), cx).log_err() + })); + } + } + + fn reconnected(&mut self, cx: &mut ModelContext) -> Task> { + let mut projects = HashMap::default(); + + let request = self.client.request_envelope(proto::RejoinRemoteProjects { + rejoined_projects: self + .projects + .iter() + .filter_map(|project| { + if let Some(handle) = project.upgrade() { + let project = handle.read(cx); + let project_id = project.remote_id()?; + projects.insert(project_id, handle.clone()); + Some(proto::RejoinProject { + id: project_id, + worktrees: project + .worktrees() + .map(|worktree| { + let worktree = worktree.read(cx); + proto::RejoinWorktree { + id: worktree.id().to_proto(), + scan_id: worktree.completed_scan_id() as u64, + } + }) + .collect(), + }) + } else { + None + } + }) + .collect(), + }); + + cx.spawn(|this, mut cx| async move { + let response = request.await?; + let message_id = response.message_id; + + this.update(&mut cx, |_, cx| { + for rejoined_project in response.payload.rejoined_projects { + if let Some(project) = projects.get(&rejoined_project.id) { + project.update(cx, |project, cx| { + project.rejoined(rejoined_project, message_id, cx).log_err(); + }); + } + } + }) + }) + } + + fn connection_lost(&mut self, cx: &mut ModelContext) { + for project in self.projects.drain() { + if let Some(project) = project.upgrade() { + project.update(cx, |project, cx| { + project.disconnected_from_host(cx); + project.close(cx); + }); + } + } + self.maintain_connection.take(); + } + + async fn maintain_connection( + this: WeakModel, + client: Arc, + mut cx: AsyncAppContext, + ) -> Result<()> { + let mut client_status = client.status(); + loop { + let _ = client_status.try_recv(); + + let is_connected = client_status.borrow().is_connected(); + // Even if we're initially connected, any future change of the status means we momentarily disconnected. + if !is_connected || client_status.next().await.is_some() { + log::info!("detected client disconnection"); + + // Wait for client to re-establish a connection to the server. + { + let mut reconnection_timeout = + cx.background_executor().timer(RECONNECT_TIMEOUT).fuse(); + let client_reconnection = async { + let mut remaining_attempts = 3; + while remaining_attempts > 0 { + if client_status.borrow().is_connected() { + log::info!("client reconnected, attempting to rejoin projects"); + + let Some(this) = this.upgrade() else { break }; + match this.update(&mut cx, |this, cx| this.reconnected(cx)) { + Ok(task) => { + if task.await.log_err().is_some() { + return true; + } else { + remaining_attempts -= 1; + } + } + Err(_app_dropped) => return false, + } + } else if client_status.borrow().is_signed_out() { + return false; + } + + log::info!( + "waiting for client status change, remaining attempts {}", + remaining_attempts + ); + client_status.next().await; + } + false + } + .fuse(); + futures::pin_mut!(client_reconnection); + + futures::select_biased! { + reconnected = client_reconnection => { + if reconnected { + log::info!("successfully reconnected"); + // If we successfully joined the room, go back around the loop + // waiting for future connection status changes. + continue; + } + } + _ = reconnection_timeout => { + log::info!("rejoin project reconnection timeout expired"); + } + } + } + + break; + } + } + + // The client failed to re-establish a connection to the server + // or an error occurred while trying to re-join the room. Either way + // we leave the room and return an error. + if let Some(this) = this.upgrade() { + log::info!("reconnection failed, disconnecting projects"); + let _ = this.update(&mut cx, |this, cx| this.connection_lost(cx))?; + } + + Ok(()) + } +} diff --git a/crates/project/src/project.rs b/crates/project/src/project.rs index bd08662dea..d6e5f62516 100644 --- a/crates/project/src/project.rs +++ b/crates/project/src/project.rs @@ -1,3 +1,4 @@ +pub mod connection_manager; pub mod debounced_delay; pub mod lsp_command; pub mod lsp_ext_command; @@ -234,6 +235,7 @@ enum BufferOrderedMessage { Resync, } +#[derive(Debug)] enum LocalProjectUpdate { WorktreesChanged, CreateBufferForPeer { @@ -597,6 +599,7 @@ impl Project { } pub fn init(client: &Arc, cx: &mut AppContext) { + connection_manager::init(client.clone(), cx); Self::init_settings(cx); client.add_model_message_handler(Self::handle_add_collaborator); @@ -733,6 +736,24 @@ impl Project { languages: Arc, fs: Arc, cx: AsyncAppContext, + ) -> Result> { + let project = + Self::in_room(remote_id, client, user_store, languages, fs, cx.clone()).await?; + cx.update(|cx| { + connection_manager::Manager::global(cx).update(cx, |manager, cx| { + manager.maintain_project_connection(&project, cx) + }) + })?; + Ok(project) + } + + pub async fn in_room( + remote_id: u64, + client: Arc, + user_store: Model, + languages: Arc, + fs: Arc, + cx: AsyncAppContext, ) -> Result> { client.authenticate_and_connect(true, &cx).await?; @@ -753,6 +774,7 @@ impl Project { ) .await } + async fn from_join_project_response( response: TypedEnvelope, subscription: PendingEntitySubscription, @@ -1561,7 +1583,7 @@ impl Project { }) })? .await; - if update_project.is_ok() { + if update_project.log_err().is_some() { for worktree in worktrees { worktree.update(&mut cx, |worktree, cx| { let worktree = worktree.as_local_mut().unwrap(); diff --git a/crates/rpc/proto/zed.proto b/crates/rpc/proto/zed.proto index 76779c85e3..b97cbead3b 100644 --- a/crates/rpc/proto/zed.proto +++ b/crates/rpc/proto/zed.proto @@ -213,7 +213,21 @@ message Envelope { UpdateNotification update_notification = 174; MultiLspQuery multi_lsp_query = 175; - MultiLspQueryResponse multi_lsp_query_response = 176; // current max + MultiLspQueryResponse multi_lsp_query_response = 176; + + CreateRemoteProject create_remote_project = 177; + CreateRemoteProjectResponse create_remote_project_response = 188; // current max + CreateDevServer create_dev_server = 178; + CreateDevServerResponse create_dev_server_response = 179; + ShutdownDevServer shutdown_dev_server = 180; + DevServerInstructions dev_server_instructions = 181; + ReconnectDevServer reconnect_dev_server = 182; + ReconnectDevServerResponse reconnect_dev_server_response = 183; + + ShareRemoteProject share_remote_project = 184; + JoinRemoteProject join_remote_project = 185; + RejoinRemoteProjects rejoin_remote_projects = 186; + RejoinRemoteProjectsResponse rejoin_remote_projects_response = 187; } reserved 158 to 161; @@ -249,6 +263,7 @@ enum ErrorCode { WrongMoveTarget = 11; UnsharedItem = 12; NoSuchProject = 13; + DevServerAlreadyOnline = 14; reserved 6; } @@ -280,6 +295,13 @@ message RejoinRoom { repeated UpdateProject reshared_projects = 2; repeated RejoinProject rejoined_projects = 3; } +message RejoinRemoteProjects { + repeated RejoinProject rejoined_projects = 1; +} + +message RejoinRemoteProjectsResponse { + repeated RejoinedProject rejoined_projects = 1; +} message RejoinProject { uint64 id = 1; @@ -429,6 +451,52 @@ message JoinHostedProject { uint64 project_id = 1; } +message CreateRemoteProject { + uint64 channel_id = 1; + string name = 2; + uint64 dev_server_id = 3; + string path = 4; +} +message CreateRemoteProjectResponse { + RemoteProject remote_project = 1; +} + +message CreateDevServer { + uint64 channel_id = 1; + string name = 2; +} + +message CreateDevServerResponse { + uint64 dev_server_id = 1; + uint64 channel_id = 2; + string access_token = 3; + string name = 4; +} + +message ShutdownDevServer { +} + +message ReconnectDevServer { + repeated UpdateProject reshared_projects = 1; +} + +message ReconnectDevServerResponse { + repeated ResharedProject reshared_projects = 1; +} + +message DevServerInstructions { + repeated RemoteProject projects = 1; +} + +message ShareRemoteProject { + uint64 remote_project_id = 1; + repeated WorktreeMetadata worktrees = 2; +} + +message JoinRemoteProject { + uint64 remote_project_id = 1; +} + message JoinProjectResponse { uint64 project_id = 5; uint32 replica_id = 1; @@ -1057,6 +1125,12 @@ message UpdateChannels { repeated HostedProject hosted_projects = 10; repeated uint64 deleted_hosted_projects = 11; + + repeated DevServer dev_servers = 12; + repeated uint64 deleted_dev_servers = 13; + + repeated RemoteProject remote_projects = 14; + repeated uint64 deleted_remote_projects = 15; } message UpdateUserChannels { @@ -1092,6 +1166,27 @@ message HostedProject { ChannelVisibility visibility = 4; } +message RemoteProject { + uint64 id = 1; + optional uint64 project_id = 2; + uint64 channel_id = 3; + string name = 4; + uint64 dev_server_id = 5; + string path = 6; +} + +message DevServer { + uint64 channel_id = 1; + uint64 dev_server_id = 2; + string name = 3; + DevServerStatus status = 4; +} + +enum DevServerStatus { + Offline = 0; + Online = 1; +} + message JoinChannel { uint64 channel_id = 1; } diff --git a/crates/rpc/src/proto.rs b/crates/rpc/src/proto.rs index c5a8f7d32b..a117648cec 100644 --- a/crates/rpc/src/proto.rs +++ b/crates/rpc/src/proto.rs @@ -299,6 +299,18 @@ messages!( (SetRoomParticipantRole, Foreground), (BlameBuffer, Foreground), (BlameBufferResponse, Foreground), + (CreateRemoteProject, Foreground), + (CreateRemoteProjectResponse, Foreground), + (CreateDevServer, Foreground), + (CreateDevServerResponse, Foreground), + (DevServerInstructions, Foreground), + (ShutdownDevServer, Foreground), + (ReconnectDevServer, Foreground), + (ReconnectDevServerResponse, Foreground), + (ShareRemoteProject, Foreground), + (JoinRemoteProject, Foreground), + (RejoinRemoteProjects, Foreground), + (RejoinRemoteProjectsResponse, Foreground), (MultiLspQuery, Background), (MultiLspQueryResponse, Background), ); @@ -392,6 +404,13 @@ request_messages!( (LspExtExpandMacro, LspExtExpandMacroResponse), (SetRoomParticipantRole, Ack), (BlameBuffer, BlameBufferResponse), + (CreateRemoteProject, CreateRemoteProjectResponse), + (CreateDevServer, CreateDevServerResponse), + (ShutdownDevServer, Ack), + (ShareRemoteProject, ShareProjectResponse), + (JoinRemoteProject, JoinProjectResponse), + (RejoinRemoteProjects, RejoinRemoteProjectsResponse), + (ReconnectDevServer, ReconnectDevServerResponse), (MultiLspQuery, MultiLspQueryResponse), ); diff --git a/crates/ui/src/components.rs b/crates/ui/src/components.rs index b3fc5dd2ee..2a38130720 100644 --- a/crates/ui/src/components.rs +++ b/crates/ui/src/components.rs @@ -9,6 +9,7 @@ mod indicator; mod keybinding; mod label; mod list; +mod modal; mod popover; mod popover_menu; mod right_click_menu; @@ -32,6 +33,7 @@ pub use indicator::*; pub use keybinding::*; pub use label::*; pub use list::*; +pub use modal::*; pub use popover::*; pub use popover_menu::*; pub use right_click_menu::*; diff --git a/crates/ui/src/components/icon.rs b/crates/ui/src/components/icon.rs index 9236e3c9fb..11806f7c1d 100644 --- a/crates/ui/src/components/icon.rs +++ b/crates/ui/src/components/icon.rs @@ -106,12 +106,14 @@ pub enum IconName { Settings, Screen, SelectAll, + Server, Shift, Snip, Space, Split, Tab, Terminal, + Trash, Update, WholeWord, XCircle, @@ -202,12 +204,14 @@ impl IconName { IconName::Settings => "icons/file_icons/settings.svg", IconName::Screen => "icons/desktop.svg", IconName::SelectAll => "icons/select_all.svg", + IconName::Server => "icons/server.svg", IconName::Shift => "icons/shift.svg", IconName::Snip => "icons/snip.svg", IconName::Space => "icons/space.svg", IconName::Split => "icons/split.svg", IconName::Tab => "icons/tab.svg", IconName::Terminal => "icons/terminal.svg", + IconName::Trash => "icons/trash.svg", IconName::Update => "icons/update.svg", IconName::WholeWord => "icons/word_search.svg", IconName::XCircle => "icons/error.svg", diff --git a/crates/ui/src/components/modal.rs b/crates/ui/src/components/modal.rs new file mode 100644 index 0000000000..7ce9707b0a --- /dev/null +++ b/crates/ui/src/components/modal.rs @@ -0,0 +1,133 @@ +use gpui::*; +use smallvec::SmallVec; + +use crate::{h_flex, IconButton, IconButtonShape, IconName, Label, LabelCommon, LabelSize}; + +#[derive(IntoElement)] +pub struct ModalHeader { + id: ElementId, + children: SmallVec<[AnyElement; 2]>, +} + +impl ModalHeader { + pub fn new(id: impl Into) -> Self { + Self { + id: id.into(), + children: SmallVec::new(), + } + } +} + +impl ParentElement for ModalHeader { + fn extend(&mut self, elements: impl Iterator) { + self.children.extend(elements) + } +} + +impl RenderOnce for ModalHeader { + fn render(self, _cx: &mut WindowContext) -> impl IntoElement { + h_flex() + .id(self.id) + .w_full() + .px_2() + .py_1p5() + .child(div().flex_1().children(self.children)) + .justify_between() + .child(IconButton::new("dismiss", IconName::Close).shape(IconButtonShape::Square)) + } +} + +#[derive(IntoElement)] +pub struct ModalContent { + children: SmallVec<[AnyElement; 2]>, +} + +impl ModalContent { + pub fn new() -> Self { + Self { + children: SmallVec::new(), + } + } +} + +impl ParentElement for ModalContent { + fn extend(&mut self, elements: impl Iterator) { + self.children.extend(elements) + } +} + +impl RenderOnce for ModalContent { + fn render(self, _cx: &mut WindowContext) -> impl IntoElement { + h_flex().w_full().px_2().py_1p5().children(self.children) + } +} + +#[derive(IntoElement)] +pub struct ModalRow { + children: SmallVec<[AnyElement; 2]>, +} + +impl ModalRow { + pub fn new() -> Self { + Self { + children: SmallVec::new(), + } + } +} + +impl ParentElement for ModalRow { + fn extend(&mut self, elements: impl Iterator) { + self.children.extend(elements) + } +} + +impl RenderOnce for ModalRow { + fn render(self, _cx: &mut WindowContext) -> impl IntoElement { + h_flex().w_full().px_2().py_1().children(self.children) + } +} + +#[derive(IntoElement)] +pub struct SectionHeader { + /// The label of the header. + label: SharedString, + /// A slot for content that appears after the label, usually on the other side of the header. + /// This might be a button, a disclosure arrow, a face pile, etc. + end_slot: Option, +} + +impl SectionHeader { + pub fn new(label: impl Into) -> Self { + Self { + label: label.into(), + end_slot: None, + } + } + + pub fn end_slot(mut self, end_slot: impl Into>) -> Self { + self.end_slot = end_slot.into().map(IntoElement::into_any_element); + self + } +} + +impl RenderOnce for SectionHeader { + fn render(self, _cx: &mut WindowContext) -> impl IntoElement { + h_flex().id(self.label.clone()).w_full().child( + div() + .h_7() + .flex() + .items_center() + .justify_between() + .w_full() + .gap_1() + .child( + div().flex_1().child( + Label::new(self.label.clone()) + .size(LabelSize::Large) + .into_element(), + ), + ) + .child(h_flex().children(self.end_slot)), + ) + } +} diff --git a/crates/workspace/src/workspace.rs b/crates/workspace/src/workspace.rs index d38f252a68..1f53f3230f 100644 --- a/crates/workspace/src/workspace.rs +++ b/crates/workspace/src/workspace.rs @@ -4704,6 +4704,61 @@ pub fn join_hosted_project( }) } +pub fn join_remote_project( + project_id: ProjectId, + app_state: Arc, + cx: &mut AppContext, +) -> Task>> { + let windows = cx.windows(); + cx.spawn(|mut cx| async move { + let existing_workspace = windows.into_iter().find_map(|window| { + window.downcast::().and_then(|window| { + window + .update(&mut cx, |workspace, cx| { + if workspace.project().read(cx).remote_id() == Some(project_id.0) { + Some(window) + } else { + None + } + }) + .unwrap_or(None) + }) + }); + + let workspace = if let Some(existing_workspace) = existing_workspace { + existing_workspace + } else { + let project = Project::remote( + project_id.0, + app_state.client.clone(), + app_state.user_store.clone(), + app_state.languages.clone(), + app_state.fs.clone(), + cx.clone(), + ) + .await?; + + let window_bounds_override = window_bounds_env_override(); + cx.update(|cx| { + let mut options = (app_state.build_window_options)(None, cx); + options.bounds = window_bounds_override; + cx.open_window(options, |cx| { + cx.new_view(|cx| { + Workspace::new(Default::default(), project, app_state.clone(), cx) + }) + }) + })? + }; + + workspace.update(&mut cx, |_, cx| { + cx.activate(true); + cx.activate_window(); + })?; + + anyhow::Ok(workspace) + }) +} + pub fn join_in_room_project( project_id: u64, follow_user_id: u64, diff --git a/crates/zed/Cargo.toml b/crates/zed/Cargo.toml index cd5791ccac..3d8ff9a0cf 100644 --- a/crates/zed/Cargo.toml +++ b/crates/zed/Cargo.toml @@ -48,6 +48,7 @@ fs.workspace = true futures.workspace = true go_to_line.workspace = true gpui.workspace = true +headless.workspace = true image_viewer.workspace = true install_cli.workspace = true isahc.workspace = true diff --git a/crates/zed/src/main.rs b/crates/zed/src/main.rs index 7a6f767d60..df09366f59 100644 --- a/crates/zed/src/main.rs +++ b/crates/zed/src/main.rs @@ -10,9 +10,7 @@ use backtrace::Backtrace; use chrono::Utc; use clap::{command, Parser}; use cli::FORCE_CLI_MODE_ENV_VAR_NAME; -use client::{ - parse_zed_link, telemetry::Telemetry, Client, ClientSettings, DevServerToken, UserStore, -}; +use client::{parse_zed_link, telemetry::Telemetry, Client, DevServerToken, UserStore}; use collab_ui::channel_view::ChannelView; use copilot::Copilot; use copilot_ui::CopilotCompletionProvider; @@ -88,7 +86,72 @@ fn fail_to_launch(e: anyhow::Error) { }) } -fn main() { +fn init_headless(dev_server_token: DevServerToken) { + if let Err(e) = init_paths() { + log::error!("Failed to launch: {}", e); + return; + } + init_logger(); + + App::new().run(|cx| { + release_channel::init(env!("CARGO_PKG_VERSION"), cx); + if let Some(build_sha) = option_env!("ZED_COMMIT_SHA") { + AppCommitSha::set_global(AppCommitSha(build_sha.into()), cx); + } + + let mut store = SettingsStore::default(); + store + .set_default_settings(default_settings().as_ref(), cx) + .unwrap(); + cx.set_global(store); + + client::init_settings(cx); + + let clock = Arc::new(clock::RealSystemClock); + let http = Arc::new(HttpClientWithUrl::new( + &client::ClientSettings::get_global(cx).server_url, + )); + + let client = client::Client::new(clock, http.clone(), cx); + let client = client.clone(); + client.set_dev_server_token(dev_server_token); + + project::Project::init(&client, cx); + client::init(&client, cx); + + let git_binary_path = if option_env!("ZED_BUNDLE").as_deref() == Some("true") { + cx.path_for_auxiliary_executable("git") + .context("could not find git binary path") + .log_err() + } else { + None + }; + let fs = Arc::new(RealFs::new(git_binary_path)); + + let mut languages = + LanguageRegistry::new(Task::ready(()), cx.background_executor().clone()); + languages.set_language_server_download_dir(paths::LANGUAGES_DIR.clone()); + let languages = Arc::new(languages); + let node_runtime = RealNodeRuntime::new(http.clone()); + + language::init(cx); + languages::init(languages.clone(), node_runtime.clone(), cx); + let user_store = cx.new_model(|cx| UserStore::new(client.clone(), cx)); + + headless::init( + client.clone(), + headless::AppState { + languages: languages.clone(), + user_store: user_store.clone(), + fs: fs.clone(), + node_runtime: node_runtime.clone(), + }, + cx, + ); + }) +} + +fn init_ui() { menu::init(); zed_actions::init(); @@ -269,7 +332,6 @@ fn main() { .to_string(), ); telemetry.flush_events(); - let app_state = Arc::new(AppState { languages: languages.clone(), client: client.clone(), @@ -277,7 +339,7 @@ fn main() { fs: fs.clone(), build_window_options, workspace_store, - node_runtime, + node_runtime: node_runtime.clone(), }); AppState::set_global(Arc::downgrade(&app_state), cx); @@ -319,31 +381,17 @@ fn main() { cx.activate(true); - let mut args = Args::parse(); - if let Some(dev_server_token) = args.dev_server_token.take() { - let dev_server_token = DevServerToken(dev_server_token); - let server_url = ClientSettings::get_global(&cx).server_url.clone(); - let client = client.clone(); - client.set_dev_server_token(dev_server_token); - cx.spawn(|cx| async move { - client.authenticate_and_connect(false, &cx).await?; - log::info!("Connected to {}", server_url); - anyhow::Ok(()) - }) - .detach_and_log_err(cx); - } else { - let urls: Vec<_> = args - .paths_or_urls - .iter() - .filter_map(|arg| parse_url_arg(arg, cx).log_err()) - .collect(); - - if !urls.is_empty() { - listener.open_urls(urls) - } - } - + let args = Args::parse(); let mut triggered_authentication = false; + let urls: Vec<_> = args + .paths_or_urls + .iter() + .filter_map(|arg| parse_url_arg(arg, cx).log_err()) + .collect(); + + if !urls.is_empty() { + listener.open_urls(urls) + } match open_rx .try_next() @@ -382,6 +430,16 @@ fn main() { }); } +fn main() { + let mut args = Args::parse(); + if let Some(dev_server_token) = args.dev_server_token.take() { + let dev_server_token = DevServerToken(dev_server_token); + init_headless(dev_server_token) + } else { + init_ui() + } +} + fn handle_open_request( request: OpenRequest, app_state: Arc,