graphql-engine/server/src-lib/Hasura/GraphQL/Transport/WSServerApp.hs
2023-07-04 06:55:40 +00:00

187 lines
6.6 KiB
Haskell

module Hasura.GraphQL.Transport.WSServerApp
( createWSServerApp,
stopWSServerApp,
createWSServerEnv,
)
where
import Control.Concurrent.Async.Lifted.Safe qualified as LA
import Control.Concurrent.STM qualified as STM
import Control.Exception.Lifted
import Control.Monad.Trans.Control qualified as MC
import Data.Aeson qualified as J
import Data.Aeson.Encoding qualified as J
import Data.ByteString.Char8 qualified as B (pack)
import Data.Text (pack)
import Hasura.App.State
import Hasura.Backends.DataConnector.Agent.Client (AgentLicenseKey)
import Hasura.CredentialCache
import Hasura.GraphQL.Execute qualified as E
import Hasura.GraphQL.Logging
import Hasura.GraphQL.Transport.HTTP (MonadExecuteQuery)
import Hasura.GraphQL.Transport.HTTP.Protocol (encodeGQExecError)
import Hasura.GraphQL.Transport.Instances ()
import Hasura.GraphQL.Transport.WebSocket
import Hasura.GraphQL.Transport.WebSocket.Protocol
import Hasura.GraphQL.Transport.WebSocket.Server qualified as WS
import Hasura.GraphQL.Transport.WebSocket.Types
import Hasura.Logging qualified as L
import Hasura.Metadata.Class
import Hasura.Prelude
import Hasura.QueryTags
import Hasura.RQL.Types.SchemaCache
import Hasura.Server.AppStateRef
import Hasura.Server.Auth (UserAuthentication)
import Hasura.Server.Init.Config
( WSConnectionInitTimeout,
)
import Hasura.Server.Limits
import Hasura.Server.Metrics (ServerMetrics (..))
import Hasura.Server.Prometheus
( PrometheusMetrics (..),
decWebsocketConnections,
incWebsocketConnections,
)
import Hasura.Server.Types (MonadGetPolicies (..))
import Hasura.Services.Network
import Hasura.Tracing qualified as Tracing
import Network.WebSockets qualified as WS
import System.Metrics.Gauge qualified as EKG.Gauge
createWSServerApp ::
( MonadIO m,
MC.MonadBaseControl IO m,
LA.Forall (LA.Pure m),
UserAuthentication m,
E.MonadGQLExecutionCheck m,
WS.MonadWSLog m,
MonadQueryLog m,
MonadExecutionLog m,
MonadExecuteQuery m,
MonadMetadataStorage m,
MonadQueryTags m,
HasResourceLimits m,
ProvidesNetwork m,
Tracing.MonadTrace m,
MonadGetPolicies m
) =>
HashSet (L.EngineLogType L.Hasura) ->
WSServerEnv impl ->
WSConnectionInitTimeout ->
Maybe (CredentialCache AgentLicenseKey) ->
-- | aka generalized 'WS.ServerApp'
WS.HasuraServerApp m
createWSServerApp enabledLogTypes serverEnv connInitTimeout licenseKeyCache = \ !ipAddress !pendingConn -> do
let getMetricsConfig = scMetricsConfig <$> getSchemaCache (_wseAppStateRef serverEnv)
WS.createServerApp getMetricsConfig connInitTimeout (_wseServer serverEnv) prometheusMetrics handlers ipAddress pendingConn
where
handlers =
WS.WSHandlers
onConnHandler
onMessageHandler
onCloseHandler
logger = _wseLogger serverEnv
serverMetrics = _wseServerMetrics serverEnv
prometheusMetrics = _wsePrometheusMetrics serverEnv
getAuthMode = acAuthMode <$> getAppContext (_wseAppStateRef serverEnv)
wsActions = mkWSActions logger
-- Mask async exceptions during event processing to help maintain integrity of mutable vars:
-- here `sp` stands for sub-protocol
onConnHandler rid rh ip sp = mask_ do
liftIO $ EKG.Gauge.inc $ smWebsocketConnections serverMetrics
liftIO $ incWebsocketConnections $ pmConnections prometheusMetrics
flip runReaderT serverEnv $ onConn rid rh ip (wsActions sp)
onMessageHandler conn bs sp =
mask_
$ onMessage enabledLogTypes getAuthMode serverEnv conn bs (wsActions sp) licenseKeyCache
onCloseHandler conn = mask_ do
granularPrometheusMetricsState <- runGetPrometheusMetricsGranularity
liftIO $ EKG.Gauge.dec $ smWebsocketConnections serverMetrics
liftIO $ decWebsocketConnections $ pmConnections prometheusMetrics
onClose logger serverMetrics prometheusMetrics (_wseSubscriptionState serverEnv) conn granularPrometheusMetricsState
stopWSServerApp :: WSServerEnv impl -> IO ()
stopWSServerApp wsEnv = WS.shutdown (_wseServer wsEnv)
createWSServerEnv ::
( HasAppEnv m,
MonadIO m
) =>
AppStateRef impl ->
m (WSServerEnv impl)
createWSServerEnv appStateRef = do
AppEnv {..} <- askAppEnv
let getCorsPolicy = acCorsPolicy <$> getAppContext appStateRef
logger = _lsLogger appEnvLoggers
AppContext {acEnableAllowlist, acAuthMode, acSQLGenCtx, acExperimentalFeatures, acDefaultNamingConvention} <- liftIO $ getAppContext appStateRef
allowlist <- liftIO $ scAllowlist <$> getSchemaCache appStateRef
corsPolicy <- liftIO getCorsPolicy
wsServer <- liftIO $ STM.atomically $ WS.createWSServer acAuthMode acEnableAllowlist allowlist corsPolicy acSQLGenCtx acExperimentalFeatures acDefaultNamingConvention logger
pure
$ WSServerEnv
(_lsLogger appEnvLoggers)
appEnvSubscriptionState
appStateRef
appEnvManager
getCorsPolicy
appEnvEnableReadOnlyMode
wsServer
appEnvWebSocketKeepAlive
appEnvServerMetrics
appEnvPrometheusMetrics
appEnvTraceSamplingPolicy
mkWSActions :: L.Logger L.Hasura -> WSSubProtocol -> WS.WSActions WSConnData
mkWSActions logger subProtocol =
WS.WSActions
mkPostExecErrMessageAction
mkOnErrorMessageAction
mkConnectionCloseAction
keepAliveAction
getServerMsgType
mkAcceptRequest
fmtErrorMessage
where
mkPostExecErrMessageAction wsConn opId execErr =
sendMsg wsConn $ case subProtocol of
Apollo -> SMData $ DataMsg opId $ throwError execErr
GraphQLWS -> SMErr $ ErrorMsg opId $ encodeGQExecError execErr
mkOnErrorMessageAction wsConn err mErrMsg =
case subProtocol of
Apollo ->
case mErrMsg of
WS.ConnInitFailed -> sendCloseWithMsg logger wsConn (WS.mkWSServerErrorCode subProtocol mErrMsg err) (Just $ SMConnErr err) Nothing
WS.ClientMessageParseFailed -> sendMsg wsConn $ SMConnErr err
GraphQLWS -> sendCloseWithMsg logger wsConn (WS.mkWSServerErrorCode subProtocol mErrMsg err) Nothing Nothing
mkConnectionCloseAction wsConn opId errMsg =
when (subProtocol == GraphQLWS)
$ sendCloseWithMsg logger wsConn (GenericError4400 errMsg) (Just . SMErr $ ErrorMsg opId $ J.toEncoding (pack errMsg)) (Just 1000)
getServerMsgType = case subProtocol of
Apollo -> SMData
GraphQLWS -> SMNext
keepAliveAction wsConn = sendMsg wsConn
$ case subProtocol of
Apollo -> SMConnKeepAlive
GraphQLWS -> SMPing . Just $ keepAliveMessage
mkAcceptRequest =
WS.defaultAcceptRequest
{ WS.acceptSubprotocol = Just . B.pack . showSubProtocol $ subProtocol
}
fmtErrorMessage errMsgs = case subProtocol of
Apollo -> J.pairs (J.pair "errors" $ J.list id errMsgs)
GraphQLWS -> J.list id errMsgs