read cookie while initialising websocket connection (fix #1660) (#1668)

* read cookie while initialising websocket connection (fix #1660)

* add tests for cookie on websocket init

* fix logic for tests

* enforce cors, and flag to force read cookie when cors disabled

  - as browsers don't enforce SOP on websockets, we enforce CORS policy
  on websocket handshake
  - if CORS is disabled, by default cookie is not read (because XSS
  risk!). Add special flag to force override this behaviour

* add log and forward origin header to webhook

  - add log notice when cors is disabled, and cookie is not read on
  websocket handshake
  - forward origin header to webhook in POST mode. So that when CORS is
  disabled, webhook can also enforce CORS independently.

* add docs, and forward all client headers to webhook
This commit is contained in:
Anon Ray 2019-03-04 07:46:53 +00:00 committed by Shahidh K Muhammed
parent f794653b69
commit 02d80c9ac6
14 changed files with 376 additions and 49 deletions

View File

@ -198,6 +198,53 @@ kill_hge_and_combine_hpc_reports
unset HASURA_GRAPHQL_CORS_DOMAIN unset HASURA_GRAPHQL_CORS_DOMAIN
# test websocket transport with initial cookie header
echo -e "\n<########## TEST GRAPHQL-ENGINE WITH COOKIE IN WEBSOCKET INIT ########>\n"
export HASURA_GRAPHQL_AUTH_HOOK="http://localhost:9876/auth"
export HASURA_GRAPHQL_AUTH_HOOK_MODE="POST"
python3 test_cookie_webhook.py > "$OUTPUT_FOLDER/cookie_webhook.log" 2>&1 & WHC_PID=$!
wait_for_port 9876
"$GRAPHQL_ENGINE" serve >> "$OUTPUT_FOLDER/graphql-engine.log" 2>&1 & PID=$!
wait_for_port 8080
echo "testcase 1: read cookie, cors enabled"
pytest -vv --hge-url="$HGE_URL" --pg-url="$HASURA_GRAPHQL_DATABASE_URL" --hge-key="$HASURA_GRAPHQL_ADMIN_SECRET" --test-ws-init-cookie=read test_websocket_init_cookie.py
kill -INT $PID
sleep 1
echo "testcase 2: no read cookie, cors disabled"
"$GRAPHQL_ENGINE" serve --disable-cors >> "$OUTPUT_FOLDER/graphql-engine.log" 2>&1 & PID=$!
wait_for_port 8080
pytest -vv --hge-url="$HGE_URL" --pg-url="$HASURA_GRAPHQL_DATABASE_URL" --hge-key="$HASURA_GRAPHQL_ADMIN_SECRET" --test-ws-init-cookie=noread test_websocket_init_cookie.py
kill -INT $PID
sleep 1
echo "testcase 3: read cookie, cors disabled and ws-read-cookie"
export HASURA_GRAPHQL_WS_READ_COOKIE="true"
"$GRAPHQL_ENGINE" serve --disable-cors >> "$OUTPUT_FOLDER/graphql-engine.log" 2>&1 & PID=$!
wait_for_port 8080
pytest -vv --hge-url="$HGE_URL" --pg-url="$HASURA_GRAPHQL_DATABASE_URL" --hge-key="$HASURA_GRAPHQL_ADMIN_SECRET" --test-ws-init-cookie=read test_websocket_init_cookie.py
kill -INT $PID
kill -INT $WHC_PID
unset HASURA_GRAPHQL_WS_READ_COOKIE
unset HASURA_GRAPHQL_AUTH_HOOK
unset HASURA_GRAPHQL_AUTH_HOOK_MODE
sleep 4
combine_hpc_reports
echo -e "\n<########## TEST GRAPHQL-ENGINE WITH GRAPHQL DISABLED ########>\n" echo -e "\n<########## TEST GRAPHQL-ENGINE WITH GRAPHQL DISABLED ########>\n"
export HASURA_GRAPHQL_ENABLED_APIS="metadata" export HASURA_GRAPHQL_ENABLED_APIS="metadata"
@ -241,6 +288,7 @@ pytest -vv --hge-url="$HGE_URL" --pg-url="$HASURA_GRAPHQL_DATABASE_URL" --hge-ke
kill_hge_and_combine_hpc_reports kill_hge_and_combine_hpc_reports
# webhook tests # webhook tests
if [ $EUID != 0 ] ; then if [ $EUID != 0 ] ; then

View File

@ -124,3 +124,8 @@ Examples:
Top-level domains are not considered as part of wildcard domains. You Top-level domains are not considered as part of wildcard domains. You
have to add them separately. E.g - ``https://*.foo.com`` doesn't include have to add them separately. E.g - ``https://*.foo.com`` doesn't include
``https://foo.com``. ``https://foo.com``.
You can tell Hasura to disable handling CORS entirely via the ``--disable-cors``
flag. Hasura will not respond with CORS headers. You can use this option if
you're already handling CORS on a reverse proxy etc.

View File

@ -107,6 +107,13 @@ For ``serve`` sub-command these are the flags and ENV variables available:
- N/A - N/A
- Disable CORS. Do not send any CORS headers on any request. - Disable CORS. Do not send any CORS headers on any request.
* - ``--ws-read-cookie``
- ``HASURA_GRAPHQL_WS_READ_COOKIE``
- Read cookie on WebSocket initial handshake, even when CORS is disabled.
This can be a potential security flaw! Please make sure you know what
you're doing.This configuration is only applicable when CORS is disabled.
(``"true"`` or ``"false"``. Default: false)
* - ``--enable-telemetry <true|false>`` * - ``--enable-telemetry <true|false>``
- ``HASURA_GRAPHQL_ENABLE_TELEMETRY`` - ``HASURA_GRAPHQL_ENABLE_TELEMETRY``
- Enable anonymous telemetry (default: true) - Enable anonymous telemetry (default: true)

View File

@ -35,6 +35,34 @@ Hasura GraphQL engine uses the `GraphQL over Websocket Protocol
`apollographql/subscriptions-transport-ws <https://github.com/apollographql/subscriptions-transport-ws>`_ library `apollographql/subscriptions-transport-ws <https://github.com/apollographql/subscriptions-transport-ws>`_ library
for sending and receiving events. for sending and receiving events.
Cookie and Websockets
---------------------
Hasura GraphQL engine will read cookies sent by the browser when initiating a
websocket connection. Browser will send the cookie only if it is a secure cookie
(``secure`` flag in the cookie) and if the cookie has a ``HttpOnly`` flag.
Hasura will read this cookie and use it as headers when resolving authorization
(i.e. when resolving the auth webhook).
Cookies, Websockets and CORS
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
As browsers don't enforce Same Origin Policy (SOP) for Websockets, Hasura server
enforces the CORS rules when accepting the websocket connection.
It uses the provided CORS configuration (as per :ref:`configure-cors`).
1. When it is ``*``, the cookie is read the CORS check is not enforced.
2. When there are explicit domains, only if the request originates from one of
the listed domains, the cookie will be read.
3. If CORS is disabled, the default behaviour is, the cookie won't be read
(because of potential security issues). To override the behaviour, you can
use the flag ``--ws-read-cookie`` or environment variable
``HASURA_GRAPHQL_WS_READ_COOKIE``. See
:doc:`../deployment/graphql-engine-flags/reference` for the setting.
.. toctree:: .. toctree::
:maxdepth: 1 :maxdepth: 1
:hidden: :hidden:

View File

@ -72,9 +72,11 @@ parseHGECommand =
<*> parseCorsConfig <*> parseCorsConfig
<*> parseEnableConsole <*> parseEnableConsole
<*> parseEnableTelemetry <*> parseEnableTelemetry
<*> parseWsReadCookie
<*> parseStringifyNum <*> parseStringifyNum
<*> parseEnabledAPIs <*> parseEnabledAPIs
parseArgs :: IO HGEOptions parseArgs :: IO HGEOptions
parseArgs = do parseArgs = do
rawHGEOpts <- execParser opts rawHGEOpts <- execParser opts
@ -104,7 +106,7 @@ main = do
let logger = mkLogger loggerCtx let logger = mkLogger loggerCtx
case hgeCmd of case hgeCmd of
HCServe so@(ServeOptions port host cp isoL mAdminSecret mAuthHook mJwtSecret HCServe so@(ServeOptions port host cp isoL mAdminSecret mAuthHook mJwtSecret
mUnAuthRole corsCfg enableConsole enableTelemetry strfyNum enabledAPIs) -> do mUnAuthRole corsCfg enableConsole enableTelemetry strfyNum enabledAPIs) -> do
-- log serve options -- log serve options
unLogger logger $ serveOptsToLog so unLogger logger $ serveOptsToLog so
hloggerCtx <- mkLoggerCtx $ defaultLoggerSettings False hloggerCtx <- mkLoggerCtx $ defaultLoggerSettings False

View File

@ -23,6 +23,7 @@ import qualified Network.WebSockets as WS
import qualified STMContainers.Map as STMMap import qualified STMContainers.Map as STMMap
import Control.Concurrent (threadDelay) import Control.Concurrent (threadDelay)
import Data.ByteString (ByteString)
import qualified Data.IORef as IORef import qualified Data.IORef as IORef
import Hasura.GraphQL.Resolve (resolveSelSet) import Hasura.GraphQL.Resolve (resolveSelSet)
@ -42,6 +43,8 @@ import Hasura.Prelude
import Hasura.RQL.Types import Hasura.RQL.Types
import Hasura.Server.Auth (AuthMode, import Hasura.Server.Auth (AuthMode,
getUserInfo) getUserInfo)
import Hasura.Server.Cors
import Hasura.Server.Utils (bsToTxt)
-- uniquely identifies an operation -- uniquely identifies an operation
type GOperationId = (WS.WSId, OperationId) type GOperationId = (WS.WSId, OperationId)
@ -51,9 +54,15 @@ type TxRunner = LazyRespTx -> IO (Either QErr BL.ByteString)
type OperationMap type OperationMap
= STMMap.Map OperationId LQ.LiveQuery = STMMap.Map OperationId LQ.LiveQuery
newtype WsHeaders
= WsHeaders { unWsHeaders :: [H.Header] }
deriving (Show, Eq)
data WSConnState data WSConnState
= CSNotInitialised -- headers from the client for websockets
= CSNotInitialised !WsHeaders
| CSInitError Text | CSInitError Text
-- headers from the client (in conn params) to forward to the remote schema
| CSInitialised UserInfo [H.Header] | CSInitialised UserInfo [H.Header]
data WSConnData data WSConnData
@ -105,6 +114,7 @@ data WSLog
{ _wslWebsocketId :: !WS.WSId { _wslWebsocketId :: !WS.WSId
, _wslUser :: !(Maybe UserVars) , _wslUser :: !(Maybe UserVars)
, _wslEvent :: !WSEvent , _wslEvent :: !WSEvent
, _wslMsg :: !(Maybe Text)
} deriving (Show, Eq) } deriving (Show, Eq)
$(J.deriveToJSON (J.aesonDrop 4 J.snakeCase) ''WSLog) $(J.deriveToJSON (J.aesonDrop 4 J.snakeCase) ''WSLog)
@ -114,36 +124,41 @@ instance L.ToEngineLog WSLog where
data WSServerEnv data WSServerEnv
= WSServerEnv = WSServerEnv
{ _wseLogger :: !L.Logger { _wseLogger :: !L.Logger
, _wseServer :: !WSServer , _wseServer :: !WSServer
, _wseRunTx :: !TxRunner , _wseRunTx :: !TxRunner
, _wseLiveQMap :: !LiveQueryMap , _wseLiveQMap :: !LiveQueryMap
, _wseGCtxMap :: !(IORef.IORef SchemaCache) , _wseGCtxMap :: !(IORef.IORef SchemaCache)
, _wseHManager :: !H.Manager , _wseHManager :: !H.Manager
, _wseSQLCtx :: !SQLGenCtx , _wseCorsPolicy :: !CorsPolicy
, _wseSQLCtx :: !SQLGenCtx
} }
onConn :: L.Logger -> WS.OnConnH WSConnData onConn :: L.Logger -> CorsPolicy -> WS.OnConnH WSConnData
onConn (L.Logger logger) wsId requestHead = do onConn (L.Logger logger) corsPolicy wsId requestHead = do
res <- runExceptT checkPath res <- runExceptT $ do
checkPath
let reqHdrs = WS.requestHeaders requestHead
headers <- maybe (return reqHdrs) (flip enforceCors reqHdrs . snd) getOrigin
return $ WsHeaders $ filterWsHeaders headers
either reject accept res either reject accept res
where
where
keepAliveAction wsConn = forever $ do keepAliveAction wsConn = forever $ do
sendMsg wsConn SMConnKeepAlive sendMsg wsConn SMConnKeepAlive
threadDelay $ 5 * 1000 * 1000 threadDelay $ 5 * 1000 * 1000
accept _ = do accept hdrs = do
logger $ WSLog wsId Nothing EAccepted logger $ WSLog wsId Nothing EAccepted Nothing
connData <- WSConnData connData <- WSConnData
<$> IORef.newIORef CSNotInitialised <$> IORef.newIORef (CSNotInitialised hdrs)
<*> STMMap.newIO <*> STMMap.newIO
let acceptRequest = WS.defaultAcceptRequest let acceptRequest = WS.defaultAcceptRequest
{ WS.acceptSubprotocol = Just "graphql-ws"} { WS.acceptSubprotocol = Just "graphql-ws"}
return $ Right (connData, acceptRequest, Just keepAliveAction) return $ Right (connData, acceptRequest, Just keepAliveAction)
reject qErr = do reject qErr = do
logger $ WSLog wsId Nothing $ ERejected qErr logger $ WSLog wsId Nothing (ERejected qErr) Nothing
return $ Left $ WS.RejectRequest return $ Left $ WS.RejectRequest
(H.statusCode $ qeStatus qErr) (H.statusCode $ qeStatus qErr)
(H.statusMessage $ qeStatus qErr) [] (H.statusMessage $ qeStatus qErr) []
@ -153,6 +168,42 @@ onConn (L.Logger logger) wsId requestHead = do
when (WS.requestPath requestHead /= "/v1alpha1/graphql") $ when (WS.requestPath requestHead /= "/v1alpha1/graphql") $
throw404 "only /v1alpha1/graphql is supported on websockets" throw404 "only /v1alpha1/graphql is supported on websockets"
getOrigin =
find ((==) "Origin" . fst) (WS.requestHeaders requestHead)
enforceCors :: ByteString -> [H.Header] -> ExceptT QErr IO [H.Header]
enforceCors origin reqHdrs = case cpConfig corsPolicy of
CCAllowAll -> return reqHdrs
CCDisabled readCookie ->
if readCookie
then return reqHdrs
else do
liftIO $ logger $ WSLog wsId Nothing EAccepted (Just corsNote)
return $ filter (\h -> fst h /= "Cookie") reqHdrs
CCAllowedOrigins ds
-- if the origin is in our cors domains, no error
| bsToTxt origin `elem` dmFqdns ds -> return reqHdrs
-- if current origin is part of wildcard domain list, no error
| inWildcardList ds (bsToTxt origin) -> return reqHdrs
-- otherwise error
| otherwise -> corsErr
filterWsHeaders hdrs = flip filter hdrs $ \(n, _) ->
n `notElem` [ "sec-websocket-key"
, "sec-websocket-version"
, "upgrade"
, "connection"
]
corsErr = throw400 AccessDenied
"received origin header does not match configured CORS domains"
corsNote = "Cookie is not read when CORS is disabled, because it is a potential "
<> "security issue. If you're already handling CORS before Hasura and enforcing "
<> "CORS on websocket connections, then you can use the flag --ws-read-cookie or "
<> "HASURA_GRAPHQL_WS_READ_COOKIE to force read cookie when CORS is disabled."
onStart :: WSServerEnv -> WSConn -> StartMsg -> BL.ByteString -> IO () onStart :: WSServerEnv -> WSConn -> StartMsg -> BL.ByteString -> IO ()
onStart serverEnv wsConn (StartMsg opId q) msgRaw = catchAndIgnore $ do onStart serverEnv wsConn (StartMsg opId q) msgRaw = catchAndIgnore $ do
@ -167,7 +218,7 @@ onStart serverEnv wsConn (StartMsg opId q) msgRaw = catchAndIgnore $ do
CSInitError initErr -> do CSInitError initErr -> do
let connErr = "cannot start as connection_init failed with : " <> initErr let connErr = "cannot start as connection_init failed with : " <> initErr
withComplete $ sendConnErr connErr withComplete $ sendConnErr connErr
CSNotInitialised -> do CSNotInitialised _ -> do
let connErr = "start received before the connection is initialised" let connErr = "start received before the connection is initialised"
withComplete $ sendConnErr connErr withComplete $ sendConnErr connErr
@ -219,7 +270,8 @@ onStart serverEnv wsConn (StartMsg opId q) msgRaw = catchAndIgnore $ do
either postExecErr sendSuccResp resp either postExecErr sendSuccResp resp
sendCompleted sendCompleted
WSServerEnv logger _ runTx lqMap gCtxMapRef httpMgr sqlGenCtx = serverEnv
WSServerEnv logger _ runTx lqMap gCtxMapRef httpMgr _ sqlGenCtx = serverEnv
wsId = WS.getWSId wsConn wsId = WS.getWSId wsConn
WSConnData userInfoR opMap = WS.getData wsConn WSConnData userInfoR opMap = WS.getData wsConn
@ -306,7 +358,7 @@ logWSEvent (L.Logger logger) wsConn wsEv = do
let userInfoM = case userInfoME of let userInfoM = case userInfoME of
CSInitialised userInfo _ -> return $ userVars userInfo CSInitialised userInfo _ -> return $ userVars userInfo
_ -> Nothing _ -> Nothing
liftIO $ logger $ WSLog wsId userInfoM wsEv liftIO $ logger $ WSLog wsId userInfoM wsEv Nothing
where where
WSConnData userInfoR _ = WS.getData wsConn WSConnData userInfoR _ = WS.getData wsConn
wsId = WS.getWSId wsConn wsId = WS.getWSId wsConn
@ -315,6 +367,7 @@ onConnInit
:: (MonadIO m) :: (MonadIO m)
=> L.Logger -> H.Manager -> WSConn -> AuthMode -> Maybe ConnParams -> m () => L.Logger -> H.Manager -> WSConn -> AuthMode -> Maybe ConnParams -> m ()
onConnInit logger manager wsConn authMode connParamsM = do onConnInit logger manager wsConn authMode connParamsM = do
headers <- mkHeaders <$> liftIO (IORef.readIORef (_wscUser $ WS.getData wsConn))
res <- runExceptT $ getUserInfo logger manager headers authMode res <- runExceptT $ getUserInfo logger manager headers authMode
case res of case res of
Left e -> do Left e -> do
@ -325,15 +378,23 @@ onConnInit logger manager wsConn authMode connParamsM = do
sendMsg wsConn $ SMConnErr connErr sendMsg wsConn $ SMConnErr connErr
Right userInfo -> do Right userInfo -> do
liftIO $ IORef.writeIORef (_wscUser $ WS.getData wsConn) $ liftIO $ IORef.writeIORef (_wscUser $ WS.getData wsConn) $
CSInitialised userInfo headers CSInitialised userInfo paramHeaders
sendMsg wsConn SMConnAck sendMsg wsConn SMConnAck
-- TODO: send it periodically? Why doesn't apollo's protocol use -- TODO: send it periodically? Why doesn't apollo's protocol use
-- ping/pong frames of websocket spec? -- ping/pong frames of websocket spec?
sendMsg wsConn SMConnKeepAlive sendMsg wsConn SMConnKeepAlive
where where
headers = [ (CI.mk $ TE.encodeUtf8 h, TE.encodeUtf8 v) mkHeaders st =
| (h, v) <- maybe [] Map.toList $ connParamsM >>= _cpHeaders paramHeaders ++ getClientHdrs st
]
paramHeaders =
[ (CI.mk $ TE.encodeUtf8 h, TE.encodeUtf8 v)
| (h, v) <- maybe [] Map.toList $ connParamsM >>= _cpHeaders
]
getClientHdrs st = case st of
CSNotInitialised h -> unWsHeaders h
_ -> []
onClose onClose
:: L.Logger :: L.Logger
@ -353,11 +414,11 @@ onClose logger lqMap _ wsConn = do
createWSServerEnv createWSServerEnv
:: L.Logger :: L.Logger
-> H.Manager -> SQLGenCtx -> IORef.IORef SchemaCache -> H.Manager -> SQLGenCtx -> IORef.IORef SchemaCache
-> TxRunner -> IO WSServerEnv -> TxRunner -> CorsPolicy -> IO WSServerEnv
createWSServerEnv logger httpManager sqlGenCtx cacheRef runTx = do createWSServerEnv logger httpManager sqlGenCtx cacheRef runTx corsPolicy = do
(wsServer, lqMap) <- (wsServer, lqMap) <-
STM.atomically $ (,) <$> WS.createWSServer logger <*> LQ.newLiveQueryMap STM.atomically $ (,) <$> WS.createWSServer logger <*> LQ.newLiveQueryMap
return $ WSServerEnv logger wsServer runTx lqMap cacheRef httpManager sqlGenCtx return $ WSServerEnv logger wsServer runTx lqMap cacheRef httpManager corsPolicy sqlGenCtx
createWSServerApp :: AuthMode -> WSServerEnv -> WS.ServerApp createWSServerApp :: AuthMode -> WSServerEnv -> WS.ServerApp
createWSServerApp authMode serverEnv = createWSServerApp authMode serverEnv =
@ -365,6 +426,6 @@ createWSServerApp authMode serverEnv =
where where
handlers = handlers =
WS.WSHandlers WS.WSHandlers
(onConn $ _wseLogger serverEnv) (onConn (_wseLogger serverEnv) (_wseCorsPolicy serverEnv))
(onMessage authMode serverEnv) (onMessage authMode serverEnv)
(onClose (_wseLogger serverEnv) $ _wseLiveQMap serverEnv) (onClose (_wseLogger serverEnv) $ _wseLiveQMap serverEnv)

View File

@ -321,9 +321,12 @@ mkWaiApp isoLevel loggerCtx pool httpManager strfyNum mode corsCfg enableConsole
httpApp corsCfg serverCtx enableConsole enableTelemetry httpApp corsCfg serverCtx enableConsole enableTelemetry
let runTx tx = runExceptT $ runLazyTx pool isoLevel tx let runTx tx = runExceptT $ runLazyTx pool isoLevel tx
corsPolicy = mkDefaultCorsPolicy corsCfg
sqlGenCtx = SQLGenCtx strfyNum sqlGenCtx = SQLGenCtx strfyNum
wsServerEnv <- WS.createWSServerEnv (scLogger serverCtx) httpManager sqlGenCtx cacheRef runTx wsServerEnv <- WS.createWSServerEnv (scLogger serverCtx) httpManager sqlGenCtx
cacheRef runTx corsPolicy
let wsServerApp = WS.createWSServerApp mode wsServerEnv let wsServerApp = WS.createWSServerApp mode wsServerEnv
return (WS.websocketsOr WS.defaultConnectionOptions wsServerApp spockApp, cacheRef) return (WS.websocketsOr WS.defaultConnectionOptions wsServerApp spockApp, cacheRef)

View File

@ -10,6 +10,7 @@ module Hasura.Server.Cors
, mkDefaultCorsPolicy , mkDefaultCorsPolicy
, isCorsDisabled , isCorsDisabled
, Domains (..) , Domains (..)
, inWildcardList
) where ) where
import Hasura.Prelude import Hasura.Prelude
@ -45,23 +46,25 @@ $(J.deriveToJSON (J.aesonDrop 2 J.snakeCase) ''Domains)
data CorsConfig data CorsConfig
= CCAllowAll = CCAllowAll
| CCAllowedOrigins Domains | CCAllowedOrigins Domains
| CCDisabled | CCDisabled Bool -- should read cookie?
deriving (Show, Eq) deriving (Show, Eq)
instance J.ToJSON CorsConfig where instance J.ToJSON CorsConfig where
toJSON c = case c of toJSON c = case c of
CCDisabled -> toJ True J.Null CCDisabled wsrc -> toJ True J.Null (Just wsrc)
CCAllowAll -> toJ False (J.String "*") CCAllowAll -> toJ False (J.String "*") Nothing
CCAllowedOrigins d -> toJ False $ J.toJSON d CCAllowedOrigins d -> toJ False (J.toJSON d) Nothing
where where
toJ dis origs = toJ :: Bool -> J.Value -> Maybe Bool -> J.Value
toJ dis origs mWsRC =
J.object [ "disabled" J..= dis J.object [ "disabled" J..= dis
, "ws_read_cookie" J..= mWsRC
, "allowed_origins" J..= origs , "allowed_origins" J..= origs
] ]
isCorsDisabled :: CorsConfig -> Bool isCorsDisabled :: CorsConfig -> Bool
isCorsDisabled = \case isCorsDisabled = \case
CCDisabled -> True CCDisabled _ -> True
_ -> False _ -> False
readCorsDomains :: String -> Either String CorsConfig readCorsDomains :: String -> Either String CorsConfig
@ -89,6 +92,11 @@ mkDefaultCorsPolicy cfg =
, cpMaxAge = 1728000 , cpMaxAge = 1728000
} }
inWildcardList :: Domains -> Text -> Bool
inWildcardList (Domains _ wildcards) origin =
either (const False) (`Set.member` wildcards) $ parseOrigin origin
-- | Parsers for wildcard domains -- | Parsers for wildcard domains
runParser :: AT.Parser a -> Text -> Either String a runParser :: AT.Parser a -> Text -> Either String a

