server: support for graphql-ws protocol

https://github.com/hasura/graphql-engine-mono/pull/1655

Co-authored-by: Rakesh Emmadi <12475069+rakeshkky@users.noreply.github.com>
Co-authored-by: Vijay Prasanna <11921040+vijayprasanna13@users.noreply.github.com>
Co-authored-by: hasura-bot <30118761+hasura-bot@users.noreply.github.com>
Co-authored-by: Brandon Simmons <210815+jberryman@users.noreply.github.com>
Co-authored-by: Varun Choudhary <68095256+Varun-Choudhary@users.noreply.github.com>
Co-authored-by: Divi <32202683+imperfect-fourth@users.noreply.github.com>
GitOrigin-RevId: 9db3902388fef06b94f9513255e2b5333bd23c3e
This commit is contained in:
Sameer Kolhar 2021-08-24 21:55:12 +05:30 committed by hasura-bot
parent 3789405e37
commit edeb8c98fd
19 changed files with 1266 additions and 449 deletions

View File

@ -5,8 +5,10 @@
(Add entries below in the order of server, console, cli, docs, others)
- server: optimize SQL query generation with LIMITs (close #5745)
- server: update non-existent event trigger, action and query collection error msgs (close #7396)
- server: fix broken `untrack_function` for non-default source
- server: Adding support for TLS allowlist by domain and service id (port)
- server: add support for `graphql-ws` clients
- console: fix error due too rendering inconsistent object's message
## v2.0.7
@ -16,7 +18,6 @@
- server: fix GraphQL type for remote relationship field (close #7284)
- server: support EdDSA algorithm and key type for JWT
- server: fix GraphQL type for single-row returning functions (close #7109)
- server: update non-existent event trigger, action and query collection error msgs (close #7396)
- console: add support for creation of indexes for Postgres data sources
- docs: document the cleanup process for scheduled triggers
- console: allow same named queries and unnamed queries on allowlist file upload

View File

@ -54,8 +54,11 @@ Communication protocol
Hasura GraphQL engine uses the `GraphQL over WebSocket Protocol
<https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md>`__ by the
`apollographql/subscriptions-transport-ws <https://github.com/apollographql/subscriptions-transport-ws>`__ library
for sending and receiving events.
`apollographql/subscriptions-transport-ws <https://github.com/apollographql/subscriptions-transport-ws>`__ library and the
`GraphQL over WebSocket Protocol <https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md>`__
by the `graphql-ws <https://github.com/enisdenjo/graphql-ws>`__ library for sending and receiving events. The support for
``graphql-ws`` is currently considered as ``BETA``. The graphql-engine uses the ``Sec-WebSocket-Protocol`` header to determine
the server Implementation that'll be used. By default, the graphql-engine will use the ``apollographql/subscriptions-transport-ws`` protocol.
.. admonition:: Setting headers for subscriptions with Apollo client

View File

@ -52,8 +52,11 @@ Communication protocol
Hasura GraphQL engine uses the `GraphQL over WebSocket Protocol
<https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md>`__ by the
`apollographql/subscriptions-transport-ws <https://github.com/apollographql/subscriptions-transport-ws>`__ library
for sending and receiving events.
`apollographql/subscriptions-transport-ws <https://github.com/apollographql/subscriptions-transport-ws>`__ library and the
`GraphQL over WebSocket Protocol <https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md>`__
by the `graphql-ws <https://github.com/enisdenjo/graphql-ws>`__ library for sending and receiving events. The support for
``graphql-ws`` is currently considered as ``BETA``. The graphql-engine uses the ``Sec-WebSocket-Protocol`` header to determine
the server Implementation that'll be used. By default, the graphql-engine will use the ``apollographql/subscriptions-transport-ws`` implementation.
.. admonition:: Setting headers for subscriptions with Apollo client

View File

@ -324,11 +324,21 @@ For the ``serve`` sub-command these are the available flags and environment vari
*(Available for versions > v2.0.0)*
* - ``--websocket-keepalive``
* - ``--websocket-keepalive <SECONDS>``
- ``HASURA_GRAPHQL_WEBSOCKET_KEEPALIVE``
- WebSocket keep-alive timeout in seconds (default: 5)
- Used to set the ``Keep Alive`` delay for client that use the ``subscription-transport-ws`` (Apollo) protocol.
For ``graphql-ws`` clients the graphql-engine sends ``PING`` messages instead.
(default: ``5``)
*(Available for versions > v2.0.0)*
* - ``--websocket-connection-init-timeout <SECONDS>``
- ``HASURA_GRAPHQL_WEBSOCKET_CONNECTION_INIT_TIMEOUT``
- Used to set the connection initialisation timeout for ``graphql-ws`` clients. This is ignored
for ``subscription-transport-ws`` (Apollo) clients.
(default: ``3``)
.. note::

View File

@ -611,7 +611,9 @@ library
, Hasura.GraphQL.Transport.HTTP
, Hasura.GraphQL.Transport.HTTP.Protocol
, Hasura.GraphQL.Transport.Instances
, Hasura.GraphQL.Transport.WSServerApp
, Hasura.GraphQL.Transport.WebSocket
, Hasura.GraphQL.Transport.WebSocket.Types
, Hasura.GraphQL.Transport.WebSocket.Protocol
, Hasura.GraphQL.Transport.WebSocket.Server

View File

@ -537,6 +537,7 @@ runHGEServer setupHook env ServeOptions{..} ServeCtx{..} initTime postPollHook s
soEnableMaintenanceMode
soExperimentalFeatures
_scEnabledLogTypes
soWebsocketConnectionInitTimeout
let serverConfigCtx =
ServerConfigCtx soInferFunctionPermissions

View File

@ -49,7 +49,7 @@ import Hasura.GraphQL.Execute.LiveQuery.Poll
import Hasura.GraphQL.ParameterizedQueryHash (ParameterizedQueryHash)
import Hasura.GraphQL.Transport.Backend
import Hasura.GraphQL.Transport.HTTP.Protocol (OperationName)
import Hasura.GraphQL.Transport.WebSocket.Protocol
import Hasura.GraphQL.Transport.WebSocket.Protocol (OperationId)
import Hasura.RQL.Types.Action
import Hasura.RQL.Types.Common (SourceName, unNonNegativeDiffTime)
import Hasura.Server.Metrics (ServerMetrics (..))

View File

@ -0,0 +1,145 @@
module Hasura.GraphQL.Transport.WSServerApp
( createWSServerApp
, stopWSServerApp
, createWSServerEnv
) where
import Hasura.Prelude
import qualified Control.Concurrent.Async.Lifted.Safe as LA
import qualified Control.Concurrent.STM as STM
import qualified Control.Monad.Trans.Control as MC
import qualified Data.ByteString.Char8 as B (pack)
import qualified Data.Environment as Env
import qualified Network.HTTP.Client as H
import qualified Network.WebSockets as WS
import qualified System.Metrics.Gauge as EKG.Gauge
import Control.Exception.Lifted
import Data.Aeson (toJSON)
import Data.Text (pack, unpack)
import qualified Hasura.GraphQL.Execute as E
import qualified Hasura.GraphQL.Execute.Backend as EB
import qualified Hasura.GraphQL.Execute.LiveQuery.State as LQ
import qualified Hasura.GraphQL.Transport.WebSocket.Server as WS
import qualified Hasura.Logging as L
import qualified Hasura.Tracing as Tracing
import Hasura.GraphQL.Logging
import Hasura.GraphQL.Transport.HTTP (MonadExecuteQuery)
import Hasura.GraphQL.Transport.Instances ()
import Hasura.GraphQL.Transport.WebSocket
import Hasura.GraphQL.Transport.WebSocket.Protocol
import Hasura.GraphQL.Transport.WebSocket.Types
import Hasura.Metadata.Class
import Hasura.RQL.Types
import Hasura.Server.Auth (AuthMode, UserAuthentication)
import Hasura.Server.Cors
import Hasura.Server.Init.Config (KeepAliveDelay,
WSConnectionInitTimeout)
import Hasura.Server.Metrics (ServerMetrics (..))
import Hasura.Server.Version (HasVersion)
createWSServerApp ::
( HasVersion
, MonadIO m
, MC.MonadBaseControl IO m
, LA.Forall (LA.Pure m)
, UserAuthentication (Tracing.TraceT m)
, E.MonadGQLExecutionCheck m
, WS.MonadWSLog m
, MonadQueryLog m
, Tracing.HasReporter m
, MonadExecuteQuery m
, MonadMetadataStorage (MetadataStorageT m)
, EB.MonadQueryTags m
)
=> Env.Environment
-> HashSet (L.EngineLogType L.Hasura)
-> AuthMode
-> WSServerEnv
-> WSConnectionInitTimeout
-> WS.HasuraServerApp m
-- -- ^ aka generalized 'WS.ServerApp'
createWSServerApp env enabledLogTypes authMode serverEnv connInitTimeout = \ !ipAddress !pendingConn ->
WS.createServerApp connInitTimeout (_wseServer serverEnv) handlers ipAddress pendingConn
where
handlers = WS.WSHandlers
onConnHandler
onMessageHandler
onCloseHandler
logger = _wseLogger serverEnv
serverMetrics = _wseServerMetrics serverEnv
wsActions = mkWSActions logger
-- Mask async exceptions during event processing to help maintain integrity of mutable vars:
-- here `sp` stands for sub-protocol
onConnHandler rid rh ip sp = mask_ do
liftIO $ EKG.Gauge.inc $ smWebsocketConnections serverMetrics
flip runReaderT serverEnv $ onConn rid rh ip (wsActions sp)
onMessageHandler conn bs sp = mask_ $
onMessage env enabledLogTypes authMode serverEnv conn bs (wsActions sp)
onCloseHandler conn = mask_ do
liftIO $ EKG.Gauge.dec $ smWebsocketConnections serverMetrics
onClose logger serverMetrics (_wseLiveQMap serverEnv) conn
stopWSServerApp :: WSServerEnv -> IO ()
stopWSServerApp wsEnv = WS.shutdown (_wseServer wsEnv)
createWSServerEnv :: (MonadIO m)
=> L.Logger L.Hasura
-> LQ.LiveQueriesState
-> IO (SchemaCache, SchemaCacheVer)
-> H.Manager
-> CorsPolicy
-> SQLGenCtx
-> Bool
-> KeepAliveDelay
-> ServerMetrics
-> m WSServerEnv
createWSServerEnv logger lqState getSchemaCache httpManager
corsPolicy sqlGenCtx enableAL keepAliveDelay serverMetrics = do
wsServer <- liftIO $ STM.atomically $ WS.createWSServer logger
pure $ WSServerEnv logger lqState getSchemaCache httpManager
corsPolicy sqlGenCtx wsServer enableAL keepAliveDelay serverMetrics
mkWSActions :: L.Logger L.Hasura -> WSSubProtocol -> WS.WSActions WSConnData
mkWSActions logger subProtocol =
WS.WSActions
mkPostExecErrMessageAction
mkOnErrorMessageAction
mkConnectionCloseAction
keepAliveAction
getServerMsgType
mkAcceptRequest
where
mkPostExecErrMessageAction wsConn opId execErr =
sendMsg wsConn $ case subProtocol of
Apollo -> SMData $ DataMsg opId $ throwError execErr
GraphQLWS -> SMErr $ ErrorMsg opId $ toJSON execErr
mkOnErrorMessageAction wsConn err mErrMsg = case subProtocol of
Apollo -> sendMsg wsConn $ SMConnErr err
GraphQLWS -> sendCloseWithMsg logger wsConn (GenericError4400 $ (fromMaybe "" mErrMsg) <> (unpack . unConnErrMsg $ err)) Nothing
mkConnectionCloseAction wsConn opId errMsg =
when (subProtocol == GraphQLWS) $
sendCloseWithMsg logger wsConn (GenericError4400 errMsg) (Just . SMErr $ ErrorMsg opId $ toJSON (pack errMsg))
getServerMsgType = case subProtocol of
Apollo -> SMData
GraphQLWS -> SMNext
keepAliveAction wsConn = sendMsg wsConn $
case subProtocol of
Apollo -> SMConnKeepAlive
GraphQLWS -> SMPing . Just $ keepAliveMessage
mkAcceptRequest = WS.defaultAcceptRequest {
WS.acceptSubprotocol = Just . B.pack . showSubProtocol $ subProtocol
}

View File

@ -1,11 +1,15 @@
-- | This file contains the handlers that are used within websocket server
{-# LANGUAGE CPP #-}
module Hasura.GraphQL.Transport.WebSocket
( createWSServerApp
, createWSServerEnv
, stopWSServerApp
, WSServerEnv
, WSLog(..)
( onConn
, onMessage
, onClose
-- ^ the main handlers for the websocket server
, sendMsg
, sendCloseWithMsg
-- ^ helpers for sending messages to the client
) where
-- NOTE!:
@ -15,7 +19,6 @@ module Hasura.GraphQL.Transport.WebSocket
import Hasura.Prelude
import qualified Control.Concurrent.Async.Lifted.Safe as LA
import qualified Control.Concurrent.STM as STM
import qualified Control.Monad.Trans.Control as MC
import qualified Data.Aeson as J
@ -36,13 +39,10 @@ import qualified Language.GraphQL.Draft.Syntax as G
import qualified ListT
import qualified Network.HTTP.Client as H
import qualified Network.HTTP.Types as H
import qualified Network.Wai.Extended as Wai
import qualified Network.WebSockets as WS
import qualified StmContainers.Map as STMMap
import qualified System.Metrics.Gauge as EKG.Gauge
import Control.Concurrent.Extended (sleep)
import Control.Exception.Lifted
import Data.String
#ifndef PROFILING
import GHC.AssertNF
@ -76,8 +76,9 @@ import Hasura.GraphQL.Transport.HTTP (MonadExecuteQuery
import Hasura.GraphQL.Transport.HTTP.Protocol
import Hasura.GraphQL.Transport.Instances ()
import Hasura.GraphQL.Transport.WebSocket.Protocol
import Hasura.GraphQL.Transport.WebSocket.Types
import Hasura.Metadata.Class
import Hasura.RQL.Types
import Hasura.RQL.Types.RemoteSchema
import Hasura.Server.Auth (AuthMode, UserAuthentication,
resolveUserInfo)
import Hasura.Server.Cors
@ -87,77 +88,12 @@ import Hasura.Server.Types (RequestId, getReq
import Hasura.Server.Version (HasVersion)
import Hasura.Session
-- | 'LQ.LiveQueryId' comes from 'Hasura.GraphQL.Execute.LiveQuery.State.addLiveQuery'. We use
-- this to track a connection's operations so we can remove them from 'LiveQueryState', and
-- log.
--
-- NOTE!: This must be kept consistent with the global 'LiveQueryState', in 'onClose'
-- and 'onStart'.
type OperationMap
= STMMap.Map OperationId (LQ.LiveQueryId, Maybe OperationName)
newtype WsHeaders
= WsHeaders { unWsHeaders :: [H.Header] }
deriving (Show, Eq)
data ErrRespType
= ERTLegacy
| ERTGraphqlCompliant
deriving (Show)
data WSConnState
= CSNotInitialised !WsHeaders !Wai.IpAddress
-- ^ headers and IP address from the client for websockets
| CSInitError !Text
| CSInitialised !WsClientState
data WsClientState
= WsClientState
{ wscsUserInfo :: !UserInfo
-- ^ the 'UserInfo' required to execute the GraphQL query
, wscsTokenExpTime :: !(Maybe TC.UTCTime)
-- ^ the JWT/token expiry time, if any
, wscsReqHeaders :: ![H.Header]
-- ^ headers from the client (in conn params) to forward to the remote schema
, wscsIpAddress :: !Wai.IpAddress
-- ^ IP address required for 'MonadGQLAuthorization'
}
data WSConnData
= WSConnData
-- the role and headers are set only on connection_init message
{ _wscUser :: !(STM.TVar WSConnState)
-- we only care about subscriptions,
-- the other operations (query/mutations)
-- are not tracked here
, _wscOpMap :: !OperationMap
, _wscErrRespTy :: !ErrRespType
, _wscAPIType :: !E.GraphQLQueryType
}
type WSServer = WS.WSServer WSConnData
type WSConn = WS.WSConn WSConnData
sendMsg :: (MonadIO m) => WSConn -> ServerMsg -> m ()
sendMsg wsConn msg =
liftIO $ WS.sendMsg wsConn $ WS.WSQueueResponse (encodeServerMsg msg) Nothing
sendMsgWithMetadata :: (MonadIO m) => WSConn -> ServerMsg -> LQ.LiveQueryMetadata -> m ()
sendMsgWithMetadata wsConn msg (LQ.LiveQueryMetadata execTime) =
liftIO $ WS.sendMsg wsConn $ WS.WSQueueResponse bs wsInfo
where
bs = encodeServerMsg msg
(msgType, operationId) = case msg of
(SMData (DataMsg opId _)) -> (Just SMT_GQL_DATA, Just opId)
_ -> (Nothing, Nothing)
wsInfo = Just $! WS.WSEventInfo
{ WS._wseiEventType = msgType
, WS._wseiOperationId = operationId
, WS._wseiQueryExecutionTime = Just $! realToFrac execTime
, WS._wseiResponseSize = Just $! LBS.length bs
}
data OpDetail
= ODStarted
@ -229,35 +165,91 @@ mkWsErrorLog :: Maybe SessionVariables -> WsConnInfo -> WSEvent -> WSLog
mkWsErrorLog uv ci ev =
WSLog L.LevelError $ WSLogInfo uv ci ev
data WSServerEnv
= WSServerEnv
{ _wseLogger :: !(L.Logger L.Hasura)
, _wseLiveQMap :: !LQ.LiveQueriesState
, _wseGCtxMap :: !(IO (SchemaCache, SchemaCacheVer))
-- ^ an action that always returns the latest version of the schema cache. See 'SchemaCacheRef'.
, _wseHManager :: !H.Manager
, _wseCorsPolicy :: !CorsPolicy
, _wseSQLCtx :: !SQLGenCtx
, _wseServer :: !WSServer
, _wseEnableAllowlist :: !Bool
, _wseKeepAliveDelay :: !KeepAliveDelay
, _wseServerMetrics :: !ServerMetrics
}
logWSEvent :: (MonadIO m)
=> L.Logger L.Hasura
-> WSConn
-> WSEvent
-> m ()
logWSEvent (L.Logger logger) wsConn wsEv = do
userInfoME <- liftIO $ STM.readTVarIO userInfoR
let (userVarsM, tokenExpM) = case userInfoME of
CSInitialised WsClientState{..} -> ( Just $ _uiSession wscsUserInfo
, wscsTokenExpTime
)
_ -> (Nothing, Nothing)
liftIO $ logger $ WSLog logLevel $ WSLogInfo userVarsM (WsConnInfo wsId tokenExpM Nothing) wsEv
where
WSConnData userInfoR _ _ _ = WS.getData wsConn
wsId = WS.getWSId wsConn
logLevel = bool L.LevelInfo L.LevelError isError
isError = case wsEv of
EAccepted -> False
ERejected _ -> True
EConnErr _ -> True
EClosed -> False
EOperation operation -> case _odOperationType operation of
ODStarted -> False
ODProtoErr _ -> True
ODQueryErr _ -> True
ODCompleted -> False
ODStopped -> False
sendMsg :: (MonadIO m) => WSConn -> ServerMsg -> m ()
sendMsg wsConn msg =
liftIO $ WS.sendMsg wsConn $ WS.WSQueueResponse (encodeServerMsg msg) Nothing
sendCloseWithMsg :: (MonadIO m)
=> L.Logger L.Hasura
-> WSConn
-> ServerErrorCode
-> Maybe ServerMsg
-> m ()
sendCloseWithMsg logger wsConn errCode mErrServerMsg = do
case mErrServerMsg of
Just errServerMsg -> do
sendMsg wsConn errServerMsg
Nothing -> pure ()
logWSEvent logger wsConn EClosed
liftIO $ WS.sendClose wsc errMsg
where
wsc = WS.getRawWebSocketConnection wsConn
errMsg = encodeServerErrorMsg errCode
sendMsgWithMetadata :: (MonadIO m) => WSConn -> ServerMsg -> LQ.LiveQueryMetadata -> m ()
sendMsgWithMetadata wsConn msg (LQ.LiveQueryMetadata execTime) =
liftIO $ WS.sendMsg wsConn $ WS.WSQueueResponse bs wsInfo
where
bs = encodeServerMsg msg
(msgType, operationId) = case msg of
(SMNext (DataMsg opId _)) -> (Just SMT_GQL_NEXT, Just opId)
(SMData (DataMsg opId _)) -> (Just SMT_GQL_DATA, Just opId)
_ -> (Nothing, Nothing)
wsInfo = Just $! WS.WSEventInfo
{ WS._wseiEventType = msgType
, WS._wseiOperationId = operationId
, WS._wseiQueryExecutionTime = Just $! realToFrac execTime
, WS._wseiResponseSize = Just $! LBS.length bs
}
onConn :: (MonadIO m, MonadReader WSServerEnv m)
=> WS.OnConnH m WSConnData
onConn wsId requestHead ipAddress = do
onConn wsId requestHead ipAddress onConnHActions = do
res <- runExceptT $ do
(errType, queryType) <- checkPath
let reqHdrs = WS.requestHeaders requestHead
headers <- maybe (return reqHdrs) (flip enforceCors reqHdrs . snd) getOrigin
return (WsHeaders $ filterWsHeaders headers, errType, queryType)
either reject accept res
where
kaAction = WS._wsaKeepAliveAction onConnHActions
acceptRequest = WS._wsaAcceptRequest onConnHActions
-- NOTE: the "Keep-Alive" delay is something that's mentioned
-- in the Apollo spec. For 'graphql-ws', we're using the Ping
-- messages that are part of the spec.
keepAliveAction keepAliveDelay wsConn = do
liftIO $ forever $ do
sendMsg wsConn SMConnKeepAlive
kaAction wsConn
sleep $ seconds (unKeepAliveDelay keepAliveDelay)
tokenExpiryHandler wsConn = do
@ -272,16 +264,21 @@ onConn wsId requestHead ipAddress = do
accept (hdrs, errType, queryType) = do
(L.Logger logger) <- asks _wseLogger
keepAliveDelay <- asks _wseKeepAliveDelay
keepAliveDelay <- asks _wseKeepAliveDelay
logger $ mkWsInfoLog Nothing (WsConnInfo wsId Nothing Nothing) EAccepted
connData <- liftIO $ WSConnData
<$> STM.newTVarIO (CSNotInitialised hdrs ipAddress)
<*> STMMap.newIO
<*> pure errType
<*> pure queryType
let acceptRequest = WS.defaultAcceptRequest
{ WS.acceptSubprotocol = Just "graphql-ws"}
return $ Right $ WS.AcceptWith connData acceptRequest (keepAliveAction keepAliveDelay) tokenExpiryHandler
pure $ Right $
WS.AcceptWith
connData
acceptRequest
(keepAliveAction keepAliveDelay)
tokenExpiryHandler
reject qErr = do
(L.Logger logger) <- asks _wseLogger
logger $ mkWsErrorLog Nothing (WsConnInfo wsId Nothing Nothing) (ERejected qErr)
@ -351,8 +348,9 @@ onStart
-> WSServerEnv
-> WSConn
-> StartMsg
-> WS.WSActions WSConnData
-> m ()
onStart env enabledLogTypes serverEnv wsConn (StartMsg opId q) = catchAndIgnore $ do
onStart env enabledLogTypes serverEnv wsConn (StartMsg opId q) onMessageActions = catchAndIgnore $ do
timerTot <- startTimer
opM <- liftIO $ STM.atomically $ STMMap.lookup opId opMap
@ -517,8 +515,9 @@ onStart env enabledLogTypes serverEnv wsConn (StartMsg opId q) = catchAndIgnore
case resultsE of
Left err -> sendError requestId err
Right results -> do
let dataMsg = SMData $ DataMsg opId $ pure $ encJToLBS $
encJFromInsOrdHashMap $ OMap.mapKeys G.unName results
let dataMsg = sendDataMsg $
DataMsg opId $ pure $ encJToLBS $
encJFromInsOrdHashMap $ OMap.mapKeys G.unName results
sendMsgWithMetadata wsConn dataMsg $ LQ.LiveQueryMetadata dTime
asyncActionQueryLive = LQ.LAAQNoRelationships $
@ -556,6 +555,10 @@ onStart env enabledLogTypes serverEnv wsConn (StartMsg opId q) = catchAndIgnore
liftIO $ logOpEv ODStarted (Just requestId)
where
sendDataMsg = WS._wsaGetDataMessageType onMessageActions
closeConnAction = WS._wsaConnectionCloseAction onMessageActions
postExecErrAction = WS._wsaPostExecErrMessageAction onMessageActions
getExecStepActionWithActionInfo acc execStep = case execStep of
E.ExecStepAction _ actionInfo _remoteJoins -> (actionInfo:acc)
_ -> acc
@ -602,7 +605,7 @@ onStart env enabledLogTypes serverEnv wsConn (StartMsg opId q) = catchAndIgnore
return $ ResultsFragment telemTimeIO_DT Telem.Remote (encJFromOrderedValue value) []
WSServerEnv logger lqMap getSchemaCache httpMgr _ sqlGenCtx
_ enableAL _keepAliveDelay _ = serverEnv
_ enableAL _keepAliveDelay _connInitTime = serverEnv
WSConnData userInfoR opMap errRespTy queryType = WS.getData wsConn
@ -620,9 +623,10 @@ onStart env enabledLogTypes serverEnv wsConn (StartMsg opId q) = catchAndIgnore
sendMsg wsConn $
SMErr $ ErrorMsg opId $ errFn False $ err400 StartFailed e
liftIO $ logOpEv (ODProtoErr e) Nothing
liftIO $ closeConnAction wsConn opId (T.unpack e)
sendCompleted reqId = do
sendMsg wsConn (SMComplete $ CompletionMsg opId)
sendMsg wsConn (SMComplete . CompletionMsg $ opId)
logOpEv ODCompleted reqId
postExecErr :: RequestId -> QErr -> ExceptT () m ()
@ -632,9 +636,7 @@ onStart env enabledLogTypes serverEnv wsConn (StartMsg opId q) = catchAndIgnore
postExecErr' $ GQExecError $ pure $ errFn qErr
postExecErr' :: GQExecError -> ExceptT () m ()
postExecErr' qErr = do
sendMsg wsConn $ SMData $
DataMsg opId $ throwError qErr
postExecErr' qErr = liftIO $ postExecErrAction wsConn opId qErr
-- why wouldn't pre exec error use graphql response?
preExecErr reqId qErr = liftIO $ sendError reqId qErr
@ -649,8 +651,8 @@ onStart env enabledLogTypes serverEnv wsConn (StartMsg opId q) = catchAndIgnore
sendSuccResp :: EncJSON -> LQ.LiveQueryMetadata -> ExceptT () m ()
sendSuccResp encJson =
sendMsgWithMetadata wsConn
(SMData $ DataMsg opId $ pure $ encJToLBS encJson)
sendMsgWithMetadata wsConn $
sendDataMsg $ DataMsg opId $ pure $ encJToLBS encJson
withComplete :: ExceptT () m () -> ExceptT () m a
withComplete action = do
@ -686,54 +688,73 @@ onStart env enabledLogTypes serverEnv wsConn (StartMsg opId q) = catchAndIgnore
liveQOnChange = \case
Right (LQ.LiveQueryResponse bs dTime) ->
sendMsgWithMetadata wsConn
(SMData $ DataMsg opId $ pure $ LBS.fromStrict bs)
(sendDataMsg $ DataMsg opId $ pure $ LBS.fromStrict bs)
(LQ.LiveQueryMetadata dTime)
resp -> sendMsg wsConn $ SMData $ DataMsg opId $ LBS.fromStrict . LQ._lqrPayload <$> resp
resp -> sendMsg wsConn $
sendDataMsg $ DataMsg opId $ LBS.fromStrict . LQ._lqrPayload <$> resp
catchAndIgnore :: ExceptT () m () -> m ()
catchAndIgnore m = void $ runExceptT m
onMessage
:: ( HasVersion
, MonadIO m
, UserAuthentication (Tracing.TraceT m)
, E.MonadGQLExecutionCheck m
, MonadQueryLog m
, Tracing.HasReporter m
, MonadExecuteQuery m
, MC.MonadBaseControl IO m
, MonadMetadataStorage (MetadataStorageT m)
, EB.MonadQueryTags m
)
onMessage ::
( HasVersion
, MonadIO m
, UserAuthentication (Tracing.TraceT m)
, E.MonadGQLExecutionCheck m
, MonadQueryLog m
, Tracing.HasReporter m
, MonadExecuteQuery m
, MC.MonadBaseControl IO m
, MonadMetadataStorage (MetadataStorageT m)
, EB.MonadQueryTags m
)
=> Env.Environment
-> HashSet (L.EngineLogType L.Hasura)
-> AuthMode
-> WSServerEnv
-> WSConn -> LBS.ByteString -> m ()
onMessage env enabledLogTypes authMode serverEnv wsConn msgRaw = Tracing.runTraceT "websocket" do
-> WSConn
-> LBS.ByteString
-> WS.WSActions WSConnData
-> m ()
onMessage env enabledLogTypes authMode serverEnv wsConn msgRaw onMessageActions = Tracing.runTraceT "websocket" do
case J.eitherDecode msgRaw of
Left e -> do
Left e -> do
let err = ConnErrMsg $ "parsing ClientMessage failed: " <> T.pack e
logWSEvent logger wsConn $ EConnErr err
sendMsg wsConn $ SMConnErr err
liftIO $ onErrAction wsConn err WS.onClientMessageParseErrorText
Right msg -> case msg of
CMConnInit params -> onConnInit (_wseLogger serverEnv)
(_wseHManager serverEnv)
wsConn authMode params
CMStart startMsg -> onStart env enabledLogTypes serverEnv wsConn startMsg
CMStop stopMsg -> liftIO $ onStop serverEnv wsConn stopMsg
-- The idea is cleanup will be handled by 'onClose', but...
-- NOTE: we need to close the websocket connection when we receive the
-- CMConnTerm message and calling WS.closeConn will definitely throw an
-- exception, but I'm not sure if 'closeConn' is the correct thing here....
-- common to both protocols
CMConnInit params -> onConnInit logger (_wseHManager serverEnv) wsConn
authMode params onErrAction keepAliveMessageAction
CMStart startMsg -> onStart env enabledLogTypes serverEnv wsConn startMsg onMessageActions
CMStop stopMsg -> onStop serverEnv wsConn stopMsg
-- specfic to graphql-ws
CMPing mPayload -> onPing wsConn mPayload
CMPong mPayload -> onPong wsConn mPayload
-- specific to apollo clients
CMConnTerm -> liftIO $ WS.closeConn wsConn "GQL_CONNECTION_TERMINATE received"
where
logger = _wseLogger serverEnv
onErrAction = WS._wsaOnErrorMessageAction onMessageActions
keepAliveMessageAction = WS._wsaKeepAliveAction onMessageActions
onPing :: (MonadIO m) => WSConn -> Maybe PingPongPayload -> m ()
onPing wsConn mPayload =
liftIO $ sendMsg wsConn (SMPong mPayload)
onStop :: WSServerEnv -> WSConn -> StopMsg -> IO ()
onStop serverEnv wsConn (StopMsg opId) = do
onPong :: (MonadIO m) => WSConn -> Maybe PingPongPayload -> m ()
onPong wsConn mPayload = liftIO $ case mPayload of
Just message -> do
when (message /= keepAliveMessage) $
sendMsg wsConn (SMPing mPayload)
-- NOTE: this is done to avoid sending Ping for every "keepalive" that the server sends
Nothing -> sendMsg wsConn $ SMPing Nothing
onStop :: (MonadIO m) => WSServerEnv -> WSConn -> StopMsg -> m ()
onStop serverEnv wsConn (StopMsg opId) = liftIO $ do
-- When a stop message is received for an operation, it may not be present in OpMap
-- in these cases:
-- 1. If the operation is a query/mutation - as we remove the operation from the
@ -763,38 +784,19 @@ stopOperation serverEnv wsConn opId logWhenOpNotExist = do
opMap = _wscOpMap $ WS.getData wsConn
opDet n = OperationDetails opId Nothing n ODStopped Nothing
logWSEvent
:: (MonadIO m)
=> L.Logger L.Hasura -> WSConn -> WSEvent -> m ()
logWSEvent (L.Logger logger) wsConn wsEv = do
userInfoME <- liftIO $ STM.readTVarIO userInfoR
let (userVarsM, tokenExpM) = case userInfoME of
CSInitialised WsClientState{..} -> ( Just $ _uiSession wscsUserInfo
, wscsTokenExpTime
)
_ -> (Nothing, Nothing)
liftIO $ logger $ WSLog logLevel $ WSLogInfo userVarsM (WsConnInfo wsId tokenExpM Nothing) wsEv
where
WSConnData userInfoR _ _ _ = WS.getData wsConn
wsId = WS.getWSId wsConn
logLevel = bool L.LevelInfo L.LevelError isError
isError = case wsEv of
EAccepted -> False
ERejected _ -> True
EConnErr _ -> True
EClosed -> False
EOperation operation -> case _odOperationType operation of
ODStarted -> False
ODProtoErr _ -> True
ODQueryErr _ -> True
ODCompleted -> False
ODStopped -> False
onConnInit
:: (HasVersion, MonadIO m, UserAuthentication (Tracing.TraceT m))
=> L.Logger L.Hasura -> H.Manager -> WSConn -> AuthMode
-> Maybe ConnParams -> Tracing.TraceT m ()
onConnInit logger manager wsConn authMode connParamsM = do
=> L.Logger L.Hasura
-> H.Manager
-> WSConn
-> AuthMode
-> Maybe ConnParams
-> WS.WSOnErrorMessageAction WSConnData
-- ^ this is the message handler for handling errors on initializing a from the client connection
-> WS.WSKeepAliveMessageAction WSConnData
-- ^ this is the message handler for handling "keep-alive" messages to the client
-> Tracing.TraceT m ()
onConnInit logger manager wsConn authMode connParamsM onConnInitErrAction keepAliveMessageAction = do
-- TODO(from master): what should be the behaviour of connection_init message when a
-- connection is already iniatilized? Currently, we seem to be doing
-- something arbitrary which isn't correct. Ideally, we should stick to
@ -820,7 +822,7 @@ onConnInit logger manager wsConn authMode connParamsM = do
let connErr = ConnErrMsg $ qeError e
logWSEvent logger wsConn $ EConnErr connErr
sendMsg wsConn $ SMConnErr connErr
liftIO $ onConnInitErrAction wsConn connErr WS.onConnInitErrorText
Right (userInfo, expTimeM) -> do
let !csInit = CSInitialised $ WsClientState userInfo expTimeM paramHeaders ipAddress
@ -831,14 +833,12 @@ onConnInit logger manager wsConn authMode connParamsM = do
STM.atomically $ STM.writeTVar (_wscUser $ WS.getData wsConn) csInit
sendMsg wsConn SMConnAck
-- TODO(from master): send it periodically? Why doesn't apollo's protocol use
-- ping/pong frames of websocket spec?
sendMsg wsConn SMConnKeepAlive
liftIO $ keepAliveMessageAction wsConn
where
unexpectedInitError e = do
let connErr = ConnErrMsg e
logWSEvent logger wsConn $ EConnErr connErr
sendMsg wsConn $ SMConnErr connErr
liftIO $ onConnInitErrAction wsConn connErr WS.onConnInitErrorText
getIpAddress = \case
CSNotInitialised _ ip -> return ip
@ -871,59 +871,3 @@ onClose logger serverMetrics lqMap wsConn = do
LQ.removeLiveQuery logger serverMetrics lqMap lqId
where
opMap = _wscOpMap $ WS.getData wsConn
createWSServerEnv
:: (MonadIO m)
=> L.Logger L.Hasura
-> LQ.LiveQueriesState
-> IO (SchemaCache, SchemaCacheVer)
-> H.Manager
-> CorsPolicy
-> SQLGenCtx
-> Bool
-> KeepAliveDelay
-> ServerMetrics
-> m WSServerEnv
createWSServerEnv logger lqState getSchemaCache httpManager
corsPolicy sqlGenCtx enableAL keepAliveDelay serverMetrics = do
wsServer <- liftIO $ STM.atomically $ WS.createWSServer logger
return $
WSServerEnv logger lqState getSchemaCache httpManager corsPolicy
sqlGenCtx wsServer enableAL keepAliveDelay serverMetrics
createWSServerApp
:: ( HasVersion
, MonadIO m
, MC.MonadBaseControl IO m
, LA.Forall (LA.Pure m)
, UserAuthentication (Tracing.TraceT m)
, E.MonadGQLExecutionCheck m
, WS.MonadWSLog m
, MonadQueryLog m
, Tracing.HasReporter m
, MonadExecuteQuery m
, MonadMetadataStorage (MetadataStorageT m)
, EB.MonadQueryTags m
)
=> Env.Environment
-> HashSet (L.EngineLogType L.Hasura)
-> AuthMode
-> WSServerEnv
-> WS.HasuraServerApp m
-- -- ^ aka generalized 'WS.ServerApp'
createWSServerApp env enabledLogTypes authMode serverEnv = \ !ipAddress !pendingConn ->
WS.createServerApp (_wseServer serverEnv) handlers ipAddress pendingConn
where
handlers = WS.WSHandlers onConnHandler onMessageHandler onCloseHandler
serverMetrics = _wseServerMetrics serverEnv
-- Mask async exceptions during event processing to help maintain integrity of mutable vars:
onConnHandler rid rh ip = mask_ do
liftIO $ EKG.Gauge.inc $ smWebsocketConnections serverMetrics
flip runReaderT serverEnv $ onConn rid rh ip
onMessageHandler conn bs = mask_ $ onMessage env enabledLogTypes authMode serverEnv conn bs
onCloseHandler conn = mask_ do
liftIO $ EKG.Gauge.dec $ smWebsocketConnections serverMetrics
onClose (_wseLogger serverEnv) serverMetrics (_wseLiveQMap serverEnv) conn
stopWSServerApp :: WSServerEnv -> IO ()
stopWSServerApp wsEnv = WS.shutdown (_wseServer wsEnv)

View File

@ -1,33 +1,84 @@
-- | See: https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md
module Hasura.GraphQL.Transport.WebSocket.Protocol
( OperationId(..)
, ConnParams(..)
, StartMsg(..)
, StopMsg(..)
, ClientMsg(..)
, ServerMsg(..)
, ServerMsgType(..)
, encodeServerMsg
, serverMsgType
, DataMsg(..)
, ErrorMsg(..)
, ConnErrMsg(..)
, CompletionMsg(..)
) where
-- | This file contains types for both the websocket protocols (Apollo) and (graphql-ws)
-- | See Apollo: https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md
-- | See graphql-ws: https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md
module Hasura.GraphQL.Transport.WebSocket.Protocol where
import Hasura.Prelude
import qualified Data.Aeson as J
import qualified Data.Aeson.TH as J
import qualified Data.ByteString.Lazy as BL
import qualified Data.HashMap.Strict as Map
import Control.Concurrent
import Control.Concurrent.Extended (sleep)
import Control.Concurrent.STM
import Data.Text (pack)
import Hasura.EncJSON
import Hasura.GraphQL.Transport.HTTP.Protocol
import Hasura.Prelude
-- | These come from the client and are websocket connection-local.
-- NOTE: the `subProtocol` is decided based on the `Sec-WebSocket-Protocol`
-- header on every request sent to the server.
data WSSubProtocol = Apollo | GraphQLWS
deriving (Eq, Show)
-- NOTE: Please do not change them, as they're used for to identify the type of client
-- on every request that reaches the server. They are unique to each of the protocols.
showSubProtocol :: WSSubProtocol -> String
showSubProtocol subProtocol = case subProtocol of
-- REF: https://github.com/apollographql/subscriptions-transport-ws/blob/master/src/server.ts#L144
Apollo -> "graphql-ws"
-- REF: https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md#communication
GraphQLWS -> "graphql-transport-ws"
toWSSubProtocol :: String -> WSSubProtocol
toWSSubProtocol str = case str of
"graphql-transport-ws" -> GraphQLWS
_ -> Apollo
-- This is set by the client when it connects to the server
newtype OperationId
= OperationId { unOperationId :: Text }
deriving (Show, Eq, J.ToJSON, J.FromJSON, Hashable)
deriving (Show, Eq, J.ToJSON, J.FromJSON, IsString, Hashable)
data ServerMsgType =
-- specific to `Apollo` clients
SMT_GQL_CONNECTION_KEEP_ALIVE
| SMT_GQL_CONNECTION_ERROR
| SMT_GQL_DATA
-- specific to `graphql-ws` clients
| SMT_GQL_NEXT
| SMT_GQL_PING
| SMT_GQL_PONG
-- common to clients of both protocols
| SMT_GQL_CONNECTION_ACK
| SMT_GQL_ERROR
| SMT_GQL_COMPLETE
deriving (Eq)
instance Show ServerMsgType where
show = \case
-- specific to `Apollo` clients
SMT_GQL_CONNECTION_KEEP_ALIVE -> "ka"
SMT_GQL_CONNECTION_ERROR -> "connection_error"
SMT_GQL_DATA -> "data"
-- specific to `graphql-ws` clients
SMT_GQL_NEXT -> "next"
SMT_GQL_PING -> "ping"
SMT_GQL_PONG -> "pong"
-- common to clients of both protocols
SMT_GQL_CONNECTION_ACK -> "connection_ack"
SMT_GQL_ERROR -> "error"
SMT_GQL_COMPLETE -> "complete"
instance J.ToJSON ServerMsgType where
toJSON = J.toJSON . show
data ConnParams = ConnParams
{ _cpHeaders :: Maybe (HashMap Text Text) }
deriving stock (Show, Eq)
$(J.deriveJSON hasuraJSON ''ConnParams)
data StartMsg
= StartMsg
@ -42,30 +93,55 @@ data StopMsg
} deriving (Show, Eq)
$(J.deriveJSON hasuraJSON ''StopMsg)
-- Specific to graphql-ws
data PingPongPayload
= PingPongPayload
{ _smMessage:: !(Maybe Text) -- NOTE: this is not within the spec, but is specific to our usecase
} deriving stock (Show, Eq)
$(J.deriveJSON hasuraJSON ''PingPongPayload)
-- Specific to graphql-ws
keepAliveMessage :: PingPongPayload
keepAliveMessage = PingPongPayload . Just . pack $ "keepalive"
-- Specific to graphql-ws
data SubscribeMsg
= SubscribeMsg
{ _subId :: !OperationId
, _subPayload :: !GQLReqUnparsed
} deriving (Show, Eq)
$(J.deriveJSON hasuraJSON ''SubscribeMsg)
data ClientMsg
= CMConnInit !(Maybe ConnParams)
| CMStart !StartMsg
| CMStop !StopMsg
-- specific to apollo clients
| CMConnTerm
-- specific to graphql-ws clients
| CMPing !(Maybe PingPongPayload)
| CMPong !(Maybe PingPongPayload)
deriving (Show, Eq)
data ConnParams
= ConnParams
{ _cpHeaders :: Maybe (Map.HashMap Text Text)
} deriving (Show, Eq)
$(J.deriveJSON hasuraJSON ''ConnParams)
instance J.FromJSON ClientMsg where
parseJSON = J.withObject "ClientMessage" $ \obj -> do
t <- obj J..: "type"
case t of
"connection_init" -> CMConnInit <$> obj J..:? "payload"
"start" -> CMStart <$> J.parseJSON (J.Object obj)
"stop" -> CMStop <$> J.parseJSON (J.Object obj)
"connection_terminate" -> return CMConnTerm
case (t :: String) of
"connection_init" -> CMConnInit <$> parsePayload obj
"start" -> CMStart <$> parseObj obj
"stop" -> CMStop <$> parseObj obj
"connection_terminate" -> pure CMConnTerm
-- graphql-ws specific message types
"complete" -> CMStop <$> parseObj obj
"subscribe" -> CMStart <$> parseObj obj
"ping" -> CMPing <$> parsePayload obj
"pong" -> CMPong <$> parsePayload obj
_ -> fail $ "unexpected type for ClientMessage: " <> t
where
parseObj o = J.parseJSON (J.Object o)
parsePayload py = py J..:? "payload"
-- server to client messages
data DataMsg
= DataMsg
{ _dmId :: !OperationId
@ -82,10 +158,22 @@ newtype CompletionMsg
= CompletionMsg { unCompletionMsg :: OperationId }
deriving (Show, Eq)
instance J.FromJSON CompletionMsg where
parseJSON = J.withObject "CompletionMsg" $ \t ->
CompletionMsg <$> t J..: "id"
instance J.ToJSON CompletionMsg where
toJSON (CompletionMsg opId) = J.String $ tshow opId
newtype ConnErrMsg
= ConnErrMsg { unConnErrMsg :: Text }
deriving (Show, Eq, J.ToJSON, J.FromJSON, IsString)
data ServerErrorMsg =
ServerErrorMsg { unServerErrorMsg :: Text }
deriving stock (Show, Eq)
$(J.deriveJSON hasuraJSON ''ServerErrorMsg)
data ServerMsg
= SMConnAck
| SMConnKeepAlive
@ -93,27 +181,34 @@ data ServerMsg
| SMData !DataMsg
| SMErr !ErrorMsg
| SMComplete !CompletionMsg
-- graphql-ws specific values
| SMNext !DataMsg
| SMPing !(Maybe PingPongPayload)
| SMPong !(Maybe PingPongPayload)
data ServerMsgType
= SMT_GQL_CONNECTION_ACK
| SMT_GQL_CONNECTION_KEEP_ALIVE
| SMT_GQL_CONNECTION_ERROR
| SMT_GQL_DATA
| SMT_GQL_ERROR
| SMT_GQL_COMPLETE
deriving (Eq)
-- | This is sent from the server to the client while closing the websocket
-- on encountering an error.
data ServerErrorCode =
ProtocolError1002
| GenericError4400 !String
| UnAuthorized4401
| Forbidden4403
| ConnectionInitTimeout4408
| NonUniqueSubscription4409 !OperationId
| TooManyRequests4429
deriving stock (Show)
instance Show ServerMsgType where
show = \case
SMT_GQL_CONNECTION_ACK -> "connection_ack"
SMT_GQL_CONNECTION_KEEP_ALIVE -> "ka"
SMT_GQL_CONNECTION_ERROR -> "connection_error"
SMT_GQL_DATA -> "data"
SMT_GQL_ERROR -> "error"
SMT_GQL_COMPLETE -> "complete"
instance J.ToJSON ServerMsgType where
toJSON = J.toJSON . show
encodeServerErrorMsg :: ServerErrorCode -> BL.ByteString
encodeServerErrorMsg ecode = encJToLBS . encJFromJValue $ case ecode of
ProtocolError1002 -> packMsg "1002: Protocol Error"
GenericError4400 msg -> packMsg $ "4400: " <> msg
UnAuthorized4401 -> packMsg "4401: Unauthorized"
Forbidden4403 -> packMsg "4403: Forbidden"
ConnectionInitTimeout4408 -> packMsg "4408: Connection initialisation timeout"
NonUniqueSubscription4409 opId -> packMsg $ "4409: Subscriber for " <> show opId <> " already exists"
TooManyRequests4429 -> packMsg "4429: Too many requests"
where
packMsg = ServerErrorMsg . pack
serverMsgType :: ServerMsg -> ServerMsgType
serverMsgType SMConnAck = SMT_GQL_CONNECTION_ACK
@ -122,6 +217,9 @@ serverMsgType (SMConnErr _) = SMT_GQL_CONNECTION_ERROR
serverMsgType (SMData _) = SMT_GQL_DATA
serverMsgType (SMErr _) = SMT_GQL_ERROR
serverMsgType (SMComplete _) = SMT_GQL_COMPLETE
serverMsgType (SMPing _) = SMT_GQL_PING
serverMsgType (SMPong _) = SMT_GQL_PONG
serverMsgType (SMNext _) = SMT_GQL_NEXT
encodeServerMsg :: ServerMsg -> BL.ByteString
encodeServerMsg msg =
@ -155,5 +253,56 @@ encodeServerMsg msg =
, ("id", encJFromJValue $ unCompletionMsg compMsg)
]
SMPing mPayload ->
encodePingPongPayload mPayload SMT_GQL_PING
SMPong mPayload ->
encodePingPongPayload mPayload SMT_GQL_PONG
SMNext (DataMsg opId payload) ->
[ encTy SMT_GQL_NEXT
, ("id", encJFromJValue opId)
, ("payload", encodeGQResp payload)
]
where
encTy ty = ("type", encJFromJValue ty)
encodePingPongPayload mPayload msgType = case mPayload of
Just payload ->
[ encTy msgType
, ("payload", encJFromJValue payload)
]
Nothing -> [ encTy msgType ]
-- This "timer" is necessary while initialising the connection
-- with the server. Also, this is specific to the GraphQL-WS protocol.
data WSConnInitTimerStatus = Running | Done
deriving stock (Show, Eq)
type WSConnInitTimer = (TVar WSConnInitTimerStatus, TMVar ())
waitForWSTimer :: WSConnInitTimer -> IO ()
waitForWSTimer (_, timer) = atomically $ readTMVar timer
stopWSTimer :: WSConnInitTimer -> IO ()
stopWSTimer (timerState, _) = atomically $ writeTVar timerState Done
getWSTimerState :: WSConnInitTimer -> IO WSConnInitTimerStatus
getWSTimerState (timerState, _) = readTVarIO timerState
getNewWSTimer :: Seconds -> IO WSConnInitTimer
getNewWSTimer timeout = do
timerState <- newTVarIO Running
timer <- newEmptyTMVarIO
forkIO $ do
sleep (seconds timeout)
atomically $ do
runTimerState <- readTVar timerState
case runTimerState of
Running -> do
-- time's up, we set status to "Done"
writeTVar timerState Done
putTMVar timer ()
Done -> pure ()
pure (timerState, timer)

View File

@ -1,35 +1,8 @@
{-# LANGUAGE CPP #-}
{-# LANGUAGE NondecreasingIndentation #-}
{-# LANGUAGE OverloadedStrings #-}
module Hasura.GraphQL.Transport.WebSocket.Server
( WSId(..)
, WSLog(..)
, WSEvent(..)
, MessageDetails(..)
, WSConn
, getData
, getWSId
, closeConn
, sendMsg
, AcceptWith(..)
, OnConnH
, OnCloseH
, OnMessageH
, WSHandlers(..)
, WSServer
, HasuraServerApp
, WSEventInfo(..)
, WSQueueResponse(..)
, ServerMsgType(..)
, createWSServer
, closeAll
, createServerApp
, shutdown
, MonadWSLog (..)
) where
module Hasura.GraphQL.Transport.WebSocket.Server where
import qualified Control.Concurrent.Async as A
import qualified Control.Concurrent.Async.Lifted.Safe as LA
@ -39,7 +12,9 @@ import qualified Control.Monad.Trans.Control as MC
import qualified Data.Aeson as J
import qualified Data.Aeson.Casing as J
import qualified Data.Aeson.TH as J
import qualified Data.ByteString.Char8 as B
import qualified Data.ByteString.Lazy as BL
import qualified Data.CaseInsensitive as CI
import Data.String
import qualified Data.TByteString as TBS
import qualified Data.UUID as UUID
@ -56,8 +31,10 @@ import qualified Network.WebSockets as WS
import qualified StmContainers.Map as STMMap
import qualified System.IO.Error as E
import Hasura.GraphQL.Transport.WebSocket.Protocol (OperationId, ServerMsgType (..))
import Hasura.GraphQL.Transport.HTTP.Protocol
import Hasura.GraphQL.Transport.WebSocket.Protocol
import qualified Hasura.Logging as L
import Hasura.Server.Init.Config (WSConnectionInitTimeout (..))
newtype WSId
= WSId { unWSId :: UUID.UUID }
@ -150,6 +127,9 @@ data WSConn a
, _wcExtraData :: !a
}
getRawWebSocketConnection :: WSConn a -> WS.Connection
getRawWebSocketConnection = _wcConnRaw
getData :: WSConn a -> a
getData = _wcExtraData
@ -233,42 +213,89 @@ data AcceptWith a
, _awOnJwtExpiry :: !(WSConn a -> IO ())
}
type OnConnH m a = WSId -> WS.RequestHead -> IpAddress -> m (Either WS.RejectRequest (AcceptWith a))
-- | These set of functions or message handlers is used by the
-- server while communicating with the client. They are particularly
-- useful for the case when the messages being sent to the client
-- are different for each of the sub-protocol(s) supported by the server.
type WSKeepAliveMessageAction a = WSConn a -> IO ()
type WSPostExecErrMessageAction a = WSConn a -> OperationId -> GQExecError -> IO ()
type WSOnErrorMessageAction a = WSConn a -> ConnErrMsg -> Maybe String -> IO ()
type WSCloseConnAction a = WSConn a -> OperationId -> String -> IO ()
-- | Used for specific actions within the `onConn` and `onMessage` handlers
data WSActions a = WSActions
{ _wsaPostExecErrMessageAction :: !(WSPostExecErrMessageAction a)
, _wsaOnErrorMessageAction :: !(WSOnErrorMessageAction a)
, _wsaConnectionCloseAction :: !(WSCloseConnAction a)
, _wsaKeepAliveAction :: !(WSKeepAliveMessageAction a)
-- ^ NOTE: keep alive action was made redundant because we need to send this message
-- after the connection has been successfully established after `connection_init`
, _wsaGetDataMessageType :: !(DataMsg -> ServerMsg)
, _wsaAcceptRequest :: !WS.AcceptRequest
}
-- | to be used with `WSOnErrorMessageAction`
onClientMessageParseErrorText :: Maybe String
onClientMessageParseErrorText = Just "Parsing client message failed: "
-- | to be used with `WSOnErrorMessageAction`
onConnInitErrorText :: Maybe String
onConnInitErrorText = Just "Connection initialization failed: "
type OnConnH m a = WSId -> WS.RequestHead -> IpAddress -> WSActions a -> m (Either WS.RejectRequest (AcceptWith a))
type OnMessageH m a = WSConn a -> BL.ByteString -> WSActions a -> m ()
type OnCloseH m a = WSConn a -> m ()
type OnMessageH m a = WSConn a -> BL.ByteString -> m ()
-- | aka generalized 'WS.ServerApp' over @m@, which takes an IPAddress
type HasuraServerApp m = IpAddress -> WS.PendingConnection -> m ()
-- | NOTE: The types of `_hOnConn` and `_hOnMessage` were updated from `OnConnH` and `OnMessageH`
-- because we needed to pass the subprotcol here to these methods to eventually get to `OnConnH` and `OnMessageH`.
-- Please see `createServerApp` to get a better understanding of how these handlers are used.
data WSHandlers m a
= WSHandlers
{ _hOnConn :: OnConnH m a
, _hOnMessage :: OnMessageH m a
{ _hOnConn :: (WSId -> WS.RequestHead -> IpAddress -> WSSubProtocol -> m (Either WS.RejectRequest (AcceptWith a)))
, _hOnMessage :: (WSConn a -> BL.ByteString -> WSSubProtocol -> m ())
, _hOnClose :: OnCloseH m a
}
createServerApp
:: (MonadIO m, MC.MonadBaseControl IO m, LA.Forall (LA.Pure m), MonadWSLog m)
=> WSServer a
=> WSConnectionInitTimeout
-> WSServer a
-> WSHandlers m a
-- ^ user provided handlers
-> HasuraServerApp m
-- ^ aka WS.ServerApp
{-# INLINE createServerApp #-}
createServerApp (WSServer logger@(L.Logger writeLog) serverStatus) wsHandlers !ipAddress !pendingConn = do
createServerApp wsConnInitTimeout (WSServer logger@(L.Logger writeLog) serverStatus) wsHandlers !ipAddress !pendingConn = do
wsId <- WSId <$> liftIO UUID.nextRandom
logWSLog logger $ WSLog wsId EConnectionRequest Nothing
-- NOTE: this timer is specific to `graphql-ws`. the server has to close the connection
-- if the client doesn't send a `connection_init` message within the timeout period
wsConnInitTimer <- liftIO $ getNewWSTimer (unWSConnectionInitTimeout wsConnInitTimeout)
status <- liftIO $ STM.readTVarIO serverStatus
case status of
AcceptingConns _ -> logUnexpectedExceptions $ do
let reqHead = WS.pendingRequest pendingConn
onConnRes <- _hOnConn wsHandlers wsId reqHead ipAddress
either (onReject wsId) (onAccept wsId) onConnRes
onConnRes <- connHandler wsId reqHead ipAddress subProtocol
either (onReject wsId) (onAccept wsConnInitTimer wsId) onConnRes
ShuttingDown ->
onReject wsId shuttingDownReject
where
reqHead = WS.pendingRequest pendingConn
getSubProtocolHeader rhdrs =
filter (\(x, _) -> x == (CI.mk . B.pack $ "Sec-WebSocket-Protocol")) $ WS.requestHeaders rhdrs
subProtocol = case getSubProtocolHeader reqHead of
[sph] -> toWSSubProtocol . B.unpack . snd $ sph
_ -> Apollo -- NOTE: we default to the apollo implemenation
connHandler = _hOnConn wsHandlers
messageHandler = _hOnMessage wsHandlers
closeHandler = _hOnClose wsHandlers
-- It's not clear what the unexpected exception handling story here should be. So at
-- least log properly and re-raise:
logUnexpectedExceptions = handle $ \(e :: SomeException) -> do
@ -286,7 +313,7 @@ createServerApp (WSServer logger@(L.Logger writeLog) serverStatus) wsHandlers !i
liftIO $ WS.rejectRequestWith pendingConn rejectRequest
logWSLog logger $ WSLog wsId ERejected Nothing
onAccept wsId (AcceptWith a acceptWithParams keepAlive onJwtExpiry) = do
onAccept wsConnInitTimer wsId (AcceptWith a acceptWithParams keepAlive onJwtExpiry) = do
conn <- liftIO $ WS.acceptRequestWith pendingConn acceptWithParams
logWSLog logger $ WSLog wsId EAccepted Nothing
sendQ <- liftIO STM.newTQueueIO
@ -317,7 +344,7 @@ createServerApp (WSServer logger@(L.Logger writeLog) serverStatus) wsHandlers !i
-- Bad luck, we were in the process of shutting the server down but a new
-- connection was accepted. Let's just close it politely
forceConnReconnect wsConn "shutting server down"
_hOnClose wsHandlers wsConn
closeHandler wsConn
AcceptingConns _ -> do
let rcv = forever $ do
@ -332,7 +359,7 @@ createServerApp (WSServer logger@(L.Logger writeLog) serverStatus) wsHandlers !i
WS.receiveData conn
let message = MessageDetails (TBS.fromLBS msg) (BL.length msg)
logWSLog logger $ WSLog wsId (EMessageReceived message) Nothing
_hOnMessage wsHandlers wsConn msg
messageHandler wsConn msg subProtocol
let send = forever $ do
WSQueueResponse msg wsInfo <- liftIO $ STM.atomically $ STM.readTQueue sendQ
@ -347,6 +374,11 @@ createServerApp (WSServer logger@(L.Logger writeLog) serverStatus) wsHandlers !i
LA.withAsync (liftIO $ keepAlive wsConn) $ \keepAliveRef -> do
LA.withAsync (liftIO $ onJwtExpiry wsConn) $ \onJwtExpiryRef -> do
-- once connection is accepted, check the status of the timer, and if it's expired, close the connection for `graphql-ws`
timeoutStatus <- liftIO $ getWSTimerState wsConnInitTimer
when (timeoutStatus == Done && subProtocol == GraphQLWS) $
liftIO $ closeConnWithCode wsConn 4408 "Connection initialisation timed out"
-- terminates on WS.ConnectionException and JWT expiry
let waitOnRefs = [keepAliveRef, onJwtExpiryRef, rcvRef, sendRef]
-- withAnyCancel re-raises exceptions from forkedThreads, and is guarenteed to cancel in
@ -365,7 +397,7 @@ createServerApp (WSServer logger@(L.Logger writeLog) serverStatus) wsHandlers !i
ShuttingDown -> pure ()
AcceptingConns connMap -> do
liftIO $ STM.atomically $ STMMap.delete (_wcConnId wsConn) connMap
_hOnClose wsHandlers wsConn
closeHandler wsConn
logWSLog logger $ WSLog (_wcConnId wsConn) EClosed Nothing
shutdown :: WSServer a -> IO ()

View File

@ -0,0 +1,86 @@
module Hasura.GraphQL.Transport.WebSocket.Types where
import Hasura.Prelude
import qualified Control.Concurrent.STM as STM
import qualified Data.Time.Clock as TC
import qualified Network.HTTP.Client as H
import qualified Network.HTTP.Types as H
import qualified Network.Wai.Extended as Wai
import qualified StmContainers.Map as STMMap
import qualified Hasura.GraphQL.Execute as E
import qualified Hasura.GraphQL.Execute.LiveQuery.State as LQ
import qualified Hasura.GraphQL.Transport.WebSocket.Server as WS
import qualified Hasura.Logging as L
import Hasura.GraphQL.Transport.HTTP.Protocol
import Hasura.GraphQL.Transport.Instances ()
import Hasura.GraphQL.Transport.WebSocket.Protocol
import Hasura.RQL.Types
import Hasura.Server.Cors
import Hasura.Server.Init.Config (KeepAliveDelay (..))
import Hasura.Server.Metrics (ServerMetrics (..))
import Hasura.Session
newtype WsHeaders
= WsHeaders { unWsHeaders :: [H.Header] }
deriving (Show, Eq)
data ErrRespType
= ERTLegacy
| ERTGraphqlCompliant
deriving (Show)
data WSConnState
= CSNotInitialised !WsHeaders !Wai.IpAddress
-- ^ headers and IP address from the client for websockets
| CSInitError !Text
| CSInitialised !WsClientState
deriving (Show)
data WsClientState
= WsClientState
{ wscsUserInfo :: !UserInfo
-- ^ the 'UserInfo' required to execute the GraphQL query
, wscsTokenExpTime :: !(Maybe TC.UTCTime)
-- ^ the JWT/token expiry time, if any
, wscsReqHeaders :: ![H.Header]
-- ^ headers from the client (in conn params) to forward to the remote schema
, wscsIpAddress :: !Wai.IpAddress
-- ^ IP address required for 'MonadGQLAuthorization'
}
deriving (Show)
data WSConnData
= WSConnData
-- the role and headers are set only on connection_init message
{ _wscUser :: !(STM.TVar WSConnState)
-- we only care about subscriptions,
-- the other operations (query/mutations)
-- are not tracked here
, _wscOpMap :: !OperationMap
, _wscErrRespTy :: !ErrRespType
, _wscAPIType :: !E.GraphQLQueryType
}
data WSServerEnv
= WSServerEnv
{ _wseLogger :: !(L.Logger L.Hasura)
, _wseLiveQMap :: !LQ.LiveQueriesState
, _wseGCtxMap :: !(IO (SchemaCache, SchemaCacheVer))
-- ^ an action that always returns the latest version of the schema cache. See 'SchemaCacheRef'.
, _wseHManager :: !H.Manager
, _wseCorsPolicy :: !CorsPolicy
, _wseSQLCtx :: !SQLGenCtx
, _wseServer :: !WSServer
, _wseEnableAllowlist :: !Bool
, _wseKeepAliveDelay :: !KeepAliveDelay
, _wseServerMetrics :: !ServerMetrics
}
type OperationMap = STMMap.Map OperationId (LQ.LiveQueryId, Maybe OperationName)
type WSServer = WS.WSServer WSConnData
type WSConn = WS.WSConn WSConnData

View File

@ -51,7 +51,7 @@ import qualified Hasura.GraphQL.Execute.LiveQuery.State as EL
import qualified Hasura.GraphQL.Explain as GE
import qualified Hasura.GraphQL.Transport.HTTP as GH
import qualified Hasura.GraphQL.Transport.HTTP.Protocol as GH
import qualified Hasura.GraphQL.Transport.WebSocket as WS
import qualified Hasura.GraphQL.Transport.WSServerApp as WS
import qualified Hasura.GraphQL.Transport.WebSocket.Server as WS
import qualified Hasura.Logging as L
import qualified Hasura.Server.API.PGDump as PGD
@ -782,16 +782,16 @@ mkWaiApp
-> FunctionPermissionsCtx
-> WS.ConnectionOptions
-> KeepAliveDelay
-- ^ Metadata storage connection pool
-> MaintenanceMode
-> S.HashSet ExperimentalFeature
-- ^ Set of the enabled experimental features
-> S.HashSet (L.EngineLogType L.Hasura)
-> WSConnectionInitTimeout
-> m HasuraApp
mkWaiApp setupHook env logger sqlGenCtx enableAL httpManager mode corsCfg enableConsole consoleAssetsDir
enableTelemetry instanceId apis lqOpts responseErrorsConfig
liveQueryHook schemaCacheRef ekgStore serverMetrics enableRSPermsCtx functionPermsCtx
connectionOptions keepAliveDelay maintenanceMode experimentalFeatures enabledLogTypes = do
connectionOptions keepAliveDelay maintenanceMode experimentalFeatures enabledLogTypes wsConnInitTimeout = do
let getSchemaCache = first lastBuiltSchemaCache <$> readIORef (_scrCache schemaCacheRef)
@ -799,21 +799,21 @@ mkWaiApp setupHook env logger sqlGenCtx enableAL httpManager mode corsCfg enable
postPollHook = fromMaybe (EL.defaultLiveQueryPostPollHook logger) liveQueryHook
lqState <- liftIO $ EL.initLiveQueriesState lqOpts postPollHook
wsServerEnv <- WS.createWSServerEnv logger lqState getSchemaCache httpManager
corsPolicy sqlGenCtx enableAL keepAliveDelay serverMetrics
wsServerEnv <- WS.createWSServerEnv logger lqState getSchemaCache httpManager corsPolicy
sqlGenCtx enableAL keepAliveDelay serverMetrics
let serverCtx = ServerCtx
{ scLogger = logger
, scCacheRef = schemaCacheRef
, scAuthMode = mode
, scManager = httpManager
, scSQLGenCtx = sqlGenCtx
, scEnabledAPIs = apis
, scInstanceId = instanceId
, scLQState = lqState
, scEnableAllowlist = enableAL
, scEkgStore = ekgStore
, scEnvironment = env
{ scLogger = logger
, scCacheRef = schemaCacheRef
, scAuthMode = mode
, scManager = httpManager
, scSQLGenCtx = sqlGenCtx
, scEnabledAPIs = apis
, scInstanceId = instanceId
, scLQState = lqState
, scEnableAllowlist = enableAL
, scEkgStore = ekgStore
, scEnvironment = env
, scResponseInternalErrorsConfig = responseErrorsConfig
, scRemoteSchemaPermsCtx = enableRSPermsCtx
, scFunctionPermsCtx = functionPermsCtx
@ -826,7 +826,7 @@ mkWaiApp setupHook env logger sqlGenCtx enableAL httpManager mode corsCfg enable
Spock.spockAsApp $ Spock.spockT lowerIO $
httpApp setupHook corsCfg serverCtx enableConsole consoleAssetsDir enableTelemetry
let wsServerApp = WS.createWSServerApp env enabledLogTypes mode wsServerEnv -- TODO: Lyndon: Can we pass environment through wsServerEnv?
let wsServerApp = WS.createWSServerApp env enabledLogTypes mode wsServerEnv wsConnInitTimeout -- TODO: Lyndon: Can we pass environment through wsServerEnv?
stopWSServer = WS.stopWSServerApp wsServerEnv
waiApp <- liftWithStateless $ \lowerIO ->

View File

@ -221,6 +221,9 @@ mkServeOptions rso = do
gracefulShutdownTime <-
fromMaybe 60 <$> withEnv (rsoGracefulShutdownTimeout rso) (fst gracefulShutdownEnv)
webSocketConnectionInitTimeout <- WSConnectionInitTimeout . fromIntegral . fromMaybe 3
<$> withEnv (rsoWebSocketConnectionInitTimeout rso) (fst webSocketConnectionInitTimeoutEnv)
pure $ ServeOptions
port
host
@ -256,6 +259,7 @@ mkServeOptions rso = do
eventsFetchBatchSize
devMode
gracefulShutdownTime
webSocketConnectionInitTimeout
where
#ifdef DeveloperAPIs
defaultAPIs = [METADATA,GRAPHQL,PGDUMP,CONFIG,DEVELOPER]
@ -1196,6 +1200,7 @@ serveOptsToLog so =
, "experimental_features" J..= soExperimentalFeatures so
, "events_fetch_batch_size" J..= soEventsFetchBatchSize so
, "graceful_shutdown_timeout" J..= soGracefulShutdownTimeout so
, "websocket_connection_init_timeout" J..= show (soWebsocketConnectionInitTimeout so)
]
mkGenericStrLog :: L.LogLevel -> Text -> String -> StartupLog
@ -1252,6 +1257,7 @@ serveOptionsParser =
<*> parseExperimentalFeatures
<*> parseEventsFetchBatchSize
<*> parseGracefulShutdownTimeout
<*> parseWebSocketConnectionInitTimeout
-- | This implements the mapping between application versions
-- and catalog schema versions.
@ -1299,6 +1305,7 @@ parseWebSocketCompression =
help (snd webSocketCompressionEnv)
)
-- NOTE: this is purely used by Apollo-Subscription-Transport-WS
webSocketKeepAliveEnv :: (String, String)
webSocketKeepAliveEnv =
( "HASURA_GRAPHQL_WEBSOCKET_KEEPALIVE"
@ -1312,3 +1319,18 @@ parseWebSocketKeepAlive =
( long "websocket-keepalive" <>
help (snd webSocketKeepAliveEnv)
)
-- NOTE: this is purely used by GraphQL-WS
webSocketConnectionInitTimeoutEnv :: (String, String)
webSocketConnectionInitTimeoutEnv =
( "HASURA_GRAPHQL_WEBSOCKET_CONNECTION_INIT_TIMEOUT" -- FIXME?: maybe a better name
, "Control websocket connection_init timeout (default 3 seconds)"
)
parseWebSocketConnectionInitTimeout :: Parser (Maybe Int)
parseWebSocketConnectionInitTimeout =
optional $
option (eitherReader readEither)
( long "websocket-connection-init-timeout" <>
help (snd webSocketConnectionInitTimeoutEnv)
)

View File

@ -63,44 +63,58 @@ instance ToJSON OptionalInterval where
Skip -> toJSON @Milliseconds 0
Interval s -> toJSON s
data API
= METADATA
| GRAPHQL
| PGDUMP
| DEVELOPER
| CONFIG
deriving (Show, Eq, Read, Generic)
$(J.deriveJSON (J.defaultOptions { J.constructorTagModifier = map toLower })
''API)
instance Hashable API
data RawServeOptions impl
= RawServeOptions
{ rsoPort :: !(Maybe Int)
, rsoHost :: !(Maybe HostPreference)
, rsoConnParams :: !RawConnParams
, rsoTxIso :: !(Maybe Q.TxIsolation)
, rsoAdminSecret :: !(Maybe AdminSecretHash)
, rsoAuthHook :: !RawAuthHook
, rsoJwtSecret :: !(Maybe JWTConfig)
, rsoUnAuthRole :: !(Maybe RoleName)
, rsoCorsConfig :: !(Maybe CorsConfig)
, rsoEnableConsole :: !Bool
, rsoConsoleAssetsDir :: !(Maybe Text)
, rsoEnableTelemetry :: !(Maybe Bool)
, rsoWsReadCookie :: !Bool
, rsoStringifyNum :: !Bool
, rsoDangerousBooleanCollapse :: !(Maybe Bool)
, rsoEnabledAPIs :: !(Maybe [API])
, rsoMxRefetchInt :: !(Maybe LQ.RefetchInterval)
, rsoMxBatchSize :: !(Maybe LQ.BatchSize)
, rsoEnableAllowlist :: !Bool
, rsoEnabledLogTypes :: !(Maybe [L.EngineLogType impl])
, rsoLogLevel :: !(Maybe L.LogLevel)
, rsoDevMode :: !Bool
, rsoAdminInternalErrors :: !(Maybe Bool)
, rsoEventsHttpPoolSize :: !(Maybe Int)
, rsoEventsFetchInterval :: !(Maybe Milliseconds)
, rsoAsyncActionsFetchInterval :: !(Maybe Milliseconds)
, rsoLogHeadersFromEnv :: !Bool
, rsoEnableRemoteSchemaPermissions :: !Bool
, rsoWebSocketCompression :: !Bool
, rsoWebSocketKeepAlive :: !(Maybe Int)
, rsoInferFunctionPermissions :: !(Maybe Bool)
, rsoEnableMaintenanceMode :: !Bool
, rsoSchemaPollInterval :: !(Maybe Milliseconds)
, rsoExperimentalFeatures :: !(Maybe [ExperimentalFeature])
, rsoEventsFetchBatchSize :: !(Maybe NonNegativeInt)
, rsoGracefulShutdownTimeout :: !(Maybe Seconds)
{ rsoPort :: !(Maybe Int)
, rsoHost :: !(Maybe HostPreference)
, rsoConnParams :: !RawConnParams
, rsoTxIso :: !(Maybe Q.TxIsolation)
, rsoAdminSecret :: !(Maybe AdminSecretHash)
, rsoAuthHook :: !RawAuthHook
, rsoJwtSecret :: !(Maybe JWTConfig)
, rsoUnAuthRole :: !(Maybe RoleName)
, rsoCorsConfig :: !(Maybe CorsConfig)
, rsoEnableConsole :: !Bool
, rsoConsoleAssetsDir :: !(Maybe Text)
, rsoEnableTelemetry :: !(Maybe Bool)
, rsoWsReadCookie :: !Bool
, rsoStringifyNum :: !Bool
, rsoDangerousBooleanCollapse :: !(Maybe Bool)
, rsoEnabledAPIs :: !(Maybe [API])
, rsoMxRefetchInt :: !(Maybe LQ.RefetchInterval)
, rsoMxBatchSize :: !(Maybe LQ.BatchSize)
, rsoEnableAllowlist :: !Bool
, rsoEnabledLogTypes :: !(Maybe [L.EngineLogType impl])
, rsoLogLevel :: !(Maybe L.LogLevel)
, rsoDevMode :: !Bool
, rsoAdminInternalErrors :: !(Maybe Bool)
, rsoEventsHttpPoolSize :: !(Maybe Int)
, rsoEventsFetchInterval :: !(Maybe Milliseconds)
, rsoAsyncActionsFetchInterval :: !(Maybe Milliseconds)
, rsoLogHeadersFromEnv :: !Bool
, rsoEnableRemoteSchemaPermissions :: !Bool
, rsoWebSocketCompression :: !Bool
, rsoWebSocketKeepAlive :: !(Maybe Int)
, rsoInferFunctionPermissions :: !(Maybe Bool)
, rsoEnableMaintenanceMode :: !Bool
, rsoSchemaPollInterval :: !(Maybe Milliseconds)
, rsoExperimentalFeatures :: !(Maybe [ExperimentalFeature])
, rsoEventsFetchBatchSize :: !(Maybe NonNegativeInt)
, rsoGracefulShutdownTimeout :: !(Maybe Seconds)
, rsoWebSocketConnectionInitTimeout :: !(Maybe Int)
}
-- | @'ResponseInternalErrorsConfig' represents the encoding of the internal
@ -118,47 +132,57 @@ shouldIncludeInternal role = \case
InternalErrorsAdminOnly -> role == adminRoleName
InternalErrorsDisabled -> False
newtype KeepAliveDelay
= KeepAliveDelay
{ unKeepAliveDelay :: Seconds
} deriving (Eq, Show)
newtype KeepAliveDelay = KeepAliveDelay { unKeepAliveDelay :: Seconds }
deriving (Eq, Show)
$(J.deriveJSON hasuraJSON ''KeepAliveDelay)
defaultKeepAliveDelay :: KeepAliveDelay
defaultKeepAliveDelay = KeepAliveDelay $ fromIntegral (5 :: Int)
newtype WSConnectionInitTimeout = WSConnectionInitTimeout { unWSConnectionInitTimeout :: Seconds }
deriving (Eq, Show)
$(J.deriveJSON hasuraJSON ''WSConnectionInitTimeout)
defaultWSConnectionInitTimeout :: WSConnectionInitTimeout
defaultWSConnectionInitTimeout = WSConnectionInitTimeout $ fromIntegral (3 :: Int)
data ServeOptions impl
= ServeOptions
{ soPort :: !Int
, soHost :: !HostPreference
, soConnParams :: !Q.ConnParams
, soTxIso :: !Q.TxIsolation
, soAdminSecret :: !(Maybe AdminSecretHash)
, soAuthHook :: !(Maybe AuthHook)
, soJwtSecret :: !(Maybe JWTConfig)
, soUnAuthRole :: !(Maybe RoleName)
, soCorsConfig :: !CorsConfig
, soEnableConsole :: !Bool
, soConsoleAssetsDir :: !(Maybe Text)
, soEnableTelemetry :: !Bool
, soStringifyNum :: !Bool
, soDangerousBooleanCollapse :: !Bool
, soEnabledAPIs :: !(Set.HashSet API)
, soLiveQueryOpts :: !LQ.LiveQueriesOptions
, soEnableAllowlist :: !Bool
, soEnabledLogTypes :: !(Set.HashSet (L.EngineLogType impl))
, soLogLevel :: !L.LogLevel
, soResponseInternalErrorsConfig :: !ResponseInternalErrorsConfig
, soEventsHttpPoolSize :: !(Maybe Int)
, soEventsFetchInterval :: !(Maybe Milliseconds)
, soAsyncActionsFetchInterval :: !OptionalInterval
, soLogHeadersFromEnv :: !Bool
, soEnableRemoteSchemaPermissions :: !RemoteSchemaPermsCtx
, soConnectionOptions :: !WS.ConnectionOptions
, soWebsocketKeepAlive :: !KeepAliveDelay
, soInferFunctionPermissions :: !FunctionPermissionsCtx
, soEnableMaintenanceMode :: !MaintenanceMode
, soSchemaPollInterval :: !OptionalInterval
, soExperimentalFeatures :: !(Set.HashSet ExperimentalFeature)
, soEventsFetchBatchSize :: !NonNegativeInt
, soDevMode :: !Bool
, soGracefulShutdownTimeout :: !Seconds
{ soPort :: !Int
, soHost :: !HostPreference
, soConnParams :: !Q.ConnParams
, soTxIso :: !Q.TxIsolation
, soAdminSecret :: !(Maybe AdminSecretHash)
, soAuthHook :: !(Maybe AuthHook)
, soJwtSecret :: !(Maybe JWTConfig)
, soUnAuthRole :: !(Maybe RoleName)
, soCorsConfig :: !CorsConfig
, soEnableConsole :: !Bool
, soConsoleAssetsDir :: !(Maybe Text)
, soEnableTelemetry :: !Bool
, soStringifyNum :: !Bool
, soDangerousBooleanCollapse :: !Bool
, soEnabledAPIs :: !(Set.HashSet API)
, soLiveQueryOpts :: !LQ.LiveQueriesOptions
, soEnableAllowlist :: !Bool
, soEnabledLogTypes :: !(Set.HashSet (L.EngineLogType impl))
, soLogLevel :: !L.LogLevel
, soResponseInternalErrorsConfig :: !ResponseInternalErrorsConfig
, soEventsHttpPoolSize :: !(Maybe Int)
, soEventsFetchInterval :: !(Maybe Milliseconds)
, soAsyncActionsFetchInterval :: !OptionalInterval
, soLogHeadersFromEnv :: !Bool
, soEnableRemoteSchemaPermissions :: !RemoteSchemaPermsCtx
, soConnectionOptions :: !WS.ConnectionOptions
, soWebsocketKeepAlive :: !KeepAliveDelay
, soInferFunctionPermissions :: !FunctionPermissionsCtx
, soEnableMaintenanceMode :: !MaintenanceMode
, soSchemaPollInterval :: !OptionalInterval
, soExperimentalFeatures :: !(Set.HashSet ExperimentalFeature)
, soEventsFetchBatchSize :: !NonNegativeInt
, soDevMode :: !Bool
, soGracefulShutdownTimeout :: !Seconds
, soWebsocketConnectionInitTimeout :: !WSConnectionInitTimeout
}
data DowngradeOptions
@ -211,19 +235,6 @@ data HGECommandG a
| HCDowngrade !DowngradeOptions
deriving (Show, Eq)
data API
= METADATA
| GRAPHQL
| PGDUMP
| DEVELOPER
| CONFIG
deriving (Show, Eq, Read, Generic)
$(J.deriveJSON (J.defaultOptions { J.constructorTagModifier = map toLower })
''API)
instance Hashable API
$(J.deriveJSON (J.aesonPrefix J.camelCase){J.omitNothingFields=True} ''PostgresRawConnDetails)
type HGECommand impl = HGECommandG (ServeOptions impl)

View File

@ -1,6 +1,6 @@
import pytest
import time
from context import HGECtx, HGECtxError, ActionsWebhookServer, EvtsWebhookServer, HGECtxGQLServer, GQLWsClient, PytestConf
from context import HGECtx, HGECtxError, ActionsWebhookServer, EvtsWebhookServer, HGECtxGQLServer, GQLWsClient, PytestConf, GraphQLWSClient
import threading
import random
from datetime import datetime
@ -388,6 +388,16 @@ def ws_client(request, hge_ctx):
yield client
client.teardown()
@pytest.fixture(scope='class')
def ws_client_graphql_ws(request, hge_ctx):
"""
This fixture provides an GraphQL-WS client
"""
client = GraphQLWSClient(hge_ctx, '/v1/graphql')
time.sleep(0.1)
yield client
client.teardown()
@pytest.fixture(scope='class')
def per_class_tests_db_state(request, hge_ctx):
"""

View File

@ -32,6 +32,7 @@ class PytestConf():
class HGECtxError(Exception):
pass
# NOTE: use this to generate a GraphQL client that uses the `Apollo`(subscription-transport-ws) sub-protocol
class GQLWsClient():
def __init__(self, hge_ctx, endpoint):
@ -160,6 +161,151 @@ class GQLWsClient():
self._ws.close()
self.wst.join()
# NOTE: use this to generate a GraphQL client that uses the `graphql-ws` sub-protocol
class GraphQLWSClient():
def __init__(self, hge_ctx, endpoint):
self.hge_ctx = hge_ctx
self.ws_queue = queue.Queue(maxsize=-1)
self.ws_url = urlparse(hge_ctx.hge_url)._replace(scheme='ws',
path=endpoint)
self.create_conn()
def get_queue(self):
return self.ws_queue.queue
def clear_queue(self):
self.ws_queue.queue.clear()
def create_conn(self):
self.ws_queue.queue.clear()
self.ws_id_query_queues = dict()
self.ws_active_query_ids = set()
self.connected_event = threading.Event()
self.init_done = False
self.is_closing = False
self.remote_closed = False
self._ws = websocket.WebSocketApp(self.ws_url.geturl(),
on_open=self._on_open, on_message=self._on_message, on_close=self._on_close, subprotocols=["graphql-transport-ws"])
self.wst = threading.Thread(target=self._ws.run_forever)
self.wst.daemon = True
self.wst.start()
def recreate_conn(self):
self.teardown()
self.create_conn()
def wait_for_connection(self, timeout=10):
assert not self.is_closing
assert self.connected_event.wait(timeout=timeout)
def get_ws_event(self, timeout):
return self.ws_queue.get(timeout=timeout)
def has_ws_query_events(self, query_id):
return not self.ws_id_query_queues[query_id].empty()
def get_ws_query_event(self, query_id, timeout):
print("HELLO", self.ws_active_query_ids)
return self.ws_id_query_queues[query_id].get(timeout=timeout)
def send(self, frame):
self.wait_for_connection()
if frame.get('type') == 'complete':
self.ws_active_query_ids.discard( frame.get('id') )
elif frame.get('type') == 'subscribe' and 'id' in frame:
self.ws_id_query_queues[frame['id']] = queue.Queue(maxsize=-1)
self._ws.send(json.dumps(frame))
def init_as_admin(self):
headers={}
if self.hge_ctx.hge_key:
headers = {'x-hasura-admin-secret': self.hge_ctx.hge_key}
self.init(headers)
def init(self, headers={}):
payload = {'type': 'connection_init', 'payload': {}}
if headers and len(headers) > 0:
payload['payload']['headers'] = headers
self.send(payload)
ev = self.get_ws_event(5)
assert ev['type'] == 'connection_ack', ev
self.init_done = True
def stop(self, query_id):
data = {'id': query_id, 'type': 'complete'}
self.send(data)
self.ws_active_query_ids.discard(query_id)
def gen_id(self, size=6, chars=string.ascii_letters + string.digits):
new_id = ''.join(random.choice(chars) for _ in range(size))
if new_id in self.ws_active_query_ids:
return self.gen_id(size, chars)
return new_id
def send_query(self, query, query_id=None, headers={}, timeout=60):
graphql.parse(query['query'])
if headers and len(headers) > 0:
#Do init If headers are provided
self.clear_queue()
self.init(headers)
elif not self.init_done:
self.init()
if query_id == None:
query_id = self.gen_id()
frame = {
'id': query_id,
'type': 'subscribe',
'payload': query,
}
self.ws_active_query_ids.add(query_id)
self.send(frame)
while True:
yield self.get_ws_query_event(query_id, timeout)
def _on_open(self):
if not self.is_closing:
self.connected_event.set()
def _on_message(self, message):
# NOTE: make sure we preserve key ordering so we can test the ordering
# properties in the graphql spec properly
json_msg = json.loads(message, object_pairs_hook=OrderedDict)
if json_msg['type'] == 'ping':
new_msg = json_msg
new_msg['type'] = 'pong'
self.send(json.dumps(new_msg))
return
if 'id' in json_msg:
query_id = json_msg['id']
if json_msg.get('type') == 'complete':
#Remove from active queries list
self.ws_active_query_ids.discard( query_id )
if not query_id in self.ws_id_query_queues:
self.ws_id_query_queues[json_msg['id']] = queue.Queue(maxsize=-1)
#Put event in the correponding query_queue
self.ws_id_query_queues[query_id].put(json_msg)
if json_msg['type'] != 'ping':
self.ws_queue.put(json_msg)
def _on_close(self):
self.remote_closed = True
self.init_done = False
def get_conn_close_state(self):
return self.remote_closed or self.is_closing
def teardown(self):
self.is_closing = True
if not self.remote_closed:
self._ws.close()
self.wst.join()
class ActionsWebhookHandler(http.server.BaseHTTPRequestHandler):
@ -501,6 +647,7 @@ class HGECtx:
self.ws_client = GQLWsClient(self, '/v1/graphql')
self.ws_client_v1alpha1 = GQLWsClient(self, '/v1alpha1/graphql')
self.ws_client_relay = GQLWsClient(self, '/v1beta1/relay')
self.ws_client_graphql_ws = GraphQLWSClient(self, '/v1/graphql')
self.backend = config.getoption('--backend')
self.default_backend = 'postgres'
@ -628,6 +775,7 @@ class HGECtx:
self.ws_client.teardown()
self.ws_client_v1alpha1.teardown()
self.ws_client_relay.teardown()
self.ws_client_graphql_ws.teardown()
def v1GraphqlExplain(self, q, hdrs=None):
headers = {}

View File

@ -1,5 +1,6 @@
#!/usr/bin/env python3
import time
import pytest
import json
import queue
@ -10,7 +11,11 @@ usefixtures = pytest.mark.usefixtures
@pytest.fixture(scope='class')
def ws_conn_init(hge_ctx, ws_client):
init_ws_conn(hge_ctx, ws_client)
init_ws_conn(hge_ctx, ws_client)
@pytest.fixture(scope='class')
def ws_conn_init_graphql_ws(hge_ctx, ws_client_graphql_ws):
init_graphql_ws_conn(hge_ctx, ws_client_graphql_ws)
'''
Refer: https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md#gql_connection_init
@ -34,6 +39,24 @@ def init_ws_conn(hge_ctx, ws_client, payload = None):
ev = ws_client.get_ws_event(3)
assert ev['type'] == 'connection_ack', ev
def init_graphql_ws_conn(hge_ctx, ws_client_graphql_ws, payload = None):
if payload is None:
payload = {}
if hge_ctx.hge_key is not None:
payload = {
'headers' : {
'X-Hasura-Admin-Secret': hge_ctx.hge_key
}
}
init_msg = {
'type': 'connection_init',
'payload': payload,
}
ws_client_graphql_ws.send(init_msg)
ev = ws_client_graphql_ws.get_ws_event(3)
assert ev['type'] == 'connection_ack', ev
class TestSubscriptionCtrl(object):
def test_init_without_payload(self, hge_ctx, ws_client):
@ -116,7 +139,6 @@ class TestSubscriptionBasic:
'''
Refer https://github.com/apollographql/subscriptions-transport-ws/blob/01e0b2b65df07c52f5831cce5c858966ba095993/src/server.ts#L306
'''
@pytest.mark.skip(reason="refer https://github.com/hasura/graphql-engine/pull/387#issuecomment-421343098")
def test_start_duplicate(self, ws_client):
self.test_start(ws_client)
@ -173,6 +195,140 @@ class TestSubscriptionBasic:
ev = ws_client.get_ws_query_event('2',3)
assert ev['type'] == 'complete' and ev['id'] == '2', ev
## NOTE: The same tests as in TestSubcscriptionBasic but with
## the subscription transport being used is `graphql-ws`
## FIXME: There's an issue with the tests being parametrized with both
## postgres and mssql data sources enabled(See issue #2084).
@usefixtures('per_method_tests_db_state', 'ws_conn_init_graphql_ws')
class TestSubscriptionBasicGraphQLWS:
@classmethod
def dir(cls):
return 'queries/subscriptions/basic'
@pytest.mark.parametrize("transport", ['http', 'websocket', 'subscription'])
def test_negative(self, hge_ctx, transport):
check_query_f(hge_ctx, self.dir() + '/negative_test.yaml', transport, gqlws=True)
def test_connection_error(self, hge_ctx, ws_client_graphql_ws):
if ws_client_graphql_ws.get_conn_close_state():
ws_client_graphql_ws.create_conn()
if hge_ctx.hge_key == None:
ws_client_graphql_ws.init()
else:
ws_client_graphql_ws.init_as_admin()
ws_client_graphql_ws.send({'type': 'test'})
time.sleep(2)
ev = ws_client_graphql_ws.get_conn_close_state()
assert ev == True, ev
def test_start(self, hge_ctx, ws_client_graphql_ws):
if ws_client_graphql_ws.get_conn_close_state():
ws_client_graphql_ws.create_conn()
if hge_ctx.hge_key == None:
ws_client_graphql_ws.init()
else:
ws_client_graphql_ws.init_as_admin()
query = """
subscription {
hge_tests_test_t1(order_by: {c1: desc}, limit: 1) {
c1,
c2
}
}
"""
obj = {
'id': '1',
'payload': {
'query': query
},
'type': 'subscribe'
}
ws_client_graphql_ws.send(obj)
ev = ws_client_graphql_ws.get_ws_query_event('1',15)
assert ev['type'] == 'next' and ev['id'] == '1', ev
@pytest.mark.skip(reason="refer https://github.com/hasura/graphql-engine/pull/387#issuecomment-421343098")
def test_start_duplicate(self, hge_ctx, ws_client_graphql_ws):
if ws_client_graphql_ws.get_conn_close_state():
ws_client_graphql_ws.create_conn()
if hge_ctx.hge_key == None:
ws_client_graphql_ws.init()
else:
ws_client_graphql_ws.init_as_admin()
self.test_start(ws_client_graphql_ws)
def test_stop_without_id(self, hge_ctx, ws_client_graphql_ws):
if ws_client_graphql_ws.get_conn_close_state():
ws_client_graphql_ws.create_conn()
if hge_ctx.hge_key == None:
ws_client_graphql_ws.init()
else:
ws_client_graphql_ws.init_as_admin()
obj = {
'type': 'complete'
}
ws_client_graphql_ws.send(obj)
time.sleep(2)
ev = ws_client_graphql_ws.get_conn_close_state()
assert ev == True, ev
def test_stop(self, hge_ctx, ws_client_graphql_ws):
if ws_client_graphql_ws.get_conn_close_state():
ws_client_graphql_ws.create_conn()
if hge_ctx.hge_key == None:
ws_client_graphql_ws.init()
else:
ws_client_graphql_ws.init_as_admin()
obj = {
'type': 'complete',
'id': '1'
}
ws_client_graphql_ws.send(obj)
time.sleep(2)
with pytest.raises(queue.Empty):
ev = ws_client_graphql_ws.get_ws_event(3)
def test_start_after_stop(self, hge_ctx, ws_client_graphql_ws):
if ws_client_graphql_ws.get_conn_close_state():
ws_client_graphql_ws.create_conn()
if hge_ctx.hge_key == None:
ws_client_graphql_ws.init()
else:
ws_client_graphql_ws.init_as_admin()
self.test_start(hge_ctx, ws_client_graphql_ws)
## NOTE: test_start leaves a message in the queue, hence clearing it
if len(ws_client_graphql_ws.get_queue()) > 0:
ws_client_graphql_ws.clear_queue()
self.test_stop(hge_ctx, ws_client_graphql_ws)
def test_complete(self, hge_ctx, ws_client_graphql_ws):
if ws_client_graphql_ws.get_conn_close_state():
ws_client_graphql_ws.create_conn()
if hge_ctx.hge_key == None:
ws_client_graphql_ws.init()
else:
ws_client_graphql_ws.init_as_admin()
query = """
query {
hge_tests_test_t1(order_by: {c1: desc}, limit: 1) {
c1,
c2
}
}
"""
obj = {
'id': '2',
'payload': {
'query': query
},
'type': 'subscribe'
}
ws_client_graphql_ws.send(obj)
ev = ws_client_graphql_ws.get_ws_query_event('2',3)
assert ev['type'] == 'next' and ev['id'] == '2', ev
# Check for complete type
ev = ws_client_graphql_ws.get_ws_query_event('2',3)
assert ev['type'] == 'complete' and ev['id'] == '2', ev
@usefixtures('per_method_tests_db_state','ws_conn_init')
class TestSubscriptionLiveQueries:
@ -256,6 +412,86 @@ class TestSubscriptionLiveQueries:
with pytest.raises(queue.Empty):
ev = ws_client.get_ws_event(3)
@usefixtures('per_method_tests_db_state','ws_conn_init_graphql_ws')
class TestSubscriptionLiveQueriesForGraphQLWS:
@classmethod
def dir(cls):
return 'queries/subscriptions/live_queries'
def test_live_queries(self, hge_ctx, ws_client_graphql_ws):
'''
Create connection using connection_init
'''
ws_client_graphql_ws.init_as_admin()
with open(self.dir() + "/steps.yaml") as c:
conf = yaml.safe_load(c)
queryTmplt = """
subscription ($result_limit: Int!) {
hge_tests_live_query_{0}: hge_tests_test_t2(order_by: {c1: asc}, limit: $result_limit) {
c1,
c2
}
}
"""
queries = [(0, 1), (1, 2), (2, 2)]
liveQs = []
for i, resultLimit in queries:
query = queryTmplt.replace('{0}',str(i))
headers={}
if hge_ctx.hge_key is not None:
headers['X-Hasura-Admin-Secret'] = hge_ctx.hge_key
subscrPayload = { 'query': query, 'variables': { 'result_limit': resultLimit } }
respLive = ws_client_graphql_ws.send_query(subscrPayload, query_id='live_'+str(i), headers=headers, timeout=15)
liveQs.append(respLive)
ev = next(respLive)
assert ev['type'] == 'next', ev
assert ev['id'] == 'live_' + str(i), ev
assert ev['payload']['data'] == {'hge_tests_live_query_'+str(i): []}, ev['payload']['data']
assert isinstance(conf, list) == True, 'Not an list'
for index, step in enumerate(conf):
mutationPayload = { 'query': step['query'] }
if 'variables' in step and step['variables']:
mutationPayload['variables'] = json.loads(step['variables'])
expected_resp = json.loads(step['response'])
mutResp = ws_client_graphql_ws.send_query(mutationPayload,'mutation_'+str(index),timeout=15)
ev = next(mutResp)
assert ev['type'] == 'next' and ev['id'] == 'mutation_'+str(index), ev
assert ev['payload']['data'] == expected_resp, ev['payload']['data']
ev = next(mutResp)
assert ev['type'] == 'complete' and ev['id'] == 'mutation_'+str(index), ev
for (i, resultLimit), respLive in zip(queries, liveQs):
ev = next(respLive)
assert ev['type'] == 'next', ev
assert ev['id'] == 'live_' + str(i), ev
expectedReturnedResponse = []
if 'live_response' in step:
expectedReturnedResponse = json.loads(step['live_response'])
elif 'returning' in expected_resp[step['name']]:
expectedReturnedResponse = expected_resp[step['name']]['returning']
expectedLimitedResponse = expectedReturnedResponse[:resultLimit]
expectedLiveResponse = { 'hge_tests_live_query_'+str(i): expectedLimitedResponse }
assert ev['payload']['data'] == expectedLiveResponse, ev['payload']['data']
for i, _ in queries:
# stop live operation
frame = {
'id': 'live_'+str(i),
'type': 'complete'
}
ws_client_graphql_ws.send(frame)
ws_client_graphql_ws.clear_queue()
@pytest.mark.parametrize("backend", ['mssql', 'postgres'])
@usefixtures('per_class_tests_db_state')
class TestSubscriptionMultiplexing:

View File

@ -1,5 +1,6 @@
#!/usr/bin/env python3
import time
import ruamel.yaml as yaml
from ruamel.yaml.compat import ordereddict, StringIO
from ruamel.yaml.comments import CommentedMap
@ -148,7 +149,7 @@ def mk_claims_with_namespace_path(claims,hasura_claims,namespace_path):
# Returns the response received and a bool indicating whether the test passed
# or not (this will always be True unless we are `--accepting`)
def check_query(hge_ctx, conf, transport='http', add_auth=True, claims_namespace_path=None):
def check_query(hge_ctx, conf, transport='http', add_auth=True, claims_namespace_path=None, gqlws=False):
hge_ctx.tests_passed = True
headers = {}
if 'headers' in conf:
@ -211,13 +212,13 @@ def check_query(hge_ctx, conf, transport='http', add_auth=True, claims_namespace
conf['status'], conf.get('response'), conf.get('resp_headers'), body=conf.get('body'), method=conf.get('method'))
elif transport == 'websocket':
print('running on websocket')
return validate_gql_ws_q(hge_ctx, conf, headers, retry=True)
return validate_gql_ws_q(hge_ctx, conf, headers, retry=True, gqlws=gqlws)
elif transport == 'subscription':
print('running via subscription')
return validate_gql_ws_q(hge_ctx, conf, headers, retry=True, via_subscription=True)
return validate_gql_ws_q(hge_ctx, conf, headers, retry=True, via_subscription=True, gqlws=gqlws)
def validate_gql_ws_q(hge_ctx, conf, headers, retry=False, via_subscription=False):
def validate_gql_ws_q(hge_ctx, conf, headers, retry=False, via_subscription=False, gqlws=False):
assert 'response' in conf
assert conf['url'].endswith('/graphql') or conf['url'].endswith('/relay')
endpoint = conf['url']
@ -236,11 +237,20 @@ def validate_gql_ws_q(hge_ctx, conf, headers, retry=False, via_subscription=Fals
ws_client = hge_ctx.ws_client_v1alpha1
elif endpoint == '/v1beta1/relay':
ws_client = hge_ctx.ws_client_relay
elif gqlws: # for `graphQL-ws` clients
ws_client = hge_ctx.ws_client_graphql_ws
else:
ws_client = hge_ctx.ws_client
print(ws_client.ws_url)
if not headers or len(headers) == 0:
ws_client.init({})
if ws_client.remote_closed or ws_client.is_closing:
ws_client.create_conn()
if not headers or len(headers) == 0 or hge_ctx.hge_key is None:
ws_client.init()
else:
ws_client.init_as_admin()
query_resp = ws_client.send_query(query, query_id='hge_test', headers=headers, timeout=15)
resp = next(query_resp)
@ -253,18 +263,22 @@ def validate_gql_ws_q(hge_ctx, conf, headers, retry=False, via_subscription=Fals
ws_client.recreate_conn()
return validate_gql_ws_q(hge_ctx, query, headers, exp_http_response, False)
else:
assert resp['type'] in ['data', 'error'], resp
assert resp['type'] in ['data', 'error', 'next'], resp
if 'errors' in exp_http_response or 'error' in exp_http_response:
assert resp['type'] in ['data', 'error'], resp
assert resp['type'] in ['data', 'error', 'next'], resp
else:
assert resp['type'] == 'data', resp
assert resp['type'] == 'data' or resp['type'] == 'next', resp
assert 'payload' in resp, resp
if via_subscription:
ws_client.send({ 'id': 'hge_test', 'type': 'stop' })
with pytest.raises(queue.Empty):
ws_client.get_ws_event(0)
if not gqlws:
ws_client.send({ 'id': 'hge_test', 'type': 'stop' })
else:
ws_client.send({ 'id': 'hge_test', 'type': 'complete' })
if not gqlws: # NOTE: for graphql-ws, we have some elements that are left in the queue especially after a 'next' message.
with pytest.raises(queue.Empty):
ws_client.get_ws_event(0)
else:
resp_done = next(query_resp)
assert resp_done['type'] == 'complete'
@ -414,7 +428,7 @@ def get_conf_f(f):
with open(f, 'r+') as c:
return yaml.YAML().load(c)
def check_query_f(hge_ctx, f, transport='http', add_auth=True):
def check_query_f(hge_ctx, f, transport='http', add_auth=True, gqlws = False):
print("Test file: " + f)
hge_ctx.may_skip_test_teardown = False
print ("transport="+transport)
@ -429,14 +443,14 @@ def check_query_f(hge_ctx, f, transport='http', add_auth=True):
conf = yml.load(c)
if isinstance(conf, list):
for ix, sconf in enumerate(conf):
actual_resp, matched = check_query(hge_ctx, sconf, transport, add_auth)
actual_resp, matched = check_query(hge_ctx, sconf, transport, add_auth, None, gqlws)
if PytestConf.config.getoption("--accept") and not matched:
conf[ix]['response'] = actual_resp
should_write_back = True
else:
if conf['status'] != 200:
hge_ctx.may_skip_test_teardown = True
actual_resp, matched = check_query(hge_ctx, conf, transport, add_auth)
actual_resp, matched = check_query(hge_ctx, conf, transport, add_auth, None, gqlws)
# If using `--accept` write the file back out with the new expected
# response set to the actual response we got:
if PytestConf.config.getoption("--accept") and not matched: