Start work on restoring server-side code for chat messages

This commit is contained in:
Max Brunsfeld 2023-09-07 12:24:25 -07:00
parent 3422eb65e8
commit da5a77badf
16 changed files with 524 additions and 5 deletions

View File

@ -14,4 +14,5 @@ mod channel_store_tests;
pub fn init(client: &Arc<Client>) {
channel_buffer::init(client);
channel_chat::init(client);
}

View File

@ -57,6 +57,10 @@ pub enum ChannelChatEvent {
},
}
pub fn init(client: &Arc<Client>) {
client.add_model_message_handler(ChannelChat::handle_message_sent);
}
impl Entity for ChannelChat {
type Event = ChannelChatEvent;
@ -70,10 +74,6 @@ impl Entity for ChannelChat {
}
impl ChannelChat {
pub fn init(rpc: &Arc<Client>) {
rpc.add_model_message_handler(Self::handle_message_sent);
}
pub async fn new(
channel: Arc<Channel>,
user_store: ModelHandle<UserStore>,

View File

@ -192,6 +192,26 @@ CREATE TABLE "channels" (
"created_at" TIMESTAMP NOT NULL DEFAULT now
);
CREATE TABLE IF NOT EXISTS "channel_chat_participants" (
"id" INTEGER PRIMARY KEY AUTOINCREMENT,
"user_id" INTEGER NOT NULL REFERENCES users (id),
"channel_id" INTEGER NOT NULL REFERENCES channels (id) ON DELETE CASCADE,
"connection_id" INTEGER NOT NULL,
"connection_server_id" INTEGER NOT NULL REFERENCES servers (id) ON DELETE CASCADE
);
CREATE INDEX "index_channel_chat_participants_on_channel_id" ON "channel_chat_participants" ("channel_id");
CREATE TABLE IF NOT EXISTS "channel_messages" (
"id" INTEGER PRIMARY KEY AUTOINCREMENT,
"channel_id" INTEGER NOT NULL REFERENCES channels (id) ON DELETE CASCADE,
"sender_id" INTEGER NOT NULL REFERENCES users (id),
"body" TEXT NOT NULL,
"sent_at" TIMESTAMP,
"nonce" BLOB NOT NULL
);
CREATE INDEX "index_channel_messages_on_channel_id" ON "channel_messages" ("channel_id");
CREATE UNIQUE INDEX "index_channel_messages_on_nonce" ON "channel_messages" ("nonce");
CREATE TABLE "channel_paths" (
"id_path" TEXT NOT NULL PRIMARY KEY,
"channel_id" INTEGER NOT NULL REFERENCES channels (id) ON DELETE CASCADE

View File

@ -0,0 +1,19 @@
CREATE TABLE IF NOT EXISTS "channel_messages" (
"id" SERIAL PRIMARY KEY,
"channel_id" INTEGER NOT NULL REFERENCES channels (id) ON DELETE CASCADE,
"sender_id" INTEGER NOT NULL REFERENCES users (id),
"body" TEXT NOT NULL,
"sent_at" TIMESTAMP,
"nonce" UUID NOT NULL
);
CREATE INDEX "index_channel_messages_on_channel_id" ON "channel_messages" ("channel_id");
CREATE UNIQUE INDEX "index_channel_messages_on_nonce" ON "channel_messages" ("nonce");
CREATE TABLE IF NOT EXISTS "channel_chat_participants" (
"id" SERIAL PRIMARY KEY,
"user_id" INTEGER NOT NULL REFERENCES users (id),
"channel_id" INTEGER NOT NULL REFERENCES channels (id) ON DELETE CASCADE,
"connection_id" INTEGER NOT NULL,
"connection_server_id" INTEGER NOT NULL REFERENCES servers (id) ON DELETE CASCADE
);
CREATE INDEX "index_channel_chat_participants_on_channel_id" ON "channel_chat_participants" ("channel_id");

View File

@ -112,8 +112,10 @@ fn value_to_integer(v: Value) -> Result<i32, ValueTypeErr> {
id_type!(BufferId);
id_type!(AccessTokenId);
id_type!(ChannelChatParticipantId);
id_type!(ChannelId);
id_type!(ChannelMemberId);
id_type!(MessageId);
id_type!(ContactId);
id_type!(FollowerId);
id_type!(RoomId);

View File

@ -4,6 +4,7 @@ pub mod access_tokens;
pub mod buffers;
pub mod channels;
pub mod contacts;
pub mod messages;
pub mod projects;
pub mod rooms;
pub mod servers;

View File

@ -0,0 +1,152 @@
use super::*;
use time::OffsetDateTime;
impl Database {
pub async fn join_channel_chat(
&self,
channel_id: ChannelId,
connection_id: ConnectionId,
user_id: UserId,
) -> Result<()> {
self.transaction(|tx| async move {
self.check_user_is_channel_member(channel_id, user_id, &*tx)
.await?;
channel_chat_participant::ActiveModel {
id: ActiveValue::NotSet,
channel_id: ActiveValue::Set(channel_id),
user_id: ActiveValue::Set(user_id),
connection_id: ActiveValue::Set(connection_id.id as i32),
connection_server_id: ActiveValue::Set(ServerId(connection_id.owner_id as i32)),
}
.insert(&*tx)
.await?;
Ok(())
})
.await
}
pub async fn leave_channel_chat(
&self,
channel_id: ChannelId,
connection_id: ConnectionId,
_user_id: UserId,
) -> Result<()> {
self.transaction(|tx| async move {
channel_chat_participant::Entity::delete_many()
.filter(
Condition::all()
.add(
channel_chat_participant::Column::ConnectionServerId
.eq(connection_id.owner_id),
)
.add(channel_chat_participant::Column::ConnectionId.eq(connection_id.id))
.add(channel_chat_participant::Column::ChannelId.eq(channel_id)),
)
.exec(&*tx)
.await?;
Ok(())
})
.await
}
pub async fn get_channel_messages(
&self,
channel_id: ChannelId,
user_id: UserId,
count: usize,
before_message_id: Option<MessageId>,
) -> Result<Vec<proto::ChannelMessage>> {
self.transaction(|tx| async move {
self.check_user_is_channel_member(channel_id, user_id, &*tx)
.await?;
let mut condition =
Condition::all().add(channel_message::Column::ChannelId.eq(channel_id));
if let Some(before_message_id) = before_message_id {
condition = condition.add(channel_message::Column::Id.lt(before_message_id));
}
let mut rows = channel_message::Entity::find()
.filter(condition)
.limit(count as u64)
.stream(&*tx)
.await?;
let mut messages = Vec::new();
while let Some(row) = rows.next().await {
let row = row?;
let nonce = row.nonce.as_u64_pair();
messages.push(proto::ChannelMessage {
id: row.id.to_proto(),
sender_id: row.sender_id.to_proto(),
body: row.body,
timestamp: row.sent_at.unix_timestamp() as u64,
nonce: Some(proto::Nonce {
upper_half: nonce.0,
lower_half: nonce.1,
}),
});
}
Ok(messages)
})
.await
}
pub async fn create_channel_message(
&self,
channel_id: ChannelId,
user_id: UserId,
body: &str,
timestamp: OffsetDateTime,
nonce: u128,
) -> Result<(MessageId, Vec<ConnectionId>)> {
self.transaction(|tx| async move {
let mut rows = channel_chat_participant::Entity::find()
.filter(channel_chat_participant::Column::ChannelId.eq(channel_id))
.stream(&*tx)
.await?;
let mut is_participant = false;
let mut participant_connection_ids = Vec::new();
while let Some(row) = rows.next().await {
let row = row?;
if row.user_id == user_id {
is_participant = true;
}
participant_connection_ids.push(row.connection());
}
drop(rows);
if !is_participant {
Err(anyhow!("not a chat participant"))?;
}
let message = channel_message::Entity::insert(channel_message::ActiveModel {
channel_id: ActiveValue::Set(channel_id),
sender_id: ActiveValue::Set(user_id),
body: ActiveValue::Set(body.to_string()),
sent_at: ActiveValue::Set(timestamp),
nonce: ActiveValue::Set(Uuid::from_u128(nonce)),
id: ActiveValue::NotSet,
})
.on_conflict(
OnConflict::column(channel_message::Column::Nonce)
.update_column(channel_message::Column::Nonce)
.to_owned(),
)
.exec(&*tx)
.await?;
#[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
enum QueryConnectionId {
ConnectionId,
}
Ok((message.last_insert_id, participant_connection_ids))
})
.await
}
}

View File

@ -4,7 +4,9 @@ pub mod buffer_operation;
pub mod buffer_snapshot;
pub mod channel;
pub mod channel_buffer_collaborator;
pub mod channel_chat_participant;
pub mod channel_member;
pub mod channel_message;
pub mod channel_path;
pub mod contact;
pub mod feature_flag;

View File

@ -21,6 +21,8 @@ pub enum Relation {
Member,
#[sea_orm(has_many = "super::channel_buffer_collaborator::Entity")]
BufferCollaborators,
#[sea_orm(has_many = "super::channel_chat_participant::Entity")]
ChatParticipants,
}
impl Related<super::channel_member::Entity> for Entity {
@ -46,3 +48,9 @@ impl Related<super::channel_buffer_collaborator::Entity> for Entity {
Relation::BufferCollaborators.def()
}
}
impl Related<super::channel_chat_participant::Entity> for Entity {
fn to() -> RelationDef {
Relation::ChatParticipants.def()
}
}

View File

@ -0,0 +1,41 @@
use crate::db::{ChannelChatParticipantId, ChannelId, ServerId, UserId};
use rpc::ConnectionId;
use sea_orm::entity::prelude::*;
#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)]
#[sea_orm(table_name = "channel_chat_participants")]
pub struct Model {
#[sea_orm(primary_key)]
pub id: ChannelChatParticipantId,
pub channel_id: ChannelId,
pub user_id: UserId,
pub connection_id: i32,
pub connection_server_id: ServerId,
}
impl Model {
pub fn connection(&self) -> ConnectionId {
ConnectionId {
owner_id: self.connection_server_id.0 as u32,
id: self.connection_id as u32,
}
}
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {
#[sea_orm(
belongs_to = "super::channel::Entity",
from = "Column::ChannelId",
to = "super::channel::Column::Id"
)]
Channel,
}
impl Related<super::channel::Entity> for Entity {
fn to() -> RelationDef {
Relation::Channel.def()
}
}
impl ActiveModelBehavior for ActiveModel {}

View File

@ -0,0 +1,45 @@
use crate::db::{ChannelId, MessageId, UserId};
use sea_orm::entity::prelude::*;
use time::OffsetDateTime;
#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)]
#[sea_orm(table_name = "channel_messages")]
pub struct Model {
#[sea_orm(primary_key)]
pub id: MessageId,
pub channel_id: ChannelId,
pub sender_id: UserId,
pub body: String,
pub sent_at: OffsetDateTime,
pub nonce: Uuid,
}
impl ActiveModelBehavior for ActiveModel {}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {
#[sea_orm(
belongs_to = "super::channel::Entity",
from = "Column::ChannelId",
to = "super::channel::Column::Id"
)]
Channel,
#[sea_orm(
belongs_to = "super::user::Entity",
from = "Column::SenderId",
to = "super::user::Column::Id"
)]
Sender,
}
impl Related<super::channel::Entity> for Entity {
fn to() -> RelationDef {
Relation::Channel.def()
}
}
impl Related<super::user::Entity> for Entity {
fn to() -> RelationDef {
Relation::Sender.def()
}
}

