Avoid N+1 query for channels with notes changes

Also, start work on new timing for recording observed notes edits.

Co-authored-by: Mikayla <mikayla@zed.dev>
This commit is contained in:
Max Brunsfeld 2023-10-02 15:58:34 -07:00
parent 84c4db13fb
commit d9d997b218
7 changed files with 381 additions and 97 deletions

View File

@ -296,6 +296,7 @@ CREATE TABLE "observed_buffer_edits" (
"buffer_id" INTEGER NOT NULL REFERENCES buffers (id) ON DELETE CASCADE,
"epoch" INTEGER NOT NULL,
"lamport_timestamp" INTEGER NOT NULL,
"replica_id" INTEGER NOT NULL,
PRIMARY KEY (user_id, buffer_id)
);

View File

@ -3,6 +3,7 @@ CREATE TABLE IF NOT EXISTS "observed_buffer_edits" (
"buffer_id" INTEGER NOT NULL REFERENCES buffers (id) ON DELETE CASCADE,
"epoch" INTEGER NOT NULL,
"lamport_timestamp" INTEGER NOT NULL,
"replica_id" INTEGER NOT NULL,
PRIMARY KEY (user_id, buffer_id)
);

View File

@ -119,7 +119,7 @@ impl Database {
Ok(new_migrations)
}
async fn transaction<F, Fut, T>(&self, f: F) -> Result<T>
pub async fn transaction<F, Fut, T>(&self, f: F) -> Result<T>
where
F: Send + Fn(TransactionHandle) -> Fut,
Fut: Send + Future<Output = Result<T>>,
@ -321,7 +321,7 @@ fn is_serialization_error(error: &Error) -> bool {
}
}
struct TransactionHandle(Arc<Option<DatabaseTransaction>>);
pub struct TransactionHandle(Arc<Option<DatabaseTransaction>>);
impl Deref for TransactionHandle {
type Target = DatabaseTransaction;

View File

@ -79,12 +79,13 @@ impl Database {
self.get_buffer_state(&buffer, &tx).await?;
// Save the last observed operation
if let Some(max_operation) = max_operation {
if let Some(op) = max_operation {
observed_buffer_edits::Entity::insert(observed_buffer_edits::ActiveModel {
user_id: ActiveValue::Set(user_id),
buffer_id: ActiveValue::Set(buffer.id),
epoch: ActiveValue::Set(max_operation.0),
lamport_timestamp: ActiveValue::Set(max_operation.1),
epoch: ActiveValue::Set(op.epoch),
lamport_timestamp: ActiveValue::Set(op.lamport_timestamp),
replica_id: ActiveValue::Set(op.replica_id),
})
.on_conflict(
OnConflict::columns([
@ -99,37 +100,6 @@ impl Database {
)
.exec(&*tx)
.await?;
} else {
let buffer_max = buffer_operation::Entity::find()
.filter(buffer_operation::Column::BufferId.eq(buffer.id))
.filter(buffer_operation::Column::Epoch.eq(buffer.epoch.saturating_sub(1)))
.order_by(buffer_operation::Column::Epoch, Desc)
.order_by(buffer_operation::Column::LamportTimestamp, Desc)
.one(&*tx)
.await?
.map(|model| (model.epoch, model.lamport_timestamp));
if let Some(buffer_max) = buffer_max {
observed_buffer_edits::Entity::insert(observed_buffer_edits::ActiveModel {
user_id: ActiveValue::Set(user_id),
buffer_id: ActiveValue::Set(buffer.id),
epoch: ActiveValue::Set(buffer_max.0),
lamport_timestamp: ActiveValue::Set(buffer_max.1),
})
.on_conflict(
OnConflict::columns([
observed_buffer_edits::Column::UserId,
observed_buffer_edits::Column::BufferId,
])
.update_columns([
observed_buffer_edits::Column::Epoch,
observed_buffer_edits::Column::LamportTimestamp,
])
.to_owned(),
)
.exec(&*tx)
.await?;
}
}
Ok(proto::JoinChannelBufferResponse {
@ -487,13 +457,8 @@ impl Database {
if !operations.is_empty() {
// get current channel participants and save the max operation above
self.save_max_operation_for_collaborators(
operations.as_slice(),
channel_id,
buffer.id,
&*tx,
)
.await?;
self.save_max_operation(user, buffer.id, buffer.epoch, operations.as_slice(), &*tx)
.await?;
channel_members = self.get_channel_members_internal(channel_id, &*tx).await?;
let collaborators = self
@ -539,54 +504,55 @@ impl Database {
.await
}
async fn save_max_operation_for_collaborators(
async fn save_max_operation(
&self,
operations: &[buffer_operation::ActiveModel],
channel_id: ChannelId,
user_id: UserId,
buffer_id: BufferId,
epoch: i32,
operations: &[buffer_operation::ActiveModel],
tx: &DatabaseTransaction,
) -> Result<()> {
use observed_buffer_edits::Column;
let max_operation = operations
.iter()
.map(|storage_model| {
(
storage_model.epoch.clone(),
storage_model.lamport_timestamp.clone(),
)
})
.max_by(
|(epoch_a, lamport_timestamp_a), (epoch_b, lamport_timestamp_b)| {
epoch_a.as_ref().cmp(epoch_b.as_ref()).then(
lamport_timestamp_a
.as_ref()
.cmp(lamport_timestamp_b.as_ref()),
)
},
)
.max_by_key(|op| (op.lamport_timestamp.as_ref(), op.replica_id.as_ref()))
.unwrap();
let users = self
.get_channel_buffer_collaborators_internal(channel_id, tx)
.await?;
observed_buffer_edits::Entity::insert_many(users.iter().map(|id| {
observed_buffer_edits::ActiveModel {
user_id: ActiveValue::Set(*id),
buffer_id: ActiveValue::Set(buffer_id),
epoch: max_operation.0.clone(),
lamport_timestamp: ActiveValue::Set(*max_operation.1.as_ref()),
}
}))
observed_buffer_edits::Entity::insert(observed_buffer_edits::ActiveModel {
user_id: ActiveValue::Set(user_id),
buffer_id: ActiveValue::Set(buffer_id),
epoch: ActiveValue::Set(epoch),
replica_id: max_operation.replica_id.clone(),
lamport_timestamp: max_operation.lamport_timestamp.clone(),
})
.on_conflict(
OnConflict::columns([
observed_buffer_edits::Column::UserId,
observed_buffer_edits::Column::BufferId,
])
.update_columns([
observed_buffer_edits::Column::Epoch,
observed_buffer_edits::Column::LamportTimestamp,
])
.to_owned(),
OnConflict::columns([Column::UserId, Column::BufferId])
.update_columns([Column::Epoch, Column::LamportTimestamp, Column::ReplicaId])
.target_cond_where(
Condition::any()
.add(Column::Epoch.lt(*max_operation.epoch.as_ref()))
.add(
Condition::all()
.add(Column::Epoch.eq(*max_operation.epoch.as_ref()))
.add(
Condition::any()
.add(
Column::LamportTimestamp
.lt(*max_operation.lamport_timestamp.as_ref()),
)
.add(
Column::LamportTimestamp
.eq(*max_operation.lamport_timestamp.as_ref())
.and(
Column::ReplicaId
.lt(*max_operation.replica_id.as_ref()),
),
),
),
),
)
.to_owned(),
)
.exec(tx)
.await?;
@ -611,7 +577,7 @@ impl Database {
.ok_or_else(|| anyhow!("missing buffer snapshot"))?)
}
async fn get_channel_buffer(
pub async fn get_channel_buffer(
&self,
channel_id: ChannelId,
tx: &DatabaseTransaction,
@ -630,7 +596,11 @@ impl Database {
&self,
buffer: &buffer::Model,
tx: &DatabaseTransaction,
) -> Result<(String, Vec<proto::Operation>, Option<(i32, i32)>)> {
) -> Result<(
String,
Vec<proto::Operation>,
Option<buffer_operation::Model>,
)> {
let id = buffer.id;
let (base_text, version) = if buffer.epoch > 0 {
let snapshot = buffer_snapshot::Entity::find()
@ -655,24 +625,28 @@ impl Database {
.eq(id)
.and(buffer_operation::Column::Epoch.eq(buffer.epoch)),
)
.order_by_asc(buffer_operation::Column::LamportTimestamp)
.order_by_asc(buffer_operation::Column::ReplicaId)
.stream(&*tx)
.await?;
let mut operations = Vec::new();
let mut max_epoch: Option<i32> = None;
let mut max_timestamp: Option<i32> = None;
let mut operations = Vec::new();
let mut last_row = None;
while let Some(row) = rows.next().await {
let row = row?;
max_assign(&mut max_epoch, row.epoch);
max_assign(&mut max_timestamp, row.lamport_timestamp);
last_row = Some(buffer_operation::Model {
buffer_id: row.buffer_id,
epoch: row.epoch,
lamport_timestamp: row.lamport_timestamp,
replica_id: row.lamport_timestamp,
value: Default::default(),
});
operations.push(proto::Operation {
variant: Some(operation_from_storage(row, version)?),
})
});
}
Ok((base_text, operations, max_epoch.zip(max_timestamp)))
Ok((base_text, operations, last_row))
}
async fn snapshot_channel_buffer(
@ -725,6 +699,119 @@ impl Database {
.await
}
pub async fn channels_with_changed_notes(
&self,
user_id: UserId,
channel_ids: impl IntoIterator<Item = ChannelId>,
tx: &DatabaseTransaction,
) -> Result<HashSet<ChannelId>> {
#[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
enum QueryIds {
ChannelId,
Id,
}
let mut channel_ids_by_buffer_id = HashMap::default();
let mut rows = buffer::Entity::find()
.filter(buffer::Column::ChannelId.is_in(channel_ids))
.stream(&*tx)
.await?;
while let Some(row) = rows.next().await {
let row = row?;
channel_ids_by_buffer_id.insert(row.id, row.channel_id);
}
drop(rows);
let mut observed_edits_by_buffer_id = HashMap::default();
let mut rows = observed_buffer_edits::Entity::find()
.filter(observed_buffer_edits::Column::UserId.eq(user_id))
.filter(
observed_buffer_edits::Column::BufferId
.is_in(channel_ids_by_buffer_id.keys().copied()),
)
.stream(&*tx)
.await?;
while let Some(row) = rows.next().await {
let row = row?;
observed_edits_by_buffer_id.insert(row.buffer_id, row);
}
drop(rows);
let last_operations = self
.get_last_operations_for_buffers(channel_ids_by_buffer_id.keys().copied(), &*tx)
.await?;
let mut channels_with_new_changes = HashSet::default();
for last_operation in last_operations {
if let Some(observed_edit) = observed_edits_by_buffer_id.get(&last_operation.buffer_id)
{
if observed_edit.epoch == last_operation.epoch
&& observed_edit.lamport_timestamp == last_operation.lamport_timestamp
&& observed_edit.replica_id == last_operation.replica_id
{
continue;
}
}
if let Some(channel_id) = channel_ids_by_buffer_id.get(&last_operation.buffer_id) {
channels_with_new_changes.insert(*channel_id);
}
}
Ok(channels_with_new_changes)
}
pub async fn get_last_operations_for_buffers(
&self,
channel_ids: impl IntoIterator<Item = BufferId>,
tx: &DatabaseTransaction,
) -> Result<Vec<buffer_operation::Model>> {
let mut values = String::new();
for id in channel_ids {
if !values.is_empty() {
values.push_str(", ");
}
write!(&mut values, "({})", id).unwrap();
}
if values.is_empty() {
return Ok(Vec::default());
}
let sql = format!(
r#"
SELECT
*
FROM (
SELECT
buffer_id,
epoch,
lamport_timestamp,
replica_id,
value,
row_number() OVER (
PARTITION BY buffer_id
ORDER BY
epoch DESC,
lamport_timestamp DESC,
replica_id DESC
) as row_number
FROM buffer_operations
WHERE
buffer_id in ({values})
) AS operations
WHERE
row_number = 1
"#,
);
let stmt = Statement::from_string(self.pool.get_database_backend(), sql);
let operations = buffer_operation::Model::find_by_statement(stmt)
.all(&*tx)
.await?;
Ok(operations)
}
pub async fn has_note_changed(
&self,
user_id: UserId,

View File

@ -463,12 +463,16 @@ impl Database {
}
}
let mut channels_with_changed_notes = HashSet::default();
let channels_with_changed_notes = self
.channels_with_changed_notes(
user_id,
graph.channels.iter().map(|channel| channel.id),
&*tx,
)
.await?;
let mut channels_with_new_messages = HashSet::default();
for channel in graph.channels.iter() {
if self.has_note_changed(user_id, channel.id, tx).await? {
channels_with_changed_notes.insert(channel.id);
}
if self.has_new_message(channel.id, user_id, tx).await? {
channels_with_new_messages.insert(channel.id);
}

View File

@ -9,6 +9,7 @@ pub struct Model {
pub buffer_id: BufferId,
pub epoch: i32,
pub lamport_timestamp: i32,
pub replica_id: i32,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]

View File

@ -272,3 +272,193 @@ async fn test_channel_buffers_diffs(db: &Database) {
assert!(!db.test_has_note_changed(a_id, zed_id).await.unwrap());
assert!(!db.test_has_note_changed(b_id, zed_id).await.unwrap());
}
test_both_dbs!(
test_channel_buffers_last_operations,
test_channel_buffers_last_operations_postgres,
test_channel_buffers_last_operations_sqlite
);
async fn test_channel_buffers_last_operations(db: &Database) {
let user_id = db
.create_user(
"user_a@example.com",
false,
NewUserParams {
github_login: "user_a".into(),
github_user_id: 101,
invite_count: 0,
},
)
.await
.unwrap()
.user_id;
let owner_id = db.create_server("production").await.unwrap().0 as u32;
let connection_id = ConnectionId {
owner_id,
id: user_id.0 as u32,
};
let mut buffers = Vec::new();
let mut text_buffers = Vec::new();
for i in 0..3 {
let channel = db
.create_root_channel(&format!("channel-{i}"), &format!("room-{i}"), user_id)
.await
.unwrap();
db.join_channel_buffer(channel, user_id, connection_id)
.await
.unwrap();
buffers.push(
db.transaction(|tx| async move { db.get_channel_buffer(channel, &*tx).await })
.await
.unwrap(),
);
text_buffers.push(Buffer::new(0, 0, "".to_string()));
}
let operations = db
.transaction(|tx| {
let buffers = &buffers;
async move {
db.get_last_operations_for_buffers([buffers[0].id, buffers[2].id], &*tx)
.await
}
})
.await
.unwrap();
assert!(operations.is_empty());
update_buffer(
buffers[0].channel_id,
user_id,
db,
vec![
text_buffers[0].edit([(0..0, "a")]),
text_buffers[0].edit([(0..0, "b")]),
text_buffers[0].edit([(0..0, "c")]),
],
)
.await;
update_buffer(
buffers[1].channel_id,
user_id,
db,
vec![
text_buffers[1].edit([(0..0, "d")]),
text_buffers[1].edit([(1..1, "e")]),
text_buffers[1].edit([(2..2, "f")]),
],
)
.await;
// cause buffer 1's epoch to increment.
db.leave_channel_buffer(buffers[1].channel_id, connection_id)
.await
.unwrap();
db.join_channel_buffer(buffers[1].channel_id, user_id, connection_id)
.await
.unwrap();
text_buffers[1] = Buffer::new(1, 0, "def".to_string());
update_buffer(
buffers[1].channel_id,
user_id,
db,
vec![
text_buffers[1].edit([(0..0, "g")]),
text_buffers[1].edit([(0..0, "h")]),
],
)
.await;
update_buffer(
buffers[2].channel_id,
user_id,
db,
vec![text_buffers[2].edit([(0..0, "i")])],
)
.await;
let operations = db
.transaction(|tx| {
let buffers = &buffers;
async move {
db.get_last_operations_for_buffers([buffers[1].id, buffers[2].id], &*tx)
.await
}
})
.await
.unwrap();
assert_operations(
&operations,
&[
(buffers[1].id, 1, &text_buffers[1]),
(buffers[2].id, 0, &text_buffers[2]),
],
);
let operations = db
.transaction(|tx| {
let buffers = &buffers;
async move {
db.get_last_operations_for_buffers([buffers[0].id, buffers[1].id], &*tx)
.await
}
})
.await
.unwrap();
assert_operations(
&operations,
&[
(buffers[0].id, 0, &text_buffers[0]),
(buffers[1].id, 1, &text_buffers[1]),
],
);
async fn update_buffer(
channel_id: ChannelId,
user_id: UserId,
db: &Database,
operations: Vec<text::Operation>,
) {
let operations = operations
.into_iter()
.map(|op| proto::serialize_operation(&language::Operation::Buffer(op)))
.collect::<Vec<_>>();
db.update_channel_buffer(channel_id, user_id, &operations)
.await
.unwrap();
}
fn assert_operations(
operations: &[buffer_operation::Model],
expected: &[(BufferId, i32, &text::Buffer)],
) {
let actual = operations
.iter()
.map(|op| buffer_operation::Model {
buffer_id: op.buffer_id,
epoch: op.epoch,
lamport_timestamp: op.lamport_timestamp,
replica_id: op.replica_id,
value: vec![],
})
.collect::<Vec<_>>();
let expected = expected
.iter()
.map(|(buffer_id, epoch, buffer)| buffer_operation::Model {
buffer_id: *buffer_id,
epoch: *epoch,
lamport_timestamp: buffer.lamport_clock.value as i32 - 1,
replica_id: buffer.replica_id() as i32,
value: vec![],
})
.collect::<Vec<_>>();
assert_eq!(actual, expected, "unexpected operations")
}
}