Fetch older messages when scrolling up in the chat message list

Co-Authored-By: Nathan Sobo <nathan@zed.dev>
This commit is contained in:
Max Brunsfeld 2021-08-27 14:58:28 -07:00
parent 5262dcd3cb
commit bc63fca8d7
8 changed files with 281 additions and 68 deletions

View File

@ -29,6 +29,7 @@ struct StateInner {
heights: SumTree<ElementHeight>,
scroll_position: f32,
orientation: Orientation,
scroll_handler: Option<Box<dyn FnMut(Range<usize>, &mut EventContext)>>,
}
#[derive(Clone, Debug)]
@ -272,6 +273,7 @@ impl ListState {
heights,
scroll_position: 0.,
orientation,
scroll_handler: None,
})))
}
@ -290,6 +292,13 @@ impl ListState {
drop(old_heights);
state.heights = new_heights;
}
pub fn set_scroll_handler(
&mut self,
handler: impl FnMut(Range<usize>, &mut EventContext) + 'static,
) {
self.0.borrow_mut().scroll_handler = Some(Box::new(handler))
}
}
impl StateInner {
@ -320,6 +329,11 @@ impl StateInner {
Orientation::Bottom => delta.y(),
};
self.scroll_position = (self.scroll_position + delta_y).max(0.).min(scroll_max);
if self.scroll_handler.is_some() {
let range = self.visible_range(height);
self.scroll_handler.as_mut().unwrap()(range, cx);
}
cx.notify();
true

View File

@ -85,7 +85,7 @@ async fn post_user(mut request: Request) -> tide::Result {
async fn put_user(mut request: Request) -> tide::Result {
request.require_admin().await?;
let user_id = request.param("id")?.parse::<i32>()?;
let user_id = request.param("id")?.parse()?;
#[derive(Deserialize)]
struct Body {
@ -104,14 +104,14 @@ async fn put_user(mut request: Request) -> tide::Result {
async fn delete_user(request: Request) -> tide::Result {
request.require_admin().await?;
let user_id = db::UserId(request.param("id")?.parse::<i32>()?);
let user_id = db::UserId(request.param("id")?.parse()?);
request.db().delete_user(user_id).await?;
Ok(tide::Redirect::new("/admin").into())
}
async fn delete_signup(request: Request) -> tide::Result {
request.require_admin().await?;
let signup_id = db::SignupId(request.param("id")?.parse::<i32>()?);
let signup_id = db::SignupId(request.param("id")?.parse()?);
request.db().delete_signup(signup_id).await?;
Ok(tide::Redirect::new("/admin").into())
}

View File

@ -380,6 +380,7 @@ impl Db {
&self,
channel_id: ChannelId,
count: usize,
before_id: Option<MessageId>,
) -> Result<Vec<ChannelMessage>> {
test_support!(self, {
let query = r#"
@ -389,14 +390,16 @@ impl Db {
FROM
channel_messages
WHERE
channel_id = $1
channel_id = $1 AND
id < $2
ORDER BY id DESC
LIMIT $2
LIMIT $3
) as recent_messages
ORDER BY id ASC
"#;
sqlx::query_as(query)
.bind(channel_id.0)
.bind(before_id.unwrap_or(MessageId::MAX))
.bind(count as i64)
.fetch_all(&self.pool)
.await
@ -412,6 +415,9 @@ macro_rules! id_type {
pub struct $name(pub i32);
impl $name {
#[allow(unused)]
pub const MAX: Self = Self(i32::MAX);
#[allow(unused)]
pub fn from_proto(value: u64) -> Self {
Self(value as i32)
@ -512,10 +518,22 @@ pub mod tests {
.unwrap();
}
let messages = db.get_recent_channel_messages(channel, 5).await.unwrap();
let messages = db
.get_recent_channel_messages(channel, 5, None)
.await
.unwrap();
assert_eq!(
messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
["5", "6", "7", "8", "9"]
);
let prev_messages = db
.get_recent_channel_messages(channel, 4, Some(messages[0].id))
.await
.unwrap();
assert_eq!(
prev_messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
["1", "2", "3", "4"]
);
}
}

View File

@ -1,6 +1,6 @@
use super::{
auth,
db::{ChannelId, UserId},
db::{ChannelId, MessageId, UserId},
AppState,
};
use anyhow::anyhow;
@ -77,6 +77,8 @@ struct Channel {
connection_ids: HashSet<ConnectionId>,
}
const MESSAGE_COUNT_PER_PAGE: usize = 50;
impl Server {
pub fn new(
app_state: Arc<AppState>,
@ -105,7 +107,8 @@ impl Server {
.add_handler(Server::get_users)
.add_handler(Server::join_channel)
.add_handler(Server::leave_channel)
.add_handler(Server::send_channel_message);
.add_handler(Server::send_channel_message)
.add_handler(Server::get_channel_messages);
Arc::new(server)
}
@ -592,7 +595,7 @@ impl Server {
let messages = self
.app_state
.db
.get_recent_channel_messages(channel_id, 50)
.get_recent_channel_messages(channel_id, MESSAGE_COUNT_PER_PAGE, None)
.await?
.into_iter()
.map(|msg| proto::ChannelMessage {
@ -601,9 +604,15 @@ impl Server {
timestamp: msg.sent_at.unix_timestamp() as u64,
sender_id: msg.sender_id.to_proto(),
})
.collect();
.collect::<Vec<_>>();
self.peer
.respond(request.receipt(), proto::JoinChannelResponse { messages })
.respond(
request.receipt(),
proto::JoinChannelResponse {
done: messages.len() < MESSAGE_COUNT_PER_PAGE,
messages,
},
)
.await?;
Ok(())
}
@ -685,6 +694,54 @@ impl Server {
Ok(())
}
async fn get_channel_messages(
self: Arc<Self>,
request: TypedEnvelope<proto::GetChannelMessages>,
) -> tide::Result<()> {
let user_id = self
.state
.read()
.await
.user_id_for_connection(request.sender_id)?;
let channel_id = ChannelId::from_proto(request.payload.channel_id);
if !self
.app_state
.db
.can_user_access_channel(user_id, channel_id)
.await?
{
Err(anyhow!("access denied"))?;
}
let messages = self
.app_state
.db
.get_recent_channel_messages(
channel_id,
MESSAGE_COUNT_PER_PAGE,
Some(MessageId::from_proto(request.payload.before_message_id)),
)
.await?
.into_iter()
.map(|msg| proto::ChannelMessage {
id: msg.id.to_proto(),
body: msg.body,
timestamp: msg.sent_at.unix_timestamp() as u64,
sender_id: msg.sender_id.to_proto(),
})
.collect::<Vec<_>>();
self.peer
.respond(
request.receipt(),
proto::GetChannelMessagesResponse {
done: messages.len() < MESSAGE_COUNT_PER_PAGE,
messages,
},
)
.await?;
Ok(())
}
async fn broadcast_in_worktree<T: proto::EnvelopedMessage>(
&self,
worktree_id: u64,

View File

@ -37,6 +37,7 @@ pub struct ChannelDetails {
pub struct Channel {
details: ChannelDetails,
messages: SumTree<ChannelMessage>,
loaded_all_messages: bool,
pending_messages: Vec<PendingChannelMessage>,
next_local_message_id: u64,
user_store: Arc<UserStore>,
@ -70,7 +71,7 @@ pub enum ChannelListEvent {}
#[derive(Clone, Debug, PartialEq)]
pub enum ChannelEvent {
Message {
MessagesAdded {
old_range: Range<usize>,
new_count: usize,
},
@ -192,31 +193,12 @@ impl Channel {
cx.spawn(|channel, mut cx| {
async move {
let response = rpc.request(proto::JoinChannel { channel_id }).await?;
let unique_user_ids = response
.messages
.iter()
.map(|m| m.sender_id)
.collect::<HashSet<_>>()
.into_iter()
.collect();
user_store.load_users(unique_user_ids).await?;
let mut messages = Vec::with_capacity(response.messages.len());
for message in response.messages {
messages.push(ChannelMessage::from_proto(message, &user_store).await?);
}
let messages = messages_from_proto(response.messages, &user_store).await?;
let loaded_all_messages = response.done;
channel.update(&mut cx, |channel, cx| {
let old_count = channel.messages.summary().count;
let new_count = messages.len();
channel.messages = SumTree::new();
channel.messages.extend(messages, &());
cx.emit(ChannelEvent::Message {
old_range: 0..old_count,
new_count,
});
channel.insert_messages(messages, cx);
channel.loaded_all_messages = loaded_all_messages;
});
Ok(())
@ -232,6 +214,7 @@ impl Channel {
rpc,
messages: Default::default(),
pending_messages: Default::default(),
loaded_all_messages: false,
next_local_message_id: 0,
_subscription,
}
@ -264,15 +247,18 @@ impl Channel {
.binary_search_by_key(&local_id, |msg| msg.local_id)
{
let body = this.pending_messages.remove(i).body;
this.insert_message(
ChannelMessage {
id: response.message_id,
timestamp: OffsetDateTime::from_unix_timestamp(
response.timestamp as i64,
)?,
body,
sender,
},
this.insert_messages(
SumTree::from_item(
ChannelMessage {
id: response.message_id,
timestamp: OffsetDateTime::from_unix_timestamp(
response.timestamp as i64,
)?,
body,
sender,
},
&(),
),
cx,
);
}
@ -286,6 +272,37 @@ impl Channel {
Ok(())
}
pub fn load_more_messages(&mut self, cx: &mut ModelContext<Self>) -> bool {
if !self.loaded_all_messages {
let rpc = self.rpc.clone();
let user_store = self.user_store.clone();
let channel_id = self.details.id;
if let Some(before_message_id) = self.messages.first().map(|message| message.id) {
cx.spawn(|this, mut cx| {
async move {
let response = rpc
.request(proto::GetChannelMessages {
channel_id,
before_message_id,
})
.await?;
let loaded_all_messages = response.done;
let messages = messages_from_proto(response.messages, &user_store).await?;
this.update(&mut cx, |this, cx| {
this.loaded_all_messages = loaded_all_messages;
this.insert_messages(messages, cx);
});
Ok(())
}
.log_err()
})
.detach();
return true;
}
}
false
}
pub fn message_count(&self) -> usize {
self.messages.summary().count
}
@ -326,7 +343,9 @@ impl Channel {
cx.spawn(|this, mut cx| {
async move {
let message = ChannelMessage::from_proto(message, &user_store).await?;
this.update(&mut cx, |this, cx| this.insert_message(message, cx));
this.update(&mut cx, |this, cx| {
this.insert_messages(SumTree::from_item(message, &()), cx)
});
Ok(())
}
.log_err()
@ -335,29 +354,51 @@ impl Channel {
Ok(())
}
fn insert_message(&mut self, message: ChannelMessage, cx: &mut ModelContext<Self>) {
let mut old_cursor = self.messages.cursor::<u64, Count>();
let mut new_messages = old_cursor.slice(&message.id, Bias::Left, &());
let start_ix = old_cursor.sum_start().0;
let mut end_ix = start_ix;
if old_cursor.item().map_or(false, |m| m.id == message.id) {
old_cursor.next(&());
end_ix += 1;
fn insert_messages(&mut self, messages: SumTree<ChannelMessage>, cx: &mut ModelContext<Self>) {
if let Some((first_message, last_message)) = messages.first().zip(messages.last()) {
let mut old_cursor = self.messages.cursor::<u64, Count>();
let mut new_messages = old_cursor.slice(&first_message.id, Bias::Left, &());
let start_ix = old_cursor.sum_start().0;
let removed_messages = old_cursor.slice(&last_message.id, Bias::Right, &());
let removed_count = removed_messages.summary().count;
let new_count = messages.summary().count;
let end_ix = start_ix + removed_count;
new_messages.push_tree(messages, &());
new_messages.push_tree(old_cursor.suffix(&()), &());
drop(old_cursor);
self.messages = new_messages;
cx.emit(ChannelEvent::MessagesAdded {
old_range: start_ix..end_ix,
new_count,
});
cx.notify();
}
new_messages.push(message.clone(), &());
new_messages.push_tree(old_cursor.suffix(&()), &());
drop(old_cursor);
self.messages = new_messages;
cx.emit(ChannelEvent::Message {
old_range: start_ix..end_ix,
new_count: 1,
});
cx.notify();
}
}
async fn messages_from_proto(
proto_messages: Vec<proto::ChannelMessage>,
user_store: &UserStore,
) -> Result<SumTree<ChannelMessage>> {
let unique_user_ids = proto_messages
.iter()
.map(|m| m.sender_id)
.collect::<HashSet<_>>()
.into_iter()
.collect();
user_store.load_users(unique_user_ids).await?;
let mut messages = Vec::with_capacity(proto_messages.len());
for message in proto_messages {
messages.push(ChannelMessage::from_proto(message, &user_store).await?);
}
let mut result = SumTree::new();
result.extend(messages, &());
Ok(result)
}
impl From<proto::Channel> for ChannelDetails {
fn from(message: proto::Channel) -> Self {
Self {
@ -489,9 +530,11 @@ mod tests {
sender_id: 6,
},
],
done: false,
},
)
.await;
// Client requests all users for the received messages
let mut get_users = server.receive::<proto::GetUsers>().await;
get_users.payload.user_ids.sort();
@ -518,7 +561,7 @@ mod tests {
assert_eq!(
channel.next_event(&cx).await,
ChannelEvent::Message {
ChannelEvent::MessagesAdded {
old_range: 0..0,
new_count: 2,
}
@ -567,7 +610,7 @@ mod tests {
assert_eq!(
channel.next_event(&cx).await,
ChannelEvent::Message {
ChannelEvent::MessagesAdded {
old_range: 2..2,
new_count: 1,
}
@ -580,7 +623,57 @@ mod tests {
.collect::<Vec<_>>(),
&[("as-cii".into(), "c".into())]
)
})
});
// Scroll up to view older messages.
channel.update(&mut cx, |channel, cx| {
assert!(channel.load_more_messages(cx));
});
let get_messages = server.receive::<proto::GetChannelMessages>().await;
assert_eq!(get_messages.payload.channel_id, 5);
assert_eq!(get_messages.payload.before_message_id, 10);
server
.respond(
get_messages.receipt(),
proto::GetChannelMessagesResponse {
done: true,
messages: vec![
proto::ChannelMessage {
id: 8,
body: "y".into(),
timestamp: 998,
sender_id: 5,
},
proto::ChannelMessage {
id: 9,
body: "z".into(),
timestamp: 999,
sender_id: 6,
},
],
},
)
.await;
assert_eq!(
channel.next_event(&cx).await,
ChannelEvent::MessagesAdded {
old_range: 0..0,
new_count: 2,
}
);
channel.read_with(&cx, |channel, _| {
assert_eq!(
channel
.messages_in_range(0..2)
.map(|message| (message.sender.github_login.clone(), message.body.clone()))
.collect::<Vec<_>>(),
&[
("nathansobo".into(), "y".into()),
("maxbrunsfeld".into(), "z".into())
]
);
});
}
struct FakeServer {

View File

@ -22,9 +22,11 @@ pub struct ChatPanel {
pub enum Event {}
action!(Send);
action!(LoadMoreMessages);
pub fn init(cx: &mut MutableAppContext) {
cx.add_action(ChatPanel::send);
cx.add_action(ChatPanel::load_more_messages);
cx.add_bindings(vec![Binding::new("enter", Send, Some("ChatPanel"))]);
}
@ -78,6 +80,11 @@ impl ChatPanel {
let subscription = cx.subscribe(&channel, Self::channel_did_change);
self.message_list =
ListState::new(channel.read(cx).message_count(), Orientation::Bottom);
self.message_list.set_scroll_handler(|visible_range, cx| {
if visible_range.start < 5 {
cx.dispatch_action(LoadMoreMessages);
}
});
self.active_channel = Some((channel, subscription));
}
}
@ -89,7 +96,7 @@ impl ChatPanel {
cx: &mut ViewContext<Self>,
) {
match event {
ChannelEvent::Message {
ChannelEvent::MessagesAdded {
old_range,
new_count,
} => {
@ -191,6 +198,14 @@ impl ChatPanel {
.log_err();
}
}
fn load_more_messages(&mut self, _: &LoadMoreMessages, cx: &mut ViewContext<Self>) {
if let Some((channel, _)) = self.active_channel.as_ref() {
channel.update(cx, |channel, cx| {
channel.load_more_messages(cx);
})
}
}
}
impl Entity for ChatPanel {

View File

@ -32,6 +32,8 @@ message Envelope {
SendChannelMessage send_channel_message = 27;
SendChannelMessageResponse send_channel_message_response = 28;
ChannelMessageSent channel_message_sent = 29;
GetChannelMessages get_channel_messages = 30;
GetChannelMessagesResponse get_channel_messages_response = 31;
}
}
@ -130,6 +132,7 @@ message JoinChannel {
message JoinChannelResponse {
repeated ChannelMessage messages = 1;
bool done = 2;
}
message LeaveChannel {
@ -159,6 +162,16 @@ message ChannelMessageSent {
ChannelMessage message = 2;
}
message GetChannelMessages {
uint64 channel_id = 1;
uint64 before_message_id = 2;
}
message GetChannelMessagesResponse {
repeated ChannelMessage messages = 1;
bool done = 2;
}
// Entities
message Peer {

View File

@ -125,6 +125,8 @@ messages!(
ChannelMessageSent,
CloseBuffer,
CloseWorktree,
GetChannelMessages,
GetChannelMessagesResponse,
GetChannels,
GetChannelsResponse,
GetUsers,
@ -158,6 +160,7 @@ request_messages!(
(SaveBuffer, BufferSaved),
(ShareWorktree, ShareWorktreeResponse),
(SendChannelMessage, SendChannelMessageResponse),
(GetChannelMessages, GetChannelMessagesResponse),
);
entity_messages!(