View File

@ -1,6 +1,7 @@
mod buffer_tests;
mod db_tests;
mod feature_flag_tests;
mod message_tests;
use super::*;
use gpui::executor::Background;

View File

@ -0,0 +1,53 @@
use crate::{
db::{Database, NewUserParams},
test_both_dbs,
};
use std::sync::Arc;
use time::OffsetDateTime;
test_both_dbs!(
test_channel_message_nonces,
test_channel_message_nonces_postgres,
test_channel_message_nonces_sqlite
);
async fn test_channel_message_nonces(db: &Arc<Database>) {
let user = db
.create_user(
"user@example.com",
false,
NewUserParams {
github_login: "user".into(),
github_user_id: 1,
invite_count: 0,
},
)
.await
.unwrap()
.user_id;
let channel = db
.create_channel("channel", None, "room", user)
.await
.unwrap();
let msg1_id = db
.create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1)
.await
.unwrap();
let msg2_id = db
.create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2)
.await
.unwrap();
let msg3_id = db
.create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1)
.await
.unwrap();
let msg4_id = db
.create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2)
.await
.unwrap();
assert_ne!(msg1_id, msg2_id);
assert_eq!(msg1_id, msg3_id);
assert_eq!(msg2_id, msg4_id);
}

View File

@ -2,7 +2,10 @@ mod connection_pool;
use crate::{
auth,
db::{self, ChannelId, ChannelsForUser, Database, ProjectId, RoomId, ServerId, User, UserId},
db::{
self, ChannelId, ChannelsForUser, Database, MessageId, ProjectId, RoomId, ServerId, User,
UserId,
},
executor::Executor,
AppState, Result,
};
@ -56,6 +59,7 @@ use std::{
},
time::{Duration, Instant},
};
use time::OffsetDateTime;
use tokio::sync::{watch, Semaphore};
use tower::ServiceBuilder;
use tracing::{info_span, instrument, Instrument};
@ -63,6 +67,9 @@ use tracing::{info_span, instrument, Instrument};
pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
const MESSAGE_COUNT_PER_PAGE: usize = 100;
const MAX_MESSAGE_LEN: usize = 1024;
lazy_static! {
static ref METRIC_CONNECTIONS: IntGauge =
register_int_gauge!("connections", "number of connections").unwrap();
@ -255,6 +262,10 @@ impl Server {
.add_request_handler(get_channel_members)
.add_request_handler(respond_to_channel_invite)
.add_request_handler(join_channel)
.add_request_handler(join_channel_chat)
.add_message_handler(leave_channel_chat)
.add_request_handler(send_channel_message)
.add_request_handler(get_channel_messages)
.add_request_handler(follow)
.add_message_handler(unfollow)
.add_message_handler(update_followers)
@ -2641,6 +2652,112 @@ fn channel_buffer_updated<T: EnvelopedMessage>(
});
}
async fn send_channel_message(
request: proto::SendChannelMessage,
response: Response<proto::SendChannelMessage>,
session: Session,
) -> Result<()> {
// Validate the message body.
let body = request.body.trim().to_string();
if body.len() > MAX_MESSAGE_LEN {
return Err(anyhow!("message is too long"))?;
}
if body.is_empty() {
return Err(anyhow!("message can't be blank"))?;
}
let timestamp = OffsetDateTime::now_utc();
let nonce = request
.nonce
.ok_or_else(|| anyhow!("nonce can't be blank"))?;
let channel_id = ChannelId::from_proto(request.channel_id);
let (message_id, connection_ids) = session
.db()
.await
.create_channel_message(
channel_id,
session.user_id,
&body,
timestamp,
nonce.clone().into(),
)
.await?;
let message = proto::ChannelMessage {
sender_id: session.user_id.to_proto(),
id: message_id.to_proto(),
body,
timestamp: timestamp.unix_timestamp() as u64,
nonce: Some(nonce),
};
broadcast(Some(session.connection_id), connection_ids, |connection| {
session.peer.send(
connection,
proto::ChannelMessageSent {
channel_id: channel_id.to_proto(),
message: Some(message.clone()),
},
)
});
response.send(proto::SendChannelMessageResponse {
message: Some(message),
})?;
Ok(())
}
async fn join_channel_chat(
request: proto::JoinChannelChat,
response: Response<proto::JoinChannelChat>,
session: Session,
) -> Result<()> {
let channel_id = ChannelId::from_proto(request.channel_id);
let db = session.db().await;
db.join_channel_chat(channel_id, session.connection_id, session.user_id)
.await?;
let messages = db
.get_channel_messages(channel_id, session.user_id, MESSAGE_COUNT_PER_PAGE, None)
.await?;
response.send(proto::JoinChannelChatResponse {
done: messages.len() < MESSAGE_COUNT_PER_PAGE,
messages,
})?;
Ok(())
}
async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
let channel_id = ChannelId::from_proto(request.channel_id);
session
.db()
.await
.leave_channel_chat(channel_id, session.connection_id, session.user_id)
.await?;
Ok(())
}
async fn get_channel_messages(
request: proto::GetChannelMessages,
response: Response<proto::GetChannelMessages>,
session: Session,
) -> Result<()> {
let channel_id = ChannelId::from_proto(request.channel_id);
let messages = session
.db()
.await
.get_channel_messages(
channel_id,
session.user_id,
MESSAGE_COUNT_PER_PAGE,
Some(MessageId::from_proto(request.before_message_id)),
)
.await?;
response.send(proto::GetChannelMessagesResponse {
done: messages.len() < MESSAGE_COUNT_PER_PAGE,
messages,
})?;
Ok(())
}
async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
let project_id = ProjectId::from_proto(request.project_id);
let project_connection_ids = session

