{-# LANGUAGE DeriveAnyClass #-}

-- | CORS (Cross Origin Resource Sharing) related configuration

module Hasura.Server.Cors
  ( CorsConfig (..)
  , CorsPolicy (..)
  , parseOrigin
  , readCorsDomains
  , mkDefaultCorsPolicy
  , isCorsDisabled
  , Domains (..)
  , inWildcardList
  ) where

import           Hasura.Prelude
import           Hasura.Server.Utils  (fmapL)

import           Control.Applicative  (optional)

import qualified Data.Aeson           as J
import qualified Data.Aeson.Casing    as J
import qualified Data.Aeson.TH        as J
import qualified Data.Attoparsec.Text as AT
import qualified Data.HashSet         as Set
import qualified Data.Text            as T


data DomainParts =
  DomainParts
  { wdScheme :: !Text
  , wdHost   :: !Text -- the hostname part (without the *.)
  , wdPort   :: !(Maybe Int)
  } deriving (Show, Eq, Generic, Hashable)

$(J.deriveToJSON (J.aesonDrop 2 J.snakeCase) ''DomainParts)

data Domains
  = Domains
  { dmFqdns     :: !(Set.HashSet Text)
  , dmWildcards :: !(Set.HashSet DomainParts)
  } deriving (Show, Eq)

$(J.deriveToJSON (J.aesonDrop 2 J.snakeCase) ''Domains)

data CorsConfig
  = CCAllowAll
  | CCAllowedOrigins Domains
  | CCDisabled Bool -- should read cookie?
  deriving (Show, Eq)

instance J.ToJSON CorsConfig where
  toJSON c = case c of
    CCDisabled wsrc    -> toJ True J.Null (Just wsrc)
    CCAllowAll         -> toJ False (J.String "*") Nothing
    CCAllowedOrigins d -> toJ False (J.toJSON d) Nothing
    where
      toJ :: Bool -> J.Value -> Maybe Bool -> J.Value
      toJ dis origs mWsRC =
        J.object [ "disabled" J..= dis
                 , "ws_read_cookie" J..= mWsRC
                 , "allowed_origins" J..= origs
                 ]

isCorsDisabled :: CorsConfig -> Bool
isCorsDisabled = \case
  CCDisabled _ -> True
  _          -> False

readCorsDomains :: String -> Either String CorsConfig
readCorsDomains str
  | str == "*"               = pure CCAllowAll
  | otherwise = do
      let domains = map T.strip $ T.splitOn "," (T.pack str)
      pDomains <- mapM parseOptWildcardDomain domains
      let (fqdns, wcs) = (lefts pDomains, rights pDomains)
      return $ CCAllowedOrigins $ Domains (Set.fromList fqdns) (Set.fromList wcs)


data CorsPolicy
  = CorsPolicy
  { cpConfig  :: !CorsConfig
  , cpMethods :: ![Text]
  , cpMaxAge  :: !Int
  } deriving (Show, Eq)

mkDefaultCorsPolicy :: CorsConfig -> CorsPolicy
mkDefaultCorsPolicy cfg =
  CorsPolicy
  { cpConfig = cfg
  , cpMethods = ["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"]
  , cpMaxAge = 1728000
  }

inWildcardList :: Domains -> Text -> Bool
inWildcardList (Domains _ wildcards) origin =
  either (const False) (`Set.member` wildcards) $ parseOrigin origin


-- | Parsers for wildcard domains

runParser :: AT.Parser a -> Text -> Either String a
runParser = AT.parseOnly

parseOrigin :: Text -> Either String DomainParts
parseOrigin = runParser originParser

originParser :: AT.Parser DomainParts
originParser =
  domainParser (Just ignoreSubdomain)
  where ignoreSubdomain = do
          s <- AT.takeTill (== '.')
          void $ AT.char '.'
          return s

parseOptWildcardDomain :: Text -> Either String (Either Text DomainParts)
parseOptWildcardDomain d =
  fmapL (const errMsg) $ runParser optWildcardDomainParser d
  where

    optWildcardDomainParser :: AT.Parser (Either Text DomainParts)
    optWildcardDomainParser =
      Right <$> wildcardDomainParser <|> Left <$> fqdnParser

    errMsg = "invalid domain: '" <> T.unpack d <> "'. " <> helpMsg
    helpMsg = "All domains should have scheme + (optional wildcard) host + "
              <> "(optional port)"

    wildcardDomainParser :: AT.Parser DomainParts
    wildcardDomainParser = domainParser $ Just (AT.string "*" *> AT.string ".")

    fqdnParser :: AT.Parser Text
    fqdnParser = do
      (DomainParts scheme host port) <- domainParser Nothing
      let sPort = maybe "" (\p -> ":" <> T.pack (show p)) port
      return $ scheme <> host <> sPort


domainParser :: Maybe (AT.Parser Text) -> AT.Parser DomainParts
domainParser parser = do
  scheme <- schemeParser
  forM_ parser void
  host <- hostPortParser
  port <- optional portParser
  return $ DomainParts scheme host port

  where
    schemeParser :: AT.Parser Text
    schemeParser = AT.string "http://" <|> AT.string "https://"

    hostPortParser :: AT.Parser Text
    hostPortParser = hostWithPortParser <|> AT.takeText

    hostWithPortParser :: AT.Parser Text
    hostWithPortParser = do
      h <- AT.takeWhile1 (/= ':')
      void $ AT.char ':'
      return h

    portParser :: AT.Parser Int
    portParser = AT.decimal