Use synchronous locks for Peer state

We hold these locks for a short amount of time anyway, and using an
async lock could cause parallel sends to happen in an order different
than the order in which `send`/`request` was called.

Co-Authored-By: Nathan Sobo <nathan@zed.dev>
This commit is contained in:
Antonio Scandurra 2022-01-12 18:02:41 +01:00
parent 310def2923
commit 9e4b118214
4 changed files with 36 additions and 44 deletions

View File

@ -661,9 +661,9 @@ impl Client {
}) })
} }
pub async fn disconnect(self: &Arc<Self>, cx: &AsyncAppContext) -> Result<()> { pub fn disconnect(self: &Arc<Self>, cx: &AsyncAppContext) -> Result<()> {
let conn_id = self.connection_id()?; let conn_id = self.connection_id()?;
self.peer.disconnect(conn_id).await; self.peer.disconnect(conn_id);
self.set_status(Status::SignedOut, cx); self.set_status(Status::SignedOut, cx);
Ok(()) Ok(())
} }
@ -764,7 +764,7 @@ mod tests {
let ping = server.receive::<proto::Ping>().await.unwrap(); let ping = server.receive::<proto::Ping>().await.unwrap();
server.respond(ping.receipt(), proto::Ack {}).await; server.respond(ping.receipt(), proto::Ack {}).await;
client.disconnect(&cx.to_async()).await.unwrap(); client.disconnect(&cx.to_async()).unwrap();
assert!(server.receive::<proto::Ping>().await.is_err()); assert!(server.receive::<proto::Ping>().await.is_err());
} }
@ -783,7 +783,7 @@ mod tests {
assert_eq!(server.auth_count(), 1); assert_eq!(server.auth_count(), 1);
server.forbid_connections(); server.forbid_connections();
server.disconnect().await; server.disconnect();
while !matches!(status.recv().await, Some(Status::ReconnectionError { .. })) {} while !matches!(status.recv().await, Some(Status::ReconnectionError { .. })) {}
server.allow_connections(); server.allow_connections();
@ -792,7 +792,7 @@ mod tests {
assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting
server.forbid_connections(); server.forbid_connections();
server.disconnect().await; server.disconnect();
while !matches!(status.recv().await, Some(Status::ReconnectionError { .. })) {} while !matches!(status.recv().await, Some(Status::ReconnectionError { .. })) {}
// Clear cached credentials after authentication fails // Clear cached credentials after authentication fails

View File

@ -72,8 +72,8 @@ impl FakeServer {
server server
} }
pub async fn disconnect(&self) { pub fn disconnect(&self) {
self.peer.disconnect(self.connection_id()).await; self.peer.disconnect(self.connection_id());
self.connection_id.lock().take(); self.connection_id.lock().take();
self.incoming.lock().take(); self.incoming.lock().take();
} }

View File

@ -1,8 +1,8 @@
use super::proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, RequestMessage}; use super::proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, RequestMessage};
use super::Connection; use super::Connection;
use anyhow::{anyhow, Context, Result}; use anyhow::{anyhow, Context, Result};
use async_lock::{Mutex, RwLock};
use futures::FutureExt as _; use futures::FutureExt as _;
use parking_lot::{Mutex, RwLock};
use postage::{ use postage::{
mpsc, mpsc,
prelude::{Sink as _, Stream as _}, prelude::{Sink as _, Stream as _},
@ -133,7 +133,7 @@ impl Peer {
incoming = read_message => match incoming { incoming = read_message => match incoming {
Ok(incoming) => { Ok(incoming) => {
if let Some(responding_to) = incoming.responding_to { if let Some(responding_to) = incoming.responding_to {
let channel = response_channels.lock().await.as_mut().unwrap().remove(&responding_to); let channel = response_channels.lock().as_mut().unwrap().remove(&responding_to);
if let Some(mut tx) = channel { if let Some(mut tx) = channel {
tx.send(incoming).await.ok(); tx.send(incoming).await.ok();
} else { } else {
@ -169,25 +169,24 @@ impl Peer {
} }
}; };
response_channels.lock().await.take(); response_channels.lock().take();
this.connections.write().await.remove(&connection_id); this.connections.write().remove(&connection_id);
result result
}; };
self.connections self.connections
.write() .write()
.await
.insert(connection_id, connection_state); .insert(connection_id, connection_state);
(connection_id, handle_io, incoming_rx) (connection_id, handle_io, incoming_rx)
} }
pub async fn disconnect(&self, connection_id: ConnectionId) { pub fn disconnect(&self, connection_id: ConnectionId) {
self.connections.write().await.remove(&connection_id); self.connections.write().remove(&connection_id);
} }
pub async fn reset(&self) { pub fn reset(&self) {
self.connections.write().await.clear(); self.connections.write().clear();
} }
pub fn request<T: RequestMessage>( pub fn request<T: RequestMessage>(
@ -216,12 +215,11 @@ impl Peer {
let this = self.clone(); let this = self.clone();
let (tx, mut rx) = mpsc::channel(1); let (tx, mut rx) = mpsc::channel(1);
async move { async move {
let mut connection = this.connection_state(receiver_id).await?; let mut connection = this.connection_state(receiver_id)?;
let message_id = connection.next_message_id.fetch_add(1, SeqCst); let message_id = connection.next_message_id.fetch_add(1, SeqCst);
connection connection
.response_channels .response_channels
.lock() .lock()
.await
.as_mut() .as_mut()
.ok_or_else(|| anyhow!("connection was closed"))? .ok_or_else(|| anyhow!("connection was closed"))?
.insert(message_id, tx); .insert(message_id, tx);
@ -250,7 +248,7 @@ impl Peer {
) -> impl Future<Output = Result<()>> { ) -> impl Future<Output = Result<()>> {
let this = self.clone(); let this = self.clone();
async move { async move {
let mut connection = this.connection_state(receiver_id).await?; let mut connection = this.connection_state(receiver_id)?;
let message_id = connection let message_id = connection
.next_message_id .next_message_id
.fetch_add(1, atomic::Ordering::SeqCst); .fetch_add(1, atomic::Ordering::SeqCst);
@ -270,7 +268,7 @@ impl Peer {
) -> impl Future<Output = Result<()>> { ) -> impl Future<Output = Result<()>> {
let this = self.clone(); let this = self.clone();
async move { async move {
let mut connection = this.connection_state(receiver_id).await?; let mut connection = this.connection_state(receiver_id)?;
let message_id = connection let message_id = connection
.next_message_id .next_message_id
.fetch_add(1, atomic::Ordering::SeqCst); .fetch_add(1, atomic::Ordering::SeqCst);
@ -289,7 +287,7 @@ impl Peer {
) -> impl Future<Output = Result<()>> { ) -> impl Future<Output = Result<()>> {
let this = self.clone(); let this = self.clone();
async move { async move {
let mut connection = this.connection_state(receipt.sender_id).await?; let mut connection = this.connection_state(receipt.sender_id)?;
let message_id = connection let message_id = connection
.next_message_id .next_message_id
.fetch_add(1, atomic::Ordering::SeqCst); .fetch_add(1, atomic::Ordering::SeqCst);
@ -308,7 +306,7 @@ impl Peer {
) -> impl Future<Output = Result<()>> { ) -> impl Future<Output = Result<()>> {
let this = self.clone(); let this = self.clone();
async move { async move {
let mut connection = this.connection_state(receipt.sender_id).await?; let mut connection = this.connection_state(receipt.sender_id)?;
let message_id = connection let message_id = connection
.next_message_id .next_message_id
.fetch_add(1, atomic::Ordering::SeqCst); .fetch_add(1, atomic::Ordering::SeqCst);
@ -320,20 +318,14 @@ impl Peer {
} }
} }
fn connection_state( fn connection_state(&self, connection_id: ConnectionId) -> Result<ConnectionState> {
self: &Arc<Self>, let connections = self.connections.read();
connection_id: ConnectionId,
) -> impl Future<Output = Result<ConnectionState>> {
let this = self.clone();
async move {
let connections = this.connections.read().await;
let connection = connections let connection = connections
.get(&connection_id) .get(&connection_id)
.ok_or_else(|| anyhow!("no such connection: {}", connection_id))?; .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
Ok(connection.clone()) Ok(connection.clone())
} }
} }
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
@ -398,7 +390,7 @@ mod tests {
proto::OpenBufferResponse { proto::OpenBufferResponse {
buffer: Some(proto::Buffer { buffer: Some(proto::Buffer {
id: 101, id: 101,
content: "path/one content".to_string(), visible_text: "path/one content".to_string(),
..Default::default() ..Default::default()
}), }),
} }
@ -419,14 +411,14 @@ mod tests {
proto::OpenBufferResponse { proto::OpenBufferResponse {
buffer: Some(proto::Buffer { buffer: Some(proto::Buffer {
id: 102, id: 102,
content: "path/two content".to_string(), visible_text: "path/two content".to_string(),
..Default::default() ..Default::default()
}), }),
} }
); );
client1.disconnect(client1_conn_id).await; client1.disconnect(client1_conn_id);
client2.disconnect(client1_conn_id).await; client2.disconnect(client1_conn_id);
async fn handle_messages( async fn handle_messages(
mut messages: mpsc::Receiver<Box<dyn AnyTypedEnvelope>>, mut messages: mpsc::Receiver<Box<dyn AnyTypedEnvelope>>,
@ -448,7 +440,7 @@ mod tests {
proto::OpenBufferResponse { proto::OpenBufferResponse {
buffer: Some(proto::Buffer { buffer: Some(proto::Buffer {
id: 101, id: 101,
content: "path/one content".to_string(), visible_text: "path/one content".to_string(),
..Default::default() ..Default::default()
}), }),
} }
@ -458,7 +450,7 @@ mod tests {
proto::OpenBufferResponse { proto::OpenBufferResponse {
buffer: Some(proto::Buffer { buffer: Some(proto::Buffer {
id: 102, id: 102,
content: "path/two content".to_string(), visible_text: "path/two content".to_string(),
..Default::default() ..Default::default()
}), }),
} }
@ -502,7 +494,7 @@ mod tests {
}) })
.detach(); .detach();
client.disconnect(connection_id).await; client.disconnect(connection_id);
io_ended_rx.recv().await; io_ended_rx.recv().await;
messages_ended_rx.recv().await; messages_ended_rx.recv().await;

View File

@ -174,7 +174,7 @@ impl Server {
} }
async fn sign_out(self: &mut Arc<Self>, connection_id: ConnectionId) -> tide::Result<()> { async fn sign_out(self: &mut Arc<Self>, connection_id: ConnectionId) -> tide::Result<()> {
self.peer.disconnect(connection_id).await; self.peer.disconnect(connection_id);
let removed_connection = self.state_mut().remove_connection(connection_id)?; let removed_connection = self.state_mut().remove_connection(connection_id)?;
for (project_id, project) in removed_connection.hosted_projects { for (project_id, project) in removed_connection.hosted_projects {
@ -1801,7 +1801,7 @@ mod tests {
.await; .await;
// Drop client B's connection and ensure client A observes client B leaving the worktree. // Drop client B's connection and ensure client A observes client B leaving the worktree.
client_b.disconnect(&cx_b.to_async()).await.unwrap(); client_b.disconnect(&cx_b.to_async()).unwrap();
project_a project_a
.condition(&cx_a, |p, _| p.collaborators().len() == 0) .condition(&cx_a, |p, _| p.collaborators().len() == 0)
.await; .await;
@ -2833,7 +2833,7 @@ mod tests {
impl Drop for TestServer { impl Drop for TestServer {
fn drop(&mut self) { fn drop(&mut self) {
task::block_on(self.peer.reset()); self.peer.reset();
} }
} }