mirror of
https://github.com/hasura/graphql-engine.git
synced 2024-12-15 01:12:56 +03:00
graphql-ws: Implement connection expiry (#1281)
<!-- The PR description should answer 2 important questions: --> TODO: - ~Add a test for parsing connection expiry from headers~ ### What <!-- What is this PR trying to accomplish (and why, if it's not obvious)? --> <!-- Consider: do we need to add a changelog entry? --> <!-- Does this PR introduce new validation that might break old builds? --> <!-- Consider: do we need to put new checks behind a flag? --> Expire the connection if the duration is set in the `Context`. ### How Spawn a background thread to wait until the expiry. Send a close message after waiting completes. <!-- How is it trying to accomplish it (what are the implementation steps)? --> V3_GIT_ORIGIN_REV_ID: a9d06c57d75cc87abc3c470ee096a99a8f378a9a
This commit is contained in:
parent
161f40636c
commit
a6ed625cd7
@ -801,6 +801,7 @@ async fn handle_websocket_request(
|
||||
) -> impl IntoResponse {
|
||||
// Create the context for the websocket server
|
||||
let context = graphql_ws::Context {
|
||||
connection_expiry: graphql_ws::ConnectionExpiry::Never,
|
||||
http_context: engine_state.http_context,
|
||||
project_id: None, // project_id is not needed for OSS v3-engine.
|
||||
expose_internal_errors: engine_state.expose_internal_errors,
|
||||
|
@ -727,6 +727,7 @@ async fn run_query_graphql_ws(
|
||||
});
|
||||
|
||||
let context = graphql_ws::Context {
|
||||
connection_expiry: graphql_ws::ConnectionExpiry::Never,
|
||||
http_context: http_context.clone(),
|
||||
expose_internal_errors,
|
||||
project_id: project_id.cloned(),
|
||||
|
@ -4,7 +4,7 @@ pub(crate) mod websocket;
|
||||
|
||||
pub use protocol::types::OperationId;
|
||||
pub use websocket::{
|
||||
types::{ActiveConnection, Context, WebSocketId},
|
||||
types::{ActiveConnection, ConnectionExpiry, Context, WebSocketId},
|
||||
WebSocketServer,
|
||||
};
|
||||
|
||||
|
@ -178,6 +178,7 @@ async fn start_websocket_session(
|
||||
parent_span_link,
|
||||
|| {
|
||||
Box::pin(async {
|
||||
let connection_expiry = context.connection_expiry.clone();
|
||||
// Split the socket into a sender and receiver
|
||||
let (websocket_sender, websocket_receiver) = socket.split();
|
||||
|
||||
@ -192,6 +193,19 @@ async fn start_websocket_session(
|
||||
|
||||
let this_span_link = tracing_util::SpanLink::from_current_span();
|
||||
|
||||
let expiry_task = match connection_expiry {
|
||||
types::ConnectionExpiry::Never => None,
|
||||
types::ConnectionExpiry::After(duration) => {
|
||||
// Spawn a task to wait until the connection expires
|
||||
// The task will send a close message to the client after the expiry duration.
|
||||
// Sending a close message will make the outgoing task exit, thus closing the connection.
|
||||
let connection = connection.clone();
|
||||
Some(tokio::spawn(async move {
|
||||
tasks::wait_until_expiry(connection, duration).await;
|
||||
}))
|
||||
}
|
||||
};
|
||||
|
||||
// Spawn a task to verify the graphql-ws connection_init state with a timeout
|
||||
let init_checker_task = tokio::spawn(tasks::verify_connection_init(
|
||||
connection.clone(),
|
||||
@ -246,6 +260,10 @@ async fn start_websocket_session(
|
||||
}
|
||||
};
|
||||
|
||||
// Abort the expiry task
|
||||
if let Some(task) = expiry_task {
|
||||
task.abort();
|
||||
}
|
||||
// Remove the connection from the active connections map
|
||||
connections.drop(&connection.id).await;
|
||||
|
||||
|
@ -71,6 +71,12 @@ async fn wait_for_initialization(connection: types::Connection) {
|
||||
}
|
||||
}
|
||||
|
||||
/// Waits until the connection expires and sends a close message.
|
||||
pub(crate) async fn wait_until_expiry(connection: types::Connection, expiry: std::time::Duration) {
|
||||
tokio::time::sleep(expiry).await;
|
||||
connection.send(types::Message::conn_expired()).await;
|
||||
}
|
||||
|
||||
/// Handles incoming WebSocket messages from the client.
|
||||
/// This task runs indefinitely until the connection is closed or an error occurs.
|
||||
pub(crate) async fn process_incoming_message(
|
||||
|
@ -15,6 +15,7 @@ use crate::protocol::types as protocol;
|
||||
/// Context required to handle a WebSocket connection
|
||||
#[derive(Clone)] // Cheap to clone as heavy fields are wrapped in `Arc`
|
||||
pub struct Context {
|
||||
pub connection_expiry: ConnectionExpiry,
|
||||
pub http_context: HttpContext,
|
||||
pub expose_internal_errors: ExposeInternalErrors,
|
||||
pub project_id: Option<ProjectId>,
|
||||
@ -40,6 +41,13 @@ impl Default for WebSocketId {
|
||||
}
|
||||
}
|
||||
|
||||
/// Configures the expiry for a WebSocket connection
|
||||
#[derive(Clone)]
|
||||
pub enum ConnectionExpiry {
|
||||
Never,
|
||||
After(std::time::Duration),
|
||||
}
|
||||
|
||||
/// A mutable and free clone-able collection of WebSocket connections.
|
||||
#[derive(Clone)]
|
||||
pub struct Connections(pub Arc<RwLock<HashMap<WebSocketId, Connection>>>);
|
||||
@ -213,6 +221,13 @@ impl Message {
|
||||
pub fn force_reconnect(message: &'static str) -> Self {
|
||||
Self::close_message(1012, message)
|
||||
}
|
||||
|
||||
/// Connection expired
|
||||
pub fn conn_expired() -> Self {
|
||||
// THe 1013 code is used to indicate "Try Again Later".
|
||||
// The session is expired and the client should try to reconnect.
|
||||
Self::close_message(1013, "WebSocket session expired")
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a close WebSocket message with the specified code and reason.
|
||||
|
@ -45,6 +45,13 @@ pub(crate) async fn ws_handler(
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub(crate) async fn start_websocket_server() -> TestServer {
|
||||
start_websocket_server_expiry(graphql_ws::ConnectionExpiry::Never).await
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub(crate) async fn start_websocket_server_expiry(
|
||||
expiry: graphql_ws::ConnectionExpiry,
|
||||
) -> TestServer {
|
||||
// Create a TCP listener
|
||||
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let addr = listener.local_addr().unwrap();
|
||||
@ -79,6 +86,7 @@ pub(crate) async fn start_websocket_server() -> TestServer {
|
||||
pre_response_plugins: Vec::new(),
|
||||
};
|
||||
let context = Context {
|
||||
connection_expiry: expiry,
|
||||
http_context,
|
||||
expose_internal_errors: execute::ExposeInternalErrors::Expose,
|
||||
project_id: None,
|
||||
|
@ -384,3 +384,29 @@ async fn test_graphql_ws_subscribe_user_1_validation_error() {
|
||||
assert_zero_connections_timeout(connections).await;
|
||||
server_handle.abort();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_graphql_ws_connection_expiry() {
|
||||
// Expiry in 4 seconds
|
||||
let expiry = graphql_ws::ConnectionExpiry::After(std::time::Duration::from_secs(4));
|
||||
let TestServer {
|
||||
connections,
|
||||
mut socket,
|
||||
server_handle,
|
||||
} = start_websocket_server_expiry(expiry).await;
|
||||
// Send connection_init and check ack
|
||||
graphql_ws_connection_init(&mut socket, connection_init_admin()).await;
|
||||
// Wait for the connection to expire
|
||||
tokio::time::sleep(std::time::Duration::from_secs(5)).await;
|
||||
// Wait for a close message
|
||||
let message = expect_close_message(&mut socket).await;
|
||||
// Check close code
|
||||
let close_code = tungstenite::protocol::frame::coding::CloseCode::Again;
|
||||
if let tungstenite::Message::Close(Some(close_frame)) = message {
|
||||
assert_eq!(close_frame.code, close_code);
|
||||
assert_eq!(close_frame.reason, "WebSocket session expired");
|
||||
}
|
||||
// Assert zero connections
|
||||
assert_zero_connections_timeout(connections).await;
|
||||
server_handle.abort();
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user