module Database.MSSQL.TransactionSpec (spec) where import Control.Exception.Base (bracket) import Data.ByteString (ByteString) import Database.MSSQL.Pool import Database.MSSQL.Transaction import Database.ODBC.SQLServer as ODBC ( ODBCException (DataRetrievalError, UnsuccessfulReturnCode), Query, ) import Hasura.Prelude import Test.Hspec -- | Describe a TransactionSpec test, see 'runTest' for additional details. data TestCase a = TestCase { -- | Specifies which transactions to run. They will be executed sequentially, -- and the value of the last query in the last transaction will be compared -- against 'expectation'. -- -- Has to be non-empty (but kept as a list for simplicity/convenience). transactions :: [Transaction], -- | Expected result of the test. 'Right' represents a successful outcome. -- -- Left is presented as a function because we want to be able to partially -- match error (messages). expectation :: Either (MSSQLTxError -> Expectation) a, -- | Which kind of parser to use on the last query of the last transaction: -- -- * 'unitQuery' for queries returning '()' -- * 'singleRowQuery' for queries returning 'a' -- * 'multiRowQuery' for queries returning '[a]' -- -- Use 'TypeApplications' to specify the return type (needed to disambiguate -- the 'a' type parameter when needed). runWith :: Query -> TxT IO a, -- | Description for the test, used in the test output. description :: String } newtype Transaction = Transaction { unTransaction :: [Query] } spec :: Text -> Spec spec connString = do runBasicChecks connString transactionStateTests connString runBasicChecks :: Text -> Spec runBasicChecks connString = describe "runTx transaction basic checks" $ do run TestCase { transactions = [ Transaction [ "CREATE TABLE SingleCol (ID INT)", "INSERT INTO SingleCol VALUES (2)", "SELECT ID FROM SingleCol" ] ], expectation = Right 2, runWith = singleRowQuery @Int, description = "CREATE, INSERT, SELECT single column" } run TestCase { transactions = [ Transaction [ "CREATE TABLE MultiCol (ID INT, NAME VARCHAR(1))", "INSERT INTO MultiCol VALUES (2, 'A')", "SELECT ID, NAME FROM MultiCol" ] ], expectation = Right (2, "A"), runWith = singleRowQuery @(Int, ByteString), description = "CREATE, INSERT, SELECT single multiple columns" } run TestCase { transactions = [Transaction ["SELECT 'hello'"]], expectation = matchDataRetrievalError "Expected Int, but got: ByteStringValue \"hello\"", runWith = singleRowQuery @Int, description = "SELECT the wrong type" } run TestCase { transactions = [Transaction ["select * from (values (1), (2)) as x(a)"]], expectation = Right [1, 2], runWith = multiRowQuery @Int, description = "SELECT multiple rows" } run TestCase { transactions = [Transaction ["select * from (values (1), (2)) as x(a)"]], expectation = matchDataRetrievalError "expecting single row", runWith = singleRowQuery @Int, description = "SELECT multiple rows, expect single row" } run TestCase { transactions = [ Transaction [ "CREATE TABLE BadQuery (ID INT, INVALID_SYNTAX)", "INSERT INTO BadQuery VALUES (3)" ] ], expectation = matchQueryError (UnsuccessfulReturnCode "odbc_SQLExecDirectW" (-1) invalidSyntaxError (Just "42000")), runWith = unitQuery, description = "Bad syntax error/transaction rollback" } where -- Partially apply connString to runTest for convenience run :: forall a. Eq a => Show a => TestCase a -> Spec run = runTest connString -- | Test COMMIT and ROLLBACK for Active and NoActive states. -- -- The Uncommittable state can be achieved by running the transaction enclosed -- in a TRY..CATCH block, which is not currently doable with our current API. -- Consider changing the API to allow such a test if we ever end up having -- bugs because of it. transactionStateTests :: Text -> Spec transactionStateTests connString = describe "runTx Transaction State -> Action" $ do run TestCase { transactions = [Transaction ["SELECT 1"]], expectation = Right 1, runWith = singleRowQuery @Int, description = "Active -> COMMIT" } run TestCase { transactions = [ Transaction [ "CREATE TABLE SingleCol (ID INT)", "INSERT INTO SingleCol VALUES (2)" ], Transaction -- Fail [ "CREATE TABLE BadQuery (ID INT, INVALID_SYNTAX)", "UPDATE SingleCol SET ID=3" ], Transaction ["SELECT ID FROM SingleCol"] -- Grab data from setup ], expectation = Right 2, runWith = singleRowQuery @Int, description = "Active -> ROLLBACK" } run TestCase { transactions = [ Transaction -- Fail ["COMMIT; SELECT 1"] ], expectation = Left (`shouldBe` MSSQLInternal "No active transaction exist; cannot commit"), runWith = singleRowQuery @Int, description = "NoActive -> COMMIT" } run TestCase { transactions = [ Transaction [ "COMMIT;", "CREATE TABLE BadQuery (ID INT, INVALID_SYNTAX)" ] ], -- We should get the error rather than the cannot commit error from the -- NoActive -> Commit test. expectation = matchQueryError (UnsuccessfulReturnCode "odbc_SQLExecDirectW" (-1) invalidSyntaxError (Just "42000")), runWith = unitQuery, description = "NoActive -> ROLLBACK" } where -- Partially apply connString to runTest for convenience run :: forall a. Eq a => Show a => TestCase a -> Spec run = runTest connString -- | Run a 'TestCase' by executing the queries in order. The last 'ODBC.Query' -- is the one we check he result against. -- -- Beacuse we don't know the type of the result, we need it supplied as part -- of the 'TestCase': -- -- * 'unitQuery' for queries returning '()' -- * 'singleRowQuery' for queries returning 'a' -- * 'multiRowQuery' for queries returning '[a]' -- -- Note that we need to use TypeApplications on the 'runWith' function for type -- checking to work, especially if the values are polymorphic -- (e.g. numbers or strings). -- -- Please also note that we are discarding 'Left's from "setup" transactions -- (all but the last transaction). See the 'runSetup' helper below. runTest :: forall a. Eq a => Show a => Text -> TestCase a -> Spec runTest connString TestCase {..} = it description do case reverse transactions of [] -> expectationFailure "Empty transaction list: nothing to do." (mainTransaction : leadingTransactions) -> do -- Run all transactions before the last (main) transaction. runSetup (reverse leadingTransactions) -- Get the result from the last transaction. result <- runInConn connString $ runQueries runWith $ unTransaction mainTransaction case (result, expectation) of -- Validate the error is the one we were expecting. (Left err, Left expected) -> expected err -- Verify the success result is the expected one. (Right res, Right expected) -> res `shouldBe` expected -- Expected success but got error. Needs special case because the expected -- Left is a validator (function). (Left err, Right expected) -> expectationFailure $ "Expected " <> show expected <> " but got error: " <> show err -- Expected error but got success. Needs special case because the expected -- Left is a validator (function). (Right res, Left _) -> expectationFailure $ "Expected error but got success: " <> show res where runSetup :: [Transaction] -> IO () runSetup [] = pure () runSetup (t : ts) = do -- Discards 'Left's. _ <- runInConn connString (runQueries unitQuery $ unTransaction t) runSetup ts runQueries :: (Query -> TxT IO x) -> [Query] -> TxT IO x runQueries _ [] = error $ "Expected at least one query per transaction in " <> description runQueries f [q] = f q runQueries f (x : xs) = unitQuery x *> runQueries f xs -- | spec helper functions runInConn :: Text -> TxT IO a -> IO (Either MSSQLTxError a) runInConn connString query = bracket (createMinimalPool connString) drainMSSQLPool (runExceptT . runTx query) createMinimalPool :: Text -> IO MSSQLPool createMinimalPool connString = initMSSQLPool (ConnectionString connString) $ ConnectionOptions 1 1 5 invalidSyntaxError :: String invalidSyntaxError = "[Microsoft][ODBC Driver 17 for SQL Server][SQL Server]The definition for column 'INVALID_SYNTAX' must include a data type." matchDataRetrievalError :: String -> Either (MSSQLTxError -> Expectation) a matchDataRetrievalError = matchQueryError . DataRetrievalError matchQueryError :: ODBCException -> Either (MSSQLTxError -> Expectation) a matchQueryError expectedErr = Left $ \case MSSQLQueryError _ err -> err `shouldBe` expectedErr MSSQLConnError _ -> expectationFailure unexpectedMSSQLConnError MSSQLInternal _ -> expectationFailure unexpectedMSSQLInternalError unexpectedMSSQLInternalError :: String unexpectedMSSQLInternalError = "Expected MSSQLQueryError, but got: MSSQLInternal" unexpectedMSSQLConnError :: String unexpectedMSSQLConnError = "Expected MSSQLQueryError, but got: MSSQLConnError"