{-# LANGUAGE NumericUnderscores #-} module Hasura.Backends.BigQuery.Connection where import Hasura.Prelude import qualified Data.Aeson as J import qualified Data.Aeson.Casing as J import qualified Data.Aeson.TH as J import qualified Data.ByteArray.Encoding as BAE import qualified Data.ByteString as BS import qualified Data.ByteString.Char8 as B8 import qualified Data.ByteString.Lazy as BL import qualified Data.Environment as Env import qualified Data.Text as T import qualified Data.Text.Encoding as TE import qualified Data.Text.Encoding.Error as TE import Control.Concurrent.MVar import Control.Exception import Crypto.Hash.Algorithms (SHA256 (..)) import Crypto.PubKey.RSA.PKCS15 (signSafer) import Crypto.PubKey.RSA.Types as Cry (Error) import Data.Bifunctor (bimap) import Data.Time.Clock import Data.Time.Clock.POSIX (getPOSIXTime) import Network.HTTP.Simple import Network.HTTP.Types import qualified Hasura.Backends.MSSQL.Connection as MSSQLConn (getEnv) import Hasura.Backends.BigQuery.Source import Hasura.Base.Error newtype Scope = Scope { unScope :: T.Text } deriving (Show, Eq, IsString) data GoogleAccessTokenRequest = GoogleAccessTokenRequest { _gatrGrantType :: !Text , _gatrAssertion :: !Text } deriving (Show, Eq) $(J.deriveJSON (J.aesonDrop 5 J.snakeCase){J.omitNothingFields=False} ''GoogleAccessTokenRequest) mkTokenRequest :: Text -> GoogleAccessTokenRequest mkTokenRequest = GoogleAccessTokenRequest "urn:ietf:params:oauth:grant-type:jwt-bearer" data TokenProblem = BearerTokenDecodeProblem TE.UnicodeException | BearerTokenSignsaferProblem Cry.Error | TokenFetchProblem JSONException | TokenRequestNonOK Status deriving (Show) instance Exception TokenProblem data ServiceAccountProblem = ServiceAccountFileDecodeProblem String deriving (Show) instance Exception ServiceAccountProblem resolveConfigurationJson :: (QErrM m, J.FromJSON a) => Env.Environment -> ConfigurationJSON a -> -- REVIEW: Can this be made polymorphic? m (Either String a) resolveConfigurationJson env = \case FromYamlJSON s -> pure . Right $ s FromEnvJSON v -> do fileContents <- MSSQLConn.getEnv env v case J.eitherDecode . BL.fromStrict . TE.encodeUtf8 $ fileContents of Left e -> pure . Left $ e Right sa -> pure . Right $ sa resolveConfigurationInput :: QErrM m => Env.Environment -> ConfigurationInput -> m Text resolveConfigurationInput env = \case FromYaml s -> pure s FromEnv v -> MSSQLConn.getEnv env v resolveConfigurationInputs :: QErrM m => Env.Environment -> ConfigurationInputs -> m [Text] resolveConfigurationInputs env = \case FromYamls a -> pure a FromEnvs v -> filter (not . T.null) . T.splitOn "," <$> MSSQLConn.getEnv env v getAccessToken :: MonadIO m => ServiceAccount -> m (Either TokenProblem TokenResp) getAccessToken sa = do eJwt <- encodeBearerJWT sa ["https://www.googleapis.com/auth/cloud-platform"] case eJwt of Left tokenProblem -> pure . Left $ tokenProblem Right jwt -> case TE.decodeUtf8' jwt of Left unicodeEx -> pure . Left . BearerTokenDecodeProblem $ unicodeEx Right assertion -> do tokenFetchResponse :: Response (Either JSONException TokenResp) <- httpJSONEither $ setRequestBodyJSON (mkTokenRequest assertion) $ parseRequest_ ("POST " <> tokenURL) if getResponseStatusCode tokenFetchResponse /= 200 then pure . Left . TokenRequestNonOK . getResponseStatus $ tokenFetchResponse else case getResponseBody tokenFetchResponse of Left jsonEx -> pure . Left . TokenFetchProblem $ jsonEx Right tr@TokenResp{_trExpiresAt} -> do -- We add the current POSIXTime and store the POSIX "moment" at -- which this token will expire, so that at the site where -- we need to check if a token is nearing expiry, we only -- need to compare it with the _then_ "current" POSIXTime. expiresAt <- (fromIntegral _trExpiresAt +) <$> liftIO getPOSIXTime pure . Right $ tr { _trExpiresAt = truncate expiresAt } where -- TODO: use jose for jwt encoding b64EncodeJ :: (J.ToJSON a) => a -> BS.ByteString b64EncodeJ = base64 . BL.toStrict . J.encode base64 :: BS.ByteString -> BS.ByteString base64 = BAE.convertToBase BAE.Base64URLUnpadded tokenURL :: String tokenURL = "https://www.googleapis.com/oauth2/v4/token" maxTokenLifetime :: Int maxTokenLifetime = 3600 truncateEquals :: B8.ByteString -> B8.ByteString truncateEquals bs = case B8.unsnoc bs of Nothing -> mempty Just (bs', x) | x == '=' -> bs' | otherwise -> bs encodeBearerJWT :: ( MonadIO m ) => ServiceAccount -> [Scope] -> m (Either TokenProblem BS.ByteString) encodeBearerJWT ServiceAccount{..} scopes = do inp <- mkSigInput . truncate <$> liftIO getPOSIXTime signRes <- liftIO $ signSafer (Just SHA256) (unPKey _saPrivateKey) inp case signRes of Left e -> pure . Left . BearerTokenSignsaferProblem $ e Right sig -> pure . Right $ inp <> "." <> truncateEquals (base64 sig) where mkSigInput :: Int -> BS.ByteString mkSigInput n = header <> "." <> payload where header = b64EncodeJ $ J.object [ "alg" J..= ("RS256" :: T.Text) , "typ" J..= ("JWT" :: T.Text) ] payload = b64EncodeJ $ J.object [ "aud" J..= tokenURL , "scope" J..= T.intercalate " " (map unScope scopes) , "iat" J..= n , "exp" J..= (n + maxTokenLifetime) , "iss" J..= _saClientEmail ] getServiceAccount :: MonadIO m => FilePath -> m (Either ServiceAccountProblem ServiceAccount) getServiceAccount serviceAccountFilePath = bimap ServiceAccountFileDecodeProblem id . J.eitherDecode' <$> liftIO (BL.readFile serviceAccountFilePath) -- | Get a usable token. If the token has expired refresh it. getUsableToken :: MonadIO m => BigQuerySourceConfig -> m (Either TokenProblem TokenResp) getUsableToken BigQuerySourceConfig{_scServiceAccount, _scAccessTokenMVar} = liftIO $ modifyMVar _scAccessTokenMVar $ \mTokenResp -> do case mTokenResp of Nothing -> do refreshedToken <- getAccessToken _scServiceAccount case refreshedToken of Left e -> pure (Nothing, Left e) Right t -> pure (Just t, Right t) Just t@TokenResp{_trAccessToken, _trExpiresAt} -> do pt <- liftIO $ getPOSIXTime if (pt >= fromIntegral _trExpiresAt - (10 :: NominalDiffTime)) -- when posix-time is greater than expires-at-minus-threshold then do refreshedToken' <- getAccessToken _scServiceAccount case refreshedToken' of Left e -> pure (Just t, Left e) Right t' -> pure (Just t', Right t') else pure (Just t, Right t) data BigQueryProblem = TokenProblem TokenProblem deriving (Show) runBigQuery :: (MonadIO m) => BigQuerySourceConfig -> Request -> m (Either BigQueryProblem (Response BL.ByteString)) runBigQuery sc req = do eToken <- getUsableToken sc case eToken of Left e -> pure . Left . TokenProblem $ e Right TokenResp{_trAccessToken, _trExpiresAt} -> do let req' = setRequestHeader "Authorization" ["Bearer " <> (TE.encodeUtf8 . coerce) _trAccessToken] req -- TODO: Make this catch the HTTP exceptions Right <$> httpLBS req'