server: implement protocol connection_init timeout

## Summary by CodeRabbit

## Release Notes

- **Documentation**
	- Updated the configuration documentation for the Hasura GraphQL Engine, including new flags and environment variables, with clarifications on WebSocket connection initialization and deprecated options.

- **Bug Fixes**
	- Enhanced WebSocket connection management and error handling, ensuring proper initialization and cleanup of connections across various components.

- **Tests**
	- Improved tests for WebSocket connection handling and logging, ensuring robust verification of connection states and error responses.

PR-URL: https://github.com/hasura/graphql-engine-mono/pull/11069
GitOrigin-RevId: 8ee25d702a64f3bb04077bbcf0f3e1bd10c916d6
This commit is contained in:
Rakesh Emmadi 2024-11-05 16:55:06 +05:30 committed by hasura-bot
parent c62e5fed32
commit b8629eaa58
10 changed files with 151 additions and 143 deletions

View File

@ -1317,8 +1317,8 @@ Enable WebSocket `permessage-deflate` compression.
### Websocket Connection Init Timeout
Used to set the connection initialization timeout for `graphql-ws` clients. This is ignored for
[`subscription-transport-ws` (Apollo) clients](/subscriptions/postgres/index.mdx).
Used to set the connection initialization timeout for GraphQL subscription protocols
(`graphql-transport-ws` and `subscriptions-transport-ws`).
| | |
| ------------------- | ---------------------------------------------------------- |
@ -1330,7 +1330,7 @@ Used to set the connection initialization timeout for `graphql-ws` clients. This
### Websocket Keepalive
Used to set the `Keep Alive` delay for clients that use the `subscription-transport-ws` (Apollo) protocol. For
Used to set the `Keep Alive` delay for clients that use the `subscriptions-transport-ws` (Apollo) protocol. For
`graphql-ws` clients, the `graphql-engine` sends `PING` messages instead.
| | |

View File

@ -1248,7 +1248,8 @@ onConnInit logger manager wsConn getAuthMode connParamsM onConnInitErrAction kee
liftIO $ do
$assertNFHere csInit -- so we don't write thunks to mutable vars
STM.atomically $ STM.writeTVar (_wscUser $ WS.getData wsConn) csInit
-- mark the connection as initialised in the connection
liftIO $ WS.setConnInitialized wsConn
sendMsg wsConn SMConnAck
liftIO $ keepAliveMessageAction wsConn
where

View File

