mirror of
https://github.com/hasura/graphql-engine.git
synced 2024-12-18 13:02:11 +03:00
338 lines
12 KiB
Haskell
338 lines
12 KiB
Haskell
{-# LANGUAGE FlexibleContexts #-}
|
|
{-# LANGUAGE OverloadedStrings #-}
|
|
{-# LANGUAGE RankNTypes #-}
|
|
{-# LANGUAGE TemplateHaskell #-}
|
|
|
|
module Hasura.GraphQL.Transport.WebSocket
|
|
( createWSServerApp
|
|
, createWSServerEnv
|
|
) where
|
|
|
|
import qualified Control.Concurrent.Async as A
|
|
import qualified Control.Concurrent.STM as STM
|
|
import qualified Data.Aeson as J
|
|
import qualified Data.Aeson.Casing as J
|
|
import qualified Data.Aeson.TH as J
|
|
import qualified Data.ByteString.Lazy as BL
|
|
import qualified Data.CaseInsensitive as CI
|
|
import qualified Data.HashMap.Strict as Map
|
|
import qualified Data.Text as T
|
|
import qualified Data.Text.Encoding as TE
|
|
import qualified Language.GraphQL.Draft.Syntax as G
|
|
import qualified ListT
|
|
import qualified Network.HTTP.Client as H
|
|
import qualified Network.HTTP.Types.Status as H
|
|
import qualified Network.WebSockets as WS
|
|
import qualified STMContainers.Map as STMMap
|
|
|
|
import Control.Concurrent (threadDelay)
|
|
import qualified Data.IORef as IORef
|
|
|
|
import Hasura.GraphQL.Resolve (resolveSelSet)
|
|
import Hasura.GraphQL.Resolve.Context (RespTx)
|
|
import qualified Hasura.GraphQL.Resolve.LiveQuery as LQ
|
|
import Hasura.GraphQL.Schema (GCtxMap, getGCtx)
|
|
import Hasura.GraphQL.Transport.HTTP.Protocol
|
|
import Hasura.GraphQL.Transport.WebSocket.Protocol
|
|
import qualified Hasura.GraphQL.Transport.WebSocket.Server as WS
|
|
import Hasura.GraphQL.Validate (validateGQ)
|
|
import qualified Hasura.Logging as L
|
|
import Hasura.Prelude
|
|
import Hasura.RQL.Types
|
|
import Hasura.Server.Auth (AuthMode,
|
|
getUserInfo)
|
|
import qualified Hasura.Server.Query as RQ
|
|
|
|
-- uniquely identifies an operation
|
|
type GOperationId = (WS.WSId, OperationId)
|
|
|
|
type TxRunner = RespTx -> IO (Either QErr BL.ByteString)
|
|
|
|
type OperationMap
|
|
= STMMap.Map OperationId LQ.LiveQuery
|
|
|
|
data WSConnData
|
|
= WSConnData
|
|
-- the role and headers are set only on connection_init message
|
|
{ _wscUser :: !(IORef.IORef (Maybe (Either Text UserInfo)))
|
|
-- we only care about subscriptions,
|
|
-- the other operations (query/mutations)
|
|
-- are not tracked here
|
|
, _wscOpMap :: !OperationMap
|
|
}
|
|
|
|
type LiveQueryMap = LQ.LiveQueryMap GOperationId
|
|
type WSServer = WS.WSServer WSConnData
|
|
|
|
type WSConn = WS.WSConn WSConnData
|
|
sendMsg :: (MonadIO m) => WSConn -> ServerMsg -> m ()
|
|
sendMsg wsConn =
|
|
liftIO . WS.sendMsg wsConn . encodeServerMsg
|
|
|
|
data OpDetail
|
|
= ODStarted
|
|
| ODProtoErr !Text
|
|
| ODQueryErr !QErr
|
|
| ODCompleted
|
|
| ODStopped
|
|
deriving (Show, Eq)
|
|
$(J.deriveToJSON
|
|
J.defaultOptions { J.constructorTagModifier = J.snakeCase . drop 2
|
|
, J.sumEncoding = J.TaggedObject "type" "detail"
|
|
}
|
|
''OpDetail)
|
|
|
|
data WSEvent
|
|
= EAccepted
|
|
| ERejected !QErr
|
|
| EConnErr !ConnErrMsg
|
|
| EOperation !OperationId !(Maybe OperationName) !OpDetail
|
|
| EClosed
|
|
deriving (Show, Eq)
|
|
$(J.deriveToJSON
|
|
J.defaultOptions { J.constructorTagModifier = J.snakeCase . drop 1
|
|
, J.sumEncoding = J.TaggedObject "type" "detail"
|
|
}
|
|
''WSEvent)
|
|
|
|
data WSLog
|
|
= WSLog
|
|
{ _wslWebsocketId :: !WS.WSId
|
|
, _wslUser :: !(Maybe UserVars)
|
|
, _wslEvent :: !WSEvent
|
|
} deriving (Show, Eq)
|
|
$(J.deriveToJSON (J.aesonDrop 4 J.snakeCase) ''WSLog)
|
|
|
|
instance L.ToEngineLog WSLog where
|
|
toEngineLog wsLog =
|
|
(L.LevelInfo, "ws-handler", J.toJSON wsLog)
|
|
|
|
data WSServerEnv
|
|
= WSServerEnv
|
|
{ _wseLogger :: !L.Logger
|
|
, _wseServer :: !WSServer
|
|
, _wseRunTx :: !TxRunner
|
|
, _wseLiveQMap :: !LiveQueryMap
|
|
, _wseGCtxMap :: !(IORef.IORef (SchemaCache, GCtxMap))
|
|
, _wseHManager :: !H.Manager
|
|
}
|
|
|
|
onConn :: L.Logger -> WS.OnConnH WSConnData
|
|
onConn (L.Logger logger) wsId requestHead = do
|
|
res <- runExceptT checkPath
|
|
either reject accept res
|
|
where
|
|
|
|
keepAliveAction wsConn = forever $ do
|
|
sendMsg wsConn SMConnKeepAlive
|
|
threadDelay $ 5 * 1000 * 1000
|
|
|
|
accept _ = do
|
|
logger $ WSLog wsId Nothing EAccepted
|
|
connData <- WSConnData <$> IORef.newIORef Nothing <*> STMMap.newIO
|
|
let acceptRequest = WS.defaultAcceptRequest
|
|
{ WS.acceptSubprotocol = Just "graphql-ws"}
|
|
return $ Right (connData, acceptRequest, Just keepAliveAction)
|
|
|
|
reject qErr = do
|
|
logger $ WSLog wsId Nothing $ ERejected qErr
|
|
return $ Left $ WS.RejectRequest
|
|
(H.statusCode $ qeStatus qErr)
|
|
(H.statusMessage $ qeStatus qErr) []
|
|
(BL.toStrict $ J.encode $ encodeQErr False qErr)
|
|
|
|
checkPath =
|
|
when (WS.requestPath requestHead /= "/v1alpha1/graphql") $
|
|
throw404 "only /v1alpha1/graphql is supported on websockets"
|
|
|
|
onStart :: WSServerEnv -> WSConn -> StartMsg -> IO ()
|
|
onStart serverEnv wsConn (StartMsg opId q) = catchAndIgnore $ do
|
|
|
|
opM <- liftIO $ STM.atomically $ STMMap.lookup opId opMap
|
|
|
|
when (isJust opM) $ withComplete $ sendConnErr $
|
|
"an operation already exists with this id: " <> unOperationId opId
|
|
|
|
userInfoM <- liftIO $ IORef.readIORef userInfoR
|
|
userInfo <- case userInfoM of
|
|
Just (Right userInfo) -> return userInfo
|
|
Just (Left initErr) -> do
|
|
let connErr = "cannot start as connection_init failed with : " <> initErr
|
|
withComplete $ sendConnErr connErr
|
|
Nothing -> do
|
|
let connErr = "start received before the connection is initialised"
|
|
withComplete $ sendConnErr connErr
|
|
|
|
-- validate and build tx
|
|
gCtxMap <- fmap snd $ liftIO $ IORef.readIORef gCtxMapRef
|
|
let gCtx = getGCtx (userRole userInfo) gCtxMap
|
|
|
|
(opTy, fields) <- either (withComplete . preExecErr) return $
|
|
runReaderT (validateGQ q) gCtx
|
|
let qTx = RQ.setHeadersTx (userVars userInfo) >>
|
|
resolveSelSet userInfo gCtx opTy fields
|
|
|
|
case opTy of
|
|
G.OperationTypeSubscription -> do
|
|
let lq = LQ.LiveQuery userInfo q
|
|
liftIO $ STM.atomically $ STMMap.insert lq opId opMap
|
|
liftIO $ LQ.addLiveQuery runTx lqMap lq
|
|
qTx (wsId, opId) liveQOnChange
|
|
logOpEv ODStarted
|
|
_ -> do
|
|
logOpEv ODStarted
|
|
resp <- liftIO $ runTx qTx
|
|
either postExecErr sendSuccResp resp
|
|
sendCompleted
|
|
|
|
where
|
|
WSServerEnv logger _ runTx lqMap gCtxMapRef _ = serverEnv
|
|
wsId = WS.getWSId wsConn
|
|
WSConnData userInfoR opMap = WS.getData wsConn
|
|
|
|
logOpEv opDet =
|
|
logWSEvent logger wsConn $ EOperation opId (_grOperationName q) opDet
|
|
|
|
sendConnErr connErr = do
|
|
sendMsg wsConn $ SMErr $ ErrorMsg opId $ J.toJSON connErr
|
|
logOpEv $ ODProtoErr connErr
|
|
|
|
sendCompleted = do
|
|
sendMsg wsConn $ SMComplete $ CompletionMsg opId
|
|
logOpEv ODCompleted
|
|
|
|
postExecErr qErr = do
|
|
logOpEv $ ODQueryErr qErr
|
|
sendMsg wsConn $ SMData $ DataMsg opId $
|
|
GQExecError $ pure $ encodeQErr False qErr
|
|
|
|
-- why wouldn't pre exec error use graphql response?
|
|
preExecErr qErr = do
|
|
logOpEv $ ODQueryErr qErr
|
|
sendMsg wsConn $ SMErr $ ErrorMsg opId $ encodeQErr False qErr
|
|
|
|
sendSuccResp bs =
|
|
sendMsg wsConn $ SMData $ DataMsg opId $ GQSuccess bs
|
|
|
|
withComplete :: ExceptT () IO () -> ExceptT () IO a
|
|
withComplete action = do
|
|
action
|
|
sendCompleted
|
|
throwError ()
|
|
|
|
-- on change, send message on the websocket
|
|
liveQOnChange resp =
|
|
WS.sendMsg wsConn $ encodeServerMsg $ SMData $ DataMsg opId resp
|
|
|
|
catchAndIgnore :: ExceptT () IO () -> IO ()
|
|
catchAndIgnore m = void $ runExceptT m
|
|
|
|
onMessage
|
|
:: AuthMode
|
|
-> WSServerEnv
|
|
-> WSConn -> BL.ByteString -> IO ()
|
|
onMessage authMode serverEnv wsConn msgRaw =
|
|
case J.eitherDecode msgRaw of
|
|
Left e -> do
|
|
let err = ConnErrMsg $ "parsing ClientMessage failed: " <> T.pack e
|
|
logWSEvent logger wsConn $ EConnErr err
|
|
sendMsg wsConn $ SMConnErr err
|
|
|
|
Right msg -> case msg of
|
|
CMConnInit params -> onConnInit (_wseLogger serverEnv)
|
|
(_wseHManager serverEnv)
|
|
wsConn authMode params
|
|
CMStart startMsg -> onStart serverEnv wsConn startMsg
|
|
CMStop stopMsg -> onStop serverEnv wsConn stopMsg
|
|
CMConnTerm -> WS.closeConn wsConn "GQL_CONNECTION_TERMINATE received"
|
|
where
|
|
logger = _wseLogger serverEnv
|
|
|
|
onStop :: WSServerEnv -> WSConn -> StopMsg -> IO ()
|
|
onStop serverEnv wsConn (StopMsg opId) = do
|
|
-- probably wrap the whole thing in a single tx?
|
|
opM <- liftIO $ STM.atomically $ STMMap.lookup opId opMap
|
|
case opM of
|
|
Just liveQ -> do
|
|
let opNameM = _grOperationName $ LQ._lqRequest liveQ
|
|
logWSEvent logger wsConn $ EOperation opId opNameM ODStopped
|
|
LQ.removeLiveQuery lqMap liveQ (wsId, opId)
|
|
Nothing -> return ()
|
|
STM.atomically $ STMMap.delete opId opMap
|
|
where
|
|
logger = _wseLogger serverEnv
|
|
lqMap = _wseLiveQMap serverEnv
|
|
wsId = WS.getWSId wsConn
|
|
opMap = _wscOpMap $ WS.getData wsConn
|
|
|
|
logWSEvent
|
|
:: (MonadIO m)
|
|
=> L.Logger -> WSConn -> WSEvent -> m ()
|
|
logWSEvent (L.Logger logger) wsConn wsEv = do
|
|
userInfoME <- liftIO $ IORef.readIORef userInfoR
|
|
let userInfoM = case userInfoME of
|
|
Just (Right userInfo) -> return $ userVars userInfo
|
|
_ -> Nothing
|
|
liftIO $ logger $ WSLog wsId userInfoM wsEv
|
|
where
|
|
WSConnData userInfoR _ = WS.getData wsConn
|
|
wsId = WS.getWSId wsConn
|
|
|
|
onConnInit
|
|
:: (MonadIO m)
|
|
=> L.Logger -> H.Manager -> WSConn -> AuthMode -> Maybe ConnParams -> m ()
|
|
onConnInit logger manager wsConn authMode connParamsM = do
|
|
res <- runExceptT $ getUserInfo logger manager headers authMode
|
|
case res of
|
|
Left e -> do
|
|
liftIO $ IORef.writeIORef (_wscUser $ WS.getData wsConn) $
|
|
Just $ Left $ qeError e
|
|
let connErr = ConnErrMsg $ qeError e
|
|
logWSEvent logger wsConn $ EConnErr connErr
|
|
sendMsg wsConn $ SMConnErr connErr
|
|
Right userInfo -> do
|
|
liftIO $ IORef.writeIORef (_wscUser $ WS.getData wsConn) $
|
|
Just $ Right userInfo
|
|
sendMsg wsConn SMConnAck
|
|
-- TODO: send it periodically? Why doesn't apollo's protocol use
|
|
-- ping/pong frames of websocket spec?
|
|
sendMsg wsConn SMConnKeepAlive
|
|
where
|
|
headers = [ (CI.mk $ TE.encodeUtf8 h, TE.encodeUtf8 v)
|
|
| (h, v) <- maybe [] Map.toList $ connParamsM >>= _cpHeaders
|
|
]
|
|
|
|
onClose
|
|
:: L.Logger
|
|
-> LiveQueryMap
|
|
-> WS.ConnectionException
|
|
-> WSConn
|
|
-> IO ()
|
|
onClose logger lqMap _ wsConn = do
|
|
logWSEvent logger wsConn EClosed
|
|
operations <- STM.atomically $ ListT.toList $ STMMap.stream opMap
|
|
void $ A.forConcurrently operations $ \(opId, liveQ) ->
|
|
LQ.removeLiveQuery lqMap liveQ (wsId, opId)
|
|
where
|
|
wsId = WS.getWSId wsConn
|
|
opMap = _wscOpMap $ WS.getData wsConn
|
|
|
|
createWSServerEnv
|
|
:: L.Logger
|
|
-> H.Manager -> IORef.IORef (SchemaCache, GCtxMap)
|
|
-> TxRunner -> IO WSServerEnv
|
|
createWSServerEnv logger httpManager gCtxMapRef runTx = do
|
|
(wsServer, lqMap) <-
|
|
STM.atomically $ (,) <$> WS.createWSServer logger <*> LQ.newLiveQueryMap
|
|
return $ WSServerEnv logger wsServer runTx lqMap gCtxMapRef httpManager
|
|
|
|
createWSServerApp :: AuthMode -> WSServerEnv -> WS.ServerApp
|
|
createWSServerApp authMode serverEnv =
|
|
WS.createServerApp (_wseServer serverEnv) handlers
|
|
where
|
|
handlers =
|
|
WS.WSHandlers
|
|
(onConn $ _wseLogger serverEnv)
|
|
(onMessage authMode serverEnv)
|
|
(onClose (_wseLogger serverEnv) $ _wseLiveQMap serverEnv)
|