WIP: Clear cached credentials if authentication fails

Still need to actually handle an HTTP response from the server indicating there was an invalid token.

Co-Authored-By: Max Brunsfeld <maxbrunsfeld@gmail.com>
This commit is contained in:
Nathan Sobo 2021-09-14 19:19:11 -06:00
parent 77a4a36eb3
commit 4a9918979e
8 changed files with 149 additions and 67 deletions

9
Cargo.lock generated
View File

@ -5108,18 +5108,18 @@ dependencies = [
[[package]]
name = "thiserror"
version = "1.0.24"
version = "1.0.29"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e0f4a65597094d4483ddaed134f409b2cb7c1beccf25201a9f73c719254fa98e"
checksum = "602eca064b2d83369e2b2f34b09c70b605402801927c65c11071ac911d299b88"
dependencies = [
"thiserror-impl",
]
[[package]]
name = "thiserror-impl"
version = "1.0.24"
version = "1.0.29"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7765189610d8241a44529806d6fd1f2e0a08734313a35d5b3a556f92b381f3c0"
checksum = "bad553cc2c78e8de258400763a647e80e6d1b31ee237275d756f6836d204494c"
dependencies = [
"proc-macro2",
"quote",
@ -5914,6 +5914,7 @@ dependencies = [
"smol",
"surf",
"tempdir",
"thiserror",
"time 0.3.2",
"tiny_http",
"toml 0.5.8",

View File

@ -27,7 +27,7 @@ use time::OffsetDateTime;
use zrpc::{
auth::random_token,
proto::{self, AnyTypedEnvelope, EnvelopedMessage},
Conn, ConnectionId, Peer, TypedEnvelope,
Connection, ConnectionId, Peer, TypedEnvelope,
};
type ReplicaId = u16;
@ -48,13 +48,13 @@ pub struct Server {
#[derive(Default)]
struct ServerState {
connections: HashMap<ConnectionId, Connection>,
connections: HashMap<ConnectionId, ConnectionState>,
pub worktrees: HashMap<u64, Worktree>,
channels: HashMap<ChannelId, Channel>,
next_worktree_id: u64,
}
struct Connection {
struct ConnectionState {
user_id: UserId,
worktrees: HashSet<u64>,
channels: HashSet<ChannelId>,
@ -133,7 +133,7 @@ impl Server {
pub fn handle_connection(
self: &Arc<Self>,
connection: Conn,
connection: Connection,
addr: String,
user_id: UserId,
) -> impl Future<Output = ()> {
@ -211,7 +211,7 @@ impl Server {
async fn add_connection(&self, connection_id: ConnectionId, user_id: UserId) {
self.state.write().await.connections.insert(
connection_id,
Connection {
ConnectionState {
user_id,
worktrees: Default::default(),
channels: Default::default(),
@ -972,7 +972,7 @@ pub fn add_routes(app: &mut tide::Server<Arc<AppState>>, rpc: &Arc<Peer>) {
let user_id = user_id.ok_or_else(|| anyhow!("user_id is not present on request. ensure auth::VerifyToken middleware is present"))?;
task::spawn(async move {
if let Some(stream) = upgrade_receiver.await {
server.handle_connection(Conn::new(WebSocketStream::from_raw_socket(stream, Role::Server, None).await), addr, user_id).await;
server.handle_connection(Connection::new(WebSocketStream::from_raw_socket(stream, Role::Server, None).await), addr, user_id).await;
}
});
@ -1023,7 +1023,7 @@ mod tests {
editor::{Editor, Insert},
fs::{FakeFs, Fs as _},
language::LanguageRegistry,
rpc::{self, Client, Credentials},
rpc::{self, Client, Credentials, EstablishConnectionError},
settings,
test::FakeHttpClient,
user::UserStore,
@ -1941,9 +1941,11 @@ mod tests {
let client_name = client_name.clone();
cx.spawn(move |cx| async move {
if forbid_connections.load(SeqCst) {
Err(anyhow!("server is forbidding connections"))
Err(EstablishConnectionError::other(anyhow!(
"server is forbidding connections"
)))
} else {
let (client_conn, server_conn, kill_conn) = Conn::in_memory();
let (client_conn, server_conn, kill_conn) = Connection::in_memory();
connection_killers.lock().insert(client_user_id, kill_conn);
cx.background()
.spawn(server.handle_connection(

View File

@ -50,6 +50,7 @@ smallvec = { version = "1.6", features = ["union"] }
smol = "1.2.5"
surf = "2.2"
tempdir = { version = "0.3.7", optional = true }
thiserror = "1.0.29"
time = { version = "0.3" }
tiny_http = "0.8"
toml = "0.5"

View File

@ -15,10 +15,11 @@ use std::{
time::{Duration, Instant},
};
use surf::Url;
use thiserror::Error;
pub use zrpc::{proto, ConnectionId, PeerId, TypedEnvelope};
use zrpc::{
proto::{AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage},
Conn, Peer, Receipt,
Connection, Peer, Receipt,
};
lazy_static! {
@ -32,10 +33,32 @@ pub struct Client {
authenticate:
Option<Box<dyn 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>>>,
establish_connection: Option<
Box<dyn 'static + Send + Sync + Fn(&Credentials, &AsyncAppContext) -> Task<Result<Conn>>>,
Box<
dyn 'static
+ Send
+ Sync
+ Fn(
&Credentials,
&AsyncAppContext,
) -> Task<Result<Connection, EstablishConnectionError>>,
>,
>,
}
#[derive(Error, Debug)]
pub enum EstablishConnectionError {
#[error("invalid access token")]
InvalidAccessToken,
#[error("{0}")]
Other(anyhow::Error),
}
impl EstablishConnectionError {
pub fn other(error: impl Into<anyhow::Error> + Send + Sync) -> Self {
Self::Other(error.into())
}
}
#[derive(Copy, Clone, Debug)]
pub enum Status {
SignedOut,
@ -122,7 +145,10 @@ impl Client {
#[cfg(any(test, feature = "test-support"))]
pub fn override_establish_connection<F>(&mut self, connect: F) -> &mut Self
where
F: 'static + Send + Sync + Fn(&Credentials, &AsyncAppContext) -> Task<Result<Conn>>,
F: 'static
+ Send
+ Sync
+ Fn(&Credentials, &AsyncAppContext) -> Task<Result<Connection, EstablishConnectionError>>,
{
self.establish_connection = Some(Box::new(connect));
self
@ -288,13 +314,18 @@ impl Client {
Ok(())
}
Err(err) => {
eprintln!("error in authenticate and connect {}", err);
if matches!(err, EstablishConnectionError::InvalidAccessToken) {
eprintln!("nuking credentials");
self.state.write().credentials.take();
}
self.set_status(Status::ConnectionError, cx);
Err(err)
Err(err)?
}
}
}
async fn set_connection(self: &Arc<Self>, conn: Conn, cx: &AsyncAppContext) {
async fn set_connection(self: &Arc<Self>, conn: Connection, cx: &AsyncAppContext) {
let (connection_id, handle_io, mut incoming) = self.peer.add_connection(conn).await;
cx.foreground()
.spawn({
@ -359,7 +390,7 @@ impl Client {
self: &Arc<Self>,
credentials: &Credentials,
cx: &AsyncAppContext,
) -> Task<Result<Conn>> {
) -> Task<Result<Connection, EstablishConnectionError>> {
if let Some(callback) = self.establish_connection.as_ref() {
callback(credentials, cx)
} else {
@ -371,28 +402,43 @@ impl Client {
self: &Arc<Self>,
credentials: &Credentials,
cx: &AsyncAppContext,
) -> Task<Result<Conn>> {
) -> Task<Result<Connection, EstablishConnectionError>> {
let request = Request::builder().header(
"Authorization",
format!("{} {}", credentials.user_id, credentials.access_token),
);
cx.background().spawn(async move {
if let Some(host) = ZED_SERVER_URL.strip_prefix("https://") {
let stream = smol::net::TcpStream::connect(host).await?;
let request = request.uri(format!("wss://{}/rpc", host)).body(())?;
let stream = smol::net::TcpStream::connect(host)
.await
.map_err(EstablishConnectionError::other)?;
let request = request
.uri(format!("wss://{}/rpc", host))
.body(())
.map_err(EstablishConnectionError::other)?;
let (stream, _) = async_tungstenite::async_tls::client_async_tls(request, stream)
.await
.context("websocket handshake")?;
Ok(Conn::new(stream))
.context("websocket handshake")
.map_err(EstablishConnectionError::other)?;
Ok(Connection::new(stream))
} else if let Some(host) = ZED_SERVER_URL.strip_prefix("http://") {
let stream = smol::net::TcpStream::connect(host).await?;
let request = request.uri(format!("ws://{}/rpc", host)).body(())?;
let stream = smol::net::TcpStream::connect(host)
.await
.map_err(EstablishConnectionError::other)?;
let request = request
.uri(format!("ws://{}/rpc", host))
.body(())
.map_err(EstablishConnectionError::other)?;
let (stream, _) = async_tungstenite::client_async(request, stream)
.await
.context("websocket handshake")?;
Ok(Conn::new(stream))
.context("websocket handshake")
.map_err(EstablishConnectionError::other)?;
Ok(Connection::new(stream))
} else {
Err(anyhow!("invalid server url: {}", *ZED_SERVER_URL))
Err(EstablishConnectionError::other(anyhow!(
"invalid server url: {}",
*ZED_SERVER_URL
)))
}
})
}
@ -591,6 +637,19 @@ mod tests {
cx.foreground().advance_clock(Duration::from_secs(10));
while !matches!(status.recv().await, Some(Status::Connected { .. })) {}
assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting
server.forbid_connections();
server.disconnect().await;
while !matches!(status.recv().await, Some(Status::ReconnectionError { .. })) {}
// Clear cached credentials after authentication fails
server.roll_access_token();
server.allow_connections();
cx.foreground().advance_clock(Duration::from_secs(10));
assert_eq!(server.auth_count(), 1);
cx.foreground().advance_clock(Duration::from_secs(10));
while !matches!(status.recv().await, Some(Status::Connected { .. })) {}
assert_eq!(server.auth_count(), 2); // Client re-authenticated due to an invalid token
}
#[test]

View File

@ -4,7 +4,7 @@ use crate::{
fs::RealFs,
http::{HttpClient, Request, Response, ServerResponse},
language::LanguageRegistry,
rpc::{self, Client, Credentials},
rpc::{self, Client, Credentials, EstablishConnectionError},
settings::{self, ThemeRegistry},
time::ReplicaId,
user::UserStore,
@ -26,7 +26,7 @@ use std::{
},
};
use tempdir::TempDir;
use zrpc::{proto, Conn, ConnectionId, Peer, Receipt, TypedEnvelope};
use zrpc::{proto, Connection, ConnectionId, Peer, Receipt, TypedEnvelope};
#[cfg(test)]
#[ctor::ctor]
@ -210,6 +210,8 @@ pub struct FakeServer {
connection_id: Mutex<Option<ConnectionId>>,
forbid_connections: AtomicBool,
auth_count: AtomicUsize,
access_token: AtomicUsize,
user_id: u64,
}
impl FakeServer {
@ -224,6 +226,8 @@ impl FakeServer {
connection_id: Default::default(),
forbid_connections: Default::default(),
auth_count: Default::default(),
access_token: Default::default(),
user_id: client_user_id,
});
Arc::get_mut(client)
@ -232,8 +236,8 @@ impl FakeServer {
let server = server.clone();
move |cx| {
server.auth_count.fetch_add(1, SeqCst);
let access_token = server.access_token.load(SeqCst).to_string();
cx.spawn(move |_| async move {
let access_token = "the-token".to_string();
Ok(Credentials {
user_id: client_user_id,
access_token,
@ -244,11 +248,10 @@ impl FakeServer {
.override_establish_connection({
let server = server.clone();
move |credentials, cx| {
assert_eq!(credentials.user_id, client_user_id);
assert_eq!(credentials.access_token, "the-token");
let credentials = credentials.clone();
cx.spawn({
let server = server.clone();
move |cx| async move { server.connect(&cx).await }
move |cx| async move { server.establish_connection(&credentials, &cx).await }
})
}
});
@ -266,23 +269,39 @@ impl FakeServer {
self.incoming.lock().take();
}
async fn connect(&self, cx: &AsyncAppContext) -> Result<Conn> {
async fn establish_connection(
&self,
credentials: &Credentials,
cx: &AsyncAppContext,
) -> Result<Connection, EstablishConnectionError> {
assert_eq!(credentials.user_id, self.user_id);
if self.forbid_connections.load(SeqCst) {
Err(anyhow!("server is forbidding connections"))
} else {
let (client_conn, server_conn, _) = Conn::in_memory();
let (connection_id, io, incoming) = self.peer.add_connection(server_conn).await;
cx.background().spawn(io).detach();
*self.incoming.lock() = Some(incoming);
*self.connection_id.lock() = Some(connection_id);
Ok(client_conn)
Err(EstablishConnectionError::Other(anyhow!(
"server is forbidding connections"
)))?
}
if credentials.access_token != self.access_token.load(SeqCst).to_string() {
Err(EstablishConnectionError::InvalidAccessToken)?
}
let (client_conn, server_conn, _) = Connection::in_memory();
let (connection_id, io, incoming) = self.peer.add_connection(server_conn).await;
cx.background().spawn(io).detach();
*self.incoming.lock() = Some(incoming);
*self.connection_id.lock() = Some(connection_id);
Ok(client_conn)
}
pub fn auth_count(&self) -> usize {
self.auth_count.load(SeqCst)
}
pub fn roll_access_token(&self) {
self.access_token.fetch_add(1, SeqCst);
}
pub fn forbid_connections(&self) {
self.forbid_connections.store(true, SeqCst);
}

View File

@ -2,7 +2,7 @@ use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSock
use futures::{channel::mpsc, SinkExt as _, Stream, StreamExt as _};
use std::{io, task::Poll};
pub struct Conn {
pub struct Connection {
pub(crate) tx:
Box<dyn 'static + Send + Unpin + futures::Sink<WebSocketMessage, Error = WebSocketError>>,
pub(crate) rx: Box<
@ -13,7 +13,7 @@ pub struct Conn {
>,
}
impl Conn {
impl Connection {
pub fn new<S>(stream: S) -> Self
where
S: 'static

View File

@ -2,5 +2,5 @@ pub mod auth;
mod conn;
mod peer;
pub mod proto;
pub use conn::Conn;
pub use conn::Connection;
pub use peer::*;

View File

@ -1,5 +1,5 @@
use super::proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, RequestMessage};
use super::Conn;
use super::Connection;
use anyhow::{anyhow, Context, Result};
use async_lock::{Mutex, RwLock};
use futures::FutureExt as _;
@ -79,12 +79,12 @@ impl<T: RequestMessage> TypedEnvelope<T> {
}
pub struct Peer {
connections: RwLock<HashMap<ConnectionId, Connection>>,
connections: RwLock<HashMap<ConnectionId, ConnectionState>>,
next_connection_id: AtomicU32,
}
#[derive(Clone)]
struct Connection {
struct ConnectionState {
outgoing_tx: mpsc::Sender<proto::Envelope>,
next_message_id: Arc<AtomicU32>,
response_channels: Arc<Mutex<HashMap<u32, mpsc::Sender<proto::Envelope>>>>,
@ -100,7 +100,7 @@ impl Peer {
pub async fn add_connection(
self: &Arc<Self>,
conn: Conn,
connection: Connection,
) -> (
ConnectionId,
impl Future<Output = anyhow::Result<()>> + Send,
@ -112,16 +112,16 @@ impl Peer {
);
let (mut incoming_tx, incoming_rx) = mpsc::channel(64);
let (outgoing_tx, mut outgoing_rx) = mpsc::channel(64);
let connection = Connection {
let connection_state = ConnectionState {
outgoing_tx,
next_message_id: Default::default(),
response_channels: Default::default(),
};
let mut writer = MessageStream::new(conn.tx);
let mut reader = MessageStream::new(conn.rx);
let mut writer = MessageStream::new(connection.tx);
let mut reader = MessageStream::new(connection.rx);
let this = self.clone();
let response_channels = connection.response_channels.clone();
let response_channels = connection_state.response_channels.clone();
let handle_io = async move {
loop {
let read_message = reader.read_message().fuse();
@ -179,7 +179,7 @@ impl Peer {
self.connections
.write()
.await
.insert(connection_id, connection);
.insert(connection_id, connection_state);
(connection_id, handle_io, incoming_rx)
}
@ -218,7 +218,7 @@ impl Peer {
let this = self.clone();
let (tx, mut rx) = mpsc::channel(1);
async move {
let mut connection = this.connection(receiver_id).await?;
let mut connection = this.connection_state(receiver_id).await?;
let message_id = connection
.next_message_id
.fetch_add(1, atomic::Ordering::SeqCst);
@ -252,7 +252,7 @@ impl Peer {
) -> impl Future<Output = Result<()>> {
let this = self.clone();
async move {
let mut connection = this.connection(receiver_id).await?;
let mut connection = this.connection_state(receiver_id).await?;
let message_id = connection
.next_message_id
.fetch_add(1, atomic::Ordering::SeqCst);
@ -272,7 +272,7 @@ impl Peer {
) -> impl Future<Output = Result<()>> {
let this = self.clone();
async move {
let mut connection = this.connection(receiver_id).await?;
let mut connection = this.connection_state(receiver_id).await?;
let message_id = connection
.next_message_id
.fetch_add(1, atomic::Ordering::SeqCst);
@ -291,7 +291,7 @@ impl Peer {
) -> impl Future<Output = Result<()>> {
let this = self.clone();
async move {
let mut connection = this.connection(receipt.sender_id).await?;
let mut connection = this.connection_state(receipt.sender_id).await?;
let message_id = connection
.next_message_id
.fetch_add(1, atomic::Ordering::SeqCst);
@ -310,7 +310,7 @@ impl Peer {
) -> impl Future<Output = Result<()>> {
let this = self.clone();
async move {
let mut connection = this.connection(receipt.sender_id).await?;
let mut connection = this.connection_state(receipt.sender_id).await?;
let message_id = connection
.next_message_id
.fetch_add(1, atomic::Ordering::SeqCst);
@ -322,10 +322,10 @@ impl Peer {
}
}
fn connection(
fn connection_state(
self: &Arc<Self>,
connection_id: ConnectionId,
) -> impl Future<Output = Result<Connection>> {
) -> impl Future<Output = Result<ConnectionState>> {
let this = self.clone();
async move {
let connections = this.connections.read().await;
@ -352,12 +352,12 @@ mod tests {
let client1 = Peer::new();
let client2 = Peer::new();
let (client1_to_server_conn, server_to_client_1_conn, _) = Conn::in_memory();
let (client1_to_server_conn, server_to_client_1_conn, _) = Connection::in_memory();
let (client1_conn_id, io_task1, _) =
client1.add_connection(client1_to_server_conn).await;
let (_, io_task2, incoming1) = server.add_connection(server_to_client_1_conn).await;
let (client2_to_server_conn, server_to_client_2_conn, _) = Conn::in_memory();
let (client2_to_server_conn, server_to_client_2_conn, _) = Connection::in_memory();
let (client2_conn_id, io_task3, _) =
client2.add_connection(client2_to_server_conn).await;
let (_, io_task4, incoming2) = server.add_connection(server_to_client_2_conn).await;
@ -486,7 +486,7 @@ mod tests {
#[test]
fn test_disconnect() {
smol::block_on(async move {
let (client_conn, mut server_conn, _) = Conn::in_memory();
let (client_conn, mut server_conn, _) = Connection::in_memory();
let client = Peer::new();
let (connection_id, io_handler, mut incoming) =
@ -520,7 +520,7 @@ mod tests {
#[test]
fn test_io_error() {
smol::block_on(async move {
let (client_conn, server_conn, _) = Conn::in_memory();
let (client_conn, server_conn, _) = Connection::in_memory();
drop(server_conn);
let client = Peer::new();