Prevent the same user from being called more than once

This commit is contained in:
Antonio Scandurra 2022-09-26 11:13:34 +02:00
parent 55b095cbd3
commit f4697ff4d1

View File

@ -13,7 +13,7 @@ pub type RoomId = u64;
#[derive(Default, Serialize)] #[derive(Default, Serialize)]
pub struct Store { pub struct Store {
connections: BTreeMap<ConnectionId, ConnectionState>, connections: BTreeMap<ConnectionId, ConnectionState>,
connections_by_user_id: BTreeMap<UserId, HashSet<ConnectionId>>, connections_by_user_id: BTreeMap<UserId, UserConnectionState>,
next_room_id: RoomId, next_room_id: RoomId,
rooms: BTreeMap<RoomId, proto::Room>, rooms: BTreeMap<RoomId, proto::Room>,
projects: BTreeMap<ProjectId, Project>, projects: BTreeMap<ProjectId, Project>,
@ -21,16 +21,27 @@ pub struct Store {
channels: BTreeMap<ChannelId, Channel>, channels: BTreeMap<ChannelId, Channel>,
} }
#[derive(Default, Serialize)]
struct UserConnectionState {
connection_ids: HashSet<ConnectionId>,
room: Option<RoomState>,
}
#[derive(Serialize)] #[derive(Serialize)]
struct ConnectionState { struct ConnectionState {
user_id: UserId, user_id: UserId,
admin: bool, admin: bool,
room: Option<RoomId>,
projects: BTreeSet<ProjectId>, projects: BTreeSet<ProjectId>,
requested_projects: HashSet<ProjectId>, requested_projects: HashSet<ProjectId>,
channels: HashSet<ChannelId>, channels: HashSet<ChannelId>,
} }
#[derive(Copy, Clone, Eq, PartialEq, Serialize)]
enum RoomState {
Joined,
Calling { room_id: RoomId },
}
#[derive(Serialize)] #[derive(Serialize)]
pub struct Project { pub struct Project {
pub online: bool, pub online: bool,
@ -140,7 +151,6 @@ impl Store {
ConnectionState { ConnectionState {
user_id, user_id,
admin, admin,
room: Default::default(),
projects: Default::default(), projects: Default::default(),
requested_projects: Default::default(), requested_projects: Default::default(),
channels: Default::default(), channels: Default::default(),
@ -149,6 +159,7 @@ impl Store {
self.connections_by_user_id self.connections_by_user_id
.entry(user_id) .entry(user_id)
.or_default() .or_default()
.connection_ids
.insert(connection_id); .insert(connection_id);
} }
@ -185,9 +196,9 @@ impl Store {
} }
} }
let user_connections = self.connections_by_user_id.get_mut(&user_id).unwrap(); let user_connection_state = self.connections_by_user_id.get_mut(&user_id).unwrap();
user_connections.remove(&connection_id); user_connection_state.connection_ids.remove(&connection_id);
if user_connections.is_empty() { if user_connection_state.connection_ids.is_empty() {
self.connections_by_user_id.remove(&user_id); self.connections_by_user_id.remove(&user_id);
} }
@ -239,6 +250,7 @@ impl Store {
self.connections_by_user_id self.connections_by_user_id
.get(&user_id) .get(&user_id)
.into_iter() .into_iter()
.map(|state| &state.connection_ids)
.flatten() .flatten()
.copied() .copied()
} }
@ -248,6 +260,7 @@ impl Store {
.connections_by_user_id .connections_by_user_id
.get(&user_id) .get(&user_id)
.unwrap_or(&Default::default()) .unwrap_or(&Default::default())
.connection_ids
.is_empty() .is_empty()
} }
@ -295,9 +308,10 @@ impl Store {
} }
pub fn project_metadata_for_user(&self, user_id: UserId) -> Vec<proto::ProjectMetadata> { pub fn project_metadata_for_user(&self, user_id: UserId) -> Vec<proto::ProjectMetadata> {
let connection_ids = self.connections_by_user_id.get(&user_id); let user_connection_state = self.connections_by_user_id.get(&user_id);
let project_ids = connection_ids.iter().flat_map(|connection_ids| { let project_ids = user_connection_state.iter().flat_map(|state| {
connection_ids state
.connection_ids
.iter() .iter()
.filter_map(|connection_id| self.connections.get(connection_id)) .filter_map(|connection_id| self.connections.get(connection_id))
.flat_map(|connection| connection.projects.iter().copied()) .flat_map(|connection| connection.projects.iter().copied())
@ -333,8 +347,12 @@ impl Store {
.connections .connections
.get_mut(&creator_connection_id) .get_mut(&creator_connection_id)
.ok_or_else(|| anyhow!("no such connection"))?; .ok_or_else(|| anyhow!("no such connection"))?;
let user_connection_state = self
.connections_by_user_id
.get_mut(&connection.user_id)
.ok_or_else(|| anyhow!("no such connection"))?;
anyhow::ensure!( anyhow::ensure!(
connection.room.is_none(), user_connection_state.room.is_none(),
"cannot participate in more than one room at once" "cannot participate in more than one room at once"
); );
@ -352,7 +370,7 @@ impl Store {
let room_id = post_inc(&mut self.next_room_id); let room_id = post_inc(&mut self.next_room_id);
self.rooms.insert(room_id, room); self.rooms.insert(room_id, room);
connection.room = Some(room_id); user_connection_state.room = Some(RoomState::Joined);
Ok(room_id) Ok(room_id)
} }
@ -365,14 +383,20 @@ impl Store {
.connections .connections
.get_mut(&connection_id) .get_mut(&connection_id)
.ok_or_else(|| anyhow!("no such connection"))?; .ok_or_else(|| anyhow!("no such connection"))?;
anyhow::ensure!(
connection.room.is_none(),
"cannot participate in more than one room at once"
);
let user_id = connection.user_id; let user_id = connection.user_id;
let recipient_ids = self.connection_ids_for_user(user_id).collect::<Vec<_>>(); let recipient_ids = self.connection_ids_for_user(user_id).collect::<Vec<_>>();
let mut user_connection_state = self
.connections_by_user_id
.get_mut(&user_id)
.ok_or_else(|| anyhow!("no such connection"))?;
anyhow::ensure!(
user_connection_state
.room
.map_or(true, |room| room == RoomState::Calling { room_id }),
"cannot participate in more than one room at once"
);
let room = self let room = self
.rooms .rooms
.get_mut(&room_id) .get_mut(&room_id)
@ -393,6 +417,7 @@ impl Store {
)), )),
}), }),
}); });
user_connection_state.room = Some(RoomState::Joined);
Ok((room, recipient_ids)) Ok((room, recipient_ids))
} }
@ -404,7 +429,17 @@ impl Store {
to_user_id: UserId, to_user_id: UserId,
) -> Result<(UserId, Vec<ConnectionId>, &proto::Room)> { ) -> Result<(UserId, Vec<ConnectionId>, &proto::Room)> {
let from_user_id = self.user_id_for_connection(from_connection_id)?; let from_user_id = self.user_id_for_connection(from_connection_id)?;
let to_connection_ids = self.connection_ids_for_user(to_user_id).collect::<Vec<_>>(); let to_connection_ids = self.connection_ids_for_user(to_user_id).collect::<Vec<_>>();
let mut to_user_connection_state = self
.connections_by_user_id
.get_mut(&to_user_id)
.ok_or_else(|| anyhow!("no such connection"))?;
anyhow::ensure!(
to_user_connection_state.room.is_none(),
"recipient is already on another call"
);
let room = self let room = self
.rooms .rooms
.get_mut(&room_id) .get_mut(&room_id)
@ -422,11 +457,18 @@ impl Store {
"cannot call the same user more than once" "cannot call the same user more than once"
); );
room.pending_calls_to_user_ids.push(to_user_id.to_proto()); room.pending_calls_to_user_ids.push(to_user_id.to_proto());
to_user_connection_state.room = Some(RoomState::Calling { room_id });
Ok((from_user_id, to_connection_ids, room)) Ok((from_user_id, to_connection_ids, room))
} }
pub fn call_failed(&mut self, room_id: RoomId, to_user_id: UserId) -> Result<&proto::Room> { pub fn call_failed(&mut self, room_id: RoomId, to_user_id: UserId) -> Result<&proto::Room> {
let mut to_user_connection_state = self
.connections_by_user_id
.get_mut(&to_user_id)
.ok_or_else(|| anyhow!("no such connection"))?;
anyhow::ensure!(to_user_connection_state.room == Some(RoomState::Calling { room_id }));
to_user_connection_state.room = None;
let room = self let room = self
.rooms .rooms
.get_mut(&room_id) .get_mut(&room_id)
@ -548,10 +590,12 @@ impl Store {
} }
for requester_user_id in project.join_requests.keys() { for requester_user_id in project.join_requests.keys() {
if let Some(requester_connection_ids) = if let Some(requester_user_connection_state) =
self.connections_by_user_id.get_mut(requester_user_id) self.connections_by_user_id.get_mut(requester_user_id)
{ {
for requester_connection_id in requester_connection_ids.iter() { for requester_connection_id in
&requester_user_connection_state.connection_ids
{
if let Some(requester_connection) = if let Some(requester_connection) =
self.connections.get_mut(requester_connection_id) self.connections.get_mut(requester_connection_id)
{ {
@ -907,11 +951,12 @@ impl Store {
.connections_by_user_id .connections_by_user_id
.get(&connection.user_id) .get(&connection.user_id)
.unwrap() .unwrap()
.connection_ids
.contains(connection_id)); .contains(connection_id));
} }
for (user_id, connection_ids) in &self.connections_by_user_id { for (user_id, state) in &self.connections_by_user_id {
for connection_id in connection_ids { for connection_id in &state.connection_ids {
assert_eq!( assert_eq!(
self.connections.get(connection_id).unwrap().user_id, self.connections.get(connection_id).unwrap().user_id,
*user_id *user_id