mirror of
https://github.com/zed-industries/zed.git
synced 2024-12-28 19:54:10 +03:00
WIP: Clear cached credentials if authentication fails
Still need to actually handle an HTTP response from the server indicating there was an invalid token. Co-Authored-By: Max Brunsfeld <maxbrunsfeld@gmail.com>
This commit is contained in:
parent
77a4a36eb3
commit
4a9918979e
9
Cargo.lock
generated
9
Cargo.lock
generated
@ -5108,18 +5108,18 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "thiserror"
|
||||
version = "1.0.24"
|
||||
version = "1.0.29"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e0f4a65597094d4483ddaed134f409b2cb7c1beccf25201a9f73c719254fa98e"
|
||||
checksum = "602eca064b2d83369e2b2f34b09c70b605402801927c65c11071ac911d299b88"
|
||||
dependencies = [
|
||||
"thiserror-impl",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "thiserror-impl"
|
||||
version = "1.0.24"
|
||||
version = "1.0.29"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7765189610d8241a44529806d6fd1f2e0a08734313a35d5b3a556f92b381f3c0"
|
||||
checksum = "bad553cc2c78e8de258400763a647e80e6d1b31ee237275d756f6836d204494c"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
@ -5914,6 +5914,7 @@ dependencies = [
|
||||
"smol",
|
||||
"surf",
|
||||
"tempdir",
|
||||
"thiserror",
|
||||
"time 0.3.2",
|
||||
"tiny_http",
|
||||
"toml 0.5.8",
|
||||
|
@ -27,7 +27,7 @@ use time::OffsetDateTime;
|
||||
use zrpc::{
|
||||
auth::random_token,
|
||||
proto::{self, AnyTypedEnvelope, EnvelopedMessage},
|
||||
Conn, ConnectionId, Peer, TypedEnvelope,
|
||||
Connection, ConnectionId, Peer, TypedEnvelope,
|
||||
};
|
||||
|
||||
type ReplicaId = u16;
|
||||
@ -48,13 +48,13 @@ pub struct Server {
|
||||
|
||||
#[derive(Default)]
|
||||
struct ServerState {
|
||||
connections: HashMap<ConnectionId, Connection>,
|
||||
connections: HashMap<ConnectionId, ConnectionState>,
|
||||
pub worktrees: HashMap<u64, Worktree>,
|
||||
channels: HashMap<ChannelId, Channel>,
|
||||
next_worktree_id: u64,
|
||||
}
|
||||
|
||||
struct Connection {
|
||||
struct ConnectionState {
|
||||
user_id: UserId,
|
||||
worktrees: HashSet<u64>,
|
||||
channels: HashSet<ChannelId>,
|
||||
@ -133,7 +133,7 @@ impl Server {
|
||||
|
||||
pub fn handle_connection(
|
||||
self: &Arc<Self>,
|
||||
connection: Conn,
|
||||
connection: Connection,
|
||||
addr: String,
|
||||
user_id: UserId,
|
||||
) -> impl Future<Output = ()> {
|
||||
@ -211,7 +211,7 @@ impl Server {
|
||||
async fn add_connection(&self, connection_id: ConnectionId, user_id: UserId) {
|
||||
self.state.write().await.connections.insert(
|
||||
connection_id,
|
||||
Connection {
|
||||
ConnectionState {
|
||||
user_id,
|
||||
worktrees: Default::default(),
|
||||
channels: Default::default(),
|
||||
@ -972,7 +972,7 @@ pub fn add_routes(app: &mut tide::Server<Arc<AppState>>, rpc: &Arc<Peer>) {
|
||||
let user_id = user_id.ok_or_else(|| anyhow!("user_id is not present on request. ensure auth::VerifyToken middleware is present"))?;
|
||||
task::spawn(async move {
|
||||
if let Some(stream) = upgrade_receiver.await {
|
||||
server.handle_connection(Conn::new(WebSocketStream::from_raw_socket(stream, Role::Server, None).await), addr, user_id).await;
|
||||
server.handle_connection(Connection::new(WebSocketStream::from_raw_socket(stream, Role::Server, None).await), addr, user_id).await;
|
||||
}
|
||||
});
|
||||
|
||||
@ -1023,7 +1023,7 @@ mod tests {
|
||||
editor::{Editor, Insert},
|
||||
fs::{FakeFs, Fs as _},
|
||||
language::LanguageRegistry,
|
||||
rpc::{self, Client, Credentials},
|
||||
rpc::{self, Client, Credentials, EstablishConnectionError},
|
||||
settings,
|
||||
test::FakeHttpClient,
|
||||
user::UserStore,
|
||||
@ -1941,9 +1941,11 @@ mod tests {
|
||||
let client_name = client_name.clone();
|
||||
cx.spawn(move |cx| async move {
|
||||
if forbid_connections.load(SeqCst) {
|
||||
Err(anyhow!("server is forbidding connections"))
|
||||
Err(EstablishConnectionError::other(anyhow!(
|
||||
"server is forbidding connections"
|
||||
)))
|
||||
} else {
|
||||
let (client_conn, server_conn, kill_conn) = Conn::in_memory();
|
||||
let (client_conn, server_conn, kill_conn) = Connection::in_memory();
|
||||
connection_killers.lock().insert(client_user_id, kill_conn);
|
||||
cx.background()
|
||||
.spawn(server.handle_connection(
|
||||
|
@ -50,6 +50,7 @@ smallvec = { version = "1.6", features = ["union"] }
|
||||
smol = "1.2.5"
|
||||
surf = "2.2"
|
||||
tempdir = { version = "0.3.7", optional = true }
|
||||
thiserror = "1.0.29"
|
||||
time = { version = "0.3" }
|
||||
tiny_http = "0.8"
|
||||
toml = "0.5"
|
||||
|
@ -15,10 +15,11 @@ use std::{
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
use surf::Url;
|
||||
use thiserror::Error;
|
||||
pub use zrpc::{proto, ConnectionId, PeerId, TypedEnvelope};
|
||||
use zrpc::{
|
||||
proto::{AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage},
|
||||
Conn, Peer, Receipt,
|
||||
Connection, Peer, Receipt,
|
||||
};
|
||||
|
||||
lazy_static! {
|
||||
@ -32,10 +33,32 @@ pub struct Client {
|
||||
authenticate:
|
||||
Option<Box<dyn 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>>>,
|
||||
establish_connection: Option<
|
||||
Box<dyn 'static + Send + Sync + Fn(&Credentials, &AsyncAppContext) -> Task<Result<Conn>>>,
|
||||
Box<
|
||||
dyn 'static
|
||||
+ Send
|
||||
+ Sync
|
||||
+ Fn(
|
||||
&Credentials,
|
||||
&AsyncAppContext,
|
||||
) -> Task<Result<Connection, EstablishConnectionError>>,
|
||||
>,
|
||||
>,
|
||||
}
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum EstablishConnectionError {
|
||||
#[error("invalid access token")]
|
||||
InvalidAccessToken,
|
||||
#[error("{0}")]
|
||||
Other(anyhow::Error),
|
||||
}
|
||||
|
||||
impl EstablishConnectionError {
|
||||
pub fn other(error: impl Into<anyhow::Error> + Send + Sync) -> Self {
|
||||
Self::Other(error.into())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
pub enum Status {
|
||||
SignedOut,
|
||||
@ -122,7 +145,10 @@ impl Client {
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
pub fn override_establish_connection<F>(&mut self, connect: F) -> &mut Self
|
||||
where
|
||||
F: 'static + Send + Sync + Fn(&Credentials, &AsyncAppContext) -> Task<Result<Conn>>,
|
||||
F: 'static
|
||||
+ Send
|
||||
+ Sync
|
||||
+ Fn(&Credentials, &AsyncAppContext) -> Task<Result<Connection, EstablishConnectionError>>,
|
||||
{
|
||||
self.establish_connection = Some(Box::new(connect));
|
||||
self
|
||||
@ -288,13 +314,18 @@ impl Client {
|
||||
Ok(())
|
||||
}
|
||||
Err(err) => {
|
||||
eprintln!("error in authenticate and connect {}", err);
|
||||
if matches!(err, EstablishConnectionError::InvalidAccessToken) {
|
||||
eprintln!("nuking credentials");
|
||||
self.state.write().credentials.take();
|
||||
}
|
||||
self.set_status(Status::ConnectionError, cx);
|
||||
Err(err)
|
||||
Err(err)?
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn set_connection(self: &Arc<Self>, conn: Conn, cx: &AsyncAppContext) {
|
||||
async fn set_connection(self: &Arc<Self>, conn: Connection, cx: &AsyncAppContext) {
|
||||
let (connection_id, handle_io, mut incoming) = self.peer.add_connection(conn).await;
|
||||
cx.foreground()
|
||||
.spawn({
|
||||
@ -359,7 +390,7 @@ impl Client {
|
||||
self: &Arc<Self>,
|
||||
credentials: &Credentials,
|
||||
cx: &AsyncAppContext,
|
||||
) -> Task<Result<Conn>> {
|
||||
) -> Task<Result<Connection, EstablishConnectionError>> {
|
||||
if let Some(callback) = self.establish_connection.as_ref() {
|
||||
callback(credentials, cx)
|
||||
} else {
|
||||
@ -371,28 +402,43 @@ impl Client {
|
||||
self: &Arc<Self>,
|
||||
credentials: &Credentials,
|
||||
cx: &AsyncAppContext,
|
||||
) -> Task<Result<Conn>> {
|
||||
) -> Task<Result<Connection, EstablishConnectionError>> {
|
||||
let request = Request::builder().header(
|
||||
"Authorization",
|
||||
format!("{} {}", credentials.user_id, credentials.access_token),
|
||||
);
|
||||
cx.background().spawn(async move {
|
||||
if let Some(host) = ZED_SERVER_URL.strip_prefix("https://") {
|
||||
let stream = smol::net::TcpStream::connect(host).await?;
|
||||
let request = request.uri(format!("wss://{}/rpc", host)).body(())?;
|
||||
let stream = smol::net::TcpStream::connect(host)
|
||||
.await
|
||||
.map_err(EstablishConnectionError::other)?;
|
||||
let request = request
|
||||
.uri(format!("wss://{}/rpc", host))
|
||||
.body(())
|
||||
.map_err(EstablishConnectionError::other)?;
|
||||
let (stream, _) = async_tungstenite::async_tls::client_async_tls(request, stream)
|
||||
.await
|
||||
.context("websocket handshake")?;
|
||||
Ok(Conn::new(stream))
|
||||
.context("websocket handshake")
|
||||
.map_err(EstablishConnectionError::other)?;
|
||||
Ok(Connection::new(stream))
|
||||
} else if let Some(host) = ZED_SERVER_URL.strip_prefix("http://") {
|
||||
let stream = smol::net::TcpStream::connect(host).await?;
|
||||
let request = request.uri(format!("ws://{}/rpc", host)).body(())?;
|
||||
let stream = smol::net::TcpStream::connect(host)
|
||||
.await
|
||||
.map_err(EstablishConnectionError::other)?;
|
||||
let request = request
|
||||
.uri(format!("ws://{}/rpc", host))
|
||||
.body(())
|
||||
.map_err(EstablishConnectionError::other)?;
|
||||
let (stream, _) = async_tungstenite::client_async(request, stream)
|
||||
.await
|
||||
.context("websocket handshake")?;
|
||||
Ok(Conn::new(stream))
|
||||
.context("websocket handshake")
|
||||
.map_err(EstablishConnectionError::other)?;
|
||||
Ok(Connection::new(stream))
|
||||
} else {
|
||||
Err(anyhow!("invalid server url: {}", *ZED_SERVER_URL))
|
||||
Err(EstablishConnectionError::other(anyhow!(
|
||||
"invalid server url: {}",
|
||||
*ZED_SERVER_URL
|
||||
)))
|
||||
}
|
||||
})
|
||||
}
|
||||
@ -591,6 +637,19 @@ mod tests {
|
||||
cx.foreground().advance_clock(Duration::from_secs(10));
|
||||
while !matches!(status.recv().await, Some(Status::Connected { .. })) {}
|
||||
assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting
|
||||
|
||||
server.forbid_connections();
|
||||
server.disconnect().await;
|
||||
while !matches!(status.recv().await, Some(Status::ReconnectionError { .. })) {}
|
||||
|
||||
// Clear cached credentials after authentication fails
|
||||
server.roll_access_token();
|
||||
server.allow_connections();
|
||||
cx.foreground().advance_clock(Duration::from_secs(10));
|
||||
assert_eq!(server.auth_count(), 1);
|
||||
cx.foreground().advance_clock(Duration::from_secs(10));
|
||||
while !matches!(status.recv().await, Some(Status::Connected { .. })) {}
|
||||
assert_eq!(server.auth_count(), 2); // Client re-authenticated due to an invalid token
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -4,7 +4,7 @@ use crate::{
|
||||
fs::RealFs,
|
||||
http::{HttpClient, Request, Response, ServerResponse},
|
||||
language::LanguageRegistry,
|
||||
rpc::{self, Client, Credentials},
|
||||
rpc::{self, Client, Credentials, EstablishConnectionError},
|
||||
settings::{self, ThemeRegistry},
|
||||
time::ReplicaId,
|
||||
user::UserStore,
|
||||
@ -26,7 +26,7 @@ use std::{
|
||||
},
|
||||
};
|
||||
use tempdir::TempDir;
|
||||
use zrpc::{proto, Conn, ConnectionId, Peer, Receipt, TypedEnvelope};
|
||||
use zrpc::{proto, Connection, ConnectionId, Peer, Receipt, TypedEnvelope};
|
||||
|
||||
#[cfg(test)]
|
||||
#[ctor::ctor]
|
||||
@ -210,6 +210,8 @@ pub struct FakeServer {
|
||||
connection_id: Mutex<Option<ConnectionId>>,
|
||||
forbid_connections: AtomicBool,
|
||||
auth_count: AtomicUsize,
|
||||
access_token: AtomicUsize,
|
||||
user_id: u64,
|
||||
}
|
||||
|
||||
impl FakeServer {
|
||||
@ -224,6 +226,8 @@ impl FakeServer {
|
||||
connection_id: Default::default(),
|
||||
forbid_connections: Default::default(),
|
||||
auth_count: Default::default(),
|
||||
access_token: Default::default(),
|
||||
user_id: client_user_id,
|
||||
});
|
||||
|
||||
Arc::get_mut(client)
|
||||
@ -232,8 +236,8 @@ impl FakeServer {
|
||||
let server = server.clone();
|
||||
move |cx| {
|
||||
server.auth_count.fetch_add(1, SeqCst);
|
||||
let access_token = server.access_token.load(SeqCst).to_string();
|
||||
cx.spawn(move |_| async move {
|
||||
let access_token = "the-token".to_string();
|
||||
Ok(Credentials {
|
||||
user_id: client_user_id,
|
||||
access_token,
|
||||
@ -244,11 +248,10 @@ impl FakeServer {
|
||||
.override_establish_connection({
|
||||
let server = server.clone();
|
||||
move |credentials, cx| {
|
||||
assert_eq!(credentials.user_id, client_user_id);
|
||||
assert_eq!(credentials.access_token, "the-token");
|
||||
let credentials = credentials.clone();
|
||||
cx.spawn({
|
||||
let server = server.clone();
|
||||
move |cx| async move { server.connect(&cx).await }
|
||||
move |cx| async move { server.establish_connection(&credentials, &cx).await }
|
||||
})
|
||||
}
|
||||
});
|
||||
@ -266,23 +269,39 @@ impl FakeServer {
|
||||
self.incoming.lock().take();
|
||||
}
|
||||
|
||||
async fn connect(&self, cx: &AsyncAppContext) -> Result<Conn> {
|
||||
async fn establish_connection(
|
||||
&self,
|
||||
credentials: &Credentials,
|
||||
cx: &AsyncAppContext,
|
||||
) -> Result<Connection, EstablishConnectionError> {
|
||||
assert_eq!(credentials.user_id, self.user_id);
|
||||
|
||||
if self.forbid_connections.load(SeqCst) {
|
||||
Err(anyhow!("server is forbidding connections"))
|
||||
} else {
|
||||
let (client_conn, server_conn, _) = Conn::in_memory();
|
||||
let (connection_id, io, incoming) = self.peer.add_connection(server_conn).await;
|
||||
cx.background().spawn(io).detach();
|
||||
*self.incoming.lock() = Some(incoming);
|
||||
*self.connection_id.lock() = Some(connection_id);
|
||||
Ok(client_conn)
|
||||
Err(EstablishConnectionError::Other(anyhow!(
|
||||
"server is forbidding connections"
|
||||
)))?
|
||||
}
|
||||
|
||||
if credentials.access_token != self.access_token.load(SeqCst).to_string() {
|
||||
Err(EstablishConnectionError::InvalidAccessToken)?
|
||||
}
|
||||
|
||||
let (client_conn, server_conn, _) = Connection::in_memory();
|
||||
let (connection_id, io, incoming) = self.peer.add_connection(server_conn).await;
|
||||
cx.background().spawn(io).detach();
|
||||
*self.incoming.lock() = Some(incoming);
|
||||
*self.connection_id.lock() = Some(connection_id);
|
||||
Ok(client_conn)
|
||||
}
|
||||
|
||||
pub fn auth_count(&self) -> usize {
|
||||
self.auth_count.load(SeqCst)
|
||||
}
|
||||
|
||||
pub fn roll_access_token(&self) {
|
||||
self.access_token.fetch_add(1, SeqCst);
|
||||
}
|
||||
|
||||
pub fn forbid_connections(&self) {
|
||||
self.forbid_connections.store(true, SeqCst);
|
||||
}
|
||||
|
@ -2,7 +2,7 @@ use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSock
|
||||
use futures::{channel::mpsc, SinkExt as _, Stream, StreamExt as _};
|
||||
use std::{io, task::Poll};
|
||||
|
||||
pub struct Conn {
|
||||
pub struct Connection {
|
||||
pub(crate) tx:
|
||||
Box<dyn 'static + Send + Unpin + futures::Sink<WebSocketMessage, Error = WebSocketError>>,
|
||||
pub(crate) rx: Box<
|
||||
@ -13,7 +13,7 @@ pub struct Conn {
|
||||
>,
|
||||
}
|
||||
|
||||
impl Conn {
|
||||
impl Connection {
|
||||
pub fn new<S>(stream: S) -> Self
|
||||
where
|
||||
S: 'static
|
||||
|
@ -2,5 +2,5 @@ pub mod auth;
|
||||
mod conn;
|
||||
mod peer;
|
||||
pub mod proto;
|
||||
pub use conn::Conn;
|
||||
pub use conn::Connection;
|
||||
pub use peer::*;
|
||||
|
@ -1,5 +1,5 @@
|
||||
use super::proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, RequestMessage};
|
||||
use super::Conn;
|
||||
use super::Connection;
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use async_lock::{Mutex, RwLock};
|
||||
use futures::FutureExt as _;
|
||||
@ -79,12 +79,12 @@ impl<T: RequestMessage> TypedEnvelope<T> {
|
||||
}
|
||||
|
||||
pub struct Peer {
|
||||
connections: RwLock<HashMap<ConnectionId, Connection>>,
|
||||
connections: RwLock<HashMap<ConnectionId, ConnectionState>>,
|
||||
next_connection_id: AtomicU32,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct Connection {
|
||||
struct ConnectionState {
|
||||
outgoing_tx: mpsc::Sender<proto::Envelope>,
|
||||
next_message_id: Arc<AtomicU32>,
|
||||
response_channels: Arc<Mutex<HashMap<u32, mpsc::Sender<proto::Envelope>>>>,
|
||||
@ -100,7 +100,7 @@ impl Peer {
|
||||
|
||||
pub async fn add_connection(
|
||||
self: &Arc<Self>,
|
||||
conn: Conn,
|
||||
connection: Connection,
|
||||
) -> (
|
||||
ConnectionId,
|
||||
impl Future<Output = anyhow::Result<()>> + Send,
|
||||
@ -112,16 +112,16 @@ impl Peer {
|
||||
);
|
||||
let (mut incoming_tx, incoming_rx) = mpsc::channel(64);
|
||||
let (outgoing_tx, mut outgoing_rx) = mpsc::channel(64);
|
||||
let connection = Connection {
|
||||
let connection_state = ConnectionState {
|
||||
outgoing_tx,
|
||||
next_message_id: Default::default(),
|
||||
response_channels: Default::default(),
|
||||
};
|
||||
let mut writer = MessageStream::new(conn.tx);
|
||||
let mut reader = MessageStream::new(conn.rx);
|
||||
let mut writer = MessageStream::new(connection.tx);
|
||||
let mut reader = MessageStream::new(connection.rx);
|
||||
|
||||
let this = self.clone();
|
||||
let response_channels = connection.response_channels.clone();
|
||||
let response_channels = connection_state.response_channels.clone();
|
||||
let handle_io = async move {
|
||||
loop {
|
||||
let read_message = reader.read_message().fuse();
|
||||
@ -179,7 +179,7 @@ impl Peer {
|
||||
self.connections
|
||||
.write()
|
||||
.await
|
||||
.insert(connection_id, connection);
|
||||
.insert(connection_id, connection_state);
|
||||
|
||||
(connection_id, handle_io, incoming_rx)
|
||||
}
|
||||
@ -218,7 +218,7 @@ impl Peer {
|
||||
let this = self.clone();
|
||||
let (tx, mut rx) = mpsc::channel(1);
|
||||
async move {
|
||||
let mut connection = this.connection(receiver_id).await?;
|
||||
let mut connection = this.connection_state(receiver_id).await?;
|
||||
let message_id = connection
|
||||
.next_message_id
|
||||
.fetch_add(1, atomic::Ordering::SeqCst);
|
||||
@ -252,7 +252,7 @@ impl Peer {
|
||||
) -> impl Future<Output = Result<()>> {
|
||||
let this = self.clone();
|
||||
async move {
|
||||
let mut connection = this.connection(receiver_id).await?;
|
||||
let mut connection = this.connection_state(receiver_id).await?;
|
||||
let message_id = connection
|
||||
.next_message_id
|
||||
.fetch_add(1, atomic::Ordering::SeqCst);
|
||||
@ -272,7 +272,7 @@ impl Peer {
|
||||
) -> impl Future<Output = Result<()>> {
|
||||
let this = self.clone();
|
||||
async move {
|
||||
let mut connection = this.connection(receiver_id).await?;
|
||||
let mut connection = this.connection_state(receiver_id).await?;
|
||||
let message_id = connection
|
||||
.next_message_id
|
||||
.fetch_add(1, atomic::Ordering::SeqCst);
|
||||
@ -291,7 +291,7 @@ impl Peer {
|
||||
) -> impl Future<Output = Result<()>> {
|
||||
let this = self.clone();
|
||||
async move {
|
||||
let mut connection = this.connection(receipt.sender_id).await?;
|
||||
let mut connection = this.connection_state(receipt.sender_id).await?;
|
||||
let message_id = connection
|
||||
.next_message_id
|
||||
.fetch_add(1, atomic::Ordering::SeqCst);
|
||||
@ -310,7 +310,7 @@ impl Peer {
|
||||
) -> impl Future<Output = Result<()>> {
|
||||
let this = self.clone();
|
||||
async move {
|
||||
let mut connection = this.connection(receipt.sender_id).await?;
|
||||
let mut connection = this.connection_state(receipt.sender_id).await?;
|
||||
let message_id = connection
|
||||
.next_message_id
|
||||
.fetch_add(1, atomic::Ordering::SeqCst);
|
||||
@ -322,10 +322,10 @@ impl Peer {
|
||||
}
|
||||
}
|
||||
|
||||
fn connection(
|
||||
fn connection_state(
|
||||
self: &Arc<Self>,
|
||||
connection_id: ConnectionId,
|
||||
) -> impl Future<Output = Result<Connection>> {
|
||||
) -> impl Future<Output = Result<ConnectionState>> {
|
||||
let this = self.clone();
|
||||
async move {
|
||||
let connections = this.connections.read().await;
|
||||
@ -352,12 +352,12 @@ mod tests {
|
||||
let client1 = Peer::new();
|
||||
let client2 = Peer::new();
|
||||
|
||||
let (client1_to_server_conn, server_to_client_1_conn, _) = Conn::in_memory();
|
||||
let (client1_to_server_conn, server_to_client_1_conn, _) = Connection::in_memory();
|
||||
let (client1_conn_id, io_task1, _) =
|
||||
client1.add_connection(client1_to_server_conn).await;
|
||||
let (_, io_task2, incoming1) = server.add_connection(server_to_client_1_conn).await;
|
||||
|
||||
let (client2_to_server_conn, server_to_client_2_conn, _) = Conn::in_memory();
|
||||
let (client2_to_server_conn, server_to_client_2_conn, _) = Connection::in_memory();
|
||||
let (client2_conn_id, io_task3, _) =
|
||||
client2.add_connection(client2_to_server_conn).await;
|
||||
let (_, io_task4, incoming2) = server.add_connection(server_to_client_2_conn).await;
|
||||
@ -486,7 +486,7 @@ mod tests {
|
||||
#[test]
|
||||
fn test_disconnect() {
|
||||
smol::block_on(async move {
|
||||
let (client_conn, mut server_conn, _) = Conn::in_memory();
|
||||
let (client_conn, mut server_conn, _) = Connection::in_memory();
|
||||
|
||||
let client = Peer::new();
|
||||
let (connection_id, io_handler, mut incoming) =
|
||||
@ -520,7 +520,7 @@ mod tests {
|
||||
#[test]
|
||||
fn test_io_error() {
|
||||
smol::block_on(async move {
|
||||
let (client_conn, server_conn, _) = Conn::in_memory();
|
||||
let (client_conn, server_conn, _) = Connection::in_memory();
|
||||
drop(server_conn);
|
||||
|
||||
let client = Peer::new();
|
||||
|
Loading…
Reference in New Issue
Block a user