diff --git a/server/src-lib/Hasura/Server/App.hs b/server/src-lib/Hasura/Server/App.hs index 2748cdcafd0..1497c7c1eb3 100644 --- a/server/src-lib/Hasura/Server/App.hs +++ b/server/src-lib/Hasura/Server/App.hs @@ -218,11 +218,11 @@ isDeveloperAPIEnabled :: ServerCtx -> Bool isDeveloperAPIEnabled sc = S.member DEVELOPER $ scEnabledAPIs sc -- {-# SCC parseBody #-} -parseBody :: (FromJSON a, MonadError QErr m) => BL.ByteString -> m a +parseBody :: (FromJSON a, MonadError QErr m) => BL.ByteString -> m (Value, a) parseBody reqBody = case eitherDecode' reqBody of Left e -> throw400 InvalidJSON (T.pack e) - Right jVal -> decodeValue jVal + Right jVal -> (jVal, ) <$> decodeValue jVal onlyAdmin :: (MonadError QErr m, MonadReader HandlerCtx m) => m () onlyAdmin = do @@ -302,7 +302,7 @@ instance HasResourceLimits m => HasResourceLimits (ExceptT e m) instance HasResourceLimits m => HasResourceLimits (Tracing.TraceT m) mkSpockAction - :: (HasVersion, MonadIO m, MonadBaseControl IO m, FromJSON a, ToJSON a, UserAuthentication (Tracing.TraceT m), HttpLog m, Tracing.HasReporter m, HasResourceLimits m) + :: (HasVersion, MonadIO m, MonadBaseControl IO m, FromJSON a, UserAuthentication (Tracing.TraceT m), HttpLog m, Tracing.HasReporter m, HasResourceLimits m) => ServerCtx -> (Bool -> QErr -> Value) -- ^ `QErr` JSON encoder function @@ -310,116 +310,108 @@ mkSpockAction -- ^ `QErr` modifier -> APIHandler (Tracing.TraceT m) a -> Spock.ActionT m () -mkSpockAction serverCtx qErrEncoder qErrModifier apiHandler = do - req <- Spock.request - -- Bytes are actually read from the socket here. Time this. - (ioWaitTime, reqBody) <- withElapsedTime $ liftIO $ Wai.strictRequestBody req - let origHeaders = Wai.requestHeaders req - authMode = scAuthMode serverCtx - manager = scManager serverCtx - ipAddress = Wai.getSourceFromFallback req - pathInfo = Wai.rawPathInfo req +mkSpockAction serverCtx@ServerCtx{..} qErrEncoder qErrModifier apiHandler = do + req <- Spock.request + let origHeaders = Wai.requestHeaders req + ipAddress = Wai.getSourceFromFallback req + pathInfo = Wai.rawPathInfo req - (requestId, headers) <- getRequestId origHeaders + -- Bytes are actually read from the socket here. Time this. + (ioWaitTime, reqBody) <- withElapsedTime $ liftIO $ Wai.strictRequestBody req - tracingCtx <- liftIO $ Tracing.extractHttpContext headers + (requestId, headers) <- getRequestId origHeaders + tracingCtx <- liftIO $ Tracing.extractHttpContext headers + limits <- lift askResourceLimits - let runTraceT - :: forall m a - . (MonadIO m, Tracing.HasReporter m) - => Tracing.TraceT m a - -> m a - runTraceT = maybe - Tracing.runTraceT - Tracing.runTraceTInContext - tracingCtx - (fromString (B8.unpack pathInfo)) + let runTraceT + :: forall m a + . (MonadIO m, Tracing.HasReporter m) + => Tracing.TraceT m a + -> m a + runTraceT = maybe + Tracing.runTraceT + Tracing.runTraceTInContext + tracingCtx + (fromString (B8.unpack pathInfo)) - mapActionT runTraceT $ do - -- Add the request ID to the tracing metadata so that we - -- can correlate requests and traces - lift $ Tracing.attachMetadata [("request_id", unRequestId requestId)] + runHandler + :: MonadBaseControl IO m + => HandlerCtx + -> ReaderT HandlerCtx (MetadataStorageT m) a + -> m (Either QErr a) + runHandler st = runMetadataStorageT . flip runReaderT st . runResourceLimits limits - let getInfo parsedRequest = do - userInfoE <- fmap fst <$> lift (resolveUserInfo logger manager headers authMode parsedRequest) - userInfo <- onLeft userInfoE (logErrorAndResp Nothing requestId req (reqBody, Nothing) False origHeaders . qErrModifier) - let handlerState = HandlerCtx serverCtx userInfo headers requestId ipAddress - includeInternal = shouldIncludeInternal (_uiRole userInfo) $ - scResponseInternalErrorsConfig serverCtx - pure (userInfo, handlerState, includeInternal) - limits <- lift askResourceLimits - let runHandler - :: MonadBaseControl IO m - => HandlerCtx - -> ReaderT HandlerCtx (MetadataStorageT m) a - -> m (Either QErr a) - runHandler st = runMetadataStorageT . flip runReaderT st . runResourceLimits limits + getInfo parsedRequest = do + userInfoE <- fmap fst <$> lift (resolveUserInfo scLogger scManager headers scAuthMode parsedRequest) + userInfo <- onLeft userInfoE (logErrorAndResp Nothing requestId req (reqBody, Nothing) False origHeaders . qErrModifier) + pure ( userInfo + , HandlerCtx serverCtx userInfo headers requestId ipAddress + , shouldIncludeInternal (_uiRole userInfo) scResponseInternalErrorsConfig + ) - (serviceTime, (result, userInfo, includeInternal, query)) <- withElapsedTime $ case apiHandler of - -- in the case of a simple get/post we don't have to send the webhook anything - AHGet handler -> do - (userInfo, handlerState, includeInternal) <- getInfo Nothing - res <- lift $ runHandler handlerState handler - return (res, userInfo, includeInternal, Nothing) - AHPost handler -> do - (userInfo, handlerState, includeInternal) <- getInfo Nothing - parsedReqE <- runExceptT $ parseBody reqBody - parsedReq <- onLeft parsedReqE (logErrorAndResp (Just userInfo) requestId req (reqBody, Nothing) includeInternal origHeaders . qErrModifier) - res <- lift $ runHandler handlerState $ handler parsedReq - return (res, userInfo, includeInternal, Just parsedReq) - -- in this case we parse the request _first_ and then send the request to the webhook for auth - AHGraphQLRequest handler -> do - parsedReqE <- runExceptT $ parseBody reqBody - parsedReq <- onLeft parsedReqE (logErrorAndResp Nothing requestId req (reqBody, Nothing) False origHeaders . qErrModifier) - (userInfo, handlerState, includeInternal) <- getInfo (Just parsedReq) - res <- lift $ runHandler handlerState $ handler parsedReq - return (res, userInfo, includeInternal, Just parsedReq) + mapActionT runTraceT $ do + -- Add the request ID to the tracing metadata so that we + -- can correlate requests and traces + lift $ Tracing.attachMetadata [("request_id", unRequestId requestId)] - -- apply the error modifier - let modResult = fmapL qErrModifier result + (serviceTime, (result, userInfo, includeInternal, queryJSON)) <- withElapsedTime $ case apiHandler of + -- in the case of a simple get/post we don't have to send the webhook anything + AHGet handler -> do + (userInfo, handlerState, includeInternal) <- getInfo Nothing + res <- lift $ runHandler handlerState handler + pure (res , userInfo, includeInternal, Nothing) + AHPost handler -> do + (userInfo, handlerState, includeInternal) <- getInfo Nothing + (queryJSON, parsedReq) <- runExcept (parseBody reqBody) `onLeft` \e -> + logErrorAndResp (Just userInfo) requestId req (reqBody, Nothing) includeInternal origHeaders $ qErrModifier e + res <- lift $ runHandler handlerState $ handler parsedReq + pure (res, userInfo, includeInternal, Just queryJSON) + -- in this case we parse the request _first_ and then send the request to the webhook for auth + AHGraphQLRequest handler -> do + (queryJSON, parsedReq) <- runExcept (parseBody reqBody) `onLeft` \e -> + logErrorAndResp Nothing requestId req (reqBody, Nothing) False origHeaders $ qErrModifier e + (userInfo, handlerState, includeInternal) <- getInfo (Just parsedReq) + res <- lift $ runHandler handlerState $ handler parsedReq + pure (res, userInfo, includeInternal, Just queryJSON) - -- log and return result - case modResult of - Left err -> logErrorAndResp (Just userInfo) requestId req (reqBody, toJSON <$> query) includeInternal headers err - Right (httpLoggingMetadata, res) -> - logSuccessAndResp (Just userInfo) requestId req (reqBody, toJSON <$> query) res (Just (ioWaitTime, serviceTime)) origHeaders httpLoggingMetadata + -- apply the error modifier + let modResult = fmapL qErrModifier result - where - logger = scLogger serverCtx + -- log and return result + case modResult of + Left err -> + logErrorAndResp (Just userInfo) requestId req (reqBody, queryJSON) includeInternal headers err + Right (httpLoggingMetadata, res) -> + logSuccessAndResp (Just userInfo) requestId req (reqBody, queryJSON) res (Just (ioWaitTime, serviceTime)) origHeaders httpLoggingMetadata - logErrorAndResp - :: (MonadIO m, HttpLog m) - => Maybe UserInfo - -> RequestId - -> Wai.Request - -> (BL.ByteString, Maybe Value) - -> Bool - -> [HTTP.Header] - -> QErr - -> Spock.ActionCtxT ctx m a - logErrorAndResp userInfo reqId waiReq req includeInternal headers qErr = do - lift $ logHttpError logger userInfo reqId waiReq req qErr headers - Spock.setStatus $ qeStatus qErr - Spock.json $ qErrEncoder includeInternal qErr + where + logErrorAndResp + :: (MonadIO m, HttpLog m) + => Maybe UserInfo + -> RequestId + -> Wai.Request + -> (BL.ByteString, Maybe Value) + -> Bool + -> [HTTP.Header] + -> QErr + -> Spock.ActionCtxT ctx m a + logErrorAndResp userInfo reqId waiReq req includeInternal headers qErr = do + lift $ logHttpError scLogger userInfo reqId waiReq req qErr headers + Spock.setStatus $ qeStatus qErr + Spock.json $ qErrEncoder includeInternal qErr - logSuccessAndResp userInfo reqId waiReq reqBody result qTime reqHeaders httpLoggingMetadata = - case result of - JSONResp (HttpResponse encJson h) -> - possiblyCompressedLazyBytes userInfo reqId waiReq reqBody qTime (encJToLBS encJson) - (pure jsonHeader <> h) reqHeaders httpLoggingMetadata - RawResp (HttpResponse rawBytes h) -> - possiblyCompressedLazyBytes userInfo reqId waiReq reqBody qTime rawBytes h reqHeaders httpLoggingMetadata + logSuccessAndResp userInfo reqId waiReq req result qTime reqHeaders httpLoggingMetadata = do + let (respBytes, respHeaders) = case result of + JSONResp (HttpResponse encJson h) -> (encJToLBS encJson, pure jsonHeader <> h) + RawResp (HttpResponse rawBytes h) -> (rawBytes, h) + (compressedResp, mEncodingHeader, mCompressionType) = compressResponse (Wai.requestHeaders waiReq) respBytes + encodingHeader = onNothing mEncodingHeader [] + reqIdHeader = (requestIdHeader, txtToBs $ unRequestId reqId) + allRespHeaders = pure reqIdHeader <> encodingHeader <> respHeaders + lift $ logHttpSuccess scLogger userInfo reqId waiReq req respBytes compressedResp qTime mCompressionType reqHeaders httpLoggingMetadata + mapM_ setHeader allRespHeaders + Spock.lazyBytes compressedResp - possiblyCompressedLazyBytes userInfo reqId waiReq req qTime respBytes respHeaders reqHeaders httpLoggingMetadata = do - let (compressedResp, mEncodingHeader, mCompressionType) = - compressResponse (Wai.requestHeaders waiReq) respBytes - encodingHeader = onNothing mEncodingHeader [] - reqIdHeader = (requestIdHeader, txtToBs $ unRequestId reqId) - allRespHeaders = pure reqIdHeader <> encodingHeader <> respHeaders - lift $ logHttpSuccess logger userInfo reqId waiReq - req respBytes compressedResp qTime mCompressionType reqHeaders httpLoggingMetadata - traverse_ setHeader allRespHeaders - Spock.lazyBytes compressedResp v1QueryHandler :: ( HasVersion, MonadIO m, MonadBaseControl IO m, MonadMetadataApiAuthorization m, Tracing.MonadTrace m @@ -1066,8 +1058,7 @@ httpApp setupHook corsCfg serverCtx enableConsole consoleAssetsDir enableTelemet spockAction :: forall a n - . (FromJSON a, ToJSON a, MonadIO n, MonadBaseControl IO n, UserAuthentication (Tracing.TraceT n), HttpLog n, - Tracing.HasReporter n, HasResourceLimits n) + . (FromJSON a, MonadIO n, MonadBaseControl IO n, UserAuthentication (Tracing.TraceT n), HttpLog n, Tracing.HasReporter n, HasResourceLimits n) => (Bool -> QErr -> Value) -> (QErr -> QErr) -> APIHandler (Tracing.TraceT n) a -> Spock.ActionT n () spockAction = mkSpockAction serverCtx