mirror of
https://github.com/hasura/graphql-engine.git
synced 2024-12-17 04:24:35 +03:00
b167120f96
We'll see if this improves compile times at all, but I think it's worth doing as at least the most minimal form of module documentation. This was accomplished by first compiling everything with -ddump-minimal-imports, and then a bunch of scripting (with help from ormolu) **EDIT** it doesn't seem to improve CI compile times but the noise floor is high as it looks like we're not caching library dependencies anymore PR-URL: https://github.com/hasura/graphql-engine-mono/pull/2730 GitOrigin-RevId: 667eb8de1e0f1af70420cbec90402922b8b84cb4
203 lines
7.4 KiB
Haskell
203 lines
7.4 KiB
Haskell
{-# LANGUAGE NumericUnderscores #-}
|
|
|
|
module Hasura.Backends.BigQuery.Connection
|
|
( BigQueryProblem,
|
|
resolveConfigurationInput,
|
|
resolveConfigurationInputs,
|
|
resolveConfigurationJson,
|
|
runBigQuery,
|
|
)
|
|
where
|
|
|
|
import Control.Concurrent.MVar
|
|
import Control.Exception
|
|
import Crypto.Hash.Algorithms (SHA256 (..))
|
|
import Crypto.PubKey.RSA.PKCS15 (signSafer)
|
|
import Crypto.PubKey.RSA.Types as Cry (Error)
|
|
import Data.Aeson qualified as J
|
|
import Data.Aeson.Casing qualified as J
|
|
import Data.Aeson.TH qualified as J
|
|
import Data.ByteArray.Encoding qualified as BAE
|
|
import Data.ByteString qualified as BS
|
|
import Data.ByteString.Char8 qualified as B8
|
|
import Data.ByteString.Lazy qualified as BL
|
|
import Data.Environment qualified as Env
|
|
import Data.Text qualified as T
|
|
import Data.Text.Encoding qualified as TE
|
|
import Data.Text.Encoding.Error qualified as TE
|
|
import Data.Time.Clock
|
|
import Data.Time.Clock.POSIX (getPOSIXTime)
|
|
import Hasura.Backends.BigQuery.Source
|
|
import Hasura.Backends.MSSQL.Connection qualified as MSSQLConn (getEnv)
|
|
import Hasura.Base.Error
|
|
import Hasura.Prelude
|
|
import Network.HTTP.Simple
|
|
import Network.HTTP.Types
|
|
|
|
newtype Scope = Scope {unScope :: T.Text}
|
|
deriving (Show, Eq, IsString)
|
|
|
|
data GoogleAccessTokenRequest = GoogleAccessTokenRequest
|
|
{ _gatrGrantType :: !Text,
|
|
_gatrAssertion :: !Text
|
|
}
|
|
deriving (Show, Eq)
|
|
|
|
$(J.deriveJSON (J.aesonDrop 5 J.snakeCase) {J.omitNothingFields = False} ''GoogleAccessTokenRequest)
|
|
|
|
mkTokenRequest :: Text -> GoogleAccessTokenRequest
|
|
mkTokenRequest = GoogleAccessTokenRequest "urn:ietf:params:oauth:grant-type:jwt-bearer"
|
|
|
|
data TokenProblem
|
|
= BearerTokenDecodeProblem TE.UnicodeException
|
|
| BearerTokenSignsaferProblem Cry.Error
|
|
| TokenFetchProblem JSONException
|
|
| TokenRequestNonOK Status
|
|
deriving (Show)
|
|
|
|
instance Exception TokenProblem
|
|
|
|
data ServiceAccountProblem
|
|
= ServiceAccountFileDecodeProblem String
|
|
deriving (Show)
|
|
|
|
instance Exception ServiceAccountProblem
|
|
|
|
resolveConfigurationJson ::
|
|
(QErrM m, J.FromJSON a) =>
|
|
Env.Environment ->
|
|
ConfigurationJSON a -> -- REVIEW: Can this be made polymorphic?
|
|
m (Either String a)
|
|
resolveConfigurationJson env = \case
|
|
FromYamlJSON s -> pure . Right $ s
|
|
FromEnvJSON v -> do
|
|
fileContents <- MSSQLConn.getEnv env v
|
|
case J.eitherDecode . BL.fromStrict . TE.encodeUtf8 $ fileContents of
|
|
Left e -> pure . Left $ e
|
|
Right sa -> pure . Right $ sa
|
|
|
|
resolveConfigurationInput ::
|
|
QErrM m =>
|
|
Env.Environment ->
|
|
ConfigurationInput ->
|
|
m Text
|
|
resolveConfigurationInput env = \case
|
|
FromYaml s -> pure s
|
|
FromEnv v -> MSSQLConn.getEnv env v
|
|
|
|
resolveConfigurationInputs ::
|
|
QErrM m =>
|
|
Env.Environment ->
|
|
ConfigurationInputs ->
|
|
m [Text]
|
|
resolveConfigurationInputs env = \case
|
|
FromYamls a -> pure a
|
|
FromEnvs v -> filter (not . T.null) . T.splitOn "," <$> MSSQLConn.getEnv env v
|
|
|
|
getAccessToken :: MonadIO m => ServiceAccount -> m (Either TokenProblem TokenResp)
|
|
getAccessToken sa = do
|
|
eJwt <- encodeBearerJWT sa ["https://www.googleapis.com/auth/cloud-platform"]
|
|
case eJwt of
|
|
Left tokenProblem -> pure . Left $ tokenProblem
|
|
Right jwt ->
|
|
case TE.decodeUtf8' jwt of
|
|
Left unicodeEx -> pure . Left . BearerTokenDecodeProblem $ unicodeEx
|
|
Right assertion -> do
|
|
tokenFetchResponse :: Response (Either JSONException TokenResp) <-
|
|
httpJSONEither $
|
|
setRequestBodyJSON (mkTokenRequest assertion) $
|
|
parseRequest_ ("POST " <> tokenURL)
|
|
if getResponseStatusCode tokenFetchResponse /= 200
|
|
then pure . Left . TokenRequestNonOK . getResponseStatus $ tokenFetchResponse
|
|
else case getResponseBody tokenFetchResponse of
|
|
Left jsonEx -> pure . Left . TokenFetchProblem $ jsonEx
|
|
Right tr@TokenResp {_trExpiresAt} -> do
|
|
-- We add the current POSIXTime and store the POSIX "moment" at
|
|
-- which this token will expire, so that at the site where
|
|
-- we need to check if a token is nearing expiry, we only
|
|
-- need to compare it with the _then_ "current" POSIXTime.
|
|
expiresAt <- (fromIntegral _trExpiresAt +) <$> liftIO getPOSIXTime
|
|
pure . Right $ tr {_trExpiresAt = truncate expiresAt}
|
|
where
|
|
-- TODO: use jose for jwt encoding
|
|
b64EncodeJ :: (J.ToJSON a) => a -> BS.ByteString
|
|
b64EncodeJ = base64 . BL.toStrict . J.encode
|
|
base64 :: BS.ByteString -> BS.ByteString
|
|
base64 = BAE.convertToBase BAE.Base64URLUnpadded
|
|
tokenURL :: String
|
|
tokenURL = "https://www.googleapis.com/oauth2/v4/token"
|
|
maxTokenLifetime :: Int
|
|
maxTokenLifetime = 3600
|
|
truncateEquals :: B8.ByteString -> B8.ByteString
|
|
truncateEquals bs =
|
|
case B8.unsnoc bs of
|
|
Nothing -> mempty
|
|
Just (bs', x)
|
|
| x == '=' -> bs'
|
|
| otherwise -> bs
|
|
encodeBearerJWT :: (MonadIO m) => ServiceAccount -> [Scope] -> m (Either TokenProblem BS.ByteString)
|
|
encodeBearerJWT ServiceAccount {..} scopes = do
|
|
inp <- mkSigInput . truncate <$> liftIO getPOSIXTime
|
|
signRes <- liftIO $ signSafer (Just SHA256) (unPKey _saPrivateKey) inp
|
|
case signRes of
|
|
Left e -> pure . Left . BearerTokenSignsaferProblem $ e
|
|
Right sig -> pure . Right $ inp <> "." <> truncateEquals (base64 sig)
|
|
where
|
|
mkSigInput :: Int -> BS.ByteString
|
|
mkSigInput n = header <> "." <> payload
|
|
where
|
|
header =
|
|
b64EncodeJ $
|
|
J.object
|
|
[ "alg" J..= ("RS256" :: T.Text),
|
|
"typ" J..= ("JWT" :: T.Text)
|
|
]
|
|
payload =
|
|
b64EncodeJ $
|
|
J.object
|
|
[ "aud" J..= tokenURL,
|
|
"scope" J..= T.intercalate " " (map unScope scopes),
|
|
"iat" J..= n,
|
|
"exp" J..= (n + maxTokenLifetime),
|
|
"iss" J..= _saClientEmail
|
|
]
|
|
|
|
-- | Get a usable token. If the token has expired refresh it.
|
|
getUsableToken :: MonadIO m => BigQuerySourceConfig -> m (Either TokenProblem TokenResp)
|
|
getUsableToken BigQuerySourceConfig {_scServiceAccount, _scAccessTokenMVar} =
|
|
liftIO $
|
|
modifyMVar _scAccessTokenMVar $ \mTokenResp -> do
|
|
case mTokenResp of
|
|
Nothing -> do
|
|
refreshedToken <- getAccessToken _scServiceAccount
|
|
case refreshedToken of
|
|
Left e -> pure (Nothing, Left e)
|
|
Right t -> pure (Just t, Right t)
|
|
Just t@TokenResp {_trAccessToken, _trExpiresAt} -> do
|
|
pt <- liftIO $ getPOSIXTime
|
|
if (pt >= fromIntegral _trExpiresAt - (10 :: NominalDiffTime)) -- when posix-time is greater than expires-at-minus-threshold
|
|
then do
|
|
refreshedToken' <- getAccessToken _scServiceAccount
|
|
case refreshedToken' of
|
|
Left e -> pure (Just t, Left e)
|
|
Right t' -> pure (Just t', Right t')
|
|
else pure (Just t, Right t)
|
|
|
|
data BigQueryProblem
|
|
= TokenProblem TokenProblem
|
|
deriving (Show)
|
|
|
|
runBigQuery ::
|
|
(MonadIO m) =>
|
|
BigQuerySourceConfig ->
|
|
Request ->
|
|
m (Either BigQueryProblem (Response BL.ByteString))
|
|
runBigQuery sc req = do
|
|
eToken <- getUsableToken sc
|
|
case eToken of
|
|
Left e -> pure . Left . TokenProblem $ e
|
|
Right TokenResp {_trAccessToken, _trExpiresAt} -> do
|
|
let req' = setRequestHeader "Authorization" ["Bearer " <> (TE.encodeUtf8 . coerce) _trAccessToken] req
|
|
-- TODO: Make this catch the HTTP exceptions
|
|
Right <$> httpLBS req'
|