diff --git a/.github/workflows/release_actions.yml b/.github/workflows/release_actions.yml index c1df24a8e5..550eda882b 100644 --- a/.github/workflows/release_actions.yml +++ b/.github/workflows/release_actions.yml @@ -20,9 +20,7 @@ jobs: id: get-content with: stringToTruncate: | - 📣 Zed ${{ github.event.release.tag_name }} was just released! - - Restart your Zed or head to ${{ steps.get-release-url.outputs.URL }} to grab it. + 📣 Zed [${{ github.event.release.tag_name }}](${{ steps.get-release-url.outputs.URL }}) was just released! ${{ github.event.release.body }} maxLength: 2000 diff --git a/Cargo.lock b/Cargo.lock index df9574be93..450d435ac2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -91,6 +91,7 @@ dependencies = [ "futures 0.3.28", "gpui", "isahc", + "language", "lazy_static", "log", "matrixmultiply", @@ -103,7 +104,7 @@ dependencies = [ "rusqlite", "serde", "serde_json", - "tiktoken-rs 0.5.4", + "tiktoken-rs", "util", ] @@ -309,6 +310,7 @@ dependencies = [ "language", "log", "menu", + "multi_buffer", "ordered-float 2.10.0", "parking_lot 0.11.2", "project", @@ -316,12 +318,13 @@ dependencies = [ "regex", "schemars", "search", + "semantic_index", "serde", "serde_json", "settings", "smol", "theme", - "tiktoken-rs 0.4.5", + "tiktoken-rs", "util", "uuid 1.4.1", "workspace", @@ -1573,7 +1576,7 @@ dependencies = [ [[package]] name = "collab" -version = "0.24.0" +version = "0.27.0" dependencies = [ "anyhow", "async-trait", @@ -1609,6 +1612,7 @@ dependencies = [ "lsp", "nanoid", "node_runtime", + "notifications", "parking_lot 0.11.2", "pretty_assertions", "project", @@ -1664,20 +1668,26 @@ dependencies = [ "fuzzy", "gpui", "language", + "lazy_static", "log", "menu", + "notifications", "picker", "postage", + "pretty_assertions", "project", "recent_projects", "rich_text", + "rpc", "schemars", "serde", "serde_derive", "settings", + "smallvec", "theme", "theme_selector", "time", + "tree-sitter-markdown", "util", "vcs_menu", "workspace", @@ -1731,6 +1741,7 @@ dependencies = [ "theme", "util", "workspace", + "zed-actions", ] [[package]] @@ -1810,6 +1821,7 @@ dependencies = [ "log", "lsp", "node_runtime", + "parking_lot 0.11.2", "rpc", "serde", "serde_derive", @@ -2556,11 +2568,11 @@ dependencies = [ "lazy_static", "log", "lsp", + "multi_buffer", "ordered-float 2.10.0", "parking_lot 0.11.2", "postage", "project", - "pulldown-cmark", "rand 0.8.5", "rich_text", "rpc", @@ -4244,6 +4256,7 @@ dependencies = [ "lsp", "parking_lot 0.11.2", "postage", + "pulldown-cmark", "rand 0.8.5", "regex", "rpc", @@ -4921,6 +4934,55 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7843ec2de400bcbc6a6328c958dc38e5359da6e93e72e37bc5246bf1ae776389" +[[package]] +name = "multi_buffer" +version = "0.1.0" +dependencies = [ + "aho-corasick", + "anyhow", + "client", + "clock", + "collections", + "context_menu", + "convert_case 0.6.0", + "copilot", + "ctor", + "env_logger 0.9.3", + "futures 0.3.28", + "git", + "gpui", + "indoc", + "itertools 0.10.5", + "language", + "lazy_static", + "log", + "lsp", + "ordered-float 2.10.0", + "parking_lot 0.11.2", + "postage", + "project", + "pulldown-cmark", + "rand 0.8.5", + "rich_text", + "schemars", + "serde", + "serde_derive", + "settings", + "smallvec", + "smol", + "snippet", + "sum_tree", + "text", + "theme", + "tree-sitter", + "tree-sitter-html", + "tree-sitter-rust", + "tree-sitter-typescript", + "unindent", + "util", + "workspace", +] + [[package]] name = "multimap" version = "0.8.3" @@ -5070,6 +5132,26 @@ dependencies = [ "minimal-lexical", ] +[[package]] +name = "notifications" +version = "0.1.0" +dependencies = [ + "anyhow", + "channel", + "client", + "clock", + "collections", + "db", + "feature_flags", + "gpui", + "rpc", + "settings", + "sum_tree", + "text", + "time", + "util", +] + [[package]] name = "ntapi" version = "0.3.7" @@ -5886,6 +5968,7 @@ dependencies = [ "log", "lsp", "node_runtime", + "parking_lot 0.11.2", "serde", "serde_derive", "serde_json", @@ -6831,8 +6914,10 @@ dependencies = [ "rsa 0.4.0", "serde", "serde_derive", + "serde_json", "smol", "smol-timeout", + "strum", "tempdir", "tracing", "util", @@ -7407,7 +7492,7 @@ dependencies = [ "smol", "tempdir", "theme", - "tiktoken-rs 0.5.4", + "tiktoken-rs", "tree-sitter", "tree-sitter-cpp", "tree-sitter-elixir", @@ -7421,7 +7506,6 @@ dependencies = [ "unindent", "util", "workspace", - "zed", ] [[package]] @@ -8713,21 +8797,6 @@ dependencies = [ "weezl", ] -[[package]] -name = "tiktoken-rs" -version = "0.4.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "52aacc1cff93ba9d5f198c62c49c77fa0355025c729eed3326beaf7f33bc8614" -dependencies = [ - "anyhow", - "base64 0.21.4", - "bstr", - "fancy-regex", - "lazy_static", - "parking_lot 0.12.1", - "rustc-hash", -] - [[package]] name = "tiktoken-rs" version = "0.5.4" @@ -9148,8 +9217,8 @@ dependencies = [ [[package]] name = "tree-sitter-bash" -version = "0.19.0" -source = "git+https://github.com/tree-sitter/tree-sitter-bash?rev=1b0321ee85701d5036c334a6f04761cdc672e64c#1b0321ee85701d5036c334a6f04761cdc672e64c" +version = "0.20.4" +source = "git+https://github.com/tree-sitter/tree-sitter-bash?rev=7331995b19b8f8aba2d5e26deb51d2195c18bc94#7331995b19b8f8aba2d5e26deb51d2195c18bc94" dependencies = [ "cc", "tree-sitter", @@ -9388,6 +9457,15 @@ dependencies = [ "tree-sitter", ] +[[package]] +name = "tree-sitter-vue" +version = "0.0.1" +source = "git+https://github.com/zed-industries/tree-sitter-vue?rev=95b2890#95b28908d90e928c308866f7631e73ef6e1d4b5f" +dependencies = [ + "cc", + "tree-sitter", +] + [[package]] name = "tree-sitter-yaml" version = "0.0.1" @@ -9712,6 +9790,7 @@ name = "vcs_menu" version = "0.1.0" dependencies = [ "anyhow", + "fs", "fuzzy", "gpui", "picker", @@ -10656,9 +10735,10 @@ dependencies = [ [[package]] name = "zed" -version = "0.109.0" +version = "0.111.0" dependencies = [ "activity_indicator", + "ai", "anyhow", "assistant", "async-compression", @@ -10710,6 +10790,7 @@ dependencies = [ "log", "lsp", "node_runtime", + "notifications", "num_cpus", "outline", "parking_lot 0.11.2", @@ -10771,6 +10852,7 @@ dependencies = [ "tree-sitter-svelte", "tree-sitter-toml", "tree-sitter-typescript", + "tree-sitter-vue", "tree-sitter-yaml", "unindent", "url", @@ -10788,6 +10870,7 @@ name = "zed-actions" version = "0.1.0" dependencies = [ "gpui", + "serde", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 82af9265dd..7db8e1073d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -58,7 +58,9 @@ members = [ "crates/lsp2", "crates/media", "crates/menu", + "crates/multi_buffer", "crates/node_runtime", + "crates/notifications", "crates/outline", "crates/picker", "crates/plugin", @@ -133,6 +135,7 @@ serde_derive = { version = "1.0", features = ["deserialize_in_place"] } serde_json = { version = "1.0", features = ["preserve_order", "raw_value"] } smallvec = { version = "1.6", features = ["union"] } smol = { version = "1.2" } +strum = { version = "0.25.0", features = ["derive"] } sysinfo = "0.29.10" tempdir = { version = "0.3.7" } thiserror = { version = "1.0.29" } @@ -144,7 +147,7 @@ pretty_assertions = "1.3.0" git2 = { version = "0.15", default-features = false} uuid = { version = "1.1.2", features = ["v4"] } -tree-sitter-bash = { git = "https://github.com/tree-sitter/tree-sitter-bash", rev = "1b0321ee85701d5036c334a6f04761cdc672e64c" } +tree-sitter-bash = { git = "https://github.com/tree-sitter/tree-sitter-bash", rev = "7331995b19b8f8aba2d5e26deb51d2195c18bc94" } tree-sitter-c = "0.20.1" tree-sitter-cpp = { git = "https://github.com/tree-sitter/tree-sitter-cpp", rev="f44509141e7e483323d2ec178f2d2e6c0fc041c1" } tree-sitter-css = { git = "https://github.com/tree-sitter/tree-sitter-css", rev = "769203d0f9abe1a9a691ac2b9fe4bb4397a73c51" } @@ -170,7 +173,7 @@ tree-sitter-yaml = { git = "https://github.com/zed-industries/tree-sitter-yaml", tree-sitter-lua = "0.0.14" tree-sitter-nix = { git = "https://github.com/nix-community/tree-sitter-nix", rev = "66e3e9ce9180ae08fc57372061006ef83f0abde7" } tree-sitter-nu = { git = "https://github.com/nushell/tree-sitter-nu", rev = "786689b0562b9799ce53e824cb45a1a2a04dc673"} - +tree-sitter-vue = {git = "https://github.com/zed-industries/tree-sitter-vue", rev = "95b2890"} [patch.crates-io] tree-sitter = { git = "https://github.com/tree-sitter/tree-sitter", rev = "35a6052fbcafc5e5fc0f9415b8652be7dcaf7222" } async-task = { git = "https://github.com/zed-industries/async-task", rev = "341b57d6de98cdfd7b418567b8de2022ca993a6e" } diff --git a/Procfile b/Procfile index 2eb7de20fb..3f42c3a967 100644 --- a/Procfile +++ b/Procfile @@ -1,4 +1,4 @@ web: cd ../zed.dev && PORT=3000 npm run dev -collab: cd crates/collab && RUST_LOG=${RUST_LOG:-collab=info} cargo run serve +collab: cd crates/collab && RUST_LOG=${RUST_LOG:-warn,collab=info} cargo run serve livekit: livekit-server --dev postgrest: postgrest crates/collab/admin_api.conf diff --git a/assets/icons/bell.svg b/assets/icons/bell.svg new file mode 100644 index 0000000000..ea1c6dd42e --- /dev/null +++ b/assets/icons/bell.svg @@ -0,0 +1,8 @@ + + + diff --git a/assets/icons/link.svg b/assets/icons/link.svg new file mode 100644 index 0000000000..4925bd8e00 --- /dev/null +++ b/assets/icons/link.svg @@ -0,0 +1,3 @@ + + + diff --git a/assets/icons/public.svg b/assets/icons/public.svg new file mode 100644 index 0000000000..38278cdaba --- /dev/null +++ b/assets/icons/public.svg @@ -0,0 +1,3 @@ + + + diff --git a/assets/icons/update.svg b/assets/icons/update.svg new file mode 100644 index 0000000000..b529b2b08b --- /dev/null +++ b/assets/icons/update.svg @@ -0,0 +1,8 @@ + + + diff --git a/assets/keymaps/default.json b/assets/keymaps/default.json index 8422d53abc..ef6a655bdc 100644 --- a/assets/keymaps/default.json +++ b/assets/keymaps/default.json @@ -370,42 +370,15 @@ { "context": "Pane", "bindings": { - "ctrl-1": [ - "pane::ActivateItem", - 0 - ], - "ctrl-2": [ - "pane::ActivateItem", - 1 - ], - "ctrl-3": [ - "pane::ActivateItem", - 2 - ], - "ctrl-4": [ - "pane::ActivateItem", - 3 - ], - "ctrl-5": [ - "pane::ActivateItem", - 4 - ], - "ctrl-6": [ - "pane::ActivateItem", - 5 - ], - "ctrl-7": [ - "pane::ActivateItem", - 6 - ], - "ctrl-8": [ - "pane::ActivateItem", - 7 - ], - "ctrl-9": [ - "pane::ActivateItem", - 8 - ], + "ctrl-1": ["pane::ActivateItem", 0], + "ctrl-2": ["pane::ActivateItem", 1], + "ctrl-3": ["pane::ActivateItem", 2], + "ctrl-4": ["pane::ActivateItem", 3], + "ctrl-5": ["pane::ActivateItem", 4], + "ctrl-6": ["pane::ActivateItem", 5], + "ctrl-7": ["pane::ActivateItem", 6], + "ctrl-8": ["pane::ActivateItem", 7], + "ctrl-9": ["pane::ActivateItem", 8], "ctrl-0": "pane::ActivateLastItem", "ctrl--": "pane::GoBack", "ctrl-_": "pane::GoForward", @@ -416,42 +389,15 @@ { "context": "Workspace", "bindings": { - "cmd-1": [ - "workspace::ActivatePane", - 0 - ], - "cmd-2": [ - "workspace::ActivatePane", - 1 - ], - "cmd-3": [ - "workspace::ActivatePane", - 2 - ], - "cmd-4": [ - "workspace::ActivatePane", - 3 - ], - "cmd-5": [ - "workspace::ActivatePane", - 4 - ], - "cmd-6": [ - "workspace::ActivatePane", - 5 - ], - "cmd-7": [ - "workspace::ActivatePane", - 6 - ], - "cmd-8": [ - "workspace::ActivatePane", - 7 - ], - "cmd-9": [ - "workspace::ActivatePane", - 8 - ], + "cmd-1": ["workspace::ActivatePane", 0], + "cmd-2": ["workspace::ActivatePane", 1], + "cmd-3": ["workspace::ActivatePane", 2], + "cmd-4": ["workspace::ActivatePane", 3], + "cmd-5": ["workspace::ActivatePane", 4], + "cmd-6": ["workspace::ActivatePane", 5], + "cmd-7": ["workspace::ActivatePane", 6], + "cmd-8": ["workspace::ActivatePane", 7], + "cmd-9": ["workspace::ActivatePane", 8], "cmd-b": "workspace::ToggleLeftDock", "cmd-r": "workspace::ToggleRightDock", "cmd-j": "workspace::ToggleBottomDock", @@ -494,38 +440,14 @@ }, { "bindings": { - "cmd-k cmd-left": [ - "workspace::ActivatePaneInDirection", - "Left" - ], - "cmd-k cmd-right": [ - "workspace::ActivatePaneInDirection", - "Right" - ], - "cmd-k cmd-up": [ - "workspace::ActivatePaneInDirection", - "Up" - ], - "cmd-k cmd-down": [ - "workspace::ActivatePaneInDirection", - "Down" - ], - "cmd-k shift-left": [ - "workspace::SwapPaneInDirection", - "Left" - ], - "cmd-k shift-right": [ - "workspace::SwapPaneInDirection", - "Right" - ], - "cmd-k shift-up": [ - "workspace::SwapPaneInDirection", - "Up" - ], - "cmd-k shift-down": [ - "workspace::SwapPaneInDirection", - "Down" - ] + "cmd-k cmd-left": ["workspace::ActivatePaneInDirection", "Left"], + "cmd-k cmd-right": ["workspace::ActivatePaneInDirection", "Right"], + "cmd-k cmd-up": ["workspace::ActivatePaneInDirection", "Up"], + "cmd-k cmd-down": ["workspace::ActivatePaneInDirection", "Down"], + "cmd-k shift-left": ["workspace::SwapPaneInDirection", "Left"], + "cmd-k shift-right": ["workspace::SwapPaneInDirection", "Right"], + "cmd-k shift-up": ["workspace::SwapPaneInDirection", "Up"], + "cmd-k shift-down": ["workspace::SwapPaneInDirection", "Down"] } }, // Bindings from Atom @@ -627,14 +549,6 @@ "space": "collab_panel::InsertSpace" } }, - { - "context": "(CollabPanel && not_editing) > Editor", - "bindings": { - "cmd-c": "collab_panel::StartLinkChannel", - "cmd-x": "collab_panel::StartMoveChannel", - "cmd-v": "collab_panel::MoveOrLinkToSelected" - } - }, { "context": "ChannelModal", "bindings": { @@ -655,57 +569,21 @@ "cmd-v": "terminal::Paste", "cmd-k": "terminal::Clear", // Some nice conveniences - "cmd-backspace": [ - "terminal::SendText", - "\u0015" - ], - "cmd-right": [ - "terminal::SendText", - "\u0005" - ], - "cmd-left": [ - "terminal::SendText", - "\u0001" - ], + "cmd-backspace": ["terminal::SendText", "\u0015"], + "cmd-right": ["terminal::SendText", "\u0005"], + "cmd-left": ["terminal::SendText", "\u0001"], // Terminal.app compatibility - "alt-left": [ - "terminal::SendText", - "\u001bb" - ], - "alt-right": [ - "terminal::SendText", - "\u001bf" - ], + "alt-left": ["terminal::SendText", "\u001bb"], + "alt-right": ["terminal::SendText", "\u001bf"], // There are conflicting bindings for these keys in the global context. // these bindings override them, remove at your own risk: - "up": [ - "terminal::SendKeystroke", - "up" - ], - "pageup": [ - "terminal::SendKeystroke", - "pageup" - ], - "down": [ - "terminal::SendKeystroke", - "down" - ], - "pagedown": [ - "terminal::SendKeystroke", - "pagedown" - ], - "escape": [ - "terminal::SendKeystroke", - "escape" - ], - "enter": [ - "terminal::SendKeystroke", - "enter" - ], - "ctrl-c": [ - "terminal::SendKeystroke", - "ctrl-c" - ] + "up": ["terminal::SendKeystroke", "up"], + "pageup": ["terminal::SendKeystroke", "pageup"], + "down": ["terminal::SendKeystroke", "down"], + "pagedown": ["terminal::SendKeystroke", "pagedown"], + "escape": ["terminal::SendKeystroke", "escape"], + "enter": ["terminal::SendKeystroke", "enter"], + "ctrl-c": ["terminal::SendKeystroke", "ctrl-c"] } } ] diff --git a/assets/keymaps/vim.json b/assets/keymaps/vim.json index ea025747d8..81235bb72a 100644 --- a/assets/keymaps/vim.json +++ b/assets/keymaps/vim.json @@ -39,6 +39,7 @@ "w": "vim::NextWordStart", "{": "vim::StartOfParagraph", "}": "vim::EndOfParagraph", + "|": "vim::GoToColumn", "shift-w": [ "vim::NextWordStart", { @@ -97,14 +98,8 @@ "ctrl-o": "pane::GoBack", "ctrl-i": "pane::GoForward", "ctrl-]": "editor::GoToDefinition", - "escape": [ - "vim::SwitchMode", - "Normal" - ], - "ctrl+[": [ - "vim::SwitchMode", - "Normal" - ], + "escape": ["vim::SwitchMode", "Normal"], + "ctrl+[": ["vim::SwitchMode", "Normal"], "v": "vim::ToggleVisual", "shift-v": "vim::ToggleVisualLine", "ctrl-v": "vim::ToggleVisualBlock", @@ -233,123 +228,36 @@ } ], // Count support - "1": [ - "vim::Number", - 1 - ], - "2": [ - "vim::Number", - 2 - ], - "3": [ - "vim::Number", - 3 - ], - "4": [ - "vim::Number", - 4 - ], - "5": [ - "vim::Number", - 5 - ], - "6": [ - "vim::Number", - 6 - ], - "7": [ - "vim::Number", - 7 - ], - "8": [ - "vim::Number", - 8 - ], - "9": [ - "vim::Number", - 9 - ], + "1": ["vim::Number", 1], + "2": ["vim::Number", 2], + "3": ["vim::Number", 3], + "4": ["vim::Number", 4], + "5": ["vim::Number", 5], + "6": ["vim::Number", 6], + "7": ["vim::Number", 7], + "8": ["vim::Number", 8], + "9": ["vim::Number", 9], // window related commands (ctrl-w X) - "ctrl-w left": [ - "workspace::ActivatePaneInDirection", - "Left" - ], - "ctrl-w right": [ - "workspace::ActivatePaneInDirection", - "Right" - ], - "ctrl-w up": [ - "workspace::ActivatePaneInDirection", - "Up" - ], - "ctrl-w down": [ - "workspace::ActivatePaneInDirection", - "Down" - ], - "ctrl-w h": [ - "workspace::ActivatePaneInDirection", - "Left" - ], - "ctrl-w l": [ - "workspace::ActivatePaneInDirection", - "Right" - ], - "ctrl-w k": [ - "workspace::ActivatePaneInDirection", - "Up" - ], - "ctrl-w j": [ - "workspace::ActivatePaneInDirection", - "Down" - ], - "ctrl-w ctrl-h": [ - "workspace::ActivatePaneInDirection", - "Left" - ], - "ctrl-w ctrl-l": [ - "workspace::ActivatePaneInDirection", - "Right" - ], - "ctrl-w ctrl-k": [ - "workspace::ActivatePaneInDirection", - "Up" - ], - "ctrl-w ctrl-j": [ - "workspace::ActivatePaneInDirection", - "Down" - ], - "ctrl-w shift-left": [ - "workspace::SwapPaneInDirection", - "Left" - ], - "ctrl-w shift-right": [ - "workspace::SwapPaneInDirection", - "Right" - ], - "ctrl-w shift-up": [ - "workspace::SwapPaneInDirection", - "Up" - ], - "ctrl-w shift-down": [ - "workspace::SwapPaneInDirection", - "Down" - ], - "ctrl-w shift-h": [ - "workspace::SwapPaneInDirection", - "Left" - ], - "ctrl-w shift-l": [ - "workspace::SwapPaneInDirection", - "Right" - ], - "ctrl-w shift-k": [ - "workspace::SwapPaneInDirection", - "Up" - ], - "ctrl-w shift-j": [ - "workspace::SwapPaneInDirection", - "Down" - ], + "ctrl-w left": ["workspace::ActivatePaneInDirection", "Left"], + "ctrl-w right": ["workspace::ActivatePaneInDirection", "Right"], + "ctrl-w up": ["workspace::ActivatePaneInDirection", "Up"], + "ctrl-w down": ["workspace::ActivatePaneInDirection", "Down"], + "ctrl-w h": ["workspace::ActivatePaneInDirection", "Left"], + "ctrl-w l": ["workspace::ActivatePaneInDirection", "Right"], + "ctrl-w k": ["workspace::ActivatePaneInDirection", "Up"], + "ctrl-w j": ["workspace::ActivatePaneInDirection", "Down"], + "ctrl-w ctrl-h": ["workspace::ActivatePaneInDirection", "Left"], + "ctrl-w ctrl-l": ["workspace::ActivatePaneInDirection", "Right"], + "ctrl-w ctrl-k": ["workspace::ActivatePaneInDirection", "Up"], + "ctrl-w ctrl-j": ["workspace::ActivatePaneInDirection", "Down"], + "ctrl-w shift-left": ["workspace::SwapPaneInDirection", "Left"], + "ctrl-w shift-right": ["workspace::SwapPaneInDirection", "Right"], + "ctrl-w shift-up": ["workspace::SwapPaneInDirection", "Up"], + "ctrl-w shift-down": ["workspace::SwapPaneInDirection", "Down"], + "ctrl-w shift-h": ["workspace::SwapPaneInDirection", "Left"], + "ctrl-w shift-l": ["workspace::SwapPaneInDirection", "Right"], + "ctrl-w shift-k": ["workspace::SwapPaneInDirection", "Up"], + "ctrl-w shift-j": ["workspace::SwapPaneInDirection", "Down"], "ctrl-w g t": "pane::ActivateNextItem", "ctrl-w ctrl-g t": "pane::ActivateNextItem", "ctrl-w g shift-t": "pane::ActivatePrevItem", @@ -371,14 +279,8 @@ "ctrl-w ctrl-q": "pane::CloseAllItems", "ctrl-w o": "workspace::CloseInactiveTabsAndPanes", "ctrl-w ctrl-o": "workspace::CloseInactiveTabsAndPanes", - "ctrl-w n": [ - "workspace::NewFileInDirection", - "Up" - ], - "ctrl-w ctrl-n": [ - "workspace::NewFileInDirection", - "Up" - ] + "ctrl-w n": ["workspace::NewFileInDirection", "Up"], + "ctrl-w ctrl-n": ["workspace::NewFileInDirection", "Up"] } }, { @@ -393,21 +295,12 @@ "context": "Editor && vim_mode == normal && vim_operator == none && !VimWaiting", "bindings": { ".": "vim::Repeat", - "c": [ - "vim::PushOperator", - "Change" - ], + "c": ["vim::PushOperator", "Change"], "shift-c": "vim::ChangeToEndOfLine", - "d": [ - "vim::PushOperator", - "Delete" - ], + "d": ["vim::PushOperator", "Delete"], "shift-d": "vim::DeleteToEndOfLine", "shift-j": "vim::JoinLines", - "y": [ - "vim::PushOperator", - "Yank" - ], + "y": ["vim::PushOperator", "Yank"], "shift-y": "vim::YankLine", "i": "vim::InsertBefore", "shift-i": "vim::InsertFirstNonWhitespace", @@ -443,10 +336,7 @@ "backwards": true } ], - "r": [ - "vim::PushOperator", - "Replace" - ], + "r": ["vim::PushOperator", "Replace"], "s": "vim::Substitute", "shift-s": "vim::SubstituteLine", "> >": "editor::Indent", @@ -458,10 +348,7 @@ { "context": "Editor && VimCount", "bindings": { - "0": [ - "vim::Number", - 0 - ] + "0": ["vim::Number", 0] } }, { @@ -497,12 +384,15 @@ "'": "vim::Quotes", "`": "vim::BackQuotes", "\"": "vim::DoubleQuotes", + "|": "vim::VerticalBars", "(": "vim::Parentheses", ")": "vim::Parentheses", + "b": "vim::Parentheses", "[": "vim::SquareBrackets", "]": "vim::SquareBrackets", "{": "vim::CurlyBrackets", "}": "vim::CurlyBrackets", + "shift-b": "vim::CurlyBrackets", "<": "vim::AngleBrackets", ">": "vim::AngleBrackets" } @@ -548,22 +438,10 @@ "shift-i": "vim::InsertBefore", "shift-a": "vim::InsertAfter", "shift-j": "vim::JoinLines", - "r": [ - "vim::PushOperator", - "Replace" - ], - "ctrl-c": [ - "vim::SwitchMode", - "Normal" - ], - "escape": [ - "vim::SwitchMode", - "Normal" - ], - "ctrl+[": [ - "vim::SwitchMode", - "Normal" - ], + "r": ["vim::PushOperator", "Replace"], + "ctrl-c": ["vim::SwitchMode", "Normal"], + "escape": ["vim::SwitchMode", "Normal"], + "ctrl+[": ["vim::SwitchMode", "Normal"], ">": "editor::Indent", "<": "editor::Outdent", "i": [ @@ -602,14 +480,8 @@ "bindings": { "tab": "vim::Tab", "enter": "vim::Enter", - "escape": [ - "vim::SwitchMode", - "Normal" - ], - "ctrl+[": [ - "vim::SwitchMode", - "Normal" - ] + "escape": ["vim::SwitchMode", "Normal"], + "ctrl+[": ["vim::SwitchMode", "Normal"] } }, { diff --git a/assets/settings/default.json b/assets/settings/default.json index 1611d80e2f..19c73ca021 100644 --- a/assets/settings/default.json +++ b/assets/settings/default.json @@ -50,6 +50,9 @@ // Whether to pop the completions menu while typing in an editor without // explicitly requesting it. "show_completions_on_input": true, + // Whether to display inline and alongside documentation for items in the + // completions menu + "show_completion_documentation": true, // Whether to show wrap guides in the editor. Setting this to true will // show a guide at the 'preferred_line_length' value if softwrap is set to // 'preferred_line_length', and will show any additional guides as specified @@ -139,6 +142,14 @@ // Default width of the channels panel. "default_width": 240 }, + "notification_panel": { + // Whether to show the collaboration panel button in the status bar. + "button": true, + // Where to dock channels panel. Can be 'left' or 'right'. + "dock": "right", + // Default width of the channels panel. + "default_width": 380 + }, "assistant": { // Whether to show the assistant panel button in the status bar. "button": true, diff --git a/crates/ai/Cargo.toml b/crates/ai/Cargo.toml index 542d7f422f..b24c4e5ece 100644 --- a/crates/ai/Cargo.toml +++ b/crates/ai/Cargo.toml @@ -11,6 +11,7 @@ doctest = false [dependencies] gpui = { path = "../gpui" } util = { path = "../util" } +language = { path = "../language" } async-trait.workspace = true anyhow.workspace = true futures.workspace = true diff --git a/crates/ai/src/ai.rs b/crates/ai/src/ai.rs index 5256a6a643..f168c15793 100644 --- a/crates/ai/src/ai.rs +++ b/crates/ai/src/ai.rs @@ -1,2 +1,4 @@ pub mod completion; pub mod embedding; +pub mod models; +pub mod templates; diff --git a/crates/ai/src/completion.rs b/crates/ai/src/completion.rs index 170b2268f9..de6ce9da71 100644 --- a/crates/ai/src/completion.rs +++ b/crates/ai/src/completion.rs @@ -53,6 +53,8 @@ pub struct OpenAIRequest { pub model: String, pub messages: Vec, pub stream: bool, + pub stop: Vec, + pub temperature: f32, } #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] diff --git a/crates/ai/src/embedding.rs b/crates/ai/src/embedding.rs index 332470aa54..b791414ba2 100644 --- a/crates/ai/src/embedding.rs +++ b/crates/ai/src/embedding.rs @@ -2,7 +2,7 @@ use anyhow::{anyhow, Result}; use async_trait::async_trait; use futures::AsyncReadExt; use gpui::executor::Background; -use gpui::serde_json; +use gpui::{serde_json, AppContext}; use isahc::http::StatusCode; use isahc::prelude::Configurable; use isahc::{AsyncBody, Response}; @@ -20,9 +20,11 @@ use std::sync::Arc; use std::time::{Duration, Instant}; use tiktoken_rs::{cl100k_base, CoreBPE}; use util::http::{HttpClient, Request}; +use util::ResultExt; + +use crate::completion::OPENAI_API_URL; lazy_static! { - static ref OPENAI_API_KEY: Option = env::var("OPENAI_API_KEY").ok(); static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap(); } @@ -85,25 +87,6 @@ impl Embedding { } } -// impl FromSql for Embedding { -// fn column_result(value: ValueRef) -> FromSqlResult { -// let bytes = value.as_blob()?; -// let embedding: Result, Box> = bincode::deserialize(bytes); -// if embedding.is_err() { -// return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err())); -// } -// Ok(Embedding(embedding.unwrap())) -// } -// } - -// impl ToSql for Embedding { -// fn to_sql(&self) -> rusqlite::Result { -// let bytes = bincode::serialize(&self.0) -// .map_err(|err| rusqlite::Error::ToSqlConversionFailure(Box::new(err)))?; -// Ok(ToSqlOutput::Owned(rusqlite::types::Value::Blob(bytes))) -// } -// } - #[derive(Clone)] pub struct OpenAIEmbeddings { pub client: Arc, @@ -139,8 +122,12 @@ struct OpenAIEmbeddingUsage { #[async_trait] pub trait EmbeddingProvider: Sync + Send { - fn is_authenticated(&self) -> bool; - async fn embed_batch(&self, spans: Vec) -> Result>; + fn retrieve_credentials(&self, cx: &AppContext) -> Option; + async fn embed_batch( + &self, + spans: Vec, + api_key: Option, + ) -> Result>; fn max_tokens_per_batch(&self) -> usize; fn truncate(&self, span: &str) -> (String, usize); fn rate_limit_expiration(&self) -> Option; @@ -150,13 +137,17 @@ pub struct DummyEmbeddings {} #[async_trait] impl EmbeddingProvider for DummyEmbeddings { - fn is_authenticated(&self) -> bool { - true + fn retrieve_credentials(&self, _cx: &AppContext) -> Option { + Some("Dummy API KEY".to_string()) } fn rate_limit_expiration(&self) -> Option { None } - async fn embed_batch(&self, spans: Vec) -> Result> { + async fn embed_batch( + &self, + spans: Vec, + _api_key: Option, + ) -> Result> { // 1024 is the OpenAI Embeddings size for ada models. // the model we will likely be starting with. let dummy_vec = Embedding::from(vec![0.32 as f32; 1536]); @@ -255,9 +246,21 @@ impl OpenAIEmbeddings { #[async_trait] impl EmbeddingProvider for OpenAIEmbeddings { - fn is_authenticated(&self) -> bool { - OPENAI_API_KEY.as_ref().is_some() + fn retrieve_credentials(&self, cx: &AppContext) -> Option { + if let Ok(api_key) = env::var("OPENAI_API_KEY") { + Some(api_key) + } else if let Some((_, api_key)) = cx + .platform() + .read_credentials(OPENAI_API_URL) + .log_err() + .flatten() + { + String::from_utf8(api_key).log_err() + } else { + None + } } + fn max_tokens_per_batch(&self) -> usize { 50000 } @@ -280,13 +283,17 @@ impl EmbeddingProvider for OpenAIEmbeddings { (output, tokens.len()) } - async fn embed_batch(&self, spans: Vec) -> Result> { + async fn embed_batch( + &self, + spans: Vec, + api_key: Option, + ) -> Result> { const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45]; const MAX_RETRIES: usize = 4; - let api_key = OPENAI_API_KEY - .as_ref() - .ok_or_else(|| anyhow!("no api key"))?; + let Some(api_key) = api_key else { + return Err(anyhow!("no open ai key provided")); + }; let mut request_number = 0; let mut rate_limiting = false; @@ -295,11 +302,12 @@ impl EmbeddingProvider for OpenAIEmbeddings { while request_number < MAX_RETRIES { response = self .send_request( - api_key, + &api_key, spans.iter().map(|x| &**x).collect(), request_timeout, ) .await?; + request_number += 1; match response.status() { diff --git a/crates/ai/src/models.rs b/crates/ai/src/models.rs new file mode 100644 index 0000000000..d0206cc41c --- /dev/null +++ b/crates/ai/src/models.rs @@ -0,0 +1,66 @@ +use anyhow::anyhow; +use tiktoken_rs::CoreBPE; +use util::ResultExt; + +pub trait LanguageModel { + fn name(&self) -> String; + fn count_tokens(&self, content: &str) -> anyhow::Result; + fn truncate(&self, content: &str, length: usize) -> anyhow::Result; + fn truncate_start(&self, content: &str, length: usize) -> anyhow::Result; + fn capacity(&self) -> anyhow::Result; +} + +pub struct OpenAILanguageModel { + name: String, + bpe: Option, +} + +impl OpenAILanguageModel { + pub fn load(model_name: &str) -> Self { + let bpe = tiktoken_rs::get_bpe_from_model(model_name).log_err(); + OpenAILanguageModel { + name: model_name.to_string(), + bpe, + } + } +} + +impl LanguageModel for OpenAILanguageModel { + fn name(&self) -> String { + self.name.clone() + } + fn count_tokens(&self, content: &str) -> anyhow::Result { + if let Some(bpe) = &self.bpe { + anyhow::Ok(bpe.encode_with_special_tokens(content).len()) + } else { + Err(anyhow!("bpe for open ai model was not retrieved")) + } + } + fn truncate(&self, content: &str, length: usize) -> anyhow::Result { + if let Some(bpe) = &self.bpe { + let tokens = bpe.encode_with_special_tokens(content); + if tokens.len() > length { + bpe.decode(tokens[..length].to_vec()) + } else { + bpe.decode(tokens) + } + } else { + Err(anyhow!("bpe for open ai model was not retrieved")) + } + } + fn truncate_start(&self, content: &str, length: usize) -> anyhow::Result { + if let Some(bpe) = &self.bpe { + let tokens = bpe.encode_with_special_tokens(content); + if tokens.len() > length { + bpe.decode(tokens[length..].to_vec()) + } else { + bpe.decode(tokens) + } + } else { + Err(anyhow!("bpe for open ai model was not retrieved")) + } + } + fn capacity(&self) -> anyhow::Result { + anyhow::Ok(tiktoken_rs::model::get_context_size(&self.name)) + } +} diff --git a/crates/ai/src/templates/base.rs b/crates/ai/src/templates/base.rs new file mode 100644 index 0000000000..bda1d6c30e --- /dev/null +++ b/crates/ai/src/templates/base.rs @@ -0,0 +1,350 @@ +use std::cmp::Reverse; +use std::ops::Range; +use std::sync::Arc; + +use language::BufferSnapshot; +use util::ResultExt; + +use crate::models::LanguageModel; +use crate::templates::repository_context::PromptCodeSnippet; + +pub(crate) enum PromptFileType { + Text, + Code, +} + +// TODO: Set this up to manage for defaults well +pub struct PromptArguments { + pub model: Arc, + pub user_prompt: Option, + pub language_name: Option, + pub project_name: Option, + pub snippets: Vec, + pub reserved_tokens: usize, + pub buffer: Option, + pub selected_range: Option>, +} + +impl PromptArguments { + pub(crate) fn get_file_type(&self) -> PromptFileType { + if self + .language_name + .as_ref() + .and_then(|name| Some(!["Markdown", "Plain Text"].contains(&name.as_str()))) + .unwrap_or(true) + { + PromptFileType::Code + } else { + PromptFileType::Text + } + } +} + +pub trait PromptTemplate { + fn generate( + &self, + args: &PromptArguments, + max_token_length: Option, + ) -> anyhow::Result<(String, usize)>; +} + +#[repr(i8)] +#[derive(PartialEq, Eq, Ord)] +pub enum PromptPriority { + Mandatory, // Ignores truncation + Ordered { order: usize }, // Truncates based on priority +} + +impl PartialOrd for PromptPriority { + fn partial_cmp(&self, other: &Self) -> Option { + match (self, other) { + (Self::Mandatory, Self::Mandatory) => Some(std::cmp::Ordering::Equal), + (Self::Mandatory, Self::Ordered { .. }) => Some(std::cmp::Ordering::Greater), + (Self::Ordered { .. }, Self::Mandatory) => Some(std::cmp::Ordering::Less), + (Self::Ordered { order: a }, Self::Ordered { order: b }) => b.partial_cmp(a), + } + } +} + +pub struct PromptChain { + args: PromptArguments, + templates: Vec<(PromptPriority, Box)>, +} + +impl PromptChain { + pub fn new( + args: PromptArguments, + templates: Vec<(PromptPriority, Box)>, + ) -> Self { + PromptChain { args, templates } + } + + pub fn generate(&self, truncate: bool) -> anyhow::Result<(String, usize)> { + // Argsort based on Prompt Priority + let seperator = "\n"; + let seperator_tokens = self.args.model.count_tokens(seperator)?; + let mut sorted_indices = (0..self.templates.len()).collect::>(); + sorted_indices.sort_by_key(|&i| Reverse(&self.templates[i].0)); + + // If Truncate + let mut tokens_outstanding = if truncate { + Some(self.args.model.capacity()? - self.args.reserved_tokens) + } else { + None + }; + + let mut prompts = vec!["".to_string(); sorted_indices.len()]; + for idx in sorted_indices { + let (_, template) = &self.templates[idx]; + + if let Some((template_prompt, prompt_token_count)) = + template.generate(&self.args, tokens_outstanding).log_err() + { + if template_prompt != "" { + prompts[idx] = template_prompt; + + if let Some(remaining_tokens) = tokens_outstanding { + let new_tokens = prompt_token_count + seperator_tokens; + tokens_outstanding = if remaining_tokens > new_tokens { + Some(remaining_tokens - new_tokens) + } else { + Some(0) + }; + } + } + } + } + + prompts.retain(|x| x != ""); + + let full_prompt = prompts.join(seperator); + let total_token_count = self.args.model.count_tokens(&full_prompt)?; + anyhow::Ok((prompts.join(seperator), total_token_count)) + } +} + +#[cfg(test)] +pub(crate) mod tests { + use super::*; + + #[test] + pub fn test_prompt_chain() { + struct TestPromptTemplate {} + impl PromptTemplate for TestPromptTemplate { + fn generate( + &self, + args: &PromptArguments, + max_token_length: Option, + ) -> anyhow::Result<(String, usize)> { + let mut content = "This is a test prompt template".to_string(); + + let mut token_count = args.model.count_tokens(&content)?; + if let Some(max_token_length) = max_token_length { + if token_count > max_token_length { + content = args.model.truncate(&content, max_token_length)?; + token_count = max_token_length; + } + } + + anyhow::Ok((content, token_count)) + } + } + + struct TestLowPriorityTemplate {} + impl PromptTemplate for TestLowPriorityTemplate { + fn generate( + &self, + args: &PromptArguments, + max_token_length: Option, + ) -> anyhow::Result<(String, usize)> { + let mut content = "This is a low priority test prompt template".to_string(); + + let mut token_count = args.model.count_tokens(&content)?; + if let Some(max_token_length) = max_token_length { + if token_count > max_token_length { + content = args.model.truncate(&content, max_token_length)?; + token_count = max_token_length; + } + } + + anyhow::Ok((content, token_count)) + } + } + + #[derive(Clone)] + struct DummyLanguageModel { + capacity: usize, + } + + impl LanguageModel for DummyLanguageModel { + fn name(&self) -> String { + "dummy".to_string() + } + fn count_tokens(&self, content: &str) -> anyhow::Result { + anyhow::Ok(content.chars().collect::>().len()) + } + fn truncate(&self, content: &str, length: usize) -> anyhow::Result { + anyhow::Ok( + content.chars().collect::>()[..length] + .into_iter() + .collect::(), + ) + } + fn truncate_start(&self, content: &str, length: usize) -> anyhow::Result { + anyhow::Ok( + content.chars().collect::>()[length..] + .into_iter() + .collect::(), + ) + } + fn capacity(&self) -> anyhow::Result { + anyhow::Ok(self.capacity) + } + } + + let model: Arc = Arc::new(DummyLanguageModel { capacity: 100 }); + let args = PromptArguments { + model: model.clone(), + language_name: None, + project_name: None, + snippets: Vec::new(), + reserved_tokens: 0, + buffer: None, + selected_range: None, + user_prompt: None, + }; + + let templates: Vec<(PromptPriority, Box)> = vec![ + ( + PromptPriority::Ordered { order: 0 }, + Box::new(TestPromptTemplate {}), + ), + ( + PromptPriority::Ordered { order: 1 }, + Box::new(TestLowPriorityTemplate {}), + ), + ]; + let chain = PromptChain::new(args, templates); + + let (prompt, token_count) = chain.generate(false).unwrap(); + + assert_eq!( + prompt, + "This is a test prompt template\nThis is a low priority test prompt template" + .to_string() + ); + + assert_eq!(model.count_tokens(&prompt).unwrap(), token_count); + + // Testing with Truncation Off + // Should ignore capacity and return all prompts + let model: Arc = Arc::new(DummyLanguageModel { capacity: 20 }); + let args = PromptArguments { + model: model.clone(), + language_name: None, + project_name: None, + snippets: Vec::new(), + reserved_tokens: 0, + buffer: None, + selected_range: None, + user_prompt: None, + }; + + let templates: Vec<(PromptPriority, Box)> = vec![ + ( + PromptPriority::Ordered { order: 0 }, + Box::new(TestPromptTemplate {}), + ), + ( + PromptPriority::Ordered { order: 1 }, + Box::new(TestLowPriorityTemplate {}), + ), + ]; + let chain = PromptChain::new(args, templates); + + let (prompt, token_count) = chain.generate(false).unwrap(); + + assert_eq!( + prompt, + "This is a test prompt template\nThis is a low priority test prompt template" + .to_string() + ); + + assert_eq!(model.count_tokens(&prompt).unwrap(), token_count); + + // Testing with Truncation Off + // Should ignore capacity and return all prompts + let capacity = 20; + let model: Arc = Arc::new(DummyLanguageModel { capacity }); + let args = PromptArguments { + model: model.clone(), + language_name: None, + project_name: None, + snippets: Vec::new(), + reserved_tokens: 0, + buffer: None, + selected_range: None, + user_prompt: None, + }; + + let templates: Vec<(PromptPriority, Box)> = vec![ + ( + PromptPriority::Ordered { order: 0 }, + Box::new(TestPromptTemplate {}), + ), + ( + PromptPriority::Ordered { order: 1 }, + Box::new(TestLowPriorityTemplate {}), + ), + ( + PromptPriority::Ordered { order: 2 }, + Box::new(TestLowPriorityTemplate {}), + ), + ]; + let chain = PromptChain::new(args, templates); + + let (prompt, token_count) = chain.generate(true).unwrap(); + + assert_eq!(prompt, "This is a test promp".to_string()); + assert_eq!(token_count, capacity); + + // Change Ordering of Prompts Based on Priority + let capacity = 120; + let reserved_tokens = 10; + let model: Arc = Arc::new(DummyLanguageModel { capacity }); + let args = PromptArguments { + model: model.clone(), + language_name: None, + project_name: None, + snippets: Vec::new(), + reserved_tokens, + buffer: None, + selected_range: None, + user_prompt: None, + }; + let templates: Vec<(PromptPriority, Box)> = vec![ + ( + PromptPriority::Mandatory, + Box::new(TestLowPriorityTemplate {}), + ), + ( + PromptPriority::Ordered { order: 0 }, + Box::new(TestPromptTemplate {}), + ), + ( + PromptPriority::Ordered { order: 1 }, + Box::new(TestLowPriorityTemplate {}), + ), + ]; + let chain = PromptChain::new(args, templates); + + let (prompt, token_count) = chain.generate(true).unwrap(); + + assert_eq!( + prompt, + "This is a low priority test prompt template\nThis is a test prompt template\nThis is a low priority test prompt " + .to_string() + ); + assert_eq!(token_count, capacity - reserved_tokens); + } +} diff --git a/crates/ai/src/templates/file_context.rs b/crates/ai/src/templates/file_context.rs new file mode 100644 index 0000000000..1afd61192e --- /dev/null +++ b/crates/ai/src/templates/file_context.rs @@ -0,0 +1,160 @@ +use anyhow::anyhow; +use language::BufferSnapshot; +use language::ToOffset; + +use crate::models::LanguageModel; +use crate::templates::base::PromptArguments; +use crate::templates::base::PromptTemplate; +use std::fmt::Write; +use std::ops::Range; +use std::sync::Arc; + +fn retrieve_context( + buffer: &BufferSnapshot, + selected_range: &Option>, + model: Arc, + max_token_count: Option, +) -> anyhow::Result<(String, usize, bool)> { + let mut prompt = String::new(); + let mut truncated = false; + if let Some(selected_range) = selected_range { + let start = selected_range.start.to_offset(buffer); + let end = selected_range.end.to_offset(buffer); + + let start_window = buffer.text_for_range(0..start).collect::(); + + let mut selected_window = String::new(); + if start == end { + write!(selected_window, "<|START|>").unwrap(); + } else { + write!(selected_window, "<|START|").unwrap(); + } + + write!( + selected_window, + "{}", + buffer.text_for_range(start..end).collect::() + ) + .unwrap(); + + if start != end { + write!(selected_window, "|END|>").unwrap(); + } + + let end_window = buffer.text_for_range(end..buffer.len()).collect::(); + + if let Some(max_token_count) = max_token_count { + let selected_tokens = model.count_tokens(&selected_window)?; + if selected_tokens > max_token_count { + return Err(anyhow!( + "selected range is greater than model context window, truncation not possible" + )); + }; + + let mut remaining_tokens = max_token_count - selected_tokens; + let start_window_tokens = model.count_tokens(&start_window)?; + let end_window_tokens = model.count_tokens(&end_window)?; + let outside_tokens = start_window_tokens + end_window_tokens; + if outside_tokens > remaining_tokens { + let (start_goal_tokens, end_goal_tokens) = + if start_window_tokens < end_window_tokens { + let start_goal_tokens = (remaining_tokens / 2).min(start_window_tokens); + remaining_tokens -= start_goal_tokens; + let end_goal_tokens = remaining_tokens.min(end_window_tokens); + (start_goal_tokens, end_goal_tokens) + } else { + let end_goal_tokens = (remaining_tokens / 2).min(end_window_tokens); + remaining_tokens -= end_goal_tokens; + let start_goal_tokens = remaining_tokens.min(start_window_tokens); + (start_goal_tokens, end_goal_tokens) + }; + + let truncated_start_window = + model.truncate_start(&start_window, start_goal_tokens)?; + let truncated_end_window = model.truncate(&end_window, end_goal_tokens)?; + writeln!( + prompt, + "{truncated_start_window}{selected_window}{truncated_end_window}" + ) + .unwrap(); + truncated = true; + } else { + writeln!(prompt, "{start_window}{selected_window}{end_window}").unwrap(); + } + } else { + // If we dont have a selected range, include entire file. + writeln!(prompt, "{}", &buffer.text()).unwrap(); + + // Dumb truncation strategy + if let Some(max_token_count) = max_token_count { + if model.count_tokens(&prompt)? > max_token_count { + truncated = true; + prompt = model.truncate(&prompt, max_token_count)?; + } + } + } + } + + let token_count = model.count_tokens(&prompt)?; + anyhow::Ok((prompt, token_count, truncated)) +} + +pub struct FileContext {} + +impl PromptTemplate for FileContext { + fn generate( + &self, + args: &PromptArguments, + max_token_length: Option, + ) -> anyhow::Result<(String, usize)> { + if let Some(buffer) = &args.buffer { + let mut prompt = String::new(); + // Add Initial Preamble + // TODO: Do we want to add the path in here? + writeln!( + prompt, + "The file you are currently working on has the following content:" + ) + .unwrap(); + + let language_name = args + .language_name + .clone() + .unwrap_or("".to_string()) + .to_lowercase(); + + let (context, _, truncated) = retrieve_context( + buffer, + &args.selected_range, + args.model.clone(), + max_token_length, + )?; + writeln!(prompt, "```{language_name}\n{context}\n```").unwrap(); + + if truncated { + writeln!(prompt, "Note the content has been truncated and only represents a portion of the file.").unwrap(); + } + + if let Some(selected_range) = &args.selected_range { + let start = selected_range.start.to_offset(buffer); + let end = selected_range.end.to_offset(buffer); + + if start == end { + writeln!(prompt, "In particular, the user's cursor is currently on the '<|START|>' span in the above content, with no text selected.").unwrap(); + } else { + writeln!(prompt, "In particular, the user has selected a section of the text between the '<|START|' and '|END|>' spans.").unwrap(); + } + } + + // Really dumb truncation strategy + if let Some(max_tokens) = max_token_length { + prompt = args.model.truncate(&prompt, max_tokens)?; + } + + let token_count = args.model.count_tokens(&prompt)?; + anyhow::Ok((prompt, token_count)) + } else { + Err(anyhow!("no buffer provided to retrieve file context from")) + } + } +} diff --git a/crates/ai/src/templates/generate.rs b/crates/ai/src/templates/generate.rs new file mode 100644 index 0000000000..1eeb197f93 --- /dev/null +++ b/crates/ai/src/templates/generate.rs @@ -0,0 +1,95 @@ +use crate::templates::base::{PromptArguments, PromptFileType, PromptTemplate}; +use anyhow::anyhow; +use std::fmt::Write; + +pub fn capitalize(s: &str) -> String { + let mut c = s.chars(); + match c.next() { + None => String::new(), + Some(f) => f.to_uppercase().collect::() + c.as_str(), + } +} + +pub struct GenerateInlineContent {} + +impl PromptTemplate for GenerateInlineContent { + fn generate( + &self, + args: &PromptArguments, + max_token_length: Option, + ) -> anyhow::Result<(String, usize)> { + let Some(user_prompt) = &args.user_prompt else { + return Err(anyhow!("user prompt not provided")); + }; + + let file_type = args.get_file_type(); + let content_type = match &file_type { + PromptFileType::Code => "code", + PromptFileType::Text => "text", + }; + + let mut prompt = String::new(); + + if let Some(selected_range) = &args.selected_range { + if selected_range.start == selected_range.end { + writeln!( + prompt, + "Assume the cursor is located where the `<|START|>` span is." + ) + .unwrap(); + writeln!( + prompt, + "{} can't be replaced, so assume your answer will be inserted at the cursor.", + capitalize(content_type) + ) + .unwrap(); + writeln!( + prompt, + "Generate {content_type} based on the users prompt: {user_prompt}", + ) + .unwrap(); + } else { + writeln!(prompt, "Modify the user's selected {content_type} based upon the users prompt: '{user_prompt}'").unwrap(); + writeln!(prompt, "You must reply with only the adjusted {content_type} (within the '<|START|' and '|END|>' spans) not the entire file.").unwrap(); + writeln!(prompt, "Double check that you only return code and not the '<|START|' and '|END|'> spans").unwrap(); + } + } else { + writeln!( + prompt, + "Generate {content_type} based on the users prompt: {user_prompt}" + ) + .unwrap(); + } + + if let Some(language_name) = &args.language_name { + writeln!( + prompt, + "Your answer MUST always and only be valid {}.", + language_name + ) + .unwrap(); + } + writeln!(prompt, "Never make remarks about the output.").unwrap(); + writeln!( + prompt, + "Do not return anything else, except the generated {content_type}." + ) + .unwrap(); + + match file_type { + PromptFileType::Code => { + // writeln!(prompt, "Always wrap your code in a Markdown block.").unwrap(); + } + _ => {} + } + + // Really dumb truncation strategy + if let Some(max_tokens) = max_token_length { + prompt = args.model.truncate(&prompt, max_tokens)?; + } + + let token_count = args.model.count_tokens(&prompt)?; + + anyhow::Ok((prompt, token_count)) + } +} diff --git a/crates/ai/src/templates/mod.rs b/crates/ai/src/templates/mod.rs new file mode 100644 index 0000000000..0025269a44 --- /dev/null +++ b/crates/ai/src/templates/mod.rs @@ -0,0 +1,5 @@ +pub mod base; +pub mod file_context; +pub mod generate; +pub mod preamble; +pub mod repository_context; diff --git a/crates/ai/src/templates/preamble.rs b/crates/ai/src/templates/preamble.rs new file mode 100644 index 0000000000..9eabaaeb97 --- /dev/null +++ b/crates/ai/src/templates/preamble.rs @@ -0,0 +1,52 @@ +use crate::templates::base::{PromptArguments, PromptFileType, PromptTemplate}; +use std::fmt::Write; + +pub struct EngineerPreamble {} + +impl PromptTemplate for EngineerPreamble { + fn generate( + &self, + args: &PromptArguments, + max_token_length: Option, + ) -> anyhow::Result<(String, usize)> { + let mut prompts = Vec::new(); + + match args.get_file_type() { + PromptFileType::Code => { + prompts.push(format!( + "You are an expert {}engineer.", + args.language_name.clone().unwrap_or("".to_string()) + " " + )); + } + PromptFileType::Text => { + prompts.push("You are an expert engineer.".to_string()); + } + } + + if let Some(project_name) = args.project_name.clone() { + prompts.push(format!( + "You are currently working inside the '{project_name}' project in code editor Zed." + )); + } + + if let Some(mut remaining_tokens) = max_token_length { + let mut prompt = String::new(); + let mut total_count = 0; + for prompt_piece in prompts { + let prompt_token_count = + args.model.count_tokens(&prompt_piece)? + args.model.count_tokens("\n")?; + if remaining_tokens > prompt_token_count { + writeln!(prompt, "{prompt_piece}").unwrap(); + remaining_tokens -= prompt_token_count; + total_count += prompt_token_count; + } + } + + anyhow::Ok((prompt, total_count)) + } else { + let prompt = prompts.join("\n"); + let token_count = args.model.count_tokens(&prompt)?; + anyhow::Ok((prompt, token_count)) + } + } +} diff --git a/crates/ai/src/templates/repository_context.rs b/crates/ai/src/templates/repository_context.rs new file mode 100644 index 0000000000..a8e7f4b5af --- /dev/null +++ b/crates/ai/src/templates/repository_context.rs @@ -0,0 +1,94 @@ +use crate::templates::base::{PromptArguments, PromptTemplate}; +use std::fmt::Write; +use std::{ops::Range, path::PathBuf}; + +use gpui::{AsyncAppContext, ModelHandle}; +use language::{Anchor, Buffer}; + +#[derive(Clone)] +pub struct PromptCodeSnippet { + path: Option, + language_name: Option, + content: String, +} + +impl PromptCodeSnippet { + pub fn new(buffer: ModelHandle, range: Range, cx: &AsyncAppContext) -> Self { + let (content, language_name, file_path) = buffer.read_with(cx, |buffer, _| { + let snapshot = buffer.snapshot(); + let content = snapshot.text_for_range(range.clone()).collect::(); + + let language_name = buffer + .language() + .and_then(|language| Some(language.name().to_string().to_lowercase())); + + let file_path = buffer + .file() + .and_then(|file| Some(file.path().to_path_buf())); + + (content, language_name, file_path) + }); + + PromptCodeSnippet { + path: file_path, + language_name, + content, + } + } +} + +impl ToString for PromptCodeSnippet { + fn to_string(&self) -> String { + let path = self + .path + .as_ref() + .and_then(|path| Some(path.to_string_lossy().to_string())) + .unwrap_or("".to_string()); + let language_name = self.language_name.clone().unwrap_or("".to_string()); + let content = self.content.clone(); + + format!("The below code snippet may be relevant from file: {path}\n```{language_name}\n{content}\n```") + } +} + +pub struct RepositoryContext {} + +impl PromptTemplate for RepositoryContext { + fn generate( + &self, + args: &PromptArguments, + max_token_length: Option, + ) -> anyhow::Result<(String, usize)> { + const MAXIMUM_SNIPPET_TOKEN_COUNT: usize = 500; + let template = "You are working inside a large repository, here are a few code snippets that may be useful."; + let mut prompt = String::new(); + + let mut remaining_tokens = max_token_length.clone(); + let seperator_token_length = args.model.count_tokens("\n")?; + for snippet in &args.snippets { + let mut snippet_prompt = template.to_string(); + let content = snippet.to_string(); + writeln!(snippet_prompt, "{content}").unwrap(); + + let token_count = args.model.count_tokens(&snippet_prompt)?; + if token_count <= MAXIMUM_SNIPPET_TOKEN_COUNT { + if let Some(tokens_left) = remaining_tokens { + if tokens_left >= token_count { + writeln!(prompt, "{snippet_prompt}").unwrap(); + remaining_tokens = if tokens_left >= (token_count + seperator_token_length) + { + Some(tokens_left - token_count - seperator_token_length) + } else { + Some(0) + }; + } + } else { + writeln!(prompt, "{snippet_prompt}").unwrap(); + } + } + } + + let total_token_count = args.model.count_tokens(&prompt)?; + anyhow::Ok((prompt, total_token_count)) + } +} diff --git a/crates/assistant/Cargo.toml b/crates/assistant/Cargo.toml index f1daf47bab..256f4d8416 100644 --- a/crates/assistant/Cargo.toml +++ b/crates/assistant/Cargo.toml @@ -17,13 +17,17 @@ fs = { path = "../fs" } gpui = { path = "../gpui" } language = { path = "../language" } menu = { path = "../menu" } +multi_buffer = { path = "../multi_buffer" } search = { path = "../search" } settings = { path = "../settings" } theme = { path = "../theme" } util = { path = "../util" } workspace = { path = "../workspace" } -uuid.workspace = true +semantic_index = { path = "../semantic_index" } +project = { path = "../project" } +uuid.workspace = true +log.workspace = true anyhow.workspace = true chrono = { version = "0.4", features = ["serde"] } futures.workspace = true @@ -36,7 +40,7 @@ schemars.workspace = true serde.workspace = true serde_json.workspace = true smol.workspace = true -tiktoken-rs = "0.4" +tiktoken-rs = "0.5" [dev-dependencies] editor = { path = "../editor", features = ["test-support"] } diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index b1c6038602..0dee8be510 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -5,8 +5,11 @@ use crate::{ MessageId, MessageMetadata, MessageStatus, Role, SavedConversation, SavedConversationMetadata, SavedMessage, }; -use ai::completion::{ - stream_completion, OpenAICompletionProvider, OpenAIRequest, RequestMessage, OPENAI_API_URL, +use ai::{ + completion::{ + stream_completion, OpenAICompletionProvider, OpenAIRequest, RequestMessage, OPENAI_API_URL, + }, + templates::repository_context::PromptCodeSnippet, }; use anyhow::{anyhow, Result}; use chrono::{DateTime, Local}; @@ -29,13 +32,15 @@ use gpui::{ }, fonts::HighlightStyle, geometry::vector::{vec2f, Vector2F}, - platform::{CursorStyle, MouseButton}, + platform::{CursorStyle, MouseButton, PromptLevel}, Action, AnyElement, AppContext, AsyncAppContext, ClipboardItem, Element, Entity, ModelContext, - ModelHandle, SizeConstraint, Subscription, Task, View, ViewContext, ViewHandle, WeakViewHandle, - WindowContext, + ModelHandle, SizeConstraint, Subscription, Task, View, ViewContext, ViewHandle, + WeakModelHandle, WeakViewHandle, WindowContext, }; use language::{language_settings::SoftWrap, Buffer, LanguageRegistry, ToOffset as _}; +use project::Project; use search::BufferSearchBar; +use semantic_index::{SemanticIndex, SemanticIndexStatus}; use settings::SettingsStore; use std::{ cell::{Cell, RefCell}, @@ -46,7 +51,7 @@ use std::{ path::{Path, PathBuf}, rc::Rc, sync::Arc, - time::Duration, + time::{Duration, Instant}, }; use theme::{ components::{action_button::Button, ComponentExt}, @@ -72,6 +77,7 @@ actions!( ResetKey, InlineAssist, ToggleIncludeConversation, + ToggleRetrieveContext, ] ); @@ -108,6 +114,7 @@ pub fn init(cx: &mut AppContext) { cx.add_action(InlineAssistant::confirm); cx.add_action(InlineAssistant::cancel); cx.add_action(InlineAssistant::toggle_include_conversation); + cx.add_action(InlineAssistant::toggle_retrieve_context); cx.add_action(InlineAssistant::move_up); cx.add_action(InlineAssistant::move_down); } @@ -145,6 +152,8 @@ pub struct AssistantPanel { include_conversation_in_next_inline_assist: bool, inline_prompt_history: VecDeque, _watch_saved_conversations: Task>, + semantic_index: Option>, + retrieve_context_in_next_inline_assist: bool, } impl AssistantPanel { @@ -191,6 +200,9 @@ impl AssistantPanel { toolbar.add_item(cx.add_view(|cx| BufferSearchBar::new(cx)), cx); toolbar }); + + let semantic_index = SemanticIndex::global(cx); + let mut this = Self { workspace: workspace_handle, active_editor_index: Default::default(), @@ -215,6 +227,8 @@ impl AssistantPanel { include_conversation_in_next_inline_assist: false, inline_prompt_history: Default::default(), _watch_saved_conversations, + semantic_index, + retrieve_context_in_next_inline_assist: false, }; let mut old_dock_position = this.position(cx); @@ -262,12 +276,19 @@ impl AssistantPanel { return; }; + let project = workspace.project(); + this.update(cx, |assistant, cx| { - assistant.new_inline_assist(&active_editor, cx) + assistant.new_inline_assist(&active_editor, cx, project) }); } - fn new_inline_assist(&mut self, editor: &ViewHandle, cx: &mut ViewContext) { + fn new_inline_assist( + &mut self, + editor: &ViewHandle, + cx: &mut ViewContext, + project: &ModelHandle, + ) { let api_key = if let Some(api_key) = self.api_key.borrow().clone() { api_key } else { @@ -275,7 +296,7 @@ impl AssistantPanel { }; let selection = editor.read(cx).selections.newest_anchor().clone(); - if selection.start.excerpt_id() != selection.end.excerpt_id() { + if selection.start.excerpt_id != selection.end.excerpt_id { return; } let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx); @@ -312,6 +333,27 @@ impl AssistantPanel { Codegen::new(editor.read(cx).buffer().clone(), codegen_kind, provider, cx) }); + if let Some(semantic_index) = self.semantic_index.clone() { + let project = project.clone(); + cx.spawn(|_, mut cx| async move { + let previously_indexed = semantic_index + .update(&mut cx, |index, cx| { + index.project_previously_indexed(&project, cx) + }) + .await + .unwrap_or(false); + if previously_indexed { + let _ = semantic_index + .update(&mut cx, |index, cx| { + index.index_project(project.clone(), cx) + }) + .await; + } + anyhow::Ok(()) + }) + .detach_and_log_err(cx); + } + let measurements = Rc::new(Cell::new(BlockMeasurements::default())); let inline_assistant = cx.add_view(|cx| { let assistant = InlineAssistant::new( @@ -322,6 +364,9 @@ impl AssistantPanel { codegen.clone(), self.workspace.clone(), cx, + self.retrieve_context_in_next_inline_assist, + self.semantic_index.clone(), + project.clone(), ); cx.focus_self(); assistant @@ -362,6 +407,7 @@ impl AssistantPanel { editor: editor.downgrade(), inline_assistant: Some((block_id, inline_assistant.clone())), codegen: codegen.clone(), + project: project.downgrade(), _subscriptions: vec![ cx.subscribe(&inline_assistant, Self::handle_inline_assistant_event), cx.subscribe(editor, { @@ -440,8 +486,15 @@ impl AssistantPanel { InlineAssistantEvent::Confirmed { prompt, include_conversation, + retrieve_context, } => { - self.confirm_inline_assist(assist_id, prompt, *include_conversation, cx); + self.confirm_inline_assist( + assist_id, + prompt, + *include_conversation, + cx, + *retrieve_context, + ); } InlineAssistantEvent::Canceled => { self.finish_inline_assist(assist_id, true, cx); @@ -454,6 +507,9 @@ impl AssistantPanel { } => { self.include_conversation_in_next_inline_assist = *include_conversation; } + InlineAssistantEvent::RetrieveContextToggled { retrieve_context } => { + self.retrieve_context_in_next_inline_assist = *retrieve_context + } } } @@ -532,6 +588,7 @@ impl AssistantPanel { user_prompt: &str, include_conversation: bool, cx: &mut ViewContext, + retrieve_context: bool, ) { let conversation = if include_conversation { self.active_editor() @@ -553,6 +610,20 @@ impl AssistantPanel { return; }; + let project = pending_assist.project.clone(); + + let project_name = if let Some(project) = project.upgrade(cx) { + Some( + project + .read(cx) + .worktree_root_names(cx) + .collect::>() + .join("/"), + ) + } else { + None + }; + self.inline_prompt_history .retain(|prompt| prompt != user_prompt); self.inline_prompt_history.push_back(user_prompt.into()); @@ -590,13 +661,70 @@ impl AssistantPanel { None }; - let codegen_kind = codegen.read(cx).kind().clone(); + // Higher Temperature increases the randomness of model outputs. + // If Markdown or No Language is Known, increase the randomness for more creative output + // If Code, decrease temperature to get more deterministic outputs + let temperature = if let Some(language) = language_name.clone() { + if language.to_string() != "Markdown".to_string() { + 0.5 + } else { + 1.0 + } + } else { + 1.0 + }; + let user_prompt = user_prompt.to_string(); - let mut messages = Vec::new(); + let snippets = if retrieve_context { + let Some(project) = project.upgrade(cx) else { + return; + }; + + let search_results = if let Some(semantic_index) = self.semantic_index.clone() { + let search_results = semantic_index.update(cx, |this, cx| { + this.search_project(project, user_prompt.to_string(), 10, vec![], vec![], cx) + }); + + cx.background() + .spawn(async move { search_results.await.unwrap_or_default() }) + } else { + Task::ready(Vec::new()) + }; + + let snippets = cx.spawn(|_, cx| async move { + let mut snippets = Vec::new(); + for result in search_results.await { + snippets.push(PromptCodeSnippet::new(result.buffer, result.range, &cx)); + } + snippets + }); + snippets + } else { + Task::ready(Vec::new()) + }; + let mut model = settings::get::(cx) .default_open_ai_model .clone(); + let model_name = model.full_name(); + + let prompt = cx.background().spawn(async move { + let snippets = snippets.await; + + let language_name = language_name.as_deref(); + generate_content_prompt( + user_prompt, + language_name, + buffer, + range, + snippets, + model_name, + project_name, + ) + }); + + let mut messages = Vec::new(); if let Some(conversation) = conversation { let conversation = conversation.read(cx); let buffer = conversation.buffer.read(cx); @@ -608,24 +736,24 @@ impl AssistantPanel { model = conversation.model.clone(); } - let prompt = cx.background().spawn(async move { - let language_name = language_name.as_deref(); - generate_content_prompt(user_prompt, language_name, &buffer, range, codegen_kind) - }); - cx.spawn(|_, mut cx| async move { - let prompt = prompt.await; + // I Don't know if we want to return a ? here. + let prompt = prompt.await?; messages.push(RequestMessage { role: Role::User, content: prompt, }); + let request = OpenAIRequest { model: model.full_name().into(), messages, stream: true, + stop: vec!["|END|>".to_string()], + temperature, }; codegen.update(&mut cx, |codegen, cx| codegen.start(request, cx)); + anyhow::Ok(()) }) .detach(); } @@ -1514,12 +1642,14 @@ impl Conversation { Role::Assistant => "assistant".into(), Role::System => "system".into(), }, - content: self - .buffer - .read(cx) - .text_for_range(message.offset_range) - .collect(), + content: Some( + self.buffer + .read(cx) + .text_for_range(message.offset_range) + .collect(), + ), name: None, + function_call: None, }) }) .collect::>(); @@ -1613,6 +1743,8 @@ impl Conversation { .map(|message| message.to_open_ai_message(self.buffer.read(cx))) .collect(), stream: true, + stop: vec![], + temperature: 1.0, }; let stream = stream_completion(api_key, cx.background().clone(), request); @@ -1897,6 +2029,8 @@ impl Conversation { model: self.model.full_name().to_string(), messages: messages.collect(), stream: true, + stop: vec![], + temperature: 1.0, }; let stream = stream_completion(api_key, cx.background().clone(), request); @@ -2638,12 +2772,16 @@ enum InlineAssistantEvent { Confirmed { prompt: String, include_conversation: bool, + retrieve_context: bool, }, Canceled, Dismissed, IncludeConversationToggled { include_conversation: bool, }, + RetrieveContextToggled { + retrieve_context: bool, + }, } struct InlineAssistant { @@ -2659,6 +2797,11 @@ struct InlineAssistant { pending_prompt: String, codegen: ModelHandle, _subscriptions: Vec, + retrieve_context: bool, + semantic_index: Option>, + semantic_permissioned: Option, + project: WeakModelHandle, + maintain_rate_limit: Option>, } impl Entity for InlineAssistant { @@ -2675,51 +2818,65 @@ impl View for InlineAssistant { let theme = theme::current(cx); Flex::row() - .with_child( - Flex::row() - .with_child( - Button::action(ToggleIncludeConversation) - .with_tooltip("Include Conversation", theme.tooltip.clone()) + .with_children([Flex::row() + .with_child( + Button::action(ToggleIncludeConversation) + .with_tooltip("Include Conversation", theme.tooltip.clone()) + .with_id(self.id) + .with_contents(theme::components::svg::Svg::new("icons/ai.svg")) + .toggleable(self.include_conversation) + .with_style(theme.assistant.inline.include_conversation.clone()) + .element() + .aligned(), + ) + .with_children(if SemanticIndex::enabled(cx) { + Some( + Button::action(ToggleRetrieveContext) + .with_tooltip("Retrieve Context", theme.tooltip.clone()) .with_id(self.id) - .with_contents(theme::components::svg::Svg::new("icons/ai.svg")) - .toggleable(self.include_conversation) - .with_style(theme.assistant.inline.include_conversation.clone()) + .with_contents(theme::components::svg::Svg::new( + "icons/magnifying_glass.svg", + )) + .toggleable(self.retrieve_context) + .with_style(theme.assistant.inline.retrieve_context.clone()) .element() .aligned(), ) - .with_children(if let Some(error) = self.codegen.read(cx).error() { - Some( - Svg::new("icons/error.svg") - .with_color(theme.assistant.error_icon.color) - .constrained() - .with_width(theme.assistant.error_icon.width) - .contained() - .with_style(theme.assistant.error_icon.container) - .with_tooltip::( - self.id, - error.to_string(), - None, - theme.tooltip.clone(), - cx, - ) - .aligned(), - ) - } else { - None - }) - .aligned() - .constrained() - .dynamically({ - let measurements = self.measurements.clone(); - move |constraint, _, _| { - let measurements = measurements.get(); - SizeConstraint { - min: vec2f(measurements.gutter_width, constraint.min.y()), - max: vec2f(measurements.gutter_width, constraint.max.y()), - } + } else { + None + }) + .with_children(if let Some(error) = self.codegen.read(cx).error() { + Some( + Svg::new("icons/error.svg") + .with_color(theme.assistant.error_icon.color) + .constrained() + .with_width(theme.assistant.error_icon.width) + .contained() + .with_style(theme.assistant.error_icon.container) + .with_tooltip::( + self.id, + error.to_string(), + None, + theme.tooltip.clone(), + cx, + ) + .aligned(), + ) + } else { + None + }) + .aligned() + .constrained() + .dynamically({ + let measurements = self.measurements.clone(); + move |constraint, _, _| { + let measurements = measurements.get(); + SizeConstraint { + min: vec2f(measurements.gutter_width, constraint.min.y()), + max: vec2f(measurements.gutter_width, constraint.max.y()), } - }), - ) + } + })]) .with_child(Empty::new().constrained().dynamically({ let measurements = self.measurements.clone(); move |constraint, _, _| { @@ -2742,6 +2899,16 @@ impl View for InlineAssistant { .left() .flex(1., true), ) + .with_children(if self.retrieve_context { + Some( + Flex::row() + .with_children(self.retrieve_context_status(cx)) + .flex(1., true) + .aligned(), + ) + } else { + None + }) .contained() .with_style(theme.assistant.inline.container) .into_any() @@ -2767,6 +2934,9 @@ impl InlineAssistant { codegen: ModelHandle, workspace: WeakViewHandle, cx: &mut ViewContext, + retrieve_context: bool, + semantic_index: Option>, + project: ModelHandle, ) -> Self { let prompt_editor = cx.add_view(|cx| { let mut editor = Editor::single_line( @@ -2780,11 +2950,16 @@ impl InlineAssistant { editor.set_placeholder_text(placeholder, cx); editor }); - let subscriptions = vec![ + let mut subscriptions = vec![ cx.observe(&codegen, Self::handle_codegen_changed), cx.subscribe(&prompt_editor, Self::handle_prompt_editor_events), ]; - Self { + + if let Some(semantic_index) = semantic_index.clone() { + subscriptions.push(cx.observe(&semantic_index, Self::semantic_index_changed)); + } + + let assistant = Self { id, prompt_editor, workspace, @@ -2797,7 +2972,33 @@ impl InlineAssistant { pending_prompt: String::new(), codegen, _subscriptions: subscriptions, + retrieve_context, + semantic_permissioned: None, + semantic_index, + project: project.downgrade(), + maintain_rate_limit: None, + }; + + assistant.index_project(cx).log_err(); + + assistant + } + + fn semantic_permissioned(&self, cx: &mut ViewContext) -> Task> { + if let Some(value) = self.semantic_permissioned { + return Task::ready(Ok(value)); } + + let Some(project) = self.project.upgrade(cx) else { + return Task::ready(Err(anyhow!("project was dropped"))); + }; + + self.semantic_index + .as_ref() + .map(|semantic| { + semantic.update(cx, |this, cx| this.project_previously_indexed(&project, cx)) + }) + .unwrap_or(Task::ready(Ok(false))) } fn handle_prompt_editor_events( @@ -2812,6 +3013,37 @@ impl InlineAssistant { } } + fn semantic_index_changed( + &mut self, + semantic_index: ModelHandle, + cx: &mut ViewContext, + ) { + let Some(project) = self.project.upgrade(cx) else { + return; + }; + + let status = semantic_index.read(cx).status(&project); + match status { + SemanticIndexStatus::Indexing { + rate_limit_expiry: Some(_), + .. + } => { + if self.maintain_rate_limit.is_none() { + self.maintain_rate_limit = Some(cx.spawn(|this, mut cx| async move { + loop { + cx.background().timer(Duration::from_secs(1)).await; + this.update(&mut cx, |_, cx| cx.notify()).log_err(); + } + })); + } + return; + } + _ => { + self.maintain_rate_limit = None; + } + } + } + fn handle_codegen_changed(&mut self, _: ModelHandle, cx: &mut ViewContext) { let is_read_only = !self.codegen.read(cx).idle(); self.prompt_editor.update(cx, |editor, cx| { @@ -2861,12 +3093,241 @@ impl InlineAssistant { cx.emit(InlineAssistantEvent::Confirmed { prompt, include_conversation: self.include_conversation, + retrieve_context: self.retrieve_context, }); self.confirmed = true; cx.notify(); } } + fn toggle_retrieve_context(&mut self, _: &ToggleRetrieveContext, cx: &mut ViewContext) { + let semantic_permissioned = self.semantic_permissioned(cx); + + let Some(project) = self.project.upgrade(cx) else { + return; + }; + + let project_name = project + .read(cx) + .worktree_root_names(cx) + .collect::>() + .join("/"); + let is_plural = project_name.chars().filter(|letter| *letter == '/').count() > 0; + let prompt_text = format!("Would you like to index the '{}' project{} for context retrieval? This requires sending code to the OpenAI API", project_name, + if is_plural { + "s" + } else {""}); + + cx.spawn(|this, mut cx| async move { + // If Necessary prompt user + if !semantic_permissioned.await.unwrap_or(false) { + let mut answer = this.update(&mut cx, |_, cx| { + cx.prompt( + PromptLevel::Info, + prompt_text.as_str(), + &["Continue", "Cancel"], + ) + })?; + + if answer.next().await == Some(0) { + this.update(&mut cx, |this, _| { + this.semantic_permissioned = Some(true); + })?; + } else { + return anyhow::Ok(()); + } + } + + // If permissioned, update context appropriately + this.update(&mut cx, |this, cx| { + this.retrieve_context = !this.retrieve_context; + + cx.emit(InlineAssistantEvent::RetrieveContextToggled { + retrieve_context: this.retrieve_context, + }); + + if this.retrieve_context { + this.index_project(cx).log_err(); + } + + cx.notify(); + })?; + + anyhow::Ok(()) + }) + .detach_and_log_err(cx); + } + + fn index_project(&self, cx: &mut ViewContext) -> anyhow::Result<()> { + let Some(project) = self.project.upgrade(cx) else { + return Err(anyhow!("project was dropped!")); + }; + + let semantic_permissioned = self.semantic_permissioned(cx); + if let Some(semantic_index) = SemanticIndex::global(cx) { + cx.spawn(|_, mut cx| async move { + // This has to be updated to accomodate for semantic_permissions + if semantic_permissioned.await.unwrap_or(false) { + semantic_index + .update(&mut cx, |index, cx| index.index_project(project, cx)) + .await + } else { + Err(anyhow!("project is not permissioned for semantic indexing")) + } + }) + .detach_and_log_err(cx); + } + + anyhow::Ok(()) + } + + fn retrieve_context_status( + &self, + cx: &mut ViewContext, + ) -> Option> { + enum ContextStatusIcon {} + + let Some(project) = self.project.upgrade(cx) else { + return None; + }; + + if let Some(semantic_index) = SemanticIndex::global(cx) { + let status = semantic_index.update(cx, |index, _| index.status(&project)); + let theme = theme::current(cx); + match status { + SemanticIndexStatus::NotAuthenticated {} => Some( + Svg::new("icons/error.svg") + .with_color(theme.assistant.error_icon.color) + .constrained() + .with_width(theme.assistant.error_icon.width) + .contained() + .with_style(theme.assistant.error_icon.container) + .with_tooltip::( + self.id, + "Not Authenticated. Please ensure you have a valid 'OPENAI_API_KEY' in your environment variables.", + None, + theme.tooltip.clone(), + cx, + ) + .aligned() + .into_any(), + ), + SemanticIndexStatus::NotIndexed {} => Some( + Svg::new("icons/error.svg") + .with_color(theme.assistant.inline.context_status.error_icon.color) + .constrained() + .with_width(theme.assistant.inline.context_status.error_icon.width) + .contained() + .with_style(theme.assistant.inline.context_status.error_icon.container) + .with_tooltip::( + self.id, + "Not Indexed", + None, + theme.tooltip.clone(), + cx, + ) + .aligned() + .into_any(), + ), + SemanticIndexStatus::Indexing { + remaining_files, + rate_limit_expiry, + } => { + + let mut status_text = if remaining_files == 0 { + "Indexing...".to_string() + } else { + format!("Remaining files to index: {remaining_files}") + }; + + if let Some(rate_limit_expiry) = rate_limit_expiry { + let remaining_seconds = rate_limit_expiry.duration_since(Instant::now()); + if remaining_seconds > Duration::from_secs(0) && remaining_files > 0 { + write!( + status_text, + " (rate limit expires in {}s)", + remaining_seconds.as_secs() + ) + .unwrap(); + } + } + Some( + Svg::new("icons/update.svg") + .with_color(theme.assistant.inline.context_status.in_progress_icon.color) + .constrained() + .with_width(theme.assistant.inline.context_status.in_progress_icon.width) + .contained() + .with_style(theme.assistant.inline.context_status.in_progress_icon.container) + .with_tooltip::( + self.id, + status_text, + None, + theme.tooltip.clone(), + cx, + ) + .aligned() + .into_any(), + ) + } + SemanticIndexStatus::Indexed {} => Some( + Svg::new("icons/check.svg") + .with_color(theme.assistant.inline.context_status.complete_icon.color) + .constrained() + .with_width(theme.assistant.inline.context_status.complete_icon.width) + .contained() + .with_style(theme.assistant.inline.context_status.complete_icon.container) + .with_tooltip::( + self.id, + "Index up to date", + None, + theme.tooltip.clone(), + cx, + ) + .aligned() + .into_any(), + ), + } + } else { + None + } + } + + // fn retrieve_context_status(&self, cx: &mut ViewContext) -> String { + // let project = self.project.clone(); + // if let Some(semantic_index) = self.semantic_index.clone() { + // let status = semantic_index.update(cx, |index, cx| index.status(&project)); + // return match status { + // // This theoretically shouldnt be a valid code path + // // As the inline assistant cant be launched without an API key + // // We keep it here for safety + // semantic_index::SemanticIndexStatus::NotAuthenticated => { + // "Not Authenticated!\nPlease ensure you have an `OPENAI_API_KEY` in your environment variables.".to_string() + // } + // semantic_index::SemanticIndexStatus::Indexed => { + // "Indexing Complete!".to_string() + // } + // semantic_index::SemanticIndexStatus::Indexing { remaining_files, rate_limit_expiry } => { + + // let mut status = format!("Remaining files to index for Context Retrieval: {remaining_files}"); + + // if let Some(rate_limit_expiry) = rate_limit_expiry { + // let remaining_seconds = + // rate_limit_expiry.duration_since(Instant::now()); + // if remaining_seconds > Duration::from_secs(0) { + // write!(status, " (rate limit resets in {}s)", remaining_seconds.as_secs()).unwrap(); + // } + // } + // status + // } + // semantic_index::SemanticIndexStatus::NotIndexed => { + // "Not Indexed for Context Retrieval".to_string() + // } + // }; + // } + + // "".to_string() + // } + fn toggle_include_conversation( &mut self, _: &ToggleIncludeConversation, @@ -2929,6 +3390,7 @@ struct PendingInlineAssist { inline_assistant: Option<(BlockId, ViewHandle)>, codegen: ModelHandle, _subscriptions: Vec, + project: WeakModelHandle, } fn merge_ranges(ranges: &mut Vec>, buffer: &MultiBufferSnapshot) { diff --git a/crates/assistant/src/codegen.rs b/crates/assistant/src/codegen.rs index b6ef6b5cfa..6b79daba42 100644 --- a/crates/assistant/src/codegen.rs +++ b/crates/assistant/src/codegen.rs @@ -1,10 +1,11 @@ use crate::streaming_diff::{Hunk, StreamingDiff}; use ai::completion::{CompletionProvider, OpenAIRequest}; use anyhow::Result; -use editor::{multi_buffer, Anchor, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint}; +use editor::{Anchor, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint}; use futures::{channel::mpsc, SinkExt, Stream, StreamExt}; use gpui::{Entity, ModelContext, ModelHandle, Task}; use language::{Rope, TransactionId}; +use multi_buffer; use std::{cmp, future, ops::Range, sync::Arc}; pub enum Event { diff --git a/crates/assistant/src/prompts.rs b/crates/assistant/src/prompts.rs index d326a7f445..dffcbc2923 100644 --- a/crates/assistant/src/prompts.rs +++ b/crates/assistant/src/prompts.rs @@ -1,8 +1,13 @@ -use crate::codegen::CodegenKind; +use ai::models::{LanguageModel, OpenAILanguageModel}; +use ai::templates::base::{PromptArguments, PromptChain, PromptPriority, PromptTemplate}; +use ai::templates::file_context::FileContext; +use ai::templates::generate::GenerateInlineContent; +use ai::templates::preamble::EngineerPreamble; +use ai::templates::repository_context::{PromptCodeSnippet, RepositoryContext}; use language::{BufferSnapshot, OffsetRangeExt, ToOffset}; use std::cmp::{self, Reverse}; -use std::fmt::Write; use std::ops::Range; +use std::sync::Arc; #[allow(dead_code)] fn summarize(buffer: &BufferSnapshot, selected_range: Range) -> String { @@ -118,86 +123,50 @@ fn summarize(buffer: &BufferSnapshot, selected_range: Range) -> S pub fn generate_content_prompt( user_prompt: String, language_name: Option<&str>, - buffer: &BufferSnapshot, - range: Range, - kind: CodegenKind, -) -> String { - let range = range.to_offset(buffer); - let mut prompt = String::new(); - - // General Preamble - if let Some(language_name) = language_name { - writeln!(prompt, "You're an expert {language_name} engineer.\n").unwrap(); + buffer: BufferSnapshot, + range: Range, + search_results: Vec, + model: &str, + project_name: Option, +) -> anyhow::Result { + // Using new Prompt Templates + let openai_model: Arc = Arc::new(OpenAILanguageModel::load(model)); + let lang_name = if let Some(language_name) = language_name { + Some(language_name.to_string()) } else { - writeln!(prompt, "You're an expert engineer.\n").unwrap(); - } + None + }; - let mut content = String::new(); - content.extend(buffer.text_for_range(0..range.start)); - if range.start == range.end { - content.push_str("<|START|>"); - } else { - content.push_str("<|START|"); - } - content.extend(buffer.text_for_range(range.clone())); - if range.start != range.end { - content.push_str("|END|>"); - } - content.extend(buffer.text_for_range(range.end..buffer.len())); + let args = PromptArguments { + model: openai_model, + language_name: lang_name.clone(), + project_name, + snippets: search_results.clone(), + reserved_tokens: 1000, + buffer: Some(buffer), + selected_range: Some(range), + user_prompt: Some(user_prompt.clone()), + }; - writeln!( - prompt, - "The file you are currently working on has the following content:" - ) - .unwrap(); - if let Some(language_name) = language_name { - let language_name = language_name.to_lowercase(); - writeln!(prompt, "```{language_name}\n{content}\n```").unwrap(); - } else { - writeln!(prompt, "```\n{content}\n```").unwrap(); - } + let templates: Vec<(PromptPriority, Box)> = vec![ + (PromptPriority::Mandatory, Box::new(EngineerPreamble {})), + ( + PromptPriority::Ordered { order: 1 }, + Box::new(RepositoryContext {}), + ), + ( + PromptPriority::Ordered { order: 0 }, + Box::new(FileContext {}), + ), + ( + PromptPriority::Mandatory, + Box::new(GenerateInlineContent {}), + ), + ]; + let chain = PromptChain::new(args, templates); + let (prompt, _) = chain.generate(true)?; - match kind { - CodegenKind::Generate { position: _ } => { - writeln!(prompt, "In particular, the user's cursor is current on the '<|START|>' span in the above outline, with no text selected.").unwrap(); - writeln!( - prompt, - "Assume the cursor is located where the `<|START|` marker is." - ) - .unwrap(); - writeln!( - prompt, - "Text can't be replaced, so assume your answer will be inserted at the cursor." - ) - .unwrap(); - writeln!( - prompt, - "Generate text based on the users prompt: {user_prompt}" - ) - .unwrap(); - } - CodegenKind::Transform { range: _ } => { - writeln!(prompt, "In particular, the user has selected a section of the text between the '<|START|' and '|END|>' spans.").unwrap(); - writeln!( - prompt, - "Modify the users code selected text based upon the users prompt: {user_prompt}" - ) - .unwrap(); - writeln!( - prompt, - "You MUST reply with only the adjusted code (within the '<|START|' and '|END|>' spans), not the entire file." - ) - .unwrap(); - } - } - - if let Some(language_name) = language_name { - writeln!(prompt, "Your answer MUST always be valid {language_name}").unwrap(); - } - writeln!(prompt, "Always wrap your response in a Markdown codeblock").unwrap(); - writeln!(prompt, "Never make remarks about the output.").unwrap(); - - prompt + anyhow::Ok(prompt) } #[cfg(test)] diff --git a/crates/call/src/call.rs b/crates/call/src/call.rs index 0846341325..ca1a60bd63 100644 --- a/crates/call/src/call.rs +++ b/crates/call/src/call.rs @@ -10,7 +10,7 @@ use client::{ ZED_ALWAYS_ACTIVE, }; use collections::HashSet; -use futures::{future::Shared, FutureExt}; +use futures::{channel::oneshot, future::Shared, Future, FutureExt}; use gpui::{ AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Subscription, Task, WeakModelHandle, @@ -37,10 +37,42 @@ pub struct IncomingCall { pub initial_project: Option, } +pub struct OneAtATime { + cancel: Option>, +} + +impl OneAtATime { + /// spawn a task in the given context. + /// if another task is spawned before that resolves, or if the OneAtATime itself is dropped, the first task will be cancelled and return Ok(None) + /// otherwise you'll see the result of the task. + fn spawn(&mut self, cx: &mut AppContext, f: F) -> Task>> + where + F: 'static + FnOnce(AsyncAppContext) -> Fut, + Fut: Future>, + R: 'static, + { + let (tx, rx) = oneshot::channel(); + self.cancel.replace(tx); + cx.spawn(|cx| async move { + futures::select_biased! { + _ = rx.fuse() => Ok(None), + result = f(cx).fuse() => result.map(Some), + } + }) + } + + fn running(&self) -> bool { + self.cancel + .as_ref() + .is_some_and(|cancel| !cancel.is_canceled()) + } +} + /// Singleton global maintaining the user's participation in a room across workspaces. pub struct ActiveCall { room: Option<(ModelHandle, Vec)>, pending_room_creation: Option, Arc>>>>, + _join_debouncer: OneAtATime, location: Option>, pending_invites: HashSet, incoming_call: ( @@ -69,6 +101,7 @@ impl ActiveCall { pending_invites: Default::default(), incoming_call: watch::channel(), + _join_debouncer: OneAtATime { cancel: None }, _subscriptions: vec![ client.add_request_handler(cx.handle(), Self::handle_incoming_call), client.add_message_handler(cx.handle(), Self::handle_call_canceled), @@ -143,6 +176,10 @@ impl ActiveCall { } cx.notify(); + if self._join_debouncer.running() { + return Task::ready(Ok(())); + } + let room = if let Some(room) = self.room().cloned() { Some(Task::ready(Ok(room)).shared()) } else { @@ -259,11 +296,20 @@ impl ActiveCall { return Task::ready(Err(anyhow!("no incoming call"))); }; - let join = Room::join(&call, self.client.clone(), self.user_store.clone(), cx); + if self.pending_room_creation.is_some() { + return Task::ready(Ok(())); + } + + let room_id = call.room_id.clone(); + let client = self.client.clone(); + let user_store = self.user_store.clone(); + let join = self + ._join_debouncer + .spawn(cx, move |cx| Room::join(room_id, client, user_store, cx)); cx.spawn(|this, mut cx| async move { let room = join.await?; - this.update(&mut cx, |this, cx| this.set_room(Some(room.clone()), cx)) + this.update(&mut cx, |this, cx| this.set_room(room.clone(), cx)) .await?; this.update(&mut cx, |this, cx| { this.report_call_event("accept incoming", cx) @@ -290,20 +336,28 @@ impl ActiveCall { &mut self, channel_id: u64, cx: &mut ModelContext, - ) -> Task>> { + ) -> Task>>> { if let Some(room) = self.room().cloned() { if room.read(cx).channel_id() == Some(channel_id) { - return Task::ready(Ok(room)); + return Task::ready(Ok(Some(room))); } else { room.update(cx, |room, cx| room.clear_state(cx)); } } - let join = Room::join_channel(channel_id, self.client.clone(), self.user_store.clone(), cx); + if self.pending_room_creation.is_some() { + return Task::ready(Ok(None)); + } - cx.spawn(|this, mut cx| async move { + let client = self.client.clone(); + let user_store = self.user_store.clone(); + let join = self._join_debouncer.spawn(cx, move |cx| async move { + Room::join_channel(channel_id, client, user_store, cx).await + }); + + cx.spawn(move |this, mut cx| async move { let room = join.await?; - this.update(&mut cx, |this, cx| this.set_room(Some(room.clone()), cx)) + this.update(&mut cx, |this, cx| this.set_room(room.clone(), cx)) .await?; this.update(&mut cx, |this, cx| { this.report_call_event("join channel", cx) @@ -457,3 +511,40 @@ pub fn report_call_event_for_channel( }; telemetry.report_clickhouse_event(event, telemetry_settings); } + +#[cfg(test)] +mod test { + use gpui::TestAppContext; + + use crate::OneAtATime; + + #[gpui::test] + async fn test_one_at_a_time(cx: &mut TestAppContext) { + let mut one_at_a_time = OneAtATime { cancel: None }; + + assert_eq!( + cx.update(|cx| one_at_a_time.spawn(cx, |_| async { Ok(1) })) + .await + .unwrap(), + Some(1) + ); + + let (a, b) = cx.update(|cx| { + ( + one_at_a_time.spawn(cx, |_| async { + assert!(false); + Ok(2) + }), + one_at_a_time.spawn(cx, |_| async { Ok(3) }), + ) + }); + + assert_eq!(a.await.unwrap(), None); + assert_eq!(b.await.unwrap(), Some(3)); + + let promise = cx.update(|cx| one_at_a_time.spawn(cx, |_| async { Ok(4) })); + drop(one_at_a_time); + + assert_eq!(promise.await.unwrap(), None); + } +} diff --git a/crates/call/src/room.rs b/crates/call/src/room.rs index a550624761..8d37194f3a 100644 --- a/crates/call/src/room.rs +++ b/crates/call/src/room.rs @@ -1,7 +1,6 @@ use crate::{ call_settings::CallSettings, participant::{LocalParticipant, ParticipantLocation, RemoteParticipant, RemoteVideoTrack}, - IncomingCall, }; use anyhow::{anyhow, Result}; use audio::{Audio, Sound}; @@ -55,7 +54,7 @@ pub enum Event { pub struct Room { id: u64, - channel_id: Option, + pub channel_id: Option, live_kit: Option, status: RoomStatus, shared_projects: HashSet>, @@ -122,6 +121,10 @@ impl Room { } } + pub fn can_publish(&self) -> bool { + self.live_kit.as_ref().is_some_and(|room| room.can_publish) + } + fn new( id: u64, channel_id: Option, @@ -181,20 +184,23 @@ impl Room { }); let connect = room.connect(&connection_info.server_url, &connection_info.token); - cx.spawn(|this, mut cx| async move { - connect.await?; + if connection_info.can_publish { + cx.spawn(|this, mut cx| async move { + connect.await?; - if !cx.read(Self::mute_on_join) { - this.update(&mut cx, |this, cx| this.share_microphone(cx)) - .await?; - } + if !cx.read(Self::mute_on_join) { + this.update(&mut cx, |this, cx| this.share_microphone(cx)) + .await?; + } - anyhow::Ok(()) - }) - .detach_and_log_err(cx); + anyhow::Ok(()) + }) + .detach_and_log_err(cx); + } Some(LiveKitRoom { room, + can_publish: connection_info.can_publish, screen_track: LocalTrack::None, microphone_track: LocalTrack::None, next_publish_id: 0, @@ -284,37 +290,32 @@ impl Room { }) } - pub(crate) fn join_channel( + pub(crate) async fn join_channel( channel_id: u64, client: Arc, user_store: ModelHandle, - cx: &mut AppContext, - ) -> Task>> { - cx.spawn(|cx| async move { - Self::from_join_response( - client.request(proto::JoinChannel { channel_id }).await?, - client, - user_store, - cx, - ) - }) + cx: AsyncAppContext, + ) -> Result> { + Self::from_join_response( + client.request(proto::JoinChannel { channel_id }).await?, + client, + user_store, + cx, + ) } - pub(crate) fn join( - call: &IncomingCall, + pub(crate) async fn join( + room_id: u64, client: Arc, user_store: ModelHandle, - cx: &mut AppContext, - ) -> Task>> { - let id = call.room_id; - cx.spawn(|cx| async move { - Self::from_join_response( - client.request(proto::JoinRoom { id }).await?, - client, - user_store, - cx, - ) - }) + cx: AsyncAppContext, + ) -> Result> { + Self::from_join_response( + client.request(proto::JoinRoom { id: room_id }).await?, + client, + user_store, + cx, + ) } pub fn mute_on_join(cx: &AppContext) -> bool { @@ -1498,6 +1499,7 @@ struct LiveKitRoom { deafened: bool, speaking: bool, next_publish_id: usize, + can_publish: bool, _maintain_room: Task<()>, _maintain_tracks: [Task<()>; 2], } diff --git a/crates/channel/src/channel.rs b/crates/channel/src/channel.rs index d31d4b3c8c..d0a32e16ff 100644 --- a/crates/channel/src/channel.rs +++ b/crates/channel/src/channel.rs @@ -7,10 +7,11 @@ use gpui::{AppContext, ModelHandle}; use std::sync::Arc; pub use channel_buffer::{ChannelBuffer, ChannelBufferEvent, ACKNOWLEDGE_DEBOUNCE_INTERVAL}; -pub use channel_chat::{ChannelChat, ChannelChatEvent, ChannelMessage, ChannelMessageId}; -pub use channel_store::{ - Channel, ChannelData, ChannelEvent, ChannelId, ChannelMembership, ChannelPath, ChannelStore, +pub use channel_chat::{ + mentions_to_proto, ChannelChat, ChannelChatEvent, ChannelMessage, ChannelMessageId, + MessageParams, }; +pub use channel_store::{Channel, ChannelEvent, ChannelId, ChannelMembership, ChannelStore}; #[cfg(test)] mod channel_store_tests; diff --git a/crates/channel/src/channel_buffer.rs b/crates/channel/src/channel_buffer.rs index ab7ea78ac1..9089973d32 100644 --- a/crates/channel/src/channel_buffer.rs +++ b/crates/channel/src/channel_buffer.rs @@ -1,4 +1,4 @@ -use crate::Channel; +use crate::{Channel, ChannelId, ChannelStore}; use anyhow::Result; use client::{Client, Collaborator, UserStore}; use collections::HashMap; @@ -19,10 +19,11 @@ pub(crate) fn init(client: &Arc) { } pub struct ChannelBuffer { - pub(crate) channel: Arc, + pub channel_id: ChannelId, connected: bool, collaborators: HashMap, user_store: ModelHandle, + channel_store: ModelHandle, buffer: ModelHandle, buffer_epoch: u64, client: Arc, @@ -34,6 +35,7 @@ pub enum ChannelBufferEvent { CollaboratorsChanged, Disconnected, BufferEdited, + ChannelChanged, } impl Entity for ChannelBuffer { @@ -46,7 +48,7 @@ impl Entity for ChannelBuffer { } self.client .send(proto::LeaveChannelBuffer { - channel_id: self.channel.id, + channel_id: self.channel_id, }) .log_err(); } @@ -58,6 +60,7 @@ impl ChannelBuffer { channel: Arc, client: Arc, user_store: ModelHandle, + channel_store: ModelHandle, mut cx: AsyncAppContext, ) -> Result> { let response = client @@ -90,9 +93,10 @@ impl ChannelBuffer { connected: true, collaborators: Default::default(), acknowledge_task: None, - channel, + channel_id: channel.id, subscription: Some(subscription.set_model(&cx.handle(), &mut cx.to_async())), user_store, + channel_store, }; this.replace_collaborators(response.collaborators, cx); this @@ -179,7 +183,7 @@ impl ChannelBuffer { let operation = language::proto::serialize_operation(operation); self.client .send(proto::UpdateChannelBuffer { - channel_id: self.channel.id, + channel_id: self.channel_id, operations: vec![operation], }) .log_err(); @@ -223,12 +227,15 @@ impl ChannelBuffer { &self.collaborators } - pub fn channel(&self) -> Arc { - self.channel.clone() + pub fn channel(&self, cx: &AppContext) -> Option> { + self.channel_store + .read(cx) + .channel_for_id(self.channel_id) + .cloned() } pub(crate) fn disconnect(&mut self, cx: &mut ModelContext) { - log::info!("channel buffer {} disconnected", self.channel.id); + log::info!("channel buffer {} disconnected", self.channel_id); if self.connected { self.connected = false; self.subscription.take(); @@ -237,6 +244,11 @@ impl ChannelBuffer { } } + pub(crate) fn channel_changed(&mut self, cx: &mut ModelContext) { + cx.emit(ChannelBufferEvent::ChannelChanged); + cx.notify() + } + pub fn is_connected(&self) -> bool { self.connected } diff --git a/crates/channel/src/channel_chat.rs b/crates/channel/src/channel_chat.rs index 734182886b..ef11d96424 100644 --- a/crates/channel/src/channel_chat.rs +++ b/crates/channel/src/channel_chat.rs @@ -3,19 +3,25 @@ use anyhow::{anyhow, Result}; use client::{ proto, user::{User, UserStore}, - Client, Subscription, TypedEnvelope, + Client, Subscription, TypedEnvelope, UserId, }; use futures::lock::Mutex; use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task}; use rand::prelude::*; -use std::{collections::HashSet, mem, ops::Range, sync::Arc}; +use std::{ + collections::HashSet, + mem, + ops::{ControlFlow, Range}, + sync::Arc, +}; use sum_tree::{Bias, SumTree}; use time::OffsetDateTime; use util::{post_inc, ResultExt as _, TryFutureExt}; pub struct ChannelChat { - channel: Arc, + pub channel_id: ChannelId, messages: SumTree, + acknowledged_message_ids: HashSet, channel_store: ModelHandle, loaded_all_messages: bool, last_acknowledged_id: Option, @@ -27,6 +33,12 @@ pub struct ChannelChat { _subscription: Subscription, } +#[derive(Debug, PartialEq, Eq)] +pub struct MessageParams { + pub text: String, + pub mentions: Vec<(Range, UserId)>, +} + #[derive(Clone, Debug)] pub struct ChannelMessage { pub id: ChannelMessageId, @@ -34,6 +46,7 @@ pub struct ChannelMessage { pub timestamp: OffsetDateTime, pub sender: Arc, pub nonce: u128, + pub mentions: Vec<(Range, UserId)>, } #[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] @@ -74,7 +87,7 @@ impl Entity for ChannelChat { fn release(&mut self, _: &mut AppContext) { self.rpc .send(proto::LeaveChannelChat { - channel_id: self.channel.id, + channel_id: self.channel_id, }) .log_err(); } @@ -99,12 +112,13 @@ impl ChannelChat { Ok(cx.add_model(|cx| { let mut this = Self { - channel, + channel_id: channel.id, user_store, channel_store, rpc: client, outgoing_messages_lock: Default::default(), messages: Default::default(), + acknowledged_message_ids: Default::default(), loaded_all_messages, next_pending_message_id: 0, last_acknowledged_id: None, @@ -116,16 +130,23 @@ impl ChannelChat { })) } - pub fn channel(&self) -> &Arc { - &self.channel + pub fn channel(&self, cx: &AppContext) -> Option> { + self.channel_store + .read(cx) + .channel_for_id(self.channel_id) + .cloned() + } + + pub fn client(&self) -> &Arc { + &self.rpc } pub fn send_message( &mut self, - body: String, + message: MessageParams, cx: &mut ModelContext, - ) -> Result>> { - if body.is_empty() { + ) -> Result>> { + if message.text.is_empty() { Err(anyhow!("message body can't be empty"))?; } @@ -135,16 +156,17 @@ impl ChannelChat { .current_user() .ok_or_else(|| anyhow!("current_user is not present"))?; - let channel_id = self.channel.id; + let channel_id = self.channel_id; let pending_id = ChannelMessageId::Pending(post_inc(&mut self.next_pending_message_id)); let nonce = self.rng.gen(); self.insert_messages( SumTree::from_item( ChannelMessage { id: pending_id, - body: body.clone(), + body: message.text.clone(), sender: current_user, timestamp: OffsetDateTime::now_utc(), + mentions: message.mentions.clone(), nonce, }, &(), @@ -158,27 +180,25 @@ impl ChannelChat { let outgoing_message_guard = outgoing_messages_lock.lock().await; let request = rpc.request(proto::SendChannelMessage { channel_id, - body, + body: message.text, nonce: Some(nonce.into()), + mentions: mentions_to_proto(&message.mentions), }); let response = request.await?; drop(outgoing_message_guard); - let message = ChannelMessage::from_proto( - response.message.ok_or_else(|| anyhow!("invalid message"))?, - &user_store, - &mut cx, - ) - .await?; + let response = response.message.ok_or_else(|| anyhow!("invalid message"))?; + let id = response.id; + let message = ChannelMessage::from_proto(response, &user_store, &mut cx).await?; this.update(&mut cx, |this, cx| { this.insert_messages(SumTree::from_item(message, &()), cx); - Ok(()) + Ok(id) }) })) } pub fn remove_message(&mut self, id: u64, cx: &mut ModelContext) -> Task> { let response = self.rpc.request(proto::RemoveChannelMessage { - channel_id: self.channel.id, + channel_id: self.channel_id, message_id: id, }); cx.spawn(|this, mut cx| async move { @@ -191,41 +211,76 @@ impl ChannelChat { }) } - pub fn load_more_messages(&mut self, cx: &mut ModelContext) -> bool { - if !self.loaded_all_messages { - let rpc = self.rpc.clone(); - let user_store = self.user_store.clone(); - let channel_id = self.channel.id; - if let Some(before_message_id) = - self.messages.first().and_then(|message| match message.id { - ChannelMessageId::Saved(id) => Some(id), - ChannelMessageId::Pending(_) => None, - }) - { - cx.spawn(|this, mut cx| { - async move { - let response = rpc - .request(proto::GetChannelMessages { - channel_id, - before_message_id, - }) - .await?; - let loaded_all_messages = response.done; - let messages = - messages_from_proto(response.messages, &user_store, &mut cx).await?; - this.update(&mut cx, |this, cx| { - this.loaded_all_messages = loaded_all_messages; - this.insert_messages(messages, cx); - }); - anyhow::Ok(()) + pub fn load_more_messages(&mut self, cx: &mut ModelContext) -> Option>> { + if self.loaded_all_messages { + return None; + } + + let rpc = self.rpc.clone(); + let user_store = self.user_store.clone(); + let channel_id = self.channel_id; + let before_message_id = self.first_loaded_message_id()?; + Some(cx.spawn(|this, mut cx| { + async move { + let response = rpc + .request(proto::GetChannelMessages { + channel_id, + before_message_id, + }) + .await?; + let loaded_all_messages = response.done; + let messages = messages_from_proto(response.messages, &user_store, &mut cx).await?; + this.update(&mut cx, |this, cx| { + this.loaded_all_messages = loaded_all_messages; + this.insert_messages(messages, cx); + }); + anyhow::Ok(()) + } + .log_err() + })) + } + + pub fn first_loaded_message_id(&mut self) -> Option { + self.messages.first().and_then(|message| match message.id { + ChannelMessageId::Saved(id) => Some(id), + ChannelMessageId::Pending(_) => None, + }) + } + + /// Load all of the chat messages since a certain message id. + /// + /// For now, we always maintain a suffix of the channel's messages. + pub async fn load_history_since_message( + chat: ModelHandle, + message_id: u64, + mut cx: AsyncAppContext, + ) -> Option { + loop { + let step = chat.update(&mut cx, |chat, cx| { + if let Some(first_id) = chat.first_loaded_message_id() { + if first_id <= message_id { + let mut cursor = chat.messages.cursor::<(ChannelMessageId, Count)>(); + let message_id = ChannelMessageId::Saved(message_id); + cursor.seek(&message_id, Bias::Left, &()); + return ControlFlow::Break( + if cursor + .item() + .map_or(false, |message| message.id == message_id) + { + Some(cursor.start().1 .0) + } else { + None + }, + ); } - .log_err() - }) - .detach(); - return true; + } + ControlFlow::Continue(chat.load_more_messages(cx)) + }); + match step { + ControlFlow::Break(ix) => return ix, + ControlFlow::Continue(task) => task?.await?, } } - false } pub fn acknowledge_last_message(&mut self, cx: &mut ModelContext) { @@ -236,13 +291,13 @@ impl ChannelChat { { self.rpc .send(proto::AckChannelMessage { - channel_id: self.channel.id, + channel_id: self.channel_id, message_id: latest_message_id, }) .ok(); self.last_acknowledged_id = Some(latest_message_id); self.channel_store.update(cx, |store, cx| { - store.acknowledge_message_id(self.channel.id, latest_message_id, cx); + store.acknowledge_message_id(self.channel_id, latest_message_id, cx); }); } } @@ -251,7 +306,7 @@ impl ChannelChat { pub fn rejoin(&mut self, cx: &mut ModelContext) { let user_store = self.user_store.clone(); let rpc = self.rpc.clone(); - let channel_id = self.channel.id; + let channel_id = self.channel_id; cx.spawn(|this, mut cx| { async move { let response = rpc.request(proto::JoinChannelChat { channel_id }).await?; @@ -284,6 +339,7 @@ impl ChannelChat { let request = rpc.request(proto::SendChannelMessage { channel_id, body: pending_message.body, + mentions: mentions_to_proto(&pending_message.mentions), nonce: Some(pending_message.nonce.into()), }); let response = request.await?; @@ -319,6 +375,17 @@ impl ChannelChat { cursor.item().unwrap() } + pub fn acknowledge_message(&mut self, id: u64) { + if self.acknowledged_message_ids.insert(id) { + self.rpc + .send(proto::AckChannelMessage { + channel_id: self.channel_id, + message_id: id, + }) + .ok(); + } + } + pub fn messages_in_range(&self, range: Range) -> impl Iterator { let mut cursor = self.messages.cursor::(); cursor.seek(&Count(range.start), Bias::Right, &()); @@ -348,7 +415,7 @@ impl ChannelChat { this.update(&mut cx, |this, cx| { this.insert_messages(SumTree::from_item(message, &()), cx); cx.emit(ChannelChatEvent::NewMessage { - channel_id: this.channel.id, + channel_id: this.channel_id, message_id, }) }); @@ -451,22 +518,7 @@ async fn messages_from_proto( user_store: &ModelHandle, cx: &mut AsyncAppContext, ) -> Result> { - let unique_user_ids = proto_messages - .iter() - .map(|m| m.sender_id) - .collect::>() - .into_iter() - .collect(); - user_store - .update(cx, |user_store, cx| { - user_store.get_users(unique_user_ids, cx) - }) - .await?; - - let mut messages = Vec::with_capacity(proto_messages.len()); - for message in proto_messages { - messages.push(ChannelMessage::from_proto(message, user_store, cx).await?); - } + let messages = ChannelMessage::from_proto_vec(proto_messages, user_store, cx).await?; let mut result = SumTree::new(); result.extend(messages, &()); Ok(result) @@ -486,6 +538,14 @@ impl ChannelMessage { Ok(ChannelMessage { id: ChannelMessageId::Saved(message.id), body: message.body, + mentions: message + .mentions + .into_iter() + .filter_map(|mention| { + let range = mention.range?; + Some((range.start as usize..range.end as usize, mention.user_id)) + }) + .collect(), timestamp: OffsetDateTime::from_unix_timestamp(message.timestamp as i64)?, sender, nonce: message @@ -498,6 +558,43 @@ impl ChannelMessage { pub fn is_pending(&self) -> bool { matches!(self.id, ChannelMessageId::Pending(_)) } + + pub async fn from_proto_vec( + proto_messages: Vec, + user_store: &ModelHandle, + cx: &mut AsyncAppContext, + ) -> Result> { + let unique_user_ids = proto_messages + .iter() + .map(|m| m.sender_id) + .collect::>() + .into_iter() + .collect(); + user_store + .update(cx, |user_store, cx| { + user_store.get_users(unique_user_ids, cx) + }) + .await?; + + let mut messages = Vec::with_capacity(proto_messages.len()); + for message in proto_messages { + messages.push(ChannelMessage::from_proto(message, user_store, cx).await?); + } + Ok(messages) + } +} + +pub fn mentions_to_proto(mentions: &[(Range, UserId)]) -> Vec { + mentions + .iter() + .map(|(range, user_id)| proto::ChatMention { + range: Some(proto::Range { + start: range.start as u64, + end: range.end as u64, + }), + user_id: *user_id as u64, + }) + .collect() } impl sum_tree::Item for ChannelMessage { @@ -538,3 +635,12 @@ impl<'a> sum_tree::Dimension<'a, ChannelMessageSummary> for Count { self.0 += summary.count; } } + +impl<'a> From<&'a str> for MessageParams { + fn from(value: &'a str) -> Self { + Self { + text: value.into(), + mentions: Vec::new(), + } + } +} diff --git a/crates/channel/src/channel_store.rs b/crates/channel/src/channel_store.rs index bceb2c094d..efa05d51a9 100644 --- a/crates/channel/src/channel_store.rs +++ b/crates/channel/src/channel_store.rs @@ -1,6 +1,6 @@ mod channel_index; -use crate::{channel_buffer::ChannelBuffer, channel_chat::ChannelChat}; +use crate::{channel_buffer::ChannelBuffer, channel_chat::ChannelChat, ChannelMessage}; use anyhow::{anyhow, Result}; use channel_index::ChannelIndex; use client::{Client, Subscription, User, UserId, UserStore}; @@ -9,11 +9,10 @@ use db::RELEASE_CHANNEL; use futures::{channel::mpsc, future::Shared, Future, FutureExt, StreamExt}; use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle}; use rpc::{ - proto::{self, ChannelEdge, ChannelPermission}, + proto::{self, ChannelVisibility}, TypedEnvelope, }; -use serde_derive::{Deserialize, Serialize}; -use std::{borrow::Cow, hash::Hash, mem, ops::Deref, sync::Arc, time::Duration}; +use std::{mem, sync::Arc, time::Duration}; use util::ResultExt; pub fn init(client: &Arc, user_store: ModelHandle, cx: &mut AppContext) { @@ -27,10 +26,9 @@ pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30); pub type ChannelId = u64; pub struct ChannelStore { - channel_index: ChannelIndex, + pub channel_index: ChannelIndex, channel_invitations: Vec>, channel_participants: HashMap>>, - channels_with_admin_privileges: HashSet, outgoing_invites: HashSet<(ChannelId, UserId)>, update_channels_tx: mpsc::UnboundedSender, opened_buffers: HashMap>, @@ -43,14 +41,15 @@ pub struct ChannelStore { _update_channels: Task<()>, } -pub type ChannelData = (Channel, ChannelPath); - #[derive(Clone, Debug, PartialEq)] pub struct Channel { pub id: ChannelId, pub name: String, + pub visibility: proto::ChannelVisibility, + pub role: proto::ChannelRole, pub unseen_note_version: Option<(u64, clock::Global)>, pub unseen_message_id: Option, + pub parent_path: Vec, } impl Channel { @@ -71,15 +70,41 @@ impl Channel { slug.trim_matches(|c| c == '-').to_string() } -} -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)] -pub struct ChannelPath(Arc<[ChannelId]>); + pub fn can_edit_notes(&self) -> bool { + self.role == proto::ChannelRole::Member || self.role == proto::ChannelRole::Admin + } +} pub struct ChannelMembership { pub user: Arc, pub kind: proto::channel_member::Kind, - pub admin: bool, + pub role: proto::ChannelRole, +} +impl ChannelMembership { + pub fn sort_key(&self) -> MembershipSortKey { + MembershipSortKey { + role_order: match self.role { + proto::ChannelRole::Admin => 0, + proto::ChannelRole::Member => 1, + proto::ChannelRole::Banned => 2, + proto::ChannelRole::Guest => 3, + }, + kind_order: match self.kind { + proto::channel_member::Kind::Member => 0, + proto::channel_member::Kind::AncestorMember => 1, + proto::channel_member::Kind::Invitee => 2, + }, + username_order: self.user.github_login.as_str(), + } + } +} + +#[derive(PartialOrd, Ord, PartialEq, Eq)] +pub struct MembershipSortKey<'a> { + role_order: u8, + kind_order: u8, + username_order: &'a str, } pub enum ChannelEvent { @@ -127,9 +152,6 @@ impl ChannelStore { this.update(&mut cx, |this, cx| this.handle_disconnect(true, cx)); } } - if status.is_connected() { - } else { - } } Some(()) }); @@ -138,7 +160,6 @@ impl ChannelStore { channel_invitations: Vec::default(), channel_index: ChannelIndex::default(), channel_participants: Default::default(), - channels_with_admin_privileges: Default::default(), outgoing_invites: Default::default(), opened_buffers: Default::default(), opened_chats: Default::default(), @@ -167,16 +188,6 @@ impl ChannelStore { self.client.clone() } - pub fn has_children(&self, channel_id: ChannelId) -> bool { - self.channel_index.iter().any(|path| { - if let Some(ix) = path.iter().position(|id| *id == channel_id) { - path.len() > ix + 1 - } else { - false - } - }) - } - /// Returns the number of unique channels in the store pub fn channel_count(&self) -> usize { self.channel_index.by_id().len() @@ -196,26 +207,31 @@ impl ChannelStore { } /// Iterate over all entries in the channel DAG - pub fn channel_dag_entries(&self) -> impl '_ + Iterator)> { - self.channel_index.iter().map(move |path| { - let id = path.last().unwrap(); - let channel = self.channel_for_id(*id).unwrap(); - (path.len() - 1, channel) - }) + pub fn ordered_channels(&self) -> impl '_ + Iterator)> { + self.channel_index + .ordered_channels() + .iter() + .filter_map(move |id| { + let channel = self.channel_index.by_id().get(id)?; + Some((channel.parent_path.len(), channel)) + }) } - pub fn channel_dag_entry_at(&self, ix: usize) -> Option<(&Arc, &ChannelPath)> { - let path = self.channel_index.get(ix)?; - let id = path.last().unwrap(); - let channel = self.channel_for_id(*id).unwrap(); - - Some((channel, path)) + pub fn channel_at_index(&self, ix: usize) -> Option<&Arc> { + let channel_id = self.channel_index.ordered_channels().get(ix)?; + self.channel_index.by_id().get(channel_id) } pub fn channel_at(&self, ix: usize) -> Option<&Arc> { self.channel_index.by_id().values().nth(ix) } + pub fn has_channel_invitation(&self, channel_id: ChannelId) -> bool { + self.channel_invitations + .iter() + .any(|channel| channel.id == channel_id) + } + pub fn channel_invitations(&self) -> &[Arc] { &self.channel_invitations } @@ -240,14 +256,42 @@ impl ChannelStore { ) -> Task>> { let client = self.client.clone(); let user_store = self.user_store.clone(); + let channel_store = cx.handle(); self.open_channel_resource( channel_id, |this| &mut this.opened_buffers, - |channel, cx| ChannelBuffer::new(channel, client, user_store, cx), + |channel, cx| ChannelBuffer::new(channel, client, user_store, channel_store, cx), cx, ) } + pub fn fetch_channel_messages( + &self, + message_ids: Vec, + cx: &mut ModelContext, + ) -> Task>> { + let request = if message_ids.is_empty() { + None + } else { + Some( + self.client + .request(proto::GetChannelMessagesById { message_ids }), + ) + }; + cx.spawn_weak(|this, mut cx| async move { + if let Some(request) = request { + let response = request.await?; + let this = this + .upgrade(&cx) + .ok_or_else(|| anyhow!("channel store dropped"))?; + let user_store = this.read_with(&cx, |this, _| this.user_store.clone()); + ChannelMessage::from_proto_vec(response.messages, &user_store, &mut cx).await + } else { + Ok(Vec::new()) + } + }) + } + pub fn has_channel_buffer_changed(&self, channel_id: ChannelId) -> Option { self.channel_index .by_id() @@ -393,16 +437,11 @@ impl ChannelStore { .spawn(async move { task.await.map_err(|error| anyhow!("{}", error)) }) } - pub fn is_user_admin(&self, channel_id: ChannelId) -> bool { - self.channel_index.iter().any(|path| { - if let Some(ix) = path.iter().position(|id| *id == channel_id) { - path[..=ix] - .iter() - .any(|id| self.channels_with_admin_privileges.contains(id)) - } else { - false - } - }) + pub fn is_channel_admin(&self, channel_id: ChannelId) -> bool { + let Some(channel) = self.channel_for_id(channel_id) else { + return false; + }; + channel.role == proto::ChannelRole::Admin } pub fn channel_participants(&self, channel_id: ChannelId) -> &[Arc] { @@ -429,24 +468,19 @@ impl ChannelStore { .ok_or_else(|| anyhow!("missing channel in response"))?; let channel_id = channel.id; - let parent_edge = if let Some(parent_id) = parent_id { - vec![ChannelEdge { - channel_id: channel.id, - parent_id, - }] - } else { - vec![] - }; + // let parent_edge = if let Some(parent_id) = parent_id { + // vec![ChannelEdge { + // channel_id: channel.id, + // parent_id, + // }] + // } else { + // vec![] + // }; this.update(&mut cx, |this, cx| { let task = this.update_channels( proto::UpdateChannels { channels: vec![channel], - insert_edge: parent_edge, - channel_permissions: vec![ChannelPermission { - channel_id, - is_admin: true, - }], ..Default::default() }, cx, @@ -464,52 +498,34 @@ impl ChannelStore { }) } - pub fn link_channel( - &mut self, - channel_id: ChannelId, - to: ChannelId, - cx: &mut ModelContext, - ) -> Task> { - let client = self.client.clone(); - cx.spawn(|_, _| async move { - let _ = client - .request(proto::LinkChannel { channel_id, to }) - .await?; - - Ok(()) - }) - } - - pub fn unlink_channel( - &mut self, - channel_id: ChannelId, - from: ChannelId, - cx: &mut ModelContext, - ) -> Task> { - let client = self.client.clone(); - cx.spawn(|_, _| async move { - let _ = client - .request(proto::UnlinkChannel { channel_id, from }) - .await?; - - Ok(()) - }) - } - pub fn move_channel( &mut self, channel_id: ChannelId, - from: ChannelId, - to: ChannelId, + to: Option, cx: &mut ModelContext, ) -> Task> { let client = self.client.clone(); cx.spawn(|_, _| async move { let _ = client - .request(proto::MoveChannel { + .request(proto::MoveChannel { channel_id, to }) + .await?; + + Ok(()) + }) + } + + pub fn set_channel_visibility( + &mut self, + channel_id: ChannelId, + visibility: ChannelVisibility, + cx: &mut ModelContext, + ) -> Task> { + let client = self.client.clone(); + cx.spawn(|_, _| async move { + let _ = client + .request(proto::SetChannelVisibility { channel_id, - from, - to, + visibility: visibility.into(), }) .await?; @@ -521,7 +537,7 @@ impl ChannelStore { &mut self, channel_id: ChannelId, user_id: UserId, - admin: bool, + role: proto::ChannelRole, cx: &mut ModelContext, ) -> Task> { if !self.outgoing_invites.insert((channel_id, user_id)) { @@ -535,7 +551,7 @@ impl ChannelStore { .request(proto::InviteChannelMember { channel_id, user_id, - admin, + role: role.into(), }) .await; @@ -579,11 +595,11 @@ impl ChannelStore { }) } - pub fn set_member_admin( + pub fn set_member_role( &mut self, channel_id: ChannelId, user_id: UserId, - admin: bool, + role: proto::ChannelRole, cx: &mut ModelContext, ) -> Task> { if !self.outgoing_invites.insert((channel_id, user_id)) { @@ -594,10 +610,10 @@ impl ChannelStore { let client = self.client.clone(); cx.spawn(|this, mut cx| async move { let result = client - .request(proto::SetChannelMemberAdmin { + .request(proto::SetChannelMemberRole { channel_id, user_id, - admin, + role: role.into(), }) .await; @@ -649,14 +665,15 @@ impl ChannelStore { &mut self, channel_id: ChannelId, accept: bool, - ) -> impl Future> { + cx: &mut ModelContext, + ) -> Task> { let client = self.client.clone(); - async move { + cx.background().spawn(async move { client .request(proto::RespondToChannelInvite { channel_id, accept }) .await?; Ok(()) - } + }) } pub fn get_channel_member_details( @@ -685,8 +702,8 @@ impl ChannelStore { .filter_map(|(user, member)| { Some(ChannelMembership { user, - admin: member.admin, - kind: proto::channel_member::Kind::from_i32(member.kind)?, + role: member.role(), + kind: member.kind(), }) }) .collect()) @@ -724,6 +741,11 @@ impl ChannelStore { } fn handle_connect(&mut self, cx: &mut ModelContext) -> Task> { + self.channel_index.clear(); + self.channel_invitations.clear(); + self.channel_participants.clear(); + self.channel_index.clear(); + self.outgoing_invites.clear(); self.disconnect_channel_buffers_task.take(); for chat in self.opened_chats.values() { @@ -743,7 +765,7 @@ impl ChannelStore { let channel_buffer = buffer.read(cx); let buffer = channel_buffer.buffer().read(cx); buffer_versions.push(proto::ChannelBufferVersion { - channel_id: channel_buffer.channel().id, + channel_id: channel_buffer.channel_id, epoch: channel_buffer.epoch(), version: language::proto::serialize_version(&buffer.version()), }); @@ -770,13 +792,13 @@ impl ChannelStore { }; channel_buffer.update(cx, |channel_buffer, cx| { - let channel_id = channel_buffer.channel().id; + let channel_id = channel_buffer.channel_id; if let Some(remote_buffer) = response .buffers .iter_mut() .find(|buffer| buffer.channel_id == channel_id) { - let channel_id = channel_buffer.channel().id; + let channel_id = channel_buffer.channel_id; let remote_version = language::proto::deserialize_version(&remote_buffer.version); @@ -833,12 +855,6 @@ impl ChannelStore { } fn handle_disconnect(&mut self, wait_for_reconnect: bool, cx: &mut ModelContext) { - self.channel_index.clear(); - self.channel_invitations.clear(); - self.channel_participants.clear(); - self.channels_with_admin_privileges.clear(); - self.channel_index.clear(); - self.outgoing_invites.clear(); cx.notify(); self.disconnect_channel_buffers_task.get_or_insert_with(|| { @@ -881,9 +897,12 @@ impl ChannelStore { ix, Arc::new(Channel { id: channel.id, + visibility: channel.visibility(), + role: channel.role(), name: channel.name, unseen_note_version: None, unseen_message_id: None, + parent_path: channel.parent_path, }), ), } @@ -891,8 +910,6 @@ impl ChannelStore { let channels_changed = !payload.channels.is_empty() || !payload.delete_channels.is_empty() - || !payload.insert_edge.is_empty() - || !payload.delete_edge.is_empty() || !payload.unseen_channel_messages.is_empty() || !payload.unseen_channel_buffer_changes.is_empty(); @@ -900,12 +917,17 @@ impl ChannelStore { if !payload.delete_channels.is_empty() { self.channel_index.delete_channels(&payload.delete_channels); self.channel_participants - .retain(|channel_id, _| !payload.delete_channels.contains(channel_id)); - self.channels_with_admin_privileges - .retain(|channel_id| !payload.delete_channels.contains(channel_id)); + .retain(|channel_id, _| !&payload.delete_channels.contains(channel_id)); for channel_id in &payload.delete_channels { let channel_id = *channel_id; + if payload + .channels + .iter() + .any(|channel| channel.id == channel_id) + { + continue; + } if let Some(OpenedModelHandle::Open(buffer)) = self.opened_buffers.remove(&channel_id) { @@ -918,7 +940,16 @@ impl ChannelStore { let mut index = self.channel_index.bulk_insert(); for channel in payload.channels { - index.insert(channel) + let id = channel.id; + let channel_changed = index.insert(channel); + + if channel_changed { + if let Some(OpenedModelHandle::Open(buffer)) = self.opened_buffers.get(&id) { + if let Some(buffer) = buffer.upgrade(cx) { + buffer.update(cx, ChannelBuffer::channel_changed); + } + } + } } for unseen_buffer_change in payload.unseen_channel_buffer_changes { @@ -936,24 +967,6 @@ impl ChannelStore { unseen_channel_message.message_id, ); } - - for edge in payload.insert_edge { - index.insert_edge(edge.channel_id, edge.parent_id); - } - - for edge in payload.delete_edge { - index.delete_edge(edge.parent_id, edge.channel_id); - } - } - - for permission in payload.channel_permissions { - if permission.is_admin { - self.channels_with_admin_privileges - .insert(permission.channel_id); - } else { - self.channels_with_admin_privileges - .remove(&permission.channel_id); - } } cx.notify(); @@ -1002,44 +1015,3 @@ impl ChannelStore { })) } } - -impl Deref for ChannelPath { - type Target = [ChannelId]; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -impl ChannelPath { - pub fn new(path: Arc<[ChannelId]>) -> Self { - debug_assert!(path.len() >= 1); - Self(path) - } - - pub fn parent_id(&self) -> Option { - self.0.len().checked_sub(2).map(|i| self.0[i]) - } - - pub fn channel_id(&self) -> ChannelId { - self.0[self.0.len() - 1] - } -} - -impl From for Cow<'static, ChannelPath> { - fn from(value: ChannelPath) -> Self { - Cow::Owned(value) - } -} - -impl<'a> From<&'a ChannelPath> for Cow<'a, ChannelPath> { - fn from(value: &'a ChannelPath) -> Self { - Cow::Borrowed(value) - } -} - -impl Default for ChannelPath { - fn default() -> Self { - ChannelPath(Arc::from([])) - } -} diff --git a/crates/channel/src/channel_store/channel_index.rs b/crates/channel/src/channel_store/channel_index.rs index bf0de1b644..97b2ab6318 100644 --- a/crates/channel/src/channel_store/channel_index.rs +++ b/crates/channel/src/channel_store/channel_index.rs @@ -1,14 +1,11 @@ -use std::{ops::Deref, sync::Arc}; - use crate::{Channel, ChannelId}; use collections::BTreeMap; use rpc::proto; - -use super::ChannelPath; +use std::sync::Arc; #[derive(Default, Debug)] pub struct ChannelIndex { - paths: Vec, + channels_ordered: Vec, channels_by_id: BTreeMap>, } @@ -17,8 +14,12 @@ impl ChannelIndex { &self.channels_by_id } + pub fn ordered_channels(&self) -> &[ChannelId] { + &self.channels_ordered + } + pub fn clear(&mut self) { - self.paths.clear(); + self.channels_ordered.clear(); self.channels_by_id.clear(); } @@ -26,15 +27,13 @@ impl ChannelIndex { pub fn delete_channels(&mut self, channels: &[ChannelId]) { self.channels_by_id .retain(|channel_id, _| !channels.contains(channel_id)); - self.paths.retain(|path| { - path.iter() - .all(|channel_id| self.channels_by_id.contains_key(channel_id)) - }); + self.channels_ordered + .retain(|channel_id| !channels.contains(channel_id)); } pub fn bulk_insert(&mut self) -> ChannelPathsInsertGuard { ChannelPathsInsertGuard { - paths: &mut self.paths, + channels_ordered: &mut self.channels_ordered, channels_by_id: &mut self.channels_by_id, } } @@ -77,42 +76,15 @@ impl ChannelIndex { } } -impl Deref for ChannelIndex { - type Target = [ChannelPath]; - - fn deref(&self) -> &Self::Target { - &self.paths - } -} - /// A guard for ensuring that the paths index maintains its sort and uniqueness /// invariants after a series of insertions #[derive(Debug)] pub struct ChannelPathsInsertGuard<'a> { - paths: &'a mut Vec, + channels_ordered: &'a mut Vec, channels_by_id: &'a mut BTreeMap>, } impl<'a> ChannelPathsInsertGuard<'a> { - /// Remove the given edge from this index. This will not remove the channel. - /// If this operation would result in a dangling edge, re-insert it. - pub fn delete_edge(&mut self, parent_id: ChannelId, channel_id: ChannelId) { - self.paths.retain(|path| { - !path - .windows(2) - .any(|window| window == [parent_id, channel_id]) - }); - - // Ensure that there is at least one channel path in the index - if !self - .paths - .iter() - .any(|path| path.iter().any(|id| id == &channel_id)) - { - self.insert_root(channel_id); - } - } - pub fn note_changed(&mut self, channel_id: ChannelId, epoch: u64, version: &clock::Global) { insert_note_changed(&mut self.channels_by_id, channel_id, epoch, &version); } @@ -121,91 +93,65 @@ impl<'a> ChannelPathsInsertGuard<'a> { insert_new_message(&mut self.channels_by_id, channel_id, message_id) } - pub fn insert(&mut self, channel_proto: proto::Channel) { + pub fn insert(&mut self, channel_proto: proto::Channel) -> bool { + let mut ret = false; if let Some(existing_channel) = self.channels_by_id.get_mut(&channel_proto.id) { - Arc::make_mut(existing_channel).name = channel_proto.name; + let existing_channel = Arc::make_mut(existing_channel); + + ret = existing_channel.visibility != channel_proto.visibility() + || existing_channel.role != channel_proto.role() + || existing_channel.name != channel_proto.name; + + existing_channel.visibility = channel_proto.visibility(); + existing_channel.role = channel_proto.role(); + existing_channel.name = channel_proto.name; } else { self.channels_by_id.insert( channel_proto.id, Arc::new(Channel { id: channel_proto.id, + visibility: channel_proto.visibility(), + role: channel_proto.role(), name: channel_proto.name, unseen_note_version: None, unseen_message_id: None, + parent_path: channel_proto.parent_path, }), ); self.insert_root(channel_proto.id); } - } - - pub fn insert_edge(&mut self, channel_id: ChannelId, parent_id: ChannelId) { - let mut parents = Vec::new(); - let mut descendants = Vec::new(); - let mut ixs_to_remove = Vec::new(); - - for (ix, path) in self.paths.iter().enumerate() { - if path - .windows(2) - .any(|window| window[0] == parent_id && window[1] == channel_id) - { - // We already have this edge in the index - return; - } - if path.ends_with(&[parent_id]) { - parents.push(path); - } else if let Some(position) = path.iter().position(|id| id == &channel_id) { - if position == 0 { - ixs_to_remove.push(ix); - } - descendants.push(path.split_at(position).1); - } - } - - let mut new_paths = Vec::new(); - for parent in parents.iter() { - if descendants.is_empty() { - let mut new_path = Vec::with_capacity(parent.len() + 1); - new_path.extend_from_slice(parent); - new_path.push(channel_id); - new_paths.push(ChannelPath::new(new_path.into())); - } else { - for descendant in descendants.iter() { - let mut new_path = Vec::with_capacity(parent.len() + descendant.len()); - new_path.extend_from_slice(parent); - new_path.extend_from_slice(descendant); - new_paths.push(ChannelPath::new(new_path.into())); - } - } - } - - for ix in ixs_to_remove.into_iter().rev() { - self.paths.swap_remove(ix); - } - self.paths.extend(new_paths) + ret } fn insert_root(&mut self, channel_id: ChannelId) { - self.paths.push(ChannelPath::new(Arc::from([channel_id]))); + self.channels_ordered.push(channel_id); } } impl<'a> Drop for ChannelPathsInsertGuard<'a> { fn drop(&mut self) { - self.paths.sort_by(|a, b| { - let a = channel_path_sorting_key(a, &self.channels_by_id); - let b = channel_path_sorting_key(b, &self.channels_by_id); + self.channels_ordered.sort_by(|a, b| { + let a = channel_path_sorting_key(*a, &self.channels_by_id); + let b = channel_path_sorting_key(*b, &self.channels_by_id); a.cmp(b) }); - self.paths.dedup(); + self.channels_ordered.dedup(); } } fn channel_path_sorting_key<'a>( - path: &'a [ChannelId], + id: ChannelId, channels_by_id: &'a BTreeMap>, -) -> impl 'a + Iterator> { - path.iter() - .map(|id| Some(channels_by_id.get(id)?.name.as_str())) +) -> impl Iterator { + let (parent_path, name) = channels_by_id + .get(&id) + .map_or((&[] as &[_], None), |channel| { + (channel.parent_path.as_slice(), Some(channel.name.as_str())) + }); + parent_path + .iter() + .filter_map(|id| Some(channels_by_id.get(id)?.name.as_str())) + .chain(name) } fn insert_note_changed( diff --git a/crates/channel/src/channel_store_tests.rs b/crates/channel/src/channel_store_tests.rs index 9303a52092..ff8761ee91 100644 --- a/crates/channel/src/channel_store_tests.rs +++ b/crates/channel/src/channel_store_tests.rs @@ -3,7 +3,7 @@ use crate::channel_chat::ChannelChatEvent; use super::*; use client::{test::FakeServer, Client, UserStore}; use gpui::{AppContext, ModelHandle, TestAppContext}; -use rpc::proto; +use rpc::proto::{self}; use settings::SettingsStore; use util::http::FakeHttpClient; @@ -18,16 +18,18 @@ fn test_update_channels(cx: &mut AppContext) { proto::Channel { id: 1, name: "b".to_string(), + visibility: proto::ChannelVisibility::Members as i32, + role: proto::ChannelRole::Admin.into(), + parent_path: Vec::new(), }, proto::Channel { id: 2, name: "a".to_string(), + visibility: proto::ChannelVisibility::Members as i32, + role: proto::ChannelRole::Member.into(), + parent_path: Vec::new(), }, ], - channel_permissions: vec![proto::ChannelPermission { - channel_id: 1, - is_admin: true, - }], ..Default::default() }, cx, @@ -36,8 +38,8 @@ fn test_update_channels(cx: &mut AppContext) { &channel_store, &[ // - (0, "a".to_string(), false), - (0, "b".to_string(), true), + (0, "a".to_string(), proto::ChannelRole::Member), + (0, "b".to_string(), proto::ChannelRole::Admin), ], cx, ); @@ -49,20 +51,16 @@ fn test_update_channels(cx: &mut AppContext) { proto::Channel { id: 3, name: "x".to_string(), + visibility: proto::ChannelVisibility::Members as i32, + role: proto::ChannelRole::Admin.into(), + parent_path: vec![1], }, proto::Channel { id: 4, name: "y".to_string(), - }, - ], - insert_edge: vec![ - proto::ChannelEdge { - parent_id: 1, - channel_id: 3, - }, - proto::ChannelEdge { - parent_id: 2, - channel_id: 4, + visibility: proto::ChannelVisibility::Members as i32, + role: proto::ChannelRole::Member.into(), + parent_path: vec![2], }, ], ..Default::default() @@ -72,10 +70,10 @@ fn test_update_channels(cx: &mut AppContext) { assert_channels( &channel_store, &[ - (0, "a".to_string(), false), - (1, "y".to_string(), false), - (0, "b".to_string(), true), - (1, "x".to_string(), true), + (0, "a".to_string(), proto::ChannelRole::Member), + (1, "y".to_string(), proto::ChannelRole::Member), + (0, "b".to_string(), proto::ChannelRole::Admin), + (1, "x".to_string(), proto::ChannelRole::Admin), ], cx, ); @@ -92,30 +90,25 @@ fn test_dangling_channel_paths(cx: &mut AppContext) { proto::Channel { id: 0, name: "a".to_string(), + visibility: proto::ChannelVisibility::Members as i32, + role: proto::ChannelRole::Admin.into(), + parent_path: vec![], }, proto::Channel { id: 1, name: "b".to_string(), + visibility: proto::ChannelVisibility::Members as i32, + role: proto::ChannelRole::Admin.into(), + parent_path: vec![0], }, proto::Channel { id: 2, name: "c".to_string(), + visibility: proto::ChannelVisibility::Members as i32, + role: proto::ChannelRole::Admin.into(), + parent_path: vec![0, 1], }, ], - insert_edge: vec![ - proto::ChannelEdge { - parent_id: 0, - channel_id: 1, - }, - proto::ChannelEdge { - parent_id: 1, - channel_id: 2, - }, - ], - channel_permissions: vec![proto::ChannelPermission { - channel_id: 0, - is_admin: true, - }], ..Default::default() }, cx, @@ -125,9 +118,9 @@ fn test_dangling_channel_paths(cx: &mut AppContext) { &channel_store, &[ // - (0, "a".to_string(), true), - (1, "b".to_string(), true), - (2, "c".to_string(), true), + (0, "a".to_string(), proto::ChannelRole::Admin), + (1, "b".to_string(), proto::ChannelRole::Admin), + (2, "c".to_string(), proto::ChannelRole::Admin), ], cx, ); @@ -142,7 +135,11 @@ fn test_dangling_channel_paths(cx: &mut AppContext) { ); // Make sure that the 1/2/3 path is gone - assert_channels(&channel_store, &[(0, "a".to_string(), true)], cx); + assert_channels( + &channel_store, + &[(0, "a".to_string(), proto::ChannelRole::Admin)], + cx, + ); } #[gpui::test] @@ -158,12 +155,19 @@ async fn test_channel_messages(cx: &mut TestAppContext) { channels: vec![proto::Channel { id: channel_id, name: "the-channel".to_string(), + visibility: proto::ChannelVisibility::Members as i32, + role: proto::ChannelRole::Member.into(), + parent_path: vec![], }], ..Default::default() }); cx.foreground().run_until_parked(); cx.read(|cx| { - assert_channels(&channel_store, &[(0, "the-channel".to_string(), false)], cx); + assert_channels( + &channel_store, + &[(0, "the-channel".to_string(), proto::ChannelRole::Member)], + cx, + ); }); let get_users = server.receive::().await.unwrap(); @@ -181,7 +185,7 @@ async fn test_channel_messages(cx: &mut TestAppContext) { // Join a channel and populate its existing messages. let channel = channel_store.update(cx, |store, cx| { - let channel_id = store.channel_dag_entries().next().unwrap().1.id; + let channel_id = store.ordered_channels().next().unwrap().1.id; store.open_channel_chat(channel_id, cx) }); let join_channel = server.receive::().await.unwrap(); @@ -194,6 +198,7 @@ async fn test_channel_messages(cx: &mut TestAppContext) { body: "a".into(), timestamp: 1000, sender_id: 5, + mentions: vec![], nonce: Some(1.into()), }, proto::ChannelMessage { @@ -201,6 +206,7 @@ async fn test_channel_messages(cx: &mut TestAppContext) { body: "b".into(), timestamp: 1001, sender_id: 6, + mentions: vec![], nonce: Some(2.into()), }, ], @@ -247,6 +253,7 @@ async fn test_channel_messages(cx: &mut TestAppContext) { body: "c".into(), timestamp: 1002, sender_id: 7, + mentions: vec![], nonce: Some(3.into()), }), }); @@ -284,7 +291,7 @@ async fn test_channel_messages(cx: &mut TestAppContext) { // Scroll up to view older messages. channel.update(cx, |channel, cx| { - assert!(channel.load_more_messages(cx)); + channel.load_more_messages(cx).unwrap().detach(); }); let get_messages = server.receive::().await.unwrap(); assert_eq!(get_messages.payload.channel_id, 5); @@ -300,6 +307,7 @@ async fn test_channel_messages(cx: &mut TestAppContext) { timestamp: 998, sender_id: 5, nonce: Some(4.into()), + mentions: vec![], }, proto::ChannelMessage { id: 9, @@ -307,6 +315,7 @@ async fn test_channel_messages(cx: &mut TestAppContext) { timestamp: 999, sender_id: 6, nonce: Some(5.into()), + mentions: vec![], }, ], }, @@ -358,19 +367,13 @@ fn update_channels( #[track_caller] fn assert_channels( channel_store: &ModelHandle, - expected_channels: &[(usize, String, bool)], + expected_channels: &[(usize, String, proto::ChannelRole)], cx: &AppContext, ) { let actual = channel_store.read_with(cx, |store, _| { store - .channel_dag_entries() - .map(|(depth, channel)| { - ( - depth, - channel.name.to_string(), - store.is_user_admin(channel.id), - ) - }) + .ordered_channels() + .map(|(depth, channel)| (depth, channel.name.to_string(), channel.role)) .collect::>() }); assert_eq!(actual, expected_channels); diff --git a/crates/client/src/telemetry.rs b/crates/client/src/telemetry.rs index 70878bf2e4..fd93aaeec8 100644 --- a/crates/client/src/telemetry.rs +++ b/crates/client/src/telemetry.rs @@ -4,7 +4,9 @@ use lazy_static::lazy_static; use parking_lot::Mutex; use serde::Serialize; use std::{env, io::Write, mem, path::PathBuf, sync::Arc, time::Duration}; -use sysinfo::{Pid, PidExt, ProcessExt, System, SystemExt}; +use sysinfo::{ + CpuRefreshKind, Pid, PidExt, ProcessExt, ProcessRefreshKind, RefreshKind, System, SystemExt, +}; use tempfile::NamedTempFile; use util::http::HttpClient; use util::{channel::ReleaseChannel, TryFutureExt}; @@ -166,8 +168,16 @@ impl Telemetry { let this = self.clone(); cx.spawn(|mut cx| async move { - let mut system = System::new_all(); - system.refresh_all(); + // Avoiding calling `System::new_all()`, as there have been crashes related to it + let refresh_kind = RefreshKind::new() + .with_memory() // For memory usage + .with_processes(ProcessRefreshKind::everything()) // For process usage + .with_cpu(CpuRefreshKind::everything()); // For core count + + let mut system = System::new_with_specifics(refresh_kind); + + // Avoiding calling `refresh_all()`, just update what we need + system.refresh_specifics(refresh_kind); loop { // Waiting some amount of time before the first query is important to get a reasonable value @@ -175,8 +185,7 @@ impl Telemetry { const DURATION_BETWEEN_SYSTEM_EVENTS: Duration = Duration::from_secs(60); smol::Timer::after(DURATION_BETWEEN_SYSTEM_EVENTS).await; - system.refresh_memory(); - system.refresh_processes(); + system.refresh_specifics(refresh_kind); let current_process = Pid::from_u32(std::process::id()); let Some(process) = system.processes().get(¤t_process) else { diff --git a/crates/client/src/user.rs b/crates/client/src/user.rs index 6aa41708e3..8299b7c6e4 100644 --- a/crates/client/src/user.rs +++ b/crates/client/src/user.rs @@ -293,21 +293,19 @@ impl UserStore { // No need to paralellize here let mut updated_contacts = Vec::new(); for contact in message.contacts { - let should_notify = contact.should_notify; - updated_contacts.push(( - Arc::new(Contact::from_proto(contact, &this, &mut cx).await?), - should_notify, + updated_contacts.push(Arc::new( + Contact::from_proto(contact, &this, &mut cx).await?, )); } let mut incoming_requests = Vec::new(); for request in message.incoming_requests { - incoming_requests.push({ - let user = this - .update(&mut cx, |this, cx| this.get_user(request.requester_id, cx)) - .await?; - (user, request.should_notify) - }); + incoming_requests.push( + this.update(&mut cx, |this, cx| { + this.get_user(request.requester_id, cx) + }) + .await?, + ); } let mut outgoing_requests = Vec::new(); @@ -330,13 +328,7 @@ impl UserStore { this.contacts .retain(|contact| !removed_contacts.contains(&contact.user.id)); // Update existing contacts and insert new ones - for (updated_contact, should_notify) in updated_contacts { - if should_notify { - cx.emit(Event::Contact { - user: updated_contact.user.clone(), - kind: ContactEventKind::Accepted, - }); - } + for updated_contact in updated_contacts { match this.contacts.binary_search_by_key( &&updated_contact.user.github_login, |contact| &contact.user.github_login, @@ -359,14 +351,7 @@ impl UserStore { } }); // Update existing incoming requests and insert new ones - for (user, should_notify) in incoming_requests { - if should_notify { - cx.emit(Event::Contact { - user: user.clone(), - kind: ContactEventKind::Requested, - }); - } - + for user in incoming_requests { match this .incoming_contact_requests .binary_search_by_key(&&user.github_login, |contact| { @@ -415,6 +400,12 @@ impl UserStore { &self.incoming_contact_requests } + pub fn has_incoming_contact_request(&self, user_id: u64) -> bool { + self.incoming_contact_requests + .iter() + .any(|user| user.id == user_id) + } + pub fn outgoing_contact_requests(&self) -> &[Arc] { &self.outgoing_contact_requests } diff --git a/crates/collab/Cargo.toml b/crates/collab/Cargo.toml index b91f0e1a5f..987c295407 100644 --- a/crates/collab/Cargo.toml +++ b/crates/collab/Cargo.toml @@ -3,7 +3,7 @@ authors = ["Nathan Sobo "] default-run = "collab" edition = "2021" name = "collab" -version = "0.24.0" +version = "0.27.0" publish = false [[bin]] @@ -73,6 +73,7 @@ git = { path = "../git", features = ["test-support"] } live_kit_client = { path = "../live_kit_client", features = ["test-support"] } lsp = { path = "../lsp", features = ["test-support"] } node_runtime = { path = "../node_runtime" } +notifications = { path = "../notifications", features = ["test-support"] } project = { path = "../project", features = ["test-support"] } rpc = { path = "../rpc", features = ["test-support"] } settings = { path = "../settings", features = ["test-support"] } diff --git a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql index 5a84bfd796..775a4c1bbe 100644 --- a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql +++ b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql @@ -44,7 +44,7 @@ CREATE UNIQUE INDEX "index_rooms_on_channel_id" ON "rooms" ("channel_id"); CREATE TABLE "projects" ( "id" INTEGER PRIMARY KEY AUTOINCREMENT, - "room_id" INTEGER REFERENCES rooms (id) NOT NULL, + "room_id" INTEGER REFERENCES rooms (id) ON DELETE CASCADE NOT NULL, "host_user_id" INTEGER REFERENCES users (id) NOT NULL, "host_connection_id" INTEGER, "host_connection_server_id" INTEGER REFERENCES servers (id) ON DELETE CASCADE, @@ -192,9 +192,13 @@ CREATE INDEX "index_followers_on_room_id" ON "followers" ("room_id"); CREATE TABLE "channels" ( "id" INTEGER PRIMARY KEY AUTOINCREMENT, "name" VARCHAR NOT NULL, - "created_at" TIMESTAMP NOT NULL DEFAULT now + "created_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + "visibility" VARCHAR NOT NULL, + "parent_path" TEXT ); +CREATE INDEX "index_channels_on_parent_path" ON "channels" ("parent_path"); + CREATE TABLE IF NOT EXISTS "channel_chat_participants" ( "id" INTEGER PRIMARY KEY AUTOINCREMENT, "user_id" INTEGER NOT NULL REFERENCES users (id), @@ -213,19 +217,22 @@ CREATE TABLE IF NOT EXISTS "channel_messages" ( "nonce" BLOB NOT NULL ); CREATE INDEX "index_channel_messages_on_channel_id" ON "channel_messages" ("channel_id"); -CREATE UNIQUE INDEX "index_channel_messages_on_nonce" ON "channel_messages" ("nonce"); +CREATE UNIQUE INDEX "index_channel_messages_on_sender_id_nonce" ON "channel_messages" ("sender_id", "nonce"); -CREATE TABLE "channel_paths" ( - "id_path" TEXT NOT NULL PRIMARY KEY, - "channel_id" INTEGER NOT NULL REFERENCES channels (id) ON DELETE CASCADE +CREATE TABLE "channel_message_mentions" ( + "message_id" INTEGER NOT NULL REFERENCES channel_messages (id) ON DELETE CASCADE, + "start_offset" INTEGER NOT NULL, + "end_offset" INTEGER NOT NULL, + "user_id" INTEGER NOT NULL REFERENCES users (id) ON DELETE CASCADE, + PRIMARY KEY(message_id, start_offset) ); -CREATE INDEX "index_channel_paths_on_channel_id" ON "channel_paths" ("channel_id"); CREATE TABLE "channel_members" ( "id" INTEGER PRIMARY KEY AUTOINCREMENT, "channel_id" INTEGER NOT NULL REFERENCES channels (id) ON DELETE CASCADE, "user_id" INTEGER NOT NULL REFERENCES users (id) ON DELETE CASCADE, "admin" BOOLEAN NOT NULL DEFAULT false, + "role" VARCHAR, "accepted" BOOLEAN NOT NULL DEFAULT false, "updated_at" TIMESTAMP NOT NULL DEFAULT now ); @@ -312,3 +319,26 @@ CREATE TABLE IF NOT EXISTS "observed_channel_messages" ( ); CREATE UNIQUE INDEX "index_observed_channel_messages_user_and_channel_id" ON "observed_channel_messages" ("user_id", "channel_id"); + +CREATE TABLE "notification_kinds" ( + "id" INTEGER PRIMARY KEY AUTOINCREMENT, + "name" VARCHAR NOT NULL +); + +CREATE UNIQUE INDEX "index_notification_kinds_on_name" ON "notification_kinds" ("name"); + +CREATE TABLE "notifications" ( + "id" INTEGER PRIMARY KEY AUTOINCREMENT, + "created_at" TIMESTAMP NOT NULL default CURRENT_TIMESTAMP, + "recipient_id" INTEGER NOT NULL REFERENCES users (id) ON DELETE CASCADE, + "kind" INTEGER NOT NULL REFERENCES notification_kinds (id), + "entity_id" INTEGER, + "content" TEXT, + "is_read" BOOLEAN NOT NULL DEFAULT FALSE, + "response" BOOLEAN +); + +CREATE INDEX + "index_notifications_on_recipient_id_is_read_kind_entity_id" + ON "notifications" + ("recipient_id", "is_read", "kind", "entity_id"); diff --git a/crates/collab/migrations/20231004130100_create_notifications.sql b/crates/collab/migrations/20231004130100_create_notifications.sql new file mode 100644 index 0000000000..93c282c631 --- /dev/null +++ b/crates/collab/migrations/20231004130100_create_notifications.sql @@ -0,0 +1,22 @@ +CREATE TABLE "notification_kinds" ( + "id" SERIAL PRIMARY KEY, + "name" VARCHAR NOT NULL +); + +CREATE UNIQUE INDEX "index_notification_kinds_on_name" ON "notification_kinds" ("name"); + +CREATE TABLE notifications ( + "id" SERIAL PRIMARY KEY, + "created_at" TIMESTAMP NOT NULL DEFAULT now(), + "recipient_id" INTEGER NOT NULL REFERENCES users (id) ON DELETE CASCADE, + "kind" INTEGER NOT NULL REFERENCES notification_kinds (id), + "entity_id" INTEGER, + "content" TEXT, + "is_read" BOOLEAN NOT NULL DEFAULT FALSE, + "response" BOOLEAN +); + +CREATE INDEX + "index_notifications_on_recipient_id_is_read_kind_entity_id" + ON "notifications" + ("recipient_id", "is_read", "kind", "entity_id"); diff --git a/crates/collab/migrations/20231011214412_add_guest_role.sql b/crates/collab/migrations/20231011214412_add_guest_role.sql new file mode 100644 index 0000000000..1713547158 --- /dev/null +++ b/crates/collab/migrations/20231011214412_add_guest_role.sql @@ -0,0 +1,4 @@ +ALTER TABLE channel_members ADD COLUMN role TEXT; +UPDATE channel_members SET role = CASE WHEN admin THEN 'admin' ELSE 'member' END; + +ALTER TABLE channels ADD COLUMN visibility TEXT NOT NULL DEFAULT 'members'; diff --git a/crates/collab/migrations/20231017185833_projects_room_id_fkey_on_delete_cascade.sql b/crates/collab/migrations/20231017185833_projects_room_id_fkey_on_delete_cascade.sql new file mode 100644 index 0000000000..be535ff7fa --- /dev/null +++ b/crates/collab/migrations/20231017185833_projects_room_id_fkey_on_delete_cascade.sql @@ -0,0 +1,8 @@ +-- Add migration script here + +ALTER TABLE projects + DROP CONSTRAINT projects_room_id_fkey, + ADD CONSTRAINT projects_room_id_fkey + FOREIGN KEY (room_id) + REFERENCES rooms (id) + ON DELETE CASCADE; diff --git a/crates/collab/migrations/20231018102700_create_mentions.sql b/crates/collab/migrations/20231018102700_create_mentions.sql new file mode 100644 index 0000000000..221a1748cf --- /dev/null +++ b/crates/collab/migrations/20231018102700_create_mentions.sql @@ -0,0 +1,11 @@ +CREATE TABLE "channel_message_mentions" ( + "message_id" INTEGER NOT NULL REFERENCES channel_messages (id) ON DELETE CASCADE, + "start_offset" INTEGER NOT NULL, + "end_offset" INTEGER NOT NULL, + "user_id" INTEGER NOT NULL REFERENCES users (id) ON DELETE CASCADE, + PRIMARY KEY(message_id, start_offset) +); + +-- We use 'on conflict update' with this index, so it should be per-user. +CREATE UNIQUE INDEX "index_channel_messages_on_sender_id_nonce" ON "channel_messages" ("sender_id", "nonce"); +DROP INDEX "index_channel_messages_on_nonce"; diff --git a/crates/collab/migrations/20231024085546_move_channel_paths_to_channels_table.sql b/crates/collab/migrations/20231024085546_move_channel_paths_to_channels_table.sql new file mode 100644 index 0000000000..d9fc6c8722 --- /dev/null +++ b/crates/collab/migrations/20231024085546_move_channel_paths_to_channels_table.sql @@ -0,0 +1,12 @@ +ALTER TABLE channels ADD COLUMN parent_path TEXT; + +UPDATE channels +SET parent_path = substr( + channel_paths.id_path, + 2, + length(channel_paths.id_path) - length('/' || channel_paths.channel_id::text || '/') +) +FROM channel_paths +WHERE channel_paths.channel_id = channels.id; + +CREATE INDEX "index_channels_on_parent_path" ON "channels" ("parent_path"); diff --git a/crates/collab/src/bin/seed.rs b/crates/collab/src/bin/seed.rs index cb1594e941..88fe0a647b 100644 --- a/crates/collab/src/bin/seed.rs +++ b/crates/collab/src/bin/seed.rs @@ -71,7 +71,6 @@ async fn main() { db::NewUserParams { github_login: github_user.login, github_user_id: github_user.id, - invite_count: 5, }, ) .await diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index e60b7cc33d..df33416a46 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -20,7 +20,7 @@ use rpc::{ }; use sea_orm::{ entity::prelude::*, - sea_query::{Alias, Expr, OnConflict, Query}, + sea_query::{Alias, Expr, OnConflict}, ActiveValue, Condition, ConnectionTrait, DatabaseConnection, DatabaseTransaction, DbErr, FromQueryResult, IntoActiveModel, IsolationLevel, JoinType, QueryOrder, QuerySelect, Statement, TransactionTrait, @@ -47,14 +47,14 @@ pub use ids::*; pub use sea_orm::ConnectOptions; pub use tables::user::Model as User; -use self::queries::channels::ChannelGraph; - pub struct Database { options: ConnectOptions, pool: DatabaseConnection, rooms: DashMap>>, rng: Mutex, executor: Executor, + notification_kinds_by_id: HashMap, + notification_kinds_by_name: HashMap, #[cfg(test)] runtime: Option, } @@ -69,6 +69,8 @@ impl Database { pool: sea_orm::Database::connect(options).await?, rooms: DashMap::with_capacity(16384), rng: Mutex::new(StdRng::seed_from_u64(0)), + notification_kinds_by_id: HashMap::default(), + notification_kinds_by_name: HashMap::default(), executor, #[cfg(test)] runtime: None, @@ -121,6 +123,11 @@ impl Database { Ok(new_migrations) } + pub async fn initialize_static_data(&mut self) -> Result<()> { + self.initialize_notification_kinds().await?; + Ok(()) + } + pub async fn transaction(&self, f: F) -> Result where F: Send + Fn(TransactionHandle) -> Fut, @@ -361,18 +368,9 @@ impl RoomGuard { #[derive(Clone, Debug, PartialEq, Eq)] pub enum Contact { - Accepted { - user_id: UserId, - should_notify: bool, - busy: bool, - }, - Outgoing { - user_id: UserId, - }, - Incoming { - user_id: UserId, - should_notify: bool, - }, + Accepted { user_id: UserId, busy: bool }, + Outgoing { user_id: UserId }, + Incoming { user_id: UserId }, } impl Contact { @@ -385,6 +383,15 @@ impl Contact { } } +pub type NotificationBatch = Vec<(UserId, proto::Notification)>; + +pub struct CreatedChannelMessage { + pub message_id: MessageId, + pub participant_connection_ids: Vec, + pub channel_members: Vec, + pub notifications: NotificationBatch, +} + #[derive(Clone, Debug, PartialEq, Eq, FromQueryResult, Serialize, Deserialize)] pub struct Invite { pub email_address: String, @@ -417,7 +424,6 @@ pub struct WaitlistSummary { pub struct NewUserParams { pub github_login: String, pub github_user_id: i32, - pub invite_count: i32, } #[derive(Debug)] @@ -428,17 +434,115 @@ pub struct NewUserResult { pub signup_device_id: Option, } -#[derive(FromQueryResult, Debug, PartialEq, Eq, Hash)] +#[derive(Debug)] +pub struct MoveChannelResult { + pub participants_to_update: HashMap, + pub participants_to_remove: HashSet, + pub moved_channels: HashSet, +} + +#[derive(Debug)] +pub struct RenameChannelResult { + pub channel: Channel, + pub participants_to_update: HashMap, +} + +#[derive(Debug)] +pub struct CreateChannelResult { + pub channel: Channel, + pub participants_to_update: Vec<(UserId, ChannelsForUser)>, +} + +#[derive(Debug)] +pub struct SetChannelVisibilityResult { + pub participants_to_update: HashMap, + pub participants_to_remove: HashSet, + pub channels_to_remove: Vec, +} + +#[derive(Debug)] +pub struct MembershipUpdated { + pub channel_id: ChannelId, + pub new_channels: ChannelsForUser, + pub removed_channels: Vec, +} + +#[derive(Debug)] +pub enum SetMemberRoleResult { + InviteUpdated(Channel), + MembershipUpdated(MembershipUpdated), +} + +#[derive(Debug)] +pub struct InviteMemberResult { + pub channel: Channel, + pub notifications: NotificationBatch, +} + +#[derive(Debug)] +pub struct RespondToChannelInvite { + pub membership_update: Option, + pub notifications: NotificationBatch, +} + +#[derive(Debug)] +pub struct RemoveChannelMemberResult { + pub membership_update: MembershipUpdated, + pub notification_id: Option, +} + +#[derive(Debug, PartialEq, Eq, Hash)] pub struct Channel { pub id: ChannelId, pub name: String, + pub visibility: ChannelVisibility, + pub role: ChannelRole, + pub parent_path: Vec, +} + +impl Channel { + fn from_model(value: channel::Model, role: ChannelRole) -> Self { + Channel { + id: value.id, + visibility: value.visibility, + name: value.clone().name, + role, + parent_path: value.ancestors().collect(), + } + } + + pub fn to_proto(&self) -> proto::Channel { + proto::Channel { + id: self.id.to_proto(), + name: self.name.clone(), + visibility: self.visibility.into(), + role: self.role.into(), + parent_path: self.parent_path.iter().map(|c| c.to_proto()).collect(), + } + } +} + +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct ChannelMember { + pub role: ChannelRole, + pub user_id: UserId, + pub kind: proto::channel_member::Kind, +} + +impl ChannelMember { + pub fn to_proto(&self) -> proto::ChannelMember { + proto::ChannelMember { + role: self.role.into(), + user_id: self.user_id.to_proto(), + kind: self.kind.into(), + } + } } #[derive(Debug, PartialEq)] pub struct ChannelsForUser { - pub channels: ChannelGraph, + pub channels: Vec, pub channel_participants: HashMap>, - pub channels_with_admin_privileges: HashSet, pub unseen_buffer_changes: Vec, pub channel_messages: Vec, } diff --git a/crates/collab/src/db/ids.rs b/crates/collab/src/db/ids.rs index 23bb9e53bf..5f0df90811 100644 --- a/crates/collab/src/db/ids.rs +++ b/crates/collab/src/db/ids.rs @@ -1,4 +1,5 @@ use crate::Result; +use rpc::proto; use sea_orm::{entity::prelude::*, DbErr}; use serde::{Deserialize, Serialize}; @@ -80,3 +81,119 @@ id_type!(SignupId); id_type!(UserId); id_type!(ChannelBufferCollaboratorId); id_type!(FlagId); +id_type!(NotificationId); +id_type!(NotificationKindId); + +#[derive(Eq, PartialEq, Copy, Clone, Debug, EnumIter, DeriveActiveEnum, Default, Hash)] +#[sea_orm(rs_type = "String", db_type = "String(None)")] +pub enum ChannelRole { + #[sea_orm(string_value = "admin")] + Admin, + #[sea_orm(string_value = "member")] + #[default] + Member, + #[sea_orm(string_value = "guest")] + Guest, + #[sea_orm(string_value = "banned")] + Banned, +} + +impl ChannelRole { + pub fn should_override(&self, other: Self) -> bool { + use ChannelRole::*; + match self { + Admin => matches!(other, Member | Banned | Guest), + Member => matches!(other, Banned | Guest), + Banned => matches!(other, Guest), + Guest => false, + } + } + + pub fn max(&self, other: Self) -> Self { + if self.should_override(other) { + *self + } else { + other + } + } + + pub fn can_see_all_descendants(&self) -> bool { + use ChannelRole::*; + match self { + Admin | Member => true, + Guest | Banned => false, + } + } + + pub fn can_only_see_public_descendants(&self) -> bool { + use ChannelRole::*; + match self { + Guest => true, + Admin | Member | Banned => false, + } + } +} + +impl From for ChannelRole { + fn from(value: proto::ChannelRole) -> Self { + match value { + proto::ChannelRole::Admin => ChannelRole::Admin, + proto::ChannelRole::Member => ChannelRole::Member, + proto::ChannelRole::Guest => ChannelRole::Guest, + proto::ChannelRole::Banned => ChannelRole::Banned, + } + } +} + +impl Into for ChannelRole { + fn into(self) -> proto::ChannelRole { + match self { + ChannelRole::Admin => proto::ChannelRole::Admin, + ChannelRole::Member => proto::ChannelRole::Member, + ChannelRole::Guest => proto::ChannelRole::Guest, + ChannelRole::Banned => proto::ChannelRole::Banned, + } + } +} + +impl Into for ChannelRole { + fn into(self) -> i32 { + let proto: proto::ChannelRole = self.into(); + proto.into() + } +} + +#[derive(Eq, PartialEq, Copy, Clone, Debug, EnumIter, DeriveActiveEnum, Default, Hash)] +#[sea_orm(rs_type = "String", db_type = "String(None)")] +pub enum ChannelVisibility { + #[sea_orm(string_value = "public")] + Public, + #[sea_orm(string_value = "members")] + #[default] + Members, +} + +impl From for ChannelVisibility { + fn from(value: proto::ChannelVisibility) -> Self { + match value { + proto::ChannelVisibility::Public => ChannelVisibility::Public, + proto::ChannelVisibility::Members => ChannelVisibility::Members, + } + } +} + +impl Into for ChannelVisibility { + fn into(self) -> proto::ChannelVisibility { + match self { + ChannelVisibility::Public => proto::ChannelVisibility::Public, + ChannelVisibility::Members => proto::ChannelVisibility::Members, + } + } +} + +impl Into for ChannelVisibility { + fn into(self) -> i32 { + let proto: proto::ChannelVisibility = self.into(); + proto.into() + } +} diff --git a/crates/collab/src/db/queries.rs b/crates/collab/src/db/queries.rs index 80bd8704b2..629e26f1a9 100644 --- a/crates/collab/src/db/queries.rs +++ b/crates/collab/src/db/queries.rs @@ -5,6 +5,7 @@ pub mod buffers; pub mod channels; pub mod contacts; pub mod messages; +pub mod notifications; pub mod projects; pub mod rooms; pub mod servers; diff --git a/crates/collab/src/db/queries/access_tokens.rs b/crates/collab/src/db/queries/access_tokens.rs index def9428a2b..589b6483df 100644 --- a/crates/collab/src/db/queries/access_tokens.rs +++ b/crates/collab/src/db/queries/access_tokens.rs @@ -1,4 +1,5 @@ use super::*; +use sea_orm::sea_query::Query; impl Database { pub async fn create_access_token( diff --git a/crates/collab/src/db/queries/buffers.rs b/crates/collab/src/db/queries/buffers.rs index c85432f2bb..9eddb1f618 100644 --- a/crates/collab/src/db/queries/buffers.rs +++ b/crates/collab/src/db/queries/buffers.rs @@ -16,7 +16,8 @@ impl Database { connection: ConnectionId, ) -> Result { self.transaction(|tx| async move { - self.check_user_is_channel_member(channel_id, user_id, &tx) + let channel = self.get_channel_internal(channel_id, &*tx).await?; + self.check_user_is_channel_participant(&channel, user_id, &tx) .await?; let buffer = channel::Model { @@ -129,9 +130,11 @@ impl Database { self.transaction(|tx| async move { let mut results = Vec::new(); for client_buffer in buffers { - let channel_id = ChannelId::from_proto(client_buffer.channel_id); + let channel = self + .get_channel_internal(ChannelId::from_proto(client_buffer.channel_id), &*tx) + .await?; if self - .check_user_is_channel_member(channel_id, user_id, &*tx) + .check_user_is_channel_participant(&channel, user_id, &*tx) .await .is_err() { @@ -139,9 +142,9 @@ impl Database { continue; } - let buffer = self.get_channel_buffer(channel_id, &*tx).await?; + let buffer = self.get_channel_buffer(channel.id, &*tx).await?; let mut collaborators = channel_buffer_collaborator::Entity::find() - .filter(channel_buffer_collaborator::Column::ChannelId.eq(channel_id)) + .filter(channel_buffer_collaborator::Column::ChannelId.eq(channel.id)) .all(&*tx) .await?; @@ -439,7 +442,8 @@ impl Database { Vec, )> { self.transaction(move |tx| async move { - self.check_user_is_channel_member(channel_id, user, &*tx) + let channel = self.get_channel_internal(channel_id, &*tx).await?; + self.check_user_is_channel_member(&channel, user, &*tx) .await?; let buffer = buffer::Entity::find() @@ -482,7 +486,7 @@ impl Database { ) .await?; - channel_members = self.get_channel_members_internal(channel_id, &*tx).await?; + channel_members = self.get_channel_participants(&channel, &*tx).await?; let collaborators = self .get_channel_buffer_collaborators_internal(channel_id, &*tx) .await?; diff --git a/crates/collab/src/db/queries/channels.rs b/crates/collab/src/db/queries/channels.rs index c576d2406b..68b06e435d 100644 --- a/crates/collab/src/db/queries/channels.rs +++ b/crates/collab/src/db/queries/channels.rs @@ -1,8 +1,6 @@ use super::*; -use rpc::proto::ChannelEdge; -use smallvec::SmallVec; - -type ChannelDescendants = HashMap>; +use rpc::proto::channel_member::Kind; +use sea_orm::TryGetableMany; impl Database { #[cfg(test)] @@ -19,71 +17,242 @@ impl Database { .await } + #[cfg(test)] pub async fn create_root_channel(&self, name: &str, creator_id: UserId) -> Result { - self.create_channel(name, None, creator_id).await + Ok(self + .create_channel(name, None, creator_id) + .await? + .channel + .id) + } + + #[cfg(test)] + pub async fn create_sub_channel( + &self, + name: &str, + parent: ChannelId, + creator_id: UserId, + ) -> Result { + Ok(self + .create_channel(name, Some(parent), creator_id) + .await? + .channel + .id) } pub async fn create_channel( &self, name: &str, - parent: Option, - creator_id: UserId, - ) -> Result { + parent_channel_id: Option, + admin_id: UserId, + ) -> Result { let name = Self::sanitize_channel_name(name)?; self.transaction(move |tx| async move { - if let Some(parent) = parent { - self.check_user_is_channel_admin(parent, creator_id, &*tx) + let mut parent = None; + + if let Some(parent_channel_id) = parent_channel_id { + let parent_channel = self.get_channel_internal(parent_channel_id, &*tx).await?; + self.check_user_is_channel_admin(&parent_channel, admin_id, &*tx) .await?; + parent = Some(parent_channel); } let channel = channel::ActiveModel { + id: ActiveValue::NotSet, name: ActiveValue::Set(name.to_string()), - ..Default::default() + visibility: ActiveValue::Set(ChannelVisibility::Members), + parent_path: ActiveValue::Set( + parent + .as_ref() + .map_or(String::new(), |parent| parent.path()), + ), } .insert(&*tx) .await?; - if let Some(parent) = parent { - let sql = r#" - INSERT INTO channel_paths - (id_path, channel_id) - SELECT - id_path || $1 || '/', $2 - FROM - channel_paths - WHERE - channel_id = $3 - "#; - let channel_paths_stmt = Statement::from_sql_and_values( - self.pool.get_database_backend(), - sql, - [ - channel.id.to_proto().into(), - channel.id.to_proto().into(), - parent.to_proto().into(), - ], - ); - tx.execute(channel_paths_stmt).await?; + let participants_to_update; + if let Some(parent) = &parent { + participants_to_update = self + .participants_to_notify_for_channel_change(parent, &*tx) + .await?; } else { - channel_path::Entity::insert(channel_path::ActiveModel { + participants_to_update = vec![]; + + channel_member::ActiveModel { + id: ActiveValue::NotSet, channel_id: ActiveValue::Set(channel.id), - id_path: ActiveValue::Set(format!("/{}/", channel.id)), + user_id: ActiveValue::Set(admin_id), + accepted: ActiveValue::Set(true), + role: ActiveValue::Set(ChannelRole::Admin), + } + .insert(&*tx) + .await?; + }; + + Ok(CreateChannelResult { + channel: Channel::from_model(channel, ChannelRole::Admin), + participants_to_update, + }) + }) + .await + } + + pub async fn join_channel( + &self, + channel_id: ChannelId, + user_id: UserId, + connection: ConnectionId, + environment: &str, + ) -> Result<(JoinRoom, Option, ChannelRole)> { + self.transaction(move |tx| async move { + let channel = self.get_channel_internal(channel_id, &*tx).await?; + let mut role = self.channel_role_for_user(&channel, user_id, &*tx).await?; + + let mut accept_invite_result = None; + + if role.is_none() { + if let Some(invitation) = self + .pending_invite_for_channel(&channel, user_id, &*tx) + .await? + { + // note, this may be a parent channel + role = Some(invitation.role); + channel_member::Entity::update(channel_member::ActiveModel { + accepted: ActiveValue::Set(true), + ..invitation.into_active_model() + }) + .exec(&*tx) + .await?; + + accept_invite_result = Some( + self.calculate_membership_updated(&channel, user_id, &*tx) + .await?, + ); + + debug_assert!( + self.channel_role_for_user(&channel, user_id, &*tx).await? == role + ); + } + } + + if channel.visibility == ChannelVisibility::Public { + role = Some(ChannelRole::Guest); + let channel_to_join = self + .public_ancestors_including_self(&channel, &*tx) + .await? + .first() + .cloned() + .unwrap_or(channel.clone()); + + channel_member::Entity::insert(channel_member::ActiveModel { + id: ActiveValue::NotSet, + channel_id: ActiveValue::Set(channel_to_join.id), + user_id: ActiveValue::Set(user_id), + accepted: ActiveValue::Set(true), + role: ActiveValue::Set(ChannelRole::Guest), }) .exec(&*tx) .await?; + + accept_invite_result = Some( + self.calculate_membership_updated(&channel_to_join, user_id, &*tx) + .await?, + ); + + debug_assert!(self.channel_role_for_user(&channel, user_id, &*tx).await? == role); } - channel_member::ActiveModel { - channel_id: ActiveValue::Set(channel.id), - user_id: ActiveValue::Set(creator_id), - accepted: ActiveValue::Set(true), - admin: ActiveValue::Set(true), - ..Default::default() + if role.is_none() || role == Some(ChannelRole::Banned) { + Err(anyhow!("not allowed"))? } - .insert(&*tx) - .await?; - Ok(channel.id) + let live_kit_room = format!("channel-{}", nanoid::nanoid!(30)); + let room_id = self + .get_or_create_channel_room(channel_id, &live_kit_room, environment, &*tx) + .await?; + + self.join_channel_room_internal(room_id, user_id, connection, &*tx) + .await + .map(|jr| (jr, accept_invite_result, role.unwrap())) + }) + .await + } + + pub async fn set_channel_visibility( + &self, + channel_id: ChannelId, + visibility: ChannelVisibility, + admin_id: UserId, + ) -> Result { + self.transaction(move |tx| async move { + let channel = self.get_channel_internal(channel_id, &*tx).await?; + + self.check_user_is_channel_admin(&channel, admin_id, &*tx) + .await?; + + let previous_members = self + .get_channel_participant_details_internal(&channel, &*tx) + .await?; + + let mut model = channel.into_active_model(); + model.visibility = ActiveValue::Set(visibility); + let channel = model.update(&*tx).await?; + + let mut participants_to_update: HashMap = self + .participants_to_notify_for_channel_change(&channel, &*tx) + .await? + .into_iter() + .collect(); + + let mut channels_to_remove: Vec = vec![]; + let mut participants_to_remove: HashSet = HashSet::default(); + match visibility { + ChannelVisibility::Members => { + let all_descendents: Vec = self + .get_channel_descendants_including_self(vec![channel_id], &*tx) + .await? + .into_iter() + .map(|channel| channel.id) + .collect(); + + channels_to_remove = channel::Entity::find() + .filter( + channel::Column::Id + .is_in(all_descendents) + .and(channel::Column::Visibility.eq(ChannelVisibility::Public)), + ) + .all(&*tx) + .await? + .into_iter() + .map(|channel| channel.id) + .collect(); + + channels_to_remove.push(channel_id); + + for member in previous_members { + if member.role.can_only_see_public_descendants() { + participants_to_remove.insert(member.user_id); + } + } + } + ChannelVisibility::Public => { + if let Some(public_parent) = self.public_parent_channel(&channel, &*tx).await? { + let parent_updates = self + .participants_to_notify_for_channel_change(&public_parent, &*tx) + .await?; + + for (user_id, channels) in parent_updates { + participants_to_update.insert(user_id, channels); + } + } + } + } + + Ok(SetChannelVisibilityResult { + participants_to_update, + participants_to_remove, + channels_to_remove, + }) }) .await } @@ -94,37 +263,12 @@ impl Database { user_id: UserId, ) -> Result<(Vec, Vec)> { self.transaction(move |tx| async move { - self.check_user_is_channel_admin(channel_id, user_id, &*tx) + let channel = self.get_channel_internal(channel_id, &*tx).await?; + self.check_user_is_channel_admin(&channel, user_id, &*tx) .await?; - // Don't remove descendant channels that have additional parents. - let mut channels_to_remove = self.get_channel_descendants([channel_id], &*tx).await?; - { - let mut channels_to_keep = channel_path::Entity::find() - .filter( - channel_path::Column::ChannelId - .is_in( - channels_to_remove - .keys() - .copied() - .filter(|&id| id != channel_id), - ) - .and( - channel_path::Column::IdPath - .not_like(&format!("%/{}/%", channel_id)), - ), - ) - .stream(&*tx) - .await?; - while let Some(row) = channels_to_keep.next().await { - let row = row?; - channels_to_remove.remove(&row.channel_id); - } - } - - let channel_ancestors = self.get_channel_ancestors(channel_id, &*tx).await?; let members_to_notify: Vec = channel_member::Entity::find() - .filter(channel_member::Column::ChannelId.is_in(channel_ancestors)) + .filter(channel_member::Column::ChannelId.is_in(channel.ancestors_including_self())) .select_only() .column(channel_member::Column::UserId) .distinct() @@ -132,25 +276,19 @@ impl Database { .all(&*tx) .await?; + let channels_to_remove = self + .get_channel_descendants_including_self(vec![channel.id], &*tx) + .await? + .into_iter() + .map(|channel| channel.id) + .collect::>(); + channel::Entity::delete_many() - .filter(channel::Column::Id.is_in(channels_to_remove.keys().copied())) + .filter(channel::Column::Id.is_in(channels_to_remove.iter().copied())) .exec(&*tx) .await?; - // Delete any other paths that include this channel - let sql = r#" - DELETE FROM channel_paths - WHERE - id_path LIKE '%' || $1 || '%' - "#; - let channel_paths_stmt = Statement::from_sql_and_values( - self.pool.get_database_backend(), - sql, - [channel_id.to_proto().into()], - ); - tx.execute(channel_paths_stmt).await?; - - Ok((channels_to_remove.into_keys().collect(), members_to_notify)) + Ok((channels_to_remove, members_to_notify)) }) .await } @@ -160,23 +298,44 @@ impl Database { channel_id: ChannelId, invitee_id: UserId, inviter_id: UserId, - is_admin: bool, - ) -> Result<()> { + role: ChannelRole, + ) -> Result { self.transaction(move |tx| async move { - self.check_user_is_channel_admin(channel_id, inviter_id, &*tx) + let channel = self.get_channel_internal(channel_id, &*tx).await?; + self.check_user_is_channel_admin(&channel, inviter_id, &*tx) .await?; channel_member::ActiveModel { + id: ActiveValue::NotSet, channel_id: ActiveValue::Set(channel_id), user_id: ActiveValue::Set(invitee_id), accepted: ActiveValue::Set(false), - admin: ActiveValue::Set(is_admin), - ..Default::default() + role: ActiveValue::Set(role), } .insert(&*tx) .await?; - Ok(()) + let channel = Channel::from_model(channel, role); + + let notifications = self + .create_notification( + invitee_id, + rpc::Notification::ChannelInvitation { + channel_id: channel_id.to_proto(), + channel_name: channel.name.clone(), + inviter_id: inviter_id.to_proto(), + }, + true, + &*tx, + ) + .await? + .into_iter() + .collect(); + + Ok(InviteMemberResult { + channel, + notifications, + }) }) .await } @@ -192,24 +351,37 @@ impl Database { pub async fn rename_channel( &self, channel_id: ChannelId, - user_id: UserId, + admin_id: UserId, new_name: &str, - ) -> Result { + ) -> Result { self.transaction(move |tx| async move { let new_name = Self::sanitize_channel_name(new_name)?.to_string(); - self.check_user_is_channel_admin(channel_id, user_id, &*tx) + let channel = self.get_channel_internal(channel_id, &*tx).await?; + let role = self + .check_user_is_channel_admin(&channel, admin_id, &*tx) .await?; - channel::ActiveModel { - id: ActiveValue::Unchanged(channel_id), - name: ActiveValue::Set(new_name.clone()), - ..Default::default() - } - .update(&*tx) - .await?; + let mut model = channel.into_active_model(); + model.name = ActiveValue::Set(new_name.clone()); + let channel = model.update(&*tx).await?; - Ok(new_name) + let participants = self + .get_channel_participant_details_internal(&channel, &*tx) + .await?; + + Ok(RenameChannelResult { + channel: Channel::from_model(channel.clone(), role), + participants_to_update: participants + .iter() + .map(|participant| { + ( + participant.user_id, + Channel::from_model(channel.clone(), participant.role), + ) + }) + .collect(), + }) }) .await } @@ -219,10 +391,12 @@ impl Database { channel_id: ChannelId, user_id: UserId, accept: bool, - ) -> Result<()> { + ) -> Result { self.transaction(move |tx| async move { - let rows_affected = if accept { - channel_member::Entity::update_many() + let channel = self.get_channel_internal(channel_id, &*tx).await?; + + let membership_update = if accept { + let rows_affected = channel_member::Entity::update_many() .set(channel_member::ActiveModel { accepted: ActiveValue::Set(accept), ..Default::default() @@ -235,35 +409,91 @@ impl Database { ) .exec(&*tx) .await? - .rows_affected - } else { - channel_member::ActiveModel { - channel_id: ActiveValue::Unchanged(channel_id), - user_id: ActiveValue::Unchanged(user_id), - ..Default::default() + .rows_affected; + + if rows_affected == 0 { + Err(anyhow!("no such invitation"))?; } - .delete(&*tx) - .await? - .rows_affected + + Some( + self.calculate_membership_updated(&channel, user_id, &*tx) + .await?, + ) + } else { + let rows_affected = channel_member::Entity::delete_many() + .filter( + channel_member::Column::ChannelId + .eq(channel_id) + .and(channel_member::Column::UserId.eq(user_id)) + .and(channel_member::Column::Accepted.eq(false)), + ) + .exec(&*tx) + .await? + .rows_affected; + if rows_affected == 0 { + Err(anyhow!("no such invitation"))?; + } + + None }; - if rows_affected == 0 { - Err(anyhow!("no such invitation"))?; - } - - Ok(()) + Ok(RespondToChannelInvite { + membership_update, + notifications: self + .mark_notification_as_read_with_response( + user_id, + &rpc::Notification::ChannelInvitation { + channel_id: channel_id.to_proto(), + channel_name: Default::default(), + inviter_id: Default::default(), + }, + accept, + &*tx, + ) + .await? + .into_iter() + .collect(), + }) }) .await } + async fn calculate_membership_updated( + &self, + channel: &channel::Model, + user_id: UserId, + tx: &DatabaseTransaction, + ) -> Result { + let new_channels = self.get_user_channels(user_id, Some(channel), &*tx).await?; + let removed_channels = self + .get_channel_descendants_including_self(vec![channel.id], &*tx) + .await? + .into_iter() + .filter_map(|channel| { + if !new_channels.channels.iter().any(|c| c.id == channel.id) { + Some(channel.id) + } else { + None + } + }) + .collect::>(); + + Ok(MembershipUpdated { + channel_id: channel.id, + new_channels, + removed_channels, + }) + } + pub async fn remove_channel_member( &self, channel_id: ChannelId, member_id: UserId, - remover_id: UserId, - ) -> Result<()> { + admin_id: UserId, + ) -> Result { self.transaction(|tx| async move { - self.check_user_is_channel_admin(channel_id, remover_id, &*tx) + let channel = self.get_channel_internal(channel_id, &*tx).await?; + self.check_user_is_channel_admin(&channel, admin_id, &*tx) .await?; let result = channel_member::Entity::delete_many() @@ -279,13 +509,30 @@ impl Database { Err(anyhow!("no such member"))?; } - Ok(()) + Ok(RemoveChannelMemberResult { + membership_update: self + .calculate_membership_updated(&channel, member_id, &*tx) + .await?, + notification_id: self + .remove_notification( + member_id, + rpc::Notification::ChannelInvitation { + channel_id: channel_id.to_proto(), + channel_name: Default::default(), + inviter_id: Default::default(), + }, + &*tx, + ) + .await?, + }) }) .await } pub async fn get_channel_invites_for_user(&self, user_id: UserId) -> Result> { self.transaction(|tx| async move { + let mut role_for_channel: HashMap = HashMap::default(); + let channel_invites = channel_member::Entity::find() .filter( channel_member::Column::UserId @@ -295,22 +542,20 @@ impl Database { .all(&*tx) .await?; + for invite in channel_invites { + role_for_channel.insert(invite.channel_id, invite.role); + } + let channels = channel::Entity::find() - .filter( - channel::Column::Id.is_in( - channel_invites - .into_iter() - .map(|channel_member| channel_member.channel_id), - ), - ) + .filter(channel::Column::Id.is_in(role_for_channel.keys().copied())) .all(&*tx) .await?; let channels = channels .into_iter() - .map(|channel| Channel { - id: channel.id, - name: channel.name, + .filter_map(|channel| { + let role = *role_for_channel.get(&channel.id)?; + Some(Channel::from_model(channel, role)) }) .collect(); @@ -319,88 +564,11 @@ impl Database { .await } - async fn get_channel_graph( - &self, - parents_by_child_id: ChannelDescendants, - trim_dangling_parents: bool, - tx: &DatabaseTransaction, - ) -> Result { - let mut channels = Vec::with_capacity(parents_by_child_id.len()); - { - let mut rows = channel::Entity::find() - .filter(channel::Column::Id.is_in(parents_by_child_id.keys().copied())) - .stream(&*tx) - .await?; - while let Some(row) = rows.next().await { - let row = row?; - channels.push(Channel { - id: row.id, - name: row.name, - }) - } - } - - let mut edges = Vec::with_capacity(parents_by_child_id.len()); - for (channel, parents) in parents_by_child_id.iter() { - for parent in parents.into_iter() { - if trim_dangling_parents { - if parents_by_child_id.contains_key(parent) { - edges.push(ChannelEdge { - channel_id: channel.to_proto(), - parent_id: parent.to_proto(), - }); - } - } else { - edges.push(ChannelEdge { - channel_id: channel.to_proto(), - parent_id: parent.to_proto(), - }); - } - } - } - - Ok(ChannelGraph { channels, edges }) - } - pub async fn get_channels_for_user(&self, user_id: UserId) -> Result { self.transaction(|tx| async move { let tx = tx; - let channel_memberships = channel_member::Entity::find() - .filter( - channel_member::Column::UserId - .eq(user_id) - .and(channel_member::Column::Accepted.eq(true)), - ) - .all(&*tx) - .await?; - - self.get_user_channels(user_id, channel_memberships, &tx) - .await - }) - .await - } - - pub async fn get_channel_for_user( - &self, - channel_id: ChannelId, - user_id: UserId, - ) -> Result { - self.transaction(|tx| async move { - let tx = tx; - - let channel_membership = channel_member::Entity::find() - .filter( - channel_member::Column::UserId - .eq(user_id) - .and(channel_member::Column::ChannelId.eq(channel_id)) - .and(channel_member::Column::Accepted.eq(true)), - ) - .all(&*tx) - .await?; - - self.get_user_channels(user_id, channel_membership, &tx) - .await + self.get_user_channels(user_id, None, &tx).await }) .await } @@ -408,22 +576,78 @@ impl Database { pub async fn get_user_channels( &self, user_id: UserId, - channel_memberships: Vec, + ancestor_channel: Option<&channel::Model>, tx: &DatabaseTransaction, ) -> Result { - let parents_by_child_id = self - .get_channel_descendants(channel_memberships.iter().map(|m| m.channel_id), &*tx) + let channel_memberships = channel_member::Entity::find() + .filter( + channel_member::Column::UserId + .eq(user_id) + .and(channel_member::Column::Accepted.eq(true)), + ) + .all(&*tx) .await?; - let channels_with_admin_privileges = channel_memberships - .iter() - .filter_map(|membership| membership.admin.then_some(membership.channel_id)) + let descendants = self + .get_channel_descendants_including_self( + channel_memberships.iter().map(|m| m.channel_id), + &*tx, + ) + .await?; + + let mut roles_by_channel_id: HashMap = HashMap::default(); + for membership in channel_memberships.iter() { + roles_by_channel_id.insert(membership.channel_id, membership.role); + } + + let mut visible_channel_ids: HashSet = HashSet::default(); + + let channels: Vec = descendants + .into_iter() + .filter_map(|channel| { + let parent_role = channel + .parent_id() + .and_then(|parent_id| roles_by_channel_id.get(&parent_id)); + + let role = if let Some(parent_role) = parent_role { + let role = if let Some(existing_role) = roles_by_channel_id.get(&channel.id) { + existing_role.max(*parent_role) + } else { + *parent_role + }; + roles_by_channel_id.insert(channel.id, role); + role + } else { + *roles_by_channel_id.get(&channel.id)? + }; + + let can_see_parent_paths = role.can_see_all_descendants() + || role.can_only_see_public_descendants() + && channel.visibility == ChannelVisibility::Public; + if !can_see_parent_paths { + return None; + } + + visible_channel_ids.insert(channel.id); + + if let Some(ancestor) = ancestor_channel { + if !channel + .ancestors_including_self() + .any(|id| id == ancestor.id) + { + return None; + } + } + + let mut channel = Channel::from_model(channel, role); + channel + .parent_path + .retain(|id| visible_channel_ids.contains(&id)); + + Some(channel) + }) .collect(); - let graph = self - .get_channel_graph(parents_by_child_id, true, &tx) - .await?; - #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] enum QueryUserIdsAndChannelIds { ChannelId, @@ -434,7 +658,7 @@ impl Database { { let mut rows = room_participant::Entity::find() .inner_join(room::Entity) - .filter(room::Column::ChannelId.is_in(graph.channels.iter().map(|c| c.id))) + .filter(room::Column::ChannelId.is_in(channels.iter().map(|c| c.id))) .select_only() .column(room::Column::ChannelId) .column(room_participant::Column::UserId) @@ -447,7 +671,7 @@ impl Database { } } - let channel_ids = graph.channels.iter().map(|c| c.id).collect::>(); + let channel_ids = channels.iter().map(|c| c.id).collect::>(); let channel_buffer_changes = self .unseen_channel_buffer_changes(user_id, &channel_ids, &*tx) .await?; @@ -457,228 +681,428 @@ impl Database { .await?; Ok(ChannelsForUser { - channels: graph, + channels, channel_participants, - channels_with_admin_privileges, unseen_buffer_changes: channel_buffer_changes, channel_messages: unseen_messages, }) } - pub async fn get_channel_members(&self, id: ChannelId) -> Result> { - self.transaction(|tx| async move { self.get_channel_members_internal(id, &*tx).await }) - .await + async fn participants_to_notify_for_channel_change( + &self, + new_parent: &channel::Model, + tx: &DatabaseTransaction, + ) -> Result> { + let mut results: Vec<(UserId, ChannelsForUser)> = Vec::new(); + + let members = self + .get_channel_participant_details_internal(new_parent, &*tx) + .await?; + + for member in members.iter() { + if !member.role.can_see_all_descendants() { + continue; + } + results.push(( + member.user_id, + self.get_user_channels(member.user_id, Some(new_parent), &*tx) + .await?, + )) + } + + let public_parents = self + .public_ancestors_including_self(new_parent, &*tx) + .await?; + let public_parent = public_parents.last(); + + let Some(public_parent) = public_parent else { + return Ok(results); + }; + + // could save some time in the common case by skipping this if the + // new channel is not public and has no public descendants. + let public_members = if public_parent == new_parent { + members + } else { + self.get_channel_participant_details_internal(public_parent, &*tx) + .await? + }; + + for member in public_members { + if !member.role.can_only_see_public_descendants() { + continue; + }; + results.push(( + member.user_id, + self.get_user_channels(member.user_id, Some(public_parent), &*tx) + .await?, + )) + } + + Ok(results) } - pub async fn set_channel_member_admin( + pub async fn set_channel_member_role( &self, channel_id: ChannelId, - from: UserId, + admin_id: UserId, for_user: UserId, - admin: bool, - ) -> Result<()> { + role: ChannelRole, + ) -> Result { self.transaction(|tx| async move { - self.check_user_is_channel_admin(channel_id, from, &*tx) + let channel = self.get_channel_internal(channel_id, &*tx).await?; + self.check_user_is_channel_admin(&channel, admin_id, &*tx) .await?; - let result = channel_member::Entity::update_many() + let membership = channel_member::Entity::find() .filter( channel_member::Column::ChannelId .eq(channel_id) .and(channel_member::Column::UserId.eq(for_user)), ) - .set(channel_member::ActiveModel { - admin: ActiveValue::set(admin), - ..Default::default() - }) - .exec(&*tx) + .one(&*tx) .await?; - if result.rows_affected == 0 { - Err(anyhow!("no such member"))?; - } + let Some(membership) = membership else { + Err(anyhow!("no such member"))? + }; - Ok(()) + let mut update = membership.into_active_model(); + update.role = ActiveValue::Set(role); + let updated = channel_member::Entity::update(update).exec(&*tx).await?; + + if updated.accepted { + Ok(SetMemberRoleResult::MembershipUpdated( + self.calculate_membership_updated(&channel, for_user, &*tx) + .await?, + )) + } else { + Ok(SetMemberRoleResult::InviteUpdated(Channel::from_model( + channel, role, + ))) + } }) .await } - pub async fn get_channel_member_details( + pub async fn get_channel_participant_details( &self, channel_id: ChannelId, user_id: UserId, ) -> Result> { - self.transaction(|tx| async move { - self.check_user_is_channel_admin(channel_id, user_id, &*tx) - .await?; + let (role, members) = self + .transaction(move |tx| async move { + let channel = self.get_channel_internal(channel_id, &*tx).await?; + let role = self + .check_user_is_channel_participant(&channel, user_id, &*tx) + .await?; + Ok(( + role, + self.get_channel_participant_details_internal(&channel, &*tx) + .await?, + )) + }) + .await?; - #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] - enum QueryMemberDetails { - UserId, - Admin, - IsDirectMember, - Accepted, - } - - let tx = tx; - let ancestor_ids = self.get_channel_ancestors(channel_id, &*tx).await?; - let mut stream = channel_member::Entity::find() - .distinct() - .filter(channel_member::Column::ChannelId.is_in(ancestor_ids.iter().copied())) - .select_only() - .column(channel_member::Column::UserId) - .column(channel_member::Column::Admin) - .column_as( - channel_member::Column::ChannelId.eq(channel_id), - QueryMemberDetails::IsDirectMember, - ) - .column(channel_member::Column::Accepted) - .order_by_asc(channel_member::Column::UserId) - .into_values::<_, QueryMemberDetails>() - .stream(&*tx) - .await?; - - let mut rows = Vec::::new(); - while let Some(row) = stream.next().await { - let (user_id, is_admin, is_direct_member, is_invite_accepted): ( - UserId, - bool, - bool, - bool, - ) = row?; - let kind = match (is_direct_member, is_invite_accepted) { - (true, true) => proto::channel_member::Kind::Member, - (true, false) => proto::channel_member::Kind::Invitee, - (false, true) => proto::channel_member::Kind::AncestorMember, - (false, false) => continue, - }; - let user_id = user_id.to_proto(); - let kind = kind.into(); - if let Some(last_row) = rows.last_mut() { - if last_row.user_id == user_id { - if is_direct_member { - last_row.kind = kind; - last_row.admin = is_admin; - } - continue; + if role == ChannelRole::Admin { + Ok(members + .into_iter() + .map(|channel_member| channel_member.to_proto()) + .collect()) + } else { + return Ok(members + .into_iter() + .filter_map(|member| { + if member.kind == proto::channel_member::Kind::Invitee { + return None; } - } - rows.push(proto::ChannelMember { - user_id, - kind, - admin: is_admin, - }); - } - - Ok(rows) - }) - .await + Some(ChannelMember { + role: member.role, + user_id: member.user_id, + kind: proto::channel_member::Kind::Member, + }) + }) + .map(|channel_member| channel_member.to_proto()) + .collect()); + } } - pub async fn get_channel_members_internal( + async fn get_channel_participant_details_internal( &self, - id: ChannelId, + channel: &channel::Model, tx: &DatabaseTransaction, - ) -> Result> { - let ancestor_ids = self.get_channel_ancestors(id, tx).await?; - let user_ids = channel_member::Entity::find() - .distinct() - .filter( - channel_member::Column::ChannelId - .is_in(ancestor_ids.iter().copied()) - .and(channel_member::Column::Accepted.eq(true)), - ) + ) -> Result> { + #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] + enum QueryMemberDetails { + UserId, + Role, + IsDirectMember, + Accepted, + Visibility, + } + + let mut stream = channel_member::Entity::find() + .left_join(channel::Entity) + .filter(channel_member::Column::ChannelId.is_in(channel.ancestors_including_self())) .select_only() .column(channel_member::Column::UserId) - .into_values::<_, QueryUserIds>() - .all(&*tx) + .column(channel_member::Column::Role) + .column_as( + channel_member::Column::ChannelId.eq(channel.id), + QueryMemberDetails::IsDirectMember, + ) + .column(channel_member::Column::Accepted) + .column(channel::Column::Visibility) + .into_values::<_, QueryMemberDetails>() + .stream(&*tx) .await?; - Ok(user_ids) + + let mut user_details: HashMap = HashMap::default(); + + while let Some(user_membership) = stream.next().await { + let (user_id, channel_role, is_direct_member, is_invite_accepted, visibility): ( + UserId, + ChannelRole, + bool, + bool, + ChannelVisibility, + ) = user_membership?; + let kind = match (is_direct_member, is_invite_accepted) { + (true, true) => proto::channel_member::Kind::Member, + (true, false) => proto::channel_member::Kind::Invitee, + (false, true) => proto::channel_member::Kind::AncestorMember, + (false, false) => continue, + }; + + if channel_role == ChannelRole::Guest + && visibility != ChannelVisibility::Public + && channel.visibility != ChannelVisibility::Public + { + continue; + } + + if let Some(details_mut) = user_details.get_mut(&user_id) { + if channel_role.should_override(details_mut.role) { + details_mut.role = channel_role; + } + if kind == Kind::Member { + details_mut.kind = kind; + // the UI is going to be a bit confusing if you already have permissions + // that are greater than or equal to the ones you're being invited to. + } else if kind == Kind::Invitee && details_mut.kind == Kind::AncestorMember { + details_mut.kind = kind; + } + } else { + user_details.insert( + user_id, + ChannelMember { + user_id, + kind, + role: channel_role, + }, + ); + } + } + + Ok(user_details + .into_iter() + .map(|(_, details)| details) + .collect()) } - pub async fn check_user_is_channel_member( + pub async fn get_channel_participants( &self, - channel_id: ChannelId, - user_id: UserId, + channel: &channel::Model, tx: &DatabaseTransaction, - ) -> Result<()> { - let channel_ids = self.get_channel_ancestors(channel_id, tx).await?; - channel_member::Entity::find() - .filter( - channel_member::Column::ChannelId - .is_in(channel_ids) - .and(channel_member::Column::UserId.eq(user_id)), - ) - .one(&*tx) - .await? - .ok_or_else(|| anyhow!("user is not a channel member or channel does not exist"))?; - Ok(()) + ) -> Result> { + let participants = self + .get_channel_participant_details_internal(channel, &*tx) + .await?; + Ok(participants + .into_iter() + .map(|member| member.user_id) + .collect()) } pub async fn check_user_is_channel_admin( &self, - channel_id: ChannelId, + channel: &channel::Model, user_id: UserId, tx: &DatabaseTransaction, - ) -> Result<()> { - let channel_ids = self.get_channel_ancestors(channel_id, tx).await?; - channel_member::Entity::find() + ) -> Result { + let role = self.channel_role_for_user(channel, user_id, tx).await?; + match role { + Some(ChannelRole::Admin) => Ok(role.unwrap()), + Some(ChannelRole::Member) + | Some(ChannelRole::Banned) + | Some(ChannelRole::Guest) + | None => Err(anyhow!( + "user is not a channel admin or channel does not exist" + ))?, + } + } + + pub async fn check_user_is_channel_member( + &self, + channel: &channel::Model, + user_id: UserId, + tx: &DatabaseTransaction, + ) -> Result { + let channel_role = self.channel_role_for_user(channel, user_id, tx).await?; + match channel_role { + Some(ChannelRole::Admin) | Some(ChannelRole::Member) => Ok(channel_role.unwrap()), + Some(ChannelRole::Banned) | Some(ChannelRole::Guest) | None => Err(anyhow!( + "user is not a channel member or channel does not exist" + ))?, + } + } + + pub async fn check_user_is_channel_participant( + &self, + channel: &channel::Model, + user_id: UserId, + tx: &DatabaseTransaction, + ) -> Result { + let role = self.channel_role_for_user(channel, user_id, tx).await?; + match role { + Some(ChannelRole::Admin) | Some(ChannelRole::Member) | Some(ChannelRole::Guest) => { + Ok(role.unwrap()) + } + Some(ChannelRole::Banned) | None => Err(anyhow!( + "user is not a channel participant or channel does not exist" + ))?, + } + } + + pub async fn pending_invite_for_channel( + &self, + channel: &channel::Model, + user_id: UserId, + tx: &DatabaseTransaction, + ) -> Result> { + let row = channel_member::Entity::find() + .filter(channel_member::Column::ChannelId.is_in(channel.ancestors_including_self())) + .filter(channel_member::Column::UserId.eq(user_id)) + .filter(channel_member::Column::Accepted.eq(false)) + .one(&*tx) + .await?; + + Ok(row) + } + + pub async fn public_parent_channel( + &self, + channel: &channel::Model, + tx: &DatabaseTransaction, + ) -> Result> { + let mut path = self.public_ancestors_including_self(channel, &*tx).await?; + if path.last().unwrap().id == channel.id { + path.pop(); + } + Ok(path.pop()) + } + + pub async fn public_ancestors_including_self( + &self, + channel: &channel::Model, + tx: &DatabaseTransaction, + ) -> Result> { + let visible_channels = channel::Entity::find() + .filter(channel::Column::Id.is_in(channel.ancestors_including_self())) + .filter(channel::Column::Visibility.eq(ChannelVisibility::Public)) + .order_by_asc(channel::Column::ParentPath) + .all(&*tx) + .await?; + + Ok(visible_channels) + } + + pub async fn channel_role_for_user( + &self, + channel: &channel::Model, + user_id: UserId, + tx: &DatabaseTransaction, + ) -> Result> { + #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] + enum QueryChannelMembership { + ChannelId, + Role, + Visibility, + } + + let mut rows = channel_member::Entity::find() + .left_join(channel::Entity) .filter( channel_member::Column::ChannelId - .is_in(channel_ids) + .is_in(channel.ancestors_including_self()) .and(channel_member::Column::UserId.eq(user_id)) - .and(channel_member::Column::Admin.eq(true)), + .and(channel_member::Column::Accepted.eq(true)), ) - .one(&*tx) - .await? - .ok_or_else(|| anyhow!("user is not a channel admin or channel does not exist"))?; - Ok(()) - } - - /// Returns the channel ancestors, deepest first - pub async fn get_channel_ancestors( - &self, - channel_id: ChannelId, - tx: &DatabaseTransaction, - ) -> Result> { - let paths = channel_path::Entity::find() - .filter(channel_path::Column::ChannelId.eq(channel_id)) - .order_by(channel_path::Column::IdPath, sea_orm::Order::Desc) - .all(tx) + .select_only() + .column(channel_member::Column::ChannelId) + .column(channel_member::Column::Role) + .column(channel::Column::Visibility) + .into_values::<_, QueryChannelMembership>() + .stream(&*tx) .await?; - let mut channel_ids = Vec::new(); - for path in paths { - for id in path.id_path.trim_matches('/').split('/') { - if let Ok(id) = id.parse() { - let id = ChannelId::from_proto(id); - if let Err(ix) = channel_ids.binary_search(&id) { - channel_ids.insert(ix, id); + + let mut user_role: Option = None; + + let mut is_participant = false; + let mut current_channel_visibility = None; + + // note these channels are not iterated in any particular order, + // our current logic takes the highest permission available. + while let Some(row) = rows.next().await { + let (membership_channel, role, visibility): ( + ChannelId, + ChannelRole, + ChannelVisibility, + ) = row?; + + match role { + ChannelRole::Admin | ChannelRole::Member | ChannelRole::Banned => { + if let Some(users_role) = user_role { + user_role = Some(users_role.max(role)); + } else { + user_role = Some(role) } } + ChannelRole::Guest if visibility == ChannelVisibility::Public => { + is_participant = true + } + ChannelRole::Guest => {} + } + if channel.id == membership_channel { + current_channel_visibility = Some(visibility); } } - Ok(channel_ids) + // free up database connection + drop(rows); + + if is_participant && user_role.is_none() { + if current_channel_visibility.is_none() { + current_channel_visibility = channel::Entity::find() + .filter(channel::Column::Id.eq(channel.id)) + .one(&*tx) + .await? + .map(|channel| channel.visibility); + } + if current_channel_visibility == Some(ChannelVisibility::Public) { + user_role = Some(ChannelRole::Guest); + } + } + + Ok(user_role) } - /// Returns the channel descendants, - /// Structured as a map from child ids to their parent ids - /// For example, the descendants of 'a' in this DAG: - /// - /// /- b -\ - /// a -- c -- d - /// - /// would be: - /// { - /// a: [], - /// b: [a], - /// c: [a], - /// d: [a, c], - /// } - async fn get_channel_descendants( + // Get the descendants of the given set if channels, ordered by their + // path. + async fn get_channel_descendants_including_self( &self, channel_ids: impl IntoIterator, tx: &DatabaseTransaction, - ) -> Result { + ) -> Result> { let mut values = String::new(); for id in channel_ids { if !values.is_empty() { @@ -688,403 +1112,201 @@ impl Database { } if values.is_empty() { - return Ok(HashMap::default()); + return Ok(vec![]); } let sql = format!( r#" - SELECT - descendant_paths.* + SELECT DISTINCT + descendant_channels.*, + descendant_channels.parent_path || descendant_channels.id as full_path FROM - channel_paths parent_paths, channel_paths descendant_paths + channels parent_channels, channels descendant_channels WHERE - parent_paths.channel_id IN ({values}) AND - descendant_paths.id_path LIKE (parent_paths.id_path || '%') - "# + descendant_channels.id IN ({values}) OR + ( + parent_channels.id IN ({values}) AND + descendant_channels.parent_path LIKE (parent_channels.parent_path || parent_channels.id || '/%') + ) + ORDER BY + full_path ASC + "# ); - let stmt = Statement::from_string(self.pool.get_database_backend(), sql); - - let mut parents_by_child_id: ChannelDescendants = HashMap::default(); - let mut paths = channel_path::Entity::find() - .from_raw_sql(stmt) - .stream(tx) - .await?; - - while let Some(path) = paths.next().await { - let path = path?; - let ids = path.id_path.trim_matches('/').split('/'); - let mut parent_id = None; - for id in ids { - if let Ok(id) = id.parse() { - let id = ChannelId::from_proto(id); - if id == path.channel_id { - break; - } - parent_id = Some(id); - } - } - let entry = parents_by_child_id.entry(path.channel_id).or_default(); - if let Some(parent_id) = parent_id { - entry.insert(parent_id); - } - } - - Ok(parents_by_child_id) + Ok(channel::Entity::find() + .from_raw_sql(Statement::from_string( + self.pool.get_database_backend(), + sql, + )) + .all(tx) + .await?) } - /// Returns the channel with the given ID and: - /// - true if the user is a member - /// - false if the user hasn't accepted the invitation yet - pub async fn get_channel( - &self, - channel_id: ChannelId, - user_id: UserId, - ) -> Result> { + /// Returns the channel with the given ID + pub async fn get_channel(&self, channel_id: ChannelId, user_id: UserId) -> Result { self.transaction(|tx| async move { - let tx = tx; + let channel = self.get_channel_internal(channel_id, &*tx).await?; + let role = self + .check_user_is_channel_participant(&channel, user_id, &*tx) + .await?; - let channel = channel::Entity::find_by_id(channel_id).one(&*tx).await?; - - if let Some(channel) = channel { - if self - .check_user_is_channel_member(channel_id, user_id, &*tx) - .await - .is_err() - { - return Ok(None); - } - - let channel_membership = channel_member::Entity::find() - .filter( - channel_member::Column::ChannelId - .eq(channel_id) - .and(channel_member::Column::UserId.eq(user_id)), - ) - .one(&*tx) - .await?; - - let is_accepted = channel_membership - .map(|membership| membership.accepted) - .unwrap_or(false); - - Ok(Some(( - Channel { - id: channel.id, - name: channel.name, - }, - is_accepted, - ))) - } else { - Ok(None) - } + Ok(Channel::from_model(channel, role)) }) .await } - pub async fn get_or_create_channel_room( + pub async fn get_channel_internal( + &self, + channel_id: ChannelId, + tx: &DatabaseTransaction, + ) -> Result { + Ok(channel::Entity::find_by_id(channel_id) + .one(&*tx) + .await? + .ok_or_else(|| anyhow!("no such channel"))?) + } + + pub(crate) async fn get_or_create_channel_room( &self, channel_id: ChannelId, live_kit_room: &str, - enviroment: &str, - ) -> Result { - self.transaction(|tx| async move { - let tx = tx; - - let room = room::Entity::find() - .filter(room::Column::ChannelId.eq(channel_id)) - .one(&*tx) - .await?; - - let room_id = if let Some(room) = room { - room.id - } else { - let result = room::Entity::insert(room::ActiveModel { - channel_id: ActiveValue::Set(Some(channel_id)), - live_kit_room: ActiveValue::Set(live_kit_room.to_string()), - enviroment: ActiveValue::Set(Some(enviroment.to_string())), - ..Default::default() - }) - .exec(&*tx) - .await?; - - result.last_insert_id - }; - - Ok(room_id) - }) - .await - } - - // Insert an edge from the given channel to the given other channel. - pub async fn link_channel( - &self, - user: UserId, - channel: ChannelId, - to: ChannelId, - ) -> Result { - self.transaction(|tx| async move { - // Note that even with these maxed permissions, this linking operation - // is still insecure because you can't remove someone's permissions to a - // channel if they've linked the channel to one where they're an admin. - self.check_user_is_channel_admin(channel, user, &*tx) - .await?; - - self.link_channel_internal(user, channel, to, &*tx).await - }) - .await - } - - pub async fn link_channel_internal( - &self, - user: UserId, - channel: ChannelId, - to: ChannelId, + environment: &str, tx: &DatabaseTransaction, - ) -> Result { - self.check_user_is_channel_admin(to, user, &*tx).await?; - - let paths = channel_path::Entity::find() - .filter(channel_path::Column::IdPath.like(&format!("%/{}/%", channel))) - .all(tx) + ) -> Result { + let room = room::Entity::find() + .filter(room::Column::ChannelId.eq(channel_id)) + .one(&*tx) .await?; - let mut new_path_suffixes = HashSet::default(); - for path in paths { - if let Some(start_offset) = path.id_path.find(&format!("/{}/", channel)) { - new_path_suffixes.insert(( - path.channel_id, - path.id_path[(start_offset + 1)..].to_string(), - )); - } - } - - let paths_to_new_parent = channel_path::Entity::find() - .filter(channel_path::Column::ChannelId.eq(to)) - .all(tx) - .await?; - - let mut new_paths = Vec::new(); - for path in paths_to_new_parent { - if path.id_path.contains(&format!("/{}/", channel)) { - Err(anyhow!("cycle"))?; - } - - new_paths.extend(new_path_suffixes.iter().map(|(channel_id, path_suffix)| { - channel_path::ActiveModel { - channel_id: ActiveValue::Set(*channel_id), - id_path: ActiveValue::Set(format!("{}{}", &path.id_path, path_suffix)), + let room_id = if let Some(room) = room { + if let Some(env) = room.enviroment { + if &env != environment { + Err(anyhow!("must join using the {} release", env))?; } - })); - } - - channel_path::Entity::insert_many(new_paths) + } + room.id + } else { + let result = room::Entity::insert(room::ActiveModel { + channel_id: ActiveValue::Set(Some(channel_id)), + live_kit_room: ActiveValue::Set(live_kit_room.to_string()), + enviroment: ActiveValue::Set(Some(environment.to_string())), + ..Default::default() + }) .exec(&*tx) .await?; - // remove any root edges for the channel we just linked - { - channel_path::Entity::delete_many() - .filter(channel_path::Column::IdPath.like(&format!("/{}/%", channel))) - .exec(&*tx) - .await?; - } + result.last_insert_id + }; - let mut channel_descendants = self.get_channel_descendants([channel], &*tx).await?; - if let Some(channel) = channel_descendants.get_mut(&channel) { - // Remove the other parents - channel.clear(); - channel.insert(to); - } - - let channels = self - .get_channel_graph(channel_descendants, false, &*tx) - .await?; - - Ok(channels) + Ok(room_id) } - /// Unlink a channel from a given parent. This will add in a root edge if - /// the channel has no other parents after this operation. - pub async fn unlink_channel( - &self, - user: UserId, - channel: ChannelId, - from: ChannelId, - ) -> Result<()> { - self.transaction(|tx| async move { - // Note that even with these maxed permissions, this linking operation - // is still insecure because you can't remove someone's permissions to a - // channel if they've linked the channel to one where they're an admin. - self.check_user_is_channel_admin(channel, user, &*tx) - .await?; - - self.unlink_channel_internal(user, channel, from, &*tx) - .await?; - - Ok(()) - }) - .await - } - - pub async fn unlink_channel_internal( - &self, - user: UserId, - channel: ChannelId, - from: ChannelId, - tx: &DatabaseTransaction, - ) -> Result<()> { - self.check_user_is_channel_admin(from, user, &*tx).await?; - - let sql = r#" - DELETE FROM channel_paths - WHERE - id_path LIKE '%/' || $1 || '/' || $2 || '/%' - RETURNING id_path, channel_id - "#; - - let paths = channel_path::Entity::find() - .from_raw_sql(Statement::from_sql_and_values( - self.pool.get_database_backend(), - sql, - [from.to_proto().into(), channel.to_proto().into()], - )) - .all(&*tx) - .await?; - - let is_stranded = channel_path::Entity::find() - .filter(channel_path::Column::ChannelId.eq(channel)) - .count(&*tx) - .await? - == 0; - - // Make sure that there is always at least one path to the channel - if is_stranded { - let root_paths: Vec<_> = paths - .iter() - .map(|path| { - let start_offset = path.id_path.find(&format!("/{}/", channel)).unwrap(); - channel_path::ActiveModel { - channel_id: ActiveValue::Set(path.channel_id), - id_path: ActiveValue::Set(path.id_path[start_offset..].to_string()), - } - }) - .collect(); - channel_path::Entity::insert_many(root_paths) - .exec(&*tx) - .await?; - } - - Ok(()) - } - - /// Move a channel from one parent to another, returns the - /// Channels that were moved for notifying clients + /// Move a channel from one parent to another pub async fn move_channel( &self, - user: UserId, - channel: ChannelId, - from: ChannelId, - to: ChannelId, - ) -> Result { - if from == to { - return Ok(ChannelGraph { - channels: vec![], - edges: vec![], - }); - } - + channel_id: ChannelId, + new_parent_id: Option, + admin_id: UserId, + ) -> Result> { self.transaction(|tx| async move { - self.check_user_is_channel_admin(channel, user, &*tx) + let channel = self.get_channel_internal(channel_id, &*tx).await?; + self.check_user_is_channel_admin(&channel, admin_id, &*tx) .await?; - let moved_channels = self.link_channel_internal(user, channel, to, &*tx).await?; + let new_parent_path; + let new_parent_channel; + if let Some(new_parent_id) = new_parent_id { + let new_parent = self.get_channel_internal(new_parent_id, &*tx).await?; + self.check_user_is_channel_admin(&new_parent, admin_id, &*tx) + .await?; - self.unlink_channel_internal(user, channel, from, &*tx) + new_parent_path = new_parent.path(); + new_parent_channel = Some(new_parent); + } else { + new_parent_path = String::new(); + new_parent_channel = None; + }; + + let previous_participants = self + .get_channel_participant_details_internal(&channel, &*tx) .await?; - Ok(moved_channels) + let old_path = format!("{}{}/", channel.parent_path, channel.id); + let new_path = format!("{}{}/", new_parent_path, channel.id); + + if old_path == new_path { + return Ok(None); + } + + let mut model = channel.into_active_model(); + model.parent_path = ActiveValue::Set(new_parent_path); + let channel = model.update(&*tx).await?; + + if new_parent_channel.is_none() { + channel_member::ActiveModel { + id: ActiveValue::NotSet, + channel_id: ActiveValue::Set(channel_id), + user_id: ActiveValue::Set(admin_id), + accepted: ActiveValue::Set(true), + role: ActiveValue::Set(ChannelRole::Admin), + } + .insert(&*tx) + .await?; + } + + let descendent_ids = + ChannelId::find_by_statement::(Statement::from_sql_and_values( + self.pool.get_database_backend(), + " + UPDATE channels SET parent_path = REPLACE(parent_path, $1, $2) + WHERE parent_path LIKE $3 || '%' + RETURNING id + ", + [old_path.clone().into(), new_path.into(), old_path.into()], + )) + .all(&*tx) + .await?; + + let participants_to_update: HashMap<_, _> = self + .participants_to_notify_for_channel_change( + new_parent_channel.as_ref().unwrap_or(&channel), + &*tx, + ) + .await? + .into_iter() + .collect(); + + let mut moved_channels: HashSet = HashSet::default(); + for id in descendent_ids { + moved_channels.insert(id); + } + moved_channels.insert(channel_id); + + let mut participants_to_remove: HashSet = HashSet::default(); + for participant in previous_participants { + if participant.kind == proto::channel_member::Kind::AncestorMember { + if !participants_to_update.contains_key(&participant.user_id) { + participants_to_remove.insert(participant.user_id); + } + } + } + + Ok(Some(MoveChannelResult { + participants_to_remove, + participants_to_update, + moved_channels, + })) }) .await } } +#[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] +enum QueryIds { + Id, +} + #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] enum QueryUserIds { UserId, } - -#[derive(Debug)] -pub struct ChannelGraph { - pub channels: Vec, - pub edges: Vec, -} - -impl ChannelGraph { - pub fn is_empty(&self) -> bool { - self.channels.is_empty() && self.edges.is_empty() - } -} - -#[cfg(test)] -impl PartialEq for ChannelGraph { - fn eq(&self, other: &Self) -> bool { - // Order independent comparison for tests - let channels_set = self.channels.iter().collect::>(); - let other_channels_set = other.channels.iter().collect::>(); - let edges_set = self - .edges - .iter() - .map(|edge| (edge.channel_id, edge.parent_id)) - .collect::>(); - let other_edges_set = other - .edges - .iter() - .map(|edge| (edge.channel_id, edge.parent_id)) - .collect::>(); - - channels_set == other_channels_set && edges_set == other_edges_set - } -} - -#[cfg(not(test))] -impl PartialEq for ChannelGraph { - fn eq(&self, other: &Self) -> bool { - self.channels == other.channels && self.edges == other.edges - } -} - -struct SmallSet(SmallVec<[T; 1]>); - -impl Deref for SmallSet { - type Target = [T]; - - fn deref(&self) -> &Self::Target { - self.0.deref() - } -} - -impl Default for SmallSet { - fn default() -> Self { - Self(SmallVec::new()) - } -} - -impl SmallSet { - fn insert(&mut self, value: T) -> bool - where - T: Ord, - { - match self.binary_search(&value) { - Ok(_) => false, - Err(ix) => { - self.0.insert(ix, value); - true - } - } - } - - fn clear(&mut self) { - self.0.clear(); - } -} diff --git a/crates/collab/src/db/queries/contacts.rs b/crates/collab/src/db/queries/contacts.rs index 2171f1a6bf..f31f1addbd 100644 --- a/crates/collab/src/db/queries/contacts.rs +++ b/crates/collab/src/db/queries/contacts.rs @@ -8,7 +8,6 @@ impl Database { user_id_b: UserId, a_to_b: bool, accepted: bool, - should_notify: bool, user_a_busy: bool, user_b_busy: bool, } @@ -53,7 +52,6 @@ impl Database { if db_contact.accepted { contacts.push(Contact::Accepted { user_id: db_contact.user_id_b, - should_notify: db_contact.should_notify && db_contact.a_to_b, busy: db_contact.user_b_busy, }); } else if db_contact.a_to_b { @@ -63,19 +61,16 @@ impl Database { } else { contacts.push(Contact::Incoming { user_id: db_contact.user_id_b, - should_notify: db_contact.should_notify, }); } } else if db_contact.accepted { contacts.push(Contact::Accepted { user_id: db_contact.user_id_a, - should_notify: db_contact.should_notify && !db_contact.a_to_b, busy: db_contact.user_a_busy, }); } else if db_contact.a_to_b { contacts.push(Contact::Incoming { user_id: db_contact.user_id_a, - should_notify: db_contact.should_notify, }); } else { contacts.push(Contact::Outgoing { @@ -124,7 +119,11 @@ impl Database { .await } - pub async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> { + pub async fn send_contact_request( + &self, + sender_id: UserId, + receiver_id: UserId, + ) -> Result { self.transaction(|tx| async move { let (id_a, id_b, a_to_b) = if sender_id < receiver_id { (sender_id, receiver_id, true) @@ -161,11 +160,22 @@ impl Database { .exec_without_returning(&*tx) .await?; - if rows_affected == 1 { - Ok(()) - } else { - Err(anyhow!("contact already requested"))? + if rows_affected == 0 { + Err(anyhow!("contact already requested"))?; } + + Ok(self + .create_notification( + receiver_id, + rpc::Notification::ContactRequest { + sender_id: sender_id.to_proto(), + }, + true, + &*tx, + ) + .await? + .into_iter() + .collect()) }) .await } @@ -179,7 +189,11 @@ impl Database { /// /// * `requester_id` - The user that initiates this request /// * `responder_id` - The user that will be removed - pub async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result { + pub async fn remove_contact( + &self, + requester_id: UserId, + responder_id: UserId, + ) -> Result<(bool, Option)> { self.transaction(|tx| async move { let (id_a, id_b) = if responder_id < requester_id { (responder_id, requester_id) @@ -198,7 +212,21 @@ impl Database { .ok_or_else(|| anyhow!("no such contact"))?; contact::Entity::delete_by_id(contact.id).exec(&*tx).await?; - Ok(contact.accepted) + + let mut deleted_notification_id = None; + if !contact.accepted { + deleted_notification_id = self + .remove_notification( + responder_id, + rpc::Notification::ContactRequest { + sender_id: requester_id.to_proto(), + }, + &*tx, + ) + .await?; + } + + Ok((contact.accepted, deleted_notification_id)) }) .await } @@ -249,7 +277,7 @@ impl Database { responder_id: UserId, requester_id: UserId, accept: bool, - ) -> Result<()> { + ) -> Result { self.transaction(|tx| async move { let (id_a, id_b, a_to_b) = if responder_id < requester_id { (responder_id, requester_id, false) @@ -287,11 +315,38 @@ impl Database { result.rows_affected }; - if rows_affected == 1 { - Ok(()) - } else { + if rows_affected == 0 { Err(anyhow!("no such contact request"))? } + + let mut notifications = Vec::new(); + notifications.extend( + self.mark_notification_as_read_with_response( + responder_id, + &rpc::Notification::ContactRequest { + sender_id: requester_id.to_proto(), + }, + accept, + &*tx, + ) + .await?, + ); + + if accept { + notifications.extend( + self.create_notification( + requester_id, + rpc::Notification::ContactRequestAccepted { + responder_id: responder_id.to_proto(), + }, + true, + &*tx, + ) + .await?, + ); + } + + Ok(notifications) }) .await } diff --git a/crates/collab/src/db/queries/messages.rs b/crates/collab/src/db/queries/messages.rs index a48d425d90..47bb27df39 100644 --- a/crates/collab/src/db/queries/messages.rs +++ b/crates/collab/src/db/queries/messages.rs @@ -1,4 +1,6 @@ use super::*; +use rpc::Notification; +use sea_orm::TryInsertResult; use time::OffsetDateTime; impl Database { @@ -9,7 +11,8 @@ impl Database { user_id: UserId, ) -> Result<()> { self.transaction(|tx| async move { - self.check_user_is_channel_member(channel_id, user_id, &*tx) + let channel = self.get_channel_internal(channel_id, &*tx).await?; + self.check_user_is_channel_participant(&channel, user_id, &*tx) .await?; channel_chat_participant::ActiveModel { id: ActiveValue::NotSet, @@ -77,7 +80,8 @@ impl Database { before_message_id: Option, ) -> Result> { self.transaction(|tx| async move { - self.check_user_is_channel_member(channel_id, user_id, &*tx) + let channel = self.get_channel_internal(channel_id, &*tx).await?; + self.check_user_is_channel_participant(&channel, user_id, &*tx) .await?; let mut condition = @@ -87,33 +91,103 @@ impl Database { condition = condition.add(channel_message::Column::Id.lt(before_message_id)); } - let mut rows = channel_message::Entity::find() + let rows = channel_message::Entity::find() .filter(condition) .order_by_desc(channel_message::Column::Id) .limit(count as u64) - .stream(&*tx) + .all(&*tx) .await?; - let mut messages = Vec::new(); - while let Some(row) = rows.next().await { - let row = row?; + self.load_channel_messages(rows, &*tx).await + }) + .await + } + + pub async fn get_channel_messages_by_id( + &self, + user_id: UserId, + message_ids: &[MessageId], + ) -> Result> { + self.transaction(|tx| async move { + let rows = channel_message::Entity::find() + .filter(channel_message::Column::Id.is_in(message_ids.iter().copied())) + .order_by_desc(channel_message::Column::Id) + .all(&*tx) + .await?; + + let mut channels = HashMap::::default(); + for row in &rows { + channels.insert( + row.channel_id, + self.get_channel_internal(row.channel_id, &*tx).await?, + ); + } + + for (_, channel) in channels { + self.check_user_is_channel_participant(&channel, user_id, &*tx) + .await?; + } + + let messages = self.load_channel_messages(rows, &*tx).await?; + Ok(messages) + }) + .await + } + + async fn load_channel_messages( + &self, + rows: Vec, + tx: &DatabaseTransaction, + ) -> Result> { + let mut messages = rows + .into_iter() + .map(|row| { let nonce = row.nonce.as_u64_pair(); - messages.push(proto::ChannelMessage { + proto::ChannelMessage { id: row.id.to_proto(), sender_id: row.sender_id.to_proto(), body: row.body, timestamp: row.sent_at.assume_utc().unix_timestamp() as u64, + mentions: vec![], nonce: Some(proto::Nonce { upper_half: nonce.0, lower_half: nonce.1, }), - }); + } + }) + .collect::>(); + messages.reverse(); + + let mut mentions = channel_message_mention::Entity::find() + .filter(channel_message_mention::Column::MessageId.is_in(messages.iter().map(|m| m.id))) + .order_by_asc(channel_message_mention::Column::MessageId) + .order_by_asc(channel_message_mention::Column::StartOffset) + .stream(&*tx) + .await?; + + let mut message_ix = 0; + while let Some(mention) = mentions.next().await { + let mention = mention?; + let message_id = mention.message_id.to_proto(); + while let Some(message) = messages.get_mut(message_ix) { + if message.id < message_id { + message_ix += 1; + } else { + if message.id == message_id { + message.mentions.push(proto::ChatMention { + range: Some(proto::Range { + start: mention.start_offset as u64, + end: mention.end_offset as u64, + }), + user_id: mention.user_id.to_proto(), + }); + } + break; + } } - drop(rows); - messages.reverse(); - Ok(messages) - }) - .await + } + + Ok(messages) } pub async fn create_channel_message( @@ -121,10 +195,15 @@ impl Database { channel_id: ChannelId, user_id: UserId, body: &str, + mentions: &[proto::ChatMention], timestamp: OffsetDateTime, nonce: u128, - ) -> Result<(MessageId, Vec, Vec)> { + ) -> Result { self.transaction(|tx| async move { + let channel = self.get_channel_internal(channel_id, &*tx).await?; + self.check_user_is_channel_participant(&channel, user_id, &*tx) + .await?; + let mut rows = channel_chat_participant::Entity::find() .filter(channel_chat_participant::Column::ChannelId.eq(channel_id)) .stream(&*tx) @@ -150,7 +229,7 @@ impl Database { let timestamp = timestamp.to_offset(time::UtcOffset::UTC); let timestamp = time::PrimitiveDateTime::new(timestamp.date(), timestamp.time()); - let message = channel_message::Entity::insert(channel_message::ActiveModel { + let result = channel_message::Entity::insert(channel_message::ActiveModel { channel_id: ActiveValue::Set(channel_id), sender_id: ActiveValue::Set(user_id), body: ActiveValue::Set(body.to_string()), @@ -159,35 +238,85 @@ impl Database { id: ActiveValue::NotSet, }) .on_conflict( - OnConflict::column(channel_message::Column::Nonce) - .update_column(channel_message::Column::Nonce) - .to_owned(), + OnConflict::columns([ + channel_message::Column::SenderId, + channel_message::Column::Nonce, + ]) + .do_nothing() + .to_owned(), ) + .do_nothing() .exec(&*tx) .await?; - #[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)] - enum QueryConnectionId { - ConnectionId, + let message_id; + let mut notifications = Vec::new(); + match result { + TryInsertResult::Inserted(result) => { + message_id = result.last_insert_id; + let mentioned_user_ids = + mentions.iter().map(|m| m.user_id).collect::>(); + let mentions = mentions + .iter() + .filter_map(|mention| { + let range = mention.range.as_ref()?; + if !body.is_char_boundary(range.start as usize) + || !body.is_char_boundary(range.end as usize) + { + return None; + } + Some(channel_message_mention::ActiveModel { + message_id: ActiveValue::Set(message_id), + start_offset: ActiveValue::Set(range.start as i32), + end_offset: ActiveValue::Set(range.end as i32), + user_id: ActiveValue::Set(UserId::from_proto(mention.user_id)), + }) + }) + .collect::>(); + if !mentions.is_empty() { + channel_message_mention::Entity::insert_many(mentions) + .exec(&*tx) + .await?; + } + + for mentioned_user in mentioned_user_ids { + notifications.extend( + self.create_notification( + UserId::from_proto(mentioned_user), + rpc::Notification::ChannelMessageMention { + message_id: message_id.to_proto(), + sender_id: user_id.to_proto(), + channel_id: channel_id.to_proto(), + }, + false, + &*tx, + ) + .await?, + ); + } + + self.observe_channel_message_internal(channel_id, user_id, message_id, &*tx) + .await?; + } + _ => { + message_id = channel_message::Entity::find() + .filter(channel_message::Column::Nonce.eq(Uuid::from_u128(nonce))) + .one(&*tx) + .await? + .ok_or_else(|| anyhow!("failed to insert message"))? + .id; + } } - // Observe this message for the sender - self.observe_channel_message_internal( - channel_id, - user_id, - message.last_insert_id, - &*tx, - ) - .await?; - - let mut channel_members = self.get_channel_members_internal(channel_id, &*tx).await?; + let mut channel_members = self.get_channel_participants(&channel, &*tx).await?; channel_members.retain(|member| !participant_user_ids.contains(member)); - Ok(( - message.last_insert_id, + Ok(CreatedChannelMessage { + message_id, participant_connection_ids, channel_members, - )) + notifications, + }) }) .await } @@ -197,11 +326,24 @@ impl Database { channel_id: ChannelId, user_id: UserId, message_id: MessageId, - ) -> Result<()> { + ) -> Result { self.transaction(|tx| async move { self.observe_channel_message_internal(channel_id, user_id, message_id, &*tx) .await?; - Ok(()) + let mut batch = NotificationBatch::default(); + batch.extend( + self.mark_notification_as_read( + user_id, + &Notification::ChannelMessageMention { + message_id: message_id.to_proto(), + sender_id: Default::default(), + channel_id: Default::default(), + }, + &*tx, + ) + .await?, + ); + Ok(batch) }) .await } @@ -337,8 +479,23 @@ impl Database { .filter(channel_message::Column::SenderId.eq(user_id)) .exec(&*tx) .await?; + if result.rows_affected == 0 { - Err(anyhow!("no such message"))?; + let channel = self.get_channel_internal(channel_id, &*tx).await?; + if self + .check_user_is_channel_admin(&channel, user_id, &*tx) + .await + .is_ok() + { + let result = channel_message::Entity::delete_by_id(message_id) + .exec(&*tx) + .await?; + if result.rows_affected == 0 { + Err(anyhow!("no such message"))?; + } + } else { + Err(anyhow!("operation could not be completed"))?; + } } Ok(participant_connection_ids) diff --git a/crates/collab/src/db/queries/notifications.rs b/crates/collab/src/db/queries/notifications.rs new file mode 100644 index 0000000000..6f2511c23e --- /dev/null +++ b/crates/collab/src/db/queries/notifications.rs @@ -0,0 +1,262 @@ +use super::*; +use rpc::Notification; + +impl Database { + pub async fn initialize_notification_kinds(&mut self) -> Result<()> { + notification_kind::Entity::insert_many(Notification::all_variant_names().iter().map( + |kind| notification_kind::ActiveModel { + name: ActiveValue::Set(kind.to_string()), + ..Default::default() + }, + )) + .on_conflict(OnConflict::new().do_nothing().to_owned()) + .exec_without_returning(&self.pool) + .await?; + + let mut rows = notification_kind::Entity::find().stream(&self.pool).await?; + while let Some(row) = rows.next().await { + let row = row?; + self.notification_kinds_by_name.insert(row.name, row.id); + } + + for name in Notification::all_variant_names() { + if let Some(id) = self.notification_kinds_by_name.get(*name).copied() { + self.notification_kinds_by_id.insert(id, name); + } + } + + Ok(()) + } + + pub async fn get_notifications( + &self, + recipient_id: UserId, + limit: usize, + before_id: Option, + ) -> Result> { + self.transaction(|tx| async move { + let mut result = Vec::new(); + let mut condition = + Condition::all().add(notification::Column::RecipientId.eq(recipient_id)); + + if let Some(before_id) = before_id { + condition = condition.add(notification::Column::Id.lt(before_id)); + } + + let mut rows = notification::Entity::find() + .filter(condition) + .order_by_desc(notification::Column::Id) + .limit(limit as u64) + .stream(&*tx) + .await?; + while let Some(row) = rows.next().await { + let row = row?; + let kind = row.kind; + if let Some(proto) = model_to_proto(self, row) { + result.push(proto); + } else { + log::warn!("unknown notification kind {:?}", kind); + } + } + result.reverse(); + Ok(result) + }) + .await + } + + /// Create a notification. If `avoid_duplicates` is set to true, then avoid + /// creating a new notification if the given recipient already has an + /// unread notification with the given kind and entity id. + pub async fn create_notification( + &self, + recipient_id: UserId, + notification: Notification, + avoid_duplicates: bool, + tx: &DatabaseTransaction, + ) -> Result> { + if avoid_duplicates { + if self + .find_notification(recipient_id, ¬ification, tx) + .await? + .is_some() + { + return Ok(None); + } + } + + let proto = notification.to_proto(); + let kind = notification_kind_from_proto(self, &proto)?; + let model = notification::ActiveModel { + recipient_id: ActiveValue::Set(recipient_id), + kind: ActiveValue::Set(kind), + entity_id: ActiveValue::Set(proto.entity_id.map(|id| id as i32)), + content: ActiveValue::Set(proto.content.clone()), + ..Default::default() + } + .save(&*tx) + .await?; + + Ok(Some(( + recipient_id, + proto::Notification { + id: model.id.as_ref().to_proto(), + kind: proto.kind, + timestamp: model.created_at.as_ref().assume_utc().unix_timestamp() as u64, + is_read: false, + response: None, + content: proto.content, + entity_id: proto.entity_id, + }, + ))) + } + + /// Remove an unread notification with the given recipient, kind and + /// entity id. + pub async fn remove_notification( + &self, + recipient_id: UserId, + notification: Notification, + tx: &DatabaseTransaction, + ) -> Result> { + let id = self + .find_notification(recipient_id, ¬ification, tx) + .await?; + if let Some(id) = id { + notification::Entity::delete_by_id(id).exec(tx).await?; + } + Ok(id) + } + + /// Populate the response for the notification with the given kind and + /// entity id. + pub async fn mark_notification_as_read_with_response( + &self, + recipient_id: UserId, + notification: &Notification, + response: bool, + tx: &DatabaseTransaction, + ) -> Result> { + self.mark_notification_as_read_internal(recipient_id, notification, Some(response), tx) + .await + } + + pub async fn mark_notification_as_read( + &self, + recipient_id: UserId, + notification: &Notification, + tx: &DatabaseTransaction, + ) -> Result> { + self.mark_notification_as_read_internal(recipient_id, notification, None, tx) + .await + } + + pub async fn mark_notification_as_read_by_id( + &self, + recipient_id: UserId, + notification_id: NotificationId, + ) -> Result { + self.transaction(|tx| async move { + let row = notification::Entity::update(notification::ActiveModel { + id: ActiveValue::Unchanged(notification_id), + recipient_id: ActiveValue::Unchanged(recipient_id), + is_read: ActiveValue::Set(true), + ..Default::default() + }) + .exec(&*tx) + .await?; + Ok(model_to_proto(self, row) + .map(|notification| (recipient_id, notification)) + .into_iter() + .collect()) + }) + .await + } + + async fn mark_notification_as_read_internal( + &self, + recipient_id: UserId, + notification: &Notification, + response: Option, + tx: &DatabaseTransaction, + ) -> Result> { + if let Some(id) = self + .find_notification(recipient_id, notification, &*tx) + .await? + { + let row = notification::Entity::update(notification::ActiveModel { + id: ActiveValue::Unchanged(id), + recipient_id: ActiveValue::Unchanged(recipient_id), + is_read: ActiveValue::Set(true), + response: if let Some(response) = response { + ActiveValue::Set(Some(response)) + } else { + ActiveValue::NotSet + }, + ..Default::default() + }) + .exec(tx) + .await?; + Ok(model_to_proto(self, row).map(|notification| (recipient_id, notification))) + } else { + Ok(None) + } + } + + /// Find an unread notification by its recipient, kind and entity id. + async fn find_notification( + &self, + recipient_id: UserId, + notification: &Notification, + tx: &DatabaseTransaction, + ) -> Result> { + let proto = notification.to_proto(); + let kind = notification_kind_from_proto(self, &proto)?; + + #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] + enum QueryIds { + Id, + } + + Ok(notification::Entity::find() + .select_only() + .column(notification::Column::Id) + .filter( + Condition::all() + .add(notification::Column::RecipientId.eq(recipient_id)) + .add(notification::Column::IsRead.eq(false)) + .add(notification::Column::Kind.eq(kind)) + .add(if proto.entity_id.is_some() { + notification::Column::EntityId.eq(proto.entity_id) + } else { + notification::Column::EntityId.is_null() + }), + ) + .into_values::<_, QueryIds>() + .one(&*tx) + .await?) + } +} + +fn model_to_proto(this: &Database, row: notification::Model) -> Option { + let kind = this.notification_kinds_by_id.get(&row.kind)?; + Some(proto::Notification { + id: row.id.to_proto(), + kind: kind.to_string(), + timestamp: row.created_at.assume_utc().unix_timestamp() as u64, + is_read: row.is_read, + response: row.response, + content: row.content, + entity_id: row.entity_id.map(|id| id as u64), + }) +} + +fn notification_kind_from_proto( + this: &Database, + proto: &proto::Notification, +) -> Result { + Ok(this + .notification_kinds_by_name + .get(&proto.kind) + .copied() + .ok_or_else(|| anyhow!("invalid notification kind {:?}", proto.kind))?) +} diff --git a/crates/collab/src/db/queries/rooms.rs b/crates/collab/src/db/queries/rooms.rs index a38c77dc0f..40fdf5d58f 100644 --- a/crates/collab/src/db/queries/rooms.rs +++ b/crates/collab/src/db/queries/rooms.rs @@ -50,10 +50,10 @@ impl Database { .map(|participant| participant.user_id), ); - let (channel_id, room) = self.get_channel_room(room_id, &tx).await?; + let (channel, room) = self.get_channel_room(room_id, &tx).await?; let channel_members; - if let Some(channel_id) = channel_id { - channel_members = self.get_channel_members_internal(channel_id, &tx).await?; + if let Some(channel) = &channel { + channel_members = self.get_channel_participants(channel, &tx).await?; } else { channel_members = Vec::new(); @@ -69,7 +69,7 @@ impl Database { Ok(RefreshedRoom { room, - channel_id, + channel_id: channel.map(|channel| channel.id), channel_members, stale_participant_user_ids, canceled_calls_to_user_ids, @@ -298,98 +298,137 @@ impl Database { } } - #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] - enum QueryParticipantIndices { - ParticipantIndex, + if channel_id.is_some() { + Err(anyhow!("tried to join channel call directly"))? } - let existing_participant_indices: Vec = room_participant::Entity::find() - .filter( - room_participant::Column::RoomId - .eq(room_id) - .and(room_participant::Column::ParticipantIndex.is_not_null()), - ) - .select_only() - .column(room_participant::Column::ParticipantIndex) - .into_values::<_, QueryParticipantIndices>() - .all(&*tx) + + let participant_index = self + .get_next_participant_index_internal(room_id, &*tx) .await?; - let mut participant_index = 0; - while existing_participant_indices.contains(&participant_index) { - participant_index += 1; - } - - if let Some(channel_id) = channel_id { - self.check_user_is_channel_member(channel_id, user_id, &*tx) - .await?; - - room_participant::Entity::insert_many([room_participant::ActiveModel { - room_id: ActiveValue::set(room_id), - user_id: ActiveValue::set(user_id), + let result = room_participant::Entity::update_many() + .filter( + Condition::all() + .add(room_participant::Column::RoomId.eq(room_id)) + .add(room_participant::Column::UserId.eq(user_id)) + .add(room_participant::Column::AnsweringConnectionId.is_null()), + ) + .set(room_participant::ActiveModel { + participant_index: ActiveValue::Set(Some(participant_index)), answering_connection_id: ActiveValue::set(Some(connection.id as i32)), answering_connection_server_id: ActiveValue::set(Some(ServerId( connection.owner_id as i32, ))), answering_connection_lost: ActiveValue::set(false), - calling_user_id: ActiveValue::set(user_id), - calling_connection_id: ActiveValue::set(connection.id as i32), - calling_connection_server_id: ActiveValue::set(Some(ServerId( - connection.owner_id as i32, - ))), - participant_index: ActiveValue::Set(Some(participant_index)), ..Default::default() - }]) - .on_conflict( - OnConflict::columns([room_participant::Column::UserId]) - .update_columns([ - room_participant::Column::AnsweringConnectionId, - room_participant::Column::AnsweringConnectionServerId, - room_participant::Column::AnsweringConnectionLost, - room_participant::Column::ParticipantIndex, - ]) - .to_owned(), - ) + }) .exec(&*tx) .await?; - } else { - let result = room_participant::Entity::update_many() - .filter( - Condition::all() - .add(room_participant::Column::RoomId.eq(room_id)) - .add(room_participant::Column::UserId.eq(user_id)) - .add(room_participant::Column::AnsweringConnectionId.is_null()), - ) - .set(room_participant::ActiveModel { - participant_index: ActiveValue::Set(Some(participant_index)), - answering_connection_id: ActiveValue::set(Some(connection.id as i32)), - answering_connection_server_id: ActiveValue::set(Some(ServerId( - connection.owner_id as i32, - ))), - answering_connection_lost: ActiveValue::set(false), - ..Default::default() - }) - .exec(&*tx) - .await?; - if result.rows_affected == 0 { - Err(anyhow!("room does not exist or was already joined"))?; - } + if result.rows_affected == 0 { + Err(anyhow!("room does not exist or was already joined"))?; } let room = self.get_room(room_id, &tx).await?; - let channel_members = if let Some(channel_id) = channel_id { - self.get_channel_members_internal(channel_id, &tx).await? - } else { - Vec::new() - }; Ok(JoinRoom { room, - channel_id, - channel_members, + channel_id: None, + channel_members: vec![], }) }) .await } + async fn get_next_participant_index_internal( + &self, + room_id: RoomId, + tx: &DatabaseTransaction, + ) -> Result { + #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] + enum QueryParticipantIndices { + ParticipantIndex, + } + let existing_participant_indices: Vec = room_participant::Entity::find() + .filter( + room_participant::Column::RoomId + .eq(room_id) + .and(room_participant::Column::ParticipantIndex.is_not_null()), + ) + .select_only() + .column(room_participant::Column::ParticipantIndex) + .into_values::<_, QueryParticipantIndices>() + .all(&*tx) + .await?; + + let mut participant_index = 0; + while existing_participant_indices.contains(&participant_index) { + participant_index += 1; + } + + Ok(participant_index) + } + + pub async fn channel_id_for_room(&self, room_id: RoomId) -> Result> { + self.transaction(|tx| async move { + let room: Option = room::Entity::find() + .filter(room::Column::Id.eq(room_id)) + .one(&*tx) + .await?; + + Ok(room.and_then(|room| room.channel_id)) + }) + .await + } + + pub(crate) async fn join_channel_room_internal( + &self, + room_id: RoomId, + user_id: UserId, + connection: ConnectionId, + tx: &DatabaseTransaction, + ) -> Result { + let participant_index = self + .get_next_participant_index_internal(room_id, &*tx) + .await?; + + room_participant::Entity::insert_many([room_participant::ActiveModel { + room_id: ActiveValue::set(room_id), + user_id: ActiveValue::set(user_id), + answering_connection_id: ActiveValue::set(Some(connection.id as i32)), + answering_connection_server_id: ActiveValue::set(Some(ServerId( + connection.owner_id as i32, + ))), + answering_connection_lost: ActiveValue::set(false), + calling_user_id: ActiveValue::set(user_id), + calling_connection_id: ActiveValue::set(connection.id as i32), + calling_connection_server_id: ActiveValue::set(Some(ServerId( + connection.owner_id as i32, + ))), + participant_index: ActiveValue::Set(Some(participant_index)), + ..Default::default() + }]) + .on_conflict( + OnConflict::columns([room_participant::Column::UserId]) + .update_columns([ + room_participant::Column::AnsweringConnectionId, + room_participant::Column::AnsweringConnectionServerId, + room_participant::Column::AnsweringConnectionLost, + room_participant::Column::ParticipantIndex, + ]) + .to_owned(), + ) + .exec(&*tx) + .await?; + + let (channel, room) = self.get_channel_room(room_id, &tx).await?; + let channel = channel.ok_or_else(|| anyhow!("no channel for room"))?; + let channel_members = self.get_channel_participants(&channel, &*tx).await?; + Ok(JoinRoom { + room, + channel_id: Some(channel.id), + channel_members, + }) + } + pub async fn rejoin_room( &self, rejoin_room: proto::RejoinRoom, @@ -679,16 +718,16 @@ impl Database { }); } - let (channel_id, room) = self.get_channel_room(room_id, &tx).await?; - let channel_members = if let Some(channel_id) = channel_id { - self.get_channel_members_internal(channel_id, &tx).await? + let (channel, room) = self.get_channel_room(room_id, &tx).await?; + let channel_members = if let Some(channel) = &channel { + self.get_channel_participants(&channel, &tx).await? } else { Vec::new() }; Ok(RejoinedRoom { room, - channel_id, + channel_id: channel.map(|channel| channel.id), channel_members, rejoined_projects, reshared_projects, @@ -830,7 +869,7 @@ impl Database { .exec(&*tx) .await?; - let (channel_id, room) = self.get_channel_room(room_id, &tx).await?; + let (channel, room) = self.get_channel_room(room_id, &tx).await?; let deleted = if room.participants.is_empty() { let result = room::Entity::delete_by_id(room_id).exec(&*tx).await?; result.rows_affected > 0 @@ -838,14 +877,14 @@ impl Database { false }; - let channel_members = if let Some(channel_id) = channel_id { - self.get_channel_members_internal(channel_id, &tx).await? + let channel_members = if let Some(channel) = &channel { + self.get_channel_participants(channel, &tx).await? } else { Vec::new() }; let left_room = LeftRoom { room, - channel_id, + channel_id: channel.map(|channel| channel.id), channel_members, left_projects, canceled_calls_to_user_ids, @@ -1033,7 +1072,7 @@ impl Database { &self, room_id: RoomId, tx: &DatabaseTransaction, - ) -> Result<(Option, proto::Room)> { + ) -> Result<(Option, proto::Room)> { let db_room = room::Entity::find_by_id(room_id) .one(tx) .await? @@ -1142,9 +1181,16 @@ impl Database { project_id: db_follower.project_id.to_proto(), }); } + drop(db_followers); + + let channel = if let Some(channel_id) = db_room.channel_id { + Some(self.get_channel_internal(channel_id, &*tx).await?) + } else { + None + }; Ok(( - db_room.channel_id, + channel, proto::Room { id: db_room.id.to_proto(), live_kit_room: db_room.live_kit_room, diff --git a/crates/collab/src/db/tables.rs b/crates/collab/src/db/tables.rs index e19391da7d..4f28ce4fbd 100644 --- a/crates/collab/src/db/tables.rs +++ b/crates/collab/src/db/tables.rs @@ -7,11 +7,13 @@ pub mod channel_buffer_collaborator; pub mod channel_chat_participant; pub mod channel_member; pub mod channel_message; -pub mod channel_path; +pub mod channel_message_mention; pub mod contact; pub mod feature_flag; pub mod follower; pub mod language_server; +pub mod notification; +pub mod notification_kind; pub mod observed_buffer_edits; pub mod observed_channel_messages; pub mod project; diff --git a/crates/collab/src/db/tables/channel.rs b/crates/collab/src/db/tables/channel.rs index 54f12defc1..e30ec9af61 100644 --- a/crates/collab/src/db/tables/channel.rs +++ b/crates/collab/src/db/tables/channel.rs @@ -1,4 +1,4 @@ -use crate::db::ChannelId; +use crate::db::{ChannelId, ChannelVisibility}; use sea_orm::entity::prelude::*; #[derive(Clone, Debug, Default, PartialEq, Eq, DeriveEntityModel)] @@ -7,6 +7,29 @@ pub struct Model { #[sea_orm(primary_key)] pub id: ChannelId, pub name: String, + pub visibility: ChannelVisibility, + pub parent_path: String, +} + +impl Model { + pub fn parent_id(&self) -> Option { + self.ancestors().last() + } + + pub fn ancestors(&self) -> impl Iterator + '_ { + self.parent_path + .trim_end_matches('/') + .split('/') + .filter_map(|id| Some(ChannelId::from_proto(id.parse().ok()?))) + } + + pub fn ancestors_including_self(&self) -> impl Iterator + '_ { + self.ancestors().chain(Some(self.id)) + } + + pub fn path(&self) -> String { + format!("{}{}/", self.parent_path, self.id) + } } impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab/src/db/tables/channel_member.rs b/crates/collab/src/db/tables/channel_member.rs index ba3db5a155..5498a00856 100644 --- a/crates/collab/src/db/tables/channel_member.rs +++ b/crates/collab/src/db/tables/channel_member.rs @@ -1,7 +1,7 @@ -use crate::db::{channel_member, ChannelId, ChannelMemberId, UserId}; +use crate::db::{channel_member, ChannelId, ChannelMemberId, ChannelRole, UserId}; use sea_orm::entity::prelude::*; -#[derive(Clone, Debug, Default, PartialEq, Eq, DeriveEntityModel)] +#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] #[sea_orm(table_name = "channel_members")] pub struct Model { #[sea_orm(primary_key)] @@ -9,7 +9,7 @@ pub struct Model { pub channel_id: ChannelId, pub user_id: UserId, pub accepted: bool, - pub admin: bool, + pub role: ChannelRole, } impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab/src/db/tables/channel_message_mention.rs b/crates/collab/src/db/tables/channel_message_mention.rs new file mode 100644 index 0000000000..6155b057f0 --- /dev/null +++ b/crates/collab/src/db/tables/channel_message_mention.rs @@ -0,0 +1,43 @@ +use crate::db::{MessageId, UserId}; +use sea_orm::entity::prelude::*; + +#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] +#[sea_orm(table_name = "channel_message_mentions")] +pub struct Model { + #[sea_orm(primary_key)] + pub message_id: MessageId, + #[sea_orm(primary_key)] + pub start_offset: i32, + pub end_offset: i32, + pub user_id: UserId, +} + +impl ActiveModelBehavior for ActiveModel {} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm( + belongs_to = "super::channel_message::Entity", + from = "Column::MessageId", + to = "super::channel_message::Column::Id" + )] + Message, + #[sea_orm( + belongs_to = "super::user::Entity", + from = "Column::UserId", + to = "super::user::Column::Id" + )] + MentionedUser, +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::Message.def() + } +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::MentionedUser.def() + } +} diff --git a/crates/collab/src/db/tables/notification.rs b/crates/collab/src/db/tables/notification.rs new file mode 100644 index 0000000000..3105198fa2 --- /dev/null +++ b/crates/collab/src/db/tables/notification.rs @@ -0,0 +1,29 @@ +use crate::db::{NotificationId, NotificationKindId, UserId}; +use sea_orm::entity::prelude::*; +use time::PrimitiveDateTime; + +#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] +#[sea_orm(table_name = "notifications")] +pub struct Model { + #[sea_orm(primary_key)] + pub id: NotificationId, + pub created_at: PrimitiveDateTime, + pub recipient_id: UserId, + pub kind: NotificationKindId, + pub entity_id: Option, + pub content: String, + pub is_read: bool, + pub response: Option, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm( + belongs_to = "super::user::Entity", + from = "Column::RecipientId", + to = "super::user::Column::Id" + )] + Recipient, +} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab/src/db/tables/channel_path.rs b/crates/collab/src/db/tables/notification_kind.rs similarity index 51% rename from crates/collab/src/db/tables/channel_path.rs rename to crates/collab/src/db/tables/notification_kind.rs index 323f116dae..865b5da04b 100644 --- a/crates/collab/src/db/tables/channel_path.rs +++ b/crates/collab/src/db/tables/notification_kind.rs @@ -1,15 +1,15 @@ -use crate::db::ChannelId; +use crate::db::NotificationKindId; use sea_orm::entity::prelude::*; -#[derive(Clone, Debug, Default, PartialEq, Eq, DeriveEntityModel)] -#[sea_orm(table_name = "channel_paths")] +#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] +#[sea_orm(table_name = "notification_kinds")] pub struct Model { #[sea_orm(primary_key)] - pub id_path: String, - pub channel_id: ChannelId, + pub id: NotificationKindId, + pub name: String, } -impl ActiveModelBehavior for ActiveModel {} - #[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] pub enum Relation {} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab/src/db/tests.rs b/crates/collab/src/db/tests.rs index 6a91fd6ffe..b6a89ff6f8 100644 --- a/crates/collab/src/db/tests.rs +++ b/crates/collab/src/db/tests.rs @@ -7,10 +7,12 @@ mod message_tests; use super::*; use gpui::executor::Background; use parking_lot::Mutex; -use rpc::proto::ChannelEdge; use sea_orm::ConnectionTrait; use sqlx::migrate::MigrateDatabase; -use std::sync::Arc; +use std::sync::{ + atomic::{AtomicI32, AtomicU32, Ordering::SeqCst}, + Arc, +}; const TEST_RELEASE_CHANNEL: &'static str = "test"; @@ -31,7 +33,7 @@ impl TestDb { let mut db = runtime.block_on(async { let mut options = ConnectOptions::new(url); options.max_connections(5); - let db = Database::new(options, Executor::Deterministic(background)) + let mut db = Database::new(options, Executor::Deterministic(background)) .await .unwrap(); let sql = include_str!(concat!( @@ -45,6 +47,7 @@ impl TestDb { )) .await .unwrap(); + db.initialize_notification_kinds().await.unwrap(); db }); @@ -79,11 +82,12 @@ impl TestDb { options .max_connections(5) .idle_timeout(Duration::from_secs(0)); - let db = Database::new(options, Executor::Deterministic(background)) + let mut db = Database::new(options, Executor::Deterministic(background)) .await .unwrap(); let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"); db.migrate(Path::new(migrations_path), false).await.unwrap(); + db.initialize_notification_kinds().await.unwrap(); db }); @@ -148,26 +152,39 @@ impl Drop for TestDb { } } -/// The second tuples are (channel_id, parent) -fn graph(channels: &[(ChannelId, &'static str)], edges: &[(ChannelId, ChannelId)]) -> ChannelGraph { - let mut graph = ChannelGraph { - channels: vec![], - edges: vec![], - }; - - for (id, name) in channels { - graph.channels.push(Channel { +fn channel_tree(channels: &[(ChannelId, &[ChannelId], &'static str, ChannelRole)]) -> Vec { + channels + .iter() + .map(|(id, parent_path, name, role)| Channel { id: *id, name: name.to_string(), + visibility: ChannelVisibility::Members, + role: *role, + parent_path: parent_path.to_vec(), }) - } - - for (channel, parent) in edges { - graph.edges.push(ChannelEdge { - channel_id: channel.to_proto(), - parent_id: parent.to_proto(), - }) - } - - graph + .collect() +} + +static GITHUB_USER_ID: AtomicI32 = AtomicI32::new(5); + +async fn new_test_user(db: &Arc, email: &str) -> UserId { + db.create_user( + email, + false, + NewUserParams { + github_login: email[0..email.find("@").unwrap()].to_string(), + github_user_id: GITHUB_USER_ID.fetch_add(1, SeqCst), + }, + ) + .await + .unwrap() + .user_id +} + +static TEST_CONNECTION_ID: AtomicU32 = AtomicU32::new(1); +fn new_test_connection(server: ServerId) -> ConnectionId { + ConnectionId { + id: TEST_CONNECTION_ID.fetch_add(1, SeqCst), + owner_id: server.0 as u32, + } } diff --git a/crates/collab/src/db/tests/buffer_tests.rs b/crates/collab/src/db/tests/buffer_tests.rs index 0ac41a8b0b..222514da0b 100644 --- a/crates/collab/src/db/tests/buffer_tests.rs +++ b/crates/collab/src/db/tests/buffer_tests.rs @@ -17,7 +17,6 @@ async fn test_channel_buffers(db: &Arc) { NewUserParams { github_login: "user_a".into(), github_user_id: 101, - invite_count: 0, }, ) .await @@ -30,7 +29,6 @@ async fn test_channel_buffers(db: &Arc) { NewUserParams { github_login: "user_b".into(), github_user_id: 102, - invite_count: 0, }, ) .await @@ -45,7 +43,6 @@ async fn test_channel_buffers(db: &Arc) { NewUserParams { github_login: "user_c".into(), github_user_id: 102, - invite_count: 0, }, ) .await @@ -56,7 +53,7 @@ async fn test_channel_buffers(db: &Arc) { let zed_id = db.create_root_channel("zed", a_id).await.unwrap(); - db.invite_channel_member(zed_id, b_id, a_id, false) + db.invite_channel_member(zed_id, b_id, a_id, ChannelRole::Member) .await .unwrap(); @@ -178,7 +175,6 @@ async fn test_channel_buffers_last_operations(db: &Database) { NewUserParams { github_login: "user_a".into(), github_user_id: 101, - invite_count: 0, }, ) .await @@ -191,7 +187,6 @@ async fn test_channel_buffers_last_operations(db: &Database) { NewUserParams { github_login: "user_b".into(), github_user_id: 102, - invite_count: 0, }, ) .await @@ -211,7 +206,7 @@ async fn test_channel_buffers_last_operations(db: &Database) { .await .unwrap(); - db.invite_channel_member(channel, observer_id, user_id, false) + db.invite_channel_member(channel, observer_id, user_id, ChannelRole::Member) .await .unwrap(); db.respond_to_channel_invite(channel, observer_id, true) diff --git a/crates/collab/src/db/tests/channel_tests.rs b/crates/collab/src/db/tests/channel_tests.rs index 7d2bc04a35..43526c7f24 100644 --- a/crates/collab/src/db/tests/channel_tests.rs +++ b/crates/collab/src/db/tests/channel_tests.rs @@ -1,56 +1,28 @@ -use collections::{HashMap, HashSet}; +use crate::{ + db::{ + tests::{channel_tree, new_test_connection, new_test_user, TEST_RELEASE_CHANNEL}, + Channel, ChannelId, ChannelRole, Database, NewUserParams, RoomId, + }, + test_both_dbs, +}; use rpc::{ proto::{self}, ConnectionId, }; - -use crate::{ - db::{ - queries::channels::ChannelGraph, - tests::{graph, TEST_RELEASE_CHANNEL}, - ChannelId, Database, NewUserParams, - }, - test_both_dbs, -}; use std::sync::Arc; test_both_dbs!(test_channels, test_channels_postgres, test_channels_sqlite); async fn test_channels(db: &Arc) { - let a_id = db - .create_user( - "user1@example.com", - false, - NewUserParams { - github_login: "user1".into(), - github_user_id: 5, - invite_count: 0, - }, - ) - .await - .unwrap() - .user_id; - - let b_id = db - .create_user( - "user2@example.com", - false, - NewUserParams { - github_login: "user2".into(), - github_user_id: 6, - invite_count: 0, - }, - ) - .await - .unwrap() - .user_id; + let a_id = new_test_user(db, "user1@example.com").await; + let b_id = new_test_user(db, "user2@example.com").await; let zed_id = db.create_root_channel("zed", a_id).await.unwrap(); // Make sure that people cannot read channels they haven't been invited to - assert!(db.get_channel(zed_id, b_id).await.unwrap().is_none()); + assert!(db.get_channel(zed_id, b_id).await.is_err()); - db.invite_channel_member(zed_id, b_id, a_id, false) + db.invite_channel_member(zed_id, b_id, a_id, ChannelRole::Member) .await .unwrap(); @@ -58,99 +30,103 @@ async fn test_channels(db: &Arc) { .await .unwrap(); - let crdb_id = db.create_channel("crdb", Some(zed_id), a_id).await.unwrap(); + let crdb_id = db.create_sub_channel("crdb", zed_id, a_id).await.unwrap(); let livestreaming_id = db - .create_channel("livestreaming", Some(zed_id), a_id) + .create_sub_channel("livestreaming", zed_id, a_id) .await .unwrap(); let replace_id = db - .create_channel("replace", Some(zed_id), a_id) + .create_sub_channel("replace", zed_id, a_id) .await .unwrap(); - let mut members = db.get_channel_members(replace_id).await.unwrap(); + let mut members = db + .transaction(|tx| async move { + let channel = db.get_channel_internal(replace_id, &*tx).await?; + Ok(db.get_channel_participants(&channel, &*tx).await?) + }) + .await + .unwrap(); members.sort(); assert_eq!(members, &[a_id, b_id]); let rust_id = db.create_root_channel("rust", a_id).await.unwrap(); - let cargo_id = db - .create_channel("cargo", Some(rust_id), a_id) - .await - .unwrap(); + let cargo_id = db.create_sub_channel("cargo", rust_id, a_id).await.unwrap(); let cargo_ra_id = db - .create_channel("cargo-ra", Some(cargo_id), a_id) + .create_sub_channel("cargo-ra", cargo_id, a_id) .await .unwrap(); let result = db.get_channels_for_user(a_id).await.unwrap(); assert_eq!( result.channels, - graph( - &[ - (zed_id, "zed"), - (crdb_id, "crdb"), - (livestreaming_id, "livestreaming"), - (replace_id, "replace"), - (rust_id, "rust"), - (cargo_id, "cargo"), - (cargo_ra_id, "cargo-ra") - ], - &[ - (crdb_id, zed_id), - (livestreaming_id, zed_id), - (replace_id, zed_id), - (cargo_id, rust_id), - (cargo_ra_id, cargo_id), - ] - ) + channel_tree(&[ + (zed_id, &[], "zed", ChannelRole::Admin), + (crdb_id, &[zed_id], "crdb", ChannelRole::Admin), + ( + livestreaming_id, + &[zed_id], + "livestreaming", + ChannelRole::Admin + ), + (replace_id, &[zed_id], "replace", ChannelRole::Admin), + (rust_id, &[], "rust", ChannelRole::Admin), + (cargo_id, &[rust_id], "cargo", ChannelRole::Admin), + ( + cargo_ra_id, + &[rust_id, cargo_id], + "cargo-ra", + ChannelRole::Admin + ) + ],) ); let result = db.get_channels_for_user(b_id).await.unwrap(); assert_eq!( result.channels, - graph( - &[ - (zed_id, "zed"), - (crdb_id, "crdb"), - (livestreaming_id, "livestreaming"), - (replace_id, "replace") - ], - &[ - (crdb_id, zed_id), - (livestreaming_id, zed_id), - (replace_id, zed_id) - ] - ) + channel_tree(&[ + (zed_id, &[], "zed", ChannelRole::Member), + (crdb_id, &[zed_id], "crdb", ChannelRole::Member), + ( + livestreaming_id, + &[zed_id], + "livestreaming", + ChannelRole::Member + ), + (replace_id, &[zed_id], "replace", ChannelRole::Member) + ],) ); // Update member permissions - let set_subchannel_admin = db.set_channel_member_admin(crdb_id, a_id, b_id, true).await; + let set_subchannel_admin = db + .set_channel_member_role(crdb_id, a_id, b_id, ChannelRole::Admin) + .await; assert!(set_subchannel_admin.is_err()); - let set_channel_admin = db.set_channel_member_admin(zed_id, a_id, b_id, true).await; + let set_channel_admin = db + .set_channel_member_role(zed_id, a_id, b_id, ChannelRole::Admin) + .await; assert!(set_channel_admin.is_ok()); let result = db.get_channels_for_user(b_id).await.unwrap(); assert_eq!( result.channels, - graph( - &[ - (zed_id, "zed"), - (crdb_id, "crdb"), - (livestreaming_id, "livestreaming"), - (replace_id, "replace") - ], - &[ - (crdb_id, zed_id), - (livestreaming_id, zed_id), - (replace_id, zed_id) - ] - ) + channel_tree(&[ + (zed_id, &[], "zed", ChannelRole::Admin), + (crdb_id, &[zed_id], "crdb", ChannelRole::Admin), + ( + livestreaming_id, + &[zed_id], + "livestreaming", + ChannelRole::Admin + ), + (replace_id, &[zed_id], "replace", ChannelRole::Admin) + ],) ); // Remove a single channel db.delete_channel(crdb_id, a_id).await.unwrap(); - assert!(db.get_channel(crdb_id, a_id).await.unwrap().is_none()); + assert!(db.get_channel(crdb_id, a_id).await.is_err()); // Remove a channel tree let (mut channel_ids, user_ids) = db.delete_channel(rust_id, a_id).await.unwrap(); @@ -158,9 +134,9 @@ async fn test_channels(db: &Arc) { assert_eq!(channel_ids, &[rust_id, cargo_id, cargo_ra_id]); assert_eq!(user_ids, &[a_id]); - assert!(db.get_channel(rust_id, a_id).await.unwrap().is_none()); - assert!(db.get_channel(cargo_id, a_id).await.unwrap().is_none()); - assert!(db.get_channel(cargo_ra_id, a_id).await.unwrap().is_none()); + assert!(db.get_channel(rust_id, a_id).await.is_err()); + assert!(db.get_channel(cargo_id, a_id).await.is_err()); + assert!(db.get_channel(cargo_ra_id, a_id).await.is_err()); } test_both_dbs!( @@ -172,43 +148,15 @@ test_both_dbs!( async fn test_joining_channels(db: &Arc) { let owner_id = db.create_server("test").await.unwrap().0 as u32; - let user_1 = db - .create_user( - "user1@example.com", - false, - NewUserParams { - github_login: "user1".into(), - github_user_id: 5, - invite_count: 0, - }, - ) - .await - .unwrap() - .user_id; - let user_2 = db - .create_user( - "user2@example.com", - false, - NewUserParams { - github_login: "user2".into(), - github_user_id: 6, - invite_count: 0, - }, - ) - .await - .unwrap() - .user_id; + let user_1 = new_test_user(db, "user1@example.com").await; + let user_2 = new_test_user(db, "user2@example.com").await; let channel_1 = db.create_root_channel("channel_1", user_1).await.unwrap(); - let room_1 = db - .get_or_create_channel_room(channel_1, "1", TEST_RELEASE_CHANNEL) - .await - .unwrap(); // can join a room with membership to its channel - let joined_room = db - .join_room( - room_1, + let (joined_room, _, _) = db + .join_channel( + channel_1, user_1, ConnectionId { owner_id, id: 1 }, TEST_RELEASE_CHANNEL, @@ -217,11 +165,12 @@ async fn test_joining_channels(db: &Arc) { .unwrap(); assert_eq!(joined_room.room.participants.len(), 1); + let room_id = RoomId::from_proto(joined_room.room.id); drop(joined_room); // cannot join a room without membership to its channel assert!(db .join_room( - room_1, + room_id, user_2, ConnectionId { owner_id, id: 1 }, TEST_RELEASE_CHANNEL @@ -239,58 +188,21 @@ test_both_dbs!( async fn test_channel_invites(db: &Arc) { db.create_server("test").await.unwrap(); - let user_1 = db - .create_user( - "user1@example.com", - false, - NewUserParams { - github_login: "user1".into(), - github_user_id: 5, - invite_count: 0, - }, - ) - .await - .unwrap() - .user_id; - let user_2 = db - .create_user( - "user2@example.com", - false, - NewUserParams { - github_login: "user2".into(), - github_user_id: 6, - invite_count: 0, - }, - ) - .await - .unwrap() - .user_id; - - let user_3 = db - .create_user( - "user3@example.com", - false, - NewUserParams { - github_login: "user3".into(), - github_user_id: 7, - invite_count: 0, - }, - ) - .await - .unwrap() - .user_id; + let user_1 = new_test_user(db, "user1@example.com").await; + let user_2 = new_test_user(db, "user2@example.com").await; + let user_3 = new_test_user(db, "user3@example.com").await; let channel_1_1 = db.create_root_channel("channel_1", user_1).await.unwrap(); let channel_1_2 = db.create_root_channel("channel_2", user_1).await.unwrap(); - db.invite_channel_member(channel_1_1, user_2, user_1, false) + db.invite_channel_member(channel_1_1, user_2, user_1, ChannelRole::Member) .await .unwrap(); - db.invite_channel_member(channel_1_2, user_2, user_1, false) + db.invite_channel_member(channel_1_2, user_2, user_1, ChannelRole::Member) .await .unwrap(); - db.invite_channel_member(channel_1_1, user_3, user_1, true) + db.invite_channel_member(channel_1_1, user_3, user_1, ChannelRole::Admin) .await .unwrap(); @@ -314,27 +226,29 @@ async fn test_channel_invites(db: &Arc) { assert_eq!(user_3_invites, &[channel_1_1]); - let members = db - .get_channel_member_details(channel_1_1, user_1) + let mut members = db + .get_channel_participant_details(channel_1_1, user_1) .await .unwrap(); + + members.sort_by_key(|member| member.user_id); assert_eq!( members, &[ proto::ChannelMember { user_id: user_1.to_proto(), kind: proto::channel_member::Kind::Member.into(), - admin: true, + role: proto::ChannelRole::Admin.into(), }, proto::ChannelMember { user_id: user_2.to_proto(), kind: proto::channel_member::Kind::Invitee.into(), - admin: false, + role: proto::ChannelRole::Member.into(), }, proto::ChannelMember { user_id: user_3.to_proto(), kind: proto::channel_member::Kind::Invitee.into(), - admin: true, + role: proto::ChannelRole::Admin.into(), }, ] ); @@ -344,12 +258,12 @@ async fn test_channel_invites(db: &Arc) { .unwrap(); let channel_1_3 = db - .create_channel("channel_3", Some(channel_1_1), user_1) + .create_sub_channel("channel_3", channel_1_1, user_1) .await .unwrap(); let members = db - .get_channel_member_details(channel_1_3, user_1) + .get_channel_participant_details(channel_1_3, user_1) .await .unwrap(); assert_eq!( @@ -357,13 +271,13 @@ async fn test_channel_invites(db: &Arc) { &[ proto::ChannelMember { user_id: user_1.to_proto(), - kind: proto::channel_member::Kind::Member.into(), - admin: true, + kind: proto::channel_member::Kind::AncestorMember.into(), + role: proto::ChannelRole::Admin.into(), }, proto::ChannelMember { user_id: user_2.to_proto(), kind: proto::channel_member::Kind::AncestorMember.into(), - admin: false, + role: proto::ChannelRole::Member.into(), }, ] ); @@ -385,7 +299,6 @@ async fn test_channel_renames(db: &Arc) { NewUserParams { github_login: "user1".into(), github_user_id: 5, - invite_count: 0, }, ) .await @@ -399,7 +312,6 @@ async fn test_channel_renames(db: &Arc) { NewUserParams { github_login: "user2".into(), github_user_id: 6, - invite_count: 0, }, ) .await @@ -412,18 +324,10 @@ async fn test_channel_renames(db: &Arc) { .await .unwrap(); - let zed_archive_id = zed_id; - - let (channel, _) = db - .get_channel(zed_archive_id, user_1) - .await - .unwrap() - .unwrap(); + let channel = db.get_channel(zed_id, user_1).await.unwrap(); assert_eq!(channel.name, "zed-archive"); - let non_permissioned_rename = db - .rename_channel(zed_archive_id, user_2, "hacked-lol") - .await; + let non_permissioned_rename = db.rename_channel(zed_id, user_2, "hacked-lol").await; assert!(non_permissioned_rename.is_err()); let bad_name_rename = db.rename_channel(zed_id, user_1, "#").await; @@ -444,7 +348,6 @@ async fn test_db_channel_moving(db: &Arc) { NewUserParams { github_login: "user1".into(), github_user_id: 5, - invite_count: 0, }, ) .await @@ -453,20 +356,17 @@ async fn test_db_channel_moving(db: &Arc) { let zed_id = db.create_root_channel("zed", a_id).await.unwrap(); - let crdb_id = db.create_channel("crdb", Some(zed_id), a_id).await.unwrap(); + let crdb_id = db.create_sub_channel("crdb", zed_id, a_id).await.unwrap(); - let gpui2_id = db - .create_channel("gpui2", Some(zed_id), a_id) - .await - .unwrap(); + let gpui2_id = db.create_sub_channel("gpui2", zed_id, a_id).await.unwrap(); let livestreaming_id = db - .create_channel("livestreaming", Some(crdb_id), a_id) + .create_sub_channel("livestreaming", crdb_id, a_id) .await .unwrap(); let livestreaming_dag_id = db - .create_channel("livestreaming_dag", Some(livestreaming_id), a_id) + .create_sub_channel("livestreaming_dag", livestreaming_id, a_id) .await .unwrap(); @@ -476,316 +376,16 @@ async fn test_db_channel_moving(db: &Arc) { // /- gpui2 // zed -- crdb - livestreaming - livestreaming_dag let result = db.get_channels_for_user(a_id).await.unwrap(); - assert_dag( + assert_channel_tree( result.channels, &[ - (zed_id, None), - (crdb_id, Some(zed_id)), - (gpui2_id, Some(zed_id)), - (livestreaming_id, Some(crdb_id)), - (livestreaming_dag_id, Some(livestreaming_id)), + (zed_id, &[]), + (crdb_id, &[zed_id]), + (livestreaming_id, &[zed_id, crdb_id]), + (livestreaming_dag_id, &[zed_id, crdb_id, livestreaming_id]), + (gpui2_id, &[zed_id]), ], ); - - // Attempt to make a cycle - assert!(db - .link_channel(a_id, zed_id, livestreaming_id) - .await - .is_err()); - - // ======================================================================== - // Make a link - db.link_channel(a_id, livestreaming_id, zed_id) - .await - .unwrap(); - - // DAG is now: - // /- gpui2 - // zed -- crdb - livestreaming - livestreaming_dag - // \---------/ - let result = db.get_channels_for_user(a_id).await.unwrap(); - assert_dag( - result.channels, - &[ - (zed_id, None), - (crdb_id, Some(zed_id)), - (gpui2_id, Some(zed_id)), - (livestreaming_id, Some(zed_id)), - (livestreaming_id, Some(crdb_id)), - (livestreaming_dag_id, Some(livestreaming_id)), - ], - ); - - // ======================================================================== - // Create a new channel below a channel with multiple parents - let livestreaming_dag_sub_id = db - .create_channel("livestreaming_dag_sub", Some(livestreaming_dag_id), a_id) - .await - .unwrap(); - - // DAG is now: - // /- gpui2 - // zed -- crdb - livestreaming - livestreaming_dag - livestreaming_dag_sub_id - // \---------/ - let result = db.get_channels_for_user(a_id).await.unwrap(); - assert_dag( - result.channels, - &[ - (zed_id, None), - (crdb_id, Some(zed_id)), - (gpui2_id, Some(zed_id)), - (livestreaming_id, Some(zed_id)), - (livestreaming_id, Some(crdb_id)), - (livestreaming_dag_id, Some(livestreaming_id)), - (livestreaming_dag_sub_id, Some(livestreaming_dag_id)), - ], - ); - - // ======================================================================== - // Test a complex DAG by making another link - let returned_channels = db - .link_channel(a_id, livestreaming_dag_sub_id, livestreaming_id) - .await - .unwrap(); - - // DAG is now: - // /- gpui2 /---------------------\ - // zed - crdb - livestreaming - livestreaming_dag - livestreaming_dag_sub_id - // \--------/ - - // make sure we're getting just the new link - // Not using the assert_dag helper because we want to make sure we're returning the full data - pretty_assertions::assert_eq!( - returned_channels, - graph( - &[(livestreaming_dag_sub_id, "livestreaming_dag_sub")], - &[(livestreaming_dag_sub_id, livestreaming_id)] - ) - ); - - let result = db.get_channels_for_user(a_id).await.unwrap(); - assert_dag( - result.channels, - &[ - (zed_id, None), - (crdb_id, Some(zed_id)), - (gpui2_id, Some(zed_id)), - (livestreaming_id, Some(zed_id)), - (livestreaming_id, Some(crdb_id)), - (livestreaming_dag_id, Some(livestreaming_id)), - (livestreaming_dag_sub_id, Some(livestreaming_id)), - (livestreaming_dag_sub_id, Some(livestreaming_dag_id)), - ], - ); - - // ======================================================================== - // Test a complex DAG by making another link - let returned_channels = db - .link_channel(a_id, livestreaming_id, gpui2_id) - .await - .unwrap(); - - // DAG is now: - // /- gpui2 -\ /---------------------\ - // zed - crdb -- livestreaming - livestreaming_dag - livestreaming_dag_sub_id - // \---------/ - - // Make sure that we're correctly getting the full sub-dag - pretty_assertions::assert_eq!( - returned_channels, - graph( - &[ - (livestreaming_id, "livestreaming"), - (livestreaming_dag_id, "livestreaming_dag"), - (livestreaming_dag_sub_id, "livestreaming_dag_sub"), - ], - &[ - (livestreaming_id, gpui2_id), - (livestreaming_dag_id, livestreaming_id), - (livestreaming_dag_sub_id, livestreaming_id), - (livestreaming_dag_sub_id, livestreaming_dag_id), - ] - ) - ); - - let result = db.get_channels_for_user(a_id).await.unwrap(); - assert_dag( - result.channels, - &[ - (zed_id, None), - (crdb_id, Some(zed_id)), - (gpui2_id, Some(zed_id)), - (livestreaming_id, Some(zed_id)), - (livestreaming_id, Some(crdb_id)), - (livestreaming_id, Some(gpui2_id)), - (livestreaming_dag_id, Some(livestreaming_id)), - (livestreaming_dag_sub_id, Some(livestreaming_id)), - (livestreaming_dag_sub_id, Some(livestreaming_dag_id)), - ], - ); - - // ======================================================================== - // Test unlinking in a complex DAG by removing the inner link - db.unlink_channel(a_id, livestreaming_dag_sub_id, livestreaming_id) - .await - .unwrap(); - - // DAG is now: - // /- gpui2 -\ - // zed - crdb -- livestreaming - livestreaming_dag - livestreaming_dag_sub - // \---------/ - - let result = db.get_channels_for_user(a_id).await.unwrap(); - assert_dag( - result.channels, - &[ - (zed_id, None), - (crdb_id, Some(zed_id)), - (gpui2_id, Some(zed_id)), - (livestreaming_id, Some(gpui2_id)), - (livestreaming_id, Some(zed_id)), - (livestreaming_id, Some(crdb_id)), - (livestreaming_dag_id, Some(livestreaming_id)), - (livestreaming_dag_sub_id, Some(livestreaming_dag_id)), - ], - ); - - // ======================================================================== - // Test unlinking in a complex DAG by removing the inner link - db.unlink_channel(a_id, livestreaming_id, gpui2_id) - .await - .unwrap(); - - // DAG is now: - // /- gpui2 - // zed - crdb -- livestreaming - livestreaming_dag - livestreaming_dag_sub - // \---------/ - let result = db.get_channels_for_user(a_id).await.unwrap(); - assert_dag( - result.channels, - &[ - (zed_id, None), - (crdb_id, Some(zed_id)), - (gpui2_id, Some(zed_id)), - (livestreaming_id, Some(zed_id)), - (livestreaming_id, Some(crdb_id)), - (livestreaming_dag_id, Some(livestreaming_id)), - (livestreaming_dag_sub_id, Some(livestreaming_dag_id)), - ], - ); - - // ======================================================================== - // Test moving DAG nodes by moving livestreaming to be below gpui2 - db.move_channel(a_id, livestreaming_id, crdb_id, gpui2_id) - .await - .unwrap(); - - // DAG is now: - // /- gpui2 -- livestreaming - livestreaming_dag - livestreaming_dag_sub - // zed - crdb / - // \---------/ - let result = db.get_channels_for_user(a_id).await.unwrap(); - assert_dag( - result.channels, - &[ - (zed_id, None), - (crdb_id, Some(zed_id)), - (gpui2_id, Some(zed_id)), - (livestreaming_id, Some(zed_id)), - (livestreaming_id, Some(gpui2_id)), - (livestreaming_dag_id, Some(livestreaming_id)), - (livestreaming_dag_sub_id, Some(livestreaming_dag_id)), - ], - ); - - // ======================================================================== - // Deleting a channel should not delete children that still have other parents - db.delete_channel(gpui2_id, a_id).await.unwrap(); - - // DAG is now: - // zed - crdb - // \- livestreaming - livestreaming_dag - livestreaming_dag_sub - let result = db.get_channels_for_user(a_id).await.unwrap(); - assert_dag( - result.channels, - &[ - (zed_id, None), - (crdb_id, Some(zed_id)), - (livestreaming_id, Some(zed_id)), - (livestreaming_dag_id, Some(livestreaming_id)), - (livestreaming_dag_sub_id, Some(livestreaming_dag_id)), - ], - ); - - // ======================================================================== - // Unlinking a channel from it's parent should automatically promote it to a root channel - db.unlink_channel(a_id, crdb_id, zed_id).await.unwrap(); - - // DAG is now: - // crdb - // zed - // \- livestreaming - livestreaming_dag - livestreaming_dag_sub - - let result = db.get_channels_for_user(a_id).await.unwrap(); - assert_dag( - result.channels, - &[ - (zed_id, None), - (crdb_id, None), - (livestreaming_id, Some(zed_id)), - (livestreaming_dag_id, Some(livestreaming_id)), - (livestreaming_dag_sub_id, Some(livestreaming_dag_id)), - ], - ); - - // ======================================================================== - // You should be able to move a root channel into a non-root channel - db.link_channel(a_id, crdb_id, zed_id).await.unwrap(); - - // DAG is now: - // zed - crdb - // \- livestreaming - livestreaming_dag - livestreaming_dag_sub - - let result = db.get_channels_for_user(a_id).await.unwrap(); - assert_dag( - result.channels, - &[ - (zed_id, None), - (crdb_id, Some(zed_id)), - (livestreaming_id, Some(zed_id)), - (livestreaming_dag_id, Some(livestreaming_id)), - (livestreaming_dag_sub_id, Some(livestreaming_dag_id)), - ], - ); - - // ======================================================================== - // Prep for DAG deletion test - db.link_channel(a_id, livestreaming_id, crdb_id) - .await - .unwrap(); - - // DAG is now: - // zed - crdb - livestreaming - livestreaming_dag - livestreaming_dag_sub - // \--------/ - - let result = db.get_channels_for_user(a_id).await.unwrap(); - assert_dag( - result.channels, - &[ - (zed_id, None), - (crdb_id, Some(zed_id)), - (livestreaming_id, Some(zed_id)), - (livestreaming_id, Some(crdb_id)), - (livestreaming_dag_id, Some(livestreaming_id)), - (livestreaming_dag_sub_id, Some(livestreaming_dag_id)), - ], - ); - - // Deleting the parent of a DAG should delete the whole DAG: - db.delete_channel(zed_id, a_id).await.unwrap(); - let result = db.get_channels_for_user(a_id).await.unwrap(); - - assert!(result.channels.is_empty()) } test_both_dbs!( @@ -802,7 +402,6 @@ async fn test_db_channel_moving_bugs(db: &Arc) { NewUserParams { github_login: "user1".into(), github_user_id: 5, - invite_count: 0, }, ) .await @@ -812,12 +411,12 @@ async fn test_db_channel_moving_bugs(db: &Arc) { let zed_id = db.create_root_channel("zed", user_id).await.unwrap(); let projects_id = db - .create_channel("projects", Some(zed_id), user_id) + .create_sub_channel("projects", zed_id, user_id) .await .unwrap(); let livestreaming_id = db - .create_channel("livestreaming", Some(projects_id), user_id) + .create_sub_channel("livestreaming", projects_id, user_id) .await .unwrap(); @@ -825,48 +424,396 @@ async fn test_db_channel_moving_bugs(db: &Arc) { // Move to same parent should be a no-op assert!(db - .move_channel(user_id, projects_id, zed_id, zed_id) + .move_channel(projects_id, Some(zed_id), user_id) .await .unwrap() - .is_empty()); - - // Stranding a channel should retain it's sub channels - db.unlink_channel(user_id, projects_id, zed_id) - .await - .unwrap(); + .is_none()); let result = db.get_channels_for_user(user_id).await.unwrap(); - assert_dag( + assert_channel_tree( result.channels, &[ - (zed_id, None), - (projects_id, None), - (livestreaming_id, Some(projects_id)), + (zed_id, &[]), + (projects_id, &[zed_id]), + (livestreaming_id, &[zed_id, projects_id]), + ], + ); + + // Move the project channel to the root + db.move_channel(projects_id, None, user_id).await.unwrap(); + let result = db.get_channels_for_user(user_id).await.unwrap(); + assert_channel_tree( + result.channels, + &[ + (zed_id, &[]), + (projects_id, &[]), + (livestreaming_id, &[projects_id]), ], ); } -#[track_caller] -fn assert_dag(actual: ChannelGraph, expected: &[(ChannelId, Option)]) { - let mut actual_map: HashMap> = HashMap::default(); - for channel in actual.channels { - actual_map.insert(channel.id, HashSet::default()); - } - for edge in actual.edges { - actual_map - .get_mut(&ChannelId::from_proto(edge.channel_id)) - .unwrap() - .insert(ChannelId::from_proto(edge.parent_id)); - } +test_both_dbs!( + test_user_is_channel_participant, + test_user_is_channel_participant_postgres, + test_user_is_channel_participant_sqlite +); - let mut expected_map: HashMap> = HashMap::default(); +async fn test_user_is_channel_participant(db: &Arc) { + let admin = new_test_user(db, "admin@example.com").await; + let member = new_test_user(db, "member@example.com").await; + let guest = new_test_user(db, "guest@example.com").await; - for (child, parent) in expected { - let entry = expected_map.entry(*child).or_default(); - if let Some(parent) = parent { - entry.insert(*parent); - } - } + let zed_channel = db.create_root_channel("zed", admin).await.unwrap(); + let active_channel_id = db + .create_sub_channel("active", zed_channel, admin) + .await + .unwrap(); + let vim_channel_id = db + .create_sub_channel("vim", active_channel_id, admin) + .await + .unwrap(); - pretty_assertions::assert_eq!(actual_map, expected_map) + db.set_channel_visibility(vim_channel_id, crate::db::ChannelVisibility::Public, admin) + .await + .unwrap(); + db.invite_channel_member(active_channel_id, member, admin, ChannelRole::Member) + .await + .unwrap(); + db.invite_channel_member(vim_channel_id, guest, admin, ChannelRole::Guest) + .await + .unwrap(); + + db.respond_to_channel_invite(active_channel_id, member, true) + .await + .unwrap(); + + db.transaction(|tx| async move { + db.check_user_is_channel_participant( + &db.get_channel_internal(vim_channel_id, &*tx).await?, + admin, + &*tx, + ) + .await + }) + .await + .unwrap(); + db.transaction(|tx| async move { + db.check_user_is_channel_participant( + &db.get_channel_internal(vim_channel_id, &*tx).await?, + member, + &*tx, + ) + .await + }) + .await + .unwrap(); + + let mut members = db + .get_channel_participant_details(vim_channel_id, admin) + .await + .unwrap(); + + members.sort_by_key(|member| member.user_id); + + assert_eq!( + members, + &[ + proto::ChannelMember { + user_id: admin.to_proto(), + kind: proto::channel_member::Kind::AncestorMember.into(), + role: proto::ChannelRole::Admin.into(), + }, + proto::ChannelMember { + user_id: member.to_proto(), + kind: proto::channel_member::Kind::AncestorMember.into(), + role: proto::ChannelRole::Member.into(), + }, + proto::ChannelMember { + user_id: guest.to_proto(), + kind: proto::channel_member::Kind::Invitee.into(), + role: proto::ChannelRole::Guest.into(), + }, + ] + ); + + db.respond_to_channel_invite(vim_channel_id, guest, true) + .await + .unwrap(); + + db.transaction(|tx| async move { + db.check_user_is_channel_participant( + &db.get_channel_internal(vim_channel_id, &*tx).await?, + guest, + &*tx, + ) + .await + }) + .await + .unwrap(); + + let channels = db.get_channels_for_user(guest).await.unwrap().channels; + assert_channel_tree(channels, &[(vim_channel_id, &[])]); + let channels = db.get_channels_for_user(member).await.unwrap().channels; + assert_channel_tree( + channels, + &[ + (active_channel_id, &[]), + (vim_channel_id, &[active_channel_id]), + ], + ); + + db.set_channel_member_role(vim_channel_id, admin, guest, ChannelRole::Banned) + .await + .unwrap(); + assert!(db + .transaction(|tx| async move { + db.check_user_is_channel_participant( + &db.get_channel_internal(vim_channel_id, &*tx).await.unwrap(), + guest, + &*tx, + ) + .await + }) + .await + .is_err()); + + let mut members = db + .get_channel_participant_details(vim_channel_id, admin) + .await + .unwrap(); + + members.sort_by_key(|member| member.user_id); + + assert_eq!( + members, + &[ + proto::ChannelMember { + user_id: admin.to_proto(), + kind: proto::channel_member::Kind::AncestorMember.into(), + role: proto::ChannelRole::Admin.into(), + }, + proto::ChannelMember { + user_id: member.to_proto(), + kind: proto::channel_member::Kind::AncestorMember.into(), + role: proto::ChannelRole::Member.into(), + }, + proto::ChannelMember { + user_id: guest.to_proto(), + kind: proto::channel_member::Kind::Member.into(), + role: proto::ChannelRole::Banned.into(), + }, + ] + ); + + db.remove_channel_member(vim_channel_id, guest, admin) + .await + .unwrap(); + + db.set_channel_visibility(zed_channel, crate::db::ChannelVisibility::Public, admin) + .await + .unwrap(); + + db.invite_channel_member(zed_channel, guest, admin, ChannelRole::Guest) + .await + .unwrap(); + + // currently people invited to parent channels are not shown here + let mut members = db + .get_channel_participant_details(vim_channel_id, admin) + .await + .unwrap(); + + members.sort_by_key(|member| member.user_id); + + assert_eq!( + members, + &[ + proto::ChannelMember { + user_id: admin.to_proto(), + kind: proto::channel_member::Kind::AncestorMember.into(), + role: proto::ChannelRole::Admin.into(), + }, + proto::ChannelMember { + user_id: member.to_proto(), + kind: proto::channel_member::Kind::AncestorMember.into(), + role: proto::ChannelRole::Member.into(), + }, + ] + ); + + db.respond_to_channel_invite(zed_channel, guest, true) + .await + .unwrap(); + + db.transaction(|tx| async move { + db.check_user_is_channel_participant( + &db.get_channel_internal(zed_channel, &*tx).await.unwrap(), + guest, + &*tx, + ) + .await + }) + .await + .unwrap(); + assert!(db + .transaction(|tx| async move { + db.check_user_is_channel_participant( + &db.get_channel_internal(active_channel_id, &*tx) + .await + .unwrap(), + guest, + &*tx, + ) + .await + }) + .await + .is_err(),); + + db.transaction(|tx| async move { + db.check_user_is_channel_participant( + &db.get_channel_internal(vim_channel_id, &*tx).await.unwrap(), + guest, + &*tx, + ) + .await + }) + .await + .unwrap(); + + let mut members = db + .get_channel_participant_details(vim_channel_id, admin) + .await + .unwrap(); + + members.sort_by_key(|member| member.user_id); + + assert_eq!( + members, + &[ + proto::ChannelMember { + user_id: admin.to_proto(), + kind: proto::channel_member::Kind::AncestorMember.into(), + role: proto::ChannelRole::Admin.into(), + }, + proto::ChannelMember { + user_id: member.to_proto(), + kind: proto::channel_member::Kind::AncestorMember.into(), + role: proto::ChannelRole::Member.into(), + }, + proto::ChannelMember { + user_id: guest.to_proto(), + kind: proto::channel_member::Kind::AncestorMember.into(), + role: proto::ChannelRole::Guest.into(), + }, + ] + ); + + let channels = db.get_channels_for_user(guest).await.unwrap().channels; + assert_channel_tree( + channels, + &[(zed_channel, &[]), (vim_channel_id, &[zed_channel])], + ) +} + +test_both_dbs!( + test_user_joins_correct_channel, + test_user_joins_correct_channel_postgres, + test_user_joins_correct_channel_sqlite +); + +async fn test_user_joins_correct_channel(db: &Arc) { + let admin = new_test_user(db, "admin@example.com").await; + + let zed_channel = db.create_root_channel("zed", admin).await.unwrap(); + + let active_channel = db + .create_sub_channel("active", zed_channel, admin) + .await + .unwrap(); + + let vim_channel = db + .create_sub_channel("vim", active_channel, admin) + .await + .unwrap(); + + let vim2_channel = db + .create_sub_channel("vim2", vim_channel, admin) + .await + .unwrap(); + + db.set_channel_visibility(zed_channel, crate::db::ChannelVisibility::Public, admin) + .await + .unwrap(); + + db.set_channel_visibility(vim_channel, crate::db::ChannelVisibility::Public, admin) + .await + .unwrap(); + + db.set_channel_visibility(vim2_channel, crate::db::ChannelVisibility::Public, admin) + .await + .unwrap(); + + let most_public = db + .transaction(|tx| async move { + Ok(db + .public_ancestors_including_self( + &db.get_channel_internal(vim_channel, &*tx).await.unwrap(), + &tx, + ) + .await? + .first() + .cloned()) + }) + .await + .unwrap() + .unwrap() + .id; + + assert_eq!(most_public, zed_channel) +} + +test_both_dbs!( + test_guest_access, + test_guest_access_postgres, + test_guest_access_sqlite +); + +async fn test_guest_access(db: &Arc) { + let server = db.create_server("test").await.unwrap(); + + let admin = new_test_user(db, "admin@example.com").await; + let guest = new_test_user(db, "guest@example.com").await; + let guest_connection = new_test_connection(server); + + let zed_channel = db.create_root_channel("zed", admin).await.unwrap(); + db.set_channel_visibility(zed_channel, crate::db::ChannelVisibility::Public, admin) + .await + .unwrap(); + + assert!(db + .join_channel_chat(zed_channel, guest_connection, guest) + .await + .is_err()); + + db.join_channel(zed_channel, guest, guest_connection, TEST_RELEASE_CHANNEL) + .await + .unwrap(); + + assert!(db + .join_channel_chat(zed_channel, guest_connection, guest) + .await + .is_ok()) +} + +#[track_caller] +fn assert_channel_tree(actual: Vec, expected: &[(ChannelId, &[ChannelId])]) { + let actual = actual + .iter() + .map(|channel| (channel.id, channel.parent_path.as_slice())) + .collect::>(); + pretty_assertions::assert_eq!( + actual, + expected.to_vec(), + "wrong channel ids and parent paths" + ); } diff --git a/crates/collab/src/db/tests/db_tests.rs b/crates/collab/src/db/tests/db_tests.rs index 1520e081c0..c4b82f8cec 100644 --- a/crates/collab/src/db/tests/db_tests.rs +++ b/crates/collab/src/db/tests/db_tests.rs @@ -22,7 +22,6 @@ async fn test_get_users(db: &Arc) { NewUserParams { github_login: format!("user{i}"), github_user_id: i, - invite_count: 0, }, ) .await @@ -88,7 +87,6 @@ async fn test_get_or_create_user_by_github_account(db: &Arc) { NewUserParams { github_login: "login1".into(), github_user_id: 101, - invite_count: 0, }, ) .await @@ -101,7 +99,6 @@ async fn test_get_or_create_user_by_github_account(db: &Arc) { NewUserParams { github_login: "login2".into(), github_user_id: 102, - invite_count: 0, }, ) .await @@ -156,7 +153,6 @@ async fn test_create_access_tokens(db: &Arc) { NewUserParams { github_login: "u1".into(), github_user_id: 1, - invite_count: 0, }, ) .await @@ -238,7 +234,6 @@ async fn test_add_contacts(db: &Arc) { NewUserParams { github_login: format!("user{i}"), github_user_id: i, - invite_count: 0, }, ) .await @@ -264,10 +259,7 @@ async fn test_add_contacts(db: &Arc) { ); assert_eq!( db.get_contacts(user_2).await.unwrap(), - &[Contact::Incoming { - user_id: user_1, - should_notify: true - }] + &[Contact::Incoming { user_id: user_1 }] ); // User 2 dismisses the contact request notification without accepting or rejecting. @@ -280,10 +272,7 @@ async fn test_add_contacts(db: &Arc) { .unwrap(); assert_eq!( db.get_contacts(user_2).await.unwrap(), - &[Contact::Incoming { - user_id: user_1, - should_notify: false - }] + &[Contact::Incoming { user_id: user_1 }] ); // User can't accept their own contact request @@ -299,7 +288,6 @@ async fn test_add_contacts(db: &Arc) { db.get_contacts(user_1).await.unwrap(), &[Contact::Accepted { user_id: user_2, - should_notify: true, busy: false, }], ); @@ -309,7 +297,6 @@ async fn test_add_contacts(db: &Arc) { db.get_contacts(user_2).await.unwrap(), &[Contact::Accepted { user_id: user_1, - should_notify: false, busy: false, }] ); @@ -326,7 +313,6 @@ async fn test_add_contacts(db: &Arc) { db.get_contacts(user_1).await.unwrap(), &[Contact::Accepted { user_id: user_2, - should_notify: true, busy: false, }] ); @@ -339,7 +325,6 @@ async fn test_add_contacts(db: &Arc) { db.get_contacts(user_1).await.unwrap(), &[Contact::Accepted { user_id: user_2, - should_notify: false, busy: false, }] ); @@ -353,12 +338,10 @@ async fn test_add_contacts(db: &Arc) { &[ Contact::Accepted { user_id: user_2, - should_notify: false, busy: false, }, Contact::Accepted { user_id: user_3, - should_notify: false, busy: false, } ] @@ -367,7 +350,6 @@ async fn test_add_contacts(db: &Arc) { db.get_contacts(user_3).await.unwrap(), &[Contact::Accepted { user_id: user_1, - should_notify: false, busy: false, }], ); @@ -383,7 +365,6 @@ async fn test_add_contacts(db: &Arc) { db.get_contacts(user_2).await.unwrap(), &[Contact::Accepted { user_id: user_1, - should_notify: false, busy: false, }] ); @@ -391,7 +372,6 @@ async fn test_add_contacts(db: &Arc) { db.get_contacts(user_3).await.unwrap(), &[Contact::Accepted { user_id: user_1, - should_notify: false, busy: false, }], ); @@ -415,7 +395,6 @@ async fn test_metrics_id(db: &Arc) { NewUserParams { github_login: "person1".into(), github_user_id: 101, - invite_count: 5, }, ) .await @@ -431,7 +410,6 @@ async fn test_metrics_id(db: &Arc) { NewUserParams { github_login: "person2".into(), github_user_id: 102, - invite_count: 5, }, ) .await @@ -460,7 +438,6 @@ async fn test_project_count(db: &Arc) { NewUserParams { github_login: "admin".into(), github_user_id: 0, - invite_count: 0, }, ) .await @@ -472,7 +449,6 @@ async fn test_project_count(db: &Arc) { NewUserParams { github_login: "user".into(), github_user_id: 1, - invite_count: 0, }, ) .await @@ -554,7 +530,6 @@ async fn test_fuzzy_search_users() { NewUserParams { github_login: github_login.into(), github_user_id: i as i32, - invite_count: 0, }, ) .await @@ -596,7 +571,6 @@ async fn test_non_matching_release_channels(db: &Arc) { NewUserParams { github_login: "admin".into(), github_user_id: 0, - invite_count: 0, }, ) .await @@ -608,7 +582,6 @@ async fn test_non_matching_release_channels(db: &Arc) { NewUserParams { github_login: "user".into(), github_user_id: 1, - invite_count: 0, }, ) .await diff --git a/crates/collab/src/db/tests/feature_flag_tests.rs b/crates/collab/src/db/tests/feature_flag_tests.rs index 9d5f039747..0286a6308e 100644 --- a/crates/collab/src/db/tests/feature_flag_tests.rs +++ b/crates/collab/src/db/tests/feature_flag_tests.rs @@ -18,7 +18,6 @@ async fn test_get_user_flags(db: &Arc) { NewUserParams { github_login: format!("user1"), github_user_id: 1, - invite_count: 0, }, ) .await @@ -32,7 +31,6 @@ async fn test_get_user_flags(db: &Arc) { NewUserParams { github_login: format!("user2"), github_user_id: 2, - invite_count: 0, }, ) .await diff --git a/crates/collab/src/db/tests/message_tests.rs b/crates/collab/src/db/tests/message_tests.rs index e758fcfb5d..10d9778612 100644 --- a/crates/collab/src/db/tests/message_tests.rs +++ b/crates/collab/src/db/tests/message_tests.rs @@ -1,7 +1,9 @@ +use super::new_test_user; use crate::{ - db::{Database, MessageId, NewUserParams}, + db::{ChannelRole, Database, MessageId}, test_both_dbs, }; +use channel::mentions_to_proto; use std::sync::Arc; use time::OffsetDateTime; @@ -12,39 +14,38 @@ test_both_dbs!( ); async fn test_channel_message_retrieval(db: &Arc) { - let user = db - .create_user( - "user@example.com", - false, - NewUserParams { - github_login: "user".into(), - github_user_id: 1, - invite_count: 0, - }, - ) - .await - .unwrap() - .user_id; - let channel = db.create_channel("channel", None, user).await.unwrap(); + let user = new_test_user(db, "user@example.com").await; + let result = db.create_channel("channel", None, user).await.unwrap(); let owner_id = db.create_server("test").await.unwrap().0 as u32; - db.join_channel_chat(channel, rpc::ConnectionId { owner_id, id: 0 }, user) - .await - .unwrap(); + db.join_channel_chat( + result.channel.id, + rpc::ConnectionId { owner_id, id: 0 }, + user, + ) + .await + .unwrap(); let mut all_messages = Vec::new(); for i in 0..10 { all_messages.push( - db.create_channel_message(channel, user, &i.to_string(), OffsetDateTime::now_utc(), i) - .await - .unwrap() - .0 - .to_proto(), + db.create_channel_message( + result.channel.id, + user, + &i.to_string(), + &[], + OffsetDateTime::now_utc(), + i, + ) + .await + .unwrap() + .message_id + .to_proto(), ); } let messages = db - .get_channel_messages(channel, user, 3, None) + .get_channel_messages(result.channel.id, user, 3, None) .await .unwrap() .into_iter() @@ -54,7 +55,7 @@ async fn test_channel_message_retrieval(db: &Arc) { let messages = db .get_channel_messages( - channel, + result.channel.id, user, 4, Some(MessageId::from_proto(all_messages[6])), @@ -74,99 +75,154 @@ test_both_dbs!( ); async fn test_channel_message_nonces(db: &Arc) { - let user = db - .create_user( - "user@example.com", - false, - NewUserParams { - github_login: "user".into(), - github_user_id: 1, - invite_count: 0, - }, + let user_a = new_test_user(db, "user_a@example.com").await; + let user_b = new_test_user(db, "user_b@example.com").await; + let user_c = new_test_user(db, "user_c@example.com").await; + let channel = db.create_root_channel("channel", user_a).await.unwrap(); + db.invite_channel_member(channel, user_b, user_a, ChannelRole::Member) + .await + .unwrap(); + db.invite_channel_member(channel, user_c, user_a, ChannelRole::Member) + .await + .unwrap(); + db.respond_to_channel_invite(channel, user_b, true) + .await + .unwrap(); + db.respond_to_channel_invite(channel, user_c, true) + .await + .unwrap(); + + let owner_id = db.create_server("test").await.unwrap().0 as u32; + db.join_channel_chat(channel, rpc::ConnectionId { owner_id, id: 0 }, user_a) + .await + .unwrap(); + db.join_channel_chat(channel, rpc::ConnectionId { owner_id, id: 1 }, user_b) + .await + .unwrap(); + + // As user A, create messages that re-use the same nonces. The requests + // succeed, but return the same ids. + let id1 = db + .create_channel_message( + channel, + user_a, + "hi @user_b", + &mentions_to_proto(&[(3..10, user_b.to_proto())]), + OffsetDateTime::now_utc(), + 100, ) .await .unwrap() - .user_id; - let channel = db.create_channel("channel", None, user).await.unwrap(); + .message_id; + let id2 = db + .create_channel_message( + channel, + user_a, + "hello, fellow users", + &mentions_to_proto(&[]), + OffsetDateTime::now_utc(), + 200, + ) + .await + .unwrap() + .message_id; + let id3 = db + .create_channel_message( + channel, + user_a, + "bye @user_c (same nonce as first message)", + &mentions_to_proto(&[(4..11, user_c.to_proto())]), + OffsetDateTime::now_utc(), + 100, + ) + .await + .unwrap() + .message_id; + let id4 = db + .create_channel_message( + channel, + user_a, + "omg (same nonce as second message)", + &mentions_to_proto(&[]), + OffsetDateTime::now_utc(), + 200, + ) + .await + .unwrap() + .message_id; - let owner_id = db.create_server("test").await.unwrap().0 as u32; + // As a different user, reuse one of the same nonces. This request succeeds + // and returns a different id. + let id5 = db + .create_channel_message( + channel, + user_b, + "omg @user_a (same nonce as user_a's first message)", + &mentions_to_proto(&[(4..11, user_a.to_proto())]), + OffsetDateTime::now_utc(), + 100, + ) + .await + .unwrap() + .message_id; - db.join_channel_chat(channel, rpc::ConnectionId { owner_id, id: 0 }, user) - .await - .unwrap(); + assert_ne!(id1, id2); + assert_eq!(id1, id3); + assert_eq!(id2, id4); + assert_ne!(id5, id1); - let msg1_id = db - .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1) + let messages = db + .get_channel_messages(channel, user_a, 5, None) .await - .unwrap(); - let msg2_id = db - .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2) - .await - .unwrap(); - let msg3_id = db - .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1) - .await - .unwrap(); - let msg4_id = db - .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2) - .await - .unwrap(); - - assert_ne!(msg1_id, msg2_id); - assert_eq!(msg1_id, msg3_id); - assert_eq!(msg2_id, msg4_id); + .unwrap() + .into_iter() + .map(|m| (m.id, m.body, m.mentions)) + .collect::>(); + assert_eq!( + messages, + &[ + ( + id1.to_proto(), + "hi @user_b".into(), + mentions_to_proto(&[(3..10, user_b.to_proto())]), + ), + ( + id2.to_proto(), + "hello, fellow users".into(), + mentions_to_proto(&[]) + ), + ( + id5.to_proto(), + "omg @user_a (same nonce as user_a's first message)".into(), + mentions_to_proto(&[(4..11, user_a.to_proto())]), + ), + ] + ); } test_both_dbs!( - test_channel_message_new_notification, - test_channel_message_new_notification_postgres, - test_channel_message_new_notification_sqlite + test_unseen_channel_messages, + test_unseen_channel_messages_postgres, + test_unseen_channel_messages_sqlite ); -async fn test_channel_message_new_notification(db: &Arc) { - let user = db - .create_user( - "user_a@example.com", - false, - NewUserParams { - github_login: "user_a".into(), - github_user_id: 1, - invite_count: 0, - }, - ) +async fn test_unseen_channel_messages(db: &Arc) { + let user = new_test_user(db, "user_a@example.com").await; + let observer = new_test_user(db, "user_b@example.com").await; + + let channel_1 = db.create_root_channel("channel", user).await.unwrap(); + let channel_2 = db.create_root_channel("channel-2", user).await.unwrap(); + + db.invite_channel_member(channel_1, observer, user, ChannelRole::Member) .await - .unwrap() - .user_id; - let observer = db - .create_user( - "user_b@example.com", - false, - NewUserParams { - github_login: "user_b".into(), - github_user_id: 1, - invite_count: 0, - }, - ) - .await - .unwrap() - .user_id; - - let channel_1 = db.create_channel("channel", None, user).await.unwrap(); - - let channel_2 = db.create_channel("channel-2", None, user).await.unwrap(); - - db.invite_channel_member(channel_1, observer, user, false) + .unwrap(); + db.invite_channel_member(channel_2, observer, user, ChannelRole::Member) .await .unwrap(); db.respond_to_channel_invite(channel_1, observer, true) .await .unwrap(); - - db.invite_channel_member(channel_2, observer, user, false) - .await - .unwrap(); - db.respond_to_channel_invite(channel_2, observer, true) .await .unwrap(); @@ -179,28 +235,31 @@ async fn test_channel_message_new_notification(db: &Arc) { .unwrap(); let _ = db - .create_channel_message(channel_1, user, "1_1", OffsetDateTime::now_utc(), 1) + .create_channel_message(channel_1, user, "1_1", &[], OffsetDateTime::now_utc(), 1) .await .unwrap(); - let (second_message, _, _) = db - .create_channel_message(channel_1, user, "1_2", OffsetDateTime::now_utc(), 2) + let second_message = db + .create_channel_message(channel_1, user, "1_2", &[], OffsetDateTime::now_utc(), 2) .await - .unwrap(); + .unwrap() + .message_id; - let (third_message, _, _) = db - .create_channel_message(channel_1, user, "1_3", OffsetDateTime::now_utc(), 3) + let third_message = db + .create_channel_message(channel_1, user, "1_3", &[], OffsetDateTime::now_utc(), 3) .await - .unwrap(); + .unwrap() + .message_id; db.join_channel_chat(channel_2, user_connection_id, user) .await .unwrap(); - let (fourth_message, _, _) = db - .create_channel_message(channel_2, user, "2_1", OffsetDateTime::now_utc(), 4) + let fourth_message = db + .create_channel_message(channel_2, user, "2_1", &[], OffsetDateTime::now_utc(), 4) .await - .unwrap(); + .unwrap() + .message_id; // Check that observer has new messages let unseen_messages = db @@ -295,3 +354,101 @@ async fn test_channel_message_new_notification(db: &Arc) { }] ); } + +test_both_dbs!( + test_channel_message_mentions, + test_channel_message_mentions_postgres, + test_channel_message_mentions_sqlite +); + +async fn test_channel_message_mentions(db: &Arc) { + let user_a = new_test_user(db, "user_a@example.com").await; + let user_b = new_test_user(db, "user_b@example.com").await; + let user_c = new_test_user(db, "user_c@example.com").await; + + let channel = db + .create_channel("channel", None, user_a) + .await + .unwrap() + .channel + .id; + db.invite_channel_member(channel, user_b, user_a, ChannelRole::Member) + .await + .unwrap(); + db.respond_to_channel_invite(channel, user_b, true) + .await + .unwrap(); + + let owner_id = db.create_server("test").await.unwrap().0 as u32; + let connection_id = rpc::ConnectionId { owner_id, id: 0 }; + db.join_channel_chat(channel, connection_id, user_a) + .await + .unwrap(); + + db.create_channel_message( + channel, + user_a, + "hi @user_b and @user_c", + &mentions_to_proto(&[(3..10, user_b.to_proto()), (15..22, user_c.to_proto())]), + OffsetDateTime::now_utc(), + 1, + ) + .await + .unwrap(); + db.create_channel_message( + channel, + user_a, + "bye @user_c", + &mentions_to_proto(&[(4..11, user_c.to_proto())]), + OffsetDateTime::now_utc(), + 2, + ) + .await + .unwrap(); + db.create_channel_message( + channel, + user_a, + "umm", + &mentions_to_proto(&[]), + OffsetDateTime::now_utc(), + 3, + ) + .await + .unwrap(); + db.create_channel_message( + channel, + user_a, + "@user_b, stop.", + &mentions_to_proto(&[(0..7, user_b.to_proto())]), + OffsetDateTime::now_utc(), + 4, + ) + .await + .unwrap(); + + let messages = db + .get_channel_messages(channel, user_b, 5, None) + .await + .unwrap() + .into_iter() + .map(|m| (m.body, m.mentions)) + .collect::>(); + assert_eq!( + &messages, + &[ + ( + "hi @user_b and @user_c".into(), + mentions_to_proto(&[(3..10, user_b.to_proto()), (15..22, user_c.to_proto())]), + ), + ( + "bye @user_c".into(), + mentions_to_proto(&[(4..11, user_c.to_proto())]), + ), + ("umm".into(), mentions_to_proto(&[]),), + ( + "@user_b, stop.".into(), + mentions_to_proto(&[(0..7, user_b.to_proto())]), + ), + ] + ); +} diff --git a/crates/collab/src/lib.rs b/crates/collab/src/lib.rs index 13fb8ed0eb..85216525b0 100644 --- a/crates/collab/src/lib.rs +++ b/crates/collab/src/lib.rs @@ -119,7 +119,9 @@ impl AppState { pub async fn new(config: Config) -> Result> { let mut db_options = db::ConnectOptions::new(config.database_url.clone()); db_options.max_connections(config.database_max_connections); - let db = Database::new(db_options, Executor::Production).await?; + let mut db = Database::new(db_options, Executor::Production).await?; + db.initialize_notification_kinds().await?; + let live_kit_client = if let Some(((server, key), secret)) = config .live_kit_server .as_ref() diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index e5c6d94ce0..7e847e8bff 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -3,8 +3,11 @@ mod connection_pool; use crate::{ auth, db::{ - self, BufferId, ChannelId, ChannelsForUser, Database, MessageId, ProjectId, RoomId, - ServerId, User, UserId, + self, BufferId, ChannelId, ChannelRole, ChannelsForUser, CreateChannelResult, + CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId, + MoveChannelResult, NotificationId, ProjectId, RemoveChannelMemberResult, + RenameChannelResult, RespondToChannelInvite, RoomId, ServerId, SetChannelVisibilityResult, + User, UserId, }, executor::Executor, AppState, Result, @@ -38,8 +41,8 @@ use lazy_static::lazy_static; use prometheus::{register_int_gauge, IntGauge}; use rpc::{ proto::{ - self, Ack, AnyTypedEnvelope, ChannelEdge, EntityMessage, EnvelopedMessage, - LiveKitConnectionInfo, RequestMessage, UpdateChannelBufferCollaborators, + self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo, + RequestMessage, UpdateChannelBufferCollaborators, }, Connection, ConnectionId, Peer, Receipt, TypedEnvelope, }; @@ -70,6 +73,7 @@ pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10); const MESSAGE_COUNT_PER_PAGE: usize = 100; const MAX_MESSAGE_LEN: usize = 1024; +const NOTIFICATION_COUNT_PER_PAGE: usize = 50; lazy_static! { static ref METRIC_CONNECTIONS: IntGauge = @@ -225,6 +229,7 @@ impl Server { .add_request_handler(forward_project_request::) .add_request_handler(forward_project_request::) .add_request_handler(forward_project_request::) + .add_request_handler(forward_project_request::) .add_request_handler(forward_project_request::) .add_request_handler(forward_project_request::) .add_request_handler(forward_project_request::) @@ -254,7 +259,8 @@ impl Server { .add_request_handler(delete_channel) .add_request_handler(invite_channel_member) .add_request_handler(remove_channel_member) - .add_request_handler(set_channel_member_admin) + .add_request_handler(set_channel_member_role) + .add_request_handler(set_channel_visibility) .add_request_handler(rename_channel) .add_request_handler(join_channel_buffer) .add_request_handler(leave_channel_buffer) @@ -268,8 +274,9 @@ impl Server { .add_request_handler(send_channel_message) .add_request_handler(remove_channel_message) .add_request_handler(get_channel_messages) - .add_request_handler(link_channel) - .add_request_handler(unlink_channel) + .add_request_handler(get_channel_messages_by_id) + .add_request_handler(get_notifications) + .add_request_handler(mark_notification_as_read) .add_request_handler(move_channel) .add_request_handler(follow) .add_message_handler(unfollow) @@ -387,7 +394,7 @@ impl Server { let contacts = app_state.db.get_contacts(user_id).await.trace_err(); if let Some((busy, contacts)) = busy.zip(contacts) { let pool = pool.lock(); - let updated_contact = contact_for_user(user_id, false, busy, &pool); + let updated_contact = contact_for_user(user_id, busy, &pool); for contact in contacts { if let db::Contact::Accepted { user_id: contact_user_id, @@ -581,14 +588,14 @@ impl Server { let (contacts, channels_for_user, channel_invites) = future::try_join3( this.app_state.db.get_contacts(user_id), this.app_state.db.get_channels_for_user(user_id), - this.app_state.db.get_channel_invites_for_user(user_id) + this.app_state.db.get_channel_invites_for_user(user_id), ).await?; { let mut pool = this.connection_pool.lock(); pool.add_connection(connection_id, user_id, user.admin); this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?; - this.peer.send(connection_id, build_initial_channels_update( + this.peer.send(connection_id, build_channels_update( channels_for_user, channel_invites ))?; @@ -687,7 +694,7 @@ impl Server { if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? { if let Some(code) = &user.invite_code { let pool = self.connection_pool.lock(); - let invitee_contact = contact_for_user(invitee_id, true, false, &pool); + let invitee_contact = contact_for_user(invitee_id, false, &pool); for connection_id in pool.user_connection_ids(inviter_id) { self.peer.send( connection_id, @@ -935,7 +942,7 @@ async fn create_room( let live_kit_room = live_kit_room.clone(); let live_kit = session.live_kit_client.as_ref(); - util::async_iife!({ + util::async_maybe!({ let live_kit = live_kit?; let token = live_kit @@ -945,6 +952,7 @@ async fn create_room( Some(proto::LiveKitConnectionInfo { server_url: live_kit.url().into(), token, + can_publish: true, }) }) } @@ -976,6 +984,13 @@ async fn join_room( session: Session, ) -> Result<()> { let room_id = RoomId::from_proto(request.id); + + let channel_id = session.db().await.channel_id_for_room(room_id).await?; + + if let Some(channel_id) = channel_id { + return join_channel_internal(channel_id, Box::new(response), session).await; + } + let joined_room = { let room = session .db() @@ -991,16 +1006,6 @@ async fn join_room( room.into_inner() }; - if let Some(channel_id) = joined_room.channel_id { - channel_updated( - channel_id, - &joined_room.room, - &joined_room.channel_members, - &session.peer, - &*session.connection_pool().await, - ) - } - for connection_id in session .connection_pool() .await @@ -1028,6 +1033,7 @@ async fn join_room( Some(proto::LiveKitConnectionInfo { server_url: live_kit.url().into(), token, + can_publish: true, }) } else { None @@ -1038,7 +1044,7 @@ async fn join_room( response.send(proto::JoinRoomResponse { room: Some(joined_room.room), - channel_id: joined_room.channel_id.map(|id| id.to_proto()), + channel_id: None, live_kit_connection_info, })?; @@ -2064,7 +2070,7 @@ async fn request_contact( return Err(anyhow!("cannot add yourself as a contact"))?; } - session + let notifications = session .db() .await .send_contact_request(requester_id, responder_id) @@ -2087,16 +2093,14 @@ async fn request_contact( .incoming_requests .push(proto::IncomingContactRequest { requester_id: requester_id.to_proto(), - should_notify: true, }); - for connection_id in session - .connection_pool() - .await - .user_connection_ids(responder_id) - { + let connection_pool = session.connection_pool().await; + for connection_id in connection_pool.user_connection_ids(responder_id) { session.peer.send(connection_id, update.clone())?; } + send_notifications(&*connection_pool, &session.peer, notifications); + response.send(proto::Ack {})?; Ok(()) } @@ -2115,7 +2119,8 @@ async fn respond_to_contact_request( } else { let accept = request.response == proto::ContactRequestResponse::Accept as i32; - db.respond_to_contact_request(responder_id, requester_id, accept) + let notifications = db + .respond_to_contact_request(responder_id, requester_id, accept) .await?; let requester_busy = db.is_user_busy(requester_id).await?; let responder_busy = db.is_user_busy(responder_id).await?; @@ -2126,7 +2131,7 @@ async fn respond_to_contact_request( if accept { update .contacts - .push(contact_for_user(requester_id, false, requester_busy, &pool)); + .push(contact_for_user(requester_id, requester_busy, &pool)); } update .remove_incoming_requests @@ -2140,14 +2145,17 @@ async fn respond_to_contact_request( if accept { update .contacts - .push(contact_for_user(responder_id, true, responder_busy, &pool)); + .push(contact_for_user(responder_id, responder_busy, &pool)); } update .remove_outgoing_requests .push(responder_id.to_proto()); + for connection_id in pool.user_connection_ids(requester_id) { session.peer.send(connection_id, update.clone())?; } + + send_notifications(&*pool, &session.peer, notifications); } response.send(proto::Ack {})?; @@ -2162,7 +2170,8 @@ async fn remove_contact( let requester_id = session.user_id; let responder_id = UserId::from_proto(request.user_id); let db = session.db().await; - let contact_accepted = db.remove_contact(requester_id, responder_id).await?; + let (contact_accepted, deleted_notification_id) = + db.remove_contact(requester_id, responder_id).await?; let pool = session.connection_pool().await; // Update outgoing contact requests of requester @@ -2189,6 +2198,14 @@ async fn remove_contact( } for connection_id in pool.user_connection_ids(responder_id) { session.peer.send(connection_id, update.clone())?; + if let Some(notification_id) = deleted_notification_id { + session.peer.send( + connection_id, + proto::DeleteNotification { + notification_id: notification_id.to_proto(), + }, + )?; + } } response.send(proto::Ack {})?; @@ -2203,37 +2220,21 @@ async fn create_channel( let db = session.db().await; let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id)); - let id = db + let CreateChannelResult { + channel, + participants_to_update, + } = db .create_channel(&request.name, parent_id, session.user_id) .await?; - let channel = proto::Channel { - id: id.to_proto(), - name: request.name, - }; - response.send(proto::CreateChannelResponse { - channel: Some(channel.clone()), + channel: Some(channel.to_proto()), parent_id: request.parent_id, })?; - let Some(parent_id) = parent_id else { - return Ok(()); - }; - - let update = proto::UpdateChannels { - channels: vec![channel], - insert_edge: vec![ChannelEdge { - parent_id: parent_id.to_proto(), - channel_id: id.to_proto(), - }], - ..Default::default() - }; - - let user_ids_to_notify = db.get_channel_members(parent_id).await?; - let connection_pool = session.connection_pool().await; - for user_id in user_ids_to_notify { + for (user_id, channels) in participants_to_update { + let update = build_channels_update(channels, vec![]); for connection_id in connection_pool.user_connection_ids(user_id) { if user_id == session.user_id { continue; @@ -2282,27 +2283,30 @@ async fn invite_channel_member( let db = session.db().await; let channel_id = ChannelId::from_proto(request.channel_id); let invitee_id = UserId::from_proto(request.user_id); - db.invite_channel_member(channel_id, invitee_id, session.user_id, request.admin) + let InviteMemberResult { + channel, + notifications, + } = db + .invite_channel_member( + channel_id, + invitee_id, + session.user_id, + request.role().into(), + ) .await?; - let (channel, _) = db - .get_channel(channel_id, session.user_id) - .await? - .ok_or_else(|| anyhow!("channel not found"))?; + let update = proto::UpdateChannels { + channel_invitations: vec![channel.to_proto()], + ..Default::default() + }; - let mut update = proto::UpdateChannels::default(); - update.channel_invitations.push(proto::Channel { - id: channel.id.to_proto(), - name: channel.name, - }); - for connection_id in session - .connection_pool() - .await - .user_connection_ids(invitee_id) - { + let connection_pool = session.connection_pool().await; + for connection_id in connection_pool.user_connection_ids(invitee_id) { session.peer.send(connection_id, update.clone())?; } + send_notifications(&*connection_pool, &session.peer, notifications); + response.send(proto::Ack {})?; Ok(()) } @@ -2316,54 +2320,117 @@ async fn remove_channel_member( let channel_id = ChannelId::from_proto(request.channel_id); let member_id = UserId::from_proto(request.user_id); - db.remove_channel_member(channel_id, member_id, session.user_id) + let RemoveChannelMemberResult { + membership_update, + notification_id, + } = db + .remove_channel_member(channel_id, member_id, session.user_id) .await?; - let mut update = proto::UpdateChannels::default(); - update.delete_channels.push(channel_id.to_proto()); - - for connection_id in session - .connection_pool() - .await - .user_connection_ids(member_id) - { - session.peer.send(connection_id, update.clone())?; + let connection_pool = &session.connection_pool().await; + notify_membership_updated( + &connection_pool, + membership_update, + member_id, + &session.peer, + ); + for connection_id in connection_pool.user_connection_ids(member_id) { + if let Some(notification_id) = notification_id { + session + .peer + .send( + connection_id, + proto::DeleteNotification { + notification_id: notification_id.to_proto(), + }, + ) + .trace_err(); + } } response.send(proto::Ack {})?; Ok(()) } -async fn set_channel_member_admin( - request: proto::SetChannelMemberAdmin, - response: Response, +async fn set_channel_visibility( + request: proto::SetChannelVisibility, + response: Response, + session: Session, +) -> Result<()> { + let db = session.db().await; + let channel_id = ChannelId::from_proto(request.channel_id); + let visibility = request.visibility().into(); + + let SetChannelVisibilityResult { + participants_to_update, + participants_to_remove, + channels_to_remove, + } = db + .set_channel_visibility(channel_id, visibility, session.user_id) + .await?; + + let connection_pool = session.connection_pool().await; + for (user_id, channels) in participants_to_update { + let update = build_channels_update(channels, vec![]); + for connection_id in connection_pool.user_connection_ids(user_id) { + session.peer.send(connection_id, update.clone())?; + } + } + for user_id in participants_to_remove { + let update = proto::UpdateChannels { + delete_channels: channels_to_remove.iter().map(|id| id.to_proto()).collect(), + ..Default::default() + }; + for connection_id in connection_pool.user_connection_ids(user_id) { + session.peer.send(connection_id, update.clone())?; + } + } + + response.send(proto::Ack {})?; + Ok(()) +} + +async fn set_channel_member_role( + request: proto::SetChannelMemberRole, + response: Response, session: Session, ) -> Result<()> { let db = session.db().await; let channel_id = ChannelId::from_proto(request.channel_id); let member_id = UserId::from_proto(request.user_id); - db.set_channel_member_admin(channel_id, session.user_id, member_id, request.admin) + let result = db + .set_channel_member_role( + channel_id, + session.user_id, + member_id, + request.role().into(), + ) .await?; - let (channel, has_accepted) = db - .get_channel(channel_id, member_id) - .await? - .ok_or_else(|| anyhow!("channel not found"))?; + match result { + db::SetMemberRoleResult::MembershipUpdated(membership_update) => { + let connection_pool = session.connection_pool().await; + notify_membership_updated( + &connection_pool, + membership_update, + member_id, + &session.peer, + ) + } + db::SetMemberRoleResult::InviteUpdated(channel) => { + let update = proto::UpdateChannels { + channel_invitations: vec![channel.to_proto()], + ..Default::default() + }; - let mut update = proto::UpdateChannels::default(); - if has_accepted { - update.channel_permissions.push(proto::ChannelPermission { - channel_id: channel.id.to_proto(), - is_admin: request.admin, - }); - } - - for connection_id in session - .connection_pool() - .await - .user_connection_ids(member_id) - { - session.peer.send(connection_id, update.clone())?; + for connection_id in session + .connection_pool() + .await + .user_connection_ids(member_id) + { + session.peer.send(connection_id, update.clone())?; + } + } } response.send(proto::Ack {})?; @@ -2377,25 +2444,25 @@ async fn rename_channel( ) -> Result<()> { let db = session.db().await; let channel_id = ChannelId::from_proto(request.channel_id); - let new_name = db + let RenameChannelResult { + channel, + participants_to_update, + } = db .rename_channel(channel_id, session.user_id, &request.name) .await?; - let channel = proto::Channel { - id: request.channel_id, - name: new_name, - }; response.send(proto::RenameChannelResponse { - channel: Some(channel.clone()), + channel: Some(channel.to_proto()), })?; - let mut update = proto::UpdateChannels::default(); - update.channels.push(channel); - - let member_ids = db.get_channel_members(channel_id).await?; let connection_pool = session.connection_pool().await; - for member_id in member_ids { - for connection_id in connection_pool.user_connection_ids(member_id) { + for (user_id, channel) in participants_to_update { + for connection_id in connection_pool.user_connection_ids(user_id) { + let update = proto::UpdateChannels { + channels: vec![channel.to_proto()], + ..Default::default() + }; + session.peer.send(connection_id, update.clone())?; } } @@ -2403,129 +2470,55 @@ async fn rename_channel( Ok(()) } -async fn link_channel( - request: proto::LinkChannel, - response: Response, - session: Session, -) -> Result<()> { - let db = session.db().await; - let channel_id = ChannelId::from_proto(request.channel_id); - let to = ChannelId::from_proto(request.to); - let channels_to_send = db.link_channel(session.user_id, channel_id, to).await?; - - let members = db.get_channel_members(to).await?; - let connection_pool = session.connection_pool().await; - let update = proto::UpdateChannels { - channels: channels_to_send - .channels - .into_iter() - .map(|channel| proto::Channel { - id: channel.id.to_proto(), - name: channel.name, - }) - .collect(), - insert_edge: channels_to_send.edges, - ..Default::default() - }; - for member_id in members { - for connection_id in connection_pool.user_connection_ids(member_id) { - session.peer.send(connection_id, update.clone())?; - } - } - - response.send(Ack {})?; - - Ok(()) -} - -async fn unlink_channel( - request: proto::UnlinkChannel, - response: Response, - session: Session, -) -> Result<()> { - let db = session.db().await; - let channel_id = ChannelId::from_proto(request.channel_id); - let from = ChannelId::from_proto(request.from); - - db.unlink_channel(session.user_id, channel_id, from).await?; - - let members = db.get_channel_members(from).await?; - - let update = proto::UpdateChannels { - delete_edge: vec![proto::ChannelEdge { - channel_id: channel_id.to_proto(), - parent_id: from.to_proto(), - }], - ..Default::default() - }; - let connection_pool = session.connection_pool().await; - for member_id in members { - for connection_id in connection_pool.user_connection_ids(member_id) { - session.peer.send(connection_id, update.clone())?; - } - } - - response.send(Ack {})?; - - Ok(()) -} - async fn move_channel( request: proto::MoveChannel, response: Response, session: Session, ) -> Result<()> { - let db = session.db().await; let channel_id = ChannelId::from_proto(request.channel_id); - let from_parent = ChannelId::from_proto(request.from); - let to = ChannelId::from_proto(request.to); + let to = request.to.map(ChannelId::from_proto); - let channels_to_send = db - .move_channel(session.user_id, channel_id, from_parent, to) + let result = session + .db() + .await + .move_channel(channel_id, to, session.user_id) .await?; - if channels_to_send.is_empty() { - response.send(Ack {})?; - return Ok(()); - } - - let members_from = db.get_channel_members(from_parent).await?; - let members_to = db.get_channel_members(to).await?; - - let update = proto::UpdateChannels { - delete_edge: vec![proto::ChannelEdge { - channel_id: channel_id.to_proto(), - parent_id: from_parent.to_proto(), - }], - ..Default::default() - }; - let connection_pool = session.connection_pool().await; - for member_id in members_from { - for connection_id in connection_pool.user_connection_ids(member_id) { - session.peer.send(connection_id, update.clone())?; - } - } - - let update = proto::UpdateChannels { - channels: channels_to_send - .channels - .into_iter() - .map(|channel| proto::Channel { - id: channel.id.to_proto(), - name: channel.name, - }) - .collect(), - insert_edge: channels_to_send.edges, - ..Default::default() - }; - for member_id in members_to { - for connection_id in connection_pool.user_connection_ids(member_id) { - session.peer.send(connection_id, update.clone())?; - } - } + notify_channel_moved(result, session).await?; response.send(Ack {})?; + Ok(()) +} +async fn notify_channel_moved(result: Option, session: Session) -> Result<()> { + let Some(MoveChannelResult { + participants_to_remove, + participants_to_update, + moved_channels, + }) = result + else { + return Ok(()); + }; + let moved_channels: Vec = moved_channels.iter().map(|id| id.to_proto()).collect(); + + let connection_pool = session.connection_pool().await; + for (user_id, channels) in participants_to_update { + let mut update = build_channels_update(channels, vec![]); + update.delete_channels = moved_channels.clone(); + for connection_id in connection_pool.user_connection_ids(user_id) { + session.peer.send(connection_id, update.clone())?; + } + } + + for user_id in participants_to_remove { + let update = proto::UpdateChannels { + delete_channels: moved_channels.clone(), + ..Default::default() + }; + for connection_id in connection_pool.user_connection_ids(user_id) { + session.peer.send(connection_id, update.clone())?; + } + } Ok(()) } @@ -2537,7 +2530,7 @@ async fn get_channel_members( let db = session.db().await; let channel_id = ChannelId::from_proto(request.channel_id); let members = db - .get_channel_member_details(channel_id, session.user_id) + .get_channel_participant_details(channel_id, session.user_id) .await?; response.send(proto::GetChannelMembersResponse { members })?; Ok(()) @@ -2550,54 +2543,34 @@ async fn respond_to_channel_invite( ) -> Result<()> { let db = session.db().await; let channel_id = ChannelId::from_proto(request.channel_id); - db.respond_to_channel_invite(channel_id, session.user_id, request.accept) + let RespondToChannelInvite { + membership_update, + notifications, + } = db + .respond_to_channel_invite(channel_id, session.user_id, request.accept) .await?; - let mut update = proto::UpdateChannels::default(); - update - .remove_channel_invitations - .push(channel_id.to_proto()); - if request.accept { - let result = db.get_channel_for_user(channel_id, session.user_id).await?; - update - .channels - .extend( - result - .channels - .channels - .into_iter() - .map(|channel| proto::Channel { - id: channel.id.to_proto(), - name: channel.name, - }), - ); - update.unseen_channel_messages = result.channel_messages; - update.unseen_channel_buffer_changes = result.unseen_buffer_changes; - update.insert_edge = result.channels.edges; - update - .channel_participants - .extend( - result - .channel_participants - .into_iter() - .map(|(channel_id, user_ids)| proto::ChannelParticipants { - channel_id: channel_id.to_proto(), - participant_user_ids: user_ids.into_iter().map(UserId::to_proto).collect(), - }), - ); - update - .channel_permissions - .extend( - result - .channels_with_admin_privileges - .into_iter() - .map(|channel_id| proto::ChannelPermission { - channel_id: channel_id.to_proto(), - is_admin: true, - }), - ); - } - session.peer.send(session.connection_id, update)?; + let connection_pool = session.connection_pool().await; + if let Some(membership_update) = membership_update { + notify_membership_updated( + &connection_pool, + membership_update, + session.user_id, + &session.peer, + ); + } else { + let update = proto::UpdateChannels { + remove_channel_invitations: vec![channel_id.to_proto()], + ..Default::default() + }; + + for connection_id in connection_pool.user_connection_ids(session.user_id) { + session.peer.send(connection_id, update.clone())?; + } + }; + + send_notifications(&*connection_pool, &session.peer, notifications); + response.send(proto::Ack {})?; Ok(()) @@ -2609,19 +2582,35 @@ async fn join_channel( session: Session, ) -> Result<()> { let channel_id = ChannelId::from_proto(request.channel_id); - let live_kit_room = format!("channel-{}", nanoid::nanoid!(30)); + join_channel_internal(channel_id, Box::new(response), session).await +} +trait JoinChannelInternalResponse { + fn send(self, result: proto::JoinRoomResponse) -> Result<()>; +} +impl JoinChannelInternalResponse for Response { + fn send(self, result: proto::JoinRoomResponse) -> Result<()> { + Response::::send(self, result) + } +} +impl JoinChannelInternalResponse for Response { + fn send(self, result: proto::JoinRoomResponse) -> Result<()> { + Response::::send(self, result) + } +} + +async fn join_channel_internal( + channel_id: ChannelId, + response: Box, + session: Session, +) -> Result<()> { let joined_room = { leave_room_for_session(&session).await?; let db = session.db().await; - let room_id = db - .get_or_create_channel_room(channel_id, &live_kit_room, &*RELEASE_CHANNEL_NAME) - .await?; - - let joined_room = db - .join_room( - room_id, + let (joined_room, membership_updated, role) = db + .join_channel( + channel_id, session.user_id, session.connection_id, RELEASE_CHANNEL_NAME.as_str(), @@ -2629,16 +2618,32 @@ async fn join_channel( .await?; let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| { - let token = live_kit - .room_token( - &joined_room.room.live_kit_room, - &session.user_id.to_string(), + let (can_publish, token) = if role == ChannelRole::Guest { + ( + false, + live_kit + .guest_token( + &joined_room.room.live_kit_room, + &session.user_id.to_string(), + ) + .trace_err()?, ) - .trace_err()?; + } else { + ( + true, + live_kit + .room_token( + &joined_room.room.live_kit_room, + &session.user_id.to_string(), + ) + .trace_err()?, + ) + }; Some(LiveKitConnectionInfo { server_url: live_kit.url().into(), token, + can_publish, }) }); @@ -2648,9 +2653,19 @@ async fn join_channel( live_kit_connection_info, })?; + let connection_pool = session.connection_pool().await; + if let Some(membership_updated) = membership_updated { + notify_membership_updated( + &connection_pool, + membership_updated, + session.user_id, + &session.peer, + ); + } + room_updated(&joined_room.room, &session.peer); - joined_room.into_inner() + joined_room }; channel_updated( @@ -2662,7 +2677,6 @@ async fn join_channel( ); update_user_contacts(session.user_id, &session).await?; - Ok(()) } @@ -2815,6 +2829,29 @@ fn channel_buffer_updated( }); } +fn send_notifications( + connection_pool: &ConnectionPool, + peer: &Peer, + notifications: db::NotificationBatch, +) { + for (user_id, notification) in notifications { + for connection_id in connection_pool.user_connection_ids(user_id) { + if let Err(error) = peer.send( + connection_id, + proto::AddNotification { + notification: Some(notification.clone()), + }, + ) { + tracing::error!( + "failed to send notification to {:?} {}", + connection_id, + error + ); + } + } + } +} + async fn send_channel_message( request: proto::SendChannelMessage, response: Response, @@ -2829,19 +2866,27 @@ async fn send_channel_message( return Err(anyhow!("message can't be blank"))?; } + // TODO: adjust mentions if body is trimmed + let timestamp = OffsetDateTime::now_utc(); let nonce = request .nonce .ok_or_else(|| anyhow!("nonce can't be blank"))?; let channel_id = ChannelId::from_proto(request.channel_id); - let (message_id, connection_ids, non_participants) = session + let CreatedChannelMessage { + message_id, + participant_connection_ids, + channel_members, + notifications, + } = session .db() .await .create_channel_message( channel_id, session.user_id, &body, + &request.mentions, timestamp, nonce.clone().into(), ) @@ -2850,18 +2895,23 @@ async fn send_channel_message( sender_id: session.user_id.to_proto(), id: message_id.to_proto(), body, + mentions: request.mentions, timestamp: timestamp.unix_timestamp() as u64, nonce: Some(nonce), }; - broadcast(Some(session.connection_id), connection_ids, |connection| { - session.peer.send( - connection, - proto::ChannelMessageSent { - channel_id: channel_id.to_proto(), - message: Some(message.clone()), - }, - ) - }); + broadcast( + Some(session.connection_id), + participant_connection_ids, + |connection| { + session.peer.send( + connection, + proto::ChannelMessageSent { + channel_id: channel_id.to_proto(), + message: Some(message.clone()), + }, + ) + }, + ); response.send(proto::SendChannelMessageResponse { message: Some(message), })?; @@ -2869,7 +2919,7 @@ async fn send_channel_message( let pool = &*session.connection_pool().await; broadcast( None, - non_participants + channel_members .iter() .flat_map(|user_id| pool.user_connection_ids(*user_id)), |peer_id| { @@ -2885,6 +2935,7 @@ async fn send_channel_message( ) }, ); + send_notifications(pool, &session.peer, notifications); Ok(()) } @@ -2914,11 +2965,16 @@ async fn acknowledge_channel_message( ) -> Result<()> { let channel_id = ChannelId::from_proto(request.channel_id); let message_id = MessageId::from_proto(request.message_id); - session + let notifications = session .db() .await .observe_channel_message(channel_id, session.user_id, message_id) .await?; + send_notifications( + &*session.connection_pool().await, + &session.peer, + notifications, + ); Ok(()) } @@ -2993,6 +3049,72 @@ async fn get_channel_messages( Ok(()) } +async fn get_channel_messages_by_id( + request: proto::GetChannelMessagesById, + response: Response, + session: Session, +) -> Result<()> { + let message_ids = request + .message_ids + .iter() + .map(|id| MessageId::from_proto(*id)) + .collect::>(); + let messages = session + .db() + .await + .get_channel_messages_by_id(session.user_id, &message_ids) + .await?; + response.send(proto::GetChannelMessagesResponse { + done: messages.len() < MESSAGE_COUNT_PER_PAGE, + messages, + })?; + Ok(()) +} + +async fn get_notifications( + request: proto::GetNotifications, + response: Response, + session: Session, +) -> Result<()> { + let notifications = session + .db() + .await + .get_notifications( + session.user_id, + NOTIFICATION_COUNT_PER_PAGE, + request + .before_id + .map(|id| db::NotificationId::from_proto(id)), + ) + .await?; + response.send(proto::GetNotificationsResponse { + done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE, + notifications, + })?; + Ok(()) +} + +async fn mark_notification_as_read( + request: proto::MarkNotificationRead, + response: Response, + session: Session, +) -> Result<()> { + let database = &session.db().await; + let notifications = database + .mark_notification_as_read_by_id( + session.user_id, + NotificationId::from_proto(request.notification_id), + ) + .await?; + send_notifications( + &*session.connection_pool().await, + &session.peer, + notifications, + ); + response.send(proto::Ack {})?; + Ok(()) +} + async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> { let project_id = ProjectId::from_proto(request.project_id); let project_connection_ids = session @@ -3062,22 +3184,37 @@ fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage { } } -fn build_initial_channels_update( +fn notify_membership_updated( + connection_pool: &ConnectionPool, + result: MembershipUpdated, + user_id: UserId, + peer: &Peer, +) { + let mut update = build_channels_update(result.new_channels, vec![]); + update.delete_channels = result + .removed_channels + .into_iter() + .map(|id| id.to_proto()) + .collect(); + update.remove_channel_invitations = vec![result.channel_id.to_proto()]; + + for connection_id in connection_pool.user_connection_ids(user_id) { + peer.send(connection_id, update.clone()).trace_err(); + } +} + +fn build_channels_update( channels: ChannelsForUser, channel_invites: Vec, ) -> proto::UpdateChannels { let mut update = proto::UpdateChannels::default(); - for channel in channels.channels.channels { - update.channels.push(proto::Channel { - id: channel.id.to_proto(), - name: channel.name, - }); + for channel in channels.channels { + update.channels.push(channel.to_proto()); } update.unseen_channel_buffer_changes = channels.unseen_buffer_changes; update.unseen_channel_messages = channels.channel_messages; - update.insert_edge = channels.channels.edges; for (channel_id, participants) in channels.channel_participants { update @@ -3088,23 +3225,8 @@ fn build_initial_channels_update( }); } - update - .channel_permissions - .extend( - channels - .channels_with_admin_privileges - .into_iter() - .map(|id| proto::ChannelPermission { - channel_id: id.to_proto(), - is_admin: true, - }), - ); - for channel in channel_invites { - update.channel_invitations.push(proto::Channel { - id: channel.id.to_proto(), - name: channel.name, - }); + update.channel_invitations.push(channel.to_proto()); } update @@ -3118,42 +3240,28 @@ fn build_initial_contacts_update( for contact in contacts { match contact { - db::Contact::Accepted { - user_id, - should_notify, - busy, - } => { - update - .contacts - .push(contact_for_user(user_id, should_notify, busy, &pool)); + db::Contact::Accepted { user_id, busy } => { + update.contacts.push(contact_for_user(user_id, busy, &pool)); } db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()), - db::Contact::Incoming { - user_id, - should_notify, - } => update - .incoming_requests - .push(proto::IncomingContactRequest { - requester_id: user_id.to_proto(), - should_notify, - }), + db::Contact::Incoming { user_id } => { + update + .incoming_requests + .push(proto::IncomingContactRequest { + requester_id: user_id.to_proto(), + }) + } } } update } -fn contact_for_user( - user_id: UserId, - should_notify: bool, - busy: bool, - pool: &ConnectionPool, -) -> proto::Contact { +fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact { proto::Contact { user_id: user_id.to_proto(), online: pool.is_user_online(user_id), busy, - should_notify, } } @@ -3214,7 +3322,7 @@ async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> let busy = db.is_user_busy(user_id).await?; let pool = session.connection_pool().await; - let updated_contact = contact_for_user(user_id, false, busy, &pool); + let updated_contact = contact_for_user(user_id, busy, &pool); for contact in contacts { if let db::Contact::Accepted { user_id: contact_user_id, diff --git a/crates/collab/src/tests.rs b/crates/collab/src/tests.rs index e78bbe3466..e8da66a75a 100644 --- a/crates/collab/src/tests.rs +++ b/crates/collab/src/tests.rs @@ -6,6 +6,7 @@ mod channel_message_tests; mod channel_tests; mod following_tests; mod integration_tests; +mod notification_tests; mod random_channel_buffer_tests; mod random_project_collaboration_tests; mod randomized_test_helpers; @@ -39,3 +40,7 @@ fn room_participants(room: &ModelHandle, cx: &mut TestAppContext) -> RoomP RoomParticipants { remote, pending } }) } + +fn channel_id(room: &ModelHandle, cx: &mut TestAppContext) -> Option { + cx.read(|cx| room.read(cx).channel_id()) +} diff --git a/crates/collab/src/tests/channel_buffer_tests.rs b/crates/collab/src/tests/channel_buffer_tests.rs index a0b9b52484..5ca40a3c2d 100644 --- a/crates/collab/src/tests/channel_buffer_tests.rs +++ b/crates/collab/src/tests/channel_buffer_tests.rs @@ -3,7 +3,7 @@ use crate::{ tests::TestServer, }; use call::ActiveCall; -use channel::{Channel, ACKNOWLEDGE_DEBOUNCE_INTERVAL}; +use channel::ACKNOWLEDGE_DEBOUNCE_INTERVAL; use client::ParticipantIndex; use client::{Collaborator, UserId}; use collab_ui::channel_view::ChannelView; @@ -407,11 +407,8 @@ async fn test_channel_buffer_disconnect( server.disconnect_client(client_a.peer_id().unwrap()); deterministic.advance_clock(RECEIVE_TIMEOUT + RECONNECT_TIMEOUT); - channel_buffer_a.update(cx_a, |buffer, _| { - assert_eq!( - buffer.channel().as_ref(), - &channel(channel_id, "the-channel") - ); + channel_buffer_a.update(cx_a, |buffer, cx| { + assert_eq!(buffer.channel(cx).unwrap().name, "the-channel"); assert!(!buffer.is_connected()); }); @@ -432,24 +429,12 @@ async fn test_channel_buffer_disconnect( deterministic.run_until_parked(); // Channel buffer observed the deletion - channel_buffer_b.update(cx_b, |buffer, _| { - assert_eq!( - buffer.channel().as_ref(), - &channel(channel_id, "the-channel") - ); + channel_buffer_b.update(cx_b, |buffer, cx| { + assert!(buffer.channel(cx).is_none()); assert!(!buffer.is_connected()); }); } -fn channel(id: u64, name: &'static str) -> Channel { - Channel { - id, - name: name.to_string(), - unseen_note_version: None, - unseen_message_id: None, - } -} - #[gpui::test] async fn test_rejoin_channel_buffer( deterministic: Arc, @@ -694,7 +679,7 @@ async fn test_following_to_channel_notes_without_a_shared_project( .await .unwrap(); channel_view_1_a.update(cx_a, |notes, cx| { - assert_eq!(notes.channel(cx).name, "channel-1"); + assert_eq!(notes.channel(cx).unwrap().name, "channel-1"); notes.editor.update(cx, |editor, cx| { editor.insert("Hello from A.", cx); editor.change_selections(None, cx, |selections| { @@ -726,7 +711,7 @@ async fn test_following_to_channel_notes_without_a_shared_project( .expect("active item is not a channel view") }); channel_view_1_b.read_with(cx_b, |notes, cx| { - assert_eq!(notes.channel(cx).name, "channel-1"); + assert_eq!(notes.channel(cx).unwrap().name, "channel-1"); let editor = notes.editor.read(cx); assert_eq!(editor.text(cx), "Hello from A."); assert_eq!(editor.selections.ranges::(cx), &[3..4]); @@ -738,7 +723,7 @@ async fn test_following_to_channel_notes_without_a_shared_project( .await .unwrap(); channel_view_2_a.read_with(cx_a, |notes, cx| { - assert_eq!(notes.channel(cx).name, "channel-2"); + assert_eq!(notes.channel(cx).unwrap().name, "channel-2"); }); // Client B is taken to the notes for channel 2. @@ -755,7 +740,7 @@ async fn test_following_to_channel_notes_without_a_shared_project( .expect("active item is not a channel view") }); channel_view_2_b.read_with(cx_b, |notes, cx| { - assert_eq!(notes.channel(cx).name, "channel-2"); + assert_eq!(notes.channel(cx).unwrap().name, "channel-2"); }); } diff --git a/crates/collab/src/tests/channel_message_tests.rs b/crates/collab/src/tests/channel_message_tests.rs index 0fc3b085ed..918eb053d3 100644 --- a/crates/collab/src/tests/channel_message_tests.rs +++ b/crates/collab/src/tests/channel_message_tests.rs @@ -1,27 +1,30 @@ use crate::{rpc::RECONNECT_TIMEOUT, tests::TestServer}; -use channel::{ChannelChat, ChannelMessageId}; +use channel::{ChannelChat, ChannelMessageId, MessageParams}; use collab_ui::chat_panel::ChatPanel; use gpui::{executor::Deterministic, BorrowAppContext, ModelHandle, TestAppContext}; +use rpc::Notification; use std::sync::Arc; use workspace::dock::Panel; #[gpui::test] async fn test_basic_channel_messages( deterministic: Arc, - cx_a: &mut TestAppContext, - cx_b: &mut TestAppContext, + mut cx_a: &mut TestAppContext, + mut cx_b: &mut TestAppContext, + mut cx_c: &mut TestAppContext, ) { deterministic.forbid_parking(); let mut server = TestServer::start(&deterministic).await; let client_a = server.create_client(cx_a, "user_a").await; let client_b = server.create_client(cx_b, "user_b").await; + let client_c = server.create_client(cx_c, "user_c").await; let channel_id = server .make_channel( "the-channel", None, (&client_a, cx_a), - &mut [(&client_b, cx_b)], + &mut [(&client_b, cx_b), (&client_c, cx_c)], ) .await; @@ -36,8 +39,17 @@ async fn test_basic_channel_messages( .await .unwrap(); - channel_chat_a - .update(cx_a, |c, cx| c.send_message("one".into(), cx).unwrap()) + let message_id = channel_chat_a + .update(cx_a, |c, cx| { + c.send_message( + MessageParams { + text: "hi @user_c!".into(), + mentions: vec![(3..10, client_c.id())], + }, + cx, + ) + .unwrap() + }) .await .unwrap(); channel_chat_a @@ -52,15 +64,55 @@ async fn test_basic_channel_messages( .unwrap(); deterministic.run_until_parked(); - channel_chat_a.update(cx_a, |c, _| { + + let channel_chat_c = client_c + .channel_store() + .update(cx_c, |store, cx| store.open_channel_chat(channel_id, cx)) + .await + .unwrap(); + + for (chat, cx) in [ + (&channel_chat_a, &mut cx_a), + (&channel_chat_b, &mut cx_b), + (&channel_chat_c, &mut cx_c), + ] { + chat.update(*cx, |c, _| { + assert_eq!( + c.messages() + .iter() + .map(|m| (m.body.as_str(), m.mentions.as_slice())) + .collect::>(), + vec![ + ("hi @user_c!", [(3..10, client_c.id())].as_slice()), + ("two", &[]), + ("three", &[]) + ], + "results for user {}", + c.client().id(), + ); + }); + } + + client_c.notification_store().update(cx_c, |store, _| { + assert_eq!(store.notification_count(), 2); + assert_eq!(store.unread_notification_count(), 1); assert_eq!( - c.messages() - .iter() - .map(|m| m.body.as_str()) - .collect::>(), - vec!["one", "two", "three"] + store.notification_at(0).unwrap().notification, + Notification::ChannelMessageMention { + message_id, + sender_id: client_a.id(), + channel_id, + } ); - }) + assert_eq!( + store.notification_at(1).unwrap().notification, + Notification::ChannelInvitation { + channel_id, + channel_name: "the-channel".to_string(), + inviter_id: client_a.id() + } + ); + }); } #[gpui::test] @@ -280,7 +332,7 @@ async fn test_channel_message_changes( chat_panel_b .update(cx_b, |chat_panel, cx| { chat_panel.set_active(true, cx); - chat_panel.select_channel(channel_id, cx) + chat_panel.select_channel(channel_id, None, cx) }) .await .unwrap(); diff --git a/crates/collab/src/tests/channel_tests.rs b/crates/collab/src/tests/channel_tests.rs index 7cfcce832b..a33ded6492 100644 --- a/crates/collab/src/tests/channel_tests.rs +++ b/crates/collab/src/tests/channel_tests.rs @@ -1,12 +1,17 @@ use crate::{ + db::{self, UserId}, rpc::RECONNECT_TIMEOUT, tests::{room_participants, RoomParticipants, TestServer}, }; use call::ActiveCall; use channel::{ChannelId, ChannelMembership, ChannelStore}; use client::User; +use futures::future::try_join_all; use gpui::{executor::Deterministic, ModelHandle, TestAppContext}; -use rpc::{proto, RECEIVE_TIMEOUT}; +use rpc::{ + proto::{self, ChannelRole}, + RECEIVE_TIMEOUT, +}; use std::sync::Arc; #[gpui::test] @@ -44,22 +49,19 @@ async fn test_core_channels( id: channel_a_id, name: "channel-a".to_string(), depth: 0, - user_is_admin: true, + role: ChannelRole::Admin, }, ExpectedChannel { id: channel_b_id, name: "channel-b".to_string(), depth: 1, - user_is_admin: true, + role: ChannelRole::Admin, }, ], ); client_b.channel_store().read_with(cx_b, |channels, _| { - assert!(channels - .channel_dag_entries() - .collect::>() - .is_empty()) + assert!(channels.ordered_channels().collect::>().is_empty()) }); // Invite client B to channel A as client A. @@ -68,7 +70,12 @@ async fn test_core_channels( .update(cx_a, |store, cx| { assert!(!store.has_pending_channel_invite(channel_a_id, client_b.user_id().unwrap())); - let invite = store.invite_member(channel_a_id, client_b.user_id().unwrap(), false, cx); + let invite = store.invite_member( + channel_a_id, + client_b.user_id().unwrap(), + proto::ChannelRole::Member, + cx, + ); // Make sure we're synchronously storing the pending invite assert!(store.has_pending_channel_invite(channel_a_id, client_b.user_id().unwrap())); @@ -86,7 +93,7 @@ async fn test_core_channels( id: channel_a_id, name: "channel-a".to_string(), depth: 0, - user_is_admin: false, + role: ChannelRole::Member, }], ); @@ -103,12 +110,12 @@ async fn test_core_channels( &[ ( client_a.user_id().unwrap(), - true, + proto::ChannelRole::Admin, proto::channel_member::Kind::Member, ), ( client_b.user_id().unwrap(), - false, + proto::ChannelRole::Member, proto::channel_member::Kind::Invitee, ), ], @@ -117,8 +124,8 @@ async fn test_core_channels( // Client B accepts the invitation. client_b .channel_store() - .update(cx_b, |channels, _| { - channels.respond_to_channel_invite(channel_a_id, true) + .update(cx_b, |channels, cx| { + channels.respond_to_channel_invite(channel_a_id, true, cx) }) .await .unwrap(); @@ -133,13 +140,13 @@ async fn test_core_channels( ExpectedChannel { id: channel_a_id, name: "channel-a".to_string(), - user_is_admin: false, + role: ChannelRole::Member, depth: 0, }, ExpectedChannel { id: channel_b_id, name: "channel-b".to_string(), - user_is_admin: false, + role: ChannelRole::Member, depth: 1, }, ], @@ -161,19 +168,19 @@ async fn test_core_channels( ExpectedChannel { id: channel_a_id, name: "channel-a".to_string(), - user_is_admin: false, + role: ChannelRole::Member, depth: 0, }, ExpectedChannel { id: channel_b_id, name: "channel-b".to_string(), - user_is_admin: false, + role: ChannelRole::Member, depth: 1, }, ExpectedChannel { id: channel_c_id, name: "channel-c".to_string(), - user_is_admin: false, + role: ChannelRole::Member, depth: 2, }, ], @@ -183,7 +190,12 @@ async fn test_core_channels( client_a .channel_store() .update(cx_a, |store, cx| { - store.set_member_admin(channel_a_id, client_b.user_id().unwrap(), true, cx) + store.set_member_role( + channel_a_id, + client_b.user_id().unwrap(), + proto::ChannelRole::Admin, + cx, + ) }) .await .unwrap(); @@ -200,19 +212,19 @@ async fn test_core_channels( id: channel_a_id, name: "channel-a".to_string(), depth: 0, - user_is_admin: true, + role: ChannelRole::Admin, }, ExpectedChannel { id: channel_b_id, name: "channel-b".to_string(), depth: 1, - user_is_admin: true, + role: ChannelRole::Admin, }, ExpectedChannel { id: channel_c_id, name: "channel-c".to_string(), depth: 2, - user_is_admin: true, + role: ChannelRole::Admin, }, ], ); @@ -234,7 +246,7 @@ async fn test_core_channels( id: channel_a_id, name: "channel-a".to_string(), depth: 0, - user_is_admin: true, + role: ChannelRole::Admin, }], ); assert_channels( @@ -244,7 +256,7 @@ async fn test_core_channels( id: channel_a_id, name: "channel-a".to_string(), depth: 0, - user_is_admin: true, + role: ChannelRole::Admin, }], ); @@ -267,18 +279,27 @@ async fn test_core_channels( id: channel_a_id, name: "channel-a".to_string(), depth: 0, - user_is_admin: true, + role: ChannelRole::Admin, }], ); // Client B no longer has access to the channel assert_channels(client_b.channel_store(), cx_b, &[]); - // When disconnected, client A sees no channels. server.forbid_connections(); server.disconnect_client(client_a.peer_id().unwrap()); deterministic.advance_clock(RECEIVE_TIMEOUT + RECONNECT_TIMEOUT); - assert_channels(client_a.channel_store(), cx_a, &[]); + + server + .app_state + .db + .rename_channel( + db::ChannelId::from_proto(channel_a_id), + UserId::from_proto(client_a.id()), + "channel-a-renamed", + ) + .await + .unwrap(); server.allow_connections(); deterministic.advance_clock(RECEIVE_TIMEOUT + RECONNECT_TIMEOUT); @@ -287,9 +308,9 @@ async fn test_core_channels( cx_a, &[ExpectedChannel { id: channel_a_id, - name: "channel-a".to_string(), + name: "channel-a-renamed".to_string(), depth: 0, - user_is_admin: true, + role: ChannelRole::Admin, }], ); } @@ -305,12 +326,12 @@ fn assert_participants_eq(participants: &[Arc], expected_partitipants: &[u #[track_caller] fn assert_members_eq( members: &[ChannelMembership], - expected_members: &[(u64, bool, proto::channel_member::Kind)], + expected_members: &[(u64, proto::ChannelRole, proto::channel_member::Kind)], ) { assert_eq!( members .iter() - .map(|member| (member.user.id, member.admin, member.kind)) + .map(|member| (member.user.id, member.role, member.kind)) .collect::>(), expected_members ); @@ -397,7 +418,7 @@ async fn test_channel_room( id: zed_id, name: "zed".to_string(), depth: 0, - user_is_admin: false, + role: ChannelRole::Member, }], ); client_b.channel_store().read_with(cx_b, |channels, _| { @@ -611,7 +632,12 @@ async fn test_permissions_update_while_invited( client_a .channel_store() .update(cx_a, |channel_store, cx| { - channel_store.invite_member(rust_id, client_b.user_id().unwrap(), false, cx) + channel_store.invite_member( + rust_id, + client_b.user_id().unwrap(), + proto::ChannelRole::Member, + cx, + ) }) .await .unwrap(); @@ -625,7 +651,7 @@ async fn test_permissions_update_while_invited( depth: 0, id: rust_id, name: "rust".to_string(), - user_is_admin: false, + role: ChannelRole::Member, }], ); assert_channels(client_b.channel_store(), cx_b, &[]); @@ -634,7 +660,12 @@ async fn test_permissions_update_while_invited( client_a .channel_store() .update(cx_a, |channel_store, cx| { - channel_store.set_member_admin(rust_id, client_b.user_id().unwrap(), true, cx) + channel_store.set_member_role( + rust_id, + client_b.user_id().unwrap(), + proto::ChannelRole::Admin, + cx, + ) }) .await .unwrap(); @@ -648,7 +679,7 @@ async fn test_permissions_update_while_invited( depth: 0, id: rust_id, name: "rust".to_string(), - user_is_admin: false, + role: ChannelRole::Member, }], ); assert_channels(client_b.channel_store(), cx_b, &[]); @@ -688,7 +719,7 @@ async fn test_channel_rename( depth: 0, id: rust_id, name: "rust-archive".to_string(), - user_is_admin: true, + role: ChannelRole::Admin, }], ); @@ -700,7 +731,7 @@ async fn test_channel_rename( depth: 0, id: rust_id, name: "rust-archive".to_string(), - user_is_admin: false, + role: ChannelRole::Member, }], ); } @@ -803,7 +834,12 @@ async fn test_lost_channel_creation( client_a .channel_store() .update(cx_a, |channel_store, cx| { - channel_store.invite_member(channel_id, client_b.user_id().unwrap(), false, cx) + channel_store.invite_member( + channel_id, + client_b.user_id().unwrap(), + proto::ChannelRole::Member, + cx, + ) }) .await .unwrap(); @@ -818,7 +854,7 @@ async fn test_lost_channel_creation( depth: 0, id: channel_id, name: "x".to_string(), - user_is_admin: false, + role: ChannelRole::Member, }], ); @@ -842,13 +878,13 @@ async fn test_lost_channel_creation( depth: 0, id: channel_id, name: "x".to_string(), - user_is_admin: true, + role: ChannelRole::Admin, }, ExpectedChannel { depth: 1, id: subchannel_id, name: "subchannel".to_string(), - user_is_admin: true, + role: ChannelRole::Admin, }, ], ); @@ -856,8 +892,8 @@ async fn test_lost_channel_creation( // Client B accepts the invite client_b .channel_store() - .update(cx_b, |channel_store, _| { - channel_store.respond_to_channel_invite(channel_id, true) + .update(cx_b, |channel_store, cx| { + channel_store.respond_to_channel_invite(channel_id, true, cx) }) .await .unwrap(); @@ -873,31 +909,489 @@ async fn test_lost_channel_creation( depth: 0, id: channel_id, name: "x".to_string(), - user_is_admin: false, + role: ChannelRole::Member, }, ExpectedChannel { depth: 1, id: subchannel_id, name: "subchannel".to_string(), - user_is_admin: false, + role: ChannelRole::Member, }, ], ); } #[gpui::test] -async fn test_channel_moving( +async fn test_channel_link_notifications( deterministic: Arc, cx_a: &mut TestAppContext, cx_b: &mut TestAppContext, cx_c: &mut TestAppContext, ) { deterministic.forbid_parking(); + let mut server = TestServer::start(&deterministic).await; let client_a = server.create_client(cx_a, "user_a").await; let client_b = server.create_client(cx_b, "user_b").await; let client_c = server.create_client(cx_c, "user_c").await; + let user_b = client_b.user_id().unwrap(); + let user_c = client_c.user_id().unwrap(); + + let channels = server + .make_channel_tree(&[("zed", None)], (&client_a, cx_a)) + .await; + let zed_channel = channels[0]; + + try_join_all(client_a.channel_store().update(cx_a, |channel_store, cx| { + [ + channel_store.set_channel_visibility(zed_channel, proto::ChannelVisibility::Public, cx), + channel_store.invite_member(zed_channel, user_b, proto::ChannelRole::Member, cx), + channel_store.invite_member(zed_channel, user_c, proto::ChannelRole::Guest, cx), + ] + })) + .await + .unwrap(); + + deterministic.run_until_parked(); + + client_b + .channel_store() + .update(cx_b, |channel_store, cx| { + channel_store.respond_to_channel_invite(zed_channel, true, cx) + }) + .await + .unwrap(); + + client_c + .channel_store() + .update(cx_c, |channel_store, cx| { + channel_store.respond_to_channel_invite(zed_channel, true, cx) + }) + .await + .unwrap(); + + deterministic.run_until_parked(); + + // we have an admin (a), member (b) and guest (c) all part of the zed channel. + + // create a new private channel, make it public, and move it under the previous one, and verify it shows for b and not c + let active_channel = client_a + .channel_store() + .update(cx_a, |channel_store, cx| { + channel_store.create_channel("active", Some(zed_channel), cx) + }) + .await + .unwrap(); + + // the new channel shows for b and not c + assert_channels_list_shape( + client_a.channel_store(), + cx_a, + &[(zed_channel, 0), (active_channel, 1)], + ); + assert_channels_list_shape( + client_b.channel_store(), + cx_b, + &[(zed_channel, 0), (active_channel, 1)], + ); + assert_channels_list_shape(client_c.channel_store(), cx_c, &[(zed_channel, 0)]); + + let vim_channel = client_a + .channel_store() + .update(cx_a, |channel_store, cx| { + channel_store.create_channel("vim", None, cx) + }) + .await + .unwrap(); + + client_a + .channel_store() + .update(cx_a, |channel_store, cx| { + channel_store.set_channel_visibility(vim_channel, proto::ChannelVisibility::Public, cx) + }) + .await + .unwrap(); + + client_a + .channel_store() + .update(cx_a, |channel_store, cx| { + channel_store.move_channel(vim_channel, Some(active_channel), cx) + }) + .await + .unwrap(); + + deterministic.run_until_parked(); + + // the new channel shows for b and c + assert_channels_list_shape( + client_a.channel_store(), + cx_a, + &[(zed_channel, 0), (active_channel, 1), (vim_channel, 2)], + ); + assert_channels_list_shape( + client_b.channel_store(), + cx_b, + &[(zed_channel, 0), (active_channel, 1), (vim_channel, 2)], + ); + assert_channels_list_shape( + client_c.channel_store(), + cx_c, + &[(zed_channel, 0), (vim_channel, 1)], + ); + + let helix_channel = client_a + .channel_store() + .update(cx_a, |channel_store, cx| { + channel_store.create_channel("helix", None, cx) + }) + .await + .unwrap(); + + client_a + .channel_store() + .update(cx_a, |channel_store, cx| { + channel_store.move_channel(helix_channel, Some(vim_channel), cx) + }) + .await + .unwrap(); + + client_a + .channel_store() + .update(cx_a, |channel_store, cx| { + channel_store.set_channel_visibility( + helix_channel, + proto::ChannelVisibility::Public, + cx, + ) + }) + .await + .unwrap(); + + // the new channel shows for b and c + assert_channels_list_shape( + client_b.channel_store(), + cx_b, + &[ + (zed_channel, 0), + (active_channel, 1), + (vim_channel, 2), + (helix_channel, 3), + ], + ); + assert_channels_list_shape( + client_c.channel_store(), + cx_c, + &[(zed_channel, 0), (vim_channel, 1), (helix_channel, 2)], + ); + + client_a + .channel_store() + .update(cx_a, |channel_store, cx| { + channel_store.set_channel_visibility(vim_channel, proto::ChannelVisibility::Members, cx) + }) + .await + .unwrap(); + + // the members-only channel is still shown for c, but hidden for b + assert_channels_list_shape( + client_b.channel_store(), + cx_b, + &[ + (zed_channel, 0), + (active_channel, 1), + (vim_channel, 2), + (helix_channel, 3), + ], + ); + client_b + .channel_store() + .read_with(cx_b, |channel_store, _| { + assert_eq!( + channel_store + .channel_for_id(vim_channel) + .unwrap() + .visibility, + proto::ChannelVisibility::Members + ) + }); + + assert_channels_list_shape(client_c.channel_store(), cx_c, &[(zed_channel, 0)]); +} + +#[gpui::test] +async fn test_channel_membership_notifications( + deterministic: Arc, + cx_a: &mut TestAppContext, + cx_b: &mut TestAppContext, +) { + deterministic.forbid_parking(); + + deterministic.forbid_parking(); + + let mut server = TestServer::start(&deterministic).await; + let client_a = server.create_client(cx_a, "user_a").await; + let client_b = server.create_client(cx_b, "user_c").await; + + let user_b = client_b.user_id().unwrap(); + + let channels = server + .make_channel_tree( + &[ + ("zed", None), + ("active", Some("zed")), + ("vim", Some("active")), + ], + (&client_a, cx_a), + ) + .await; + let zed_channel = channels[0]; + let _active_channel = channels[1]; + let vim_channel = channels[2]; + + try_join_all(client_a.channel_store().update(cx_a, |channel_store, cx| { + [ + channel_store.set_channel_visibility(zed_channel, proto::ChannelVisibility::Public, cx), + channel_store.set_channel_visibility(vim_channel, proto::ChannelVisibility::Public, cx), + channel_store.invite_member(vim_channel, user_b, proto::ChannelRole::Member, cx), + channel_store.invite_member(zed_channel, user_b, proto::ChannelRole::Guest, cx), + ] + })) + .await + .unwrap(); + + deterministic.run_until_parked(); + + client_b + .channel_store() + .update(cx_b, |channel_store, cx| { + channel_store.respond_to_channel_invite(zed_channel, true, cx) + }) + .await + .unwrap(); + + client_b + .channel_store() + .update(cx_b, |channel_store, cx| { + channel_store.respond_to_channel_invite(vim_channel, true, cx) + }) + .await + .unwrap(); + + deterministic.run_until_parked(); + + // we have an admin (a), and a guest (b) with access to all of zed, and membership in vim. + assert_channels( + client_b.channel_store(), + cx_b, + &[ + ExpectedChannel { + depth: 0, + id: zed_channel, + name: "zed".to_string(), + role: ChannelRole::Guest, + }, + ExpectedChannel { + depth: 1, + id: vim_channel, + name: "vim".to_string(), + role: ChannelRole::Member, + }, + ], + ); + + client_a + .channel_store() + .update(cx_a, |channel_store, cx| { + channel_store.remove_member(vim_channel, user_b, cx) + }) + .await + .unwrap(); + + deterministic.run_until_parked(); + + assert_channels( + client_b.channel_store(), + cx_b, + &[ + ExpectedChannel { + depth: 0, + id: zed_channel, + name: "zed".to_string(), + role: ChannelRole::Guest, + }, + ExpectedChannel { + depth: 1, + id: vim_channel, + name: "vim".to_string(), + role: ChannelRole::Guest, + }, + ], + ) +} + +#[gpui::test] +async fn test_guest_access( + deterministic: Arc, + cx_a: &mut TestAppContext, + cx_b: &mut TestAppContext, +) { + deterministic.forbid_parking(); + + let mut server = TestServer::start(&deterministic).await; + let client_a = server.create_client(cx_a, "user_a").await; + let client_b = server.create_client(cx_b, "user_b").await; + + let channels = server + .make_channel_tree( + &[("channel-a", None), ("channel-b", Some("channel-a"))], + (&client_a, cx_a), + ) + .await; + let channel_a = channels[0]; + let channel_b = channels[1]; + + let active_call_b = cx_b.read(ActiveCall::global); + + // Non-members should not be allowed to join + assert!(active_call_b + .update(cx_b, |call, cx| call.join_channel(channel_a, cx)) + .await + .is_err()); + + // Make channels A and B public + client_a + .channel_store() + .update(cx_a, |channel_store, cx| { + channel_store.set_channel_visibility(channel_a, proto::ChannelVisibility::Public, cx) + }) + .await + .unwrap(); + client_a + .channel_store() + .update(cx_a, |channel_store, cx| { + channel_store.set_channel_visibility(channel_b, proto::ChannelVisibility::Public, cx) + }) + .await + .unwrap(); + + // Client B joins channel A as a guest + active_call_b + .update(cx_b, |call, cx| call.join_channel(channel_a, cx)) + .await + .unwrap(); + + deterministic.run_until_parked(); + assert_channels_list_shape( + client_a.channel_store(), + cx_a, + &[(channel_a, 0), (channel_b, 1)], + ); + assert_channels_list_shape( + client_b.channel_store(), + cx_b, + &[(channel_a, 0), (channel_b, 1)], + ); + + client_a.channel_store().update(cx_a, |channel_store, _| { + let participants = channel_store.channel_participants(channel_a); + assert_eq!(participants.len(), 1); + assert_eq!(participants[0].id, client_b.user_id().unwrap()); + }); + + client_a + .channel_store() + .update(cx_a, |channel_store, cx| { + channel_store.set_channel_visibility(channel_a, proto::ChannelVisibility::Members, cx) + }) + .await + .unwrap(); + + assert_channels_list_shape(client_b.channel_store(), cx_b, &[]); + + active_call_b + .update(cx_b, |call, cx| call.join_channel(channel_b, cx)) + .await + .unwrap(); + + deterministic.run_until_parked(); + assert_channels_list_shape(client_b.channel_store(), cx_b, &[(channel_b, 0)]); +} + +#[gpui::test] +async fn test_invite_access( + deterministic: Arc, + cx_a: &mut TestAppContext, + cx_b: &mut TestAppContext, +) { + deterministic.forbid_parking(); + + let mut server = TestServer::start(&deterministic).await; + let client_a = server.create_client(cx_a, "user_a").await; + let client_b = server.create_client(cx_b, "user_b").await; + + let channels = server + .make_channel_tree( + &[("channel-a", None), ("channel-b", Some("channel-a"))], + (&client_a, cx_a), + ) + .await; + let channel_a_id = channels[0]; + let channel_b_id = channels[0]; + + let active_call_b = cx_b.read(ActiveCall::global); + + // should not be allowed to join + assert!(active_call_b + .update(cx_b, |call, cx| call.join_channel(channel_b_id, cx)) + .await + .is_err()); + + client_a + .channel_store() + .update(cx_a, |channel_store, cx| { + channel_store.invite_member( + channel_a_id, + client_b.user_id().unwrap(), + ChannelRole::Member, + cx, + ) + }) + .await + .unwrap(); + + active_call_b + .update(cx_b, |call, cx| call.join_channel(channel_b_id, cx)) + .await + .unwrap(); + + deterministic.run_until_parked(); + + client_b.channel_store().update(cx_b, |channel_store, _| { + assert!(channel_store.channel_for_id(channel_b_id).is_some()); + assert!(channel_store.channel_for_id(channel_a_id).is_some()); + }); + + client_a.channel_store().update(cx_a, |channel_store, _| { + let participants = channel_store.channel_participants(channel_b_id); + assert_eq!(participants.len(), 1); + assert_eq!(participants[0].id, client_b.user_id().unwrap()); + }) +} + +#[gpui::test] +async fn test_channel_moving( + deterministic: Arc, + cx_a: &mut TestAppContext, + _cx_b: &mut TestAppContext, + _cx_c: &mut TestAppContext, +) { + deterministic.forbid_parking(); + let mut server = TestServer::start(&deterministic).await; + let client_a = server.create_client(cx_a, "user_a").await; + // let client_b = server.create_client(cx_b, "user_b").await; + // let client_c = server.create_client(cx_c, "user_c").await; + let channels = server .make_channel_tree( &[ @@ -930,7 +1424,7 @@ async fn test_channel_moving( client_a .channel_store() .update(cx_a, |channel_store, cx| { - channel_store.move_channel(channel_d_id, channel_c_id, channel_b_id, cx) + channel_store.move_channel(channel_d_id, Some(channel_b_id), cx) }) .await .unwrap(); @@ -948,188 +1442,6 @@ async fn test_channel_moving( (channel_d_id, 2), ], ); - - client_a - .channel_store() - .update(cx_a, |channel_store, cx| { - channel_store.link_channel(channel_d_id, channel_c_id, cx) - }) - .await - .unwrap(); - - // Current shape for A: - // /------\ - // a - b -- c -- d - assert_channels_list_shape( - client_a.channel_store(), - cx_a, - &[ - (channel_a_id, 0), - (channel_b_id, 1), - (channel_c_id, 2), - (channel_d_id, 3), - (channel_d_id, 2), - ], - ); - - let b_channels = server - .make_channel_tree( - &[ - ("channel-mu", None), - ("channel-gamma", Some("channel-mu")), - ("channel-epsilon", Some("channel-mu")), - ], - (&client_b, cx_b), - ) - .await; - let channel_mu_id = b_channels[0]; - let channel_ga_id = b_channels[1]; - let channel_ep_id = b_channels[2]; - - // Current shape for B: - // /- ep - // mu -- ga - assert_channels_list_shape( - client_b.channel_store(), - cx_b, - &[(channel_mu_id, 0), (channel_ep_id, 1), (channel_ga_id, 1)], - ); - - client_a - .add_admin_to_channel((&client_b, cx_b), channel_b_id, cx_a) - .await; - - // Current shape for B: - // /- ep - // mu -- ga - // /---------\ - // b -- c -- d - assert_channels_list_shape( - client_b.channel_store(), - cx_b, - &[ - // New channels from a - (channel_b_id, 0), - (channel_c_id, 1), - (channel_d_id, 2), - (channel_d_id, 1), - // B's old channels - (channel_mu_id, 0), - (channel_ep_id, 1), - (channel_ga_id, 1), - ], - ); - - client_b - .add_admin_to_channel((&client_c, cx_c), channel_ep_id, cx_b) - .await; - - // Current shape for C: - // - ep - assert_channels_list_shape(client_c.channel_store(), cx_c, &[(channel_ep_id, 0)]); - - client_b - .channel_store() - .update(cx_b, |channel_store, cx| { - channel_store.link_channel(channel_b_id, channel_ep_id, cx) - }) - .await - .unwrap(); - - // Current shape for B: - // /---------\ - // /- ep -- b -- c -- d - // mu -- ga - assert_channels_list_shape( - client_b.channel_store(), - cx_b, - &[ - (channel_mu_id, 0), - (channel_ep_id, 1), - (channel_b_id, 2), - (channel_c_id, 3), - (channel_d_id, 4), - (channel_d_id, 3), - (channel_ga_id, 1), - ], - ); - - // Current shape for C: - // /---------\ - // ep -- b -- c -- d - assert_channels_list_shape( - client_c.channel_store(), - cx_c, - &[ - (channel_ep_id, 0), - (channel_b_id, 1), - (channel_c_id, 2), - (channel_d_id, 3), - (channel_d_id, 2), - ], - ); - - client_b - .channel_store() - .update(cx_b, |channel_store, cx| { - channel_store.link_channel(channel_ga_id, channel_b_id, cx) - }) - .await - .unwrap(); - - // Current shape for B: - // /---------\ - // /- ep -- b -- c -- d - // / \ - // mu ---------- ga - assert_channels_list_shape( - client_b.channel_store(), - cx_b, - &[ - (channel_mu_id, 0), - (channel_ep_id, 1), - (channel_b_id, 2), - (channel_c_id, 3), - (channel_d_id, 4), - (channel_d_id, 3), - (channel_ga_id, 3), - (channel_ga_id, 1), - ], - ); - - // Current shape for A: - // /------\ - // a - b -- c -- d - // \-- ga - assert_channels_list_shape( - client_a.channel_store(), - cx_a, - &[ - (channel_a_id, 0), - (channel_b_id, 1), - (channel_c_id, 2), - (channel_d_id, 3), - (channel_d_id, 2), - (channel_ga_id, 2), - ], - ); - - // Current shape for C: - // /-------\ - // ep -- b -- c -- d - // \-- ga - assert_channels_list_shape( - client_c.channel_store(), - cx_c, - &[ - (channel_ep_id, 0), - (channel_b_id, 1), - (channel_c_id, 2), - (channel_d_id, 3), - (channel_d_id, 2), - (channel_ga_id, 2), - ], - ); } #[derive(Debug, PartialEq)] @@ -1137,7 +1449,7 @@ struct ExpectedChannel { depth: usize, id: ChannelId, name: String, - user_is_admin: bool, + role: ChannelRole, } #[track_caller] @@ -1154,7 +1466,7 @@ fn assert_channel_invitations( depth: 0, name: channel.name.clone(), id: channel.id, - user_is_admin: store.is_user_admin(channel.id), + role: channel.role, }) .collect::>() }); @@ -1169,12 +1481,12 @@ fn assert_channels( ) { let actual = channel_store.read_with(cx, |store, _| { store - .channel_dag_entries() + .ordered_channels() .map(|(depth, channel)| ExpectedChannel { depth, name: channel.name.clone(), id: channel.id, - user_is_admin: store.is_user_admin(channel.id), + role: channel.role, }) .collect::>() }); @@ -1191,7 +1503,7 @@ fn assert_channels_list_shape( let actual = channel_store.read_with(cx, |store, _| { store - .channel_dag_entries() + .ordered_channels() .map(|(depth, channel)| (channel.id, depth)) .collect::>() }); diff --git a/crates/collab/src/tests/following_tests.rs b/crates/collab/src/tests/following_tests.rs index f3857e3db3..a28f2ae87f 100644 --- a/crates/collab/src/tests/following_tests.rs +++ b/crates/collab/src/tests/following_tests.rs @@ -1,6 +1,6 @@ use crate::{rpc::RECONNECT_TIMEOUT, tests::TestServer}; use call::ActiveCall; -use collab_ui::project_shared_notification::ProjectSharedNotification; +use collab_ui::notifications::project_shared_notification::ProjectSharedNotification; use editor::{Editor, ExcerptRange, MultiBuffer}; use gpui::{executor::Deterministic, geometry::vector::vec2f, TestAppContext, ViewHandle}; use live_kit_client::MacOSDisplay; diff --git a/crates/collab/src/tests/integration_tests.rs b/crates/collab/src/tests/integration_tests.rs index d6d449fd47..550c3a2bd8 100644 --- a/crates/collab/src/tests/integration_tests.rs +++ b/crates/collab/src/tests/integration_tests.rs @@ -1,6 +1,6 @@ use crate::{ rpc::{CLEANUP_TIMEOUT, RECONNECT_TIMEOUT}, - tests::{room_participants, RoomParticipants, TestClient, TestServer}, + tests::{channel_id, room_participants, RoomParticipants, TestClient, TestServer}, }; use call::{room, ActiveCall, ParticipantLocation, Room}; use client::{User, RECEIVE_TIMEOUT}; @@ -15,8 +15,8 @@ use gpui::{executor::Deterministic, test::EmptyView, AppContext, ModelHandle, Te use indoc::indoc; use language::{ language_settings::{AllLanguageSettings, Formatter, InlayHintSettings}, - tree_sitter_rust, Anchor, BundledFormatter, Diagnostic, DiagnosticEntry, FakeLspAdapter, - Language, LanguageConfig, LineEnding, OffsetRangeExt, Point, Rope, + tree_sitter_rust, Anchor, Diagnostic, DiagnosticEntry, FakeLspAdapter, Language, + LanguageConfig, LineEnding, OffsetRangeExt, Point, Rope, }; use live_kit_client::MacOSDisplay; use lsp::LanguageServerId; @@ -469,6 +469,119 @@ async fn test_calling_multiple_users_simultaneously( ); } +#[gpui::test(iterations = 10)] +async fn test_joining_channels_and_calling_multiple_users_simultaneously( + deterministic: Arc, + cx_a: &mut TestAppContext, + cx_b: &mut TestAppContext, + cx_c: &mut TestAppContext, +) { + deterministic.forbid_parking(); + let mut server = TestServer::start(&deterministic).await; + + let client_a = server.create_client(cx_a, "user_a").await; + let client_b = server.create_client(cx_b, "user_b").await; + let client_c = server.create_client(cx_c, "user_c").await; + server + .make_contacts(&mut [(&client_a, cx_a), (&client_b, cx_b), (&client_c, cx_c)]) + .await; + + let channel_1 = server + .make_channel( + "channel1", + None, + (&client_a, cx_a), + &mut [(&client_b, cx_b), (&client_c, cx_c)], + ) + .await; + + let channel_2 = server + .make_channel( + "channel2", + None, + (&client_a, cx_a), + &mut [(&client_b, cx_b), (&client_c, cx_c)], + ) + .await; + + let active_call_a = cx_a.read(ActiveCall::global); + + // Simultaneously join channel 1 and then channel 2 + active_call_a + .update(cx_a, |call, cx| call.join_channel(channel_1, cx)) + .detach(); + let join_channel_2 = active_call_a.update(cx_a, |call, cx| call.join_channel(channel_2, cx)); + + join_channel_2.await.unwrap(); + + let room_a = active_call_a.read_with(cx_a, |call, _| call.room().unwrap().clone()); + deterministic.run_until_parked(); + + assert_eq!(channel_id(&room_a, cx_a), Some(channel_2)); + + // Leave the room + active_call_a + .update(cx_a, |call, cx| { + let hang_up = call.hang_up(cx); + hang_up + }) + .await + .unwrap(); + + // Initiating invites and then joining a channel should fail gracefully + let b_invite = active_call_a.update(cx_a, |call, cx| { + call.invite(client_b.user_id().unwrap(), None, cx) + }); + let c_invite = active_call_a.update(cx_a, |call, cx| { + call.invite(client_c.user_id().unwrap(), None, cx) + }); + + let join_channel = active_call_a.update(cx_a, |call, cx| call.join_channel(channel_1, cx)); + + b_invite.await.unwrap(); + c_invite.await.unwrap(); + join_channel.await.unwrap(); + + let room_a = active_call_a.read_with(cx_a, |call, _| call.room().unwrap().clone()); + deterministic.run_until_parked(); + + assert_eq!( + room_participants(&room_a, cx_a), + RoomParticipants { + remote: Default::default(), + pending: vec!["user_b".to_string(), "user_c".to_string()] + } + ); + + assert_eq!(channel_id(&room_a, cx_a), None); + + // Leave the room + active_call_a + .update(cx_a, |call, cx| { + let hang_up = call.hang_up(cx); + hang_up + }) + .await + .unwrap(); + + // Simultaneously join channel 1 and call user B and user C from client A. + let join_channel = active_call_a.update(cx_a, |call, cx| call.join_channel(channel_1, cx)); + + let b_invite = active_call_a.update(cx_a, |call, cx| { + call.invite(client_b.user_id().unwrap(), None, cx) + }); + let c_invite = active_call_a.update(cx_a, |call, cx| { + call.invite(client_c.user_id().unwrap(), None, cx) + }); + + join_channel.await.unwrap(); + b_invite.await.unwrap(); + c_invite.await.unwrap(); + + active_call_a.read_with(cx_a, |call, _| call.room().unwrap().clone()); + deterministic.run_until_parked(); +} + #[gpui::test(iterations = 10)] async fn test_room_uniqueness( deterministic: Arc, @@ -4530,6 +4643,7 @@ async fn test_prettier_formatting_buffer( LanguageConfig { name: "Rust".into(), path_suffixes: vec!["rs".to_string()], + prettier_parser_name: Some("test_parser".to_string()), ..Default::default() }, Some(tree_sitter_rust::language()), @@ -4537,10 +4651,7 @@ async fn test_prettier_formatting_buffer( let test_plugin = "test_plugin"; let mut fake_language_servers = language .set_fake_lsp_adapter(Arc::new(FakeLspAdapter { - enabled_formatters: vec![BundledFormatter::Prettier { - parser_name: Some("test_parser"), - plugin_names: vec![test_plugin], - }], + prettier_plugins: vec![test_plugin], ..Default::default() })) .await; @@ -4557,11 +4668,7 @@ async fn test_prettier_formatting_buffer( .insert_tree(&directory, json!({ "a.rs": buffer_text })) .await; let (project_a, worktree_id) = client_a.build_local_project(&directory, cx_a).await; - let prettier_format_suffix = project_a.update(cx_a, |project, _| { - let suffix = project.enable_test_prettier(&[test_plugin]); - project.languages().add(language); - suffix - }); + let prettier_format_suffix = project::TEST_PRETTIER_FORMAT_SUFFIX; let buffer_a = cx_a .background() .spawn(project_a.update(cx_a, |p, cx| p.open_buffer((worktree_id, "a.rs"), cx))) diff --git a/crates/collab/src/tests/notification_tests.rs b/crates/collab/src/tests/notification_tests.rs new file mode 100644 index 0000000000..1114470449 --- /dev/null +++ b/crates/collab/src/tests/notification_tests.rs @@ -0,0 +1,159 @@ +use crate::tests::TestServer; +use gpui::{executor::Deterministic, TestAppContext}; +use notifications::NotificationEvent; +use parking_lot::Mutex; +use rpc::{proto, Notification}; +use std::sync::Arc; + +#[gpui::test] +async fn test_notifications( + deterministic: Arc, + cx_a: &mut TestAppContext, + cx_b: &mut TestAppContext, +) { + deterministic.forbid_parking(); + let mut server = TestServer::start(&deterministic).await; + let client_a = server.create_client(cx_a, "user_a").await; + let client_b = server.create_client(cx_b, "user_b").await; + + let notification_events_a = Arc::new(Mutex::new(Vec::new())); + let notification_events_b = Arc::new(Mutex::new(Vec::new())); + client_a.notification_store().update(cx_a, |_, cx| { + let events = notification_events_a.clone(); + cx.subscribe(&cx.handle(), move |_, _, event, _| { + events.lock().push(event.clone()); + }) + .detach() + }); + client_b.notification_store().update(cx_b, |_, cx| { + let events = notification_events_b.clone(); + cx.subscribe(&cx.handle(), move |_, _, event, _| { + events.lock().push(event.clone()); + }) + .detach() + }); + + // Client A sends a contact request to client B. + client_a + .user_store() + .update(cx_a, |store, cx| store.request_contact(client_b.id(), cx)) + .await + .unwrap(); + + // Client B receives a contact request notification and responds to the + // request, accepting it. + deterministic.run_until_parked(); + client_b.notification_store().update(cx_b, |store, cx| { + assert_eq!(store.notification_count(), 1); + assert_eq!(store.unread_notification_count(), 1); + + let entry = store.notification_at(0).unwrap(); + assert_eq!( + entry.notification, + Notification::ContactRequest { + sender_id: client_a.id() + } + ); + assert!(!entry.is_read); + assert_eq!( + ¬ification_events_b.lock()[0..], + &[ + NotificationEvent::NewNotification { + entry: entry.clone(), + }, + NotificationEvent::NotificationsUpdated { + old_range: 0..0, + new_count: 1 + } + ] + ); + + store.respond_to_notification(entry.notification.clone(), true, cx); + }); + + // Client B sees the notification is now read, and that they responded. + deterministic.run_until_parked(); + client_b.notification_store().read_with(cx_b, |store, _| { + assert_eq!(store.notification_count(), 1); + assert_eq!(store.unread_notification_count(), 0); + + let entry = store.notification_at(0).unwrap(); + assert!(entry.is_read); + assert_eq!(entry.response, Some(true)); + assert_eq!( + ¬ification_events_b.lock()[2..], + &[ + NotificationEvent::NotificationRead { + entry: entry.clone(), + }, + NotificationEvent::NotificationsUpdated { + old_range: 0..1, + new_count: 1 + } + ] + ); + }); + + // Client A receives a notification that client B accepted their request. + client_a.notification_store().read_with(cx_a, |store, _| { + assert_eq!(store.notification_count(), 1); + assert_eq!(store.unread_notification_count(), 1); + + let entry = store.notification_at(0).unwrap(); + assert_eq!( + entry.notification, + Notification::ContactRequestAccepted { + responder_id: client_b.id() + } + ); + assert!(!entry.is_read); + }); + + // Client A creates a channel and invites client B to be a member. + let channel_id = client_a + .channel_store() + .update(cx_a, |store, cx| { + store.create_channel("the-channel", None, cx) + }) + .await + .unwrap(); + client_a + .channel_store() + .update(cx_a, |store, cx| { + store.invite_member(channel_id, client_b.id(), proto::ChannelRole::Member, cx) + }) + .await + .unwrap(); + + // Client B receives a channel invitation notification and responds to the + // invitation, accepting it. + deterministic.run_until_parked(); + client_b.notification_store().update(cx_b, |store, cx| { + assert_eq!(store.notification_count(), 2); + assert_eq!(store.unread_notification_count(), 1); + + let entry = store.notification_at(0).unwrap(); + assert_eq!( + entry.notification, + Notification::ChannelInvitation { + channel_id, + channel_name: "the-channel".to_string(), + inviter_id: client_a.id() + } + ); + assert!(!entry.is_read); + + store.respond_to_notification(entry.notification.clone(), true, cx); + }); + + // Client B sees the notification is now read, and that they responded. + deterministic.run_until_parked(); + client_b.notification_store().read_with(cx_b, |store, _| { + assert_eq!(store.notification_count(), 2); + assert_eq!(store.unread_notification_count(), 0); + + let entry = store.notification_at(0).unwrap(); + assert!(entry.is_read); + assert_eq!(entry.response, Some(true)); + }); +} diff --git a/crates/collab/src/tests/random_channel_buffer_tests.rs b/crates/collab/src/tests/random_channel_buffer_tests.rs index 6e0bef225c..38bc3f7c12 100644 --- a/crates/collab/src/tests/random_channel_buffer_tests.rs +++ b/crates/collab/src/tests/random_channel_buffer_tests.rs @@ -1,3 +1,5 @@ +use crate::db::ChannelRole; + use super::{run_randomized_test, RandomizedTest, TestClient, TestError, TestServer, UserTestPlan}; use anyhow::Result; use async_trait::async_trait; @@ -46,11 +48,11 @@ impl RandomizedTest for RandomChannelBufferTest { let db = &server.app_state.db; for ix in 0..CHANNEL_COUNT { let id = db - .create_channel(&format!("channel-{ix}"), None, users[0].user_id) + .create_root_channel(&format!("channel-{ix}"), users[0].user_id) .await .unwrap(); for user in &users[1..] { - db.invite_channel_member(id, user.user_id, users[0].user_id, false) + db.invite_channel_member(id, user.user_id, users[0].user_id, ChannelRole::Member) .await .unwrap(); db.respond_to_channel_invite(id, user.user_id, true) @@ -81,7 +83,7 @@ impl RandomizedTest for RandomChannelBufferTest { match rng.gen_range(0..100_u32) { 0..=29 => { let channel_name = client.channel_store().read_with(cx, |store, cx| { - store.channel_dag_entries().find_map(|(_, channel)| { + store.ordered_channels().find_map(|(_, channel)| { if store.has_open_channel_buffer(channel.id, cx) { None } else { @@ -96,15 +98,16 @@ impl RandomizedTest for RandomChannelBufferTest { 30..=40 => { if let Some(buffer) = channel_buffers.iter().choose(rng) { - let channel_name = buffer.read_with(cx, |b, _| b.channel().name.clone()); + let channel_name = + buffer.read_with(cx, |b, cx| b.channel(cx).unwrap().name.clone()); break ChannelBufferOperation::LeaveChannelNotes { channel_name }; } } _ => { if let Some(buffer) = channel_buffers.iter().choose(rng) { - break buffer.read_with(cx, |b, _| { - let channel_name = b.channel().name.clone(); + break buffer.read_with(cx, |b, cx| { + let channel_name = b.channel(cx).unwrap().name.clone(); let edits = b .buffer() .read_with(cx, |buffer, _| buffer.get_random_edits(rng, 3)); @@ -128,7 +131,7 @@ impl RandomizedTest for RandomChannelBufferTest { ChannelBufferOperation::JoinChannelNotes { channel_name } => { let buffer = client.channel_store().update(cx, |store, cx| { let channel_id = store - .channel_dag_entries() + .ordered_channels() .find(|(_, c)| c.name == channel_name) .unwrap() .1 @@ -151,7 +154,7 @@ impl RandomizedTest for RandomChannelBufferTest { let buffer = cx.update(|cx| { let mut left_buffer = Err(TestError::Inapplicable); client.channel_buffers().retain(|buffer| { - if buffer.read(cx).channel().name == channel_name { + if buffer.read(cx).channel(cx).unwrap().name == channel_name { left_buffer = Ok(buffer.clone()); false } else { @@ -177,7 +180,9 @@ impl RandomizedTest for RandomChannelBufferTest { client .channel_buffers() .iter() - .find(|buffer| buffer.read(cx).channel().name == channel_name) + .find(|buffer| { + buffer.read(cx).channel(cx).unwrap().name == channel_name + }) .cloned() }) .ok_or_else(|| TestError::Inapplicable)?; @@ -248,7 +253,7 @@ impl RandomizedTest for RandomChannelBufferTest { if let Some(channel_buffer) = client .channel_buffers() .iter() - .find(|b| b.read(cx).channel().id == channel_id.to_proto()) + .find(|b| b.read(cx).channel_id == channel_id.to_proto()) { let channel_buffer = channel_buffer.read(cx); diff --git a/crates/collab/src/tests/randomized_test_helpers.rs b/crates/collab/src/tests/randomized_test_helpers.rs index 39598bdaf9..1cec945282 100644 --- a/crates/collab/src/tests/randomized_test_helpers.rs +++ b/crates/collab/src/tests/randomized_test_helpers.rs @@ -208,8 +208,7 @@ impl TestPlan { false, NewUserParams { github_login: username.clone(), - github_user_id: (ix + 1) as i32, - invite_count: 0, + github_user_id: ix as i32, }, ) .await diff --git a/crates/collab/src/tests/test_server.rs b/crates/collab/src/tests/test_server.rs index 7397489b34..d6ebe1e84e 100644 --- a/crates/collab/src/tests/test_server.rs +++ b/crates/collab/src/tests/test_server.rs @@ -16,9 +16,10 @@ use futures::{channel::oneshot, StreamExt as _}; use gpui::{executor::Deterministic, ModelHandle, Task, TestAppContext, WindowHandle}; use language::LanguageRegistry; use node_runtime::FakeNodeRuntime; +use notifications::NotificationStore; use parking_lot::Mutex; use project::{Project, WorktreeId}; -use rpc::RECEIVE_TIMEOUT; +use rpc::{proto::ChannelRole, RECEIVE_TIMEOUT}; use settings::SettingsStore; use std::{ cell::{Ref, RefCell, RefMut}, @@ -46,6 +47,7 @@ pub struct TestClient { pub username: String, pub app_state: Arc, channel_store: ModelHandle, + notification_store: ModelHandle, state: RefCell, } @@ -138,7 +140,6 @@ impl TestServer { NewUserParams { github_login: name.into(), github_user_id: 0, - invite_count: 0, }, ) .await @@ -231,7 +232,8 @@ impl TestServer { workspace::init(app_state.clone(), cx); audio::init((), cx); call::init(client.clone(), user_store.clone(), cx); - channel::init(&client, user_store, cx); + channel::init(&client, user_store.clone(), cx); + notifications::init(client.clone(), user_store, cx); }); client @@ -243,6 +245,7 @@ impl TestServer { app_state, username: name.to_string(), channel_store: cx.read(ChannelStore::global).clone(), + notification_store: cx.read(NotificationStore::global).clone(), state: Default::default(), }; client.wait_for_current_user(cx).await; @@ -327,7 +330,7 @@ impl TestServer { channel_store.invite_member( channel_id, member_client.user_id().unwrap(), - false, + ChannelRole::Member, cx, ) }) @@ -338,8 +341,8 @@ impl TestServer { member_cx .read(ChannelStore::global) - .update(*member_cx, |channels, _| { - channels.respond_to_channel_invite(channel_id, true) + .update(*member_cx, |channels, cx| { + channels.respond_to_channel_invite(channel_id, true, cx) }) .await .unwrap(); @@ -448,6 +451,10 @@ impl TestClient { &self.channel_store } + pub fn notification_store(&self) -> &ModelHandle { + &self.notification_store + } + pub fn user_store(&self) -> &ModelHandle { &self.app_state.user_store } @@ -604,33 +611,6 @@ impl TestClient { ) -> WindowHandle { cx.add_window(|cx| Workspace::new(0, project.clone(), self.app_state.clone(), cx)) } - - pub async fn add_admin_to_channel( - &self, - user: (&TestClient, &mut TestAppContext), - channel: u64, - cx_self: &mut TestAppContext, - ) { - let (other_client, other_cx) = user; - - cx_self - .read(ChannelStore::global) - .update(cx_self, |channel_store, cx| { - channel_store.invite_member(channel, other_client.user_id().unwrap(), true, cx) - }) - .await - .unwrap(); - - cx_self.foreground().run_until_parked(); - - other_cx - .read(ChannelStore::global) - .update(other_cx, |channel_store, _| { - channel_store.respond_to_channel_invite(channel, true) - }) - .await - .unwrap(); - } } impl Drop for TestClient { diff --git a/crates/collab_ui/Cargo.toml b/crates/collab_ui/Cargo.toml index 98790778c9..791c6b2fa7 100644 --- a/crates/collab_ui/Cargo.toml +++ b/crates/collab_ui/Cargo.toml @@ -37,10 +37,12 @@ fuzzy = { path = "../fuzzy" } gpui = { path = "../gpui" } language = { path = "../language" } menu = { path = "../menu" } +notifications = { path = "../notifications" } rich_text = { path = "../rich_text" } picker = { path = "../picker" } project = { path = "../project" } -recent_projects = {path = "../recent_projects"} +recent_projects = { path = "../recent_projects" } +rpc = { path = "../rpc" } settings = { path = "../settings" } feature_flags = {path = "../feature_flags"} theme = { path = "../theme" } @@ -52,12 +54,14 @@ zed-actions = {path = "../zed-actions"} anyhow.workspace = true futures.workspace = true +lazy_static.workspace = true log.workspace = true schemars.workspace = true postage.workspace = true serde.workspace = true serde_derive.workspace = true time.workspace = true +smallvec.workspace = true [dev-dependencies] call = { path = "../call", features = ["test-support"] } @@ -65,7 +69,12 @@ client = { path = "../client", features = ["test-support"] } collections = { path = "../collections", features = ["test-support"] } editor = { path = "../editor", features = ["test-support"] } gpui = { path = "../gpui", features = ["test-support"] } +notifications = { path = "../notifications", features = ["test-support"] } project = { path = "../project", features = ["test-support"] } +rpc = { path = "../rpc", features = ["test-support"] } settings = { path = "../settings", features = ["test-support"] } util = { path = "../util", features = ["test-support"] } workspace = { path = "../workspace", features = ["test-support"] } + +pretty_assertions.workspace = true +tree-sitter-markdown.workspace = true diff --git a/crates/collab_ui/src/channel_view.rs b/crates/collab_ui/src/channel_view.rs index e62ee8ef4b..1bdcebd018 100644 --- a/crates/collab_ui/src/channel_view.rs +++ b/crates/collab_ui/src/channel_view.rs @@ -15,13 +15,14 @@ use gpui::{ ViewContext, ViewHandle, }; use project::Project; +use smallvec::SmallVec; use std::{ any::{Any, TypeId}, sync::Arc, }; use util::ResultExt; use workspace::{ - item::{FollowableItem, Item, ItemHandle}, + item::{FollowableItem, Item, ItemEvent, ItemHandle}, register_followable_item, searchable::SearchableItemHandle, ItemNavHistory, Pane, SaveIntent, ViewId, Workspace, WorkspaceId, @@ -140,6 +141,12 @@ impl ChannelView { editor.set_collaboration_hub(Box::new(ChannelBufferCollaborationHub( channel_buffer.clone(), ))); + editor.set_read_only( + !channel_buffer + .read(cx) + .channel(cx) + .is_some_and(|c| c.can_edit_notes()), + ); editor }); let _editor_event_subscription = cx.subscribe(&editor, |_, _, e, cx| cx.emit(e.clone())); @@ -157,8 +164,8 @@ impl ChannelView { } } - pub fn channel(&self, cx: &AppContext) -> Arc { - self.channel_buffer.read(cx).channel() + pub fn channel(&self, cx: &AppContext) -> Option> { + self.channel_buffer.read(cx).channel(cx) } fn handle_channel_buffer_event( @@ -172,6 +179,13 @@ impl ChannelView { editor.set_read_only(true); cx.notify(); }), + ChannelBufferEvent::ChannelChanged => { + self.editor.update(cx, |editor, cx| { + editor.set_read_only(!self.channel(cx).is_some_and(|c| c.can_edit_notes())); + cx.emit(editor::Event::TitleChanged); + cx.notify() + }); + } ChannelBufferEvent::BufferEdited => { if cx.is_self_focused() || self.editor.is_focused(cx) { self.acknowledge_buffer_version(cx); @@ -179,7 +193,7 @@ impl ChannelView { self.channel_store.update(cx, |store, cx| { let channel_buffer = self.channel_buffer.read(cx); store.notes_changed( - channel_buffer.channel().id, + channel_buffer.channel_id, channel_buffer.epoch(), &channel_buffer.buffer().read(cx).version(), cx, @@ -187,7 +201,7 @@ impl ChannelView { }); } } - _ => {} + ChannelBufferEvent::CollaboratorsChanged => {} } } @@ -195,7 +209,7 @@ impl ChannelView { self.channel_store.update(cx, |store, cx| { let channel_buffer = self.channel_buffer.read(cx); store.acknowledge_notes_version( - channel_buffer.channel().id, + channel_buffer.channel_id, channel_buffer.epoch(), &channel_buffer.buffer().read(cx).version(), cx, @@ -250,11 +264,17 @@ impl Item for ChannelView { style: &theme::Tab, cx: &gpui::AppContext, ) -> AnyElement { - let channel_name = &self.channel_buffer.read(cx).channel().name; - let label = if self.channel_buffer.read(cx).is_connected() { - format!("#{}", channel_name) + let label = if let Some(channel) = self.channel(cx) { + match ( + channel.can_edit_notes(), + self.channel_buffer.read(cx).is_connected(), + ) { + (true, true) => format!("#{}", channel.name), + (false, true) => format!("#{} (read-only)", channel.name), + (_, false) => format!("#{} (disconnected)", channel.name), + } } else { - format!("#{} (disconnected)", channel_name) + format!("channel notes (disconnected)") }; Label::new(label, style.label.to_owned()).into_any() } @@ -298,6 +318,10 @@ impl Item for ChannelView { fn pixel_position_of_cursor(&self, cx: &AppContext) -> Option { self.editor.read(cx).pixel_position_of_cursor(cx) } + + fn to_item_events(event: &Self::Event) -> SmallVec<[ItemEvent; 2]> { + editor::Editor::to_item_events(event) + } } impl FollowableItem for ChannelView { @@ -313,7 +337,7 @@ impl FollowableItem for ChannelView { Some(proto::view::Variant::ChannelView( proto::view::ChannelView { - channel_id: channel_buffer.channel().id, + channel_id: channel_buffer.channel_id, editor: if let Some(proto::view::Variant::Editor(proto)) = self.editor.read(cx).to_state_proto(cx) { diff --git a/crates/collab_ui/src/chat_panel.rs b/crates/collab_ui/src/chat_panel.rs index 1a17b48f19..5a4dafb6d4 100644 --- a/crates/collab_ui/src/chat_panel.rs +++ b/crates/collab_ui/src/chat_panel.rs @@ -1,4 +1,6 @@ -use crate::{channel_view::ChannelView, ChatPanelSettings}; +use crate::{ + channel_view::ChannelView, is_channels_feature_enabled, render_avatar, ChatPanelSettings, +}; use anyhow::Result; use call::ActiveCall; use channel::{ChannelChat, ChannelChatEvent, ChannelMessageId, ChannelStore}; @@ -6,18 +8,18 @@ use client::Client; use collections::HashMap; use db::kvp::KEY_VALUE_STORE; use editor::Editor; -use feature_flags::{ChannelsAlpha, FeatureFlagAppExt}; use gpui::{ actions, elements::*, platform::{CursorStyle, MouseButton}, serde_json, views::{ItemType, Select, SelectStyle}, - AnyViewHandle, AppContext, AsyncAppContext, Entity, ImageData, ModelHandle, Subscription, Task, - View, ViewContext, ViewHandle, WeakViewHandle, + AnyViewHandle, AppContext, AsyncAppContext, Entity, ModelHandle, Subscription, Task, View, + ViewContext, ViewHandle, WeakViewHandle, }; -use language::{language_settings::SoftWrap, LanguageRegistry}; +use language::LanguageRegistry; use menu::Confirm; +use message_editor::MessageEditor; use project::Fs; use rich_text::RichText; use serde::{Deserialize, Serialize}; @@ -31,6 +33,8 @@ use workspace::{ Workspace, }; +mod message_editor; + const MESSAGE_LOADING_THRESHOLD: usize = 50; const CHAT_PANEL_KEY: &'static str = "ChatPanel"; @@ -40,7 +44,7 @@ pub struct ChatPanel { languages: Arc, active_chat: Option<(ModelHandle, Subscription)>, message_list: ListState, - input_editor: ViewHandle, + input_editor: ViewHandle, channel_select: ViewHandle