graphql-engine/server/src-lib/Hasura/GraphQL/Transport/WebSocket.hs
Vamshi Surabhi ada35c2236 add user information to transactions in ws protocol (fix #1867) (#1869)
This bug introduced with the refactor in 739ff80a51.
2019-03-26 16:38:27 +05:30

449 lines
16 KiB
Haskell

{-# LANGUAGE RankNTypes #-}
module Hasura.GraphQL.Transport.WebSocket
( createWSServerApp
, createWSServerEnv
) where
import qualified Control.Concurrent.Async as A
import qualified Control.Concurrent.STM as STM
import qualified Data.Aeson as J
import qualified Data.Aeson.Casing as J
import qualified Data.Aeson.TH as J
import qualified Data.ByteString.Lazy as BL
import qualified Data.CaseInsensitive as CI
import qualified Data.HashMap.Strict as Map
import qualified Data.Text as T
import qualified Data.Text.Encoding as TE
import qualified Language.GraphQL.Draft.Syntax as G
import qualified ListT
import qualified Network.HTTP.Client as H
import qualified Network.HTTP.Types as H
import qualified Network.WebSockets as WS
import qualified StmContainers.Map as STMMap
import Control.Concurrent (threadDelay)
import Data.ByteString (ByteString)
import qualified Data.IORef as IORef
import Hasura.EncJSON
import Hasura.GraphQL.Context (GCtx)
import qualified Hasura.GraphQL.Execute as E
import qualified Hasura.GraphQL.Resolve as R
import qualified Hasura.GraphQL.Resolve.LiveQuery as LQ
import Hasura.GraphQL.Transport.HTTP.Protocol
import Hasura.GraphQL.Transport.WebSocket.Protocol
import qualified Hasura.GraphQL.Transport.WebSocket.Server as WS
import qualified Hasura.GraphQL.Validate as V
import qualified Hasura.Logging as L
import Hasura.Prelude
import Hasura.RQL.Types
import Hasura.Server.Auth (AuthMode,
getUserInfo)
import Hasura.Server.Cors
import Hasura.Server.Utils (bsToTxt)
-- uniquely identifies an operation
type GOperationId = (WS.WSId, OperationId)
type OperationMap
= STMMap.Map OperationId LQ.LiveQuery
newtype WsHeaders
= WsHeaders { unWsHeaders :: [H.Header] }
deriving (Show, Eq)
data WSConnState
-- headers from the client for websockets
= CSNotInitialised !WsHeaders
| CSInitError Text
-- headers from the client (in conn params) to forward to the remote schema
| CSInitialised UserInfo [H.Header]
data WSConnData
= WSConnData
-- the role and headers are set only on connection_init message
{ _wscUser :: !(IORef.IORef WSConnState)
-- we only care about subscriptions,
-- the other operations (query/mutations)
-- are not tracked here
, _wscOpMap :: !OperationMap
}
type LiveQueryMap = LQ.LiveQueryMap GOperationId
type WSServer = WS.WSServer WSConnData
type WSConn = WS.WSConn WSConnData
sendMsg :: (MonadIO m) => WSConn -> ServerMsg -> m ()
sendMsg wsConn =
liftIO . WS.sendMsg wsConn . encodeServerMsg
data OpDetail
= ODStarted
| ODProtoErr !Text
| ODQueryErr !QErr
| ODCompleted
| ODStopped
deriving (Show, Eq)
$(J.deriveToJSON
J.defaultOptions { J.constructorTagModifier = J.snakeCase . drop 2
, J.sumEncoding = J.TaggedObject "type" "detail"
}
''OpDetail)
data WSEvent
= EAccepted
| ERejected !QErr
| EConnErr !ConnErrMsg
| EOperation !OperationId !(Maybe OperationName) !OpDetail
| EClosed
deriving (Show, Eq)
$(J.deriveToJSON
J.defaultOptions { J.constructorTagModifier = J.snakeCase . drop 1
, J.sumEncoding = J.TaggedObject "type" "detail"
}
''WSEvent)
data WSLog
= WSLog
{ _wslWebsocketId :: !WS.WSId
, _wslUser :: !(Maybe UserVars)
, _wslEvent :: !WSEvent
, _wslMsg :: !(Maybe Text)
} deriving (Show, Eq)
$(J.deriveToJSON (J.aesonDrop 4 J.snakeCase) ''WSLog)
instance L.ToEngineLog WSLog where
toEngineLog wsLog =
(L.LevelInfo, "ws-handler", J.toJSON wsLog)
data WSServerEnv
= WSServerEnv
{ _wseLogger :: !L.Logger
, _wseServer :: !WSServer
, _wseRunTx :: !LQ.TxRunner
, _wseLiveQMap :: !LiveQueryMap
, _wseGCtxMap :: !(IORef.IORef SchemaCache)
, _wseHManager :: !H.Manager
, _wseCorsPolicy :: !CorsPolicy
, _wseSQLCtx :: !SQLGenCtx
}
onConn :: L.Logger -> CorsPolicy -> WS.OnConnH WSConnData
onConn (L.Logger logger) corsPolicy wsId requestHead = do
res <- runExceptT $ do
checkPath
let reqHdrs = WS.requestHeaders requestHead
headers <- maybe (return reqHdrs) (flip enforceCors reqHdrs . snd) getOrigin
return $ WsHeaders $ filterWsHeaders headers
either reject accept res
where
keepAliveAction wsConn = forever $ do
sendMsg wsConn SMConnKeepAlive
threadDelay $ 5 * 1000 * 1000
accept hdrs = do
logger $ WSLog wsId Nothing EAccepted Nothing
connData <- WSConnData
<$> IORef.newIORef (CSNotInitialised hdrs)
<*> STMMap.newIO
let acceptRequest = WS.defaultAcceptRequest
{ WS.acceptSubprotocol = Just "graphql-ws"}
return $ Right (connData, acceptRequest, Just keepAliveAction)
reject qErr = do
logger $ WSLog wsId Nothing (ERejected qErr) Nothing
return $ Left $ WS.RejectRequest
(H.statusCode $ qeStatus qErr)
(H.statusMessage $ qeStatus qErr) []
(BL.toStrict $ J.encode $ encodeGQLErr False qErr)
checkPath =
when (WS.requestPath requestHead /= "/v1alpha1/graphql") $
throw404 "only /v1alpha1/graphql is supported on websockets"
getOrigin =
find ((==) "Origin" . fst) (WS.requestHeaders requestHead)
enforceCors :: ByteString -> [H.Header] -> ExceptT QErr IO [H.Header]
enforceCors origin reqHdrs = case cpConfig corsPolicy of
CCAllowAll -> return reqHdrs
CCDisabled readCookie ->
if readCookie
then return reqHdrs
else do
liftIO $ logger $ WSLog wsId Nothing EAccepted (Just corsNote)
return $ filter (\h -> fst h /= "Cookie") reqHdrs
CCAllowedOrigins ds
-- if the origin is in our cors domains, no error
| bsToTxt origin `elem` dmFqdns ds -> return reqHdrs
-- if current origin is part of wildcard domain list, no error
| inWildcardList ds (bsToTxt origin) -> return reqHdrs
-- otherwise error
| otherwise -> corsErr
filterWsHeaders hdrs = flip filter hdrs $ \(n, _) ->
n `notElem` [ "sec-websocket-key"
, "sec-websocket-version"
, "upgrade"
, "connection"
]
corsErr = throw400 AccessDenied
"received origin header does not match configured CORS domains"
corsNote = "Cookie is not read when CORS is disabled, because it is a potential "
<> "security issue. If you're already handling CORS before Hasura and enforcing "
<> "CORS on websocket connections, then you can use the flag --ws-read-cookie or "
<> "HASURA_GRAPHQL_WS_READ_COOKIE to force read cookie when CORS is disabled."
onStart :: WSServerEnv -> WSConn -> StartMsg -> BL.ByteString -> IO ()
onStart serverEnv wsConn (StartMsg opId q) msgRaw = catchAndIgnore $ do
opM <- liftIO $ STM.atomically $ STMMap.lookup opId opMap
when (isJust opM) $ withComplete $ sendConnErr $
"an operation already exists with this id: " <> unOperationId opId
userInfoM <- liftIO $ IORef.readIORef userInfoR
(userInfo, reqHdrs) <- case userInfoM of
CSInitialised userInfo reqHdrs -> return (userInfo, reqHdrs)
CSInitError initErr -> do
let connErr = "cannot start as connection_init failed with : " <> initErr
withComplete $ sendConnErr connErr
CSNotInitialised _ -> do
let connErr = "start received before the connection is initialised"
withComplete $ sendConnErr connErr
sc <- liftIO $ IORef.readIORef gCtxMapRef
execPlanE <- runExceptT $ E.getExecPlan userInfo sc q
execPlan <- either (withComplete . preExecErr) return execPlanE
case execPlan of
E.GExPHasura gCtx rootSelSet ->
runHasuraGQ userInfo gCtx rootSelSet
E.GExPRemote rsi opDef ->
runRemoteGQ userInfo reqHdrs opDef rsi
where
runHasuraGQ :: UserInfo -> GCtx -> V.RootSelSet -> ExceptT () IO ()
runHasuraGQ userInfo gCtx rootSelSet =
case rootSelSet of
V.RQuery selSet ->
execQueryOrMut $ withUserInfo userInfo $
R.resolveQuerySelSet userInfo gCtx sqlGenCtx selSet
V.RMutation selSet ->
execQueryOrMut $ withUserInfo userInfo $
R.resolveMutSelSet userInfo gCtx sqlGenCtx selSet
V.RSubscription fld -> do
let tx = withUserInfo userInfo $
R.resolveSubsFld userInfo gCtx sqlGenCtx fld
let lq = LQ.LiveQuery userInfo q
liftIO $ STM.atomically $ STMMap.insert lq opId opMap
liftIO $ LQ.addLiveQuery runTx lqMap lq
tx (wsId, opId) liveQOnChange
logOpEv ODStarted
execQueryOrMut tx = do
logOpEv ODStarted
resp <- liftIO $ runTx tx
either postExecErr sendSuccResp resp
sendCompleted
runRemoteGQ :: UserInfo -> [H.Header]
-> G.TypedOperationDefinition -> RemoteSchemaInfo
-> ExceptT () IO ()
runRemoteGQ userInfo reqHdrs opDef rsi = do
when (G._todType opDef == G.OperationTypeSubscription) $
withComplete $ preExecErr $
err400 NotSupported "subscription to remote server is not supported"
-- if it's not a subscription, use HTTP to execute the query on the remote
-- server
-- try to parse the (apollo protocol) websocket frame and get only the
-- payload
sockPayload <- onLeft (J.eitherDecode msgRaw) $
const $ withComplete $ preExecErr $
err500 Unexpected "invalid websocket payload"
let payload = J.encode $ _wpPayload sockPayload
resp <- runExceptT $ E.execRemoteGQ httpMgr userInfo reqHdrs
payload rsi opDef
either postExecErr sendSuccResp resp
sendCompleted
WSServerEnv logger _ runTx lqMap gCtxMapRef httpMgr _ sqlGenCtx = serverEnv
wsId = WS.getWSId wsConn
WSConnData userInfoR opMap = WS.getData wsConn
logOpEv opDet =
logWSEvent logger wsConn $ EOperation opId (_grOperationName q) opDet
sendConnErr connErr = do
sendMsg wsConn $ SMErr $ ErrorMsg opId $ J.toJSON connErr
logOpEv $ ODProtoErr connErr
sendCompleted = do
sendMsg wsConn $ SMComplete $ CompletionMsg opId
logOpEv ODCompleted
postExecErr qErr = do
logOpEv $ ODQueryErr qErr
sendMsg wsConn $ SMData $ DataMsg opId $
GQExecError $ pure $ encodeQErr False qErr
-- why wouldn't pre exec error use graphql response?
preExecErr qErr = do
logOpEv $ ODQueryErr qErr
sendMsg wsConn $ SMErr $ ErrorMsg opId $ encodeQErr False qErr
sendSuccResp encJson =
sendMsg wsConn $ SMData $ DataMsg opId $ GQSuccess $ encJToLBS encJson
withComplete :: ExceptT () IO () -> ExceptT () IO a
withComplete action = do
action
sendCompleted
throwError ()
-- on change, send message on the websocket
liveQOnChange resp =
WS.sendMsg wsConn $ encodeServerMsg $ SMData $ DataMsg opId resp
catchAndIgnore :: ExceptT () IO () -> IO ()
catchAndIgnore m = void $ runExceptT m
onMessage
:: AuthMode
-> WSServerEnv
-> WSConn -> BL.ByteString -> IO ()
onMessage authMode serverEnv wsConn msgRaw =
case J.eitherDecode msgRaw of
Left e -> do
let err = ConnErrMsg $ "parsing ClientMessage failed: " <> T.pack e
logWSEvent logger wsConn $ EConnErr err
sendMsg wsConn $ SMConnErr err
Right msg -> case msg of
CMConnInit params -> onConnInit (_wseLogger serverEnv)
(_wseHManager serverEnv)
wsConn authMode params
CMStart startMsg -> onStart serverEnv wsConn startMsg msgRaw
CMStop stopMsg -> onStop serverEnv wsConn stopMsg
CMConnTerm -> WS.closeConn wsConn "GQL_CONNECTION_TERMINATE received"
where
logger = _wseLogger serverEnv
onStop :: WSServerEnv -> WSConn -> StopMsg -> IO ()
onStop serverEnv wsConn (StopMsg opId) = do
-- probably wrap the whole thing in a single tx?
opM <- liftIO $ STM.atomically $ STMMap.lookup opId opMap
case opM of
Just liveQ -> do
let opNameM = _grOperationName $ LQ._lqRequest liveQ
logWSEvent logger wsConn $ EOperation opId opNameM ODStopped
LQ.removeLiveQuery lqMap liveQ (wsId, opId)
Nothing -> return ()
STM.atomically $ STMMap.delete opId opMap
where
logger = _wseLogger serverEnv
lqMap = _wseLiveQMap serverEnv
wsId = WS.getWSId wsConn
opMap = _wscOpMap $ WS.getData wsConn
logWSEvent
:: (MonadIO m)
=> L.Logger -> WSConn -> WSEvent -> m ()
logWSEvent (L.Logger logger) wsConn wsEv = do
userInfoME <- liftIO $ IORef.readIORef userInfoR
let userInfoM = case userInfoME of
CSInitialised userInfo _ -> return $ userVars userInfo
_ -> Nothing
liftIO $ logger $ WSLog wsId userInfoM wsEv Nothing
where
WSConnData userInfoR _ = WS.getData wsConn
wsId = WS.getWSId wsConn
onConnInit
:: (MonadIO m)
=> L.Logger -> H.Manager -> WSConn -> AuthMode -> Maybe ConnParams -> m ()
onConnInit logger manager wsConn authMode connParamsM = do
headers <- mkHeaders <$> liftIO (IORef.readIORef (_wscUser $ WS.getData wsConn))
res <- runExceptT $ getUserInfo logger manager headers authMode
case res of
Left e -> do
liftIO $ IORef.writeIORef (_wscUser $ WS.getData wsConn) $
CSInitError $ qeError e
let connErr = ConnErrMsg $ qeError e
logWSEvent logger wsConn $ EConnErr connErr
sendMsg wsConn $ SMConnErr connErr
Right userInfo -> do
liftIO $ IORef.writeIORef (_wscUser $ WS.getData wsConn) $
CSInitialised userInfo paramHeaders
sendMsg wsConn SMConnAck
-- TODO: send it periodically? Why doesn't apollo's protocol use
-- ping/pong frames of websocket spec?
sendMsg wsConn SMConnKeepAlive
where
mkHeaders st =
paramHeaders ++ getClientHdrs st
paramHeaders =
[ (CI.mk $ TE.encodeUtf8 h, TE.encodeUtf8 v)
| (h, v) <- maybe [] Map.toList $ connParamsM >>= _cpHeaders
]
getClientHdrs st = case st of
CSNotInitialised h -> unWsHeaders h
_ -> []
onClose
:: L.Logger
-> LiveQueryMap
-> WS.ConnectionException
-> WSConn
-> IO ()
onClose logger lqMap _ wsConn = do
logWSEvent logger wsConn EClosed
operations <- STM.atomically $ ListT.toList $ STMMap.listT opMap
void $ A.forConcurrently operations $ \(opId, liveQ) ->
LQ.removeLiveQuery lqMap liveQ (wsId, opId)
where
wsId = WS.getWSId wsConn
opMap = _wscOpMap $ WS.getData wsConn
createWSServerEnv
:: L.Logger
-> H.Manager -> SQLGenCtx -> IORef.IORef SchemaCache
-> LQ.TxRunner -> CorsPolicy -> IO WSServerEnv
createWSServerEnv logger httpManager sqlGenCtx cacheRef runTx corsPolicy = do
(wsServer, lqMap) <-
STM.atomically $ (,) <$> WS.createWSServer logger <*> LQ.newLiveQueryMap
return $ WSServerEnv logger wsServer runTx lqMap cacheRef
httpManager corsPolicy sqlGenCtx
createWSServerApp :: AuthMode -> WSServerEnv -> WS.ServerApp
createWSServerApp authMode serverEnv =
WS.createServerApp (_wseServer serverEnv) handlers
where
handlers =
WS.WSHandlers
(onConn (_wseLogger serverEnv) (_wseCorsPolicy serverEnv))
(onMessage authMode serverEnv)
(onClose (_wseLogger serverEnv) $ _wseLiveQMap serverEnv)
-- | TODO:
-- | The following ADT is required so that we can parse the incoming websocket
-- | frame, and only pick the payload, for remote schema queries.
-- | Ideally we should use `StartMsg` from Websocket.Protocol, but as
-- | `GraphQLRequest` doesn't have a ToJSON instance we are using our own type to
-- | get only the payload
data WebsocketPayload
= WebsocketPayload
{ _wpId :: !Text
, _wpType :: !Text
, _wpPayload :: !J.Value
} deriving (Show, Eq)
$(J.deriveJSON (J.aesonDrop 3 J.snakeCase) ''WebsocketPayload)