diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index 0527e070ea..e9249edcb1 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -460,6 +460,8 @@ pub struct UpdatedChannelMessage { pub notifications: NotificationBatch, pub reply_to_message_id: Option, pub timestamp: PrimitiveDateTime, + pub deleted_mention_notification_ids: Vec, + pub updated_mention_notifications: Vec, } #[derive(Clone, Debug, PartialEq, Eq, FromQueryResult, Serialize, Deserialize)] diff --git a/crates/collab/src/db/queries/messages.rs b/crates/collab/src/db/queries/messages.rs index f0db33b2da..2fa6edeb0a 100644 --- a/crates/collab/src/db/queries/messages.rs +++ b/crates/collab/src/db/queries/messages.rs @@ -1,7 +1,8 @@ use super::*; use rpc::Notification; -use sea_orm::TryInsertResult; +use sea_orm::{SelectColumns, TryInsertResult}; use time::OffsetDateTime; +use util::ResultExt; impl Database { /// Inserts a record representing a user joining the chat for a given channel. @@ -480,13 +481,20 @@ impl Database { Ok(results) } + fn get_notification_kind_id_by_name(&self, notification_kind: &str) -> Option { + self.notification_kinds_by_id + .iter() + .find(|(_, kind)| **kind == notification_kind) + .map(|kind| kind.0 .0) + } + /// Removes the channel message with the given ID. pub async fn remove_channel_message( &self, channel_id: ChannelId, message_id: MessageId, user_id: UserId, - ) -> Result> { + ) -> Result<(Vec, Vec)> { self.transaction(|tx| async move { let mut rows = channel_chat_participant::Entity::find() .filter(channel_chat_participant::Column::ChannelId.eq(channel_id)) @@ -531,7 +539,29 @@ impl Database { } } - Ok(participant_connection_ids) + let notification_kind_id = + self.get_notification_kind_id_by_name("ChannelMessageMention"); + + let existing_notifications = notification::Entity::find() + .filter(notification::Column::EntityId.eq(message_id)) + .filter(notification::Column::Kind.eq(notification_kind_id)) + .select_column(notification::Column::Id) + .all(&*tx) + .await?; + + let existing_notification_ids = existing_notifications + .into_iter() + .map(|notification| notification.id) + .collect(); + + // remove all the mention notifications for this message + notification::Entity::delete_many() + .filter(notification::Column::EntityId.eq(message_id)) + .filter(notification::Column::Kind.eq(notification_kind_id)) + .exec(&*tx) + .await?; + + Ok((participant_connection_ids, existing_notification_ids)) }) .await } @@ -629,14 +659,44 @@ impl Database { .await?; } - let mut mentioned_user_ids = mentions.iter().map(|m| m.user_id).collect::>(); + let mut update_mention_user_ids = HashSet::default(); + let mut new_mention_user_ids = + mentions.iter().map(|m| m.user_id).collect::>(); // Filter out users that were mentioned before - for mention in old_mentions { - mentioned_user_ids.remove(&mention.user_id.to_proto()); + for mention in &old_mentions { + if new_mention_user_ids.contains(&mention.user_id.to_proto()) { + update_mention_user_ids.insert(mention.user_id.to_proto()); + } + + new_mention_user_ids.remove(&mention.user_id.to_proto()); + } + + let notification_kind_id = + self.get_notification_kind_id_by_name("ChannelMessageMention"); + + let existing_notifications = notification::Entity::find() + .filter(notification::Column::EntityId.eq(message_id)) + .filter(notification::Column::Kind.eq(notification_kind_id)) + .all(&*tx) + .await?; + + // determine which notifications should be updated or deleted + let mut deleted_notification_ids = HashSet::default(); + let mut updated_mention_notifications = Vec::new(); + for notification in existing_notifications { + if update_mention_user_ids.contains(¬ification.recipient_id.to_proto()) { + if let Some(notification) = + self::notifications::model_to_proto(self, notification).log_err() + { + updated_mention_notifications.push(notification); + } + } else { + deleted_notification_ids.insert(notification.id); + } } let mut notifications = Vec::new(); - for mentioned_user in mentioned_user_ids { + for mentioned_user in new_mention_user_ids { notifications.extend( self.create_notification( UserId::from_proto(mentioned_user), @@ -658,6 +718,10 @@ impl Database { notifications, reply_to_message_id: channel_message.reply_to_message_id, timestamp: channel_message.sent_at, + deleted_mention_notification_ids: deleted_notification_ids + .into_iter() + .collect::>(), + updated_mention_notifications, }) }) .await diff --git a/crates/collab/src/db/queries/notifications.rs b/crates/collab/src/db/queries/notifications.rs index 5a44f62a53..e0993f0d56 100644 --- a/crates/collab/src/db/queries/notifications.rs +++ b/crates/collab/src/db/queries/notifications.rs @@ -1,5 +1,6 @@ use super::*; use rpc::Notification; +use util::ResultExt; impl Database { /// Initializes the different kinds of notifications by upserting records for them. @@ -53,11 +54,8 @@ impl Database { .await?; while let Some(row) = rows.next().await { let row = row?; - let kind = row.kind; - if let Some(proto) = model_to_proto(self, row) { + if let Some(proto) = model_to_proto(self, row).log_err() { result.push(proto); - } else { - log::warn!("unknown notification kind {:?}", kind); } } result.reverse(); @@ -200,7 +198,9 @@ impl Database { }) .exec(tx) .await?; - Ok(model_to_proto(self, row).map(|notification| (recipient_id, notification))) + Ok(model_to_proto(self, row) + .map(|notification| (recipient_id, notification)) + .ok()) } else { Ok(None) } @@ -241,9 +241,12 @@ impl Database { } } -fn model_to_proto(this: &Database, row: notification::Model) -> Option { - let kind = this.notification_kinds_by_id.get(&row.kind)?; - Some(proto::Notification { +pub fn model_to_proto(this: &Database, row: notification::Model) -> Result { + let kind = this + .notification_kinds_by_id + .get(&row.kind) + .ok_or_else(|| anyhow!("Unknown notification kind"))?; + Ok(proto::Notification { id: row.id.to_proto(), kind: kind.to_string(), timestamp: row.created_at.assume_utc().unix_timestamp() as u64, diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 4055b7ca8d..16f131af51 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -3388,14 +3388,30 @@ async fn remove_channel_message( ) -> Result<()> { let channel_id = ChannelId::from_proto(request.channel_id); let message_id = MessageId::from_proto(request.message_id); - let connection_ids = session + let (connection_ids, existing_notification_ids) = session .db() .await .remove_channel_message(channel_id, message_id, session.user_id()) .await?; - broadcast(Some(session.connection_id), connection_ids, |connection| { - session.peer.send(connection, request.clone()) - }); + + broadcast( + Some(session.connection_id), + connection_ids, + move |connection| { + session.peer.send(connection, request.clone())?; + + for notification_id in &existing_notification_ids { + session.peer.send( + connection, + proto::DeleteNotification { + notification_id: (*notification_id).to_proto(), + }, + )?; + } + + Ok(()) + }, + ); response.send(proto::Ack {})?; Ok(()) } @@ -3414,6 +3430,8 @@ async fn update_channel_message( notifications, reply_to_message_id, timestamp, + deleted_mention_notification_ids, + updated_mention_notifications, } = session .db() .await @@ -3456,7 +3474,27 @@ async fn update_channel_message( channel_id: channel_id.to_proto(), message: Some(message.clone()), }, - ) + )?; + + for notification_id in &deleted_mention_notification_ids { + session.peer.send( + connection, + proto::DeleteNotification { + notification_id: (*notification_id).to_proto(), + }, + )?; + } + + for notification in &updated_mention_notifications { + session.peer.send( + connection, + proto::UpdateNotification { + notification: Some(notification.clone()), + }, + )?; + } + + Ok(()) }, ); diff --git a/crates/collab/src/tests/channel_message_tests.rs b/crates/collab/src/tests/channel_message_tests.rs index b3242485bd..459604468c 100644 --- a/crates/collab/src/tests/channel_message_tests.rs +++ b/crates/collab/src/tests/channel_message_tests.rs @@ -222,8 +222,18 @@ async fn test_remove_channel_message( .update(cx_a, |c, cx| c.send_message("one".into(), cx).unwrap()) .await .unwrap(); - channel_chat_a - .update(cx_a, |c, cx| c.send_message("two".into(), cx).unwrap()) + let msg_id_2 = channel_chat_a + .update(cx_a, |c, cx| { + c.send_message( + MessageParams { + text: "two @user_b".to_string(), + mentions: vec![(4..12, client_b.id())], + reply_to_message_id: None, + }, + cx, + ) + .unwrap() + }) .await .unwrap(); channel_chat_a @@ -233,10 +243,24 @@ async fn test_remove_channel_message( // Clients A and B see all of the messages. executor.run_until_parked(); - let expected_messages = &["one", "two", "three"]; + let expected_messages = &["one", "two @user_b", "three"]; assert_messages(&channel_chat_a, expected_messages, cx_a); assert_messages(&channel_chat_b, expected_messages, cx_b); + // Ensure that client B received a notification for the mention. + client_b.notification_store().read_with(cx_b, |store, _| { + assert_eq!(store.notification_count(), 2); + let entry = store.notification_at(0).unwrap(); + assert_eq!( + entry.notification, + Notification::ChannelMessageMention { + message_id: msg_id_2, + sender_id: client_a.id(), + channel_id: channel_id.0, + } + ); + }); + // Client A deletes one of their messages. channel_chat_a .update(cx_a, |c, cx| { @@ -261,6 +285,13 @@ async fn test_remove_channel_message( .await .unwrap(); assert_messages(&channel_chat_c, expected_messages, cx_c); + + // Ensure we remove the notifications when the message is removed + client_b.notification_store().read_with(cx_b, |store, _| { + // First notification is the channel invitation, second would be the mention + // notification, which should now be removed. + assert_eq!(store.notification_count(), 1); + }); } #[track_caller] @@ -598,4 +629,97 @@ async fn test_chat_editing(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) } ); }); + + // Test update message and keep the mention and check that the body is updated correctly + + channel_chat_a + .update(cx_a, |c, cx| { + c.update_message( + msg_id, + MessageParams { + text: "Updated body v2 including a mention for @user_b".into(), + reply_to_message_id: None, + mentions: vec![(37..45, client_b.id())], + }, + cx, + ) + .unwrap() + }) + .await + .unwrap(); + + cx_a.run_until_parked(); + cx_b.run_until_parked(); + + channel_chat_a.update(cx_a, |channel_chat, _| { + assert_eq!( + channel_chat.find_loaded_message(msg_id).unwrap().body, + "Updated body v2 including a mention for @user_b", + ) + }); + channel_chat_b.update(cx_b, |channel_chat, _| { + assert_eq!( + channel_chat.find_loaded_message(msg_id).unwrap().body, + "Updated body v2 including a mention for @user_b", + ) + }); + + client_b.notification_store().read_with(cx_b, |store, _| { + let message = store.channel_message_for_id(msg_id); + assert!(message.is_some()); + assert_eq!( + message.unwrap().body, + "Updated body v2 including a mention for @user_b" + ); + assert_eq!(store.notification_count(), 2); + let entry = store.notification_at(0).unwrap(); + assert_eq!( + entry.notification, + Notification::ChannelMessageMention { + message_id: msg_id, + sender_id: client_a.id(), + channel_id: channel_id.0, + } + ); + }); + + // If we remove a mention from a message the corresponding mention notification + // should also be removed. + + channel_chat_a + .update(cx_a, |c, cx| { + c.update_message( + msg_id, + MessageParams { + text: "Updated body without a mention".into(), + reply_to_message_id: None, + mentions: vec![], + }, + cx, + ) + .unwrap() + }) + .await + .unwrap(); + + cx_a.run_until_parked(); + cx_b.run_until_parked(); + + channel_chat_a.update(cx_a, |channel_chat, _| { + assert_eq!( + channel_chat.find_loaded_message(msg_id).unwrap().body, + "Updated body without a mention", + ) + }); + channel_chat_b.update(cx_b, |channel_chat, _| { + assert_eq!( + channel_chat.find_loaded_message(msg_id).unwrap().body, + "Updated body without a mention", + ) + }); + client_b.notification_store().read_with(cx_b, |store, _| { + // First notification is the channel invitation, second would be the mention + // notification, which should now be removed. + assert_eq!(store.notification_count(), 1); + }); } diff --git a/crates/notifications/src/notification_store.rs b/crates/notifications/src/notification_store.rs index 67a1ec487a..100398f3f4 100644 --- a/crates/notifications/src/notification_store.rs +++ b/crates/notifications/src/notification_store.rs @@ -114,6 +114,7 @@ impl NotificationStore { _subscriptions: vec![ client.add_message_handler(cx.weak_model(), Self::handle_new_notification), client.add_message_handler(cx.weak_model(), Self::handle_delete_notification), + client.add_message_handler(cx.weak_model(), Self::handle_update_notification), ], user_store, client, @@ -236,6 +237,40 @@ impl NotificationStore { })? } + async fn handle_update_notification( + this: Model, + envelope: TypedEnvelope, + _: Arc, + mut cx: AsyncAppContext, + ) -> Result<()> { + this.update(&mut cx, |this, cx| { + if let Some(notification) = envelope.payload.notification { + if let Some(rpc::Notification::ChannelMessageMention { + message_id, + sender_id: _, + channel_id: _, + }) = Notification::from_proto(¬ification) + { + let fetch_message_task = this.channel_store.update(cx, |this, cx| { + this.fetch_channel_messages(vec![message_id], cx) + }); + + cx.spawn(|this, mut cx| async move { + let messages = fetch_message_task.await?; + this.update(&mut cx, move |this, cx| { + for message in messages { + this.channel_messages.insert(message_id, message); + } + cx.notify(); + }) + }) + .detach_and_log_err(cx) + } + } + Ok(()) + })? + } + async fn add_notifications( this: Model, notifications: Vec, diff --git a/crates/rpc/proto/zed.proto b/crates/rpc/proto/zed.proto index a610e2b850..630308b459 100644 --- a/crates/rpc/proto/zed.proto +++ b/crates/rpc/proto/zed.proto @@ -208,7 +208,9 @@ message Envelope { ChannelMessageUpdate channel_message_update = 171; BlameBuffer blame_buffer = 172; - BlameBufferResponse blame_buffer_response = 173; // Current max + BlameBufferResponse blame_buffer_response = 173; + + UpdateNotification update_notification = 174; // current max } reserved 158 to 161; @@ -1715,6 +1717,10 @@ message DeleteNotification { uint64 notification_id = 1; } +message UpdateNotification { + Notification notification = 1; +} + message MarkNotificationRead { uint64 notification_id = 1; } diff --git a/crates/rpc/src/proto.rs b/crates/rpc/src/proto.rs index bc2b44046f..89f44faab8 100644 --- a/crates/rpc/src/proto.rs +++ b/crates/rpc/src/proto.rs @@ -163,6 +163,7 @@ messages!( (DeclineCall, Foreground), (DeleteChannel, Foreground), (DeleteNotification, Foreground), + (UpdateNotification, Foreground), (DeleteProjectEntry, Foreground), (EndStream, Foreground), (Error, Foreground),