graphql-engine/server/src-lib/Hasura/GraphQL/Transport/WebSocket.hs

338 lines
12 KiB
Haskell

{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TemplateHaskell #-}
module Hasura.GraphQL.Transport.WebSocket
( createWSServerApp
, createWSServerEnv
) where
import qualified Control.Concurrent.Async as A
import qualified Control.Concurrent.STM as STM
import qualified Data.Aeson as J
import qualified Data.Aeson.Casing as J
import qualified Data.Aeson.TH as J
import qualified Data.ByteString.Lazy as BL
import qualified Data.CaseInsensitive as CI
import qualified Data.HashMap.Strict as Map
import qualified Data.Text as T
import qualified Data.Text.Encoding as TE
import qualified Language.GraphQL.Draft.Syntax as G
import qualified ListT
import qualified Network.HTTP.Client as H
import qualified Network.HTTP.Types.Status as H
import qualified Network.WebSockets as WS
import qualified STMContainers.Map as STMMap
import Control.Concurrent (threadDelay)
import qualified Data.IORef as IORef
import Hasura.GraphQL.Resolve (resolveSelSet)
import Hasura.GraphQL.Resolve.Context (RespTx)
import qualified Hasura.GraphQL.Resolve.LiveQuery as LQ
import Hasura.GraphQL.Schema (GCtxMap, getGCtx)
import Hasura.GraphQL.Transport.HTTP.Protocol
import Hasura.GraphQL.Transport.WebSocket.Protocol
import qualified Hasura.GraphQL.Transport.WebSocket.Server as WS
import Hasura.GraphQL.Validate (validateGQ)
import qualified Hasura.Logging as L
import Hasura.Prelude
import Hasura.RQL.Types
import Hasura.Server.Auth (AuthMode,
getUserInfo)
import qualified Hasura.Server.Query as RQ
-- uniquely identifies an operation
type GOperationId = (WS.WSId, OperationId)
type TxRunner = RespTx -> IO (Either QErr BL.ByteString)
type OperationMap
= STMMap.Map OperationId LQ.LiveQuery
data WSConnData
= WSConnData
-- the role and headers are set only on connection_init message
{ _wscUser :: !(IORef.IORef (Maybe (Either Text UserInfo)))
-- we only care about subscriptions,
-- the other operations (query/mutations)
-- are not tracked here
, _wscOpMap :: !OperationMap
}
type LiveQueryMap = LQ.LiveQueryMap GOperationId
type WSServer = WS.WSServer WSConnData
type WSConn = WS.WSConn WSConnData
sendMsg :: (MonadIO m) => WSConn -> ServerMsg -> m ()
sendMsg wsConn =
liftIO . WS.sendMsg wsConn . encodeServerMsg
data OpDetail
= ODStarted
| ODProtoErr !Text
| ODQueryErr !QErr
| ODCompleted
| ODStopped
deriving (Show, Eq)
$(J.deriveToJSON
J.defaultOptions { J.constructorTagModifier = J.snakeCase . drop 2
, J.sumEncoding = J.TaggedObject "type" "detail"
}
''OpDetail)
data WSEvent
= EAccepted
| ERejected !QErr
| EConnErr !ConnErrMsg
| EOperation !OperationId !(Maybe OperationName) !OpDetail
| EClosed
deriving (Show, Eq)
$(J.deriveToJSON
J.defaultOptions { J.constructorTagModifier = J.snakeCase . drop 1
, J.sumEncoding = J.TaggedObject "type" "detail"
}
''WSEvent)
data WSLog
= WSLog
{ _wslWebsocketId :: !WS.WSId
, _wslUser :: !(Maybe UserVars)
, _wslEvent :: !WSEvent
} deriving (Show, Eq)
$(J.deriveToJSON (J.aesonDrop 4 J.snakeCase) ''WSLog)
instance L.ToEngineLog WSLog where
toEngineLog wsLog =
(L.LevelInfo, "ws-handler", J.toJSON wsLog)
data WSServerEnv
= WSServerEnv
{ _wseLogger :: !L.Logger
, _wseServer :: !WSServer
, _wseRunTx :: !TxRunner
, _wseLiveQMap :: !LiveQueryMap
, _wseGCtxMap :: !(IORef.IORef (SchemaCache, GCtxMap))
, _wseHManager :: !H.Manager
}
onConn :: L.Logger -> WS.OnConnH WSConnData
onConn (L.Logger logger) wsId requestHead = do
res <- runExceptT checkPath
either reject accept res
where
keepAliveAction wsConn = forever $ do
sendMsg wsConn SMConnKeepAlive
threadDelay $ 5 * 1000 * 1000
accept _ = do
logger $ WSLog wsId Nothing EAccepted
connData <- WSConnData <$> IORef.newIORef Nothing <*> STMMap.newIO
let acceptRequest = WS.defaultAcceptRequest
{ WS.acceptSubprotocol = Just "graphql-ws"}
return $ Right (connData, acceptRequest, Just keepAliveAction)
reject qErr = do
logger $ WSLog wsId Nothing $ ERejected qErr
return $ Left $ WS.RejectRequest
(H.statusCode $ qeStatus qErr)
(H.statusMessage $ qeStatus qErr) []
(BL.toStrict $ J.encode $ encodeQErr False qErr)
checkPath =
when (WS.requestPath requestHead /= "/v1alpha1/graphql") $
throw404 "only /v1alpha1/graphql is supported on websockets"
onStart :: WSServerEnv -> WSConn -> StartMsg -> IO ()
onStart serverEnv wsConn (StartMsg opId q) = catchAndIgnore $ do
opM <- liftIO $ STM.atomically $ STMMap.lookup opId opMap
when (isJust opM) $ withComplete $ sendConnErr $
"an operation already exists with this id: " <> unOperationId opId
userInfoM <- liftIO $ IORef.readIORef userInfoR
userInfo <- case userInfoM of
Just (Right userInfo) -> return userInfo
Just (Left initErr) -> do
let connErr = "cannot start as connection_init failed with : " <> initErr
withComplete $ sendConnErr connErr
Nothing -> do
let connErr = "start received before the connection is initialised"
withComplete $ sendConnErr connErr
-- validate and build tx
gCtxMap <- fmap snd $ liftIO $ IORef.readIORef gCtxMapRef
let gCtx = getGCtx (userRole userInfo) gCtxMap
(opTy, fields) <- either (withComplete . preExecErr) return $
runReaderT (validateGQ q) gCtx
let qTx = RQ.setHeadersTx (userVars userInfo) >>
resolveSelSet userInfo gCtx opTy fields
case opTy of
G.OperationTypeSubscription -> do
let lq = LQ.LiveQuery userInfo q
liftIO $ STM.atomically $ STMMap.insert lq opId opMap
liftIO $ LQ.addLiveQuery runTx lqMap lq
qTx (wsId, opId) liveQOnChange
logOpEv ODStarted
_ -> do
logOpEv ODStarted
resp <- liftIO $ runTx qTx
either postExecErr sendSuccResp resp
sendCompleted
where
WSServerEnv logger _ runTx lqMap gCtxMapRef _ = serverEnv
wsId = WS.getWSId wsConn
WSConnData userInfoR opMap = WS.getData wsConn
logOpEv opDet =
logWSEvent logger wsConn $ EOperation opId (_grOperationName q) opDet
sendConnErr connErr = do
sendMsg wsConn $ SMErr $ ErrorMsg opId $ J.toJSON connErr
logOpEv $ ODProtoErr connErr
sendCompleted = do
sendMsg wsConn $ SMComplete $ CompletionMsg opId
logOpEv ODCompleted
postExecErr qErr = do
logOpEv $ ODQueryErr qErr
sendMsg wsConn $ SMData $ DataMsg opId $
GQExecError $ pure $ encodeQErr False qErr
-- why wouldn't pre exec error use graphql response?
preExecErr qErr = do
logOpEv $ ODQueryErr qErr
sendMsg wsConn $ SMErr $ ErrorMsg opId $ encodeQErr False qErr
sendSuccResp bs =
sendMsg wsConn $ SMData $ DataMsg opId $ GQSuccess bs
withComplete :: ExceptT () IO () -> ExceptT () IO a
withComplete action = do
action
sendCompleted
throwError ()
-- on change, send message on the websocket
liveQOnChange resp =
WS.sendMsg wsConn $ encodeServerMsg $ SMData $ DataMsg opId resp
catchAndIgnore :: ExceptT () IO () -> IO ()
catchAndIgnore m = void $ runExceptT m
onMessage
:: AuthMode
-> WSServerEnv
-> WSConn -> BL.ByteString -> IO ()
onMessage authMode serverEnv wsConn msgRaw =
case J.eitherDecode msgRaw of
Left e -> do
let err = ConnErrMsg $ "parsing ClientMessage failed: " <> T.pack e
logWSEvent logger wsConn $ EConnErr err
sendMsg wsConn $ SMConnErr err
Right msg -> case msg of
CMConnInit params -> onConnInit (_wseLogger serverEnv)
(_wseHManager serverEnv)
wsConn authMode params
CMStart startMsg -> onStart serverEnv wsConn startMsg
CMStop stopMsg -> onStop serverEnv wsConn stopMsg
CMConnTerm -> WS.closeConn wsConn "GQL_CONNECTION_TERMINATE received"
where
logger = _wseLogger serverEnv
onStop :: WSServerEnv -> WSConn -> StopMsg -> IO ()
onStop serverEnv wsConn (StopMsg opId) = do
-- probably wrap the whole thing in a single tx?
opM <- liftIO $ STM.atomically $ STMMap.lookup opId opMap
case opM of
Just liveQ -> do
let opNameM = _grOperationName $ LQ._lqRequest liveQ
logWSEvent logger wsConn $ EOperation opId opNameM ODStopped
LQ.removeLiveQuery lqMap liveQ (wsId, opId)
Nothing -> return ()
STM.atomically $ STMMap.delete opId opMap
where
logger = _wseLogger serverEnv
lqMap = _wseLiveQMap serverEnv
wsId = WS.getWSId wsConn
opMap = _wscOpMap $ WS.getData wsConn
logWSEvent
:: (MonadIO m)
=> L.Logger -> WSConn -> WSEvent -> m ()
logWSEvent (L.Logger logger) wsConn wsEv = do
userInfoME <- liftIO $ IORef.readIORef userInfoR
let userInfoM = case userInfoME of
Just (Right userInfo) -> return $ userVars userInfo
_ -> Nothing
liftIO $ logger $ WSLog wsId userInfoM wsEv
where
WSConnData userInfoR _ = WS.getData wsConn
wsId = WS.getWSId wsConn
onConnInit
:: (MonadIO m)
=> L.Logger -> H.Manager -> WSConn -> AuthMode -> Maybe ConnParams -> m ()
onConnInit logger manager wsConn authMode connParamsM = do
res <- runExceptT $ getUserInfo logger manager headers authMode
case res of
Left e -> do
liftIO $ IORef.writeIORef (_wscUser $ WS.getData wsConn) $
Just $ Left $ qeError e
let connErr = ConnErrMsg $ qeError e
logWSEvent logger wsConn $ EConnErr connErr
sendMsg wsConn $ SMConnErr connErr
Right userInfo -> do
liftIO $ IORef.writeIORef (_wscUser $ WS.getData wsConn) $
Just $ Right userInfo
sendMsg wsConn SMConnAck
-- TODO: send it periodically? Why doesn't apollo's protocol use
-- ping/pong frames of websocket spec?
sendMsg wsConn SMConnKeepAlive
where
headers = [ (CI.mk $ TE.encodeUtf8 h, TE.encodeUtf8 v)
| (h, v) <- maybe [] Map.toList $ connParamsM >>= _cpHeaders
]
onClose
:: L.Logger
-> LiveQueryMap
-> WS.ConnectionException
-> WSConn
-> IO ()
onClose logger lqMap _ wsConn = do
logWSEvent logger wsConn EClosed
operations <- STM.atomically $ ListT.toList $ STMMap.stream opMap
void $ A.forConcurrently operations $ \(opId, liveQ) ->
LQ.removeLiveQuery lqMap liveQ (wsId, opId)
where
wsId = WS.getWSId wsConn
opMap = _wscOpMap $ WS.getData wsConn
createWSServerEnv
:: L.Logger
-> H.Manager -> IORef.IORef (SchemaCache, GCtxMap)
-> TxRunner -> IO WSServerEnv
createWSServerEnv logger httpManager gCtxMapRef runTx = do
(wsServer, lqMap) <-
STM.atomically $ (,) <$> WS.createWSServer logger <*> LQ.newLiveQueryMap
return $ WSServerEnv logger wsServer runTx lqMap gCtxMapRef httpManager
createWSServerApp :: AuthMode -> WSServerEnv -> WS.ServerApp
createWSServerApp authMode serverEnv =
WS.createServerApp (_wseServer serverEnv) handlers
where
handlers =
WS.WSHandlers
(onConn $ _wseLogger serverEnv)
(onMessage authMode serverEnv)
(onClose (_wseLogger serverEnv) $ _wseLiveQMap serverEnv)