diff --git a/hasql.cabal b/hasql.cabal index dc0f775..be7234f 100644 --- a/hasql.cabal +++ b/hasql.cabal @@ -147,7 +147,9 @@ library testing-kit hs-source-dirs: testing-kit exposed-modules: Hasql.TestingKit.Constants + Hasql.TestingKit.Statements.BrokenSyntax Hasql.TestingKit.Statements.GenerateSeries + Hasql.TestingKit.Statements.WrongDecoder Hasql.TestingKit.TestingDsl build-depends: diff --git a/hspec/Hasql/PipelineSpec.hs b/hspec/Hasql/PipelineSpec.hs index 56d172a..25bc179 100644 --- a/hspec/Hasql/PipelineSpec.hs +++ b/hspec/Hasql/PipelineSpec.hs @@ -1,6 +1,8 @@ module Hasql.PipelineSpec (spec) where +import Hasql.TestingKit.Statements.BrokenSyntax qualified as BrokenSyntax import Hasql.TestingKit.Statements.GenerateSeries qualified as GenerateSeries +import Hasql.TestingKit.Statements.WrongDecoder qualified as WrongDecoder import Hasql.TestingKit.TestingDsl qualified as Dsl import Test.Hspec import Prelude @@ -22,7 +24,7 @@ spec = do $ GenerateSeries.pipeline True GenerateSeries.Params {start = 0, end = 2} shouldBe result (Right [0 .. 2]) - describe "Normally" do + describe "Multi-statement" do describe "On unprepared statements" do it "Collects results and sends params" do result <- @@ -39,6 +41,51 @@ spec = do $ GenerateSeries.pipeline True GenerateSeries.Params {start = 0, end = 2} shouldBe result (Right [[0 .. 2], [0 .. 2]]) - describe "When some part fails" do - it "Works" do - pending + describe "When a part in the middle fails" do + describe "With query error" do + it "Captures the error" do + result <- + Dsl.runPipelineOnLocalDb + $ (,,) + <$> GenerateSeries.pipeline True GenerateSeries.Params {start = 0, end = 2} + <*> BrokenSyntax.pipeline True BrokenSyntax.Params {start = 0, end = 2} + <*> GenerateSeries.pipeline True GenerateSeries.Params {start = 0, end = 2} + case result of + Left (Dsl.SessionError (Dsl.QuerySessionError _ _ _)) -> pure () + _ -> expectationFailure $ "Unexpected result: " <> show result + + it "Leaves the connection usable" do + result <- + Dsl.runSessionOnLocalDb do + tryError + $ Dsl.runPipelineInSession + $ (,,) + <$> GenerateSeries.pipeline True GenerateSeries.Params {start = 0, end = 2} + <*> BrokenSyntax.pipeline True BrokenSyntax.Params {start = 0, end = 2} + <*> GenerateSeries.pipeline True GenerateSeries.Params {start = 0, end = 2} + GenerateSeries.session True GenerateSeries.Params {start = 0, end = 0} + shouldBe result (Right [0]) + + describe "With decoding error" do + it "Captures the error" do + result <- + Dsl.runPipelineOnLocalDb + $ (,,) + <$> GenerateSeries.pipeline True GenerateSeries.Params {start = 0, end = 2} + <*> WrongDecoder.pipeline True WrongDecoder.Params {start = 0, end = 2} + <*> GenerateSeries.pipeline True GenerateSeries.Params {start = 0, end = 2} + case result of + Left (Dsl.SessionError (Dsl.QuerySessionError _ _ _)) -> pure () + _ -> expectationFailure $ "Unexpected result: " <> show result + + it "Leaves the connection usable" do + result <- + Dsl.runSessionOnLocalDb do + tryError + $ Dsl.runPipelineInSession + $ (,,) + <$> GenerateSeries.pipeline True GenerateSeries.Params {start = 0, end = 2} + <*> WrongDecoder.pipeline True WrongDecoder.Params {start = 0, end = 2} + <*> GenerateSeries.pipeline True GenerateSeries.Params {start = 0, end = 2} + GenerateSeries.session True GenerateSeries.Params {start = 0, end = 0} + shouldBe result (Right [0]) diff --git a/library/Hasql/Errors.hs b/library/Hasql/Errors.hs index e4b72df..85e7eb1 100644 --- a/library/Hasql/Errors.hs +++ b/library/Hasql/Errors.hs @@ -5,8 +5,7 @@ import Hasql.Prelude -- | Error during execution of a session. data SessionError - = -- | - -- An error during the execution of a query. + = -- | Error during the execution of a query. -- Comes packed with the query template and a textual representation of the provided params. QuerySessionError -- | SQL template. diff --git a/testing-kit/Hasql/TestingKit/Statements/BrokenSyntax.hs b/testing-kit/Hasql/TestingKit/Statements/BrokenSyntax.hs new file mode 100644 index 0000000..1d47e70 --- /dev/null +++ b/testing-kit/Hasql/TestingKit/Statements/BrokenSyntax.hs @@ -0,0 +1,44 @@ +module Hasql.TestingKit.Statements.BrokenSyntax where + +import Hasql.Decoders qualified as Decoders +import Hasql.Encoders qualified as Encoders +import Hasql.Pipeline qualified as Pipeline +import Hasql.Session qualified as Session +import Hasql.Statement qualified as Statement +import Prelude + +data Params = Params + { start :: Int64, + end :: Int64 + } + +type Result = [Int64] + +session :: Bool -> Params -> Session.Session Result +session prepared params = + Session.statement params (statement prepared) + +pipeline :: Bool -> Params -> Pipeline.Pipeline Result +pipeline prepared params = + Pipeline.statement params (statement prepared) + +statement :: Bool -> Statement.Statement Params Result +statement = + Statement.Statement sql encoder decoder + +sql :: ByteString +sql = + "S" + +encoder :: Encoders.Params Params +encoder = + mconcat + [ (.start) >$< Encoders.param (Encoders.nonNullable Encoders.int8), + (.end) >$< Encoders.param (Encoders.nonNullable Encoders.int8) + ] + +decoder :: Decoders.Result Result +decoder = + Decoders.rowList + ( Decoders.column (Decoders.nonNullable Decoders.int8) + ) diff --git a/testing-kit/Hasql/TestingKit/Statements/WrongDecoder.hs b/testing-kit/Hasql/TestingKit/Statements/WrongDecoder.hs new file mode 100644 index 0000000..d139327 --- /dev/null +++ b/testing-kit/Hasql/TestingKit/Statements/WrongDecoder.hs @@ -0,0 +1,44 @@ +module Hasql.TestingKit.Statements.WrongDecoder where + +import Hasql.Decoders qualified as Decoders +import Hasql.Encoders qualified as Encoders +import Hasql.Pipeline qualified as Pipeline +import Hasql.Session qualified as Session +import Hasql.Statement qualified as Statement +import Prelude + +data Params = Params + { start :: Int64, + end :: Int64 + } + +type Result = [UUID] + +session :: Bool -> Params -> Session.Session Result +session prepared params = + Session.statement params (statement prepared) + +pipeline :: Bool -> Params -> Pipeline.Pipeline Result +pipeline prepared params = + Pipeline.statement params (statement prepared) + +statement :: Bool -> Statement.Statement Params Result +statement = + Statement.Statement sql encoder decoder + +sql :: ByteString +sql = + "SELECT generate_series($1, $2)" + +encoder :: Encoders.Params Params +encoder = + mconcat + [ (.start) >$< Encoders.param (Encoders.nonNullable Encoders.int8), + (.end) >$< Encoders.param (Encoders.nonNullable Encoders.int8) + ] + +decoder :: Decoders.Result Result +decoder = + Decoders.rowList + ( Decoders.column (Decoders.nonNullable Decoders.uuid) + ) diff --git a/testing-kit/Hasql/TestingKit/TestingDsl.hs b/testing-kit/Hasql/TestingKit/TestingDsl.hs index ba762d0..5183939 100644 --- a/testing-kit/Hasql/TestingKit/TestingDsl.hs +++ b/testing-kit/Hasql/TestingKit/TestingDsl.hs @@ -1,10 +1,18 @@ module Hasql.TestingKit.TestingDsl - ( Session.Session, + ( -- * Errors Error (..), Session.SessionError (..), Session.CommandError (..), + Session.ResultError (..), + Session.RowError (..), + Session.ColumnError (..), + + -- * Abstractions + Session.Session, Pipeline.Pipeline, Statement.Statement (..), + + -- * Execution runSessionOnLocalDb, runPipelineOnLocalDb, runStatementInSession,