diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index 4bb61c3404..056f94ecfe 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -17,10 +17,11 @@ pub trait Db: Send + Sync { async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()>; async fn destroy_user(&self, id: UserId) -> Result<()>; - async fn get_contacts(&self, id: UserId) -> Result; + async fn get_contacts(&self, id: UserId) -> Result>; + async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result; async fn send_contact_request(&self, requester_id: UserId, responder_id: UserId) -> Result<()>; async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()>; - async fn dismiss_contact_request( + async fn dismiss_contact_notification( &self, responder_id: UserId, requester_id: UserId, @@ -190,7 +191,7 @@ impl Db for PostgresDb { // contacts - async fn get_contacts(&self, user_id: UserId) -> Result { + async fn get_contacts(&self, user_id: UserId) -> Result> { let query = " SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify FROM contacts @@ -201,46 +202,67 @@ impl Db for PostgresDb { .bind(user_id) .fetch(&self.pool); - let mut current = vec![user_id]; - let mut outgoing_requests = Vec::new(); - let mut incoming_requests = Vec::new(); + let mut contacts = vec![Contact::Accepted { + user_id, + should_notify: false, + }]; while let Some(row) = rows.next().await { let (user_id_a, user_id_b, a_to_b, accepted, should_notify) = row?; if user_id_a == user_id { if accepted { - current.push(user_id_b); + contacts.push(Contact::Accepted { + user_id: user_id_b, + should_notify: should_notify && a_to_b, + }); } else if a_to_b { - outgoing_requests.push(user_id_b); + contacts.push(Contact::Outgoing { user_id: user_id_b }) } else { - incoming_requests.push(IncomingContactRequest { - requester_id: user_id_b, + contacts.push(Contact::Incoming { + user_id: user_id_b, should_notify, }); } } else { if accepted { - current.push(user_id_a); + contacts.push(Contact::Accepted { + user_id: user_id_a, + should_notify: should_notify && !a_to_b, + }); } else if a_to_b { - incoming_requests.push(IncomingContactRequest { - requester_id: user_id_a, + contacts.push(Contact::Incoming { + user_id: user_id_a, should_notify, }); } else { - outgoing_requests.push(user_id_a); + contacts.push(Contact::Outgoing { user_id: user_id_a }); } } } - current.sort_unstable(); - outgoing_requests.sort_unstable(); - incoming_requests.sort_unstable(); + contacts.sort_unstable_by_key(|contact| contact.user_id()); - Ok(Contacts { - current, - outgoing_requests, - incoming_requests, - }) + Ok(contacts) + } + + async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result { + let (id_a, id_b) = if user_id_1 < user_id_2 { + (user_id_1, user_id_2) + } else { + (user_id_2, user_id_1) + }; + + let query = " + SELECT 1 FROM contacts + WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = 't' + LIMIT 1 + "; + Ok(sqlx::query_scalar::<_, i32>(query) + .bind(id_a.0) + .bind(id_b.0) + .fetch_optional(&self.pool) + .await? + .is_some()) } async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> { @@ -254,7 +276,8 @@ impl Db for PostgresDb { VALUES ($1, $2, $3, 'f', 't') ON CONFLICT (user_id_a, user_id_b) DO UPDATE SET - accepted = 't' + accepted = 't', + should_notify = 'f' WHERE NOT contacts.accepted AND ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR @@ -297,21 +320,26 @@ impl Db for PostgresDb { } } - async fn dismiss_contact_request( + async fn dismiss_contact_notification( &self, - responder_id: UserId, - requester_id: UserId, + user_id: UserId, + contact_user_id: UserId, ) -> Result<()> { - let (id_a, id_b, a_to_b) = if responder_id < requester_id { - (responder_id, requester_id, false) + let (id_a, id_b, a_to_b) = if user_id < contact_user_id { + (user_id, contact_user_id, true) } else { - (requester_id, responder_id, true) + (contact_user_id, user_id, false) }; let query = " UPDATE contacts SET should_notify = 'f' - WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3; + WHERE + user_id_a = $1 AND user_id_b = $2 AND + ( + (a_to_b = $3 AND accepted) OR + (a_to_b != $3 AND NOT accepted) + ); "; let result = sqlx::query(query) @@ -342,7 +370,7 @@ impl Db for PostgresDb { let result = if accept { let query = " UPDATE contacts - SET accepted = 't', should_notify = 'f' + SET accepted = 't', should_notify = 't' WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3; "; sqlx::query(query) @@ -702,10 +730,28 @@ pub struct ChannelMessage { } #[derive(Clone, Debug, PartialEq, Eq)] -pub struct Contacts { - pub current: Vec, - pub incoming_requests: Vec, - pub outgoing_requests: Vec, +pub enum Contact { + Accepted { + user_id: UserId, + should_notify: bool, + }, + Outgoing { + user_id: UserId, + }, + Incoming { + user_id: UserId, + should_notify: bool, + }, +} + +impl Contact { + pub fn user_id(&self) -> UserId { + match self { + Contact::Accepted { user_id, .. } => *user_id, + Contact::Outgoing { user_id } => *user_id, + Contact::Incoming { user_id, .. } => *user_id, + } + } } #[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] @@ -947,51 +993,60 @@ pub mod tests { // User starts with no contacts assert_eq!( db.get_contacts(user_1).await.unwrap(), - Contacts { - current: vec![user_1], - outgoing_requests: vec![], - incoming_requests: vec![], - }, + vec![Contact::Accepted { + user_id: user_1, + should_notify: false + }], ); // User requests a contact. Both users see the pending request. db.send_contact_request(user_1, user_2).await.unwrap(); + assert!(!db.has_contact(user_1, user_2).await.unwrap()); + assert!(!db.has_contact(user_2, user_1).await.unwrap()); assert_eq!( db.get_contacts(user_1).await.unwrap(), - Contacts { - current: vec![user_1], - outgoing_requests: vec![user_2], - incoming_requests: vec![], - }, + &[ + Contact::Accepted { + user_id: user_1, + should_notify: false + }, + Contact::Outgoing { user_id: user_2 } + ], ); assert_eq!( db.get_contacts(user_2).await.unwrap(), - Contacts { - current: vec![user_2], - outgoing_requests: vec![], - incoming_requests: vec![IncomingContactRequest { - requester_id: user_1, + &[ + Contact::Incoming { + user_id: user_1, should_notify: true - }], - }, + }, + Contact::Accepted { + user_id: user_2, + should_notify: false + }, + ] ); // User 2 dismisses the contact request notification without accepting or rejecting. // We shouldn't notify them again. - db.dismiss_contact_request(user_1, user_2) + db.dismiss_contact_notification(user_1, user_2) .await .unwrap_err(); - db.dismiss_contact_request(user_2, user_1).await.unwrap(); + db.dismiss_contact_notification(user_2, user_1) + .await + .unwrap(); assert_eq!( db.get_contacts(user_2).await.unwrap(), - Contacts { - current: vec![user_2], - outgoing_requests: vec![], - incoming_requests: vec![IncomingContactRequest { - requester_id: user_1, + &[ + Contact::Incoming { + user_id: user_1, should_notify: false - }], - }, + }, + Contact::Accepted { + user_id: user_2, + should_notify: false + }, + ] ); // User can't accept their own contact request @@ -1005,44 +1060,106 @@ pub mod tests { .unwrap(); assert_eq!( db.get_contacts(user_1).await.unwrap(), - Contacts { - current: vec![user_1, user_2], - outgoing_requests: vec![], - incoming_requests: vec![], - }, + &[ + Contact::Accepted { + user_id: user_1, + should_notify: false + }, + Contact::Accepted { + user_id: user_2, + should_notify: true + } + ], ); + assert!(db.has_contact(user_1, user_2).await.unwrap()); + assert!(db.has_contact(user_2, user_1).await.unwrap()); assert_eq!( db.get_contacts(user_2).await.unwrap(), - Contacts { - current: vec![user_1, user_2], - outgoing_requests: vec![], - incoming_requests: vec![], - }, + &[ + Contact::Accepted { + user_id: user_1, + should_notify: false, + }, + Contact::Accepted { + user_id: user_2, + should_notify: false, + }, + ] ); // Users cannot re-request existing contacts. db.send_contact_request(user_1, user_2).await.unwrap_err(); db.send_contact_request(user_2, user_1).await.unwrap_err(); + // Users can't dismiss notifications of them accepting other users' requests. + db.dismiss_contact_notification(user_2, user_1) + .await + .unwrap_err(); + assert_eq!( + db.get_contacts(user_1).await.unwrap(), + &[ + Contact::Accepted { + user_id: user_1, + should_notify: false + }, + Contact::Accepted { + user_id: user_2, + should_notify: true, + }, + ] + ); + + // Users can dismiss notifications of other users accepting their requests. + db.dismiss_contact_notification(user_1, user_2) + .await + .unwrap(); + assert_eq!( + db.get_contacts(user_1).await.unwrap(), + &[ + Contact::Accepted { + user_id: user_1, + should_notify: false + }, + Contact::Accepted { + user_id: user_2, + should_notify: false, + }, + ] + ); + // Users send each other concurrent contact requests and // see that they are immediately accepted. db.send_contact_request(user_1, user_3).await.unwrap(); db.send_contact_request(user_3, user_1).await.unwrap(); assert_eq!( db.get_contacts(user_1).await.unwrap(), - Contacts { - current: vec![user_1, user_2, user_3], - outgoing_requests: vec![], - incoming_requests: vec![], - }, + &[ + Contact::Accepted { + user_id: user_1, + should_notify: false + }, + Contact::Accepted { + user_id: user_2, + should_notify: false, + }, + Contact::Accepted { + user_id: user_3, + should_notify: false + }, + ] ); assert_eq!( db.get_contacts(user_3).await.unwrap(), - Contacts { - current: vec![user_1, user_3], - outgoing_requests: vec![], - incoming_requests: vec![], - }, + &[ + Contact::Accepted { + user_id: user_1, + should_notify: false + }, + Contact::Accepted { + user_id: user_3, + should_notify: false + } + ], ); // User declines a contact request. Both users see that it is gone. @@ -1050,21 +1167,33 @@ pub mod tests { db.respond_to_contact_request(user_3, user_2, false) .await .unwrap(); + assert!(!db.has_contact(user_2, user_3).await.unwrap()); + assert!(!db.has_contact(user_3, user_2).await.unwrap()); assert_eq!( db.get_contacts(user_2).await.unwrap(), - Contacts { - current: vec![user_1, user_2], - outgoing_requests: vec![], - incoming_requests: vec![], - }, + &[ + Contact::Accepted { + user_id: user_1, + should_notify: false + }, + Contact::Accepted { + user_id: user_2, + should_notify: false + } + ] ); assert_eq!( db.get_contacts(user_3).await.unwrap(), - Contacts { - current: vec![user_1, user_3], - outgoing_requests: vec![], - incoming_requests: vec![], - }, + &[ + Contact::Accepted { + user_id: user_1, + should_notify: false + }, + Contact::Accepted { + user_id: user_3, + should_notify: false + } + ], ); } } @@ -1219,40 +1348,51 @@ pub mod tests { unimplemented!() } - async fn get_contacts(&self, id: UserId) -> Result { + async fn get_contacts(&self, id: UserId) -> Result> { self.background.simulate_random_delay().await; - let mut current = vec![id]; - let mut outgoing_requests = Vec::new(); - let mut incoming_requests = Vec::new(); + let mut contacts = vec![Contact::Accepted { + user_id: id, + should_notify: false, + }]; for contact in self.contacts.lock().iter() { if contact.requester_id == id { if contact.accepted { - current.push(contact.responder_id); + contacts.push(Contact::Accepted { + user_id: contact.responder_id, + should_notify: contact.should_notify, + }); } else { - outgoing_requests.push(contact.responder_id); + contacts.push(Contact::Outgoing { + user_id: contact.responder_id, + }); } } else if contact.responder_id == id { if contact.accepted { - current.push(contact.requester_id); + contacts.push(Contact::Accepted { + user_id: contact.requester_id, + should_notify: false, + }); } else { - incoming_requests.push(IncomingContactRequest { - requester_id: contact.requester_id, + contacts.push(Contact::Incoming { + user_id: contact.requester_id, should_notify: contact.should_notify, }); } } } - current.sort_unstable(); - outgoing_requests.sort_unstable(); - incoming_requests.sort_unstable(); + contacts.sort_unstable_by_key(|contact| contact.user_id()); + Ok(contacts) + } - Ok(Contacts { - current, - outgoing_requests, - incoming_requests, - }) + async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result { + self.background.simulate_random_delay().await; + Ok(self.contacts.lock().iter().any(|contact| { + contact.accepted + && ((contact.requester_id == user_id_a && contact.responder_id == user_id_b) + || (contact.requester_id == user_id_b && contact.responder_id == user_id_a)) + })) } async fn send_contact_request( @@ -1274,6 +1414,7 @@ pub mod tests { Err(anyhow!("contact already exists"))?; } else { contact.accepted = true; + contact.should_notify = false; return Ok(()); } } @@ -1294,22 +1435,29 @@ pub mod tests { Ok(()) } - async fn dismiss_contact_request( + async fn dismiss_contact_notification( &self, - responder_id: UserId, - requester_id: UserId, + user_id: UserId, + contact_user_id: UserId, ) -> Result<()> { let mut contacts = self.contacts.lock(); for contact in contacts.iter_mut() { - if contact.requester_id == requester_id && contact.responder_id == responder_id { - if contact.accepted { - return Err(anyhow!("contact already confirmed")); - } + if contact.requester_id == contact_user_id + && contact.responder_id == user_id + && !contact.accepted + { + contact.should_notify = false; + return Ok(()); + } + if contact.requester_id == user_id + && contact.responder_id == contact_user_id + && contact.accepted + { contact.should_notify = false; return Ok(()); } } - Err(anyhow!("no such contact request")) + Err(anyhow!("no such notification")) } async fn respond_to_contact_request( @@ -1326,6 +1474,7 @@ pub mod tests { } if accept { contact.accepted = true; + contact.should_notify = true; } else { contacts.remove(ix); } diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 8cd4b6387c..4bf06fe7a3 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -2,7 +2,7 @@ mod store; use crate::{ auth, - db::{ChannelId, MessageId, UserId}, + db::{self, ChannelId, MessageId, UserId}, AppState, Result, }; use anyhow::anyhow; @@ -421,21 +421,27 @@ impl Server { let contacts = self.app_state.db.get_contacts(user_id).await?; let store = self.store().await; let updated_contact = store.contact_for_user(user_id); - for contact_user_id in contacts.current { - for contact_conn_id in store.connection_ids_for_user(contact_user_id) { - self.peer - .send( - contact_conn_id, - proto::UpdateContacts { - contacts: vec![updated_contact.clone()], - remove_contacts: Default::default(), - incoming_requests: Default::default(), - remove_incoming_requests: Default::default(), - outgoing_requests: Default::default(), - remove_outgoing_requests: Default::default(), - }, - ) - .trace_err(); + for contact in contacts { + if let db::Contact::Accepted { + user_id: contact_user_id, + .. + } = contact + { + for contact_conn_id in store.connection_ids_for_user(contact_user_id) { + self.peer + .send( + contact_conn_id, + proto::UpdateContacts { + contacts: vec![updated_contact.clone()], + remove_contacts: Default::default(), + incoming_requests: Default::default(), + remove_incoming_requests: Default::default(), + outgoing_requests: Default::default(), + remove_outgoing_requests: Default::default(), + }, + ) + .trace_err(); + } } } Ok(()) @@ -473,8 +479,12 @@ impl Server { guest_user_id = state.user_id_for_connection(request.sender_id)?; }; - let guest_contacts = self.app_state.db.get_contacts(guest_user_id).await?; - if !guest_contacts.current.contains(&host_user_id) { + let has_contact = self + .app_state + .db + .has_contact(guest_user_id, host_user_id) + .await?; + if !has_contact { return Err(anyhow!("no such project"))?; } @@ -1026,7 +1036,7 @@ impl Server { if request.payload.response == proto::ContactRequestResponse::Dismiss as i32 { self.app_state .db - .dismiss_contact_request(responder_id, requester_id) + .dismiss_contact_notification(responder_id, requester_id) .await?; } else { let accept = request.payload.response == proto::ContactRequestResponse::Accept as i32; diff --git a/crates/collab/src/rpc/store.rs b/crates/collab/src/rpc/store.rs index 8ca2706228..9f56c95a47 100644 --- a/crates/collab/src/rpc/store.rs +++ b/crates/collab/src/rpc/store.rs @@ -217,23 +217,30 @@ impl Store { .is_empty() } - pub fn build_initial_contacts_update(&self, contacts: db::Contacts) -> proto::UpdateContacts { + pub fn build_initial_contacts_update( + &self, + contacts: Vec, + ) -> proto::UpdateContacts { let mut update = proto::UpdateContacts::default(); - for user_id in contacts.current { - update.contacts.push(self.contact_for_user(user_id)); - } - for request in contacts.incoming_requests { - update - .incoming_requests - .push(proto::IncomingContactRequest { - requester_id: request.requester_id.to_proto(), - should_notify: request.should_notify, - }) - } - - for requested_user_id in contacts.outgoing_requests { - update.outgoing_requests.push(requested_user_id.to_proto()) + for contact in contacts { + match contact { + db::Contact::Accepted { user_id, .. } => { + update.contacts.push(self.contact_for_user(user_id)); + } + db::Contact::Outgoing { user_id } => { + update.outgoing_requests.push(user_id.to_proto()) + } + db::Contact::Incoming { + user_id, + should_notify, + } => update + .incoming_requests + .push(proto::IncomingContactRequest { + requester_id: user_id.to_proto(), + should_notify, + }), + } } update