mirror of
https://github.com/hasura/graphql-engine.git
synced 2024-12-18 04:51:35 +03:00
ada35c2236
This bug introduced with the refactor in 739ff80a51
.
449 lines
16 KiB
Haskell
449 lines
16 KiB
Haskell
{-# LANGUAGE RankNTypes #-}
|
|
|
|
module Hasura.GraphQL.Transport.WebSocket
|
|
( createWSServerApp
|
|
, createWSServerEnv
|
|
) where
|
|
|
|
import qualified Control.Concurrent.Async as A
|
|
import qualified Control.Concurrent.STM as STM
|
|
import qualified Data.Aeson as J
|
|
import qualified Data.Aeson.Casing as J
|
|
import qualified Data.Aeson.TH as J
|
|
import qualified Data.ByteString.Lazy as BL
|
|
import qualified Data.CaseInsensitive as CI
|
|
import qualified Data.HashMap.Strict as Map
|
|
import qualified Data.Text as T
|
|
import qualified Data.Text.Encoding as TE
|
|
import qualified Language.GraphQL.Draft.Syntax as G
|
|
import qualified ListT
|
|
import qualified Network.HTTP.Client as H
|
|
import qualified Network.HTTP.Types as H
|
|
import qualified Network.WebSockets as WS
|
|
import qualified StmContainers.Map as STMMap
|
|
|
|
import Control.Concurrent (threadDelay)
|
|
import Data.ByteString (ByteString)
|
|
import qualified Data.IORef as IORef
|
|
|
|
import Hasura.EncJSON
|
|
import Hasura.GraphQL.Context (GCtx)
|
|
import qualified Hasura.GraphQL.Execute as E
|
|
import qualified Hasura.GraphQL.Resolve as R
|
|
import qualified Hasura.GraphQL.Resolve.LiveQuery as LQ
|
|
import Hasura.GraphQL.Transport.HTTP.Protocol
|
|
import Hasura.GraphQL.Transport.WebSocket.Protocol
|
|
import qualified Hasura.GraphQL.Transport.WebSocket.Server as WS
|
|
import qualified Hasura.GraphQL.Validate as V
|
|
import qualified Hasura.Logging as L
|
|
import Hasura.Prelude
|
|
import Hasura.RQL.Types
|
|
import Hasura.Server.Auth (AuthMode,
|
|
getUserInfo)
|
|
import Hasura.Server.Cors
|
|
import Hasura.Server.Utils (bsToTxt)
|
|
|
|
-- uniquely identifies an operation
|
|
type GOperationId = (WS.WSId, OperationId)
|
|
|
|
type OperationMap
|
|
= STMMap.Map OperationId LQ.LiveQuery
|
|
|
|
newtype WsHeaders
|
|
= WsHeaders { unWsHeaders :: [H.Header] }
|
|
deriving (Show, Eq)
|
|
|
|
data WSConnState
|
|
-- headers from the client for websockets
|
|
= CSNotInitialised !WsHeaders
|
|
| CSInitError Text
|
|
-- headers from the client (in conn params) to forward to the remote schema
|
|
| CSInitialised UserInfo [H.Header]
|
|
|
|
data WSConnData
|
|
= WSConnData
|
|
-- the role and headers are set only on connection_init message
|
|
{ _wscUser :: !(IORef.IORef WSConnState)
|
|
-- we only care about subscriptions,
|
|
-- the other operations (query/mutations)
|
|
-- are not tracked here
|
|
, _wscOpMap :: !OperationMap
|
|
}
|
|
|
|
type LiveQueryMap = LQ.LiveQueryMap GOperationId
|
|
type WSServer = WS.WSServer WSConnData
|
|
|
|
type WSConn = WS.WSConn WSConnData
|
|
sendMsg :: (MonadIO m) => WSConn -> ServerMsg -> m ()
|
|
sendMsg wsConn =
|
|
liftIO . WS.sendMsg wsConn . encodeServerMsg
|
|
|
|
data OpDetail
|
|
= ODStarted
|
|
| ODProtoErr !Text
|
|
| ODQueryErr !QErr
|
|
| ODCompleted
|
|
| ODStopped
|
|
deriving (Show, Eq)
|
|
$(J.deriveToJSON
|
|
J.defaultOptions { J.constructorTagModifier = J.snakeCase . drop 2
|
|
, J.sumEncoding = J.TaggedObject "type" "detail"
|
|
}
|
|
''OpDetail)
|
|
|
|
data WSEvent
|
|
= EAccepted
|
|
| ERejected !QErr
|
|
| EConnErr !ConnErrMsg
|
|
| EOperation !OperationId !(Maybe OperationName) !OpDetail
|
|
| EClosed
|
|
deriving (Show, Eq)
|
|
$(J.deriveToJSON
|
|
J.defaultOptions { J.constructorTagModifier = J.snakeCase . drop 1
|
|
, J.sumEncoding = J.TaggedObject "type" "detail"
|
|
}
|
|
''WSEvent)
|
|
|
|
data WSLog
|
|
= WSLog
|
|
{ _wslWebsocketId :: !WS.WSId
|
|
, _wslUser :: !(Maybe UserVars)
|
|
, _wslEvent :: !WSEvent
|
|
, _wslMsg :: !(Maybe Text)
|
|
} deriving (Show, Eq)
|
|
$(J.deriveToJSON (J.aesonDrop 4 J.snakeCase) ''WSLog)
|
|
|
|
instance L.ToEngineLog WSLog where
|
|
toEngineLog wsLog =
|
|
(L.LevelInfo, "ws-handler", J.toJSON wsLog)
|
|
|
|
data WSServerEnv
|
|
= WSServerEnv
|
|
{ _wseLogger :: !L.Logger
|
|
, _wseServer :: !WSServer
|
|
, _wseRunTx :: !LQ.TxRunner
|
|
, _wseLiveQMap :: !LiveQueryMap
|
|
, _wseGCtxMap :: !(IORef.IORef SchemaCache)
|
|
, _wseHManager :: !H.Manager
|
|
, _wseCorsPolicy :: !CorsPolicy
|
|
, _wseSQLCtx :: !SQLGenCtx
|
|
}
|
|
|
|
onConn :: L.Logger -> CorsPolicy -> WS.OnConnH WSConnData
|
|
onConn (L.Logger logger) corsPolicy wsId requestHead = do
|
|
res <- runExceptT $ do
|
|
checkPath
|
|
let reqHdrs = WS.requestHeaders requestHead
|
|
headers <- maybe (return reqHdrs) (flip enforceCors reqHdrs . snd) getOrigin
|
|
return $ WsHeaders $ filterWsHeaders headers
|
|
either reject accept res
|
|
|
|
where
|
|
keepAliveAction wsConn = forever $ do
|
|
sendMsg wsConn SMConnKeepAlive
|
|
threadDelay $ 5 * 1000 * 1000
|
|
|
|
accept hdrs = do
|
|
logger $ WSLog wsId Nothing EAccepted Nothing
|
|
connData <- WSConnData
|
|
<$> IORef.newIORef (CSNotInitialised hdrs)
|
|
<*> STMMap.newIO
|
|
let acceptRequest = WS.defaultAcceptRequest
|
|
{ WS.acceptSubprotocol = Just "graphql-ws"}
|
|
return $ Right (connData, acceptRequest, Just keepAliveAction)
|
|
|
|
reject qErr = do
|
|
logger $ WSLog wsId Nothing (ERejected qErr) Nothing
|
|
return $ Left $ WS.RejectRequest
|
|
(H.statusCode $ qeStatus qErr)
|
|
(H.statusMessage $ qeStatus qErr) []
|
|
(BL.toStrict $ J.encode $ encodeGQLErr False qErr)
|
|
|
|
checkPath =
|
|
when (WS.requestPath requestHead /= "/v1alpha1/graphql") $
|
|
throw404 "only /v1alpha1/graphql is supported on websockets"
|
|
|
|
getOrigin =
|
|
find ((==) "Origin" . fst) (WS.requestHeaders requestHead)
|
|
|
|
enforceCors :: ByteString -> [H.Header] -> ExceptT QErr IO [H.Header]
|
|
enforceCors origin reqHdrs = case cpConfig corsPolicy of
|
|
CCAllowAll -> return reqHdrs
|
|
CCDisabled readCookie ->
|
|
if readCookie
|
|
then return reqHdrs
|
|
else do
|
|
liftIO $ logger $ WSLog wsId Nothing EAccepted (Just corsNote)
|
|
return $ filter (\h -> fst h /= "Cookie") reqHdrs
|
|
CCAllowedOrigins ds
|
|
-- if the origin is in our cors domains, no error
|
|
| bsToTxt origin `elem` dmFqdns ds -> return reqHdrs
|
|
-- if current origin is part of wildcard domain list, no error
|
|
| inWildcardList ds (bsToTxt origin) -> return reqHdrs
|
|
-- otherwise error
|
|
| otherwise -> corsErr
|
|
|
|
filterWsHeaders hdrs = flip filter hdrs $ \(n, _) ->
|
|
n `notElem` [ "sec-websocket-key"
|
|
, "sec-websocket-version"
|
|
, "upgrade"
|
|
, "connection"
|
|
]
|
|
|
|
corsErr = throw400 AccessDenied
|
|
"received origin header does not match configured CORS domains"
|
|
|
|
corsNote = "Cookie is not read when CORS is disabled, because it is a potential "
|
|
<> "security issue. If you're already handling CORS before Hasura and enforcing "
|
|
<> "CORS on websocket connections, then you can use the flag --ws-read-cookie or "
|
|
<> "HASURA_GRAPHQL_WS_READ_COOKIE to force read cookie when CORS is disabled."
|
|
|
|
|
|
onStart :: WSServerEnv -> WSConn -> StartMsg -> BL.ByteString -> IO ()
|
|
onStart serverEnv wsConn (StartMsg opId q) msgRaw = catchAndIgnore $ do
|
|
|
|
opM <- liftIO $ STM.atomically $ STMMap.lookup opId opMap
|
|
|
|
when (isJust opM) $ withComplete $ sendConnErr $
|
|
"an operation already exists with this id: " <> unOperationId opId
|
|
|
|
userInfoM <- liftIO $ IORef.readIORef userInfoR
|
|
(userInfo, reqHdrs) <- case userInfoM of
|
|
CSInitialised userInfo reqHdrs -> return (userInfo, reqHdrs)
|
|
CSInitError initErr -> do
|
|
let connErr = "cannot start as connection_init failed with : " <> initErr
|
|
withComplete $ sendConnErr connErr
|
|
CSNotInitialised _ -> do
|
|
let connErr = "start received before the connection is initialised"
|
|
withComplete $ sendConnErr connErr
|
|
|
|
sc <- liftIO $ IORef.readIORef gCtxMapRef
|
|
execPlanE <- runExceptT $ E.getExecPlan userInfo sc q
|
|
execPlan <- either (withComplete . preExecErr) return execPlanE
|
|
case execPlan of
|
|
E.GExPHasura gCtx rootSelSet ->
|
|
runHasuraGQ userInfo gCtx rootSelSet
|
|
E.GExPRemote rsi opDef ->
|
|
runRemoteGQ userInfo reqHdrs opDef rsi
|
|
where
|
|
runHasuraGQ :: UserInfo -> GCtx -> V.RootSelSet -> ExceptT () IO ()
|
|
runHasuraGQ userInfo gCtx rootSelSet =
|
|
case rootSelSet of
|
|
V.RQuery selSet ->
|
|
execQueryOrMut $ withUserInfo userInfo $
|
|
R.resolveQuerySelSet userInfo gCtx sqlGenCtx selSet
|
|
V.RMutation selSet ->
|
|
execQueryOrMut $ withUserInfo userInfo $
|
|
R.resolveMutSelSet userInfo gCtx sqlGenCtx selSet
|
|
V.RSubscription fld -> do
|
|
let tx = withUserInfo userInfo $
|
|
R.resolveSubsFld userInfo gCtx sqlGenCtx fld
|
|
let lq = LQ.LiveQuery userInfo q
|
|
liftIO $ STM.atomically $ STMMap.insert lq opId opMap
|
|
liftIO $ LQ.addLiveQuery runTx lqMap lq
|
|
tx (wsId, opId) liveQOnChange
|
|
logOpEv ODStarted
|
|
|
|
execQueryOrMut tx = do
|
|
logOpEv ODStarted
|
|
resp <- liftIO $ runTx tx
|
|
either postExecErr sendSuccResp resp
|
|
sendCompleted
|
|
|
|
runRemoteGQ :: UserInfo -> [H.Header]
|
|
-> G.TypedOperationDefinition -> RemoteSchemaInfo
|
|
-> ExceptT () IO ()
|
|
runRemoteGQ userInfo reqHdrs opDef rsi = do
|
|
when (G._todType opDef == G.OperationTypeSubscription) $
|
|
withComplete $ preExecErr $
|
|
err400 NotSupported "subscription to remote server is not supported"
|
|
|
|
-- if it's not a subscription, use HTTP to execute the query on the remote
|
|
-- server
|
|
-- try to parse the (apollo protocol) websocket frame and get only the
|
|
-- payload
|
|
sockPayload <- onLeft (J.eitherDecode msgRaw) $
|
|
const $ withComplete $ preExecErr $
|
|
err500 Unexpected "invalid websocket payload"
|
|
let payload = J.encode $ _wpPayload sockPayload
|
|
resp <- runExceptT $ E.execRemoteGQ httpMgr userInfo reqHdrs
|
|
payload rsi opDef
|
|
either postExecErr sendSuccResp resp
|
|
sendCompleted
|
|
|
|
WSServerEnv logger _ runTx lqMap gCtxMapRef httpMgr _ sqlGenCtx = serverEnv
|
|
|
|
wsId = WS.getWSId wsConn
|
|
WSConnData userInfoR opMap = WS.getData wsConn
|
|
|
|
logOpEv opDet =
|
|
logWSEvent logger wsConn $ EOperation opId (_grOperationName q) opDet
|
|
|
|
sendConnErr connErr = do
|
|
sendMsg wsConn $ SMErr $ ErrorMsg opId $ J.toJSON connErr
|
|
logOpEv $ ODProtoErr connErr
|
|
|
|
sendCompleted = do
|
|
sendMsg wsConn $ SMComplete $ CompletionMsg opId
|
|
logOpEv ODCompleted
|
|
|
|
postExecErr qErr = do
|
|
logOpEv $ ODQueryErr qErr
|
|
sendMsg wsConn $ SMData $ DataMsg opId $
|
|
GQExecError $ pure $ encodeQErr False qErr
|
|
|
|
-- why wouldn't pre exec error use graphql response?
|
|
preExecErr qErr = do
|
|
logOpEv $ ODQueryErr qErr
|
|
sendMsg wsConn $ SMErr $ ErrorMsg opId $ encodeQErr False qErr
|
|
|
|
sendSuccResp encJson =
|
|
sendMsg wsConn $ SMData $ DataMsg opId $ GQSuccess $ encJToLBS encJson
|
|
|
|
withComplete :: ExceptT () IO () -> ExceptT () IO a
|
|
withComplete action = do
|
|
action
|
|
sendCompleted
|
|
throwError ()
|
|
|
|
-- on change, send message on the websocket
|
|
liveQOnChange resp =
|
|
WS.sendMsg wsConn $ encodeServerMsg $ SMData $ DataMsg opId resp
|
|
|
|
catchAndIgnore :: ExceptT () IO () -> IO ()
|
|
catchAndIgnore m = void $ runExceptT m
|
|
|
|
onMessage
|
|
:: AuthMode
|
|
-> WSServerEnv
|
|
-> WSConn -> BL.ByteString -> IO ()
|
|
onMessage authMode serverEnv wsConn msgRaw =
|
|
case J.eitherDecode msgRaw of
|
|
Left e -> do
|
|
let err = ConnErrMsg $ "parsing ClientMessage failed: " <> T.pack e
|
|
logWSEvent logger wsConn $ EConnErr err
|
|
sendMsg wsConn $ SMConnErr err
|
|
|
|
Right msg -> case msg of
|
|
CMConnInit params -> onConnInit (_wseLogger serverEnv)
|
|
(_wseHManager serverEnv)
|
|
wsConn authMode params
|
|
CMStart startMsg -> onStart serverEnv wsConn startMsg msgRaw
|
|
CMStop stopMsg -> onStop serverEnv wsConn stopMsg
|
|
CMConnTerm -> WS.closeConn wsConn "GQL_CONNECTION_TERMINATE received"
|
|
where
|
|
logger = _wseLogger serverEnv
|
|
|
|
onStop :: WSServerEnv -> WSConn -> StopMsg -> IO ()
|
|
onStop serverEnv wsConn (StopMsg opId) = do
|
|
-- probably wrap the whole thing in a single tx?
|
|
opM <- liftIO $ STM.atomically $ STMMap.lookup opId opMap
|
|
case opM of
|
|
Just liveQ -> do
|
|
let opNameM = _grOperationName $ LQ._lqRequest liveQ
|
|
logWSEvent logger wsConn $ EOperation opId opNameM ODStopped
|
|
LQ.removeLiveQuery lqMap liveQ (wsId, opId)
|
|
Nothing -> return ()
|
|
STM.atomically $ STMMap.delete opId opMap
|
|
where
|
|
logger = _wseLogger serverEnv
|
|
lqMap = _wseLiveQMap serverEnv
|
|
wsId = WS.getWSId wsConn
|
|
opMap = _wscOpMap $ WS.getData wsConn
|
|
|
|
logWSEvent
|
|
:: (MonadIO m)
|
|
=> L.Logger -> WSConn -> WSEvent -> m ()
|
|
logWSEvent (L.Logger logger) wsConn wsEv = do
|
|
userInfoME <- liftIO $ IORef.readIORef userInfoR
|
|
let userInfoM = case userInfoME of
|
|
CSInitialised userInfo _ -> return $ userVars userInfo
|
|
_ -> Nothing
|
|
liftIO $ logger $ WSLog wsId userInfoM wsEv Nothing
|
|
where
|
|
WSConnData userInfoR _ = WS.getData wsConn
|
|
wsId = WS.getWSId wsConn
|
|
|
|
onConnInit
|
|
:: (MonadIO m)
|
|
=> L.Logger -> H.Manager -> WSConn -> AuthMode -> Maybe ConnParams -> m ()
|
|
onConnInit logger manager wsConn authMode connParamsM = do
|
|
headers <- mkHeaders <$> liftIO (IORef.readIORef (_wscUser $ WS.getData wsConn))
|
|
res <- runExceptT $ getUserInfo logger manager headers authMode
|
|
case res of
|
|
Left e -> do
|
|
liftIO $ IORef.writeIORef (_wscUser $ WS.getData wsConn) $
|
|
CSInitError $ qeError e
|
|
let connErr = ConnErrMsg $ qeError e
|
|
logWSEvent logger wsConn $ EConnErr connErr
|
|
sendMsg wsConn $ SMConnErr connErr
|
|
Right userInfo -> do
|
|
liftIO $ IORef.writeIORef (_wscUser $ WS.getData wsConn) $
|
|
CSInitialised userInfo paramHeaders
|
|
sendMsg wsConn SMConnAck
|
|
-- TODO: send it periodically? Why doesn't apollo's protocol use
|
|
-- ping/pong frames of websocket spec?
|
|
sendMsg wsConn SMConnKeepAlive
|
|
where
|
|
mkHeaders st =
|
|
paramHeaders ++ getClientHdrs st
|
|
|
|
paramHeaders =
|
|
[ (CI.mk $ TE.encodeUtf8 h, TE.encodeUtf8 v)
|
|
| (h, v) <- maybe [] Map.toList $ connParamsM >>= _cpHeaders
|
|
]
|
|
|
|
getClientHdrs st = case st of
|
|
CSNotInitialised h -> unWsHeaders h
|
|
_ -> []
|
|
|
|
onClose
|
|
:: L.Logger
|
|
-> LiveQueryMap
|
|
-> WS.ConnectionException
|
|
-> WSConn
|
|
-> IO ()
|
|
onClose logger lqMap _ wsConn = do
|
|
logWSEvent logger wsConn EClosed
|
|
operations <- STM.atomically $ ListT.toList $ STMMap.listT opMap
|
|
void $ A.forConcurrently operations $ \(opId, liveQ) ->
|
|
LQ.removeLiveQuery lqMap liveQ (wsId, opId)
|
|
where
|
|
wsId = WS.getWSId wsConn
|
|
opMap = _wscOpMap $ WS.getData wsConn
|
|
|
|
createWSServerEnv
|
|
:: L.Logger
|
|
-> H.Manager -> SQLGenCtx -> IORef.IORef SchemaCache
|
|
-> LQ.TxRunner -> CorsPolicy -> IO WSServerEnv
|
|
createWSServerEnv logger httpManager sqlGenCtx cacheRef runTx corsPolicy = do
|
|
(wsServer, lqMap) <-
|
|
STM.atomically $ (,) <$> WS.createWSServer logger <*> LQ.newLiveQueryMap
|
|
return $ WSServerEnv logger wsServer runTx lqMap cacheRef
|
|
httpManager corsPolicy sqlGenCtx
|
|
|
|
createWSServerApp :: AuthMode -> WSServerEnv -> WS.ServerApp
|
|
createWSServerApp authMode serverEnv =
|
|
WS.createServerApp (_wseServer serverEnv) handlers
|
|
where
|
|
handlers =
|
|
WS.WSHandlers
|
|
(onConn (_wseLogger serverEnv) (_wseCorsPolicy serverEnv))
|
|
(onMessage authMode serverEnv)
|
|
(onClose (_wseLogger serverEnv) $ _wseLiveQMap serverEnv)
|
|
|
|
|
|
-- | TODO:
|
|
-- | The following ADT is required so that we can parse the incoming websocket
|
|
-- | frame, and only pick the payload, for remote schema queries.
|
|
-- | Ideally we should use `StartMsg` from Websocket.Protocol, but as
|
|
-- | `GraphQLRequest` doesn't have a ToJSON instance we are using our own type to
|
|
-- | get only the payload
|
|
data WebsocketPayload
|
|
= WebsocketPayload
|
|
{ _wpId :: !Text
|
|
, _wpType :: !Text
|
|
, _wpPayload :: !J.Value
|
|
} deriving (Show, Eq)
|
|
$(J.deriveJSON (J.aesonDrop 3 J.snakeCase) ''WebsocketPayload)
|