diff --git a/Cargo.lock b/Cargo.lock index fcbf5a0f95..0c951dd441 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1710,6 +1710,78 @@ dependencies = [ "workspace", ] +[[package]] +name = "collab2" +version = "0.28.0" +dependencies = [ + "anyhow", + "async-trait", + "async-tungstenite", + "audio2", + "axum", + "axum-extra", + "base64 0.13.1", + "call2", + "channel2", + "clap 3.2.25", + "client2", + "clock", + "collab_ui", + "collections", + "ctor", + "dashmap", + "editor2", + "env_logger 0.9.3", + "envy", + "fs2", + "futures 0.3.28", + "git3", + "gpui2", + "hyper", + "indoc", + "language2", + "lazy_static", + "lipsum", + "live_kit_client2", + "live_kit_server", + "log", + "lsp2", + "nanoid", + "node_runtime", + "parking_lot 0.11.2", + "pretty_assertions", + "project2", + "prometheus", + "prost 0.8.0", + "rand 0.8.5", + "reqwest", + "rpc2", + "scrypt", + "sea-orm", + "serde", + "serde_derive", + "serde_json", + "settings2", + "sha-1 0.9.8", + "smallvec", + "sqlx", + "text2", + "theme2", + "time", + "tokio", + "tokio-tungstenite", + "toml 0.5.11", + "tonic", + "tower", + "tracing", + "tracing-log", + "tracing-subscriber", + "unindent", + "util", + "uuid 1.4.1", + "workspace2", +] + [[package]] name = "collab_ui" version = "0.1.0" @@ -7241,8 +7313,10 @@ dependencies = [ "rsa 0.4.0", "serde", "serde_derive", + "serde_json", "smol", "smol-timeout", + "strum", "tempdir", "tracing", "util", diff --git a/Cargo.toml b/Cargo.toml index 81aff80c90..f2ad2949fe 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,6 +16,7 @@ members = [ "crates/client2", "crates/clock", "crates/collab", + "crates/collab2", "crates/collab_ui", "crates/collections", "crates/command_palette", diff --git a/crates/call2/src/call2.rs b/crates/call2/src/call2.rs index 477931919d..b19720bcdc 100644 --- a/crates/call2/src/call2.rs +++ b/crates/call2/src/call2.rs @@ -10,7 +10,7 @@ use client::{ ZED_ALWAYS_ACTIVE, }; use collections::HashSet; -use futures::{future::Shared, FutureExt}; +use futures::{channel::oneshot, future::Shared, Future, FutureExt}; use gpui::{ AppContext, AsyncAppContext, Context, EventEmitter, Model, ModelContext, Subscription, Task, WeakModel, @@ -30,6 +30,37 @@ pub fn init(client: Arc, user_store: Model, cx: &mut AppConte cx.set_global(active_call); } +pub struct OneAtATime { + cancel: Option>, +} + +impl OneAtATime { + /// spawn a task in the given context. + /// if another task is spawned before that resolves, or if the OneAtATime itself is dropped, the first task will be cancelled and return Ok(None) + /// otherwise you'll see the result of the task. + fn spawn(&mut self, cx: &mut AppContext, f: F) -> Task>> + where + F: 'static + FnOnce(AsyncAppContext) -> Fut, + Fut: Future>, + R: 'static, + { + let (tx, rx) = oneshot::channel(); + self.cancel.replace(tx); + cx.spawn(|cx| async move { + futures::select_biased! { + _ = rx.fuse() => Ok(None), + result = f(cx).fuse() => result.map(Some), + } + }) + } + + fn running(&self) -> bool { + self.cancel + .as_ref() + .is_some_and(|cancel| !cancel.is_canceled()) + } +} + #[derive(Clone)] pub struct IncomingCall { pub room_id: u64, @@ -43,6 +74,7 @@ pub struct ActiveCall { room: Option<(Model, Vec)>, pending_room_creation: Option, Arc>>>>, location: Option>, + _join_debouncer: OneAtATime, pending_invites: HashSet, incoming_call: ( watch::Sender>, @@ -65,7 +97,7 @@ impl ActiveCall { location: None, pending_invites: Default::default(), incoming_call: watch::channel(), - + _join_debouncer: OneAtATime { cancel: None }, _subscriptions: vec![ client.add_request_handler(cx.weak_model(), Self::handle_incoming_call), client.add_message_handler(cx.weak_model(), Self::handle_call_canceled), @@ -140,6 +172,10 @@ impl ActiveCall { } cx.notify(); + if self._join_debouncer.running() { + return Task::ready(Ok(())); + } + let room = if let Some(room) = self.room().cloned() { Some(Task::ready(Ok(room)).shared()) } else { @@ -256,11 +292,20 @@ impl ActiveCall { return Task::ready(Err(anyhow!("no incoming call"))); }; - let join = Room::join(&call, self.client.clone(), self.user_store.clone(), cx); + if self.pending_room_creation.is_some() { + return Task::ready(Ok(())); + } + + let room_id = call.room_id.clone(); + let client = self.client.clone(); + let user_store = self.user_store.clone(); + let join = self + ._join_debouncer + .spawn(cx, move |cx| Room::join(room_id, client, user_store, cx)); cx.spawn(|this, mut cx| async move { let room = join.await?; - this.update(&mut cx, |this, cx| this.set_room(Some(room.clone()), cx))? + this.update(&mut cx, |this, cx| this.set_room(room.clone(), cx))? .await?; this.update(&mut cx, |this, cx| { this.report_call_event("accept incoming", cx) @@ -287,20 +332,28 @@ impl ActiveCall { &mut self, channel_id: u64, cx: &mut ModelContext, - ) -> Task>> { + ) -> Task>>> { if let Some(room) = self.room().cloned() { if room.read(cx).channel_id() == Some(channel_id) { - return Task::ready(Ok(room)); + return Task::ready(Ok(Some(room))); } else { room.update(cx, |room, cx| room.clear_state(cx)); } } - let join = Room::join_channel(channel_id, self.client.clone(), self.user_store.clone(), cx); + if self.pending_room_creation.is_some() { + return Task::ready(Ok(None)); + } + + let client = self.client.clone(); + let user_store = self.user_store.clone(); + let join = self._join_debouncer.spawn(cx, move |cx| async move { + Room::join_channel(channel_id, client, user_store, cx).await + }); cx.spawn(|this, mut cx| async move { let room = join.await?; - this.update(&mut cx, |this, cx| this.set_room(Some(room.clone()), cx))? + this.update(&mut cx, |this, cx| this.set_room(room.clone(), cx))? .await?; this.update(&mut cx, |this, cx| { this.report_call_event("join channel", cx) @@ -459,3 +512,40 @@ pub fn report_call_event_for_channel( }; telemetry.report_clickhouse_event(event, telemetry_settings); } + +#[cfg(test)] +mod test { + use gpui::TestAppContext; + + use crate::OneAtATime; + + #[gpui::test] + async fn test_one_at_a_time(cx: &mut TestAppContext) { + let mut one_at_a_time = OneAtATime { cancel: None }; + + assert_eq!( + cx.update(|cx| one_at_a_time.spawn(cx, |_| async { Ok(1) })) + .await + .unwrap(), + Some(1) + ); + + let (a, b) = cx.update(|cx| { + ( + one_at_a_time.spawn(cx, |_| async { + assert!(false); + Ok(2) + }), + one_at_a_time.spawn(cx, |_| async { Ok(3) }), + ) + }); + + assert_eq!(a.await.unwrap(), None); + assert_eq!(b.await.unwrap(), Some(3)); + + let promise = cx.update(|cx| one_at_a_time.spawn(cx, |_| async { Ok(4) })); + drop(one_at_a_time); + + assert_eq!(promise.await.unwrap(), None); + } +} diff --git a/crates/call2/src/room.rs b/crates/call2/src/room.rs index a46269a508..27bc51a277 100644 --- a/crates/call2/src/room.rs +++ b/crates/call2/src/room.rs @@ -1,7 +1,6 @@ use crate::{ call_settings::CallSettings, participant::{LocalParticipant, ParticipantLocation, RemoteParticipant}, - IncomingCall, }; use anyhow::{anyhow, Result}; use audio::{Audio, Sound}; @@ -284,37 +283,32 @@ impl Room { }) } - pub(crate) fn join_channel( + pub(crate) async fn join_channel( channel_id: u64, client: Arc, user_store: Model, - cx: &mut AppContext, - ) -> Task>> { - cx.spawn(move |cx| async move { - Self::from_join_response( - client.request(proto::JoinChannel { channel_id }).await?, - client, - user_store, - cx, - ) - }) + cx: AsyncAppContext, + ) -> Result> { + Self::from_join_response( + client.request(proto::JoinChannel { channel_id }).await?, + client, + user_store, + cx, + ) } - pub(crate) fn join( - call: &IncomingCall, + pub(crate) async fn join( + room_id: u64, client: Arc, user_store: Model, - cx: &mut AppContext, - ) -> Task>> { - let id = call.room_id; - cx.spawn(move |cx| async move { - Self::from_join_response( - client.request(proto::JoinRoom { id }).await?, - client, - user_store, - cx, - ) - }) + cx: AsyncAppContext, + ) -> Result> { + Self::from_join_response( + client.request(proto::JoinRoom { id: room_id }).await?, + client, + user_store, + cx, + ) } fn released(&mut self, cx: &mut AppContext) { diff --git a/crates/collab/src/tests/test_server.rs b/crates/collab/src/tests/test_server.rs index d6ebe1e84e..245f34ebac 100644 --- a/crates/collab/src/tests/test_server.rs +++ b/crates/collab/src/tests/test_server.rs @@ -124,6 +124,7 @@ impl TestServer { if cx.has_global::() { panic!("Same cx used to create two test clients") } + cx.set_global(SettingsStore::test(cx)); }); diff --git a/crates/collab2/.env.toml b/crates/collab2/.env.toml new file mode 100644 index 0000000000..01866012ea --- /dev/null +++ b/crates/collab2/.env.toml @@ -0,0 +1,12 @@ +DATABASE_URL = "postgres://postgres@localhost/zed" +DATABASE_MAX_CONNECTIONS = 5 +HTTP_PORT = 8080 +API_TOKEN = "secret" +INVITE_LINK_PREFIX = "http://localhost:3000/invites/" +ZED_ENVIRONMENT = "development" +LIVE_KIT_SERVER = "http://localhost:7880" +LIVE_KIT_KEY = "devkey" +LIVE_KIT_SECRET = "secret" + +# RUST_LOG=info +# LOG_JSON=true diff --git a/crates/collab2/Cargo.toml b/crates/collab2/Cargo.toml new file mode 100644 index 0000000000..fe050a2aa8 --- /dev/null +++ b/crates/collab2/Cargo.toml @@ -0,0 +1,101 @@ +[package] +authors = ["Nathan Sobo "] +default-run = "collab2" +edition = "2021" +name = "collab2" +version = "0.28.0" +publish = false + +[[bin]] +name = "collab2" + +[[bin]] +name = "seed" +required-features = ["seed-support"] + +[dependencies] +clock = { path = "../clock" } +collections = { path = "../collections" } +live_kit_server = { path = "../live_kit_server" } +text = { package = "text2", path = "../text2" } +rpc = { package = "rpc2", path = "../rpc2" } +util = { path = "../util" } + +anyhow.workspace = true +async-tungstenite = "0.16" +axum = { version = "0.5", features = ["json", "headers", "ws"] } +axum-extra = { version = "0.3", features = ["erased-json"] } +base64 = "0.13" +clap = { version = "3.1", features = ["derive"], optional = true } +dashmap = "5.4" +envy = "0.4.2" +futures.workspace = true +hyper = "0.14" +lazy_static.workspace = true +lipsum = { version = "0.8", optional = true } +log.workspace = true +nanoid = "0.4" +parking_lot.workspace = true +prometheus = "0.13" +prost.workspace = true +rand.workspace = true +reqwest = { version = "0.11", features = ["json"], optional = true } +scrypt = "0.7" +smallvec.workspace = true +sea-orm = { version = "0.12.x", features = ["sqlx-postgres", "postgres-array", "runtime-tokio-rustls", "with-uuid"] } +serde.workspace = true +serde_derive.workspace = true +serde_json.workspace = true +sha-1 = "0.9" +sqlx = { version = "0.7", features = ["runtime-tokio-rustls", "postgres", "json", "time", "uuid", "any"] } +time.workspace = true +tokio = { version = "1", features = ["full"] } +tokio-tungstenite = "0.17" +tonic = "0.6" +tower = "0.4" +toml.workspace = true +tracing = "0.1.34" +tracing-log = "0.1.3" +tracing-subscriber = { version = "0.3.11", features = ["env-filter", "json"] } +uuid.workspace = true + +[dev-dependencies] +audio = { package = "audio2", path = "../audio2" } +collections = { path = "../collections", features = ["test-support"] } +gpui = { package = "gpui2", path = "../gpui2", features = ["test-support"] } +call = { package = "call2", path = "../call2", features = ["test-support"] } +client = { package = "client2", path = "../client2", features = ["test-support"] } +channel = { package = "channel2", path = "../channel2" } +editor = { package = "editor2", path = "../editor2", features = ["test-support"] } +language = { package = "language2", path = "../language2", features = ["test-support"] } +fs = { package = "fs2", path = "../fs2", features = ["test-support"] } +git = { package = "git3", path = "../git3", features = ["test-support"] } +live_kit_client = { package = "live_kit_client2", path = "../live_kit_client2", features = ["test-support"] } +lsp = { package = "lsp2", path = "../lsp2", features = ["test-support"] } + +node_runtime = { path = "../node_runtime" } +#todo!(notifications) +#notifications = { path = "../notifications", features = ["test-support"] } + +project = { package = "project2", path = "../project2", features = ["test-support"] } +rpc = { package = "rpc2", path = "../rpc2", features = ["test-support"] } +settings = { package = "settings2", path = "../settings2", features = ["test-support"] } +theme = { package = "theme2", path = "../theme2" } +workspace = { package = "workspace2", path = "../workspace2", features = ["test-support"] } + +collab_ui = { path = "../collab_ui", features = ["test-support"] } + +async-trait.workspace = true +pretty_assertions.workspace = true +ctor.workspace = true +env_logger.workspace = true +indoc.workspace = true +util = { path = "../util" } +lazy_static.workspace = true +sea-orm = { version = "0.12.x", features = ["sqlx-sqlite"] } +serde_json.workspace = true +sqlx = { version = "0.7", features = ["sqlite"] } +unindent.workspace = true + +[features] +seed-support = ["clap", "lipsum", "reqwest"] diff --git a/crates/collab2/README.md b/crates/collab2/README.md new file mode 100644 index 0000000000..d766324255 --- /dev/null +++ b/crates/collab2/README.md @@ -0,0 +1,5 @@ +# Zed Server + +This crate is what we run at https://collab.zed.dev. + +It contains our back-end logic for collaboration, to which we connect from the Zed client via a websocket after authenticating via https://zed.dev, which is a separate repo running on Vercel. diff --git a/crates/collab2/admin_api.conf b/crates/collab2/admin_api.conf new file mode 100644 index 0000000000..5d3b0e65b7 --- /dev/null +++ b/crates/collab2/admin_api.conf @@ -0,0 +1,4 @@ +db-uri = "postgres://postgres@localhost/zed" +server-port = 8081 +jwt-secret = "the-postgrest-jwt-secret-for-authorization" +log-level = "info" diff --git a/crates/collab2/basic.conf b/crates/collab2/basic.conf new file mode 100644 index 0000000000..c6db392dba --- /dev/null +++ b/crates/collab2/basic.conf @@ -0,0 +1,12 @@ + +[Interface] +PrivateKey = B5Fp/yVfP0QYlb+YJv9ea+EMI1mWODPD3akh91cVjvc= +Address = fdaa:0:2ce3:a7b:bea:0:a:2/120 +DNS = fdaa:0:2ce3::3 + +[Peer] +PublicKey = RKAYPljEJiuaELNDdQIEJmQienT9+LRISfIHwH45HAw= +AllowedIPs = fdaa:0:2ce3::/48 +Endpoint = ord1.gateway.6pn.dev:51820 +PersistentKeepalive = 15 + diff --git a/crates/collab2/k8s/environments/preview.sh b/crates/collab2/k8s/environments/preview.sh new file mode 100644 index 0000000000..132a1ef53c --- /dev/null +++ b/crates/collab2/k8s/environments/preview.sh @@ -0,0 +1,4 @@ +ZED_ENVIRONMENT=preview +RUST_LOG=info +INVITE_LINK_PREFIX=https://zed.dev/invites/ +DATABASE_MAX_CONNECTIONS=10 diff --git a/crates/collab2/k8s/environments/production.sh b/crates/collab2/k8s/environments/production.sh new file mode 100644 index 0000000000..cb1d4b4de7 --- /dev/null +++ b/crates/collab2/k8s/environments/production.sh @@ -0,0 +1,4 @@ +ZED_ENVIRONMENT=production +RUST_LOG=info +INVITE_LINK_PREFIX=https://zed.dev/invites/ +DATABASE_MAX_CONNECTIONS=85 diff --git a/crates/collab2/k8s/environments/staging.sh b/crates/collab2/k8s/environments/staging.sh new file mode 100644 index 0000000000..b9689ccb19 --- /dev/null +++ b/crates/collab2/k8s/environments/staging.sh @@ -0,0 +1,4 @@ +ZED_ENVIRONMENT=staging +RUST_LOG=info +INVITE_LINK_PREFIX=https://staging.zed.dev/invites/ +DATABASE_MAX_CONNECTIONS=5 diff --git a/crates/collab2/k8s/manifest.template.yml b/crates/collab2/k8s/manifest.template.yml new file mode 100644 index 0000000000..d4a7a7033e --- /dev/null +++ b/crates/collab2/k8s/manifest.template.yml @@ -0,0 +1,177 @@ +--- +apiVersion: v1 +kind: Namespace +metadata: + name: ${ZED_KUBE_NAMESPACE} + +--- +kind: Service +apiVersion: v1 +metadata: + namespace: ${ZED_KUBE_NAMESPACE} + name: collab + annotations: + service.beta.kubernetes.io/do-loadbalancer-tls-ports: "443" + service.beta.kubernetes.io/do-loadbalancer-certificate-id: ${ZED_DO_CERTIFICATE_ID} +spec: + type: LoadBalancer + selector: + app: collab + ports: + - name: web + protocol: TCP + port: 443 + targetPort: 8080 + +--- +kind: Service +apiVersion: v1 +metadata: + namespace: ${ZED_KUBE_NAMESPACE} + name: pgadmin + annotations: + service.beta.kubernetes.io/do-loadbalancer-tls-ports: "443" + service.beta.kubernetes.io/do-loadbalancer-certificate-id: ${ZED_DO_CERTIFICATE_ID} +spec: + type: LoadBalancer + selector: + app: postgrest + ports: + - name: web + protocol: TCP + port: 443 + targetPort: 8080 + +--- +apiVersion: apps/v1 +kind: Deployment +metadata: + namespace: ${ZED_KUBE_NAMESPACE} + name: collab + +spec: + replicas: 1 + selector: + matchLabels: + app: collab + template: + metadata: + labels: + app: collab + annotations: + ad.datadoghq.com/collab.check_names: | + ["openmetrics"] + ad.datadoghq.com/collab.init_configs: | + [{}] + ad.datadoghq.com/collab.instances: | + [ + { + "openmetrics_endpoint": "http://%%host%%:%%port%%/metrics", + "namespace": "collab_${ZED_KUBE_NAMESPACE}", + "metrics": [".*"] + } + ] + spec: + containers: + - name: collab + image: "${ZED_IMAGE_ID}" + args: + - serve + ports: + - containerPort: 8080 + protocol: TCP + livenessProbe: + httpGet: + path: /healthz + port: 8080 + initialDelaySeconds: 5 + periodSeconds: 5 + timeoutSeconds: 5 + readinessProbe: + httpGet: + path: / + port: 8080 + initialDelaySeconds: 1 + periodSeconds: 1 + env: + - name: HTTP_PORT + value: "8080" + - name: DATABASE_URL + valueFrom: + secretKeyRef: + name: database + key: url + - name: DATABASE_MAX_CONNECTIONS + value: "${DATABASE_MAX_CONNECTIONS}" + - name: API_TOKEN + valueFrom: + secretKeyRef: + name: api + key: token + - name: LIVE_KIT_SERVER + valueFrom: + secretKeyRef: + name: livekit + key: server + - name: LIVE_KIT_KEY + valueFrom: + secretKeyRef: + name: livekit + key: key + - name: LIVE_KIT_SECRET + valueFrom: + secretKeyRef: + name: livekit + key: secret + - name: INVITE_LINK_PREFIX + value: ${INVITE_LINK_PREFIX} + - name: RUST_BACKTRACE + value: "1" + - name: RUST_LOG + value: ${RUST_LOG} + - name: LOG_JSON + value: "true" + - name: ZED_ENVIRONMENT + value: ${ZED_ENVIRONMENT} + securityContext: + capabilities: + # FIXME - Switch to the more restrictive `PERFMON` capability. + # This capability isn't yet available in a stable version of Debian. + add: ["SYS_ADMIN"] + +--- +apiVersion: apps/v1 +kind: Deployment +metadata: + namespace: ${ZED_KUBE_NAMESPACE} + name: postgrest + +spec: + replicas: 1 + selector: + matchLabels: + app: postgrest + template: + metadata: + labels: + app: postgrest + spec: + containers: + - name: postgrest + image: "postgrest/postgrest" + ports: + - containerPort: 8080 + protocol: TCP + env: + - name: PGRST_SERVER_PORT + value: "8080" + - name: PGRST_DB_URI + valueFrom: + secretKeyRef: + name: database + key: url + - name: PGRST_JWT_SECRET + valueFrom: + secretKeyRef: + name: postgrest + key: jwt_secret diff --git a/crates/collab2/k8s/migrate.template.yml b/crates/collab2/k8s/migrate.template.yml new file mode 100644 index 0000000000..c890d7b330 --- /dev/null +++ b/crates/collab2/k8s/migrate.template.yml @@ -0,0 +1,21 @@ +apiVersion: batch/v1 +kind: Job +metadata: + namespace: ${ZED_KUBE_NAMESPACE} + name: ${ZED_MIGRATE_JOB_NAME} +spec: + template: + spec: + restartPolicy: Never + containers: + - name: migrator + imagePullPolicy: Always + image: ${ZED_IMAGE_ID} + args: + - migrate + env: + - name: DATABASE_URL + valueFrom: + secretKeyRef: + name: database + key: url diff --git a/crates/collab2/migrations.sqlite/20221109000000_test_schema.sql b/crates/collab2/migrations.sqlite/20221109000000_test_schema.sql new file mode 100644 index 0000000000..775a4c1bbe --- /dev/null +++ b/crates/collab2/migrations.sqlite/20221109000000_test_schema.sql @@ -0,0 +1,344 @@ +CREATE TABLE "users" ( + "id" INTEGER PRIMARY KEY AUTOINCREMENT, + "github_login" VARCHAR, + "admin" BOOLEAN, + "email_address" VARCHAR(255) DEFAULT NULL, + "invite_code" VARCHAR(64), + "invite_count" INTEGER NOT NULL DEFAULT 0, + "inviter_id" INTEGER REFERENCES users (id), + "connected_once" BOOLEAN NOT NULL DEFAULT false, + "created_at" TIMESTAMP NOT NULL DEFAULT now, + "metrics_id" TEXT, + "github_user_id" INTEGER +); +CREATE UNIQUE INDEX "index_users_github_login" ON "users" ("github_login"); +CREATE UNIQUE INDEX "index_invite_code_users" ON "users" ("invite_code"); +CREATE INDEX "index_users_on_email_address" ON "users" ("email_address"); +CREATE INDEX "index_users_on_github_user_id" ON "users" ("github_user_id"); + +CREATE TABLE "access_tokens" ( + "id" INTEGER PRIMARY KEY AUTOINCREMENT, + "user_id" INTEGER REFERENCES users (id), + "hash" VARCHAR(128) +); +CREATE INDEX "index_access_tokens_user_id" ON "access_tokens" ("user_id"); + +CREATE TABLE "contacts" ( + "id" INTEGER PRIMARY KEY AUTOINCREMENT, + "user_id_a" INTEGER REFERENCES users (id) NOT NULL, + "user_id_b" INTEGER REFERENCES users (id) NOT NULL, + "a_to_b" BOOLEAN NOT NULL, + "should_notify" BOOLEAN NOT NULL, + "accepted" BOOLEAN NOT NULL +); +CREATE UNIQUE INDEX "index_contacts_user_ids" ON "contacts" ("user_id_a", "user_id_b"); +CREATE INDEX "index_contacts_user_id_b" ON "contacts" ("user_id_b"); + +CREATE TABLE "rooms" ( + "id" INTEGER PRIMARY KEY AUTOINCREMENT, + "live_kit_room" VARCHAR NOT NULL, + "enviroment" VARCHAR, + "channel_id" INTEGER REFERENCES channels (id) ON DELETE CASCADE +); +CREATE UNIQUE INDEX "index_rooms_on_channel_id" ON "rooms" ("channel_id"); + +CREATE TABLE "projects" ( + "id" INTEGER PRIMARY KEY AUTOINCREMENT, + "room_id" INTEGER REFERENCES rooms (id) ON DELETE CASCADE NOT NULL, + "host_user_id" INTEGER REFERENCES users (id) NOT NULL, + "host_connection_id" INTEGER, + "host_connection_server_id" INTEGER REFERENCES servers (id) ON DELETE CASCADE, + "unregistered" BOOLEAN NOT NULL DEFAULT FALSE +); +CREATE INDEX "index_projects_on_host_connection_server_id" ON "projects" ("host_connection_server_id"); +CREATE INDEX "index_projects_on_host_connection_id_and_host_connection_server_id" ON "projects" ("host_connection_id", "host_connection_server_id"); + +CREATE TABLE "worktrees" ( + "project_id" INTEGER NOT NULL REFERENCES projects (id) ON DELETE CASCADE, + "id" INTEGER NOT NULL, + "root_name" VARCHAR NOT NULL, + "abs_path" VARCHAR NOT NULL, + "visible" BOOL NOT NULL, + "scan_id" INTEGER NOT NULL, + "is_complete" BOOL NOT NULL DEFAULT FALSE, + "completed_scan_id" INTEGER NOT NULL, + PRIMARY KEY(project_id, id) +); +CREATE INDEX "index_worktrees_on_project_id" ON "worktrees" ("project_id"); + +CREATE TABLE "worktree_entries" ( + "project_id" INTEGER NOT NULL, + "worktree_id" INTEGER NOT NULL, + "scan_id" INTEGER NOT NULL, + "id" INTEGER NOT NULL, + "is_dir" BOOL NOT NULL, + "path" VARCHAR NOT NULL, + "inode" INTEGER NOT NULL, + "mtime_seconds" INTEGER NOT NULL, + "mtime_nanos" INTEGER NOT NULL, + "is_symlink" BOOL NOT NULL, + "is_external" BOOL NOT NULL, + "is_ignored" BOOL NOT NULL, + "is_deleted" BOOL NOT NULL, + "git_status" INTEGER, + PRIMARY KEY(project_id, worktree_id, id), + FOREIGN KEY(project_id, worktree_id) REFERENCES worktrees (project_id, id) ON DELETE CASCADE +); +CREATE INDEX "index_worktree_entries_on_project_id" ON "worktree_entries" ("project_id"); +CREATE INDEX "index_worktree_entries_on_project_id_and_worktree_id" ON "worktree_entries" ("project_id", "worktree_id"); + +CREATE TABLE "worktree_repositories" ( + "project_id" INTEGER NOT NULL, + "worktree_id" INTEGER NOT NULL, + "work_directory_id" INTEGER NOT NULL, + "branch" VARCHAR, + "scan_id" INTEGER NOT NULL, + "is_deleted" BOOL NOT NULL, + PRIMARY KEY(project_id, worktree_id, work_directory_id), + FOREIGN KEY(project_id, worktree_id) REFERENCES worktrees (project_id, id) ON DELETE CASCADE, + FOREIGN KEY(project_id, worktree_id, work_directory_id) REFERENCES worktree_entries (project_id, worktree_id, id) ON DELETE CASCADE +); +CREATE INDEX "index_worktree_repositories_on_project_id" ON "worktree_repositories" ("project_id"); +CREATE INDEX "index_worktree_repositories_on_project_id_and_worktree_id" ON "worktree_repositories" ("project_id", "worktree_id"); + +CREATE TABLE "worktree_settings_files" ( + "project_id" INTEGER NOT NULL, + "worktree_id" INTEGER NOT NULL, + "path" VARCHAR NOT NULL, + "content" TEXT, + PRIMARY KEY(project_id, worktree_id, path), + FOREIGN KEY(project_id, worktree_id) REFERENCES worktrees (project_id, id) ON DELETE CASCADE +); +CREATE INDEX "index_worktree_settings_files_on_project_id" ON "worktree_settings_files" ("project_id"); +CREATE INDEX "index_worktree_settings_files_on_project_id_and_worktree_id" ON "worktree_settings_files" ("project_id", "worktree_id"); + +CREATE TABLE "worktree_diagnostic_summaries" ( + "project_id" INTEGER NOT NULL, + "worktree_id" INTEGER NOT NULL, + "path" VARCHAR NOT NULL, + "language_server_id" INTEGER NOT NULL, + "error_count" INTEGER NOT NULL, + "warning_count" INTEGER NOT NULL, + PRIMARY KEY(project_id, worktree_id, path), + FOREIGN KEY(project_id, worktree_id) REFERENCES worktrees (project_id, id) ON DELETE CASCADE +); +CREATE INDEX "index_worktree_diagnostic_summaries_on_project_id" ON "worktree_diagnostic_summaries" ("project_id"); +CREATE INDEX "index_worktree_diagnostic_summaries_on_project_id_and_worktree_id" ON "worktree_diagnostic_summaries" ("project_id", "worktree_id"); + +CREATE TABLE "language_servers" ( + "id" INTEGER NOT NULL, + "project_id" INTEGER NOT NULL REFERENCES projects (id) ON DELETE CASCADE, + "name" VARCHAR NOT NULL, + PRIMARY KEY(project_id, id) +); +CREATE INDEX "index_language_servers_on_project_id" ON "language_servers" ("project_id"); + +CREATE TABLE "project_collaborators" ( + "id" INTEGER PRIMARY KEY AUTOINCREMENT, + "project_id" INTEGER NOT NULL REFERENCES projects (id) ON DELETE CASCADE, + "connection_id" INTEGER NOT NULL, + "connection_server_id" INTEGER NOT NULL REFERENCES servers (id) ON DELETE CASCADE, + "user_id" INTEGER NOT NULL, + "replica_id" INTEGER NOT NULL, + "is_host" BOOLEAN NOT NULL +); +CREATE INDEX "index_project_collaborators_on_project_id" ON "project_collaborators" ("project_id"); +CREATE UNIQUE INDEX "index_project_collaborators_on_project_id_and_replica_id" ON "project_collaborators" ("project_id", "replica_id"); +CREATE INDEX "index_project_collaborators_on_connection_server_id" ON "project_collaborators" ("connection_server_id"); +CREATE INDEX "index_project_collaborators_on_connection_id" ON "project_collaborators" ("connection_id"); +CREATE UNIQUE INDEX "index_project_collaborators_on_project_id_connection_id_and_server_id" ON "project_collaborators" ("project_id", "connection_id", "connection_server_id"); + +CREATE TABLE "room_participants" ( + "id" INTEGER PRIMARY KEY AUTOINCREMENT, + "room_id" INTEGER NOT NULL REFERENCES rooms (id), + "user_id" INTEGER NOT NULL REFERENCES users (id), + "answering_connection_id" INTEGER, + "answering_connection_server_id" INTEGER REFERENCES servers (id) ON DELETE CASCADE, + "answering_connection_lost" BOOLEAN NOT NULL, + "location_kind" INTEGER, + "location_project_id" INTEGER, + "initial_project_id" INTEGER, + "calling_user_id" INTEGER NOT NULL REFERENCES users (id), + "calling_connection_id" INTEGER NOT NULL, + "calling_connection_server_id" INTEGER REFERENCES servers (id) ON DELETE SET NULL, + "participant_index" INTEGER +); +CREATE UNIQUE INDEX "index_room_participants_on_user_id" ON "room_participants" ("user_id"); +CREATE INDEX "index_room_participants_on_room_id" ON "room_participants" ("room_id"); +CREATE INDEX "index_room_participants_on_answering_connection_server_id" ON "room_participants" ("answering_connection_server_id"); +CREATE INDEX "index_room_participants_on_calling_connection_server_id" ON "room_participants" ("calling_connection_server_id"); +CREATE INDEX "index_room_participants_on_answering_connection_id" ON "room_participants" ("answering_connection_id"); +CREATE UNIQUE INDEX "index_room_participants_on_answering_connection_id_and_answering_connection_server_id" ON "room_participants" ("answering_connection_id", "answering_connection_server_id"); + +CREATE TABLE "servers" ( + "id" INTEGER PRIMARY KEY AUTOINCREMENT, + "environment" VARCHAR NOT NULL +); + +CREATE TABLE "followers" ( + "id" INTEGER PRIMARY KEY AUTOINCREMENT, + "room_id" INTEGER NOT NULL REFERENCES rooms (id) ON DELETE CASCADE, + "project_id" INTEGER NOT NULL REFERENCES projects (id) ON DELETE CASCADE, + "leader_connection_server_id" INTEGER NOT NULL REFERENCES servers (id) ON DELETE CASCADE, + "leader_connection_id" INTEGER NOT NULL, + "follower_connection_server_id" INTEGER NOT NULL REFERENCES servers (id) ON DELETE CASCADE, + "follower_connection_id" INTEGER NOT NULL +); +CREATE UNIQUE INDEX + "index_followers_on_project_id_and_leader_connection_server_id_and_leader_connection_id_and_follower_connection_server_id_and_follower_connection_id" +ON "followers" ("project_id", "leader_connection_server_id", "leader_connection_id", "follower_connection_server_id", "follower_connection_id"); +CREATE INDEX "index_followers_on_room_id" ON "followers" ("room_id"); + +CREATE TABLE "channels" ( + "id" INTEGER PRIMARY KEY AUTOINCREMENT, + "name" VARCHAR NOT NULL, + "created_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + "visibility" VARCHAR NOT NULL, + "parent_path" TEXT +); + +CREATE INDEX "index_channels_on_parent_path" ON "channels" ("parent_path"); + +CREATE TABLE IF NOT EXISTS "channel_chat_participants" ( + "id" INTEGER PRIMARY KEY AUTOINCREMENT, + "user_id" INTEGER NOT NULL REFERENCES users (id), + "channel_id" INTEGER NOT NULL REFERENCES channels (id) ON DELETE CASCADE, + "connection_id" INTEGER NOT NULL, + "connection_server_id" INTEGER NOT NULL REFERENCES servers (id) ON DELETE CASCADE +); +CREATE INDEX "index_channel_chat_participants_on_channel_id" ON "channel_chat_participants" ("channel_id"); + +CREATE TABLE IF NOT EXISTS "channel_messages" ( + "id" INTEGER PRIMARY KEY AUTOINCREMENT, + "channel_id" INTEGER NOT NULL REFERENCES channels (id) ON DELETE CASCADE, + "sender_id" INTEGER NOT NULL REFERENCES users (id), + "body" TEXT NOT NULL, + "sent_at" TIMESTAMP, + "nonce" BLOB NOT NULL +); +CREATE INDEX "index_channel_messages_on_channel_id" ON "channel_messages" ("channel_id"); +CREATE UNIQUE INDEX "index_channel_messages_on_sender_id_nonce" ON "channel_messages" ("sender_id", "nonce"); + +CREATE TABLE "channel_message_mentions" ( + "message_id" INTEGER NOT NULL REFERENCES channel_messages (id) ON DELETE CASCADE, + "start_offset" INTEGER NOT NULL, + "end_offset" INTEGER NOT NULL, + "user_id" INTEGER NOT NULL REFERENCES users (id) ON DELETE CASCADE, + PRIMARY KEY(message_id, start_offset) +); + +CREATE TABLE "channel_members" ( + "id" INTEGER PRIMARY KEY AUTOINCREMENT, + "channel_id" INTEGER NOT NULL REFERENCES channels (id) ON DELETE CASCADE, + "user_id" INTEGER NOT NULL REFERENCES users (id) ON DELETE CASCADE, + "admin" BOOLEAN NOT NULL DEFAULT false, + "role" VARCHAR, + "accepted" BOOLEAN NOT NULL DEFAULT false, + "updated_at" TIMESTAMP NOT NULL DEFAULT now +); + +CREATE UNIQUE INDEX "index_channel_members_on_channel_id_and_user_id" ON "channel_members" ("channel_id", "user_id"); + +CREATE TABLE "buffers" ( + "id" INTEGER PRIMARY KEY AUTOINCREMENT, + "channel_id" INTEGER NOT NULL REFERENCES channels (id) ON DELETE CASCADE, + "epoch" INTEGER NOT NULL DEFAULT 0 +); + +CREATE INDEX "index_buffers_on_channel_id" ON "buffers" ("channel_id"); + +CREATE TABLE "buffer_operations" ( + "buffer_id" INTEGER NOT NULL REFERENCES buffers (id) ON DELETE CASCADE, + "epoch" INTEGER NOT NULL, + "replica_id" INTEGER NOT NULL, + "lamport_timestamp" INTEGER NOT NULL, + "value" BLOB NOT NULL, + PRIMARY KEY(buffer_id, epoch, lamport_timestamp, replica_id) +); + +CREATE TABLE "buffer_snapshots" ( + "buffer_id" INTEGER NOT NULL REFERENCES buffers (id) ON DELETE CASCADE, + "epoch" INTEGER NOT NULL, + "text" TEXT NOT NULL, + "operation_serialization_version" INTEGER NOT NULL, + PRIMARY KEY(buffer_id, epoch) +); + +CREATE TABLE "channel_buffer_collaborators" ( + "id" INTEGER PRIMARY KEY AUTOINCREMENT, + "channel_id" INTEGER NOT NULL REFERENCES channels (id) ON DELETE CASCADE, + "connection_id" INTEGER NOT NULL, + "connection_server_id" INTEGER NOT NULL REFERENCES servers (id) ON DELETE CASCADE, + "connection_lost" BOOLEAN NOT NULL DEFAULT false, + "user_id" INTEGER NOT NULL REFERENCES users (id) ON DELETE CASCADE, + "replica_id" INTEGER NOT NULL +); + +CREATE INDEX "index_channel_buffer_collaborators_on_channel_id" ON "channel_buffer_collaborators" ("channel_id"); +CREATE UNIQUE INDEX "index_channel_buffer_collaborators_on_channel_id_and_replica_id" ON "channel_buffer_collaborators" ("channel_id", "replica_id"); +CREATE INDEX "index_channel_buffer_collaborators_on_connection_server_id" ON "channel_buffer_collaborators" ("connection_server_id"); +CREATE INDEX "index_channel_buffer_collaborators_on_connection_id" ON "channel_buffer_collaborators" ("connection_id"); +CREATE UNIQUE INDEX "index_channel_buffer_collaborators_on_channel_id_connection_id_and_server_id" ON "channel_buffer_collaborators" ("channel_id", "connection_id", "connection_server_id"); + + +CREATE TABLE "feature_flags" ( + "id" INTEGER PRIMARY KEY AUTOINCREMENT, + "flag" TEXT NOT NULL UNIQUE +); + +CREATE INDEX "index_feature_flags" ON "feature_flags" ("id"); + + +CREATE TABLE "user_features" ( + "user_id" INTEGER NOT NULL REFERENCES users (id) ON DELETE CASCADE, + "feature_id" INTEGER NOT NULL REFERENCES feature_flags (id) ON DELETE CASCADE, + PRIMARY KEY (user_id, feature_id) +); + +CREATE UNIQUE INDEX "index_user_features_user_id_and_feature_id" ON "user_features" ("user_id", "feature_id"); +CREATE INDEX "index_user_features_on_user_id" ON "user_features" ("user_id"); +CREATE INDEX "index_user_features_on_feature_id" ON "user_features" ("feature_id"); + + +CREATE TABLE "observed_buffer_edits" ( + "user_id" INTEGER NOT NULL REFERENCES users (id) ON DELETE CASCADE, + "buffer_id" INTEGER NOT NULL REFERENCES buffers (id) ON DELETE CASCADE, + "epoch" INTEGER NOT NULL, + "lamport_timestamp" INTEGER NOT NULL, + "replica_id" INTEGER NOT NULL, + PRIMARY KEY (user_id, buffer_id) +); + +CREATE UNIQUE INDEX "index_observed_buffers_user_and_buffer_id" ON "observed_buffer_edits" ("user_id", "buffer_id"); + +CREATE TABLE IF NOT EXISTS "observed_channel_messages" ( + "user_id" INTEGER NOT NULL REFERENCES users (id) ON DELETE CASCADE, + "channel_id" INTEGER NOT NULL REFERENCES channels (id) ON DELETE CASCADE, + "channel_message_id" INTEGER NOT NULL, + PRIMARY KEY (user_id, channel_id) +); + +CREATE UNIQUE INDEX "index_observed_channel_messages_user_and_channel_id" ON "observed_channel_messages" ("user_id", "channel_id"); + +CREATE TABLE "notification_kinds" ( + "id" INTEGER PRIMARY KEY AUTOINCREMENT, + "name" VARCHAR NOT NULL +); + +CREATE UNIQUE INDEX "index_notification_kinds_on_name" ON "notification_kinds" ("name"); + +CREATE TABLE "notifications" ( + "id" INTEGER PRIMARY KEY AUTOINCREMENT, + "created_at" TIMESTAMP NOT NULL default CURRENT_TIMESTAMP, + "recipient_id" INTEGER NOT NULL REFERENCES users (id) ON DELETE CASCADE, + "kind" INTEGER NOT NULL REFERENCES notification_kinds (id), + "entity_id" INTEGER, + "content" TEXT, + "is_read" BOOLEAN NOT NULL DEFAULT FALSE, + "response" BOOLEAN +); + +CREATE INDEX + "index_notifications_on_recipient_id_is_read_kind_entity_id" + ON "notifications" + ("recipient_id", "is_read", "kind", "entity_id"); diff --git a/crates/collab2/migrations/20210527024318_initial_schema.sql b/crates/collab2/migrations/20210527024318_initial_schema.sql new file mode 100644 index 0000000000..4b06531848 --- /dev/null +++ b/crates/collab2/migrations/20210527024318_initial_schema.sql @@ -0,0 +1,20 @@ +CREATE TABLE IF NOT EXISTS "sessions" ( + "id" VARCHAR NOT NULL PRIMARY KEY, + "expires" TIMESTAMP WITH TIME ZONE NULL, + "session" TEXT NOT NULL +); + +CREATE TABLE IF NOT EXISTS "users" ( + "id" SERIAL PRIMARY KEY, + "github_login" VARCHAR, + "admin" BOOLEAN +); + +CREATE UNIQUE INDEX "index_users_github_login" ON "users" ("github_login"); + +CREATE TABLE IF NOT EXISTS "signups" ( + "id" SERIAL PRIMARY KEY, + "github_login" VARCHAR, + "email_address" VARCHAR, + "about" TEXT +); diff --git a/crates/collab2/migrations/20210607190313_create_access_tokens.sql b/crates/collab2/migrations/20210607190313_create_access_tokens.sql new file mode 100644 index 0000000000..60745a98ba --- /dev/null +++ b/crates/collab2/migrations/20210607190313_create_access_tokens.sql @@ -0,0 +1,7 @@ +CREATE TABLE IF NOT EXISTS "access_tokens" ( + "id" SERIAL PRIMARY KEY, + "user_id" INTEGER REFERENCES users (id), + "hash" VARCHAR(128) +); + +CREATE INDEX "index_access_tokens_user_id" ON "access_tokens" ("user_id"); diff --git a/crates/collab2/migrations/20210805175147_create_chat_tables.sql b/crates/collab2/migrations/20210805175147_create_chat_tables.sql new file mode 100644 index 0000000000..5bba4689d9 --- /dev/null +++ b/crates/collab2/migrations/20210805175147_create_chat_tables.sql @@ -0,0 +1,46 @@ +CREATE TABLE IF NOT EXISTS "orgs" ( + "id" SERIAL PRIMARY KEY, + "name" VARCHAR NOT NULL, + "slug" VARCHAR NOT NULL +); + +CREATE UNIQUE INDEX "index_orgs_slug" ON "orgs" ("slug"); + +CREATE TABLE IF NOT EXISTS "org_memberships" ( + "id" SERIAL PRIMARY KEY, + "org_id" INTEGER REFERENCES orgs (id) NOT NULL, + "user_id" INTEGER REFERENCES users (id) NOT NULL, + "admin" BOOLEAN NOT NULL +); + +CREATE INDEX "index_org_memberships_user_id" ON "org_memberships" ("user_id"); +CREATE UNIQUE INDEX "index_org_memberships_org_id_and_user_id" ON "org_memberships" ("org_id", "user_id"); + +CREATE TABLE IF NOT EXISTS "channels" ( + "id" SERIAL PRIMARY KEY, + "owner_id" INTEGER NOT NULL, + "owner_is_user" BOOLEAN NOT NULL, + "name" VARCHAR NOT NULL +); + +CREATE UNIQUE INDEX "index_channels_owner_and_name" ON "channels" ("owner_is_user", "owner_id", "name"); + +CREATE TABLE IF NOT EXISTS "channel_memberships" ( + "id" SERIAL PRIMARY KEY, + "channel_id" INTEGER REFERENCES channels (id) NOT NULL, + "user_id" INTEGER REFERENCES users (id) NOT NULL, + "admin" BOOLEAN NOT NULL +); + +CREATE INDEX "index_channel_memberships_user_id" ON "channel_memberships" ("user_id"); +CREATE UNIQUE INDEX "index_channel_memberships_channel_id_and_user_id" ON "channel_memberships" ("channel_id", "user_id"); + +CREATE TABLE IF NOT EXISTS "channel_messages" ( + "id" SERIAL PRIMARY KEY, + "channel_id" INTEGER REFERENCES channels (id) NOT NULL, + "sender_id" INTEGER REFERENCES users (id) NOT NULL, + "body" TEXT NOT NULL, + "sent_at" TIMESTAMP +); + +CREATE INDEX "index_channel_messages_channel_id" ON "channel_messages" ("channel_id"); diff --git a/crates/collab2/migrations/20210916123647_add_nonce_to_channel_messages.sql b/crates/collab2/migrations/20210916123647_add_nonce_to_channel_messages.sql new file mode 100644 index 0000000000..ee4d4aa319 --- /dev/null +++ b/crates/collab2/migrations/20210916123647_add_nonce_to_channel_messages.sql @@ -0,0 +1,4 @@ +ALTER TABLE "channel_messages" +ADD "nonce" UUID NOT NULL DEFAULT gen_random_uuid(); + +CREATE UNIQUE INDEX "index_channel_messages_nonce" ON "channel_messages" ("nonce"); diff --git a/crates/collab2/migrations/20210920192001_add_interests_to_signups.sql b/crates/collab2/migrations/20210920192001_add_interests_to_signups.sql new file mode 100644 index 0000000000..2457abfc75 --- /dev/null +++ b/crates/collab2/migrations/20210920192001_add_interests_to_signups.sql @@ -0,0 +1,4 @@ +ALTER TABLE "signups" + ADD "wants_releases" BOOLEAN, + ADD "wants_updates" BOOLEAN, + ADD "wants_community" BOOLEAN; \ No newline at end of file diff --git a/crates/collab2/migrations/20220421165757_drop_signups.sql b/crates/collab2/migrations/20220421165757_drop_signups.sql new file mode 100644 index 0000000000..d7cd6e204c --- /dev/null +++ b/crates/collab2/migrations/20220421165757_drop_signups.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS "signups"; diff --git a/crates/collab2/migrations/20220505144506_add_trigram_index_to_users.sql b/crates/collab2/migrations/20220505144506_add_trigram_index_to_users.sql new file mode 100644 index 0000000000..3d6fd3179a --- /dev/null +++ b/crates/collab2/migrations/20220505144506_add_trigram_index_to_users.sql @@ -0,0 +1,2 @@ +CREATE EXTENSION IF NOT EXISTS pg_trgm; +CREATE INDEX trigram_index_users_on_github_login ON users USING GIN(github_login gin_trgm_ops); diff --git a/crates/collab2/migrations/20220506130724_create_contacts.sql b/crates/collab2/migrations/20220506130724_create_contacts.sql new file mode 100644 index 0000000000..56beb70fd0 --- /dev/null +++ b/crates/collab2/migrations/20220506130724_create_contacts.sql @@ -0,0 +1,11 @@ +CREATE TABLE IF NOT EXISTS "contacts" ( + "id" SERIAL PRIMARY KEY, + "user_id_a" INTEGER REFERENCES users (id) NOT NULL, + "user_id_b" INTEGER REFERENCES users (id) NOT NULL, + "a_to_b" BOOLEAN NOT NULL, + "should_notify" BOOLEAN NOT NULL, + "accepted" BOOLEAN NOT NULL +); + +CREATE UNIQUE INDEX "index_contacts_user_ids" ON "contacts" ("user_id_a", "user_id_b"); +CREATE INDEX "index_contacts_user_id_b" ON "contacts" ("user_id_b"); diff --git a/crates/collab2/migrations/20220518151305_add_invites_to_users.sql b/crates/collab2/migrations/20220518151305_add_invites_to_users.sql new file mode 100644 index 0000000000..2ac89b649e --- /dev/null +++ b/crates/collab2/migrations/20220518151305_add_invites_to_users.sql @@ -0,0 +1,9 @@ +ALTER TABLE users +ADD email_address VARCHAR(255) DEFAULT NULL, +ADD invite_code VARCHAR(64), +ADD invite_count INTEGER NOT NULL DEFAULT 0, +ADD inviter_id INTEGER REFERENCES users (id), +ADD connected_once BOOLEAN NOT NULL DEFAULT false, +ADD created_at TIMESTAMP NOT NULL DEFAULT NOW(); + +CREATE UNIQUE INDEX "index_invite_code_users" ON "users" ("invite_code"); diff --git a/crates/collab2/migrations/20220523232954_allow_user_deletes.sql b/crates/collab2/migrations/20220523232954_allow_user_deletes.sql new file mode 100644 index 0000000000..ddf3f6f9bd --- /dev/null +++ b/crates/collab2/migrations/20220523232954_allow_user_deletes.sql @@ -0,0 +1,6 @@ +ALTER TABLE contacts DROP CONSTRAINT contacts_user_id_a_fkey; +ALTER TABLE contacts DROP CONSTRAINT contacts_user_id_b_fkey; +ALTER TABLE contacts ADD CONSTRAINT contacts_user_id_a_fkey FOREIGN KEY (user_id_a) REFERENCES users(id) ON DELETE CASCADE; +ALTER TABLE contacts ADD CONSTRAINT contacts_user_id_b_fkey FOREIGN KEY (user_id_b) REFERENCES users(id) ON DELETE CASCADE; +ALTER TABLE users DROP CONSTRAINT users_inviter_id_fkey; +ALTER TABLE users ADD CONSTRAINT users_inviter_id_fkey FOREIGN KEY (inviter_id) REFERENCES users(id) ON DELETE SET NULL; diff --git a/crates/collab2/migrations/20220620211403_create_projects.sql b/crates/collab2/migrations/20220620211403_create_projects.sql new file mode 100644 index 0000000000..d813c9f7a1 --- /dev/null +++ b/crates/collab2/migrations/20220620211403_create_projects.sql @@ -0,0 +1,24 @@ +CREATE TABLE IF NOT EXISTS "projects" ( + "id" SERIAL PRIMARY KEY, + "host_user_id" INTEGER REFERENCES users (id) NOT NULL, + "unregistered" BOOLEAN NOT NULL DEFAULT false +); + +CREATE TABLE IF NOT EXISTS "worktree_extensions" ( + "id" SERIAL PRIMARY KEY, + "project_id" INTEGER REFERENCES projects (id) NOT NULL, + "worktree_id" INTEGER NOT NULL, + "extension" VARCHAR(255), + "count" INTEGER NOT NULL +); + +CREATE TABLE IF NOT EXISTS "project_activity_periods" ( + "id" SERIAL PRIMARY KEY, + "duration_millis" INTEGER NOT NULL, + "ended_at" TIMESTAMP NOT NULL, + "user_id" INTEGER REFERENCES users (id) NOT NULL, + "project_id" INTEGER REFERENCES projects (id) NOT NULL +); + +CREATE INDEX "index_project_activity_periods_on_ended_at" ON "project_activity_periods" ("ended_at"); +CREATE UNIQUE INDEX "index_worktree_extensions_on_project_id_and_worktree_id_and_extension" ON "worktree_extensions" ("project_id", "worktree_id", "extension"); \ No newline at end of file diff --git a/crates/collab2/migrations/20220913211150_create_signups.sql b/crates/collab2/migrations/20220913211150_create_signups.sql new file mode 100644 index 0000000000..19559b747c --- /dev/null +++ b/crates/collab2/migrations/20220913211150_create_signups.sql @@ -0,0 +1,27 @@ +CREATE TABLE IF NOT EXISTS "signups" ( + "id" SERIAL PRIMARY KEY, + "email_address" VARCHAR NOT NULL, + "email_confirmation_code" VARCHAR(64) NOT NULL, + "email_confirmation_sent" BOOLEAN NOT NULL, + "created_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + "device_id" VARCHAR, + "user_id" INTEGER REFERENCES users (id) ON DELETE CASCADE, + "inviting_user_id" INTEGER REFERENCES users (id) ON DELETE SET NULL, + + "platform_mac" BOOLEAN NOT NULL, + "platform_linux" BOOLEAN NOT NULL, + "platform_windows" BOOLEAN NOT NULL, + "platform_unknown" BOOLEAN NOT NULL, + + "editor_features" VARCHAR[], + "programming_languages" VARCHAR[] +); + +CREATE UNIQUE INDEX "index_signups_on_email_address" ON "signups" ("email_address"); +CREATE INDEX "index_signups_on_email_confirmation_sent" ON "signups" ("email_confirmation_sent"); + +ALTER TABLE "users" + ADD "github_user_id" INTEGER; + +CREATE INDEX "index_users_on_email_address" ON "users" ("email_address"); +CREATE INDEX "index_users_on_github_user_id" ON "users" ("github_user_id"); diff --git a/crates/collab2/migrations/20220929182110_add_metrics_id.sql b/crates/collab2/migrations/20220929182110_add_metrics_id.sql new file mode 100644 index 0000000000..665d6323bf --- /dev/null +++ b/crates/collab2/migrations/20220929182110_add_metrics_id.sql @@ -0,0 +1,2 @@ +ALTER TABLE "users" + ADD "metrics_id" uuid NOT NULL DEFAULT gen_random_uuid(); diff --git a/crates/collab2/migrations/20221111092550_reconnection_support.sql b/crates/collab2/migrations/20221111092550_reconnection_support.sql new file mode 100644 index 0000000000..3289f6bbdd --- /dev/null +++ b/crates/collab2/migrations/20221111092550_reconnection_support.sql @@ -0,0 +1,90 @@ +CREATE TABLE IF NOT EXISTS "rooms" ( + "id" SERIAL PRIMARY KEY, + "live_kit_room" VARCHAR NOT NULL +); + +ALTER TABLE "projects" + ADD "room_id" INTEGER REFERENCES rooms (id), + ADD "host_connection_id" INTEGER, + ADD "host_connection_epoch" UUID; +CREATE INDEX "index_projects_on_host_connection_epoch" ON "projects" ("host_connection_epoch"); + +CREATE TABLE "worktrees" ( + "project_id" INTEGER NOT NULL REFERENCES projects (id) ON DELETE CASCADE, + "id" INT8 NOT NULL, + "root_name" VARCHAR NOT NULL, + "abs_path" VARCHAR NOT NULL, + "visible" BOOL NOT NULL, + "scan_id" INT8 NOT NULL, + "is_complete" BOOL NOT NULL, + PRIMARY KEY(project_id, id) +); +CREATE INDEX "index_worktrees_on_project_id" ON "worktrees" ("project_id"); + +CREATE TABLE "worktree_entries" ( + "project_id" INTEGER NOT NULL, + "worktree_id" INT8 NOT NULL, + "id" INT8 NOT NULL, + "is_dir" BOOL NOT NULL, + "path" VARCHAR NOT NULL, + "inode" INT8 NOT NULL, + "mtime_seconds" INT8 NOT NULL, + "mtime_nanos" INTEGER NOT NULL, + "is_symlink" BOOL NOT NULL, + "is_ignored" BOOL NOT NULL, + PRIMARY KEY(project_id, worktree_id, id), + FOREIGN KEY(project_id, worktree_id) REFERENCES worktrees (project_id, id) ON DELETE CASCADE +); +CREATE INDEX "index_worktree_entries_on_project_id" ON "worktree_entries" ("project_id"); +CREATE INDEX "index_worktree_entries_on_project_id_and_worktree_id" ON "worktree_entries" ("project_id", "worktree_id"); + +CREATE TABLE "worktree_diagnostic_summaries" ( + "project_id" INTEGER NOT NULL, + "worktree_id" INT8 NOT NULL, + "path" VARCHAR NOT NULL, + "language_server_id" INT8 NOT NULL, + "error_count" INTEGER NOT NULL, + "warning_count" INTEGER NOT NULL, + PRIMARY KEY(project_id, worktree_id, path), + FOREIGN KEY(project_id, worktree_id) REFERENCES worktrees (project_id, id) ON DELETE CASCADE +); +CREATE INDEX "index_worktree_diagnostic_summaries_on_project_id" ON "worktree_diagnostic_summaries" ("project_id"); +CREATE INDEX "index_worktree_diagnostic_summaries_on_project_id_and_worktree_id" ON "worktree_diagnostic_summaries" ("project_id", "worktree_id"); + +CREATE TABLE "language_servers" ( + "project_id" INTEGER NOT NULL REFERENCES projects (id) ON DELETE CASCADE, + "id" INT8 NOT NULL, + "name" VARCHAR NOT NULL, + PRIMARY KEY(project_id, id) +); +CREATE INDEX "index_language_servers_on_project_id" ON "language_servers" ("project_id"); + +CREATE TABLE "project_collaborators" ( + "id" SERIAL PRIMARY KEY, + "project_id" INTEGER NOT NULL REFERENCES projects (id) ON DELETE CASCADE, + "connection_id" INTEGER NOT NULL, + "connection_epoch" UUID NOT NULL, + "user_id" INTEGER NOT NULL, + "replica_id" INTEGER NOT NULL, + "is_host" BOOLEAN NOT NULL +); +CREATE INDEX "index_project_collaborators_on_project_id" ON "project_collaborators" ("project_id"); +CREATE UNIQUE INDEX "index_project_collaborators_on_project_id_and_replica_id" ON "project_collaborators" ("project_id", "replica_id"); +CREATE INDEX "index_project_collaborators_on_connection_epoch" ON "project_collaborators" ("connection_epoch"); + +CREATE TABLE "room_participants" ( + "id" SERIAL PRIMARY KEY, + "room_id" INTEGER NOT NULL REFERENCES rooms (id), + "user_id" INTEGER NOT NULL REFERENCES users (id), + "answering_connection_id" INTEGER, + "answering_connection_epoch" UUID, + "location_kind" INTEGER, + "location_project_id" INTEGER, + "initial_project_id" INTEGER, + "calling_user_id" INTEGER NOT NULL REFERENCES users (id), + "calling_connection_id" INTEGER NOT NULL, + "calling_connection_epoch" UUID NOT NULL +); +CREATE UNIQUE INDEX "index_room_participants_on_user_id" ON "room_participants" ("user_id"); +CREATE INDEX "index_room_participants_on_answering_connection_epoch" ON "room_participants" ("answering_connection_epoch"); +CREATE INDEX "index_room_participants_on_calling_connection_epoch" ON "room_participants" ("calling_connection_epoch"); diff --git a/crates/collab2/migrations/20221125192125_add_added_to_mailing_list_to_signups.sql b/crates/collab2/migrations/20221125192125_add_added_to_mailing_list_to_signups.sql new file mode 100644 index 0000000000..b154396df1 --- /dev/null +++ b/crates/collab2/migrations/20221125192125_add_added_to_mailing_list_to_signups.sql @@ -0,0 +1,2 @@ +ALTER TABLE "signups" + ADD "added_to_mailing_list" BOOLEAN NOT NULL DEFAULT FALSE; \ No newline at end of file diff --git a/crates/collab2/migrations/20221207165001_add_connection_lost_to_room_participants.sql b/crates/collab2/migrations/20221207165001_add_connection_lost_to_room_participants.sql new file mode 100644 index 0000000000..ed0cf972bc --- /dev/null +++ b/crates/collab2/migrations/20221207165001_add_connection_lost_to_room_participants.sql @@ -0,0 +1,7 @@ +ALTER TABLE "room_participants" + ADD "answering_connection_lost" BOOLEAN NOT NULL DEFAULT FALSE; + +CREATE INDEX "index_project_collaborators_on_connection_id" ON "project_collaborators" ("connection_id"); +CREATE UNIQUE INDEX "index_project_collaborators_on_project_id_connection_id_and_epoch" ON "project_collaborators" ("project_id", "connection_id", "connection_epoch"); +CREATE INDEX "index_room_participants_on_answering_connection_id" ON "room_participants" ("answering_connection_id"); +CREATE UNIQUE INDEX "index_room_participants_on_answering_connection_id_and_answering_connection_epoch" ON "room_participants" ("answering_connection_id", "answering_connection_epoch"); diff --git a/crates/collab2/migrations/20221213125710_index_room_participants_on_room_id.sql b/crates/collab2/migrations/20221213125710_index_room_participants_on_room_id.sql new file mode 100644 index 0000000000..f40ca81906 --- /dev/null +++ b/crates/collab2/migrations/20221213125710_index_room_participants_on_room_id.sql @@ -0,0 +1 @@ +CREATE INDEX "index_room_participants_on_room_id" ON "room_participants" ("room_id"); diff --git a/crates/collab2/migrations/20221214144346_change_epoch_from_uuid_to_integer.sql b/crates/collab2/migrations/20221214144346_change_epoch_from_uuid_to_integer.sql new file mode 100644 index 0000000000..5e02f76ce2 --- /dev/null +++ b/crates/collab2/migrations/20221214144346_change_epoch_from_uuid_to_integer.sql @@ -0,0 +1,30 @@ +CREATE TABLE servers ( + id SERIAL PRIMARY KEY, + environment VARCHAR NOT NULL +); + +DROP TABLE worktree_extensions; +DROP TABLE project_activity_periods; +DELETE from projects; +ALTER TABLE projects + DROP COLUMN host_connection_epoch, + ADD COLUMN host_connection_server_id INTEGER REFERENCES servers (id) ON DELETE CASCADE; +CREATE INDEX "index_projects_on_host_connection_server_id" ON "projects" ("host_connection_server_id"); +CREATE INDEX "index_projects_on_host_connection_id_and_host_connection_server_id" ON "projects" ("host_connection_id", "host_connection_server_id"); + +DELETE FROM project_collaborators; +ALTER TABLE project_collaborators + DROP COLUMN connection_epoch, + ADD COLUMN connection_server_id INTEGER NOT NULL REFERENCES servers (id) ON DELETE CASCADE; +CREATE INDEX "index_project_collaborators_on_connection_server_id" ON "project_collaborators" ("connection_server_id"); +CREATE UNIQUE INDEX "index_project_collaborators_on_project_id_connection_id_and_server_id" ON "project_collaborators" ("project_id", "connection_id", "connection_server_id"); + +DELETE FROM room_participants; +ALTER TABLE room_participants + DROP COLUMN answering_connection_epoch, + DROP COLUMN calling_connection_epoch, + ADD COLUMN answering_connection_server_id INTEGER REFERENCES servers (id) ON DELETE CASCADE, + ADD COLUMN calling_connection_server_id INTEGER REFERENCES servers (id) ON DELETE SET NULL; +CREATE INDEX "index_room_participants_on_answering_connection_server_id" ON "room_participants" ("answering_connection_server_id"); +CREATE INDEX "index_room_participants_on_calling_connection_server_id" ON "room_participants" ("calling_connection_server_id"); +CREATE UNIQUE INDEX "index_room_participants_on_answering_connection_id_and_answering_connection_server_id" ON "room_participants" ("answering_connection_id", "answering_connection_server_id"); diff --git a/crates/collab2/migrations/20221219181850_project_reconnection_support.sql b/crates/collab2/migrations/20221219181850_project_reconnection_support.sql new file mode 100644 index 0000000000..6efef5571c --- /dev/null +++ b/crates/collab2/migrations/20221219181850_project_reconnection_support.sql @@ -0,0 +1,3 @@ +ALTER TABLE "worktree_entries" + ADD COLUMN "scan_id" INT8, + ADD COLUMN "is_deleted" BOOL; diff --git a/crates/collab2/migrations/20230103200902_replace_is_completed_with_completed_scan_id.sql b/crates/collab2/migrations/20230103200902_replace_is_completed_with_completed_scan_id.sql new file mode 100644 index 0000000000..1894d888b9 --- /dev/null +++ b/crates/collab2/migrations/20230103200902_replace_is_completed_with_completed_scan_id.sql @@ -0,0 +1,3 @@ +ALTER TABLE worktrees + ALTER COLUMN is_complete SET DEFAULT FALSE, + ADD COLUMN completed_scan_id INT8; diff --git a/crates/collab2/migrations/20230202155735_followers.sql b/crates/collab2/migrations/20230202155735_followers.sql new file mode 100644 index 0000000000..c82d6ba3bd --- /dev/null +++ b/crates/collab2/migrations/20230202155735_followers.sql @@ -0,0 +1,15 @@ +CREATE TABLE IF NOT EXISTS "followers" ( + "id" SERIAL PRIMARY KEY, + "room_id" INTEGER NOT NULL REFERENCES rooms (id) ON DELETE CASCADE, + "project_id" INTEGER NOT NULL REFERENCES projects (id) ON DELETE CASCADE, + "leader_connection_server_id" INTEGER NOT NULL REFERENCES servers (id) ON DELETE CASCADE, + "leader_connection_id" INTEGER NOT NULL, + "follower_connection_server_id" INTEGER NOT NULL REFERENCES servers (id) ON DELETE CASCADE, + "follower_connection_id" INTEGER NOT NULL +); + +CREATE UNIQUE INDEX + "index_followers_on_project_id_and_leader_connection_server_id_and_leader_connection_id_and_follower_connection_server_id_and_follower_connection_id" +ON "followers" ("project_id", "leader_connection_server_id", "leader_connection_id", "follower_connection_server_id", "follower_connection_id"); + +CREATE INDEX "index_followers_on_room_id" ON "followers" ("room_id"); diff --git a/crates/collab2/migrations/20230508211523_add-repository-entries.sql b/crates/collab2/migrations/20230508211523_add-repository-entries.sql new file mode 100644 index 0000000000..1e59347939 --- /dev/null +++ b/crates/collab2/migrations/20230508211523_add-repository-entries.sql @@ -0,0 +1,13 @@ +CREATE TABLE "worktree_repositories" ( + "project_id" INTEGER NOT NULL, + "worktree_id" INT8 NOT NULL, + "work_directory_id" INT8 NOT NULL, + "scan_id" INT8 NOT NULL, + "branch" VARCHAR, + "is_deleted" BOOL NOT NULL, + PRIMARY KEY(project_id, worktree_id, work_directory_id), + FOREIGN KEY(project_id, worktree_id) REFERENCES worktrees (project_id, id) ON DELETE CASCADE, + FOREIGN KEY(project_id, worktree_id, work_directory_id) REFERENCES worktree_entries (project_id, worktree_id, id) ON DELETE CASCADE +); +CREATE INDEX "index_worktree_repositories_on_project_id" ON "worktree_repositories" ("project_id"); +CREATE INDEX "index_worktree_repositories_on_project_id_and_worktree_id" ON "worktree_repositories" ("project_id", "worktree_id"); diff --git a/crates/collab2/migrations/20230511004019_add_repository_statuses.sql b/crates/collab2/migrations/20230511004019_add_repository_statuses.sql new file mode 100644 index 0000000000..862561c686 --- /dev/null +++ b/crates/collab2/migrations/20230511004019_add_repository_statuses.sql @@ -0,0 +1,15 @@ +CREATE TABLE "worktree_repository_statuses" ( + "project_id" INTEGER NOT NULL, + "worktree_id" INT8 NOT NULL, + "work_directory_id" INT8 NOT NULL, + "repo_path" VARCHAR NOT NULL, + "status" INT8 NOT NULL, + "scan_id" INT8 NOT NULL, + "is_deleted" BOOL NOT NULL, + PRIMARY KEY(project_id, worktree_id, work_directory_id, repo_path), + FOREIGN KEY(project_id, worktree_id) REFERENCES worktrees (project_id, id) ON DELETE CASCADE, + FOREIGN KEY(project_id, worktree_id, work_directory_id) REFERENCES worktree_entries (project_id, worktree_id, id) ON DELETE CASCADE +); +CREATE INDEX "index_wt_repos_statuses_on_project_id" ON "worktree_repository_statuses" ("project_id"); +CREATE INDEX "index_wt_repos_statuses_on_project_id_and_wt_id" ON "worktree_repository_statuses" ("project_id", "worktree_id"); +CREATE INDEX "index_wt_repos_statuses_on_project_id_and_wt_id_and_wd_id" ON "worktree_repository_statuses" ("project_id", "worktree_id", "work_directory_id"); diff --git a/crates/collab2/migrations/20230529164700_add_worktree_settings_files.sql b/crates/collab2/migrations/20230529164700_add_worktree_settings_files.sql new file mode 100644 index 0000000000..973a40af0f --- /dev/null +++ b/crates/collab2/migrations/20230529164700_add_worktree_settings_files.sql @@ -0,0 +1,10 @@ +CREATE TABLE "worktree_settings_files" ( + "project_id" INTEGER NOT NULL, + "worktree_id" INT8 NOT NULL, + "path" VARCHAR NOT NULL, + "content" TEXT NOT NULL, + PRIMARY KEY(project_id, worktree_id, path), + FOREIGN KEY(project_id, worktree_id) REFERENCES worktrees (project_id, id) ON DELETE CASCADE +); +CREATE INDEX "index_settings_files_on_project_id" ON "worktree_settings_files" ("project_id"); +CREATE INDEX "index_settings_files_on_project_id_and_wt_id" ON "worktree_settings_files" ("project_id", "worktree_id"); diff --git a/crates/collab2/migrations/20230605191135_remove_repository_statuses.sql b/crates/collab2/migrations/20230605191135_remove_repository_statuses.sql new file mode 100644 index 0000000000..3e5f907c44 --- /dev/null +++ b/crates/collab2/migrations/20230605191135_remove_repository_statuses.sql @@ -0,0 +1,2 @@ +ALTER TABLE "worktree_entries" +ADD "git_status" INT8; diff --git a/crates/collab2/migrations/20230616134535_add_is_external_to_worktree_entries.sql b/crates/collab2/migrations/20230616134535_add_is_external_to_worktree_entries.sql new file mode 100644 index 0000000000..e4348af0cc --- /dev/null +++ b/crates/collab2/migrations/20230616134535_add_is_external_to_worktree_entries.sql @@ -0,0 +1,2 @@ +ALTER TABLE "worktree_entries" +ADD "is_external" BOOL NOT NULL DEFAULT FALSE; diff --git a/crates/collab2/migrations/20230727150500_add_channels.sql b/crates/collab2/migrations/20230727150500_add_channels.sql new file mode 100644 index 0000000000..df981838bf --- /dev/null +++ b/crates/collab2/migrations/20230727150500_add_channels.sql @@ -0,0 +1,30 @@ +DROP TABLE "channel_messages"; +DROP TABLE "channel_memberships"; +DROP TABLE "org_memberships"; +DROP TABLE "orgs"; +DROP TABLE "channels"; + +CREATE TABLE "channels" ( + "id" SERIAL PRIMARY KEY, + "name" VARCHAR NOT NULL, + "created_at" TIMESTAMP NOT NULL DEFAULT now() +); + +CREATE TABLE "channel_paths" ( + "id_path" VARCHAR NOT NULL PRIMARY KEY, + "channel_id" INTEGER NOT NULL REFERENCES channels (id) ON DELETE CASCADE +); +CREATE INDEX "index_channel_paths_on_channel_id" ON "channel_paths" ("channel_id"); + +CREATE TABLE "channel_members" ( + "id" SERIAL PRIMARY KEY, + "channel_id" INTEGER NOT NULL REFERENCES channels (id) ON DELETE CASCADE, + "user_id" INTEGER NOT NULL REFERENCES users (id) ON DELETE CASCADE, + "admin" BOOLEAN NOT NULL DEFAULT false, + "accepted" BOOLEAN NOT NULL DEFAULT false, + "updated_at" TIMESTAMP NOT NULL DEFAULT now() +); + +CREATE UNIQUE INDEX "index_channel_members_on_channel_id_and_user_id" ON "channel_members" ("channel_id", "user_id"); + +ALTER TABLE rooms ADD COLUMN "channel_id" INTEGER REFERENCES channels (id) ON DELETE CASCADE; diff --git a/crates/collab2/migrations/20230819154600_add_channel_buffers.sql b/crates/collab2/migrations/20230819154600_add_channel_buffers.sql new file mode 100644 index 0000000000..5e6e7ce339 --- /dev/null +++ b/crates/collab2/migrations/20230819154600_add_channel_buffers.sql @@ -0,0 +1,40 @@ +CREATE TABLE "buffers" ( + "id" SERIAL PRIMARY KEY, + "channel_id" INTEGER NOT NULL REFERENCES channels (id) ON DELETE CASCADE, + "epoch" INTEGER NOT NULL DEFAULT 0 +); + +CREATE INDEX "index_buffers_on_channel_id" ON "buffers" ("channel_id"); + +CREATE TABLE "buffer_operations" ( + "buffer_id" INTEGER NOT NULL REFERENCES buffers (id) ON DELETE CASCADE, + "epoch" INTEGER NOT NULL, + "replica_id" INTEGER NOT NULL, + "lamport_timestamp" INTEGER NOT NULL, + "value" BYTEA NOT NULL, + PRIMARY KEY(buffer_id, epoch, lamport_timestamp, replica_id) +); + +CREATE TABLE "buffer_snapshots" ( + "buffer_id" INTEGER NOT NULL REFERENCES buffers (id) ON DELETE CASCADE, + "epoch" INTEGER NOT NULL, + "text" TEXT NOT NULL, + "operation_serialization_version" INTEGER NOT NULL, + PRIMARY KEY(buffer_id, epoch) +); + +CREATE TABLE "channel_buffer_collaborators" ( + "id" SERIAL PRIMARY KEY, + "channel_id" INTEGER NOT NULL REFERENCES channels (id) ON DELETE CASCADE, + "connection_id" INTEGER NOT NULL, + "connection_server_id" INTEGER NOT NULL REFERENCES servers (id) ON DELETE CASCADE, + "connection_lost" BOOLEAN NOT NULL DEFAULT FALSE, + "user_id" INTEGER NOT NULL REFERENCES users (id) ON DELETE CASCADE, + "replica_id" INTEGER NOT NULL +); + +CREATE INDEX "index_channel_buffer_collaborators_on_channel_id" ON "channel_buffer_collaborators" ("channel_id"); +CREATE UNIQUE INDEX "index_channel_buffer_collaborators_on_channel_id_and_replica_id" ON "channel_buffer_collaborators" ("channel_id", "replica_id"); +CREATE INDEX "index_channel_buffer_collaborators_on_connection_server_id" ON "channel_buffer_collaborators" ("connection_server_id"); +CREATE INDEX "index_channel_buffer_collaborators_on_connection_id" ON "channel_buffer_collaborators" ("connection_id"); +CREATE UNIQUE INDEX "index_channel_buffer_collaborators_on_channel_id_connection_id_and_server_id" ON "channel_buffer_collaborators" ("channel_id", "connection_id", "connection_server_id"); diff --git a/crates/collab2/migrations/20230825190322_add_server_feature_flags.sql b/crates/collab2/migrations/20230825190322_add_server_feature_flags.sql new file mode 100644 index 0000000000..fffde54a20 --- /dev/null +++ b/crates/collab2/migrations/20230825190322_add_server_feature_flags.sql @@ -0,0 +1,16 @@ +CREATE TABLE "feature_flags" ( + "id" SERIAL PRIMARY KEY, + "flag" VARCHAR(255) NOT NULL UNIQUE +); + +CREATE UNIQUE INDEX "index_feature_flags" ON "feature_flags" ("id"); + +CREATE TABLE "user_features" ( + "user_id" INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE, + "feature_id" INTEGER NOT NULL REFERENCES feature_flags(id) ON DELETE CASCADE, + PRIMARY KEY (user_id, feature_id) +); + +CREATE UNIQUE INDEX "index_user_features_user_id_and_feature_id" ON "user_features" ("user_id", "feature_id"); +CREATE INDEX "index_user_features_on_user_id" ON "user_features" ("user_id"); +CREATE INDEX "index_user_features_on_feature_id" ON "user_features" ("feature_id"); diff --git a/crates/collab2/migrations/20230907114200_add_channel_messages.sql b/crates/collab2/migrations/20230907114200_add_channel_messages.sql new file mode 100644 index 0000000000..abe7753ca6 --- /dev/null +++ b/crates/collab2/migrations/20230907114200_add_channel_messages.sql @@ -0,0 +1,19 @@ +CREATE TABLE IF NOT EXISTS "channel_messages" ( + "id" SERIAL PRIMARY KEY, + "channel_id" INTEGER NOT NULL REFERENCES channels (id) ON DELETE CASCADE, + "sender_id" INTEGER NOT NULL REFERENCES users (id), + "body" TEXT NOT NULL, + "sent_at" TIMESTAMP, + "nonce" UUID NOT NULL +); +CREATE INDEX "index_channel_messages_on_channel_id" ON "channel_messages" ("channel_id"); +CREATE UNIQUE INDEX "index_channel_messages_on_nonce" ON "channel_messages" ("nonce"); + +CREATE TABLE IF NOT EXISTS "channel_chat_participants" ( + "id" SERIAL PRIMARY KEY, + "user_id" INTEGER NOT NULL REFERENCES users (id), + "channel_id" INTEGER NOT NULL REFERENCES channels (id) ON DELETE CASCADE, + "connection_id" INTEGER NOT NULL, + "connection_server_id" INTEGER NOT NULL REFERENCES servers (id) ON DELETE CASCADE +); +CREATE INDEX "index_channel_chat_participants_on_channel_id" ON "channel_chat_participants" ("channel_id"); diff --git a/crates/collab2/migrations/20230925210437_add_channel_changes.sql b/crates/collab2/migrations/20230925210437_add_channel_changes.sql new file mode 100644 index 0000000000..250a9ac731 --- /dev/null +++ b/crates/collab2/migrations/20230925210437_add_channel_changes.sql @@ -0,0 +1,19 @@ +CREATE TABLE IF NOT EXISTS "observed_buffer_edits" ( + "user_id" INTEGER NOT NULL REFERENCES users (id) ON DELETE CASCADE, + "buffer_id" INTEGER NOT NULL REFERENCES buffers (id) ON DELETE CASCADE, + "epoch" INTEGER NOT NULL, + "lamport_timestamp" INTEGER NOT NULL, + "replica_id" INTEGER NOT NULL, + PRIMARY KEY (user_id, buffer_id) +); + +CREATE UNIQUE INDEX "index_observed_buffer_user_and_buffer_id" ON "observed_buffer_edits" ("user_id", "buffer_id"); + +CREATE TABLE IF NOT EXISTS "observed_channel_messages" ( + "user_id" INTEGER NOT NULL REFERENCES users (id) ON DELETE CASCADE, + "channel_id" INTEGER NOT NULL REFERENCES channels (id) ON DELETE CASCADE, + "channel_message_id" INTEGER NOT NULL, + PRIMARY KEY (user_id, channel_id) +); + +CREATE UNIQUE INDEX "index_observed_channel_messages_user_and_channel_id" ON "observed_channel_messages" ("user_id", "channel_id"); diff --git a/crates/collab2/migrations/20230926102500_add_participant_index_to_room_participants.sql b/crates/collab2/migrations/20230926102500_add_participant_index_to_room_participants.sql new file mode 100644 index 0000000000..1493119e2a --- /dev/null +++ b/crates/collab2/migrations/20230926102500_add_participant_index_to_room_participants.sql @@ -0,0 +1 @@ +ALTER TABLE room_participants ADD COLUMN participant_index INTEGER; diff --git a/crates/collab2/migrations/20231004130100_create_notifications.sql b/crates/collab2/migrations/20231004130100_create_notifications.sql new file mode 100644 index 0000000000..93c282c631 --- /dev/null +++ b/crates/collab2/migrations/20231004130100_create_notifications.sql @@ -0,0 +1,22 @@ +CREATE TABLE "notification_kinds" ( + "id" SERIAL PRIMARY KEY, + "name" VARCHAR NOT NULL +); + +CREATE UNIQUE INDEX "index_notification_kinds_on_name" ON "notification_kinds" ("name"); + +CREATE TABLE notifications ( + "id" SERIAL PRIMARY KEY, + "created_at" TIMESTAMP NOT NULL DEFAULT now(), + "recipient_id" INTEGER NOT NULL REFERENCES users (id) ON DELETE CASCADE, + "kind" INTEGER NOT NULL REFERENCES notification_kinds (id), + "entity_id" INTEGER, + "content" TEXT, + "is_read" BOOLEAN NOT NULL DEFAULT FALSE, + "response" BOOLEAN +); + +CREATE INDEX + "index_notifications_on_recipient_id_is_read_kind_entity_id" + ON "notifications" + ("recipient_id", "is_read", "kind", "entity_id"); diff --git a/crates/collab2/migrations/20231009181554_add_release_channel_to_rooms.sql b/crates/collab2/migrations/20231009181554_add_release_channel_to_rooms.sql new file mode 100644 index 0000000000..8f3a704add --- /dev/null +++ b/crates/collab2/migrations/20231009181554_add_release_channel_to_rooms.sql @@ -0,0 +1 @@ +ALTER TABLE rooms ADD COLUMN enviroment TEXT; diff --git a/crates/collab2/migrations/20231010114600_add_unique_index_on_rooms_channel_id.sql b/crates/collab2/migrations/20231010114600_add_unique_index_on_rooms_channel_id.sql new file mode 100644 index 0000000000..21ec4cfbb7 --- /dev/null +++ b/crates/collab2/migrations/20231010114600_add_unique_index_on_rooms_channel_id.sql @@ -0,0 +1 @@ +CREATE UNIQUE INDEX "index_rooms_on_channel_id" ON "rooms" ("channel_id"); diff --git a/crates/collab2/migrations/20231011214412_add_guest_role.sql b/crates/collab2/migrations/20231011214412_add_guest_role.sql new file mode 100644 index 0000000000..1713547158 --- /dev/null +++ b/crates/collab2/migrations/20231011214412_add_guest_role.sql @@ -0,0 +1,4 @@ +ALTER TABLE channel_members ADD COLUMN role TEXT; +UPDATE channel_members SET role = CASE WHEN admin THEN 'admin' ELSE 'member' END; + +ALTER TABLE channels ADD COLUMN visibility TEXT NOT NULL DEFAULT 'members'; diff --git a/crates/collab2/migrations/20231017185833_projects_room_id_fkey_on_delete_cascade.sql b/crates/collab2/migrations/20231017185833_projects_room_id_fkey_on_delete_cascade.sql new file mode 100644 index 0000000000..be535ff7fa --- /dev/null +++ b/crates/collab2/migrations/20231017185833_projects_room_id_fkey_on_delete_cascade.sql @@ -0,0 +1,8 @@ +-- Add migration script here + +ALTER TABLE projects + DROP CONSTRAINT projects_room_id_fkey, + ADD CONSTRAINT projects_room_id_fkey + FOREIGN KEY (room_id) + REFERENCES rooms (id) + ON DELETE CASCADE; diff --git a/crates/collab2/migrations/20231018102700_create_mentions.sql b/crates/collab2/migrations/20231018102700_create_mentions.sql new file mode 100644 index 0000000000..221a1748cf --- /dev/null +++ b/crates/collab2/migrations/20231018102700_create_mentions.sql @@ -0,0 +1,11 @@ +CREATE TABLE "channel_message_mentions" ( + "message_id" INTEGER NOT NULL REFERENCES channel_messages (id) ON DELETE CASCADE, + "start_offset" INTEGER NOT NULL, + "end_offset" INTEGER NOT NULL, + "user_id" INTEGER NOT NULL REFERENCES users (id) ON DELETE CASCADE, + PRIMARY KEY(message_id, start_offset) +); + +-- We use 'on conflict update' with this index, so it should be per-user. +CREATE UNIQUE INDEX "index_channel_messages_on_sender_id_nonce" ON "channel_messages" ("sender_id", "nonce"); +DROP INDEX "index_channel_messages_on_nonce"; diff --git a/crates/collab2/migrations/20231024085546_move_channel_paths_to_channels_table.sql b/crates/collab2/migrations/20231024085546_move_channel_paths_to_channels_table.sql new file mode 100644 index 0000000000..d9fc6c8722 --- /dev/null +++ b/crates/collab2/migrations/20231024085546_move_channel_paths_to_channels_table.sql @@ -0,0 +1,12 @@ +ALTER TABLE channels ADD COLUMN parent_path TEXT; + +UPDATE channels +SET parent_path = substr( + channel_paths.id_path, + 2, + length(channel_paths.id_path) - length('/' || channel_paths.channel_id::text || '/') +) +FROM channel_paths +WHERE channel_paths.channel_id = channels.id; + +CREATE INDEX "index_channels_on_parent_path" ON "channels" ("parent_path"); diff --git a/crates/collab2/src/api.rs b/crates/collab2/src/api.rs new file mode 100644 index 0000000000..a84fcf328b --- /dev/null +++ b/crates/collab2/src/api.rs @@ -0,0 +1,184 @@ +use crate::{ + auth, + db::{User, UserId}, + rpc, AppState, Error, Result, +}; +use anyhow::anyhow; +use axum::{ + body::Body, + extract::{Path, Query}, + http::{self, Request, StatusCode}, + middleware::{self, Next}, + response::IntoResponse, + routing::{get, post}, + Extension, Json, Router, +}; +use axum_extra::response::ErasedJson; +use serde::{Deserialize, Serialize}; +use std::sync::Arc; +use tower::ServiceBuilder; +use tracing::instrument; + +pub fn routes(rpc_server: Arc, state: Arc) -> Router { + Router::new() + .route("/user", get(get_authenticated_user)) + .route("/users/:id/access_tokens", post(create_access_token)) + .route("/panic", post(trace_panic)) + .route("/rpc_server_snapshot", get(get_rpc_server_snapshot)) + .layer( + ServiceBuilder::new() + .layer(Extension(state)) + .layer(Extension(rpc_server)) + .layer(middleware::from_fn(validate_api_token)), + ) +} + +pub async fn validate_api_token(req: Request, next: Next) -> impl IntoResponse { + let token = req + .headers() + .get(http::header::AUTHORIZATION) + .and_then(|header| header.to_str().ok()) + .ok_or_else(|| { + Error::Http( + StatusCode::BAD_REQUEST, + "missing authorization header".to_string(), + ) + })? + .strip_prefix("token ") + .ok_or_else(|| { + Error::Http( + StatusCode::BAD_REQUEST, + "invalid authorization header".to_string(), + ) + })?; + + let state = req.extensions().get::>().unwrap(); + + if token != state.config.api_token { + Err(Error::Http( + StatusCode::UNAUTHORIZED, + "invalid authorization token".to_string(), + ))? + } + + Ok::<_, Error>(next.run(req).await) +} + +#[derive(Debug, Deserialize)] +struct AuthenticatedUserParams { + github_user_id: Option, + github_login: String, + github_email: Option, +} + +#[derive(Debug, Serialize)] +struct AuthenticatedUserResponse { + user: User, + metrics_id: String, +} + +async fn get_authenticated_user( + Query(params): Query, + Extension(app): Extension>, +) -> Result> { + let user = app + .db + .get_or_create_user_by_github_account( + ¶ms.github_login, + params.github_user_id, + params.github_email.as_deref(), + ) + .await? + .ok_or_else(|| Error::Http(StatusCode::NOT_FOUND, "user not found".into()))?; + let metrics_id = app.db.get_user_metrics_id(user.id).await?; + return Ok(Json(AuthenticatedUserResponse { user, metrics_id })); +} + +#[derive(Deserialize, Debug)] +struct CreateUserParams { + github_user_id: i32, + github_login: String, + email_address: String, + email_confirmation_code: Option, + #[serde(default)] + admin: bool, + #[serde(default)] + invite_count: i32, +} + +#[derive(Serialize, Debug)] +struct CreateUserResponse { + user: User, + signup_device_id: Option, + metrics_id: String, +} + +#[derive(Debug, Deserialize)] +struct Panic { + version: String, + text: String, +} + +#[instrument(skip(panic))] +async fn trace_panic(panic: Json) -> Result<()> { + tracing::error!(version = %panic.version, text = %panic.text, "panic report"); + Ok(()) +} + +async fn get_rpc_server_snapshot( + Extension(rpc_server): Extension>, +) -> Result { + Ok(ErasedJson::pretty(rpc_server.snapshot().await)) +} + +#[derive(Deserialize)] +struct CreateAccessTokenQueryParams { + public_key: String, + impersonate: Option, +} + +#[derive(Serialize)] +struct CreateAccessTokenResponse { + user_id: UserId, + encrypted_access_token: String, +} + +async fn create_access_token( + Path(user_id): Path, + Query(params): Query, + Extension(app): Extension>, +) -> Result> { + let user = app + .db + .get_user_by_id(user_id) + .await? + .ok_or_else(|| anyhow!("user not found"))?; + + let mut user_id = user.id; + if let Some(impersonate) = params.impersonate { + if user.admin { + if let Some(impersonated_user) = app.db.get_user_by_github_login(&impersonate).await? { + user_id = impersonated_user.id; + } else { + return Err(Error::Http( + StatusCode::UNPROCESSABLE_ENTITY, + format!("user {impersonate} does not exist"), + )); + } + } else { + return Err(Error::Http( + StatusCode::UNAUTHORIZED, + "you do not have permission to impersonate other users".to_string(), + )); + } + } + + let access_token = auth::create_access_token(app.db.as_ref(), user_id).await?; + let encrypted_access_token = + auth::encrypt_access_token(&access_token, params.public_key.clone())?; + + Ok(Json(CreateAccessTokenResponse { + user_id, + encrypted_access_token, + })) +} diff --git a/crates/collab2/src/auth.rs b/crates/collab2/src/auth.rs new file mode 100644 index 0000000000..9ce602c577 --- /dev/null +++ b/crates/collab2/src/auth.rs @@ -0,0 +1,151 @@ +use crate::{ + db::{self, AccessTokenId, Database, UserId}, + AppState, Error, Result, +}; +use anyhow::{anyhow, Context}; +use axum::{ + http::{self, Request, StatusCode}, + middleware::Next, + response::IntoResponse, +}; +use lazy_static::lazy_static; +use prometheus::{exponential_buckets, register_histogram, Histogram}; +use rand::thread_rng; +use scrypt::{ + password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, SaltString}, + Scrypt, +}; +use serde::{Deserialize, Serialize}; +use std::{sync::Arc, time::Instant}; + +lazy_static! { + static ref METRIC_ACCESS_TOKEN_HASHING_TIME: Histogram = register_histogram!( + "access_token_hashing_time", + "time spent hashing access tokens", + exponential_buckets(10.0, 2.0, 10).unwrap(), + ) + .unwrap(); +} + +pub async fn validate_header(mut req: Request, next: Next) -> impl IntoResponse { + let mut auth_header = req + .headers() + .get(http::header::AUTHORIZATION) + .and_then(|header| header.to_str().ok()) + .ok_or_else(|| { + Error::Http( + StatusCode::UNAUTHORIZED, + "missing authorization header".to_string(), + ) + })? + .split_whitespace(); + + let user_id = UserId(auth_header.next().unwrap_or("").parse().map_err(|_| { + Error::Http( + StatusCode::BAD_REQUEST, + "missing user id in authorization header".to_string(), + ) + })?); + + let access_token = auth_header.next().ok_or_else(|| { + Error::Http( + StatusCode::BAD_REQUEST, + "missing access token in authorization header".to_string(), + ) + })?; + + let state = req.extensions().get::>().unwrap(); + let credentials_valid = if let Some(admin_token) = access_token.strip_prefix("ADMIN_TOKEN:") { + state.config.api_token == admin_token + } else { + verify_access_token(&access_token, user_id, &state.db) + .await + .unwrap_or(false) + }; + + if credentials_valid { + let user = state + .db + .get_user_by_id(user_id) + .await? + .ok_or_else(|| anyhow!("user {} not found", user_id))?; + req.extensions_mut().insert(user); + Ok::<_, Error>(next.run(req).await) + } else { + Err(Error::Http( + StatusCode::UNAUTHORIZED, + "invalid credentials".to_string(), + )) + } +} + +const MAX_ACCESS_TOKENS_TO_STORE: usize = 8; + +#[derive(Serialize, Deserialize)] +struct AccessTokenJson { + version: usize, + id: AccessTokenId, + token: String, +} + +pub async fn create_access_token(db: &db::Database, user_id: UserId) -> Result { + const VERSION: usize = 1; + let access_token = rpc::auth::random_token(); + let access_token_hash = + hash_access_token(&access_token).context("failed to hash access token")?; + let id = db + .create_access_token(user_id, &access_token_hash, MAX_ACCESS_TOKENS_TO_STORE) + .await?; + Ok(serde_json::to_string(&AccessTokenJson { + version: VERSION, + id, + token: access_token, + })?) +} + +fn hash_access_token(token: &str) -> Result { + // Avoid slow hashing in debug mode. + let params = if cfg!(debug_assertions) { + scrypt::Params::new(1, 1, 1).unwrap() + } else { + scrypt::Params::new(14, 8, 1).unwrap() + }; + + Ok(Scrypt + .hash_password( + token.as_bytes(), + None, + params, + &SaltString::generate(thread_rng()), + ) + .map_err(anyhow::Error::new)? + .to_string()) +} + +pub fn encrypt_access_token(access_token: &str, public_key: String) -> Result { + let native_app_public_key = + rpc::auth::PublicKey::try_from(public_key).context("failed to parse app public key")?; + let encrypted_access_token = native_app_public_key + .encrypt_string(access_token) + .context("failed to encrypt access token with public key")?; + Ok(encrypted_access_token) +} + +pub async fn verify_access_token(token: &str, user_id: UserId, db: &Arc) -> Result { + let token: AccessTokenJson = serde_json::from_str(&token)?; + + let db_token = db.get_access_token(token.id).await?; + if db_token.user_id != user_id { + return Err(anyhow!("no such access token"))?; + } + + let db_hash = PasswordHash::new(&db_token.hash).map_err(anyhow::Error::new)?; + let t0 = Instant::now(); + let is_valid = Scrypt + .verify_password(token.token.as_bytes(), &db_hash) + .is_ok(); + let duration = t0.elapsed(); + log::info!("hashed access token in {:?}", duration); + METRIC_ACCESS_TOKEN_HASHING_TIME.observe(duration.as_millis() as f64); + Ok(is_valid) +} diff --git a/crates/collab2/src/bin/dotenv.rs b/crates/collab2/src/bin/dotenv.rs new file mode 100644 index 0000000000..c093bcb6e9 --- /dev/null +++ b/crates/collab2/src/bin/dotenv.rs @@ -0,0 +1,20 @@ +use anyhow::anyhow; +use std::fs; + +fn main() -> anyhow::Result<()> { + let env: toml::map::Map = toml::de::from_str( + &fs::read_to_string("./.env.toml").map_err(|_| anyhow!("no .env.toml file found"))?, + )?; + + for (key, value) in env { + let value = match value { + toml::Value::String(value) => value, + toml::Value::Integer(value) => value.to_string(), + toml::Value::Float(value) => value.to_string(), + _ => panic!("unsupported TOML value in .env.toml for key {}", key), + }; + println!("export {}=\"{}\"", key, value); + } + + Ok(()) +} diff --git a/crates/collab2/src/bin/seed.rs b/crates/collab2/src/bin/seed.rs new file mode 100644 index 0000000000..a7127bbb77 --- /dev/null +++ b/crates/collab2/src/bin/seed.rs @@ -0,0 +1,107 @@ +use collab2::{db, executor::Executor}; +use db::{ConnectOptions, Database}; +use serde::{de::DeserializeOwned, Deserialize}; +use std::fmt::Write; + +#[derive(Debug, Deserialize)] +struct GitHubUser { + id: i32, + login: String, + email: Option, +} + +#[tokio::main] +async fn main() { + let database_url = std::env::var("DATABASE_URL").expect("missing DATABASE_URL env var"); + let db = Database::new(ConnectOptions::new(database_url), Executor::Production) + .await + .expect("failed to connect to postgres database"); + let github_token = std::env::var("GITHUB_TOKEN").expect("missing GITHUB_TOKEN env var"); + let client = reqwest::Client::new(); + + let mut current_user = + fetch_github::(&client, &github_token, "https://api.github.com/user").await; + current_user + .email + .get_or_insert_with(|| "placeholder@example.com".to_string()); + let staff_users = fetch_github::>( + &client, + &github_token, + "https://api.github.com/orgs/zed-industries/teams/staff/members", + ) + .await; + + let mut zed_users = Vec::new(); + zed_users.push((current_user, true)); + zed_users.extend(staff_users.into_iter().map(|user| (user, true))); + + let user_count = db + .get_all_users(0, 200) + .await + .expect("failed to load users from db") + .len(); + if user_count < 100 { + let mut last_user_id = None; + for _ in 0..10 { + let mut uri = "https://api.github.com/users?per_page=100".to_string(); + if let Some(last_user_id) = last_user_id { + write!(&mut uri, "&since={}", last_user_id).unwrap(); + } + let users = fetch_github::>(&client, &github_token, &uri).await; + if let Some(last_user) = users.last() { + last_user_id = Some(last_user.id); + zed_users.extend(users.into_iter().map(|user| (user, false))); + } else { + break; + } + } + } + + for (github_user, admin) in zed_users { + if db + .get_user_by_github_login(&github_user.login) + .await + .expect("failed to fetch user") + .is_none() + { + if admin { + db.create_user( + &format!("{}@zed.dev", github_user.login), + admin, + db::NewUserParams { + github_login: github_user.login, + github_user_id: github_user.id, + }, + ) + .await + .expect("failed to insert user"); + } else { + db.get_or_create_user_by_github_account( + &github_user.login, + Some(github_user.id), + github_user.email.as_deref(), + ) + .await + .expect("failed to insert user"); + } + } + } +} + +async fn fetch_github( + client: &reqwest::Client, + access_token: &str, + url: &str, +) -> T { + let response = client + .get(url) + .bearer_auth(&access_token) + .header("user-agent", "zed") + .send() + .await + .expect(&format!("failed to fetch '{}'", url)); + response + .json() + .await + .expect(&format!("failed to deserialize github user from '{}'", url)) +} diff --git a/crates/collab2/src/db.rs b/crates/collab2/src/db.rs new file mode 100644 index 0000000000..df33416a46 --- /dev/null +++ b/crates/collab2/src/db.rs @@ -0,0 +1,672 @@ +#[cfg(test)] +pub mod tests; + +#[cfg(test)] +pub use tests::TestDb; + +mod ids; +mod queries; +mod tables; + +use crate::{executor::Executor, Error, Result}; +use anyhow::anyhow; +use collections::{BTreeMap, HashMap, HashSet}; +use dashmap::DashMap; +use futures::StreamExt; +use rand::{prelude::StdRng, Rng, SeedableRng}; +use rpc::{ + proto::{self}, + ConnectionId, +}; +use sea_orm::{ + entity::prelude::*, + sea_query::{Alias, Expr, OnConflict}, + ActiveValue, Condition, ConnectionTrait, DatabaseConnection, DatabaseTransaction, DbErr, + FromQueryResult, IntoActiveModel, IsolationLevel, JoinType, QueryOrder, QuerySelect, Statement, + TransactionTrait, +}; +use serde::{Deserialize, Serialize}; +use sqlx::{ + migrate::{Migrate, Migration, MigrationSource}, + Connection, +}; +use std::{ + fmt::Write as _, + future::Future, + marker::PhantomData, + ops::{Deref, DerefMut}, + path::Path, + rc::Rc, + sync::Arc, + time::Duration, +}; +use tables::*; +use tokio::sync::{Mutex, OwnedMutexGuard}; + +pub use ids::*; +pub use sea_orm::ConnectOptions; +pub use tables::user::Model as User; + +pub struct Database { + options: ConnectOptions, + pool: DatabaseConnection, + rooms: DashMap>>, + rng: Mutex, + executor: Executor, + notification_kinds_by_id: HashMap, + notification_kinds_by_name: HashMap, + #[cfg(test)] + runtime: Option, +} + +// The `Database` type has so many methods that its impl blocks are split into +// separate files in the `queries` folder. +impl Database { + pub async fn new(options: ConnectOptions, executor: Executor) -> Result { + sqlx::any::install_default_drivers(); + Ok(Self { + options: options.clone(), + pool: sea_orm::Database::connect(options).await?, + rooms: DashMap::with_capacity(16384), + rng: Mutex::new(StdRng::seed_from_u64(0)), + notification_kinds_by_id: HashMap::default(), + notification_kinds_by_name: HashMap::default(), + executor, + #[cfg(test)] + runtime: None, + }) + } + + #[cfg(test)] + pub fn reset(&self) { + self.rooms.clear(); + } + + pub async fn migrate( + &self, + migrations_path: &Path, + ignore_checksum_mismatch: bool, + ) -> anyhow::Result> { + let migrations = MigrationSource::resolve(migrations_path) + .await + .map_err(|err| anyhow!("failed to load migrations: {err:?}"))?; + + let mut connection = sqlx::AnyConnection::connect(self.options.get_url()).await?; + + connection.ensure_migrations_table().await?; + let applied_migrations: HashMap<_, _> = connection + .list_applied_migrations() + .await? + .into_iter() + .map(|m| (m.version, m)) + .collect(); + + let mut new_migrations = Vec::new(); + for migration in migrations { + match applied_migrations.get(&migration.version) { + Some(applied_migration) => { + if migration.checksum != applied_migration.checksum && !ignore_checksum_mismatch + { + Err(anyhow!( + "checksum mismatch for applied migration {}", + migration.description + ))?; + } + } + None => { + let elapsed = connection.apply(&migration).await?; + new_migrations.push((migration, elapsed)); + } + } + } + + Ok(new_migrations) + } + + pub async fn initialize_static_data(&mut self) -> Result<()> { + self.initialize_notification_kinds().await?; + Ok(()) + } + + pub async fn transaction(&self, f: F) -> Result + where + F: Send + Fn(TransactionHandle) -> Fut, + Fut: Send + Future>, + { + let body = async { + let mut i = 0; + loop { + let (tx, result) = self.with_transaction(&f).await?; + match result { + Ok(result) => match tx.commit().await.map_err(Into::into) { + Ok(()) => return Ok(result), + Err(error) => { + if !self.retry_on_serialization_error(&error, i).await { + return Err(error); + } + } + }, + Err(error) => { + tx.rollback().await?; + if !self.retry_on_serialization_error(&error, i).await { + return Err(error); + } + } + } + i += 1; + } + }; + + self.run(body).await + } + + async fn optional_room_transaction(&self, f: F) -> Result>> + where + F: Send + Fn(TransactionHandle) -> Fut, + Fut: Send + Future>>, + { + let body = async { + let mut i = 0; + loop { + let (tx, result) = self.with_transaction(&f).await?; + match result { + Ok(Some((room_id, data))) => { + let lock = self.rooms.entry(room_id).or_default().clone(); + let _guard = lock.lock_owned().await; + match tx.commit().await.map_err(Into::into) { + Ok(()) => { + return Ok(Some(RoomGuard { + data, + _guard, + _not_send: PhantomData, + })); + } + Err(error) => { + if !self.retry_on_serialization_error(&error, i).await { + return Err(error); + } + } + } + } + Ok(None) => match tx.commit().await.map_err(Into::into) { + Ok(()) => return Ok(None), + Err(error) => { + if !self.retry_on_serialization_error(&error, i).await { + return Err(error); + } + } + }, + Err(error) => { + tx.rollback().await?; + if !self.retry_on_serialization_error(&error, i).await { + return Err(error); + } + } + } + i += 1; + } + }; + + self.run(body).await + } + + async fn room_transaction(&self, room_id: RoomId, f: F) -> Result> + where + F: Send + Fn(TransactionHandle) -> Fut, + Fut: Send + Future>, + { + let body = async { + let mut i = 0; + loop { + let lock = self.rooms.entry(room_id).or_default().clone(); + let _guard = lock.lock_owned().await; + let (tx, result) = self.with_transaction(&f).await?; + match result { + Ok(data) => match tx.commit().await.map_err(Into::into) { + Ok(()) => { + return Ok(RoomGuard { + data, + _guard, + _not_send: PhantomData, + }); + } + Err(error) => { + if !self.retry_on_serialization_error(&error, i).await { + return Err(error); + } + } + }, + Err(error) => { + tx.rollback().await?; + if !self.retry_on_serialization_error(&error, i).await { + return Err(error); + } + } + } + i += 1; + } + }; + + self.run(body).await + } + + async fn with_transaction(&self, f: &F) -> Result<(DatabaseTransaction, Result)> + where + F: Send + Fn(TransactionHandle) -> Fut, + Fut: Send + Future>, + { + let tx = self + .pool + .begin_with_config(Some(IsolationLevel::Serializable), None) + .await?; + + let mut tx = Arc::new(Some(tx)); + let result = f(TransactionHandle(tx.clone())).await; + let Some(tx) = Arc::get_mut(&mut tx).and_then(|tx| tx.take()) else { + return Err(anyhow!( + "couldn't complete transaction because it's still in use" + ))?; + }; + + Ok((tx, result)) + } + + async fn run(&self, future: F) -> Result + where + F: Future>, + { + #[cfg(test)] + { + if let Executor::Deterministic(executor) = &self.executor { + executor.simulate_random_delay().await; + } + + self.runtime.as_ref().unwrap().block_on(future) + } + + #[cfg(not(test))] + { + future.await + } + } + + async fn retry_on_serialization_error(&self, error: &Error, prev_attempt_count: u32) -> bool { + // If the error is due to a failure to serialize concurrent transactions, then retry + // this transaction after a delay. With each subsequent retry, double the delay duration. + // Also vary the delay randomly in order to ensure different database connections retry + // at different times. + if is_serialization_error(error) { + let base_delay = 4_u64 << prev_attempt_count.min(16); + let randomized_delay = base_delay as f32 * self.rng.lock().await.gen_range(0.5..=2.0); + log::info!( + "retrying transaction after serialization error. delay: {} ms.", + randomized_delay + ); + self.executor + .sleep(Duration::from_millis(randomized_delay as u64)) + .await; + true + } else { + false + } + } +} + +fn is_serialization_error(error: &Error) -> bool { + const SERIALIZATION_FAILURE_CODE: &'static str = "40001"; + match error { + Error::Database( + DbErr::Exec(sea_orm::RuntimeErr::SqlxError(error)) + | DbErr::Query(sea_orm::RuntimeErr::SqlxError(error)), + ) if error + .as_database_error() + .and_then(|error| error.code()) + .as_deref() + == Some(SERIALIZATION_FAILURE_CODE) => + { + true + } + _ => false, + } +} + +pub struct TransactionHandle(Arc>); + +impl Deref for TransactionHandle { + type Target = DatabaseTransaction; + + fn deref(&self) -> &Self::Target { + self.0.as_ref().as_ref().unwrap() + } +} + +pub struct RoomGuard { + data: T, + _guard: OwnedMutexGuard<()>, + _not_send: PhantomData>, +} + +impl Deref for RoomGuard { + type Target = T; + + fn deref(&self) -> &T { + &self.data + } +} + +impl DerefMut for RoomGuard { + fn deref_mut(&mut self) -> &mut T { + &mut self.data + } +} + +impl RoomGuard { + pub fn into_inner(self) -> T { + self.data + } +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum Contact { + Accepted { user_id: UserId, busy: bool }, + Outgoing { user_id: UserId }, + Incoming { user_id: UserId }, +} + +impl Contact { + pub fn user_id(&self) -> UserId { + match self { + Contact::Accepted { user_id, .. } => *user_id, + Contact::Outgoing { user_id } => *user_id, + Contact::Incoming { user_id, .. } => *user_id, + } + } +} + +pub type NotificationBatch = Vec<(UserId, proto::Notification)>; + +pub struct CreatedChannelMessage { + pub message_id: MessageId, + pub participant_connection_ids: Vec, + pub channel_members: Vec, + pub notifications: NotificationBatch, +} + +#[derive(Clone, Debug, PartialEq, Eq, FromQueryResult, Serialize, Deserialize)] +pub struct Invite { + pub email_address: String, + pub email_confirmation_code: String, +} + +#[derive(Clone, Debug, Deserialize)] +pub struct NewSignup { + pub email_address: String, + pub platform_mac: bool, + pub platform_windows: bool, + pub platform_linux: bool, + pub editor_features: Vec, + pub programming_languages: Vec, + pub device_id: Option, + pub added_to_mailing_list: bool, + pub created_at: Option, +} + +#[derive(Clone, Debug, PartialEq, Deserialize, Serialize, FromQueryResult)] +pub struct WaitlistSummary { + pub count: i64, + pub linux_count: i64, + pub mac_count: i64, + pub windows_count: i64, + pub unknown_count: i64, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct NewUserParams { + pub github_login: String, + pub github_user_id: i32, +} + +#[derive(Debug)] +pub struct NewUserResult { + pub user_id: UserId, + pub metrics_id: String, + pub inviting_user_id: Option, + pub signup_device_id: Option, +} + +#[derive(Debug)] +pub struct MoveChannelResult { + pub participants_to_update: HashMap, + pub participants_to_remove: HashSet, + pub moved_channels: HashSet, +} + +#[derive(Debug)] +pub struct RenameChannelResult { + pub channel: Channel, + pub participants_to_update: HashMap, +} + +#[derive(Debug)] +pub struct CreateChannelResult { + pub channel: Channel, + pub participants_to_update: Vec<(UserId, ChannelsForUser)>, +} + +#[derive(Debug)] +pub struct SetChannelVisibilityResult { + pub participants_to_update: HashMap, + pub participants_to_remove: HashSet, + pub channels_to_remove: Vec, +} + +#[derive(Debug)] +pub struct MembershipUpdated { + pub channel_id: ChannelId, + pub new_channels: ChannelsForUser, + pub removed_channels: Vec, +} + +#[derive(Debug)] +pub enum SetMemberRoleResult { + InviteUpdated(Channel), + MembershipUpdated(MembershipUpdated), +} + +#[derive(Debug)] +pub struct InviteMemberResult { + pub channel: Channel, + pub notifications: NotificationBatch, +} + +#[derive(Debug)] +pub struct RespondToChannelInvite { + pub membership_update: Option, + pub notifications: NotificationBatch, +} + +#[derive(Debug)] +pub struct RemoveChannelMemberResult { + pub membership_update: MembershipUpdated, + pub notification_id: Option, +} + +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct Channel { + pub id: ChannelId, + pub name: String, + pub visibility: ChannelVisibility, + pub role: ChannelRole, + pub parent_path: Vec, +} + +impl Channel { + fn from_model(value: channel::Model, role: ChannelRole) -> Self { + Channel { + id: value.id, + visibility: value.visibility, + name: value.clone().name, + role, + parent_path: value.ancestors().collect(), + } + } + + pub fn to_proto(&self) -> proto::Channel { + proto::Channel { + id: self.id.to_proto(), + name: self.name.clone(), + visibility: self.visibility.into(), + role: self.role.into(), + parent_path: self.parent_path.iter().map(|c| c.to_proto()).collect(), + } + } +} + +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct ChannelMember { + pub role: ChannelRole, + pub user_id: UserId, + pub kind: proto::channel_member::Kind, +} + +impl ChannelMember { + pub fn to_proto(&self) -> proto::ChannelMember { + proto::ChannelMember { + role: self.role.into(), + user_id: self.user_id.to_proto(), + kind: self.kind.into(), + } + } +} + +#[derive(Debug, PartialEq)] +pub struct ChannelsForUser { + pub channels: Vec, + pub channel_participants: HashMap>, + pub unseen_buffer_changes: Vec, + pub channel_messages: Vec, +} + +#[derive(Debug)] +pub struct RejoinedChannelBuffer { + pub buffer: proto::RejoinedChannelBuffer, + pub old_connection_id: ConnectionId, +} + +#[derive(Clone)] +pub struct JoinRoom { + pub room: proto::Room, + pub channel_id: Option, + pub channel_members: Vec, +} + +pub struct RejoinedRoom { + pub room: proto::Room, + pub rejoined_projects: Vec, + pub reshared_projects: Vec, + pub channel_id: Option, + pub channel_members: Vec, +} + +pub struct ResharedProject { + pub id: ProjectId, + pub old_connection_id: ConnectionId, + pub collaborators: Vec, + pub worktrees: Vec, +} + +pub struct RejoinedProject { + pub id: ProjectId, + pub old_connection_id: ConnectionId, + pub collaborators: Vec, + pub worktrees: Vec, + pub language_servers: Vec, +} + +#[derive(Debug)] +pub struct RejoinedWorktree { + pub id: u64, + pub abs_path: String, + pub root_name: String, + pub visible: bool, + pub updated_entries: Vec, + pub removed_entries: Vec, + pub updated_repositories: Vec, + pub removed_repositories: Vec, + pub diagnostic_summaries: Vec, + pub settings_files: Vec, + pub scan_id: u64, + pub completed_scan_id: u64, +} + +pub struct LeftRoom { + pub room: proto::Room, + pub channel_id: Option, + pub channel_members: Vec, + pub left_projects: HashMap, + pub canceled_calls_to_user_ids: Vec, + pub deleted: bool, +} + +pub struct RefreshedRoom { + pub room: proto::Room, + pub channel_id: Option, + pub channel_members: Vec, + pub stale_participant_user_ids: Vec, + pub canceled_calls_to_user_ids: Vec, +} + +pub struct RefreshedChannelBuffer { + pub connection_ids: Vec, + pub collaborators: Vec, +} + +pub struct Project { + pub collaborators: Vec, + pub worktrees: BTreeMap, + pub language_servers: Vec, +} + +pub struct ProjectCollaborator { + pub connection_id: ConnectionId, + pub user_id: UserId, + pub replica_id: ReplicaId, + pub is_host: bool, +} + +impl ProjectCollaborator { + pub fn to_proto(&self) -> proto::Collaborator { + proto::Collaborator { + peer_id: Some(self.connection_id.into()), + replica_id: self.replica_id.0 as u32, + user_id: self.user_id.to_proto(), + } + } +} + +#[derive(Debug)] +pub struct LeftProject { + pub id: ProjectId, + pub host_user_id: UserId, + pub host_connection_id: ConnectionId, + pub connection_ids: Vec, +} + +pub struct Worktree { + pub id: u64, + pub abs_path: String, + pub root_name: String, + pub visible: bool, + pub entries: Vec, + pub repository_entries: BTreeMap, + pub diagnostic_summaries: Vec, + pub settings_files: Vec, + pub scan_id: u64, + pub completed_scan_id: u64, +} + +#[derive(Debug)] +pub struct WorktreeSettingsFile { + pub path: String, + pub content: String, +} diff --git a/crates/collab2/src/db/ids.rs b/crates/collab2/src/db/ids.rs new file mode 100644 index 0000000000..5f0df90811 --- /dev/null +++ b/crates/collab2/src/db/ids.rs @@ -0,0 +1,199 @@ +use crate::Result; +use rpc::proto; +use sea_orm::{entity::prelude::*, DbErr}; +use serde::{Deserialize, Serialize}; + +macro_rules! id_type { + ($name:ident) => { + #[derive( + Clone, + Copy, + Debug, + Default, + PartialEq, + Eq, + PartialOrd, + Ord, + Hash, + Serialize, + Deserialize, + DeriveValueType, + )] + #[serde(transparent)] + pub struct $name(pub i32); + + impl $name { + #[allow(unused)] + pub const MAX: Self = Self(i32::MAX); + + #[allow(unused)] + pub fn from_proto(value: u64) -> Self { + Self(value as i32) + } + + #[allow(unused)] + pub fn to_proto(self) -> u64 { + self.0 as u64 + } + } + + impl std::fmt::Display for $name { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + self.0.fmt(f) + } + } + + impl sea_orm::TryFromU64 for $name { + fn try_from_u64(n: u64) -> Result { + Ok(Self(n.try_into().map_err(|_| { + DbErr::ConvertFromU64(concat!( + "error converting ", + stringify!($name), + " to u64" + )) + })?)) + } + } + + impl sea_orm::sea_query::Nullable for $name { + fn null() -> Value { + Value::Int(None) + } + } + }; +} + +id_type!(BufferId); +id_type!(AccessTokenId); +id_type!(ChannelChatParticipantId); +id_type!(ChannelId); +id_type!(ChannelMemberId); +id_type!(MessageId); +id_type!(ContactId); +id_type!(FollowerId); +id_type!(RoomId); +id_type!(RoomParticipantId); +id_type!(ProjectId); +id_type!(ProjectCollaboratorId); +id_type!(ReplicaId); +id_type!(ServerId); +id_type!(SignupId); +id_type!(UserId); +id_type!(ChannelBufferCollaboratorId); +id_type!(FlagId); +id_type!(NotificationId); +id_type!(NotificationKindId); + +#[derive(Eq, PartialEq, Copy, Clone, Debug, EnumIter, DeriveActiveEnum, Default, Hash)] +#[sea_orm(rs_type = "String", db_type = "String(None)")] +pub enum ChannelRole { + #[sea_orm(string_value = "admin")] + Admin, + #[sea_orm(string_value = "member")] + #[default] + Member, + #[sea_orm(string_value = "guest")] + Guest, + #[sea_orm(string_value = "banned")] + Banned, +} + +impl ChannelRole { + pub fn should_override(&self, other: Self) -> bool { + use ChannelRole::*; + match self { + Admin => matches!(other, Member | Banned | Guest), + Member => matches!(other, Banned | Guest), + Banned => matches!(other, Guest), + Guest => false, + } + } + + pub fn max(&self, other: Self) -> Self { + if self.should_override(other) { + *self + } else { + other + } + } + + pub fn can_see_all_descendants(&self) -> bool { + use ChannelRole::*; + match self { + Admin | Member => true, + Guest | Banned => false, + } + } + + pub fn can_only_see_public_descendants(&self) -> bool { + use ChannelRole::*; + match self { + Guest => true, + Admin | Member | Banned => false, + } + } +} + +impl From for ChannelRole { + fn from(value: proto::ChannelRole) -> Self { + match value { + proto::ChannelRole::Admin => ChannelRole::Admin, + proto::ChannelRole::Member => ChannelRole::Member, + proto::ChannelRole::Guest => ChannelRole::Guest, + proto::ChannelRole::Banned => ChannelRole::Banned, + } + } +} + +impl Into for ChannelRole { + fn into(self) -> proto::ChannelRole { + match self { + ChannelRole::Admin => proto::ChannelRole::Admin, + ChannelRole::Member => proto::ChannelRole::Member, + ChannelRole::Guest => proto::ChannelRole::Guest, + ChannelRole::Banned => proto::ChannelRole::Banned, + } + } +} + +impl Into for ChannelRole { + fn into(self) -> i32 { + let proto: proto::ChannelRole = self.into(); + proto.into() + } +} + +#[derive(Eq, PartialEq, Copy, Clone, Debug, EnumIter, DeriveActiveEnum, Default, Hash)] +#[sea_orm(rs_type = "String", db_type = "String(None)")] +pub enum ChannelVisibility { + #[sea_orm(string_value = "public")] + Public, + #[sea_orm(string_value = "members")] + #[default] + Members, +} + +impl From for ChannelVisibility { + fn from(value: proto::ChannelVisibility) -> Self { + match value { + proto::ChannelVisibility::Public => ChannelVisibility::Public, + proto::ChannelVisibility::Members => ChannelVisibility::Members, + } + } +} + +impl Into for ChannelVisibility { + fn into(self) -> proto::ChannelVisibility { + match self { + ChannelVisibility::Public => proto::ChannelVisibility::Public, + ChannelVisibility::Members => proto::ChannelVisibility::Members, + } + } +} + +impl Into for ChannelVisibility { + fn into(self) -> i32 { + let proto: proto::ChannelVisibility = self.into(); + proto.into() + } +} diff --git a/crates/collab2/src/db/queries.rs b/crates/collab2/src/db/queries.rs new file mode 100644 index 0000000000..629e26f1a9 --- /dev/null +++ b/crates/collab2/src/db/queries.rs @@ -0,0 +1,12 @@ +use super::*; + +pub mod access_tokens; +pub mod buffers; +pub mod channels; +pub mod contacts; +pub mod messages; +pub mod notifications; +pub mod projects; +pub mod rooms; +pub mod servers; +pub mod users; diff --git a/crates/collab2/src/db/queries/access_tokens.rs b/crates/collab2/src/db/queries/access_tokens.rs new file mode 100644 index 0000000000..589b6483df --- /dev/null +++ b/crates/collab2/src/db/queries/access_tokens.rs @@ -0,0 +1,54 @@ +use super::*; +use sea_orm::sea_query::Query; + +impl Database { + pub async fn create_access_token( + &self, + user_id: UserId, + access_token_hash: &str, + max_access_token_count: usize, + ) -> Result { + self.transaction(|tx| async { + let tx = tx; + + let token = access_token::ActiveModel { + user_id: ActiveValue::set(user_id), + hash: ActiveValue::set(access_token_hash.into()), + ..Default::default() + } + .insert(&*tx) + .await?; + + access_token::Entity::delete_many() + .filter( + access_token::Column::Id.in_subquery( + Query::select() + .column(access_token::Column::Id) + .from(access_token::Entity) + .and_where(access_token::Column::UserId.eq(user_id)) + .order_by(access_token::Column::Id, sea_orm::Order::Desc) + .limit(10000) + .offset(max_access_token_count as u64) + .to_owned(), + ), + ) + .exec(&*tx) + .await?; + Ok(token.id) + }) + .await + } + + pub async fn get_access_token( + &self, + access_token_id: AccessTokenId, + ) -> Result { + self.transaction(|tx| async move { + Ok(access_token::Entity::find_by_id(access_token_id) + .one(&*tx) + .await? + .ok_or_else(|| anyhow!("no such access token"))?) + }) + .await + } +} diff --git a/crates/collab2/src/db/queries/buffers.rs b/crates/collab2/src/db/queries/buffers.rs new file mode 100644 index 0000000000..9eddb1f618 --- /dev/null +++ b/crates/collab2/src/db/queries/buffers.rs @@ -0,0 +1,1078 @@ +use super::*; +use prost::Message; +use text::{EditOperation, UndoOperation}; + +pub struct LeftChannelBuffer { + pub channel_id: ChannelId, + pub collaborators: Vec, + pub connections: Vec, +} + +impl Database { + pub async fn join_channel_buffer( + &self, + channel_id: ChannelId, + user_id: UserId, + connection: ConnectionId, + ) -> Result { + self.transaction(|tx| async move { + let channel = self.get_channel_internal(channel_id, &*tx).await?; + self.check_user_is_channel_participant(&channel, user_id, &tx) + .await?; + + let buffer = channel::Model { + id: channel_id, + ..Default::default() + } + .find_related(buffer::Entity) + .one(&*tx) + .await?; + + let buffer = if let Some(buffer) = buffer { + buffer + } else { + let buffer = buffer::ActiveModel { + channel_id: ActiveValue::Set(channel_id), + ..Default::default() + } + .insert(&*tx) + .await?; + buffer_snapshot::ActiveModel { + buffer_id: ActiveValue::Set(buffer.id), + epoch: ActiveValue::Set(0), + text: ActiveValue::Set(String::new()), + operation_serialization_version: ActiveValue::Set( + storage::SERIALIZATION_VERSION, + ), + } + .insert(&*tx) + .await?; + buffer + }; + + // Join the collaborators + let mut collaborators = channel_buffer_collaborator::Entity::find() + .filter(channel_buffer_collaborator::Column::ChannelId.eq(channel_id)) + .all(&*tx) + .await?; + let replica_ids = collaborators + .iter() + .map(|c| c.replica_id) + .collect::>(); + let mut replica_id = ReplicaId(0); + while replica_ids.contains(&replica_id) { + replica_id.0 += 1; + } + let collaborator = channel_buffer_collaborator::ActiveModel { + channel_id: ActiveValue::Set(channel_id), + connection_id: ActiveValue::Set(connection.id as i32), + connection_server_id: ActiveValue::Set(ServerId(connection.owner_id as i32)), + user_id: ActiveValue::Set(user_id), + replica_id: ActiveValue::Set(replica_id), + ..Default::default() + } + .insert(&*tx) + .await?; + collaborators.push(collaborator); + + let (base_text, operations, max_operation) = + self.get_buffer_state(&buffer, &tx).await?; + + // Save the last observed operation + if let Some(op) = max_operation { + observed_buffer_edits::Entity::insert(observed_buffer_edits::ActiveModel { + user_id: ActiveValue::Set(user_id), + buffer_id: ActiveValue::Set(buffer.id), + epoch: ActiveValue::Set(op.epoch), + lamport_timestamp: ActiveValue::Set(op.lamport_timestamp), + replica_id: ActiveValue::Set(op.replica_id), + }) + .on_conflict( + OnConflict::columns([ + observed_buffer_edits::Column::UserId, + observed_buffer_edits::Column::BufferId, + ]) + .update_columns([ + observed_buffer_edits::Column::Epoch, + observed_buffer_edits::Column::LamportTimestamp, + ]) + .to_owned(), + ) + .exec(&*tx) + .await?; + } + + Ok(proto::JoinChannelBufferResponse { + buffer_id: buffer.id.to_proto(), + replica_id: replica_id.to_proto() as u32, + base_text, + operations, + epoch: buffer.epoch as u64, + collaborators: collaborators + .into_iter() + .map(|collaborator| proto::Collaborator { + peer_id: Some(collaborator.connection().into()), + user_id: collaborator.user_id.to_proto(), + replica_id: collaborator.replica_id.0 as u32, + }) + .collect(), + }) + }) + .await + } + + pub async fn rejoin_channel_buffers( + &self, + buffers: &[proto::ChannelBufferVersion], + user_id: UserId, + connection_id: ConnectionId, + ) -> Result> { + self.transaction(|tx| async move { + let mut results = Vec::new(); + for client_buffer in buffers { + let channel = self + .get_channel_internal(ChannelId::from_proto(client_buffer.channel_id), &*tx) + .await?; + if self + .check_user_is_channel_participant(&channel, user_id, &*tx) + .await + .is_err() + { + log::info!("user is not a member of channel"); + continue; + } + + let buffer = self.get_channel_buffer(channel.id, &*tx).await?; + let mut collaborators = channel_buffer_collaborator::Entity::find() + .filter(channel_buffer_collaborator::Column::ChannelId.eq(channel.id)) + .all(&*tx) + .await?; + + // If the buffer epoch hasn't changed since the client lost + // connection, then the client's buffer can be syncronized with + // the server's buffer. + if buffer.epoch as u64 != client_buffer.epoch { + log::info!("can't rejoin buffer, epoch has changed"); + continue; + } + + // Find the collaborator record for this user's previous lost + // connection. Update it with the new connection id. + let server_id = ServerId(connection_id.owner_id as i32); + let Some(self_collaborator) = collaborators.iter_mut().find(|c| { + c.user_id == user_id + && (c.connection_lost || c.connection_server_id != server_id) + }) else { + log::info!("can't rejoin buffer, no previous collaborator found"); + continue; + }; + let old_connection_id = self_collaborator.connection(); + *self_collaborator = channel_buffer_collaborator::ActiveModel { + id: ActiveValue::Unchanged(self_collaborator.id), + connection_id: ActiveValue::Set(connection_id.id as i32), + connection_server_id: ActiveValue::Set(ServerId(connection_id.owner_id as i32)), + connection_lost: ActiveValue::Set(false), + ..Default::default() + } + .update(&*tx) + .await?; + + let client_version = version_from_wire(&client_buffer.version); + let serialization_version = self + .get_buffer_operation_serialization_version(buffer.id, buffer.epoch, &*tx) + .await?; + + let mut rows = buffer_operation::Entity::find() + .filter( + buffer_operation::Column::BufferId + .eq(buffer.id) + .and(buffer_operation::Column::Epoch.eq(buffer.epoch)), + ) + .stream(&*tx) + .await?; + + // Find the server's version vector and any operations + // that the client has not seen. + let mut server_version = clock::Global::new(); + let mut operations = Vec::new(); + while let Some(row) = rows.next().await { + let row = row?; + let timestamp = clock::Lamport { + replica_id: row.replica_id as u16, + value: row.lamport_timestamp as u32, + }; + server_version.observe(timestamp); + if !client_version.observed(timestamp) { + operations.push(proto::Operation { + variant: Some(operation_from_storage(row, serialization_version)?), + }) + } + } + + results.push(RejoinedChannelBuffer { + old_connection_id, + buffer: proto::RejoinedChannelBuffer { + channel_id: client_buffer.channel_id, + version: version_to_wire(&server_version), + operations, + collaborators: collaborators + .into_iter() + .map(|collaborator| proto::Collaborator { + peer_id: Some(collaborator.connection().into()), + user_id: collaborator.user_id.to_proto(), + replica_id: collaborator.replica_id.0 as u32, + }) + .collect(), + }, + }); + } + + Ok(results) + }) + .await + } + + pub async fn clear_stale_channel_buffer_collaborators( + &self, + channel_id: ChannelId, + server_id: ServerId, + ) -> Result { + self.transaction(|tx| async move { + let db_collaborators = channel_buffer_collaborator::Entity::find() + .filter(channel_buffer_collaborator::Column::ChannelId.eq(channel_id)) + .all(&*tx) + .await?; + + let mut connection_ids = Vec::new(); + let mut collaborators = Vec::new(); + let mut collaborator_ids_to_remove = Vec::new(); + for db_collaborator in &db_collaborators { + if !db_collaborator.connection_lost + && db_collaborator.connection_server_id == server_id + { + connection_ids.push(db_collaborator.connection()); + collaborators.push(proto::Collaborator { + peer_id: Some(db_collaborator.connection().into()), + replica_id: db_collaborator.replica_id.0 as u32, + user_id: db_collaborator.user_id.to_proto(), + }) + } else { + collaborator_ids_to_remove.push(db_collaborator.id); + } + } + + channel_buffer_collaborator::Entity::delete_many() + .filter(channel_buffer_collaborator::Column::Id.is_in(collaborator_ids_to_remove)) + .exec(&*tx) + .await?; + + Ok(RefreshedChannelBuffer { + connection_ids, + collaborators, + }) + }) + .await + } + + pub async fn leave_channel_buffer( + &self, + channel_id: ChannelId, + connection: ConnectionId, + ) -> Result { + self.transaction(|tx| async move { + self.leave_channel_buffer_internal(channel_id, connection, &*tx) + .await + }) + .await + } + + pub async fn channel_buffer_connection_lost( + &self, + connection: ConnectionId, + tx: &DatabaseTransaction, + ) -> Result<()> { + channel_buffer_collaborator::Entity::update_many() + .filter( + Condition::all() + .add(channel_buffer_collaborator::Column::ConnectionId.eq(connection.id as i32)) + .add( + channel_buffer_collaborator::Column::ConnectionServerId + .eq(connection.owner_id as i32), + ), + ) + .set(channel_buffer_collaborator::ActiveModel { + connection_lost: ActiveValue::set(true), + ..Default::default() + }) + .exec(&*tx) + .await?; + Ok(()) + } + + pub async fn leave_channel_buffers( + &self, + connection: ConnectionId, + ) -> Result> { + self.transaction(|tx| async move { + #[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)] + enum QueryChannelIds { + ChannelId, + } + + let channel_ids: Vec = channel_buffer_collaborator::Entity::find() + .select_only() + .column(channel_buffer_collaborator::Column::ChannelId) + .filter(Condition::all().add( + channel_buffer_collaborator::Column::ConnectionId.eq(connection.id as i32), + )) + .into_values::<_, QueryChannelIds>() + .all(&*tx) + .await?; + + let mut result = Vec::new(); + for channel_id in channel_ids { + let left_channel_buffer = self + .leave_channel_buffer_internal(channel_id, connection, &*tx) + .await?; + result.push(left_channel_buffer); + } + + Ok(result) + }) + .await + } + + pub async fn leave_channel_buffer_internal( + &self, + channel_id: ChannelId, + connection: ConnectionId, + tx: &DatabaseTransaction, + ) -> Result { + let result = channel_buffer_collaborator::Entity::delete_many() + .filter( + Condition::all() + .add(channel_buffer_collaborator::Column::ChannelId.eq(channel_id)) + .add(channel_buffer_collaborator::Column::ConnectionId.eq(connection.id as i32)) + .add( + channel_buffer_collaborator::Column::ConnectionServerId + .eq(connection.owner_id as i32), + ), + ) + .exec(&*tx) + .await?; + if result.rows_affected == 0 { + Err(anyhow!("not a collaborator on this project"))?; + } + + let mut collaborators = Vec::new(); + let mut connections = Vec::new(); + let mut rows = channel_buffer_collaborator::Entity::find() + .filter( + Condition::all().add(channel_buffer_collaborator::Column::ChannelId.eq(channel_id)), + ) + .stream(&*tx) + .await?; + while let Some(row) = rows.next().await { + let row = row?; + let connection = row.connection(); + connections.push(connection); + collaborators.push(proto::Collaborator { + peer_id: Some(connection.into()), + replica_id: row.replica_id.0 as u32, + user_id: row.user_id.to_proto(), + }); + } + + drop(rows); + + if collaborators.is_empty() { + self.snapshot_channel_buffer(channel_id, &tx).await?; + } + + Ok(LeftChannelBuffer { + channel_id, + collaborators, + connections, + }) + } + + pub async fn get_channel_buffer_collaborators( + &self, + channel_id: ChannelId, + ) -> Result> { + self.transaction(|tx| async move { + self.get_channel_buffer_collaborators_internal(channel_id, &*tx) + .await + }) + .await + } + + async fn get_channel_buffer_collaborators_internal( + &self, + channel_id: ChannelId, + tx: &DatabaseTransaction, + ) -> Result> { + #[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)] + enum QueryUserIds { + UserId, + } + + let users: Vec = channel_buffer_collaborator::Entity::find() + .select_only() + .column(channel_buffer_collaborator::Column::UserId) + .filter( + Condition::all().add(channel_buffer_collaborator::Column::ChannelId.eq(channel_id)), + ) + .into_values::<_, QueryUserIds>() + .all(&*tx) + .await?; + + Ok(users) + } + + pub async fn update_channel_buffer( + &self, + channel_id: ChannelId, + user: UserId, + operations: &[proto::Operation], + ) -> Result<( + Vec, + Vec, + i32, + Vec, + )> { + self.transaction(move |tx| async move { + let channel = self.get_channel_internal(channel_id, &*tx).await?; + self.check_user_is_channel_member(&channel, user, &*tx) + .await?; + + let buffer = buffer::Entity::find() + .filter(buffer::Column::ChannelId.eq(channel_id)) + .one(&*tx) + .await? + .ok_or_else(|| anyhow!("no such buffer"))?; + + let serialization_version = self + .get_buffer_operation_serialization_version(buffer.id, buffer.epoch, &*tx) + .await?; + + let operations = operations + .iter() + .filter_map(|op| operation_to_storage(op, &buffer, serialization_version)) + .collect::>(); + + let mut channel_members; + let max_version; + + if !operations.is_empty() { + let max_operation = operations + .iter() + .max_by_key(|op| (op.lamport_timestamp.as_ref(), op.replica_id.as_ref())) + .unwrap(); + + max_version = vec![proto::VectorClockEntry { + replica_id: *max_operation.replica_id.as_ref() as u32, + timestamp: *max_operation.lamport_timestamp.as_ref() as u32, + }]; + + // get current channel participants and save the max operation above + self.save_max_operation( + user, + buffer.id, + buffer.epoch, + *max_operation.replica_id.as_ref(), + *max_operation.lamport_timestamp.as_ref(), + &*tx, + ) + .await?; + + channel_members = self.get_channel_participants(&channel, &*tx).await?; + let collaborators = self + .get_channel_buffer_collaborators_internal(channel_id, &*tx) + .await?; + channel_members.retain(|member| !collaborators.contains(member)); + + buffer_operation::Entity::insert_many(operations) + .on_conflict( + OnConflict::columns([ + buffer_operation::Column::BufferId, + buffer_operation::Column::Epoch, + buffer_operation::Column::LamportTimestamp, + buffer_operation::Column::ReplicaId, + ]) + .do_nothing() + .to_owned(), + ) + .exec(&*tx) + .await?; + } else { + channel_members = Vec::new(); + max_version = Vec::new(); + } + + let mut connections = Vec::new(); + let mut rows = channel_buffer_collaborator::Entity::find() + .filter( + Condition::all() + .add(channel_buffer_collaborator::Column::ChannelId.eq(channel_id)), + ) + .stream(&*tx) + .await?; + while let Some(row) = rows.next().await { + let row = row?; + connections.push(ConnectionId { + id: row.connection_id as u32, + owner_id: row.connection_server_id.0 as u32, + }); + } + + Ok((connections, channel_members, buffer.epoch, max_version)) + }) + .await + } + + async fn save_max_operation( + &self, + user_id: UserId, + buffer_id: BufferId, + epoch: i32, + replica_id: i32, + lamport_timestamp: i32, + tx: &DatabaseTransaction, + ) -> Result<()> { + use observed_buffer_edits::Column; + + observed_buffer_edits::Entity::insert(observed_buffer_edits::ActiveModel { + user_id: ActiveValue::Set(user_id), + buffer_id: ActiveValue::Set(buffer_id), + epoch: ActiveValue::Set(epoch), + replica_id: ActiveValue::Set(replica_id), + lamport_timestamp: ActiveValue::Set(lamport_timestamp), + }) + .on_conflict( + OnConflict::columns([Column::UserId, Column::BufferId]) + .update_columns([Column::Epoch, Column::LamportTimestamp, Column::ReplicaId]) + .action_cond_where( + Condition::any().add(Column::Epoch.lt(epoch)).add( + Condition::all().add(Column::Epoch.eq(epoch)).add( + Condition::any() + .add(Column::LamportTimestamp.lt(lamport_timestamp)) + .add( + Column::LamportTimestamp + .eq(lamport_timestamp) + .and(Column::ReplicaId.lt(replica_id)), + ), + ), + ), + ) + .to_owned(), + ) + .exec_without_returning(tx) + .await?; + + Ok(()) + } + + async fn get_buffer_operation_serialization_version( + &self, + buffer_id: BufferId, + epoch: i32, + tx: &DatabaseTransaction, + ) -> Result { + Ok(buffer_snapshot::Entity::find() + .filter(buffer_snapshot::Column::BufferId.eq(buffer_id)) + .filter(buffer_snapshot::Column::Epoch.eq(epoch)) + .select_only() + .column(buffer_snapshot::Column::OperationSerializationVersion) + .into_values::<_, QueryOperationSerializationVersion>() + .one(&*tx) + .await? + .ok_or_else(|| anyhow!("missing buffer snapshot"))?) + } + + pub async fn get_channel_buffer( + &self, + channel_id: ChannelId, + tx: &DatabaseTransaction, + ) -> Result { + Ok(channel::Model { + id: channel_id, + ..Default::default() + } + .find_related(buffer::Entity) + .one(&*tx) + .await? + .ok_or_else(|| anyhow!("no such buffer"))?) + } + + async fn get_buffer_state( + &self, + buffer: &buffer::Model, + tx: &DatabaseTransaction, + ) -> Result<( + String, + Vec, + Option, + )> { + let id = buffer.id; + let (base_text, version) = if buffer.epoch > 0 { + let snapshot = buffer_snapshot::Entity::find() + .filter( + buffer_snapshot::Column::BufferId + .eq(id) + .and(buffer_snapshot::Column::Epoch.eq(buffer.epoch)), + ) + .one(&*tx) + .await? + .ok_or_else(|| anyhow!("no such snapshot"))?; + + let version = snapshot.operation_serialization_version; + (snapshot.text, version) + } else { + (String::new(), storage::SERIALIZATION_VERSION) + }; + + let mut rows = buffer_operation::Entity::find() + .filter( + buffer_operation::Column::BufferId + .eq(id) + .and(buffer_operation::Column::Epoch.eq(buffer.epoch)), + ) + .order_by_asc(buffer_operation::Column::LamportTimestamp) + .order_by_asc(buffer_operation::Column::ReplicaId) + .stream(&*tx) + .await?; + + let mut operations = Vec::new(); + let mut last_row = None; + while let Some(row) = rows.next().await { + let row = row?; + last_row = Some(buffer_operation::Model { + buffer_id: row.buffer_id, + epoch: row.epoch, + lamport_timestamp: row.lamport_timestamp, + replica_id: row.lamport_timestamp, + value: Default::default(), + }); + operations.push(proto::Operation { + variant: Some(operation_from_storage(row, version)?), + }); + } + + Ok((base_text, operations, last_row)) + } + + async fn snapshot_channel_buffer( + &self, + channel_id: ChannelId, + tx: &DatabaseTransaction, + ) -> Result<()> { + let buffer = self.get_channel_buffer(channel_id, tx).await?; + let (base_text, operations, _) = self.get_buffer_state(&buffer, tx).await?; + if operations.is_empty() { + return Ok(()); + } + + let mut text_buffer = text::Buffer::new(0, 0, base_text); + text_buffer + .apply_ops(operations.into_iter().filter_map(operation_from_wire)) + .unwrap(); + + let base_text = text_buffer.text(); + let epoch = buffer.epoch + 1; + + buffer_snapshot::Model { + buffer_id: buffer.id, + epoch, + text: base_text, + operation_serialization_version: storage::SERIALIZATION_VERSION, + } + .into_active_model() + .insert(tx) + .await?; + + buffer::ActiveModel { + id: ActiveValue::Unchanged(buffer.id), + epoch: ActiveValue::Set(epoch), + ..Default::default() + } + .save(tx) + .await?; + + Ok(()) + } + + pub async fn observe_buffer_version( + &self, + buffer_id: BufferId, + user_id: UserId, + epoch: i32, + version: &[proto::VectorClockEntry], + ) -> Result<()> { + self.transaction(|tx| async move { + // For now, combine concurrent operations. + let Some(component) = version.iter().max_by_key(|version| version.timestamp) else { + return Ok(()); + }; + self.save_max_operation( + user_id, + buffer_id, + epoch, + component.replica_id as i32, + component.timestamp as i32, + &*tx, + ) + .await?; + Ok(()) + }) + .await + } + + pub async fn unseen_channel_buffer_changes( + &self, + user_id: UserId, + channel_ids: &[ChannelId], + tx: &DatabaseTransaction, + ) -> Result> { + #[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)] + enum QueryIds { + ChannelId, + Id, + } + + let mut channel_ids_by_buffer_id = HashMap::default(); + let mut rows = buffer::Entity::find() + .filter(buffer::Column::ChannelId.is_in(channel_ids.iter().copied())) + .stream(&*tx) + .await?; + while let Some(row) = rows.next().await { + let row = row?; + channel_ids_by_buffer_id.insert(row.id, row.channel_id); + } + drop(rows); + + let mut observed_edits_by_buffer_id = HashMap::default(); + let mut rows = observed_buffer_edits::Entity::find() + .filter(observed_buffer_edits::Column::UserId.eq(user_id)) + .filter( + observed_buffer_edits::Column::BufferId + .is_in(channel_ids_by_buffer_id.keys().copied()), + ) + .stream(&*tx) + .await?; + while let Some(row) = rows.next().await { + let row = row?; + observed_edits_by_buffer_id.insert(row.buffer_id, row); + } + drop(rows); + + let latest_operations = self + .get_latest_operations_for_buffers(channel_ids_by_buffer_id.keys().copied(), &*tx) + .await?; + + let mut changes = Vec::default(); + for latest in latest_operations { + if let Some(observed) = observed_edits_by_buffer_id.get(&latest.buffer_id) { + if ( + observed.epoch, + observed.lamport_timestamp, + observed.replica_id, + ) >= (latest.epoch, latest.lamport_timestamp, latest.replica_id) + { + continue; + } + } + + if let Some(channel_id) = channel_ids_by_buffer_id.get(&latest.buffer_id) { + changes.push(proto::UnseenChannelBufferChange { + channel_id: channel_id.to_proto(), + epoch: latest.epoch as u64, + version: vec![proto::VectorClockEntry { + replica_id: latest.replica_id as u32, + timestamp: latest.lamport_timestamp as u32, + }], + }); + } + } + + Ok(changes) + } + + pub async fn get_latest_operations_for_buffers( + &self, + buffer_ids: impl IntoIterator, + tx: &DatabaseTransaction, + ) -> Result> { + let mut values = String::new(); + for id in buffer_ids { + if !values.is_empty() { + values.push_str(", "); + } + write!(&mut values, "({})", id).unwrap(); + } + + if values.is_empty() { + return Ok(Vec::default()); + } + + let sql = format!( + r#" + SELECT + * + FROM + ( + SELECT + *, + row_number() OVER ( + PARTITION BY buffer_id + ORDER BY + epoch DESC, + lamport_timestamp DESC, + replica_id DESC + ) as row_number + FROM buffer_operations + WHERE + buffer_id in ({values}) + ) AS last_operations + WHERE + row_number = 1 + "#, + ); + + let stmt = Statement::from_string(self.pool.get_database_backend(), sql); + Ok(buffer_operation::Entity::find() + .from_raw_sql(stmt) + .all(&*tx) + .await?) + } +} + +fn operation_to_storage( + operation: &proto::Operation, + buffer: &buffer::Model, + _format: i32, +) -> Option { + let (replica_id, lamport_timestamp, value) = match operation.variant.as_ref()? { + proto::operation::Variant::Edit(operation) => ( + operation.replica_id, + operation.lamport_timestamp, + storage::Operation { + version: version_to_storage(&operation.version), + is_undo: false, + edit_ranges: operation + .ranges + .iter() + .map(|range| storage::Range { + start: range.start, + end: range.end, + }) + .collect(), + edit_texts: operation.new_text.clone(), + undo_counts: Vec::new(), + }, + ), + proto::operation::Variant::Undo(operation) => ( + operation.replica_id, + operation.lamport_timestamp, + storage::Operation { + version: version_to_storage(&operation.version), + is_undo: true, + edit_ranges: Vec::new(), + edit_texts: Vec::new(), + undo_counts: operation + .counts + .iter() + .map(|entry| storage::UndoCount { + replica_id: entry.replica_id, + lamport_timestamp: entry.lamport_timestamp, + count: entry.count, + }) + .collect(), + }, + ), + _ => None?, + }; + + Some(buffer_operation::ActiveModel { + buffer_id: ActiveValue::Set(buffer.id), + epoch: ActiveValue::Set(buffer.epoch), + replica_id: ActiveValue::Set(replica_id as i32), + lamport_timestamp: ActiveValue::Set(lamport_timestamp as i32), + value: ActiveValue::Set(value.encode_to_vec()), + }) +} + +fn operation_from_storage( + row: buffer_operation::Model, + _format_version: i32, +) -> Result { + let operation = + storage::Operation::decode(row.value.as_slice()).map_err(|error| anyhow!("{}", error))?; + let version = version_from_storage(&operation.version); + Ok(if operation.is_undo { + proto::operation::Variant::Undo(proto::operation::Undo { + replica_id: row.replica_id as u32, + lamport_timestamp: row.lamport_timestamp as u32, + version, + counts: operation + .undo_counts + .iter() + .map(|entry| proto::UndoCount { + replica_id: entry.replica_id, + lamport_timestamp: entry.lamport_timestamp, + count: entry.count, + }) + .collect(), + }) + } else { + proto::operation::Variant::Edit(proto::operation::Edit { + replica_id: row.replica_id as u32, + lamport_timestamp: row.lamport_timestamp as u32, + version, + ranges: operation + .edit_ranges + .into_iter() + .map(|range| proto::Range { + start: range.start, + end: range.end, + }) + .collect(), + new_text: operation.edit_texts, + }) + }) +} + +fn version_to_storage(version: &Vec) -> Vec { + version + .iter() + .map(|entry| storage::VectorClockEntry { + replica_id: entry.replica_id, + timestamp: entry.timestamp, + }) + .collect() +} + +fn version_from_storage(version: &Vec) -> Vec { + version + .iter() + .map(|entry| proto::VectorClockEntry { + replica_id: entry.replica_id, + timestamp: entry.timestamp, + }) + .collect() +} + +// This is currently a manual copy of the deserialization code in the client's langauge crate +pub fn operation_from_wire(operation: proto::Operation) -> Option { + match operation.variant? { + proto::operation::Variant::Edit(edit) => Some(text::Operation::Edit(EditOperation { + timestamp: clock::Lamport { + replica_id: edit.replica_id as text::ReplicaId, + value: edit.lamport_timestamp, + }, + version: version_from_wire(&edit.version), + ranges: edit + .ranges + .into_iter() + .map(|range| { + text::FullOffset(range.start as usize)..text::FullOffset(range.end as usize) + }) + .collect(), + new_text: edit.new_text.into_iter().map(Arc::from).collect(), + })), + proto::operation::Variant::Undo(undo) => Some(text::Operation::Undo(UndoOperation { + timestamp: clock::Lamport { + replica_id: undo.replica_id as text::ReplicaId, + value: undo.lamport_timestamp, + }, + version: version_from_wire(&undo.version), + counts: undo + .counts + .into_iter() + .map(|c| { + ( + clock::Lamport { + replica_id: c.replica_id as text::ReplicaId, + value: c.lamport_timestamp, + }, + c.count, + ) + }) + .collect(), + })), + _ => None, + } +} + +fn version_from_wire(message: &[proto::VectorClockEntry]) -> clock::Global { + let mut version = clock::Global::new(); + for entry in message { + version.observe(clock::Lamport { + replica_id: entry.replica_id as text::ReplicaId, + value: entry.timestamp, + }); + } + version +} + +fn version_to_wire(version: &clock::Global) -> Vec { + let mut message = Vec::new(); + for entry in version.iter() { + message.push(proto::VectorClockEntry { + replica_id: entry.replica_id as u32, + timestamp: entry.value, + }); + } + message +} + +#[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)] +enum QueryOperationSerializationVersion { + OperationSerializationVersion, +} + +mod storage { + #![allow(non_snake_case)] + use prost::Message; + pub const SERIALIZATION_VERSION: i32 = 1; + + #[derive(Message)] + pub struct Operation { + #[prost(message, repeated, tag = "2")] + pub version: Vec, + #[prost(bool, tag = "3")] + pub is_undo: bool, + #[prost(message, repeated, tag = "4")] + pub edit_ranges: Vec, + #[prost(string, repeated, tag = "5")] + pub edit_texts: Vec, + #[prost(message, repeated, tag = "6")] + pub undo_counts: Vec, + } + + #[derive(Message)] + pub struct VectorClockEntry { + #[prost(uint32, tag = "1")] + pub replica_id: u32, + #[prost(uint32, tag = "2")] + pub timestamp: u32, + } + + #[derive(Message)] + pub struct Range { + #[prost(uint64, tag = "1")] + pub start: u64, + #[prost(uint64, tag = "2")] + pub end: u64, + } + + #[derive(Message)] + pub struct UndoCount { + #[prost(uint32, tag = "1")] + pub replica_id: u32, + #[prost(uint32, tag = "2")] + pub lamport_timestamp: u32, + #[prost(uint32, tag = "3")] + pub count: u32, + } +} diff --git a/crates/collab2/src/db/queries/channels.rs b/crates/collab2/src/db/queries/channels.rs new file mode 100644 index 0000000000..68b06e435d --- /dev/null +++ b/crates/collab2/src/db/queries/channels.rs @@ -0,0 +1,1312 @@ +use super::*; +use rpc::proto::channel_member::Kind; +use sea_orm::TryGetableMany; + +impl Database { + #[cfg(test)] + pub async fn all_channels(&self) -> Result> { + self.transaction(move |tx| async move { + let mut channels = Vec::new(); + let mut rows = channel::Entity::find().stream(&*tx).await?; + while let Some(row) = rows.next().await { + let row = row?; + channels.push((row.id, row.name)); + } + Ok(channels) + }) + .await + } + + #[cfg(test)] + pub async fn create_root_channel(&self, name: &str, creator_id: UserId) -> Result { + Ok(self + .create_channel(name, None, creator_id) + .await? + .channel + .id) + } + + #[cfg(test)] + pub async fn create_sub_channel( + &self, + name: &str, + parent: ChannelId, + creator_id: UserId, + ) -> Result { + Ok(self + .create_channel(name, Some(parent), creator_id) + .await? + .channel + .id) + } + + pub async fn create_channel( + &self, + name: &str, + parent_channel_id: Option, + admin_id: UserId, + ) -> Result { + let name = Self::sanitize_channel_name(name)?; + self.transaction(move |tx| async move { + let mut parent = None; + + if let Some(parent_channel_id) = parent_channel_id { + let parent_channel = self.get_channel_internal(parent_channel_id, &*tx).await?; + self.check_user_is_channel_admin(&parent_channel, admin_id, &*tx) + .await?; + parent = Some(parent_channel); + } + + let channel = channel::ActiveModel { + id: ActiveValue::NotSet, + name: ActiveValue::Set(name.to_string()), + visibility: ActiveValue::Set(ChannelVisibility::Members), + parent_path: ActiveValue::Set( + parent + .as_ref() + .map_or(String::new(), |parent| parent.path()), + ), + } + .insert(&*tx) + .await?; + + let participants_to_update; + if let Some(parent) = &parent { + participants_to_update = self + .participants_to_notify_for_channel_change(parent, &*tx) + .await?; + } else { + participants_to_update = vec![]; + + channel_member::ActiveModel { + id: ActiveValue::NotSet, + channel_id: ActiveValue::Set(channel.id), + user_id: ActiveValue::Set(admin_id), + accepted: ActiveValue::Set(true), + role: ActiveValue::Set(ChannelRole::Admin), + } + .insert(&*tx) + .await?; + }; + + Ok(CreateChannelResult { + channel: Channel::from_model(channel, ChannelRole::Admin), + participants_to_update, + }) + }) + .await + } + + pub async fn join_channel( + &self, + channel_id: ChannelId, + user_id: UserId, + connection: ConnectionId, + environment: &str, + ) -> Result<(JoinRoom, Option, ChannelRole)> { + self.transaction(move |tx| async move { + let channel = self.get_channel_internal(channel_id, &*tx).await?; + let mut role = self.channel_role_for_user(&channel, user_id, &*tx).await?; + + let mut accept_invite_result = None; + + if role.is_none() { + if let Some(invitation) = self + .pending_invite_for_channel(&channel, user_id, &*tx) + .await? + { + // note, this may be a parent channel + role = Some(invitation.role); + channel_member::Entity::update(channel_member::ActiveModel { + accepted: ActiveValue::Set(true), + ..invitation.into_active_model() + }) + .exec(&*tx) + .await?; + + accept_invite_result = Some( + self.calculate_membership_updated(&channel, user_id, &*tx) + .await?, + ); + + debug_assert!( + self.channel_role_for_user(&channel, user_id, &*tx).await? == role + ); + } + } + + if channel.visibility == ChannelVisibility::Public { + role = Some(ChannelRole::Guest); + let channel_to_join = self + .public_ancestors_including_self(&channel, &*tx) + .await? + .first() + .cloned() + .unwrap_or(channel.clone()); + + channel_member::Entity::insert(channel_member::ActiveModel { + id: ActiveValue::NotSet, + channel_id: ActiveValue::Set(channel_to_join.id), + user_id: ActiveValue::Set(user_id), + accepted: ActiveValue::Set(true), + role: ActiveValue::Set(ChannelRole::Guest), + }) + .exec(&*tx) + .await?; + + accept_invite_result = Some( + self.calculate_membership_updated(&channel_to_join, user_id, &*tx) + .await?, + ); + + debug_assert!(self.channel_role_for_user(&channel, user_id, &*tx).await? == role); + } + + if role.is_none() || role == Some(ChannelRole::Banned) { + Err(anyhow!("not allowed"))? + } + + let live_kit_room = format!("channel-{}", nanoid::nanoid!(30)); + let room_id = self + .get_or_create_channel_room(channel_id, &live_kit_room, environment, &*tx) + .await?; + + self.join_channel_room_internal(room_id, user_id, connection, &*tx) + .await + .map(|jr| (jr, accept_invite_result, role.unwrap())) + }) + .await + } + + pub async fn set_channel_visibility( + &self, + channel_id: ChannelId, + visibility: ChannelVisibility, + admin_id: UserId, + ) -> Result { + self.transaction(move |tx| async move { + let channel = self.get_channel_internal(channel_id, &*tx).await?; + + self.check_user_is_channel_admin(&channel, admin_id, &*tx) + .await?; + + let previous_members = self + .get_channel_participant_details_internal(&channel, &*tx) + .await?; + + let mut model = channel.into_active_model(); + model.visibility = ActiveValue::Set(visibility); + let channel = model.update(&*tx).await?; + + let mut participants_to_update: HashMap = self + .participants_to_notify_for_channel_change(&channel, &*tx) + .await? + .into_iter() + .collect(); + + let mut channels_to_remove: Vec = vec![]; + let mut participants_to_remove: HashSet = HashSet::default(); + match visibility { + ChannelVisibility::Members => { + let all_descendents: Vec = self + .get_channel_descendants_including_self(vec![channel_id], &*tx) + .await? + .into_iter() + .map(|channel| channel.id) + .collect(); + + channels_to_remove = channel::Entity::find() + .filter( + channel::Column::Id + .is_in(all_descendents) + .and(channel::Column::Visibility.eq(ChannelVisibility::Public)), + ) + .all(&*tx) + .await? + .into_iter() + .map(|channel| channel.id) + .collect(); + + channels_to_remove.push(channel_id); + + for member in previous_members { + if member.role.can_only_see_public_descendants() { + participants_to_remove.insert(member.user_id); + } + } + } + ChannelVisibility::Public => { + if let Some(public_parent) = self.public_parent_channel(&channel, &*tx).await? { + let parent_updates = self + .participants_to_notify_for_channel_change(&public_parent, &*tx) + .await?; + + for (user_id, channels) in parent_updates { + participants_to_update.insert(user_id, channels); + } + } + } + } + + Ok(SetChannelVisibilityResult { + participants_to_update, + participants_to_remove, + channels_to_remove, + }) + }) + .await + } + + pub async fn delete_channel( + &self, + channel_id: ChannelId, + user_id: UserId, + ) -> Result<(Vec, Vec)> { + self.transaction(move |tx| async move { + let channel = self.get_channel_internal(channel_id, &*tx).await?; + self.check_user_is_channel_admin(&channel, user_id, &*tx) + .await?; + + let members_to_notify: Vec = channel_member::Entity::find() + .filter(channel_member::Column::ChannelId.is_in(channel.ancestors_including_self())) + .select_only() + .column(channel_member::Column::UserId) + .distinct() + .into_values::<_, QueryUserIds>() + .all(&*tx) + .await?; + + let channels_to_remove = self + .get_channel_descendants_including_self(vec![channel.id], &*tx) + .await? + .into_iter() + .map(|channel| channel.id) + .collect::>(); + + channel::Entity::delete_many() + .filter(channel::Column::Id.is_in(channels_to_remove.iter().copied())) + .exec(&*tx) + .await?; + + Ok((channels_to_remove, members_to_notify)) + }) + .await + } + + pub async fn invite_channel_member( + &self, + channel_id: ChannelId, + invitee_id: UserId, + inviter_id: UserId, + role: ChannelRole, + ) -> Result { + self.transaction(move |tx| async move { + let channel = self.get_channel_internal(channel_id, &*tx).await?; + self.check_user_is_channel_admin(&channel, inviter_id, &*tx) + .await?; + + channel_member::ActiveModel { + id: ActiveValue::NotSet, + channel_id: ActiveValue::Set(channel_id), + user_id: ActiveValue::Set(invitee_id), + accepted: ActiveValue::Set(false), + role: ActiveValue::Set(role), + } + .insert(&*tx) + .await?; + + let channel = Channel::from_model(channel, role); + + let notifications = self + .create_notification( + invitee_id, + rpc::Notification::ChannelInvitation { + channel_id: channel_id.to_proto(), + channel_name: channel.name.clone(), + inviter_id: inviter_id.to_proto(), + }, + true, + &*tx, + ) + .await? + .into_iter() + .collect(); + + Ok(InviteMemberResult { + channel, + notifications, + }) + }) + .await + } + + fn sanitize_channel_name(name: &str) -> Result<&str> { + let new_name = name.trim().trim_start_matches('#'); + if new_name == "" { + Err(anyhow!("channel name can't be blank"))?; + } + Ok(new_name) + } + + pub async fn rename_channel( + &self, + channel_id: ChannelId, + admin_id: UserId, + new_name: &str, + ) -> Result { + self.transaction(move |tx| async move { + let new_name = Self::sanitize_channel_name(new_name)?.to_string(); + + let channel = self.get_channel_internal(channel_id, &*tx).await?; + let role = self + .check_user_is_channel_admin(&channel, admin_id, &*tx) + .await?; + + let mut model = channel.into_active_model(); + model.name = ActiveValue::Set(new_name.clone()); + let channel = model.update(&*tx).await?; + + let participants = self + .get_channel_participant_details_internal(&channel, &*tx) + .await?; + + Ok(RenameChannelResult { + channel: Channel::from_model(channel.clone(), role), + participants_to_update: participants + .iter() + .map(|participant| { + ( + participant.user_id, + Channel::from_model(channel.clone(), participant.role), + ) + }) + .collect(), + }) + }) + .await + } + + pub async fn respond_to_channel_invite( + &self, + channel_id: ChannelId, + user_id: UserId, + accept: bool, + ) -> Result { + self.transaction(move |tx| async move { + let channel = self.get_channel_internal(channel_id, &*tx).await?; + + let membership_update = if accept { + let rows_affected = channel_member::Entity::update_many() + .set(channel_member::ActiveModel { + accepted: ActiveValue::Set(accept), + ..Default::default() + }) + .filter( + channel_member::Column::ChannelId + .eq(channel_id) + .and(channel_member::Column::UserId.eq(user_id)) + .and(channel_member::Column::Accepted.eq(false)), + ) + .exec(&*tx) + .await? + .rows_affected; + + if rows_affected == 0 { + Err(anyhow!("no such invitation"))?; + } + + Some( + self.calculate_membership_updated(&channel, user_id, &*tx) + .await?, + ) + } else { + let rows_affected = channel_member::Entity::delete_many() + .filter( + channel_member::Column::ChannelId + .eq(channel_id) + .and(channel_member::Column::UserId.eq(user_id)) + .and(channel_member::Column::Accepted.eq(false)), + ) + .exec(&*tx) + .await? + .rows_affected; + if rows_affected == 0 { + Err(anyhow!("no such invitation"))?; + } + + None + }; + + Ok(RespondToChannelInvite { + membership_update, + notifications: self + .mark_notification_as_read_with_response( + user_id, + &rpc::Notification::ChannelInvitation { + channel_id: channel_id.to_proto(), + channel_name: Default::default(), + inviter_id: Default::default(), + }, + accept, + &*tx, + ) + .await? + .into_iter() + .collect(), + }) + }) + .await + } + + async fn calculate_membership_updated( + &self, + channel: &channel::Model, + user_id: UserId, + tx: &DatabaseTransaction, + ) -> Result { + let new_channels = self.get_user_channels(user_id, Some(channel), &*tx).await?; + let removed_channels = self + .get_channel_descendants_including_self(vec![channel.id], &*tx) + .await? + .into_iter() + .filter_map(|channel| { + if !new_channels.channels.iter().any(|c| c.id == channel.id) { + Some(channel.id) + } else { + None + } + }) + .collect::>(); + + Ok(MembershipUpdated { + channel_id: channel.id, + new_channels, + removed_channels, + }) + } + + pub async fn remove_channel_member( + &self, + channel_id: ChannelId, + member_id: UserId, + admin_id: UserId, + ) -> Result { + self.transaction(|tx| async move { + let channel = self.get_channel_internal(channel_id, &*tx).await?; + self.check_user_is_channel_admin(&channel, admin_id, &*tx) + .await?; + + let result = channel_member::Entity::delete_many() + .filter( + channel_member::Column::ChannelId + .eq(channel_id) + .and(channel_member::Column::UserId.eq(member_id)), + ) + .exec(&*tx) + .await?; + + if result.rows_affected == 0 { + Err(anyhow!("no such member"))?; + } + + Ok(RemoveChannelMemberResult { + membership_update: self + .calculate_membership_updated(&channel, member_id, &*tx) + .await?, + notification_id: self + .remove_notification( + member_id, + rpc::Notification::ChannelInvitation { + channel_id: channel_id.to_proto(), + channel_name: Default::default(), + inviter_id: Default::default(), + }, + &*tx, + ) + .await?, + }) + }) + .await + } + + pub async fn get_channel_invites_for_user(&self, user_id: UserId) -> Result> { + self.transaction(|tx| async move { + let mut role_for_channel: HashMap = HashMap::default(); + + let channel_invites = channel_member::Entity::find() + .filter( + channel_member::Column::UserId + .eq(user_id) + .and(channel_member::Column::Accepted.eq(false)), + ) + .all(&*tx) + .await?; + + for invite in channel_invites { + role_for_channel.insert(invite.channel_id, invite.role); + } + + let channels = channel::Entity::find() + .filter(channel::Column::Id.is_in(role_for_channel.keys().copied())) + .all(&*tx) + .await?; + + let channels = channels + .into_iter() + .filter_map(|channel| { + let role = *role_for_channel.get(&channel.id)?; + Some(Channel::from_model(channel, role)) + }) + .collect(); + + Ok(channels) + }) + .await + } + + pub async fn get_channels_for_user(&self, user_id: UserId) -> Result { + self.transaction(|tx| async move { + let tx = tx; + + self.get_user_channels(user_id, None, &tx).await + }) + .await + } + + pub async fn get_user_channels( + &self, + user_id: UserId, + ancestor_channel: Option<&channel::Model>, + tx: &DatabaseTransaction, + ) -> Result { + let channel_memberships = channel_member::Entity::find() + .filter( + channel_member::Column::UserId + .eq(user_id) + .and(channel_member::Column::Accepted.eq(true)), + ) + .all(&*tx) + .await?; + + let descendants = self + .get_channel_descendants_including_self( + channel_memberships.iter().map(|m| m.channel_id), + &*tx, + ) + .await?; + + let mut roles_by_channel_id: HashMap = HashMap::default(); + for membership in channel_memberships.iter() { + roles_by_channel_id.insert(membership.channel_id, membership.role); + } + + let mut visible_channel_ids: HashSet = HashSet::default(); + + let channels: Vec = descendants + .into_iter() + .filter_map(|channel| { + let parent_role = channel + .parent_id() + .and_then(|parent_id| roles_by_channel_id.get(&parent_id)); + + let role = if let Some(parent_role) = parent_role { + let role = if let Some(existing_role) = roles_by_channel_id.get(&channel.id) { + existing_role.max(*parent_role) + } else { + *parent_role + }; + roles_by_channel_id.insert(channel.id, role); + role + } else { + *roles_by_channel_id.get(&channel.id)? + }; + + let can_see_parent_paths = role.can_see_all_descendants() + || role.can_only_see_public_descendants() + && channel.visibility == ChannelVisibility::Public; + if !can_see_parent_paths { + return None; + } + + visible_channel_ids.insert(channel.id); + + if let Some(ancestor) = ancestor_channel { + if !channel + .ancestors_including_self() + .any(|id| id == ancestor.id) + { + return None; + } + } + + let mut channel = Channel::from_model(channel, role); + channel + .parent_path + .retain(|id| visible_channel_ids.contains(&id)); + + Some(channel) + }) + .collect(); + + #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] + enum QueryUserIdsAndChannelIds { + ChannelId, + UserId, + } + + let mut channel_participants: HashMap> = HashMap::default(); + { + let mut rows = room_participant::Entity::find() + .inner_join(room::Entity) + .filter(room::Column::ChannelId.is_in(channels.iter().map(|c| c.id))) + .select_only() + .column(room::Column::ChannelId) + .column(room_participant::Column::UserId) + .into_values::<_, QueryUserIdsAndChannelIds>() + .stream(&*tx) + .await?; + while let Some(row) = rows.next().await { + let row: (ChannelId, UserId) = row?; + channel_participants.entry(row.0).or_default().push(row.1) + } + } + + let channel_ids = channels.iter().map(|c| c.id).collect::>(); + let channel_buffer_changes = self + .unseen_channel_buffer_changes(user_id, &channel_ids, &*tx) + .await?; + + let unseen_messages = self + .unseen_channel_messages(user_id, &channel_ids, &*tx) + .await?; + + Ok(ChannelsForUser { + channels, + channel_participants, + unseen_buffer_changes: channel_buffer_changes, + channel_messages: unseen_messages, + }) + } + + async fn participants_to_notify_for_channel_change( + &self, + new_parent: &channel::Model, + tx: &DatabaseTransaction, + ) -> Result> { + let mut results: Vec<(UserId, ChannelsForUser)> = Vec::new(); + + let members = self + .get_channel_participant_details_internal(new_parent, &*tx) + .await?; + + for member in members.iter() { + if !member.role.can_see_all_descendants() { + continue; + } + results.push(( + member.user_id, + self.get_user_channels(member.user_id, Some(new_parent), &*tx) + .await?, + )) + } + + let public_parents = self + .public_ancestors_including_self(new_parent, &*tx) + .await?; + let public_parent = public_parents.last(); + + let Some(public_parent) = public_parent else { + return Ok(results); + }; + + // could save some time in the common case by skipping this if the + // new channel is not public and has no public descendants. + let public_members = if public_parent == new_parent { + members + } else { + self.get_channel_participant_details_internal(public_parent, &*tx) + .await? + }; + + for member in public_members { + if !member.role.can_only_see_public_descendants() { + continue; + }; + results.push(( + member.user_id, + self.get_user_channels(member.user_id, Some(public_parent), &*tx) + .await?, + )) + } + + Ok(results) + } + + pub async fn set_channel_member_role( + &self, + channel_id: ChannelId, + admin_id: UserId, + for_user: UserId, + role: ChannelRole, + ) -> Result { + self.transaction(|tx| async move { + let channel = self.get_channel_internal(channel_id, &*tx).await?; + self.check_user_is_channel_admin(&channel, admin_id, &*tx) + .await?; + + let membership = channel_member::Entity::find() + .filter( + channel_member::Column::ChannelId + .eq(channel_id) + .and(channel_member::Column::UserId.eq(for_user)), + ) + .one(&*tx) + .await?; + + let Some(membership) = membership else { + Err(anyhow!("no such member"))? + }; + + let mut update = membership.into_active_model(); + update.role = ActiveValue::Set(role); + let updated = channel_member::Entity::update(update).exec(&*tx).await?; + + if updated.accepted { + Ok(SetMemberRoleResult::MembershipUpdated( + self.calculate_membership_updated(&channel, for_user, &*tx) + .await?, + )) + } else { + Ok(SetMemberRoleResult::InviteUpdated(Channel::from_model( + channel, role, + ))) + } + }) + .await + } + + pub async fn get_channel_participant_details( + &self, + channel_id: ChannelId, + user_id: UserId, + ) -> Result> { + let (role, members) = self + .transaction(move |tx| async move { + let channel = self.get_channel_internal(channel_id, &*tx).await?; + let role = self + .check_user_is_channel_participant(&channel, user_id, &*tx) + .await?; + Ok(( + role, + self.get_channel_participant_details_internal(&channel, &*tx) + .await?, + )) + }) + .await?; + + if role == ChannelRole::Admin { + Ok(members + .into_iter() + .map(|channel_member| channel_member.to_proto()) + .collect()) + } else { + return Ok(members + .into_iter() + .filter_map(|member| { + if member.kind == proto::channel_member::Kind::Invitee { + return None; + } + Some(ChannelMember { + role: member.role, + user_id: member.user_id, + kind: proto::channel_member::Kind::Member, + }) + }) + .map(|channel_member| channel_member.to_proto()) + .collect()); + } + } + + async fn get_channel_participant_details_internal( + &self, + channel: &channel::Model, + tx: &DatabaseTransaction, + ) -> Result> { + #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] + enum QueryMemberDetails { + UserId, + Role, + IsDirectMember, + Accepted, + Visibility, + } + + let mut stream = channel_member::Entity::find() + .left_join(channel::Entity) + .filter(channel_member::Column::ChannelId.is_in(channel.ancestors_including_self())) + .select_only() + .column(channel_member::Column::UserId) + .column(channel_member::Column::Role) + .column_as( + channel_member::Column::ChannelId.eq(channel.id), + QueryMemberDetails::IsDirectMember, + ) + .column(channel_member::Column::Accepted) + .column(channel::Column::Visibility) + .into_values::<_, QueryMemberDetails>() + .stream(&*tx) + .await?; + + let mut user_details: HashMap = HashMap::default(); + + while let Some(user_membership) = stream.next().await { + let (user_id, channel_role, is_direct_member, is_invite_accepted, visibility): ( + UserId, + ChannelRole, + bool, + bool, + ChannelVisibility, + ) = user_membership?; + let kind = match (is_direct_member, is_invite_accepted) { + (true, true) => proto::channel_member::Kind::Member, + (true, false) => proto::channel_member::Kind::Invitee, + (false, true) => proto::channel_member::Kind::AncestorMember, + (false, false) => continue, + }; + + if channel_role == ChannelRole::Guest + && visibility != ChannelVisibility::Public + && channel.visibility != ChannelVisibility::Public + { + continue; + } + + if let Some(details_mut) = user_details.get_mut(&user_id) { + if channel_role.should_override(details_mut.role) { + details_mut.role = channel_role; + } + if kind == Kind::Member { + details_mut.kind = kind; + // the UI is going to be a bit confusing if you already have permissions + // that are greater than or equal to the ones you're being invited to. + } else if kind == Kind::Invitee && details_mut.kind == Kind::AncestorMember { + details_mut.kind = kind; + } + } else { + user_details.insert( + user_id, + ChannelMember { + user_id, + kind, + role: channel_role, + }, + ); + } + } + + Ok(user_details + .into_iter() + .map(|(_, details)| details) + .collect()) + } + + pub async fn get_channel_participants( + &self, + channel: &channel::Model, + tx: &DatabaseTransaction, + ) -> Result> { + let participants = self + .get_channel_participant_details_internal(channel, &*tx) + .await?; + Ok(participants + .into_iter() + .map(|member| member.user_id) + .collect()) + } + + pub async fn check_user_is_channel_admin( + &self, + channel: &channel::Model, + user_id: UserId, + tx: &DatabaseTransaction, + ) -> Result { + let role = self.channel_role_for_user(channel, user_id, tx).await?; + match role { + Some(ChannelRole::Admin) => Ok(role.unwrap()), + Some(ChannelRole::Member) + | Some(ChannelRole::Banned) + | Some(ChannelRole::Guest) + | None => Err(anyhow!( + "user is not a channel admin or channel does not exist" + ))?, + } + } + + pub async fn check_user_is_channel_member( + &self, + channel: &channel::Model, + user_id: UserId, + tx: &DatabaseTransaction, + ) -> Result { + let channel_role = self.channel_role_for_user(channel, user_id, tx).await?; + match channel_role { + Some(ChannelRole::Admin) | Some(ChannelRole::Member) => Ok(channel_role.unwrap()), + Some(ChannelRole::Banned) | Some(ChannelRole::Guest) | None => Err(anyhow!( + "user is not a channel member or channel does not exist" + ))?, + } + } + + pub async fn check_user_is_channel_participant( + &self, + channel: &channel::Model, + user_id: UserId, + tx: &DatabaseTransaction, + ) -> Result { + let role = self.channel_role_for_user(channel, user_id, tx).await?; + match role { + Some(ChannelRole::Admin) | Some(ChannelRole::Member) | Some(ChannelRole::Guest) => { + Ok(role.unwrap()) + } + Some(ChannelRole::Banned) | None => Err(anyhow!( + "user is not a channel participant or channel does not exist" + ))?, + } + } + + pub async fn pending_invite_for_channel( + &self, + channel: &channel::Model, + user_id: UserId, + tx: &DatabaseTransaction, + ) -> Result> { + let row = channel_member::Entity::find() + .filter(channel_member::Column::ChannelId.is_in(channel.ancestors_including_self())) + .filter(channel_member::Column::UserId.eq(user_id)) + .filter(channel_member::Column::Accepted.eq(false)) + .one(&*tx) + .await?; + + Ok(row) + } + + pub async fn public_parent_channel( + &self, + channel: &channel::Model, + tx: &DatabaseTransaction, + ) -> Result> { + let mut path = self.public_ancestors_including_self(channel, &*tx).await?; + if path.last().unwrap().id == channel.id { + path.pop(); + } + Ok(path.pop()) + } + + pub async fn public_ancestors_including_self( + &self, + channel: &channel::Model, + tx: &DatabaseTransaction, + ) -> Result> { + let visible_channels = channel::Entity::find() + .filter(channel::Column::Id.is_in(channel.ancestors_including_self())) + .filter(channel::Column::Visibility.eq(ChannelVisibility::Public)) + .order_by_asc(channel::Column::ParentPath) + .all(&*tx) + .await?; + + Ok(visible_channels) + } + + pub async fn channel_role_for_user( + &self, + channel: &channel::Model, + user_id: UserId, + tx: &DatabaseTransaction, + ) -> Result> { + #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] + enum QueryChannelMembership { + ChannelId, + Role, + Visibility, + } + + let mut rows = channel_member::Entity::find() + .left_join(channel::Entity) + .filter( + channel_member::Column::ChannelId + .is_in(channel.ancestors_including_self()) + .and(channel_member::Column::UserId.eq(user_id)) + .and(channel_member::Column::Accepted.eq(true)), + ) + .select_only() + .column(channel_member::Column::ChannelId) + .column(channel_member::Column::Role) + .column(channel::Column::Visibility) + .into_values::<_, QueryChannelMembership>() + .stream(&*tx) + .await?; + + let mut user_role: Option = None; + + let mut is_participant = false; + let mut current_channel_visibility = None; + + // note these channels are not iterated in any particular order, + // our current logic takes the highest permission available. + while let Some(row) = rows.next().await { + let (membership_channel, role, visibility): ( + ChannelId, + ChannelRole, + ChannelVisibility, + ) = row?; + + match role { + ChannelRole::Admin | ChannelRole::Member | ChannelRole::Banned => { + if let Some(users_role) = user_role { + user_role = Some(users_role.max(role)); + } else { + user_role = Some(role) + } + } + ChannelRole::Guest if visibility == ChannelVisibility::Public => { + is_participant = true + } + ChannelRole::Guest => {} + } + if channel.id == membership_channel { + current_channel_visibility = Some(visibility); + } + } + // free up database connection + drop(rows); + + if is_participant && user_role.is_none() { + if current_channel_visibility.is_none() { + current_channel_visibility = channel::Entity::find() + .filter(channel::Column::Id.eq(channel.id)) + .one(&*tx) + .await? + .map(|channel| channel.visibility); + } + if current_channel_visibility == Some(ChannelVisibility::Public) { + user_role = Some(ChannelRole::Guest); + } + } + + Ok(user_role) + } + + // Get the descendants of the given set if channels, ordered by their + // path. + async fn get_channel_descendants_including_self( + &self, + channel_ids: impl IntoIterator, + tx: &DatabaseTransaction, + ) -> Result> { + let mut values = String::new(); + for id in channel_ids { + if !values.is_empty() { + values.push_str(", "); + } + write!(&mut values, "({})", id).unwrap(); + } + + if values.is_empty() { + return Ok(vec![]); + } + + let sql = format!( + r#" + SELECT DISTINCT + descendant_channels.*, + descendant_channels.parent_path || descendant_channels.id as full_path + FROM + channels parent_channels, channels descendant_channels + WHERE + descendant_channels.id IN ({values}) OR + ( + parent_channels.id IN ({values}) AND + descendant_channels.parent_path LIKE (parent_channels.parent_path || parent_channels.id || '/%') + ) + ORDER BY + full_path ASC + "# + ); + + Ok(channel::Entity::find() + .from_raw_sql(Statement::from_string( + self.pool.get_database_backend(), + sql, + )) + .all(tx) + .await?) + } + + /// Returns the channel with the given ID + pub async fn get_channel(&self, channel_id: ChannelId, user_id: UserId) -> Result { + self.transaction(|tx| async move { + let channel = self.get_channel_internal(channel_id, &*tx).await?; + let role = self + .check_user_is_channel_participant(&channel, user_id, &*tx) + .await?; + + Ok(Channel::from_model(channel, role)) + }) + .await + } + + pub async fn get_channel_internal( + &self, + channel_id: ChannelId, + tx: &DatabaseTransaction, + ) -> Result { + Ok(channel::Entity::find_by_id(channel_id) + .one(&*tx) + .await? + .ok_or_else(|| anyhow!("no such channel"))?) + } + + pub(crate) async fn get_or_create_channel_room( + &self, + channel_id: ChannelId, + live_kit_room: &str, + environment: &str, + tx: &DatabaseTransaction, + ) -> Result { + let room = room::Entity::find() + .filter(room::Column::ChannelId.eq(channel_id)) + .one(&*tx) + .await?; + + let room_id = if let Some(room) = room { + if let Some(env) = room.enviroment { + if &env != environment { + Err(anyhow!("must join using the {} release", env))?; + } + } + room.id + } else { + let result = room::Entity::insert(room::ActiveModel { + channel_id: ActiveValue::Set(Some(channel_id)), + live_kit_room: ActiveValue::Set(live_kit_room.to_string()), + enviroment: ActiveValue::Set(Some(environment.to_string())), + ..Default::default() + }) + .exec(&*tx) + .await?; + + result.last_insert_id + }; + + Ok(room_id) + } + + /// Move a channel from one parent to another + pub async fn move_channel( + &self, + channel_id: ChannelId, + new_parent_id: Option, + admin_id: UserId, + ) -> Result> { + self.transaction(|tx| async move { + let channel = self.get_channel_internal(channel_id, &*tx).await?; + self.check_user_is_channel_admin(&channel, admin_id, &*tx) + .await?; + + let new_parent_path; + let new_parent_channel; + if let Some(new_parent_id) = new_parent_id { + let new_parent = self.get_channel_internal(new_parent_id, &*tx).await?; + self.check_user_is_channel_admin(&new_parent, admin_id, &*tx) + .await?; + + new_parent_path = new_parent.path(); + new_parent_channel = Some(new_parent); + } else { + new_parent_path = String::new(); + new_parent_channel = None; + }; + + let previous_participants = self + .get_channel_participant_details_internal(&channel, &*tx) + .await?; + + let old_path = format!("{}{}/", channel.parent_path, channel.id); + let new_path = format!("{}{}/", new_parent_path, channel.id); + + if old_path == new_path { + return Ok(None); + } + + let mut model = channel.into_active_model(); + model.parent_path = ActiveValue::Set(new_parent_path); + let channel = model.update(&*tx).await?; + + if new_parent_channel.is_none() { + channel_member::ActiveModel { + id: ActiveValue::NotSet, + channel_id: ActiveValue::Set(channel_id), + user_id: ActiveValue::Set(admin_id), + accepted: ActiveValue::Set(true), + role: ActiveValue::Set(ChannelRole::Admin), + } + .insert(&*tx) + .await?; + } + + let descendent_ids = + ChannelId::find_by_statement::(Statement::from_sql_and_values( + self.pool.get_database_backend(), + " + UPDATE channels SET parent_path = REPLACE(parent_path, $1, $2) + WHERE parent_path LIKE $3 || '%' + RETURNING id + ", + [old_path.clone().into(), new_path.into(), old_path.into()], + )) + .all(&*tx) + .await?; + + let participants_to_update: HashMap<_, _> = self + .participants_to_notify_for_channel_change( + new_parent_channel.as_ref().unwrap_or(&channel), + &*tx, + ) + .await? + .into_iter() + .collect(); + + let mut moved_channels: HashSet = HashSet::default(); + for id in descendent_ids { + moved_channels.insert(id); + } + moved_channels.insert(channel_id); + + let mut participants_to_remove: HashSet = HashSet::default(); + for participant in previous_participants { + if participant.kind == proto::channel_member::Kind::AncestorMember { + if !participants_to_update.contains_key(&participant.user_id) { + participants_to_remove.insert(participant.user_id); + } + } + } + + Ok(Some(MoveChannelResult { + participants_to_remove, + participants_to_update, + moved_channels, + })) + }) + .await + } +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] +enum QueryIds { + Id, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] +enum QueryUserIds { + UserId, +} diff --git a/crates/collab2/src/db/queries/contacts.rs b/crates/collab2/src/db/queries/contacts.rs new file mode 100644 index 0000000000..f31f1addbd --- /dev/null +++ b/crates/collab2/src/db/queries/contacts.rs @@ -0,0 +1,353 @@ +use super::*; + +impl Database { + pub async fn get_contacts(&self, user_id: UserId) -> Result> { + #[derive(Debug, FromQueryResult)] + struct ContactWithUserBusyStatuses { + user_id_a: UserId, + user_id_b: UserId, + a_to_b: bool, + accepted: bool, + user_a_busy: bool, + user_b_busy: bool, + } + + self.transaction(|tx| async move { + let user_a_participant = Alias::new("user_a_participant"); + let user_b_participant = Alias::new("user_b_participant"); + let mut db_contacts = contact::Entity::find() + .column_as( + Expr::col((user_a_participant.clone(), room_participant::Column::Id)) + .is_not_null(), + "user_a_busy", + ) + .column_as( + Expr::col((user_b_participant.clone(), room_participant::Column::Id)) + .is_not_null(), + "user_b_busy", + ) + .filter( + contact::Column::UserIdA + .eq(user_id) + .or(contact::Column::UserIdB.eq(user_id)), + ) + .join_as( + JoinType::LeftJoin, + contact::Relation::UserARoomParticipant.def(), + user_a_participant, + ) + .join_as( + JoinType::LeftJoin, + contact::Relation::UserBRoomParticipant.def(), + user_b_participant, + ) + .into_model::() + .stream(&*tx) + .await?; + + let mut contacts = Vec::new(); + while let Some(db_contact) = db_contacts.next().await { + let db_contact = db_contact?; + if db_contact.user_id_a == user_id { + if db_contact.accepted { + contacts.push(Contact::Accepted { + user_id: db_contact.user_id_b, + busy: db_contact.user_b_busy, + }); + } else if db_contact.a_to_b { + contacts.push(Contact::Outgoing { + user_id: db_contact.user_id_b, + }) + } else { + contacts.push(Contact::Incoming { + user_id: db_contact.user_id_b, + }); + } + } else if db_contact.accepted { + contacts.push(Contact::Accepted { + user_id: db_contact.user_id_a, + busy: db_contact.user_a_busy, + }); + } else if db_contact.a_to_b { + contacts.push(Contact::Incoming { + user_id: db_contact.user_id_a, + }); + } else { + contacts.push(Contact::Outgoing { + user_id: db_contact.user_id_a, + }); + } + } + + contacts.sort_unstable_by_key(|contact| contact.user_id()); + + Ok(contacts) + }) + .await + } + + pub async fn is_user_busy(&self, user_id: UserId) -> Result { + self.transaction(|tx| async move { + let participant = room_participant::Entity::find() + .filter(room_participant::Column::UserId.eq(user_id)) + .one(&*tx) + .await?; + Ok(participant.is_some()) + }) + .await + } + + pub async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result { + self.transaction(|tx| async move { + let (id_a, id_b) = if user_id_1 < user_id_2 { + (user_id_1, user_id_2) + } else { + (user_id_2, user_id_1) + }; + + Ok(contact::Entity::find() + .filter( + contact::Column::UserIdA + .eq(id_a) + .and(contact::Column::UserIdB.eq(id_b)) + .and(contact::Column::Accepted.eq(true)), + ) + .one(&*tx) + .await? + .is_some()) + }) + .await + } + + pub async fn send_contact_request( + &self, + sender_id: UserId, + receiver_id: UserId, + ) -> Result { + self.transaction(|tx| async move { + let (id_a, id_b, a_to_b) = if sender_id < receiver_id { + (sender_id, receiver_id, true) + } else { + (receiver_id, sender_id, false) + }; + + let rows_affected = contact::Entity::insert(contact::ActiveModel { + user_id_a: ActiveValue::set(id_a), + user_id_b: ActiveValue::set(id_b), + a_to_b: ActiveValue::set(a_to_b), + accepted: ActiveValue::set(false), + should_notify: ActiveValue::set(true), + ..Default::default() + }) + .on_conflict( + OnConflict::columns([contact::Column::UserIdA, contact::Column::UserIdB]) + .values([ + (contact::Column::Accepted, true.into()), + (contact::Column::ShouldNotify, false.into()), + ]) + .action_and_where( + contact::Column::Accepted.eq(false).and( + contact::Column::AToB + .eq(a_to_b) + .and(contact::Column::UserIdA.eq(id_b)) + .or(contact::Column::AToB + .ne(a_to_b) + .and(contact::Column::UserIdA.eq(id_a))), + ), + ) + .to_owned(), + ) + .exec_without_returning(&*tx) + .await?; + + if rows_affected == 0 { + Err(anyhow!("contact already requested"))?; + } + + Ok(self + .create_notification( + receiver_id, + rpc::Notification::ContactRequest { + sender_id: sender_id.to_proto(), + }, + true, + &*tx, + ) + .await? + .into_iter() + .collect()) + }) + .await + } + + /// Returns a bool indicating whether the removed contact had originally accepted or not + /// + /// Deletes the contact identified by the requester and responder ids, and then returns + /// whether the deleted contact had originally accepted or was a pending contact request. + /// + /// # Arguments + /// + /// * `requester_id` - The user that initiates this request + /// * `responder_id` - The user that will be removed + pub async fn remove_contact( + &self, + requester_id: UserId, + responder_id: UserId, + ) -> Result<(bool, Option)> { + self.transaction(|tx| async move { + let (id_a, id_b) = if responder_id < requester_id { + (responder_id, requester_id) + } else { + (requester_id, responder_id) + }; + + let contact = contact::Entity::find() + .filter( + contact::Column::UserIdA + .eq(id_a) + .and(contact::Column::UserIdB.eq(id_b)), + ) + .one(&*tx) + .await? + .ok_or_else(|| anyhow!("no such contact"))?; + + contact::Entity::delete_by_id(contact.id).exec(&*tx).await?; + + let mut deleted_notification_id = None; + if !contact.accepted { + deleted_notification_id = self + .remove_notification( + responder_id, + rpc::Notification::ContactRequest { + sender_id: requester_id.to_proto(), + }, + &*tx, + ) + .await?; + } + + Ok((contact.accepted, deleted_notification_id)) + }) + .await + } + + pub async fn dismiss_contact_notification( + &self, + user_id: UserId, + contact_user_id: UserId, + ) -> Result<()> { + self.transaction(|tx| async move { + let (id_a, id_b, a_to_b) = if user_id < contact_user_id { + (user_id, contact_user_id, true) + } else { + (contact_user_id, user_id, false) + }; + + let result = contact::Entity::update_many() + .set(contact::ActiveModel { + should_notify: ActiveValue::set(false), + ..Default::default() + }) + .filter( + contact::Column::UserIdA + .eq(id_a) + .and(contact::Column::UserIdB.eq(id_b)) + .and( + contact::Column::AToB + .eq(a_to_b) + .and(contact::Column::Accepted.eq(true)) + .or(contact::Column::AToB + .ne(a_to_b) + .and(contact::Column::Accepted.eq(false))), + ), + ) + .exec(&*tx) + .await?; + if result.rows_affected == 0 { + Err(anyhow!("no such contact request"))? + } else { + Ok(()) + } + }) + .await + } + + pub async fn respond_to_contact_request( + &self, + responder_id: UserId, + requester_id: UserId, + accept: bool, + ) -> Result { + self.transaction(|tx| async move { + let (id_a, id_b, a_to_b) = if responder_id < requester_id { + (responder_id, requester_id, false) + } else { + (requester_id, responder_id, true) + }; + let rows_affected = if accept { + let result = contact::Entity::update_many() + .set(contact::ActiveModel { + accepted: ActiveValue::set(true), + should_notify: ActiveValue::set(true), + ..Default::default() + }) + .filter( + contact::Column::UserIdA + .eq(id_a) + .and(contact::Column::UserIdB.eq(id_b)) + .and(contact::Column::AToB.eq(a_to_b)), + ) + .exec(&*tx) + .await?; + result.rows_affected + } else { + let result = contact::Entity::delete_many() + .filter( + contact::Column::UserIdA + .eq(id_a) + .and(contact::Column::UserIdB.eq(id_b)) + .and(contact::Column::AToB.eq(a_to_b)) + .and(contact::Column::Accepted.eq(false)), + ) + .exec(&*tx) + .await?; + + result.rows_affected + }; + + if rows_affected == 0 { + Err(anyhow!("no such contact request"))? + } + + let mut notifications = Vec::new(); + notifications.extend( + self.mark_notification_as_read_with_response( + responder_id, + &rpc::Notification::ContactRequest { + sender_id: requester_id.to_proto(), + }, + accept, + &*tx, + ) + .await?, + ); + + if accept { + notifications.extend( + self.create_notification( + requester_id, + rpc::Notification::ContactRequestAccepted { + responder_id: responder_id.to_proto(), + }, + true, + &*tx, + ) + .await?, + ); + } + + Ok(notifications) + }) + .await + } +} diff --git a/crates/collab2/src/db/queries/messages.rs b/crates/collab2/src/db/queries/messages.rs new file mode 100644 index 0000000000..47bb27df39 --- /dev/null +++ b/crates/collab2/src/db/queries/messages.rs @@ -0,0 +1,505 @@ +use super::*; +use rpc::Notification; +use sea_orm::TryInsertResult; +use time::OffsetDateTime; + +impl Database { + pub async fn join_channel_chat( + &self, + channel_id: ChannelId, + connection_id: ConnectionId, + user_id: UserId, + ) -> Result<()> { + self.transaction(|tx| async move { + let channel = self.get_channel_internal(channel_id, &*tx).await?; + self.check_user_is_channel_participant(&channel, user_id, &*tx) + .await?; + channel_chat_participant::ActiveModel { + id: ActiveValue::NotSet, + channel_id: ActiveValue::Set(channel_id), + user_id: ActiveValue::Set(user_id), + connection_id: ActiveValue::Set(connection_id.id as i32), + connection_server_id: ActiveValue::Set(ServerId(connection_id.owner_id as i32)), + } + .insert(&*tx) + .await?; + Ok(()) + }) + .await + } + + pub async fn channel_chat_connection_lost( + &self, + connection_id: ConnectionId, + tx: &DatabaseTransaction, + ) -> Result<()> { + channel_chat_participant::Entity::delete_many() + .filter( + Condition::all() + .add( + channel_chat_participant::Column::ConnectionServerId + .eq(connection_id.owner_id), + ) + .add(channel_chat_participant::Column::ConnectionId.eq(connection_id.id)), + ) + .exec(tx) + .await?; + Ok(()) + } + + pub async fn leave_channel_chat( + &self, + channel_id: ChannelId, + connection_id: ConnectionId, + _user_id: UserId, + ) -> Result<()> { + self.transaction(|tx| async move { + channel_chat_participant::Entity::delete_many() + .filter( + Condition::all() + .add( + channel_chat_participant::Column::ConnectionServerId + .eq(connection_id.owner_id), + ) + .add(channel_chat_participant::Column::ConnectionId.eq(connection_id.id)) + .add(channel_chat_participant::Column::ChannelId.eq(channel_id)), + ) + .exec(&*tx) + .await?; + + Ok(()) + }) + .await + } + + pub async fn get_channel_messages( + &self, + channel_id: ChannelId, + user_id: UserId, + count: usize, + before_message_id: Option, + ) -> Result> { + self.transaction(|tx| async move { + let channel = self.get_channel_internal(channel_id, &*tx).await?; + self.check_user_is_channel_participant(&channel, user_id, &*tx) + .await?; + + let mut condition = + Condition::all().add(channel_message::Column::ChannelId.eq(channel_id)); + + if let Some(before_message_id) = before_message_id { + condition = condition.add(channel_message::Column::Id.lt(before_message_id)); + } + + let rows = channel_message::Entity::find() + .filter(condition) + .order_by_desc(channel_message::Column::Id) + .limit(count as u64) + .all(&*tx) + .await?; + + self.load_channel_messages(rows, &*tx).await + }) + .await + } + + pub async fn get_channel_messages_by_id( + &self, + user_id: UserId, + message_ids: &[MessageId], + ) -> Result> { + self.transaction(|tx| async move { + let rows = channel_message::Entity::find() + .filter(channel_message::Column::Id.is_in(message_ids.iter().copied())) + .order_by_desc(channel_message::Column::Id) + .all(&*tx) + .await?; + + let mut channels = HashMap::::default(); + for row in &rows { + channels.insert( + row.channel_id, + self.get_channel_internal(row.channel_id, &*tx).await?, + ); + } + + for (_, channel) in channels { + self.check_user_is_channel_participant(&channel, user_id, &*tx) + .await?; + } + + let messages = self.load_channel_messages(rows, &*tx).await?; + Ok(messages) + }) + .await + } + + async fn load_channel_messages( + &self, + rows: Vec, + tx: &DatabaseTransaction, + ) -> Result> { + let mut messages = rows + .into_iter() + .map(|row| { + let nonce = row.nonce.as_u64_pair(); + proto::ChannelMessage { + id: row.id.to_proto(), + sender_id: row.sender_id.to_proto(), + body: row.body, + timestamp: row.sent_at.assume_utc().unix_timestamp() as u64, + mentions: vec![], + nonce: Some(proto::Nonce { + upper_half: nonce.0, + lower_half: nonce.1, + }), + } + }) + .collect::>(); + messages.reverse(); + + let mut mentions = channel_message_mention::Entity::find() + .filter(channel_message_mention::Column::MessageId.is_in(messages.iter().map(|m| m.id))) + .order_by_asc(channel_message_mention::Column::MessageId) + .order_by_asc(channel_message_mention::Column::StartOffset) + .stream(&*tx) + .await?; + + let mut message_ix = 0; + while let Some(mention) = mentions.next().await { + let mention = mention?; + let message_id = mention.message_id.to_proto(); + while let Some(message) = messages.get_mut(message_ix) { + if message.id < message_id { + message_ix += 1; + } else { + if message.id == message_id { + message.mentions.push(proto::ChatMention { + range: Some(proto::Range { + start: mention.start_offset as u64, + end: mention.end_offset as u64, + }), + user_id: mention.user_id.to_proto(), + }); + } + break; + } + } + } + + Ok(messages) + } + + pub async fn create_channel_message( + &self, + channel_id: ChannelId, + user_id: UserId, + body: &str, + mentions: &[proto::ChatMention], + timestamp: OffsetDateTime, + nonce: u128, + ) -> Result { + self.transaction(|tx| async move { + let channel = self.get_channel_internal(channel_id, &*tx).await?; + self.check_user_is_channel_participant(&channel, user_id, &*tx) + .await?; + + let mut rows = channel_chat_participant::Entity::find() + .filter(channel_chat_participant::Column::ChannelId.eq(channel_id)) + .stream(&*tx) + .await?; + + let mut is_participant = false; + let mut participant_connection_ids = Vec::new(); + let mut participant_user_ids = Vec::new(); + while let Some(row) = rows.next().await { + let row = row?; + if row.user_id == user_id { + is_participant = true; + } + participant_user_ids.push(row.user_id); + participant_connection_ids.push(row.connection()); + } + drop(rows); + + if !is_participant { + Err(anyhow!("not a chat participant"))?; + } + + let timestamp = timestamp.to_offset(time::UtcOffset::UTC); + let timestamp = time::PrimitiveDateTime::new(timestamp.date(), timestamp.time()); + + let result = channel_message::Entity::insert(channel_message::ActiveModel { + channel_id: ActiveValue::Set(channel_id), + sender_id: ActiveValue::Set(user_id), + body: ActiveValue::Set(body.to_string()), + sent_at: ActiveValue::Set(timestamp), + nonce: ActiveValue::Set(Uuid::from_u128(nonce)), + id: ActiveValue::NotSet, + }) + .on_conflict( + OnConflict::columns([ + channel_message::Column::SenderId, + channel_message::Column::Nonce, + ]) + .do_nothing() + .to_owned(), + ) + .do_nothing() + .exec(&*tx) + .await?; + + let message_id; + let mut notifications = Vec::new(); + match result { + TryInsertResult::Inserted(result) => { + message_id = result.last_insert_id; + let mentioned_user_ids = + mentions.iter().map(|m| m.user_id).collect::>(); + let mentions = mentions + .iter() + .filter_map(|mention| { + let range = mention.range.as_ref()?; + if !body.is_char_boundary(range.start as usize) + || !body.is_char_boundary(range.end as usize) + { + return None; + } + Some(channel_message_mention::ActiveModel { + message_id: ActiveValue::Set(message_id), + start_offset: ActiveValue::Set(range.start as i32), + end_offset: ActiveValue::Set(range.end as i32), + user_id: ActiveValue::Set(UserId::from_proto(mention.user_id)), + }) + }) + .collect::>(); + if !mentions.is_empty() { + channel_message_mention::Entity::insert_many(mentions) + .exec(&*tx) + .await?; + } + + for mentioned_user in mentioned_user_ids { + notifications.extend( + self.create_notification( + UserId::from_proto(mentioned_user), + rpc::Notification::ChannelMessageMention { + message_id: message_id.to_proto(), + sender_id: user_id.to_proto(), + channel_id: channel_id.to_proto(), + }, + false, + &*tx, + ) + .await?, + ); + } + + self.observe_channel_message_internal(channel_id, user_id, message_id, &*tx) + .await?; + } + _ => { + message_id = channel_message::Entity::find() + .filter(channel_message::Column::Nonce.eq(Uuid::from_u128(nonce))) + .one(&*tx) + .await? + .ok_or_else(|| anyhow!("failed to insert message"))? + .id; + } + } + + let mut channel_members = self.get_channel_participants(&channel, &*tx).await?; + channel_members.retain(|member| !participant_user_ids.contains(member)); + + Ok(CreatedChannelMessage { + message_id, + participant_connection_ids, + channel_members, + notifications, + }) + }) + .await + } + + pub async fn observe_channel_message( + &self, + channel_id: ChannelId, + user_id: UserId, + message_id: MessageId, + ) -> Result { + self.transaction(|tx| async move { + self.observe_channel_message_internal(channel_id, user_id, message_id, &*tx) + .await?; + let mut batch = NotificationBatch::default(); + batch.extend( + self.mark_notification_as_read( + user_id, + &Notification::ChannelMessageMention { + message_id: message_id.to_proto(), + sender_id: Default::default(), + channel_id: Default::default(), + }, + &*tx, + ) + .await?, + ); + Ok(batch) + }) + .await + } + + async fn observe_channel_message_internal( + &self, + channel_id: ChannelId, + user_id: UserId, + message_id: MessageId, + tx: &DatabaseTransaction, + ) -> Result<()> { + observed_channel_messages::Entity::insert(observed_channel_messages::ActiveModel { + user_id: ActiveValue::Set(user_id), + channel_id: ActiveValue::Set(channel_id), + channel_message_id: ActiveValue::Set(message_id), + }) + .on_conflict( + OnConflict::columns([ + observed_channel_messages::Column::ChannelId, + observed_channel_messages::Column::UserId, + ]) + .update_column(observed_channel_messages::Column::ChannelMessageId) + .action_cond_where(observed_channel_messages::Column::ChannelMessageId.lt(message_id)) + .to_owned(), + ) + // TODO: Try to upgrade SeaORM so we don't have to do this hack around their bug + .exec_without_returning(&*tx) + .await?; + Ok(()) + } + + pub async fn unseen_channel_messages( + &self, + user_id: UserId, + channel_ids: &[ChannelId], + tx: &DatabaseTransaction, + ) -> Result> { + let mut observed_messages_by_channel_id = HashMap::default(); + let mut rows = observed_channel_messages::Entity::find() + .filter(observed_channel_messages::Column::UserId.eq(user_id)) + .filter(observed_channel_messages::Column::ChannelId.is_in(channel_ids.iter().copied())) + .stream(&*tx) + .await?; + + while let Some(row) = rows.next().await { + let row = row?; + observed_messages_by_channel_id.insert(row.channel_id, row); + } + drop(rows); + let mut values = String::new(); + for id in channel_ids { + if !values.is_empty() { + values.push_str(", "); + } + write!(&mut values, "({})", id).unwrap(); + } + + if values.is_empty() { + return Ok(Default::default()); + } + + let sql = format!( + r#" + SELECT + * + FROM ( + SELECT + *, + row_number() OVER ( + PARTITION BY channel_id + ORDER BY id DESC + ) as row_number + FROM channel_messages + WHERE + channel_id in ({values}) + ) AS messages + WHERE + row_number = 1 + "#, + ); + + let stmt = Statement::from_string(self.pool.get_database_backend(), sql); + let last_messages = channel_message::Model::find_by_statement(stmt) + .all(&*tx) + .await?; + + let mut changes = Vec::new(); + for last_message in last_messages { + if let Some(observed_message) = + observed_messages_by_channel_id.get(&last_message.channel_id) + { + if observed_message.channel_message_id == last_message.id { + continue; + } + } + changes.push(proto::UnseenChannelMessage { + channel_id: last_message.channel_id.to_proto(), + message_id: last_message.id.to_proto(), + }); + } + + Ok(changes) + } + + pub async fn remove_channel_message( + &self, + channel_id: ChannelId, + message_id: MessageId, + user_id: UserId, + ) -> Result> { + self.transaction(|tx| async move { + let mut rows = channel_chat_participant::Entity::find() + .filter(channel_chat_participant::Column::ChannelId.eq(channel_id)) + .stream(&*tx) + .await?; + + let mut is_participant = false; + let mut participant_connection_ids = Vec::new(); + while let Some(row) = rows.next().await { + let row = row?; + if row.user_id == user_id { + is_participant = true; + } + participant_connection_ids.push(row.connection()); + } + drop(rows); + + if !is_participant { + Err(anyhow!("not a chat participant"))?; + } + + let result = channel_message::Entity::delete_by_id(message_id) + .filter(channel_message::Column::SenderId.eq(user_id)) + .exec(&*tx) + .await?; + + if result.rows_affected == 0 { + let channel = self.get_channel_internal(channel_id, &*tx).await?; + if self + .check_user_is_channel_admin(&channel, user_id, &*tx) + .await + .is_ok() + { + let result = channel_message::Entity::delete_by_id(message_id) + .exec(&*tx) + .await?; + if result.rows_affected == 0 { + Err(anyhow!("no such message"))?; + } + } else { + Err(anyhow!("operation could not be completed"))?; + } + } + + Ok(participant_connection_ids) + }) + .await + } +} diff --git a/crates/collab2/src/db/queries/notifications.rs b/crates/collab2/src/db/queries/notifications.rs new file mode 100644 index 0000000000..6f2511c23e --- /dev/null +++ b/crates/collab2/src/db/queries/notifications.rs @@ -0,0 +1,262 @@ +use super::*; +use rpc::Notification; + +impl Database { + pub async fn initialize_notification_kinds(&mut self) -> Result<()> { + notification_kind::Entity::insert_many(Notification::all_variant_names().iter().map( + |kind| notification_kind::ActiveModel { + name: ActiveValue::Set(kind.to_string()), + ..Default::default() + }, + )) + .on_conflict(OnConflict::new().do_nothing().to_owned()) + .exec_without_returning(&self.pool) + .await?; + + let mut rows = notification_kind::Entity::find().stream(&self.pool).await?; + while let Some(row) = rows.next().await { + let row = row?; + self.notification_kinds_by_name.insert(row.name, row.id); + } + + for name in Notification::all_variant_names() { + if let Some(id) = self.notification_kinds_by_name.get(*name).copied() { + self.notification_kinds_by_id.insert(id, name); + } + } + + Ok(()) + } + + pub async fn get_notifications( + &self, + recipient_id: UserId, + limit: usize, + before_id: Option, + ) -> Result> { + self.transaction(|tx| async move { + let mut result = Vec::new(); + let mut condition = + Condition::all().add(notification::Column::RecipientId.eq(recipient_id)); + + if let Some(before_id) = before_id { + condition = condition.add(notification::Column::Id.lt(before_id)); + } + + let mut rows = notification::Entity::find() + .filter(condition) + .order_by_desc(notification::Column::Id) + .limit(limit as u64) + .stream(&*tx) + .await?; + while let Some(row) = rows.next().await { + let row = row?; + let kind = row.kind; + if let Some(proto) = model_to_proto(self, row) { + result.push(proto); + } else { + log::warn!("unknown notification kind {:?}", kind); + } + } + result.reverse(); + Ok(result) + }) + .await + } + + /// Create a notification. If `avoid_duplicates` is set to true, then avoid + /// creating a new notification if the given recipient already has an + /// unread notification with the given kind and entity id. + pub async fn create_notification( + &self, + recipient_id: UserId, + notification: Notification, + avoid_duplicates: bool, + tx: &DatabaseTransaction, + ) -> Result> { + if avoid_duplicates { + if self + .find_notification(recipient_id, ¬ification, tx) + .await? + .is_some() + { + return Ok(None); + } + } + + let proto = notification.to_proto(); + let kind = notification_kind_from_proto(self, &proto)?; + let model = notification::ActiveModel { + recipient_id: ActiveValue::Set(recipient_id), + kind: ActiveValue::Set(kind), + entity_id: ActiveValue::Set(proto.entity_id.map(|id| id as i32)), + content: ActiveValue::Set(proto.content.clone()), + ..Default::default() + } + .save(&*tx) + .await?; + + Ok(Some(( + recipient_id, + proto::Notification { + id: model.id.as_ref().to_proto(), + kind: proto.kind, + timestamp: model.created_at.as_ref().assume_utc().unix_timestamp() as u64, + is_read: false, + response: None, + content: proto.content, + entity_id: proto.entity_id, + }, + ))) + } + + /// Remove an unread notification with the given recipient, kind and + /// entity id. + pub async fn remove_notification( + &self, + recipient_id: UserId, + notification: Notification, + tx: &DatabaseTransaction, + ) -> Result> { + let id = self + .find_notification(recipient_id, ¬ification, tx) + .await?; + if let Some(id) = id { + notification::Entity::delete_by_id(id).exec(tx).await?; + } + Ok(id) + } + + /// Populate the response for the notification with the given kind and + /// entity id. + pub async fn mark_notification_as_read_with_response( + &self, + recipient_id: UserId, + notification: &Notification, + response: bool, + tx: &DatabaseTransaction, + ) -> Result> { + self.mark_notification_as_read_internal(recipient_id, notification, Some(response), tx) + .await + } + + pub async fn mark_notification_as_read( + &self, + recipient_id: UserId, + notification: &Notification, + tx: &DatabaseTransaction, + ) -> Result> { + self.mark_notification_as_read_internal(recipient_id, notification, None, tx) + .await + } + + pub async fn mark_notification_as_read_by_id( + &self, + recipient_id: UserId, + notification_id: NotificationId, + ) -> Result { + self.transaction(|tx| async move { + let row = notification::Entity::update(notification::ActiveModel { + id: ActiveValue::Unchanged(notification_id), + recipient_id: ActiveValue::Unchanged(recipient_id), + is_read: ActiveValue::Set(true), + ..Default::default() + }) + .exec(&*tx) + .await?; + Ok(model_to_proto(self, row) + .map(|notification| (recipient_id, notification)) + .into_iter() + .collect()) + }) + .await + } + + async fn mark_notification_as_read_internal( + &self, + recipient_id: UserId, + notification: &Notification, + response: Option, + tx: &DatabaseTransaction, + ) -> Result> { + if let Some(id) = self + .find_notification(recipient_id, notification, &*tx) + .await? + { + let row = notification::Entity::update(notification::ActiveModel { + id: ActiveValue::Unchanged(id), + recipient_id: ActiveValue::Unchanged(recipient_id), + is_read: ActiveValue::Set(true), + response: if let Some(response) = response { + ActiveValue::Set(Some(response)) + } else { + ActiveValue::NotSet + }, + ..Default::default() + }) + .exec(tx) + .await?; + Ok(model_to_proto(self, row).map(|notification| (recipient_id, notification))) + } else { + Ok(None) + } + } + + /// Find an unread notification by its recipient, kind and entity id. + async fn find_notification( + &self, + recipient_id: UserId, + notification: &Notification, + tx: &DatabaseTransaction, + ) -> Result> { + let proto = notification.to_proto(); + let kind = notification_kind_from_proto(self, &proto)?; + + #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] + enum QueryIds { + Id, + } + + Ok(notification::Entity::find() + .select_only() + .column(notification::Column::Id) + .filter( + Condition::all() + .add(notification::Column::RecipientId.eq(recipient_id)) + .add(notification::Column::IsRead.eq(false)) + .add(notification::Column::Kind.eq(kind)) + .add(if proto.entity_id.is_some() { + notification::Column::EntityId.eq(proto.entity_id) + } else { + notification::Column::EntityId.is_null() + }), + ) + .into_values::<_, QueryIds>() + .one(&*tx) + .await?) + } +} + +fn model_to_proto(this: &Database, row: notification::Model) -> Option { + let kind = this.notification_kinds_by_id.get(&row.kind)?; + Some(proto::Notification { + id: row.id.to_proto(), + kind: kind.to_string(), + timestamp: row.created_at.assume_utc().unix_timestamp() as u64, + is_read: row.is_read, + response: row.response, + content: row.content, + entity_id: row.entity_id.map(|id| id as u64), + }) +} + +fn notification_kind_from_proto( + this: &Database, + proto: &proto::Notification, +) -> Result { + Ok(this + .notification_kinds_by_name + .get(&proto.kind) + .copied() + .ok_or_else(|| anyhow!("invalid notification kind {:?}", proto.kind))?) +} diff --git a/crates/collab2/src/db/queries/projects.rs b/crates/collab2/src/db/queries/projects.rs new file mode 100644 index 0000000000..3e2c003378 --- /dev/null +++ b/crates/collab2/src/db/queries/projects.rs @@ -0,0 +1,960 @@ +use super::*; + +impl Database { + pub async fn project_count_excluding_admins(&self) -> Result { + #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] + enum QueryAs { + Count, + } + + self.transaction(|tx| async move { + Ok(project::Entity::find() + .select_only() + .column_as(project::Column::Id.count(), QueryAs::Count) + .inner_join(user::Entity) + .filter(user::Column::Admin.eq(false)) + .into_values::<_, QueryAs>() + .one(&*tx) + .await? + .unwrap_or(0i64) as usize) + }) + .await + } + + pub async fn share_project( + &self, + room_id: RoomId, + connection: ConnectionId, + worktrees: &[proto::WorktreeMetadata], + ) -> Result> { + self.room_transaction(room_id, |tx| async move { + let participant = room_participant::Entity::find() + .filter( + Condition::all() + .add( + room_participant::Column::AnsweringConnectionId + .eq(connection.id as i32), + ) + .add( + room_participant::Column::AnsweringConnectionServerId + .eq(connection.owner_id as i32), + ), + ) + .one(&*tx) + .await? + .ok_or_else(|| anyhow!("could not find participant"))?; + if participant.room_id != room_id { + return Err(anyhow!("shared project on unexpected room"))?; + } + + let project = project::ActiveModel { + room_id: ActiveValue::set(participant.room_id), + host_user_id: ActiveValue::set(participant.user_id), + host_connection_id: ActiveValue::set(Some(connection.id as i32)), + host_connection_server_id: ActiveValue::set(Some(ServerId( + connection.owner_id as i32, + ))), + ..Default::default() + } + .insert(&*tx) + .await?; + + if !worktrees.is_empty() { + worktree::Entity::insert_many(worktrees.iter().map(|worktree| { + worktree::ActiveModel { + id: ActiveValue::set(worktree.id as i64), + project_id: ActiveValue::set(project.id), + abs_path: ActiveValue::set(worktree.abs_path.clone()), + root_name: ActiveValue::set(worktree.root_name.clone()), + visible: ActiveValue::set(worktree.visible), + scan_id: ActiveValue::set(0), + completed_scan_id: ActiveValue::set(0), + } + })) + .exec(&*tx) + .await?; + } + + project_collaborator::ActiveModel { + project_id: ActiveValue::set(project.id), + connection_id: ActiveValue::set(connection.id as i32), + connection_server_id: ActiveValue::set(ServerId(connection.owner_id as i32)), + user_id: ActiveValue::set(participant.user_id), + replica_id: ActiveValue::set(ReplicaId(0)), + is_host: ActiveValue::set(true), + ..Default::default() + } + .insert(&*tx) + .await?; + + let room = self.get_room(room_id, &tx).await?; + Ok((project.id, room)) + }) + .await + } + + pub async fn unshare_project( + &self, + project_id: ProjectId, + connection: ConnectionId, + ) -> Result)>> { + let room_id = self.room_id_for_project(project_id).await?; + self.room_transaction(room_id, |tx| async move { + let guest_connection_ids = self.project_guest_connection_ids(project_id, &tx).await?; + + let project = project::Entity::find_by_id(project_id) + .one(&*tx) + .await? + .ok_or_else(|| anyhow!("project not found"))?; + if project.host_connection()? == connection { + project::Entity::delete(project.into_active_model()) + .exec(&*tx) + .await?; + let room = self.get_room(room_id, &tx).await?; + Ok((room, guest_connection_ids)) + } else { + Err(anyhow!("cannot unshare a project hosted by another user"))? + } + }) + .await + } + + pub async fn update_project( + &self, + project_id: ProjectId, + connection: ConnectionId, + worktrees: &[proto::WorktreeMetadata], + ) -> Result)>> { + let room_id = self.room_id_for_project(project_id).await?; + self.room_transaction(room_id, |tx| async move { + let project = project::Entity::find_by_id(project_id) + .filter( + Condition::all() + .add(project::Column::HostConnectionId.eq(connection.id as i32)) + .add( + project::Column::HostConnectionServerId.eq(connection.owner_id as i32), + ), + ) + .one(&*tx) + .await? + .ok_or_else(|| anyhow!("no such project"))?; + + self.update_project_worktrees(project.id, worktrees, &tx) + .await?; + + let guest_connection_ids = self.project_guest_connection_ids(project.id, &tx).await?; + let room = self.get_room(project.room_id, &tx).await?; + Ok((room, guest_connection_ids)) + }) + .await + } + + pub(in crate::db) async fn update_project_worktrees( + &self, + project_id: ProjectId, + worktrees: &[proto::WorktreeMetadata], + tx: &DatabaseTransaction, + ) -> Result<()> { + if !worktrees.is_empty() { + worktree::Entity::insert_many(worktrees.iter().map(|worktree| worktree::ActiveModel { + id: ActiveValue::set(worktree.id as i64), + project_id: ActiveValue::set(project_id), + abs_path: ActiveValue::set(worktree.abs_path.clone()), + root_name: ActiveValue::set(worktree.root_name.clone()), + visible: ActiveValue::set(worktree.visible), + scan_id: ActiveValue::set(0), + completed_scan_id: ActiveValue::set(0), + })) + .on_conflict( + OnConflict::columns([worktree::Column::ProjectId, worktree::Column::Id]) + .update_column(worktree::Column::RootName) + .to_owned(), + ) + .exec(&*tx) + .await?; + } + + worktree::Entity::delete_many() + .filter(worktree::Column::ProjectId.eq(project_id).and( + worktree::Column::Id.is_not_in(worktrees.iter().map(|worktree| worktree.id as i64)), + )) + .exec(&*tx) + .await?; + + Ok(()) + } + + pub async fn update_worktree( + &self, + update: &proto::UpdateWorktree, + connection: ConnectionId, + ) -> Result>> { + let project_id = ProjectId::from_proto(update.project_id); + let worktree_id = update.worktree_id as i64; + let room_id = self.room_id_for_project(project_id).await?; + self.room_transaction(room_id, |tx| async move { + // Ensure the update comes from the host. + let _project = project::Entity::find_by_id(project_id) + .filter( + Condition::all() + .add(project::Column::HostConnectionId.eq(connection.id as i32)) + .add( + project::Column::HostConnectionServerId.eq(connection.owner_id as i32), + ), + ) + .one(&*tx) + .await? + .ok_or_else(|| anyhow!("no such project"))?; + + // Update metadata. + worktree::Entity::update(worktree::ActiveModel { + id: ActiveValue::set(worktree_id), + project_id: ActiveValue::set(project_id), + root_name: ActiveValue::set(update.root_name.clone()), + scan_id: ActiveValue::set(update.scan_id as i64), + completed_scan_id: if update.is_last_update { + ActiveValue::set(update.scan_id as i64) + } else { + ActiveValue::default() + }, + abs_path: ActiveValue::set(update.abs_path.clone()), + ..Default::default() + }) + .exec(&*tx) + .await?; + + if !update.updated_entries.is_empty() { + worktree_entry::Entity::insert_many(update.updated_entries.iter().map(|entry| { + let mtime = entry.mtime.clone().unwrap_or_default(); + worktree_entry::ActiveModel { + project_id: ActiveValue::set(project_id), + worktree_id: ActiveValue::set(worktree_id), + id: ActiveValue::set(entry.id as i64), + is_dir: ActiveValue::set(entry.is_dir), + path: ActiveValue::set(entry.path.clone()), + inode: ActiveValue::set(entry.inode as i64), + mtime_seconds: ActiveValue::set(mtime.seconds as i64), + mtime_nanos: ActiveValue::set(mtime.nanos as i32), + is_symlink: ActiveValue::set(entry.is_symlink), + is_ignored: ActiveValue::set(entry.is_ignored), + is_external: ActiveValue::set(entry.is_external), + git_status: ActiveValue::set(entry.git_status.map(|status| status as i64)), + is_deleted: ActiveValue::set(false), + scan_id: ActiveValue::set(update.scan_id as i64), + } + })) + .on_conflict( + OnConflict::columns([ + worktree_entry::Column::ProjectId, + worktree_entry::Column::WorktreeId, + worktree_entry::Column::Id, + ]) + .update_columns([ + worktree_entry::Column::IsDir, + worktree_entry::Column::Path, + worktree_entry::Column::Inode, + worktree_entry::Column::MtimeSeconds, + worktree_entry::Column::MtimeNanos, + worktree_entry::Column::IsSymlink, + worktree_entry::Column::IsIgnored, + worktree_entry::Column::GitStatus, + worktree_entry::Column::ScanId, + ]) + .to_owned(), + ) + .exec(&*tx) + .await?; + } + + if !update.removed_entries.is_empty() { + worktree_entry::Entity::update_many() + .filter( + worktree_entry::Column::ProjectId + .eq(project_id) + .and(worktree_entry::Column::WorktreeId.eq(worktree_id)) + .and( + worktree_entry::Column::Id + .is_in(update.removed_entries.iter().map(|id| *id as i64)), + ), + ) + .set(worktree_entry::ActiveModel { + is_deleted: ActiveValue::Set(true), + scan_id: ActiveValue::Set(update.scan_id as i64), + ..Default::default() + }) + .exec(&*tx) + .await?; + } + + if !update.updated_repositories.is_empty() { + worktree_repository::Entity::insert_many(update.updated_repositories.iter().map( + |repository| worktree_repository::ActiveModel { + project_id: ActiveValue::set(project_id), + worktree_id: ActiveValue::set(worktree_id), + work_directory_id: ActiveValue::set(repository.work_directory_id as i64), + scan_id: ActiveValue::set(update.scan_id as i64), + branch: ActiveValue::set(repository.branch.clone()), + is_deleted: ActiveValue::set(false), + }, + )) + .on_conflict( + OnConflict::columns([ + worktree_repository::Column::ProjectId, + worktree_repository::Column::WorktreeId, + worktree_repository::Column::WorkDirectoryId, + ]) + .update_columns([ + worktree_repository::Column::ScanId, + worktree_repository::Column::Branch, + ]) + .to_owned(), + ) + .exec(&*tx) + .await?; + } + + if !update.removed_repositories.is_empty() { + worktree_repository::Entity::update_many() + .filter( + worktree_repository::Column::ProjectId + .eq(project_id) + .and(worktree_repository::Column::WorktreeId.eq(worktree_id)) + .and( + worktree_repository::Column::WorkDirectoryId + .is_in(update.removed_repositories.iter().map(|id| *id as i64)), + ), + ) + .set(worktree_repository::ActiveModel { + is_deleted: ActiveValue::Set(true), + scan_id: ActiveValue::Set(update.scan_id as i64), + ..Default::default() + }) + .exec(&*tx) + .await?; + } + + let connection_ids = self.project_guest_connection_ids(project_id, &tx).await?; + Ok(connection_ids) + }) + .await + } + + pub async fn update_diagnostic_summary( + &self, + update: &proto::UpdateDiagnosticSummary, + connection: ConnectionId, + ) -> Result>> { + let project_id = ProjectId::from_proto(update.project_id); + let worktree_id = update.worktree_id as i64; + let room_id = self.room_id_for_project(project_id).await?; + self.room_transaction(room_id, |tx| async move { + let summary = update + .summary + .as_ref() + .ok_or_else(|| anyhow!("invalid summary"))?; + + // Ensure the update comes from the host. + let project = project::Entity::find_by_id(project_id) + .one(&*tx) + .await? + .ok_or_else(|| anyhow!("no such project"))?; + if project.host_connection()? != connection { + return Err(anyhow!("can't update a project hosted by someone else"))?; + } + + // Update summary. + worktree_diagnostic_summary::Entity::insert(worktree_diagnostic_summary::ActiveModel { + project_id: ActiveValue::set(project_id), + worktree_id: ActiveValue::set(worktree_id), + path: ActiveValue::set(summary.path.clone()), + language_server_id: ActiveValue::set(summary.language_server_id as i64), + error_count: ActiveValue::set(summary.error_count as i32), + warning_count: ActiveValue::set(summary.warning_count as i32), + ..Default::default() + }) + .on_conflict( + OnConflict::columns([ + worktree_diagnostic_summary::Column::ProjectId, + worktree_diagnostic_summary::Column::WorktreeId, + worktree_diagnostic_summary::Column::Path, + ]) + .update_columns([ + worktree_diagnostic_summary::Column::LanguageServerId, + worktree_diagnostic_summary::Column::ErrorCount, + worktree_diagnostic_summary::Column::WarningCount, + ]) + .to_owned(), + ) + .exec(&*tx) + .await?; + + let connection_ids = self.project_guest_connection_ids(project_id, &tx).await?; + Ok(connection_ids) + }) + .await + } + + pub async fn start_language_server( + &self, + update: &proto::StartLanguageServer, + connection: ConnectionId, + ) -> Result>> { + let project_id = ProjectId::from_proto(update.project_id); + let room_id = self.room_id_for_project(project_id).await?; + self.room_transaction(room_id, |tx| async move { + let server = update + .server + .as_ref() + .ok_or_else(|| anyhow!("invalid language server"))?; + + // Ensure the update comes from the host. + let project = project::Entity::find_by_id(project_id) + .one(&*tx) + .await? + .ok_or_else(|| anyhow!("no such project"))?; + if project.host_connection()? != connection { + return Err(anyhow!("can't update a project hosted by someone else"))?; + } + + // Add the newly-started language server. + language_server::Entity::insert(language_server::ActiveModel { + project_id: ActiveValue::set(project_id), + id: ActiveValue::set(server.id as i64), + name: ActiveValue::set(server.name.clone()), + ..Default::default() + }) + .on_conflict( + OnConflict::columns([ + language_server::Column::ProjectId, + language_server::Column::Id, + ]) + .update_column(language_server::Column::Name) + .to_owned(), + ) + .exec(&*tx) + .await?; + + let connection_ids = self.project_guest_connection_ids(project_id, &tx).await?; + Ok(connection_ids) + }) + .await + } + + pub async fn update_worktree_settings( + &self, + update: &proto::UpdateWorktreeSettings, + connection: ConnectionId, + ) -> Result>> { + let project_id = ProjectId::from_proto(update.project_id); + let room_id = self.room_id_for_project(project_id).await?; + self.room_transaction(room_id, |tx| async move { + // Ensure the update comes from the host. + let project = project::Entity::find_by_id(project_id) + .one(&*tx) + .await? + .ok_or_else(|| anyhow!("no such project"))?; + if project.host_connection()? != connection { + return Err(anyhow!("can't update a project hosted by someone else"))?; + } + + if let Some(content) = &update.content { + worktree_settings_file::Entity::insert(worktree_settings_file::ActiveModel { + project_id: ActiveValue::Set(project_id), + worktree_id: ActiveValue::Set(update.worktree_id as i64), + path: ActiveValue::Set(update.path.clone()), + content: ActiveValue::Set(content.clone()), + }) + .on_conflict( + OnConflict::columns([ + worktree_settings_file::Column::ProjectId, + worktree_settings_file::Column::WorktreeId, + worktree_settings_file::Column::Path, + ]) + .update_column(worktree_settings_file::Column::Content) + .to_owned(), + ) + .exec(&*tx) + .await?; + } else { + worktree_settings_file::Entity::delete(worktree_settings_file::ActiveModel { + project_id: ActiveValue::Set(project_id), + worktree_id: ActiveValue::Set(update.worktree_id as i64), + path: ActiveValue::Set(update.path.clone()), + ..Default::default() + }) + .exec(&*tx) + .await?; + } + + let connection_ids = self.project_guest_connection_ids(project_id, &tx).await?; + Ok(connection_ids) + }) + .await + } + + pub async fn join_project( + &self, + project_id: ProjectId, + connection: ConnectionId, + ) -> Result> { + let room_id = self.room_id_for_project(project_id).await?; + self.room_transaction(room_id, |tx| async move { + let participant = room_participant::Entity::find() + .filter( + Condition::all() + .add( + room_participant::Column::AnsweringConnectionId + .eq(connection.id as i32), + ) + .add( + room_participant::Column::AnsweringConnectionServerId + .eq(connection.owner_id as i32), + ), + ) + .one(&*tx) + .await? + .ok_or_else(|| anyhow!("must join a room first"))?; + + let project = project::Entity::find_by_id(project_id) + .one(&*tx) + .await? + .ok_or_else(|| anyhow!("no such project"))?; + if project.room_id != participant.room_id { + return Err(anyhow!("no such project"))?; + } + + let mut collaborators = project + .find_related(project_collaborator::Entity) + .all(&*tx) + .await?; + let replica_ids = collaborators + .iter() + .map(|c| c.replica_id) + .collect::>(); + let mut replica_id = ReplicaId(1); + while replica_ids.contains(&replica_id) { + replica_id.0 += 1; + } + let new_collaborator = project_collaborator::ActiveModel { + project_id: ActiveValue::set(project_id), + connection_id: ActiveValue::set(connection.id as i32), + connection_server_id: ActiveValue::set(ServerId(connection.owner_id as i32)), + user_id: ActiveValue::set(participant.user_id), + replica_id: ActiveValue::set(replica_id), + is_host: ActiveValue::set(false), + ..Default::default() + } + .insert(&*tx) + .await?; + collaborators.push(new_collaborator); + + let db_worktrees = project.find_related(worktree::Entity).all(&*tx).await?; + let mut worktrees = db_worktrees + .into_iter() + .map(|db_worktree| { + ( + db_worktree.id as u64, + Worktree { + id: db_worktree.id as u64, + abs_path: db_worktree.abs_path, + root_name: db_worktree.root_name, + visible: db_worktree.visible, + entries: Default::default(), + repository_entries: Default::default(), + diagnostic_summaries: Default::default(), + settings_files: Default::default(), + scan_id: db_worktree.scan_id as u64, + completed_scan_id: db_worktree.completed_scan_id as u64, + }, + ) + }) + .collect::>(); + + // Populate worktree entries. + { + let mut db_entries = worktree_entry::Entity::find() + .filter( + Condition::all() + .add(worktree_entry::Column::ProjectId.eq(project_id)) + .add(worktree_entry::Column::IsDeleted.eq(false)), + ) + .stream(&*tx) + .await?; + while let Some(db_entry) = db_entries.next().await { + let db_entry = db_entry?; + if let Some(worktree) = worktrees.get_mut(&(db_entry.worktree_id as u64)) { + worktree.entries.push(proto::Entry { + id: db_entry.id as u64, + is_dir: db_entry.is_dir, + path: db_entry.path, + inode: db_entry.inode as u64, + mtime: Some(proto::Timestamp { + seconds: db_entry.mtime_seconds as u64, + nanos: db_entry.mtime_nanos as u32, + }), + is_symlink: db_entry.is_symlink, + is_ignored: db_entry.is_ignored, + is_external: db_entry.is_external, + git_status: db_entry.git_status.map(|status| status as i32), + }); + } + } + } + + // Populate repository entries. + { + let mut db_repository_entries = worktree_repository::Entity::find() + .filter( + Condition::all() + .add(worktree_repository::Column::ProjectId.eq(project_id)) + .add(worktree_repository::Column::IsDeleted.eq(false)), + ) + .stream(&*tx) + .await?; + while let Some(db_repository_entry) = db_repository_entries.next().await { + let db_repository_entry = db_repository_entry?; + if let Some(worktree) = + worktrees.get_mut(&(db_repository_entry.worktree_id as u64)) + { + worktree.repository_entries.insert( + db_repository_entry.work_directory_id as u64, + proto::RepositoryEntry { + work_directory_id: db_repository_entry.work_directory_id as u64, + branch: db_repository_entry.branch, + }, + ); + } + } + } + + // Populate worktree diagnostic summaries. + { + let mut db_summaries = worktree_diagnostic_summary::Entity::find() + .filter(worktree_diagnostic_summary::Column::ProjectId.eq(project_id)) + .stream(&*tx) + .await?; + while let Some(db_summary) = db_summaries.next().await { + let db_summary = db_summary?; + if let Some(worktree) = worktrees.get_mut(&(db_summary.worktree_id as u64)) { + worktree + .diagnostic_summaries + .push(proto::DiagnosticSummary { + path: db_summary.path, + language_server_id: db_summary.language_server_id as u64, + error_count: db_summary.error_count as u32, + warning_count: db_summary.warning_count as u32, + }); + } + } + } + + // Populate worktree settings files + { + let mut db_settings_files = worktree_settings_file::Entity::find() + .filter(worktree_settings_file::Column::ProjectId.eq(project_id)) + .stream(&*tx) + .await?; + while let Some(db_settings_file) = db_settings_files.next().await { + let db_settings_file = db_settings_file?; + if let Some(worktree) = + worktrees.get_mut(&(db_settings_file.worktree_id as u64)) + { + worktree.settings_files.push(WorktreeSettingsFile { + path: db_settings_file.path, + content: db_settings_file.content, + }); + } + } + } + + // Populate language servers. + let language_servers = project + .find_related(language_server::Entity) + .all(&*tx) + .await?; + + let project = Project { + collaborators: collaborators + .into_iter() + .map(|collaborator| ProjectCollaborator { + connection_id: collaborator.connection(), + user_id: collaborator.user_id, + replica_id: collaborator.replica_id, + is_host: collaborator.is_host, + }) + .collect(), + worktrees, + language_servers: language_servers + .into_iter() + .map(|language_server| proto::LanguageServer { + id: language_server.id as u64, + name: language_server.name, + }) + .collect(), + }; + Ok((project, replica_id as ReplicaId)) + }) + .await + } + + pub async fn leave_project( + &self, + project_id: ProjectId, + connection: ConnectionId, + ) -> Result> { + let room_id = self.room_id_for_project(project_id).await?; + self.room_transaction(room_id, |tx| async move { + let result = project_collaborator::Entity::delete_many() + .filter( + Condition::all() + .add(project_collaborator::Column::ProjectId.eq(project_id)) + .add(project_collaborator::Column::ConnectionId.eq(connection.id as i32)) + .add( + project_collaborator::Column::ConnectionServerId + .eq(connection.owner_id as i32), + ), + ) + .exec(&*tx) + .await?; + if result.rows_affected == 0 { + Err(anyhow!("not a collaborator on this project"))?; + } + + let project = project::Entity::find_by_id(project_id) + .one(&*tx) + .await? + .ok_or_else(|| anyhow!("no such project"))?; + let collaborators = project + .find_related(project_collaborator::Entity) + .all(&*tx) + .await?; + let connection_ids = collaborators + .into_iter() + .map(|collaborator| collaborator.connection()) + .collect(); + + follower::Entity::delete_many() + .filter( + Condition::any() + .add( + Condition::all() + .add(follower::Column::ProjectId.eq(Some(project_id))) + .add( + follower::Column::LeaderConnectionServerId + .eq(connection.owner_id), + ) + .add(follower::Column::LeaderConnectionId.eq(connection.id)), + ) + .add( + Condition::all() + .add(follower::Column::ProjectId.eq(Some(project_id))) + .add( + follower::Column::FollowerConnectionServerId + .eq(connection.owner_id), + ) + .add(follower::Column::FollowerConnectionId.eq(connection.id)), + ), + ) + .exec(&*tx) + .await?; + + let room = self.get_room(project.room_id, &tx).await?; + let left_project = LeftProject { + id: project_id, + host_user_id: project.host_user_id, + host_connection_id: project.host_connection()?, + connection_ids, + }; + Ok((room, left_project)) + }) + .await + } + + pub async fn project_collaborators( + &self, + project_id: ProjectId, + connection_id: ConnectionId, + ) -> Result>> { + let room_id = self.room_id_for_project(project_id).await?; + self.room_transaction(room_id, |tx| async move { + let collaborators = project_collaborator::Entity::find() + .filter(project_collaborator::Column::ProjectId.eq(project_id)) + .all(&*tx) + .await? + .into_iter() + .map(|collaborator| ProjectCollaborator { + connection_id: collaborator.connection(), + user_id: collaborator.user_id, + replica_id: collaborator.replica_id, + is_host: collaborator.is_host, + }) + .collect::>(); + + if collaborators + .iter() + .any(|collaborator| collaborator.connection_id == connection_id) + { + Ok(collaborators) + } else { + Err(anyhow!("no such project"))? + } + }) + .await + } + + pub async fn project_connection_ids( + &self, + project_id: ProjectId, + connection_id: ConnectionId, + ) -> Result>> { + let room_id = self.room_id_for_project(project_id).await?; + self.room_transaction(room_id, |tx| async move { + let mut collaborators = project_collaborator::Entity::find() + .filter(project_collaborator::Column::ProjectId.eq(project_id)) + .stream(&*tx) + .await?; + + let mut connection_ids = HashSet::default(); + while let Some(collaborator) = collaborators.next().await { + let collaborator = collaborator?; + connection_ids.insert(collaborator.connection()); + } + + if connection_ids.contains(&connection_id) { + Ok(connection_ids) + } else { + Err(anyhow!("no such project"))? + } + }) + .await + } + + async fn project_guest_connection_ids( + &self, + project_id: ProjectId, + tx: &DatabaseTransaction, + ) -> Result> { + let mut collaborators = project_collaborator::Entity::find() + .filter( + project_collaborator::Column::ProjectId + .eq(project_id) + .and(project_collaborator::Column::IsHost.eq(false)), + ) + .stream(tx) + .await?; + + let mut guest_connection_ids = Vec::new(); + while let Some(collaborator) = collaborators.next().await { + let collaborator = collaborator?; + guest_connection_ids.push(collaborator.connection()); + } + Ok(guest_connection_ids) + } + + pub async fn room_id_for_project(&self, project_id: ProjectId) -> Result { + self.transaction(|tx| async move { + let project = project::Entity::find_by_id(project_id) + .one(&*tx) + .await? + .ok_or_else(|| anyhow!("project {} not found", project_id))?; + Ok(project.room_id) + }) + .await + } + + pub async fn check_room_participants( + &self, + room_id: RoomId, + leader_id: ConnectionId, + follower_id: ConnectionId, + ) -> Result<()> { + self.transaction(|tx| async move { + use room_participant::Column; + + let count = room_participant::Entity::find() + .filter( + Condition::all().add(Column::RoomId.eq(room_id)).add( + Condition::any() + .add(Column::AnsweringConnectionId.eq(leader_id.id as i32).and( + Column::AnsweringConnectionServerId.eq(leader_id.owner_id as i32), + )) + .add(Column::AnsweringConnectionId.eq(follower_id.id as i32).and( + Column::AnsweringConnectionServerId.eq(follower_id.owner_id as i32), + )), + ), + ) + .count(&*tx) + .await?; + + if count < 2 { + Err(anyhow!("not room participants"))?; + } + + Ok(()) + }) + .await + } + + pub async fn follow( + &self, + room_id: RoomId, + project_id: ProjectId, + leader_connection: ConnectionId, + follower_connection: ConnectionId, + ) -> Result> { + self.room_transaction(room_id, |tx| async move { + follower::ActiveModel { + room_id: ActiveValue::set(room_id), + project_id: ActiveValue::set(project_id), + leader_connection_server_id: ActiveValue::set(ServerId( + leader_connection.owner_id as i32, + )), + leader_connection_id: ActiveValue::set(leader_connection.id as i32), + follower_connection_server_id: ActiveValue::set(ServerId( + follower_connection.owner_id as i32, + )), + follower_connection_id: ActiveValue::set(follower_connection.id as i32), + ..Default::default() + } + .insert(&*tx) + .await?; + + let room = self.get_room(room_id, &*tx).await?; + Ok(room) + }) + .await + } + + pub async fn unfollow( + &self, + room_id: RoomId, + project_id: ProjectId, + leader_connection: ConnectionId, + follower_connection: ConnectionId, + ) -> Result> { + self.room_transaction(room_id, |tx| async move { + follower::Entity::delete_many() + .filter( + Condition::all() + .add(follower::Column::RoomId.eq(room_id)) + .add(follower::Column::ProjectId.eq(project_id)) + .add( + follower::Column::LeaderConnectionServerId + .eq(leader_connection.owner_id), + ) + .add(follower::Column::LeaderConnectionId.eq(leader_connection.id)) + .add( + follower::Column::FollowerConnectionServerId + .eq(follower_connection.owner_id), + ) + .add(follower::Column::FollowerConnectionId.eq(follower_connection.id)), + ) + .exec(&*tx) + .await?; + + let room = self.get_room(room_id, &*tx).await?; + Ok(room) + }) + .await + } +} diff --git a/crates/collab2/src/db/queries/rooms.rs b/crates/collab2/src/db/queries/rooms.rs new file mode 100644 index 0000000000..40fdf5d58f --- /dev/null +++ b/crates/collab2/src/db/queries/rooms.rs @@ -0,0 +1,1203 @@ +use super::*; + +impl Database { + pub async fn clear_stale_room_participants( + &self, + room_id: RoomId, + new_server_id: ServerId, + ) -> Result> { + self.room_transaction(room_id, |tx| async move { + let stale_participant_filter = Condition::all() + .add(room_participant::Column::RoomId.eq(room_id)) + .add(room_participant::Column::AnsweringConnectionId.is_not_null()) + .add(room_participant::Column::AnsweringConnectionServerId.ne(new_server_id)); + + let stale_participant_user_ids = room_participant::Entity::find() + .filter(stale_participant_filter.clone()) + .all(&*tx) + .await? + .into_iter() + .map(|participant| participant.user_id) + .collect::>(); + + // Delete participants who failed to reconnect and cancel their calls. + let mut canceled_calls_to_user_ids = Vec::new(); + room_participant::Entity::delete_many() + .filter(stale_participant_filter) + .exec(&*tx) + .await?; + let called_participants = room_participant::Entity::find() + .filter( + Condition::all() + .add( + room_participant::Column::CallingUserId + .is_in(stale_participant_user_ids.iter().copied()), + ) + .add(room_participant::Column::AnsweringConnectionId.is_null()), + ) + .all(&*tx) + .await?; + room_participant::Entity::delete_many() + .filter( + room_participant::Column::Id + .is_in(called_participants.iter().map(|participant| participant.id)), + ) + .exec(&*tx) + .await?; + canceled_calls_to_user_ids.extend( + called_participants + .into_iter() + .map(|participant| participant.user_id), + ); + + let (channel, room) = self.get_channel_room(room_id, &tx).await?; + let channel_members; + if let Some(channel) = &channel { + channel_members = self.get_channel_participants(channel, &tx).await?; + } else { + channel_members = Vec::new(); + + // Delete the room if it becomes empty. + if room.participants.is_empty() { + project::Entity::delete_many() + .filter(project::Column::RoomId.eq(room_id)) + .exec(&*tx) + .await?; + room::Entity::delete_by_id(room_id).exec(&*tx).await?; + } + }; + + Ok(RefreshedRoom { + room, + channel_id: channel.map(|channel| channel.id), + channel_members, + stale_participant_user_ids, + canceled_calls_to_user_ids, + }) + }) + .await + } + + pub async fn incoming_call_for_user( + &self, + user_id: UserId, + ) -> Result> { + self.transaction(|tx| async move { + let pending_participant = room_participant::Entity::find() + .filter( + room_participant::Column::UserId + .eq(user_id) + .and(room_participant::Column::AnsweringConnectionId.is_null()), + ) + .one(&*tx) + .await?; + + if let Some(pending_participant) = pending_participant { + let room = self.get_room(pending_participant.room_id, &tx).await?; + Ok(Self::build_incoming_call(&room, user_id)) + } else { + Ok(None) + } + }) + .await + } + + pub async fn create_room( + &self, + user_id: UserId, + connection: ConnectionId, + live_kit_room: &str, + release_channel: &str, + ) -> Result { + self.transaction(|tx| async move { + let room = room::ActiveModel { + live_kit_room: ActiveValue::set(live_kit_room.into()), + enviroment: ActiveValue::set(Some(release_channel.to_string())), + ..Default::default() + } + .insert(&*tx) + .await?; + room_participant::ActiveModel { + room_id: ActiveValue::set(room.id), + user_id: ActiveValue::set(user_id), + answering_connection_id: ActiveValue::set(Some(connection.id as i32)), + answering_connection_server_id: ActiveValue::set(Some(ServerId( + connection.owner_id as i32, + ))), + answering_connection_lost: ActiveValue::set(false), + calling_user_id: ActiveValue::set(user_id), + calling_connection_id: ActiveValue::set(connection.id as i32), + calling_connection_server_id: ActiveValue::set(Some(ServerId( + connection.owner_id as i32, + ))), + participant_index: ActiveValue::set(Some(0)), + ..Default::default() + } + .insert(&*tx) + .await?; + + let room = self.get_room(room.id, &tx).await?; + Ok(room) + }) + .await + } + + pub async fn call( + &self, + room_id: RoomId, + calling_user_id: UserId, + calling_connection: ConnectionId, + called_user_id: UserId, + initial_project_id: Option, + ) -> Result> { + self.room_transaction(room_id, |tx| async move { + room_participant::ActiveModel { + room_id: ActiveValue::set(room_id), + user_id: ActiveValue::set(called_user_id), + answering_connection_lost: ActiveValue::set(false), + participant_index: ActiveValue::NotSet, + calling_user_id: ActiveValue::set(calling_user_id), + calling_connection_id: ActiveValue::set(calling_connection.id as i32), + calling_connection_server_id: ActiveValue::set(Some(ServerId( + calling_connection.owner_id as i32, + ))), + initial_project_id: ActiveValue::set(initial_project_id), + ..Default::default() + } + .insert(&*tx) + .await?; + + let room = self.get_room(room_id, &tx).await?; + let incoming_call = Self::build_incoming_call(&room, called_user_id) + .ok_or_else(|| anyhow!("failed to build incoming call"))?; + Ok((room, incoming_call)) + }) + .await + } + + pub async fn call_failed( + &self, + room_id: RoomId, + called_user_id: UserId, + ) -> Result> { + self.room_transaction(room_id, |tx| async move { + room_participant::Entity::delete_many() + .filter( + room_participant::Column::RoomId + .eq(room_id) + .and(room_participant::Column::UserId.eq(called_user_id)), + ) + .exec(&*tx) + .await?; + let room = self.get_room(room_id, &tx).await?; + Ok(room) + }) + .await + } + + pub async fn decline_call( + &self, + expected_room_id: Option, + user_id: UserId, + ) -> Result>> { + self.optional_room_transaction(|tx| async move { + let mut filter = Condition::all() + .add(room_participant::Column::UserId.eq(user_id)) + .add(room_participant::Column::AnsweringConnectionId.is_null()); + if let Some(room_id) = expected_room_id { + filter = filter.add(room_participant::Column::RoomId.eq(room_id)); + } + let participant = room_participant::Entity::find() + .filter(filter) + .one(&*tx) + .await?; + + let participant = if let Some(participant) = participant { + participant + } else if expected_room_id.is_some() { + return Err(anyhow!("could not find call to decline"))?; + } else { + return Ok(None); + }; + + let room_id = participant.room_id; + room_participant::Entity::delete(participant.into_active_model()) + .exec(&*tx) + .await?; + + let room = self.get_room(room_id, &tx).await?; + Ok(Some((room_id, room))) + }) + .await + } + + pub async fn cancel_call( + &self, + room_id: RoomId, + calling_connection: ConnectionId, + called_user_id: UserId, + ) -> Result> { + self.room_transaction(room_id, |tx| async move { + let participant = room_participant::Entity::find() + .filter( + Condition::all() + .add(room_participant::Column::UserId.eq(called_user_id)) + .add(room_participant::Column::RoomId.eq(room_id)) + .add( + room_participant::Column::CallingConnectionId + .eq(calling_connection.id as i32), + ) + .add( + room_participant::Column::CallingConnectionServerId + .eq(calling_connection.owner_id as i32), + ) + .add(room_participant::Column::AnsweringConnectionId.is_null()), + ) + .one(&*tx) + .await? + .ok_or_else(|| anyhow!("no call to cancel"))?; + + room_participant::Entity::delete(participant.into_active_model()) + .exec(&*tx) + .await?; + + let room = self.get_room(room_id, &tx).await?; + Ok(room) + }) + .await + } + + pub async fn join_room( + &self, + room_id: RoomId, + user_id: UserId, + connection: ConnectionId, + enviroment: &str, + ) -> Result> { + self.room_transaction(room_id, |tx| async move { + #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] + enum QueryChannelIdAndEnviroment { + ChannelId, + Enviroment, + } + + let (channel_id, release_channel): (Option, Option) = + room::Entity::find() + .select_only() + .column(room::Column::ChannelId) + .column(room::Column::Enviroment) + .filter(room::Column::Id.eq(room_id)) + .into_values::<_, QueryChannelIdAndEnviroment>() + .one(&*tx) + .await? + .ok_or_else(|| anyhow!("no such room"))?; + + if let Some(release_channel) = release_channel { + if &release_channel != enviroment { + Err(anyhow!("must join using the {} release", release_channel))?; + } + } + + if channel_id.is_some() { + Err(anyhow!("tried to join channel call directly"))? + } + + let participant_index = self + .get_next_participant_index_internal(room_id, &*tx) + .await?; + + let result = room_participant::Entity::update_many() + .filter( + Condition::all() + .add(room_participant::Column::RoomId.eq(room_id)) + .add(room_participant::Column::UserId.eq(user_id)) + .add(room_participant::Column::AnsweringConnectionId.is_null()), + ) + .set(room_participant::ActiveModel { + participant_index: ActiveValue::Set(Some(participant_index)), + answering_connection_id: ActiveValue::set(Some(connection.id as i32)), + answering_connection_server_id: ActiveValue::set(Some(ServerId( + connection.owner_id as i32, + ))), + answering_connection_lost: ActiveValue::set(false), + ..Default::default() + }) + .exec(&*tx) + .await?; + if result.rows_affected == 0 { + Err(anyhow!("room does not exist or was already joined"))?; + } + + let room = self.get_room(room_id, &tx).await?; + Ok(JoinRoom { + room, + channel_id: None, + channel_members: vec![], + }) + }) + .await + } + + async fn get_next_participant_index_internal( + &self, + room_id: RoomId, + tx: &DatabaseTransaction, + ) -> Result { + #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] + enum QueryParticipantIndices { + ParticipantIndex, + } + let existing_participant_indices: Vec = room_participant::Entity::find() + .filter( + room_participant::Column::RoomId + .eq(room_id) + .and(room_participant::Column::ParticipantIndex.is_not_null()), + ) + .select_only() + .column(room_participant::Column::ParticipantIndex) + .into_values::<_, QueryParticipantIndices>() + .all(&*tx) + .await?; + + let mut participant_index = 0; + while existing_participant_indices.contains(&participant_index) { + participant_index += 1; + } + + Ok(participant_index) + } + + pub async fn channel_id_for_room(&self, room_id: RoomId) -> Result> { + self.transaction(|tx| async move { + let room: Option = room::Entity::find() + .filter(room::Column::Id.eq(room_id)) + .one(&*tx) + .await?; + + Ok(room.and_then(|room| room.channel_id)) + }) + .await + } + + pub(crate) async fn join_channel_room_internal( + &self, + room_id: RoomId, + user_id: UserId, + connection: ConnectionId, + tx: &DatabaseTransaction, + ) -> Result { + let participant_index = self + .get_next_participant_index_internal(room_id, &*tx) + .await?; + + room_participant::Entity::insert_many([room_participant::ActiveModel { + room_id: ActiveValue::set(room_id), + user_id: ActiveValue::set(user_id), + answering_connection_id: ActiveValue::set(Some(connection.id as i32)), + answering_connection_server_id: ActiveValue::set(Some(ServerId( + connection.owner_id as i32, + ))), + answering_connection_lost: ActiveValue::set(false), + calling_user_id: ActiveValue::set(user_id), + calling_connection_id: ActiveValue::set(connection.id as i32), + calling_connection_server_id: ActiveValue::set(Some(ServerId( + connection.owner_id as i32, + ))), + participant_index: ActiveValue::Set(Some(participant_index)), + ..Default::default() + }]) + .on_conflict( + OnConflict::columns([room_participant::Column::UserId]) + .update_columns([ + room_participant::Column::AnsweringConnectionId, + room_participant::Column::AnsweringConnectionServerId, + room_participant::Column::AnsweringConnectionLost, + room_participant::Column::ParticipantIndex, + ]) + .to_owned(), + ) + .exec(&*tx) + .await?; + + let (channel, room) = self.get_channel_room(room_id, &tx).await?; + let channel = channel.ok_or_else(|| anyhow!("no channel for room"))?; + let channel_members = self.get_channel_participants(&channel, &*tx).await?; + Ok(JoinRoom { + room, + channel_id: Some(channel.id), + channel_members, + }) + } + + pub async fn rejoin_room( + &self, + rejoin_room: proto::RejoinRoom, + user_id: UserId, + connection: ConnectionId, + ) -> Result> { + let room_id = RoomId::from_proto(rejoin_room.id); + self.room_transaction(room_id, |tx| async { + let tx = tx; + let participant_update = room_participant::Entity::update_many() + .filter( + Condition::all() + .add(room_participant::Column::RoomId.eq(room_id)) + .add(room_participant::Column::UserId.eq(user_id)) + .add(room_participant::Column::AnsweringConnectionId.is_not_null()) + .add( + Condition::any() + .add(room_participant::Column::AnsweringConnectionLost.eq(true)) + .add( + room_participant::Column::AnsweringConnectionServerId + .ne(connection.owner_id as i32), + ), + ), + ) + .set(room_participant::ActiveModel { + answering_connection_id: ActiveValue::set(Some(connection.id as i32)), + answering_connection_server_id: ActiveValue::set(Some(ServerId( + connection.owner_id as i32, + ))), + answering_connection_lost: ActiveValue::set(false), + ..Default::default() + }) + .exec(&*tx) + .await?; + if participant_update.rows_affected == 0 { + return Err(anyhow!("room does not exist or was already joined"))?; + } + + let mut reshared_projects = Vec::new(); + for reshared_project in &rejoin_room.reshared_projects { + let project_id = ProjectId::from_proto(reshared_project.project_id); + let project = project::Entity::find_by_id(project_id) + .one(&*tx) + .await? + .ok_or_else(|| anyhow!("project does not exist"))?; + if project.host_user_id != user_id { + return Err(anyhow!("no such project"))?; + } + + let mut collaborators = project + .find_related(project_collaborator::Entity) + .all(&*tx) + .await?; + let host_ix = collaborators + .iter() + .position(|collaborator| { + collaborator.user_id == user_id && collaborator.is_host + }) + .ok_or_else(|| anyhow!("host not found among collaborators"))?; + let host = collaborators.swap_remove(host_ix); + let old_connection_id = host.connection(); + + project::Entity::update(project::ActiveModel { + host_connection_id: ActiveValue::set(Some(connection.id as i32)), + host_connection_server_id: ActiveValue::set(Some(ServerId( + connection.owner_id as i32, + ))), + ..project.into_active_model() + }) + .exec(&*tx) + .await?; + project_collaborator::Entity::update(project_collaborator::ActiveModel { + connection_id: ActiveValue::set(connection.id as i32), + connection_server_id: ActiveValue::set(ServerId(connection.owner_id as i32)), + ..host.into_active_model() + }) + .exec(&*tx) + .await?; + + self.update_project_worktrees(project_id, &reshared_project.worktrees, &tx) + .await?; + + reshared_projects.push(ResharedProject { + id: project_id, + old_connection_id, + collaborators: collaborators + .iter() + .map(|collaborator| ProjectCollaborator { + connection_id: collaborator.connection(), + user_id: collaborator.user_id, + replica_id: collaborator.replica_id, + is_host: collaborator.is_host, + }) + .collect(), + worktrees: reshared_project.worktrees.clone(), + }); + } + + project::Entity::delete_many() + .filter( + Condition::all() + .add(project::Column::RoomId.eq(room_id)) + .add(project::Column::HostUserId.eq(user_id)) + .add( + project::Column::Id + .is_not_in(reshared_projects.iter().map(|project| project.id)), + ), + ) + .exec(&*tx) + .await?; + + let mut rejoined_projects = Vec::new(); + for rejoined_project in &rejoin_room.rejoined_projects { + let project_id = ProjectId::from_proto(rejoined_project.id); + let Some(project) = project::Entity::find_by_id(project_id).one(&*tx).await? else { + continue; + }; + + let mut worktrees = Vec::new(); + let db_worktrees = project.find_related(worktree::Entity).all(&*tx).await?; + for db_worktree in db_worktrees { + let mut worktree = RejoinedWorktree { + id: db_worktree.id as u64, + abs_path: db_worktree.abs_path, + root_name: db_worktree.root_name, + visible: db_worktree.visible, + updated_entries: Default::default(), + removed_entries: Default::default(), + updated_repositories: Default::default(), + removed_repositories: Default::default(), + diagnostic_summaries: Default::default(), + settings_files: Default::default(), + scan_id: db_worktree.scan_id as u64, + completed_scan_id: db_worktree.completed_scan_id as u64, + }; + + let rejoined_worktree = rejoined_project + .worktrees + .iter() + .find(|worktree| worktree.id == db_worktree.id as u64); + + // File entries + { + let entry_filter = if let Some(rejoined_worktree) = rejoined_worktree { + worktree_entry::Column::ScanId.gt(rejoined_worktree.scan_id) + } else { + worktree_entry::Column::IsDeleted.eq(false) + }; + + let mut db_entries = worktree_entry::Entity::find() + .filter( + Condition::all() + .add(worktree_entry::Column::ProjectId.eq(project.id)) + .add(worktree_entry::Column::WorktreeId.eq(worktree.id)) + .add(entry_filter), + ) + .stream(&*tx) + .await?; + + while let Some(db_entry) = db_entries.next().await { + let db_entry = db_entry?; + if db_entry.is_deleted { + worktree.removed_entries.push(db_entry.id as u64); + } else { + worktree.updated_entries.push(proto::Entry { + id: db_entry.id as u64, + is_dir: db_entry.is_dir, + path: db_entry.path, + inode: db_entry.inode as u64, + mtime: Some(proto::Timestamp { + seconds: db_entry.mtime_seconds as u64, + nanos: db_entry.mtime_nanos as u32, + }), + is_symlink: db_entry.is_symlink, + is_ignored: db_entry.is_ignored, + is_external: db_entry.is_external, + git_status: db_entry.git_status.map(|status| status as i32), + }); + } + } + } + + // Repository Entries + { + let repository_entry_filter = + if let Some(rejoined_worktree) = rejoined_worktree { + worktree_repository::Column::ScanId.gt(rejoined_worktree.scan_id) + } else { + worktree_repository::Column::IsDeleted.eq(false) + }; + + let mut db_repositories = worktree_repository::Entity::find() + .filter( + Condition::all() + .add(worktree_repository::Column::ProjectId.eq(project.id)) + .add(worktree_repository::Column::WorktreeId.eq(worktree.id)) + .add(repository_entry_filter), + ) + .stream(&*tx) + .await?; + + while let Some(db_repository) = db_repositories.next().await { + let db_repository = db_repository?; + if db_repository.is_deleted { + worktree + .removed_repositories + .push(db_repository.work_directory_id as u64); + } else { + worktree.updated_repositories.push(proto::RepositoryEntry { + work_directory_id: db_repository.work_directory_id as u64, + branch: db_repository.branch, + }); + } + } + } + + worktrees.push(worktree); + } + + let language_servers = project + .find_related(language_server::Entity) + .all(&*tx) + .await? + .into_iter() + .map(|language_server| proto::LanguageServer { + id: language_server.id as u64, + name: language_server.name, + }) + .collect::>(); + + { + let mut db_settings_files = worktree_settings_file::Entity::find() + .filter(worktree_settings_file::Column::ProjectId.eq(project_id)) + .stream(&*tx) + .await?; + while let Some(db_settings_file) = db_settings_files.next().await { + let db_settings_file = db_settings_file?; + if let Some(worktree) = worktrees + .iter_mut() + .find(|w| w.id == db_settings_file.worktree_id as u64) + { + worktree.settings_files.push(WorktreeSettingsFile { + path: db_settings_file.path, + content: db_settings_file.content, + }); + } + } + } + + let mut collaborators = project + .find_related(project_collaborator::Entity) + .all(&*tx) + .await?; + let self_collaborator = if let Some(self_collaborator_ix) = collaborators + .iter() + .position(|collaborator| collaborator.user_id == user_id) + { + collaborators.swap_remove(self_collaborator_ix) + } else { + continue; + }; + let old_connection_id = self_collaborator.connection(); + project_collaborator::Entity::update(project_collaborator::ActiveModel { + connection_id: ActiveValue::set(connection.id as i32), + connection_server_id: ActiveValue::set(ServerId(connection.owner_id as i32)), + ..self_collaborator.into_active_model() + }) + .exec(&*tx) + .await?; + + let collaborators = collaborators + .into_iter() + .map(|collaborator| ProjectCollaborator { + connection_id: collaborator.connection(), + user_id: collaborator.user_id, + replica_id: collaborator.replica_id, + is_host: collaborator.is_host, + }) + .collect::>(); + + rejoined_projects.push(RejoinedProject { + id: project_id, + old_connection_id, + collaborators, + worktrees, + language_servers, + }); + } + + let (channel, room) = self.get_channel_room(room_id, &tx).await?; + let channel_members = if let Some(channel) = &channel { + self.get_channel_participants(&channel, &tx).await? + } else { + Vec::new() + }; + + Ok(RejoinedRoom { + room, + channel_id: channel.map(|channel| channel.id), + channel_members, + rejoined_projects, + reshared_projects, + }) + }) + .await + } + + pub async fn leave_room( + &self, + connection: ConnectionId, + ) -> Result>> { + self.optional_room_transaction(|tx| async move { + let leaving_participant = room_participant::Entity::find() + .filter( + Condition::all() + .add( + room_participant::Column::AnsweringConnectionId + .eq(connection.id as i32), + ) + .add( + room_participant::Column::AnsweringConnectionServerId + .eq(connection.owner_id as i32), + ), + ) + .one(&*tx) + .await?; + + if let Some(leaving_participant) = leaving_participant { + // Leave room. + let room_id = leaving_participant.room_id; + room_participant::Entity::delete_by_id(leaving_participant.id) + .exec(&*tx) + .await?; + + // Cancel pending calls initiated by the leaving user. + let called_participants = room_participant::Entity::find() + .filter( + Condition::all() + .add( + room_participant::Column::CallingUserId + .eq(leaving_participant.user_id), + ) + .add(room_participant::Column::AnsweringConnectionId.is_null()), + ) + .all(&*tx) + .await?; + room_participant::Entity::delete_many() + .filter( + room_participant::Column::Id + .is_in(called_participants.iter().map(|participant| participant.id)), + ) + .exec(&*tx) + .await?; + let canceled_calls_to_user_ids = called_participants + .into_iter() + .map(|participant| participant.user_id) + .collect(); + + // Detect left projects. + #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] + enum QueryProjectIds { + ProjectId, + } + let project_ids: Vec = project_collaborator::Entity::find() + .select_only() + .column_as( + project_collaborator::Column::ProjectId, + QueryProjectIds::ProjectId, + ) + .filter( + Condition::all() + .add( + project_collaborator::Column::ConnectionId.eq(connection.id as i32), + ) + .add( + project_collaborator::Column::ConnectionServerId + .eq(connection.owner_id as i32), + ), + ) + .into_values::<_, QueryProjectIds>() + .all(&*tx) + .await?; + let mut left_projects = HashMap::default(); + let mut collaborators = project_collaborator::Entity::find() + .filter(project_collaborator::Column::ProjectId.is_in(project_ids)) + .stream(&*tx) + .await?; + while let Some(collaborator) = collaborators.next().await { + let collaborator = collaborator?; + let left_project = + left_projects + .entry(collaborator.project_id) + .or_insert(LeftProject { + id: collaborator.project_id, + host_user_id: Default::default(), + connection_ids: Default::default(), + host_connection_id: Default::default(), + }); + + let collaborator_connection_id = collaborator.connection(); + if collaborator_connection_id != connection { + left_project.connection_ids.push(collaborator_connection_id); + } + + if collaborator.is_host { + left_project.host_user_id = collaborator.user_id; + left_project.host_connection_id = collaborator_connection_id; + } + } + drop(collaborators); + + // Leave projects. + project_collaborator::Entity::delete_many() + .filter( + Condition::all() + .add( + project_collaborator::Column::ConnectionId.eq(connection.id as i32), + ) + .add( + project_collaborator::Column::ConnectionServerId + .eq(connection.owner_id as i32), + ), + ) + .exec(&*tx) + .await?; + + // Unshare projects. + project::Entity::delete_many() + .filter( + Condition::all() + .add(project::Column::RoomId.eq(room_id)) + .add(project::Column::HostConnectionId.eq(connection.id as i32)) + .add( + project::Column::HostConnectionServerId + .eq(connection.owner_id as i32), + ), + ) + .exec(&*tx) + .await?; + + let (channel, room) = self.get_channel_room(room_id, &tx).await?; + let deleted = if room.participants.is_empty() { + let result = room::Entity::delete_by_id(room_id).exec(&*tx).await?; + result.rows_affected > 0 + } else { + false + }; + + let channel_members = if let Some(channel) = &channel { + self.get_channel_participants(channel, &tx).await? + } else { + Vec::new() + }; + let left_room = LeftRoom { + room, + channel_id: channel.map(|channel| channel.id), + channel_members, + left_projects, + canceled_calls_to_user_ids, + deleted, + }; + + if left_room.room.participants.is_empty() { + self.rooms.remove(&room_id); + } + + Ok(Some((room_id, left_room))) + } else { + Ok(None) + } + }) + .await + } + + pub async fn update_room_participant_location( + &self, + room_id: RoomId, + connection: ConnectionId, + location: proto::ParticipantLocation, + ) -> Result> { + self.room_transaction(room_id, |tx| async { + let tx = tx; + let location_kind; + let location_project_id; + match location + .variant + .as_ref() + .ok_or_else(|| anyhow!("invalid location"))? + { + proto::participant_location::Variant::SharedProject(project) => { + location_kind = 0; + location_project_id = Some(ProjectId::from_proto(project.id)); + } + proto::participant_location::Variant::UnsharedProject(_) => { + location_kind = 1; + location_project_id = None; + } + proto::participant_location::Variant::External(_) => { + location_kind = 2; + location_project_id = None; + } + } + + let result = room_participant::Entity::update_many() + .filter( + Condition::all() + .add(room_participant::Column::RoomId.eq(room_id)) + .add( + room_participant::Column::AnsweringConnectionId + .eq(connection.id as i32), + ) + .add( + room_participant::Column::AnsweringConnectionServerId + .eq(connection.owner_id as i32), + ), + ) + .set(room_participant::ActiveModel { + location_kind: ActiveValue::set(Some(location_kind)), + location_project_id: ActiveValue::set(location_project_id), + ..Default::default() + }) + .exec(&*tx) + .await?; + + if result.rows_affected == 1 { + let room = self.get_room(room_id, &tx).await?; + Ok(room) + } else { + Err(anyhow!("could not update room participant location"))? + } + }) + .await + } + + pub async fn connection_lost(&self, connection: ConnectionId) -> Result<()> { + self.transaction(|tx| async move { + self.room_connection_lost(connection, &*tx).await?; + self.channel_buffer_connection_lost(connection, &*tx) + .await?; + self.channel_chat_connection_lost(connection, &*tx).await?; + Ok(()) + }) + .await + } + + pub async fn room_connection_lost( + &self, + connection: ConnectionId, + tx: &DatabaseTransaction, + ) -> Result<()> { + let participant = room_participant::Entity::find() + .filter( + Condition::all() + .add(room_participant::Column::AnsweringConnectionId.eq(connection.id as i32)) + .add( + room_participant::Column::AnsweringConnectionServerId + .eq(connection.owner_id as i32), + ), + ) + .one(&*tx) + .await?; + + if let Some(participant) = participant { + room_participant::Entity::update(room_participant::ActiveModel { + answering_connection_lost: ActiveValue::set(true), + ..participant.into_active_model() + }) + .exec(&*tx) + .await?; + } + Ok(()) + } + + fn build_incoming_call( + room: &proto::Room, + called_user_id: UserId, + ) -> Option { + let pending_participant = room + .pending_participants + .iter() + .find(|participant| participant.user_id == called_user_id.to_proto())?; + + Some(proto::IncomingCall { + room_id: room.id, + calling_user_id: pending_participant.calling_user_id, + participant_user_ids: room + .participants + .iter() + .map(|participant| participant.user_id) + .collect(), + initial_project: room.participants.iter().find_map(|participant| { + let initial_project_id = pending_participant.initial_project_id?; + participant + .projects + .iter() + .find(|project| project.id == initial_project_id) + .cloned() + }), + }) + } + + pub async fn get_room(&self, room_id: RoomId, tx: &DatabaseTransaction) -> Result { + let (_, room) = self.get_channel_room(room_id, tx).await?; + Ok(room) + } + + pub async fn room_connection_ids( + &self, + room_id: RoomId, + connection_id: ConnectionId, + ) -> Result>> { + self.room_transaction(room_id, |tx| async move { + let mut participants = room_participant::Entity::find() + .filter(room_participant::Column::RoomId.eq(room_id)) + .stream(&*tx) + .await?; + + let mut is_participant = false; + let mut connection_ids = HashSet::default(); + while let Some(participant) = participants.next().await { + let participant = participant?; + if let Some(answering_connection) = participant.answering_connection() { + if answering_connection == connection_id { + is_participant = true; + } else { + connection_ids.insert(answering_connection); + } + } + } + + if !is_participant { + Err(anyhow!("not a room participant"))?; + } + + Ok(connection_ids) + }) + .await + } + + async fn get_channel_room( + &self, + room_id: RoomId, + tx: &DatabaseTransaction, + ) -> Result<(Option, proto::Room)> { + let db_room = room::Entity::find_by_id(room_id) + .one(tx) + .await? + .ok_or_else(|| anyhow!("could not find room"))?; + + let mut db_participants = db_room + .find_related(room_participant::Entity) + .stream(tx) + .await?; + let mut participants = HashMap::default(); + let mut pending_participants = Vec::new(); + while let Some(db_participant) = db_participants.next().await { + let db_participant = db_participant?; + if let ( + Some(answering_connection_id), + Some(answering_connection_server_id), + Some(participant_index), + ) = ( + db_participant.answering_connection_id, + db_participant.answering_connection_server_id, + db_participant.participant_index, + ) { + let location = match ( + db_participant.location_kind, + db_participant.location_project_id, + ) { + (Some(0), Some(project_id)) => { + Some(proto::participant_location::Variant::SharedProject( + proto::participant_location::SharedProject { + id: project_id.to_proto(), + }, + )) + } + (Some(1), _) => Some(proto::participant_location::Variant::UnsharedProject( + Default::default(), + )), + _ => Some(proto::participant_location::Variant::External( + Default::default(), + )), + }; + + let answering_connection = ConnectionId { + owner_id: answering_connection_server_id.0 as u32, + id: answering_connection_id as u32, + }; + participants.insert( + answering_connection, + proto::Participant { + user_id: db_participant.user_id.to_proto(), + peer_id: Some(answering_connection.into()), + projects: Default::default(), + location: Some(proto::ParticipantLocation { variant: location }), + participant_index: participant_index as u32, + }, + ); + } else { + pending_participants.push(proto::PendingParticipant { + user_id: db_participant.user_id.to_proto(), + calling_user_id: db_participant.calling_user_id.to_proto(), + initial_project_id: db_participant.initial_project_id.map(|id| id.to_proto()), + }); + } + } + drop(db_participants); + + let mut db_projects = db_room + .find_related(project::Entity) + .find_with_related(worktree::Entity) + .stream(tx) + .await?; + + while let Some(row) = db_projects.next().await { + let (db_project, db_worktree) = row?; + let host_connection = db_project.host_connection()?; + if let Some(participant) = participants.get_mut(&host_connection) { + let project = if let Some(project) = participant + .projects + .iter_mut() + .find(|project| project.id == db_project.id.to_proto()) + { + project + } else { + participant.projects.push(proto::ParticipantProject { + id: db_project.id.to_proto(), + worktree_root_names: Default::default(), + }); + participant.projects.last_mut().unwrap() + }; + + if let Some(db_worktree) = db_worktree { + if db_worktree.visible { + project.worktree_root_names.push(db_worktree.root_name); + } + } + } + } + drop(db_projects); + + let mut db_followers = db_room.find_related(follower::Entity).stream(tx).await?; + let mut followers = Vec::new(); + while let Some(db_follower) = db_followers.next().await { + let db_follower = db_follower?; + followers.push(proto::Follower { + leader_id: Some(db_follower.leader_connection().into()), + follower_id: Some(db_follower.follower_connection().into()), + project_id: db_follower.project_id.to_proto(), + }); + } + drop(db_followers); + + let channel = if let Some(channel_id) = db_room.channel_id { + Some(self.get_channel_internal(channel_id, &*tx).await?) + } else { + None + }; + + Ok(( + channel, + proto::Room { + id: db_room.id.to_proto(), + live_kit_room: db_room.live_kit_room, + participants: participants.into_values().collect(), + pending_participants, + followers, + }, + )) + } +} diff --git a/crates/collab2/src/db/queries/servers.rs b/crates/collab2/src/db/queries/servers.rs new file mode 100644 index 0000000000..e5ceee8887 --- /dev/null +++ b/crates/collab2/src/db/queries/servers.rs @@ -0,0 +1,99 @@ +use super::*; + +impl Database { + pub async fn create_server(&self, environment: &str) -> Result { + self.transaction(|tx| async move { + let server = server::ActiveModel { + environment: ActiveValue::set(environment.into()), + ..Default::default() + } + .insert(&*tx) + .await?; + Ok(server.id) + }) + .await + } + + pub async fn stale_server_resource_ids( + &self, + environment: &str, + new_server_id: ServerId, + ) -> Result<(Vec, Vec)> { + self.transaction(|tx| async move { + #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] + enum QueryRoomIds { + RoomId, + } + + #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] + enum QueryChannelIds { + ChannelId, + } + + let stale_server_epochs = self + .stale_server_ids(environment, new_server_id, &tx) + .await?; + let room_ids = room_participant::Entity::find() + .select_only() + .column(room_participant::Column::RoomId) + .distinct() + .filter( + room_participant::Column::AnsweringConnectionServerId + .is_in(stale_server_epochs.iter().copied()), + ) + .into_values::<_, QueryRoomIds>() + .all(&*tx) + .await?; + let channel_ids = channel_buffer_collaborator::Entity::find() + .select_only() + .column(channel_buffer_collaborator::Column::ChannelId) + .distinct() + .filter( + channel_buffer_collaborator::Column::ConnectionServerId + .is_in(stale_server_epochs.iter().copied()), + ) + .into_values::<_, QueryChannelIds>() + .all(&*tx) + .await?; + + Ok((room_ids, channel_ids)) + }) + .await + } + + pub async fn delete_stale_servers( + &self, + environment: &str, + new_server_id: ServerId, + ) -> Result<()> { + self.transaction(|tx| async move { + server::Entity::delete_many() + .filter( + Condition::all() + .add(server::Column::Environment.eq(environment)) + .add(server::Column::Id.ne(new_server_id)), + ) + .exec(&*tx) + .await?; + Ok(()) + }) + .await + } + + async fn stale_server_ids( + &self, + environment: &str, + new_server_id: ServerId, + tx: &DatabaseTransaction, + ) -> Result> { + let stale_servers = server::Entity::find() + .filter( + Condition::all() + .add(server::Column::Environment.eq(environment)) + .add(server::Column::Id.ne(new_server_id)), + ) + .all(&*tx) + .await?; + Ok(stale_servers.into_iter().map(|server| server.id).collect()) + } +} diff --git a/crates/collab2/src/db/queries/users.rs b/crates/collab2/src/db/queries/users.rs new file mode 100644 index 0000000000..27e64e2598 --- /dev/null +++ b/crates/collab2/src/db/queries/users.rs @@ -0,0 +1,259 @@ +use super::*; + +impl Database { + pub async fn create_user( + &self, + email_address: &str, + admin: bool, + params: NewUserParams, + ) -> Result { + self.transaction(|tx| async { + let tx = tx; + let user = user::Entity::insert(user::ActiveModel { + email_address: ActiveValue::set(Some(email_address.into())), + github_login: ActiveValue::set(params.github_login.clone()), + github_user_id: ActiveValue::set(Some(params.github_user_id)), + admin: ActiveValue::set(admin), + metrics_id: ActiveValue::set(Uuid::new_v4()), + ..Default::default() + }) + .on_conflict( + OnConflict::column(user::Column::GithubLogin) + .update_column(user::Column::GithubLogin) + .to_owned(), + ) + .exec_with_returning(&*tx) + .await?; + + Ok(NewUserResult { + user_id: user.id, + metrics_id: user.metrics_id.to_string(), + signup_device_id: None, + inviting_user_id: None, + }) + }) + .await + } + + pub async fn get_user_by_id(&self, id: UserId) -> Result> { + self.transaction(|tx| async move { Ok(user::Entity::find_by_id(id).one(&*tx).await?) }) + .await + } + + pub async fn get_users_by_ids(&self, ids: Vec) -> Result> { + self.transaction(|tx| async { + let tx = tx; + Ok(user::Entity::find() + .filter(user::Column::Id.is_in(ids.iter().copied())) + .all(&*tx) + .await?) + }) + .await + } + + pub async fn get_user_by_github_login(&self, github_login: &str) -> Result> { + self.transaction(|tx| async move { + Ok(user::Entity::find() + .filter(user::Column::GithubLogin.eq(github_login)) + .one(&*tx) + .await?) + }) + .await + } + + pub async fn get_or_create_user_by_github_account( + &self, + github_login: &str, + github_user_id: Option, + github_email: Option<&str>, + ) -> Result> { + self.transaction(|tx| async move { + let tx = &*tx; + if let Some(github_user_id) = github_user_id { + if let Some(user_by_github_user_id) = user::Entity::find() + .filter(user::Column::GithubUserId.eq(github_user_id)) + .one(tx) + .await? + { + let mut user_by_github_user_id = user_by_github_user_id.into_active_model(); + user_by_github_user_id.github_login = ActiveValue::set(github_login.into()); + Ok(Some(user_by_github_user_id.update(tx).await?)) + } else if let Some(user_by_github_login) = user::Entity::find() + .filter(user::Column::GithubLogin.eq(github_login)) + .one(tx) + .await? + { + let mut user_by_github_login = user_by_github_login.into_active_model(); + user_by_github_login.github_user_id = ActiveValue::set(Some(github_user_id)); + Ok(Some(user_by_github_login.update(tx).await?)) + } else { + let user = user::Entity::insert(user::ActiveModel { + email_address: ActiveValue::set(github_email.map(|email| email.into())), + github_login: ActiveValue::set(github_login.into()), + github_user_id: ActiveValue::set(Some(github_user_id)), + admin: ActiveValue::set(false), + invite_count: ActiveValue::set(0), + invite_code: ActiveValue::set(None), + metrics_id: ActiveValue::set(Uuid::new_v4()), + ..Default::default() + }) + .exec_with_returning(&*tx) + .await?; + Ok(Some(user)) + } + } else { + Ok(user::Entity::find() + .filter(user::Column::GithubLogin.eq(github_login)) + .one(tx) + .await?) + } + }) + .await + } + + pub async fn get_all_users(&self, page: u32, limit: u32) -> Result> { + self.transaction(|tx| async move { + Ok(user::Entity::find() + .order_by_asc(user::Column::GithubLogin) + .limit(limit as u64) + .offset(page as u64 * limit as u64) + .all(&*tx) + .await?) + }) + .await + } + + pub async fn get_user_metrics_id(&self, id: UserId) -> Result { + #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] + enum QueryAs { + MetricsId, + } + + self.transaction(|tx| async move { + let metrics_id: Uuid = user::Entity::find_by_id(id) + .select_only() + .column(user::Column::MetricsId) + .into_values::<_, QueryAs>() + .one(&*tx) + .await? + .ok_or_else(|| anyhow!("could not find user"))?; + Ok(metrics_id.to_string()) + }) + .await + } + + pub async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> { + self.transaction(|tx| async move { + user::Entity::update_many() + .filter(user::Column::Id.eq(id)) + .set(user::ActiveModel { + connected_once: ActiveValue::set(connected_once), + ..Default::default() + }) + .exec(&*tx) + .await?; + Ok(()) + }) + .await + } + + pub async fn destroy_user(&self, id: UserId) -> Result<()> { + self.transaction(|tx| async move { + access_token::Entity::delete_many() + .filter(access_token::Column::UserId.eq(id)) + .exec(&*tx) + .await?; + user::Entity::delete_by_id(id).exec(&*tx).await?; + Ok(()) + }) + .await + } + + pub async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result> { + self.transaction(|tx| async { + let tx = tx; + let like_string = Self::fuzzy_like_string(name_query); + let query = " + SELECT users.* + FROM users + WHERE github_login ILIKE $1 + ORDER BY github_login <-> $2 + LIMIT $3 + "; + + Ok(user::Entity::find() + .from_raw_sql(Statement::from_sql_and_values( + self.pool.get_database_backend(), + query, + vec![like_string.into(), name_query.into(), limit.into()], + )) + .all(&*tx) + .await?) + }) + .await + } + + pub fn fuzzy_like_string(string: &str) -> String { + let mut result = String::with_capacity(string.len() * 2 + 1); + for c in string.chars() { + if c.is_alphanumeric() { + result.push('%'); + result.push(c); + } + } + result.push('%'); + result + } + + pub async fn create_user_flag(&self, flag: &str) -> Result { + self.transaction(|tx| async move { + let flag = feature_flag::Entity::insert(feature_flag::ActiveModel { + flag: ActiveValue::set(flag.to_string()), + ..Default::default() + }) + .exec(&*tx) + .await? + .last_insert_id; + + Ok(flag) + }) + .await + } + + pub async fn add_user_flag(&self, user: UserId, flag: FlagId) -> Result<()> { + self.transaction(|tx| async move { + user_feature::Entity::insert(user_feature::ActiveModel { + user_id: ActiveValue::set(user), + feature_id: ActiveValue::set(flag), + }) + .exec(&*tx) + .await?; + + Ok(()) + }) + .await + } + + pub async fn get_user_flags(&self, user: UserId) -> Result> { + self.transaction(|tx| async move { + #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] + enum QueryAs { + Flag, + } + + let flags = user::Model { + id: user, + ..Default::default() + } + .find_linked(user::UserFlags) + .select_only() + .column(feature_flag::Column::Flag) + .into_values::<_, QueryAs>() + .all(&*tx) + .await?; + + Ok(flags) + }) + .await + } +} diff --git a/crates/collab2/src/db/tables.rs b/crates/collab2/src/db/tables.rs new file mode 100644 index 0000000000..4f28ce4fbd --- /dev/null +++ b/crates/collab2/src/db/tables.rs @@ -0,0 +1,32 @@ +pub mod access_token; +pub mod buffer; +pub mod buffer_operation; +pub mod buffer_snapshot; +pub mod channel; +pub mod channel_buffer_collaborator; +pub mod channel_chat_participant; +pub mod channel_member; +pub mod channel_message; +pub mod channel_message_mention; +pub mod contact; +pub mod feature_flag; +pub mod follower; +pub mod language_server; +pub mod notification; +pub mod notification_kind; +pub mod observed_buffer_edits; +pub mod observed_channel_messages; +pub mod project; +pub mod project_collaborator; +pub mod room; +pub mod room_participant; +pub mod server; +pub mod signup; +pub mod user; +pub mod user_feature; +pub mod worktree; +pub mod worktree_diagnostic_summary; +pub mod worktree_entry; +pub mod worktree_repository; +pub mod worktree_repository_statuses; +pub mod worktree_settings_file; diff --git a/crates/collab2/src/db/tables/access_token.rs b/crates/collab2/src/db/tables/access_token.rs new file mode 100644 index 0000000000..da7392b98c --- /dev/null +++ b/crates/collab2/src/db/tables/access_token.rs @@ -0,0 +1,29 @@ +use crate::db::{AccessTokenId, UserId}; +use sea_orm::entity::prelude::*; + +#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] +#[sea_orm(table_name = "access_tokens")] +pub struct Model { + #[sea_orm(primary_key)] + pub id: AccessTokenId, + pub user_id: UserId, + pub hash: String, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm( + belongs_to = "super::user::Entity", + from = "Column::UserId", + to = "super::user::Column::Id" + )] + User, +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::User.def() + } +} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab2/src/db/tables/buffer.rs b/crates/collab2/src/db/tables/buffer.rs new file mode 100644 index 0000000000..ec2ffd4a68 --- /dev/null +++ b/crates/collab2/src/db/tables/buffer.rs @@ -0,0 +1,45 @@ +use crate::db::{BufferId, ChannelId}; +use sea_orm::entity::prelude::*; + +#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] +#[sea_orm(table_name = "buffers")] +pub struct Model { + #[sea_orm(primary_key)] + pub id: BufferId, + pub epoch: i32, + pub channel_id: ChannelId, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm(has_many = "super::buffer_operation::Entity")] + Operations, + #[sea_orm(has_many = "super::buffer_snapshot::Entity")] + Snapshots, + #[sea_orm( + belongs_to = "super::channel::Entity", + from = "Column::ChannelId", + to = "super::channel::Column::Id" + )] + Channel, +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::Operations.def() + } +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::Snapshots.def() + } +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::Channel.def() + } +} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab2/src/db/tables/buffer_operation.rs b/crates/collab2/src/db/tables/buffer_operation.rs new file mode 100644 index 0000000000..37bd4bedfe --- /dev/null +++ b/crates/collab2/src/db/tables/buffer_operation.rs @@ -0,0 +1,34 @@ +use crate::db::BufferId; +use sea_orm::entity::prelude::*; + +#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] +#[sea_orm(table_name = "buffer_operations")] +pub struct Model { + #[sea_orm(primary_key)] + pub buffer_id: BufferId, + #[sea_orm(primary_key)] + pub epoch: i32, + #[sea_orm(primary_key)] + pub lamport_timestamp: i32, + #[sea_orm(primary_key)] + pub replica_id: i32, + pub value: Vec, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm( + belongs_to = "super::buffer::Entity", + from = "Column::BufferId", + to = "super::buffer::Column::Id" + )] + Buffer, +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::Buffer.def() + } +} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab2/src/db/tables/buffer_snapshot.rs b/crates/collab2/src/db/tables/buffer_snapshot.rs new file mode 100644 index 0000000000..c9de665e43 --- /dev/null +++ b/crates/collab2/src/db/tables/buffer_snapshot.rs @@ -0,0 +1,31 @@ +use crate::db::BufferId; +use sea_orm::entity::prelude::*; + +#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] +#[sea_orm(table_name = "buffer_snapshots")] +pub struct Model { + #[sea_orm(primary_key)] + pub buffer_id: BufferId, + #[sea_orm(primary_key)] + pub epoch: i32, + pub text: String, + pub operation_serialization_version: i32, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm( + belongs_to = "super::buffer::Entity", + from = "Column::BufferId", + to = "super::buffer::Column::Id" + )] + Buffer, +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::Buffer.def() + } +} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab2/src/db/tables/channel.rs b/crates/collab2/src/db/tables/channel.rs new file mode 100644 index 0000000000..e30ec9af61 --- /dev/null +++ b/crates/collab2/src/db/tables/channel.rs @@ -0,0 +1,79 @@ +use crate::db::{ChannelId, ChannelVisibility}; +use sea_orm::entity::prelude::*; + +#[derive(Clone, Debug, Default, PartialEq, Eq, DeriveEntityModel)] +#[sea_orm(table_name = "channels")] +pub struct Model { + #[sea_orm(primary_key)] + pub id: ChannelId, + pub name: String, + pub visibility: ChannelVisibility, + pub parent_path: String, +} + +impl Model { + pub fn parent_id(&self) -> Option { + self.ancestors().last() + } + + pub fn ancestors(&self) -> impl Iterator + '_ { + self.parent_path + .trim_end_matches('/') + .split('/') + .filter_map(|id| Some(ChannelId::from_proto(id.parse().ok()?))) + } + + pub fn ancestors_including_self(&self) -> impl Iterator + '_ { + self.ancestors().chain(Some(self.id)) + } + + pub fn path(&self) -> String { + format!("{}{}/", self.parent_path, self.id) + } +} + +impl ActiveModelBehavior for ActiveModel {} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm(has_one = "super::room::Entity")] + Room, + #[sea_orm(has_one = "super::buffer::Entity")] + Buffer, + #[sea_orm(has_many = "super::channel_member::Entity")] + Member, + #[sea_orm(has_many = "super::channel_buffer_collaborator::Entity")] + BufferCollaborators, + #[sea_orm(has_many = "super::channel_chat_participant::Entity")] + ChatParticipants, +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::Member.def() + } +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::Room.def() + } +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::Buffer.def() + } +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::BufferCollaborators.def() + } +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::ChatParticipants.def() + } +} diff --git a/crates/collab2/src/db/tables/channel_buffer_collaborator.rs b/crates/collab2/src/db/tables/channel_buffer_collaborator.rs new file mode 100644 index 0000000000..ac2637b36e --- /dev/null +++ b/crates/collab2/src/db/tables/channel_buffer_collaborator.rs @@ -0,0 +1,43 @@ +use crate::db::{ChannelBufferCollaboratorId, ChannelId, ReplicaId, ServerId, UserId}; +use rpc::ConnectionId; +use sea_orm::entity::prelude::*; + +#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] +#[sea_orm(table_name = "channel_buffer_collaborators")] +pub struct Model { + #[sea_orm(primary_key)] + pub id: ChannelBufferCollaboratorId, + pub channel_id: ChannelId, + pub connection_id: i32, + pub connection_server_id: ServerId, + pub connection_lost: bool, + pub user_id: UserId, + pub replica_id: ReplicaId, +} + +impl Model { + pub fn connection(&self) -> ConnectionId { + ConnectionId { + owner_id: self.connection_server_id.0 as u32, + id: self.connection_id as u32, + } + } +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm( + belongs_to = "super::channel::Entity", + from = "Column::ChannelId", + to = "super::channel::Column::Id" + )] + Channel, +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::Channel.def() + } +} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab2/src/db/tables/channel_chat_participant.rs b/crates/collab2/src/db/tables/channel_chat_participant.rs new file mode 100644 index 0000000000..f3ef36c289 --- /dev/null +++ b/crates/collab2/src/db/tables/channel_chat_participant.rs @@ -0,0 +1,41 @@ +use crate::db::{ChannelChatParticipantId, ChannelId, ServerId, UserId}; +use rpc::ConnectionId; +use sea_orm::entity::prelude::*; + +#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] +#[sea_orm(table_name = "channel_chat_participants")] +pub struct Model { + #[sea_orm(primary_key)] + pub id: ChannelChatParticipantId, + pub channel_id: ChannelId, + pub user_id: UserId, + pub connection_id: i32, + pub connection_server_id: ServerId, +} + +impl Model { + pub fn connection(&self) -> ConnectionId { + ConnectionId { + owner_id: self.connection_server_id.0 as u32, + id: self.connection_id as u32, + } + } +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm( + belongs_to = "super::channel::Entity", + from = "Column::ChannelId", + to = "super::channel::Column::Id" + )] + Channel, +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::Channel.def() + } +} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab2/src/db/tables/channel_member.rs b/crates/collab2/src/db/tables/channel_member.rs new file mode 100644 index 0000000000..5498a00856 --- /dev/null +++ b/crates/collab2/src/db/tables/channel_member.rs @@ -0,0 +1,59 @@ +use crate::db::{channel_member, ChannelId, ChannelMemberId, ChannelRole, UserId}; +use sea_orm::entity::prelude::*; + +#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] +#[sea_orm(table_name = "channel_members")] +pub struct Model { + #[sea_orm(primary_key)] + pub id: ChannelMemberId, + pub channel_id: ChannelId, + pub user_id: UserId, + pub accepted: bool, + pub role: ChannelRole, +} + +impl ActiveModelBehavior for ActiveModel {} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm( + belongs_to = "super::channel::Entity", + from = "Column::ChannelId", + to = "super::channel::Column::Id" + )] + Channel, + #[sea_orm( + belongs_to = "super::user::Entity", + from = "Column::UserId", + to = "super::user::Column::Id" + )] + User, +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::Channel.def() + } +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::User.def() + } +} + +#[derive(Debug)] +pub struct UserToChannel; + +impl Linked for UserToChannel { + type FromEntity = super::user::Entity; + + type ToEntity = super::channel::Entity; + + fn link(&self) -> Vec { + vec![ + channel_member::Relation::User.def().rev(), + channel_member::Relation::Channel.def(), + ] + } +} diff --git a/crates/collab2/src/db/tables/channel_message.rs b/crates/collab2/src/db/tables/channel_message.rs new file mode 100644 index 0000000000..ff49c63ba7 --- /dev/null +++ b/crates/collab2/src/db/tables/channel_message.rs @@ -0,0 +1,45 @@ +use crate::db::{ChannelId, MessageId, UserId}; +use sea_orm::entity::prelude::*; +use time::PrimitiveDateTime; + +#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] +#[sea_orm(table_name = "channel_messages")] +pub struct Model { + #[sea_orm(primary_key)] + pub id: MessageId, + pub channel_id: ChannelId, + pub sender_id: UserId, + pub body: String, + pub sent_at: PrimitiveDateTime, + pub nonce: Uuid, +} + +impl ActiveModelBehavior for ActiveModel {} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm( + belongs_to = "super::channel::Entity", + from = "Column::ChannelId", + to = "super::channel::Column::Id" + )] + Channel, + #[sea_orm( + belongs_to = "super::user::Entity", + from = "Column::SenderId", + to = "super::user::Column::Id" + )] + Sender, +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::Channel.def() + } +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::Sender.def() + } +} diff --git a/crates/collab2/src/db/tables/channel_message_mention.rs b/crates/collab2/src/db/tables/channel_message_mention.rs new file mode 100644 index 0000000000..6155b057f0 --- /dev/null +++ b/crates/collab2/src/db/tables/channel_message_mention.rs @@ -0,0 +1,43 @@ +use crate::db::{MessageId, UserId}; +use sea_orm::entity::prelude::*; + +#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] +#[sea_orm(table_name = "channel_message_mentions")] +pub struct Model { + #[sea_orm(primary_key)] + pub message_id: MessageId, + #[sea_orm(primary_key)] + pub start_offset: i32, + pub end_offset: i32, + pub user_id: UserId, +} + +impl ActiveModelBehavior for ActiveModel {} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm( + belongs_to = "super::channel_message::Entity", + from = "Column::MessageId", + to = "super::channel_message::Column::Id" + )] + Message, + #[sea_orm( + belongs_to = "super::user::Entity", + from = "Column::UserId", + to = "super::user::Column::Id" + )] + MentionedUser, +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::Message.def() + } +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::MentionedUser.def() + } +} diff --git a/crates/collab2/src/db/tables/contact.rs b/crates/collab2/src/db/tables/contact.rs new file mode 100644 index 0000000000..38af8b782b --- /dev/null +++ b/crates/collab2/src/db/tables/contact.rs @@ -0,0 +1,32 @@ +use crate::db::{ContactId, UserId}; +use sea_orm::entity::prelude::*; + +#[derive(Clone, Debug, Default, PartialEq, Eq, DeriveEntityModel)] +#[sea_orm(table_name = "contacts")] +pub struct Model { + #[sea_orm(primary_key)] + pub id: ContactId, + pub user_id_a: UserId, + pub user_id_b: UserId, + pub a_to_b: bool, + pub should_notify: bool, + pub accepted: bool, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm( + belongs_to = "super::room_participant::Entity", + from = "Column::UserIdA", + to = "super::room_participant::Column::UserId" + )] + UserARoomParticipant, + #[sea_orm( + belongs_to = "super::room_participant::Entity", + from = "Column::UserIdB", + to = "super::room_participant::Column::UserId" + )] + UserBRoomParticipant, +} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab2/src/db/tables/feature_flag.rs b/crates/collab2/src/db/tables/feature_flag.rs new file mode 100644 index 0000000000..41c1451c64 --- /dev/null +++ b/crates/collab2/src/db/tables/feature_flag.rs @@ -0,0 +1,40 @@ +use sea_orm::entity::prelude::*; + +use crate::db::FlagId; + +#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] +#[sea_orm(table_name = "feature_flags")] +pub struct Model { + #[sea_orm(primary_key)] + pub id: FlagId, + pub flag: String, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm(has_many = "super::user_feature::Entity")] + UserFeature, +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::UserFeature.def() + } +} + +impl ActiveModelBehavior for ActiveModel {} + +pub struct FlaggedUsers; + +impl Linked for FlaggedUsers { + type FromEntity = Entity; + + type ToEntity = super::user::Entity; + + fn link(&self) -> Vec { + vec![ + super::user_feature::Relation::Flag.def().rev(), + super::user_feature::Relation::User.def(), + ] + } +} diff --git a/crates/collab2/src/db/tables/follower.rs b/crates/collab2/src/db/tables/follower.rs new file mode 100644 index 0000000000..ffd45434e9 --- /dev/null +++ b/crates/collab2/src/db/tables/follower.rs @@ -0,0 +1,50 @@ +use crate::db::{FollowerId, ProjectId, RoomId, ServerId}; +use rpc::ConnectionId; +use sea_orm::entity::prelude::*; + +#[derive(Clone, Debug, Default, PartialEq, Eq, DeriveEntityModel)] +#[sea_orm(table_name = "followers")] +pub struct Model { + #[sea_orm(primary_key)] + pub id: FollowerId, + pub room_id: RoomId, + pub project_id: ProjectId, + pub leader_connection_server_id: ServerId, + pub leader_connection_id: i32, + pub follower_connection_server_id: ServerId, + pub follower_connection_id: i32, +} + +impl Model { + pub fn leader_connection(&self) -> ConnectionId { + ConnectionId { + owner_id: self.leader_connection_server_id.0 as u32, + id: self.leader_connection_id as u32, + } + } + + pub fn follower_connection(&self) -> ConnectionId { + ConnectionId { + owner_id: self.follower_connection_server_id.0 as u32, + id: self.follower_connection_id as u32, + } + } +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm( + belongs_to = "super::room::Entity", + from = "Column::RoomId", + to = "super::room::Column::Id" + )] + Room, +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::Room.def() + } +} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab2/src/db/tables/language_server.rs b/crates/collab2/src/db/tables/language_server.rs new file mode 100644 index 0000000000..9ff8c75fc6 --- /dev/null +++ b/crates/collab2/src/db/tables/language_server.rs @@ -0,0 +1,30 @@ +use crate::db::ProjectId; +use sea_orm::entity::prelude::*; + +#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] +#[sea_orm(table_name = "language_servers")] +pub struct Model { + #[sea_orm(primary_key)] + pub project_id: ProjectId, + #[sea_orm(primary_key)] + pub id: i64, + pub name: String, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm( + belongs_to = "super::project::Entity", + from = "Column::ProjectId", + to = "super::project::Column::Id" + )] + Project, +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::Project.def() + } +} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab2/src/db/tables/notification.rs b/crates/collab2/src/db/tables/notification.rs new file mode 100644 index 0000000000..3105198fa2 --- /dev/null +++ b/crates/collab2/src/db/tables/notification.rs @@ -0,0 +1,29 @@ +use crate::db::{NotificationId, NotificationKindId, UserId}; +use sea_orm::entity::prelude::*; +use time::PrimitiveDateTime; + +#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] +#[sea_orm(table_name = "notifications")] +pub struct Model { + #[sea_orm(primary_key)] + pub id: NotificationId, + pub created_at: PrimitiveDateTime, + pub recipient_id: UserId, + pub kind: NotificationKindId, + pub entity_id: Option, + pub content: String, + pub is_read: bool, + pub response: Option, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm( + belongs_to = "super::user::Entity", + from = "Column::RecipientId", + to = "super::user::Column::Id" + )] + Recipient, +} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab2/src/db/tables/notification_kind.rs b/crates/collab2/src/db/tables/notification_kind.rs new file mode 100644 index 0000000000..865b5da04b --- /dev/null +++ b/crates/collab2/src/db/tables/notification_kind.rs @@ -0,0 +1,15 @@ +use crate::db::NotificationKindId; +use sea_orm::entity::prelude::*; + +#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] +#[sea_orm(table_name = "notification_kinds")] +pub struct Model { + #[sea_orm(primary_key)] + pub id: NotificationKindId, + pub name: String, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation {} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab2/src/db/tables/observed_buffer_edits.rs b/crates/collab2/src/db/tables/observed_buffer_edits.rs new file mode 100644 index 0000000000..e8e7aafaa2 --- /dev/null +++ b/crates/collab2/src/db/tables/observed_buffer_edits.rs @@ -0,0 +1,43 @@ +use crate::db::{BufferId, UserId}; +use sea_orm::entity::prelude::*; + +#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] +#[sea_orm(table_name = "observed_buffer_edits")] +pub struct Model { + #[sea_orm(primary_key)] + pub user_id: UserId, + pub buffer_id: BufferId, + pub epoch: i32, + pub lamport_timestamp: i32, + pub replica_id: i32, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm( + belongs_to = "super::buffer::Entity", + from = "Column::BufferId", + to = "super::buffer::Column::Id" + )] + Buffer, + #[sea_orm( + belongs_to = "super::user::Entity", + from = "Column::UserId", + to = "super::user::Column::Id" + )] + User, +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::Buffer.def() + } +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::User.def() + } +} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab2/src/db/tables/observed_channel_messages.rs b/crates/collab2/src/db/tables/observed_channel_messages.rs new file mode 100644 index 0000000000..18259f8442 --- /dev/null +++ b/crates/collab2/src/db/tables/observed_channel_messages.rs @@ -0,0 +1,41 @@ +use crate::db::{ChannelId, MessageId, UserId}; +use sea_orm::entity::prelude::*; + +#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] +#[sea_orm(table_name = "observed_channel_messages")] +pub struct Model { + #[sea_orm(primary_key)] + pub user_id: UserId, + pub channel_id: ChannelId, + pub channel_message_id: MessageId, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm( + belongs_to = "super::channel::Entity", + from = "Column::ChannelId", + to = "super::channel::Column::Id" + )] + Channel, + #[sea_orm( + belongs_to = "super::user::Entity", + from = "Column::UserId", + to = "super::user::Column::Id" + )] + User, +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::Channel.def() + } +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::User.def() + } +} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab2/src/db/tables/project.rs b/crates/collab2/src/db/tables/project.rs new file mode 100644 index 0000000000..8c26836046 --- /dev/null +++ b/crates/collab2/src/db/tables/project.rs @@ -0,0 +1,84 @@ +use crate::db::{ProjectId, Result, RoomId, ServerId, UserId}; +use anyhow::anyhow; +use rpc::ConnectionId; +use sea_orm::entity::prelude::*; + +#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] +#[sea_orm(table_name = "projects")] +pub struct Model { + #[sea_orm(primary_key)] + pub id: ProjectId, + pub room_id: RoomId, + pub host_user_id: UserId, + pub host_connection_id: Option, + pub host_connection_server_id: Option, +} + +impl Model { + pub fn host_connection(&self) -> Result { + let host_connection_server_id = self + .host_connection_server_id + .ok_or_else(|| anyhow!("empty host_connection_server_id"))?; + let host_connection_id = self + .host_connection_id + .ok_or_else(|| anyhow!("empty host_connection_id"))?; + Ok(ConnectionId { + owner_id: host_connection_server_id.0 as u32, + id: host_connection_id as u32, + }) + } +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm( + belongs_to = "super::user::Entity", + from = "Column::HostUserId", + to = "super::user::Column::Id" + )] + HostUser, + #[sea_orm( + belongs_to = "super::room::Entity", + from = "Column::RoomId", + to = "super::room::Column::Id" + )] + Room, + #[sea_orm(has_many = "super::worktree::Entity")] + Worktrees, + #[sea_orm(has_many = "super::project_collaborator::Entity")] + Collaborators, + #[sea_orm(has_many = "super::language_server::Entity")] + LanguageServers, +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::HostUser.def() + } +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::Room.def() + } +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::Worktrees.def() + } +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::Collaborators.def() + } +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::LanguageServers.def() + } +} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab2/src/db/tables/project_collaborator.rs b/crates/collab2/src/db/tables/project_collaborator.rs new file mode 100644 index 0000000000..ac57befa63 --- /dev/null +++ b/crates/collab2/src/db/tables/project_collaborator.rs @@ -0,0 +1,43 @@ +use crate::db::{ProjectCollaboratorId, ProjectId, ReplicaId, ServerId, UserId}; +use rpc::ConnectionId; +use sea_orm::entity::prelude::*; + +#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] +#[sea_orm(table_name = "project_collaborators")] +pub struct Model { + #[sea_orm(primary_key)] + pub id: ProjectCollaboratorId, + pub project_id: ProjectId, + pub connection_id: i32, + pub connection_server_id: ServerId, + pub user_id: UserId, + pub replica_id: ReplicaId, + pub is_host: bool, +} + +impl Model { + pub fn connection(&self) -> ConnectionId { + ConnectionId { + owner_id: self.connection_server_id.0 as u32, + id: self.connection_id as u32, + } + } +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm( + belongs_to = "super::project::Entity", + from = "Column::ProjectId", + to = "super::project::Column::Id" + )] + Project, +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::Project.def() + } +} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab2/src/db/tables/room.rs b/crates/collab2/src/db/tables/room.rs new file mode 100644 index 0000000000..4150c741ac --- /dev/null +++ b/crates/collab2/src/db/tables/room.rs @@ -0,0 +1,54 @@ +use crate::db::{ChannelId, RoomId}; +use sea_orm::entity::prelude::*; + +#[derive(Clone, Default, Debug, PartialEq, Eq, DeriveEntityModel)] +#[sea_orm(table_name = "rooms")] +pub struct Model { + #[sea_orm(primary_key)] + pub id: RoomId, + pub live_kit_room: String, + pub channel_id: Option, + pub enviroment: Option, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm(has_many = "super::room_participant::Entity")] + RoomParticipant, + #[sea_orm(has_many = "super::project::Entity")] + Project, + #[sea_orm(has_many = "super::follower::Entity")] + Follower, + #[sea_orm( + belongs_to = "super::channel::Entity", + from = "Column::ChannelId", + to = "super::channel::Column::Id" + )] + Channel, +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::RoomParticipant.def() + } +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::Project.def() + } +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::Follower.def() + } +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::Channel.def() + } +} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab2/src/db/tables/room_participant.rs b/crates/collab2/src/db/tables/room_participant.rs new file mode 100644 index 0000000000..4c5b8cc11c --- /dev/null +++ b/crates/collab2/src/db/tables/room_participant.rs @@ -0,0 +1,61 @@ +use crate::db::{ProjectId, RoomId, RoomParticipantId, ServerId, UserId}; +use rpc::ConnectionId; +use sea_orm::entity::prelude::*; + +#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] +#[sea_orm(table_name = "room_participants")] +pub struct Model { + #[sea_orm(primary_key)] + pub id: RoomParticipantId, + pub room_id: RoomId, + pub user_id: UserId, + pub answering_connection_id: Option, + pub answering_connection_server_id: Option, + pub answering_connection_lost: bool, + pub location_kind: Option, + pub location_project_id: Option, + pub initial_project_id: Option, + pub calling_user_id: UserId, + pub calling_connection_id: i32, + pub calling_connection_server_id: Option, + pub participant_index: Option, +} + +impl Model { + pub fn answering_connection(&self) -> Option { + Some(ConnectionId { + owner_id: self.answering_connection_server_id?.0 as u32, + id: self.answering_connection_id? as u32, + }) + } +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm( + belongs_to = "super::user::Entity", + from = "Column::UserId", + to = "super::user::Column::Id" + )] + User, + #[sea_orm( + belongs_to = "super::room::Entity", + from = "Column::RoomId", + to = "super::room::Column::Id" + )] + Room, +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::User.def() + } +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::Room.def() + } +} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab2/src/db/tables/server.rs b/crates/collab2/src/db/tables/server.rs new file mode 100644 index 0000000000..ea847bdf74 --- /dev/null +++ b/crates/collab2/src/db/tables/server.rs @@ -0,0 +1,15 @@ +use crate::db::ServerId; +use sea_orm::entity::prelude::*; + +#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] +#[sea_orm(table_name = "servers")] +pub struct Model { + #[sea_orm(primary_key)] + pub id: ServerId, + pub environment: String, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation {} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab2/src/db/tables/signup.rs b/crates/collab2/src/db/tables/signup.rs new file mode 100644 index 0000000000..79d9f0580c --- /dev/null +++ b/crates/collab2/src/db/tables/signup.rs @@ -0,0 +1,28 @@ +use crate::db::{SignupId, UserId}; +use sea_orm::entity::prelude::*; + +#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] +#[sea_orm(table_name = "signups")] +pub struct Model { + #[sea_orm(primary_key)] + pub id: SignupId, + pub email_address: String, + pub email_confirmation_code: String, + pub email_confirmation_sent: bool, + pub created_at: DateTime, + pub device_id: Option, + pub user_id: Option, + pub inviting_user_id: Option, + pub platform_mac: bool, + pub platform_linux: bool, + pub platform_windows: bool, + pub platform_unknown: bool, + pub editor_features: Option>, + pub programming_languages: Option>, + pub added_to_mailing_list: bool, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation {} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab2/src/db/tables/user.rs b/crates/collab2/src/db/tables/user.rs new file mode 100644 index 0000000000..739693527f --- /dev/null +++ b/crates/collab2/src/db/tables/user.rs @@ -0,0 +1,80 @@ +use crate::db::UserId; +use sea_orm::entity::prelude::*; +use serde::Serialize; + +#[derive(Clone, Debug, Default, PartialEq, Eq, DeriveEntityModel, Serialize)] +#[sea_orm(table_name = "users")] +pub struct Model { + #[sea_orm(primary_key)] + pub id: UserId, + pub github_login: String, + pub github_user_id: Option, + pub email_address: Option, + pub admin: bool, + pub invite_code: Option, + pub invite_count: i32, + pub inviter_id: Option, + pub connected_once: bool, + pub metrics_id: Uuid, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm(has_many = "super::access_token::Entity")] + AccessToken, + #[sea_orm(has_one = "super::room_participant::Entity")] + RoomParticipant, + #[sea_orm(has_many = "super::project::Entity")] + HostedProjects, + #[sea_orm(has_many = "super::channel_member::Entity")] + ChannelMemberships, + #[sea_orm(has_many = "super::user_feature::Entity")] + UserFeatures, +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::AccessToken.def() + } +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::RoomParticipant.def() + } +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::HostedProjects.def() + } +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::ChannelMemberships.def() + } +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::UserFeatures.def() + } +} + +impl ActiveModelBehavior for ActiveModel {} + +pub struct UserFlags; + +impl Linked for UserFlags { + type FromEntity = Entity; + + type ToEntity = super::feature_flag::Entity; + + fn link(&self) -> Vec { + vec![ + super::user_feature::Relation::User.def().rev(), + super::user_feature::Relation::Flag.def(), + ] + } +} diff --git a/crates/collab2/src/db/tables/user_feature.rs b/crates/collab2/src/db/tables/user_feature.rs new file mode 100644 index 0000000000..cc24b5e796 --- /dev/null +++ b/crates/collab2/src/db/tables/user_feature.rs @@ -0,0 +1,42 @@ +use sea_orm::entity::prelude::*; + +use crate::db::{FlagId, UserId}; + +#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] +#[sea_orm(table_name = "user_features")] +pub struct Model { + #[sea_orm(primary_key)] + pub user_id: UserId, + #[sea_orm(primary_key)] + pub feature_id: FlagId, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm( + belongs_to = "super::feature_flag::Entity", + from = "Column::FeatureId", + to = "super::feature_flag::Column::Id" + )] + Flag, + #[sea_orm( + belongs_to = "super::user::Entity", + from = "Column::UserId", + to = "super::user::Column::Id" + )] + User, +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::Flag.def() + } +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::User.def() + } +} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab2/src/db/tables/worktree.rs b/crates/collab2/src/db/tables/worktree.rs new file mode 100644 index 0000000000..46d9877dff --- /dev/null +++ b/crates/collab2/src/db/tables/worktree.rs @@ -0,0 +1,36 @@ +use crate::db::ProjectId; +use sea_orm::entity::prelude::*; + +#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] +#[sea_orm(table_name = "worktrees")] +pub struct Model { + #[sea_orm(primary_key)] + pub id: i64, + #[sea_orm(primary_key)] + pub project_id: ProjectId, + pub abs_path: String, + pub root_name: String, + pub visible: bool, + /// The last scan for which we've observed entries. It may be in progress. + pub scan_id: i64, + /// The last scan that fully completed. + pub completed_scan_id: i64, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm( + belongs_to = "super::project::Entity", + from = "Column::ProjectId", + to = "super::project::Column::Id" + )] + Project, +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::Project.def() + } +} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab2/src/db/tables/worktree_diagnostic_summary.rs b/crates/collab2/src/db/tables/worktree_diagnostic_summary.rs new file mode 100644 index 0000000000..5620ed255f --- /dev/null +++ b/crates/collab2/src/db/tables/worktree_diagnostic_summary.rs @@ -0,0 +1,21 @@ +use crate::db::ProjectId; +use sea_orm::entity::prelude::*; + +#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] +#[sea_orm(table_name = "worktree_diagnostic_summaries")] +pub struct Model { + #[sea_orm(primary_key)] + pub project_id: ProjectId, + #[sea_orm(primary_key)] + pub worktree_id: i64, + #[sea_orm(primary_key)] + pub path: String, + pub language_server_id: i64, + pub error_count: i32, + pub warning_count: i32, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation {} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab2/src/db/tables/worktree_entry.rs b/crates/collab2/src/db/tables/worktree_entry.rs new file mode 100644 index 0000000000..81bf6e2d53 --- /dev/null +++ b/crates/collab2/src/db/tables/worktree_entry.rs @@ -0,0 +1,29 @@ +use crate::db::ProjectId; +use sea_orm::entity::prelude::*; + +#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] +#[sea_orm(table_name = "worktree_entries")] +pub struct Model { + #[sea_orm(primary_key)] + pub project_id: ProjectId, + #[sea_orm(primary_key)] + pub worktree_id: i64, + #[sea_orm(primary_key)] + pub id: i64, + pub is_dir: bool, + pub path: String, + pub inode: i64, + pub mtime_seconds: i64, + pub mtime_nanos: i32, + pub git_status: Option, + pub is_symlink: bool, + pub is_ignored: bool, + pub is_external: bool, + pub is_deleted: bool, + pub scan_id: i64, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation {} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab2/src/db/tables/worktree_repository.rs b/crates/collab2/src/db/tables/worktree_repository.rs new file mode 100644 index 0000000000..6f86ff0c2d --- /dev/null +++ b/crates/collab2/src/db/tables/worktree_repository.rs @@ -0,0 +1,21 @@ +use crate::db::ProjectId; +use sea_orm::entity::prelude::*; + +#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] +#[sea_orm(table_name = "worktree_repositories")] +pub struct Model { + #[sea_orm(primary_key)] + pub project_id: ProjectId, + #[sea_orm(primary_key)] + pub worktree_id: i64, + #[sea_orm(primary_key)] + pub work_directory_id: i64, + pub scan_id: i64, + pub branch: Option, + pub is_deleted: bool, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation {} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab2/src/db/tables/worktree_repository_statuses.rs b/crates/collab2/src/db/tables/worktree_repository_statuses.rs new file mode 100644 index 0000000000..cab016749d --- /dev/null +++ b/crates/collab2/src/db/tables/worktree_repository_statuses.rs @@ -0,0 +1,23 @@ +use crate::db::ProjectId; +use sea_orm::entity::prelude::*; + +#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] +#[sea_orm(table_name = "worktree_repository_statuses")] +pub struct Model { + #[sea_orm(primary_key)] + pub project_id: ProjectId, + #[sea_orm(primary_key)] + pub worktree_id: i64, + #[sea_orm(primary_key)] + pub work_directory_id: i64, + #[sea_orm(primary_key)] + pub repo_path: String, + pub status: i64, + pub scan_id: i64, + pub is_deleted: bool, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation {} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab2/src/db/tables/worktree_settings_file.rs b/crates/collab2/src/db/tables/worktree_settings_file.rs new file mode 100644 index 0000000000..92348c1ec9 --- /dev/null +++ b/crates/collab2/src/db/tables/worktree_settings_file.rs @@ -0,0 +1,19 @@ +use crate::db::ProjectId; +use sea_orm::entity::prelude::*; + +#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] +#[sea_orm(table_name = "worktree_settings_files")] +pub struct Model { + #[sea_orm(primary_key)] + pub project_id: ProjectId, + #[sea_orm(primary_key)] + pub worktree_id: i64, + #[sea_orm(primary_key)] + pub path: String, + pub content: String, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation {} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab2/src/db/tests.rs b/crates/collab2/src/db/tests.rs new file mode 100644 index 0000000000..56e37abc1d --- /dev/null +++ b/crates/collab2/src/db/tests.rs @@ -0,0 +1,187 @@ +mod buffer_tests; +mod channel_tests; +mod db_tests; +mod feature_flag_tests; +mod message_tests; + +use super::*; +use gpui::BackgroundExecutor; +use parking_lot::Mutex; +use sea_orm::ConnectionTrait; +use sqlx::migrate::MigrateDatabase; +use std::sync::{ + atomic::{AtomicI32, AtomicU32, Ordering::SeqCst}, + Arc, +}; + +const TEST_RELEASE_CHANNEL: &'static str = "test"; + +pub struct TestDb { + pub db: Option>, + pub connection: Option, +} + +impl TestDb { + pub fn sqlite(background: BackgroundExecutor) -> Self { + let url = format!("sqlite::memory:"); + let runtime = tokio::runtime::Builder::new_current_thread() + .enable_io() + .enable_time() + .build() + .unwrap(); + + let mut db = runtime.block_on(async { + let mut options = ConnectOptions::new(url); + options.max_connections(5); + let mut db = Database::new(options, Executor::Deterministic(background)) + .await + .unwrap(); + let sql = include_str!(concat!( + env!("CARGO_MANIFEST_DIR"), + "/migrations.sqlite/20221109000000_test_schema.sql" + )); + db.pool + .execute(sea_orm::Statement::from_string( + db.pool.get_database_backend(), + sql, + )) + .await + .unwrap(); + db.initialize_notification_kinds().await.unwrap(); + db + }); + + db.runtime = Some(runtime); + + Self { + db: Some(Arc::new(db)), + connection: None, + } + } + + pub fn postgres(background: BackgroundExecutor) -> Self { + static LOCK: Mutex<()> = Mutex::new(()); + + let _guard = LOCK.lock(); + let mut rng = StdRng::from_entropy(); + let url = format!( + "postgres://postgres@localhost/zed-test-{}", + rng.gen::() + ); + let runtime = tokio::runtime::Builder::new_current_thread() + .enable_io() + .enable_time() + .build() + .unwrap(); + + let mut db = runtime.block_on(async { + sqlx::Postgres::create_database(&url) + .await + .expect("failed to create test db"); + let mut options = ConnectOptions::new(url); + options + .max_connections(5) + .idle_timeout(Duration::from_secs(0)); + let mut db = Database::new(options, Executor::Deterministic(background)) + .await + .unwrap(); + let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"); + db.migrate(Path::new(migrations_path), false).await.unwrap(); + db.initialize_notification_kinds().await.unwrap(); + db + }); + + db.runtime = Some(runtime); + + Self { + db: Some(Arc::new(db)), + connection: None, + } + } + + pub fn db(&self) -> &Arc { + self.db.as_ref().unwrap() + } +} + +#[macro_export] +macro_rules! test_both_dbs { + ($test_name:ident, $postgres_test_name:ident, $sqlite_test_name:ident) => { + #[gpui::test] + async fn $postgres_test_name(cx: &mut gpui::TestAppContext) { + let test_db = crate::db::TestDb::postgres(cx.executor().clone()); + $test_name(test_db.db()).await; + } + + #[gpui::test] + async fn $sqlite_test_name(cx: &mut gpui::TestAppContext) { + let test_db = crate::db::TestDb::sqlite(cx.executor().clone()); + $test_name(test_db.db()).await; + } + }; +} + +impl Drop for TestDb { + fn drop(&mut self) { + let db = self.db.take().unwrap(); + if let sea_orm::DatabaseBackend::Postgres = db.pool.get_database_backend() { + db.runtime.as_ref().unwrap().block_on(async { + use util::ResultExt; + let query = " + SELECT pg_terminate_backend(pg_stat_activity.pid) + FROM pg_stat_activity + WHERE + pg_stat_activity.datname = current_database() AND + pid <> pg_backend_pid(); + "; + db.pool + .execute(sea_orm::Statement::from_string( + db.pool.get_database_backend(), + query, + )) + .await + .log_err(); + sqlx::Postgres::drop_database(db.options.get_url()) + .await + .log_err(); + }) + } + } +} + +fn channel_tree(channels: &[(ChannelId, &[ChannelId], &'static str, ChannelRole)]) -> Vec { + channels + .iter() + .map(|(id, parent_path, name, role)| Channel { + id: *id, + name: name.to_string(), + visibility: ChannelVisibility::Members, + role: *role, + parent_path: parent_path.to_vec(), + }) + .collect() +} + +static GITHUB_USER_ID: AtomicI32 = AtomicI32::new(5); + +async fn new_test_user(db: &Arc, email: &str) -> UserId { + db.create_user( + email, + false, + NewUserParams { + github_login: email[0..email.find("@").unwrap()].to_string(), + github_user_id: GITHUB_USER_ID.fetch_add(1, SeqCst), + }, + ) + .await + .unwrap() + .user_id +} + +static TEST_CONNECTION_ID: AtomicU32 = AtomicU32::new(1); +fn new_test_connection(server: ServerId) -> ConnectionId { + ConnectionId { + id: TEST_CONNECTION_ID.fetch_add(1, SeqCst), + owner_id: server.0 as u32, + } +} diff --git a/crates/collab2/src/db/tests/buffer_tests.rs b/crates/collab2/src/db/tests/buffer_tests.rs new file mode 100644 index 0000000000..222514da0b --- /dev/null +++ b/crates/collab2/src/db/tests/buffer_tests.rs @@ -0,0 +1,506 @@ +use super::*; +use crate::test_both_dbs; +use language::proto::{self, serialize_version}; +use text::Buffer; + +test_both_dbs!( + test_channel_buffers, + test_channel_buffers_postgres, + test_channel_buffers_sqlite +); + +async fn test_channel_buffers(db: &Arc) { + let a_id = db + .create_user( + "user_a@example.com", + false, + NewUserParams { + github_login: "user_a".into(), + github_user_id: 101, + }, + ) + .await + .unwrap() + .user_id; + let b_id = db + .create_user( + "user_b@example.com", + false, + NewUserParams { + github_login: "user_b".into(), + github_user_id: 102, + }, + ) + .await + .unwrap() + .user_id; + + // This user will not be a part of the channel + let c_id = db + .create_user( + "user_c@example.com", + false, + NewUserParams { + github_login: "user_c".into(), + github_user_id: 102, + }, + ) + .await + .unwrap() + .user_id; + + let owner_id = db.create_server("production").await.unwrap().0 as u32; + + let zed_id = db.create_root_channel("zed", a_id).await.unwrap(); + + db.invite_channel_member(zed_id, b_id, a_id, ChannelRole::Member) + .await + .unwrap(); + + db.respond_to_channel_invite(zed_id, b_id, true) + .await + .unwrap(); + + let connection_id_a = ConnectionId { owner_id, id: 1 }; + let _ = db + .join_channel_buffer(zed_id, a_id, connection_id_a) + .await + .unwrap(); + + let mut buffer_a = Buffer::new(0, 0, "".to_string()); + let mut operations = Vec::new(); + operations.push(buffer_a.edit([(0..0, "hello world")])); + operations.push(buffer_a.edit([(5..5, ", cruel")])); + operations.push(buffer_a.edit([(0..5, "goodbye")])); + operations.push(buffer_a.undo().unwrap().1); + assert_eq!(buffer_a.text(), "hello, cruel world"); + + let operations = operations + .into_iter() + .map(|op| proto::serialize_operation(&language::Operation::Buffer(op))) + .collect::>(); + + db.update_channel_buffer(zed_id, a_id, &operations) + .await + .unwrap(); + + let connection_id_b = ConnectionId { owner_id, id: 2 }; + let buffer_response_b = db + .join_channel_buffer(zed_id, b_id, connection_id_b) + .await + .unwrap(); + + let mut buffer_b = Buffer::new(0, 0, buffer_response_b.base_text); + buffer_b + .apply_ops(buffer_response_b.operations.into_iter().map(|operation| { + let operation = proto::deserialize_operation(operation).unwrap(); + if let language::Operation::Buffer(operation) = operation { + operation + } else { + unreachable!() + } + })) + .unwrap(); + + assert_eq!(buffer_b.text(), "hello, cruel world"); + + // Ensure that C fails to open the buffer + assert!(db + .join_channel_buffer(zed_id, c_id, ConnectionId { owner_id, id: 3 }) + .await + .is_err()); + + // Ensure that both collaborators have shown up + assert_eq!( + buffer_response_b.collaborators, + &[ + rpc::proto::Collaborator { + user_id: a_id.to_proto(), + peer_id: Some(rpc::proto::PeerId { id: 1, owner_id }), + replica_id: 0, + }, + rpc::proto::Collaborator { + user_id: b_id.to_proto(), + peer_id: Some(rpc::proto::PeerId { id: 2, owner_id }), + replica_id: 1, + } + ] + ); + + // Ensure that get_channel_buffer_collaborators works + let zed_collaborats = db.get_channel_buffer_collaborators(zed_id).await.unwrap(); + assert_eq!(zed_collaborats, &[a_id, b_id]); + + let left_buffer = db + .leave_channel_buffer(zed_id, connection_id_b) + .await + .unwrap(); + + assert_eq!(left_buffer.connections, &[connection_id_a],); + + let cargo_id = db.create_root_channel("cargo", a_id).await.unwrap(); + let _ = db + .join_channel_buffer(cargo_id, a_id, connection_id_a) + .await + .unwrap(); + + db.leave_channel_buffers(connection_id_a).await.unwrap(); + + let zed_collaborators = db.get_channel_buffer_collaborators(zed_id).await.unwrap(); + let cargo_collaborators = db.get_channel_buffer_collaborators(cargo_id).await.unwrap(); + assert_eq!(zed_collaborators, &[]); + assert_eq!(cargo_collaborators, &[]); + + // When everyone has left the channel, the operations are collapsed into + // a new base text. + let buffer_response_b = db + .join_channel_buffer(zed_id, b_id, connection_id_b) + .await + .unwrap(); + assert_eq!(buffer_response_b.base_text, "hello, cruel world"); + assert_eq!(buffer_response_b.operations, &[]); +} + +test_both_dbs!( + test_channel_buffers_last_operations, + test_channel_buffers_last_operations_postgres, + test_channel_buffers_last_operations_sqlite +); + +async fn test_channel_buffers_last_operations(db: &Database) { + let user_id = db + .create_user( + "user_a@example.com", + false, + NewUserParams { + github_login: "user_a".into(), + github_user_id: 101, + }, + ) + .await + .unwrap() + .user_id; + let observer_id = db + .create_user( + "user_b@example.com", + false, + NewUserParams { + github_login: "user_b".into(), + github_user_id: 102, + }, + ) + .await + .unwrap() + .user_id; + let owner_id = db.create_server("production").await.unwrap().0 as u32; + let connection_id = ConnectionId { + owner_id, + id: user_id.0 as u32, + }; + + let mut buffers = Vec::new(); + let mut text_buffers = Vec::new(); + for i in 0..3 { + let channel = db + .create_root_channel(&format!("channel-{i}"), user_id) + .await + .unwrap(); + + db.invite_channel_member(channel, observer_id, user_id, ChannelRole::Member) + .await + .unwrap(); + db.respond_to_channel_invite(channel, observer_id, true) + .await + .unwrap(); + + db.join_channel_buffer(channel, user_id, connection_id) + .await + .unwrap(); + + buffers.push( + db.transaction(|tx| async move { db.get_channel_buffer(channel, &*tx).await }) + .await + .unwrap(), + ); + + text_buffers.push(Buffer::new(0, 0, "".to_string())); + } + + let operations = db + .transaction(|tx| { + let buffers = &buffers; + async move { + db.get_latest_operations_for_buffers([buffers[0].id, buffers[2].id], &*tx) + .await + } + }) + .await + .unwrap(); + + assert!(operations.is_empty()); + + update_buffer( + buffers[0].channel_id, + user_id, + db, + vec![ + text_buffers[0].edit([(0..0, "a")]), + text_buffers[0].edit([(0..0, "b")]), + text_buffers[0].edit([(0..0, "c")]), + ], + ) + .await; + + update_buffer( + buffers[1].channel_id, + user_id, + db, + vec![ + text_buffers[1].edit([(0..0, "d")]), + text_buffers[1].edit([(1..1, "e")]), + text_buffers[1].edit([(2..2, "f")]), + ], + ) + .await; + + // cause buffer 1's epoch to increment. + db.leave_channel_buffer(buffers[1].channel_id, connection_id) + .await + .unwrap(); + db.join_channel_buffer(buffers[1].channel_id, user_id, connection_id) + .await + .unwrap(); + text_buffers[1] = Buffer::new(1, 0, "def".to_string()); + update_buffer( + buffers[1].channel_id, + user_id, + db, + vec![ + text_buffers[1].edit([(0..0, "g")]), + text_buffers[1].edit([(0..0, "h")]), + ], + ) + .await; + + update_buffer( + buffers[2].channel_id, + user_id, + db, + vec![text_buffers[2].edit([(0..0, "i")])], + ) + .await; + + let operations = db + .transaction(|tx| { + let buffers = &buffers; + async move { + db.get_latest_operations_for_buffers([buffers[1].id, buffers[2].id], &*tx) + .await + } + }) + .await + .unwrap(); + assert_operations( + &operations, + &[ + (buffers[1].id, 1, &text_buffers[1]), + (buffers[2].id, 0, &text_buffers[2]), + ], + ); + + let operations = db + .transaction(|tx| { + let buffers = &buffers; + async move { + db.get_latest_operations_for_buffers([buffers[0].id, buffers[1].id], &*tx) + .await + } + }) + .await + .unwrap(); + assert_operations( + &operations, + &[ + (buffers[0].id, 0, &text_buffers[0]), + (buffers[1].id, 1, &text_buffers[1]), + ], + ); + + let buffer_changes = db + .transaction(|tx| { + let buffers = &buffers; + async move { + db.unseen_channel_buffer_changes( + observer_id, + &[ + buffers[0].channel_id, + buffers[1].channel_id, + buffers[2].channel_id, + ], + &*tx, + ) + .await + } + }) + .await + .unwrap(); + + pretty_assertions::assert_eq!( + buffer_changes, + [ + rpc::proto::UnseenChannelBufferChange { + channel_id: buffers[0].channel_id.to_proto(), + epoch: 0, + version: serialize_version(&text_buffers[0].version()), + }, + rpc::proto::UnseenChannelBufferChange { + channel_id: buffers[1].channel_id.to_proto(), + epoch: 1, + version: serialize_version(&text_buffers[1].version()) + .into_iter() + .filter(|vector| vector.replica_id + == buffer_changes[1].version.first().unwrap().replica_id) + .collect::>(), + }, + rpc::proto::UnseenChannelBufferChange { + channel_id: buffers[2].channel_id.to_proto(), + epoch: 0, + version: serialize_version(&text_buffers[2].version()), + }, + ] + ); + + db.observe_buffer_version( + buffers[1].id, + observer_id, + 1, + serialize_version(&text_buffers[1].version()).as_slice(), + ) + .await + .unwrap(); + + let buffer_changes = db + .transaction(|tx| { + let buffers = &buffers; + async move { + db.unseen_channel_buffer_changes( + observer_id, + &[ + buffers[0].channel_id, + buffers[1].channel_id, + buffers[2].channel_id, + ], + &*tx, + ) + .await + } + }) + .await + .unwrap(); + + assert_eq!( + buffer_changes, + [ + rpc::proto::UnseenChannelBufferChange { + channel_id: buffers[0].channel_id.to_proto(), + epoch: 0, + version: serialize_version(&text_buffers[0].version()), + }, + rpc::proto::UnseenChannelBufferChange { + channel_id: buffers[2].channel_id.to_proto(), + epoch: 0, + version: serialize_version(&text_buffers[2].version()), + }, + ] + ); + + // Observe an earlier version of the buffer. + db.observe_buffer_version( + buffers[1].id, + observer_id, + 1, + &[rpc::proto::VectorClockEntry { + replica_id: 0, + timestamp: 0, + }], + ) + .await + .unwrap(); + + let buffer_changes = db + .transaction(|tx| { + let buffers = &buffers; + async move { + db.unseen_channel_buffer_changes( + observer_id, + &[ + buffers[0].channel_id, + buffers[1].channel_id, + buffers[2].channel_id, + ], + &*tx, + ) + .await + } + }) + .await + .unwrap(); + + assert_eq!( + buffer_changes, + [ + rpc::proto::UnseenChannelBufferChange { + channel_id: buffers[0].channel_id.to_proto(), + epoch: 0, + version: serialize_version(&text_buffers[0].version()), + }, + rpc::proto::UnseenChannelBufferChange { + channel_id: buffers[2].channel_id.to_proto(), + epoch: 0, + version: serialize_version(&text_buffers[2].version()), + }, + ] + ); +} + +async fn update_buffer( + channel_id: ChannelId, + user_id: UserId, + db: &Database, + operations: Vec, +) { + let operations = operations + .into_iter() + .map(|op| proto::serialize_operation(&language::Operation::Buffer(op))) + .collect::>(); + db.update_channel_buffer(channel_id, user_id, &operations) + .await + .unwrap(); +} + +fn assert_operations( + operations: &[buffer_operation::Model], + expected: &[(BufferId, i32, &text::Buffer)], +) { + let actual = operations + .iter() + .map(|op| buffer_operation::Model { + buffer_id: op.buffer_id, + epoch: op.epoch, + lamport_timestamp: op.lamport_timestamp, + replica_id: op.replica_id, + value: vec![], + }) + .collect::>(); + let expected = expected + .iter() + .map(|(buffer_id, epoch, buffer)| buffer_operation::Model { + buffer_id: *buffer_id, + epoch: *epoch, + lamport_timestamp: buffer.lamport_clock.value as i32 - 1, + replica_id: buffer.replica_id() as i32, + value: vec![], + }) + .collect::>(); + assert_eq!(actual, expected, "unexpected operations") +} diff --git a/crates/collab2/src/db/tests/channel_tests.rs b/crates/collab2/src/db/tests/channel_tests.rs new file mode 100644 index 0000000000..43526c7f24 --- /dev/null +++ b/crates/collab2/src/db/tests/channel_tests.rs @@ -0,0 +1,819 @@ +use crate::{ + db::{ + tests::{channel_tree, new_test_connection, new_test_user, TEST_RELEASE_CHANNEL}, + Channel, ChannelId, ChannelRole, Database, NewUserParams, RoomId, + }, + test_both_dbs, +}; +use rpc::{ + proto::{self}, + ConnectionId, +}; +use std::sync::Arc; + +test_both_dbs!(test_channels, test_channels_postgres, test_channels_sqlite); + +async fn test_channels(db: &Arc) { + let a_id = new_test_user(db, "user1@example.com").await; + let b_id = new_test_user(db, "user2@example.com").await; + + let zed_id = db.create_root_channel("zed", a_id).await.unwrap(); + + // Make sure that people cannot read channels they haven't been invited to + assert!(db.get_channel(zed_id, b_id).await.is_err()); + + db.invite_channel_member(zed_id, b_id, a_id, ChannelRole::Member) + .await + .unwrap(); + + db.respond_to_channel_invite(zed_id, b_id, true) + .await + .unwrap(); + + let crdb_id = db.create_sub_channel("crdb", zed_id, a_id).await.unwrap(); + let livestreaming_id = db + .create_sub_channel("livestreaming", zed_id, a_id) + .await + .unwrap(); + let replace_id = db + .create_sub_channel("replace", zed_id, a_id) + .await + .unwrap(); + + let mut members = db + .transaction(|tx| async move { + let channel = db.get_channel_internal(replace_id, &*tx).await?; + Ok(db.get_channel_participants(&channel, &*tx).await?) + }) + .await + .unwrap(); + members.sort(); + assert_eq!(members, &[a_id, b_id]); + + let rust_id = db.create_root_channel("rust", a_id).await.unwrap(); + let cargo_id = db.create_sub_channel("cargo", rust_id, a_id).await.unwrap(); + + let cargo_ra_id = db + .create_sub_channel("cargo-ra", cargo_id, a_id) + .await + .unwrap(); + + let result = db.get_channels_for_user(a_id).await.unwrap(); + assert_eq!( + result.channels, + channel_tree(&[ + (zed_id, &[], "zed", ChannelRole::Admin), + (crdb_id, &[zed_id], "crdb", ChannelRole::Admin), + ( + livestreaming_id, + &[zed_id], + "livestreaming", + ChannelRole::Admin + ), + (replace_id, &[zed_id], "replace", ChannelRole::Admin), + (rust_id, &[], "rust", ChannelRole::Admin), + (cargo_id, &[rust_id], "cargo", ChannelRole::Admin), + ( + cargo_ra_id, + &[rust_id, cargo_id], + "cargo-ra", + ChannelRole::Admin + ) + ],) + ); + + let result = db.get_channels_for_user(b_id).await.unwrap(); + assert_eq!( + result.channels, + channel_tree(&[ + (zed_id, &[], "zed", ChannelRole::Member), + (crdb_id, &[zed_id], "crdb", ChannelRole::Member), + ( + livestreaming_id, + &[zed_id], + "livestreaming", + ChannelRole::Member + ), + (replace_id, &[zed_id], "replace", ChannelRole::Member) + ],) + ); + + // Update member permissions + let set_subchannel_admin = db + .set_channel_member_role(crdb_id, a_id, b_id, ChannelRole::Admin) + .await; + assert!(set_subchannel_admin.is_err()); + let set_channel_admin = db + .set_channel_member_role(zed_id, a_id, b_id, ChannelRole::Admin) + .await; + assert!(set_channel_admin.is_ok()); + + let result = db.get_channels_for_user(b_id).await.unwrap(); + assert_eq!( + result.channels, + channel_tree(&[ + (zed_id, &[], "zed", ChannelRole::Admin), + (crdb_id, &[zed_id], "crdb", ChannelRole::Admin), + ( + livestreaming_id, + &[zed_id], + "livestreaming", + ChannelRole::Admin + ), + (replace_id, &[zed_id], "replace", ChannelRole::Admin) + ],) + ); + + // Remove a single channel + db.delete_channel(crdb_id, a_id).await.unwrap(); + assert!(db.get_channel(crdb_id, a_id).await.is_err()); + + // Remove a channel tree + let (mut channel_ids, user_ids) = db.delete_channel(rust_id, a_id).await.unwrap(); + channel_ids.sort(); + assert_eq!(channel_ids, &[rust_id, cargo_id, cargo_ra_id]); + assert_eq!(user_ids, &[a_id]); + + assert!(db.get_channel(rust_id, a_id).await.is_err()); + assert!(db.get_channel(cargo_id, a_id).await.is_err()); + assert!(db.get_channel(cargo_ra_id, a_id).await.is_err()); +} + +test_both_dbs!( + test_joining_channels, + test_joining_channels_postgres, + test_joining_channels_sqlite +); + +async fn test_joining_channels(db: &Arc) { + let owner_id = db.create_server("test").await.unwrap().0 as u32; + + let user_1 = new_test_user(db, "user1@example.com").await; + let user_2 = new_test_user(db, "user2@example.com").await; + + let channel_1 = db.create_root_channel("channel_1", user_1).await.unwrap(); + + // can join a room with membership to its channel + let (joined_room, _, _) = db + .join_channel( + channel_1, + user_1, + ConnectionId { owner_id, id: 1 }, + TEST_RELEASE_CHANNEL, + ) + .await + .unwrap(); + assert_eq!(joined_room.room.participants.len(), 1); + + let room_id = RoomId::from_proto(joined_room.room.id); + drop(joined_room); + // cannot join a room without membership to its channel + assert!(db + .join_room( + room_id, + user_2, + ConnectionId { owner_id, id: 1 }, + TEST_RELEASE_CHANNEL + ) + .await + .is_err()); +} + +test_both_dbs!( + test_channel_invites, + test_channel_invites_postgres, + test_channel_invites_sqlite +); + +async fn test_channel_invites(db: &Arc) { + db.create_server("test").await.unwrap(); + + let user_1 = new_test_user(db, "user1@example.com").await; + let user_2 = new_test_user(db, "user2@example.com").await; + let user_3 = new_test_user(db, "user3@example.com").await; + + let channel_1_1 = db.create_root_channel("channel_1", user_1).await.unwrap(); + + let channel_1_2 = db.create_root_channel("channel_2", user_1).await.unwrap(); + + db.invite_channel_member(channel_1_1, user_2, user_1, ChannelRole::Member) + .await + .unwrap(); + db.invite_channel_member(channel_1_2, user_2, user_1, ChannelRole::Member) + .await + .unwrap(); + db.invite_channel_member(channel_1_1, user_3, user_1, ChannelRole::Admin) + .await + .unwrap(); + + let user_2_invites = db + .get_channel_invites_for_user(user_2) // -> [channel_1_1, channel_1_2] + .await + .unwrap() + .into_iter() + .map(|channel| channel.id) + .collect::>(); + + assert_eq!(user_2_invites, &[channel_1_1, channel_1_2]); + + let user_3_invites = db + .get_channel_invites_for_user(user_3) // -> [channel_1_1] + .await + .unwrap() + .into_iter() + .map(|channel| channel.id) + .collect::>(); + + assert_eq!(user_3_invites, &[channel_1_1]); + + let mut members = db + .get_channel_participant_details(channel_1_1, user_1) + .await + .unwrap(); + + members.sort_by_key(|member| member.user_id); + assert_eq!( + members, + &[ + proto::ChannelMember { + user_id: user_1.to_proto(), + kind: proto::channel_member::Kind::Member.into(), + role: proto::ChannelRole::Admin.into(), + }, + proto::ChannelMember { + user_id: user_2.to_proto(), + kind: proto::channel_member::Kind::Invitee.into(), + role: proto::ChannelRole::Member.into(), + }, + proto::ChannelMember { + user_id: user_3.to_proto(), + kind: proto::channel_member::Kind::Invitee.into(), + role: proto::ChannelRole::Admin.into(), + }, + ] + ); + + db.respond_to_channel_invite(channel_1_1, user_2, true) + .await + .unwrap(); + + let channel_1_3 = db + .create_sub_channel("channel_3", channel_1_1, user_1) + .await + .unwrap(); + + let members = db + .get_channel_participant_details(channel_1_3, user_1) + .await + .unwrap(); + assert_eq!( + members, + &[ + proto::ChannelMember { + user_id: user_1.to_proto(), + kind: proto::channel_member::Kind::AncestorMember.into(), + role: proto::ChannelRole::Admin.into(), + }, + proto::ChannelMember { + user_id: user_2.to_proto(), + kind: proto::channel_member::Kind::AncestorMember.into(), + role: proto::ChannelRole::Member.into(), + }, + ] + ); +} + +test_both_dbs!( + test_channel_renames, + test_channel_renames_postgres, + test_channel_renames_sqlite +); + +async fn test_channel_renames(db: &Arc) { + db.create_server("test").await.unwrap(); + + let user_1 = db + .create_user( + "user1@example.com", + false, + NewUserParams { + github_login: "user1".into(), + github_user_id: 5, + }, + ) + .await + .unwrap() + .user_id; + + let user_2 = db + .create_user( + "user2@example.com", + false, + NewUserParams { + github_login: "user2".into(), + github_user_id: 6, + }, + ) + .await + .unwrap() + .user_id; + + let zed_id = db.create_root_channel("zed", user_1).await.unwrap(); + + db.rename_channel(zed_id, user_1, "#zed-archive") + .await + .unwrap(); + + let channel = db.get_channel(zed_id, user_1).await.unwrap(); + assert_eq!(channel.name, "zed-archive"); + + let non_permissioned_rename = db.rename_channel(zed_id, user_2, "hacked-lol").await; + assert!(non_permissioned_rename.is_err()); + + let bad_name_rename = db.rename_channel(zed_id, user_1, "#").await; + assert!(bad_name_rename.is_err()) +} + +test_both_dbs!( + test_db_channel_moving, + test_channels_moving_postgres, + test_channels_moving_sqlite +); + +async fn test_db_channel_moving(db: &Arc) { + let a_id = db + .create_user( + "user1@example.com", + false, + NewUserParams { + github_login: "user1".into(), + github_user_id: 5, + }, + ) + .await + .unwrap() + .user_id; + + let zed_id = db.create_root_channel("zed", a_id).await.unwrap(); + + let crdb_id = db.create_sub_channel("crdb", zed_id, a_id).await.unwrap(); + + let gpui2_id = db.create_sub_channel("gpui2", zed_id, a_id).await.unwrap(); + + let livestreaming_id = db + .create_sub_channel("livestreaming", crdb_id, a_id) + .await + .unwrap(); + + let livestreaming_dag_id = db + .create_sub_channel("livestreaming_dag", livestreaming_id, a_id) + .await + .unwrap(); + + // ======================================================================== + // sanity check + // Initial DAG: + // /- gpui2 + // zed -- crdb - livestreaming - livestreaming_dag + let result = db.get_channels_for_user(a_id).await.unwrap(); + assert_channel_tree( + result.channels, + &[ + (zed_id, &[]), + (crdb_id, &[zed_id]), + (livestreaming_id, &[zed_id, crdb_id]), + (livestreaming_dag_id, &[zed_id, crdb_id, livestreaming_id]), + (gpui2_id, &[zed_id]), + ], + ); +} + +test_both_dbs!( + test_db_channel_moving_bugs, + test_db_channel_moving_bugs_postgres, + test_db_channel_moving_bugs_sqlite +); + +async fn test_db_channel_moving_bugs(db: &Arc) { + let user_id = db + .create_user( + "user1@example.com", + false, + NewUserParams { + github_login: "user1".into(), + github_user_id: 5, + }, + ) + .await + .unwrap() + .user_id; + + let zed_id = db.create_root_channel("zed", user_id).await.unwrap(); + + let projects_id = db + .create_sub_channel("projects", zed_id, user_id) + .await + .unwrap(); + + let livestreaming_id = db + .create_sub_channel("livestreaming", projects_id, user_id) + .await + .unwrap(); + + // Dag is: zed - projects - livestreaming + + // Move to same parent should be a no-op + assert!(db + .move_channel(projects_id, Some(zed_id), user_id) + .await + .unwrap() + .is_none()); + + let result = db.get_channels_for_user(user_id).await.unwrap(); + assert_channel_tree( + result.channels, + &[ + (zed_id, &[]), + (projects_id, &[zed_id]), + (livestreaming_id, &[zed_id, projects_id]), + ], + ); + + // Move the project channel to the root + db.move_channel(projects_id, None, user_id).await.unwrap(); + let result = db.get_channels_for_user(user_id).await.unwrap(); + assert_channel_tree( + result.channels, + &[ + (zed_id, &[]), + (projects_id, &[]), + (livestreaming_id, &[projects_id]), + ], + ); +} + +test_both_dbs!( + test_user_is_channel_participant, + test_user_is_channel_participant_postgres, + test_user_is_channel_participant_sqlite +); + +async fn test_user_is_channel_participant(db: &Arc) { + let admin = new_test_user(db, "admin@example.com").await; + let member = new_test_user(db, "member@example.com").await; + let guest = new_test_user(db, "guest@example.com").await; + + let zed_channel = db.create_root_channel("zed", admin).await.unwrap(); + let active_channel_id = db + .create_sub_channel("active", zed_channel, admin) + .await + .unwrap(); + let vim_channel_id = db + .create_sub_channel("vim", active_channel_id, admin) + .await + .unwrap(); + + db.set_channel_visibility(vim_channel_id, crate::db::ChannelVisibility::Public, admin) + .await + .unwrap(); + db.invite_channel_member(active_channel_id, member, admin, ChannelRole::Member) + .await + .unwrap(); + db.invite_channel_member(vim_channel_id, guest, admin, ChannelRole::Guest) + .await + .unwrap(); + + db.respond_to_channel_invite(active_channel_id, member, true) + .await + .unwrap(); + + db.transaction(|tx| async move { + db.check_user_is_channel_participant( + &db.get_channel_internal(vim_channel_id, &*tx).await?, + admin, + &*tx, + ) + .await + }) + .await + .unwrap(); + db.transaction(|tx| async move { + db.check_user_is_channel_participant( + &db.get_channel_internal(vim_channel_id, &*tx).await?, + member, + &*tx, + ) + .await + }) + .await + .unwrap(); + + let mut members = db + .get_channel_participant_details(vim_channel_id, admin) + .await + .unwrap(); + + members.sort_by_key(|member| member.user_id); + + assert_eq!( + members, + &[ + proto::ChannelMember { + user_id: admin.to_proto(), + kind: proto::channel_member::Kind::AncestorMember.into(), + role: proto::ChannelRole::Admin.into(), + }, + proto::ChannelMember { + user_id: member.to_proto(), + kind: proto::channel_member::Kind::AncestorMember.into(), + role: proto::ChannelRole::Member.into(), + }, + proto::ChannelMember { + user_id: guest.to_proto(), + kind: proto::channel_member::Kind::Invitee.into(), + role: proto::ChannelRole::Guest.into(), + }, + ] + ); + + db.respond_to_channel_invite(vim_channel_id, guest, true) + .await + .unwrap(); + + db.transaction(|tx| async move { + db.check_user_is_channel_participant( + &db.get_channel_internal(vim_channel_id, &*tx).await?, + guest, + &*tx, + ) + .await + }) + .await + .unwrap(); + + let channels = db.get_channels_for_user(guest).await.unwrap().channels; + assert_channel_tree(channels, &[(vim_channel_id, &[])]); + let channels = db.get_channels_for_user(member).await.unwrap().channels; + assert_channel_tree( + channels, + &[ + (active_channel_id, &[]), + (vim_channel_id, &[active_channel_id]), + ], + ); + + db.set_channel_member_role(vim_channel_id, admin, guest, ChannelRole::Banned) + .await + .unwrap(); + assert!(db + .transaction(|tx| async move { + db.check_user_is_channel_participant( + &db.get_channel_internal(vim_channel_id, &*tx).await.unwrap(), + guest, + &*tx, + ) + .await + }) + .await + .is_err()); + + let mut members = db + .get_channel_participant_details(vim_channel_id, admin) + .await + .unwrap(); + + members.sort_by_key(|member| member.user_id); + + assert_eq!( + members, + &[ + proto::ChannelMember { + user_id: admin.to_proto(), + kind: proto::channel_member::Kind::AncestorMember.into(), + role: proto::ChannelRole::Admin.into(), + }, + proto::ChannelMember { + user_id: member.to_proto(), + kind: proto::channel_member::Kind::AncestorMember.into(), + role: proto::ChannelRole::Member.into(), + }, + proto::ChannelMember { + user_id: guest.to_proto(), + kind: proto::channel_member::Kind::Member.into(), + role: proto::ChannelRole::Banned.into(), + }, + ] + ); + + db.remove_channel_member(vim_channel_id, guest, admin) + .await + .unwrap(); + + db.set_channel_visibility(zed_channel, crate::db::ChannelVisibility::Public, admin) + .await + .unwrap(); + + db.invite_channel_member(zed_channel, guest, admin, ChannelRole::Guest) + .await + .unwrap(); + + // currently people invited to parent channels are not shown here + let mut members = db + .get_channel_participant_details(vim_channel_id, admin) + .await + .unwrap(); + + members.sort_by_key(|member| member.user_id); + + assert_eq!( + members, + &[ + proto::ChannelMember { + user_id: admin.to_proto(), + kind: proto::channel_member::Kind::AncestorMember.into(), + role: proto::ChannelRole::Admin.into(), + }, + proto::ChannelMember { + user_id: member.to_proto(), + kind: proto::channel_member::Kind::AncestorMember.into(), + role: proto::ChannelRole::Member.into(), + }, + ] + ); + + db.respond_to_channel_invite(zed_channel, guest, true) + .await + .unwrap(); + + db.transaction(|tx| async move { + db.check_user_is_channel_participant( + &db.get_channel_internal(zed_channel, &*tx).await.unwrap(), + guest, + &*tx, + ) + .await + }) + .await + .unwrap(); + assert!(db + .transaction(|tx| async move { + db.check_user_is_channel_participant( + &db.get_channel_internal(active_channel_id, &*tx) + .await + .unwrap(), + guest, + &*tx, + ) + .await + }) + .await + .is_err(),); + + db.transaction(|tx| async move { + db.check_user_is_channel_participant( + &db.get_channel_internal(vim_channel_id, &*tx).await.unwrap(), + guest, + &*tx, + ) + .await + }) + .await + .unwrap(); + + let mut members = db + .get_channel_participant_details(vim_channel_id, admin) + .await + .unwrap(); + + members.sort_by_key(|member| member.user_id); + + assert_eq!( + members, + &[ + proto::ChannelMember { + user_id: admin.to_proto(), + kind: proto::channel_member::Kind::AncestorMember.into(), + role: proto::ChannelRole::Admin.into(), + }, + proto::ChannelMember { + user_id: member.to_proto(), + kind: proto::channel_member::Kind::AncestorMember.into(), + role: proto::ChannelRole::Member.into(), + }, + proto::ChannelMember { + user_id: guest.to_proto(), + kind: proto::channel_member::Kind::AncestorMember.into(), + role: proto::ChannelRole::Guest.into(), + }, + ] + ); + + let channels = db.get_channels_for_user(guest).await.unwrap().channels; + assert_channel_tree( + channels, + &[(zed_channel, &[]), (vim_channel_id, &[zed_channel])], + ) +} + +test_both_dbs!( + test_user_joins_correct_channel, + test_user_joins_correct_channel_postgres, + test_user_joins_correct_channel_sqlite +); + +async fn test_user_joins_correct_channel(db: &Arc) { + let admin = new_test_user(db, "admin@example.com").await; + + let zed_channel = db.create_root_channel("zed", admin).await.unwrap(); + + let active_channel = db + .create_sub_channel("active", zed_channel, admin) + .await + .unwrap(); + + let vim_channel = db + .create_sub_channel("vim", active_channel, admin) + .await + .unwrap(); + + let vim2_channel = db + .create_sub_channel("vim2", vim_channel, admin) + .await + .unwrap(); + + db.set_channel_visibility(zed_channel, crate::db::ChannelVisibility::Public, admin) + .await + .unwrap(); + + db.set_channel_visibility(vim_channel, crate::db::ChannelVisibility::Public, admin) + .await + .unwrap(); + + db.set_channel_visibility(vim2_channel, crate::db::ChannelVisibility::Public, admin) + .await + .unwrap(); + + let most_public = db + .transaction(|tx| async move { + Ok(db + .public_ancestors_including_self( + &db.get_channel_internal(vim_channel, &*tx).await.unwrap(), + &tx, + ) + .await? + .first() + .cloned()) + }) + .await + .unwrap() + .unwrap() + .id; + + assert_eq!(most_public, zed_channel) +} + +test_both_dbs!( + test_guest_access, + test_guest_access_postgres, + test_guest_access_sqlite +); + +async fn test_guest_access(db: &Arc) { + let server = db.create_server("test").await.unwrap(); + + let admin = new_test_user(db, "admin@example.com").await; + let guest = new_test_user(db, "guest@example.com").await; + let guest_connection = new_test_connection(server); + + let zed_channel = db.create_root_channel("zed", admin).await.unwrap(); + db.set_channel_visibility(zed_channel, crate::db::ChannelVisibility::Public, admin) + .await + .unwrap(); + + assert!(db + .join_channel_chat(zed_channel, guest_connection, guest) + .await + .is_err()); + + db.join_channel(zed_channel, guest, guest_connection, TEST_RELEASE_CHANNEL) + .await + .unwrap(); + + assert!(db + .join_channel_chat(zed_channel, guest_connection, guest) + .await + .is_ok()) +} + +#[track_caller] +fn assert_channel_tree(actual: Vec, expected: &[(ChannelId, &[ChannelId])]) { + let actual = actual + .iter() + .map(|channel| (channel.id, channel.parent_path.as_slice())) + .collect::>(); + pretty_assertions::assert_eq!( + actual, + expected.to_vec(), + "wrong channel ids and parent paths" + ); +} diff --git a/crates/collab2/src/db/tests/db_tests.rs b/crates/collab2/src/db/tests/db_tests.rs new file mode 100644 index 0000000000..98d1fee8fa --- /dev/null +++ b/crates/collab2/src/db/tests/db_tests.rs @@ -0,0 +1,633 @@ +use super::*; +use crate::test_both_dbs; +use gpui::TestAppContext; +use pretty_assertions::{assert_eq, assert_ne}; +use std::sync::Arc; +use tests::TestDb; + +test_both_dbs!( + test_get_users, + test_get_users_by_ids_postgres, + test_get_users_by_ids_sqlite +); + +async fn test_get_users(db: &Arc) { + let mut user_ids = Vec::new(); + let mut user_metric_ids = Vec::new(); + for i in 1..=4 { + let user = db + .create_user( + &format!("user{i}@example.com"), + false, + NewUserParams { + github_login: format!("user{i}"), + github_user_id: i, + }, + ) + .await + .unwrap(); + user_ids.push(user.user_id); + user_metric_ids.push(user.metrics_id); + } + + assert_eq!( + db.get_users_by_ids(user_ids.clone()).await.unwrap(), + vec![ + User { + id: user_ids[0], + github_login: "user1".to_string(), + github_user_id: Some(1), + email_address: Some("user1@example.com".to_string()), + admin: false, + metrics_id: user_metric_ids[0].parse().unwrap(), + ..Default::default() + }, + User { + id: user_ids[1], + github_login: "user2".to_string(), + github_user_id: Some(2), + email_address: Some("user2@example.com".to_string()), + admin: false, + metrics_id: user_metric_ids[1].parse().unwrap(), + ..Default::default() + }, + User { + id: user_ids[2], + github_login: "user3".to_string(), + github_user_id: Some(3), + email_address: Some("user3@example.com".to_string()), + admin: false, + metrics_id: user_metric_ids[2].parse().unwrap(), + ..Default::default() + }, + User { + id: user_ids[3], + github_login: "user4".to_string(), + github_user_id: Some(4), + email_address: Some("user4@example.com".to_string()), + admin: false, + metrics_id: user_metric_ids[3].parse().unwrap(), + ..Default::default() + } + ] + ); +} + +test_both_dbs!( + test_get_or_create_user_by_github_account, + test_get_or_create_user_by_github_account_postgres, + test_get_or_create_user_by_github_account_sqlite +); + +async fn test_get_or_create_user_by_github_account(db: &Arc) { + let user_id1 = db + .create_user( + "user1@example.com", + false, + NewUserParams { + github_login: "login1".into(), + github_user_id: 101, + }, + ) + .await + .unwrap() + .user_id; + let user_id2 = db + .create_user( + "user2@example.com", + false, + NewUserParams { + github_login: "login2".into(), + github_user_id: 102, + }, + ) + .await + .unwrap() + .user_id; + + let user = db + .get_or_create_user_by_github_account("login1", None, None) + .await + .unwrap() + .unwrap(); + assert_eq!(user.id, user_id1); + assert_eq!(&user.github_login, "login1"); + assert_eq!(user.github_user_id, Some(101)); + + assert!(db + .get_or_create_user_by_github_account("non-existent-login", None, None) + .await + .unwrap() + .is_none()); + + let user = db + .get_or_create_user_by_github_account("the-new-login2", Some(102), None) + .await + .unwrap() + .unwrap(); + assert_eq!(user.id, user_id2); + assert_eq!(&user.github_login, "the-new-login2"); + assert_eq!(user.github_user_id, Some(102)); + + let user = db + .get_or_create_user_by_github_account("login3", Some(103), Some("user3@example.com")) + .await + .unwrap() + .unwrap(); + assert_eq!(&user.github_login, "login3"); + assert_eq!(user.github_user_id, Some(103)); + assert_eq!(user.email_address, Some("user3@example.com".into())); +} + +test_both_dbs!( + test_create_access_tokens, + test_create_access_tokens_postgres, + test_create_access_tokens_sqlite +); + +async fn test_create_access_tokens(db: &Arc) { + let user = db + .create_user( + "u1@example.com", + false, + NewUserParams { + github_login: "u1".into(), + github_user_id: 1, + }, + ) + .await + .unwrap() + .user_id; + + let token_1 = db.create_access_token(user, "h1", 2).await.unwrap(); + let token_2 = db.create_access_token(user, "h2", 2).await.unwrap(); + assert_eq!( + db.get_access_token(token_1).await.unwrap(), + access_token::Model { + id: token_1, + user_id: user, + hash: "h1".into(), + } + ); + assert_eq!( + db.get_access_token(token_2).await.unwrap(), + access_token::Model { + id: token_2, + user_id: user, + hash: "h2".into() + } + ); + + let token_3 = db.create_access_token(user, "h3", 2).await.unwrap(); + assert_eq!( + db.get_access_token(token_3).await.unwrap(), + access_token::Model { + id: token_3, + user_id: user, + hash: "h3".into() + } + ); + assert_eq!( + db.get_access_token(token_2).await.unwrap(), + access_token::Model { + id: token_2, + user_id: user, + hash: "h2".into() + } + ); + assert!(db.get_access_token(token_1).await.is_err()); + + let token_4 = db.create_access_token(user, "h4", 2).await.unwrap(); + assert_eq!( + db.get_access_token(token_4).await.unwrap(), + access_token::Model { + id: token_4, + user_id: user, + hash: "h4".into() + } + ); + assert_eq!( + db.get_access_token(token_3).await.unwrap(), + access_token::Model { + id: token_3, + user_id: user, + hash: "h3".into() + } + ); + assert!(db.get_access_token(token_2).await.is_err()); + assert!(db.get_access_token(token_1).await.is_err()); +} + +test_both_dbs!( + test_add_contacts, + test_add_contacts_postgres, + test_add_contacts_sqlite +); + +async fn test_add_contacts(db: &Arc) { + let mut user_ids = Vec::new(); + for i in 0..3 { + user_ids.push( + db.create_user( + &format!("user{i}@example.com"), + false, + NewUserParams { + github_login: format!("user{i}"), + github_user_id: i, + }, + ) + .await + .unwrap() + .user_id, + ); + } + + let user_1 = user_ids[0]; + let user_2 = user_ids[1]; + let user_3 = user_ids[2]; + + // User starts with no contacts + assert_eq!(db.get_contacts(user_1).await.unwrap(), &[]); + + // User requests a contact. Both users see the pending request. + db.send_contact_request(user_1, user_2).await.unwrap(); + assert!(!db.has_contact(user_1, user_2).await.unwrap()); + assert!(!db.has_contact(user_2, user_1).await.unwrap()); + assert_eq!( + db.get_contacts(user_1).await.unwrap(), + &[Contact::Outgoing { user_id: user_2 }], + ); + assert_eq!( + db.get_contacts(user_2).await.unwrap(), + &[Contact::Incoming { user_id: user_1 }] + ); + + // User 2 dismisses the contact request notification without accepting or rejecting. + // We shouldn't notify them again. + db.dismiss_contact_notification(user_1, user_2) + .await + .unwrap_err(); + db.dismiss_contact_notification(user_2, user_1) + .await + .unwrap(); + assert_eq!( + db.get_contacts(user_2).await.unwrap(), + &[Contact::Incoming { user_id: user_1 }] + ); + + // User can't accept their own contact request + db.respond_to_contact_request(user_1, user_2, true) + .await + .unwrap_err(); + + // User accepts a contact request. Both users see the contact. + db.respond_to_contact_request(user_2, user_1, true) + .await + .unwrap(); + assert_eq!( + db.get_contacts(user_1).await.unwrap(), + &[Contact::Accepted { + user_id: user_2, + busy: false, + }], + ); + assert!(db.has_contact(user_1, user_2).await.unwrap()); + assert!(db.has_contact(user_2, user_1).await.unwrap()); + assert_eq!( + db.get_contacts(user_2).await.unwrap(), + &[Contact::Accepted { + user_id: user_1, + busy: false, + }] + ); + + // Users cannot re-request existing contacts. + db.send_contact_request(user_1, user_2).await.unwrap_err(); + db.send_contact_request(user_2, user_1).await.unwrap_err(); + + // Users can't dismiss notifications of them accepting other users' requests. + db.dismiss_contact_notification(user_2, user_1) + .await + .unwrap_err(); + assert_eq!( + db.get_contacts(user_1).await.unwrap(), + &[Contact::Accepted { + user_id: user_2, + busy: false, + }] + ); + + // Users can dismiss notifications of other users accepting their requests. + db.dismiss_contact_notification(user_1, user_2) + .await + .unwrap(); + assert_eq!( + db.get_contacts(user_1).await.unwrap(), + &[Contact::Accepted { + user_id: user_2, + busy: false, + }] + ); + + // Users send each other concurrent contact requests and + // see that they are immediately accepted. + db.send_contact_request(user_1, user_3).await.unwrap(); + db.send_contact_request(user_3, user_1).await.unwrap(); + assert_eq!( + db.get_contacts(user_1).await.unwrap(), + &[ + Contact::Accepted { + user_id: user_2, + busy: false, + }, + Contact::Accepted { + user_id: user_3, + busy: false, + } + ] + ); + assert_eq!( + db.get_contacts(user_3).await.unwrap(), + &[Contact::Accepted { + user_id: user_1, + busy: false, + }], + ); + + // User declines a contact request. Both users see that it is gone. + db.send_contact_request(user_2, user_3).await.unwrap(); + db.respond_to_contact_request(user_3, user_2, false) + .await + .unwrap(); + assert!(!db.has_contact(user_2, user_3).await.unwrap()); + assert!(!db.has_contact(user_3, user_2).await.unwrap()); + assert_eq!( + db.get_contacts(user_2).await.unwrap(), + &[Contact::Accepted { + user_id: user_1, + busy: false, + }] + ); + assert_eq!( + db.get_contacts(user_3).await.unwrap(), + &[Contact::Accepted { + user_id: user_1, + busy: false, + }], + ); +} + +test_both_dbs!( + test_metrics_id, + test_metrics_id_postgres, + test_metrics_id_sqlite +); + +async fn test_metrics_id(db: &Arc) { + let NewUserResult { + user_id: user1, + metrics_id: metrics_id1, + .. + } = db + .create_user( + "person1@example.com", + false, + NewUserParams { + github_login: "person1".into(), + github_user_id: 101, + }, + ) + .await + .unwrap(); + let NewUserResult { + user_id: user2, + metrics_id: metrics_id2, + .. + } = db + .create_user( + "person2@example.com", + false, + NewUserParams { + github_login: "person2".into(), + github_user_id: 102, + }, + ) + .await + .unwrap(); + + assert_eq!(db.get_user_metrics_id(user1).await.unwrap(), metrics_id1); + assert_eq!(db.get_user_metrics_id(user2).await.unwrap(), metrics_id2); + assert_eq!(metrics_id1.len(), 36); + assert_eq!(metrics_id2.len(), 36); + assert_ne!(metrics_id1, metrics_id2); +} + +test_both_dbs!( + test_project_count, + test_project_count_postgres, + test_project_count_sqlite +); + +async fn test_project_count(db: &Arc) { + let owner_id = db.create_server("test").await.unwrap().0 as u32; + + let user1 = db + .create_user( + &format!("admin@example.com"), + true, + NewUserParams { + github_login: "admin".into(), + github_user_id: 0, + }, + ) + .await + .unwrap(); + let user2 = db + .create_user( + &format!("user@example.com"), + false, + NewUserParams { + github_login: "user".into(), + github_user_id: 1, + }, + ) + .await + .unwrap(); + + let room_id = RoomId::from_proto( + db.create_room(user1.user_id, ConnectionId { owner_id, id: 0 }, "", "dev") + .await + .unwrap() + .id, + ); + db.call( + room_id, + user1.user_id, + ConnectionId { owner_id, id: 0 }, + user2.user_id, + None, + ) + .await + .unwrap(); + db.join_room( + room_id, + user2.user_id, + ConnectionId { owner_id, id: 1 }, + "dev", + ) + .await + .unwrap(); + assert_eq!(db.project_count_excluding_admins().await.unwrap(), 0); + + db.share_project(room_id, ConnectionId { owner_id, id: 1 }, &[]) + .await + .unwrap(); + assert_eq!(db.project_count_excluding_admins().await.unwrap(), 1); + + db.share_project(room_id, ConnectionId { owner_id, id: 1 }, &[]) + .await + .unwrap(); + assert_eq!(db.project_count_excluding_admins().await.unwrap(), 2); + + // Projects shared by admins aren't counted. + db.share_project(room_id, ConnectionId { owner_id, id: 0 }, &[]) + .await + .unwrap(); + assert_eq!(db.project_count_excluding_admins().await.unwrap(), 2); + + db.leave_room(ConnectionId { owner_id, id: 1 }) + .await + .unwrap(); + assert_eq!(db.project_count_excluding_admins().await.unwrap(), 0); +} + +#[test] +fn test_fuzzy_like_string() { + assert_eq!(Database::fuzzy_like_string("abcd"), "%a%b%c%d%"); + assert_eq!(Database::fuzzy_like_string("x y"), "%x%y%"); + assert_eq!(Database::fuzzy_like_string(" z "), "%z%"); +} + +#[gpui::test] +async fn test_fuzzy_search_users(cx: &mut TestAppContext) { + let test_db = TestDb::postgres(cx.executor().clone()); + let db = test_db.db(); + for (i, github_login) in [ + "California", + "colorado", + "oregon", + "washington", + "florida", + "delaware", + "rhode-island", + ] + .into_iter() + .enumerate() + { + db.create_user( + &format!("{github_login}@example.com"), + false, + NewUserParams { + github_login: github_login.into(), + github_user_id: i as i32, + }, + ) + .await + .unwrap(); + } + + assert_eq!( + fuzzy_search_user_names(db, "clr").await, + &["colorado", "California"] + ); + assert_eq!( + fuzzy_search_user_names(db, "ro").await, + &["rhode-island", "colorado", "oregon"], + ); + + async fn fuzzy_search_user_names(db: &Database, query: &str) -> Vec { + db.fuzzy_search_users(query, 10) + .await + .unwrap() + .into_iter() + .map(|user| user.github_login) + .collect::>() + } +} + +test_both_dbs!( + test_non_matching_release_channels, + test_non_matching_release_channels_postgres, + test_non_matching_release_channels_sqlite +); + +async fn test_non_matching_release_channels(db: &Arc) { + let owner_id = db.create_server("test").await.unwrap().0 as u32; + + let user1 = db + .create_user( + &format!("admin@example.com"), + true, + NewUserParams { + github_login: "admin".into(), + github_user_id: 0, + }, + ) + .await + .unwrap(); + let user2 = db + .create_user( + &format!("user@example.com"), + false, + NewUserParams { + github_login: "user".into(), + github_user_id: 1, + }, + ) + .await + .unwrap(); + + let room = db + .create_room( + user1.user_id, + ConnectionId { owner_id, id: 0 }, + "", + "stable", + ) + .await + .unwrap(); + + db.call( + RoomId::from_proto(room.id), + user1.user_id, + ConnectionId { owner_id, id: 0 }, + user2.user_id, + None, + ) + .await + .unwrap(); + + // User attempts to join from preview + let result = db + .join_room( + RoomId::from_proto(room.id), + user2.user_id, + ConnectionId { owner_id, id: 1 }, + "preview", + ) + .await; + + assert!(result.is_err()); + + // User switches to stable + let result = db + .join_room( + RoomId::from_proto(room.id), + user2.user_id, + ConnectionId { owner_id, id: 1 }, + "stable", + ) + .await; + + assert!(result.is_ok()) +} diff --git a/crates/collab2/src/db/tests/feature_flag_tests.rs b/crates/collab2/src/db/tests/feature_flag_tests.rs new file mode 100644 index 0000000000..0286a6308e --- /dev/null +++ b/crates/collab2/src/db/tests/feature_flag_tests.rs @@ -0,0 +1,58 @@ +use crate::{ + db::{Database, NewUserParams}, + test_both_dbs, +}; +use std::sync::Arc; + +test_both_dbs!( + test_get_user_flags, + test_get_user_flags_postgres, + test_get_user_flags_sqlite +); + +async fn test_get_user_flags(db: &Arc) { + let user_1 = db + .create_user( + &format!("user1@example.com"), + false, + NewUserParams { + github_login: format!("user1"), + github_user_id: 1, + }, + ) + .await + .unwrap() + .user_id; + + let user_2 = db + .create_user( + &format!("user2@example.com"), + false, + NewUserParams { + github_login: format!("user2"), + github_user_id: 2, + }, + ) + .await + .unwrap() + .user_id; + + const CHANNELS_ALPHA: &'static str = "channels-alpha"; + const NEW_SEARCH: &'static str = "new-search"; + + let channels_flag = db.create_user_flag(CHANNELS_ALPHA).await.unwrap(); + let search_flag = db.create_user_flag(NEW_SEARCH).await.unwrap(); + + db.add_user_flag(user_1, channels_flag).await.unwrap(); + db.add_user_flag(user_1, search_flag).await.unwrap(); + + db.add_user_flag(user_2, channels_flag).await.unwrap(); + + let mut user_1_flags = db.get_user_flags(user_1).await.unwrap(); + user_1_flags.sort(); + assert_eq!(user_1_flags, &[CHANNELS_ALPHA, NEW_SEARCH]); + + let mut user_2_flags = db.get_user_flags(user_2).await.unwrap(); + user_2_flags.sort(); + assert_eq!(user_2_flags, &[CHANNELS_ALPHA]); +} diff --git a/crates/collab2/src/db/tests/message_tests.rs b/crates/collab2/src/db/tests/message_tests.rs new file mode 100644 index 0000000000..10d9778612 --- /dev/null +++ b/crates/collab2/src/db/tests/message_tests.rs @@ -0,0 +1,454 @@ +use super::new_test_user; +use crate::{ + db::{ChannelRole, Database, MessageId}, + test_both_dbs, +}; +use channel::mentions_to_proto; +use std::sync::Arc; +use time::OffsetDateTime; + +test_both_dbs!( + test_channel_message_retrieval, + test_channel_message_retrieval_postgres, + test_channel_message_retrieval_sqlite +); + +async fn test_channel_message_retrieval(db: &Arc) { + let user = new_test_user(db, "user@example.com").await; + let result = db.create_channel("channel", None, user).await.unwrap(); + + let owner_id = db.create_server("test").await.unwrap().0 as u32; + db.join_channel_chat( + result.channel.id, + rpc::ConnectionId { owner_id, id: 0 }, + user, + ) + .await + .unwrap(); + + let mut all_messages = Vec::new(); + for i in 0..10 { + all_messages.push( + db.create_channel_message( + result.channel.id, + user, + &i.to_string(), + &[], + OffsetDateTime::now_utc(), + i, + ) + .await + .unwrap() + .message_id + .to_proto(), + ); + } + + let messages = db + .get_channel_messages(result.channel.id, user, 3, None) + .await + .unwrap() + .into_iter() + .map(|message| message.id) + .collect::>(); + assert_eq!(messages, &all_messages[7..10]); + + let messages = db + .get_channel_messages( + result.channel.id, + user, + 4, + Some(MessageId::from_proto(all_messages[6])), + ) + .await + .unwrap() + .into_iter() + .map(|message| message.id) + .collect::>(); + assert_eq!(messages, &all_messages[2..6]); +} + +test_both_dbs!( + test_channel_message_nonces, + test_channel_message_nonces_postgres, + test_channel_message_nonces_sqlite +); + +async fn test_channel_message_nonces(db: &Arc) { + let user_a = new_test_user(db, "user_a@example.com").await; + let user_b = new_test_user(db, "user_b@example.com").await; + let user_c = new_test_user(db, "user_c@example.com").await; + let channel = db.create_root_channel("channel", user_a).await.unwrap(); + db.invite_channel_member(channel, user_b, user_a, ChannelRole::Member) + .await + .unwrap(); + db.invite_channel_member(channel, user_c, user_a, ChannelRole::Member) + .await + .unwrap(); + db.respond_to_channel_invite(channel, user_b, true) + .await + .unwrap(); + db.respond_to_channel_invite(channel, user_c, true) + .await + .unwrap(); + + let owner_id = db.create_server("test").await.unwrap().0 as u32; + db.join_channel_chat(channel, rpc::ConnectionId { owner_id, id: 0 }, user_a) + .await + .unwrap(); + db.join_channel_chat(channel, rpc::ConnectionId { owner_id, id: 1 }, user_b) + .await + .unwrap(); + + // As user A, create messages that re-use the same nonces. The requests + // succeed, but return the same ids. + let id1 = db + .create_channel_message( + channel, + user_a, + "hi @user_b", + &mentions_to_proto(&[(3..10, user_b.to_proto())]), + OffsetDateTime::now_utc(), + 100, + ) + .await + .unwrap() + .message_id; + let id2 = db + .create_channel_message( + channel, + user_a, + "hello, fellow users", + &mentions_to_proto(&[]), + OffsetDateTime::now_utc(), + 200, + ) + .await + .unwrap() + .message_id; + let id3 = db + .create_channel_message( + channel, + user_a, + "bye @user_c (same nonce as first message)", + &mentions_to_proto(&[(4..11, user_c.to_proto())]), + OffsetDateTime::now_utc(), + 100, + ) + .await + .unwrap() + .message_id; + let id4 = db + .create_channel_message( + channel, + user_a, + "omg (same nonce as second message)", + &mentions_to_proto(&[]), + OffsetDateTime::now_utc(), + 200, + ) + .await + .unwrap() + .message_id; + + // As a different user, reuse one of the same nonces. This request succeeds + // and returns a different id. + let id5 = db + .create_channel_message( + channel, + user_b, + "omg @user_a (same nonce as user_a's first message)", + &mentions_to_proto(&[(4..11, user_a.to_proto())]), + OffsetDateTime::now_utc(), + 100, + ) + .await + .unwrap() + .message_id; + + assert_ne!(id1, id2); + assert_eq!(id1, id3); + assert_eq!(id2, id4); + assert_ne!(id5, id1); + + let messages = db + .get_channel_messages(channel, user_a, 5, None) + .await + .unwrap() + .into_iter() + .map(|m| (m.id, m.body, m.mentions)) + .collect::>(); + assert_eq!( + messages, + &[ + ( + id1.to_proto(), + "hi @user_b".into(), + mentions_to_proto(&[(3..10, user_b.to_proto())]), + ), + ( + id2.to_proto(), + "hello, fellow users".into(), + mentions_to_proto(&[]) + ), + ( + id5.to_proto(), + "omg @user_a (same nonce as user_a's first message)".into(), + mentions_to_proto(&[(4..11, user_a.to_proto())]), + ), + ] + ); +} + +test_both_dbs!( + test_unseen_channel_messages, + test_unseen_channel_messages_postgres, + test_unseen_channel_messages_sqlite +); + +async fn test_unseen_channel_messages(db: &Arc) { + let user = new_test_user(db, "user_a@example.com").await; + let observer = new_test_user(db, "user_b@example.com").await; + + let channel_1 = db.create_root_channel("channel", user).await.unwrap(); + let channel_2 = db.create_root_channel("channel-2", user).await.unwrap(); + + db.invite_channel_member(channel_1, observer, user, ChannelRole::Member) + .await + .unwrap(); + db.invite_channel_member(channel_2, observer, user, ChannelRole::Member) + .await + .unwrap(); + + db.respond_to_channel_invite(channel_1, observer, true) + .await + .unwrap(); + db.respond_to_channel_invite(channel_2, observer, true) + .await + .unwrap(); + + let owner_id = db.create_server("test").await.unwrap().0 as u32; + let user_connection_id = rpc::ConnectionId { owner_id, id: 0 }; + + db.join_channel_chat(channel_1, user_connection_id, user) + .await + .unwrap(); + + let _ = db + .create_channel_message(channel_1, user, "1_1", &[], OffsetDateTime::now_utc(), 1) + .await + .unwrap(); + + let second_message = db + .create_channel_message(channel_1, user, "1_2", &[], OffsetDateTime::now_utc(), 2) + .await + .unwrap() + .message_id; + + let third_message = db + .create_channel_message(channel_1, user, "1_3", &[], OffsetDateTime::now_utc(), 3) + .await + .unwrap() + .message_id; + + db.join_channel_chat(channel_2, user_connection_id, user) + .await + .unwrap(); + + let fourth_message = db + .create_channel_message(channel_2, user, "2_1", &[], OffsetDateTime::now_utc(), 4) + .await + .unwrap() + .message_id; + + // Check that observer has new messages + let unseen_messages = db + .transaction(|tx| async move { + db.unseen_channel_messages(observer, &[channel_1, channel_2], &*tx) + .await + }) + .await + .unwrap(); + + assert_eq!( + unseen_messages, + [ + rpc::proto::UnseenChannelMessage { + channel_id: channel_1.to_proto(), + message_id: third_message.to_proto(), + }, + rpc::proto::UnseenChannelMessage { + channel_id: channel_2.to_proto(), + message_id: fourth_message.to_proto(), + }, + ] + ); + + // Observe the second message + db.observe_channel_message(channel_1, observer, second_message) + .await + .unwrap(); + + // Make sure the observer still has a new message + let unseen_messages = db + .transaction(|tx| async move { + db.unseen_channel_messages(observer, &[channel_1, channel_2], &*tx) + .await + }) + .await + .unwrap(); + assert_eq!( + unseen_messages, + [ + rpc::proto::UnseenChannelMessage { + channel_id: channel_1.to_proto(), + message_id: third_message.to_proto(), + }, + rpc::proto::UnseenChannelMessage { + channel_id: channel_2.to_proto(), + message_id: fourth_message.to_proto(), + }, + ] + ); + + // Observe the third message, + db.observe_channel_message(channel_1, observer, third_message) + .await + .unwrap(); + + // Make sure the observer does not have a new method + let unseen_messages = db + .transaction(|tx| async move { + db.unseen_channel_messages(observer, &[channel_1, channel_2], &*tx) + .await + }) + .await + .unwrap(); + + assert_eq!( + unseen_messages, + [rpc::proto::UnseenChannelMessage { + channel_id: channel_2.to_proto(), + message_id: fourth_message.to_proto(), + }] + ); + + // Observe the second message again, should not regress our observed state + db.observe_channel_message(channel_1, observer, second_message) + .await + .unwrap(); + + // Make sure the observer does not have a new message + let unseen_messages = db + .transaction(|tx| async move { + db.unseen_channel_messages(observer, &[channel_1, channel_2], &*tx) + .await + }) + .await + .unwrap(); + assert_eq!( + unseen_messages, + [rpc::proto::UnseenChannelMessage { + channel_id: channel_2.to_proto(), + message_id: fourth_message.to_proto(), + }] + ); +} + +test_both_dbs!( + test_channel_message_mentions, + test_channel_message_mentions_postgres, + test_channel_message_mentions_sqlite +); + +async fn test_channel_message_mentions(db: &Arc) { + let user_a = new_test_user(db, "user_a@example.com").await; + let user_b = new_test_user(db, "user_b@example.com").await; + let user_c = new_test_user(db, "user_c@example.com").await; + + let channel = db + .create_channel("channel", None, user_a) + .await + .unwrap() + .channel + .id; + db.invite_channel_member(channel, user_b, user_a, ChannelRole::Member) + .await + .unwrap(); + db.respond_to_channel_invite(channel, user_b, true) + .await + .unwrap(); + + let owner_id = db.create_server("test").await.unwrap().0 as u32; + let connection_id = rpc::ConnectionId { owner_id, id: 0 }; + db.join_channel_chat(channel, connection_id, user_a) + .await + .unwrap(); + + db.create_channel_message( + channel, + user_a, + "hi @user_b and @user_c", + &mentions_to_proto(&[(3..10, user_b.to_proto()), (15..22, user_c.to_proto())]), + OffsetDateTime::now_utc(), + 1, + ) + .await + .unwrap(); + db.create_channel_message( + channel, + user_a, + "bye @user_c", + &mentions_to_proto(&[(4..11, user_c.to_proto())]), + OffsetDateTime::now_utc(), + 2, + ) + .await + .unwrap(); + db.create_channel_message( + channel, + user_a, + "umm", + &mentions_to_proto(&[]), + OffsetDateTime::now_utc(), + 3, + ) + .await + .unwrap(); + db.create_channel_message( + channel, + user_a, + "@user_b, stop.", + &mentions_to_proto(&[(0..7, user_b.to_proto())]), + OffsetDateTime::now_utc(), + 4, + ) + .await + .unwrap(); + + let messages = db + .get_channel_messages(channel, user_b, 5, None) + .await + .unwrap() + .into_iter() + .map(|m| (m.body, m.mentions)) + .collect::>(); + assert_eq!( + &messages, + &[ + ( + "hi @user_b and @user_c".into(), + mentions_to_proto(&[(3..10, user_b.to_proto()), (15..22, user_c.to_proto())]), + ), + ( + "bye @user_c".into(), + mentions_to_proto(&[(4..11, user_c.to_proto())]), + ), + ("umm".into(), mentions_to_proto(&[]),), + ( + "@user_b, stop.".into(), + mentions_to_proto(&[(0..7, user_b.to_proto())]), + ), + ] + ); +} diff --git a/crates/collab2/src/env.rs b/crates/collab2/src/env.rs new file mode 100644 index 0000000000..58c29b0205 --- /dev/null +++ b/crates/collab2/src/env.rs @@ -0,0 +1,20 @@ +use anyhow::anyhow; +use std::fs; + +pub fn load_dotenv() -> anyhow::Result<()> { + let env: toml::map::Map = toml::de::from_str( + &fs::read_to_string("./.env.toml").map_err(|_| anyhow!("no .env.toml file found"))?, + )?; + + for (key, value) in env { + let value = match value { + toml::Value::String(value) => value, + toml::Value::Integer(value) => value.to_string(), + toml::Value::Float(value) => value.to_string(), + _ => panic!("unsupported TOML value in .env.toml for key {}", key), + }; + std::env::set_var(key, value); + } + + Ok(()) +} diff --git a/crates/collab2/src/errors.rs b/crates/collab2/src/errors.rs new file mode 100644 index 0000000000..93e46848a1 --- /dev/null +++ b/crates/collab2/src/errors.rs @@ -0,0 +1,29 @@ +// Allow tide Results to accept context like other Results do when +// using anyhow. +pub trait TideResultExt { + fn context(self, cx: C) -> Self + where + C: std::fmt::Display + Send + Sync + 'static; + + fn with_context(self, f: F) -> Self + where + C: std::fmt::Display + Send + Sync + 'static, + F: FnOnce() -> C; +} + +impl TideResultExt for tide::Result { + fn context(self, cx: C) -> Self + where + C: std::fmt::Display + Send + Sync + 'static, + { + self.map_err(|e| tide::Error::new(e.status(), e.into_inner().context(cx))) + } + + fn with_context(self, f: F) -> Self + where + C: std::fmt::Display + Send + Sync + 'static, + F: FnOnce() -> C, + { + self.map_err(|e| tide::Error::new(e.status(), e.into_inner().context(f()))) + } +} diff --git a/crates/collab2/src/executor.rs b/crates/collab2/src/executor.rs new file mode 100644 index 0000000000..81d5e977a6 --- /dev/null +++ b/crates/collab2/src/executor.rs @@ -0,0 +1,39 @@ +use std::{future::Future, time::Duration}; + +#[cfg(test)] +use gpui::BackgroundExecutor; + +#[derive(Clone)] +pub enum Executor { + Production, + #[cfg(test)] + Deterministic(BackgroundExecutor), +} + +impl Executor { + pub fn spawn_detached(&self, future: F) + where + F: 'static + Send + Future, + { + match self { + Executor::Production => { + tokio::spawn(future); + } + #[cfg(test)] + Executor::Deterministic(background) => { + background.spawn(future).detach(); + } + } + } + + pub fn sleep(&self, duration: Duration) -> impl Future { + let this = self.clone(); + async move { + match this { + Executor::Production => tokio::time::sleep(duration).await, + #[cfg(test)] + Executor::Deterministic(background) => background.timer(duration).await, + } + } + } +} diff --git a/crates/collab2/src/lib.rs b/crates/collab2/src/lib.rs new file mode 100644 index 0000000000..85216525b0 --- /dev/null +++ b/crates/collab2/src/lib.rs @@ -0,0 +1,147 @@ +pub mod api; +pub mod auth; +pub mod db; +pub mod env; +pub mod executor; +pub mod rpc; + +#[cfg(test)] +mod tests; + +use axum::{http::StatusCode, response::IntoResponse}; +use db::Database; +use executor::Executor; +use serde::Deserialize; +use std::{path::PathBuf, sync::Arc}; + +pub type Result = std::result::Result; + +pub enum Error { + Http(StatusCode, String), + Database(sea_orm::error::DbErr), + Internal(anyhow::Error), +} + +impl From for Error { + fn from(error: anyhow::Error) -> Self { + Self::Internal(error) + } +} + +impl From for Error { + fn from(error: sea_orm::error::DbErr) -> Self { + Self::Database(error) + } +} + +impl From for Error { + fn from(error: axum::Error) -> Self { + Self::Internal(error.into()) + } +} + +impl From for Error { + fn from(error: hyper::Error) -> Self { + Self::Internal(error.into()) + } +} + +impl From for Error { + fn from(error: serde_json::Error) -> Self { + Self::Internal(error.into()) + } +} + +impl IntoResponse for Error { + fn into_response(self) -> axum::response::Response { + match self { + Error::Http(code, message) => (code, message).into_response(), + Error::Database(error) => { + (StatusCode::INTERNAL_SERVER_ERROR, format!("{}", &error)).into_response() + } + Error::Internal(error) => { + (StatusCode::INTERNAL_SERVER_ERROR, format!("{}", &error)).into_response() + } + } + } +} + +impl std::fmt::Debug for Error { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Error::Http(code, message) => (code, message).fmt(f), + Error::Database(error) => error.fmt(f), + Error::Internal(error) => error.fmt(f), + } + } +} + +impl std::fmt::Display for Error { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Error::Http(code, message) => write!(f, "{code}: {message}"), + Error::Database(error) => error.fmt(f), + Error::Internal(error) => error.fmt(f), + } + } +} + +impl std::error::Error for Error {} + +#[derive(Default, Deserialize)] +pub struct Config { + pub http_port: u16, + pub database_url: String, + pub database_max_connections: u32, + pub api_token: String, + pub invite_link_prefix: String, + pub live_kit_server: Option, + pub live_kit_key: Option, + pub live_kit_secret: Option, + pub rust_log: Option, + pub log_json: Option, + pub zed_environment: String, +} + +#[derive(Default, Deserialize)] +pub struct MigrateConfig { + pub database_url: String, + pub migrations_path: Option, +} + +pub struct AppState { + pub db: Arc, + pub live_kit_client: Option>, + pub config: Config, +} + +impl AppState { + pub async fn new(config: Config) -> Result> { + let mut db_options = db::ConnectOptions::new(config.database_url.clone()); + db_options.max_connections(config.database_max_connections); + let mut db = Database::new(db_options, Executor::Production).await?; + db.initialize_notification_kinds().await?; + + let live_kit_client = if let Some(((server, key), secret)) = config + .live_kit_server + .as_ref() + .zip(config.live_kit_key.as_ref()) + .zip(config.live_kit_secret.as_ref()) + { + Some(Arc::new(live_kit_server::api::LiveKitClient::new( + server.clone(), + key.clone(), + secret.clone(), + )) as Arc) + } else { + None + }; + + let this = Self { + db: Arc::new(db), + live_kit_client, + config, + }; + Ok(Arc::new(this)) + } +} diff --git a/crates/collab2/src/main.rs b/crates/collab2/src/main.rs new file mode 100644 index 0000000000..a7167ef630 --- /dev/null +++ b/crates/collab2/src/main.rs @@ -0,0 +1,139 @@ +use anyhow::anyhow; +use axum::{routing::get, Extension, Router}; +use collab2::{db, env, executor::Executor, AppState, Config, MigrateConfig, Result}; +use db::Database; +use std::{ + env::args, + net::{SocketAddr, TcpListener}, + path::Path, + sync::Arc, +}; +use tokio::signal::unix::SignalKind; +use tracing_log::LogTracer; +use tracing_subscriber::{filter::EnvFilter, fmt::format::JsonFields, Layer}; +use util::ResultExt; + +const VERSION: &'static str = env!("CARGO_PKG_VERSION"); + +#[tokio::main] +async fn main() -> Result<()> { + if let Err(error) = env::load_dotenv() { + eprintln!( + "error loading .env.toml (this is expected in production): {}", + error + ); + } + + match args().skip(1).next().as_deref() { + Some("version") => { + println!("collab v{VERSION}"); + } + Some("migrate") => { + let config = envy::from_env::().expect("error loading config"); + let mut db_options = db::ConnectOptions::new(config.database_url.clone()); + db_options.max_connections(5); + let db = Database::new(db_options, Executor::Production).await?; + + let migrations_path = config + .migrations_path + .as_deref() + .unwrap_or_else(|| Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"))); + + let migrations = db.migrate(&migrations_path, false).await?; + for (migration, duration) in migrations { + println!( + "Ran {} {} {:?}", + migration.version, migration.description, duration + ); + } + + return Ok(()); + } + Some("serve") => { + let config = envy::from_env::().expect("error loading config"); + init_tracing(&config); + + let state = AppState::new(config).await?; + + let listener = TcpListener::bind(&format!("0.0.0.0:{}", state.config.http_port)) + .expect("failed to bind TCP listener"); + + let epoch = state + .db + .create_server(&state.config.zed_environment) + .await?; + let rpc_server = collab2::rpc::Server::new(epoch, state.clone(), Executor::Production); + rpc_server.start().await?; + + let app = collab2::api::routes(rpc_server.clone(), state.clone()) + .merge(collab2::rpc::routes(rpc_server.clone())) + .merge( + Router::new() + .route("/", get(handle_root)) + .route("/healthz", get(handle_liveness_probe)) + .layer(Extension(state.clone())), + ); + + axum::Server::from_tcp(listener)? + .serve(app.into_make_service_with_connect_info::()) + .with_graceful_shutdown(async move { + let mut sigterm = tokio::signal::unix::signal(SignalKind::terminate()) + .expect("failed to listen for interrupt signal"); + let mut sigint = tokio::signal::unix::signal(SignalKind::interrupt()) + .expect("failed to listen for interrupt signal"); + let sigterm = sigterm.recv(); + let sigint = sigint.recv(); + futures::pin_mut!(sigterm, sigint); + futures::future::select(sigterm, sigint).await; + tracing::info!("Received interrupt signal"); + rpc_server.teardown(); + }) + .await?; + } + _ => { + Err(anyhow!("usage: collab "))?; + } + } + Ok(()) +} + +async fn handle_root() -> String { + format!("collab v{VERSION}") +} + +async fn handle_liveness_probe(Extension(state): Extension>) -> Result { + state.db.get_all_users(0, 1).await?; + Ok("ok".to_string()) +} + +pub fn init_tracing(config: &Config) -> Option<()> { + use std::str::FromStr; + use tracing_subscriber::layer::SubscriberExt; + let rust_log = config.rust_log.clone()?; + + LogTracer::init().log_err()?; + + let subscriber = tracing_subscriber::Registry::default() + .with(if config.log_json.unwrap_or(false) { + Box::new( + tracing_subscriber::fmt::layer() + .fmt_fields(JsonFields::default()) + .event_format( + tracing_subscriber::fmt::format() + .json() + .flatten_event(true) + .with_span_list(true), + ), + ) as Box + Send + Sync> + } else { + Box::new( + tracing_subscriber::fmt::layer() + .event_format(tracing_subscriber::fmt::format().pretty()), + ) + }) + .with(EnvFilter::from_str(rust_log.as_str()).log_err()?); + + tracing::subscriber::set_global_default(subscriber).unwrap(); + + None +} diff --git a/crates/collab2/src/rpc.rs b/crates/collab2/src/rpc.rs new file mode 100644 index 0000000000..835b48809d --- /dev/null +++ b/crates/collab2/src/rpc.rs @@ -0,0 +1,3495 @@ +mod connection_pool; + +use crate::{ + auth, + db::{ + self, BufferId, ChannelId, ChannelRole, ChannelsForUser, CreateChannelResult, + CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId, + MoveChannelResult, NotificationId, ProjectId, RemoveChannelMemberResult, + RenameChannelResult, RespondToChannelInvite, RoomId, ServerId, SetChannelVisibilityResult, + User, UserId, + }, + executor::Executor, + AppState, Result, +}; +use anyhow::anyhow; +use async_tungstenite::tungstenite::{ + protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage, +}; +use axum::{ + body::Body, + extract::{ + ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage}, + ConnectInfo, WebSocketUpgrade, + }, + headers::{Header, HeaderName}, + http::StatusCode, + middleware, + response::IntoResponse, + routing::get, + Extension, Router, TypedHeader, +}; +use collections::{HashMap, HashSet}; +pub use connection_pool::ConnectionPool; +use futures::{ + channel::oneshot, + future::{self, BoxFuture}, + stream::FuturesUnordered, + FutureExt, SinkExt, StreamExt, TryStreamExt, +}; +use lazy_static::lazy_static; +use prometheus::{register_int_gauge, IntGauge}; +use rpc::{ + proto::{ + self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo, + RequestMessage, UpdateChannelBufferCollaborators, + }, + Connection, ConnectionId, Peer, Receipt, TypedEnvelope, +}; +use serde::{Serialize, Serializer}; +use std::{ + any::TypeId, + fmt, + future::Future, + marker::PhantomData, + mem, + net::SocketAddr, + ops::{Deref, DerefMut}, + rc::Rc, + sync::{ + atomic::{AtomicBool, Ordering::SeqCst}, + Arc, + }, + time::{Duration, Instant}, +}; +use time::OffsetDateTime; +use tokio::sync::{watch, Semaphore}; +use tower::ServiceBuilder; +use tracing::{info_span, instrument, Instrument}; +use util::channel::RELEASE_CHANNEL_NAME; + +pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30); +pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10); + +const MESSAGE_COUNT_PER_PAGE: usize = 100; +const MAX_MESSAGE_LEN: usize = 1024; +const NOTIFICATION_COUNT_PER_PAGE: usize = 50; + +lazy_static! { + static ref METRIC_CONNECTIONS: IntGauge = + register_int_gauge!("connections", "number of connections").unwrap(); + static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!( + "shared_projects", + "number of open projects with one or more guests" + ) + .unwrap(); +} + +type MessageHandler = + Box, Session) -> BoxFuture<'static, ()>>; + +struct Response { + peer: Arc, + receipt: Receipt, + responded: Arc, +} + +impl Response { + fn send(self, payload: R::Response) -> Result<()> { + self.responded.store(true, SeqCst); + self.peer.respond(self.receipt, payload)?; + Ok(()) + } +} + +#[derive(Clone)] +struct Session { + user_id: UserId, + connection_id: ConnectionId, + db: Arc>, + peer: Arc, + connection_pool: Arc>, + live_kit_client: Option>, + _executor: Executor, +} + +impl Session { + async fn db(&self) -> tokio::sync::MutexGuard { + #[cfg(test)] + tokio::task::yield_now().await; + let guard = self.db.lock().await; + #[cfg(test)] + tokio::task::yield_now().await; + guard + } + + async fn connection_pool(&self) -> ConnectionPoolGuard<'_> { + #[cfg(test)] + tokio::task::yield_now().await; + let guard = self.connection_pool.lock(); + ConnectionPoolGuard { + guard, + _not_send: PhantomData, + } + } +} + +impl fmt::Debug for Session { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Session") + .field("user_id", &self.user_id) + .field("connection_id", &self.connection_id) + .finish() + } +} + +struct DbHandle(Arc); + +impl Deref for DbHandle { + type Target = Database; + + fn deref(&self) -> &Self::Target { + self.0.as_ref() + } +} + +pub struct Server { + id: parking_lot::Mutex, + peer: Arc, + pub(crate) connection_pool: Arc>, + app_state: Arc, + executor: Executor, + handlers: HashMap, + teardown: watch::Sender<()>, +} + +pub(crate) struct ConnectionPoolGuard<'a> { + guard: parking_lot::MutexGuard<'a, ConnectionPool>, + _not_send: PhantomData>, +} + +#[derive(Serialize)] +pub struct ServerSnapshot<'a> { + peer: &'a Peer, + #[serde(serialize_with = "serialize_deref")] + connection_pool: ConnectionPoolGuard<'a>, +} + +pub fn serialize_deref(value: &T, serializer: S) -> Result +where + S: Serializer, + T: Deref, + U: Serialize, +{ + Serialize::serialize(value.deref(), serializer) +} + +impl Server { + pub fn new(id: ServerId, app_state: Arc, executor: Executor) -> Arc { + let mut server = Self { + id: parking_lot::Mutex::new(id), + peer: Peer::new(id.0 as u32), + app_state, + executor, + connection_pool: Default::default(), + handlers: Default::default(), + teardown: watch::channel(()).0, + }; + + server + .add_request_handler(ping) + .add_request_handler(create_room) + .add_request_handler(join_room) + .add_request_handler(rejoin_room) + .add_request_handler(leave_room) + .add_request_handler(call) + .add_request_handler(cancel_call) + .add_message_handler(decline_call) + .add_request_handler(update_participant_location) + .add_request_handler(share_project) + .add_message_handler(unshare_project) + .add_request_handler(join_project) + .add_message_handler(leave_project) + .add_request_handler(update_project) + .add_request_handler(update_worktree) + .add_message_handler(start_language_server) + .add_message_handler(update_language_server) + .add_message_handler(update_diagnostic_summary) + .add_message_handler(update_worktree_settings) + .add_message_handler(refresh_inlay_hints) + .add_request_handler(forward_project_request::) + .add_request_handler(forward_project_request::) + .add_request_handler(forward_project_request::) + .add_request_handler(forward_project_request::) + .add_request_handler(forward_project_request::) + .add_request_handler(forward_project_request::) + .add_request_handler(forward_project_request::) + .add_request_handler(forward_project_request::) + .add_request_handler(forward_project_request::) + .add_request_handler(forward_project_request::) + .add_request_handler(forward_project_request::) + .add_request_handler(forward_project_request::) + .add_request_handler(forward_project_request::) + .add_request_handler(forward_project_request::) + .add_request_handler(forward_project_request::) + .add_request_handler(forward_project_request::) + .add_request_handler(forward_project_request::) + .add_request_handler(forward_project_request::) + .add_request_handler(forward_project_request::) + .add_request_handler(forward_project_request::) + .add_request_handler(forward_project_request::) + .add_request_handler(forward_project_request::) + .add_request_handler(forward_project_request::) + .add_request_handler(forward_project_request::) + .add_request_handler(forward_project_request::) + .add_request_handler(forward_project_request::) + .add_request_handler(forward_project_request::) + .add_message_handler(create_buffer_for_peer) + .add_request_handler(update_buffer) + .add_message_handler(update_buffer_file) + .add_message_handler(buffer_reloaded) + .add_message_handler(buffer_saved) + .add_request_handler(forward_project_request::) + .add_request_handler(get_users) + .add_request_handler(fuzzy_search_users) + .add_request_handler(request_contact) + .add_request_handler(remove_contact) + .add_request_handler(respond_to_contact_request) + .add_request_handler(create_channel) + .add_request_handler(delete_channel) + .add_request_handler(invite_channel_member) + .add_request_handler(remove_channel_member) + .add_request_handler(set_channel_member_role) + .add_request_handler(set_channel_visibility) + .add_request_handler(rename_channel) + .add_request_handler(join_channel_buffer) + .add_request_handler(leave_channel_buffer) + .add_message_handler(update_channel_buffer) + .add_request_handler(rejoin_channel_buffers) + .add_request_handler(get_channel_members) + .add_request_handler(respond_to_channel_invite) + .add_request_handler(join_channel) + .add_request_handler(join_channel_chat) + .add_message_handler(leave_channel_chat) + .add_request_handler(send_channel_message) + .add_request_handler(remove_channel_message) + .add_request_handler(get_channel_messages) + .add_request_handler(get_channel_messages_by_id) + .add_request_handler(get_notifications) + .add_request_handler(mark_notification_as_read) + .add_request_handler(move_channel) + .add_request_handler(follow) + .add_message_handler(unfollow) + .add_message_handler(update_followers) + .add_message_handler(update_diff_base) + .add_request_handler(get_private_user_info) + .add_message_handler(acknowledge_channel_message) + .add_message_handler(acknowledge_buffer_version); + + Arc::new(server) + } + + pub async fn start(&self) -> Result<()> { + let server_id = *self.id.lock(); + let app_state = self.app_state.clone(); + let peer = self.peer.clone(); + let timeout = self.executor.sleep(CLEANUP_TIMEOUT); + let pool = self.connection_pool.clone(); + let live_kit_client = self.app_state.live_kit_client.clone(); + + let span = info_span!("start server"); + self.executor.spawn_detached( + async move { + tracing::info!("waiting for cleanup timeout"); + timeout.await; + tracing::info!("cleanup timeout expired, retrieving stale rooms"); + if let Some((room_ids, channel_ids)) = app_state + .db + .stale_server_resource_ids(&app_state.config.zed_environment, server_id) + .await + .trace_err() + { + tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms"); + tracing::info!( + stale_channel_buffer_count = channel_ids.len(), + "retrieved stale channel buffers" + ); + + for channel_id in channel_ids { + if let Some(refreshed_channel_buffer) = app_state + .db + .clear_stale_channel_buffer_collaborators(channel_id, server_id) + .await + .trace_err() + { + for connection_id in refreshed_channel_buffer.connection_ids { + peer.send( + connection_id, + proto::UpdateChannelBufferCollaborators { + channel_id: channel_id.to_proto(), + collaborators: refreshed_channel_buffer + .collaborators + .clone(), + }, + ) + .trace_err(); + } + } + } + + for room_id in room_ids { + let mut contacts_to_update = HashSet::default(); + let mut canceled_calls_to_user_ids = Vec::new(); + let mut live_kit_room = String::new(); + let mut delete_live_kit_room = false; + + if let Some(mut refreshed_room) = app_state + .db + .clear_stale_room_participants(room_id, server_id) + .await + .trace_err() + { + tracing::info!( + room_id = room_id.0, + new_participant_count = refreshed_room.room.participants.len(), + "refreshed room" + ); + room_updated(&refreshed_room.room, &peer); + if let Some(channel_id) = refreshed_room.channel_id { + channel_updated( + channel_id, + &refreshed_room.room, + &refreshed_room.channel_members, + &peer, + &*pool.lock(), + ); + } + contacts_to_update + .extend(refreshed_room.stale_participant_user_ids.iter().copied()); + contacts_to_update + .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied()); + canceled_calls_to_user_ids = + mem::take(&mut refreshed_room.canceled_calls_to_user_ids); + live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room); + delete_live_kit_room = refreshed_room.room.participants.is_empty(); + } + + { + let pool = pool.lock(); + for canceled_user_id in canceled_calls_to_user_ids { + for connection_id in pool.user_connection_ids(canceled_user_id) { + peer.send( + connection_id, + proto::CallCanceled { + room_id: room_id.to_proto(), + }, + ) + .trace_err(); + } + } + } + + for user_id in contacts_to_update { + let busy = app_state.db.is_user_busy(user_id).await.trace_err(); + let contacts = app_state.db.get_contacts(user_id).await.trace_err(); + if let Some((busy, contacts)) = busy.zip(contacts) { + let pool = pool.lock(); + let updated_contact = contact_for_user(user_id, busy, &pool); + for contact in contacts { + if let db::Contact::Accepted { + user_id: contact_user_id, + .. + } = contact + { + for contact_conn_id in + pool.user_connection_ids(contact_user_id) + { + peer.send( + contact_conn_id, + proto::UpdateContacts { + contacts: vec![updated_contact.clone()], + remove_contacts: Default::default(), + incoming_requests: Default::default(), + remove_incoming_requests: Default::default(), + outgoing_requests: Default::default(), + remove_outgoing_requests: Default::default(), + }, + ) + .trace_err(); + } + } + } + } + } + + if let Some(live_kit) = live_kit_client.as_ref() { + if delete_live_kit_room { + live_kit.delete_room(live_kit_room).await.trace_err(); + } + } + } + } + + app_state + .db + .delete_stale_servers(&app_state.config.zed_environment, server_id) + .await + .trace_err(); + } + .instrument(span), + ); + Ok(()) + } + + pub fn teardown(&self) { + self.peer.teardown(); + self.connection_pool.lock().reset(); + let _ = self.teardown.send(()); + } + + #[cfg(test)] + pub fn reset(&self, id: ServerId) { + self.teardown(); + *self.id.lock() = id; + self.peer.reset(id.0 as u32); + } + + #[cfg(test)] + pub fn id(&self) -> ServerId { + *self.id.lock() + } + + fn add_handler(&mut self, handler: F) -> &mut Self + where + F: 'static + Send + Sync + Fn(TypedEnvelope, Session) -> Fut, + Fut: 'static + Send + Future>, + M: EnvelopedMessage, + { + let prev_handler = self.handlers.insert( + TypeId::of::(), + Box::new(move |envelope, session| { + let envelope = envelope.into_any().downcast::>().unwrap(); + let span = info_span!( + "handle message", + payload_type = envelope.payload_type_name() + ); + span.in_scope(|| { + tracing::info!( + payload_type = envelope.payload_type_name(), + "message received" + ); + }); + let start_time = Instant::now(); + let future = (handler)(*envelope, session); + async move { + let result = future.await; + let duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0; + match result { + Err(error) => { + tracing::error!(%error, ?duration_ms, "error handling message") + } + Ok(()) => tracing::info!(?duration_ms, "finished handling message"), + } + } + .instrument(span) + .boxed() + }), + ); + if prev_handler.is_some() { + panic!("registered a handler for the same message twice"); + } + self + } + + fn add_message_handler(&mut self, handler: F) -> &mut Self + where + F: 'static + Send + Sync + Fn(M, Session) -> Fut, + Fut: 'static + Send + Future>, + M: EnvelopedMessage, + { + self.add_handler(move |envelope, session| handler(envelope.payload, session)); + self + } + + fn add_request_handler(&mut self, handler: F) -> &mut Self + where + F: 'static + Send + Sync + Fn(M, Response, Session) -> Fut, + Fut: Send + Future>, + M: RequestMessage, + { + let handler = Arc::new(handler); + self.add_handler(move |envelope, session| { + let receipt = envelope.receipt(); + let handler = handler.clone(); + async move { + let peer = session.peer.clone(); + let responded = Arc::new(AtomicBool::default()); + let response = Response { + peer: peer.clone(), + responded: responded.clone(), + receipt, + }; + match (handler)(envelope.payload, response, session).await { + Ok(()) => { + if responded.load(std::sync::atomic::Ordering::SeqCst) { + Ok(()) + } else { + Err(anyhow!("handler did not send a response"))? + } + } + Err(error) => { + peer.respond_with_error( + receipt, + proto::Error { + message: error.to_string(), + }, + )?; + Err(error) + } + } + } + }) + } + + pub fn handle_connection( + self: &Arc, + connection: Connection, + address: String, + user: User, + mut send_connection_id: Option>, + executor: Executor, + ) -> impl Future> { + let this = self.clone(); + let user_id = user.id; + let login = user.github_login; + let span = info_span!("handle connection", %user_id, %login, %address); + let mut teardown = self.teardown.subscribe(); + async move { + let (connection_id, handle_io, mut incoming_rx) = this + .peer + .add_connection(connection, { + let executor = executor.clone(); + move |duration| executor.sleep(duration) + }); + + tracing::info!(%user_id, %login, %connection_id, %address, "connection opened"); + this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?; + tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message"); + + if let Some(send_connection_id) = send_connection_id.take() { + let _ = send_connection_id.send(connection_id); + } + + if !user.connected_once { + this.peer.send(connection_id, proto::ShowContacts {})?; + this.app_state.db.set_user_connected_once(user_id, true).await?; + } + + let (contacts, channels_for_user, channel_invites) = future::try_join3( + this.app_state.db.get_contacts(user_id), + this.app_state.db.get_channels_for_user(user_id), + this.app_state.db.get_channel_invites_for_user(user_id), + ).await?; + + { + let mut pool = this.connection_pool.lock(); + pool.add_connection(connection_id, user_id, user.admin); + this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?; + this.peer.send(connection_id, build_channels_update( + channels_for_user, + channel_invites + ))?; + } + + if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? { + this.peer.send(connection_id, incoming_call)?; + } + + let session = Session { + user_id, + connection_id, + db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))), + peer: this.peer.clone(), + connection_pool: this.connection_pool.clone(), + live_kit_client: this.app_state.live_kit_client.clone(), + _executor: executor.clone() + }; + update_user_contacts(user_id, &session).await?; + + let handle_io = handle_io.fuse(); + futures::pin_mut!(handle_io); + + // Handlers for foreground messages are pushed into the following `FuturesUnordered`. + // This prevents deadlocks when e.g., client A performs a request to client B and + // client B performs a request to client A. If both clients stop processing further + // messages until their respective request completes, they won't have a chance to + // respond to the other client's request and cause a deadlock. + // + // This arrangement ensures we will attempt to process earlier messages first, but fall + // back to processing messages arrived later in the spirit of making progress. + let mut foreground_message_handlers = FuturesUnordered::new(); + let concurrent_handlers = Arc::new(Semaphore::new(256)); + loop { + let next_message = async { + let permit = concurrent_handlers.clone().acquire_owned().await.unwrap(); + let message = incoming_rx.next().await; + (permit, message) + }.fuse(); + futures::pin_mut!(next_message); + futures::select_biased! { + _ = teardown.changed().fuse() => return Ok(()), + result = handle_io => { + if let Err(error) = result { + tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O"); + } + break; + } + _ = foreground_message_handlers.next() => {} + next_message = next_message => { + let (permit, message) = next_message; + if let Some(message) = message { + let type_name = message.payload_type_name(); + let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name); + let span_enter = span.enter(); + if let Some(handler) = this.handlers.get(&message.payload_type_id()) { + let is_background = message.is_background(); + let handle_message = (handler)(message, session.clone()); + drop(span_enter); + + let handle_message = async move { + handle_message.await; + drop(permit); + }.instrument(span); + if is_background { + executor.spawn_detached(handle_message); + } else { + foreground_message_handlers.push(handle_message); + } + } else { + tracing::error!(%user_id, %login, %connection_id, %address, "no message handler"); + } + } else { + tracing::info!(%user_id, %login, %connection_id, %address, "connection closed"); + break; + } + } + } + } + + drop(foreground_message_handlers); + tracing::info!(%user_id, %login, %connection_id, %address, "signing out"); + if let Err(error) = connection_lost(session, teardown, executor).await { + tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out"); + } + + Ok(()) + }.instrument(span) + } + + pub async fn invite_code_redeemed( + self: &Arc, + inviter_id: UserId, + invitee_id: UserId, + ) -> Result<()> { + if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? { + if let Some(code) = &user.invite_code { + let pool = self.connection_pool.lock(); + let invitee_contact = contact_for_user(invitee_id, false, &pool); + for connection_id in pool.user_connection_ids(inviter_id) { + self.peer.send( + connection_id, + proto::UpdateContacts { + contacts: vec![invitee_contact.clone()], + ..Default::default() + }, + )?; + self.peer.send( + connection_id, + proto::UpdateInviteInfo { + url: format!("{}{}", self.app_state.config.invite_link_prefix, &code), + count: user.invite_count as u32, + }, + )?; + } + } + } + Ok(()) + } + + pub async fn invite_count_updated(self: &Arc, user_id: UserId) -> Result<()> { + if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? { + if let Some(invite_code) = &user.invite_code { + let pool = self.connection_pool.lock(); + for connection_id in pool.user_connection_ids(user_id) { + self.peer.send( + connection_id, + proto::UpdateInviteInfo { + url: format!( + "{}{}", + self.app_state.config.invite_link_prefix, invite_code + ), + count: user.invite_count as u32, + }, + )?; + } + } + } + Ok(()) + } + + pub async fn snapshot<'a>(self: &'a Arc) -> ServerSnapshot<'a> { + ServerSnapshot { + connection_pool: ConnectionPoolGuard { + guard: self.connection_pool.lock(), + _not_send: PhantomData, + }, + peer: &self.peer, + } + } +} + +impl<'a> Deref for ConnectionPoolGuard<'a> { + type Target = ConnectionPool; + + fn deref(&self) -> &Self::Target { + &*self.guard + } +} + +impl<'a> DerefMut for ConnectionPoolGuard<'a> { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut *self.guard + } +} + +impl<'a> Drop for ConnectionPoolGuard<'a> { + fn drop(&mut self) { + #[cfg(test)] + self.check_invariants(); + } +} + +fn broadcast( + sender_id: Option, + receiver_ids: impl IntoIterator, + mut f: F, +) where + F: FnMut(ConnectionId) -> anyhow::Result<()>, +{ + for receiver_id in receiver_ids { + if Some(receiver_id) != sender_id { + if let Err(error) = f(receiver_id) { + tracing::error!("failed to send to {:?} {}", receiver_id, error); + } + } + } +} + +lazy_static! { + static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version"); +} + +pub struct ProtocolVersion(u32); + +impl Header for ProtocolVersion { + fn name() -> &'static HeaderName { + &ZED_PROTOCOL_VERSION + } + + fn decode<'i, I>(values: &mut I) -> Result + where + Self: Sized, + I: Iterator, + { + let version = values + .next() + .ok_or_else(axum::headers::Error::invalid)? + .to_str() + .map_err(|_| axum::headers::Error::invalid())? + .parse() + .map_err(|_| axum::headers::Error::invalid())?; + Ok(Self(version)) + } + + fn encode>(&self, values: &mut E) { + values.extend([self.0.to_string().parse().unwrap()]); + } +} + +pub fn routes(server: Arc) -> Router { + Router::new() + .route("/rpc", get(handle_websocket_request)) + .layer( + ServiceBuilder::new() + .layer(Extension(server.app_state.clone())) + .layer(middleware::from_fn(auth::validate_header)), + ) + .route("/metrics", get(handle_metrics)) + .layer(Extension(server)) +} + +pub async fn handle_websocket_request( + TypedHeader(ProtocolVersion(protocol_version)): TypedHeader, + ConnectInfo(socket_address): ConnectInfo, + Extension(server): Extension>, + Extension(user): Extension, + ws: WebSocketUpgrade, +) -> axum::response::Response { + if protocol_version != rpc::PROTOCOL_VERSION { + return ( + StatusCode::UPGRADE_REQUIRED, + "client must be upgraded".to_string(), + ) + .into_response(); + } + let socket_address = socket_address.to_string(); + ws.on_upgrade(move |socket| { + use util::ResultExt; + let socket = socket + .map_ok(to_tungstenite_message) + .err_into() + .with(|message| async move { Ok(to_axum_message(message)) }); + let connection = Connection::new(Box::pin(socket)); + async move { + server + .handle_connection(connection, socket_address, user, None, Executor::Production) + .await + .log_err(); + } + }) +} + +pub async fn handle_metrics(Extension(server): Extension>) -> Result { + let connections = server + .connection_pool + .lock() + .connections() + .filter(|connection| !connection.admin) + .count(); + + METRIC_CONNECTIONS.set(connections as _); + + let shared_projects = server.app_state.db.project_count_excluding_admins().await?; + METRIC_SHARED_PROJECTS.set(shared_projects as _); + + let encoder = prometheus::TextEncoder::new(); + let metric_families = prometheus::gather(); + let encoded_metrics = encoder + .encode_to_string(&metric_families) + .map_err(|err| anyhow!("{}", err))?; + Ok(encoded_metrics) +} + +#[instrument(err, skip(executor))] +async fn connection_lost( + session: Session, + mut teardown: watch::Receiver<()>, + executor: Executor, +) -> Result<()> { + session.peer.disconnect(session.connection_id); + session + .connection_pool() + .await + .remove_connection(session.connection_id)?; + + session + .db() + .await + .connection_lost(session.connection_id) + .await + .trace_err(); + + futures::select_biased! { + _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => { + log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id, session.connection_id); + leave_room_for_session(&session).await.trace_err(); + leave_channel_buffers_for_session(&session) + .await + .trace_err(); + + if !session + .connection_pool() + .await + .is_user_online(session.user_id) + { + let db = session.db().await; + if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() { + room_updated(&room, &session.peer); + } + } + + update_user_contacts(session.user_id, &session).await?; + } + _ = teardown.changed().fuse() => {} + } + + Ok(()) +} + +async fn ping(_: proto::Ping, response: Response, _session: Session) -> Result<()> { + response.send(proto::Ack {})?; + Ok(()) +} + +async fn create_room( + _request: proto::CreateRoom, + response: Response, + session: Session, +) -> Result<()> { + let live_kit_room = nanoid::nanoid!(30); + + let live_kit_connection_info = { + let live_kit_room = live_kit_room.clone(); + let live_kit = session.live_kit_client.as_ref(); + + util::async_maybe!({ + let live_kit = live_kit?; + + let token = live_kit + .room_token(&live_kit_room, &session.user_id.to_string()) + .trace_err()?; + + Some(proto::LiveKitConnectionInfo { + server_url: live_kit.url().into(), + token, + can_publish: true, + }) + }) + } + .await; + + let room = session + .db() + .await + .create_room( + session.user_id, + session.connection_id, + &live_kit_room, + RELEASE_CHANNEL_NAME.as_str(), + ) + .await?; + + response.send(proto::CreateRoomResponse { + room: Some(room.clone()), + live_kit_connection_info, + })?; + + update_user_contacts(session.user_id, &session).await?; + Ok(()) +} + +async fn join_room( + request: proto::JoinRoom, + response: Response, + session: Session, +) -> Result<()> { + let room_id = RoomId::from_proto(request.id); + + let channel_id = session.db().await.channel_id_for_room(room_id).await?; + + if let Some(channel_id) = channel_id { + return join_channel_internal(channel_id, Box::new(response), session).await; + } + + let joined_room = { + let room = session + .db() + .await + .join_room( + room_id, + session.user_id, + session.connection_id, + RELEASE_CHANNEL_NAME.as_str(), + ) + .await?; + room_updated(&room.room, &session.peer); + room.into_inner() + }; + + for connection_id in session + .connection_pool() + .await + .user_connection_ids(session.user_id) + { + session + .peer + .send( + connection_id, + proto::CallCanceled { + room_id: room_id.to_proto(), + }, + ) + .trace_err(); + } + + let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() { + if let Some(token) = live_kit + .room_token( + &joined_room.room.live_kit_room, + &session.user_id.to_string(), + ) + .trace_err() + { + Some(proto::LiveKitConnectionInfo { + server_url: live_kit.url().into(), + token, + can_publish: true, + }) + } else { + None + } + } else { + None + }; + + response.send(proto::JoinRoomResponse { + room: Some(joined_room.room), + channel_id: None, + live_kit_connection_info, + })?; + + update_user_contacts(session.user_id, &session).await?; + Ok(()) +} + +async fn rejoin_room( + request: proto::RejoinRoom, + response: Response, + session: Session, +) -> Result<()> { + let room; + let channel_id; + let channel_members; + { + let mut rejoined_room = session + .db() + .await + .rejoin_room(request, session.user_id, session.connection_id) + .await?; + + response.send(proto::RejoinRoomResponse { + room: Some(rejoined_room.room.clone()), + reshared_projects: rejoined_room + .reshared_projects + .iter() + .map(|project| proto::ResharedProject { + id: project.id.to_proto(), + collaborators: project + .collaborators + .iter() + .map(|collaborator| collaborator.to_proto()) + .collect(), + }) + .collect(), + rejoined_projects: rejoined_room + .rejoined_projects + .iter() + .map(|rejoined_project| proto::RejoinedProject { + id: rejoined_project.id.to_proto(), + worktrees: rejoined_project + .worktrees + .iter() + .map(|worktree| proto::WorktreeMetadata { + id: worktree.id, + root_name: worktree.root_name.clone(), + visible: worktree.visible, + abs_path: worktree.abs_path.clone(), + }) + .collect(), + collaborators: rejoined_project + .collaborators + .iter() + .map(|collaborator| collaborator.to_proto()) + .collect(), + language_servers: rejoined_project.language_servers.clone(), + }) + .collect(), + })?; + room_updated(&rejoined_room.room, &session.peer); + + for project in &rejoined_room.reshared_projects { + for collaborator in &project.collaborators { + session + .peer + .send( + collaborator.connection_id, + proto::UpdateProjectCollaborator { + project_id: project.id.to_proto(), + old_peer_id: Some(project.old_connection_id.into()), + new_peer_id: Some(session.connection_id.into()), + }, + ) + .trace_err(); + } + + broadcast( + Some(session.connection_id), + project + .collaborators + .iter() + .map(|collaborator| collaborator.connection_id), + |connection_id| { + session.peer.forward_send( + session.connection_id, + connection_id, + proto::UpdateProject { + project_id: project.id.to_proto(), + worktrees: project.worktrees.clone(), + }, + ) + }, + ); + } + + for project in &rejoined_room.rejoined_projects { + for collaborator in &project.collaborators { + session + .peer + .send( + collaborator.connection_id, + proto::UpdateProjectCollaborator { + project_id: project.id.to_proto(), + old_peer_id: Some(project.old_connection_id.into()), + new_peer_id: Some(session.connection_id.into()), + }, + ) + .trace_err(); + } + } + + for project in &mut rejoined_room.rejoined_projects { + for worktree in mem::take(&mut project.worktrees) { + #[cfg(any(test, feature = "test-support"))] + const MAX_CHUNK_SIZE: usize = 2; + #[cfg(not(any(test, feature = "test-support")))] + const MAX_CHUNK_SIZE: usize = 256; + + // Stream this worktree's entries. + let message = proto::UpdateWorktree { + project_id: project.id.to_proto(), + worktree_id: worktree.id, + abs_path: worktree.abs_path.clone(), + root_name: worktree.root_name, + updated_entries: worktree.updated_entries, + removed_entries: worktree.removed_entries, + scan_id: worktree.scan_id, + is_last_update: worktree.completed_scan_id == worktree.scan_id, + updated_repositories: worktree.updated_repositories, + removed_repositories: worktree.removed_repositories, + }; + for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) { + session.peer.send(session.connection_id, update.clone())?; + } + + // Stream this worktree's diagnostics. + for summary in worktree.diagnostic_summaries { + session.peer.send( + session.connection_id, + proto::UpdateDiagnosticSummary { + project_id: project.id.to_proto(), + worktree_id: worktree.id, + summary: Some(summary), + }, + )?; + } + + for settings_file in worktree.settings_files { + session.peer.send( + session.connection_id, + proto::UpdateWorktreeSettings { + project_id: project.id.to_proto(), + worktree_id: worktree.id, + path: settings_file.path, + content: Some(settings_file.content), + }, + )?; + } + } + + for language_server in &project.language_servers { + session.peer.send( + session.connection_id, + proto::UpdateLanguageServer { + project_id: project.id.to_proto(), + language_server_id: language_server.id, + variant: Some( + proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated( + proto::LspDiskBasedDiagnosticsUpdated {}, + ), + ), + }, + )?; + } + } + + let rejoined_room = rejoined_room.into_inner(); + + room = rejoined_room.room; + channel_id = rejoined_room.channel_id; + channel_members = rejoined_room.channel_members; + } + + if let Some(channel_id) = channel_id { + channel_updated( + channel_id, + &room, + &channel_members, + &session.peer, + &*session.connection_pool().await, + ); + } + + update_user_contacts(session.user_id, &session).await?; + Ok(()) +} + +async fn leave_room( + _: proto::LeaveRoom, + response: Response, + session: Session, +) -> Result<()> { + leave_room_for_session(&session).await?; + response.send(proto::Ack {})?; + Ok(()) +} + +async fn call( + request: proto::Call, + response: Response, + session: Session, +) -> Result<()> { + let room_id = RoomId::from_proto(request.room_id); + let calling_user_id = session.user_id; + let calling_connection_id = session.connection_id; + let called_user_id = UserId::from_proto(request.called_user_id); + let initial_project_id = request.initial_project_id.map(ProjectId::from_proto); + if !session + .db() + .await + .has_contact(calling_user_id, called_user_id) + .await? + { + return Err(anyhow!("cannot call a user who isn't a contact"))?; + } + + let incoming_call = { + let (room, incoming_call) = &mut *session + .db() + .await + .call( + room_id, + calling_user_id, + calling_connection_id, + called_user_id, + initial_project_id, + ) + .await?; + room_updated(&room, &session.peer); + mem::take(incoming_call) + }; + update_user_contacts(called_user_id, &session).await?; + + let mut calls = session + .connection_pool() + .await + .user_connection_ids(called_user_id) + .map(|connection_id| session.peer.request(connection_id, incoming_call.clone())) + .collect::>(); + + while let Some(call_response) = calls.next().await { + match call_response.as_ref() { + Ok(_) => { + response.send(proto::Ack {})?; + return Ok(()); + } + Err(_) => { + call_response.trace_err(); + } + } + } + + { + let room = session + .db() + .await + .call_failed(room_id, called_user_id) + .await?; + room_updated(&room, &session.peer); + } + update_user_contacts(called_user_id, &session).await?; + + Err(anyhow!("failed to ring user"))? +} + +async fn cancel_call( + request: proto::CancelCall, + response: Response, + session: Session, +) -> Result<()> { + let called_user_id = UserId::from_proto(request.called_user_id); + let room_id = RoomId::from_proto(request.room_id); + { + let room = session + .db() + .await + .cancel_call(room_id, session.connection_id, called_user_id) + .await?; + room_updated(&room, &session.peer); + } + + for connection_id in session + .connection_pool() + .await + .user_connection_ids(called_user_id) + { + session + .peer + .send( + connection_id, + proto::CallCanceled { + room_id: room_id.to_proto(), + }, + ) + .trace_err(); + } + response.send(proto::Ack {})?; + + update_user_contacts(called_user_id, &session).await?; + Ok(()) +} + +async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> { + let room_id = RoomId::from_proto(message.room_id); + { + let room = session + .db() + .await + .decline_call(Some(room_id), session.user_id) + .await? + .ok_or_else(|| anyhow!("failed to decline call"))?; + room_updated(&room, &session.peer); + } + + for connection_id in session + .connection_pool() + .await + .user_connection_ids(session.user_id) + { + session + .peer + .send( + connection_id, + proto::CallCanceled { + room_id: room_id.to_proto(), + }, + ) + .trace_err(); + } + update_user_contacts(session.user_id, &session).await?; + Ok(()) +} + +async fn update_participant_location( + request: proto::UpdateParticipantLocation, + response: Response, + session: Session, +) -> Result<()> { + let room_id = RoomId::from_proto(request.room_id); + let location = request + .location + .ok_or_else(|| anyhow!("invalid location"))?; + + let db = session.db().await; + let room = db + .update_room_participant_location(room_id, session.connection_id, location) + .await?; + + room_updated(&room, &session.peer); + response.send(proto::Ack {})?; + Ok(()) +} + +async fn share_project( + request: proto::ShareProject, + response: Response, + session: Session, +) -> Result<()> { + let (project_id, room) = &*session + .db() + .await + .share_project( + RoomId::from_proto(request.room_id), + session.connection_id, + &request.worktrees, + ) + .await?; + response.send(proto::ShareProjectResponse { + project_id: project_id.to_proto(), + })?; + room_updated(&room, &session.peer); + + Ok(()) +} + +async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> { + let project_id = ProjectId::from_proto(message.project_id); + + let (room, guest_connection_ids) = &*session + .db() + .await + .unshare_project(project_id, session.connection_id) + .await?; + + broadcast( + Some(session.connection_id), + guest_connection_ids.iter().copied(), + |conn_id| session.peer.send(conn_id, message.clone()), + ); + room_updated(&room, &session.peer); + + Ok(()) +} + +async fn join_project( + request: proto::JoinProject, + response: Response, + session: Session, +) -> Result<()> { + let project_id = ProjectId::from_proto(request.project_id); + let guest_user_id = session.user_id; + + tracing::info!(%project_id, "join project"); + + let (project, replica_id) = &mut *session + .db() + .await + .join_project(project_id, session.connection_id) + .await?; + + let collaborators = project + .collaborators + .iter() + .filter(|collaborator| collaborator.connection_id != session.connection_id) + .map(|collaborator| collaborator.to_proto()) + .collect::>(); + + let worktrees = project + .worktrees + .iter() + .map(|(id, worktree)| proto::WorktreeMetadata { + id: *id, + root_name: worktree.root_name.clone(), + visible: worktree.visible, + abs_path: worktree.abs_path.clone(), + }) + .collect::>(); + + for collaborator in &collaborators { + session + .peer + .send( + collaborator.peer_id.unwrap().into(), + proto::AddProjectCollaborator { + project_id: project_id.to_proto(), + collaborator: Some(proto::Collaborator { + peer_id: Some(session.connection_id.into()), + replica_id: replica_id.0 as u32, + user_id: guest_user_id.to_proto(), + }), + }, + ) + .trace_err(); + } + + // First, we send the metadata associated with each worktree. + response.send(proto::JoinProjectResponse { + worktrees: worktrees.clone(), + replica_id: replica_id.0 as u32, + collaborators: collaborators.clone(), + language_servers: project.language_servers.clone(), + })?; + + for (worktree_id, worktree) in mem::take(&mut project.worktrees) { + #[cfg(any(test, feature = "test-support"))] + const MAX_CHUNK_SIZE: usize = 2; + #[cfg(not(any(test, feature = "test-support")))] + const MAX_CHUNK_SIZE: usize = 256; + + // Stream this worktree's entries. + let message = proto::UpdateWorktree { + project_id: project_id.to_proto(), + worktree_id, + abs_path: worktree.abs_path.clone(), + root_name: worktree.root_name, + updated_entries: worktree.entries, + removed_entries: Default::default(), + scan_id: worktree.scan_id, + is_last_update: worktree.scan_id == worktree.completed_scan_id, + updated_repositories: worktree.repository_entries.into_values().collect(), + removed_repositories: Default::default(), + }; + for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) { + session.peer.send(session.connection_id, update.clone())?; + } + + // Stream this worktree's diagnostics. + for summary in worktree.diagnostic_summaries { + session.peer.send( + session.connection_id, + proto::UpdateDiagnosticSummary { + project_id: project_id.to_proto(), + worktree_id: worktree.id, + summary: Some(summary), + }, + )?; + } + + for settings_file in worktree.settings_files { + session.peer.send( + session.connection_id, + proto::UpdateWorktreeSettings { + project_id: project_id.to_proto(), + worktree_id: worktree.id, + path: settings_file.path, + content: Some(settings_file.content), + }, + )?; + } + } + + for language_server in &project.language_servers { + session.peer.send( + session.connection_id, + proto::UpdateLanguageServer { + project_id: project_id.to_proto(), + language_server_id: language_server.id, + variant: Some( + proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated( + proto::LspDiskBasedDiagnosticsUpdated {}, + ), + ), + }, + )?; + } + + Ok(()) +} + +async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> { + let sender_id = session.connection_id; + let project_id = ProjectId::from_proto(request.project_id); + + let (room, project) = &*session + .db() + .await + .leave_project(project_id, sender_id) + .await?; + tracing::info!( + %project_id, + host_user_id = %project.host_user_id, + host_connection_id = %project.host_connection_id, + "leave project" + ); + + project_left(&project, &session); + room_updated(&room, &session.peer); + + Ok(()) +} + +async fn update_project( + request: proto::UpdateProject, + response: Response, + session: Session, +) -> Result<()> { + let project_id = ProjectId::from_proto(request.project_id); + let (room, guest_connection_ids) = &*session + .db() + .await + .update_project(project_id, session.connection_id, &request.worktrees) + .await?; + broadcast( + Some(session.connection_id), + guest_connection_ids.iter().copied(), + |connection_id| { + session + .peer + .forward_send(session.connection_id, connection_id, request.clone()) + }, + ); + room_updated(&room, &session.peer); + response.send(proto::Ack {})?; + + Ok(()) +} + +async fn update_worktree( + request: proto::UpdateWorktree, + response: Response, + session: Session, +) -> Result<()> { + let guest_connection_ids = session + .db() + .await + .update_worktree(&request, session.connection_id) + .await?; + + broadcast( + Some(session.connection_id), + guest_connection_ids.iter().copied(), + |connection_id| { + session + .peer + .forward_send(session.connection_id, connection_id, request.clone()) + }, + ); + response.send(proto::Ack {})?; + Ok(()) +} + +async fn update_diagnostic_summary( + message: proto::UpdateDiagnosticSummary, + session: Session, +) -> Result<()> { + let guest_connection_ids = session + .db() + .await + .update_diagnostic_summary(&message, session.connection_id) + .await?; + + broadcast( + Some(session.connection_id), + guest_connection_ids.iter().copied(), + |connection_id| { + session + .peer + .forward_send(session.connection_id, connection_id, message.clone()) + }, + ); + + Ok(()) +} + +async fn update_worktree_settings( + message: proto::UpdateWorktreeSettings, + session: Session, +) -> Result<()> { + let guest_connection_ids = session + .db() + .await + .update_worktree_settings(&message, session.connection_id) + .await?; + + broadcast( + Some(session.connection_id), + guest_connection_ids.iter().copied(), + |connection_id| { + session + .peer + .forward_send(session.connection_id, connection_id, message.clone()) + }, + ); + + Ok(()) +} + +async fn refresh_inlay_hints(request: proto::RefreshInlayHints, session: Session) -> Result<()> { + broadcast_project_message(request.project_id, request, session).await +} + +async fn start_language_server( + request: proto::StartLanguageServer, + session: Session, +) -> Result<()> { + let guest_connection_ids = session + .db() + .await + .start_language_server(&request, session.connection_id) + .await?; + + broadcast( + Some(session.connection_id), + guest_connection_ids.iter().copied(), + |connection_id| { + session + .peer + .forward_send(session.connection_id, connection_id, request.clone()) + }, + ); + Ok(()) +} + +async fn update_language_server( + request: proto::UpdateLanguageServer, + session: Session, +) -> Result<()> { + let project_id = ProjectId::from_proto(request.project_id); + let project_connection_ids = session + .db() + .await + .project_connection_ids(project_id, session.connection_id) + .await?; + broadcast( + Some(session.connection_id), + project_connection_ids.iter().copied(), + |connection_id| { + session + .peer + .forward_send(session.connection_id, connection_id, request.clone()) + }, + ); + Ok(()) +} + +async fn forward_project_request( + request: T, + response: Response, + session: Session, +) -> Result<()> +where + T: EntityMessage + RequestMessage, +{ + let project_id = ProjectId::from_proto(request.remote_entity_id()); + let host_connection_id = { + let collaborators = session + .db() + .await + .project_collaborators(project_id, session.connection_id) + .await?; + collaborators + .iter() + .find(|collaborator| collaborator.is_host) + .ok_or_else(|| anyhow!("host not found"))? + .connection_id + }; + + let payload = session + .peer + .forward_request(session.connection_id, host_connection_id, request) + .await?; + + response.send(payload)?; + Ok(()) +} + +async fn create_buffer_for_peer( + request: proto::CreateBufferForPeer, + session: Session, +) -> Result<()> { + let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?; + session + .peer + .forward_send(session.connection_id, peer_id.into(), request)?; + Ok(()) +} + +async fn update_buffer( + request: proto::UpdateBuffer, + response: Response, + session: Session, +) -> Result<()> { + let project_id = ProjectId::from_proto(request.project_id); + let mut guest_connection_ids; + let mut host_connection_id = None; + { + let collaborators = session + .db() + .await + .project_collaborators(project_id, session.connection_id) + .await?; + guest_connection_ids = Vec::with_capacity(collaborators.len() - 1); + for collaborator in collaborators.iter() { + if collaborator.is_host { + host_connection_id = Some(collaborator.connection_id); + } else { + guest_connection_ids.push(collaborator.connection_id); + } + } + } + let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?; + + broadcast( + Some(session.connection_id), + guest_connection_ids, + |connection_id| { + session + .peer + .forward_send(session.connection_id, connection_id, request.clone()) + }, + ); + if host_connection_id != session.connection_id { + session + .peer + .forward_request(session.connection_id, host_connection_id, request.clone()) + .await?; + } + + response.send(proto::Ack {})?; + Ok(()) +} + +async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> { + let project_id = ProjectId::from_proto(request.project_id); + let project_connection_ids = session + .db() + .await + .project_connection_ids(project_id, session.connection_id) + .await?; + + broadcast( + Some(session.connection_id), + project_connection_ids.iter().copied(), + |connection_id| { + session + .peer + .forward_send(session.connection_id, connection_id, request.clone()) + }, + ); + Ok(()) +} + +async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> { + let project_id = ProjectId::from_proto(request.project_id); + let project_connection_ids = session + .db() + .await + .project_connection_ids(project_id, session.connection_id) + .await?; + broadcast( + Some(session.connection_id), + project_connection_ids.iter().copied(), + |connection_id| { + session + .peer + .forward_send(session.connection_id, connection_id, request.clone()) + }, + ); + Ok(()) +} + +async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> { + broadcast_project_message(request.project_id, request, session).await +} + +async fn broadcast_project_message( + project_id: u64, + request: T, + session: Session, +) -> Result<()> { + let project_id = ProjectId::from_proto(project_id); + let project_connection_ids = session + .db() + .await + .project_connection_ids(project_id, session.connection_id) + .await?; + broadcast( + Some(session.connection_id), + project_connection_ids.iter().copied(), + |connection_id| { + session + .peer + .forward_send(session.connection_id, connection_id, request.clone()) + }, + ); + Ok(()) +} + +async fn follow( + request: proto::Follow, + response: Response, + session: Session, +) -> Result<()> { + let room_id = RoomId::from_proto(request.room_id); + let project_id = request.project_id.map(ProjectId::from_proto); + let leader_id = request + .leader_id + .ok_or_else(|| anyhow!("invalid leader id"))? + .into(); + let follower_id = session.connection_id; + + session + .db() + .await + .check_room_participants(room_id, leader_id, session.connection_id) + .await?; + + let response_payload = session + .peer + .forward_request(session.connection_id, leader_id, request) + .await?; + response.send(response_payload)?; + + if let Some(project_id) = project_id { + let room = session + .db() + .await + .follow(room_id, project_id, leader_id, follower_id) + .await?; + room_updated(&room, &session.peer); + } + + Ok(()) +} + +async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> { + let room_id = RoomId::from_proto(request.room_id); + let project_id = request.project_id.map(ProjectId::from_proto); + let leader_id = request + .leader_id + .ok_or_else(|| anyhow!("invalid leader id"))? + .into(); + let follower_id = session.connection_id; + + session + .db() + .await + .check_room_participants(room_id, leader_id, session.connection_id) + .await?; + + session + .peer + .forward_send(session.connection_id, leader_id, request)?; + + if let Some(project_id) = project_id { + let room = session + .db() + .await + .unfollow(room_id, project_id, leader_id, follower_id) + .await?; + room_updated(&room, &session.peer); + } + + Ok(()) +} + +async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> { + let room_id = RoomId::from_proto(request.room_id); + let database = session.db.lock().await; + + let connection_ids = if let Some(project_id) = request.project_id { + let project_id = ProjectId::from_proto(project_id); + database + .project_connection_ids(project_id, session.connection_id) + .await? + } else { + database + .room_connection_ids(room_id, session.connection_id) + .await? + }; + + // For now, don't send view update messages back to that view's current leader. + let connection_id_to_omit = request.variant.as_ref().and_then(|variant| match variant { + proto::update_followers::Variant::UpdateView(payload) => payload.leader_id, + _ => None, + }); + + for follower_peer_id in request.follower_ids.iter().copied() { + let follower_connection_id = follower_peer_id.into(); + if Some(follower_peer_id) != connection_id_to_omit + && connection_ids.contains(&follower_connection_id) + { + session.peer.forward_send( + session.connection_id, + follower_connection_id, + request.clone(), + )?; + } + } + Ok(()) +} + +async fn get_users( + request: proto::GetUsers, + response: Response, + session: Session, +) -> Result<()> { + let user_ids = request + .user_ids + .into_iter() + .map(UserId::from_proto) + .collect(); + let users = session + .db() + .await + .get_users_by_ids(user_ids) + .await? + .into_iter() + .map(|user| proto::User { + id: user.id.to_proto(), + avatar_url: format!("https://github.com/{}.png?size=128", user.github_login), + github_login: user.github_login, + }) + .collect(); + response.send(proto::UsersResponse { users })?; + Ok(()) +} + +async fn fuzzy_search_users( + request: proto::FuzzySearchUsers, + response: Response, + session: Session, +) -> Result<()> { + let query = request.query; + let users = match query.len() { + 0 => vec![], + 1 | 2 => session + .db() + .await + .get_user_by_github_login(&query) + .await? + .into_iter() + .collect(), + _ => session.db().await.fuzzy_search_users(&query, 10).await?, + }; + let users = users + .into_iter() + .filter(|user| user.id != session.user_id) + .map(|user| proto::User { + id: user.id.to_proto(), + avatar_url: format!("https://github.com/{}.png?size=128", user.github_login), + github_login: user.github_login, + }) + .collect(); + response.send(proto::UsersResponse { users })?; + Ok(()) +} + +async fn request_contact( + request: proto::RequestContact, + response: Response, + session: Session, +) -> Result<()> { + let requester_id = session.user_id; + let responder_id = UserId::from_proto(request.responder_id); + if requester_id == responder_id { + return Err(anyhow!("cannot add yourself as a contact"))?; + } + + let notifications = session + .db() + .await + .send_contact_request(requester_id, responder_id) + .await?; + + // Update outgoing contact requests of requester + let mut update = proto::UpdateContacts::default(); + update.outgoing_requests.push(responder_id.to_proto()); + for connection_id in session + .connection_pool() + .await + .user_connection_ids(requester_id) + { + session.peer.send(connection_id, update.clone())?; + } + + // Update incoming contact requests of responder + let mut update = proto::UpdateContacts::default(); + update + .incoming_requests + .push(proto::IncomingContactRequest { + requester_id: requester_id.to_proto(), + }); + let connection_pool = session.connection_pool().await; + for connection_id in connection_pool.user_connection_ids(responder_id) { + session.peer.send(connection_id, update.clone())?; + } + + send_notifications(&*connection_pool, &session.peer, notifications); + + response.send(proto::Ack {})?; + Ok(()) +} + +async fn respond_to_contact_request( + request: proto::RespondToContactRequest, + response: Response, + session: Session, +) -> Result<()> { + let responder_id = session.user_id; + let requester_id = UserId::from_proto(request.requester_id); + let db = session.db().await; + if request.response == proto::ContactRequestResponse::Dismiss as i32 { + db.dismiss_contact_notification(responder_id, requester_id) + .await?; + } else { + let accept = request.response == proto::ContactRequestResponse::Accept as i32; + + let notifications = db + .respond_to_contact_request(responder_id, requester_id, accept) + .await?; + let requester_busy = db.is_user_busy(requester_id).await?; + let responder_busy = db.is_user_busy(responder_id).await?; + + let pool = session.connection_pool().await; + // Update responder with new contact + let mut update = proto::UpdateContacts::default(); + if accept { + update + .contacts + .push(contact_for_user(requester_id, requester_busy, &pool)); + } + update + .remove_incoming_requests + .push(requester_id.to_proto()); + for connection_id in pool.user_connection_ids(responder_id) { + session.peer.send(connection_id, update.clone())?; + } + + // Update requester with new contact + let mut update = proto::UpdateContacts::default(); + if accept { + update + .contacts + .push(contact_for_user(responder_id, responder_busy, &pool)); + } + update + .remove_outgoing_requests + .push(responder_id.to_proto()); + + for connection_id in pool.user_connection_ids(requester_id) { + session.peer.send(connection_id, update.clone())?; + } + + send_notifications(&*pool, &session.peer, notifications); + } + + response.send(proto::Ack {})?; + Ok(()) +} + +async fn remove_contact( + request: proto::RemoveContact, + response: Response, + session: Session, +) -> Result<()> { + let requester_id = session.user_id; + let responder_id = UserId::from_proto(request.user_id); + let db = session.db().await; + let (contact_accepted, deleted_notification_id) = + db.remove_contact(requester_id, responder_id).await?; + + let pool = session.connection_pool().await; + // Update outgoing contact requests of requester + let mut update = proto::UpdateContacts::default(); + if contact_accepted { + update.remove_contacts.push(responder_id.to_proto()); + } else { + update + .remove_outgoing_requests + .push(responder_id.to_proto()); + } + for connection_id in pool.user_connection_ids(requester_id) { + session.peer.send(connection_id, update.clone())?; + } + + // Update incoming contact requests of responder + let mut update = proto::UpdateContacts::default(); + if contact_accepted { + update.remove_contacts.push(requester_id.to_proto()); + } else { + update + .remove_incoming_requests + .push(requester_id.to_proto()); + } + for connection_id in pool.user_connection_ids(responder_id) { + session.peer.send(connection_id, update.clone())?; + if let Some(notification_id) = deleted_notification_id { + session.peer.send( + connection_id, + proto::DeleteNotification { + notification_id: notification_id.to_proto(), + }, + )?; + } + } + + response.send(proto::Ack {})?; + Ok(()) +} + +async fn create_channel( + request: proto::CreateChannel, + response: Response, + session: Session, +) -> Result<()> { + let db = session.db().await; + + let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id)); + let CreateChannelResult { + channel, + participants_to_update, + } = db + .create_channel(&request.name, parent_id, session.user_id) + .await?; + + response.send(proto::CreateChannelResponse { + channel: Some(channel.to_proto()), + parent_id: request.parent_id, + })?; + + let connection_pool = session.connection_pool().await; + for (user_id, channels) in participants_to_update { + let update = build_channels_update(channels, vec![]); + for connection_id in connection_pool.user_connection_ids(user_id) { + if user_id == session.user_id { + continue; + } + session.peer.send(connection_id, update.clone())?; + } + } + + Ok(()) +} + +async fn delete_channel( + request: proto::DeleteChannel, + response: Response, + session: Session, +) -> Result<()> { + let db = session.db().await; + + let channel_id = request.channel_id; + let (removed_channels, member_ids) = db + .delete_channel(ChannelId::from_proto(channel_id), session.user_id) + .await?; + response.send(proto::Ack {})?; + + // Notify members of removed channels + let mut update = proto::UpdateChannels::default(); + update + .delete_channels + .extend(removed_channels.into_iter().map(|id| id.to_proto())); + + let connection_pool = session.connection_pool().await; + for member_id in member_ids { + for connection_id in connection_pool.user_connection_ids(member_id) { + session.peer.send(connection_id, update.clone())?; + } + } + + Ok(()) +} + +async fn invite_channel_member( + request: proto::InviteChannelMember, + response: Response, + session: Session, +) -> Result<()> { + let db = session.db().await; + let channel_id = ChannelId::from_proto(request.channel_id); + let invitee_id = UserId::from_proto(request.user_id); + let InviteMemberResult { + channel, + notifications, + } = db + .invite_channel_member( + channel_id, + invitee_id, + session.user_id, + request.role().into(), + ) + .await?; + + let update = proto::UpdateChannels { + channel_invitations: vec![channel.to_proto()], + ..Default::default() + }; + + let connection_pool = session.connection_pool().await; + for connection_id in connection_pool.user_connection_ids(invitee_id) { + session.peer.send(connection_id, update.clone())?; + } + + send_notifications(&*connection_pool, &session.peer, notifications); + + response.send(proto::Ack {})?; + Ok(()) +} + +async fn remove_channel_member( + request: proto::RemoveChannelMember, + response: Response, + session: Session, +) -> Result<()> { + let db = session.db().await; + let channel_id = ChannelId::from_proto(request.channel_id); + let member_id = UserId::from_proto(request.user_id); + + let RemoveChannelMemberResult { + membership_update, + notification_id, + } = db + .remove_channel_member(channel_id, member_id, session.user_id) + .await?; + + let connection_pool = &session.connection_pool().await; + notify_membership_updated( + &connection_pool, + membership_update, + member_id, + &session.peer, + ); + for connection_id in connection_pool.user_connection_ids(member_id) { + if let Some(notification_id) = notification_id { + session + .peer + .send( + connection_id, + proto::DeleteNotification { + notification_id: notification_id.to_proto(), + }, + ) + .trace_err(); + } + } + + response.send(proto::Ack {})?; + Ok(()) +} + +async fn set_channel_visibility( + request: proto::SetChannelVisibility, + response: Response, + session: Session, +) -> Result<()> { + let db = session.db().await; + let channel_id = ChannelId::from_proto(request.channel_id); + let visibility = request.visibility().into(); + + let SetChannelVisibilityResult { + participants_to_update, + participants_to_remove, + channels_to_remove, + } = db + .set_channel_visibility(channel_id, visibility, session.user_id) + .await?; + + let connection_pool = session.connection_pool().await; + for (user_id, channels) in participants_to_update { + let update = build_channels_update(channels, vec![]); + for connection_id in connection_pool.user_connection_ids(user_id) { + session.peer.send(connection_id, update.clone())?; + } + } + for user_id in participants_to_remove { + let update = proto::UpdateChannels { + delete_channels: channels_to_remove.iter().map(|id| id.to_proto()).collect(), + ..Default::default() + }; + for connection_id in connection_pool.user_connection_ids(user_id) { + session.peer.send(connection_id, update.clone())?; + } + } + + response.send(proto::Ack {})?; + Ok(()) +} + +async fn set_channel_member_role( + request: proto::SetChannelMemberRole, + response: Response, + session: Session, +) -> Result<()> { + let db = session.db().await; + let channel_id = ChannelId::from_proto(request.channel_id); + let member_id = UserId::from_proto(request.user_id); + let result = db + .set_channel_member_role( + channel_id, + session.user_id, + member_id, + request.role().into(), + ) + .await?; + + match result { + db::SetMemberRoleResult::MembershipUpdated(membership_update) => { + let connection_pool = session.connection_pool().await; + notify_membership_updated( + &connection_pool, + membership_update, + member_id, + &session.peer, + ) + } + db::SetMemberRoleResult::InviteUpdated(channel) => { + let update = proto::UpdateChannels { + channel_invitations: vec![channel.to_proto()], + ..Default::default() + }; + + for connection_id in session + .connection_pool() + .await + .user_connection_ids(member_id) + { + session.peer.send(connection_id, update.clone())?; + } + } + } + + response.send(proto::Ack {})?; + Ok(()) +} + +async fn rename_channel( + request: proto::RenameChannel, + response: Response, + session: Session, +) -> Result<()> { + let db = session.db().await; + let channel_id = ChannelId::from_proto(request.channel_id); + let RenameChannelResult { + channel, + participants_to_update, + } = db + .rename_channel(channel_id, session.user_id, &request.name) + .await?; + + response.send(proto::RenameChannelResponse { + channel: Some(channel.to_proto()), + })?; + + let connection_pool = session.connection_pool().await; + for (user_id, channel) in participants_to_update { + for connection_id in connection_pool.user_connection_ids(user_id) { + let update = proto::UpdateChannels { + channels: vec![channel.to_proto()], + ..Default::default() + }; + + session.peer.send(connection_id, update.clone())?; + } + } + + Ok(()) +} + +async fn move_channel( + request: proto::MoveChannel, + response: Response, + session: Session, +) -> Result<()> { + let channel_id = ChannelId::from_proto(request.channel_id); + let to = request.to.map(ChannelId::from_proto); + + let result = session + .db() + .await + .move_channel(channel_id, to, session.user_id) + .await?; + + notify_channel_moved(result, session).await?; + + response.send(Ack {})?; + Ok(()) +} + +async fn notify_channel_moved(result: Option, session: Session) -> Result<()> { + let Some(MoveChannelResult { + participants_to_remove, + participants_to_update, + moved_channels, + }) = result + else { + return Ok(()); + }; + let moved_channels: Vec = moved_channels.iter().map(|id| id.to_proto()).collect(); + + let connection_pool = session.connection_pool().await; + for (user_id, channels) in participants_to_update { + let mut update = build_channels_update(channels, vec![]); + update.delete_channels = moved_channels.clone(); + for connection_id in connection_pool.user_connection_ids(user_id) { + session.peer.send(connection_id, update.clone())?; + } + } + + for user_id in participants_to_remove { + let update = proto::UpdateChannels { + delete_channels: moved_channels.clone(), + ..Default::default() + }; + for connection_id in connection_pool.user_connection_ids(user_id) { + session.peer.send(connection_id, update.clone())?; + } + } + Ok(()) +} + +async fn get_channel_members( + request: proto::GetChannelMembers, + response: Response, + session: Session, +) -> Result<()> { + let db = session.db().await; + let channel_id = ChannelId::from_proto(request.channel_id); + let members = db + .get_channel_participant_details(channel_id, session.user_id) + .await?; + response.send(proto::GetChannelMembersResponse { members })?; + Ok(()) +} + +async fn respond_to_channel_invite( + request: proto::RespondToChannelInvite, + response: Response, + session: Session, +) -> Result<()> { + let db = session.db().await; + let channel_id = ChannelId::from_proto(request.channel_id); + let RespondToChannelInvite { + membership_update, + notifications, + } = db + .respond_to_channel_invite(channel_id, session.user_id, request.accept) + .await?; + + let connection_pool = session.connection_pool().await; + if let Some(membership_update) = membership_update { + notify_membership_updated( + &connection_pool, + membership_update, + session.user_id, + &session.peer, + ); + } else { + let update = proto::UpdateChannels { + remove_channel_invitations: vec![channel_id.to_proto()], + ..Default::default() + }; + + for connection_id in connection_pool.user_connection_ids(session.user_id) { + session.peer.send(connection_id, update.clone())?; + } + }; + + send_notifications(&*connection_pool, &session.peer, notifications); + + response.send(proto::Ack {})?; + + Ok(()) +} + +async fn join_channel( + request: proto::JoinChannel, + response: Response, + session: Session, +) -> Result<()> { + let channel_id = ChannelId::from_proto(request.channel_id); + join_channel_internal(channel_id, Box::new(response), session).await +} + +trait JoinChannelInternalResponse { + fn send(self, result: proto::JoinRoomResponse) -> Result<()>; +} +impl JoinChannelInternalResponse for Response { + fn send(self, result: proto::JoinRoomResponse) -> Result<()> { + Response::::send(self, result) + } +} +impl JoinChannelInternalResponse for Response { + fn send(self, result: proto::JoinRoomResponse) -> Result<()> { + Response::::send(self, result) + } +} + +async fn join_channel_internal( + channel_id: ChannelId, + response: Box, + session: Session, +) -> Result<()> { + let joined_room = { + leave_room_for_session(&session).await?; + let db = session.db().await; + + let (joined_room, membership_updated, role) = db + .join_channel( + channel_id, + session.user_id, + session.connection_id, + RELEASE_CHANNEL_NAME.as_str(), + ) + .await?; + + let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| { + let (can_publish, token) = if role == ChannelRole::Guest { + ( + false, + live_kit + .guest_token( + &joined_room.room.live_kit_room, + &session.user_id.to_string(), + ) + .trace_err()?, + ) + } else { + ( + true, + live_kit + .room_token( + &joined_room.room.live_kit_room, + &session.user_id.to_string(), + ) + .trace_err()?, + ) + }; + + Some(LiveKitConnectionInfo { + server_url: live_kit.url().into(), + token, + can_publish, + }) + }); + + response.send(proto::JoinRoomResponse { + room: Some(joined_room.room.clone()), + channel_id: joined_room.channel_id.map(|id| id.to_proto()), + live_kit_connection_info, + })?; + + let connection_pool = session.connection_pool().await; + if let Some(membership_updated) = membership_updated { + notify_membership_updated( + &connection_pool, + membership_updated, + session.user_id, + &session.peer, + ); + } + + room_updated(&joined_room.room, &session.peer); + + joined_room + }; + + channel_updated( + channel_id, + &joined_room.room, + &joined_room.channel_members, + &session.peer, + &*session.connection_pool().await, + ); + + update_user_contacts(session.user_id, &session).await?; + Ok(()) +} + +async fn join_channel_buffer( + request: proto::JoinChannelBuffer, + response: Response, + session: Session, +) -> Result<()> { + let db = session.db().await; + let channel_id = ChannelId::from_proto(request.channel_id); + + let open_response = db + .join_channel_buffer(channel_id, session.user_id, session.connection_id) + .await?; + + let collaborators = open_response.collaborators.clone(); + response.send(open_response)?; + + let update = UpdateChannelBufferCollaborators { + channel_id: channel_id.to_proto(), + collaborators: collaborators.clone(), + }; + channel_buffer_updated( + session.connection_id, + collaborators + .iter() + .filter_map(|collaborator| Some(collaborator.peer_id?.into())), + &update, + &session.peer, + ); + + Ok(()) +} + +async fn update_channel_buffer( + request: proto::UpdateChannelBuffer, + session: Session, +) -> Result<()> { + let db = session.db().await; + let channel_id = ChannelId::from_proto(request.channel_id); + + let (collaborators, non_collaborators, epoch, version) = db + .update_channel_buffer(channel_id, session.user_id, &request.operations) + .await?; + + channel_buffer_updated( + session.connection_id, + collaborators, + &proto::UpdateChannelBuffer { + channel_id: channel_id.to_proto(), + operations: request.operations, + }, + &session.peer, + ); + + let pool = &*session.connection_pool().await; + + broadcast( + None, + non_collaborators + .iter() + .flat_map(|user_id| pool.user_connection_ids(*user_id)), + |peer_id| { + session.peer.send( + peer_id.into(), + proto::UpdateChannels { + unseen_channel_buffer_changes: vec![proto::UnseenChannelBufferChange { + channel_id: channel_id.to_proto(), + epoch: epoch as u64, + version: version.clone(), + }], + ..Default::default() + }, + ) + }, + ); + + Ok(()) +} + +async fn rejoin_channel_buffers( + request: proto::RejoinChannelBuffers, + response: Response, + session: Session, +) -> Result<()> { + let db = session.db().await; + let buffers = db + .rejoin_channel_buffers(&request.buffers, session.user_id, session.connection_id) + .await?; + + for rejoined_buffer in &buffers { + let collaborators_to_notify = rejoined_buffer + .buffer + .collaborators + .iter() + .filter_map(|c| Some(c.peer_id?.into())); + channel_buffer_updated( + session.connection_id, + collaborators_to_notify, + &proto::UpdateChannelBufferCollaborators { + channel_id: rejoined_buffer.buffer.channel_id, + collaborators: rejoined_buffer.buffer.collaborators.clone(), + }, + &session.peer, + ); + } + + response.send(proto::RejoinChannelBuffersResponse { + buffers: buffers.into_iter().map(|b| b.buffer).collect(), + })?; + + Ok(()) +} + +async fn leave_channel_buffer( + request: proto::LeaveChannelBuffer, + response: Response, + session: Session, +) -> Result<()> { + let db = session.db().await; + let channel_id = ChannelId::from_proto(request.channel_id); + + let left_buffer = db + .leave_channel_buffer(channel_id, session.connection_id) + .await?; + + response.send(Ack {})?; + + channel_buffer_updated( + session.connection_id, + left_buffer.connections, + &proto::UpdateChannelBufferCollaborators { + channel_id: channel_id.to_proto(), + collaborators: left_buffer.collaborators, + }, + &session.peer, + ); + + Ok(()) +} + +fn channel_buffer_updated( + sender_id: ConnectionId, + collaborators: impl IntoIterator, + message: &T, + peer: &Peer, +) { + broadcast(Some(sender_id), collaborators.into_iter(), |peer_id| { + peer.send(peer_id.into(), message.clone()) + }); +} + +fn send_notifications( + connection_pool: &ConnectionPool, + peer: &Peer, + notifications: db::NotificationBatch, +) { + for (user_id, notification) in notifications { + for connection_id in connection_pool.user_connection_ids(user_id) { + if let Err(error) = peer.send( + connection_id, + proto::AddNotification { + notification: Some(notification.clone()), + }, + ) { + tracing::error!( + "failed to send notification to {:?} {}", + connection_id, + error + ); + } + } + } +} + +async fn send_channel_message( + request: proto::SendChannelMessage, + response: Response, + session: Session, +) -> Result<()> { + // Validate the message body. + let body = request.body.trim().to_string(); + if body.len() > MAX_MESSAGE_LEN { + return Err(anyhow!("message is too long"))?; + } + if body.is_empty() { + return Err(anyhow!("message can't be blank"))?; + } + + // TODO: adjust mentions if body is trimmed + + let timestamp = OffsetDateTime::now_utc(); + let nonce = request + .nonce + .ok_or_else(|| anyhow!("nonce can't be blank"))?; + + let channel_id = ChannelId::from_proto(request.channel_id); + let CreatedChannelMessage { + message_id, + participant_connection_ids, + channel_members, + notifications, + } = session + .db() + .await + .create_channel_message( + channel_id, + session.user_id, + &body, + &request.mentions, + timestamp, + nonce.clone().into(), + ) + .await?; + let message = proto::ChannelMessage { + sender_id: session.user_id.to_proto(), + id: message_id.to_proto(), + body, + mentions: request.mentions, + timestamp: timestamp.unix_timestamp() as u64, + nonce: Some(nonce), + }; + broadcast( + Some(session.connection_id), + participant_connection_ids, + |connection| { + session.peer.send( + connection, + proto::ChannelMessageSent { + channel_id: channel_id.to_proto(), + message: Some(message.clone()), + }, + ) + }, + ); + response.send(proto::SendChannelMessageResponse { + message: Some(message), + })?; + + let pool = &*session.connection_pool().await; + broadcast( + None, + channel_members + .iter() + .flat_map(|user_id| pool.user_connection_ids(*user_id)), + |peer_id| { + session.peer.send( + peer_id.into(), + proto::UpdateChannels { + unseen_channel_messages: vec![proto::UnseenChannelMessage { + channel_id: channel_id.to_proto(), + message_id: message_id.to_proto(), + }], + ..Default::default() + }, + ) + }, + ); + send_notifications(pool, &session.peer, notifications); + + Ok(()) +} + +async fn remove_channel_message( + request: proto::RemoveChannelMessage, + response: Response, + session: Session, +) -> Result<()> { + let channel_id = ChannelId::from_proto(request.channel_id); + let message_id = MessageId::from_proto(request.message_id); + let connection_ids = session + .db() + .await + .remove_channel_message(channel_id, message_id, session.user_id) + .await?; + broadcast(Some(session.connection_id), connection_ids, |connection| { + session.peer.send(connection, request.clone()) + }); + response.send(proto::Ack {})?; + Ok(()) +} + +async fn acknowledge_channel_message( + request: proto::AckChannelMessage, + session: Session, +) -> Result<()> { + let channel_id = ChannelId::from_proto(request.channel_id); + let message_id = MessageId::from_proto(request.message_id); + let notifications = session + .db() + .await + .observe_channel_message(channel_id, session.user_id, message_id) + .await?; + send_notifications( + &*session.connection_pool().await, + &session.peer, + notifications, + ); + Ok(()) +} + +async fn acknowledge_buffer_version( + request: proto::AckBufferOperation, + session: Session, +) -> Result<()> { + let buffer_id = BufferId::from_proto(request.buffer_id); + session + .db() + .await + .observe_buffer_version( + buffer_id, + session.user_id, + request.epoch as i32, + &request.version, + ) + .await?; + Ok(()) +} + +async fn join_channel_chat( + request: proto::JoinChannelChat, + response: Response, + session: Session, +) -> Result<()> { + let channel_id = ChannelId::from_proto(request.channel_id); + + let db = session.db().await; + db.join_channel_chat(channel_id, session.connection_id, session.user_id) + .await?; + let messages = db + .get_channel_messages(channel_id, session.user_id, MESSAGE_COUNT_PER_PAGE, None) + .await?; + response.send(proto::JoinChannelChatResponse { + done: messages.len() < MESSAGE_COUNT_PER_PAGE, + messages, + })?; + Ok(()) +} + +async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> { + let channel_id = ChannelId::from_proto(request.channel_id); + session + .db() + .await + .leave_channel_chat(channel_id, session.connection_id, session.user_id) + .await?; + Ok(()) +} + +async fn get_channel_messages( + request: proto::GetChannelMessages, + response: Response, + session: Session, +) -> Result<()> { + let channel_id = ChannelId::from_proto(request.channel_id); + let messages = session + .db() + .await + .get_channel_messages( + channel_id, + session.user_id, + MESSAGE_COUNT_PER_PAGE, + Some(MessageId::from_proto(request.before_message_id)), + ) + .await?; + response.send(proto::GetChannelMessagesResponse { + done: messages.len() < MESSAGE_COUNT_PER_PAGE, + messages, + })?; + Ok(()) +} + +async fn get_channel_messages_by_id( + request: proto::GetChannelMessagesById, + response: Response, + session: Session, +) -> Result<()> { + let message_ids = request + .message_ids + .iter() + .map(|id| MessageId::from_proto(*id)) + .collect::>(); + let messages = session + .db() + .await + .get_channel_messages_by_id(session.user_id, &message_ids) + .await?; + response.send(proto::GetChannelMessagesResponse { + done: messages.len() < MESSAGE_COUNT_PER_PAGE, + messages, + })?; + Ok(()) +} + +async fn get_notifications( + request: proto::GetNotifications, + response: Response, + session: Session, +) -> Result<()> { + let notifications = session + .db() + .await + .get_notifications( + session.user_id, + NOTIFICATION_COUNT_PER_PAGE, + request + .before_id + .map(|id| db::NotificationId::from_proto(id)), + ) + .await?; + response.send(proto::GetNotificationsResponse { + done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE, + notifications, + })?; + Ok(()) +} + +async fn mark_notification_as_read( + request: proto::MarkNotificationRead, + response: Response, + session: Session, +) -> Result<()> { + let database = &session.db().await; + let notifications = database + .mark_notification_as_read_by_id( + session.user_id, + NotificationId::from_proto(request.notification_id), + ) + .await?; + send_notifications( + &*session.connection_pool().await, + &session.peer, + notifications, + ); + response.send(proto::Ack {})?; + Ok(()) +} + +async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> { + let project_id = ProjectId::from_proto(request.project_id); + let project_connection_ids = session + .db() + .await + .project_connection_ids(project_id, session.connection_id) + .await?; + broadcast( + Some(session.connection_id), + project_connection_ids.iter().copied(), + |connection_id| { + session + .peer + .forward_send(session.connection_id, connection_id, request.clone()) + }, + ); + Ok(()) +} + +async fn get_private_user_info( + _request: proto::GetPrivateUserInfo, + response: Response, + session: Session, +) -> Result<()> { + let db = session.db().await; + + let metrics_id = db.get_user_metrics_id(session.user_id).await?; + let user = db + .get_user_by_id(session.user_id) + .await? + .ok_or_else(|| anyhow!("user not found"))?; + let flags = db.get_user_flags(session.user_id).await?; + + response.send(proto::GetPrivateUserInfoResponse { + metrics_id, + staff: user.admin, + flags, + })?; + Ok(()) +} + +fn to_axum_message(message: TungsteniteMessage) -> AxumMessage { + match message { + TungsteniteMessage::Text(payload) => AxumMessage::Text(payload), + TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload), + TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload), + TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload), + TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame { + code: frame.code.into(), + reason: frame.reason, + })), + } +} + +fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage { + match message { + AxumMessage::Text(payload) => TungsteniteMessage::Text(payload), + AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload), + AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload), + AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload), + AxumMessage::Close(frame) => { + TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame { + code: frame.code.into(), + reason: frame.reason, + })) + } + } +} + +fn notify_membership_updated( + connection_pool: &ConnectionPool, + result: MembershipUpdated, + user_id: UserId, + peer: &Peer, +) { + let mut update = build_channels_update(result.new_channels, vec![]); + update.delete_channels = result + .removed_channels + .into_iter() + .map(|id| id.to_proto()) + .collect(); + update.remove_channel_invitations = vec![result.channel_id.to_proto()]; + + for connection_id in connection_pool.user_connection_ids(user_id) { + peer.send(connection_id, update.clone()).trace_err(); + } +} + +fn build_channels_update( + channels: ChannelsForUser, + channel_invites: Vec, +) -> proto::UpdateChannels { + let mut update = proto::UpdateChannels::default(); + + for channel in channels.channels { + update.channels.push(channel.to_proto()); + } + + update.unseen_channel_buffer_changes = channels.unseen_buffer_changes; + update.unseen_channel_messages = channels.channel_messages; + + for (channel_id, participants) in channels.channel_participants { + update + .channel_participants + .push(proto::ChannelParticipants { + channel_id: channel_id.to_proto(), + participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(), + }); + } + + for channel in channel_invites { + update.channel_invitations.push(channel.to_proto()); + } + + update +} + +fn build_initial_contacts_update( + contacts: Vec, + pool: &ConnectionPool, +) -> proto::UpdateContacts { + let mut update = proto::UpdateContacts::default(); + + for contact in contacts { + match contact { + db::Contact::Accepted { user_id, busy } => { + update.contacts.push(contact_for_user(user_id, busy, &pool)); + } + db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()), + db::Contact::Incoming { user_id } => { + update + .incoming_requests + .push(proto::IncomingContactRequest { + requester_id: user_id.to_proto(), + }) + } + } + } + + update +} + +fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact { + proto::Contact { + user_id: user_id.to_proto(), + online: pool.is_user_online(user_id), + busy, + } +} + +fn room_updated(room: &proto::Room, peer: &Peer) { + broadcast( + None, + room.participants + .iter() + .filter_map(|participant| Some(participant.peer_id?.into())), + |peer_id| { + peer.send( + peer_id.into(), + proto::RoomUpdated { + room: Some(room.clone()), + }, + ) + }, + ); +} + +fn channel_updated( + channel_id: ChannelId, + room: &proto::Room, + channel_members: &[UserId], + peer: &Peer, + pool: &ConnectionPool, +) { + let participants = room + .participants + .iter() + .map(|p| p.user_id) + .collect::>(); + + broadcast( + None, + channel_members + .iter() + .flat_map(|user_id| pool.user_connection_ids(*user_id)), + |peer_id| { + peer.send( + peer_id.into(), + proto::UpdateChannels { + channel_participants: vec![proto::ChannelParticipants { + channel_id: channel_id.to_proto(), + participant_user_ids: participants.clone(), + }], + ..Default::default() + }, + ) + }, + ); +} + +async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> { + let db = session.db().await; + + let contacts = db.get_contacts(user_id).await?; + let busy = db.is_user_busy(user_id).await?; + + let pool = session.connection_pool().await; + let updated_contact = contact_for_user(user_id, busy, &pool); + for contact in contacts { + if let db::Contact::Accepted { + user_id: contact_user_id, + .. + } = contact + { + for contact_conn_id in pool.user_connection_ids(contact_user_id) { + session + .peer + .send( + contact_conn_id, + proto::UpdateContacts { + contacts: vec![updated_contact.clone()], + remove_contacts: Default::default(), + incoming_requests: Default::default(), + remove_incoming_requests: Default::default(), + outgoing_requests: Default::default(), + remove_outgoing_requests: Default::default(), + }, + ) + .trace_err(); + } + } + } + Ok(()) +} + +async fn leave_room_for_session(session: &Session) -> Result<()> { + let mut contacts_to_update = HashSet::default(); + + let room_id; + let canceled_calls_to_user_ids; + let live_kit_room; + let delete_live_kit_room; + let room; + let channel_members; + let channel_id; + + if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? { + contacts_to_update.insert(session.user_id); + + for project in left_room.left_projects.values() { + project_left(project, session); + } + + room_id = RoomId::from_proto(left_room.room.id); + canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids); + live_kit_room = mem::take(&mut left_room.room.live_kit_room); + delete_live_kit_room = left_room.deleted; + room = mem::take(&mut left_room.room); + channel_members = mem::take(&mut left_room.channel_members); + channel_id = left_room.channel_id; + + room_updated(&room, &session.peer); + } else { + return Ok(()); + } + + if let Some(channel_id) = channel_id { + channel_updated( + channel_id, + &room, + &channel_members, + &session.peer, + &*session.connection_pool().await, + ); + } + + { + let pool = session.connection_pool().await; + for canceled_user_id in canceled_calls_to_user_ids { + for connection_id in pool.user_connection_ids(canceled_user_id) { + session + .peer + .send( + connection_id, + proto::CallCanceled { + room_id: room_id.to_proto(), + }, + ) + .trace_err(); + } + contacts_to_update.insert(canceled_user_id); + } + } + + for contact_user_id in contacts_to_update { + update_user_contacts(contact_user_id, &session).await?; + } + + if let Some(live_kit) = session.live_kit_client.as_ref() { + live_kit + .remove_participant(live_kit_room.clone(), session.user_id.to_string()) + .await + .trace_err(); + + if delete_live_kit_room { + live_kit.delete_room(live_kit_room).await.trace_err(); + } + } + + Ok(()) +} + +async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> { + let left_channel_buffers = session + .db() + .await + .leave_channel_buffers(session.connection_id) + .await?; + + for left_buffer in left_channel_buffers { + channel_buffer_updated( + session.connection_id, + left_buffer.connections, + &proto::UpdateChannelBufferCollaborators { + channel_id: left_buffer.channel_id.to_proto(), + collaborators: left_buffer.collaborators, + }, + &session.peer, + ); + } + + Ok(()) +} + +fn project_left(project: &db::LeftProject, session: &Session) { + for connection_id in &project.connection_ids { + if project.host_user_id == session.user_id { + session + .peer + .send( + *connection_id, + proto::UnshareProject { + project_id: project.id.to_proto(), + }, + ) + .trace_err(); + } else { + session + .peer + .send( + *connection_id, + proto::RemoveProjectCollaborator { + project_id: project.id.to_proto(), + peer_id: Some(session.connection_id.into()), + }, + ) + .trace_err(); + } + } +} + +pub trait ResultExt { + type Ok; + + fn trace_err(self) -> Option; +} + +impl ResultExt for Result +where + E: std::fmt::Debug, +{ + type Ok = T; + + fn trace_err(self) -> Option { + match self { + Ok(value) => Some(value), + Err(error) => { + tracing::error!("{:?}", error); + None + } + } + } +} diff --git a/crates/collab2/src/rpc/connection_pool.rs b/crates/collab2/src/rpc/connection_pool.rs new file mode 100644 index 0000000000..30c4e144ed --- /dev/null +++ b/crates/collab2/src/rpc/connection_pool.rs @@ -0,0 +1,98 @@ +use crate::db::UserId; +use anyhow::{anyhow, Result}; +use collections::{BTreeMap, HashSet}; +use rpc::ConnectionId; +use serde::Serialize; +use tracing::instrument; + +#[derive(Default, Serialize)] +pub struct ConnectionPool { + connections: BTreeMap, + connected_users: BTreeMap, +} + +#[derive(Default, Serialize)] +struct ConnectedUser { + connection_ids: HashSet, +} + +#[derive(Serialize)] +pub struct Connection { + pub user_id: UserId, + pub admin: bool, +} + +impl ConnectionPool { + pub fn reset(&mut self) { + self.connections.clear(); + self.connected_users.clear(); + } + + #[instrument(skip(self))] + pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId, admin: bool) { + self.connections + .insert(connection_id, Connection { user_id, admin }); + let connected_user = self.connected_users.entry(user_id).or_default(); + connected_user.connection_ids.insert(connection_id); + } + + #[instrument(skip(self))] + pub fn remove_connection(&mut self, connection_id: ConnectionId) -> Result<()> { + let connection = self + .connections + .get_mut(&connection_id) + .ok_or_else(|| anyhow!("no such connection"))?; + + let user_id = connection.user_id; + let connected_user = self.connected_users.get_mut(&user_id).unwrap(); + connected_user.connection_ids.remove(&connection_id); + if connected_user.connection_ids.is_empty() { + self.connected_users.remove(&user_id); + } + self.connections.remove(&connection_id).unwrap(); + Ok(()) + } + + pub fn connections(&self) -> impl Iterator { + self.connections.values() + } + + pub fn user_connection_ids(&self, user_id: UserId) -> impl Iterator + '_ { + self.connected_users + .get(&user_id) + .into_iter() + .map(|state| &state.connection_ids) + .flatten() + .copied() + } + + pub fn is_user_online(&self, user_id: UserId) -> bool { + !self + .connected_users + .get(&user_id) + .unwrap_or(&Default::default()) + .connection_ids + .is_empty() + } + + #[cfg(test)] + pub fn check_invariants(&self) { + for (connection_id, connection) in &self.connections { + assert!(self + .connected_users + .get(&connection.user_id) + .unwrap() + .connection_ids + .contains(connection_id)); + } + + for (user_id, state) in &self.connected_users { + for connection_id in &state.connection_ids { + assert_eq!( + self.connections.get(connection_id).unwrap().user_id, + *user_id + ); + } + } + } +} diff --git a/crates/collab2/src/tests.rs b/crates/collab2/src/tests.rs new file mode 100644 index 0000000000..cb25856551 --- /dev/null +++ b/crates/collab2/src/tests.rs @@ -0,0 +1,47 @@ +use call::Room; +use gpui::{Model, TestAppContext}; + +mod channel_buffer_tests; +mod channel_message_tests; +mod channel_tests; +mod following_tests; +mod integration_tests; +mod notification_tests; +mod random_channel_buffer_tests; +mod random_project_collaboration_tests; +mod randomized_test_helpers; +mod test_server; + +pub use crate as collab2; +pub use randomized_test_helpers::{ + run_randomized_test, save_randomized_test_plan, RandomizedTest, TestError, UserTestPlan, +}; +pub use test_server::{TestClient, TestServer}; + +#[derive(Debug, Eq, PartialEq)] +struct RoomParticipants { + remote: Vec, + pending: Vec, +} + +fn room_participants(room: &Model, cx: &mut TestAppContext) -> RoomParticipants { + room.read_with(cx, |room, _| { + let mut remote = room + .remote_participants() + .iter() + .map(|(_, participant)| participant.user.github_login.clone()) + .collect::>(); + let mut pending = room + .pending_participants() + .iter() + .map(|user| user.github_login.clone()) + .collect::>(); + remote.sort(); + pending.sort(); + RoomParticipants { remote, pending } + }) +} + +fn channel_id(room: &Model, cx: &mut TestAppContext) -> Option { + cx.read(|cx| room.read(cx).channel_id()) +} diff --git a/crates/collab2/src/tests/channel_buffer_tests.rs b/crates/collab2/src/tests/channel_buffer_tests.rs new file mode 100644 index 0000000000..ba891e6192 --- /dev/null +++ b/crates/collab2/src/tests/channel_buffer_tests.rs @@ -0,0 +1,872 @@ +use crate::{ + rpc::{CLEANUP_TIMEOUT, RECONNECT_TIMEOUT}, + tests::TestServer, +}; +use client::{Collaborator, UserId}; +use collections::HashMap; +use futures::future; +use gpui::{BackgroundExecutor, Model, TestAppContext}; +use rpc::{proto::PeerId, RECEIVE_TIMEOUT}; + +#[gpui::test] +async fn test_core_channel_buffers( + executor: BackgroundExecutor, + cx_a: &mut TestAppContext, + cx_b: &mut TestAppContext, +) { + let mut server = TestServer::start(executor.clone()).await; + let client_a = server.create_client(cx_a, "user_a").await; + let client_b = server.create_client(cx_b, "user_b").await; + + let channel_id = server + .make_channel("zed", None, (&client_a, cx_a), &mut [(&client_b, cx_b)]) + .await; + + // Client A joins the channel buffer + let channel_buffer_a = client_a + .channel_store() + .update(cx_a, |store, cx| store.open_channel_buffer(channel_id, cx)) + .await + .unwrap(); + + // Client A edits the buffer + let buffer_a = channel_buffer_a.read_with(cx_a, |buffer, _| buffer.buffer()); + buffer_a.update(cx_a, |buffer, cx| { + buffer.edit([(0..0, "hello world")], None, cx) + }); + buffer_a.update(cx_a, |buffer, cx| { + buffer.edit([(5..5, ", cruel")], None, cx) + }); + buffer_a.update(cx_a, |buffer, cx| { + buffer.edit([(0..5, "goodbye")], None, cx) + }); + buffer_a.update(cx_a, |buffer, cx| buffer.undo(cx)); + assert_eq!(buffer_text(&buffer_a, cx_a), "hello, cruel world"); + executor.run_until_parked(); + + // Client B joins the channel buffer + let channel_buffer_b = client_b + .channel_store() + .update(cx_b, |store, cx| store.open_channel_buffer(channel_id, cx)) + .await + .unwrap(); + channel_buffer_b.read_with(cx_b, |buffer, _| { + assert_collaborators( + buffer.collaborators(), + &[client_a.user_id(), client_b.user_id()], + ); + }); + + // Client B sees the correct text, and then edits it + let buffer_b = channel_buffer_b.read_with(cx_b, |buffer, _| buffer.buffer()); + assert_eq!( + buffer_b.read_with(cx_b, |buffer, _| buffer.remote_id()), + buffer_a.read_with(cx_a, |buffer, _| buffer.remote_id()) + ); + assert_eq!(buffer_text(&buffer_b, cx_b), "hello, cruel world"); + buffer_b.update(cx_b, |buffer, cx| { + buffer.edit([(7..12, "beautiful")], None, cx) + }); + + // Both A and B see the new edit + executor.run_until_parked(); + assert_eq!(buffer_text(&buffer_a, cx_a), "hello, beautiful world"); + assert_eq!(buffer_text(&buffer_b, cx_b), "hello, beautiful world"); + + // Client A closes the channel buffer. + cx_a.update(|_| drop(channel_buffer_a)); + executor.run_until_parked(); + + // Client B sees that client A is gone from the channel buffer. + channel_buffer_b.read_with(cx_b, |buffer, _| { + assert_collaborators(&buffer.collaborators(), &[client_b.user_id()]); + }); + + // Client A rejoins the channel buffer + let _channel_buffer_a = client_a + .channel_store() + .update(cx_a, |store, cx| store.open_channel_buffer(channel_id, cx)) + .await + .unwrap(); + executor.run_until_parked(); + + // Sanity test, make sure we saw A rejoining + channel_buffer_b.read_with(cx_b, |buffer, _| { + assert_collaborators( + &buffer.collaborators(), + &[client_a.user_id(), client_b.user_id()], + ); + }); + + // Client A loses connection. + server.forbid_connections(); + server.disconnect_client(client_a.peer_id().unwrap()); + executor.advance_clock(RECEIVE_TIMEOUT + RECONNECT_TIMEOUT); + + // Client B observes A disconnect + channel_buffer_b.read_with(cx_b, |buffer, _| { + assert_collaborators(&buffer.collaborators(), &[client_b.user_id()]); + }); + + // TODO: + // - Test synchronizing offline updates, what happens to A's channel buffer when A disconnects + // - Test interaction with channel deletion while buffer is open +} + +// todo!("collab_ui") +// #[gpui::test] +// async fn test_channel_notes_participant_indices( +// executor: BackgroundExecutor, +// mut cx_a: &mut TestAppContext, +// mut cx_b: &mut TestAppContext, +// cx_c: &mut TestAppContext, +// ) { +// let mut server = TestServer::start(&executor).await; +// let client_a = server.create_client(cx_a, "user_a").await; +// let client_b = server.create_client(cx_b, "user_b").await; +// let client_c = server.create_client(cx_c, "user_c").await; + +// let active_call_a = cx_a.read(ActiveCall::global); +// let active_call_b = cx_b.read(ActiveCall::global); + +// cx_a.update(editor::init); +// cx_b.update(editor::init); +// cx_c.update(editor::init); + +// let channel_id = server +// .make_channel( +// "the-channel", +// None, +// (&client_a, cx_a), +// &mut [(&client_b, cx_b), (&client_c, cx_c)], +// ) +// .await; + +// client_a +// .fs() +// .insert_tree("/root", json!({"file.txt": "123"})) +// .await; +// let (project_a, worktree_id_a) = client_a.build_local_project("/root", cx_a).await; +// let project_b = client_b.build_empty_local_project(cx_b); +// let project_c = client_c.build_empty_local_project(cx_c); +// let workspace_a = client_a.build_workspace(&project_a, cx_a).root(cx_a); +// let workspace_b = client_b.build_workspace(&project_b, cx_b).root(cx_b); +// let workspace_c = client_c.build_workspace(&project_c, cx_c).root(cx_c); + +// // Clients A, B, and C open the channel notes +// let channel_view_a = cx_a +// .update(|cx| ChannelView::open(channel_id, workspace_a.clone(), cx)) +// .await +// .unwrap(); +// let channel_view_b = cx_b +// .update(|cx| ChannelView::open(channel_id, workspace_b.clone(), cx)) +// .await +// .unwrap(); +// let channel_view_c = cx_c +// .update(|cx| ChannelView::open(channel_id, workspace_c.clone(), cx)) +// .await +// .unwrap(); + +// // Clients A, B, and C all insert and select some text +// channel_view_a.update(cx_a, |notes, cx| { +// notes.editor.update(cx, |editor, cx| { +// editor.insert("a", cx); +// editor.change_selections(None, cx, |selections| { +// selections.select_ranges(vec![0..1]); +// }); +// }); +// }); +// executor.run_until_parked(); +// channel_view_b.update(cx_b, |notes, cx| { +// notes.editor.update(cx, |editor, cx| { +// editor.move_down(&Default::default(), cx); +// editor.insert("b", cx); +// editor.change_selections(None, cx, |selections| { +// selections.select_ranges(vec![1..2]); +// }); +// }); +// }); +// executor.run_until_parked(); +// channel_view_c.update(cx_c, |notes, cx| { +// notes.editor.update(cx, |editor, cx| { +// editor.move_down(&Default::default(), cx); +// editor.insert("c", cx); +// editor.change_selections(None, cx, |selections| { +// selections.select_ranges(vec![2..3]); +// }); +// }); +// }); + +// // Client A sees clients B and C without assigned colors, because they aren't +// // in a call together. +// executor.run_until_parked(); +// channel_view_a.update(cx_a, |notes, cx| { +// notes.editor.update(cx, |editor, cx| { +// assert_remote_selections(editor, &[(None, 1..2), (None, 2..3)], cx); +// }); +// }); + +// // Clients A and B join the same call. +// for (call, cx) in [(&active_call_a, &mut cx_a), (&active_call_b, &mut cx_b)] { +// call.update(*cx, |call, cx| call.join_channel(channel_id, cx)) +// .await +// .unwrap(); +// } + +// // Clients A and B see each other with two different assigned colors. Client C +// // still doesn't have a color. +// executor.run_until_parked(); +// channel_view_a.update(cx_a, |notes, cx| { +// notes.editor.update(cx, |editor, cx| { +// assert_remote_selections( +// editor, +// &[(Some(ParticipantIndex(1)), 1..2), (None, 2..3)], +// cx, +// ); +// }); +// }); +// channel_view_b.update(cx_b, |notes, cx| { +// notes.editor.update(cx, |editor, cx| { +// assert_remote_selections( +// editor, +// &[(Some(ParticipantIndex(0)), 0..1), (None, 2..3)], +// cx, +// ); +// }); +// }); + +// // Client A shares a project, and client B joins. +// let project_id = active_call_a +// .update(cx_a, |call, cx| call.share_project(project_a.clone(), cx)) +// .await +// .unwrap(); +// let project_b = client_b.build_remote_project(project_id, cx_b).await; +// let workspace_b = client_b.build_workspace(&project_b, cx_b).root(cx_b); + +// // Clients A and B open the same file. +// let editor_a = workspace_a +// .update(cx_a, |workspace, cx| { +// workspace.open_path((worktree_id_a, "file.txt"), None, true, cx) +// }) +// .await +// .unwrap() +// .downcast::() +// .unwrap(); +// let editor_b = workspace_b +// .update(cx_b, |workspace, cx| { +// workspace.open_path((worktree_id_a, "file.txt"), None, true, cx) +// }) +// .await +// .unwrap() +// .downcast::() +// .unwrap(); + +// editor_a.update(cx_a, |editor, cx| { +// editor.change_selections(None, cx, |selections| { +// selections.select_ranges(vec![0..1]); +// }); +// }); +// editor_b.update(cx_b, |editor, cx| { +// editor.change_selections(None, cx, |selections| { +// selections.select_ranges(vec![2..3]); +// }); +// }); +// executor.run_until_parked(); + +// // Clients A and B see each other with the same colors as in the channel notes. +// editor_a.update(cx_a, |editor, cx| { +// assert_remote_selections(editor, &[(Some(ParticipantIndex(1)), 2..3)], cx); +// }); +// editor_b.update(cx_b, |editor, cx| { +// assert_remote_selections(editor, &[(Some(ParticipantIndex(0)), 0..1)], cx); +// }); +// } + +//todo!(editor) +// #[track_caller] +// fn assert_remote_selections( +// editor: &mut Editor, +// expected_selections: &[(Option, Range)], +// cx: &mut ViewContext, +// ) { +// let snapshot = editor.snapshot(cx); +// let range = Anchor::min()..Anchor::max(); +// let remote_selections = snapshot +// .remote_selections_in_range(&range, editor.collaboration_hub().unwrap(), cx) +// .map(|s| { +// let start = s.selection.start.to_offset(&snapshot.buffer_snapshot); +// let end = s.selection.end.to_offset(&snapshot.buffer_snapshot); +// (s.participant_index, start..end) +// }) +// .collect::>(); +// assert_eq!( +// remote_selections, expected_selections, +// "incorrect remote selections" +// ); +// } + +#[gpui::test] +async fn test_multiple_handles_to_channel_buffer( + deterministic: BackgroundExecutor, + cx_a: &mut TestAppContext, +) { + let mut server = TestServer::start(deterministic.clone()).await; + let client_a = server.create_client(cx_a, "user_a").await; + + let channel_id = server + .make_channel("the-channel", None, (&client_a, cx_a), &mut []) + .await; + + let channel_buffer_1 = client_a + .channel_store() + .update(cx_a, |store, cx| store.open_channel_buffer(channel_id, cx)); + let channel_buffer_2 = client_a + .channel_store() + .update(cx_a, |store, cx| store.open_channel_buffer(channel_id, cx)); + let channel_buffer_3 = client_a + .channel_store() + .update(cx_a, |store, cx| store.open_channel_buffer(channel_id, cx)); + + // All concurrent tasks for opening a channel buffer return the same model handle. + let (channel_buffer, channel_buffer_2, channel_buffer_3) = + future::try_join3(channel_buffer_1, channel_buffer_2, channel_buffer_3) + .await + .unwrap(); + let channel_buffer_model_id = channel_buffer.entity_id(); + assert_eq!(channel_buffer, channel_buffer_2); + assert_eq!(channel_buffer, channel_buffer_3); + + channel_buffer.update(cx_a, |buffer, cx| { + buffer.buffer().update(cx, |buffer, cx| { + buffer.edit([(0..0, "hello")], None, cx); + }) + }); + deterministic.run_until_parked(); + + cx_a.update(|_| { + drop(channel_buffer); + drop(channel_buffer_2); + drop(channel_buffer_3); + }); + deterministic.run_until_parked(); + + // The channel buffer can be reopened after dropping it. + let channel_buffer = client_a + .channel_store() + .update(cx_a, |store, cx| store.open_channel_buffer(channel_id, cx)) + .await + .unwrap(); + assert_ne!(channel_buffer.entity_id(), channel_buffer_model_id); + channel_buffer.update(cx_a, |buffer, cx| { + buffer.buffer().update(cx, |buffer, _| { + assert_eq!(buffer.text(), "hello"); + }) + }); +} + +#[gpui::test] +async fn test_channel_buffer_disconnect( + deterministic: BackgroundExecutor, + cx_a: &mut TestAppContext, + cx_b: &mut TestAppContext, +) { + let mut server = TestServer::start(deterministic.clone()).await; + let client_a = server.create_client(cx_a, "user_a").await; + let client_b = server.create_client(cx_b, "user_b").await; + + let channel_id = server + .make_channel( + "the-channel", + None, + (&client_a, cx_a), + &mut [(&client_b, cx_b)], + ) + .await; + + let channel_buffer_a = client_a + .channel_store() + .update(cx_a, |store, cx| store.open_channel_buffer(channel_id, cx)) + .await + .unwrap(); + + let channel_buffer_b = client_b + .channel_store() + .update(cx_b, |store, cx| store.open_channel_buffer(channel_id, cx)) + .await + .unwrap(); + + server.forbid_connections(); + server.disconnect_client(client_a.peer_id().unwrap()); + deterministic.advance_clock(RECEIVE_TIMEOUT + RECONNECT_TIMEOUT); + + channel_buffer_a.update(cx_a, |buffer, cx| { + assert_eq!(buffer.channel(cx).unwrap().name, "the-channel"); + assert!(!buffer.is_connected()); + }); + + deterministic.run_until_parked(); + + server.allow_connections(); + deterministic.advance_clock(RECEIVE_TIMEOUT + RECONNECT_TIMEOUT); + + deterministic.run_until_parked(); + + client_a + .channel_store() + .update(cx_a, |channel_store, _| { + channel_store.remove_channel(channel_id) + }) + .await + .unwrap(); + deterministic.run_until_parked(); + + // Channel buffer observed the deletion + channel_buffer_b.update(cx_b, |buffer, cx| { + assert!(buffer.channel(cx).is_none()); + assert!(!buffer.is_connected()); + }); +} + +#[gpui::test] +async fn test_rejoin_channel_buffer( + deterministic: BackgroundExecutor, + cx_a: &mut TestAppContext, + cx_b: &mut TestAppContext, +) { + let mut server = TestServer::start(deterministic.clone()).await; + let client_a = server.create_client(cx_a, "user_a").await; + let client_b = server.create_client(cx_b, "user_b").await; + + let channel_id = server + .make_channel( + "the-channel", + None, + (&client_a, cx_a), + &mut [(&client_b, cx_b)], + ) + .await; + + let channel_buffer_a = client_a + .channel_store() + .update(cx_a, |store, cx| store.open_channel_buffer(channel_id, cx)) + .await + .unwrap(); + let channel_buffer_b = client_b + .channel_store() + .update(cx_b, |store, cx| store.open_channel_buffer(channel_id, cx)) + .await + .unwrap(); + + channel_buffer_a.update(cx_a, |buffer, cx| { + buffer.buffer().update(cx, |buffer, cx| { + buffer.edit([(0..0, "1")], None, cx); + }) + }); + deterministic.run_until_parked(); + + // Client A disconnects. + server.forbid_connections(); + server.disconnect_client(client_a.peer_id().unwrap()); + + // Both clients make an edit. + channel_buffer_a.update(cx_a, |buffer, cx| { + buffer.buffer().update(cx, |buffer, cx| { + buffer.edit([(1..1, "2")], None, cx); + }) + }); + channel_buffer_b.update(cx_b, |buffer, cx| { + buffer.buffer().update(cx, |buffer, cx| { + buffer.edit([(0..0, "0")], None, cx); + }) + }); + + // Both clients see their own edit. + deterministic.run_until_parked(); + channel_buffer_a.read_with(cx_a, |buffer, cx| { + assert_eq!(buffer.buffer().read(cx).text(), "12"); + }); + channel_buffer_b.read_with(cx_b, |buffer, cx| { + assert_eq!(buffer.buffer().read(cx).text(), "01"); + }); + + // Client A reconnects. Both clients see each other's edits, and see + // the same collaborators. + server.allow_connections(); + deterministic.advance_clock(RECEIVE_TIMEOUT); + channel_buffer_a.read_with(cx_a, |buffer, cx| { + assert_eq!(buffer.buffer().read(cx).text(), "012"); + }); + channel_buffer_b.read_with(cx_b, |buffer, cx| { + assert_eq!(buffer.buffer().read(cx).text(), "012"); + }); + + channel_buffer_a.read_with(cx_a, |buffer_a, _| { + channel_buffer_b.read_with(cx_b, |buffer_b, _| { + assert_eq!(buffer_a.collaborators(), buffer_b.collaborators()); + }); + }); +} + +#[gpui::test] +async fn test_channel_buffers_and_server_restarts( + deterministic: BackgroundExecutor, + cx_a: &mut TestAppContext, + cx_b: &mut TestAppContext, + cx_c: &mut TestAppContext, +) { + let mut server = TestServer::start(deterministic.clone()).await; + let client_a = server.create_client(cx_a, "user_a").await; + let client_b = server.create_client(cx_b, "user_b").await; + let client_c = server.create_client(cx_c, "user_c").await; + + let channel_id = server + .make_channel( + "the-channel", + None, + (&client_a, cx_a), + &mut [(&client_b, cx_b), (&client_c, cx_c)], + ) + .await; + + let channel_buffer_a = client_a + .channel_store() + .update(cx_a, |store, cx| store.open_channel_buffer(channel_id, cx)) + .await + .unwrap(); + let channel_buffer_b = client_b + .channel_store() + .update(cx_b, |store, cx| store.open_channel_buffer(channel_id, cx)) + .await + .unwrap(); + let _channel_buffer_c = client_c + .channel_store() + .update(cx_c, |store, cx| store.open_channel_buffer(channel_id, cx)) + .await + .unwrap(); + + channel_buffer_a.update(cx_a, |buffer, cx| { + buffer.buffer().update(cx, |buffer, cx| { + buffer.edit([(0..0, "1")], None, cx); + }) + }); + deterministic.run_until_parked(); + + // Client C can't reconnect. + client_c.override_establish_connection(|_, cx| cx.spawn(|_| future::pending())); + + // Server stops. + server.reset().await; + deterministic.advance_clock(RECEIVE_TIMEOUT); + + // While the server is down, both clients make an edit. + channel_buffer_a.update(cx_a, |buffer, cx| { + buffer.buffer().update(cx, |buffer, cx| { + buffer.edit([(1..1, "2")], None, cx); + }) + }); + channel_buffer_b.update(cx_b, |buffer, cx| { + buffer.buffer().update(cx, |buffer, cx| { + buffer.edit([(0..0, "0")], None, cx); + }) + }); + + // Server restarts. + server.start().await.unwrap(); + deterministic.advance_clock(CLEANUP_TIMEOUT); + + // Clients reconnects. Clients A and B see each other's edits, and see + // that client C has disconnected. + channel_buffer_a.read_with(cx_a, |buffer, cx| { + assert_eq!(buffer.buffer().read(cx).text(), "012"); + }); + channel_buffer_b.read_with(cx_b, |buffer, cx| { + assert_eq!(buffer.buffer().read(cx).text(), "012"); + }); + + channel_buffer_a.read_with(cx_a, |buffer_a, _| { + channel_buffer_b.read_with(cx_b, |buffer_b, _| { + assert_collaborators( + buffer_a.collaborators(), + &[client_a.user_id(), client_b.user_id()], + ); + assert_eq!(buffer_a.collaborators(), buffer_b.collaborators()); + }); + }); +} + +//todo!(collab_ui) +// #[gpui::test(iterations = 10)] +// async fn test_following_to_channel_notes_without_a_shared_project( +// deterministic: BackgroundExecutor, +// mut cx_a: &mut TestAppContext, +// mut cx_b: &mut TestAppContext, +// mut cx_c: &mut TestAppContext, +// ) { +// let mut server = TestServer::start(&deterministic).await; +// let client_a = server.create_client(cx_a, "user_a").await; +// let client_b = server.create_client(cx_b, "user_b").await; + +// let client_c = server.create_client(cx_c, "user_c").await; + +// cx_a.update(editor::init); +// cx_b.update(editor::init); +// cx_c.update(editor::init); +// cx_a.update(collab_ui::channel_view::init); +// cx_b.update(collab_ui::channel_view::init); +// cx_c.update(collab_ui::channel_view::init); + +// let channel_1_id = server +// .make_channel( +// "channel-1", +// None, +// (&client_a, cx_a), +// &mut [(&client_b, cx_b), (&client_c, cx_c)], +// ) +// .await; +// let channel_2_id = server +// .make_channel( +// "channel-2", +// None, +// (&client_a, cx_a), +// &mut [(&client_b, cx_b), (&client_c, cx_c)], +// ) +// .await; + +// // Clients A, B, and C join a channel. +// let active_call_a = cx_a.read(ActiveCall::global); +// let active_call_b = cx_b.read(ActiveCall::global); +// let active_call_c = cx_c.read(ActiveCall::global); +// for (call, cx) in [ +// (&active_call_a, &mut cx_a), +// (&active_call_b, &mut cx_b), +// (&active_call_c, &mut cx_c), +// ] { +// call.update(*cx, |call, cx| call.join_channel(channel_1_id, cx)) +// .await +// .unwrap(); +// } +// deterministic.run_until_parked(); + +// // Clients A, B, and C all open their own unshared projects. +// client_a.fs().insert_tree("/a", json!({})).await; +// client_b.fs().insert_tree("/b", json!({})).await; +// client_c.fs().insert_tree("/c", json!({})).await; +// let (project_a, _) = client_a.build_local_project("/a", cx_a).await; +// let (project_b, _) = client_b.build_local_project("/b", cx_b).await; +// let (project_c, _) = client_b.build_local_project("/c", cx_c).await; +// let workspace_a = client_a.build_workspace(&project_a, cx_a).root(cx_a); +// let workspace_b = client_b.build_workspace(&project_b, cx_b).root(cx_b); +// let _workspace_c = client_c.build_workspace(&project_c, cx_c).root(cx_c); + +// active_call_a +// .update(cx_a, |call, cx| call.set_location(Some(&project_a), cx)) +// .await +// .unwrap(); + +// // Client A opens the notes for channel 1. +// let channel_view_1_a = cx_a +// .update(|cx| ChannelView::open(channel_1_id, workspace_a.clone(), cx)) +// .await +// .unwrap(); +// channel_view_1_a.update(cx_a, |notes, cx| { +// assert_eq!(notes.channel(cx).unwrap().name, "channel-1"); +// notes.editor.update(cx, |editor, cx| { +// editor.insert("Hello from A.", cx); +// editor.change_selections(None, cx, |selections| { +// selections.select_ranges(vec![3..4]); +// }); +// }); +// }); + +// // Client B follows client A. +// workspace_b +// .update(cx_b, |workspace, cx| { +// workspace.follow(client_a.peer_id().unwrap(), cx).unwrap() +// }) +// .await +// .unwrap(); + +// // Client B is taken to the notes for channel 1, with the same +// // text selected as client A. +// deterministic.run_until_parked(); +// let channel_view_1_b = workspace_b.read_with(cx_b, |workspace, cx| { +// assert_eq!( +// workspace.leader_for_pane(workspace.active_pane()), +// Some(client_a.peer_id().unwrap()) +// ); +// workspace +// .active_item(cx) +// .expect("no active item") +// .downcast::() +// .expect("active item is not a channel view") +// }); +// channel_view_1_b.read_with(cx_b, |notes, cx| { +// assert_eq!(notes.channel(cx).unwrap().name, "channel-1"); +// let editor = notes.editor.read(cx); +// assert_eq!(editor.text(cx), "Hello from A."); +// assert_eq!(editor.selections.ranges::(cx), &[3..4]); +// }); + +// // Client A opens the notes for channel 2. +// let channel_view_2_a = cx_a +// .update(|cx| ChannelView::open(channel_2_id, workspace_a.clone(), cx)) +// .await +// .unwrap(); +// channel_view_2_a.read_with(cx_a, |notes, cx| { +// assert_eq!(notes.channel(cx).unwrap().name, "channel-2"); +// }); + +// // Client B is taken to the notes for channel 2. +// deterministic.run_until_parked(); +// let channel_view_2_b = workspace_b.read_with(cx_b, |workspace, cx| { +// assert_eq!( +// workspace.leader_for_pane(workspace.active_pane()), +// Some(client_a.peer_id().unwrap()) +// ); +// workspace +// .active_item(cx) +// .expect("no active item") +// .downcast::() +// .expect("active item is not a channel view") +// }); +// channel_view_2_b.read_with(cx_b, |notes, cx| { +// assert_eq!(notes.channel(cx).unwrap().name, "channel-2"); +// }); +// } + +//todo!(collab_ui) +// #[gpui::test] +// async fn test_channel_buffer_changes( +// deterministic: BackgroundExecutor, +// cx_a: &mut TestAppContext, +// cx_b: &mut TestAppContext, +// ) { +// let mut server = TestServer::start(&deterministic).await; +// let client_a = server.create_client(cx_a, "user_a").await; +// let client_b = server.create_client(cx_b, "user_b").await; + +// let channel_id = server +// .make_channel( +// "the-channel", +// None, +// (&client_a, cx_a), +// &mut [(&client_b, cx_b)], +// ) +// .await; + +// let channel_buffer_a = client_a +// .channel_store() +// .update(cx_a, |store, cx| store.open_channel_buffer(channel_id, cx)) +// .await +// .unwrap(); + +// // Client A makes an edit, and client B should see that the note has changed. +// channel_buffer_a.update(cx_a, |buffer, cx| { +// buffer.buffer().update(cx, |buffer, cx| { +// buffer.edit([(0..0, "1")], None, cx); +// }) +// }); +// deterministic.run_until_parked(); + +// let has_buffer_changed = cx_b.update(|cx| { +// client_b +// .channel_store() +// .read(cx) +// .has_channel_buffer_changed(channel_id) +// .unwrap() +// }); +// assert!(has_buffer_changed); + +// // Opening the buffer should clear the changed flag. +// let project_b = client_b.build_empty_local_project(cx_b); +// let workspace_b = client_b.build_workspace(&project_b, cx_b).root(cx_b); +// let channel_view_b = cx_b +// .update(|cx| ChannelView::open(channel_id, workspace_b.clone(), cx)) +// .await +// .unwrap(); +// deterministic.run_until_parked(); + +// let has_buffer_changed = cx_b.update(|cx| { +// client_b +// .channel_store() +// .read(cx) +// .has_channel_buffer_changed(channel_id) +// .unwrap() +// }); +// assert!(!has_buffer_changed); + +// // Editing the channel while the buffer is open should not show that the buffer has changed. +// channel_buffer_a.update(cx_a, |buffer, cx| { +// buffer.buffer().update(cx, |buffer, cx| { +// buffer.edit([(0..0, "2")], None, cx); +// }) +// }); +// deterministic.run_until_parked(); + +// let has_buffer_changed = cx_b.read(|cx| { +// client_b +// .channel_store() +// .read(cx) +// .has_channel_buffer_changed(channel_id) +// .unwrap() +// }); +// assert!(!has_buffer_changed); + +// deterministic.advance_clock(ACKNOWLEDGE_DEBOUNCE_INTERVAL); + +// // Test that the server is tracking things correctly, and we retain our 'not changed' +// // state across a disconnect +// server.simulate_long_connection_interruption(client_b.peer_id().unwrap(), &deterministic); +// let has_buffer_changed = cx_b.read(|cx| { +// client_b +// .channel_store() +// .read(cx) +// .has_channel_buffer_changed(channel_id) +// .unwrap() +// }); +// assert!(!has_buffer_changed); + +// // Closing the buffer should re-enable change tracking +// cx_b.update(|cx| { +// workspace_b.update(cx, |workspace, cx| { +// workspace.close_all_items_and_panes(&Default::default(), cx) +// }); + +// drop(channel_view_b) +// }); + +// deterministic.run_until_parked(); + +// channel_buffer_a.update(cx_a, |buffer, cx| { +// buffer.buffer().update(cx, |buffer, cx| { +// buffer.edit([(0..0, "3")], None, cx); +// }) +// }); +// deterministic.run_until_parked(); + +// let has_buffer_changed = cx_b.read(|cx| { +// client_b +// .channel_store() +// .read(cx) +// .has_channel_buffer_changed(channel_id) +// .unwrap() +// }); +// assert!(has_buffer_changed); +// } + +#[track_caller] +fn assert_collaborators(collaborators: &HashMap, ids: &[Option]) { + let mut user_ids = collaborators + .values() + .map(|collaborator| collaborator.user_id) + .collect::>(); + user_ids.sort(); + assert_eq!( + user_ids, + ids.into_iter().map(|id| id.unwrap()).collect::>() + ); +} + +fn buffer_text(channel_buffer: &Model, cx: &mut TestAppContext) -> String { + channel_buffer.read_with(cx, |buffer, _| buffer.text()) +} diff --git a/crates/collab2/src/tests/channel_message_tests.rs b/crates/collab2/src/tests/channel_message_tests.rs new file mode 100644 index 0000000000..4d030dd679 --- /dev/null +++ b/crates/collab2/src/tests/channel_message_tests.rs @@ -0,0 +1,408 @@ +use crate::{rpc::RECONNECT_TIMEOUT, tests::TestServer}; +use channel::{ChannelChat, ChannelMessageId}; +use gpui::{BackgroundExecutor, Model, TestAppContext}; + +// todo!(notifications) +// #[gpui::test] +// async fn test_basic_channel_messages( +// executor: BackgroundExecutor, +// mut cx_a: &mut TestAppContext, +// mut cx_b: &mut TestAppContext, +// mut cx_c: &mut TestAppContext, +// ) { +// let mut server = TestServer::start(executor.clone()).await; +// let client_a = server.create_client(cx_a, "user_a").await; +// let client_b = server.create_client(cx_b, "user_b").await; +// let client_c = server.create_client(cx_c, "user_c").await; + +// let channel_id = server +// .make_channel( +// "the-channel", +// None, +// (&client_a, cx_a), +// &mut [(&client_b, cx_b), (&client_c, cx_c)], +// ) +// .await; + +// let channel_chat_a = client_a +// .channel_store() +// .update(cx_a, |store, cx| store.open_channel_chat(channel_id, cx)) +// .await +// .unwrap(); +// let channel_chat_b = client_b +// .channel_store() +// .update(cx_b, |store, cx| store.open_channel_chat(channel_id, cx)) +// .await +// .unwrap(); + +// let message_id = channel_chat_a +// .update(cx_a, |c, cx| { +// c.send_message( +// MessageParams { +// text: "hi @user_c!".into(), +// mentions: vec![(3..10, client_c.id())], +// }, +// cx, +// ) +// .unwrap() +// }) +// .await +// .unwrap(); +// channel_chat_a +// .update(cx_a, |c, cx| c.send_message("two".into(), cx).unwrap()) +// .await +// .unwrap(); + +// executor.run_until_parked(); +// channel_chat_b +// .update(cx_b, |c, cx| c.send_message("three".into(), cx).unwrap()) +// .await +// .unwrap(); + +// executor.run_until_parked(); + +// let channel_chat_c = client_c +// .channel_store() +// .update(cx_c, |store, cx| store.open_channel_chat(channel_id, cx)) +// .await +// .unwrap(); + +// for (chat, cx) in [ +// (&channel_chat_a, &mut cx_a), +// (&channel_chat_b, &mut cx_b), +// (&channel_chat_c, &mut cx_c), +// ] { +// chat.update(*cx, |c, _| { +// assert_eq!( +// c.messages() +// .iter() +// .map(|m| (m.body.as_str(), m.mentions.as_slice())) +// .collect::>(), +// vec![ +// ("hi @user_c!", [(3..10, client_c.id())].as_slice()), +// ("two", &[]), +// ("three", &[]) +// ], +// "results for user {}", +// c.client().id(), +// ); +// }); +// } + +// client_c.notification_store().update(cx_c, |store, _| { +// assert_eq!(store.notification_count(), 2); +// assert_eq!(store.unread_notification_count(), 1); +// assert_eq!( +// store.notification_at(0).unwrap().notification, +// Notification::ChannelMessageMention { +// message_id, +// sender_id: client_a.id(), +// channel_id, +// } +// ); +// assert_eq!( +// store.notification_at(1).unwrap().notification, +// Notification::ChannelInvitation { +// channel_id, +// channel_name: "the-channel".to_string(), +// inviter_id: client_a.id() +// } +// ); +// }); +// } + +#[gpui::test] +async fn test_rejoin_channel_chat( + executor: BackgroundExecutor, + cx_a: &mut TestAppContext, + cx_b: &mut TestAppContext, +) { + let mut server = TestServer::start(executor.clone()).await; + let client_a = server.create_client(cx_a, "user_a").await; + let client_b = server.create_client(cx_b, "user_b").await; + + let channel_id = server + .make_channel( + "the-channel", + None, + (&client_a, cx_a), + &mut [(&client_b, cx_b)], + ) + .await; + + let channel_chat_a = client_a + .channel_store() + .update(cx_a, |store, cx| store.open_channel_chat(channel_id, cx)) + .await + .unwrap(); + let channel_chat_b = client_b + .channel_store() + .update(cx_b, |store, cx| store.open_channel_chat(channel_id, cx)) + .await + .unwrap(); + + channel_chat_a + .update(cx_a, |c, cx| c.send_message("one".into(), cx).unwrap()) + .await + .unwrap(); + channel_chat_b + .update(cx_b, |c, cx| c.send_message("two".into(), cx).unwrap()) + .await + .unwrap(); + + server.forbid_connections(); + server.disconnect_client(client_a.peer_id().unwrap()); + + // While client A is disconnected, clients A and B both send new messages. + channel_chat_a + .update(cx_a, |c, cx| c.send_message("three".into(), cx).unwrap()) + .await + .unwrap_err(); + channel_chat_a + .update(cx_a, |c, cx| c.send_message("four".into(), cx).unwrap()) + .await + .unwrap_err(); + channel_chat_b + .update(cx_b, |c, cx| c.send_message("five".into(), cx).unwrap()) + .await + .unwrap(); + channel_chat_b + .update(cx_b, |c, cx| c.send_message("six".into(), cx).unwrap()) + .await + .unwrap(); + + // Client A reconnects. + server.allow_connections(); + executor.advance_clock(RECONNECT_TIMEOUT); + + // Client A fetches the messages that were sent while they were disconnected + // and resends their own messages which failed to send. + let expected_messages = &["one", "two", "five", "six", "three", "four"]; + assert_messages(&channel_chat_a, expected_messages, cx_a); + assert_messages(&channel_chat_b, expected_messages, cx_b); +} + +#[gpui::test] +async fn test_remove_channel_message( + executor: BackgroundExecutor, + cx_a: &mut TestAppContext, + cx_b: &mut TestAppContext, + cx_c: &mut TestAppContext, +) { + let mut server = TestServer::start(executor.clone()).await; + let client_a = server.create_client(cx_a, "user_a").await; + let client_b = server.create_client(cx_b, "user_b").await; + let client_c = server.create_client(cx_c, "user_c").await; + + let channel_id = server + .make_channel( + "the-channel", + None, + (&client_a, cx_a), + &mut [(&client_b, cx_b), (&client_c, cx_c)], + ) + .await; + + let channel_chat_a = client_a + .channel_store() + .update(cx_a, |store, cx| store.open_channel_chat(channel_id, cx)) + .await + .unwrap(); + let channel_chat_b = client_b + .channel_store() + .update(cx_b, |store, cx| store.open_channel_chat(channel_id, cx)) + .await + .unwrap(); + + // Client A sends some messages. + channel_chat_a + .update(cx_a, |c, cx| c.send_message("one".into(), cx).unwrap()) + .await + .unwrap(); + channel_chat_a + .update(cx_a, |c, cx| c.send_message("two".into(), cx).unwrap()) + .await + .unwrap(); + channel_chat_a + .update(cx_a, |c, cx| c.send_message("three".into(), cx).unwrap()) + .await + .unwrap(); + + // Clients A and B see all of the messages. + executor.run_until_parked(); + let expected_messages = &["one", "two", "three"]; + assert_messages(&channel_chat_a, expected_messages, cx_a); + assert_messages(&channel_chat_b, expected_messages, cx_b); + + // Client A deletes one of their messages. + channel_chat_a + .update(cx_a, |c, cx| { + let ChannelMessageId::Saved(id) = c.message(1).id else { + panic!("message not saved") + }; + c.remove_message(id, cx) + }) + .await + .unwrap(); + + // Client B sees that the message is gone. + executor.run_until_parked(); + let expected_messages = &["one", "three"]; + assert_messages(&channel_chat_a, expected_messages, cx_a); + assert_messages(&channel_chat_b, expected_messages, cx_b); + + // Client C joins the channel chat, and does not see the deleted message. + let channel_chat_c = client_c + .channel_store() + .update(cx_c, |store, cx| store.open_channel_chat(channel_id, cx)) + .await + .unwrap(); + assert_messages(&channel_chat_c, expected_messages, cx_c); +} + +#[track_caller] +fn assert_messages(chat: &Model, messages: &[&str], cx: &mut TestAppContext) { + // todo!(don't directly borrow here) + assert_eq!( + chat.read_with(cx, |chat, _| { + chat.messages() + .iter() + .map(|m| m.body.clone()) + .collect::>() + }), + messages + ); +} + +//todo!(collab_ui) +// #[gpui::test] +// async fn test_channel_message_changes( +// executor: BackgroundExecutor, +// cx_a: &mut TestAppContext, +// cx_b: &mut TestAppContext, +// ) { +// let mut server = TestServer::start(&executor).await; +// let client_a = server.create_client(cx_a, "user_a").await; +// let client_b = server.create_client(cx_b, "user_b").await; + +// let channel_id = server +// .make_channel( +// "the-channel", +// None, +// (&client_a, cx_a), +// &mut [(&client_b, cx_b)], +// ) +// .await; + +// // Client A sends a message, client B should see that there is a new message. +// let channel_chat_a = client_a +// .channel_store() +// .update(cx_a, |store, cx| store.open_channel_chat(channel_id, cx)) +// .await +// .unwrap(); + +// channel_chat_a +// .update(cx_a, |c, cx| c.send_message("one".into(), cx).unwrap()) +// .await +// .unwrap(); + +// executor.run_until_parked(); + +// let b_has_messages = cx_b.read_with(|cx| { +// client_b +// .channel_store() +// .read(cx) +// .has_new_messages(channel_id) +// .unwrap() +// }); + +// assert!(b_has_messages); + +// // Opening the chat should clear the changed flag. +// cx_b.update(|cx| { +// collab_ui::init(&client_b.app_state, cx); +// }); +// let project_b = client_b.build_empty_local_project(cx_b); +// let workspace_b = client_b.build_workspace(&project_b, cx_b).root(cx_b); +// let chat_panel_b = workspace_b.update(cx_b, |workspace, cx| ChatPanel::new(workspace, cx)); +// chat_panel_b +// .update(cx_b, |chat_panel, cx| { +// chat_panel.set_active(true, cx); +// chat_panel.select_channel(channel_id, None, cx) +// }) +// .await +// .unwrap(); + +// executor.run_until_parked(); + +// let b_has_messages = cx_b.read_with(|cx| { +// client_b +// .channel_store() +// .read(cx) +// .has_new_messages(channel_id) +// .unwrap() +// }); + +// assert!(!b_has_messages); + +// // Sending a message while the chat is open should not change the flag. +// channel_chat_a +// .update(cx_a, |c, cx| c.send_message("two".into(), cx).unwrap()) +// .await +// .unwrap(); + +// executor.run_until_parked(); + +// let b_has_messages = cx_b.read_with(|cx| { +// client_b +// .channel_store() +// .read(cx) +// .has_new_messages(channel_id) +// .unwrap() +// }); + +// assert!(!b_has_messages); + +// // Sending a message while the chat is closed should change the flag. +// chat_panel_b.update(cx_b, |chat_panel, cx| { +// chat_panel.set_active(false, cx); +// }); + +// // Sending a message while the chat is open should not change the flag. +// channel_chat_a +// .update(cx_a, |c, cx| c.send_message("three".into(), cx).unwrap()) +// .await +// .unwrap(); + +// executor.run_until_parked(); + +// let b_has_messages = cx_b.read_with(|cx| { +// client_b +// .channel_store() +// .read(cx) +// .has_new_messages(channel_id) +// .unwrap() +// }); + +// assert!(b_has_messages); + +// // Closing the chat should re-enable change tracking +// cx_b.update(|_| drop(chat_panel_b)); + +// channel_chat_a +// .update(cx_a, |c, cx| c.send_message("four".into(), cx).unwrap()) +// .await +// .unwrap(); + +// executor.run_until_parked(); + +// let b_has_messages = cx_b.read_with(|cx| { +// client_b +// .channel_store() +// .read(cx) +// .has_new_messages(channel_id) +// .unwrap() +// }); + +// assert!(b_has_messages); +// } diff --git a/crates/collab2/src/tests/channel_tests.rs b/crates/collab2/src/tests/channel_tests.rs new file mode 100644 index 0000000000..31c092bd08 --- /dev/null +++ b/crates/collab2/src/tests/channel_tests.rs @@ -0,0 +1,1541 @@ +use crate::{ + db::{self, UserId}, + rpc::RECONNECT_TIMEOUT, + tests::{room_participants, RoomParticipants, TestServer}, +}; +use call::ActiveCall; +use channel::{ChannelId, ChannelMembership, ChannelStore}; +use client::User; +use futures::future::try_join_all; +use gpui::{BackgroundExecutor, Model, TestAppContext}; +use rpc::{ + proto::{self, ChannelRole}, + RECEIVE_TIMEOUT, +}; +use std::sync::Arc; + +#[gpui::test] +async fn test_core_channels( + executor: BackgroundExecutor, + cx_a: &mut TestAppContext, + cx_b: &mut TestAppContext, +) { + let mut server = TestServer::start(executor.clone()).await; + let client_a = server.create_client(cx_a, "user_a").await; + let client_b = server.create_client(cx_b, "user_b").await; + + let channel_a_id = client_a + .channel_store() + .update(cx_a, |channel_store, cx| { + channel_store.create_channel("channel-a", None, cx) + }) + .await + .unwrap(); + let channel_b_id = client_a + .channel_store() + .update(cx_a, |channel_store, cx| { + channel_store.create_channel("channel-b", Some(channel_a_id), cx) + }) + .await + .unwrap(); + + executor.run_until_parked(); + assert_channels( + client_a.channel_store(), + cx_a, + &[ + ExpectedChannel { + id: channel_a_id, + name: "channel-a".to_string(), + depth: 0, + role: ChannelRole::Admin, + }, + ExpectedChannel { + id: channel_b_id, + name: "channel-b".to_string(), + depth: 1, + role: ChannelRole::Admin, + }, + ], + ); + + cx_b.read(|cx| { + client_b.channel_store().read_with(cx, |channels, _| { + assert!(channels.ordered_channels().collect::>().is_empty()) + }) + }); + + // Invite client B to channel A as client A. + client_a + .channel_store() + .update(cx_a, |store, cx| { + assert!(!store.has_pending_channel_invite(channel_a_id, client_b.user_id().unwrap())); + + let invite = store.invite_member( + channel_a_id, + client_b.user_id().unwrap(), + proto::ChannelRole::Member, + cx, + ); + + // Make sure we're synchronously storing the pending invite + assert!(store.has_pending_channel_invite(channel_a_id, client_b.user_id().unwrap())); + invite + }) + .await + .unwrap(); + + // Client A sees that B has been invited. + executor.run_until_parked(); + assert_channel_invitations( + client_b.channel_store(), + cx_b, + &[ExpectedChannel { + id: channel_a_id, + name: "channel-a".to_string(), + depth: 0, + role: ChannelRole::Member, + }], + ); + + let members = client_a + .channel_store() + .update(cx_a, |store, cx| { + assert!(!store.has_pending_channel_invite(channel_a_id, client_b.user_id().unwrap())); + store.get_channel_member_details(channel_a_id, cx) + }) + .await + .unwrap(); + assert_members_eq( + &members, + &[ + ( + client_a.user_id().unwrap(), + proto::ChannelRole::Admin, + proto::channel_member::Kind::Member, + ), + ( + client_b.user_id().unwrap(), + proto::ChannelRole::Member, + proto::channel_member::Kind::Invitee, + ), + ], + ); + + // Client B accepts the invitation. + client_b + .channel_store() + .update(cx_b, |channels, cx| { + channels.respond_to_channel_invite(channel_a_id, true, cx) + }) + .await + .unwrap(); + executor.run_until_parked(); + + // Client B now sees that they are a member of channel A and its existing subchannels. + assert_channel_invitations(client_b.channel_store(), cx_b, &[]); + assert_channels( + client_b.channel_store(), + cx_b, + &[ + ExpectedChannel { + id: channel_a_id, + name: "channel-a".to_string(), + role: ChannelRole::Member, + depth: 0, + }, + ExpectedChannel { + id: channel_b_id, + name: "channel-b".to_string(), + role: ChannelRole::Member, + depth: 1, + }, + ], + ); + + let channel_c_id = client_a + .channel_store() + .update(cx_a, |channel_store, cx| { + channel_store.create_channel("channel-c", Some(channel_b_id), cx) + }) + .await + .unwrap(); + + executor.run_until_parked(); + assert_channels( + client_b.channel_store(), + cx_b, + &[ + ExpectedChannel { + id: channel_a_id, + name: "channel-a".to_string(), + role: ChannelRole::Member, + depth: 0, + }, + ExpectedChannel { + id: channel_b_id, + name: "channel-b".to_string(), + role: ChannelRole::Member, + depth: 1, + }, + ExpectedChannel { + id: channel_c_id, + name: "channel-c".to_string(), + role: ChannelRole::Member, + depth: 2, + }, + ], + ); + + // Update client B's membership to channel A to be an admin. + client_a + .channel_store() + .update(cx_a, |store, cx| { + store.set_member_role( + channel_a_id, + client_b.user_id().unwrap(), + proto::ChannelRole::Admin, + cx, + ) + }) + .await + .unwrap(); + executor.run_until_parked(); + + // Observe that client B is now an admin of channel A, and that + // their admin priveleges extend to subchannels of channel A. + assert_channel_invitations(client_b.channel_store(), cx_b, &[]); + assert_channels( + client_b.channel_store(), + cx_b, + &[ + ExpectedChannel { + id: channel_a_id, + name: "channel-a".to_string(), + depth: 0, + role: ChannelRole::Admin, + }, + ExpectedChannel { + id: channel_b_id, + name: "channel-b".to_string(), + depth: 1, + role: ChannelRole::Admin, + }, + ExpectedChannel { + id: channel_c_id, + name: "channel-c".to_string(), + depth: 2, + role: ChannelRole::Admin, + }, + ], + ); + + // Client A deletes the channel, deletion also deletes subchannels. + client_a + .channel_store() + .update(cx_a, |channel_store, _| { + channel_store.remove_channel(channel_b_id) + }) + .await + .unwrap(); + + executor.run_until_parked(); + assert_channels( + client_a.channel_store(), + cx_a, + &[ExpectedChannel { + id: channel_a_id, + name: "channel-a".to_string(), + depth: 0, + role: ChannelRole::Admin, + }], + ); + assert_channels( + client_b.channel_store(), + cx_b, + &[ExpectedChannel { + id: channel_a_id, + name: "channel-a".to_string(), + depth: 0, + role: ChannelRole::Admin, + }], + ); + + // Remove client B + client_a + .channel_store() + .update(cx_a, |channel_store, cx| { + channel_store.remove_member(channel_a_id, client_b.user_id().unwrap(), cx) + }) + .await + .unwrap(); + + executor.run_until_parked(); + + // Client A still has their channel + assert_channels( + client_a.channel_store(), + cx_a, + &[ExpectedChannel { + id: channel_a_id, + name: "channel-a".to_string(), + depth: 0, + role: ChannelRole::Admin, + }], + ); + + // Client B no longer has access to the channel + assert_channels(client_b.channel_store(), cx_b, &[]); + + server.forbid_connections(); + server.disconnect_client(client_a.peer_id().unwrap()); + executor.advance_clock(RECEIVE_TIMEOUT + RECONNECT_TIMEOUT); + + server + .app_state + .db + .rename_channel( + db::ChannelId::from_proto(channel_a_id), + UserId::from_proto(client_a.id()), + "channel-a-renamed", + ) + .await + .unwrap(); + + server.allow_connections(); + executor.advance_clock(RECEIVE_TIMEOUT + RECONNECT_TIMEOUT); + assert_channels( + client_a.channel_store(), + cx_a, + &[ExpectedChannel { + id: channel_a_id, + name: "channel-a-renamed".to_string(), + depth: 0, + role: ChannelRole::Admin, + }], + ); +} + +#[track_caller] +fn assert_participants_eq(participants: &[Arc], expected_partitipants: &[u64]) { + assert_eq!( + participants.iter().map(|p| p.id).collect::>(), + expected_partitipants + ); +} + +#[track_caller] +fn assert_members_eq( + members: &[ChannelMembership], + expected_members: &[(u64, proto::ChannelRole, proto::channel_member::Kind)], +) { + assert_eq!( + members + .iter() + .map(|member| (member.user.id, member.role, member.kind)) + .collect::>(), + expected_members + ); +} + +#[gpui::test] +async fn test_joining_channel_ancestor_member( + executor: BackgroundExecutor, + cx_a: &mut TestAppContext, + cx_b: &mut TestAppContext, +) { + let mut server = TestServer::start(executor.clone()).await; + + let client_a = server.create_client(cx_a, "user_a").await; + let client_b = server.create_client(cx_b, "user_b").await; + + let parent_id = server + .make_channel("parent", None, (&client_a, cx_a), &mut [(&client_b, cx_b)]) + .await; + + let sub_id = client_a + .channel_store() + .update(cx_a, |channel_store, cx| { + channel_store.create_channel("sub_channel", Some(parent_id), cx) + }) + .await + .unwrap(); + + let active_call_b = cx_b.read(ActiveCall::global); + + assert!(active_call_b + .update(cx_b, |active_call, cx| active_call.join_channel(sub_id, cx)) + .await + .is_ok()); +} + +#[gpui::test] +async fn test_channel_room( + executor: BackgroundExecutor, + cx_a: &mut TestAppContext, + cx_b: &mut TestAppContext, + cx_c: &mut TestAppContext, +) { + let mut server = TestServer::start(executor.clone()).await; + let client_a = server.create_client(cx_a, "user_a").await; + let client_b = server.create_client(cx_b, "user_b").await; + let client_c = server.create_client(cx_c, "user_c").await; + + let zed_id = server + .make_channel( + "zed", + None, + (&client_a, cx_a), + &mut [(&client_b, cx_b), (&client_c, cx_c)], + ) + .await; + + let active_call_a = cx_a.read(ActiveCall::global); + let active_call_b = cx_b.read(ActiveCall::global); + + active_call_a + .update(cx_a, |active_call, cx| active_call.join_channel(zed_id, cx)) + .await + .unwrap(); + + // Give everyone a chance to observe user A joining + executor.run_until_parked(); + let room_a = + cx_a.read(|cx| active_call_a.read_with(cx, |call, _| call.room().unwrap().clone())); + cx_a.read(|cx| room_a.read_with(cx, |room, _| assert!(room.is_connected()))); + + cx_a.read(|cx| { + client_a.channel_store().read_with(cx, |channels, _| { + assert_participants_eq( + channels.channel_participants(zed_id), + &[client_a.user_id().unwrap()], + ); + }) + }); + + assert_channels( + client_b.channel_store(), + cx_b, + &[ExpectedChannel { + id: zed_id, + name: "zed".to_string(), + depth: 0, + role: ChannelRole::Member, + }], + ); + cx_b.read(|cx| { + client_b.channel_store().read_with(cx, |channels, _| { + assert_participants_eq( + channels.channel_participants(zed_id), + &[client_a.user_id().unwrap()], + ); + }) + }); + + cx_c.read(|cx| { + client_c.channel_store().read_with(cx, |channels, _| { + assert_participants_eq( + channels.channel_participants(zed_id), + &[client_a.user_id().unwrap()], + ); + }) + }); + + active_call_b + .update(cx_b, |active_call, cx| active_call.join_channel(zed_id, cx)) + .await + .unwrap(); + + executor.run_until_parked(); + + cx_a.read(|cx| { + client_a.channel_store().read_with(cx, |channels, _| { + assert_participants_eq( + channels.channel_participants(zed_id), + &[client_a.user_id().unwrap(), client_b.user_id().unwrap()], + ); + }) + }); + + cx_b.read(|cx| { + client_b.channel_store().read_with(cx, |channels, _| { + assert_participants_eq( + channels.channel_participants(zed_id), + &[client_a.user_id().unwrap(), client_b.user_id().unwrap()], + ); + }) + }); + + cx_c.read(|cx| { + client_c.channel_store().read_with(cx, |channels, _| { + assert_participants_eq( + channels.channel_participants(zed_id), + &[client_a.user_id().unwrap(), client_b.user_id().unwrap()], + ); + }) + }); + + let room_a = + cx_a.read(|cx| active_call_a.read_with(cx, |call, _| call.room().unwrap().clone())); + cx_a.read(|cx| room_a.read_with(cx, |room, _| assert!(room.is_connected()))); + assert_eq!( + room_participants(&room_a, cx_a), + RoomParticipants { + remote: vec!["user_b".to_string()], + pending: vec![] + } + ); + + let room_b = + cx_b.read(|cx| active_call_b.read_with(cx, |call, _| call.room().unwrap().clone())); + cx_b.read(|cx| room_b.read_with(cx, |room, _| assert!(room.is_connected()))); + assert_eq!( + room_participants(&room_b, cx_b), + RoomParticipants { + remote: vec!["user_a".to_string()], + pending: vec![] + } + ); + + // Make sure that leaving and rejoining works + + active_call_a + .update(cx_a, |active_call, cx| active_call.hang_up(cx)) + .await + .unwrap(); + + executor.run_until_parked(); + + cx_a.read(|cx| { + client_a.channel_store().read_with(cx, |channels, _| { + assert_participants_eq( + channels.channel_participants(zed_id), + &[client_b.user_id().unwrap()], + ); + }) + }); + + cx_b.read(|cx| { + client_b.channel_store().read_with(cx, |channels, _| { + assert_participants_eq( + channels.channel_participants(zed_id), + &[client_b.user_id().unwrap()], + ); + }) + }); + + cx_c.read(|cx| { + client_c.channel_store().read_with(cx, |channels, _| { + assert_participants_eq( + channels.channel_participants(zed_id), + &[client_b.user_id().unwrap()], + ); + }) + }); + + active_call_b + .update(cx_b, |active_call, cx| active_call.hang_up(cx)) + .await + .unwrap(); + + executor.run_until_parked(); + + cx_a.read(|cx| { + client_a.channel_store().read_with(cx, |channels, _| { + assert_participants_eq(channels.channel_participants(zed_id), &[]); + }) + }); + + cx_b.read(|cx| { + client_b.channel_store().read_with(cx, |channels, _| { + assert_participants_eq(channels.channel_participants(zed_id), &[]); + }) + }); + + cx_c.read(|cx| { + client_c.channel_store().read_with(cx, |channels, _| { + assert_participants_eq(channels.channel_participants(zed_id), &[]); + }) + }); + + active_call_a + .update(cx_a, |active_call, cx| active_call.join_channel(zed_id, cx)) + .await + .unwrap(); + + active_call_b + .update(cx_b, |active_call, cx| active_call.join_channel(zed_id, cx)) + .await + .unwrap(); + + executor.run_until_parked(); + + let room_a = + cx_a.read(|cx| active_call_a.read_with(cx, |call, _| call.room().unwrap().clone())); + cx_a.read(|cx| room_a.read_with(cx, |room, _| assert!(room.is_connected()))); + assert_eq!( + room_participants(&room_a, cx_a), + RoomParticipants { + remote: vec!["user_b".to_string()], + pending: vec![] + } + ); + + let room_b = + cx_b.read(|cx| active_call_b.read_with(cx, |call, _| call.room().unwrap().clone())); + cx_b.read(|cx| room_b.read_with(cx, |room, _| assert!(room.is_connected()))); + assert_eq!( + room_participants(&room_b, cx_b), + RoomParticipants { + remote: vec!["user_a".to_string()], + pending: vec![] + } + ); +} + +#[gpui::test] +async fn test_channel_jumping(executor: BackgroundExecutor, cx_a: &mut TestAppContext) { + let mut server = TestServer::start(executor.clone()).await; + let client_a = server.create_client(cx_a, "user_a").await; + + let zed_id = server + .make_channel("zed", None, (&client_a, cx_a), &mut []) + .await; + let rust_id = server + .make_channel("rust", None, (&client_a, cx_a), &mut []) + .await; + + let active_call_a = cx_a.read(ActiveCall::global); + + active_call_a + .update(cx_a, |active_call, cx| active_call.join_channel(zed_id, cx)) + .await + .unwrap(); + + // Give everything a chance to observe user A joining + executor.run_until_parked(); + + cx_a.read(|cx| { + client_a.channel_store().read_with(cx, |channels, _| { + assert_participants_eq( + channels.channel_participants(zed_id), + &[client_a.user_id().unwrap()], + ); + assert_participants_eq(channels.channel_participants(rust_id), &[]); + }) + }); + + active_call_a + .update(cx_a, |active_call, cx| { + active_call.join_channel(rust_id, cx) + }) + .await + .unwrap(); + + executor.run_until_parked(); + + cx_a.read(|cx| { + client_a.channel_store().read_with(cx, |channels, _| { + assert_participants_eq(channels.channel_participants(zed_id), &[]); + assert_participants_eq( + channels.channel_participants(rust_id), + &[client_a.user_id().unwrap()], + ); + }) + }); +} + +#[gpui::test] +async fn test_permissions_update_while_invited( + executor: BackgroundExecutor, + cx_a: &mut TestAppContext, + cx_b: &mut TestAppContext, +) { + let mut server = TestServer::start(executor.clone()).await; + let client_a = server.create_client(cx_a, "user_a").await; + let client_b = server.create_client(cx_b, "user_b").await; + + let rust_id = server + .make_channel("rust", None, (&client_a, cx_a), &mut []) + .await; + + client_a + .channel_store() + .update(cx_a, |channel_store, cx| { + channel_store.invite_member( + rust_id, + client_b.user_id().unwrap(), + proto::ChannelRole::Member, + cx, + ) + }) + .await + .unwrap(); + + executor.run_until_parked(); + + assert_channel_invitations( + client_b.channel_store(), + cx_b, + &[ExpectedChannel { + depth: 0, + id: rust_id, + name: "rust".to_string(), + role: ChannelRole::Member, + }], + ); + assert_channels(client_b.channel_store(), cx_b, &[]); + + // Update B's invite before they've accepted it + client_a + .channel_store() + .update(cx_a, |channel_store, cx| { + channel_store.set_member_role( + rust_id, + client_b.user_id().unwrap(), + proto::ChannelRole::Admin, + cx, + ) + }) + .await + .unwrap(); + + executor.run_until_parked(); + + assert_channel_invitations( + client_b.channel_store(), + cx_b, + &[ExpectedChannel { + depth: 0, + id: rust_id, + name: "rust".to_string(), + role: ChannelRole::Member, + }], + ); + assert_channels(client_b.channel_store(), cx_b, &[]); +} + +#[gpui::test] +async fn test_channel_rename( + executor: BackgroundExecutor, + cx_a: &mut TestAppContext, + cx_b: &mut TestAppContext, +) { + let mut server = TestServer::start(executor.clone()).await; + let client_a = server.create_client(cx_a, "user_a").await; + let client_b = server.create_client(cx_b, "user_b").await; + + let rust_id = server + .make_channel("rust", None, (&client_a, cx_a), &mut [(&client_b, cx_b)]) + .await; + + // Rename the channel + client_a + .channel_store() + .update(cx_a, |channel_store, cx| { + channel_store.rename(rust_id, "#rust-archive", cx) + }) + .await + .unwrap(); + + executor.run_until_parked(); + + // Client A sees the channel with its new name. + assert_channels( + client_a.channel_store(), + cx_a, + &[ExpectedChannel { + depth: 0, + id: rust_id, + name: "rust-archive".to_string(), + role: ChannelRole::Admin, + }], + ); + + // Client B sees the channel with its new name. + assert_channels( + client_b.channel_store(), + cx_b, + &[ExpectedChannel { + depth: 0, + id: rust_id, + name: "rust-archive".to_string(), + role: ChannelRole::Member, + }], + ); +} + +#[gpui::test] +async fn test_call_from_channel( + executor: BackgroundExecutor, + cx_a: &mut TestAppContext, + cx_b: &mut TestAppContext, + cx_c: &mut TestAppContext, +) { + let mut server = TestServer::start(executor.clone()).await; + let client_a = server.create_client(cx_a, "user_a").await; + let client_b = server.create_client(cx_b, "user_b").await; + let client_c = server.create_client(cx_c, "user_c").await; + server + .make_contacts(&mut [(&client_a, cx_a), (&client_b, cx_b)]) + .await; + + let channel_id = server + .make_channel( + "x", + None, + (&client_a, cx_a), + &mut [(&client_b, cx_b), (&client_c, cx_c)], + ) + .await; + + let active_call_a = cx_a.read(ActiveCall::global); + let active_call_b = cx_b.read(ActiveCall::global); + + active_call_a + .update(cx_a, |call, cx| call.join_channel(channel_id, cx)) + .await + .unwrap(); + + // Client A calls client B while in the channel. + active_call_a + .update(cx_a, |call, cx| { + call.invite(client_b.user_id().unwrap(), None, cx) + }) + .await + .unwrap(); + + // Client B accepts the call. + executor.run_until_parked(); + active_call_b + .update(cx_b, |call, cx| call.accept_incoming(cx)) + .await + .unwrap(); + + // Client B sees that they are now in the channel + executor.run_until_parked(); + cx_b.read(|cx| { + active_call_b.read_with(cx, |call, cx| { + assert_eq!(call.channel_id(cx), Some(channel_id)); + }) + }); + cx_b.read(|cx| { + client_b.channel_store().read_with(cx, |channels, _| { + assert_participants_eq( + channels.channel_participants(channel_id), + &[client_a.user_id().unwrap(), client_b.user_id().unwrap()], + ); + }) + }); + + // Clients A and C also see that client B is in the channel. + cx_a.read(|cx| { + client_a.channel_store().read_with(cx, |channels, _| { + assert_participants_eq( + channels.channel_participants(channel_id), + &[client_a.user_id().unwrap(), client_b.user_id().unwrap()], + ); + }) + }); + cx_c.read(|cx| { + client_c.channel_store().read_with(cx, |channels, _| { + assert_participants_eq( + channels.channel_participants(channel_id), + &[client_a.user_id().unwrap(), client_b.user_id().unwrap()], + ); + }) + }); +} + +#[gpui::test] +async fn test_lost_channel_creation( + executor: BackgroundExecutor, + cx_a: &mut TestAppContext, + cx_b: &mut TestAppContext, +) { + let mut server = TestServer::start(executor.clone()).await; + let client_a = server.create_client(cx_a, "user_a").await; + let client_b = server.create_client(cx_b, "user_b").await; + + server + .make_contacts(&mut [(&client_a, cx_a), (&client_b, cx_b)]) + .await; + + let channel_id = server + .make_channel("x", None, (&client_a, cx_a), &mut []) + .await; + + // Invite a member + client_a + .channel_store() + .update(cx_a, |channel_store, cx| { + channel_store.invite_member( + channel_id, + client_b.user_id().unwrap(), + proto::ChannelRole::Member, + cx, + ) + }) + .await + .unwrap(); + + executor.run_until_parked(); + + // Sanity check, B has the invitation + assert_channel_invitations( + client_b.channel_store(), + cx_b, + &[ExpectedChannel { + depth: 0, + id: channel_id, + name: "x".to_string(), + role: ChannelRole::Member, + }], + ); + + // A creates a subchannel while the invite is still pending. + let subchannel_id = client_a + .channel_store() + .update(cx_a, |channel_store, cx| { + channel_store.create_channel("subchannel", Some(channel_id), cx) + }) + .await + .unwrap(); + + executor.run_until_parked(); + + // Make sure A sees their new channel + assert_channels( + client_a.channel_store(), + cx_a, + &[ + ExpectedChannel { + depth: 0, + id: channel_id, + name: "x".to_string(), + role: ChannelRole::Admin, + }, + ExpectedChannel { + depth: 1, + id: subchannel_id, + name: "subchannel".to_string(), + role: ChannelRole::Admin, + }, + ], + ); + + // Client B accepts the invite + client_b + .channel_store() + .update(cx_b, |channel_store, cx| { + channel_store.respond_to_channel_invite(channel_id, true, cx) + }) + .await + .unwrap(); + + executor.run_until_parked(); + + // Client B should now see the channel + assert_channels( + client_b.channel_store(), + cx_b, + &[ + ExpectedChannel { + depth: 0, + id: channel_id, + name: "x".to_string(), + role: ChannelRole::Member, + }, + ExpectedChannel { + depth: 1, + id: subchannel_id, + name: "subchannel".to_string(), + role: ChannelRole::Member, + }, + ], + ); +} + +#[gpui::test] +async fn test_channel_link_notifications( + executor: BackgroundExecutor, + cx_a: &mut TestAppContext, + cx_b: &mut TestAppContext, + cx_c: &mut TestAppContext, +) { + let mut server = TestServer::start(executor.clone()).await; + let client_a = server.create_client(cx_a, "user_a").await; + let client_b = server.create_client(cx_b, "user_b").await; + let client_c = server.create_client(cx_c, "user_c").await; + + let user_b = client_b.user_id().unwrap(); + let user_c = client_c.user_id().unwrap(); + + let channels = server + .make_channel_tree(&[("zed", None)], (&client_a, cx_a)) + .await; + let zed_channel = channels[0]; + + try_join_all(client_a.channel_store().update(cx_a, |channel_store, cx| { + [ + channel_store.set_channel_visibility(zed_channel, proto::ChannelVisibility::Public, cx), + channel_store.invite_member(zed_channel, user_b, proto::ChannelRole::Member, cx), + channel_store.invite_member(zed_channel, user_c, proto::ChannelRole::Guest, cx), + ] + })) + .await + .unwrap(); + + executor.run_until_parked(); + + client_b + .channel_store() + .update(cx_b, |channel_store, cx| { + channel_store.respond_to_channel_invite(zed_channel, true, cx) + }) + .await + .unwrap(); + + client_c + .channel_store() + .update(cx_c, |channel_store, cx| { + channel_store.respond_to_channel_invite(zed_channel, true, cx) + }) + .await + .unwrap(); + + executor.run_until_parked(); + + // we have an admin (a), member (b) and guest (c) all part of the zed channel. + + // create a new private channel, make it public, and move it under the previous one, and verify it shows for b and not c + let active_channel = client_a + .channel_store() + .update(cx_a, |channel_store, cx| { + channel_store.create_channel("active", Some(zed_channel), cx) + }) + .await + .unwrap(); + + executor.run_until_parked(); + + // the new channel shows for b and not c + assert_channels_list_shape( + client_a.channel_store(), + cx_a, + &[(zed_channel, 0), (active_channel, 1)], + ); + assert_channels_list_shape( + client_b.channel_store(), + cx_b, + &[(zed_channel, 0), (active_channel, 1)], + ); + assert_channels_list_shape(client_c.channel_store(), cx_c, &[(zed_channel, 0)]); + + let vim_channel = client_a + .channel_store() + .update(cx_a, |channel_store, cx| { + channel_store.create_channel("vim", None, cx) + }) + .await + .unwrap(); + + client_a + .channel_store() + .update(cx_a, |channel_store, cx| { + channel_store.set_channel_visibility(vim_channel, proto::ChannelVisibility::Public, cx) + }) + .await + .unwrap(); + + client_a + .channel_store() + .update(cx_a, |channel_store, cx| { + channel_store.move_channel(vim_channel, Some(active_channel), cx) + }) + .await + .unwrap(); + + executor.run_until_parked(); + + // the new channel shows for b and c + assert_channels_list_shape( + client_a.channel_store(), + cx_a, + &[(zed_channel, 0), (active_channel, 1), (vim_channel, 2)], + ); + assert_channels_list_shape( + client_b.channel_store(), + cx_b, + &[(zed_channel, 0), (active_channel, 1), (vim_channel, 2)], + ); + assert_channels_list_shape( + client_c.channel_store(), + cx_c, + &[(zed_channel, 0), (vim_channel, 1)], + ); + + let helix_channel = client_a + .channel_store() + .update(cx_a, |channel_store, cx| { + channel_store.create_channel("helix", None, cx) + }) + .await + .unwrap(); + + client_a + .channel_store() + .update(cx_a, |channel_store, cx| { + channel_store.move_channel(helix_channel, Some(vim_channel), cx) + }) + .await + .unwrap(); + + client_a + .channel_store() + .update(cx_a, |channel_store, cx| { + channel_store.set_channel_visibility( + helix_channel, + proto::ChannelVisibility::Public, + cx, + ) + }) + .await + .unwrap(); + + // the new channel shows for b and c + assert_channels_list_shape( + client_b.channel_store(), + cx_b, + &[ + (zed_channel, 0), + (active_channel, 1), + (vim_channel, 2), + (helix_channel, 3), + ], + ); + assert_channels_list_shape( + client_c.channel_store(), + cx_c, + &[(zed_channel, 0), (vim_channel, 1), (helix_channel, 2)], + ); + + client_a + .channel_store() + .update(cx_a, |channel_store, cx| { + channel_store.set_channel_visibility(vim_channel, proto::ChannelVisibility::Members, cx) + }) + .await + .unwrap(); + + // the members-only channel is still shown for c, but hidden for b + assert_channels_list_shape( + client_b.channel_store(), + cx_b, + &[ + (zed_channel, 0), + (active_channel, 1), + (vim_channel, 2), + (helix_channel, 3), + ], + ); + cx_b.read(|cx| { + client_b.channel_store().read_with(cx, |channel_store, _| { + assert_eq!( + channel_store + .channel_for_id(vim_channel) + .unwrap() + .visibility, + proto::ChannelVisibility::Members + ) + }) + }); + + assert_channels_list_shape(client_c.channel_store(), cx_c, &[(zed_channel, 0)]); +} + +#[gpui::test] +async fn test_channel_membership_notifications( + executor: BackgroundExecutor, + cx_a: &mut TestAppContext, + cx_b: &mut TestAppContext, +) { + let mut server = TestServer::start(executor.clone()).await; + let client_a = server.create_client(cx_a, "user_a").await; + let client_b = server.create_client(cx_b, "user_c").await; + + let user_b = client_b.user_id().unwrap(); + + let channels = server + .make_channel_tree( + &[ + ("zed", None), + ("active", Some("zed")), + ("vim", Some("active")), + ], + (&client_a, cx_a), + ) + .await; + let zed_channel = channels[0]; + let _active_channel = channels[1]; + let vim_channel = channels[2]; + + try_join_all(client_a.channel_store().update(cx_a, |channel_store, cx| { + [ + channel_store.set_channel_visibility(zed_channel, proto::ChannelVisibility::Public, cx), + channel_store.set_channel_visibility(vim_channel, proto::ChannelVisibility::Public, cx), + channel_store.invite_member(vim_channel, user_b, proto::ChannelRole::Member, cx), + channel_store.invite_member(zed_channel, user_b, proto::ChannelRole::Guest, cx), + ] + })) + .await + .unwrap(); + + executor.run_until_parked(); + + client_b + .channel_store() + .update(cx_b, |channel_store, cx| { + channel_store.respond_to_channel_invite(zed_channel, true, cx) + }) + .await + .unwrap(); + + client_b + .channel_store() + .update(cx_b, |channel_store, cx| { + channel_store.respond_to_channel_invite(vim_channel, true, cx) + }) + .await + .unwrap(); + + executor.run_until_parked(); + + // we have an admin (a), and a guest (b) with access to all of zed, and membership in vim. + assert_channels( + client_b.channel_store(), + cx_b, + &[ + ExpectedChannel { + depth: 0, + id: zed_channel, + name: "zed".to_string(), + role: ChannelRole::Guest, + }, + ExpectedChannel { + depth: 1, + id: vim_channel, + name: "vim".to_string(), + role: ChannelRole::Member, + }, + ], + ); + + client_a + .channel_store() + .update(cx_a, |channel_store, cx| { + channel_store.remove_member(vim_channel, user_b, cx) + }) + .await + .unwrap(); + + executor.run_until_parked(); + + assert_channels( + client_b.channel_store(), + cx_b, + &[ + ExpectedChannel { + depth: 0, + id: zed_channel, + name: "zed".to_string(), + role: ChannelRole::Guest, + }, + ExpectedChannel { + depth: 1, + id: vim_channel, + name: "vim".to_string(), + role: ChannelRole::Guest, + }, + ], + ) +} + +#[gpui::test] +async fn test_guest_access( + executor: BackgroundExecutor, + cx_a: &mut TestAppContext, + cx_b: &mut TestAppContext, +) { + let mut server = TestServer::start(executor.clone()).await; + let client_a = server.create_client(cx_a, "user_a").await; + let client_b = server.create_client(cx_b, "user_b").await; + + let channels = server + .make_channel_tree( + &[("channel-a", None), ("channel-b", Some("channel-a"))], + (&client_a, cx_a), + ) + .await; + let channel_a = channels[0]; + let channel_b = channels[1]; + + let active_call_b = cx_b.read(ActiveCall::global); + + // Non-members should not be allowed to join + assert!(active_call_b + .update(cx_b, |call, cx| call.join_channel(channel_a, cx)) + .await + .is_err()); + + // Make channels A and B public + client_a + .channel_store() + .update(cx_a, |channel_store, cx| { + channel_store.set_channel_visibility(channel_a, proto::ChannelVisibility::Public, cx) + }) + .await + .unwrap(); + client_a + .channel_store() + .update(cx_a, |channel_store, cx| { + channel_store.set_channel_visibility(channel_b, proto::ChannelVisibility::Public, cx) + }) + .await + .unwrap(); + + // Client B joins channel A as a guest + active_call_b + .update(cx_b, |call, cx| call.join_channel(channel_a, cx)) + .await + .unwrap(); + + executor.run_until_parked(); + assert_channels_list_shape( + client_a.channel_store(), + cx_a, + &[(channel_a, 0), (channel_b, 1)], + ); + assert_channels_list_shape( + client_b.channel_store(), + cx_b, + &[(channel_a, 0), (channel_b, 1)], + ); + + client_a.channel_store().update(cx_a, |channel_store, _| { + let participants = channel_store.channel_participants(channel_a); + assert_eq!(participants.len(), 1); + assert_eq!(participants[0].id, client_b.user_id().unwrap()); + }); + + client_a + .channel_store() + .update(cx_a, |channel_store, cx| { + channel_store.set_channel_visibility(channel_a, proto::ChannelVisibility::Members, cx) + }) + .await + .unwrap(); + + assert_channels_list_shape(client_b.channel_store(), cx_b, &[]); + + active_call_b + .update(cx_b, |call, cx| call.join_channel(channel_b, cx)) + .await + .unwrap(); + + executor.run_until_parked(); + assert_channels_list_shape(client_b.channel_store(), cx_b, &[(channel_b, 0)]); +} + +#[gpui::test] +async fn test_invite_access( + executor: BackgroundExecutor, + cx_a: &mut TestAppContext, + cx_b: &mut TestAppContext, +) { + let mut server = TestServer::start(executor.clone()).await; + let client_a = server.create_client(cx_a, "user_a").await; + let client_b = server.create_client(cx_b, "user_b").await; + + let channels = server + .make_channel_tree( + &[("channel-a", None), ("channel-b", Some("channel-a"))], + (&client_a, cx_a), + ) + .await; + let channel_a_id = channels[0]; + let channel_b_id = channels[0]; + + let active_call_b = cx_b.read(ActiveCall::global); + + // should not be allowed to join + assert!(active_call_b + .update(cx_b, |call, cx| call.join_channel(channel_b_id, cx)) + .await + .is_err()); + + client_a + .channel_store() + .update(cx_a, |channel_store, cx| { + channel_store.invite_member( + channel_a_id, + client_b.user_id().unwrap(), + ChannelRole::Member, + cx, + ) + }) + .await + .unwrap(); + + active_call_b + .update(cx_b, |call, cx| call.join_channel(channel_b_id, cx)) + .await + .unwrap(); + + executor.run_until_parked(); + + client_b.channel_store().update(cx_b, |channel_store, _| { + assert!(channel_store.channel_for_id(channel_b_id).is_some()); + assert!(channel_store.channel_for_id(channel_a_id).is_some()); + }); + + client_a.channel_store().update(cx_a, |channel_store, _| { + let participants = channel_store.channel_participants(channel_b_id); + assert_eq!(participants.len(), 1); + assert_eq!(participants[0].id, client_b.user_id().unwrap()); + }) +} + +#[gpui::test] +async fn test_channel_moving( + executor: BackgroundExecutor, + cx_a: &mut TestAppContext, + _cx_b: &mut TestAppContext, + _cx_c: &mut TestAppContext, +) { + let mut server = TestServer::start(executor.clone()).await; + let client_a = server.create_client(cx_a, "user_a").await; + // let client_b = server.create_client(cx_b, "user_b").await; + // let client_c = server.create_client(cx_c, "user_c").await; + + let channels = server + .make_channel_tree( + &[ + ("channel-a", None), + ("channel-b", Some("channel-a")), + ("channel-c", Some("channel-b")), + ("channel-d", Some("channel-c")), + ], + (&client_a, cx_a), + ) + .await; + let channel_a_id = channels[0]; + let channel_b_id = channels[1]; + let channel_c_id = channels[2]; + let channel_d_id = channels[3]; + + // Current shape: + // a - b - c - d + assert_channels_list_shape( + client_a.channel_store(), + cx_a, + &[ + (channel_a_id, 0), + (channel_b_id, 1), + (channel_c_id, 2), + (channel_d_id, 3), + ], + ); + + client_a + .channel_store() + .update(cx_a, |channel_store, cx| { + channel_store.move_channel(channel_d_id, Some(channel_b_id), cx) + }) + .await + .unwrap(); + + // Current shape: + // /- d + // a - b -- c + assert_channels_list_shape( + client_a.channel_store(), + cx_a, + &[ + (channel_a_id, 0), + (channel_b_id, 1), + (channel_c_id, 2), + (channel_d_id, 2), + ], + ); +} + +#[derive(Debug, PartialEq)] +struct ExpectedChannel { + depth: usize, + id: ChannelId, + name: String, + role: ChannelRole, +} + +#[track_caller] +fn assert_channel_invitations( + channel_store: &Model, + cx: &TestAppContext, + expected_channels: &[ExpectedChannel], +) { + let actual = cx.read(|cx| { + channel_store.read_with(cx, |store, _| { + store + .channel_invitations() + .iter() + .map(|channel| ExpectedChannel { + depth: 0, + name: channel.name.clone(), + id: channel.id, + role: channel.role, + }) + .collect::>() + }) + }); + assert_eq!(actual, expected_channels); +} + +#[track_caller] +fn assert_channels( + channel_store: &Model, + cx: &TestAppContext, + expected_channels: &[ExpectedChannel], +) { + let actual = cx.read(|cx| { + channel_store.read_with(cx, |store, _| { + store + .ordered_channels() + .map(|(depth, channel)| ExpectedChannel { + depth, + name: channel.name.clone(), + id: channel.id, + role: channel.role, + }) + .collect::>() + }) + }); + pretty_assertions::assert_eq!(actual, expected_channels); +} + +#[track_caller] +fn assert_channels_list_shape( + channel_store: &Model, + cx: &TestAppContext, + expected_channels: &[(u64, usize)], +) { + let actual = cx.read(|cx| { + channel_store.read_with(cx, |store, _| { + store + .ordered_channels() + .map(|(depth, channel)| (channel.id, depth)) + .collect::>() + }) + }); + pretty_assertions::assert_eq!(actual, expected_channels); +} diff --git a/crates/collab2/src/tests/editor_tests.rs b/crates/collab2/src/tests/editor_tests.rs new file mode 100644 index 0000000000..4900cb20f6 --- /dev/null +++ b/crates/collab2/src/tests/editor_tests.rs @@ -0,0 +1,1108 @@ +// use editor::{ +// test::editor_test_context::EditorTestContext, ConfirmCodeAction, ConfirmCompletion, +// ConfirmRename, Editor, Redo, Rename, ToggleCodeActions, Undo, +// }; + +//todo!(editor) +// #[gpui::test(iterations = 10)] +// async fn test_host_disconnect( +// executor: BackgroundExecutor, +// cx_a: &mut TestAppContext, +// cx_b: &mut TestAppContext, +// cx_c: &mut TestAppContext, +// ) { +// let mut server = TestServer::start(&executor).await; +// let client_a = server.create_client(cx_a, "user_a").await; +// let client_b = server.create_client(cx_b, "user_b").await; +// let client_c = server.create_client(cx_c, "user_c").await; +// server +// .create_room(&mut [(&client_a, cx_a), (&client_b, cx_b), (&client_c, cx_c)]) +// .await; + +// cx_b.update(editor::init); + +// client_a +// .fs() +// .insert_tree( +// "/a", +// json!({ +// "a.txt": "a-contents", +// "b.txt": "b-contents", +// }), +// ) +// .await; + +// let active_call_a = cx_a.read(ActiveCall::global); +// let (project_a, worktree_id) = client_a.build_local_project("/a", cx_a).await; + +// let worktree_a = project_a.read_with(cx_a, |project, cx| project.worktrees(cx).next().unwrap()); +// let project_id = active_call_a +// .update(cx_a, |call, cx| call.share_project(project_a.clone(), cx)) +// .await +// .unwrap(); + +// let project_b = client_b.build_remote_project(project_id, cx_b).await; +// executor.run_until_parked(); + +// assert!(worktree_a.read_with(cx_a, |tree, _| tree.as_local().unwrap().is_shared())); + +// let window_b = +// cx_b.add_window(|cx| Workspace::new(0, project_b.clone(), client_b.app_state.clone(), cx)); +// let workspace_b = window_b.root(cx_b); +// let editor_b = workspace_b +// .update(cx_b, |workspace, cx| { +// workspace.open_path((worktree_id, "b.txt"), None, true, cx) +// }) +// .await +// .unwrap() +// .downcast::() +// .unwrap(); + +// assert!(window_b.read_with(cx_b, |cx| editor_b.is_focused(cx))); +// editor_b.update(cx_b, |editor, cx| editor.insert("X", cx)); +// assert!(window_b.is_edited(cx_b)); + +// // Drop client A's connection. Collaborators should disappear and the project should not be shown as shared. +// server.forbid_connections(); +// server.disconnect_client(client_a.peer_id().unwrap()); +// executor.advance_clock(RECEIVE_TIMEOUT + RECONNECT_TIMEOUT); + +// project_a.read_with(cx_a, |project, _| project.collaborators().is_empty()); + +// project_a.read_with(cx_a, |project, _| assert!(!project.is_shared())); + +// project_b.read_with(cx_b, |project, _| project.is_read_only()); + +// assert!(worktree_a.read_with(cx_a, |tree, _| !tree.as_local().unwrap().is_shared())); + +// // Ensure client B's edited state is reset and that the whole window is blurred. + +// window_b.read_with(cx_b, |cx| { +// assert_eq!(cx.focused_view_id(), None); +// }); +// assert!(!window_b.is_edited(cx_b)); + +// // Ensure client B is not prompted to save edits when closing window after disconnecting. +// let can_close = workspace_b +// .update(cx_b, |workspace, cx| workspace.prepare_to_close(true, cx)) +// .await +// .unwrap(); +// assert!(can_close); + +// // Allow client A to reconnect to the server. +// server.allow_connections(); +// executor.advance_clock(RECEIVE_TIMEOUT); + +// // Client B calls client A again after they reconnected. +// let active_call_b = cx_b.read(ActiveCall::global); +// active_call_b +// .update(cx_b, |call, cx| { +// call.invite(client_a.user_id().unwrap(), None, cx) +// }) +// .await +// .unwrap(); +// executor.run_until_parked(); +// active_call_a +// .update(cx_a, |call, cx| call.accept_incoming(cx)) +// .await +// .unwrap(); + +// active_call_a +// .update(cx_a, |call, cx| call.share_project(project_a.clone(), cx)) +// .await +// .unwrap(); + +// // Drop client A's connection again. We should still unshare it successfully. +// server.forbid_connections(); +// server.disconnect_client(client_a.peer_id().unwrap()); +// executor.advance_clock(RECEIVE_TIMEOUT + RECONNECT_TIMEOUT); + +// project_a.read_with(cx_a, |project, _| assert!(!project.is_shared())); +// } + +//todo!(editor) +// #[gpui::test] +// async fn test_newline_above_or_below_does_not_move_guest_cursor( +// executor: BackgroundExecutor, +// cx_a: &mut TestAppContext, +// cx_b: &mut TestAppContext, +// ) { +// let mut server = TestServer::start(&executor).await; +// let client_a = server.create_client(cx_a, "user_a").await; +// let client_b = server.create_client(cx_b, "user_b").await; +// server +// .create_room(&mut [(&client_a, cx_a), (&client_b, cx_b)]) +// .await; +// let active_call_a = cx_a.read(ActiveCall::global); + +// client_a +// .fs() +// .insert_tree("/dir", json!({ "a.txt": "Some text\n" })) +// .await; +// let (project_a, worktree_id) = client_a.build_local_project("/dir", cx_a).await; +// let project_id = active_call_a +// .update(cx_a, |call, cx| call.share_project(project_a.clone(), cx)) +// .await +// .unwrap(); + +// let project_b = client_b.build_remote_project(project_id, cx_b).await; + +// // Open a buffer as client A +// let buffer_a = project_a +// .update(cx_a, |p, cx| p.open_buffer((worktree_id, "a.txt"), cx)) +// .await +// .unwrap(); +// let window_a = cx_a.add_window(|_| EmptyView); +// let editor_a = window_a.add_view(cx_a, |cx| Editor::for_buffer(buffer_a, Some(project_a), cx)); +// let mut editor_cx_a = EditorTestContext { +// cx: cx_a, +// window: window_a.into(), +// editor: editor_a, +// }; + +// // Open a buffer as client B +// let buffer_b = project_b +// .update(cx_b, |p, cx| p.open_buffer((worktree_id, "a.txt"), cx)) +// .await +// .unwrap(); +// let window_b = cx_b.add_window(|_| EmptyView); +// let editor_b = window_b.add_view(cx_b, |cx| Editor::for_buffer(buffer_b, Some(project_b), cx)); +// let mut editor_cx_b = EditorTestContext { +// cx: cx_b, +// window: window_b.into(), +// editor: editor_b, +// }; + +// // Test newline above +// editor_cx_a.set_selections_state(indoc! {" +// Some textˇ +// "}); +// editor_cx_b.set_selections_state(indoc! {" +// Some textˇ +// "}); +// editor_cx_a.update_editor(|editor, cx| editor.newline_above(&editor::NewlineAbove, cx)); +// executor.run_until_parked(); +// editor_cx_a.assert_editor_state(indoc! {" +// ˇ +// Some text +// "}); +// editor_cx_b.assert_editor_state(indoc! {" + +// Some textˇ +// "}); + +// // Test newline below +// editor_cx_a.set_selections_state(indoc! {" + +// Some textˇ +// "}); +// editor_cx_b.set_selections_state(indoc! {" + +// Some textˇ +// "}); +// editor_cx_a.update_editor(|editor, cx| editor.newline_below(&editor::NewlineBelow, cx)); +// executor.run_until_parked(); +// editor_cx_a.assert_editor_state(indoc! {" + +// Some text +// ˇ +// "}); +// editor_cx_b.assert_editor_state(indoc! {" + +// Some textˇ + +// "}); +// } + +//todo!(editor) +// #[gpui::test(iterations = 10)] +// async fn test_collaborating_with_completion( +// executor: BackgroundExecutor, +// cx_a: &mut TestAppContext, +// cx_b: &mut TestAppContext, +// ) { +// let mut server = TestServer::start(&executor).await; +// let client_a = server.create_client(cx_a, "user_a").await; +// let client_b = server.create_client(cx_b, "user_b").await; +// server +// .create_room(&mut [(&client_a, cx_a), (&client_b, cx_b)]) +// .await; +// let active_call_a = cx_a.read(ActiveCall::global); + +// // Set up a fake language server. +// let mut language = Language::new( +// LanguageConfig { +// name: "Rust".into(), +// path_suffixes: vec!["rs".to_string()], +// ..Default::default() +// }, +// Some(tree_sitter_rust::language()), +// ); +// let mut fake_language_servers = language +// .set_fake_lsp_adapter(Arc::new(FakeLspAdapter { +// capabilities: lsp::ServerCapabilities { +// completion_provider: Some(lsp::CompletionOptions { +// trigger_characters: Some(vec![".".to_string()]), +// resolve_provider: Some(true), +// ..Default::default() +// }), +// ..Default::default() +// }, +// ..Default::default() +// })) +// .await; +// client_a.language_registry().add(Arc::new(language)); + +// client_a +// .fs() +// .insert_tree( +// "/a", +// json!({ +// "main.rs": "fn main() { a }", +// "other.rs": "", +// }), +// ) +// .await; +// let (project_a, worktree_id) = client_a.build_local_project("/a", cx_a).await; +// let project_id = active_call_a +// .update(cx_a, |call, cx| call.share_project(project_a.clone(), cx)) +// .await +// .unwrap(); +// let project_b = client_b.build_remote_project(project_id, cx_b).await; + +// // Open a file in an editor as the guest. +// let buffer_b = project_b +// .update(cx_b, |p, cx| p.open_buffer((worktree_id, "main.rs"), cx)) +// .await +// .unwrap(); +// let window_b = cx_b.add_window(|_| EmptyView); +// let editor_b = window_b.add_view(cx_b, |cx| { +// Editor::for_buffer(buffer_b.clone(), Some(project_b.clone()), cx) +// }); + +// let fake_language_server = fake_language_servers.next().await.unwrap(); +// cx_a.foreground().run_until_parked(); + +// buffer_b.read_with(cx_b, |buffer, _| { +// assert!(!buffer.completion_triggers().is_empty()) +// }); + +// // Type a completion trigger character as the guest. +// editor_b.update(cx_b, |editor, cx| { +// editor.change_selections(None, cx, |s| s.select_ranges([13..13])); +// editor.handle_input(".", cx); +// cx.focus(&editor_b); +// }); + +// // Receive a completion request as the host's language server. +// // Return some completions from the host's language server. +// cx_a.foreground().start_waiting(); +// fake_language_server +// .handle_request::(|params, _| async move { +// assert_eq!( +// params.text_document_position.text_document.uri, +// lsp::Url::from_file_path("/a/main.rs").unwrap(), +// ); +// assert_eq!( +// params.text_document_position.position, +// lsp::Position::new(0, 14), +// ); + +// Ok(Some(lsp::CompletionResponse::Array(vec![ +// lsp::CompletionItem { +// label: "first_method(…)".into(), +// detail: Some("fn(&mut self, B) -> C".into()), +// text_edit: Some(lsp::CompletionTextEdit::Edit(lsp::TextEdit { +// new_text: "first_method($1)".to_string(), +// range: lsp::Range::new( +// lsp::Position::new(0, 14), +// lsp::Position::new(0, 14), +// ), +// })), +// insert_text_format: Some(lsp::InsertTextFormat::SNIPPET), +// ..Default::default() +// }, +// lsp::CompletionItem { +// label: "second_method(…)".into(), +// detail: Some("fn(&mut self, C) -> D".into()), +// text_edit: Some(lsp::CompletionTextEdit::Edit(lsp::TextEdit { +// new_text: "second_method()".to_string(), +// range: lsp::Range::new( +// lsp::Position::new(0, 14), +// lsp::Position::new(0, 14), +// ), +// })), +// insert_text_format: Some(lsp::InsertTextFormat::SNIPPET), +// ..Default::default() +// }, +// ]))) +// }) +// .next() +// .await +// .unwrap(); +// cx_a.foreground().finish_waiting(); + +// // Open the buffer on the host. +// let buffer_a = project_a +// .update(cx_a, |p, cx| p.open_buffer((worktree_id, "main.rs"), cx)) +// .await +// .unwrap(); +// cx_a.foreground().run_until_parked(); + +// buffer_a.read_with(cx_a, |buffer, _| { +// assert_eq!(buffer.text(), "fn main() { a. }") +// }); + +// // Confirm a completion on the guest. + +// editor_b.read_with(cx_b, |editor, _| assert!(editor.context_menu_visible())); +// editor_b.update(cx_b, |editor, cx| { +// editor.confirm_completion(&ConfirmCompletion { item_ix: Some(0) }, cx); +// assert_eq!(editor.text(cx), "fn main() { a.first_method() }"); +// }); + +// // Return a resolved completion from the host's language server. +// // The resolved completion has an additional text edit. +// fake_language_server.handle_request::( +// |params, _| async move { +// assert_eq!(params.label, "first_method(…)"); +// Ok(lsp::CompletionItem { +// label: "first_method(…)".into(), +// detail: Some("fn(&mut self, B) -> C".into()), +// text_edit: Some(lsp::CompletionTextEdit::Edit(lsp::TextEdit { +// new_text: "first_method($1)".to_string(), +// range: lsp::Range::new(lsp::Position::new(0, 14), lsp::Position::new(0, 14)), +// })), +// additional_text_edits: Some(vec![lsp::TextEdit { +// new_text: "use d::SomeTrait;\n".to_string(), +// range: lsp::Range::new(lsp::Position::new(0, 0), lsp::Position::new(0, 0)), +// }]), +// insert_text_format: Some(lsp::InsertTextFormat::SNIPPET), +// ..Default::default() +// }) +// }, +// ); + +// // The additional edit is applied. +// cx_a.foreground().run_until_parked(); + +// buffer_a.read_with(cx_a, |buffer, _| { +// assert_eq!( +// buffer.text(), +// "use d::SomeTrait;\nfn main() { a.first_method() }" +// ); +// }); + +// buffer_b.read_with(cx_b, |buffer, _| { +// assert_eq!( +// buffer.text(), +// "use d::SomeTrait;\nfn main() { a.first_method() }" +// ); +// }); +// } +//todo!(editor) +// #[gpui::test(iterations = 10)] +// async fn test_collaborating_with_code_actions( +// executor: BackgroundExecutor, +// cx_a: &mut TestAppContext, +// cx_b: &mut TestAppContext, +// ) { +// let mut server = TestServer::start(&executor).await; +// let client_a = server.create_client(cx_a, "user_a").await; +// // +// let client_b = server.create_client(cx_b, "user_b").await; +// server +// .create_room(&mut [(&client_a, cx_a), (&client_b, cx_b)]) +// .await; +// let active_call_a = cx_a.read(ActiveCall::global); + +// cx_b.update(editor::init); + +// // Set up a fake language server. +// let mut language = Language::new( +// LanguageConfig { +// name: "Rust".into(), +// path_suffixes: vec!["rs".to_string()], +// ..Default::default() +// }, +// Some(tree_sitter_rust::language()), +// ); +// let mut fake_language_servers = language.set_fake_lsp_adapter(Default::default()).await; +// client_a.language_registry().add(Arc::new(language)); + +// client_a +// .fs() +// .insert_tree( +// "/a", +// json!({ +// "main.rs": "mod other;\nfn main() { let foo = other::foo(); }", +// "other.rs": "pub fn foo() -> usize { 4 }", +// }), +// ) +// .await; +// let (project_a, worktree_id) = client_a.build_local_project("/a", cx_a).await; +// let project_id = active_call_a +// .update(cx_a, |call, cx| call.share_project(project_a.clone(), cx)) +// .await +// .unwrap(); + +// // Join the project as client B. +// let project_b = client_b.build_remote_project(project_id, cx_b).await; +// let window_b = +// cx_b.add_window(|cx| Workspace::new(0, project_b.clone(), client_b.app_state.clone(), cx)); +// let workspace_b = window_b.root(cx_b); +// let editor_b = workspace_b +// .update(cx_b, |workspace, cx| { +// workspace.open_path((worktree_id, "main.rs"), None, true, cx) +// }) +// .await +// .unwrap() +// .downcast::() +// .unwrap(); + +// let mut fake_language_server = fake_language_servers.next().await.unwrap(); +// let mut requests = fake_language_server +// .handle_request::(|params, _| async move { +// assert_eq!( +// params.text_document.uri, +// lsp::Url::from_file_path("/a/main.rs").unwrap(), +// ); +// assert_eq!(params.range.start, lsp::Position::new(0, 0)); +// assert_eq!(params.range.end, lsp::Position::new(0, 0)); +// Ok(None) +// }); +// executor.advance_clock(editor::CODE_ACTIONS_DEBOUNCE_TIMEOUT * 2); +// requests.next().await; + +// // Move cursor to a location that contains code actions. +// editor_b.update(cx_b, |editor, cx| { +// editor.change_selections(None, cx, |s| { +// s.select_ranges([Point::new(1, 31)..Point::new(1, 31)]) +// }); +// cx.focus(&editor_b); +// }); + +// let mut requests = fake_language_server +// .handle_request::(|params, _| async move { +// assert_eq!( +// params.text_document.uri, +// lsp::Url::from_file_path("/a/main.rs").unwrap(), +// ); +// assert_eq!(params.range.start, lsp::Position::new(1, 31)); +// assert_eq!(params.range.end, lsp::Position::new(1, 31)); + +// Ok(Some(vec![lsp::CodeActionOrCommand::CodeAction( +// lsp::CodeAction { +// title: "Inline into all callers".to_string(), +// edit: Some(lsp::WorkspaceEdit { +// changes: Some( +// [ +// ( +// lsp::Url::from_file_path("/a/main.rs").unwrap(), +// vec![lsp::TextEdit::new( +// lsp::Range::new( +// lsp::Position::new(1, 22), +// lsp::Position::new(1, 34), +// ), +// "4".to_string(), +// )], +// ), +// ( +// lsp::Url::from_file_path("/a/other.rs").unwrap(), +// vec![lsp::TextEdit::new( +// lsp::Range::new( +// lsp::Position::new(0, 0), +// lsp::Position::new(0, 27), +// ), +// "".to_string(), +// )], +// ), +// ] +// .into_iter() +// .collect(), +// ), +// ..Default::default() +// }), +// data: Some(json!({ +// "codeActionParams": { +// "range": { +// "start": {"line": 1, "column": 31}, +// "end": {"line": 1, "column": 31}, +// } +// } +// })), +// ..Default::default() +// }, +// )])) +// }); +// executor.advance_clock(editor::CODE_ACTIONS_DEBOUNCE_TIMEOUT * 2); +// requests.next().await; + +// // Toggle code actions and wait for them to display. +// editor_b.update(cx_b, |editor, cx| { +// editor.toggle_code_actions( +// &ToggleCodeActions { +// deployed_from_indicator: false, +// }, +// cx, +// ); +// }); +// cx_a.foreground().run_until_parked(); + +// editor_b.read_with(cx_b, |editor, _| assert!(editor.context_menu_visible())); + +// fake_language_server.remove_request_handler::(); + +// // Confirming the code action will trigger a resolve request. +// let confirm_action = workspace_b +// .update(cx_b, |workspace, cx| { +// Editor::confirm_code_action(workspace, &ConfirmCodeAction { item_ix: Some(0) }, cx) +// }) +// .unwrap(); +// fake_language_server.handle_request::( +// |_, _| async move { +// Ok(lsp::CodeAction { +// title: "Inline into all callers".to_string(), +// edit: Some(lsp::WorkspaceEdit { +// changes: Some( +// [ +// ( +// lsp::Url::from_file_path("/a/main.rs").unwrap(), +// vec![lsp::TextEdit::new( +// lsp::Range::new( +// lsp::Position::new(1, 22), +// lsp::Position::new(1, 34), +// ), +// "4".to_string(), +// )], +// ), +// ( +// lsp::Url::from_file_path("/a/other.rs").unwrap(), +// vec![lsp::TextEdit::new( +// lsp::Range::new( +// lsp::Position::new(0, 0), +// lsp::Position::new(0, 27), +// ), +// "".to_string(), +// )], +// ), +// ] +// .into_iter() +// .collect(), +// ), +// ..Default::default() +// }), +// ..Default::default() +// }) +// }, +// ); + +// // After the action is confirmed, an editor containing both modified files is opened. +// confirm_action.await.unwrap(); + +// let code_action_editor = workspace_b.read_with(cx_b, |workspace, cx| { +// workspace +// .active_item(cx) +// .unwrap() +// .downcast::() +// .unwrap() +// }); +// code_action_editor.update(cx_b, |editor, cx| { +// assert_eq!(editor.text(cx), "mod other;\nfn main() { let foo = 4; }\n"); +// editor.undo(&Undo, cx); +// assert_eq!( +// editor.text(cx), +// "mod other;\nfn main() { let foo = other::foo(); }\npub fn foo() -> usize { 4 }" +// ); +// editor.redo(&Redo, cx); +// assert_eq!(editor.text(cx), "mod other;\nfn main() { let foo = 4; }\n"); +// }); +// } + +//todo!(editor) +// #[gpui::test(iterations = 10)] +// async fn test_collaborating_with_renames( +// executor: BackgroundExecutor, +// cx_a: &mut TestAppContext, +// cx_b: &mut TestAppContext, +// ) { +// let mut server = TestServer::start(&executor).await; +// let client_a = server.create_client(cx_a, "user_a").await; +// let client_b = server.create_client(cx_b, "user_b").await; +// server +// .create_room(&mut [(&client_a, cx_a), (&client_b, cx_b)]) +// .await; +// let active_call_a = cx_a.read(ActiveCall::global); + +// cx_b.update(editor::init); + +// // Set up a fake language server. +// let mut language = Language::new( +// LanguageConfig { +// name: "Rust".into(), +// path_suffixes: vec!["rs".to_string()], +// ..Default::default() +// }, +// Some(tree_sitter_rust::language()), +// ); +// let mut fake_language_servers = language +// .set_fake_lsp_adapter(Arc::new(FakeLspAdapter { +// capabilities: lsp::ServerCapabilities { +// rename_provider: Some(lsp::OneOf::Right(lsp::RenameOptions { +// prepare_provider: Some(true), +// work_done_progress_options: Default::default(), +// })), +// ..Default::default() +// }, +// ..Default::default() +// })) +// .await; +// client_a.language_registry().add(Arc::new(language)); + +// client_a +// .fs() +// .insert_tree( +// "/dir", +// json!({ +// "one.rs": "const ONE: usize = 1;", +// "two.rs": "const TWO: usize = one::ONE + one::ONE;" +// }), +// ) +// .await; +// let (project_a, worktree_id) = client_a.build_local_project("/dir", cx_a).await; +// let project_id = active_call_a +// .update(cx_a, |call, cx| call.share_project(project_a.clone(), cx)) +// .await +// .unwrap(); +// let project_b = client_b.build_remote_project(project_id, cx_b).await; + +// let window_b = +// cx_b.add_window(|cx| Workspace::new(0, project_b.clone(), client_b.app_state.clone(), cx)); +// let workspace_b = window_b.root(cx_b); +// let editor_b = workspace_b +// .update(cx_b, |workspace, cx| { +// workspace.open_path((worktree_id, "one.rs"), None, true, cx) +// }) +// .await +// .unwrap() +// .downcast::() +// .unwrap(); +// let fake_language_server = fake_language_servers.next().await.unwrap(); + +// // Move cursor to a location that can be renamed. +// let prepare_rename = editor_b.update(cx_b, |editor, cx| { +// editor.change_selections(None, cx, |s| s.select_ranges([7..7])); +// editor.rename(&Rename, cx).unwrap() +// }); + +// fake_language_server +// .handle_request::(|params, _| async move { +// assert_eq!(params.text_document.uri.as_str(), "file:///dir/one.rs"); +// assert_eq!(params.position, lsp::Position::new(0, 7)); +// Ok(Some(lsp::PrepareRenameResponse::Range(lsp::Range::new( +// lsp::Position::new(0, 6), +// lsp::Position::new(0, 9), +// )))) +// }) +// .next() +// .await +// .unwrap(); +// prepare_rename.await.unwrap(); +// editor_b.update(cx_b, |editor, cx| { +// use editor::ToOffset; +// let rename = editor.pending_rename().unwrap(); +// let buffer = editor.buffer().read(cx).snapshot(cx); +// assert_eq!( +// rename.range.start.to_offset(&buffer)..rename.range.end.to_offset(&buffer), +// 6..9 +// ); +// rename.editor.update(cx, |rename_editor, cx| { +// rename_editor.buffer().update(cx, |rename_buffer, cx| { +// rename_buffer.edit([(0..3, "THREE")], None, cx); +// }); +// }); +// }); + +// let confirm_rename = workspace_b.update(cx_b, |workspace, cx| { +// Editor::confirm_rename(workspace, &ConfirmRename, cx).unwrap() +// }); +// fake_language_server +// .handle_request::(|params, _| async move { +// assert_eq!( +// params.text_document_position.text_document.uri.as_str(), +// "file:///dir/one.rs" +// ); +// assert_eq!( +// params.text_document_position.position, +// lsp::Position::new(0, 6) +// ); +// assert_eq!(params.new_name, "THREE"); +// Ok(Some(lsp::WorkspaceEdit { +// changes: Some( +// [ +// ( +// lsp::Url::from_file_path("/dir/one.rs").unwrap(), +// vec![lsp::TextEdit::new( +// lsp::Range::new(lsp::Position::new(0, 6), lsp::Position::new(0, 9)), +// "THREE".to_string(), +// )], +// ), +// ( +// lsp::Url::from_file_path("/dir/two.rs").unwrap(), +// vec![ +// lsp::TextEdit::new( +// lsp::Range::new( +// lsp::Position::new(0, 24), +// lsp::Position::new(0, 27), +// ), +// "THREE".to_string(), +// ), +// lsp::TextEdit::new( +// lsp::Range::new( +// lsp::Position::new(0, 35), +// lsp::Position::new(0, 38), +// ), +// "THREE".to_string(), +// ), +// ], +// ), +// ] +// .into_iter() +// .collect(), +// ), +// ..Default::default() +// })) +// }) +// .next() +// .await +// .unwrap(); +// confirm_rename.await.unwrap(); + +// let rename_editor = workspace_b.read_with(cx_b, |workspace, cx| { +// workspace +// .active_item(cx) +// .unwrap() +// .downcast::() +// .unwrap() +// }); +// rename_editor.update(cx_b, |editor, cx| { +// assert_eq!( +// editor.text(cx), +// "const THREE: usize = 1;\nconst TWO: usize = one::THREE + one::THREE;" +// ); +// editor.undo(&Undo, cx); +// assert_eq!( +// editor.text(cx), +// "const ONE: usize = 1;\nconst TWO: usize = one::ONE + one::ONE;" +// ); +// editor.redo(&Redo, cx); +// assert_eq!( +// editor.text(cx), +// "const THREE: usize = 1;\nconst TWO: usize = one::THREE + one::THREE;" +// ); +// }); + +// // Ensure temporary rename edits cannot be undone/redone. +// editor_b.update(cx_b, |editor, cx| { +// editor.undo(&Undo, cx); +// assert_eq!(editor.text(cx), "const ONE: usize = 1;"); +// editor.undo(&Undo, cx); +// assert_eq!(editor.text(cx), "const ONE: usize = 1;"); +// editor.redo(&Redo, cx); +// assert_eq!(editor.text(cx), "const THREE: usize = 1;"); +// }) +// } + +//todo!(editor) +// #[gpui::test(iterations = 10)] +// async fn test_language_server_statuses( +// executor: BackgroundExecutor, +// cx_a: &mut TestAppContext, +// cx_b: &mut TestAppContext, +// ) { +// let mut server = TestServer::start(&executor).await; +// let client_a = server.create_client(cx_a, "user_a").await; +// let client_b = server.create_client(cx_b, "user_b").await; +// server +// .create_room(&mut [(&client_a, cx_a), (&client_b, cx_b)]) +// .await; +// let active_call_a = cx_a.read(ActiveCall::global); + +// cx_b.update(editor::init); + +// // Set up a fake language server. +// let mut language = Language::new( +// LanguageConfig { +// name: "Rust".into(), +// path_suffixes: vec!["rs".to_string()], +// ..Default::default() +// }, +// Some(tree_sitter_rust::language()), +// ); +// let mut fake_language_servers = language +// .set_fake_lsp_adapter(Arc::new(FakeLspAdapter { +// name: "the-language-server", +// ..Default::default() +// })) +// .await; +// client_a.language_registry().add(Arc::new(language)); + +// client_a +// .fs() +// .insert_tree( +// "/dir", +// json!({ +// "main.rs": "const ONE: usize = 1;", +// }), +// ) +// .await; +// let (project_a, worktree_id) = client_a.build_local_project("/dir", cx_a).await; + +// let _buffer_a = project_a +// .update(cx_a, |p, cx| p.open_buffer((worktree_id, "main.rs"), cx)) +// .await +// .unwrap(); + +// let fake_language_server = fake_language_servers.next().await.unwrap(); +// fake_language_server.start_progress("the-token").await; +// fake_language_server.notify::(lsp::ProgressParams { +// token: lsp::NumberOrString::String("the-token".to_string()), +// value: lsp::ProgressParamsValue::WorkDone(lsp::WorkDoneProgress::Report( +// lsp::WorkDoneProgressReport { +// message: Some("the-message".to_string()), +// ..Default::default() +// }, +// )), +// }); +// executor.run_until_parked(); + +// project_a.read_with(cx_a, |project, _| { +// let status = project.language_server_statuses().next().unwrap(); +// assert_eq!(status.name, "the-language-server"); +// assert_eq!(status.pending_work.len(), 1); +// assert_eq!( +// status.pending_work["the-token"].message.as_ref().unwrap(), +// "the-message" +// ); +// }); + +// let project_id = active_call_a +// .update(cx_a, |call, cx| call.share_project(project_a.clone(), cx)) +// .await +// .unwrap(); +// executor.run_until_parked(); +// let project_b = client_b.build_remote_project(project_id, cx_b).await; + +// project_b.read_with(cx_b, |project, _| { +// let status = project.language_server_statuses().next().unwrap(); +// assert_eq!(status.name, "the-language-server"); +// }); + +// fake_language_server.notify::(lsp::ProgressParams { +// token: lsp::NumberOrString::String("the-token".to_string()), +// value: lsp::ProgressParamsValue::WorkDone(lsp::WorkDoneProgress::Report( +// lsp::WorkDoneProgressReport { +// message: Some("the-message-2".to_string()), +// ..Default::default() +// }, +// )), +// }); +// executor.run_until_parked(); + +// project_a.read_with(cx_a, |project, _| { +// let status = project.language_server_statuses().next().unwrap(); +// assert_eq!(status.name, "the-language-server"); +// assert_eq!(status.pending_work.len(), 1); +// assert_eq!( +// status.pending_work["the-token"].message.as_ref().unwrap(), +// "the-message-2" +// ); +// }); + +// project_b.read_with(cx_b, |project, _| { +// let status = project.language_server_statuses().next().unwrap(); +// assert_eq!(status.name, "the-language-server"); +// assert_eq!(status.pending_work.len(), 1); +// assert_eq!( +// status.pending_work["the-token"].message.as_ref().unwrap(), +// "the-message-2" +// ); +// }); +// } + +// #[gpui::test(iterations = 10)] +// async fn test_share_project( +// executor: BackgroundExecutor, +// cx_a: &mut TestAppContext, +// cx_b: &mut TestAppContext, +// cx_c: &mut TestAppContext, +// ) { +// let window_b = cx_b.add_window(|_| EmptyView); +// let mut server = TestServer::start(&executor).await; +// let client_a = server.create_client(cx_a, "user_a").await; +// let client_b = server.create_client(cx_b, "user_b").await; +// let client_c = server.create_client(cx_c, "user_c").await; +// server +// .make_contacts(&mut [(&client_a, cx_a), (&client_b, cx_b), (&client_c, cx_c)]) +// .await; +// let active_call_a = cx_a.read(ActiveCall::global); +// let active_call_b = cx_b.read(ActiveCall::global); +// let active_call_c = cx_c.read(ActiveCall::global); + +// client_a +// .fs() +// .insert_tree( +// "/a", +// json!({ +// ".gitignore": "ignored-dir", +// "a.txt": "a-contents", +// "b.txt": "b-contents", +// "ignored-dir": { +// "c.txt": "", +// "d.txt": "", +// } +// }), +// ) +// .await; + +// // Invite client B to collaborate on a project +// let (project_a, worktree_id) = client_a.build_local_project("/a", cx_a).await; +// active_call_a +// .update(cx_a, |call, cx| { +// call.invite(client_b.user_id().unwrap(), Some(project_a.clone()), cx) +// }) +// .await +// .unwrap(); + +// // Join that project as client B + +// let incoming_call_b = active_call_b.read_with(cx_b, |call, _| call.incoming()); +// executor.run_until_parked(); +// let call = incoming_call_b.borrow().clone().unwrap(); +// assert_eq!(call.calling_user.github_login, "user_a"); +// let initial_project = call.initial_project.unwrap(); +// active_call_b +// .update(cx_b, |call, cx| call.accept_incoming(cx)) +// .await +// .unwrap(); +// let client_b_peer_id = client_b.peer_id().unwrap(); +// let project_b = client_b +// .build_remote_project(initial_project.id, cx_b) +// .await; + +// let replica_id_b = project_b.read_with(cx_b, |project, _| project.replica_id()); + +// executor.run_until_parked(); + +// project_a.read_with(cx_a, |project, _| { +// let client_b_collaborator = project.collaborators().get(&client_b_peer_id).unwrap(); +// assert_eq!(client_b_collaborator.replica_id, replica_id_b); +// }); + +// project_b.read_with(cx_b, |project, cx| { +// let worktree = project.worktrees().next().unwrap().read(cx); +// assert_eq!( +// worktree.paths().map(AsRef::as_ref).collect::>(), +// [ +// Path::new(".gitignore"), +// Path::new("a.txt"), +// Path::new("b.txt"), +// Path::new("ignored-dir"), +// ] +// ); +// }); + +// project_b +// .update(cx_b, |project, cx| { +// let worktree = project.worktrees().next().unwrap(); +// let entry = worktree.read(cx).entry_for_path("ignored-dir").unwrap(); +// project.expand_entry(worktree_id, entry.id, cx).unwrap() +// }) +// .await +// .unwrap(); + +// project_b.read_with(cx_b, |project, cx| { +// let worktree = project.worktrees().next().unwrap().read(cx); +// assert_eq!( +// worktree.paths().map(AsRef::as_ref).collect::>(), +// [ +// Path::new(".gitignore"), +// Path::new("a.txt"), +// Path::new("b.txt"), +// Path::new("ignored-dir"), +// Path::new("ignored-dir/c.txt"), +// Path::new("ignored-dir/d.txt"), +// ] +// ); +// }); + +// // Open the same file as client B and client A. +// let buffer_b = project_b +// .update(cx_b, |p, cx| p.open_buffer((worktree_id, "b.txt"), cx)) +// .await +// .unwrap(); + +// buffer_b.read_with(cx_b, |buf, _| assert_eq!(buf.text(), "b-contents")); + +// project_a.read_with(cx_a, |project, cx| { +// assert!(project.has_open_buffer((worktree_id, "b.txt"), cx)) +// }); +// let buffer_a = project_a +// .update(cx_a, |p, cx| p.open_buffer((worktree_id, "b.txt"), cx)) +// .await +// .unwrap(); + +// let editor_b = window_b.add_view(cx_b, |cx| Editor::for_buffer(buffer_b, None, cx)); + +// // Client A sees client B's selection +// executor.run_until_parked(); + +// buffer_a.read_with(cx_a, |buffer, _| { +// buffer +// .snapshot() +// .remote_selections_in_range(Anchor::MIN..Anchor::MAX) +// .count() +// == 1 +// }); + +// // Edit the buffer as client B and see that edit as client A. +// editor_b.update(cx_b, |editor, cx| editor.handle_input("ok, ", cx)); +// executor.run_until_parked(); + +// buffer_a.read_with(cx_a, |buffer, _| { +// assert_eq!(buffer.text(), "ok, b-contents") +// }); + +// // Client B can invite client C on a project shared by client A. +// active_call_b +// .update(cx_b, |call, cx| { +// call.invite(client_c.user_id().unwrap(), Some(project_b.clone()), cx) +// }) +// .await +// .unwrap(); + +// let incoming_call_c = active_call_c.read_with(cx_c, |call, _| call.incoming()); +// executor.run_until_parked(); +// let call = incoming_call_c.borrow().clone().unwrap(); +// assert_eq!(call.calling_user.github_login, "user_b"); +// let initial_project = call.initial_project.unwrap(); +// active_call_c +// .update(cx_c, |call, cx| call.accept_incoming(cx)) +// .await +// .unwrap(); +// let _project_c = client_c +// .build_remote_project(initial_project.id, cx_c) +// .await; + +// // Client B closes the editor, and client A sees client B's selections removed. +// cx_b.update(move |_| drop(editor_b)); +// executor.run_until_parked(); + +// buffer_a.read_with(cx_a, |buffer, _| { +// buffer +// .snapshot() +// .remote_selections_in_range(Anchor::MIN..Anchor::MAX) +// .count() +// == 0 +// }); +// } diff --git a/crates/collab2/src/tests/following_tests.rs b/crates/collab2/src/tests/following_tests.rs new file mode 100644 index 0000000000..61d14c25c4 --- /dev/null +++ b/crates/collab2/src/tests/following_tests.rs @@ -0,0 +1,1677 @@ +//todo!(workspace) + +// use crate::{rpc::RECONNECT_TIMEOUT, tests::TestServer}; +// use call::ActiveCall; +// use collab_ui::notifications::project_shared_notification::ProjectSharedNotification; +// use editor::{Editor, ExcerptRange, MultiBuffer}; +// use gpui::{BackgroundExecutor, TestAppContext, View}; +// use live_kit_client::MacOSDisplay; +// use rpc::proto::PeerId; +// use serde_json::json; +// use std::borrow::Cow; +// use workspace::{ +// dock::{test::TestPanel, DockPosition}, +// item::{test::TestItem, ItemHandle as _}, +// shared_screen::SharedScreen, +// SplitDirection, Workspace, +// }; + +// #[gpui::test(iterations = 10)] +// async fn test_basic_following( +// executor: BackgroundExecutor, +// cx_a: &mut TestAppContext, +// cx_b: &mut TestAppContext, +// cx_c: &mut TestAppContext, +// cx_d: &mut TestAppContext, +// ) { +// let mut server = TestServer::start(&executor).await; +// let client_a = server.create_client(cx_a, "user_a").await; +// let client_b = server.create_client(cx_b, "user_b").await; +// let client_c = server.create_client(cx_c, "user_c").await; +// let client_d = server.create_client(cx_d, "user_d").await; +// server +// .create_room(&mut [ +// (&client_a, cx_a), +// (&client_b, cx_b), +// (&client_c, cx_c), +// (&client_d, cx_d), +// ]) +// .await; +// let active_call_a = cx_a.read(ActiveCall::global); +// let active_call_b = cx_b.read(ActiveCall::global); + +// cx_a.update(editor::init); +// cx_b.update(editor::init); + +// client_a +// .fs() +// .insert_tree( +// "/a", +// json!({ +// "1.txt": "one\none\none", +// "2.txt": "two\ntwo\ntwo", +// "3.txt": "three\nthree\nthree", +// }), +// ) +// .await; +// let (project_a, worktree_id) = client_a.build_local_project("/a", cx_a).await; +// active_call_a +// .update(cx_a, |call, cx| call.set_location(Some(&project_a), cx)) +// .await +// .unwrap(); + +// let project_id = active_call_a +// .update(cx_a, |call, cx| call.share_project(project_a.clone(), cx)) +// .await +// .unwrap(); +// let project_b = client_b.build_remote_project(project_id, cx_b).await; +// active_call_b +// .update(cx_b, |call, cx| call.set_location(Some(&project_b), cx)) +// .await +// .unwrap(); + +// let window_a = client_a.build_workspace(&project_a, cx_a); +// let workspace_a = window_a.root(cx_a); +// let window_b = client_b.build_workspace(&project_b, cx_b); +// let workspace_b = window_b.root(cx_b); + +// // Client A opens some editors. +// let pane_a = workspace_a.read_with(cx_a, |workspace, _| workspace.active_pane().clone()); +// let editor_a1 = workspace_a +// .update(cx_a, |workspace, cx| { +// workspace.open_path((worktree_id, "1.txt"), None, true, cx) +// }) +// .await +// .unwrap() +// .downcast::() +// .unwrap(); +// let editor_a2 = workspace_a +// .update(cx_a, |workspace, cx| { +// workspace.open_path((worktree_id, "2.txt"), None, true, cx) +// }) +// .await +// .unwrap() +// .downcast::() +// .unwrap(); + +// // Client B opens an editor. +// let editor_b1 = workspace_b +// .update(cx_b, |workspace, cx| { +// workspace.open_path((worktree_id, "1.txt"), None, true, cx) +// }) +// .await +// .unwrap() +// .downcast::() +// .unwrap(); + +// let peer_id_a = client_a.peer_id().unwrap(); +// let peer_id_b = client_b.peer_id().unwrap(); +// let peer_id_c = client_c.peer_id().unwrap(); +// let peer_id_d = client_d.peer_id().unwrap(); + +// // Client A updates their selections in those editors +// editor_a1.update(cx_a, |editor, cx| { +// editor.handle_input("a", cx); +// editor.handle_input("b", cx); +// editor.handle_input("c", cx); +// editor.select_left(&Default::default(), cx); +// assert_eq!(editor.selections.ranges(cx), vec![3..2]); +// }); +// editor_a2.update(cx_a, |editor, cx| { +// editor.handle_input("d", cx); +// editor.handle_input("e", cx); +// editor.select_left(&Default::default(), cx); +// assert_eq!(editor.selections.ranges(cx), vec![2..1]); +// }); + +// // When client B starts following client A, all visible view states are replicated to client B. +// workspace_b +// .update(cx_b, |workspace, cx| { +// workspace.follow(peer_id_a, cx).unwrap() +// }) +// .await +// .unwrap(); + +// cx_c.foreground().run_until_parked(); +// let editor_b2 = workspace_b.read_with(cx_b, |workspace, cx| { +// workspace +// .active_item(cx) +// .unwrap() +// .downcast::() +// .unwrap() +// }); +// assert_eq!( +// cx_b.read(|cx| editor_b2.project_path(cx)), +// Some((worktree_id, "2.txt").into()) +// ); +// assert_eq!( +// editor_b2.read_with(cx_b, |editor, cx| editor.selections.ranges(cx)), +// vec![2..1] +// ); +// assert_eq!( +// editor_b1.read_with(cx_b, |editor, cx| editor.selections.ranges(cx)), +// vec![3..2] +// ); + +// cx_c.foreground().run_until_parked(); +// let active_call_c = cx_c.read(ActiveCall::global); +// let project_c = client_c.build_remote_project(project_id, cx_c).await; +// let window_c = client_c.build_workspace(&project_c, cx_c); +// let workspace_c = window_c.root(cx_c); +// active_call_c +// .update(cx_c, |call, cx| call.set_location(Some(&project_c), cx)) +// .await +// .unwrap(); +// drop(project_c); + +// // Client C also follows client A. +// workspace_c +// .update(cx_c, |workspace, cx| { +// workspace.follow(peer_id_a, cx).unwrap() +// }) +// .await +// .unwrap(); + +// cx_d.foreground().run_until_parked(); +// let active_call_d = cx_d.read(ActiveCall::global); +// let project_d = client_d.build_remote_project(project_id, cx_d).await; +// let workspace_d = client_d.build_workspace(&project_d, cx_d).root(cx_d); +// active_call_d +// .update(cx_d, |call, cx| call.set_location(Some(&project_d), cx)) +// .await +// .unwrap(); +// drop(project_d); + +// // All clients see that clients B and C are following client A. +// cx_c.foreground().run_until_parked(); +// for (name, cx) in [("A", &cx_a), ("B", &cx_b), ("C", &cx_c), ("D", &cx_d)] { +// assert_eq!( +// followers_by_leader(project_id, cx), +// &[(peer_id_a, vec![peer_id_b, peer_id_c])], +// "followers seen by {name}" +// ); +// } + +// // Client C unfollows client A. +// workspace_c.update(cx_c, |workspace, cx| { +// workspace.unfollow(&workspace.active_pane().clone(), cx); +// }); + +// // All clients see that clients B is following client A. +// cx_c.foreground().run_until_parked(); +// for (name, cx) in [("A", &cx_a), ("B", &cx_b), ("C", &cx_c), ("D", &cx_d)] { +// assert_eq!( +// followers_by_leader(project_id, cx), +// &[(peer_id_a, vec![peer_id_b])], +// "followers seen by {name}" +// ); +// } + +// // Client C re-follows client A. +// workspace_c +// .update(cx_c, |workspace, cx| { +// workspace.follow(peer_id_a, cx).unwrap() +// }) +// .await +// .unwrap(); + +// // All clients see that clients B and C are following client A. +// cx_c.foreground().run_until_parked(); +// for (name, cx) in [("A", &cx_a), ("B", &cx_b), ("C", &cx_c), ("D", &cx_d)] { +// assert_eq!( +// followers_by_leader(project_id, cx), +// &[(peer_id_a, vec![peer_id_b, peer_id_c])], +// "followers seen by {name}" +// ); +// } + +// // Client D follows client B, then switches to following client C. +// workspace_d +// .update(cx_d, |workspace, cx| { +// workspace.follow(peer_id_b, cx).unwrap() +// }) +// .await +// .unwrap(); +// workspace_d +// .update(cx_d, |workspace, cx| { +// workspace.follow(peer_id_c, cx).unwrap() +// }) +// .await +// .unwrap(); + +// // All clients see that D is following C +// cx_d.foreground().run_until_parked(); +// for (name, cx) in [("A", &cx_a), ("B", &cx_b), ("C", &cx_c), ("D", &cx_d)] { +// assert_eq!( +// followers_by_leader(project_id, cx), +// &[ +// (peer_id_a, vec![peer_id_b, peer_id_c]), +// (peer_id_c, vec![peer_id_d]) +// ], +// "followers seen by {name}" +// ); +// } + +// // Client C closes the project. +// window_c.remove(cx_c); +// cx_c.drop_last(workspace_c); + +// // Clients A and B see that client B is following A, and client C is not present in the followers. +// cx_c.foreground().run_until_parked(); +// for (name, cx) in [("A", &cx_a), ("B", &cx_b), ("C", &cx_c), ("D", &cx_d)] { +// assert_eq!( +// followers_by_leader(project_id, cx), +// &[(peer_id_a, vec![peer_id_b]),], +// "followers seen by {name}" +// ); +// } + +// // When client A activates a different editor, client B does so as well. +// workspace_a.update(cx_a, |workspace, cx| { +// workspace.activate_item(&editor_a1, cx) +// }); +// executor.run_until_parked(); +// workspace_b.read_with(cx_b, |workspace, cx| { +// assert_eq!(workspace.active_item(cx).unwrap().id(), editor_b1.id()); +// }); + +// // When client A opens a multibuffer, client B does so as well. +// let multibuffer_a = cx_a.add_model(|cx| { +// let buffer_a1 = project_a.update(cx, |project, cx| { +// project +// .get_open_buffer(&(worktree_id, "1.txt").into(), cx) +// .unwrap() +// }); +// let buffer_a2 = project_a.update(cx, |project, cx| { +// project +// .get_open_buffer(&(worktree_id, "2.txt").into(), cx) +// .unwrap() +// }); +// let mut result = MultiBuffer::new(0); +// result.push_excerpts( +// buffer_a1, +// [ExcerptRange { +// context: 0..3, +// primary: None, +// }], +// cx, +// ); +// result.push_excerpts( +// buffer_a2, +// [ExcerptRange { +// context: 4..7, +// primary: None, +// }], +// cx, +// ); +// result +// }); +// let multibuffer_editor_a = workspace_a.update(cx_a, |workspace, cx| { +// let editor = +// cx.add_view(|cx| Editor::for_multibuffer(multibuffer_a, Some(project_a.clone()), cx)); +// workspace.add_item(Box::new(editor.clone()), cx); +// editor +// }); +// executor.run_until_parked(); +// let multibuffer_editor_b = workspace_b.read_with(cx_b, |workspace, cx| { +// workspace +// .active_item(cx) +// .unwrap() +// .downcast::() +// .unwrap() +// }); +// assert_eq!( +// multibuffer_editor_a.read_with(cx_a, |editor, cx| editor.text(cx)), +// multibuffer_editor_b.read_with(cx_b, |editor, cx| editor.text(cx)), +// ); + +// // When client A navigates back and forth, client B does so as well. +// workspace_a +// .update(cx_a, |workspace, cx| { +// workspace.go_back(workspace.active_pane().downgrade(), cx) +// }) +// .await +// .unwrap(); +// executor.run_until_parked(); +// workspace_b.read_with(cx_b, |workspace, cx| { +// assert_eq!(workspace.active_item(cx).unwrap().id(), editor_b1.id()); +// }); + +// workspace_a +// .update(cx_a, |workspace, cx| { +// workspace.go_back(workspace.active_pane().downgrade(), cx) +// }) +// .await +// .unwrap(); +// executor.run_until_parked(); +// workspace_b.read_with(cx_b, |workspace, cx| { +// assert_eq!(workspace.active_item(cx).unwrap().id(), editor_b2.id()); +// }); + +// workspace_a +// .update(cx_a, |workspace, cx| { +// workspace.go_forward(workspace.active_pane().downgrade(), cx) +// }) +// .await +// .unwrap(); +// executor.run_until_parked(); +// workspace_b.read_with(cx_b, |workspace, cx| { +// assert_eq!(workspace.active_item(cx).unwrap().id(), editor_b1.id()); +// }); + +// // Changes to client A's editor are reflected on client B. +// editor_a1.update(cx_a, |editor, cx| { +// editor.change_selections(None, cx, |s| s.select_ranges([1..1, 2..2])); +// }); +// executor.run_until_parked(); +// editor_b1.read_with(cx_b, |editor, cx| { +// assert_eq!(editor.selections.ranges(cx), &[1..1, 2..2]); +// }); + +// editor_a1.update(cx_a, |editor, cx| editor.set_text("TWO", cx)); +// executor.run_until_parked(); +// editor_b1.read_with(cx_b, |editor, cx| assert_eq!(editor.text(cx), "TWO")); + +// editor_a1.update(cx_a, |editor, cx| { +// editor.change_selections(None, cx, |s| s.select_ranges([3..3])); +// editor.set_scroll_position(vec2f(0., 100.), cx); +// }); +// executor.run_until_parked(); +// editor_b1.read_with(cx_b, |editor, cx| { +// assert_eq!(editor.selections.ranges(cx), &[3..3]); +// }); + +// // After unfollowing, client B stops receiving updates from client A. +// workspace_b.update(cx_b, |workspace, cx| { +// workspace.unfollow(&workspace.active_pane().clone(), cx) +// }); +// workspace_a.update(cx_a, |workspace, cx| { +// workspace.activate_item(&editor_a2, cx) +// }); +// executor.run_until_parked(); +// assert_eq!( +// workspace_b.read_with(cx_b, |workspace, cx| workspace +// .active_item(cx) +// .unwrap() +// .id()), +// editor_b1.id() +// ); + +// // Client A starts following client B. +// workspace_a +// .update(cx_a, |workspace, cx| { +// workspace.follow(peer_id_b, cx).unwrap() +// }) +// .await +// .unwrap(); +// assert_eq!( +// workspace_a.read_with(cx_a, |workspace, _| workspace.leader_for_pane(&pane_a)), +// Some(peer_id_b) +// ); +// assert_eq!( +// workspace_a.read_with(cx_a, |workspace, cx| workspace +// .active_item(cx) +// .unwrap() +// .id()), +// editor_a1.id() +// ); + +// // Client B activates an external window, which causes a new screen-sharing item to be added to the pane. +// let display = MacOSDisplay::new(); +// active_call_b +// .update(cx_b, |call, cx| call.set_location(None, cx)) +// .await +// .unwrap(); +// active_call_b +// .update(cx_b, |call, cx| { +// call.room().unwrap().update(cx, |room, cx| { +// room.set_display_sources(vec![display.clone()]); +// room.share_screen(cx) +// }) +// }) +// .await +// .unwrap(); +// executor.run_until_parked(); +// let shared_screen = workspace_a.read_with(cx_a, |workspace, cx| { +// workspace +// .active_item(cx) +// .expect("no active item") +// .downcast::() +// .expect("active item isn't a shared screen") +// }); + +// // Client B activates Zed again, which causes the previous editor to become focused again. +// active_call_b +// .update(cx_b, |call, cx| call.set_location(Some(&project_b), cx)) +// .await +// .unwrap(); +// executor.run_until_parked(); +// workspace_a.read_with(cx_a, |workspace, cx| { +// assert_eq!(workspace.active_item(cx).unwrap().id(), editor_a1.id()) +// }); + +// // Client B activates a multibuffer that was created by following client A. Client A returns to that multibuffer. +// workspace_b.update(cx_b, |workspace, cx| { +// workspace.activate_item(&multibuffer_editor_b, cx) +// }); +// executor.run_until_parked(); +// workspace_a.read_with(cx_a, |workspace, cx| { +// assert_eq!( +// workspace.active_item(cx).unwrap().id(), +// multibuffer_editor_a.id() +// ) +// }); + +// // Client B activates a panel, and the previously-opened screen-sharing item gets activated. +// let panel = window_b.add_view(cx_b, |_| TestPanel::new(DockPosition::Left)); +// workspace_b.update(cx_b, |workspace, cx| { +// workspace.add_panel(panel, cx); +// workspace.toggle_panel_focus::(cx); +// }); +// executor.run_until_parked(); +// assert_eq!( +// workspace_a.read_with(cx_a, |workspace, cx| workspace +// .active_item(cx) +// .unwrap() +// .id()), +// shared_screen.id() +// ); + +// // Toggling the focus back to the pane causes client A to return to the multibuffer. +// workspace_b.update(cx_b, |workspace, cx| { +// workspace.toggle_panel_focus::(cx); +// }); +// executor.run_until_parked(); +// workspace_a.read_with(cx_a, |workspace, cx| { +// assert_eq!( +// workspace.active_item(cx).unwrap().id(), +// multibuffer_editor_a.id() +// ) +// }); + +// // Client B activates an item that doesn't implement following, +// // so the previously-opened screen-sharing item gets activated. +// let unfollowable_item = window_b.add_view(cx_b, |_| TestItem::new()); +// workspace_b.update(cx_b, |workspace, cx| { +// workspace.active_pane().update(cx, |pane, cx| { +// pane.add_item(Box::new(unfollowable_item), true, true, None, cx) +// }) +// }); +// executor.run_until_parked(); +// assert_eq!( +// workspace_a.read_with(cx_a, |workspace, cx| workspace +// .active_item(cx) +// .unwrap() +// .id()), +// shared_screen.id() +// ); + +// // Following interrupts when client B disconnects. +// client_b.disconnect(&cx_b.to_async()); +// executor.advance_clock(RECONNECT_TIMEOUT); +// assert_eq!( +// workspace_a.read_with(cx_a, |workspace, _| workspace.leader_for_pane(&pane_a)), +// None +// ); +// } + +// #[gpui::test] +// async fn test_following_tab_order( +// executor: BackgroundExecutor, +// cx_a: &mut TestAppContext, +// cx_b: &mut TestAppContext, +// ) { +// let mut server = TestServer::start(&executor).await; +// let client_a = server.create_client(cx_a, "user_a").await; +// let client_b = server.create_client(cx_b, "user_b").await; +// server +// .create_room(&mut [(&client_a, cx_a), (&client_b, cx_b)]) +// .await; +// let active_call_a = cx_a.read(ActiveCall::global); +// let active_call_b = cx_b.read(ActiveCall::global); + +// cx_a.update(editor::init); +// cx_b.update(editor::init); + +// client_a +// .fs() +// .insert_tree( +// "/a", +// json!({ +// "1.txt": "one", +// "2.txt": "two", +// "3.txt": "three", +// }), +// ) +// .await; +// let (project_a, worktree_id) = client_a.build_local_project("/a", cx_a).await; +// active_call_a +// .update(cx_a, |call, cx| call.set_location(Some(&project_a), cx)) +// .await +// .unwrap(); + +// let project_id = active_call_a +// .update(cx_a, |call, cx| call.share_project(project_a.clone(), cx)) +// .await +// .unwrap(); +// let project_b = client_b.build_remote_project(project_id, cx_b).await; +// active_call_b +// .update(cx_b, |call, cx| call.set_location(Some(&project_b), cx)) +// .await +// .unwrap(); + +// let workspace_a = client_a.build_workspace(&project_a, cx_a).root(cx_a); +// let pane_a = workspace_a.read_with(cx_a, |workspace, _| workspace.active_pane().clone()); + +// let workspace_b = client_b.build_workspace(&project_b, cx_b).root(cx_b); +// let pane_b = workspace_b.read_with(cx_b, |workspace, _| workspace.active_pane().clone()); + +// let client_b_id = project_a.read_with(cx_a, |project, _| { +// project.collaborators().values().next().unwrap().peer_id +// }); + +// //Open 1, 3 in that order on client A +// workspace_a +// .update(cx_a, |workspace, cx| { +// workspace.open_path((worktree_id, "1.txt"), None, true, cx) +// }) +// .await +// .unwrap(); +// workspace_a +// .update(cx_a, |workspace, cx| { +// workspace.open_path((worktree_id, "3.txt"), None, true, cx) +// }) +// .await +// .unwrap(); + +// let pane_paths = |pane: &ViewHandle, cx: &mut TestAppContext| { +// pane.update(cx, |pane, cx| { +// pane.items() +// .map(|item| { +// item.project_path(cx) +// .unwrap() +// .path +// .to_str() +// .unwrap() +// .to_owned() +// }) +// .collect::>() +// }) +// }; + +// //Verify that the tabs opened in the order we expect +// assert_eq!(&pane_paths(&pane_a, cx_a), &["1.txt", "3.txt"]); + +// //Follow client B as client A +// workspace_a +// .update(cx_a, |workspace, cx| { +// workspace.follow(client_b_id, cx).unwrap() +// }) +// .await +// .unwrap(); + +// //Open just 2 on client B +// workspace_b +// .update(cx_b, |workspace, cx| { +// workspace.open_path((worktree_id, "2.txt"), None, true, cx) +// }) +// .await +// .unwrap(); +// executor.run_until_parked(); + +// // Verify that newly opened followed file is at the end +// assert_eq!(&pane_paths(&pane_a, cx_a), &["1.txt", "3.txt", "2.txt"]); + +// //Open just 1 on client B +// workspace_b +// .update(cx_b, |workspace, cx| { +// workspace.open_path((worktree_id, "1.txt"), None, true, cx) +// }) +// .await +// .unwrap(); +// assert_eq!(&pane_paths(&pane_b, cx_b), &["2.txt", "1.txt"]); +// executor.run_until_parked(); + +// // Verify that following into 1 did not reorder +// assert_eq!(&pane_paths(&pane_a, cx_a), &["1.txt", "3.txt", "2.txt"]); +// } + +// #[gpui::test(iterations = 10)] +// async fn test_peers_following_each_other( +// executor: BackgroundExecutor, +// cx_a: &mut TestAppContext, +// cx_b: &mut TestAppContext, +// ) { +// let mut server = TestServer::start(&executor).await; +// let client_a = server.create_client(cx_a, "user_a").await; +// let client_b = server.create_client(cx_b, "user_b").await; +// server +// .create_room(&mut [(&client_a, cx_a), (&client_b, cx_b)]) +// .await; +// let active_call_a = cx_a.read(ActiveCall::global); +// let active_call_b = cx_b.read(ActiveCall::global); + +// cx_a.update(editor::init); +// cx_b.update(editor::init); + +// // Client A shares a project. +// client_a +// .fs() +// .insert_tree( +// "/a", +// json!({ +// "1.txt": "one", +// "2.txt": "two", +// "3.txt": "three", +// "4.txt": "four", +// }), +// ) +// .await; +// let (project_a, worktree_id) = client_a.build_local_project("/a", cx_a).await; +// active_call_a +// .update(cx_a, |call, cx| call.set_location(Some(&project_a), cx)) +// .await +// .unwrap(); +// let project_id = active_call_a +// .update(cx_a, |call, cx| call.share_project(project_a.clone(), cx)) +// .await +// .unwrap(); + +// // Client B joins the project. +// let project_b = client_b.build_remote_project(project_id, cx_b).await; +// active_call_b +// .update(cx_b, |call, cx| call.set_location(Some(&project_b), cx)) +// .await +// .unwrap(); + +// // Client A opens a file. +// let workspace_a = client_a.build_workspace(&project_a, cx_a).root(cx_a); +// workspace_a +// .update(cx_a, |workspace, cx| { +// workspace.open_path((worktree_id, "1.txt"), None, true, cx) +// }) +// .await +// .unwrap() +// .downcast::() +// .unwrap(); + +// // Client B opens a different file. +// let workspace_b = client_b.build_workspace(&project_b, cx_b).root(cx_b); +// workspace_b +// .update(cx_b, |workspace, cx| { +// workspace.open_path((worktree_id, "2.txt"), None, true, cx) +// }) +// .await +// .unwrap() +// .downcast::() +// .unwrap(); + +// // Clients A and B follow each other in split panes +// workspace_a.update(cx_a, |workspace, cx| { +// workspace.split_and_clone(workspace.active_pane().clone(), SplitDirection::Right, cx); +// }); +// workspace_a +// .update(cx_a, |workspace, cx| { +// workspace.follow(client_b.peer_id().unwrap(), cx).unwrap() +// }) +// .await +// .unwrap(); +// workspace_b.update(cx_b, |workspace, cx| { +// workspace.split_and_clone(workspace.active_pane().clone(), SplitDirection::Right, cx); +// }); +// workspace_b +// .update(cx_b, |workspace, cx| { +// workspace.follow(client_a.peer_id().unwrap(), cx).unwrap() +// }) +// .await +// .unwrap(); + +// // Clients A and B return focus to the original files they had open +// workspace_a.update(cx_a, |workspace, cx| workspace.activate_next_pane(cx)); +// workspace_b.update(cx_b, |workspace, cx| workspace.activate_next_pane(cx)); +// executor.run_until_parked(); + +// // Both clients see the other client's focused file in their right pane. +// assert_eq!( +// pane_summaries(&workspace_a, cx_a), +// &[ +// PaneSummary { +// active: true, +// leader: None, +// items: vec![(true, "1.txt".into())] +// }, +// PaneSummary { +// active: false, +// leader: client_b.peer_id(), +// items: vec![(false, "1.txt".into()), (true, "2.txt".into())] +// }, +// ] +// ); +// assert_eq!( +// pane_summaries(&workspace_b, cx_b), +// &[ +// PaneSummary { +// active: true, +// leader: None, +// items: vec![(true, "2.txt".into())] +// }, +// PaneSummary { +// active: false, +// leader: client_a.peer_id(), +// items: vec![(false, "2.txt".into()), (true, "1.txt".into())] +// }, +// ] +// ); + +// // Clients A and B each open a new file. +// workspace_a +// .update(cx_a, |workspace, cx| { +// workspace.open_path((worktree_id, "3.txt"), None, true, cx) +// }) +// .await +// .unwrap(); + +// workspace_b +// .update(cx_b, |workspace, cx| { +// workspace.open_path((worktree_id, "4.txt"), None, true, cx) +// }) +// .await +// .unwrap(); +// executor.run_until_parked(); + +// // Both client's see the other client open the new file, but keep their +// // focus on their own active pane. +// assert_eq!( +// pane_summaries(&workspace_a, cx_a), +// &[ +// PaneSummary { +// active: true, +// leader: None, +// items: vec![(false, "1.txt".into()), (true, "3.txt".into())] +// }, +// PaneSummary { +// active: false, +// leader: client_b.peer_id(), +// items: vec![ +// (false, "1.txt".into()), +// (false, "2.txt".into()), +// (true, "4.txt".into()) +// ] +// }, +// ] +// ); +// assert_eq!( +// pane_summaries(&workspace_b, cx_b), +// &[ +// PaneSummary { +// active: true, +// leader: None, +// items: vec![(false, "2.txt".into()), (true, "4.txt".into())] +// }, +// PaneSummary { +// active: false, +// leader: client_a.peer_id(), +// items: vec![ +// (false, "2.txt".into()), +// (false, "1.txt".into()), +// (true, "3.txt".into()) +// ] +// }, +// ] +// ); + +// // Client A focuses their right pane, in which they're following client B. +// workspace_a.update(cx_a, |workspace, cx| workspace.activate_next_pane(cx)); +// executor.run_until_parked(); + +// // Client B sees that client A is now looking at the same file as them. +// assert_eq!( +// pane_summaries(&workspace_a, cx_a), +// &[ +// PaneSummary { +// active: false, +// leader: None, +// items: vec![(false, "1.txt".into()), (true, "3.txt".into())] +// }, +// PaneSummary { +// active: true, +// leader: client_b.peer_id(), +// items: vec![ +// (false, "1.txt".into()), +// (false, "2.txt".into()), +// (true, "4.txt".into()) +// ] +// }, +// ] +// ); +// assert_eq!( +// pane_summaries(&workspace_b, cx_b), +// &[ +// PaneSummary { +// active: true, +// leader: None, +// items: vec![(false, "2.txt".into()), (true, "4.txt".into())] +// }, +// PaneSummary { +// active: false, +// leader: client_a.peer_id(), +// items: vec![ +// (false, "2.txt".into()), +// (false, "1.txt".into()), +// (false, "3.txt".into()), +// (true, "4.txt".into()) +// ] +// }, +// ] +// ); + +// // Client B focuses their right pane, in which they're following client A, +// // who is following them. +// workspace_b.update(cx_b, |workspace, cx| workspace.activate_next_pane(cx)); +// executor.run_until_parked(); + +// // Client A sees that client B is now looking at the same file as them. +// assert_eq!( +// pane_summaries(&workspace_b, cx_b), +// &[ +// PaneSummary { +// active: false, +// leader: None, +// items: vec![(false, "2.txt".into()), (true, "4.txt".into())] +// }, +// PaneSummary { +// active: true, +// leader: client_a.peer_id(), +// items: vec![ +// (false, "2.txt".into()), +// (false, "1.txt".into()), +// (false, "3.txt".into()), +// (true, "4.txt".into()) +// ] +// }, +// ] +// ); +// assert_eq!( +// pane_summaries(&workspace_a, cx_a), +// &[ +// PaneSummary { +// active: false, +// leader: None, +// items: vec![(false, "1.txt".into()), (true, "3.txt".into())] +// }, +// PaneSummary { +// active: true, +// leader: client_b.peer_id(), +// items: vec![ +// (false, "1.txt".into()), +// (false, "2.txt".into()), +// (true, "4.txt".into()) +// ] +// }, +// ] +// ); + +// // Client B focuses a file that they previously followed A to, breaking +// // the follow. +// workspace_b.update(cx_b, |workspace, cx| { +// workspace.active_pane().update(cx, |pane, cx| { +// pane.activate_prev_item(true, cx); +// }); +// }); +// executor.run_until_parked(); + +// // Both clients see that client B is looking at that previous file. +// assert_eq!( +// pane_summaries(&workspace_b, cx_b), +// &[ +// PaneSummary { +// active: false, +// leader: None, +// items: vec![(false, "2.txt".into()), (true, "4.txt".into())] +// }, +// PaneSummary { +// active: true, +// leader: None, +// items: vec![ +// (false, "2.txt".into()), +// (false, "1.txt".into()), +// (true, "3.txt".into()), +// (false, "4.txt".into()) +// ] +// }, +// ] +// ); +// assert_eq!( +// pane_summaries(&workspace_a, cx_a), +// &[ +// PaneSummary { +// active: false, +// leader: None, +// items: vec![(false, "1.txt".into()), (true, "3.txt".into())] +// }, +// PaneSummary { +// active: true, +// leader: client_b.peer_id(), +// items: vec![ +// (false, "1.txt".into()), +// (false, "2.txt".into()), +// (false, "4.txt".into()), +// (true, "3.txt".into()), +// ] +// }, +// ] +// ); + +// // Client B closes tabs, some of which were originally opened by client A, +// // and some of which were originally opened by client B. +// workspace_b.update(cx_b, |workspace, cx| { +// workspace.active_pane().update(cx, |pane, cx| { +// pane.close_inactive_items(&Default::default(), cx) +// .unwrap() +// .detach(); +// }); +// }); + +// executor.run_until_parked(); + +// // Both clients see that Client B is looking at the previous tab. +// assert_eq!( +// pane_summaries(&workspace_b, cx_b), +// &[ +// PaneSummary { +// active: false, +// leader: None, +// items: vec![(false, "2.txt".into()), (true, "4.txt".into())] +// }, +// PaneSummary { +// active: true, +// leader: None, +// items: vec![(true, "3.txt".into()),] +// }, +// ] +// ); +// assert_eq!( +// pane_summaries(&workspace_a, cx_a), +// &[ +// PaneSummary { +// active: false, +// leader: None, +// items: vec![(false, "1.txt".into()), (true, "3.txt".into())] +// }, +// PaneSummary { +// active: true, +// leader: client_b.peer_id(), +// items: vec![ +// (false, "1.txt".into()), +// (false, "2.txt".into()), +// (false, "4.txt".into()), +// (true, "3.txt".into()), +// ] +// }, +// ] +// ); + +// // Client B follows client A again. +// workspace_b +// .update(cx_b, |workspace, cx| { +// workspace.follow(client_a.peer_id().unwrap(), cx).unwrap() +// }) +// .await +// .unwrap(); + +// // Client A cycles through some tabs. +// workspace_a.update(cx_a, |workspace, cx| { +// workspace.active_pane().update(cx, |pane, cx| { +// pane.activate_prev_item(true, cx); +// }); +// }); +// executor.run_until_parked(); + +// // Client B follows client A into those tabs. +// assert_eq!( +// pane_summaries(&workspace_a, cx_a), +// &[ +// PaneSummary { +// active: false, +// leader: None, +// items: vec![(false, "1.txt".into()), (true, "3.txt".into())] +// }, +// PaneSummary { +// active: true, +// leader: None, +// items: vec![ +// (false, "1.txt".into()), +// (false, "2.txt".into()), +// (true, "4.txt".into()), +// (false, "3.txt".into()), +// ] +// }, +// ] +// ); +// assert_eq!( +// pane_summaries(&workspace_b, cx_b), +// &[ +// PaneSummary { +// active: false, +// leader: None, +// items: vec![(false, "2.txt".into()), (true, "4.txt".into())] +// }, +// PaneSummary { +// active: true, +// leader: client_a.peer_id(), +// items: vec![(false, "3.txt".into()), (true, "4.txt".into())] +// }, +// ] +// ); + +// workspace_a.update(cx_a, |workspace, cx| { +// workspace.active_pane().update(cx, |pane, cx| { +// pane.activate_prev_item(true, cx); +// }); +// }); +// executor.run_until_parked(); + +// assert_eq!( +// pane_summaries(&workspace_a, cx_a), +// &[ +// PaneSummary { +// active: false, +// leader: None, +// items: vec![(false, "1.txt".into()), (true, "3.txt".into())] +// }, +// PaneSummary { +// active: true, +// leader: None, +// items: vec![ +// (false, "1.txt".into()), +// (true, "2.txt".into()), +// (false, "4.txt".into()), +// (false, "3.txt".into()), +// ] +// }, +// ] +// ); +// assert_eq!( +// pane_summaries(&workspace_b, cx_b), +// &[ +// PaneSummary { +// active: false, +// leader: None, +// items: vec![(false, "2.txt".into()), (true, "4.txt".into())] +// }, +// PaneSummary { +// active: true, +// leader: client_a.peer_id(), +// items: vec![ +// (false, "3.txt".into()), +// (false, "4.txt".into()), +// (true, "2.txt".into()) +// ] +// }, +// ] +// ); + +// workspace_a.update(cx_a, |workspace, cx| { +// workspace.active_pane().update(cx, |pane, cx| { +// pane.activate_prev_item(true, cx); +// }); +// }); +// executor.run_until_parked(); + +// assert_eq!( +// pane_summaries(&workspace_a, cx_a), +// &[ +// PaneSummary { +// active: false, +// leader: None, +// items: vec![(false, "1.txt".into()), (true, "3.txt".into())] +// }, +// PaneSummary { +// active: true, +// leader: None, +// items: vec![ +// (true, "1.txt".into()), +// (false, "2.txt".into()), +// (false, "4.txt".into()), +// (false, "3.txt".into()), +// ] +// }, +// ] +// ); +// assert_eq!( +// pane_summaries(&workspace_b, cx_b), +// &[ +// PaneSummary { +// active: false, +// leader: None, +// items: vec![(false, "2.txt".into()), (true, "4.txt".into())] +// }, +// PaneSummary { +// active: true, +// leader: client_a.peer_id(), +// items: vec![ +// (false, "3.txt".into()), +// (false, "4.txt".into()), +// (false, "2.txt".into()), +// (true, "1.txt".into()), +// ] +// }, +// ] +// ); +// } + +// #[gpui::test(iterations = 10)] +// async fn test_auto_unfollowing( +// executor: BackgroundExecutor, +// cx_a: &mut TestAppContext, +// cx_b: &mut TestAppContext, +// ) { +// // 2 clients connect to a server. +// let mut server = TestServer::start(&executor).await; +// let client_a = server.create_client(cx_a, "user_a").await; +// let client_b = server.create_client(cx_b, "user_b").await; +// server +// .create_room(&mut [(&client_a, cx_a), (&client_b, cx_b)]) +// .await; +// let active_call_a = cx_a.read(ActiveCall::global); +// let active_call_b = cx_b.read(ActiveCall::global); + +// cx_a.update(editor::init); +// cx_b.update(editor::init); + +// // Client A shares a project. +// client_a +// .fs() +// .insert_tree( +// "/a", +// json!({ +// "1.txt": "one", +// "2.txt": "two", +// "3.txt": "three", +// }), +// ) +// .await; +// let (project_a, worktree_id) = client_a.build_local_project("/a", cx_a).await; +// active_call_a +// .update(cx_a, |call, cx| call.set_location(Some(&project_a), cx)) +// .await +// .unwrap(); + +// let project_id = active_call_a +// .update(cx_a, |call, cx| call.share_project(project_a.clone(), cx)) +// .await +// .unwrap(); +// let project_b = client_b.build_remote_project(project_id, cx_b).await; +// active_call_b +// .update(cx_b, |call, cx| call.set_location(Some(&project_b), cx)) +// .await +// .unwrap(); + +// // Client A opens some editors. +// let workspace_a = client_a.build_workspace(&project_a, cx_a).root(cx_a); +// let _editor_a1 = workspace_a +// .update(cx_a, |workspace, cx| { +// workspace.open_path((worktree_id, "1.txt"), None, true, cx) +// }) +// .await +// .unwrap() +// .downcast::() +// .unwrap(); + +// // Client B starts following client A. +// let workspace_b = client_b.build_workspace(&project_b, cx_b).root(cx_b); +// let pane_b = workspace_b.read_with(cx_b, |workspace, _| workspace.active_pane().clone()); +// let leader_id = project_b.read_with(cx_b, |project, _| { +// project.collaborators().values().next().unwrap().peer_id +// }); +// workspace_b +// .update(cx_b, |workspace, cx| { +// workspace.follow(leader_id, cx).unwrap() +// }) +// .await +// .unwrap(); +// assert_eq!( +// workspace_b.read_with(cx_b, |workspace, _| workspace.leader_for_pane(&pane_b)), +// Some(leader_id) +// ); +// let editor_b2 = workspace_b.read_with(cx_b, |workspace, cx| { +// workspace +// .active_item(cx) +// .unwrap() +// .downcast::() +// .unwrap() +// }); + +// // When client B moves, it automatically stops following client A. +// editor_b2.update(cx_b, |editor, cx| editor.move_right(&editor::MoveRight, cx)); +// assert_eq!( +// workspace_b.read_with(cx_b, |workspace, _| workspace.leader_for_pane(&pane_b)), +// None +// ); + +// workspace_b +// .update(cx_b, |workspace, cx| { +// workspace.follow(leader_id, cx).unwrap() +// }) +// .await +// .unwrap(); +// assert_eq!( +// workspace_b.read_with(cx_b, |workspace, _| workspace.leader_for_pane(&pane_b)), +// Some(leader_id) +// ); + +// // When client B edits, it automatically stops following client A. +// editor_b2.update(cx_b, |editor, cx| editor.insert("X", cx)); +// assert_eq!( +// workspace_b.read_with(cx_b, |workspace, _| workspace.leader_for_pane(&pane_b)), +// None +// ); + +// workspace_b +// .update(cx_b, |workspace, cx| { +// workspace.follow(leader_id, cx).unwrap() +// }) +// .await +// .unwrap(); +// assert_eq!( +// workspace_b.read_with(cx_b, |workspace, _| workspace.leader_for_pane(&pane_b)), +// Some(leader_id) +// ); + +// // When client B scrolls, it automatically stops following client A. +// editor_b2.update(cx_b, |editor, cx| { +// editor.set_scroll_position(vec2f(0., 3.), cx) +// }); +// assert_eq!( +// workspace_b.read_with(cx_b, |workspace, _| workspace.leader_for_pane(&pane_b)), +// None +// ); + +// workspace_b +// .update(cx_b, |workspace, cx| { +// workspace.follow(leader_id, cx).unwrap() +// }) +// .await +// .unwrap(); +// assert_eq!( +// workspace_b.read_with(cx_b, |workspace, _| workspace.leader_for_pane(&pane_b)), +// Some(leader_id) +// ); + +// // When client B activates a different pane, it continues following client A in the original pane. +// workspace_b.update(cx_b, |workspace, cx| { +// workspace.split_and_clone(pane_b.clone(), SplitDirection::Right, cx) +// }); +// assert_eq!( +// workspace_b.read_with(cx_b, |workspace, _| workspace.leader_for_pane(&pane_b)), +// Some(leader_id) +// ); + +// workspace_b.update(cx_b, |workspace, cx| workspace.activate_next_pane(cx)); +// assert_eq!( +// workspace_b.read_with(cx_b, |workspace, _| workspace.leader_for_pane(&pane_b)), +// Some(leader_id) +// ); + +// // When client B activates a different item in the original pane, it automatically stops following client A. +// workspace_b +// .update(cx_b, |workspace, cx| { +// workspace.open_path((worktree_id, "2.txt"), None, true, cx) +// }) +// .await +// .unwrap(); +// assert_eq!( +// workspace_b.read_with(cx_b, |workspace, _| workspace.leader_for_pane(&pane_b)), +// None +// ); +// } + +// #[gpui::test(iterations = 10)] +// async fn test_peers_simultaneously_following_each_other( +// executor: BackgroundExecutor, +// cx_a: &mut TestAppContext, +// cx_b: &mut TestAppContext, +// ) { +// let mut server = TestServer::start(&executor).await; +// let client_a = server.create_client(cx_a, "user_a").await; +// let client_b = server.create_client(cx_b, "user_b").await; +// server +// .create_room(&mut [(&client_a, cx_a), (&client_b, cx_b)]) +// .await; +// let active_call_a = cx_a.read(ActiveCall::global); + +// cx_a.update(editor::init); +// cx_b.update(editor::init); + +// client_a.fs().insert_tree("/a", json!({})).await; +// let (project_a, _) = client_a.build_local_project("/a", cx_a).await; +// let workspace_a = client_a.build_workspace(&project_a, cx_a).root(cx_a); +// let project_id = active_call_a +// .update(cx_a, |call, cx| call.share_project(project_a.clone(), cx)) +// .await +// .unwrap(); + +// let project_b = client_b.build_remote_project(project_id, cx_b).await; +// let workspace_b = client_b.build_workspace(&project_b, cx_b).root(cx_b); + +// executor.run_until_parked(); +// let client_a_id = project_b.read_with(cx_b, |project, _| { +// project.collaborators().values().next().unwrap().peer_id +// }); +// let client_b_id = project_a.read_with(cx_a, |project, _| { +// project.collaborators().values().next().unwrap().peer_id +// }); + +// let a_follow_b = workspace_a.update(cx_a, |workspace, cx| { +// workspace.follow(client_b_id, cx).unwrap() +// }); +// let b_follow_a = workspace_b.update(cx_b, |workspace, cx| { +// workspace.follow(client_a_id, cx).unwrap() +// }); + +// futures::try_join!(a_follow_b, b_follow_a).unwrap(); +// workspace_a.read_with(cx_a, |workspace, _| { +// assert_eq!( +// workspace.leader_for_pane(workspace.active_pane()), +// Some(client_b_id) +// ); +// }); +// workspace_b.read_with(cx_b, |workspace, _| { +// assert_eq!( +// workspace.leader_for_pane(workspace.active_pane()), +// Some(client_a_id) +// ); +// }); +// } + +// #[gpui::test(iterations = 10)] +// async fn test_following_across_workspaces( +// executor: BackgroundExecutor, +// cx_a: &mut TestAppContext, +// cx_b: &mut TestAppContext, +// ) { +// // a and b join a channel/call +// // a shares project 1 +// // b shares project 2 +// // +// // b follows a: causes project 2 to be joined, and b to follow a. +// // b opens a different file in project 2, a follows b +// // b opens a different file in project 1, a cannot follow b +// // b shares the project, a joins the project and follows b +// let mut server = TestServer::start(&executor).await; +// let client_a = server.create_client(cx_a, "user_a").await; +// let client_b = server.create_client(cx_b, "user_b").await; +// cx_a.update(editor::init); +// cx_b.update(editor::init); + +// client_a +// .fs() +// .insert_tree( +// "/a", +// json!({ +// "w.rs": "", +// "x.rs": "", +// }), +// ) +// .await; + +// client_b +// .fs() +// .insert_tree( +// "/b", +// json!({ +// "y.rs": "", +// "z.rs": "", +// }), +// ) +// .await; + +// server +// .create_room(&mut [(&client_a, cx_a), (&client_b, cx_b)]) +// .await; +// let active_call_a = cx_a.read(ActiveCall::global); +// let active_call_b = cx_b.read(ActiveCall::global); + +// let (project_a, worktree_id_a) = client_a.build_local_project("/a", cx_a).await; +// let (project_b, worktree_id_b) = client_b.build_local_project("/b", cx_b).await; + +// let workspace_a = client_a.build_workspace(&project_a, cx_a).root(cx_a); +// let workspace_b = client_b.build_workspace(&project_b, cx_b).root(cx_b); + +// cx_a.update(|cx| collab_ui::init(&client_a.app_state, cx)); +// cx_b.update(|cx| collab_ui::init(&client_b.app_state, cx)); + +// active_call_a +// .update(cx_a, |call, cx| call.share_project(project_a.clone(), cx)) +// .await +// .unwrap(); + +// active_call_a +// .update(cx_a, |call, cx| call.set_location(Some(&project_a), cx)) +// .await +// .unwrap(); +// active_call_b +// .update(cx_b, |call, cx| call.set_location(Some(&project_b), cx)) +// .await +// .unwrap(); + +// workspace_a +// .update(cx_a, |workspace, cx| { +// workspace.open_path((worktree_id_a, "w.rs"), None, true, cx) +// }) +// .await +// .unwrap(); + +// executor.run_until_parked(); +// assert_eq!(visible_push_notifications(cx_b).len(), 1); + +// workspace_b.update(cx_b, |workspace, cx| { +// workspace +// .follow(client_a.peer_id().unwrap(), cx) +// .unwrap() +// .detach() +// }); + +// executor.run_until_parked(); +// let workspace_b_project_a = cx_b +// .windows() +// .iter() +// .max_by_key(|window| window.id()) +// .unwrap() +// .downcast::() +// .unwrap() +// .root(cx_b); + +// // assert that b is following a in project a in w.rs +// workspace_b_project_a.update(cx_b, |workspace, cx| { +// assert!(workspace.is_being_followed(client_a.peer_id().unwrap())); +// assert_eq!( +// client_a.peer_id(), +// workspace.leader_for_pane(workspace.active_pane()) +// ); +// let item = workspace.active_item(cx).unwrap(); +// assert_eq!(item.tab_description(0, cx).unwrap(), Cow::Borrowed("w.rs")); +// }); + +// // TODO: in app code, this would be done by the collab_ui. +// active_call_b +// .update(cx_b, |call, cx| { +// let project = workspace_b_project_a.read(cx).project().clone(); +// call.set_location(Some(&project), cx) +// }) +// .await +// .unwrap(); + +// // assert that there are no share notifications open +// assert_eq!(visible_push_notifications(cx_b).len(), 0); + +// // b moves to x.rs in a's project, and a follows +// workspace_b_project_a +// .update(cx_b, |workspace, cx| { +// workspace.open_path((worktree_id_a, "x.rs"), None, true, cx) +// }) +// .await +// .unwrap(); + +// executor.run_until_parked(); +// workspace_b_project_a.update(cx_b, |workspace, cx| { +// let item = workspace.active_item(cx).unwrap(); +// assert_eq!(item.tab_description(0, cx).unwrap(), Cow::Borrowed("x.rs")); +// }); + +// workspace_a.update(cx_a, |workspace, cx| { +// workspace +// .follow(client_b.peer_id().unwrap(), cx) +// .unwrap() +// .detach() +// }); + +// executor.run_until_parked(); +// workspace_a.update(cx_a, |workspace, cx| { +// assert!(workspace.is_being_followed(client_b.peer_id().unwrap())); +// assert_eq!( +// client_b.peer_id(), +// workspace.leader_for_pane(workspace.active_pane()) +// ); +// let item = workspace.active_pane().read(cx).active_item().unwrap(); +// assert_eq!(item.tab_description(0, cx).unwrap(), Cow::Borrowed("x.rs")); +// }); + +// // b moves to y.rs in b's project, a is still following but can't yet see +// workspace_b +// .update(cx_b, |workspace, cx| { +// workspace.open_path((worktree_id_b, "y.rs"), None, true, cx) +// }) +// .await +// .unwrap(); + +// // TODO: in app code, this would be done by the collab_ui. +// active_call_b +// .update(cx_b, |call, cx| { +// let project = workspace_b.read(cx).project().clone(); +// call.set_location(Some(&project), cx) +// }) +// .await +// .unwrap(); + +// let project_b_id = active_call_b +// .update(cx_b, |call, cx| call.share_project(project_b.clone(), cx)) +// .await +// .unwrap(); + +// executor.run_until_parked(); +// assert_eq!(visible_push_notifications(cx_a).len(), 1); +// cx_a.update(|cx| { +// workspace::join_remote_project( +// project_b_id, +// client_b.user_id().unwrap(), +// client_a.app_state.clone(), +// cx, +// ) +// }) +// .await +// .unwrap(); + +// executor.run_until_parked(); + +// assert_eq!(visible_push_notifications(cx_a).len(), 0); +// let workspace_a_project_b = cx_a +// .windows() +// .iter() +// .max_by_key(|window| window.id()) +// .unwrap() +// .downcast::() +// .unwrap() +// .root(cx_a); + +// workspace_a_project_b.update(cx_a, |workspace, cx| { +// assert_eq!(workspace.project().read(cx).remote_id(), Some(project_b_id)); +// assert!(workspace.is_being_followed(client_b.peer_id().unwrap())); +// assert_eq!( +// client_b.peer_id(), +// workspace.leader_for_pane(workspace.active_pane()) +// ); +// let item = workspace.active_item(cx).unwrap(); +// assert_eq!(item.tab_description(0, cx).unwrap(), Cow::Borrowed("y.rs")); +// }); +// } + +// fn visible_push_notifications( +// cx: &mut TestAppContext, +// ) -> Vec> { +// let mut ret = Vec::new(); +// for window in cx.windows() { +// window.read_with(cx, |window| { +// if let Some(handle) = window +// .root_view() +// .clone() +// .downcast::() +// { +// ret.push(handle) +// } +// }); +// } +// ret +// } + +// #[derive(Debug, PartialEq, Eq)] +// struct PaneSummary { +// active: bool, +// leader: Option, +// items: Vec<(bool, String)>, +// } + +// fn followers_by_leader(project_id: u64, cx: &TestAppContext) -> Vec<(PeerId, Vec)> { +// cx.read(|cx| { +// let active_call = ActiveCall::global(cx).read(cx); +// let peer_id = active_call.client().peer_id(); +// let room = active_call.room().unwrap().read(cx); +// let mut result = room +// .remote_participants() +// .values() +// .map(|participant| participant.peer_id) +// .chain(peer_id) +// .filter_map(|peer_id| { +// let followers = room.followers_for(peer_id, project_id); +// if followers.is_empty() { +// None +// } else { +// Some((peer_id, followers.to_vec())) +// } +// }) +// .collect::>(); +// result.sort_by_key(|e| e.0); +// result +// }) +// } + +// fn pane_summaries(workspace: &ViewHandle, cx: &mut TestAppContext) -> Vec { +// workspace.read_with(cx, |workspace, cx| { +// let active_pane = workspace.active_pane(); +// workspace +// .panes() +// .iter() +// .map(|pane| { +// let leader = workspace.leader_for_pane(pane); +// let active = pane == active_pane; +// let pane = pane.read(cx); +// let active_ix = pane.active_item_index(); +// PaneSummary { +// active, +// leader, +// items: pane +// .items() +// .enumerate() +// .map(|(ix, item)| { +// ( +// ix == active_ix, +// item.tab_description(0, cx) +// .map_or(String::new(), |s| s.to_string()), +// ) +// }) +// .collect(), +// } +// }) +// .collect() +// }) +// } diff --git a/crates/collab2/src/tests/integration_tests.rs b/crates/collab2/src/tests/integration_tests.rs new file mode 100644 index 0000000000..f681e4877f --- /dev/null +++ b/crates/collab2/src/tests/integration_tests.rs @@ -0,0 +1,6474 @@ +use crate::{ + rpc::{CLEANUP_TIMEOUT, RECONNECT_TIMEOUT}, + tests::{channel_id, room_participants, RoomParticipants, TestClient, TestServer}, +}; +use call::{room, ActiveCall, ParticipantLocation, Room}; +use client::{User, RECEIVE_TIMEOUT}; +use collections::{HashMap, HashSet}; +use fs::{repository::GitFileStatus, FakeFs, Fs as _, RemoveOptions}; +use futures::StreamExt as _; +use gpui::{AppContext, BackgroundExecutor, Model, TestAppContext}; +use language::{ + language_settings::{AllLanguageSettings, Formatter}, + tree_sitter_rust, Diagnostic, DiagnosticEntry, FakeLspAdapter, Language, LanguageConfig, + LineEnding, OffsetRangeExt, Point, Rope, +}; +use live_kit_client::MacOSDisplay; +use lsp::LanguageServerId; +use project::{ + search::SearchQuery, DiagnosticSummary, FormatTrigger, HoverBlockKind, Project, ProjectPath, +}; +use rand::prelude::*; +use serde_json::json; +use settings::SettingsStore; +use std::{ + cell::{Cell, RefCell}, + env, future, mem, + path::{Path, PathBuf}, + rc::Rc, + sync::{ + atomic::{AtomicBool, Ordering::SeqCst}, + Arc, + }, +}; +use unindent::Unindent as _; + +#[ctor::ctor] +fn init_logger() { + if std::env::var("RUST_LOG").is_ok() { + env_logger::init(); + } +} + +#[gpui::test(iterations = 10)] +async fn test_basic_calls( + executor: BackgroundExecutor, + cx_a: &mut TestAppContext, + cx_b: &mut TestAppContext, + cx_b2: &mut TestAppContext, + cx_c: &mut TestAppContext, +) { + let mut server = TestServer::start(executor.clone()).await; + + let client_a = server.create_client(cx_a, "user_a").await; + let client_b = server.create_client(cx_b, "user_b").await; + let client_c = server.create_client(cx_c, "user_c").await; + server + .make_contacts(&mut [(&client_a, cx_a), (&client_b, cx_b), (&client_c, cx_c)]) + .await; + + let active_call_a = cx_a.read(ActiveCall::global); + let active_call_b = cx_b.read(ActiveCall::global); + let active_call_c = cx_c.read(ActiveCall::global); + + // Call user B from client A. + active_call_a + .update(cx_a, |call, cx| { + call.invite(client_b.user_id().unwrap(), None, cx) + }) + .await + .unwrap(); + let room_a = active_call_a.read_with(cx_a, |call, _| call.room().unwrap().clone()); + executor.run_until_parked(); + assert_eq!( + room_participants(&room_a, cx_a), + RoomParticipants { + remote: Default::default(), + pending: vec!["user_b".to_string()] + } + ); + + // User B receives the call. + + let mut incoming_call_b = active_call_b.read_with(cx_b, |call, _| call.incoming()); + let call_b = incoming_call_b.next().await.unwrap().unwrap(); + assert_eq!(call_b.calling_user.github_login, "user_a"); + + // User B connects via another client and also receives a ring on the newly-connected client. + let _client_b2 = server.create_client(cx_b2, "user_b").await; + let active_call_b2 = cx_b2.read(ActiveCall::global); + + let mut incoming_call_b2 = active_call_b2.read_with(cx_b2, |call, _| call.incoming()); + executor.run_until_parked(); + let call_b2 = incoming_call_b2.next().await.unwrap().unwrap(); + assert_eq!(call_b2.calling_user.github_login, "user_a"); + + // User B joins the room using the first client. + active_call_b + .update(cx_b, |call, cx| call.accept_incoming(cx)) + .await + .unwrap(); + + let room_b = active_call_b.read_with(cx_b, |call, _| call.room().unwrap().clone()); + assert!(incoming_call_b.next().await.unwrap().is_none()); + + executor.run_until_parked(); + assert_eq!( + room_participants(&room_a, cx_a), + RoomParticipants { + remote: vec!["user_b".to_string()], + pending: Default::default() + } + ); + assert_eq!( + room_participants(&room_b, cx_b), + RoomParticipants { + remote: vec!["user_a".to_string()], + pending: Default::default() + } + ); + + // Call user C from client B. + + let mut incoming_call_c = active_call_c.read_with(cx_c, |call, _| call.incoming()); + active_call_b + .update(cx_b, |call, cx| { + call.invite(client_c.user_id().unwrap(), None, cx) + }) + .await + .unwrap(); + + executor.run_until_parked(); + assert_eq!( + room_participants(&room_a, cx_a), + RoomParticipants { + remote: vec!["user_b".to_string()], + pending: vec!["user_c".to_string()] + } + ); + assert_eq!( + room_participants(&room_b, cx_b), + RoomParticipants { + remote: vec!["user_a".to_string()], + pending: vec!["user_c".to_string()] + } + ); + + // User C receives the call, but declines it. + let call_c = incoming_call_c.next().await.unwrap().unwrap(); + assert_eq!(call_c.calling_user.github_login, "user_b"); + active_call_c.update(cx_c, |call, cx| call.decline_incoming(cx).unwrap()); + assert!(incoming_call_c.next().await.unwrap().is_none()); + + executor.run_until_parked(); + assert_eq!( + room_participants(&room_a, cx_a), + RoomParticipants { + remote: vec!["user_b".to_string()], + pending: Default::default() + } + ); + assert_eq!( + room_participants(&room_b, cx_b), + RoomParticipants { + remote: vec!["user_a".to_string()], + pending: Default::default() + } + ); + + // Call user C again from user A. + active_call_a + .update(cx_a, |call, cx| { + call.invite(client_c.user_id().unwrap(), None, cx) + }) + .await + .unwrap(); + + executor.run_until_parked(); + assert_eq!( + room_participants(&room_a, cx_a), + RoomParticipants { + remote: vec!["user_b".to_string()], + pending: vec!["user_c".to_string()] + } + ); + assert_eq!( + room_participants(&room_b, cx_b), + RoomParticipants { + remote: vec!["user_a".to_string()], + pending: vec!["user_c".to_string()] + } + ); + + // User C accepts the call. + let call_c = incoming_call_c.next().await.unwrap().unwrap(); + assert_eq!(call_c.calling_user.github_login, "user_a"); + active_call_c + .update(cx_c, |call, cx| call.accept_incoming(cx)) + .await + .unwrap(); + assert!(incoming_call_c.next().await.unwrap().is_none()); + + let room_c = active_call_c.read_with(cx_c, |call, _| call.room().unwrap().clone()); + + executor.run_until_parked(); + assert_eq!( + room_participants(&room_a, cx_a), + RoomParticipants { + remote: vec!["user_b".to_string(), "user_c".to_string()], + pending: Default::default() + } + ); + assert_eq!( + room_participants(&room_b, cx_b), + RoomParticipants { + remote: vec!["user_a".to_string(), "user_c".to_string()], + pending: Default::default() + } + ); + assert_eq!( + room_participants(&room_c, cx_c), + RoomParticipants { + remote: vec!["user_a".to_string(), "user_b".to_string()], + pending: Default::default() + } + ); + + // User A shares their screen + let display = MacOSDisplay::new(); + let events_b = active_call_events(cx_b); + let events_c = active_call_events(cx_c); + active_call_a + .update(cx_a, |call, cx| { + call.room().unwrap().update(cx, |room, cx| { + room.set_display_sources(vec![display.clone()]); + room.share_screen(cx) + }) + }) + .await + .unwrap(); + + executor.run_until_parked(); + + // User B observes the remote screen sharing track. + assert_eq!(events_b.borrow().len(), 1); + let event_b = events_b.borrow().first().unwrap().clone(); + if let call::room::Event::RemoteVideoTracksChanged { participant_id } = event_b { + assert_eq!(participant_id, client_a.peer_id().unwrap()); + + room_b.read_with(cx_b, |room, _| { + assert_eq!( + room.remote_participants()[&client_a.user_id().unwrap()] + .video_tracks + .len(), + 1 + ); + }); + } else { + panic!("unexpected event") + } + + // User C observes the remote screen sharing track. + assert_eq!(events_c.borrow().len(), 1); + let event_c = events_c.borrow().first().unwrap().clone(); + if let call::room::Event::RemoteVideoTracksChanged { participant_id } = event_c { + assert_eq!(participant_id, client_a.peer_id().unwrap()); + + room_c.read_with(cx_c, |room, _| { + assert_eq!( + room.remote_participants()[&client_a.user_id().unwrap()] + .video_tracks + .len(), + 1 + ); + }); + } else { + panic!("unexpected event") + } + + // User A leaves the room. + active_call_a + .update(cx_a, |call, cx| { + let hang_up = call.hang_up(cx); + assert!(call.room().is_none()); + hang_up + }) + .await + .unwrap(); + executor.run_until_parked(); + assert_eq!( + room_participants(&room_a, cx_a), + RoomParticipants { + remote: Default::default(), + pending: Default::default() + } + ); + assert_eq!( + room_participants(&room_b, cx_b), + RoomParticipants { + remote: vec!["user_c".to_string()], + pending: Default::default() + } + ); + assert_eq!( + room_participants(&room_c, cx_c), + RoomParticipants { + remote: vec!["user_b".to_string()], + pending: Default::default() + } + ); + + // User B gets disconnected from the LiveKit server, which causes them + // to automatically leave the room. User C leaves the room as well because + // nobody else is in there. + server + .test_live_kit_server + .disconnect_client(client_b.user_id().unwrap().to_string()) + .await; + executor.run_until_parked(); + + active_call_b.read_with(cx_b, |call, _| assert!(call.room().is_none())); + + active_call_c.read_with(cx_c, |call, _| assert!(call.room().is_none())); + assert_eq!( + room_participants(&room_a, cx_a), + RoomParticipants { + remote: Default::default(), + pending: Default::default() + } + ); + assert_eq!( + room_participants(&room_b, cx_b), + RoomParticipants { + remote: Default::default(), + pending: Default::default() + } + ); + assert_eq!( + room_participants(&room_c, cx_c), + RoomParticipants { + remote: Default::default(), + pending: Default::default() + } + ); +} + +#[gpui::test(iterations = 10)] +async fn test_calling_multiple_users_simultaneously( + executor: BackgroundExecutor, + cx_a: &mut TestAppContext, + cx_b: &mut TestAppContext, + cx_c: &mut TestAppContext, + cx_d: &mut TestAppContext, +) { + let mut server = TestServer::start(executor.clone()).await; + + let client_a = server.create_client(cx_a, "user_a").await; + let client_b = server.create_client(cx_b, "user_b").await; + let client_c = server.create_client(cx_c, "user_c").await; + let client_d = server.create_client(cx_d, "user_d").await; + server + .make_contacts(&mut [ + (&client_a, cx_a), + (&client_b, cx_b), + (&client_c, cx_c), + (&client_d, cx_d), + ]) + .await; + + let active_call_a = cx_a.read(ActiveCall::global); + let active_call_b = cx_b.read(ActiveCall::global); + let active_call_c = cx_c.read(ActiveCall::global); + let active_call_d = cx_d.read(ActiveCall::global); + + // Simultaneously call user B and user C from client A. + let b_invite = active_call_a.update(cx_a, |call, cx| { + call.invite(client_b.user_id().unwrap(), None, cx) + }); + let c_invite = active_call_a.update(cx_a, |call, cx| { + call.invite(client_c.user_id().unwrap(), None, cx) + }); + b_invite.await.unwrap(); + c_invite.await.unwrap(); + + let room_a = active_call_a.read_with(cx_a, |call, _| call.room().unwrap().clone()); + executor.run_until_parked(); + assert_eq!( + room_participants(&room_a, cx_a), + RoomParticipants { + remote: Default::default(), + pending: vec!["user_b".to_string(), "user_c".to_string()] + } + ); + + // Call client D from client A. + active_call_a + .update(cx_a, |call, cx| { + call.invite(client_d.user_id().unwrap(), None, cx) + }) + .await + .unwrap(); + executor.run_until_parked(); + assert_eq!( + room_participants(&room_a, cx_a), + RoomParticipants { + remote: Default::default(), + pending: vec![ + "user_b".to_string(), + "user_c".to_string(), + "user_d".to_string() + ] + } + ); + + // Accept the call on all clients simultaneously. + let accept_b = active_call_b.update(cx_b, |call, cx| call.accept_incoming(cx)); + let accept_c = active_call_c.update(cx_c, |call, cx| call.accept_incoming(cx)); + let accept_d = active_call_d.update(cx_d, |call, cx| call.accept_incoming(cx)); + accept_b.await.unwrap(); + accept_c.await.unwrap(); + accept_d.await.unwrap(); + + executor.run_until_parked(); + + let room_b = active_call_b.read_with(cx_b, |call, _| call.room().unwrap().clone()); + + let room_c = active_call_c.read_with(cx_c, |call, _| call.room().unwrap().clone()); + + let room_d = active_call_d.read_with(cx_d, |call, _| call.room().unwrap().clone()); + assert_eq!( + room_participants(&room_a, cx_a), + RoomParticipants { + remote: vec![ + "user_b".to_string(), + "user_c".to_string(), + "user_d".to_string(), + ], + pending: Default::default() + } + ); + assert_eq!( + room_participants(&room_b, cx_b), + RoomParticipants { + remote: vec![ + "user_a".to_string(), + "user_c".to_string(), + "user_d".to_string(), + ], + pending: Default::default() + } + ); + assert_eq!( + room_participants(&room_c, cx_c), + RoomParticipants { + remote: vec![ + "user_a".to_string(), + "user_b".to_string(), + "user_d".to_string(), + ], + pending: Default::default() + } + ); + assert_eq!( + room_participants(&room_d, cx_d), + RoomParticipants { + remote: vec![ + "user_a".to_string(), + "user_b".to_string(), + "user_c".to_string(), + ], + pending: Default::default() + } + ); +} + +#[gpui::test(iterations = 10)] +async fn test_joining_channels_and_calling_multiple_users_simultaneously( + executor: BackgroundExecutor, + cx_a: &mut TestAppContext, + cx_b: &mut TestAppContext, + cx_c: &mut TestAppContext, +) { + let mut server = TestServer::start(executor.clone()).await; + + let client_a = server.create_client(cx_a, "user_a").await; + let client_b = server.create_client(cx_b, "user_b").await; + let client_c = server.create_client(cx_c, "user_c").await; + server + .make_contacts(&mut [(&client_a, cx_a), (&client_b, cx_b), (&client_c, cx_c)]) + .await; + + let channel_1 = server + .make_channel( + "channel1", + None, + (&client_a, cx_a), + &mut [(&client_b, cx_b), (&client_c, cx_c)], + ) + .await; + + let channel_2 = server + .make_channel( + "channel2", + None, + (&client_a, cx_a), + &mut [(&client_b, cx_b), (&client_c, cx_c)], + ) + .await; + + let active_call_a = cx_a.read(ActiveCall::global); + + // Simultaneously join channel 1 and then channel 2 + active_call_a + .update(cx_a, |call, cx| call.join_channel(channel_1, cx)) + .detach(); + let join_channel_2 = active_call_a.update(cx_a, |call, cx| call.join_channel(channel_2, cx)); + + join_channel_2.await.unwrap(); + + let room_a = active_call_a.read_with(cx_a, |call, _| call.room().unwrap().clone()); + executor.run_until_parked(); + + assert_eq!(channel_id(&room_a, cx_a), Some(channel_2)); + + // Leave the room + active_call_a + .update(cx_a, |call, cx| { + let hang_up = call.hang_up(cx); + hang_up + }) + .await + .unwrap(); + + // Initiating invites and then joining a channel should fail gracefully + let b_invite = active_call_a.update(cx_a, |call, cx| { + call.invite(client_b.user_id().unwrap(), None, cx) + }); + let c_invite = active_call_a.update(cx_a, |call, cx| { + call.invite(client_c.user_id().unwrap(), None, cx) + }); + + let join_channel = active_call_a.update(cx_a, |call, cx| call.join_channel(channel_1, cx)); + + b_invite.await.unwrap(); + c_invite.await.unwrap(); + join_channel.await.unwrap(); + + let room_a = active_call_a.read_with(cx_a, |call, _| call.room().unwrap().clone()); + executor.run_until_parked(); + + assert_eq!( + room_participants(&room_a, cx_a), + RoomParticipants { + remote: Default::default(), + pending: vec!["user_b".to_string(), "user_c".to_string()] + } + ); + + assert_eq!(channel_id(&room_a, cx_a), None); + + // Leave the room + active_call_a + .update(cx_a, |call, cx| { + let hang_up = call.hang_up(cx); + hang_up + }) + .await + .unwrap(); + + // Simultaneously join channel 1 and call user B and user C from client A. + let join_channel = active_call_a.update(cx_a, |call, cx| call.join_channel(channel_1, cx)); + + let b_invite = active_call_a.update(cx_a, |call, cx| { + call.invite(client_b.user_id().unwrap(), None, cx) + }); + let c_invite = active_call_a.update(cx_a, |call, cx| { + call.invite(client_c.user_id().unwrap(), None, cx) + }); + + join_channel.await.unwrap(); + b_invite.await.unwrap(); + c_invite.await.unwrap(); + + active_call_a.read_with(cx_a, |call, _| call.room().unwrap().clone()); + executor.run_until_parked(); +} + +#[gpui::test(iterations = 10)] +async fn test_room_uniqueness( + executor: BackgroundExecutor, + cx_a: &mut TestAppContext, + cx_a2: &mut TestAppContext, + cx_b: &mut TestAppContext, + cx_b2: &mut TestAppContext, + cx_c: &mut TestAppContext, +) { + let mut server = TestServer::start(executor.clone()).await; + let client_a = server.create_client(cx_a, "user_a").await; + let _client_a2 = server.create_client(cx_a2, "user_a").await; + let client_b = server.create_client(cx_b, "user_b").await; + let _client_b2 = server.create_client(cx_b2, "user_b").await; + let client_c = server.create_client(cx_c, "user_c").await; + server + .make_contacts(&mut [(&client_a, cx_a), (&client_b, cx_b), (&client_c, cx_c)]) + .await; + + let active_call_a = cx_a.read(ActiveCall::global); + let active_call_a2 = cx_a2.read(ActiveCall::global); + let active_call_b = cx_b.read(ActiveCall::global); + let active_call_b2 = cx_b2.read(ActiveCall::global); + let active_call_c = cx_c.read(ActiveCall::global); + + // Call user B from client A. + active_call_a + .update(cx_a, |call, cx| { + call.invite(client_b.user_id().unwrap(), None, cx) + }) + .await + .unwrap(); + + // Ensure a new room can't be created given user A just created one. + active_call_a2 + .update(cx_a2, |call, cx| { + call.invite(client_c.user_id().unwrap(), None, cx) + }) + .await + .unwrap_err(); + + active_call_a2.read_with(cx_a2, |call, _| assert!(call.room().is_none())); + + // User B receives the call from user A. + + let mut incoming_call_b = active_call_b.read_with(cx_b, |call, _| call.incoming()); + let call_b1 = incoming_call_b.next().await.unwrap().unwrap(); + assert_eq!(call_b1.calling_user.github_login, "user_a"); + + // Ensure calling users A and B from client C fails. + active_call_c + .update(cx_c, |call, cx| { + call.invite(client_a.user_id().unwrap(), None, cx) + }) + .await + .unwrap_err(); + active_call_c + .update(cx_c, |call, cx| { + call.invite(client_b.user_id().unwrap(), None, cx) + }) + .await + .unwrap_err(); + + // Ensure User B can't create a room while they still have an incoming call. + active_call_b2 + .update(cx_b2, |call, cx| { + call.invite(client_c.user_id().unwrap(), None, cx) + }) + .await + .unwrap_err(); + + active_call_b2.read_with(cx_b2, |call, _| assert!(call.room().is_none())); + + // User B joins the room and calling them after they've joined still fails. + active_call_b + .update(cx_b, |call, cx| call.accept_incoming(cx)) + .await + .unwrap(); + active_call_c + .update(cx_c, |call, cx| { + call.invite(client_b.user_id().unwrap(), None, cx) + }) + .await + .unwrap_err(); + + // Ensure User B can't create a room while they belong to another room. + active_call_b2 + .update(cx_b2, |call, cx| { + call.invite(client_c.user_id().unwrap(), None, cx) + }) + .await + .unwrap_err(); + + active_call_b2.read_with(cx_b2, |call, _| assert!(call.room().is_none())); + + // Client C can successfully call client B after client B leaves the room. + active_call_b + .update(cx_b, |call, cx| call.hang_up(cx)) + .await + .unwrap(); + executor.run_until_parked(); + active_call_c + .update(cx_c, |call, cx| { + call.invite(client_b.user_id().unwrap(), None, cx) + }) + .await + .unwrap(); + executor.run_until_parked(); + let call_b2 = incoming_call_b.next().await.unwrap().unwrap(); + assert_eq!(call_b2.calling_user.github_login, "user_c"); +} + +#[gpui::test(iterations = 10)] +async fn test_client_disconnecting_from_room( + executor: BackgroundExecutor, + cx_a: &mut TestAppContext, + cx_b: &mut TestAppContext, +) { + let mut server = TestServer::start(executor.clone()).await; + let client_a = server.create_client(cx_a, "user_a").await; + let client_b = server.create_client(cx_b, "user_b").await; + server + .make_contacts(&mut [(&client_a, cx_a), (&client_b, cx_b)]) + .await; + + let active_call_a = cx_a.read(ActiveCall::global); + let active_call_b = cx_b.read(ActiveCall::global); + + // Call user B from client A. + active_call_a + .update(cx_a, |call, cx| { + call.invite(client_b.user_id().unwrap(), None, cx) + }) + .await + .unwrap(); + + let room_a = active_call_a.read_with(cx_a, |call, _| call.room().unwrap().clone()); + + // User B receives the call and joins the room. + + let mut incoming_call_b = active_call_b.read_with(cx_b, |call, _| call.incoming()); + incoming_call_b.next().await.unwrap().unwrap(); + active_call_b + .update(cx_b, |call, cx| call.accept_incoming(cx)) + .await + .unwrap(); + + let room_b = active_call_b.read_with(cx_b, |call, _| call.room().unwrap().clone()); + executor.run_until_parked(); + assert_eq!( + room_participants(&room_a, cx_a), + RoomParticipants { + remote: vec!["user_b".to_string()], + pending: Default::default() + } + ); + assert_eq!( + room_participants(&room_b, cx_b), + RoomParticipants { + remote: vec!["user_a".to_string()], + pending: Default::default() + } + ); + + // User A automatically reconnects to the room upon disconnection. + server.disconnect_client(client_a.peer_id().unwrap()); + executor.advance_clock(RECEIVE_TIMEOUT); + executor.run_until_parked(); + assert_eq!( + room_participants(&room_a, cx_a), + RoomParticipants { + remote: vec!["user_b".to_string()], + pending: Default::default() + } + ); + assert_eq!( + room_participants(&room_b, cx_b), + RoomParticipants { + remote: vec!["user_a".to_string()], + pending: Default::default() + } + ); + + // When user A disconnects, both client A and B clear their room on the active call. + server.forbid_connections(); + server.disconnect_client(client_a.peer_id().unwrap()); + executor.advance_clock(RECEIVE_TIMEOUT + RECONNECT_TIMEOUT); + + active_call_a.read_with(cx_a, |call, _| assert!(call.room().is_none())); + + active_call_b.read_with(cx_b, |call, _| assert!(call.room().is_none())); + assert_eq!( + room_participants(&room_a, cx_a), + RoomParticipants { + remote: Default::default(), + pending: Default::default() + } + ); + assert_eq!( + room_participants(&room_b, cx_b), + RoomParticipants { + remote: Default::default(), + pending: Default::default() + } + ); + + // Allow user A to reconnect to the server. + server.allow_connections(); + executor.advance_clock(RECEIVE_TIMEOUT); + + // Call user B again from client A. + active_call_a + .update(cx_a, |call, cx| { + call.invite(client_b.user_id().unwrap(), None, cx) + }) + .await + .unwrap(); + + let room_a = active_call_a.read_with(cx_a, |call, _| call.room().unwrap().clone()); + + // User B receives the call and joins the room. + + let mut incoming_call_b = active_call_b.read_with(cx_b, |call, _| call.incoming()); + incoming_call_b.next().await.unwrap().unwrap(); + active_call_b + .update(cx_b, |call, cx| call.accept_incoming(cx)) + .await + .unwrap(); + + let room_b = active_call_b.read_with(cx_b, |call, _| call.room().unwrap().clone()); + executor.run_until_parked(); + assert_eq!( + room_participants(&room_a, cx_a), + RoomParticipants { + remote: vec!["user_b".to_string()], + pending: Default::default() + } + ); + assert_eq!( + room_participants(&room_b, cx_b), + RoomParticipants { + remote: vec!["user_a".to_string()], + pending: Default::default() + } + ); + + // User B gets disconnected from the LiveKit server, which causes it + // to automatically leave the room. + server + .test_live_kit_server + .disconnect_client(client_b.user_id().unwrap().to_string()) + .await; + executor.run_until_parked(); + active_call_a.update(cx_a, |call, _| assert!(call.room().is_none())); + active_call_b.update(cx_b, |call, _| assert!(call.room().is_none())); + assert_eq!( + room_participants(&room_a, cx_a), + RoomParticipants { + remote: Default::default(), + pending: Default::default() + } + ); + assert_eq!( + room_participants(&room_b, cx_b), + RoomParticipants { + remote: Default::default(), + pending: Default::default() + } + ); +} + +#[gpui::test(iterations = 10)] +async fn test_server_restarts( + executor: BackgroundExecutor, + cx_a: &mut TestAppContext, + cx_b: &mut TestAppContext, + cx_c: &mut TestAppContext, + cx_d: &mut TestAppContext, +) { + let mut server = TestServer::start(executor.clone()).await; + let client_a = server.create_client(cx_a, "user_a").await; + client_a + .fs() + .insert_tree("/a", json!({ "a.txt": "a-contents" })) + .await; + + // Invite client B to collaborate on a project + let (project_a, _) = client_a.build_local_project("/a", cx_a).await; + + let client_b = server.create_client(cx_b, "user_b").await; + let client_c = server.create_client(cx_c, "user_c").await; + let client_d = server.create_client(cx_d, "user_d").await; + server + .make_contacts(&mut [ + (&client_a, cx_a), + (&client_b, cx_b), + (&client_c, cx_c), + (&client_d, cx_d), + ]) + .await; + + let active_call_a = cx_a.read(ActiveCall::global); + let active_call_b = cx_b.read(ActiveCall::global); + let active_call_c = cx_c.read(ActiveCall::global); + let active_call_d = cx_d.read(ActiveCall::global); + + // User A calls users B, C, and D. + active_call_a + .update(cx_a, |call, cx| { + call.invite(client_b.user_id().unwrap(), Some(project_a.clone()), cx) + }) + .await + .unwrap(); + active_call_a + .update(cx_a, |call, cx| { + call.invite(client_c.user_id().unwrap(), Some(project_a.clone()), cx) + }) + .await + .unwrap(); + active_call_a + .update(cx_a, |call, cx| { + call.invite(client_d.user_id().unwrap(), Some(project_a.clone()), cx) + }) + .await + .unwrap(); + + let room_a = active_call_a.read_with(cx_a, |call, _| call.room().unwrap().clone()); + + // User B receives the call and joins the room. + + let mut incoming_call_b = active_call_b.read_with(cx_b, |call, _| call.incoming()); + assert!(incoming_call_b.next().await.unwrap().is_some()); + active_call_b + .update(cx_b, |call, cx| call.accept_incoming(cx)) + .await + .unwrap(); + + let room_b = active_call_b.read_with(cx_b, |call, _| call.room().unwrap().clone()); + + // User C receives the call and joins the room. + + let mut incoming_call_c = active_call_c.read_with(cx_c, |call, _| call.incoming()); + assert!(incoming_call_c.next().await.unwrap().is_some()); + active_call_c + .update(cx_c, |call, cx| call.accept_incoming(cx)) + .await + .unwrap(); + + let room_c = active_call_c.read_with(cx_c, |call, _| call.room().unwrap().clone()); + + // User D receives the call but doesn't join the room yet. + + let mut incoming_call_d = active_call_d.read_with(cx_d, |call, _| call.incoming()); + assert!(incoming_call_d.next().await.unwrap().is_some()); + + executor.run_until_parked(); + assert_eq!( + room_participants(&room_a, cx_a), + RoomParticipants { + remote: vec!["user_b".to_string(), "user_c".to_string()], + pending: vec!["user_d".to_string()] + } + ); + assert_eq!( + room_participants(&room_b, cx_b), + RoomParticipants { + remote: vec!["user_a".to_string(), "user_c".to_string()], + pending: vec!["user_d".to_string()] + } + ); + assert_eq!( + room_participants(&room_c, cx_c), + RoomParticipants { + remote: vec!["user_a".to_string(), "user_b".to_string()], + pending: vec!["user_d".to_string()] + } + ); + + // The server is torn down. + server.reset().await; + + // Users A and B reconnect to the call. User C has troubles reconnecting, so it leaves the room. + client_c.override_establish_connection(|_, cx| cx.spawn(|_| future::pending())); + executor.advance_clock(RECONNECT_TIMEOUT); + assert_eq!( + room_participants(&room_a, cx_a), + RoomParticipants { + remote: vec!["user_b".to_string(), "user_c".to_string()], + pending: vec!["user_d".to_string()] + } + ); + assert_eq!( + room_participants(&room_b, cx_b), + RoomParticipants { + remote: vec!["user_a".to_string(), "user_c".to_string()], + pending: vec!["user_d".to_string()] + } + ); + assert_eq!( + room_participants(&room_c, cx_c), + RoomParticipants { + remote: vec![], + pending: vec![] + } + ); + + // User D is notified again of the incoming call and accepts it. + assert!(incoming_call_d.next().await.unwrap().is_some()); + active_call_d + .update(cx_d, |call, cx| call.accept_incoming(cx)) + .await + .unwrap(); + executor.run_until_parked(); + + let room_d = active_call_d.read_with(cx_d, |call, _| call.room().unwrap().clone()); + assert_eq!( + room_participants(&room_a, cx_a), + RoomParticipants { + remote: vec![ + "user_b".to_string(), + "user_c".to_string(), + "user_d".to_string(), + ], + pending: vec![] + } + ); + assert_eq!( + room_participants(&room_b, cx_b), + RoomParticipants { + remote: vec![ + "user_a".to_string(), + "user_c".to_string(), + "user_d".to_string(), + ], + pending: vec![] + } + ); + assert_eq!( + room_participants(&room_c, cx_c), + RoomParticipants { + remote: vec![], + pending: vec![] + } + ); + assert_eq!( + room_participants(&room_d, cx_d), + RoomParticipants { + remote: vec![ + "user_a".to_string(), + "user_b".to_string(), + "user_c".to_string(), + ], + pending: vec![] + } + ); + + // The server finishes restarting, cleaning up stale connections. + server.start().await.unwrap(); + executor.advance_clock(CLEANUP_TIMEOUT); + assert_eq!( + room_participants(&room_a, cx_a), + RoomParticipants { + remote: vec!["user_b".to_string(), "user_d".to_string()], + pending: vec![] + } + ); + assert_eq!( + room_participants(&room_b, cx_b), + RoomParticipants { + remote: vec!["user_a".to_string(), "user_d".to_string()], + pending: vec![] + } + ); + assert_eq!( + room_participants(&room_c, cx_c), + RoomParticipants { + remote: vec![], + pending: vec![] + } + ); + assert_eq!( + room_participants(&room_d, cx_d), + RoomParticipants { + remote: vec!["user_a".to_string(), "user_b".to_string()], + pending: vec![] + } + ); + + // User D hangs up. + active_call_d + .update(cx_d, |call, cx| call.hang_up(cx)) + .await + .unwrap(); + executor.run_until_parked(); + assert_eq!( + room_participants(&room_a, cx_a), + RoomParticipants { + remote: vec!["user_b".to_string()], + pending: vec![] + } + ); + assert_eq!( + room_participants(&room_b, cx_b), + RoomParticipants { + remote: vec!["user_a".to_string()], + pending: vec![] + } + ); + assert_eq!( + room_participants(&room_c, cx_c), + RoomParticipants { + remote: vec![], + pending: vec![] + } + ); + assert_eq!( + room_participants(&room_d, cx_d), + RoomParticipants { + remote: vec![], + pending: vec![] + } + ); + + // User B calls user D again. + active_call_b + .update(cx_b, |call, cx| { + call.invite(client_d.user_id().unwrap(), None, cx) + }) + .await + .unwrap(); + + // User D receives the call but doesn't join the room yet. + + let mut incoming_call_d = active_call_d.read_with(cx_d, |call, _| call.incoming()); + assert!(incoming_call_d.next().await.unwrap().is_some()); + executor.run_until_parked(); + assert_eq!( + room_participants(&room_a, cx_a), + RoomParticipants { + remote: vec!["user_b".to_string()], + pending: vec!["user_d".to_string()] + } + ); + assert_eq!( + room_participants(&room_b, cx_b), + RoomParticipants { + remote: vec!["user_a".to_string()], + pending: vec!["user_d".to_string()] + } + ); + + // The server is torn down. + server.reset().await; + + // Users A and B have troubles reconnecting, so they leave the room. + client_a.override_establish_connection(|_, cx| cx.spawn(|_| future::pending())); + client_b.override_establish_connection(|_, cx| cx.spawn(|_| future::pending())); + client_c.override_establish_connection(|_, cx| cx.spawn(|_| future::pending())); + executor.advance_clock(RECONNECT_TIMEOUT); + assert_eq!( + room_participants(&room_a, cx_a), + RoomParticipants { + remote: vec![], + pending: vec![] + } + ); + assert_eq!( + room_participants(&room_b, cx_b), + RoomParticipants { + remote: vec![], + pending: vec![] + } + ); + + // User D is notified again of the incoming call but doesn't accept it. + assert!(incoming_call_d.next().await.unwrap().is_some()); + + // The server finishes restarting, cleaning up stale connections and canceling the + // call to user D because the room has become empty. + server.start().await.unwrap(); + executor.advance_clock(CLEANUP_TIMEOUT); + assert!(incoming_call_d.next().await.unwrap().is_none()); +} + +#[gpui::test(iterations = 10)] +async fn test_calls_on_multiple_connections( + executor: BackgroundExecutor, + cx_a: &mut TestAppContext, + cx_b1: &mut TestAppContext, + cx_b2: &mut TestAppContext, +) { + let mut server = TestServer::start(executor.clone()).await; + let client_a = server.create_client(cx_a, "user_a").await; + let client_b1 = server.create_client(cx_b1, "user_b").await; + let client_b2 = server.create_client(cx_b2, "user_b").await; + server + .make_contacts(&mut [(&client_a, cx_a), (&client_b1, cx_b1)]) + .await; + + let active_call_a = cx_a.read(ActiveCall::global); + let active_call_b1 = cx_b1.read(ActiveCall::global); + let active_call_b2 = cx_b2.read(ActiveCall::global); + + let mut incoming_call_b1 = active_call_b1.read_with(cx_b1, |call, _| call.incoming()); + + let mut incoming_call_b2 = active_call_b2.read_with(cx_b2, |call, _| call.incoming()); + assert!(incoming_call_b1.next().await.unwrap().is_none()); + assert!(incoming_call_b2.next().await.unwrap().is_none()); + + // Call user B from client A, ensuring both clients for user B ring. + active_call_a + .update(cx_a, |call, cx| { + call.invite(client_b1.user_id().unwrap(), None, cx) + }) + .await + .unwrap(); + executor.run_until_parked(); + assert!(incoming_call_b1.next().await.unwrap().is_some()); + assert!(incoming_call_b2.next().await.unwrap().is_some()); + + // User B declines the call on one of the two connections, causing both connections + // to stop ringing. + active_call_b2.update(cx_b2, |call, cx| call.decline_incoming(cx).unwrap()); + executor.run_until_parked(); + assert!(incoming_call_b1.next().await.unwrap().is_none()); + assert!(incoming_call_b2.next().await.unwrap().is_none()); + + // Call user B again from client A. + active_call_a + .update(cx_a, |call, cx| { + call.invite(client_b1.user_id().unwrap(), None, cx) + }) + .await + .unwrap(); + executor.run_until_parked(); + assert!(incoming_call_b1.next().await.unwrap().is_some()); + assert!(incoming_call_b2.next().await.unwrap().is_some()); + + // User B accepts the call on one of the two connections, causing both connections + // to stop ringing. + active_call_b2 + .update(cx_b2, |call, cx| call.accept_incoming(cx)) + .await + .unwrap(); + executor.run_until_parked(); + assert!(incoming_call_b1.next().await.unwrap().is_none()); + assert!(incoming_call_b2.next().await.unwrap().is_none()); + + // User B disconnects the client that is not on the call. Everything should be fine. + client_b1.disconnect(&cx_b1.to_async()); + executor.advance_clock(RECEIVE_TIMEOUT); + client_b1 + .authenticate_and_connect(false, &cx_b1.to_async()) + .await + .unwrap(); + + // User B hangs up, and user A calls them again. + active_call_b2 + .update(cx_b2, |call, cx| call.hang_up(cx)) + .await + .unwrap(); + executor.run_until_parked(); + active_call_a + .update(cx_a, |call, cx| { + call.invite(client_b1.user_id().unwrap(), None, cx) + }) + .await + .unwrap(); + executor.run_until_parked(); + assert!(incoming_call_b1.next().await.unwrap().is_some()); + assert!(incoming_call_b2.next().await.unwrap().is_some()); + + // User A cancels the call, causing both connections to stop ringing. + active_call_a + .update(cx_a, |call, cx| { + call.cancel_invite(client_b1.user_id().unwrap(), cx) + }) + .await + .unwrap(); + executor.run_until_parked(); + assert!(incoming_call_b1.next().await.unwrap().is_none()); + assert!(incoming_call_b2.next().await.unwrap().is_none()); + + // User A calls user B again. + active_call_a + .update(cx_a, |call, cx| { + call.invite(client_b1.user_id().unwrap(), None, cx) + }) + .await + .unwrap(); + executor.run_until_parked(); + assert!(incoming_call_b1.next().await.unwrap().is_some()); + assert!(incoming_call_b2.next().await.unwrap().is_some()); + + // User A hangs up, causing both connections to stop ringing. + active_call_a + .update(cx_a, |call, cx| call.hang_up(cx)) + .await + .unwrap(); + executor.run_until_parked(); + assert!(incoming_call_b1.next().await.unwrap().is_none()); + assert!(incoming_call_b2.next().await.unwrap().is_none()); + + // User A calls user B again. + active_call_a + .update(cx_a, |call, cx| { + call.invite(client_b1.user_id().unwrap(), None, cx) + }) + .await + .unwrap(); + executor.run_until_parked(); + assert!(incoming_call_b1.next().await.unwrap().is_some()); + assert!(incoming_call_b2.next().await.unwrap().is_some()); + + // User A disconnects, causing both connections to stop ringing. + server.forbid_connections(); + server.disconnect_client(client_a.peer_id().unwrap()); + executor.advance_clock(RECEIVE_TIMEOUT + RECONNECT_TIMEOUT); + assert!(incoming_call_b1.next().await.unwrap().is_none()); + assert!(incoming_call_b2.next().await.unwrap().is_none()); + + // User A reconnects automatically, then calls user B again. + server.allow_connections(); + executor.advance_clock(RECEIVE_TIMEOUT); + active_call_a + .update(cx_a, |call, cx| { + call.invite(client_b1.user_id().unwrap(), None, cx) + }) + .await + .unwrap(); + executor.run_until_parked(); + assert!(incoming_call_b1.next().await.unwrap().is_some()); + assert!(incoming_call_b2.next().await.unwrap().is_some()); + + // User B disconnects all clients, causing user A to no longer see a pending call for them. + server.forbid_connections(); + server.disconnect_client(client_b1.peer_id().unwrap()); + server.disconnect_client(client_b2.peer_id().unwrap()); + executor.advance_clock(RECEIVE_TIMEOUT + RECONNECT_TIMEOUT); + + active_call_a.read_with(cx_a, |call, _| assert!(call.room().is_none())); +} + +#[gpui::test(iterations = 10)] +async fn test_unshare_project( + executor: BackgroundExecutor, + cx_a: &mut TestAppContext, + cx_b: &mut TestAppContext, + cx_c: &mut TestAppContext, +) { + let mut server = TestServer::start(executor.clone()).await; + let client_a = server.create_client(cx_a, "user_a").await; + let client_b = server.create_client(cx_b, "user_b").await; + let client_c = server.create_client(cx_c, "user_c").await; + server + .create_room(&mut [(&client_a, cx_a), (&client_b, cx_b), (&client_c, cx_c)]) + .await; + + let active_call_a = cx_a.read(ActiveCall::global); + let active_call_b = cx_b.read(ActiveCall::global); + + client_a + .fs() + .insert_tree( + "/a", + json!({ + "a.txt": "a-contents", + "b.txt": "b-contents", + }), + ) + .await; + + let (project_a, worktree_id) = client_a.build_local_project("/a", cx_a).await; + let project_id = active_call_a + .update(cx_a, |call, cx| call.share_project(project_a.clone(), cx)) + .await + .unwrap(); + + let worktree_a = project_a.read_with(cx_a, |project, _| project.worktrees().next().unwrap()); + let project_b = client_b.build_remote_project(project_id, cx_b).await; + executor.run_until_parked(); + + assert!(worktree_a.read_with(cx_a, |tree, _| tree.as_local().unwrap().is_shared())); + + project_b + .update(cx_b, |p, cx| p.open_buffer((worktree_id, "a.txt"), cx)) + .await + .unwrap(); + + // When client B leaves the room, the project becomes read-only. + active_call_b + .update(cx_b, |call, cx| call.hang_up(cx)) + .await + .unwrap(); + executor.run_until_parked(); + + assert!(project_b.read_with(cx_b, |project, _| project.is_read_only())); + + // Client C opens the project. + let project_c = client_c.build_remote_project(project_id, cx_c).await; + + // When client A unshares the project, client C's project becomes read-only. + project_a + .update(cx_a, |project, cx| project.unshare(cx)) + .unwrap(); + executor.run_until_parked(); + + assert!(worktree_a.read_with(cx_a, |tree, _| !tree.as_local().unwrap().is_shared())); + + assert!(project_c.read_with(cx_c, |project, _| project.is_read_only())); + + // Client C can open the project again after client A re-shares. + let project_id = active_call_a + .update(cx_a, |call, cx| call.share_project(project_a.clone(), cx)) + .await + .unwrap(); + let project_c2 = client_c.build_remote_project(project_id, cx_c).await; + executor.run_until_parked(); + + assert!(worktree_a.read_with(cx_a, |tree, _| tree.as_local().unwrap().is_shared())); + project_c2 + .update(cx_c, |p, cx| p.open_buffer((worktree_id, "a.txt"), cx)) + .await + .unwrap(); + + // When client A (the host) leaves the room, the project gets unshared and guests are notified. + active_call_a + .update(cx_a, |call, cx| call.hang_up(cx)) + .await + .unwrap(); + executor.run_until_parked(); + + project_a.read_with(cx_a, |project, _| assert!(!project.is_shared())); + + project_c2.read_with(cx_c, |project, _| { + assert!(project.is_read_only()); + assert!(project.collaborators().is_empty()); + }); +} + +#[gpui::test(iterations = 10)] +async fn test_project_reconnect( + executor: BackgroundExecutor, + cx_a: &mut TestAppContext, + cx_b: &mut TestAppContext, +) { + let mut server = TestServer::start(executor.clone()).await; + let client_a = server.create_client(cx_a, "user_a").await; + let client_b = server.create_client(cx_b, "user_b").await; + server + .create_room(&mut [(&client_a, cx_a), (&client_b, cx_b)]) + .await; + + cx_b.update(editor::init); + + client_a + .fs() + .insert_tree( + "/root-1", + json!({ + "dir1": { + "a.txt": "a", + "b.txt": "b", + "subdir1": { + "c.txt": "c", + "d.txt": "d", + "e.txt": "e", + } + }, + "dir2": { + "v.txt": "v", + }, + "dir3": { + "w.txt": "w", + "x.txt": "x", + "y.txt": "y", + }, + "dir4": { + "z.txt": "z", + }, + }), + ) + .await; + client_a + .fs() + .insert_tree( + "/root-2", + json!({ + "2.txt": "2", + }), + ) + .await; + client_a + .fs() + .insert_tree( + "/root-3", + json!({ + "3.txt": "3", + }), + ) + .await; + + let active_call_a = cx_a.read(ActiveCall::global); + let (project_a1, _) = client_a.build_local_project("/root-1/dir1", cx_a).await; + let (project_a2, _) = client_a.build_local_project("/root-2", cx_a).await; + let (project_a3, _) = client_a.build_local_project("/root-3", cx_a).await; + let worktree_a1 = project_a1.read_with(cx_a, |project, _| project.worktrees().next().unwrap()); + let project1_id = active_call_a + .update(cx_a, |call, cx| call.share_project(project_a1.clone(), cx)) + .await + .unwrap(); + let project2_id = active_call_a + .update(cx_a, |call, cx| call.share_project(project_a2.clone(), cx)) + .await + .unwrap(); + let project3_id = active_call_a + .update(cx_a, |call, cx| call.share_project(project_a3.clone(), cx)) + .await + .unwrap(); + + let project_b1 = client_b.build_remote_project(project1_id, cx_b).await; + let project_b2 = client_b.build_remote_project(project2_id, cx_b).await; + let project_b3 = client_b.build_remote_project(project3_id, cx_b).await; + executor.run_until_parked(); + + let worktree1_id = worktree_a1.read_with(cx_a, |worktree, _| { + assert!(worktree.as_local().unwrap().is_shared()); + worktree.id() + }); + let (worktree_a2, _) = project_a1 + .update(cx_a, |p, cx| { + p.find_or_create_local_worktree("/root-1/dir2", true, cx) + }) + .await + .unwrap(); + executor.run_until_parked(); + + let worktree2_id = worktree_a2.read_with(cx_a, |tree, _| { + assert!(tree.as_local().unwrap().is_shared()); + tree.id() + }); + executor.run_until_parked(); + + project_b1.read_with(cx_b, |project, cx| { + assert!(project.worktree_for_id(worktree2_id, cx).is_some()) + }); + + let buffer_a1 = project_a1 + .update(cx_a, |p, cx| p.open_buffer((worktree1_id, "a.txt"), cx)) + .await + .unwrap(); + let buffer_b1 = project_b1 + .update(cx_b, |p, cx| p.open_buffer((worktree1_id, "a.txt"), cx)) + .await + .unwrap(); + + // Drop client A's connection. + server.forbid_connections(); + server.disconnect_client(client_a.peer_id().unwrap()); + executor.advance_clock(RECEIVE_TIMEOUT); + + project_a1.read_with(cx_a, |project, _| { + assert!(project.is_shared()); + assert_eq!(project.collaborators().len(), 1); + }); + + project_b1.read_with(cx_b, |project, _| { + assert!(!project.is_read_only()); + assert_eq!(project.collaborators().len(), 1); + }); + + worktree_a1.read_with(cx_a, |tree, _| { + assert!(tree.as_local().unwrap().is_shared()) + }); + + // While client A is disconnected, add and remove files from client A's project. + client_a + .fs() + .insert_tree( + "/root-1/dir1/subdir2", + json!({ + "f.txt": "f-contents", + "g.txt": "g-contents", + "h.txt": "h-contents", + "i.txt": "i-contents", + }), + ) + .await; + client_a + .fs() + .remove_dir( + "/root-1/dir1/subdir1".as_ref(), + RemoveOptions { + recursive: true, + ..Default::default() + }, + ) + .await + .unwrap(); + + // While client A is disconnected, add and remove worktrees from client A's project. + project_a1.update(cx_a, |project, cx| { + project.remove_worktree(worktree2_id, cx) + }); + let (worktree_a3, _) = project_a1 + .update(cx_a, |p, cx| { + p.find_or_create_local_worktree("/root-1/dir3", true, cx) + }) + .await + .unwrap(); + worktree_a3 + .read_with(cx_a, |tree, _| tree.as_local().unwrap().scan_complete()) + .await; + + let worktree3_id = worktree_a3.read_with(cx_a, |tree, _| { + assert!(!tree.as_local().unwrap().is_shared()); + tree.id() + }); + executor.run_until_parked(); + + // While client A is disconnected, close project 2 + cx_a.update(|_| drop(project_a2)); + + // While client A is disconnected, mutate a buffer on both the host and the guest. + buffer_a1.update(cx_a, |buf, cx| buf.edit([(0..0, "W")], None, cx)); + buffer_b1.update(cx_b, |buf, cx| buf.edit([(1..1, "Z")], None, cx)); + executor.run_until_parked(); + + // Client A reconnects. Their project is re-shared, and client B re-joins it. + server.allow_connections(); + client_a + .authenticate_and_connect(false, &cx_a.to_async()) + .await + .unwrap(); + executor.run_until_parked(); + + project_a1.read_with(cx_a, |project, cx| { + assert!(project.is_shared()); + assert!(worktree_a1.read(cx).as_local().unwrap().is_shared()); + assert_eq!( + worktree_a1 + .read(cx) + .snapshot() + .paths() + .map(|p| p.to_str().unwrap()) + .collect::>(), + vec![ + "a.txt", + "b.txt", + "subdir2", + "subdir2/f.txt", + "subdir2/g.txt", + "subdir2/h.txt", + "subdir2/i.txt" + ] + ); + assert!(worktree_a3.read(cx).as_local().unwrap().is_shared()); + assert_eq!( + worktree_a3 + .read(cx) + .snapshot() + .paths() + .map(|p| p.to_str().unwrap()) + .collect::>(), + vec!["w.txt", "x.txt", "y.txt"] + ); + }); + + project_b1.read_with(cx_b, |project, cx| { + assert!(!project.is_read_only()); + assert_eq!( + project + .worktree_for_id(worktree1_id, cx) + .unwrap() + .read(cx) + .snapshot() + .paths() + .map(|p| p.to_str().unwrap()) + .collect::>(), + vec![ + "a.txt", + "b.txt", + "subdir2", + "subdir2/f.txt", + "subdir2/g.txt", + "subdir2/h.txt", + "subdir2/i.txt" + ] + ); + assert!(project.worktree_for_id(worktree2_id, cx).is_none()); + assert_eq!( + project + .worktree_for_id(worktree3_id, cx) + .unwrap() + .read(cx) + .snapshot() + .paths() + .map(|p| p.to_str().unwrap()) + .collect::>(), + vec!["w.txt", "x.txt", "y.txt"] + ); + }); + + project_b2.read_with(cx_b, |project, _| assert!(project.is_read_only())); + + project_b3.read_with(cx_b, |project, _| assert!(!project.is_read_only())); + + buffer_a1.read_with(cx_a, |buffer, _| assert_eq!(buffer.text(), "WaZ")); + + buffer_b1.read_with(cx_b, |buffer, _| assert_eq!(buffer.text(), "WaZ")); + + // Drop client B's connection. + server.forbid_connections(); + server.disconnect_client(client_b.peer_id().unwrap()); + executor.advance_clock(RECEIVE_TIMEOUT); + + // While client B is disconnected, add and remove files from client A's project + client_a + .fs() + .insert_file("/root-1/dir1/subdir2/j.txt", "j-contents".into()) + .await; + client_a + .fs() + .remove_file("/root-1/dir1/subdir2/i.txt".as_ref(), Default::default()) + .await + .unwrap(); + + // While client B is disconnected, add and remove worktrees from client A's project. + let (worktree_a4, _) = project_a1 + .update(cx_a, |p, cx| { + p.find_or_create_local_worktree("/root-1/dir4", true, cx) + }) + .await + .unwrap(); + executor.run_until_parked(); + + let worktree4_id = worktree_a4.read_with(cx_a, |tree, _| { + assert!(tree.as_local().unwrap().is_shared()); + tree.id() + }); + project_a1.update(cx_a, |project, cx| { + project.remove_worktree(worktree3_id, cx) + }); + executor.run_until_parked(); + + // While client B is disconnected, mutate a buffer on both the host and the guest. + buffer_a1.update(cx_a, |buf, cx| buf.edit([(1..1, "X")], None, cx)); + buffer_b1.update(cx_b, |buf, cx| buf.edit([(2..2, "Y")], None, cx)); + executor.run_until_parked(); + + // While disconnected, close project 3 + cx_a.update(|_| drop(project_a3)); + + // Client B reconnects. They re-join the room and the remaining shared project. + server.allow_connections(); + client_b + .authenticate_and_connect(false, &cx_b.to_async()) + .await + .unwrap(); + executor.run_until_parked(); + + project_b1.read_with(cx_b, |project, cx| { + assert!(!project.is_read_only()); + assert_eq!( + project + .worktree_for_id(worktree1_id, cx) + .unwrap() + .read(cx) + .snapshot() + .paths() + .map(|p| p.to_str().unwrap()) + .collect::>(), + vec![ + "a.txt", + "b.txt", + "subdir2", + "subdir2/f.txt", + "subdir2/g.txt", + "subdir2/h.txt", + "subdir2/j.txt" + ] + ); + assert!(project.worktree_for_id(worktree2_id, cx).is_none()); + assert_eq!( + project + .worktree_for_id(worktree4_id, cx) + .unwrap() + .read(cx) + .snapshot() + .paths() + .map(|p| p.to_str().unwrap()) + .collect::>(), + vec!["z.txt"] + ); + }); + + project_b3.read_with(cx_b, |project, _| assert!(project.is_read_only())); + + buffer_a1.read_with(cx_a, |buffer, _| assert_eq!(buffer.text(), "WXaYZ")); + + buffer_b1.read_with(cx_b, |buffer, _| assert_eq!(buffer.text(), "WXaYZ")); +} + +#[gpui::test(iterations = 10)] +async fn test_active_call_events( + executor: BackgroundExecutor, + cx_a: &mut TestAppContext, + cx_b: &mut TestAppContext, +) { + let mut server = TestServer::start(executor.clone()).await; + let client_a = server.create_client(cx_a, "user_a").await; + let client_b = server.create_client(cx_b, "user_b").await; + client_a.fs().insert_tree("/a", json!({})).await; + client_b.fs().insert_tree("/b", json!({})).await; + + let (project_a, _) = client_a.build_local_project("/a", cx_a).await; + let (project_b, _) = client_b.build_local_project("/b", cx_b).await; + + server + .create_room(&mut [(&client_a, cx_a), (&client_b, cx_b)]) + .await; + let active_call_a = cx_a.read(ActiveCall::global); + let active_call_b = cx_b.read(ActiveCall::global); + + let events_a = active_call_events(cx_a); + let events_b = active_call_events(cx_b); + + let project_a_id = active_call_a + .update(cx_a, |call, cx| call.share_project(project_a.clone(), cx)) + .await + .unwrap(); + executor.run_until_parked(); + assert_eq!(mem::take(&mut *events_a.borrow_mut()), vec![]); + assert_eq!( + mem::take(&mut *events_b.borrow_mut()), + vec![room::Event::RemoteProjectShared { + owner: Arc::new(User { + id: client_a.user_id().unwrap(), + github_login: "user_a".to_string(), + avatar: None, + }), + project_id: project_a_id, + worktree_root_names: vec!["a".to_string()], + }] + ); + + let project_b_id = active_call_b + .update(cx_b, |call, cx| call.share_project(project_b.clone(), cx)) + .await + .unwrap(); + executor.run_until_parked(); + assert_eq!( + mem::take(&mut *events_a.borrow_mut()), + vec![room::Event::RemoteProjectShared { + owner: Arc::new(User { + id: client_b.user_id().unwrap(), + github_login: "user_b".to_string(), + avatar: None, + }), + project_id: project_b_id, + worktree_root_names: vec!["b".to_string()] + }] + ); + assert_eq!(mem::take(&mut *events_b.borrow_mut()), vec![]); + + // Sharing a project twice is idempotent. + let project_b_id_2 = active_call_b + .update(cx_b, |call, cx| call.share_project(project_b.clone(), cx)) + .await + .unwrap(); + assert_eq!(project_b_id_2, project_b_id); + executor.run_until_parked(); + assert_eq!(mem::take(&mut *events_a.borrow_mut()), vec![]); + assert_eq!(mem::take(&mut *events_b.borrow_mut()), vec![]); +} + +fn active_call_events(cx: &mut TestAppContext) -> Rc>> { + let events = Rc::new(RefCell::new(Vec::new())); + let active_call = cx.read(ActiveCall::global); + cx.update({ + let events = events.clone(); + |cx| { + cx.subscribe(&active_call, move |_, event, _| { + events.borrow_mut().push(event.clone()) + }) + .detach() + } + }); + events +} + +#[gpui::test(iterations = 10)] +async fn test_room_location( + executor: BackgroundExecutor, + cx_a: &mut TestAppContext, + cx_b: &mut TestAppContext, +) { + let mut server = TestServer::start(executor.clone()).await; + let client_a = server.create_client(cx_a, "user_a").await; + let client_b = server.create_client(cx_b, "user_b").await; + client_a.fs().insert_tree("/a", json!({})).await; + client_b.fs().insert_tree("/b", json!({})).await; + + let active_call_a = cx_a.read(ActiveCall::global); + let active_call_b = cx_b.read(ActiveCall::global); + + let a_notified = Rc::new(Cell::new(false)); + cx_a.update({ + let notified = a_notified.clone(); + |cx| { + cx.observe(&active_call_a, move |_, _| notified.set(true)) + .detach() + } + }); + + let b_notified = Rc::new(Cell::new(false)); + cx_b.update({ + let b_notified = b_notified.clone(); + |cx| { + cx.observe(&active_call_b, move |_, _| b_notified.set(true)) + .detach() + } + }); + + let (project_a, _) = client_a.build_local_project("/a", cx_a).await; + active_call_a + .update(cx_a, |call, cx| call.set_location(Some(&project_a), cx)) + .await + .unwrap(); + let (project_b, _) = client_b.build_local_project("/b", cx_b).await; + + server + .create_room(&mut [(&client_a, cx_a), (&client_b, cx_b)]) + .await; + + let room_a = active_call_a.read_with(cx_a, |call, _| call.room().unwrap().clone()); + + let room_b = active_call_b.read_with(cx_b, |call, _| call.room().unwrap().clone()); + executor.run_until_parked(); + assert!(a_notified.take()); + assert_eq!( + participant_locations(&room_a, cx_a), + vec![("user_b".to_string(), ParticipantLocation::External)] + ); + assert!(b_notified.take()); + assert_eq!( + participant_locations(&room_b, cx_b), + vec![("user_a".to_string(), ParticipantLocation::UnsharedProject)] + ); + + let project_a_id = active_call_a + .update(cx_a, |call, cx| call.share_project(project_a.clone(), cx)) + .await + .unwrap(); + executor.run_until_parked(); + assert!(a_notified.take()); + assert_eq!( + participant_locations(&room_a, cx_a), + vec![("user_b".to_string(), ParticipantLocation::External)] + ); + assert!(b_notified.take()); + assert_eq!( + participant_locations(&room_b, cx_b), + vec![( + "user_a".to_string(), + ParticipantLocation::SharedProject { + project_id: project_a_id + } + )] + ); + + let project_b_id = active_call_b + .update(cx_b, |call, cx| call.share_project(project_b.clone(), cx)) + .await + .unwrap(); + executor.run_until_parked(); + assert!(a_notified.take()); + assert_eq!( + participant_locations(&room_a, cx_a), + vec![("user_b".to_string(), ParticipantLocation::External)] + ); + assert!(b_notified.take()); + assert_eq!( + participant_locations(&room_b, cx_b), + vec![( + "user_a".to_string(), + ParticipantLocation::SharedProject { + project_id: project_a_id + } + )] + ); + + active_call_b + .update(cx_b, |call, cx| call.set_location(Some(&project_b), cx)) + .await + .unwrap(); + executor.run_until_parked(); + assert!(a_notified.take()); + assert_eq!( + participant_locations(&room_a, cx_a), + vec![( + "user_b".to_string(), + ParticipantLocation::SharedProject { + project_id: project_b_id + } + )] + ); + assert!(b_notified.take()); + assert_eq!( + participant_locations(&room_b, cx_b), + vec![( + "user_a".to_string(), + ParticipantLocation::SharedProject { + project_id: project_a_id + } + )] + ); + + active_call_b + .update(cx_b, |call, cx| call.set_location(None, cx)) + .await + .unwrap(); + executor.run_until_parked(); + assert!(a_notified.take()); + assert_eq!( + participant_locations(&room_a, cx_a), + vec![("user_b".to_string(), ParticipantLocation::External)] + ); + assert!(b_notified.take()); + assert_eq!( + participant_locations(&room_b, cx_b), + vec![( + "user_a".to_string(), + ParticipantLocation::SharedProject { + project_id: project_a_id + } + )] + ); + + fn participant_locations( + room: &Model, + cx: &TestAppContext, + ) -> Vec<(String, ParticipantLocation)> { + room.read_with(cx, |room, _| { + room.remote_participants() + .values() + .map(|participant| { + ( + participant.user.github_login.to_string(), + participant.location, + ) + }) + .collect() + }) + } +} + +#[gpui::test(iterations = 10)] +async fn test_propagate_saves_and_fs_changes( + executor: BackgroundExecutor, + cx_a: &mut TestAppContext, + cx_b: &mut TestAppContext, + cx_c: &mut TestAppContext, +) { + let mut server = TestServer::start(executor.clone()).await; + let client_a = server.create_client(cx_a, "user_a").await; + let client_b = server.create_client(cx_b, "user_b").await; + let client_c = server.create_client(cx_c, "user_c").await; + + server + .create_room(&mut [(&client_a, cx_a), (&client_b, cx_b), (&client_c, cx_c)]) + .await; + let active_call_a = cx_a.read(ActiveCall::global); + + let rust = Arc::new(Language::new( + LanguageConfig { + name: "Rust".into(), + path_suffixes: vec!["rs".to_string()], + ..Default::default() + }, + Some(tree_sitter_rust::language()), + )); + let javascript = Arc::new(Language::new( + LanguageConfig { + name: "JavaScript".into(), + path_suffixes: vec!["js".to_string()], + ..Default::default() + }, + Some(tree_sitter_rust::language()), + )); + for client in [&client_a, &client_b, &client_c] { + client.language_registry().add(rust.clone()); + client.language_registry().add(javascript.clone()); + } + + client_a + .fs() + .insert_tree( + "/a", + json!({ + "file1.rs": "", + "file2": "" + }), + ) + .await; + let (project_a, worktree_id) = client_a.build_local_project("/a", cx_a).await; + + let worktree_a = project_a.read_with(cx_a, |p, _| p.worktrees().next().unwrap()); + let project_id = active_call_a + .update(cx_a, |call, cx| call.share_project(project_a.clone(), cx)) + .await + .unwrap(); + + // Join that worktree as clients B and C. + let project_b = client_b.build_remote_project(project_id, cx_b).await; + let project_c = client_c.build_remote_project(project_id, cx_c).await; + + let worktree_b = project_b.read_with(cx_b, |p, _| p.worktrees().next().unwrap()); + + let worktree_c = project_c.read_with(cx_c, |p, _| p.worktrees().next().unwrap()); + + // Open and edit a buffer as both guests B and C. + let buffer_b = project_b + .update(cx_b, |p, cx| p.open_buffer((worktree_id, "file1.rs"), cx)) + .await + .unwrap(); + let buffer_c = project_c + .update(cx_c, |p, cx| p.open_buffer((worktree_id, "file1.rs"), cx)) + .await + .unwrap(); + + buffer_b.read_with(cx_b, |buffer, _| { + assert_eq!(&*buffer.language().unwrap().name(), "Rust"); + }); + + buffer_c.read_with(cx_c, |buffer, _| { + assert_eq!(&*buffer.language().unwrap().name(), "Rust"); + }); + buffer_b.update(cx_b, |buf, cx| buf.edit([(0..0, "i-am-b, ")], None, cx)); + buffer_c.update(cx_c, |buf, cx| buf.edit([(0..0, "i-am-c, ")], None, cx)); + + // Open and edit that buffer as the host. + let buffer_a = project_a + .update(cx_a, |p, cx| p.open_buffer((worktree_id, "file1.rs"), cx)) + .await + .unwrap(); + + executor.run_until_parked(); + + buffer_a.read_with(cx_a, |buf, _| assert_eq!(buf.text(), "i-am-c, i-am-b, ")); + buffer_a.update(cx_a, |buf, cx| { + buf.edit([(buf.len()..buf.len(), "i-am-a")], None, cx) + }); + + executor.run_until_parked(); + + buffer_a.read_with(cx_a, |buf, _| { + assert_eq!(buf.text(), "i-am-c, i-am-b, i-am-a"); + }); + + buffer_b.read_with(cx_b, |buf, _| { + assert_eq!(buf.text(), "i-am-c, i-am-b, i-am-a"); + }); + + buffer_c.read_with(cx_c, |buf, _| { + assert_eq!(buf.text(), "i-am-c, i-am-b, i-am-a"); + }); + + // Edit the buffer as the host and concurrently save as guest B. + let save_b = project_b.update(cx_b, |project, cx| { + project.save_buffer(buffer_b.clone(), cx) + }); + buffer_a.update(cx_a, |buf, cx| buf.edit([(0..0, "hi-a, ")], None, cx)); + save_b.await.unwrap(); + assert_eq!( + client_a.fs().load("/a/file1.rs".as_ref()).await.unwrap(), + "hi-a, i-am-c, i-am-b, i-am-a" + ); + + executor.run_until_parked(); + + buffer_a.read_with(cx_a, |buf, _| assert!(!buf.is_dirty())); + + buffer_b.read_with(cx_b, |buf, _| assert!(!buf.is_dirty())); + + buffer_c.read_with(cx_c, |buf, _| assert!(!buf.is_dirty())); + + // Make changes on host's file system, see those changes on guest worktrees. + client_a + .fs() + .rename( + "/a/file1.rs".as_ref(), + "/a/file1.js".as_ref(), + Default::default(), + ) + .await + .unwrap(); + client_a + .fs() + .rename("/a/file2".as_ref(), "/a/file3".as_ref(), Default::default()) + .await + .unwrap(); + client_a.fs().insert_file("/a/file4", "4".into()).await; + executor.run_until_parked(); + + worktree_a.read_with(cx_a, |tree, _| { + assert_eq!( + tree.paths() + .map(|p| p.to_string_lossy()) + .collect::>(), + ["file1.js", "file3", "file4"] + ) + }); + + worktree_b.read_with(cx_b, |tree, _| { + assert_eq!( + tree.paths() + .map(|p| p.to_string_lossy()) + .collect::>(), + ["file1.js", "file3", "file4"] + ) + }); + + worktree_c.read_with(cx_c, |tree, _| { + assert_eq!( + tree.paths() + .map(|p| p.to_string_lossy()) + .collect::>(), + ["file1.js", "file3", "file4"] + ) + }); + + // Ensure buffer files are updated as well. + + buffer_a.read_with(cx_a, |buffer, _| { + assert_eq!(buffer.file().unwrap().path().to_str(), Some("file1.js")); + assert_eq!(&*buffer.language().unwrap().name(), "JavaScript"); + }); + + buffer_b.read_with(cx_b, |buffer, _| { + assert_eq!(buffer.file().unwrap().path().to_str(), Some("file1.js")); + assert_eq!(&*buffer.language().unwrap().name(), "JavaScript"); + }); + + buffer_c.read_with(cx_c, |buffer, _| { + assert_eq!(buffer.file().unwrap().path().to_str(), Some("file1.js")); + assert_eq!(&*buffer.language().unwrap().name(), "JavaScript"); + }); + + let new_buffer_a = project_a + .update(cx_a, |p, cx| p.create_buffer("", None, cx)) + .unwrap(); + + let new_buffer_id = new_buffer_a.read_with(cx_a, |buffer, _| buffer.remote_id()); + let new_buffer_b = project_b + .update(cx_b, |p, cx| p.open_buffer_by_id(new_buffer_id, cx)) + .await + .unwrap(); + + new_buffer_b.read_with(cx_b, |buffer, _| { + assert!(buffer.file().is_none()); + }); + + new_buffer_a.update(cx_a, |buffer, cx| { + buffer.edit([(0..0, "ok")], None, cx); + }); + project_a + .update(cx_a, |project, cx| { + project.save_buffer_as(new_buffer_a.clone(), "/a/file3.rs".into(), cx) + }) + .await + .unwrap(); + + executor.run_until_parked(); + + new_buffer_b.read_with(cx_b, |buffer_b, _| { + assert_eq!( + buffer_b.file().unwrap().path().as_ref(), + Path::new("file3.rs") + ); + + new_buffer_a.read_with(cx_a, |buffer_a, _| { + assert_eq!(buffer_b.saved_mtime(), buffer_a.saved_mtime()); + assert_eq!(buffer_b.saved_version(), buffer_a.saved_version()); + }); + }); +} + +#[gpui::test(iterations = 10)] +async fn test_git_diff_base_change( + executor: BackgroundExecutor, + cx_a: &mut TestAppContext, + cx_b: &mut TestAppContext, +) { + let mut server = TestServer::start(executor.clone()).await; + let client_a = server.create_client(cx_a, "user_a").await; + let client_b = server.create_client(cx_b, "user_b").await; + server + .create_room(&mut [(&client_a, cx_a), (&client_b, cx_b)]) + .await; + let active_call_a = cx_a.read(ActiveCall::global); + + client_a + .fs() + .insert_tree( + "/dir", + json!({ + ".git": {}, + "sub": { + ".git": {}, + "b.txt": " + one + two + three + ".unindent(), + }, + "a.txt": " + one + two + three + ".unindent(), + }), + ) + .await; + + let (project_local, worktree_id) = client_a.build_local_project("/dir", cx_a).await; + let project_id = active_call_a + .update(cx_a, |call, cx| { + call.share_project(project_local.clone(), cx) + }) + .await + .unwrap(); + + let project_remote = client_b.build_remote_project(project_id, cx_b).await; + + let diff_base = " + one + three + " + .unindent(); + + let new_diff_base = " + one + two + " + .unindent(); + + client_a.fs().set_index_for_repo( + Path::new("/dir/.git"), + &[(Path::new("a.txt"), diff_base.clone())], + ); + + // Create the buffer + let buffer_local_a = project_local + .update(cx_a, |p, cx| p.open_buffer((worktree_id, "a.txt"), cx)) + .await + .unwrap(); + + // Wait for it to catch up to the new diff + executor.run_until_parked(); + + // Smoke test diffing + + buffer_local_a.read_with(cx_a, |buffer, _| { + assert_eq!(buffer.diff_base(), Some(diff_base.as_ref())); + git::diff::assert_hunks( + buffer.snapshot().git_diff_hunks_in_row_range(0..4), + &buffer, + &diff_base, + &[(1..2, "", "two\n")], + ); + }); + + // Create remote buffer + let buffer_remote_a = project_remote + .update(cx_b, |p, cx| p.open_buffer((worktree_id, "a.txt"), cx)) + .await + .unwrap(); + + // Wait remote buffer to catch up to the new diff + executor.run_until_parked(); + + // Smoke test diffing + + buffer_remote_a.read_with(cx_b, |buffer, _| { + assert_eq!(buffer.diff_base(), Some(diff_base.as_ref())); + git::diff::assert_hunks( + buffer.snapshot().git_diff_hunks_in_row_range(0..4), + &buffer, + &diff_base, + &[(1..2, "", "two\n")], + ); + }); + + client_a.fs().set_index_for_repo( + Path::new("/dir/.git"), + &[(Path::new("a.txt"), new_diff_base.clone())], + ); + + // Wait for buffer_local_a to receive it + executor.run_until_parked(); + + // Smoke test new diffing + + buffer_local_a.read_with(cx_a, |buffer, _| { + assert_eq!(buffer.diff_base(), Some(new_diff_base.as_ref())); + + git::diff::assert_hunks( + buffer.snapshot().git_diff_hunks_in_row_range(0..4), + &buffer, + &diff_base, + &[(2..3, "", "three\n")], + ); + }); + + // Smoke test B + + buffer_remote_a.read_with(cx_b, |buffer, _| { + assert_eq!(buffer.diff_base(), Some(new_diff_base.as_ref())); + git::diff::assert_hunks( + buffer.snapshot().git_diff_hunks_in_row_range(0..4), + &buffer, + &diff_base, + &[(2..3, "", "three\n")], + ); + }); + + //Nested git dir + + let diff_base = " + one + three + " + .unindent(); + + let new_diff_base = " + one + two + " + .unindent(); + + client_a.fs().set_index_for_repo( + Path::new("/dir/sub/.git"), + &[(Path::new("b.txt"), diff_base.clone())], + ); + + // Create the buffer + let buffer_local_b = project_local + .update(cx_a, |p, cx| p.open_buffer((worktree_id, "sub/b.txt"), cx)) + .await + .unwrap(); + + // Wait for it to catch up to the new diff + executor.run_until_parked(); + + // Smoke test diffing + + buffer_local_b.read_with(cx_a, |buffer, _| { + assert_eq!(buffer.diff_base(), Some(diff_base.as_ref())); + git::diff::assert_hunks( + buffer.snapshot().git_diff_hunks_in_row_range(0..4), + &buffer, + &diff_base, + &[(1..2, "", "two\n")], + ); + }); + + // Create remote buffer + let buffer_remote_b = project_remote + .update(cx_b, |p, cx| p.open_buffer((worktree_id, "sub/b.txt"), cx)) + .await + .unwrap(); + + // Wait remote buffer to catch up to the new diff + executor.run_until_parked(); + + // Smoke test diffing + + buffer_remote_b.read_with(cx_b, |buffer, _| { + assert_eq!(buffer.diff_base(), Some(diff_base.as_ref())); + git::diff::assert_hunks( + buffer.snapshot().git_diff_hunks_in_row_range(0..4), + &buffer, + &diff_base, + &[(1..2, "", "two\n")], + ); + }); + + client_a.fs().set_index_for_repo( + Path::new("/dir/sub/.git"), + &[(Path::new("b.txt"), new_diff_base.clone())], + ); + + // Wait for buffer_local_b to receive it + executor.run_until_parked(); + + // Smoke test new diffing + + buffer_local_b.read_with(cx_a, |buffer, _| { + assert_eq!(buffer.diff_base(), Some(new_diff_base.as_ref())); + println!("{:?}", buffer.as_rope().to_string()); + println!("{:?}", buffer.diff_base()); + println!( + "{:?}", + buffer + .snapshot() + .git_diff_hunks_in_row_range(0..4) + .collect::>() + ); + + git::diff::assert_hunks( + buffer.snapshot().git_diff_hunks_in_row_range(0..4), + &buffer, + &diff_base, + &[(2..3, "", "three\n")], + ); + }); + + // Smoke test B + + buffer_remote_b.read_with(cx_b, |buffer, _| { + assert_eq!(buffer.diff_base(), Some(new_diff_base.as_ref())); + git::diff::assert_hunks( + buffer.snapshot().git_diff_hunks_in_row_range(0..4), + &buffer, + &diff_base, + &[(2..3, "", "three\n")], + ); + }); +} + +#[gpui::test] +async fn test_git_branch_name( + executor: BackgroundExecutor, + cx_a: &mut TestAppContext, + cx_b: &mut TestAppContext, + cx_c: &mut TestAppContext, +) { + let mut server = TestServer::start(executor.clone()).await; + let client_a = server.create_client(cx_a, "user_a").await; + let client_b = server.create_client(cx_b, "user_b").await; + let client_c = server.create_client(cx_c, "user_c").await; + server + .create_room(&mut [(&client_a, cx_a), (&client_b, cx_b), (&client_c, cx_c)]) + .await; + let active_call_a = cx_a.read(ActiveCall::global); + + client_a + .fs() + .insert_tree( + "/dir", + json!({ + ".git": {}, + }), + ) + .await; + + let (project_local, _worktree_id) = client_a.build_local_project("/dir", cx_a).await; + let project_id = active_call_a + .update(cx_a, |call, cx| { + call.share_project(project_local.clone(), cx) + }) + .await + .unwrap(); + + let project_remote = client_b.build_remote_project(project_id, cx_b).await; + client_a + .fs() + .set_branch_name(Path::new("/dir/.git"), Some("branch-1")); + + // Wait for it to catch up to the new branch + executor.run_until_parked(); + + #[track_caller] + fn assert_branch(branch_name: Option>, project: &Project, cx: &AppContext) { + let branch_name = branch_name.map(Into::into); + let worktrees = project.visible_worktrees(cx).collect::>(); + assert_eq!(worktrees.len(), 1); + let worktree = worktrees[0].clone(); + let root_entry = worktree.read(cx).snapshot().root_git_entry().unwrap(); + assert_eq!(root_entry.branch(), branch_name.map(Into::into)); + } + + // Smoke test branch reading + + project_local.read_with(cx_a, |project, cx| { + assert_branch(Some("branch-1"), project, cx) + }); + + project_remote.read_with(cx_b, |project, cx| { + assert_branch(Some("branch-1"), project, cx) + }); + + client_a + .fs() + .set_branch_name(Path::new("/dir/.git"), Some("branch-2")); + + // Wait for buffer_local_a to receive it + executor.run_until_parked(); + + // Smoke test branch reading + + project_local.read_with(cx_a, |project, cx| { + assert_branch(Some("branch-2"), project, cx) + }); + + project_remote.read_with(cx_b, |project, cx| { + assert_branch(Some("branch-2"), project, cx) + }); + + let project_remote_c = client_c.build_remote_project(project_id, cx_c).await; + executor.run_until_parked(); + + project_remote_c.read_with(cx_c, |project, cx| { + assert_branch(Some("branch-2"), project, cx) + }); +} + +#[gpui::test] +async fn test_git_status_sync( + executor: BackgroundExecutor, + cx_a: &mut TestAppContext, + cx_b: &mut TestAppContext, + cx_c: &mut TestAppContext, +) { + let mut server = TestServer::start(executor.clone()).await; + let client_a = server.create_client(cx_a, "user_a").await; + let client_b = server.create_client(cx_b, "user_b").await; + let client_c = server.create_client(cx_c, "user_c").await; + server + .create_room(&mut [(&client_a, cx_a), (&client_b, cx_b), (&client_c, cx_c)]) + .await; + let active_call_a = cx_a.read(ActiveCall::global); + + client_a + .fs() + .insert_tree( + "/dir", + json!({ + ".git": {}, + "a.txt": "a", + "b.txt": "b", + }), + ) + .await; + + const A_TXT: &'static str = "a.txt"; + const B_TXT: &'static str = "b.txt"; + + client_a.fs().set_status_for_repo_via_git_operation( + Path::new("/dir/.git"), + &[ + (&Path::new(A_TXT), GitFileStatus::Added), + (&Path::new(B_TXT), GitFileStatus::Added), + ], + ); + + let (project_local, _worktree_id) = client_a.build_local_project("/dir", cx_a).await; + let project_id = active_call_a + .update(cx_a, |call, cx| { + call.share_project(project_local.clone(), cx) + }) + .await + .unwrap(); + + let project_remote = client_b.build_remote_project(project_id, cx_b).await; + + // Wait for it to catch up to the new status + executor.run_until_parked(); + + #[track_caller] + fn assert_status( + file: &impl AsRef, + status: Option, + project: &Project, + cx: &AppContext, + ) { + let file = file.as_ref(); + let worktrees = project.visible_worktrees(cx).collect::>(); + assert_eq!(worktrees.len(), 1); + let worktree = worktrees[0].clone(); + let snapshot = worktree.read(cx).snapshot(); + assert_eq!(snapshot.status_for_file(file), status); + } + + // Smoke test status reading + + project_local.read_with(cx_a, |project, cx| { + assert_status(&Path::new(A_TXT), Some(GitFileStatus::Added), project, cx); + assert_status(&Path::new(B_TXT), Some(GitFileStatus::Added), project, cx); + }); + + project_remote.read_with(cx_b, |project, cx| { + assert_status(&Path::new(A_TXT), Some(GitFileStatus::Added), project, cx); + assert_status(&Path::new(B_TXT), Some(GitFileStatus::Added), project, cx); + }); + + client_a.fs().set_status_for_repo_via_working_copy_change( + Path::new("/dir/.git"), + &[ + (&Path::new(A_TXT), GitFileStatus::Modified), + (&Path::new(B_TXT), GitFileStatus::Modified), + ], + ); + + // Wait for buffer_local_a to receive it + executor.run_until_parked(); + + // Smoke test status reading + + project_local.read_with(cx_a, |project, cx| { + assert_status( + &Path::new(A_TXT), + Some(GitFileStatus::Modified), + project, + cx, + ); + assert_status( + &Path::new(B_TXT), + Some(GitFileStatus::Modified), + project, + cx, + ); + }); + + project_remote.read_with(cx_b, |project, cx| { + assert_status( + &Path::new(A_TXT), + Some(GitFileStatus::Modified), + project, + cx, + ); + assert_status( + &Path::new(B_TXT), + Some(GitFileStatus::Modified), + project, + cx, + ); + }); + + // And synchronization while joining + let project_remote_c = client_c.build_remote_project(project_id, cx_c).await; + executor.run_until_parked(); + + project_remote_c.read_with(cx_c, |project, cx| { + assert_status( + &Path::new(A_TXT), + Some(GitFileStatus::Modified), + project, + cx, + ); + assert_status( + &Path::new(B_TXT), + Some(GitFileStatus::Modified), + project, + cx, + ); + }); +} + +#[gpui::test(iterations = 10)] +async fn test_fs_operations( + executor: BackgroundExecutor, + cx_a: &mut TestAppContext, + cx_b: &mut TestAppContext, +) { + let mut server = TestServer::start(executor.clone()).await; + let client_a = server.create_client(cx_a, "user_a").await; + let client_b = server.create_client(cx_b, "user_b").await; + server + .create_room(&mut [(&client_a, cx_a), (&client_b, cx_b)]) + .await; + let active_call_a = cx_a.read(ActiveCall::global); + + client_a + .fs() + .insert_tree( + "/dir", + json!({ + "a.txt": "a-contents", + "b.txt": "b-contents", + }), + ) + .await; + let (project_a, worktree_id) = client_a.build_local_project("/dir", cx_a).await; + let project_id = active_call_a + .update(cx_a, |call, cx| call.share_project(project_a.clone(), cx)) + .await + .unwrap(); + let project_b = client_b.build_remote_project(project_id, cx_b).await; + + let worktree_a = project_a.read_with(cx_a, |project, _| project.worktrees().next().unwrap()); + + let worktree_b = project_b.read_with(cx_b, |project, _| project.worktrees().next().unwrap()); + + let entry = project_b + .update(cx_b, |project, cx| { + project + .create_entry((worktree_id, "c.txt"), false, cx) + .unwrap() + }) + .await + .unwrap(); + + worktree_a.read_with(cx_a, |worktree, _| { + assert_eq!( + worktree + .paths() + .map(|p| p.to_string_lossy()) + .collect::>(), + ["a.txt", "b.txt", "c.txt"] + ); + }); + + worktree_b.read_with(cx_b, |worktree, _| { + assert_eq!( + worktree + .paths() + .map(|p| p.to_string_lossy()) + .collect::>(), + ["a.txt", "b.txt", "c.txt"] + ); + }); + + project_b + .update(cx_b, |project, cx| { + project.rename_entry(entry.id, Path::new("d.txt"), cx) + }) + .unwrap() + .await + .unwrap(); + + worktree_a.read_with(cx_a, |worktree, _| { + assert_eq!( + worktree + .paths() + .map(|p| p.to_string_lossy()) + .collect::>(), + ["a.txt", "b.txt", "d.txt"] + ); + }); + + worktree_b.read_with(cx_b, |worktree, _| { + assert_eq!( + worktree + .paths() + .map(|p| p.to_string_lossy()) + .collect::>(), + ["a.txt", "b.txt", "d.txt"] + ); + }); + + let dir_entry = project_b + .update(cx_b, |project, cx| { + project + .create_entry((worktree_id, "DIR"), true, cx) + .unwrap() + }) + .await + .unwrap(); + + worktree_a.read_with(cx_a, |worktree, _| { + assert_eq!( + worktree + .paths() + .map(|p| p.to_string_lossy()) + .collect::>(), + ["DIR", "a.txt", "b.txt", "d.txt"] + ); + }); + + worktree_b.read_with(cx_b, |worktree, _| { + assert_eq!( + worktree + .paths() + .map(|p| p.to_string_lossy()) + .collect::>(), + ["DIR", "a.txt", "b.txt", "d.txt"] + ); + }); + + project_b + .update(cx_b, |project, cx| { + project + .create_entry((worktree_id, "DIR/e.txt"), false, cx) + .unwrap() + }) + .await + .unwrap(); + project_b + .update(cx_b, |project, cx| { + project + .create_entry((worktree_id, "DIR/SUBDIR"), true, cx) + .unwrap() + }) + .await + .unwrap(); + project_b + .update(cx_b, |project, cx| { + project + .create_entry((worktree_id, "DIR/SUBDIR/f.txt"), false, cx) + .unwrap() + }) + .await + .unwrap(); + + worktree_a.read_with(cx_a, |worktree, _| { + assert_eq!( + worktree + .paths() + .map(|p| p.to_string_lossy()) + .collect::>(), + [ + "DIR", + "DIR/SUBDIR", + "DIR/SUBDIR/f.txt", + "DIR/e.txt", + "a.txt", + "b.txt", + "d.txt" + ] + ); + }); + + worktree_b.read_with(cx_b, |worktree, _| { + assert_eq!( + worktree + .paths() + .map(|p| p.to_string_lossy()) + .collect::>(), + [ + "DIR", + "DIR/SUBDIR", + "DIR/SUBDIR/f.txt", + "DIR/e.txt", + "a.txt", + "b.txt", + "d.txt" + ] + ); + }); + + project_b + .update(cx_b, |project, cx| { + project + .copy_entry(entry.id, Path::new("f.txt"), cx) + .unwrap() + }) + .await + .unwrap(); + + worktree_a.read_with(cx_a, |worktree, _| { + assert_eq!( + worktree + .paths() + .map(|p| p.to_string_lossy()) + .collect::>(), + [ + "DIR", + "DIR/SUBDIR", + "DIR/SUBDIR/f.txt", + "DIR/e.txt", + "a.txt", + "b.txt", + "d.txt", + "f.txt" + ] + ); + }); + + worktree_b.read_with(cx_b, |worktree, _| { + assert_eq!( + worktree + .paths() + .map(|p| p.to_string_lossy()) + .collect::>(), + [ + "DIR", + "DIR/SUBDIR", + "DIR/SUBDIR/f.txt", + "DIR/e.txt", + "a.txt", + "b.txt", + "d.txt", + "f.txt" + ] + ); + }); + + project_b + .update(cx_b, |project, cx| { + project.delete_entry(dir_entry.id, cx).unwrap() + }) + .await + .unwrap(); + executor.run_until_parked(); + + worktree_a.read_with(cx_a, |worktree, _| { + assert_eq!( + worktree + .paths() + .map(|p| p.to_string_lossy()) + .collect::>(), + ["a.txt", "b.txt", "d.txt", "f.txt"] + ); + }); + + worktree_b.read_with(cx_b, |worktree, _| { + assert_eq!( + worktree + .paths() + .map(|p| p.to_string_lossy()) + .collect::>(), + ["a.txt", "b.txt", "d.txt", "f.txt"] + ); + }); + + project_b + .update(cx_b, |project, cx| { + project.delete_entry(entry.id, cx).unwrap() + }) + .await + .unwrap(); + + worktree_a.read_with(cx_a, |worktree, _| { + assert_eq!( + worktree + .paths() + .map(|p| p.to_string_lossy()) + .collect::>(), + ["a.txt", "b.txt", "f.txt"] + ); + }); + + worktree_b.read_with(cx_b, |worktree, _| { + assert_eq!( + worktree + .paths() + .map(|p| p.to_string_lossy()) + .collect::>(), + ["a.txt", "b.txt", "f.txt"] + ); + }); +} + +#[gpui::test(iterations = 10)] +async fn test_local_settings( + executor: BackgroundExecutor, + cx_a: &mut TestAppContext, + cx_b: &mut TestAppContext, +) { + let mut server = TestServer::start(executor.clone()).await; + let client_a = server.create_client(cx_a, "user_a").await; + let client_b = server.create_client(cx_b, "user_b").await; + server + .create_room(&mut [(&client_a, cx_a), (&client_b, cx_b)]) + .await; + let active_call_a = cx_a.read(ActiveCall::global); + + // As client A, open a project that contains some local settings files + client_a + .fs() + .insert_tree( + "/dir", + json!({ + ".zed": { + "settings.json": r#"{ "tab_size": 2 }"# + }, + "a": { + ".zed": { + "settings.json": r#"{ "tab_size": 8 }"# + }, + "a.txt": "a-contents", + }, + "b": { + "b.txt": "b-contents", + } + }), + ) + .await; + let (project_a, _) = client_a.build_local_project("/dir", cx_a).await; + executor.run_until_parked(); + let project_id = active_call_a + .update(cx_a, |call, cx| call.share_project(project_a.clone(), cx)) + .await + .unwrap(); + + // As client B, join that project and observe the local settings. + let project_b = client_b.build_remote_project(project_id, cx_b).await; + + let worktree_b = project_b.read_with(cx_b, |project, _| project.worktrees().next().unwrap()); + executor.run_until_parked(); + cx_b.read(|cx| { + let store = cx.global::(); + assert_eq!( + store + .local_settings(worktree_b.read(cx).id().to_usize()) + .collect::>(), + &[ + (Path::new("").into(), r#"{"tab_size":2}"#.to_string()), + (Path::new("a").into(), r#"{"tab_size":8}"#.to_string()), + ] + ) + }); + + // As client A, update a settings file. As Client B, see the changed settings. + client_a + .fs() + .insert_file("/dir/.zed/settings.json", r#"{}"#.into()) + .await; + executor.run_until_parked(); + cx_b.read(|cx| { + let store = cx.global::(); + assert_eq!( + store + .local_settings(worktree_b.read(cx).id().to_usize()) + .collect::>(), + &[ + (Path::new("").into(), r#"{}"#.to_string()), + (Path::new("a").into(), r#"{"tab_size":8}"#.to_string()), + ] + ) + }); + + // As client A, create and remove some settings files. As client B, see the changed settings. + client_a + .fs() + .remove_file("/dir/.zed/settings.json".as_ref(), Default::default()) + .await + .unwrap(); + client_a + .fs() + .create_dir("/dir/b/.zed".as_ref()) + .await + .unwrap(); + client_a + .fs() + .insert_file("/dir/b/.zed/settings.json", r#"{"tab_size": 4}"#.into()) + .await; + executor.run_until_parked(); + cx_b.read(|cx| { + let store = cx.global::(); + assert_eq!( + store + .local_settings(worktree_b.read(cx).id().to_usize()) + .collect::>(), + &[ + (Path::new("a").into(), r#"{"tab_size":8}"#.to_string()), + (Path::new("b").into(), r#"{"tab_size":4}"#.to_string()), + ] + ) + }); + + // As client B, disconnect. + server.forbid_connections(); + server.disconnect_client(client_b.peer_id().unwrap()); + + // As client A, change and remove settings files while client B is disconnected. + client_a + .fs() + .insert_file("/dir/a/.zed/settings.json", r#"{"hard_tabs":true}"#.into()) + .await; + client_a + .fs() + .remove_file("/dir/b/.zed/settings.json".as_ref(), Default::default()) + .await + .unwrap(); + executor.run_until_parked(); + + // As client B, reconnect and see the changed settings. + server.allow_connections(); + executor.advance_clock(RECEIVE_TIMEOUT); + cx_b.read(|cx| { + let store = cx.global::(); + assert_eq!( + store + .local_settings(worktree_b.read(cx).id().to_usize()) + .collect::>(), + &[(Path::new("a").into(), r#"{"hard_tabs":true}"#.to_string()),] + ) + }); +} + +#[gpui::test(iterations = 10)] +async fn test_buffer_conflict_after_save( + executor: BackgroundExecutor, + cx_a: &mut TestAppContext, + cx_b: &mut TestAppContext, +) { + let mut server = TestServer::start(executor.clone()).await; + let client_a = server.create_client(cx_a, "user_a").await; + let client_b = server.create_client(cx_b, "user_b").await; + server + .create_room(&mut [(&client_a, cx_a), (&client_b, cx_b)]) + .await; + let active_call_a = cx_a.read(ActiveCall::global); + + client_a + .fs() + .insert_tree( + "/dir", + json!({ + "a.txt": "a-contents", + }), + ) + .await; + let (project_a, worktree_id) = client_a.build_local_project("/dir", cx_a).await; + let project_id = active_call_a + .update(cx_a, |call, cx| call.share_project(project_a.clone(), cx)) + .await + .unwrap(); + let project_b = client_b.build_remote_project(project_id, cx_b).await; + + // Open a buffer as client B + let buffer_b = project_b + .update(cx_b, |p, cx| p.open_buffer((worktree_id, "a.txt"), cx)) + .await + .unwrap(); + + buffer_b.update(cx_b, |buf, cx| buf.edit([(0..0, "world ")], None, cx)); + + buffer_b.read_with(cx_b, |buf, _| { + assert!(buf.is_dirty()); + assert!(!buf.has_conflict()); + }); + + project_b + .update(cx_b, |project, cx| { + project.save_buffer(buffer_b.clone(), cx) + }) + .await + .unwrap(); + + buffer_b.read_with(cx_b, |buffer_b, _| assert!(!buffer_b.is_dirty())); + + buffer_b.read_with(cx_b, |buf, _| { + assert!(!buf.has_conflict()); + }); + + buffer_b.update(cx_b, |buf, cx| buf.edit([(0..0, "hello ")], None, cx)); + + buffer_b.read_with(cx_b, |buf, _| { + assert!(buf.is_dirty()); + assert!(!buf.has_conflict()); + }); +} + +#[gpui::test(iterations = 10)] +async fn test_buffer_reloading( + executor: BackgroundExecutor, + cx_a: &mut TestAppContext, + cx_b: &mut TestAppContext, +) { + let mut server = TestServer::start(executor.clone()).await; + let client_a = server.create_client(cx_a, "user_a").await; + let client_b = server.create_client(cx_b, "user_b").await; + server + .create_room(&mut [(&client_a, cx_a), (&client_b, cx_b)]) + .await; + let active_call_a = cx_a.read(ActiveCall::global); + + client_a + .fs() + .insert_tree( + "/dir", + json!({ + "a.txt": "a\nb\nc", + }), + ) + .await; + let (project_a, worktree_id) = client_a.build_local_project("/dir", cx_a).await; + let project_id = active_call_a + .update(cx_a, |call, cx| call.share_project(project_a.clone(), cx)) + .await + .unwrap(); + let project_b = client_b.build_remote_project(project_id, cx_b).await; + + // Open a buffer as client B + let buffer_b = project_b + .update(cx_b, |p, cx| p.open_buffer((worktree_id, "a.txt"), cx)) + .await + .unwrap(); + + buffer_b.read_with(cx_b, |buf, _| { + assert!(!buf.is_dirty()); + assert!(!buf.has_conflict()); + assert_eq!(buf.line_ending(), LineEnding::Unix); + }); + + let new_contents = Rope::from("d\ne\nf"); + client_a + .fs() + .save("/dir/a.txt".as_ref(), &new_contents, LineEnding::Windows) + .await + .unwrap(); + + executor.run_until_parked(); + + buffer_b.read_with(cx_b, |buf, _| { + assert_eq!(buf.text(), new_contents.to_string()); + assert!(!buf.is_dirty()); + assert!(!buf.has_conflict()); + assert_eq!(buf.line_ending(), LineEnding::Windows); + }); +} + +#[gpui::test(iterations = 10)] +async fn test_editing_while_guest_opens_buffer( + executor: BackgroundExecutor, + cx_a: &mut TestAppContext, + cx_b: &mut TestAppContext, +) { + let mut server = TestServer::start(executor.clone()).await; + let client_a = server.create_client(cx_a, "user_a").await; + let client_b = server.create_client(cx_b, "user_b").await; + server + .create_room(&mut [(&client_a, cx_a), (&client_b, cx_b)]) + .await; + let active_call_a = cx_a.read(ActiveCall::global); + + client_a + .fs() + .insert_tree("/dir", json!({ "a.txt": "a-contents" })) + .await; + let (project_a, worktree_id) = client_a.build_local_project("/dir", cx_a).await; + let project_id = active_call_a + .update(cx_a, |call, cx| call.share_project(project_a.clone(), cx)) + .await + .unwrap(); + let project_b = client_b.build_remote_project(project_id, cx_b).await; + + // Open a buffer as client A + let buffer_a = project_a + .update(cx_a, |p, cx| p.open_buffer((worktree_id, "a.txt"), cx)) + .await + .unwrap(); + + // Start opening the same buffer as client B + let open_buffer = project_b.update(cx_b, |p, cx| p.open_buffer((worktree_id, "a.txt"), cx)); + let buffer_b = cx_b.executor().spawn(open_buffer); + + // Edit the buffer as client A while client B is still opening it. + cx_b.executor().simulate_random_delay().await; + buffer_a.update(cx_a, |buf, cx| buf.edit([(0..0, "X")], None, cx)); + cx_b.executor().simulate_random_delay().await; + buffer_a.update(cx_a, |buf, cx| buf.edit([(1..1, "Y")], None, cx)); + + let text = buffer_a.read_with(cx_a, |buf, _| buf.text()); + let buffer_b = buffer_b.await.unwrap(); + executor.run_until_parked(); + + buffer_b.read_with(cx_b, |buf, _| assert_eq!(buf.text(), text)); +} + +#[gpui::test(iterations = 10)] +async fn test_leaving_worktree_while_opening_buffer( + executor: BackgroundExecutor, + cx_a: &mut TestAppContext, + cx_b: &mut TestAppContext, +) { + let mut server = TestServer::start(executor.clone()).await; + let client_a = server.create_client(cx_a, "user_a").await; + let client_b = server.create_client(cx_b, "user_b").await; + server + .create_room(&mut [(&client_a, cx_a), (&client_b, cx_b)]) + .await; + let active_call_a = cx_a.read(ActiveCall::global); + + client_a + .fs() + .insert_tree("/dir", json!({ "a.txt": "a-contents" })) + .await; + let (project_a, worktree_id) = client_a.build_local_project("/dir", cx_a).await; + let project_id = active_call_a + .update(cx_a, |call, cx| call.share_project(project_a.clone(), cx)) + .await + .unwrap(); + let project_b = client_b.build_remote_project(project_id, cx_b).await; + + // See that a guest has joined as client A. + executor.run_until_parked(); + + project_a.read_with(cx_a, |p, _| assert_eq!(p.collaborators().len(), 1)); + + // Begin opening a buffer as client B, but leave the project before the open completes. + let open_buffer = project_b.update(cx_b, |p, cx| p.open_buffer((worktree_id, "a.txt"), cx)); + let buffer_b = cx_b.executor().spawn(open_buffer); + cx_b.update(|_| drop(project_b)); + drop(buffer_b); + + // See that the guest has left. + executor.run_until_parked(); + + project_a.read_with(cx_a, |p, _| assert!(p.collaborators().is_empty())); +} + +#[gpui::test(iterations = 10)] +async fn test_canceling_buffer_opening( + executor: BackgroundExecutor, + cx_a: &mut TestAppContext, + cx_b: &mut TestAppContext, +) { + let mut server = TestServer::start(executor.clone()).await; + let client_a = server.create_client(cx_a, "user_a").await; + let client_b = server.create_client(cx_b, "user_b").await; + server + .create_room(&mut [(&client_a, cx_a), (&client_b, cx_b)]) + .await; + let active_call_a = cx_a.read(ActiveCall::global); + + client_a + .fs() + .insert_tree( + "/dir", + json!({ + "a.txt": "abc", + }), + ) + .await; + let (project_a, worktree_id) = client_a.build_local_project("/dir", cx_a).await; + let project_id = active_call_a + .update(cx_a, |call, cx| call.share_project(project_a.clone(), cx)) + .await + .unwrap(); + let project_b = client_b.build_remote_project(project_id, cx_b).await; + + let buffer_a = project_a + .update(cx_a, |p, cx| p.open_buffer((worktree_id, "a.txt"), cx)) + .await + .unwrap(); + + // Open a buffer as client B but cancel after a random amount of time. + let buffer_b = project_b.update(cx_b, |p, cx| { + p.open_buffer_by_id(buffer_a.read_with(cx_a, |a, _| a.remote_id()), cx) + }); + executor.simulate_random_delay().await; + drop(buffer_b); + + // Try opening the same buffer again as client B, and ensure we can + // still do it despite the cancellation above. + let buffer_b = project_b + .update(cx_b, |p, cx| { + p.open_buffer_by_id(buffer_a.read_with(cx_a, |a, _| a.remote_id()), cx) + }) + .await + .unwrap(); + + buffer_b.read_with(cx_b, |buf, _| assert_eq!(buf.text(), "abc")); +} + +#[gpui::test(iterations = 10)] +async fn test_leaving_project( + executor: BackgroundExecutor, + cx_a: &mut TestAppContext, + cx_b: &mut TestAppContext, + cx_c: &mut TestAppContext, +) { + let mut server = TestServer::start(executor.clone()).await; + let client_a = server.create_client(cx_a, "user_a").await; + let client_b = server.create_client(cx_b, "user_b").await; + let client_c = server.create_client(cx_c, "user_c").await; + server + .create_room(&mut [(&client_a, cx_a), (&client_b, cx_b), (&client_c, cx_c)]) + .await; + let active_call_a = cx_a.read(ActiveCall::global); + + client_a + .fs() + .insert_tree( + "/a", + json!({ + "a.txt": "a-contents", + "b.txt": "b-contents", + }), + ) + .await; + let (project_a, _) = client_a.build_local_project("/a", cx_a).await; + let project_id = active_call_a + .update(cx_a, |call, cx| call.share_project(project_a.clone(), cx)) + .await + .unwrap(); + let project_b1 = client_b.build_remote_project(project_id, cx_b).await; + let project_c = client_c.build_remote_project(project_id, cx_c).await; + + // Client A sees that a guest has joined. + executor.run_until_parked(); + + project_a.read_with(cx_a, |project, _| { + assert_eq!(project.collaborators().len(), 2); + }); + + project_b1.read_with(cx_b, |project, _| { + assert_eq!(project.collaborators().len(), 2); + }); + + project_c.read_with(cx_c, |project, _| { + assert_eq!(project.collaborators().len(), 2); + }); + + // Client B opens a buffer. + let buffer_b1 = project_b1 + .update(cx_b, |project, cx| { + let worktree_id = project.worktrees().next().unwrap().read(cx).id(); + project.open_buffer((worktree_id, "a.txt"), cx) + }) + .await + .unwrap(); + + buffer_b1.read_with(cx_b, |buffer, _| assert_eq!(buffer.text(), "a-contents")); + + // Drop client B's project and ensure client A and client C observe client B leaving. + cx_b.update(|_| drop(project_b1)); + executor.run_until_parked(); + + project_a.read_with(cx_a, |project, _| { + assert_eq!(project.collaborators().len(), 1); + }); + + project_c.read_with(cx_c, |project, _| { + assert_eq!(project.collaborators().len(), 1); + }); + + // Client B re-joins the project and can open buffers as before. + let project_b2 = client_b.build_remote_project(project_id, cx_b).await; + executor.run_until_parked(); + + project_a.read_with(cx_a, |project, _| { + assert_eq!(project.collaborators().len(), 2); + }); + + project_b2.read_with(cx_b, |project, _| { + assert_eq!(project.collaborators().len(), 2); + }); + + project_c.read_with(cx_c, |project, _| { + assert_eq!(project.collaborators().len(), 2); + }); + + let buffer_b2 = project_b2 + .update(cx_b, |project, cx| { + let worktree_id = project.worktrees().next().unwrap().read(cx).id(); + project.open_buffer((worktree_id, "a.txt"), cx) + }) + .await + .unwrap(); + + buffer_b2.read_with(cx_b, |buffer, _| assert_eq!(buffer.text(), "a-contents")); + + // Drop client B's connection and ensure client A and client C observe client B leaving. + client_b.disconnect(&cx_b.to_async()); + executor.advance_clock(RECONNECT_TIMEOUT); + + project_a.read_with(cx_a, |project, _| { + assert_eq!(project.collaborators().len(), 1); + }); + + project_b2.read_with(cx_b, |project, _| { + assert!(project.is_read_only()); + }); + + project_c.read_with(cx_c, |project, _| { + assert_eq!(project.collaborators().len(), 1); + }); + + // Client B can't join the project, unless they re-join the room. + cx_b.spawn(|cx| { + Project::remote( + project_id, + client_b.app_state.client.clone(), + client_b.user_store().clone(), + client_b.language_registry().clone(), + FakeFs::new(cx.background_executor().clone()), + cx, + ) + }) + .await + .unwrap_err(); + + // Simulate connection loss for client C and ensure client A observes client C leaving the project. + client_c.wait_for_current_user(cx_c).await; + server.forbid_connections(); + server.disconnect_client(client_c.peer_id().unwrap()); + executor.advance_clock(RECEIVE_TIMEOUT + RECONNECT_TIMEOUT); + executor.run_until_parked(); + + project_a.read_with(cx_a, |project, _| { + assert_eq!(project.collaborators().len(), 0); + }); + + project_b2.read_with(cx_b, |project, _| { + assert!(project.is_read_only()); + }); + + project_c.read_with(cx_c, |project, _| { + assert!(project.is_read_only()); + }); +} + +#[gpui::test(iterations = 10)] +async fn test_collaborating_with_diagnostics( + executor: BackgroundExecutor, + cx_a: &mut TestAppContext, + cx_b: &mut TestAppContext, + cx_c: &mut TestAppContext, +) { + let mut server = TestServer::start(executor.clone()).await; + let client_a = server.create_client(cx_a, "user_a").await; + let client_b = server.create_client(cx_b, "user_b").await; + let client_c = server.create_client(cx_c, "user_c").await; + server + .create_room(&mut [(&client_a, cx_a), (&client_b, cx_b), (&client_c, cx_c)]) + .await; + let active_call_a = cx_a.read(ActiveCall::global); + + // Set up a fake language server. + let mut language = Language::new( + LanguageConfig { + name: "Rust".into(), + path_suffixes: vec!["rs".to_string()], + ..Default::default() + }, + Some(tree_sitter_rust::language()), + ); + let mut fake_language_servers = language.set_fake_lsp_adapter(Default::default()).await; + client_a.language_registry().add(Arc::new(language)); + + // Share a project as client A + client_a + .fs() + .insert_tree( + "/a", + json!({ + "a.rs": "let one = two", + "other.rs": "", + }), + ) + .await; + let (project_a, worktree_id) = client_a.build_local_project("/a", cx_a).await; + + // Cause the language server to start. + let _buffer = project_a + .update(cx_a, |project, cx| { + project.open_buffer( + ProjectPath { + worktree_id, + path: Path::new("other.rs").into(), + }, + cx, + ) + }) + .await + .unwrap(); + + // Simulate a language server reporting errors for a file. + let mut fake_language_server = fake_language_servers.next().await.unwrap(); + fake_language_server + .receive_notification::() + .await; + fake_language_server.notify::( + lsp::PublishDiagnosticsParams { + uri: lsp::Url::from_file_path("/a/a.rs").unwrap(), + version: None, + diagnostics: vec![lsp::Diagnostic { + severity: Some(lsp::DiagnosticSeverity::WARNING), + range: lsp::Range::new(lsp::Position::new(0, 4), lsp::Position::new(0, 7)), + message: "message 0".to_string(), + ..Default::default() + }], + }, + ); + + // Client A shares the project and, simultaneously, the language server + // publishes a diagnostic. This is done to ensure that the server always + // observes the latest diagnostics for a worktree. + let project_id = active_call_a + .update(cx_a, |call, cx| call.share_project(project_a.clone(), cx)) + .await + .unwrap(); + fake_language_server.notify::( + lsp::PublishDiagnosticsParams { + uri: lsp::Url::from_file_path("/a/a.rs").unwrap(), + version: None, + diagnostics: vec![lsp::Diagnostic { + severity: Some(lsp::DiagnosticSeverity::ERROR), + range: lsp::Range::new(lsp::Position::new(0, 4), lsp::Position::new(0, 7)), + message: "message 1".to_string(), + ..Default::default() + }], + }, + ); + + // Join the worktree as client B. + let project_b = client_b.build_remote_project(project_id, cx_b).await; + + // Wait for server to see the diagnostics update. + executor.run_until_parked(); + + // Ensure client B observes the new diagnostics. + + project_b.read_with(cx_b, |project, cx| { + assert_eq!( + project.diagnostic_summaries(cx).collect::>(), + &[( + ProjectPath { + worktree_id, + path: Arc::from(Path::new("a.rs")), + }, + LanguageServerId(0), + DiagnosticSummary { + error_count: 1, + warning_count: 0, + ..Default::default() + }, + )] + ) + }); + + // Join project as client C and observe the diagnostics. + let project_c = client_c.build_remote_project(project_id, cx_c).await; + let project_c_diagnostic_summaries = + Rc::new(RefCell::new(project_c.read_with(cx_c, |project, cx| { + project.diagnostic_summaries(cx).collect::>() + }))); + project_c.update(cx_c, |_, cx| { + let summaries = project_c_diagnostic_summaries.clone(); + cx.subscribe(&project_c, { + move |p, _, event, cx| { + if let project::Event::DiskBasedDiagnosticsFinished { .. } = event { + *summaries.borrow_mut() = p.diagnostic_summaries(cx).collect(); + } + } + }) + .detach(); + }); + + executor.run_until_parked(); + assert_eq!( + project_c_diagnostic_summaries.borrow().as_slice(), + &[( + ProjectPath { + worktree_id, + path: Arc::from(Path::new("a.rs")), + }, + LanguageServerId(0), + DiagnosticSummary { + error_count: 1, + warning_count: 0, + ..Default::default() + }, + )] + ); + + // Simulate a language server reporting more errors for a file. + fake_language_server.notify::( + lsp::PublishDiagnosticsParams { + uri: lsp::Url::from_file_path("/a/a.rs").unwrap(), + version: None, + diagnostics: vec![ + lsp::Diagnostic { + severity: Some(lsp::DiagnosticSeverity::ERROR), + range: lsp::Range::new(lsp::Position::new(0, 4), lsp::Position::new(0, 7)), + message: "message 1".to_string(), + ..Default::default() + }, + lsp::Diagnostic { + severity: Some(lsp::DiagnosticSeverity::WARNING), + range: lsp::Range::new(lsp::Position::new(0, 10), lsp::Position::new(0, 13)), + message: "message 2".to_string(), + ..Default::default() + }, + ], + }, + ); + + // Clients B and C get the updated summaries + executor.run_until_parked(); + + project_b.read_with(cx_b, |project, cx| { + assert_eq!( + project.diagnostic_summaries(cx).collect::>(), + [( + ProjectPath { + worktree_id, + path: Arc::from(Path::new("a.rs")), + }, + LanguageServerId(0), + DiagnosticSummary { + error_count: 1, + warning_count: 1, + }, + )] + ); + }); + + project_c.read_with(cx_c, |project, cx| { + assert_eq!( + project.diagnostic_summaries(cx).collect::>(), + [( + ProjectPath { + worktree_id, + path: Arc::from(Path::new("a.rs")), + }, + LanguageServerId(0), + DiagnosticSummary { + error_count: 1, + warning_count: 1, + }, + )] + ); + }); + + // Open the file with the errors on client B. They should be present. + let open_buffer = project_b.update(cx_b, |p, cx| p.open_buffer((worktree_id, "a.rs"), cx)); + let buffer_b = cx_b.executor().spawn(open_buffer).await.unwrap(); + + buffer_b.read_with(cx_b, |buffer, _| { + assert_eq!( + buffer + .snapshot() + .diagnostics_in_range::<_, Point>(0..buffer.len(), false) + .collect::>(), + &[ + DiagnosticEntry { + range: Point::new(0, 4)..Point::new(0, 7), + diagnostic: Diagnostic { + group_id: 2, + message: "message 1".to_string(), + severity: lsp::DiagnosticSeverity::ERROR, + is_primary: true, + ..Default::default() + } + }, + DiagnosticEntry { + range: Point::new(0, 10)..Point::new(0, 13), + diagnostic: Diagnostic { + group_id: 3, + severity: lsp::DiagnosticSeverity::WARNING, + message: "message 2".to_string(), + is_primary: true, + ..Default::default() + } + } + ] + ); + }); + + // Simulate a language server reporting no errors for a file. + fake_language_server.notify::( + lsp::PublishDiagnosticsParams { + uri: lsp::Url::from_file_path("/a/a.rs").unwrap(), + version: None, + diagnostics: vec![], + }, + ); + executor.run_until_parked(); + + project_a.read_with(cx_a, |project, cx| { + assert_eq!(project.diagnostic_summaries(cx).collect::>(), []) + }); + + project_b.read_with(cx_b, |project, cx| { + assert_eq!(project.diagnostic_summaries(cx).collect::>(), []) + }); + + project_c.read_with(cx_c, |project, cx| { + assert_eq!(project.diagnostic_summaries(cx).collect::>(), []) + }); +} + +#[gpui::test(iterations = 10)] +async fn test_collaborating_with_lsp_progress_updates_and_diagnostics_ordering( + executor: BackgroundExecutor, + cx_a: &mut TestAppContext, + cx_b: &mut TestAppContext, +) { + let mut server = TestServer::start(executor.clone()).await; + let client_a = server.create_client(cx_a, "user_a").await; + let client_b = server.create_client(cx_b, "user_b").await; + server + .create_room(&mut [(&client_a, cx_a), (&client_b, cx_b)]) + .await; + + // Set up a fake language server. + let mut language = Language::new( + LanguageConfig { + name: "Rust".into(), + path_suffixes: vec!["rs".to_string()], + ..Default::default() + }, + Some(tree_sitter_rust::language()), + ); + let mut fake_language_servers = language + .set_fake_lsp_adapter(Arc::new(FakeLspAdapter { + disk_based_diagnostics_progress_token: Some("the-disk-based-token".into()), + disk_based_diagnostics_sources: vec!["the-disk-based-diagnostics-source".into()], + ..Default::default() + })) + .await; + client_a.language_registry().add(Arc::new(language)); + + let file_names = &["one.rs", "two.rs", "three.rs", "four.rs", "five.rs"]; + client_a + .fs() + .insert_tree( + "/test", + json!({ + "one.rs": "const ONE: usize = 1;", + "two.rs": "const TWO: usize = 2;", + "three.rs": "const THREE: usize = 3;", + "four.rs": "const FOUR: usize = 3;", + "five.rs": "const FIVE: usize = 3;", + }), + ) + .await; + + let (project_a, worktree_id) = client_a.build_local_project("/test", cx_a).await; + + // Share a project as client A + let active_call_a = cx_a.read(ActiveCall::global); + let project_id = active_call_a + .update(cx_a, |call, cx| call.share_project(project_a.clone(), cx)) + .await + .unwrap(); + + // Join the project as client B and open all three files. + let project_b = client_b.build_remote_project(project_id, cx_b).await; + let guest_buffers = futures::future::try_join_all(file_names.iter().map(|file_name| { + project_b.update(cx_b, |p, cx| p.open_buffer((worktree_id, file_name), cx)) + })) + .await + .unwrap(); + + // Simulate a language server reporting errors for a file. + let fake_language_server = fake_language_servers.next().await.unwrap(); + fake_language_server + .request::(lsp::WorkDoneProgressCreateParams { + token: lsp::NumberOrString::String("the-disk-based-token".to_string()), + }) + .await + .unwrap(); + fake_language_server.notify::(lsp::ProgressParams { + token: lsp::NumberOrString::String("the-disk-based-token".to_string()), + value: lsp::ProgressParamsValue::WorkDone(lsp::WorkDoneProgress::Begin( + lsp::WorkDoneProgressBegin { + title: "Progress Began".into(), + ..Default::default() + }, + )), + }); + for file_name in file_names { + fake_language_server.notify::( + lsp::PublishDiagnosticsParams { + uri: lsp::Url::from_file_path(Path::new("/test").join(file_name)).unwrap(), + version: None, + diagnostics: vec![lsp::Diagnostic { + severity: Some(lsp::DiagnosticSeverity::WARNING), + source: Some("the-disk-based-diagnostics-source".into()), + range: lsp::Range::new(lsp::Position::new(0, 0), lsp::Position::new(0, 0)), + message: "message one".to_string(), + ..Default::default() + }], + }, + ); + } + fake_language_server.notify::(lsp::ProgressParams { + token: lsp::NumberOrString::String("the-disk-based-token".to_string()), + value: lsp::ProgressParamsValue::WorkDone(lsp::WorkDoneProgress::End( + lsp::WorkDoneProgressEnd { message: None }, + )), + }); + + // When the "disk base diagnostics finished" message is received, the buffers' + // diagnostics are expected to be present. + let disk_based_diagnostics_finished = Arc::new(AtomicBool::new(false)); + project_b.update(cx_b, { + let project_b = project_b.clone(); + let disk_based_diagnostics_finished = disk_based_diagnostics_finished.clone(); + move |_, cx| { + cx.subscribe(&project_b, move |_, _, event, cx| { + if let project::Event::DiskBasedDiagnosticsFinished { .. } = event { + disk_based_diagnostics_finished.store(true, SeqCst); + for buffer in &guest_buffers { + assert_eq!( + buffer + .read(cx) + .snapshot() + .diagnostics_in_range::<_, usize>(0..5, false) + .count(), + 1, + "expected a diagnostic for buffer {:?}", + buffer.read(cx).file().unwrap().path(), + ); + } + } + }) + .detach(); + } + }); + + executor.run_until_parked(); + assert!(disk_based_diagnostics_finished.load(SeqCst)); +} + +#[gpui::test(iterations = 10)] +async fn test_reloading_buffer_manually( + executor: BackgroundExecutor, + cx_a: &mut TestAppContext, + cx_b: &mut TestAppContext, +) { + let mut server = TestServer::start(executor.clone()).await; + let client_a = server.create_client(cx_a, "user_a").await; + let client_b = server.create_client(cx_b, "user_b").await; + server + .create_room(&mut [(&client_a, cx_a), (&client_b, cx_b)]) + .await; + let active_call_a = cx_a.read(ActiveCall::global); + + client_a + .fs() + .insert_tree("/a", json!({ "a.rs": "let one = 1;" })) + .await; + let (project_a, worktree_id) = client_a.build_local_project("/a", cx_a).await; + let buffer_a = project_a + .update(cx_a, |p, cx| p.open_buffer((worktree_id, "a.rs"), cx)) + .await + .unwrap(); + let project_id = active_call_a + .update(cx_a, |call, cx| call.share_project(project_a.clone(), cx)) + .await + .unwrap(); + + let project_b = client_b.build_remote_project(project_id, cx_b).await; + + let open_buffer = project_b.update(cx_b, |p, cx| p.open_buffer((worktree_id, "a.rs"), cx)); + let buffer_b = cx_b.executor().spawn(open_buffer).await.unwrap(); + buffer_b.update(cx_b, |buffer, cx| { + buffer.edit([(4..7, "six")], None, cx); + buffer.edit([(10..11, "6")], None, cx); + assert_eq!(buffer.text(), "let six = 6;"); + assert!(buffer.is_dirty()); + assert!(!buffer.has_conflict()); + }); + executor.run_until_parked(); + + buffer_a.read_with(cx_a, |buffer, _| assert_eq!(buffer.text(), "let six = 6;")); + + client_a + .fs() + .save( + "/a/a.rs".as_ref(), + &Rope::from("let seven = 7;"), + LineEnding::Unix, + ) + .await + .unwrap(); + executor.run_until_parked(); + + buffer_a.read_with(cx_a, |buffer, _| assert!(buffer.has_conflict())); + + buffer_b.read_with(cx_b, |buffer, _| assert!(buffer.has_conflict())); + + project_b + .update(cx_b, |project, cx| { + project.reload_buffers(HashSet::from_iter([buffer_b.clone()]), true, cx) + }) + .await + .unwrap(); + + buffer_a.read_with(cx_a, |buffer, _| { + assert_eq!(buffer.text(), "let seven = 7;"); + assert!(!buffer.is_dirty()); + assert!(!buffer.has_conflict()); + }); + + buffer_b.read_with(cx_b, |buffer, _| { + assert_eq!(buffer.text(), "let seven = 7;"); + assert!(!buffer.is_dirty()); + assert!(!buffer.has_conflict()); + }); + + buffer_a.update(cx_a, |buffer, cx| { + // Undoing on the host is a no-op when the reload was initiated by the guest. + buffer.undo(cx); + assert_eq!(buffer.text(), "let seven = 7;"); + assert!(!buffer.is_dirty()); + assert!(!buffer.has_conflict()); + }); + buffer_b.update(cx_b, |buffer, cx| { + // Undoing on the guest rolls back the buffer to before it was reloaded but the conflict gets cleared. + buffer.undo(cx); + assert_eq!(buffer.text(), "let six = 6;"); + assert!(buffer.is_dirty()); + assert!(!buffer.has_conflict()); + }); +} + +#[gpui::test(iterations = 10)] +async fn test_formatting_buffer( + executor: BackgroundExecutor, + cx_a: &mut TestAppContext, + cx_b: &mut TestAppContext, +) { + executor.allow_parking(); + let mut server = TestServer::start(executor.clone()).await; + let client_a = server.create_client(cx_a, "user_a").await; + let client_b = server.create_client(cx_b, "user_b").await; + server + .create_room(&mut [(&client_a, cx_a), (&client_b, cx_b)]) + .await; + let active_call_a = cx_a.read(ActiveCall::global); + + // Set up a fake language server. + let mut language = Language::new( + LanguageConfig { + name: "Rust".into(), + path_suffixes: vec!["rs".to_string()], + ..Default::default() + }, + Some(tree_sitter_rust::language()), + ); + let mut fake_language_servers = language.set_fake_lsp_adapter(Default::default()).await; + client_a.language_registry().add(Arc::new(language)); + + // Here we insert a fake tree with a directory that exists on disk. This is needed + // because later we'll invoke a command, which requires passing a working directory + // that points to a valid location on disk. + let directory = env::current_dir().unwrap(); + client_a + .fs() + .insert_tree(&directory, json!({ "a.rs": "let one = \"two\"" })) + .await; + let (project_a, worktree_id) = client_a.build_local_project(&directory, cx_a).await; + let project_id = active_call_a + .update(cx_a, |call, cx| call.share_project(project_a.clone(), cx)) + .await + .unwrap(); + let project_b = client_b.build_remote_project(project_id, cx_b).await; + + let open_buffer = project_b.update(cx_b, |p, cx| p.open_buffer((worktree_id, "a.rs"), cx)); + let buffer_b = cx_b.executor().spawn(open_buffer).await.unwrap(); + + let fake_language_server = fake_language_servers.next().await.unwrap(); + fake_language_server.handle_request::(|_, _| async move { + Ok(Some(vec![ + lsp::TextEdit { + range: lsp::Range::new(lsp::Position::new(0, 4), lsp::Position::new(0, 4)), + new_text: "h".to_string(), + }, + lsp::TextEdit { + range: lsp::Range::new(lsp::Position::new(0, 7), lsp::Position::new(0, 7)), + new_text: "y".to_string(), + }, + ])) + }); + + project_b + .update(cx_b, |project, cx| { + project.format( + HashSet::from_iter([buffer_b.clone()]), + true, + FormatTrigger::Save, + cx, + ) + }) + .await + .unwrap(); + + // The edits from the LSP are applied, and a final newline is added. + assert_eq!( + buffer_b.read_with(cx_b, |buffer, _| buffer.text()), + "let honey = \"two\"\n" + ); + + // Ensure buffer can be formatted using an external command. Notice how the + // host's configuration is honored as opposed to using the guest's settings. + cx_a.update(|cx| { + cx.update_global(|store: &mut SettingsStore, cx| { + store.update_user_settings::(cx, |file| { + file.defaults.formatter = Some(Formatter::External { + command: "awk".into(), + arguments: vec!["{sub(/two/,\"{buffer_path}\")}1".to_string()].into(), + }); + }); + }); + }); + project_b + .update(cx_b, |project, cx| { + project.format( + HashSet::from_iter([buffer_b.clone()]), + true, + FormatTrigger::Save, + cx, + ) + }) + .await + .unwrap(); + assert_eq!( + buffer_b.read_with(cx_b, |buffer, _| buffer.text()), + format!("let honey = \"{}/a.rs\"\n", directory.to_str().unwrap()) + ); +} + +#[gpui::test(iterations = 10)] +async fn test_prettier_formatting_buffer( + executor: BackgroundExecutor, + cx_a: &mut TestAppContext, + cx_b: &mut TestAppContext, +) { + let mut server = TestServer::start(executor.clone()).await; + let client_a = server.create_client(cx_a, "user_a").await; + let client_b = server.create_client(cx_b, "user_b").await; + server + .create_room(&mut [(&client_a, cx_a), (&client_b, cx_b)]) + .await; + let active_call_a = cx_a.read(ActiveCall::global); + + // Set up a fake language server. + let mut language = Language::new( + LanguageConfig { + name: "Rust".into(), + path_suffixes: vec!["rs".to_string()], + prettier_parser_name: Some("test_parser".to_string()), + ..Default::default() + }, + Some(tree_sitter_rust::language()), + ); + let test_plugin = "test_plugin"; + let mut fake_language_servers = language + .set_fake_lsp_adapter(Arc::new(FakeLspAdapter { + prettier_plugins: vec![test_plugin], + ..Default::default() + })) + .await; + let language = Arc::new(language); + client_a.language_registry().add(Arc::clone(&language)); + + // Here we insert a fake tree with a directory that exists on disk. This is needed + // because later we'll invoke a command, which requires passing a working directory + // that points to a valid location on disk. + let directory = env::current_dir().unwrap(); + let buffer_text = "let one = \"two\""; + client_a + .fs() + .insert_tree(&directory, json!({ "a.rs": buffer_text })) + .await; + let (project_a, worktree_id) = client_a.build_local_project(&directory, cx_a).await; + let prettier_format_suffix = project::TEST_PRETTIER_FORMAT_SUFFIX; + let open_buffer = project_a.update(cx_a, |p, cx| p.open_buffer((worktree_id, "a.rs"), cx)); + let buffer_a = cx_a.executor().spawn(open_buffer).await.unwrap(); + + let project_id = active_call_a + .update(cx_a, |call, cx| call.share_project(project_a.clone(), cx)) + .await + .unwrap(); + let project_b = client_b.build_remote_project(project_id, cx_b).await; + let open_buffer = project_b.update(cx_b, |p, cx| p.open_buffer((worktree_id, "a.rs"), cx)); + let buffer_b = cx_b.executor().spawn(open_buffer).await.unwrap(); + + cx_a.update(|cx| { + cx.update_global(|store: &mut SettingsStore, cx| { + store.update_user_settings::(cx, |file| { + file.defaults.formatter = Some(Formatter::Auto); + }); + }); + }); + cx_b.update(|cx| { + cx.update_global(|store: &mut SettingsStore, cx| { + store.update_user_settings::(cx, |file| { + file.defaults.formatter = Some(Formatter::LanguageServer); + }); + }); + }); + let fake_language_server = fake_language_servers.next().await.unwrap(); + fake_language_server.handle_request::(|_, _| async move { + panic!( + "Unexpected: prettier should be preferred since it's enabled and language supports it" + ) + }); + + project_b + .update(cx_b, |project, cx| { + project.format( + HashSet::from_iter([buffer_b.clone()]), + true, + FormatTrigger::Save, + cx, + ) + }) + .await + .unwrap(); + + executor.run_until_parked(); + assert_eq!( + buffer_b.read_with(cx_b, |buffer, _| buffer.text()), + buffer_text.to_string() + "\n" + prettier_format_suffix, + "Prettier formatting was not applied to client buffer after client's request" + ); + + project_a + .update(cx_a, |project, cx| { + project.format( + HashSet::from_iter([buffer_a.clone()]), + true, + FormatTrigger::Manual, + cx, + ) + }) + .await + .unwrap(); + + executor.run_until_parked(); + assert_eq!( + buffer_b.read_with(cx_b, |buffer, _| buffer.text()), + buffer_text.to_string() + "\n" + prettier_format_suffix + "\n" + prettier_format_suffix, + "Prettier formatting was not applied to client buffer after host's request" + ); +} + +#[gpui::test(iterations = 10)] +async fn test_definition( + executor: BackgroundExecutor, + cx_a: &mut TestAppContext, + cx_b: &mut TestAppContext, +) { + let mut server = TestServer::start(executor.clone()).await; + let client_a = server.create_client(cx_a, "user_a").await; + let client_b = server.create_client(cx_b, "user_b").await; + server + .create_room(&mut [(&client_a, cx_a), (&client_b, cx_b)]) + .await; + let active_call_a = cx_a.read(ActiveCall::global); + + // Set up a fake language server. + let mut language = Language::new( + LanguageConfig { + name: "Rust".into(), + path_suffixes: vec!["rs".to_string()], + ..Default::default() + }, + Some(tree_sitter_rust::language()), + ); + let mut fake_language_servers = language.set_fake_lsp_adapter(Default::default()).await; + client_a.language_registry().add(Arc::new(language)); + + client_a + .fs() + .insert_tree( + "/root", + json!({ + "dir-1": { + "a.rs": "const ONE: usize = b::TWO + b::THREE;", + }, + "dir-2": { + "b.rs": "const TWO: c::T2 = 2;\nconst THREE: usize = 3;", + "c.rs": "type T2 = usize;", + } + }), + ) + .await; + let (project_a, worktree_id) = client_a.build_local_project("/root/dir-1", cx_a).await; + let project_id = active_call_a + .update(cx_a, |call, cx| call.share_project(project_a.clone(), cx)) + .await + .unwrap(); + let project_b = client_b.build_remote_project(project_id, cx_b).await; + + // Open the file on client B. + let open_buffer = project_b.update(cx_b, |p, cx| p.open_buffer((worktree_id, "a.rs"), cx)); + let buffer_b = cx_b.executor().spawn(open_buffer).await.unwrap(); + + // Request the definition of a symbol as the guest. + let fake_language_server = fake_language_servers.next().await.unwrap(); + fake_language_server.handle_request::(|_, _| async move { + Ok(Some(lsp::GotoDefinitionResponse::Scalar( + lsp::Location::new( + lsp::Url::from_file_path("/root/dir-2/b.rs").unwrap(), + lsp::Range::new(lsp::Position::new(0, 6), lsp::Position::new(0, 9)), + ), + ))) + }); + + let definitions_1 = project_b + .update(cx_b, |p, cx| p.definition(&buffer_b, 23, cx)) + .await + .unwrap(); + cx_b.read(|cx| { + assert_eq!(definitions_1.len(), 1); + assert_eq!(project_b.read(cx).worktrees().count(), 2); + let target_buffer = definitions_1[0].target.buffer.read(cx); + assert_eq!( + target_buffer.text(), + "const TWO: c::T2 = 2;\nconst THREE: usize = 3;" + ); + assert_eq!( + definitions_1[0].target.range.to_point(target_buffer), + Point::new(0, 6)..Point::new(0, 9) + ); + }); + + // Try getting more definitions for the same buffer, ensuring the buffer gets reused from + // the previous call to `definition`. + fake_language_server.handle_request::(|_, _| async move { + Ok(Some(lsp::GotoDefinitionResponse::Scalar( + lsp::Location::new( + lsp::Url::from_file_path("/root/dir-2/b.rs").unwrap(), + lsp::Range::new(lsp::Position::new(1, 6), lsp::Position::new(1, 11)), + ), + ))) + }); + + let definitions_2 = project_b + .update(cx_b, |p, cx| p.definition(&buffer_b, 33, cx)) + .await + .unwrap(); + cx_b.read(|cx| { + assert_eq!(definitions_2.len(), 1); + assert_eq!(project_b.read(cx).worktrees().count(), 2); + let target_buffer = definitions_2[0].target.buffer.read(cx); + assert_eq!( + target_buffer.text(), + "const TWO: c::T2 = 2;\nconst THREE: usize = 3;" + ); + assert_eq!( + definitions_2[0].target.range.to_point(target_buffer), + Point::new(1, 6)..Point::new(1, 11) + ); + }); + assert_eq!( + definitions_1[0].target.buffer, + definitions_2[0].target.buffer + ); + + fake_language_server.handle_request::( + |req, _| async move { + assert_eq!( + req.text_document_position_params.position, + lsp::Position::new(0, 7) + ); + Ok(Some(lsp::GotoDefinitionResponse::Scalar( + lsp::Location::new( + lsp::Url::from_file_path("/root/dir-2/c.rs").unwrap(), + lsp::Range::new(lsp::Position::new(0, 5), lsp::Position::new(0, 7)), + ), + ))) + }, + ); + + let type_definitions = project_b + .update(cx_b, |p, cx| p.type_definition(&buffer_b, 7, cx)) + .await + .unwrap(); + cx_b.read(|cx| { + assert_eq!(type_definitions.len(), 1); + let target_buffer = type_definitions[0].target.buffer.read(cx); + assert_eq!(target_buffer.text(), "type T2 = usize;"); + assert_eq!( + type_definitions[0].target.range.to_point(target_buffer), + Point::new(0, 5)..Point::new(0, 7) + ); + }); +} + +#[gpui::test(iterations = 10)] +async fn test_references( + executor: BackgroundExecutor, + cx_a: &mut TestAppContext, + cx_b: &mut TestAppContext, +) { + let mut server = TestServer::start(executor.clone()).await; + let client_a = server.create_client(cx_a, "user_a").await; + let client_b = server.create_client(cx_b, "user_b").await; + server + .create_room(&mut [(&client_a, cx_a), (&client_b, cx_b)]) + .await; + let active_call_a = cx_a.read(ActiveCall::global); + + // Set up a fake language server. + let mut language = Language::new( + LanguageConfig { + name: "Rust".into(), + path_suffixes: vec!["rs".to_string()], + ..Default::default() + }, + Some(tree_sitter_rust::language()), + ); + let mut fake_language_servers = language.set_fake_lsp_adapter(Default::default()).await; + client_a.language_registry().add(Arc::new(language)); + + client_a + .fs() + .insert_tree( + "/root", + json!({ + "dir-1": { + "one.rs": "const ONE: usize = 1;", + "two.rs": "const TWO: usize = one::ONE + one::ONE;", + }, + "dir-2": { + "three.rs": "const THREE: usize = two::TWO + one::ONE;", + } + }), + ) + .await; + let (project_a, worktree_id) = client_a.build_local_project("/root/dir-1", cx_a).await; + let project_id = active_call_a + .update(cx_a, |call, cx| call.share_project(project_a.clone(), cx)) + .await + .unwrap(); + let project_b = client_b.build_remote_project(project_id, cx_b).await; + + // Open the file on client B. + let open_buffer = project_b.update(cx_b, |p, cx| p.open_buffer((worktree_id, "one.rs"), cx)); + let buffer_b = cx_b.executor().spawn(open_buffer).await.unwrap(); + + // Request references to a symbol as the guest. + let fake_language_server = fake_language_servers.next().await.unwrap(); + fake_language_server.handle_request::(|params, _| async move { + assert_eq!( + params.text_document_position.text_document.uri.as_str(), + "file:///root/dir-1/one.rs" + ); + Ok(Some(vec![ + lsp::Location { + uri: lsp::Url::from_file_path("/root/dir-1/two.rs").unwrap(), + range: lsp::Range::new(lsp::Position::new(0, 24), lsp::Position::new(0, 27)), + }, + lsp::Location { + uri: lsp::Url::from_file_path("/root/dir-1/two.rs").unwrap(), + range: lsp::Range::new(lsp::Position::new(0, 35), lsp::Position::new(0, 38)), + }, + lsp::Location { + uri: lsp::Url::from_file_path("/root/dir-2/three.rs").unwrap(), + range: lsp::Range::new(lsp::Position::new(0, 37), lsp::Position::new(0, 40)), + }, + ])) + }); + + let references = project_b + .update(cx_b, |p, cx| p.references(&buffer_b, 7, cx)) + .await + .unwrap(); + cx_b.read(|cx| { + assert_eq!(references.len(), 3); + assert_eq!(project_b.read(cx).worktrees().count(), 2); + + let two_buffer = references[0].buffer.read(cx); + let three_buffer = references[2].buffer.read(cx); + assert_eq!( + two_buffer.file().unwrap().path().as_ref(), + Path::new("two.rs") + ); + assert_eq!(references[1].buffer, references[0].buffer); + assert_eq!( + three_buffer.file().unwrap().full_path(cx), + Path::new("/root/dir-2/three.rs") + ); + + assert_eq!(references[0].range.to_offset(two_buffer), 24..27); + assert_eq!(references[1].range.to_offset(two_buffer), 35..38); + assert_eq!(references[2].range.to_offset(three_buffer), 37..40); + }); +} + +#[gpui::test(iterations = 10)] +async fn test_project_search( + executor: BackgroundExecutor, + cx_a: &mut TestAppContext, + cx_b: &mut TestAppContext, +) { + let mut server = TestServer::start(executor.clone()).await; + let client_a = server.create_client(cx_a, "user_a").await; + let client_b = server.create_client(cx_b, "user_b").await; + server + .create_room(&mut [(&client_a, cx_a), (&client_b, cx_b)]) + .await; + let active_call_a = cx_a.read(ActiveCall::global); + + client_a + .fs() + .insert_tree( + "/root", + json!({ + "dir-1": { + "a": "hello world", + "b": "goodnight moon", + "c": "a world of goo", + "d": "world champion of clown world", + }, + "dir-2": { + "e": "disney world is fun", + } + }), + ) + .await; + let (project_a, _) = client_a.build_local_project("/root/dir-1", cx_a).await; + let (worktree_2, _) = project_a + .update(cx_a, |p, cx| { + p.find_or_create_local_worktree("/root/dir-2", true, cx) + }) + .await + .unwrap(); + worktree_2 + .read_with(cx_a, |tree, _| tree.as_local().unwrap().scan_complete()) + .await; + let project_id = active_call_a + .update(cx_a, |call, cx| call.share_project(project_a.clone(), cx)) + .await + .unwrap(); + + let project_b = client_b.build_remote_project(project_id, cx_b).await; + + // Perform a search as the guest. + let mut results = HashMap::default(); + let mut search_rx = project_b.update(cx_b, |project, cx| { + project.search( + SearchQuery::text("world", false, false, Vec::new(), Vec::new()).unwrap(), + cx, + ) + }); + while let Some((buffer, ranges)) = search_rx.next().await { + results.entry(buffer).or_insert(ranges); + } + + let mut ranges_by_path = results + .into_iter() + .map(|(buffer, ranges)| { + buffer.read_with(cx_b, |buffer, cx| { + let path = buffer.file().unwrap().full_path(cx); + let offset_ranges = ranges + .into_iter() + .map(|range| range.to_offset(buffer)) + .collect::>(); + (path, offset_ranges) + }) + }) + .collect::>(); + ranges_by_path.sort_by_key(|(path, _)| path.clone()); + + assert_eq!( + ranges_by_path, + &[ + (PathBuf::from("dir-1/a"), vec![6..11]), + (PathBuf::from("dir-1/c"), vec![2..7]), + (PathBuf::from("dir-1/d"), vec![0..5, 24..29]), + (PathBuf::from("dir-2/e"), vec![7..12]), + ] + ); +} + +#[gpui::test(iterations = 10)] +async fn test_document_highlights( + executor: BackgroundExecutor, + cx_a: &mut TestAppContext, + cx_b: &mut TestAppContext, +) { + let mut server = TestServer::start(executor.clone()).await; + let client_a = server.create_client(cx_a, "user_a").await; + let client_b = server.create_client(cx_b, "user_b").await; + server + .create_room(&mut [(&client_a, cx_a), (&client_b, cx_b)]) + .await; + let active_call_a = cx_a.read(ActiveCall::global); + + client_a + .fs() + .insert_tree( + "/root-1", + json!({ + "main.rs": "fn double(number: i32) -> i32 { number + number }", + }), + ) + .await; + + // Set up a fake language server. + let mut language = Language::new( + LanguageConfig { + name: "Rust".into(), + path_suffixes: vec!["rs".to_string()], + ..Default::default() + }, + Some(tree_sitter_rust::language()), + ); + let mut fake_language_servers = language.set_fake_lsp_adapter(Default::default()).await; + client_a.language_registry().add(Arc::new(language)); + + let (project_a, worktree_id) = client_a.build_local_project("/root-1", cx_a).await; + let project_id = active_call_a + .update(cx_a, |call, cx| call.share_project(project_a.clone(), cx)) + .await + .unwrap(); + let project_b = client_b.build_remote_project(project_id, cx_b).await; + + // Open the file on client B. + let open_b = project_b.update(cx_b, |p, cx| p.open_buffer((worktree_id, "main.rs"), cx)); + let buffer_b = cx_b.executor().spawn(open_b).await.unwrap(); + + // Request document highlights as the guest. + let fake_language_server = fake_language_servers.next().await.unwrap(); + fake_language_server.handle_request::( + |params, _| async move { + assert_eq!( + params + .text_document_position_params + .text_document + .uri + .as_str(), + "file:///root-1/main.rs" + ); + assert_eq!( + params.text_document_position_params.position, + lsp::Position::new(0, 34) + ); + Ok(Some(vec![ + lsp::DocumentHighlight { + kind: Some(lsp::DocumentHighlightKind::WRITE), + range: lsp::Range::new(lsp::Position::new(0, 10), lsp::Position::new(0, 16)), + }, + lsp::DocumentHighlight { + kind: Some(lsp::DocumentHighlightKind::READ), + range: lsp::Range::new(lsp::Position::new(0, 32), lsp::Position::new(0, 38)), + }, + lsp::DocumentHighlight { + kind: Some(lsp::DocumentHighlightKind::READ), + range: lsp::Range::new(lsp::Position::new(0, 41), lsp::Position::new(0, 47)), + }, + ])) + }, + ); + + let highlights = project_b + .update(cx_b, |p, cx| p.document_highlights(&buffer_b, 34, cx)) + .await + .unwrap(); + + buffer_b.read_with(cx_b, |buffer, _| { + let snapshot = buffer.snapshot(); + + let highlights = highlights + .into_iter() + .map(|highlight| (highlight.kind, highlight.range.to_offset(&snapshot))) + .collect::>(); + assert_eq!( + highlights, + &[ + (lsp::DocumentHighlightKind::WRITE, 10..16), + (lsp::DocumentHighlightKind::READ, 32..38), + (lsp::DocumentHighlightKind::READ, 41..47) + ] + ) + }); +} + +#[gpui::test(iterations = 10)] +async fn test_lsp_hover( + executor: BackgroundExecutor, + cx_a: &mut TestAppContext, + cx_b: &mut TestAppContext, +) { + let mut server = TestServer::start(executor.clone()).await; + let client_a = server.create_client(cx_a, "user_a").await; + let client_b = server.create_client(cx_b, "user_b").await; + server + .create_room(&mut [(&client_a, cx_a), (&client_b, cx_b)]) + .await; + let active_call_a = cx_a.read(ActiveCall::global); + + client_a + .fs() + .insert_tree( + "/root-1", + json!({ + "main.rs": "use std::collections::HashMap;", + }), + ) + .await; + + // Set up a fake language server. + let mut language = Language::new( + LanguageConfig { + name: "Rust".into(), + path_suffixes: vec!["rs".to_string()], + ..Default::default() + }, + Some(tree_sitter_rust::language()), + ); + let mut fake_language_servers = language.set_fake_lsp_adapter(Default::default()).await; + client_a.language_registry().add(Arc::new(language)); + + let (project_a, worktree_id) = client_a.build_local_project("/root-1", cx_a).await; + let project_id = active_call_a + .update(cx_a, |call, cx| call.share_project(project_a.clone(), cx)) + .await + .unwrap(); + let project_b = client_b.build_remote_project(project_id, cx_b).await; + + // Open the file as the guest + let open_buffer = project_b.update(cx_b, |p, cx| p.open_buffer((worktree_id, "main.rs"), cx)); + let buffer_b = cx_b.executor().spawn(open_buffer).await.unwrap(); + + // Request hover information as the guest. + let fake_language_server = fake_language_servers.next().await.unwrap(); + fake_language_server.handle_request::( + |params, _| async move { + assert_eq!( + params + .text_document_position_params + .text_document + .uri + .as_str(), + "file:///root-1/main.rs" + ); + assert_eq!( + params.text_document_position_params.position, + lsp::Position::new(0, 22) + ); + Ok(Some(lsp::Hover { + contents: lsp::HoverContents::Array(vec![ + lsp::MarkedString::String("Test hover content.".to_string()), + lsp::MarkedString::LanguageString(lsp::LanguageString { + language: "Rust".to_string(), + value: "let foo = 42;".to_string(), + }), + ]), + range: Some(lsp::Range::new( + lsp::Position::new(0, 22), + lsp::Position::new(0, 29), + )), + })) + }, + ); + + let hover_info = project_b + .update(cx_b, |p, cx| p.hover(&buffer_b, 22, cx)) + .await + .unwrap() + .unwrap(); + + buffer_b.read_with(cx_b, |buffer, _| { + let snapshot = buffer.snapshot(); + assert_eq!(hover_info.range.unwrap().to_offset(&snapshot), 22..29); + assert_eq!( + hover_info.contents, + vec![ + project::HoverBlock { + text: "Test hover content.".to_string(), + kind: HoverBlockKind::Markdown, + }, + project::HoverBlock { + text: "let foo = 42;".to_string(), + kind: HoverBlockKind::Code { + language: "Rust".to_string() + }, + } + ] + ); + }); +} + +#[gpui::test(iterations = 10)] +async fn test_project_symbols( + executor: BackgroundExecutor, + cx_a: &mut TestAppContext, + cx_b: &mut TestAppContext, +) { + let mut server = TestServer::start(executor.clone()).await; + let client_a = server.create_client(cx_a, "user_a").await; + let client_b = server.create_client(cx_b, "user_b").await; + server + .create_room(&mut [(&client_a, cx_a), (&client_b, cx_b)]) + .await; + let active_call_a = cx_a.read(ActiveCall::global); + + // Set up a fake language server. + let mut language = Language::new( + LanguageConfig { + name: "Rust".into(), + path_suffixes: vec!["rs".to_string()], + ..Default::default() + }, + Some(tree_sitter_rust::language()), + ); + let mut fake_language_servers = language.set_fake_lsp_adapter(Default::default()).await; + client_a.language_registry().add(Arc::new(language)); + + client_a + .fs() + .insert_tree( + "/code", + json!({ + "crate-1": { + "one.rs": "const ONE: usize = 1;", + }, + "crate-2": { + "two.rs": "const TWO: usize = 2; const THREE: usize = 3;", + }, + "private": { + "passwords.txt": "the-password", + } + }), + ) + .await; + let (project_a, worktree_id) = client_a.build_local_project("/code/crate-1", cx_a).await; + let project_id = active_call_a + .update(cx_a, |call, cx| call.share_project(project_a.clone(), cx)) + .await + .unwrap(); + let project_b = client_b.build_remote_project(project_id, cx_b).await; + + // Cause the language server to start. + let open_buffer_task = + project_b.update(cx_b, |p, cx| p.open_buffer((worktree_id, "one.rs"), cx)); + let _buffer = cx_b.executor().spawn(open_buffer_task).await.unwrap(); + + let fake_language_server = fake_language_servers.next().await.unwrap(); + fake_language_server.handle_request::(|_, _| async move { + Ok(Some(lsp::WorkspaceSymbolResponse::Flat(vec![ + #[allow(deprecated)] + lsp::SymbolInformation { + name: "TWO".into(), + location: lsp::Location { + uri: lsp::Url::from_file_path("/code/crate-2/two.rs").unwrap(), + range: lsp::Range::new(lsp::Position::new(0, 6), lsp::Position::new(0, 9)), + }, + kind: lsp::SymbolKind::CONSTANT, + tags: None, + container_name: None, + deprecated: None, + }, + ]))) + }); + + // Request the definition of a symbol as the guest. + let symbols = project_b + .update(cx_b, |p, cx| p.symbols("two", cx)) + .await + .unwrap(); + assert_eq!(symbols.len(), 1); + assert_eq!(symbols[0].name, "TWO"); + + // Open one of the returned symbols. + let buffer_b_2 = project_b + .update(cx_b, |project, cx| { + project.open_buffer_for_symbol(&symbols[0], cx) + }) + .await + .unwrap(); + + buffer_b_2.read_with(cx_b, |buffer, _| { + assert_eq!( + buffer.file().unwrap().path().as_ref(), + Path::new("../crate-2/two.rs") + ); + }); + + // Attempt to craft a symbol and violate host's privacy by opening an arbitrary file. + let mut fake_symbol = symbols[0].clone(); + fake_symbol.path.path = Path::new("/code/secrets").into(); + let error = project_b + .update(cx_b, |project, cx| { + project.open_buffer_for_symbol(&fake_symbol, cx) + }) + .await + .unwrap_err(); + assert!(error.to_string().contains("invalid symbol signature")); +} + +#[gpui::test(iterations = 10)] +async fn test_open_buffer_while_getting_definition_pointing_to_it( + executor: BackgroundExecutor, + cx_a: &mut TestAppContext, + cx_b: &mut TestAppContext, + mut rng: StdRng, +) { + let mut server = TestServer::start(executor.clone()).await; + let client_a = server.create_client(cx_a, "user_a").await; + let client_b = server.create_client(cx_b, "user_b").await; + server + .create_room(&mut [(&client_a, cx_a), (&client_b, cx_b)]) + .await; + let active_call_a = cx_a.read(ActiveCall::global); + + // Set up a fake language server. + let mut language = Language::new( + LanguageConfig { + name: "Rust".into(), + path_suffixes: vec!["rs".to_string()], + ..Default::default() + }, + Some(tree_sitter_rust::language()), + ); + let mut fake_language_servers = language.set_fake_lsp_adapter(Default::default()).await; + client_a.language_registry().add(Arc::new(language)); + + client_a + .fs() + .insert_tree( + "/root", + json!({ + "a.rs": "const ONE: usize = b::TWO;", + "b.rs": "const TWO: usize = 2", + }), + ) + .await; + let (project_a, worktree_id) = client_a.build_local_project("/root", cx_a).await; + let project_id = active_call_a + .update(cx_a, |call, cx| call.share_project(project_a.clone(), cx)) + .await + .unwrap(); + let project_b = client_b.build_remote_project(project_id, cx_b).await; + + let open_buffer_task = project_b.update(cx_b, |p, cx| p.open_buffer((worktree_id, "a.rs"), cx)); + let buffer_b1 = cx_b.executor().spawn(open_buffer_task).await.unwrap(); + + let fake_language_server = fake_language_servers.next().await.unwrap(); + fake_language_server.handle_request::(|_, _| async move { + Ok(Some(lsp::GotoDefinitionResponse::Scalar( + lsp::Location::new( + lsp::Url::from_file_path("/root/b.rs").unwrap(), + lsp::Range::new(lsp::Position::new(0, 6), lsp::Position::new(0, 9)), + ), + ))) + }); + + let definitions; + let buffer_b2; + if rng.gen() { + definitions = project_b.update(cx_b, |p, cx| p.definition(&buffer_b1, 23, cx)); + buffer_b2 = project_b.update(cx_b, |p, cx| p.open_buffer((worktree_id, "b.rs"), cx)); + } else { + buffer_b2 = project_b.update(cx_b, |p, cx| p.open_buffer((worktree_id, "b.rs"), cx)); + definitions = project_b.update(cx_b, |p, cx| p.definition(&buffer_b1, 23, cx)); + } + + let buffer_b2 = buffer_b2.await.unwrap(); + let definitions = definitions.await.unwrap(); + assert_eq!(definitions.len(), 1); + assert_eq!(definitions[0].target.buffer, buffer_b2); +} + +#[gpui::test(iterations = 10)] +async fn test_contacts( + executor: BackgroundExecutor, + cx_a: &mut TestAppContext, + cx_b: &mut TestAppContext, + cx_c: &mut TestAppContext, + cx_d: &mut TestAppContext, +) { + let mut server = TestServer::start(executor.clone()).await; + let client_a = server.create_client(cx_a, "user_a").await; + let client_b = server.create_client(cx_b, "user_b").await; + let client_c = server.create_client(cx_c, "user_c").await; + let client_d = server.create_client(cx_d, "user_d").await; + server + .make_contacts(&mut [(&client_a, cx_a), (&client_b, cx_b), (&client_c, cx_c)]) + .await; + let active_call_a = cx_a.read(ActiveCall::global); + let active_call_b = cx_b.read(ActiveCall::global); + let active_call_c = cx_c.read(ActiveCall::global); + let _active_call_d = cx_d.read(ActiveCall::global); + + executor.run_until_parked(); + assert_eq!( + contacts(&client_a, cx_a), + [ + ("user_b".to_string(), "online", "free"), + ("user_c".to_string(), "online", "free") + ] + ); + assert_eq!( + contacts(&client_b, cx_b), + [ + ("user_a".to_string(), "online", "free"), + ("user_c".to_string(), "online", "free") + ] + ); + assert_eq!( + contacts(&client_c, cx_c), + [ + ("user_a".to_string(), "online", "free"), + ("user_b".to_string(), "online", "free") + ] + ); + assert_eq!(contacts(&client_d, cx_d), []); + + server.disconnect_client(client_c.peer_id().unwrap()); + server.forbid_connections(); + executor.advance_clock(RECEIVE_TIMEOUT + RECONNECT_TIMEOUT); + assert_eq!( + contacts(&client_a, cx_a), + [ + ("user_b".to_string(), "online", "free"), + ("user_c".to_string(), "offline", "free") + ] + ); + assert_eq!( + contacts(&client_b, cx_b), + [ + ("user_a".to_string(), "online", "free"), + ("user_c".to_string(), "offline", "free") + ] + ); + assert_eq!(contacts(&client_c, cx_c), []); + assert_eq!(contacts(&client_d, cx_d), []); + + server.allow_connections(); + client_c + .authenticate_and_connect(false, &cx_c.to_async()) + .await + .unwrap(); + + executor.run_until_parked(); + assert_eq!( + contacts(&client_a, cx_a), + [ + ("user_b".to_string(), "online", "free"), + ("user_c".to_string(), "online", "free") + ] + ); + assert_eq!( + contacts(&client_b, cx_b), + [ + ("user_a".to_string(), "online", "free"), + ("user_c".to_string(), "online", "free") + ] + ); + assert_eq!( + contacts(&client_c, cx_c), + [ + ("user_a".to_string(), "online", "free"), + ("user_b".to_string(), "online", "free") + ] + ); + assert_eq!(contacts(&client_d, cx_d), []); + + active_call_a + .update(cx_a, |call, cx| { + call.invite(client_b.user_id().unwrap(), None, cx) + }) + .await + .unwrap(); + executor.run_until_parked(); + assert_eq!( + contacts(&client_a, cx_a), + [ + ("user_b".to_string(), "online", "busy"), + ("user_c".to_string(), "online", "free") + ] + ); + assert_eq!( + contacts(&client_b, cx_b), + [ + ("user_a".to_string(), "online", "busy"), + ("user_c".to_string(), "online", "free") + ] + ); + assert_eq!( + contacts(&client_c, cx_c), + [ + ("user_a".to_string(), "online", "busy"), + ("user_b".to_string(), "online", "busy") + ] + ); + assert_eq!(contacts(&client_d, cx_d), []); + + // Client B and client D become contacts while client B is being called. + server + .make_contacts(&mut [(&client_b, cx_b), (&client_d, cx_d)]) + .await; + executor.run_until_parked(); + assert_eq!( + contacts(&client_a, cx_a), + [ + ("user_b".to_string(), "online", "busy"), + ("user_c".to_string(), "online", "free") + ] + ); + assert_eq!( + contacts(&client_b, cx_b), + [ + ("user_a".to_string(), "online", "busy"), + ("user_c".to_string(), "online", "free"), + ("user_d".to_string(), "online", "free"), + ] + ); + assert_eq!( + contacts(&client_c, cx_c), + [ + ("user_a".to_string(), "online", "busy"), + ("user_b".to_string(), "online", "busy") + ] + ); + assert_eq!( + contacts(&client_d, cx_d), + [("user_b".to_string(), "online", "busy")] + ); + + active_call_b.update(cx_b, |call, cx| call.decline_incoming(cx).unwrap()); + executor.run_until_parked(); + assert_eq!( + contacts(&client_a, cx_a), + [ + ("user_b".to_string(), "online", "free"), + ("user_c".to_string(), "online", "free") + ] + ); + assert_eq!( + contacts(&client_b, cx_b), + [ + ("user_a".to_string(), "online", "free"), + ("user_c".to_string(), "online", "free"), + ("user_d".to_string(), "online", "free") + ] + ); + assert_eq!( + contacts(&client_c, cx_c), + [ + ("user_a".to_string(), "online", "free"), + ("user_b".to_string(), "online", "free") + ] + ); + assert_eq!( + contacts(&client_d, cx_d), + [("user_b".to_string(), "online", "free")] + ); + + active_call_c + .update(cx_c, |call, cx| { + call.invite(client_a.user_id().unwrap(), None, cx) + }) + .await + .unwrap(); + executor.run_until_parked(); + assert_eq!( + contacts(&client_a, cx_a), + [ + ("user_b".to_string(), "online", "free"), + ("user_c".to_string(), "online", "busy") + ] + ); + assert_eq!( + contacts(&client_b, cx_b), + [ + ("user_a".to_string(), "online", "busy"), + ("user_c".to_string(), "online", "busy"), + ("user_d".to_string(), "online", "free") + ] + ); + assert_eq!( + contacts(&client_c, cx_c), + [ + ("user_a".to_string(), "online", "busy"), + ("user_b".to_string(), "online", "free") + ] + ); + assert_eq!( + contacts(&client_d, cx_d), + [("user_b".to_string(), "online", "free")] + ); + + active_call_a + .update(cx_a, |call, cx| call.accept_incoming(cx)) + .await + .unwrap(); + executor.run_until_parked(); + assert_eq!( + contacts(&client_a, cx_a), + [ + ("user_b".to_string(), "online", "free"), + ("user_c".to_string(), "online", "busy") + ] + ); + assert_eq!( + contacts(&client_b, cx_b), + [ + ("user_a".to_string(), "online", "busy"), + ("user_c".to_string(), "online", "busy"), + ("user_d".to_string(), "online", "free") + ] + ); + assert_eq!( + contacts(&client_c, cx_c), + [ + ("user_a".to_string(), "online", "busy"), + ("user_b".to_string(), "online", "free") + ] + ); + assert_eq!( + contacts(&client_d, cx_d), + [("user_b".to_string(), "online", "free")] + ); + + active_call_a + .update(cx_a, |call, cx| { + call.invite(client_b.user_id().unwrap(), None, cx) + }) + .await + .unwrap(); + executor.run_until_parked(); + assert_eq!( + contacts(&client_a, cx_a), + [ + ("user_b".to_string(), "online", "busy"), + ("user_c".to_string(), "online", "busy") + ] + ); + assert_eq!( + contacts(&client_b, cx_b), + [ + ("user_a".to_string(), "online", "busy"), + ("user_c".to_string(), "online", "busy"), + ("user_d".to_string(), "online", "free") + ] + ); + assert_eq!( + contacts(&client_c, cx_c), + [ + ("user_a".to_string(), "online", "busy"), + ("user_b".to_string(), "online", "busy") + ] + ); + assert_eq!( + contacts(&client_d, cx_d), + [("user_b".to_string(), "online", "busy")] + ); + + active_call_a + .update(cx_a, |call, cx| call.hang_up(cx)) + .await + .unwrap(); + executor.run_until_parked(); + assert_eq!( + contacts(&client_a, cx_a), + [ + ("user_b".to_string(), "online", "free"), + ("user_c".to_string(), "online", "free") + ] + ); + assert_eq!( + contacts(&client_b, cx_b), + [ + ("user_a".to_string(), "online", "free"), + ("user_c".to_string(), "online", "free"), + ("user_d".to_string(), "online", "free") + ] + ); + assert_eq!( + contacts(&client_c, cx_c), + [ + ("user_a".to_string(), "online", "free"), + ("user_b".to_string(), "online", "free") + ] + ); + assert_eq!( + contacts(&client_d, cx_d), + [("user_b".to_string(), "online", "free")] + ); + + active_call_a + .update(cx_a, |call, cx| { + call.invite(client_b.user_id().unwrap(), None, cx) + }) + .await + .unwrap(); + executor.run_until_parked(); + assert_eq!( + contacts(&client_a, cx_a), + [ + ("user_b".to_string(), "online", "busy"), + ("user_c".to_string(), "online", "free") + ] + ); + assert_eq!( + contacts(&client_b, cx_b), + [ + ("user_a".to_string(), "online", "busy"), + ("user_c".to_string(), "online", "free"), + ("user_d".to_string(), "online", "free") + ] + ); + assert_eq!( + contacts(&client_c, cx_c), + [ + ("user_a".to_string(), "online", "busy"), + ("user_b".to_string(), "online", "busy") + ] + ); + assert_eq!( + contacts(&client_d, cx_d), + [("user_b".to_string(), "online", "busy")] + ); + + server.forbid_connections(); + server.disconnect_client(client_a.peer_id().unwrap()); + executor.advance_clock(RECEIVE_TIMEOUT + RECONNECT_TIMEOUT); + assert_eq!(contacts(&client_a, cx_a), []); + assert_eq!( + contacts(&client_b, cx_b), + [ + ("user_a".to_string(), "offline", "free"), + ("user_c".to_string(), "online", "free"), + ("user_d".to_string(), "online", "free") + ] + ); + assert_eq!( + contacts(&client_c, cx_c), + [ + ("user_a".to_string(), "offline", "free"), + ("user_b".to_string(), "online", "free") + ] + ); + assert_eq!( + contacts(&client_d, cx_d), + [("user_b".to_string(), "online", "free")] + ); + + // Test removing a contact + client_b + .user_store() + .update(cx_b, |store, cx| { + store.remove_contact(client_c.user_id().unwrap(), cx) + }) + .await + .unwrap(); + executor.run_until_parked(); + assert_eq!( + contacts(&client_b, cx_b), + [ + ("user_a".to_string(), "offline", "free"), + ("user_d".to_string(), "online", "free") + ] + ); + assert_eq!( + contacts(&client_c, cx_c), + [("user_a".to_string(), "offline", "free"),] + ); + + fn contacts( + client: &TestClient, + cx: &TestAppContext, + ) -> Vec<(String, &'static str, &'static str)> { + client.user_store().read_with(cx, |store, _| { + store + .contacts() + .iter() + .map(|contact| { + ( + contact.user.github_login.clone(), + if contact.online { "online" } else { "offline" }, + if contact.busy { "busy" } else { "free" }, + ) + }) + .collect() + }) + } +} + +#[gpui::test(iterations = 10)] +async fn test_contact_requests( + executor: BackgroundExecutor, + cx_a: &mut TestAppContext, + cx_a2: &mut TestAppContext, + cx_b: &mut TestAppContext, + cx_b2: &mut TestAppContext, + cx_c: &mut TestAppContext, + cx_c2: &mut TestAppContext, +) { + // Connect to a server as 3 clients. + let mut server = TestServer::start(executor.clone()).await; + let client_a = server.create_client(cx_a, "user_a").await; + let client_a2 = server.create_client(cx_a2, "user_a").await; + let client_b = server.create_client(cx_b, "user_b").await; + let client_b2 = server.create_client(cx_b2, "user_b").await; + let client_c = server.create_client(cx_c, "user_c").await; + let client_c2 = server.create_client(cx_c2, "user_c").await; + + assert_eq!(client_a.user_id().unwrap(), client_a2.user_id().unwrap()); + assert_eq!(client_b.user_id().unwrap(), client_b2.user_id().unwrap()); + assert_eq!(client_c.user_id().unwrap(), client_c2.user_id().unwrap()); + + // User A and User C request that user B become their contact. + client_a + .user_store() + .update(cx_a, |store, cx| { + store.request_contact(client_b.user_id().unwrap(), cx) + }) + .await + .unwrap(); + client_c + .user_store() + .update(cx_c, |store, cx| { + store.request_contact(client_b.user_id().unwrap(), cx) + }) + .await + .unwrap(); + executor.run_until_parked(); + + // All users see the pending request appear in all their clients. + assert_eq!( + client_a.summarize_contacts(cx_a).outgoing_requests, + &["user_b"] + ); + assert_eq!( + client_a2.summarize_contacts(cx_a2).outgoing_requests, + &["user_b"] + ); + assert_eq!( + client_b.summarize_contacts(cx_b).incoming_requests, + &["user_a", "user_c"] + ); + assert_eq!( + client_b2.summarize_contacts(cx_b2).incoming_requests, + &["user_a", "user_c"] + ); + assert_eq!( + client_c.summarize_contacts(cx_c).outgoing_requests, + &["user_b"] + ); + assert_eq!( + client_c2.summarize_contacts(cx_c2).outgoing_requests, + &["user_b"] + ); + + // Contact requests are present upon connecting (tested here via disconnect/reconnect) + disconnect_and_reconnect(&client_a, cx_a).await; + disconnect_and_reconnect(&client_b, cx_b).await; + disconnect_and_reconnect(&client_c, cx_c).await; + executor.run_until_parked(); + assert_eq!( + client_a.summarize_contacts(cx_a).outgoing_requests, + &["user_b"] + ); + assert_eq!( + client_b.summarize_contacts(cx_b).incoming_requests, + &["user_a", "user_c"] + ); + assert_eq!( + client_c.summarize_contacts(cx_c).outgoing_requests, + &["user_b"] + ); + + // User B accepts the request from user A. + client_b + .user_store() + .update(cx_b, |store, cx| { + store.respond_to_contact_request(client_a.user_id().unwrap(), true, cx) + }) + .await + .unwrap(); + + executor.run_until_parked(); + + // User B sees user A as their contact now in all client, and the incoming request from them is removed. + let contacts_b = client_b.summarize_contacts(cx_b); + assert_eq!(contacts_b.current, &["user_a"]); + assert_eq!(contacts_b.incoming_requests, &["user_c"]); + let contacts_b2 = client_b2.summarize_contacts(cx_b2); + assert_eq!(contacts_b2.current, &["user_a"]); + assert_eq!(contacts_b2.incoming_requests, &["user_c"]); + + // User A sees user B as their contact now in all clients, and the outgoing request to them is removed. + let contacts_a = client_a.summarize_contacts(cx_a); + assert_eq!(contacts_a.current, &["user_b"]); + assert!(contacts_a.outgoing_requests.is_empty()); + let contacts_a2 = client_a2.summarize_contacts(cx_a2); + assert_eq!(contacts_a2.current, &["user_b"]); + assert!(contacts_a2.outgoing_requests.is_empty()); + + // Contacts are present upon connecting (tested here via disconnect/reconnect) + disconnect_and_reconnect(&client_a, cx_a).await; + disconnect_and_reconnect(&client_b, cx_b).await; + disconnect_and_reconnect(&client_c, cx_c).await; + executor.run_until_parked(); + assert_eq!(client_a.summarize_contacts(cx_a).current, &["user_b"]); + assert_eq!(client_b.summarize_contacts(cx_b).current, &["user_a"]); + assert_eq!( + client_b.summarize_contacts(cx_b).incoming_requests, + &["user_c"] + ); + assert!(client_c.summarize_contacts(cx_c).current.is_empty()); + assert_eq!( + client_c.summarize_contacts(cx_c).outgoing_requests, + &["user_b"] + ); + + // User B rejects the request from user C. + client_b + .user_store() + .update(cx_b, |store, cx| { + store.respond_to_contact_request(client_c.user_id().unwrap(), false, cx) + }) + .await + .unwrap(); + + executor.run_until_parked(); + + // User B doesn't see user C as their contact, and the incoming request from them is removed. + let contacts_b = client_b.summarize_contacts(cx_b); + assert_eq!(contacts_b.current, &["user_a"]); + assert!(contacts_b.incoming_requests.is_empty()); + let contacts_b2 = client_b2.summarize_contacts(cx_b2); + assert_eq!(contacts_b2.current, &["user_a"]); + assert!(contacts_b2.incoming_requests.is_empty()); + + // User C doesn't see user B as their contact, and the outgoing request to them is removed. + let contacts_c = client_c.summarize_contacts(cx_c); + assert!(contacts_c.current.is_empty()); + assert!(contacts_c.outgoing_requests.is_empty()); + let contacts_c2 = client_c2.summarize_contacts(cx_c2); + assert!(contacts_c2.current.is_empty()); + assert!(contacts_c2.outgoing_requests.is_empty()); + + // Incoming/outgoing requests are not present upon connecting (tested here via disconnect/reconnect) + disconnect_and_reconnect(&client_a, cx_a).await; + disconnect_and_reconnect(&client_b, cx_b).await; + disconnect_and_reconnect(&client_c, cx_c).await; + executor.run_until_parked(); + assert_eq!(client_a.summarize_contacts(cx_a).current, &["user_b"]); + assert_eq!(client_b.summarize_contacts(cx_b).current, &["user_a"]); + assert!(client_b + .summarize_contacts(cx_b) + .incoming_requests + .is_empty()); + assert!(client_c.summarize_contacts(cx_c).current.is_empty()); + assert!(client_c + .summarize_contacts(cx_c) + .outgoing_requests + .is_empty()); + + async fn disconnect_and_reconnect(client: &TestClient, cx: &mut TestAppContext) { + client.disconnect(&cx.to_async()); + client.clear_contacts(cx).await; + client + .authenticate_and_connect(false, &cx.to_async()) + .await + .unwrap(); + } +} + +#[gpui::test(iterations = 10)] +async fn test_join_call_after_screen_was_shared( + executor: BackgroundExecutor, + cx_a: &mut TestAppContext, + cx_b: &mut TestAppContext, +) { + let mut server = TestServer::start(executor.clone()).await; + + let client_a = server.create_client(cx_a, "user_a").await; + let client_b = server.create_client(cx_b, "user_b").await; + server + .make_contacts(&mut [(&client_a, cx_a), (&client_b, cx_b)]) + .await; + + let active_call_a = cx_a.read(ActiveCall::global); + let active_call_b = cx_b.read(ActiveCall::global); + + // Call users B and C from client A. + active_call_a + .update(cx_a, |call, cx| { + call.invite(client_b.user_id().unwrap(), None, cx) + }) + .await + .unwrap(); + + let room_a = active_call_a.read_with(cx_a, |call, _| call.room().unwrap().clone()); + executor.run_until_parked(); + assert_eq!( + room_participants(&room_a, cx_a), + RoomParticipants { + remote: Default::default(), + pending: vec!["user_b".to_string()] + } + ); + + // User B receives the call. + + let mut incoming_call_b = active_call_b.read_with(cx_b, |call, _| call.incoming()); + let call_b = incoming_call_b.next().await.unwrap().unwrap(); + assert_eq!(call_b.calling_user.github_login, "user_a"); + + // User A shares their screen + let display = MacOSDisplay::new(); + active_call_a + .update(cx_a, |call, cx| { + call.room().unwrap().update(cx, |room, cx| { + room.set_display_sources(vec![display.clone()]); + room.share_screen(cx) + }) + }) + .await + .unwrap(); + + client_b.user_store().update(cx_b, |user_store, _| { + user_store.clear_cache(); + }); + + // User B joins the room + active_call_b + .update(cx_b, |call, cx| call.accept_incoming(cx)) + .await + .unwrap(); + + let room_b = active_call_b.read_with(cx_b, |call, _| call.room().unwrap().clone()); + assert!(incoming_call_b.next().await.unwrap().is_none()); + + executor.run_until_parked(); + assert_eq!( + room_participants(&room_a, cx_a), + RoomParticipants { + remote: vec!["user_b".to_string()], + pending: vec![], + } + ); + assert_eq!( + room_participants(&room_b, cx_b), + RoomParticipants { + remote: vec!["user_a".to_string()], + pending: vec![], + } + ); + + // Ensure User B sees User A's screenshare. + + room_b.read_with(cx_b, |room, _| { + assert_eq!( + room.remote_participants() + .get(&client_a.user_id().unwrap()) + .unwrap() + .video_tracks + .len(), + 1 + ); + }); +} + +//todo!(editor) +// #[gpui::test(iterations = 10)] +// async fn test_on_input_format_from_host_to_guest( +// executor: BackgroundExecutor, +// cx_a: &mut TestAppContext, +// cx_b: &mut TestAppContext, +// ) { +// let mut server = TestServer::start(&executor).await; +// let client_a = server.create_client(cx_a, "user_a").await; +// let client_b = server.create_client(cx_b, "user_b").await; +// server +// .create_room(&mut [(&client_a, cx_a), (&client_b, cx_b)]) +// .await; +// let active_call_a = cx_a.read(ActiveCall::global); + +// // Set up a fake language server. +// let mut language = Language::new( +// LanguageConfig { +// name: "Rust".into(), +// path_suffixes: vec!["rs".to_string()], +// ..Default::default() +// }, +// Some(tree_sitter_rust::language()), +// ); +// let mut fake_language_servers = language +// .set_fake_lsp_adapter(Arc::new(FakeLspAdapter { +// capabilities: lsp::ServerCapabilities { +// document_on_type_formatting_provider: Some(lsp::DocumentOnTypeFormattingOptions { +// first_trigger_character: ":".to_string(), +// more_trigger_character: Some(vec![">".to_string()]), +// }), +// ..Default::default() +// }, +// ..Default::default() +// })) +// .await; +// client_a.language_registry().add(Arc::new(language)); + +// client_a +// .fs() +// .insert_tree( +// "/a", +// json!({ +// "main.rs": "fn main() { a }", +// "other.rs": "// Test file", +// }), +// ) +// .await; +// let (project_a, worktree_id) = client_a.build_local_project("/a", cx_a).await; +// let project_id = active_call_a +// .update(cx_a, |call, cx| call.share_project(project_a.clone(), cx)) +// .await +// .unwrap(); +// let project_b = client_b.build_remote_project(project_id, cx_b).await; + +// // Open a file in an editor as the host. +// let buffer_a = project_a +// .update(cx_a, |p, cx| p.open_buffer((worktree_id, "main.rs"), cx)) +// .await +// .unwrap(); +// let window_a = cx_a.add_window(|_| EmptyView); +// let editor_a = window_a.add_view(cx_a, |cx| { +// Editor::for_buffer(buffer_a, Some(project_a.clone()), cx) +// }); + +// let fake_language_server = fake_language_servers.next().await.unwrap(); +// executor.run_until_parked(); + +// // Receive an OnTypeFormatting request as the host's language server. +// // Return some formattings from the host's language server. +// fake_language_server.handle_request::( +// |params, _| async move { +// assert_eq!( +// params.text_document_position.text_document.uri, +// lsp::Url::from_file_path("/a/main.rs").unwrap(), +// ); +// assert_eq!( +// params.text_document_position.position, +// lsp::Position::new(0, 14), +// ); + +// Ok(Some(vec![lsp::TextEdit { +// new_text: "~<".to_string(), +// range: lsp::Range::new(lsp::Position::new(0, 14), lsp::Position::new(0, 14)), +// }])) +// }, +// ); + +// // Open the buffer on the guest and see that the formattings worked +// let buffer_b = project_b +// .update(cx_b, |p, cx| p.open_buffer((worktree_id, "main.rs"), cx)) +// .await +// .unwrap(); + +// // Type a on type formatting trigger character as the guest. +// editor_a.update(cx_a, |editor, cx| { +// cx.focus(&editor_a); +// editor.change_selections(None, cx, |s| s.select_ranges([13..13])); +// editor.handle_input(">", cx); +// }); + +// executor.run_until_parked(); + +// buffer_b.read_with(cx_b, |buffer, _| { +// assert_eq!(buffer.text(), "fn main() { a>~< }") +// }); + +// // Undo should remove LSP edits first +// editor_a.update(cx_a, |editor, cx| { +// assert_eq!(editor.text(cx), "fn main() { a>~< }"); +// editor.undo(&Undo, cx); +// assert_eq!(editor.text(cx), "fn main() { a> }"); +// }); +// executor.run_until_parked(); + +// buffer_b.read_with(cx_b, |buffer, _| { +// assert_eq!(buffer.text(), "fn main() { a> }") +// }); + +// editor_a.update(cx_a, |editor, cx| { +// assert_eq!(editor.text(cx), "fn main() { a> }"); +// editor.undo(&Undo, cx); +// assert_eq!(editor.text(cx), "fn main() { a }"); +// }); +// executor.run_until_parked(); + +// buffer_b.read_with(cx_b, |buffer, _| { +// assert_eq!(buffer.text(), "fn main() { a }") +// }); +// } + +// #[gpui::test(iterations = 10)] +// async fn test_on_input_format_from_guest_to_host( +// executor: BackgroundExecutor, +// cx_a: &mut TestAppContext, +// cx_b: &mut TestAppContext, +// ) { +// let mut server = TestServer::start(&executor).await; +// let client_a = server.create_client(cx_a, "user_a").await; +// let client_b = server.create_client(cx_b, "user_b").await; +// server +// .create_room(&mut [(&client_a, cx_a), (&client_b, cx_b)]) +// .await; +// let active_call_a = cx_a.read(ActiveCall::global); + +// // Set up a fake language server. +// let mut language = Language::new( +// LanguageConfig { +// name: "Rust".into(), +// path_suffixes: vec!["rs".to_string()], +// ..Default::default() +// }, +// Some(tree_sitter_rust::language()), +// ); +// let mut fake_language_servers = language +// .set_fake_lsp_adapter(Arc::new(FakeLspAdapter { +// capabilities: lsp::ServerCapabilities { +// document_on_type_formatting_provider: Some(lsp::DocumentOnTypeFormattingOptions { +// first_trigger_character: ":".to_string(), +// more_trigger_character: Some(vec![">".to_string()]), +// }), +// ..Default::default() +// }, +// ..Default::default() +// })) +// .await; +// client_a.language_registry().add(Arc::new(language)); + +// client_a +// .fs() +// .insert_tree( +// "/a", +// json!({ +// "main.rs": "fn main() { a }", +// "other.rs": "// Test file", +// }), +// ) +// .await; +// let (project_a, worktree_id) = client_a.build_local_project("/a", cx_a).await; +// let project_id = active_call_a +// .update(cx_a, |call, cx| call.share_project(project_a.clone(), cx)) +// .await +// .unwrap(); +// let project_b = client_b.build_remote_project(project_id, cx_b).await; + +// // Open a file in an editor as the guest. +// let buffer_b = project_b +// .update(cx_b, |p, cx| p.open_buffer((worktree_id, "main.rs"), cx)) +// .await +// .unwrap(); +// let window_b = cx_b.add_window(|_| EmptyView); +// let editor_b = window_b.add_view(cx_b, |cx| { +// Editor::for_buffer(buffer_b, Some(project_b.clone()), cx) +// }); + +// let fake_language_server = fake_language_servers.next().await.unwrap(); +// executor.run_until_parked(); +// // Type a on type formatting trigger character as the guest. +// editor_b.update(cx_b, |editor, cx| { +// editor.change_selections(None, cx, |s| s.select_ranges([13..13])); +// editor.handle_input(":", cx); +// cx.focus(&editor_b); +// }); + +// // Receive an OnTypeFormatting request as the host's language server. +// // Return some formattings from the host's language server. +// cx_a.foreground().start_waiting(); +// fake_language_server +// .handle_request::(|params, _| async move { +// assert_eq!( +// params.text_document_position.text_document.uri, +// lsp::Url::from_file_path("/a/main.rs").unwrap(), +// ); +// assert_eq!( +// params.text_document_position.position, +// lsp::Position::new(0, 14), +// ); + +// Ok(Some(vec![lsp::TextEdit { +// new_text: "~:".to_string(), +// range: lsp::Range::new(lsp::Position::new(0, 14), lsp::Position::new(0, 14)), +// }])) +// }) +// .next() +// .await +// .unwrap(); +// cx_a.foreground().finish_waiting(); + +// // Open the buffer on the host and see that the formattings worked +// let buffer_a = project_a +// .update(cx_a, |p, cx| p.open_buffer((worktree_id, "main.rs"), cx)) +// .await +// .unwrap(); +// executor.run_until_parked(); + +// buffer_a.read_with(cx_a, |buffer, _| { +// assert_eq!(buffer.text(), "fn main() { a:~: }") +// }); + +// // Undo should remove LSP edits first +// editor_b.update(cx_b, |editor, cx| { +// assert_eq!(editor.text(cx), "fn main() { a:~: }"); +// editor.undo(&Undo, cx); +// assert_eq!(editor.text(cx), "fn main() { a: }"); +// }); +// executor.run_until_parked(); + +// buffer_a.read_with(cx_a, |buffer, _| { +// assert_eq!(buffer.text(), "fn main() { a: }") +// }); + +// editor_b.update(cx_b, |editor, cx| { +// assert_eq!(editor.text(cx), "fn main() { a: }"); +// editor.undo(&Undo, cx); +// assert_eq!(editor.text(cx), "fn main() { a }"); +// }); +// executor.run_until_parked(); + +// buffer_a.read_with(cx_a, |buffer, _| { +// assert_eq!(buffer.text(), "fn main() { a }") +// }); +// } + +//todo!(editor) +// #[gpui::test(iterations = 10)] +// async fn test_mutual_editor_inlay_hint_cache_update( +// executor: BackgroundExecutor, +// cx_a: &mut TestAppContext, +// cx_b: &mut TestAppContext, +// ) { +// let mut server = TestServer::start(&executor).await; +// let client_a = server.create_client(cx_a, "user_a").await; +// let client_b = server.create_client(cx_b, "user_b").await; +// server +// .create_room(&mut [(&client_a, cx_a), (&client_b, cx_b)]) +// .await; +// let active_call_a = cx_a.read(ActiveCall::global); +// let active_call_b = cx_b.read(ActiveCall::global); + +// cx_a.update(editor::init); +// cx_b.update(editor::init); + +// cx_a.update(|cx| { +// cx.update_global(|store: &mut SettingsStore, cx| { +// store.update_user_settings::(cx, |settings| { +// settings.defaults.inlay_hints = Some(InlayHintSettings { +// enabled: true, +// show_type_hints: true, +// show_parameter_hints: false, +// show_other_hints: true, +// }) +// }); +// }); +// }); +// cx_b.update(|cx| { +// cx.update_global(|store: &mut SettingsStore, cx| { +// store.update_user_settings::(cx, |settings| { +// settings.defaults.inlay_hints = Some(InlayHintSettings { +// enabled: true, +// show_type_hints: true, +// show_parameter_hints: false, +// show_other_hints: true, +// }) +// }); +// }); +// }); + +// let mut language = Language::new( +// LanguageConfig { +// name: "Rust".into(), +// path_suffixes: vec!["rs".to_string()], +// ..Default::default() +// }, +// Some(tree_sitter_rust::language()), +// ); +// let mut fake_language_servers = language +// .set_fake_lsp_adapter(Arc::new(FakeLspAdapter { +// capabilities: lsp::ServerCapabilities { +// inlay_hint_provider: Some(lsp::OneOf::Left(true)), +// ..Default::default() +// }, +// ..Default::default() +// })) +// .await; +// let language = Arc::new(language); +// client_a.language_registry().add(Arc::clone(&language)); +// client_b.language_registry().add(language); + +// // Client A opens a project. +// client_a +// .fs() +// .insert_tree( +// "/a", +// json!({ +// "main.rs": "fn main() { a } // and some long comment to ensure inlay hints are not trimmed out", +// "other.rs": "// Test file", +// }), +// ) +// .await; +// let (project_a, worktree_id) = client_a.build_local_project("/a", cx_a).await; +// active_call_a +// .update(cx_a, |call, cx| call.set_location(Some(&project_a), cx)) +// .await +// .unwrap(); +// let project_id = active_call_a +// .update(cx_a, |call, cx| call.share_project(project_a.clone(), cx)) +// .await +// .unwrap(); + +// // Client B joins the project +// let project_b = client_b.build_remote_project(project_id, cx_b).await; +// active_call_b +// .update(cx_b, |call, cx| call.set_location(Some(&project_b), cx)) +// .await +// .unwrap(); + +// let workspace_a = client_a.build_workspace(&project_a, cx_a).root(cx_a); +// cx_a.foreground().start_waiting(); + +// // The host opens a rust file. +// let _buffer_a = project_a +// .update(cx_a, |project, cx| { +// project.open_local_buffer("/a/main.rs", cx) +// }) +// .await +// .unwrap(); +// let fake_language_server = fake_language_servers.next().await.unwrap(); +// let editor_a = workspace_a +// .update(cx_a, |workspace, cx| { +// workspace.open_path((worktree_id, "main.rs"), None, true, cx) +// }) +// .await +// .unwrap() +// .downcast::() +// .unwrap(); + +// // Set up the language server to return an additional inlay hint on each request. +// let edits_made = Arc::new(AtomicUsize::new(0)); +// let closure_edits_made = Arc::clone(&edits_made); +// fake_language_server +// .handle_request::(move |params, _| { +// let task_edits_made = Arc::clone(&closure_edits_made); +// async move { +// assert_eq!( +// params.text_document.uri, +// lsp::Url::from_file_path("/a/main.rs").unwrap(), +// ); +// let edits_made = task_edits_made.load(atomic::Ordering::Acquire); +// Ok(Some(vec![lsp::InlayHint { +// position: lsp::Position::new(0, edits_made as u32), +// label: lsp::InlayHintLabel::String(edits_made.to_string()), +// kind: None, +// text_edits: None, +// tooltip: None, +// padding_left: None, +// padding_right: None, +// data: None, +// }])) +// } +// }) +// .next() +// .await +// .unwrap(); + +// executor.run_until_parked(); + +// let initial_edit = edits_made.load(atomic::Ordering::Acquire); +// editor_a.update(cx_a, |editor, _| { +// assert_eq!( +// vec![initial_edit.to_string()], +// extract_hint_labels(editor), +// "Host should get its first hints when opens an editor" +// ); +// let inlay_cache = editor.inlay_hint_cache(); +// assert_eq!( +// inlay_cache.version(), +// 1, +// "Host editor update the cache version after every cache/view change", +// ); +// }); +// let workspace_b = client_b.build_workspace(&project_b, cx_b).root(cx_b); +// let editor_b = workspace_b +// .update(cx_b, |workspace, cx| { +// workspace.open_path((worktree_id, "main.rs"), None, true, cx) +// }) +// .await +// .unwrap() +// .downcast::() +// .unwrap(); + +// executor.run_until_parked(); +// editor_b.update(cx_b, |editor, _| { +// assert_eq!( +// vec![initial_edit.to_string()], +// extract_hint_labels(editor), +// "Client should get its first hints when opens an editor" +// ); +// let inlay_cache = editor.inlay_hint_cache(); +// assert_eq!( +// inlay_cache.version(), +// 1, +// "Guest editor update the cache version after every cache/view change" +// ); +// }); + +// let after_client_edit = edits_made.fetch_add(1, atomic::Ordering::Release) + 1; +// editor_b.update(cx_b, |editor, cx| { +// editor.change_selections(None, cx, |s| s.select_ranges([13..13].clone())); +// editor.handle_input(":", cx); +// cx.focus(&editor_b); +// }); + +// executor.run_until_parked(); +// editor_a.update(cx_a, |editor, _| { +// assert_eq!( +// vec![after_client_edit.to_string()], +// extract_hint_labels(editor), +// ); +// let inlay_cache = editor.inlay_hint_cache(); +// assert_eq!(inlay_cache.version(), 2); +// }); +// editor_b.update(cx_b, |editor, _| { +// assert_eq!( +// vec![after_client_edit.to_string()], +// extract_hint_labels(editor), +// ); +// let inlay_cache = editor.inlay_hint_cache(); +// assert_eq!(inlay_cache.version(), 2); +// }); + +// let after_host_edit = edits_made.fetch_add(1, atomic::Ordering::Release) + 1; +// editor_a.update(cx_a, |editor, cx| { +// editor.change_selections(None, cx, |s| s.select_ranges([13..13])); +// editor.handle_input("a change to increment both buffers' versions", cx); +// cx.focus(&editor_a); +// }); + +// executor.run_until_parked(); +// editor_a.update(cx_a, |editor, _| { +// assert_eq!( +// vec![after_host_edit.to_string()], +// extract_hint_labels(editor), +// ); +// let inlay_cache = editor.inlay_hint_cache(); +// assert_eq!(inlay_cache.version(), 3); +// }); +// editor_b.update(cx_b, |editor, _| { +// assert_eq!( +// vec![after_host_edit.to_string()], +// extract_hint_labels(editor), +// ); +// let inlay_cache = editor.inlay_hint_cache(); +// assert_eq!(inlay_cache.version(), 3); +// }); + +// let after_special_edit_for_refresh = edits_made.fetch_add(1, atomic::Ordering::Release) + 1; +// fake_language_server +// .request::(()) +// .await +// .expect("inlay refresh request failed"); + +// executor.run_until_parked(); +// editor_a.update(cx_a, |editor, _| { +// assert_eq!( +// vec![after_special_edit_for_refresh.to_string()], +// extract_hint_labels(editor), +// "Host should react to /refresh LSP request" +// ); +// let inlay_cache = editor.inlay_hint_cache(); +// assert_eq!( +// inlay_cache.version(), +// 4, +// "Host should accepted all edits and bump its cache version every time" +// ); +// }); +// editor_b.update(cx_b, |editor, _| { +// assert_eq!( +// vec![after_special_edit_for_refresh.to_string()], +// extract_hint_labels(editor), +// "Guest should get a /refresh LSP request propagated by host" +// ); +// let inlay_cache = editor.inlay_hint_cache(); +// assert_eq!( +// inlay_cache.version(), +// 4, +// "Guest should accepted all edits and bump its cache version every time" +// ); +// }); +// } + +//todo!(editor) +// #[gpui::test(iterations = 10)] +// async fn test_inlay_hint_refresh_is_forwarded( +// executor: BackgroundExecutor, +// cx_a: &mut TestAppContext, +// cx_b: &mut TestAppContext, +// ) { +// let mut server = TestServer::start(&executor).await; +// let client_a = server.create_client(cx_a, "user_a").await; +// let client_b = server.create_client(cx_b, "user_b").await; +// server +// .create_room(&mut [(&client_a, cx_a), (&client_b, cx_b)]) +// .await; +// let active_call_a = cx_a.read(ActiveCall::global); +// let active_call_b = cx_b.read(ActiveCall::global); + +// cx_a.update(editor::init); +// cx_b.update(editor::init); + +// cx_a.update(|cx| { +// cx.update_global(|store: &mut SettingsStore, cx| { +// store.update_user_settings::(cx, |settings| { +// settings.defaults.inlay_hints = Some(InlayHintSettings { +// enabled: false, +// show_type_hints: false, +// show_parameter_hints: false, +// show_other_hints: false, +// }) +// }); +// }); +// }); +// cx_b.update(|cx| { +// cx.update_global(|store: &mut SettingsStore, cx| { +// store.update_user_settings::(cx, |settings| { +// settings.defaults.inlay_hints = Some(InlayHintSettings { +// enabled: true, +// show_type_hints: true, +// show_parameter_hints: true, +// show_other_hints: true, +// }) +// }); +// }); +// }); + +// let mut language = Language::new( +// LanguageConfig { +// name: "Rust".into(), +// path_suffixes: vec!["rs".to_string()], +// ..Default::default() +// }, +// Some(tree_sitter_rust::language()), +// ); +// let mut fake_language_servers = language +// .set_fake_lsp_adapter(Arc::new(FakeLspAdapter { +// capabilities: lsp::ServerCapabilities { +// inlay_hint_provider: Some(lsp::OneOf::Left(true)), +// ..Default::default() +// }, +// ..Default::default() +// })) +// .await; +// let language = Arc::new(language); +// client_a.language_registry().add(Arc::clone(&language)); +// client_b.language_registry().add(language); + +// client_a +// .fs() +// .insert_tree( +// "/a", +// json!({ +// "main.rs": "fn main() { a } // and some long comment to ensure inlay hints are not trimmed out", +// "other.rs": "// Test file", +// }), +// ) +// .await; +// let (project_a, worktree_id) = client_a.build_local_project("/a", cx_a).await; +// active_call_a +// .update(cx_a, |call, cx| call.set_location(Some(&project_a), cx)) +// .await +// .unwrap(); +// let project_id = active_call_a +// .update(cx_a, |call, cx| call.share_project(project_a.clone(), cx)) +// .await +// .unwrap(); + +// let project_b = client_b.build_remote_project(project_id, cx_b).await; +// active_call_b +// .update(cx_b, |call, cx| call.set_location(Some(&project_b), cx)) +// .await +// .unwrap(); + +// let workspace_a = client_a.build_workspace(&project_a, cx_a).root(cx_a); +// let workspace_b = client_b.build_workspace(&project_b, cx_b).root(cx_b); +// cx_a.foreground().start_waiting(); +// cx_b.foreground().start_waiting(); + +// let editor_a = workspace_a +// .update(cx_a, |workspace, cx| { +// workspace.open_path((worktree_id, "main.rs"), None, true, cx) +// }) +// .await +// .unwrap() +// .downcast::() +// .unwrap(); + +// let editor_b = workspace_b +// .update(cx_b, |workspace, cx| { +// workspace.open_path((worktree_id, "main.rs"), None, true, cx) +// }) +// .await +// .unwrap() +// .downcast::() +// .unwrap(); + +// let other_hints = Arc::new(AtomicBool::new(false)); +// let fake_language_server = fake_language_servers.next().await.unwrap(); +// let closure_other_hints = Arc::clone(&other_hints); +// fake_language_server +// .handle_request::(move |params, _| { +// let task_other_hints = Arc::clone(&closure_other_hints); +// async move { +// assert_eq!( +// params.text_document.uri, +// lsp::Url::from_file_path("/a/main.rs").unwrap(), +// ); +// let other_hints = task_other_hints.load(atomic::Ordering::Acquire); +// let character = if other_hints { 0 } else { 2 }; +// let label = if other_hints { +// "other hint" +// } else { +// "initial hint" +// }; +// Ok(Some(vec![lsp::InlayHint { +// position: lsp::Position::new(0, character), +// label: lsp::InlayHintLabel::String(label.to_string()), +// kind: None, +// text_edits: None, +// tooltip: None, +// padding_left: None, +// padding_right: None, +// data: None, +// }])) +// } +// }) +// .next() +// .await +// .unwrap(); +// cx_a.foreground().finish_waiting(); +// cx_b.foreground().finish_waiting(); + +// executor.run_until_parked(); +// editor_a.update(cx_a, |editor, _| { +// assert!( +// extract_hint_labels(editor).is_empty(), +// "Host should get no hints due to them turned off" +// ); +// let inlay_cache = editor.inlay_hint_cache(); +// assert_eq!( +// inlay_cache.version(), +// 0, +// "Turned off hints should not generate version updates" +// ); +// }); + +// executor.run_until_parked(); +// editor_b.update(cx_b, |editor, _| { +// assert_eq!( +// vec!["initial hint".to_string()], +// extract_hint_labels(editor), +// "Client should get its first hints when opens an editor" +// ); +// let inlay_cache = editor.inlay_hint_cache(); +// assert_eq!( +// inlay_cache.version(), +// 1, +// "Should update cache verison after first hints" +// ); +// }); + +// other_hints.fetch_or(true, atomic::Ordering::Release); +// fake_language_server +// .request::(()) +// .await +// .expect("inlay refresh request failed"); +// executor.run_until_parked(); +// editor_a.update(cx_a, |editor, _| { +// assert!( +// extract_hint_labels(editor).is_empty(), +// "Host should get nop hints due to them turned off, even after the /refresh" +// ); +// let inlay_cache = editor.inlay_hint_cache(); +// assert_eq!( +// inlay_cache.version(), +// 0, +// "Turned off hints should not generate version updates, again" +// ); +// }); + +// executor.run_until_parked(); +// editor_b.update(cx_b, |editor, _| { +// assert_eq!( +// vec!["other hint".to_string()], +// extract_hint_labels(editor), +// "Guest should get a /refresh LSP request propagated by host despite host hints are off" +// ); +// let inlay_cache = editor.inlay_hint_cache(); +// assert_eq!( +// inlay_cache.version(), +// 2, +// "Guest should accepted all edits and bump its cache version every time" +// ); +// }); +// } + +// fn extract_hint_labels(editor: &Editor) -> Vec { +// let mut labels = Vec::new(); +// for hint in editor.inlay_hint_cache().hints() { +// match hint.label { +// project::InlayHintLabel::String(s) => labels.push(s), +// _ => unreachable!(), +// } +// } +// labels +// } diff --git a/crates/collab2/src/tests/notification_tests.rs b/crates/collab2/src/tests/notification_tests.rs new file mode 100644 index 0000000000..021591ee09 --- /dev/null +++ b/crates/collab2/src/tests/notification_tests.rs @@ -0,0 +1,160 @@ +//todo!(notifications) +// use crate::tests::TestServer; +// use gpui::{executor::Deterministic, TestAppContext}; +// use notifications::NotificationEvent; +// use parking_lot::Mutex; +// use rpc::{proto, Notification}; +// use std::sync::Arc; + +// #[gpui::test] +// async fn test_notifications( +// deterministic: Arc, +// cx_a: &mut TestAppContext, +// cx_b: &mut TestAppContext, +// ) { +// deterministic.forbid_parking(); +// let mut server = TestServer::start(&deterministic).await; +// let client_a = server.create_client(cx_a, "user_a").await; +// let client_b = server.create_client(cx_b, "user_b").await; + +// let notification_events_a = Arc::new(Mutex::new(Vec::new())); +// let notification_events_b = Arc::new(Mutex::new(Vec::new())); +// client_a.notification_store().update(cx_a, |_, cx| { +// let events = notification_events_a.clone(); +// cx.subscribe(&cx.handle(), move |_, _, event, _| { +// events.lock().push(event.clone()); +// }) +// .detach() +// }); +// client_b.notification_store().update(cx_b, |_, cx| { +// let events = notification_events_b.clone(); +// cx.subscribe(&cx.handle(), move |_, _, event, _| { +// events.lock().push(event.clone()); +// }) +// .detach() +// }); + +// // Client A sends a contact request to client B. +// client_a +// .user_store() +// .update(cx_a, |store, cx| store.request_contact(client_b.id(), cx)) +// .await +// .unwrap(); + +// // Client B receives a contact request notification and responds to the +// // request, accepting it. +// deterministic.run_until_parked(); +// client_b.notification_store().update(cx_b, |store, cx| { +// assert_eq!(store.notification_count(), 1); +// assert_eq!(store.unread_notification_count(), 1); + +// let entry = store.notification_at(0).unwrap(); +// assert_eq!( +// entry.notification, +// Notification::ContactRequest { +// sender_id: client_a.id() +// } +// ); +// assert!(!entry.is_read); +// assert_eq!( +// ¬ification_events_b.lock()[0..], +// &[ +// NotificationEvent::NewNotification { +// entry: entry.clone(), +// }, +// NotificationEvent::NotificationsUpdated { +// old_range: 0..0, +// new_count: 1 +// } +// ] +// ); + +// store.respond_to_notification(entry.notification.clone(), true, cx); +// }); + +// // Client B sees the notification is now read, and that they responded. +// deterministic.run_until_parked(); +// client_b.notification_store().read_with(cx_b, |store, _| { +// assert_eq!(store.notification_count(), 1); +// assert_eq!(store.unread_notification_count(), 0); + +// let entry = store.notification_at(0).unwrap(); +// assert!(entry.is_read); +// assert_eq!(entry.response, Some(true)); +// assert_eq!( +// ¬ification_events_b.lock()[2..], +// &[ +// NotificationEvent::NotificationRead { +// entry: entry.clone(), +// }, +// NotificationEvent::NotificationsUpdated { +// old_range: 0..1, +// new_count: 1 +// } +// ] +// ); +// }); + +// // Client A receives a notification that client B accepted their request. +// client_a.notification_store().read_with(cx_a, |store, _| { +// assert_eq!(store.notification_count(), 1); +// assert_eq!(store.unread_notification_count(), 1); + +// let entry = store.notification_at(0).unwrap(); +// assert_eq!( +// entry.notification, +// Notification::ContactRequestAccepted { +// responder_id: client_b.id() +// } +// ); +// assert!(!entry.is_read); +// }); + +// // Client A creates a channel and invites client B to be a member. +// let channel_id = client_a +// .channel_store() +// .update(cx_a, |store, cx| { +// store.create_channel("the-channel", None, cx) +// }) +// .await +// .unwrap(); +// client_a +// .channel_store() +// .update(cx_a, |store, cx| { +// store.invite_member(channel_id, client_b.id(), proto::ChannelRole::Member, cx) +// }) +// .await +// .unwrap(); + +// // Client B receives a channel invitation notification and responds to the +// // invitation, accepting it. +// deterministic.run_until_parked(); +// client_b.notification_store().update(cx_b, |store, cx| { +// assert_eq!(store.notification_count(), 2); +// assert_eq!(store.unread_notification_count(), 1); + +// let entry = store.notification_at(0).unwrap(); +// assert_eq!( +// entry.notification, +// Notification::ChannelInvitation { +// channel_id, +// channel_name: "the-channel".to_string(), +// inviter_id: client_a.id() +// } +// ); +// assert!(!entry.is_read); + +// store.respond_to_notification(entry.notification.clone(), true, cx); +// }); + +// // Client B sees the notification is now read, and that they responded. +// deterministic.run_until_parked(); +// client_b.notification_store().read_with(cx_b, |store, _| { +// assert_eq!(store.notification_count(), 2); +// assert_eq!(store.unread_notification_count(), 0); + +// let entry = store.notification_at(0).unwrap(); +// assert!(entry.is_read); +// assert_eq!(entry.response, Some(true)); +// }); +// } diff --git a/crates/collab2/src/tests/random_channel_buffer_tests.rs b/crates/collab2/src/tests/random_channel_buffer_tests.rs new file mode 100644 index 0000000000..01f8daa5d2 --- /dev/null +++ b/crates/collab2/src/tests/random_channel_buffer_tests.rs @@ -0,0 +1,296 @@ +use crate::db::ChannelRole; + +use super::{run_randomized_test, RandomizedTest, TestClient, TestError, TestServer, UserTestPlan}; +use anyhow::Result; +use async_trait::async_trait; +use gpui::{BackgroundExecutor, TestAppContext}; +use rand::prelude::*; +use serde_derive::{Deserialize, Serialize}; +use std::{ + ops::{Deref, DerefMut, Range}, + rc::Rc, + sync::Arc, +}; +use text::Bias; + +#[gpui::test( + iterations = 100, + on_failure = "crate::tests::save_randomized_test_plan" +)] +async fn test_random_channel_buffers( + cx: &mut TestAppContext, + executor: BackgroundExecutor, + rng: StdRng, +) { + run_randomized_test::(cx, executor, rng).await; +} + +struct RandomChannelBufferTest; + +#[derive(Clone, Serialize, Deserialize)] +enum ChannelBufferOperation { + JoinChannelNotes { + channel_name: String, + }, + LeaveChannelNotes { + channel_name: String, + }, + EditChannelNotes { + channel_name: String, + edits: Vec<(Range, Arc)>, + }, + Noop, +} + +const CHANNEL_COUNT: usize = 3; + +#[async_trait(?Send)] +impl RandomizedTest for RandomChannelBufferTest { + type Operation = ChannelBufferOperation; + + async fn initialize(server: &mut TestServer, users: &[UserTestPlan]) { + let db = &server.app_state.db; + for ix in 0..CHANNEL_COUNT { + let id = db + .create_root_channel(&format!("channel-{ix}"), users[0].user_id) + .await + .unwrap(); + for user in &users[1..] { + db.invite_channel_member(id, user.user_id, users[0].user_id, ChannelRole::Member) + .await + .unwrap(); + db.respond_to_channel_invite(id, user.user_id, true) + .await + .unwrap(); + } + } + } + + fn generate_operation( + client: &TestClient, + rng: &mut StdRng, + _: &mut UserTestPlan, + cx: &TestAppContext, + ) -> ChannelBufferOperation { + let channel_store = client.channel_store().clone(); + let mut channel_buffers = client.channel_buffers(); + + // When signed out, we can't do anything unless a channel buffer is + // already open. + if channel_buffers.deref_mut().is_empty() + && channel_store.read_with(cx, |store, _| store.channel_count() == 0) + { + return ChannelBufferOperation::Noop; + } + + loop { + match rng.gen_range(0..100_u32) { + 0..=29 => { + let channel_name = client.channel_store().read_with(cx, |store, cx| { + store.ordered_channels().find_map(|(_, channel)| { + if store.has_open_channel_buffer(channel.id, cx) { + None + } else { + Some(channel.name.clone()) + } + }) + }); + if let Some(channel_name) = channel_name { + break ChannelBufferOperation::JoinChannelNotes { channel_name }; + } + } + + 30..=40 => { + if let Some(buffer) = channel_buffers.deref().iter().choose(rng) { + let channel_name = + buffer.read_with(cx, |b, cx| b.channel(cx).unwrap().name.clone()); + break ChannelBufferOperation::LeaveChannelNotes { channel_name }; + } + } + + _ => { + if let Some(buffer) = channel_buffers.deref().iter().choose(rng) { + break buffer.read_with(cx, |b, cx| { + let channel_name = b.channel(cx).unwrap().name.clone(); + let edits = b + .buffer() + .read_with(cx, |buffer, _| buffer.get_random_edits(rng, 3)); + ChannelBufferOperation::EditChannelNotes { + channel_name, + edits, + } + }); + } + } + } + } + } + + async fn apply_operation( + client: &TestClient, + operation: ChannelBufferOperation, + cx: &mut TestAppContext, + ) -> Result<(), TestError> { + match operation { + ChannelBufferOperation::JoinChannelNotes { channel_name } => { + let buffer = client.channel_store().update(cx, |store, cx| { + let channel_id = store + .ordered_channels() + .find(|(_, c)| c.name == channel_name) + .unwrap() + .1 + .id; + if store.has_open_channel_buffer(channel_id, cx) { + Err(TestError::Inapplicable) + } else { + Ok(store.open_channel_buffer(channel_id, cx)) + } + })?; + + log::info!( + "{}: opening notes for channel {channel_name}", + client.username + ); + client.channel_buffers().deref_mut().insert(buffer.await?); + } + + ChannelBufferOperation::LeaveChannelNotes { channel_name } => { + let buffer = cx.update(|cx| { + let mut left_buffer = Err(TestError::Inapplicable); + client.channel_buffers().deref_mut().retain(|buffer| { + if buffer.read(cx).channel(cx).unwrap().name == channel_name { + left_buffer = Ok(buffer.clone()); + false + } else { + true + } + }); + left_buffer + })?; + + log::info!( + "{}: closing notes for channel {channel_name}", + client.username + ); + cx.update(|_| drop(buffer)); + } + + ChannelBufferOperation::EditChannelNotes { + channel_name, + edits, + } => { + let channel_buffer = cx + .read(|cx| { + client + .channel_buffers() + .deref() + .iter() + .find(|buffer| { + buffer.read(cx).channel(cx).unwrap().name == channel_name + }) + .cloned() + }) + .ok_or_else(|| TestError::Inapplicable)?; + + log::info!( + "{}: editing notes for channel {channel_name} with {:?}", + client.username, + edits + ); + + channel_buffer.update(cx, |buffer, cx| { + let buffer = buffer.buffer(); + buffer.update(cx, |buffer, cx| { + let snapshot = buffer.snapshot(); + buffer.edit( + edits.into_iter().map(|(range, text)| { + let start = snapshot.clip_offset(range.start, Bias::Left); + let end = snapshot.clip_offset(range.end, Bias::Right); + (start..end, text) + }), + None, + cx, + ); + }); + }); + } + + ChannelBufferOperation::Noop => Err(TestError::Inapplicable)?, + } + Ok(()) + } + + async fn on_client_added(client: &Rc, cx: &mut TestAppContext) { + let channel_store = client.channel_store(); + while channel_store.read_with(cx, |store, _| store.channel_count() == 0) { + // todo!(notifications) + // channel_store.next_notification(cx).await; + } + } + + async fn on_quiesce(server: &mut TestServer, clients: &mut [(Rc, TestAppContext)]) { + let channels = server.app_state.db.all_channels().await.unwrap(); + + for (client, client_cx) in clients.iter_mut() { + client_cx.update(|cx| { + client + .channel_buffers() + .deref_mut() + .retain(|b| b.read(cx).is_connected()); + }); + } + + for (channel_id, channel_name) in channels { + let mut prev_text: Option<(u64, String)> = None; + + let mut collaborator_user_ids = server + .app_state + .db + .get_channel_buffer_collaborators(channel_id) + .await + .unwrap() + .into_iter() + .map(|id| id.to_proto()) + .collect::>(); + collaborator_user_ids.sort(); + + for (client, client_cx) in clients.iter() { + let user_id = client.user_id().unwrap(); + client_cx.read(|cx| { + if let Some(channel_buffer) = client + .channel_buffers() + .deref() + .iter() + .find(|b| b.read(cx).channel_id == channel_id.to_proto()) + { + let channel_buffer = channel_buffer.read(cx); + + // Assert that channel buffer's text matches other clients' copies. + let text = channel_buffer.buffer().read(cx).text(); + if let Some((prev_user_id, prev_text)) = &prev_text { + assert_eq!( + &text, + prev_text, + "client {user_id} has different text than client {prev_user_id} for channel {channel_name}", + ); + } else { + prev_text = Some((user_id, text.clone())); + } + + // Assert that all clients and the server agree about who is present in the + // channel buffer. + let collaborators = channel_buffer.collaborators(); + let mut user_ids = + collaborators.values().map(|c| c.user_id).collect::>(); + user_ids.sort(); + assert_eq!( + user_ids, + collaborator_user_ids, + "client {user_id} has different user ids for channel {channel_name} than the server", + ); + } + }); + } + } + } +} diff --git a/crates/collab2/src/tests/random_project_collaboration_tests.rs b/crates/collab2/src/tests/random_project_collaboration_tests.rs new file mode 100644 index 0000000000..361ca00c33 --- /dev/null +++ b/crates/collab2/src/tests/random_project_collaboration_tests.rs @@ -0,0 +1,1587 @@ +use super::{RandomizedTest, TestClient, TestError, TestServer, UserTestPlan}; +use crate::{db::UserId, tests::run_randomized_test}; +use anyhow::{anyhow, Result}; +use async_trait::async_trait; +use call::ActiveCall; +use collections::{BTreeMap, HashMap}; +use editor::Bias; +use fs::{repository::GitFileStatus, FakeFs, Fs as _}; +use futures::StreamExt; +use gpui::{BackgroundExecutor, Model, TestAppContext}; +use language::{range_to_lsp, FakeLspAdapter, Language, LanguageConfig, PointUtf16}; +use lsp::FakeLanguageServer; +use pretty_assertions::assert_eq; +use project::{search::SearchQuery, Project, ProjectPath}; +use rand::{ + distributions::{Alphanumeric, DistString}, + prelude::*, +}; +use serde::{Deserialize, Serialize}; +use std::{ + ops::{Deref, Range}, + path::{Path, PathBuf}, + rc::Rc, + sync::Arc, +}; +use util::ResultExt; + +#[gpui::test( + iterations = 100, + on_failure = "crate::tests::save_randomized_test_plan" +)] +async fn test_random_project_collaboration( + cx: &mut TestAppContext, + executor: BackgroundExecutor, + rng: StdRng, +) { + run_randomized_test::(cx, executor, rng).await; +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +enum ClientOperation { + AcceptIncomingCall, + RejectIncomingCall, + LeaveCall, + InviteContactToCall { + user_id: UserId, + }, + OpenLocalProject { + first_root_name: String, + }, + OpenRemoteProject { + host_id: UserId, + first_root_name: String, + }, + AddWorktreeToProject { + project_root_name: String, + new_root_path: PathBuf, + }, + CloseRemoteProject { + project_root_name: String, + }, + OpenBuffer { + project_root_name: String, + is_local: bool, + full_path: PathBuf, + }, + SearchProject { + project_root_name: String, + is_local: bool, + query: String, + detach: bool, + }, + EditBuffer { + project_root_name: String, + is_local: bool, + full_path: PathBuf, + edits: Vec<(Range, Arc)>, + }, + CloseBuffer { + project_root_name: String, + is_local: bool, + full_path: PathBuf, + }, + SaveBuffer { + project_root_name: String, + is_local: bool, + full_path: PathBuf, + detach: bool, + }, + RequestLspDataInBuffer { + project_root_name: String, + is_local: bool, + full_path: PathBuf, + offset: usize, + kind: LspRequestKind, + detach: bool, + }, + CreateWorktreeEntry { + project_root_name: String, + is_local: bool, + full_path: PathBuf, + is_dir: bool, + }, + WriteFsEntry { + path: PathBuf, + is_dir: bool, + content: String, + }, + GitOperation { + operation: GitOperation, + }, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +enum GitOperation { + WriteGitIndex { + repo_path: PathBuf, + contents: Vec<(PathBuf, String)>, + }, + WriteGitBranch { + repo_path: PathBuf, + new_branch: Option, + }, + WriteGitStatuses { + repo_path: PathBuf, + statuses: Vec<(PathBuf, GitFileStatus)>, + git_operation: bool, + }, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +enum LspRequestKind { + Rename, + Completion, + CodeAction, + Definition, + Highlights, +} + +struct ProjectCollaborationTest; + +#[async_trait(?Send)] +impl RandomizedTest for ProjectCollaborationTest { + type Operation = ClientOperation; + + async fn initialize(server: &mut TestServer, users: &[UserTestPlan]) { + let db = &server.app_state.db; + for (ix, user_a) in users.iter().enumerate() { + for user_b in &users[ix + 1..] { + db.send_contact_request(user_a.user_id, user_b.user_id) + .await + .unwrap(); + db.respond_to_contact_request(user_b.user_id, user_a.user_id, true) + .await + .unwrap(); + } + } + } + + fn generate_operation( + client: &TestClient, + rng: &mut StdRng, + plan: &mut UserTestPlan, + cx: &TestAppContext, + ) -> ClientOperation { + let call = cx.read(ActiveCall::global); + loop { + match rng.gen_range(0..100_u32) { + // Mutate the call + 0..=29 => { + // Respond to an incoming call + if call.read_with(cx, |call, _| call.incoming().borrow().is_some()) { + break if rng.gen_bool(0.7) { + ClientOperation::AcceptIncomingCall + } else { + ClientOperation::RejectIncomingCall + }; + } + + match rng.gen_range(0..100_u32) { + // Invite a contact to the current call + 0..=70 => { + let available_contacts = + client.user_store().read_with(cx, |user_store, _| { + user_store + .contacts() + .iter() + .filter(|contact| contact.online && !contact.busy) + .cloned() + .collect::>() + }); + if !available_contacts.is_empty() { + let contact = available_contacts.choose(rng).unwrap(); + break ClientOperation::InviteContactToCall { + user_id: UserId(contact.user.id as i32), + }; + } + } + + // Leave the current call + 71.. => { + if plan.allow_client_disconnection + && call.read_with(cx, |call, _| call.room().is_some()) + { + break ClientOperation::LeaveCall; + } + } + } + } + + // Mutate projects + 30..=59 => match rng.gen_range(0..100_u32) { + // Open a new project + 0..=70 => { + // Open a remote project + if let Some(room) = call.read_with(cx, |call, _| call.room().cloned()) { + let existing_remote_project_ids = cx.read(|cx| { + client + .remote_projects() + .iter() + .map(|p| p.read(cx).remote_id().unwrap()) + .collect::>() + }); + let new_remote_projects = room.read_with(cx, |room, _| { + room.remote_participants() + .values() + .flat_map(|participant| { + participant.projects.iter().filter_map(|project| { + if existing_remote_project_ids.contains(&project.id) { + None + } else { + Some(( + UserId::from_proto(participant.user.id), + project.worktree_root_names[0].clone(), + )) + } + }) + }) + .collect::>() + }); + if !new_remote_projects.is_empty() { + let (host_id, first_root_name) = + new_remote_projects.choose(rng).unwrap().clone(); + break ClientOperation::OpenRemoteProject { + host_id, + first_root_name, + }; + } + } + // Open a local project + else { + let first_root_name = plan.next_root_dir_name(); + break ClientOperation::OpenLocalProject { first_root_name }; + } + } + + // Close a remote project + 71..=80 => { + if !client.remote_projects().is_empty() { + let project = client.remote_projects().choose(rng).unwrap().clone(); + let first_root_name = root_name_for_project(&project, cx); + break ClientOperation::CloseRemoteProject { + project_root_name: first_root_name, + }; + } + } + + // Mutate project worktrees + 81.. => match rng.gen_range(0..100_u32) { + // Add a worktree to a local project + 0..=50 => { + let Some(project) = client.local_projects().choose(rng).cloned() else { + continue; + }; + let project_root_name = root_name_for_project(&project, cx); + let mut paths = client.fs().paths(false); + paths.remove(0); + let new_root_path = if paths.is_empty() || rng.gen() { + Path::new("/").join(&plan.next_root_dir_name()) + } else { + paths.choose(rng).unwrap().clone() + }; + break ClientOperation::AddWorktreeToProject { + project_root_name, + new_root_path, + }; + } + + // Add an entry to a worktree + _ => { + let Some(project) = choose_random_project(client, rng) else { + continue; + }; + let project_root_name = root_name_for_project(&project, cx); + let is_local = project.read_with(cx, |project, _| project.is_local()); + let worktree = project.read_with(cx, |project, cx| { + project + .worktrees() + .filter(|worktree| { + let worktree = worktree.read(cx); + worktree.is_visible() + && worktree.entries(false).any(|e| e.is_file()) + && worktree.root_entry().map_or(false, |e| e.is_dir()) + }) + .choose(rng) + }); + let Some(worktree) = worktree else { continue }; + let is_dir = rng.gen::(); + let mut full_path = + worktree.read_with(cx, |w, _| PathBuf::from(w.root_name())); + full_path.push(gen_file_name(rng)); + if !is_dir { + full_path.set_extension("rs"); + } + break ClientOperation::CreateWorktreeEntry { + project_root_name, + is_local, + full_path, + is_dir, + }; + } + }, + }, + + // Query and mutate buffers + 60..=90 => { + let Some(project) = choose_random_project(client, rng) else { + continue; + }; + let project_root_name = root_name_for_project(&project, cx); + let is_local = project.read_with(cx, |project, _| project.is_local()); + + match rng.gen_range(0..100_u32) { + // Manipulate an existing buffer + 0..=70 => { + let Some(buffer) = client + .buffers_for_project(&project) + .iter() + .choose(rng) + .cloned() + else { + continue; + }; + + let full_path = buffer + .read_with(cx, |buffer, cx| buffer.file().unwrap().full_path(cx)); + + match rng.gen_range(0..100_u32) { + // Close the buffer + 0..=15 => { + break ClientOperation::CloseBuffer { + project_root_name, + is_local, + full_path, + }; + } + // Save the buffer + 16..=29 if buffer.read_with(cx, |b, _| b.is_dirty()) => { + let detach = rng.gen_bool(0.3); + break ClientOperation::SaveBuffer { + project_root_name, + is_local, + full_path, + detach, + }; + } + // Edit the buffer + 30..=69 => { + let edits = buffer + .read_with(cx, |buffer, _| buffer.get_random_edits(rng, 3)); + break ClientOperation::EditBuffer { + project_root_name, + is_local, + full_path, + edits, + }; + } + // Make an LSP request + _ => { + let offset = buffer.read_with(cx, |buffer, _| { + buffer.clip_offset( + rng.gen_range(0..=buffer.len()), + language::Bias::Left, + ) + }); + let detach = rng.gen(); + break ClientOperation::RequestLspDataInBuffer { + project_root_name, + full_path, + offset, + is_local, + kind: match rng.gen_range(0..5_u32) { + 0 => LspRequestKind::Rename, + 1 => LspRequestKind::Highlights, + 2 => LspRequestKind::Definition, + 3 => LspRequestKind::CodeAction, + 4.. => LspRequestKind::Completion, + }, + detach, + }; + } + } + } + + 71..=80 => { + let query = rng.gen_range('a'..='z').to_string(); + let detach = rng.gen_bool(0.3); + break ClientOperation::SearchProject { + project_root_name, + is_local, + query, + detach, + }; + } + + // Open a buffer + 81.. => { + let worktree = project.read_with(cx, |project, cx| { + project + .worktrees() + .filter(|worktree| { + let worktree = worktree.read(cx); + worktree.is_visible() + && worktree.entries(false).any(|e| e.is_file()) + }) + .choose(rng) + }); + let Some(worktree) = worktree else { continue }; + let full_path = worktree.read_with(cx, |worktree, _| { + let entry = worktree + .entries(false) + .filter(|e| e.is_file()) + .choose(rng) + .unwrap(); + if entry.path.as_ref() == Path::new("") { + Path::new(worktree.root_name()).into() + } else { + Path::new(worktree.root_name()).join(&entry.path) + } + }); + break ClientOperation::OpenBuffer { + project_root_name, + is_local, + full_path, + }; + } + } + } + + // Update a git related action + 91..=95 => { + break ClientOperation::GitOperation { + operation: generate_git_operation(rng, client), + }; + } + + // Create or update a file or directory + 96.. => { + let is_dir = rng.gen::(); + let content; + let mut path; + let dir_paths = client.fs().directories(false); + + if is_dir { + content = String::new(); + path = dir_paths.choose(rng).unwrap().clone(); + path.push(gen_file_name(rng)); + } else { + content = Alphanumeric.sample_string(rng, 16); + + // Create a new file or overwrite an existing file + let file_paths = client.fs().files(); + if file_paths.is_empty() || rng.gen_bool(0.5) { + path = dir_paths.choose(rng).unwrap().clone(); + path.push(gen_file_name(rng)); + path.set_extension("rs"); + } else { + path = file_paths.choose(rng).unwrap().clone() + }; + } + break ClientOperation::WriteFsEntry { + path, + is_dir, + content, + }; + } + } + } + } + + async fn apply_operation( + client: &TestClient, + operation: ClientOperation, + cx: &mut TestAppContext, + ) -> Result<(), TestError> { + match operation { + ClientOperation::AcceptIncomingCall => { + let active_call = cx.read(ActiveCall::global); + if active_call.read_with(cx, |call, _| call.incoming().borrow().is_none()) { + Err(TestError::Inapplicable)?; + } + + log::info!("{}: accepting incoming call", client.username); + active_call + .update(cx, |call, cx| call.accept_incoming(cx)) + .await?; + } + + ClientOperation::RejectIncomingCall => { + let active_call = cx.read(ActiveCall::global); + if active_call.read_with(cx, |call, _| call.incoming().borrow().is_none()) { + Err(TestError::Inapplicable)?; + } + + log::info!("{}: declining incoming call", client.username); + active_call.update(cx, |call, cx| call.decline_incoming(cx))?; + } + + ClientOperation::LeaveCall => { + let active_call = cx.read(ActiveCall::global); + if active_call.read_with(cx, |call, _| call.room().is_none()) { + Err(TestError::Inapplicable)?; + } + + log::info!("{}: hanging up", client.username); + active_call.update(cx, |call, cx| call.hang_up(cx)).await?; + } + + ClientOperation::InviteContactToCall { user_id } => { + let active_call = cx.read(ActiveCall::global); + + log::info!("{}: inviting {}", client.username, user_id,); + active_call + .update(cx, |call, cx| call.invite(user_id.to_proto(), None, cx)) + .await + .log_err(); + } + + ClientOperation::OpenLocalProject { first_root_name } => { + log::info!( + "{}: opening local project at {:?}", + client.username, + first_root_name + ); + + let root_path = Path::new("/").join(&first_root_name); + client.fs().create_dir(&root_path).await.unwrap(); + client + .fs() + .create_file(&root_path.join("main.rs"), Default::default()) + .await + .unwrap(); + let project = client.build_local_project(root_path, cx).await.0; + ensure_project_shared(&project, client, cx).await; + client.local_projects_mut().push(project.clone()); + } + + ClientOperation::AddWorktreeToProject { + project_root_name, + new_root_path, + } => { + let project = project_for_root_name(client, &project_root_name, cx) + .ok_or(TestError::Inapplicable)?; + + log::info!( + "{}: finding/creating local worktree at {:?} to project with root path {}", + client.username, + new_root_path, + project_root_name + ); + + ensure_project_shared(&project, client, cx).await; + if !client.fs().paths(false).contains(&new_root_path) { + client.fs().create_dir(&new_root_path).await.unwrap(); + } + project + .update(cx, |project, cx| { + project.find_or_create_local_worktree(&new_root_path, true, cx) + }) + .await + .unwrap(); + } + + ClientOperation::CloseRemoteProject { project_root_name } => { + let project = project_for_root_name(client, &project_root_name, cx) + .ok_or(TestError::Inapplicable)?; + + log::info!( + "{}: closing remote project with root path {}", + client.username, + project_root_name, + ); + + let ix = client + .remote_projects() + .iter() + .position(|p| p == &project) + .unwrap(); + cx.update(|_| { + client.remote_projects_mut().remove(ix); + client.buffers().retain(|p, _| *p != project); + drop(project); + }); + } + + ClientOperation::OpenRemoteProject { + host_id, + first_root_name, + } => { + let active_call = cx.read(ActiveCall::global); + let project = active_call + .update(cx, |call, cx| { + let room = call.room().cloned()?; + let participant = room + .read(cx) + .remote_participants() + .get(&host_id.to_proto())?; + let project_id = participant + .projects + .iter() + .find(|project| project.worktree_root_names[0] == first_root_name)? + .id; + Some(room.update(cx, |room, cx| { + room.join_project( + project_id, + client.language_registry().clone(), + FakeFs::new(cx.background_executor().clone()), + cx, + ) + })) + }) + .ok_or(TestError::Inapplicable)?; + + log::info!( + "{}: joining remote project of user {}, root name {}", + client.username, + host_id, + first_root_name, + ); + + let project = project.await?; + client.remote_projects_mut().push(project.clone()); + } + + ClientOperation::CreateWorktreeEntry { + project_root_name, + is_local, + full_path, + is_dir, + } => { + let project = project_for_root_name(client, &project_root_name, cx) + .ok_or(TestError::Inapplicable)?; + let project_path = project_path_for_full_path(&project, &full_path, cx) + .ok_or(TestError::Inapplicable)?; + + log::info!( + "{}: creating {} at path {:?} in {} project {}", + client.username, + if is_dir { "dir" } else { "file" }, + full_path, + if is_local { "local" } else { "remote" }, + project_root_name, + ); + + ensure_project_shared(&project, client, cx).await; + project + .update(cx, |p, cx| p.create_entry(project_path, is_dir, cx)) + .unwrap() + .await?; + } + + ClientOperation::OpenBuffer { + project_root_name, + is_local, + full_path, + } => { + let project = project_for_root_name(client, &project_root_name, cx) + .ok_or(TestError::Inapplicable)?; + let project_path = project_path_for_full_path(&project, &full_path, cx) + .ok_or(TestError::Inapplicable)?; + + log::info!( + "{}: opening buffer {:?} in {} project {}", + client.username, + full_path, + if is_local { "local" } else { "remote" }, + project_root_name, + ); + + ensure_project_shared(&project, client, cx).await; + let buffer = project + .update(cx, |project, cx| project.open_buffer(project_path, cx)) + .await?; + client.buffers_for_project(&project).insert(buffer); + } + + ClientOperation::EditBuffer { + project_root_name, + is_local, + full_path, + edits, + } => { + let project = project_for_root_name(client, &project_root_name, cx) + .ok_or(TestError::Inapplicable)?; + let buffer = buffer_for_full_path(client, &project, &full_path, cx) + .ok_or(TestError::Inapplicable)?; + + log::info!( + "{}: editing buffer {:?} in {} project {} with {:?}", + client.username, + full_path, + if is_local { "local" } else { "remote" }, + project_root_name, + edits + ); + + ensure_project_shared(&project, client, cx).await; + buffer.update(cx, |buffer, cx| { + let snapshot = buffer.snapshot(); + buffer.edit( + edits.into_iter().map(|(range, text)| { + let start = snapshot.clip_offset(range.start, Bias::Left); + let end = snapshot.clip_offset(range.end, Bias::Right); + (start..end, text) + }), + None, + cx, + ); + }); + } + + ClientOperation::CloseBuffer { + project_root_name, + is_local, + full_path, + } => { + let project = project_for_root_name(client, &project_root_name, cx) + .ok_or(TestError::Inapplicable)?; + let buffer = buffer_for_full_path(client, &project, &full_path, cx) + .ok_or(TestError::Inapplicable)?; + + log::info!( + "{}: closing buffer {:?} in {} project {}", + client.username, + full_path, + if is_local { "local" } else { "remote" }, + project_root_name + ); + + ensure_project_shared(&project, client, cx).await; + cx.update(|_| { + client.buffers_for_project(&project).remove(&buffer); + drop(buffer); + }); + } + + ClientOperation::SaveBuffer { + project_root_name, + is_local, + full_path, + detach, + } => { + let project = project_for_root_name(client, &project_root_name, cx) + .ok_or(TestError::Inapplicable)?; + let buffer = buffer_for_full_path(client, &project, &full_path, cx) + .ok_or(TestError::Inapplicable)?; + + log::info!( + "{}: saving buffer {:?} in {} project {}, {}", + client.username, + full_path, + if is_local { "local" } else { "remote" }, + project_root_name, + if detach { "detaching" } else { "awaiting" } + ); + + ensure_project_shared(&project, client, cx).await; + let requested_version = buffer.read_with(cx, |buffer, _| buffer.version()); + let save = + project.update(cx, |project, cx| project.save_buffer(buffer.clone(), cx)); + let save = cx.spawn(|cx| async move { + save.await + .map_err(|err| anyhow!("save request failed: {:?}", err))?; + assert!(buffer + .read_with(&cx, |buffer, _| { buffer.saved_version().to_owned() }) + .expect("App should not be dropped") + .observed_all(&requested_version)); + anyhow::Ok(()) + }); + if detach { + cx.update(|cx| save.detach_and_log_err(cx)); + } else { + save.await?; + } + } + + ClientOperation::RequestLspDataInBuffer { + project_root_name, + is_local, + full_path, + offset, + kind, + detach, + } => { + let project = project_for_root_name(client, &project_root_name, cx) + .ok_or(TestError::Inapplicable)?; + let buffer = buffer_for_full_path(client, &project, &full_path, cx) + .ok_or(TestError::Inapplicable)?; + + log::info!( + "{}: request LSP {:?} for buffer {:?} in {} project {}, {}", + client.username, + kind, + full_path, + if is_local { "local" } else { "remote" }, + project_root_name, + if detach { "detaching" } else { "awaiting" } + ); + + use futures::{FutureExt as _, TryFutureExt as _}; + let offset = buffer.read_with(cx, |b, _| b.clip_offset(offset, Bias::Left)); + + let process_lsp_request = project.update(cx, |project, cx| match kind { + LspRequestKind::Rename => project + .prepare_rename(buffer, offset, cx) + .map_ok(|_| ()) + .boxed(), + LspRequestKind::Completion => project + .completions(&buffer, offset, cx) + .map_ok(|_| ()) + .boxed(), + LspRequestKind::CodeAction => project + .code_actions(&buffer, offset..offset, cx) + .map_ok(|_| ()) + .boxed(), + LspRequestKind::Definition => project + .definition(&buffer, offset, cx) + .map_ok(|_| ()) + .boxed(), + LspRequestKind::Highlights => project + .document_highlights(&buffer, offset, cx) + .map_ok(|_| ()) + .boxed(), + }); + let request = cx.foreground_executor().spawn(process_lsp_request); + if detach { + request.detach(); + } else { + request.await?; + } + } + + ClientOperation::SearchProject { + project_root_name, + is_local, + query, + detach, + } => { + let project = project_for_root_name(client, &project_root_name, cx) + .ok_or(TestError::Inapplicable)?; + + log::info!( + "{}: search {} project {} for {:?}, {}", + client.username, + if is_local { "local" } else { "remote" }, + project_root_name, + query, + if detach { "detaching" } else { "awaiting" } + ); + + let mut search = project.update(cx, |project, cx| { + project.search( + SearchQuery::text(query, false, false, Vec::new(), Vec::new()).unwrap(), + cx, + ) + }); + drop(project); + let search = cx.executor().spawn(async move { + let mut results = HashMap::default(); + while let Some((buffer, ranges)) = search.next().await { + results.entry(buffer).or_insert(ranges); + } + results + }); + search.await; + } + + ClientOperation::WriteFsEntry { + path, + is_dir, + content, + } => { + if !client + .fs() + .directories(false) + .contains(&path.parent().unwrap().to_owned()) + { + return Err(TestError::Inapplicable); + } + + if is_dir { + log::info!("{}: creating dir at {:?}", client.username, path); + client.fs().create_dir(&path).await.unwrap(); + } else { + let exists = client.fs().metadata(&path).await?.is_some(); + let verb = if exists { "updating" } else { "creating" }; + log::info!("{}: {} file at {:?}", verb, client.username, path); + + client + .fs() + .save(&path, &content.as_str().into(), text::LineEnding::Unix) + .await + .unwrap(); + } + } + + ClientOperation::GitOperation { operation } => match operation { + GitOperation::WriteGitIndex { + repo_path, + contents, + } => { + if !client.fs().directories(false).contains(&repo_path) { + return Err(TestError::Inapplicable); + } + + for (path, _) in contents.iter() { + if !client.fs().files().contains(&repo_path.join(path)) { + return Err(TestError::Inapplicable); + } + } + + log::info!( + "{}: writing git index for repo {:?}: {:?}", + client.username, + repo_path, + contents + ); + + let dot_git_dir = repo_path.join(".git"); + let contents = contents + .iter() + .map(|(path, contents)| (path.as_path(), contents.clone())) + .collect::>(); + if client.fs().metadata(&dot_git_dir).await?.is_none() { + client.fs().create_dir(&dot_git_dir).await?; + } + client.fs().set_index_for_repo(&dot_git_dir, &contents); + } + GitOperation::WriteGitBranch { + repo_path, + new_branch, + } => { + if !client.fs().directories(false).contains(&repo_path) { + return Err(TestError::Inapplicable); + } + + log::info!( + "{}: writing git branch for repo {:?}: {:?}", + client.username, + repo_path, + new_branch + ); + + let dot_git_dir = repo_path.join(".git"); + if client.fs().metadata(&dot_git_dir).await?.is_none() { + client.fs().create_dir(&dot_git_dir).await?; + } + client + .fs() + .set_branch_name(&dot_git_dir, new_branch.clone()); + } + GitOperation::WriteGitStatuses { + repo_path, + statuses, + git_operation, + } => { + if !client.fs().directories(false).contains(&repo_path) { + return Err(TestError::Inapplicable); + } + for (path, _) in statuses.iter() { + if !client.fs().files().contains(&repo_path.join(path)) { + return Err(TestError::Inapplicable); + } + } + + log::info!( + "{}: writing git statuses for repo {:?}: {:?}", + client.username, + repo_path, + statuses + ); + + let dot_git_dir = repo_path.join(".git"); + + let statuses = statuses + .iter() + .map(|(path, val)| (path.as_path(), val.clone())) + .collect::>(); + + if client.fs().metadata(&dot_git_dir).await?.is_none() { + client.fs().create_dir(&dot_git_dir).await?; + } + + if git_operation { + client.fs().set_status_for_repo_via_git_operation( + &dot_git_dir, + statuses.as_slice(), + ); + } else { + client.fs().set_status_for_repo_via_working_copy_change( + &dot_git_dir, + statuses.as_slice(), + ); + } + } + }, + } + Ok(()) + } + + async fn on_client_added(client: &Rc, _: &mut TestAppContext) { + let mut language = Language::new( + LanguageConfig { + name: "Rust".into(), + path_suffixes: vec!["rs".to_string()], + ..Default::default() + }, + None, + ); + language + .set_fake_lsp_adapter(Arc::new(FakeLspAdapter { + name: "the-fake-language-server", + capabilities: lsp::LanguageServer::full_capabilities(), + initializer: Some(Box::new({ + let fs = client.app_state.fs.clone(); + move |fake_server: &mut FakeLanguageServer| { + fake_server.handle_request::( + |_, _| async move { + Ok(Some(lsp::CompletionResponse::Array(vec![ + lsp::CompletionItem { + text_edit: Some(lsp::CompletionTextEdit::Edit( + lsp::TextEdit { + range: lsp::Range::new( + lsp::Position::new(0, 0), + lsp::Position::new(0, 0), + ), + new_text: "the-new-text".to_string(), + }, + )), + ..Default::default() + }, + ]))) + }, + ); + + fake_server.handle_request::( + |_, _| async move { + Ok(Some(vec![lsp::CodeActionOrCommand::CodeAction( + lsp::CodeAction { + title: "the-code-action".to_string(), + ..Default::default() + }, + )])) + }, + ); + + fake_server.handle_request::( + |params, _| async move { + Ok(Some(lsp::PrepareRenameResponse::Range(lsp::Range::new( + params.position, + params.position, + )))) + }, + ); + + fake_server.handle_request::({ + let fs = fs.clone(); + move |_, cx| { + let background = cx.background_executor(); + let mut rng = background.rng(); + let count = rng.gen_range::(1..3); + let files = fs.as_fake().files(); + let files = (0..count) + .map(|_| files.choose(&mut rng).unwrap().clone()) + .collect::>(); + async move { + log::info!("LSP: Returning definitions in files {:?}", &files); + Ok(Some(lsp::GotoDefinitionResponse::Array( + files + .into_iter() + .map(|file| lsp::Location { + uri: lsp::Url::from_file_path(file).unwrap(), + range: Default::default(), + }) + .collect(), + ))) + } + } + }); + + fake_server.handle_request::( + move |_, cx| { + let mut highlights = Vec::new(); + let background = cx.background_executor(); + let mut rng = background.rng(); + + let highlight_count = rng.gen_range(1..=5); + for _ in 0..highlight_count { + let start_row = rng.gen_range(0..100); + let start_column = rng.gen_range(0..100); + let end_row = rng.gen_range(0..100); + let end_column = rng.gen_range(0..100); + let start = PointUtf16::new(start_row, start_column); + let end = PointUtf16::new(end_row, end_column); + let range = if start > end { end..start } else { start..end }; + highlights.push(lsp::DocumentHighlight { + range: range_to_lsp(range.clone()), + kind: Some(lsp::DocumentHighlightKind::READ), + }); + } + highlights.sort_unstable_by_key(|highlight| { + (highlight.range.start, highlight.range.end) + }); + async move { Ok(Some(highlights)) } + }, + ); + } + })), + ..Default::default() + })) + .await; + client.app_state.languages.add(Arc::new(language)); + } + + async fn on_quiesce(_: &mut TestServer, clients: &mut [(Rc, TestAppContext)]) { + for (client, client_cx) in clients.iter() { + for guest_project in client.remote_projects().iter() { + guest_project.read_with(client_cx, |guest_project, cx| { + let host_project = clients.iter().find_map(|(client, cx)| { + let project = client + .local_projects() + .iter() + .find(|host_project| { + host_project.read_with(cx, |host_project, _| { + host_project.remote_id() == guest_project.remote_id() + }) + })? + .clone(); + Some((project, cx)) + }); + + if !guest_project.is_read_only() { + if let Some((host_project, host_cx)) = host_project { + let host_worktree_snapshots = + host_project.read_with(host_cx, |host_project, cx| { + host_project + .worktrees() + .map(|worktree| { + let worktree = worktree.read(cx); + (worktree.id(), worktree.snapshot()) + }) + .collect::>() + }); + let guest_worktree_snapshots = guest_project + .worktrees() + .map(|worktree| { + let worktree = worktree.read(cx); + (worktree.id(), worktree.snapshot()) + }) + .collect::>(); + + assert_eq!( + guest_worktree_snapshots.values().map(|w| w.abs_path()).collect::>(), + host_worktree_snapshots.values().map(|w| w.abs_path()).collect::>(), + "{} has different worktrees than the host for project {:?}", + client.username, guest_project.remote_id(), + ); + + for (id, host_snapshot) in &host_worktree_snapshots { + let guest_snapshot = &guest_worktree_snapshots[id]; + assert_eq!( + guest_snapshot.root_name(), + host_snapshot.root_name(), + "{} has different root name than the host for worktree {}, project {:?}", + client.username, + id, + guest_project.remote_id(), + ); + assert_eq!( + guest_snapshot.abs_path(), + host_snapshot.abs_path(), + "{} has different abs path than the host for worktree {}, project: {:?}", + client.username, + id, + guest_project.remote_id(), + ); + assert_eq!( + guest_snapshot.entries(false).collect::>(), + host_snapshot.entries(false).collect::>(), + "{} has different snapshot than the host for worktree {:?} ({:?}) and project {:?}", + client.username, + host_snapshot.abs_path(), + id, + guest_project.remote_id(), + ); + assert_eq!(guest_snapshot.repositories().collect::>(), host_snapshot.repositories().collect::>(), + "{} has different repositories than the host for worktree {:?} and project {:?}", + client.username, + host_snapshot.abs_path(), + guest_project.remote_id(), + ); + assert_eq!(guest_snapshot.scan_id(), host_snapshot.scan_id(), + "{} has different scan id than the host for worktree {:?} and project {:?}", + client.username, + host_snapshot.abs_path(), + guest_project.remote_id(), + ); + } + } + } + + for buffer in guest_project.opened_buffers() { + let buffer = buffer.read(cx); + assert_eq!( + buffer.deferred_ops_len(), + 0, + "{} has deferred operations for buffer {:?} in project {:?}", + client.username, + buffer.file().unwrap().full_path(cx), + guest_project.remote_id(), + ); + } + }); + } + + let buffers = client.buffers().clone(); + for (guest_project, guest_buffers) in &buffers { + let project_id = if guest_project.read_with(client_cx, |project, _| { + project.is_local() || project.is_read_only() + }) { + continue; + } else { + guest_project + .read_with(client_cx, |project, _| project.remote_id()) + .unwrap() + }; + let guest_user_id = client.user_id().unwrap(); + + let host_project = clients.iter().find_map(|(client, cx)| { + let project = client + .local_projects() + .iter() + .find(|host_project| { + host_project.read_with(cx, |host_project, _| { + host_project.remote_id() == Some(project_id) + }) + })? + .clone(); + Some((client.user_id().unwrap(), project, cx)) + }); + + let (host_user_id, host_project, host_cx) = + if let Some((host_user_id, host_project, host_cx)) = host_project { + (host_user_id, host_project, host_cx) + } else { + continue; + }; + + for guest_buffer in guest_buffers { + let buffer_id = + guest_buffer.read_with(client_cx, |buffer, _| buffer.remote_id()); + let host_buffer = host_project.read_with(host_cx, |project, _| { + project.buffer_for_id(buffer_id).unwrap_or_else(|| { + panic!( + "host does not have buffer for guest:{}, peer:{:?}, id:{}", + client.username, + client.peer_id(), + buffer_id + ) + }) + }); + let path = host_buffer + .read_with(host_cx, |buffer, cx| buffer.file().unwrap().full_path(cx)); + + assert_eq!( + guest_buffer.read_with(client_cx, |buffer, _| buffer.deferred_ops_len()), + 0, + "{}, buffer {}, path {:?} has deferred operations", + client.username, + buffer_id, + path, + ); + assert_eq!( + guest_buffer.read_with(client_cx, |buffer, _| buffer.text()), + host_buffer.read_with(host_cx, |buffer, _| buffer.text()), + "{}, buffer {}, path {:?}, differs from the host's buffer", + client.username, + buffer_id, + path + ); + + let host_file = host_buffer.read_with(host_cx, |b, _| b.file().cloned()); + let guest_file = guest_buffer.read_with(client_cx, |b, _| b.file().cloned()); + match (host_file, guest_file) { + (Some(host_file), Some(guest_file)) => { + assert_eq!(guest_file.path(), host_file.path()); + assert_eq!(guest_file.is_deleted(), host_file.is_deleted()); + assert_eq!( + guest_file.mtime(), + host_file.mtime(), + "guest {} mtime does not match host {} for path {:?} in project {}", + guest_user_id, + host_user_id, + guest_file.path(), + project_id, + ); + } + (None, None) => {} + (None, _) => panic!("host's file is None, guest's isn't"), + (_, None) => panic!("guest's file is None, hosts's isn't"), + } + + let host_diff_base = host_buffer + .read_with(host_cx, |b, _| b.diff_base().map(ToString::to_string)); + let guest_diff_base = guest_buffer + .read_with(client_cx, |b, _| b.diff_base().map(ToString::to_string)); + assert_eq!( + guest_diff_base, host_diff_base, + "guest {} diff base does not match host's for path {path:?} in project {project_id}", + client.username + ); + + let host_saved_version = + host_buffer.read_with(host_cx, |b, _| b.saved_version().clone()); + let guest_saved_version = + guest_buffer.read_with(client_cx, |b, _| b.saved_version().clone()); + assert_eq!( + guest_saved_version, host_saved_version, + "guest {} saved version does not match host's for path {path:?} in project {project_id}", + client.username + ); + + let host_saved_version_fingerprint = + host_buffer.read_with(host_cx, |b, _| b.saved_version_fingerprint()); + let guest_saved_version_fingerprint = + guest_buffer.read_with(client_cx, |b, _| b.saved_version_fingerprint()); + assert_eq!( + guest_saved_version_fingerprint, host_saved_version_fingerprint, + "guest {} saved fingerprint does not match host's for path {path:?} in project {project_id}", + client.username + ); + + let host_saved_mtime = host_buffer.read_with(host_cx, |b, _| b.saved_mtime()); + let guest_saved_mtime = + guest_buffer.read_with(client_cx, |b, _| b.saved_mtime()); + assert_eq!( + guest_saved_mtime, host_saved_mtime, + "guest {} saved mtime does not match host's for path {path:?} in project {project_id}", + client.username + ); + + let host_is_dirty = host_buffer.read_with(host_cx, |b, _| b.is_dirty()); + let guest_is_dirty = guest_buffer.read_with(client_cx, |b, _| b.is_dirty()); + assert_eq!(guest_is_dirty, host_is_dirty, + "guest {} dirty status does not match host's for path {path:?} in project {project_id}", + client.username + ); + + let host_has_conflict = host_buffer.read_with(host_cx, |b, _| b.has_conflict()); + let guest_has_conflict = + guest_buffer.read_with(client_cx, |b, _| b.has_conflict()); + assert_eq!(guest_has_conflict, host_has_conflict, + "guest {} conflict status does not match host's for path {path:?} in project {project_id}", + client.username + ); + } + } + } + } +} + +fn generate_git_operation(rng: &mut StdRng, client: &TestClient) -> GitOperation { + fn generate_file_paths( + repo_path: &Path, + rng: &mut StdRng, + client: &TestClient, + ) -> Vec { + let mut paths = client + .fs() + .files() + .into_iter() + .filter(|path| path.starts_with(repo_path)) + .collect::>(); + + let count = rng.gen_range(0..=paths.len()); + paths.shuffle(rng); + paths.truncate(count); + + paths + .iter() + .map(|path| path.strip_prefix(repo_path).unwrap().to_path_buf()) + .collect::>() + } + + let repo_path = client.fs().directories(false).choose(rng).unwrap().clone(); + + match rng.gen_range(0..100_u32) { + 0..=25 => { + let file_paths = generate_file_paths(&repo_path, rng, client); + + let contents = file_paths + .into_iter() + .map(|path| (path, Alphanumeric.sample_string(rng, 16))) + .collect(); + + GitOperation::WriteGitIndex { + repo_path, + contents, + } + } + 26..=63 => { + let new_branch = (rng.gen_range(0..10) > 3).then(|| Alphanumeric.sample_string(rng, 8)); + + GitOperation::WriteGitBranch { + repo_path, + new_branch, + } + } + 64..=100 => { + let file_paths = generate_file_paths(&repo_path, rng, client); + + let statuses = file_paths + .into_iter() + .map(|paths| { + ( + paths, + match rng.gen_range(0..3_u32) { + 0 => GitFileStatus::Added, + 1 => GitFileStatus::Modified, + 2 => GitFileStatus::Conflict, + _ => unreachable!(), + }, + ) + }) + .collect::>(); + + let git_operation = rng.gen::(); + + GitOperation::WriteGitStatuses { + repo_path, + statuses, + git_operation, + } + } + _ => unreachable!(), + } +} + +fn buffer_for_full_path( + client: &TestClient, + project: &Model, + full_path: &PathBuf, + cx: &TestAppContext, +) -> Option> { + client + .buffers_for_project(project) + .iter() + .find(|buffer| { + buffer.read_with(cx, |buffer, cx| { + buffer.file().unwrap().full_path(cx) == *full_path + }) + }) + .cloned() +} + +fn project_for_root_name( + client: &TestClient, + root_name: &str, + cx: &TestAppContext, +) -> Option> { + if let Some(ix) = project_ix_for_root_name(&*client.local_projects().deref(), root_name, cx) { + return Some(client.local_projects()[ix].clone()); + } + if let Some(ix) = project_ix_for_root_name(&*client.remote_projects().deref(), root_name, cx) { + return Some(client.remote_projects()[ix].clone()); + } + None +} + +fn project_ix_for_root_name( + projects: &[Model], + root_name: &str, + cx: &TestAppContext, +) -> Option { + projects.iter().position(|project| { + project.read_with(cx, |project, cx| { + let worktree = project.visible_worktrees(cx).next().unwrap(); + worktree.read(cx).root_name() == root_name + }) + }) +} + +fn root_name_for_project(project: &Model, cx: &TestAppContext) -> String { + project.read_with(cx, |project, cx| { + project + .visible_worktrees(cx) + .next() + .unwrap() + .read(cx) + .root_name() + .to_string() + }) +} + +fn project_path_for_full_path( + project: &Model, + full_path: &Path, + cx: &TestAppContext, +) -> Option { + let mut components = full_path.components(); + let root_name = components.next().unwrap().as_os_str().to_str().unwrap(); + let path = components.as_path().into(); + let worktree_id = project.read_with(cx, |project, cx| { + project.worktrees().find_map(|worktree| { + let worktree = worktree.read(cx); + if worktree.root_name() == root_name { + Some(worktree.id()) + } else { + None + } + }) + })?; + Some(ProjectPath { worktree_id, path }) +} + +async fn ensure_project_shared( + project: &Model, + client: &TestClient, + cx: &mut TestAppContext, +) { + let first_root_name = root_name_for_project(project, cx); + let active_call = cx.read(ActiveCall::global); + if active_call.read_with(cx, |call, _| call.room().is_some()) + && project.read_with(cx, |project, _| project.is_local() && !project.is_shared()) + { + match active_call + .update(cx, |call, cx| call.share_project(project.clone(), cx)) + .await + { + Ok(project_id) => { + log::info!( + "{}: shared project {} with id {}", + client.username, + first_root_name, + project_id + ); + } + Err(error) => { + log::error!( + "{}: error sharing project {}: {:?}", + client.username, + first_root_name, + error + ); + } + } + } +} + +fn choose_random_project(client: &TestClient, rng: &mut StdRng) -> Option> { + client + .local_projects() + .deref() + .iter() + .chain(client.remote_projects().iter()) + .choose(rng) + .cloned() +} + +fn gen_file_name(rng: &mut StdRng) -> String { + let mut name = String::new(); + for _ in 0..10 { + let letter = rng.gen_range('a'..='z'); + name.push(letter); + } + name +} diff --git a/crates/collab2/src/tests/randomized_test_helpers.rs b/crates/collab2/src/tests/randomized_test_helpers.rs new file mode 100644 index 0000000000..ac63738a36 --- /dev/null +++ b/crates/collab2/src/tests/randomized_test_helpers.rs @@ -0,0 +1,677 @@ +use crate::{ + db::{self, NewUserParams, UserId}, + rpc::{CLEANUP_TIMEOUT, RECONNECT_TIMEOUT}, + tests::{TestClient, TestServer}, +}; +use async_trait::async_trait; +use futures::StreamExt; +use gpui::{BackgroundExecutor, Task, TestAppContext}; +use parking_lot::Mutex; +use rand::prelude::*; +use rpc::RECEIVE_TIMEOUT; +use serde::{de::DeserializeOwned, Deserialize, Serialize}; +use settings::SettingsStore; +use std::{ + env, + path::PathBuf, + rc::Rc, + sync::{ + atomic::{AtomicBool, Ordering::SeqCst}, + Arc, + }, +}; + +lazy_static::lazy_static! { + static ref PLAN_LOAD_PATH: Option = path_env_var("LOAD_PLAN"); + static ref PLAN_SAVE_PATH: Option = path_env_var("SAVE_PLAN"); + static ref MAX_PEERS: usize = env::var("MAX_PEERS") + .map(|i| i.parse().expect("invalid `MAX_PEERS` variable")) + .unwrap_or(3); + static ref MAX_OPERATIONS: usize = env::var("OPERATIONS") + .map(|i| i.parse().expect("invalid `OPERATIONS` variable")) + .unwrap_or(10); + +} + +static LOADED_PLAN_JSON: Mutex>> = Mutex::new(None); +static LAST_PLAN: Mutex Vec>>> = Mutex::new(None); + +struct TestPlan { + rng: StdRng, + replay: bool, + stored_operations: Vec<(StoredOperation, Arc)>, + max_operations: usize, + operation_ix: usize, + users: Vec, + next_batch_id: usize, + allow_server_restarts: bool, + allow_client_reconnection: bool, + allow_client_disconnection: bool, +} + +pub struct UserTestPlan { + pub user_id: UserId, + pub username: String, + pub allow_client_reconnection: bool, + pub allow_client_disconnection: bool, + next_root_id: usize, + operation_ix: usize, + online: bool, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(untagged)] +enum StoredOperation { + Server(ServerOperation), + Client { + user_id: UserId, + batch_id: usize, + operation: T, + }, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +enum ServerOperation { + AddConnection { + user_id: UserId, + }, + RemoveConnection { + user_id: UserId, + }, + BounceConnection { + user_id: UserId, + }, + RestartServer, + MutateClients { + batch_id: usize, + #[serde(skip_serializing)] + #[serde(skip_deserializing)] + user_ids: Vec, + quiesce: bool, + }, +} + +pub enum TestError { + Inapplicable, + Other(anyhow::Error), +} + +#[async_trait(?Send)] +pub trait RandomizedTest: 'static + Sized { + type Operation: Send + Clone + Serialize + DeserializeOwned; + + fn generate_operation( + client: &TestClient, + rng: &mut StdRng, + plan: &mut UserTestPlan, + cx: &TestAppContext, + ) -> Self::Operation; + + async fn apply_operation( + client: &TestClient, + operation: Self::Operation, + cx: &mut TestAppContext, + ) -> Result<(), TestError>; + + async fn initialize(server: &mut TestServer, users: &[UserTestPlan]); + + async fn on_client_added(client: &Rc, cx: &mut TestAppContext); + + async fn on_quiesce(server: &mut TestServer, client: &mut [(Rc, TestAppContext)]); +} + +pub async fn run_randomized_test( + cx: &mut TestAppContext, + executor: BackgroundExecutor, + rng: StdRng, +) { + let mut server = TestServer::start(executor.clone()).await; + let plan = TestPlan::::new(&mut server, rng).await; + + LAST_PLAN.lock().replace({ + let plan = plan.clone(); + Box::new(move || plan.lock().serialize()) + }); + + let mut clients = Vec::new(); + let mut client_tasks = Vec::new(); + let mut operation_channels = Vec::new(); + loop { + let Some((next_operation, applied)) = plan.lock().next_server_operation(&clients) else { + break; + }; + applied.store(true, SeqCst); + let did_apply = TestPlan::apply_server_operation( + plan.clone(), + executor.clone(), + &mut server, + &mut clients, + &mut client_tasks, + &mut operation_channels, + next_operation, + cx, + ) + .await; + if !did_apply { + applied.store(false, SeqCst); + } + } + + drop(operation_channels); + executor.start_waiting(); + futures::future::join_all(client_tasks).await; + executor.finish_waiting(); + + executor.run_until_parked(); + T::on_quiesce(&mut server, &mut clients).await; + + for (client, cx) in clients { + cx.update(|cx| { + let store = cx.remove_global::(); + cx.clear_globals(); + cx.set_global(store); + drop(client); + }); + } + executor.run_until_parked(); + + if let Some(path) = &*PLAN_SAVE_PATH { + eprintln!("saved test plan to path {:?}", path); + std::fs::write(path, plan.lock().serialize()).unwrap(); + } +} + +pub fn save_randomized_test_plan() { + if let Some(serialize_plan) = LAST_PLAN.lock().take() { + if let Some(path) = &*PLAN_SAVE_PATH { + eprintln!("saved test plan to path {:?}", path); + std::fs::write(path, serialize_plan()).unwrap(); + } + } +} + +impl TestPlan { + pub async fn new(server: &mut TestServer, mut rng: StdRng) -> Arc> { + let allow_server_restarts = rng.gen_bool(0.7); + let allow_client_reconnection = rng.gen_bool(0.7); + let allow_client_disconnection = rng.gen_bool(0.1); + + let mut users = Vec::new(); + for ix in 0..*MAX_PEERS { + let username = format!("user-{}", ix + 1); + let user_id = server + .app_state + .db + .create_user( + &format!("{username}@example.com"), + false, + NewUserParams { + github_login: username.clone(), + github_user_id: ix as i32, + }, + ) + .await + .unwrap() + .user_id; + users.push(UserTestPlan { + user_id, + username, + online: false, + next_root_id: 0, + operation_ix: 0, + allow_client_disconnection, + allow_client_reconnection, + }); + } + + T::initialize(server, &users).await; + + let plan = Arc::new(Mutex::new(Self { + replay: false, + allow_server_restarts, + allow_client_reconnection, + allow_client_disconnection, + stored_operations: Vec::new(), + operation_ix: 0, + next_batch_id: 0, + max_operations: *MAX_OPERATIONS, + users, + rng, + })); + + if let Some(path) = &*PLAN_LOAD_PATH { + let json = LOADED_PLAN_JSON + .lock() + .get_or_insert_with(|| { + eprintln!("loaded test plan from path {:?}", path); + std::fs::read(path).unwrap() + }) + .clone(); + plan.lock().deserialize(json); + } + + plan + } + + fn deserialize(&mut self, json: Vec) { + let stored_operations: Vec> = + serde_json::from_slice(&json).unwrap(); + self.replay = true; + self.stored_operations = stored_operations + .iter() + .cloned() + .enumerate() + .map(|(i, mut operation)| { + let did_apply = Arc::new(AtomicBool::new(false)); + if let StoredOperation::Server(ServerOperation::MutateClients { + batch_id: current_batch_id, + user_ids, + .. + }) = &mut operation + { + assert!(user_ids.is_empty()); + user_ids.extend(stored_operations[i + 1..].iter().filter_map(|operation| { + if let StoredOperation::Client { + user_id, batch_id, .. + } = operation + { + if batch_id == current_batch_id { + return Some(user_id); + } + } + None + })); + user_ids.sort_unstable(); + } + (operation, did_apply) + }) + .collect() + } + + fn serialize(&mut self) -> Vec { + // Format each operation as one line + let mut json = Vec::new(); + json.push(b'['); + for (operation, applied) in &self.stored_operations { + if !applied.load(SeqCst) { + continue; + } + if json.len() > 1 { + json.push(b','); + } + json.extend_from_slice(b"\n "); + serde_json::to_writer(&mut json, operation).unwrap(); + } + json.extend_from_slice(b"\n]\n"); + json + } + + fn next_server_operation( + &mut self, + clients: &[(Rc, TestAppContext)], + ) -> Option<(ServerOperation, Arc)> { + if self.replay { + while let Some(stored_operation) = self.stored_operations.get(self.operation_ix) { + self.operation_ix += 1; + if let (StoredOperation::Server(operation), applied) = stored_operation { + return Some((operation.clone(), applied.clone())); + } + } + None + } else { + let operation = self.generate_server_operation(clients)?; + let applied = Arc::new(AtomicBool::new(false)); + self.stored_operations + .push((StoredOperation::Server(operation.clone()), applied.clone())); + Some((operation, applied)) + } + } + + fn next_client_operation( + &mut self, + client: &TestClient, + current_batch_id: usize, + cx: &TestAppContext, + ) -> Option<(T::Operation, Arc)> { + let current_user_id = client.current_user_id(cx); + let user_ix = self + .users + .iter() + .position(|user| user.user_id == current_user_id) + .unwrap(); + let user_plan = &mut self.users[user_ix]; + + if self.replay { + while let Some(stored_operation) = self.stored_operations.get(user_plan.operation_ix) { + user_plan.operation_ix += 1; + if let ( + StoredOperation::Client { + user_id, operation, .. + }, + applied, + ) = stored_operation + { + if user_id == ¤t_user_id { + return Some((operation.clone(), applied.clone())); + } + } + } + None + } else { + if self.operation_ix == self.max_operations { + return None; + } + self.operation_ix += 1; + let operation = T::generate_operation( + client, + &mut self.rng, + self.users + .iter_mut() + .find(|user| user.user_id == current_user_id) + .unwrap(), + cx, + ); + let applied = Arc::new(AtomicBool::new(false)); + self.stored_operations.push(( + StoredOperation::Client { + user_id: current_user_id, + batch_id: current_batch_id, + operation: operation.clone(), + }, + applied.clone(), + )); + Some((operation, applied)) + } + } + + fn generate_server_operation( + &mut self, + clients: &[(Rc, TestAppContext)], + ) -> Option { + if self.operation_ix == self.max_operations { + return None; + } + + Some(loop { + break match self.rng.gen_range(0..100) { + 0..=29 if clients.len() < self.users.len() => { + let user = self + .users + .iter() + .filter(|u| !u.online) + .choose(&mut self.rng) + .unwrap(); + self.operation_ix += 1; + ServerOperation::AddConnection { + user_id: user.user_id, + } + } + 30..=34 if clients.len() > 1 && self.allow_client_disconnection => { + let (client, cx) = &clients[self.rng.gen_range(0..clients.len())]; + let user_id = client.current_user_id(cx); + self.operation_ix += 1; + ServerOperation::RemoveConnection { user_id } + } + 35..=39 if clients.len() > 1 && self.allow_client_reconnection => { + let (client, cx) = &clients[self.rng.gen_range(0..clients.len())]; + let user_id = client.current_user_id(cx); + self.operation_ix += 1; + ServerOperation::BounceConnection { user_id } + } + 40..=44 if self.allow_server_restarts && clients.len() > 1 => { + self.operation_ix += 1; + ServerOperation::RestartServer + } + _ if !clients.is_empty() => { + let count = self + .rng + .gen_range(1..10) + .min(self.max_operations - self.operation_ix); + let batch_id = util::post_inc(&mut self.next_batch_id); + let mut user_ids = (0..count) + .map(|_| { + let ix = self.rng.gen_range(0..clients.len()); + let (client, cx) = &clients[ix]; + client.current_user_id(cx) + }) + .collect::>(); + user_ids.sort_unstable(); + ServerOperation::MutateClients { + user_ids, + batch_id, + quiesce: self.rng.gen_bool(0.7), + } + } + _ => continue, + }; + }) + } + + async fn apply_server_operation( + plan: Arc>, + deterministic: BackgroundExecutor, + server: &mut TestServer, + clients: &mut Vec<(Rc, TestAppContext)>, + client_tasks: &mut Vec>, + operation_channels: &mut Vec>, + operation: ServerOperation, + cx: &mut TestAppContext, + ) -> bool { + match operation { + ServerOperation::AddConnection { user_id } => { + let username; + { + let mut plan = plan.lock(); + let user = plan.user(user_id); + if user.online { + return false; + } + user.online = true; + username = user.username.clone(); + }; + log::info!("adding new connection for {}", username); + + let mut client_cx = cx.new_app(); + + let (operation_tx, operation_rx) = futures::channel::mpsc::unbounded(); + let client = Rc::new(server.create_client(&mut client_cx, &username).await); + operation_channels.push(operation_tx); + clients.push((client.clone(), client_cx.clone())); + + let foreground_executor = client_cx.foreground_executor().clone(); + let simulate_client = + Self::simulate_client(plan.clone(), client, operation_rx, client_cx); + client_tasks.push(foreground_executor.spawn(simulate_client)); + + log::info!("added connection for {}", username); + } + + ServerOperation::RemoveConnection { + user_id: removed_user_id, + } => { + log::info!("simulating full disconnection of user {}", removed_user_id); + let client_ix = clients + .iter() + .position(|(client, cx)| client.current_user_id(cx) == removed_user_id); + let Some(client_ix) = client_ix else { + return false; + }; + let user_connection_ids = server + .connection_pool + .lock() + .user_connection_ids(removed_user_id) + .collect::>(); + assert_eq!(user_connection_ids.len(), 1); + let removed_peer_id = user_connection_ids[0].into(); + let (client, client_cx) = clients.remove(client_ix); + let client_task = client_tasks.remove(client_ix); + operation_channels.remove(client_ix); + server.forbid_connections(); + server.disconnect_client(removed_peer_id); + deterministic.advance_clock(RECEIVE_TIMEOUT + RECONNECT_TIMEOUT); + deterministic.start_waiting(); + log::info!("waiting for user {} to exit...", removed_user_id); + client_task.await; + deterministic.finish_waiting(); + server.allow_connections(); + + for project in client.remote_projects().iter() { + project.read_with(&client_cx, |project, _| { + assert!( + project.is_read_only(), + "project {:?} should be read only", + project.remote_id() + ) + }); + } + + for (client, cx) in clients { + let contacts = server + .app_state + .db + .get_contacts(client.current_user_id(cx)) + .await + .unwrap(); + let pool = server.connection_pool.lock(); + for contact in contacts { + if let db::Contact::Accepted { user_id, busy, .. } = contact { + if user_id == removed_user_id { + assert!(!pool.is_user_online(user_id)); + assert!(!busy); + } + } + } + } + + log::info!("{} removed", client.username); + plan.lock().user(removed_user_id).online = false; + client_cx.update(|cx| { + cx.clear_globals(); + drop(client); + }); + } + + ServerOperation::BounceConnection { user_id } => { + log::info!("simulating temporary disconnection of user {}", user_id); + let user_connection_ids = server + .connection_pool + .lock() + .user_connection_ids(user_id) + .collect::>(); + if user_connection_ids.is_empty() { + return false; + } + assert_eq!(user_connection_ids.len(), 1); + let peer_id = user_connection_ids[0].into(); + server.disconnect_client(peer_id); + deterministic.advance_clock(RECEIVE_TIMEOUT + RECONNECT_TIMEOUT); + } + + ServerOperation::RestartServer => { + log::info!("simulating server restart"); + server.reset().await; + deterministic.advance_clock(RECEIVE_TIMEOUT); + server.start().await.unwrap(); + deterministic.advance_clock(CLEANUP_TIMEOUT); + let environment = &server.app_state.config.zed_environment; + let (stale_room_ids, _) = server + .app_state + .db + .stale_server_resource_ids(environment, server.id()) + .await + .unwrap(); + assert_eq!(stale_room_ids, vec![]); + } + + ServerOperation::MutateClients { + user_ids, + batch_id, + quiesce, + } => { + let mut applied = false; + for user_id in user_ids { + let client_ix = clients + .iter() + .position(|(client, cx)| client.current_user_id(cx) == user_id); + let Some(client_ix) = client_ix else { continue }; + applied = true; + if let Err(err) = operation_channels[client_ix].unbounded_send(batch_id) { + log::error!("error signaling user {user_id}: {err}"); + } + } + + if quiesce && applied { + deterministic.run_until_parked(); + T::on_quiesce(server, clients).await; + } + + return applied; + } + } + true + } + + async fn simulate_client( + plan: Arc>, + client: Rc, + mut operation_rx: futures::channel::mpsc::UnboundedReceiver, + mut cx: TestAppContext, + ) { + T::on_client_added(&client, &mut cx).await; + + while let Some(batch_id) = operation_rx.next().await { + let Some((operation, applied)) = + plan.lock().next_client_operation(&client, batch_id, &cx) + else { + break; + }; + applied.store(true, SeqCst); + match T::apply_operation(&client, operation, &mut cx).await { + Ok(()) => {} + Err(TestError::Inapplicable) => { + applied.store(false, SeqCst); + log::info!("skipped operation"); + } + Err(TestError::Other(error)) => { + log::error!("{} error: {}", client.username, error); + } + } + cx.executor().simulate_random_delay().await; + } + log::info!("{}: done", client.username); + } + + fn user(&mut self, user_id: UserId) -> &mut UserTestPlan { + self.users + .iter_mut() + .find(|user| user.user_id == user_id) + .unwrap() + } +} + +impl UserTestPlan { + pub fn next_root_dir_name(&mut self) -> String { + let user_id = self.user_id; + let root_id = util::post_inc(&mut self.next_root_id); + format!("dir-{user_id}-{root_id}") + } +} + +impl From for TestError { + fn from(value: anyhow::Error) -> Self { + Self::Other(value) + } +} + +fn path_env_var(name: &str) -> Option { + let value = env::var(name).ok()?; + let mut path = PathBuf::from(value); + if path.is_relative() { + let mut abs_path = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + abs_path.pop(); + abs_path.pop(); + abs_path.push(path); + path = abs_path + } + Some(path) +} diff --git a/crates/collab2/src/tests/test_server.rs b/crates/collab2/src/tests/test_server.rs new file mode 100644 index 0000000000..76a587ffde --- /dev/null +++ b/crates/collab2/src/tests/test_server.rs @@ -0,0 +1,624 @@ +use crate::{ + db::{tests::TestDb, NewUserParams, UserId}, + executor::Executor, + rpc::{Server, CLEANUP_TIMEOUT, RECONNECT_TIMEOUT}, + AppState, +}; +use anyhow::anyhow; +use call::ActiveCall; +use channel::{ChannelBuffer, ChannelStore}; +use client::{ + self, proto::PeerId, Client, Connection, Credentials, EstablishConnectionError, UserStore, +}; +use collections::{HashMap, HashSet}; +use fs::FakeFs; +use futures::{channel::oneshot, StreamExt as _}; +use gpui::{BackgroundExecutor, Context, Model, TestAppContext, WindowHandle}; +use language::LanguageRegistry; +use node_runtime::FakeNodeRuntime; + +use parking_lot::Mutex; +use project::{Project, WorktreeId}; +use rpc::{proto::ChannelRole, RECEIVE_TIMEOUT}; +use settings::SettingsStore; +use std::{ + cell::{Ref, RefCell, RefMut}, + env, + ops::{Deref, DerefMut}, + path::Path, + sync::{ + atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst}, + Arc, + }, +}; +use util::http::FakeHttpClient; +use workspace::{Workspace, WorkspaceStore}; + +pub struct TestServer { + pub app_state: Arc, + pub test_live_kit_server: Arc, + server: Arc, + connection_killers: Arc>>>, + forbid_connections: Arc, + _test_db: TestDb, +} + +pub struct TestClient { + pub username: String, + pub app_state: Arc, + channel_store: Model, + // todo!(notifications) + // notification_store: Model, + state: RefCell, +} + +#[derive(Default)] +struct TestClientState { + local_projects: Vec>, + remote_projects: Vec>, + buffers: HashMap, HashSet>>, + channel_buffers: HashSet>, +} + +pub struct ContactsSummary { + pub current: Vec, + pub outgoing_requests: Vec, + pub incoming_requests: Vec, +} + +impl TestServer { + pub async fn start(deterministic: BackgroundExecutor) -> Self { + static NEXT_LIVE_KIT_SERVER_ID: AtomicUsize = AtomicUsize::new(0); + + let use_postgres = env::var("USE_POSTGRES").ok(); + let use_postgres = use_postgres.as_deref(); + let test_db = if use_postgres == Some("true") || use_postgres == Some("1") { + TestDb::postgres(deterministic.clone()) + } else { + TestDb::sqlite(deterministic.clone()) + }; + let live_kit_server_id = NEXT_LIVE_KIT_SERVER_ID.fetch_add(1, SeqCst); + let live_kit_server = live_kit_client::TestServer::create( + format!("http://livekit.{}.test", live_kit_server_id), + format!("devkey-{}", live_kit_server_id), + format!("secret-{}", live_kit_server_id), + deterministic.clone(), + ) + .unwrap(); + let app_state = Self::build_app_state(&test_db, &live_kit_server).await; + let epoch = app_state + .db + .create_server(&app_state.config.zed_environment) + .await + .unwrap(); + let server = Server::new( + epoch, + app_state.clone(), + Executor::Deterministic(deterministic.clone()), + ); + server.start().await.unwrap(); + // Advance clock to ensure the server's cleanup task is finished. + deterministic.advance_clock(CLEANUP_TIMEOUT); + Self { + app_state, + server, + connection_killers: Default::default(), + forbid_connections: Default::default(), + _test_db: test_db, + test_live_kit_server: live_kit_server, + } + } + + pub async fn reset(&self) { + self.app_state.db.reset(); + let epoch = self + .app_state + .db + .create_server(&self.app_state.config.zed_environment) + .await + .unwrap(); + self.server.reset(epoch); + } + + pub async fn create_client(&mut self, cx: &mut TestAppContext, name: &str) -> TestClient { + cx.update(|cx| { + if cx.has_global::() { + panic!("Same cx used to create two test clients") + } + let settings = SettingsStore::test(cx); + cx.set_global(settings); + }); + + let http = FakeHttpClient::with_404_response(); + let user_id = if let Ok(Some(user)) = self.app_state.db.get_user_by_github_login(name).await + { + user.id + } else { + self.app_state + .db + .create_user( + &format!("{name}@example.com"), + false, + NewUserParams { + github_login: name.into(), + github_user_id: 0, + }, + ) + .await + .expect("creating user failed") + .user_id + }; + let client_name = name.to_string(); + let mut client = cx.read(|cx| Client::new(http.clone(), cx)); + let server = self.server.clone(); + let db = self.app_state.db.clone(); + let connection_killers = self.connection_killers.clone(); + let forbid_connections = self.forbid_connections.clone(); + + Arc::get_mut(&mut client) + .unwrap() + .set_id(user_id.to_proto()) + .override_authenticate(move |cx| { + cx.spawn(|_| async move { + let access_token = "the-token".to_string(); + Ok(Credentials { + user_id: user_id.to_proto(), + access_token, + }) + }) + }) + .override_establish_connection(move |credentials, cx| { + assert_eq!(credentials.user_id, user_id.0 as u64); + assert_eq!(credentials.access_token, "the-token"); + + let server = server.clone(); + let db = db.clone(); + let connection_killers = connection_killers.clone(); + let forbid_connections = forbid_connections.clone(); + let client_name = client_name.clone(); + cx.spawn(move |cx| async move { + if forbid_connections.load(SeqCst) { + Err(EstablishConnectionError::other(anyhow!( + "server is forbidding connections" + ))) + } else { + let (client_conn, server_conn, killed) = + Connection::in_memory(cx.background_executor().clone()); + let (connection_id_tx, connection_id_rx) = oneshot::channel(); + let user = db + .get_user_by_id(user_id) + .await + .expect("retrieving user failed") + .unwrap(); + cx.background_executor() + .spawn(server.handle_connection( + server_conn, + client_name, + user, + Some(connection_id_tx), + Executor::Deterministic(cx.background_executor().clone()), + )) + .detach(); + let connection_id = connection_id_rx.await.unwrap(); + connection_killers + .lock() + .insert(connection_id.into(), killed); + Ok(client_conn) + } + }) + }); + + let fs = FakeFs::new(cx.executor().clone()); + let user_store = cx.build_model(|cx| UserStore::new(client.clone(), http, cx)); + let workspace_store = cx.build_model(|cx| WorkspaceStore::new(client.clone(), cx)); + let mut language_registry = LanguageRegistry::test(); + language_registry.set_executor(cx.executor().clone()); + let app_state = Arc::new(workspace::AppState { + client: client.clone(), + user_store: user_store.clone(), + workspace_store, + languages: Arc::new(language_registry), + fs: fs.clone(), + build_window_options: |_, _, _| Default::default(), + initialize_workspace: |_, _, _, _| gpui::Task::ready(Ok(())), + node_runtime: FakeNodeRuntime::new(), + }); + + cx.update(|cx| { + theme::init(cx); + Project::init(&client, cx); + client::init(&client, cx); + language::init(cx); + editor::init_settings(cx); + workspace::init(app_state.clone(), cx); + audio::init((), cx); + call::init(client.clone(), user_store.clone(), cx); + channel::init(&client, user_store.clone(), cx); + //todo(notifications) + // notifications::init(client.clone(), user_store, cx); + }); + + client + .authenticate_and_connect(false, &cx.to_async()) + .await + .unwrap(); + + let client = TestClient { + app_state, + username: name.to_string(), + channel_store: cx.read(ChannelStore::global).clone(), + // todo!(notifications) + // notification_store: cx.read(NotificationStore::global).clone(), + state: Default::default(), + }; + client.wait_for_current_user(cx).await; + client + } + + pub fn disconnect_client(&self, peer_id: PeerId) { + self.connection_killers + .lock() + .remove(&peer_id) + .unwrap() + .store(true, SeqCst); + } + + //todo!(workspace) + #[allow(dead_code)] + pub fn simulate_long_connection_interruption( + &self, + peer_id: PeerId, + deterministic: BackgroundExecutor, + ) { + self.forbid_connections(); + self.disconnect_client(peer_id); + deterministic.advance_clock(RECEIVE_TIMEOUT + RECONNECT_TIMEOUT); + self.allow_connections(); + deterministic.advance_clock(RECEIVE_TIMEOUT + RECONNECT_TIMEOUT); + deterministic.run_until_parked(); + } + + pub fn forbid_connections(&self) { + self.forbid_connections.store(true, SeqCst); + } + + pub fn allow_connections(&self) { + self.forbid_connections.store(false, SeqCst); + } + + pub async fn make_contacts(&self, clients: &mut [(&TestClient, &mut TestAppContext)]) { + for ix in 1..clients.len() { + let (left, right) = clients.split_at_mut(ix); + let (client_a, cx_a) = left.last_mut().unwrap(); + for (client_b, cx_b) in right { + client_a + .app_state + .user_store + .update(*cx_a, |store, cx| { + store.request_contact(client_b.user_id().unwrap(), cx) + }) + .await + .unwrap(); + cx_a.executor().run_until_parked(); + client_b + .app_state + .user_store + .update(*cx_b, |store, cx| { + store.respond_to_contact_request(client_a.user_id().unwrap(), true, cx) + }) + .await + .unwrap(); + } + } + } + + pub async fn make_channel( + &self, + channel: &str, + parent: Option, + admin: (&TestClient, &mut TestAppContext), + members: &mut [(&TestClient, &mut TestAppContext)], + ) -> u64 { + let (_, admin_cx) = admin; + let channel_id = admin_cx + .read(ChannelStore::global) + .update(admin_cx, |channel_store, cx| { + channel_store.create_channel(channel, parent, cx) + }) + .await + .unwrap(); + + for (member_client, member_cx) in members { + admin_cx + .read(ChannelStore::global) + .update(admin_cx, |channel_store, cx| { + channel_store.invite_member( + channel_id, + member_client.user_id().unwrap(), + ChannelRole::Member, + cx, + ) + }) + .await + .unwrap(); + + admin_cx.executor().run_until_parked(); + + member_cx + .read(ChannelStore::global) + .update(*member_cx, |channels, cx| { + channels.respond_to_channel_invite(channel_id, true, cx) + }) + .await + .unwrap(); + } + + channel_id + } + + pub async fn make_channel_tree( + &self, + channels: &[(&str, Option<&str>)], + creator: (&TestClient, &mut TestAppContext), + ) -> Vec { + let mut observed_channels = HashMap::default(); + let mut result = Vec::new(); + for (channel, parent) in channels { + let id; + if let Some(parent) = parent { + if let Some(parent_id) = observed_channels.get(parent) { + id = self + .make_channel(channel, Some(*parent_id), (creator.0, creator.1), &mut []) + .await; + } else { + panic!( + "Edge {}->{} referenced before {} was created", + parent, channel, parent + ) + } + } else { + id = self + .make_channel(channel, None, (creator.0, creator.1), &mut []) + .await; + } + + observed_channels.insert(channel, id); + result.push(id); + } + + result + } + + pub async fn create_room(&self, clients: &mut [(&TestClient, &mut TestAppContext)]) { + self.make_contacts(clients).await; + + let (left, right) = clients.split_at_mut(1); + let (_client_a, cx_a) = &mut left[0]; + let active_call_a = cx_a.read(ActiveCall::global); + + for (client_b, cx_b) in right { + let user_id_b = client_b.current_user_id(*cx_b).to_proto(); + active_call_a + .update(*cx_a, |call, cx| call.invite(user_id_b, None, cx)) + .await + .unwrap(); + + cx_b.executor().run_until_parked(); + let active_call_b = cx_b.read(ActiveCall::global); + active_call_b + .update(*cx_b, |call, cx| call.accept_incoming(cx)) + .await + .unwrap(); + } + } + + pub async fn build_app_state( + test_db: &TestDb, + fake_server: &live_kit_client::TestServer, + ) -> Arc { + Arc::new(AppState { + db: test_db.db().clone(), + live_kit_client: Some(Arc::new(fake_server.create_api_client())), + config: Default::default(), + }) + } +} + +impl Deref for TestServer { + type Target = Server; + + fn deref(&self) -> &Self::Target { + &self.server + } +} + +impl Drop for TestServer { + fn drop(&mut self) { + self.server.teardown(); + self.test_live_kit_server.teardown().unwrap(); + } +} + +impl Deref for TestClient { + type Target = Arc; + + fn deref(&self) -> &Self::Target { + &self.app_state.client + } +} + +impl TestClient { + pub fn fs(&self) -> &FakeFs { + self.app_state.fs.as_fake() + } + + pub fn channel_store(&self) -> &Model { + &self.channel_store + } + + // todo!(notifications) + // pub fn notification_store(&self) -> &Model { + // &self.notification_store + // } + + pub fn user_store(&self) -> &Model { + &self.app_state.user_store + } + + pub fn language_registry(&self) -> &Arc { + &self.app_state.languages + } + + pub fn client(&self) -> &Arc { + &self.app_state.client + } + + pub fn current_user_id(&self, cx: &TestAppContext) -> UserId { + UserId::from_proto( + self.app_state + .user_store + .read_with(cx, |user_store, _| user_store.current_user().unwrap().id), + ) + } + + pub async fn wait_for_current_user(&self, cx: &TestAppContext) { + let mut authed_user = self + .app_state + .user_store + .read_with(cx, |user_store, _| user_store.watch_current_user()); + while authed_user.next().await.unwrap().is_none() {} + } + + pub async fn clear_contacts(&self, cx: &mut TestAppContext) { + self.app_state + .user_store + .update(cx, |store, _| store.clear_contacts()) + .await; + } + + pub fn local_projects<'a>(&'a self) -> impl Deref>> + 'a { + Ref::map(self.state.borrow(), |state| &state.local_projects) + } + + pub fn remote_projects<'a>(&'a self) -> impl Deref>> + 'a { + Ref::map(self.state.borrow(), |state| &state.remote_projects) + } + + pub fn local_projects_mut<'a>(&'a self) -> impl DerefMut>> + 'a { + RefMut::map(self.state.borrow_mut(), |state| &mut state.local_projects) + } + + pub fn remote_projects_mut<'a>(&'a self) -> impl DerefMut>> + 'a { + RefMut::map(self.state.borrow_mut(), |state| &mut state.remote_projects) + } + + pub fn buffers_for_project<'a>( + &'a self, + project: &Model, + ) -> impl DerefMut>> + 'a { + RefMut::map(self.state.borrow_mut(), |state| { + state.buffers.entry(project.clone()).or_default() + }) + } + + pub fn buffers<'a>( + &'a self, + ) -> impl DerefMut, HashSet>>> + 'a + { + RefMut::map(self.state.borrow_mut(), |state| &mut state.buffers) + } + + pub fn channel_buffers<'a>( + &'a self, + ) -> impl DerefMut>> + 'a { + RefMut::map(self.state.borrow_mut(), |state| &mut state.channel_buffers) + } + + pub fn summarize_contacts(&self, cx: &TestAppContext) -> ContactsSummary { + self.app_state + .user_store + .read_with(cx, |store, _| ContactsSummary { + current: store + .contacts() + .iter() + .map(|contact| contact.user.github_login.clone()) + .collect(), + outgoing_requests: store + .outgoing_contact_requests() + .iter() + .map(|user| user.github_login.clone()) + .collect(), + incoming_requests: store + .incoming_contact_requests() + .iter() + .map(|user| user.github_login.clone()) + .collect(), + }) + } + + pub async fn build_local_project( + &self, + root_path: impl AsRef, + cx: &mut TestAppContext, + ) -> (Model, WorktreeId) { + let project = self.build_empty_local_project(cx); + let (worktree, _) = project + .update(cx, |p, cx| { + p.find_or_create_local_worktree(root_path, true, cx) + }) + .await + .unwrap(); + worktree + .read_with(cx, |tree, _| tree.as_local().unwrap().scan_complete()) + .await; + (project, worktree.read_with(cx, |tree, _| tree.id())) + } + + pub fn build_empty_local_project(&self, cx: &mut TestAppContext) -> Model { + cx.update(|cx| { + Project::local( + self.client().clone(), + self.app_state.node_runtime.clone(), + self.app_state.user_store.clone(), + self.app_state.languages.clone(), + self.app_state.fs.clone(), + cx, + ) + }) + } + + pub async fn build_remote_project( + &self, + host_project_id: u64, + guest_cx: &mut TestAppContext, + ) -> Model { + let active_call = guest_cx.read(ActiveCall::global); + let room = active_call.read_with(guest_cx, |call, _| call.room().unwrap().clone()); + room.update(guest_cx, |room, cx| { + room.join_project( + host_project_id, + self.app_state.languages.clone(), + self.app_state.fs.clone(), + cx, + ) + }) + .await + .unwrap() + } + + //todo(workspace) + #[allow(dead_code)] + pub fn build_workspace( + &self, + project: &Model, + cx: &mut TestAppContext, + ) -> WindowHandle { + cx.add_window(|cx| Workspace::new(0, project.clone(), self.app_state.clone(), cx)) + } +} + +impl Drop for TestClient { + fn drop(&mut self) { + self.app_state.client.teardown(); + } +} diff --git a/crates/gpui2/src/app.rs b/crates/gpui2/src/app.rs index 79f80f474d..a3ab426321 100644 --- a/crates/gpui2/src/app.rs +++ b/crates/gpui2/src/app.rs @@ -16,8 +16,8 @@ pub use test_context::*; use crate::{ current_platform, image_cache::ImageCache, Action, AnyBox, AnyView, AnyWindowHandle, AppMetadata, AssetSource, BackgroundExecutor, ClipboardItem, Context, DispatchPhase, DisplayId, - Entity, FocusEvent, FocusHandle, FocusId, ForegroundExecutor, KeyBinding, Keymap, LayoutId, - PathPromptOptions, Pixels, Platform, PlatformDisplay, Point, Render, SubscriberSet, + Entity, EventEmitter, FocusEvent, FocusHandle, FocusId, ForegroundExecutor, KeyBinding, Keymap, + LayoutId, PathPromptOptions, Pixels, Platform, PlatformDisplay, Point, Render, SubscriberSet, Subscription, SvgRenderer, Task, TextStyle, TextStyleRefinement, TextSystem, View, Window, WindowContext, WindowHandle, WindowId, }; @@ -48,15 +48,19 @@ pub struct AppCell { impl AppCell { #[track_caller] pub fn borrow(&self) -> AppRef { - // let thread_id = std::thread::current().id(); - // eprintln!("borrowed {thread_id:?}"); + if let Some(_) = option_env!("TRACK_THREAD_BORROWS") { + let thread_id = std::thread::current().id(); + eprintln!("borrowed {thread_id:?}"); + } AppRef(self.app.borrow()) } #[track_caller] pub fn borrow_mut(&self) -> AppRefMut { - // let thread_id = std::thread::current().id(); - // eprintln!("borrowed {thread_id:?}"); + if let Some(_) = option_env!("TRACK_THREAD_BORROWS") { + let thread_id = std::thread::current().id(); + eprintln!("borrowed {thread_id:?}"); + } AppRefMut(self.app.borrow_mut()) } } @@ -64,9 +68,27 @@ impl AppCell { #[derive(Deref, DerefMut)] pub struct AppRef<'a>(Ref<'a, AppContext>); +impl<'a> Drop for AppRef<'a> { + fn drop(&mut self) { + if let Some(_) = option_env!("TRACK_THREAD_BORROWS") { + let thread_id = std::thread::current().id(); + eprintln!("dropped borrow from {thread_id:?}"); + } + } +} + #[derive(Deref, DerefMut)] pub struct AppRefMut<'a>(RefMut<'a, AppContext>); +impl<'a> Drop for AppRefMut<'a> { + fn drop(&mut self) { + if let Some(_) = option_env!("TRACK_THREAD_BORROWS") { + let thread_id = std::thread::current().id(); + eprintln!("dropped {thread_id:?}"); + } + } +} + pub struct App(Rc); /// Represents an application before it is fully launched. Once your app is @@ -291,6 +313,83 @@ impl AppContext { result } + pub fn observe( + &mut self, + entity: &E, + mut on_notify: impl FnMut(E, &mut AppContext) + 'static, + ) -> Subscription + where + W: 'static, + E: Entity, + { + self.observe_internal(entity, move |e, cx| { + on_notify(e, cx); + true + }) + } + + pub fn observe_internal( + &mut self, + entity: &E, + mut on_notify: impl FnMut(E, &mut AppContext) -> bool + 'static, + ) -> Subscription + where + W: 'static, + E: Entity, + { + let entity_id = entity.entity_id(); + let handle = entity.downgrade(); + self.observers.insert( + entity_id, + Box::new(move |cx| { + if let Some(handle) = E::upgrade_from(&handle) { + on_notify(handle, cx) + } else { + false + } + }), + ) + } + + pub fn subscribe( + &mut self, + entity: &E, + mut on_event: impl FnMut(E, &T::Event, &mut AppContext) + 'static, + ) -> Subscription + where + T: 'static + EventEmitter, + E: Entity, + { + self.subscribe_internal(entity, move |entity, event, cx| { + on_event(entity, event, cx); + true + }) + } + + pub(crate) fn subscribe_internal( + &mut self, + entity: &E, + mut on_event: impl FnMut(E, &T::Event, &mut AppContext) -> bool + 'static, + ) -> Subscription + where + T: 'static + EventEmitter, + E: Entity, + { + let entity_id = entity.entity_id(); + let entity = entity.downgrade(); + self.event_listeners.insert( + entity_id, + Box::new(move |event, cx| { + let event: &T::Event = event.downcast_ref().expect("invalid event type"); + if let Some(handle) = E::upgrade_from(&entity) { + on_event(handle, event, cx) + } else { + false + } + }), + ) + } + pub fn windows(&self) -> Vec { self.windows .values() @@ -624,6 +723,7 @@ impl AppContext { } /// Access the global of the given type. Panics if a global for that type has not been assigned. + #[track_caller] pub fn global(&self) -> &G { self.globals_by_type .get(&TypeId::of::()) @@ -640,6 +740,7 @@ impl AppContext { } /// Access the global of the given type mutably. Panics if a global for that type has not been assigned. + #[track_caller] pub fn global_mut(&mut self) -> &mut G { let global_type = TypeId::of::(); self.push_effect(Effect::NotifyGlobalObservers { global_type }); @@ -669,6 +770,24 @@ impl AppContext { self.globals_by_type.insert(global_type, Box::new(global)); } + /// Clear all stored globals. Does not notify global observers. + #[cfg(any(test, feature = "test-support"))] + pub fn clear_globals(&mut self) { + self.globals_by_type.drain(); + } + + /// Remove the global of the given type from the app context. Does not notify global observers. + #[cfg(any(test, feature = "test-support"))] + pub fn remove_global(&mut self) -> G { + let global_type = TypeId::of::(); + *self + .globals_by_type + .remove(&global_type) + .unwrap_or_else(|| panic!("no global added for {}", std::any::type_name::())) + .downcast() + .unwrap() + } + /// Update the global of the given type with a closure. Unlike `global_mut`, this method provides /// your closure with mutable access to the `AppContext` and the global simultaneously. pub fn update_global(&mut self, f: impl FnOnce(&mut G, &mut Self) -> R) -> R { @@ -828,6 +947,18 @@ impl Context for AppContext { Ok(result) }) } + + fn read_model( + &self, + handle: &Model, + read: impl FnOnce(&T, &AppContext) -> R, + ) -> Self::Result + where + T: 'static, + { + let entity = self.entities.read(handle); + read(entity, self) + } } /// These effects are processed at the end of each application update cycle. diff --git a/crates/gpui2/src/app/async_context.rs b/crates/gpui2/src/app/async_context.rs index e3ae78d78f..c05182444e 100644 --- a/crates/gpui2/src/app/async_context.rs +++ b/crates/gpui2/src/app/async_context.rs @@ -45,6 +45,19 @@ impl Context for AsyncAppContext { Ok(app.update_model(handle, update)) } + fn read_model( + &self, + handle: &Model, + callback: impl FnOnce(&T, &AppContext) -> R, + ) -> Self::Result + where + T: 'static, + { + let app = self.app.upgrade().context("app was released")?; + let lock = app.borrow(); + Ok(lock.read_model(handle, callback)) + } + fn update_window(&mut self, window: AnyWindowHandle, f: F) -> Result where F: FnOnce(AnyView, &mut WindowContext<'_>) -> T, @@ -226,6 +239,17 @@ impl Context for AsyncWindowContext { { self.app.update_window(window, update) } + + fn read_model( + &self, + handle: &Model, + read: impl FnOnce(&T, &AppContext) -> R, + ) -> Self::Result + where + T: 'static, + { + self.app.read_model(handle, read) + } } impl VisualContext for AsyncWindowContext { diff --git a/crates/gpui2/src/app/entity_map.rs b/crates/gpui2/src/app/entity_map.rs index 588091c7a0..1ae9aec9b5 100644 --- a/crates/gpui2/src/app/entity_map.rs +++ b/crates/gpui2/src/app/entity_map.rs @@ -325,6 +325,14 @@ impl Model { cx.entities.read(self) } + pub fn read_with<'a, R, C: Context>( + &self, + cx: &'a C, + f: impl FnOnce(&T, &AppContext) -> R, + ) -> C::Result { + cx.read_model(self, f) + } + /// Update the entity referenced by this model with the given function. /// /// The update function receives a context appropriate for its environment. diff --git a/crates/gpui2/src/app/model_context.rs b/crates/gpui2/src/app/model_context.rs index cb25adfb63..35d41ab362 100644 --- a/crates/gpui2/src/app/model_context.rs +++ b/crates/gpui2/src/app/model_context.rs @@ -49,19 +49,14 @@ impl<'a, T: 'static> ModelContext<'a, T> { E: Entity, { let this = self.weak_model(); - let entity_id = entity.entity_id(); - let handle = entity.downgrade(); - self.app.observers.insert( - entity_id, - Box::new(move |cx| { - if let Some((this, handle)) = this.upgrade().zip(E::upgrade_from(&handle)) { - this.update(cx, |this, cx| on_notify(this, handle, cx)); - true - } else { - false - } - }), - ) + self.app.observe_internal(entity, move |e, cx| { + if let Some(this) = this.upgrade() { + this.update(cx, |this, cx| on_notify(this, e, cx)); + true + } else { + false + } + }) } pub fn subscribe( @@ -75,20 +70,14 @@ impl<'a, T: 'static> ModelContext<'a, T> { E: Entity, { let this = self.weak_model(); - let entity_id = entity.entity_id(); - let entity = entity.downgrade(); - self.app.event_listeners.insert( - entity_id, - Box::new(move |event, cx| { - let event: &T2::Event = event.downcast_ref().expect("invalid event type"); - if let Some((this, handle)) = this.upgrade().zip(E::upgrade_from(&entity)) { - this.update(cx, |this, cx| on_event(this, handle, event, cx)); - true - } else { - false - } - }), - ) + self.app.subscribe_internal(entity, move |e, event, cx| { + if let Some(this) = this.upgrade() { + this.update(cx, |this, cx| on_event(this, e, event, cx)); + true + } else { + false + } + }) } pub fn on_release( @@ -236,6 +225,17 @@ impl<'a, T> Context for ModelContext<'a, T> { { self.app.update_window(window, update) } + + fn read_model( + &self, + handle: &Model, + read: impl FnOnce(&U, &AppContext) -> R, + ) -> Self::Result + where + U: 'static, + { + self.app.read_model(handle, read) + } } impl Borrow for ModelContext<'_, T> { diff --git a/crates/gpui2/src/app/test_context.rs b/crates/gpui2/src/app/test_context.rs index 530138e21e..eb5ce283a5 100644 --- a/crates/gpui2/src/app/test_context.rs +++ b/crates/gpui2/src/app/test_context.rs @@ -1,7 +1,8 @@ use crate::{ AnyView, AnyWindowHandle, AppCell, AppContext, AsyncAppContext, BackgroundExecutor, Context, EventEmitter, ForegroundExecutor, InputEvent, KeyDownEvent, Keystroke, Model, ModelContext, - Result, Task, TestDispatcher, TestPlatform, WindowContext, + Render, Result, Task, TestDispatcher, TestPlatform, ViewContext, VisualContext, WindowContext, + WindowHandle, WindowOptions, }; use anyhow::{anyhow, bail}; use futures::{Stream, StreamExt}; @@ -12,6 +13,7 @@ pub struct TestAppContext { pub app: Rc, pub background_executor: BackgroundExecutor, pub foreground_executor: ForegroundExecutor, + pub dispatcher: TestDispatcher, } impl Context for TestAppContext { @@ -44,13 +46,25 @@ impl Context for TestAppContext { let mut lock = self.app.borrow_mut(); lock.update_window(window, f) } + + fn read_model( + &self, + handle: &Model, + read: impl FnOnce(&T, &AppContext) -> R, + ) -> Self::Result + where + T: 'static, + { + let app = self.app.borrow(); + app.read_model(handle, read) + } } impl TestAppContext { pub fn new(dispatcher: TestDispatcher) -> Self { - let dispatcher = Arc::new(dispatcher); - let background_executor = BackgroundExecutor::new(dispatcher.clone()); - let foreground_executor = ForegroundExecutor::new(dispatcher); + let arc_dispatcher = Arc::new(dispatcher.clone()); + let background_executor = BackgroundExecutor::new(arc_dispatcher.clone()); + let foreground_executor = ForegroundExecutor::new(arc_dispatcher); let platform = Rc::new(TestPlatform::new( background_executor.clone(), foreground_executor.clone(), @@ -61,9 +75,14 @@ impl TestAppContext { app: AppContext::new(platform, asset_source, http_client), background_executor, foreground_executor, + dispatcher: dispatcher.clone(), } } + pub fn new_app(&self) -> TestAppContext { + Self::new(self.dispatcher.clone()) + } + pub fn quit(&self) { self.app.borrow_mut().quit(); } @@ -87,6 +106,20 @@ impl TestAppContext { cx.update(f) } + pub fn read(&self, f: impl FnOnce(&AppContext) -> R) -> R { + let cx = self.app.borrow(); + f(&*cx) + } + + pub fn add_window(&mut self, build_window: F) -> WindowHandle + where + F: FnOnce(&mut ViewContext) -> V, + V: Render, + { + let mut cx = self.app.borrow_mut(); + cx.open_window(WindowOptions::default(), |cx| cx.build_view(build_window)) + } + pub fn spawn(&self, f: impl FnOnce(AsyncAppContext) -> Fut) -> Task where Fut: Future + 'static, diff --git a/crates/gpui2/src/executor.rs b/crates/gpui2/src/executor.rs index b7e3610283..bb9b5d0d79 100644 --- a/crates/gpui2/src/executor.rs +++ b/crates/gpui2/src/executor.rs @@ -17,6 +17,9 @@ use std::{ use util::TryFutureExt; use waker_fn::waker_fn; +#[cfg(any(test, feature = "test-support"))] +use rand::rngs::StdRng; + #[derive(Clone)] pub struct BackgroundExecutor { dispatcher: Arc, @@ -95,6 +98,7 @@ impl BackgroundExecutor { } #[cfg(any(test, feature = "test-support"))] + #[track_caller] pub fn block_test(&self, future: impl Future) -> R { self.block_internal(false, future) } @@ -103,6 +107,7 @@ impl BackgroundExecutor { self.block_internal(true, future) } + #[track_caller] pub(crate) fn block_internal( &self, background_only: bool, @@ -226,6 +231,11 @@ impl BackgroundExecutor { self.dispatcher.as_test().unwrap().allow_parking(); } + #[cfg(any(test, feature = "test-support"))] + pub fn rng(&self) -> StdRng { + self.dispatcher.as_test().unwrap().rng() + } + pub fn num_cpus(&self) -> usize { num_cpus::get() } diff --git a/crates/gpui2/src/gpui2.rs b/crates/gpui2/src/gpui2.rs index 8f3dc6c314..e253872ed4 100644 --- a/crates/gpui2/src/gpui2.rs +++ b/crates/gpui2/src/gpui2.rs @@ -93,6 +93,14 @@ pub trait Context { where T: 'static; + fn read_model( + &self, + handle: &Model, + read: impl FnOnce(&T, &AppContext) -> R, + ) -> Self::Result + where + T: 'static; + fn update_window(&mut self, window: AnyWindowHandle, f: F) -> Result where F: FnOnce(AnyView, &mut WindowContext<'_>) -> T; diff --git a/crates/gpui2/src/platform/test/dispatcher.rs b/crates/gpui2/src/platform/test/dispatcher.rs index 618d8c7917..258c484063 100644 --- a/crates/gpui2/src/platform/test/dispatcher.rs +++ b/crates/gpui2/src/platform/test/dispatcher.rs @@ -127,6 +127,10 @@ impl TestDispatcher { b }) } + + pub fn rng(&self) -> StdRng { + self.state.lock().random.clone() + } } impl Clone for TestDispatcher { diff --git a/crates/gpui2/src/window.rs b/crates/gpui2/src/window.rs index 1474165742..cf138eb1ef 100644 --- a/crates/gpui2/src/window.rs +++ b/crates/gpui2/src/window.rs @@ -1391,6 +1391,18 @@ impl Context for WindowContext<'_> { window.update(self.app, update) } } + + fn read_model( + &self, + handle: &Model, + read: impl FnOnce(&T, &AppContext) -> R, + ) -> Self::Result + where + T: 'static, + { + let entity = self.entities.read(handle); + read(&*entity, &*self.app) + } } impl VisualContext for WindowContext<'_> { @@ -2076,6 +2088,17 @@ impl Context for ViewContext<'_, V> { { self.window_cx.update_window(window, update) } + + fn read_model( + &self, + handle: &Model, + read: impl FnOnce(&T, &AppContext) -> R, + ) -> Self::Result + where + T: 'static, + { + self.window_cx.read_model(handle, read) + } } impl VisualContext for ViewContext<'_, V> { diff --git a/crates/gpui2_macros/src/test.rs b/crates/gpui2_macros/src/test.rs index 05d2c1f63a..70c6da22d5 100644 --- a/crates/gpui2_macros/src/test.rs +++ b/crates/gpui2_macros/src/test.rs @@ -110,6 +110,7 @@ pub fn test(args: TokenStream, function: TokenStream) -> TokenStream { ); )); cx_teardowns.extend(quote!( + dispatcher.run_until_parked(); #cx_varname.quit(); dispatcher.run_until_parked(); )); @@ -174,9 +175,10 @@ pub fn test(args: TokenStream, function: TokenStream) -> TokenStream { )); inner_fn_args.extend(quote!(&mut #cx_varname_lock,)); cx_teardowns.extend(quote!( - #cx_varname_lock.quit(); drop(#cx_varname_lock); dispatcher.run_until_parked(); + #cx_varname.update(|cx| { cx.quit() }); + dispatcher.run_until_parked(); )); continue; } @@ -188,6 +190,7 @@ pub fn test(args: TokenStream, function: TokenStream) -> TokenStream { ); )); cx_teardowns.extend(quote!( + dispatcher.run_until_parked(); #cx_varname.quit(); dispatcher.run_until_parked(); )); diff --git a/crates/live_kit_client2/src/test.rs b/crates/live_kit_client2/src/test.rs index 6367e53ba8..1106e66f31 100644 --- a/crates/live_kit_client2/src/test.rs +++ b/crates/live_kit_client2/src/test.rs @@ -16,7 +16,7 @@ pub struct TestServer { pub api_key: String, pub secret_key: String, rooms: Mutex>, - executor: Arc, + executor: BackgroundExecutor, } impl TestServer { @@ -24,7 +24,7 @@ impl TestServer { url: String, api_key: String, secret_key: String, - executor: Arc, + executor: BackgroundExecutor, ) -> Result> { let mut servers = SERVERS.lock(); if servers.contains_key(&url) { diff --git a/crates/rpc2/Cargo.toml b/crates/rpc2/Cargo.toml index 0995029b30..31dbeb55aa 100644 --- a/crates/rpc2/Cargo.toml +++ b/crates/rpc2/Cargo.toml @@ -26,9 +26,11 @@ parking_lot.workspace = true prost.workspace = true rand.workspace = true rsa = "0.4" +serde_json.workspace = true serde.workspace = true serde_derive.workspace = true smol-timeout = "0.6" +strum.workspace = true tracing = { version = "0.1.34", features = ["log"] } zstd = "0.11" diff --git a/crates/rpc2/src/notification.rs b/crates/rpc2/src/notification.rs new file mode 100644 index 0000000000..c5476469be --- /dev/null +++ b/crates/rpc2/src/notification.rs @@ -0,0 +1,105 @@ +use crate::proto; +use serde::{Deserialize, Serialize}; +use serde_json::{map, Value}; +use strum::{EnumVariantNames, VariantNames as _}; + +const KIND: &'static str = "kind"; +const ENTITY_ID: &'static str = "entity_id"; + +/// A notification that can be stored, associated with a given recipient. +/// +/// This struct is stored in the collab database as JSON, so it shouldn't be +/// changed in a backward-incompatible way. For example, when renaming a +/// variant, add a serde alias for the old name. +/// +/// Most notification types have a special field which is aliased to +/// `entity_id`. This field is stored in its own database column, and can +/// be used to query the notification. +#[derive(Debug, Clone, PartialEq, Eq, EnumVariantNames, Serialize, Deserialize)] +#[serde(tag = "kind")] +pub enum Notification { + ContactRequest { + #[serde(rename = "entity_id")] + sender_id: u64, + }, + ContactRequestAccepted { + #[serde(rename = "entity_id")] + responder_id: u64, + }, + ChannelInvitation { + #[serde(rename = "entity_id")] + channel_id: u64, + channel_name: String, + inviter_id: u64, + }, + ChannelMessageMention { + #[serde(rename = "entity_id")] + message_id: u64, + sender_id: u64, + channel_id: u64, + }, +} + +impl Notification { + pub fn to_proto(&self) -> proto::Notification { + let mut value = serde_json::to_value(self).unwrap(); + let mut entity_id = None; + let value = value.as_object_mut().unwrap(); + let Some(Value::String(kind)) = value.remove(KIND) else { + unreachable!("kind is the enum tag") + }; + if let map::Entry::Occupied(e) = value.entry(ENTITY_ID) { + if e.get().is_u64() { + entity_id = e.remove().as_u64(); + } + } + proto::Notification { + kind, + entity_id, + content: serde_json::to_string(&value).unwrap(), + ..Default::default() + } + } + + pub fn from_proto(notification: &proto::Notification) -> Option { + let mut value = serde_json::from_str::(¬ification.content).ok()?; + let object = value.as_object_mut()?; + object.insert(KIND.into(), notification.kind.to_string().into()); + if let Some(entity_id) = notification.entity_id { + object.insert(ENTITY_ID.into(), entity_id.into()); + } + serde_json::from_value(value).ok() + } + + pub fn all_variant_names() -> &'static [&'static str] { + Self::VARIANTS + } +} + +#[test] +fn test_notification() { + // Notifications can be serialized and deserialized. + for notification in [ + Notification::ContactRequest { sender_id: 1 }, + Notification::ContactRequestAccepted { responder_id: 2 }, + Notification::ChannelInvitation { + channel_id: 100, + channel_name: "the-channel".into(), + inviter_id: 50, + }, + Notification::ChannelMessageMention { + sender_id: 200, + channel_id: 30, + message_id: 1, + }, + ] { + let message = notification.to_proto(); + let deserialized = Notification::from_proto(&message).unwrap(); + assert_eq!(deserialized, notification); + } + + // When notifications are serialized, the `kind` and `actor_id` fields are + // stored separately, and do not appear redundantly in the JSON. + let notification = Notification::ContactRequest { sender_id: 1 }; + assert_eq!(notification.to_proto().content, "{}"); +} diff --git a/crates/rpc2/src/rpc.rs b/crates/rpc2/src/rpc.rs index 942672b94b..4bf90669b2 100644 --- a/crates/rpc2/src/rpc.rs +++ b/crates/rpc2/src/rpc.rs @@ -1,8 +1,11 @@ pub mod auth; mod conn; +mod notification; mod peer; pub mod proto; + pub use conn::Connection; +pub use notification::*; pub use peer::*; mod macros; diff --git a/crates/settings2/src/settings_store.rs b/crates/settings2/src/settings_store.rs index 3317a50f52..5bf10a518d 100644 --- a/crates/settings2/src/settings_store.rs +++ b/crates/settings2/src/settings_store.rs @@ -86,6 +86,7 @@ pub trait Settings: 'static + Send + Sync { }); } + #[track_caller] fn get<'a>(path: Option<(usize, &Path)>, cx: &'a AppContext) -> &'a Self where Self: Sized, @@ -93,6 +94,7 @@ pub trait Settings: 'static + Send + Sync { cx.global::().get(path) } + #[track_caller] fn get_global<'a>(cx: &'a AppContext) -> &'a Self where Self: Sized, @@ -100,6 +102,7 @@ pub trait Settings: 'static + Send + Sync { cx.global::().get(None) } + #[track_caller] fn override_global<'a>(settings: Self, cx: &'a mut AppContext) where Self: Sized,