module Hasura.Server.Middleware ( corsMiddleware, ) where import Control.Applicative import Data.ByteString qualified as B import Data.CaseInsensitive qualified as CI import Data.Text.Encoding qualified as TE import Hasura.Prelude import Hasura.Server.Cors import Hasura.Server.Utils import Network.HTTP.Types qualified as H import Network.Wai corsMiddleware :: CorsPolicy -> Middleware corsMiddleware policy app req sendResp = do let origin = getRequestHeader "Origin" $ requestHeaders req maybe (app req sendResp) handleCors origin where handleCors origin = case cpConfig policy of CCDisabled _ -> app req sendResp CCAllowAll -> sendCors origin CCAllowedOrigins ds -- if the origin is in our cors domains, send cors headers | bsToTxt origin `elem` dmFqdns ds -> sendCors origin -- if current origin is part of wildcard domain list, send cors | inWildcardList ds (bsToTxt origin) -> sendCors origin -- otherwise don't send cors headers | otherwise -> app req sendResp sendCors :: B.ByteString -> IO ResponseReceived sendCors origin = case requestMethod req of "OPTIONS" -> sendResp $ respondPreFlight origin _ -> app req $ sendResp . injectCorsHeaders origin respondPreFlight :: B.ByteString -> Response respondPreFlight origin = setHeaders (mkPreFlightHeaders requestedHeaders) $ injectCorsHeaders origin emptyResponse emptyResponse = responseLBS H.status204 [] "" requestedHeaders = fromMaybe "" $ getRequestHeader "Access-Control-Request-Headers" $ requestHeaders req injectCorsHeaders :: B.ByteString -> Response -> Response injectCorsHeaders origin = setHeaders (mkCorsHeaders origin) mkPreFlightHeaders allowReqHdrs = [ ("Access-Control-Max-Age", "1728000"), ("Access-Control-Allow-Headers", allowReqHdrs), ("Content-Length", "0"), ("Content-Type", "text/plain charset=UTF-8") ] mkCorsHeaders origin = [ ("Access-Control-Allow-Origin", origin), ("Access-Control-Allow-Credentials", "true"), ( "Access-Control-Allow-Methods", B.intercalate "," $ TE.encodeUtf8 <$> cpMethods policy ) ] setHeaders hdrs = mapResponseHeaders (\h -> mkRespHdrs hdrs ++ h) mkRespHdrs = map (\(k, v) -> (CI.mk k, v))