module Hasura.RQL.DML.Count ( CountQueryP1 (..), validateCountQWith, validateCountQ, runCount, countQToTx, ) where import Control.Monad.Trans.Control (MonadBaseControl) import Data.Aeson import Data.ByteString.Builder qualified as BB import Data.Sequence qualified as DS import Database.PG.Query qualified as Q import Hasura.Backends.Postgres.SQL.DML qualified as S import Hasura.Backends.Postgres.SQL.Types import Hasura.Backends.Postgres.Translate.BoolExp import Hasura.Base.Error import Hasura.EncJSON import Hasura.Prelude import Hasura.RQL.DML.Internal import Hasura.RQL.DML.Types import Hasura.RQL.IR.BoolExp import Hasura.RQL.Types import Hasura.SQL.Types import Hasura.Session import Hasura.Tracing qualified as Tracing data CountQueryP1 = CountQueryP1 { cqp1Table :: !QualifiedTable, cqp1Where :: !(AnnBoolExpSQL ('Postgres 'Vanilla), Maybe (AnnBoolExpSQL ('Postgres 'Vanilla))), cqp1Distinct :: !(Maybe [PGCol]) } deriving (Eq) mkSQLCount :: CountQueryP1 -> S.Select mkSQLCount (CountQueryP1 tn (permFltr, mWc) mDistCols) = S.mkSelect { S.selExtr = [S.Extractor S.countStar Nothing], S.selFrom = Just $ S.FromExp [S.mkSelFromExp False innerSel $ TableName "r"] } where finalWC = toSQLBoolExp (S.QualTable tn) $ maybe permFltr (andAnnBoolExps permFltr) mWc innerSel = partSel { S.selFrom = Just $ S.mkSimpleFromExp tn, S.selWhere = S.WhereFrag <$> Just finalWC } partSel = case mDistCols of Just distCols -> let extrs = flip map distCols $ \c -> S.Extractor (S.mkSIdenExp c) Nothing in S.mkSelect { S.selDistinct = Just S.DistinctSimple, S.selExtr = extrs } Nothing -> S.mkSelect { S.selExtr = [S.Extractor (S.SEStar Nothing) Nothing] } -- SELECT count(*) FROM (SELECT DISTINCT c1, .. cn FROM .. WHERE ..) r; -- SELECT count(*) FROM (SELECT * FROM .. WHERE ..) r; validateCountQWith :: (UserInfoM m, QErrM m, TableInfoRM ('Postgres 'Vanilla) m) => SessionVariableBuilder ('Postgres 'Vanilla) m -> (ColumnType ('Postgres 'Vanilla) -> Value -> m S.SQLExp) -> CountQuery -> m CountQueryP1 validateCountQWith sessVarBldr prepValBldr (CountQuery qt _ mDistCols mWhere) = do tableInfo <- askTableInfoSource qt -- Check if select is allowed selPerm <- modifyErr (<> selNecessaryMsg) $ askSelPermInfo tableInfo let colInfoMap = _tciFieldInfoMap $ _tiCoreInfo tableInfo forM_ mDistCols $ \distCols -> do let distColAsrns = [ checkSelOnCol selPerm, assertColumnExists colInfoMap relInDistColsErr ] withPathK "distinct" $ verifyAsrns distColAsrns distCols -- convert the where clause annSQLBoolExp <- forM mWhere $ \be -> withPathK "where" $ convBoolExp colInfoMap selPerm be sessVarBldr qt (valueParserWithCollectableType prepValBldr) resolvedSelFltr <- convAnnBoolExpPartialSQL sessVarBldr $ spiFilter selPerm return $ CountQueryP1 qt (resolvedSelFltr, annSQLBoolExp) mDistCols where selNecessaryMsg = "; \"count\" is only allowed if the role " <> "has \"select\" permissions on the table" relInDistColsErr = "Relationships can't be used in \"distinct\"." validateCountQ :: (QErrM m, UserInfoM m, CacheRM m) => CountQuery -> m (CountQueryP1, DS.Seq Q.PrepArg) validateCountQ query = do let source = cqSource query tableCache :: TableCache ('Postgres 'Vanilla) <- fold <$> askTableCache source flip runTableCacheRT (source, tableCache) $ runDMLP1T $ validateCountQWith sessVarFromCurrentSetting binRHSBuilder query countQToTx :: (QErrM m, MonadTx m) => (CountQueryP1, DS.Seq Q.PrepArg) -> m EncJSON countQToTx (u, p) = do qRes <- liftTx $ Q.rawQE dmlTxErrorHandler (Q.fromBuilder countSQL) (toList p) True return $ encJFromBuilder $ encodeCount qRes where countSQL = toSQL $ mkSQLCount u encodeCount (Q.SingleRow (Identity c)) = BB.byteString "{\"count\":" <> BB.intDec c <> BB.char7 '}' runCount :: ( QErrM m, UserInfoM m, CacheRM m, MonadIO m, MonadBaseControl IO m, Tracing.MonadTrace m, MetadataM m ) => CountQuery -> m EncJSON runCount q = do sourceConfig <- askSourceConfig @('Postgres 'Vanilla) (cqSource q) validateCountQ q >>= runTxWithCtx (_pscExecCtx sourceConfig) Q.ReadOnly . countQToTx