diff --git a/crates/call/src/call.rs b/crates/call/src/call.rs index 3cd868a438..6e58be4f15 100644 --- a/crates/call/src/call.rs +++ b/crates/call/src/call.rs @@ -279,15 +279,21 @@ impl ActiveCall { channel_id: u64, cx: &mut ModelContext, ) -> Task> { + let leave_room; if let Some(room) = self.room().cloned() { if room.read(cx).channel_id() == Some(channel_id) { return Task::ready(Ok(())); + } else { + leave_room = room.update(cx, |room, cx| room.leave(cx)); } + } else { + leave_room = Task::ready(Ok(())); } let join = Room::join_channel(channel_id, self.client.clone(), self.user_store.clone(), cx); cx.spawn(|this, mut cx| async move { + leave_room.await?; let room = join.await?; this.update(&mut cx, |this, cx| this.set_room(Some(room.clone()), cx)) .await?; diff --git a/crates/client/src/client.rs b/crates/client/src/client.rs index 1e86cef4cc..8ef3e32ea8 100644 --- a/crates/client/src/client.rs +++ b/crates/client/src/client.rs @@ -540,6 +540,7 @@ impl Client { } } + #[track_caller] pub fn add_message_handler( self: &Arc, model: ModelHandle, @@ -575,8 +576,11 @@ impl Client { }), ); if prev_handler.is_some() { + let location = std::panic::Location::caller(); panic!( - "registered handler for the same message {} twice", + "{}:{} registered handler for the same message {} twice", + location.file(), + location.line(), std::any::type_name::() ); } diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index 85f5d5f0b8..36b226b97b 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -1342,6 +1342,35 @@ impl Database { .await } + pub async fn is_current_room_different_channel( + &self, + user_id: UserId, + channel_id: ChannelId, + ) -> Result { + self.transaction(|tx| async move { + #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] + enum QueryAs { + ChannelId, + } + + let channel_id_model: Option = room_participant::Entity::find() + .select_only() + .column_as(room::Column::ChannelId, QueryAs::ChannelId) + .inner_join(room::Entity) + .filter(room_participant::Column::UserId.eq(user_id)) + .into_values::<_, QueryAs>() + .one(&*tx) + .await?; + + let result = channel_id_model + .map(|channel_id_model| channel_id_model != channel_id) + .unwrap_or(false); + + Ok(result) + }) + .await + } + pub async fn join_room( &self, room_id: RoomId, diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 526f12d812..15237049c3 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -2276,6 +2276,14 @@ async fn join_channel( let joined_room = { let db = session.db().await; + + if db + .is_current_room_different_channel(session.user_id, channel_id) + .await? + { + leave_room_for_session_with_guard(&session, &db).await?; + } + let room_id = db.room_id_for_channel(channel_id).await?; let joined_room = db @@ -2531,6 +2539,14 @@ fn channel_updated( async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> { let db = session.db().await; + update_user_contacts_with_guard(user_id, session, &db).await +} + +async fn update_user_contacts_with_guard( + user_id: UserId, + session: &Session, + db: &DbHandle, +) -> Result<()> { let contacts = db.get_contacts(user_id).await?; let busy = db.is_user_busy(user_id).await?; @@ -2564,6 +2580,11 @@ async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> } async fn leave_room_for_session(session: &Session) -> Result<()> { + let db = session.db().await; + leave_room_for_session_with_guard(session, &db).await +} + +async fn leave_room_for_session_with_guard(session: &Session, db: &DbHandle) -> Result<()> { let mut contacts_to_update = HashSet::default(); let room_id; @@ -2574,7 +2595,7 @@ async fn leave_room_for_session(session: &Session) -> Result<()> { let channel_members; let channel_id; - if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? { + if let Some(mut left_room) = db.leave_room(session.connection_id).await? { contacts_to_update.insert(session.user_id); for project in left_room.left_projects.values() { @@ -2624,7 +2645,7 @@ async fn leave_room_for_session(session: &Session) -> Result<()> { } for contact_user_id in contacts_to_update { - update_user_contacts(contact_user_id, &session).await?; + update_user_contacts_with_guard(contact_user_id, &session, db).await?; } if let Some(live_kit) = session.live_kit_client.as_ref() { diff --git a/crates/collab/src/tests/channel_tests.rs b/crates/collab/src/tests/channel_tests.rs index c41ac84d1d..3999740557 100644 --- a/crates/collab/src/tests/channel_tests.rs +++ b/crates/collab/src/tests/channel_tests.rs @@ -304,3 +304,50 @@ async fn test_channel_room( } ); } + +#[gpui::test] +async fn test_channel_jumping(deterministic: Arc, cx_a: &mut TestAppContext) { + deterministic.forbid_parking(); + let mut server = TestServer::start(&deterministic).await; + let client_a = server.create_client(cx_a, "user_a").await; + + let zed_id = server.make_channel("zed", (&client_a, cx_a), &mut []).await; + let rust_id = server + .make_channel("rust", (&client_a, cx_a), &mut []) + .await; + + let active_call_a = cx_a.read(ActiveCall::global); + + active_call_a + .update(cx_a, |active_call, cx| active_call.join_channel(zed_id, cx)) + .await + .unwrap(); + + // Give everything a chance to observe user A joining + deterministic.run_until_parked(); + + client_a.channel_store().read_with(cx_a, |channels, _| { + assert_participants_eq( + channels.channel_participants(zed_id), + &[client_a.user_id().unwrap()], + ); + assert_participants_eq(channels.channel_participants(rust_id), &[]); + }); + + active_call_a + .update(cx_a, |active_call, cx| { + active_call.join_channel(rust_id, cx) + }) + .await + .unwrap(); + + deterministic.run_until_parked(); + + client_a.channel_store().read_with(cx_a, |channels, _| { + assert_participants_eq(channels.channel_participants(zed_id), &[]); + assert_participants_eq( + channels.channel_participants(rust_id), + &[client_a.user_id().unwrap()], + ); + }); +}