This commit is contained in:
nikita-volkov 2023-10-12 23:24:12 +00:00 committed by github-actions[bot]
parent 8c2aacc464
commit 55dc24bbd5
17 changed files with 568 additions and 533 deletions

View File

@ -25,7 +25,7 @@ main =
sessionBench "manySmallResults" sessionWithManySmallResults sessionBench "manySmallResults" sessionWithManySmallResults
] ]
where where
sessionBench :: NFData a => String -> B.Session a -> Benchmark sessionBench :: (NFData a) => String -> B.Session a -> Benchmark
sessionBench name session = sessionBench name session =
bench name (nfIO (fmap (either (error "") id) (B.run session connection))) bench name (nfIO (fmap (either (error "") id) (B.run session connection)))

View File

@ -306,7 +306,7 @@ refine fn (Value v) = Value (Value.Value (\b -> A.refine fn (Value.run v b)))
-- x = hstore 'replicateM' -- x = hstore 'replicateM'
-- @ -- @
{-# INLINEABLE hstore #-} {-# INLINEABLE hstore #-}
hstore :: (forall m. Monad m => Int -> m (Text, Maybe Text) -> m a) -> Value a hstore :: (forall m. (Monad m) => Int -> m (Text, Maybe Text) -> m a) -> Value a
hstore replicateM = Value (Value.decoder (const (A.hstore replicateM A.text_strict A.text_strict))) hstore replicateM = Value (Value.decoder (const (A.hstore replicateM A.text_strict A.text_strict)))
-- | -- |
@ -348,7 +348,7 @@ listArray = array . dimension replicateM . element
-- Please notice that in case of multidimensional arrays nesting 'vectorArray' decoder -- Please notice that in case of multidimensional arrays nesting 'vectorArray' decoder
-- won't work. You have to explicitly construct the array decoder using 'array'. -- won't work. You have to explicitly construct the array decoder using 'array'.
{-# INLINE vectorArray #-} {-# INLINE vectorArray #-}
vectorArray :: GenericVector.Vector vector element => NullableOrNot Value element -> Value (vector element) vectorArray :: (GenericVector.Vector vector element) => NullableOrNot Value element -> Value (vector element)
vectorArray = array . dimension GenericVector.replicateM . element vectorArray = array . dimension GenericVector.replicateM . element
-- | -- |
@ -383,7 +383,7 @@ newtype Array a = Array (Array.Array a)
-- --
-- * A decoder of its components, which can be either another 'dimension' or 'element'. -- * A decoder of its components, which can be either another 'dimension' or 'element'.
{-# INLINEABLE dimension #-} {-# INLINEABLE dimension #-}
dimension :: (forall m. Monad m => Int -> m a -> m b) -> Array a -> Array b dimension :: (forall m. (Monad m) => Int -> m a -> m b) -> Array a -> Array b
dimension replicateM (Array imp) = Array (Array.dimension replicateM imp) dimension replicateM (Array imp) = Array (Array.dimension replicateM imp)
-- | -- |

View File

@ -13,7 +13,7 @@ run (Array imp) env =
A.array (runReaderT imp env) A.array (runReaderT imp env)
{-# INLINE dimension #-} {-# INLINE dimension #-}
dimension :: (forall m. Monad m => Int -> m a -> m b) -> Array a -> Array b dimension :: (forall m. (Monad m) => Int -> m a -> m b) -> Array a -> Array b
dimension replicateM (Array imp) = dimension replicateM (Array imp) =
Array $ ReaderT $ \env -> A.dimensionArray replicateM (runReaderT imp env) Array $ ReaderT $ \env -> A.dimensionArray replicateM (runReaderT imp env)

View File

@ -34,10 +34,12 @@ rowsAffected =
checkExecStatus $ \case checkExecStatus $ \case
LibPQ.CommandOk -> True LibPQ.CommandOk -> True
_ -> False _ -> False
Result $ Result
ReaderT $ \(_, result) -> $ ReaderT
ExceptT $ $ \(_, result) ->
LibPQ.cmdTuples result & fmap cmdTuplesReader ExceptT
$ LibPQ.cmdTuples result
& fmap cmdTuplesReader
where where
cmdTuplesReader = cmdTuplesReader =
notNothing >=> notEmpty >=> decimal notNothing >=> notEmpty >=> decimal
@ -49,8 +51,8 @@ rowsAffected =
then Left (UnexpectedResult "Empty bytes") then Left (UnexpectedResult "Empty bytes")
else Right bytes else Right bytes
decimal bytes = decimal bytes =
mapLeft (\m -> UnexpectedResult ("Decimal parsing failure: " <> fromString m)) $ mapLeft (\m -> UnexpectedResult ("Decimal parsing failure: " <> fromString m))
Attoparsec.parseOnly (Attoparsec.decimal <* Attoparsec.endOfInput) bytes $ Attoparsec.parseOnly (Attoparsec.decimal <* Attoparsec.endOfInput) bytes
{-# INLINE checkExecStatus #-} {-# INLINE checkExecStatus #-}
checkExecStatus :: (LibPQ.ExecStatus -> Bool) -> Result () checkExecStatus :: (LibPQ.ExecStatus -> Bool) -> Result ()
@ -69,14 +71,15 @@ checkExecStatus predicate =
{-# INLINE serverError #-} {-# INLINE serverError #-}
serverError :: Result () serverError :: Result ()
serverError = serverError =
Result $ Result
ReaderT $ \(_, result) -> ExceptT $ do $ ReaderT
$ \(_, result) -> ExceptT $ do
code <- code <-
fmap fold $ fmap fold
LibPQ.resultErrorField result LibPQ.DiagSqlstate $ LibPQ.resultErrorField result LibPQ.DiagSqlstate
message <- message <-
fmap fold $ fmap fold
LibPQ.resultErrorField result LibPQ.DiagMessagePrimary $ LibPQ.resultErrorField result LibPQ.DiagMessagePrimary
detail <- detail <-
LibPQ.resultErrorField result LibPQ.DiagMessageDetail LibPQ.resultErrorField result LibPQ.DiagMessageDetail
hint <- hint <-
@ -99,8 +102,9 @@ maybe rowDec =
checkExecStatus $ \case checkExecStatus $ \case
LibPQ.TuplesOk -> True LibPQ.TuplesOk -> True
_ -> False _ -> False
Result $ Result
ReaderT $ \(integerDatetimes, result) -> ExceptT $ do $ ReaderT
$ \(integerDatetimes, result) -> ExceptT $ do
maxRows <- LibPQ.ntuples result maxRows <- LibPQ.ntuples result
case maxRows of case maxRows of
0 -> return (Right Nothing) 0 -> return (Right Nothing)
@ -122,8 +126,9 @@ single rowDec =
checkExecStatus $ \case checkExecStatus $ \case
LibPQ.TuplesOk -> True LibPQ.TuplesOk -> True
_ -> False _ -> False
Result $ Result
ReaderT $ \(integerDatetimes, result) -> ExceptT $ do $ ReaderT
$ \(integerDatetimes, result) -> ExceptT $ do
maxRows <- LibPQ.ntuples result maxRows <- LibPQ.ntuples result
case maxRows of case maxRows of
1 -> do 1 -> do
@ -144,8 +149,9 @@ vector rowDec =
checkExecStatus $ \case checkExecStatus $ \case
LibPQ.TuplesOk -> True LibPQ.TuplesOk -> True
_ -> False _ -> False
Result $ Result
ReaderT $ \(integerDatetimes, result) -> ExceptT $ do $ ReaderT
$ \(integerDatetimes, result) -> ExceptT $ do
maxRows <- LibPQ.ntuples result maxRows <- LibPQ.ntuples result
maxCols <- LibPQ.nfields result maxCols <- LibPQ.nfields result
mvector <- MutableVector.unsafeNew (rowToInt maxRows) mvector <- MutableVector.unsafeNew (rowToInt maxRows)
@ -172,10 +178,11 @@ foldl step init rowDec =
checkExecStatus $ \case checkExecStatus $ \case
LibPQ.TuplesOk -> True LibPQ.TuplesOk -> True
_ -> False _ -> False
Result $ Result
ReaderT $ \(integerDatetimes, result) -> $ ReaderT
ExceptT $ $ \(integerDatetimes, result) ->
{-# SCC "traversal" #-} ExceptT
$ {-# SCC "traversal" #-}
do do
maxRows <- LibPQ.ntuples result maxRows <- LibPQ.ntuples result
maxCols <- LibPQ.nfields result maxCols <- LibPQ.nfields result
@ -203,8 +210,9 @@ foldr step init rowDec =
checkExecStatus $ \case checkExecStatus $ \case
LibPQ.TuplesOk -> True LibPQ.TuplesOk -> True
_ -> False _ -> False
Result $ Result
ReaderT $ \(integerDatetimes, result) -> ExceptT $ do $ ReaderT
$ \(integerDatetimes, result) -> ExceptT $ do
maxRows <- LibPQ.ntuples result maxRows <- LibPQ.ntuples result
maxCols <- LibPQ.nfields result maxCols <- LibPQ.nfields result
accRef <- newIORef init accRef <- newIORef init

View File

@ -29,18 +29,20 @@ run (Results stack) env =
{-# INLINE clientError #-} {-# INLINE clientError #-}
clientError :: Results a clientError :: Results a
clientError = clientError =
Results $ Results
ReaderT $ \(_, connection) -> $ ReaderT
ExceptT $ $ \(_, connection) ->
fmap (Left . ClientError) (LibPQ.errorMessage connection) ExceptT
$ fmap (Left . ClientError) (LibPQ.errorMessage connection)
-- | -- |
-- Parse a single result. -- Parse a single result.
{-# INLINE single #-} {-# INLINE single #-}
single :: Result.Result a -> Results a single :: Result.Result a -> Results a
single resultDec = single resultDec =
Results $ Results
ReaderT $ \(integerDatetimes, connection) -> ExceptT $ do $ ReaderT
$ \(integerDatetimes, connection) -> ExceptT $ do
resultMaybe <- LibPQ.getResult connection resultMaybe <- LibPQ.getResult connection
case resultMaybe of case resultMaybe of
Just result -> Just result ->
@ -53,8 +55,9 @@ single resultDec =
{-# INLINE getResult #-} {-# INLINE getResult #-}
getResult :: Results LibPQ.Result getResult :: Results LibPQ.Result
getResult = getResult =
Results $ Results
ReaderT $ \(_, connection) -> ExceptT $ do $ ReaderT
$ \(_, connection) -> ExceptT $ do
resultMaybe <- LibPQ.getResult connection resultMaybe <- LibPQ.getResult connection
case resultMaybe of case resultMaybe of
Just result -> pure (Right result) Just result -> pure (Right result)
@ -85,7 +88,8 @@ dropRemainders =
ExceptT $ fmap (mapLeft ResultError) $ Result.run Result.noResult (integerDatetimes, result) ExceptT $ fmap (mapLeft ResultError) $ Result.run Result.noResult (integerDatetimes, result)
refine :: (a -> Either Text b) -> Results a -> Results b refine :: (a -> Either Text b) -> Results a -> Results b
refine refiner results = Results $ refine refiner results = Results
ReaderT $ \env -> ExceptT $ do $ ReaderT
$ \env -> ExceptT $ do
resultEither <- run results env resultEither <- run results env
return $ resultEither >>= mapLeft (ResultError . UnexpectedResult) . refiner return $ resultEither >>= mapLeft (ResultError . UnexpectedResult) . refiner

View File

@ -41,21 +41,22 @@ error x =
value :: Value.Value a -> Row (Maybe a) value :: Value.Value a -> Row (Maybe a)
value valueDec = value valueDec =
{-# SCC "value" #-} {-# SCC "value" #-}
Row $ Row
ReaderT $ \(Env result row columnsAmount integerDatetimes columnRef) -> ExceptT $ do $ ReaderT
$ \(Env result row columnsAmount integerDatetimes columnRef) -> ExceptT $ do
col <- readIORef columnRef col <- readIORef columnRef
writeIORef columnRef (succ col) writeIORef columnRef (succ col)
if col < columnsAmount if col < columnsAmount
then do then do
valueMaybe <- {-# SCC "getvalue'" #-} LibPQ.getvalue' result row col valueMaybe <- {-# SCC "getvalue'" #-} LibPQ.getvalue' result row col
pure $ pure
case valueMaybe of $ case valueMaybe of
Nothing -> Nothing ->
Right Nothing Right Nothing
Just value -> Just value ->
fmap Just $ fmap Just
mapLeft ValueError $ $ mapLeft ValueError
{-# SCC "decode" #-} A.valueParser (Value.run valueDec integerDatetimes) value $ {-# SCC "decode" #-} A.valueParser (Value.run valueDec integerDatetimes) value
else pure (Left EndOfInput) else pure (Left EndOfInput)
-- | -- |

View File

@ -330,7 +330,7 @@ composite (Composite encode print) =
-- Please notice that in case of multidimensional arrays nesting 'foldableArray' encoder -- Please notice that in case of multidimensional arrays nesting 'foldableArray' encoder
-- won't work. You have to explicitly construct the array encoder using 'array'. -- won't work. You have to explicitly construct the array encoder using 'array'.
{-# INLINE foldableArray #-} {-# INLINE foldableArray #-}
foldableArray :: Foldable foldable => NullableOrNot Value element -> Value (foldable element) foldableArray :: (Foldable foldable) => NullableOrNot Value element -> Value (foldable element)
foldableArray = array . dimension foldl' . element foldableArray = array . dimension foldl' . element
-- * Array -- * Array

View File

@ -19,8 +19,9 @@ value =
nullableValue :: C.Value a -> Params (Maybe a) nullableValue :: C.Value a -> Params (Maybe a)
nullableValue (C.Value valueOID arrayOID encode render) = nullableValue (C.Value valueOID arrayOID encode render) =
Params $ Params
Op $ \input -> $ Op
$ \input ->
let D.OID _ pqOid format = let D.OID _ pqOid format =
valueOID valueOID
encoder env = encoder env =

View File

@ -19,6 +19,6 @@ unsafePTI pti =
Value (PTI.ptiOID pti) (fromMaybe (error "No array OID") (PTI.ptiArrayOID pti)) Value (PTI.ptiOID pti) (fromMaybe (error "No array OID") (PTI.ptiArrayOID pti))
{-# INLINE unsafePTIWithShow #-} {-# INLINE unsafePTIWithShow #-}
unsafePTIWithShow :: Show a => PTI.PTI -> (Bool -> a -> B.Encoding) -> Value a unsafePTIWithShow :: (Show a) => PTI.PTI -> (Bool -> a -> B.Encoding) -> Value a
unsafePTIWithShow pti encode = unsafePTIWithShow pti encode =
unsafePTI pti encode (C.string . show) unsafePTI pti encode (C.string . show)

View File

@ -39,24 +39,24 @@ data CommandError
data ResultError data ResultError
= -- | An error reported by the DB. = -- | An error reported by the DB.
ServerError ServerError
ByteString -- | __Code__. The SQLSTATE code for the error. It's recommended to use
-- ^ __Code__. The SQLSTATE code for the error. It's recommended to use
-- <http://hackage.haskell.org/package/postgresql-error-codes -- <http://hackage.haskell.org/package/postgresql-error-codes
-- the "postgresql-error-codes" package> to work with those. -- the "postgresql-error-codes" package> to work with those.
ByteString ByteString
-- ^ __Message__. The primary human-readable error message(typically one -- | __Message__. The primary human-readable error message(typically one
-- line). Always present. -- line). Always present.
(Maybe ByteString) ByteString
-- ^ __Details__. An optional secondary error message carrying more -- | __Details__. An optional secondary error message carrying more
-- detail about the problem. Might run to multiple lines. -- detail about the problem. Might run to multiple lines.
(Maybe ByteString) (Maybe ByteString)
-- ^ __Hint__. An optional suggestion on what to do about the problem. -- | __Hint__. An optional suggestion on what to do about the problem.
-- This is intended to differ from detail in that it offers advice -- This is intended to differ from detail in that it offers advice
-- (potentially inappropriate) rather than hard facts. Might run to -- (potentially inappropriate) rather than hard facts. Might run to
-- multiple lines. -- multiple lines.
(Maybe Int) (Maybe ByteString)
-- ^ __Position__. Error cursor position as an index into the original -- | __Position__. Error cursor position as an index into the original
-- statement string. Positions are measured in characters not bytes. -- statement string. Positions are measured in characters not bytes.
(Maybe Int)
| -- | | -- |
-- The database returned an unexpected result. -- The database returned an unexpected result.
-- Indicates an improper statement or a schema mismatch. -- Indicates an improper statement or a schema mismatch.

View File

@ -117,12 +117,12 @@ type TextBuilder =
Data.Text.Lazy.Builder.Builder Data.Text.Lazy.Builder.Builder
{-# INLINE forMToZero_ #-} {-# INLINE forMToZero_ #-}
forMToZero_ :: Applicative m => Int -> (Int -> m a) -> m () forMToZero_ :: (Applicative m) => Int -> (Int -> m a) -> m ()
forMToZero_ !startN f = forMToZero_ !startN f =
($ pred startN) $ fix $ \loop !n -> if n >= 0 then f n *> loop (pred n) else pure () ($ pred startN) $ fix $ \loop !n -> if n >= 0 then f n *> loop (pred n) else pure ()
{-# INLINE forMFromZero_ #-} {-# INLINE forMFromZero_ #-}
forMFromZero_ :: Applicative m => Int -> (Int -> m a) -> m () forMFromZero_ :: (Applicative m) => Int -> (Int -> m a) -> m ()
forMFromZero_ !endN f = forMFromZero_ !endN f =
($ 0) $ fix $ \loop !n -> if n < endN then f n *> loop (succ n) else pure () ($ 0) $ fix $ \loop !n -> if n < endN then f n *> loop (succ n) else pure ()

View File

@ -22,8 +22,8 @@ newtype Session a
-- Executes a bunch of commands on the provided connection. -- Executes a bunch of commands on the provided connection.
run :: Session a -> Connection.Connection -> IO (Either QueryError a) run :: Session a -> Connection.Connection -> IO (Either QueryError a)
run (Session impl) connection = run (Session impl) connection =
runExceptT $ runExceptT
runReaderT impl connection $ runReaderT impl connection
-- | -- |
-- Possibly a multi-statement query, -- Possibly a multi-statement query,
@ -31,14 +31,16 @@ run (Session impl) connection =
-- nor can any results of it be collected. -- nor can any results of it be collected.
sql :: ByteString -> Session () sql :: ByteString -> Session ()
sql sql = sql sql =
Session $ Session
ReaderT $ \(Connection.Connection pqConnectionRef integerDatetimes registry) -> $ ReaderT
ExceptT $ $ \(Connection.Connection pqConnectionRef integerDatetimes registry) ->
fmap (mapLeft (QueryError sql [])) $ ExceptT
withMVar pqConnectionRef $ \pqConnection -> do $ fmap (mapLeft (QueryError sql []))
r1 <- IO.sendNonparametricStatement pqConnection sql $ withMVar pqConnectionRef
r2 <- IO.getResults pqConnection integerDatetimes decoder $ \pqConnection -> do
return $ r1 *> r2 r1 <- IO.sendNonparametricStatement pqConnection sql
r2 <- IO.getResults pqConnection integerDatetimes decoder
return $ r1 *> r2
where where
decoder = decoder =
Decoders.Results.single Decoders.Result.noResult Decoders.Results.single Decoders.Result.noResult
@ -47,14 +49,16 @@ sql sql =
-- Parameters and a specification of a parametric single-statement query to apply them to. -- Parameters and a specification of a parametric single-statement query to apply them to.
statement :: params -> Statement.Statement params result -> Session result statement :: params -> Statement.Statement params result -> Session result
statement input (Statement.Statement template (Encoders.Params paramsEncoder) decoder preparable) = statement input (Statement.Statement template (Encoders.Params paramsEncoder) decoder preparable) =
Session $ Session
ReaderT $ \(Connection.Connection pqConnectionRef integerDatetimes registry) -> $ ReaderT
ExceptT $ $ \(Connection.Connection pqConnectionRef integerDatetimes registry) ->
fmap (mapLeft (QueryError template inputReps)) $ ExceptT
withMVar pqConnectionRef $ \pqConnection -> do $ fmap (mapLeft (QueryError template inputReps))
r1 <- IO.sendParametricStatement pqConnection integerDatetimes registry template paramsEncoder preparable input $ withMVar pqConnectionRef
r2 <- IO.getResults pqConnection integerDatetimes (unsafeCoerce decoder) $ \pqConnection -> do
return $ r1 *> r2 r1 <- IO.sendParametricStatement pqConnection integerDatetimes registry template paramsEncoder preparable input
r2 <- IO.getResults pqConnection integerDatetimes (unsafeCoerce decoder)
return $ r1 *> r2
where where
inputReps = inputReps =
let Encoders.Params.Params (Op encoderOp) = paramsEncoder let Encoders.Params.Params (Op encoderOp) = paramsEncoder

View File

@ -16,19 +16,24 @@ type Settings =
{-# INLINE settings #-} {-# INLINE settings #-}
settings :: ByteString -> Word16 -> ByteString -> ByteString -> ByteString -> Settings settings :: ByteString -> Word16 -> ByteString -> ByteString -> ByteString -> Settings
settings host port user password database = settings host port user password database =
BL.toStrict $ BL.toStrict
BB.toLazyByteString $ $ BB.toLazyByteString
mconcat $ $ mconcat
intersperse (BB.char7 ' ') $ $ intersperse (BB.char7 ' ')
catMaybes $ $ catMaybes
[ mappend (BB.string7 "host=") . BB.byteString $ [ mappend (BB.string7 "host=")
<$> mfilter (not . B.null) (pure host), . BB.byteString
mappend (BB.string7 "port=") . BB.word16Dec <$> mfilter (not . B.null) (pure host),
<$> mfilter (/= 0) (pure port), mappend (BB.string7 "port=")
mappend (BB.string7 "user=") . BB.byteString . BB.word16Dec
<$> mfilter (not . B.null) (pure user), <$> mfilter (/= 0) (pure port),
mappend (BB.string7 "password=") . BB.byteString mappend (BB.string7 "user=")
<$> mfilter (not . B.null) (pure password), . BB.byteString
mappend (BB.string7 "dbname=") . BB.byteString <$> mfilter (not . B.null) (pure user),
<$> mfilter (not . B.null) (pure database) mappend (BB.string7 "password=")
] . BB.byteString
<$> mfilter (not . B.null) (pure password),
mappend (BB.string7 "dbname=")
. BB.byteString
<$> mfilter (not . B.null) (pure database)
]

View File

@ -20,141 +20,147 @@ main =
defaultMain tree defaultMain tree
tree = tree =
localOption (NumThreads 1) $ localOption (NumThreads 1)
testGroup $ testGroup
"All tests" "All tests"
[ testGroup "Roundtrips" $ [ testGroup "Roundtrips"
let roundtrip encoder decoder input = $ let roundtrip encoder decoder input =
let session = let session =
let statement = Statement.Statement "select $1" encoder decoder True let statement = Statement.Statement "select $1" encoder decoder True
in Session.statement input statement in Session.statement input statement
in unsafePerformIO $ do in unsafePerformIO $ do
x <- Connection.with (Session.run session) x <- Connection.with (Session.run session)
return (Right (Right input) === x) return (Right (Right input) === x)
in [ testProperty "Array" $ in [ testProperty "Array"
let encoder = Encoders.param (Encoders.nonNullable (Encoders.array (Encoders.dimension foldl' (Encoders.element (Encoders.nonNullable Encoders.int8))))) $ let encoder = Encoders.param (Encoders.nonNullable (Encoders.array (Encoders.dimension foldl' (Encoders.element (Encoders.nonNullable Encoders.int8)))))
decoder = Decoders.singleRow (Decoders.column (Decoders.nonNullable (Decoders.array (Decoders.dimension replicateM (Decoders.element (Decoders.nonNullable Decoders.int8)))))) decoder = Decoders.singleRow (Decoders.column (Decoders.nonNullable (Decoders.array (Decoders.dimension replicateM (Decoders.element (Decoders.nonNullable Decoders.int8))))))
in roundtrip encoder decoder, in roundtrip encoder decoder,
testProperty "2D Array" $ testProperty "2D Array"
let encoder = Encoders.param (Encoders.nonNullable (Encoders.array (Encoders.dimension foldl' (Encoders.dimension foldl' (Encoders.element (Encoders.nonNullable Encoders.int8)))))) $ let encoder = Encoders.param (Encoders.nonNullable (Encoders.array (Encoders.dimension foldl' (Encoders.dimension foldl' (Encoders.element (Encoders.nonNullable Encoders.int8))))))
decoder = Decoders.singleRow (Decoders.column (Decoders.nonNullable (Decoders.array (Decoders.dimension replicateM (Decoders.dimension replicateM (Decoders.element (Decoders.nonNullable Decoders.int8))))))) decoder = Decoders.singleRow (Decoders.column (Decoders.nonNullable (Decoders.array (Decoders.dimension replicateM (Decoders.dimension replicateM (Decoders.element (Decoders.nonNullable Decoders.int8)))))))
in \list -> list /= [] ==> roundtrip encoder decoder (replicate 3 list) in \list -> list /= [] ==> roundtrip encoder decoder (replicate 3 list)
], ],
testCase "Failed query" $ testCase "Failed query"
let statement = $ let statement =
Statement.Statement "select true where 1 = any ($1) and $2" encoder decoder True Statement.Statement "select true where 1 = any ($1) and $2" encoder decoder True
where where
encoder = encoder =
contrazip2 contrazip2
(Encoders.param (Encoders.nonNullable (Encoders.array (Encoders.dimension foldl' (Encoders.element (Encoders.nonNullable Encoders.int8)))))) (Encoders.param (Encoders.nonNullable (Encoders.array (Encoders.dimension foldl' (Encoders.element (Encoders.nonNullable Encoders.int8))))))
(Encoders.param (Encoders.nonNullable (Encoders.text))) (Encoders.param (Encoders.nonNullable (Encoders.text)))
decoder = decoder =
fmap (maybe False (const True)) (Decoders.rowMaybe ((Decoders.column . Decoders.nonNullable) Decoders.bool)) fmap (maybe False (const True)) (Decoders.rowMaybe ((Decoders.column . Decoders.nonNullable) Decoders.bool))
session = session =
Session.statement ([3, 7], "a") statement Session.statement ([3, 7], "a") statement
in do in do
x <- Connection.with (Session.run session) x <- Connection.with (Session.run session)
assertBool (show x) $ case x of assertBool (show x) $ case x of
Right (Left (Session.QueryError "select true where 1 = any ($1) and $2" ["[3, 7]", "\"a\""] _)) -> True Right (Left (Session.QueryError "select true where 1 = any ($1) and $2" ["[3, 7]", "\"a\""] _)) -> True
_ -> False, _ -> False,
testCase "IN simulation" $ testCase "IN simulation"
let statement = $ let statement =
Statement.Statement "select true where 1 = any ($1)" encoder decoder True Statement.Statement "select true where 1 = any ($1)" encoder decoder True
where where
encoder = encoder =
Encoders.param (Encoders.nonNullable (Encoders.array (Encoders.dimension foldl' (Encoders.element (Encoders.nonNullable Encoders.int8))))) Encoders.param (Encoders.nonNullable (Encoders.array (Encoders.dimension foldl' (Encoders.element (Encoders.nonNullable Encoders.int8)))))
decoder = decoder =
fmap (maybe False (const True)) (Decoders.rowMaybe ((Decoders.column . Decoders.nonNullable) Decoders.bool)) fmap (maybe False (const True)) (Decoders.rowMaybe ((Decoders.column . Decoders.nonNullable) Decoders.bool))
session = session =
do do
result1 <- Session.statement [1, 2] statement result1 <- Session.statement [1, 2] statement
result2 <- Session.statement [2, 3] statement result2 <- Session.statement [2, 3] statement
return (result1, result2) return (result1, result2)
in do in do
x <- Connection.with (Session.run session) x <- Connection.with (Session.run session)
assertEqual (show x) (Right (Right (True, False))) x, assertEqual (show x) (Right (Right (True, False))) x,
testCase "NOT IN simulation" $ testCase "NOT IN simulation"
let statement = $ let statement =
Statement.Statement "select true where 3 <> all ($1)" encoder decoder True Statement.Statement "select true where 3 <> all ($1)" encoder decoder True
where where
encoder = encoder =
Encoders.param (Encoders.nonNullable (Encoders.array (Encoders.dimension foldl' (Encoders.element (Encoders.nonNullable Encoders.int8))))) Encoders.param (Encoders.nonNullable (Encoders.array (Encoders.dimension foldl' (Encoders.element (Encoders.nonNullable Encoders.int8)))))
decoder = decoder =
fmap (maybe False (const True)) (Decoders.rowMaybe ((Decoders.column . Decoders.nonNullable) Decoders.bool)) fmap (maybe False (const True)) (Decoders.rowMaybe ((Decoders.column . Decoders.nonNullable) Decoders.bool))
session = session =
do do
result1 <- Session.statement [1, 2] statement result1 <- Session.statement [1, 2] statement
result2 <- Session.statement [2, 3] statement result2 <- Session.statement [2, 3] statement
return (result1, result2) return (result1, result2)
in do in do
x <- Connection.with (Session.run session) x <- Connection.with (Session.run session)
assertEqual (show x) (Right (Right (True, False))) x, assertEqual (show x) (Right (Right (True, False))) x,
testCase "Composite decoding" $ testCase "Composite decoding"
let statement = $ let statement =
Statement.Statement sql encoder decoder True Statement.Statement sql encoder decoder True
where where
sql = sql =
"select (1, true)" "select (1, true)"
encoder = encoder =
mempty mempty
decoder = decoder =
Decoders.singleRow ((Decoders.column . Decoders.nonNullable) (Decoders.composite ((,) <$> (Decoders.field . Decoders.nonNullable) Decoders.int8 <*> (Decoders.field . Decoders.nonNullable) Decoders.bool))) Decoders.singleRow ((Decoders.column . Decoders.nonNullable) (Decoders.composite ((,) <$> (Decoders.field . Decoders.nonNullable) Decoders.int8 <*> (Decoders.field . Decoders.nonNullable) Decoders.bool)))
session = session =
Session.statement () statement Session.statement () statement
in do in do
x <- Connection.with (Session.run session) x <- Connection.with (Session.run session)
assertEqual (show x) (Right (Right (1, True))) x, assertEqual (show x) (Right (Right (1, True))) x,
testCase "Complex composite decoding" $ testCase "Complex composite decoding"
let statement = $ let statement =
Statement.Statement sql encoder decoder True Statement.Statement sql encoder decoder True
where where
sql = sql =
"select (1, true) as entity1, ('hello', 3) as entity2" "select (1, true) as entity1, ('hello', 3) as entity2"
encoder = encoder =
mempty mempty
decoder = decoder =
Decoders.singleRow $ Decoders.singleRow
(,) <$> (Decoders.column . Decoders.nonNullable) entity1 <*> (Decoders.column . Decoders.nonNullable) entity2 $ (,)
where <$> (Decoders.column . Decoders.nonNullable) entity1
entity1 = <*> (Decoders.column . Decoders.nonNullable) entity2
Decoders.composite $
(,) <$> (Decoders.field . Decoders.nonNullable) Decoders.int8 <*> (Decoders.field . Decoders.nonNullable) Decoders.bool
entity2 =
Decoders.composite $
(,) <$> (Decoders.field . Decoders.nonNullable) Decoders.text <*> (Decoders.field . Decoders.nonNullable) Decoders.int8
session =
Session.statement () statement
in do
x <- Connection.with (Session.run session)
assertEqual (show x) (Right (Right ((1, True), ("hello", 3)))) x,
testGroup "unknownEnum" $
[ testCase "" $ do
res <- DSL.session $ do
let statement =
Statement.Statement sql mempty Decoders.noResult True
where where
sql = entity1 =
"drop type if exists mood" Decoders.composite
in DSL.statement () statement $ (,)
let statement = <$> (Decoders.field . Decoders.nonNullable) Decoders.int8
Statement.Statement sql mempty Decoders.noResult True <*> (Decoders.field . Decoders.nonNullable) Decoders.bool
where entity2 =
sql = Decoders.composite
"create type mood as enum ('sad', 'ok', 'happy')" $ (,)
in DSL.statement () statement <$> (Decoders.field . Decoders.nonNullable) Decoders.text
let statement = <*> (Decoders.field . Decoders.nonNullable) Decoders.int8
Statement.Statement sql encoder decoder True session =
where Session.statement () statement
sql = in do
"select $1" x <- Connection.with (Session.run session)
decoder = assertEqual (show x) (Right (Right ((1, True), ("hello", 3)))) x,
(Decoders.singleRow ((Decoders.column . Decoders.nonNullable) (Decoders.enum (Just . id)))) testGroup "unknownEnum"
encoder = $ [ testCase "" $ do
Encoders.param (Encoders.nonNullable (Encoders.unknownEnum id)) res <- DSL.session $ do
in DSL.statement "ok" statement let statement =
Statement.Statement sql mempty Decoders.noResult True
where
sql =
"drop type if exists mood"
in DSL.statement () statement
let statement =
Statement.Statement sql mempty Decoders.noResult True
where
sql =
"create type mood as enum ('sad', 'ok', 'happy')"
in DSL.statement () statement
let statement =
Statement.Statement sql encoder decoder True
where
sql =
"select $1"
decoder =
(Decoders.singleRow ((Decoders.column . Decoders.nonNullable) (Decoders.enum (Just . id))))
encoder =
Encoders.param (Encoders.nonNullable (Encoders.unknownEnum id))
in DSL.statement "ok" statement
assertEqual "" (Right "ok") res assertEqual "" (Right "ok") res
], ],
testCase "Composite encoding" $ do testCase "Composite encoding" $ do
let value = let value =
(123, 456, 789, "abc") (123, 456, 789, "abc")
@ -165,15 +171,18 @@ tree =
sql = sql =
"select $1 :: pg_enum" "select $1 :: pg_enum"
encoder = encoder =
Encoders.param . Encoders.nonNullable . Encoders.composite . mconcat $ Encoders.param
[ contramap (\(a, _, _, _) -> a) . Encoders.field . Encoders.nonNullable $ Encoders.oid, . Encoders.nonNullable
contramap (\(_, a, _, _) -> a) . Encoders.field . Encoders.nonNullable $ Encoders.oid, . Encoders.composite
contramap (\(_, _, a, _) -> a) . Encoders.field . Encoders.nonNullable $ Encoders.float4, . mconcat
contramap (\(_, _, _, a) -> a) . Encoders.field . Encoders.nonNullable $ Encoders.name $ [ contramap (\(a, _, _, _) -> a) . Encoders.field . Encoders.nonNullable $ Encoders.oid,
] contramap (\(_, a, _, _) -> a) . Encoders.field . Encoders.nonNullable $ Encoders.oid,
contramap (\(_, _, a, _) -> a) . Encoders.field . Encoders.nonNullable $ Encoders.float4,
contramap (\(_, _, _, a) -> a) . Encoders.field . Encoders.nonNullable $ Encoders.name
]
decoder = decoder =
Decoders.singleRow $ Decoders.singleRow
(Decoders.column . Decoders.nonNullable . Decoders.composite) $ (Decoders.column . Decoders.nonNullable . Decoders.composite)
( (,,,) ( (,,,)
<$> (Decoders.field . Decoders.nonNullable) Decoders.int4 <$> (Decoders.field . Decoders.nonNullable) Decoders.int4
<*> (Decoders.field . Decoders.nonNullable) Decoders.int4 <*> (Decoders.field . Decoders.nonNullable) Decoders.int4
@ -182,270 +191,270 @@ tree =
) )
in Connection.with $ Session.run $ Session.statement value statement in Connection.with $ Session.run $ Session.statement value statement
assertEqual "" (Right (Right value)) res, assertEqual "" (Right (Right value)) res,
testCase "Empty array" $ testCase "Empty array"
let io = $ let io =
do do
x <- Connection.with (Session.run session) x <- Connection.with (Session.run session)
assertEqual (show x) (Right (Right [])) x assertEqual (show x) (Right (Right [])) x
where where
session = session =
Session.statement () statement Session.statement () statement
where where
statement = statement =
Statement.Statement sql encoder decoder True Statement.Statement sql encoder decoder True
where where
sql = sql =
"select array[]::int8[]" "select array[]::int8[]"
encoder = encoder =
mempty mempty
decoder = decoder =
Decoders.singleRow ((Decoders.column . Decoders.nonNullable) (Decoders.array (Decoders.dimension replicateM (Decoders.element (Decoders.nonNullable Decoders.int8))))) Decoders.singleRow ((Decoders.column . Decoders.nonNullable) (Decoders.array (Decoders.dimension replicateM (Decoders.element (Decoders.nonNullable Decoders.int8)))))
in io, in io,
testCase "Failing prepared statements" $ testCase "Failing prepared statements"
let io = $ let io =
Connection.with (Session.run session) Connection.with (Session.run session)
>>= (assertBool <$> show <*> resultTest) >>= (assertBool <$> show <*> resultTest)
where where
resultTest = resultTest =
\case \case
Right (Left (Session.QueryError _ _ (Session.ResultError (Session.ServerError "26000" _ _ _ _)))) -> False Right (Left (Session.QueryError _ _ (Session.ResultError (Session.ServerError "26000" _ _ _ _)))) -> False
_ -> True _ -> True
session = session =
catchError session (const (pure ())) *> session catchError session (const (pure ())) *> session
where where
session = session =
Session.statement () statement Session.statement () statement
where where
statement = statement =
Statement.Statement sql encoder decoder True Statement.Statement sql encoder decoder True
where where
sql = sql =
"absurd" "absurd"
encoder = encoder =
mempty mempty
decoder = decoder =
Decoders.noResult Decoders.noResult
in io, in io,
testCase "Prepared statements after error" $ testCase "Prepared statements after error"
let io = $ let io =
Connection.with (Session.run session) Connection.with (Session.run session)
>>= \x -> assertBool (show x) (either (const False) isRight x) >>= \x -> assertBool (show x) (either (const False) isRight x)
where where
session = session =
try *> fail *> try try *> fail *> try
where where
try = try =
Session.statement 1 statement Session.statement 1 statement
where where
statement = statement =
Statement.Statement sql encoder decoder True Statement.Statement sql encoder decoder True
where where
sql = sql =
"select $1 :: int8" "select $1 :: int8"
encoder = encoder =
Encoders.param (Encoders.nonNullable (Encoders.int8)) Encoders.param (Encoders.nonNullable (Encoders.int8))
decoder = decoder =
Decoders.singleRow $ (Decoders.column . Decoders.nonNullable) Decoders.int8 Decoders.singleRow $ (Decoders.column . Decoders.nonNullable) Decoders.int8
fail = fail =
catchError (Session.sql "absurd") (const (pure ())) catchError (Session.sql "absurd") (const (pure ()))
in io, in io,
testCase "\"in progress after error\" bugfix" $ testCase "\"in progress after error\" bugfix"
let sumStatement :: Statement.Statement (Int64, Int64) Int64 $ let sumStatement :: Statement.Statement (Int64, Int64) Int64
sumStatement = sumStatement =
Statement.Statement sql encoder decoder True Statement.Statement sql encoder decoder True
where where
sql = sql =
"select ($1 + $2)" "select ($1 + $2)"
encoder = encoder =
contramap fst (Encoders.param (Encoders.nonNullable (Encoders.int8))) contramap fst (Encoders.param (Encoders.nonNullable (Encoders.int8)))
<> contramap snd (Encoders.param (Encoders.nonNullable (Encoders.int8))) <> contramap snd (Encoders.param (Encoders.nonNullable (Encoders.int8)))
decoder = decoder =
Decoders.singleRow ((Decoders.column . Decoders.nonNullable) Decoders.int8) Decoders.singleRow ((Decoders.column . Decoders.nonNullable) Decoders.int8)
sumSession :: Session.Session Int64 sumSession :: Session.Session Int64
sumSession = sumSession =
Session.sql "begin" *> Session.statement (1, 1) sumStatement <* Session.sql "end" Session.sql "begin" *> Session.statement (1, 1) sumStatement <* Session.sql "end"
errorSession :: Session.Session () errorSession :: Session.Session ()
errorSession = errorSession =
Session.sql "asldfjsldk" Session.sql "asldfjsldk"
io = io =
Connection.with $ \c -> do Connection.with $ \c -> do
Session.run errorSession c Session.run errorSession c
Session.run sumSession c Session.run sumSession c
in io >>= \x -> assertBool (show x) (either (const False) isRight x), in io >>= \x -> assertBool (show x) (either (const False) isRight x),
testCase "\"another command is already in progress\" bugfix" $ testCase "\"another command is already in progress\" bugfix"
let sumStatement :: Statement.Statement (Int64, Int64) Int64 $ let sumStatement :: Statement.Statement (Int64, Int64) Int64
sumStatement = sumStatement =
Statement.Statement sql encoder decoder True Statement.Statement sql encoder decoder True
where where
sql = sql =
"select ($1 + $2)" "select ($1 + $2)"
encoder = encoder =
contramap fst (Encoders.param (Encoders.nonNullable (Encoders.int8))) contramap fst (Encoders.param (Encoders.nonNullable (Encoders.int8)))
<> contramap snd (Encoders.param (Encoders.nonNullable (Encoders.int8))) <> contramap snd (Encoders.param (Encoders.nonNullable (Encoders.int8)))
decoder = decoder =
Decoders.singleRow ((Decoders.column . Decoders.nonNullable) Decoders.int8) Decoders.singleRow ((Decoders.column . Decoders.nonNullable) Decoders.int8)
session :: Session.Session Int64 session :: Session.Session Int64
session = session =
do do
Session.sql "begin;" Session.sql "begin;"
s <- Session.statement (1, 1) sumStatement s <- Session.statement (1, 1) sumStatement
Session.sql "end;" Session.sql "end;"
return s return s
in DSL.session session >>= \x -> assertEqual (show x) (Right 2) x, in DSL.session session >>= \x -> assertEqual (show x) (Right 2) x,
testCase "Executing the same query twice" $ testCase "Executing the same query twice"
pure (), $ pure (),
testCase "Interval Encoding" $ testCase "Interval Encoding"
let actualIO = $ let actualIO =
DSL.session $ do DSL.session $ do
let statement = let statement =
Statement.Statement sql encoder decoder True Statement.Statement sql encoder decoder True
where where
sql = sql =
"select $1 = interval '10 seconds'" "select $1 = interval '10 seconds'"
decoder = decoder =
(Decoders.singleRow ((Decoders.column . Decoders.nonNullable) (Decoders.bool))) (Decoders.singleRow ((Decoders.column . Decoders.nonNullable) (Decoders.bool)))
encoder = encoder =
Encoders.param (Encoders.nonNullable (Encoders.interval)) Encoders.param (Encoders.nonNullable (Encoders.interval))
in DSL.statement (10 :: DiffTime) statement in DSL.statement (10 :: DiffTime) statement
in actualIO >>= \x -> assertEqual (show x) (Right True) x, in actualIO >>= \x -> assertEqual (show x) (Right True) x,
testCase "Interval Decoding" $ testCase "Interval Decoding"
let actualIO = $ let actualIO =
DSL.session $ do DSL.session $ do
let statement = let statement =
Statement.Statement sql encoder decoder True Statement.Statement sql encoder decoder True
where where
sql = sql =
"select interval '10 seconds'" "select interval '10 seconds'"
decoder = decoder =
(Decoders.singleRow ((Decoders.column . Decoders.nonNullable) (Decoders.interval))) (Decoders.singleRow ((Decoders.column . Decoders.nonNullable) (Decoders.interval)))
encoder = encoder =
Encoders.noParams Encoders.noParams
in DSL.statement () statement in DSL.statement () statement
in actualIO >>= \x -> assertEqual (show x) (Right (10 :: DiffTime)) x, in actualIO >>= \x -> assertEqual (show x) (Right (10 :: DiffTime)) x,
testCase "Interval Encoding/Decoding" $ testCase "Interval Encoding/Decoding"
let actualIO = $ let actualIO =
DSL.session $ do DSL.session $ do
let statement = let statement =
Statement.Statement sql encoder decoder True Statement.Statement sql encoder decoder True
where where
sql = sql =
"select $1" "select $1"
decoder = decoder =
(Decoders.singleRow ((Decoders.column . Decoders.nonNullable) (Decoders.interval))) (Decoders.singleRow ((Decoders.column . Decoders.nonNullable) (Decoders.interval)))
encoder = encoder =
Encoders.param (Encoders.nonNullable (Encoders.interval)) Encoders.param (Encoders.nonNullable (Encoders.interval))
in DSL.statement (10 :: DiffTime) statement in DSL.statement (10 :: DiffTime) statement
in actualIO >>= \x -> assertEqual (show x) (Right (10 :: DiffTime)) x, in actualIO >>= \x -> assertEqual (show x) (Right (10 :: DiffTime)) x,
testCase "Unknown" $ testCase "Unknown"
let actualIO = $ let actualIO =
DSL.session $ do DSL.session $ do
let statement = let statement =
Statement.Statement sql mempty Decoders.noResult True Statement.Statement sql mempty Decoders.noResult True
where where
sql = sql =
"drop type if exists mood" "drop type if exists mood"
in DSL.statement () statement in DSL.statement () statement
let statement = let statement =
Statement.Statement sql mempty Decoders.noResult True Statement.Statement sql mempty Decoders.noResult True
where where
sql = sql =
"create type mood as enum ('sad', 'ok', 'happy')" "create type mood as enum ('sad', 'ok', 'happy')"
in DSL.statement () statement in DSL.statement () statement
let statement = let statement =
Statement.Statement sql encoder decoder True Statement.Statement sql encoder decoder True
where where
sql = sql =
"select $1 = ('ok' :: mood)" "select $1 = ('ok' :: mood)"
decoder = decoder =
(Decoders.singleRow ((Decoders.column . Decoders.nonNullable) (Decoders.bool))) (Decoders.singleRow ((Decoders.column . Decoders.nonNullable) (Decoders.bool)))
encoder = encoder =
Encoders.param (Encoders.nonNullable (Encoders.unknown)) Encoders.param (Encoders.nonNullable (Encoders.unknown))
in DSL.statement "ok" statement in DSL.statement "ok" statement
in actualIO >>= assertEqual "" (Right True), in actualIO >>= assertEqual "" (Right True),
testCase "Textual Unknown" $ testCase "Textual Unknown"
let actualIO = $ let actualIO =
DSL.session $ do DSL.session $ do
let statement = let statement =
Statement.Statement sql mempty Decoders.noResult True Statement.Statement sql mempty Decoders.noResult True
where where
sql = sql =
"create or replace function overloaded(a int, b int) returns int as $$ select a + b $$ language sql;" "create or replace function overloaded(a int, b int) returns int as $$ select a + b $$ language sql;"
in DSL.statement () statement in DSL.statement () statement
let statement = let statement =
Statement.Statement sql mempty Decoders.noResult True Statement.Statement sql mempty Decoders.noResult True
where where
sql = sql =
"create or replace function overloaded(a text, b text, c text) returns text as $$ select a || b || c $$ language sql;" "create or replace function overloaded(a text, b text, c text) returns text as $$ select a || b || c $$ language sql;"
in DSL.statement () statement in DSL.statement () statement
let statement = let statement =
Statement.Statement sql encoder decoder True Statement.Statement sql encoder decoder True
where where
sql = sql =
"select overloaded($1, $2) || overloaded($3, $4, $5)" "select overloaded($1, $2) || overloaded($3, $4, $5)"
decoder = decoder =
(Decoders.singleRow ((Decoders.column . Decoders.nonNullable) (Decoders.text))) (Decoders.singleRow ((Decoders.column . Decoders.nonNullable) (Decoders.text)))
encoder = encoder =
contramany (Encoders.param (Encoders.nonNullable (Encoders.unknown))) contramany (Encoders.param (Encoders.nonNullable (Encoders.unknown)))
in DSL.statement ["1", "2", "4", "5", "6"] statement in DSL.statement ["1", "2", "4", "5", "6"] statement
in actualIO >>= assertEqual "" (Right "3456"), in actualIO >>= assertEqual "" (Right "3456"),
testCase "Enum" $ testCase "Enum"
let actualIO = $ let actualIO =
DSL.session $ do DSL.session $ do
let statement = let statement =
Statement.Statement sql mempty Decoders.noResult True Statement.Statement sql mempty Decoders.noResult True
where where
sql = sql =
"drop type if exists mood" "drop type if exists mood"
in DSL.statement () statement in DSL.statement () statement
let statement = let statement =
Statement.Statement sql mempty Decoders.noResult True Statement.Statement sql mempty Decoders.noResult True
where where
sql = sql =
"create type mood as enum ('sad', 'ok', 'happy')" "create type mood as enum ('sad', 'ok', 'happy')"
in DSL.statement () statement in DSL.statement () statement
let statement = let statement =
Statement.Statement sql encoder decoder True Statement.Statement sql encoder decoder True
where where
sql = sql =
"select ($1 :: mood)" "select ($1 :: mood)"
decoder = decoder =
(Decoders.singleRow ((Decoders.column . Decoders.nonNullable) (Decoders.enum (Just . id)))) (Decoders.singleRow ((Decoders.column . Decoders.nonNullable) (Decoders.enum (Just . id))))
encoder = encoder =
Encoders.param (Encoders.nonNullable ((Encoders.enum id))) Encoders.param (Encoders.nonNullable ((Encoders.enum id)))
in DSL.statement "ok" statement in DSL.statement "ok" statement
in actualIO >>= assertEqual "" (Right "ok"), in actualIO >>= assertEqual "" (Right "ok"),
testCase "The same prepared statement used on different types" $ testCase "The same prepared statement used on different types"
let actualIO = $ let actualIO =
DSL.session $ do DSL.session $ do
let effect1 = let effect1 =
DSL.statement "ok" statement DSL.statement "ok" statement
where where
statement = statement =
Statement.Statement sql encoder decoder True Statement.Statement sql encoder decoder True
where where
sql = sql =
"select $1" "select $1"
encoder = encoder =
Encoders.param (Encoders.nonNullable (Encoders.text)) Encoders.param (Encoders.nonNullable (Encoders.text))
decoder = decoder =
(Decoders.singleRow ((Decoders.column . Decoders.nonNullable) (Decoders.text))) (Decoders.singleRow ((Decoders.column . Decoders.nonNullable) (Decoders.text)))
effect2 = effect2 =
DSL.statement 1 statement DSL.statement 1 statement
where where
statement = statement =
Statement.Statement sql encoder decoder True Statement.Statement sql encoder decoder True
where where
sql = sql =
"select $1" "select $1"
encoder = encoder =
Encoders.param (Encoders.nonNullable (Encoders.int8)) Encoders.param (Encoders.nonNullable (Encoders.int8))
decoder = decoder =
(Decoders.singleRow ((Decoders.column . Decoders.nonNullable) Decoders.int8)) (Decoders.singleRow ((Decoders.column . Decoders.nonNullable) Decoders.int8))
in (,) <$> effect1 <*> effect2 in (,) <$> effect1 <*> effect2
in actualIO >>= assertEqual "" (Right ("ok", 1)), in actualIO >>= assertEqual "" (Right ("ok", 1)),
testCase "Affected rows counting" $ testCase "Affected rows counting"
replicateM_ 13 $ $ replicateM_ 13
let actualIO = $ let actualIO =
DSL.session $ do DSL.session $ do
dropTable dropTable
createTable createTable
@ -453,17 +462,17 @@ tree =
deleteRows <* dropTable deleteRows <* dropTable
where where
dropTable = dropTable =
DSL.statement () $ DSL.statement ()
Statements.plain $ $ Statements.plain
"drop table if exists a" $ "drop table if exists a"
createTable = createTable =
DSL.statement () $ DSL.statement ()
Statements.plain $ $ Statements.plain
"create table a (id bigserial not null, name varchar not null, primary key (id))" $ "create table a (id bigserial not null, name varchar not null, primary key (id))"
insertRow = insertRow =
DSL.statement () $ DSL.statement ()
Statements.plain $ $ Statements.plain
"insert into a (name) values ('a')" $ "insert into a (name) values ('a')"
deleteRows = deleteRows =
DSL.statement () $ Statement.Statement sql mempty decoder False DSL.statement () $ Statement.Statement sql mempty decoder False
where where
@ -472,18 +481,18 @@ tree =
decoder = decoder =
Decoders.rowsAffected Decoders.rowsAffected
in actualIO >>= assertEqual "" (Right 100), in actualIO >>= assertEqual "" (Right 100),
testCase "Result of an auto-incremented column" $ testCase "Result of an auto-incremented column"
let actualIO = $ let actualIO =
DSL.session $ do DSL.session $ do
DSL.statement () $ Statements.plain $ "drop table if exists a" DSL.statement () $ Statements.plain $ "drop table if exists a"
DSL.statement () $ Statements.plain $ "create table a (id serial not null, v char not null, primary key (id))" DSL.statement () $ Statements.plain $ "create table a (id serial not null, v char not null, primary key (id))"
id1 <- DSL.statement () $ Statement.Statement "insert into a (v) values ('a') returning id" mempty (Decoders.singleRow ((Decoders.column . Decoders.nonNullable) Decoders.int4)) False id1 <- DSL.statement () $ Statement.Statement "insert into a (v) values ('a') returning id" mempty (Decoders.singleRow ((Decoders.column . Decoders.nonNullable) Decoders.int4)) False
id2 <- DSL.statement () $ Statement.Statement "insert into a (v) values ('b') returning id" mempty (Decoders.singleRow ((Decoders.column . Decoders.nonNullable) Decoders.int4)) False id2 <- DSL.statement () $ Statement.Statement "insert into a (v) values ('b') returning id" mempty (Decoders.singleRow ((Decoders.column . Decoders.nonNullable) Decoders.int4)) False
DSL.statement () $ Statements.plain $ "drop table if exists a" DSL.statement () $ Statements.plain $ "drop table if exists a"
pure (id1, id2) pure (id1, id2)
in assertEqual "" (Right (1, 2)) =<< actualIO, in assertEqual "" (Right (1, 2)) =<< actualIO,
testCase "List decoding" $ testCase "List decoding"
let actualIO = $ let actualIO =
DSL.session $ DSL.statement () $ Statements.selectList DSL.session $ DSL.statement () $ Statements.selectList
in assertEqual "" (Right [(1, 2), (3, 4), (5, 6)]) =<< actualIO in assertEqual "" (Right [(1, 2), (3, 4), (5, 6)]) =<< actualIO
] ]

View File

@ -38,8 +38,8 @@ session session =
password = "" password = ""
database = "postgres" database = "postgres"
use connection = use connection =
ExceptT $ ExceptT
fmap (mapLeft SessionError) $ $ fmap (mapLeft SessionError)
Hasql.Session.run session connection $ Hasql.Session.run session connection
release connection = release connection =
lift $ HC.release connection lift $ HC.release connection

View File

@ -12,15 +12,18 @@ plain sql =
dropType :: ByteString -> HQ.Statement () () dropType :: ByteString -> HQ.Statement () ()
dropType name = dropType name =
plain $ plain
"drop type if exists " <> name $ "drop type if exists "
<> name
createEnum :: ByteString -> [ByteString] -> HQ.Statement () () createEnum :: ByteString -> [ByteString] -> HQ.Statement () ()
createEnum name values = createEnum name values =
plain $ plain
"create type " <> name <> " as enum (" $ "create type "
<> mconcat (intersperse ", " (map (\x -> "'" <> x <> "'") values)) <> name
<> ")" <> " as enum ("
<> mconcat (intersperse ", " (map (\x -> "'" <> x <> "'") values))
<> ")"
selectList :: HQ.Statement () ([] (Int64, Int64)) selectList :: HQ.Statement () ([] (Int64, Int64))
selectList = selectList =

View File

@ -15,9 +15,9 @@ main =
(,) <$> acquire <*> acquire (,) <$> acquire <*> acquire
where where
acquire = acquire =
join $ join
fmap (either (fail . show) return) $ $ fmap (either (fail . show) return)
Hasql.Connection.acquire connectionSettings $ Hasql.Connection.acquire connectionSettings
where where
connectionSettings = connectionSettings =
Hasql.Connection.settings "localhost" 5432 "postgres" "" "postgres" Hasql.Connection.settings "localhost" 5432 "postgres" "" "postgres"