graphql-ws: Implement connection expiry (#1281)

<!-- The PR description should answer 2 important questions: -->
TODO:
- ~Add a test for parsing connection expiry from headers~

### What

<!-- What is this PR trying to accomplish (and why, if it's not
obvious)? -->

<!-- Consider: do we need to add a changelog entry? -->

<!-- Does this PR introduce new validation that might break old builds?
-->

<!-- Consider: do we need to put new checks behind a flag? -->
Expire the connection if the duration is set in the `Context`.

### How
Spawn a background thread to wait until the expiry. Send a close message
after waiting completes.

<!-- How is it trying to accomplish it (what are the implementation
steps)? -->

V3_GIT_ORIGIN_REV_ID: a9d06c57d75cc87abc3c470ee096a99a8f378a9a
This commit is contained in:
Rakesh Emmadi 2024-10-28 20:11:16 +05:30 committed by hasura-bot
parent 161f40636c
commit a6ed625cd7
8 changed files with 76 additions and 1 deletions

View File

@ -801,6 +801,7 @@ async fn handle_websocket_request(
) -> impl IntoResponse {
// Create the context for the websocket server
let context = graphql_ws::Context {
connection_expiry: graphql_ws::ConnectionExpiry::Never,
http_context: engine_state.http_context,
project_id: None, // project_id is not needed for OSS v3-engine.
expose_internal_errors: engine_state.expose_internal_errors,

View File

@ -727,6 +727,7 @@ async fn run_query_graphql_ws(
});
let context = graphql_ws::Context {
connection_expiry: graphql_ws::ConnectionExpiry::Never,
http_context: http_context.clone(),
expose_internal_errors,
project_id: project_id.cloned(),

View File

@ -4,7 +4,7 @@ pub(crate) mod websocket;
pub use protocol::types::OperationId;
pub use websocket::{
types::{ActiveConnection, Context, WebSocketId},
types::{ActiveConnection, ConnectionExpiry, Context, WebSocketId},
WebSocketServer,
};

View File

@ -178,6 +178,7 @@ async fn start_websocket_session(
parent_span_link,
|| {
Box::pin(async {
let connection_expiry = context.connection_expiry.clone();
// Split the socket into a sender and receiver
let (websocket_sender, websocket_receiver) = socket.split();
@ -192,6 +193,19 @@ async fn start_websocket_session(
let this_span_link = tracing_util::SpanLink::from_current_span();
let expiry_task = match connection_expiry {
types::ConnectionExpiry::Never => None,
types::ConnectionExpiry::After(duration) => {
// Spawn a task to wait until the connection expires
// The task will send a close message to the client after the expiry duration.
// Sending a close message will make the outgoing task exit, thus closing the connection.
let connection = connection.clone();
Some(tokio::spawn(async move {
tasks::wait_until_expiry(connection, duration).await;
}))
}
};
// Spawn a task to verify the graphql-ws connection_init state with a timeout
let init_checker_task = tokio::spawn(tasks::verify_connection_init(
connection.clone(),
@ -246,6 +260,10 @@ async fn start_websocket_session(
}
};
// Abort the expiry task
if let Some(task) = expiry_task {
task.abort();
}
// Remove the connection from the active connections map
connections.drop(&connection.id).await;

View File

@ -71,6 +71,12 @@ async fn wait_for_initialization(connection: types::Connection) {
}
}
/// Waits until the connection expires and sends a close message.
pub(crate) async fn wait_until_expiry(connection: types::Connection, expiry: std::time::Duration) {
tokio::time::sleep(expiry).await;
connection.send(types::Message::conn_expired()).await;
}
/// Handles incoming WebSocket messages from the client.
/// This task runs indefinitely until the connection is closed or an error occurs.
pub(crate) async fn process_incoming_message(

View File

@ -15,6 +15,7 @@ use crate::protocol::types as protocol;
/// Context required to handle a WebSocket connection
#[derive(Clone)] // Cheap to clone as heavy fields are wrapped in `Arc`
pub struct Context {
pub connection_expiry: ConnectionExpiry,
pub http_context: HttpContext,
pub expose_internal_errors: ExposeInternalErrors,
pub project_id: Option<ProjectId>,
@ -40,6 +41,13 @@ impl Default for WebSocketId {
}
}
/// Configures the expiry for a WebSocket connection
#[derive(Clone)]
pub enum ConnectionExpiry {
Never,
After(std::time::Duration),
}
/// A mutable and free clone-able collection of WebSocket connections.
#[derive(Clone)]
pub struct Connections(pub Arc<RwLock<HashMap<WebSocketId, Connection>>>);
@ -213,6 +221,13 @@ impl Message {
pub fn force_reconnect(message: &'static str) -> Self {
Self::close_message(1012, message)
}
/// Connection expired
pub fn conn_expired() -> Self {
// THe 1013 code is used to indicate "Try Again Later".
// The session is expired and the client should try to reconnect.
Self::close_message(1013, "WebSocket session expired")
}
}
/// Creates a close WebSocket message with the specified code and reason.

View File

@ -45,6 +45,13 @@ pub(crate) async fn ws_handler(
#[allow(dead_code)]
pub(crate) async fn start_websocket_server() -> TestServer {
start_websocket_server_expiry(graphql_ws::ConnectionExpiry::Never).await
}
#[allow(dead_code)]
pub(crate) async fn start_websocket_server_expiry(
expiry: graphql_ws::ConnectionExpiry,
) -> TestServer {
// Create a TCP listener
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
@ -79,6 +86,7 @@ pub(crate) async fn start_websocket_server() -> TestServer {
pre_response_plugins: Vec::new(),
};
let context = Context {
connection_expiry: expiry,
http_context,
expose_internal_errors: execute::ExposeInternalErrors::Expose,
project_id: None,

View File

@ -384,3 +384,29 @@ async fn test_graphql_ws_subscribe_user_1_validation_error() {
assert_zero_connections_timeout(connections).await;
server_handle.abort();
}
#[tokio::test]
async fn test_graphql_ws_connection_expiry() {
// Expiry in 4 seconds
let expiry = graphql_ws::ConnectionExpiry::After(std::time::Duration::from_secs(4));
let TestServer {
connections,
mut socket,
server_handle,
} = start_websocket_server_expiry(expiry).await;
// Send connection_init and check ack
graphql_ws_connection_init(&mut socket, connection_init_admin()).await;
// Wait for the connection to expire
tokio::time::sleep(std::time::Duration::from_secs(5)).await;
// Wait for a close message
let message = expect_close_message(&mut socket).await;
// Check close code
let close_code = tungstenite::protocol::frame::coding::CloseCode::Again;
if let tungstenite::Message::Close(Some(close_frame)) = message {
assert_eq!(close_frame.code, close_code);
assert_eq!(close_frame.reason, "WebSocket session expired");
}
// Assert zero connections
assert_zero_connections_timeout(connections).await;
server_handle.abort();
}