graphql-ws: Tests for WebSocket server (#1191)

<!-- The PR description should answer 2 important questions: -->

### What

<!-- What is this PR trying to accomplish (and why, if it's not
obvious)? -->

<!-- Consider: do we need to add a changelog entry? -->

<!-- Does this PR introduce new validation that might break old builds?
-->

<!-- Consider: do we need to put new checks behind a flag? -->
Write tests to confirm websocket connection behavior in conjunction with
[graphl-ws](https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md)
subprotocol.

### How

<!-- How is it trying to accomplish it (what are the implementation
steps)? -->
Test the websocket by spinning up a server in an async tokio task. Use
tokio-tungstenite for websocket client.

V3_GIT_ORIGIN_REV_ID: 32c19298b6a5b23649b22d8d820ef8d47ef1d293
This commit is contained in:
Rakesh Emmadi 2024-10-03 09:51:55 +05:30 committed by hasura-bot
parent 721ea64cc0
commit 454ca0575f
13 changed files with 6832 additions and 14 deletions

4
v3/Cargo.lock generated
View File

@ -2254,12 +2254,16 @@ dependencies = [
"hasura-authn-core",
"indexmap 2.5.0",
"lang-graphql",
"metadata-resolve",
"nonempty",
"open-dds",
"reqwest",
"serde",
"serde_json",
"smol_str",
"thiserror",
"tokio",
"tokio-tungstenite",
"tracing-util",
"uuid",
]

View File

@ -141,6 +141,7 @@ syn = "2"
thiserror = "1"
tokio = "1"
tokio-test = "0.4"
tokio-tungstenite = "0.20.1"
tower = "0.4"
tower-http = "0.4"
tracing = "0.1"

View File

@ -44,5 +44,14 @@ services:
environment:
RUST_LOG: info
# Required to test graphql-ws; subscriptions with custom session variables
auth_hook:
build:
dockerfile: dev-auth-webhook.Dockerfile
environment:
RUST_LOG: debug
ports:
- "3050:3050"
volumes:
postgres:

View File

@ -26,5 +26,12 @@ thiserror = { workspace = true }
tokio = { workspace = true, features = ["macros", "parking_lot", "rt-multi-thread", "signal"] }
uuid = { workspace = true, features = ["v4"] }
[dev-dependencies]
metadata-resolve = { path = "../metadata-resolve" }
open-dds = { path = "../open-dds" }
reqwest = { workspace = true, features = ["json", "multipart"] }
tokio-tungstenite = { workspace = true }
[lints]
workspace = true

View File

@ -12,5 +12,9 @@ pub use websocket::{
pub use protocol::{
subscribe::{execute_request_internal, send_request_error},
types::ServerMessage,
GRAPHQL_WS_PROTOCOL,
};
pub use websocket::{
types::{Connection, Connections, Message},
SEC_WEBSOCKET_PROTOCOL,
};
pub use websocket::types::{Connection, Message};

View File

@ -10,12 +10,12 @@ use futures_util::StreamExt;
use crate::protocol;
static SEC_WEBSOCKET_PROTOCOL: &str = "Sec-WebSocket-Protocol";
pub static SEC_WEBSOCKET_PROTOCOL: &str = "Sec-WebSocket-Protocol";
static WEBSOCKET_CHANNEL_SIZE: usize = 50;
/// GraphQL WebSocket server implementation.
pub struct WebSocketServer {
connections: types::Connections,
pub connections: types::Connections,
}
impl WebSocketServer {
@ -194,9 +194,11 @@ async fn start_websocket_session(
Ok(Err(tasks::ConnectionTimeOutError)) => {
// Connection not initialized within the specified time, send close message
connection.send(types::Message::conn_init_timeout()).await;
// Abort all tasks
// Abort incoming task
incoming_task.abort();
outgoing_task.abort();
// A close message is handled by the outgoing task and it makes the task exit.
// So we need to wait for the task to complete
let _ = outgoing_task.await;
}
Err(_e) => {
// Handle internal server error

View File

@ -12,6 +12,7 @@ use crate::poller;
use crate::protocol::types as protocol;
/// Context required to handle a WebSocket connection
#[derive(Clone)]
pub struct Context {
pub http_context: HttpContext,
pub expose_internal_errors: ExposeInternalErrors,
@ -39,10 +40,10 @@ impl Default for WebSocketId {
/// A mutable and free clone-able collection of WebSocket connections.
#[derive(Clone)]
pub(crate) struct Connections(pub(crate) Arc<RwLock<HashMap<WebSocketId, Connection>>>);
pub struct Connections(pub Arc<RwLock<HashMap<WebSocketId, Connection>>>);
impl Connections {
pub(crate) fn new() -> Self {
pub fn new() -> Self {
Self(Arc::new(RwLock::new(HashMap::new())))
}
@ -69,20 +70,26 @@ impl Connections {
}
}
impl Default for Connections {
fn default() -> Self {
Self::new()
}
}
/// Represents an internal WebSocket connection.
/// Designed for efficient cloning, as all contained fields are inexpensive to clone.
#[derive(Clone)]
pub struct Connection {
// Unique WebSocket connection ID
pub(crate) id: WebSocketId,
pub id: WebSocketId,
// Manages the WebSocket protocol state
pub(crate) protocol_init_state: Arc<RwLock<protocol::ConnectionInitState>>,
pub protocol_init_state: Arc<RwLock<protocol::ConnectionInitState>>,
// Shared connection context
pub(crate) context: Arc<Context>,
pub context: Arc<Context>,
// Channel for sending messages over the WebSocket
pub(crate) send_channel: Sender<Message>,
pub send_channel: Sender<Message>,
// Active pollers associated with operations. A web socket connection can have multiple active subscriptions.
pub(crate) pollers: Arc<RwLock<HashMap<protocol::OperationId, poller::Poller>>>,
pub pollers: Arc<RwLock<HashMap<protocol::OperationId, poller::Poller>>>,
}
/// A representation of an active WebSocket connection.

View File

@ -0,0 +1,267 @@
use axum::{extract::State, response::IntoResponse, routing::get};
use execute::HttpContext;
use futures_util::{SinkExt, StreamExt};
use graphql_ws::Context;
use graphql_ws::GRAPHQL_WS_PROTOCOL;
use std::{net::TcpListener, path::PathBuf, sync::Arc};
use tokio::{net::TcpStream, task::JoinHandle};
use tokio_tungstenite::{
connect_async,
tungstenite::{self, client::IntoClientRequest},
MaybeTlsStream, WebSocketStream,
};
#[allow(dead_code)]
static METADATA_PATH: &str = "tests/static/metadata.json";
#[allow(dead_code)]
static AUTH_CONFIG_PATH: &str = "tests/static/auth_config_v2.json";
#[allow(dead_code)]
pub(crate) struct ServerState {
pub(crate) ws_server: graphql_ws::WebSocketServer,
pub(crate) context: Context,
}
#[allow(dead_code)]
pub(crate) struct TestServer {
pub(crate) connections: graphql_ws::Connections,
pub(crate) socket: WebSocketStream<MaybeTlsStream<TcpStream>>,
pub(crate) server_handle: JoinHandle<()>,
}
#[allow(dead_code)]
pub(crate) async fn ws_handler(
headers: axum::http::header::HeaderMap,
State(state): State<Arc<ServerState>>,
ws: axum::extract::ws::WebSocketUpgrade,
) -> impl IntoResponse {
let context = state.context.clone();
state
.ws_server
.upgrade_and_handle_websocket(ws, &headers, context)
.into_response()
}
#[allow(dead_code)]
pub(crate) async fn start_websocket_server() -> TestServer {
// Create a TCP listener
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();
// Auth Config
let auth_config_path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join(AUTH_CONFIG_PATH);
let raw_auth_config = std::fs::read_to_string(auth_config_path).unwrap();
let (auth_config, _auth_warnings) =
hasura_authn::resolve_auth_config(&raw_auth_config).unwrap();
// Metadata
let metadata_path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join(METADATA_PATH);
let raw_metadata = std::fs::read_to_string(metadata_path).unwrap();
let metadata = open_dds::Metadata::from_json_str(&raw_metadata).unwrap();
let metadata_resolve_configuration = metadata_resolve::configuration::Configuration {
allow_unknown_subgraphs: false,
unstable_features: metadata_resolve::configuration::UnstableFeatures {
enable_subscriptions: true,
..Default::default()
},
};
let (resolved_metadata, _warnings) =
metadata_resolve::resolve(metadata, &metadata_resolve_configuration).unwrap();
let schema = graphql_schema::GDS {
metadata: resolved_metadata.into(),
}
.build_schema()
.unwrap();
// Init context
let http_context = HttpContext {
client: reqwest::Client::new(),
ndc_response_size_limit: None,
};
let context = Context {
http_context,
expose_internal_errors: execute::ExposeInternalErrors::Expose,
project_id: None,
schema,
auth_config,
};
let connections = graphql_ws::Connections::new();
let ws_server = graphql_ws::WebSocketServer {
connections: connections.clone(),
};
// Spawn a server
let state = ServerState { ws_server, context };
let server_handle = tokio::spawn(async move {
let app = axum::Router::new()
.route("/ws", get(ws_handler))
.with_state(Arc::new(state));
axum::Server::from_tcp(listener)
.unwrap()
.serve(app.into_make_service())
.await
.unwrap();
});
let url = format!("ws://{addr}/ws");
let mut request = url.into_client_request().unwrap();
request.headers_mut().insert(
graphql_ws::SEC_WEBSOCKET_PROTOCOL,
GRAPHQL_WS_PROTOCOL.parse().unwrap(),
);
let (socket, _response) = connect_async(request)
.await
.expect("Failed to connect to WebSocket server");
TestServer {
connections,
socket,
server_handle,
}
}
#[allow(dead_code)]
pub(crate) async fn assert_zero_connections_timeout(connections: graphql_ws::Connections) {
// Closure of a websocket connection is not immediate. So, we keep checking zero connections
// for at most 5 seconds.
let result = tokio::time::timeout(tokio::time::Duration::from_secs(5), async {
loop {
let conns = connections.0.read().await.len();
if conns == 0 {
break;
}
tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
}
})
.await;
assert!(result.is_ok(), "Connections are not empty");
}
#[allow(dead_code)]
pub(crate) async fn assert_zero_operations_timeout(connections: &graphql_ws::Connections) {
// One connection should be present in an active test
let connections = connections.0.read().await;
let (_, connection) = connections.iter().next().unwrap();
// Removal of an operation is not immediate. So, we keep checking zero operations
// for at most 5 seconds.
let result = tokio::time::timeout(tokio::time::Duration::from_secs(5), async {
loop {
let operations = connection.pollers.read().await.len();
if operations == 0 {
break;
}
tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
}
})
.await;
assert!(result.is_ok(), "Operations are not empty");
}
#[allow(dead_code)]
pub(crate) async fn expect_close_message(
socket: &mut WebSocketStream<MaybeTlsStream<TcpStream>>,
) -> tungstenite::Message {
let message = socket.next().await.unwrap();
let message = message.unwrap();
// Check close message
assert!(message.is_close(), "Expected close message");
message
}
#[allow(dead_code)]
pub(crate) async fn expect_text_message(
socket: &mut WebSocketStream<MaybeTlsStream<TcpStream>>,
) -> tungstenite::Message {
let message = socket.next().await.unwrap();
let message = message.unwrap();
// Check text message
assert!(message.is_text(), "Expected text message");
message
}
#[allow(dead_code)]
pub(crate) fn connection_init_admin() -> serde_json::Value {
serde_json::json!(
{
"type": "connection_init",
"payload": {
"headers": {
"x-hasura-role": "admin"
}
}
}
)
}
#[allow(dead_code)]
pub(crate) fn connection_init_user_1_id_2() -> serde_json::Value {
serde_json::json!(
{
"type": "connection_init",
"payload": {
"headers": {
"x-hasura-role": "user_1",
"x-hasura-user-id": "2"
}
}
}
)
}
#[allow(dead_code)]
pub(crate) fn subscribe_article_by_id(operation_id: &str) -> serde_json::Value {
let query = r"
subscription {
ArticleByID(article_id: 1) {
article_id
title
Author {
author_id
first_name
}
}
}
";
serde_json::json!({
"type": "subscribe",
"id": operation_id,
"payload": {
"operationName": null,
"query": query
}
})
}
#[allow(dead_code)]
pub(crate) async fn graphql_ws_connection_init(
socket: &mut WebSocketStream<MaybeTlsStream<TcpStream>>,
init_payload: serde_json::Value,
) {
// Send connection init with required headers for authentication.
let json_message = serde_json::to_string(&init_payload).unwrap();
socket
.send(tungstenite::Message::Text(json_message))
.await
.unwrap();
// Wait for a text message
let message = expect_text_message(socket).await;
// Check for connection_ack message
if let tungstenite::Message::Text(message) = message {
let message_json: serde_json::Value =
serde_json::from_str(message.as_str()).expect("Expected a valid JSON");
assert_eq!(message_json, serde_json::json!({"type": "connection_ack"}));
}
}
#[allow(dead_code)]
pub(crate) async fn check_operation_id(operation_id: &str, connections: &graphql_ws::Connections) {
let operation_id = graphql_ws::OperationId(operation_id.to_string());
// One connection should be present in an active test
let connections = connections.0.read().await;
let (_, connection) = connections.iter().next().unwrap();
assert!(connection.pollers.read().await.contains_key(&operation_id));
}

View File

@ -0,0 +1,386 @@
mod common;
use common::*;
use futures_util::SinkExt;
use tokio_tungstenite::tungstenite;
#[tokio::test]
async fn test_graphql_ws_connection_init_timeout() {
let TestServer {
connections,
mut socket,
server_handle,
} = start_websocket_server().await;
// Wait for the connection to be timed out and closed by the server
let message = expect_close_message(&mut socket).await;
// Check close code
let close_code = tungstenite::protocol::frame::coding::CloseCode::from(4408);
if let tungstenite::Message::Close(Some(close_frame)) = message {
assert_eq!(close_frame.code, close_code);
assert_eq!(close_frame.reason, "Connection initialization timeout");
}
// Assert zero connections
assert_zero_connections_timeout(connections).await;
server_handle.abort();
}
#[tokio::test]
async fn test_graphql_ws_invalid_message_format() {
let TestServer {
connections,
mut socket,
server_handle,
} = start_websocket_server().await;
// Only JSON text messages are allowed. Sending non-JSON messages result in websocket closure.
let text_message = "Hello!";
socket
.send(tungstenite::Message::Text(text_message.into()))
.await
.unwrap();
// Wait for a close message
let message = expect_close_message(&mut socket).await;
// Check close code
let close_code = tungstenite::protocol::frame::coding::CloseCode::from(4400);
if let tungstenite::Message::Close(Some(close_frame)) = message {
assert_eq!(close_frame.code, close_code);
assert_eq!(
close_frame.reason,
"Invalid message format: expected value at line 1 column 1"
);
}
// Assert zero connections
assert_zero_connections_timeout(connections).await;
server_handle.abort();
}
#[tokio::test]
async fn test_graphql_ws_invalid_json() {
let TestServer {
connections,
mut socket,
server_handle,
} = start_websocket_server().await;
// Send JSON message not supported by the graphql-ws protocol
let json_message = "{\"hello\": \"world\"}";
socket
.send(tungstenite::Message::Text(json_message.into()))
.await
.unwrap();
// Wait for a close message
let message = expect_close_message(&mut socket).await;
// Check close code
let close_code_4400 = tungstenite::protocol::frame::coding::CloseCode::from(4400);
if let tungstenite::Message::Close(Some(close_frame)) = message {
assert_eq!(close_frame.code, close_code_4400);
assert_eq!(
close_frame.reason,
"Invalid message format: missing field `type` at line 1 column 18"
);
}
// Assert zero connections
assert_zero_connections_timeout(connections).await;
server_handle.abort();
}
#[tokio::test]
async fn test_graphql_ws_connection_init_no_headers() {
let TestServer {
connections,
mut socket,
server_handle,
} = start_websocket_server().await;
// Send connection init without headers. Connection initialization fails with a forbidden message.
let connection_init_no_headers = serde_json::json!({
"type": "connection_init",
"payload": {
"headers": {}
}
});
let json_message = serde_json::to_string(&connection_init_no_headers).unwrap();
socket
.send(tungstenite::Message::Text(json_message))
.await
.unwrap();
// Wait for a close message
let message = expect_close_message(&mut socket).await;
// Check close code
let close_code = tungstenite::protocol::frame::coding::CloseCode::from(4403);
if let tungstenite::Message::Close(Some(close_frame)) = message {
assert_eq!(close_frame.code, close_code);
assert_eq!(close_frame.reason, "Forbidden");
}
// Assert zero connections
assert_zero_connections_timeout(connections).await;
server_handle.abort();
}
#[tokio::test]
async fn test_graphql_ws_too_many_connection_inits() {
let TestServer {
connections,
mut socket,
server_handle,
} = start_websocket_server().await;
// Send connection_init and check ack
graphql_ws_connection_init(&mut socket, connection_init_admin()).await;
// Sending connection_init again results in connection closure
let json_message = serde_json::to_string(&connection_init_admin()).unwrap();
socket
.send(tungstenite::Message::Text(json_message))
.await
.unwrap();
// Wait for a close message
let message = expect_close_message(&mut socket).await;
// Check close code
let close_code = tungstenite::protocol::frame::coding::CloseCode::from(4429);
if let tungstenite::Message::Close(Some(close_frame)) = message {
assert_eq!(close_frame.code, close_code);
assert_eq!(close_frame.reason, "Too many initialization requests");
}
// Assert zero connections
assert_zero_connections_timeout(connections).await;
server_handle.abort();
}
#[tokio::test]
async fn test_graphql_ws_subscribe_admin() {
let TestServer {
connections,
mut socket,
server_handle,
} = start_websocket_server().await;
// Send connection_init and check ack
graphql_ws_connection_init(&mut socket, connection_init_admin()).await;
// Send a subscription
let operation_id = "some-operation-id";
let json_message = serde_json::to_string(&subscribe_article_by_id(operation_id)).unwrap();
socket
.send(tungstenite::Message::Text(json_message))
.await
.unwrap();
// Wait for a text message
let message = expect_text_message(&mut socket).await;
// Check message
if let tungstenite::Message::Text(message) = message {
let message_json: serde_json::Value =
serde_json::from_str(message.as_str()).expect("Expected a valid JSON");
let expected = serde_json::json!({
"type": "next",
"id": operation_id,
"payload": {
"data": {
"ArticleByID": {
"article_id": 1,
"title": "The Next 700 Programming Languages",
"Author": {
"author_id": 1,
"first_name": "Peter"
}
}
}
}
});
assert_eq!(message_json, expected);
}
// Check operation id
check_operation_id(operation_id, &connections).await;
// Send another subscription with same operation_id
let json_message = serde_json::to_string(&subscribe_article_by_id(operation_id)).unwrap();
socket
.send(tungstenite::Message::Text(json_message))
.await
.unwrap();
// Wait for a close message
let message = expect_close_message(&mut socket).await;
// Check close code
let close_code = tungstenite::protocol::frame::coding::CloseCode::from(4409);
if let tungstenite::Message::Close(Some(close_frame)) = message {
assert_eq!(close_frame.code, close_code);
assert_eq!(
close_frame.reason,
"Subscriber for some-operation-id already exists"
);
}
// Assert zero connections
assert_zero_connections_timeout(connections).await;
server_handle.abort();
}
#[tokio::test]
async fn test_graphql_ws_subscribe_user_1() {
let TestServer {
connections,
mut socket,
server_handle,
} = start_websocket_server().await;
// Send connection_init and check ack
graphql_ws_connection_init(&mut socket, connection_init_user_1_id_2()).await;
// Send a subscription
let operation_id = "some-operation-id";
let query = r"
subscription {
ArticleMany{
article_id
author_id
}
}
";
let subscribe_message = serde_json::json!({
"type": "subscribe",
"id": operation_id,
"payload": {
"query": query,
"variables": {}
}
});
let json_message = serde_json::to_string(&subscribe_message).unwrap();
socket
.send(tungstenite::Message::Text(json_message))
.await
.unwrap();
// Wait for a text message
let message = expect_text_message(&mut socket).await;
// Check message
if let tungstenite::Message::Text(message) = message {
let message_json: serde_json::Value =
serde_json::from_str(message.as_str()).expect("Expected a valid JSON");
// Expects data with author_id = 2
let expected = serde_json::json!({
"type": "next",
"id": operation_id,
"payload": {
"data": {
"ArticleMany": [
{
"article_id": 2,
"author_id": 2
},
{
"article_id": 3,
"author_id": 2
},
{
"article_id": 5,
"author_id": 2
}
]
}
}
});
assert_eq!(message_json, expected);
}
// Check operation id
check_operation_id(operation_id, &connections).await;
// stop subscription
let stop_message = serde_json::json!({
"id": operation_id,
"type": "complete"
});
socket
.send(tungstenite::Message::Text(
serde_json::to_string(&stop_message).unwrap(),
))
.await
.unwrap();
// Assert zero operations
assert_zero_operations_timeout(&connections).await;
// Send close frame from client
socket
.send(tungstenite::Message::Close(None))
.await
.unwrap();
// Assert zero connections
assert_zero_connections_timeout(connections).await;
server_handle.abort();
}
#[tokio::test]
async fn test_graphql_ws_subscribe_user_1_validation_error() {
let TestServer {
connections,
mut socket,
server_handle,
} = start_websocket_server().await;
// Send connection_init and check ack
graphql_ws_connection_init(&mut socket, connection_init_user_1_id_2()).await;
// Send a subscription
let operation_id = "some-operation-id";
let query = r"
subscription {
ArticleByID(article_id: 1) {
article_id
title
Author {
author_id
first_name
}
}
}
";
let subscribe_message = serde_json::json!({
"type": "subscribe",
"id": operation_id,
"payload": {
"query": query,
"variables": {}
}
});
let json_message = serde_json::to_string(&subscribe_message).unwrap();
socket
.send(tungstenite::Message::Text(json_message))
.await
.unwrap();
// Wait for a text message
let message = expect_text_message(&mut socket).await;
// Check message
if let tungstenite::Message::Text(message) = message {
let message_json: serde_json::Value =
serde_json::from_str(message.as_str()).expect("Expected a valid JSON");
// Expects data with author_id = 2
let expected = serde_json::json!({
"type": "error",
"id": operation_id,
"payload": [
{
"message": "validation failed: no such field on type Article: title"
}
]
});
assert_eq!(message_json, expected);
}
// The above operation resulted in an error.
// Assert zero operations
assert_zero_operations_timeout(&connections).await;
// Send close frame from client
socket
.send(tungstenite::Message::Close(None))
.await
.unwrap();
// Assert zero connections
assert_zero_connections_timeout(connections).await;
server_handle.abort();
}

View File

@ -0,0 +1,11 @@
{
"version": "v2",
"definition": {
"mode": {
"webhook": {
"url": "http://127.0.0.1:3050/validate-request",
"method": "Post"
}
}
}
}

File diff suppressed because it is too large Load Diff

View File

@ -44,7 +44,7 @@ run-local-with-shell:
# start all the docker deps for running tests (not engine)
start-docker-test-deps:
# start connectors and wait for health
docker compose -f ci.docker-compose.yaml up --wait postgres postgres_connector custom_connector custom_connector_ndc_v01
docker compose -f ci.docker-compose.yaml up --wait postgres postgres_connector custom_connector custom_connector_ndc_v01 auth_hook
# start all the docker run time deps for the engine
start-docker-run-deps:

View File

@ -3,7 +3,7 @@
"definition": {
"mode": {
"webhook": {
"url": "http://auth_hook:3050/validate-request",
"url": "http://127.0.0.1:3050/validate-request",
"method": "Post"
}
}