@ -15,22 +15,22 @@ module Hasura.GraphQL.Transport.WebSocket.Protocol
ServerMsgType (..),
StartMsg (StartMsg),
StopMsg (StopMsg),
WSConnInitTimerStatus (Done),
WSConnInitTimeoutStatus (..),
WSConnInitTimeout,
WSSubProtocol (..),
encodeServerErrorMsg,
encodeServerMsg,
getNewWSTimer,
getWSTimerState,
keepAliveMessage,
showSubProtocol,
toWSSubProtocol,
newWSConnInitTimeout,
runTimer,
-- * exported for testing
unsafeMkOperationId,
)
where
import Control.Concurrent
import Control.Concurrent.Extended (sleep)
import Control.Concurrent.STM
import Data.Aeson qualified as J
@ -312,32 +312,25 @@ encodeServerMsg msg =
]
Nothing -> [encTy msgType]
-- This "timer" is necessary while initialising the connection
-- with the server. Also, this is specific to the GraphQL-WS protocol.
data WSConnInitTimerStatus = Running | Done
-- Status for connection initialisation in sub-protocol
-- This is used to timeout the 'connection_init' message sent by the client
data WSConnInitTimeoutStatus = Initialized | TimedOut
deriving stock (Show, Eq)
type WSConnInitTimer = (TVar WSConnInitTimerStatus, TMVar ())
type WSConnInitTimeout = TMVar WSConnInitTimeoutStatus
getWSTimerState :: WSConnInitTimer -> IO WSConnInitTimerStatus
getWSTimerState (timerState, _) = readTVarIO timerState
newWSConnInitTimeout :: IO WSConnInitTimeout
newWSConnInitTimeout = newEmptyTMVarIO
{-# ANN getNewWSTimer ("HLint: ignore Use withAsync" :: String) #-}
getNewWSTimer :: Seconds -> IO WSConnInitTimer
getNewWSTimer timeout = do
timerState <- newTVarIO Running
timer <- newEmptyTMVarIO
void
$ forkIO
$ do
labelMe "getNewWSTimer"
sleep (seconds timeout)
atomically $ do
runTimerState <- readTVar timerState
case runTimerState of
Running -> do
-- time's up, we set status to "Done"
writeTVar timerState Done
putTMVar timer ()
Done -> pure ()
pure (timerState, timer)
-- | Run the timer for the given timeout duration
runTimer :: Seconds -> WSConnInitTimeout -> IO ()
runTimer timeout timer = do
-- sleep for the timeout duration
sleep (seconds timeout)
atomically $ do
-- check the status of the timer
timerState <- tryReadTMVar timer
-- if the timer is not set, set it to 'TimedOut'
case timerState of
Nothing -> writeTMVar timer TimedOut
Just _ -> pure ()

View File

@ -28,6 +28,7 @@ module Hasura.GraphQL.Transport.WebSocket.Server
getData,
getRawWebSocketConnection,
getWSId,
setConnInitialized,
mkWSServerErrorCode,
sendMsg,
shutdown,
@ -40,7 +41,6 @@ where
import Control.Concurrent.Async qualified as A
import Control.Concurrent.Async.Lifted.Safe qualified as LA
import Control.Concurrent.Extended (sleep)
import Control.Concurrent.STM (readTVarIO)
import Control.Concurrent.STM qualified as STM
import Control.Exception.Lifted
import Control.Monad.Trans.Control qualified as MC
@ -110,6 +110,7 @@ $(J.deriveToJSON hasuraJSON ''MessageDetails)
data WSEvent
= EConnectionRequest
| EConnectionTimeout
| EAccepted
| ERejected
| EMessageReceived !MessageDetails
@ -201,6 +202,7 @@ data WSConn a = WSConn
_wcLogger :: !(L.Logger L.Hasura),
_wcConnRaw :: !WS.Connection,
_wcSendQ :: !(STM.TQueue WSQueueResponse),
_wsConnInitTimer :: !WSConnInitTimeout,
_wcExtraData :: !a
}
@ -216,6 +218,10 @@ getWSId = _wcConnId
closeConn :: WSConn a -> BL.ByteString -> IO ()
closeConn wsConn = closeConnWithCode wsConn 1000 -- 1000 is "normal close"
setConnInitialized :: WSConn a -> IO ()
setConnInitialized wsConn =
STM.atomically $ STM.writeTMVar (_wsConnInitTimer wsConn) Initialized
-- | Closes a connection with code 1012, which means "Server is restarting"
-- good clients will implement a retry logic with a backoff of a few seconds
forceConnReconnect :: (MonadIO m) => WSConn a -> BL.ByteString -> m ()
@ -380,7 +386,7 @@ websocketConnectionReaper getLatestConfig getSchemaCache ws@(WSServer _ userConf
forever $ do
(currAuthMode, currEnableAllowlist, currCorsPolicy, currSqlGenCtx, currExperimentalFeatures, currDefaultNamingCase) <- getLatestConfig
currAllowlist <- scAllowlist <$> getSchemaCache
SecuritySensitiveUserConfig prevAuthMode prevEnableAllowlist prevAllowlist prevCorsPolicy prevSqlGenCtx prevExperimentalFeatures prevDefaultNamingCase <- readTVarIO userConfRef
SecuritySensitiveUserConfig prevAuthMode prevEnableAllowlist prevAllowlist prevCorsPolicy prevSqlGenCtx prevExperimentalFeatures prevDefaultNamingCase <- STM.readTVarIO userConfRef
-- check and close all connections if required
checkAndReapConnections
(currAuthMode, prevAuthMode)
@ -513,9 +519,9 @@ createServerApp ::
createServerApp getMetricsConfig wsConnInitTimeout (WSServer logger@(L.Logger writeLog) _ serverStatus) prometheusMetrics wsHandlers !ipAddress !pendingConn = do
wsId <- WSId <$> liftIO UUID.nextRandom
logWSLog logger $ WSLog wsId EConnectionRequest Nothing
-- NOTE: this timer is specific to `graphql-ws`. the server has to close the connection
-- if the client doesn't send a `connection_init` message within the timeout period
wsConnInitTimer <- liftIO $ getNewWSTimer (unrefine $ unWSConnectionInitTimeout wsConnInitTimeout)
-- the server has to close the connection
wsConnInitTimer <- liftIO newWSConnInitTimeout
status <- liftIO $ STM.readTVarIO serverStatus
case status of
AcceptingConns _ -> logUnexpectedExceptions $ do
@ -574,7 +580,7 @@ createServerApp getMetricsConfig wsConnInitTimeout (WSServer logger@(L.Logger wr
conn <- liftIO $ WS.acceptRequestWith pendingConn acceptWithParams
logWSLog logger $ WSLog wsId EAccepted Nothing
sendQ <- liftIO STM.newTQueueIO
let !wsConn = WSConn wsId logger conn sendQ a
let !wsConn = WSConn wsId logger conn sendQ wsConnInitTimer a
-- TODO there are many thunks here. Difficult to trace how much is retained, and
-- how much of that would be shared anyway.
-- Requires a fork of 'wai-websockets' and 'websockets', it looks like.
@ -657,31 +663,52 @@ createServerApp getMetricsConfig wsConnInitTimeout (WSServer logger@(L.Logger wr
(realToFrac messageWriteTime)
logWSLog logger $ WSLog wsId (EMessageSent messageDetails) wsInfo
let connInitTimer = liftIO $ do
labelMe "WebSocket connInitTimer"
runTimer (unrefine $ unWSConnectionInitTimeout wsConnInitTimeout) wsConnInitTimer
-- withAsync lets us be very sure that if e.g. an async exception is raised while we're
-- forking that the threads we launched will be cleaned up. See also below.
LA.withAsync rcv $ \rcvRef -> do
LA.withAsync send $ \sendRef -> do
LA.withAsync (liftIO $ labelMe "WebSocket keepAlive" >> keepAlive wsConn) $ \keepAliveRef -> do
LA.withAsync (liftIO $ labelMe "WebSocket onJwtExpiry" >> onJwtExpiry wsConn) $ \onJwtExpiryRef -> do
-- once connection is accepted, check the status of the timer, and if it's expired, close the connection for `graphql-ws`
timeoutStatus <- liftIO $ getWSTimerState wsConnInitTimer
when (timeoutStatus == Done && subProtocol == GraphQLWS)
$ liftIO
$ closeConnWithCode wsConn 4408 "Connection initialisation timed out"
-- terminates on WS.ConnectionException and JWT expiry
let waitOnRefs = [keepAliveRef, onJwtExpiryRef, rcvRef, sendRef]
-- withAnyCancel re-raises exceptions from forkedThreads, and is guarenteed to cancel in
-- case of async exceptions raised while blocking here:
try (LA.waitAnyCancel waitOnRefs) >>= \case
-- NOTE: 'websockets' is a bit of a rat's nest at the moment wrt
-- exceptions; for now handle all ConnectionException by closing
-- and cleaning up, see: https://github.com/jaspervdj/websockets/issues/48
Left (_ :: WS.ConnectionException) -> do
logWSLog logger $ WSLog (_wcConnId wsConn) ECloseReceived Nothing
-- this will happen when jwt is expired
Right _ -> do
logWSLog logger $ WSLog (_wcConnId wsConn) EJwtExpired Nothing
LA.withAsync connInitTimer $ \connInitTimerRef -> do
LA.withAsync rcv $ \rcvRef -> do
LA.withAsync send $ \sendRef -> do
LA.withAsync (liftIO $ labelMe "WebSocket keepAlive" >> keepAlive wsConn) $ \keepAliveRef -> do
LA.withAsync (liftIO $ labelMe "WebSocket onJwtExpiry" >> onJwtExpiry wsConn) $ \onJwtExpiryRef -> do
-- Wait for connection init status, and then wait for any of the threads to terminate
-- The code below will block until the status is updated by either the timer thread or the GraphQL protocol connection initialization.
liftIO (STM.atomically $ STM.takeTMVar wsConnInitTimer) >>= \case
TimedOut -> do
-- send close message
let timeoutMessage = "Connection initialisation timed out"
case subProtocol of
GraphQLWS ->
-- https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md#connectioninit
liftIO $ closeConnWithCode wsConn 4408 timeoutMessage
Apollo ->
-- 1011 is an unexpected condition prevented the request from being fulfilled.
-- NOTE: The protocol spec does not define a close code for connection timeouts, so we're using the close code from the reference implementation.
-- https://github.com/apollographql/subscriptions-transport-ws/blob/36f3f6f780acc1a458b768db13fd39c65e5e6518/src/server.ts#L159
liftIO $ closeConnWithCode wsConn 1011 timeoutMessage
-- log connection timeout
logWSLog logger $ WSLog (_wcConnId wsConn) EConnectionTimeout Nothing
-- terminate all threads
mapM_ LA.cancel [keepAliveRef, onJwtExpiryRef, rcvRef, sendRef]
Initialized -> do
-- terminate the timer thread
LA.cancel connInitTimerRef
-- terminates on WS.ConnectionException and JWT expiry
let waitOnRefs = [keepAliveRef, onJwtExpiryRef, rcvRef, sendRef]
-- withAnyCancel re-raises exceptions from forkedThreads, and is guarenteed to cancel in
-- case of async exceptions raised while blocking here:
try (LA.waitAnyCancel waitOnRefs) >>= \case
-- NOTE: 'websockets' is a bit of a rat's nest at the moment wrt
-- exceptions; for now handle all ConnectionException by closing
-- and cleaning up, see: https://github.com/jaspervdj/websockets/issues/48
Left (_ :: WS.ConnectionException) -> do
logWSLog logger $ WSLog (_wcConnId wsConn) ECloseReceived Nothing
-- this will happen when jwt is expired
Right _ -> do
logWSLog logger $ WSLog (_wcConnId wsConn) EJwtExpired Nothing
onConnClose wsConn = \case
ShuttingDown -> pure ()

View File

@ -43,7 +43,7 @@ class GQLWsClient():
self.ws_queue = queue.Queue(maxsize=-1)
self.ws_url = urlparse(hge_ctx.hge_url)._replace(scheme='ws',
path=endpoint)
self.create_conn()
self.conn_created = False
def create_conn(self):
self.ws_queue.queue.clear()
@ -60,10 +60,7 @@ class GQLWsClient():
self.wst = threading.Thread(target=self._ws.run_forever)
self.wst.daemon = True
self.wst.start()
def recreate_conn(self):
self.teardown()
self.create_conn()
self.conn_created = True
def wait_for_connection(self, timeout=10):
assert not self.is_closing
@ -78,7 +75,7 @@ class GQLWsClient():
def get_ws_query_event(self, query_id, timeout):
return self.ws_id_query_queues[query_id].get(timeout=timeout)
def send(self, frame, count=0):
def send(self, frame, headers={}, count=0):
self.wait_for_connection()
if frame.get('type') == 'stop':
self.ws_active_query_ids.discard( frame.get('id') )
@ -90,7 +87,7 @@ class GQLWsClient():
if count > 2:
raise websocket.WebSocketConnectionClosedException("Connection is already closed and cannot be recreated even after 3 attempts")
# Connection closed, try to recreate the connection and send the frame again
self.recreate_conn()
self.init(headers)
self.send(frame, count+1)
def init_as_admin(self):
@ -100,6 +97,7 @@ class GQLWsClient():
self.init(headers)
def init(self, headers={}):
self.create_conn()
payload = {'type': 'connection_init', 'payload': {}}
if headers and len(headers) > 0:
@ -121,13 +119,10 @@ class GQLWsClient():
return self.gen_id(size, chars)
return new_id
def send_query(self, query, query_id=None, headers={}, timeout=60):
graphql.parse(query['query'])
if headers and len(headers) > 0:
#Do init If headers are provided
self.init(headers)
elif not self.init_done:
def send_query(self, query, query_id=None, timeout=60):
if not self.init_done:
self.init()
graphql.parse(query['query'])
if query_id == None:
query_id = self.gen_id()
frame = {
@ -169,10 +164,12 @@ class GQLWsClient():
return self.remote_closed or self.is_closing
def teardown(self):
self.is_closing = True
if not self.remote_closed:
self._ws.close()
self.wst.join()
if self.conn_created:
self.is_closing = True
if not self.remote_closed:
self._ws.close()
self.wst.join()
self.conn_created = False
# NOTE: use this to generate a GraphQL client that uses the `graphql-ws` sub-protocol
class GraphQLWSClient():
@ -182,7 +179,7 @@ class GraphQLWSClient():
self.ws_queue = queue.Queue(maxsize=-1)
self.ws_url = urlparse(hge_ctx.hge_url)._replace(scheme='ws',
path=endpoint)
self.create_conn()
self.conn_created = False
def get_queue(self):
return self.ws_queue.queue
@ -205,10 +202,7 @@ class GraphQLWSClient():
self.wst = threading.Thread(target=self._ws.run_forever)
self.wst.daemon = True
self.wst.start()
def recreate_conn(self):
self.teardown()
self.create_conn()
self.conn_created = True
def wait_for_connection(self, timeout=10):
assert not self.is_closing
@ -239,6 +233,7 @@ class GraphQLWSClient():
self.init(headers)
def init(self, headers={}):
self.create_conn()
payload = {'type': 'connection_init', 'payload': {}}
if headers and len(headers) > 0:
@ -260,13 +255,9 @@ class GraphQLWSClient():
return self.gen_id(size, chars)
return new_id
def send_query(self, query, query_id=None, headers={}, timeout=60):
def send_query(self, query, query_id=None, timeout=60):
graphql.parse(query['query'])
if headers and len(headers) > 0:
#Do init If headers are provided
self.clear_queue()
self.init(headers)
elif not self.init_done:
if not self.init_done:
self.init()
if query_id == None:
query_id = self.gen_id()
@ -318,10 +309,11 @@ class GraphQLWSClient():
return self.remote_closed or self.is_closing
def teardown(self):
self.is_closing = True
if not self.remote_closed:
self._ws.close()
self.wst.join()
if self.conn_created:
self.is_closing = True
if not self.remote_closed:
self._ws.close()
self.wst.join()
class ActionsWebhookHandler(http.server.BaseHTTPRequestHandler):
hge_url: str
@ -387,7 +379,7 @@ class ActionsWebhookHandler(http.server.BaseHTTPRequestHandler):
elif req_path == "/null-response":
resp, status = self.null_response()
self._send_response(status, resp)
elif req_path == "/omitted-response-field":
self._send_response(
HTTPStatus.OK,
@ -624,7 +616,7 @@ class ActionsWebhookHandler(http.server.BaseHTTPRequestHandler):
'id': 1,
'child': None
}
def get_omitted_response_field(self):
return {
'country': 'India'

View File

@ -31,7 +31,7 @@ def parse_logs(stream):
pytestmark = [
pytest.mark.capture_hge_logs,
pytest.mark.admin_secret,
pytest.mark.hge_env('HASURA_GRAPHQL_LOG_LEVEL', 'debug'),
pytest.mark.hge_env('HASURA_GRAPHQL_LOG_LEVEL', 'debug'),
]
@ -259,17 +259,18 @@ class TestWebsocketLogging:
@pytest.fixture(scope='class', autouse=True)
def make_requests(self, hge_ctx, ws_client):
'''
Create connection using connection_init
'''
headers = {'x-request-id': self.query_id}
if hge_ctx.hge_key:
headers['x-hasura-admin-secret'] = hge_ctx.hge_key
ws_client.init(headers=headers)
# setup some tables
hge_ctx.v1q_f(self.dir + '/setup.yaml')
# make a successful websocket query
headers = {'x-request-id': self.query_id}
if hge_ctx.hge_key:
headers['x-hasura-admin-secret'] = hge_ctx.hge_key
resp = ws_client.send_query(self.query, headers=headers,
query_id=self.query_id,
timeout=5)
resp = ws_client.send_query(self.query, query_id=self.query_id, timeout=5)
try:
ev = next(resp)
assert ev['type'] == 'data' and ev['id'] == self.query_id, ev

View File

@ -27,6 +27,7 @@ def ws_conn_init_graphql_ws(hge_key, ws_client_graphql_ws):
# This is used in other test files! Be careful when modifying it.
def init_ws_conn(hge_key, ws_client, payload = None):
ws_client.create_conn()
init_msg = {
'type': 'connection_init',
'payload': payload or ws_payload(hge_key),
@ -36,6 +37,7 @@ def init_ws_conn(hge_key, ws_client, payload = None):
assert ev['type'] == 'connection_ack', ev
def init_graphql_ws_conn(hge_key, ws_client_graphql_ws):
ws_client_graphql_ws.create_conn()
init_msg = {
'type': 'connection_init',
'payload': ws_payload(hge_key),
@ -67,7 +69,6 @@ def get_explain_graphql_query_response(hge_ctx, hge_key, query, variables, user_
@pytest.mark.no_admin_secret
class TestSubscriptionCtrlWithoutSecret(object):
def test_connection(self, ws_client):
ws_client.recreate_conn()
init_ws_conn(None, ws_client)
obj = {
@ -86,7 +87,6 @@ class TestSubscriptionCtrl(object):
'''
def test_connection(self, hge_key, ws_client):
ws_client.recreate_conn()
init_ws_conn(hge_key, ws_client)
obj = {
@ -104,6 +104,7 @@ class TestSubscriptionBasicNoAuth:
def test_closed_connection_apollo(self, ws_client):
# sends empty header so that there is not authentication present in the test
ws_client.create_conn()
init_msg = {
'type': 'connection_init',
'payload':{'headers':{}}
@ -115,6 +116,7 @@ class TestSubscriptionBasicNoAuth:
def test_closed_connection_graphql_ws(self, ws_client_graphql_ws):
# sends empty header so that there is not authentication present in the test
ws_client_graphql_ws.create_conn()
init_msg = {
'type': 'connection_init',
'payload':{'headers':{}}
@ -383,7 +385,10 @@ class TestSubscriptionLiveQueries:
'''
Create connection using connection_init
'''
ws_client.init_as_admin()
headers={}
if hge_key is not None:
headers['X-Hasura-Admin-Secret'] = hge_key
ws_client.init(headers=headers)
with open(self.dir() + "/steps.yaml") as c:
conf = yaml.load(c)
@ -401,11 +406,8 @@ class TestSubscriptionLiveQueries:
liveQs = []
for i, resultLimit in queries:
query = queryTmplt.replace('{0}',str(i))
headers={}
if hge_key is not None:
headers['X-Hasura-Admin-Secret'] = hge_key
subscrPayload = { 'query': query, 'variables': { 'result_limit': resultLimit } }
respLive = ws_client.send_query(subscrPayload, query_id='live_'+str(i), headers=headers, timeout=15)
respLive = ws_client.send_query(subscrPayload, query_id='live_'+str(i), timeout=15)
liveQs.append(respLive)
ev = next(respLive)
assert ev['type'] == 'data', ev
@ -466,7 +468,10 @@ class TestStreamingSubscription:
'''
Create connection using connection_init
'''
ws_client.init_as_admin()
headers={}
if hge_key is not None:
headers['X-Hasura-Admin-Secret'] = hge_key
ws_client.init(headers=headers)
query = """
subscription ($batch_size: Int!) {
@ -478,15 +483,12 @@ class TestStreamingSubscription:
"""
liveQs = []
headers={}
articles_to_insert = []
for i in range(10):
articles_to_insert.append({"id": i + 1, "title": "Article title {}".format(i + 1)})
insert_many(hge_ctx, {"schema": "hge_tests", "name": "articles"}, articles_to_insert)
if hge_key is not None:
headers['X-Hasura-Admin-Secret'] = hge_key
subscrPayload = { 'query': query, 'variables': { 'batch_size': 2 } }
respLive = ws_client.send_query(subscrPayload, query_id='stream_1', headers=headers, timeout=15)
respLive = ws_client.send_query(subscrPayload, query_id='stream_1', timeout=15)
liveQs.append(respLive)
for idx in range(5):
ev = next(respLive)
@ -511,7 +513,6 @@ class TestStreamingSubscription:
Create connection using connection_init
'''
ws_client.init_as_admin()
headers={}
query = """
subscription ($batch_size: Int!, $initial_created_at: timestamptz!) {
hge_tests_stream_query: hge_tests_test_t2_stream(cursor: [{initial_value: {created_at: $initial_created_at}, ordering: ASC}], batch_size: $batch_size) {
@ -525,7 +526,7 @@ class TestStreamingSubscription:
conf = yaml.load(c)
subscrPayload = { 'query': query, 'variables': { 'batch_size': 2, 'initial_created_at': "2020-01-01" } }
respLive = ws_client.send_query(subscrPayload, query_id='stream_1', headers=headers, timeout=15)
respLive = ws_client.send_query(subscrPayload, query_id='stream_1', timeout=15)
assert isinstance(conf, list) == True, 'Not an list'
for index, step in enumerate(conf):
@ -566,7 +567,10 @@ class TestStreamingSubscription:
'''
Create connection using connection_init
'''
ws_client.init_as_admin()
headers={}
if hge_key is not None:
headers['X-Hasura-Admin-Secret'] = hge_key
ws_client.init(headers=headers)
query = """
subscription ($batch_size: Int!) {
@ -578,11 +582,8 @@ class TestStreamingSubscription:
"""
liveQs = []
headers={}
if hge_key is not None:
headers['X-Hasura-Admin-Secret'] = hge_key
subscrPayload = { 'query': query, 'variables': { 'batch_size': 1 } }
respLive = ws_client.send_query(subscrPayload, query_id='stream_1', headers=headers, timeout=15)
respLive = ws_client.send_query(subscrPayload, query_id='stream_1', timeout=15)
liveQs.append(respLive)
for idx in range(2):
ev = next(respLive)
@ -614,7 +615,10 @@ class TestSubscriptionLiveQueriesForGraphQLWS:
'''
Create connection using connection_init
'''
ws_client_graphql_ws.init_as_admin()
headers={}
if hge_key is not None:
headers['X-Hasura-Admin-Secret'] = hge_key
ws_client_graphql_ws.init(headers=headers)
with open(self.dir() + "/steps.yaml") as c:
conf = yaml.load(c)
@ -632,11 +636,8 @@ class TestSubscriptionLiveQueriesForGraphQLWS:
liveQs = []
for i, resultLimit in queries:
query = queryTmplt.replace('{0}',str(i))
headers={}
if hge_key is not None:
headers['X-Hasura-Admin-Secret'] = hge_key
subscrPayload = { 'query': query, 'variables': { 'result_limit': resultLimit } }
respLive = ws_client_graphql_ws.send_query(subscrPayload, query_id='live_'+str(i), headers=headers, timeout=15)
respLive = ws_client_graphql_ws.send_query(subscrPayload, query_id='live_'+str(i), timeout=15)
liveQs.append(respLive)
ev = next(respLive)
assert ev['type'] == 'next', ev
@ -761,12 +762,12 @@ class TestSubscriptionUDFWithSessionArg:
return 'queries/subscriptions/udf_session_args'
def test_user_defined_function_with_session_argument(self, hge_key, ws_client):
ws_client.init_as_admin()
headers = {'x-hasura-role': 'user', 'x-hasura-user-id': '42'}
if hge_key is not None:
headers['X-Hasura-Admin-Secret'] = hge_key
ws_client.init(headers=headers)
payload = {'query': self.query}
resp = ws_client.send_query(payload, headers=headers, timeout=15)
resp = ws_client.send_query(payload, timeout=15)
ev = next(resp)
assert ev['type'] == 'data', ev
assert ev['payload']['data'] == {'me': [{'id': '42', 'name': 'Charlie'}]}, ev['payload']['data']

View File

@ -132,6 +132,7 @@ class TestV1Alpha1GraphQLErrors:
def test_v1alpha1_ws_start_error(self, hge_ctx):
ws_client = GQLWsClient(hge_ctx, '/v1alpha1/graphql')
ws_client.create_conn()
query = {'query': '{ author { name } }'}
frame = {
'id': '1',

View File

@ -96,10 +96,6 @@ class TestWebhookMetadataInPOSTModeWithTLS(AbstractTestWebhookMetadata): pass
class TestWebhookSubscriptionExpiry(object):
EXPIRE_TIME_FORMAT = '%a, %d %b %Y %T GMT'
@pytest.fixture(scope='function', autouse=True)
def ws_conn_recreate(self, ws_client):
ws_client.recreate_conn()
def test_expiry_with_no_header(self, ws_client):
# no expiry time => the connextion will remain alive
self.connect_with(ws_client, {})

View File

@ -302,17 +302,13 @@ def validate_gql_ws_q(hge_ctx, conf, headers, retry=False, via_subscription=Fals
else:
ws_client = hge_ctx.ws_client
print(ws_client.ws_url)
if not headers or len(headers) == 0:
ws_client.init({})
ws_client.init_as_admin()
else:
ws_client.init(headers)
if ws_client.remote_closed or ws_client.is_closing:
ws_client.create_conn()
if not headers or len(headers) == 0 or hge_ctx.hge_key is None:
ws_client.init()
else:
ws_client.init_as_admin()
query_resp = ws_client.send_query(query, query_id='hge_test', headers=headers, timeout=15)
query_resp = ws_client.send_query(query, query_id='hge_test', timeout=15)
resp = next(query_resp)
print('websocket resp: ', resp)
@ -320,7 +316,7 @@ def validate_gql_ws_q(hge_ctx, conf, headers, retry=False, via_subscription=Fals
if retry:
#Got query complete before payload. Retry once more
print("Got query complete before getting query response payload. Retrying")
ws_client.recreate_conn()
ws_client.tear_down()
return validate_gql_ws_q(hge_ctx, query, headers, exp_http_response, False)
else:
assert resp['type'] in ['data', 'error', 'next'], resp