mirror of
https://github.com/hasura/graphql-engine.git
synced 2024-12-17 20:41:49 +03:00
e69d4c445f
PR-URL: https://github.com/hasura/graphql-engine-mono/pull/9707 GitOrigin-RevId: 6695d107169cd82b598ea53623b7a16974749fb0
187 lines
6.6 KiB
Haskell
187 lines
6.6 KiB
Haskell
module Hasura.GraphQL.Transport.WSServerApp
|
|
( createWSServerApp,
|
|
stopWSServerApp,
|
|
createWSServerEnv,
|
|
)
|
|
where
|
|
|
|
import Control.Concurrent.Async.Lifted.Safe qualified as LA
|
|
import Control.Concurrent.STM qualified as STM
|
|
import Control.Exception.Lifted
|
|
import Control.Monad.Trans.Control qualified as MC
|
|
import Data.Aeson qualified as J
|
|
import Data.Aeson.Encoding qualified as J
|
|
import Data.ByteString.Char8 qualified as B (pack)
|
|
import Data.Text (pack)
|
|
import Hasura.App.State
|
|
import Hasura.Backends.DataConnector.Agent.Client (AgentLicenseKey)
|
|
import Hasura.CredentialCache
|
|
import Hasura.GraphQL.Execute qualified as E
|
|
import Hasura.GraphQL.Logging
|
|
import Hasura.GraphQL.Transport.HTTP (MonadExecuteQuery)
|
|
import Hasura.GraphQL.Transport.HTTP.Protocol (encodeGQExecError)
|
|
import Hasura.GraphQL.Transport.Instances ()
|
|
import Hasura.GraphQL.Transport.WebSocket
|
|
import Hasura.GraphQL.Transport.WebSocket.Protocol
|
|
import Hasura.GraphQL.Transport.WebSocket.Server qualified as WS
|
|
import Hasura.GraphQL.Transport.WebSocket.Types
|
|
import Hasura.Logging qualified as L
|
|
import Hasura.Metadata.Class
|
|
import Hasura.Prelude
|
|
import Hasura.QueryTags
|
|
import Hasura.RQL.Types.SchemaCache
|
|
import Hasura.Server.AppStateRef
|
|
import Hasura.Server.Auth (UserAuthentication)
|
|
import Hasura.Server.Init.Config
|
|
( WSConnectionInitTimeout,
|
|
)
|
|
import Hasura.Server.Limits
|
|
import Hasura.Server.Metrics (ServerMetrics (..))
|
|
import Hasura.Server.Prometheus
|
|
( PrometheusMetrics (..),
|
|
decWebsocketConnections,
|
|
incWebsocketConnections,
|
|
)
|
|
import Hasura.Server.Types (MonadGetPolicies (..))
|
|
import Hasura.Services.Network
|
|
import Hasura.Tracing qualified as Tracing
|
|
import Network.WebSockets qualified as WS
|
|
import System.Metrics.Gauge qualified as EKG.Gauge
|
|
|
|
createWSServerApp ::
|
|
( MonadIO m,
|
|
MC.MonadBaseControl IO m,
|
|
LA.Forall (LA.Pure m),
|
|
UserAuthentication m,
|
|
E.MonadGQLExecutionCheck m,
|
|
WS.MonadWSLog m,
|
|
MonadQueryLog m,
|
|
MonadExecutionLog m,
|
|
MonadExecuteQuery m,
|
|
MonadMetadataStorage m,
|
|
MonadQueryTags m,
|
|
HasResourceLimits m,
|
|
ProvidesNetwork m,
|
|
Tracing.MonadTrace m,
|
|
MonadGetPolicies m
|
|
) =>
|
|
HashSet (L.EngineLogType L.Hasura) ->
|
|
WSServerEnv impl ->
|
|
WSConnectionInitTimeout ->
|
|
Maybe (CredentialCache AgentLicenseKey) ->
|
|
-- | aka generalized 'WS.ServerApp'
|
|
WS.HasuraServerApp m
|
|
createWSServerApp enabledLogTypes serverEnv connInitTimeout licenseKeyCache = \ !ipAddress !pendingConn -> do
|
|
let getMetricsConfig = scMetricsConfig <$> getSchemaCache (_wseAppStateRef serverEnv)
|
|
WS.createServerApp getMetricsConfig connInitTimeout (_wseServer serverEnv) prometheusMetrics handlers ipAddress pendingConn
|
|
where
|
|
handlers =
|
|
WS.WSHandlers
|
|
onConnHandler
|
|
onMessageHandler
|
|
onCloseHandler
|
|
|
|
logger = _wseLogger serverEnv
|
|
serverMetrics = _wseServerMetrics serverEnv
|
|
prometheusMetrics = _wsePrometheusMetrics serverEnv
|
|
|
|
getAuthMode = acAuthMode <$> getAppContext (_wseAppStateRef serverEnv)
|
|
wsActions = mkWSActions logger
|
|
|
|
-- Mask async exceptions during event processing to help maintain integrity of mutable vars:
|
|
-- here `sp` stands for sub-protocol
|
|
onConnHandler rid rh ip sp = mask_ do
|
|
liftIO $ EKG.Gauge.inc $ smWebsocketConnections serverMetrics
|
|
liftIO $ incWebsocketConnections $ pmConnections prometheusMetrics
|
|
flip runReaderT serverEnv $ onConn rid rh ip (wsActions sp)
|
|
|
|
onMessageHandler conn bs sp =
|
|
mask_
|
|
$ onMessage enabledLogTypes getAuthMode serverEnv conn bs (wsActions sp) licenseKeyCache
|
|
|
|
onCloseHandler conn = mask_ do
|
|
granularPrometheusMetricsState <- runGetPrometheusMetricsGranularity
|
|
liftIO $ EKG.Gauge.dec $ smWebsocketConnections serverMetrics
|
|
liftIO $ decWebsocketConnections $ pmConnections prometheusMetrics
|
|
onClose logger serverMetrics prometheusMetrics (_wseSubscriptionState serverEnv) conn granularPrometheusMetricsState
|
|
|
|
stopWSServerApp :: WSServerEnv impl -> IO ()
|
|
stopWSServerApp wsEnv = WS.shutdown (_wseServer wsEnv)
|
|
|
|
createWSServerEnv ::
|
|
( HasAppEnv m,
|
|
MonadIO m
|
|
) =>
|
|
AppStateRef impl ->
|
|
m (WSServerEnv impl)
|
|
createWSServerEnv appStateRef = do
|
|
AppEnv {..} <- askAppEnv
|
|
let getCorsPolicy = acCorsPolicy <$> getAppContext appStateRef
|
|
logger = _lsLogger appEnvLoggers
|
|
|
|
AppContext {acEnableAllowlist, acAuthMode, acSQLGenCtx, acExperimentalFeatures, acDefaultNamingConvention} <- liftIO $ getAppContext appStateRef
|
|
allowlist <- liftIO $ scAllowlist <$> getSchemaCache appStateRef
|
|
corsPolicy <- liftIO getCorsPolicy
|
|
|
|
wsServer <- liftIO $ STM.atomically $ WS.createWSServer acAuthMode acEnableAllowlist allowlist corsPolicy acSQLGenCtx acExperimentalFeatures acDefaultNamingConvention logger
|
|
|
|
pure
|
|
$ WSServerEnv
|
|
(_lsLogger appEnvLoggers)
|
|
appEnvSubscriptionState
|
|
appStateRef
|
|
appEnvManager
|
|
getCorsPolicy
|
|
appEnvEnableReadOnlyMode
|
|
wsServer
|
|
appEnvWebSocketKeepAlive
|
|
appEnvServerMetrics
|
|
appEnvPrometheusMetrics
|
|
appEnvTraceSamplingPolicy
|
|
|
|
mkWSActions :: L.Logger L.Hasura -> WSSubProtocol -> WS.WSActions WSConnData
|
|
mkWSActions logger subProtocol =
|
|
WS.WSActions
|
|
mkPostExecErrMessageAction
|
|
mkOnErrorMessageAction
|
|
mkConnectionCloseAction
|
|
keepAliveAction
|
|
getServerMsgType
|
|
mkAcceptRequest
|
|
fmtErrorMessage
|
|
where
|
|
mkPostExecErrMessageAction wsConn opId execErr =
|
|
sendMsg wsConn $ case subProtocol of
|
|
Apollo -> SMData $ DataMsg opId $ throwError execErr
|
|
GraphQLWS -> SMErr $ ErrorMsg opId $ encodeGQExecError execErr
|
|
|
|
mkOnErrorMessageAction wsConn err mErrMsg =
|
|
case subProtocol of
|
|
Apollo ->
|
|
case mErrMsg of
|
|
WS.ConnInitFailed -> sendCloseWithMsg logger wsConn (WS.mkWSServerErrorCode subProtocol mErrMsg err) (Just $ SMConnErr err) Nothing
|
|
WS.ClientMessageParseFailed -> sendMsg wsConn $ SMConnErr err
|
|
GraphQLWS -> sendCloseWithMsg logger wsConn (WS.mkWSServerErrorCode subProtocol mErrMsg err) Nothing Nothing
|
|
|
|
mkConnectionCloseAction wsConn opId errMsg =
|
|
when (subProtocol == GraphQLWS)
|
|
$ sendCloseWithMsg logger wsConn (GenericError4400 errMsg) (Just . SMErr $ ErrorMsg opId $ J.toEncoding (pack errMsg)) (Just 1000)
|
|
|
|
getServerMsgType = case subProtocol of
|
|
Apollo -> SMData
|
|
GraphQLWS -> SMNext
|
|
|
|
keepAliveAction wsConn = sendMsg wsConn
|
|
$ case subProtocol of
|
|
Apollo -> SMConnKeepAlive
|
|
GraphQLWS -> SMPing . Just $ keepAliveMessage
|
|
|
|
mkAcceptRequest =
|
|
WS.defaultAcceptRequest
|
|
{ WS.acceptSubprotocol = Just . B.pack . showSubProtocol $ subProtocol
|
|
}
|
|
|
|
fmtErrorMessage errMsgs = case subProtocol of
|
|
Apollo -> J.pairs (J.pair "errors" $ J.list id errMsgs)
|
|
GraphQLWS -> J.list id errMsgs
|