diff --git a/crates/channel/src/channel_store.rs b/crates/channel/src/channel_store.rs index 74ff7c731e..cc5009ead8 100644 --- a/crates/channel/src/channel_store.rs +++ b/crates/channel/src/channel_store.rs @@ -62,6 +62,7 @@ pub struct ChannelStore { opened_buffers: HashMap>, opened_chats: HashMap>, client: Arc, + did_subscribe: bool, user_store: Model, _rpc_subscriptions: [Subscription; 2], _watch_connection_status: Task>, @@ -243,6 +244,20 @@ impl ChannelStore { .log_err(); }), channel_states: Default::default(), + did_subscribe: false, + } + } + + pub fn initialize(&mut self) { + if !self.did_subscribe { + if self + .client + .send(proto::SubscribeToChannels {}) + .log_err() + .is_some() + { + self.did_subscribe = true; + } } } diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index 02b182aca7..d4ed9ea5e7 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -654,6 +654,7 @@ pub struct ChannelsForUser { pub channel_memberships: Vec, pub channel_participants: HashMap>, pub hosted_projects: Vec, + pub invited_channels: Vec, pub observed_buffer_versions: Vec, pub observed_channel_messages: Vec, diff --git a/crates/collab/src/db/queries/channels.rs b/crates/collab/src/db/queries/channels.rs index 966e52811a..ff7a95cf76 100644 --- a/crates/collab/src/db/queries/channels.rs +++ b/crates/collab/src/db/queries/channels.rs @@ -416,7 +416,9 @@ impl Database { user_id: UserId, tx: &DatabaseTransaction, ) -> Result { - let new_channels = self.get_user_channels(user_id, Some(channel), tx).await?; + let new_channels = self + .get_user_channels(user_id, Some(channel), false, tx) + .await?; let removed_channels = self .get_channel_descendants_excluding_self([channel], tx) .await? @@ -481,44 +483,10 @@ impl Database { .await } - /// Returns all channel invites for the user with the given ID. - pub async fn get_channel_invites_for_user(&self, user_id: UserId) -> Result> { - self.transaction(|tx| async move { - let mut role_for_channel: HashMap = HashMap::default(); - - let channel_invites = channel_member::Entity::find() - .filter( - channel_member::Column::UserId - .eq(user_id) - .and(channel_member::Column::Accepted.eq(false)), - ) - .all(&*tx) - .await?; - - for invite in channel_invites { - role_for_channel.insert(invite.channel_id, invite.role); - } - - let channels = channel::Entity::find() - .filter(channel::Column::Id.is_in(role_for_channel.keys().copied())) - .all(&*tx) - .await?; - - let channels = channels.into_iter().map(Channel::from_model).collect(); - - Ok(channels) - }) - .await - } - /// Returns all channels for the user with the given ID. pub async fn get_channels_for_user(&self, user_id: UserId) -> Result { - self.transaction(|tx| async move { - let tx = tx; - - self.get_user_channels(user_id, None, &tx).await - }) - .await + self.transaction(|tx| async move { self.get_user_channels(user_id, None, true, &tx).await }) + .await } /// Returns all channels for the user with the given ID that are descendants @@ -527,25 +495,37 @@ impl Database { &self, user_id: UserId, ancestor_channel: Option<&channel::Model>, + include_invites: bool, tx: &DatabaseTransaction, ) -> Result { - let mut filter = channel_member::Column::UserId - .eq(user_id) - .and(channel_member::Column::Accepted.eq(true)); - + let mut filter = channel_member::Column::UserId.eq(user_id); + if !include_invites { + filter = filter.and(channel_member::Column::Accepted.eq(true)) + } if let Some(ancestor) = ancestor_channel { filter = filter.and(channel_member::Column::ChannelId.eq(ancestor.root_id())); } - let channel_memberships = channel_member::Entity::find() + let mut channels = Vec::::new(); + let mut invited_channels = Vec::::new(); + let mut channel_memberships = Vec::::new(); + let mut rows = channel_member::Entity::find() .filter(filter) - .all(tx) - .await?; - - let channels = channel::Entity::find() - .filter(channel::Column::Id.is_in(channel_memberships.iter().map(|m| m.channel_id))) - .all(tx) + .inner_join(channel::Entity) + .select_also(channel::Entity) + .stream(tx) .await?; + while let Some(row) = rows.next().await { + if let (membership, Some(channel)) = row? { + if membership.accepted { + channel_memberships.push(membership); + channels.push(channel); + } else { + invited_channels.push(Channel::from_model(channel)); + } + } + } + drop(rows); let mut descendants = self .get_channel_descendants_excluding_self(channels.iter(), tx) @@ -643,6 +623,7 @@ impl Database { Ok(ChannelsForUser { channel_memberships, channels, + invited_channels, hosted_projects, channel_participants, latest_buffer_versions, diff --git a/crates/collab/src/db/tests/channel_tests.rs b/crates/collab/src/db/tests/channel_tests.rs index 4482549e91..d409867447 100644 --- a/crates/collab/src/db/tests/channel_tests.rs +++ b/crates/collab/src/db/tests/channel_tests.rs @@ -176,23 +176,23 @@ async fn test_channel_invites(db: &Arc) { .unwrap(); let user_2_invites = db - .get_channel_invites_for_user(user_2) // -> [channel_1_1, channel_1_2] + .get_channels_for_user(user_2) .await .unwrap() + .invited_channels .into_iter() .map(|channel| channel.id) .collect::>(); - assert_eq!(user_2_invites, &[channel_1_1, channel_1_2]); let user_3_invites = db - .get_channel_invites_for_user(user_3) // -> [channel_1_1] + .get_channels_for_user(user_3) .await .unwrap() + .invited_channels .into_iter() .map(|channel| channel.id) .collect::>(); - assert_eq!(user_3_invites, &[channel_1_1]); let (mut members, _) = db diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 7798a1492a..75626646d9 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -557,6 +557,7 @@ impl Server { .add_request_handler(user_handler(request_contact)) .add_request_handler(user_handler(remove_contact)) .add_request_handler(user_handler(respond_to_contact_request)) + .add_message_handler(subscribe_to_channels) .add_request_handler(user_handler(create_channel)) .add_request_handler(user_handler(delete_channel)) .add_request_handler(user_handler(invite_channel_member)) @@ -1105,34 +1106,25 @@ impl Server { .await?; } - let (contacts, channels_for_user, channel_invites, dev_server_projects) = - future::try_join4( - self.app_state.db.get_contacts(user.id), - self.app_state.db.get_channels_for_user(user.id), - self.app_state.db.get_channel_invites_for_user(user.id), - self.app_state.db.dev_server_projects_update(user.id), - ) - .await?; + let (contacts, dev_server_projects) = future::try_join( + self.app_state.db.get_contacts(user.id), + self.app_state.db.dev_server_projects_update(user.id), + ) + .await?; { let mut pool = self.connection_pool.lock(); pool.add_connection(connection_id, user.id, user.admin, zed_version); - for membership in &channels_for_user.channel_memberships { - pool.subscribe_to_channel(user.id, membership.channel_id, membership.role) - } self.peer.send( connection_id, build_initial_contacts_update(contacts, &pool), )?; - self.peer.send( - connection_id, - build_update_user_channels(&channels_for_user), - )?; - self.peer.send( - connection_id, - build_channels_update(channels_for_user, channel_invites), - )?; } + + if should_auto_subscribe_to_channels(zed_version) { + subscribe_user_to_channels(user.id, session).await?; + } + send_dev_server_projects_update(user.id, dev_server_projects, session).await; if let Some(incoming_call) = @@ -3399,6 +3391,36 @@ async fn remove_contact( Ok(()) } +fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool { + version.0.minor() < 139 +} + +async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> { + subscribe_user_to_channels( + session.user_id().ok_or_else(|| anyhow!("must be a user"))?, + &session, + ) + .await?; + Ok(()) +} + +async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> { + let channels_for_user = session.db().await.get_channels_for_user(user_id).await?; + let mut pool = session.connection_pool().await; + for membership in &channels_for_user.channel_memberships { + pool.subscribe_to_channel(user_id, membership.channel_id, membership.role) + } + session.peer.send( + session.connection_id, + build_update_user_channels(&channels_for_user), + )?; + session.peer.send( + session.connection_id, + build_channels_update(channels_for_user), + )?; + Ok(()) +} + /// Creates a new channel. async fn create_channel( request: proto::CreateChannel, @@ -5034,7 +5056,7 @@ fn notify_membership_updated( ..Default::default() }; - let mut update = build_channels_update(result.new_channels, vec![]); + let mut update = build_channels_update(result.new_channels); update.delete_channels = result .removed_channels .into_iter() @@ -5064,10 +5086,7 @@ fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserCh } } -fn build_channels_update( - channels: ChannelsForUser, - channel_invites: Vec, -) -> proto::UpdateChannels { +fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels { let mut update = proto::UpdateChannels::default(); for channel in channels.channels { @@ -5086,7 +5105,7 @@ fn build_channels_update( }); } - for channel in channel_invites { + for channel in channels.invited_channels { update.channel_invitations.push(channel.to_proto()); } diff --git a/crates/collab_ui/src/collab_panel.rs b/crates/collab_ui/src/collab_panel.rs index de77af6dd6..6cdc4d484d 100644 --- a/crates/collab_ui/src/collab_panel.rs +++ b/crates/collab_ui/src/collab_panel.rs @@ -2161,6 +2161,9 @@ impl CollabPanel { } fn render_signed_in(&mut self, cx: &mut ViewContext) -> Div { + self.channel_store.update(cx, |channel_store, _| { + channel_store.initialize(); + }); v_flex() .size_full() .child(list(self.list_state.clone()).size_full()) diff --git a/crates/rpc/proto/zed.proto b/crates/rpc/proto/zed.proto index 2965c92815..67c2a8045f 100644 --- a/crates/rpc/proto/zed.proto +++ b/crates/rpc/proto/zed.proto @@ -159,6 +159,7 @@ message Envelope { SetChannelMemberRole set_channel_member_role = 123; RenameChannel rename_channel = 124; RenameChannelResponse rename_channel_response = 125; + SubscribeToChannels subscribe_to_channels = 207; // current max JoinChannelBuffer join_channel_buffer = 126; JoinChannelBufferResponse join_channel_buffer_response = 127; @@ -250,7 +251,7 @@ message Envelope { TaskContextForLocation task_context_for_location = 203; TaskContext task_context = 204; TaskTemplatesResponse task_templates_response = 205; - TaskTemplates task_templates = 206; // Current max + TaskTemplates task_templates = 206; } reserved 158 to 161; @@ -1297,6 +1298,8 @@ message ChannelMember { } } +message SubscribeToChannels {} + message CreateChannel { string name = 1; optional uint64 parent_id = 2; diff --git a/crates/rpc/src/proto.rs b/crates/rpc/src/proto.rs index e41685af7a..1cb17b104f 100644 --- a/crates/rpc/src/proto.rs +++ b/crates/rpc/src/proto.rs @@ -277,6 +277,7 @@ messages!( (ShareProjectResponse, Foreground), (ShowContacts, Foreground), (StartLanguageServer, Foreground), + (SubscribeToChannels, Foreground), (SynchronizeBuffers, Foreground), (SynchronizeBuffersResponse, Foreground), (TaskContextForLocation, Background),