mirror of
https://github.com/hasura/graphql-engine.git
synced 2024-12-16 18:42:30 +03:00
2152911e24
GitOrigin-RevId: 0dd10f1ccd338b1cf382ebff59b6ee7f209d39a1
208 lines
7.8 KiB
Haskell
208 lines
7.8 KiB
Haskell
{-# LANGUAGE NumericUnderscores #-}
|
|
|
|
module Hasura.Backends.BigQuery.Connection where
|
|
|
|
import Hasura.Prelude
|
|
|
|
import qualified Data.Aeson as J
|
|
import qualified Data.Aeson.Casing as J
|
|
import qualified Data.Aeson.TH as J
|
|
import qualified Data.ByteArray.Encoding as BAE
|
|
import qualified Data.ByteString as BS
|
|
import qualified Data.ByteString.Char8 as B8
|
|
import qualified Data.ByteString.Lazy as BL
|
|
import qualified Data.Environment as Env
|
|
import qualified Data.Text as T
|
|
import qualified Data.Text.Encoding as TE
|
|
import qualified Data.Text.Encoding.Error as TE
|
|
|
|
import Control.Concurrent.MVar
|
|
import Control.Exception
|
|
import Crypto.Hash.Algorithms (SHA256 (..))
|
|
import Crypto.PubKey.RSA.PKCS15 (signSafer)
|
|
import Crypto.PubKey.RSA.Types as Cry (Error)
|
|
import Data.Bifunctor (bimap)
|
|
import Data.Time.Clock
|
|
import Data.Time.Clock.POSIX (getPOSIXTime)
|
|
import Network.HTTP.Simple
|
|
import Network.HTTP.Types
|
|
|
|
import qualified Hasura.Backends.MSSQL.Connection as MSSQLConn (getEnv)
|
|
|
|
import Hasura.Backends.BigQuery.Source
|
|
import Hasura.Base.Error
|
|
|
|
|
|
newtype Scope
|
|
= Scope { unScope :: T.Text }
|
|
deriving (Show, Eq, IsString)
|
|
|
|
|
|
data GoogleAccessTokenRequest = GoogleAccessTokenRequest
|
|
{ _gatrGrantType :: !Text
|
|
, _gatrAssertion :: !Text
|
|
} deriving (Show, Eq)
|
|
$(J.deriveJSON (J.aesonDrop 5 J.snakeCase){J.omitNothingFields=False} ''GoogleAccessTokenRequest)
|
|
|
|
mkTokenRequest :: Text -> GoogleAccessTokenRequest
|
|
mkTokenRequest = GoogleAccessTokenRequest "urn:ietf:params:oauth:grant-type:jwt-bearer"
|
|
|
|
|
|
data TokenProblem
|
|
= BearerTokenDecodeProblem TE.UnicodeException
|
|
| BearerTokenSignsaferProblem Cry.Error
|
|
| TokenFetchProblem JSONException
|
|
| TokenRequestNonOK Status
|
|
deriving (Show)
|
|
instance Exception TokenProblem
|
|
|
|
data ServiceAccountProblem
|
|
= ServiceAccountFileDecodeProblem String
|
|
deriving (Show)
|
|
instance Exception ServiceAccountProblem
|
|
|
|
|
|
resolveConfigurationJson ::
|
|
(QErrM m, J.FromJSON a) =>
|
|
Env.Environment ->
|
|
ConfigurationJSON a -> -- REVIEW: Can this be made polymorphic?
|
|
m (Either String a)
|
|
resolveConfigurationJson env = \case
|
|
FromYamlJSON s -> pure . Right $ s
|
|
FromEnvJSON v -> do
|
|
fileContents <- MSSQLConn.getEnv env v
|
|
case J.eitherDecode . BL.fromStrict . TE.encodeUtf8 $ fileContents of
|
|
Left e -> pure . Left $ e
|
|
Right sa -> pure . Right $ sa
|
|
|
|
|
|
resolveConfigurationInput ::
|
|
QErrM m =>
|
|
Env.Environment ->
|
|
ConfigurationInput ->
|
|
m Text
|
|
resolveConfigurationInput env = \case
|
|
FromYaml s -> pure s
|
|
FromEnv v -> MSSQLConn.getEnv env v
|
|
|
|
|
|
resolveConfigurationInputs ::
|
|
QErrM m =>
|
|
Env.Environment ->
|
|
ConfigurationInputs ->
|
|
m [Text]
|
|
resolveConfigurationInputs env = \case
|
|
FromYamls a -> pure a
|
|
FromEnvs v -> filter (not . T.null) . T.splitOn "," <$> MSSQLConn.getEnv env v
|
|
|
|
|
|
getAccessToken :: MonadIO m => ServiceAccount -> m (Either TokenProblem TokenResp)
|
|
getAccessToken sa = do
|
|
eJwt <- encodeBearerJWT sa ["https://www.googleapis.com/auth/cloud-platform"]
|
|
case eJwt of
|
|
Left tokenProblem -> pure . Left $ tokenProblem
|
|
Right jwt ->
|
|
case TE.decodeUtf8' jwt of
|
|
Left unicodeEx -> pure . Left . BearerTokenDecodeProblem $ unicodeEx
|
|
Right assertion -> do
|
|
tokenFetchResponse :: Response (Either JSONException TokenResp) <-
|
|
httpJSONEither $
|
|
setRequestBodyJSON (mkTokenRequest assertion) $
|
|
parseRequest_ ("POST " <> tokenURL)
|
|
if getResponseStatusCode tokenFetchResponse /= 200
|
|
then
|
|
pure . Left . TokenRequestNonOK . getResponseStatus $ tokenFetchResponse
|
|
else
|
|
case getResponseBody tokenFetchResponse of
|
|
Left jsonEx -> pure . Left . TokenFetchProblem $ jsonEx
|
|
Right tr@TokenResp{_trExpiresAt} -> do
|
|
-- We add the current POSIXTime and store the POSIX "moment" at
|
|
-- which this token will expire, so that at the site where
|
|
-- we need to check if a token is nearing expiry, we only
|
|
-- need to compare it with the _then_ "current" POSIXTime.
|
|
expiresAt <- (fromIntegral _trExpiresAt +) <$> liftIO getPOSIXTime
|
|
pure . Right $ tr { _trExpiresAt = truncate expiresAt }
|
|
where
|
|
-- TODO: use jose for jwt encoding
|
|
b64EncodeJ :: (J.ToJSON a) => a -> BS.ByteString
|
|
b64EncodeJ = base64 . BL.toStrict . J.encode
|
|
base64 :: BS.ByteString -> BS.ByteString
|
|
base64 = BAE.convertToBase BAE.Base64URLUnpadded
|
|
tokenURL :: String
|
|
tokenURL = "https://www.googleapis.com/oauth2/v4/token"
|
|
maxTokenLifetime :: Int
|
|
maxTokenLifetime = 3600
|
|
truncateEquals :: B8.ByteString -> B8.ByteString
|
|
truncateEquals bs =
|
|
case B8.unsnoc bs of
|
|
Nothing -> mempty
|
|
Just (bs', x)
|
|
| x == '=' -> bs'
|
|
| otherwise -> bs
|
|
encodeBearerJWT :: ( MonadIO m ) => ServiceAccount -> [Scope] -> m (Either TokenProblem BS.ByteString)
|
|
encodeBearerJWT ServiceAccount{..} scopes = do
|
|
inp <- mkSigInput . truncate <$> liftIO getPOSIXTime
|
|
signRes <- liftIO $ signSafer (Just SHA256) (unPKey _saPrivateKey) inp
|
|
case signRes of
|
|
Left e -> pure . Left . BearerTokenSignsaferProblem $ e
|
|
Right sig -> pure . Right $ inp <> "." <> truncateEquals (base64 sig)
|
|
where
|
|
mkSigInput :: Int -> BS.ByteString
|
|
mkSigInput n = header <> "." <> payload
|
|
where
|
|
header = b64EncodeJ $ J.object
|
|
[ "alg" J..= ("RS256" :: T.Text)
|
|
, "typ" J..= ("JWT" :: T.Text)
|
|
]
|
|
payload = b64EncodeJ $ J.object [ "aud" J..= tokenURL , "scope" J..= T.intercalate " " (map unScope scopes)
|
|
, "iat" J..= n
|
|
, "exp" J..= (n + maxTokenLifetime)
|
|
, "iss" J..= _saClientEmail
|
|
]
|
|
|
|
|
|
getServiceAccount :: MonadIO m => FilePath -> m (Either ServiceAccountProblem ServiceAccount)
|
|
getServiceAccount serviceAccountFilePath =
|
|
bimap ServiceAccountFileDecodeProblem id . J.eitherDecode' <$> liftIO (BL.readFile serviceAccountFilePath)
|
|
|
|
|
|
-- | Get a usable token. If the token has expired refresh it.
|
|
getUsableToken :: MonadIO m => BigQuerySourceConfig -> m (Either TokenProblem TokenResp)
|
|
getUsableToken BigQuerySourceConfig{_scServiceAccount, _scAccessTokenMVar} =
|
|
liftIO $ modifyMVar _scAccessTokenMVar $ \mTokenResp -> do
|
|
case mTokenResp of
|
|
Nothing -> do
|
|
refreshedToken <- getAccessToken _scServiceAccount
|
|
case refreshedToken of
|
|
Left e -> pure (Nothing, Left e)
|
|
Right t -> pure (Just t, Right t)
|
|
Just t@TokenResp{_trAccessToken, _trExpiresAt} -> do
|
|
pt <- liftIO $ getPOSIXTime
|
|
if (pt >= fromIntegral _trExpiresAt - (10 :: NominalDiffTime)) -- when posix-time is greater than expires-at-minus-threshold
|
|
then do
|
|
refreshedToken' <- getAccessToken _scServiceAccount
|
|
case refreshedToken' of
|
|
Left e -> pure (Just t, Left e)
|
|
Right t' -> pure (Just t', Right t')
|
|
else pure (Just t, Right t)
|
|
|
|
|
|
data BigQueryProblem
|
|
= TokenProblem TokenProblem
|
|
deriving (Show)
|
|
|
|
|
|
runBigQuery ::
|
|
(MonadIO m) =>
|
|
BigQuerySourceConfig ->
|
|
Request ->
|
|
m (Either BigQueryProblem (Response BL.ByteString))
|
|
runBigQuery sc req = do
|
|
eToken <- getUsableToken sc
|
|
case eToken of
|
|
Left e -> pure . Left . TokenProblem $ e
|
|
Right TokenResp{_trAccessToken, _trExpiresAt} -> do
|
|
let req' = setRequestHeader "Authorization" ["Bearer " <> (TE.encodeUtf8 . coerce) _trAccessToken] req
|
|
-- TODO: Make this catch the HTTP exceptions
|
|
Right <$> httpLBS req'
|