mirror of
https://github.com/hasura/graphql-engine.git
synced 2024-11-10 10:29:12 +03:00
pass gql requests into auth webhook POST body (#149)
* fix arg order in UserAuthentication instance [force ci] * change the constructor name to AHGraphQLRequest Co-authored-by: Stylish Haskell Bot <stylish-haskell@users.noreply.github.com> Co-authored-by: Karthikeyan Chinnakonda <karthikeyan@hasura.io> GitOrigin-RevId: fb3258f4a84efc6c730b0c6222ebd8cea1b91081
This commit is contained in:
parent
1583fa6872
commit
c14dcd5792
@ -21,6 +21,7 @@ ws-metadata-api-disabled
|
||||
remote-schema-permissions
|
||||
query-caching
|
||||
query-logs
|
||||
webhook-request-context
|
||||
post-webhook
|
||||
get-webhook
|
||||
insecure-webhook
|
||||
|
@ -526,7 +526,7 @@ case "$SERVER_TEST_TO_RUN" in
|
||||
ws-init-cookie-read-cors-enabled)
|
||||
# test websocket transport with initial cookie header
|
||||
|
||||
echo -e "\n$(time_elapsped): <########## TEST GRAPHQL-ENGINE WITH COOKIE IN WEBSOCKET INIT ########>\n"
|
||||
echo -e "\n$(time_elapsed): <########## TEST GRAPHQL-ENGINE WITH COOKIE IN WEBSOCKET INIT ########>\n"
|
||||
TEST_TYPE="ws-init-cookie-read-cors-enabled"
|
||||
export HASURA_GRAPHQL_AUTH_HOOK="http://localhost:9876/auth"
|
||||
export HASURA_GRAPHQL_AUTH_HOOK_MODE="POST"
|
||||
@ -737,6 +737,24 @@ case "$SERVER_TEST_TO_RUN" in
|
||||
fi
|
||||
;;
|
||||
|
||||
webhook-request-context)
|
||||
if [ "$RUN_WEBHOOK_TESTS" == "true" ] ; then
|
||||
echo -e "\n$(time_elapsed): <########## TEST WEBHOOK RECEIVES REQUEST DATA AS CONTEXT #########################>\n"
|
||||
TEST_TYPE="webhook-request-context"
|
||||
export HASURA_GRAPHQL_AUTH_HOOK="http://localhost:5594/"
|
||||
export HASURA_GRAPHQL_AUTH_HOOK_MODE="POST"
|
||||
export HASURA_GRAPHQL_ADMIN_SECRET="HGE$RANDOM$RANDOM"
|
||||
|
||||
run_hge_with_args serve
|
||||
wait_for_port 8080
|
||||
|
||||
pytest -s -n 1 -vv --hge-urls "$HGE_URL" --pg-urls "$HASURA_GRAPHQL_DATABASE_URL" --hge-key="$HASURA_GRAPHQL_ADMIN_SECRET" --hge-webhook="$HASURA_GRAPHQL_AUTH_HOOK" --test-webhook-request-context test_webhook_request_context.py
|
||||
|
||||
kill_hge_servers
|
||||
fi
|
||||
;;
|
||||
|
||||
|
||||
get-webhook)
|
||||
if [ "$RUN_WEBHOOK_TESTS" == "true" ] ; then
|
||||
echo -e "\n$(time_elapsed): <########## TEST GRAPHQL-ENGINE WITH ADMIN SECRET & WEBHOOK (GET) #########################>\n"
|
||||
|
@ -96,6 +96,7 @@ have select permissions to the target table of the function.
|
||||
|
||||
(Add entries here in the order of: server, console, cli, docs, others)
|
||||
|
||||
- server: add `request` field to webhook POST body containing the GraphQL query/mutation, its name, and any variables passed (close #2666)
|
||||
- server: fix a regression where variables in fragments weren't accepted (fix #6303)
|
||||
- server: output stack traces when encountering conflicting GraphQL types in the schema
|
||||
- server: add `--websocket-compression` command-line flag for enabling websocket compression (fix #3292)
|
||||
|
@ -673,8 +673,8 @@ instance MonadExecuteQuery PGMetadataStorageApp where
|
||||
cacheStore _ _ = pure ()
|
||||
|
||||
instance UserAuthentication (Tracing.TraceT PGMetadataStorageApp) where
|
||||
resolveUserInfo logger manager headers authMode =
|
||||
runExceptT $ getUserInfoWithExpTime logger manager headers authMode
|
||||
resolveUserInfo logger manager headers authMode reqs =
|
||||
runExceptT $ getUserInfoWithExpTime logger manager headers authMode reqs
|
||||
|
||||
accessDeniedErrMsg :: Text
|
||||
accessDeniedErrMsg =
|
||||
|
@ -636,7 +636,7 @@ onConnInit logger manager wsConn authMode connParamsM = do
|
||||
Left err -> unexpectedInitError err
|
||||
Right ipAddress -> do
|
||||
let headers = mkHeaders connState
|
||||
res <- resolveUserInfo logger manager headers authMode
|
||||
res <- resolveUserInfo logger manager headers authMode Nothing
|
||||
case res of
|
||||
Left e -> do
|
||||
let !initErr = CSInitError $ qeError e
|
||||
|
@ -131,9 +131,18 @@ data APIResp
|
||||
= JSONResp !(HttpResponse EncJSON)
|
||||
| RawResp !(HttpResponse BL.ByteString)
|
||||
|
||||
data APIHandler m a
|
||||
= AHGet !(Handler m APIResp)
|
||||
| AHPost !(a -> Handler m APIResp)
|
||||
type ReqsText = GH.GQLBatchedReqs GH.GQLQueryText
|
||||
|
||||
-- | API request handlers for different endpoints
|
||||
data APIHandler m a where
|
||||
-- | A simple GET request
|
||||
AHGet :: !(Handler m APIResp) -> APIHandler m void
|
||||
-- | A simple POST request that expects a request body from which an 'a' can be extracted
|
||||
AHPost :: !(a -> Handler m APIResp) -> APIHandler m a
|
||||
-- | A general GraphQL request (query or mutation) for which the content of the query
|
||||
-- is made available to the handler for authentication.
|
||||
-- This is a more specific version of the 'AHPost' constructor.
|
||||
AHGraphQLRequest :: !(ReqsText -> Handler m APIResp) -> APIHandler m ReqsText
|
||||
|
||||
boolToText :: Bool -> Text
|
||||
boolToText = bool "false" "true"
|
||||
@ -157,6 +166,9 @@ withSCUpdate scr logger action =
|
||||
(!res, !newSC) <- action
|
||||
liftIO $ do
|
||||
-- update schemacache in IO reference
|
||||
|
||||
|
||||
|
||||
modifyIORef' cacheRef $ \(_, prevVer) ->
|
||||
let !newVer = incSchemaCacheVer prevVer
|
||||
in (newSC, newVer)
|
||||
@ -173,6 +185,9 @@ mkGetHandler = AHGet
|
||||
mkPostHandler :: (a -> Handler m APIResp) -> APIHandler m a
|
||||
mkPostHandler = AHPost
|
||||
|
||||
mkGQLRequestHandler :: (ReqsText -> Handler m APIResp) -> APIHandler m ReqsText
|
||||
mkGQLRequestHandler = AHGraphQLRequest
|
||||
|
||||
mkAPIRespHandler :: (Functor m) => (a -> Handler m (HttpResponse EncJSON)) -> (a -> Handler m APIResp)
|
||||
mkAPIRespHandler = (fmap . fmap) JSONResp
|
||||
|
||||
@ -314,33 +329,43 @@ mkSpockAction serverCtx qErrEncoder qErrModifier apiHandler = do
|
||||
-- can correlate requests and traces
|
||||
lift $ Tracing.attachMetadata [("request_id", unRequestId requestId)]
|
||||
|
||||
userInfoE <- fmap fst <$> lift (resolveUserInfo logger manager headers authMode)
|
||||
userInfo <- onLeft userInfoE (logErrorAndResp Nothing requestId req (reqBody, Nothing) False headers . qErrModifier)
|
||||
|
||||
let handlerState = HandlerCtx serverCtx userInfo headers requestId ipAddress
|
||||
includeInternal = shouldIncludeInternal (_uiRole userInfo) $
|
||||
scResponseInternalErrorsConfig serverCtx
|
||||
|
||||
let getInfo parsedRequest = do
|
||||
userInfoE <- fmap fst <$> lift (resolveUserInfo logger manager headers authMode parsedRequest)
|
||||
userInfo <- onLeft userInfoE (logErrorAndResp Nothing requestId req (reqBody, Nothing) False headers . qErrModifier)
|
||||
let handlerState = HandlerCtx serverCtx userInfo headers requestId ipAddress
|
||||
includeInternal = shouldIncludeInternal (_uiRole userInfo) $
|
||||
scResponseInternalErrorsConfig serverCtx
|
||||
pure (userInfo, handlerState, includeInternal)
|
||||
limits <- lift askResourceLimits
|
||||
let runHandler = runMetadataStorageT . flip runReaderT handlerState . runResourceLimits limits
|
||||
let runHandler st = runMetadataStorageT . flip runReaderT st . runResourceLimits limits
|
||||
|
||||
(serviceTime, (result, q)) <- withElapsedTime $ case apiHandler of
|
||||
(serviceTime, (result, userInfo, includeInternal, query)) <- withElapsedTime $ case apiHandler of
|
||||
-- in the case of a simple get/post we don't have to send the webhook anything
|
||||
AHGet handler -> do
|
||||
res <- lift $ runHandler handler
|
||||
return (res, Nothing)
|
||||
(userInfo, handlerState, includeInternal) <- getInfo Nothing
|
||||
res <- lift $ runHandler handlerState handler
|
||||
return (res, userInfo, includeInternal, Nothing)
|
||||
AHPost handler -> do
|
||||
(userInfo, handlerState, includeInternal) <- getInfo Nothing
|
||||
parsedReqE <- runExceptT $ parseBody reqBody
|
||||
parsedReq <- onLeft parsedReqE (logErrorAndResp (Just userInfo) requestId req (reqBody, Nothing) includeInternal headers . qErrModifier)
|
||||
res <- lift $ runHandler $ handler parsedReq
|
||||
return (res, Just parsedReq)
|
||||
res <- lift $ runHandler handlerState $ handler parsedReq
|
||||
return (res, userInfo, includeInternal, Just parsedReq)
|
||||
-- in this case we parse the request _first_ and then send the request to the webhook for auth
|
||||
AHGraphQLRequest handler -> do
|
||||
parsedReqE <- runExceptT $ parseBody reqBody
|
||||
parsedReq <- onLeft parsedReqE (logErrorAndResp Nothing requestId req (reqBody, Nothing) False headers . qErrModifier)
|
||||
(userInfo, handlerState, includeInternal) <- getInfo (Just parsedReq)
|
||||
res <- lift $ runHandler handlerState $ handler parsedReq
|
||||
return (res, userInfo, includeInternal, Just parsedReq)
|
||||
|
||||
-- apply the error modifier
|
||||
let modResult = fmapL qErrModifier result
|
||||
|
||||
-- log and return result
|
||||
case modResult of
|
||||
Left err -> logErrorAndResp (Just userInfo) requestId req (reqBody, toJSON <$> q) includeInternal headers err
|
||||
Right res -> logSuccessAndResp (Just userInfo) requestId req (reqBody, toJSON <$> q) res (Just (ioWaitTime, serviceTime)) headers
|
||||
Left err -> logErrorAndResp (Just userInfo) requestId req (reqBody, toJSON <$> query) includeInternal headers err
|
||||
Right res -> logSuccessAndResp (Just userInfo) requestId req (reqBody, toJSON <$> query) res (Just (ioWaitTime, serviceTime)) headers
|
||||
|
||||
where
|
||||
logger = scLogger serverCtx
|
||||
@ -787,6 +812,7 @@ mkWaiApp env logger sqlGenCtx enableAL httpManager mode corsCfg enableConsole co
|
||||
initialiseCache :: m SchemaCacheRef
|
||||
initialiseCache = do
|
||||
cacheLock <- liftIO $ newMVar ()
|
||||
|
||||
cacheCell <- liftIO $ newIORef (schemaCache, initSchemaCacheVer)
|
||||
-- planCache <- liftIO $ E.initPlanCache planCacheOptions
|
||||
let cacheRef = SchemaCacheRef cacheLock cacheCell E.clearPlanCache
|
||||
@ -920,13 +946,13 @@ httpApp corsCfg serverCtx enableConsole consoleAssetsDir enableTelemetry = do
|
||||
|
||||
when enableGraphQL $ do
|
||||
Spock.post "v1alpha1/graphql" $ spockAction GH.encodeGQErr id $
|
||||
mkPostHandler $ mkAPIRespHandler $ v1Alpha1GQHandler E.QueryHasura
|
||||
mkGQLRequestHandler $ mkAPIRespHandler $ v1Alpha1GQHandler E.QueryHasura
|
||||
|
||||
Spock.post "v1/graphql" $ spockAction GH.encodeGQErr allMod200 $
|
||||
mkPostHandler $ mkAPIRespHandler v1GQHandler
|
||||
mkGQLRequestHandler $ mkAPIRespHandler v1GQHandler
|
||||
|
||||
Spock.post "v1beta1/relay" $ spockAction GH.encodeGQErr allMod200 $
|
||||
mkPostHandler $ mkAPIRespHandler $ v1GQRelayHandler
|
||||
mkGQLRequestHandler $ mkAPIRespHandler $ v1GQRelayHandler
|
||||
|
||||
when (isDeveloperAPIEnabled serverCtx) $ do
|
||||
Spock.get "dev/ekg" $ spockAction encodeQErr id $
|
||||
|
@ -26,29 +26,28 @@ module Hasura.Server.Auth
|
||||
|
||||
import Hasura.Prelude
|
||||
|
||||
import qualified Crypto.Hash as Crypto
|
||||
import qualified Data.Text.Encoding as T
|
||||
import qualified Network.HTTP.Client as H
|
||||
import qualified Network.HTTP.Types as N
|
||||
import qualified Crypto.Hash as Crypto
|
||||
import qualified Data.Text.Encoding as T
|
||||
import qualified Network.HTTP.Client as H
|
||||
import qualified Network.HTTP.Types as N
|
||||
|
||||
import Control.Concurrent.Extended (ForkableMonadIO, forkManagedT)
|
||||
import Control.Monad.Trans.Managed (ManagedT)
|
||||
import Control.Monad.Morph (hoist)
|
||||
import Control.Monad.Trans.Control (MonadBaseControl)
|
||||
import Data.IORef (newIORef)
|
||||
import Data.Time.Clock (UTCTime)
|
||||
import Control.Concurrent.Extended (ForkableMonadIO, forkManagedT)
|
||||
import Control.Monad.Morph (hoist)
|
||||
import Control.Monad.Trans.Control (MonadBaseControl)
|
||||
import Control.Monad.Trans.Managed (ManagedT)
|
||||
import Data.IORef (newIORef)
|
||||
import Data.Time.Clock (UTCTime)
|
||||
|
||||
import qualified Hasura.Tracing as Tracing
|
||||
import qualified Hasura.Tracing as Tracing
|
||||
|
||||
import Hasura.Logging
|
||||
import Hasura.RQL.Types
|
||||
import Hasura.Server.Auth.JWT hiding (processJwt_)
|
||||
import Hasura.Server.Auth.JWT hiding (processJwt_)
|
||||
import Hasura.Server.Auth.WebHook
|
||||
import Hasura.Server.Utils
|
||||
import Hasura.Server.Version (HasVersion)
|
||||
import Hasura.Server.Version (HasVersion)
|
||||
import Hasura.Session
|
||||
|
||||
|
||||
-- | Typeclass representing the @UserInfo@ authorization and resolving effect
|
||||
class (Monad m) => UserAuthentication m where
|
||||
resolveUserInfo
|
||||
@ -58,6 +57,7 @@ class (Monad m) => UserAuthentication m where
|
||||
-> [N.Header]
|
||||
-- ^ request headers
|
||||
-> AuthMode
|
||||
-> Maybe ReqsText
|
||||
-> m (Either QErr (UserInfo, Maybe UTCTime))
|
||||
|
||||
-- | The hashed admin password. 'hashAdminSecret' is our public interface for
|
||||
@ -174,9 +174,9 @@ setupAuthMode mAdminSecretHash mWebHook mJwtSecret mUnAuthRole httpManager logge
|
||||
res <- runExceptT act
|
||||
onLeft res $ \case
|
||||
-- when fetching JWK initially, except expiry parsing error, all errors are critical
|
||||
JFEHttpException _ msg -> throwError msg
|
||||
JFEHttpError _ _ _ e -> throwError e
|
||||
JFEJwkParseError _ e -> throwError e
|
||||
JFEHttpException _ msg -> throwError msg
|
||||
JFEHttpError _ _ _ e -> throwError e
|
||||
JFEJwkParseError _ e -> throwError e
|
||||
JFEExpiryParseError _ _ -> return Nothing
|
||||
|
||||
getUserInfo
|
||||
@ -185,8 +185,9 @@ getUserInfo
|
||||
-> H.Manager
|
||||
-> [N.Header]
|
||||
-> AuthMode
|
||||
-> Maybe ReqsText
|
||||
-> m UserInfo
|
||||
getUserInfo l m r a = fst <$> getUserInfoWithExpTime l m r a
|
||||
getUserInfo l m r a reqs = fst <$> getUserInfoWithExpTime l m r a reqs
|
||||
|
||||
-- | Authenticate the request using the headers and the configured 'AuthMode'.
|
||||
getUserInfoWithExpTime
|
||||
@ -195,13 +196,14 @@ getUserInfoWithExpTime
|
||||
-> H.Manager
|
||||
-> [N.Header]
|
||||
-> AuthMode
|
||||
-> Maybe ReqsText
|
||||
-> m (UserInfo, Maybe UTCTime)
|
||||
getUserInfoWithExpTime = getUserInfoWithExpTime_ userInfoFromAuthHook processJwt
|
||||
|
||||
-- Broken out for testing with mocks:
|
||||
getUserInfoWithExpTime_
|
||||
:: forall m _Manager _Logger_Hasura. (MonadIO m, MonadError QErr m)
|
||||
=> (_Logger_Hasura -> _Manager -> AuthHook -> [N.Header] -> m (UserInfo, Maybe UTCTime))
|
||||
=> (_Logger_Hasura -> _Manager -> AuthHook -> [N.Header] -> Maybe ReqsText -> m (UserInfo, Maybe UTCTime))
|
||||
-- ^ mock 'userInfoFromAuthHook'
|
||||
-> (JWTCtx -> [N.Header] -> Maybe RoleName -> m (UserInfo, Maybe UTCTime))
|
||||
-- ^ mock 'processJwt'
|
||||
@ -209,8 +211,9 @@ getUserInfoWithExpTime_
|
||||
-> _Manager
|
||||
-> [N.Header]
|
||||
-> AuthMode
|
||||
-> Maybe ReqsText
|
||||
-> m (UserInfo, Maybe UTCTime)
|
||||
getUserInfoWithExpTime_ userInfoFromAuthHook_ processJwt_ logger manager rawHeaders = \case
|
||||
getUserInfoWithExpTime_ userInfoFromAuthHook_ processJwt_ logger manager rawHeaders authMode reqs = case authMode of
|
||||
|
||||
AMNoAuth -> withNoExpTime $ mkUserInfoFallbackAdminRole UAuthNotSet
|
||||
|
||||
@ -227,8 +230,9 @@ getUserInfoWithExpTime_ userInfoFromAuthHook_ processJwt_ logger manager rawHead
|
||||
Just unAuthRole ->
|
||||
mkUserInfo (URBPreDetermined unAuthRole) UAdminSecretNotSent sessionVariables
|
||||
|
||||
-- this is the case that actually ends up consuming the request AST
|
||||
AMAdminSecretAndHook realAdminSecretHash hook ->
|
||||
checkingSecretIfSent realAdminSecretHash $ userInfoFromAuthHook_ logger manager hook rawHeaders
|
||||
checkingSecretIfSent realAdminSecretHash $ userInfoFromAuthHook_ logger manager hook rawHeaders reqs
|
||||
|
||||
AMAdminSecretAndJWT realAdminSecretHash jwtSecret unAuthRole ->
|
||||
checkingSecretIfSent realAdminSecretHash $ processJwt_ jwtSecret rawHeaders unAuthRole
|
||||
|
@ -3,26 +3,29 @@ module Hasura.Server.Auth.WebHook
|
||||
, AuthHookG (..)
|
||||
, AuthHook
|
||||
, userInfoFromAuthHook
|
||||
, userInfoFromAuthHook'
|
||||
, type ReqsText
|
||||
) where
|
||||
|
||||
import Control.Exception.Lifted (try)
|
||||
import Control.Exception.Lifted (try)
|
||||
import Control.Lens
|
||||
import Control.Monad.Trans.Control (MonadBaseControl)
|
||||
import Control.Monad.Trans.Control (MonadBaseControl)
|
||||
import Control.Monad.Trans.Maybe
|
||||
import Data.Aeson
|
||||
import Data.Time.Clock (UTCTime, addUTCTime, getCurrentTime)
|
||||
import Hasura.Server.Version (HasVersion)
|
||||
import Data.Time.Clock (UTCTime, addUTCTime, getCurrentTime)
|
||||
import Hasura.Server.Version (HasVersion)
|
||||
|
||||
import qualified Data.Aeson as J
|
||||
import qualified Data.ByteString.Lazy as BL
|
||||
import qualified Data.HashMap.Strict as Map
|
||||
import qualified Data.Text as T
|
||||
import qualified Network.HTTP.Client as H
|
||||
import qualified Network.HTTP.Types as N
|
||||
import qualified Network.Wreq as Wreq
|
||||
import qualified Data.Aeson as J
|
||||
import qualified Data.ByteString.Lazy as BL
|
||||
import qualified Data.HashMap.Strict as Map
|
||||
import qualified Data.Text as T
|
||||
import qualified Network.HTTP.Client as H
|
||||
import qualified Network.HTTP.Types as N
|
||||
import qualified Network.Wreq as Wreq
|
||||
|
||||
import Data.Parser.CacheControl
|
||||
import Data.Parser.Expires
|
||||
import qualified Hasura.GraphQL.Transport.HTTP.Protocol as GH
|
||||
import Hasura.HTTP
|
||||
import Hasura.Logging
|
||||
import Hasura.Prelude
|
||||
@ -30,7 +33,7 @@ import Hasura.RQL.Types
|
||||
import Hasura.Server.Logging
|
||||
import Hasura.Server.Utils
|
||||
import Hasura.Session
|
||||
import qualified Hasura.Tracing as Tracing
|
||||
import qualified Hasura.Tracing as Tracing
|
||||
|
||||
data AuthHookType
|
||||
= AHTGet
|
||||
@ -55,11 +58,9 @@ hookMethod authHook = case ahType authHook of
|
||||
AHTGet -> N.GET
|
||||
AHTPost -> N.POST
|
||||
|
||||
type ReqsText = GH.GQLBatchedReqs GH.GQLQueryText
|
||||
|
||||
-- | Makes an authentication request to the given AuthHook and returns
|
||||
-- UserInfo parsed from the response, plus an expiration time if one
|
||||
-- was returned.
|
||||
userInfoFromAuthHook
|
||||
userInfoFromAuthHook'
|
||||
:: forall m
|
||||
. (HasVersion, MonadIO m, MonadBaseControl IO m, MonadError QErr m, Tracing.MonadTrace m)
|
||||
=> Logger Hasura
|
||||
@ -67,7 +68,22 @@ userInfoFromAuthHook
|
||||
-> AuthHook
|
||||
-> [N.Header]
|
||||
-> m (UserInfo, Maybe UTCTime)
|
||||
userInfoFromAuthHook logger manager hook reqHeaders = do
|
||||
userInfoFromAuthHook' l m h r = userInfoFromAuthHook l m h r Nothing
|
||||
|
||||
-- | Makes an authentication request to the given AuthHook and returns
|
||||
-- UserInfo parsed from the response, plus an expiration time if one
|
||||
-- was returned. Optionally passes a batch of raw GraphQL requests
|
||||
-- for finer-grained auth. (#2666)
|
||||
userInfoFromAuthHook
|
||||
:: forall m
|
||||
. (HasVersion, MonadIO m, MonadBaseControl IO m, MonadError QErr m, Tracing.MonadTrace m)
|
||||
=> Logger Hasura
|
||||
-> H.Manager
|
||||
-> AuthHook
|
||||
-> [N.Header]
|
||||
-> Maybe ReqsText
|
||||
-> m (UserInfo, Maybe UTCTime)
|
||||
userInfoFromAuthHook logger manager hook reqHeaders reqs = do
|
||||
resp <- (`onLeft` logAndThrow) =<< try performHTTPRequest
|
||||
let status = resp ^. Wreq.responseStatus
|
||||
respBody = resp ^. Wreq.responseBody
|
||||
@ -88,7 +104,8 @@ userInfoFromAuthHook logger manager hook reqHeaders = do
|
||||
headersPayload = J.toJSON $ Map.fromList $ hdrsToText reqHeaders
|
||||
H.httpLbs (req' { H.method = "POST"
|
||||
, H.requestHeaders = addDefaultHeaders [contentType]
|
||||
, H.requestBody = H.RequestBodyLBS . J.encode $ object ["headers" J..= headersPayload]
|
||||
, H.requestBody = H.RequestBodyLBS . J.encode $ object ["headers" J..= headersPayload,
|
||||
"request" J..= reqs]
|
||||
}) manager
|
||||
|
||||
logAndThrow :: H.HttpException -> m a
|
||||
|
@ -6,20 +6,21 @@ import Hasura.Logging
|
||||
import Hasura.Prelude
|
||||
import Hasura.Server.Version
|
||||
|
||||
import Control.Monad.Trans.Managed (lowerManagedT)
|
||||
import Control.Monad.Trans.Control
|
||||
import Control.Lens hiding ((.=))
|
||||
import Control.Monad.Trans.Control
|
||||
import Control.Monad.Trans.Managed (lowerManagedT)
|
||||
import qualified Crypto.JOSE.JWK as Jose
|
||||
import qualified Crypto.JWT as JWT
|
||||
import Data.Aeson ((.=))
|
||||
import qualified Data.Aeson as J
|
||||
import Data.Parser.JSONPath
|
||||
import qualified Data.HashMap.Strict as Map
|
||||
import Data.Parser.JSONPath
|
||||
import qualified Network.HTTP.Types as N
|
||||
|
||||
import Hasura.RQL.Types
|
||||
import Hasura.Server.Auth hiding (getUserInfoWithExpTime, processJwt)
|
||||
import Hasura.Server.Auth.JWT hiding (processJwt)
|
||||
import Hasura.Server.Auth.WebHook (ReqsText)
|
||||
import Hasura.Server.Utils
|
||||
import Hasura.Session
|
||||
import qualified Hasura.Tracing as Tracing
|
||||
@ -39,23 +40,24 @@ defaultRoleClaimText = sessionVariableToText defaultRoleClaim
|
||||
|
||||
-- Unit test the core of our authentication code. This doesn't test the details
|
||||
-- of resolving roles from JWT or webhook.
|
||||
-- TODO(swann): does this need to also test passing
|
||||
getUserInfoWithExpTimeTests :: Spec
|
||||
getUserInfoWithExpTimeTests = describe "getUserInfo" $ do
|
||||
---- FUNCTION UNDER TEST:
|
||||
let getUserInfoWithExpTime
|
||||
let gqlUserInfoWithExpTime
|
||||
:: J.Object
|
||||
-- ^ For JWT, inject the raw claims object as though returned from 'processAuthZHeader'
|
||||
-- acting on an 'Authorization' header from the request
|
||||
-> [N.Header] -> AuthMode -> IO (Either Code RoleName)
|
||||
getUserInfoWithExpTime claims rawHeaders =
|
||||
-> [N.Header] -> AuthMode -> Maybe ReqsText -> IO (Either Code RoleName)
|
||||
gqlUserInfoWithExpTime claims rawHeaders authMode =
|
||||
runExceptT
|
||||
. withExceptT qeCode -- just look at Code for purposes of tests
|
||||
. fmap _uiRole -- just look at RoleName for purposes of tests
|
||||
. fmap fst -- disregard Nothing expiration
|
||||
. getUserInfoWithExpTime_ userInfoFromAuthHook processJwt () () rawHeaders
|
||||
. getUserInfoWithExpTime_ userInfoFromAuthHook processJwt () () rawHeaders authMode
|
||||
where
|
||||
-- mock authorization callbacks:
|
||||
userInfoFromAuthHook _ _ _hook _reqHeaders = do
|
||||
userInfoFromAuthHook _ _ _hook _reqHeaders _optionalReqs = do
|
||||
(, Nothing) <$> _UserInfo "hook"
|
||||
where
|
||||
-- we don't care about details here; we'll just check role name in tests:
|
||||
@ -67,6 +69,13 @@ getUserInfoWithExpTimeTests = describe "getUserInfo" $ do
|
||||
-- processAuthZHeader:
|
||||
\_jwtCtx _authzHeader -> return (mapKeys mkSessionVariable claims, Nothing)
|
||||
|
||||
let getUserInfoWithExpTime
|
||||
:: J.Object
|
||||
-- ^ For JWT, inject the raw claims object as though returned from 'processAuthZHeader'
|
||||
-- acting on an 'Authorization' header from the request
|
||||
-> [N.Header] -> AuthMode -> IO (Either Code RoleName)
|
||||
getUserInfoWithExpTime o claims authMode = gqlUserInfoWithExpTime o claims authMode Nothing
|
||||
|
||||
let setupAuthMode'E a b c d =
|
||||
either (const $ error "fixme") id <$> setupAuthMode' a b c d
|
||||
|
||||
@ -549,7 +558,7 @@ mkCustomDefaultRoleClaim claimPath defVal =
|
||||
-- as the literal value by removing the `Maybe` of defVal
|
||||
case claimPath of
|
||||
Just path -> JWTCustomClaimsMapJSONPath (mkJSONPathE path) $ defRoleName
|
||||
Nothing -> JWTCustomClaimsMapStatic $ fromMaybe (mkRoleNameE "user") defRoleName
|
||||
Nothing -> JWTCustomClaimsMapStatic $ fromMaybe (mkRoleNameE "user") defRoleName
|
||||
where
|
||||
defRoleName = mkRoleNameE <$> defVal
|
||||
|
||||
@ -572,7 +581,7 @@ mkCustomOtherClaim claimPath defVal =
|
||||
-- as the literal value by removing the `Maybe` of defVal
|
||||
case claimPath of
|
||||
Just path -> JWTCustomClaimsMapJSONPath (mkJSONPathE path) $ defVal
|
||||
Nothing -> JWTCustomClaimsMapStatic $ fromMaybe "default claim value" defVal
|
||||
Nothing -> JWTCustomClaimsMapStatic $ fromMaybe "default claim value" defVal
|
||||
|
||||
fakeJWTConfig :: JWTConfig
|
||||
fakeJWTConfig =
|
||||
|
@ -32,6 +32,10 @@ def pytest_addoption(parser):
|
||||
"--test-webhook-insecure", action="store_true",
|
||||
help="Run Test cases for insecure https webhook"
|
||||
)
|
||||
parser.addoption(
|
||||
"--test-webhook-request-context", action="store_true",
|
||||
help="Run Test cases for testing webhook request context"
|
||||
)
|
||||
parser.addoption(
|
||||
"--hge-jwt-key-file", metavar="HGE_JWT_KEY_FILE", help="File containting the private key used to encode jwt tokens using RS512 algorithm", required=False
|
||||
)
|
||||
|
34
server/tests-py/queries/webhooks/request_context/setup.yaml
Normal file
34
server/tests-py/queries/webhooks/request_context/setup.yaml
Normal file
@ -0,0 +1,34 @@
|
||||
type: bulk
|
||||
args:
|
||||
|
||||
- type: run_sql
|
||||
args:
|
||||
sql: |
|
||||
CREATE TABLE users (
|
||||
id SERIAL PRIMARY KEY,
|
||||
name TEXT
|
||||
);
|
||||
INSERT INTO users VALUES (1, 'alice');
|
||||
INSERT INTO users VALUES (2, 'bob');
|
||||
|
||||
- type: track_table
|
||||
args:
|
||||
schema: public
|
||||
name: users
|
||||
|
||||
- type: create_select_permission
|
||||
args:
|
||||
table: users
|
||||
role: user
|
||||
permission:
|
||||
columns: '*'
|
||||
filter:
|
||||
id: X-Hasura-User-Id
|
||||
|
||||
- type: create_insert_permission
|
||||
args:
|
||||
table: users
|
||||
role: user
|
||||
permission:
|
||||
check:
|
||||
id: X-Hasura-User-Id
|
@ -0,0 +1,7 @@
|
||||
type: bulk
|
||||
args:
|
||||
|
||||
- type: run_sql
|
||||
args:
|
||||
sql: |
|
||||
DROP TABLE users;
|
138
server/tests-py/test_webhook_request_context.py
Normal file
138
server/tests-py/test_webhook_request_context.py
Normal file
@ -0,0 +1,138 @@
|
||||
import pytest
|
||||
import time
|
||||
import json
|
||||
import http
|
||||
import queue
|
||||
import socket
|
||||
from context import (
|
||||
HGECtx,
|
||||
HGECtxError,
|
||||
ActionsWebhookServer,
|
||||
EvtsWebhookServer,
|
||||
HGECtxGQLServer,
|
||||
GQLWsClient,
|
||||
PytestConf,
|
||||
)
|
||||
import threading
|
||||
import random
|
||||
from datetime import datetime
|
||||
import sys
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from validate import check_query
|
||||
|
||||
if not PytestConf.config.getoption("--test-webhook-request-context"):
|
||||
pytest.skip("--test-webhook-https-request-context flag is missing, skipping tests", allow_module_level=True)
|
||||
|
||||
|
||||
class QueryEchoWebhookHandler(http.server.BaseHTTPRequestHandler):
|
||||
def do_GET(self):
|
||||
self.log_message("get")
|
||||
self.send_response(http.HTTPStatus.OK)
|
||||
self.end_headers()
|
||||
|
||||
def do_POST(self):
|
||||
self.log_message("post")
|
||||
content_len = self.headers.get("Content-Length")
|
||||
req_body = self.rfile.read(int(content_len)).decode("utf-8")
|
||||
req_json = json.loads(req_body)
|
||||
req_headers = self.headers
|
||||
print(json.dumps(req_json))
|
||||
user_id_header = req_json["headers"]["auth-user-id"]
|
||||
req_path = self.path
|
||||
h = {
|
||||
"x-hasura-role":"user",
|
||||
"x-hasura-user-id": user_id_header
|
||||
}
|
||||
self.send_response(http.HTTPStatus.OK)
|
||||
self.send_header('Content-Type', 'application/json')
|
||||
self.end_headers()
|
||||
self.server.resp_queue.put({"request": req_json["request"]})
|
||||
self.wfile.write(json.dumps(h).encode('utf-8'))
|
||||
|
||||
|
||||
class QueryEchoWebhookServer(http.server.HTTPServer):
|
||||
def __init__(self, server_address):
|
||||
self.resp_queue = queue.Queue(maxsize=1)
|
||||
self.error_queue = queue.Queue()
|
||||
super().__init__(server_address, QueryEchoWebhookHandler)
|
||||
|
||||
def server_bind(self):
|
||||
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
self.socket.bind(self.server_address)
|
||||
|
||||
def get_event(self, timeout):
|
||||
return self.resp_queue.get(timeout=timeout)
|
||||
|
||||
def teardown(self):
|
||||
self.evt_trggr_httpd.shutdown()
|
||||
self.evt_trggr_httpd.server_close()
|
||||
graphql_server.stop_server(self.graphql_server)
|
||||
self.gql_srvr_thread.join()
|
||||
self.evt_trggr_web_server.join()
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def query_echo_webhook(request):
|
||||
# TODO(swann): is this the right port?
|
||||
webhook_httpd = QueryEchoWebhookServer(server_address=("127.0.0.1", 5594))
|
||||
web_server = threading.Thread(target=webhook_httpd.serve_forever)
|
||||
web_server.start()
|
||||
yield webhook_httpd
|
||||
webhook_httpd.shutdown()
|
||||
webhook_httpd.server_close()
|
||||
web_server.join()
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("per_method_tests_db_state")
|
||||
class TestWebhookRequestContext(object):
|
||||
@classmethod
|
||||
def dir(cls):
|
||||
return "queries/webhooks/request_context"
|
||||
|
||||
def test_query(self, hge_ctx, query_echo_webhook):
|
||||
query = """
|
||||
query allUsers {
|
||||
users {
|
||||
id
|
||||
name
|
||||
}
|
||||
}
|
||||
"""
|
||||
query_obj = {
|
||||
"query": query,
|
||||
"operationName": "allUsers"
|
||||
}
|
||||
headers = dict()
|
||||
headers['auth-user-id'] = '1'
|
||||
code, resp, _ = hge_ctx.anyq('/v1/graphql', query_obj, headers)
|
||||
assert code == 200, resp
|
||||
|
||||
ev_full = query_echo_webhook.get_event(3)
|
||||
exp_result = {"request": query_obj}
|
||||
assert ev_full['request'] == query_obj
|
||||
|
||||
def test_mutation_with_vars(self, hge_ctx, query_echo_webhook):
|
||||
query = """
|
||||
mutation insert_single_user($id: Int!, $name: String!) {
|
||||
insert_users_one(
|
||||
object: {
|
||||
id: $id,
|
||||
name: $name
|
||||
}
|
||||
) {
|
||||
id
|
||||
name
|
||||
}
|
||||
}
|
||||
"""
|
||||
variables = {"id": 4, "name": "danish"}
|
||||
query_obj = {"query": query, "variables": variables, "operationName": "insert_single_user"}
|
||||
headers = dict()
|
||||
headers['auth-user-id'] = '4'
|
||||
code, resp, _ = hge_ctx.anyq('/v1/graphql', query_obj, headers)
|
||||
assert code == 200, resp
|
||||
|
||||
ev_full = query_echo_webhook.get_event(3)
|
||||
exp_result = {"request": query_obj}
|
||||
assert ev_full['request'] == query_obj
|
@ -13,10 +13,6 @@ import http.server
|
||||
import traceback
|
||||
import sys
|
||||
|
||||
# FIXME(swann):
|
||||
print("in webhook.py")
|
||||
print("__name__ == {}".format(__name__))
|
||||
|
||||
class S(http.server.BaseHTTPRequestHandler):
|
||||
|
||||
|
||||
@ -59,6 +55,7 @@ class S(http.server.BaseHTTPRequestHandler):
|
||||
if 'headers' in req_json:
|
||||
self.handle_headers(req_json['headers'])
|
||||
else:
|
||||
# TODO: is this a typo?
|
||||
self.handler_headers({})
|
||||
|
||||
def run(keyfile, certfile, server_class=http.server.HTTPServer, handler_class=S, port=9090):
|
||||
|
Loading…
Reference in New Issue
Block a user