From 4a9918979e77ef03357fa3ec76e23a523ca126b0 Mon Sep 17 00:00:00 2001 From: Nathan Sobo Date: Tue, 14 Sep 2021 19:19:11 -0600 Subject: [PATCH] WIP: Clear cached credentials if authentication fails Still need to actually handle an HTTP response from the server indicating there was an invalid token. Co-Authored-By: Max Brunsfeld --- Cargo.lock | 9 ++--- server/src/rpc.rs | 20 ++++++----- zed/Cargo.toml | 1 + zed/src/rpc.rs | 91 ++++++++++++++++++++++++++++++++++++++--------- zed/src/test.rs | 49 +++++++++++++++++-------- zrpc/src/conn.rs | 4 +-- zrpc/src/lib.rs | 2 +- zrpc/src/peer.rs | 40 ++++++++++----------- 8 files changed, 149 insertions(+), 67 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c749b35f2d..aea0c49da8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5108,18 +5108,18 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.24" +version = "1.0.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e0f4a65597094d4483ddaed134f409b2cb7c1beccf25201a9f73c719254fa98e" +checksum = "602eca064b2d83369e2b2f34b09c70b605402801927c65c11071ac911d299b88" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.24" +version = "1.0.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7765189610d8241a44529806d6fd1f2e0a08734313a35d5b3a556f92b381f3c0" +checksum = "bad553cc2c78e8de258400763a647e80e6d1b31ee237275d756f6836d204494c" dependencies = [ "proc-macro2", "quote", @@ -5914,6 +5914,7 @@ dependencies = [ "smol", "surf", "tempdir", + "thiserror", "time 0.3.2", "tiny_http", "toml 0.5.8", diff --git a/server/src/rpc.rs b/server/src/rpc.rs index 86f369fb8a..1e0fe2465c 100644 --- a/server/src/rpc.rs +++ b/server/src/rpc.rs @@ -27,7 +27,7 @@ use time::OffsetDateTime; use zrpc::{ auth::random_token, proto::{self, AnyTypedEnvelope, EnvelopedMessage}, - Conn, ConnectionId, Peer, TypedEnvelope, + Connection, ConnectionId, Peer, TypedEnvelope, }; type ReplicaId = u16; @@ -48,13 +48,13 @@ pub struct Server { #[derive(Default)] struct ServerState { - connections: HashMap, + connections: HashMap, pub worktrees: HashMap, channels: HashMap, next_worktree_id: u64, } -struct Connection { +struct ConnectionState { user_id: UserId, worktrees: HashSet, channels: HashSet, @@ -133,7 +133,7 @@ impl Server { pub fn handle_connection( self: &Arc, - connection: Conn, + connection: Connection, addr: String, user_id: UserId, ) -> impl Future { @@ -211,7 +211,7 @@ impl Server { async fn add_connection(&self, connection_id: ConnectionId, user_id: UserId) { self.state.write().await.connections.insert( connection_id, - Connection { + ConnectionState { user_id, worktrees: Default::default(), channels: Default::default(), @@ -972,7 +972,7 @@ pub fn add_routes(app: &mut tide::Server>, rpc: &Arc) { let user_id = user_id.ok_or_else(|| anyhow!("user_id is not present on request. ensure auth::VerifyToken middleware is present"))?; task::spawn(async move { if let Some(stream) = upgrade_receiver.await { - server.handle_connection(Conn::new(WebSocketStream::from_raw_socket(stream, Role::Server, None).await), addr, user_id).await; + server.handle_connection(Connection::new(WebSocketStream::from_raw_socket(stream, Role::Server, None).await), addr, user_id).await; } }); @@ -1023,7 +1023,7 @@ mod tests { editor::{Editor, Insert}, fs::{FakeFs, Fs as _}, language::LanguageRegistry, - rpc::{self, Client, Credentials}, + rpc::{self, Client, Credentials, EstablishConnectionError}, settings, test::FakeHttpClient, user::UserStore, @@ -1941,9 +1941,11 @@ mod tests { let client_name = client_name.clone(); cx.spawn(move |cx| async move { if forbid_connections.load(SeqCst) { - Err(anyhow!("server is forbidding connections")) + Err(EstablishConnectionError::other(anyhow!( + "server is forbidding connections" + ))) } else { - let (client_conn, server_conn, kill_conn) = Conn::in_memory(); + let (client_conn, server_conn, kill_conn) = Connection::in_memory(); connection_killers.lock().insert(client_user_id, kill_conn); cx.background() .spawn(server.handle_connection( diff --git a/zed/Cargo.toml b/zed/Cargo.toml index 2b36db9ba8..d9c2cc6a58 100644 --- a/zed/Cargo.toml +++ b/zed/Cargo.toml @@ -50,6 +50,7 @@ smallvec = { version = "1.6", features = ["union"] } smol = "1.2.5" surf = "2.2" tempdir = { version = "0.3.7", optional = true } +thiserror = "1.0.29" time = { version = "0.3" } tiny_http = "0.8" toml = "0.5" diff --git a/zed/src/rpc.rs b/zed/src/rpc.rs index 8846b02f4b..3526381cde 100644 --- a/zed/src/rpc.rs +++ b/zed/src/rpc.rs @@ -15,10 +15,11 @@ use std::{ time::{Duration, Instant}, }; use surf::Url; +use thiserror::Error; pub use zrpc::{proto, ConnectionId, PeerId, TypedEnvelope}; use zrpc::{ proto::{AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage}, - Conn, Peer, Receipt, + Connection, Peer, Receipt, }; lazy_static! { @@ -32,10 +33,32 @@ pub struct Client { authenticate: Option Task>>>, establish_connection: Option< - Box Task>>, + Box< + dyn 'static + + Send + + Sync + + Fn( + &Credentials, + &AsyncAppContext, + ) -> Task>, + >, >, } +#[derive(Error, Debug)] +pub enum EstablishConnectionError { + #[error("invalid access token")] + InvalidAccessToken, + #[error("{0}")] + Other(anyhow::Error), +} + +impl EstablishConnectionError { + pub fn other(error: impl Into + Send + Sync) -> Self { + Self::Other(error.into()) + } +} + #[derive(Copy, Clone, Debug)] pub enum Status { SignedOut, @@ -122,7 +145,10 @@ impl Client { #[cfg(any(test, feature = "test-support"))] pub fn override_establish_connection(&mut self, connect: F) -> &mut Self where - F: 'static + Send + Sync + Fn(&Credentials, &AsyncAppContext) -> Task>, + F: 'static + + Send + + Sync + + Fn(&Credentials, &AsyncAppContext) -> Task>, { self.establish_connection = Some(Box::new(connect)); self @@ -288,13 +314,18 @@ impl Client { Ok(()) } Err(err) => { + eprintln!("error in authenticate and connect {}", err); + if matches!(err, EstablishConnectionError::InvalidAccessToken) { + eprintln!("nuking credentials"); + self.state.write().credentials.take(); + } self.set_status(Status::ConnectionError, cx); - Err(err) + Err(err)? } } } - async fn set_connection(self: &Arc, conn: Conn, cx: &AsyncAppContext) { + async fn set_connection(self: &Arc, conn: Connection, cx: &AsyncAppContext) { let (connection_id, handle_io, mut incoming) = self.peer.add_connection(conn).await; cx.foreground() .spawn({ @@ -359,7 +390,7 @@ impl Client { self: &Arc, credentials: &Credentials, cx: &AsyncAppContext, - ) -> Task> { + ) -> Task> { if let Some(callback) = self.establish_connection.as_ref() { callback(credentials, cx) } else { @@ -371,28 +402,43 @@ impl Client { self: &Arc, credentials: &Credentials, cx: &AsyncAppContext, - ) -> Task> { + ) -> Task> { let request = Request::builder().header( "Authorization", format!("{} {}", credentials.user_id, credentials.access_token), ); cx.background().spawn(async move { if let Some(host) = ZED_SERVER_URL.strip_prefix("https://") { - let stream = smol::net::TcpStream::connect(host).await?; - let request = request.uri(format!("wss://{}/rpc", host)).body(())?; + let stream = smol::net::TcpStream::connect(host) + .await + .map_err(EstablishConnectionError::other)?; + let request = request + .uri(format!("wss://{}/rpc", host)) + .body(()) + .map_err(EstablishConnectionError::other)?; let (stream, _) = async_tungstenite::async_tls::client_async_tls(request, stream) .await - .context("websocket handshake")?; - Ok(Conn::new(stream)) + .context("websocket handshake") + .map_err(EstablishConnectionError::other)?; + Ok(Connection::new(stream)) } else if let Some(host) = ZED_SERVER_URL.strip_prefix("http://") { - let stream = smol::net::TcpStream::connect(host).await?; - let request = request.uri(format!("ws://{}/rpc", host)).body(())?; + let stream = smol::net::TcpStream::connect(host) + .await + .map_err(EstablishConnectionError::other)?; + let request = request + .uri(format!("ws://{}/rpc", host)) + .body(()) + .map_err(EstablishConnectionError::other)?; let (stream, _) = async_tungstenite::client_async(request, stream) .await - .context("websocket handshake")?; - Ok(Conn::new(stream)) + .context("websocket handshake") + .map_err(EstablishConnectionError::other)?; + Ok(Connection::new(stream)) } else { - Err(anyhow!("invalid server url: {}", *ZED_SERVER_URL)) + Err(EstablishConnectionError::other(anyhow!( + "invalid server url: {}", + *ZED_SERVER_URL + ))) } }) } @@ -591,6 +637,19 @@ mod tests { cx.foreground().advance_clock(Duration::from_secs(10)); while !matches!(status.recv().await, Some(Status::Connected { .. })) {} assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting + + server.forbid_connections(); + server.disconnect().await; + while !matches!(status.recv().await, Some(Status::ReconnectionError { .. })) {} + + // Clear cached credentials after authentication fails + server.roll_access_token(); + server.allow_connections(); + cx.foreground().advance_clock(Duration::from_secs(10)); + assert_eq!(server.auth_count(), 1); + cx.foreground().advance_clock(Duration::from_secs(10)); + while !matches!(status.recv().await, Some(Status::Connected { .. })) {} + assert_eq!(server.auth_count(), 2); // Client re-authenticated due to an invalid token } #[test] diff --git a/zed/src/test.rs b/zed/src/test.rs index 4f28db9d28..e5ab3154f5 100644 --- a/zed/src/test.rs +++ b/zed/src/test.rs @@ -4,7 +4,7 @@ use crate::{ fs::RealFs, http::{HttpClient, Request, Response, ServerResponse}, language::LanguageRegistry, - rpc::{self, Client, Credentials}, + rpc::{self, Client, Credentials, EstablishConnectionError}, settings::{self, ThemeRegistry}, time::ReplicaId, user::UserStore, @@ -26,7 +26,7 @@ use std::{ }, }; use tempdir::TempDir; -use zrpc::{proto, Conn, ConnectionId, Peer, Receipt, TypedEnvelope}; +use zrpc::{proto, Connection, ConnectionId, Peer, Receipt, TypedEnvelope}; #[cfg(test)] #[ctor::ctor] @@ -210,6 +210,8 @@ pub struct FakeServer { connection_id: Mutex>, forbid_connections: AtomicBool, auth_count: AtomicUsize, + access_token: AtomicUsize, + user_id: u64, } impl FakeServer { @@ -224,6 +226,8 @@ impl FakeServer { connection_id: Default::default(), forbid_connections: Default::default(), auth_count: Default::default(), + access_token: Default::default(), + user_id: client_user_id, }); Arc::get_mut(client) @@ -232,8 +236,8 @@ impl FakeServer { let server = server.clone(); move |cx| { server.auth_count.fetch_add(1, SeqCst); + let access_token = server.access_token.load(SeqCst).to_string(); cx.spawn(move |_| async move { - let access_token = "the-token".to_string(); Ok(Credentials { user_id: client_user_id, access_token, @@ -244,11 +248,10 @@ impl FakeServer { .override_establish_connection({ let server = server.clone(); move |credentials, cx| { - assert_eq!(credentials.user_id, client_user_id); - assert_eq!(credentials.access_token, "the-token"); + let credentials = credentials.clone(); cx.spawn({ let server = server.clone(); - move |cx| async move { server.connect(&cx).await } + move |cx| async move { server.establish_connection(&credentials, &cx).await } }) } }); @@ -266,23 +269,39 @@ impl FakeServer { self.incoming.lock().take(); } - async fn connect(&self, cx: &AsyncAppContext) -> Result { + async fn establish_connection( + &self, + credentials: &Credentials, + cx: &AsyncAppContext, + ) -> Result { + assert_eq!(credentials.user_id, self.user_id); + if self.forbid_connections.load(SeqCst) { - Err(anyhow!("server is forbidding connections")) - } else { - let (client_conn, server_conn, _) = Conn::in_memory(); - let (connection_id, io, incoming) = self.peer.add_connection(server_conn).await; - cx.background().spawn(io).detach(); - *self.incoming.lock() = Some(incoming); - *self.connection_id.lock() = Some(connection_id); - Ok(client_conn) + Err(EstablishConnectionError::Other(anyhow!( + "server is forbidding connections" + )))? } + + if credentials.access_token != self.access_token.load(SeqCst).to_string() { + Err(EstablishConnectionError::InvalidAccessToken)? + } + + let (client_conn, server_conn, _) = Connection::in_memory(); + let (connection_id, io, incoming) = self.peer.add_connection(server_conn).await; + cx.background().spawn(io).detach(); + *self.incoming.lock() = Some(incoming); + *self.connection_id.lock() = Some(connection_id); + Ok(client_conn) } pub fn auth_count(&self) -> usize { self.auth_count.load(SeqCst) } + pub fn roll_access_token(&self) { + self.access_token.fetch_add(1, SeqCst); + } + pub fn forbid_connections(&self) { self.forbid_connections.store(true, SeqCst); } diff --git a/zrpc/src/conn.rs b/zrpc/src/conn.rs index e67b4fa587..5ca845d13f 100644 --- a/zrpc/src/conn.rs +++ b/zrpc/src/conn.rs @@ -2,7 +2,7 @@ use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSock use futures::{channel::mpsc, SinkExt as _, Stream, StreamExt as _}; use std::{io, task::Poll}; -pub struct Conn { +pub struct Connection { pub(crate) tx: Box>, pub(crate) rx: Box< @@ -13,7 +13,7 @@ pub struct Conn { >, } -impl Conn { +impl Connection { pub fn new(stream: S) -> Self where S: 'static diff --git a/zrpc/src/lib.rs b/zrpc/src/lib.rs index b3973cae19..a7bb44774b 100644 --- a/zrpc/src/lib.rs +++ b/zrpc/src/lib.rs @@ -2,5 +2,5 @@ pub mod auth; mod conn; mod peer; pub mod proto; -pub use conn::Conn; +pub use conn::Connection; pub use peer::*; diff --git a/zrpc/src/peer.rs b/zrpc/src/peer.rs index 75db257f55..eeda034e95 100644 --- a/zrpc/src/peer.rs +++ b/zrpc/src/peer.rs @@ -1,5 +1,5 @@ use super::proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, RequestMessage}; -use super::Conn; +use super::Connection; use anyhow::{anyhow, Context, Result}; use async_lock::{Mutex, RwLock}; use futures::FutureExt as _; @@ -79,12 +79,12 @@ impl TypedEnvelope { } pub struct Peer { - connections: RwLock>, + connections: RwLock>, next_connection_id: AtomicU32, } #[derive(Clone)] -struct Connection { +struct ConnectionState { outgoing_tx: mpsc::Sender, next_message_id: Arc, response_channels: Arc>>>, @@ -100,7 +100,7 @@ impl Peer { pub async fn add_connection( self: &Arc, - conn: Conn, + connection: Connection, ) -> ( ConnectionId, impl Future> + Send, @@ -112,16 +112,16 @@ impl Peer { ); let (mut incoming_tx, incoming_rx) = mpsc::channel(64); let (outgoing_tx, mut outgoing_rx) = mpsc::channel(64); - let connection = Connection { + let connection_state = ConnectionState { outgoing_tx, next_message_id: Default::default(), response_channels: Default::default(), }; - let mut writer = MessageStream::new(conn.tx); - let mut reader = MessageStream::new(conn.rx); + let mut writer = MessageStream::new(connection.tx); + let mut reader = MessageStream::new(connection.rx); let this = self.clone(); - let response_channels = connection.response_channels.clone(); + let response_channels = connection_state.response_channels.clone(); let handle_io = async move { loop { let read_message = reader.read_message().fuse(); @@ -179,7 +179,7 @@ impl Peer { self.connections .write() .await - .insert(connection_id, connection); + .insert(connection_id, connection_state); (connection_id, handle_io, incoming_rx) } @@ -218,7 +218,7 @@ impl Peer { let this = self.clone(); let (tx, mut rx) = mpsc::channel(1); async move { - let mut connection = this.connection(receiver_id).await?; + let mut connection = this.connection_state(receiver_id).await?; let message_id = connection .next_message_id .fetch_add(1, atomic::Ordering::SeqCst); @@ -252,7 +252,7 @@ impl Peer { ) -> impl Future> { let this = self.clone(); async move { - let mut connection = this.connection(receiver_id).await?; + let mut connection = this.connection_state(receiver_id).await?; let message_id = connection .next_message_id .fetch_add(1, atomic::Ordering::SeqCst); @@ -272,7 +272,7 @@ impl Peer { ) -> impl Future> { let this = self.clone(); async move { - let mut connection = this.connection(receiver_id).await?; + let mut connection = this.connection_state(receiver_id).await?; let message_id = connection .next_message_id .fetch_add(1, atomic::Ordering::SeqCst); @@ -291,7 +291,7 @@ impl Peer { ) -> impl Future> { let this = self.clone(); async move { - let mut connection = this.connection(receipt.sender_id).await?; + let mut connection = this.connection_state(receipt.sender_id).await?; let message_id = connection .next_message_id .fetch_add(1, atomic::Ordering::SeqCst); @@ -310,7 +310,7 @@ impl Peer { ) -> impl Future> { let this = self.clone(); async move { - let mut connection = this.connection(receipt.sender_id).await?; + let mut connection = this.connection_state(receipt.sender_id).await?; let message_id = connection .next_message_id .fetch_add(1, atomic::Ordering::SeqCst); @@ -322,10 +322,10 @@ impl Peer { } } - fn connection( + fn connection_state( self: &Arc, connection_id: ConnectionId, - ) -> impl Future> { + ) -> impl Future> { let this = self.clone(); async move { let connections = this.connections.read().await; @@ -352,12 +352,12 @@ mod tests { let client1 = Peer::new(); let client2 = Peer::new(); - let (client1_to_server_conn, server_to_client_1_conn, _) = Conn::in_memory(); + let (client1_to_server_conn, server_to_client_1_conn, _) = Connection::in_memory(); let (client1_conn_id, io_task1, _) = client1.add_connection(client1_to_server_conn).await; let (_, io_task2, incoming1) = server.add_connection(server_to_client_1_conn).await; - let (client2_to_server_conn, server_to_client_2_conn, _) = Conn::in_memory(); + let (client2_to_server_conn, server_to_client_2_conn, _) = Connection::in_memory(); let (client2_conn_id, io_task3, _) = client2.add_connection(client2_to_server_conn).await; let (_, io_task4, incoming2) = server.add_connection(server_to_client_2_conn).await; @@ -486,7 +486,7 @@ mod tests { #[test] fn test_disconnect() { smol::block_on(async move { - let (client_conn, mut server_conn, _) = Conn::in_memory(); + let (client_conn, mut server_conn, _) = Connection::in_memory(); let client = Peer::new(); let (connection_id, io_handler, mut incoming) = @@ -520,7 +520,7 @@ mod tests { #[test] fn test_io_error() { smol::block_on(async move { - let (client_conn, server_conn, _) = Conn::in_memory(); + let (client_conn, server_conn, _) = Connection::in_memory(); drop(server_conn); let client = Peer::new();