diff --git a/server/src/rpc.rs b/server/src/rpc.rs index c712ced835..696487fdca 100644 --- a/server/src/rpc.rs +++ b/server/src/rpc.rs @@ -77,8 +77,14 @@ struct Channel { connection_ids: HashSet, } +#[cfg(debug_assertions)] +const MESSAGE_COUNT_PER_PAGE: usize = 10; + +#[cfg(not(debug_assertions))] const MESSAGE_COUNT_PER_PAGE: usize = 50; +const MAX_MESSAGE_LEN: usize = 1024; + impl Server { pub fn new( app_state: Arc, @@ -661,20 +667,33 @@ impl Server { } } + let receipt = request.receipt(); + let body = request.payload.body.trim().to_string(); + if body.len() > MAX_MESSAGE_LEN { + self.peer + .respond_with_error( + receipt, + proto::Error { + message: "message is too long".to_string(), + }, + ) + .await?; + return Ok(()); + } + let timestamp = OffsetDateTime::now_utc(); let message_id = self .app_state .db - .create_channel_message(channel_id, user_id, &request.payload.body, timestamp) + .create_channel_message(channel_id, user_id, &body, timestamp) .await? .to_proto(); - let receipt = request.receipt(); let message = proto::ChannelMessageSent { channel_id: channel_id.to_proto(), message: Some(proto::ChannelMessage { sender_id: user_id.to_proto(), id: message_id, - body: request.payload.body, + body, timestamp: timestamp.unix_timestamp() as u64, }), }; @@ -1530,18 +1549,25 @@ mod tests { }) .await; - channel_a.update(&mut cx_a, |channel, cx| { - channel.send_message("oh, hi B.".to_string(), cx).unwrap(); - channel.send_message("sup".to_string(), cx).unwrap(); - assert_eq!( + channel_a + .update(&mut cx_a, |channel, cx| { channel - .pending_messages() - .iter() - .map(|m| &m.body) - .collect::>(), - &["oh, hi B.", "sup"] - ) - }); + .send_message("oh, hi B.".to_string(), cx) + .unwrap() + .detach(); + let task = channel.send_message("sup".to_string(), cx).unwrap(); + assert_eq!( + channel + .pending_messages() + .iter() + .map(|m| &m.body) + .collect::>(), + &["oh, hi B.", "sup"] + ); + task + }) + .await + .unwrap(); channel_a .condition(&cx_a, |channel, _| channel.pending_messages().is_empty()) @@ -1582,6 +1608,59 @@ mod tests { } } + #[gpui::test] + async fn test_chat_message_validation(mut cx_a: TestAppContext) { + cx_a.foreground().forbid_parking(); + + let mut server = TestServer::start().await; + let (user_id_a, client_a) = server.create_client(&mut cx_a, "user_a").await; + + let db = &server.app_state.db; + let org_id = db.create_org("Test Org", "test-org").await.unwrap(); + let channel_id = db.create_org_channel(org_id, "test-channel").await.unwrap(); + db.add_org_member(org_id, user_id_a, false).await.unwrap(); + db.add_channel_member(channel_id, user_id_a, false) + .await + .unwrap(); + + let user_store_a = Arc::new(UserStore::new(client_a.clone())); + let channels_a = cx_a.add_model(|cx| ChannelList::new(user_store_a, client_a, cx)); + channels_a + .condition(&mut cx_a, |list, _| list.available_channels().is_some()) + .await; + let channel_a = channels_a.update(&mut cx_a, |this, cx| { + this.get_channel(channel_id.to_proto(), cx).unwrap() + }); + + // Leading and trailing whitespace are trimmed. + channel_a + .update(&mut cx_a, |channel, cx| { + channel + .send_message("\n surrounded by whitespace \n".to_string(), cx) + .unwrap() + }) + .await + .unwrap(); + assert_eq!( + db.get_channel_messages(channel_id, 10, None) + .await + .unwrap() + .iter() + .map(|m| &m.body) + .collect::>(), + &["surrounded by whitespace"] + ); + + // Messages aren't allowed to be too long. + channel_a + .update(&mut cx_a, |channel, cx| { + let long_body = "this is long.\n".repeat(1024); + channel.send_message(long_body, cx).unwrap() + }) + .await + .unwrap_err(); + } + struct TestServer { peer: Arc, app_state: Arc, diff --git a/zed/src/channel.rs b/zed/src/channel.rs index f0f9d4f43c..6bac81fa19 100644 --- a/zed/src/channel.rs +++ b/zed/src/channel.rs @@ -224,7 +224,11 @@ impl Channel { &self.details.name } - pub fn send_message(&mut self, body: String, cx: &mut ModelContext) -> Result<()> { + pub fn send_message( + &mut self, + body: String, + cx: &mut ModelContext, + ) -> Result>> { let channel_id = self.details.id; let current_user_id = self.current_user_id()?; let local_id = self.next_local_message_id; @@ -235,41 +239,35 @@ impl Channel { }); let user_store = self.user_store.clone(); let rpc = self.rpc.clone(); - cx.spawn(|this, mut cx| { - async move { - let request = rpc.request(proto::SendChannelMessage { channel_id, body }); - let response = request.await?; - let sender = user_store.get_user(current_user_id).await?; + Ok(cx.spawn(|this, mut cx| async move { + let request = rpc.request(proto::SendChannelMessage { channel_id, body }); + let response = request.await?; + let sender = user_store.get_user(current_user_id).await?; - this.update(&mut cx, |this, cx| { - if let Ok(i) = this - .pending_messages - .binary_search_by_key(&local_id, |msg| msg.local_id) - { - let body = this.pending_messages.remove(i).body; - this.insert_messages( - SumTree::from_item( - ChannelMessage { - id: response.message_id, - timestamp: OffsetDateTime::from_unix_timestamp( - response.timestamp as i64, - )?, - body, - sender, - }, - &(), - ), - cx, - ); - } - Ok(()) - }) - } - .log_err() - }) - .detach(); - cx.notify(); - Ok(()) + this.update(&mut cx, |this, cx| { + if let Ok(i) = this + .pending_messages + .binary_search_by_key(&local_id, |msg| msg.local_id) + { + let body = this.pending_messages.remove(i).body; + this.insert_messages( + SumTree::from_item( + ChannelMessage { + id: response.message_id, + timestamp: OffsetDateTime::from_unix_timestamp( + response.timestamp as i64, + )?, + body, + sender, + }, + &(), + ), + cx, + ); + } + Ok(()) + }) + })) } pub fn load_more_messages(&mut self, cx: &mut ModelContext) -> bool { diff --git a/zed/src/chat_panel.rs b/zed/src/chat_panel.rs index a8b4ed94bf..e94f5712e5 100644 --- a/zed/src/chat_panel.rs +++ b/zed/src/chat_panel.rs @@ -193,9 +193,12 @@ impl ChatPanel { body }); - channel + if let Some(task) = channel .update(cx, |channel, cx| channel.send_message(body, cx)) - .log_err(); + .log_err() + { + task.detach(); + } } } diff --git a/zrpc/proto/zed.proto b/zrpc/proto/zed.proto index a94c0f6204..123fc0f1da 100644 --- a/zrpc/proto/zed.proto +++ b/zrpc/proto/zed.proto @@ -6,34 +6,35 @@ message Envelope { optional uint32 responding_to = 2; optional uint32 original_sender_id = 3; oneof payload { - Ping ping = 4; - Pong pong = 5; - ShareWorktree share_worktree = 6; - ShareWorktreeResponse share_worktree_response = 7; - OpenWorktree open_worktree = 8; - OpenWorktreeResponse open_worktree_response = 9; - UpdateWorktree update_worktree = 10; - CloseWorktree close_worktree = 11; - OpenBuffer open_buffer = 12; - OpenBufferResponse open_buffer_response = 13; - CloseBuffer close_buffer = 14; - UpdateBuffer update_buffer = 15; - SaveBuffer save_buffer = 16; - BufferSaved buffer_saved = 17; - AddPeer add_peer = 18; - RemovePeer remove_peer = 19; - GetChannels get_channels = 20; - GetChannelsResponse get_channels_response = 21; - GetUsers get_users = 22; - GetUsersResponse get_users_response = 23; - JoinChannel join_channel = 24; - JoinChannelResponse join_channel_response = 25; - LeaveChannel leave_channel = 26; - SendChannelMessage send_channel_message = 27; - SendChannelMessageResponse send_channel_message_response = 28; - ChannelMessageSent channel_message_sent = 29; - GetChannelMessages get_channel_messages = 30; - GetChannelMessagesResponse get_channel_messages_response = 31; + Error error = 4; + Ping ping = 5; + Pong pong = 6; + ShareWorktree share_worktree = 7; + ShareWorktreeResponse share_worktree_response = 8; + OpenWorktree open_worktree = 9; + OpenWorktreeResponse open_worktree_response = 10; + UpdateWorktree update_worktree = 11; + CloseWorktree close_worktree = 12; + OpenBuffer open_buffer = 13; + OpenBufferResponse open_buffer_response = 14; + CloseBuffer close_buffer = 15; + UpdateBuffer update_buffer = 16; + SaveBuffer save_buffer = 17; + BufferSaved buffer_saved = 18; + AddPeer add_peer = 19; + RemovePeer remove_peer = 20; + GetChannels get_channels = 21; + GetChannelsResponse get_channels_response = 22; + GetUsers get_users = 23; + GetUsersResponse get_users_response = 24; + JoinChannel join_channel = 25; + JoinChannelResponse join_channel_response = 26; + LeaveChannel leave_channel = 27; + SendChannelMessage send_channel_message = 28; + SendChannelMessageResponse send_channel_message_response = 29; + ChannelMessageSent channel_message_sent = 30; + GetChannelMessages get_channel_messages = 31; + GetChannelMessagesResponse get_channel_messages_response = 32; } } @@ -47,6 +48,10 @@ message Pong { int32 id = 2; } +message Error { + string message = 1; +} + message ShareWorktree { Worktree worktree = 1; } diff --git a/zrpc/src/peer.rs b/zrpc/src/peer.rs index 06d4b01ae0..9a7954341b 100644 --- a/zrpc/src/peer.rs +++ b/zrpc/src/peer.rs @@ -238,8 +238,12 @@ impl Peer { .recv() .await .ok_or_else(|| anyhow!("connection was closed"))?; - T::Response::from_envelope(response) - .ok_or_else(|| anyhow!("received response of the wrong type")) + if let Some(proto::envelope::Payload::Error(error)) = &response.payload { + Err(anyhow!("request failed").context(error.message.clone())) + } else { + T::Response::from_envelope(response) + .ok_or_else(|| anyhow!("received response of the wrong type")) + } } } @@ -301,6 +305,25 @@ impl Peer { } } + pub fn respond_with_error( + self: &Arc, + receipt: Receipt, + response: proto::Error, + ) -> impl Future> { + let this = self.clone(); + async move { + let mut connection = this.connection(receipt.sender_id).await?; + let message_id = connection + .next_message_id + .fetch_add(1, atomic::Ordering::SeqCst); + connection + .outgoing_tx + .send(response.into_envelope(message_id, Some(receipt.message_id), None)) + .await?; + Ok(()) + } + } + fn connection( self: &Arc, connection_id: ConnectionId, diff --git a/zrpc/src/proto.rs b/zrpc/src/proto.rs index 3743a06a07..002c5bc840 100644 --- a/zrpc/src/proto.rs +++ b/zrpc/src/proto.rs @@ -125,6 +125,7 @@ messages!( ChannelMessageSent, CloseBuffer, CloseWorktree, + Error, GetChannelMessages, GetChannelMessagesResponse, GetChannels,