View File

@ -2,6 +2,7 @@ use call::Room;
use gpui::{ModelHandle, TestAppContext};
mod channel_buffer_tests;
mod channel_message_tests;
mod channel_tests;
mod integration_tests;
mod random_channel_buffer_tests;

View File

@ -0,0 +1,56 @@
use crate::tests::TestServer;
use gpui::{executor::Deterministic, TestAppContext};
use std::sync::Arc;
#[gpui::test]
async fn test_basic_channel_messages(
deterministic: Arc<Deterministic>,
cx_a: &mut TestAppContext,
cx_b: &mut TestAppContext,
) {
deterministic.forbid_parking();
let mut server = TestServer::start(&deterministic).await;
let client_a = server.create_client(cx_a, "user_a").await;
let client_b = server.create_client(cx_b, "user_b").await;
let channel_id = server
.make_channel("the-channel", (&client_a, cx_a), &mut [(&client_b, cx_b)])
.await;
let channel_chat_a = client_a
.channel_store()
.update(cx_a, |store, cx| store.open_channel_chat(channel_id, cx))
.await
.unwrap();
let channel_chat_b = client_b
.channel_store()
.update(cx_b, |store, cx| store.open_channel_chat(channel_id, cx))
.await
.unwrap();
channel_chat_a
.update(cx_a, |c, cx| c.send_message("one".into(), cx).unwrap())
.await
.unwrap();
channel_chat_a
.update(cx_a, |c, cx| c.send_message("two".into(), cx).unwrap())
.await
.unwrap();
deterministic.run_until_parked();
channel_chat_b
.update(cx_b, |c, cx| c.send_message("three".into(), cx).unwrap())
.await
.unwrap();
deterministic.run_until_parked();
channel_chat_a.update(cx_a, |c, _| {
assert_eq!(
c.messages()
.iter()
.map(|m| m.body.as_str())
.collect::<Vec<_>>(),
vec!["one", "two", "three"]
);
})
}