diff --git a/docs/docs/deployment/graphql-engine-flags/reference.mdx b/docs/docs/deployment/graphql-engine-flags/reference.mdx index 7dcf1489c90..7f9b5c830a1 100644 --- a/docs/docs/deployment/graphql-engine-flags/reference.mdx +++ b/docs/docs/deployment/graphql-engine-flags/reference.mdx @@ -1317,8 +1317,8 @@ Enable WebSocket `permessage-deflate` compression. ### Websocket Connection Init Timeout -Used to set the connection initialization timeout for `graphql-ws` clients. This is ignored for -[`subscription-transport-ws` (Apollo) clients](/subscriptions/postgres/index.mdx). +Used to set the connection initialization timeout for GraphQL subscription protocols +(`graphql-transport-ws` and `subscriptions-transport-ws`). | | | | ------------------- | ---------------------------------------------------------- | @@ -1330,7 +1330,7 @@ Used to set the connection initialization timeout for `graphql-ws` clients. This ### Websocket Keepalive -Used to set the `Keep Alive` delay for clients that use the `subscription-transport-ws` (Apollo) protocol. For +Used to set the `Keep Alive` delay for clients that use the `subscriptions-transport-ws` (Apollo) protocol. For `graphql-ws` clients, the `graphql-engine` sends `PING` messages instead. | | | diff --git a/server/src-lib/Hasura/GraphQL/Transport/WebSocket.hs b/server/src-lib/Hasura/GraphQL/Transport/WebSocket.hs index cc5bcf8e861..fd965d2c8de 100644 --- a/server/src-lib/Hasura/GraphQL/Transport/WebSocket.hs +++ b/server/src-lib/Hasura/GraphQL/Transport/WebSocket.hs @@ -1248,7 +1248,8 @@ onConnInit logger manager wsConn getAuthMode connParamsM onConnInitErrAction kee liftIO $ do $assertNFHere csInit -- so we don't write thunks to mutable vars STM.atomically $ STM.writeTVar (_wscUser $ WS.getData wsConn) csInit - + -- mark the connection as initialised in the connection + liftIO $ WS.setConnInitialized wsConn sendMsg wsConn SMConnAck liftIO $ keepAliveMessageAction wsConn where diff --git a/server/src-lib/Hasura/GraphQL/Transport/WebSocket/Protocol.hs b/server/src-lib/Hasura/GraphQL/Transport/WebSocket/Protocol.hs index c863d009e2f..30d8c19896a 100644 --- a/server/src-lib/Hasura/GraphQL/Transport/WebSocket/Protocol.hs +++ b/server/src-lib/Hasura/GraphQL/Transport/WebSocket/Protocol.hs @@ -15,22 +15,22 @@ module Hasura.GraphQL.Transport.WebSocket.Protocol ServerMsgType (..), StartMsg (StartMsg), StopMsg (StopMsg), - WSConnInitTimerStatus (Done), + WSConnInitTimeoutStatus (..), + WSConnInitTimeout, WSSubProtocol (..), encodeServerErrorMsg, encodeServerMsg, - getNewWSTimer, - getWSTimerState, keepAliveMessage, showSubProtocol, toWSSubProtocol, + newWSConnInitTimeout, + runTimer, -- * exported for testing unsafeMkOperationId, ) where -import Control.Concurrent import Control.Concurrent.Extended (sleep) import Control.Concurrent.STM import Data.Aeson qualified as J @@ -312,32 +312,25 @@ encodeServerMsg msg = ] Nothing -> [encTy msgType] --- This "timer" is necessary while initialising the connection --- with the server. Also, this is specific to the GraphQL-WS protocol. -data WSConnInitTimerStatus = Running | Done +-- Status for connection initialisation in sub-protocol +-- This is used to timeout the 'connection_init' message sent by the client +data WSConnInitTimeoutStatus = Initialized | TimedOut deriving stock (Show, Eq) -type WSConnInitTimer = (TVar WSConnInitTimerStatus, TMVar ()) +type WSConnInitTimeout = TMVar WSConnInitTimeoutStatus -getWSTimerState :: WSConnInitTimer -> IO WSConnInitTimerStatus -getWSTimerState (timerState, _) = readTVarIO timerState +newWSConnInitTimeout :: IO WSConnInitTimeout +newWSConnInitTimeout = newEmptyTMVarIO -{-# ANN getNewWSTimer ("HLint: ignore Use withAsync" :: String) #-} -getNewWSTimer :: Seconds -> IO WSConnInitTimer -getNewWSTimer timeout = do - timerState <- newTVarIO Running - timer <- newEmptyTMVarIO - void - $ forkIO - $ do - labelMe "getNewWSTimer" - sleep (seconds timeout) - atomically $ do - runTimerState <- readTVar timerState - case runTimerState of - Running -> do - -- time's up, we set status to "Done" - writeTVar timerState Done - putTMVar timer () - Done -> pure () - pure (timerState, timer) +-- | Run the timer for the given timeout duration +runTimer :: Seconds -> WSConnInitTimeout -> IO () +runTimer timeout timer = do + -- sleep for the timeout duration + sleep (seconds timeout) + atomically $ do + -- check the status of the timer + timerState <- tryReadTMVar timer + -- if the timer is not set, set it to 'TimedOut' + case timerState of + Nothing -> writeTMVar timer TimedOut + Just _ -> pure () diff --git a/server/src-lib/Hasura/GraphQL/Transport/WebSocket/Server.hs b/server/src-lib/Hasura/GraphQL/Transport/WebSocket/Server.hs index f121b5bda8a..9290c69f013 100644 --- a/server/src-lib/Hasura/GraphQL/Transport/WebSocket/Server.hs +++ b/server/src-lib/Hasura/GraphQL/Transport/WebSocket/Server.hs @@ -28,6 +28,7 @@ module Hasura.GraphQL.Transport.WebSocket.Server getData, getRawWebSocketConnection, getWSId, + setConnInitialized, mkWSServerErrorCode, sendMsg, shutdown, @@ -40,7 +41,6 @@ where import Control.Concurrent.Async qualified as A import Control.Concurrent.Async.Lifted.Safe qualified as LA import Control.Concurrent.Extended (sleep) -import Control.Concurrent.STM (readTVarIO) import Control.Concurrent.STM qualified as STM import Control.Exception.Lifted import Control.Monad.Trans.Control qualified as MC @@ -110,6 +110,7 @@ $(J.deriveToJSON hasuraJSON ''MessageDetails) data WSEvent = EConnectionRequest + | EConnectionTimeout | EAccepted | ERejected | EMessageReceived !MessageDetails @@ -201,6 +202,7 @@ data WSConn a = WSConn _wcLogger :: !(L.Logger L.Hasura), _wcConnRaw :: !WS.Connection, _wcSendQ :: !(STM.TQueue WSQueueResponse), + _wsConnInitTimer :: !WSConnInitTimeout, _wcExtraData :: !a } @@ -216,6 +218,10 @@ getWSId = _wcConnId closeConn :: WSConn a -> BL.ByteString -> IO () closeConn wsConn = closeConnWithCode wsConn 1000 -- 1000 is "normal close" +setConnInitialized :: WSConn a -> IO () +setConnInitialized wsConn = + STM.atomically $ STM.writeTMVar (_wsConnInitTimer wsConn) Initialized + -- | Closes a connection with code 1012, which means "Server is restarting" -- good clients will implement a retry logic with a backoff of a few seconds forceConnReconnect :: (MonadIO m) => WSConn a -> BL.ByteString -> m () @@ -380,7 +386,7 @@ websocketConnectionReaper getLatestConfig getSchemaCache ws@(WSServer _ userConf forever $ do (currAuthMode, currEnableAllowlist, currCorsPolicy, currSqlGenCtx, currExperimentalFeatures, currDefaultNamingCase) <- getLatestConfig currAllowlist <- scAllowlist <$> getSchemaCache - SecuritySensitiveUserConfig prevAuthMode prevEnableAllowlist prevAllowlist prevCorsPolicy prevSqlGenCtx prevExperimentalFeatures prevDefaultNamingCase <- readTVarIO userConfRef + SecuritySensitiveUserConfig prevAuthMode prevEnableAllowlist prevAllowlist prevCorsPolicy prevSqlGenCtx prevExperimentalFeatures prevDefaultNamingCase <- STM.readTVarIO userConfRef -- check and close all connections if required checkAndReapConnections (currAuthMode, prevAuthMode) @@ -513,9 +519,9 @@ createServerApp :: createServerApp getMetricsConfig wsConnInitTimeout (WSServer logger@(L.Logger writeLog) _ serverStatus) prometheusMetrics wsHandlers !ipAddress !pendingConn = do wsId <- WSId <$> liftIO UUID.nextRandom logWSLog logger $ WSLog wsId EConnectionRequest Nothing - -- NOTE: this timer is specific to `graphql-ws`. the server has to close the connection -- if the client doesn't send a `connection_init` message within the timeout period - wsConnInitTimer <- liftIO $ getNewWSTimer (unrefine $ unWSConnectionInitTimeout wsConnInitTimeout) + -- the server has to close the connection + wsConnInitTimer <- liftIO newWSConnInitTimeout status <- liftIO $ STM.readTVarIO serverStatus case status of AcceptingConns _ -> logUnexpectedExceptions $ do @@ -574,7 +580,7 @@ createServerApp getMetricsConfig wsConnInitTimeout (WSServer logger@(L.Logger wr conn <- liftIO $ WS.acceptRequestWith pendingConn acceptWithParams logWSLog logger $ WSLog wsId EAccepted Nothing sendQ <- liftIO STM.newTQueueIO - let !wsConn = WSConn wsId logger conn sendQ a + let !wsConn = WSConn wsId logger conn sendQ wsConnInitTimer a -- TODO there are many thunks here. Difficult to trace how much is retained, and -- how much of that would be shared anyway. -- Requires a fork of 'wai-websockets' and 'websockets', it looks like. @@ -657,31 +663,52 @@ createServerApp getMetricsConfig wsConnInitTimeout (WSServer logger@(L.Logger wr (realToFrac messageWriteTime) logWSLog logger $ WSLog wsId (EMessageSent messageDetails) wsInfo + let connInitTimer = liftIO $ do + labelMe "WebSocket connInitTimer" + runTimer (unrefine $ unWSConnectionInitTimeout wsConnInitTimeout) wsConnInitTimer + -- withAsync lets us be very sure that if e.g. an async exception is raised while we're -- forking that the threads we launched will be cleaned up. See also below. - LA.withAsync rcv $ \rcvRef -> do - LA.withAsync send $ \sendRef -> do - LA.withAsync (liftIO $ labelMe "WebSocket keepAlive" >> keepAlive wsConn) $ \keepAliveRef -> do - LA.withAsync (liftIO $ labelMe "WebSocket onJwtExpiry" >> onJwtExpiry wsConn) $ \onJwtExpiryRef -> do - -- once connection is accepted, check the status of the timer, and if it's expired, close the connection for `graphql-ws` - timeoutStatus <- liftIO $ getWSTimerState wsConnInitTimer - when (timeoutStatus == Done && subProtocol == GraphQLWS) - $ liftIO - $ closeConnWithCode wsConn 4408 "Connection initialisation timed out" - - -- terminates on WS.ConnectionException and JWT expiry - let waitOnRefs = [keepAliveRef, onJwtExpiryRef, rcvRef, sendRef] - -- withAnyCancel re-raises exceptions from forkedThreads, and is guarenteed to cancel in - -- case of async exceptions raised while blocking here: - try (LA.waitAnyCancel waitOnRefs) >>= \case - -- NOTE: 'websockets' is a bit of a rat's nest at the moment wrt - -- exceptions; for now handle all ConnectionException by closing - -- and cleaning up, see: https://github.com/jaspervdj/websockets/issues/48 - Left (_ :: WS.ConnectionException) -> do - logWSLog logger $ WSLog (_wcConnId wsConn) ECloseReceived Nothing - -- this will happen when jwt is expired - Right _ -> do - logWSLog logger $ WSLog (_wcConnId wsConn) EJwtExpired Nothing + LA.withAsync connInitTimer $ \connInitTimerRef -> do + LA.withAsync rcv $ \rcvRef -> do + LA.withAsync send $ \sendRef -> do + LA.withAsync (liftIO $ labelMe "WebSocket keepAlive" >> keepAlive wsConn) $ \keepAliveRef -> do + LA.withAsync (liftIO $ labelMe "WebSocket onJwtExpiry" >> onJwtExpiry wsConn) $ \onJwtExpiryRef -> do + -- Wait for connection init status, and then wait for any of the threads to terminate + -- The code below will block until the status is updated by either the timer thread or the GraphQL protocol connection initialization. + liftIO (STM.atomically $ STM.takeTMVar wsConnInitTimer) >>= \case + TimedOut -> do + -- send close message + let timeoutMessage = "Connection initialisation timed out" + case subProtocol of + GraphQLWS -> + -- https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md#connectioninit + liftIO $ closeConnWithCode wsConn 4408 timeoutMessage + Apollo -> + -- 1011 is an unexpected condition prevented the request from being fulfilled. + -- NOTE: The protocol spec does not define a close code for connection timeouts, so we're using the close code from the reference implementation. + -- https://github.com/apollographql/subscriptions-transport-ws/blob/36f3f6f780acc1a458b768db13fd39c65e5e6518/src/server.ts#L159 + liftIO $ closeConnWithCode wsConn 1011 timeoutMessage + -- log connection timeout + logWSLog logger $ WSLog (_wcConnId wsConn) EConnectionTimeout Nothing + -- terminate all threads + mapM_ LA.cancel [keepAliveRef, onJwtExpiryRef, rcvRef, sendRef] + Initialized -> do + -- terminate the timer thread + LA.cancel connInitTimerRef + -- terminates on WS.ConnectionException and JWT expiry + let waitOnRefs = [keepAliveRef, onJwtExpiryRef, rcvRef, sendRef] + -- withAnyCancel re-raises exceptions from forkedThreads, and is guarenteed to cancel in + -- case of async exceptions raised while blocking here: + try (LA.waitAnyCancel waitOnRefs) >>= \case + -- NOTE: 'websockets' is a bit of a rat's nest at the moment wrt + -- exceptions; for now handle all ConnectionException by closing + -- and cleaning up, see: https://github.com/jaspervdj/websockets/issues/48 + Left (_ :: WS.ConnectionException) -> do + logWSLog logger $ WSLog (_wcConnId wsConn) ECloseReceived Nothing + -- this will happen when jwt is expired + Right _ -> do + logWSLog logger $ WSLog (_wcConnId wsConn) EJwtExpired Nothing onConnClose wsConn = \case ShuttingDown -> pure () diff --git a/server/tests-py/context.py b/server/tests-py/context.py index 8cdfa5202c1..c7111fd3350 100644 --- a/server/tests-py/context.py +++ b/server/tests-py/context.py @@ -43,7 +43,7 @@ class GQLWsClient(): self.ws_queue = queue.Queue(maxsize=-1) self.ws_url = urlparse(hge_ctx.hge_url)._replace(scheme='ws', path=endpoint) - self.create_conn() + self.conn_created = False def create_conn(self): self.ws_queue.queue.clear() @@ -60,10 +60,7 @@ class GQLWsClient(): self.wst = threading.Thread(target=self._ws.run_forever) self.wst.daemon = True self.wst.start() - - def recreate_conn(self): - self.teardown() - self.create_conn() + self.conn_created = True def wait_for_connection(self, timeout=10): assert not self.is_closing @@ -78,7 +75,7 @@ class GQLWsClient(): def get_ws_query_event(self, query_id, timeout): return self.ws_id_query_queues[query_id].get(timeout=timeout) - def send(self, frame, count=0): + def send(self, frame, headers={}, count=0): self.wait_for_connection() if frame.get('type') == 'stop': self.ws_active_query_ids.discard( frame.get('id') ) @@ -90,7 +87,7 @@ class GQLWsClient(): if count > 2: raise websocket.WebSocketConnectionClosedException("Connection is already closed and cannot be recreated even after 3 attempts") # Connection closed, try to recreate the connection and send the frame again - self.recreate_conn() + self.init(headers) self.send(frame, count+1) def init_as_admin(self): @@ -100,6 +97,7 @@ class GQLWsClient(): self.init(headers) def init(self, headers={}): + self.create_conn() payload = {'type': 'connection_init', 'payload': {}} if headers and len(headers) > 0: @@ -121,13 +119,10 @@ class GQLWsClient(): return self.gen_id(size, chars) return new_id - def send_query(self, query, query_id=None, headers={}, timeout=60): - graphql.parse(query['query']) - if headers and len(headers) > 0: - #Do init If headers are provided - self.init(headers) - elif not self.init_done: + def send_query(self, query, query_id=None, timeout=60): + if not self.init_done: self.init() + graphql.parse(query['query']) if query_id == None: query_id = self.gen_id() frame = { @@ -169,10 +164,12 @@ class GQLWsClient(): return self.remote_closed or self.is_closing def teardown(self): - self.is_closing = True - if not self.remote_closed: - self._ws.close() - self.wst.join() + if self.conn_created: + self.is_closing = True + if not self.remote_closed: + self._ws.close() + self.wst.join() + self.conn_created = False # NOTE: use this to generate a GraphQL client that uses the `graphql-ws` sub-protocol class GraphQLWSClient(): @@ -182,7 +179,7 @@ class GraphQLWSClient(): self.ws_queue = queue.Queue(maxsize=-1) self.ws_url = urlparse(hge_ctx.hge_url)._replace(scheme='ws', path=endpoint) - self.create_conn() + self.conn_created = False def get_queue(self): return self.ws_queue.queue @@ -205,10 +202,7 @@ class GraphQLWSClient(): self.wst = threading.Thread(target=self._ws.run_forever) self.wst.daemon = True self.wst.start() - - def recreate_conn(self): - self.teardown() - self.create_conn() + self.conn_created = True def wait_for_connection(self, timeout=10): assert not self.is_closing @@ -239,6 +233,7 @@ class GraphQLWSClient(): self.init(headers) def init(self, headers={}): + self.create_conn() payload = {'type': 'connection_init', 'payload': {}} if headers and len(headers) > 0: @@ -260,13 +255,9 @@ class GraphQLWSClient(): return self.gen_id(size, chars) return new_id - def send_query(self, query, query_id=None, headers={}, timeout=60): + def send_query(self, query, query_id=None, timeout=60): graphql.parse(query['query']) - if headers and len(headers) > 0: - #Do init If headers are provided - self.clear_queue() - self.init(headers) - elif not self.init_done: + if not self.init_done: self.init() if query_id == None: query_id = self.gen_id() @@ -318,10 +309,11 @@ class GraphQLWSClient(): return self.remote_closed or self.is_closing def teardown(self): - self.is_closing = True - if not self.remote_closed: - self._ws.close() - self.wst.join() + if self.conn_created: + self.is_closing = True + if not self.remote_closed: + self._ws.close() + self.wst.join() class ActionsWebhookHandler(http.server.BaseHTTPRequestHandler): hge_url: str @@ -387,7 +379,7 @@ class ActionsWebhookHandler(http.server.BaseHTTPRequestHandler): elif req_path == "/null-response": resp, status = self.null_response() self._send_response(status, resp) - + elif req_path == "/omitted-response-field": self._send_response( HTTPStatus.OK, @@ -624,7 +616,7 @@ class ActionsWebhookHandler(http.server.BaseHTTPRequestHandler): 'id': 1, 'child': None } - + def get_omitted_response_field(self): return { 'country': 'India' diff --git a/server/tests-py/test_logging.py b/server/tests-py/test_logging.py index cdc3e8b39ac..56a560e59ec 100644 --- a/server/tests-py/test_logging.py +++ b/server/tests-py/test_logging.py @@ -31,7 +31,7 @@ def parse_logs(stream): pytestmark = [ pytest.mark.capture_hge_logs, pytest.mark.admin_secret, - pytest.mark.hge_env('HASURA_GRAPHQL_LOG_LEVEL', 'debug'), + pytest.mark.hge_env('HASURA_GRAPHQL_LOG_LEVEL', 'debug'), ] @@ -259,17 +259,18 @@ class TestWebsocketLogging: @pytest.fixture(scope='class', autouse=True) def make_requests(self, hge_ctx, ws_client): + ''' + Create connection using connection_init + ''' + headers = {'x-request-id': self.query_id} + if hge_ctx.hge_key: + headers['x-hasura-admin-secret'] = hge_ctx.hge_key + ws_client.init(headers=headers) # setup some tables hge_ctx.v1q_f(self.dir + '/setup.yaml') # make a successful websocket query - headers = {'x-request-id': self.query_id} - if hge_ctx.hge_key: - headers['x-hasura-admin-secret'] = hge_ctx.hge_key - - resp = ws_client.send_query(self.query, headers=headers, - query_id=self.query_id, - timeout=5) + resp = ws_client.send_query(self.query, query_id=self.query_id, timeout=5) try: ev = next(resp) assert ev['type'] == 'data' and ev['id'] == self.query_id, ev diff --git a/server/tests-py/test_subscriptions.py b/server/tests-py/test_subscriptions.py index 1fc609694d9..91ae3528e7e 100644 --- a/server/tests-py/test_subscriptions.py +++ b/server/tests-py/test_subscriptions.py @@ -27,6 +27,7 @@ def ws_conn_init_graphql_ws(hge_key, ws_client_graphql_ws): # This is used in other test files! Be careful when modifying it. def init_ws_conn(hge_key, ws_client, payload = None): + ws_client.create_conn() init_msg = { 'type': 'connection_init', 'payload': payload or ws_payload(hge_key), @@ -36,6 +37,7 @@ def init_ws_conn(hge_key, ws_client, payload = None): assert ev['type'] == 'connection_ack', ev def init_graphql_ws_conn(hge_key, ws_client_graphql_ws): + ws_client_graphql_ws.create_conn() init_msg = { 'type': 'connection_init', 'payload': ws_payload(hge_key), @@ -67,7 +69,6 @@ def get_explain_graphql_query_response(hge_ctx, hge_key, query, variables, user_ @pytest.mark.no_admin_secret class TestSubscriptionCtrlWithoutSecret(object): def test_connection(self, ws_client): - ws_client.recreate_conn() init_ws_conn(None, ws_client) obj = { @@ -86,7 +87,6 @@ class TestSubscriptionCtrl(object): ''' def test_connection(self, hge_key, ws_client): - ws_client.recreate_conn() init_ws_conn(hge_key, ws_client) obj = { @@ -104,6 +104,7 @@ class TestSubscriptionBasicNoAuth: def test_closed_connection_apollo(self, ws_client): # sends empty header so that there is not authentication present in the test + ws_client.create_conn() init_msg = { 'type': 'connection_init', 'payload':{'headers':{}} @@ -115,6 +116,7 @@ class TestSubscriptionBasicNoAuth: def test_closed_connection_graphql_ws(self, ws_client_graphql_ws): # sends empty header so that there is not authentication present in the test + ws_client_graphql_ws.create_conn() init_msg = { 'type': 'connection_init', 'payload':{'headers':{}} @@ -383,7 +385,10 @@ class TestSubscriptionLiveQueries: ''' Create connection using connection_init ''' - ws_client.init_as_admin() + headers={} + if hge_key is not None: + headers['X-Hasura-Admin-Secret'] = hge_key + ws_client.init(headers=headers) with open(self.dir() + "/steps.yaml") as c: conf = yaml.load(c) @@ -401,11 +406,8 @@ class TestSubscriptionLiveQueries: liveQs = [] for i, resultLimit in queries: query = queryTmplt.replace('{0}',str(i)) - headers={} - if hge_key is not None: - headers['X-Hasura-Admin-Secret'] = hge_key subscrPayload = { 'query': query, 'variables': { 'result_limit': resultLimit } } - respLive = ws_client.send_query(subscrPayload, query_id='live_'+str(i), headers=headers, timeout=15) + respLive = ws_client.send_query(subscrPayload, query_id='live_'+str(i), timeout=15) liveQs.append(respLive) ev = next(respLive) assert ev['type'] == 'data', ev @@ -466,7 +468,10 @@ class TestStreamingSubscription: ''' Create connection using connection_init ''' - ws_client.init_as_admin() + headers={} + if hge_key is not None: + headers['X-Hasura-Admin-Secret'] = hge_key + ws_client.init(headers=headers) query = """ subscription ($batch_size: Int!) { @@ -478,15 +483,12 @@ class TestStreamingSubscription: """ liveQs = [] - headers={} articles_to_insert = [] for i in range(10): articles_to_insert.append({"id": i + 1, "title": "Article title {}".format(i + 1)}) insert_many(hge_ctx, {"schema": "hge_tests", "name": "articles"}, articles_to_insert) - if hge_key is not None: - headers['X-Hasura-Admin-Secret'] = hge_key subscrPayload = { 'query': query, 'variables': { 'batch_size': 2 } } - respLive = ws_client.send_query(subscrPayload, query_id='stream_1', headers=headers, timeout=15) + respLive = ws_client.send_query(subscrPayload, query_id='stream_1', timeout=15) liveQs.append(respLive) for idx in range(5): ev = next(respLive) @@ -511,7 +513,6 @@ class TestStreamingSubscription: Create connection using connection_init ''' ws_client.init_as_admin() - headers={} query = """ subscription ($batch_size: Int!, $initial_created_at: timestamptz!) { hge_tests_stream_query: hge_tests_test_t2_stream(cursor: [{initial_value: {created_at: $initial_created_at}, ordering: ASC}], batch_size: $batch_size) { @@ -525,7 +526,7 @@ class TestStreamingSubscription: conf = yaml.load(c) subscrPayload = { 'query': query, 'variables': { 'batch_size': 2, 'initial_created_at': "2020-01-01" } } - respLive = ws_client.send_query(subscrPayload, query_id='stream_1', headers=headers, timeout=15) + respLive = ws_client.send_query(subscrPayload, query_id='stream_1', timeout=15) assert isinstance(conf, list) == True, 'Not an list' for index, step in enumerate(conf): @@ -566,7 +567,10 @@ class TestStreamingSubscription: ''' Create connection using connection_init ''' - ws_client.init_as_admin() + headers={} + if hge_key is not None: + headers['X-Hasura-Admin-Secret'] = hge_key + ws_client.init(headers=headers) query = """ subscription ($batch_size: Int!) { @@ -578,11 +582,8 @@ class TestStreamingSubscription: """ liveQs = [] - headers={} - if hge_key is not None: - headers['X-Hasura-Admin-Secret'] = hge_key subscrPayload = { 'query': query, 'variables': { 'batch_size': 1 } } - respLive = ws_client.send_query(subscrPayload, query_id='stream_1', headers=headers, timeout=15) + respLive = ws_client.send_query(subscrPayload, query_id='stream_1', timeout=15) liveQs.append(respLive) for idx in range(2): ev = next(respLive) @@ -614,7 +615,10 @@ class TestSubscriptionLiveQueriesForGraphQLWS: ''' Create connection using connection_init ''' - ws_client_graphql_ws.init_as_admin() + headers={} + if hge_key is not None: + headers['X-Hasura-Admin-Secret'] = hge_key + ws_client_graphql_ws.init(headers=headers) with open(self.dir() + "/steps.yaml") as c: conf = yaml.load(c) @@ -632,11 +636,8 @@ class TestSubscriptionLiveQueriesForGraphQLWS: liveQs = [] for i, resultLimit in queries: query = queryTmplt.replace('{0}',str(i)) - headers={} - if hge_key is not None: - headers['X-Hasura-Admin-Secret'] = hge_key subscrPayload = { 'query': query, 'variables': { 'result_limit': resultLimit } } - respLive = ws_client_graphql_ws.send_query(subscrPayload, query_id='live_'+str(i), headers=headers, timeout=15) + respLive = ws_client_graphql_ws.send_query(subscrPayload, query_id='live_'+str(i), timeout=15) liveQs.append(respLive) ev = next(respLive) assert ev['type'] == 'next', ev @@ -761,12 +762,12 @@ class TestSubscriptionUDFWithSessionArg: return 'queries/subscriptions/udf_session_args' def test_user_defined_function_with_session_argument(self, hge_key, ws_client): - ws_client.init_as_admin() headers = {'x-hasura-role': 'user', 'x-hasura-user-id': '42'} if hge_key is not None: headers['X-Hasura-Admin-Secret'] = hge_key + ws_client.init(headers=headers) payload = {'query': self.query} - resp = ws_client.send_query(payload, headers=headers, timeout=15) + resp = ws_client.send_query(payload, timeout=15) ev = next(resp) assert ev['type'] == 'data', ev assert ev['payload']['data'] == {'me': [{'id': '42', 'name': 'Charlie'}]}, ev['payload']['data'] diff --git a/server/tests-py/test_v1alpha1_endpoint.py b/server/tests-py/test_v1alpha1_endpoint.py index ce2221e7335..7717c69c585 100644 --- a/server/tests-py/test_v1alpha1_endpoint.py +++ b/server/tests-py/test_v1alpha1_endpoint.py @@ -132,6 +132,7 @@ class TestV1Alpha1GraphQLErrors: def test_v1alpha1_ws_start_error(self, hge_ctx): ws_client = GQLWsClient(hge_ctx, '/v1alpha1/graphql') + ws_client.create_conn() query = {'query': '{ author { name } }'} frame = { 'id': '1', diff --git a/server/tests-py/test_webhook.py b/server/tests-py/test_webhook.py index c1aa57e71bb..aa9741b923d 100644 --- a/server/tests-py/test_webhook.py +++ b/server/tests-py/test_webhook.py @@ -96,10 +96,6 @@ class TestWebhookMetadataInPOSTModeWithTLS(AbstractTestWebhookMetadata): pass class TestWebhookSubscriptionExpiry(object): EXPIRE_TIME_FORMAT = '%a, %d %b %Y %T GMT' - @pytest.fixture(scope='function', autouse=True) - def ws_conn_recreate(self, ws_client): - ws_client.recreate_conn() - def test_expiry_with_no_header(self, ws_client): # no expiry time => the connextion will remain alive self.connect_with(ws_client, {}) diff --git a/server/tests-py/validate.py b/server/tests-py/validate.py index 0b7a377f475..ce298221218 100644 --- a/server/tests-py/validate.py +++ b/server/tests-py/validate.py @@ -302,17 +302,13 @@ def validate_gql_ws_q(hge_ctx, conf, headers, retry=False, via_subscription=Fals else: ws_client = hge_ctx.ws_client print(ws_client.ws_url) + if not headers or len(headers) == 0: - ws_client.init({}) + ws_client.init_as_admin() + else: + ws_client.init(headers) - if ws_client.remote_closed or ws_client.is_closing: - ws_client.create_conn() - if not headers or len(headers) == 0 or hge_ctx.hge_key is None: - ws_client.init() - else: - ws_client.init_as_admin() - - query_resp = ws_client.send_query(query, query_id='hge_test', headers=headers, timeout=15) + query_resp = ws_client.send_query(query, query_id='hge_test', timeout=15) resp = next(query_resp) print('websocket resp: ', resp) @@ -320,7 +316,7 @@ def validate_gql_ws_q(hge_ctx, conf, headers, retry=False, via_subscription=Fals if retry: #Got query complete before payload. Retry once more print("Got query complete before getting query response payload. Retrying") - ws_client.recreate_conn() + ws_client.tear_down() return validate_gql_ws_q(hge_ctx, query, headers, exp_http_response, False) else: assert resp['type'] in ['data', 'error', 'next'], resp