mirror of
https://github.com/hasura/graphql-engine.git
synced 2024-09-17 13:37:26 +03:00
server: support for graphql-ws
protocol
https://github.com/hasura/graphql-engine-mono/pull/1655 Co-authored-by: Rakesh Emmadi <12475069+rakeshkky@users.noreply.github.com> Co-authored-by: Vijay Prasanna <11921040+vijayprasanna13@users.noreply.github.com> Co-authored-by: hasura-bot <30118761+hasura-bot@users.noreply.github.com> Co-authored-by: Brandon Simmons <210815+jberryman@users.noreply.github.com> Co-authored-by: Varun Choudhary <68095256+Varun-Choudhary@users.noreply.github.com> Co-authored-by: Divi <32202683+imperfect-fourth@users.noreply.github.com> GitOrigin-RevId: 9db3902388fef06b94f9513255e2b5333bd23c3e
This commit is contained in:
parent
3789405e37
commit
edeb8c98fd
@ -5,8 +5,10 @@
|
||||
(Add entries below in the order of server, console, cli, docs, others)
|
||||
|
||||
- server: optimize SQL query generation with LIMITs (close #5745)
|
||||
- server: update non-existent event trigger, action and query collection error msgs (close #7396)
|
||||
- server: fix broken `untrack_function` for non-default source
|
||||
- server: Adding support for TLS allowlist by domain and service id (port)
|
||||
- server: add support for `graphql-ws` clients
|
||||
- console: fix error due too rendering inconsistent object's message
|
||||
|
||||
## v2.0.7
|
||||
@ -16,7 +18,6 @@
|
||||
- server: fix GraphQL type for remote relationship field (close #7284)
|
||||
- server: support EdDSA algorithm and key type for JWT
|
||||
- server: fix GraphQL type for single-row returning functions (close #7109)
|
||||
- server: update non-existent event trigger, action and query collection error msgs (close #7396)
|
||||
- console: add support for creation of indexes for Postgres data sources
|
||||
- docs: document the cleanup process for scheduled triggers
|
||||
- console: allow same named queries and unnamed queries on allowlist file upload
|
||||
|
@ -54,8 +54,11 @@ Communication protocol
|
||||
|
||||
Hasura GraphQL engine uses the `GraphQL over WebSocket Protocol
|
||||
<https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md>`__ by the
|
||||
`apollographql/subscriptions-transport-ws <https://github.com/apollographql/subscriptions-transport-ws>`__ library
|
||||
for sending and receiving events.
|
||||
`apollographql/subscriptions-transport-ws <https://github.com/apollographql/subscriptions-transport-ws>`__ library and the
|
||||
`GraphQL over WebSocket Protocol <https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md>`__
|
||||
by the `graphql-ws <https://github.com/enisdenjo/graphql-ws>`__ library for sending and receiving events. The support for
|
||||
``graphql-ws`` is currently considered as ``BETA``. The graphql-engine uses the ``Sec-WebSocket-Protocol`` header to determine
|
||||
the server Implementation that'll be used. By default, the graphql-engine will use the ``apollographql/subscriptions-transport-ws`` protocol.
|
||||
|
||||
.. admonition:: Setting headers for subscriptions with Apollo client
|
||||
|
||||
|
@ -52,8 +52,11 @@ Communication protocol
|
||||
|
||||
Hasura GraphQL engine uses the `GraphQL over WebSocket Protocol
|
||||
<https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md>`__ by the
|
||||
`apollographql/subscriptions-transport-ws <https://github.com/apollographql/subscriptions-transport-ws>`__ library
|
||||
for sending and receiving events.
|
||||
`apollographql/subscriptions-transport-ws <https://github.com/apollographql/subscriptions-transport-ws>`__ library and the
|
||||
`GraphQL over WebSocket Protocol <https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md>`__
|
||||
by the `graphql-ws <https://github.com/enisdenjo/graphql-ws>`__ library for sending and receiving events. The support for
|
||||
``graphql-ws`` is currently considered as ``BETA``. The graphql-engine uses the ``Sec-WebSocket-Protocol`` header to determine
|
||||
the server Implementation that'll be used. By default, the graphql-engine will use the ``apollographql/subscriptions-transport-ws`` implementation.
|
||||
|
||||
.. admonition:: Setting headers for subscriptions with Apollo client
|
||||
|
||||
|
@ -324,11 +324,21 @@ For the ``serve`` sub-command these are the available flags and environment vari
|
||||
|
||||
*(Available for versions > v2.0.0)*
|
||||
|
||||
* - ``--websocket-keepalive``
|
||||
* - ``--websocket-keepalive <SECONDS>``
|
||||
- ``HASURA_GRAPHQL_WEBSOCKET_KEEPALIVE``
|
||||
- WebSocket keep-alive timeout in seconds (default: 5)
|
||||
- Used to set the ``Keep Alive`` delay for client that use the ``subscription-transport-ws`` (Apollo) protocol.
|
||||
For ``graphql-ws`` clients the graphql-engine sends ``PING`` messages instead.
|
||||
|
||||
(default: ``5``)
|
||||
|
||||
*(Available for versions > v2.0.0)*
|
||||
|
||||
* - ``--websocket-connection-init-timeout <SECONDS>``
|
||||
- ``HASURA_GRAPHQL_WEBSOCKET_CONNECTION_INIT_TIMEOUT``
|
||||
- Used to set the connection initialisation timeout for ``graphql-ws`` clients. This is ignored
|
||||
for ``subscription-transport-ws`` (Apollo) clients.
|
||||
|
||||
(default: ``3``)
|
||||
|
||||
.. note::
|
||||
|
||||
|
@ -611,7 +611,9 @@ library
|
||||
, Hasura.GraphQL.Transport.HTTP
|
||||
, Hasura.GraphQL.Transport.HTTP.Protocol
|
||||
, Hasura.GraphQL.Transport.Instances
|
||||
, Hasura.GraphQL.Transport.WSServerApp
|
||||
, Hasura.GraphQL.Transport.WebSocket
|
||||
, Hasura.GraphQL.Transport.WebSocket.Types
|
||||
, Hasura.GraphQL.Transport.WebSocket.Protocol
|
||||
, Hasura.GraphQL.Transport.WebSocket.Server
|
||||
|
||||
|
@ -537,6 +537,7 @@ runHGEServer setupHook env ServeOptions{..} ServeCtx{..} initTime postPollHook s
|
||||
soEnableMaintenanceMode
|
||||
soExperimentalFeatures
|
||||
_scEnabledLogTypes
|
||||
soWebsocketConnectionInitTimeout
|
||||
|
||||
let serverConfigCtx =
|
||||
ServerConfigCtx soInferFunctionPermissions
|
||||
|
@ -49,7 +49,7 @@ import Hasura.GraphQL.Execute.LiveQuery.Poll
|
||||
import Hasura.GraphQL.ParameterizedQueryHash (ParameterizedQueryHash)
|
||||
import Hasura.GraphQL.Transport.Backend
|
||||
import Hasura.GraphQL.Transport.HTTP.Protocol (OperationName)
|
||||
import Hasura.GraphQL.Transport.WebSocket.Protocol
|
||||
import Hasura.GraphQL.Transport.WebSocket.Protocol (OperationId)
|
||||
import Hasura.RQL.Types.Action
|
||||
import Hasura.RQL.Types.Common (SourceName, unNonNegativeDiffTime)
|
||||
import Hasura.Server.Metrics (ServerMetrics (..))
|
||||
|
145
server/src-lib/Hasura/GraphQL/Transport/WSServerApp.hs
Normal file
145
server/src-lib/Hasura/GraphQL/Transport/WSServerApp.hs
Normal file
@ -0,0 +1,145 @@
|
||||
module Hasura.GraphQL.Transport.WSServerApp
|
||||
( createWSServerApp
|
||||
, stopWSServerApp
|
||||
, createWSServerEnv
|
||||
) where
|
||||
|
||||
import Hasura.Prelude
|
||||
|
||||
import qualified Control.Concurrent.Async.Lifted.Safe as LA
|
||||
import qualified Control.Concurrent.STM as STM
|
||||
import qualified Control.Monad.Trans.Control as MC
|
||||
import qualified Data.ByteString.Char8 as B (pack)
|
||||
import qualified Data.Environment as Env
|
||||
import qualified Network.HTTP.Client as H
|
||||
import qualified Network.WebSockets as WS
|
||||
import qualified System.Metrics.Gauge as EKG.Gauge
|
||||
|
||||
import Control.Exception.Lifted
|
||||
import Data.Aeson (toJSON)
|
||||
import Data.Text (pack, unpack)
|
||||
|
||||
import qualified Hasura.GraphQL.Execute as E
|
||||
import qualified Hasura.GraphQL.Execute.Backend as EB
|
||||
import qualified Hasura.GraphQL.Execute.LiveQuery.State as LQ
|
||||
import qualified Hasura.GraphQL.Transport.WebSocket.Server as WS
|
||||
import qualified Hasura.Logging as L
|
||||
import qualified Hasura.Tracing as Tracing
|
||||
|
||||
import Hasura.GraphQL.Logging
|
||||
import Hasura.GraphQL.Transport.HTTP (MonadExecuteQuery)
|
||||
import Hasura.GraphQL.Transport.Instances ()
|
||||
import Hasura.GraphQL.Transport.WebSocket
|
||||
import Hasura.GraphQL.Transport.WebSocket.Protocol
|
||||
import Hasura.GraphQL.Transport.WebSocket.Types
|
||||
import Hasura.Metadata.Class
|
||||
import Hasura.RQL.Types
|
||||
import Hasura.Server.Auth (AuthMode, UserAuthentication)
|
||||
import Hasura.Server.Cors
|
||||
import Hasura.Server.Init.Config (KeepAliveDelay,
|
||||
WSConnectionInitTimeout)
|
||||
import Hasura.Server.Metrics (ServerMetrics (..))
|
||||
import Hasura.Server.Version (HasVersion)
|
||||
|
||||
createWSServerApp ::
|
||||
( HasVersion
|
||||
, MonadIO m
|
||||
, MC.MonadBaseControl IO m
|
||||
, LA.Forall (LA.Pure m)
|
||||
, UserAuthentication (Tracing.TraceT m)
|
||||
, E.MonadGQLExecutionCheck m
|
||||
, WS.MonadWSLog m
|
||||
, MonadQueryLog m
|
||||
, Tracing.HasReporter m
|
||||
, MonadExecuteQuery m
|
||||
, MonadMetadataStorage (MetadataStorageT m)
|
||||
, EB.MonadQueryTags m
|
||||
)
|
||||
=> Env.Environment
|
||||
-> HashSet (L.EngineLogType L.Hasura)
|
||||
-> AuthMode
|
||||
-> WSServerEnv
|
||||
-> WSConnectionInitTimeout
|
||||
-> WS.HasuraServerApp m
|
||||
-- -- ^ aka generalized 'WS.ServerApp'
|
||||
createWSServerApp env enabledLogTypes authMode serverEnv connInitTimeout = \ !ipAddress !pendingConn ->
|
||||
WS.createServerApp connInitTimeout (_wseServer serverEnv) handlers ipAddress pendingConn
|
||||
where
|
||||
handlers = WS.WSHandlers
|
||||
onConnHandler
|
||||
onMessageHandler
|
||||
onCloseHandler
|
||||
|
||||
logger = _wseLogger serverEnv
|
||||
serverMetrics = _wseServerMetrics serverEnv
|
||||
|
||||
wsActions = mkWSActions logger
|
||||
|
||||
-- Mask async exceptions during event processing to help maintain integrity of mutable vars:
|
||||
-- here `sp` stands for sub-protocol
|
||||
onConnHandler rid rh ip sp = mask_ do
|
||||
liftIO $ EKG.Gauge.inc $ smWebsocketConnections serverMetrics
|
||||
flip runReaderT serverEnv $ onConn rid rh ip (wsActions sp)
|
||||
|
||||
onMessageHandler conn bs sp = mask_ $
|
||||
onMessage env enabledLogTypes authMode serverEnv conn bs (wsActions sp)
|
||||
|
||||
onCloseHandler conn = mask_ do
|
||||
liftIO $ EKG.Gauge.dec $ smWebsocketConnections serverMetrics
|
||||
onClose logger serverMetrics (_wseLiveQMap serverEnv) conn
|
||||
|
||||
stopWSServerApp :: WSServerEnv -> IO ()
|
||||
stopWSServerApp wsEnv = WS.shutdown (_wseServer wsEnv)
|
||||
|
||||
createWSServerEnv :: (MonadIO m)
|
||||
=> L.Logger L.Hasura
|
||||
-> LQ.LiveQueriesState
|
||||
-> IO (SchemaCache, SchemaCacheVer)
|
||||
-> H.Manager
|
||||
-> CorsPolicy
|
||||
-> SQLGenCtx
|
||||
-> Bool
|
||||
-> KeepAliveDelay
|
||||
-> ServerMetrics
|
||||
-> m WSServerEnv
|
||||
createWSServerEnv logger lqState getSchemaCache httpManager
|
||||
corsPolicy sqlGenCtx enableAL keepAliveDelay serverMetrics = do
|
||||
wsServer <- liftIO $ STM.atomically $ WS.createWSServer logger
|
||||
pure $ WSServerEnv logger lqState getSchemaCache httpManager
|
||||
corsPolicy sqlGenCtx wsServer enableAL keepAliveDelay serverMetrics
|
||||
|
||||
mkWSActions :: L.Logger L.Hasura -> WSSubProtocol -> WS.WSActions WSConnData
|
||||
mkWSActions logger subProtocol =
|
||||
WS.WSActions
|
||||
mkPostExecErrMessageAction
|
||||
mkOnErrorMessageAction
|
||||
mkConnectionCloseAction
|
||||
keepAliveAction
|
||||
getServerMsgType
|
||||
mkAcceptRequest
|
||||
where
|
||||
mkPostExecErrMessageAction wsConn opId execErr =
|
||||
sendMsg wsConn $ case subProtocol of
|
||||
Apollo -> SMData $ DataMsg opId $ throwError execErr
|
||||
GraphQLWS -> SMErr $ ErrorMsg opId $ toJSON execErr
|
||||
|
||||
mkOnErrorMessageAction wsConn err mErrMsg = case subProtocol of
|
||||
Apollo -> sendMsg wsConn $ SMConnErr err
|
||||
GraphQLWS -> sendCloseWithMsg logger wsConn (GenericError4400 $ (fromMaybe "" mErrMsg) <> (unpack . unConnErrMsg $ err)) Nothing
|
||||
|
||||
mkConnectionCloseAction wsConn opId errMsg =
|
||||
when (subProtocol == GraphQLWS) $
|
||||
sendCloseWithMsg logger wsConn (GenericError4400 errMsg) (Just . SMErr $ ErrorMsg opId $ toJSON (pack errMsg))
|
||||
|
||||
getServerMsgType = case subProtocol of
|
||||
Apollo -> SMData
|
||||
GraphQLWS -> SMNext
|
||||
|
||||
keepAliveAction wsConn = sendMsg wsConn $
|
||||
case subProtocol of
|
||||
Apollo -> SMConnKeepAlive
|
||||
GraphQLWS -> SMPing . Just $ keepAliveMessage
|
||||
|
||||
mkAcceptRequest = WS.defaultAcceptRequest {
|
||||
WS.acceptSubprotocol = Just . B.pack . showSubProtocol $ subProtocol
|
||||
}
|
@ -1,11 +1,15 @@
|
||||
-- | This file contains the handlers that are used within websocket server
|
||||
|
||||
{-# LANGUAGE CPP #-}
|
||||
|
||||
module Hasura.GraphQL.Transport.WebSocket
|
||||
( createWSServerApp
|
||||
, createWSServerEnv
|
||||
, stopWSServerApp
|
||||
, WSServerEnv
|
||||
, WSLog(..)
|
||||
( onConn
|
||||
, onMessage
|
||||
, onClose
|
||||
-- ^ the main handlers for the websocket server
|
||||
, sendMsg
|
||||
, sendCloseWithMsg
|
||||
-- ^ helpers for sending messages to the client
|
||||
) where
|
||||
|
||||
-- NOTE!:
|
||||
@ -15,7 +19,6 @@ module Hasura.GraphQL.Transport.WebSocket
|
||||
|
||||
import Hasura.Prelude
|
||||
|
||||
import qualified Control.Concurrent.Async.Lifted.Safe as LA
|
||||
import qualified Control.Concurrent.STM as STM
|
||||
import qualified Control.Monad.Trans.Control as MC
|
||||
import qualified Data.Aeson as J
|
||||
@ -36,13 +39,10 @@ import qualified Language.GraphQL.Draft.Syntax as G
|
||||
import qualified ListT
|
||||
import qualified Network.HTTP.Client as H
|
||||
import qualified Network.HTTP.Types as H
|
||||
import qualified Network.Wai.Extended as Wai
|
||||
import qualified Network.WebSockets as WS
|
||||
import qualified StmContainers.Map as STMMap
|
||||
import qualified System.Metrics.Gauge as EKG.Gauge
|
||||
|
||||
import Control.Concurrent.Extended (sleep)
|
||||
import Control.Exception.Lifted
|
||||
import Data.String
|
||||
#ifndef PROFILING
|
||||
import GHC.AssertNF
|
||||
@ -76,8 +76,9 @@ import Hasura.GraphQL.Transport.HTTP (MonadExecuteQuery
|
||||
import Hasura.GraphQL.Transport.HTTP.Protocol
|
||||
import Hasura.GraphQL.Transport.Instances ()
|
||||
import Hasura.GraphQL.Transport.WebSocket.Protocol
|
||||
import Hasura.GraphQL.Transport.WebSocket.Types
|
||||
import Hasura.Metadata.Class
|
||||
import Hasura.RQL.Types
|
||||
import Hasura.RQL.Types.RemoteSchema
|
||||
import Hasura.Server.Auth (AuthMode, UserAuthentication,
|
||||
resolveUserInfo)
|
||||
import Hasura.Server.Cors
|
||||
@ -87,77 +88,12 @@ import Hasura.Server.Types (RequestId, getReq
|
||||
import Hasura.Server.Version (HasVersion)
|
||||
import Hasura.Session
|
||||
|
||||
|
||||
-- | 'LQ.LiveQueryId' comes from 'Hasura.GraphQL.Execute.LiveQuery.State.addLiveQuery'. We use
|
||||
-- this to track a connection's operations so we can remove them from 'LiveQueryState', and
|
||||
-- log.
|
||||
--
|
||||
-- NOTE!: This must be kept consistent with the global 'LiveQueryState', in 'onClose'
|
||||
-- and 'onStart'.
|
||||
type OperationMap
|
||||
= STMMap.Map OperationId (LQ.LiveQueryId, Maybe OperationName)
|
||||
|
||||
newtype WsHeaders
|
||||
= WsHeaders { unWsHeaders :: [H.Header] }
|
||||
deriving (Show, Eq)
|
||||
|
||||
data ErrRespType
|
||||
= ERTLegacy
|
||||
| ERTGraphqlCompliant
|
||||
deriving (Show)
|
||||
|
||||
data WSConnState
|
||||
= CSNotInitialised !WsHeaders !Wai.IpAddress
|
||||
-- ^ headers and IP address from the client for websockets
|
||||
| CSInitError !Text
|
||||
| CSInitialised !WsClientState
|
||||
|
||||
data WsClientState
|
||||
= WsClientState
|
||||
{ wscsUserInfo :: !UserInfo
|
||||
-- ^ the 'UserInfo' required to execute the GraphQL query
|
||||
, wscsTokenExpTime :: !(Maybe TC.UTCTime)
|
||||
-- ^ the JWT/token expiry time, if any
|
||||
, wscsReqHeaders :: ![H.Header]
|
||||
-- ^ headers from the client (in conn params) to forward to the remote schema
|
||||
, wscsIpAddress :: !Wai.IpAddress
|
||||
-- ^ IP address required for 'MonadGQLAuthorization'
|
||||
}
|
||||
|
||||
data WSConnData
|
||||
= WSConnData
|
||||
-- the role and headers are set only on connection_init message
|
||||
{ _wscUser :: !(STM.TVar WSConnState)
|
||||
-- we only care about subscriptions,
|
||||
-- the other operations (query/mutations)
|
||||
-- are not tracked here
|
||||
, _wscOpMap :: !OperationMap
|
||||
, _wscErrRespTy :: !ErrRespType
|
||||
, _wscAPIType :: !E.GraphQLQueryType
|
||||
}
|
||||
|
||||
type WSServer = WS.WSServer WSConnData
|
||||
|
||||
type WSConn = WS.WSConn WSConnData
|
||||
|
||||
sendMsg :: (MonadIO m) => WSConn -> ServerMsg -> m ()
|
||||
sendMsg wsConn msg =
|
||||
liftIO $ WS.sendMsg wsConn $ WS.WSQueueResponse (encodeServerMsg msg) Nothing
|
||||
|
||||
sendMsgWithMetadata :: (MonadIO m) => WSConn -> ServerMsg -> LQ.LiveQueryMetadata -> m ()
|
||||
sendMsgWithMetadata wsConn msg (LQ.LiveQueryMetadata execTime) =
|
||||
liftIO $ WS.sendMsg wsConn $ WS.WSQueueResponse bs wsInfo
|
||||
where
|
||||
bs = encodeServerMsg msg
|
||||
(msgType, operationId) = case msg of
|
||||
(SMData (DataMsg opId _)) -> (Just SMT_GQL_DATA, Just opId)
|
||||
_ -> (Nothing, Nothing)
|
||||
wsInfo = Just $! WS.WSEventInfo
|
||||
{ WS._wseiEventType = msgType
|
||||
, WS._wseiOperationId = operationId
|
||||
, WS._wseiQueryExecutionTime = Just $! realToFrac execTime
|
||||
, WS._wseiResponseSize = Just $! LBS.length bs
|
||||
}
|
||||
|
||||
data OpDetail
|
||||
= ODStarted
|
||||
@ -229,35 +165,91 @@ mkWsErrorLog :: Maybe SessionVariables -> WsConnInfo -> WSEvent -> WSLog
|
||||
mkWsErrorLog uv ci ev =
|
||||
WSLog L.LevelError $ WSLogInfo uv ci ev
|
||||
|
||||
data WSServerEnv
|
||||
= WSServerEnv
|
||||
{ _wseLogger :: !(L.Logger L.Hasura)
|
||||
, _wseLiveQMap :: !LQ.LiveQueriesState
|
||||
, _wseGCtxMap :: !(IO (SchemaCache, SchemaCacheVer))
|
||||
-- ^ an action that always returns the latest version of the schema cache. See 'SchemaCacheRef'.
|
||||
, _wseHManager :: !H.Manager
|
||||
, _wseCorsPolicy :: !CorsPolicy
|
||||
, _wseSQLCtx :: !SQLGenCtx
|
||||
, _wseServer :: !WSServer
|
||||
, _wseEnableAllowlist :: !Bool
|
||||
, _wseKeepAliveDelay :: !KeepAliveDelay
|
||||
, _wseServerMetrics :: !ServerMetrics
|
||||
}
|
||||
logWSEvent :: (MonadIO m)
|
||||
=> L.Logger L.Hasura
|
||||
-> WSConn
|
||||
-> WSEvent
|
||||
-> m ()
|
||||
logWSEvent (L.Logger logger) wsConn wsEv = do
|
||||
userInfoME <- liftIO $ STM.readTVarIO userInfoR
|
||||
let (userVarsM, tokenExpM) = case userInfoME of
|
||||
CSInitialised WsClientState{..} -> ( Just $ _uiSession wscsUserInfo
|
||||
, wscsTokenExpTime
|
||||
)
|
||||
_ -> (Nothing, Nothing)
|
||||
liftIO $ logger $ WSLog logLevel $ WSLogInfo userVarsM (WsConnInfo wsId tokenExpM Nothing) wsEv
|
||||
where
|
||||
WSConnData userInfoR _ _ _ = WS.getData wsConn
|
||||
wsId = WS.getWSId wsConn
|
||||
logLevel = bool L.LevelInfo L.LevelError isError
|
||||
isError = case wsEv of
|
||||
EAccepted -> False
|
||||
ERejected _ -> True
|
||||
EConnErr _ -> True
|
||||
EClosed -> False
|
||||
EOperation operation -> case _odOperationType operation of
|
||||
ODStarted -> False
|
||||
ODProtoErr _ -> True
|
||||
ODQueryErr _ -> True
|
||||
ODCompleted -> False
|
||||
ODStopped -> False
|
||||
|
||||
sendMsg :: (MonadIO m) => WSConn -> ServerMsg -> m ()
|
||||
sendMsg wsConn msg =
|
||||
liftIO $ WS.sendMsg wsConn $ WS.WSQueueResponse (encodeServerMsg msg) Nothing
|
||||
|
||||
sendCloseWithMsg :: (MonadIO m)
|
||||
=> L.Logger L.Hasura
|
||||
-> WSConn
|
||||
-> ServerErrorCode
|
||||
-> Maybe ServerMsg
|
||||
-> m ()
|
||||
sendCloseWithMsg logger wsConn errCode mErrServerMsg = do
|
||||
case mErrServerMsg of
|
||||
Just errServerMsg -> do
|
||||
sendMsg wsConn errServerMsg
|
||||
Nothing -> pure ()
|
||||
logWSEvent logger wsConn EClosed
|
||||
liftIO $ WS.sendClose wsc errMsg
|
||||
where
|
||||
wsc = WS.getRawWebSocketConnection wsConn
|
||||
errMsg = encodeServerErrorMsg errCode
|
||||
|
||||
sendMsgWithMetadata :: (MonadIO m) => WSConn -> ServerMsg -> LQ.LiveQueryMetadata -> m ()
|
||||
sendMsgWithMetadata wsConn msg (LQ.LiveQueryMetadata execTime) =
|
||||
liftIO $ WS.sendMsg wsConn $ WS.WSQueueResponse bs wsInfo
|
||||
where
|
||||
bs = encodeServerMsg msg
|
||||
(msgType, operationId) = case msg of
|
||||
(SMNext (DataMsg opId _)) -> (Just SMT_GQL_NEXT, Just opId)
|
||||
(SMData (DataMsg opId _)) -> (Just SMT_GQL_DATA, Just opId)
|
||||
_ -> (Nothing, Nothing)
|
||||
wsInfo = Just $! WS.WSEventInfo
|
||||
{ WS._wseiEventType = msgType
|
||||
, WS._wseiOperationId = operationId
|
||||
, WS._wseiQueryExecutionTime = Just $! realToFrac execTime
|
||||
, WS._wseiResponseSize = Just $! LBS.length bs
|
||||
}
|
||||
|
||||
onConn :: (MonadIO m, MonadReader WSServerEnv m)
|
||||
=> WS.OnConnH m WSConnData
|
||||
onConn wsId requestHead ipAddress = do
|
||||
onConn wsId requestHead ipAddress onConnHActions = do
|
||||
res <- runExceptT $ do
|
||||
(errType, queryType) <- checkPath
|
||||
let reqHdrs = WS.requestHeaders requestHead
|
||||
headers <- maybe (return reqHdrs) (flip enforceCors reqHdrs . snd) getOrigin
|
||||
return (WsHeaders $ filterWsHeaders headers, errType, queryType)
|
||||
either reject accept res
|
||||
|
||||
where
|
||||
kaAction = WS._wsaKeepAliveAction onConnHActions
|
||||
acceptRequest = WS._wsaAcceptRequest onConnHActions
|
||||
|
||||
-- NOTE: the "Keep-Alive" delay is something that's mentioned
|
||||
-- in the Apollo spec. For 'graphql-ws', we're using the Ping
|
||||
-- messages that are part of the spec.
|
||||
keepAliveAction keepAliveDelay wsConn = do
|
||||
liftIO $ forever $ do
|
||||
sendMsg wsConn SMConnKeepAlive
|
||||
kaAction wsConn
|
||||
sleep $ seconds (unKeepAliveDelay keepAliveDelay)
|
||||
|
||||
tokenExpiryHandler wsConn = do
|
||||
@ -272,16 +264,21 @@ onConn wsId requestHead ipAddress = do
|
||||
|
||||
accept (hdrs, errType, queryType) = do
|
||||
(L.Logger logger) <- asks _wseLogger
|
||||
keepAliveDelay <- asks _wseKeepAliveDelay
|
||||
keepAliveDelay <- asks _wseKeepAliveDelay
|
||||
logger $ mkWsInfoLog Nothing (WsConnInfo wsId Nothing Nothing) EAccepted
|
||||
connData <- liftIO $ WSConnData
|
||||
<$> STM.newTVarIO (CSNotInitialised hdrs ipAddress)
|
||||
<*> STMMap.newIO
|
||||
<*> pure errType
|
||||
<*> pure queryType
|
||||
let acceptRequest = WS.defaultAcceptRequest
|
||||
{ WS.acceptSubprotocol = Just "graphql-ws"}
|
||||
return $ Right $ WS.AcceptWith connData acceptRequest (keepAliveAction keepAliveDelay) tokenExpiryHandler
|
||||
|
||||
pure $ Right $
|
||||
WS.AcceptWith
|
||||
connData
|
||||
acceptRequest
|
||||
(keepAliveAction keepAliveDelay)
|
||||
tokenExpiryHandler
|
||||
|
||||
reject qErr = do
|
||||
(L.Logger logger) <- asks _wseLogger
|
||||
logger $ mkWsErrorLog Nothing (WsConnInfo wsId Nothing Nothing) (ERejected qErr)
|
||||
@ -351,8 +348,9 @@ onStart
|
||||
-> WSServerEnv
|
||||
-> WSConn
|
||||
-> StartMsg
|
||||
-> WS.WSActions WSConnData
|
||||
-> m ()
|
||||
onStart env enabledLogTypes serverEnv wsConn (StartMsg opId q) = catchAndIgnore $ do
|
||||
onStart env enabledLogTypes serverEnv wsConn (StartMsg opId q) onMessageActions = catchAndIgnore $ do
|
||||
timerTot <- startTimer
|
||||
opM <- liftIO $ STM.atomically $ STMMap.lookup opId opMap
|
||||
|
||||
@ -517,8 +515,9 @@ onStart env enabledLogTypes serverEnv wsConn (StartMsg opId q) = catchAndIgnore
|
||||
case resultsE of
|
||||
Left err -> sendError requestId err
|
||||
Right results -> do
|
||||
let dataMsg = SMData $ DataMsg opId $ pure $ encJToLBS $
|
||||
encJFromInsOrdHashMap $ OMap.mapKeys G.unName results
|
||||
let dataMsg = sendDataMsg $
|
||||
DataMsg opId $ pure $ encJToLBS $
|
||||
encJFromInsOrdHashMap $ OMap.mapKeys G.unName results
|
||||
sendMsgWithMetadata wsConn dataMsg $ LQ.LiveQueryMetadata dTime
|
||||
|
||||
asyncActionQueryLive = LQ.LAAQNoRelationships $
|
||||
@ -556,6 +555,10 @@ onStart env enabledLogTypes serverEnv wsConn (StartMsg opId q) = catchAndIgnore
|
||||
|
||||
liftIO $ logOpEv ODStarted (Just requestId)
|
||||
where
|
||||
sendDataMsg = WS._wsaGetDataMessageType onMessageActions
|
||||
closeConnAction = WS._wsaConnectionCloseAction onMessageActions
|
||||
postExecErrAction = WS._wsaPostExecErrMessageAction onMessageActions
|
||||
|
||||
getExecStepActionWithActionInfo acc execStep = case execStep of
|
||||
E.ExecStepAction _ actionInfo _remoteJoins -> (actionInfo:acc)
|
||||
_ -> acc
|
||||
@ -602,7 +605,7 @@ onStart env enabledLogTypes serverEnv wsConn (StartMsg opId q) = catchAndIgnore
|
||||
return $ ResultsFragment telemTimeIO_DT Telem.Remote (encJFromOrderedValue value) []
|
||||
|
||||
WSServerEnv logger lqMap getSchemaCache httpMgr _ sqlGenCtx
|
||||
_ enableAL _keepAliveDelay _ = serverEnv
|
||||
_ enableAL _keepAliveDelay _connInitTime = serverEnv
|
||||
|
||||
WSConnData userInfoR opMap errRespTy queryType = WS.getData wsConn
|
||||
|
||||
@ -620,9 +623,10 @@ onStart env enabledLogTypes serverEnv wsConn (StartMsg opId q) = catchAndIgnore
|
||||
sendMsg wsConn $
|
||||
SMErr $ ErrorMsg opId $ errFn False $ err400 StartFailed e
|
||||
liftIO $ logOpEv (ODProtoErr e) Nothing
|
||||
liftIO $ closeConnAction wsConn opId (T.unpack e)
|
||||
|
||||
sendCompleted reqId = do
|
||||
sendMsg wsConn (SMComplete $ CompletionMsg opId)
|
||||
sendMsg wsConn (SMComplete . CompletionMsg $ opId)
|
||||
logOpEv ODCompleted reqId
|
||||
|
||||
postExecErr :: RequestId -> QErr -> ExceptT () m ()
|
||||
@ -632,9 +636,7 @@ onStart env enabledLogTypes serverEnv wsConn (StartMsg opId q) = catchAndIgnore
|
||||
postExecErr' $ GQExecError $ pure $ errFn qErr
|
||||
|
||||
postExecErr' :: GQExecError -> ExceptT () m ()
|
||||
postExecErr' qErr = do
|
||||
sendMsg wsConn $ SMData $
|
||||
DataMsg opId $ throwError qErr
|
||||
postExecErr' qErr = liftIO $ postExecErrAction wsConn opId qErr
|
||||
|
||||
-- why wouldn't pre exec error use graphql response?
|
||||
preExecErr reqId qErr = liftIO $ sendError reqId qErr
|
||||
@ -649,8 +651,8 @@ onStart env enabledLogTypes serverEnv wsConn (StartMsg opId q) = catchAndIgnore
|
||||
|
||||
sendSuccResp :: EncJSON -> LQ.LiveQueryMetadata -> ExceptT () m ()
|
||||
sendSuccResp encJson =
|
||||
sendMsgWithMetadata wsConn
|
||||
(SMData $ DataMsg opId $ pure $ encJToLBS encJson)
|
||||
sendMsgWithMetadata wsConn $
|
||||
sendDataMsg $ DataMsg opId $ pure $ encJToLBS encJson
|
||||
|
||||
withComplete :: ExceptT () m () -> ExceptT () m a
|
||||
withComplete action = do
|
||||
@ -686,54 +688,73 @@ onStart env enabledLogTypes serverEnv wsConn (StartMsg opId q) = catchAndIgnore
|
||||
liveQOnChange = \case
|
||||
Right (LQ.LiveQueryResponse bs dTime) ->
|
||||
sendMsgWithMetadata wsConn
|
||||
(SMData $ DataMsg opId $ pure $ LBS.fromStrict bs)
|
||||
(sendDataMsg $ DataMsg opId $ pure $ LBS.fromStrict bs)
|
||||
(LQ.LiveQueryMetadata dTime)
|
||||
resp -> sendMsg wsConn $ SMData $ DataMsg opId $ LBS.fromStrict . LQ._lqrPayload <$> resp
|
||||
resp -> sendMsg wsConn $
|
||||
sendDataMsg $ DataMsg opId $ LBS.fromStrict . LQ._lqrPayload <$> resp
|
||||
|
||||
catchAndIgnore :: ExceptT () m () -> m ()
|
||||
catchAndIgnore m = void $ runExceptT m
|
||||
|
||||
onMessage
|
||||
:: ( HasVersion
|
||||
, MonadIO m
|
||||
, UserAuthentication (Tracing.TraceT m)
|
||||
, E.MonadGQLExecutionCheck m
|
||||
, MonadQueryLog m
|
||||
, Tracing.HasReporter m
|
||||
, MonadExecuteQuery m
|
||||
, MC.MonadBaseControl IO m
|
||||
, MonadMetadataStorage (MetadataStorageT m)
|
||||
, EB.MonadQueryTags m
|
||||
)
|
||||
onMessage ::
|
||||
( HasVersion
|
||||
, MonadIO m
|
||||
, UserAuthentication (Tracing.TraceT m)
|
||||
, E.MonadGQLExecutionCheck m
|
||||
, MonadQueryLog m
|
||||
, Tracing.HasReporter m
|
||||
, MonadExecuteQuery m
|
||||
, MC.MonadBaseControl IO m
|
||||
, MonadMetadataStorage (MetadataStorageT m)
|
||||
, EB.MonadQueryTags m
|
||||
)
|
||||
=> Env.Environment
|
||||
-> HashSet (L.EngineLogType L.Hasura)
|
||||
-> AuthMode
|
||||
-> WSServerEnv
|
||||
-> WSConn -> LBS.ByteString -> m ()
|
||||
onMessage env enabledLogTypes authMode serverEnv wsConn msgRaw = Tracing.runTraceT "websocket" do
|
||||
-> WSConn
|
||||
-> LBS.ByteString
|
||||
-> WS.WSActions WSConnData
|
||||
-> m ()
|
||||
onMessage env enabledLogTypes authMode serverEnv wsConn msgRaw onMessageActions = Tracing.runTraceT "websocket" do
|
||||
case J.eitherDecode msgRaw of
|
||||
Left e -> do
|
||||
Left e -> do
|
||||
let err = ConnErrMsg $ "parsing ClientMessage failed: " <> T.pack e
|
||||
logWSEvent logger wsConn $ EConnErr err
|
||||
sendMsg wsConn $ SMConnErr err
|
||||
liftIO $ onErrAction wsConn err WS.onClientMessageParseErrorText
|
||||
|
||||
Right msg -> case msg of
|
||||
CMConnInit params -> onConnInit (_wseLogger serverEnv)
|
||||
(_wseHManager serverEnv)
|
||||
wsConn authMode params
|
||||
CMStart startMsg -> onStart env enabledLogTypes serverEnv wsConn startMsg
|
||||
CMStop stopMsg -> liftIO $ onStop serverEnv wsConn stopMsg
|
||||
-- The idea is cleanup will be handled by 'onClose', but...
|
||||
-- NOTE: we need to close the websocket connection when we receive the
|
||||
-- CMConnTerm message and calling WS.closeConn will definitely throw an
|
||||
-- exception, but I'm not sure if 'closeConn' is the correct thing here....
|
||||
-- common to both protocols
|
||||
CMConnInit params -> onConnInit logger (_wseHManager serverEnv) wsConn
|
||||
authMode params onErrAction keepAliveMessageAction
|
||||
CMStart startMsg -> onStart env enabledLogTypes serverEnv wsConn startMsg onMessageActions
|
||||
CMStop stopMsg -> onStop serverEnv wsConn stopMsg
|
||||
|
||||
-- specfic to graphql-ws
|
||||
CMPing mPayload -> onPing wsConn mPayload
|
||||
CMPong mPayload -> onPong wsConn mPayload
|
||||
|
||||
-- specific to apollo clients
|
||||
CMConnTerm -> liftIO $ WS.closeConn wsConn "GQL_CONNECTION_TERMINATE received"
|
||||
where
|
||||
logger = _wseLogger serverEnv
|
||||
onErrAction = WS._wsaOnErrorMessageAction onMessageActions
|
||||
keepAliveMessageAction = WS._wsaKeepAliveAction onMessageActions
|
||||
|
||||
onPing :: (MonadIO m) => WSConn -> Maybe PingPongPayload -> m ()
|
||||
onPing wsConn mPayload =
|
||||
liftIO $ sendMsg wsConn (SMPong mPayload)
|
||||
|
||||
onStop :: WSServerEnv -> WSConn -> StopMsg -> IO ()
|
||||
onStop serverEnv wsConn (StopMsg opId) = do
|
||||
onPong :: (MonadIO m) => WSConn -> Maybe PingPongPayload -> m ()
|
||||
onPong wsConn mPayload = liftIO $ case mPayload of
|
||||
Just message -> do
|
||||
when (message /= keepAliveMessage) $
|
||||
sendMsg wsConn (SMPing mPayload)
|
||||
-- NOTE: this is done to avoid sending Ping for every "keepalive" that the server sends
|
||||
Nothing -> sendMsg wsConn $ SMPing Nothing
|
||||
|
||||
onStop :: (MonadIO m) => WSServerEnv -> WSConn -> StopMsg -> m ()
|
||||
onStop serverEnv wsConn (StopMsg opId) = liftIO $ do
|
||||
-- When a stop message is received for an operation, it may not be present in OpMap
|
||||
-- in these cases:
|
||||
-- 1. If the operation is a query/mutation - as we remove the operation from the
|
||||
@ -763,38 +784,19 @@ stopOperation serverEnv wsConn opId logWhenOpNotExist = do
|
||||
opMap = _wscOpMap $ WS.getData wsConn
|
||||
opDet n = OperationDetails opId Nothing n ODStopped Nothing
|
||||
|
||||
logWSEvent
|
||||
:: (MonadIO m)
|
||||
=> L.Logger L.Hasura -> WSConn -> WSEvent -> m ()
|
||||
logWSEvent (L.Logger logger) wsConn wsEv = do
|
||||
userInfoME <- liftIO $ STM.readTVarIO userInfoR
|
||||
let (userVarsM, tokenExpM) = case userInfoME of
|
||||
CSInitialised WsClientState{..} -> ( Just $ _uiSession wscsUserInfo
|
||||
, wscsTokenExpTime
|
||||
)
|
||||
_ -> (Nothing, Nothing)
|
||||
liftIO $ logger $ WSLog logLevel $ WSLogInfo userVarsM (WsConnInfo wsId tokenExpM Nothing) wsEv
|
||||
where
|
||||
WSConnData userInfoR _ _ _ = WS.getData wsConn
|
||||
wsId = WS.getWSId wsConn
|
||||
logLevel = bool L.LevelInfo L.LevelError isError
|
||||
isError = case wsEv of
|
||||
EAccepted -> False
|
||||
ERejected _ -> True
|
||||
EConnErr _ -> True
|
||||
EClosed -> False
|
||||
EOperation operation -> case _odOperationType operation of
|
||||
ODStarted -> False
|
||||
ODProtoErr _ -> True
|
||||
ODQueryErr _ -> True
|
||||
ODCompleted -> False
|
||||
ODStopped -> False
|
||||
|
||||
onConnInit
|
||||
:: (HasVersion, MonadIO m, UserAuthentication (Tracing.TraceT m))
|
||||
=> L.Logger L.Hasura -> H.Manager -> WSConn -> AuthMode
|
||||
-> Maybe ConnParams -> Tracing.TraceT m ()
|
||||
onConnInit logger manager wsConn authMode connParamsM = do
|
||||
=> L.Logger L.Hasura
|
||||
-> H.Manager
|
||||
-> WSConn
|
||||
-> AuthMode
|
||||
-> Maybe ConnParams
|
||||
-> WS.WSOnErrorMessageAction WSConnData
|
||||
-- ^ this is the message handler for handling errors on initializing a from the client connection
|
||||
-> WS.WSKeepAliveMessageAction WSConnData
|
||||
-- ^ this is the message handler for handling "keep-alive" messages to the client
|
||||
-> Tracing.TraceT m ()
|
||||
onConnInit logger manager wsConn authMode connParamsM onConnInitErrAction keepAliveMessageAction = do
|
||||
-- TODO(from master): what should be the behaviour of connection_init message when a
|
||||
-- connection is already iniatilized? Currently, we seem to be doing
|
||||
-- something arbitrary which isn't correct. Ideally, we should stick to
|
||||
@ -820,7 +822,7 @@ onConnInit logger manager wsConn authMode connParamsM = do
|
||||
|
||||
let connErr = ConnErrMsg $ qeError e
|
||||
logWSEvent logger wsConn $ EConnErr connErr
|
||||
sendMsg wsConn $ SMConnErr connErr
|
||||
liftIO $ onConnInitErrAction wsConn connErr WS.onConnInitErrorText
|
||||
|
||||
Right (userInfo, expTimeM) -> do
|
||||
let !csInit = CSInitialised $ WsClientState userInfo expTimeM paramHeaders ipAddress
|
||||
@ -831,14 +833,12 @@ onConnInit logger manager wsConn authMode connParamsM = do
|
||||
STM.atomically $ STM.writeTVar (_wscUser $ WS.getData wsConn) csInit
|
||||
|
||||
sendMsg wsConn SMConnAck
|
||||
-- TODO(from master): send it periodically? Why doesn't apollo's protocol use
|
||||
-- ping/pong frames of websocket spec?
|
||||
sendMsg wsConn SMConnKeepAlive
|
||||
liftIO $ keepAliveMessageAction wsConn
|
||||
where
|
||||
unexpectedInitError e = do
|
||||
let connErr = ConnErrMsg e
|
||||
logWSEvent logger wsConn $ EConnErr connErr
|
||||
sendMsg wsConn $ SMConnErr connErr
|
||||
liftIO $ onConnInitErrAction wsConn connErr WS.onConnInitErrorText
|
||||
|
||||
getIpAddress = \case
|
||||
CSNotInitialised _ ip -> return ip
|
||||
@ -871,59 +871,3 @@ onClose logger serverMetrics lqMap wsConn = do
|
||||
LQ.removeLiveQuery logger serverMetrics lqMap lqId
|
||||
where
|
||||
opMap = _wscOpMap $ WS.getData wsConn
|
||||
|
||||
createWSServerEnv
|
||||
:: (MonadIO m)
|
||||
=> L.Logger L.Hasura
|
||||
-> LQ.LiveQueriesState
|
||||
-> IO (SchemaCache, SchemaCacheVer)
|
||||
-> H.Manager
|
||||
-> CorsPolicy
|
||||
-> SQLGenCtx
|
||||
-> Bool
|
||||
-> KeepAliveDelay
|
||||
-> ServerMetrics
|
||||
-> m WSServerEnv
|
||||
createWSServerEnv logger lqState getSchemaCache httpManager
|
||||
corsPolicy sqlGenCtx enableAL keepAliveDelay serverMetrics = do
|
||||
wsServer <- liftIO $ STM.atomically $ WS.createWSServer logger
|
||||
return $
|
||||
WSServerEnv logger lqState getSchemaCache httpManager corsPolicy
|
||||
sqlGenCtx wsServer enableAL keepAliveDelay serverMetrics
|
||||
|
||||
createWSServerApp
|
||||
:: ( HasVersion
|
||||
, MonadIO m
|
||||
, MC.MonadBaseControl IO m
|
||||
, LA.Forall (LA.Pure m)
|
||||
, UserAuthentication (Tracing.TraceT m)
|
||||
, E.MonadGQLExecutionCheck m
|
||||
, WS.MonadWSLog m
|
||||
, MonadQueryLog m
|
||||
, Tracing.HasReporter m
|
||||
, MonadExecuteQuery m
|
||||
, MonadMetadataStorage (MetadataStorageT m)
|
||||
, EB.MonadQueryTags m
|
||||
)
|
||||
=> Env.Environment
|
||||
-> HashSet (L.EngineLogType L.Hasura)
|
||||
-> AuthMode
|
||||
-> WSServerEnv
|
||||
-> WS.HasuraServerApp m
|
||||
-- -- ^ aka generalized 'WS.ServerApp'
|
||||
createWSServerApp env enabledLogTypes authMode serverEnv = \ !ipAddress !pendingConn ->
|
||||
WS.createServerApp (_wseServer serverEnv) handlers ipAddress pendingConn
|
||||
where
|
||||
handlers = WS.WSHandlers onConnHandler onMessageHandler onCloseHandler
|
||||
serverMetrics = _wseServerMetrics serverEnv
|
||||
-- Mask async exceptions during event processing to help maintain integrity of mutable vars:
|
||||
onConnHandler rid rh ip = mask_ do
|
||||
liftIO $ EKG.Gauge.inc $ smWebsocketConnections serverMetrics
|
||||
flip runReaderT serverEnv $ onConn rid rh ip
|
||||
onMessageHandler conn bs = mask_ $ onMessage env enabledLogTypes authMode serverEnv conn bs
|
||||
onCloseHandler conn = mask_ do
|
||||
liftIO $ EKG.Gauge.dec $ smWebsocketConnections serverMetrics
|
||||
onClose (_wseLogger serverEnv) serverMetrics (_wseLiveQMap serverEnv) conn
|
||||
|
||||
stopWSServerApp :: WSServerEnv -> IO ()
|
||||
stopWSServerApp wsEnv = WS.shutdown (_wseServer wsEnv)
|
||||
|
@ -1,33 +1,84 @@
|
||||
-- | See: https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md
|
||||
module Hasura.GraphQL.Transport.WebSocket.Protocol
|
||||
( OperationId(..)
|
||||
, ConnParams(..)
|
||||
, StartMsg(..)
|
||||
, StopMsg(..)
|
||||
, ClientMsg(..)
|
||||
, ServerMsg(..)
|
||||
, ServerMsgType(..)
|
||||
, encodeServerMsg
|
||||
, serverMsgType
|
||||
, DataMsg(..)
|
||||
, ErrorMsg(..)
|
||||
, ConnErrMsg(..)
|
||||
, CompletionMsg(..)
|
||||
) where
|
||||
-- | This file contains types for both the websocket protocols (Apollo) and (graphql-ws)
|
||||
-- | See Apollo: https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md
|
||||
-- | See graphql-ws: https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md
|
||||
|
||||
module Hasura.GraphQL.Transport.WebSocket.Protocol where
|
||||
|
||||
import Hasura.Prelude
|
||||
|
||||
import qualified Data.Aeson as J
|
||||
import qualified Data.Aeson.TH as J
|
||||
import qualified Data.ByteString.Lazy as BL
|
||||
import qualified Data.HashMap.Strict as Map
|
||||
|
||||
import Control.Concurrent
|
||||
import Control.Concurrent.Extended (sleep)
|
||||
import Control.Concurrent.STM
|
||||
import Data.Text (pack)
|
||||
|
||||
import Hasura.EncJSON
|
||||
import Hasura.GraphQL.Transport.HTTP.Protocol
|
||||
import Hasura.Prelude
|
||||
|
||||
-- | These come from the client and are websocket connection-local.
|
||||
-- NOTE: the `subProtocol` is decided based on the `Sec-WebSocket-Protocol`
|
||||
-- header on every request sent to the server.
|
||||
data WSSubProtocol = Apollo | GraphQLWS
|
||||
deriving (Eq, Show)
|
||||
|
||||
-- NOTE: Please do not change them, as they're used for to identify the type of client
|
||||
-- on every request that reaches the server. They are unique to each of the protocols.
|
||||
showSubProtocol :: WSSubProtocol -> String
|
||||
showSubProtocol subProtocol = case subProtocol of
|
||||
-- REF: https://github.com/apollographql/subscriptions-transport-ws/blob/master/src/server.ts#L144
|
||||
Apollo -> "graphql-ws"
|
||||
-- REF: https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md#communication
|
||||
GraphQLWS -> "graphql-transport-ws"
|
||||
|
||||
toWSSubProtocol :: String -> WSSubProtocol
|
||||
toWSSubProtocol str = case str of
|
||||
"graphql-transport-ws" -> GraphQLWS
|
||||
_ -> Apollo
|
||||
|
||||
-- This is set by the client when it connects to the server
|
||||
newtype OperationId
|
||||
= OperationId { unOperationId :: Text }
|
||||
deriving (Show, Eq, J.ToJSON, J.FromJSON, Hashable)
|
||||
deriving (Show, Eq, J.ToJSON, J.FromJSON, IsString, Hashable)
|
||||
|
||||
data ServerMsgType =
|
||||
-- specific to `Apollo` clients
|
||||
SMT_GQL_CONNECTION_KEEP_ALIVE
|
||||
| SMT_GQL_CONNECTION_ERROR
|
||||
| SMT_GQL_DATA
|
||||
-- specific to `graphql-ws` clients
|
||||
| SMT_GQL_NEXT
|
||||
| SMT_GQL_PING
|
||||
| SMT_GQL_PONG
|
||||
-- common to clients of both protocols
|
||||
| SMT_GQL_CONNECTION_ACK
|
||||
| SMT_GQL_ERROR
|
||||
| SMT_GQL_COMPLETE
|
||||
deriving (Eq)
|
||||
|
||||
instance Show ServerMsgType where
|
||||
show = \case
|
||||
-- specific to `Apollo` clients
|
||||
SMT_GQL_CONNECTION_KEEP_ALIVE -> "ka"
|
||||
SMT_GQL_CONNECTION_ERROR -> "connection_error"
|
||||
SMT_GQL_DATA -> "data"
|
||||
-- specific to `graphql-ws` clients
|
||||
SMT_GQL_NEXT -> "next"
|
||||
SMT_GQL_PING -> "ping"
|
||||
SMT_GQL_PONG -> "pong"
|
||||
-- common to clients of both protocols
|
||||
SMT_GQL_CONNECTION_ACK -> "connection_ack"
|
||||
SMT_GQL_ERROR -> "error"
|
||||
SMT_GQL_COMPLETE -> "complete"
|
||||
|
||||
instance J.ToJSON ServerMsgType where
|
||||
toJSON = J.toJSON . show
|
||||
|
||||
data ConnParams = ConnParams
|
||||
{ _cpHeaders :: Maybe (HashMap Text Text) }
|
||||
deriving stock (Show, Eq)
|
||||
$(J.deriveJSON hasuraJSON ''ConnParams)
|
||||
|
||||
data StartMsg
|
||||
= StartMsg
|
||||
@ -42,30 +93,55 @@ data StopMsg
|
||||
} deriving (Show, Eq)
|
||||
$(J.deriveJSON hasuraJSON ''StopMsg)
|
||||
|
||||
-- Specific to graphql-ws
|
||||
data PingPongPayload
|
||||
= PingPongPayload
|
||||
{ _smMessage:: !(Maybe Text) -- NOTE: this is not within the spec, but is specific to our usecase
|
||||
} deriving stock (Show, Eq)
|
||||
$(J.deriveJSON hasuraJSON ''PingPongPayload)
|
||||
|
||||
-- Specific to graphql-ws
|
||||
keepAliveMessage :: PingPongPayload
|
||||
keepAliveMessage = PingPongPayload . Just . pack $ "keepalive"
|
||||
|
||||
-- Specific to graphql-ws
|
||||
data SubscribeMsg
|
||||
= SubscribeMsg
|
||||
{ _subId :: !OperationId
|
||||
, _subPayload :: !GQLReqUnparsed
|
||||
} deriving (Show, Eq)
|
||||
$(J.deriveJSON hasuraJSON ''SubscribeMsg)
|
||||
|
||||
data ClientMsg
|
||||
= CMConnInit !(Maybe ConnParams)
|
||||
| CMStart !StartMsg
|
||||
| CMStop !StopMsg
|
||||
-- specific to apollo clients
|
||||
| CMConnTerm
|
||||
-- specific to graphql-ws clients
|
||||
| CMPing !(Maybe PingPongPayload)
|
||||
| CMPong !(Maybe PingPongPayload)
|
||||
deriving (Show, Eq)
|
||||
|
||||
data ConnParams
|
||||
= ConnParams
|
||||
{ _cpHeaders :: Maybe (Map.HashMap Text Text)
|
||||
} deriving (Show, Eq)
|
||||
$(J.deriveJSON hasuraJSON ''ConnParams)
|
||||
|
||||
instance J.FromJSON ClientMsg where
|
||||
parseJSON = J.withObject "ClientMessage" $ \obj -> do
|
||||
t <- obj J..: "type"
|
||||
case t of
|
||||
"connection_init" -> CMConnInit <$> obj J..:? "payload"
|
||||
"start" -> CMStart <$> J.parseJSON (J.Object obj)
|
||||
"stop" -> CMStop <$> J.parseJSON (J.Object obj)
|
||||
"connection_terminate" -> return CMConnTerm
|
||||
case (t :: String) of
|
||||
"connection_init" -> CMConnInit <$> parsePayload obj
|
||||
"start" -> CMStart <$> parseObj obj
|
||||
"stop" -> CMStop <$> parseObj obj
|
||||
"connection_terminate" -> pure CMConnTerm
|
||||
-- graphql-ws specific message types
|
||||
"complete" -> CMStop <$> parseObj obj
|
||||
"subscribe" -> CMStart <$> parseObj obj
|
||||
"ping" -> CMPing <$> parsePayload obj
|
||||
"pong" -> CMPong <$> parsePayload obj
|
||||
_ -> fail $ "unexpected type for ClientMessage: " <> t
|
||||
where
|
||||
parseObj o = J.parseJSON (J.Object o)
|
||||
|
||||
parsePayload py = py J..:? "payload"
|
||||
|
||||
-- server to client messages
|
||||
data DataMsg
|
||||
= DataMsg
|
||||
{ _dmId :: !OperationId
|
||||
@ -82,10 +158,22 @@ newtype CompletionMsg
|
||||
= CompletionMsg { unCompletionMsg :: OperationId }
|
||||
deriving (Show, Eq)
|
||||
|
||||
instance J.FromJSON CompletionMsg where
|
||||
parseJSON = J.withObject "CompletionMsg" $ \t ->
|
||||
CompletionMsg <$> t J..: "id"
|
||||
|
||||
instance J.ToJSON CompletionMsg where
|
||||
toJSON (CompletionMsg opId) = J.String $ tshow opId
|
||||
|
||||
newtype ConnErrMsg
|
||||
= ConnErrMsg { unConnErrMsg :: Text }
|
||||
deriving (Show, Eq, J.ToJSON, J.FromJSON, IsString)
|
||||
|
||||
data ServerErrorMsg =
|
||||
ServerErrorMsg { unServerErrorMsg :: Text }
|
||||
deriving stock (Show, Eq)
|
||||
$(J.deriveJSON hasuraJSON ''ServerErrorMsg)
|
||||
|
||||
data ServerMsg
|
||||
= SMConnAck
|
||||
| SMConnKeepAlive
|
||||
@ -93,27 +181,34 @@ data ServerMsg
|
||||
| SMData !DataMsg
|
||||
| SMErr !ErrorMsg
|
||||
| SMComplete !CompletionMsg
|
||||
-- graphql-ws specific values
|
||||
| SMNext !DataMsg
|
||||
| SMPing !(Maybe PingPongPayload)
|
||||
| SMPong !(Maybe PingPongPayload)
|
||||
|
||||
data ServerMsgType
|
||||
= SMT_GQL_CONNECTION_ACK
|
||||
| SMT_GQL_CONNECTION_KEEP_ALIVE
|
||||
| SMT_GQL_CONNECTION_ERROR
|
||||
| SMT_GQL_DATA
|
||||
| SMT_GQL_ERROR
|
||||
| SMT_GQL_COMPLETE
|
||||
deriving (Eq)
|
||||
-- | This is sent from the server to the client while closing the websocket
|
||||
-- on encountering an error.
|
||||
data ServerErrorCode =
|
||||
ProtocolError1002
|
||||
| GenericError4400 !String
|
||||
| UnAuthorized4401
|
||||
| Forbidden4403
|
||||
| ConnectionInitTimeout4408
|
||||
| NonUniqueSubscription4409 !OperationId
|
||||
| TooManyRequests4429
|
||||
deriving stock (Show)
|
||||
|
||||
instance Show ServerMsgType where
|
||||
show = \case
|
||||
SMT_GQL_CONNECTION_ACK -> "connection_ack"
|
||||
SMT_GQL_CONNECTION_KEEP_ALIVE -> "ka"
|
||||
SMT_GQL_CONNECTION_ERROR -> "connection_error"
|
||||
SMT_GQL_DATA -> "data"
|
||||
SMT_GQL_ERROR -> "error"
|
||||
SMT_GQL_COMPLETE -> "complete"
|
||||
|
||||
instance J.ToJSON ServerMsgType where
|
||||
toJSON = J.toJSON . show
|
||||
encodeServerErrorMsg :: ServerErrorCode -> BL.ByteString
|
||||
encodeServerErrorMsg ecode = encJToLBS . encJFromJValue $ case ecode of
|
||||
ProtocolError1002 -> packMsg "1002: Protocol Error"
|
||||
GenericError4400 msg -> packMsg $ "4400: " <> msg
|
||||
UnAuthorized4401 -> packMsg "4401: Unauthorized"
|
||||
Forbidden4403 -> packMsg "4403: Forbidden"
|
||||
ConnectionInitTimeout4408 -> packMsg "4408: Connection initialisation timeout"
|
||||
NonUniqueSubscription4409 opId -> packMsg $ "4409: Subscriber for " <> show opId <> " already exists"
|
||||
TooManyRequests4429 -> packMsg "4429: Too many requests"
|
||||
where
|
||||
packMsg = ServerErrorMsg . pack
|
||||
|
||||
serverMsgType :: ServerMsg -> ServerMsgType
|
||||
serverMsgType SMConnAck = SMT_GQL_CONNECTION_ACK
|
||||
@ -122,6 +217,9 @@ serverMsgType (SMConnErr _) = SMT_GQL_CONNECTION_ERROR
|
||||
serverMsgType (SMData _) = SMT_GQL_DATA
|
||||
serverMsgType (SMErr _) = SMT_GQL_ERROR
|
||||
serverMsgType (SMComplete _) = SMT_GQL_COMPLETE
|
||||
serverMsgType (SMPing _) = SMT_GQL_PING
|
||||
serverMsgType (SMPong _) = SMT_GQL_PONG
|
||||
serverMsgType (SMNext _) = SMT_GQL_NEXT
|
||||
|
||||
encodeServerMsg :: ServerMsg -> BL.ByteString
|
||||
encodeServerMsg msg =
|
||||
@ -155,5 +253,56 @@ encodeServerMsg msg =
|
||||
, ("id", encJFromJValue $ unCompletionMsg compMsg)
|
||||
]
|
||||
|
||||
SMPing mPayload ->
|
||||
encodePingPongPayload mPayload SMT_GQL_PING
|
||||
|
||||
SMPong mPayload ->
|
||||
encodePingPongPayload mPayload SMT_GQL_PONG
|
||||
|
||||
SMNext (DataMsg opId payload) ->
|
||||
[ encTy SMT_GQL_NEXT
|
||||
, ("id", encJFromJValue opId)
|
||||
, ("payload", encodeGQResp payload)
|
||||
]
|
||||
|
||||
where
|
||||
encTy ty = ("type", encJFromJValue ty)
|
||||
|
||||
encodePingPongPayload mPayload msgType = case mPayload of
|
||||
Just payload ->
|
||||
[ encTy msgType
|
||||
, ("payload", encJFromJValue payload)
|
||||
]
|
||||
Nothing -> [ encTy msgType ]
|
||||
|
||||
-- This "timer" is necessary while initialising the connection
|
||||
-- with the server. Also, this is specific to the GraphQL-WS protocol.
|
||||
data WSConnInitTimerStatus = Running | Done
|
||||
deriving stock (Show, Eq)
|
||||
|
||||
type WSConnInitTimer = (TVar WSConnInitTimerStatus, TMVar ())
|
||||
|
||||
waitForWSTimer :: WSConnInitTimer -> IO ()
|
||||
waitForWSTimer (_, timer) = atomically $ readTMVar timer
|
||||
|
||||
stopWSTimer :: WSConnInitTimer -> IO ()
|
||||
stopWSTimer (timerState, _) = atomically $ writeTVar timerState Done
|
||||
|
||||
getWSTimerState :: WSConnInitTimer -> IO WSConnInitTimerStatus
|
||||
getWSTimerState (timerState, _) = readTVarIO timerState
|
||||
|
||||
getNewWSTimer :: Seconds -> IO WSConnInitTimer
|
||||
getNewWSTimer timeout = do
|
||||
timerState <- newTVarIO Running
|
||||
timer <- newEmptyTMVarIO
|
||||
forkIO $ do
|
||||
sleep (seconds timeout)
|
||||
atomically $ do
|
||||
runTimerState <- readTVar timerState
|
||||
case runTimerState of
|
||||
Running -> do
|
||||
-- time's up, we set status to "Done"
|
||||
writeTVar timerState Done
|
||||
putTMVar timer ()
|
||||
Done -> pure ()
|
||||
pure (timerState, timer)
|
||||
|
@ -1,35 +1,8 @@
|
||||
{-# LANGUAGE CPP #-}
|
||||
{-# LANGUAGE NondecreasingIndentation #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
|
||||
module Hasura.GraphQL.Transport.WebSocket.Server
|
||||
( WSId(..)
|
||||
, WSLog(..)
|
||||
, WSEvent(..)
|
||||
, MessageDetails(..)
|
||||
, WSConn
|
||||
, getData
|
||||
, getWSId
|
||||
, closeConn
|
||||
, sendMsg
|
||||
|
||||
, AcceptWith(..)
|
||||
, OnConnH
|
||||
, OnCloseH
|
||||
, OnMessageH
|
||||
, WSHandlers(..)
|
||||
|
||||
, WSServer
|
||||
, HasuraServerApp
|
||||
, WSEventInfo(..)
|
||||
, WSQueueResponse(..)
|
||||
, ServerMsgType(..)
|
||||
, createWSServer
|
||||
, closeAll
|
||||
, createServerApp
|
||||
, shutdown
|
||||
|
||||
, MonadWSLog (..)
|
||||
) where
|
||||
module Hasura.GraphQL.Transport.WebSocket.Server where
|
||||
|
||||
import qualified Control.Concurrent.Async as A
|
||||
import qualified Control.Concurrent.Async.Lifted.Safe as LA
|
||||
@ -39,7 +12,9 @@ import qualified Control.Monad.Trans.Control as MC
|
||||
import qualified Data.Aeson as J
|
||||
import qualified Data.Aeson.Casing as J
|
||||
import qualified Data.Aeson.TH as J
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import qualified Data.ByteString.Lazy as BL
|
||||
import qualified Data.CaseInsensitive as CI
|
||||
import Data.String
|
||||
import qualified Data.TByteString as TBS
|
||||
import qualified Data.UUID as UUID
|
||||
@ -56,8 +31,10 @@ import qualified Network.WebSockets as WS
|
||||
import qualified StmContainers.Map as STMMap
|
||||
import qualified System.IO.Error as E
|
||||
|
||||
import Hasura.GraphQL.Transport.WebSocket.Protocol (OperationId, ServerMsgType (..))
|
||||
import Hasura.GraphQL.Transport.HTTP.Protocol
|
||||
import Hasura.GraphQL.Transport.WebSocket.Protocol
|
||||
import qualified Hasura.Logging as L
|
||||
import Hasura.Server.Init.Config (WSConnectionInitTimeout (..))
|
||||
|
||||
newtype WSId
|
||||
= WSId { unWSId :: UUID.UUID }
|
||||
@ -150,6 +127,9 @@ data WSConn a
|
||||
, _wcExtraData :: !a
|
||||
}
|
||||
|
||||
getRawWebSocketConnection :: WSConn a -> WS.Connection
|
||||
getRawWebSocketConnection = _wcConnRaw
|
||||
|
||||
getData :: WSConn a -> a
|
||||
getData = _wcExtraData
|
||||
|
||||
@ -233,42 +213,89 @@ data AcceptWith a
|
||||
, _awOnJwtExpiry :: !(WSConn a -> IO ())
|
||||
}
|
||||
|
||||
type OnConnH m a = WSId -> WS.RequestHead -> IpAddress -> m (Either WS.RejectRequest (AcceptWith a))
|
||||
-- | These set of functions or message handlers is used by the
|
||||
-- server while communicating with the client. They are particularly
|
||||
-- useful for the case when the messages being sent to the client
|
||||
-- are different for each of the sub-protocol(s) supported by the server.
|
||||
type WSKeepAliveMessageAction a = WSConn a -> IO ()
|
||||
type WSPostExecErrMessageAction a = WSConn a -> OperationId -> GQExecError -> IO ()
|
||||
type WSOnErrorMessageAction a = WSConn a -> ConnErrMsg -> Maybe String -> IO ()
|
||||
type WSCloseConnAction a = WSConn a -> OperationId -> String -> IO ()
|
||||
|
||||
-- | Used for specific actions within the `onConn` and `onMessage` handlers
|
||||
data WSActions a = WSActions
|
||||
{ _wsaPostExecErrMessageAction :: !(WSPostExecErrMessageAction a)
|
||||
, _wsaOnErrorMessageAction :: !(WSOnErrorMessageAction a)
|
||||
, _wsaConnectionCloseAction :: !(WSCloseConnAction a)
|
||||
, _wsaKeepAliveAction :: !(WSKeepAliveMessageAction a)
|
||||
-- ^ NOTE: keep alive action was made redundant because we need to send this message
|
||||
-- after the connection has been successfully established after `connection_init`
|
||||
, _wsaGetDataMessageType :: !(DataMsg -> ServerMsg)
|
||||
, _wsaAcceptRequest :: !WS.AcceptRequest
|
||||
}
|
||||
|
||||
-- | to be used with `WSOnErrorMessageAction`
|
||||
onClientMessageParseErrorText :: Maybe String
|
||||
onClientMessageParseErrorText = Just "Parsing client message failed: "
|
||||
|
||||
-- | to be used with `WSOnErrorMessageAction`
|
||||
onConnInitErrorText :: Maybe String
|
||||
onConnInitErrorText = Just "Connection initialization failed: "
|
||||
|
||||
type OnConnH m a = WSId -> WS.RequestHead -> IpAddress -> WSActions a -> m (Either WS.RejectRequest (AcceptWith a))
|
||||
type OnMessageH m a = WSConn a -> BL.ByteString -> WSActions a -> m ()
|
||||
type OnCloseH m a = WSConn a -> m ()
|
||||
type OnMessageH m a = WSConn a -> BL.ByteString -> m ()
|
||||
|
||||
-- | aka generalized 'WS.ServerApp' over @m@, which takes an IPAddress
|
||||
type HasuraServerApp m = IpAddress -> WS.PendingConnection -> m ()
|
||||
|
||||
-- | NOTE: The types of `_hOnConn` and `_hOnMessage` were updated from `OnConnH` and `OnMessageH`
|
||||
-- because we needed to pass the subprotcol here to these methods to eventually get to `OnConnH` and `OnMessageH`.
|
||||
-- Please see `createServerApp` to get a better understanding of how these handlers are used.
|
||||
data WSHandlers m a
|
||||
= WSHandlers
|
||||
{ _hOnConn :: OnConnH m a
|
||||
, _hOnMessage :: OnMessageH m a
|
||||
{ _hOnConn :: (WSId -> WS.RequestHead -> IpAddress -> WSSubProtocol -> m (Either WS.RejectRequest (AcceptWith a)))
|
||||
, _hOnMessage :: (WSConn a -> BL.ByteString -> WSSubProtocol -> m ())
|
||||
, _hOnClose :: OnCloseH m a
|
||||
}
|
||||
|
||||
createServerApp
|
||||
:: (MonadIO m, MC.MonadBaseControl IO m, LA.Forall (LA.Pure m), MonadWSLog m)
|
||||
=> WSServer a
|
||||
=> WSConnectionInitTimeout
|
||||
-> WSServer a
|
||||
-> WSHandlers m a
|
||||
-- ^ user provided handlers
|
||||
-> HasuraServerApp m
|
||||
-- ^ aka WS.ServerApp
|
||||
{-# INLINE createServerApp #-}
|
||||
createServerApp (WSServer logger@(L.Logger writeLog) serverStatus) wsHandlers !ipAddress !pendingConn = do
|
||||
createServerApp wsConnInitTimeout (WSServer logger@(L.Logger writeLog) serverStatus) wsHandlers !ipAddress !pendingConn = do
|
||||
wsId <- WSId <$> liftIO UUID.nextRandom
|
||||
logWSLog logger $ WSLog wsId EConnectionRequest Nothing
|
||||
-- NOTE: this timer is specific to `graphql-ws`. the server has to close the connection
|
||||
-- if the client doesn't send a `connection_init` message within the timeout period
|
||||
wsConnInitTimer <- liftIO $ getNewWSTimer (unWSConnectionInitTimeout wsConnInitTimeout)
|
||||
status <- liftIO $ STM.readTVarIO serverStatus
|
||||
case status of
|
||||
AcceptingConns _ -> logUnexpectedExceptions $ do
|
||||
let reqHead = WS.pendingRequest pendingConn
|
||||
onConnRes <- _hOnConn wsHandlers wsId reqHead ipAddress
|
||||
either (onReject wsId) (onAccept wsId) onConnRes
|
||||
onConnRes <- connHandler wsId reqHead ipAddress subProtocol
|
||||
either (onReject wsId) (onAccept wsConnInitTimer wsId) onConnRes
|
||||
|
||||
ShuttingDown ->
|
||||
onReject wsId shuttingDownReject
|
||||
|
||||
where
|
||||
reqHead = WS.pendingRequest pendingConn
|
||||
|
||||
getSubProtocolHeader rhdrs =
|
||||
filter (\(x, _) -> x == (CI.mk . B.pack $ "Sec-WebSocket-Protocol")) $ WS.requestHeaders rhdrs
|
||||
|
||||
subProtocol = case getSubProtocolHeader reqHead of
|
||||
[sph] -> toWSSubProtocol . B.unpack . snd $ sph
|
||||
_ -> Apollo -- NOTE: we default to the apollo implemenation
|
||||
|
||||
connHandler = _hOnConn wsHandlers
|
||||
messageHandler = _hOnMessage wsHandlers
|
||||
closeHandler = _hOnClose wsHandlers
|
||||
|
||||
-- It's not clear what the unexpected exception handling story here should be. So at
|
||||
-- least log properly and re-raise:
|
||||
logUnexpectedExceptions = handle $ \(e :: SomeException) -> do
|
||||
@ -286,7 +313,7 @@ createServerApp (WSServer logger@(L.Logger writeLog) serverStatus) wsHandlers !i
|
||||
liftIO $ WS.rejectRequestWith pendingConn rejectRequest
|
||||
logWSLog logger $ WSLog wsId ERejected Nothing
|
||||
|
||||
onAccept wsId (AcceptWith a acceptWithParams keepAlive onJwtExpiry) = do
|
||||
onAccept wsConnInitTimer wsId (AcceptWith a acceptWithParams keepAlive onJwtExpiry) = do
|
||||
conn <- liftIO $ WS.acceptRequestWith pendingConn acceptWithParams
|
||||
logWSLog logger $ WSLog wsId EAccepted Nothing
|
||||
sendQ <- liftIO STM.newTQueueIO
|
||||
@ -317,7 +344,7 @@ createServerApp (WSServer logger@(L.Logger writeLog) serverStatus) wsHandlers !i
|
||||
-- Bad luck, we were in the process of shutting the server down but a new
|
||||
-- connection was accepted. Let's just close it politely
|
||||
forceConnReconnect wsConn "shutting server down"
|
||||
_hOnClose wsHandlers wsConn
|
||||
closeHandler wsConn
|
||||
|
||||
AcceptingConns _ -> do
|
||||
let rcv = forever $ do
|
||||
@ -332,7 +359,7 @@ createServerApp (WSServer logger@(L.Logger writeLog) serverStatus) wsHandlers !i
|
||||
WS.receiveData conn
|
||||
let message = MessageDetails (TBS.fromLBS msg) (BL.length msg)
|
||||
logWSLog logger $ WSLog wsId (EMessageReceived message) Nothing
|
||||
_hOnMessage wsHandlers wsConn msg
|
||||
messageHandler wsConn msg subProtocol
|
||||
|
||||
let send = forever $ do
|
||||
WSQueueResponse msg wsInfo <- liftIO $ STM.atomically $ STM.readTQueue sendQ
|
||||
@ -347,6 +374,11 @@ createServerApp (WSServer logger@(L.Logger writeLog) serverStatus) wsHandlers !i
|
||||
LA.withAsync (liftIO $ keepAlive wsConn) $ \keepAliveRef -> do
|
||||
LA.withAsync (liftIO $ onJwtExpiry wsConn) $ \onJwtExpiryRef -> do
|
||||
|
||||
-- once connection is accepted, check the status of the timer, and if it's expired, close the connection for `graphql-ws`
|
||||
timeoutStatus <- liftIO $ getWSTimerState wsConnInitTimer
|
||||
when (timeoutStatus == Done && subProtocol == GraphQLWS) $
|
||||
liftIO $ closeConnWithCode wsConn 4408 "Connection initialisation timed out"
|
||||
|
||||
-- terminates on WS.ConnectionException and JWT expiry
|
||||
let waitOnRefs = [keepAliveRef, onJwtExpiryRef, rcvRef, sendRef]
|
||||
-- withAnyCancel re-raises exceptions from forkedThreads, and is guarenteed to cancel in
|
||||
@ -365,7 +397,7 @@ createServerApp (WSServer logger@(L.Logger writeLog) serverStatus) wsHandlers !i
|
||||
ShuttingDown -> pure ()
|
||||
AcceptingConns connMap -> do
|
||||
liftIO $ STM.atomically $ STMMap.delete (_wcConnId wsConn) connMap
|
||||
_hOnClose wsHandlers wsConn
|
||||
closeHandler wsConn
|
||||
logWSLog logger $ WSLog (_wcConnId wsConn) EClosed Nothing
|
||||
|
||||
shutdown :: WSServer a -> IO ()
|
||||
|
86
server/src-lib/Hasura/GraphQL/Transport/WebSocket/Types.hs
Normal file
86
server/src-lib/Hasura/GraphQL/Transport/WebSocket/Types.hs
Normal file
@ -0,0 +1,86 @@
|
||||
module Hasura.GraphQL.Transport.WebSocket.Types where
|
||||
|
||||
import Hasura.Prelude
|
||||
|
||||
import qualified Control.Concurrent.STM as STM
|
||||
import qualified Data.Time.Clock as TC
|
||||
import qualified Network.HTTP.Client as H
|
||||
import qualified Network.HTTP.Types as H
|
||||
import qualified Network.Wai.Extended as Wai
|
||||
import qualified StmContainers.Map as STMMap
|
||||
|
||||
import qualified Hasura.GraphQL.Execute as E
|
||||
import qualified Hasura.GraphQL.Execute.LiveQuery.State as LQ
|
||||
import qualified Hasura.GraphQL.Transport.WebSocket.Server as WS
|
||||
import qualified Hasura.Logging as L
|
||||
|
||||
import Hasura.GraphQL.Transport.HTTP.Protocol
|
||||
import Hasura.GraphQL.Transport.Instances ()
|
||||
import Hasura.GraphQL.Transport.WebSocket.Protocol
|
||||
import Hasura.RQL.Types
|
||||
import Hasura.Server.Cors
|
||||
import Hasura.Server.Init.Config (KeepAliveDelay (..))
|
||||
import Hasura.Server.Metrics (ServerMetrics (..))
|
||||
import Hasura.Session
|
||||
|
||||
newtype WsHeaders
|
||||
= WsHeaders { unWsHeaders :: [H.Header] }
|
||||
deriving (Show, Eq)
|
||||
|
||||
data ErrRespType
|
||||
= ERTLegacy
|
||||
| ERTGraphqlCompliant
|
||||
deriving (Show)
|
||||
|
||||
data WSConnState
|
||||
= CSNotInitialised !WsHeaders !Wai.IpAddress
|
||||
-- ^ headers and IP address from the client for websockets
|
||||
| CSInitError !Text
|
||||
| CSInitialised !WsClientState
|
||||
deriving (Show)
|
||||
|
||||
data WsClientState
|
||||
= WsClientState
|
||||
{ wscsUserInfo :: !UserInfo
|
||||
-- ^ the 'UserInfo' required to execute the GraphQL query
|
||||
, wscsTokenExpTime :: !(Maybe TC.UTCTime)
|
||||
-- ^ the JWT/token expiry time, if any
|
||||
, wscsReqHeaders :: ![H.Header]
|
||||
-- ^ headers from the client (in conn params) to forward to the remote schema
|
||||
, wscsIpAddress :: !Wai.IpAddress
|
||||
-- ^ IP address required for 'MonadGQLAuthorization'
|
||||
}
|
||||
deriving (Show)
|
||||
|
||||
data WSConnData
|
||||
= WSConnData
|
||||
-- the role and headers are set only on connection_init message
|
||||
{ _wscUser :: !(STM.TVar WSConnState)
|
||||
-- we only care about subscriptions,
|
||||
-- the other operations (query/mutations)
|
||||
-- are not tracked here
|
||||
, _wscOpMap :: !OperationMap
|
||||
, _wscErrRespTy :: !ErrRespType
|
||||
, _wscAPIType :: !E.GraphQLQueryType
|
||||
}
|
||||
|
||||
data WSServerEnv
|
||||
= WSServerEnv
|
||||
{ _wseLogger :: !(L.Logger L.Hasura)
|
||||
, _wseLiveQMap :: !LQ.LiveQueriesState
|
||||
, _wseGCtxMap :: !(IO (SchemaCache, SchemaCacheVer))
|
||||
-- ^ an action that always returns the latest version of the schema cache. See 'SchemaCacheRef'.
|
||||
, _wseHManager :: !H.Manager
|
||||
, _wseCorsPolicy :: !CorsPolicy
|
||||
, _wseSQLCtx :: !SQLGenCtx
|
||||
, _wseServer :: !WSServer
|
||||
, _wseEnableAllowlist :: !Bool
|
||||
, _wseKeepAliveDelay :: !KeepAliveDelay
|
||||
, _wseServerMetrics :: !ServerMetrics
|
||||
}
|
||||
|
||||
type OperationMap = STMMap.Map OperationId (LQ.LiveQueryId, Maybe OperationName)
|
||||
|
||||
type WSServer = WS.WSServer WSConnData
|
||||
|
||||
type WSConn = WS.WSConn WSConnData
|
@ -51,7 +51,7 @@ import qualified Hasura.GraphQL.Execute.LiveQuery.State as EL
|
||||
import qualified Hasura.GraphQL.Explain as GE
|
||||
import qualified Hasura.GraphQL.Transport.HTTP as GH
|
||||
import qualified Hasura.GraphQL.Transport.HTTP.Protocol as GH
|
||||
import qualified Hasura.GraphQL.Transport.WebSocket as WS
|
||||
import qualified Hasura.GraphQL.Transport.WSServerApp as WS
|
||||
import qualified Hasura.GraphQL.Transport.WebSocket.Server as WS
|
||||
import qualified Hasura.Logging as L
|
||||
import qualified Hasura.Server.API.PGDump as PGD
|
||||
@ -782,16 +782,16 @@ mkWaiApp
|
||||
-> FunctionPermissionsCtx
|
||||
-> WS.ConnectionOptions
|
||||
-> KeepAliveDelay
|
||||
-- ^ Metadata storage connection pool
|
||||
-> MaintenanceMode
|
||||
-> S.HashSet ExperimentalFeature
|
||||
-- ^ Set of the enabled experimental features
|
||||
-> S.HashSet (L.EngineLogType L.Hasura)
|
||||
-> WSConnectionInitTimeout
|
||||
-> m HasuraApp
|
||||
mkWaiApp setupHook env logger sqlGenCtx enableAL httpManager mode corsCfg enableConsole consoleAssetsDir
|
||||
enableTelemetry instanceId apis lqOpts responseErrorsConfig
|
||||
liveQueryHook schemaCacheRef ekgStore serverMetrics enableRSPermsCtx functionPermsCtx
|
||||
connectionOptions keepAliveDelay maintenanceMode experimentalFeatures enabledLogTypes = do
|
||||
connectionOptions keepAliveDelay maintenanceMode experimentalFeatures enabledLogTypes wsConnInitTimeout = do
|
||||
|
||||
let getSchemaCache = first lastBuiltSchemaCache <$> readIORef (_scrCache schemaCacheRef)
|
||||
|
||||
@ -799,21 +799,21 @@ mkWaiApp setupHook env logger sqlGenCtx enableAL httpManager mode corsCfg enable
|
||||
postPollHook = fromMaybe (EL.defaultLiveQueryPostPollHook logger) liveQueryHook
|
||||
|
||||
lqState <- liftIO $ EL.initLiveQueriesState lqOpts postPollHook
|
||||
wsServerEnv <- WS.createWSServerEnv logger lqState getSchemaCache httpManager
|
||||
corsPolicy sqlGenCtx enableAL keepAliveDelay serverMetrics
|
||||
wsServerEnv <- WS.createWSServerEnv logger lqState getSchemaCache httpManager corsPolicy
|
||||
sqlGenCtx enableAL keepAliveDelay serverMetrics
|
||||
|
||||
let serverCtx = ServerCtx
|
||||
{ scLogger = logger
|
||||
, scCacheRef = schemaCacheRef
|
||||
, scAuthMode = mode
|
||||
, scManager = httpManager
|
||||
, scSQLGenCtx = sqlGenCtx
|
||||
, scEnabledAPIs = apis
|
||||
, scInstanceId = instanceId
|
||||
, scLQState = lqState
|
||||
, scEnableAllowlist = enableAL
|
||||
, scEkgStore = ekgStore
|
||||
, scEnvironment = env
|
||||
{ scLogger = logger
|
||||
, scCacheRef = schemaCacheRef
|
||||
, scAuthMode = mode
|
||||
, scManager = httpManager
|
||||
, scSQLGenCtx = sqlGenCtx
|
||||
, scEnabledAPIs = apis
|
||||
, scInstanceId = instanceId
|
||||
, scLQState = lqState
|
||||
, scEnableAllowlist = enableAL
|
||||
, scEkgStore = ekgStore
|
||||
, scEnvironment = env
|
||||
, scResponseInternalErrorsConfig = responseErrorsConfig
|
||||
, scRemoteSchemaPermsCtx = enableRSPermsCtx
|
||||
, scFunctionPermsCtx = functionPermsCtx
|
||||
@ -826,7 +826,7 @@ mkWaiApp setupHook env logger sqlGenCtx enableAL httpManager mode corsCfg enable
|
||||
Spock.spockAsApp $ Spock.spockT lowerIO $
|
||||
httpApp setupHook corsCfg serverCtx enableConsole consoleAssetsDir enableTelemetry
|
||||
|
||||
let wsServerApp = WS.createWSServerApp env enabledLogTypes mode wsServerEnv -- TODO: Lyndon: Can we pass environment through wsServerEnv?
|
||||
let wsServerApp = WS.createWSServerApp env enabledLogTypes mode wsServerEnv wsConnInitTimeout -- TODO: Lyndon: Can we pass environment through wsServerEnv?
|
||||
stopWSServer = WS.stopWSServerApp wsServerEnv
|
||||
|
||||
waiApp <- liftWithStateless $ \lowerIO ->
|
||||
|
@ -221,6 +221,9 @@ mkServeOptions rso = do
|
||||
gracefulShutdownTime <-
|
||||
fromMaybe 60 <$> withEnv (rsoGracefulShutdownTimeout rso) (fst gracefulShutdownEnv)
|
||||
|
||||
webSocketConnectionInitTimeout <- WSConnectionInitTimeout . fromIntegral . fromMaybe 3
|
||||
<$> withEnv (rsoWebSocketConnectionInitTimeout rso) (fst webSocketConnectionInitTimeoutEnv)
|
||||
|
||||
pure $ ServeOptions
|
||||
port
|
||||
host
|
||||
@ -256,6 +259,7 @@ mkServeOptions rso = do
|
||||
eventsFetchBatchSize
|
||||
devMode
|
||||
gracefulShutdownTime
|
||||
webSocketConnectionInitTimeout
|
||||
where
|
||||
#ifdef DeveloperAPIs
|
||||
defaultAPIs = [METADATA,GRAPHQL,PGDUMP,CONFIG,DEVELOPER]
|
||||
@ -1196,6 +1200,7 @@ serveOptsToLog so =
|
||||
, "experimental_features" J..= soExperimentalFeatures so
|
||||
, "events_fetch_batch_size" J..= soEventsFetchBatchSize so
|
||||
, "graceful_shutdown_timeout" J..= soGracefulShutdownTimeout so
|
||||
, "websocket_connection_init_timeout" J..= show (soWebsocketConnectionInitTimeout so)
|
||||
]
|
||||
|
||||
mkGenericStrLog :: L.LogLevel -> Text -> String -> StartupLog
|
||||
@ -1252,6 +1257,7 @@ serveOptionsParser =
|
||||
<*> parseExperimentalFeatures
|
||||
<*> parseEventsFetchBatchSize
|
||||
<*> parseGracefulShutdownTimeout
|
||||
<*> parseWebSocketConnectionInitTimeout
|
||||
|
||||
-- | This implements the mapping between application versions
|
||||
-- and catalog schema versions.
|
||||
@ -1299,6 +1305,7 @@ parseWebSocketCompression =
|
||||
help (snd webSocketCompressionEnv)
|
||||
)
|
||||
|
||||
-- NOTE: this is purely used by Apollo-Subscription-Transport-WS
|
||||
webSocketKeepAliveEnv :: (String, String)
|
||||
webSocketKeepAliveEnv =
|
||||
( "HASURA_GRAPHQL_WEBSOCKET_KEEPALIVE"
|
||||
@ -1312,3 +1319,18 @@ parseWebSocketKeepAlive =
|
||||
( long "websocket-keepalive" <>
|
||||
help (snd webSocketKeepAliveEnv)
|
||||
)
|
||||
|
||||
-- NOTE: this is purely used by GraphQL-WS
|
||||
webSocketConnectionInitTimeoutEnv :: (String, String)
|
||||
webSocketConnectionInitTimeoutEnv =
|
||||
( "HASURA_GRAPHQL_WEBSOCKET_CONNECTION_INIT_TIMEOUT" -- FIXME?: maybe a better name
|
||||
, "Control websocket connection_init timeout (default 3 seconds)"
|
||||
)
|
||||
|
||||
parseWebSocketConnectionInitTimeout :: Parser (Maybe Int)
|
||||
parseWebSocketConnectionInitTimeout =
|
||||
optional $
|
||||
option (eitherReader readEither)
|
||||
( long "websocket-connection-init-timeout" <>
|
||||
help (snd webSocketConnectionInitTimeoutEnv)
|
||||
)
|
||||
|
@ -63,44 +63,58 @@ instance ToJSON OptionalInterval where
|
||||
Skip -> toJSON @Milliseconds 0
|
||||
Interval s -> toJSON s
|
||||
|
||||
data API
|
||||
= METADATA
|
||||
| GRAPHQL
|
||||
| PGDUMP
|
||||
| DEVELOPER
|
||||
| CONFIG
|
||||
deriving (Show, Eq, Read, Generic)
|
||||
|
||||
$(J.deriveJSON (J.defaultOptions { J.constructorTagModifier = map toLower })
|
||||
''API)
|
||||
|
||||
instance Hashable API
|
||||
|
||||
data RawServeOptions impl
|
||||
= RawServeOptions
|
||||
{ rsoPort :: !(Maybe Int)
|
||||
, rsoHost :: !(Maybe HostPreference)
|
||||
, rsoConnParams :: !RawConnParams
|
||||
, rsoTxIso :: !(Maybe Q.TxIsolation)
|
||||
, rsoAdminSecret :: !(Maybe AdminSecretHash)
|
||||
, rsoAuthHook :: !RawAuthHook
|
||||
, rsoJwtSecret :: !(Maybe JWTConfig)
|
||||
, rsoUnAuthRole :: !(Maybe RoleName)
|
||||
, rsoCorsConfig :: !(Maybe CorsConfig)
|
||||
, rsoEnableConsole :: !Bool
|
||||
, rsoConsoleAssetsDir :: !(Maybe Text)
|
||||
, rsoEnableTelemetry :: !(Maybe Bool)
|
||||
, rsoWsReadCookie :: !Bool
|
||||
, rsoStringifyNum :: !Bool
|
||||
, rsoDangerousBooleanCollapse :: !(Maybe Bool)
|
||||
, rsoEnabledAPIs :: !(Maybe [API])
|
||||
, rsoMxRefetchInt :: !(Maybe LQ.RefetchInterval)
|
||||
, rsoMxBatchSize :: !(Maybe LQ.BatchSize)
|
||||
, rsoEnableAllowlist :: !Bool
|
||||
, rsoEnabledLogTypes :: !(Maybe [L.EngineLogType impl])
|
||||
, rsoLogLevel :: !(Maybe L.LogLevel)
|
||||
, rsoDevMode :: !Bool
|
||||
, rsoAdminInternalErrors :: !(Maybe Bool)
|
||||
, rsoEventsHttpPoolSize :: !(Maybe Int)
|
||||
, rsoEventsFetchInterval :: !(Maybe Milliseconds)
|
||||
, rsoAsyncActionsFetchInterval :: !(Maybe Milliseconds)
|
||||
, rsoLogHeadersFromEnv :: !Bool
|
||||
, rsoEnableRemoteSchemaPermissions :: !Bool
|
||||
, rsoWebSocketCompression :: !Bool
|
||||
, rsoWebSocketKeepAlive :: !(Maybe Int)
|
||||
, rsoInferFunctionPermissions :: !(Maybe Bool)
|
||||
, rsoEnableMaintenanceMode :: !Bool
|
||||
, rsoSchemaPollInterval :: !(Maybe Milliseconds)
|
||||
, rsoExperimentalFeatures :: !(Maybe [ExperimentalFeature])
|
||||
, rsoEventsFetchBatchSize :: !(Maybe NonNegativeInt)
|
||||
, rsoGracefulShutdownTimeout :: !(Maybe Seconds)
|
||||
{ rsoPort :: !(Maybe Int)
|
||||
, rsoHost :: !(Maybe HostPreference)
|
||||
, rsoConnParams :: !RawConnParams
|
||||
, rsoTxIso :: !(Maybe Q.TxIsolation)
|
||||
, rsoAdminSecret :: !(Maybe AdminSecretHash)
|
||||
, rsoAuthHook :: !RawAuthHook
|
||||
, rsoJwtSecret :: !(Maybe JWTConfig)
|
||||
, rsoUnAuthRole :: !(Maybe RoleName)
|
||||
, rsoCorsConfig :: !(Maybe CorsConfig)
|
||||
, rsoEnableConsole :: !Bool
|
||||
, rsoConsoleAssetsDir :: !(Maybe Text)
|
||||
, rsoEnableTelemetry :: !(Maybe Bool)
|
||||
, rsoWsReadCookie :: !Bool
|
||||
, rsoStringifyNum :: !Bool
|
||||
, rsoDangerousBooleanCollapse :: !(Maybe Bool)
|
||||
, rsoEnabledAPIs :: !(Maybe [API])
|
||||
, rsoMxRefetchInt :: !(Maybe LQ.RefetchInterval)
|
||||
, rsoMxBatchSize :: !(Maybe LQ.BatchSize)
|
||||
, rsoEnableAllowlist :: !Bool
|
||||
, rsoEnabledLogTypes :: !(Maybe [L.EngineLogType impl])
|
||||
, rsoLogLevel :: !(Maybe L.LogLevel)
|
||||
, rsoDevMode :: !Bool
|
||||
, rsoAdminInternalErrors :: !(Maybe Bool)
|
||||
, rsoEventsHttpPoolSize :: !(Maybe Int)
|
||||
, rsoEventsFetchInterval :: !(Maybe Milliseconds)
|
||||
, rsoAsyncActionsFetchInterval :: !(Maybe Milliseconds)
|
||||
, rsoLogHeadersFromEnv :: !Bool
|
||||
, rsoEnableRemoteSchemaPermissions :: !Bool
|
||||
, rsoWebSocketCompression :: !Bool
|
||||
, rsoWebSocketKeepAlive :: !(Maybe Int)
|
||||
, rsoInferFunctionPermissions :: !(Maybe Bool)
|
||||
, rsoEnableMaintenanceMode :: !Bool
|
||||
, rsoSchemaPollInterval :: !(Maybe Milliseconds)
|
||||
, rsoExperimentalFeatures :: !(Maybe [ExperimentalFeature])
|
||||
, rsoEventsFetchBatchSize :: !(Maybe NonNegativeInt)
|
||||
, rsoGracefulShutdownTimeout :: !(Maybe Seconds)
|
||||
, rsoWebSocketConnectionInitTimeout :: !(Maybe Int)
|
||||
}
|
||||
|
||||
-- | @'ResponseInternalErrorsConfig' represents the encoding of the internal
|
||||
@ -118,47 +132,57 @@ shouldIncludeInternal role = \case
|
||||
InternalErrorsAdminOnly -> role == adminRoleName
|
||||
InternalErrorsDisabled -> False
|
||||
|
||||
newtype KeepAliveDelay
|
||||
= KeepAliveDelay
|
||||
{ unKeepAliveDelay :: Seconds
|
||||
} deriving (Eq, Show)
|
||||
newtype KeepAliveDelay = KeepAliveDelay { unKeepAliveDelay :: Seconds }
|
||||
deriving (Eq, Show)
|
||||
$(J.deriveJSON hasuraJSON ''KeepAliveDelay)
|
||||
|
||||
defaultKeepAliveDelay :: KeepAliveDelay
|
||||
defaultKeepAliveDelay = KeepAliveDelay $ fromIntegral (5 :: Int)
|
||||
|
||||
newtype WSConnectionInitTimeout = WSConnectionInitTimeout { unWSConnectionInitTimeout :: Seconds }
|
||||
deriving (Eq, Show)
|
||||
$(J.deriveJSON hasuraJSON ''WSConnectionInitTimeout)
|
||||
|
||||
defaultWSConnectionInitTimeout :: WSConnectionInitTimeout
|
||||
defaultWSConnectionInitTimeout = WSConnectionInitTimeout $ fromIntegral (3 :: Int)
|
||||
|
||||
data ServeOptions impl
|
||||
= ServeOptions
|
||||
{ soPort :: !Int
|
||||
, soHost :: !HostPreference
|
||||
, soConnParams :: !Q.ConnParams
|
||||
, soTxIso :: !Q.TxIsolation
|
||||
, soAdminSecret :: !(Maybe AdminSecretHash)
|
||||
, soAuthHook :: !(Maybe AuthHook)
|
||||
, soJwtSecret :: !(Maybe JWTConfig)
|
||||
, soUnAuthRole :: !(Maybe RoleName)
|
||||
, soCorsConfig :: !CorsConfig
|
||||
, soEnableConsole :: !Bool
|
||||
, soConsoleAssetsDir :: !(Maybe Text)
|
||||
, soEnableTelemetry :: !Bool
|
||||
, soStringifyNum :: !Bool
|
||||
, soDangerousBooleanCollapse :: !Bool
|
||||
, soEnabledAPIs :: !(Set.HashSet API)
|
||||
, soLiveQueryOpts :: !LQ.LiveQueriesOptions
|
||||
, soEnableAllowlist :: !Bool
|
||||
, soEnabledLogTypes :: !(Set.HashSet (L.EngineLogType impl))
|
||||
, soLogLevel :: !L.LogLevel
|
||||
, soResponseInternalErrorsConfig :: !ResponseInternalErrorsConfig
|
||||
, soEventsHttpPoolSize :: !(Maybe Int)
|
||||
, soEventsFetchInterval :: !(Maybe Milliseconds)
|
||||
, soAsyncActionsFetchInterval :: !OptionalInterval
|
||||
, soLogHeadersFromEnv :: !Bool
|
||||
, soEnableRemoteSchemaPermissions :: !RemoteSchemaPermsCtx
|
||||
, soConnectionOptions :: !WS.ConnectionOptions
|
||||
, soWebsocketKeepAlive :: !KeepAliveDelay
|
||||
, soInferFunctionPermissions :: !FunctionPermissionsCtx
|
||||
, soEnableMaintenanceMode :: !MaintenanceMode
|
||||
, soSchemaPollInterval :: !OptionalInterval
|
||||
, soExperimentalFeatures :: !(Set.HashSet ExperimentalFeature)
|
||||
, soEventsFetchBatchSize :: !NonNegativeInt
|
||||
, soDevMode :: !Bool
|
||||
, soGracefulShutdownTimeout :: !Seconds
|
||||
{ soPort :: !Int
|
||||
, soHost :: !HostPreference
|
||||
, soConnParams :: !Q.ConnParams
|
||||
, soTxIso :: !Q.TxIsolation
|
||||
, soAdminSecret :: !(Maybe AdminSecretHash)
|
||||
, soAuthHook :: !(Maybe AuthHook)
|
||||
, soJwtSecret :: !(Maybe JWTConfig)
|
||||
, soUnAuthRole :: !(Maybe RoleName)
|
||||
, soCorsConfig :: !CorsConfig
|
||||
, soEnableConsole :: !Bool
|
||||
, soConsoleAssetsDir :: !(Maybe Text)
|
||||
, soEnableTelemetry :: !Bool
|
||||
, soStringifyNum :: !Bool
|
||||
, soDangerousBooleanCollapse :: !Bool
|
||||
, soEnabledAPIs :: !(Set.HashSet API)
|
||||
, soLiveQueryOpts :: !LQ.LiveQueriesOptions
|
||||
, soEnableAllowlist :: !Bool
|
||||
, soEnabledLogTypes :: !(Set.HashSet (L.EngineLogType impl))
|
||||
, soLogLevel :: !L.LogLevel
|
||||
, soResponseInternalErrorsConfig :: !ResponseInternalErrorsConfig
|
||||
, soEventsHttpPoolSize :: !(Maybe Int)
|
||||
, soEventsFetchInterval :: !(Maybe Milliseconds)
|
||||
, soAsyncActionsFetchInterval :: !OptionalInterval
|
||||
, soLogHeadersFromEnv :: !Bool
|
||||
, soEnableRemoteSchemaPermissions :: !RemoteSchemaPermsCtx
|
||||
, soConnectionOptions :: !WS.ConnectionOptions
|
||||
, soWebsocketKeepAlive :: !KeepAliveDelay
|
||||
, soInferFunctionPermissions :: !FunctionPermissionsCtx
|
||||
, soEnableMaintenanceMode :: !MaintenanceMode
|
||||
, soSchemaPollInterval :: !OptionalInterval
|
||||
, soExperimentalFeatures :: !(Set.HashSet ExperimentalFeature)
|
||||
, soEventsFetchBatchSize :: !NonNegativeInt
|
||||
, soDevMode :: !Bool
|
||||
, soGracefulShutdownTimeout :: !Seconds
|
||||
, soWebsocketConnectionInitTimeout :: !WSConnectionInitTimeout
|
||||
}
|
||||
|
||||
data DowngradeOptions
|
||||
@ -211,19 +235,6 @@ data HGECommandG a
|
||||
| HCDowngrade !DowngradeOptions
|
||||
deriving (Show, Eq)
|
||||
|
||||
data API
|
||||
= METADATA
|
||||
| GRAPHQL
|
||||
| PGDUMP
|
||||
| DEVELOPER
|
||||
| CONFIG
|
||||
deriving (Show, Eq, Read, Generic)
|
||||
|
||||
$(J.deriveJSON (J.defaultOptions { J.constructorTagModifier = map toLower })
|
||||
''API)
|
||||
|
||||
instance Hashable API
|
||||
|
||||
$(J.deriveJSON (J.aesonPrefix J.camelCase){J.omitNothingFields=True} ''PostgresRawConnDetails)
|
||||
|
||||
type HGECommand impl = HGECommandG (ServeOptions impl)
|
||||
|
@ -1,6 +1,6 @@
|
||||
import pytest
|
||||
import time
|
||||
from context import HGECtx, HGECtxError, ActionsWebhookServer, EvtsWebhookServer, HGECtxGQLServer, GQLWsClient, PytestConf
|
||||
from context import HGECtx, HGECtxError, ActionsWebhookServer, EvtsWebhookServer, HGECtxGQLServer, GQLWsClient, PytestConf, GraphQLWSClient
|
||||
import threading
|
||||
import random
|
||||
from datetime import datetime
|
||||
@ -388,6 +388,16 @@ def ws_client(request, hge_ctx):
|
||||
yield client
|
||||
client.teardown()
|
||||
|
||||
@pytest.fixture(scope='class')
|
||||
def ws_client_graphql_ws(request, hge_ctx):
|
||||
"""
|
||||
This fixture provides an GraphQL-WS client
|
||||
"""
|
||||
client = GraphQLWSClient(hge_ctx, '/v1/graphql')
|
||||
time.sleep(0.1)
|
||||
yield client
|
||||
client.teardown()
|
||||
|
||||
@pytest.fixture(scope='class')
|
||||
def per_class_tests_db_state(request, hge_ctx):
|
||||
"""
|
||||
|
@ -32,6 +32,7 @@ class PytestConf():
|
||||
class HGECtxError(Exception):
|
||||
pass
|
||||
|
||||
# NOTE: use this to generate a GraphQL client that uses the `Apollo`(subscription-transport-ws) sub-protocol
|
||||
class GQLWsClient():
|
||||
|
||||
def __init__(self, hge_ctx, endpoint):
|
||||
@ -160,6 +161,151 @@ class GQLWsClient():
|
||||
self._ws.close()
|
||||
self.wst.join()
|
||||
|
||||
# NOTE: use this to generate a GraphQL client that uses the `graphql-ws` sub-protocol
|
||||
class GraphQLWSClient():
|
||||
|
||||
def __init__(self, hge_ctx, endpoint):
|
||||
self.hge_ctx = hge_ctx
|
||||
self.ws_queue = queue.Queue(maxsize=-1)
|
||||
self.ws_url = urlparse(hge_ctx.hge_url)._replace(scheme='ws',
|
||||
path=endpoint)
|
||||
self.create_conn()
|
||||
|
||||
def get_queue(self):
|
||||
return self.ws_queue.queue
|
||||
|
||||
def clear_queue(self):
|
||||
self.ws_queue.queue.clear()
|
||||
|
||||
def create_conn(self):
|
||||
self.ws_queue.queue.clear()
|
||||
self.ws_id_query_queues = dict()
|
||||
self.ws_active_query_ids = set()
|
||||
|
||||
self.connected_event = threading.Event()
|
||||
self.init_done = False
|
||||
self.is_closing = False
|
||||
self.remote_closed = False
|
||||
|
||||
self._ws = websocket.WebSocketApp(self.ws_url.geturl(),
|
||||
on_open=self._on_open, on_message=self._on_message, on_close=self._on_close, subprotocols=["graphql-transport-ws"])
|
||||
self.wst = threading.Thread(target=self._ws.run_forever)
|
||||
self.wst.daemon = True
|
||||
self.wst.start()
|
||||
|
||||
def recreate_conn(self):
|
||||
self.teardown()
|
||||
self.create_conn()
|
||||
|
||||
def wait_for_connection(self, timeout=10):
|
||||
assert not self.is_closing
|
||||
assert self.connected_event.wait(timeout=timeout)
|
||||
|
||||
def get_ws_event(self, timeout):
|
||||
return self.ws_queue.get(timeout=timeout)
|
||||
|
||||
def has_ws_query_events(self, query_id):
|
||||
return not self.ws_id_query_queues[query_id].empty()
|
||||
|
||||
def get_ws_query_event(self, query_id, timeout):
|
||||
print("HELLO", self.ws_active_query_ids)
|
||||
return self.ws_id_query_queues[query_id].get(timeout=timeout)
|
||||
|
||||
def send(self, frame):
|
||||
self.wait_for_connection()
|
||||
if frame.get('type') == 'complete':
|
||||
self.ws_active_query_ids.discard( frame.get('id') )
|
||||
elif frame.get('type') == 'subscribe' and 'id' in frame:
|
||||
self.ws_id_query_queues[frame['id']] = queue.Queue(maxsize=-1)
|
||||
self._ws.send(json.dumps(frame))
|
||||
|
||||
def init_as_admin(self):
|
||||
headers={}
|
||||
if self.hge_ctx.hge_key:
|
||||
headers = {'x-hasura-admin-secret': self.hge_ctx.hge_key}
|
||||
self.init(headers)
|
||||
|
||||
def init(self, headers={}):
|
||||
payload = {'type': 'connection_init', 'payload': {}}
|
||||
|
||||
if headers and len(headers) > 0:
|
||||
payload['payload']['headers'] = headers
|
||||
|
||||
self.send(payload)
|
||||
ev = self.get_ws_event(5)
|
||||
assert ev['type'] == 'connection_ack', ev
|
||||
self.init_done = True
|
||||
|
||||
def stop(self, query_id):
|
||||
data = {'id': query_id, 'type': 'complete'}
|
||||
self.send(data)
|
||||
self.ws_active_query_ids.discard(query_id)
|
||||
|
||||
def gen_id(self, size=6, chars=string.ascii_letters + string.digits):
|
||||
new_id = ''.join(random.choice(chars) for _ in range(size))
|
||||
if new_id in self.ws_active_query_ids:
|
||||
return self.gen_id(size, chars)
|
||||
return new_id
|
||||
|
||||
def send_query(self, query, query_id=None, headers={}, timeout=60):
|
||||
graphql.parse(query['query'])
|
||||
if headers and len(headers) > 0:
|
||||
#Do init If headers are provided
|
||||
self.clear_queue()
|
||||
self.init(headers)
|
||||
elif not self.init_done:
|
||||
self.init()
|
||||
if query_id == None:
|
||||
query_id = self.gen_id()
|
||||
frame = {
|
||||
'id': query_id,
|
||||
'type': 'subscribe',
|
||||
'payload': query,
|
||||
}
|
||||
self.ws_active_query_ids.add(query_id)
|
||||
self.send(frame)
|
||||
while True:
|
||||
yield self.get_ws_query_event(query_id, timeout)
|
||||
|
||||
def _on_open(self):
|
||||
if not self.is_closing:
|
||||
self.connected_event.set()
|
||||
|
||||
def _on_message(self, message):
|
||||
# NOTE: make sure we preserve key ordering so we can test the ordering
|
||||
# properties in the graphql spec properly
|
||||
json_msg = json.loads(message, object_pairs_hook=OrderedDict)
|
||||
if json_msg['type'] == 'ping':
|
||||
new_msg = json_msg
|
||||
new_msg['type'] = 'pong'
|
||||
self.send(json.dumps(new_msg))
|
||||
return
|
||||
|
||||
if 'id' in json_msg:
|
||||
query_id = json_msg['id']
|
||||
if json_msg.get('type') == 'complete':
|
||||
#Remove from active queries list
|
||||
self.ws_active_query_ids.discard( query_id )
|
||||
if not query_id in self.ws_id_query_queues:
|
||||
self.ws_id_query_queues[json_msg['id']] = queue.Queue(maxsize=-1)
|
||||
#Put event in the correponding query_queue
|
||||
self.ws_id_query_queues[query_id].put(json_msg)
|
||||
|
||||
if json_msg['type'] != 'ping':
|
||||
self.ws_queue.put(json_msg)
|
||||
|
||||
def _on_close(self):
|
||||
self.remote_closed = True
|
||||
self.init_done = False
|
||||
|
||||
def get_conn_close_state(self):
|
||||
return self.remote_closed or self.is_closing
|
||||
|
||||
def teardown(self):
|
||||
self.is_closing = True
|
||||
if not self.remote_closed:
|
||||
self._ws.close()
|
||||
self.wst.join()
|
||||
|
||||
class ActionsWebhookHandler(http.server.BaseHTTPRequestHandler):
|
||||
|
||||
@ -501,6 +647,7 @@ class HGECtx:
|
||||
self.ws_client = GQLWsClient(self, '/v1/graphql')
|
||||
self.ws_client_v1alpha1 = GQLWsClient(self, '/v1alpha1/graphql')
|
||||
self.ws_client_relay = GQLWsClient(self, '/v1beta1/relay')
|
||||
self.ws_client_graphql_ws = GraphQLWSClient(self, '/v1/graphql')
|
||||
|
||||
self.backend = config.getoption('--backend')
|
||||
self.default_backend = 'postgres'
|
||||
@ -628,6 +775,7 @@ class HGECtx:
|
||||
self.ws_client.teardown()
|
||||
self.ws_client_v1alpha1.teardown()
|
||||
self.ws_client_relay.teardown()
|
||||
self.ws_client_graphql_ws.teardown()
|
||||
|
||||
def v1GraphqlExplain(self, q, hdrs=None):
|
||||
headers = {}
|
||||
|
@ -1,5 +1,6 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import time
|
||||
import pytest
|
||||
import json
|
||||
import queue
|
||||
@ -10,7 +11,11 @@ usefixtures = pytest.mark.usefixtures
|
||||
|
||||
@pytest.fixture(scope='class')
|
||||
def ws_conn_init(hge_ctx, ws_client):
|
||||
init_ws_conn(hge_ctx, ws_client)
|
||||
init_ws_conn(hge_ctx, ws_client)
|
||||
|
||||
@pytest.fixture(scope='class')
|
||||
def ws_conn_init_graphql_ws(hge_ctx, ws_client_graphql_ws):
|
||||
init_graphql_ws_conn(hge_ctx, ws_client_graphql_ws)
|
||||
|
||||
'''
|
||||
Refer: https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md#gql_connection_init
|
||||
@ -34,6 +39,24 @@ def init_ws_conn(hge_ctx, ws_client, payload = None):
|
||||
ev = ws_client.get_ws_event(3)
|
||||
assert ev['type'] == 'connection_ack', ev
|
||||
|
||||
def init_graphql_ws_conn(hge_ctx, ws_client_graphql_ws, payload = None):
|
||||
if payload is None:
|
||||
payload = {}
|
||||
if hge_ctx.hge_key is not None:
|
||||
payload = {
|
||||
'headers' : {
|
||||
'X-Hasura-Admin-Secret': hge_ctx.hge_key
|
||||
}
|
||||
}
|
||||
|
||||
init_msg = {
|
||||
'type': 'connection_init',
|
||||
'payload': payload,
|
||||
}
|
||||
ws_client_graphql_ws.send(init_msg)
|
||||
ev = ws_client_graphql_ws.get_ws_event(3)
|
||||
assert ev['type'] == 'connection_ack', ev
|
||||
|
||||
class TestSubscriptionCtrl(object):
|
||||
|
||||
def test_init_without_payload(self, hge_ctx, ws_client):
|
||||
@ -116,7 +139,6 @@ class TestSubscriptionBasic:
|
||||
'''
|
||||
Refer https://github.com/apollographql/subscriptions-transport-ws/blob/01e0b2b65df07c52f5831cce5c858966ba095993/src/server.ts#L306
|
||||
'''
|
||||
|
||||
@pytest.mark.skip(reason="refer https://github.com/hasura/graphql-engine/pull/387#issuecomment-421343098")
|
||||
def test_start_duplicate(self, ws_client):
|
||||
self.test_start(ws_client)
|
||||
@ -173,6 +195,140 @@ class TestSubscriptionBasic:
|
||||
ev = ws_client.get_ws_query_event('2',3)
|
||||
assert ev['type'] == 'complete' and ev['id'] == '2', ev
|
||||
|
||||
## NOTE: The same tests as in TestSubcscriptionBasic but with
|
||||
## the subscription transport being used is `graphql-ws`
|
||||
## FIXME: There's an issue with the tests being parametrized with both
|
||||
## postgres and mssql data sources enabled(See issue #2084).
|
||||
@usefixtures('per_method_tests_db_state', 'ws_conn_init_graphql_ws')
|
||||
class TestSubscriptionBasicGraphQLWS:
|
||||
@classmethod
|
||||
def dir(cls):
|
||||
return 'queries/subscriptions/basic'
|
||||
|
||||
@pytest.mark.parametrize("transport", ['http', 'websocket', 'subscription'])
|
||||
def test_negative(self, hge_ctx, transport):
|
||||
check_query_f(hge_ctx, self.dir() + '/negative_test.yaml', transport, gqlws=True)
|
||||
|
||||
def test_connection_error(self, hge_ctx, ws_client_graphql_ws):
|
||||
if ws_client_graphql_ws.get_conn_close_state():
|
||||
ws_client_graphql_ws.create_conn()
|
||||
if hge_ctx.hge_key == None:
|
||||
ws_client_graphql_ws.init()
|
||||
else:
|
||||
ws_client_graphql_ws.init_as_admin()
|
||||
ws_client_graphql_ws.send({'type': 'test'})
|
||||
time.sleep(2)
|
||||
ev = ws_client_graphql_ws.get_conn_close_state()
|
||||
assert ev == True, ev
|
||||
|
||||
def test_start(self, hge_ctx, ws_client_graphql_ws):
|
||||
if ws_client_graphql_ws.get_conn_close_state():
|
||||
ws_client_graphql_ws.create_conn()
|
||||
if hge_ctx.hge_key == None:
|
||||
ws_client_graphql_ws.init()
|
||||
else:
|
||||
ws_client_graphql_ws.init_as_admin()
|
||||
query = """
|
||||
subscription {
|
||||
hge_tests_test_t1(order_by: {c1: desc}, limit: 1) {
|
||||
c1,
|
||||
c2
|
||||
}
|
||||
}
|
||||
"""
|
||||
obj = {
|
||||
'id': '1',
|
||||
'payload': {
|
||||
'query': query
|
||||
},
|
||||
'type': 'subscribe'
|
||||
}
|
||||
ws_client_graphql_ws.send(obj)
|
||||
ev = ws_client_graphql_ws.get_ws_query_event('1',15)
|
||||
assert ev['type'] == 'next' and ev['id'] == '1', ev
|
||||
|
||||
@pytest.mark.skip(reason="refer https://github.com/hasura/graphql-engine/pull/387#issuecomment-421343098")
|
||||
def test_start_duplicate(self, hge_ctx, ws_client_graphql_ws):
|
||||
if ws_client_graphql_ws.get_conn_close_state():
|
||||
ws_client_graphql_ws.create_conn()
|
||||
if hge_ctx.hge_key == None:
|
||||
ws_client_graphql_ws.init()
|
||||
else:
|
||||
ws_client_graphql_ws.init_as_admin()
|
||||
self.test_start(ws_client_graphql_ws)
|
||||
|
||||
def test_stop_without_id(self, hge_ctx, ws_client_graphql_ws):
|
||||
if ws_client_graphql_ws.get_conn_close_state():
|
||||
ws_client_graphql_ws.create_conn()
|
||||
if hge_ctx.hge_key == None:
|
||||
ws_client_graphql_ws.init()
|
||||
else:
|
||||
ws_client_graphql_ws.init_as_admin()
|
||||
obj = {
|
||||
'type': 'complete'
|
||||
}
|
||||
ws_client_graphql_ws.send(obj)
|
||||
time.sleep(2)
|
||||
ev = ws_client_graphql_ws.get_conn_close_state()
|
||||
assert ev == True, ev
|
||||
|
||||
def test_stop(self, hge_ctx, ws_client_graphql_ws):
|
||||
if ws_client_graphql_ws.get_conn_close_state():
|
||||
ws_client_graphql_ws.create_conn()
|
||||
if hge_ctx.hge_key == None:
|
||||
ws_client_graphql_ws.init()
|
||||
else:
|
||||
ws_client_graphql_ws.init_as_admin()
|
||||
obj = {
|
||||
'type': 'complete',
|
||||
'id': '1'
|
||||
}
|
||||
ws_client_graphql_ws.send(obj)
|
||||
time.sleep(2)
|
||||
with pytest.raises(queue.Empty):
|
||||
ev = ws_client_graphql_ws.get_ws_event(3)
|
||||
|
||||
def test_start_after_stop(self, hge_ctx, ws_client_graphql_ws):
|
||||
if ws_client_graphql_ws.get_conn_close_state():
|
||||
ws_client_graphql_ws.create_conn()
|
||||
if hge_ctx.hge_key == None:
|
||||
ws_client_graphql_ws.init()
|
||||
else:
|
||||
ws_client_graphql_ws.init_as_admin()
|
||||
self.test_start(hge_ctx, ws_client_graphql_ws)
|
||||
## NOTE: test_start leaves a message in the queue, hence clearing it
|
||||
if len(ws_client_graphql_ws.get_queue()) > 0:
|
||||
ws_client_graphql_ws.clear_queue()
|
||||
self.test_stop(hge_ctx, ws_client_graphql_ws)
|
||||
|
||||
def test_complete(self, hge_ctx, ws_client_graphql_ws):
|
||||
if ws_client_graphql_ws.get_conn_close_state():
|
||||
ws_client_graphql_ws.create_conn()
|
||||
if hge_ctx.hge_key == None:
|
||||
ws_client_graphql_ws.init()
|
||||
else:
|
||||
ws_client_graphql_ws.init_as_admin()
|
||||
query = """
|
||||
query {
|
||||
hge_tests_test_t1(order_by: {c1: desc}, limit: 1) {
|
||||
c1,
|
||||
c2
|
||||
}
|
||||
}
|
||||
"""
|
||||
obj = {
|
||||
'id': '2',
|
||||
'payload': {
|
||||
'query': query
|
||||
},
|
||||
'type': 'subscribe'
|
||||
}
|
||||
ws_client_graphql_ws.send(obj)
|
||||
ev = ws_client_graphql_ws.get_ws_query_event('2',3)
|
||||
assert ev['type'] == 'next' and ev['id'] == '2', ev
|
||||
# Check for complete type
|
||||
ev = ws_client_graphql_ws.get_ws_query_event('2',3)
|
||||
assert ev['type'] == 'complete' and ev['id'] == '2', ev
|
||||
|
||||
@usefixtures('per_method_tests_db_state','ws_conn_init')
|
||||
class TestSubscriptionLiveQueries:
|
||||
@ -256,6 +412,86 @@ class TestSubscriptionLiveQueries:
|
||||
with pytest.raises(queue.Empty):
|
||||
ev = ws_client.get_ws_event(3)
|
||||
|
||||
@usefixtures('per_method_tests_db_state','ws_conn_init_graphql_ws')
|
||||
class TestSubscriptionLiveQueriesForGraphQLWS:
|
||||
|
||||
@classmethod
|
||||
def dir(cls):
|
||||
return 'queries/subscriptions/live_queries'
|
||||
|
||||
def test_live_queries(self, hge_ctx, ws_client_graphql_ws):
|
||||
'''
|
||||
Create connection using connection_init
|
||||
'''
|
||||
ws_client_graphql_ws.init_as_admin()
|
||||
|
||||
with open(self.dir() + "/steps.yaml") as c:
|
||||
conf = yaml.safe_load(c)
|
||||
|
||||
queryTmplt = """
|
||||
subscription ($result_limit: Int!) {
|
||||
hge_tests_live_query_{0}: hge_tests_test_t2(order_by: {c1: asc}, limit: $result_limit) {
|
||||
c1,
|
||||
c2
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
queries = [(0, 1), (1, 2), (2, 2)]
|
||||
liveQs = []
|
||||
for i, resultLimit in queries:
|
||||
query = queryTmplt.replace('{0}',str(i))
|
||||
headers={}
|
||||
if hge_ctx.hge_key is not None:
|
||||
headers['X-Hasura-Admin-Secret'] = hge_ctx.hge_key
|
||||
subscrPayload = { 'query': query, 'variables': { 'result_limit': resultLimit } }
|
||||
respLive = ws_client_graphql_ws.send_query(subscrPayload, query_id='live_'+str(i), headers=headers, timeout=15)
|
||||
liveQs.append(respLive)
|
||||
ev = next(respLive)
|
||||
assert ev['type'] == 'next', ev
|
||||
assert ev['id'] == 'live_' + str(i), ev
|
||||
assert ev['payload']['data'] == {'hge_tests_live_query_'+str(i): []}, ev['payload']['data']
|
||||
|
||||
assert isinstance(conf, list) == True, 'Not an list'
|
||||
for index, step in enumerate(conf):
|
||||
mutationPayload = { 'query': step['query'] }
|
||||
if 'variables' in step and step['variables']:
|
||||
mutationPayload['variables'] = json.loads(step['variables'])
|
||||
|
||||
expected_resp = json.loads(step['response'])
|
||||
|
||||
mutResp = ws_client_graphql_ws.send_query(mutationPayload,'mutation_'+str(index),timeout=15)
|
||||
ev = next(mutResp)
|
||||
assert ev['type'] == 'next' and ev['id'] == 'mutation_'+str(index), ev
|
||||
assert ev['payload']['data'] == expected_resp, ev['payload']['data']
|
||||
|
||||
ev = next(mutResp)
|
||||
assert ev['type'] == 'complete' and ev['id'] == 'mutation_'+str(index), ev
|
||||
|
||||
for (i, resultLimit), respLive in zip(queries, liveQs):
|
||||
ev = next(respLive)
|
||||
assert ev['type'] == 'next', ev
|
||||
assert ev['id'] == 'live_' + str(i), ev
|
||||
|
||||
expectedReturnedResponse = []
|
||||
if 'live_response' in step:
|
||||
expectedReturnedResponse = json.loads(step['live_response'])
|
||||
elif 'returning' in expected_resp[step['name']]:
|
||||
expectedReturnedResponse = expected_resp[step['name']]['returning']
|
||||
expectedLimitedResponse = expectedReturnedResponse[:resultLimit]
|
||||
expectedLiveResponse = { 'hge_tests_live_query_'+str(i): expectedLimitedResponse }
|
||||
|
||||
assert ev['payload']['data'] == expectedLiveResponse, ev['payload']['data']
|
||||
|
||||
for i, _ in queries:
|
||||
# stop live operation
|
||||
frame = {
|
||||
'id': 'live_'+str(i),
|
||||
'type': 'complete'
|
||||
}
|
||||
ws_client_graphql_ws.send(frame)
|
||||
ws_client_graphql_ws.clear_queue()
|
||||
|
||||
@pytest.mark.parametrize("backend", ['mssql', 'postgres'])
|
||||
@usefixtures('per_class_tests_db_state')
|
||||
class TestSubscriptionMultiplexing:
|
||||
|
@ -1,5 +1,6 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import time
|
||||
import ruamel.yaml as yaml
|
||||
from ruamel.yaml.compat import ordereddict, StringIO
|
||||
from ruamel.yaml.comments import CommentedMap
|
||||
@ -148,7 +149,7 @@ def mk_claims_with_namespace_path(claims,hasura_claims,namespace_path):
|
||||
|
||||
# Returns the response received and a bool indicating whether the test passed
|
||||
# or not (this will always be True unless we are `--accepting`)
|
||||
def check_query(hge_ctx, conf, transport='http', add_auth=True, claims_namespace_path=None):
|
||||
def check_query(hge_ctx, conf, transport='http', add_auth=True, claims_namespace_path=None, gqlws=False):
|
||||
hge_ctx.tests_passed = True
|
||||
headers = {}
|
||||
if 'headers' in conf:
|
||||
@ -211,13 +212,13 @@ def check_query(hge_ctx, conf, transport='http', add_auth=True, claims_namespace
|
||||
conf['status'], conf.get('response'), conf.get('resp_headers'), body=conf.get('body'), method=conf.get('method'))
|
||||
elif transport == 'websocket':
|
||||
print('running on websocket')
|
||||
return validate_gql_ws_q(hge_ctx, conf, headers, retry=True)
|
||||
return validate_gql_ws_q(hge_ctx, conf, headers, retry=True, gqlws=gqlws)
|
||||
elif transport == 'subscription':
|
||||
print('running via subscription')
|
||||
return validate_gql_ws_q(hge_ctx, conf, headers, retry=True, via_subscription=True)
|
||||
return validate_gql_ws_q(hge_ctx, conf, headers, retry=True, via_subscription=True, gqlws=gqlws)
|
||||
|
||||
|
||||
def validate_gql_ws_q(hge_ctx, conf, headers, retry=False, via_subscription=False):
|
||||
def validate_gql_ws_q(hge_ctx, conf, headers, retry=False, via_subscription=False, gqlws=False):
|
||||
assert 'response' in conf
|
||||
assert conf['url'].endswith('/graphql') or conf['url'].endswith('/relay')
|
||||
endpoint = conf['url']
|
||||
@ -236,11 +237,20 @@ def validate_gql_ws_q(hge_ctx, conf, headers, retry=False, via_subscription=Fals
|
||||
ws_client = hge_ctx.ws_client_v1alpha1
|
||||
elif endpoint == '/v1beta1/relay':
|
||||
ws_client = hge_ctx.ws_client_relay
|
||||
elif gqlws: # for `graphQL-ws` clients
|
||||
ws_client = hge_ctx.ws_client_graphql_ws
|
||||
else:
|
||||
ws_client = hge_ctx.ws_client
|
||||
print(ws_client.ws_url)
|
||||
if not headers or len(headers) == 0:
|
||||
ws_client.init({})
|
||||
|
||||
if ws_client.remote_closed or ws_client.is_closing:
|
||||
ws_client.create_conn()
|
||||
if not headers or len(headers) == 0 or hge_ctx.hge_key is None:
|
||||
ws_client.init()
|
||||
else:
|
||||
ws_client.init_as_admin()
|
||||
|
||||
query_resp = ws_client.send_query(query, query_id='hge_test', headers=headers, timeout=15)
|
||||
resp = next(query_resp)
|
||||
@ -253,18 +263,22 @@ def validate_gql_ws_q(hge_ctx, conf, headers, retry=False, via_subscription=Fals
|
||||
ws_client.recreate_conn()
|
||||
return validate_gql_ws_q(hge_ctx, query, headers, exp_http_response, False)
|
||||
else:
|
||||
assert resp['type'] in ['data', 'error'], resp
|
||||
assert resp['type'] in ['data', 'error', 'next'], resp
|
||||
|
||||
if 'errors' in exp_http_response or 'error' in exp_http_response:
|
||||
assert resp['type'] in ['data', 'error'], resp
|
||||
assert resp['type'] in ['data', 'error', 'next'], resp
|
||||
else:
|
||||
assert resp['type'] == 'data', resp
|
||||
assert resp['type'] == 'data' or resp['type'] == 'next', resp
|
||||
assert 'payload' in resp, resp
|
||||
|
||||
if via_subscription:
|
||||
ws_client.send({ 'id': 'hge_test', 'type': 'stop' })
|
||||
with pytest.raises(queue.Empty):
|
||||
ws_client.get_ws_event(0)
|
||||
if not gqlws:
|
||||
ws_client.send({ 'id': 'hge_test', 'type': 'stop' })
|
||||
else:
|
||||
ws_client.send({ 'id': 'hge_test', 'type': 'complete' })
|
||||
if not gqlws: # NOTE: for graphql-ws, we have some elements that are left in the queue especially after a 'next' message.
|
||||
with pytest.raises(queue.Empty):
|
||||
ws_client.get_ws_event(0)
|
||||
else:
|
||||
resp_done = next(query_resp)
|
||||
assert resp_done['type'] == 'complete'
|
||||
@ -414,7 +428,7 @@ def get_conf_f(f):
|
||||
with open(f, 'r+') as c:
|
||||
return yaml.YAML().load(c)
|
||||
|
||||
def check_query_f(hge_ctx, f, transport='http', add_auth=True):
|
||||
def check_query_f(hge_ctx, f, transport='http', add_auth=True, gqlws = False):
|
||||
print("Test file: " + f)
|
||||
hge_ctx.may_skip_test_teardown = False
|
||||
print ("transport="+transport)
|
||||
@ -429,14 +443,14 @@ def check_query_f(hge_ctx, f, transport='http', add_auth=True):
|
||||
conf = yml.load(c)
|
||||
if isinstance(conf, list):
|
||||
for ix, sconf in enumerate(conf):
|
||||
actual_resp, matched = check_query(hge_ctx, sconf, transport, add_auth)
|
||||
actual_resp, matched = check_query(hge_ctx, sconf, transport, add_auth, None, gqlws)
|
||||
if PytestConf.config.getoption("--accept") and not matched:
|
||||
conf[ix]['response'] = actual_resp
|
||||
should_write_back = True
|
||||
else:
|
||||
if conf['status'] != 200:
|
||||
hge_ctx.may_skip_test_teardown = True
|
||||
actual_resp, matched = check_query(hge_ctx, conf, transport, add_auth)
|
||||
actual_resp, matched = check_query(hge_ctx, conf, transport, add_auth, None, gqlws)
|
||||
# If using `--accept` write the file back out with the new expected
|
||||
# response set to the actual response we got:
|
||||
if PytestConf.config.getoption("--accept") and not matched:
|
||||
|
Loading…
Reference in New Issue
Block a user