View File

@ -48,6 +48,7 @@ data RawServeOptions
, rsoCorsConfig :: !(Maybe CorsConfig) , rsoCorsConfig :: !(Maybe CorsConfig)
, rsoEnableConsole :: !Bool , rsoEnableConsole :: !Bool
, rsoEnableTelemetry :: !(Maybe Bool) , rsoEnableTelemetry :: !(Maybe Bool)
, rsoWsReadCookie :: !Bool
, rsoStringifyNum :: !Bool , rsoStringifyNum :: !Bool
, rsoEnabledAPIs :: !(Maybe [API]) , rsoEnabledAPIs :: !(Maybe [API])
} deriving (Show, Eq) } deriving (Show, Eq)
@ -265,8 +266,18 @@ mkServeOptions rso = do
authHookTyEnv mType = fromMaybe AHTGet <$> authHookTyEnv mType = fromMaybe AHTGet <$>
withEnv mType "HASURA_GRAPHQL_AUTH_HOOK_TYPE" withEnv mType "HASURA_GRAPHQL_AUTH_HOOK_TYPE"
mkCorsConfig mCfg = mkCorsConfig mCfg = do
fromMaybe CCAllowAll <$> withEnv mCfg (fst corsDomainEnv) corsCfg <- fromMaybe CCAllowAll <$> withEnv mCfg (fst corsDomainEnv)
readCookVal <- withEnvBool (rsoWsReadCookie rso) (fst wsReadCookieEnv)
wsReadCookie <- case (isCorsDisabled corsCfg, readCookVal) of
(True, _) -> return readCookVal
(False, True) -> throwError $ fst wsReadCookieEnv
<> " can only be used when CORS is disabled"
(False, False) -> return False
return $ case corsCfg of
CCDisabled _ -> CCDisabled wsReadCookie
_ -> corsCfg
mkExamplesDoc :: [[String]] -> PP.Doc mkExamplesDoc :: [[String]] -> PP.Doc
mkExamplesDoc exampleLines = mkExamplesDoc exampleLines =
@ -350,7 +361,7 @@ serveCmdFooter =
, pgUsePrepareEnv, txIsoEnv, adminSecretEnv , pgUsePrepareEnv, txIsoEnv, adminSecretEnv
, accessKeyEnv, authHookEnv, authHookModeEnv , accessKeyEnv, authHookEnv, authHookModeEnv
, jwtSecretEnv, unAuthRoleEnv, corsDomainEnv, enableConsoleEnv , jwtSecretEnv, unAuthRoleEnv, corsDomainEnv, enableConsoleEnv
, enableTelemetryEnv, stringifyNumEnv , enableTelemetryEnv, wsReadCookieEnv, stringifyNumEnv, enabledAPIsEnv
] ]
eventEnvs = eventEnvs =
@ -460,13 +471,22 @@ enableTelemetryEnv =
, "Enable anonymous telemetry (default: true)" , "Enable anonymous telemetry (default: true)"
) )
wsReadCookieEnv :: (String, String)
wsReadCookieEnv =
( "HASURA_GRAPHQL_WS_READ_COOKIE"
, "Read cookie on WebSocket initial handshake, even when CORS is disabled."
++ " This can be a potential security flaw! Please make sure you know "
++ "what you're doing."
++ "This configuration is only applicable when CORS is disabled."
)
stringifyNumEnv :: (String, String) stringifyNumEnv :: (String, String)
stringifyNumEnv = stringifyNumEnv =
( "HASURA_GRAPHQL_STRINGIFY_NUMERIC_TYPES" ( "HASURA_GRAPHQL_STRINGIFY_NUMERIC_TYPES"
, "Stringify numeric types (default: false)" , "Stringify numeric types (default: false)"
) )
enabledAPIsEnv :: (String,String) enabledAPIsEnv :: (String, String)
enabledAPIsEnv = enabledAPIsEnv =
( "HASURA_GRAPHQL_ENABLED_APIS" ( "HASURA_GRAPHQL_ENABLED_APIS"
, "List of comma separated list of allowed APIs. (default: metadata,graphql)" , "List of comma separated list of allowed APIs. (default: metadata,graphql)"
@ -674,7 +694,7 @@ parseCorsConfig = mapCC <$> disableCors <*> corsDomain
) )
mapCC isDisabled domains = mapCC isDisabled domains =
bool domains (Just CCDisabled) isDisabled bool domains (Just $ CCDisabled False) isDisabled
parseEnableConsole :: Parser Bool parseEnableConsole :: Parser Bool
parseEnableConsole = parseEnableConsole =
@ -689,6 +709,12 @@ parseEnableTelemetry = optional $
help (snd enableTelemetryEnv) help (snd enableTelemetryEnv)
) )
parseWsReadCookie :: Parser Bool
parseWsReadCookie =
switch ( long "ws-read-cookie" <>
help (snd wsReadCookieEnv)
)
parseStringifyNum :: Parser Bool parseStringifyNum :: Parser Bool
parseStringifyNum = parseStringifyNum =
switch ( long "stringify-numeric-types" <> switch ( long "stringify-numeric-types" <>

View File

@ -11,7 +11,6 @@ import Hasura.Server.Utils
import qualified Data.ByteString as B import qualified Data.ByteString as B
import qualified Data.CaseInsensitive as CI import qualified Data.CaseInsensitive as CI
import qualified Data.HashSet as Set
import qualified Data.Text.Encoding as TE import qualified Data.Text.Encoding as TE
import qualified Network.HTTP.Types as H import qualified Network.HTTP.Types as H
@ -22,8 +21,8 @@ corsMiddleware policy app req sendResp =
where where
handleCors origin = case cpConfig policy of handleCors origin = case cpConfig policy of
CCDisabled -> app req sendResp CCDisabled _ -> app req sendResp
CCAllowAll -> sendCors origin CCAllowAll -> sendCors origin
CCAllowedOrigins ds CCAllowedOrigins ds
-- if the origin is in our cors domains, send cors headers -- if the origin is in our cors domains, send cors headers
| bsToTxt origin `elem` dmFqdns ds -> sendCors origin | bsToTxt origin `elem` dmFqdns ds -> sendCors origin
@ -66,7 +65,3 @@ corsMiddleware policy app req sendResp =
setHeaders hdrs = mapResponseHeaders (\h -> mkRespHdrs hdrs ++ h) setHeaders hdrs = mapResponseHeaders (\h -> mkRespHdrs hdrs ++ h)
mkRespHdrs = map (\(k,v) -> (CI.mk k, v)) mkRespHdrs = map (\(k,v) -> (CI.mk k, v))
inWildcardList :: Domains -> Text -> Bool
inWildcardList (Domains _ wildcards) origin =
either (const False) (`Set.member` wildcards) $ parseOrigin origin

View File

@ -33,6 +33,13 @@ def pytest_addoption(parser):
help="Run testcases for CORS configuration" help="Run testcases for CORS configuration"
) )
parser.addoption(
"--test-ws-init-cookie",
metavar="read|noread",
required=False,
help="Run testcases for testing cookie sending over websockets"
)
parser.addoption( parser.addoption(
"--test-metadata-disabled", action="store_true", "--test-metadata-disabled", action="store_true",
help="Run Test cases with metadata queries being disabled" help="Run Test cases with metadata queries being disabled"
@ -54,7 +61,7 @@ def hge_ctx(request):
webhook_insecure = request.config.getoption('--test-webhook-insecure') webhook_insecure = request.config.getoption('--test-webhook-insecure')
hge_jwt_key_file = request.config.getoption('--hge-jwt-key-file') hge_jwt_key_file = request.config.getoption('--hge-jwt-key-file')
hge_jwt_conf = request.config.getoption('--hge-jwt-conf') hge_jwt_conf = request.config.getoption('--hge-jwt-conf')
test_cors = request.config.getoption('--test-cors') ws_read_cookie = request.config.getoption('--test-ws-init-cookie')
metadata_disabled = request.config.getoption('--test-metadata-disabled') metadata_disabled = request.config.getoption('--test-metadata-disabled')
try: try:
hge_ctx = HGECtx( hge_ctx = HGECtx(
@ -65,6 +72,7 @@ def hge_ctx(request):
webhook_insecure=webhook_insecure, webhook_insecure=webhook_insecure,
hge_jwt_key_file=hge_jwt_key_file, hge_jwt_key_file=hge_jwt_key_file,
hge_jwt_conf=hge_jwt_conf, hge_jwt_conf=hge_jwt_conf,
ws_read_cookie=ws_read_cookie,
metadata_disabled=metadata_disabled metadata_disabled=metadata_disabled
) )
except HGECtxError as e: except HGECtxError as e:

View File

@ -74,7 +74,8 @@ class WebhookServer(http.server.HTTPServer):
class HGECtx: class HGECtx:
def __init__(self, hge_url, pg_url, hge_key, hge_webhook, webhook_insecure, hge_jwt_key_file, hge_jwt_conf, metadata_disabled): def __init__(self, hge_url, pg_url, hge_key, hge_webhook, webhook_insecure,
hge_jwt_key_file, hge_jwt_conf, metadata_disabled, ws_read_cookie):
server_address = ('0.0.0.0', 5592) server_address = ('0.0.0.0', 5592)
self.resp_queue = queue.Queue(maxsize=1) self.resp_queue = queue.Queue(maxsize=1)
@ -115,6 +116,8 @@ class HGECtx:
self.gql_srvr_thread = threading.Thread(target=self.graphql_server.serve_forever) self.gql_srvr_thread = threading.Thread(target=self.graphql_server.serve_forever)
self.gql_srvr_thread.start() self.gql_srvr_thread.start()
self.ws_read_cookie = ws_read_cookie
result = subprocess.run(['../../scripts/get-version.sh'], shell=False, stdout=subprocess.PIPE, check=True) result = subprocess.run(['../../scripts/get-version.sh'], shell=False, stdout=subprocess.PIPE, check=True)
self.version = result.stdout.decode('utf-8').strip() self.version = result.stdout.decode('utf-8').strip()
if not self.metadata_disabled: if not self.metadata_disabled:

View File

@ -0,0 +1,38 @@
"""
Sample auth webhook to receive a cookie and respond
"""
from http import HTTPStatus
from webserver import RequestHandler, WebServer, MkHandlers, Response
class CookieAuth(RequestHandler):
def get(self, request):
headers = {k.lower(): v for k, v in request.headers.items()}
print(headers)
if 'cookie' in headers and headers['cookie']:
res = {'x-hasura-role': 'admin'}
return Response(HTTPStatus.OK, res)
return Response(HTTPStatus.UNAUTHORIZED)
def post(self, request):
headers = {k.lower(): v for k, v in request.json['headers'].items()}
print(headers)
if 'cookie' in headers and headers['cookie']:
res = {'x-hasura-role': 'admin'}
return Response(HTTPStatus.OK, res)
return Response(HTTPStatus.UNAUTHORIZED)
handlers = MkHandlers({
'/auth': CookieAuth,
})
def create_server(host='127.0.0.1', port=9876):
return WebServer((host, port), handlers)
def stop_server(server):
server.shutdown()
server.server_close()
if __name__ == '__main__':
s = create_server(host='0.0.0.0')
s.serve_forever()

View File

@ -0,0 +1,95 @@
import json
import threading
from urllib.parse import urlparse
import websocket
import pytest
from validate import check_query
if not pytest.config.getoption("--test-ws-init-cookie"):
pytest.skip("--test-ws-init-cookie flag is missing, skipping tests", allow_module_level=True)
def url(hge_ctx):
ws_url = urlparse(hge_ctx.hge_url)._replace(scheme='ws', path='/v1alpha1/graphql')
return ws_url.geturl()
class TestWebsocketInitCookie():
"""
test if cookie is sent when initing the websocket connection, is our auth
webhook receiving the cookie
"""
dir = 'queries/remote_schemas'
@pytest.fixture(autouse=True)
def transact(self, hge_ctx):
st_code, resp = hge_ctx.v1q_f(self.dir + '/person_table.yaml')
assert st_code == 200, resp
yield
assert st_code == 200, resp
st_code, resp = hge_ctx.v1q_f(self.dir + '/drop_person_table.yaml')
def _send_query(self, hge_ctx):
ws_url = url(hge_ctx)
headers = {'Cookie': 'foo=bar;'}
ws = websocket.create_connection(ws_url, header=headers)
init_payload = {
'type': 'connection_init',
'payload': {'headers': {}}
}
ws.send(json.dumps(init_payload))
payload = {
'type': 'start',
'id': '1',
'payload': {'query': 'query { person {name}}'}
}
ws.send(json.dumps(payload))
return ws
def test_websocket_init_cookie_used(self, hge_ctx):
if hge_ctx.ws_read_cookie == 'noread':
pytest.skip('cookie is not to be read')
ws = self._send_query(hge_ctx)
it = 0
while True:
raw = ws.recv()
frame = json.loads(raw)
if frame['type'] == 'data':
assert 'person' in frame['payload']['data']
break
elif it == 10:
print('max try over')
assert False
break
elif frame['type'] == 'connection_error' or frame['type'] == 'error':
print(frame)
assert False
break
it = it + 1
def test_websocket_init_cookie_not_used(self, hge_ctx):
if hge_ctx.ws_read_cookie == 'read':
pytest.skip('cookie is read')
ws = self._send_query(hge_ctx)
it = 0
while True:
raw = ws.recv()
frame = json.loads(raw)
if frame['type'] == 'data':
print('got data')
assert False
break
elif it == 10:
print('max try over')
assert False
break
elif frame['type'] == 'connection_error':
print(frame)
assert frame['payload'] == 'Authentication hook unauthorized this request'
break
elif frame['type'] == 'error':
print(frame)
assert False
break
it = it + 1