mirror of
https://github.com/zed-industries/zed.git
synced 2024-09-18 18:08:07 +03:00
Trim whitespace from chat messages and limit their length
Add a way for the server to respond to any request with an error
This commit is contained in:
parent
b3d5f01ba8
commit
a98d293f54
@ -77,8 +77,14 @@ struct Channel {
|
||||
connection_ids: HashSet<ConnectionId>,
|
||||
}
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
const MESSAGE_COUNT_PER_PAGE: usize = 10;
|
||||
|
||||
#[cfg(not(debug_assertions))]
|
||||
const MESSAGE_COUNT_PER_PAGE: usize = 50;
|
||||
|
||||
const MAX_MESSAGE_LEN: usize = 1024;
|
||||
|
||||
impl Server {
|
||||
pub fn new(
|
||||
app_state: Arc<AppState>,
|
||||
@ -661,20 +667,33 @@ impl Server {
|
||||
}
|
||||
}
|
||||
|
||||
let receipt = request.receipt();
|
||||
let body = request.payload.body.trim().to_string();
|
||||
if body.len() > MAX_MESSAGE_LEN {
|
||||
self.peer
|
||||
.respond_with_error(
|
||||
receipt,
|
||||
proto::Error {
|
||||
message: "message is too long".to_string(),
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let timestamp = OffsetDateTime::now_utc();
|
||||
let message_id = self
|
||||
.app_state
|
||||
.db
|
||||
.create_channel_message(channel_id, user_id, &request.payload.body, timestamp)
|
||||
.create_channel_message(channel_id, user_id, &body, timestamp)
|
||||
.await?
|
||||
.to_proto();
|
||||
let receipt = request.receipt();
|
||||
let message = proto::ChannelMessageSent {
|
||||
channel_id: channel_id.to_proto(),
|
||||
message: Some(proto::ChannelMessage {
|
||||
sender_id: user_id.to_proto(),
|
||||
id: message_id,
|
||||
body: request.payload.body,
|
||||
body,
|
||||
timestamp: timestamp.unix_timestamp() as u64,
|
||||
}),
|
||||
};
|
||||
@ -1530,18 +1549,25 @@ mod tests {
|
||||
})
|
||||
.await;
|
||||
|
||||
channel_a.update(&mut cx_a, |channel, cx| {
|
||||
channel.send_message("oh, hi B.".to_string(), cx).unwrap();
|
||||
channel.send_message("sup".to_string(), cx).unwrap();
|
||||
assert_eq!(
|
||||
channel_a
|
||||
.update(&mut cx_a, |channel, cx| {
|
||||
channel
|
||||
.pending_messages()
|
||||
.iter()
|
||||
.map(|m| &m.body)
|
||||
.collect::<Vec<_>>(),
|
||||
&["oh, hi B.", "sup"]
|
||||
)
|
||||
});
|
||||
.send_message("oh, hi B.".to_string(), cx)
|
||||
.unwrap()
|
||||
.detach();
|
||||
let task = channel.send_message("sup".to_string(), cx).unwrap();
|
||||
assert_eq!(
|
||||
channel
|
||||
.pending_messages()
|
||||
.iter()
|
||||
.map(|m| &m.body)
|
||||
.collect::<Vec<_>>(),
|
||||
&["oh, hi B.", "sup"]
|
||||
);
|
||||
task
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
channel_a
|
||||
.condition(&cx_a, |channel, _| channel.pending_messages().is_empty())
|
||||
@ -1582,6 +1608,59 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_chat_message_validation(mut cx_a: TestAppContext) {
|
||||
cx_a.foreground().forbid_parking();
|
||||
|
||||
let mut server = TestServer::start().await;
|
||||
let (user_id_a, client_a) = server.create_client(&mut cx_a, "user_a").await;
|
||||
|
||||
let db = &server.app_state.db;
|
||||
let org_id = db.create_org("Test Org", "test-org").await.unwrap();
|
||||
let channel_id = db.create_org_channel(org_id, "test-channel").await.unwrap();
|
||||
db.add_org_member(org_id, user_id_a, false).await.unwrap();
|
||||
db.add_channel_member(channel_id, user_id_a, false)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let user_store_a = Arc::new(UserStore::new(client_a.clone()));
|
||||
let channels_a = cx_a.add_model(|cx| ChannelList::new(user_store_a, client_a, cx));
|
||||
channels_a
|
||||
.condition(&mut cx_a, |list, _| list.available_channels().is_some())
|
||||
.await;
|
||||
let channel_a = channels_a.update(&mut cx_a, |this, cx| {
|
||||
this.get_channel(channel_id.to_proto(), cx).unwrap()
|
||||
});
|
||||
|
||||
// Leading and trailing whitespace are trimmed.
|
||||
channel_a
|
||||
.update(&mut cx_a, |channel, cx| {
|
||||
channel
|
||||
.send_message("\n surrounded by whitespace \n".to_string(), cx)
|
||||
.unwrap()
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
db.get_channel_messages(channel_id, 10, None)
|
||||
.await
|
||||
.unwrap()
|
||||
.iter()
|
||||
.map(|m| &m.body)
|
||||
.collect::<Vec<_>>(),
|
||||
&["surrounded by whitespace"]
|
||||
);
|
||||
|
||||
// Messages aren't allowed to be too long.
|
||||
channel_a
|
||||
.update(&mut cx_a, |channel, cx| {
|
||||
let long_body = "this is long.\n".repeat(1024);
|
||||
channel.send_message(long_body, cx).unwrap()
|
||||
})
|
||||
.await
|
||||
.unwrap_err();
|
||||
}
|
||||
|
||||
struct TestServer {
|
||||
peer: Arc<Peer>,
|
||||
app_state: Arc<AppState>,
|
||||
|
@ -224,7 +224,11 @@ impl Channel {
|
||||
&self.details.name
|
||||
}
|
||||
|
||||
pub fn send_message(&mut self, body: String, cx: &mut ModelContext<Self>) -> Result<()> {
|
||||
pub fn send_message(
|
||||
&mut self,
|
||||
body: String,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) -> Result<Task<Result<()>>> {
|
||||
let channel_id = self.details.id;
|
||||
let current_user_id = self.current_user_id()?;
|
||||
let local_id = self.next_local_message_id;
|
||||
@ -235,41 +239,35 @@ impl Channel {
|
||||
});
|
||||
let user_store = self.user_store.clone();
|
||||
let rpc = self.rpc.clone();
|
||||
cx.spawn(|this, mut cx| {
|
||||
async move {
|
||||
let request = rpc.request(proto::SendChannelMessage { channel_id, body });
|
||||
let response = request.await?;
|
||||
let sender = user_store.get_user(current_user_id).await?;
|
||||
Ok(cx.spawn(|this, mut cx| async move {
|
||||
let request = rpc.request(proto::SendChannelMessage { channel_id, body });
|
||||
let response = request.await?;
|
||||
let sender = user_store.get_user(current_user_id).await?;
|
||||
|
||||
this.update(&mut cx, |this, cx| {
|
||||
if let Ok(i) = this
|
||||
.pending_messages
|
||||
.binary_search_by_key(&local_id, |msg| msg.local_id)
|
||||
{
|
||||
let body = this.pending_messages.remove(i).body;
|
||||
this.insert_messages(
|
||||
SumTree::from_item(
|
||||
ChannelMessage {
|
||||
id: response.message_id,
|
||||
timestamp: OffsetDateTime::from_unix_timestamp(
|
||||
response.timestamp as i64,
|
||||
)?,
|
||||
body,
|
||||
sender,
|
||||
},
|
||||
&(),
|
||||
),
|
||||
cx,
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
.log_err()
|
||||
})
|
||||
.detach();
|
||||
cx.notify();
|
||||
Ok(())
|
||||
this.update(&mut cx, |this, cx| {
|
||||
if let Ok(i) = this
|
||||
.pending_messages
|
||||
.binary_search_by_key(&local_id, |msg| msg.local_id)
|
||||
{
|
||||
let body = this.pending_messages.remove(i).body;
|
||||
this.insert_messages(
|
||||
SumTree::from_item(
|
||||
ChannelMessage {
|
||||
id: response.message_id,
|
||||
timestamp: OffsetDateTime::from_unix_timestamp(
|
||||
response.timestamp as i64,
|
||||
)?,
|
||||
body,
|
||||
sender,
|
||||
},
|
||||
&(),
|
||||
),
|
||||
cx,
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
})
|
||||
}))
|
||||
}
|
||||
|
||||
pub fn load_more_messages(&mut self, cx: &mut ModelContext<Self>) -> bool {
|
||||
|
@ -193,9 +193,12 @@ impl ChatPanel {
|
||||
body
|
||||
});
|
||||
|
||||
channel
|
||||
if let Some(task) = channel
|
||||
.update(cx, |channel, cx| channel.send_message(body, cx))
|
||||
.log_err();
|
||||
.log_err()
|
||||
{
|
||||
task.detach();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -6,34 +6,35 @@ message Envelope {
|
||||
optional uint32 responding_to = 2;
|
||||
optional uint32 original_sender_id = 3;
|
||||
oneof payload {
|
||||
Ping ping = 4;
|
||||
Pong pong = 5;
|
||||
ShareWorktree share_worktree = 6;
|
||||
ShareWorktreeResponse share_worktree_response = 7;
|
||||
OpenWorktree open_worktree = 8;
|
||||
OpenWorktreeResponse open_worktree_response = 9;
|
||||
UpdateWorktree update_worktree = 10;
|
||||
CloseWorktree close_worktree = 11;
|
||||
OpenBuffer open_buffer = 12;
|
||||
OpenBufferResponse open_buffer_response = 13;
|
||||
CloseBuffer close_buffer = 14;
|
||||
UpdateBuffer update_buffer = 15;
|
||||
SaveBuffer save_buffer = 16;
|
||||
BufferSaved buffer_saved = 17;
|
||||
AddPeer add_peer = 18;
|
||||
RemovePeer remove_peer = 19;
|
||||
GetChannels get_channels = 20;
|
||||
GetChannelsResponse get_channels_response = 21;
|
||||
GetUsers get_users = 22;
|
||||
GetUsersResponse get_users_response = 23;
|
||||
JoinChannel join_channel = 24;
|
||||
JoinChannelResponse join_channel_response = 25;
|
||||
LeaveChannel leave_channel = 26;
|
||||
SendChannelMessage send_channel_message = 27;
|
||||
SendChannelMessageResponse send_channel_message_response = 28;
|
||||
ChannelMessageSent channel_message_sent = 29;
|
||||
GetChannelMessages get_channel_messages = 30;
|
||||
GetChannelMessagesResponse get_channel_messages_response = 31;
|
||||
Error error = 4;
|
||||
Ping ping = 5;
|
||||
Pong pong = 6;
|
||||
ShareWorktree share_worktree = 7;
|
||||
ShareWorktreeResponse share_worktree_response = 8;
|
||||
OpenWorktree open_worktree = 9;
|
||||
OpenWorktreeResponse open_worktree_response = 10;
|
||||
UpdateWorktree update_worktree = 11;
|
||||
CloseWorktree close_worktree = 12;
|
||||
OpenBuffer open_buffer = 13;
|
||||
OpenBufferResponse open_buffer_response = 14;
|
||||
CloseBuffer close_buffer = 15;
|
||||
UpdateBuffer update_buffer = 16;
|
||||
SaveBuffer save_buffer = 17;
|
||||
BufferSaved buffer_saved = 18;
|
||||
AddPeer add_peer = 19;
|
||||
RemovePeer remove_peer = 20;
|
||||
GetChannels get_channels = 21;
|
||||
GetChannelsResponse get_channels_response = 22;
|
||||
GetUsers get_users = 23;
|
||||
GetUsersResponse get_users_response = 24;
|
||||
JoinChannel join_channel = 25;
|
||||
JoinChannelResponse join_channel_response = 26;
|
||||
LeaveChannel leave_channel = 27;
|
||||
SendChannelMessage send_channel_message = 28;
|
||||
SendChannelMessageResponse send_channel_message_response = 29;
|
||||
ChannelMessageSent channel_message_sent = 30;
|
||||
GetChannelMessages get_channel_messages = 31;
|
||||
GetChannelMessagesResponse get_channel_messages_response = 32;
|
||||
}
|
||||
}
|
||||
|
||||
@ -47,6 +48,10 @@ message Pong {
|
||||
int32 id = 2;
|
||||
}
|
||||
|
||||
message Error {
|
||||
string message = 1;
|
||||
}
|
||||
|
||||
message ShareWorktree {
|
||||
Worktree worktree = 1;
|
||||
}
|
||||
|
@ -238,8 +238,12 @@ impl Peer {
|
||||
.recv()
|
||||
.await
|
||||
.ok_or_else(|| anyhow!("connection was closed"))?;
|
||||
T::Response::from_envelope(response)
|
||||
.ok_or_else(|| anyhow!("received response of the wrong type"))
|
||||
if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
|
||||
Err(anyhow!("request failed").context(error.message.clone()))
|
||||
} else {
|
||||
T::Response::from_envelope(response)
|
||||
.ok_or_else(|| anyhow!("received response of the wrong type"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -301,6 +305,25 @@ impl Peer {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn respond_with_error<T: RequestMessage>(
|
||||
self: &Arc<Self>,
|
||||
receipt: Receipt<T>,
|
||||
response: proto::Error,
|
||||
) -> impl Future<Output = Result<()>> {
|
||||
let this = self.clone();
|
||||
async move {
|
||||
let mut connection = this.connection(receipt.sender_id).await?;
|
||||
let message_id = connection
|
||||
.next_message_id
|
||||
.fetch_add(1, atomic::Ordering::SeqCst);
|
||||
connection
|
||||
.outgoing_tx
|
||||
.send(response.into_envelope(message_id, Some(receipt.message_id), None))
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn connection(
|
||||
self: &Arc<Self>,
|
||||
connection_id: ConnectionId,
|
||||
|
@ -125,6 +125,7 @@ messages!(
|
||||
ChannelMessageSent,
|
||||
CloseBuffer,
|
||||
CloseWorktree,
|
||||
Error,
|
||||
GetChannelMessages,
|
||||
GetChannelMessagesResponse,
|
||||
GetChannels,
|
||||
|
Loading…
Reference in New Issue
Block a user