mirror of
https://github.com/zed-industries/zed.git
synced 2024-11-08 07:35:01 +03:00
Use synchronous locks for Peer
state
We hold these locks for a short amount of time anyway, and using an async lock could cause parallel sends to happen in an order different than the order in which `send`/`request` was called. Co-Authored-By: Nathan Sobo <nathan@zed.dev>
This commit is contained in:
parent
310def2923
commit
9e4b118214
@ -661,9 +661,9 @@ impl Client {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn disconnect(self: &Arc<Self>, cx: &AsyncAppContext) -> Result<()> {
|
pub fn disconnect(self: &Arc<Self>, cx: &AsyncAppContext) -> Result<()> {
|
||||||
let conn_id = self.connection_id()?;
|
let conn_id = self.connection_id()?;
|
||||||
self.peer.disconnect(conn_id).await;
|
self.peer.disconnect(conn_id);
|
||||||
self.set_status(Status::SignedOut, cx);
|
self.set_status(Status::SignedOut, cx);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -764,7 +764,7 @@ mod tests {
|
|||||||
let ping = server.receive::<proto::Ping>().await.unwrap();
|
let ping = server.receive::<proto::Ping>().await.unwrap();
|
||||||
server.respond(ping.receipt(), proto::Ack {}).await;
|
server.respond(ping.receipt(), proto::Ack {}).await;
|
||||||
|
|
||||||
client.disconnect(&cx.to_async()).await.unwrap();
|
client.disconnect(&cx.to_async()).unwrap();
|
||||||
assert!(server.receive::<proto::Ping>().await.is_err());
|
assert!(server.receive::<proto::Ping>().await.is_err());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -783,7 +783,7 @@ mod tests {
|
|||||||
assert_eq!(server.auth_count(), 1);
|
assert_eq!(server.auth_count(), 1);
|
||||||
|
|
||||||
server.forbid_connections();
|
server.forbid_connections();
|
||||||
server.disconnect().await;
|
server.disconnect();
|
||||||
while !matches!(status.recv().await, Some(Status::ReconnectionError { .. })) {}
|
while !matches!(status.recv().await, Some(Status::ReconnectionError { .. })) {}
|
||||||
|
|
||||||
server.allow_connections();
|
server.allow_connections();
|
||||||
@ -792,7 +792,7 @@ mod tests {
|
|||||||
assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting
|
assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting
|
||||||
|
|
||||||
server.forbid_connections();
|
server.forbid_connections();
|
||||||
server.disconnect().await;
|
server.disconnect();
|
||||||
while !matches!(status.recv().await, Some(Status::ReconnectionError { .. })) {}
|
while !matches!(status.recv().await, Some(Status::ReconnectionError { .. })) {}
|
||||||
|
|
||||||
// Clear cached credentials after authentication fails
|
// Clear cached credentials after authentication fails
|
||||||
|
@ -72,8 +72,8 @@ impl FakeServer {
|
|||||||
server
|
server
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn disconnect(&self) {
|
pub fn disconnect(&self) {
|
||||||
self.peer.disconnect(self.connection_id()).await;
|
self.peer.disconnect(self.connection_id());
|
||||||
self.connection_id.lock().take();
|
self.connection_id.lock().take();
|
||||||
self.incoming.lock().take();
|
self.incoming.lock().take();
|
||||||
}
|
}
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
use super::proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, RequestMessage};
|
use super::proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, RequestMessage};
|
||||||
use super::Connection;
|
use super::Connection;
|
||||||
use anyhow::{anyhow, Context, Result};
|
use anyhow::{anyhow, Context, Result};
|
||||||
use async_lock::{Mutex, RwLock};
|
|
||||||
use futures::FutureExt as _;
|
use futures::FutureExt as _;
|
||||||
|
use parking_lot::{Mutex, RwLock};
|
||||||
use postage::{
|
use postage::{
|
||||||
mpsc,
|
mpsc,
|
||||||
prelude::{Sink as _, Stream as _},
|
prelude::{Sink as _, Stream as _},
|
||||||
@ -133,7 +133,7 @@ impl Peer {
|
|||||||
incoming = read_message => match incoming {
|
incoming = read_message => match incoming {
|
||||||
Ok(incoming) => {
|
Ok(incoming) => {
|
||||||
if let Some(responding_to) = incoming.responding_to {
|
if let Some(responding_to) = incoming.responding_to {
|
||||||
let channel = response_channels.lock().await.as_mut().unwrap().remove(&responding_to);
|
let channel = response_channels.lock().as_mut().unwrap().remove(&responding_to);
|
||||||
if let Some(mut tx) = channel {
|
if let Some(mut tx) = channel {
|
||||||
tx.send(incoming).await.ok();
|
tx.send(incoming).await.ok();
|
||||||
} else {
|
} else {
|
||||||
@ -169,25 +169,24 @@ impl Peer {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
response_channels.lock().await.take();
|
response_channels.lock().take();
|
||||||
this.connections.write().await.remove(&connection_id);
|
this.connections.write().remove(&connection_id);
|
||||||
result
|
result
|
||||||
};
|
};
|
||||||
|
|
||||||
self.connections
|
self.connections
|
||||||
.write()
|
.write()
|
||||||
.await
|
|
||||||
.insert(connection_id, connection_state);
|
.insert(connection_id, connection_state);
|
||||||
|
|
||||||
(connection_id, handle_io, incoming_rx)
|
(connection_id, handle_io, incoming_rx)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn disconnect(&self, connection_id: ConnectionId) {
|
pub fn disconnect(&self, connection_id: ConnectionId) {
|
||||||
self.connections.write().await.remove(&connection_id);
|
self.connections.write().remove(&connection_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn reset(&self) {
|
pub fn reset(&self) {
|
||||||
self.connections.write().await.clear();
|
self.connections.write().clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn request<T: RequestMessage>(
|
pub fn request<T: RequestMessage>(
|
||||||
@ -216,12 +215,11 @@ impl Peer {
|
|||||||
let this = self.clone();
|
let this = self.clone();
|
||||||
let (tx, mut rx) = mpsc::channel(1);
|
let (tx, mut rx) = mpsc::channel(1);
|
||||||
async move {
|
async move {
|
||||||
let mut connection = this.connection_state(receiver_id).await?;
|
let mut connection = this.connection_state(receiver_id)?;
|
||||||
let message_id = connection.next_message_id.fetch_add(1, SeqCst);
|
let message_id = connection.next_message_id.fetch_add(1, SeqCst);
|
||||||
connection
|
connection
|
||||||
.response_channels
|
.response_channels
|
||||||
.lock()
|
.lock()
|
||||||
.await
|
|
||||||
.as_mut()
|
.as_mut()
|
||||||
.ok_or_else(|| anyhow!("connection was closed"))?
|
.ok_or_else(|| anyhow!("connection was closed"))?
|
||||||
.insert(message_id, tx);
|
.insert(message_id, tx);
|
||||||
@ -250,7 +248,7 @@ impl Peer {
|
|||||||
) -> impl Future<Output = Result<()>> {
|
) -> impl Future<Output = Result<()>> {
|
||||||
let this = self.clone();
|
let this = self.clone();
|
||||||
async move {
|
async move {
|
||||||
let mut connection = this.connection_state(receiver_id).await?;
|
let mut connection = this.connection_state(receiver_id)?;
|
||||||
let message_id = connection
|
let message_id = connection
|
||||||
.next_message_id
|
.next_message_id
|
||||||
.fetch_add(1, atomic::Ordering::SeqCst);
|
.fetch_add(1, atomic::Ordering::SeqCst);
|
||||||
@ -270,7 +268,7 @@ impl Peer {
|
|||||||
) -> impl Future<Output = Result<()>> {
|
) -> impl Future<Output = Result<()>> {
|
||||||
let this = self.clone();
|
let this = self.clone();
|
||||||
async move {
|
async move {
|
||||||
let mut connection = this.connection_state(receiver_id).await?;
|
let mut connection = this.connection_state(receiver_id)?;
|
||||||
let message_id = connection
|
let message_id = connection
|
||||||
.next_message_id
|
.next_message_id
|
||||||
.fetch_add(1, atomic::Ordering::SeqCst);
|
.fetch_add(1, atomic::Ordering::SeqCst);
|
||||||
@ -289,7 +287,7 @@ impl Peer {
|
|||||||
) -> impl Future<Output = Result<()>> {
|
) -> impl Future<Output = Result<()>> {
|
||||||
let this = self.clone();
|
let this = self.clone();
|
||||||
async move {
|
async move {
|
||||||
let mut connection = this.connection_state(receipt.sender_id).await?;
|
let mut connection = this.connection_state(receipt.sender_id)?;
|
||||||
let message_id = connection
|
let message_id = connection
|
||||||
.next_message_id
|
.next_message_id
|
||||||
.fetch_add(1, atomic::Ordering::SeqCst);
|
.fetch_add(1, atomic::Ordering::SeqCst);
|
||||||
@ -308,7 +306,7 @@ impl Peer {
|
|||||||
) -> impl Future<Output = Result<()>> {
|
) -> impl Future<Output = Result<()>> {
|
||||||
let this = self.clone();
|
let this = self.clone();
|
||||||
async move {
|
async move {
|
||||||
let mut connection = this.connection_state(receipt.sender_id).await?;
|
let mut connection = this.connection_state(receipt.sender_id)?;
|
||||||
let message_id = connection
|
let message_id = connection
|
||||||
.next_message_id
|
.next_message_id
|
||||||
.fetch_add(1, atomic::Ordering::SeqCst);
|
.fetch_add(1, atomic::Ordering::SeqCst);
|
||||||
@ -320,20 +318,14 @@ impl Peer {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn connection_state(
|
fn connection_state(&self, connection_id: ConnectionId) -> Result<ConnectionState> {
|
||||||
self: &Arc<Self>,
|
let connections = self.connections.read();
|
||||||
connection_id: ConnectionId,
|
|
||||||
) -> impl Future<Output = Result<ConnectionState>> {
|
|
||||||
let this = self.clone();
|
|
||||||
async move {
|
|
||||||
let connections = this.connections.read().await;
|
|
||||||
let connection = connections
|
let connection = connections
|
||||||
.get(&connection_id)
|
.get(&connection_id)
|
||||||
.ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
|
.ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
|
||||||
Ok(connection.clone())
|
Ok(connection.clone())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
@ -398,7 +390,7 @@ mod tests {
|
|||||||
proto::OpenBufferResponse {
|
proto::OpenBufferResponse {
|
||||||
buffer: Some(proto::Buffer {
|
buffer: Some(proto::Buffer {
|
||||||
id: 101,
|
id: 101,
|
||||||
content: "path/one content".to_string(),
|
visible_text: "path/one content".to_string(),
|
||||||
..Default::default()
|
..Default::default()
|
||||||
}),
|
}),
|
||||||
}
|
}
|
||||||
@ -419,14 +411,14 @@ mod tests {
|
|||||||
proto::OpenBufferResponse {
|
proto::OpenBufferResponse {
|
||||||
buffer: Some(proto::Buffer {
|
buffer: Some(proto::Buffer {
|
||||||
id: 102,
|
id: 102,
|
||||||
content: "path/two content".to_string(),
|
visible_text: "path/two content".to_string(),
|
||||||
..Default::default()
|
..Default::default()
|
||||||
}),
|
}),
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
|
||||||
client1.disconnect(client1_conn_id).await;
|
client1.disconnect(client1_conn_id);
|
||||||
client2.disconnect(client1_conn_id).await;
|
client2.disconnect(client1_conn_id);
|
||||||
|
|
||||||
async fn handle_messages(
|
async fn handle_messages(
|
||||||
mut messages: mpsc::Receiver<Box<dyn AnyTypedEnvelope>>,
|
mut messages: mpsc::Receiver<Box<dyn AnyTypedEnvelope>>,
|
||||||
@ -448,7 +440,7 @@ mod tests {
|
|||||||
proto::OpenBufferResponse {
|
proto::OpenBufferResponse {
|
||||||
buffer: Some(proto::Buffer {
|
buffer: Some(proto::Buffer {
|
||||||
id: 101,
|
id: 101,
|
||||||
content: "path/one content".to_string(),
|
visible_text: "path/one content".to_string(),
|
||||||
..Default::default()
|
..Default::default()
|
||||||
}),
|
}),
|
||||||
}
|
}
|
||||||
@ -458,7 +450,7 @@ mod tests {
|
|||||||
proto::OpenBufferResponse {
|
proto::OpenBufferResponse {
|
||||||
buffer: Some(proto::Buffer {
|
buffer: Some(proto::Buffer {
|
||||||
id: 102,
|
id: 102,
|
||||||
content: "path/two content".to_string(),
|
visible_text: "path/two content".to_string(),
|
||||||
..Default::default()
|
..Default::default()
|
||||||
}),
|
}),
|
||||||
}
|
}
|
||||||
@ -502,7 +494,7 @@ mod tests {
|
|||||||
})
|
})
|
||||||
.detach();
|
.detach();
|
||||||
|
|
||||||
client.disconnect(connection_id).await;
|
client.disconnect(connection_id);
|
||||||
|
|
||||||
io_ended_rx.recv().await;
|
io_ended_rx.recv().await;
|
||||||
messages_ended_rx.recv().await;
|
messages_ended_rx.recv().await;
|
||||||
|
@ -174,7 +174,7 @@ impl Server {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async fn sign_out(self: &mut Arc<Self>, connection_id: ConnectionId) -> tide::Result<()> {
|
async fn sign_out(self: &mut Arc<Self>, connection_id: ConnectionId) -> tide::Result<()> {
|
||||||
self.peer.disconnect(connection_id).await;
|
self.peer.disconnect(connection_id);
|
||||||
let removed_connection = self.state_mut().remove_connection(connection_id)?;
|
let removed_connection = self.state_mut().remove_connection(connection_id)?;
|
||||||
|
|
||||||
for (project_id, project) in removed_connection.hosted_projects {
|
for (project_id, project) in removed_connection.hosted_projects {
|
||||||
@ -1801,7 +1801,7 @@ mod tests {
|
|||||||
.await;
|
.await;
|
||||||
|
|
||||||
// Drop client B's connection and ensure client A observes client B leaving the worktree.
|
// Drop client B's connection and ensure client A observes client B leaving the worktree.
|
||||||
client_b.disconnect(&cx_b.to_async()).await.unwrap();
|
client_b.disconnect(&cx_b.to_async()).unwrap();
|
||||||
project_a
|
project_a
|
||||||
.condition(&cx_a, |p, _| p.collaborators().len() == 0)
|
.condition(&cx_a, |p, _| p.collaborators().len() == 0)
|
||||||
.await;
|
.await;
|
||||||
@ -2833,7 +2833,7 @@ mod tests {
|
|||||||
|
|
||||||
impl Drop for TestServer {
|
impl Drop for TestServer {
|
||||||
fn drop(&mut self) {
|
fn drop(&mut self) {
|
||||||
task::block_on(self.peer.reset());
|
self.peer.reset();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user