graphql-engine/server/src-lib/Hasura/GraphQL/Transport/WebSocket.hs

311 lines
11 KiB
Haskell

{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TemplateHaskell #-}
module Hasura.GraphQL.Transport.WebSocket
( createWSServerApp
, createWSServerEnv
) where
import qualified Control.Concurrent.Async as A
import qualified Control.Concurrent.STM as STM
import qualified Data.Aeson as J
import qualified Data.Aeson.Casing as J
import qualified Data.Aeson.TH as J
import qualified Data.ByteString.Lazy as BL
import qualified Data.CaseInsensitive as CI
import qualified Data.HashMap.Strict as Map
import qualified Data.TByteString as TBS
import qualified Data.Text as T
import qualified Data.Text.Encoding as TE
import qualified Language.GraphQL.Draft.Syntax as G
import qualified ListT
import qualified Network.HTTP.Client as H
import qualified Network.HTTP.Types.Status as H
import qualified Network.WebSockets as WS
import qualified STMContainers.Map as STMMap
import Control.Concurrent (threadDelay)
import qualified Data.IORef as IORef
import Hasura.GraphQL.Resolve (resolveSelSet)
import Hasura.GraphQL.Resolve.Context (RespTx)
import qualified Hasura.GraphQL.Resolve.LiveQuery as LQ
import Hasura.GraphQL.Schema (GCtxMap, getGCtx)
import Hasura.GraphQL.Transport.HTTP.Protocol
import Hasura.GraphQL.Transport.WebSocket.Protocol
import qualified Hasura.GraphQL.Transport.WebSocket.Server as WS
import Hasura.GraphQL.Validate (validateGQ)
import qualified Hasura.Logging as L
import Hasura.Prelude
import Hasura.RQL.Types
import Hasura.Server.Auth (AuthMode,
getUserInfo)
import qualified Hasura.Server.Query as RQ
-- uniquely identifies an operation
type GOperationId = (WS.WSId, OperationId)
type TxRunner = RespTx -> IO (Either QErr BL.ByteString)
type OperationMap
= STMMap.Map OperationId LQ.LiveQuery
data WSConnData
= WSConnData
-- the role and headers are set only on connection_init message
{ _wscUser :: !(IORef.IORef (Maybe UserInfo))
-- we only care about subscriptions,
-- the other operations (query/mutations)
-- are not tracked here
, _wscOpMap :: !OperationMap
}
type LiveQueryMap = LQ.LiveQueryMap GOperationId
type WSServer = WS.WSServer WSConnData
type WSConn = WS.WSConn WSConnData
sendMsg :: (MonadIO m) => WSConn -> ServerMsg -> m ()
sendMsg wsConn =
liftIO . WS.sendMsg wsConn . encodeServerMsg
data SubsDetail
= SDStarted
| SDStopped
deriving (Show, Eq)
$(J.deriveToJSON
J.defaultOptions { J.constructorTagModifier = J.snakeCase . drop 2
, J.sumEncoding = J.TaggedObject "type" "detail"
}
''SubsDetail)
data OpDetail
= ODCompleted
| ODError !QErr
deriving (Show, Eq)
$(J.deriveToJSON
J.defaultOptions { J.constructorTagModifier = J.snakeCase . drop 2
, J.sumEncoding = J.TaggedObject "type" "detail"
}
''OpDetail)
data WSEvent
= EAccepted
| ERejected !QErr
| EProtocolError !TBS.TByteString !ConnErrMsg
| EOperation !OperationId !OpDetail
| ESubscription !OperationId !SubsDetail
| EClosed
deriving (Show, Eq)
$(J.deriveToJSON
J.defaultOptions { J.constructorTagModifier = J.snakeCase . drop 1
, J.sumEncoding = J.TaggedObject "type" "detail"
}
''WSEvent)
data WSLog
= WSLog
{ _wslWebsocketId :: !WS.WSId
, _wslEvent :: !WSEvent
} deriving (Show, Eq)
$(J.deriveToJSON (J.aesonDrop 4 J.snakeCase) ''WSLog)
instance L.ToEngineLog WSLog where
toEngineLog wsLog =
(L.LevelInfo, "ws-handler", J.toJSON wsLog)
data WSServerEnv
= WSServerEnv
{ _wseLogger :: !L.Logger
, _wseServer :: !WSServer
, _wseRunTx :: !TxRunner
, _wseLiveQMap :: !LiveQueryMap
, _wseGCtxMap :: !(IORef.IORef (SchemaCache, GCtxMap))
, _wseHManager :: !H.Manager
}
onConn :: L.Logger -> WS.OnConnH WSConnData
onConn logger wsId requestHead = do
res <- runExceptT checkPath
either reject accept res
where
keepAliveAction wsConn = forever $ do
sendMsg wsConn SMConnKeepAlive
threadDelay $ 5 * 1000 * 1000
accept _ = do
logger $ WSLog wsId EAccepted
connData <- WSConnData <$> IORef.newIORef Nothing <*> STMMap.newIO
let acceptRequest = WS.defaultAcceptRequest
{ WS.acceptSubprotocol = Just "graphql-ws"}
return $ Right (connData, acceptRequest, Just keepAliveAction)
reject qErr = do
logger $ WSLog wsId $ ERejected qErr
return $ Left $ WS.RejectRequest
(H.statusCode $ qeStatus qErr)
(H.statusMessage $ qeStatus qErr) []
(BL.toStrict $ J.encode $ encodeQErr False qErr)
checkPath =
when (WS.requestPath requestHead /= "/v1alpha1/graphql") $
throw404 "only /v1alpha1/graphql is supported on websockets"
onStart :: WSServerEnv -> WSConn -> StartMsg -> IO ()
onStart serverEnv wsConn msg@(StartMsg opId q) = catchAndSend $ do
opM <- liftIO $ STM.atomically $ STMMap.lookup opId opMap
when (isJust opM) $ withExceptT preExecErr $ loggingQErr $
throw400 UnexpectedPayload $
"an operation already exists with this id: " <> unOperationId opId
userInfoM <- liftIO $ IORef.readIORef userInfoR
userInfo <- case userInfoM of
Just userInfo -> return userInfo
Nothing -> do
let err = "start received before the connection is initialised"
liftIO $ logger $ WSLog wsId $
-- TODO: we are encoding the start msg back into a bytestring
-- should we be throwing protocol error here?
EProtocolError (TBS.fromLBS $ J.encode msg) err
throwError $ SMConnErr err
-- validate and build tx
gCtxMap <- fmap snd $ liftIO $ IORef.readIORef gCtxMapRef
let gCtx = getGCtx (userRole userInfo) gCtxMap
(opTy, fields) <- withExceptT preExecErr $ loggingQErr $
runReaderT (validateGQ q) gCtx
let qTx = RQ.setHeadersTx userInfo >>
resolveSelSet userInfo gCtx opTy fields
case opTy of
G.OperationTypeSubscription -> do
let lq = LQ.LiveQuery userInfo q
liftIO $ STM.atomically $ STMMap.insert lq opId opMap
liftIO $ LQ.addLiveQuery runTx lqMap lq
qTx (wsId, opId) liveQOnChange
liftIO $ logger $ WSLog wsId $ ESubscription opId SDStarted
_ -> withExceptT postExecErr $ loggingQErr $ do
resp <- ExceptT $ runTx qTx
sendMsg wsConn $ SMData $ DataMsg opId $ GQSuccess resp
sendMsg wsConn $ SMComplete $ CompletionMsg opId
liftIO $ logger $ WSLog wsId $ EOperation opId ODCompleted
where
(WSServerEnv logger _ runTx lqMap gCtxMapRef _) = serverEnv
wsId = WS.getWSId wsConn
(WSConnData userInfoR opMap) = WS.getData wsConn
-- on change, send message on the websocket
liveQOnChange resp = WS.sendMsg wsConn $ encodeServerMsg $ SMData $
DataMsg opId resp
loggingQErr m = catchError m $ \qErr -> do
liftIO $ logger $ WSLog wsId $ EOperation opId $ ODError qErr
throwError qErr
preExecErr qErr = SMErr $ ErrorMsg opId $ encodeQErr False qErr
postExecErr qErr = SMData $ DataMsg opId $ GQExecError
[encodeQErr False qErr]
catchAndSend :: ExceptT ServerMsg IO () -> IO ()
catchAndSend m = do
res <- runExceptT m
either (sendMsg wsConn) return res
onMessage
:: AuthMode
-> WSServerEnv
-> WSConn -> BL.ByteString -> IO ()
onMessage authMode serverEnv wsConn msgRaw =
case J.eitherDecode msgRaw of
Left e -> do
let err = ConnErrMsg $ "parsing ClientMessage failed: " <> T.pack e
liftIO $ logger $ WSLog (WS.getWSId wsConn) $
EProtocolError (TBS.fromLBS msgRaw) err
sendMsg wsConn $ SMConnErr err
Right msg -> case msg of
CMConnInit params -> onConnInit (_wseLogger serverEnv)
(_wseHManager serverEnv)
wsConn authMode params
CMStart startMsg -> onStart serverEnv wsConn startMsg
CMStop stopMsg -> onStop serverEnv wsConn stopMsg
CMConnTerm -> WS.closeConn wsConn "GQL_CONNECTION_TERMINATE received"
where
logger = _wseLogger serverEnv
onStop :: WSServerEnv -> WSConn -> StopMsg -> IO ()
onStop serverEnv wsConn (StopMsg opId) = do
-- probably wrap the whole thing in a single tx?
opM <- liftIO $ STM.atomically $ STMMap.lookup opId opMap
case opM of
Just liveQ -> do
liftIO $ logger $ WSLog wsId $ ESubscription opId SDStopped
LQ.removeLiveQuery lqMap liveQ (wsId, opId)
Nothing -> return ()
STM.atomically $ STMMap.delete opId opMap
where
logger = _wseLogger serverEnv
lqMap = _wseLiveQMap serverEnv
wsId = WS.getWSId wsConn
opMap = _wscOpMap $ WS.getData wsConn
onConnInit
:: (MonadIO m)
=> L.Logger -> H.Manager -> WSConn -> AuthMode -> ConnParams -> m ()
onConnInit logger manager wsConn authMode connParams = do
res <- runExceptT $ getUserInfo logger manager headers authMode
case res of
Left e ->
liftIO $ WS.closeConn wsConn $
BL.fromStrict $ TE.encodeUtf8 $ qeError e
Right userInfo -> do
liftIO $ IORef.writeIORef (_wscUser $ WS.getData wsConn) $ Just userInfo
sendMsg wsConn SMConnAck
-- TODO: send it periodically? Why doesn't apollo's protocol use
-- ping/pong frames of websocket spec?
sendMsg wsConn SMConnKeepAlive
where
headers = [ (CI.mk $ TE.encodeUtf8 h, TE.encodeUtf8 v)
| (h, v) <- maybe [] Map.toList $ _cpHeaders connParams
]
onClose
:: L.Logger
-> LiveQueryMap
-> WS.ConnectionException
-> WSConn
-> IO ()
onClose logger lqMap _ wsConn = do
logger $ WSLog wsId EClosed
operations <- STM.atomically $ ListT.toList $ STMMap.stream opMap
void $ A.forConcurrently operations $ \(opId, liveQ) ->
LQ.removeLiveQuery lqMap liveQ (wsId, opId)
where
wsId = WS.getWSId wsConn
opMap = _wscOpMap $ WS.getData wsConn
createWSServerEnv
:: L.Logger
-> H.Manager -> IORef.IORef (SchemaCache, GCtxMap)
-> TxRunner -> IO WSServerEnv
createWSServerEnv logger httpManager gCtxMapRef runTx = do
(wsServer, lqMap) <-
STM.atomically $ (,) <$> WS.createWSServer logger <*> LQ.newLiveQueryMap
return $ WSServerEnv logger wsServer runTx lqMap gCtxMapRef httpManager
createWSServerApp :: AuthMode -> WSServerEnv -> WS.ServerApp
createWSServerApp authMode serverEnv =
WS.createServerApp (_wseServer serverEnv) handlers
where
handlers =
WS.WSHandlers
(onConn $ _wseLogger serverEnv)
(onMessage authMode serverEnv)
(onClose (_wseLogger serverEnv) $ _wseLiveQMap serverEnv)